diff --git a/pychron/cloud/api_client.py b/pychron/cloud/api_client.py index a6abe97df..f6b51c0d4 100644 --- a/pychron/cloud/api_client.py +++ b/pychron/cloud/api_client.py @@ -26,7 +26,6 @@ from pychron.globals import globalv - DEFAULT_TIMEOUT = 10 @@ -49,6 +48,32 @@ class CloudFingerprintRejected(CloudAPIError): """ +class CloudDeviceCodePending(CloudAPIError): + """Device-code poll: admin has not approved yet (HTTP 425). + + Workstation should sleep ``interval_seconds`` and poll again. + """ + + +class CloudDeviceCodeDenied(CloudAPIError): + """Device-code poll: admin explicitly denied the request (HTTP 403). + + Terminal — workstation must stop polling and ask the admin to start + a new request. + """ + + +class CloudDeviceCodeExpired(CloudAPIError): + """Device-code poll terminal failure (HTTP 410). + + Server collapses several lifecycle states into a uniform + ``expired_token`` to deny enumeration oracles, so the client can't + distinguish ``not_found`` / ``expired`` / ``already_consumed`` / + ``lab_vanished`` / ``scope_mismatch`` either. Workstation must stop + polling and start over. + """ + + class CloudNetworkError(CloudAPIError): """Transport-level failure (DNS, TCP, TLS, timeout, non-JSON body).""" @@ -327,6 +352,222 @@ def register_ssh_key(base_url, token, public_key, title=None, timeout=DEFAULT_TI ) +class DeviceCodeStart(object): + """Result of ``POST /api/v1/forgejo/device-codes``. + + The ``device_code`` is the polling secret; the ``user_code`` is the + short admin-typed code shown in the workstation UI alongside the + ``verification_url``. Both plaintext fields are returned exactly + once — only hashes are persisted server-side. + """ + + __slots__ = ( + "device_code", + "user_code", + "verification_url", + "verification_url_complete", + "expires_at", + "interval_seconds", + "raw", + ) + + def __init__( + self, + device_code, + user_code, + verification_url, + verification_url_complete, + expires_at, + interval_seconds, + raw, + ): + self.device_code = device_code + self.user_code = user_code + self.verification_url = verification_url + self.verification_url_complete = verification_url_complete + self.expires_at = expires_at + self.interval_seconds = interval_seconds + self.raw = raw + + +class DeviceCodePollSuccess(object): + """Successful device-code poll. The minted ``api_token`` is plaintext + and is returned exactly once; the caller must persist it to the OS + keyring before losing the reference. ``ssh_key`` is the same shape + that :func:`register_ssh_key` returns so the orchestrator can reuse + the existing persist/apply path. + """ + + __slots__ = ( + "api_token", + "lab", + "api_base_url", + "default_metadata_repo", + "ssh_host_alias", + "ssh_key", + "raw", + ) + + def __init__( + self, + api_token, + lab, + api_base_url, + default_metadata_repo, + ssh_host_alias, + ssh_key, + raw, + ): + self.api_token = api_token + self.lab = lab + self.api_base_url = api_base_url + self.default_metadata_repo = default_metadata_repo + self.ssh_host_alias = ssh_host_alias or {} + self.ssh_key = ssh_key + self.raw = raw + + +def start_device_code(base_url, public_key, hostname, timeout=DEFAULT_TIMEOUT): + """POST a workstation public key to start a device-code grant. + + Endpoint is unauthenticated. Maps: + + - 201 → :class:`DeviceCodeStart` + - 400 → :class:`CloudFingerprintRejected` (malformed pubkey) + - other 4xx/5xx → :class:`CloudAPIError` + - transport / non-JSON → :class:`CloudNetworkError` + """ + if not base_url: + raise CloudAPIError("api_base_url is empty") + if not public_key: + raise CloudAPIError("public_key is empty") + if not hostname: + raise CloudAPIError("hostname is empty") + + url = _join(base_url, "/api/v1/forgejo/device-codes") + headers = { + "Accept": "application/json", + "Content-Type": "application/json", + } + payload = {"public_key": public_key.strip(), "hostname": hostname} + + try: + resp = requests.post( + url, + headers=headers, + json=payload, + timeout=timeout, + verify=globalv.cert_file, + ) + except requests.RequestException as exc: + raise CloudNetworkError("device-code start transport failure: {}".format(exc)) + + if resp.status_code == 400: + raise CloudFingerprintRejected("server rejected key (HTTP 400): {}".format(resp.text[:200])) + if resp.status_code != 201: + raise CloudAPIError( + "device-code start returned HTTP {}: {}".format(resp.status_code, resp.text[:200]) + ) + + try: + body = resp.json() + except ValueError as exc: + raise CloudNetworkError("device-code start returned non-JSON body: {}".format(exc)) + + # Strip the secret from `raw` before exposing it. Callers who + # serialize DeviceCodeStart.raw for debugging would otherwise leak + # both the device_code (polling secret) and user_code into logs/disk. + safe_raw = {k: v for k, v in body.items() if k not in ("device_code", "user_code")} + return DeviceCodeStart( + device_code=body.get("device_code", ""), + user_code=body.get("user_code", ""), + verification_url=body.get("verification_url", ""), + verification_url_complete=body.get("verification_url_complete", ""), + expires_at=body.get("expires_at", ""), + interval_seconds=int(body.get("interval_seconds") or 5), + raw=safe_raw, + ) + + +def poll_device_code(base_url, device_code, timeout=DEFAULT_TIMEOUT): + """Poll a device-code grant. Unauthenticated — the device_code is the credential. + + Maps: + + - 200 → :class:`DeviceCodePollSuccess` + - 425 → :class:`CloudDeviceCodePending` (keep polling) + - 403 → :class:`CloudDeviceCodeDenied` (terminal — admin denied) + - 410 → :class:`CloudDeviceCodeExpired` (terminal — uniform server response + for not-found / expired / already-consumed / lab-vanished / + scope-mismatch) + - 400 → :class:`CloudFingerprintRejected` + - other 4xx/5xx → :class:`CloudAPIError` + - transport / non-JSON → :class:`CloudNetworkError` + """ + if not base_url: + raise CloudAPIError("api_base_url is empty") + if not device_code: + raise CloudAPIError("device_code is empty") + + url = _join(base_url, "/api/v1/forgejo/device-codes/poll") + headers = { + "Accept": "application/json", + "Content-Type": "application/json", + } + payload = {"device_code": device_code} + + try: + resp = requests.post( + url, + headers=headers, + json=payload, + timeout=timeout, + verify=globalv.cert_file, + ) + except requests.RequestException as exc: + raise CloudNetworkError("device-code poll transport failure: {}".format(exc)) + + if resp.status_code == 425: + raise CloudDeviceCodePending("authorization_pending") + if resp.status_code == 403: + raise CloudDeviceCodeDenied("access_denied") + if resp.status_code == 410: + raise CloudDeviceCodeExpired("expired_token") + if resp.status_code == 400: + raise CloudFingerprintRejected("server rejected key (HTTP 400): {}".format(resp.text[:200])) + if resp.status_code != 200: + raise CloudAPIError( + "device-code poll returned HTTP {}: {}".format(resp.status_code, resp.text[:200]) + ) + + try: + body = resp.json() + except ValueError as exc: + raise CloudNetworkError("device-code poll returned non-JSON body: {}".format(exc)) + + ssh_key_payload = body.get("ssh_key") or {} + ssh_key = SSHKeyRegistration( + bot_username=ssh_key_payload.get("bot_username", ""), + fingerprint=ssh_key_payload.get("fingerprint", ""), + default_metadata_repo=ssh_key_payload.get("default_metadata_repo", ""), + ssh_host_alias=ssh_key_payload.get("ssh_host_alias") or body.get("ssh_host_alias") or {}, + raw=ssh_key_payload, + ) + + # Strip the plaintext token from `raw` before exposing it. + safe_raw = {k: v for k, v in body.items() if k != "api_token"} + + return DeviceCodePollSuccess( + api_token=body.get("api_token", ""), + lab=body.get("lab", ""), + api_base_url=body.get("api_base_url", "") or base_url, + default_metadata_repo=body.get("default_metadata_repo"), + ssh_host_alias=body.get("ssh_host_alias") or {}, + ssh_key=ssh_key, + raw=safe_raw, + ) + + def revoke_workstation_token(base_url, token, timeout=DEFAULT_TIMEOUT): """Revoke the calling token via ``DELETE /api/v1/forgejo/tokens/self``. diff --git a/pychron/cloud/tasks/preferences.py b/pychron/cloud/tasks/preferences.py index def09b86f..cc3d63a43 100644 --- a/pychron/cloud/tasks/preferences.py +++ b/pychron/cloud/tasks/preferences.py @@ -31,12 +31,15 @@ import logging from envisage.ui.tasks.preferences_pane import PreferencesPane +from pyface.api import GUI from traits.api import Bool, Button, Password, Str from traitsui.api import Color, Group, HGroup, Item, VGroup, View from pychron.cloud.api_client import ( CloudAPIError, CloudAuthError, + CloudDeviceCodeDenied, + CloudDeviceCodeExpired, CloudNetworkError, whoami, ) @@ -46,6 +49,8 @@ set_token, ) from pychron.cloud.workstation_setup import ( + DeviceEnrollmentCancelled, + KeyringWriteFailedError, WorkstationSetup, WorkstationSetupError, switch_lab as wipe_for_switch_lab, @@ -75,6 +80,25 @@ class CloudPreferences(BasePreferencesHelper): reonboard_button = Button("Re-onboard workstation") revoke_button = Button("Revoke this workstation") switch_lab_button = Button("Switch lab (destructive)") + + # Device-code enrollment (RFC 8628-style). The technician clicks + # ``enroll_via_device_code_button``; the workstation contacts + # pychronAPI, displays ``_pending_user_code`` + ``_pending_verification_url`` + # for the technician to read out to the admin, and polls in a + # background thread until the admin approves. + enroll_via_device_code_button = Button("Start device-code enrollment") + cancel_enrollment_button = Button("Cancel enrollment") + _pending_user_code = Str + _pending_verification_url = Str + _pending_active = Bool(False) + _should_cancel_enrollment = Bool(False) + + # Surfaced on KeyringWriteFailedError so the technician can paste + # the (still-in-memory) token into a password manager. Cleared + # whenever a fresh enrollment starts. + _recovery_token = Str + _recovery_lab = Str + _remote_status = Str _remote_status_color = Color @@ -87,8 +111,8 @@ def _initialize(self, *args, **kw): def _is_preference_trait(self, trait_name): # api_token must never be written to the .cfg — it lives in the OS - # keyring. The transient remote-status traits and the lifecycle - # buttons also stay out. + # keyring. The transient remote-status, enrollment progress, and + # lifecycle-button traits also stay out. if trait_name in ( "api_token", "_remote_status", @@ -97,6 +121,14 @@ def _is_preference_trait(self, trait_name): "reonboard_button", "revoke_button", "switch_lab_button", + "enroll_via_device_code_button", + "cancel_enrollment_button", + "_pending_user_code", + "_pending_verification_url", + "_pending_active", + "_should_cancel_enrollment", + "_recovery_token", + "_recovery_lab", ): return False return super(CloudPreferences, self)._is_preference_trait(trait_name) @@ -156,6 +188,144 @@ def _test_connection_fired(self): self._remote_status = "OK ({} / {})".format(info.kind or "?", info.lab or "?") self._remote_status_color = normalize_color_name("green") + # -- device-code enrollment --------------------------------------- + + def _enroll_via_device_code_button_fired(self): + """Kick off a device-code grant in a background thread. + + The worker thread updates ``_pending_user_code`` and + ``_pending_verification_url`` so the technician can read them + out to the admin, then polls until completion. + """ + if self._pending_active: + return + self._remote_status_color = normalize_color_name("red") + if not self.api_base_url: + self._remote_status = "Set API Base URL first" + return + + self._should_cancel_enrollment = False + self._pending_user_code = "" + self._pending_verification_url = "" + self._recovery_token = "" + self._recovery_lab = "" + self._pending_active = True + self._remote_status = "Starting enrollment..." + self._remote_status_color = normalize_color_name("orange") + + import threading + + threading.Thread( + target=self._enrollment_worker, + name="pychron-cloud-device-code", + daemon=True, + ).start() + + def _on_device_code_user_code( + self, user_code, verification_url, verification_url_complete, expires_at + ): + """Worker-thread callback: surface the user_code + URL in the pane. + + Trait writes from non-UI threads are dispatched to the UI thread + by the Pyface event loop, so the operator sees the code as soon + as the server returns it. + """ + self._pending_user_code = user_code + self._pending_verification_url = verification_url + self._remote_status = "Show {} to admin at {}".format(user_code, verification_url) + self._remote_status_color = normalize_color_name("orange") + + def _enrollment_worker(self): + api_base_url = self.api_base_url + try: + setup = WorkstationSetup.from_device_code( + api_base_url, + on_user_code=self._on_device_code_user_code, + should_cancel=lambda: self._should_cancel_enrollment, + ) + except DeviceEnrollmentCancelled: + GUI.invoke_later(self._apply_enrollment_terminal, "Enrollment cancelled", "red") + return + except CloudDeviceCodeDenied: + GUI.invoke_later( + self._apply_enrollment_terminal, + "Admin denied — ask for a new request", + "red", + ) + return + except CloudDeviceCodeExpired: + GUI.invoke_later(self._apply_enrollment_terminal, "Code expired — start over", "red") + return + except CloudAuthError: + GUI.invoke_later(self._apply_enrollment_terminal, "Auth rejected", "red") + return + except CloudNetworkError as exc: + logger.warning("device-code enrollment transport failure: %s", exc) + GUI.invoke_later(self._apply_enrollment_terminal, "Unreachable", "red") + return + except KeyringWriteFailedError as exc: + # Server already minted; we hold the only copy. Hand the + # plaintext to the UI thread for display — and DO NOT log + # the exception (its message intentionally omits the token + # but defense-in-depth: log only the type name). + logger.warning( + "device-code enrollment keyring write failed: %s", + type(exc).__name__, + ) + GUI.invoke_later(self._apply_keyring_recovery, exc.lab_name, exc.api_token) + return + except (CloudAPIError, WorkstationSetupError) as exc: + logger.warning("device-code enrollment failed: %s", type(exc).__name__) + GUI.invoke_later(self._apply_enrollment_terminal, "Enrollment failed", "red") + return + + GUI.invoke_later(self._apply_enrollment_success, setup) + + def _apply_enrollment_success(self, setup): + """Run on the UI thread. Persistent-trait writes (api_base_url, + lab_name) fire BasePreferencesHelper listeners that call into + Envisage's preferences node, which expects single-threaded + access — so we dispatch them here rather than from the worker. + """ + self.api_base_url = setup.api_base_url + self.lab_name = setup.lab_name + self._load_token_from_keyring() + self._remote_status = "Enrolled as {}".format(setup.lab_name) + self._remote_status_color = normalize_color_name("green") + self._reset_pending() + + def _apply_enrollment_terminal(self, message, color): + self._remote_status = message + self._remote_status_color = normalize_color_name(color) + self._reset_pending() + + def _apply_keyring_recovery(self, lab_name, api_token): + """Display the still-in-memory token so the technician can copy + it into a password manager. This is the recovery path for the + single-use polling secret being already consumed server-side + but not persisted locally. + """ + self._recovery_lab = lab_name + self._recovery_token = api_token + self._remote_status = ( + "Keyring write failed — copy the token below and store it " + "manually before closing this window" + ) + self._remote_status_color = normalize_color_name("red") + self._reset_pending() + + def _reset_pending(self): + self._pending_user_code = "" + self._pending_verification_url = "" + self._pending_active = False + self._should_cancel_enrollment = False + + def _cancel_enrollment_button_fired(self): + if not self._pending_active: + return + self._should_cancel_enrollment = True + self._remote_status = "Cancelling..." + # -- P6 buttons --------------------------------------------------- def _build_setup(self): @@ -274,6 +444,57 @@ def traits_view(self): show_border=True, label="Pychron Cloud (pychronAPI)", ) + enroll = VGroup( + HGroup( + Item( + "enroll_via_device_code_button", + show_label=False, + enabled_when="not _pending_active", + tooltip="Contact pychronAPI for a single-use device code, " + "then read the displayed code to your lab admin. They will " + "approve from any phone or laptop browser; this workstation " + "polls until they do.", + ), + Item( + "cancel_enrollment_button", + show_label=False, + enabled_when="_pending_active", + ), + ), + HGroup( + Item( + "_pending_user_code", + style="readonly", + label="Code", + visible_when="_pending_active", + ), + Item( + "_pending_verification_url", + style="readonly", + label="Approve at", + visible_when="_pending_active", + ), + ), + HGroup( + Item( + "_recovery_token", + style="readonly", + label="RECOVERY TOKEN", + tooltip="Keyring write failed — copy this token into a " + "password manager before closing the window. The polling " + "secret is single-use, so this is the only copy.", + visible_when="_recovery_token != ''", + ), + Item( + "_recovery_lab", + style="readonly", + label="for lab", + visible_when="_recovery_token != ''", + ), + ), + show_border=True, + label="Enroll via Device Code", + ) lifecycle = VGroup( HGroup( Item("reonboard_button", show_label=False), @@ -283,7 +504,7 @@ def traits_view(self): show_border=True, label="Workstation Lifecycle", ) - return View(Group(creds, lifecycle)) + return View(Group(creds, enroll, lifecycle)) # ============= EOF ============================================= diff --git a/pychron/cloud/workstation_setup.py b/pychron/cloud/workstation_setup.py index fbbd05fac..a06e875d2 100644 --- a/pychron/cloud/workstation_setup.py +++ b/pychron/cloud/workstation_setup.py @@ -45,13 +45,20 @@ import json import logging import os +import time from pychron.cloud.api_client import ( CloudAPIError, + CloudDeviceCodeDenied, + CloudDeviceCodeExpired, + CloudDeviceCodePending, CloudFingerprintRejected, + poll_device_code, register_ssh_key, revoke_workstation_token, + start_device_code, ) +from pychron.cloud.keyring_store import set_token as keyring_set_token from pychron.cloud.paths import ( ensure_pychron_dirs, host_slug, @@ -73,7 +80,6 @@ read_public_key, ) - logger = logging.getLogger(__name__) @@ -81,6 +87,36 @@ class WorkstationSetupError(Exception): """Raised when onboarding cannot complete.""" +class DeviceEnrollmentCancelled(WorkstationSetupError): + """The polling loop returned because ``should_cancel`` went True. + + Raised so the UI can distinguish a user-cancelled enrollment (offer + to start over) from a server-side denial / expiry (offer to ask the + admin for a new approval). + """ + + +class KeyringWriteFailedError(WorkstationSetupError): + """OS keyring write failed during enrollment. + + The polling secret is single-use, so by the time this fires the + server has already minted credentials and the workstation has the + only copy in memory. ``api_token`` and ``lab_name`` are exposed as + attributes so the UI can render them for the technician to paste + into a password manager. The exception's ``__str__`` deliberately + OMITS the token to keep it out of log files — callers must reach + into the attributes if they want to display it. + """ + + def __init__(self, lab_name, api_token): + super().__init__( + "could not save api_token to OS keyring; UI must surface " + "the token for manual capture" + ) + self.lab_name = lab_name + self.api_token = api_token + + class WorkstationSetup(object): """Onboard the current workstation against a pychronAPI lab. @@ -95,6 +131,102 @@ def __init__(self, api_base_url, api_token, lab_name, host=None): self.lab_name = lab_name self.host = host or host_slug() + # -- device-code enrollment ---------------------------------------- + + @classmethod + def from_device_code( + cls, + api_base_url, + on_user_code, + should_cancel=None, + host=None, + sleep=time.sleep, + ): + """Orchestrate an RFC 8628-style device-code enrollment end-to-end. + + Sequence: + + 1. ``ensure_pychron_dirs`` + ``ensure_keypair`` (matches + :meth:`run` — re-uses any existing local keypair). + 2. ``POST /api/v1/forgejo/device-codes`` to get a polling secret + + the short user_code the technician will read out to the admin. + 3. Invoke ``on_user_code(user_code, verification_url, + verification_url_complete, expires_at)`` exactly once. The UI + is expected to display the code and the URL. + 4. Poll on ``interval_seconds`` until one of: + + * **success** — admin approved; persist registration, + SSH-config, known_hosts, and OS-keyring token; return a + populated :class:`WorkstationSetup`. + * **denied** — :class:`CloudDeviceCodeDenied` re-raised. + * **expired** — :class:`CloudDeviceCodeExpired` re-raised. + * **cancelled** — ``should_cancel()`` returned True → + :class:`DeviceEnrollmentCancelled`. + + The ``sleep`` parameter is dependency-injected so tests can pass + a no-op without burning real wall time. + """ + if not api_base_url: + raise WorkstationSetupError("api_base_url is empty") + host = host or host_slug() + + ensure_pychron_dirs() + ensure_keypair(host) + public_key = read_public_key(host) + + start = start_device_code(api_base_url, public_key, host) + on_user_code( + start.user_code, + start.verification_url, + start.verification_url_complete, + start.expires_at, + ) + + interval = max(1, int(start.interval_seconds or 5)) + cancel = should_cancel or (lambda: False) + + while True: + if cancel(): + raise DeviceEnrollmentCancelled("enrollment cancelled by user") + try: + success = poll_device_code(api_base_url, start.device_code) + except CloudDeviceCodePending: + sleep(interval) + continue + except CloudDeviceCodeDenied: + raise + except CloudDeviceCodeExpired: + raise + except CloudFingerprintRejected: + raise + break + + if not success.api_token: + raise WorkstationSetupError("server did not return an api_token") + if not success.lab: + raise WorkstationSetupError("server did not return a lab name") + + api_base_url = success.api_base_url or api_base_url + setup = cls( + api_base_url=api_base_url, + api_token=success.api_token, + lab_name=success.lab, + host=host, + ) + setup._persist_registration(success.ssh_key) + setup._apply_ssh_config(success.ssh_key) + + # Stash the token in the OS keyring last. A failure here would + # normally be silent (the helper logs + returns False), but in + # the device-code flow the polling secret is single-use, so a + # silent loss leaves the technician unable to recover. Raise a + # typed error carrying the token on attributes (NOT in str()) + # so the UI can present it for manual capture without leaking + # it to log files. + if not keyring_set_token(success.lab, success.api_token): + raise KeyringWriteFailedError(lab_name=success.lab, api_token=success.api_token) + return setup + # -- public entry -------------------------------------------------- def run(self): diff --git a/test/cloud/test_api_client.py b/test/cloud/test_api_client.py index 3c3c4db18..2a63ae0f8 100644 --- a/test/cloud/test_api_client.py +++ b/test/cloud/test_api_client.py @@ -79,5 +79,216 @@ def test_can_register_ssh_key_false_when_scope_missing(self): self.assertFalse(info.can_register_ssh_key()) +_START_BODY = { + "device_code": "dvc_xyz", + "user_code": "ABCD-EFGH", + "verification_url": "https://api.example/device", + "verification_url_complete": "https://api.example/device?user_code=ABCD-EFGH", + "expires_at": "2026-05-09T12:00:00Z", + "interval_seconds": 5, +} + + +def _poll_body(**overrides): + body = { + "api_token": "pcy_NMGRL_xyz", + "lab": "NMGRL", + "api_base_url": "https://api.example", + "default_metadata_repo": None, + "ssh_host_alias": { + "alias": "pychron-NMGRL", + "real_host": "repo.example", + "port": 222, + "known_hosts_line": "repo.example ssh-rsa AAAA", + }, + "ssh_key": { + "bot_username": "bot-NMGRL-deadbeef", + "fingerprint": "SHA256:abc", + "rotated": False, + "default_metadata_repo": None, + "ssh_host_alias": { + "alias": "pychron-NMGRL", + "real_host": "repo.example", + "port": 222, + "known_hosts_line": "repo.example ssh-rsa AAAA", + }, + }, + } + body.update(overrides) + return body + + +class TestStartDeviceCode(unittest.TestCase): + URL = "https://api.example" + PUBKEY = "ssh-ed25519 AAAA test@host" + HOST = "lab-mac-01" + + def _call(self): + return api_client.start_device_code(self.URL, self.PUBKEY, self.HOST) + + def test_success_returns_start_result(self): + with patch.object(api_client.requests, "post", return_value=_resp(201, _START_BODY)): + r = self._call() + self.assertEqual(r.device_code, "dvc_xyz") + self.assertEqual(r.user_code, "ABCD-EFGH") + self.assertEqual(r.verification_url, "https://api.example/device") + self.assertEqual(r.interval_seconds, 5) + + def test_post_url_matches_endpoint(self): + with patch.object(api_client.requests, "post", return_value=_resp(201, _START_BODY)) as p: + self._call() + self.assertEqual( + p.call_args[0][0], + "https://api.example/api/v1/forgejo/device-codes", + ) + + def test_no_authorization_header(self): + """Endpoint is unauthenticated. Sending a stale Authorization header + could leak it on misconfigured proxies; client must omit one.""" + with patch.object(api_client.requests, "post", return_value=_resp(201, _START_BODY)) as p: + self._call() + headers = p.call_args.kwargs["headers"] + self.assertNotIn("Authorization", headers) + self.assertNotIn("authorization", {k.lower() for k in headers}) + + def test_400_raises_fingerprint_rejected(self): + with patch.object(api_client.requests, "post", return_value=_resp(400, {"detail": "bad"})): + with self.assertRaises(api_client.CloudFingerprintRejected): + self._call() + + def test_500_raises_api_error(self): + with patch.object(api_client.requests, "post", return_value=_resp(500, {"detail": "boom"})): + with self.assertRaises(api_client.CloudAPIError): + self._call() + + def test_transport_error_raises_network_error(self): + with patch.object( + api_client.requests, + "post", + side_effect=requests.ConnectionError("boom"), + ): + with self.assertRaises(api_client.CloudNetworkError): + self._call() + + def test_non_json_body_raises_network_error(self): + with patch.object(api_client.requests, "post", return_value=_resp(201, None)): + with self.assertRaises(api_client.CloudNetworkError): + self._call() + + def test_empty_args_raise_api_error(self): + for args in [ + ("", self.PUBKEY, self.HOST), + (self.URL, "", self.HOST), + (self.URL, self.PUBKEY, ""), + ]: + with self.assertRaises(api_client.CloudAPIError): + api_client.start_device_code(*args) + + def test_secrets_stripped_from_raw(self): + """``DeviceCodeStart.raw`` is exposed for debugging. Both the + device_code (polling secret) and the user_code (admin-facing + but not meant for logs) must be stripped.""" + with patch.object(api_client.requests, "post", return_value=_resp(201, _START_BODY)): + r = self._call() + self.assertNotIn("device_code", r.raw) + self.assertNotIn("user_code", r.raw) + self.assertIn("verification_url", r.raw) + + +class TestPollDeviceCode(unittest.TestCase): + URL = "https://api.example" + DEVICE_CODE = "dvc_xyz" + + def _call(self): + return api_client.poll_device_code(self.URL, self.DEVICE_CODE) + + def test_success_returns_poll_result(self): + with patch.object(api_client.requests, "post", return_value=_resp(200, _poll_body())): + r = self._call() + self.assertEqual(r.api_token, "pcy_NMGRL_xyz") + self.assertEqual(r.lab, "NMGRL") + self.assertEqual(r.api_base_url, "https://api.example") + self.assertEqual(r.ssh_key.bot_username, "bot-NMGRL-deadbeef") + self.assertEqual(r.ssh_key.alias, "pychron-NMGRL") + self.assertEqual(r.ssh_key.port, 222) + + def test_post_url_matches_endpoint(self): + with patch.object(api_client.requests, "post", return_value=_resp(200, _poll_body())) as p: + self._call() + self.assertEqual( + p.call_args[0][0], + "https://api.example/api/v1/forgejo/device-codes/poll", + ) + + def test_no_authorization_header(self): + with patch.object(api_client.requests, "post", return_value=_resp(200, _poll_body())) as p: + self._call() + headers = p.call_args.kwargs["headers"] + self.assertNotIn("Authorization", headers) + + def test_425_raises_pending(self): + with patch.object(api_client.requests, "post", return_value=_resp(425, {})): + with self.assertRaises(api_client.CloudDeviceCodePending): + self._call() + + def test_403_raises_denied(self): + with patch.object(api_client.requests, "post", return_value=_resp(403, {})): + with self.assertRaises(api_client.CloudDeviceCodeDenied): + self._call() + + def test_410_raises_expired(self): + with patch.object(api_client.requests, "post", return_value=_resp(410, {})): + with self.assertRaises(api_client.CloudDeviceCodeExpired): + self._call() + + def test_400_raises_fingerprint_rejected(self): + with patch.object(api_client.requests, "post", return_value=_resp(400, {})): + with self.assertRaises(api_client.CloudFingerprintRejected): + self._call() + + def test_500_raises_api_error(self): + with patch.object(api_client.requests, "post", return_value=_resp(500, {})): + with self.assertRaises(api_client.CloudAPIError) as cm: + self._call() + # Make sure 5xx isn't accidentally caught as one of the + # device-code-specific subclasses. + self.assertNotIsInstance(cm.exception, api_client.CloudDeviceCodePending) + self.assertNotIsInstance(cm.exception, api_client.CloudDeviceCodeDenied) + self.assertNotIsInstance(cm.exception, api_client.CloudDeviceCodeExpired) + + def test_transport_error_raises_network_error(self): + with patch.object( + api_client.requests, "post", side_effect=requests.ConnectionError("boom") + ): + with self.assertRaises(api_client.CloudNetworkError): + self._call() + + def test_non_json_body_raises_network_error(self): + with patch.object(api_client.requests, "post", return_value=_resp(200, None)): + with self.assertRaises(api_client.CloudNetworkError): + self._call() + + def test_empty_args_raise_api_error(self): + for args in [("", self.DEVICE_CODE), (self.URL, "")]: + with self.assertRaises(api_client.CloudAPIError): + api_client.poll_device_code(*args) + + def test_api_token_stripped_from_raw(self): + """``DeviceCodePollSuccess.raw`` is exposed for debugging. The + plaintext bearer token must NOT be in it — only on the dedicated + `.api_token` attribute that callers treat as a secret.""" + with patch.object(api_client.requests, "post", return_value=_resp(200, _poll_body())): + r = self._call() + self.assertEqual(r.api_token, "pcy_NMGRL_xyz") + self.assertNotIn("api_token", r.raw) + + def test_falls_back_to_caller_base_url_when_server_omits_it(self): + body = _poll_body() + body.pop("api_base_url") + with patch.object(api_client.requests, "post", return_value=_resp(200, body)): + r = self._call() + self.assertEqual(r.api_base_url, self.URL) + + if __name__ == "__main__": unittest.main() diff --git a/test/cloud/test_device_code_setup.py b/test/cloud/test_device_code_setup.py new file mode 100644 index 000000000..3a21789c0 --- /dev/null +++ b/test/cloud/test_device_code_setup.py @@ -0,0 +1,253 @@ +"""Tests for WorkstationSetup.from_device_code end-to-end flow.""" + +import json +import os +import tempfile +import unittest +from unittest.mock import MagicMock, patch + +from pychron.cloud import api_client, workstation_setup + + +_START_BODY = { + "device_code": "dvc_xyz", + "user_code": "ABCD-EFGH", + "verification_url": "https://api.example/device", + "verification_url_complete": "https://api.example/device?user_code=ABCD-EFGH", + "expires_at": "2026-05-09T12:00:00Z", + "interval_seconds": 1, +} + + +def _poll_body(): + return { + "api_token": "pcy_NMGRL_xyz", + "lab": "NMGRL", + "api_base_url": "https://api.example", + "default_metadata_repo": None, + "ssh_host_alias": { + "alias": "pychron-NMGRL", + "real_host": "repo.example", + "port": 2222, + "known_hosts_line": "[repo.example]:2222 ssh-ed25519 AAAA", + }, + "ssh_key": { + "bot_username": "bot-NMGRL-deadbeef", + "fingerprint": "SHA256:abc", + "rotated": False, + "default_metadata_repo": None, + "ssh_host_alias": { + "alias": "pychron-NMGRL", + "real_host": "repo.example", + "port": 2222, + "known_hosts_line": "[repo.example]:2222 ssh-ed25519 AAAA", + }, + }, + } + + +def _resp(status_code, body): + r = MagicMock() + r.status_code = status_code + r.text = str(body) + if body is None: + r.json.side_effect = ValueError("not json") + else: + r.json.return_value = body + return r + + +class FromDeviceCodeTestCase(unittest.TestCase): + URL = "https://api.example" + + def setUp(self): + self.tmp = tempfile.mkdtemp() + # Redirect ~ → tmp so ~/.pychron and ~/.ssh land in scratch space. + self._patcher = patch( + "pychron.cloud.paths.os.path.expanduser", + lambda p: p.replace("~", self.tmp), + ) + self._patcher.start() + self.addCleanup(self._patcher.stop) + self.addCleanup(self._rmtree, self.tmp) + + def _rmtree(self, path): + import shutil + + shutil.rmtree(path, ignore_errors=True) + + def test_happy_path_pending_then_success(self): + seen_codes = [] + + def on_user_code(uc, vu, vu_complete, exp): + seen_codes.append((uc, vu, vu_complete, exp)) + + with ( + patch.object(api_client.requests, "post") as post, + patch.object(workstation_setup, "keyring_set_token", return_value=True) as kr, + ): + post.side_effect = [ + _resp(201, _START_BODY), + _resp(425, {}), # pending + _resp(200, _poll_body()), + ] + sleeps = [] + setup = workstation_setup.WorkstationSetup.from_device_code( + self.URL, + on_user_code=on_user_code, + sleep=lambda s: sleeps.append(s), + host="testhost", + ) + + # Callback fired exactly once with the user_code + URL. + self.assertEqual(len(seen_codes), 1) + self.assertEqual(seen_codes[0][0], "ABCD-EFGH") + self.assertEqual(seen_codes[0][1], "https://api.example/device") + + # One sleep between pending and success (interval_seconds=1). + self.assertEqual(sleeps, [1]) + + # Returned setup populated. + self.assertEqual(setup.api_token, "pcy_NMGRL_xyz") + self.assertEqual(setup.lab_name, "NMGRL") + self.assertEqual(setup.api_base_url, "https://api.example") + + # Keypair, registration.json, known_hosts, and ~/.ssh/config all written. + priv = os.path.join(self.tmp, ".pychron", "keys", "pychron_testhost") + self.assertTrue(os.path.isfile(priv)) + self.assertTrue(os.path.isfile(priv + ".pub")) + + reg_path = os.path.join(self.tmp, ".pychron", "registration.json") + self.assertTrue(os.path.isfile(reg_path)) + with open(reg_path) as f: + self.assertEqual(json.load(f)["bot_username"], "bot-NMGRL-deadbeef") + + kh = os.path.join(self.tmp, ".pychron", "known_hosts") + with open(kh) as f: + self.assertIn("[repo.example]:2222", f.read()) + + ssh_cfg = os.path.join(self.tmp, ".ssh", "config") + with open(ssh_cfg) as f: + self.assertIn("Host pychron-NMGRL", f.read()) + + # Keyring write happened with the right (lab, token). + kr.assert_called_once_with("NMGRL", "pcy_NMGRL_xyz") + + # Polling endpoints hit. Start was the first call, polls came after. + self.assertEqual( + post.call_args_list[0][0][0], + "https://api.example/api/v1/forgejo/device-codes", + ) + self.assertEqual( + post.call_args_list[1][0][0], + "https://api.example/api/v1/forgejo/device-codes/poll", + ) + # No Authorization header on either unauthenticated call. + for call in post.call_args_list: + self.assertNotIn("Authorization", call.kwargs["headers"]) + + def test_denied_propagates_no_artifacts_persisted(self): + with patch.object(api_client.requests, "post") as post: + post.side_effect = [ + _resp(201, _START_BODY), + _resp(403, {}), # admin denied + ] + with self.assertRaises(api_client.CloudDeviceCodeDenied): + workstation_setup.WorkstationSetup.from_device_code( + self.URL, + on_user_code=lambda *a: None, + sleep=lambda s: None, + host="testhost", + ) + # Keypair was generated (start_device_code happened) but no + # registration / SSH config persisted because we never reached + # the success branch. + priv = os.path.join(self.tmp, ".pychron", "keys", "pychron_testhost") + self.assertTrue(os.path.isfile(priv)) + reg_path = os.path.join(self.tmp, ".pychron", "registration.json") + self.assertFalse(os.path.isfile(reg_path)) + + def test_expired_propagates(self): + with patch.object(api_client.requests, "post") as post: + post.side_effect = [ + _resp(201, _START_BODY), + _resp(425, {}), + _resp(410, {}), + ] + with self.assertRaises(api_client.CloudDeviceCodeExpired): + workstation_setup.WorkstationSetup.from_device_code( + self.URL, + on_user_code=lambda *a: None, + sleep=lambda s: None, + host="testhost", + ) + + def test_should_cancel_raises_DeviceEnrollmentCancelled(self): + ticks = {"n": 0} + + def cancel(): + ticks["n"] += 1 + return ticks["n"] >= 2 # cancel on second tick + + with patch.object(api_client.requests, "post") as post: + post.side_effect = [ + _resp(201, _START_BODY), + _resp(425, {}), # pending → loop sleeps then re-checks cancel + ] + with self.assertRaises(workstation_setup.DeviceEnrollmentCancelled): + workstation_setup.WorkstationSetup.from_device_code( + self.URL, + on_user_code=lambda *a: None, + should_cancel=cancel, + sleep=lambda s: None, + host="testhost", + ) + + def test_keyring_failure_raises_typed_error_token_not_in_str(self): + """Single-use polling secret was already consumed; if the keyring + write fails silently the technician would lose the credential. + Surface as ``KeyringWriteFailedError`` whose ``__str__`` does + NOT contain the token (so it can be safely logged) but whose + ``.api_token`` / ``.lab_name`` attributes carry the plaintext + for the UI to display. + """ + with ( + patch.object(api_client.requests, "post") as post, + patch.object(workstation_setup, "keyring_set_token", return_value=False), + ): + post.side_effect = [ + _resp(201, _START_BODY), + _resp(200, _poll_body()), + ] + with self.assertRaises(workstation_setup.KeyringWriteFailedError) as cm: + workstation_setup.WorkstationSetup.from_device_code( + self.URL, + on_user_code=lambda *a: None, + sleep=lambda s: None, + host="testhost", + ) + # Token NOT leaked through str(exc) — protects log files. + self.assertNotIn("pcy_NMGRL_xyz", str(cm.exception)) + # But available on attributes for the UI. + self.assertEqual(cm.exception.api_token, "pcy_NMGRL_xyz") + self.assertEqual(cm.exception.lab_name, "NMGRL") + # Still a WorkstationSetupError subclass for any callers + # catching the broader type. + self.assertIsInstance(cm.exception, workstation_setup.WorkstationSetupError) + + def test_empty_api_base_url_aborts_before_any_io(self): + with patch.object(api_client.requests, "post") as post: + with self.assertRaises(workstation_setup.WorkstationSetupError): + workstation_setup.WorkstationSetup.from_device_code( + "", + on_user_code=lambda *a: None, + sleep=lambda s: None, + host="testhost", + ) + post.assert_not_called() + priv = os.path.join(self.tmp, ".pychron", "keys", "pychron_testhost") + self.assertFalse(os.path.isfile(priv)) + + +if __name__ == "__main__": + unittest.main()