From abf5e81f77dcf14de2760a69e24464555168b77d Mon Sep 17 00:00:00 2001 From: jakeross Date: Sun, 10 May 2026 12:54:17 -0600 Subject: [PATCH] feat(cloud): persist device-flow Postgres credentials into DVC prefs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes the last gap in the device-flow onboarding so a technician can go from "fresh workstation" to "DVC connected" without typing, pasting, or running any CLI other than the Pychron Preferences UI: - pychron/cloud/api_client.py — DeviceCodePollSuccess gains database_url and database_role slots. poll_device_code() parses them out of the response and strips database_url from the safe_raw debug-friendly dict so the embedded password cannot leak into caller logs. - pychron/cloud/dvc_credentials.py (NEW) — pure helpers wrapping the parse → CSV → favorites round-trip: * parse_database_url() decomposes a postgresql:// URL into the fields a DVCConnectionItem needs, percent-decoding the password. * build_dvc_connection_csv() serializes the parsed result as a positional CSV that DVCConnectionItem(attrs=...) will rehydrate on next DVC startup. The host field carries host:port so non- default-port Cloud SQL connections work — the underlying CSV schema has no separate port attribute. * merge_dvc_connection_favorites() either replaces an existing cloud- favorite or appends a new one, and demotes any other row that previously held default=True so only one default favorite is active. * apply_db_credentials_to_prefs() ties it together and pushes the list back to pychron.dvc.connection.favorites via the standard apptools preferences adapter. - pychron/cloud/workstation_setup.py — WorkstationSetup gets database_url, database_role, and default_metadata_repo attributes; from_device_code() populates them from the poll- success body. None on either field is a legitimate state — the workstation runs HTTP-only when the bridge has no staged credential. - pychron/cloud/tasks/preferences.py: * New _registration_status field (Registered / Partial / Unregistered) bound to a CustomLabel so the technician sees the workstation's onboarding state at a glance on pane open. * The Start-device-code-enrollment button now refuses to proceed when the workstation is already onboarded (registration.json + keyring token both present) without an explicit confirmation dialog — protects against an accidental tap rotating the SSH keypair and burning a fresh device-code slot. * After a successful enrollment, _persist_db_credentials_from_setup writes the staged Postgres credentials into the DVC connection favorites and the same whoami probe the manual Test Connection button uses runs automatically so the technician sees an immediate end-to-end pass / fail without an extra click. A parse failure is non-fatal — surfaced via the status badge so the rest of the enrollment isn't rolled back. * Re-onboard / revoke / switch-lab buttons all refresh the registration status indicator on completion. - test/cloud/test_dvc_credentials.py (NEW) — 28 unit tests covering parse_database_url, build_dvc_connection_csv, merge_dvc_connection_favorites, _row_set_field (caveman regression for the silent no-op bug), and apply_db_credentials_to_prefs end-to-end against a fake prefs adapter. - test/cloud/test_device_code_setup.py — 3 new tests: db credential fields propagate to the returned WorkstationSetup, missing fields leave attrs at None, and the plaintext password is stripped from DeviceCodePollSuccess.raw the same way api_token already is. Total: 171 cloud tests passing (was 140, +31 new). The companion server-side staging endpoint and admin CLI live in pychronAPI feat/workstation-db-credentials (PR #42). Co-Authored-By: Claude Opus 4.7 (1M context) --- pychron/cloud/api_client.py | 23 +- pychron/cloud/dvc_credentials.py | 327 +++++++++++++++++++++++++++ pychron/cloud/tasks/preferences.py | 221 ++++++++++++++++-- pychron/cloud/workstation_setup.py | 14 ++ test/cloud/test_device_code_setup.py | 94 ++++++++ test/cloud/test_dvc_credentials.py | 302 +++++++++++++++++++++++++ 6 files changed, 964 insertions(+), 17 deletions(-) create mode 100644 pychron/cloud/dvc_credentials.py create mode 100644 test/cloud/test_dvc_credentials.py diff --git a/pychron/cloud/api_client.py b/pychron/cloud/api_client.py index f6b51c0d4..3fb139f23 100644 --- a/pychron/cloud/api_client.py +++ b/pychron/cloud/api_client.py @@ -396,6 +396,14 @@ class DeviceCodePollSuccess(object): keyring before losing the reference. ``ssh_key`` is the same shape that :func:`register_ssh_key` returns so the orchestrator can reuse the existing persist/apply path. + + ``database_url`` carries a per-workstation Postgres connection URL + (``postgresql://role:password@host:port/dbname``) when the off- + cluster admin tool has staged a credential via the bridge's + bootstrap-only ``/internal/workstation-credentials`` endpoint. + ``None`` when no credential is pending — the workstation runs in + HTTP-only mode. Returned exactly once; the staging row is DELETED + on this read so the password is not recoverable later. """ __slots__ = ( @@ -405,6 +413,8 @@ class DeviceCodePollSuccess(object): "default_metadata_repo", "ssh_host_alias", "ssh_key", + "database_url", + "database_role", "raw", ) @@ -417,6 +427,8 @@ def __init__( ssh_host_alias, ssh_key, raw, + database_url=None, + database_role=None, ): self.api_token = api_token self.lab = lab @@ -424,6 +436,8 @@ def __init__( self.default_metadata_repo = default_metadata_repo self.ssh_host_alias = ssh_host_alias or {} self.ssh_key = ssh_key + self.database_url = database_url or None + self.database_role = database_role or None self.raw = raw @@ -554,8 +568,11 @@ def poll_device_code(base_url, device_code, timeout=DEFAULT_TIMEOUT): raw=ssh_key_payload, ) - # Strip the plaintext token from `raw` before exposing it. - safe_raw = {k: v for k, v in body.items() if k != "api_token"} + # Strip the plaintext token AND the database_url (which embeds the + # Postgres role's password) from `raw` before exposing it. Callers + # who serialize `raw` for debugging would otherwise leak both + # bearer secrets into logs/disk. + safe_raw = {k: v for k, v in body.items() if k not in ("api_token", "database_url")} return DeviceCodePollSuccess( api_token=body.get("api_token", ""), @@ -565,6 +582,8 @@ def poll_device_code(base_url, device_code, timeout=DEFAULT_TIMEOUT): ssh_host_alias=body.get("ssh_host_alias") or {}, ssh_key=ssh_key, raw=safe_raw, + database_url=body.get("database_url") or None, + database_role=body.get("database_role") or None, ) diff --git a/pychron/cloud/dvc_credentials.py b/pychron/cloud/dvc_credentials.py new file mode 100644 index 000000000..b576ba57a --- /dev/null +++ b/pychron/cloud/dvc_credentials.py @@ -0,0 +1,327 @@ +# =============================================================================== +# Copyright 2026 Jake Ross +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =============================================================================== +"""DVC connection-prefs persistence for device-flow enrollment. + +After a successful device-code poll the workstation receives a +``database_url`` of the form +``postgresql://role:password@host:port/dbname``. This module parses +that URL and writes the result as a ``DVCConnectionItem`` favorite to +the ``pychron.dvc.connection`` Envisage preference node so the next +DVC startup picks up the new credentials without any manual paste. + +Kept as pure functions on the Envisage prefs adapter so the unit tests +can exercise the CSV / favorites round-trip without spinning up the +full Traits / Envisage stack. +""" + +from __future__ import absolute_import + +import logging +from urllib.parse import unquote, urlparse + +logger = logging.getLogger(__name__) + + +# Order MUST match DVCConnectionItem.attributes in +# pychron/dvc/tasks/dvc_preferences.py. CSV is positional — we cannot +# pass kwargs to the on-disk format. +_DVC_CONNECTION_ATTRS = ( + "name", + "kind", + "username", + "host", + "dbname", + "password", + "enabled", + "default", + "path", + "organization", + "meta_repo_name", + "meta_repo_dir", + "timeout", + "repository_root", + "connection_method", + "cloudsql_instance_connection_name", + "cloudsql_ip_type", + "cloudsql_service_account_email", + "cloudsql_service_account_key_path", +) + +# Sentinel used to mark favorites added by the device-flow path so a +# re-enrollment for the same lab REPLACES rather than stacking entries. +CLOUD_FAV_PREFIX = "cloud-" + + +class DatabaseUrlParseError(ValueError): + """Raised when ``database_url`` cannot be parsed into the fields a + ``DVCConnectionItem`` needs.""" + + +def parse_database_url(url): + """Parse a ``postgresql://`` URL into the components a + ``DVCConnectionItem`` needs. + + Returns a dict with keys ``host``, ``port`` (int or None), + ``username``, ``password``, ``dbname``. Percent-encoded userinfo + components (per RFC 3986) are decoded so the workstation gets the + raw password the server-side admin tool actually generated. + + Raises :class:`DatabaseUrlParseError` on malformed input — the + caller is expected to fall back to leaving prefs unchanged so the + technician is not silently locked out. + """ + if not url: + raise DatabaseUrlParseError("empty url") + parts = urlparse(url) + if parts.scheme not in ("postgresql", "postgres"): + raise DatabaseUrlParseError("expected postgresql:// scheme, got {!r}".format(parts.scheme)) + if not parts.hostname: + raise DatabaseUrlParseError("url is missing host") + dbname = parts.path.lstrip("/") if parts.path else "" + if not dbname: + raise DatabaseUrlParseError("url is missing database name") + return { + "host": parts.hostname, + "port": parts.port, + "username": unquote(parts.username) if parts.username else "", + "password": unquote(parts.password) if parts.password else "", + "dbname": dbname, + } + + +def _row_to_csv(values): + """Join the favorite's positional fields with commas. Mirrors + :func:`pychron.core.helpers.strtools.to_csv_str` so the resulting + CSV round-trips through ``DVCConnectionItem.__init__(attrs=...)``. + + A password that contains a literal comma would corrupt the CSV. + The pychronAPI admin CLI uses a ``[a-zA-Z0-9]`` alphabet so this + cannot happen for credentials minted via the device flow, but we + raise loudly if it ever does so the caller knows the favorite is + unsafe to write. + """ + out = [] + for attr, value in zip(_DVC_CONNECTION_ATTRS, values): + s = "" if value is None else str(value) + if "," in s: + raise DatabaseUrlParseError( + "{} contains a literal comma which would corrupt the " + "CSV-encoded favorites preference".format(attr) + ) + out.append(s) + return ",".join(out) + + +def build_dvc_connection_csv( + parsed, + name, + organization="", + meta_repo_name="", + meta_repo_dir="", + repository_root="", +): + """Serialize a parsed ``database_url`` as a positional-CSV row that + ``DVCConnectionItem(attrs=)`` can re-hydrate. + + Marked ``enabled=True`` and ``default=True`` so the next DVC + startup picks the new entry without further user action. + """ + if "host" not in parsed: + raise DatabaseUrlParseError("parsed url is missing host") + # DVCConnectionItem has no separate port attribute — the SQLAlchemy + # URL builder in pychron/database/core/database_adapter.py:557 + # interpolates ``host`` directly into the connection string, so + # encode port as ``host:port``. Skipping the port silently demotes + # everything to the dialect default (5432 for postgresql) which + # corrupts connections to non-default Cloud SQL ports. + host = parsed.get("host", "") + port = parsed.get("port") + if host and port: + host = "{}:{}".format(host, port) + values = [ + name, # name + "postgresql", # kind + parsed.get("username", ""), # username + host, # host (host[:port]) + parsed.get("dbname", ""), # dbname + parsed.get("password", ""), # password + "True", # enabled + "True", # default + "", # path (sqlite-only) + organization, # organization + meta_repo_name, # meta_repo_name + meta_repo_dir, # meta_repo_dir + "5", # timeout + repository_root, # repository_root + "direct", # connection_method + "", # cloudsql_instance_connection_name + "public", # cloudsql_ip_type + "", # cloudsql_service_account_email + "", # cloudsql_service_account_key_path + ] + return _row_to_csv(values) + + +def _favorite_name(row): + """First field of a favorites CSV is the user-visible name. Used to + de-duplicate when re-enrolling the same lab.""" + if not row: + return "" + return row.split(",", 1)[0] + + +def merge_dvc_connection_favorites(existing, new_row, replace_name): + """Return the new favorites list with ``replace_name`` (if any + matching row exists) replaced by ``new_row``, or with ``new_row`` + appended otherwise. Existing rows whose name matches but is not the + replacement target are left alone — the user may have set up other + connections by hand. + + Also strips the ``default=True`` flag from any other row that had + it, since CSV position 8 (zero-indexed 7) is ``default``. We only + want one default favorite at a time. + """ + out = [] + replaced = False + new_default = _row_field(new_row, 7) == "True" + for row in existing or []: + name = _favorite_name(row) + if name == replace_name: + out.append(new_row) + replaced = True + continue + if new_default: + row = _row_set_field(row, 7, "False") + out.append(row) + if not replaced: + out.append(new_row) + return out + + +def _row_field(row, idx): + parts = row.split(",") + if idx < len(parts): + return parts[idx] + return "" + + +def _row_set_field(row, idx, value): + """Set the ``idx``-th comma-separated field of ``row`` to ``value``, + extending the row with empty fields if it is shorter than ``idx``. + + Older saved favorites may have been written before + ``DVCConnectionItem.attributes`` grew its current set of fields, + so a short row is the common case rather than an exception. + Silently dropping the update would leave a stale ``default=True`` + on a prior favorite when re-enrolling, demoting the new + cloud-minted credential to non-default and breaking the + no-manual-paste contract. + """ + parts = row.split(",") + while len(parts) <= idx: + parts.append("") + parts[idx] = value + return ",".join(parts) + + +def apply_db_credentials_to_prefs( + preferences, + database_url, + database_role=None, + lab_name="", + organization="", + meta_repo_name="", + meta_repo_dir="", + repository_root="", +): + """Write a parsed ``database_url`` into the + ``pychron.dvc.connection.favorites`` pref node as a new (or + replacing) ``DVCConnectionItem`` favorite. + + ``preferences`` is the Envisage preferences adapter -- anything + with ``get(key, default=None)``, ``set(key, value)``, and + ``flush()``. Tests pass a fake. + + Returns the canonical favorite name on success, or ``None`` when + there is no credential to apply (``database_url`` is falsy). + """ + if not database_url: + return None + parsed = parse_database_url(database_url) + name = _favorite_name_for_lab(lab_name) + new_row = build_dvc_connection_csv( + parsed, + name=name, + organization=organization, + meta_repo_name=meta_repo_name, + meta_repo_dir=meta_repo_dir, + repository_root=repository_root, + ) + + raw = preferences.get("pychron.dvc.connection.favorites", "") or "" + existing = _split_favorites(raw) + merged = merge_dvc_connection_favorites(existing, new_row, replace_name=name) + preferences.set("pychron.dvc.connection.favorites", _join_favorites(merged)) + preferences.flush() + logger.info( + "applied DVC connection favorite name=%s host=%s db=%s role=%s", + name, + parsed.get("host", ""), + parsed.get("dbname", ""), + database_role or parsed.get("username", ""), + ) + return name + + +def _favorite_name_for_lab(lab_name): + safe = "".join(c for c in (lab_name or "default") if c.isalnum() or c in "-_") + return "{}{}".format(CLOUD_FAV_PREFIX, safe or "default") + + +def _split_favorites(raw): + """Envisage stores a List trait as a Python-repr-ish string. The + ``FavoritesPreferencesHelper`` round-trips through + ``self.favorites = [...]`` which Envisage serializes / deserializes + on its own. When we read raw via ``preferences.get(...)`` we may + see either a list (already deserialized) or a string we must + parse. + """ + if raw is None: + return [] + if isinstance(raw, list): + return [str(item) for item in raw] + if isinstance(raw, str): + s = raw.strip() + if not s: + return [] + # Envisage's PreferencesHelper writes List traits as + # repr-like strings — try literal_eval first. + try: + import ast + + parsed = ast.literal_eval(s) + except (ValueError, SyntaxError): + return [s] + if isinstance(parsed, list): + return [str(item) for item in parsed] + return [s] + return [str(raw)] + + +def _join_favorites(items): + """Inverse of :func:`_split_favorites`. Envisage will store this + string into the preferences node and re-deserialize it on read.""" + return repr(list(items)) diff --git a/pychron/cloud/tasks/preferences.py b/pychron/cloud/tasks/preferences.py index 5c67e2d09..2fe06b3a7 100644 --- a/pychron/cloud/tasks/preferences.py +++ b/pychron/cloud/tasks/preferences.py @@ -30,11 +30,40 @@ import logging +import os +from urllib.parse import urlparse, urlunparse + from envisage.ui.tasks.preferences_pane import PreferencesPane from pyface.api import GUI +from pyface.image_resource import ImageResource +from pyface.ui_traits import Image from traits.api import Bool, Button, File, Password, Str from traitsui.api import Color, Group, HGroup, ImageEditor, Item, VGroup, View + +_BLANK_QR_IMAGE = ImageResource("blank") + + +def _swap_origin(url, new_origin): + """Replace scheme+netloc of ``url`` with that of ``new_origin``. + + Server-supplied ``verification_url`` may point at a misconfigured + host (e.g. ``api.example.com``); the workstation operator can + override the public-facing host without redeploying the API. + Returns ``url`` unchanged if either side fails to parse. + """ + if not (url and new_origin): + return url + try: + p = urlparse(url) + np = urlparse(new_origin) + except ValueError: + return url + if not np.netloc: + return url + return urlunparse((np.scheme or p.scheme, np.netloc, p.path, p.params, p.query, p.fragment)) + + from pychron.cloud.api_client import ( CloudAPIError, CloudAuthError, @@ -43,6 +72,10 @@ CloudNetworkError, whoami, ) +from pychron.cloud.dvc_credentials import ( + DatabaseUrlParseError, + apply_db_credentials_to_prefs, +) from pychron.cloud.keyring_store import ( delete_token, get_token, @@ -54,6 +87,7 @@ KeyringWriteFailedError, WorkstationSetup, WorkstationSetupError, + load_registration, switch_lab as wipe_for_switch_lab, ) from pychron.core.confirmation import confirmation_dialog @@ -76,6 +110,7 @@ class CloudPreferences(BasePreferencesHelper): api_base_url = Str lab_name = Str api_token = Password + verification_url_override = Str test_connection = Button reonboard_button = Button("Re-onboard workstation") @@ -96,6 +131,7 @@ class CloudPreferences(BasePreferencesHelper): # user_code by hand. Empty string until the server returns the # `verification_url_complete` payload. _pending_qr_path = File + _pending_qr_image = Image(_BLANK_QR_IMAGE) _pending_active = Bool(False) _should_cancel_enrollment = Bool(False) @@ -108,12 +144,43 @@ class CloudPreferences(BasePreferencesHelper): _remote_status = Str _remote_status_color = Color + # Surfaces "Registered" / "Unregistered" on pane open so the + # technician can tell at a glance whether the workstation is + # already onboarded. Derived from the on-disk + # ``~/.pychron/registration.json`` + keyring token. + _registration_status = Str + _registration_status_color = Color + def _remote_status_color_default(self): return normalize_color_name("red") + def _registration_status_color_default(self): + return normalize_color_name("red") + def _initialize(self, *args, **kw): super(CloudPreferences, self)._initialize(*args, **kw) self._load_token_from_keyring() + self._refresh_registration_status() + + def _refresh_registration_status(self): + """Update the Registered/Unregistered indicator based on local state. + + Considered "Registered" iff a registration.json exists AND the + keyring carries a token for the configured lab. Either alone is + a half-onboarded state (e.g. keyring wipe + stale json) — call + that "Partial" so the technician knows to re-onboard. + """ + reg = load_registration() + token = get_token(self.lab_name) if self.lab_name else "" + if reg and token: + self._registration_status = "Registered" + self._registration_status_color = normalize_color_name("green") + elif reg or token: + self._registration_status = "Partial — re-onboard recommended" + self._registration_status_color = normalize_color_name("orange") + else: + self._registration_status = "Unregistered" + self._registration_status_color = normalize_color_name("red") def _is_preference_trait(self, trait_name): # api_token must never be written to the .cfg — it lives in the OS @@ -123,6 +190,8 @@ def _is_preference_trait(self, trait_name): "api_token", "_remote_status", "_remote_status_color", + "_registration_status", + "_registration_status_color", "test_connection", "reonboard_button", "revoke_button", @@ -132,6 +201,7 @@ def _is_preference_trait(self, trait_name): "_pending_user_code", "_pending_verification_url", "_pending_qr_path", + "_pending_qr_image", "_pending_active", "_should_cancel_enrollment", "_recovery_token", @@ -150,6 +220,7 @@ def _lab_name_changed(self, old, new): # there so the user sees the right token without re-entering it. if old != new: self._load_token_from_keyring() + self._refresh_registration_status() def _api_token_changed(self, old, new): if not self.lab_name: @@ -211,6 +282,24 @@ def _enroll_via_device_code_button_fired(self): self._remote_status = "Set API Base URL first" return + # Re-registration guardrail: a workstation that already has a + # registration.json + keyring token is functional; an admin + # tap on the button could otherwise silently rotate the SSH + # key and burn a fresh device-code slot. Require explicit + # confirmation before continuing. + existing_reg = load_registration() + existing_token = get_token(self.lab_name) if self.lab_name else "" + if existing_reg and existing_token: + if not confirmation_dialog( + "This workstation is already registered with Pychron Cloud " + "as lab '{}'. Re-enrolling will rotate the SSH keypair and " + "mint a new API token. Continue?".format(self.lab_name or "?"), + title="Re-register workstation", + ): + self._remote_status = "Already registered — cancelled" + self._remote_status_color = normalize_color_name("orange") + return + self._should_cancel_enrollment = False self._pending_user_code = "" self._pending_verification_url = "" @@ -232,24 +321,39 @@ def _enroll_via_device_code_button_fired(self): def _on_device_code_user_code( self, user_code, verification_url, verification_url_complete, expires_at ): - """Worker-thread callback: surface the user_code + URL + QR in the pane. - - Trait writes from non-UI threads are dispatched to the UI thread - by the Pyface event loop, so the operator sees the code as soon - as the server returns it. QR generation runs on this thread - (small file, ~hundreds of microseconds for a typical URL); a - failure is non-fatal — the typed code + URL still work. - """ - self._pending_user_code = user_code - self._pending_verification_url = verification_url + if self.verification_url_override: + verification_url = _swap_origin(verification_url, self.verification_url_override) + verification_url_complete = _swap_origin( + verification_url_complete, self.verification_url_override + ) try: - self._pending_qr_path = make_qr_for_device_code( + qr_path = make_qr_for_device_code( verification_url_complete, host_slug=self.lab_name or "default" ) except Exception as exc: logger.warning("device-code QR generation failed: %s", exc) - self._pending_qr_path = "" - self._remote_status = "Show {} to admin at {}".format(user_code, verification_url) + qr_path = "" + if qr_path: + d, n = os.path.split(qr_path) + qr_image = ImageResource(name=n, search_path=[d]) + else: + qr_image = _BLANK_QR_IMAGE + status = "Show {} to admin at {}".format(user_code, verification_url) + GUI.invoke_later( + self._apply_pending_user_code, + user_code, + verification_url, + qr_path, + qr_image, + status, + ) + + def _apply_pending_user_code(self, user_code, verification_url, qr_path, qr_image, status): + self._pending_user_code = user_code + self._pending_verification_url = verification_url + self._pending_qr_path = qr_path + self._pending_qr_image = qr_image + self._remote_status = status self._remote_status_color = normalize_color_name("orange") def _enrollment_worker(self): @@ -321,9 +425,71 @@ def _apply_enrollment_success(self, setup): self.api_base_url = setup.api_base_url self.lab_name = setup.lab_name self._load_token_from_keyring() - self._remote_status = "Enrolled as {}".format(setup.lab_name) + # Persist Postgres credentials (if the bridge had one staged) + # to ``pychron.dvc.connection.favorites`` so DVC startup picks + # them up on the next run with no manual paste. ``None`` is a + # legitimate state — the workstation runs HTTP-only. + db_applied = self._persist_db_credentials_from_setup(setup) + if db_applied: + self._remote_status = "Enrolled as {} (DB credentials applied)".format(setup.lab_name) + else: + self._remote_status = "Enrolled as {}".format(setup.lab_name) self._remote_status_color = normalize_color_name("green") + self._refresh_registration_status() self._reset_pending() + # Run the same whoami probe the manual "Test Connection" button + # uses, so the technician sees an immediate, end-to-end pass / + # fail without an extra click. Failures here do NOT roll back + # the enrollment — the credentials are already minted and + # persisted; surface the failure as a status badge instead. + self._test_connection_fired() + + def _persist_db_credentials_from_setup(self, setup): + """Apply :attr:`WorkstationSetup.database_url` to the DVC + connection prefs. Returns True when something was written. + + Errors are caught + logged + surfaced via remote_status so a + bad URL does not roll back the rest of enrollment (the cloud + prefs + ssh + keyring are already on disk by the time we get + here). + """ + if not getattr(setup, "database_url", None): + return False + # Pull MetaData metadata off the setup if available so the + # favorite carries an organization + meta_repo_name. Falls + # back to lab_name when the server omitted the block. + meta = getattr(setup, "default_metadata_repo", None) or {} + repo_id = meta.get("repository_identifier", "") if isinstance(meta, dict) else "" + organization = "" + meta_repo_name = "" + if "/" in repo_id: + organization, meta_repo_name = repo_id.split("/", 1) + elif repo_id: + meta_repo_name = repo_id + organization = organization or setup.lab_name or "" + try: + apply_db_credentials_to_prefs( + self.preferences, + database_url=setup.database_url, + database_role=setup.database_role, + lab_name=setup.lab_name, + organization=organization, + meta_repo_name=meta_repo_name, + ) + except DatabaseUrlParseError as exc: + logger.warning( + "device-code DB credential parse failed (skipping DVC prefs): %s", + exc, + ) + self._remote_status = "Enrolled — DB URL malformed, prefs unchanged" + self._remote_status_color = normalize_color_name("orange") + return False + except Exception as exc: # defensive — never abort enrollment + logger.warning("device-code DB credential persist failed: %s", exc) + self._remote_status = "Enrolled — DB prefs write failed" + self._remote_status_color = normalize_color_name("orange") + return False + return True def _apply_enrollment_terminal(self, message, color): self._remote_status = message @@ -349,6 +515,7 @@ def _reset_pending(self): self._pending_user_code = "" self._pending_verification_url = "" self._pending_qr_path = "" + self._pending_qr_image = _BLANK_QR_IMAGE self._pending_active = False self._should_cancel_enrollment = False @@ -383,6 +550,7 @@ def _reonboard_button_fired(self): return self._remote_status = "Re-onboarded" self._remote_status_color = normalize_color_name("green") + self._refresh_registration_status() def _revoke_button_fired(self): self._remote_status_color = normalize_color_name("red") @@ -409,6 +577,7 @@ def _revoke_button_fired(self): if self.lab_name: delete_token(self.lab_name) self.trait_setq(api_token="") + self._refresh_registration_status() def _switch_lab_button_fired(self): self._remote_status_color = normalize_color_name("red") @@ -429,6 +598,7 @@ def _switch_lab_button_fired(self): self.trait_setq(api_token="", lab_name="", api_base_url="") self._remote_status = "Wiped; configure new lab above" self._remote_status_color = normalize_color_name("orange") + self._refresh_registration_status() class CloudPreferencesPane(PreferencesPane): @@ -465,6 +635,16 @@ def traits_view(self): resizable=True, label="API Token", ), + Item( + "verification_url_override", + tooltip="Optional. If the server returns a verification_url " + "with the wrong public host (e.g. api.example.com), set the " + "correct origin here (e.g. https://console.pychronlabs.com). " + "The scheme+host is swapped; the path and user_code query are " + "preserved.", + resizable=True, + label="Verification URL Override", + ), HGroup( test_connection_item(), CustomLabel( @@ -473,6 +653,15 @@ def traits_view(self): color_name="_remote_status_color", ), ), + HGroup( + CustomLabel( + "_registration_status", + width=240, + color_name="_registration_status_color", + ), + label="Status", + show_border=False, + ), show_border=True, label="Pychron Cloud (pychronAPI)", ) @@ -509,9 +698,11 @@ def traits_view(self): ), HGroup( Item( - "_pending_qr_path", + "_pending_qr_image", show_label=False, editor=ImageEditor(), + width=300, + height=300, tooltip="Scan with the admin's phone to open the " "verification page with the user_code pre-filled.", visible_when="_pending_qr_path != ''", diff --git a/pychron/cloud/workstation_setup.py b/pychron/cloud/workstation_setup.py index aa05a2541..a9fe6908b 100644 --- a/pychron/cloud/workstation_setup.py +++ b/pychron/cloud/workstation_setup.py @@ -140,6 +140,17 @@ def __init__(self, api_base_url, api_token, lab_name, host=None): self.api_token = api_token self.lab_name = lab_name self.host = host or host_slug() + # Populated by :meth:`from_device_code` when the bridge has a + # staged Postgres credential for this api_token. ``None`` means + # the workstation runs HTTP-only — DVC connection prefs are + # left untouched. + self.database_url = None + self.database_role = None + # Default-MetaData repo metadata so the prefs pane can write + # ``pychron.dvc.connection`` favorites with the right org + + # meta_repo_name without re-deriving them from + # ``registration.json``. + self.default_metadata_repo = None # -- device-code enrollment ---------------------------------------- @@ -248,6 +259,9 @@ def from_device_code( lab_name=success.lab, host=host, ) + setup.database_url = success.database_url + setup.database_role = success.database_role + setup.default_metadata_repo = success.default_metadata_repo setup._persist_registration(success.ssh_key) setup._apply_ssh_config(success.ssh_key) diff --git a/test/cloud/test_device_code_setup.py b/test/cloud/test_device_code_setup.py index 33777191f..fdd1dee9c 100644 --- a/test/cloud/test_device_code_setup.py +++ b/test/cloud/test_device_code_setup.py @@ -302,3 +302,97 @@ def test_empty_api_base_url_aborts_before_any_io(self): if __name__ == "__main__": unittest.main() + + +class FromDeviceCodeDbCredentialsTestCase(unittest.TestCase): + """The poll-success body now optionally carries a ``database_url`` + + ``database_role`` minted by the off-cluster admin tool. The + orchestrator must surface those onto the returned + ``WorkstationSetup`` so the prefs pane can persist them to the + DVC connection favorites — without leaking the URL into + ``DeviceCodePollSuccess.raw`` (which is exposed for debug logs). + """ + + URL = "https://api.example" + + def setUp(self): + self.tmp = tempfile.mkdtemp() + self._patcher = patch( + "pychron.cloud.paths.os.path.expanduser", + lambda p: p.replace("~", self.tmp), + ) + self._patcher.start() + self.addCleanup(self._patcher.stop) + + def _rmtree(): + import shutil + + shutil.rmtree(self.tmp, ignore_errors=True) + + self.addCleanup(_rmtree) + + def _poll_body_with_db(self): + body = _poll_body() + body["database_url"] = "postgresql://wkstn_x:Pa55@10.0.1.5:5432/nmgrl?sslmode=require" + body["database_role"] = "wkstn_x" + return body + + def test_db_credential_fields_propagate_to_setup(self): + with ( + patch.object(api_client.requests, "post") as post, + patch.object(workstation_setup, "keyring_set_token", return_value=True), + ): + post.side_effect = [ + _resp(201, _START_BODY), + _resp(200, self._poll_body_with_db()), + ] + setup = workstation_setup.WorkstationSetup.from_device_code( + self.URL, + on_user_code=lambda *a: None, + sleep=lambda s: None, + host="testhost", + ) + + self.assertEqual( + setup.database_url, + "postgresql://wkstn_x:Pa55@10.0.1.5:5432/nmgrl?sslmode=require", + ) + self.assertEqual(setup.database_role, "wkstn_x") + + def test_no_db_credential_leaves_setup_attrs_none(self): + """When the bridge does not stage a credential, the setup + carries ``None`` for both DB fields — the prefs pane uses + this to skip writing DVC favorites and leave the existing + connection list untouched.""" + with ( + patch.object(api_client.requests, "post") as post, + patch.object(workstation_setup, "keyring_set_token", return_value=True), + ): + post.side_effect = [ + _resp(201, _START_BODY), + _resp(200, _poll_body()), # no database_url + ] + setup = workstation_setup.WorkstationSetup.from_device_code( + self.URL, + on_user_code=lambda *a: None, + sleep=lambda s: None, + host="testhost", + ) + + self.assertIsNone(setup.database_url) + self.assertIsNone(setup.database_role) + + def test_database_url_stripped_from_raw_debug_field(self): + """The plaintext password embedded in ``database_url`` must + not survive into the ``raw`` dict that callers may log for + debugging — same defensive treatment we give ``api_token``.""" + with patch.object(api_client.requests, "post") as post: + post.side_effect = [ + _resp(200, self._poll_body_with_db()), + ] + success = api_client.poll_device_code(self.URL, "dvc_xyz") + + self.assertNotIn("database_url", success.raw) + self.assertNotIn("api_token", success.raw) + # But the typed attribute still carries it for the orchestrator. + self.assertIn("Pa55", success.database_url) diff --git a/test/cloud/test_dvc_credentials.py b/test/cloud/test_dvc_credentials.py new file mode 100644 index 000000000..14043b234 --- /dev/null +++ b/test/cloud/test_dvc_credentials.py @@ -0,0 +1,302 @@ +# =============================================================================== +# Copyright 2026 Jake Ross +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =============================================================================== +"""Tests for the device-flow → DVC connection-prefs persistence path.""" + +from __future__ import absolute_import + +import unittest + +from pychron.cloud.dvc_credentials import ( + DatabaseUrlParseError, + _favorite_name_for_lab, + _row_set_field, + apply_db_credentials_to_prefs, + build_dvc_connection_csv, + merge_dvc_connection_favorites, + parse_database_url, +) + + +class FakePreferences(object): + """Minimal apptools.preferences-like adapter for unit tests.""" + + def __init__(self, initial=None): + self._store = dict(initial or {}) + self._flushed = False + + def get(self, key, default=None): + return self._store.get(key, default) + + def set(self, key, value): + self._store[key] = value + + def flush(self): + self._flushed = True + + +class ParseDatabaseUrlTestCase(unittest.TestCase): + def test_basic_url_round_trip(self): + out = parse_database_url("postgresql://wkstn_x:secret@10.0.1.5:5432/nmgrl") + self.assertEqual(out["host"], "10.0.1.5") + self.assertEqual(out["port"], 5432) + self.assertEqual(out["username"], "wkstn_x") + self.assertEqual(out["password"], "secret") + self.assertEqual(out["dbname"], "nmgrl") + + def test_postgres_alias_scheme_accepted(self): + out = parse_database_url("postgres://user:pw@h/db") + self.assertEqual(out["dbname"], "db") + + def test_no_port_defaults_to_none(self): + out = parse_database_url("postgresql://user:pw@h/db") + self.assertIsNone(out["port"]) + + def test_percent_encoded_password_decoded(self): + out = parse_database_url("postgresql://user:p%40ss@h:5432/db") + self.assertEqual(out["password"], "p@ss") + + def test_empty_url_raises(self): + with self.assertRaises(DatabaseUrlParseError): + parse_database_url("") + + def test_non_postgres_scheme_raises(self): + with self.assertRaises(DatabaseUrlParseError): + parse_database_url("mysql://u:p@h/db") + + def test_missing_host_raises(self): + with self.assertRaises(DatabaseUrlParseError): + parse_database_url("postgresql:///db") + + def test_missing_dbname_raises(self): + with self.assertRaises(DatabaseUrlParseError): + parse_database_url("postgresql://u:p@h/") + + +class BuildDvcConnectionCsvTestCase(unittest.TestCase): + def test_csv_field_order_matches_attributes(self): + parsed = parse_database_url("postgresql://wkstn_x:secret@10.0.1.5:5432/nmgrl") + csv = build_dvc_connection_csv( + parsed, + name="cloud-nmgrl", + organization="nmgrl", + meta_repo_name="MetaData", + ) + parts = csv.split(",") + self.assertEqual(parts[0], "cloud-nmgrl") + self.assertEqual(parts[1], "postgresql") + self.assertEqual(parts[2], "wkstn_x") + self.assertEqual(parts[3], "10.0.1.5:5432") # host:port + self.assertEqual(parts[4], "nmgrl") + self.assertEqual(parts[5], "secret") + self.assertEqual(parts[6], "True") # enabled + self.assertEqual(parts[7], "True") # default + self.assertEqual(parts[9], "nmgrl") # organization + self.assertEqual(parts[10], "MetaData") # meta_repo_name + + def test_port_appended_to_host(self): + """Caveman finding: port was lost. Host field MUST carry + host:port so non-default-port Cloud SQL connections work.""" + parsed = parse_database_url("postgresql://u:p@db.lab.example.com:6543/nmgrl") + csv = build_dvc_connection_csv(parsed, name="cloud-nmgrl") + self.assertIn("db.lab.example.com:6543", csv) + + def test_no_port_omits_colon(self): + parsed = parse_database_url("postgresql://u:p@db.lab/nmgrl") + csv = build_dvc_connection_csv(parsed, name="cloud-nmgrl") + parts = csv.split(",") + self.assertEqual(parts[3], "db.lab") + self.assertNotIn(":", parts[3]) + + def test_password_with_comma_rejected(self): + """A literal comma in the password would corrupt the CSV + positional encoding. Caller must catch + abort rather than + write a poison favorite.""" + parsed = parse_database_url("postgresql://u:p@h/db") + parsed["password"] = "a,b" + with self.assertRaises(DatabaseUrlParseError): + build_dvc_connection_csv(parsed, name="cloud-x") + + +class MergeFavoritesTestCase(unittest.TestCase): + def test_appends_when_no_match(self): + existing = ["myhand,postgresql,me,h,db,p,True,False,,"] + new_row = "cloud-nmgrl,postgresql,wkstn,h,db,p,True,True,," + out = merge_dvc_connection_favorites(existing, new_row, "cloud-nmgrl") + self.assertEqual(len(out), 2) + self.assertEqual(out[1], new_row) + + def test_replaces_matching_row(self): + existing = [ + "myhand,postgresql,me,h,db,p,True,False,,", + "cloud-nmgrl,postgresql,old,h,db,old,True,True,,", + ] + new_row = "cloud-nmgrl,postgresql,new,h,db,new,True,True,," + out = merge_dvc_connection_favorites(existing, new_row, "cloud-nmgrl") + self.assertEqual(len(out), 2) + self.assertEqual(out[1], new_row) + + def test_clears_default_flag_on_other_rows(self): + """Adding a new default=True favorite must demote any other + row that previously held the default flag — only one default + at a time.""" + existing = [ + "myhand,postgresql,me,h,db,p,True,True,,", + ] + new_row = "cloud-nmgrl,postgresql,wkstn,h,db,p,True,True,," + out = merge_dvc_connection_favorites(existing, new_row, "cloud-nmgrl") + self.assertIn("True,False", out[0]) # default field flipped + self.assertIn("cloud-nmgrl", out[1]) + + def test_short_row_extended_when_clearing_default(self): + """Caveman finding: _row_set_field used to silently no-op + when idx >= len(parts), leaving stale default=True flags on + legacy short-format rows.""" + existing = ["legacy,postgresql,me,h"] # 4 fields only + new_row = "cloud-x,postgresql,wkstn,h,db,p,True,True,," + out = merge_dvc_connection_favorites(existing, new_row, "cloud-x") + # The old short row was extended so position 7 now holds the + # new default=False marker rather than being missing. + legacy_parts = out[0].split(",") + self.assertGreater(len(legacy_parts), 7) + self.assertEqual(legacy_parts[7], "False") + + def test_no_default_change_when_new_row_isnt_default(self): + """Edge case: if the new row isn't default, leave existing + default flags alone.""" + existing = ["myhand,postgresql,me,h,db,p,True,True,,"] + new_row = "cloud-x,postgresql,wkstn,h,db,p,True,False,," + out = merge_dvc_connection_favorites(existing, new_row, "cloud-x") + self.assertIn("True,True", out[0]) # original default preserved + + +class RowSetFieldTestCase(unittest.TestCase): + def test_extends_short_row(self): + """Regression: `_row_set_field` used to drop updates whose + index exceeded the row length.""" + out = _row_set_field("a,b", 5, "X") + parts = out.split(",") + self.assertEqual(len(parts), 6) + self.assertEqual(parts[5], "X") + # In-range positions preserved. + self.assertEqual(parts[0], "a") + self.assertEqual(parts[1], "b") + + def test_in_range_update(self): + out = _row_set_field("a,b,c,d", 2, "X") + self.assertEqual(out, "a,b,X,d") + + +class FavoriteNameForLabTestCase(unittest.TestCase): + def test_safe_lab(self): + self.assertEqual(_favorite_name_for_lab("nmgrl"), "cloud-nmgrl") + + def test_strips_unsafe_chars(self): + self.assertEqual(_favorite_name_for_lab("nm grl/!"), "cloud-nmgrl") + + def test_empty_lab_falls_back(self): + self.assertEqual(_favorite_name_for_lab(""), "cloud-default") + + def test_none_lab_falls_back(self): + self.assertEqual(_favorite_name_for_lab(None), "cloud-default") + + +class ApplyDbCredentialsToPrefsTestCase(unittest.TestCase): + def test_writes_favorite_into_empty_prefs(self): + prefs = FakePreferences() + name = apply_db_credentials_to_prefs( + prefs, + database_url="postgresql://wkstn_x:secret@10.0.1.5:5432/nmgrl", + lab_name="nmgrl", + organization="nmgrl", + meta_repo_name="MetaData", + ) + self.assertEqual(name, "cloud-nmgrl") + favs = prefs.get("pychron.dvc.connection.favorites") + self.assertIsNotNone(favs) + # Round-trip through repr — that's what _join_favorites uses. + import ast + + items = ast.literal_eval(favs) + self.assertEqual(len(items), 1) + self.assertIn("cloud-nmgrl", items[0]) + self.assertIn("10.0.1.5:5432", items[0]) + self.assertIn("secret", items[0]) + self.assertTrue(prefs._flushed) + + def test_replaces_prior_cloud_favorite(self): + """Re-enrolling the same lab must REPLACE the prior cloud-* + favorite, not stack them.""" + prefs = FakePreferences() + apply_db_credentials_to_prefs( + prefs, + database_url="postgresql://wkstn_old:old@h:5432/nmgrl", + lab_name="nmgrl", + ) + apply_db_credentials_to_prefs( + prefs, + database_url="postgresql://wkstn_new:new@h:5432/nmgrl", + lab_name="nmgrl", + ) + import ast + + items = ast.literal_eval(prefs.get("pychron.dvc.connection.favorites")) + cloud_items = [r for r in items if r.startswith("cloud-nmgrl,")] + self.assertEqual(len(cloud_items), 1) + self.assertIn("new", cloud_items[0]) + self.assertNotIn("old", cloud_items[0]) + + def test_preserves_user_defined_favorites(self): + legacy = "myhand,postgresql,me,otherhost,otherdb,mypw,True,True,," + prefs = FakePreferences( + initial={ + "pychron.dvc.connection.favorites": repr([legacy]), + } + ) + apply_db_credentials_to_prefs( + prefs, + database_url="postgresql://wkstn_x:secret@10.0.1.5:5432/nmgrl", + lab_name="nmgrl", + ) + import ast + + items = ast.literal_eval(prefs.get("pychron.dvc.connection.favorites")) + # legacy row preserved (not deleted) + self.assertTrue(any(r.startswith("myhand,") for r in items)) + # new row appended + self.assertTrue(any(r.startswith("cloud-nmgrl,") for r in items)) + # legacy default flag flipped to False (only one default) + legacy_row = [r for r in items if r.startswith("myhand,")][0] + self.assertEqual(legacy_row.split(",")[7], "False") + + def test_none_url_returns_none_without_writing(self): + prefs = FakePreferences() + out = apply_db_credentials_to_prefs(prefs, database_url=None, lab_name="nmgrl") + self.assertIsNone(out) + self.assertFalse(prefs._flushed) + self.assertIsNone(prefs.get("pychron.dvc.connection.favorites")) + + def test_malformed_url_raises(self): + prefs = FakePreferences() + with self.assertRaises(DatabaseUrlParseError): + apply_db_credentials_to_prefs( + prefs, + database_url="not-a-url", + lab_name="nmgrl", + ) + + +if __name__ == "__main__": # pragma: no cover + unittest.main()