From 15e6f7d420d286724c7c1e8a3d09e005d8d3d783 Mon Sep 17 00:00:00 2001 From: jakeross Date: Sun, 10 May 2026 13:20:38 -0600 Subject: [PATCH] feat(cloud): persist device-flow Cloud SQL IAM credentials into DVC prefs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the static-Postgres-password client work from PR #17 (held) with the Cloud SQL IAM auth path required by the latest BDD spec. No static Postgres password is parsed, persisted, or transmitted. The poll-success response now optionally carries a database_iam bundle minted off-cluster by the admin tool (companion server PR #43). This commit wires it through the client: - pychron/cloud/api_client.py — DeviceCodePollSuccess gains a database_iam dict slot. poll_device_code() parses the new block from the response body and strips it from the safe_raw debug dict so the embedded SA private key cannot leak into caller logs. - pychron/cloud/paths.py — new cloudsql_key_path(lab) helper that returns ~/.pychron/keys/cloudsql_.json. Lab name is filesystem-sanitized so a hostile / weird lab string cannot escape the keys directory. - pychron/cloud/iam_credentials.py (NEW) — pure helpers wrapping validate → write SA key → CSV → favorites: * _validate_iam_bundle: required-field check + ip_type public/private/psc + JSON-decode SA key + cross-check key.client_email == service_account_email (defends against bridge-side key-swap). * write_sa_key_file: atomic-ish write to ~/.pychron/keys/cloudsql_.json with 0600 on POSIX. * build_iam_dvc_csv: positional CSV that DVCConnectionItem(attrs=...) rehydrates with connection_method=cloudsql_iam, the four cloudsql_* fields populated, and username/password left empty. * merge_iam_dvc_favorites: REPLACES any prior cloud- row, demotes any other default=True favorite. _row_set_field() extends short legacy rows rather than silently no-op'ing. * apply_iam_credentials_to_prefs: end-to-end glue. - pychron/cloud/workstation_setup.py — WorkstationSetup gets database_iam + default_metadata_repo attributes; from_device_code populates them from the poll body. None means HTTP-only. - pychron/cloud/tasks/preferences.py: * New _registration_status field (Registered / Partial / Unregistered) bound to a CustomLabel. * Start-device-code-enrollment refuses to proceed when the workstation is already onboarded (registration.json + keyring token both present) without an explicit confirmation dialog. * After successful enrollment, _persist_iam_credentials_from_setup writes the SA key + DVC favorite, and the same whoami probe the manual Test Connection button uses runs automatically. A parse failure is non-fatal — surfaced via the status badge so enrollment isn't rolled back. * Re-onboard / revoke / switch-lab refresh the registration status indicator on completion. - test/cloud/test_iam_credentials.py (NEW) — 23 unit tests: write_sa_key_file (correct path, 0600 perms on POSIX, overwrite on re-enrollment, slug sanitization), build_iam_dvc_csv (CSV field ordering), merge_iam_dvc_favorites (append, replace, default flag demotion, short-row extension), _row_set_field regression, apply_iam_credentials_to_prefs end-to-end + rejection paths (missing field, invalid ip_type, mismatched key.client_email, malformed key JSON, non-service_account key). - test/cloud/test_device_code_setup.py — 3 new tests: IAM bundle propagates to WorkstationSetup, missing bundle leaves attr at None, database_iam stripped from raw debug dict. Total: 166 cloud tests passing (was 140, +26 new). Replaces PR #17. Server: pychronAPI feat/workstation-iam-credentials (PR #43). Co-Authored-By: Claude Opus 4.7 (1M context) --- pychron/cloud/api_client.py | 21 +- pychron/cloud/iam_credentials.py | 332 ++++++++++++++++++++++++ pychron/cloud/paths.py | 11 + pychron/cloud/tasks/preferences.py | 126 ++++++++- pychron/cloud/workstation_setup.py | 11 + test/cloud/test_device_code_setup.py | 106 ++++++++ test/cloud/test_iam_credentials.py | 365 +++++++++++++++++++++++++++ 7 files changed, 969 insertions(+), 3 deletions(-) create mode 100644 pychron/cloud/iam_credentials.py create mode 100644 test/cloud/test_iam_credentials.py diff --git a/pychron/cloud/api_client.py b/pychron/cloud/api_client.py index f6b51c0d4..6ab44307d 100644 --- a/pychron/cloud/api_client.py +++ b/pychron/cloud/api_client.py @@ -396,6 +396,15 @@ class DeviceCodePollSuccess(object): keyring before losing the reference. ``ssh_key`` is the same shape that :func:`register_ssh_key` returns so the orchestrator can reuse the existing persist/apply path. + + ``database_iam`` carries a per-workstation Cloud SQL IAM bundle + when the off-cluster admin tool has staged one via the bridge's + bootstrap-only ``/internal/workstation-iam-credentials`` endpoint. + Shape (dict): ``instance_connection_name``, ``database_name``, + ``service_account_email``, ``service_account_key_json``, + ``ip_type``. ``None`` means no bundle is pending — the workstation + runs HTTP-only mode. The staging row is DELETED on this read; the + SA key is not recoverable later. """ __slots__ = ( @@ -405,6 +414,7 @@ class DeviceCodePollSuccess(object): "default_metadata_repo", "ssh_host_alias", "ssh_key", + "database_iam", "raw", ) @@ -417,6 +427,7 @@ def __init__( ssh_host_alias, ssh_key, raw, + database_iam=None, ): self.api_token = api_token self.lab = lab @@ -424,6 +435,7 @@ def __init__( self.default_metadata_repo = default_metadata_repo self.ssh_host_alias = ssh_host_alias or {} self.ssh_key = ssh_key + self.database_iam = database_iam or None self.raw = raw @@ -554,8 +566,12 @@ def poll_device_code(base_url, device_code, timeout=DEFAULT_TIMEOUT): raw=ssh_key_payload, ) - # Strip the plaintext token from `raw` before exposing it. - safe_raw = {k: v for k, v in body.items() if k != "api_token"} + # Strip the plaintext token AND the database_iam bundle (which + # embeds a service-account private key) from `raw` before + # exposing it. Callers who serialize `raw` for debugging would + # otherwise leak both the bearer secret and the SA key into + # logs/disk. + safe_raw = {k: v for k, v in body.items() if k not in ("api_token", "database_iam")} return DeviceCodePollSuccess( api_token=body.get("api_token", ""), @@ -565,6 +581,7 @@ def poll_device_code(base_url, device_code, timeout=DEFAULT_TIMEOUT): ssh_host_alias=body.get("ssh_host_alias") or {}, ssh_key=ssh_key, raw=safe_raw, + database_iam=body.get("database_iam") or None, ) diff --git a/pychron/cloud/iam_credentials.py b/pychron/cloud/iam_credentials.py new file mode 100644 index 000000000..797f5a6f1 --- /dev/null +++ b/pychron/cloud/iam_credentials.py @@ -0,0 +1,332 @@ +# =============================================================================== +# Copyright 2026 Jake Ross +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =============================================================================== +"""DVC connection-prefs persistence for device-flow Cloud SQL IAM creds. + +After a successful device-code poll the workstation receives a +``database_iam`` dict shaped:: + + { + "instance_connection_name": "project:region:instance", + "database_name": "nmgrl", + "service_account_email": "wkstn-x@project.iam.gserviceaccount.com", + "service_account_key_json": "", + "ip_type": "public" | "private" | "psc", + } + +This module: + + 1. Writes ``service_account_key_json`` to + ``~/.pychron/keys/cloudsql_.json`` with 0600 permissions on + POSIX (Windows has no POSIX mode bits). + 2. Writes a ``DVCConnectionItem`` favorite to the + ``pychron.dvc.connection`` Envisage preference node with + ``connection_method=cloudsql_iam`` + the four ``cloudsql_*`` + fields populated. ``username`` / ``password`` are left empty — + the Cloud SQL Python Connector exchanges the SA key for a + short-lived OAuth token at every connect. + +Pure-function helpers exposed for unit testing without spinning up +the full Traits / Envisage stack. +""" + +from __future__ import absolute_import + +import ast +import json +import logging +import os + +from pychron.cloud.paths import cloudsql_key_path, ensure_pychron_dirs + +logger = logging.getLogger(__name__) + + +# Order MUST match DVCConnectionItem.attributes in +# pychron/dvc/tasks/dvc_preferences.py. CSV is positional. +_DVC_CONNECTION_ATTRS = ( + "name", + "kind", + "username", + "host", + "dbname", + "password", + "enabled", + "default", + "path", + "organization", + "meta_repo_name", + "meta_repo_dir", + "timeout", + "repository_root", + "connection_method", + "cloudsql_instance_connection_name", + "cloudsql_ip_type", + "cloudsql_service_account_email", + "cloudsql_service_account_key_path", +) + +# Sentinel name prefix used to mark favorites added by the device-flow +# IAM path. Re-enrolling the same lab REPLACES the prior cloud-* row. +CLOUD_FAV_PREFIX = "cloud-" + +# CloudSQL routing modes per DVCConnectionItem.cloudsql_ip_type Enum. +_VALID_IP_TYPES = ("public", "private", "psc") + + +class IamCredentialsError(Exception): + """Raised when an IAM bundle cannot be applied to prefs.""" + + +def _validate_iam_bundle(bundle): + """Lightweight shape check on the bridge response. Mirrors the + server's pydantic validators so a malformed bundle fails fast on + the client rather than landing a half-configured DVC favorite.""" + if not isinstance(bundle, dict): + raise IamCredentialsError("database_iam payload is not a dict") + for key in ( + "instance_connection_name", + "database_name", + "service_account_email", + "service_account_key_json", + ): + v = bundle.get(key) + if not isinstance(v, str) or not v: + raise IamCredentialsError("database_iam is missing required field {}".format(key)) + ip_type = bundle.get("ip_type", "public") or "public" + if ip_type not in _VALID_IP_TYPES: + raise IamCredentialsError( + "database_iam ip_type {!r} is not one of {}".format(ip_type, _VALID_IP_TYPES) + ) + # Verify the SA key file looks plausible — same surface the server + # validates. A workstation that writes a malformed SA key to disk + # cannot connect to Cloud SQL anyway, and the failure will be + # easier to diagnose at enrollment than at first DVC startup. + try: + key_payload = json.loads(bundle["service_account_key_json"]) + except json.JSONDecodeError as exc: + raise IamCredentialsError( + "database_iam service_account_key_json is not valid JSON: {}".format(exc) + ) + if not isinstance(key_payload, dict) or key_payload.get("type") != "service_account": + raise IamCredentialsError( + "database_iam service_account_key_json is not a service_account key" + ) + if key_payload.get("client_email") != bundle["service_account_email"]: + raise IamCredentialsError( + "database_iam SA key client_email does not match service_account_email" + ) + + +def write_sa_key_file(lab_name, key_json): + """Persist ``key_json`` to ``~/.pychron/keys/cloudsql_.json``. + + Returns the absolute path on success. Raises :class:`OSError` + propagated from the filesystem on failure. The file is written + with 0600 permissions on POSIX so it isn't world-readable. On + Windows POSIX mode bits don't apply; the file inherits parent + ACLs (the keys directory is created via ``ensure_pychron_dirs`` + which the caller is responsible for). + """ + ensure_pychron_dirs() + path = cloudsql_key_path(lab_name) + # Atomic-ish write: write to temp then replace, so a crash mid- + # write doesn't leave a partially-written key. + tmp = path + ".tmp" + with open(tmp, "w") as f: + f.write(key_json) + if os.name == "posix": + os.chmod(tmp, 0o600) + os.replace(tmp, path) + return path + + +def _row_to_csv(values): + out = [] + for attr, value in zip(_DVC_CONNECTION_ATTRS, values): + s = "" if value is None else str(value) + if "," in s: + raise IamCredentialsError( + "{} contains a literal comma which would corrupt the " + "CSV-encoded favorites preference".format(attr) + ) + out.append(s) + return ",".join(out) + + +def build_iam_dvc_csv( + bundle, + name, + sa_key_file_path, + organization="", + meta_repo_name="", + meta_repo_dir="", + repository_root="", +): + """Serialize an IAM bundle as a positional CSV row that + ``DVCConnectionItem(attrs=, load_names=False)`` rehydrates. + + ``connection_method`` is set to ``cloudsql_iam``; ``username`` and + ``password`` are empty (Cloud SQL Connector handles auth via the + SA key). Marked ``enabled=True`` and ``default=True`` so the next + DVC startup picks the new entry without further user action. + """ + values = [ + name, # name + "postgresql", # kind + "", # username (unused for IAM) + "", # host (unused for IAM) + bundle["database_name"], # dbname + "", # password (unused for IAM) + "True", # enabled + "True", # default + "", # path (sqlite-only) + organization, # organization + meta_repo_name, # meta_repo_name + meta_repo_dir, # meta_repo_dir + "5", # timeout + repository_root, # repository_root + "cloudsql_iam", # connection_method + bundle["instance_connection_name"], # cloudsql_instance_connection_name + bundle.get("ip_type", "public") or "public", # cloudsql_ip_type + bundle["service_account_email"], # cloudsql_service_account_email + sa_key_file_path, # cloudsql_service_account_key_path + ] + return _row_to_csv(values) + + +def _favorite_name_for_lab(lab_name): + safe = "".join(c for c in (lab_name or "default") if c.isalnum() or c in "-_") + return "{}{}".format(CLOUD_FAV_PREFIX, safe or "default") + + +def _row_field(row, idx): + parts = row.split(",") + if idx < len(parts): + return parts[idx] + return "" + + +def _row_set_field(row, idx, value): + """Set the ``idx``-th comma-separated field, extending short rows + with empty fields. A silent no-op on short rows would leave a + stale ``default=True`` flag on legacy favorites and demote the + new cloud-minted entry to non-default. + """ + parts = row.split(",") + while len(parts) <= idx: + parts.append("") + parts[idx] = value + return ",".join(parts) + + +def _favorite_name(row): + if not row: + return "" + return row.split(",", 1)[0] + + +def merge_iam_dvc_favorites(existing, new_row, replace_name): + """Replace any row whose name == ``replace_name``, otherwise + append. Demotes any other ``default=True`` favorite so only one + default is active at a time. + """ + out = [] + replaced = False + new_default = _row_field(new_row, 7) == "True" + for row in existing or []: + if _favorite_name(row) == replace_name: + out.append(new_row) + replaced = True + continue + if new_default: + row = _row_set_field(row, 7, "False") + out.append(row) + if not replaced: + out.append(new_row) + return out + + +def _split_favorites(raw): + if raw is None: + return [] + if isinstance(raw, list): + return [str(item) for item in raw] + if isinstance(raw, str): + s = raw.strip() + if not s: + return [] + try: + parsed = ast.literal_eval(s) + except (ValueError, SyntaxError): + return [s] + if isinstance(parsed, list): + return [str(item) for item in parsed] + return [s] + return [str(raw)] + + +def _join_favorites(items): + return repr(list(items)) + + +def apply_iam_credentials_to_prefs( + preferences, + bundle, + lab_name="", + organization="", + meta_repo_name="", + meta_repo_dir="", + repository_root="", +): + """End-to-end: validate bundle, write SA key file, push a + ``cloudsql_iam`` DVCConnectionItem favorite into + ``pychron.dvc.connection.favorites``. + + Returns the canonical favorite name on success, or ``None`` when + there is no bundle to apply (``bundle`` is falsy / empty dict). + + Raises :class:`IamCredentialsError` on a malformed bundle so the + caller can show the technician a clear failure rather than + silently writing a half-configured favorite. + """ + if not bundle: + return None + _validate_iam_bundle(bundle) + name = _favorite_name_for_lab(lab_name) + sa_path = write_sa_key_file(lab_name, bundle["service_account_key_json"]) + new_row = build_iam_dvc_csv( + bundle, + name=name, + sa_key_file_path=sa_path, + organization=organization, + meta_repo_name=meta_repo_name, + meta_repo_dir=meta_repo_dir, + repository_root=repository_root, + ) + + raw = preferences.get("pychron.dvc.connection.favorites", "") or "" + existing = _split_favorites(raw) + merged = merge_iam_dvc_favorites(existing, new_row, replace_name=name) + preferences.set("pychron.dvc.connection.favorites", _join_favorites(merged)) + preferences.flush() + logger.info( + "applied DVC IAM favorite name=%s instance=%s db=%s sa=%s", + name, + bundle["instance_connection_name"], + bundle["database_name"], + bundle["service_account_email"], + ) + return name diff --git a/pychron/cloud/paths.py b/pychron/cloud/paths.py index da7332863..e24ce22f4 100644 --- a/pychron/cloud/paths.py +++ b/pychron/cloud/paths.py @@ -79,6 +79,17 @@ def public_key_path(host=None): return key_path(host) + ".pub" +def cloudsql_key_path(lab): + """Path to the per-lab Cloud SQL service-account JSON key file. + + Lab name is filesystem-sanitized so a hostile / weird lab string + cannot escape the keys directory. Falls back to ``default`` when + no lab is supplied. + """ + safe = "".join(c for c in (lab or "default") if c.isalnum() or c in "-_") or "default" + return os.path.join(keys_dir(), "cloudsql_{}.json".format(safe)) + + def ensure_pychron_dirs(): """Create ``~/.pychron`` and ``~/.pychron/keys`` if missing. diff --git a/pychron/cloud/tasks/preferences.py b/pychron/cloud/tasks/preferences.py index 5c67e2d09..8baf07259 100644 --- a/pychron/cloud/tasks/preferences.py +++ b/pychron/cloud/tasks/preferences.py @@ -48,12 +48,17 @@ get_token, set_token, ) +from pychron.cloud.iam_credentials import ( + IamCredentialsError, + apply_iam_credentials_to_prefs, +) from pychron.cloud.qr import make_qr_for_device_code from pychron.cloud.workstation_setup import ( DeviceEnrollmentCancelled, KeyringWriteFailedError, WorkstationSetup, WorkstationSetupError, + load_registration, switch_lab as wipe_for_switch_lab, ) from pychron.core.confirmation import confirmation_dialog @@ -108,12 +113,42 @@ class CloudPreferences(BasePreferencesHelper): _remote_status = Str _remote_status_color = Color + # Surfaces "Registered" / "Partial" / "Unregistered" on pane open + # so the technician sees onboarding state at a glance. Derived + # from the on-disk ``~/.pychron/registration.json`` + keyring + # token. + _registration_status = Str + _registration_status_color = Color + def _remote_status_color_default(self): return normalize_color_name("red") + def _registration_status_color_default(self): + return normalize_color_name("red") + def _initialize(self, *args, **kw): super(CloudPreferences, self)._initialize(*args, **kw) self._load_token_from_keyring() + self._refresh_registration_status() + + def _refresh_registration_status(self): + """Update the Registered/Partial/Unregistered indicator based + on local state. Considered "Registered" iff a registration.json + exists AND the keyring carries a token for the configured lab. + Either alone is half-onboarded — call that "Partial" so the + technician knows to re-onboard. + """ + reg = load_registration() + token = get_token(self.lab_name) if self.lab_name else "" + if reg and token: + self._registration_status = "Registered" + self._registration_status_color = normalize_color_name("green") + elif reg or token: + self._registration_status = "Partial — re-onboard recommended" + self._registration_status_color = normalize_color_name("orange") + else: + self._registration_status = "Unregistered" + self._registration_status_color = normalize_color_name("red") def _is_preference_trait(self, trait_name): # api_token must never be written to the .cfg — it lives in the OS @@ -123,6 +158,8 @@ def _is_preference_trait(self, trait_name): "api_token", "_remote_status", "_remote_status_color", + "_registration_status", + "_registration_status_color", "test_connection", "reonboard_button", "revoke_button", @@ -150,6 +187,7 @@ def _lab_name_changed(self, old, new): # there so the user sees the right token without re-entering it. if old != new: self._load_token_from_keyring() + self._refresh_registration_status() def _api_token_changed(self, old, new): if not self.lab_name: @@ -211,6 +249,23 @@ def _enroll_via_device_code_button_fired(self): self._remote_status = "Set API Base URL first" return + # Re-registration guardrail: a workstation that already has a + # registration.json + keyring token is functional; an admin + # tap on the button could otherwise silently rotate the SSH + # key and burn a fresh device-code slot. + existing_reg = load_registration() + existing_token = get_token(self.lab_name) if self.lab_name else "" + if existing_reg and existing_token: + if not confirmation_dialog( + "This workstation is already registered with Pychron Cloud " + "as lab '{}'. Re-enrolling will rotate the SSH keypair and " + "mint a new API token. Continue?".format(self.lab_name or "?"), + title="Re-register workstation", + ): + self._remote_status = "Already registered — cancelled" + self._remote_status_color = normalize_color_name("orange") + return + self._should_cancel_enrollment = False self._pending_user_code = "" self._pending_verification_url = "" @@ -321,9 +376,66 @@ def _apply_enrollment_success(self, setup): self.api_base_url = setup.api_base_url self.lab_name = setup.lab_name self._load_token_from_keyring() - self._remote_status = "Enrolled as {}".format(setup.lab_name) + # Persist the Cloud SQL IAM bundle (if the bridge had one + # staged) into ``pychron.dvc.connection.favorites`` so DVC + # startup picks it up on the next run with no manual paste. + # ``None`` is a legitimate state — workstation runs HTTP-only. + iam_applied = self._persist_iam_credentials_from_setup(setup) + if iam_applied: + self._remote_status = "Enrolled as {} (Cloud SQL IAM configured)".format(setup.lab_name) + else: + self._remote_status = "Enrolled as {}".format(setup.lab_name) self._remote_status_color = normalize_color_name("green") + self._refresh_registration_status() self._reset_pending() + # Run the same whoami probe the manual "Test Connection" + # button uses, so the technician sees an immediate end-to-end + # pass / fail without an extra click. Failures here do NOT + # roll back enrollment — credentials are already minted + + # persisted. + self._test_connection_fired() + + def _persist_iam_credentials_from_setup(self, setup): + """Apply :attr:`WorkstationSetup.database_iam` to DVC prefs. + + Returns True when something was written. Errors are caught + + logged + surfaced via remote_status so a malformed bundle + does not roll back the rest of enrollment (cloud prefs + ssh + + keyring are already on disk by the time we get here). + """ + if not getattr(setup, "database_iam", None): + return False + meta = getattr(setup, "default_metadata_repo", None) or {} + repo_id = meta.get("repository_identifier", "") if isinstance(meta, dict) else "" + organization = "" + meta_repo_name = "" + if "/" in repo_id: + organization, meta_repo_name = repo_id.split("/", 1) + elif repo_id: + meta_repo_name = repo_id + organization = organization or setup.lab_name or "" + try: + apply_iam_credentials_to_prefs( + self.preferences, + bundle=setup.database_iam, + lab_name=setup.lab_name, + organization=organization, + meta_repo_name=meta_repo_name, + ) + except IamCredentialsError as exc: + logger.warning( + "device-code IAM bundle apply failed (skipping DVC prefs): %s", + exc, + ) + self._remote_status = "Enrolled — IAM bundle malformed, prefs unchanged" + self._remote_status_color = normalize_color_name("orange") + return False + except Exception as exc: # defensive + logger.warning("device-code IAM bundle persist failed: %s", exc) + self._remote_status = "Enrolled — IAM prefs write failed" + self._remote_status_color = normalize_color_name("orange") + return False + return True def _apply_enrollment_terminal(self, message, color): self._remote_status = message @@ -383,6 +495,7 @@ def _reonboard_button_fired(self): return self._remote_status = "Re-onboarded" self._remote_status_color = normalize_color_name("green") + self._refresh_registration_status() def _revoke_button_fired(self): self._remote_status_color = normalize_color_name("red") @@ -409,6 +522,7 @@ def _revoke_button_fired(self): if self.lab_name: delete_token(self.lab_name) self.trait_setq(api_token="") + self._refresh_registration_status() def _switch_lab_button_fired(self): self._remote_status_color = normalize_color_name("red") @@ -429,6 +543,7 @@ def _switch_lab_button_fired(self): self.trait_setq(api_token="", lab_name="", api_base_url="") self._remote_status = "Wiped; configure new lab above" self._remote_status_color = normalize_color_name("orange") + self._refresh_registration_status() class CloudPreferencesPane(PreferencesPane): @@ -473,6 +588,15 @@ def traits_view(self): color_name="_remote_status_color", ), ), + HGroup( + CustomLabel( + "_registration_status", + width=240, + color_name="_registration_status_color", + ), + label="Status", + show_border=False, + ), show_border=True, label="Pychron Cloud (pychronAPI)", ) diff --git a/pychron/cloud/workstation_setup.py b/pychron/cloud/workstation_setup.py index aa05a2541..07026bef7 100644 --- a/pychron/cloud/workstation_setup.py +++ b/pychron/cloud/workstation_setup.py @@ -140,6 +140,15 @@ def __init__(self, api_base_url, api_token, lab_name, host=None): self.api_token = api_token self.lab_name = lab_name self.host = host or host_slug() + # Populated by :meth:`from_device_code` when the bridge has a + # staged Cloud SQL IAM bundle for this api_token. ``None`` + # means HTTP-only — DVC connection prefs are left untouched. + self.database_iam = None + # Default-MetaData repo metadata so the prefs pane can write + # ``pychron.dvc.connection`` favorites with the right org + + # meta_repo_name without re-deriving them from + # ``registration.json``. + self.default_metadata_repo = None # -- device-code enrollment ---------------------------------------- @@ -248,6 +257,8 @@ def from_device_code( lab_name=success.lab, host=host, ) + setup.database_iam = success.database_iam + setup.default_metadata_repo = success.default_metadata_repo setup._persist_registration(success.ssh_key) setup._apply_ssh_config(success.ssh_key) diff --git a/test/cloud/test_device_code_setup.py b/test/cloud/test_device_code_setup.py index 33777191f..488bfa3f6 100644 --- a/test/cloud/test_device_code_setup.py +++ b/test/cloud/test_device_code_setup.py @@ -302,3 +302,109 @@ def test_empty_api_base_url_aborts_before_any_io(self): if __name__ == "__main__": unittest.main() + + +class FromDeviceCodeIamCredentialsTestCase(unittest.TestCase): + """The poll-success body now optionally carries a ``database_iam`` + bundle minted off-cluster by the admin tool. The orchestrator must + surface it onto the returned ``WorkstationSetup`` so the prefs + pane can persist the SA key + cloudsql_* favorite — without + leaking the bundle into ``DeviceCodePollSuccess.raw`` (which is + exposed for debug logs).""" + + URL = "https://api.example" + + def setUp(self): + self.tmp = tempfile.mkdtemp() + self._patcher = patch( + "pychron.cloud.paths.os.path.expanduser", + lambda p: p.replace("~", self.tmp), + ) + self._patcher.start() + self.addCleanup(self._patcher.stop) + + def _rmtree(): + import shutil + + shutil.rmtree(self.tmp, ignore_errors=True) + + self.addCleanup(_rmtree) + + def _iam_bundle(self): + return { + "instance_connection_name": "pychron-prod:us-central1:lab-db", + "database_name": "nmgrl", + "service_account_email": ("wkstn-x@pychron-prod.iam.gserviceaccount.com"), + "service_account_key_json": json.dumps( + { + "type": "service_account", + "client_email": ("wkstn-x@pychron-prod.iam.gserviceaccount.com"), + "private_key": ( + "-----BEGIN PRIVATE KEY-----\nFAKE\n-----END PRIVATE KEY-----\n" + ), + } + ), + "ip_type": "public", + } + + def _poll_body_with_iam(self): + body = _poll_body() + body["database_iam"] = self._iam_bundle() + return body + + def test_iam_bundle_propagates_to_setup(self): + with ( + patch.object(api_client.requests, "post") as post, + patch.object(workstation_setup, "keyring_set_token", return_value=True), + ): + post.side_effect = [ + _resp(201, _START_BODY), + _resp(200, self._poll_body_with_iam()), + ] + setup = workstation_setup.WorkstationSetup.from_device_code( + self.URL, + on_user_code=lambda *a: None, + sleep=lambda s: None, + host="testhost", + ) + self.assertIsNotNone(setup.database_iam) + self.assertEqual( + setup.database_iam["instance_connection_name"], + "pychron-prod:us-central1:lab-db", + ) + self.assertEqual( + setup.database_iam["service_account_email"], + "wkstn-x@pychron-prod.iam.gserviceaccount.com", + ) + + def test_no_iam_bundle_leaves_setup_attr_none(self): + with ( + patch.object(api_client.requests, "post") as post, + patch.object(workstation_setup, "keyring_set_token", return_value=True), + ): + post.side_effect = [ + _resp(201, _START_BODY), + _resp(200, _poll_body()), # no database_iam + ] + setup = workstation_setup.WorkstationSetup.from_device_code( + self.URL, + on_user_code=lambda *a: None, + sleep=lambda s: None, + host="testhost", + ) + self.assertIsNone(setup.database_iam) + + def test_database_iam_stripped_from_raw_debug_field(self): + """The SA private key embedded in ``database_iam`` must not + survive into the ``raw`` dict that callers may log for + debugging — same defensive treatment we give ``api_token``.""" + with patch.object(api_client.requests, "post") as post: + post.side_effect = [_resp(200, self._poll_body_with_iam())] + success = api_client.poll_device_code(self.URL, "dvc_xyz") + self.assertNotIn("database_iam", success.raw) + self.assertNotIn("api_token", success.raw) + # But the typed attribute carries the bundle for the orchestrator. + self.assertEqual( + success.database_iam["service_account_email"], + "wkstn-x@pychron-prod.iam.gserviceaccount.com", + ) diff --git a/test/cloud/test_iam_credentials.py b/test/cloud/test_iam_credentials.py new file mode 100644 index 000000000..1b40a3acf --- /dev/null +++ b/test/cloud/test_iam_credentials.py @@ -0,0 +1,365 @@ +# =============================================================================== +# Copyright 2026 Jake Ross +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =============================================================================== +"""Tests for the device-flow → DVC IAM-credential persistence path.""" + +from __future__ import absolute_import + +import ast +import json +import os +import tempfile +import unittest +from unittest.mock import patch + +from pychron.cloud.iam_credentials import ( + IamCredentialsError, + _favorite_name_for_lab, + _row_set_field, + apply_iam_credentials_to_prefs, + build_iam_dvc_csv, + merge_iam_dvc_favorites, + write_sa_key_file, +) + + +def _good_key(client_email="wkstn-x@pychron-prod.iam.gserviceaccount.com"): + return json.dumps( + { + "type": "service_account", + "project_id": "pychron-prod", + "private_key_id": "deadbeef", + "private_key": "-----BEGIN PRIVATE KEY-----\nFAKE\n-----END PRIVATE KEY-----\n", + "client_email": client_email, + "client_id": "111222333", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + } + ) + + +def _good_bundle(**overrides): + base = { + "instance_connection_name": "pychron-prod:us-central1:lab-db", + "database_name": "nmgrl", + "service_account_email": "wkstn-x@pychron-prod.iam.gserviceaccount.com", + "service_account_key_json": _good_key(), + "ip_type": "public", + } + base.update(overrides) + return base + + +class FakePreferences(object): + def __init__(self, initial=None): + self._store = dict(initial or {}) + self._flushed = False + + def get(self, key, default=None): + return self._store.get(key, default) + + def set(self, key, value): + self._store[key] = value + + def flush(self): + self._flushed = True + + +class _IsolatedHomeTestCase(unittest.TestCase): + """Redirects ``~`` to a tmpdir for tests that touch the SA key file.""" + + def setUp(self): + self.tmp = tempfile.mkdtemp() + self._patcher = patch( + "pychron.cloud.paths.os.path.expanduser", + lambda p: p.replace("~", self.tmp), + ) + self._patcher.start() + self.addCleanup(self._patcher.stop) + + def _rmtree(): + import shutil + + shutil.rmtree(self.tmp, ignore_errors=True) + + self.addCleanup(_rmtree) + + +class WriteSaKeyFileTestCase(_IsolatedHomeTestCase): + def test_writes_under_pychron_keys_with_lab_slug(self): + path = write_sa_key_file("nmgrl", _good_key()) + self.assertTrue(os.path.isfile(path)) + self.assertTrue(path.endswith("cloudsql_nmgrl.json")) + # File ended up under the redirected ~ → tmp. + self.assertTrue(path.startswith(self.tmp)) + with open(path) as f: + self.assertEqual(json.loads(f.read())["type"], "service_account") + + def test_secure_perms_on_posix(self): + if os.name != "posix": + self.skipTest("POSIX-only mode bits") + path = write_sa_key_file("nmgrl", _good_key()) + mode = os.stat(path).st_mode & 0o777 + self.assertEqual(mode, 0o600) + + def test_re_enrollment_overwrites(self): + write_sa_key_file("nmgrl", _good_key("first@p.iam.gserviceaccount.com")) + path = write_sa_key_file("nmgrl", _good_key("second@p.iam.gserviceaccount.com")) + with open(path) as f: + self.assertIn("second@p.iam.gserviceaccount.com", f.read()) + + def test_unsafe_lab_slug_sanitized(self): + path = write_sa_key_file("nm grl/!", _good_key()) + # Slashes / spaces stripped — no path-traversal escape. + self.assertTrue(path.endswith("cloudsql_nmgrl.json")) + + +class BuildIamDvcCsvTestCase(unittest.TestCase): + def test_field_order_matches_attributes(self): + csv = build_iam_dvc_csv( + _good_bundle(), + name="cloud-nmgrl", + sa_key_file_path="/home/lab/.pychron/keys/cloudsql_nmgrl.json", + organization="nmgrl", + meta_repo_name="MetaData", + ) + parts = csv.split(",") + self.assertEqual(parts[0], "cloud-nmgrl") + self.assertEqual(parts[1], "postgresql") + # username + host + password unset for IAM auth — Cloud SQL + # Connector handles auth via the SA key. + self.assertEqual(parts[2], "") + self.assertEqual(parts[3], "") + self.assertEqual(parts[4], "nmgrl") # dbname + self.assertEqual(parts[5], "") + self.assertEqual(parts[6], "True") # enabled + self.assertEqual(parts[7], "True") # default + self.assertEqual(parts[14], "cloudsql_iam") # connection_method + self.assertEqual(parts[15], "pychron-prod:us-central1:lab-db") + self.assertEqual(parts[16], "public") + self.assertEqual(parts[17], "wkstn-x@pychron-prod.iam.gserviceaccount.com") + self.assertEqual(parts[18], "/home/lab/.pychron/keys/cloudsql_nmgrl.json") + + +class MergeIamDvcFavoritesTestCase(unittest.TestCase): + def test_appends_when_no_match(self): + existing = ["myhand,postgresql,me,h,db,p,True,False,,,,,,,direct,,,,"] + new_row = ( + "cloud-nmgrl,postgresql,,,nmgrl,,True,True,,,,,5,,cloudsql_iam," + "pychron-prod:us-central1:lab-db,public," + "wkstn-x@pychron-prod.iam.gserviceaccount.com," + "/home/lab/.pychron/keys/cloudsql_nmgrl.json" + ) + out = merge_iam_dvc_favorites(existing, new_row, "cloud-nmgrl") + self.assertEqual(len(out), 2) + self.assertEqual(out[1], new_row) + + def test_replaces_matching_row(self): + existing = [ + "cloud-nmgrl,postgresql,,,nmgrl,,True,True,,,,,5,,cloudsql_iam," + "old-instance:r:i,public,old@p.iam.gserviceaccount.com,/old/path" + ] + new_row = ( + "cloud-nmgrl,postgresql,,,nmgrl,,True,True,,,,,5,,cloudsql_iam," + "new-instance:r:i,public,new@p.iam.gserviceaccount.com,/new/path" + ) + out = merge_iam_dvc_favorites(existing, new_row, "cloud-nmgrl") + self.assertEqual(len(out), 1) + self.assertIn("new-instance", out[0]) + + def test_clears_default_flag_on_other_rows(self): + existing = ["myhand,postgresql,me,h,db,p,True,True,,,,,,,direct,,,,"] + new_row = ( + "cloud-nmgrl,postgresql,,,nmgrl,,True,True,,,,,5,,cloudsql_iam," + "i:r:i,public,sa@p.iam.gserviceaccount.com,/p" + ) + out = merge_iam_dvc_favorites(existing, new_row, "cloud-nmgrl") + # Position 7 (default) on the legacy row must flip to False. + self.assertEqual(out[0].split(",")[7], "False") + + def test_short_row_extended_when_clearing_default(self): + existing = ["legacy,postgresql,me,h"] # 4 fields + new_row = ( + "cloud-x,postgresql,,,db,,True,True,,,,,5,,cloudsql_iam," + "i:r:i,public,sa@p.iam.gserviceaccount.com,/p" + ) + out = merge_iam_dvc_favorites(existing, new_row, "cloud-x") + legacy_parts = out[0].split(",") + self.assertGreater(len(legacy_parts), 7) + self.assertEqual(legacy_parts[7], "False") + + +class RowSetFieldTestCase(unittest.TestCase): + def test_extends_short_row(self): + out = _row_set_field("a,b", 5, "X") + parts = out.split(",") + self.assertEqual(len(parts), 6) + self.assertEqual(parts[5], "X") + self.assertEqual(parts[0], "a") + + def test_in_range_update(self): + out = _row_set_field("a,b,c,d", 2, "X") + self.assertEqual(out, "a,b,X,d") + + +class FavoriteNameForLabTestCase(unittest.TestCase): + def test_safe_lab(self): + self.assertEqual(_favorite_name_for_lab("nmgrl"), "cloud-nmgrl") + + def test_strips_unsafe_chars(self): + self.assertEqual(_favorite_name_for_lab("nm grl/!"), "cloud-nmgrl") + + def test_empty_lab_falls_back(self): + self.assertEqual(_favorite_name_for_lab(""), "cloud-default") + + +class ApplyIamCredentialsToPrefsTestCase(_IsolatedHomeTestCase): + def test_writes_favorite_and_sa_key(self): + prefs = FakePreferences() + name = apply_iam_credentials_to_prefs( + prefs, + bundle=_good_bundle(), + lab_name="nmgrl", + organization="nmgrl", + meta_repo_name="MetaData", + ) + self.assertEqual(name, "cloud-nmgrl") + + # SA key landed on disk. + sa_path = os.path.join(self.tmp, ".pychron", "keys", "cloudsql_nmgrl.json") + self.assertTrue(os.path.isfile(sa_path)) + + # Favorite wired the key path. + favs_raw = prefs.get("pychron.dvc.connection.favorites") + self.assertIsNotNone(favs_raw) + items = ast.literal_eval(favs_raw) + self.assertEqual(len(items), 1) + self.assertIn("cloudsql_iam", items[0]) + self.assertIn(sa_path, items[0]) + self.assertTrue(prefs._flushed) + + def test_replaces_prior_cloud_favorite(self): + prefs = FakePreferences() + apply_iam_credentials_to_prefs( + prefs, + bundle=_good_bundle( + instance_connection_name="old:r:i", + service_account_email="old@p.iam.gserviceaccount.com", + service_account_key_json=_good_key("old@p.iam.gserviceaccount.com"), + ), + lab_name="nmgrl", + ) + apply_iam_credentials_to_prefs( + prefs, + bundle=_good_bundle( + instance_connection_name="new:r:i", + service_account_email="new@p.iam.gserviceaccount.com", + service_account_key_json=_good_key("new@p.iam.gserviceaccount.com"), + ), + lab_name="nmgrl", + ) + items = ast.literal_eval(prefs.get("pychron.dvc.connection.favorites")) + cloud_items = [r for r in items if r.startswith("cloud-nmgrl,")] + self.assertEqual(len(cloud_items), 1) + self.assertIn("new:r:i", cloud_items[0]) + self.assertNotIn("old:r:i", cloud_items[0]) + + def test_preserves_user_defined_favorites(self): + prefs = FakePreferences( + initial={ + "pychron.dvc.connection.favorites": repr( + ["myhand,postgresql,me,otherhost,otherdb,mypw,True,True,,,,,,,direct,,,,"] + ), + } + ) + apply_iam_credentials_to_prefs(prefs, bundle=_good_bundle(), lab_name="nmgrl") + items = ast.literal_eval(prefs.get("pychron.dvc.connection.favorites")) + self.assertTrue(any(r.startswith("myhand,") for r in items)) + self.assertTrue(any(r.startswith("cloud-nmgrl,") for r in items)) + legacy_row = [r for r in items if r.startswith("myhand,")][0] + self.assertEqual(legacy_row.split(",")[7], "False") + + def test_none_bundle_returns_none_without_writing(self): + prefs = FakePreferences() + out = apply_iam_credentials_to_prefs(prefs, bundle=None, lab_name="nmgrl") + self.assertIsNone(out) + self.assertFalse(prefs._flushed) + + def test_missing_field_raises(self): + prefs = FakePreferences() + with self.assertRaises(IamCredentialsError): + apply_iam_credentials_to_prefs( + prefs, + bundle={ + "instance_connection_name": "x:r:i", + "database_name": "nmgrl", + "service_account_email": "wkstn-x@p.iam.gserviceaccount.com", + # service_account_key_json missing + "ip_type": "public", + }, + lab_name="nmgrl", + ) + + def test_invalid_ip_type_raises(self): + prefs = FakePreferences() + with self.assertRaises(IamCredentialsError): + apply_iam_credentials_to_prefs( + prefs, + bundle=_good_bundle(ip_type="carrier_pigeon"), + lab_name="nmgrl", + ) + + def test_mismatched_key_email_raises(self): + """Key file's client_email MUST match service_account_email or + the bundle is a key-swap attempt.""" + prefs = FakePreferences() + with self.assertRaises(IamCredentialsError): + apply_iam_credentials_to_prefs( + prefs, + bundle=_good_bundle( + service_account_email="wkstn-x@p.iam.gserviceaccount.com", + service_account_key_json=_good_key("wkstn-y@p.iam.gserviceaccount.com"), + ), + lab_name="nmgrl", + ) + + def test_malformed_key_json_raises(self): + prefs = FakePreferences() + with self.assertRaises(IamCredentialsError): + apply_iam_credentials_to_prefs( + prefs, + bundle=_good_bundle(service_account_key_json="not-json"), + lab_name="nmgrl", + ) + + def test_non_service_account_key_raises(self): + prefs = FakePreferences() + with self.assertRaises(IamCredentialsError): + apply_iam_credentials_to_prefs( + prefs, + bundle=_good_bundle( + service_account_key_json=json.dumps( + { + "type": "user_account", + "client_email": "wkstn-x@p.iam.gserviceaccount.com", + } + ) + ), + lab_name="nmgrl", + ) + + +if __name__ == "__main__": # pragma: no cover + unittest.main()