diff --git a/.gitignore b/.gitignore index a286365..8812707 100644 --- a/.gitignore +++ b/.gitignore @@ -69,6 +69,10 @@ playwright-report/ .parts/ *.partial +# Phase 2 W3a — runtime CA / cert dirs (controller + executor bootstrap) +.ca/ +.executor-certs/ + # Database (dev) *.db *.sqlite diff --git a/.gitleaks.toml b/.gitleaks.toml new file mode 100644 index 0000000..e9b6adc --- /dev/null +++ b/.gitleaks.toml @@ -0,0 +1,22 @@ +# gitleaks config — extends the default ruleset with project-specific allowlist. +# CI runs: gitleaks detect over the PR commit range. + +[extend] +useDefault = true + +[allowlist] +description = "modelpull false positives" + +# `key: ed25519.Ed25519PrivateKey` and similar dataclass field annotations in +# dlw.auth.* trip the default generic-api-key rule (a `key:` token followed by +# a high-entropy identifier). These are Python type annotations, not secrets — +# the actual CA / JWT keys are generated at runtime and persisted to +# chmod-600 files under ${DLW_CA_DIR}, never committed. +regexes = [ + '''ed25519\.Ed25519PrivateKey''', + '''ed25519\.Ed25519PublicKey''', +] + +# Test fixtures + auth modules legitimately contain the words "key", "token", +# "secret", "cert" in identifiers and docstrings; scope the regexes above +# rather than broad path excludes so real leaks are still caught. diff --git a/api/openapi.yaml b/api/openapi.yaml index bcf154f..eb0fe55 100644 --- a/api/openapi.yaml +++ b/api/openapi.yaml @@ -496,6 +496,21 @@ paths: name: X-Heartbeat-HMAC required: true schema: {type: string} + - in: header + name: X-HMAC-Timestamp + required: true + schema: {type: integer} + description: Unix epoch seconds; validated within ±5 min of server clock + - in: header + name: X-HMAC-Nonce + required: true + schema: {type: string} + description: 128-bit random hex nonce; replay window enforced by server + - in: header + name: X-HMAC-Signature + required: true + schema: {type: string} + description: HMAC-SHA256(hmac_seed_hex, body || X-HMAC-Timestamp || X-HMAC-Nonce) requestBody: required: true content: @@ -638,18 +653,39 @@ paths: schema: {type: string} post: tags: [executors] - summary: Renew JWT (proactive) + summary: Renew executor JWT and optionally rotate mTLS cert operationId: renewExecutorJwt + requestBody: + required: false + content: + application/json: + schema: + type: object + properties: + client_csr_pem: + type: string + nullable: true + description: X.509 CSR PEM for cert rotation; omit or null to renew JWT only responses: '200': - description: New JWT + description: Renewed credentials content: application/json: schema: type: object properties: executor_jwt: {type: string} - expires_at: {type: string, format: date-time} + jwt_renew_in_seconds: {type: integer} + client_cert_pem: + type: string + nullable: true + description: New PEM-encoded client cert; null if no CSR was provided + cert_renew_in_seconds: + type: integer + nullable: true + description: Seconds until next cert renewal; null if cert not rotated + '401': + description: Invalid or expired executor JWT / mTLS cert /executors/{executorId}/poll: parameters: @@ -1525,7 +1561,7 @@ components: # ===== Executor ===== ExecutorRegisterRequest: type: object - required: [host_id, executor_id_proposal, capabilities, client_csr] + required: [host_id, executor_id_proposal, capabilities, client_csr_pem] properties: host_id: {type: string} executor_id_proposal: {type: string} @@ -1565,22 +1601,26 @@ components: properties: alias: {type: string} source_id: {type: string} - client_csr: + client_csr_pem: type: string - description: X.509 CSR PEM + description: X.509 CSR in PEM format ExecutorRegisterResponse: type: object properties: executor_id: {type: string} epoch: {type: integer, format: int64} - client_cert: {type: string, description: PEM-encoded X.509} + client_cert_pem: {type: string, description: PEM-encoded X.509 client certificate} ca_chain: type: array items: {type: string} + description: CA certificate chain PEM strings executor_jwt: {type: string} + hmac_seed_hex: {type: string, description: 256-bit hex HMAC seed for heartbeat signatures} + cert_renew_in_seconds: {type: integer, description: Seconds until certificate renewal recommended} + jwt_renew_in_seconds: {type: integer, description: Seconds until JWT renewal recommended} jwt_signing_alg: {type: string, const: EdDSA} - next_renew_in_seconds: {type: integer} + next_renew_in_seconds: {type: integer, description: "Deprecated: use cert_renew_in_seconds"} HeartbeatRequest: type: object diff --git a/docs/operator/executor-runbook.md b/docs/operator/executor-runbook.md index c2bf258..9110db8 100644 --- a/docs/operator/executor-runbook.md +++ b/docs/operator/executor-runbook.md @@ -37,3 +37,52 @@ any paused subtasks that appeared after the original cancel. A future Phase 2 W3 release will add heartbeat-carried cancellation signals so executors abort in-flight downloads on chunk boundaries, reducing latency to sub-minute. + +## mTLS + Executor JWT + HMAC (Phase 2 W3a+) + +### Controller bootstrap + +On first launch the controller generates, under `${DLW_CA_DIR}` (default +`./.ca`, chmod 700): + +- `ca-cert.pem` / `ca-key.pem` — the self-signed CA (10-year validity). +- `server-cert.pem` / `server-key.pem` — the controller's TLS server cert + (SAN: localhost, $DLW_CONTROLLER_HOSTNAME, 127.0.0.1, ::1). +- `jwt-signing.pem` — Ed25519 JWT signing key. +- `enrollment.token` — 256-bit hex token (also logged once at INFO). + +Run uvicorn with TLS: + + uvicorn dlw.main:app --host 0.0.0.0 --port 8443 \ + --ssl-keyfile ${DLW_CA_DIR}/server-key.pem \ + --ssl-certfile ${DLW_CA_DIR}/server-cert.pem \ + --ssl-ca-certs ${DLW_CA_DIR}/ca-cert.pem \ + --ssl-cert-reqs 1 + +`--ssl-cert-reqs 1` (CERT_OPTIONAL) — the server requests a client cert but +does not reject connections that lack one at the TLS layer. `/register` +(enrollment-token auth, no client cert) and `/health/*` need this. The +application layer (`require_executor_mtls`) enforces the cert where required. + +### Enrolling an executor + +1. Copy the controller's enrollment token to the executor host out-of-band. +2. Set `DLW_EXECUTOR_ENROLLMENT_TOKEN` on the executor. +3. On first boot the executor generates a keypair, builds a CSR, calls + `/register`, and persists `client-cert.pem` / `client-key.pem` / + `ca-chain.pem` / `hmac-seed` under `${DLW_EXECUTOR_EXECUTOR_CERT_DIR}` + (default `./.executor-certs`). +4. Certs auto-renew (24h cert, 1h JWT) via the executor's renew loop. + +### `DLW_TLS_TRUSTED_PROXY` — security warning + +`DLW_TLS_TRUSTED_PROXY=1` makes the controller honor the +`X-Client-Cert-PEM` header instead of the direct TLS peer cert. Only +enable this when a real TLS-terminating reverse proxy sits in front AND +the uvicorn port is NOT directly reachable. With it on and no proxy, +anyone can forge the header. Default is `0` (direct uvicorn TLS only). + +### Host clock sync + +Heartbeats carry an HMAC timestamp validated within ±5 min. Run +`chrony` / `systemd-timesyncd` on all executor + controller hosts. diff --git a/docs/superpowers/plans/2026-05-14-phase-2-w3a-mtls-jwt-hmac.md b/docs/superpowers/plans/2026-05-14-phase-2-w3a-mtls-jwt-hmac.md new file mode 100644 index 0000000..0a284d2 --- /dev/null +++ b/docs/superpowers/plans/2026-05-14-phase-2-w3a-mtls-jwt-hmac.md @@ -0,0 +1,2598 @@ +# Phase 2 Week 3a — mTLS + Executor JWT + HMAC Heartbeat Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Replace executor-side bearer auth with SVID-style mTLS + Ed25519 JWT + HMAC heartbeat (roadmap §2.6 Day 1-3 — SEC-01 + SEC-04). Self-signed CA file-persisted under `${DLW_CA_DIR}`; `POST /register` (CSR signing) replaces W1 `/join`; `POST /{eid}/renew` for cert + JWT lifecycle; in-process nonce store for anti-replay. UI bearer auth retained. + +**Architecture:** Controller bootstraps a self-signed CA + Ed25519 JWT signing key + server cert at first startup (file-persisted, chmod 600). New `dlw.auth.{ca,jwt_signing,hmac_nonce}` modules + three chained FastAPI dependencies (`require_executor_mtls` → `require_executor_jwt` → `require_hmac_heartbeat`). `require_executor_epoch` is refactored to chain under the JWT dep and assert the path id matches the mTLS-authenticated identity (confused-deputy guard). Executor generates its own keypair + CSR, registers, persists cert/key/seed locally, and runs a third background loop to renew before expiry. uvicorn terminates TLS via `--ssl-*` flags. + +**Tech Stack:** `cryptography>=43,<44` (promoted from transitive to explicit), `pyjwt[crypto]>=2.9,<3.0` (NEW — EdDSA JWT). SQLAlchemy 2.x async + alembic. pytest with ephemeral-CA fixtures; one real-TLS e2e via a uvicorn subprocess. No new CI jobs. + +**Scope:** 11 tasks across 4 milestones. Branch `feat/phase-2-w3a-mtls-jwt-hmac` exists with the spec committed (`255a561`). Companion spec: `docs/superpowers/specs/2026-05-14-phase-2-w3a-mtls-jwt-hmac-design.md`. + +**Pre-flight:** Phase 2 W2b2 merged into `main` at `ba89a91`. Local PG 18 on `localhost:5433`. `uv` 0.11.9. Existing pytest baseline = 181 passed, 1 deselected. Alembic head `b1d5ea4944ba`. + +**Out-of-scope (deferred — see spec §1.2):** HF reverse-proxy (W3b); active/standby + chaos drill (W3c); OIDC / multi-tenant / UI auth (Phase 3); Vault/KMS for keys (Phase 3); CRL / cert-manager (Phase 3+); envelope encryption of `hmac_seed` (Phase 3); PG/Redis nonce store (Phase 3); dual-auth transition window (not needed — hard cutover). + +--- + +## File Structure + +After this plan: + +``` +modelpull/ +├── pyproject.toml MODIFY (+cryptography, +pyjwt[crypto]) +├── uv.lock MODIFY (uv add regenerates) +├── src/dlw/ +│ ├── alembic/versions/_p2w3a_hmac_seed.py NEW +│ ├── db/models/executor.py MODIFY (+hmac_seed_encrypted column) +│ ├── auth/ +│ │ ├── ca.py NEW (CA + sign_csr + fingerprint + ensure_server_cert) +│ │ ├── jwt_signing.py NEW (Ed25519 JWT sign/verify) +│ │ ├── hmac_nonce.py NEW (NonceStore + compute/verify_hmac) +│ │ ├── executor_mtls.py NEW (require_executor_mtls dep) +│ │ ├── executor_jwt_dep.py NEW (require_executor_jwt dep) +│ │ ├── hmac_heartbeat_dep.py NEW (require_hmac_heartbeat dep) +│ │ ├── executor_epoch.py MODIFY (refactor: chain under JWT, assert path id) +│ │ └── bearer.py (W1, unchanged — UI only) +│ ├── api/ +│ │ ├── executors.py MODIFY (/register + /renew NEW; /join DELETED; auth chains) +│ │ └── subtasks.py MODIFY (/report auth chain) +│ ├── schemas/executor.py MODIFY (+ExecutorRegister/RegistrationResponse/RenewResponse; -ExecutorJoin) +│ ├── services/executor_service.py MODIFY (join_executor → upsert_executor_with_cert) +│ ├── main.py MODIFY (lifespan bootstrap CA/JWT/nonce/enrollment) +│ ├── config.py MODIFY (+ca_dir, +enrollment_token, +controller_hostname, +tls_trusted_proxy) +│ └── executor/ +│ ├── cert.py NEW (build_csr / persist / load / fingerprint) +│ ├── auth_lifecycle.py NEW (AuthState / register / renew / load_or_register) +│ ├── client.py MODIFY (mTLS + JWT + HMAC; AuthState-driven) +│ ├── runner.py MODIFY (load_or_register bootstrap; 3rd bg task) +│ └── config.py MODIFY (+enrollment_token, +executor_cert_dir, +executor_ca_bundle) +├── tests/ +│ ├── conftest.py MODIFY (+ephemeral_ca, +client_cert_pair, +_signed_heartbeat_headers) +│ ├── auth/ +│ │ ├── test_ca.py NEW (4 cases) +│ │ ├── test_jwt_signing.py NEW (4 cases) +│ │ ├── test_hmac_nonce.py NEW (4 cases) +│ │ ├── test_executor_mtls_dep.py NEW (3 cases) +│ │ ├── test_executor_jwt_dep.py NEW (2 cases) +│ │ ├── test_hmac_heartbeat_dep.py NEW (4 cases) +│ │ └── test_executor_epoch.py MODIFY (+1 confused-deputy case; migrate /join setups) +│ ├── api/ +│ │ ├── test_register_endpoint.py NEW (3 cases) +│ │ ├── test_renew_endpoint.py NEW (2 cases) +│ │ ├── test_executors.py MODIFY (joined_executor → registered_executor; HMAC headers) +│ │ └── test_subtasks.py MODIFY (fixture migration) +│ ├── e2e/ +│ │ ├── test_executor_auth_e2e.py NEW (1 real-TLS case) +│ │ ├── test_executor_e2e.py MODIFY (register flow) +│ │ └── test_happy_path.py MODIFY (register flow) +│ └── services/test_executor_service.py MODIFY (upsert_executor_with_cert rename) +├── tools/lint_invariants.py MODIFY (+check_no_bearer_on_executor_routes) +├── api/openapi.yaml MODIFY (/register +/renew; -/join; HMAC headers) +└── docs/operator/ MODIFY (CA dir, enrollment token, uvicorn --ssl-*, proxy warning) +``` + +--- + +## Pre-flight checks + +- [ ] On branch `feat/phase-2-w3a-mtls-jwt-hmac`, spec committed (`git log --oneline -1` shows `255a561` or descendant). +- [ ] `main` at `ba89a91` (PR #11 merge): `git log main --oneline -1`. +- [ ] PG running on `localhost:5433` (`pg_isready -h localhost -p 5433`). +- [ ] `dlw` database at alembic head `b1d5ea4944ba` (W2b2): `uv run alembic current`. +- [ ] Existing pytest suite green: `uv run pytest -x` → 181 passed, 1 deselected. + +--- + +## Milestone 1 — Auth substrate + +After M1: deps added, `executors.hmac_seed_encrypted` column exists, and `dlw.auth.{ca,jwt_signing,hmac_nonce}` modules work with ~12 unit tests. No endpoint wiring yet. + +--- + +### Task 1: Dependencies + alembic migration + ORM column + +**Files:** +- Modify: `pyproject.toml`, `uv.lock` +- Create: `src/dlw/alembic/versions/_p2w3a_hmac_seed.py` +- Modify: `src/dlw/db/models/executor.py` +- Possibly modify: `tests/db/test_alembic.py` + +- [ ] **Step 1: Add the two runtime dependencies** + +``` +uv add "cryptography>=43,<44" "pyjwt[crypto]>=2.9,<3.0" +``` + +This updates `pyproject.toml` `dependencies` + regenerates `uv.lock`. Verify: + +``` +uv run python -c "import cryptography, jwt; print(cryptography.__version__, jwt.__version__)" +``` + +Expected: prints two version strings (e.g. `43.x.x 2.x.x`). + +- [ ] **Step 2: Generate the alembic revision** + +``` +uv run alembic revision -m "p2w3a hmac_seed" +``` + +Note the 12-char hex revision id. Open the new file. + +- [ ] **Step 3: Verify down_revision** + +Confirm: + +```python +revision: str = '' +down_revision: Union[str, None] = 'b1d5ea4944ba' +``` + +If `down_revision` differs, fix it. (W2b2 `last_paused_at` is the current head.) + +- [ ] **Step 4: Implement upgrade/downgrade** + +```python +"""p2w3a hmac_seed + +Revision ID: +Revises: b1d5ea4944ba +Create Date: +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +revision: str = '' +down_revision: Union[str, None] = 'b1d5ea4944ba' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + "executors", + sa.Column("hmac_seed_encrypted", sa.LargeBinary(), nullable=True), + ) + + +def downgrade() -> None: + op.drop_column("executors", "hmac_seed_encrypted") +``` + +Replace both `` placeholders with the actual revision id. + +- [ ] **Step 5: Add the ORM column to `src/dlw/db/models/executor.py`** + +Read the file. Find the `Executor` class. After `cert_fingerprint` (W1 column), add: + +```python + # W3a §4: 256-bit HMAC seed for heartbeat anti-replay. "encrypted" in the + # name is forward-compatible — Phase 2 stores raw bytes; Phase 3 wraps with KMS. + hmac_seed_encrypted: Mapped[bytes | None] = mapped_column( + LargeBinary, nullable=True + ) +``` + +Add `LargeBinary` to the `from sqlalchemy import (...)` import line if absent. + +- [ ] **Step 6: Apply migration + verify** + +``` +uv run alembic upgrade head +psql -h localhost -p 5433 -U postgres -d dlw -c "\d executors" 2>&1 | grep hmac_seed_encrypted +``` + +Expected: prints a line containing `hmac_seed_encrypted | bytea`. + +- [ ] **Step 7: Verify downgrade reverses** + +``` +uv run alembic downgrade -1 +psql -h localhost -p 5433 -U postgres -d dlw -c "\d executors" 2>&1 | grep hmac_seed_encrypted +uv run alembic upgrade head +``` + +Expected: middle command returns nothing; final re-applies. + +- [ ] **Step 8: Update `tests/db/test_alembic.py` if it enumerates columns** + +Read `tests/db/test_alembic.py`. If it has an `EXPECTED_*` set listing `executors` columns, add `"hmac_seed_encrypted"`. Otherwise no change. + +- [ ] **Step 9: Run full suite** + +``` +uv run pytest -x +``` + +Expected: 181 passed, 1 deselected (deps + schema change, no behavior change). + +- [ ] **Step 10: Commit** + +```bash +git add pyproject.toml uv.lock src/dlw/alembic/versions/ src/dlw/db/models/executor.py tests/db/test_alembic.py +git commit -m "feat(db): p2w3a deps (cryptography+pyjwt) + alembic hmac_seed_encrypted (W3a M1)" +``` + +--- + +### Task 2: `dlw.auth.ca` — CA + CSR signing + server cert + +**Files:** +- Create: `src/dlw/auth/ca.py` +- Create: `tests/auth/__init__.py` (empty, if `tests/auth/` doesn't already have one) +- Create: `tests/auth/test_ca.py` + +- [ ] **Step 1: Write the failing tests** + +Create `tests/auth/test_ca.py`: + +```python +"""Tests for dlw.auth.ca (Phase 2 W3a §3.1).""" +from __future__ import annotations + +from cryptography import x509 +from cryptography.hazmat.primitives.asymmetric import ed25519 + +from dlw.auth.ca import ( + bootstrap_ca, + ensure_server_cert, + fingerprint_of, + sign_csr, +) + + +def _build_csr(executor_id: str) -> bytes: + """Helper: build an Ed25519 CSR for the given executor_id.""" + from cryptography.hazmat.primitives import serialization + from cryptography.x509.oid import NameOID + key = ed25519.Ed25519PrivateKey.generate() + csr = (x509.CertificateSigningRequestBuilder() + .subject_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, executor_id)])) + .sign(key, None)) + return csr.public_bytes(serialization.Encoding.PEM) + + +def test_bootstrap_ca_idempotent(tmp_path) -> None: + ca1 = bootstrap_ca(tmp_path) + ca2 = bootstrap_ca(tmp_path) + assert ca1.cert_pem == ca2.cert_pem + assert ca1.key_pem == ca2.key_pem + + +def test_sign_csr_returns_valid_client_cert(tmp_path) -> None: + ca = bootstrap_ca(tmp_path) + csr_pem = _build_csr("host-1-worker-1") + cert_pem = sign_csr(ca, csr_pem, "host-1-worker-1", ttl_hours=24) + cert = x509.load_pem_x509_certificate(cert_pem) + # CN matches + cn = cert.subject.get_attributes_for_oid(x509.oid.NameOID.COMMON_NAME)[0].value + assert cn == "host-1-worker-1" + # Signed by the CA + ca.cert.public_key().verify(cert.signature, cert.tbs_certificate_bytes) + # EKU = CLIENT_AUTH + eku = cert.extensions.get_extension_for_class(x509.ExtendedKeyUsage).value + assert x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH in eku + + +def test_fingerprint_of_is_deterministic_sha256(tmp_path) -> None: + ca = bootstrap_ca(tmp_path) + csr_pem = _build_csr("host-2-worker-1") + cert_pem = sign_csr(ca, csr_pem, "host-2-worker-1") + fp1 = fingerprint_of(cert_pem) + fp2 = fingerprint_of(cert_pem) + assert fp1 == fp2 + assert fp1.startswith("SHA256:") + assert len(fp1) == len("SHA256:") + 64 # hex sha256 + + +def test_ensure_server_cert_has_required_sans(tmp_path) -> None: + ca = bootstrap_ca(tmp_path) + cert_path, key_path = ensure_server_cert(ca, tmp_path, hostname="dlw-controller") + assert cert_path.exists() and key_path.exists() + cert = x509.load_pem_x509_certificate(cert_path.read_bytes()) + san = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value + dns_names = san.get_values_for_type(x509.DNSName) + ip_addrs = {str(ip) for ip in san.get_values_for_type(x509.IPAddress)} + assert "localhost" in dns_names + assert "dlw-controller" in dns_names + assert "127.0.0.1" in ip_addrs + assert "::1" in ip_addrs +``` + +- [ ] **Step 2: Run — verify ModuleNotFoundError** + +``` +uv run pytest tests/auth/test_ca.py -v +``` + +Expected: 4 collection errors, `ModuleNotFoundError: No module named 'dlw.auth.ca'`. + +- [ ] **Step 3: Implement `src/dlw/auth/ca.py`** + +```python +"""Self-signed CA + client cert signing + server cert (Phase 2 W3a §3.1).""" +from __future__ import annotations + +import datetime as _dt +import ipaddress +from dataclasses import dataclass +from pathlib import Path + +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import ed25519 +from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID + + +@dataclass(frozen=True) +class CABundle: + cert_pem: bytes + key_pem: bytes + cert: x509.Certificate + key: ed25519.Ed25519PrivateKey + + +def bootstrap_ca(ca_dir: Path) -> CABundle: + """Idempotent: load existing CA from disk, else generate + persist. + Files: ca-cert.pem, ca-key.pem (chmod 600). CA validity 10 years.""" + cert_path = ca_dir / "ca-cert.pem" + key_path = ca_dir / "ca-key.pem" + if cert_path.exists() and key_path.exists(): + cert_pem = cert_path.read_bytes() + key_pem = key_path.read_bytes() + cert = x509.load_pem_x509_certificate(cert_pem) + key = serialization.load_pem_private_key(key_pem, password=None) + if not isinstance(key, ed25519.Ed25519PrivateKey): + raise ValueError("CA key is not Ed25519 (file corrupted)") + return CABundle(cert_pem=cert_pem, key_pem=key_pem, cert=cert, key=key) + + key = ed25519.Ed25519PrivateKey.generate() + name = x509.Name([ + x509.NameAttribute(NameOID.COMMON_NAME, "dlw-controller-ca"), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "modelpull"), + ]) + now = _dt.datetime.now(_dt.UTC) + cert = (x509.CertificateBuilder() + .subject_name(name).issuer_name(name) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now) + .not_valid_after(now + _dt.timedelta(days=3650)) + .add_extension(x509.BasicConstraints(ca=True, path_length=0), critical=True) + .add_extension( + x509.KeyUsage( + digital_signature=False, content_commitment=False, + key_encipherment=False, data_encipherment=False, + key_agreement=False, key_cert_sign=True, crl_sign=True, + encipher_only=False, decipher_only=False, + ), critical=True, + ) + .sign(key, None) + ) + cert_pem = cert.public_bytes(serialization.Encoding.PEM) + key_pem = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + ca_dir.mkdir(mode=0o700, parents=True, exist_ok=True) + cert_path.write_bytes(cert_pem) + cert_path.chmod(0o600) + key_path.write_bytes(key_pem) + key_path.chmod(0o600) + return CABundle(cert_pem=cert_pem, key_pem=key_pem, cert=cert, key=key) + + +def sign_csr(ca: CABundle, csr_pem: bytes, executor_id: str, + ttl_hours: int = 24) -> bytes: + """Sign an executor CSR. CN = executor_id; SAN URI:spiffe://dlw/executor/; + EKU = CLIENT_AUTH. Raises ValueError on invalid CSR signature.""" + csr = x509.load_pem_x509_csr(csr_pem) + if not csr.is_signature_valid: + raise ValueError("CSR signature invalid") + now = _dt.datetime.now(_dt.UTC) + cert = (x509.CertificateBuilder() + .subject_name(x509.Name([ + x509.NameAttribute(NameOID.COMMON_NAME, executor_id), + ])) + .issuer_name(ca.cert.subject) + .public_key(csr.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now) + .not_valid_after(now + _dt.timedelta(hours=ttl_hours)) + .add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True) + .add_extension( + x509.SubjectAlternativeName([ + x509.UniformResourceIdentifier(f"spiffe://dlw/executor/{executor_id}"), + ]), critical=False, + ) + .add_extension( + x509.KeyUsage( + digital_signature=True, content_commitment=False, + key_encipherment=False, data_encipherment=False, + key_agreement=False, key_cert_sign=False, crl_sign=False, + encipher_only=False, decipher_only=False, + ), critical=True, + ) + .add_extension( + x509.ExtendedKeyUsage([ExtendedKeyUsageOID.CLIENT_AUTH]), critical=True, + ) + .sign(ca.key, None) + ) + return cert.public_bytes(serialization.Encoding.PEM) + + +def fingerprint_of(cert_pem: bytes) -> str: + """SHA256 fingerprint as 'SHA256:' — stored on executors.cert_fingerprint.""" + cert = x509.load_pem_x509_certificate(cert_pem) + return f"SHA256:{cert.fingerprint(hashes.SHA256()).hex()}" + + +def ensure_server_cert(ca: CABundle, ca_dir: Path, + hostname: str = "dlw-controller") -> tuple[Path, Path]: + """Idempotent: load or generate server-cert.pem + server-key.pem (chmod 600). + CN = hostname. SAN = DNS:localhost, DNS:, IP:127.0.0.1, IP:::1. + TTL 10 years. EKU = SERVER_AUTH. Returns (cert_path, key_path).""" + cert_path = ca_dir / "server-cert.pem" + key_path = ca_dir / "server-key.pem" + if cert_path.exists() and key_path.exists(): + return cert_path, key_path + + key = ed25519.Ed25519PrivateKey.generate() + now = _dt.datetime.now(_dt.UTC) + cert = (x509.CertificateBuilder() + .subject_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, hostname)])) + .issuer_name(ca.cert.subject) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now) + .not_valid_after(now + _dt.timedelta(days=3650)) + .add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True) + .add_extension( + x509.SubjectAlternativeName([ + x509.DNSName("localhost"), + x509.DNSName(hostname), + x509.IPAddress(ipaddress.ip_address("127.0.0.1")), + x509.IPAddress(ipaddress.ip_address("::1")), + ]), critical=False, + ) + .add_extension( + x509.ExtendedKeyUsage([ExtendedKeyUsageOID.SERVER_AUTH]), critical=True, + ) + .sign(ca.key, None) + ) + ca_dir.mkdir(mode=0o700, parents=True, exist_ok=True) + cert_path.write_bytes(cert.public_bytes(serialization.Encoding.PEM)) + cert_path.chmod(0o600) + key_path.write_bytes(key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + )) + key_path.chmod(0o600) + return cert_path, key_path +``` + +> Note: `cert.chmod(0o600)` on Windows is a no-op for the permission bits but does not error — the dev environment is Windows; CI is Linux where it takes effect. This matches how W3a's `.parts/` dirs already behave. + +- [ ] **Step 4: Run tests — verify all 4 pass** + +``` +uv run pytest tests/auth/test_ca.py -v +``` + +Expected: 4 passed. + +- [ ] **Step 5: Run full suite** + +``` +uv run pytest -x +``` + +Expected: 185 passed (181 + 4 new), 1 deselected. + +- [ ] **Step 6: Commit** + +```bash +git add src/dlw/auth/ca.py tests/auth/ +git commit -m "feat(auth): ca.py — self-signed CA + CSR signing + server cert (W3a M1)" +``` + +--- + +### Task 3: `dlw.auth.jwt_signing` + `dlw.auth.hmac_nonce` + +**Files:** +- Create: `src/dlw/auth/jwt_signing.py` +- Create: `src/dlw/auth/hmac_nonce.py` +- Create: `tests/auth/test_jwt_signing.py` +- Create: `tests/auth/test_hmac_nonce.py` + +- [ ] **Step 1: Write the failing tests** + +Create `tests/auth/test_jwt_signing.py`: + +```python +"""Tests for dlw.auth.jwt_signing (Phase 2 W3a §3.2).""" +from __future__ import annotations + +import time + +import jwt as _pyjwt +import pytest + +from dlw.auth.jwt_signing import bootstrap_keypair, sign, verify + + +def test_bootstrap_keypair_idempotent(tmp_path) -> None: + kp1 = bootstrap_keypair(tmp_path) + kp2 = bootstrap_keypair(tmp_path) + assert kp1.priv_pem == kp2.priv_pem + assert kp1.pub_pem == kp2.pub_pem + + +def test_sign_and_verify_roundtrip(tmp_path) -> None: + kp = bootstrap_keypair(tmp_path) + token = sign(kp, executor_id="host-1-worker-1", epoch=3, + scopes=["heartbeat", "poll"], ttl_seconds=3600) + claims = verify(kp, token) + assert claims["sub"] == "host-1-worker-1" + assert claims["epoch"] == 3 + assert claims["scope"] == "heartbeat poll" + assert claims["iss"] == "dlw-controller" + + +def test_verify_rejects_expired_token(tmp_path) -> None: + kp = bootstrap_keypair(tmp_path) + token = sign(kp, executor_id="e", epoch=1, scopes=["heartbeat"], + ttl_seconds=-10) # already expired + with pytest.raises(_pyjwt.PyJWTError): + verify(kp, token) + + +def test_verify_rejects_wrong_issuer(tmp_path) -> None: + kp = bootstrap_keypair(tmp_path) + # Hand-craft a token with a bad issuer using the same key. + now = int(time.time()) + bad = _pyjwt.encode( + {"iss": "evil", "sub": "e", "epoch": 1, "scope": "heartbeat", + "iat": now, "exp": now + 3600}, + kp.priv_pem.decode("utf-8"), algorithm="EdDSA", + ) + with pytest.raises(_pyjwt.PyJWTError): + verify(kp, bad) +``` + +Create `tests/auth/test_hmac_nonce.py`: + +```python +"""Tests for dlw.auth.hmac_nonce (Phase 2 W3a §3.3).""" +from __future__ import annotations + +import time + +from dlw.auth.hmac_nonce import NonceStore, compute_hmac, verify_hmac + + +_SEED = b"\x01" * 32 + + +def test_hmac_compute_and_verify_roundtrip() -> None: + body = b'{"health_score":100}' + sig = compute_hmac(_SEED, ts=1715739200, nonce="abc", body=body) + assert verify_hmac(_SEED, ts=1715739200, nonce="abc", body=body, + signature_hex=sig) + + +def test_hmac_verify_rejects_tampered_body() -> None: + body = b'{"health_score":100}' + sig = compute_hmac(_SEED, ts=1715739200, nonce="abc", body=body) + tampered = b'{"health_score":101}' + assert not verify_hmac(_SEED, ts=1715739200, nonce="abc", body=tampered, + signature_hex=sig) + + +def test_nonce_store_first_add_then_seen() -> None: + store = NonceStore(maxsize=100, ttl_seconds=300) + assert not store.seen("n1") + store.add("n1") + assert store.seen("n1") + + +def test_nonce_store_evicts_after_ttl(monkeypatch) -> None: + store = NonceStore(maxsize=100, ttl_seconds=10) + fake = [1000.0] + monkeypatch.setattr("dlw.auth.hmac_nonce.time.monotonic", lambda: fake[0]) + store.add("n1") + assert store.seen("n1") + fake[0] += 11 # advance past TTL + assert not store.seen("n1") +``` + +- [ ] **Step 2: Run — verify ModuleNotFoundError** + +``` +uv run pytest tests/auth/test_jwt_signing.py tests/auth/test_hmac_nonce.py -v +``` + +Expected: 8 collection errors, `ModuleNotFoundError`. + +- [ ] **Step 3: Implement `src/dlw/auth/jwt_signing.py`** + +```python +"""Ed25519 JWT signing for executor JWTs (Phase 2 W3a §3.2).""" +from __future__ import annotations + +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import jwt as _pyjwt +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ed25519 + + +@dataclass(frozen=True) +class JWTKeypair: + priv_pem: bytes + pub_pem: bytes + + +def bootstrap_keypair(ca_dir: Path) -> JWTKeypair: + """Idempotent: load or generate jwt-signing.pem (chmod 600, PKCS8 Ed25519).""" + priv_path = ca_dir / "jwt-signing.pem" + if priv_path.exists(): + priv_pem = priv_path.read_bytes() + priv = serialization.load_pem_private_key(priv_pem, password=None) + if not isinstance(priv, ed25519.Ed25519PrivateKey): + raise ValueError("JWT signing key is not Ed25519") + else: + priv = ed25519.Ed25519PrivateKey.generate() + priv_pem = priv.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + ca_dir.mkdir(mode=0o700, parents=True, exist_ok=True) + priv_path.write_bytes(priv_pem) + priv_path.chmod(0o600) + pub_pem = priv.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + return JWTKeypair(priv_pem=priv_pem, pub_pem=pub_pem) + + +def sign(kp: JWTKeypair, *, executor_id: str, epoch: int, + scopes: list[str], ttl_seconds: int = 3600) -> str: + """Sign an executor JWT. Returns compact JWS.""" + now = int(time.time()) + claims = { + "iss": "dlw-controller", + "sub": executor_id, + "epoch": epoch, + "scope": " ".join(scopes), + "iat": now, + "exp": now + ttl_seconds, + } + return _pyjwt.encode(claims, kp.priv_pem.decode("utf-8"), algorithm="EdDSA") + + +def verify(kp: JWTKeypair, token: str) -> dict[str, Any]: + """Decode + verify. Raises jwt.PyJWTError on any failure.""" + return _pyjwt.decode( + token, kp.pub_pem.decode("utf-8"), + algorithms=["EdDSA"], + issuer="dlw-controller", + options={"require": ["sub", "epoch", "scope", "exp", "iss", "iat"]}, + ) +``` + +- [ ] **Step 4: Implement `src/dlw/auth/hmac_nonce.py`** + +```python +"""HMAC heartbeat: nonce store + signature verify (Phase 2 W3a §3.3).""" +from __future__ import annotations + +import hashlib +import hmac as _hmac +import time +from collections import OrderedDict + + +class NonceStore: + """In-process LRU with timestamp-based eviction. asyncio single-threaded — + no lock needed. Restart loses state; replay defense is bounded by the + ±5min timestamp window enforced at the dependency layer.""" + + def __init__(self, *, maxsize: int = 10_000, ttl_seconds: int = 300) -> None: + self._maxsize = maxsize + self._ttl = ttl_seconds + self._data: OrderedDict[str, float] = OrderedDict() + + def _evict_expired(self) -> None: + cutoff = time.monotonic() - self._ttl + while self._data: + k, v = next(iter(self._data.items())) + if v >= cutoff: + break + self._data.popitem(last=False) + + def seen(self, nonce: str) -> bool: + self._evict_expired() + return nonce in self._data + + def add(self, nonce: str) -> None: + self._evict_expired() + if len(self._data) >= self._maxsize: + self._data.popitem(last=False) + self._data[nonce] = time.monotonic() + + +def compute_hmac(hmac_seed: bytes, *, ts: int, nonce: str, body: bytes) -> str: + """HMAC-SHA256(hmac_seed, f'{ts}:{nonce}:'.encode() + body). Hex string.""" + msg = f"{ts}:{nonce}:".encode("utf-8") + body + return _hmac.new(hmac_seed, msg, hashlib.sha256).hexdigest() + + +def verify_hmac(hmac_seed: bytes, *, ts: int, nonce: str, body: bytes, + signature_hex: str) -> bool: + expected = compute_hmac(hmac_seed, ts=ts, nonce=nonce, body=body) + return _hmac.compare_digest(expected, signature_hex) +``` + +- [ ] **Step 5: Run tests — verify all 8 pass** + +``` +uv run pytest tests/auth/test_jwt_signing.py tests/auth/test_hmac_nonce.py -v +``` + +Expected: 8 passed. + +- [ ] **Step 6: Run full suite** + +``` +uv run pytest -x +``` + +Expected: 193 passed (185 + 8 new), 1 deselected. + +- [ ] **Step 7: Commit** + +```bash +git add src/dlw/auth/jwt_signing.py src/dlw/auth/hmac_nonce.py tests/auth/test_jwt_signing.py tests/auth/test_hmac_nonce.py +git commit -m "feat(auth): jwt_signing (Ed25519 JWT) + hmac_nonce (NonceStore) (W3a M1)" +``` + +--- + +### Milestone 1 verification (self) + +- [ ] `cryptography` + `pyjwt` in `pyproject.toml`; `uv sync` clean. +- [ ] alembic head is the new revision; `executors.hmac_seed_encrypted` exists. +- [ ] `ca.py` / `jwt_signing.py` / `hmac_nonce.py` import cleanly; 16 unit tests pass. +- [ ] Full suite at 193. + +--- + +## Milestone 2 — Controller deps + endpoints + +After M2: three FastAPI dependencies + `require_executor_epoch` refactor + `/register` + `/renew` endpoints + `main.py` bootstrap. `/join` deleted. ~14 new tests. + +--- + +### Task 4: FastAPI dependencies + `require_executor_epoch` refactor + +**Files:** +- Create: `src/dlw/auth/executor_mtls.py`, `src/dlw/auth/executor_jwt_dep.py`, `src/dlw/auth/hmac_heartbeat_dep.py` +- Modify: `src/dlw/auth/executor_epoch.py` +- Modify: `tests/conftest.py` (+ephemeral_ca, +client_cert_pair fixtures) +- Create: `tests/auth/test_executor_mtls_dep.py`, `tests/auth/test_executor_jwt_dep.py`, `tests/auth/test_hmac_heartbeat_dep.py` +- Modify: `tests/auth/test_executor_epoch.py` (+confused-deputy case) + +- [ ] **Step 1: Add conftest fixtures** + +In `tests/conftest.py`, add at module level (after the existing fixtures): + +```python +@pytest.fixture(scope="session") +def ephemeral_ca(tmp_path_factory): + """One CA + JWT keypair per test session, in a temp dir.""" + from dlw.auth.ca import bootstrap_ca + from dlw.auth.jwt_signing import bootstrap_keypair + ca_dir = tmp_path_factory.mktemp("ca") + ca = bootstrap_ca(ca_dir) + jwt_kp = bootstrap_keypair(ca_dir) + return {"ca": ca, "jwt_keypair": jwt_kp, "ca_dir": ca_dir} + + +@pytest.fixture +def client_cert_pair(ephemeral_ca): + """Per-test client cert (executor 'test-executor-1') signed by the session CA. + Returns (cert_pem: bytes, key: Ed25519PrivateKey, executor_id: str).""" + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric import ed25519 + from cryptography import x509 + from cryptography.x509.oid import NameOID + from dlw.auth.ca import sign_csr + executor_id = "test-executor-1" + key = ed25519.Ed25519PrivateKey.generate() + csr = (x509.CertificateSigningRequestBuilder() + .subject_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, executor_id)])) + .sign(key, None)) + csr_pem = csr.public_bytes(serialization.Encoding.PEM) + cert_pem = sign_csr(ephemeral_ca["ca"], csr_pem, executor_id, ttl_hours=24) + return cert_pem, key, executor_id +``` + +(Don't remove or alter any existing conftest fixture.) + +- [ ] **Step 2: Write the failing dependency tests** + +Create `tests/auth/test_executor_mtls_dep.py`: + +```python +"""Tests for require_executor_mtls (Phase 2 W3a §3.4).""" +from __future__ import annotations + +import pytest +from fastapi import FastAPI, Depends +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import async_sessionmaker + +from dlw.auth.ca import fingerprint_of +from dlw.auth.executor_mtls import require_executor_mtls +from dlw.db.base import Base +from dlw.db.models.executor import Executor + + +@pytest.fixture(scope="module", autouse=True) +async def _create_tables(engine): + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + + +def _mini_app(): + app = FastAPI() + + @app.get("/whoami") + async def whoami(ex: Executor = Depends(require_executor_mtls)) -> dict: + return {"executor_id": ex.id} + + return app + + +@pytest.mark.slow +async def test_require_executor_mtls_via_trusted_proxy_header( + engine, client_cert_pair, monkeypatch, +) -> None: + monkeypatch.setenv("DLW_TLS_TRUSTED_PROXY", "1") + cert_pem, _key, executor_id = client_cert_pair + fp = fingerprint_of(cert_pem) + factory = async_sessionmaker(engine, expire_on_commit=False) + async with factory() as s: + s.add(Executor(id=executor_id, host_id="h", cert_fingerprint=fp, + status="healthy", epoch=1)) + await s.commit() + + app = _mini_app() + async with AsyncClient(transport=ASGITransport(app=app), + base_url="http://test") as c: + r = await c.get("/whoami", headers={ + "X-Client-Cert-PEM": cert_pem.decode("utf-8").replace("\n", "\\n"), + }) + assert r.status_code == 200 + assert r.json()["executor_id"] == executor_id + + +@pytest.mark.slow +async def test_require_executor_mtls_rejects_unknown_fingerprint( + client_cert_pair, monkeypatch, +) -> None: + monkeypatch.setenv("DLW_TLS_TRUSTED_PROXY", "1") + cert_pem, _key, _ = client_cert_pair # cert NOT inserted into DB + app = _mini_app() + async with AsyncClient(transport=ASGITransport(app=app), + base_url="http://test") as c: + r = await c.get("/whoami", headers={ + "X-Client-Cert-PEM": cert_pem.decode("utf-8").replace("\n", "\\n"), + }) + assert r.status_code == 401 + + +@pytest.mark.slow +async def test_require_executor_mtls_rejects_header_when_proxy_disabled( + client_cert_pair, monkeypatch, +) -> None: + monkeypatch.setenv("DLW_TLS_TRUSTED_PROXY", "0") + cert_pem, _key, _ = client_cert_pair + app = _mini_app() + async with AsyncClient(transport=ASGITransport(app=app), + base_url="http://test") as c: + r = await c.get("/whoami", headers={ + "X-Client-Cert-PEM": cert_pem.decode("utf-8").replace("\n", "\\n"), + }) + assert r.status_code == 401 +``` + +Create `tests/auth/test_executor_jwt_dep.py`: + +```python +"""Tests for require_executor_jwt (Phase 2 W3a §3.4).""" +from __future__ import annotations + +import pytest +from fastapi import FastAPI, Depends +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import async_sessionmaker + +from dlw.auth.ca import fingerprint_of +from dlw.auth.executor_jwt_dep import require_executor_jwt +from dlw.auth.jwt_signing import sign +from dlw.db.base import Base +from dlw.db.models.executor import Executor + + +@pytest.fixture(scope="module", autouse=True) +async def _create_tables(engine): + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + + +def _mini_app(jwt_keypair): + app = FastAPI() + app.state.jwt_keypair = jwt_keypair + + @app.get("/whoami") + async def whoami(ex: Executor = Depends(require_executor_jwt)) -> dict: + return {"executor_id": ex.id} + + return app + + +@pytest.mark.slow +async def test_require_executor_jwt_accepts_valid_token( + engine, ephemeral_ca, client_cert_pair, monkeypatch, +) -> None: + monkeypatch.setenv("DLW_TLS_TRUSTED_PROXY", "1") + cert_pem, _key, executor_id = client_cert_pair + fp = fingerprint_of(cert_pem) + factory = async_sessionmaker(engine, expire_on_commit=False) + async with factory() as s: + s.add(Executor(id=executor_id, host_id="h", cert_fingerprint=fp, + status="healthy", epoch=2)) + await s.commit() + token = sign(ephemeral_ca["jwt_keypair"], executor_id=executor_id, + epoch=2, scopes=["heartbeat"]) + + app = _mini_app(ephemeral_ca["jwt_keypair"]) + async with AsyncClient(transport=ASGITransport(app=app), + base_url="http://test") as c: + r = await c.get("/whoami", headers={ + "X-Client-Cert-PEM": cert_pem.decode("utf-8").replace("\n", "\\n"), + "Authorization": f"Bearer {token}", + }) + assert r.status_code == 200 + assert r.json()["executor_id"] == executor_id + + +@pytest.mark.slow +async def test_require_executor_jwt_rejects_sub_mismatch( + engine, ephemeral_ca, client_cert_pair, monkeypatch, +) -> None: + monkeypatch.setenv("DLW_TLS_TRUSTED_PROXY", "1") + cert_pem, _key, executor_id = client_cert_pair + fp = fingerprint_of(cert_pem) + factory = async_sessionmaker(engine, expire_on_commit=False) + async with factory() as s: + s.add(Executor(id=executor_id, host_id="h", cert_fingerprint=fp, + status="healthy", epoch=2)) + await s.commit() + # JWT for a DIFFERENT executor + token = sign(ephemeral_ca["jwt_keypair"], executor_id="other-executor", + epoch=2, scopes=["heartbeat"]) + + app = _mini_app(ephemeral_ca["jwt_keypair"]) + async with AsyncClient(transport=ASGITransport(app=app), + base_url="http://test") as c: + r = await c.get("/whoami", headers={ + "X-Client-Cert-PEM": cert_pem.decode("utf-8").replace("\n", "\\n"), + "Authorization": f"Bearer {token}", + }) + assert r.status_code == 401 +``` + +Create `tests/auth/test_hmac_heartbeat_dep.py`: + +```python +"""Tests for require_hmac_heartbeat (Phase 2 W3a §3.4).""" +from __future__ import annotations + +import secrets +import time + +import pytest +from fastapi import FastAPI, Depends, Request +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import async_sessionmaker + +from dlw.auth.ca import fingerprint_of +from dlw.auth.hmac_heartbeat_dep import require_hmac_heartbeat +from dlw.auth.hmac_nonce import NonceStore, compute_hmac +from dlw.auth.jwt_signing import sign +from dlw.db.base import Base +from dlw.db.models.executor import Executor + + +@pytest.fixture(scope="module", autouse=True) +async def _create_tables(engine): + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + + +_HMAC_SEED = b"\x02" * 32 + + +async def _seed_executor(engine, executor_id, fp): + factory = async_sessionmaker(engine, expire_on_commit=False) + async with factory() as s: + s.add(Executor(id=executor_id, host_id="h", cert_fingerprint=fp, + status="healthy", epoch=1, + hmac_seed_encrypted=_HMAC_SEED)) + await s.commit() + + +def _mini_app(jwt_keypair): + app = FastAPI() + app.state.jwt_keypair = jwt_keypair + app.state.nonce_store = NonceStore(maxsize=100, ttl_seconds=300) + + @app.post("/hb") + async def hb(request: Request, + ex: Executor = Depends(require_hmac_heartbeat)) -> dict: + return {"ok": True, "executor_id": ex.id} + + return app + + +def _hmac_headers(seed, body: bytes, *, ts: int | None = None, nonce: str | None = None): + ts = ts if ts is not None else int(time.time()) + nonce = nonce or secrets.token_hex(16) + sig = compute_hmac(seed, ts=ts, nonce=nonce, body=body) + return {"X-HMAC-Timestamp": str(ts), "X-HMAC-Nonce": nonce, + "X-HMAC-Signature": sig} + + +@pytest.mark.slow +async def test_hmac_heartbeat_accepts_valid_signature( + engine, ephemeral_ca, client_cert_pair, monkeypatch, +) -> None: + monkeypatch.setenv("DLW_TLS_TRUSTED_PROXY", "1") + cert_pem, _key, executor_id = client_cert_pair + fp = fingerprint_of(cert_pem) + await _seed_executor(engine, executor_id, fp) + token = sign(ephemeral_ca["jwt_keypair"], executor_id=executor_id, + epoch=1, scopes=["heartbeat"]) + body = b'{"health_score":100}' + app = _mini_app(ephemeral_ca["jwt_keypair"]) + async with AsyncClient(transport=ASGITransport(app=app), + base_url="http://test") as c: + r = await c.post("/hb", content=body, headers={ + "X-Client-Cert-PEM": cert_pem.decode("utf-8").replace("\n", "\\n"), + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + **_hmac_headers(_HMAC_SEED, body), + }) + assert r.status_code == 200 + + +@pytest.mark.slow +async def test_hmac_heartbeat_rejects_clock_skew( + engine, ephemeral_ca, client_cert_pair, monkeypatch, +) -> None: + monkeypatch.setenv("DLW_TLS_TRUSTED_PROXY", "1") + cert_pem, _key, executor_id = client_cert_pair + fp = fingerprint_of(cert_pem) + await _seed_executor(engine, executor_id, fp) + token = sign(ephemeral_ca["jwt_keypair"], executor_id=executor_id, + epoch=1, scopes=["heartbeat"]) + body = b'{"health_score":100}' + app = _mini_app(ephemeral_ca["jwt_keypair"]) + async with AsyncClient(transport=ASGITransport(app=app), + base_url="http://test") as c: + r = await c.post("/hb", content=body, headers={ + "X-Client-Cert-PEM": cert_pem.decode("utf-8").replace("\n", "\\n"), + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + **_hmac_headers(_HMAC_SEED, body, ts=int(time.time()) - 400), + }) + assert r.status_code == 401 + + +@pytest.mark.slow +async def test_hmac_heartbeat_rejects_replay( + engine, ephemeral_ca, client_cert_pair, monkeypatch, +) -> None: + monkeypatch.setenv("DLW_TLS_TRUSTED_PROXY", "1") + cert_pem, _key, executor_id = client_cert_pair + fp = fingerprint_of(cert_pem) + await _seed_executor(engine, executor_id, fp) + token = sign(ephemeral_ca["jwt_keypair"], executor_id=executor_id, + epoch=1, scopes=["heartbeat"]) + body = b'{"health_score":100}' + headers_base = { + "X-Client-Cert-PEM": cert_pem.decode("utf-8").replace("\n", "\\n"), + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + } + fixed_nonce = "fixed-replay-nonce" + hmac_h = _hmac_headers(_HMAC_SEED, body, nonce=fixed_nonce) + app = _mini_app(ephemeral_ca["jwt_keypair"]) + async with AsyncClient(transport=ASGITransport(app=app), + base_url="http://test") as c: + r1 = await c.post("/hb", content=body, headers={**headers_base, **hmac_h}) + r2 = await c.post("/hb", content=body, headers={**headers_base, **hmac_h}) + assert r1.status_code == 200 + assert r2.status_code == 401 # REPLAY_DETECTED + + +@pytest.mark.slow +async def test_hmac_heartbeat_rejects_tampered_body( + engine, ephemeral_ca, client_cert_pair, monkeypatch, +) -> None: + monkeypatch.setenv("DLW_TLS_TRUSTED_PROXY", "1") + cert_pem, _key, executor_id = client_cert_pair + fp = fingerprint_of(cert_pem) + await _seed_executor(engine, executor_id, fp) + token = sign(ephemeral_ca["jwt_keypair"], executor_id=executor_id, + epoch=1, scopes=["heartbeat"]) + signed_body = b'{"health_score":100}' + sent_body = b'{"health_score":999}' + app = _mini_app(ephemeral_ca["jwt_keypair"]) + async with AsyncClient(transport=ASGITransport(app=app), + base_url="http://test") as c: + r = await c.post("/hb", content=sent_body, headers={ + "X-Client-Cert-PEM": cert_pem.decode("utf-8").replace("\n", "\\n"), + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + **_hmac_headers(_HMAC_SEED, signed_body), # sig over signed_body + }) + assert r.status_code == 401 # HMAC_INVALID +``` + +In `tests/auth/test_executor_epoch.py`, ADD one case (keep existing tests, migrate any `/join`-based setup to a direct DB insert): + +```python +@pytest.mark.slow +async def test_require_executor_epoch_rejects_path_id_mismatch( + engine, ephemeral_ca, client_cert_pair, monkeypatch, +) -> None: + """mTLS+JWT authenticate executor A, but the URL path says executor B → + 401 EXECUTOR_ID_MISMATCH (confused-deputy guard).""" + from dlw.auth.ca import fingerprint_of + from dlw.auth.jwt_signing import sign + from fastapi import FastAPI, Depends, Path + from httpx import ASGITransport, AsyncClient + from sqlalchemy.ext.asyncio import async_sessionmaker + from dlw.auth.executor_epoch import require_executor_epoch + from dlw.db.models.executor import Executor + + monkeypatch.setenv("DLW_TLS_TRUSTED_PROXY", "1") + cert_pem, _key, executor_id = client_cert_pair # executor_id == "test-executor-1" + fp = fingerprint_of(cert_pem) + factory = async_sessionmaker(engine, expire_on_commit=False) + async with factory() as s: + s.add(Executor(id=executor_id, host_id="h", cert_fingerprint=fp, + status="healthy", epoch=1)) + await s.commit() + token = sign(ephemeral_ca["jwt_keypair"], executor_id=executor_id, + epoch=1, scopes=["heartbeat"]) + + app = FastAPI() + app.state.jwt_keypair = ephemeral_ca["jwt_keypair"] + + @app.post("/executors/{executor_id}/x") + async def x(executor_id: str = Path(...), + ex: Executor = Depends(require_executor_epoch)) -> dict: + return {"ok": True} + + async with AsyncClient(transport=ASGITransport(app=app), + base_url="http://test") as c: + # path says "other-executor" but cert+JWT are for "test-executor-1" + r = await c.post("/executors/other-executor/x", headers={ + "X-Client-Cert-PEM": cert_pem.decode("utf-8").replace("\n", "\\n"), + "Authorization": f"Bearer {token}", + "X-Executor-Epoch": "1", + }) + assert r.status_code == 401 +``` + +- [ ] **Step 3: Run — verify the dep tests fail with ModuleNotFoundError** + +``` +uv run pytest tests/auth/test_executor_mtls_dep.py tests/auth/test_executor_jwt_dep.py tests/auth/test_hmac_heartbeat_dep.py -v +``` + +Expected: collection errors for the three missing modules. + +- [ ] **Step 4: Implement `src/dlw/auth/executor_mtls.py`** + +```python +"""mTLS peer-cert dependency (Phase 2 W3a §3.4).""" +from __future__ import annotations + +import os + +from cryptography import x509 +from cryptography.hazmat.primitives import serialization +from fastapi import Depends, HTTPException, Request +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from dlw.api.tasks import _session +from dlw.auth.ca import fingerprint_of +from dlw.db.models.executor import Executor + + +def _extract_peer_cert(request: Request) -> bytes | None: + """Two paths: (a) direct uvicorn TLS — peercert in scope; (b) trusted-proxy + forwarded header — only honored when DLW_TLS_TRUSTED_PROXY=1.""" + transport = request.scope.get("transport") + if transport is not None: + peercert = transport.get_extra_info("peercert") if hasattr( + transport, "get_extra_info") else None + if peercert: + # uvicorn provides the DER-encoded peer cert + try: + cert = x509.load_der_x509_certificate(peercert) + return cert.public_bytes(serialization.Encoding.PEM) + except Exception: + pass + if os.environ.get("DLW_TLS_TRUSTED_PROXY") == "1": + header = request.headers.get("X-Client-Cert-PEM") + if header: + return header.replace("\\n", "\n").encode("utf-8") + return None + + +async def require_executor_mtls( + request: Request, + session: AsyncSession = Depends(_session), +) -> Executor: + """Validate mTLS peer cert + look up executor by fingerprint.""" + cert_pem = _extract_peer_cert(request) + if cert_pem is None: + raise HTTPException(401, detail="missing or invalid mTLS peer cert") + try: + fp = fingerprint_of(cert_pem) + except Exception as e: + raise HTTPException(401, detail=f"invalid client cert: {e}") from e + ex = (await session.execute( + select(Executor).where(Executor.cert_fingerprint == fp) + )).scalar_one_or_none() + if ex is None: + raise HTTPException(401, detail="cert fingerprint not registered") + return ex +``` + +- [ ] **Step 5: Implement `src/dlw/auth/executor_jwt_dep.py`** + +```python +"""Executor JWT dependency (Phase 2 W3a §3.4).""" +from __future__ import annotations + +import jwt as _pyjwt +from fastapi import Depends, Header, HTTPException, Request + +from dlw.auth.executor_mtls import require_executor_mtls +from dlw.auth.jwt_signing import verify +from dlw.db.models.executor import Executor + + +async def require_executor_jwt( + request: Request, + authorization: str | None = Header(default=None), + ex: Executor = Depends(require_executor_mtls), +) -> Executor: + if not authorization or not authorization.startswith("Bearer "): + raise HTTPException(401, detail="missing executor JWT") + token = authorization.split(" ", 1)[1] + try: + claims = verify(request.app.state.jwt_keypair, token) + except _pyjwt.PyJWTError as e: + raise HTTPException(401, detail=f"invalid JWT: {e}") from e + if claims["sub"] != ex.id: + raise HTTPException(401, detail="JWT sub mismatch") + return ex +``` + +- [ ] **Step 6: Implement `src/dlw/auth/hmac_heartbeat_dep.py`** + +```python +"""HMAC heartbeat dependency (Phase 2 W3a §3.4).""" +from __future__ import annotations + +import time + +from fastapi import Depends, Header, HTTPException, Request + +from dlw.auth.executor_jwt_dep import require_executor_jwt +from dlw.auth.hmac_nonce import verify_hmac +from dlw.db.models.executor import Executor + +_TIMESTAMP_SKEW_SECONDS = 300 + + +async def require_hmac_heartbeat( + request: Request, + x_hmac_timestamp: int = Header(..., alias="X-HMAC-Timestamp"), + x_hmac_nonce: str = Header(..., alias="X-HMAC-Nonce"), + x_hmac_signature: str = Header(..., alias="X-HMAC-Signature"), + ex: Executor = Depends(require_executor_jwt), +) -> Executor: + now = int(time.time()) + if abs(now - x_hmac_timestamp) > _TIMESTAMP_SKEW_SECONDS: + raise HTTPException(401, detail="CLOCK_SKEW") + store = request.app.state.nonce_store + if store.seen(x_hmac_nonce): + raise HTTPException(401, detail="REPLAY_DETECTED") + if ex.hmac_seed_encrypted is None: + raise HTTPException(401, detail="HMAC_SEED_MISSING — re-register") + hmac_seed = bytes(ex.hmac_seed_encrypted) + body = await request.body() + if not verify_hmac(hmac_seed, ts=x_hmac_timestamp, nonce=x_hmac_nonce, + body=body, signature_hex=x_hmac_signature): + raise HTTPException(401, detail="HMAC_INVALID") + store.add(x_hmac_nonce) + return ex +``` + +- [ ] **Step 7: Refactor `src/dlw/auth/executor_epoch.py`** + +Replace the whole file with: + +```python +"""require_executor_epoch — W1 fence-token dep, refactored for W3a §3.4. + +Under W3a it chains under require_executor_jwt: the Executor row is already +loaded + authenticated via the mTLS cert fingerprint. This dep adds two +checks on top: + 1. the path executor_id MUST equal the mTLS-authenticated identity + (confused-deputy guard — W3a closes the gap where a valid cert for + executor A could be used against /executors/B/...); + 2. X-Executor-Epoch MUST match the authenticated row's epoch (W1 fence). +""" +from __future__ import annotations + +from fastapi import Depends, Header, HTTPException, Path + +from dlw.auth.executor_jwt_dep import require_executor_jwt +from dlw.db.models.executor import Executor + + +async def require_executor_epoch( + executor_id: str = Path(..., description="Executor id from URL path"), + x_executor_epoch: int | None = Header(default=None, alias="X-Executor-Epoch"), + ex: Executor = Depends(require_executor_jwt), +) -> Executor: + """Return the mTLS+JWT-authenticated Executor row if the path id matches + and the epoch header matches; else 401.""" + if executor_id != ex.id: + raise HTTPException( + status_code=401, + detail={"code": "EXECUTOR_ID_MISMATCH", + "path": executor_id, "authenticated": ex.id}, + ) + if x_executor_epoch is None: + raise HTTPException(status_code=401, detail="missing X-Executor-Epoch header") + if ex.epoch != x_executor_epoch: + raise HTTPException( + status_code=401, + detail={"code": "EPOCH_MISMATCH", "expected": ex.epoch, + "got": x_executor_epoch}, + ) + return ex +``` + +> The old W1 version did its own `session.get(Executor, executor_id)`; the W3a version receives the already-loaded + authenticated row from `require_executor_jwt`. The `from dlw.api.tasks import _session` import is removed (no longer does its own lookup). + +- [ ] **Step 8: Run the dep tests — verify all pass** + +``` +uv run pytest tests/auth/ -v +``` + +Expected: all `tests/auth/` pass — the new dep tests + the existing W1 epoch tests (migrated) + the new confused-deputy case. If a W1 `test_executor_epoch.py` case constructed an executor via `/join`, replace that setup with a direct DB insert + a signed JWT (mirror the new case's setup). + +- [ ] **Step 9: Run full suite** + +``` +uv run pytest -x +``` + +Expected: 193 (M1) + ~13 new auth dep tests = ~206; existing W1 epoch tests may have been rewritten but count stays roughly same. Some W1 API tests (`test_executors.py`, `test_subtasks.py`) will now FAIL because their endpoints still use the old `require_executor_epoch` signature indirectly — that's expected; Task 6 migrates them. **For this task, run `uv run pytest tests/auth/ -x` and confirm green; the full-suite breakage in `tests/api/` is addressed in Task 6.** Note the count of failures so Task 6 can confirm they all clear. + +- [ ] **Step 10: Commit** + +```bash +git add src/dlw/auth/executor_mtls.py src/dlw/auth/executor_jwt_dep.py src/dlw/auth/hmac_heartbeat_dep.py src/dlw/auth/executor_epoch.py tests/conftest.py tests/auth/ +git commit -m "feat(auth): mTLS + JWT + HMAC FastAPI deps; executor_epoch confused-deputy guard (W3a M2)" +``` + +--- + +### Task 5: `/register` + `/renew` endpoints + `main.py` bootstrap + service rename + +**Files:** +- Modify: `src/dlw/schemas/executor.py` +- Modify: `src/dlw/services/executor_service.py` +- Modify: `src/dlw/api/executors.py` +- Modify: `src/dlw/main.py` +- Modify: `src/dlw/config.py` +- Create: `tests/api/test_register_endpoint.py`, `tests/api/test_renew_endpoint.py` + +- [ ] **Step 1: Add schemas to `src/dlw/schemas/executor.py`** + +Read the file. Add (and DELETE the W1 `ExecutorJoin` schema if present): + +```python +class ExecutorRegister(BaseModel): + host_id: str + executor_id_proposal: str + capabilities: dict[str, Any] = {} + client_csr_pem: str + + +class RegistrationResponse(BaseModel): + executor_id: str + epoch: int + client_cert_pem: str + ca_chain: list[str] + executor_jwt: str + hmac_seed_hex: str + cert_renew_in_seconds: int + jwt_renew_in_seconds: int + + +class RenewRequest(BaseModel): + # Optional: the executor includes a fresh CSR only when its cert is near + # expiry. When None, /renew refreshes the JWT only. + client_csr_pem: str | None = None + + +class RenewResponse(BaseModel): + executor_jwt: str + jwt_renew_in_seconds: int + client_cert_pem: str | None = None + cert_renew_in_seconds: int | None = None +``` + +Add `from typing import Any` to imports if absent. + +- [ ] **Step 2: Rename + extend `join_executor` in `src/dlw/services/executor_service.py`** + +Read the file. The W1 `join_executor` does a `pg_insert ... ON CONFLICT DO UPDATE` epoch-bump. Rename it to `upsert_executor_with_cert` and add `cert_fingerprint` + `hmac_seed` params: + +```python +async def upsert_executor_with_cert( + session: AsyncSession, + *, + executor_id: str, + host_id: str, + capabilities: dict[str, Any], + cert_fingerprint: str, + hmac_seed: bytes, +) -> Executor: + """W3a §3.8: INSERT-or-bump executor row, writing cert_fingerprint + + hmac_seed_encrypted. Same atomic epoch semantics as W1 join_executor: + epoch=1 on insert, +1 on conflict. Caller commits.""" + from sqlalchemy.dialects.postgresql import insert as pg_insert + + stmt = ( + pg_insert(Executor) + .values( + id=executor_id, host_id=host_id, capabilities=capabilities, + cert_fingerprint=cert_fingerprint, hmac_seed_encrypted=hmac_seed, + status="joining", epoch=1, + ) + .on_conflict_do_update( + index_elements=["id"], + set_={ + "host_id": host_id, + "capabilities": capabilities, + "cert_fingerprint": cert_fingerprint, + "hmac_seed_encrypted": hmac_seed, + "status": "joining", + "epoch": Executor.epoch + 1, + }, + ) + .returning(Executor) + ) + row = (await session.execute(stmt)).scalar_one() + return row +``` + +(Match the exact W1 `join_executor` structure — read it first. The W1 version may set `status="joining"` or `status="healthy"`; preserve whatever W1 did. The W3a additions are the two new columns. If W1's `join_executor` had different param names, keep its body shape and only add the cert + seed handling.) + +Keep a thin `join_executor` alias removed entirely — grep for callers and migrate them all to `upsert_executor_with_cert` (Task 6 + the test migration handle the test callers). + +- [ ] **Step 3: Rewrite the executor endpoints in `src/dlw/api/executors.py`** + +Read the file. DELETE the `POST /join` endpoint. Add imports: + +```python +import secrets +from cryptography import x509 +from dlw.auth.ca import fingerprint_of, sign_csr +from dlw.auth import jwt_signing +from dlw.auth.executor_jwt_dep import require_executor_jwt +from dlw.schemas.executor import ( + ExecutorRegister, RegistrationResponse, RenewRequest, RenewResponse, +) +from dlw.services.executor_service import upsert_executor_with_cert +``` + +Add `/register`: + +```python +@router.post("/register", status_code=status.HTTP_201_CREATED) +async def post_register( + body: ExecutorRegister, + request: Request, + x_enrollment_token: str = Header(..., alias="X-Enrollment-Token"), + session: AsyncSession = Depends(_session), +) -> RegistrationResponse: + """W3a §3.5: enrollment-token auth; signs CSR; returns cert + JWT + hmac_seed.""" + expected = request.app.state.enrollment_token + if not secrets.compare_digest(x_enrollment_token, expected): + raise HTTPException(401, detail="invalid enrollment token") + try: + cert_pem = sign_csr( + request.app.state.ca, + body.client_csr_pem.encode("utf-8"), + executor_id=body.executor_id_proposal, + ttl_hours=24, + ) + except ValueError as e: + raise HTTPException(422, detail=f"invalid CSR: {e}") from e + fp = fingerprint_of(cert_pem) + hmac_seed = secrets.token_bytes(32) + ex = await upsert_executor_with_cert( + session, executor_id=body.executor_id_proposal, + host_id=body.host_id, capabilities=body.capabilities, + cert_fingerprint=fp, hmac_seed=hmac_seed, + ) + token = jwt_signing.sign( + request.app.state.jwt_keypair, + executor_id=ex.id, epoch=ex.epoch, + scopes=["heartbeat", "poll", "report"], + ) + await session.commit() + return RegistrationResponse( + executor_id=ex.id, epoch=ex.epoch, + client_cert_pem=cert_pem.decode("utf-8"), + ca_chain=[request.app.state.ca.cert_pem.decode("utf-8")], + executor_jwt=token, + hmac_seed_hex=hmac_seed.hex(), + cert_renew_in_seconds=86100, + jwt_renew_in_seconds=3300, + ) +``` + +Add `/renew` — the executor sends an optional fresh CSR in the body (the controller cannot re-sign from a bare public key; a CSR is self-signed by the executor's private key, which the controller never holds). When `client_csr_pem` is present, sign a new cert; otherwise refresh the JWT only: + +```python +@router.post("/{executor_id}/renew") +async def post_renew( + executor_id: str, + body: RenewRequest, + request: Request, + ex: Executor = Depends(require_executor_jwt), + session: AsyncSession = Depends(_session), +) -> RenewResponse: + """W3a §3.5: always refresh the JWT; sign a new cert iff the request + carries a fresh CSR (the executor includes one when its cert is near + expiry — see Task 9's renew loop).""" + if executor_id != ex.id: + raise HTTPException(401, detail="path executor_id mismatch") + new_jwt = jwt_signing.sign( + request.app.state.jwt_keypair, + executor_id=ex.id, epoch=ex.epoch, + scopes=["heartbeat", "poll", "report"], + ) + new_cert_pem: str | None = None + new_cert_renew_in: int | None = None + if body.client_csr_pem: + try: + new_cert_bytes = sign_csr( + request.app.state.ca, + body.client_csr_pem.encode("utf-8"), + executor_id=ex.id, ttl_hours=24, + ) + except ValueError as e: + raise HTTPException(422, detail=f"invalid CSR: {e}") from e + new_cert_pem = new_cert_bytes.decode("utf-8") + ex.cert_fingerprint = fingerprint_of(new_cert_bytes) + new_cert_renew_in = 86100 + await session.commit() + return RenewResponse( + executor_jwt=new_jwt, jwt_renew_in_seconds=3300, + client_cert_pem=new_cert_pem, + cert_renew_in_seconds=new_cert_renew_in, + ) +``` + +Add `RenewRequest` to the imports from `dlw.schemas.executor` alongside `ExecutorRegister` / `RegistrationResponse` / `RenewResponse`. + +Migrate `/heartbeat` and `/poll` dependency chains in this same file: +- `/heartbeat`: remove `Depends(require_bearer)`; the handler param `executor: Executor = Depends(require_executor_epoch)` stays, ADD `_hmac: Executor = Depends(require_hmac_heartbeat)`. +- `/poll`: remove `Depends(require_bearer)`; `executor: Executor = Depends(require_executor_epoch)` stays (it now transitively requires mTLS+JWT). + +(The W1 `/heartbeat` and `/poll` already use `Depends(require_executor_epoch)` — after Task 4's refactor that dep already pulls mTLS+JWT. So the only change here is dropping `require_bearer` and adding the HMAC dep to `/heartbeat`.) + +- [ ] **Step 4: Bootstrap in `src/dlw/main.py`** + +Read `main.py`. In `lifespan`, BEFORE the W1 `run_recovery_routine` call, add: + +```python + from pathlib import Path + from dlw.auth.ca import bootstrap_ca, ensure_server_cert + from dlw.auth.jwt_signing import bootstrap_keypair + from dlw.auth.hmac_nonce import NonceStore + import secrets as _secrets + from dlw.config import get_settings as _gs + _settings = _gs() + ca_dir = Path(_settings.ca_dir) + ca_dir.mkdir(mode=0o700, parents=True, exist_ok=True) + _ca = bootstrap_ca(ca_dir) + ensure_server_cert(_ca, ca_dir, hostname=_settings.controller_hostname) + _jwt_kp = bootstrap_keypair(ca_dir) + # Enrollment token: env override > file > generate-and-persist. + if _settings.enrollment_token: + _enroll = _settings.enrollment_token + else: + _tok_path = ca_dir / "enrollment.token" + if _tok_path.exists(): + _enroll = _tok_path.read_text().strip() + else: + _enroll = _secrets.token_hex(32) + _tok_path.write_text(_enroll) + _tok_path.chmod(0o600) + logger.info("generated enrollment token (copy to executors): %s", _enroll) + app.state.ca = _ca + app.state.jwt_keypair = _jwt_kp + app.state.nonce_store = NonceStore(maxsize=10_000, ttl_seconds=300) + app.state.enrollment_token = _enroll +``` + +`app` is available in `lifespan(app)` — confirm the signature. Place the block so `app.state.*` is set before the app serves traffic. + +- [ ] **Step 5: Add config fields to `src/dlw/config.py`** + +In `Settings`, add: + +```python + # Phase 2 W3a — mTLS + JWT + HMAC + ca_dir: str = Field(default="./.ca") + enrollment_token: str = Field(default="") + controller_hostname: str = Field(default="dlw-controller") + tls_trusted_proxy: bool = Field(default=False) +``` + +Env vars: `DLW_CA_DIR`, `DLW_ENROLLMENT_TOKEN`, `DLW_CONTROLLER_HOSTNAME`, `DLW_TLS_TRUSTED_PROXY`. + +- [ ] **Step 6: Write the endpoint tests** + +Create `tests/api/test_register_endpoint.py` — 3 cases: `test_register_returns_cert_jwt_and_hmac_seed`, `test_register_rejects_invalid_enrollment_token`, `test_register_idempotent_on_reregister`. Use the `ephemeral_ca` fixture; the test app's `app.state.ca / jwt_keypair / enrollment_token` are set from it. Build a CSR with `_build_csr` (copy the helper from `test_ca.py` or hoist it into conftest). Assert: 201 + all 4 response fields populated; 401 on wrong token; epoch bumps on re-register. + +Create `tests/api/test_renew_endpoint.py` — 2 cases: `test_renew_returns_new_jwt_only_when_cert_fresh` (register, then renew with no CSR → `client_cert_pem` is None), `test_renew_returns_new_cert_when_csr_provided` (renew with a fresh CSR → returns a new cert). Both use the `DLW_TLS_TRUSTED_PROXY=1` header bypass for the mTLS dep. + +Write the full test bodies following the `tests/api/test_cancel_endpoint.py` pattern (httpx `AsyncClient` + `ASGITransport` + `_bootstrap` fixture that creates tables + seeds tenant/project/user/storage). The app fixture must set `app.state.ca / jwt_keypair / nonce_store / enrollment_token` — either let the real `lifespan` run (it bootstraps into a tmp `DLW_CA_DIR`) or set them manually on `create_app()`'s result. Prefer letting `lifespan` run with `DLW_CA_DIR` monkeypatched to a tmp dir. + +- [ ] **Step 7: Run the new endpoint tests** + +``` +uv run pytest tests/api/test_register_endpoint.py tests/api/test_renew_endpoint.py -v +``` + +Expected: 5 passed. + +- [ ] **Step 8: Commit** (the api/ suite is still broken — Task 6 fixes it) + +```bash +git add src/dlw/schemas/executor.py src/dlw/services/executor_service.py src/dlw/api/executors.py src/dlw/main.py src/dlw/config.py tests/api/test_register_endpoint.py tests/api/test_renew_endpoint.py +git commit -m "feat(api): /register + /renew endpoints + main bootstrap; delete /join (W3a M2)" +``` + +--- + +### Milestone 2 verification (self) + +- [ ] `tests/auth/` fully green. +- [ ] `tests/api/test_register_endpoint.py` + `test_renew_endpoint.py` green. +- [ ] `/join` is deleted from `executors.py` and `ExecutorJoin` from `schemas/executor.py`. +- [ ] `main.py lifespan` bootstraps CA + JWT key + nonce store + enrollment token onto `app.state`. +- [ ] `tests/api/test_executors.py` + `test_subtasks.py` are EXPECTED to be red here — Task 6 migrates them. + +--- + +## Milestone 3 — Endpoint auth migration + e2e + +After M3: all W1 executor/subtask test setups migrated to `/register`; the real-TLS e2e passes; full suite green. + +--- + +### Task 6: Migrate W1 executor + subtask test setups + +**Files:** +- Modify: `src/dlw/api/subtasks.py` +- Modify: `tests/api/test_executors.py`, `tests/api/test_subtasks.py` +- Modify: `tests/e2e/test_executor_e2e.py`, `tests/e2e/test_happy_path.py` +- Modify: `tests/services/test_executor_service.py` +- Modify: `tests/conftest.py` (+`_signed_heartbeat_headers` + `registered_executor` helper) + +- [ ] **Step 1: Migrate `subtasks.py` `/report` auth chain** + +Read `src/dlw/api/subtasks.py`. The `/report` endpoint currently uses `Depends(require_bearer)` (and possibly `require_executor_epoch`). Remove `require_bearer`; ensure the chain is `Depends(require_executor_jwt)` + `Depends(require_executor_epoch)` (the latter transitively pulls mTLS+JWT). Body shape unchanged. + +- [ ] **Step 2: Add conftest helpers** + +In `tests/conftest.py`, add a `registered_executor` async helper + `_signed_heartbeat_headers`: + +```python +async def register_test_executor(client, *, ca, jwt_keypair, enrollment_token, + executor_id="test-host-worker-1", + host_id="test-host"): + """Build a CSR, POST /register, return a dict with cert_pem, key, jwt, + hmac_seed, epoch. For use in API tests that need an authenticated executor.""" + # build CSR + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric import ed25519 + from cryptography import x509 + from cryptography.x509.oid import NameOID + key = ed25519.Ed25519PrivateKey.generate() + csr = (x509.CertificateSigningRequestBuilder() + .subject_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, executor_id)])) + .sign(key, None)) + csr_pem = csr.public_bytes(serialization.Encoding.PEM).decode("utf-8") + r = await client.post("/api/v1/executors/register", json={ + "host_id": host_id, "executor_id_proposal": executor_id, + "capabilities": {}, "client_csr_pem": csr_pem, + }, headers={"X-Enrollment-Token": enrollment_token}) + assert r.status_code in (200, 201), r.text + body = r.json() + return { + "executor_id": body["executor_id"], "epoch": body["epoch"], + "cert_pem": body["client_cert_pem"], "jwt": body["executor_jwt"], + "hmac_seed": bytes.fromhex(body["hmac_seed_hex"]), + "ca_chain": body["ca_chain"], + } + + +def signed_heartbeat_headers(reg: dict, body: bytes) -> dict[str, str]: + """Compute the mTLS-bypass + JWT + HMAC + epoch headers for a heartbeat + body, given the dict returned by register_test_executor.""" + import secrets as _s, time as _t + from dlw.auth.hmac_nonce import compute_hmac + ts = int(_t.time()) + nonce = _s.token_hex(16) + sig = compute_hmac(reg["hmac_seed"], ts=ts, nonce=nonce, body=body) + return { + "X-Client-Cert-PEM": reg["cert_pem"].replace("\n", "\\n"), + "Authorization": f"Bearer {reg['jwt']}", + "X-Executor-Epoch": str(reg["epoch"]), + "X-HMAC-Timestamp": str(ts), + "X-HMAC-Nonce": nonce, + "X-HMAC-Signature": sig, + "Content-Type": "application/json", + } +``` + +(These are plain helper functions, not fixtures — import them in the test files that need them.) + +- [ ] **Step 3: Migrate `tests/api/test_executors.py`** + +Read the file. The `joined_executor` fixture calls `POST /api/v1/executors/join`. Replace it with a `registered_executor` fixture that: +1. Ensures `DLW_TLS_TRUSTED_PROXY=1` (via the existing `_set_token`-style env fixture or a new one). +2. Lets the app's `lifespan` bootstrap the CA (monkeypatch `DLW_CA_DIR` to a tmp dir, OR set `app.state` manually). +3. Calls `register_test_executor(...)` and returns the reg dict. + +Every test that does `POST /heartbeat` / `/poll` now needs the mTLS+JWT(+HMAC) headers — use `signed_heartbeat_headers(reg, body)` for heartbeat and a plain JWT+cert+epoch header set for poll. The W1 `_TOKEN` shared-bearer fixture is removed for executor endpoints (UI endpoints keep it, but `test_executors.py` only hits executor endpoints). + +Migrate the unauthenticated-rejection test: it should now assert that a request with NO cert header gets 401. + +- [ ] **Step 4: Migrate `tests/api/test_subtasks.py`** + +Same pattern: the `/report` calls need cert + JWT + epoch headers (no HMAC — report isn't HMAC-protected). Use `register_test_executor` + a `report_headers(reg)` helper (cert + JWT + epoch only). + +- [ ] **Step 5: Migrate `tests/e2e/test_executor_e2e.py` + `tests/e2e/test_happy_path.py`** + +These run a fuller flow. The W1 `/join` call at the start becomes `/register`. The mocked controller responses / runner wiring change to carry the new auth. Read each file; the changes are mechanical (swap `/join` → `/register`, attach the new headers). If a test mocks `ControllerClient` directly, update the mock to the W3a client surface (Task 9 defines it — for now, mock at the same boundary). + +- [ ] **Step 6: Migrate `tests/services/test_executor_service.py`** + +`join_executor` is renamed to `upsert_executor_with_cert`. Update the import + call sites. The W1 INSERT-or-bump cases still apply — add two assertions per case: `ex.cert_fingerprint` is set, `ex.hmac_seed_encrypted` is set. Pass synthetic `cert_fingerprint="SHA256:..."` + `hmac_seed=b"\x00"*32` in the test calls. + +- [ ] **Step 7: Run the full suite** + +``` +uv run pytest -x +``` + +Expected: green. Count ≈ 193 (M1) + 13 (Task 4 deps) + 5 (Task 5 endpoints) + ~1 (Task 4 confused-deputy) = ~212, minus/plus the W1 test rewrites (same count, different bodies). The exact number depends on how the W1 fixtures were structured — the key check is **0 failures**. + +If failures remain, they're almost certainly missed header migrations in `test_executors.py` / `test_subtasks.py` — fix them. + +- [ ] **Step 8: Commit** + +```bash +git add src/dlw/api/subtasks.py tests/conftest.py tests/api/test_executors.py tests/api/test_subtasks.py tests/e2e/test_executor_e2e.py tests/e2e/test_happy_path.py tests/services/test_executor_service.py +git commit -m "feat(api): migrate executor + subtask endpoints + test setups to mTLS+JWT (W3a M3)" +``` + +--- + +### Task 7: Real-TLS e2e test + +**Files:** +- Create: `tests/e2e/test_executor_auth_e2e.py` + +- [ ] **Step 1: Write the e2e test** + +Create `tests/e2e/test_executor_auth_e2e.py`. This test spawns uvicorn in a subprocess with real `--ssl-*` flags and connects with an httpx client doing real mTLS. + +```python +"""Real-TLS e2e: register → heartbeat full flow (Phase 2 W3a §7.2).""" +from __future__ import annotations + +import asyncio +import os +import socket +import subprocess +import sys +import time + +import httpx +import pytest + +from dlw.auth.ca import bootstrap_ca, ensure_server_cert + + +def _free_port() -> int: + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.mark.slow +async def test_register_then_heartbeat_full_flow(tmp_path, test_db_name) -> None: + """Spawn uvicorn with real TLS; register an executor; send an HMAC-signed + heartbeat over mTLS. Verifies the uvicorn wiring + peer-cert extraction.""" + ca_dir = tmp_path / "ca" + ca_dir.mkdir() + ca = bootstrap_ca(ca_dir) + server_cert, server_key = ensure_server_cert(ca, ca_dir, hostname="localhost") + port = _free_port() + enrollment_token = "e2e-enrollment-token" + + env = { + **os.environ, + "DLW_CA_DIR": str(ca_dir), + "DLW_ENROLLMENT_TOKEN": enrollment_token, + "DLW_CONTROLLER_HOSTNAME": "localhost", + "DLW_DB_NAME": test_db_name, + # ... DB host/port/user from the conftest env pattern ... + } + proc = subprocess.Popen( + [sys.executable, "-m", "uvicorn", "dlw.main:app", + "--host", "127.0.0.1", "--port", str(port), + "--ssl-keyfile", str(server_key), + "--ssl-certfile", str(server_cert), + "--ssl-ca-certs", str(ca_dir / "ca-cert.pem"), + "--ssl-cert-reqs", "2"], + env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + ) + try: + # Wait for the server to be ready (poll /healthz over TLS). + base = f"https://localhost:{port}" + for _ in range(50): + try: + async with httpx.AsyncClient(verify=str(ca_dir / "ca-cert.pem")) as c: + # /healthz may not require mTLS; if it does, skip this probe + r = await c.get(f"{base}/healthz", timeout=1.0) + if r.status_code < 500: + break + except Exception: + await asyncio.sleep(0.2) + else: + out = proc.stdout.read().decode() if proc.stdout else "" + pytest.fail(f"uvicorn did not start: {out}") + + # Build a CSR + register (no mTLS for /register — enrollment token). + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric import ed25519 + from cryptography import x509 + from cryptography.x509.oid import NameOID + key = ed25519.Ed25519PrivateKey.generate() + csr = (x509.CertificateSigningRequestBuilder() + .subject_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "e2e-worker-1")])) + .sign(key, None)) + csr_pem = csr.public_bytes(serialization.Encoding.PEM).decode() + + async with httpx.AsyncClient(verify=str(ca_dir / "ca-cert.pem")) as c: + reg = await c.post(f"{base}/api/v1/executors/register", json={ + "host_id": "e2e-host", "executor_id_proposal": "e2e-worker-1", + "capabilities": {}, "client_csr_pem": csr_pem, + }, headers={"X-Enrollment-Token": enrollment_token}) + assert reg.status_code == 201, reg.text + body = reg.json() + + # Persist the issued client cert + key for the mTLS heartbeat call. + client_cert_path = tmp_path / "client-cert.pem" + client_key_path = tmp_path / "client-key.pem" + client_cert_path.write_text(body["client_cert_pem"]) + client_key_path.write_bytes(key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + )) + + # Heartbeat over real mTLS + JWT + HMAC. + import secrets, json + from dlw.auth.hmac_nonce import compute_hmac + hb_body = json.dumps({"health_score": 100, "parts_dir_bytes": 0}).encode() + ts = int(time.time()) + nonce = secrets.token_hex(16) + sig = compute_hmac(bytes.fromhex(body["hmac_seed_hex"]), + ts=ts, nonce=nonce, body=hb_body) + async with httpx.AsyncClient( + verify=str(ca_dir / "ca-cert.pem"), + cert=(str(client_cert_path), str(client_key_path)), + ) as c: + hb = await c.post( + f"{base}/api/v1/executors/e2e-worker-1/heartbeat", + content=hb_body, + headers={ + "Authorization": f"Bearer {body['executor_jwt']}", + "X-Executor-Epoch": str(body["epoch"]), + "X-HMAC-Timestamp": str(ts), + "X-HMAC-Nonce": nonce, + "X-HMAC-Signature": sig, + "Content-Type": "application/json", + }, + ) + assert hb.status_code == 200, hb.text + finally: + proc.terminate() + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + proc.kill() +``` + +> The test needs the DB env vars matching the conftest pattern — read `tests/conftest.py`'s `_pg_env()` and replicate the host/port/user/password into the subprocess `env`. The subprocess controller bootstraps its OWN `DLW_CA_DIR` but the test pre-creates the CA + server cert so the `verify=` path is known. Since `bootstrap_ca` is idempotent, the subprocess loads the same CA the test created. + +- [ ] **Step 2: Run the e2e test** + +``` +uv run pytest tests/e2e/test_executor_auth_e2e.py -v +``` + +Expected: 1 passed. If uvicorn fails to start, the test fails fast with the captured stdout. Common issues: missing `--ssl-*` file paths, the controller's own `lifespan` `bootstrap_ca` colliding with the pre-created files (it shouldn't — `bootstrap_ca` is idempotent and loads existing files). + +- [ ] **Step 3: Run full suite** + +``` +uv run pytest -x +``` + +Expected: green, +1 from the e2e. + +- [ ] **Step 4: Commit** + +```bash +git add tests/e2e/test_executor_auth_e2e.py +git commit -m "test(e2e): real-TLS register → heartbeat full flow (W3a M3)" +``` + +--- + +### Milestone 3 verification (self) + +- [ ] Full pytest suite green (0 failures). +- [ ] `tests/e2e/test_executor_auth_e2e.py` exercises real uvicorn TLS and passes. +- [ ] No `require_bearer` remains on any executor/subtask route. +- [ ] `git grep -n "require_bearer" src/dlw/api/` shows only `tasks.py` (UI). + +--- + +## Milestone 4 — Executor side + lint + OpenAPI + PR + +After M4: executor side does register/renew/HMAC; lint locks the bearer-free invariant; OpenAPI + runbook updated; PR open. + +--- + +### Task 8: Executor `cert.py` + `auth_lifecycle.py` + +**Files:** +- Create: `src/dlw/executor/cert.py` +- Create: `src/dlw/executor/auth_lifecycle.py` +- Create: `tests/executor/test_cert.py`, `tests/executor/test_auth_lifecycle.py` + +- [ ] **Step 1: Write failing tests for `cert.py`** + +Create `tests/executor/test_cert.py` — cases: `test_build_csr_returns_pem_and_key`, `test_persist_and_load_roundtrip`, `test_fingerprint_matches_controller_format`. `build_csr` returns `(csr_pem, key_pem)`; `persist` writes 4 files; `load` returns the tuple or None; `fingerprint` matches `dlw.auth.ca.fingerprint_of` output format. + +- [ ] **Step 2: Implement `src/dlw/executor/cert.py`** + +```python +"""Executor-side cert helpers (Phase 2 W3a §3.10).""" +from __future__ import annotations + +from pathlib import Path + +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import ed25519 +from cryptography.x509.oid import NameOID + + +def build_csr(executor_id: str) -> tuple[bytes, bytes]: + """Generate an Ed25519 keypair + CSR (CN=executor_id). + Returns (csr_pem, private_key_pem).""" + key = ed25519.Ed25519PrivateKey.generate() + csr = (x509.CertificateSigningRequestBuilder() + .subject_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, executor_id)])) + .sign(key, None)) + csr_pem = csr.public_bytes(serialization.Encoding.PEM) + key_pem = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + return csr_pem, key_pem + + +def persist(cert_dir: Path, *, cert_pem: bytes, key_pem: bytes, + ca_chain_pem: bytes, hmac_seed: bytes) -> None: + """Write client-cert.pem / client-key.pem / ca-chain.pem / hmac-seed + (chmod 600) into cert_dir (chmod 700).""" + cert_dir.mkdir(mode=0o700, parents=True, exist_ok=True) + for name, data in [ + ("client-cert.pem", cert_pem), ("client-key.pem", key_pem), + ("ca-chain.pem", ca_chain_pem), ("hmac-seed", hmac_seed), + ]: + p = cert_dir / name + p.write_bytes(data) + p.chmod(0o600) + + +def load(cert_dir: Path) -> tuple[bytes, bytes, bytes, bytes] | None: + """Return (cert_pem, key_pem, ca_chain_pem, hmac_seed) or None if absent.""" + paths = [cert_dir / n for n in + ("client-cert.pem", "client-key.pem", "ca-chain.pem", "hmac-seed")] + if not all(p.exists() for p in paths): + return None + return tuple(p.read_bytes() for p in paths) # type: ignore[return-value] + + +def fingerprint(cert_pem: bytes) -> str: + """SHA256: — same format as dlw.auth.ca.fingerprint_of.""" + cert = x509.load_pem_x509_certificate(cert_pem) + return f"SHA256:{cert.fingerprint(hashes.SHA256()).hex()}" +``` + +- [ ] **Step 3: Write failing tests for `auth_lifecycle.py`** + +Create `tests/executor/test_auth_lifecycle.py` — at minimum `test_load_or_register_first_run_calls_register` and `test_load_or_register_existing_loads_and_renews`. Mock the controller HTTP calls with `httpx.MockTransport` (W4 pattern). Verify `AuthState` fields populated correctly. + +- [ ] **Step 4: Implement `src/dlw/executor/auth_lifecycle.py`** + +```python +"""Executor auth lifecycle: register / renew / load (Phase 2 W3a §3.11).""" +from __future__ import annotations + +import datetime as _dt +from dataclasses import dataclass +from pathlib import Path + +import httpx +from cryptography import x509 + +from dlw.executor import cert as _cert + + +@dataclass +class AuthState: + executor_id: str + epoch: int + cert_pem: bytes + key_pem: bytes + ca_chain_pem: bytes + jwt: str + jwt_exp: _dt.datetime + cert_exp: _dt.datetime + hmac_seed: bytes + cert_dir: Path + + +def _parse_jwt_exp(token: str) -> _dt.datetime: + import jwt as _pyjwt + claims = _pyjwt.decode(token, options={"verify_signature": False}) + return _dt.datetime.fromtimestamp(claims["exp"], tz=_dt.UTC) + + +def _parse_cert_exp(cert_pem: bytes) -> _dt.datetime: + return x509.load_pem_x509_certificate(cert_pem).not_valid_after_utc + + +async def register(*, controller_url: str, ca_bundle_path: str | None, + enrollment_token: str, executor_id: str, host_id: str, + capabilities: dict, cert_dir: Path) -> AuthState: + csr_pem, key_pem = _cert.build_csr(executor_id) + verify = ca_bundle_path if ca_bundle_path else True + async with httpx.AsyncClient(verify=verify) as c: + r = await c.post(f"{controller_url}/api/v1/executors/register", json={ + "host_id": host_id, "executor_id_proposal": executor_id, + "capabilities": capabilities, + "client_csr_pem": csr_pem.decode("utf-8"), + }, headers={"X-Enrollment-Token": enrollment_token}) + r.raise_for_status() + body = r.json() + cert_pem = body["client_cert_pem"].encode("utf-8") + ca_chain_pem = "\n".join(body["ca_chain"]).encode("utf-8") + hmac_seed = bytes.fromhex(body["hmac_seed_hex"]) + _cert.persist(cert_dir, cert_pem=cert_pem, key_pem=key_pem, + ca_chain_pem=ca_chain_pem, hmac_seed=hmac_seed) + return AuthState( + executor_id=body["executor_id"], epoch=body["epoch"], + cert_pem=cert_pem, key_pem=key_pem, ca_chain_pem=ca_chain_pem, + jwt=body["executor_jwt"], jwt_exp=_parse_jwt_exp(body["executor_jwt"]), + cert_exp=_parse_cert_exp(cert_pem), hmac_seed=hmac_seed, + cert_dir=cert_dir, + ) + + +async def renew(state: AuthState, *, controller_url: str) -> AuthState: + """POST /{eid}/renew over mTLS. Include a fresh CSR iff cert TTL < 1h.""" + now = _dt.datetime.now(_dt.UTC) + payload: dict = {} + new_key_pem = state.key_pem + if state.cert_exp - now < _dt.timedelta(hours=1): + csr_pem, new_key_pem = _cert.build_csr(state.executor_id) + payload["client_csr_pem"] = csr_pem.decode("utf-8") + cert_file = state.cert_dir / "client-cert.pem" + key_file = state.cert_dir / "client-key.pem" + async with httpx.AsyncClient( + verify=str(state.cert_dir / "ca-chain.pem"), + cert=(str(cert_file), str(key_file)), + headers={"Authorization": f"Bearer {state.jwt}"}, + ) as c: + r = await c.post( + f"{controller_url}/api/v1/executors/{state.executor_id}/renew", + json=payload, + ) + r.raise_for_status() + body = r.json() + new_jwt = body["executor_jwt"] + cert_pem = state.cert_pem + cert_exp = state.cert_exp + if body.get("client_cert_pem"): + cert_pem = body["client_cert_pem"].encode("utf-8") + cert_exp = _parse_cert_exp(cert_pem) + _cert.persist(state.cert_dir, cert_pem=cert_pem, key_pem=new_key_pem, + ca_chain_pem=state.ca_chain_pem, hmac_seed=state.hmac_seed) + return AuthState( + executor_id=state.executor_id, epoch=state.epoch, + cert_pem=cert_pem, key_pem=new_key_pem, ca_chain_pem=state.ca_chain_pem, + jwt=new_jwt, jwt_exp=_parse_jwt_exp(new_jwt), cert_exp=cert_exp, + hmac_seed=state.hmac_seed, cert_dir=state.cert_dir, + ) + + +async def load_or_register(*, cert_dir: Path, controller_url: str, + ca_bundle_path: str | None, enrollment_token: str, + executor_id: str, host_id: str, + capabilities: dict) -> AuthState: + loaded = _cert.load(cert_dir) + if loaded is not None: + cert_pem, key_pem, ca_chain_pem, hmac_seed = loaded + # We have a cert but no JWT (JWT is never persisted). Re-register to + # get a fresh JWT — simpler + always correct than a JWT-only path. + # The existing cert is still valid; /register's upsert just bumps epoch. + # W3a simplification: always re-register on restart. + return await register( + controller_url=controller_url, ca_bundle_path=ca_bundle_path, + enrollment_token=enrollment_token, executor_id=executor_id, + host_id=host_id, capabilities=capabilities, cert_dir=cert_dir, + ) + return await register( + controller_url=controller_url, ca_bundle_path=ca_bundle_path, + enrollment_token=enrollment_token, executor_id=executor_id, + host_id=host_id, capabilities=capabilities, cert_dir=cert_dir, + ) +``` + +> **Simplification noted in code:** `load_or_register` always re-registers on restart (the JWT is never persisted, so a "load + renew" path would still need a valid JWT, which we don't have). Re-register is idempotent (epoch bumps — correct W1 fence semantics). The renew loop handles in-process refresh; restart goes through register. This is simpler and always correct. The two branches above are intentionally identical — kept separate so a future Phase-3 JWT-persistence optimization has an obvious seam. + +- [ ] **Step 5: Run the executor auth tests** + +``` +uv run pytest tests/executor/test_cert.py tests/executor/test_auth_lifecycle.py -v +``` + +Expected: all pass. + +- [ ] **Step 6: Run full suite** + +``` +uv run pytest -x +``` + +Expected: green, + the new executor-side tests. + +- [ ] **Step 7: Commit** + +```bash +git add src/dlw/executor/cert.py src/dlw/executor/auth_lifecycle.py tests/executor/test_cert.py tests/executor/test_auth_lifecycle.py +git commit -m "feat(executor): cert.py + auth_lifecycle.py (register/renew/load) (W3a M4)" +``` + +--- + +### Task 9: Executor `client.py` + `runner.py` + `config.py` + +**Files:** +- Modify: `src/dlw/executor/config.py` +- Modify: `src/dlw/executor/client.py` +- Modify: `src/dlw/executor/runner.py` +- Modify: `tests/executor/test_client.py`, `tests/executor/test_runner.py`, `tests/executor/test_runner_dispatch.py`, `tests/executor/test_runner_external_throttle.py` + +- [ ] **Step 1: Add config fields to `src/dlw/executor/config.py`** + +```python + # Phase 2 W3a — mTLS + JWT auth + enrollment_token: str = Field(default="") + executor_cert_dir: str = Field(default="~/.dlw/executor") + executor_ca_bundle: str = Field(default="") # runtime-defaults to {cert_dir}/ca-chain.pem +``` + +Env vars: `DLW_EXECUTOR_ENROLLMENT_TOKEN`, `DLW_EXECUTOR_EXECUTOR_CERT_DIR`, `DLW_EXECUTOR_EXECUTOR_CA_BUNDLE` (the existing `DLW_EXECUTOR_` prefix applies). + +- [ ] **Step 2: Rewrite `src/dlw/executor/client.py`** + +The `ControllerClient` takes an `AuthState` reference instead of a bearer token. Read the current file first. Key changes: +- Constructor: `ControllerClient(base_url, auth_state, timeout_seconds=30.0, _transport=None)` — drop `bearer_token`. +- Each request builds an `httpx.AsyncClient` with `verify=`, `cert=(, )`, `headers={"Authorization": f"Bearer {auth.jwt}"}`. (When `_transport` is injected for tests, `verify` / `cert` are ignored — MockTransport short-circuits.) +- `heartbeat(...)` additionally computes the HMAC headers via `compute_hmac` on the JSON body. +- Add `update_auth(new_state: AuthState)` to swap the auth ref in place (the renew loop calls this). +- `current_epoch()` reads `auth.epoch`. +- Preserve the `_transport` injection seam — tests rely on it. + +Write the full file. Mirror the W1 structure (tenacity retry decorator, `__aenter__`/`__aexit__`, the per-method request shape) but swap the auth. + +- [ ] **Step 3: Rewrite `src/dlw/executor/runner.py` auth bootstrap + renew loop** + +Read the file. Changes: +- `run()`: before the W1 `/join` (now removed), call `load_or_register(...)` → `self._auth`. Pass it into the client (`self._client.update_auth(self._auth)` or construct the client with it). +- Spawn a THIRD background task `_auth_renew_loop` alongside `_heartbeat_loop` + `_poll_and_execute_loop`. +- `_auth_renew_loop`: sleep until `min(jwt_exp - 5min, cert_exp - 1h)`, then `self._auth = await renew(self._auth, controller_url=...)` + `self._client.update_auth(self._auth)`. On exception: log + retry next cycle. On `_shutdown`: return. +- The W1 `EPOCH_MISMATCH` re-join path in `_poll_and_execute_loop`: generalize the 401 handler to call `load_or_register(...)` (which re-registers) instead of the deleted `/join`. + +```python +async def _auth_renew_loop(self) -> None: + from datetime import UTC, datetime, timedelta + from dlw.executor.auth_lifecycle import renew + while not self._shutdown.is_set(): + now = datetime.now(UTC) + jwt_due = self._auth.jwt_exp - timedelta(minutes=5) + cert_due = self._auth.cert_exp - timedelta(hours=1) + sleep_for = max(60, int((min(jwt_due, cert_due) - now).total_seconds())) + try: + await asyncio.wait_for(self._shutdown.wait(), timeout=sleep_for) + return + except asyncio.TimeoutError: + pass + try: + self._auth = await renew(self._auth, + controller_url=self._s.controller_url) + self._client.update_auth(self._auth) + except Exception as e: + logger.warning("auth renew failed: %s; retry next cycle", e) +``` + +- [ ] **Step 4: Migrate the executor test setups** + +`tests/executor/test_client.py` / `test_runner.py` / `test_runner_dispatch.py` / `test_runner_external_throttle.py` construct `ControllerClient` and `ExecutorRunner`. Update: +- `ControllerClient(...)` calls: pass a synthetic `AuthState` (build one with a self-signed cert + a fake JWT + a 32-byte seed) instead of `bearer_token`. +- `ExecutorRunner(...)` calls: the runner now does `load_or_register` in `run()` — for tests that don't call `run()` (most just test `_execute_subtask` / `_choose_downloader`), inject `self._auth` directly after construction or via a constructor param. Choose whichever is least invasive given the current test structure; the W2b1 tests pass `MagicMock()` downloaders, so a `MagicMock()` or a synthetic `AuthState` for `_auth` works. + +These are mechanical fixture edits — no test logic changes. + +- [ ] **Step 5: Run executor tests** + +``` +uv run pytest tests/executor/ -v +``` + +Expected: all pass. + +- [ ] **Step 6: Run full suite** + +``` +uv run pytest -x +``` + +Expected: green. + +- [ ] **Step 7: Commit** + +```bash +git add src/dlw/executor/config.py src/dlw/executor/client.py src/dlw/executor/runner.py tests/executor/ +git commit -m "feat(executor): client + runner mTLS/JWT/HMAC auth + renew loop (W3a M4)" +``` + +--- + +### Task 10: Lint + OpenAPI + operator runbook + +**Files:** +- Modify: `tools/lint_invariants.py` +- Modify: `api/openapi.yaml` +- Modify: `docs/operator/` + +- [ ] **Step 1: Add `check_no_bearer_on_executor_routes` to `tools/lint_invariants.py`** + +After the W2b2 helpers, add: + +```python +def check_no_bearer_on_executor_routes() -> list[str]: + """W3a §3.15: forbid Depends(require_bearer) in executor/subtask route files. + Those endpoints must use mTLS + JWT (not the UI shared-secret bearer).""" + errors: list[str] = [] + files = [ + ROOT / "src" / "dlw" / "api" / "executors.py", + ROOT / "src" / "dlw" / "api" / "subtasks.py", + ] + import ast as _ast + for f in files: + if not f.exists(): + continue + src = f.read_text(encoding="utf-8") + tree = _ast.parse(src) + for node in _ast.walk(tree): + # Depends(require_bearer) — a Call to Depends with require_bearer arg + if (isinstance(node, _ast.Call) + and isinstance(node.func, _ast.Name) + and node.func.id == "Depends" + and node.args + and isinstance(node.args[0], _ast.Name) + and node.args[0].id == "require_bearer"): + errors.append( + f"{f.relative_to(ROOT)}:{node.lineno}: " + f"require_bearer forbidden on executor/subtask routes " + f"(use mTLS + JWT)" + ) + return errors +``` + +Wire into `main()`: `failures.extend(check_no_bearer_on_executor_routes())`. + +- [ ] **Step 2: Run the lint** + +``` +python tools/lint_invariants.py +uv run pytest tools/test_lint_invariants.py -v +``` + +Expected: lint exits 0 (the M3 migration removed all `require_bearer` from those files); existing lint self-tests pass. + +- [ ] **Step 3: Update `api/openapi.yaml`** + +- Remove the `/executors/join` operation. +- Add `/executors/register` (POST, `X-Enrollment-Token` header, `ExecutorRegister` body, `RegistrationResponse` 201). +- Add `/executors/{executorId}/renew` (POST, mTLS + JWT, `RenewResponse` 200). +- On the `/executors/{executorId}/heartbeat` operation: document the `X-HMAC-Timestamp` / `X-HMAC-Nonce` / `X-HMAC-Signature` headers as required. +- Add `RegistrationResponse` / `RenewResponse` / `ExecutorRegister` schemas under `components/schemas`. + +Match the existing openapi.yaml indentation + style. The W2b2 changes show the pattern. + +- [ ] **Step 4: Update `docs/operator/`** + +Add a new file `docs/operator/executor-auth.md` (or append to `executor-runbook.md` if that's the established home — check `ls docs/operator/`): + +```markdown +## mTLS + Executor JWT + HMAC (Phase 2 W3a+) + +### Controller bootstrap + +On first launch the controller generates, under `${DLW_CA_DIR}` (default +`./.ca`, chmod 700): + +- `ca-cert.pem` / `ca-key.pem` — the self-signed CA (10-year validity). +- `server-cert.pem` / `server-key.pem` — the controller's TLS server cert + (SAN: localhost, $DLW_CONTROLLER_HOSTNAME, 127.0.0.1, ::1). +- `jwt-signing.pem` — Ed25519 JWT signing key. +- `enrollment.token` — 256-bit hex token (also logged once at INFO). + +Run uvicorn with TLS: + + uvicorn dlw.main:app --host 0.0.0.0 --port 8443 \ + --ssl-keyfile ${DLW_CA_DIR}/server-key.pem \ + --ssl-certfile ${DLW_CA_DIR}/server-cert.pem \ + --ssl-ca-certs ${DLW_CA_DIR}/ca-cert.pem \ + --ssl-cert-reqs 2 + +### Enrolling an executor + +1. Copy the controller's enrollment token to the executor host out-of-band. +2. Set `DLW_EXECUTOR_ENROLLMENT_TOKEN` on the executor. +3. On first boot the executor generates a keypair, builds a CSR, calls + `/register`, and persists `client-cert.pem` / `client-key.pem` / + `ca-chain.pem` / `hmac-seed` under `${DLW_EXECUTOR_CERT_DIR}` + (default `~/.dlw/executor`, chmod 700). +4. Certs auto-renew (24h cert, 1h JWT) via the executor's renew loop. + +### `DLW_TLS_TRUSTED_PROXY` — security warning + +`DLW_TLS_TRUSTED_PROXY=1` makes the controller honor the +`X-Client-Cert-PEM` header instead of the direct TLS peer cert. Only +enable this when a real TLS-terminating reverse proxy sits in front AND +the uvicorn port is NOT directly reachable. With it on and no proxy, +anyone can forge the header. Default is `0` (direct uvicorn TLS only). + +### Host clock sync + +Heartbeats carry an HMAC timestamp validated within ±5 min. Run +`chrony` / `systemd-timesyncd` on all executor + controller hosts. +``` + +- [ ] **Step 5: Final local verification** + +``` +python tools/lint_invariants.py +python tools/lint_no_direct_status_write.py +uv run pytest -x +``` + +Expected: both lints clean; pytest green. + +- [ ] **Step 6: Commit** + +```bash +git add tools/lint_invariants.py api/openapi.yaml docs/operator/ +git commit -m "ci(lint): forbid bearer on executor routes + OpenAPI + operator runbook (W3a M4)" +``` + +--- + +### Task 11: Push branch + open PR + monitor CI (controller does this) + +- [ ] **Step 1: Confirm branch state** + +```bash +git status +git log main..HEAD --oneline +``` + +Expected: clean working tree; ~11 commits (1 spec + 10 task commits). + +- [ ] **Step 2: Push** + +```bash +git push -u origin feat/phase-2-w3a-mtls-jwt-hmac +``` + +- [ ] **Step 3: Open the PR** + +```bash +gh pr create \ + --title "Phase 2 Week 3a — mTLS + executor JWT + HMAC heartbeat" \ + --body "$(cat <<'EOF' +## Summary + +W3a of `docs/v2.0/08-mvp-roadmap.md` §2.6 Day 1-3 — replaces executor-side bearer auth with SVID-style mTLS + Ed25519 JWT + HMAC heartbeat (SEC-01 + SEC-04): + +- **mTLS substrate.** Controller bootstraps a self-signed CA + server cert + Ed25519 JWT signing key under `${DLW_CA_DIR}` (file-persisted, chmod 600). New `POST /executors/register` (enrollment-token auth) signs an executor CSR; `POST /executors/{eid}/renew` refreshes the JWT (+ cert when a CSR is supplied). W1 `/join` deleted. +- **JWT + HMAC.** Three chained FastAPI deps — `require_executor_mtls` → `require_executor_jwt` → `require_hmac_heartbeat`. `require_executor_epoch` refactored to chain under the JWT dep and assert the path id matches the mTLS identity (confused-deputy guard). In-process nonce store bounds replay to a ±5min window. +- **Executor side.** New `cert.py` + `auth_lifecycle.py`; `client.py` does mTLS + JWT + HMAC; runner spawns a 3rd background loop for cert/JWT renewal. `load_or_register` re-registers on restart (idempotent epoch bump). +- **UI auth unchanged.** `/api/v1/tasks/*` keeps `require_bearer`; a new `check_no_bearer_on_executor_routes` lint locks executor routes onto mTLS+JWT. + +Spec: `docs/superpowers/specs/2026-05-14-phase-2-w3a-mtls-jwt-hmac-design.md`. +Plan: `docs/superpowers/plans/2026-05-14-phase-2-w3a-mtls-jwt-hmac.md`. + +W3b (HF reverse-proxy) and W3c (active/standby) are companion specs. + +## Test plan + +- [x] Backend pytest: baseline 181 + ~27 new + ~13-15 migrated W1 setups. Zero regressions. +- [x] `dlw.auth.{ca,jwt_signing,hmac_nonce}` modules + 3 FastAPI deps unit-tested. +- [x] `/register` + `/renew` endpoint tests. +- [x] One real-TLS e2e (uvicorn subprocess with `--ssl-*`): register → HMAC-signed heartbeat over mTLS. +- [x] alembic upgrade clean from W2b2 head; downgrade clean. +- [x] `tools/lint_invariants.py` `check_no_bearer_on_executor_routes` returns 0. +- [x] `cryptography` + `pyjwt[crypto]` added to `pyproject.toml` + `uv.lock`. +- [x] OpenAPI: `/register` + `/renew` added, `/join` removed, HMAC headers documented. + +## Out of scope (deferred — see spec §1.2) + +HF reverse-proxy (W3b); active/standby + chaos drill (W3c); OIDC / multi-tenant / UI auth (Phase 3); Vault/KMS for keys (Phase 3); CRL / cert-manager (Phase 3+); envelope encryption of hmac_seed (Phase 3); PG/Redis nonce store (Phase 3). + +🤖 Generated with [Claude Code](https://claude.com/claude-code) +EOF +)" +``` + +- [ ] **Step 4: Monitor CI** + +```bash +gh pr checks $(gh pr view --json number -q .number) --watch +``` + +Expected: 12 checks pass. If any fail: + +- **pytest** — the real-TLS e2e is the most fragile; if it fails on CI but passes locally, check the uvicorn subprocess startup timeout + the DB env var plumbing into the subprocess. +- **Invariant + cross-ref lint** — `check_no_bearer_on_executor_routes` may catch a missed `require_bearer`; remove it from the source. +- **OpenAPI lint** — spectral may flag the new operations; diff against the W2b2 OpenAPI change pattern. + +--- + +### Milestone 4 verification (self) + +- [ ] PR opened; CI 12/12 green. +- [ ] No diff outside the File Structure list (`gh pr diff --name-only`). +- [ ] `git grep require_bearer src/dlw/api/` shows only `tasks.py`. +- [ ] All new tests pass; no W1/W2a/W2b1/W2b2 regressions. + +--- + +## Definition of Done + +- [ ] All 10 implementation tasks committed on `feat/phase-2-w3a-mtls-jwt-hmac`. +- [ ] PR opened, CI 12/12 green. +- [ ] `cryptography` + `pyjwt[crypto]` pinned in `pyproject.toml`; `uv.lock` updated. +- [ ] 1 alembic migration applies + reverses clean; `executors.hmac_seed_encrypted` exists. +- [ ] `dlw.auth.{ca,jwt_signing,hmac_nonce}` + 3 FastAPI deps + `require_executor_epoch` refactor — all unit-tested. +- [ ] `/register` + `/renew` endpoints work; `/join` + `ExecutorJoin` deleted. +- [ ] `join_executor` → `upsert_executor_with_cert` (writes cert_fingerprint + hmac_seed). +- [ ] Executor side: `cert.py` + `auth_lifecycle.py` + rewritten `client.py` + `runner.py` 3rd loop. +- [ ] One real-TLS e2e passes. +- [ ] `check_no_bearer_on_executor_routes` lint reports 0; `git grep require_bearer src/dlw/api/` → only `tasks.py`. +- [ ] OpenAPI updated; operator runbook documents CA dir, enrollment token, uvicorn `--ssl-*`, `DLW_TLS_TRUSTED_PROXY` warning, clock-sync requirement. +- [ ] No new CI jobs. Full suite green. + +--- + +## Plan Revisions Log + +(Empty on first draft.) + +| Tag | Severity | Issue | Fix applied | +|-----|----------|-------|-------------| +| _(none yet)_ | | | | + +--- + +## References + +- Spec: `docs/superpowers/specs/2026-05-14-phase-2-w3a-mtls-jwt-hmac-design.md` +- Predecessor specs/plans: W1 / W2a / W2b1 / W2b2 under `docs/superpowers/{specs,plans}/` +- Security: `docs/v2.0/04-security-and-tenancy.md` §2.2; Protocol: `docs/v2.0/02-protocol.md` §4.1 +- Roadmap: `docs/v2.0/08-mvp-roadmap.md` §2.6 Phase 2 W3 Day 1-3 +- W2b2 PR (merged): https://github.com/l17728/modelpull/pull/11 (squash `ba89a91`) diff --git a/docs/superpowers/specs/2026-05-14-phase-2-w3a-mtls-jwt-hmac-design.md b/docs/superpowers/specs/2026-05-14-phase-2-w3a-mtls-jwt-hmac-design.md new file mode 100644 index 0000000..029ddff --- /dev/null +++ b/docs/superpowers/specs/2026-05-14-phase-2-w3a-mtls-jwt-hmac-design.md @@ -0,0 +1,605 @@ +# Phase 2 Week 3a — mTLS + Executor JWT + HMAC Heartbeat Design + +> **Status:** Draft (brainstormed 2026-05-14). +> **Companion plan:** `docs/superpowers/plans/2026-05-14-phase-2-w3a-mtls-jwt-hmac.md` (to be written by writing-plans skill after spec approval). +> **Roadmap source:** `docs/v2.0/08-mvp-roadmap.md` §2.6 — Phase 2 Week 3 Day 1-3 (mTLS CA + enrollment + JWT, then HMAC heartbeat). +> **Companion split (W3b / W3c):** HF reverse-proxy is W3b; Active/standby controller + chaos drill is W3c. Both depend on W3a's auth substrate. +> **Security source:** `docs/v2.0/04-security-and-tenancy.md` §2.2 (Executor auth, SEC-01 + SEC-04) + `docs/v2.0/02-protocol.md` §4.1 (heartbeat HMAC). + +--- + +## 1. Goal & Non-Goals + +### 1.1 Goal + +Replace the Phase-1 single-shared-bearer auth for **executor** endpoints with the SVID-style auth specified in `04 §2.2`: + +1. **mTLS substrate.** Controller bootstraps a self-signed CA at first startup (file-persisted under `${DLW_CA_DIR}`, chmod 600). A new `POST /api/v1/executors/register` endpoint (auth: enrollment token) accepts an executor-generated CSR + metadata, signs a 24h client cert, persists the cert fingerprint on `executors.cert_fingerprint`, and returns `(client_cert_pem, ca_chain, initial_jwt, hmac_seed)`. uvicorn loads the CA bundle at startup; a FastAPI dependency reads the verified peer cert and looks up the executor by fingerprint. + +2. **Executor JWT.** Controller generates an Ed25519 signing keypair at first startup (same dir as the CA). `POST /api/v1/executors/{eid}/renew` requires a valid client cert + current JWT and returns a fresh 1h JWT (plus a new client cert if the old cert is within 1h of expiry). JWT claims: `{iss, sub: executor_id, epoch, scope, iat, exp}`. + +3. **HMAC heartbeat.** Per `04 §2.2.4` / `02 §4.1`: every heartbeat body is signed `HMAC-SHA256(hmac_seed, f"{ts}:{nonce}:" + body)`; the executor sends `X-HMAC-Timestamp`, `X-HMAC-Nonce`, `X-HMAC-Signature` headers. The controller validates: timestamp within ±5 min, nonce not seen in an in-process LRU (size 10000, TTL 5 min), signature matches constant-time. + +4. **Executor-side lifecycle.** First boot: read enrollment token from env, generate an Ed25519 keypair, build a CSR, call `/register`, persist `client-cert.pem` + `client-key.pem` + `ca-chain.pem` + `hmac-seed` (chmod 600), cache the JWT in memory. A background renewal loop refreshes the JWT ~5 min before expiry and the client cert ~1 h before expiry. + +After W3a: executor endpoints (`/heartbeat`, `/poll`, `/report`) require mTLS + valid JWT + (heartbeat only) HMAC signature. UI endpoints (`/api/v1/tasks/*`) keep `require_bearer` — Phase 3 W3 (User OIDC) replaces those. + +### 1.2 Non-goals (deferred — explicit list) + +| Item | Where | +|---|---| +| HF reverse-proxy + executor route migration | **W3b** | +| Active/standby controller + chaos drill | **W3c** | +| OIDC + multi-tenant + UI auth migration | **Phase 3 W3** | +| `tenants.hf_tokens` envelope encryption + 5-min controller-memory cache | **Phase 3** (consumed by W3b) | +| Vault / k8s Secret integration for CA + signing key + enrollment token | **Phase 3** | +| cert-manager / Sigstore / CRL (cert revocation) | **Phase 3+** | +| Audit log of register / renew events | **Phase 3** | +| Per-tenant CA (each tenant has its own SVID issuer) | **Phase 3+** | +| Rotation of the CA itself (re-sign all executors) | **Phase 3** ops | +| Envelope encryption of `hmac_seed` at rest (column is forward-compatible; W3a stores raw bytes) | **Phase 3** | +| Dual-auth transition window (bearer + mTLS in parallel) | not needed — internal beta tolerates hard cutover | +| WebSocket auth (UI WS subscription) | Phase 1 W3 already uses bearer; revisit Phase 3 | +| PG-backed / Redis nonce store (survives controller restart) | **Phase 3** — W3a's in-process store is bounded by the ±5min window | + +--- + +## 2. Tech Stack Additions + +W3a **adds two runtime dependencies** (correcting an earlier assumption — `cryptography` is currently only a transitive dep of boto3/httpx, and no JWT library is present): + +| Package | Version pin | Why | +|---|---|---| +| `cryptography` | `>=43,<44` | Promoted from transitive to explicit — W3a uses `cryptography.x509` + `cryptography.hazmat` directly for CA generation, CSR signing, Ed25519 keys. Pinning it explicitly is correct hygiene now that it's a first-class dependency. | +| `pyjwt[crypto]` | `>=2.9,<3.0` | EdDSA (Ed25519) JWT signing + verification. PyJWT is lighter than `python-jose` and its `[crypto]` extra reuses the same `cryptography` backend. | + +No new dev dependencies (`pytest` + existing fixtures). No new CI jobs. uvicorn (already pinned) provides the TLS termination via its `--ssl-*` flags. + +`uv add cryptography pyjwt` updates `pyproject.toml` + `uv.lock`. Per the `feedback_uv_ci_version_pin` memory: the CI `uv` is pinned to 0.11.9 — these are standard `dependencies` entries, not PEP 735 groups, so no `--all-groups`-style incompatibility. + +--- + +## 3. Components + +### 3.1 New: `src/dlw/auth/ca.py` + +Self-signed CA generation, CSR signing, fingerprint extraction, **and the controller's own server cert** (for uvicorn TLS). + +```python +@dataclass(frozen=True) +class CABundle: + cert_pem: bytes + key_pem: bytes + cert: x509.Certificate + key: ed25519.Ed25519PrivateKey + + +def bootstrap_ca(ca_dir: Path) -> CABundle: + """Idempotent: load existing CA from disk, else generate + persist. + Files: ca-cert.pem, ca-key.pem (chmod 600). CA validity 10 years.""" + + +def sign_csr(ca: CABundle, csr_pem: bytes, executor_id: str, ttl_hours: int = 24) -> bytes: + """Sign an executor CSR. CN = executor_id. SAN carries + URI:spiffe://dlw/executor/. ExtendedKeyUsage = CLIENT_AUTH. + Raises ValueError on invalid CSR signature.""" + + +def fingerprint_of(cert_pem: bytes) -> str: + """SHA256 fingerprint as 'SHA256:' — stored on executors.cert_fingerprint.""" +``` + +Implementation notes: + +- CA + executor + server certs all use **Ed25519** keys (`ed25519.Ed25519PrivateKey`). `cryptography` signs Ed25519 certs with `.sign(key, None)` (the hash-algorithm arg is `None` for Ed25519). +- CA cert: `BasicConstraints(ca=True, path_length=0)`, `KeyUsage(key_cert_sign=True, crl_sign=True)`. +- Executor cert: `BasicConstraints(ca=False)`, `KeyUsage(digital_signature=True)`, `ExtendedKeyUsage([CLIENT_AUTH])`, SAN `URI:spiffe://dlw/executor/`. +- `sign_csr` validates `csr.is_signature_valid` before signing. + +#### 3.1.1 Server cert (`_ensure_server_cert`) + +`bootstrap_ca` is paired with a server-cert helper used at the same bootstrap point: + +```python +def ensure_server_cert(ca: CABundle, ca_dir: Path, + hostname: str = "dlw-controller") -> tuple[Path, Path]: + """Idempotent: load or generate server-cert.pem + server-key.pem (chmod 600). + + CN = hostname. SubjectAlternativeName MUST include: + - DNS:localhost + - DNS: (from $DLW_CONTROLLER_HOSTNAME, default 'dlw-controller') + - IP:127.0.0.1 + - IP:::1 + TTL = 10 years (matches CA — server cert is not rotated in Phase 2). + ExtendedKeyUsage = SERVER_AUTH. + + Returns (server_cert_path, server_key_path) for the uvicorn --ssl-* flags. + """ +``` + +The SAN list is the load-bearing detail: without `DNS:localhost` + `IP:127.0.0.1`, an executor connecting to `https://localhost:8443` fails httpx hostname verification. The implementer MUST include all four SAN entries. `hostname` comes from `DLW_CONTROLLER_HOSTNAME` env (default `dlw-controller`); operators set it to the real hostname in prod, and the cert is regenerated only if absent. + +### 3.2 New: `src/dlw/auth/jwt_signing.py` + +Ed25519 JWT signing + verification via **PyJWT**. + +```python +@dataclass(frozen=True) +class JWTKeypair: + priv_pem: bytes + pub_pem: bytes + + +def bootstrap_keypair(ca_dir: Path) -> JWTKeypair: + """Idempotent: load or generate jwt-signing.pem (chmod 600, PKCS8 Ed25519 + private key). Public key is derived on load.""" + + +def sign(kp: JWTKeypair, *, executor_id: str, epoch: int, + scopes: list[str], ttl_seconds: int = 3600) -> str: + """jwt.encode({iss:'dlw-controller', sub:executor_id, epoch, scope:' '.join(scopes), + iat, exp}, kp.priv_pem, algorithm='EdDSA').""" + + +def verify(kp: JWTKeypair, token: str) -> dict[str, Any]: + """jwt.decode(token, kp.pub_pem, algorithms=['EdDSA'], issuer='dlw-controller', + options={'require': ['sub','epoch','scope','exp','iss','iat']}). + Raises jwt.PyJWTError on any failure (signature / expiry / shape / issuer).""" +``` + +PyJWT API: `jwt.encode(payload, key, algorithm="EdDSA")` accepts a PEM-encoded Ed25519 private key; `jwt.decode(token, key, algorithms=["EdDSA"], ...)` accepts the PEM public key. The `exp` claim is validated automatically by PyJWT; `issuer=` triggers `iss` validation; `options={"require": [...]}` enforces claim presence. + +### 3.3 New: `src/dlw/auth/hmac_nonce.py` + +```python +class NonceStore: + """In-process LRU with timestamp-based eviction. asyncio single-threaded — + no lock needed. Restart loses state; replay defense is bounded by the + ±5min timestamp window enforced at the dependency layer.""" + + def __init__(self, *, maxsize: int = 10_000, ttl_seconds: int = 300) -> None: ... + def seen(self, nonce: str) -> bool: ... # evicts expired, then checks membership + def add(self, nonce: str) -> None: ... # evicts expired + LRU-trims, then inserts + + +def compute_hmac(hmac_seed: bytes, *, ts: int, nonce: str, body: bytes) -> str: + """HMAC-SHA256(hmac_seed, f'{ts}:{nonce}:'.encode() + body).hexdigest().""" + + +def verify_hmac(hmac_seed: bytes, *, ts: int, nonce: str, body: bytes, + signature_hex: str) -> bool: + """hmac.compare_digest(compute_hmac(...), signature_hex) — constant-time.""" +``` + +`NonceStore` uses an `OrderedDict[str, float]` keyed by nonce, value = `time.monotonic()` insertion time. `seen()` and `add()` both call `_evict_expired()` first (pop from the front while value < `now - ttl`). `add()` also LRU-trims when `len >= maxsize`. + +### 3.4 New: FastAPI dependencies + +Three chained dependencies. Each builds on the prior: + +**`src/dlw/auth/executor_mtls.py` — `require_executor_mtls`:** +- Reads the verified peer cert from `request.scope` (uvicorn TLS path: `request.scope["transport"].get_extra_info("peercert")` → DER → PEM) OR, when `DLW_TLS_TRUSTED_PROXY=1`, from the `X-Client-Cert-PEM` header (reverse-proxy / test path). +- Computes the fingerprint, looks up `executors` by `cert_fingerprint`. 401 on missing cert or unknown fingerprint. +- Returns the `Executor` row. + +**`src/dlw/auth/executor_jwt_dep.py` — `require_executor_jwt`:** +- `Depends(require_executor_mtls)` → has the `Executor` row. +- Reads `Authorization: Bearer `; `jwt_signing.verify(app.state.jwt_keypair, token)`. 401 on any `PyJWTError`. +- Asserts `claims["sub"] == ex.id`. 401 on mismatch. +- Returns the `Executor` row. + +**`src/dlw/auth/hmac_heartbeat_dep.py` — `require_hmac_heartbeat`:** +- `Depends(require_executor_jwt)` → has the `Executor` row. +- Reads `X-HMAC-Timestamp` / `X-HMAC-Nonce` / `X-HMAC-Signature` (all required headers — 422 if absent). +- `abs(now - ts) > 300` → 401 `CLOCK_SKEW`. +- `app.state.nonce_store.seen(nonce)` → 401 `REPLAY_DETECTED`. +- `verify_hmac(hmac_seed, ts=, nonce=, body=await request.body(), signature_hex=)` → 401 `HMAC_INVALID` on mismatch. +- On success: `nonce_store.add(nonce)`, return the `Executor` row. +- `hmac_seed` comes from `_decrypt_hmac_seed(ex)` — Phase 2 returns `bytes(ex.hmac_seed_encrypted)` raw; Phase 3 swaps for a KMS decrypt. 401 `HMAC_SEED_MISSING` if the column is NULL (executor must re-register). + +**`require_executor_epoch` is refactored** (not "unchanged"). W1's version takes the path `executor_id` + `X-Executor-Epoch` header and does a *fresh* DB lookup by path id. Under W3a that fresh lookup is a confused-deputy gap: an attacker with a valid cert + JWT for executor **A** could hit `/executors/B/heartbeat` with B's epoch, and the W1 dep would happily validate against B's row. The W3a `require_executor_epoch`: + +1. `Depends(require_executor_jwt)` → receives the `Executor` row already loaded from the mTLS cert fingerprint (call it `ex_mtls`). +2. Asserts the path `executor_id` parameter equals `ex_mtls.id` → 401 `EXECUTOR_ID_MISMATCH` otherwise. This binds the URL to the authenticated identity. +3. Compares `X-Executor-Epoch` against `ex_mtls.epoch` (W1's fence check) — using the already-loaded row, no second lookup. +4. Returns `ex_mtls`. + +The W1 fence semantics (epoch must match) are preserved; the change is *which* row the epoch is checked against (the mTLS-authenticated row, not a path-id lookup) plus the new path-vs-identity assertion. + +### 3.5 Modified: `src/dlw/api/executors.py` + +**Deleted:** `POST /api/v1/executors/join` (the W1 bearer-auth endpoint). + +**New `POST /api/v1/executors/register`** — enrollment-token auth, signs the CSR, INSERTs-or-bumps the executor row (mirrors W1 `join_executor`'s `pg_insert ... ON CONFLICT DO UPDATE` epoch-bump semantics), generates a 256-bit `hmac_seed`, returns `RegistrationResponse`. + +**New `POST /api/v1/executors/{eid}/renew`** — `Depends(require_executor_jwt)`. Signs a fresh JWT. If the peer cert TTL is within 1h, re-signs the cert (reusing the peer cert's public key) and updates `executors.cert_fingerprint`. Returns `RenewResponse`. + +**Modified `POST /api/v1/executors/{eid}/heartbeat`** — drops `Depends(require_bearer)`; the dependency chain is `require_executor_epoch` (W1) + `require_hmac_heartbeat` (which transitively pulls `require_executor_jwt` → `require_executor_mtls`). Body shape (`ExecutorHeartbeat`) unchanged. + +**Modified `POST /api/v1/executors/{eid}/poll`** — drops `Depends(require_bearer)`; chain is `require_executor_epoch` + `require_executor_jwt`. Body shape unchanged. + +### 3.6 Modified: `src/dlw/api/subtasks.py` + +`POST /api/v1/subtasks/{id}/report` — drops `Depends(require_bearer)`; chain is `require_executor_jwt` + `require_executor_epoch`. Body shape (`SubTaskReport`) unchanged. + +### 3.7 New schemas in `src/dlw/schemas/executor.py` + +```python +class ExecutorRegister(BaseModel): + host_id: str + executor_id_proposal: str + capabilities: dict[str, Any] = {} + client_csr_pem: str + + +class RegistrationResponse(BaseModel): + executor_id: str + epoch: int + client_cert_pem: str + ca_chain: list[str] + executor_jwt: str + hmac_seed_hex: str + cert_renew_in_seconds: int + jwt_renew_in_seconds: int + + +class RenewResponse(BaseModel): + executor_jwt: str + jwt_renew_in_seconds: int + client_cert_pem: str | None # non-null only when cert was rotated + cert_renew_in_seconds: int | None +``` + +The W1 `ExecutorJoin` schema is deleted alongside the `/join` endpoint. + +### 3.8 Modified: `src/dlw/services/executor_service.py` + +W1's `join_executor` (the `pg_insert ... ON CONFLICT DO UPDATE` INSERT-or-bump) is **renamed and extended** to `upsert_executor_with_cert(session, *, executor_id, host_id, capabilities, cert_fingerprint, hmac_seed)`. Same atomic INSERT-or-bump semantics — epoch starts at 1 on insert, bumps by 1 on conflict — but now also writes `cert_fingerprint` and `hmac_seed_encrypted`. `record_heartbeat` is unchanged (W2b1 already routes it through `transition_executor`). + +The W3a `/register` endpoint calls `upsert_executor_with_cert`. `join_executor`'s callers (W1 tests) migrate to `upsert_executor_with_cert` or to the `/register` HTTP path. + +### 3.9 Modified: `src/dlw/main.py` + +`lifespan` startup gains a bootstrap block (before the W1 recovery routine): + +```python +from dlw.auth.ca import bootstrap_ca, ensure_server_cert +from dlw.auth.jwt_signing import bootstrap_keypair +from dlw.auth.hmac_nonce import NonceStore + +ca_dir = Path(settings.ca_dir) +ca_dir.mkdir(mode=0o700, exist_ok=True) +ca = bootstrap_ca(ca_dir) +ensure_server_cert(ca, ca_dir, hostname=settings.controller_hostname) +jwt_kp = bootstrap_keypair(ca_dir) +enrollment_token = _ensure_enrollment_token(ca_dir, settings) +app.state.ca = ca +app.state.jwt_keypair = jwt_kp +app.state.nonce_store = NonceStore(maxsize=10_000, ttl_seconds=300) +app.state.enrollment_token = enrollment_token +``` + +`_ensure_enrollment_token`: if `DLW_ENROLLMENT_TOKEN` env is set, use it; else read `${ca_dir}/enrollment.token`; else generate a 256-bit hex token, write the file (chmod 600), and log it once at INFO so the operator can copy it out-of-band. + +The W2a/W2b2 `_sweep_loop_main` and W1 recovery routine are unchanged. uvicorn `--ssl-*` flags are passed at the CLI / deployment layer (documented in `docs/operator/`), not in `create_app`. + +### 3.10 New: `src/dlw/executor/cert.py` + +```python +def build_csr(executor_id: str) -> tuple[bytes, bytes]: + """Generate an Ed25519 keypair + CSR. CN = executor_id. + Returns (csr_pem, private_key_pem).""" + +def persist(cert_dir: Path, *, cert_pem: bytes, key_pem: bytes, + ca_chain_pem: bytes, hmac_seed: bytes) -> None: + """Write client-cert.pem / client-key.pem / ca-chain.pem / hmac-seed + (all chmod 600) into cert_dir (chmod 700).""" + +def load(cert_dir: Path) -> tuple[bytes, bytes, bytes, bytes] | None: + """Return (cert_pem, key_pem, ca_chain_pem, hmac_seed) or None if absent.""" + +def fingerprint(cert_pem: bytes) -> str: + """SHA256: — same format as controller's ca.fingerprint_of.""" +``` + +### 3.11 New: `src/dlw/executor/auth_lifecycle.py` + +```python +@dataclass +class AuthState: + executor_id: str + epoch: int + cert_pem: bytes + key_pem: bytes + ca_chain_pem: bytes + jwt: str + jwt_exp: datetime + cert_exp: datetime + hmac_seed: bytes + cert_dir: Path # for the httpx cert= file paths + + +async def register(*, controller_url, ca_bundle_path, enrollment_token, + executor_id, host_id, capabilities, cert_dir) -> AuthState: + """Build CSR, POST /register, persist cert+key+ca+seed to cert_dir, + return AuthState with the parsed JWT/cert expiry timestamps.""" + + +async def renew(state: AuthState, *, controller_url) -> AuthState: + """POST /{eid}/renew using the current cert (mTLS) + JWT. Update the JWT; + if the response carries a new cert, persist it + update the fingerprint + timestamps. Return a fresh AuthState.""" + + +async def load_or_register(*, cert_dir, controller_url, ca_bundle_path, + enrollment_token, executor_id, host_id, + capabilities) -> AuthState: + """If cert_dir has a persisted cert: load it + run renew() to refresh the + JWT. Else: run register(). Returns a ready AuthState.""" +``` + +### 3.12 Modified: `src/dlw/executor/client.py` + +`ControllerClient` takes an `AuthState` reference (mutable — the renewal loop updates it in place). Each request builds an `httpx.AsyncClient(verify=, cert=(, ), headers={"Authorization": f"Bearer {jwt}"})`. The `heartbeat` method additionally computes the HMAC headers (`X-HMAC-Timestamp` / `X-HMAC-Nonce` / `X-HMAC-Signature`) over the JSON body via `compute_hmac`. + +### 3.13 Modified: `src/dlw/executor/runner.py` + +`run()` gains an auth-bootstrap step before `/join` (now `load_or_register`) and spawns a **third** background task `_auth_renew_loop` alongside `_heartbeat_loop` and `_poll_and_execute_loop`: + +```python +async def _auth_renew_loop(self) -> None: + while not self._shutdown.is_set(): + now = datetime.now(UTC) + jwt_due = self._auth.jwt_exp - timedelta(minutes=5) + cert_due = self._auth.cert_exp - timedelta(hours=1) + sleep_for = max(60, int((min(jwt_due, cert_due) - now).total_seconds())) + try: + await asyncio.wait_for(self._shutdown.wait(), timeout=sleep_for) + return # shutdown + except asyncio.TimeoutError: + pass + try: + self._auth = await renew(self._auth, controller_url=self._s.controller_url) + self._client.update_auth(self._auth) + except Exception as e: + logger.warning("auth renew failed: %s; retry next cycle", e) +``` + +The W1 `EPOCH_MISMATCH` re-join path in `_poll_and_execute_loop` is generalized: on a 401 from a protected endpoint, the loop attempts `load_or_register` (which falls back to `/register` with the persisted enrollment token) before giving up — same structure as the existing rejoin logic. + +### 3.14 Modified: `src/dlw/executor/config.py` + `src/dlw/config.py` + +Executor `ExecutorSettings` adds: + +```python +enrollment_token: str = Field(default="", description="OOB enrollment token from operator.") +executor_cert_dir: str = Field(default="~/.dlw/executor") +executor_ca_bundle: str = Field(default="") # defaults to {cert_dir}/ca-chain.pem at runtime +``` + +Controller `Settings` adds: + +```python +ca_dir: str = Field(default="./.ca") +enrollment_token: str = Field(default="") # if empty, bootstrap generates one +controller_hostname: str = Field(default="dlw-controller") +tls_trusted_proxy: bool = Field(default=False) # DLW_TLS_TRUSTED_PROXY +``` + +### 3.15 Modified: `tools/lint_invariants.py` + +New helper `check_no_bearer_on_executor_routes`: AST-scans `src/dlw/api/executors.py` + `src/dlw/api/subtasks.py` for any `Depends(require_bearer)` in a route decorator's `dependencies=[...]` list or as a parameter default. Any hit → failure. Wired into `main()` next to the W2b2 helpers. This locks the migration in — a future change that re-adds bearer to an executor route gets caught. + +--- + +## 4. Schema Changes + +**One alembic migration** (`_p2w3a_hmac_seed.py`): + +```python +def upgrade() -> None: + op.add_column( + "executors", + sa.Column("hmac_seed_encrypted", sa.LargeBinary(), nullable=True), + ) + +def downgrade() -> None: + op.drop_column("executors", "hmac_seed_encrypted") +``` + +`down_revision` = `b1d5ea4944ba` (W2b2 `last_paused_at`). Nullable — pre-W3a executor rows survive with `hmac_seed_encrypted=NULL`; they re-register on next runner restart. ORM: add `hmac_seed_encrypted: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)` to `Executor`. + +`executors.cert_fingerprint` (W1) and `executors.epoch` (W1) already exist. No other DDL. + +--- + +## 5. Wire Format Changes + +### 5.1 New endpoint `POST /api/v1/executors/register` + +| Aspect | Value | +|---|---| +| Auth | `X-Enrollment-Token` header (no mTLS, no JWT) | +| Request | `ExecutorRegister` (host_id, executor_id_proposal, capabilities, client_csr_pem) | +| Response 201 | `RegistrationResponse` (cert + ca_chain + jwt + hmac_seed_hex + renew intervals) | +| 200 | Same body, returned when re-registering an existing executor_id (epoch bumped) | +| 401 | Invalid / missing enrollment token | +| 422 | Malformed CSR | + +### 5.2 New endpoint `POST /api/v1/executors/{eid}/renew` + +| Aspect | Value | +|---|---| +| Auth | mTLS peer cert + `Authorization: Bearer ` | +| Request | empty `{}` | +| Response 200 | `RenewResponse` (new jwt; new cert iff old cert TTL < 1h) | +| 401 | Missing/invalid mTLS cert, expired/invalid JWT, or `sub`/fingerprint mismatch | + +### 5.3 `POST /api/v1/executors/join` — DELETED + +The W1 bearer-auth `/join` endpoint and its `ExecutorJoin` schema are removed entirely. No transitional alias. + +### 5.4 Heartbeat headers (additive) + +`POST /api/v1/executors/{eid}/heartbeat` body shape unchanged; new **required** headers: `Authorization: Bearer ` (was: shared-secret bearer), `X-HMAC-Timestamp`, `X-HMAC-Nonce`, `X-HMAC-Signature`. `X-Executor-Epoch` (W1) still required. + +### 5.5 Poll / Report auth (no body change) + +`POST /api/v1/executors/{eid}/poll` and `POST /api/v1/subtasks/{id}/report` switch the `Authorization` header from the shared-secret bearer to the executor JWT. Body shapes unchanged. + +### 5.6 OpenAPI + +`api/openapi.yaml`: add `/executors/register` + `/executors/{executorId}/renew` operations; remove `/executors/join`; document the HMAC headers on the heartbeat operation; add `RegistrationResponse` / `RenewResponse` / `ExecutorRegister` schemas. The aspirational `04 §2.2.1` doc already sketches the register shape — align with it. + +--- + +## 6. Error Handling Matrix + +| Situation | Behaviour | +|---|---| +| `/register` with bad enrollment token | 401 `invalid enrollment token` | +| `/register` with malformed CSR | 422; `sign_csr` raises `ValueError`, endpoint maps to 422 | +| `/register` re-register of existing executor_id | 200; `upsert_executor_with_cert` bumps epoch + regenerates hmac_seed + new cert (W1 fence semantics — old-epoch in-flight subtasks get reclaimed by the W2a sweeper) | +| Protected endpoint, no mTLS peer cert | 401 `missing or invalid mTLS peer cert` (from `require_executor_mtls`) | +| Protected endpoint, peer cert fingerprint not in DB | 401 `cert fingerprint not registered` | +| Protected endpoint, expired JWT | 401 from `require_executor_jwt` (`PyJWTError`); executor's renew loop or 401-handler re-registers | +| Protected endpoint, JWT `sub` ≠ mTLS executor_id | 401 `JWT sub mismatch` — cert + token belong to different executors | +| Heartbeat, timestamp skew > 5min | 401 `CLOCK_SKEW` | +| Heartbeat, nonce already in store | 401 `REPLAY_DETECTED` | +| Heartbeat, signature mismatch | 401 `HMAC_INVALID` | +| Heartbeat, `hmac_seed_encrypted` is NULL (pre-W3a row) | 401 `HMAC_SEED_MISSING` — executor must re-register | +| Path `executor_id` ≠ mTLS-authenticated executor (confused deputy) | 401 `EXECUTOR_ID_MISMATCH` from the refactored `require_executor_epoch` (§3.4) | +| `/renew`, peer cert valid but JWT expired | 401 — `/renew` requires a *valid* JWT; if the JWT genuinely expired, the executor re-registers via the enrollment token | +| Controller restart loses nonce store | A captured heartbeat with a ≤5min-old timestamp replays exactly once post-restart; impact is limited to one idempotent `last_heartbeat_at` write — no state corruption | +| uvicorn TLS not configured but `DLW_TLS_TRUSTED_PROXY=0` | `require_executor_mtls` finds no peer cert + no trusted header → 401 on every protected request — fails closed | +| Executor clock ahead/behind controller | `CLOCK_SKEW` on heartbeat; operator runs `chrony`/`systemd-timesyncd` (standard) | +| Server cert SAN missing `localhost` / `127.0.0.1` | Executor httpx `verify=` fails the TLS handshake — `ensure_server_cert` MUST include all four SAN entries (§3.1.1) | + +--- + +## 7. Testing Strategy + +### 7.1 Unit + integration (~25 new cases) + +| # | File | Case | What it asserts | +|---|---|---|---| +| 1 | `tests/auth/test_ca.py` | `test_bootstrap_ca_idempotent` | Two calls on same dir → identical cert (load path) | +| 2 | same | `test_sign_csr_returns_valid_client_cert` | CN == executor_id, 24h TTL, signed by CA, EKU=CLIENT_AUTH | +| 3 | same | `test_fingerprint_of_is_deterministic_sha256` | "SHA256:" format, stable for same cert | +| 4 | same | `test_ensure_server_cert_has_required_sans` | SAN includes DNS:localhost, DNS:, IP:127.0.0.1, IP:::1 | +| 5 | `tests/auth/test_jwt_signing.py` | `test_bootstrap_keypair_idempotent` | Second call → same keypair | +| 6 | same | `test_sign_and_verify_roundtrip` | sign → verify returns matching claims | +| 7 | same | `test_verify_rejects_expired_token` | `exp` in the past → PyJWTError | +| 8 | same | `test_verify_rejects_wrong_issuer` | tampered `iss` → PyJWTError | +| 9 | `tests/auth/test_hmac_nonce.py` | `test_hmac_compute_and_verify_roundtrip` | compute → verify_hmac True | +| 10 | same | `test_hmac_verify_rejects_tampered_body` | body off by 1 byte → False | +| 11 | same | `test_nonce_store_first_add_then_seen` | add(n) → seen(n) True | +| 12 | same | `test_nonce_store_evicts_after_ttl` | monkeypatch monotonic, expire → seen() False | +| 13 | `tests/auth/test_executor_mtls_dep.py` | `test_require_executor_mtls_via_trusted_proxy_header` | `DLW_TLS_TRUSTED_PROXY=1` + `X-Client-Cert-PEM` → returns Executor row | +| 14 | same | `test_require_executor_mtls_rejects_unknown_fingerprint` | header with unregistered cert → 401 | +| 15 | same | `test_require_executor_mtls_rejects_header_when_proxy_disabled` | `DLW_TLS_TRUSTED_PROXY=0` + header → 401 (header ignored) | +| 16 | `tests/auth/test_executor_jwt_dep.py` | `test_require_executor_jwt_accepts_valid_token` | valid JWT + matching mTLS → Executor | +| 17 | same | `test_require_executor_jwt_rejects_sub_mismatch` | JWT `sub` ≠ mTLS executor_id → 401 | +| 17b | `tests/auth/test_executor_epoch.py` | `test_require_executor_epoch_rejects_path_id_mismatch` | mTLS+JWT for executor A, path `/executors/B/...` → 401 EXECUTOR_ID_MISMATCH (confused-deputy guard) | +| 18 | `tests/auth/test_hmac_heartbeat_dep.py` | `test_hmac_heartbeat_accepts_valid_signature` | mTLS + JWT + correct HMAC headers → passes | +| 19 | same | `test_hmac_heartbeat_rejects_clock_skew` | ts off by 400s → 401 CLOCK_SKEW | +| 20 | same | `test_hmac_heartbeat_rejects_replay` | same nonce twice → 2nd is 401 REPLAY_DETECTED | +| 21 | same | `test_hmac_heartbeat_rejects_tampered_body` | sig computed over body A, POST body B → 401 HMAC_INVALID | +| 22 | `tests/api/test_register_endpoint.py` | `test_register_returns_cert_jwt_and_hmac_seed` | enrollment token + CSR → 201 + all 4 fields populated | +| 23 | same | `test_register_rejects_invalid_enrollment_token` | wrong token → 401 | +| 24 | same | `test_register_idempotent_on_reregister` | same executor_id twice → epoch bumped, new fingerprint | +| 25 | `tests/api/test_renew_endpoint.py` | `test_renew_returns_new_jwt_only_when_cert_fresh` | cert TTL > 1h → `client_cert_pem` is null | +| 26 | same | `test_renew_returns_new_cert_when_under_1h` | deliberately short-TTL cert → renew returns a new cert | +| 27 | `tests/e2e/test_executor_auth_e2e.py` | `test_register_then_heartbeat_full_flow` | Real uvicorn TLS subprocess: register → heartbeat 200 with HMAC | + +Count ≈ 27. The single e2e (test 27) is the load-bearing wiring check; everything else is module-level. + +### 7.2 mTLS test strategy + +- **Units (tests 1-26):** bypass real TLS. `tests/conftest.py` gains `ephemeral_ca` (session-scoped — one CA per session) + `client_cert_pair` (per-test client cert signed by that CA) fixtures. Protected-endpoint tests set `DLW_TLS_TRUSTED_PROXY=1` (via `monkeypatch.setenv`) and send `X-Client-Cert-PEM`. +- **e2e (test 27):** spawns uvicorn in a subprocess with real `--ssl-*` flags pointing at the `ephemeral_ca`'s server cert; an httpx client connects with `verify=` + `cert=`. This is the only test that touches real TLS — it protects the uvicorn wiring + peer-cert extraction path. Slow (~1-2s) but runs once. +- **Why the two layers:** identical rationale to W2a/W2b1's "test pyramid" — module units stay fast (header bypass), one e2e guards the integration seam. + +### 7.3 Existing test migration + +`/join` deletion forces test-setup migration: + +- `tests/api/test_executors.py` — the `joined_executor` fixture → `registered_executor` (build CSR, call `/register`). Heartbeat tests gain HMAC headers (a `_signed_heartbeat_headers(auth_state, body)` helper in conftest). +- `tests/api/test_subtasks.py` — same fixture migration. +- `tests/e2e/test_executor_e2e.py` / `test_happy_path.py` — full flow now starts with `/register`. +- `tests/services/test_executor_service.py` — `join_executor` → `upsert_executor_with_cert` rename; the W1 INSERT-or-bump cases still apply, plus 2 new assertions (cert_fingerprint + hmac_seed written). +- `tests/auth/test_executor_epoch.py` — W1 fence dep unchanged; tests that constructed an executor via `/join` switch to a direct DB insert or `/register`. + +Estimated existing-test churn: ~12-15 setups across 5 files. All mechanical (fixture-level), no logic changes. + +### 7.4 Test infra + +- `cryptography` + `pyjwt[crypto]` — added as runtime deps (§2); `uv sync` picks them up. CI `uv` 0.11.9 handles plain `dependencies` entries fine. +- `ephemeral_ca` / `client_cert_pair` fixtures in `tests/conftest.py` — new, ~30 LOC. +- `_signed_heartbeat_headers` helper — new, computes the 3 HMAC headers for a given body + seed. +- No new pytest plugins. No new CI jobs. + +### 7.5 CI 12 checks expectations + +| Check | W3a impact | +|---|---| +| pytest | +27 new, ~12-15 modified W1 setups | +| Invariant + cross-ref lint | **+`check_no_bearer_on_executor_routes`** helper | +| OpenAPI lint | `/register` + `/renew` added, `/join` removed, HMAC headers documented | +| Markdown lint | spec/plan cross-ref 04 §2.2 + 02 §4.1 | +| Other 8 | no change | + +### 7.6 Not tested + +- Real NTP clock sync — `CLOCK_SKEW` tests monkeypatch `time.time`. +- High-concurrency nonce store (>10k) — Phase 3 P-001 baseline. +- CA rotation / forced re-enrollment — Phase 3 ops. +- CRL / cert revocation — Phase 3+. +- The `DLW_TLS_TRUSTED_PROXY=1` production-with-real-proxy path — only the header-bypass logic is unit-tested; a real nginx/Caddy front is operator-territory. + +--- + +## 8. Acceptance Criteria + +- [ ] `uv add cryptography pyjwt` — both pinned in `pyproject.toml`; `uv.lock` updated; `uv sync` clean. +- [ ] 1 alembic migration applies clean from W2b2 head (`b1d5ea4944ba`), reverses clean. `executors.hmac_seed_encrypted` exists. +- [ ] `dlw.auth.ca` (incl. `ensure_server_cert` with all 4 SANs) + `jwt_signing` + `hmac_nonce` modules + ~12 unit tests pass. +- [ ] `require_executor_mtls` / `require_executor_jwt` / `require_hmac_heartbeat` deps + ~9 unit tests pass. +- [ ] `POST /register` + `POST /{eid}/renew` endpoints + ~5 API tests pass. +- [ ] One e2e (`test_register_then_heartbeat_full_flow`) exercises real uvicorn TLS and passes. +- [ ] W1 `/join` deleted; `ExecutorJoin` schema deleted; all `/join`-calling test setups migrated. +- [ ] `join_executor` → `upsert_executor_with_cert` rename; writes `cert_fingerprint` + `hmac_seed_encrypted`. +- [ ] Executor side: `cert.py` / `auth_lifecycle.py` new; `client.py` / `runner.py` / `config.py` modified; runner spawns the 3rd background task. +- [ ] `tools/lint_invariants.py` `check_no_bearer_on_executor_routes` reports 0 on production tree. +- [ ] OpenAPI: `/register` + `/renew` documented; `/join` removed; heartbeat HMAC headers added; spectral clean. +- [ ] `docs/operator/` documents `${DLW_CA_DIR}` layout, enrollment-token OOB distribution, uvicorn `--ssl-*` flags, the `DLW_TLS_TRUSTED_PROXY` forgery warning, and host-clock-sync requirement. +- [ ] No new CI jobs. Full suite green: baseline 181 + ~27 new + ~12-15 modified W1 setups. + +--- + +## 9. Implementation Phasing (preview for plan) + +The plan will be written by the writing-plans skill after spec approval. Expected milestone shape (4 milestones, ~10 tasks): + +- **M1 — Auth substrate.** `uv add` deps + alembic + ORM column + `ca.py` (incl. server cert) + `jwt_signing.py` + `hmac_nonce.py` + ~12 unit tests. No endpoint wiring yet. +- **M2 — Controller deps + endpoints.** 3 FastAPI deps + `/register` + `/renew` endpoints + `main.py` bootstrap + `upsert_executor_with_cert` rename + ~14 unit/API tests. `/join` deleted; W1 test setups migrated. +- **M3 — Endpoint auth migration.** heartbeat/poll/report swap to mTLS+JWT(+HMAC); the e2e test; remaining W1 test-setup migration. +- **M4 — Executor side + lint + OpenAPI + PR.** `cert.py` / `auth_lifecycle.py` / `client.py` / `runner.py` / `config.py` + `check_no_bearer_on_executor_routes` lint + OpenAPI + operator runbook + PR. + +Branch: `feat/phase-2-w3a-mtls-jwt-hmac`. Branched off `main` at `ba89a91` (PR #11 merge). + +--- + +## 10. References + +- Spec source: brainstormed 2026-05-14 (this document). +- Roadmap: `docs/v2.0/08-mvp-roadmap.md` §2.6 Phase 2 W3 Day 1-3. +- Security: `docs/v2.0/04-security-and-tenancy.md` §2.2 (Executor auth, SEC-01 + SEC-04), §2.2.1-2.2.4. +- Protocol: `docs/v2.0/02-protocol.md` §4.1 (heartbeat HMAC). +- Invariants: `docs/v2.0/INVARIANTS.md` §B (security/tenancy) — INVARIANT 1, 4, 44. +- Predecessor specs: + - W1: `docs/superpowers/specs/2026-05-11-phase-2-week-1-fence-token-recovery-design.md` + - W2a: `docs/superpowers/specs/2026-05-13-phase-2-w2a-scheduler-state-machine-design.md` + - W2b1: `docs/superpowers/specs/2026-05-13-phase-2-w2b1-chunk-level-downloader-design.md` + - W2b2: `docs/superpowers/specs/2026-05-14-phase-2-w2b2-cancel-and-paused-external-design.md` +- W2b2 PR (merged): https://github.com/l17728/modelpull/pull/11 (squash `ba89a91`). diff --git a/pyproject.toml b/pyproject.toml index 4386ac2..ddf6acc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,8 @@ dependencies = [ "structlog>=24.4,<24.5", "httpx>=0.27,<0.28", "tenacity>=9.0,<10.0", + "cryptography>=43,<44", + "pyjwt[crypto]>=2.9,<3.0", ] [dependency-groups] diff --git a/src/dlw/alembic/versions/6f37b72630ce_p2w3a_hmac_seed.py b/src/dlw/alembic/versions/6f37b72630ce_p2w3a_hmac_seed.py new file mode 100644 index 0000000..20cd275 --- /dev/null +++ b/src/dlw/alembic/versions/6f37b72630ce_p2w3a_hmac_seed.py @@ -0,0 +1,29 @@ +"""p2w3a hmac_seed + +Revision ID: 6f37b72630ce +Revises: b1d5ea4944ba +Create Date: 2026-05-14 15:23:03.022914 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '6f37b72630ce' +down_revision: Union[str, None] = 'b1d5ea4944ba' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + "executors", + sa.Column("hmac_seed_encrypted", sa.LargeBinary(), nullable=True), + ) + + +def downgrade() -> None: + op.drop_column("executors", "hmac_seed_encrypted") diff --git a/src/dlw/api/executors.py b/src/dlw/api/executors.py index b76a0c1..6b1551f 100644 --- a/src/dlw/api/executors.py +++ b/src/dlw/api/executors.py @@ -1,21 +1,25 @@ -"""Executors API: join / heartbeat / poll. - -Phase 2 W1 changes: - - heartbeat + poll now depend on require_executor_epoch (X-Executor-Epoch header - required, must match stored executor.epoch — else 401 EPOCH_MISMATCH). - - /join is unaffected (first contact; controller assigns epoch). - - poll passes executor.epoch to claim_one_subtask so the subtask row - captures the current fence. +"""Executors API: register / renew / heartbeat / poll. + +Phase 2 W3a changes: + - DELETE /join (W1) — replaced by /register with enrollment-token auth. + - ADD /register: signs CSR, issues cert + JWT + hmac_seed. + - ADD /{id}/renew: refreshes JWT; optionally re-signs cert if CSR provided. + - /heartbeat: bearer removed; now requires require_hmac_heartbeat + require_executor_epoch. + - /poll: bearer removed; require_executor_epoch (transitively pulls mTLS+JWT). """ from __future__ import annotations import json +import secrets -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, Header, HTTPException, Request, status from sqlalchemy.ext.asyncio import AsyncSession from dlw.api.tasks import _session -from dlw.auth.bearer import require_bearer +from dlw.auth.ca import fingerprint_of, sign_csr +from dlw.auth import jwt_signing +from dlw.auth.executor_jwt_dep import require_executor_jwt +from dlw.auth.hmac_heartbeat_dep import require_hmac_heartbeat from dlw.auth.executor_epoch import require_executor_epoch from dlw.db.models.executor import Executor from dlw.db.models.storage import StorageBackend @@ -23,29 +27,106 @@ from dlw.schemas.executor import ( AssignmentResponse, ExecutorHeartbeat, - ExecutorJoin, ExecutorRead, + ExecutorRegister, + RegistrationResponse, + RenewRequest, + RenewResponse, ) from dlw.schemas.storage import StorageConfig from dlw.schemas.subtask import SubTaskRead -from dlw.services.executor_service import join_executor, record_heartbeat +from dlw.services.executor_service import upsert_executor_with_cert, record_heartbeat from dlw.services.scheduler import claim_one_subtask router = APIRouter(prefix="/api/v1/executors", tags=["executors"]) -@router.post("/join", status_code=status.HTTP_201_CREATED, dependencies=[Depends(require_bearer)]) -async def post_join( - body: ExecutorJoin, session: AsyncSession = Depends(_session) -) -> ExecutorRead: - ex = await join_executor(session, body) +@router.post("/register", status_code=status.HTTP_201_CREATED) +async def post_register( + body: ExecutorRegister, + request: Request, + x_enrollment_token: str = Header(..., alias="X-Enrollment-Token"), + session: AsyncSession = Depends(_session), +) -> RegistrationResponse: + """W3a: enrollment-token auth; signs CSR; returns cert + JWT + hmac_seed.""" + expected = request.app.state.enrollment_token + if not secrets.compare_digest(x_enrollment_token, expected): + raise HTTPException(401, detail="invalid enrollment token") + try: + cert_pem = sign_csr( + request.app.state.ca, + body.client_csr_pem.encode("utf-8"), + executor_id=body.executor_id_proposal, + ttl_hours=24, + ) + except ValueError as e: + raise HTTPException(422, detail=f"invalid CSR: {e}") from e + fp = fingerprint_of(cert_pem) + hmac_seed = secrets.token_bytes(32) + ex = await upsert_executor_with_cert( + session, executor_id=body.executor_id_proposal, + host_id=body.host_id, capabilities=body.capabilities, + cert_fingerprint=fp, hmac_seed=hmac_seed, + ) + token = jwt_signing.sign( + request.app.state.jwt_keypair, + executor_id=ex.id, epoch=ex.epoch, + scopes=["heartbeat", "poll", "report"], + ) await session.commit() - return ExecutorRead.model_validate(ex) + return RegistrationResponse( + executor_id=ex.id, epoch=ex.epoch, + client_cert_pem=cert_pem.decode("utf-8"), + ca_chain=[request.app.state.ca.cert_pem.decode("utf-8")], + executor_jwt=token, + hmac_seed_hex=hmac_seed.hex(), + cert_renew_in_seconds=86100, + jwt_renew_in_seconds=3300, + ) + + +@router.post("/{executor_id}/renew") +async def post_renew( + executor_id: str, + body: RenewRequest, + request: Request, + ex: Executor = Depends(require_executor_jwt), + session: AsyncSession = Depends(_session), +) -> RenewResponse: + """W3a: always refresh JWT; sign a new cert iff the request carries a CSR.""" + if executor_id != ex.id: + raise HTTPException(401, detail="path executor_id mismatch") + new_jwt = jwt_signing.sign( + request.app.state.jwt_keypair, + executor_id=ex.id, epoch=ex.epoch, + scopes=["heartbeat", "poll", "report"], + ) + new_cert_pem: str | None = None + new_cert_renew_in: int | None = None + if body.client_csr_pem: + try: + new_cert_bytes = sign_csr( + request.app.state.ca, + body.client_csr_pem.encode("utf-8"), + executor_id=ex.id, ttl_hours=24, + ) + except ValueError as e: + raise HTTPException(422, detail=f"invalid CSR: {e}") from e + new_cert_pem = new_cert_bytes.decode("utf-8") + ex.cert_fingerprint = fingerprint_of(new_cert_bytes) + new_cert_renew_in = 86100 + await session.commit() + return RenewResponse( + executor_jwt=new_jwt, jwt_renew_in_seconds=3300, + client_cert_pem=new_cert_pem, + cert_renew_in_seconds=new_cert_renew_in, + ) -@router.post("/{executor_id}/heartbeat", dependencies=[Depends(require_bearer)]) +@router.post("/{executor_id}/heartbeat") async def post_heartbeat( body: ExecutorHeartbeat, + _hmac: Executor = Depends(require_hmac_heartbeat), executor: Executor = Depends(require_executor_epoch), session: AsyncSession = Depends(_session), ) -> ExecutorRead: @@ -57,7 +138,7 @@ async def post_heartbeat( return ExecutorRead.model_validate(ex) -@router.post("/{executor_id}/poll", dependencies=[Depends(require_bearer)]) +@router.post("/{executor_id}/poll") async def post_poll( executor: Executor = Depends(require_executor_epoch), session: AsyncSession = Depends(_session), diff --git a/src/dlw/api/subtasks.py b/src/dlw/api/subtasks.py index cbf2670..3f06d02 100644 --- a/src/dlw/api/subtasks.py +++ b/src/dlw/api/subtasks.py @@ -11,7 +11,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from dlw.api.tasks import _session -from dlw.auth.bearer import require_bearer +from dlw.auth.executor_jwt_dep import require_executor_jwt from dlw.db.models.executor import Executor from dlw.db.models.task import FileSubTask from dlw.schemas.subtask import SubTaskReport @@ -20,20 +20,22 @@ router = APIRouter(prefix="/api/v1/subtasks", tags=["subtasks"]) -@router.post("/{subtask_id}/report", dependencies=[Depends(require_bearer)]) +@router.post("/{subtask_id}/report") async def post_report( subtask_id: uuid.UUID, body: SubTaskReport, # W6-D: use default=None so missing header raises 401 (not FastAPI auto-422) x_executor_epoch: int | None = Header(default=None, alias="X-Executor-Epoch"), + auth_ex: Executor = Depends(require_executor_jwt), session: AsyncSession = Depends(_session), ) -> dict[str, str]: + """W3a: mTLS + JWT auth. The reporting executor must be the one the + subtask is assigned to (confused-deputy guard) and its epoch must match.""" if x_executor_epoch is None: raise HTTPException( status_code=401, detail="missing X-Executor-Epoch header", ) - # Phase 2 W1 fence: load the subtask's claimed executor and verify epoch. sub = await session.get(FileSubTask, subtask_id) if sub is None: raise HTTPException(status_code=404, detail=f"subtask {subtask_id} not found") @@ -41,17 +43,24 @@ async def post_report( raise HTTPException( status_code=409, detail=f"subtask {subtask_id} not assigned" ) - ex = await session.get(Executor, sub.executor_id) - if ex is None: + # W3a confused-deputy guard: the mTLS-authenticated executor MUST be the + # one the subtask is claimed by. + if sub.executor_id != auth_ex.id: raise HTTPException( - status_code=404, detail=f"executor {sub.executor_id} not found" + status_code=401, + detail={ + "code": "EXECUTOR_ID_MISMATCH", + "subtask_executor": sub.executor_id, + "authenticated": auth_ex.id, + }, ) - if ex.epoch != x_executor_epoch: + # Phase 2 W1 fence: verify epoch against the authenticated executor row. + if auth_ex.epoch != x_executor_epoch: raise HTTPException( status_code=401, detail={ "code": "EPOCH_MISMATCH", - "expected": ex.epoch, + "expected": auth_ex.epoch, "got": x_executor_epoch, }, ) diff --git a/src/dlw/auth/ca.py b/src/dlw/auth/ca.py new file mode 100644 index 0000000..824da75 --- /dev/null +++ b/src/dlw/auth/ca.py @@ -0,0 +1,161 @@ +"""Self-signed CA + client cert signing + server cert (Phase 2 W3a §3.1).""" +from __future__ import annotations + +import datetime as _dt +import ipaddress +from dataclasses import dataclass +from pathlib import Path + +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import ed25519 +from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID + + +@dataclass(frozen=True) +class CABundle: + cert_pem: bytes + key_pem: bytes + cert: x509.Certificate + key: ed25519.Ed25519PrivateKey + + +def bootstrap_ca(ca_dir: Path) -> CABundle: + """Idempotent: load existing CA from disk, else generate + persist. + Files: ca-cert.pem, ca-key.pem (chmod 600). CA validity 10 years.""" + cert_path = ca_dir / "ca-cert.pem" + key_path = ca_dir / "ca-key.pem" + if cert_path.exists() and key_path.exists(): + cert_pem = cert_path.read_bytes() + key_pem = key_path.read_bytes() + cert = x509.load_pem_x509_certificate(cert_pem) + key = serialization.load_pem_private_key(key_pem, password=None) + if not isinstance(key, ed25519.Ed25519PrivateKey): + raise ValueError("CA key is not Ed25519 (file corrupted)") + return CABundle(cert_pem=cert_pem, key_pem=key_pem, cert=cert, key=key) + + key = ed25519.Ed25519PrivateKey.generate() + name = x509.Name([ + x509.NameAttribute(NameOID.COMMON_NAME, "dlw-controller-ca"), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "modelpull"), + ]) + now = _dt.datetime.now(_dt.UTC) + cert = (x509.CertificateBuilder() + .subject_name(name).issuer_name(name) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now) + .not_valid_after(now + _dt.timedelta(days=3650)) + .add_extension(x509.BasicConstraints(ca=True, path_length=0), critical=True) + .add_extension( + x509.KeyUsage( + digital_signature=False, content_commitment=False, + key_encipherment=False, data_encipherment=False, + key_agreement=False, key_cert_sign=True, crl_sign=True, + encipher_only=False, decipher_only=False, + ), critical=True, + ) + .sign(key, None) + ) + cert_pem = cert.public_bytes(serialization.Encoding.PEM) + key_pem = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + ca_dir.mkdir(mode=0o700, parents=True, exist_ok=True) + cert_path.write_bytes(cert_pem) + cert_path.chmod(0o600) + key_path.write_bytes(key_pem) + key_path.chmod(0o600) + return CABundle(cert_pem=cert_pem, key_pem=key_pem, cert=cert, key=key) + + +def sign_csr(ca: CABundle, csr_pem: bytes, executor_id: str, + ttl_hours: int = 24) -> bytes: + """Sign an executor CSR. CN = executor_id; SAN URI:spiffe://dlw/executor/; + EKU = CLIENT_AUTH. Raises ValueError on invalid CSR signature.""" + csr = x509.load_pem_x509_csr(csr_pem) + if not csr.is_signature_valid: + raise ValueError("CSR signature invalid") + now = _dt.datetime.now(_dt.UTC) + cert = (x509.CertificateBuilder() + .subject_name(x509.Name([ + x509.NameAttribute(NameOID.COMMON_NAME, executor_id), + ])) + .issuer_name(ca.cert.subject) + .public_key(csr.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now) + .not_valid_after(now + _dt.timedelta(hours=ttl_hours)) + .add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True) + .add_extension( + x509.SubjectAlternativeName([ + x509.UniformResourceIdentifier(f"spiffe://dlw/executor/{executor_id}"), + ]), critical=False, + ) + .add_extension( + x509.KeyUsage( + digital_signature=True, content_commitment=False, + key_encipherment=False, data_encipherment=False, + key_agreement=False, key_cert_sign=False, crl_sign=False, + encipher_only=False, decipher_only=False, + ), critical=True, + ) + .add_extension( + x509.ExtendedKeyUsage([ExtendedKeyUsageOID.CLIENT_AUTH]), critical=True, + ) + .sign(ca.key, None) + ) + return cert.public_bytes(serialization.Encoding.PEM) + + +def fingerprint_of(cert_pem: bytes) -> str: + """SHA256 fingerprint as 'SHA256:' — stored on executors.cert_fingerprint.""" + cert = x509.load_pem_x509_certificate(cert_pem) + return f"SHA256:{cert.fingerprint(hashes.SHA256()).hex()}" + + +def ensure_server_cert(ca: CABundle, ca_dir: Path, + hostname: str = "dlw-controller") -> tuple[Path, Path]: + """Idempotent: load or generate server-cert.pem + server-key.pem (chmod 600). + CN = hostname. SAN = DNS:localhost, DNS:, IP:127.0.0.1, IP:::1. + TTL 10 years. EKU = SERVER_AUTH. Returns (cert_path, key_path).""" + cert_path = ca_dir / "server-cert.pem" + key_path = ca_dir / "server-key.pem" + if cert_path.exists() and key_path.exists(): + return cert_path, key_path + + key = ed25519.Ed25519PrivateKey.generate() + now = _dt.datetime.now(_dt.UTC) + cert = (x509.CertificateBuilder() + .subject_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, hostname)])) + .issuer_name(ca.cert.subject) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now) + .not_valid_after(now + _dt.timedelta(days=3650)) + .add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True) + .add_extension( + x509.SubjectAlternativeName([ + x509.DNSName("localhost"), + x509.DNSName(hostname), + x509.IPAddress(ipaddress.ip_address("127.0.0.1")), + x509.IPAddress(ipaddress.ip_address("::1")), + ]), critical=False, + ) + .add_extension( + x509.ExtendedKeyUsage([ExtendedKeyUsageOID.SERVER_AUTH]), critical=True, + ) + .sign(ca.key, None) + ) + ca_dir.mkdir(mode=0o700, parents=True, exist_ok=True) + cert_path.write_bytes(cert.public_bytes(serialization.Encoding.PEM)) + cert_path.chmod(0o600) + key_path.write_bytes(key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + )) + key_path.chmod(0o600) + return cert_path, key_path diff --git a/src/dlw/auth/executor_epoch.py b/src/dlw/auth/executor_epoch.py index a476264..22f19b0 100644 --- a/src/dlw/auth/executor_epoch.py +++ b/src/dlw/auth/executor_epoch.py @@ -1,43 +1,38 @@ -"""require_executor_epoch — Phase 2 W1 fence-token dependency. +"""require_executor_epoch — W1 fence-token dep, refactored for W3a §3.4. -Reads X-Executor-Epoch header + executor_id path param; looks up the executor -in DB; returns the row to the handler if epoch matches. 401 EPOCH_MISMATCH -otherwise. - -Compose with require_bearer at the route level — both run; order doesn't -matter (different concerns). +Under W3a it chains under require_executor_jwt: the Executor row is already +loaded + authenticated via the mTLS cert fingerprint. This dep adds: + 1. path executor_id MUST equal the mTLS-authenticated identity + (confused-deputy guard); + 2. X-Executor-Epoch MUST match the authenticated row's epoch (W1 fence). """ from __future__ import annotations from fastapi import Depends, Header, HTTPException, Path -from sqlalchemy.ext.asyncio import AsyncSession -from dlw.api.tasks import _session +from dlw.auth.executor_jwt_dep import require_executor_jwt from dlw.db.models.executor import Executor async def require_executor_epoch( executor_id: str = Path(..., description="Executor id from URL path"), - # W6-D: accept Optional so dep body can raise 401 (not FastAPI auto-422) on missing x_executor_epoch: int | None = Header(default=None, alias="X-Executor-Epoch"), - session: AsyncSession = Depends(_session), + ex: Executor = Depends(require_executor_jwt), ) -> Executor: - """Return the Executor row if header matches stored epoch; else 401/404.""" - if x_executor_epoch is None: + """Return the mTLS+JWT-authenticated Executor row if path id matches and + epoch header matches; else 401.""" + if executor_id != ex.id: raise HTTPException( status_code=401, - detail="missing X-Executor-Epoch header", + detail={"code": "EXECUTOR_ID_MISMATCH", + "path": executor_id, "authenticated": ex.id}, ) - ex = await session.get(Executor, executor_id) - if ex is None: - raise HTTPException(status_code=404, detail="executor not found") + if x_executor_epoch is None: + raise HTTPException(status_code=401, detail="missing X-Executor-Epoch header") if ex.epoch != x_executor_epoch: raise HTTPException( status_code=401, - detail={ - "code": "EPOCH_MISMATCH", - "expected": ex.epoch, - "got": x_executor_epoch, - }, + detail={"code": "EPOCH_MISMATCH", "expected": ex.epoch, + "got": x_executor_epoch}, ) return ex diff --git a/src/dlw/auth/executor_jwt_dep.py b/src/dlw/auth/executor_jwt_dep.py new file mode 100644 index 0000000..a40ed35 --- /dev/null +++ b/src/dlw/auth/executor_jwt_dep.py @@ -0,0 +1,26 @@ +"""Executor JWT dependency (Phase 2 W3a §3.4).""" +from __future__ import annotations + +import jwt as _pyjwt +from fastapi import Depends, Header, HTTPException, Request + +from dlw.auth.executor_mtls import require_executor_mtls +from dlw.auth.jwt_signing import verify +from dlw.db.models.executor import Executor + + +async def require_executor_jwt( + request: Request, + authorization: str | None = Header(default=None), + ex: Executor = Depends(require_executor_mtls), +) -> Executor: + if not authorization or not authorization.startswith("Bearer "): + raise HTTPException(401, detail="missing executor JWT") + token = authorization.split(" ", 1)[1] + try: + claims = verify(request.app.state.jwt_keypair, token) + except _pyjwt.PyJWTError as e: + raise HTTPException(401, detail=f"invalid JWT: {e}") from e + if claims["sub"] != ex.id: + raise HTTPException(401, detail="JWT sub mismatch") + return ex diff --git a/src/dlw/auth/executor_mtls.py b/src/dlw/auth/executor_mtls.py new file mode 100644 index 0000000..0c118ff --- /dev/null +++ b/src/dlw/auth/executor_mtls.py @@ -0,0 +1,65 @@ +"""mTLS peer-cert dependency (Phase 2 W3a §3.4).""" +from __future__ import annotations + +import os + +from cryptography import x509 +from cryptography.hazmat.primitives import serialization +from fastapi import Depends, HTTPException, Request +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from dlw.api.tasks import _session +from dlw.auth.ca import fingerprint_of +from dlw.db.models.executor import Executor + + +def _extract_peer_cert(request: Request) -> bytes | None: + """Two paths: (a) direct uvicorn TLS — transport injected into scope by the + HttpToolsProtocol patch in dlw.auth.uvicorn_tls_patch; (b) trusted-proxy + forwarded header — only honored when DLW_TLS_TRUSTED_PROXY=1. + + uvicorn (httptools) does not expose the asyncio transport in the ASGI scope + natively. The ``install_transport_scope_patch()`` call in ``main.lifespan`` + monkey-patches ``HttpToolsProtocol.on_headers_complete`` to inject + ``scope["transport"]`` before the request is dispatched. From the + transport we reach ``ssl_object.getpeercert(binary_form=True)`` → DER bytes. + """ + transport = request.scope.get("transport") + if transport is not None and hasattr(transport, "get_extra_info"): + # Prefer ssl_object.getpeercert(binary_form=True) — returns DER bytes. + # transport.get_extra_info("peercert") returns a dict (parsed), not DER. + ssl_obj = transport.get_extra_info("ssl_object") + if ssl_obj is not None: + try: + peercert_der = ssl_obj.getpeercert(binary_form=True) + if peercert_der: + cert = x509.load_der_x509_certificate(peercert_der) + return cert.public_bytes(serialization.Encoding.PEM) + except Exception: + pass + if os.environ.get("DLW_TLS_TRUSTED_PROXY") == "1": + header = request.headers.get("X-Client-Cert-PEM") + if header: + return header.replace("\\n", "\n").encode("utf-8") + return None + + +async def require_executor_mtls( + request: Request, + session: AsyncSession = Depends(_session), +) -> Executor: + """Validate mTLS peer cert + look up executor by fingerprint.""" + cert_pem = _extract_peer_cert(request) + if cert_pem is None: + raise HTTPException(401, detail="missing or invalid mTLS peer cert") + try: + fp = fingerprint_of(cert_pem) + except Exception as e: + raise HTTPException(401, detail=f"invalid client cert: {e}") from e + ex = (await session.execute( + select(Executor).where(Executor.cert_fingerprint == fp) + )).scalar_one_or_none() + if ex is None: + raise HTTPException(401, detail="cert fingerprint not registered") + return ex diff --git a/src/dlw/auth/hmac_heartbeat_dep.py b/src/dlw/auth/hmac_heartbeat_dep.py new file mode 100644 index 0000000..813b670 --- /dev/null +++ b/src/dlw/auth/hmac_heartbeat_dep.py @@ -0,0 +1,36 @@ +"""HMAC heartbeat dependency (Phase 2 W3a §3.4).""" +from __future__ import annotations + +import time + +from fastapi import Depends, Header, HTTPException, Request + +from dlw.auth.executor_jwt_dep import require_executor_jwt +from dlw.auth.hmac_nonce import verify_hmac +from dlw.db.models.executor import Executor + +_TIMESTAMP_SKEW_SECONDS = 300 + + +async def require_hmac_heartbeat( + request: Request, + x_hmac_timestamp: int = Header(..., alias="X-HMAC-Timestamp"), + x_hmac_nonce: str = Header(..., alias="X-HMAC-Nonce"), + x_hmac_signature: str = Header(..., alias="X-HMAC-Signature"), + ex: Executor = Depends(require_executor_jwt), +) -> Executor: + now = int(time.time()) + if abs(now - x_hmac_timestamp) > _TIMESTAMP_SKEW_SECONDS: + raise HTTPException(401, detail="CLOCK_SKEW") + store = request.app.state.nonce_store + if store.seen(x_hmac_nonce): + raise HTTPException(401, detail="REPLAY_DETECTED") + if ex.hmac_seed_encrypted is None: + raise HTTPException(401, detail="HMAC_SEED_MISSING — re-register") + hmac_seed = bytes(ex.hmac_seed_encrypted) + body = await request.body() + if not verify_hmac(hmac_seed, ts=x_hmac_timestamp, nonce=x_hmac_nonce, + body=body, signature_hex=x_hmac_signature): + raise HTTPException(401, detail="HMAC_INVALID") + store.add(x_hmac_nonce) + return ex diff --git a/src/dlw/auth/hmac_nonce.py b/src/dlw/auth/hmac_nonce.py new file mode 100644 index 0000000..7a6dfc8 --- /dev/null +++ b/src/dlw/auth/hmac_nonce.py @@ -0,0 +1,48 @@ +"""HMAC heartbeat: nonce store + signature verify (Phase 2 W3a §3.3).""" +from __future__ import annotations + +import hashlib +import hmac as _hmac +import time +from collections import OrderedDict + + +class NonceStore: + """In-process LRU with timestamp-based eviction. asyncio single-threaded — + no lock needed. Restart loses state; replay defense is bounded by the + ±5min timestamp window enforced at the dependency layer.""" + + def __init__(self, *, maxsize: int = 10_000, ttl_seconds: int = 300) -> None: + self._maxsize = maxsize + self._ttl = ttl_seconds + self._data: OrderedDict[str, float] = OrderedDict() + + def _evict_expired(self) -> None: + cutoff = time.monotonic() - self._ttl + while self._data: + k, v = next(iter(self._data.items())) + if v >= cutoff: + break + self._data.popitem(last=False) + + def seen(self, nonce: str) -> bool: + self._evict_expired() + return nonce in self._data + + def add(self, nonce: str) -> None: + self._evict_expired() + if len(self._data) >= self._maxsize: + self._data.popitem(last=False) + self._data[nonce] = time.monotonic() + + +def compute_hmac(hmac_seed: bytes, *, ts: int, nonce: str, body: bytes) -> str: + """HMAC-SHA256(hmac_seed, f'{ts}:{nonce}:'.encode() + body). Hex string.""" + msg = f"{ts}:{nonce}:".encode("utf-8") + body + return _hmac.new(hmac_seed, msg, hashlib.sha256).hexdigest() + + +def verify_hmac(hmac_seed: bytes, *, ts: int, nonce: str, body: bytes, + signature_hex: str) -> bool: + expected = compute_hmac(hmac_seed, ts=ts, nonce=nonce, body=body) + return _hmac.compare_digest(expected, signature_hex) diff --git a/src/dlw/auth/jwt_signing.py b/src/dlw/auth/jwt_signing.py new file mode 100644 index 0000000..17b25dc --- /dev/null +++ b/src/dlw/auth/jwt_signing.py @@ -0,0 +1,67 @@ +"""Ed25519 JWT signing for executor JWTs (Phase 2 W3a §3.2).""" +from __future__ import annotations + +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import jwt as _pyjwt +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ed25519 + + +@dataclass(frozen=True) +class JWTKeypair: + priv_pem: bytes + pub_pem: bytes + + +def bootstrap_keypair(ca_dir: Path) -> JWTKeypair: + """Idempotent: load or generate jwt-signing.pem (chmod 600, PKCS8 Ed25519).""" + priv_path = ca_dir / "jwt-signing.pem" + if priv_path.exists(): + priv_pem = priv_path.read_bytes() + priv = serialization.load_pem_private_key(priv_pem, password=None) + if not isinstance(priv, ed25519.Ed25519PrivateKey): + raise ValueError("JWT signing key is not Ed25519") + else: + priv = ed25519.Ed25519PrivateKey.generate() + priv_pem = priv.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + ca_dir.mkdir(mode=0o700, parents=True, exist_ok=True) + priv_path.write_bytes(priv_pem) + priv_path.chmod(0o600) + pub_pem = priv.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + return JWTKeypair(priv_pem=priv_pem, pub_pem=pub_pem) + + +def sign(kp: JWTKeypair, *, executor_id: str, epoch: int, + scopes: list[str], ttl_seconds: int = 3600) -> str: + """Sign an executor JWT. Returns compact JWS.""" + now = int(time.time()) + claims = { + "iss": "dlw-controller", + "sub": executor_id, + "epoch": epoch, + "scope": " ".join(scopes), + "iat": now, + "exp": now + ttl_seconds, + } + return _pyjwt.encode(claims, kp.priv_pem.decode("utf-8"), algorithm="EdDSA") + + +def verify(kp: JWTKeypair, token: str) -> dict[str, Any]: + """Decode + verify. Raises jwt.PyJWTError on any failure.""" + return _pyjwt.decode( + token, kp.pub_pem.decode("utf-8"), + algorithms=["EdDSA"], + issuer="dlw-controller", + options={"require": ["sub", "epoch", "scope", "exp", "iss", "iat"]}, + ) diff --git a/src/dlw/auth/uvicorn_tls_patch.py b/src/dlw/auth/uvicorn_tls_patch.py new file mode 100644 index 0000000..c8e0a90 --- /dev/null +++ b/src/dlw/auth/uvicorn_tls_patch.py @@ -0,0 +1,50 @@ +"""uvicorn transport injection for mTLS peer-cert extraction (Phase 2 W3a). + +uvicorn (httptools backend) does not expose the asyncio transport in the ASGI +``scope`` dict. This module monkey-patches ``HttpToolsProtocol.on_headers_complete`` +to inject ``scope["transport"] = self.transport`` before dispatching each request, +making the SSL transport available to ``_extract_peer_cert`` in executor_mtls.py. + +Call ``install_transport_scope_patch()`` once at application startup (from +``main.lifespan``). The patch is idempotent. +""" +from __future__ import annotations + +import logging + +logger = logging.getLogger(__name__) + +_patched = False + + +def install_transport_scope_patch() -> None: + """Idempotently monkey-patch HttpToolsProtocol to inject the asyncio + transport into the ASGI scope so mTLS peer certs are accessible.""" + global _patched + if _patched: + return + try: + from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol + + _orig = HttpToolsProtocol.on_headers_complete + + def _patched_on_headers_complete(self) -> None: # type: ignore[misc] + _orig(self) + # self.scope is the dict passed (by reference) to RequestResponseCycle + # and ultimately to the ASGI app. Injecting the transport here makes + # it available as request.scope["transport"] inside FastAPI endpoints. + if self.scope is not None: + self.scope["transport"] = self.transport + + HttpToolsProtocol.on_headers_complete = _patched_on_headers_complete + _patched = True + logger.debug("installed uvicorn transport scope patch for mTLS peer-cert extraction") + except ImportError: + # httptools not available — uvicorn falls back to h11. + # h11 path is rare (httptools is a dependency of uvicorn[standard]) + # and not supported for direct-TLS peer cert extraction. + logger.warning( + "httptools not available; uvicorn transport scope patch not installed. " + "mTLS peer-cert extraction via direct TLS will not work. " + "Use DLW_TLS_TRUSTED_PROXY=1 with a terminating proxy instead." + ) diff --git a/src/dlw/config.py b/src/dlw/config.py index 4f5cec4..350a58b 100644 --- a/src/dlw/config.py +++ b/src/dlw/config.py @@ -30,6 +30,12 @@ class Settings(BaseSettings): log_level: str = Field(default="INFO") + # Phase 2 W3a — mTLS + JWT + HMAC + ca_dir: str = Field(default="./.ca") + enrollment_token: str = Field(default="") + controller_hostname: str = Field(default="dlw-controller") + tls_trusted_proxy: bool = Field(default=False) + @property def db_url(self) -> str: auth = f"{self.db_user}:{self.db_password}" if self.db_password else self.db_user diff --git a/src/dlw/db/models/executor.py b/src/dlw/db/models/executor.py index 3a8973c..d2c0643 100644 --- a/src/dlw/db/models/executor.py +++ b/src/dlw/db/models/executor.py @@ -4,7 +4,7 @@ from datetime import datetime from typing import Any -from sqlalchemy import BigInteger, DateTime, ForeignKey, Integer, SmallInteger, String, func +from sqlalchemy import BigInteger, DateTime, ForeignKey, Integer, LargeBinary, SmallInteger, String, func from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped, mapped_column @@ -45,3 +45,8 @@ class Executor(Base): DateTime(timezone=True), server_default=func.now(), nullable=False ) deactivated_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + # W3a: 256-bit HMAC seed for heartbeat anti-replay. "encrypted" in the + # name is forward-compatible — Phase 2 stores raw bytes; Phase 3 wraps KMS. + hmac_seed_encrypted: Mapped[bytes | None] = mapped_column( + LargeBinary, nullable=True + ) diff --git a/src/dlw/executor/auth_lifecycle.py b/src/dlw/executor/auth_lifecycle.py new file mode 100644 index 0000000..758f8d6 --- /dev/null +++ b/src/dlw/executor/auth_lifecycle.py @@ -0,0 +1,114 @@ +"""Executor auth lifecycle: register / renew / load (Phase 2 W3a §3.11).""" +from __future__ import annotations + +import datetime as _dt +from dataclasses import dataclass +from pathlib import Path + +import httpx +import jwt as _pyjwt +from cryptography import x509 + +from dlw.executor import cert as _cert + + +@dataclass +class AuthState: + executor_id: str + epoch: int + cert_pem: bytes + key_pem: bytes + ca_chain_pem: bytes + jwt: str + jwt_exp: _dt.datetime + cert_exp: _dt.datetime + hmac_seed: bytes + cert_dir: Path + + +def _parse_jwt_exp(token: str) -> _dt.datetime: + claims = _pyjwt.decode(token, options={"verify_signature": False}) + return _dt.datetime.fromtimestamp(claims["exp"], tz=_dt.UTC) + + +def _parse_cert_exp(cert_pem: bytes) -> _dt.datetime: + return x509.load_pem_x509_certificate(cert_pem).not_valid_after_utc + + +async def register(*, controller_url: str, ca_bundle_path: str | None, + enrollment_token: str, executor_id: str, host_id: str, + capabilities: dict, cert_dir: Path) -> AuthState: + csr_pem, key_pem = _cert.build_csr(executor_id) + verify = ca_bundle_path if ca_bundle_path else True + async with httpx.AsyncClient(verify=verify) as c: + r = await c.post(f"{controller_url}/api/v1/executors/register", json={ + "host_id": host_id, "executor_id_proposal": executor_id, + "capabilities": capabilities, + "client_csr_pem": csr_pem.decode("utf-8"), + }, headers={"X-Enrollment-Token": enrollment_token}) + r.raise_for_status() + body = r.json() + cert_pem = body["client_cert_pem"].encode("utf-8") + ca_chain_pem = "\n".join(body["ca_chain"]).encode("utf-8") + hmac_seed = bytes.fromhex(body["hmac_seed_hex"]) + _cert.persist(cert_dir, cert_pem=cert_pem, key_pem=key_pem, + ca_chain_pem=ca_chain_pem, hmac_seed=hmac_seed) + return AuthState( + executor_id=body["executor_id"], epoch=body["epoch"], + cert_pem=cert_pem, key_pem=key_pem, ca_chain_pem=ca_chain_pem, + jwt=body["executor_jwt"], jwt_exp=_parse_jwt_exp(body["executor_jwt"]), + cert_exp=_parse_cert_exp(cert_pem), hmac_seed=hmac_seed, + cert_dir=cert_dir, + ) + + +async def renew(state: AuthState, *, controller_url: str) -> AuthState: + """POST /{eid}/renew over mTLS. Include a fresh CSR iff cert TTL < 1h.""" + now = _dt.datetime.now(_dt.UTC) + payload: dict = {} + new_key_pem = state.key_pem + if state.cert_exp - now < _dt.timedelta(hours=1): + csr_pem, new_key_pem = _cert.build_csr(state.executor_id) + payload["client_csr_pem"] = csr_pem.decode("utf-8") + cert_file = state.cert_dir / "client-cert.pem" + key_file = state.cert_dir / "client-key.pem" + async with httpx.AsyncClient( + verify=str(state.cert_dir / "ca-chain.pem"), + cert=(str(cert_file), str(key_file)), + headers={"Authorization": f"Bearer {state.jwt}"}, + ) as c: + r = await c.post( + f"{controller_url}/api/v1/executors/{state.executor_id}/renew", + json=payload, + ) + r.raise_for_status() + body = r.json() + new_jwt = body["executor_jwt"] + cert_pem = state.cert_pem + cert_exp = state.cert_exp + if body.get("client_cert_pem"): + cert_pem = body["client_cert_pem"].encode("utf-8") + cert_exp = _parse_cert_exp(cert_pem) + _cert.persist(state.cert_dir, cert_pem=cert_pem, key_pem=new_key_pem, + ca_chain_pem=state.ca_chain_pem, hmac_seed=state.hmac_seed) + return AuthState( + executor_id=state.executor_id, epoch=state.epoch, + cert_pem=cert_pem, key_pem=new_key_pem, ca_chain_pem=state.ca_chain_pem, + jwt=new_jwt, jwt_exp=_parse_jwt_exp(new_jwt), cert_exp=cert_exp, + hmac_seed=state.hmac_seed, cert_dir=state.cert_dir, + ) + + +async def load_or_register(*, cert_dir: Path, controller_url: str, + ca_bundle_path: str | None, enrollment_token: str, + executor_id: str, host_id: str, + capabilities: dict) -> AuthState: + """W3a simplification: always re-register on startup. The JWT is never + persisted, so a "load + renew" path would still need a valid JWT. Re-register + is idempotent (epoch bumps — correct W1 fence semantics). The renew loop + handles in-process refresh; restart goes through register.""" + return await register( + controller_url=controller_url, ca_bundle_path=ca_bundle_path, + enrollment_token=enrollment_token, executor_id=executor_id, + host_id=host_id, capabilities=capabilities, cert_dir=cert_dir, + ) diff --git a/src/dlw/executor/cert.py b/src/dlw/executor/cert.py new file mode 100644 index 0000000..d6dd6bb --- /dev/null +++ b/src/dlw/executor/cert.py @@ -0,0 +1,54 @@ +"""Executor-side cert helpers (Phase 2 W3a §3.10).""" +from __future__ import annotations + +from pathlib import Path + +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import ed25519 +from cryptography.x509.oid import NameOID + + +def build_csr(executor_id: str) -> tuple[bytes, bytes]: + """Generate an Ed25519 keypair + CSR (CN=executor_id). + Returns (csr_pem, private_key_pem).""" + key = ed25519.Ed25519PrivateKey.generate() + csr = (x509.CertificateSigningRequestBuilder() + .subject_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, executor_id)])) + .sign(key, None)) + csr_pem = csr.public_bytes(serialization.Encoding.PEM) + key_pem = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + return csr_pem, key_pem + + +def persist(cert_dir: Path, *, cert_pem: bytes, key_pem: bytes, + ca_chain_pem: bytes, hmac_seed: bytes) -> None: + """Write client-cert.pem / client-key.pem / ca-chain.pem / hmac-seed + (chmod 600) into cert_dir (chmod 700).""" + cert_dir.mkdir(mode=0o700, parents=True, exist_ok=True) + for name, data in [ + ("client-cert.pem", cert_pem), ("client-key.pem", key_pem), + ("ca-chain.pem", ca_chain_pem), ("hmac-seed", hmac_seed), + ]: + p = cert_dir / name + p.write_bytes(data) + p.chmod(0o600) + + +def load(cert_dir: Path) -> tuple[bytes, bytes, bytes, bytes] | None: + """Return (cert_pem, key_pem, ca_chain_pem, hmac_seed) or None if absent.""" + paths = [cert_dir / n for n in + ("client-cert.pem", "client-key.pem", "ca-chain.pem", "hmac-seed")] + if not all(p.exists() for p in paths): + return None + return tuple(p.read_bytes() for p in paths) # type: ignore[return-value] + + +def fingerprint(cert_pem: bytes) -> str: + """SHA256: — same format as dlw.auth.ca.fingerprint_of.""" + cert = x509.load_pem_x509_certificate(cert_pem) + return f"SHA256:{cert.fingerprint(hashes.SHA256()).hex()}" diff --git a/src/dlw/executor/cli.py b/src/dlw/executor/cli.py index 5f41498..c464a1f 100644 --- a/src/dlw/executor/cli.py +++ b/src/dlw/executor/cli.py @@ -36,9 +36,9 @@ async def _async_main(args: argparse.Namespace) -> int: format="%(asctime)s %(name)s %(levelname)s: %(message)s", ) settings = ExecutorSettings() - client = ControllerClient( - base_url=settings.controller_url, bearer_token=settings.bearer_token, - ) + # W3a: the client starts without an AuthState; ExecutorRunner.run() does + # load_or_register and then calls client.update_auth() before any request. + client = ControllerClient(base_url=settings.controller_url) stream = HfS3StreamDownloader(settings=settings) chunk = DirectOffsetDownloader(settings=settings) runner = ExecutorRunner( diff --git a/src/dlw/executor/client.py b/src/dlw/executor/client.py index f8fb7b1..9f10a07 100644 --- a/src/dlw/executor/client.py +++ b/src/dlw/executor/client.py @@ -1,13 +1,19 @@ """HTTP client wrapping the controller's executor + subtask endpoints. -Phase 2 W1 additions: - - Persists the executor's current epoch from /join response. - - Attaches X-Executor-Epoch header on heartbeat / poll / report. - - Caller (runner) should observe `current_epoch()` and react to 401 - EPOCH_MISMATCH by calling join() again. +Phase 2 W3a: auth is mTLS (client cert) + executor JWT + (heartbeat only) HMAC. +The client is driven by an AuthState (mutable — the runner's renew loop swaps +it in place via update_auth). Each request builds an httpx.AsyncClient with +verify= + cert=(, ) and an +Authorization: Bearer header. heartbeat additionally signs the body +with HMAC-SHA256 over the executor's hmac_seed. + +The W1 /join + epoch-persistence logic is gone — registration happens via +dlw.executor.auth_lifecycle, and the epoch lives on the AuthState. """ from __future__ import annotations +import secrets +import time import uuid from typing import Any, Self @@ -19,6 +25,9 @@ wait_exponential, ) +from dlw.auth.hmac_nonce import compute_hmac +from dlw.executor.auth_lifecycle import AuthState + _retry = retry( retry=retry_if_exception_type( @@ -31,69 +40,108 @@ class ControllerClient: - """Async HTTP client for controller endpoints (executor side).""" + """Async HTTP client for controller endpoints (executor side). + + Auth is carried by an AuthState. The runner's renew loop updates it in + place via update_auth(); each request reads the current state. + """ def __init__( self, base_url: str, - bearer_token: str, + auth_state: AuthState | None = None, timeout_seconds: float = 30.0, _transport: httpx.AsyncBaseTransport | None = None, + _extra_test_headers: dict[str, str] | None = None, ) -> None: + # auth_state may be None at construction (cli.py builds the client + # before run() does load_or_register); the runner calls update_auth() + # before any request is made. Requests crash if auth is still None. + # + # _extra_test_headers is a test-only seam: when an ASGI transport is + # injected (no real TLS), tests pass {"X-Client-Cert-PEM": ...} so the + # controller's trusted-proxy mTLS path can authenticate. Production + # leaves it None and relies on real mTLS via cert=(...). self._base_url = base_url.rstrip("/") - self._headers = {"Authorization": f"Bearer {bearer_token}"} - self._client = httpx.AsyncClient( - base_url=self._base_url, - headers=self._headers, - timeout=timeout_seconds, - transport=_transport, - ) - self._epoch: int | None = None # P2-W1 + self._auth = auth_state + self._timeout = timeout_seconds + self._transport = _transport + self._extra_test_headers = _extra_test_headers or {} async def __aenter__(self) -> Self: return self async def __aexit__(self, *exc_info: Any) -> None: - await self._client.aclose() - - def current_epoch(self) -> int | None: - """Returns the most recent epoch from /join, or None if not joined yet.""" - return self._epoch + # Per-request clients are created + closed inside each call; nothing to + # close here. Kept for API compatibility with the W1 context-manager use. + return None + + def update_auth(self, new_state: AuthState) -> None: + """Swap the AuthState (called by the runner's renew loop).""" + self._auth = new_state + + def current_epoch(self) -> int: + """Current executor epoch from the AuthState.""" + return self._auth.epoch + + def _make_client(self) -> httpx.AsyncClient: + """Build a per-request httpx client. When a test transport is injected, + cert/verify are skipped (MockTransport short-circuits the network).""" + if self._transport is not None: + return httpx.AsyncClient( + base_url=self._base_url, timeout=self._timeout, + transport=self._transport, + ) + cert_dir = self._auth.cert_dir + return httpx.AsyncClient( + base_url=self._base_url, timeout=self._timeout, + verify=str(cert_dir / "ca-chain.pem"), + cert=(str(cert_dir / "client-cert.pem"), + str(cert_dir / "client-key.pem")), + ) - def _epoch_headers(self) -> dict[str, str]: - if self._epoch is None: - return {} - return {"X-Executor-Epoch": str(self._epoch)} + def _auth_headers(self) -> dict[str, str]: + return { + "Authorization": f"Bearer {self._auth.jwt}", + "X-Executor-Epoch": str(self._auth.epoch), + **self._extra_test_headers, + } - async def _post( + async def _post_json( self, path: str, - json_body: dict[str, Any] | None = None, - extra_headers: dict[str, str] | None = None, + json_body: dict[str, Any] | None, + extra_headers: dict[str, str], ) -> dict[str, Any]: - headers = {**(extra_headers or {})} - @_retry async def _do() -> httpx.Response: - r = await self._client.post(path, json=json_body, headers=headers) - if 500 <= r.status_code < 600: - r.raise_for_status() - return r + async with self._make_client() as client: + r = await client.post(path, json=json_body, headers=extra_headers) + if 500 <= r.status_code < 600: + r.raise_for_status() + return r r = await _do() r.raise_for_status() return r.json() - async def join( - self, *, executor_id: str, host_id: str, capabilities: dict[str, Any] + async def _post_content( + self, + path: str, + content: bytes, + extra_headers: dict[str, str], ) -> dict[str, Any]: - body = await self._post("/api/v1/executors/join", { - "id": executor_id, "host_id": host_id, "capabilities": capabilities, - }) - epoch = body.get("epoch") - if isinstance(epoch, int): - self._epoch = epoch - return body + @_retry + async def _do() -> httpx.Response: + async with self._make_client() as client: + r = await client.post(path, content=content, headers=extra_headers) + if 500 <= r.status_code < 600: + r.raise_for_status() + return r + + r = await _do() + r.raise_for_status() + return r.json() async def heartbeat( self, @@ -103,22 +151,33 @@ async def heartbeat( parts_dir_bytes: int, disk_free_gb: int | None = None, ) -> dict[str, Any]: - body: dict[str, Any] = { + """POST /heartbeat with mTLS + JWT + HMAC. The body is sent as raw + content (not json=) so the HMAC signature covers the exact bytes.""" + import json as _json + body_dict: dict[str, Any] = { "health_score": health_score, "parts_dir_bytes": parts_dir_bytes, } if disk_free_gb is not None: - body["disk_free_gb"] = disk_free_gb - return await self._post( - f"/api/v1/executors/{executor_id}/heartbeat", - body, - extra_headers=self._epoch_headers(), + body_dict["disk_free_gb"] = disk_free_gb + body = _json.dumps(body_dict).encode("utf-8") + ts = int(time.time()) + nonce = secrets.token_hex(16) + sig = compute_hmac(self._auth.hmac_seed, ts=ts, nonce=nonce, body=body) + headers = { + **self._auth_headers(), + "Content-Type": "application/json", + "X-HMAC-Timestamp": str(ts), + "X-HMAC-Nonce": nonce, + "X-HMAC-Signature": sig, + } + return await self._post_content( + f"/api/v1/executors/{executor_id}/heartbeat", body, headers, ) async def poll(self, *, executor_id: str) -> dict[str, Any]: - return await self._post( - f"/api/v1/executors/{executor_id}/poll", - extra_headers=self._epoch_headers(), + return await self._post_json( + f"/api/v1/executors/{executor_id}/poll", None, self._auth_headers(), ) async def report( @@ -144,8 +203,6 @@ async def report( body["error"] = error if s3_key is not None: body["s3_key"] = s3_key - return await self._post( - f"/api/v1/subtasks/{subtask_id}/report", - body, - extra_headers=self._epoch_headers(), + return await self._post_json( + f"/api/v1/subtasks/{subtask_id}/report", body, self._auth_headers(), ) diff --git a/src/dlw/executor/config.py b/src/dlw/executor/config.py index af85550..6ae13bf 100644 --- a/src/dlw/executor/config.py +++ b/src/dlw/executor/config.py @@ -63,6 +63,20 @@ class ExecutorSettings(BaseSettings): description="Local staging dir for chunk-level downloads. Configure to a writable PV in prod.", ) + # Phase 2 W3a — mTLS + JWT auth + enrollment_token: str = Field( + default="", + description="OOB enrollment token from the operator (controller's DLW_ENROLLMENT_TOKEN).", + ) + executor_cert_dir: str = Field( + default="./.executor-certs", + description="Local dir for client-cert.pem / client-key.pem / ca-chain.pem / hmac-seed.", + ) + executor_ca_bundle: str = Field( + default="", + description="CA bundle path for httpx verify=; empty → defaults to {cert_dir}/ca-chain.pem at runtime.", + ) + @model_validator(mode="after") def _derive_host_id(self) -> "ExecutorSettings": """If host_id not set, derive from id by stripping any -worker-N suffix. diff --git a/src/dlw/executor/runner.py b/src/dlw/executor/runner.py index cdc1f65..7a1b0f9 100644 --- a/src/dlw/executor/runner.py +++ b/src/dlw/executor/runner.py @@ -1,14 +1,18 @@ -"""ExecutorRunner — async main loop joining heartbeat + poll-and-execute. +"""ExecutorRunner — async main loop: register, heartbeat, poll-and-execute, renew. -On startup: register via /join. Then runs two concurrent loops: - - Heartbeat every settings.heartbeat_interval_seconds +On startup: load-or-register via mTLS enrollment (W3a). Then runs three +concurrent loops: + - Heartbeat every settings.heartbeat_interval_seconds (mTLS + JWT + HMAC) - Poll every settings.poll_interval_seconds; if assigned, download + report + - Auth renew: refresh the JWT ~5min before expiry, the cert ~1h before expiry W3-A: shutdown does NOT cancel loops mid-iteration. The pacing wait inside -each loop reacts to _shutdown.set() instantly (asyncio.wait_for completes -immediately when the event is set). _execute_subtask completes its download -and report cycle before the loop exits — otherwise mid-flight subtasks would -be stuck in 'assigned' state forever. +each loop reacts to _shutdown.set() instantly. _execute_subtask completes its +download + report cycle before the loop exits — otherwise mid-flight subtasks +would be stuck in 'assigned' state forever. + +W3a: a 401 on poll triggers a re-register (generalizes the W1 EPOCH_MISMATCH +re-join path) — the executor's identity may have been reset controller-side. """ from __future__ import annotations @@ -16,7 +20,10 @@ import logging import shutil import uuid +from datetime import UTC, datetime, timedelta +from pathlib import Path +from dlw.executor.auth_lifecycle import AuthState, load_or_register, renew from dlw.executor.chunk_downloader import DirectOffsetDownloader, DiskFullError from dlw.executor.client import ControllerClient from dlw.executor.config import ExecutorSettings @@ -34,11 +41,14 @@ def __init__( client: ControllerClient, stream_downloader: HfS3StreamDownloader, chunk_downloader: DirectOffsetDownloader, + auth_state: AuthState | None = None, ) -> None: self._s = settings self._client = client self._stream_downloader = stream_downloader self._chunk_downloader = chunk_downloader + # When auth_state is provided (tests), run() skips load_or_register. + self._auth = auth_state self._shutdown = asyncio.Event() def _choose_downloader(self, file_size: int | None): @@ -50,30 +60,37 @@ def _choose_downloader(self, file_size: int | None): def request_shutdown(self) -> None: self._shutdown.set() + def _capabilities(self) -> dict: + return {"nic_speed_gbps": self._s.nic_speed_gbps, "region": self._s.region} + async def run(self) -> None: # W2b1 §3.2: clean up any stale .parts/ dirs from a previous crash. - # active_subtask_ids=set() removes everything — W2b1 has no resume. removed = startup_gc(self._s.parts_dir_path, active_subtask_ids=set()) if removed: logger.info("startup_gc removed %d stale parts dirs", removed) - # 1. Join (one-shot) - await self._client.join( - executor_id=self._s.id, - host_id=self._s.host_id, - capabilities={ - "nic_speed_gbps": self._s.nic_speed_gbps, - "region": self._s.region, - }, - ) + # W3a §3.13: auth bootstrap — load existing cert or register fresh. + if self._auth is None: + self._auth = await load_or_register( + cert_dir=Path(self._s.executor_cert_dir), + controller_url=self._s.controller_url, + ca_bundle_path=self._s.executor_ca_bundle or None, + enrollment_token=self._s.enrollment_token, + executor_id=self._s.id, + host_id=self._s.host_id, + capabilities=self._capabilities(), + ) + self._client.update_auth(self._auth) - # 2. Concurrent loops — both check self._shutdown.is_set() each iteration + # Three concurrent loops — each checks self._shutdown.is_set(). heartbeat_task = asyncio.create_task(self._heartbeat_loop()) poll_task = asyncio.create_task(self._poll_and_execute_loop()) + renew_task = asyncio.create_task(self._auth_renew_loop()) - # 3. Wait for shutdown signal then let loops exit naturally (W3-A) await self._shutdown.wait() - await asyncio.gather(heartbeat_task, poll_task, return_exceptions=True) + await asyncio.gather( + heartbeat_task, poll_task, renew_task, return_exceptions=True, + ) async def _heartbeat_loop(self) -> None: while not self._shutdown.is_set(): @@ -117,19 +134,14 @@ async def _poll_and_execute_loop(self) -> None: ) continue # immediately poll again — there may be more work except _httpx.HTTPStatusError as e: + # W3a: any 401 on poll → re-register. The executor identity may + # have been reset controller-side (epoch bump, cert rotation, + # or the row cleared). load_or_register is idempotent. if e.response.status_code == 401: - detail = None - try: - detail = e.response.json().get("detail") - except Exception: - pass - if isinstance(detail, dict) and detail.get("code") == "EPOCH_MISMATCH": - logger.warning( - "EPOCH_MISMATCH (expected=%s got=%s); re-joining", - detail.get("expected"), detail.get("got"), - ) - await self._rejoin() - continue + logger.warning("poll 401 (%s); re-registering", + _safe_detail(e)) + await self._reregister() + continue logger.warning("poll failed: %s", e) except Exception as e: logger.warning("poll failed: %s", e) @@ -141,19 +153,42 @@ async def _poll_and_execute_loop(self) -> None: except asyncio.TimeoutError: pass - async def _rejoin(self) -> None: - """Discard any in-flight state and re-issue /join (gets new epoch).""" + async def _auth_renew_loop(self) -> None: + """W3a §3.13: refresh the JWT ~5min before expiry, the cert ~1h before.""" + while not self._shutdown.is_set(): + assert self._auth is not None + now = datetime.now(UTC) + jwt_due = self._auth.jwt_exp - timedelta(minutes=5) + cert_due = self._auth.cert_exp - timedelta(hours=1) + sleep_for = max(60, int((min(jwt_due, cert_due) - now).total_seconds())) + try: + await asyncio.wait_for(self._shutdown.wait(), timeout=sleep_for) + return # shutdown + except asyncio.TimeoutError: + pass + try: + self._auth = await renew( + self._auth, controller_url=self._s.controller_url, + ) + self._client.update_auth(self._auth) + except Exception as e: + logger.warning("auth renew failed: %s; retry next cycle", e) + + async def _reregister(self) -> None: + """Discard in-flight auth state and re-register (fresh cert + JWT + epoch).""" try: - await self._client.join( + self._auth = await load_or_register( + cert_dir=Path(self._s.executor_cert_dir), + controller_url=self._s.controller_url, + ca_bundle_path=self._s.executor_ca_bundle or None, + enrollment_token=self._s.enrollment_token, executor_id=self._s.id, host_id=self._s.host_id, - capabilities={ - "nic_speed_gbps": self._s.nic_speed_gbps, - "region": self._s.region, - }, + capabilities=self._capabilities(), ) + self._client.update_auth(self._auth) except Exception as e: - logger.warning("rejoin failed: %s", e) + logger.warning("re-register failed: %s", e) async def _execute_subtask( self, *, subtask: dict, assignment_token: uuid.UUID, @@ -226,3 +261,10 @@ async def _execute_subtask( ) except Exception: logger.exception("report failure also failed for %s", sub_id) + + +def _safe_detail(e) -> str: + try: + return str(e.response.json().get("detail")) + except Exception: + return "" diff --git a/src/dlw/main.py b/src/dlw/main.py index b923f5c..2140901 100644 --- a/src/dlw/main.py +++ b/src/dlw/main.py @@ -30,6 +30,41 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]: from dlw.services.recovery import run_recovery_routine factory = async_sessionmaker(get_engine(), expire_on_commit=False) + + # W3a: bootstrap CA + JWT signing key + server cert + nonce store + enrollment token. + # Install the uvicorn transport scope patch so _extract_peer_cert can read + # the mTLS peer cert from scope["transport"] in direct-TLS deployments. + from dlw.auth.uvicorn_tls_patch import install_transport_scope_patch + install_transport_scope_patch() + + from pathlib import Path + from dlw.auth.ca import bootstrap_ca, ensure_server_cert + from dlw.auth.jwt_signing import bootstrap_keypair + from dlw.auth.hmac_nonce import NonceStore + import secrets as _secrets + from dlw.config import get_settings as _gs + _settings = _gs() + _ca_dir = Path(_settings.ca_dir) + _ca_dir.mkdir(mode=0o700, parents=True, exist_ok=True) + _ca = bootstrap_ca(_ca_dir) + ensure_server_cert(_ca, _ca_dir, hostname=_settings.controller_hostname) + _jwt_kp = bootstrap_keypair(_ca_dir) + if _settings.enrollment_token: + _enroll = _settings.enrollment_token + else: + _tok_path = _ca_dir / "enrollment.token" + if _tok_path.exists(): + _enroll = _tok_path.read_text().strip() + else: + _enroll = _secrets.token_hex(32) + _tok_path.write_text(_enroll) + _tok_path.chmod(0o600) + logger.info("generated enrollment token (copy to executors): %s", _enroll) + app.state.ca = _ca + app.state.jwt_keypair = _jwt_kp + app.state.nonce_store = NonceStore(maxsize=10_000, ttl_seconds=300) + app.state.enrollment_token = _enroll + # W6-J: spec §7 says recovery failure aborts startup. Permissive dev mode # via DLW_STRICT_RECOVERY=false env override (defaults to strict). import os diff --git a/src/dlw/schemas/executor.py b/src/dlw/schemas/executor.py index 6589646..f7edfbb 100644 --- a/src/dlw/schemas/executor.py +++ b/src/dlw/schemas/executor.py @@ -10,19 +10,33 @@ from dlw.schemas.subtask import SubTaskRead -class ExecutorJoin(BaseModel): - """POST /api/v1/executors/join — first contact from new executor. - - NOTE (W2-H): Invariant 9 in `docs/v2.0/03-distributed-correctness.md` - requires id format `^[a-z0-9-]+-worker-\\d+$`. Week 2 does NOT enforce this - at the schema level — tests use shorter ids like 'exec-A' for brevity. - Phase 2 will add a `@field_validator("id")` enforcing the regex once mTLS - cert binding is in place (the cert CN should match the executor id). - """ - id: str = Field(min_length=1, max_length=64, examples=["host-12.local-worker-1"]) - host_id: str = Field(min_length=1, max_length=64) - cert_fingerprint: str = Field(default="placeholder-week2", max_length=128) - capabilities: dict[str, Any] = Field(default_factory=dict) +class ExecutorRegister(BaseModel): + host_id: str + executor_id_proposal: str + capabilities: dict[str, Any] = {} + client_csr_pem: str + + +class RegistrationResponse(BaseModel): + executor_id: str + epoch: int + client_cert_pem: str + ca_chain: list[str] + executor_jwt: str + hmac_seed_hex: str + cert_renew_in_seconds: int + jwt_renew_in_seconds: int + + +class RenewRequest(BaseModel): + client_csr_pem: str | None = None + + +class RenewResponse(BaseModel): + executor_jwt: str + jwt_renew_in_seconds: int + client_cert_pem: str | None = None + cert_renew_in_seconds: int | None = None class ExecutorHeartbeat(BaseModel): diff --git a/src/dlw/services/executor_service.py b/src/dlw/services/executor_service.py index 4de9868..446b110 100644 --- a/src/dlw/services/executor_service.py +++ b/src/dlw/services/executor_service.py @@ -5,6 +5,9 @@ Status resets to 'joining' on every rejoin so that 'unhealthy' (set by reclaim_stale_executors) flips back to 'joining' → 'healthy' on the next heartbeat. + +Phase 2 W3a: join_executor renamed to upsert_executor_with_cert; now accepts +cert_fingerprint + hmac_seed as explicit kwargs rather than from ExecutorJoin schema. """ from __future__ import annotations @@ -13,29 +16,37 @@ from sqlalchemy.ext.asyncio import AsyncSession from dlw.db.models.executor import Executor -from dlw.schemas.executor import ExecutorHeartbeat, ExecutorJoin - +from dlw.schemas.executor import ExecutorHeartbeat -async def join_executor(session: AsyncSession, body: ExecutorJoin) -> Executor: - """Atomic INSERT-or-bump. Returns the persisted Executor row with current epoch. - PG INSERT ... ON CONFLICT (id) DO UPDATE is atomic for the bump — two - concurrent join calls for the same id can never get the same epoch. - """ +async def upsert_executor_with_cert( + session: AsyncSession, + *, + executor_id: str, + host_id: str, + capabilities: dict, + cert_fingerprint: str, + hmac_seed: bytes, +) -> Executor: + """W3a: INSERT-or-bump executor row, writing cert_fingerprint + + hmac_seed_encrypted. Same atomic epoch semantics as W1 join_executor: + epoch=1 on insert, +1 on conflict; status='joining'. Caller commits.""" stmt = pg_insert(Executor).values( - id=body.id, - host_id=body.host_id, - cert_fingerprint=body.cert_fingerprint, - capabilities=body.capabilities, + id=executor_id, + host_id=host_id, + cert_fingerprint=cert_fingerprint, + hmac_seed_encrypted=hmac_seed, + capabilities=capabilities, status="joining", epoch=1, ).on_conflict_do_update( index_elements=["id"], set_=dict( status="joining", - host_id=body.host_id, - cert_fingerprint=body.cert_fingerprint, - capabilities=body.capabilities, + host_id=host_id, + cert_fingerprint=cert_fingerprint, + hmac_seed_encrypted=hmac_seed, + capabilities=capabilities, epoch=Executor.__table__.c.epoch + 1, ), ).returning(Executor) diff --git a/tests/api/test_executors.py b/tests/api/test_executors.py index 2639182..d6e50bb 100644 --- a/tests/api/test_executors.py +++ b/tests/api/test_executors.py @@ -1,4 +1,4 @@ -"""Tests for executors API: join / heartbeat / poll with bearer auth.""" +"""Tests for executors API: register / heartbeat / poll under mTLS + JWT + HMAC (W3a).""" from __future__ import annotations import pytest @@ -7,9 +7,15 @@ from dlw.config import get_settings from dlw.db.base import Base +from tests.conftest import ( + executor_request_headers, + register_test_executor, + signed_heartbeat_headers, +) _TOKEN = "test-bearer-token-12345" +_ENROLL = "test-enrollment-token-w3a-executors" @pytest.fixture(scope="module", autouse=True) @@ -36,7 +42,9 @@ async def _bootstrap(engine): @pytest.fixture(autouse=True) def _set_token(monkeypatch: pytest.MonkeyPatch): + """Bearer token for /tasks (UI) endpoints + trusted-proxy bypass for mTLS.""" monkeypatch.setenv("DLW_BEARER_TOKEN", _TOKEN) + monkeypatch.setenv("DLW_TLS_TRUSTED_PROXY", "1") get_settings.cache_clear() yield get_settings.cache_clear() @@ -44,62 +52,86 @@ def _set_token(monkeypatch: pytest.MonkeyPatch): @pytest.fixture def auth() -> dict[str, str]: + """Bearer header for /tasks (UI) endpoints — executor endpoints use mTLS+JWT.""" return {"Authorization": f"Bearer {_TOKEN}"} @pytest.fixture -async def client(): +async def client(ephemeral_ca): + """App with W3a auth state injected onto app.state from the session CA.""" from dlw.main import create_app + from dlw.auth.hmac_nonce import NonceStore app = create_app() + app.state.ca = ephemeral_ca["ca"] + app.state.jwt_keypair = ephemeral_ca["jwt_keypair"] + app.state.nonce_store = NonceStore(maxsize=1000, ttl_seconds=300) + app.state.enrollment_token = _ENROLL async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as c: yield c @pytest.fixture -async def joined_executor(client: AsyncClient, auth: dict[str, str]) -> tuple[str, int]: - """POST /join and return (executor_id, epoch). Used by tests that need fence headers.""" - r = await client.post("/api/v1/executors/join", json={ - "id": "fence-host-worker-1", "host_id": "fence-host", - }, headers=auth) - assert r.status_code == 201 - body = r.json() - return body["id"], body["epoch"] +async def registered_executor(client: AsyncClient) -> dict: + """Register a fence-test executor via /register; return the reg dict.""" + return await register_test_executor( + client, enrollment_token=_ENROLL, + executor_id="fence-host-worker-1", host_id="fence-host", + ) + + +async def _register_and_heartbeat(client, executor_id: str, host_id: str) -> dict: + """Register an executor + send one heartbeat (joining → healthy) so /poll + can claim work. Returns the reg dict.""" + reg = await register_test_executor( + client, enrollment_token=_ENROLL, + executor_id=executor_id, host_id=host_id, + ) + hb_body = b'{"health_score": 100, "parts_dir_bytes": 0, "disk_free_gb": 100}' + r = await client.post( + f"/api/v1/executors/{executor_id}/heartbeat", + content=hb_body, + headers=signed_heartbeat_headers(reg, hb_body), + ) + assert r.status_code == 200, r.text + return reg @pytest.mark.slow -async def test_executor_join_returns_201(client, auth) -> None: - r = await client.post("/api/v1/executors/join", json={ - "id": "exec-test-1", "host_id": "host-test", - }, headers=auth) - assert r.status_code == 201, r.text - body = r.json() - assert body["id"] == "exec-test-1" - assert body["status"] == "joining" +async def test_executor_register_returns_201(client) -> None: + reg = await register_test_executor( + client, enrollment_token=_ENROLL, + executor_id="exec-test-1", host_id="host-test", + ) + assert reg["executor_id"] == "exec-test-1" + assert reg["epoch"] == 1 + assert reg["cert_pem"].startswith("-----BEGIN CERTIFICATE-----") + assert reg["jwt"] + assert len(reg["hmac_seed"]) == 32 @pytest.mark.slow -async def test_executor_heartbeat_transitions_to_healthy(client, auth) -> None: - r_join = await client.post("/api/v1/executors/join", json={ - "id": "exec-hb-1", "host_id": "host-hb" - }, headers=auth) - epoch = r_join.json()["epoch"] - r = await client.post("/api/v1/executors/exec-hb-1/heartbeat", - json={"health_score": 95}, - headers={**auth, "X-Executor-Epoch": str(epoch)}) +async def test_executor_heartbeat_transitions_to_healthy(client) -> None: + reg = await register_test_executor( + client, enrollment_token=_ENROLL, + executor_id="exec-hb-1", host_id="host-hb", + ) + hb_body = b'{"health_score": 95}' + r = await client.post( + "/api/v1/executors/exec-hb-1/heartbeat", + content=hb_body, + headers=signed_heartbeat_headers(reg, hb_body), + ) assert r.status_code == 200, r.text assert r.json()["status"] == "healthy" assert r.json()["health_score"] == 95 @pytest.mark.slow -async def test_poll_returns_assigned_false_when_no_work(client, auth) -> None: - r_join = await client.post("/api/v1/executors/join", json={ - "id": "exec-poll-empty", "host_id": "host-pe" - }, headers=auth) - epoch = r_join.json()["epoch"] +async def test_poll_returns_assigned_false_when_no_work(client) -> None: + reg = await _register_and_heartbeat(client, "exec-poll-empty", "host-pe") r = await client.post("/api/v1/executors/exec-poll-empty/poll", - headers={**auth, "X-Executor-Epoch": str(epoch)}) - assert r.status_code == 200 + headers=executor_request_headers(reg)) + assert r.status_code == 200, r.text assert r.json()["assigned"] is False assert r.json()["subtask"] is None @@ -109,18 +141,9 @@ async def test_poll_returns_subtask_when_work_available(client, auth) -> None: await client.post("/api/v1/tasks", json={ "repo_id": "o/poll", "revision": "0" * 40, "storage_id": 1, }, headers=auth) - r_join = await client.post("/api/v1/executors/join", json={ - "id": "exec-poll-w", "host_id": "host-pw" - }, headers=auth) - epoch = r_join.json()["epoch"] - # W2a: claim_one_subtask requires status='healthy'|'degraded'; join sets - # 'joining'. One heartbeat transitions the executor to 'healthy'. - # W2b1: include disk_free_gb so disk pre-flight allows claiming. - await client.post("/api/v1/executors/exec-poll-w/heartbeat", - json={"health_score": 100, "parts_dir_bytes": 0, "disk_free_gb": 100}, - headers={**auth, "X-Executor-Epoch": str(epoch)}) + reg = await _register_and_heartbeat(client, "exec-poll-w", "host-pw") r = await client.post("/api/v1/executors/exec-poll-w/poll", - headers={**auth, "X-Executor-Epoch": str(epoch)}) + headers=executor_request_headers(reg)) assert r.status_code == 200, r.text body = r.json() assert body["assigned"] is True @@ -131,9 +154,30 @@ async def test_poll_returns_subtask_when_work_available(client, auth) -> None: @pytest.mark.slow async def test_unauthenticated_returns_401(client) -> None: - r = await client.post("/api/v1/executors/join", json={ - "id": "x", "host_id": "y" - }) + """No mTLS cert header → /heartbeat is rejected before the handler runs.""" + r = await client.post( + "/api/v1/executors/whoever/heartbeat", + content=b'{"health_score": 100}', + ) + assert r.status_code == 401 + + +@pytest.mark.slow +async def test_register_rejects_bad_enrollment_token(client) -> None: + """A wrong enrollment token on /register is rejected.""" + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric import ed25519 + from cryptography import x509 + from cryptography.x509.oid import NameOID + key = ed25519.Ed25519PrivateKey.generate() + csr = (x509.CertificateSigningRequestBuilder() + .subject_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "x")])) + .sign(key, None)) + csr_pem = csr.public_bytes(serialization.Encoding.PEM).decode("utf-8") + r = await client.post("/api/v1/executors/register", json={ + "host_id": "y", "executor_id_proposal": "x", + "capabilities": {}, "client_csr_pem": csr_pem, + }, headers={"X-Enrollment-Token": "wrong"}) assert r.status_code == 401 @@ -150,18 +194,10 @@ async def fake_hf(*args, **kwargs): # Drain any leftover pending subtasks from earlier tests so we can assert # the exact repo_id returned by this test's task. - r_drain_join = await client.post("/api/v1/executors/join", json={ - "id": "host-x-drain", "host_id": "host-x", - }, headers=auth) - drain_epoch = r_drain_join.json()["epoch"] - # W2a: claim_one_subtask requires healthy/degraded status; transition via heartbeat. - # W2b1: include disk_free_gb so disk pre-flight allows claiming during drain. - await client.post("/api/v1/executors/host-x-drain/heartbeat", - json={"health_score": 100, "parts_dir_bytes": 0, "disk_free_gb": 100}, - headers={**auth, "X-Executor-Epoch": str(drain_epoch)}) + drain = await _register_and_heartbeat(client, "host-x-drain", "host-x") for _ in range(20): # safety upper bound dr = await client.post("/api/v1/executors/host-x-drain/poll", - headers={**auth, "X-Executor-Epoch": str(drain_epoch)}) + headers=executor_request_headers(drain)) if not dr.json().get("assigned"): break @@ -173,20 +209,9 @@ async def fake_hf(*args, **kwargs): }, headers=auth) assert r.status_code == 201 - # Join an executor and transition to healthy before polling. - r_worker_join = await client.post("/api/v1/executors/join", json={ - "id": "host-x-worker-1", "host_id": "host-x", - }, headers=auth) - worker_epoch = r_worker_join.json()["epoch"] - # W2a: heartbeat transitions joining → healthy so poll can claim work. - # W2b1: include disk_free_gb so disk pre-flight allows claiming. - await client.post("/api/v1/executors/host-x-worker-1/heartbeat", - json={"health_score": 100, "parts_dir_bytes": 0, "disk_free_gb": 100}, - headers={**auth, "X-Executor-Epoch": str(worker_epoch)}) - - # Poll + worker = await _register_and_heartbeat(client, "host-x-worker-1", "host-x") pr = await client.post("/api/v1/executors/host-x-worker-1/poll", - headers={**auth, "X-Executor-Epoch": str(worker_epoch)}) + headers=executor_request_headers(worker)) assert pr.status_code == 200 body = pr.json() assert body["assigned"] is True @@ -198,26 +223,30 @@ async def fake_hf(*args, **kwargs): @pytest.mark.slow async def test_heartbeat_missing_epoch_header_returns_401( - client: AsyncClient, auth: dict[str, str], joined_executor, + client: AsyncClient, registered_executor: dict, ) -> None: - eid, _ = joined_executor + reg = registered_executor + hb_body = b'{"health_score": 100, "parts_dir_bytes": 0}' + headers = signed_heartbeat_headers(reg, hb_body) + del headers["X-Executor-Epoch"] r = await client.post( - f"/api/v1/executors/{eid}/heartbeat", - json={"health_score": 100, "parts_dir_bytes": 0}, - headers=auth, + f"/api/v1/executors/{reg['executor_id']}/heartbeat", + content=hb_body, headers=headers, ) assert r.status_code == 401 @pytest.mark.slow async def test_heartbeat_wrong_epoch_returns_EPOCH_MISMATCH( - client: AsyncClient, auth: dict[str, str], joined_executor, + client: AsyncClient, registered_executor: dict, ) -> None: - eid, epoch = joined_executor + reg = registered_executor + hb_body = b'{"health_score": 100, "parts_dir_bytes": 0}' + headers = signed_heartbeat_headers(reg, hb_body) + headers["X-Executor-Epoch"] = str(reg["epoch"] + 1) r = await client.post( - f"/api/v1/executors/{eid}/heartbeat", - json={"health_score": 100, "parts_dir_bytes": 0}, - headers={**auth, "X-Executor-Epoch": str(epoch + 1)}, + f"/api/v1/executors/{reg['executor_id']}/heartbeat", + content=hb_body, headers=headers, ) assert r.status_code == 401 assert r.json()["detail"]["code"] == "EPOCH_MISMATCH" @@ -225,44 +254,47 @@ async def test_heartbeat_wrong_epoch_returns_EPOCH_MISMATCH( @pytest.mark.slow async def test_heartbeat_correct_epoch_passes( - client: AsyncClient, auth: dict[str, str], joined_executor, + client: AsyncClient, registered_executor: dict, ) -> None: - eid, epoch = joined_executor + reg = registered_executor + hb_body = b'{"health_score": 100, "parts_dir_bytes": 0}' r = await client.post( - f"/api/v1/executors/{eid}/heartbeat", - json={"health_score": 100, "parts_dir_bytes": 0}, - headers={**auth, "X-Executor-Epoch": str(epoch)}, + f"/api/v1/executors/{reg['executor_id']}/heartbeat", + content=hb_body, headers=signed_heartbeat_headers(reg, hb_body), ) assert r.status_code == 200 assert r.json()["status"] == "healthy" @pytest.mark.slow -async def test_poll_after_rejoin_uses_new_epoch( - client: AsyncClient, auth: dict[str, str], -) -> None: - """Two consecutive joins for the same id: poll with old epoch fails.""" - r1 = await client.post("/api/v1/executors/join", json={ - "id": "rejoin-host-worker-1", "host_id": "rejoin-host", - }, headers=auth) - epoch_old = r1.json()["epoch"] - - r2 = await client.post("/api/v1/executors/join", json={ - "id": "rejoin-host-worker-1", "host_id": "rejoin-host", - }, headers=auth) - epoch_new = r2.json()["epoch"] - assert epoch_new == epoch_old + 1 +async def test_poll_after_reregister_uses_new_epoch(client: AsyncClient) -> None: + """Re-register the same executor_id: poll with the old epoch fails, new passes. + + W3a: re-register bumps epoch AND issues a fresh cert. The old cert's + fingerprint is overwritten on the executor row, so polling with the new + cert + old epoch exercises the W1 fence under the new auth chain.""" + reg1 = await register_test_executor( + client, enrollment_token=_ENROLL, + executor_id="rejoin-host-worker-1", host_id="rejoin-host", + ) + reg2 = await register_test_executor( + client, enrollment_token=_ENROLL, + executor_id="rejoin-host-worker-1", host_id="rejoin-host", + ) + assert reg2["epoch"] == reg1["epoch"] + 1 - # Old epoch rejected + # New cert + OLD epoch → EPOCH_MISMATCH. + stale = dict(reg2) + stale["epoch"] = reg1["epoch"] r3 = await client.post( "/api/v1/executors/rejoin-host-worker-1/poll", - headers={**auth, "X-Executor-Epoch": str(epoch_old)}, + headers=executor_request_headers(stale), ) assert r3.status_code == 401 - # New epoch accepted + # New cert + new epoch → accepted. r4 = await client.post( "/api/v1/executors/rejoin-host-worker-1/poll", - headers={**auth, "X-Executor-Epoch": str(epoch_new)}, + headers=executor_request_headers(reg2), ) assert r4.status_code == 200 diff --git a/tests/api/test_register_endpoint.py b/tests/api/test_register_endpoint.py new file mode 100644 index 0000000..4981a55 --- /dev/null +++ b/tests/api/test_register_endpoint.py @@ -0,0 +1,94 @@ +"""Tests for POST /api/v1/executors/register (Phase 2 W3a §3.5).""" +from __future__ import annotations + +import pytest +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ed25519 +from cryptography import x509 +from cryptography.x509.oid import NameOID +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import async_sessionmaker + +import dlw.db.models # noqa: F401 — registers all ORM models with Base.metadata +from dlw.db.base import Base + + +_ENROLL = "test-enrollment-token-w3a" + + +def _build_csr(executor_id: str) -> str: + key = ed25519.Ed25519PrivateKey.generate() + csr = (x509.CertificateSigningRequestBuilder() + .subject_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, executor_id)])) + .sign(key, None)) + return csr.public_bytes(serialization.Encoding.PEM).decode("utf-8") + + +@pytest.fixture(scope="module", autouse=True) +async def _create_tables(engine): + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + + +@pytest.fixture +async def client(ephemeral_ca, monkeypatch, tmp_path): + """App with W3a auth state set directly on app.state (skip the real + lifespan bootstrap — set ca/jwt_keypair/enrollment_token from ephemeral_ca).""" + from dlw.main import create_app + from dlw.auth.hmac_nonce import NonceStore + app = create_app() + app.state.ca = ephemeral_ca["ca"] + app.state.jwt_keypair = ephemeral_ca["jwt_keypair"] + app.state.nonce_store = NonceStore(maxsize=100, ttl_seconds=300) + app.state.enrollment_token = _ENROLL + async with AsyncClient(transport=ASGITransport(app=app), + base_url="http://test") as c: + yield c + + +@pytest.mark.slow +async def test_register_returns_cert_jwt_and_hmac_seed(client) -> None: + csr = _build_csr("reg-host-worker-1") + r = await client.post("/api/v1/executors/register", json={ + "host_id": "reg-host", "executor_id_proposal": "reg-host-worker-1", + "capabilities": {}, "client_csr_pem": csr, + }, headers={"X-Enrollment-Token": _ENROLL}) + assert r.status_code == 201, r.text + body = r.json() + assert body["executor_id"] == "reg-host-worker-1" + assert body["epoch"] == 1 + assert body["client_cert_pem"].startswith("-----BEGIN CERTIFICATE-----") + assert body["executor_jwt"] + assert len(bytes.fromhex(body["hmac_seed_hex"])) == 32 + assert body["ca_chain"] + + +@pytest.mark.slow +async def test_register_rejects_invalid_enrollment_token(client) -> None: + csr = _build_csr("reg-host-worker-2") + r = await client.post("/api/v1/executors/register", json={ + "host_id": "reg-host", "executor_id_proposal": "reg-host-worker-2", + "capabilities": {}, "client_csr_pem": csr, + }, headers={"X-Enrollment-Token": "wrong-token"}) + assert r.status_code == 401 + + +@pytest.mark.slow +async def test_register_idempotent_on_reregister(client) -> None: + csr1 = _build_csr("reg-host-worker-3") + r1 = await client.post("/api/v1/executors/register", json={ + "host_id": "reg-host", "executor_id_proposal": "reg-host-worker-3", + "capabilities": {}, "client_csr_pem": csr1, + }, headers={"X-Enrollment-Token": _ENROLL}) + assert r1.status_code == 201 + epoch1 = r1.json()["epoch"] + csr2 = _build_csr("reg-host-worker-3") + r2 = await client.post("/api/v1/executors/register", json={ + "host_id": "reg-host", "executor_id_proposal": "reg-host-worker-3", + "capabilities": {}, "client_csr_pem": csr2, + }, headers={"X-Enrollment-Token": _ENROLL}) + assert r2.status_code == 201 + assert r2.json()["epoch"] == epoch1 + 1 # epoch bumped on re-register diff --git a/tests/api/test_renew_endpoint.py b/tests/api/test_renew_endpoint.py new file mode 100644 index 0000000..7fc506a --- /dev/null +++ b/tests/api/test_renew_endpoint.py @@ -0,0 +1,92 @@ +"""Tests for POST /api/v1/executors/{eid}/renew (Phase 2 W3a §3.5).""" +from __future__ import annotations + +import pytest +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ed25519 +from cryptography import x509 +from cryptography.x509.oid import NameOID +from httpx import ASGITransport, AsyncClient + +import dlw.db.models # noqa: F401 — registers all ORM models with Base.metadata +from dlw.db.base import Base + + +_ENROLL = "test-enrollment-token-w3a-renew" + + +def _build_csr(executor_id: str) -> str: + key = ed25519.Ed25519PrivateKey.generate() + csr = (x509.CertificateSigningRequestBuilder() + .subject_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, executor_id)])) + .sign(key, None)) + return csr.public_bytes(serialization.Encoding.PEM).decode("utf-8") + + +@pytest.fixture(scope="module", autouse=True) +async def _create_tables(engine): + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + + +@pytest.fixture +async def client(ephemeral_ca, monkeypatch): + from dlw.main import create_app + from dlw.auth.hmac_nonce import NonceStore + monkeypatch.setenv("DLW_TLS_TRUSTED_PROXY", "1") + app = create_app() + app.state.ca = ephemeral_ca["ca"] + app.state.jwt_keypair = ephemeral_ca["jwt_keypair"] + app.state.nonce_store = NonceStore(maxsize=100, ttl_seconds=300) + app.state.enrollment_token = _ENROLL + async with AsyncClient(transport=ASGITransport(app=app), + base_url="http://test") as c: + yield c + + +async def _register(client, executor_id: str) -> dict: + csr = _build_csr(executor_id) + r = await client.post("/api/v1/executors/register", json={ + "host_id": "renew-host", "executor_id_proposal": executor_id, + "capabilities": {}, "client_csr_pem": csr, + }, headers={"X-Enrollment-Token": _ENROLL}) + assert r.status_code == 201, r.text + return r.json() + + +@pytest.mark.slow +async def test_renew_returns_new_jwt_only_when_no_csr(client) -> None: + reg = await _register(client, "renew-worker-1") + r = await client.post( + "/api/v1/executors/renew-worker-1/renew", json={}, + headers={ + "X-Client-Cert-PEM": reg["client_cert_pem"].replace("\n", "\\n"), + "Authorization": f"Bearer {reg['executor_jwt']}", + }, + ) + assert r.status_code == 200, r.text + body = r.json() + assert body["executor_jwt"] + assert body["client_cert_pem"] is None + + +@pytest.mark.slow +async def test_renew_returns_new_cert_when_csr_provided(client) -> None: + reg = await _register(client, "renew-worker-2") + fresh_csr = _build_csr("renew-worker-2") + r = await client.post( + "/api/v1/executors/renew-worker-2/renew", + json={"client_csr_pem": fresh_csr}, + headers={ + "X-Client-Cert-PEM": reg["client_cert_pem"].replace("\n", "\\n"), + "Authorization": f"Bearer {reg['executor_jwt']}", + }, + ) + assert r.status_code == 200, r.text + body = r.json() + assert body["executor_jwt"] + assert body["client_cert_pem"] is not None + assert body["client_cert_pem"].startswith("-----BEGIN CERTIFICATE-----") diff --git a/tests/api/test_subtasks.py b/tests/api/test_subtasks.py index 0af7f4a..ffb0079 100644 --- a/tests/api/test_subtasks.py +++ b/tests/api/test_subtasks.py @@ -1,4 +1,4 @@ -"""Tests for subtasks API: POST /report including double-report idempotency.""" +"""Tests for subtasks API: POST /report under mTLS + JWT (W3a).""" from __future__ import annotations import uuid @@ -9,9 +9,15 @@ from dlw.config import get_settings from dlw.db.base import Base +from tests.conftest import ( + executor_request_headers, + register_test_executor, + signed_heartbeat_headers, +) _TOKEN = "test-bearer-token-12345" +_ENROLL = "test-enrollment-token-w3a-subtasks" @pytest.fixture(scope="module", autouse=True) @@ -47,6 +53,7 @@ async def _cleanup_tasks(engine): @pytest.fixture(autouse=True) def _set_token(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("DLW_BEARER_TOKEN", _TOKEN) + monkeypatch.setenv("DLW_TLS_TRUSTED_PROXY", "1") get_settings.cache_clear() yield get_settings.cache_clear() @@ -54,60 +61,52 @@ def _set_token(monkeypatch: pytest.MonkeyPatch): @pytest.fixture def auth() -> dict[str, str]: + """Bearer header for /tasks (UI) endpoints.""" return {"Authorization": f"Bearer {_TOKEN}"} @pytest.fixture -async def client(): +async def client(ephemeral_ca): from dlw.main import create_app + from dlw.auth.hmac_nonce import NonceStore app = create_app() + app.state.ca = ephemeral_ca["ca"] + app.state.jwt_keypair = ephemeral_ca["jwt_keypair"] + app.state.nonce_store = NonceStore(maxsize=1000, ttl_seconds=300) + app.state.enrollment_token = _ENROLL async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as c: yield c -@pytest.fixture -async def joined_executor(client: AsyncClient, auth: dict[str, str]) -> tuple[str, int]: - """POST /join and return (executor_id, epoch). Used by tests that need fence headers.""" - r = await client.post("/api/v1/executors/join", json={ - "id": "sub-fence-worker-1", "host_id": "sub-fence-host", - }, headers=auth) - assert r.status_code == 201 - body = r.json() - return body["id"], body["epoch"] - - -async def _setup_assigned_subtask(client, auth, repo_id="o/sub-test") -> tuple[str, str, int]: - """Helper: create task → join executor → heartbeat → poll → return (subtask_id, exec_id, epoch). - - W2a: claim_one_subtask requires status='healthy'|'degraded'. A heartbeat - after join transitions the executor from 'joining' to 'healthy'. - """ +async def _setup_assigned_subtask( + client, auth, repo_id="o/sub-test", +) -> tuple[str, dict]: + """create task → register executor → heartbeat → poll → return (subtask_id, reg).""" await client.post("/api/v1/tasks", json={ "repo_id": repo_id, "revision": "0" * 40, "storage_id": 1, }, headers=auth) exec_id = f"ex-{repo_id.replace('/', '-')}" - rj = await client.post("/api/v1/executors/join", json={ - "id": exec_id, "host_id": "h" - }, headers=auth) - epoch = rj.json()["epoch"] + reg = await register_test_executor( + client, enrollment_token=_ENROLL, executor_id=exec_id, host_id="h", + ) + hb_body = b'{"health_score": 100, "parts_dir_bytes": 0, "disk_free_gb": 100}' await client.post(f"/api/v1/executors/{exec_id}/heartbeat", - json={"health_score": 100, "parts_dir_bytes": 0, "disk_free_gb": 100}, - headers={**auth, "X-Executor-Epoch": str(epoch)}) + content=hb_body, headers=signed_heartbeat_headers(reg, hb_body)) r = await client.post( f"/api/v1/executors/{exec_id}/poll", - headers={**auth, "X-Executor-Epoch": str(epoch)}, + headers=executor_request_headers(reg), ) - return r.json()["subtask"]["id"], exec_id, epoch + return r.json()["subtask"]["id"], reg @pytest.mark.slow async def test_report_succeeded_marks_subtask_done(client, auth) -> None: - sub_id, exec_id, epoch = await _setup_assigned_subtask(client, auth, "o/r1") + sub_id, reg = await _setup_assigned_subtask(client, auth, "o/r1") r = await client.post(f"/api/v1/subtasks/{sub_id}/report", json={ "status": "succeeded", "actual_sha256": "a" * 64, "bytes_downloaded": 1234, - }, headers={**auth, "X-Executor-Epoch": str(epoch)}) + }, headers=executor_request_headers(reg)) assert r.status_code == 200, r.text @@ -117,26 +116,23 @@ async def test_report_two_subtasks_succeed_then_task_succeeds(client, auth) -> N "repo_id": "o/full", "revision": "0" * 40, "storage_id": 1, }, headers=auth) task_id = create.json()["id"] - rj = await client.post("/api/v1/executors/join", json={ - "id": "ex-full", "host_id": "h" - }, headers=auth) - epoch = rj.json()["epoch"] - # W2a: heartbeat transitions joining → healthy before poll can claim work. - # W2b1: include disk_free_gb so disk pre-flight allows claiming. + reg = await register_test_executor( + client, enrollment_token=_ENROLL, executor_id="ex-full", host_id="h", + ) + hb_body = b'{"health_score": 100, "parts_dir_bytes": 0, "disk_free_gb": 100}' await client.post("/api/v1/executors/ex-full/heartbeat", - json={"health_score": 100, "parts_dir_bytes": 0, "disk_free_gb": 100}, - headers={**auth, "X-Executor-Epoch": str(epoch)}) + content=hb_body, headers=signed_heartbeat_headers(reg, hb_body)) sub_ids = [] for _ in range(2): r = await client.post( "/api/v1/executors/ex-full/poll", - headers={**auth, "X-Executor-Epoch": str(epoch)}, + headers=executor_request_headers(reg), ) sub_ids.append(r.json()["subtask"]["id"]) for sid in sub_ids: await client.post(f"/api/v1/subtasks/{sid}/report", json={ "status": "succeeded", "actual_sha256": "b" * 64, "bytes_downloaded": 100, - }, headers={**auth, "X-Executor-Epoch": str(epoch)}) + }, headers=executor_request_headers(reg)) r = await client.get(f"/api/v1/tasks/{task_id}", headers=auth) assert r.json()["status"] == "succeeded" assert r.json()["completed_at"] is not None @@ -148,23 +144,20 @@ async def test_report_one_failure_marks_task_failed(client, auth) -> None: "repo_id": "o/fail", "revision": "0" * 40, "storage_id": 1, }, headers=auth) task_id = create.json()["id"] - rj = await client.post("/api/v1/executors/join", json={ - "id": "ex-fail", "host_id": "h" - }, headers=auth) - epoch = rj.json()["epoch"] - # W2a: heartbeat transitions joining → healthy before poll can claim work. - # W2b1: include disk_free_gb so disk pre-flight allows claiming. + reg = await register_test_executor( + client, enrollment_token=_ENROLL, executor_id="ex-fail", host_id="h", + ) + hb_body = b'{"health_score": 100, "parts_dir_bytes": 0, "disk_free_gb": 100}' await client.post("/api/v1/executors/ex-fail/heartbeat", - json={"health_score": 100, "parts_dir_bytes": 0, "disk_free_gb": 100}, - headers={**auth, "X-Executor-Epoch": str(epoch)}) + content=hb_body, headers=signed_heartbeat_headers(reg, hb_body)) r = await client.post( "/api/v1/executors/ex-fail/poll", - headers={**auth, "X-Executor-Epoch": str(epoch)}, + headers=executor_request_headers(reg), ) sub_id = r.json()["subtask"]["id"] await client.post(f"/api/v1/subtasks/{sub_id}/report", json={ "status": "failed", "error": "disk full", - }, headers={**auth, "X-Executor-Epoch": str(epoch)}) + }, headers=executor_request_headers(reg)) r = await client.get(f"/api/v1/tasks/{task_id}", headers=auth) assert r.json()["status"] == "failed" assert "disk full" in r.json()["error_message"] @@ -172,14 +165,19 @@ async def test_report_one_failure_marks_task_failed(client, auth) -> None: @pytest.mark.slow async def test_report_unknown_subtask_returns_404(client, auth) -> None: + """Authenticated request to a random subtask id → 404.""" + reg = await register_test_executor( + client, enrollment_token=_ENROLL, executor_id="ex-404", host_id="h", + ) r = await client.post(f"/api/v1/subtasks/{uuid.uuid4()}/report", json={ "status": "succeeded", - }, headers={**auth, "X-Executor-Epoch": "1"}) + }, headers=executor_request_headers(reg)) assert r.status_code == 404 @pytest.mark.slow async def test_report_unauthenticated_returns_401(client) -> None: + """No mTLS cert → /report rejected before the handler runs.""" r = await client.post(f"/api/v1/subtasks/{uuid.uuid4()}/report", json={ "status": "succeeded", }) @@ -189,14 +187,14 @@ async def test_report_unauthenticated_returns_401(client) -> None: @pytest.mark.slow async def test_double_report_returns_409(client, auth) -> None: """W2-G: idempotency / illegal-transition guard.""" - sub_id, exec_id, epoch = await _setup_assigned_subtask(client, auth, "o/dup") + sub_id, reg = await _setup_assigned_subtask(client, auth, "o/dup") r1 = await client.post(f"/api/v1/subtasks/{sub_id}/report", json={ "status": "succeeded", "actual_sha256": "c" * 64, "bytes_downloaded": 100, - }, headers={**auth, "X-Executor-Epoch": str(epoch)}) + }, headers=executor_request_headers(reg)) assert r1.status_code == 200 r2 = await client.post(f"/api/v1/subtasks/{sub_id}/report", json={ "status": "succeeded", "actual_sha256": "c" * 64, "bytes_downloaded": 100, - }, headers={**auth, "X-Executor-Epoch": str(epoch)}) + }, headers=executor_request_headers(reg)) assert r2.status_code == 409, r2.text @@ -204,10 +202,16 @@ async def test_double_report_returns_409(client, auth) -> None: async def test_report_missing_epoch_header_returns_401( client: AsyncClient, auth: dict[str, str], ) -> None: + """Authenticated but no X-Executor-Epoch → 401.""" + reg = await register_test_executor( + client, enrollment_token=_ENROLL, executor_id="ex-noepoch", host_id="h", + ) + headers = executor_request_headers(reg) + del headers["X-Executor-Epoch"] r = await client.post( f"/api/v1/subtasks/{uuid.uuid4()}/report", json={"status": "succeeded", "bytes_downloaded": 100}, - headers=auth, + headers=headers, ) assert r.status_code == 401 @@ -216,29 +220,23 @@ async def test_report_missing_epoch_header_returns_401( async def test_report_stale_epoch_returns_EPOCH_MISMATCH( client: AsyncClient, auth: dict[str, str], ) -> None: - """Create task + join executor + claim subtask + report with stale epoch.""" - # Setup: create a task to generate subtasks + """Claim a subtask under epoch=1, re-register to epoch=2, report with the + stale epoch=1 (using the fresh cert) → EPOCH_MISMATCH.""" r = await client.post("/api/v1/tasks", json={ "repo_id": "o/fence-report", "revision": "0" * 40, "storage_id": 1, }, headers=auth) - task_id = r.json()["id"] - - # Join executor — get epoch=1 - rj = await client.post("/api/v1/executors/join", json={ - "id": "report-host-worker-1", "host_id": "report-host", - }, headers=auth) - epoch = rj.json()["epoch"] + assert r.status_code == 201 - # W2a: heartbeat transitions joining → healthy before poll can claim work. - # W2b1: include disk_free_gb so disk pre-flight allows claiming. + reg1 = await register_test_executor( + client, enrollment_token=_ENROLL, + executor_id="report-host-worker-1", host_id="report-host", + ) + hb_body = b'{"health_score": 100, "parts_dir_bytes": 0, "disk_free_gb": 100}' await client.post("/api/v1/executors/report-host-worker-1/heartbeat", - json={"health_score": 100, "parts_dir_bytes": 0, "disk_free_gb": 100}, - headers={**auth, "X-Executor-Epoch": str(epoch)}) - - # Claim subtask via poll + content=hb_body, headers=signed_heartbeat_headers(reg1, hb_body)) rp = await client.post( "/api/v1/executors/report-host-worker-1/poll", - headers={**auth, "X-Executor-Epoch": str(epoch)}, + headers=executor_request_headers(reg1), ) assert rp.status_code == 200, rp.text if not rp.json()["assigned"]: @@ -246,20 +244,23 @@ async def test_report_stale_epoch_returns_EPOCH_MISMATCH( subtask_id = rp.json()["subtask"]["id"] token = rp.json()["assignment_token"] - # Bump epoch (re-join → epoch=2) - rj2 = await client.post("/api/v1/executors/join", json={ - "id": "report-host-worker-1", "host_id": "report-host", - }, headers=auth) - assert rj2.json()["epoch"] == epoch + 1 + # Re-register → epoch=2 + fresh cert. + reg2 = await register_test_executor( + client, enrollment_token=_ENROLL, + executor_id="report-host-worker-1", host_id="report-host", + ) + assert reg2["epoch"] == reg1["epoch"] + 1 - # Report with STALE epoch (the one we claimed under) + # Report with the fresh cert but the STALE epoch. + stale = dict(reg2) + stale["epoch"] = reg1["epoch"] rr = await client.post( f"/api/v1/subtasks/{subtask_id}/report", json={ "status": "succeeded", "bytes_downloaded": 100, "actual_sha256": "a" * 64, "assignment_token": token, }, - headers={**auth, "X-Executor-Epoch": str(epoch)}, # stale! + headers=executor_request_headers(stale), ) assert rr.status_code == 401 assert rr.json()["detail"]["code"] == "EPOCH_MISMATCH" diff --git a/tests/auth/test_ca.py b/tests/auth/test_ca.py new file mode 100644 index 0000000..95d6f4f --- /dev/null +++ b/tests/auth/test_ca.py @@ -0,0 +1,66 @@ +"""Tests for dlw.auth.ca (Phase 2 W3a §3.1).""" +from __future__ import annotations + +from cryptography import x509 +from cryptography.hazmat.primitives.asymmetric import ed25519 + +from dlw.auth.ca import ( + bootstrap_ca, + ensure_server_cert, + fingerprint_of, + sign_csr, +) + + +def _build_csr(executor_id: str) -> bytes: + from cryptography.hazmat.primitives import serialization + from cryptography.x509.oid import NameOID + key = ed25519.Ed25519PrivateKey.generate() + csr = (x509.CertificateSigningRequestBuilder() + .subject_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, executor_id)])) + .sign(key, None)) + return csr.public_bytes(serialization.Encoding.PEM) + + +def test_bootstrap_ca_idempotent(tmp_path) -> None: + ca1 = bootstrap_ca(tmp_path) + ca2 = bootstrap_ca(tmp_path) + assert ca1.cert_pem == ca2.cert_pem + assert ca1.key_pem == ca2.key_pem + + +def test_sign_csr_returns_valid_client_cert(tmp_path) -> None: + ca = bootstrap_ca(tmp_path) + csr_pem = _build_csr("host-1-worker-1") + cert_pem = sign_csr(ca, csr_pem, "host-1-worker-1", ttl_hours=24) + cert = x509.load_pem_x509_certificate(cert_pem) + cn = cert.subject.get_attributes_for_oid(x509.oid.NameOID.COMMON_NAME)[0].value + assert cn == "host-1-worker-1" + ca.cert.public_key().verify(cert.signature, cert.tbs_certificate_bytes) + eku = cert.extensions.get_extension_for_class(x509.ExtendedKeyUsage).value + assert x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH in eku + + +def test_fingerprint_of_is_deterministic_sha256(tmp_path) -> None: + ca = bootstrap_ca(tmp_path) + csr_pem = _build_csr("host-2-worker-1") + cert_pem = sign_csr(ca, csr_pem, "host-2-worker-1") + fp1 = fingerprint_of(cert_pem) + fp2 = fingerprint_of(cert_pem) + assert fp1 == fp2 + assert fp1.startswith("SHA256:") + assert len(fp1) == len("SHA256:") + 64 + + +def test_ensure_server_cert_has_required_sans(tmp_path) -> None: + ca = bootstrap_ca(tmp_path) + cert_path, key_path = ensure_server_cert(ca, tmp_path, hostname="dlw-controller") + assert cert_path.exists() and key_path.exists() + cert = x509.load_pem_x509_certificate(cert_path.read_bytes()) + san = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value + dns_names = san.get_values_for_type(x509.DNSName) + ip_addrs = {str(ip) for ip in san.get_values_for_type(x509.IPAddress)} + assert "localhost" in dns_names + assert "dlw-controller" in dns_names + assert "127.0.0.1" in ip_addrs + assert "::1" in ip_addrs diff --git a/tests/auth/test_executor_epoch.py b/tests/auth/test_executor_epoch.py index 73f7f69..adad5cd 100644 --- a/tests/auth/test_executor_epoch.py +++ b/tests/auth/test_executor_epoch.py @@ -2,103 +2,124 @@ from __future__ import annotations import pytest -from fastapi import FastAPI, Depends +from fastapi import FastAPI, Depends, Path from httpx import ASGITransport, AsyncClient from sqlalchemy.ext.asyncio import async_sessionmaker -from dlw.config import get_settings +from dlw.auth.ca import fingerprint_of +from dlw.auth.jwt_signing import sign from dlw.db.base import Base from dlw.db.models.executor import Executor -_TOKEN = "test-bearer-token-12345" - - @pytest.fixture(scope="module", autouse=True) async def _bootstrap(engine): - """W6-M (revert W6-I): restore drop_all at module end. - - W6-I was over-cautious — pytest runs modules sequentially (not in parallel - by default), so module-level drop_all firing on engine teardown doesn't - affect other modules. Keeping the table left over causes test_alembic.py - to fail because alembic upgrade-head can't create tables that already exist. - """ + """Create tables once per module; drop at end.""" async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) - factory = async_sessionmaker(engine, expire_on_commit=False) - async with factory() as s: - # ON CONFLICT DO NOTHING keeps the seed idempotent (defensive) - from sqlalchemy.dialects.postgresql import insert as pg_insert - stmt = pg_insert(Executor).values( - id="probe-host-worker-1", host_id="probe-host", - cert_fingerprint="x", status="healthy", epoch=3, - ).on_conflict_do_nothing() - await s.execute(stmt) - await s.commit() yield async with engine.begin() as conn: await conn.run_sync(Base.metadata.drop_all) -@pytest.fixture(autouse=True) -def _set_token(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setenv("DLW_BEARER_TOKEN", _TOKEN) - get_settings.cache_clear() - yield - get_settings.cache_clear() - - -@pytest.fixture -def app(): - """Build a tiny app that mounts the dep on a probe endpoint.""" - from dlw.api.tasks import _session +def _build_app(jwt_keypair): + """Build a tiny app that mounts require_executor_epoch on a probe endpoint.""" from dlw.auth.executor_epoch import require_executor_epoch app = FastAPI() + app.state.jwt_keypair = jwt_keypair - @app.get("/probe/{executor_id}") - async def probe(executor: Executor = Depends(require_executor_epoch)): + @app.post("/probe/{executor_id}") + async def probe(executor_id: str = Path(...), + executor: Executor = Depends(require_executor_epoch)): return {"executor_id": executor.id, "epoch": executor.epoch} return app -@pytest.fixture -async def client(app): - async with AsyncClient( - transport=ASGITransport(app=app), base_url="http://test" - ) as c: - yield c +async def _seed_probe_executor(engine, executor_id, cert_fingerprint, epoch=3): + from sqlalchemy.dialects.postgresql import insert as pg_insert + factory = async_sessionmaker(engine, expire_on_commit=False) + async with factory() as s: + stmt = pg_insert(Executor).values( + id=executor_id, host_id="probe-host", + cert_fingerprint=cert_fingerprint, status="healthy", epoch=epoch, + ).on_conflict_do_update( + index_elements=["id"], + set_={"cert_fingerprint": cert_fingerprint, "epoch": epoch, + "status": "healthy"}, + ) + await s.execute(stmt) + await s.commit() @pytest.mark.slow -async def test_require_epoch_missing_header_returns_401(client: AsyncClient) -> None: - # W6-D: dep accepts Optional + raises 401 with custom detail (not FastAPI auto-422) - r = await client.get("/probe/probe-host-worker-1") +async def test_require_epoch_missing_header_returns_401( + engine, ephemeral_ca, client_cert_pair, monkeypatch, +) -> None: + """Missing X-Executor-Epoch → 401 (mTLS+JWT pass, epoch header absent).""" + monkeypatch.setenv("DLW_TLS_TRUSTED_PROXY", "1") + cert_pem, _key, executor_id = client_cert_pair + fp = fingerprint_of(cert_pem) + await _seed_probe_executor(engine, executor_id, fp, epoch=3) + token = sign(ephemeral_ca["jwt_keypair"], executor_id=executor_id, + epoch=3, scopes=["heartbeat"]) + app = _build_app(ephemeral_ca["jwt_keypair"]) + async with AsyncClient(transport=ASGITransport(app=app), + base_url="http://test") as c: + r = await c.post(f"/probe/{executor_id}", headers={ + "X-Client-Cert-PEM": cert_pem.decode("utf-8").replace("\n", "\\n"), + "Authorization": f"Bearer {token}", + }) assert r.status_code == 401 assert "missing X-Executor-Epoch" in r.json()["detail"] @pytest.mark.slow -async def test_require_epoch_unknown_executor_returns_404( - client: AsyncClient, +async def test_require_epoch_unknown_executor_returns_executor_id_mismatch( + engine, ephemeral_ca, client_cert_pair, monkeypatch, ) -> None: - r = await client.get( - "/probe/no-such-host-worker-99", - headers={"X-Executor-Epoch": "1"}, - ) - assert r.status_code == 404 - assert "executor not found" in r.json()["detail"] + """Path executor_id differs from mTLS-authenticated identity → 401 EXECUTOR_ID_MISMATCH. + + W3a refactor: confused-deputy guard fires before epoch check. + (Previously returned 404 'executor not found'; that DB lookup is removed.) + """ + monkeypatch.setenv("DLW_TLS_TRUSTED_PROXY", "1") + cert_pem, _key, executor_id = client_cert_pair + fp = fingerprint_of(cert_pem) + await _seed_probe_executor(engine, executor_id, fp, epoch=3) + token = sign(ephemeral_ca["jwt_keypair"], executor_id=executor_id, + epoch=3, scopes=["heartbeat"]) + app = _build_app(ephemeral_ca["jwt_keypair"]) + async with AsyncClient(transport=ASGITransport(app=app), + base_url="http://test") as c: + r = await c.post("/probe/no-such-host-worker-99", headers={ + "X-Client-Cert-PEM": cert_pem.decode("utf-8").replace("\n", "\\n"), + "Authorization": f"Bearer {token}", + "X-Executor-Epoch": "1", + }) + assert r.status_code == 401 + assert r.json()["detail"]["code"] == "EXECUTOR_ID_MISMATCH" @pytest.mark.slow async def test_require_epoch_mismatch_returns_EPOCH_MISMATCH( - client: AsyncClient, + engine, ephemeral_ca, client_cert_pair, monkeypatch, ) -> None: - r = await client.get( - "/probe/probe-host-worker-1", - headers={"X-Executor-Epoch": "2"}, - ) + monkeypatch.setenv("DLW_TLS_TRUSTED_PROXY", "1") + cert_pem, _key, executor_id = client_cert_pair + fp = fingerprint_of(cert_pem) + await _seed_probe_executor(engine, executor_id, fp, epoch=3) + token = sign(ephemeral_ca["jwt_keypair"], executor_id=executor_id, + epoch=3, scopes=["heartbeat"]) + app = _build_app(ephemeral_ca["jwt_keypair"]) + async with AsyncClient(transport=ASGITransport(app=app), + base_url="http://test") as c: + r = await c.post(f"/probe/{executor_id}", headers={ + "X-Client-Cert-PEM": cert_pem.decode("utf-8").replace("\n", "\\n"), + "Authorization": f"Bearer {token}", + "X-Executor-Epoch": "2", + }) assert r.status_code == 401 body = r.json() assert body["detail"]["code"] == "EPOCH_MISMATCH" @@ -108,11 +129,67 @@ async def test_require_epoch_mismatch_returns_EPOCH_MISMATCH( @pytest.mark.slow async def test_require_epoch_match_returns_executor_row( - client: AsyncClient, + engine, ephemeral_ca, client_cert_pair, monkeypatch, ) -> None: - r = await client.get( - "/probe/probe-host-worker-1", - headers={"X-Executor-Epoch": "3"}, - ) + monkeypatch.setenv("DLW_TLS_TRUSTED_PROXY", "1") + cert_pem, _key, executor_id = client_cert_pair + fp = fingerprint_of(cert_pem) + await _seed_probe_executor(engine, executor_id, fp, epoch=3) + token = sign(ephemeral_ca["jwt_keypair"], executor_id=executor_id, + epoch=3, scopes=["heartbeat"]) + app = _build_app(ephemeral_ca["jwt_keypair"]) + async with AsyncClient(transport=ASGITransport(app=app), + base_url="http://test") as c: + r = await c.post(f"/probe/{executor_id}", headers={ + "X-Client-Cert-PEM": cert_pem.decode("utf-8").replace("\n", "\\n"), + "Authorization": f"Bearer {token}", + "X-Executor-Epoch": "3", + }) assert r.status_code == 200 - assert r.json() == {"executor_id": "probe-host-worker-1", "epoch": 3} + assert r.json() == {"executor_id": executor_id, "epoch": 3} + + +@pytest.mark.slow +async def test_require_executor_epoch_rejects_path_id_mismatch( + engine, ephemeral_ca, client_cert_pair, monkeypatch, +) -> None: + """mTLS+JWT authenticate executor A, URL path says executor B → + 401 EXECUTOR_ID_MISMATCH (confused-deputy guard).""" + from dlw.auth.executor_epoch import require_executor_epoch + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + monkeypatch.setenv("DLW_TLS_TRUSTED_PROXY", "1") + cert_pem, _key, executor_id = client_cert_pair + fp = fingerprint_of(cert_pem) + from sqlalchemy.dialects.postgresql import insert as pg_insert + factory = async_sessionmaker(engine, expire_on_commit=False) + async with factory() as s: + stmt = pg_insert(Executor).values( + id=executor_id, host_id="h", cert_fingerprint=fp, + status="healthy", epoch=1, + ).on_conflict_do_update( + index_elements=["id"], + set_={"cert_fingerprint": fp, "epoch": 1, "status": "healthy"}, + ) + await s.execute(stmt) + await s.commit() + token = sign(ephemeral_ca["jwt_keypair"], executor_id=executor_id, + epoch=1, scopes=["heartbeat"]) + + app = FastAPI() + app.state.jwt_keypair = ephemeral_ca["jwt_keypair"] + + @app.post("/executors/{executor_id}/x") + async def x(executor_id: str = Path(...), + ex: Executor = Depends(require_executor_epoch)) -> dict: + return {"ok": True} + + async with AsyncClient(transport=ASGITransport(app=app), + base_url="http://test") as c: + r = await c.post("/executors/other-executor/x", headers={ + "X-Client-Cert-PEM": cert_pem.decode("utf-8").replace("\n", "\\n"), + "Authorization": f"Bearer {token}", + "X-Executor-Epoch": "1", + }) + assert r.status_code == 401 diff --git a/tests/auth/test_executor_jwt_dep.py b/tests/auth/test_executor_jwt_dep.py new file mode 100644 index 0000000..4c081e3 --- /dev/null +++ b/tests/auth/test_executor_jwt_dep.py @@ -0,0 +1,98 @@ +"""Tests for require_executor_jwt (Phase 2 W3a §3.4).""" +from __future__ import annotations + +import pytest +from fastapi import FastAPI, Depends +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import async_sessionmaker + +from dlw.auth.ca import fingerprint_of +from dlw.auth.executor_jwt_dep import require_executor_jwt +from dlw.auth.jwt_signing import sign +from dlw.db.base import Base +from dlw.db.models.executor import Executor + + +@pytest.fixture(scope="module", autouse=True) +async def _create_tables(engine): + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + + +def _mini_app(jwt_keypair): + app = FastAPI() + app.state.jwt_keypair = jwt_keypair + + @app.get("/whoami") + async def whoami(ex: Executor = Depends(require_executor_jwt)) -> dict: + return {"executor_id": ex.id} + + return app + + +@pytest.mark.slow +async def test_require_executor_jwt_accepts_valid_token( + engine, ephemeral_ca, client_cert_pair, monkeypatch, +) -> None: + monkeypatch.setenv("DLW_TLS_TRUSTED_PROXY", "1") + cert_pem, _key, executor_id = client_cert_pair + fp = fingerprint_of(cert_pem) + factory = async_sessionmaker(engine, expire_on_commit=False) + async with factory() as s: + from sqlalchemy.dialects.postgresql import insert as pg_insert + stmt = pg_insert(Executor).values( + id=executor_id, host_id="h", cert_fingerprint=fp, + status="healthy", epoch=2, + ).on_conflict_do_update( + index_elements=["id"], + set_={"cert_fingerprint": fp, "epoch": 2, "status": "healthy"}, + ) + await s.execute(stmt) + await s.commit() + token = sign(ephemeral_ca["jwt_keypair"], executor_id=executor_id, + epoch=2, scopes=["heartbeat"]) + + app = _mini_app(ephemeral_ca["jwt_keypair"]) + async with AsyncClient(transport=ASGITransport(app=app), + base_url="http://test") as c: + r = await c.get("/whoami", headers={ + "X-Client-Cert-PEM": cert_pem.decode("utf-8").replace("\n", "\\n"), + "Authorization": f"Bearer {token}", + }) + assert r.status_code == 200 + assert r.json()["executor_id"] == executor_id + + +@pytest.mark.slow +async def test_require_executor_jwt_rejects_sub_mismatch( + engine, ephemeral_ca, client_cert_pair, monkeypatch, +) -> None: + monkeypatch.setenv("DLW_TLS_TRUSTED_PROXY", "1") + cert_pem, _key, executor_id = client_cert_pair + fp = fingerprint_of(cert_pem) + factory = async_sessionmaker(engine, expire_on_commit=False) + async with factory() as s: + from sqlalchemy.dialects.postgresql import insert as pg_insert + stmt = pg_insert(Executor).values( + id=executor_id, host_id="h", cert_fingerprint=fp, + status="healthy", epoch=2, + ).on_conflict_do_update( + index_elements=["id"], + set_={"cert_fingerprint": fp, "epoch": 2, "status": "healthy"}, + ) + await s.execute(stmt) + await s.commit() + token = sign(ephemeral_ca["jwt_keypair"], executor_id="other-executor", + epoch=2, scopes=["heartbeat"]) + + app = _mini_app(ephemeral_ca["jwt_keypair"]) + async with AsyncClient(transport=ASGITransport(app=app), + base_url="http://test") as c: + r = await c.get("/whoami", headers={ + "X-Client-Cert-PEM": cert_pem.decode("utf-8").replace("\n", "\\n"), + "Authorization": f"Bearer {token}", + }) + assert r.status_code == 401 diff --git a/tests/auth/test_executor_mtls_dep.py b/tests/auth/test_executor_mtls_dep.py new file mode 100644 index 0000000..e5c766d --- /dev/null +++ b/tests/auth/test_executor_mtls_dep.py @@ -0,0 +1,91 @@ +"""Tests for require_executor_mtls (Phase 2 W3a §3.4).""" +from __future__ import annotations + +import pytest +from fastapi import FastAPI, Depends +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import async_sessionmaker + +from dlw.auth.ca import fingerprint_of +from dlw.auth.executor_mtls import require_executor_mtls +from dlw.db.base import Base +from dlw.db.models.executor import Executor + + +@pytest.fixture(scope="module", autouse=True) +async def _create_tables(engine): + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + + +def _mini_app(): + app = FastAPI() + + @app.get("/whoami") + async def whoami(ex: Executor = Depends(require_executor_mtls)) -> dict: + return {"executor_id": ex.id} + + return app + + +@pytest.mark.slow +async def test_require_executor_mtls_via_trusted_proxy_header( + engine, client_cert_pair, monkeypatch, +) -> None: + monkeypatch.setenv("DLW_TLS_TRUSTED_PROXY", "1") + cert_pem, _key, executor_id = client_cert_pair + fp = fingerprint_of(cert_pem) + factory = async_sessionmaker(engine, expire_on_commit=False) + async with factory() as s: + from sqlalchemy.dialects.postgresql import insert as pg_insert + stmt = pg_insert(Executor).values( + id=executor_id, host_id="h", cert_fingerprint=fp, + status="healthy", epoch=1, + ).on_conflict_do_update( + index_elements=["id"], + set_={"cert_fingerprint": fp, "epoch": 1, "status": "healthy"}, + ) + await s.execute(stmt) + await s.commit() + + app = _mini_app() + async with AsyncClient(transport=ASGITransport(app=app), + base_url="http://test") as c: + r = await c.get("/whoami", headers={ + "X-Client-Cert-PEM": cert_pem.decode("utf-8").replace("\n", "\\n"), + }) + assert r.status_code == 200 + assert r.json()["executor_id"] == executor_id + + +@pytest.mark.slow +async def test_require_executor_mtls_rejects_unknown_fingerprint( + client_cert_pair, monkeypatch, +) -> None: + monkeypatch.setenv("DLW_TLS_TRUSTED_PROXY", "1") + cert_pem, _key, _ = client_cert_pair + app = _mini_app() + async with AsyncClient(transport=ASGITransport(app=app), + base_url="http://test") as c: + r = await c.get("/whoami", headers={ + "X-Client-Cert-PEM": cert_pem.decode("utf-8").replace("\n", "\\n"), + }) + assert r.status_code == 401 + + +@pytest.mark.slow +async def test_require_executor_mtls_rejects_header_when_proxy_disabled( + client_cert_pair, monkeypatch, +) -> None: + monkeypatch.setenv("DLW_TLS_TRUSTED_PROXY", "0") + cert_pem, _key, _ = client_cert_pair + app = _mini_app() + async with AsyncClient(transport=ASGITransport(app=app), + base_url="http://test") as c: + r = await c.get("/whoami", headers={ + "X-Client-Cert-PEM": cert_pem.decode("utf-8").replace("\n", "\\n"), + }) + assert r.status_code == 401 diff --git a/tests/auth/test_hmac_heartbeat_dep.py b/tests/auth/test_hmac_heartbeat_dep.py new file mode 100644 index 0000000..40360f7 --- /dev/null +++ b/tests/auth/test_hmac_heartbeat_dep.py @@ -0,0 +1,162 @@ +"""Tests for require_hmac_heartbeat (Phase 2 W3a §3.4).""" +from __future__ import annotations + +import secrets +import time + +import pytest +from fastapi import FastAPI, Depends, Request +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import async_sessionmaker + +from dlw.auth.ca import fingerprint_of +from dlw.auth.hmac_heartbeat_dep import require_hmac_heartbeat +from dlw.auth.hmac_nonce import NonceStore, compute_hmac +from dlw.auth.jwt_signing import sign +from dlw.db.base import Base +from dlw.db.models.executor import Executor + + +@pytest.fixture(scope="module", autouse=True) +async def _create_tables(engine): + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + + +_HMAC_SEED = b"\x02" * 32 + + +async def _seed_executor(engine, executor_id, fp): + from sqlalchemy.dialects.postgresql import insert as pg_insert + factory = async_sessionmaker(engine, expire_on_commit=False) + async with factory() as s: + stmt = pg_insert(Executor).values( + id=executor_id, host_id="h", cert_fingerprint=fp, + status="healthy", epoch=1, hmac_seed_encrypted=_HMAC_SEED, + ).on_conflict_do_update( + index_elements=["id"], + set_={"cert_fingerprint": fp, "epoch": 1, "status": "healthy", + "hmac_seed_encrypted": _HMAC_SEED}, + ) + await s.execute(stmt) + await s.commit() + + +def _mini_app(jwt_keypair): + app = FastAPI() + app.state.jwt_keypair = jwt_keypair + app.state.nonce_store = NonceStore(maxsize=100, ttl_seconds=300) + + @app.post("/hb") + async def hb(request: Request, + ex: Executor = Depends(require_hmac_heartbeat)) -> dict: + return {"ok": True, "executor_id": ex.id} + + return app + + +def _hmac_headers(seed, body: bytes, *, ts=None, nonce=None): + ts = ts if ts is not None else int(time.time()) + nonce = nonce or secrets.token_hex(16) + sig = compute_hmac(seed, ts=ts, nonce=nonce, body=body) + return {"X-HMAC-Timestamp": str(ts), "X-HMAC-Nonce": nonce, + "X-HMAC-Signature": sig} + + +@pytest.mark.slow +async def test_hmac_heartbeat_accepts_valid_signature( + engine, ephemeral_ca, client_cert_pair, monkeypatch, +) -> None: + monkeypatch.setenv("DLW_TLS_TRUSTED_PROXY", "1") + cert_pem, _key, executor_id = client_cert_pair + fp = fingerprint_of(cert_pem) + await _seed_executor(engine, executor_id, fp) + token = sign(ephemeral_ca["jwt_keypair"], executor_id=executor_id, + epoch=1, scopes=["heartbeat"]) + body = b'{"health_score":100}' + app = _mini_app(ephemeral_ca["jwt_keypair"]) + async with AsyncClient(transport=ASGITransport(app=app), + base_url="http://test") as c: + r = await c.post("/hb", content=body, headers={ + "X-Client-Cert-PEM": cert_pem.decode("utf-8").replace("\n", "\\n"), + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + **_hmac_headers(_HMAC_SEED, body), + }) + assert r.status_code == 200 + + +@pytest.mark.slow +async def test_hmac_heartbeat_rejects_clock_skew( + engine, ephemeral_ca, client_cert_pair, monkeypatch, +) -> None: + monkeypatch.setenv("DLW_TLS_TRUSTED_PROXY", "1") + cert_pem, _key, executor_id = client_cert_pair + fp = fingerprint_of(cert_pem) + await _seed_executor(engine, executor_id, fp) + token = sign(ephemeral_ca["jwt_keypair"], executor_id=executor_id, + epoch=1, scopes=["heartbeat"]) + body = b'{"health_score":100}' + app = _mini_app(ephemeral_ca["jwt_keypair"]) + async with AsyncClient(transport=ASGITransport(app=app), + base_url="http://test") as c: + r = await c.post("/hb", content=body, headers={ + "X-Client-Cert-PEM": cert_pem.decode("utf-8").replace("\n", "\\n"), + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + **_hmac_headers(_HMAC_SEED, body, ts=int(time.time()) - 400), + }) + assert r.status_code == 401 + + +@pytest.mark.slow +async def test_hmac_heartbeat_rejects_replay( + engine, ephemeral_ca, client_cert_pair, monkeypatch, +) -> None: + monkeypatch.setenv("DLW_TLS_TRUSTED_PROXY", "1") + cert_pem, _key, executor_id = client_cert_pair + fp = fingerprint_of(cert_pem) + await _seed_executor(engine, executor_id, fp) + token = sign(ephemeral_ca["jwt_keypair"], executor_id=executor_id, + epoch=1, scopes=["heartbeat"]) + body = b'{"health_score":100}' + headers_base = { + "X-Client-Cert-PEM": cert_pem.decode("utf-8").replace("\n", "\\n"), + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + } + hmac_h = _hmac_headers(_HMAC_SEED, body, nonce="fixed-replay-nonce") + app = _mini_app(ephemeral_ca["jwt_keypair"]) + async with AsyncClient(transport=ASGITransport(app=app), + base_url="http://test") as c: + r1 = await c.post("/hb", content=body, headers={**headers_base, **hmac_h}) + r2 = await c.post("/hb", content=body, headers={**headers_base, **hmac_h}) + assert r1.status_code == 200 + assert r2.status_code == 401 + + +@pytest.mark.slow +async def test_hmac_heartbeat_rejects_tampered_body( + engine, ephemeral_ca, client_cert_pair, monkeypatch, +) -> None: + monkeypatch.setenv("DLW_TLS_TRUSTED_PROXY", "1") + cert_pem, _key, executor_id = client_cert_pair + fp = fingerprint_of(cert_pem) + await _seed_executor(engine, executor_id, fp) + token = sign(ephemeral_ca["jwt_keypair"], executor_id=executor_id, + epoch=1, scopes=["heartbeat"]) + signed_body = b'{"health_score":100}' + sent_body = b'{"health_score":999}' + app = _mini_app(ephemeral_ca["jwt_keypair"]) + async with AsyncClient(transport=ASGITransport(app=app), + base_url="http://test") as c: + r = await c.post("/hb", content=sent_body, headers={ + "X-Client-Cert-PEM": cert_pem.decode("utf-8").replace("\n", "\\n"), + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + **_hmac_headers(_HMAC_SEED, signed_body), + }) + assert r.status_code == 401 diff --git a/tests/auth/test_hmac_nonce.py b/tests/auth/test_hmac_nonce.py new file mode 100644 index 0000000..fd4e04c --- /dev/null +++ b/tests/auth/test_hmac_nonce.py @@ -0,0 +1,39 @@ +"""Tests for dlw.auth.hmac_nonce (Phase 2 W3a §3.3).""" +from __future__ import annotations + +from dlw.auth.hmac_nonce import NonceStore, compute_hmac, verify_hmac + + +_SEED = b"\x01" * 32 + + +def test_hmac_compute_and_verify_roundtrip() -> None: + body = b'{"health_score":100}' + sig = compute_hmac(_SEED, ts=1715739200, nonce="abc", body=body) + assert verify_hmac(_SEED, ts=1715739200, nonce="abc", body=body, + signature_hex=sig) + + +def test_hmac_verify_rejects_tampered_body() -> None: + body = b'{"health_score":100}' + sig = compute_hmac(_SEED, ts=1715739200, nonce="abc", body=body) + tampered = b'{"health_score":101}' + assert not verify_hmac(_SEED, ts=1715739200, nonce="abc", body=tampered, + signature_hex=sig) + + +def test_nonce_store_first_add_then_seen() -> None: + store = NonceStore(maxsize=100, ttl_seconds=300) + assert not store.seen("n1") + store.add("n1") + assert store.seen("n1") + + +def test_nonce_store_evicts_after_ttl(monkeypatch) -> None: + store = NonceStore(maxsize=100, ttl_seconds=10) + fake = [1000.0] + monkeypatch.setattr("dlw.auth.hmac_nonce.time.monotonic", lambda: fake[0]) + store.add("n1") + assert store.seen("n1") + fake[0] += 11 + assert not store.seen("n1") diff --git a/tests/auth/test_jwt_signing.py b/tests/auth/test_jwt_signing.py new file mode 100644 index 0000000..07b164f --- /dev/null +++ b/tests/auth/test_jwt_signing.py @@ -0,0 +1,47 @@ +"""Tests for dlw.auth.jwt_signing (Phase 2 W3a §3.2).""" +from __future__ import annotations + +import time + +import jwt as _pyjwt +import pytest + +from dlw.auth.jwt_signing import bootstrap_keypair, sign, verify + + +def test_bootstrap_keypair_idempotent(tmp_path) -> None: + kp1 = bootstrap_keypair(tmp_path) + kp2 = bootstrap_keypair(tmp_path) + assert kp1.priv_pem == kp2.priv_pem + assert kp1.pub_pem == kp2.pub_pem + + +def test_sign_and_verify_roundtrip(tmp_path) -> None: + kp = bootstrap_keypair(tmp_path) + token = sign(kp, executor_id="host-1-worker-1", epoch=3, + scopes=["heartbeat", "poll"], ttl_seconds=3600) + claims = verify(kp, token) + assert claims["sub"] == "host-1-worker-1" + assert claims["epoch"] == 3 + assert claims["scope"] == "heartbeat poll" + assert claims["iss"] == "dlw-controller" + + +def test_verify_rejects_expired_token(tmp_path) -> None: + kp = bootstrap_keypair(tmp_path) + token = sign(kp, executor_id="e", epoch=1, scopes=["heartbeat"], + ttl_seconds=-10) + with pytest.raises(_pyjwt.PyJWTError): + verify(kp, token) + + +def test_verify_rejects_wrong_issuer(tmp_path) -> None: + kp = bootstrap_keypair(tmp_path) + now = int(time.time()) + bad = _pyjwt.encode( + {"iss": "evil", "sub": "e", "epoch": 1, "scope": "heartbeat", + "iat": now, "exp": now + 3600}, + kp.priv_pem.decode("utf-8"), algorithm="EdDSA", + ) + with pytest.raises(_pyjwt.PyJWTError): + verify(kp, bad) diff --git a/tests/conftest.py b/tests/conftest.py index 80fa9c1..552352a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -165,3 +165,111 @@ async def _point_app_at_test_db(test_db_name: str, engine: AsyncEngine): os.environ[k] = v get_settings.cache_clear() await reset_engine() + + +@pytest.fixture(scope="session") +def ephemeral_ca(tmp_path_factory): + """One CA + JWT keypair per test session, in a temp dir.""" + from dlw.auth.ca import bootstrap_ca + from dlw.auth.jwt_signing import bootstrap_keypair + ca_dir = tmp_path_factory.mktemp("ca") + ca = bootstrap_ca(ca_dir) + jwt_kp = bootstrap_keypair(ca_dir) + return {"ca": ca, "jwt_keypair": jwt_kp, "ca_dir": ca_dir} + + +@pytest.fixture +def client_cert_pair(ephemeral_ca): + """Per-test client cert ('test-executor-1') signed by the session CA. + Returns (cert_pem: bytes, key: Ed25519PrivateKey, executor_id: str).""" + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric import ed25519 + from cryptography import x509 + from cryptography.x509.oid import NameOID + from dlw.auth.ca import sign_csr + executor_id = "test-executor-1" + key = ed25519.Ed25519PrivateKey.generate() + csr = (x509.CertificateSigningRequestBuilder() + .subject_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, executor_id)])) + .sign(key, None)) + csr_pem = csr.public_bytes(serialization.Encoding.PEM) + cert_pem = sign_csr(ephemeral_ca["ca"], csr_pem, executor_id, ttl_hours=24) + return cert_pem, key, executor_id + + +# --- W3a: helpers for executor auth in API tests (mTLS+JWT+HMAC via header bypass) --- + +async def register_test_executor( + client, *, enrollment_token: str, + executor_id: str = "test-host-worker-1", host_id: str = "test-host", +) -> dict: + """Build a CSR, POST /api/v1/executors/register, return a dict with + executor_id, epoch, cert_pem, jwt, hmac_seed, ca_chain. The caller's app + must have app.state.ca / jwt_keypair / nonce_store / enrollment_token set.""" + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric import ed25519 + from cryptography import x509 + from cryptography.x509.oid import NameOID + key = ed25519.Ed25519PrivateKey.generate() + csr = (x509.CertificateSigningRequestBuilder() + .subject_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, executor_id)])) + .sign(key, None)) + csr_pem = csr.public_bytes(serialization.Encoding.PEM).decode("utf-8") + r = await client.post("/api/v1/executors/register", json={ + "host_id": host_id, "executor_id_proposal": executor_id, + "capabilities": {}, "client_csr_pem": csr_pem, + }, headers={"X-Enrollment-Token": enrollment_token}) + assert r.status_code in (200, 201), r.text + body = r.json() + return { + "executor_id": body["executor_id"], "epoch": body["epoch"], + "cert_pem": body["client_cert_pem"], "jwt": body["executor_jwt"], + "hmac_seed": bytes.fromhex(body["hmac_seed_hex"]), + "ca_chain": body["ca_chain"], + } + + +def executor_request_headers(reg: dict) -> dict[str, str]: + """mTLS-bypass + JWT + epoch headers for poll / report (no HMAC).""" + return { + "X-Client-Cert-PEM": reg["cert_pem"].replace("\n", "\n"), + "Authorization": f"Bearer {reg['jwt']}", + "X-Executor-Epoch": str(reg["epoch"]), + } + + +def signed_heartbeat_headers(reg: dict, body: bytes) -> dict[str, str]: + """mTLS-bypass + JWT + epoch + HMAC headers for a heartbeat body.""" + import secrets as _s + import time as _t + from dlw.auth.hmac_nonce import compute_hmac + ts = int(_t.time()) + nonce = _s.token_hex(16) + sig = compute_hmac(reg["hmac_seed"], ts=ts, nonce=nonce, body=body) + return { + **executor_request_headers(reg), + "X-HMAC-Timestamp": str(ts), + "X-HMAC-Nonce": nonce, + "X-HMAC-Signature": sig, + "Content-Type": "application/json", + } + + +def make_fake_auth_state( + cert_dir, *, executor_id: str = "test-worker-1", epoch: int = 1, + jwt: str = "fake.jwt.token", hmac_seed: bytes = b"\x09" * 32, +): + """Build a minimal AuthState for executor-side tests that inject a + MockTransport (cert files are never read in that mode).""" + import datetime as _dt + from pathlib import Path as _Path + from dlw.executor.auth_lifecycle import AuthState + far = _dt.datetime.now(_dt.UTC) + _dt.timedelta(hours=24) + return AuthState( + executor_id=executor_id, epoch=epoch, + cert_pem=b"-----BEGIN CERTIFICATE-----\nfake\n-----END CERTIFICATE-----", + key_pem=b"-----BEGIN PRIVATE KEY-----\nfake\n-----END PRIVATE KEY-----", + ca_chain_pem=b"-----BEGIN CERTIFICATE-----\nfakeca\n-----END CERTIFICATE-----", + jwt=jwt, jwt_exp=far, cert_exp=far, + hmac_seed=hmac_seed, cert_dir=_Path(str(cert_dir)), + ) diff --git a/tests/e2e/test_executor_auth_e2e.py b/tests/e2e/test_executor_auth_e2e.py new file mode 100644 index 0000000..ea1ada5 --- /dev/null +++ b/tests/e2e/test_executor_auth_e2e.py @@ -0,0 +1,181 @@ +"""Real-TLS e2e: register → heartbeat full flow (Phase 2 W3a §7.2). + +Spawns uvicorn with actual --ssl-* flags and exercises mTLS end-to-end. +This is the single test that verifies the uvicorn TLS wiring + peer-cert +extraction path; all other W3a auth tests use the header-bypass. +""" +from __future__ import annotations + +import asyncio +import json +import os +import secrets +import socket +import subprocess +import sys +import time + +import httpx +import pytest + +from dlw.auth.ca import bootstrap_ca, ensure_server_cert +from dlw.auth.hmac_nonce import compute_hmac +from dlw.db.base import Base + + +def _free_port() -> int: + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture(scope="module", autouse=True) +async def _bootstrap(engine): + """Create tables at module start; drop all at end (mirrors other e2e modules + so each e2e module gets a clean schema for its seed data).""" + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + + +@pytest.mark.slow +async def test_register_then_heartbeat_full_flow(tmp_path, engine, test_db_name) -> None: + """Spawn uvicorn with real TLS; register an executor; send an HMAC-signed + heartbeat over mTLS. Verifies the uvicorn wiring + peer-cert extraction.""" + # 1. Pre-create the CA + server cert. The subprocess controller's lifespan + # calls bootstrap_ca on the same dir — it's idempotent, so it loads these. + ca_dir = tmp_path / "ca" + ca_dir.mkdir() + ca = bootstrap_ca(ca_dir) + server_cert, server_key = ensure_server_cert(ca, ca_dir, hostname="localhost") + + # 2. Seed the minimal FK rows the subprocess controller needs for startup + # and registration. Tables are created by the module-scoped _bootstrap fixture. + from dlw.db.models.storage import StorageBackend + from dlw.db.models.tenant import Project, Tenant, User + from sqlalchemy.ext.asyncio import async_sessionmaker + factory = async_sessionmaker(engine, expire_on_commit=False) + async with factory() as s: + s.add(Tenant(id=1, slug="d", display_name="D")) + await s.flush() + s.add(Project(id=1, tenant_id=1, name="d")) + s.add(User(id=1, tenant_id=1, oidc_subject="d", + email="d@l", role="tenant_admin")) + s.add(StorageBackend(id=1, tenant_id=1, name="d", + backend_type="s3", config_encrypted=b"")) + await s.commit() + + port = _free_port() + enrollment_token = "e2e-real-tls-enrollment-token" + + # 3. Spawn uvicorn with real TLS. Plumb the test-DB env vars so the + # subprocess controller connects to the SAME test database the + # `engine` fixture created. + env = { + **os.environ, + "DLW_CA_DIR": str(ca_dir), + "DLW_ENROLLMENT_TOKEN": enrollment_token, + "DLW_CONTROLLER_HOSTNAME": "localhost", + "DLW_DB_HOST": os.environ.get("DLW_TEST_PG_HOST", "localhost"), + "DLW_DB_PORT": os.environ.get("DLW_TEST_PG_PORT", "5433"), + "DLW_DB_USER": os.environ.get("DLW_TEST_PG_USER", "postgres"), + "DLW_DB_PASSWORD": os.environ.get("DLW_TEST_PG_PASSWORD", ""), + "DLW_DB_NAME": test_db_name, + "DLW_STRICT_RECOVERY": "false", # don't abort startup on recovery hiccups + } + proc = subprocess.Popen( + [sys.executable, "-m", "uvicorn", "dlw.main:app", + "--host", "127.0.0.1", "--port", str(port), + "--ssl-keyfile", str(server_key), + "--ssl-certfile", str(server_cert), + "--ssl-ca-certs", str(ca_dir / "ca-cert.pem"), + "--ssl-cert-reqs", "1"], # CERT_OPTIONAL: app layer enforces presence + env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + ) + try: + base = f"https://localhost:{port}" + ca_cert_path = str(ca_dir / "ca-cert.pem") + + # 4. Wait for the server to be ready. /healthz may or may not require + # mTLS — probe with the CA bundle; tolerate any non-connection error. + ready = False + for _ in range(60): + if proc.poll() is not None: + out = proc.stdout.read().decode() if proc.stdout else "" + pytest.fail(f"uvicorn exited early (code {proc.returncode}):\n{out}") + try: + async with httpx.AsyncClient(verify=ca_cert_path) as c: + r = await c.get(f"{base}/health/live", timeout=1.0) + if r.status_code < 500: + ready = True + break + except Exception: + await asyncio.sleep(0.3) + if not ready: + out = proc.stdout.read().decode() if proc.stdout else "" + pytest.fail(f"uvicorn did not become ready:\n{out}") + + # 5. Build a CSR + register (no client cert for /register — enrollment token). + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric import ed25519 + from cryptography import x509 + from cryptography.x509.oid import NameOID + key = ed25519.Ed25519PrivateKey.generate() + csr = (x509.CertificateSigningRequestBuilder() + .subject_name(x509.Name([ + x509.NameAttribute(NameOID.COMMON_NAME, "e2e-tls-worker-1"), + ])) + .sign(key, None)) + csr_pem = csr.public_bytes(serialization.Encoding.PEM).decode() + + async with httpx.AsyncClient(verify=ca_cert_path) as c: + reg = await c.post(f"{base}/api/v1/executors/register", json={ + "host_id": "e2e-tls-host", + "executor_id_proposal": "e2e-tls-worker-1", + "capabilities": {}, "client_csr_pem": csr_pem, + }, headers={"X-Enrollment-Token": enrollment_token}) + assert reg.status_code == 201, reg.text + body = reg.json() + + # 6. Persist the issued client cert + key for the mTLS heartbeat call. + client_cert_path = tmp_path / "client-cert.pem" + client_key_path = tmp_path / "client-key.pem" + client_cert_path.write_text(body["client_cert_pem"]) + client_key_path.write_bytes(key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + )) + + # 7. Heartbeat over real mTLS + JWT + HMAC. + hb_body = json.dumps({"health_score": 100, "parts_dir_bytes": 0}).encode() + ts = int(time.time()) + nonce = secrets.token_hex(16) + sig = compute_hmac(bytes.fromhex(body["hmac_seed_hex"]), + ts=ts, nonce=nonce, body=hb_body) + async with httpx.AsyncClient( + verify=ca_cert_path, + cert=(str(client_cert_path), str(client_key_path)), + ) as c: + hb = await c.post( + f"{base}/api/v1/executors/e2e-tls-worker-1/heartbeat", + content=hb_body, + headers={ + "Authorization": f"Bearer {body['executor_jwt']}", + "X-Executor-Epoch": str(body["epoch"]), + "X-HMAC-Timestamp": str(ts), + "X-HMAC-Nonce": nonce, + "X-HMAC-Signature": sig, + "Content-Type": "application/json", + }, + ) + assert hb.status_code == 200, hb.text + assert hb.json()["status"] == "healthy" + finally: + proc.terminate() + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + proc.kill() diff --git a/tests/e2e/test_executor_e2e.py b/tests/e2e/test_executor_e2e.py index 487303e..ef94e92 100644 --- a/tests/e2e/test_executor_e2e.py +++ b/tests/e2e/test_executor_e2e.py @@ -11,6 +11,7 @@ import json import os import uuid +from pathlib import Path from unittest.mock import MagicMock import boto3 @@ -65,9 +66,13 @@ async def _bootstrap(engine): await conn.run_sync(Base.metadata.drop_all) +_ENROLL = "e2e-w4-enrollment-token" + + @pytest.fixture(autouse=True) def _set_token(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("DLW_BEARER_TOKEN", _TOKEN) + monkeypatch.setenv("DLW_TLS_TRUSTED_PROXY", "1") monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test") monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test") monkeypatch.setenv("AWS_DEFAULT_REGION", "us-east-1") @@ -78,7 +83,7 @@ def _set_token(monkeypatch: pytest.MonkeyPatch): @pytest.mark.slow async def test_e2e_hf_to_s3_full_pipeline( - monkeypatch: pytest.MonkeyPatch, + monkeypatch: pytest.MonkeyPatch, ephemeral_ca, ) -> None: """End-to-end with mocked HF + moto S3, no MockDownloader.""" # Mock HF metadata (controller side) @@ -114,7 +119,16 @@ def hf_handler(request: httpx.Request) -> httpx.Response: s3.create_bucket(Bucket=_BUCKET) from dlw.main import create_app + from dlw.auth.hmac_nonce import NonceStore + from dlw.executor.auth_lifecycle import AuthState + from tests.conftest import register_test_executor app = create_app() + # W3a: inject the auth substrate onto app.state (skip the lifespan + # bootstrap — this test drives the ASGI app directly, no real server). + app.state.ca = ephemeral_ca["ca"] + app.state.jwt_keypair = ephemeral_ca["jwt_keypair"] + app.state.nonce_store = NonceStore(maxsize=1000, ttl_seconds=300) + app.state.enrollment_token = _ENROLL asgi_transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient( @@ -130,10 +144,31 @@ def hf_handler(request: httpx.Request) -> httpx.Response: assert r.status_code == 201, r.text task_id = r.json()["id"] + # W3a: register the executor via mTLS enrollment, then build an + # AuthState the ControllerClient + runner drive off. + reg = await register_test_executor( + ctrl_client, enrollment_token=_ENROLL, + executor_id="e2e-w4-host-worker-1", host_id="e2e-w4-host", + ) + import datetime as _dt + _far = _dt.datetime.now(_dt.UTC) + _dt.timedelta(hours=24) + auth_state = AuthState( + executor_id=reg["executor_id"], epoch=reg["epoch"], + cert_pem=reg["cert_pem"].encode("utf-8"), + key_pem=b"unused-in-asgi-transport-mode", + ca_chain_pem="\n".join(reg["ca_chain"]).encode("utf-8"), + jwt=reg["jwt"], jwt_exp=_far, cert_exp=_far, + hmac_seed=reg["hmac_seed"], cert_dir=Path("."), + ) executor_client = ControllerClient( base_url="http://test", - bearer_token=_TOKEN, + auth_state=auth_state, _transport=asgi_transport, + # ASGI transport has no real TLS — feed the cert via the + # trusted-proxy header bypass (DLW_TLS_TRUSTED_PROXY=1). + _extra_test_headers={ + "X-Client-Cert-PEM": reg["cert_pem"].replace("\n", "\\n"), + }, ) settings = ExecutorSettings( id="e2e-w4-host-worker-1", @@ -154,6 +189,7 @@ def hf_handler(request: httpx.Request) -> httpx.Response: runner = ExecutorRunner( settings=settings, client=executor_client, stream_downloader=downloader, chunk_downloader=MagicMock(), + auth_state=auth_state, # skip load_or_register ) async with executor_client: diff --git a/tests/e2e/test_happy_path.py b/tests/e2e/test_happy_path.py index b60223b..1395a54 100644 --- a/tests/e2e/test_happy_path.py +++ b/tests/e2e/test_happy_path.py @@ -38,18 +38,30 @@ async def _bootstrap(engine): await conn.run_sync(Base.metadata.drop_all) +_ENROLL = "e2e-enrollment-token-happy-path" + + @pytest.fixture(autouse=True) def _set_token(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("DLW_BEARER_TOKEN", _TOKEN) + monkeypatch.setenv("DLW_TLS_TRUSTED_PROXY", "1") get_settings.cache_clear() yield get_settings.cache_clear() @pytest.mark.slow -async def test_full_task_lifecycle_via_http() -> None: +async def test_full_task_lifecycle_via_http(ephemeral_ca) -> None: from dlw.main import create_app + from dlw.auth.hmac_nonce import NonceStore + from tests.conftest import ( + executor_request_headers, register_test_executor, signed_heartbeat_headers, + ) app = create_app() + app.state.ca = ephemeral_ca["ca"] + app.state.jwt_keypair = ephemeral_ca["jwt_keypair"] + app.state.nonce_store = NonceStore(maxsize=1000, ttl_seconds=300) + app.state.enrollment_token = _ENROLL auth = {"Authorization": f"Bearer {_TOKEN}"} async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as c: @@ -64,20 +76,17 @@ async def test_full_task_lifecycle_via_http() -> None: task_id = r.json()["id"] assert r.json()["status"] == "pending" - # 2. Register a worker executor — capture epoch (P2-W1 fence) - r = await c.post("/api/v1/executors/join", json={ - "id": "e2e-worker-1", "host_id": "e2e-host", - "capabilities": {"nic_speed_gbps": 25}, - }, headers=auth) - assert r.status_code == 201, r.text - epoch = r.json()["epoch"] - fence = {**auth, "X-Executor-Epoch": str(epoch)} + # 2. Register a worker executor via mTLS enrollment (P2-W3a). + reg = await register_test_executor( + c, enrollment_token=_ENROLL, + executor_id="e2e-worker-1", host_id="e2e-host", + ) - # 3. Heartbeat (executor reports liveness) - # W2b1: include disk_free_gb so disk pre-flight allows claiming. + # 3. Heartbeat (executor reports liveness) — mTLS + JWT + HMAC. + hb_body = b'{"health_score": 100, "parts_dir_bytes": 0, "disk_free_gb": 100}' r = await c.post("/api/v1/executors/e2e-worker-1/heartbeat", - json={"health_score": 100, "parts_dir_bytes": 0, "disk_free_gb": 100}, - headers=fence) + content=hb_body, + headers=signed_heartbeat_headers(reg, hb_body)) assert r.status_code == 200 assert r.json()["status"] == "healthy" @@ -85,7 +94,8 @@ async def test_full_task_lifecycle_via_http() -> None: sub_ids: list[str] = [] tokens: list[str] = [] for _ in range(2): - r = await c.post("/api/v1/executors/e2e-worker-1/poll", headers=fence) + r = await c.post("/api/v1/executors/e2e-worker-1/poll", + headers=executor_request_headers(reg)) assert r.status_code == 200, r.text assert r.json()["assigned"] is True sub_ids.append(r.json()["subtask"]["id"]) @@ -93,7 +103,8 @@ async def test_full_task_lifecycle_via_http() -> None: assert len(set(sub_ids)) == 2 # 5. Third poll → no work - r = await c.post("/api/v1/executors/e2e-worker-1/poll", headers=fence) + r = await c.post("/api/v1/executors/e2e-worker-1/poll", + headers=executor_request_headers(reg)) assert r.json()["assigned"] is False # 6. Report success for both subtasks (with token + epoch verification) @@ -103,7 +114,7 @@ async def test_full_task_lifecycle_via_http() -> None: "assignment_token": tok, "actual_sha256": "f" * 64, "bytes_downloaded": 100_000_000, - }, headers=fence) + }, headers=executor_request_headers(reg)) assert r.status_code == 200, r.text # 7. Task should now be succeeded diff --git a/tests/executor/test_auth_lifecycle.py b/tests/executor/test_auth_lifecycle.py new file mode 100644 index 0000000..4c2e9d1 --- /dev/null +++ b/tests/executor/test_auth_lifecycle.py @@ -0,0 +1,91 @@ +"""Tests for dlw.executor.auth_lifecycle (Phase 2 W3a §3.11).""" +from __future__ import annotations + +import datetime as _dt +import json +import secrets + +import httpx +import pytest + +from dlw.auth.ca import bootstrap_ca, sign_csr +from dlw.auth.jwt_signing import bootstrap_keypair, sign as jwt_sign +from dlw.executor.auth_lifecycle import AuthState, load_or_register, register + + +def _make_controller_transport(ca, jwt_kp, *, executor_id: str, epoch: int = 1): + """A MockTransport that emulates POST /register: signs the CSR, returns + cert + JWT + hmac_seed.""" + def handler(request: httpx.Request) -> httpx.Response: + if request.url.path == "/api/v1/executors/register": + payload = json.loads(request.content) + csr_pem = payload["client_csr_pem"].encode("utf-8") + cert_pem = sign_csr(ca, csr_pem, executor_id, ttl_hours=24) + token = jwt_sign(jwt_kp, executor_id=executor_id, epoch=epoch, + scopes=["heartbeat", "poll", "report"]) + return httpx.Response(201, json={ + "executor_id": executor_id, "epoch": epoch, + "client_cert_pem": cert_pem.decode("utf-8"), + "ca_chain": [ca.cert_pem.decode("utf-8")], + "executor_jwt": token, + "hmac_seed_hex": secrets.token_bytes(32).hex(), + "cert_renew_in_seconds": 86100, + "jwt_renew_in_seconds": 3300, + }) + return httpx.Response(404) + return httpx.MockTransport(handler) + + +@pytest.mark.asyncio +async def test_register_persists_state_and_parses_expiry(tmp_path, monkeypatch) -> None: + ca = bootstrap_ca(tmp_path / "ca") + jwt_kp = bootstrap_keypair(tmp_path / "ca") + transport = _make_controller_transport(ca, jwt_kp, executor_id="reg-worker-1") + # Patch httpx.AsyncClient so auth_lifecycle's internal client uses our transport. + import dlw.executor.auth_lifecycle as al + orig = httpx.AsyncClient + def patched(*args, **kwargs): + kwargs.pop("verify", None) + kwargs.pop("cert", None) + return orig(*args, transport=transport, **kwargs) + monkeypatch.setattr(al.httpx, "AsyncClient", patched) + + cert_dir = tmp_path / "executor" + state = await register( + controller_url="http://controller", ca_bundle_path=None, + enrollment_token="tok", executor_id="reg-worker-1", + host_id="reg-host", capabilities={}, cert_dir=cert_dir, + ) + assert isinstance(state, AuthState) + assert state.executor_id == "reg-worker-1" + assert state.epoch == 1 + assert state.jwt + assert state.jwt_exp > _dt.datetime.now(_dt.UTC) + assert state.cert_exp > _dt.datetime.now(_dt.UTC) + assert len(state.hmac_seed) == 32 + # Persisted to disk + assert (cert_dir / "client-cert.pem").exists() + assert (cert_dir / "hmac-seed").exists() + + +@pytest.mark.asyncio +async def test_load_or_register_first_run_calls_register(tmp_path, monkeypatch) -> None: + ca = bootstrap_ca(tmp_path / "ca") + jwt_kp = bootstrap_keypair(tmp_path / "ca") + transport = _make_controller_transport(ca, jwt_kp, executor_id="lor-worker-1") + import dlw.executor.auth_lifecycle as al + orig = httpx.AsyncClient + def patched(*args, **kwargs): + kwargs.pop("verify", None) + kwargs.pop("cert", None) + return orig(*args, transport=transport, **kwargs) + monkeypatch.setattr(al.httpx, "AsyncClient", patched) + + cert_dir = tmp_path / "executor" # empty — first run + state = await load_or_register( + cert_dir=cert_dir, controller_url="http://controller", + ca_bundle_path=None, enrollment_token="tok", + executor_id="lor-worker-1", host_id="lor-host", capabilities={}, + ) + assert state.executor_id == "lor-worker-1" + assert (cert_dir / "client-cert.pem").exists() diff --git a/tests/executor/test_cert.py b/tests/executor/test_cert.py new file mode 100644 index 0000000..c1ba69a --- /dev/null +++ b/tests/executor/test_cert.py @@ -0,0 +1,45 @@ +"""Tests for dlw.executor.cert (Phase 2 W3a §3.10).""" +from __future__ import annotations + +from pathlib import Path + +from dlw.executor.cert import build_csr, fingerprint, load, persist + + +def test_build_csr_returns_pem_and_key() -> None: + csr_pem, key_pem = build_csr("host-1-worker-1") + assert csr_pem.startswith(b"-----BEGIN CERTIFICATE REQUEST-----") + assert key_pem.startswith(b"-----BEGIN PRIVATE KEY-----") + # CSR is parseable + has the right CN + from cryptography import x509 + csr = x509.load_pem_x509_csr(csr_pem) + cn = csr.subject.get_attributes_for_oid(x509.oid.NameOID.COMMON_NAME)[0].value + assert cn == "host-1-worker-1" + assert csr.is_signature_valid + + +def test_persist_and_load_roundtrip(tmp_path) -> None: + cert_dir = tmp_path / "executor" + persist(cert_dir, cert_pem=b"CERTDATA", key_pem=b"KEYDATA", + ca_chain_pem=b"CADATA", hmac_seed=b"\x01" * 32) + loaded = load(cert_dir) + assert loaded is not None + cert_pem, key_pem, ca_chain_pem, hmac_seed = loaded + assert cert_pem == b"CERTDATA" + assert key_pem == b"KEYDATA" + assert ca_chain_pem == b"CADATA" + assert hmac_seed == b"\x01" * 32 + + +def test_load_returns_none_when_absent(tmp_path) -> None: + assert load(tmp_path / "nonexistent") is None + + +def test_fingerprint_matches_controller_format(tmp_path) -> None: + # Sign a real cert via the controller-side CA so we can compare formats. + from dlw.auth.ca import bootstrap_ca, sign_csr, fingerprint_of + ca = bootstrap_ca(tmp_path / "ca") + csr_pem, _key = build_csr("fp-host-worker-1") + cert_pem = sign_csr(ca, csr_pem, "fp-host-worker-1") + assert fingerprint(cert_pem) == fingerprint_of(cert_pem) + assert fingerprint(cert_pem).startswith("SHA256:") diff --git a/tests/executor/test_client.py b/tests/executor/test_client.py index ac8050e..a80cd52 100644 --- a/tests/executor/test_client.py +++ b/tests/executor/test_client.py @@ -1,4 +1,4 @@ -"""Tests for ControllerClient using httpx MockTransport — no real network.""" +"""Tests for ControllerClient (W3a: mTLS + JWT + HMAC) using httpx MockTransport.""" from __future__ import annotations import json @@ -7,21 +7,23 @@ import httpx import pytest +from dlw.auth.hmac_nonce import verify_hmac from dlw.executor.client import ControllerClient +from tests.conftest import make_fake_auth_state + + +_HMAC_SEED = b"\x05" * 32 def _mock_handler(request: httpx.Request) -> httpx.Response: """Routes requests to canned responses based on URL path.""" path = request.url.path - body = json.loads(request.content) if request.content else {} - if path == "/api/v1/executors/join" and request.method == "POST": - return httpx.Response(201, json={ - "id": body["id"], "status": "joining", "health_score": 100, - }) if path.endswith("/heartbeat") and request.method == "POST": + body = json.loads(request.content) if request.content else {} return httpx.Response(200, json={ - "id": "x", "status": "healthy", "health_score": body.get("health_score", 100), + "id": "x", "status": "healthy", + "health_score": body.get("health_score", 100), }) if path.endswith("/poll") and request.method == "POST": return httpx.Response(200, json={ @@ -48,19 +50,18 @@ def transport() -> httpx.MockTransport: return httpx.MockTransport(_mock_handler) -@pytest.mark.slow -async def test_join_sends_correct_body(transport) -> None: - async with ControllerClient( - base_url="http://test", bearer_token="t", _transport=transport - ) as c: - r = await c.join(executor_id="ex-1", host_id="h", capabilities={"nic_speed_gbps": 10}) - assert r["status"] == "joining" +@pytest.fixture +def auth_state(tmp_path): + return make_fake_auth_state( + tmp_path, executor_id="ex-1", epoch=1, + jwt="header.payload.sig", hmac_seed=_HMAC_SEED, + ) @pytest.mark.slow -async def test_heartbeat_returns_state(transport) -> None: +async def test_heartbeat_returns_state(transport, auth_state) -> None: async with ControllerClient( - base_url="http://test", bearer_token="t", _transport=transport + base_url="http://test", auth_state=auth_state, _transport=transport, ) as c: r = await c.heartbeat(executor_id="ex-1", health_score=88, parts_dir_bytes=0) assert r["status"] == "healthy" @@ -68,9 +69,9 @@ async def test_heartbeat_returns_state(transport) -> None: @pytest.mark.slow -async def test_poll_returns_assignment(transport) -> None: +async def test_poll_returns_assignment(transport, auth_state) -> None: async with ControllerClient( - base_url="http://test", bearer_token="t", _transport=transport + base_url="http://test", auth_state=auth_state, _transport=transport, ) as c: r = await c.poll(executor_id="ex-1") assert r["assigned"] is True @@ -79,9 +80,9 @@ async def test_poll_returns_assignment(transport) -> None: @pytest.mark.slow -async def test_report_propagates_token(transport) -> None: +async def test_report_propagates_token(transport, auth_state) -> None: async with ControllerClient( - base_url="http://test", bearer_token="t", _transport=transport + base_url="http://test", auth_state=auth_state, _transport=transport, ) as c: r = await c.report( subtask_id=uuid.uuid4(), @@ -94,84 +95,81 @@ async def test_report_propagates_token(transport) -> None: @pytest.mark.slow -async def test_unauthenticated_returns_401(transport) -> None: - """ControllerClient should propagate 401 as an exception (caller decides retry).""" +async def test_unauthenticated_returns_401(auth_state) -> None: + """ControllerClient propagates 401 as an exception (caller decides retry).""" def unauth(_): - return httpx.Response(401, json={"detail": "missing bearer token"}) + return httpx.Response(401, json={"detail": "missing executor JWT"}) t = httpx.MockTransport(unauth) - async with ControllerClient(base_url="http://test", bearer_token="bad", _transport=t) as c: + async with ControllerClient( + base_url="http://test", auth_state=auth_state, _transport=t, + ) as c: with pytest.raises(httpx.HTTPStatusError): await c.heartbeat(executor_id="ex-1", health_score=100, parts_dir_bytes=0) @pytest.mark.slow -async def test_client_persists_epoch_from_join_response() -> None: - """After join(), client should store the response epoch internally.""" +async def test_client_attaches_jwt_and_epoch_headers(tmp_path) -> None: + """Every request carries Authorization: Bearer + X-Executor-Epoch.""" + seen: list[dict[str, str]] = [] + def handler(request: httpx.Request) -> httpx.Response: - return httpx.Response(201, json={ - "id": "h-w-1", "status": "joining", "health_score": 100, "epoch": 7, - }) + seen.append({k.lower(): v for k, v in request.headers.items()}) + return httpx.Response(200, json={"assigned": False}) - transport = httpx.MockTransport(handler) + state = make_fake_auth_state( + tmp_path, executor_id="h-w-1", epoch=7, jwt="jwt-token-7", + hmac_seed=_HMAC_SEED, + ) async with ControllerClient( - base_url="http://test", bearer_token="t", _transport=transport, + base_url="http://test", auth_state=state, + _transport=httpx.MockTransport(handler), ) as c: - await c.join(executor_id="h-w-1", host_id="h", capabilities={}) - assert c.current_epoch() == 7 + await c.poll(executor_id="h-w-1") + + assert seen[0]["authorization"] == "Bearer jwt-token-7" + assert seen[0]["x-executor-epoch"] == "7" @pytest.mark.slow -async def test_client_attaches_epoch_header_on_heartbeat() -> None: - """heartbeat must send X-Executor-Epoch matching the join response.""" - seen_headers: list[dict[str, str]] = [] +async def test_client_signs_heartbeat_with_hmac(tmp_path) -> None: + """heartbeat must carry X-HMAC-* headers and the signature must verify + against the executor's hmac_seed over the exact request body.""" + seen: dict[str, object] = {} def handler(request: httpx.Request) -> httpx.Response: - seen_headers.append({k.lower(): v for k, v in request.headers.items()}) - if request.url.path.endswith("/join"): - return httpx.Response(201, json={ - "id": "h-w-1", "status": "joining", "health_score": 100, "epoch": 5, - }) - return httpx.Response(200, json={ - "id": "h-w-1", "status": "healthy", "health_score": 100, "epoch": 5, - }) - - transport = httpx.MockTransport(handler) + seen["headers"] = {k.lower(): v for k, v in request.headers.items()} + seen["body"] = bytes(request.content) + return httpx.Response(200, json={"id": "h-w-1", "status": "healthy", + "health_score": 100}) + + state = make_fake_auth_state( + tmp_path, executor_id="h-w-1", epoch=3, jwt="jwt-3", + hmac_seed=_HMAC_SEED, + ) async with ControllerClient( - base_url="http://test", bearer_token="t", _transport=transport, + base_url="http://test", auth_state=state, + _transport=httpx.MockTransport(handler), ) as c: - await c.join(executor_id="h-w-1", host_id="h", capabilities={}) await c.heartbeat(executor_id="h-w-1", health_score=100, parts_dir_bytes=0) - assert "x-executor-epoch" in seen_headers[1] - assert seen_headers[1]["x-executor-epoch"] == "5" + h = seen["headers"] + assert "x-hmac-timestamp" in h + assert "x-hmac-nonce" in h + assert "x-hmac-signature" in h + assert verify_hmac( + _HMAC_SEED, + ts=int(h["x-hmac-timestamp"]), + nonce=h["x-hmac-nonce"], + body=seen["body"], + signature_hex=h["x-hmac-signature"], + ) @pytest.mark.slow -async def test_client_attaches_epoch_header_on_report() -> None: - seen_headers: list[dict[str, str]] = [] - import uuid as _uuid - - def handler(request: httpx.Request) -> httpx.Response: - seen_headers.append({k.lower(): v for k, v in request.headers.items()}) - if request.url.path.endswith("/join"): - return httpx.Response(201, json={ - "id": "h-w-1", "status": "joining", "health_score": 100, "epoch": 11, - }) - return httpx.Response(200, json={ - "subtask_status": "succeeded", "task_status": "succeeded", - }) - - transport = httpx.MockTransport(handler) - async with ControllerClient( - base_url="http://test", bearer_token="t", _transport=transport, - ) as c: - await c.join(executor_id="h-w-1", host_id="h", capabilities={}) - await c.report( - subtask_id=_uuid.uuid4(), - status="succeeded", - assignment_token=_uuid.uuid4(), - actual_sha256="a" * 64, - bytes_downloaded=4096, - ) - - assert seen_headers[1]["x-executor-epoch"] == "11" +async def test_current_epoch_reads_auth_state(tmp_path) -> None: + state = make_fake_auth_state(tmp_path, epoch=42) + c = ControllerClient(base_url="http://test", auth_state=state) + assert c.current_epoch() == 42 + new = make_fake_auth_state(tmp_path, epoch=43) + c.update_auth(new) + assert c.current_epoch() == 43 diff --git a/tests/executor/test_runner.py b/tests/executor/test_runner.py index abae2a4..36031e0 100644 --- a/tests/executor/test_runner.py +++ b/tests/executor/test_runner.py @@ -1,4 +1,4 @@ -"""Tests for ExecutorRunner main loop.""" +"""Tests for ExecutorRunner main loop (W3a: mTLS+JWT auth bootstrap).""" from __future__ import annotations import asyncio @@ -12,6 +12,7 @@ from dlw.executor.config import ExecutorSettings from dlw.executor.downloader import DownloadResult, HfS3StreamDownloader from dlw.executor.runner import ExecutorRunner +from tests.conftest import make_fake_auth_state @pytest.fixture @@ -24,11 +25,17 @@ def settings(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> ExecutorSetting return ExecutorSettings() +@pytest.fixture +def auth_state(tmp_path): + """A pre-built AuthState injected into the runner so run() skips + load_or_register (no HTTP call to a controller in these unit tests).""" + return make_fake_auth_state(tmp_path, executor_id="host-test-w1", epoch=1) + + @pytest.mark.slow -async def test_runner_join_then_heartbeat_in_idle(settings) -> None: +async def test_runner_heartbeats_in_idle(settings, auth_state) -> None: """When poll always returns assigned=False, runner heartbeats but does not download.""" client = MagicMock(spec=ControllerClient) - client.join = AsyncMock(return_value={"id": "host-test-w1", "status": "joining", "health_score": 100}) client.heartbeat = AsyncMock(return_value={"id": "x", "status": "healthy", "health_score": 100}) client.poll = AsyncMock(return_value={"assigned": False, "subtask": None, "assignment_token": None}) downloader = MagicMock(spec=HfS3StreamDownloader) @@ -36,26 +43,26 @@ async def test_runner_join_then_heartbeat_in_idle(settings) -> None: runner = ExecutorRunner( settings=settings, client=client, stream_downloader=downloader, chunk_downloader=MagicMock(), + auth_state=auth_state, ) task = asyncio.create_task(runner.run()) await asyncio.sleep(2.5) runner.request_shutdown() await asyncio.wait_for(task, timeout=5) - client.join.assert_awaited_once() + client.update_auth.assert_called_once() assert client.heartbeat.await_count >= 1 assert client.poll.await_count >= 1 downloader.download.assert_not_called() @pytest.mark.slow -async def test_runner_executes_assigned_subtask(settings) -> None: +async def test_runner_executes_assigned_subtask(settings, auth_state) -> None: """When poll returns an assignment, runner downloads + reports.""" sub_id = uuid.uuid4() token = uuid.uuid4() client = MagicMock(spec=ControllerClient) - client.join = AsyncMock(return_value={"id": "host-test-w1", "status": "joining", "health_score": 100}) client.heartbeat = AsyncMock(return_value={"id": "x", "status": "healthy", "health_score": 100}) poll_results = [ @@ -87,12 +94,12 @@ async def test_runner_executes_assigned_subtask(settings) -> None: ) downloader = MagicMock(spec=HfS3StreamDownloader) downloader.download = AsyncMock(return_value=download_result) - client.report = AsyncMock(return_value={"subtask_status": "succeeded", "task_status": "pending"}) runner = ExecutorRunner( settings=settings, client=client, stream_downloader=downloader, chunk_downloader=MagicMock(), + auth_state=auth_state, ) task = asyncio.create_task(runner.run()) await asyncio.sleep(2.5) @@ -108,13 +115,12 @@ async def test_runner_executes_assigned_subtask(settings) -> None: @pytest.mark.slow -async def test_runner_reports_failure_on_download_error(settings) -> None: +async def test_runner_reports_failure_on_download_error(settings, auth_state) -> None: """If downloader raises, runner reports status=failed with the error message.""" sub_id = uuid.uuid4() token = uuid.uuid4() client = MagicMock(spec=ControllerClient) - client.join = AsyncMock(return_value={"id": "host-test-w1", "status": "joining", "health_score": 100}) client.heartbeat = AsyncMock(return_value={"id": "x", "status": "healthy", "health_score": 100}) client.poll = AsyncMock(side_effect=[ { @@ -137,6 +143,7 @@ async def test_runner_reports_failure_on_download_error(settings) -> None: runner = ExecutorRunner( settings=settings, client=client, stream_downloader=downloader, chunk_downloader=MagicMock(), + auth_state=auth_state, ) task = asyncio.create_task(runner.run()) await asyncio.sleep(2.5) @@ -150,10 +157,9 @@ async def test_runner_reports_failure_on_download_error(settings) -> None: @pytest.mark.slow -async def test_runner_graceful_shutdown(settings) -> None: +async def test_runner_graceful_shutdown(settings, auth_state) -> None: """request_shutdown() during execution should cleanly cancel the loops.""" client = MagicMock(spec=ControllerClient) - client.join = AsyncMock(return_value={"id": "x", "status": "joining", "health_score": 100}) client.heartbeat = AsyncMock(return_value={"id": "x", "status": "healthy", "health_score": 100}) client.poll = AsyncMock(return_value={"assigned": False, "subtask": None, "assignment_token": None}) downloader = MagicMock(spec=HfS3StreamDownloader) @@ -161,6 +167,7 @@ async def test_runner_graceful_shutdown(settings) -> None: runner = ExecutorRunner( settings=settings, client=client, stream_downloader=downloader, chunk_downloader=MagicMock(), + auth_state=auth_state, ) task = asyncio.create_task(runner.run()) await asyncio.sleep(0.3) @@ -170,11 +177,10 @@ async def test_runner_graceful_shutdown(settings) -> None: @pytest.mark.slow async def test_runner_passes_assignment_with_repo_and_storage( - monkeypatch: pytest.MonkeyPatch, + monkeypatch: pytest.MonkeyPatch, tmp_path: Path, ) -> None: """W4: runner forwards repo_id/revision/storage_config from /poll to downloader.""" from dlw.executor.downloader import Assignment, DownloadResult - from dlw.schemas.storage import StorageConfig captured: dict[str, object] = {} @@ -189,9 +195,9 @@ async def download(self, *, assignment: Assignment) -> DownloadResult: ) class FakeClient: - async def join(self, **kw): pass - async def heartbeat(self, **kw): pass joined_polls = 0 + def update_auth(self, _state): pass + async def heartbeat(self, **kw): pass async def poll(self, **kw): FakeClient.joined_polls += 1 if FakeClient.joined_polls > 1: @@ -220,6 +226,7 @@ async def report(self, **kw): captured["report_kw"] = kw runner = ExecutorRunner( settings=settings, client=FakeClient(), stream_downloader=FakeDownloader(), chunk_downloader=MagicMock(), + auth_state=make_fake_auth_state(tmp_path, executor_id="host-r-worker-1"), ) run_task = asyncio.create_task(runner.run()) await asyncio.sleep(2) @@ -236,67 +243,61 @@ async def report(self, **kw): captured["report_kw"] = kw @pytest.mark.slow -async def test_runner_rejoins_on_epoch_mismatch() -> None: - """Runner: on 401 EPOCH_MISMATCH, abort current poll + re-join + continue.""" +async def test_runner_reregisters_on_poll_401( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path, +) -> None: + """W3a: on a 401 from poll, the runner re-registers (load_or_register) and + continues. Generalizes the W1 EPOCH_MISMATCH re-join path.""" import httpx as _httpx - import uuid as _u - from dlw.executor.runner import ExecutorRunner - from dlw.executor.config import ExecutorSettings + import dlw.executor.runner as runner_mod - class FakeClient: - def __init__(self): - self.calls: list[str] = [] - self._poll_returns_401_once = True - self._epoch: int | None = None + register_calls: list[int] = [] - def current_epoch(self): return self._epoch + async def fake_load_or_register(**kwargs): + register_calls.append(1) + return make_fake_auth_state( + tmp_path, executor_id=kwargs["executor_id"], + epoch=len(register_calls), + ) - async def __aenter__(self): return self - async def __aexit__(self, *a): pass - - async def join(self, *, executor_id, host_id, capabilities): - self.calls.append("join") - self._epoch = 2 if "join" in self.calls[:-1] else 1 - return {"id": executor_id, "epoch": self._epoch, "status": "joining", - "health_score": 100} + monkeypatch.setattr(runner_mod, "load_or_register", fake_load_or_register) + class FakeClient: + def __init__(self): + self.calls: list[str] = [] + self._poll_401_once = True + def update_auth(self, _state): + self.calls.append("update_auth") async def heartbeat(self, **kw): self.calls.append("heartbeat") - async def poll(self, **kw): self.calls.append("poll") - if self._poll_returns_401_once: - self._poll_returns_401_once = False + if self._poll_401_once: + self._poll_401_once = False req = _httpx.Request("POST", "http://x/poll") - resp = _httpx.Response( - 401, json={"detail": {"code": "EPOCH_MISMATCH", - "expected": 2, "got": 1}} - ) - raise _httpx.HTTPStatusError( - "401", request=req, response=resp, - ) + resp = _httpx.Response(401, json={"detail": "invalid JWT"}) + raise _httpx.HTTPStatusError("401", request=req, response=resp) return {"assigned": False} - async def report(self, **kw): pass - settings = ExecutorSettings( - id="rj-host-worker-1", host_id="rj-host", bearer_token="t", - heartbeat_interval_seconds=1, poll_interval_seconds=1, - ) - class FakeDl: async def download(self, **kw): raise AssertionError("downloader should NOT be invoked in this test") + settings = ExecutorSettings( + id="rj-host-worker-1", host_id="rj-host", bearer_token="t", + heartbeat_interval_seconds=1, poll_interval_seconds=1, + ) runner = ExecutorRunner( settings=settings, client=FakeClient(), stream_downloader=FakeDl(), chunk_downloader=MagicMock(), + # auth_state=None → run() calls (the monkeypatched) load_or_register. ) run_task = asyncio.create_task(runner.run()) - await asyncio.sleep(2.5) # let 1-2 poll cycles run + await asyncio.sleep(2.5) runner.request_shutdown() await asyncio.wait_for(run_task, timeout=3) - # join called at least twice (initial + after EPOCH_MISMATCH) - assert runner._client.calls.count("join") >= 2 + # load_or_register called at least twice: initial bootstrap + after the 401. + assert len(register_calls) >= 2 assert "poll" in runner._client.calls diff --git a/tests/services/test_executor_service.py b/tests/services/test_executor_service.py index 93b6968..69516d4 100644 --- a/tests/services/test_executor_service.py +++ b/tests/services/test_executor_service.py @@ -1,4 +1,4 @@ -"""Tests for executor_service: join + heartbeat upsert.""" +"""Tests for executor_service: upsert_executor_with_cert + heartbeat upsert (W3a).""" from __future__ import annotations import asyncio @@ -9,8 +9,19 @@ from dlw.db.base import Base from dlw.db.models.executor import Executor -from dlw.schemas.executor import ExecutorHeartbeat, ExecutorJoin -from dlw.services.executor_service import join_executor, record_heartbeat +from dlw.schemas.executor import ExecutorHeartbeat +from dlw.services.executor_service import record_heartbeat, upsert_executor_with_cert + + +_FP = "SHA256:" + "ab" * 32 +_SEED = b"\x07" * 32 + + +def _upsert_kwargs(executor_id: str, host_id: str, *, fp: str = _FP): + return dict( + executor_id=executor_id, host_id=host_id, capabilities={}, + cert_fingerprint=fp, hmac_seed=_SEED, + ) @pytest.fixture(scope="module", autouse=True) @@ -23,30 +34,33 @@ async def _create_tables(engine): @pytest.mark.slow -async def test_join_creates_executor(db_session: AsyncSession) -> None: - body = ExecutorJoin( - id="host-a-w1", host_id="host-a", capabilities={"nic_speed_gbps": 10}, +async def test_upsert_creates_executor_with_cert_and_seed(db_session: AsyncSession) -> None: + ex = await upsert_executor_with_cert( + db_session, + **_upsert_kwargs("host-a-w1", "host-a"), ) - ex = await join_executor(db_session, body) await db_session.commit() assert ex.id == "host-a-w1" assert ex.status == "joining" assert ex.health_score == 100 + assert ex.cert_fingerprint == _FP + assert bytes(ex.hmac_seed_encrypted) == _SEED @pytest.mark.slow -async def test_join_idempotent(db_session: AsyncSession) -> None: - body = ExecutorJoin(id="host-b-w1", host_id="host-b") - await join_executor(db_session, body) +async def test_upsert_idempotent(db_session: AsyncSession) -> None: + await upsert_executor_with_cert(db_session, **_upsert_kwargs("host-b-w1", "host-b")) await db_session.commit() - again = await join_executor(db_session, body) + again = await upsert_executor_with_cert( + db_session, **_upsert_kwargs("host-b-w1", "host-b"), + ) await db_session.commit() assert again.id == "host-b-w1" @pytest.mark.slow async def test_heartbeat_updates_health_and_timestamp(db_session: AsyncSession) -> None: - await join_executor(db_session, ExecutorJoin(id="host-c-w1", host_id="host-c")) + await upsert_executor_with_cert(db_session, **_upsert_kwargs("host-c-w1", "host-c")) await db_session.commit() before = datetime.now(UTC) ex = await record_heartbeat( @@ -76,41 +90,40 @@ async def env(): @pytest.mark.slow -async def test_join_first_time_returns_epoch_1(db_session: AsyncSession, env) -> None: - """First /join for a brand new executor_id assigns epoch=1.""" - body = ExecutorJoin(id="new-host-worker-1", host_id="new-host") - ex = await join_executor(db_session, body) +async def test_upsert_first_time_returns_epoch_1(db_session: AsyncSession, env) -> None: + """First upsert for a brand new executor_id assigns epoch=1.""" + ex = await upsert_executor_with_cert( + db_session, **_upsert_kwargs("new-host-worker-1", "new-host"), + ) assert ex.epoch == 1 @pytest.mark.slow -async def test_join_existing_executor_increments_epoch( +async def test_upsert_existing_executor_increments_epoch( db_session: AsyncSession, env, ) -> None: - """Repeated /join for same id bumps epoch atomically: 1 → 2 → 3.""" - body = ExecutorJoin(id="bump-host-worker-1", host_id="bump-host") - ex1 = await join_executor(db_session, body); await db_session.commit() + """Repeated upsert for same id bumps epoch atomically: 1 → 2 → 3.""" + kw = _upsert_kwargs("bump-host-worker-1", "bump-host") + ex1 = await upsert_executor_with_cert(db_session, **kw); await db_session.commit() assert ex1.epoch == 1 - ex2 = await join_executor(db_session, body); await db_session.commit() + ex2 = await upsert_executor_with_cert(db_session, **kw); await db_session.commit() assert ex2.epoch == 2 - ex3 = await join_executor(db_session, body); await db_session.commit() + ex3 = await upsert_executor_with_cert(db_session, **kw); await db_session.commit() assert ex3.epoch == 3 @pytest.mark.slow -async def test_join_concurrent_returns_distinct_epochs( - engine, env, -) -> None: - """asyncio.gather × 2 join calls for the same id must yield distinct epochs.""" +async def test_upsert_concurrent_returns_distinct_epochs(engine, env) -> None: + """asyncio.gather × 2 upsert calls for the same id must yield distinct epochs.""" from sqlalchemy.ext.asyncio import async_sessionmaker factory = async_sessionmaker(engine, expire_on_commit=False) - body = ExecutorJoin(id="race-host-worker-1", host_id="race-host") + kw = _upsert_kwargs("race-host-worker-1", "race-host") - async def join_in_own_session(): + async def upsert_in_own_session(): async with factory() as s: - ex = await join_executor(s, body) + ex = await upsert_executor_with_cert(s, **kw) await s.commit() return ex.epoch - epochs = await asyncio.gather(join_in_own_session(), join_in_own_session()) + epochs = await asyncio.gather(upsert_in_own_session(), upsert_in_own_session()) assert sorted(epochs) == [1, 2] # PG atomic; one wins INSERT, the other UPDATE diff --git a/tools/lint_invariants.py b/tools/lint_invariants.py index 9bcf914..b000931 100644 --- a/tools/lint_invariants.py +++ b/tools/lint_invariants.py @@ -197,6 +197,34 @@ def check_d10_host_affinity_test_owner() -> list[str]: return [] +def check_no_bearer_on_executor_routes() -> list[str]: + """W3a §3.15: forbid Depends(require_bearer) in executor/subtask route files. + Those endpoints must use mTLS + JWT (not the UI shared-secret bearer).""" + errors: list[str] = [] + files = [ + ROOT / "src" / "dlw" / "api" / "executors.py", + ROOT / "src" / "dlw" / "api" / "subtasks.py", + ] + import ast as _ast + for f in files: + if not f.exists(): + continue + tree = _ast.parse(f.read_text(encoding="utf-8")) + for node in _ast.walk(tree): + if (isinstance(node, _ast.Call) + and isinstance(node.func, _ast.Name) + and node.func.id == "Depends" + and node.args + and isinstance(node.args[0], _ast.Name) + and node.args[0].id == "require_bearer"): + errors.append( + f"{f.relative_to(ROOT)}:{node.lineno}: " + f"require_bearer forbidden on executor/subtask routes " + f"(use mTLS + JWT)" + ) + return errors + + def main() -> int: failures: list[str] = [] @@ -334,6 +362,7 @@ def main() -> int: failures.extend(check_subtask_status_domain()) failures.extend(check_task_status_domain()) failures.extend(check_d10_host_affinity_test_owner()) + failures.extend(check_no_bearer_on_executor_routes()) # --- Report --- if failures: diff --git a/uv.lock b/uv.lock index b55320f..4b31a0c 100644 --- a/uv.lock +++ b/uv.lock @@ -332,55 +332,31 @@ wheels = [ [[package]] name = "cryptography" -version = "48.0.0" +version = "43.0.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cffi", marker = "platform_python_implementation != 'PyPy'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/9f/a9/db8f313fdcd85d767d4973515e1db101f9c71f95fced83233de224673757/cryptography-48.0.0.tar.gz", hash = "sha256:5c3932f4436d1cccb036cb0eaef46e6e2db91035166f1ad6505c3c9d5a635920", size = 832984, upload-time = "2026-05-04T22:59:38.133Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/df/3d/01f6dd9190170a5a241e0e98c2d04be3664a9e6f5b9b872cde63aff1c3dd/cryptography-48.0.0-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:0c558d2cdffd8f4bbb30fc7134c74d2ca9a476f830bb053074498fbc86f41ed6", size = 8001587, upload-time = "2026-05-04T22:57:36.803Z" }, - { url = "https://files.pythonhosted.org/packages/b2/6e/e90527eef33f309beb811cf7c982c3aeffcce8e3edb178baa4ca3ae4a6fa/cryptography-48.0.0-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f5333311663ea94f75dd408665686aaf426563556bb5283554a3539177e03b8c", size = 4690433, upload-time = "2026-05-04T22:57:40.373Z" }, - { url = "https://files.pythonhosted.org/packages/90/04/673510ed51ddff56575f306cf1617d80411ee76831ccd3097599140efdfe/cryptography-48.0.0-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7995ef305d7165c3f11ae07f2517e5a4f1d5c18da1376a0a9ed496336b69e5f3", size = 4710620, upload-time = "2026-05-04T22:57:42.935Z" }, - { url = "https://files.pythonhosted.org/packages/14/d5/e9c4ef932c8d800490c34d8bd589d64a31d5890e27ec9e9ad532be893294/cryptography-48.0.0-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:40ba1f85eaa6959837b1d51c9767e230e14612eea4ef110ee8854ada22da1bf5", size = 4696283, upload-time = "2026-05-04T22:57:45.294Z" }, - { url = "https://files.pythonhosted.org/packages/0c/29/174b9dfb60b12d59ecfc6cfa04bc88c21b42a54f01b8aae09bb6e51e4c7f/cryptography-48.0.0-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:369a6348999f94bbd53435c894377b20ab95f25a9065c283570e70150d8abc3c", size = 5296573, upload-time = "2026-05-04T22:57:47.933Z" }, - { url = "https://files.pythonhosted.org/packages/95/38/0d29a6fd7d0d1373f0c0c88a04ba20e359b257753ac497564cd660fc1d55/cryptography-48.0.0-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:a0e692c683f4df67815a2d258b324e66f4738bd7a96a218c826dce4f4bd05d8f", size = 4743677, upload-time = "2026-05-04T22:57:50.067Z" }, - { url = "https://files.pythonhosted.org/packages/30/be/eef653013d5c63b6a490529e0316f9ac14a37602965d4903efed1399f32b/cryptography-48.0.0-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:18349bbc56f4743c8b12dc32e2bccb2cf83ee8b69a3bba74ef8ae857e26b3d25", size = 4330808, upload-time = "2026-05-04T22:57:52.301Z" }, - { url = "https://files.pythonhosted.org/packages/84/9e/500463e87abb7a0a0f9f256ec21123ecde0a7b5541a15e840ea54551fd81/cryptography-48.0.0-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:7e8eac43dfca5c4cccc6dad9a80504436fca53bb9bc3100a2386d730fbe6b602", size = 4695941, upload-time = "2026-05-04T22:57:54.603Z" }, - { url = "https://files.pythonhosted.org/packages/e3/dc/7303087450c2ec9e7fbb750e17c2abfbc658f23cbd0e54009509b7cc4091/cryptography-48.0.0-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:9ccdac7d40688ecb5a3b4a604b8a88c8002e3442d6c60aead1db2a89a041560c", size = 5252579, upload-time = "2026-05-04T22:57:57.207Z" }, - { url = "https://files.pythonhosted.org/packages/d0/c0/7101d3b7215edcdc90c45da544961fd8ed2d6448f77577460fa75a8443f7/cryptography-48.0.0-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:bd72e68b06bb1e96913f97dd4901119bc17f39d4586a5adf2d3e47bc2b9d58b5", size = 4743326, upload-time = "2026-05-04T22:57:59.535Z" }, - { url = "https://files.pythonhosted.org/packages/ac/d8/5b833bad13016f562ab9d063d68199a4bd121d18458e439515601d3357ec/cryptography-48.0.0-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:59baa2cb386c4f0b9905bd6eb4c2a79a69a128408fd31d32ca4d7102d4156321", size = 4826672, upload-time = "2026-05-04T22:58:01.996Z" }, - { url = "https://files.pythonhosted.org/packages/98/e1/7074eb8bf3c135558c73fc2bcf0f5633f912e6fb87e868a55c454080ef09/cryptography-48.0.0-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:9249e3cd978541d665967ac2cb2787fd6a62bddf1e75b3e347a594d7dacf4f74", size = 4972574, upload-time = "2026-05-04T22:58:03.968Z" }, - { url = "https://files.pythonhosted.org/packages/04/70/e5a1b41d325f797f39427aa44ef8baf0be500065ab6d8e10369d850d4a4f/cryptography-48.0.0-cp311-abi3-win32.whl", hash = "sha256:9c459db21422be75e2809370b829a87eb37f74cd785fc4aa9ea1e5f43b47cda4", size = 3294868, upload-time = "2026-05-04T22:58:06.467Z" }, - { url = "https://files.pythonhosted.org/packages/f4/ac/8ac51b4a5fc5932eb7ee5c517ba7dc8cd834f0048962b6b352f00f41ebf9/cryptography-48.0.0-cp311-abi3-win_amd64.whl", hash = "sha256:5b012212e08b8dd5edc78ef54da83dd9892fd9105323b3993eff6bea65dc21d7", size = 3817107, upload-time = "2026-05-04T22:58:08.845Z" }, - { url = "https://files.pythonhosted.org/packages/6b/84/70e3feea9feea87fd7cbe77efb2712ae1e3e6edf10749dc6e95f4e60e455/cryptography-48.0.0-cp314-cp314t-macosx_10_9_universal2.whl", hash = "sha256:3cb07a3ed6431663cd321ea8a000a1314c74211f823e4177fefa2255e057d1ec", size = 7986556, upload-time = "2026-05-04T22:58:11.172Z" }, - { url = "https://files.pythonhosted.org/packages/89/6e/18e07a618bb5442ba10cf4df16e99c071365528aa570dfcb8c02e25a303b/cryptography-48.0.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8c7378637d7d88016fa6791c159f698b3d3eed28ebf844ac36b9dc04a14dae18", size = 4684776, upload-time = "2026-05-04T22:58:13.712Z" }, - { url = "https://files.pythonhosted.org/packages/be/6a/4ea3b4c6c6759794d5ee2103c304a5076dc4b19ae1f9fe47dba439e159e9/cryptography-48.0.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:cc90c0b39b2e3c65ef52c804b72e3c58f8a04ab2a1871272798e5f9572c17d20", size = 4698121, upload-time = "2026-05-04T22:58:16.448Z" }, - { url = "https://files.pythonhosted.org/packages/2f/59/6ff6ad6cae03bb887da2a5860b2c9805f8dac969ef01ce563336c49bd1d1/cryptography-48.0.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:76341972e1eff8b4bea859f09c0d3e64b96ce931b084f9b9b7db8ef364c30eff", size = 4690042, upload-time = "2026-05-04T22:58:18.544Z" }, - { url = "https://files.pythonhosted.org/packages/ca/b4/fc334ed8cfd705aca282fe4d8f5ae64a8e0f74932e9feecb344610cf6e4d/cryptography-48.0.0-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:55b7718303bf06a5753dcdccf2f3945cf18ad7bffde41b61226e4db31ab89a9c", size = 5282526, upload-time = "2026-05-04T22:58:20.75Z" }, - { url = "https://files.pythonhosted.org/packages/11/08/9f8c5386cc4cd90d8255c7cdd0f5baf459a08502a09de30dc51f553d38dc/cryptography-48.0.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:a64697c641c7b1b2178e573cbc31c7c6684cd56883a478d75143dbb7118036db", size = 4733116, upload-time = "2026-05-04T22:58:23.627Z" }, - { url = "https://files.pythonhosted.org/packages/b8/77/99307d7574045699f8805aa500fa0fb83422d115b5400a064ddd306d7750/cryptography-48.0.0-cp314-cp314t-manylinux_2_31_armv7l.whl", hash = "sha256:561215ea3879cb1cbbf272867e2efda62476f240fb58c64de6b393ae19246741", size = 4316030, upload-time = "2026-05-04T22:58:25.581Z" }, - { url = "https://files.pythonhosted.org/packages/fd/36/a608b98337af3cb2aff4818e406649d30572b7031918b04c87d979495348/cryptography-48.0.0-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:ad64688338ed4bc1a6618076ba75fd7194a5f1797ac60b47afe926285adb3166", size = 4689640, upload-time = "2026-05-04T22:58:27.747Z" }, - { url = "https://files.pythonhosted.org/packages/dd/a6/825010a291b4438aecc1f568bc428189fc1175515223632477c07dc0a6df/cryptography-48.0.0-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:906cbf0670286c6e0044156bc7d4af9cbb0ef6db9f73e52c3ec56ba6bdde5336", size = 5237657, upload-time = "2026-05-04T22:58:29.848Z" }, - { url = "https://files.pythonhosted.org/packages/b9/09/4e76a09b4caa29aad535ddc806f5d4c5d01885bd978bd984fbc6ca032cae/cryptography-48.0.0-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:ea8990436d914540a40ab24b6a77c0969695ed52f4a4874c5137ccf7045a7057", size = 4732362, upload-time = "2026-05-04T22:58:32.009Z" }, - { url = "https://files.pythonhosted.org/packages/18/78/444fa04a77d0cb95f417dda20d450e13c56ba8e5220fc892a1658f44f882/cryptography-48.0.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:c18684a7f0cc9a3cb60328f496b8e3372def7c5d2df39ac267878b05565aaaae", size = 4819580, upload-time = "2026-05-04T22:58:34.254Z" }, - { url = "https://files.pythonhosted.org/packages/38/85/ea67067c70a1fd4be2c63d35eeed82658023021affccc7b17705f8527dd2/cryptography-48.0.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:9be5aafa5736574f8f15f262adc81b2a9869e2cfe9014d52a44633905b40d52c", size = 4963283, upload-time = "2026-05-04T22:58:36.376Z" }, - { url = "https://files.pythonhosted.org/packages/75/54/cc6d0f3deac3e81c7f847e8a189a12b6cdd65059b43dad25d4316abd849a/cryptography-48.0.0-cp314-cp314t-win32.whl", hash = "sha256:c17dfe85494deaeddc5ce251aebd1d60bbe6afc8b62071bb0b469431a000124f", size = 3270954, upload-time = "2026-05-04T22:58:38.791Z" }, - { url = "https://files.pythonhosted.org/packages/49/67/cc947e288c0758a4e5473d1dcb743037ab7785541265a969240b8885441a/cryptography-48.0.0-cp314-cp314t-win_amd64.whl", hash = "sha256:27241b1dc9962e056062a8eef1991d02c3a24569c95975bd2322a8a52c6e5e12", size = 3797313, upload-time = "2026-05-04T22:58:40.746Z" }, - { url = "https://files.pythonhosted.org/packages/f2/63/61d4a4e1c6b6bab6ce1e213cd36a24c415d90e76d78c5eb8577c5541d2e8/cryptography-48.0.0-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:58d00498e8933e4a194f3076aee1b4a97dfec1a6da444535755822fe5d8b0b86", size = 7983482, upload-time = "2026-05-04T22:58:43.769Z" }, - { url = "https://files.pythonhosted.org/packages/d5/ac/f5b5995b87770c693e2596559ffafe195b4033a57f14a82268a2842953f3/cryptography-48.0.0-cp39-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:614d0949f4790582d2cc25553abd09dd723025f0c0e7c67376a1d77196743d6e", size = 4683266, upload-time = "2026-05-04T22:58:46.064Z" }, - { url = "https://files.pythonhosted.org/packages/ec/c6/8b14f67e18338fbc4adb76f66c001f5c3610b3e2d1837f268f47a347dbbb/cryptography-48.0.0-cp39-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7ce4bfae76319a532a2dc68f82cc32f5676ee792a983187dac07183690e5c66f", size = 4696228, upload-time = "2026-05-04T22:58:48.22Z" }, - { url = "https://files.pythonhosted.org/packages/ea/73/f808fbae9514bd91b47875b003f13e284c8c6bdfd904b7944e803937eec1/cryptography-48.0.0-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:2eb992bbd4661238c5a397594c83f5b4dc2bc5b848c365c8f991b6780efcc5c7", size = 4689097, upload-time = "2026-05-04T22:58:50.9Z" }, - { url = "https://files.pythonhosted.org/packages/93/01/d86632d7d28db8ae83221995752eeb6639ffb374c2d22955648cf8d52797/cryptography-48.0.0-cp39-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:22a5cb272895dce158b2cacdfdc3debd299019659f42947dbdac6f32d68fe832", size = 5283582, upload-time = "2026-05-04T22:58:53.017Z" }, - { url = "https://files.pythonhosted.org/packages/02/e1/50edc7a50334807cc4791fc4a0ce7468b4a1416d9138eab358bfc9a3d70b/cryptography-48.0.0-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:2b4d59804e8408e2fea7d1fbaf218e5ec984325221db76e6a241a9abd6cdd95c", size = 4730479, upload-time = "2026-05-04T22:58:55.611Z" }, - { url = "https://files.pythonhosted.org/packages/6f/af/99a582b1b1641ff5911ac559beb45097cf79efd4ead4657f578ef1af2d47/cryptography-48.0.0-cp39-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:984a20b0f62a26f48a3396c72e4bc34c66e356d356bf370053066b3b6d54634a", size = 4326481, upload-time = "2026-05-04T22:58:57.607Z" }, - { url = "https://files.pythonhosted.org/packages/90/ee/89aa26a06ef0a7d7611788ffd571a7c50e368cc6a4d5eef8b4884e866edb/cryptography-48.0.0-cp39-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:5a5ed8fde7a1d09376ca0b40e68cd59c69fe23b1f9768bd5824f54681626032a", size = 4688713, upload-time = "2026-05-04T22:59:00.077Z" }, - { url = "https://files.pythonhosted.org/packages/70/ba/bcb1b0bb7a33d4c7c0c4d4c7874b4a62ae4f56113a5f4baefa362dfb1f0f/cryptography-48.0.0-cp39-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:8cd666227ef7af430aa5914a9910e0ddd703e75f039cef0825cd0da71b6b711a", size = 5238165, upload-time = "2026-05-04T22:59:02.317Z" }, - { url = "https://files.pythonhosted.org/packages/c9/70/ca4003b1ce5ca3dc3186ada51908c8a9b9ff7d5cab83cc0d43ee14ec144f/cryptography-48.0.0-cp39-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:9071196d81abc88b3516ac8cdfad32e2b66dd4a5393a8e68a961e9161ddc6239", size = 4729947, upload-time = "2026-05-04T22:59:05.255Z" }, - { url = "https://files.pythonhosted.org/packages/44/a0/4ec7cf774207905aef1a8d11c3750d5a1db805eb380ee4e16df317870128/cryptography-48.0.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:1e2d54c8be6152856a36f0882ab231e70f8ec7f14e93cf87db8a2ed056bf160c", size = 4822059, upload-time = "2026-05-04T22:59:07.802Z" }, - { url = "https://files.pythonhosted.org/packages/1e/75/a2e55f99c16fcac7b5d6c1eb19ad8e00799854d6be5ca845f9259eae1681/cryptography-48.0.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a5da777e32ffed6f85a7b2b3f7c5cbc88c146bfcd0a1d7baf5fcc6c52ee35dd4", size = 4960575, upload-time = "2026-05-04T22:59:09.851Z" }, - { url = "https://files.pythonhosted.org/packages/b8/23/6e6f32143ab5d8b36ca848a502c4bcd477ae75b9e1677e3530d669062578/cryptography-48.0.0-cp39-abi3-win32.whl", hash = "sha256:77a2ccbbe917f6710e05ba9adaa25fb5075620bf3ea6fb751997875aff4ae4bd", size = 3279117, upload-time = "2026-05-04T22:59:12.019Z" }, - { url = "https://files.pythonhosted.org/packages/9d/9a/0fea98a70cf1749d41d738836f6349d97945f7c89433a259a6c2642eefeb/cryptography-48.0.0-cp39-abi3-win_amd64.whl", hash = "sha256:16cd65b9330583e4619939b3a3843eec1e6e789744bb01e7c7e2e62e33c239c8", size = 3792100, upload-time = "2026-05-04T22:59:14.884Z" }, +sdist = { url = "https://files.pythonhosted.org/packages/0d/05/07b55d1fa21ac18c3a8c79f764e2514e6f6a9698f1be44994f5adf0d29db/cryptography-43.0.3.tar.gz", hash = "sha256:315b9001266a492a6ff443b61238f956b214dbec9910a081ba5b6646a055a805", size = 686989, upload-time = "2024-10-18T15:58:32.918Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1f/f3/01fdf26701a26f4b4dbc337a26883ad5bccaa6f1bbbdd29cd89e22f18a1c/cryptography-43.0.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:bf7a1932ac4176486eab36a19ed4c0492da5d97123f1406cf15e41b05e787d2e", size = 6225303, upload-time = "2024-10-18T15:57:36.753Z" }, + { url = "https://files.pythonhosted.org/packages/a3/01/4896f3d1b392025d4fcbecf40fdea92d3df8662123f6835d0af828d148fd/cryptography-43.0.3-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63efa177ff54aec6e1c0aefaa1a241232dcd37413835a9b674b6e3f0ae2bfd3e", size = 3760905, upload-time = "2024-10-18T15:57:39.166Z" }, + { url = "https://files.pythonhosted.org/packages/0a/be/f9a1f673f0ed4b7f6c643164e513dbad28dd4f2dcdf5715004f172ef24b6/cryptography-43.0.3-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e1ce50266f4f70bf41a2c6dc4358afadae90e2a1e5342d3c08883df1675374f", size = 3977271, upload-time = "2024-10-18T15:57:41.227Z" }, + { url = "https://files.pythonhosted.org/packages/4e/49/80c3a7b5514d1b416d7350830e8c422a4d667b6d9b16a9392ebfd4a5388a/cryptography-43.0.3-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:443c4a81bb10daed9a8f334365fe52542771f25aedaf889fd323a853ce7377d6", size = 3746606, upload-time = "2024-10-18T15:57:42.903Z" }, + { url = "https://files.pythonhosted.org/packages/0e/16/a28ddf78ac6e7e3f25ebcef69ab15c2c6be5ff9743dd0709a69a4f968472/cryptography-43.0.3-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:74f57f24754fe349223792466a709f8e0c093205ff0dca557af51072ff47ab18", size = 3986484, upload-time = "2024-10-18T15:57:45.434Z" }, + { url = "https://files.pythonhosted.org/packages/01/f5/69ae8da70c19864a32b0315049866c4d411cce423ec169993d0434218762/cryptography-43.0.3-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9762ea51a8fc2a88b70cf2995e5675b38d93bf36bd67d91721c309df184f49bd", size = 3852131, upload-time = "2024-10-18T15:57:47.267Z" }, + { url = "https://files.pythonhosted.org/packages/fd/db/e74911d95c040f9afd3612b1f732e52b3e517cb80de8bf183be0b7d413c6/cryptography-43.0.3-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:81ef806b1fef6b06dcebad789f988d3b37ccaee225695cf3e07648eee0fc6b73", size = 4075647, upload-time = "2024-10-18T15:57:49.684Z" }, + { url = "https://files.pythonhosted.org/packages/56/48/7b6b190f1462818b324e674fa20d1d5ef3e24f2328675b9b16189cbf0b3c/cryptography-43.0.3-cp37-abi3-win32.whl", hash = "sha256:cbeb489927bd7af4aa98d4b261af9a5bc025bd87f0e3547e11584be9e9427be2", size = 2623873, upload-time = "2024-10-18T15:57:51.822Z" }, + { url = "https://files.pythonhosted.org/packages/eb/b1/0ebff61a004f7f89e7b65ca95f2f2375679d43d0290672f7713ee3162aff/cryptography-43.0.3-cp37-abi3-win_amd64.whl", hash = "sha256:f46304d6f0c6ab8e52770addfa2fc41e6629495548862279641972b6215451cd", size = 3068039, upload-time = "2024-10-18T15:57:54.426Z" }, + { url = "https://files.pythonhosted.org/packages/30/d5/c8b32c047e2e81dd172138f772e81d852c51f0f2ad2ae8a24f1122e9e9a7/cryptography-43.0.3-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:8ac43ae87929a5982f5948ceda07001ee5e83227fd69cf55b109144938d96984", size = 6222984, upload-time = "2024-10-18T15:57:56.174Z" }, + { url = "https://files.pythonhosted.org/packages/2f/78/55356eb9075d0be6e81b59f45c7b48df87f76a20e73893872170471f3ee8/cryptography-43.0.3-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:846da004a5804145a5f441b8530b4bf35afbf7da70f82409f151695b127213d5", size = 3762968, upload-time = "2024-10-18T15:57:58.206Z" }, + { url = "https://files.pythonhosted.org/packages/2a/2c/488776a3dc843f95f86d2f957ca0fc3407d0242b50bede7fad1e339be03f/cryptography-43.0.3-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f996e7268af62598f2fc1204afa98a3b5712313a55c4c9d434aef49cadc91d4", size = 3977754, upload-time = "2024-10-18T15:58:00.683Z" }, + { url = "https://files.pythonhosted.org/packages/7c/04/2345ca92f7a22f601a9c62961741ef7dd0127c39f7310dffa0041c80f16f/cryptography-43.0.3-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:f7b178f11ed3664fd0e995a47ed2b5ff0a12d893e41dd0494f406d1cf555cab7", size = 3749458, upload-time = "2024-10-18T15:58:02.225Z" }, + { url = "https://files.pythonhosted.org/packages/ac/25/e715fa0bc24ac2114ed69da33adf451a38abb6f3f24ec207908112e9ba53/cryptography-43.0.3-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:c2e6fc39c4ab499049df3bdf567f768a723a5e8464816e8f009f121a5a9f4405", size = 3988220, upload-time = "2024-10-18T15:58:04.331Z" }, + { url = "https://files.pythonhosted.org/packages/21/ce/b9c9ff56c7164d8e2edfb6c9305045fbc0df4508ccfdb13ee66eb8c95b0e/cryptography-43.0.3-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:e1be4655c7ef6e1bbe6b5d0403526601323420bcf414598955968c9ef3eb7d16", size = 3853898, upload-time = "2024-10-18T15:58:06.113Z" }, + { url = "https://files.pythonhosted.org/packages/2a/33/b3682992ab2e9476b9c81fff22f02c8b0a1e6e1d49ee1750a67d85fd7ed2/cryptography-43.0.3-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:df6b6c6d742395dd77a23ea3728ab62f98379eff8fb61be2744d4679ab678f73", size = 4076592, upload-time = "2024-10-18T15:58:08.673Z" }, + { url = "https://files.pythonhosted.org/packages/81/1e/ffcc41b3cebd64ca90b28fd58141c5f68c83d48563c88333ab660e002cd3/cryptography-43.0.3-cp39-abi3-win32.whl", hash = "sha256:d56e96520b1020449bbace2b78b603442e7e378a9b3bd68de65c782db1507995", size = 2623145, upload-time = "2024-10-18T15:58:10.264Z" }, + { url = "https://files.pythonhosted.org/packages/87/5c/3dab83cc4aba1f4b0e733e3f0c3e7d4386440d660ba5b1e3ff995feb734d/cryptography-43.0.3-cp39-abi3-win_amd64.whl", hash = "sha256:0c580952eef9bf68c4747774cde7ec1d85a6e61de97281f2dba83c7d2c806362", size = 3068026, upload-time = "2024-10-18T15:58:11.916Z" }, ] [[package]] @@ -391,11 +367,13 @@ dependencies = [ { name = "alembic" }, { name = "asyncpg" }, { name = "boto3" }, + { name = "cryptography" }, { name = "fastapi" }, { name = "httpx" }, { name = "huggingface-hub" }, { name = "pydantic" }, { name = "pydantic-settings" }, + { name = "pyjwt", extra = ["crypto"] }, { name = "sqlalchemy", extra = ["asyncio"] }, { name = "structlog" }, { name = "tenacity" }, @@ -418,11 +396,13 @@ requires-dist = [ { name = "alembic", specifier = ">=1.13,<1.14" }, { name = "asyncpg", specifier = ">=0.29,<0.30" }, { name = "boto3", specifier = ">=1.35,<2.0" }, + { name = "cryptography", specifier = ">=43,<44" }, { name = "fastapi", specifier = ">=0.115,<0.116" }, { name = "httpx", specifier = ">=0.27,<0.28" }, { name = "huggingface-hub", specifier = ">=0.26,<1.0" }, { name = "pydantic", specifier = ">=2.9,<2.11" }, { name = "pydantic-settings", specifier = ">=2.6,<2.7" }, + { name = "pyjwt", extras = ["crypto"], specifier = ">=2.9,<3.0" }, { name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0,<2.1" }, { name = "structlog", specifier = ">=24.4,<24.5" }, { name = "tenacity", specifier = ">=9.0,<10.0" }, @@ -1009,6 +989,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f4/7e/a72dd26f3b0f4f2bf1dd8923c85f7ceb43172af56d63c7383eb62b332364/pygments-2.20.0-py3-none-any.whl", hash = "sha256:81a9e26dd42fd28a23a2d169d86d7ac03b46e2f8b59ed4698fb4785f946d0176", size = 1231151, upload-time = "2026-03-29T13:29:30.038Z" }, ] +[[package]] +name = "pyjwt" +version = "2.12.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c2/27/a3b6e5bf6ff856d2509292e95c8f57f0df7017cf5394921fc4e4ef40308a/pyjwt-2.12.1.tar.gz", hash = "sha256:c74a7a2adf861c04d002db713dd85f84beb242228e671280bf709d765b03672b", size = 102564, upload-time = "2026-03-13T19:27:37.25Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/7a/8dd906bd22e79e47397a61742927f6747fe93242ef86645ee9092e610244/pyjwt-2.12.1-py3-none-any.whl", hash = "sha256:28ca37c070cad8ba8cd9790cd940535d40274d22f80ab87f3ac6a713e6e8454c", size = 29726, upload-time = "2026-03-13T19:27:35.677Z" }, +] + +[package.optional-dependencies] +crypto = [ + { name = "cryptography" }, +] + [[package]] name = "pytest" version = "8.4.2"