diff --git a/examples/stage/manifest.json b/examples/stage/manifest.json new file mode 100644 index 0000000..95902ac --- /dev/null +++ b/examples/stage/manifest.json @@ -0,0 +1,17 @@ +{ + "created_at": "2025-10-23T13:09:51.866481+00:00", + "envelope": { + "alg": "AES-256-GCM", + "filename": "mvpEX.stl", + "kdf": "HKDF-SHA256", + "kx": "X25519", + "payload_sha256": "409d11741321d4940018f304a2df598b4c175cb2f2933985ea99f91ff69f2078", + "sender_pub_ed25519": "2NWHb34FFmSUkc77VKhwk2rq9Nj8yhp95xaI9Uvl6GM=" + }, + "payload": { + "filename": "mvpEX.stl.psenc", + "sha256": "25f2eedf62aeebd95c525de870932e38c76e4c0909f636bebfd4f7338042949e" + }, + "policy_id": "nist-800-171", + "v": 1 +} \ No newline at end of file diff --git a/examples/stage/mvpEX.stl b/examples/stage/mvpEX.stl new file mode 100644 index 0000000..ee25e4b Binary files /dev/null and b/examples/stage/mvpEX.stl differ diff --git a/examples/stage/mvpEX.stl.psenc b/examples/stage/mvpEX.stl.psenc new file mode 100644 index 0000000..d9f4316 Binary files /dev/null and b/examples/stage/mvpEX.stl.psenc differ diff --git a/pyproject.toml b/pyproject.toml index ca01cb9..b103698 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ dependencies = [ "rich>=13.7,<15", "pydantic>=2.11,<2.12", "cryptography>=42,<46", - "paramiko>=3.4,<3.7", + "paramiko>=3.2", "requests>=2.31,<3.0", "watchdog>=4.0,<5.0", "docker>=6.1,<8.0", diff --git a/src/printshield/cli.py b/src/printshield/cli.py index bb34536..43ef7ac 100644 --- a/src/printshield/cli.py +++ b/src/printshield/cli.py @@ -24,6 +24,11 @@ build_provenance, sign_provenance, verify_provenance, default_manifest_path, save_manifest, ) +from .core.transfer.receive import receive_gate +from .core.transfer.sftp import sftp_upload_verify +from .core.transfer.https import https_upload_verify +from .core.transfer.octoprint import octoprint_upload + @dataclass class Context: @@ -65,6 +70,9 @@ def main_callback( prov_app = typer.Typer(help="Provenance sidecar tools for STL.") app.add_typer(prov_app, name="provenance") +transfer_app = typer.Typer(help="Secure file transfer to printers/servers.") +app.add_typer(transfer_app, name="transfer") + # --- Command stubs --- @app.command(help="Interactive bootstrap: keys, policy, audit path, endpoints, hardening.") @@ -271,13 +279,195 @@ def provenance_verify( typer.echo(f"Provenance FAIL — {res.reason}") raise typer.Exit(code=0 if res.ok else 1) -@app.command(help="Securely transfer to printer/server (SFTP/HTTPS/OctoPrint); integrity gate on receive.") -def transfer() -> None: - typer.echo("transfer: not implemented yet.") +# ============================ TRANSFER / RECIEVE / MONITOR ================================ +@transfer_app.command("sftp", help="Upload a bundle or envelope over SFTP with host-key check and SHA-256 verification.") +def transfer_sftp( + ctx: typer.Context, + input: Path = typer.Argument(..., exists=True, file_okay=True, dir_okay=False, readable=True), + host: str = typer.Option(..., "--host", "-H"), + user: str = typer.Option(..., "--user", "-u"), + dest_dir: str = typer.Option(..., "--dest-dir", "-d", help="Remote directory (POSIX)"), + dest_name: str = typer.Option(None, "--dest-name", help="Remote filename (defaults to source basename)"), + port: int = typer.Option(22, "--port"), + key_file: Path = typer.Option(None, "--key-file", help="Private key for auth (PEM)"), + password: str = typer.Option(None, "--password", help="Password (discouraged; use keys)"), + known_hosts: Path = typer.Option(None, "--known-hosts", help="known_hosts file to trust (merged with system)"), + hostkey_fingerprint: str = typer.Option(None, "--hostkey-fingerprint", help="Pin server key (OpenSSH SHA256:... form)"), + insecure_no_hostkey_check: bool = typer.Option(False, "--insecure-no-hostkey-check", help="Accept new/unknown host keys (NOT recommended)"), +): + c = cast(Context, ctx.obj) + loaded = load_config(c.config_path) + res = sftp_upload_verify( + input, + host=host, + username=user, + dest_dir=dest_dir, + dest_name=dest_name, + port=port, + key_filename=str(key_file) if key_file else None, + password=password, + known_hosts=str(known_hosts) if known_hosts else None, + hostkey_fingerprint=hostkey_fingerprint, + allow_unknown_hostkey=bool(insecure_no_hostkey_check), + ) + record_event(loaded.config.audit.path, "transfer.sftp", { + "input": str(input), + "host": host, + "remote_path": res.remote_path, + "size": res.size, + "sha256": res.local_sha256, + "hostkey_pinned": hostkey_fingerprint is not None, + "insecure_no_hostkey_check": bool(insecure_no_hostkey_check), + }) + if c.output_format == "json": + import json + typer.echo(json.dumps(res.__dict__, indent=2)) + else: + typer.echo(f"Remote: {res.remote_path}") + typer.echo(f"Size: {res.size} bytes") + typer.echo(f"SHA-256: {res.local_sha256}") + +@transfer_app.command("https", help="Upload a bundle/envelope over HTTPS with TLS + SHA-256 verification.") +def transfer_https( + ctx: typer.Context, + input: Path = typer.Argument(..., exists=True, file_okay=True, dir_okay=False, readable=True), + url: str = typer.Option(..., "--url", "-U", help="HTTPS endpoint to POST to."), + token: str = typer.Option(None, "--token", help="Bearer token for Authorization header."), + ca_cert: Path = typer.Option(None, "--ca-cert", help="Custom CA cert bundle (PEM)."), + insecure_no_verify_tls: bool = typer.Option(False, "--insecure-no-verify-tls", help="Disable TLS verification (NOT recommended)."), + timeout: float = typer.Option(30.0, "--timeout", help="Request timeout in seconds."), +): + c = cast(Context, ctx.obj) + loaded = load_config(c.config_path) + + res = https_upload_verify( + input, + url=url, + token=token, + ca_cert=str(ca_cert) if ca_cert is not None else None, + insecure_no_verify_tls=bool(insecure_no_verify_tls), + timeout=timeout, + ) + + record_event( + loaded.config.audit.path, + "transfer.https", + { + "input": str(input), + "url": res.url, + "sha256": res.sha256, + "stored_path": res.stored_path, + "insecure_no_verify_tls": bool(insecure_no_verify_tls), + "has_token": token is not None, + }, + ) + + if c.output_format == "json": + typer.echo(json.dumps(res.__dict__, indent=2)) + else: + typer.echo(f"Uploaded to: {res.url}") + typer.echo(f"SHA-256: {res.sha256}") + if res.stored_path: + typer.echo(f"Stored as: {res.stored_path}") + +@transfer_app.command("octoprint", help="Upload a bundle/envelope to OctoPrint via its REST API.") +def transfer_octoprint( + ctx: typer.Context, + input: Path = typer.Argument(..., exists=True, file_okay=True, dir_okay=False, readable=True), + base_url: str = typer.Option(..., "--url", "-U", help="Base URL of OctoPrint, e.g. https://octopi.local"), + api_key: str = typer.Option(..., "--api-key", "-K", help="OctoPrint API key (X-Api-Key)."), + location: str = typer.Option("local", "--location", "-L", help="Target location: local|sdcard"), + subpath: str = typer.Option(None, "--path", help="Remote subfolder within location."), + select: bool = typer.Option(False, "--select", help="Select file after upload."), + print_after: bool = typer.Option(False, "--print", help="Start printing after upload (implies select)."), + ca_cert: Path = typer.Option(None, "--ca-cert", help="Custom CA bundle (PEM)."), + insecure_no_verify_tls: bool = typer.Option(False, "--insecure-no-verify-tls", help="Disable TLS verification (NOT recommended)."), + timeout: float = typer.Option(30.0, "--timeout", help="Request timeout in seconds."), +): + c = cast(Context, ctx.obj) + loaded = load_config(c.config_path) + + res = octoprint_upload( + input, + base_url=base_url, + api_key=api_key, + location=location, + select=select or print_after, + print_after_select=print_after, + subpath=subpath, + ca_cert=str(ca_cert) if ca_cert is not None else None, + insecure_no_verify_tls=bool(insecure_no_verify_tls), + timeout=timeout, + ) + + record_event( + loaded.config.audit.path, + "transfer.octoprint", + { + "input": str(input), + "url": res.url, + "location": res.location, + "path": res.path, + "origin": res.origin, + "sha256": res.sha256, + "select": select or print_after, + "print": print_after, + "insecure_no_verify_tls": bool(insecure_no_verify_tls), + }, + ) + + if c.output_format == "json": + typer.echo(json.dumps(res.__dict__, indent=2)) + else: + typer.echo(f"Uploaded to OctoPrint {res.origin}:{res.path}") + typer.echo(f"SHA-256: {res.sha256}") + -@app.command(help="Server/printer-side staging: verify bundle, signatures, and policy before queueing.") -def receive() -> None: - typer.echo("receive: not implemented yet.") +@app.command(help="Server-side gate: verify policy, decrypt envelope, and stage plaintext for printing.") +def receive( + ctx: typer.Context, + input: Path = typer.Argument(..., exists=True, file_okay=True, dir_okay=False, readable=True, + help="Path to .pshieldpkg or .psenc"), + private_key: Path = typer.Option(..., "--private-key", "-k", help="Recipient X25519 private key (PEM)."), + expected_sender_pub: Path = typer.Option(None, "--expected-sender-pub", "-E", help="Require this ed25519 public key (PEM)."), + out_dir: Path = typer.Option(None, "--out-dir", "-d", help="Staging directory (created if missing). Defaults to .stage"), + require_policy: str = typer.Option(None, "--require-policy", help="If set, reject bundles whose manifest.policy_id != this"), +): + c = cast(Context, ctx.obj) + loaded = load_config(c.config_path) + + result = receive_gate( + input_path=input, + out_dir=out_dir, + recipient_private_key_pem=private_key, + expected_sender_pub_pem=expected_sender_pub, + require_policy_id=require_policy, + ) + + record_event( + loaded.config.audit.path, "receive", + { + "input": str(input), + "staged_dir": str(result.staged_dir), + "envelope": str(result.envelope_path), + "decrypted": str(result.decrypted_path), + "manifest": str(result.manifest_path) if result.manifest_path else None, + }, + ) + + if c.output_format == "json": + typer.echo(json.dumps({ + "staged_dir": str(result.staged_dir), + "envelope": str(result.envelope_path), + "decrypted": str(result.decrypted_path), + "manifest": str(result.manifest_path) if result.manifest_path else None, + }, indent=2)) + else: + typer.echo(f"Staged dir: {result.staged_dir}") + if result.manifest_path: + typer.echo(f"Manifest: {result.manifest_path}") + typer.echo(f"Envelope: {result.envelope_path}") + typer.echo(f"Decrypted: {result.decrypted_path}") @app.command(help="Watch directories/printers for changes and integrity drift; emit audit events.") def monitor() -> None: diff --git a/src/printshield/core/transfer/https.py b/src/printshield/core/transfer/https.py index e69de29..d73de2a 100644 --- a/src/printshield/core/transfer/https.py +++ b/src/printshield/core/transfer/https.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +import hashlib +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Dict, Any + +@dataclass +class HttpUploadResult: + url: str + sha256: str + stored_path: Optional[str] + +def _sha256_file(path: Path) -> str: + h = hashlib.sha256() + with path.open("rb") as f: + for chunk in iter(lambda: f.read(1024 * 1024), b""): + h.update(chunk) + return h.hexdigest() + +def https_upload_verify( + local_path: str | Path, + *, + url: str, + token: Optional[str] = None, + ca_cert: Optional[str | Path] = None, + insecure_no_verify_tls: bool = False, + timeout: float = 30.0, + extra_meta: Optional[Dict[str, Any]] = None, + session: Any = None, # for tests; if None, create a new requests.Session() +) -> HttpUploadResult: + """ + Upload local_path to HTTPS endpoint via POST multipart/form-data, + include SHA-256 in header/body, and require server to echo its computed SHA-256. + Integrity passes only if server's sha256 == local sha256. + """ + lp = Path(local_path) + local_hash = _sha256_file(lp) + + import requests # type: ignore + + verify: bool | str + if insecure_no_verify_tls: + verify = False + elif ca_cert is not None: + verify = str(ca_cert) + else: + verify = True # system CA bundle + + headers: Dict[str, str] = { + "X-PrintShield-SHA256": local_hash, + } + if token: + headers["Authorization"] = f"Bearer {token}" + + meta = dict(extra_meta or {}) + meta["sha256"] = local_hash + + sess = session or requests.Session() + with lp.open("rb") as f: + files = {"file": (lp.name, f)} + resp = sess.post( + url, + headers=headers, + files=files, + data=meta, + timeout=timeout, + verify=verify, + ) + + if not (200 <= resp.status_code < 300): + raise ValueError(f"HTTP upload failed: {resp.status_code}") + + # Expect JSON body with 'sha256' and optional 'stored_path' + try: + data = resp.json() + except Exception as e: + raise ValueError(f"Server response is not valid JSON: {e}") from e + + server_hash = data.get("sha256") + if server_hash != local_hash: + raise ValueError("Integrity check failed: server sha256 != local sha256") + + stored = data.get("stored_path") + return HttpUploadResult(url=url, sha256=local_hash, stored_path=stored) \ No newline at end of file diff --git a/src/printshield/core/transfer/octoprint.py b/src/printshield/core/transfer/octoprint.py index e69de29..99acb7e 100644 --- a/src/printshield/core/transfer/octoprint.py +++ b/src/printshield/core/transfer/octoprint.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +import hashlib +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Optional + + +@dataclass +class OctoPrintUploadResult: + url: str # full API URL used + location: str # 'local' or 'sdcard' + path: str # remote path within location + name: str # remote file name + origin: str # origin reported by OctoPrint ('local'/'sdcard') + sha256: str # local file SHA-256 + + +def _sha256_file(path: Path) -> str: + h = hashlib.sha256() + with path.open("rb") as f: + for chunk in iter(lambda: f.read(1024 * 1024), b""): + h.update(chunk) + return h.hexdigest() + + +def octoprint_upload( + local_path: str | Path, + *, + base_url: str, + api_key: str, + location: str = "local", # 'local' or 'sdcard' + select: bool = False, + print_after_select: bool = False, + subpath: Optional[str] = None, # remote folder within location + ca_cert: Optional[str | Path] = None, + insecure_no_verify_tls: bool = False, + timeout: float = 30.0, + session: Any = None, # requests.Session-like; used for tests +) -> OctoPrintUploadResult: + """ + Upload local_path to OctoPrint via POST /api/files/{location}. + + - Uses X-Api-Key header for auth. + - Attaches local SHA-256 as metadata in 'userdata' JSON. + - Optionally selects/prints the file. + """ + lp = Path(local_path) + local_hash = _sha256_file(lp) + + import requests # type: ignore + + url = base_url.rstrip("/") + f"/api/files/{location}" + + # TLS verification settings + verify: bool | str + if insecure_no_verify_tls: + verify = False + elif ca_cert is not None: + verify = str(ca_cert) + else: + verify = True + + headers: Dict[str, str] = { + "X-Api-Key": api_key, + } + + form: Dict[str, str] = {} + + # OctoPrint expects "true"/"false" (strings) for these flags + if select or print_after_select: + form["select"] = "true" + else: + form["select"] = "false" + form["print"] = "true" if print_after_select else "false" + + if subpath: + form["path"] = subpath + + # Store PrintShield metadata as userdata JSON + userdata = {"sha256": local_hash} + form["userdata"] = json.dumps(userdata) + + sess = session or requests.Session() + with lp.open("rb") as f: + files = {"file": (lp.name, f)} + resp = sess.post( + url, + headers=headers, + files=files, + data=form, + timeout=timeout, + verify=verify, + ) + + if not (200 <= resp.status_code < 300): + raise ValueError(f"OctoPrint upload failed: HTTP {resp.status_code}") + + try: + body = resp.json() + except Exception as e: + raise ValueError(f"OctoPrint response is not valid JSON: {e}") from e + + # Upload Response contains a 'files' dict keyed by origin ('local' or 'sdcard') + files_info = body.get("files", {}) + file_info = None + if isinstance(files_info, dict): + file_info = files_info.get(location) or (next(iter(files_info.values()), None) if files_info else None) + + if file_info is None and "file" in body: + file_info = body["file"] + + if not isinstance(file_info, dict): + # fall back to filename only + remote_path = lp.name + name = lp.name + origin = location + else: + remote_path = file_info.get("path") or lp.name + name = file_info.get("name") or lp.name + origin = file_info.get("origin") or location + + return OctoPrintUploadResult( + url=url, + location=location, + path=remote_path, + name=name, + origin=origin, + sha256=local_hash, + ) \ No newline at end of file diff --git a/src/printshield/core/transfer/receive.py b/src/printshield/core/transfer/receive.py new file mode 100644 index 0000000..17ea994 --- /dev/null +++ b/src/printshield/core/transfer/receive.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Tuple + +from .bundle import extract_bundle +from ..crypto.encrypt import read_envelope_header, decrypt_file + +@dataclass +class ReceiveResult: + staged_dir: Path + is_bundle: bool + envelope_path: Path + decrypted_path: Path + manifest_path: Path | None + +def _is_bundle(path: Path) -> bool: + return path.suffix.endswith(".pshieldpkg") or path.name.endswith(".pshieldpkg") + +def receive_gate( + input_path: str | Path, + out_dir: str | Path | None, + *, + recipient_private_key_pem: str | Path, + expected_sender_pub_pem: str | Path | None = None, + require_policy_id: str | None = None, +) -> ReceiveResult: + """ + Accept a .pshieldpkg (bundle) or a .psenc (envelope), verify, decrypt, and stage the plaintext. + Returns paths of staged dir, envelope, decrypted file, and manifest (if any). + """ + src = Path(input_path) + stage = Path(out_dir) if out_dir else src.with_suffix(src.suffix + ".stage") + stage.mkdir(parents=True, exist_ok=True) + + manifest_path: Path | None = None + envelope_path: Path + + if _is_bundle(src): + # Verify and extract bundle into staging directory (checksum + optional policy) + # extract_bundle() writes manifest.json and returns the extracted payload path (envelope) + envelope_path = extract_bundle(src, stage, require_policy_id=require_policy_id) + manifest_path = stage / "manifest.json" + else: + # Treat Input as direct envelope + envelope_path = src + + # Determine original plaintext filename from header for a predictable output name + hdr = read_envelope_header(envelope_path) + target_name = hdr.get("filename") or Path(envelope_path).stem + decrypted_path = stage / target_name + + # Decrypt with sender pinning if provided; decrypt_file validates AEAD + payload hash (+ sender sig if present) + decrypted_path = decrypt_file( + envelope_path, + recipient_private_key_pem=recipient_private_key_pem, + out_path=decrypted_path, + expected_sender_pub_pem=expected_sender_pub_pem, + ) + + return ReceiveResult( + staged_dir=stage, + is_bundle=_is_bundle(src), + envelope_path=envelope_path, + decrypted_path=decrypted_path, + manifest_path=manifest_path, + ) \ No newline at end of file diff --git a/src/printshield/core/transfer/sftp.py b/src/printshield/core/transfer/sftp.py index e69de29..6be042d 100644 --- a/src/printshield/core/transfer/sftp.py +++ b/src/printshield/core/transfer/sftp.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +import hashlib +import base64 +import paramiko +from dataclasses import dataclass +from pathlib import Path, PurePosixPath +from typing import Optional, Callable, Any + +@dataclass +class SftpResult: + remote_path: str + local_sha256: str + remote_sha256: str + size: int + +def _sha256_file(p: Path) -> str: + h = hashlib.sha256() + with p.open("rb") as f: + for chunk in iter(lambda: f.read(1024 * 1024), b""): + h.update(chunk) + return h.hexdigest() + +def _sha256_stream(read_iter: Callable[[], bytes]) -> str: + h = hashlib.sha256() + for chunk in iter(read_iter, b""): + h.update(chunk) + return h.hexdigest() + +def _ensure_remote_dirs(sftp: Any, dest_dir: PurePosixPath) -> None: + # Recursively mkdir -p on remote + parts = [] + for part in dest_dir.parts: + parts.append(part) + cur = PurePosixPath(*parts) + try: + sftp.stat(str(cur)) + except Exception: + sftp.mkdir(str(cur)) + +def _fingerprint_sha256_str(pkey: Any) -> str: + """ + Return OpenSSH-style SHA256 fingerprint string. + Paramiko 3.2+ exposes `pkey.fingerprint` already; fall back if missing. + """ + fp = getattr(pkey, "fingerprint", None) + if isinstance(fp, str): + return fp # I'm expecting something like SHA256:Base64 (hopefully) + # Incase something fucks up -> compute SHA256 over raw public blob and encode like OpenSSH + raw = pkey.asbytes() + return "SHA256:" + base64.b64encode(hashlib.sha256(raw).digest()).decode("ascii").rstrip("=") + +def sftp_upload_verify( + local_path: str | Path, + *, + host: str, + username: str, + dest_dir: str, + dest_name: Optional[str] = None, + port: int = 22, + key_filename: Optional[str | Path] = None, + password: Optional[str] = None, + known_hosts: Optional[str | Path] = None, + hostkey_fingerprint: Optional[str] = None, # e.g. 'SHA256:xxxx' + allow_unknown_hostkey: bool = False, # default SAFE: False + ssh_client_factory: Optional[Callable[[], Any]] = None, # for tests/mocking +) -> SftpResult: + """ + Upload local_path to sftp://username@host:port/dest_dir/dest_name, + then read back the remote file to verify SHA-256 equals local. + """ + lp = Path(local_path) + if dest_name is None: + dest_name = lp.name + remote_dir = PurePosixPath(dest_dir) + remote_path = str(remote_dir / dest_name) + + # Add a lazy import later + + ssh = (ssh_client_factory or paramiko.SSHClient)() + # Host key policy and Known Hosts + ssh.load_system_host_keys() + if known_hosts: + ssh.load_host_keys(str(known_hosts)) + policy = paramiko.AutoAddPolicy() if allow_unknown_hostkey else paramiko.RejectPolicy() + ssh.set_missing_host_key_policy(policy) + + ssh.connect( + hostname=host, + port=port, + username=username, + key_filename=str(key_filename) if key_filename else None, + password=password, + look_for_keys=False if key_filename or password else True, + allow_agent=True, + ) + + # Optional explicit fingerprint pinning (post-connect) + if hostkey_fingerprint: + pkey = ssh.get_transport().get_remote_server_key() + actual = _fingerprint_sha256_str(pkey) + if actual != hostkey_fingerprint: + try: + ssh.close() + finally: + raise ValueError(f"Host key fingerprint mismatch: expected {hostkey_fingerprint}, got {actual}") + + sftp = ssh.open_sftp() + + try: + _ensure_remote_dirs(sftp, remote_dir) + # Upload + sftp.put(str(lp), remote_path) + # Compute local hash + local_hash = _sha256_file(lp) + # Read back remote to verify + with sftp.open(remote_path, "rb") as rf: + remote_hash = _sha256_stream(lambda: rf.read(1024 * 1024)) + size = sftp.stat(remote_path).st_size + finally: + try: + sftp.close() + finally: + ssh.close() + + if local_hash != remote_hash: + raise ValueError("Integrity check failed: remote SHA-256 != local SHA-256") + + return SftpResult(remote_path=remote_path, local_sha256=local_hash, remote_sha256=remote_hash, size=size) \ No newline at end of file diff --git a/tests/unit/test_receive_bundle_gate.py b/tests/unit/test_receive_bundle_gate.py new file mode 100644 index 0000000..ca00210 --- /dev/null +++ b/tests/unit/test_receive_bundle_gate.py @@ -0,0 +1,68 @@ +from __future__ import annotations +from pathlib import Path + +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey +from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey +from typer.testing import CliRunner +from printshield.cli import app + +runner = CliRunner() + +def _gen_rx_and_sender(tmp: Path): + rx_sk = X25519PrivateKey.generate() + rx_pk = rx_sk.public_key() + sx_sk = Ed25519PrivateKey.generate() + sx_pk = sx_sk.public_key() + + rx_sk_p = tmp / "rx_sk.pem"; rx_pk_p = tmp / "rx_pk.pem" + sx_sk_p = tmp / "sx_sk.pem"; sx_pk_p = tmp / "sx_pk.pem" + + rx_sk_p.write_bytes(rx_sk.private_bytes( + serialization.Encoding.PEM, serialization.PrivateFormat.PKCS8, serialization.NoEncryption())) + rx_pk_p.write_bytes(rx_pk.public_bytes( + serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo)) + sx_sk_p.write_bytes(sx_sk.private_bytes( + serialization.Encoding.PEM, serialization.PrivateFormat.PKCS8, serialization.NoEncryption())) + sx_pk_p.write_bytes(sx_pk.public_bytes( + serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo)) + return rx_sk_p, rx_pk_p, sx_sk_p, sx_pk_p + +def test_receive_bundle_happy_path(tmp_path: Path): + rx_sk, rx_pk, sx_sk, sx_pk = _gen_rx_and_sender(tmp_path) + cfg = tmp_path / "ps.yaml" + cfg.write_text(f"policy_id: thesis\naudit:\n path: {tmp_path.as_posix()}/audit\n", encoding="utf-8") + + # plaintext + f = tmp_path / "part.stl" + f.write_text("solid p\nendsolid p\n", encoding="utf-8") + + # encrypt+bundle + enc = tmp_path / "part.stl.psenc" + r1 = runner.invoke(app, [ + "encrypt", str(f), + "--recipient-key", str(rx_pk), + "--sender-key", str(sx_sk), + "--output", str(enc), + ]) + assert r1.exit_code == 0, r1.output + + bundle = tmp_path / "part.stl.psenc.pshieldpkg" + r2 = runner.invoke(app, ["--config", str(cfg), "bundle", "create", str(enc), "--output", str(bundle)]) + assert r2.exit_code == 0, r2.output + + # receive gate + out_dir = tmp_path / "stage" + r3 = runner.invoke(app, [ + "--config", str(cfg), + "receive", str(bundle), + "--private-key", str(rx_sk), + "--expected-sender-pub", str(sx_pk), + "--out-dir", str(out_dir), + "--require-policy", "thesis", + ]) + assert r3.exit_code == 0, r3.output + # plaintext should be staged and match original + dec = out_dir / "part.stl" + assert dec.exists() + assert dec.read_bytes() == f.read_bytes() diff --git a/tests/unit/test_receive_envelope_gate.py b/tests/unit/test_receive_envelope_gate.py new file mode 100644 index 0000000..430750e --- /dev/null +++ b/tests/unit/test_receive_envelope_gate.py @@ -0,0 +1,58 @@ +from __future__ import annotations +from pathlib import Path + +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey +from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey +from typer.testing import CliRunner +from printshield.cli import app + +runner = CliRunner() + +def _gen_rx_and_sender(tmp: Path): + rx_sk = X25519PrivateKey.generate() + rx_pk = rx_sk.public_key() + sx_sk = Ed25519PrivateKey.generate() + sx_pk = sx_sk.public_key() + + rx_sk_p = tmp / "rx_sk.pem"; rx_pk_p = tmp / "rx_pk.pem" + sx_sk_p = tmp / "sx_sk.pem"; sx_pk_p = tmp / "sx_pk.pem" + + rx_sk_p.write_bytes(rx_sk.private_bytes( + serialization.Encoding.PEM, serialization.PrivateFormat.PKCS8, serialization.NoEncryption())) + rx_pk_p.write_bytes(rx_pk.public_bytes( + serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo)) + sx_sk_p.write_bytes(sx_sk.private_bytes( + serialization.Encoding.PEM, serialization.PrivateFormat.PKCS8, serialization.NoEncryption())) + sx_pk_p.write_bytes(sx_pk.public_bytes( + serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo)) + return rx_sk_p, rx_pk_p, sx_sk_p, sx_pk_p + +def test_receive_envelope_direct(tmp_path: Path): + rx_sk, rx_pk, sx_sk, sx_pk = _gen_rx_and_sender(tmp_path) + cfg = tmp_path / "ps.yaml" + cfg.write_text(f"policy_id: thesis\naudit:\n path: {tmp_path.as_posix()}/audit\n", encoding="utf-8") + + f = tmp_path / "p.stl" + f.write_text("solid p\nendsolid p\n", encoding="utf-8") + + enc = tmp_path / "p.stl.psenc" + r1 = runner.invoke(app, [ + "encrypt", str(f), + "--recipient-key", str(rx_pk), + "--sender-key", str(sx_sk), + "--output", str(enc), + ]) + assert r1.exit_code == 0, r1.output + + out_dir = tmp_path / "stage" + r2 = runner.invoke(app, [ + "--config", str(cfg), + "receive", str(enc), + "--private-key", str(rx_sk), + "--expected-sender-pub", str(sx_pk), + "--out-dir", str(out_dir), + ]) + assert r2.exit_code == 0, r2.output + dec = out_dir / "p.stl" + assert dec.read_bytes() == f.read_bytes() diff --git a/tests/unit/test_receive_policy_and_sender_pinning.py b/tests/unit/test_receive_policy_and_sender_pinning.py new file mode 100644 index 0000000..0c6dd12 --- /dev/null +++ b/tests/unit/test_receive_policy_and_sender_pinning.py @@ -0,0 +1,91 @@ +from __future__ import annotations +from pathlib import Path + +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey +from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey +from typer.testing import CliRunner +from printshield.cli import app + +runner = CliRunner() + +def _gen_rx() -> tuple[X25519PrivateKey, X25519PrivateKey]: + return X25519PrivateKey.generate(), X25519PrivateKey.generate() + +def _emit_pems_x25519(tmp: Path, sk: X25519PrivateKey): + sk_p = tmp / "sk.pem"; pk_p = tmp / "pk.pem" + sk_p.write_bytes(sk.private_bytes( + serialization.Encoding.PEM, serialization.PrivateFormat.PKCS8, serialization.NoEncryption())) + pk_p.write_bytes(sk.public_key().public_bytes( + serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo)) + return sk_p, pk_p + +# My idiot ass was overwriting the file, causing the test to fail :) +# def _emit_pems_ed25519(tmp: Path, sk: Ed25519PrivateKey): +# sk_p = tmp / "esk.pem"; pk_p = tmp / "epk.pem" +# sk_p.write_bytes(sk.private_bytes( +# serialization.Encoding.PEM, serialization.PrivateFormat.PKCS8, serialization.NoEncryption())) +# pk_p.write_bytes(sk.public_key().public_bytes( +# serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo)) +# return sk_p, pk_p + +def _emit_pems_ed25519(tmp: Path, sk: Ed25519PrivateKey, name: str): + sk_p = tmp / f"{name}_sk.pem" + pk_p = tmp / f"{name}_pk.pem" + sk_p.write_bytes(sk.private_bytes( + serialization.Encoding.PEM, serialization.PrivateFormat.PKCS8, serialization.NoEncryption())) + pk_p.write_bytes(sk.public_key().public_bytes( + serialization.Encoding.PEM, serialization.PublicFormat.SubjectPublicKeyInfo)) + return sk_p, pk_p + +def test_receive_rejects_policy_mismatch(tmp_path: Path): + rx_sk = X25519PrivateKey.generate() + sx_sk = Ed25519PrivateKey.generate() + rx_sk_p, rx_pk_p = _emit_pems_x25519(tmp_path, rx_sk) + sx_sk_p, sx_pk_p = _emit_pems_ed25519(tmp_path, sx_sk, "sk1") + + cfg = tmp_path / "ps.yaml" + cfg.write_text(f"policy_id: correct\naudit:\n path: {tmp_path.as_posix()}/audit\n", encoding="utf-8") + + f = tmp_path / "m.stl"; f.write_text("solid m\nendsolid m\n", encoding="utf-8") + enc = tmp_path / "m.stl.psenc" + runner.invoke(app, ["encrypt", str(f), "--recipient-key", str(rx_pk_p), "--sender-key", str(sx_sk_p), "--output", str(enc)]) + bundle = tmp_path / "m.stl.psenc.pshieldpkg" + runner.invoke(app, ["--config", str(cfg), "bundle", "create", str(enc), "--output", str(bundle)]) + + out_dir = tmp_path / "stage" + r = runner.invoke(app, [ + "--config", str(cfg), + "receive", str(bundle), + "--private-key", str(rx_sk_p), + "--expected-sender-pub", str(sx_pk_p), + "--out-dir", str(out_dir), + "--require-policy", "mismatch", + ]) + assert r.exit_code != 0 + +def test_receive_rejects_wrong_sender(tmp_path: Path): + rx_sk = X25519PrivateKey.generate() + sx1_sk = Ed25519PrivateKey.generate() + sx2_sk = Ed25519PrivateKey.generate() + + rx_sk_p, rx_pk_p = _emit_pems_x25519(tmp_path, rx_sk) + sx1_sk_p, sx1_pk_p = _emit_pems_ed25519(tmp_path, sx1_sk, "sx1") + sx2_sk_p, sx2_pk_p = _emit_pems_ed25519(tmp_path, sx2_sk, "sx2") + + cfg = tmp_path / "ps.yaml" + cfg.write_text(f"policy_id: thesis\naudit:\n path: {tmp_path.as_posix()}/audit\n", encoding="utf-8") + + f = tmp_path / "w.stl"; f.write_text("solid w\nendsolid w\n", encoding="utf-8") + enc = tmp_path / "w.stl.psenc" + runner.invoke(app, ["encrypt", str(f), "--recipient-key", str(rx_pk_p), "--sender-key", str(sx1_sk_p), "--output", str(enc)]) + out_dir = tmp_path / "stage" + + r = runner.invoke(app, [ + "--config", str(cfg), + "receive", str(enc), + "--private-key", str(rx_sk_p), + "--expected-sender-pub", str(sx2_pk_p), # wrong sender + "--out-dir", str(out_dir), + ]) + assert r.exit_code != 0 diff --git a/tests/unit/test_transfer_https_non_2xx.py b/tests/unit/test_transfer_https_non_2xx.py new file mode 100644 index 0000000..08a7eaa --- /dev/null +++ b/tests/unit/test_transfer_https_non_2xx.py @@ -0,0 +1,23 @@ +from __future__ import annotations +from pathlib import Path +import pytest +from printshield.core.transfer.https import https_upload_verify + +class _FakeRespFail: + def __init__(self): self.status_code = 500 + def json(self): return {"error": "boom"} + +class _FakeSessionFail: + def post(self, url, headers=None, files=None, data=None, timeout=None, verify=None): + return _FakeRespFail() + +def test_https_upload_verify_non_2xx(tmp_path: Path): + f = tmp_path / "job.pshieldpkg" + f.write_bytes(b"hi") + sess = _FakeSessionFail() + with pytest.raises(ValueError, match="HTTP upload failed"): + https_upload_verify( + f, + url="https://printshield.example/upload", + session=sess, + ) diff --git a/tests/unit/test_transfer_https_ok.py b/tests/unit/test_transfer_https_ok.py new file mode 100644 index 0000000..1f1024c --- /dev/null +++ b/tests/unit/test_transfer_https_ok.py @@ -0,0 +1,42 @@ +from __future__ import annotations +from pathlib import Path + +from printshield.core.transfer.https import https_upload_verify + +class _FakeRespOK: + def __init__(self, status_code: int, data: dict): + self.status_code = status_code + self._data = data + def json(self): + return self._data + +class _FakeSessionOK: + def __init__(self, expected_sha: str): + self.expected_sha = expected_sha + self.last_headers = None + self.last_url = None + def post(self, url, headers=None, files=None, data=None, timeout=None, verify=None): + self.last_url = url + self.last_headers = headers + # mimic server: echo sha256 it saw in body + sha = data.get("sha256") + assert sha == self.expected_sha + return _FakeRespOK(200, {"sha256": sha, "stored_path": "/spool/jobs/foo.pshieldpkg"}) + +def test_https_upload_verify_ok(tmp_path: Path): + f = tmp_path / "job.pshieldpkg" + f.write_bytes(b"hello-world") + # compute hash to set expectation + import hashlib + h = hashlib.sha256(b"hello-world").hexdigest() + sess = _FakeSessionOK(expected_sha=h) + + res = https_upload_verify( + f, + url="https://printshield.example/upload", + session=sess, + ) + assert res.url == "https://printshield.example/upload" + assert res.sha256 == h + assert res.stored_path == "/spool/jobs/foo.pshieldpkg" + assert sess.last_headers["X-PrintShield-SHA256"] == h diff --git a/tests/unit/test_transfer_https_sha_mismatch.py b/tests/unit/test_transfer_https_sha_mismatch.py new file mode 100644 index 0000000..c3e8dee --- /dev/null +++ b/tests/unit/test_transfer_https_sha_mismatch.py @@ -0,0 +1,25 @@ +from __future__ import annotations +from pathlib import Path +import pytest + +from printshield.core.transfer.https import https_upload_verify + +class _FakeRespBadHash: + def __init__(self): self.status_code = 200 + def json(self): + return {"sha256": "not-the-same", "stored_path": "/spool/bad"} + +class _FakeSessionBadHash: + def post(self, url, headers=None, files=None, data=None, timeout=None, verify=None): + return _FakeRespBadHash() + +def test_https_upload_verify_fails_on_mismatched_hash(tmp_path: Path): + f = tmp_path / "job.psenc" + f.write_bytes(b"payload") + sess = _FakeSessionBadHash() + with pytest.raises(ValueError, match="Integrity check failed"): + https_upload_verify( + f, + url="https://printshield.example/upload", + session=sess, + ) diff --git a/tests/unit/test_transfer_octoprint_bad_json.py b/tests/unit/test_transfer_octoprint_bad_json.py new file mode 100644 index 0000000..f31809f --- /dev/null +++ b/tests/unit/test_transfer_octoprint_bad_json.py @@ -0,0 +1,30 @@ +from __future__ import annotations +from pathlib import Path +import pytest + +from printshield.core.transfer.octoprint import octoprint_upload + + +class _FakeRespBadJSON: + def __init__(self): + self.status_code = 201 + def json(self): + raise ValueError("not json") + + +class _FakeSessionBadJSON: + def post(self, url, headers=None, files=None, data=None, timeout=None, verify=None): + return _FakeRespBadJSON() + + +def test_octoprint_upload_rejects_invalid_json(tmp_path: Path): + f = tmp_path / "job.pshieldpkg" + f.write_bytes(b"x") + sess = _FakeSessionBadJSON() + with pytest.raises(ValueError, match="OctoPrint response is not valid JSON"): + octoprint_upload( + f, + base_url="https://octopi.local", + api_key="APIKEY", + session=sess, + ) diff --git a/tests/unit/test_transfer_octoprint_non_2xx.py b/tests/unit/test_transfer_octoprint_non_2xx.py new file mode 100644 index 0000000..8172e3c --- /dev/null +++ b/tests/unit/test_transfer_octoprint_non_2xx.py @@ -0,0 +1,30 @@ +from __future__ import annotations +from pathlib import Path +import pytest + +from printshield.core.transfer.octoprint import octoprint_upload + + +class _FakeRespFail: + def __init__(self, status: int): + self.status_code = status + def json(self): + return {"error": "nope"} + + +class _FakeSessionFail: + def post(self, url, headers=None, files=None, data=None, timeout=None, verify=None): + return _FakeRespFail(500) + + +def test_octoprint_upload_rejects_non_2xx(tmp_path: Path): + f = tmp_path / "job.pshieldpkg" + f.write_bytes(b"x") + sess = _FakeSessionFail() + with pytest.raises(ValueError, match="OctoPrint upload failed"): + octoprint_upload( + f, + base_url="https://octopi.local", + api_key="APIKEY", + session=sess, + ) diff --git a/tests/unit/test_transfer_octoprint_ok.py b/tests/unit/test_transfer_octoprint_ok.py new file mode 100644 index 0000000..ee18d86 --- /dev/null +++ b/tests/unit/test_transfer_octoprint_ok.py @@ -0,0 +1,84 @@ +from __future__ import annotations +from pathlib import Path +import json +import hashlib + +from printshield.core.transfer.octoprint import octoprint_upload + + +class _FakeRespOK: + def __init__(self, status: int, body: dict): + self.status_code = status + self._body = body + + def json(self): + return self._body + + +class _FakeSessionOK: + def __init__(self): + self.last_url = None + self.last_headers = None + self.last_files = None + self.last_data = None + + def post(self, url, headers=None, files=None, data=None, timeout=None, verify=None): + self.last_url = url + self.last_headers = headers + self.last_files = files + self.last_data = data + # Simulate OctoPrint Upload Response + body = { + "files": { + "local": { + "path": "printshield/job.pshieldpkg", + "name": "job.pshieldpkg", + "origin": "local", + } + } + } + return _FakeRespOK(201, body) + + +def test_octoprint_upload_ok(tmp_path: Path): + f = tmp_path / "job.pshieldpkg" + f.write_bytes(b"hello-octoprint") + + # Compute expected hash + h = hashlib.sha256(b"hello-octoprint").hexdigest() + + sess = _FakeSessionOK() + res = octoprint_upload( + f, + base_url="https://octopi.local", + api_key="APIKEY123", + location="local", + select=True, + print_after_select=False, + subpath="printshield", + session=sess, + ) + + # Check result + assert res.url == "https://octopi.local/api/files/local" + assert res.location == "local" + assert res.path == "printshield/job.pshieldpkg" + assert res.name == "job.pshieldpkg" + assert res.origin == "local" + assert res.sha256 == h + + # Check request details + assert sess.last_url == "https://octopi.local/api/files/local" + assert sess.last_headers["X-Api-Key"] == "APIKEY123" + assert "file" in sess.last_files + filename, _fileobj = sess.last_files["file"] + assert filename == "job.pshieldpkg" + + # form data + data = sess.last_data + assert data["select"] == "true" + assert data["print"] == "false" + assert data["path"] == "printshield" + # userdata contains our sha256 + userdata = json.loads(data["userdata"]) + assert userdata["sha256"] == h diff --git a/tests/unit/test_transfer_sftp_fingerprint_mismatch.py b/tests/unit/test_transfer_sftp_fingerprint_mismatch.py new file mode 100644 index 0000000..4c52375 --- /dev/null +++ b/tests/unit/test_transfer_sftp_fingerprint_mismatch.py @@ -0,0 +1,38 @@ +from __future__ import annotations +from pathlib import Path +import pytest +from printshield.core.transfer.sftp import sftp_upload_verify + +class _FakeSSH_badFP: + def __init__(self): pass + def load_system_host_keys(self): pass + def load_host_keys(self, fn): pass + def set_missing_host_key_policy(self, p): pass + def connect(self, **kw): pass + class _T: + class _K: + def asbytes(self): return b"not-the-same" + def get_remote_server_key(self): return self._K() + def get_transport(self): return self._T() + class _S: + def stat(self, p): + class _SS: st_size = 6 + return _SS() + def mkdir(self, p): pass + def put(self, local, remote): pass + def open(self, p, mode): + from io import BytesIO + return BytesIO(b"abc123") + def close(self): pass + def open_sftp(self): return self._S() + def close(self): pass + +def test_transfer_sftp_fails_on_pinned_fingerprint_mismatch(tmp_path: Path): + f = tmp_path / "file.psenc" + f.write_text("abc123", encoding="utf-8") + with pytest.raises(ValueError): + sftp_upload_verify( + f, host="h", username="u", dest_dir="/x", + hostkey_fingerprint="SHA256:AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", + ssh_client_factory=_FakeSSH_badFP, + ) diff --git a/tests/unit/test_transfer_sftp_mkdirs.py b/tests/unit/test_transfer_sftp_mkdirs.py new file mode 100644 index 0000000..0dd0524 --- /dev/null +++ b/tests/unit/test_transfer_sftp_mkdirs.py @@ -0,0 +1,37 @@ +from __future__ import annotations +from pathlib import Path +from printshield.core.transfer.sftp import sftp_upload_verify + +class _FakeSSH_dirs: + def __init__(self): + self.store = {} + def load_system_host_keys(self): pass + def load_host_keys(self, fn): pass + def set_missing_host_key_policy(self, p): pass + def connect(self, **kw): pass + class _T: + class _K: + def asbytes(self): return b"blob" + def get_remote_server_key(self): return self._K() + def get_transport(self): return self._T() + class _S: + def __init__(self, store): self.store = store + def stat(self, p): + if p not in self.store: raise IOError("nope") + class _SS: st_size = len(self.store[p]) + return _SS() + def mkdir(self, p): self.store[p] = b"" + def put(self, local, remote): self.store[remote] = Path(local).read_bytes() + def open(self, p, mode): + from io import BytesIO + return BytesIO(self.store[p]) + def close(self): pass + def open_sftp(self): return self._S(self.store) + def close(self): pass + +def test_transfer_sftp_creates_remote_dirs(tmp_path: Path): + f = tmp_path / "x.psenc"; f.write_bytes(b"data") + res = sftp_upload_verify( + f, host="h", username="u", dest_dir="/a/b/c", ssh_client_factory=_FakeSSH_dirs + ) + assert res.remote_path == "/a/b/c/x.psenc" diff --git a/tests/unit/test_transfer_sftp_ok.py b/tests/unit/test_transfer_sftp_ok.py new file mode 100644 index 0000000..4ab5dd3 --- /dev/null +++ b/tests/unit/test_transfer_sftp_ok.py @@ -0,0 +1,55 @@ +from __future__ import annotations +from pathlib import Path +from typer.testing import CliRunner +from printshield.cli import app +from printshield.core.transfer.sftp import sftp_upload_verify + +runner = CliRunner() + +class _FakePKey: + def __init__(self): self._raw = b"ssh-ed25519 AAA..." + def asbytes(self): return b"dummy-pub-blob" + +class _FakeTransport: + def get_remote_server_key(self): return _FakePKey() + +class _FakeSFTP: + def __init__(self, store): self.store = store + def stat(self, p): + data = self.store[p] + class _S: st_size = len(data) + return _S() + def mkdir(self, p): self.store.setdefault(p, b"") + def open(self, p, mode): + from io import BytesIO + if "r" in mode: + return BytesIO(self.store[p]) + raise NotImplementedError + def put(self, local, remote): + self.store[remote] = Path(local).read_bytes() + def close(self): pass + +class _FakeSSH: + def __init__(self): + self.store = {} + self.transport = _FakeTransport() + def load_system_host_keys(self): pass + def load_host_keys(self, fn): pass + def set_missing_host_key_policy(self, p): self._policy = p + def connect(self, **kw): pass + def get_transport(self): return self.transport + def open_sftp(self): return _FakeSFTP(self.store) + def close(self): pass + +def test_transfer_sftp_upload_and_verify(tmp_path: Path, monkeypatch): + # Generate a dummy file + f = tmp_path / "job.pshieldpkg" + f.write_text("abc123", encoding="utf-8") + + # Monkeypatch factory to return fake ssh client + from printshield.core.transfer import sftp as mod + res = sftp_upload_verify( + f, host="h", username="u", dest_dir="/var/spool/ps", ssh_client_factory=_FakeSSH + ) + assert res.remote_path == "/var/spool/ps/job.pshieldpkg" + assert res.local_sha256 == res.remote_sha256