diff --git a/pychron/cloud/api_client.py b/pychron/cloud/api_client.py index f6b51c0d4..3fb139f23 100644 --- a/pychron/cloud/api_client.py +++ b/pychron/cloud/api_client.py @@ -396,6 +396,14 @@ class DeviceCodePollSuccess(object): keyring before losing the reference. ``ssh_key`` is the same shape that :func:`register_ssh_key` returns so the orchestrator can reuse the existing persist/apply path. + + ``database_url`` carries a per-workstation Postgres connection URL + (``postgresql://role:password@host:port/dbname``) when the off- + cluster admin tool has staged a credential via the bridge's + bootstrap-only ``/internal/workstation-credentials`` endpoint. + ``None`` when no credential is pending — the workstation runs in + HTTP-only mode. Returned exactly once; the staging row is DELETED + on this read so the password is not recoverable later. """ __slots__ = ( @@ -405,6 +413,8 @@ class DeviceCodePollSuccess(object): "default_metadata_repo", "ssh_host_alias", "ssh_key", + "database_url", + "database_role", "raw", ) @@ -417,6 +427,8 @@ def __init__( ssh_host_alias, ssh_key, raw, + database_url=None, + database_role=None, ): self.api_token = api_token self.lab = lab @@ -424,6 +436,8 @@ def __init__( self.default_metadata_repo = default_metadata_repo self.ssh_host_alias = ssh_host_alias or {} self.ssh_key = ssh_key + self.database_url = database_url or None + self.database_role = database_role or None self.raw = raw @@ -554,8 +568,11 @@ def poll_device_code(base_url, device_code, timeout=DEFAULT_TIMEOUT): raw=ssh_key_payload, ) - # Strip the plaintext token from `raw` before exposing it. - safe_raw = {k: v for k, v in body.items() if k != "api_token"} + # Strip the plaintext token AND the database_url (which embeds the + # Postgres role's password) from `raw` before exposing it. Callers + # who serialize `raw` for debugging would otherwise leak both + # bearer secrets into logs/disk. + safe_raw = {k: v for k, v in body.items() if k not in ("api_token", "database_url")} return DeviceCodePollSuccess( api_token=body.get("api_token", ""), @@ -565,6 +582,8 @@ def poll_device_code(base_url, device_code, timeout=DEFAULT_TIMEOUT): ssh_host_alias=body.get("ssh_host_alias") or {}, ssh_key=ssh_key, raw=safe_raw, + database_url=body.get("database_url") or None, + database_role=body.get("database_role") or None, ) diff --git a/pychron/cloud/dvc_credentials.py b/pychron/cloud/dvc_credentials.py new file mode 100644 index 000000000..b576ba57a --- /dev/null +++ b/pychron/cloud/dvc_credentials.py @@ -0,0 +1,327 @@ +# =============================================================================== +# Copyright 2026 Jake Ross +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =============================================================================== +"""DVC connection-prefs persistence for device-flow enrollment. + +After a successful device-code poll the workstation receives a +``database_url`` of the form +``postgresql://role:password@host:port/dbname``. This module parses +that URL and writes the result as a ``DVCConnectionItem`` favorite to +the ``pychron.dvc.connection`` Envisage preference node so the next +DVC startup picks up the new credentials without any manual paste. + +Kept as pure functions on the Envisage prefs adapter so the unit tests +can exercise the CSV / favorites round-trip without spinning up the +full Traits / Envisage stack. +""" + +from __future__ import absolute_import + +import logging +from urllib.parse import unquote, urlparse + +logger = logging.getLogger(__name__) + + +# Order MUST match DVCConnectionItem.attributes in +# pychron/dvc/tasks/dvc_preferences.py. CSV is positional — we cannot +# pass kwargs to the on-disk format. +_DVC_CONNECTION_ATTRS = ( + "name", + "kind", + "username", + "host", + "dbname", + "password", + "enabled", + "default", + "path", + "organization", + "meta_repo_name", + "meta_repo_dir", + "timeout", + "repository_root", + "connection_method", + "cloudsql_instance_connection_name", + "cloudsql_ip_type", + "cloudsql_service_account_email", + "cloudsql_service_account_key_path", +) + +# Sentinel used to mark favorites added by the device-flow path so a +# re-enrollment for the same lab REPLACES rather than stacking entries. +CLOUD_FAV_PREFIX = "cloud-" + + +class DatabaseUrlParseError(ValueError): + """Raised when ``database_url`` cannot be parsed into the fields a + ``DVCConnectionItem`` needs.""" + + +def parse_database_url(url): + """Parse a ``postgresql://`` URL into the components a + ``DVCConnectionItem`` needs. + + Returns a dict with keys ``host``, ``port`` (int or None), + ``username``, ``password``, ``dbname``. Percent-encoded userinfo + components (per RFC 3986) are decoded so the workstation gets the + raw password the server-side admin tool actually generated. + + Raises :class:`DatabaseUrlParseError` on malformed input — the + caller is expected to fall back to leaving prefs unchanged so the + technician is not silently locked out. + """ + if not url: + raise DatabaseUrlParseError("empty url") + parts = urlparse(url) + if parts.scheme not in ("postgresql", "postgres"): + raise DatabaseUrlParseError("expected postgresql:// scheme, got {!r}".format(parts.scheme)) + if not parts.hostname: + raise DatabaseUrlParseError("url is missing host") + dbname = parts.path.lstrip("/") if parts.path else "" + if not dbname: + raise DatabaseUrlParseError("url is missing database name") + return { + "host": parts.hostname, + "port": parts.port, + "username": unquote(parts.username) if parts.username else "", + "password": unquote(parts.password) if parts.password else "", + "dbname": dbname, + } + + +def _row_to_csv(values): + """Join the favorite's positional fields with commas. Mirrors + :func:`pychron.core.helpers.strtools.to_csv_str` so the resulting + CSV round-trips through ``DVCConnectionItem.__init__(attrs=...)``. + + A password that contains a literal comma would corrupt the CSV. + The pychronAPI admin CLI uses a ``[a-zA-Z0-9]`` alphabet so this + cannot happen for credentials minted via the device flow, but we + raise loudly if it ever does so the caller knows the favorite is + unsafe to write. + """ + out = [] + for attr, value in zip(_DVC_CONNECTION_ATTRS, values): + s = "" if value is None else str(value) + if "," in s: + raise DatabaseUrlParseError( + "{} contains a literal comma which would corrupt the " + "CSV-encoded favorites preference".format(attr) + ) + out.append(s) + return ",".join(out) + + +def build_dvc_connection_csv( + parsed, + name, + organization="", + meta_repo_name="", + meta_repo_dir="", + repository_root="", +): + """Serialize a parsed ``database_url`` as a positional-CSV row that + ``DVCConnectionItem(attrs=)`` can re-hydrate. + + Marked ``enabled=True`` and ``default=True`` so the next DVC + startup picks the new entry without further user action. + """ + if "host" not in parsed: + raise DatabaseUrlParseError("parsed url is missing host") + # DVCConnectionItem has no separate port attribute — the SQLAlchemy + # URL builder in pychron/database/core/database_adapter.py:557 + # interpolates ``host`` directly into the connection string, so + # encode port as ``host:port``. Skipping the port silently demotes + # everything to the dialect default (5432 for postgresql) which + # corrupts connections to non-default Cloud SQL ports. + host = parsed.get("host", "") + port = parsed.get("port") + if host and port: + host = "{}:{}".format(host, port) + values = [ + name, # name + "postgresql", # kind + parsed.get("username", ""), # username + host, # host (host[:port]) + parsed.get("dbname", ""), # dbname + parsed.get("password", ""), # password + "True", # enabled + "True", # default + "", # path (sqlite-only) + organization, # organization + meta_repo_name, # meta_repo_name + meta_repo_dir, # meta_repo_dir + "5", # timeout + repository_root, # repository_root + "direct", # connection_method + "", # cloudsql_instance_connection_name + "public", # cloudsql_ip_type + "", # cloudsql_service_account_email + "", # cloudsql_service_account_key_path + ] + return _row_to_csv(values) + + +def _favorite_name(row): + """First field of a favorites CSV is the user-visible name. Used to + de-duplicate when re-enrolling the same lab.""" + if not row: + return "" + return row.split(",", 1)[0] + + +def merge_dvc_connection_favorites(existing, new_row, replace_name): + """Return the new favorites list with ``replace_name`` (if any + matching row exists) replaced by ``new_row``, or with ``new_row`` + appended otherwise. Existing rows whose name matches but is not the + replacement target are left alone — the user may have set up other + connections by hand. + + Also strips the ``default=True`` flag from any other row that had + it, since CSV position 8 (zero-indexed 7) is ``default``. We only + want one default favorite at a time. + """ + out = [] + replaced = False + new_default = _row_field(new_row, 7) == "True" + for row in existing or []: + name = _favorite_name(row) + if name == replace_name: + out.append(new_row) + replaced = True + continue + if new_default: + row = _row_set_field(row, 7, "False") + out.append(row) + if not replaced: + out.append(new_row) + return out + + +def _row_field(row, idx): + parts = row.split(",") + if idx < len(parts): + return parts[idx] + return "" + + +def _row_set_field(row, idx, value): + """Set the ``idx``-th comma-separated field of ``row`` to ``value``, + extending the row with empty fields if it is shorter than ``idx``. + + Older saved favorites may have been written before + ``DVCConnectionItem.attributes`` grew its current set of fields, + so a short row is the common case rather than an exception. + Silently dropping the update would leave a stale ``default=True`` + on a prior favorite when re-enrolling, demoting the new + cloud-minted credential to non-default and breaking the + no-manual-paste contract. + """ + parts = row.split(",") + while len(parts) <= idx: + parts.append("") + parts[idx] = value + return ",".join(parts) + + +def apply_db_credentials_to_prefs( + preferences, + database_url, + database_role=None, + lab_name="", + organization="", + meta_repo_name="", + meta_repo_dir="", + repository_root="", +): + """Write a parsed ``database_url`` into the + ``pychron.dvc.connection.favorites`` pref node as a new (or + replacing) ``DVCConnectionItem`` favorite. + + ``preferences`` is the Envisage preferences adapter -- anything + with ``get(key, default=None)``, ``set(key, value)``, and + ``flush()``. Tests pass a fake. + + Returns the canonical favorite name on success, or ``None`` when + there is no credential to apply (``database_url`` is falsy). + """ + if not database_url: + return None + parsed = parse_database_url(database_url) + name = _favorite_name_for_lab(lab_name) + new_row = build_dvc_connection_csv( + parsed, + name=name, + organization=organization, + meta_repo_name=meta_repo_name, + meta_repo_dir=meta_repo_dir, + repository_root=repository_root, + ) + + raw = preferences.get("pychron.dvc.connection.favorites", "") or "" + existing = _split_favorites(raw) + merged = merge_dvc_connection_favorites(existing, new_row, replace_name=name) + preferences.set("pychron.dvc.connection.favorites", _join_favorites(merged)) + preferences.flush() + logger.info( + "applied DVC connection favorite name=%s host=%s db=%s role=%s", + name, + parsed.get("host", ""), + parsed.get("dbname", ""), + database_role or parsed.get("username", ""), + ) + return name + + +def _favorite_name_for_lab(lab_name): + safe = "".join(c for c in (lab_name or "default") if c.isalnum() or c in "-_") + return "{}{}".format(CLOUD_FAV_PREFIX, safe or "default") + + +def _split_favorites(raw): + """Envisage stores a List trait as a Python-repr-ish string. The + ``FavoritesPreferencesHelper`` round-trips through + ``self.favorites = [...]`` which Envisage serializes / deserializes + on its own. When we read raw via ``preferences.get(...)`` we may + see either a list (already deserialized) or a string we must + parse. + """ + if raw is None: + return [] + if isinstance(raw, list): + return [str(item) for item in raw] + if isinstance(raw, str): + s = raw.strip() + if not s: + return [] + # Envisage's PreferencesHelper writes List traits as + # repr-like strings — try literal_eval first. + try: + import ast + + parsed = ast.literal_eval(s) + except (ValueError, SyntaxError): + return [s] + if isinstance(parsed, list): + return [str(item) for item in parsed] + return [s] + return [str(raw)] + + +def _join_favorites(items): + """Inverse of :func:`_split_favorites`. Envisage will store this + string into the preferences node and re-deserialize it on read.""" + return repr(list(items)) diff --git a/pychron/cloud/tasks/preferences.py b/pychron/cloud/tasks/preferences.py index 5c67e2d09..2fe06b3a7 100644 --- a/pychron/cloud/tasks/preferences.py +++ b/pychron/cloud/tasks/preferences.py @@ -30,11 +30,40 @@ import logging +import os +from urllib.parse import urlparse, urlunparse + from envisage.ui.tasks.preferences_pane import PreferencesPane from pyface.api import GUI +from pyface.image_resource import ImageResource +from pyface.ui_traits import Image from traits.api import Bool, Button, File, Password, Str from traitsui.api import Color, Group, HGroup, ImageEditor, Item, VGroup, View + +_BLANK_QR_IMAGE = ImageResource("blank") + + +def _swap_origin(url, new_origin): + """Replace scheme+netloc of ``url`` with that of ``new_origin``. + + Server-supplied ``verification_url`` may point at a misconfigured + host (e.g. ``api.example.com``); the workstation operator can + override the public-facing host without redeploying the API. + Returns ``url`` unchanged if either side fails to parse. + """ + if not (url and new_origin): + return url + try: + p = urlparse(url) + np = urlparse(new_origin) + except ValueError: + return url + if not np.netloc: + return url + return urlunparse((np.scheme or p.scheme, np.netloc, p.path, p.params, p.query, p.fragment)) + + from pychron.cloud.api_client import ( CloudAPIError, CloudAuthError, @@ -43,6 +72,10 @@ CloudNetworkError, whoami, ) +from pychron.cloud.dvc_credentials import ( + DatabaseUrlParseError, + apply_db_credentials_to_prefs, +) from pychron.cloud.keyring_store import ( delete_token, get_token, @@ -54,6 +87,7 @@ KeyringWriteFailedError, WorkstationSetup, WorkstationSetupError, + load_registration, switch_lab as wipe_for_switch_lab, ) from pychron.core.confirmation import confirmation_dialog @@ -76,6 +110,7 @@ class CloudPreferences(BasePreferencesHelper): api_base_url = Str lab_name = Str api_token = Password + verification_url_override = Str test_connection = Button reonboard_button = Button("Re-onboard workstation") @@ -96,6 +131,7 @@ class CloudPreferences(BasePreferencesHelper): # user_code by hand. Empty string until the server returns the # `verification_url_complete` payload. _pending_qr_path = File + _pending_qr_image = Image(_BLANK_QR_IMAGE) _pending_active = Bool(False) _should_cancel_enrollment = Bool(False) @@ -108,12 +144,43 @@ class CloudPreferences(BasePreferencesHelper): _remote_status = Str _remote_status_color = Color + # Surfaces "Registered" / "Unregistered" on pane open so the + # technician can tell at a glance whether the workstation is + # already onboarded. Derived from the on-disk + # ``~/.pychron/registration.json`` + keyring token. + _registration_status = Str + _registration_status_color = Color + def _remote_status_color_default(self): return normalize_color_name("red") + def _registration_status_color_default(self): + return normalize_color_name("red") + def _initialize(self, *args, **kw): super(CloudPreferences, self)._initialize(*args, **kw) self._load_token_from_keyring() + self._refresh_registration_status() + + def _refresh_registration_status(self): + """Update the Registered/Unregistered indicator based on local state. + + Considered "Registered" iff a registration.json exists AND the + keyring carries a token for the configured lab. Either alone is + a half-onboarded state (e.g. keyring wipe + stale json) — call + that "Partial" so the technician knows to re-onboard. + """ + reg = load_registration() + token = get_token(self.lab_name) if self.lab_name else "" + if reg and token: + self._registration_status = "Registered" + self._registration_status_color = normalize_color_name("green") + elif reg or token: + self._registration_status = "Partial — re-onboard recommended" + self._registration_status_color = normalize_color_name("orange") + else: + self._registration_status = "Unregistered" + self._registration_status_color = normalize_color_name("red") def _is_preference_trait(self, trait_name): # api_token must never be written to the .cfg — it lives in the OS @@ -123,6 +190,8 @@ def _is_preference_trait(self, trait_name): "api_token", "_remote_status", "_remote_status_color", + "_registration_status", + "_registration_status_color", "test_connection", "reonboard_button", "revoke_button", @@ -132,6 +201,7 @@ def _is_preference_trait(self, trait_name): "_pending_user_code", "_pending_verification_url", "_pending_qr_path", + "_pending_qr_image", "_pending_active", "_should_cancel_enrollment", "_recovery_token", @@ -150,6 +220,7 @@ def _lab_name_changed(self, old, new): # there so the user sees the right token without re-entering it. if old != new: self._load_token_from_keyring() + self._refresh_registration_status() def _api_token_changed(self, old, new): if not self.lab_name: @@ -211,6 +282,24 @@ def _enroll_via_device_code_button_fired(self): self._remote_status = "Set API Base URL first" return + # Re-registration guardrail: a workstation that already has a + # registration.json + keyring token is functional; an admin + # tap on the button could otherwise silently rotate the SSH + # key and burn a fresh device-code slot. Require explicit + # confirmation before continuing. + existing_reg = load_registration() + existing_token = get_token(self.lab_name) if self.lab_name else "" + if existing_reg and existing_token: + if not confirmation_dialog( + "This workstation is already registered with Pychron Cloud " + "as lab '{}'. Re-enrolling will rotate the SSH keypair and " + "mint a new API token. Continue?".format(self.lab_name or "?"), + title="Re-register workstation", + ): + self._remote_status = "Already registered — cancelled" + self._remote_status_color = normalize_color_name("orange") + return + self._should_cancel_enrollment = False self._pending_user_code = "" self._pending_verification_url = "" @@ -232,24 +321,39 @@ def _enroll_via_device_code_button_fired(self): def _on_device_code_user_code( self, user_code, verification_url, verification_url_complete, expires_at ): - """Worker-thread callback: surface the user_code + URL + QR in the pane. - - Trait writes from non-UI threads are dispatched to the UI thread - by the Pyface event loop, so the operator sees the code as soon - as the server returns it. QR generation runs on this thread - (small file, ~hundreds of microseconds for a typical URL); a - failure is non-fatal — the typed code + URL still work. - """ - self._pending_user_code = user_code - self._pending_verification_url = verification_url + if self.verification_url_override: + verification_url = _swap_origin(verification_url, self.verification_url_override) + verification_url_complete = _swap_origin( + verification_url_complete, self.verification_url_override + ) try: - self._pending_qr_path = make_qr_for_device_code( + qr_path = make_qr_for_device_code( verification_url_complete, host_slug=self.lab_name or "default" ) except Exception as exc: logger.warning("device-code QR generation failed: %s", exc) - self._pending_qr_path = "" - self._remote_status = "Show {} to admin at {}".format(user_code, verification_url) + qr_path = "" + if qr_path: + d, n = os.path.split(qr_path) + qr_image = ImageResource(name=n, search_path=[d]) + else: + qr_image = _BLANK_QR_IMAGE + status = "Show {} to admin at {}".format(user_code, verification_url) + GUI.invoke_later( + self._apply_pending_user_code, + user_code, + verification_url, + qr_path, + qr_image, + status, + ) + + def _apply_pending_user_code(self, user_code, verification_url, qr_path, qr_image, status): + self._pending_user_code = user_code + self._pending_verification_url = verification_url + self._pending_qr_path = qr_path + self._pending_qr_image = qr_image + self._remote_status = status self._remote_status_color = normalize_color_name("orange") def _enrollment_worker(self): @@ -321,9 +425,71 @@ def _apply_enrollment_success(self, setup): self.api_base_url = setup.api_base_url self.lab_name = setup.lab_name self._load_token_from_keyring() - self._remote_status = "Enrolled as {}".format(setup.lab_name) + # Persist Postgres credentials (if the bridge had one staged) + # to ``pychron.dvc.connection.favorites`` so DVC startup picks + # them up on the next run with no manual paste. ``None`` is a + # legitimate state — the workstation runs HTTP-only. + db_applied = self._persist_db_credentials_from_setup(setup) + if db_applied: + self._remote_status = "Enrolled as {} (DB credentials applied)".format(setup.lab_name) + else: + self._remote_status = "Enrolled as {}".format(setup.lab_name) self._remote_status_color = normalize_color_name("green") + self._refresh_registration_status() self._reset_pending() + # Run the same whoami probe the manual "Test Connection" button + # uses, so the technician sees an immediate, end-to-end pass / + # fail without an extra click. Failures here do NOT roll back + # the enrollment — the credentials are already minted and + # persisted; surface the failure as a status badge instead. + self._test_connection_fired() + + def _persist_db_credentials_from_setup(self, setup): + """Apply :attr:`WorkstationSetup.database_url` to the DVC + connection prefs. Returns True when something was written. + + Errors are caught + logged + surfaced via remote_status so a + bad URL does not roll back the rest of enrollment (the cloud + prefs + ssh + keyring are already on disk by the time we get + here). + """ + if not getattr(setup, "database_url", None): + return False + # Pull MetaData metadata off the setup if available so the + # favorite carries an organization + meta_repo_name. Falls + # back to lab_name when the server omitted the block. + meta = getattr(setup, "default_metadata_repo", None) or {} + repo_id = meta.get("repository_identifier", "") if isinstance(meta, dict) else "" + organization = "" + meta_repo_name = "" + if "/" in repo_id: + organization, meta_repo_name = repo_id.split("/", 1) + elif repo_id: + meta_repo_name = repo_id + organization = organization or setup.lab_name or "" + try: + apply_db_credentials_to_prefs( + self.preferences, + database_url=setup.database_url, + database_role=setup.database_role, + lab_name=setup.lab_name, + organization=organization, + meta_repo_name=meta_repo_name, + ) + except DatabaseUrlParseError as exc: + logger.warning( + "device-code DB credential parse failed (skipping DVC prefs): %s", + exc, + ) + self._remote_status = "Enrolled — DB URL malformed, prefs unchanged" + self._remote_status_color = normalize_color_name("orange") + return False + except Exception as exc: # defensive — never abort enrollment + logger.warning("device-code DB credential persist failed: %s", exc) + self._remote_status = "Enrolled — DB prefs write failed" + self._remote_status_color = normalize_color_name("orange") + return False + return True def _apply_enrollment_terminal(self, message, color): self._remote_status = message @@ -349,6 +515,7 @@ def _reset_pending(self): self._pending_user_code = "" self._pending_verification_url = "" self._pending_qr_path = "" + self._pending_qr_image = _BLANK_QR_IMAGE self._pending_active = False self._should_cancel_enrollment = False @@ -383,6 +550,7 @@ def _reonboard_button_fired(self): return self._remote_status = "Re-onboarded" self._remote_status_color = normalize_color_name("green") + self._refresh_registration_status() def _revoke_button_fired(self): self._remote_status_color = normalize_color_name("red") @@ -409,6 +577,7 @@ def _revoke_button_fired(self): if self.lab_name: delete_token(self.lab_name) self.trait_setq(api_token="") + self._refresh_registration_status() def _switch_lab_button_fired(self): self._remote_status_color = normalize_color_name("red") @@ -429,6 +598,7 @@ def _switch_lab_button_fired(self): self.trait_setq(api_token="", lab_name="", api_base_url="") self._remote_status = "Wiped; configure new lab above" self._remote_status_color = normalize_color_name("orange") + self._refresh_registration_status() class CloudPreferencesPane(PreferencesPane): @@ -465,6 +635,16 @@ def traits_view(self): resizable=True, label="API Token", ), + Item( + "verification_url_override", + tooltip="Optional. If the server returns a verification_url " + "with the wrong public host (e.g. api.example.com), set the " + "correct origin here (e.g. https://console.pychronlabs.com). " + "The scheme+host is swapped; the path and user_code query are " + "preserved.", + resizable=True, + label="Verification URL Override", + ), HGroup( test_connection_item(), CustomLabel( @@ -473,6 +653,15 @@ def traits_view(self): color_name="_remote_status_color", ), ), + HGroup( + CustomLabel( + "_registration_status", + width=240, + color_name="_registration_status_color", + ), + label="Status", + show_border=False, + ), show_border=True, label="Pychron Cloud (pychronAPI)", ) @@ -509,9 +698,11 @@ def traits_view(self): ), HGroup( Item( - "_pending_qr_path", + "_pending_qr_image", show_label=False, editor=ImageEditor(), + width=300, + height=300, tooltip="Scan with the admin's phone to open the " "verification page with the user_code pre-filled.", visible_when="_pending_qr_path != ''", diff --git a/pychron/cloud/workstation_setup.py b/pychron/cloud/workstation_setup.py index aa05a2541..a9fe6908b 100644 --- a/pychron/cloud/workstation_setup.py +++ b/pychron/cloud/workstation_setup.py @@ -140,6 +140,17 @@ def __init__(self, api_base_url, api_token, lab_name, host=None): self.api_token = api_token self.lab_name = lab_name self.host = host or host_slug() + # Populated by :meth:`from_device_code` when the bridge has a + # staged Postgres credential for this api_token. ``None`` means + # the workstation runs HTTP-only — DVC connection prefs are + # left untouched. + self.database_url = None + self.database_role = None + # Default-MetaData repo metadata so the prefs pane can write + # ``pychron.dvc.connection`` favorites with the right org + + # meta_repo_name without re-deriving them from + # ``registration.json``. + self.default_metadata_repo = None # -- device-code enrollment ---------------------------------------- @@ -248,6 +259,9 @@ def from_device_code( lab_name=success.lab, host=host, ) + setup.database_url = success.database_url + setup.database_role = success.database_role + setup.default_metadata_repo = success.default_metadata_repo setup._persist_registration(success.ssh_key) setup._apply_ssh_config(success.ssh_key) diff --git a/test/cloud/test_device_code_setup.py b/test/cloud/test_device_code_setup.py index 33777191f..fdd1dee9c 100644 --- a/test/cloud/test_device_code_setup.py +++ b/test/cloud/test_device_code_setup.py @@ -302,3 +302,97 @@ def test_empty_api_base_url_aborts_before_any_io(self): if __name__ == "__main__": unittest.main() + + +class FromDeviceCodeDbCredentialsTestCase(unittest.TestCase): + """The poll-success body now optionally carries a ``database_url`` + + ``database_role`` minted by the off-cluster admin tool. The + orchestrator must surface those onto the returned + ``WorkstationSetup`` so the prefs pane can persist them to the + DVC connection favorites — without leaking the URL into + ``DeviceCodePollSuccess.raw`` (which is exposed for debug logs). + """ + + URL = "https://api.example" + + def setUp(self): + self.tmp = tempfile.mkdtemp() + self._patcher = patch( + "pychron.cloud.paths.os.path.expanduser", + lambda p: p.replace("~", self.tmp), + ) + self._patcher.start() + self.addCleanup(self._patcher.stop) + + def _rmtree(): + import shutil + + shutil.rmtree(self.tmp, ignore_errors=True) + + self.addCleanup(_rmtree) + + def _poll_body_with_db(self): + body = _poll_body() + body["database_url"] = "postgresql://wkstn_x:Pa55@10.0.1.5:5432/nmgrl?sslmode=require" + body["database_role"] = "wkstn_x" + return body + + def test_db_credential_fields_propagate_to_setup(self): + with ( + patch.object(api_client.requests, "post") as post, + patch.object(workstation_setup, "keyring_set_token", return_value=True), + ): + post.side_effect = [ + _resp(201, _START_BODY), + _resp(200, self._poll_body_with_db()), + ] + setup = workstation_setup.WorkstationSetup.from_device_code( + self.URL, + on_user_code=lambda *a: None, + sleep=lambda s: None, + host="testhost", + ) + + self.assertEqual( + setup.database_url, + "postgresql://wkstn_x:Pa55@10.0.1.5:5432/nmgrl?sslmode=require", + ) + self.assertEqual(setup.database_role, "wkstn_x") + + def test_no_db_credential_leaves_setup_attrs_none(self): + """When the bridge does not stage a credential, the setup + carries ``None`` for both DB fields — the prefs pane uses + this to skip writing DVC favorites and leave the existing + connection list untouched.""" + with ( + patch.object(api_client.requests, "post") as post, + patch.object(workstation_setup, "keyring_set_token", return_value=True), + ): + post.side_effect = [ + _resp(201, _START_BODY), + _resp(200, _poll_body()), # no database_url + ] + setup = workstation_setup.WorkstationSetup.from_device_code( + self.URL, + on_user_code=lambda *a: None, + sleep=lambda s: None, + host="testhost", + ) + + self.assertIsNone(setup.database_url) + self.assertIsNone(setup.database_role) + + def test_database_url_stripped_from_raw_debug_field(self): + """The plaintext password embedded in ``database_url`` must + not survive into the ``raw`` dict that callers may log for + debugging — same defensive treatment we give ``api_token``.""" + with patch.object(api_client.requests, "post") as post: + post.side_effect = [ + _resp(200, self._poll_body_with_db()), + ] + success = api_client.poll_device_code(self.URL, "dvc_xyz") + + self.assertNotIn("database_url", success.raw) + self.assertNotIn("api_token", success.raw) + # But the typed attribute still carries it for the orchestrator. + self.assertIn("Pa55", success.database_url) diff --git a/test/cloud/test_dvc_credentials.py b/test/cloud/test_dvc_credentials.py new file mode 100644 index 000000000..14043b234 --- /dev/null +++ b/test/cloud/test_dvc_credentials.py @@ -0,0 +1,302 @@ +# =============================================================================== +# Copyright 2026 Jake Ross +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =============================================================================== +"""Tests for the device-flow → DVC connection-prefs persistence path.""" + +from __future__ import absolute_import + +import unittest + +from pychron.cloud.dvc_credentials import ( + DatabaseUrlParseError, + _favorite_name_for_lab, + _row_set_field, + apply_db_credentials_to_prefs, + build_dvc_connection_csv, + merge_dvc_connection_favorites, + parse_database_url, +) + + +class FakePreferences(object): + """Minimal apptools.preferences-like adapter for unit tests.""" + + def __init__(self, initial=None): + self._store = dict(initial or {}) + self._flushed = False + + def get(self, key, default=None): + return self._store.get(key, default) + + def set(self, key, value): + self._store[key] = value + + def flush(self): + self._flushed = True + + +class ParseDatabaseUrlTestCase(unittest.TestCase): + def test_basic_url_round_trip(self): + out = parse_database_url("postgresql://wkstn_x:secret@10.0.1.5:5432/nmgrl") + self.assertEqual(out["host"], "10.0.1.5") + self.assertEqual(out["port"], 5432) + self.assertEqual(out["username"], "wkstn_x") + self.assertEqual(out["password"], "secret") + self.assertEqual(out["dbname"], "nmgrl") + + def test_postgres_alias_scheme_accepted(self): + out = parse_database_url("postgres://user:pw@h/db") + self.assertEqual(out["dbname"], "db") + + def test_no_port_defaults_to_none(self): + out = parse_database_url("postgresql://user:pw@h/db") + self.assertIsNone(out["port"]) + + def test_percent_encoded_password_decoded(self): + out = parse_database_url("postgresql://user:p%40ss@h:5432/db") + self.assertEqual(out["password"], "p@ss") + + def test_empty_url_raises(self): + with self.assertRaises(DatabaseUrlParseError): + parse_database_url("") + + def test_non_postgres_scheme_raises(self): + with self.assertRaises(DatabaseUrlParseError): + parse_database_url("mysql://u:p@h/db") + + def test_missing_host_raises(self): + with self.assertRaises(DatabaseUrlParseError): + parse_database_url("postgresql:///db") + + def test_missing_dbname_raises(self): + with self.assertRaises(DatabaseUrlParseError): + parse_database_url("postgresql://u:p@h/") + + +class BuildDvcConnectionCsvTestCase(unittest.TestCase): + def test_csv_field_order_matches_attributes(self): + parsed = parse_database_url("postgresql://wkstn_x:secret@10.0.1.5:5432/nmgrl") + csv = build_dvc_connection_csv( + parsed, + name="cloud-nmgrl", + organization="nmgrl", + meta_repo_name="MetaData", + ) + parts = csv.split(",") + self.assertEqual(parts[0], "cloud-nmgrl") + self.assertEqual(parts[1], "postgresql") + self.assertEqual(parts[2], "wkstn_x") + self.assertEqual(parts[3], "10.0.1.5:5432") # host:port + self.assertEqual(parts[4], "nmgrl") + self.assertEqual(parts[5], "secret") + self.assertEqual(parts[6], "True") # enabled + self.assertEqual(parts[7], "True") # default + self.assertEqual(parts[9], "nmgrl") # organization + self.assertEqual(parts[10], "MetaData") # meta_repo_name + + def test_port_appended_to_host(self): + """Caveman finding: port was lost. Host field MUST carry + host:port so non-default-port Cloud SQL connections work.""" + parsed = parse_database_url("postgresql://u:p@db.lab.example.com:6543/nmgrl") + csv = build_dvc_connection_csv(parsed, name="cloud-nmgrl") + self.assertIn("db.lab.example.com:6543", csv) + + def test_no_port_omits_colon(self): + parsed = parse_database_url("postgresql://u:p@db.lab/nmgrl") + csv = build_dvc_connection_csv(parsed, name="cloud-nmgrl") + parts = csv.split(",") + self.assertEqual(parts[3], "db.lab") + self.assertNotIn(":", parts[3]) + + def test_password_with_comma_rejected(self): + """A literal comma in the password would corrupt the CSV + positional encoding. Caller must catch + abort rather than + write a poison favorite.""" + parsed = parse_database_url("postgresql://u:p@h/db") + parsed["password"] = "a,b" + with self.assertRaises(DatabaseUrlParseError): + build_dvc_connection_csv(parsed, name="cloud-x") + + +class MergeFavoritesTestCase(unittest.TestCase): + def test_appends_when_no_match(self): + existing = ["myhand,postgresql,me,h,db,p,True,False,,"] + new_row = "cloud-nmgrl,postgresql,wkstn,h,db,p,True,True,," + out = merge_dvc_connection_favorites(existing, new_row, "cloud-nmgrl") + self.assertEqual(len(out), 2) + self.assertEqual(out[1], new_row) + + def test_replaces_matching_row(self): + existing = [ + "myhand,postgresql,me,h,db,p,True,False,,", + "cloud-nmgrl,postgresql,old,h,db,old,True,True,,", + ] + new_row = "cloud-nmgrl,postgresql,new,h,db,new,True,True,," + out = merge_dvc_connection_favorites(existing, new_row, "cloud-nmgrl") + self.assertEqual(len(out), 2) + self.assertEqual(out[1], new_row) + + def test_clears_default_flag_on_other_rows(self): + """Adding a new default=True favorite must demote any other + row that previously held the default flag — only one default + at a time.""" + existing = [ + "myhand,postgresql,me,h,db,p,True,True,,", + ] + new_row = "cloud-nmgrl,postgresql,wkstn,h,db,p,True,True,," + out = merge_dvc_connection_favorites(existing, new_row, "cloud-nmgrl") + self.assertIn("True,False", out[0]) # default field flipped + self.assertIn("cloud-nmgrl", out[1]) + + def test_short_row_extended_when_clearing_default(self): + """Caveman finding: _row_set_field used to silently no-op + when idx >= len(parts), leaving stale default=True flags on + legacy short-format rows.""" + existing = ["legacy,postgresql,me,h"] # 4 fields only + new_row = "cloud-x,postgresql,wkstn,h,db,p,True,True,," + out = merge_dvc_connection_favorites(existing, new_row, "cloud-x") + # The old short row was extended so position 7 now holds the + # new default=False marker rather than being missing. + legacy_parts = out[0].split(",") + self.assertGreater(len(legacy_parts), 7) + self.assertEqual(legacy_parts[7], "False") + + def test_no_default_change_when_new_row_isnt_default(self): + """Edge case: if the new row isn't default, leave existing + default flags alone.""" + existing = ["myhand,postgresql,me,h,db,p,True,True,,"] + new_row = "cloud-x,postgresql,wkstn,h,db,p,True,False,," + out = merge_dvc_connection_favorites(existing, new_row, "cloud-x") + self.assertIn("True,True", out[0]) # original default preserved + + +class RowSetFieldTestCase(unittest.TestCase): + def test_extends_short_row(self): + """Regression: `_row_set_field` used to drop updates whose + index exceeded the row length.""" + out = _row_set_field("a,b", 5, "X") + parts = out.split(",") + self.assertEqual(len(parts), 6) + self.assertEqual(parts[5], "X") + # In-range positions preserved. + self.assertEqual(parts[0], "a") + self.assertEqual(parts[1], "b") + + def test_in_range_update(self): + out = _row_set_field("a,b,c,d", 2, "X") + self.assertEqual(out, "a,b,X,d") + + +class FavoriteNameForLabTestCase(unittest.TestCase): + def test_safe_lab(self): + self.assertEqual(_favorite_name_for_lab("nmgrl"), "cloud-nmgrl") + + def test_strips_unsafe_chars(self): + self.assertEqual(_favorite_name_for_lab("nm grl/!"), "cloud-nmgrl") + + def test_empty_lab_falls_back(self): + self.assertEqual(_favorite_name_for_lab(""), "cloud-default") + + def test_none_lab_falls_back(self): + self.assertEqual(_favorite_name_for_lab(None), "cloud-default") + + +class ApplyDbCredentialsToPrefsTestCase(unittest.TestCase): + def test_writes_favorite_into_empty_prefs(self): + prefs = FakePreferences() + name = apply_db_credentials_to_prefs( + prefs, + database_url="postgresql://wkstn_x:secret@10.0.1.5:5432/nmgrl", + lab_name="nmgrl", + organization="nmgrl", + meta_repo_name="MetaData", + ) + self.assertEqual(name, "cloud-nmgrl") + favs = prefs.get("pychron.dvc.connection.favorites") + self.assertIsNotNone(favs) + # Round-trip through repr — that's what _join_favorites uses. + import ast + + items = ast.literal_eval(favs) + self.assertEqual(len(items), 1) + self.assertIn("cloud-nmgrl", items[0]) + self.assertIn("10.0.1.5:5432", items[0]) + self.assertIn("secret", items[0]) + self.assertTrue(prefs._flushed) + + def test_replaces_prior_cloud_favorite(self): + """Re-enrolling the same lab must REPLACE the prior cloud-* + favorite, not stack them.""" + prefs = FakePreferences() + apply_db_credentials_to_prefs( + prefs, + database_url="postgresql://wkstn_old:old@h:5432/nmgrl", + lab_name="nmgrl", + ) + apply_db_credentials_to_prefs( + prefs, + database_url="postgresql://wkstn_new:new@h:5432/nmgrl", + lab_name="nmgrl", + ) + import ast + + items = ast.literal_eval(prefs.get("pychron.dvc.connection.favorites")) + cloud_items = [r for r in items if r.startswith("cloud-nmgrl,")] + self.assertEqual(len(cloud_items), 1) + self.assertIn("new", cloud_items[0]) + self.assertNotIn("old", cloud_items[0]) + + def test_preserves_user_defined_favorites(self): + legacy = "myhand,postgresql,me,otherhost,otherdb,mypw,True,True,," + prefs = FakePreferences( + initial={ + "pychron.dvc.connection.favorites": repr([legacy]), + } + ) + apply_db_credentials_to_prefs( + prefs, + database_url="postgresql://wkstn_x:secret@10.0.1.5:5432/nmgrl", + lab_name="nmgrl", + ) + import ast + + items = ast.literal_eval(prefs.get("pychron.dvc.connection.favorites")) + # legacy row preserved (not deleted) + self.assertTrue(any(r.startswith("myhand,") for r in items)) + # new row appended + self.assertTrue(any(r.startswith("cloud-nmgrl,") for r in items)) + # legacy default flag flipped to False (only one default) + legacy_row = [r for r in items if r.startswith("myhand,")][0] + self.assertEqual(legacy_row.split(",")[7], "False") + + def test_none_url_returns_none_without_writing(self): + prefs = FakePreferences() + out = apply_db_credentials_to_prefs(prefs, database_url=None, lab_name="nmgrl") + self.assertIsNone(out) + self.assertFalse(prefs._flushed) + self.assertIsNone(prefs.get("pychron.dvc.connection.favorites")) + + def test_malformed_url_raises(self): + prefs = FakePreferences() + with self.assertRaises(DatabaseUrlParseError): + apply_db_credentials_to_prefs( + prefs, + database_url="not-a-url", + lab_name="nmgrl", + ) + + +if __name__ == "__main__": # pragma: no cover + unittest.main()