From b7bf04e914a57f1f22de7a0cd6870ef6a31f1493 Mon Sep 17 00:00:00 2001 From: Dima Tisnek Date: Thu, 14 Nov 2024 17:52:31 +0900 Subject: [PATCH 01/10] chore: type hints for Connection --- juju/client/connection.py | 60 +++++++++++++++++++++------------------ juju/client/connector.py | 10 ++++--- 2 files changed, 39 insertions(+), 31 deletions(-) diff --git a/juju/client/connection.py b/juju/client/connection.py index c842311c8..761bd82cf 100644 --- a/juju/client/connection.py +++ b/juju/client/connection.py @@ -1,5 +1,6 @@ # Copyright 2023 Canonical Ltd. # Licensed under the Apache V2, see LICENCE file for details. +from __future__ import annotations import base64 import json @@ -9,12 +10,13 @@ import warnings import weakref from http.client import HTTPSConnection -from typing import Dict, Literal, Optional, Sequence +from typing import Any, Literal, Sequence import macaroonbakery.bakery as bakery import macaroonbakery.httpbakery as httpbakery import websockets from dateutil.parser import parse +from typing_extensions import Self from juju import errors, jasyncio, tag, utils from juju.client import client @@ -27,19 +29,6 @@ log = logging.getLogger("juju.client.connection") -def facade_versions(name, versions): - """facade_versions returns a new object that correctly returns a object in - format expected by the connection facades inspection. - :param name: name of the facade - :param versions: versions to support by the facade - """ - if name.endswith("Facade"): - name = name[: -len("Facade")] - return { - name: {"versions": versions}, - } - - class Monitor: """Monitor helper class for our Connection class. @@ -59,7 +48,7 @@ class Monitor: DISCONNECTING = "disconnecting" DISCONNECTED = "disconnected" - def __init__(self, connection): + def __init__(self, connection: Connection): self.connection = weakref.ref(connection) self.reconnecting = jasyncio.Lock() self.close_called = jasyncio.Event() @@ -117,28 +106,40 @@ class Connection: MAX_FRAME_SIZE = 2**22 "Maximum size for a single frame. Defaults to 4MB." - facades: Dict[str, int] - _specified_facades: Dict[str, Sequence[int]] + facades: dict[str, int] + _specified_facades: dict[str, Sequence[int]] + bakery_client: Any + usertag: str | None + password: str | None + name: str + __request_id__: int + endpoints: list[tuple[str, str]] | None # Set by juju/controller.py + is_debug_log_connection: bool + monitor: Monitor + proxy: Any # Need to find types for this library + max_frame_size: int + _retries: int + _retry_backoff: float + uuid: str | None @classmethod async def connect( cls, endpoint=None, - uuid=None, - username=None, - password=None, + uuid: str | None = None, + username: str | None = None, + password: str | None = None, cacert=None, bakery_client=None, - max_frame_size=None, + max_frame_size: int | None = None, retries=3, retry_backoff=10, - specified_facades: Optional[ - Dict[str, Dict[Literal["versions"], Sequence[int]]] - ] = None, + specified_facades: dict[str, dict[Literal["versions"], Sequence[int]]] + | None = None, proxy=None, debug_log_conn=None, debug_log_params={}, - ): + ) -> Self: """Connect to the websocket. If uuid is None, the connection will be to the controller. Otherwise it @@ -270,7 +271,7 @@ def ws(self): return self._ws @property - def username(self): + def username(self) -> str | None: if not self.usertag: return None return self.usertag[len("user-") :] @@ -534,7 +535,7 @@ async def _do_ping(): log.debug("ping failed because of closed connection") pass - async def rpc(self, msg, encoder=None): + async def rpc(self, msg: dict, encoder=None) -> dict: """Make an RPC to the API. The message is encoded as JSON using the given encoder if any. :param msg: Parameters for the call (will be encoded as JSON). @@ -744,8 +745,13 @@ async def _try_endpoint(endpoint, cacert, delay): # only executed if inner loop's else did not continue # (i.e., inner loop did break due to successful connection) break + else: + # impossible, work around https://github.com/microsoft/pyright/issues/8791 + assert False # noqa: B011 + for task in tasks: task.cancel() + self._ws = result[0] self.addr = result[1] self.endpoint = result[2] diff --git a/juju/client/connector.py b/juju/client/connector.py index feb650d8c..c9be0cda4 100644 --- a/juju/client/connector.py +++ b/juju/client/connector.py @@ -1,8 +1,10 @@ # Copyright 2023 Canonical Ltd. # Licensed under the Apache V2, see LICENCE file for details. +from __future__ import annotations import copy import logging +from typing import Any import macaroonbakery.httpbakery as httpbakery from packaging import version @@ -33,9 +35,9 @@ class Connector: def __init__( self, - max_frame_size=None, - bakery_client=None, - jujudata=None, + max_frame_size: int | None = None, + bakery_client: Any | None = None, + jujudata: Any | None = None, ): """Initialize a connector that will use the given parameters by default when making a new connection @@ -52,7 +54,7 @@ def is_connected(self): """Report whether there is a currently connected controller or not""" return self._connection is not None - def connection(self): + def connection(self) -> Connection: """Return the current connection; raises an exception if there is no current connection. """ From f830d9fdfeeab418ca9b85bb669811262be4c58b Mon Sep 17 00:00:00 2001 From: Dima Tisnek Date: Fri, 15 Nov 2024 16:06:37 +0900 Subject: [PATCH 02/10] chore: basic, coarse type hints for all watcher deltas --- juju/client/overrides.py | 14 ++++++++++---- juju/delta.py | 14 +++++++++----- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/juju/client/overrides.py b/juju/client/overrides.py index 4b9446a13..1b5692555 100644 --- a/juju/client/overrides.py +++ b/juju/client/overrides.py @@ -1,8 +1,9 @@ # Copyright 2023 Canonical Ltd. # Licensed under the Apache V2, see LICENCE file for details. +from __future__ import annotations import re -from collections import namedtuple +from typing import Any, NamedTuple from . import _client, _definitions from .facade import ReturnMapping, Type, TypeEncoder @@ -22,6 +23,12 @@ ] +class _Change(NamedTuple): + entity: str + type: str + data: dict[str, Any] + + class Delta(Type): """A single websocket delta. @@ -42,12 +49,11 @@ class Delta(Type): _toSchema = {"deltas": "deltas"} _toPy = {"deltas": "deltas"} - def __init__(self, deltas=None): + def __init__(self, deltas: tuple[str, str, dict[str, Any]]): """:param deltas: [str, str, object]""" self.deltas = deltas - Change = namedtuple("Change", "entity type data") - change = Change(*self.deltas) + change = _Change(*self.deltas) self.entity = change.entity self.type = change.type diff --git a/juju/delta.py b/juju/delta.py index b32464e65..c82a20bd5 100644 --- a/juju/delta.py +++ b/juju/delta.py @@ -1,10 +1,12 @@ # Copyright 2023 Canonical Ltd. # Licensed under the Apache V2, see LICENCE file for details. +from __future__ import annotations -from .client import client +from . import model +from .client import client, overrides -def get_entity_delta(d): +def get_entity_delta(d: overrides.Delta): return _delta_types[d.entity](d.deltas) @@ -13,12 +15,14 @@ def get_entity_class(entity_type): class EntityDelta(client.Delta): - def get_id(self): + data: dict[str, str] + + def get_id(self) -> str: return self.data["id"] @classmethod - def get_entity_class(cls): - return None + def get_entity_class(cls) -> type[model.ModelEntity]: + raise NotImplementedError() class ActionDelta(EntityDelta): From a11dfa00cd0d023147c69eb505d85e1225901433 Mon Sep 17 00:00:00 2001 From: Dima Tisnek Date: Fri, 15 Nov 2024 16:17:27 +0900 Subject: [PATCH 03/10] chore: few type hints for model --- juju/model.py | 27 +++++++++++++++------------ juju/tag.py | 27 ++++++++++++++------------- 2 files changed, 29 insertions(+), 25 deletions(-) diff --git a/juju/model.py b/juju/model.py index 12f9edced..77bcead26 100644 --- a/juju/model.py +++ b/juju/model.py @@ -19,6 +19,7 @@ from datetime import datetime, timedelta from functools import partial from pathlib import Path +from typing import Any import websockets import yaml @@ -28,6 +29,7 @@ from .bundle import BundleHandler, get_charm_series, is_local_charm from .charmhub import CharmHub from .client import client, connector +from .client.connection import Connection from .client.overrides import Caveat, Macaroon from .constraints import parse as parse_constraints from .controller import ConnectedController, Controller @@ -257,7 +259,9 @@ def get_entity(self, entity_type, entity_id, history_index=-1, connected=True): class ModelEntity: """An object in the Model tree""" - def __init__(self, entity_id, model, history_index=-1, connected=True): + entity_id: str + + def __init__(self, entity_id: str, model: Model, history_index=-1, connected=True): """Initialize a new entity :param entity_id str: The unique id of the object in the model @@ -279,7 +283,7 @@ def __init__(self, entity_id, model, history_index=-1, connected=True): def __repr__(self): return f'<{type(self).__name__} entity_id="{self.entity_id}">' - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: """Fetch object attributes from the underlying data dict held in the model. @@ -615,7 +619,7 @@ def is_connected(self): """Reports whether the Model is currently connected.""" return self._connector.is_connected() - def connection(self): + def connection(self) -> Connection: """Return the current Connection object. It raises an exception if the Model is disconnected """ @@ -3227,16 +3231,16 @@ def make_archive(self, path): zf.close() return path - def _check_type(self, path): + def _check_type(self, path: str) -> str: """Check the path""" - s = os.stat(str(path)) + s = os.stat(path) if stat.S_ISDIR(s.st_mode) or stat.S_ISREG(s.st_mode): return path raise ValueError( "Invalid Charm at %s %s" % (path, "Invalid file type for a charm") ) - def _check_link(self, path): + def _check_link(self, path: str) -> None: link_path = os.readlink(path) if link_path[0] == "/": raise ValueError( @@ -3249,7 +3253,9 @@ def _check_link(self, path): "Invalid charm at %s %s" % (path, "Only internal symlinks are allowed") ) - def _write_symlink(self, zf, link_target, link_path): + def _write_symlink( + self, zf: zipfile.ZipFile, link_target: str, link_path: str + ) -> None: """Package symlinks with appropriate zipfile metadata.""" info = zipfile.ZipInfo() info.filename = link_path @@ -3259,11 +3265,8 @@ def _write_symlink(self, zf, link_target, link_path): info.external_attr = 2716663808 zf.writestr(info, link_target) - def _ignore(self, path): - if path == "build" or path.startswith("build/"): - return True - if path.startswith("."): - return True + def _ignore(self, path: str) -> bool: + return path == "build" or path.startswith("build/") or path.startswith(".") class ModelInfo(ModelEntity): diff --git a/juju/tag.py b/juju/tag.py index 957710288..5057d6398 100644 --- a/juju/tag.py +++ b/juju/tag.py @@ -1,63 +1,64 @@ # Copyright 2023 Canonical Ltd. # Licensed under the Apache V2, see LICENCE file for details. +from __future__ import annotations # TODO: Tags should be a proper class, so that we can distinguish whether # something is already a tag or not. For example, 'user-foo' is a valid # username, but is ambiguous with the already-tagged username 'foo'. -def _prefix(prefix, s): +def _prefix(prefix: str, s: str) -> str: if s and not s.startswith(prefix): return f"{prefix}{s}" return s -def untag(prefix, s): +def untag(prefix: str, s: str) -> str: if s and s.startswith(prefix): return s[len(prefix) :] return s -def cloud(cloud_name): +def cloud(cloud_name: str) -> str: return _prefix("cloud-", cloud_name) -def controller(controller_uuid): +def controller(controller_uuid: str) -> str: return _prefix("controller-", controller_uuid) -def credential(cloud, user, credential_name): +def credential(cloud: str, user: str, credential_name: str) -> str: credential_string = f"{cloud}_{user}_{credential_name}" return _prefix("cloudcred-", credential_string) -def model(model_uuid): +def model(model_uuid: str) -> str: return _prefix("model-", model_uuid) -def machine(machine_id): +def machine(machine_id: str) -> str: return _prefix("machine-", machine_id) -def user(username): +def user(username: str) -> str: return _prefix("user-", username) -def application(app_name): +def application(app_name: str) -> str: return _prefix("application-", app_name) -def storage(app_name): +def storage(app_name: str) -> str: return _prefix("storage-", app_name) -def unit(unit_name): +def unit(unit_name: str) -> str: return _prefix("unit-", unit_name.replace("/", "-")) -def action(action_uuid): +def action(action_uuid: str) -> str: return _prefix("action-", action_uuid) -def space(space_name): +def space(space_name: str) -> str: return _prefix("space-", space_name) From 1a23a601ecc46c7adb920db47d818d529d2f7eb8 Mon Sep 17 00:00:00 2001 From: Dima Tisnek Date: Fri, 15 Nov 2024 16:32:51 +0900 Subject: [PATCH 04/10] chore: simplify the websocket response queue --- juju/client/connection.py | 3 ++- juju/jasyncio.py | 3 ++- juju/model.py | 4 +-- juju/utils.py | 51 ++++++++++++++++++++++++--------------- 4 files changed, 38 insertions(+), 23 deletions(-) diff --git a/juju/client/connection.py b/juju/client/connection.py index 761bd82cf..085246aa3 100644 --- a/juju/client/connection.py +++ b/juju/client/connection.py @@ -121,6 +121,7 @@ class Connection: _retries: int _retry_backoff: float uuid: str | None + messages: IdQueue @classmethod async def connect( @@ -373,7 +374,7 @@ async def close(self, to_reconnect=False): if self.proxy is not None: self.proxy.close() - async def _recv(self, request_id): + async def _recv(self, request_id: int) -> dict[str, Any]: if not self.is_open: raise websockets.exceptions.ConnectionClosed( websockets.frames.Close( diff --git a/juju/jasyncio.py b/juju/jasyncio.py index 3590e9390..49d499142 100644 --- a/juju/jasyncio.py +++ b/juju/jasyncio.py @@ -21,6 +21,7 @@ ) from asyncio import ( CancelledError, + Task, create_task, wait, ) @@ -84,7 +85,7 @@ ROOT_LOGGER = logging.getLogger() -def create_task_with_handler(coro, task_name, logger=ROOT_LOGGER): +def create_task_with_handler(coro, task_name, logger=ROOT_LOGGER) -> Task: """Wrapper around "asyncio.create_task" to make sure the task exceptions are handled properly. diff --git a/juju/model.py b/juju/model.py index 77bcead26..445b4ab47 100644 --- a/juju/model.py +++ b/juju/model.py @@ -2917,7 +2917,7 @@ async def _get_source_api(self, url): async def wait_for_idle( self, - apps=None, + apps: list[str] | None = None, raise_on_error=True, raise_on_blocked=False, wait_for_active=False, @@ -2927,7 +2927,7 @@ async def wait_for_idle( status=None, wait_for_at_least_units=None, wait_for_exact_units=None, - ): + ) -> None: """Wait for applications in the model to settle into an idle state. :param List[str] apps: Optional list of specific app names to wait on. diff --git a/juju/utils.py b/juju/utils.py index 1692816f1..710fcc56e 100644 --- a/juju/utils.py +++ b/juju/utils.py @@ -1,13 +1,15 @@ # Copyright 2023 Canonical Ltd. # Licensed under the Apache V2, see LICENCE file for details. +from __future__ import annotations +import asyncio import base64 import os import textwrap import zipfile from collections import defaultdict -from functools import partial from pathlib import Path +from typing import Any import yaml from pyasn1.codec.der.encoder import encode @@ -20,11 +22,11 @@ async def execute_process(*cmd, log=None): """Wrapper around asyncio.create_subprocess_exec.""" - p = await jasyncio.create_subprocess_exec( + p = await asyncio.create_subprocess_exec( *cmd, - stdin=jasyncio.subprocess.PIPE, - stdout=jasyncio.subprocess.PIPE, - stderr=jasyncio.subprocess.PIPE, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, ) stdout, stderr = await p.communicate() if log: @@ -84,7 +86,7 @@ async def read_ssh_key(): """Attempt to read the local juju admin's public ssh key, so that it can be passed on to a model. """ - loop = jasyncio.get_running_loop() + loop = asyncio.get_running_loop() return await loop.run_in_executor(None, _read_ssh_key) @@ -93,20 +95,31 @@ class IdQueue: ID. """ - def __init__(self, maxsize=0): - self._queues = defaultdict(partial(jasyncio.Queue, maxsize)) - - async def get(self, id_): + _queues: dict[int, asyncio.Queue[dict[str, Any] | Exception]] + + def __init__(self): + self._queues = defaultdict(asyncio.Queue) + # FIXME cleanup needed. + # in some cases an Exception is put into the queue. + # if the main coro exits, this exception will be logged as "never awaited" + # we gotta do something about that to keep the output clean. + # + # Additionally, it's conceivable that a response is put in the queue + # and then an exception is put via put_all() + # the reader only ever fetches one item, and exception is "never awaited" + # rewrite put_all to replace the pending response instead. + + async def get(self, id_: int) -> dict[str, Any]: value = await self._queues[id_].get() del self._queues[id_] if isinstance(value, Exception): raise value return value - async def put(self, id_, value): + async def put(self, id_: int, value: dict[str, Any]): await self._queues[id_].put(value) - async def put_all(self, value): + async def put_all(self, value: Exception): for queue in self._queues.values(): await queue.put(value) @@ -120,9 +133,9 @@ async def block_until(*conditions, timeout=None, wait_period=0.5): async def _block(): while not all(c() for c in conditions): - await jasyncio.sleep(wait_period) + await asyncio.sleep(wait_period) - await jasyncio.shield(jasyncio.wait_for(_block(), timeout)) + await asyncio.shield(asyncio.wait_for(_block(), timeout)) async def block_until_with_coroutine( @@ -136,12 +149,12 @@ async def block_until_with_coroutine( async def _block(): while not await condition_coroutine(): - await jasyncio.sleep(wait_period) + await asyncio.sleep(wait_period) - await jasyncio.shield(jasyncio.wait_for(_block(), timeout=timeout)) + await asyncio.shield(asyncio.wait_for(_block(), timeout=timeout)) -async def wait_for_bundle(model, bundle, **kwargs): +async def wait_for_bundle(model, bundle: str | Path, **kwargs) -> None: """Helper to wait for just the apps in a specific bundle. Equivalent to loading the bundle, pulling out the app names, and calling:: @@ -156,8 +169,8 @@ async def wait_for_bundle(model, bundle, **kwargs): bundle = bundle_path / "bundle.yaml" except OSError: pass - bundle = yaml.safe_load(textwrap.dedent(bundle).strip()) - apps = list(bundle.get("applications", bundle.get("services")).keys()) + content: dict[str, Any] = yaml.safe_load(textwrap.dedent(bundle).strip()) + apps = list(content.get("applications", content.get("services")).keys()) await model.wait_for_idle(apps, **kwargs) From a399af9cf986fe3dbf50feb9baed44b758f7c89a Mon Sep 17 00:00:00 2001 From: Dima Tisnek Date: Mon, 18 Nov 2024 16:01:50 +0900 Subject: [PATCH 05/10] chore: better types for Type.rpc and Connection.rpc --- juju/client/connection.py | 17 +++++++++++++++-- juju/client/facade.py | 24 +++++++++++++++++------- pyproject.toml | 5 +++-- 3 files changed, 35 insertions(+), 11 deletions(-) diff --git a/juju/client/connection.py b/juju/client/connection.py index 085246aa3..781946068 100644 --- a/juju/client/connection.py +++ b/juju/client/connection.py @@ -16,13 +16,14 @@ import macaroonbakery.httpbakery as httpbakery import websockets from dateutil.parser import parse -from typing_extensions import Self +from typing_extensions import Self, overload from juju import errors, jasyncio, tag, utils from juju.client import client from juju.utils import IdQueue from juju.version import CLIENT_VERSION +from .facade import _JSON, _RICH_JSON, TypeEncoder from .facade_versions import client_facade_versions, known_unsupported_facades LEVELS = ["TRACE", "DEBUG", "INFO", "WARNING", "ERROR"] @@ -536,7 +537,19 @@ async def _do_ping(): log.debug("ping failed because of closed connection") pass - async def rpc(self, msg: dict, encoder=None) -> dict: + @overload + async def rpc( + self, msg: dict[str, _JSON], encoder: None = None + ) -> dict[str, _JSON]: ... + + @overload + async def rpc( + self, msg: dict[str, _RICH_JSON], encoder: TypeEncoder + ) -> dict[str, _JSON]: ... + + async def rpc( + self, msg: dict[str, Any], encoder: json.JSONEncoder | None = None + ) -> dict[str, Any]: """Make an RPC to the API. The message is encoded as JSON using the given encoder if any. :param msg: Parameters for the call (will be encoded as JSON). diff --git a/juju/client/facade.py b/juju/client/facade.py index 08c7a5242..81da9eef7 100644 --- a/juju/client/facade.py +++ b/juju/client/facade.py @@ -1,5 +1,6 @@ # Copyright 2023 Canonical Ltd. # Licensed under the Apache V2, see LICENCE file for details. +from __future__ import annotations import argparse import builtins @@ -13,13 +14,22 @@ from collections import defaultdict from glob import glob from pathlib import Path -from typing import Any, Dict, List, Mapping, Sequence +from typing import Any, Mapping, Sequence import packaging.version import typing_inspect +from typing_extensions import TypeAlias from . import codegen +# Plain JSON, what is received from Juju +_JSON_LEAF: TypeAlias = None | bool | int | float | str +_JSON: TypeAlias = "_JSON_LEAF|list[_JSON]|dict[str, _JSON]" + +# Type-enriched JSON, what can be sent to Juju +_RICH_LEAF: TypeAlias = "_JSON_LEAF|Type" +_RICH_JSON: TypeAlias = "_RICH_LEAF|list[_RICH_JSON]|dict[str, _RICH_JSON]" + _marker = object() JUJU_VERSION = re.compile(r"[0-9]+\.[0-9-]+[\.\-][0-9a-z]+(\.[0-9]+)?") @@ -634,7 +644,7 @@ class {name}Facade(Type): class TypeEncoder(json.JSONEncoder): - def default(self, obj): + def default(self, obj: _RICH_JSON) -> _JSON: if isinstance(obj, Type): return obj.serialize() return json.JSONEncoder.default(self, obj) @@ -653,7 +663,7 @@ def __eq__(self, other): return self.__dict__ == other.__dict__ - async def rpc(self, msg): + async def rpc(self, msg: dict[str, _RICH_JSON]) -> _JSON: result = await self.connection.rpc(msg, encoder=TypeEncoder) return result @@ -704,13 +714,13 @@ def _parse_nested_list_entry(expr, result_dict): return cls(**d) return None - def serialize(self): + def serialize(self) -> dict[str, _JSON]: d = {} for attr, tgt in self._toSchema.items(): d[tgt] = getattr(self, attr) return d - def to_json(self): + def to_json(self) -> str: return json.dumps(self.serialize(), cls=TypeEncoder, sort_keys=True) def __contains__(self, key): @@ -917,8 +927,8 @@ def generate_definitions(schemas): def generate_facades( - schemas: Dict[str, List[Schema]], -) -> Dict[str, Dict[int, codegen.Capture]]: + schemas: dict[str, list[Schema]], +) -> dict[str, dict[int, codegen.Capture]]: captures = defaultdict(codegen.Capture) # Build the Facade classes diff --git a/pyproject.toml b/pyproject.toml index 7b9e36e91..69cce1f7b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -212,6 +212,7 @@ ignore = [ [tool.pyright] # These are tentative # include = ["**/*.py"] -pythonVersion = "3.8" # check no python > 3.8 features are used -pythonPlatform = "All" +pythonVersion = "3.10" typeCheckingMode = "strict" +useLibraryCodeForTypes = true +reportGeneralTypeIssues = true From 4b6b1cee12e2ab39cc8a1c821fcffe37f95683bf Mon Sep 17 00:00:00 2001 From: Dima Tisnek Date: Mon, 18 Nov 2024 16:09:07 +0900 Subject: [PATCH 06/10] chore: type alias for readability --- juju/client/connection.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/juju/client/connection.py b/juju/client/connection.py index 781946068..25de320c0 100644 --- a/juju/client/connection.py +++ b/juju/client/connection.py @@ -16,7 +16,7 @@ import macaroonbakery.httpbakery as httpbakery import websockets from dateutil.parser import parse -from typing_extensions import Self, overload +from typing_extensions import Self, TypeAlias, overload from juju import errors, jasyncio, tag, utils from juju.client import client @@ -26,6 +26,8 @@ from .facade import _JSON, _RICH_JSON, TypeEncoder from .facade_versions import client_facade_versions, known_unsupported_facades +SPECIFIED_FACADES: TypeAlias = dict[str, dict[Literal["versions"], Sequence[int]]] + LEVELS = ["TRACE", "DEBUG", "INFO", "WARNING", "ERROR"] log = logging.getLogger("juju.client.connection") @@ -136,8 +138,7 @@ async def connect( max_frame_size: int | None = None, retries=3, retry_backoff=10, - specified_facades: dict[str, dict[Literal["versions"], Sequence[int]]] - | None = None, + specified_facades: SPECIFIED_FACADES | None = None, proxy=None, debug_log_conn=None, debug_log_params={}, From e084a70c0858c60fd2cf3fd74752ffa681e9eed0 Mon Sep 17 00:00:00 2001 From: Dima Tisnek Date: Mon, 18 Nov 2024 16:19:44 +0900 Subject: [PATCH 07/10] chore: connection impl type hint --- juju/client/connection.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/juju/client/connection.py b/juju/client/connection.py index 25de320c0..80a52b207 100644 --- a/juju/client/connection.py +++ b/juju/client/connection.py @@ -27,6 +27,7 @@ from .facade_versions import client_facade_versions, known_unsupported_facades SPECIFIED_FACADES: TypeAlias = dict[str, dict[Literal["versions"], Sequence[int]]] +_WebSocket: TypeAlias = "websockets.legacy.client.WebSocketClientProtocol" LEVELS = ["TRACE", "DEBUG", "INFO", "WARNING", "ERROR"] log = logging.getLogger("juju.client.connection") @@ -125,6 +126,7 @@ class Connection: _retry_backoff: float uuid: str | None messages: IdQueue + _ws: _WebSocket | None @classmethod async def connect( @@ -303,7 +305,7 @@ def _get_ssl(self, cert=None): context.check_hostname = False return context - async def _open(self, endpoint, cacert): + async def _open(self, endpoint, cacert) -> tuple[_WebSocket, str, str, str]: if self.is_debug_log_connection: assert self.uuid url = f"wss://user-{self.username}:{self.password}@{endpoint}/model/{self.uuid}/log" @@ -726,7 +728,9 @@ async def _connect(self, endpoints): if len(endpoints) == 0: raise errors.JujuConnectionError("no endpoints to connect to") - async def _try_endpoint(endpoint, cacert, delay): + async def _try_endpoint( + endpoint, cacert, delay + ) -> tuple[_WebSocket, str, str, str]: if delay: await jasyncio.sleep(delay) return await self._open(endpoint, cacert) @@ -738,6 +742,8 @@ async def _try_endpoint(endpoint, cacert, delay): jasyncio.ensure_future(_try_endpoint(endpoint, cacert, 0.1 * i)) for i, (endpoint, cacert) in enumerate(endpoints) ] + result: tuple[_WebSocket, str, str, str] | None = None + for attempt in range(self._retries + 1): for task in jasyncio.as_completed(tasks): try: @@ -760,13 +766,12 @@ async def _try_endpoint(endpoint, cacert, delay): # only executed if inner loop's else did not continue # (i.e., inner loop did break due to successful connection) break - else: - # impossible, work around https://github.com/microsoft/pyright/issues/8791 - assert False # noqa: B011 for task in tasks: task.cancel() + assert result # loop raises or sets the result + self._ws = result[0] self.addr = result[1] self.endpoint = result[2] From 56bb71c9eaaf1965378101c2ecdd77df06d57a68 Mon Sep 17 00:00:00 2001 From: Dima Tisnek Date: Mon, 18 Nov 2024 16:24:41 +0900 Subject: [PATCH 08/10] chore: low-hanging fruit on __init__ --- juju/model.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/juju/model.py b/juju/model.py index 445b4ab47..ada4ac29b 100644 --- a/juju/model.py +++ b/juju/model.py @@ -261,7 +261,13 @@ class ModelEntity: entity_id: str - def __init__(self, entity_id: str, model: Model, history_index=-1, connected=True): + def __init__( + self, + entity_id: str, + model: Model, + history_index: int = -1, + connected: bool = True, + ): """Initialize a new entity :param entity_id str: The unique id of the object in the model From 6fdef4b1d846281ba3cccee3a6f3f1d22d9ca0f6 Mon Sep 17 00:00:00 2001 From: Dima Tisnek Date: Mon, 18 Nov 2024 16:51:16 +0900 Subject: [PATCH 09/10] chore: stringify type aliases, so that Python 3.8 doesn't get confused --- juju/client/connection.py | 2 +- juju/client/facade.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/juju/client/connection.py b/juju/client/connection.py index 80a52b207..edf029ae8 100644 --- a/juju/client/connection.py +++ b/juju/client/connection.py @@ -26,7 +26,7 @@ from .facade import _JSON, _RICH_JSON, TypeEncoder from .facade_versions import client_facade_versions, known_unsupported_facades -SPECIFIED_FACADES: TypeAlias = dict[str, dict[Literal["versions"], Sequence[int]]] +SPECIFIED_FACADES: TypeAlias = "dict[str, dict[Literal['versions'], Sequence[int]]]" _WebSocket: TypeAlias = "websockets.legacy.client.WebSocketClientProtocol" LEVELS = ["TRACE", "DEBUG", "INFO", "WARNING", "ERROR"] diff --git a/juju/client/facade.py b/juju/client/facade.py index 81da9eef7..f8e4bb57e 100644 --- a/juju/client/facade.py +++ b/juju/client/facade.py @@ -23,7 +23,7 @@ from . import codegen # Plain JSON, what is received from Juju -_JSON_LEAF: TypeAlias = None | bool | int | float | str +_JSON_LEAF: TypeAlias = "None | bool | int | float | str" _JSON: TypeAlias = "_JSON_LEAF|list[_JSON]|dict[str, _JSON]" # Type-enriched JSON, what can be sent to Juju From c7f5c9109be32b721d39e62a1a0f3f8b3fdcfad2 Mon Sep 17 00:00:00 2001 From: Dima Tisnek Date: Tue, 19 Nov 2024 09:21:38 +0900 Subject: [PATCH 10/10] chore: py3.8 setting for pyright, CamelCase type alias --- juju/client/connection.py | 16 ++++++++-------- juju/client/facade.py | 14 +++++++------- pyproject.toml | 4 ++-- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/juju/client/connection.py b/juju/client/connection.py index edf029ae8..e79f2ea7e 100644 --- a/juju/client/connection.py +++ b/juju/client/connection.py @@ -23,10 +23,10 @@ from juju.utils import IdQueue from juju.version import CLIENT_VERSION -from .facade import _JSON, _RICH_JSON, TypeEncoder +from .facade import TypeEncoder, _Json, _RichJson from .facade_versions import client_facade_versions, known_unsupported_facades -SPECIFIED_FACADES: TypeAlias = "dict[str, dict[Literal['versions'], Sequence[int]]]" +SpecifiedFacades: TypeAlias = "dict[str, dict[Literal['versions'], Sequence[int]]]" _WebSocket: TypeAlias = "websockets.legacy.client.WebSocketClientProtocol" LEVELS = ["TRACE", "DEBUG", "INFO", "WARNING", "ERROR"] @@ -140,7 +140,7 @@ async def connect( max_frame_size: int | None = None, retries=3, retry_backoff=10, - specified_facades: SPECIFIED_FACADES | None = None, + specified_facades: SpecifiedFacades | None = None, proxy=None, debug_log_conn=None, debug_log_params={}, @@ -542,17 +542,17 @@ async def _do_ping(): @overload async def rpc( - self, msg: dict[str, _JSON], encoder: None = None - ) -> dict[str, _JSON]: ... + self, msg: dict[str, _Json], encoder: None = None + ) -> dict[str, _Json]: ... @overload async def rpc( - self, msg: dict[str, _RICH_JSON], encoder: TypeEncoder - ) -> dict[str, _JSON]: ... + self, msg: dict[str, _RichJson], encoder: TypeEncoder + ) -> dict[str, _Json]: ... async def rpc( self, msg: dict[str, Any], encoder: json.JSONEncoder | None = None - ) -> dict[str, Any]: + ) -> dict[str, _Json]: """Make an RPC to the API. The message is encoded as JSON using the given encoder if any. :param msg: Parameters for the call (will be encoded as JSON). diff --git a/juju/client/facade.py b/juju/client/facade.py index f8e4bb57e..596f97aba 100644 --- a/juju/client/facade.py +++ b/juju/client/facade.py @@ -23,12 +23,12 @@ from . import codegen # Plain JSON, what is received from Juju -_JSON_LEAF: TypeAlias = "None | bool | int | float | str" -_JSON: TypeAlias = "_JSON_LEAF|list[_JSON]|dict[str, _JSON]" +_JsonLeaf: TypeAlias = "None | bool | int | float | str" +_Json: TypeAlias = "_JsonLeaf|list[_Json]|dict[str, _Json]" # Type-enriched JSON, what can be sent to Juju -_RICH_LEAF: TypeAlias = "_JSON_LEAF|Type" -_RICH_JSON: TypeAlias = "_RICH_LEAF|list[_RICH_JSON]|dict[str, _RICH_JSON]" +_RichLeaf: TypeAlias = "_JsonLeaf|Type" +_RichJson: TypeAlias = "_RichLeaf|list[_RichJson]|dict[str, _RichJson]" _marker = object() @@ -644,7 +644,7 @@ class {name}Facade(Type): class TypeEncoder(json.JSONEncoder): - def default(self, obj: _RICH_JSON) -> _JSON: + def default(self, obj: _RichJson) -> _Json: if isinstance(obj, Type): return obj.serialize() return json.JSONEncoder.default(self, obj) @@ -663,7 +663,7 @@ def __eq__(self, other): return self.__dict__ == other.__dict__ - async def rpc(self, msg: dict[str, _RICH_JSON]) -> _JSON: + async def rpc(self, msg: dict[str, _RichJson]) -> _Json: result = await self.connection.rpc(msg, encoder=TypeEncoder) return result @@ -714,7 +714,7 @@ def _parse_nested_list_entry(expr, result_dict): return cls(**d) return None - def serialize(self) -> dict[str, _JSON]: + def serialize(self) -> dict[str, _Json]: d = {} for attr, tgt in self._toSchema.items(): d[tgt] = getattr(self, attr) diff --git a/pyproject.toml b/pyproject.toml index 69cce1f7b..e55113be8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -211,8 +211,8 @@ ignore = [ [tool.pyright] # These are tentative -# include = ["**/*.py"] -pythonVersion = "3.10" +include = ["**/*.py"] +pythonVersion = "3.8" typeCheckingMode = "strict" useLibraryCodeForTypes = true reportGeneralTypeIssues = true