Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 242 additions & 1 deletion pychron/cloud/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

from pychron.globals import globalv


DEFAULT_TIMEOUT = 10


Expand All @@ -49,6 +48,32 @@ class CloudFingerprintRejected(CloudAPIError):
"""


class CloudDeviceCodePending(CloudAPIError):
"""Device-code poll: admin has not approved yet (HTTP 425).

Workstation should sleep ``interval_seconds`` and poll again.
"""


class CloudDeviceCodeDenied(CloudAPIError):
"""Device-code poll: admin explicitly denied the request (HTTP 403).

Terminal — workstation must stop polling and ask the admin to start
a new request.
"""


class CloudDeviceCodeExpired(CloudAPIError):
"""Device-code poll terminal failure (HTTP 410).

Server collapses several lifecycle states into a uniform
``expired_token`` to deny enumeration oracles, so the client can't
distinguish ``not_found`` / ``expired`` / ``already_consumed`` /
``lab_vanished`` / ``scope_mismatch`` either. Workstation must stop
polling and start over.
"""


class CloudNetworkError(CloudAPIError):
"""Transport-level failure (DNS, TCP, TLS, timeout, non-JSON body)."""

Expand Down Expand Up @@ -327,6 +352,222 @@ def register_ssh_key(base_url, token, public_key, title=None, timeout=DEFAULT_TI
)


class DeviceCodeStart(object):
"""Result of ``POST /api/v1/forgejo/device-codes``.

The ``device_code`` is the polling secret; the ``user_code`` is the
short admin-typed code shown in the workstation UI alongside the
``verification_url``. Both plaintext fields are returned exactly
once — only hashes are persisted server-side.
"""

__slots__ = (
"device_code",
"user_code",
"verification_url",
"verification_url_complete",
"expires_at",
"interval_seconds",
"raw",
)

def __init__(
self,
device_code,
user_code,
verification_url,
verification_url_complete,
expires_at,
interval_seconds,
raw,
):
self.device_code = device_code
self.user_code = user_code
self.verification_url = verification_url
self.verification_url_complete = verification_url_complete
self.expires_at = expires_at
self.interval_seconds = interval_seconds
self.raw = raw


class DeviceCodePollSuccess(object):
"""Successful device-code poll. The minted ``api_token`` is plaintext
and is returned exactly once; the caller must persist it to the OS
keyring before losing the reference. ``ssh_key`` is the same shape
that :func:`register_ssh_key` returns so the orchestrator can reuse
the existing persist/apply path.
"""

__slots__ = (
"api_token",
"lab",
"api_base_url",
"default_metadata_repo",
"ssh_host_alias",
"ssh_key",
"raw",
)

def __init__(
self,
api_token,
lab,
api_base_url,
default_metadata_repo,
ssh_host_alias,
ssh_key,
raw,
):
self.api_token = api_token
self.lab = lab
self.api_base_url = api_base_url
self.default_metadata_repo = default_metadata_repo
self.ssh_host_alias = ssh_host_alias or {}
self.ssh_key = ssh_key
self.raw = raw


def start_device_code(base_url, public_key, hostname, timeout=DEFAULT_TIMEOUT):
"""POST a workstation public key to start a device-code grant.

Endpoint is unauthenticated. Maps:

- 201 → :class:`DeviceCodeStart`
- 400 → :class:`CloudFingerprintRejected` (malformed pubkey)
- other 4xx/5xx → :class:`CloudAPIError`
- transport / non-JSON → :class:`CloudNetworkError`
"""
if not base_url:
raise CloudAPIError("api_base_url is empty")
if not public_key:
raise CloudAPIError("public_key is empty")
if not hostname:
raise CloudAPIError("hostname is empty")

url = _join(base_url, "/api/v1/forgejo/device-codes")
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
payload = {"public_key": public_key.strip(), "hostname": hostname}

try:
resp = requests.post(
url,
headers=headers,
json=payload,
timeout=timeout,
verify=globalv.cert_file,
)
except requests.RequestException as exc:
raise CloudNetworkError("device-code start transport failure: {}".format(exc))

if resp.status_code == 400:
raise CloudFingerprintRejected("server rejected key (HTTP 400): {}".format(resp.text[:200]))
if resp.status_code != 201:
raise CloudAPIError(
"device-code start returned HTTP {}: {}".format(resp.status_code, resp.text[:200])
)

try:
body = resp.json()
except ValueError as exc:
raise CloudNetworkError("device-code start returned non-JSON body: {}".format(exc))

# Strip the secret from `raw` before exposing it. Callers who
# serialize DeviceCodeStart.raw for debugging would otherwise leak
# both the device_code (polling secret) and user_code into logs/disk.
safe_raw = {k: v for k, v in body.items() if k not in ("device_code", "user_code")}
return DeviceCodeStart(
device_code=body.get("device_code", ""),
user_code=body.get("user_code", ""),
verification_url=body.get("verification_url", ""),
verification_url_complete=body.get("verification_url_complete", ""),
expires_at=body.get("expires_at", ""),
interval_seconds=int(body.get("interval_seconds") or 5),
raw=safe_raw,
)


def poll_device_code(base_url, device_code, timeout=DEFAULT_TIMEOUT):
"""Poll a device-code grant. Unauthenticated — the device_code is the credential.

Maps:

- 200 → :class:`DeviceCodePollSuccess`
- 425 → :class:`CloudDeviceCodePending` (keep polling)
- 403 → :class:`CloudDeviceCodeDenied` (terminal — admin denied)
- 410 → :class:`CloudDeviceCodeExpired` (terminal — uniform server response
for not-found / expired / already-consumed / lab-vanished /
scope-mismatch)
- 400 → :class:`CloudFingerprintRejected`
- other 4xx/5xx → :class:`CloudAPIError`
- transport / non-JSON → :class:`CloudNetworkError`
"""
if not base_url:
raise CloudAPIError("api_base_url is empty")
if not device_code:
raise CloudAPIError("device_code is empty")

url = _join(base_url, "/api/v1/forgejo/device-codes/poll")
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
payload = {"device_code": device_code}

try:
resp = requests.post(
url,
headers=headers,
json=payload,
timeout=timeout,
verify=globalv.cert_file,
)
except requests.RequestException as exc:
raise CloudNetworkError("device-code poll transport failure: {}".format(exc))

if resp.status_code == 425:
raise CloudDeviceCodePending("authorization_pending")
if resp.status_code == 403:
raise CloudDeviceCodeDenied("access_denied")
if resp.status_code == 410:
raise CloudDeviceCodeExpired("expired_token")
if resp.status_code == 400:
raise CloudFingerprintRejected("server rejected key (HTTP 400): {}".format(resp.text[:200]))
if resp.status_code != 200:
raise CloudAPIError(
"device-code poll returned HTTP {}: {}".format(resp.status_code, resp.text[:200])
)

try:
body = resp.json()
except ValueError as exc:
raise CloudNetworkError("device-code poll returned non-JSON body: {}".format(exc))

ssh_key_payload = body.get("ssh_key") or {}
ssh_key = SSHKeyRegistration(
bot_username=ssh_key_payload.get("bot_username", ""),
fingerprint=ssh_key_payload.get("fingerprint", ""),
default_metadata_repo=ssh_key_payload.get("default_metadata_repo", ""),
ssh_host_alias=ssh_key_payload.get("ssh_host_alias") or body.get("ssh_host_alias") or {},
raw=ssh_key_payload,
)

# Strip the plaintext token from `raw` before exposing it.
safe_raw = {k: v for k, v in body.items() if k != "api_token"}

return DeviceCodePollSuccess(
api_token=body.get("api_token", ""),
lab=body.get("lab", ""),
api_base_url=body.get("api_base_url", "") or base_url,
default_metadata_repo=body.get("default_metadata_repo"),
ssh_host_alias=body.get("ssh_host_alias") or {},
ssh_key=ssh_key,
raw=safe_raw,
)


def revoke_workstation_token(base_url, token, timeout=DEFAULT_TIMEOUT):
"""Revoke the calling token via ``DELETE /api/v1/forgejo/tokens/self``.

Expand Down
Loading