From af17320f248ed85ca8eb9703c17af293483a991f Mon Sep 17 00:00:00 2001 From: Jochem Rutgers <68805714+jhrutgers@users.noreply.github.com> Date: Fri, 5 Dec 2025 21:13:53 +0100 Subject: [PATCH 01/15] bump version --- version/version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version/version.txt b/version/version.txt index 7ec1d6db..024a3a24 100644 --- a/version/version.txt +++ b/version/version.txt @@ -1 +1 @@ -2.1.0 +2.1.1-alpha From d1a2460dab1911245f8ec4039b991dd6483f1604 Mon Sep 17 00:00:00 2001 From: Jochem Rutgers <68805714+jhrutgers@users.noreply.github.com> Date: Mon, 8 Dec 2025 17:50:05 +0100 Subject: [PATCH 02/15] improve connect/disconnect handling in protocol stack --- CHANGELOG.rst | 5 ++ python/libstored/asyncio/zmq.py | 15 +++-- python/libstored/protocol/protocol.py | 89 +++++++++++++++++++++++++-- python/libstored/protocol/zmq.py | 25 ++++---- 4 files changed, 111 insertions(+), 23 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 847925d9..5f574c86 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -26,6 +26,11 @@ Added ... +Fixed +````` + +- ``ZmqClient`` assertion during cleanup + .. _Unreleased: https://github.com/DEMCON/libstored/compare/v2.1.0...HEAD diff --git a/python/libstored/asyncio/zmq.py b/python/libstored/asyncio/zmq.py index e8bf2f76..45a00f00 100644 --- a/python/libstored/asyncio/zmq.py +++ b/python/libstored/asyncio/zmq.py @@ -1487,6 +1487,7 @@ async def connect(self, host : str | None=None, port : int | None=None, \ await self.req(f'm{m}') self.connected.trigger() + await self._stack.connected() if not default_state: await self.restore_state() @@ -1555,6 +1556,7 @@ async def disconnect(self): self.logger.debug('disconnect') self.disconnecting.trigger() + await self._stack.disconnected() await self.save_state() self._socket = None @@ -1581,12 +1583,13 @@ def close(self, *, block : typing.Literal[False], sync : typing.Literal[True]) - @Work.run_sync async def close(self): - '''Alias for disconnect().''' + '''Disconnect and release resources.''' await self.disconnect() + await self._stack.close() def __del__(self): - if self.is_connected(): - self.disconnect(sync=True) + if self.is_connected() or not self._stack.is_closed(): + self.close(sync=True) def sync(self) -> SyncZmqClient: s = SyncZmqClient(self) @@ -1597,15 +1600,15 @@ def __enter__(self): return self.sync() def __exit__(self, *args): - if self.is_connected(): - self.disconnect(sync=True) + if self.is_connected() or not self._stack.is_closed(): + self.close(sync=True) async def __aenter__(self): await self.connect() return self async def __aexit__(self, *args): - await self.disconnect() + await self.close() diff --git a/python/libstored/protocol/protocol.py b/python/libstored/protocol/protocol.py index 9b62979e..ed795a80 100644 --- a/python/libstored/protocol/protocol.py +++ b/python/libstored/protocol/protocol.py @@ -56,14 +56,15 @@ class ProtocolLayer: def __init__(self, *args, **kwargs): self._closed : bool = False + self._connected : bool = True super().__init__(*args, **kwargs) self.logger = logging.getLogger(self.__class__.__name__) self._down : ProtocolLayer | None = None self._up : ProtocolLayer | None = None - self._down_callback : ProtocolLayer.AsyncCallback = callback_factory(None) - self._up_callback : ProtocolLayer.AsyncCallback = callback_factory(None) + self._encode_callback : ProtocolLayer.AsyncCallback = callback_factory(None) + self._decode_callback : ProtocolLayer.AsyncCallback = callback_factory(None) self._activity : float = 0 self._async_except_hook = callback_factory(self.default_async_except_hook) @@ -83,7 +84,7 @@ def up(self, cb : ProtocolLayer.Callback | None) -> None: ''' Set a callback to be called when data is received from the lower layer. ''' - self._up_callback = callback_factory(cb) + self._decode_callback = callback_factory(cb) @property def down(self) -> ProtocolLayer | None: @@ -94,7 +95,32 @@ def down(self, cb : ProtocolLayer.Callback | None) -> None: ''' Set a callback to be called when data is received from the upper layer. ''' - self._down_callback = callback_factory(cb) + self._encode_callback = callback_factory(cb) + + async def connected(self) -> None: + ''' + Called when the connection is (re)connected. + ''' + self._connected = True + if self.up is not None: + await self.up.connected() + + async def disconnected(self) -> None: + ''' + Called when the connection is disconnected. + ''' + if not self.is_connected(): + return + + self._connected = False + if self.up is not None: + await self.up.disconnected() + + def is_connected(self) -> bool: + ''' + Return whether the connection is currently connected. + ''' + return self._connected async def encode(self, data : ProtocolLayer.Packet) -> None: ''' @@ -102,7 +128,7 @@ async def encode(self, data : ProtocolLayer.Packet) -> None: ''' self.activity() - await self._down_callback(data) + await self._encode_callback(data) if self.down is not None: await self.down.encode(data) @@ -113,7 +139,7 @@ async def decode(self, data : ProtocolLayer.Packet) -> None: ''' self.activity() - await self._up_callback(data) + await self._decode_callback(data) if self.up is not None: await self.up.decode(data) @@ -155,7 +181,14 @@ def last_activity(self) -> float: async def close(self) -> None: ''' Close the layer and release resources. + + Closing cannot be undone. ''' + if self._closed: + return + + await self.disconnected() + self._closed = True if self.down is not None: try: @@ -163,6 +196,12 @@ async def close(self) -> None: except BaseException as e: self.logger.warning(f'Exception while closing: {e}') + def is_closed(self) -> bool: + ''' + Return whether the layer is closed. + ''' + return self._closed + async def __aenter__(self): return self @@ -363,6 +402,10 @@ async def decode(self, data : ProtocolLayer.Packet) -> None: self._inMsg = False await super().decode(msg) + async def disconnected(self) -> None: + self._inMsg = False + await super().disconnected() + @property def mtu(self) -> int | None: m = super().mtu @@ -490,6 +533,10 @@ async def decode(self, data : ProtocolLayer.Packet) -> None: self._req = False await super().decode(data) + async def disconnected(self) -> None: + await super().disconnected() + self._req = False + class SegmentationLayer(ProtocolLayer): @@ -551,6 +598,10 @@ async def timeout(self) -> None: self._buffer = bytearray() await super().timeout() + async def disconnected(self) -> None: + await super().disconnected() + self._buffer = bytearray() + class DebugArqLayer(ProtocolLayer): @@ -695,6 +746,10 @@ def reset(self) -> None: self._reset = True self._request = [] + async def connected(self) -> None: + self.reset() + await super().connected() + async def retransmit(self) -> None: self.logger.debug('retransmit') if not self._req: @@ -894,6 +949,14 @@ async def close(self) -> None: await super().close() + async def connected(self) -> None: + await self._layers[-1].connected() + await super().connected() + + async def disconnected(self) -> None: + await self._layers[-1].disconnected() + await super().disconnected() + @property def mtu(self) -> int | None: return self._layers[0].mtu @@ -1106,6 +1169,15 @@ async def timeout(self) -> None: self._prev = None await super().timeout() + async def connected(self) -> None: + self._prev = None + await super().connected() + + async def disconnected(self) -> None: + self._decoding = None + self._decoding_esc = False + await super().disconnected() + class Aes256Layer(ProtocolLayer): @@ -1273,6 +1345,11 @@ def _iv(self, unified : bool) -> bytes: else: return b'B' + iv + async def connected(self) -> None: + self._encrypt = None + self._decrypt = None + await super().connected() + layer_types : list[typing.Type[ProtocolLayer]] = [ diff --git a/python/libstored/protocol/zmq.py b/python/libstored/protocol/zmq.py index 2865e08a..8c905ab2 100644 --- a/python/libstored/protocol/zmq.py +++ b/python/libstored/protocol/zmq.py @@ -61,12 +61,15 @@ def socket(self) -> zmq.asyncio.Socket: raise RuntimeError('ZMQ socket is closed') return self._socket - def mark_open(self) -> None: + async def mark_open(self) -> None: + if self.open: + return self._open = True + await self.connected() @property def open(self) -> bool: - return self._open + return self._open and self.is_connected() async def _recv_task(self) -> None: try: @@ -78,7 +81,7 @@ async def _recv_task(self) -> None: x = b''.join(await socket.recv_multipart()) if self.logger.getEffectiveLevel() <= logging.DEBUG: self.logger.debug(f'recv {x}') - self.mark_open() + await self.mark_open() await self._handle_recv(x) except asyncio.CancelledError: pass @@ -105,10 +108,9 @@ async def close(self) -> None: self._socket.close() self._socket = None - self.disconnected() await super().close() - def _check_sent(self) -> None: + async def _check_sent(self) -> None: if self._timeout_s is None: t = None else: @@ -129,14 +131,15 @@ def _check_sent(self) -> None: if self.open: self.logger.info('connection timed out') - self.disconnected() + await self.disconnected() return - def disconnected(self) -> None: + async def disconnected(self) -> None: self._open = False for f, _ in self._sent: f.cancel() self._sent = [] + await super().disconnected() async def _send(self, data : lprot.ProtocolLayer.Packet) -> None: if isinstance(data, str): @@ -144,7 +147,7 @@ async def _send(self, data : lprot.ProtocolLayer.Packet) -> None: elif isinstance(data, memoryview): data = data.cast('B') - self._check_sent() + await self._check_sent() if self.open: if self.logger.getEffectiveLevel() <= logging.DEBUG: @@ -210,7 +213,7 @@ async def _handle_recv(self, data : bytes) -> None: async def _recv_init(self) -> None: # Indicate that we are connected. - self.mark_open() + await self.mark_open() await self._send(b'') async def encode(self, data : lprot.ProtocolLayer.Packet) -> None: @@ -315,8 +318,8 @@ async def decode(self, data : lprot.ProtocolLayer.Packet) -> None: self._req = False await super().decode(data) - def disconnected(self) -> None: - super().disconnected() + async def disconnected(self) -> None: + await super().disconnected() self._req = False lprot.register_layer_type(ZmqServer) From 0bdf59543313e9c0a0aebd490b9eeca41d12d21f Mon Sep 17 00:00:00 2001 From: Jochem Rutgers <68805714+jhrutgers@users.noreply.github.com> Date: Thu, 11 Dec 2025 17:18:58 +0100 Subject: [PATCH 03/15] allow forcing OS --- include/libstored/macros.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/include/libstored/macros.h b/include/libstored/macros.h index 24603773..3be509f6 100644 --- a/include/libstored/macros.h +++ b/include/libstored/macros.h @@ -141,7 +141,9 @@ typedef SSIZE_T ssize_t; // Platform // -#ifdef __ZEPHYR__ +#if defined(STORED_OS_BAREMETAL) || defined(STORED_OS_GENERIC) +// Accept pre-defined setup. +#elif defined(__ZEPHYR__) # define STORED_OS_BAREMETAL 1 # include // By default, turn off; picolibc does not provide it by default. From aeed978b821073b147991e2f9d8dad92a8cd9a86 Mon Sep 17 00:00:00 2001 From: Jochem Rutgers <68805714+jhrutgers@users.noreply.github.com> Date: Wed, 17 Dec 2025 14:53:40 +0100 Subject: [PATCH 04/15] add ArqLayer --- CHANGELOG.rst | 2 +- python/libstored/protocol/protocol.py | 267 +++++++++++++++++++++++++- python/libstored/wrapper/serial.py | 2 +- python/libstored/wrapper/stdio.py | 2 +- sphinx/doc/py_py.rst | 34 ++-- version/version.txt | 2 +- 6 files changed, 284 insertions(+), 25 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 5f574c86..50a2d0ed 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -24,7 +24,7 @@ The format is based on `Keep a Changelog`_, and this project adheres to Added ````` -... +- ``libstored.protocol.ArqLayer`` for general-purpose ARQ. Fixed ````` diff --git a/python/libstored/protocol/protocol.py b/python/libstored/protocol/protocol.py index ed795a80..4d179f13 100644 --- a/python/libstored/protocol/protocol.py +++ b/python/libstored/protocol/protocol.py @@ -5,6 +5,7 @@ from __future__ import annotations import asyncio +from contourpy.util import data import crcmod import Crypto.Cipher.AES import Crypto.Random @@ -451,13 +452,13 @@ async def non_debug_data(self, data : ProtocolLayer.Packet) -> None: await self._socket.send(data) -class RepReqCheckLayer(ProtocolLayer): +class ReqRepCheckLayer(ProtocolLayer): ''' A ProtocolLayer that checks that requests and replies are matched. It triggers timeout() when a reply is not received in time. ''' - name = 'repreqcheck' + name = 'reqrepcheck' def __init__(self, timeout_s : float = 1, *args, **kwargs): super().__init__(*args, **kwargs) @@ -519,7 +520,7 @@ async def close(self) -> None: async def encode(self, data : ProtocolLayer.Packet) -> None: if self._req: - raise RuntimeError('RepReqCheckLayer encode called while previous request not yet handled') + raise RuntimeError('ReqRepCheckLayer encode called while previous request not yet handled') self._req = True self._retransmit_time = time.time() + self._timeout_s @@ -612,7 +613,7 @@ class DebugArqLayer(ProtocolLayer): name = 'arq' reset_flag = 0x80 - def __init__(self, timeout_s : float = 1, *args, **kwargs): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._req : bool = False self._request : list[bytes] = [] @@ -778,6 +779,261 @@ def mtu(self) -> int | None: +class ArqLayer(ProtocolLayer): + ''' + A ProtocolLayer that implements a general-purpose ARQ protocol. + ''' + + name = 'Arq' + + nop_flag = 0x40 + ack_flag = 0x80 + seq_mask = 0x3f + + def __init__(self, timeout_s : float | None=None, *args, keep_alive_s : float | None=None, **kwargs): + super().__init__(*args, **kwargs) + self._encode_lock : asyncio.Lock = asyncio.Lock() + self._retransmitter : asyncio.Task | None = None + self._keep_alive : asyncio.Task | None = None + self._timeout_s : float | None = None + self._keep_alive_s : float | None = None + self._reset() + self.timeout_s = timeout_s + self.keep_alive_s = keep_alive_s + + @property + def timeout_s(self) -> float | None: + return self._timeout_s + + @timeout_s.setter + def timeout_s(self, value : float | None) -> None: + self._timeout_s = value + + if value is None and self._retransmitter is not None: + self._retransmitter.cancel() + self._retransmitter = None + elif value is not None and self._retransmitter is None: + self._retransmitter = asyncio.create_task(self._retransmitter_task(), name=self.__class__.__name__) + + @property + def keep_alive_s(self) -> float | None: + return self._keep_alive_s + + @keep_alive_s.setter + def keep_alive_s(self, value : float | None) -> None: + self._keep_alive_s = value + + if value is not None and self.timeout_s is None: + self.timeout_s = value + + if value is None and self._keep_alive is not None: + self._keep_alive.cancel() + self._keep_alive = None + elif value is not None and self._keep_alive is None: + self._keep_alive = asyncio.create_task(self._keep_alive_task(), name=self.__class__.__name__) + + def _reset(self) -> None: + self._encode_queue : list[bytes] = [bytes([self.nop_flag])] + self._send_seq : int = self._next_seq(0) + self._recv_seq : int = 0 + self._sent : bool = False + self._pause_transmit : bool = False + self._t_sent : float = time.time() + + async def decode(self, data : ProtocolLayer.Packet) -> None: + if isinstance(data, str): + data = data.encode() + if isinstance(data, memoryview): + data = data.cast('B') + else: + data = memoryview(data).cast('B') + + resp = b'' + reset_handshake = False + do_transmit = False + do_decode = False + + assert not self._pause_transmit + + while len(data) > 0: + hdr = data[0] + hdr_seq = hdr & self.seq_mask + + if hdr & self.ack_flag: + if hdr_seq == 0: + reset_handshake = True + + if self.waiting_for_ack and hdr_seq == (self._encode_queue[0][0] & self.seq_mask): + # Ack received for sent data. + self._encode_queue.pop(0) + do_transmit = True + + if reset_handshake: + self._recv_seq = self._next_seq(0) + await super().connected() + + data = data[1:] + elif hdr_seq == 0: + # Reset handshake. + resp += bytes([self.ack_flag]) + data = b'' + + if not reset_handshake: + self._reset() + do_transmit = True + await super().disconnected() + elif hdr_seq == self._recv_seq: + # Next message. + resp += bytes([self.ack_flag | hdr_seq]) + self._recv_seq = self._next_seq(self._recv_seq) + do_decode = not (hdr & self.nop_flag) + do_transmit = True + data = data[1:] + elif self._next_seq(hdr_seq) == self._recv_seq: + # Duplicate message, re-ack it. + resp += bytes([self.ack_flag | hdr_seq]) + if hdr & self.nop_flag: + data = data[1:] + else: + # Already decoded. + data = b'' + else: + # Drop. + data = b'' + do_transmit = True + + if do_decode: + break + + if do_decode: + self._pause_transmit = True + try: + await super().decode(data) + finally: + self._pause_transmit = False + + if do_transmit or len(resp) > 0: + await self._transmit(resp) + + @property + def waiting_for_ack(self) -> bool: + return len(self._encode_queue) > 0 and self._sent + + async def encode(self, data : ProtocolLayer.Packet) -> None: + if len(data) == 0: + return + + is_idle = not self.waiting_for_ack + self._push_encode_queue(data) + if is_idle and not self._pause_transmit: + await self._transmit() + + def _push_encode_queue(self, data : ProtocolLayer.Packet) -> None: + if isinstance(data, str): + data = data.encode() + elif isinstance(data, memoryview): + data = data.cast('B') + + self._encode_queue.append(bytes([self._send_seq]) + data) + self._send_seq = self._next_seq(self._send_seq) + + def _next_seq(self, seq : int) -> int: + seq = (seq + 1) & self.seq_mask + if seq == 0: + seq = 1 + return seq + + async def _transmit(self, prefix : bytes = b'') -> bool: + async with self._encode_lock: + self._t_sent = time.time() + + if len(self._encode_queue) == 0: + if prefix == b'': + return False + await super().encode(prefix) + return True + + self._sent = True + assert self.waiting_for_ack + await super().encode(prefix + self._encode_queue[0]) + return True + + async def connected(self) -> None: + self._reset() + await self._transmit() + await super().connected() + + async def retransmit(self) -> None: + self.logger.debug('retransmit') + await self._transmit() + + async def timeout(self) -> None: + if self.waiting_for_ack: + await self.retransmit() + else: + await self.keep_alive() + + @property + def mtu(self) -> int | None: + m = super().mtu + if m is None or m <= 0: + return None + return max(1, m - 1) + + async def _retransmitter_task(self) -> None: + try: + while True: + if self._timeout_s is None: + return + + if not self.waiting_for_ack and self._sent: + await asyncio.sleep(self._timeout_s) + else: + dt = time.time() - self._t_sent + t_rem = self._timeout_s - dt + if t_rem <= 0: + await self.retransmit() + else: + await asyncio.sleep(t_rem) + except asyncio.CancelledError: + pass + except Exception as e: + await self.async_except(e) + raise + + async def keep_alive(self) -> None: + if self.waiting_for_ack: + return + + if len(self._encode_queue) > 0 and not self._sent: + await self._transmit() + return + + self.logger.debug('keep alive') + self._encode_queue.append(bytes([self._send_seq | self.nop_flag])) + self._send_seq = self._next_seq(self._send_seq) + await self._transmit() + + async def _keep_alive_task(self) -> None: + try: + while True: + if self._keep_alive_s is None: + return + + dt = time.time() - self._t_sent + t_rem = self._keep_alive_s - dt + if t_rem <= 0: + await self.keep_alive() + else: + await asyncio.sleep(t_rem) + except asyncio.CancelledError: + pass + except Exception as e: + await self.async_except(e) + raise + + + class Crc8Layer(ProtocolLayer): ''' ProtocolLayer to add and check integrity using a CRC8. @@ -1356,9 +1612,10 @@ async def connected(self) -> None: AsciiEscapeLayer, TerminalLayer, PubTerminalLayer, - RepReqCheckLayer, + ReqRepCheckLayer, SegmentationLayer, DebugArqLayer, + ArqLayer, Crc8Layer, Crc16Layer, Crc32Layer, diff --git a/python/libstored/wrapper/serial.py b/python/libstored/wrapper/serial.py index d9b59dd5..76b79e55 100644 --- a/python/libstored/wrapper/serial.py +++ b/python/libstored/wrapper/serial.py @@ -44,7 +44,7 @@ async def async_main(args : argparse.Namespace): stack = lprot.build_stack( ','.join([ f'zmq={args.zmqlisten}:{args.zmqport}', - 'repreqcheck', + 'reqrepcheck', re.sub(r'\bpubterm\b(,|$)', f'pubterm={args.zmqlisten}:{args.zmqport+1}\\1', args.stack)]) ) diff --git a/python/libstored/wrapper/stdio.py b/python/libstored/wrapper/stdio.py index b3d92baf..28efea89 100644 --- a/python/libstored/wrapper/stdio.py +++ b/python/libstored/wrapper/stdio.py @@ -43,7 +43,7 @@ async def async_main(args : argparse.Namespace): stack = lprot.build_stack( ','.join([ f'zmq={args.listen}:{args.port}', - 'repreqcheck', + 'reqrepcheck', re.sub(r'\bpubterm\b(,|$)', f'pubterm={args.listen}:{args.port+1}\\1', args.stack)]) ) diff --git a/sphinx/doc/py_py.rst b/sphinx/doc/py_py.rst index c4d0a979..fb3310ea 100644 --- a/sphinx/doc/py_py.rst +++ b/sphinx/doc/py_py.rst @@ -24,46 +24,48 @@ Protocol layers .. autoclass:: libstored.protocol.Aes256Layer -.. autoclass:: libstored.protocol.AsciiEscapeLayer +.. autoclass:: libstored.protocol.ArqLayer -.. autoclass:: libstored.protocol.TerminalLayer +.. autoclass:: libstored.protocol.AsciiEscapeLayer -.. autoclass:: libstored.protocol.PubTerminalLayer +.. autoclass:: libstored.protocol.Crc16Layer -.. autoclass:: libstored.protocol.RepReqCheckLayer +.. autoclass:: libstored.protocol.Crc32Layer -.. autoclass:: libstored.protocol.SegmentationLayer +.. autoclass:: libstored.protocol.Crc8Layer .. autoclass:: libstored.protocol.DebugArqLayer -.. autoclass:: libstored.protocol.Crc8Layer +.. autoclass:: libstored.protocol.FileLayer -.. autoclass:: libstored.protocol.Crc16Layer +.. autoclass:: libstored.protocol.LoopbackLayer -.. autoclass:: libstored.protocol.Crc32Layer +.. autoclass:: libstored.protocol.MuxLayer -.. autoclass:: libstored.protocol.LoopbackLayer +.. autoclass:: libstored.protocol.PrintLayer + +.. autoclass:: libstored.protocol.PubTerminalLayer .. autoclass:: libstored.protocol.RawLayer -.. autoclass:: libstored.protocol.MuxLayer +.. autoclass:: libstored.protocol.ReqRepCheckLayer + +.. autoclass:: libstored.protocol.SegmentationLayer + +.. autoclass:: libstored.protocol.SerialLayer .. autoclass:: libstored.protocol.StdinLayer .. autoclass:: libstored.protocol.StdioLayer -.. autoclass:: libstored.protocol.PrintLayer - -.. autoclass:: libstored.protocol.SerialLayer +.. autoclass:: libstored.protocol.TerminalLayer -.. autoclass:: libstored.protocol.FileLayer +.. autoclass:: libstored.protocol.ZmqServer .. autoclass:: libstored.protocol.ZmqSocketClient .. autoclass:: libstored.protocol.ZmqSocketServer -.. autoclass:: libstored.protocol.ZmqServer - Protocol stack -------------- diff --git a/version/version.txt b/version/version.txt index 024a3a24..a74d18b8 100644 --- a/version/version.txt +++ b/version/version.txt @@ -1 +1 @@ -2.1.1-alpha +2.2.0-alpha From 3589dc6703b3df5414e3220521b5d9213815b4c1 Mon Sep 17 00:00:00 2001 From: Jochem Rutgers <68805714+jhrutgers@users.noreply.github.com> Date: Wed, 17 Dec 2025 15:30:22 +0100 Subject: [PATCH 05/15] add black --- .vscode/extensions.json | 3 ++- .vscode/settings.json | 5 ++++- python/pyproject.toml | 3 +++ 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/.vscode/extensions.json b/.vscode/extensions.json index 10a1f6b9..38dee031 100644 --- a/.vscode/extensions.json +++ b/.vscode/extensions.json @@ -10,6 +10,7 @@ "streetsidesoftware.code-spell-checker", "discretegames.f5anything", "ms-vscode.cmake-tools", - "spadin.memento-inputs" + "spadin.memento-inputs", + "ms-python.black-formatter" ] } \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index 3ff625b1..18290749 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -35,4 +35,7 @@ "python.defaultInterpreterPath": "${workspaceFolder}/dist/venv/bin/python3", "vim.textwidth": 100, "vim.tabstop": 8, -} + "[python]": { + "editor.defaultFormatter": "ms-python.black-formatter", + } +} \ No newline at end of file diff --git a/python/pyproject.toml b/python/pyproject.toml index e3bc0318..af90604c 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -5,3 +5,6 @@ [build-system] requires = ["setuptools>=51", "wheel"] build-backend = "setuptools.build_meta" + +[tool.black] +line-length = 100 From df7556055067e72a15a253a1a7aee7671655810a Mon Sep 17 00:00:00 2001 From: Jochem Rutgers <68805714+jhrutgers@users.noreply.github.com> Date: Wed, 17 Dec 2025 15:43:40 +0100 Subject: [PATCH 06/15] apply formatting --- python/libstored/asyncio/csv.py | 303 ++-- python/libstored/asyncio/event.py | 60 +- python/libstored/asyncio/tk.py | 269 ++-- python/libstored/asyncio/worker.py | 151 +- python/libstored/asyncio/zmq.py | 1725 +++++++++++++---------- python/libstored/cli/__main__.py | 43 +- python/libstored/cmake/__main__.py | 79 +- python/libstored/exceptions.py | 41 +- python/libstored/generator/__main__.py | 574 ++++---- python/libstored/generator/dsl/types.py | 455 ++++-- python/libstored/gui/__main__.py | 770 ++++++---- python/libstored/heatshrink.py | 63 +- python/libstored/log/__main__.py | 127 +- python/libstored/protocol/file.py | 48 +- python/libstored/protocol/protocol.py | 781 +++++----- python/libstored/protocol/serial.py | 60 +- python/libstored/protocol/stdio.py | 145 +- python/libstored/protocol/util.py | 114 +- python/libstored/protocol/zmq.py | 197 ++- python/libstored/protocol/zmqcat.py | 48 +- python/libstored/tk.py | 61 +- python/libstored/wrapper/serial.py | 78 +- python/libstored/wrapper/stdio.py | 56 +- 23 files changed, 3717 insertions(+), 2531 deletions(-) diff --git a/python/libstored/asyncio/csv.py b/python/libstored/asyncio/csv.py index bb82852a..749638a0 100644 --- a/python/libstored/asyncio/csv.py +++ b/python/libstored/asyncio/csv.py @@ -17,25 +17,49 @@ from . import worker as laio_worker from . import zmq as laio_zmq + @overload -def generate_filename(filename : str | None=None, *, - add_timestamp : bool=False, ext : str='.csv', now : time.struct_time | float | None=None, - unique : bool=False) -> str: ... +def generate_filename( + filename: str | None = None, + *, + add_timestamp: bool = False, + ext: str = ".csv", + now: time.struct_time | float | None = None, + unique: bool = False, +) -> str: ... @overload -def generate_filename(*, base : str, - add_timestamp : bool=False, ext : str='.csv', now : time.struct_time | float | None=None, - unique : bool=False) -> str: ... +def generate_filename( + *, + base: str, + add_timestamp: bool = False, + ext: str = ".csv", + now: time.struct_time | float | None = None, + unique: bool = False, +) -> str: ... @overload -def generate_filename(filename : list[str] | str | None=None, *, base : list[str] | str | None=None, - add_timestamp : bool=False, ext : list[str] | str='.csv', now : time.struct_time | float | None=None, - unique : bool=False) -> str | list[str]: ... - -def generate_filename(filename : list[str] | str | None=None, *, base : list[str] | str | None=None, - add_timestamp : bool=False, ext : list[str] | str='.csv', now : time.struct_time | float | None=None, - unique : bool=False) -> str | list[str]: +def generate_filename( + filename: list[str] | str | None = None, + *, + base: list[str] | str | None = None, + add_timestamp: bool = False, + ext: list[str] | str = ".csv", + now: time.struct_time | float | None = None, + unique: bool = False, +) -> str | list[str]: ... + + +def generate_filename( + filename: list[str] | str | None = None, + *, + base: list[str] | str | None = None, + add_timestamp: bool = False, + ext: list[str] | str = ".csv", + now: time.struct_time | float | None = None, + unique: bool = False, +) -> str | list[str]: if filename is None and base is None: - raise ValueError('Specify filename and/or base') + raise ValueError("Specify filename and/or base") return_list = False @@ -84,7 +108,7 @@ def generate_filename(filename : list[str] | str | None=None, *, base : list[str # Append timestamps to the generated bases. if add_timestamp: for i in range(0, len(names)): - names[i] = (names[i][0] + '_%Y%m%dT%H%M%S%z', names[i][1]) + names[i] = (names[i][0] + "_%Y%m%dT%H%M%S%z", names[i][1]) # Time-format collected bases. for i in range(0, len(names)): @@ -94,7 +118,7 @@ def generate_filename(filename : list[str] | str | None=None, *, base : list[str if unique: for i in range(0, len(names)): suffix_nr = 1 - suffix = '' + suffix = "" n = names[i][0] + names[i][1] while True: # Check if the file already exists. @@ -115,7 +139,7 @@ def generate_filename(filename : list[str] | str | None=None, *, base : list[str # Pick another suffix and retry. suffix_nr += 1 - suffix = f'_{suffix_nr}' + suffix = f"_{suffix_nr}" n = names[i][0] + suffix + names[i][1] # Combine bases/exts. @@ -129,15 +153,23 @@ def generate_filename(filename : list[str] | str | None=None, *, base : list[str else: return names + class CsvExport(laio_worker.Work): - ''' + """ asyncio csv exporter via AsyncioWorker. - ''' - - def __init__(self, filename : str = 'out.csv', *, - auto_write : float | None=None, write_on_change : bool=True, auto_flush : float | None=1.0, - worker : laio_worker.AsyncioWorker | None=None, logger : logging.Logger | None=None, - **fmtargs): + """ + + def __init__( + self, + filename: str = "out.csv", + *, + auto_write: float | None = None, + write_on_change: bool = True, + auto_flush: float | None = 1.0, + worker: laio_worker.AsyncioWorker | None = None, + logger: logging.Logger | None = None, + **fmtargs, + ): super().__init__(worker=worker, logger=logger) self._out = io.StringIO() @@ -145,19 +177,19 @@ def __init__(self, filename : str = 'out.csv', *, self._filename = filename self._file_context = None self._file = None - self._objs : dict[laio_zmq.Object, typing.Any] = {} - self._t_last : float = 0 - self._t_update : float = 0 - self._coalesced : tuple[float, list[typing.Any]] = (0.0, []) + self._objs: dict[laio_zmq.Object, typing.Any] = {} + self._t_last: float = 0 + self._t_update: float = 0 + self._coalesced: tuple[float, list[typing.Any]] = (0.0, []) self._write_sem = asyncio.BoundedSemaphore(1) - self._auto_write_task : typing.Optional[asyncio.Task[None]] = None - self._auto_write : float | None = auto_write + self._auto_write_task: typing.Optional[asyncio.Task[None]] = None + self._auto_write: float | None = auto_write - self._write_on_change_task : typing.Optional[asyncio.Task[None]] = None + self._write_on_change_task: typing.Optional[asyncio.Task[None]] = None self._write_on_change = write_on_change and auto_write is None - self._auto_flush_task : typing.Optional[asyncio.Task[None]] = None + self._auto_flush_task: typing.Optional[asyncio.Task[None]] = None self._auto_flush = auto_flush @property @@ -165,20 +197,22 @@ def opened(self) -> bool: return self._file is not None @property - def file(self): # -> some aiofile type + def file(self): # -> some aiofile type if not self.opened: - raise RuntimeError('File not opened') + raise RuntimeError("File not opened") assert self._file is not None return self._file @overload async def open(self) -> None: ... @overload - def open(self, *, block : typing.Literal[False]) -> asyncio.Future[bool]: ... + def open(self, *, block: typing.Literal[False]) -> asyncio.Future[bool]: ... @overload - def open(self, *, sync : typing.Literal[True]) -> bool: ... + def open(self, *, sync: typing.Literal[True]) -> bool: ... @overload - def open(self, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[bool]: ... + def open( + self, *, block: typing.Literal[False], sync: typing.Literal[True] + ) -> concurrent.futures.Future[bool]: ... @laio_worker.Work.run_sync @laio_worker.Work.locked @@ -189,21 +223,21 @@ async def _open(self): assert self.lock.has_lock() if self.opened: - raise RuntimeError('File already opened') + raise RuntimeError("File already opened") - if self._filename == '-': - self.logger.info('using stdout for CSV export') + if self._filename == "-": + self.logger.info("using stdout for CSV export") self._file_context = None self._file = aiofiles.stdout else: - self.logger.info('using %s for CSV export', self._filename) - self._file_context = aiofiles.open(self._filename, 'w', newline='', encoding='utf-8') + self.logger.info("using %s for CSV export", self._filename) + self._file_context = aiofiles.open(self._filename, "w", newline="", encoding="utf-8") self._file = await self._file_context.__aenter__() self._t_last = 0 self._t_update = 0 self._need_restart = True - self._coalesced : tuple[float, list[typing.Any]] = (0.0, []) + self._coalesced: tuple[float, list[typing.Any]] = (0.0, []) self._queue = [] self._update_auto_write(self._auto_write) @@ -217,11 +251,13 @@ async def __aenter__(self): @overload async def close(self) -> None: ... @overload - def close(self, *, block : typing.Literal[False]) -> asyncio.Future[bool]: ... + def close(self, *, block: typing.Literal[False]) -> asyncio.Future[bool]: ... @overload - def close(self, *, sync : typing.Literal[True]) -> bool: ... + def close(self, *, sync: typing.Literal[True]) -> bool: ... @overload - def close(self, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[bool]: ... + def close( + self, *, block: typing.Literal[False], sync: typing.Literal[True] + ) -> concurrent.futures.Future[bool]: ... @laio_worker.Work.run_sync @laio_worker.Work.locked @@ -241,7 +277,7 @@ async def _close(self): self._update_write_on_change(False) self._update_auto_flush(None) except Exception as e: - self.logger.debug('ignore exception: %s', e) + self.logger.debug("ignore exception: %s", e) try: data = self._out.getvalue() @@ -251,7 +287,7 @@ async def _close(self): if data: await self._file.write(data) except Exception as e: - self.logger.debug('ignore exception: %s', e) + self.logger.debug("ignore exception: %s", e) self._file = None @@ -259,10 +295,10 @@ async def _close(self): try: await self._file_context.__aexit__(None, None, None) except Exception as e: - self.logger.debug('ignore exception: %s', e) + self.logger.debug("ignore exception: %s", e) finally: self._file_context = None - self.logger.debug('closed %s', self._filename) + self.logger.debug("closed %s", self._filename) async def __aexit__(self, exc_type, exc_value, traceback): await self.close() @@ -273,11 +309,11 @@ def __del__(self): @laio_worker.Work.locked async def _restart(self): - self.logger.debug('restart') + self.logger.debug("restart") self._out.truncate(0) self._out.seek(0) - header = ['t (s)'] + header = ["t (s)"] for obj in self._objs.keys(): header.append(obj.name) if len(header) > 1: @@ -287,23 +323,34 @@ async def _restart(self): self._need_restart = False @overload - async def write(self, t : float | None=None, *, flush : bool=False) -> None: ... + async def write(self, t: float | None = None, *, flush: bool = False) -> None: ... @overload - def write(self, t : float | None=None, *, flush : bool=False, block : typing.Literal[False]) -> asyncio.Future[None]: ... + def write( + self, t: float | None = None, *, flush: bool = False, block: typing.Literal[False] + ) -> asyncio.Future[None]: ... @overload - def write(self, t : float | None=None, *, flush : bool=False, sync : typing.Literal[True]) -> None: ... + def write( + self, t: float | None = None, *, flush: bool = False, sync: typing.Literal[True] + ) -> None: ... @overload - def write(self, t : float | None=None, *, flush : bool=False, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[None]: ... + def write( + self, + t: float | None = None, + *, + flush: bool = False, + block: typing.Literal[False], + sync: typing.Literal[True], + ) -> concurrent.futures.Future[None]: ... @laio_worker.Work.run_sync - async def write(self, t : float | None=None, *, flush : bool=False) -> None: + async def write(self, t: float | None = None, *, flush: bool = False) -> None: self._collect(t) await self._write() if flush: await self._flush() - def _collect(self, t : float | None=None) -> None: + def _collect(self, t: float | None = None) -> None: if t is None: t = self._t_update @@ -325,7 +372,7 @@ async def _write(self) -> None: assert not self.lock.has_lock() if not self.opened: - raise RuntimeError('File not opened') + raise RuntimeError("File not opened") if self._need_restart: await self._restart() @@ -336,18 +383,20 @@ async def _write(self) -> None: for t, row in queue: if t > self._t_last: - self.logger.debug('write t=%.6f', t) + self.logger.debug("write t=%.6f", t) self._t_last = t self._writer.writerow(row) @overload async def flush(self) -> None: ... @overload - def flush(self, *, block : typing.Literal[False]) -> asyncio.Future[None]: ... + def flush(self, *, block: typing.Literal[False]) -> asyncio.Future[None]: ... @overload - def flush(self, *, sync : typing.Literal[True]) -> None: ... + def flush(self, *, sync: typing.Literal[True]) -> None: ... @overload - def flush(self, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[None]: ... + def flush( + self, *, block: typing.Literal[False], sync: typing.Literal[True] + ) -> concurrent.futures.Future[None]: ... @laio_worker.Work.run_sync async def flush(self) -> None: @@ -357,7 +406,7 @@ async def _flush(self) -> None: assert not self.lock.has_lock() if not self.opened: - raise RuntimeError('File not opened') + raise RuntimeError("File not opened") async with self.lock: data = self._out.getvalue() @@ -365,40 +414,48 @@ async def _flush(self) -> None: self._out.seek(0) if data: - self.logger.debug('flush') + self.logger.debug("flush") await self.file.write(data) await self.file.flush() - self.logger.debug('flushed') + self.logger.debug("flushed") @overload - async def add(self, obj : laio_zmq.Object) -> None: ... + async def add(self, obj: laio_zmq.Object) -> None: ... @overload - def add(self, obj : laio_zmq.Object, *, block : typing.Literal[False]) -> asyncio.Future[None]: ... + def add( + self, obj: laio_zmq.Object, *, block: typing.Literal[False] + ) -> asyncio.Future[None]: ... @overload - def add(self, obj : laio_zmq.Object, *, sync : typing.Literal[True]) -> None: ... + def add(self, obj: laio_zmq.Object, *, sync: typing.Literal[True]) -> None: ... @overload - def add(self, obj : laio_zmq.Object, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[None]: ... + def add( + self, obj: laio_zmq.Object, *, block: typing.Literal[False], sync: typing.Literal[True] + ) -> concurrent.futures.Future[None]: ... @laio_worker.Work.run_sync @laio_worker.Work.locked - async def add(self, obj : laio_zmq.Object) -> None: + async def add(self, obj: laio_zmq.Object) -> None: if obj not in self._objs: self._objs[obj] = await obj.read() obj.register(lambda v, o=obj: self._on_object_update(o, v), self) self._need_restart = True @overload - async def remove(self, obj : laio_zmq.Object) -> None: ... + async def remove(self, obj: laio_zmq.Object) -> None: ... @overload - def remove(self, obj : laio_zmq.Object, *, block : typing.Literal[False]) -> asyncio.Future[None]: ... + def remove( + self, obj: laio_zmq.Object, *, block: typing.Literal[False] + ) -> asyncio.Future[None]: ... @overload - def remove(self, obj : laio_zmq.Object, *, sync : typing.Literal[True]) -> None: ... + def remove(self, obj: laio_zmq.Object, *, sync: typing.Literal[True]) -> None: ... @overload - def remove(self, obj : laio_zmq.Object, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[None]: ... + def remove( + self, obj: laio_zmq.Object, *, block: typing.Literal[False], sync: typing.Literal[True] + ) -> concurrent.futures.Future[None]: ... @laio_worker.Work.run_sync @laio_worker.Work.locked - async def remove(self, obj : laio_zmq.Object) -> None: + async def remove(self, obj: laio_zmq.Object) -> None: if obj in self._objs: obj.unregister(self) del self._objs[obj] @@ -407,11 +464,13 @@ async def remove(self, obj : laio_zmq.Object) -> None: @overload async def clear(self) -> None: ... @overload - def clear(self, *, block : typing.Literal[False]) -> asyncio.Future[None]: ... + def clear(self, *, block: typing.Literal[False]) -> asyncio.Future[None]: ... @overload - def clear(self, *, sync : typing.Literal[True]) -> None: ... + def clear(self, *, sync: typing.Literal[True]) -> None: ... @overload - def clear(self, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[None]: ... + def clear( + self, *, block: typing.Literal[False], sync: typing.Literal[True] + ) -> concurrent.futures.Future[None]: ... @laio_worker.Work.run_sync @laio_worker.Work.locked @@ -422,7 +481,7 @@ async def clear(self) -> None: self._objs.clear() self._need_restart = True - def _on_object_update(self, obj : laio_zmq.Object, value : typing.Any) -> None: + def _on_object_update(self, obj: laio_zmq.Object, value: typing.Any) -> None: if obj in self._objs: self._objs[obj] = value obj_t = obj.t.value @@ -431,33 +490,37 @@ def _on_object_update(self, obj : laio_zmq.Object, value : typing.Any) -> None: self._collect() @overload - def auto_write(self, interval_s : float | None) -> None: ... + def auto_write(self, interval_s: float | None) -> None: ... @overload - def auto_write(self, interval_s : float | None, *, block : typing.Literal[False]) -> asyncio.Future[None]: ... + def auto_write( + self, interval_s: float | None, *, block: typing.Literal[False] + ) -> asyncio.Future[None]: ... @overload - def auto_write(self, interval_s : float | None, *, sync : typing.Literal[True]) -> None: ... + def auto_write(self, interval_s: float | None, *, sync: typing.Literal[True]) -> None: ... @overload - def auto_write(self, interval_s : float | None, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[None]: ... + def auto_write( + self, interval_s: float | None, *, block: typing.Literal[False], sync: typing.Literal[True] + ) -> concurrent.futures.Future[None]: ... @laio_worker.Work.thread_safe_async - def auto_write(self, interval_s : float | None) -> None: - ''' + def auto_write(self, interval_s: float | None) -> None: + """ Enable/disable automatic writing every interval_s seconds. If interval_s is None, automatic writing is disabled. - ''' + """ self._auto_write = interval_s if self.opened or interval_s is None: self._update_auto_write(self._auto_write) - def _update_auto_write(self, interval_s : float | None) -> None: + def _update_auto_write(self, interval_s: float | None) -> None: if self._auto_write_task is not None: self._auto_write_task.cancel() self._auto_write_task = None if interval_s is not None: if not self.opened: - raise RuntimeError('File not opened') + raise RuntimeError("File not opened") async def auto_write_task(): try: @@ -468,39 +531,45 @@ async def auto_write_task(): except asyncio.CancelledError: pass except: - self.logger.exception(f'Auto write task error') + self.logger.exception(f"Auto write task error") raise - self._auto_write_task = asyncio.create_task(auto_write_task(), name=f'{self.__class__.__name__} auto write') + self._auto_write_task = asyncio.create_task( + auto_write_task(), name=f"{self.__class__.__name__} auto write" + ) @overload - def write_on_change(self, enable : bool) -> None: ... + def write_on_change(self, enable: bool) -> None: ... @overload - def write_on_change(self, enable : bool, *, block : typing.Literal[False]) -> asyncio.Future[None]: ... + def write_on_change( + self, enable: bool, *, block: typing.Literal[False] + ) -> asyncio.Future[None]: ... @overload - def write_on_change(self, enable : bool, *, sync : typing.Literal[True]) -> None: ... + def write_on_change(self, enable: bool, *, sync: typing.Literal[True]) -> None: ... @overload - def write_on_change(self, enable : bool, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[None]: ... + def write_on_change( + self, enable: bool, *, block: typing.Literal[False], sync: typing.Literal[True] + ) -> concurrent.futures.Future[None]: ... @laio_worker.Work.thread_safe_async - def write_on_change(self, enable : bool) -> None: - ''' + def write_on_change(self, enable: bool) -> None: + """ Enable/disable writing on object value change. If enabled, a write is performed whenever an object's value is updated. - ''' + """ self._write_on_change = enable if self.opened or not enable: self._update_write_on_change(enable) - def _update_write_on_change(self, enable : bool) -> None: + def _update_write_on_change(self, enable: bool) -> None: if self._write_on_change_task is not None: self._write_on_change_task.cancel() self._write_on_change_task = None if enable: if not self.opened: - raise RuntimeError('File not opened') + raise RuntimeError("File not opened") async def write_on_change_task(): try: @@ -510,39 +579,45 @@ async def write_on_change_task(): except asyncio.CancelledError: pass except: - self.logger.exception(f'Write on change task error') + self.logger.exception(f"Write on change task error") raise - self._write_on_change_task = asyncio.create_task(write_on_change_task(), name=f'{self.__class__.__name__} write on change') + self._write_on_change_task = asyncio.create_task( + write_on_change_task(), name=f"{self.__class__.__name__} write on change" + ) @overload - def auto_flush(self, interval_s : float | None) -> None: ... + def auto_flush(self, interval_s: float | None) -> None: ... @overload - def auto_flush(self, interval_s : float | None, *, block : typing.Literal[False]) -> asyncio.Future[None]: ... + def auto_flush( + self, interval_s: float | None, *, block: typing.Literal[False] + ) -> asyncio.Future[None]: ... @overload - def auto_flush(self, interval_s : float | None, *, sync : typing.Literal[True]) -> None: ... + def auto_flush(self, interval_s: float | None, *, sync: typing.Literal[True]) -> None: ... @overload - def auto_flush(self, interval_s : float | None, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[None]: ... + def auto_flush( + self, interval_s: float | None, *, block: typing.Literal[False], sync: typing.Literal[True] + ) -> concurrent.futures.Future[None]: ... @laio_worker.Work.thread_safe_async - def auto_flush(self, interval_s : float | None) -> None: - ''' + def auto_flush(self, interval_s: float | None) -> None: + """ Enable/disable automatic flushing every interval_s seconds. If interval_s is None, automatic flushing is disabled. - ''' + """ self._auto_flush = interval_s if self.opened or interval_s is None: self._update_auto_flush(self._auto_flush) - def _update_auto_flush(self, interval_s : float | None) -> None: + def _update_auto_flush(self, interval_s: float | None) -> None: if self._auto_flush_task is not None: self._auto_flush_task.cancel() self._auto_flush_task = None if interval_s is not None: if not self.opened: - raise RuntimeError('File not opened') + raise RuntimeError("File not opened") async def auto_flush_task(): try: @@ -552,7 +627,9 @@ async def auto_flush_task(): except asyncio.CancelledError: pass except: - self.logger.exception(f'Auto flush task error') + self.logger.exception(f"Auto flush task error") raise - self._auto_flush_task = asyncio.create_task(auto_flush_task(), name=f'{self.__class__.__name__} auto flush') + self._auto_flush_task = asyncio.create_task( + auto_flush_task(), name=f"{self.__class__.__name__} auto flush" + ) diff --git a/python/libstored/asyncio/event.py b/python/libstored/asyncio/event.py index 2ef7ddd2..07ab3dfd 100644 --- a/python/libstored/asyncio/event.py +++ b/python/libstored/asyncio/event.py @@ -11,12 +11,13 @@ from . import worker as laio_worker from .. import exceptions as lexc + class Event: logger = logging.getLogger(__name__) - def __init__(self, event_name : str | None=None, *args, **kwargs): + def __init__(self, event_name: str | None = None, *args, **kwargs): super().__init__(*args, **kwargs) - self._callbacks : typing.Dict[typing.Hashable, typing.Callable] = {} + self._callbacks: typing.Dict[typing.Hashable, typing.Callable] = {} self._key = 0 self._queued = None self._paused = False @@ -24,12 +25,18 @@ def __init__(self, event_name : str | None=None, *args, **kwargs): self._lock = lexc.DeadlockChecker(threading.RLock()) def __repr__(self) -> str: - return f'{self.__class__.__name__}({self._event_name})' if self._event_name is not None else super().__repr__() + return ( + f"{self.__class__.__name__}({self._event_name})" + if self._event_name is not None + else super().__repr__() + ) def __str__(self) -> str: return self._event_name if self._event_name is not None else super().__str__() - def register(self, callback : typing.Callable, id : typing.Hashable | None=None) -> typing.Hashable: + def register( + self, callback: typing.Callable, id: typing.Hashable | None = None + ) -> typing.Hashable: with self._lock: if id is not None and id in self._callbacks: raise KeyError(f"Callback with id {id} already registered") @@ -42,7 +49,7 @@ def register(self, callback : typing.Callable, id : typing.Hashable | None=None) self._callbacks[id] = callback return id - def unregister(self, id : typing.Hashable): + def unregister(self, id: typing.Hashable): c = [] with self._lock: if id in self._callbacks: @@ -74,7 +81,7 @@ def paused(self) -> bool: return self._paused def trigger(self, *args, **kwargs): - self.logger.debug('trigger %s', repr(self)) + self.logger.debug("trigger %s", repr(self)) callbacks = [] with self._lock: @@ -95,7 +102,7 @@ def trigger(self, *args, **kwargs): else: callback(*bound.args, **bound.kwargs) except Exception as e: - self.logger.exception(f'Exception in {repr(self)} callback: {e}') + self.logger.exception(f"Exception in {repr(self)} callback: {e}") def __call__(self, *args, **kwargs): self.trigger(*args, **kwargs) @@ -118,11 +125,16 @@ def _safe_callback(): finally: self.unregister(key) + class ValueWrapper(Event): - def __init__(self, type : typing.Type, - get : typing.Callable[[], typing.Any] | None = None, - set : typing.Callable[[typing.Any], None] | None = None, - *args, **kwargs): + def __init__( + self, + type: typing.Type, + get: typing.Callable[[], typing.Any] | None = None, + set: typing.Callable[[typing.Any], None] | None = None, + *args, + **kwargs, + ): super().__init__(*args, **kwargs) self._type = type self._get = get @@ -140,7 +152,7 @@ def value(self) -> typing.Any: return x @value.setter - def value(self, value : typing.Any): + def value(self, value: typing.Any): if self._set is None: raise AttributeError("not writable") if value is not None and not isinstance(value, self.type): @@ -150,10 +162,10 @@ def value(self, value : typing.Any): def get(self) -> typing.Any: return self.value - def set(self, value : typing.Any): + def set(self, value: typing.Any): self.value = value - def trigger(self, value : typing.Any = None): + def trigger(self, value: typing.Any = None): if value is not None and not isinstance(value, self.type): raise TypeError(f"expected {self.type}, got {type(value)}") super().trigger(value if value is not None else self.value) @@ -171,8 +183,9 @@ def type(self) -> typing.Type: def __str__(self) -> str: return str(self.value) + class Value(ValueWrapper): - def __init__(self, type : typing.Type, initial : typing.Any=None, *args, **kwargs): + def __init__(self, type: typing.Type, initial: typing.Any = None, *args, **kwargs): super().__init__(type, self._get, self._set, *args, **kwargs) if initial is not None and not isinstance(initial, type): raise TypeError(f"expected {type}, got {type(initial)}") @@ -181,13 +194,14 @@ def __init__(self, type : typing.Type, initial : typing.Any=None, *args, **kwarg def _get(self) -> typing.Any: return self._value - def _set(self, value : typing.Any): + def _set(self, value: typing.Any): if self._value != value: self._value = value self.trigger() + class AsyncioRateLimit(laio_worker.Work, Event): - ''' + """ Event that can be triggered, but not more often than a specified minimum interval. If triggered more often, only the last trigger arguments are used, and the event @@ -195,9 +209,11 @@ class AsyncioRateLimit(laio_worker.Work, Event): The trigger() method is thread-safe. The timer callback is executed in the worker's event loop. - ''' + """ - def __init__(self, Hz : float | None=None, min_interval_s : float | None=None, *args, **kwargs): + def __init__( + self, Hz: float | None = None, min_interval_s: float | None = None, *args, **kwargs + ): super().__init__(*args, **kwargs) if Hz is not None and min_interval_s is not None: @@ -212,10 +228,10 @@ def __init__(self, Hz : float | None=None, min_interval_s : float | None=None, * assert min_interval_s is not None self._min_interval_s = max(0, min_interval_s) self._last_trigger = 0.0 - self._timer : asyncio.TimerHandle | None = None - self._args : tuple[tuple, dict] = ((), {}) + self._timer: asyncio.TimerHandle | None = None + self._args: tuple[tuple, dict] = ((), {}) - def unregister(self, id : typing.Hashable): + def unregister(self, id: typing.Hashable): super().unregister(id) if len(self) == 0 and self._timer is not None: self._timer.cancel() diff --git a/python/libstored/asyncio/tk.py b/python/libstored/asyncio/tk.py index dfada03b..d2704d26 100644 --- a/python/libstored/asyncio/tk.py +++ b/python/libstored/asyncio/tk.py @@ -22,10 +22,11 @@ from . import zmq as laio_zmq from .. import exceptions as lexc + class AsyncTk: - ''' + """ A thread running a tkinter mainloop. - ''' + """ def __init__(self, cb_init=None, *args, **kwargs): super().__init__(*args, **kwargs) @@ -47,7 +48,7 @@ def thread(self) -> threading.Thread | None: return self._thread def start(self): - ''' + """ Start the tkinter mainloop in a separate thread. Call only once. @@ -55,13 +56,13 @@ def start(self): Note that Tk is not fully thread-safe. This call is not recommended. Just call run() from the main thread instead. - ''' + """ if self._started: raise lexc.InvalidState("Mainloop already started") self.logger.debug("Starting mainloop") - self._thread = threading.Thread(target=self._run, daemon=False, name='AsyncTk') + self._thread = threading.Thread(target=self._run, daemon=False, name="AsyncTk") self._thread.start() while not self._started: time.sleep(0.1) @@ -70,11 +71,11 @@ def start(self): @property def root(self) -> tk.Tk: - ''' + """ Return the root Tk instance. Only to be called from within the Tk context (run()). - ''' + """ if threading.current_thread() != self.thread: raise lexc.InvalidState("Accessing tk from wrong thread") @@ -83,11 +84,11 @@ def root(self) -> tk.Tk: return self._root def run(self): - ''' + """ Run the tkinter mainloop. Call from the main thread. - ''' + """ if threading.current_thread() != threading.main_thread(): raise lexc.InvalidState("run() must be called from the main thread") @@ -101,17 +102,19 @@ def run(self): self._run_from_main = False def _run(self): - ''' + """ Run the tkinter mainloop. Call from the main thread, or via start(). - ''' + """ self._do_async = True self._started = True if not self._run_from_main: - self.logger.warning('Running Tk mainloop in a separate thread. This is not recommended, as Tk is not fully thread-safe.') + self.logger.warning( + "Running Tk mainloop in a separate thread. This is not recommended, as Tk is not fully thread-safe." + ) try: while True: @@ -121,9 +124,11 @@ def _run(self): try: self._root = tk.Tk() - self._root.report_callback_exception = lambda *args: self.logger.exception('Unhandled exception in Tk', exc_info=args) + self._root.report_callback_exception = lambda *args: self.logger.exception( + "Unhandled exception in Tk", exc_info=args + ) self._root.protocol("WM_DELETE_WINDOW", self._on_stop) - self._root.bind('<>', self._on_async_call) + self._root.bind("<>", self._on_async_call) init = None if self._cb_init is not None: @@ -149,17 +154,17 @@ def _run(self): gc.collect() self.logger.debug("thread exit") - def _on_stop(self, event = None): + def _on_stop(self, event=None): self._on_stopping() if self._root is not None: self._root.destroy() - def _on_stopping(self, event = None): + def _on_stopping(self, event=None): if self._do_async: self.logger.debug("Prevent further async calls") self._do_async = False - def _dump_referrers(self, obj, depth : int=3, indent : str=''): + def _dump_referrers(self, obj, depth: int = 3, indent: str = ""): if depth < 0 or obj is None: return @@ -167,10 +172,10 @@ def _dump_referrers(self, obj, depth : int=3, indent : str=''): if ref == []: return - self.logger.debug(f'{indent}{repr(obj)}: {len(ref)} referrers') + self.logger.debug(f"{indent}{repr(obj)}: {len(ref)} referrers") for r in ref: - self._dump_referrers(r, depth-1, indent + ' ') + self._dump_referrers(r, depth - 1, indent + " ") def _stop(self): if self._root is None: @@ -179,11 +184,11 @@ def _stop(self): self._root.quit() def stop(self): - ''' + """ Stop the tkinter mainloop and wait for the thread to exit. Thread-safe. - ''' + """ if self._thread is None: return @@ -199,11 +204,11 @@ def stop(self): self.logger.debug("Mainloop stopped") def wait(self, timeout=None): - ''' + """ Wait for the tkinter mainloop thread to exit. Thread-safe. - ''' + """ if self._thread is None: return @@ -217,13 +222,14 @@ def wait(self, timeout=None): self.logger.debug("Mainloop stopped") def is_running(self): - ''' + """ Check if the mainloop is running. Thread-safe. - ''' - return self._run_from_main or \ - (self._thread is not None and self._thread.is_alive() and self._root is not None) + """ + return self._run_from_main or ( + self._thread is not None and self._thread.is_alive() and self._root is not None + ) def __del__(self): self.stop() @@ -236,11 +242,11 @@ def __exit__(self, *args): self.stop() def execute(self, f, *args, **kwargs) -> concurrent.futures.Future: - ''' + """ Queue a function for the tkinter mainloop thread. Thread-safe. - ''' + """ if not self.is_running(): raise lexc.InvalidState("Mainloop is not running") @@ -248,7 +254,11 @@ def execute(self, f, *args, **kwargs) -> concurrent.futures.Future: self.logger.debug("Queueing async call to %s", f.__qualname__) future = concurrent.futures.Future() try: - self._queue.put((f, args, kwargs, future), block=True, timeout=lexc.DeadlockChecker.default_timeout_s) + self._queue.put( + (f, args, kwargs, future), + block=True, + timeout=lexc.DeadlockChecker.default_timeout_s, + ) except queue.Full: raise lexc.Deadlock("AsyncTk queue full") from None @@ -258,7 +268,7 @@ def execute(self, f, *args, **kwargs) -> concurrent.futures.Future: return future assert self._root is not None - self._root.event_generate('<>', when='tail') + self._root.event_generate("<>", when="tail") return future def _on_async_call(self, event): @@ -269,26 +279,28 @@ def _on_async_call(self, event): try: future.set_result(func(*args, **kwargs)) except BaseException as e: - self.logger.debug('Exception in async call to %s', func.__qualname__, exc_info=True) + self.logger.debug( + "Exception in async call to %s", func.__qualname__, exc_info=True + ) future.set_exception(e) except queue.Empty: pass - class Work: - ''' + """ Mixin class for all async Tk modules. - ''' - def __init__(self, atk : AsyncTk, logger : logging.Logger | None=None, *args, **kwargs): + """ + + def __init__(self, atk: AsyncTk, logger: logging.Logger | None = None, *args, **kwargs): super().__init__(*args, **kwargs) self.logger = logger if logger is not None else logging.getLogger(self.__class__.__name__) - self._atk : AsyncTk = atk - self._connections : dict[typing.Hashable, laio_event.Event] = {} + self._atk: AsyncTk = atk + self._connections: dict[typing.Hashable, laio_event.Event] = {} self._connections_key = 0 - if hasattr(self, 'bind') and callable(getattr(self, 'bind')): + if hasattr(self, "bind") and callable(getattr(self, "bind")): # Assume this is a tk widget. typing.cast(tk.Widget, self).bind("", self._on_destroy) else: @@ -300,7 +312,7 @@ def atk(self) -> AsyncTk: @staticmethod def tk_func(f) -> typing.Callable[..., typing.Any | asyncio.Future | concurrent.futures.Future]: - ''' + """ Decorator to mark a function to be executed in the tk context. The decorated function must be a regular function. @@ -308,10 +320,10 @@ def tk_func(f) -> typing.Callable[..., typing.Any | asyncio.Future | concurrent. Otherwise, the call is blocking, and the result is returned. When block=True is passed, the call is always blocking, and the result is returned. - ''' + """ @functools.wraps(f) - def tk_func(self : Work, *args, block : bool=False, **kwargs) -> typing.Any: + def tk_func(self: Work, *args, block: bool = False, **kwargs) -> typing.Any: # self.logger.debug(f'Scheduling {f} in tk') try: @@ -338,12 +350,15 @@ def tk_func(self : Work, *args, block : bool=False, **kwargs) -> typing.Any: else: return future except BaseException as e: - self.logger.debug(f'Exception {e} in scheduling tk function {f}') + self.logger.debug(f"Exception {e} in scheduling tk function {f}") raise + return tk_func @tk_func - def connect(self, event : laio_event.Event, callback : typing.Callable, *args, **kwargs) -> typing.Hashable: + def connect( + self, event: laio_event.Event, callback: typing.Callable, *args, **kwargs + ) -> typing.Hashable: k = (self, self._connections_key) self._connections_key += 1 k = event.register(callback, k, *args, **kwargs) @@ -352,7 +367,7 @@ def connect(self, event : laio_event.Event, callback : typing.Callable, *args, * return k @tk_func - def disconnect(self, id : typing.Hashable): + def disconnect(self, id: typing.Hashable): if id in self._connections: event = self._connections[id] del self._connections[id] @@ -376,33 +391,38 @@ def _on_destroy(self, event): def __del__(self): # Tk is not thread-safe. Ignore this check. - #assert threading.current_thread() == self._atk.thread + # assert threading.current_thread() == self._atk.thread self.cleanup() - class AsyncApp(Work, ttk.Frame): - ''' + """ A ttk application, running Tk in a separate thread, and an asyncio worker in another thread. Calling between contexts is thread-safe, as long functions are decorated with @tk_func or @worker_func. - ''' + """ - def __init__(self, atk : AsyncTk, worker : laio_worker.AsyncioWorker, *args, **kwargs): - ''' + def __init__(self, atk: AsyncTk, worker: laio_worker.AsyncioWorker, *args, **kwargs): + """ Initialize the application. Do not call directly. Use create() instead. - ''' + """ super().__init__(atk=atk, master=atk.root, *args, **kwargs) self._worker = worker - self.grid(sticky='nsew') + self.grid(sticky="nsew") self.root.columnconfigure(0, weight=1) self.root.rowconfigure(0, weight=1) class Context: - def __init__(self, cls : typing.Type, worker : laio_worker.AsyncioWorker | laio_worker.Work | None=None, *args, **kwargs): + def __init__( + self, + cls: typing.Type, + worker: laio_worker.AsyncioWorker | laio_worker.Work | None = None, + *args, + **kwargs, + ): global default_worker self.cls = cls @@ -421,13 +441,15 @@ def __init__(self, cls : typing.Type, worker : laio_worker.AsyncioWorker | laio_ self.atk = AsyncTk(cb_init=self._init) - def _init(self, atk : AsyncTk): - return self.cls(*self.args, **self.kwargs, atk=atk, worker=self.worker, logger=self.logger) + def _init(self, atk: AsyncTk): + return self.cls( + *self.args, **self.kwargs, atk=atk, worker=self.worker, logger=self.logger + ) def __enter__(self): # Disabled, as Tk is not fully thread-safe. # Just call run() from the main thread instead. - #self.atk.start() + # self.atk.start() return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -436,7 +458,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): @classmethod def create(cls, *arg, **kwargs) -> Context: - ''' + """ Create an instance of the application, running Tk in a separate thread, and an asyncio worker in another thread. Usage: @@ -447,12 +469,12 @@ async def stuff(): with App.create() as app: app.worker.execute(stuff()) app.atk.run() - ''' + """ return cls.Context(cls, *arg, **kwargs) @classmethod - def run(cls, *args, coro: typing.Coroutine | None=None, **kwargs): - ''' + def run(cls, *args, coro: typing.Coroutine | None = None, **kwargs): + """ Create and run an instance of the application, running Tk in the main thread, and an asyncio worker in another thread. When coro is provided, coro is started in the worker context. @@ -463,7 +485,7 @@ async def stuff(): return result stuff_result = App.run(stuff()) - ''' + """ with cls.create(*args, **kwargs) as context: res = None if coro is None else context.worker.execute(coro) context.atk.run() @@ -478,16 +500,18 @@ def root(self) -> tk.Tk: return self.atk.root @staticmethod - def worker_func(f) -> typing.Callable[..., typing.Any | asyncio.Future | concurrent.futures.Future]: - ''' + def worker_func( + f, + ) -> typing.Callable[..., typing.Any | asyncio.Future | concurrent.futures.Future]: + """ Decorator to mark a function to be executed in the worker context. The decorated function may be a coroutine function or a regular function. By default, a future is returned, unless block=True is passed. - ''' + """ @functools.wraps(f) - def worker_func(self : AsyncApp, *args, block : bool=False, **kwargs) -> typing.Any: + def worker_func(self: AsyncApp, *args, block: bool = False, **kwargs) -> typing.Any: # self.logger.debug(f'Scheduling {f} in worker') try: @@ -526,7 +550,7 @@ def worker_func(self : AsyncApp, *args, block : bool=False, **kwargs) -> typing. except asyncio.CancelledError: raise except BaseException as e: - self.logger.debug(f'Exception {e} in scheduling worker function {f}') + self.logger.debug(f"Exception {e} in scheduling worker function {f}") raise return worker_func @@ -535,25 +559,27 @@ def worker_func(self : AsyncApp, *args, block : bool=False, **kwargs) -> typing. def cleanup(self): super().cleanup() - self.logger.debug('Cleanup worker') + self.logger.debug("Cleanup worker") try: - self.worker.stop(lexc.DeadlockChecker.default_timeout_s + 1 \ - if lexc.DeadlockChecker.default_timeout_s is not None else None) + self.worker.stop( + lexc.DeadlockChecker.default_timeout_s + 1 + if lexc.DeadlockChecker.default_timeout_s is not None + else None + ) except TimeoutError: - self.logger.debug('Cleanup worker - forcing') + self.logger.debug("Cleanup worker - forcing") self.worker.cancel() def __del__(self): assert not self.worker.is_running() - class AsyncWidget(Work): - ''' + """ Mixin class for all async widgets. - ''' + """ - def __init__(self, app : AsyncApp, *args, **kwargs): + def __init__(self, app: AsyncApp, *args, **kwargs): super().__init__(atk=app.atk, *args, **kwargs) self._app = app @@ -570,15 +596,16 @@ def root(self) -> tk.Tk: return self._app.root @staticmethod - def worker_func(f) -> typing.Callable[..., typing.Any | asyncio.Future | concurrent.futures.Future]: + def worker_func( + f, + ) -> typing.Callable[..., typing.Any | asyncio.Future | concurrent.futures.Future]: return AsyncApp.worker_func(f) - class ZmqObjectEntry(AsyncWidget, ttk.Entry): - ''' + """ An Entry widget, bound to a libstored Object. - ''' + """ class State(enum.IntEnum): INIT = enum.auto() @@ -590,41 +617,50 @@ class State(enum.IntEnum): FOCUSED = enum.auto() EDITING = enum.auto() - def __init__(self, app : AsyncApp, parent : tk.Widget, obj : laio_zmq.Object, - rate_limit_Hz=3, *args, **kwargs): + def __init__( + self, + app: AsyncApp, + parent: tk.Widget, + obj: laio_zmq.Object, + rate_limit_Hz=3, + *args, + **kwargs, + ): super().__init__(app=app, master=parent, *args, **kwargs) self._obj = obj - self._updated : float = 0 + self._updated: float = 0 self._state = ZmqObjectEntry.State.INIT self._var = tk.StringVar() - self['textvariable'] = self._var - self._rate_limit = laio_event.AsyncioRateLimit(worker=self.worker, Hz=rate_limit_Hz, event_name=obj.name) + self["textvariable"] = self._var + self._rate_limit = laio_event.AsyncioRateLimit( + worker=self.worker, Hz=rate_limit_Hz, event_name=obj.name + ) self.connect(self._rate_limit, self._refresh) self.connect(self.obj.value_str, self._rate_limit) self.connect(self.obj.client.disconnected, self._refresh) - self.bind('', self._write) - self.bind('', self._write) - self.bind('', self._focus_in) - self.bind('', self._focus_out) - self.bind('', self._edit) + self.bind("", self._write) + self.bind("", self._write) + self.bind("", self._focus_in) + self.bind("", self._focus_out) + self.bind("", self._edit) def select_all(event): - event.widget.select_range(0, 'end') - event.widget.icursor('end') - return 'break' + event.widget.select_range(0, "end") + event.widget.icursor("end") + return "break" - self.bind('', select_all) - self.bind('', select_all) + self.bind("", select_all) + self.bind("", select_all) - self.bind('', self._revert) + self.bind("", self._revert) - self['justify'] = 'right' + self["justify"] = "right" self._set_state(ZmqObjectEntry.State.DEFAULT) self._refresh() def __repr__(self): - return f'ZmqObjectEntry({self.obj.name})@0x{id(self):x}' + return f"ZmqObjectEntry({self.obj.name})@0x{id(self):x}" @property def alive(self): @@ -653,7 +689,7 @@ def obj(self) -> laio_zmq.Object: def updated(self) -> bool: return self._state == ZmqObjectEntry.State.UPDATED - def _set_state(self, state : State): + def _set_state(self, state: State): if not self.alive: state = ZmqObjectEntry.State.DISCONNECTED @@ -666,8 +702,11 @@ def _set_state(self, state : State): if state == self._state: return - if self._state == ZmqObjectEntry.State.INVALID and state != ZmqObjectEntry.State.DISCONNECTED: - self._var.set('') + if ( + self._state == ZmqObjectEntry.State.INVALID + and state != ZmqObjectEntry.State.DISCONNECTED + ): + self._var.set("") self._state = state @@ -676,32 +715,32 @@ def _set_state(self, state : State): if self._state == ZmqObjectEntry.State.DISCONNECTED: # Freeze field. - self['state'] = 'disabled' - elif self['state'] == 'disabled': - self['state'] = 'normal' + self["state"] = "disabled" + elif self["state"] == "disabled": + self["state"] = "normal" if self._state == ZmqObjectEntry.State.DISCONNECTED: - self['foreground'] = 'gray' + self["foreground"] = "gray" elif self._state == ZmqObjectEntry.State.INVALID: - self['foreground'] = 'gray' - self._var.set('?') + self["foreground"] = "gray" + self._var.set("?") elif self._state == ZmqObjectEntry.State.VALID: - self['foreground'] = 'black' + self["foreground"] = "black" elif self._state == ZmqObjectEntry.State.UPDATED: - self['foreground'] = 'blue' + self["foreground"] = "blue" elif self._state == ZmqObjectEntry.State.FOCUSED: - self['foreground'] = 'black' + self["foreground"] = "black" elif self._state == ZmqObjectEntry.State.EDITING: - self['foreground'] = 'red' + self["foreground"] = "red" @AsyncApp.worker_func - async def refresh(self, acquire_alias : bool=False): + async def refresh(self, acquire_alias: bool = False): if self.alive: await self.obj.read(acquire_alias=acquire_alias) self._rate_limit.flush() @AsyncApp.tk_func - def _refresh(self, value : str | None = None): + def _refresh(self, value: str | None = None): if value is None: value = self.obj.value_str.value @@ -739,7 +778,7 @@ def _write(self, *args): self._set_state(ZmqObjectEntry.State.FOCUSED) x = self._var.get() - self.logger.debug(f'Write {x} to {self.obj.name}') + self.logger.debug(f"Write {x} to {self.obj.name}") self.obj.value_str.value = x self.obj.write(block=False) @@ -755,7 +794,7 @@ def _focus_out(self, *args): def _edit(self, e): try: - if e.keysym == 'Escape': + if e.keysym == "Escape": return except: pass @@ -768,6 +807,6 @@ def _revert(self, *args): if self.focused: value = self.obj.value_str.value - self._var.set(value if value is not None else '') + self._var.set(value if value is not None else "") else: self._refresh() diff --git a/python/libstored/asyncio/worker.py b/python/libstored/asyncio/worker.py index e4b018f8..4d94e2d0 100644 --- a/python/libstored/asyncio/worker.py +++ b/python/libstored/asyncio/worker.py @@ -14,14 +14,15 @@ import sys import typing -if sys.platform == 'win32' and sys.version_info < (3, 16): - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) # type: ignore +if sys.platform == "win32" and sys.version_info < (3, 16): + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) # type: ignore from .. import exceptions as lexc -default_worker : AsyncioWorker | None = None +default_worker: AsyncioWorker | None = None + +workers: set[AsyncioWorker] = set() -workers : set[AsyncioWorker] = set() # Do a graceful shutdown when the main thread exits. def monitor_workers(): @@ -32,9 +33,9 @@ def monitor_workers(): except TimeoutError: pass -monitor = threading.Thread(target=monitor_workers, daemon=False, name='AsyncioWorkerMonitor') -monitor.start() +monitor = threading.Thread(target=monitor_workers, daemon=False, name="AsyncioWorkerMonitor") +monitor.start() def current_worker() -> AsyncioWorker | None: @@ -45,18 +46,19 @@ def current_worker() -> AsyncioWorker | None: return None - class AsyncioWorker: - ''' + """ A worker thread running an asyncio event loop. - ''' + """ - def __init__(self, daemon : None | bool=False, name='AsyncioWorker', *args, **kwargs): + def __init__(self, daemon: None | bool = False, name="AsyncioWorker", *args, **kwargs): super().__init__(*args, **kwargs) self.logger = logging.getLogger(__class__.__name__) - self._loop : asyncio.AbstractEventLoop | None = None - self._started : bool = False - self._thread : threading.Thread | None = threading.Thread(target=self._run, daemon=daemon, name=name) + self._loop: asyncio.AbstractEventLoop | None = None + self._started: bool = False + self._thread: threading.Thread | None = threading.Thread( + target=self._run, daemon=daemon, name=name + ) self._thread.start() self.logger.debug("Waiting for event loop to start") while not self._started: @@ -77,7 +79,7 @@ def _run(self): assert self._thread == threading.current_thread() self._loop = asyncio.new_event_loop() asyncio.set_event_loop(self._loop) - self._loop.create_task(self._flag_started(), name=f'{self.__class__.__name__} flag') + self._loop.create_task(self._flag_started(), name=f"{self.__class__.__name__} flag") global default_worker if default_worker is None: @@ -104,7 +106,7 @@ def _run(self): except lexc.InvalidState: pass except: - self.logger.debug('Exception in %s during shutdown', t, exc_info=True) + self.logger.debug("Exception in %s during shutdown", t, exc_info=True) except Exception as e: self.logger.exception("Exception in event loop: %s", e, exc_info=True) raise @@ -123,11 +125,11 @@ async def _flag_started(self): self._started = True def make_default(self): - ''' + """ Make this worker the default worker. Thread-safe. - ''' + """ if not self.is_running(): raise lexc.InvalidState("Event loop is not running") @@ -135,12 +137,12 @@ def make_default(self): global default_worker default_worker = self - def cancel(self, timeout_s : float | None=None): - ''' + def cancel(self, timeout_s: float | None = None): + """ Cancel all tasks and stop the event loop and wait for the thread to exit. Thread-safe. - ''' + """ loop = self._loop if loop is None: @@ -156,12 +158,12 @@ def _cancel(self): for task in asyncio.all_tasks(self._loop): task.cancel() - def stop(self, timeout_s : float | None=None): - ''' + def stop(self, timeout_s: float | None = None): + """ Request to stop the event loop and wait for the thread to exit. Thread-safe. - ''' + """ loop = self._loop if loop is None: @@ -173,12 +175,12 @@ def stop(self, timeout_s : float | None=None): self.wait(timeout_s) - def wait(self, timeout_s : float | None=None): - ''' + def wait(self, timeout_s: float | None = None): + """ Wait for the event loop to complete all tasks. Thread-safe. - ''' + """ global default_worker @@ -205,11 +207,11 @@ def __del__(self): self.stop() def is_running(self) -> bool: - ''' + """ Return True if the event loop is running. Thread-safe. - ''' + """ return self._loop is not None @@ -222,12 +224,14 @@ def __exit__(self, *args): except KeyboardInterrupt: self.cancel() - def execute(self, f : typing.Callable | typing.Coroutine, *args, **kwargs) -> concurrent.futures.Future | asyncio.Future: - ''' + def execute( + self, f: typing.Callable | typing.Coroutine, *args, **kwargs + ) -> concurrent.futures.Future | asyncio.Future: + """ Schedule a coroutine or a callable to be executed in the event loop. Thread-safe and safe to call from within the event loop. - ''' + """ if not self.is_running(): raise lexc.InvalidState("Event loop is not running") @@ -236,14 +240,18 @@ def execute(self, f : typing.Callable | typing.Coroutine, *args, **kwargs) -> co if asyncio.coroutines.iscoroutine(f): if len(args) > 0 or len(kwargs) > 0: - raise TypeError("When passing a coroutine function, no additional arguments are supported") + raise TypeError( + "When passing a coroutine function, no additional arguments are supported" + ) coro = f elif asyncio.iscoroutinefunction(f): coro = f(*args, **kwargs) elif callable(f): coro = self._execute(f, *args, **kwargs) else: - raise TypeError("First argument must be a coroutine function or a callable returning a coroutine") + raise TypeError( + "First argument must be a coroutine function or a callable returning a coroutine" + ) # self.logger.debug("Scheduling coroutine") @@ -256,7 +264,7 @@ def execute(self, f : typing.Callable | typing.Coroutine, *args, **kwargs) -> co # Create coro in the worker thread context, and return async future. return asyncio.run_coroutine_threadsafe(self._log(coro, name), self._loop) - async def _log(self, coro : typing.Coroutine, name : str | None=None) -> typing.Any: + async def _log(self, coro: typing.Coroutine, name: str | None = None) -> typing.Any: try: if name: t = asyncio.current_task(self._loop) @@ -271,31 +279,36 @@ async def _log(self, coro : typing.Coroutine, name : str | None=None) -> typing. except lexc.Disconnected: raise except Exception as e: - self.logger.debug('Exception in %s', asyncio.current_task(self._loop), exc_info=True) + self.logger.debug("Exception in %s", asyncio.current_task(self._loop), exc_info=True) raise async def _execute(self, f, *args, **kwargs) -> typing.Any: return f(*args, **kwargs) -def silence_future(future : concurrent.futures.Future | asyncio.Future, logger : logging.Logger | None=None) -> concurrent.futures.Future | asyncio.Future: - ''' + +def silence_future( + future: concurrent.futures.Future | asyncio.Future, logger: logging.Logger | None = None +) -> concurrent.futures.Future | asyncio.Future: + """ Silences exceptions in a future by adding a done callback that retrieves the result. This prevents "unhandled exception in future" warnings. - ''' + """ + def _callback(fut: concurrent.futures.Future | asyncio.Future): try: fut.result() except Exception as e: if logger is not None: - logger.debug('Silenced exception in %s: %s', fut, e) + logger.debug("Silenced exception in %s: %s", fut, e) future.add_done_callback(_callback) return future -def run_sync(f : typing.Callable) -> typing.Callable: - ''' + +def run_sync(f: typing.Callable) -> typing.Callable: + """ Decorator to run an async function synchronously. If called from within an event loop, the coroutine is directly started, @@ -304,13 +317,13 @@ def run_sync(f : typing.Callable) -> typing.Callable: When block=False is passed, a future is returned instead. When sync is passed, it checks consistency with the detected (a)sync context. - ''' + """ @functools.wraps(f) - def run_sync(*args, block : bool=True, sync : bool | None=None, **kwargs) -> typing.Any: + def run_sync(*args, block: bool = True, sync: bool | None = None, **kwargs) -> typing.Any: assert asyncio.iscoroutinefunction(f) - self : typing.Any = args[0] if len(args) > 0 else None + self: typing.Any = args[0] if len(args) > 0 else None loop = None try: @@ -318,7 +331,9 @@ def run_sync(*args, block : bool=True, sync : bool | None=None, **kwargs) -> typ except RuntimeError: pass - assert sync is None or (loop is None) == sync or not block, 'sync argument contradicts current context' + assert ( + sync is None or (loop is None) == sync or not block + ), "sync argument contradicts current context" if loop is not None and not sync: # We are in an event loop, just start the coro. @@ -336,16 +351,16 @@ def run_sync(*args, block : bool=True, sync : bool | None=None, **kwargs) -> typ logger.debug("Running %s in worker %s", f.__qualname__, str(self.worker)) w = self.worker else: - if hasattr(self, 'logger'): + if hasattr(self, "logger"): logger = self.logger logger.debug("Running %s in default worker", f.__qualname__) global default_worker w = default_worker if w is None or not w.is_running(): - if hasattr(self, 'logger'): + if hasattr(self, "logger"): logger = self.logger - logger.debug('No worker running, creating new one') + logger.debug("No worker running, creating new one") w = AsyncioWorker() future = w.execute(f(*args, **kwargs)) @@ -360,13 +375,18 @@ def run_sync(*args, block : bool=True, sync : bool | None=None, **kwargs) -> typ return run_sync - class Work: - ''' + """ Mixin class for objects that have work to run by an AsyncioWorker. - ''' - - def __init__(self, worker : AsyncioWorker | None=None, logger : logging.Logger | None=None, *args, **kwargs): + """ + + def __init__( + self, + worker: AsyncioWorker | None = None, + logger: logging.Logger | None = None, + *args, + **kwargs, + ): super().__init__(*args, **kwargs) self.logger = logger or logging.getLogger(self.__class__.__name__) @@ -392,13 +412,13 @@ def loop(self) -> asyncio.AbstractEventLoop: @staticmethod @functools.wraps(run_sync) - def run_sync(f : typing.Callable) -> typing.Callable: + def run_sync(f: typing.Callable) -> typing.Callable: # This is just an alias of the global function. return run_sync(f) @staticmethod - def thread_safe_async(f : typing.Callable) -> typing.Callable: - ''' + def thread_safe_async(f: typing.Callable) -> typing.Callable: + """ Decorator to make a method thread-safe by executing it in the worker thread, without waiting for completion. @@ -406,7 +426,7 @@ def thread_safe_async(f : typing.Callable) -> typing.Callable: directly executed. The actual result is returned. Otherwise, a concurrent.futures.Future is returned. - ''' + """ assert not asyncio.iscoroutinefunction(f) @@ -420,20 +440,22 @@ def thread_safe_async(self, *args, **kwargs) -> typing.Any | concurrent.futures. return thread_safe_async @staticmethod - def thread_safe(f : typing.Callable) -> typing.Callable: - ''' + def thread_safe(f: typing.Callable) -> typing.Callable: + """ Decorator to make a method thread-safe by executing it in the worker thread. By default, the call blocks until the method has completed and the actual result is returned. If block=False is passed, a concurrent.futures.Future *may* be returned. - ''' + """ assert not asyncio.iscoroutinefunction(f) @functools.wraps(f) - def thread_safe(self, *args, block=True, **kwargs) -> typing.Any | concurrent.futures.Future: + def thread_safe( + self, *args, block=True, **kwargs + ) -> typing.Any | concurrent.futures.Future: x = Work.thread_safe_async(f)(self, *args, **kwargs) if isinstance(x, concurrent.futures.Future): if block: @@ -446,17 +468,18 @@ def thread_safe(self, *args, block=True, **kwargs) -> typing.Any | concurrent.fu return thread_safe @staticmethod - def locked(f : typing.Callable) -> typing.Callable: - '''Decorator to lock a method with the instance's lock.''' + def locked(f: typing.Callable) -> typing.Callable: + """Decorator to lock a method with the instance's lock.""" @functools.wraps(f) async def locked(self, *args, **kwargs): async with self.lock: return await f(self, *args, **kwargs) + return locked @property def lock(self) -> lexc.DeadlockChecker: - if not hasattr(self, '_lock'): + if not hasattr(self, "_lock"): self._lock = lexc.DeadlockChecker(asyncio.Lock()) return self._lock diff --git a/python/libstored/asyncio/zmq.py b/python/libstored/asyncio/zmq.py index 45a00f00..e046c4d0 100644 --- a/python/libstored/asyncio/zmq.py +++ b/python/libstored/asyncio/zmq.py @@ -32,17 +32,18 @@ from ..heatshrink import HeatshrinkDecoder from .. import exceptions as lexc + class ZmqClientWork(Work): - def __init__(self, client : ZmqClient, *args, **kwargs): + def __init__(self, client: ZmqClient, *args, **kwargs): super().__init__(worker=client.worker, *args, **kwargs) - self._client : ZmqClient | None = client + self._client: ZmqClient | None = client def alive(self) -> bool: - '''Check if this object is still alive, i.e. the client connection is still active.''' + """Check if this object is still alive, i.e. the client connection is still active.""" return self._client is not None def destroy(self): - '''Destroy this object, as the client connection is closed.''' + """Destroy this object, as the client connection is closed.""" self._client = None def __del__(self): @@ -50,12 +51,13 @@ def __del__(self): @property def client(self) -> ZmqClient: - '''The ZmqClient this object belongs to.''' + """The ZmqClient this object belongs to.""" if not self.alive(): - raise lexc.Disconnected('Object destroyed, client connection closed') + raise lexc.Disconnected("Object destroyed, client connection closed") assert self._client is not None return self._client + class Object(ZmqClientWork, Value): """A variable or function as handled by a ZmqClient @@ -63,47 +65,55 @@ class Object(ZmqClientWork, Value): """ @staticmethod - def create(s : str, client : ZmqClient) -> Object | None: - ''' + def create(s: str, client: ZmqClient) -> Object | None: + """ Create an Object from a List response line. Return None if the line is invalid. - ''' + """ - split = s.split('/', 1) + split = s.split("/", 1) if len(split) < 2: return None if len(split[0]) < 3: return None try: - return Object('/' + split[1], int(split[0][0:2], 16), int(split[0][2:], 16), client) + return Object("/" + split[1], int(split[0][0:2], 16), int(split[0][2:], 16), client) except ValueError: return None - def __init__(self, name : str, type : int, size : int, client : ZmqClient, *args, **kwargs): + def __init__(self, name: str, type: int, size: int, client: ZmqClient, *args, **kwargs): self._name = name self._type_id = type self._size = size - super().__init__(client=client, type=self.value_type, event_name=f'{name}/value', *args, **kwargs) - - self._format : str = '' - self._formatter : typing.Callable[..., str] | None = None - self._poller : asyncio.Task | None = None - self._poll_interval_s : float | None = None - - self.alias = Value(str, event_name=f'{name}/alias') - self.t = Value(float, event_name=f'{name}/t') - self.value_str = ValueWrapper(str, self._value_str_get, self._value_str_set, event_name=f'{name}/value_str') - self.polling = ValueWrapper(float, lambda: self.poll_interval, self._poll_set, event_name=f'{name}/polling') - self.format = ValueWrapper(str, self._format_get, self._format_set, event_name=f'{name}/format') - self.format.value = 'default' + super().__init__( + client=client, type=self.value_type, event_name=f"{name}/value", *args, **kwargs + ) + + self._format: str = "" + self._formatter: typing.Callable[..., str] | None = None + self._poller: asyncio.Task | None = None + self._poll_interval_s: float | None = None + + self.alias = Value(str, event_name=f"{name}/alias") + self.t = Value(float, event_name=f"{name}/t") + self.value_str = ValueWrapper( + str, self._value_str_get, self._value_str_set, event_name=f"{name}/value_str" + ) + self.polling = ValueWrapper( + float, lambda: self.poll_interval, self._poll_set, event_name=f"{name}/polling" + ) + self.format = ValueWrapper( + str, self._format_get, self._format_set, event_name=f"{name}/format" + ) + self.format.value = "default" @property def name(self) -> str: - '''The full name of this object.''' + """The full name of this object.""" return self._name def __str__(self) -> str: - return f'{self.name} = {repr(self.value)}' + return f"{self.name} = {repr(self.value)}" def destroy(self): super().destroy() @@ -114,19 +124,17 @@ def destroy(self): self.value = None - - ################################################# # Type @property def type_id(self): - '''The type code of this object.''' + """The type code of this object.""" return self._type_id @property def size(self): - '''The size of this object.''' + """The size of this object.""" return self._size FlagSigned = 0x8 @@ -153,7 +161,7 @@ def size(self): Blob = 1 String = 2 - Invalid = 0xff + Invalid = 0xFF def is_valid_type(self) -> bool: return self._type_id & 0x80 == 0 @@ -175,73 +183,75 @@ def is_special(self) -> bool: @property def type_name(self) -> str: - '''Get the type name as used in the store definition.''' + """Get the type name as used in the store definition.""" dtype = self._type_id & ~self.FlagFunction t = { - self.Int8: 'int8', - self.Uint8: 'uint8', - self.Int16: 'int16', - self.Uint16: 'uint16', - self.Int32: 'int32', - self.Uint32: 'uint32', - self.Int64: 'int64', - self.Uint64: 'uint64', - self.Float: 'float', - self.Double: 'double', - self.Pointer32: 'ptr32', - self.Pointer64: 'ptr64', - self.Bool: 'bool', - self.Blob: 'blob', - self.String: 'string', - self.Void: 'void', - }.get(dtype, '?') + self.Int8: "int8", + self.Uint8: "uint8", + self.Int16: "int16", + self.Uint16: "uint16", + self.Int32: "int32", + self.Uint32: "uint32", + self.Int64: "int64", + self.Uint64: "uint64", + self.Float: "float", + self.Double: "double", + self.Pointer32: "ptr32", + self.Pointer64: "ptr64", + self.Bool: "bool", + self.Blob: "blob", + self.String: "string", + self.Void: "void", + }.get(dtype, "?") if dtype in [self.Blob, self.String]: - t = f'{t}:{self.size}' - return f'({t})' if self.is_function() else t + t = f"{t}:{self.size}" + return f"({t})" if self.is_function() else t @property def value_type(self) -> typing.Type: - '''Get the Python type used for the value of this object.''' + """Get the Python type used for the value of this object.""" dtype = self._type_id & ~self.FlagFunction t = { - self.Int8: int, - self.Uint8: int, - self.Int16: int, - self.Uint16: int, - self.Int32: int, - self.Uint32: int, - self.Int64: int, - self.Uint64: int, - self.Float: float, - self.Double: float, - self.Pointer32: int, - self.Pointer64: int, - self.Bool: bool, - self.Blob: bytearray, - self.String: str, - self.Void: type(None), - }.get(dtype, type(None)) + self.Int8: int, + self.Uint8: int, + self.Int16: int, + self.Uint16: int, + self.Int32: int, + self.Uint32: int, + self.Int64: int, + self.Uint64: int, + self.Float: float, + self.Double: float, + self.Pointer32: int, + self.Pointer64: int, + self.Bool: bool, + self.Blob: bytearray, + self.String: str, + self.Void: type(None), + }.get(dtype, type(None)) return t - - ############################################### # Read @overload - async def short_name(self, acquire : bool=True) -> str: ... + async def short_name(self, acquire: bool = True) -> str: ... @overload - def short_name(self, acquire : bool=True, *, block : typing.Literal[False]) -> asyncio.Future[str]: ... + def short_name( + self, acquire: bool = True, *, block: typing.Literal[False] + ) -> asyncio.Future[str]: ... @overload - def short_name(self, acquire : bool=True, *, sync : typing.Literal[True]) -> str: ... + def short_name(self, acquire: bool = True, *, sync: typing.Literal[True]) -> str: ... @overload - def short_name(self, acquire : bool=True, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[str]: ... + def short_name( + self, acquire: bool = True, *, block: typing.Literal[False], sync: typing.Literal[True] + ) -> concurrent.futures.Future[str]: ... @ZmqClientWork.run_sync - async def short_name(self, acquire : bool=True) -> str: - ''' + async def short_name(self, acquire: bool = True) -> str: + """ Get the alias of this object, or its full name if no alias is set. **Arguments** @@ -254,7 +264,7 @@ async def short_name(self, acquire : bool=True) -> str: **Raises** * `OperationFailed`: when the Alias command failed - ''' + """ if not self.alias.value is None: return self.alias.value @@ -271,17 +281,25 @@ async def short_name(self, acquire : bool=True) -> str: return self.name @overload - async def read(self, acquire_alias : bool=True) -> typing.Any: ... + async def read(self, acquire_alias: bool = True) -> typing.Any: ... @overload - def read(self, acquire_alias : bool=True, *, block : typing.Literal[False]) -> asyncio.Future[typing.Any]: ... + def read( + self, acquire_alias: bool = True, *, block: typing.Literal[False] + ) -> asyncio.Future[typing.Any]: ... @overload - def read(self, acquire_alias : bool=True, *, sync : typing.Literal[True]) -> typing.Any: ... + def read(self, acquire_alias: bool = True, *, sync: typing.Literal[True]) -> typing.Any: ... @overload - def read(self, acquire_alias : bool=True, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[typing.Any]: ... + def read( + self, + acquire_alias: bool = True, + *, + block: typing.Literal[False], + sync: typing.Literal[True], + ) -> concurrent.futures.Future[typing.Any]: ... @ZmqClientWork.run_sync - async def read(self, acquire_alias : bool=True) -> typing.Any: - ''' + async def read(self, acquire_alias: bool = True) -> typing.Any: + """ Read the value of this object from the server. **Arguments** @@ -294,19 +312,19 @@ async def read(self, acquire_alias : bool=True) -> typing.Any: **Raises** * `OperationFailed`: when the read operation failed - ''' + """ return await self._read(acquire_alias) - async def _read(self, acquire_alias : bool=True) -> typing.Any: + async def _read(self, acquire_alias: bool = True) -> typing.Any: name = await self.short_name(acquire_alias) t = time.time() - rep = await self.client.req(b'r' + name.encode()) + rep = await self.client.req(b"r" + name.encode()) return self.handle_read(rep, t) - def handle_read(self, rep : bytes, t=None) -> typing.Any: - '''Handle a read reply.''' + def handle_read(self, rep: bytes, t=None) -> typing.Any: + """Handle a read reply.""" - if rep == b'?': + if rep == b"?": return None try: self.set(self._decode(rep), t) @@ -314,20 +332,20 @@ def handle_read(self, rep : bytes, t=None) -> typing.Any: pass return self.value - def _decode_hex(self, data : bytes) -> bytearray: + def _decode_hex(self, data: bytes) -> bytearray: if len(data) % 2 == 1: - data = b'0' + data + data = b"0" + data res = bytearray() for i in range(0, len(data), 2): - res.append(int(data[i:i+2], 16)) + res.append(int(data[i : i + 2], 16)) return res @staticmethod - def _sign_extend(value : int, bits : int) -> int: + def _sign_extend(value: int, bits: int) -> int: sign_bit = 1 << (bits - 1) return (value & (sign_bit - 1)) - (value & sign_bit) - def _decode(self, rep : bytes) -> typing.Any: + def _decode(self, rep: bytes) -> typing.Any: dtype = self._type_id & ~self.FlagFunction if self.is_fixed(): binint = int(rep.decode(), 16) @@ -342,9 +360,9 @@ def _decode(self, rep : bytes) -> typing.Any: elif dtype == self.Int64: return self._sign_extend(binint, 64) elif dtype == self.Float: - return struct.unpack(' typing.Any: else: raise ValueError() elif dtype == self.Void: - return b'' + return b"" elif dtype == self.Blob: return self._decode_hex(rep) elif dtype == self.String: - return self._decode_hex(rep).partition(b'\x00')[0].decode() + return self._decode_hex(rep).partition(b"\x00")[0].decode() else: raise ValueError() def get(self) -> typing.Any: - '''Get the locally cached value of this object.''' + """Get the locally cached value of this object.""" return self.value - - ############################################### # Write @overload - async def write(self, value : typing.Any=None) -> None: ... + async def write(self, value: typing.Any = None) -> None: ... @overload - def write(self, value : typing.Any=None, *, block : typing.Literal[False]) -> asyncio.Future[None]: ... + def write( + self, value: typing.Any = None, *, block: typing.Literal[False] + ) -> asyncio.Future[None]: ... @overload - def write(self, value : typing.Any=None, *, sync : typing.Literal[True]) -> None: ... + def write(self, value: typing.Any = None, *, sync: typing.Literal[True]) -> None: ... @overload - def write(self, value : typing.Any=None, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[None]: ... + def write( + self, value: typing.Any = None, *, block: typing.Literal[False], sync: typing.Literal[True] + ) -> concurrent.futures.Future[None]: ... @ZmqClientWork.run_sync - async def write(self, value : typing.Any=None) -> None: - ''' + async def write(self, value: typing.Any = None) -> None: + """ Write a value to this object on the server. **Arguments** @@ -394,7 +414,7 @@ async def write(self, value : typing.Any=None) -> None: **Raises** * `OperationFailed`: when the write operation failed * `ValueError`: when the value cannot be encoded - ''' + """ if value is not None: self.set(value) @@ -406,55 +426,57 @@ async def write(self, value : typing.Any=None) -> None: data = self._encode(value) name = await self.short_name() - req = b'w' + data + name.encode() + req = b"w" + data + name.encode() rep = await self.client.req(req) - if rep != b'!': - raise lexc.OperationFailed('Write failed') + if rep != b"!": + raise lexc.OperationFailed("Write failed") - def _encode_hex(self, data, zerostrip = False) -> bytes: - s = b''.join([b'%02x' % b for b in data]) + def _encode_hex(self, data, zerostrip=False) -> bytes: + s = b"".join([b"%02x" % b for b in data]) if zerostrip: - s = s.lstrip(b'0') - if s == b'': - s = b'0' + s = s.lstrip(b"0") + if s == b"": + s = b"0" return s - def _encode(self, value : typing.Any) -> bytes: + def _encode(self, value: typing.Any) -> bytes: dtype = self._type_id & ~self.FlagFunction if dtype == self.Void: - return b'' + return b"" elif dtype == self.Blob: return self._encode_hex(value) elif dtype == self.String: - return self._encode_hex(value.encode()) + b'00' + return self._encode_hex(value.encode()) + b"00" elif dtype == self.Pointer32: - return ('%x' % value).encode() + return ("%x" % value).encode() elif dtype == self.Pointer64: - return ('%x' % value).encode() + return ("%x" % value).encode() elif dtype == self.Bool: - return b'1' if value else b'0' + return b"1" if value else b"0" elif dtype == self.Float: - return self._encode_hex(struct.pack('>f', value)) + return self._encode_hex(struct.pack(">f", value)) elif dtype == self.Double: - return self._encode_hex(struct.pack('>d', value)) + return self._encode_hex(struct.pack(">d", value)) elif not self.is_int(): - raise TypeError('Invalid type for encoding') + raise TypeError("Invalid type for encoding") elif self.is_signed(): - return self._encode_hex(struct.pack('>q', value)[-self._size:], True) + return self._encode_hex(struct.pack(">q", value)[-self._size :], True) else: if value < 0: value += 1 << 64 - return self._encode_hex(struct.pack('>Q', value)[-self._size:], True) + return self._encode_hex(struct.pack(">Q", value)[-self._size :], True) @overload - def set(self, value : typing.Any, t : float | None=None) -> None: ... + def set(self, value: typing.Any, t: float | None = None) -> None: ... @overload - def set(self, value : typing.Any, t : float | None=None, *, block : typing.Literal[False]) -> concurrent.futures.Future[None] | None: ... + def set( + self, value: typing.Any, t: float | None = None, *, block: typing.Literal[False] + ) -> concurrent.futures.Future[None] | None: ... @ZmqClientWork.thread_safe - def set(self, value : typing.Any, t : float | None=None) -> None: - ''' + def set(self, value: typing.Any, t: float | None = None) -> None: + """ Set the value of this object, without actually writing it yet to the server. **Arguments** @@ -465,10 +487,10 @@ def set(self, value : typing.Any, t : float | None=None) -> None: **Result** * `None`: when `block = True` * otherwise a future - ''' + """ if type(value) != self.value_type: - raise TypeError(f'Expected value of type {self.value_type}, got {type(value)}') + raise TypeError(f"Expected value of type {self.value_type}, got {type(value)}") if t is None: t = time.time() @@ -479,12 +501,18 @@ def set(self, value : typing.Any, t : float | None=None) -> None: if not self.is_fixed(): if isinstance(value, str) or isinstance(value, bytes): - value = value[0:self.size] + value = value[0 : self.size] self.t.pause() self.t.value = t - if isinstance(value, float) and math.isnan(value) and self.type == float and self.value is not None and math.isnan(self.value): + if ( + isinstance(value, float) + and math.isnan(value) + and self.type == float + and self.value is not None + and math.isnan(self.value) + ): # Not updated, even though value != self._value would be True pass elif value != self.value: @@ -493,8 +521,6 @@ def set(self, value : typing.Any, t : float | None=None) -> None: self.t.resume() - - ############################################### # String conversion @@ -507,12 +533,12 @@ def _interpret_float(self, value): except ValueError: return float(self._interpret_int(value)) - def interpret(self, value : str) -> typing.Any: - '''Interpret a string as a value of the appropriate type for this object.''' + def interpret(self, value: str) -> typing.Any: + """Interpret a string as a value of the appropriate type for this object.""" - value = value.strip().replace(' ', '') + value = value.strip().replace(" ", "") - if not hasattr(self, '_interpret_map'): + if not hasattr(self, "_interpret_map"): self._interpret_map = { self.Int8: self._interpret_int, self.Uint8: self._interpret_int, @@ -524,9 +550,9 @@ def interpret(self, value : str) -> typing.Any: self.Uint64: self._interpret_int, self.Float: self._interpret_float, self.Double: self._interpret_float, - self.Pointer32: lambda x: int(x,0), - self.Pointer64: lambda x: int(x,0), - self.Bool: lambda x: x.lower() in ['true', '1'], + self.Pointer32: lambda x: int(x, 0), + self.Pointer64: lambda x: int(x, 0), + self.Bool: lambda x: x.lower() in ["true", "1"], self.Blob: lambda x: x.encode(), self.String: lambda x: x, self.Void: lambda x: bytes(), @@ -534,27 +560,27 @@ def interpret(self, value : str) -> typing.Any: return self._interpret_map.get(self._type_id & ~self.FlagFunction, lambda x: x)(value) - def _format_int(self, x : int) -> str: - return locale.format_string('%d', x, True) + def _format_int(self, x: int) -> str: + return locale.format_string("%d", x, True) - def _format_float(self, x : float, f : str, prec : int) -> str: - return locale.format_string(f'%.{prec}{f}', x, True) + def _format_float(self, x: float, f: str, prec: int) -> str: + return locale.format_string(f"%.{prec}{f}", x, True) - def _format_bytes(self, value : typing.Any) -> str: + def _format_bytes(self, value: typing.Any) -> str: value = self._encode(value).decode() - value = '0' * (self._size * 2 - len(value)) + value - res = '' + value = "0" * (self._size * 2 - len(value)) + value + res = "" for i in range(0, len(value), 2): if res != []: - res += ' ' - res += value[i:i+2] + res += " " + res += value[i : i + 2] return res def _format_get(self): - '''Get or set the format used to convert the value to a string.''' + """Get or set the format used to convert the value to a string.""" return self._format - def _format_set(self, f : str): + def _format_set(self, f: str): if self._format == f: return if not f in self.formats(): @@ -562,16 +588,16 @@ def _format_set(self, f : str): self._format = f - if f == 'hex': + if f == "hex": self._formatter = lambda x: hex(x & (1 << self._size * 8) - 1) - elif f == 'bin': + elif f == "bin": self._formatter = bin - elif f == 'bytes' or self._type_id & ~self.FlagFunction == self.Blob: + elif f == "bytes" or self._type_id & ~self.FlagFunction == self.Blob: self._formatter = self._format_bytes elif self._type_id & ~self.FlagFunction == self.Float: - self._formatter = lambda x: self._format_float(x, 'g', 6) + self._formatter = lambda x: self._format_float(x, "g", 6) elif self._type_id & ~self.FlagFunction == self.Double: - self._formatter = lambda x: self._format_float(x, 'g', 15) + self._formatter = lambda x: self._format_float(x, "g", 15) elif self._type_id & self.FlagInt: self._formatter = self._format_int else: @@ -581,32 +607,32 @@ def _format_set(self, f : str): self.format.trigger() def formats(self) -> list[str]: - '''Get the list of supported formats for this object.''' + """Get the list of supported formats for this object.""" - f = ['default', 'bytes'] + f = ["default", "bytes"] if self._type_id & ~self.FlagFunction == self.Blob: return f if self.is_int(): - f += ['hex', 'bin'] + f += ["hex", "bin"] return f def _value_str_get(self) -> str: - '''Get the string representation of the value of this object.''' + """Get the string representation of the value of this object.""" x = self.value if x is None: - return '' + return "" assert self._formatter is not None try: return self._formatter(x) except: - return '?' + return "?" - def _value_str_set(self, s : str): - '''Set the value of this object from a string representation.''' + def _value_str_set(self, s: str): + """Set the value of this object from a string representation.""" - if s == '': + if s == "": self.set(self.type(), block=False) else: try: @@ -614,23 +640,29 @@ def _value_str_set(self, s : str): except: self.value_str.trigger() - - ############################################### # Polling @overload - async def poll(self, interval_s : float | None=None) -> None: ... + async def poll(self, interval_s: float | None = None) -> None: ... @overload - def poll(self, interval_s : float | None=None, *, block : typing.Literal[False]) -> asyncio.Future[None]: ... + def poll( + self, interval_s: float | None = None, *, block: typing.Literal[False] + ) -> asyncio.Future[None]: ... @overload - def poll(self, interval_s : float | None=None, *, sync : typing.Literal[True]) -> None: ... + def poll(self, interval_s: float | None = None, *, sync: typing.Literal[True]) -> None: ... @overload - def poll(self, interval_s : float | None=None, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[None]: ... + def poll( + self, + interval_s: float | None = None, + *, + block: typing.Literal[False], + sync: typing.Literal[True], + ) -> concurrent.futures.Future[None]: ... @ZmqClientWork.run_sync - async def poll(self, interval_s : float | None=None): - '''Set up polling of this object. + async def poll(self, interval_s: float | None = None): + """Set up polling of this object. If interval_s is None (the default), stop polling. If interval_s is 0, poll as fast as possible. @@ -643,13 +675,13 @@ async def poll(self, interval_s : float | None=None): **Result** * `None`: when `block = True` * otherwise a future - ''' + """ if not self.alive(): - raise lexc.InvalidState('Object destroyed, client connection closed') + raise lexc.InvalidState("Object destroyed, client connection closed") if interval_s is not None and interval_s < 0: - raise ValueError('interval_s must be None or >= 0') + raise ValueError("interval_s must be None or >= 0") if interval_s is None: pass @@ -657,14 +689,14 @@ async def poll(self, interval_s : float | None=None): # Good enough. interval_s = float(interval_s) elif not isinstance(interval_s, float): - raise ValueError('interval_s must be None or a float') + raise ValueError("interval_s must be None or a float") self._poll_slow_stop() self._poll_interval_s = interval_s await self.client._poll(self, interval_s) self.polling.trigger() - def _poll_set(self, interval_s : float | None): + def _poll_set(self, interval_s: float | None): self.poll(interval_s, block=False) def _poll_slow_stop(self): @@ -672,50 +704,54 @@ def _poll_slow_stop(self): self._poller.cancel() self._poller = None - async def _poll_slow(self, interval_s : float): + async def _poll_slow(self, interval_s: float): self._poll_slow_stop() - self._poller = self.client.periodic(interval_s, self._read, name=f'poll {self.name}') + self._poller = self.client.periodic(interval_s, self._read, name=f"poll {self.name}") @property def poll_interval(self) -> float | None: - '''Get the current polling interval, or None if not polling.''' + """Get the current polling interval, or None if not polling.""" return self._poll_interval_s - - ############################################### # State def state(self) -> dict[str, dict[str, typing.Any]]: - '''Get the state of this object as a JSON-serializable dictionary.''' + """Get the state of this object as a JSON-serializable dictionary.""" default = True - s : dict[str, typing.Any] = {} + s: dict[str, typing.Any] = {} - if self.format.value != 'default': - s['format'] = self.format.value + if self.format.value != "default": + s["format"] = self.format.value default = False p = self.poll_interval if not p is None: - s['poll_interval'] = p + s["poll_interval"] = p default = False - return {} if default else { self.name: s } + return {} if default else {self.name: s} @overload - async def restore_state(self, state : dict[str, dict[str, typing.Any]]) -> None: ... + async def restore_state(self, state: dict[str, dict[str, typing.Any]]) -> None: ... @overload - def restore_state(self, state : dict[str, dict[str, typing.Any]], *, block : typing.Literal[False]) -> asyncio.Future[None]: ... + def restore_state( + self, state: dict[str, dict[str, typing.Any]], *, block: typing.Literal[False] + ) -> asyncio.Future[None]: ... @overload - def restore_state(self, state : dict[str, dict[str, typing.Any]], *, sync : typing.Literal[True]) -> None: ... + def restore_state( + self, state: dict[str, dict[str, typing.Any]], *, sync: typing.Literal[True] + ) -> None: ... @overload - def restore_state(self, state : dict[str, dict[str, typing.Any]], *, block : typing.Literal[True]) -> asyncio.Future[None]: ... + def restore_state( + self, state: dict[str, dict[str, typing.Any]], *, block: typing.Literal[True] + ) -> asyncio.Future[None]: ... @ZmqClientWork.run_sync - async def restore_state(self, state : dict): - ''' + async def restore_state(self, state: dict): + """ Restore the state of this object from a dictionary as returned by state(). **Arguments** @@ -725,42 +761,41 @@ async def restore_state(self, state : dict): **Result** * `None`: when `block = True` * otherwise a future - ''' + """ if not self.name in state: return if not self.alive(): - raise lexc.InvalidState('Object not connected') + raise lexc.InvalidState("Object not connected") s = state[self.name] try: - if 'format' in s: - self.format.value = s['format'] + if "format" in s: + self.format.value = s["format"] except ValueError: pass try: - if 'poll_interval' in s: - await self.poll(float(s['poll_interval'])) + if "poll_interval" in s: + await self.poll(float(s["poll_interval"])) except ValueError: pass - class Stream(ZmqClientWork): - def __init__(self, client : ZmqClient, name : str, raw : bool=False, *args, **kwargs): + def __init__(self, client: ZmqClient, name: str, raw: bool = False, *args, **kwargs): super().__init__(client=client, *args, **kwargs) self._raw = raw if not isinstance(name, str) or len(name) != 1: - raise ValueError('Invalid stream name ' + name) + raise ValueError("Invalid stream name " + name) self._name = name self._finishing = False self._flushing = False - self._decoder : HeatshrinkDecoder | None = None + self._decoder: HeatshrinkDecoder | None = None self._initialized = False self._compressed = False @@ -777,25 +812,29 @@ async def _init(self): return cap = await self.client.capabilities() - if not 's' in cap: - raise lexc.NotSupported('Stream capability missing') + if not "s" in cap: + raise lexc.NotSupported("Stream capability missing") - self._compressed = 'f' in cap + self._compressed = "f" in cap self._initialized = True await self.reset() @overload - async def poll(self, suffix : str='') -> str | bytes | bytearray: ... + async def poll(self, suffix: str = "") -> str | bytes | bytearray: ... @overload - def poll(self, suffix : str='', *, block : typing.Literal[False]) -> asyncio.Future[str | bytes | bytearray]: ... + def poll( + self, suffix: str = "", *, block: typing.Literal[False] + ) -> asyncio.Future[str | bytes | bytearray]: ... @overload - def poll(self, suffix : str='', *, sync : typing.Literal[True]) -> str | bytes | bytearray: ... + def poll(self, suffix: str = "", *, sync: typing.Literal[True]) -> str | bytes | bytearray: ... @overload - def poll(self, suffix : str='', *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[str | bytes | bytearray]: ... + def poll( + self, suffix: str = "", *, block: typing.Literal[False], sync: typing.Literal[True] + ) -> concurrent.futures.Future[str | bytes | bytearray]: ... @ZmqClientWork.run_sync - async def poll(self, suffix : str='') -> str | bytes | bytearray: - ''' + async def poll(self, suffix: str = "") -> str | bytes | bytearray: + """ Poll the stream for new data. **Arguments** @@ -805,12 +844,12 @@ async def poll(self, suffix : str='') -> str | bytes | bytearray: **Result** * `str | bytes | bytearray`: the new data when `block = True` * otherwise a future with this `str | bytes | bytearray` - ''' + """ await self._init() - req = b's' + (self.name + suffix).encode() + req = b"s" + (self.name + suffix).encode() return self._decode(await self.client.req(req)) - def _decode(self, x : bytes) -> str | bytes | bytearray: + def _decode(self, x: bytes) -> str | bytes | bytearray: if self._decoder is not None: x = self._decoder.fill(x) if self._finishing: @@ -820,20 +859,22 @@ def _decode(self, x : bytes) -> str | bytes | bytearray: if self.raw: return x else: - return x.decode(errors='backslashreplace') + return x.decode(errors="backslashreplace") @overload async def flush(self) -> None: ... @overload - def flush(self, *, block : typing.Literal[False]) -> asyncio.Future[None]: ... + def flush(self, *, block: typing.Literal[False]) -> asyncio.Future[None]: ... @overload - def flush(self, *, sync : typing.Literal[True]) -> None: ... + def flush(self, *, sync: typing.Literal[True]) -> None: ... @overload - def flush(self, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[None]: ... + def flush( + self, *, block: typing.Literal[False], sync: typing.Literal[True] + ) -> concurrent.futures.Future[None]: ... @ZmqClientWork.run_sync async def flush(self) -> None: - ''' + """ Flush the stream, to finalize the compression, if any. **Arguments** @@ -842,25 +883,27 @@ async def flush(self) -> None: **Result** * `None`: when `block = True` * otherwise a future - ''' + """ if self._compressed and not self._flushing and not self._finishing: self._flushing = True - await self.client.req(b'f' + self.name.encode()) + await self.client.req(b"f" + self.name.encode()) self._flushing = False self._finishing = True @overload async def reset(self) -> None: ... @overload - def reset(self, *, block : typing.Literal[False]) -> asyncio.Future[None]: ... + def reset(self, *, block: typing.Literal[False]) -> asyncio.Future[None]: ... @overload - def reset(self, *, sync : typing.Literal[True]) -> None: ... + def reset(self, *, sync: typing.Literal[True]) -> None: ... @overload - def reset(self, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[None]: ... + def reset( + self, *, block: typing.Literal[False], sync: typing.Literal[True] + ) -> concurrent.futures.Future[None]: ... @ZmqClientWork.run_sync async def reset(self) -> None: - ''' + """ Reset the compressed stream, when compression is enabled. **Arguments** @@ -869,12 +912,12 @@ async def reset(self) -> None: **Result** * `None`: when `block = True` * otherwise a future - ''' + """ if self._compressed: - await self.client.req(b'f' + self.name.encode()) + await self.client.req(b"f" + self.name.encode()) # Drop old data, as we missed the start of the stream. - await self.client.req(b's' + self.name.encode()) + await self.client.req(b"s" + self.name.encode()) self._reset() def _reset(self): @@ -884,25 +927,34 @@ def _reset(self): self._flushing = False - class Macro(ZmqClientWork): """Macro object as returned by ZmqClient.macro() Do not instantiate directly, but let ZmqClient acquire one for you. """ - def __init__(self, client : ZmqClient, macro : str | None=None, reqsep : bytes=b'\n', repsep : bytes=b' ', *args, **kwargs): + def __init__( + self, + client: ZmqClient, + macro: str | None = None, + reqsep: bytes = b"\n", + repsep: bytes = b" ", + *args, + **kwargs, + ): super().__init__(client=client, *args, **kwargs) if macro is not None and len(macro) != 1: - raise ValueError('Invalid macro name ' + macro) + raise ValueError("Invalid macro name " + macro) self._macro = None if macro is None else macro.encode() - self._cmds : dict[typing.Hashable, tuple[bytes, typing.Callable[[bytes, float | None], None] | None]] = {} + self._cmds: dict[ + typing.Hashable, tuple[bytes, typing.Callable[[bytes, float | None], None] | None] + ] = {} self._key = 0 if len(reqsep) != 1: - raise ValueError('Invalid request separator') + raise ValueError("Invalid request separator") self._reqsep = reqsep self._repsep = repsep @@ -918,18 +970,50 @@ def macro(self) -> bytes | None: return self._macro @overload - async def add(self, cmd : str, cb : typing.Callable[[bytes, float | None], None] | None=None, key : typing.Hashable | None=None) -> None: ... - @overload - def add(self, cmd : str, cb : typing.Callable[[bytes, float | None], None] | None=None, key : typing.Hashable | None=None, *, block : typing.Literal[False]) -> asyncio.Future[None]: ... - @overload - def add(self, cmd : str, cb : typing.Callable[[bytes, float | None], None] | None=None, key : typing.Hashable | None=None, *, sync : typing.Literal[True]) -> None: ... - @overload - def add(self, cmd : str, cb : typing.Callable[[bytes, float | None], None] | None=None, key : typing.Hashable | None=None, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[None]: ... + async def add( + self, + cmd: str, + cb: typing.Callable[[bytes, float | None], None] | None = None, + key: typing.Hashable | None = None, + ) -> None: ... + @overload + def add( + self, + cmd: str, + cb: typing.Callable[[bytes, float | None], None] | None = None, + key: typing.Hashable | None = None, + *, + block: typing.Literal[False], + ) -> asyncio.Future[None]: ... + @overload + def add( + self, + cmd: str, + cb: typing.Callable[[bytes, float | None], None] | None = None, + key: typing.Hashable | None = None, + *, + sync: typing.Literal[True], + ) -> None: ... + @overload + def add( + self, + cmd: str, + cb: typing.Callable[[bytes, float | None], None] | None = None, + key: typing.Hashable | None = None, + *, + block: typing.Literal[False], + sync: typing.Literal[True], + ) -> concurrent.futures.Future[None]: ... @ZmqClientWork.run_sync @ZmqClientWork.locked - async def add(self, cmd : str, cb : typing.Callable[[bytes, float | None], None] | None=None, key : typing.Hashable | None=None): - ''' + async def add( + self, + cmd: str, + cb: typing.Callable[[bytes, float | None], None] | None = None, + key: typing.Hashable | None = None, + ): + """ Add a command to this macro. **Arguments** @@ -941,7 +1025,7 @@ async def add(self, cmd : str, cb : typing.Callable[[bytes, float | None], None] **Result** * `bool`: True if the command was added successfully, False otherwise, when `block = True` * otherwise a future with this `bool` - ''' + """ if key is None: key = self._key @@ -966,21 +1050,25 @@ async def add(self, cmd : str, cb : typing.Callable[[bytes, float | None], None] except RuntimeError: # Rollback. await self._remove(key) - raise lexc.OperationFailed('Cannot add to macro') + raise lexc.OperationFailed("Cannot add to macro") @overload - async def remove(self, key : typing.Hashable) -> bool: ... + async def remove(self, key: typing.Hashable) -> bool: ... @overload - def remove(self, key : typing.Hashable, *, block : typing.Literal[False]) -> asyncio.Future[bool]: ... + def remove( + self, key: typing.Hashable, *, block: typing.Literal[False] + ) -> asyncio.Future[bool]: ... @overload - def remove(self, key : typing.Hashable, *, sync : typing.Literal[True]) -> bool: ... + def remove(self, key: typing.Hashable, *, sync: typing.Literal[True]) -> bool: ... @overload - def remove(self, key : typing.Hashable, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[bool]: ... + def remove( + self, key: typing.Hashable, *, block: typing.Literal[False], sync: typing.Literal[True] + ) -> concurrent.futures.Future[bool]: ... @ZmqClientWork.run_sync @ZmqClientWork.locked - async def remove(self, key : typing.Hashable) -> bool: - ''' + async def remove(self, key: typing.Hashable) -> bool: + """ Remove a command from this macro. **Arguments** @@ -990,10 +1078,10 @@ async def remove(self, key : typing.Hashable) -> bool: **Result** * `bool`: True if the command was removed successfully, when `block = True` * otherwise a future with this `bool` - ''' + """ return await self._remove(key) - async def _remove(self, key : typing.Hashable) -> bool: + async def _remove(self, key: typing.Hashable) -> bool: if key in self._cmds: del self._cmds[key] await self._update() @@ -1004,16 +1092,18 @@ async def _remove(self, key : typing.Hashable) -> bool: @overload async def clear(self) -> None: ... @overload - def clear(self, *, block : typing.Literal[False]) -> asyncio.Future[None]: ... + def clear(self, *, block: typing.Literal[False]) -> asyncio.Future[None]: ... @overload - def clear(self, *, sync : typing.Literal[True]) -> None: ... + def clear(self, *, sync: typing.Literal[True]) -> None: ... @overload - def clear(self, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[None]: ... + def clear( + self, *, block: typing.Literal[False], sync: typing.Literal[True] + ) -> concurrent.futures.Future[None]: ... @ZmqClientWork.run_sync @ZmqClientWork.locked async def clear(self) -> None: - ''' + """ Clear all commands from this macro. **Arguments** @@ -1022,7 +1112,7 @@ async def clear(self) -> None: **Result** * `None`: when `block = True` * otherwise a future - ''' + """ await self._clear() async def _clear(self): @@ -1035,31 +1125,33 @@ async def _update(self): if m is None: return - cmds = [b'm' + m] + cmds = [b"m" + m] first = True for c in self._cmds.values(): if not first: - cmds.append(b'e' + self._repsep) + cmds.append(b"e" + self._repsep) cmds.append(c[0]) first = False definition = self._reqsep.join(cmds) - if await self.client.req(definition) != b'!': - raise lexc.OperationFailed('Macro definition failed') + if await self.client.req(definition) != b"!": + raise lexc.OperationFailed("Macro definition failed") @overload async def run(self) -> None: ... @overload - def run(self, *, block : typing.Literal[False]) -> asyncio.Future[None]: ... + def run(self, *, block: typing.Literal[False]) -> asyncio.Future[None]: ... @overload - def run(self, *, sync : typing.Literal[True]) -> None: ... + def run(self, *, sync: typing.Literal[True]) -> None: ... @overload - def run(self, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[None]: ... + def run( + self, *, block: typing.Literal[False], sync: typing.Literal[True] + ) -> concurrent.futures.Future[None]: ... @ZmqClientWork.run_sync @ZmqClientWork.locked async def run(self): - ''' + """ Run this macro. **Arguments** @@ -1068,7 +1160,7 @@ async def run(self): **Result** * `None`: when `block = True` * otherwise a future with this `bool` - ''' + """ await self._run() async def _run(self): @@ -1082,11 +1174,11 @@ async def _run(self): else: await self.client.req(c[0]) - def decode(self, rep : bytes, t : float | None=None, skip : int=0): + def decode(self, rep: bytes, t: float | None = None, skip: int = 0): cb = [x[1] for x in self._cmds.values()] values = rep.split(self._repsep) if len(cb) != len(values) + skip: - raise lexc.InvalidResponse('Unexpected number of responses') + raise lexc.InvalidResponse("Unexpected number of responses") for i in range(0, len(values)): f = cb[i + skip] @@ -1105,19 +1197,20 @@ def __iter__(self): return iter(self._cmds) - class Tracing(Macro): """Tracing command handling""" - def __init__(self, client : ZmqClient, stream : str='t', poll_interval_s : float=0, *args, **kwargs): - super().__init__(client=client, reqsep=b'\r', repsep=b';', *args, **kwargs) + def __init__( + self, client: ZmqClient, stream: str = "t", poll_interval_s: float = 0, *args, **kwargs + ): + super().__init__(client=client, reqsep=b"\r", repsep=b";", *args, **kwargs) - self._poll_interval_s : float = poll_interval_s - self._stream : Stream | str = stream - self._enabled : bool | None = None - self._decimate : int = 1 - self._partial : bytearray = bytearray() - self._task : asyncio.Task | None = None + self._poll_interval_s: float = poll_interval_s + self._stream: Stream | str = stream + self._enabled: bool | None = None + self._decimate: int = 1 + self._partial: bytearray = bytearray() + self._task: asyncio.Task | None = None async def _init(self): if self._enabled is not None: @@ -1126,18 +1219,18 @@ async def _init(self): self._enabled = False if not self.client.is_connected(): - raise lexc.InvalidState('Client not connected') + raise lexc.InvalidState("Client not connected") try: cap = await self.client.capabilities() - if 't' not in cap: - raise lexc.NotSupported('Tracing capability missing') - if 'm' not in cap: - raise lexc.NotSupported('Macro capability missing') - if 'e' not in cap: - raise lexc.NotSupported('Echo capability missing') - if 's' not in cap: - raise lexc.NotSupported('Stream capability missing') + if "t" not in cap: + raise lexc.NotSupported("Tracing capability missing") + if "m" not in cap: + raise lexc.NotSupported("Macro capability missing") + if "e" not in cap: + raise lexc.NotSupported("Echo capability missing") + if "s" not in cap: + raise lexc.NotSupported("Stream capability missing") if isinstance(self._stream, str): self._stream = self.client.stream(self._stream, raw=True) @@ -1146,22 +1239,22 @@ async def _init(self): # Start with sample separator. try: - await self.add('e\n', None, 'e') + await self.add("e\n", None, "e") except lexc.OperationFailed: - raise lexc.NotSupported('Cannot add echo command for tracing') + raise lexc.NotSupported("Cannot add echo command for tracing") # We must have a macro, not a simulated Macro instance. if self.macro is None: - raise lexc.NotSupported('Cannot get macro for tracing') + raise lexc.NotSupported("Cannot get macro for tracing") t = self.client.time() if t is None: - raise lexc.NotSupported('Cannot determine time stamp variable') + raise lexc.NotSupported("Cannot determine time stamp variable") try: - await self.add(f'r{await t.short_name()}', None, 't') + await self.add(f"r{await t.short_name()}", None, "t") except lexc.OperationFailed: - raise lexc.NotSupported('Cannot add time stamp command for tracing') + raise lexc.NotSupported("Cannot add time stamp command for tracing") await self._update_tracing(True) except: @@ -1172,7 +1265,7 @@ def __del__(self): try: self._enabled = False if self.client.is_connected(): - self.client.req(b't', sync=True, block=False) + self.client.req(b"t", sync=True, block=False) except: pass @@ -1194,7 +1287,7 @@ async def _update_tracing(self, force=False): if (force or self._enabled) and not enable: self._enabled = False - await self.client.req(b't') + await self.client.req(b"t") if self._task is not None: self._task.cancel() self._task = None @@ -1203,9 +1296,11 @@ async def _update_tracing(self, force=False): assert macro is not None assert isinstance(self._stream, Stream) - rep = await self.client.req(b't' + macro + self._stream.name.encode() + ('%x' % self.decimate).encode()) - if rep != b'!': - raise lexc.NotSupported('Cannot configure tracing') + rep = await self.client.req( + b"t" + macro + self._stream.name.encode() + ("%x" % self.decimate).encode() + ) + if rep != b"!": + raise lexc.NotSupported("Cannot configure tracing") await self._stream.reset() self._partial = bytearray() @@ -1216,7 +1311,7 @@ async def _update_tracing(self, force=False): self._task.cancel() self._task = None - self._task = self.client.periodic(self._poll_interval_s, self._process, name='tracing') + self._task = self.client.periodic(self._poll_interval_s, self._process, name="tracing") async def _clear(self): await super()._clear() @@ -1235,9 +1330,9 @@ def decimate(self) -> int: async def set_decimate(self, decimate: int): if decimate < 1: decimate = 1 - elif decimate > 0x7fffffff: + elif decimate > 0x7FFFFFFF: # Limit it somewhat to stay within 32 bit - decimate = 0x7fffffff + decimate = 0x7FFFFFFF self._decimate = decimate await self._update_tracing(True) @@ -1249,20 +1344,22 @@ def stream(self) -> Stream | None: @overload async def process(self) -> None: ... @overload - def process(self, *, block : typing.Literal[False]) -> asyncio.Future[None]: ... + def process(self, *, block: typing.Literal[False]) -> asyncio.Future[None]: ... @overload - def process(self, *, sync : typing.Literal[True]) -> None: ... + def process(self, *, sync: typing.Literal[True]) -> None: ... @overload - def process(self, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[None]: ... + def process( + self, *, block: typing.Literal[False], sync: typing.Literal[True] + ) -> concurrent.futures.Future[None]: ... @Macro.run_sync @Macro.locked async def process(self): - '''Process new samples from the stream. + """Process new samples from the stream. This function is called automatically when polling is enabled. It can also be called manually to process samples immediately. - ''' + """ await self._process() @Macro.locked @@ -1274,15 +1371,15 @@ async def _process(self): assert not isinstance(x, str) self._process_data(x) - def _process_data(self, s : bytes | bytearray): - samples = (self._partial + s).split(b'\n;') + def _process_data(self, s: bytes | bytearray): + samples = (self._partial + s).split(b"\n;") self._partial = samples[-1] time = self.client.time() assert time is not None for sample in samples[0:-1]: # The first value is the time stamp. - t_data = sample.split(b';', 1) + t_data = sample.split(b";", 1) if len(t_data) < 2: # Empty sample. continue @@ -1298,19 +1395,26 @@ def __len__(self): return max(0, super().__len__() - 2) - class ZmqClient(Work): - ''' + """ Asynchronous ZMQ client. This client can connect to both the libstored.zmq_server.ZmqServer and stored::DebugZmqLayer. - ''' + """ - def __init__(self, host : str='localhost', port : int=lprot.default_port, - multi : bool=False, timeout : float | None=None, context : None | zmq.asyncio.Context=None, - t : str | None = None, use_state : str | None=None, - stack : str | lprot.ProtocolLayer | None=None, - *args, **kwargs): + def __init__( + self, + host: str = "localhost", + port: int = lprot.default_port, + multi: bool = False, + timeout: float | None = None, + context: None | zmq.asyncio.Context = None, + t: str | None = None, + use_state: str | None = None, + stack: str | lprot.ProtocolLayer | None = None, + *args, + **kwargs, + ): super().__init__(*args, **kwargs) self._context = context or zmq.asyncio.Context.instance() @@ -1320,8 +1424,8 @@ def __init__(self, host : str='localhost', port : int=lprot.default_port, self._timeout = timeout if timeout is None or timeout > 0 else None self._socket = None self._alias_lock = lexc.DeadlockChecker(asyncio.Lock()) - self._t : str | Object | None | bool = t - self._t0 : float = 0 + self._t: str | Object | None | bool = t + self._t0: float = 0 self._timestamp_to_time = lambda t: float(t) self._use_state = use_state @@ -1331,134 +1435,163 @@ def __init__(self, host : str='localhost', port : int=lprot.default_port, self._stack = stack else: self._stack = lprot.ProtocolLayer() - self._stack_encoded : bytearray | None = None - self._stack_decoded : bytearray | None = None + self._stack_encoded: bytearray | None = None + self._stack_decoded: bytearray | None = None self._stack.up = self._stack_up self._stack.down = self._stack_down - self._reset() # Events - self.connecting = Event('connecting') - self.connected = Event('connected') - self.disconnecting = Event('disconnecting') - self.disconnected = Event('disconnected') - - + self.connecting = Event("connecting") + self.connected = Event("connected") + self.disconnecting = Event("disconnecting") + self.disconnected = Event("disconnected") ############################################## # ZMQ connection handling @property def host(self) -> str: - '''Configured or currently connected host.''' + """Configured or currently connected host.""" return self._host @property def port(self) -> int: - '''Configured or currently connected port.''' + """Configured or currently connected port.""" return self._port @property def multi(self) -> bool: - ''' + """ Return whether the client uses a subset of the commands that are safe when multiple connections to the same ZMQ server are made. - ''' + """ return self._multi @property def context(self) -> zmq.asyncio.Context: - '''The ZMQ context used by this client.''' + """The ZMQ context used by this client.""" return self._context @property def socket(self) -> zmq.asyncio.Socket | None: - '''The ZMQ socket used by this client, or None if not connected.''' + """The ZMQ socket used by this client, or None if not connected.""" return self._socket def is_connected(self) -> bool: - '''Check if connected to the ZMQ server.''' + """Check if connected to the ZMQ server.""" return self.socket is not None def _reset(self): - self._capabilities : str | None = None - self._identification : str | None = None - self._version : str | None = None + self._capabilities: str | None = None + self._identification: str | None = None + self._version: str | None = None - self._available_aliases : list[str] | None = None - self._temporary_aliases : dict[str, Object] = {} - self._permanent_aliases : dict[str, tuple[Object, list[typing.Any]]] = {} + self._available_aliases: list[str] | None = None + self._temporary_aliases: dict[str, Object] = {} + self._permanent_aliases: dict[str, tuple[Object, list[typing.Any]]] = {} - self._available_macros : list[str] | None = None - self._used_macros : list[str] = [] - if hasattr(self, '_macros'): + self._available_macros: list[str] | None = None + self._used_macros: list[str] = [] + if hasattr(self, "_macros"): if not self._macros is None: for m in self._macros: m.destroy() - self._macros : list[Macro] = [] + self._macros: list[Macro] = [] self._t = None - if hasattr(self, '_objects'): + if hasattr(self, "_objects"): if not self._objects is None: for o in self._objects: o.destroy() - self._objects : typing.List[Object] | None = None + self._objects: typing.List[Object] | None = None - if hasattr(self, '_objects_attr'): + if hasattr(self, "_objects_attr"): for o in self._objects_attr: if hasattr(self, o): delattr(self, o) - self._objects_attr : set[str] = set() + self._objects_attr: set[str] = set() - if hasattr(self, '_streams'): + if hasattr(self, "_streams"): for s, o in self._streams.items(): o.destroy() - self._streams : typing.Dict[str, Stream] = {} + self._streams: typing.Dict[str, Stream] = {} - if hasattr(self, '_periodic_tasks'): + if hasattr(self, "_periodic_tasks"): for t in self._periodic_tasks: t.cancel() - self._periodic_tasks : set[asyncio.Task] = set() + self._periodic_tasks: set[asyncio.Task] = set() - if hasattr(self, '_monitor'): + if hasattr(self, "_monitor"): if self._monitor is not None: self._monitor.cancel() - self._monitor : asyncio.Task | None = None + self._monitor: asyncio.Task | None = None - if hasattr(self, '_req_task'): + if hasattr(self, "_req_task"): if self._req_task is not None: self._req_task.cancel() - self._req_task : asyncio.Task | None = None + self._req_task: asyncio.Task | None = None - if hasattr(self, '_fast_poll_task'): + if hasattr(self, "_fast_poll_task"): if self._fast_poll_task is not None: self._fast_poll_task.cancel() - self._fast_poll_task : asyncio.Task | None = None - self._fast_poll_macro : Macro | None = None - self._fast_poll_interval_s : float = self.fast_poll_threshold_s - - self._tracing : Tracing | bool | None = None - - @overload - async def connect(self, host : str | None=None, port : int | None=None, \ - multi : bool | None=None, default_state : bool=False) -> None: ... - @overload - def connect(self, host : str | None=None, port : int | None=None, \ - multi : bool | None=None, default_state : bool=False, *, block : typing.Literal[False]) -> asyncio.Future[None]: ... - @overload - def connect(self, host : str | None=None, port : int | None=None, \ - multi : bool | None=None, default_state : bool=False, *, sync : typing.Literal[True]) -> None: ... - @overload - def connect(self, host : str | None=None, port : int | None=None, \ - multi : bool | None=None, default_state : bool=False, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[None]: ... + self._fast_poll_task: asyncio.Task | None = None + self._fast_poll_macro: Macro | None = None + self._fast_poll_interval_s: float = self.fast_poll_threshold_s + + self._tracing: Tracing | bool | None = None + + @overload + async def connect( + self, + host: str | None = None, + port: int | None = None, + multi: bool | None = None, + default_state: bool = False, + ) -> None: ... + @overload + def connect( + self, + host: str | None = None, + port: int | None = None, + multi: bool | None = None, + default_state: bool = False, + *, + block: typing.Literal[False], + ) -> asyncio.Future[None]: ... + @overload + def connect( + self, + host: str | None = None, + port: int | None = None, + multi: bool | None = None, + default_state: bool = False, + *, + sync: typing.Literal[True], + ) -> None: ... + @overload + def connect( + self, + host: str | None = None, + port: int | None = None, + multi: bool | None = None, + default_state: bool = False, + *, + block: typing.Literal[False], + sync: typing.Literal[True], + ) -> concurrent.futures.Future[None]: ... @Work.run_sync - async def connect(self, host : str | None=None, port : int | None=None, \ - multi : bool | None=None, default_state : bool=False): - ''' + async def connect( + self, + host: str | None = None, + port: int | None = None, + multi: bool | None = None, + default_state: bool = False, + ): + """ Connect to the ZMQ server. **Arguments** @@ -1471,20 +1604,20 @@ async def connect(self, host : str | None=None, port : int | None=None, \ **Result** * `None`: when `block = True` * otherwise a future - ''' + """ await self._connect(host, port, multi) - if 'l' in await self.capabilities(): + if "l" in await self.capabilities(): await self.list() await self.find_time() - if 'm' in await self.capabilities(): + if "m" in await self.capabilities(): # Clear all existing macros. - macros = await self.req('m') - if macros != '?': + macros = await self.req("m") + if macros != "?": for m in macros: - await self.req(f'm{m}') + await self.req(f"m{m}") self.connected.trigger() await self._stack.connected() @@ -1493,9 +1626,11 @@ async def connect(self, host : str | None=None, port : int | None=None, \ await self.restore_state() @Work.locked - async def _connect(self, host : str | None=None, port : int | None=None, multi : bool | None=None): + async def _connect( + self, host: str | None = None, port: int | None = None, multi: bool | None = None + ): if self.is_connected(): - raise lexc.InvalidState('Already connected') + raise lexc.InvalidState("Already connected") if host is not None: self._host = host @@ -1511,14 +1646,16 @@ async def _connect(self, host : str | None=None, port : int | None=None, multi : try: if self._timeout is not None: - self.logger.debug(f'using a timeout of {self._timeout} s') + self.logger.debug(f"using a timeout of {self._timeout} s") self._socket.setsockopt(zmq.CONNECT_TIMEOUT, int(self._timeout * 1000)) self._socket.setsockopt(zmq.RCVTIMEO, int(self._timeout * 1000)) self._socket.setsockopt(zmq.SNDTIMEO, int(self._timeout * 1000)) - self.logger.debug(f'connect to tcp://{self._host}:{self._port}') - self._socket.connect(f'tcp://{self._host}:{self._port}') - self._monitor = asyncio.create_task(self._monitor_socket(), name=f'{self.__class__.__name__} monitor') + self.logger.debug(f"connect to tcp://{self._host}:{self._port}") + self._socket.connect(f"tcp://{self._host}:{self._port}") + self._monitor = asyncio.create_task( + self._monitor_socket(), name=f"{self.__class__.__name__} monitor" + ) except: s = self._socket self._socket = None @@ -1529,15 +1666,17 @@ async def _connect(self, host : str | None=None, port : int | None=None, multi : @overload async def disconnect(self) -> None: ... @overload - def disconnect(self, *, block : typing.Literal[False]) -> asyncio.Future[None]: ... + def disconnect(self, *, block: typing.Literal[False]) -> asyncio.Future[None]: ... @overload - def disconnect(self, *, sync : typing.Literal[True]) -> None: ... + def disconnect(self, *, sync: typing.Literal[True]) -> None: ... @overload - def disconnect(self, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[None]: ... + def disconnect( + self, *, block: typing.Literal[False], sync: typing.Literal[True] + ) -> concurrent.futures.Future[None]: ... @Work.run_sync async def disconnect(self): - ''' + """ Disconnect from the ZMQ server. **Arguments** @@ -1546,7 +1685,7 @@ async def disconnect(self): **Result** * `None`: when `block = True` * otherwise a future - ''' + """ s = self._socket @@ -1554,7 +1693,7 @@ async def disconnect(self): # Not connected return - self.logger.debug('disconnect') + self.logger.debug("disconnect") self.disconnecting.trigger() await self._stack.disconnected() @@ -1575,15 +1714,17 @@ async def disconnect(self): @overload async def close(self) -> None: ... @overload - def close(self, *, block : typing.Literal[False]) -> asyncio.Future[None]: ... + def close(self, *, block: typing.Literal[False]) -> asyncio.Future[None]: ... @overload - def close(self, *, sync : typing.Literal[True]) -> None: ... + def close(self, *, sync: typing.Literal[True]) -> None: ... @overload - def close(self, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[None]: ... + def close( + self, *, block: typing.Literal[False], sync: typing.Literal[True] + ) -> concurrent.futures.Future[None]: ... @Work.run_sync async def close(self): - '''Disconnect and release resources.''' + """Disconnect and release resources.""" await self.disconnect() await self._stack.close() @@ -1610,32 +1751,34 @@ async def __aenter__(self): async def __aexit__(self, *args): await self.close() - - ############################################## # Low-level req @overload - async def req(self, msg : bytes) -> bytes: ... + async def req(self, msg: bytes) -> bytes: ... @overload - async def req(self, msg : str) -> str: ... + async def req(self, msg: str) -> str: ... @overload - def req(self, msg : bytes, *, block : typing.Literal[False]) -> asyncio.Future[bytes]: ... + def req(self, msg: bytes, *, block: typing.Literal[False]) -> asyncio.Future[bytes]: ... @overload - def req(self, msg : str, *, block : typing.Literal[False]) -> asyncio.Future[str]: ... + def req(self, msg: str, *, block: typing.Literal[False]) -> asyncio.Future[str]: ... @overload - def req(self, msg : bytes, *, sync : typing.Literal[True]) -> bytes: ... + def req(self, msg: bytes, *, sync: typing.Literal[True]) -> bytes: ... @overload - def req(self, msg : str, *, sync : typing.Literal[True]) -> str: ... + def req(self, msg: str, *, sync: typing.Literal[True]) -> str: ... @overload - def req(self, msg : bytes, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[bytes]: ... + def req( + self, msg: bytes, *, block: typing.Literal[False], sync: typing.Literal[True] + ) -> concurrent.futures.Future[bytes]: ... @overload - def req(self, msg : str, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[str]: ... + def req( + self, msg: str, *, block: typing.Literal[False], sync: typing.Literal[True] + ) -> concurrent.futures.Future[str]: ... @Work.run_sync @Work.locked - async def req(self, msg : bytes | str) -> bytes | str: - ''' + async def req(self, msg: bytes | str) -> bytes | str: + """ Send a request to the ZMQ server and wait for a reply. **Arguments** @@ -1651,17 +1794,21 @@ async def req(self, msg : bytes | str) -> bytes | str: * `InvalidState`: when not connected * `Disconnected`: when the connection was lost during the request * `OperationFailed`: when the request failed or interrupted - ''' + """ if len(msg) == 0: - raise ValueError('Empty request') + raise ValueError("Empty request") try: if isinstance(msg, str): - self._req_task = asyncio.create_task(self._req(msg.encode()), name=f'{self.__class__.__name__} req') + self._req_task = asyncio.create_task( + self._req(msg.encode()), name=f"{self.__class__.__name__} req" + ) return (await self._req_task).decode() else: - self._req_task = asyncio.create_task(self._req(msg), name=f'{self.__class__.__name__} req') + self._req_task = asyncio.create_task( + self._req(msg), name=f"{self.__class__.__name__} req" + ) return await self._req_task except asyncio.CancelledError: if self._req_task is not None: @@ -1675,9 +1822,9 @@ async def req(self, msg : bytes | str) -> bytes | str: # The exception is not due to us being cancelled. Someone # just aborted the req. Raise another exception instead. if not self.is_connected(): - raise lexc.Disconnected('Request aborted') + raise lexc.Disconnected("Request aborted") else: - raise lexc.OperationFailed('Request aborted') + raise lexc.OperationFailed("Request aborted") raise finally: @@ -1687,79 +1834,79 @@ def _stack_clear(self) -> None: self._stack_encoded = None self._stack_decoded = None - def _stack_up(self, data : lprot.ProtocolLayer.Packet) -> None: + def _stack_up(self, data: lprot.ProtocolLayer.Packet) -> None: if isinstance(data, str): data = data.encode() elif isinstance(data, memoryview): - data = data.cast('B') + data = data.cast("B") if self._stack_decoded is None: self._stack_decoded = bytearray(data) else: self._stack_decoded.extend(data) - def _stack_down(self, data : lprot.ProtocolLayer.Packet) -> None: + def _stack_down(self, data: lprot.ProtocolLayer.Packet) -> None: if isinstance(data, str): data = data.encode() elif isinstance(data, memoryview): - data = data.cast('B') + data = data.cast("B") if self._stack_encoded is None: self._stack_encoded = bytearray(data) else: self._stack_encoded.extend(data) - async def _req(self, msg : bytes) -> bytes: + async def _req(self, msg: bytes) -> bytes: if not self.is_connected(): - raise lexc.InvalidState('Not connected') + raise lexc.InvalidState("Not connected") assert self._socket is not None self._stack_clear() await self._stack.encode(msg) if self._stack_encoded is None: - raise lexc.OperationFailed('Stack did not produce data') + raise lexc.OperationFailed("Stack did not produce data") if self.logger.getEffectiveLevel() <= logging.DEBUG: if self._stack_encoded != msg: - self.logger.debug('req %s -> %s', msg, bytes(self._stack_encoded)) + self.logger.debug("req %s -> %s", msg, bytes(self._stack_encoded)) else: - self.logger.debug('req %s', msg) + self.logger.debug("req %s", msg) await self._socket.send(self._stack_encoded) - rep = b''.join(await self._socket.recv_multipart()) + rep = b"".join(await self._socket.recv_multipart()) await self._stack.decode(rep) if rep and self._stack_decoded is None: - raise lexc.InvalidResponse('Stack did not decode data') + raise lexc.InvalidResponse("Stack did not decode data") - decoded = bytes(self._stack_decoded) if self._stack_decoded is not None else b'' + decoded = bytes(self._stack_decoded) if self._stack_decoded is not None else b"" if self.logger.getEffectiveLevel() <= logging.DEBUG: if self._stack_decoded != rep: - self.logger.debug('rep %s <- %s', decoded, rep) + self.logger.debug("rep %s <- %s", decoded, rep) else: - self.logger.debug('rep %s', decoded) + self.logger.debug("rep %s", decoded) return decoded - - ############################################## # Simple commands @overload async def capabilities(self) -> str: ... @overload - def capabilities(self, *, block : typing.Literal[False]) -> asyncio.Future[str]: ... + def capabilities(self, *, block: typing.Literal[False]) -> asyncio.Future[str]: ... @overload - def capabilities(self, *, sync : typing.Literal[True]) -> str: ... + def capabilities(self, *, sync: typing.Literal[True]) -> str: ... @overload - def capabilities(self, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[str]: ... + def capabilities( + self, *, block: typing.Literal[False], sync: typing.Literal[True] + ) -> concurrent.futures.Future[str]: ... @Work.run_sync async def capabilities(self) -> str: - ''' + """ Get the capabilities of the connected ZMQ server. **Arguments** @@ -1768,29 +1915,31 @@ async def capabilities(self) -> str: **Result** * `str` with the capabilities when `block = True` * otherwise a future - ''' + """ if self._capabilities is None: - self._capabilities = await self.req('?') + self._capabilities = await self.req("?") assert self._capabilities is not None if self._multi: # Remove capabilities that are stateful at the embedded side. - self._capabilities = re.sub(r'[amstf]', '', self._capabilities) + self._capabilities = re.sub(r"[amstf]", "", self._capabilities) return self._capabilities @overload - async def echo(self, msg : str) -> str: ... + async def echo(self, msg: str) -> str: ... @overload - def echo(self, msg : str, *, block : typing.Literal[False]) -> asyncio.Future[str]: ... + def echo(self, msg: str, *, block: typing.Literal[False]) -> asyncio.Future[str]: ... @overload - def echo(self, msg : str, *, sync : typing.Literal[True]) -> str: ... + def echo(self, msg: str, *, sync: typing.Literal[True]) -> str: ... @overload - def echo(self, msg : str, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[str]: ... + def echo( + self, msg: str, *, block: typing.Literal[False], sync: typing.Literal[True] + ) -> concurrent.futures.Future[str]: ... @Work.run_sync - async def echo(self, msg : str) -> str: - ''' + async def echo(self, msg: str) -> str: + """ Echo a message via the ZMQ server. **Arguments** @@ -1803,24 +1952,26 @@ async def echo(self, msg : str) -> str: **Raises** * `NotSupported`: when the echo command is not supported by the server - ''' - if 'e' not in await self.capabilities(): - raise lexc.NotSupported('Echo command not supported') + """ + if "e" not in await self.capabilities(): + raise lexc.NotSupported("Echo command not supported") - return (await self.req(b'e' + msg.encode())).decode() + return (await self.req(b"e" + msg.encode())).decode() @overload async def identification(self) -> str: ... @overload - def identification(self, *, block : typing.Literal[False]) -> asyncio.Future[str]: ... + def identification(self, *, block: typing.Literal[False]) -> asyncio.Future[str]: ... @overload - def identification(self, *, sync : typing.Literal[True]) -> str: ... + def identification(self, *, sync: typing.Literal[True]) -> str: ... @overload - def identification(self, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[str]: ... + def identification( + self, *, block: typing.Literal[False], sync: typing.Literal[True] + ) -> concurrent.futures.Future[str]: ... @Work.run_sync async def identification(self) -> str: - ''' + """ Get the identification string. **Arguments** @@ -1829,19 +1980,19 @@ async def identification(self) -> str: **Result** * `str`: the identification string, which is empty when not supported, when `block = True` * otherwise a future with this `str` - ''' + """ if self._identification is not None: return self._identification - if not 'i' in await self.capabilities(): - self._identification = '' + if not "i" in await self.capabilities(): + self._identification = "" return self._identification try: - self._identification = (await self.req(b'i')).decode() + self._identification = (await self.req(b"i")).decode() except ValueError: - self._identification = '' + self._identification = "" assert self._identification is not None return self._identification @@ -1849,15 +2000,17 @@ async def identification(self) -> str: @overload async def version(self) -> str: ... @overload - def version(self, *, block : typing.Literal[False]) -> asyncio.Future[str]: ... + def version(self, *, block: typing.Literal[False]) -> asyncio.Future[str]: ... @overload - def version(self, *, sync : typing.Literal[True]) -> str: ... + def version(self, *, sync: typing.Literal[True]) -> str: ... @overload - def version(self, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[str]: ... + def version( + self, *, block: typing.Literal[False], sync: typing.Literal[True] + ) -> concurrent.futures.Future[str]: ... @Work.run_sync async def version(self) -> str: - ''' + """ Get the version string. **Arguments** @@ -1866,30 +2019,34 @@ async def version(self) -> str: **Result** * `str`: the version string, which is empty when not supported, when `block = True` * otherwise a future with this `str` - ''' + """ if self._version is not None: return self._version try: - self._version = (await self.req(b'v')).decode() + self._version = (await self.req(b"v")).decode() except ValueError: - self._version = '' + self._version = "" assert self._version is not None return self._version @overload - async def read_mem(self, pointer : int, size : int) -> bytearray: ... + async def read_mem(self, pointer: int, size: int) -> bytearray: ... @overload - def read_mem(self, pointer : int, size : int, *, block : typing.Literal[False]) -> asyncio.Future[bytearray]: ... + def read_mem( + self, pointer: int, size: int, *, block: typing.Literal[False] + ) -> asyncio.Future[bytearray]: ... @overload - def read_mem(self, pointer : int, size : int, *, sync : typing.Literal[True]) -> bytearray: ... + def read_mem(self, pointer: int, size: int, *, sync: typing.Literal[True]) -> bytearray: ... @overload - def read_mem(self, pointer : int, size : int, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[bytearray]: ... + def read_mem( + self, pointer: int, size: int, *, block: typing.Literal[False], sync: typing.Literal[True] + ) -> concurrent.futures.Future[bytearray]: ... @Work.run_sync - async def read_mem(self, pointer : int, size : int) -> bytearray: - ''' + async def read_mem(self, pointer: int, size: int) -> bytearray: + """ Read memory from the connected device. **Arguments** @@ -1904,37 +2061,46 @@ async def read_mem(self, pointer : int, size : int) -> bytearray: **Raises** * `NotSupported`: when the ReadMem command is not supported by the server * `OperationFailed`: when the ReadMem command failed - ''' + """ - if 'R' not in await self.capabilities(): - raise lexc.NotSupported('ReadMem command not supported') + if "R" not in await self.capabilities(): + raise lexc.NotSupported("ReadMem command not supported") - rep = await self.req(f'R{pointer:x} {size}') + rep = await self.req(f"R{pointer:x} {size}") - if rep == '?': - raise lexc.OperationFailed('ReadMem command failed') + if rep == "?": + raise lexc.OperationFailed("ReadMem command failed") if len(rep) & 1: # Odd number of bytes. - raise lexc.OperationFailed('Invalid ReadMem response') + raise lexc.OperationFailed("Invalid ReadMem response") res = bytearray() for i in range(0, len(rep), 2): - res.append(int(rep[i:i+2], 16)) + res.append(int(rep[i : i + 2], 16)) return res @overload - async def write_mem(self, pointer : int, data : bytearray) -> None: ... + async def write_mem(self, pointer: int, data: bytearray) -> None: ... @overload - def write_mem(self, pointer : int, data : bytearray, *, block : typing.Literal[False]) -> asyncio.Future[None]: ... + def write_mem( + self, pointer: int, data: bytearray, *, block: typing.Literal[False] + ) -> asyncio.Future[None]: ... @overload - def write_mem(self, pointer : int, data : bytearray, *, sync : typing.Literal[True]) -> None: ... + def write_mem(self, pointer: int, data: bytearray, *, sync: typing.Literal[True]) -> None: ... @overload - def write_mem(self, pointer : int, data : bytearray, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[None]: ... + def write_mem( + self, + pointer: int, + data: bytearray, + *, + block: typing.Literal[False], + sync: typing.Literal[True], + ) -> concurrent.futures.Future[None]: ... @Work.run_sync - async def write_mem(self, pointer : int, data : bytearray): - ''' + async def write_mem(self, pointer: int, data: bytearray): + """ Write memory to the connected device. **Arguments** @@ -1949,18 +2115,16 @@ async def write_mem(self, pointer : int, data : bytearray): **Raises** * `NotSupported`: when the WriteMem command is not supported by the server * `OperationFailed`: when the WriteMem command failed - ''' - if 'W' not in await self.capabilities(): - raise lexc.NotSupported('WriteMem command not supported') + """ + if "W" not in await self.capabilities(): + raise lexc.NotSupported("WriteMem command not supported") - req = f'W{pointer:x} ' + req = f"W{pointer:x} " for i in range(0, len(data)): - req += f'{data[i]:02x}' + req += f"{data[i]:02x}" rep = await self.req(req) - if rep != '!': - raise lexc.OperationFailed('WriteMem command failed') - - + if rep != "!": + raise lexc.OperationFailed("WriteMem command failed") ############################################## # Objects @@ -1975,15 +2139,17 @@ def objects(self) -> typing.List[Object]: @overload async def list(self) -> typing.List[Object]: ... @overload - def list(self, *, block : typing.Literal[False]) -> asyncio.Future[typing.List[Object]]: ... + def list(self, *, block: typing.Literal[False]) -> asyncio.Future[typing.List[Object]]: ... @overload - def list(self, *, sync : typing.Literal[True]) -> typing.List[Object]: ... + def list(self, *, sync: typing.Literal[True]) -> typing.List[Object]: ... @overload - def list(self, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[typing.List[Object]]: ... + def list( + self, *, block: typing.Literal[False], sync: typing.Literal[True] + ) -> concurrent.futures.Future[typing.List[Object]]: ... @Work.run_sync async def list(self) -> typing.List[Object]: - ''' + """ List the objects available. **Arguments** @@ -1996,17 +2162,17 @@ async def list(self) -> typing.List[Object]: **Raises** * `NotSupported`: when the List command is not supported by the server * `InvalidResponse`: when the List command returned an invalid response - ''' + """ if not self._objects is None: return self.objects - if 'l' not in await self.capabilities(): - raise lexc.NotSupported('List command not supported') + if "l" not in await self.capabilities(): + raise lexc.NotSupported("List command not supported") res = [] - for o in (await self.req('l')).split('\n'): - if o == '': + for o in (await self.req("l")).split("\n"): + if o == "": continue obj = Object.create(o, self) if obj is None: @@ -2020,46 +2186,46 @@ async def list(self) -> typing.List[Object]: self._objects = res return self.objects - def _pyname(self, name : str) -> str: - '''Convert an object name to a valid Python attribute name.''' + def _pyname(self, name: str) -> str: + """Convert an object name to a valid Python attribute name.""" - n = re.sub(r'[^A-Za-z0-9/]+', '_', name) - n = re.sub(r'_*/+', '__', n) - n = re.sub(r'^__', '', n) - n = re.sub(r'^[^A-Za-z]_*', '_', n) - n = re.sub(r'_+$', '', n) + n = re.sub(r"[^A-Za-z0-9/]+", "_", name) + n = re.sub(r"_*/+", "__", n) + n = re.sub(r"^__", "", n) + n = re.sub(r"^[^A-Za-z]_*", "_", n) + n = re.sub(r"_+$", "", n) - if n == '': - n = 'obj' + if n == "": + n = "obj" if keyword.iskeyword(n): - n += '_obj' + n += "_obj" if hasattr(self, n): i = 1 - while hasattr(self, f'{n}_{i}'): + while hasattr(self, f"{n}_{i}"): i += 1 - n = f'{n}_{i}' + n = f"{n}_{i}" return n - def find(self, name : str, all=False) -> Object | typing.Set[Object] | None: - ''' + def find(self, name: str, all=False) -> Object | typing.Set[Object] | None: + """ Find object(s) by name. This functions uses the previously retrieved list of objects. - ''' + """ if self._objects is None: return None - chunks = name.split('/') + chunks = name.split("/") obj1 = set() obj2 = set() obj3 = set() obj4 = set() for o in self._objects: - ochunks = o.name.split('/') + ochunks = o.name.split("/") if len(chunks) != len(ochunks): continue @@ -2082,7 +2248,10 @@ def find(self, name : str, all=False) -> Object | typing.Set[Object] | None: # Case 2. match = True for i in range(0, len(ochunks)): - if re.fullmatch(re.sub(r'\\\?', '.', re.escape(ochunks[i])) + r'.*', chunks[i]) is None: + if ( + re.fullmatch(re.sub(r"\\\?", ".", re.escape(ochunks[i])) + r".*", chunks[i]) + is None + ): match = False break # It seems to match. Additional check: the object's chunk should not be longer, as it makes name ambiguous. @@ -2107,21 +2276,24 @@ def find(self, name : str, all=False) -> Object | typing.Set[Object] | None: else: exact = False if match: - obj3 = {(x,e) for x,e in obj3 if e >= exactLen} - best = max(obj3, key=lambda x: x[1], default=(None,0))[1] + obj3 = {(x, e) for x, e in obj3 if e >= exactLen} + best = max(obj3, key=lambda x: x[1], default=(None, 0))[1] if exactLen >= best: - obj3.add((o,exactLen)) + obj3.add((o, exactLen)) # Case 4. match = True for i in range(0, len(ochunks)): - if re.fullmatch(re.sub(r'\\\?', '.', re.escape(ochunks[i])) + r'.*', chunks[i]) is None: + if ( + re.fullmatch(re.sub(r"\\\?", ".", re.escape(ochunks[i])) + r".*", chunks[i]) + is None + ): match = False break if match: obj4.add(o) - obj = obj1 | obj2 | {x for x,e in obj3} | obj4 + obj = obj1 | obj2 | {x for x, e in obj3} | obj4 if all: return obj if len(obj1) == 1: @@ -2134,8 +2306,8 @@ def find(self, name : str, all=False) -> Object | typing.Set[Object] | None: else: return obj - def obj(self, x : str) -> Object: - '''Get an object by name.''' + def obj(self, x: str) -> Object: + """Get an object by name.""" try: return getattr(self, x) @@ -2153,8 +2325,6 @@ def obj(self, x : str) -> Object: def __getitem__(self, x): return self.obj(x) - - ############################################## # Time @@ -2169,15 +2339,17 @@ def time(self) -> Object | None: @overload async def find_time(self) -> Object | None: ... @overload - def find_time(self, *, block : typing.Literal[False]) -> asyncio.Future[Object | None]: ... + def find_time(self, *, block: typing.Literal[False]) -> asyncio.Future[Object | None]: ... @overload - def find_time(self, *, sync : typing.Literal[True]) -> Object | None: ... + def find_time(self, *, sync: typing.Literal[True]) -> Object | None: ... @overload - def find_time(self, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[Object | None]: ... + def find_time( + self, *, block: typing.Literal[False], sync: typing.Literal[True] + ) -> concurrent.futures.Future[Object | None]: ... @Work.run_sync async def find_time(self) -> Object | None: - ''' + """ Find the time object. It should start with `/t`, and have a unit between parentheses, like `/t (s)`. @@ -2187,13 +2359,13 @@ async def find_time(self) -> Object | None: **Result** * `Object | None`: the time object when found, or None when not found, when `block = True` * otherwise a future with this `Object | None` - ''' + """ if isinstance(self._t, str): # Take the given time variable t = self.find(self._t) else: # Try finding /t (unit) - t = self.find('/t (') + t = self.find("/t (") # Not initialized. self._t = False @@ -2201,11 +2373,11 @@ async def find_time(self) -> Object | None: if t is None: # Not found, try the first /store/t (unit) for o in self.objects: - chunks = o.name.split('/', 4) + chunks = o.name.split("/", 4) if len(chunks) != 3: # Strange name continue - elif chunks[2].startswith('t ('): + elif chunks[2].startswith("t ("): # Got some t = o else: @@ -2232,14 +2404,14 @@ async def find_time(self) -> Object | None: return None # Try parse the unit - unit = re.sub(r'.*/t \((.*)\)$', r'\1', t.name) - if unit == 's': + unit = re.sub(r".*/t \((.*)\)$", r"\1", t.name) + if unit == "s": self._timestamp_to_time = lambda t: float(t - t0) + self._t0 - elif unit == 'ms': + elif unit == "ms": self._timestamp_to_time = lambda t: float(t - t0) / 1e3 + self._t0 - elif unit == 'us': + elif unit == "us": self._timestamp_to_time = lambda t: float(t - t0) / 1e6 + self._t0 - elif unit == 'ns': + elif unit == "ns": self._timestamp_to_time = lambda t: float(t - t0) / 1e9 + self._t0 else: # Don't know a conversion, just use the raw value. @@ -2250,10 +2422,10 @@ async def find_time(self) -> Object | None: # All set. self._t = t - self.logger.info('time object: %s', t.name) + self.logger.info("time object: %s", t.name) return self._t - def timestamp_to_time(self, t : float | None=None) -> float: + def timestamp_to_time(self, t: float | None = None) -> float: if not isinstance(self._t, Object): # No time object found. return time.time() @@ -2261,23 +2433,23 @@ def timestamp_to_time(self, t : float | None=None) -> float: # Override to implement arbitrary conversion. return self._timestamp_to_time(t if t is not None else self._t.value) - - ############################################## # Streams @overload async def streams(self) -> typing.List[str]: ... @overload - def streams(self, *, block : typing.Literal[False]) -> asyncio.Future[typing.List[str]]: ... + def streams(self, *, block: typing.Literal[False]) -> asyncio.Future[typing.List[str]]: ... @overload - def streams(self, *, sync : typing.Literal[True]) -> typing.List[str]: ... + def streams(self, *, sync: typing.Literal[True]) -> typing.List[str]: ... @overload - def streams(self, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[typing.List[str]]: ... + def streams( + self, *, block: typing.Literal[False], sync: typing.Literal[True] + ) -> concurrent.futures.Future[typing.List[str]]: ... @Work.run_sync async def streams(self) -> typing.List[str]: - ''' + """ Get the list of available streams. **Arguments** @@ -2286,13 +2458,13 @@ async def streams(self) -> typing.List[str]: **Result** * `List[str]`: the list of stream names when `block = True` * otherwise a future with this `List[str]` - ''' + """ - if 's' not in await self.capabilities(): + if "s" not in await self.capabilities(): return [] - rep = await self.req(b's') - if rep == b'?': + rep = await self.req(b"s") + if rep == b"?": return [] else: return list(map(lambda b: chr(b), rep)) @@ -2310,11 +2482,11 @@ async def other_streams(self): return streams - def stream(self, s : str, raw : bool=False) -> Stream: - '''Get a Stream object for the given stream name.''' + def stream(self, s: str, raw: bool = False) -> Stream: + """Get a Stream object for the given stream name.""" if not isinstance(s, str) or len(s) != 1: - raise ValueError('Invalid stream name ' + s) + raise ValueError("Invalid stream name " + s) if s in self._streams: return self._streams[s] @@ -2322,29 +2494,58 @@ def stream(self, s : str, raw : bool=False) -> Stream: self._streams[s] = Stream(self, s, raw) return self._streams[s] - - ############################################## # Alias @overload - async def alias(self, obj : str | Object, prefer : str | None=None, - temporary : bool=True, permanentRef : typing.Any=None) -> str | None: ... - @overload - def alias(self, obj : str | Object, prefer : str | None=None, - temporary : bool=True, permanentRef : typing.Any=None, *, block : typing.Literal[False]) -> asyncio.Future[str | None]: ... - @overload - def alias(self, obj : str | Object, prefer : str | None=None, - temporary : bool=True, permanentRef : typing.Any=None, *, sync : typing.Literal[True]) -> str | None: ... - @overload - def alias(self, obj : str | Object, prefer : str | None=None, - temporary : bool=True, permanentRef : typing.Any=None, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[str | None]: ... + async def alias( + self, + obj: str | Object, + prefer: str | None = None, + temporary: bool = True, + permanentRef: typing.Any = None, + ) -> str | None: ... + @overload + def alias( + self, + obj: str | Object, + prefer: str | None = None, + temporary: bool = True, + permanentRef: typing.Any = None, + *, + block: typing.Literal[False], + ) -> asyncio.Future[str | None]: ... + @overload + def alias( + self, + obj: str | Object, + prefer: str | None = None, + temporary: bool = True, + permanentRef: typing.Any = None, + *, + sync: typing.Literal[True], + ) -> str | None: ... + @overload + def alias( + self, + obj: str | Object, + prefer: str | None = None, + temporary: bool = True, + permanentRef: typing.Any = None, + *, + block: typing.Literal[False], + sync: typing.Literal[True], + ) -> concurrent.futures.Future[str | None]: ... @Work.run_sync - async def alias(self, obj : str | Object, prefer : str | None=None, - temporary : bool=True, permanentRef : typing.Any=None) -> str | None: - - ''' + async def alias( + self, + obj: str | Object, + prefer: str | None = None, + temporary: bool = True, + permanentRef: typing.Any = None, + ) -> str | None: + """ Assign an alias to an object. **Arguments** @@ -2356,7 +2557,7 @@ async def alias(self, obj : str | Object, prefer : str | None=None, **Result** * `str | None`: the assigned alias, or None when no alias could be assigned, when `block = True` * otherwise a future with this `str | None` - ''' + """ if isinstance(obj, str): obj = self.obj(obj) @@ -2364,9 +2565,9 @@ async def alias(self, obj : str | Object, prefer : str | None=None, async with self._alias_lock: if self._available_aliases is None: # Not yet initialized - if 'a' in await self.capabilities(): - self._available_aliases = list(map(chr, range(0x20, 0x7f))) - self._available_aliases.remove('/') + if "a" in await self.capabilities(): + self._available_aliases = list(map(chr, range(0x20, 0x7F))) + self._available_aliases.remove("/") else: self._available_aliases = [] @@ -2410,45 +2611,47 @@ async def alias(self, obj : str | Object, prefer : str | None=None, return None return await self._acquire_alias(a, obj, temporary, permanentRef) - def _is_alias_available(self, a : str) -> bool: + def _is_alias_available(self, a: str) -> bool: return self._available_aliases is not None and a in self._available_aliases - def _is_temporary_alias(self, a : str) -> bool: + def _is_temporary_alias(self, a: str) -> bool: return a in self._temporary_aliases - def _is_alias_in_use(self, a : str) -> bool: + def _is_alias_in_use(self, a: str) -> bool: return a in self._temporary_aliases or a in self._permanent_aliases - def _inc_permanent_alias(self, a : str, permanentRef : typing.Any): + def _inc_permanent_alias(self, a: str, permanentRef: typing.Any): assert a in self._permanent_aliases - self.logger.debug(f'increment permanent alias {a} use') + self.logger.debug(f"increment permanent alias {a} use") self._permanent_aliases[a][1].append(permanentRef) - def _dec_permanent_alias(self, a : str, permanentRef : typing.Any): + def _dec_permanent_alias(self, a: str, permanentRef: typing.Any): assert a in self._permanent_aliases if permanentRef is None: - self.logger.debug(f'ignored decrement permanent alias {a} use') + self.logger.debug(f"ignored decrement permanent alias {a} use") return False try: self._permanent_aliases[a][1].remove(permanentRef) - self.logger.debug(f'decrement permanent alias {a} use') + self.logger.debug(f"decrement permanent alias {a} use") except ValueError: # Unknown ref. pass return self._permanent_aliases[a][1] == [] - async def _acquire_alias(self, a : str, obj : Object, temporary : bool, permanentRef : typing.Any) -> str | None: + async def _acquire_alias( + self, a: str, obj: Object, temporary: bool, permanentRef: typing.Any + ) -> str | None: assert not self._is_alias_in_use(a) assert self._available_aliases is not None if not (isinstance(a, str) and len(a) == 1): - raise ValueError('Invalid alias ' + a) + raise ValueError("Invalid alias " + a) available_upon_rollback = False if a in self._available_aliases: - self.logger.debug('available: ' + ''.join(self._available_aliases)) + self.logger.debug("available: " + "".join(self._available_aliases)) self._available_aliases.remove(a) available_upon_rollback = True @@ -2472,56 +2675,58 @@ async def _acquire_alias(self, a : str, obj : Object, temporary : bool, permanen # Success! if temporary: - self.logger.debug(f'new temporary alias {a} for {obj.name}') + self.logger.debug(f"new temporary alias {a} for {obj.name}") self._temporary_aliases[a] = obj else: - self.logger.debug(f'new permanent alias {a} for {obj.name}') + self.logger.debug(f"new permanent alias {a} for {obj.name}") self._permanent_aliases[a] = (obj, [permanentRef]) obj.alias.value = a return a - async def _set_alias(self, a : str, name : str) -> bool: - rep = await self.req(b'a' + a.encode() + name.encode()) - return rep == b'!' + async def _set_alias(self, a: str, name: str) -> bool: + rep = await self.req(b"a" + a.encode() + name.encode()) + return rep == b"!" - async def _reassign_alias(self, a : str, obj : Object, temporary : bool, permanentRef : typing.Any) -> str | None: + async def _reassign_alias( + self, a: str, obj: Object, temporary: bool, permanentRef: typing.Any + ) -> str | None: assert a in self._temporary_aliases or a in self._permanent_aliases assert not self._is_alias_available(a) assert self._available_aliases is not None if not self._release_alias(a, permanentRef): # Not allowed, still is use as permanent alias. - self.logger.debug(f'cannot release alias {a}; still in use') + self.logger.debug(f"cannot release alias {a}; still in use") else: if a in self._available_aliases: self._available_aliases.remove(a) if temporary: - self.logger.debug(f'reassigned temporary alias {a} to {obj.name}') + self.logger.debug(f"reassigned temporary alias {a} to {obj.name}") self._temporary_aliases[a] = obj else: - self.logger.debug(f'reassigned permanent alias {a} to {obj.name}') + self.logger.debug(f"reassigned permanent alias {a} to {obj.name}") self._permanent_aliases[a] = (obj, [permanentRef]) obj.alias.value = a return a - def _release_alias(self, alias : str, permanentRef : typing.Any = None) -> bool: + def _release_alias(self, alias: str, permanentRef: typing.Any = None) -> bool: assert self._available_aliases is not None obj = None if alias in self._temporary_aliases: obj = self._temporary_aliases[alias] del self._temporary_aliases[alias] - self.logger.debug(f'released temporary alias {alias}') + self.logger.debug(f"released temporary alias {alias}") elif alias in self._permanent_aliases: if not self._dec_permanent_alias(alias, permanentRef): # Do not release (yet). return False obj = self._permanent_aliases[alias][0] del self._permanent_aliases[alias] - self.logger.debug(f'released permanent alias {alias}') + self.logger.debug(f"released permanent alias {alias}") else: - self.logger.debug(f'released unused alias {alias}') + self.logger.debug(f"released unused alias {alias}") if not obj is None: obj.alias.value = None @@ -2539,23 +2744,34 @@ def _get_temporary_alias(self) -> str | None: keys = list(self._temporary_aliases.keys()) if not keys: return None - a = keys[0] # pick oldest one - self.logger.debug(f'stealing temporary alias {a}') + a = keys[0] # pick oldest one + self.logger.debug(f"stealing temporary alias {a}") self._release_alias(a) return a @overload - async def release_alias(self, alias : str, permanentRef=None) -> None: ... + async def release_alias(self, alias: str, permanentRef=None) -> None: ... @overload - def release_alias(self, alias : str, permanentRef=None, *, block : typing.Literal[False]) -> asyncio.Future[None]: ... + def release_alias( + self, alias: str, permanentRef=None, *, block: typing.Literal[False] + ) -> asyncio.Future[None]: ... @overload - def release_alias(self, alias : str, permanentRef=None, *, sync : typing.Literal[True]) -> None: ... + def release_alias( + self, alias: str, permanentRef=None, *, sync: typing.Literal[True] + ) -> None: ... @overload - def release_alias(self, alias : str, permanentRef=None, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[None]: ... + def release_alias( + self, + alias: str, + permanentRef=None, + *, + block: typing.Literal[False], + sync: typing.Literal[True], + ) -> concurrent.futures.Future[None]: ... @Work.run_sync - async def release_alias(self, alias : str, permanentRef=None): - ''' + async def release_alias(self, alias: str, permanentRef=None): + """ Release an alias. **Arguments** @@ -2565,31 +2781,42 @@ async def release_alias(self, alias : str, permanentRef=None): **Result** * `None`: when `block = True` * otherwise a future - ''' + """ if self._release_alias(alias, permanentRef): - await self.req(b'a' + alias.encode()) + await self.req(b"a" + alias.encode()) @Work.run_sync async def _print_alias_map(self): - '''Print the current alias map.''' + """Print the current alias map.""" async with self._alias_lock: if self._available_aliases is None: print("Not initialized") else: - print("Available aliases: " + ''.join(self._available_aliases)) + print("Available aliases: " + "".join(self._available_aliases)) if len(self._temporary_aliases) == 0: print("No temporary aliases") else: - print("Temporary aliases:\n\t" + '\n\t'.join([f'{a}: {o.name}' for a,o in self._temporary_aliases.items()])) + print( + "Temporary aliases:\n\t" + + "\n\t".join( + [f"{a}: {o.name}" for a, o in self._temporary_aliases.items()] + ) + ) if len(self._permanent_aliases) == 0: print("No permanent aliases") else: - print("Permanent aliases: \n\t" + '\n\t'.join([f'{a}: {o[0].name} ({len(o[1])})' for a,o in self._permanent_aliases.items()])) - - + print( + "Permanent aliases: \n\t" + + "\n\t".join( + [ + f"{a}: {o[0].name} ({len(o[1])})" + for a, o in self._permanent_aliases.items() + ] + ) + ) ############################################## # Macro @@ -2597,15 +2824,17 @@ async def _print_alias_map(self): @overload async def acquire_macro(self) -> str | None: ... @overload - def acquire_macro(self, *, block : typing.Literal[False]) -> asyncio.Future[str | None]: ... + def acquire_macro(self, *, block: typing.Literal[False]) -> asyncio.Future[str | None]: ... @overload - def acquire_macro(self, *, sync : typing.Literal[True]) -> str | None: ... + def acquire_macro(self, *, sync: typing.Literal[True]) -> str | None: ... @overload - def acquire_macro(self, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[str | None]: ... + def acquire_macro( + self, *, block: typing.Literal[False], sync: typing.Literal[True] + ) -> concurrent.futures.Future[str | None]: ... @Work.run_sync async def acquire_macro(self) -> str | None: - ''' + """ Get a free macro name. In case there is no available macro name, `None` is returned. This can @@ -2618,16 +2847,16 @@ async def acquire_macro(self) -> str | None: **Result** * `str | None`: the macro name when `block = True` * otherwise a future with this `str | None` - ''' + """ if self._available_macros is None: # Not initialized yet. capabilities = await self.capabilities() - if 'm' not in capabilities: + if "m" not in capabilities: # Not supported. self._available_macros = [] else: - self._available_macros = list(map(chr, range(0x20, 0x7f))) + self._available_macros = list(map(chr, range(0x20, 0x7F))) for c in capabilities: self._available_macros.remove(c) @@ -2638,27 +2867,31 @@ async def acquire_macro(self) -> str | None: self._used_macros.append(m) return m - def macro(self, name : str | None, *args, **kwargs) -> Macro: - ''' + def macro(self, name: str | None, *args, **kwargs) -> Macro: + """ Create a macro object for the given macro name. - ''' + """ mo = Macro(self, name, *args, **kwargs) self._macros.append(mo) return mo @overload - async def release_macro(self, m : str | bytes | Macro) -> None: ... + async def release_macro(self, m: str | bytes | Macro) -> None: ... @overload - def release_macro(self, m : str | bytes | Macro, *, block : typing.Literal[False]) -> asyncio.Future[None]: ... + def release_macro( + self, m: str | bytes | Macro, *, block: typing.Literal[False] + ) -> asyncio.Future[None]: ... @overload - def release_macro(self, m : str | bytes | Macro, *, sync : typing.Literal[True]) -> None: ... + def release_macro(self, m: str | bytes | Macro, *, sync: typing.Literal[True]) -> None: ... @overload - def release_macro(self, m : str | bytes | Macro, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[None]: ... + def release_macro( + self, m: str | bytes | Macro, *, block: typing.Literal[False], sync: typing.Literal[True] + ) -> concurrent.futures.Future[None]: ... @Work.run_sync - async def release_macro(self, m : str | bytes | Macro): - ''' + async def release_macro(self, m: str | bytes | Macro): + """ Release a macro. **Arguments** @@ -2668,7 +2901,7 @@ async def release_macro(self, m : str | bytes | Macro): **Result** * `None`: when `block = True` * otherwise a future - ''' + """ macro = None mo = None @@ -2685,7 +2918,7 @@ async def release_macro(self, m : str | bytes | Macro): assert self._available_macros is not None self._used_macros.remove(macro) self._available_macros.append(macro) - await self.req(b'm' + macro.encode()) + await self.req(b"m" + macro.encode()) if mo is None: assert isinstance(macro, str) @@ -2699,19 +2932,28 @@ async def release_macro(self, m : str | bytes | Macro): self._macros.remove(mo) mo.destroy() - - ############################################## # Poll @overload - def periodic(self, interval_s : float, f : typing.Callable, *args, name : str | None=None) -> asyncio.Task: ... + def periodic( + self, interval_s: float, f: typing.Callable, *args, name: str | None = None + ) -> asyncio.Task: ... @overload - def periodic(self, interval_s : float, f : typing.Callable, *args, name : str | None=None, block : typing.Literal[False]) -> concurrent.futures.Future[asyncio.Task] | asyncio.Task: ... + def periodic( + self, + interval_s: float, + f: typing.Callable, + *args, + name: str | None = None, + block: typing.Literal[False], + ) -> concurrent.futures.Future[asyncio.Task] | asyncio.Task: ... @Work.thread_safe - def periodic(self, interval_s : float, f : typing.Callable, *args, name : str | None=None) -> asyncio.Task: - ''' + def periodic( + self, interval_s: float, f: typing.Callable, *args, name: str | None = None + ) -> asyncio.Task: + """ Run a function periodically while the client is alive. **Arguments** @@ -2724,14 +2966,16 @@ def periodic(self, interval_s : float, f : typing.Callable, *args, name : str | **Result** * `asyncio.Task`: the created periodic task when `block = True` * otherwise a future with this `asyncio.Task` - ''' + """ if not interval_s >= 0: - raise ValueError('interval_s must be non-negative') + raise ValueError("interval_s must be non-negative") if not asyncio.iscoroutinefunction(f): + async def _f(*args): f(*args) + coro = _f else: coro = f @@ -2740,18 +2984,18 @@ async def _f(*args): self._periodic_tasks.add(task) return task - async def _periodic(self, interval_s : float, coro : typing.Callable, *args, **kwargs): - name = '' + async def _periodic(self, interval_s: float, coro: typing.Callable, *args, **kwargs): + name = "" try: task = asyncio.current_task() assert task is not None name = task.get_name() - if name != '': - name = ' ' + name + if name != "": + name = " " + name t = time.time() while self.is_connected(): - self.logger.debug('periodic task%s', name) + self.logger.debug("periodic task%s", name) await coro(*args, **kwargs) if interval_s == 0: @@ -2768,13 +3012,13 @@ async def _periodic(self, interval_s : float, coro : typing.Callable, *args, **k except asyncio.CancelledError: pass except lexc.Disconnected as e: - self.logger.debug('periodic task%s stopped; %s', name, e) + self.logger.debug("periodic task%s stopped; %s", name, e) pass except lexc.InvalidState as e: if self.is_connected(): - self.logger.exception('exception in periodic task%s: %s', name, e) + self.logger.exception("exception in periodic task%s: %s", name, e) except Exception as e: - self.logger.exception('exception in periodic task%s: %s', name, e) + self.logger.exception("exception in periodic task%s: %s", name, e) finally: t = asyncio.current_task() try: @@ -2788,9 +3032,9 @@ async def _periodic(self, interval_s : float, coro : typing.Callable, *args, **k fast_poll_threshold_s = 0.9 slow_poll_threshold_s = 1.0 - async def _poll(self, o : Object, interval_s : float | None): + async def _poll(self, o: Object, interval_s: float | None): if o not in self.objects: - raise ValueError('Object not managed by this client') + raise ValueError("Object not managed by this client") if interval_s is None: # Stop slow polling, if any. @@ -2834,7 +3078,7 @@ async def _poll_fast_stop(self): self._fast_poll_task = None t.cancel() - async def _poll_fast(self, o : Object, interval_s : float): + async def _poll_fast(self, o: Object, interval_s: float): t = self.time() if t is None: # Cannot do fast polling without time object. @@ -2842,15 +3086,20 @@ async def _poll_fast(self, o : Object, interval_s : float): if self._fast_poll_macro is None: self._fast_poll_macro = self.macro(await self.acquire_macro()) - await self._fast_poll_macro.add(cmd=f'r{await t.short_name()}', cb=t.handle_read, key=self) + await self._fast_poll_macro.add( + cmd=f"r{await t.short_name()}", cb=t.handle_read, key=self + ) assert self._fast_poll_macro is not None if o not in self._fast_poll_macro: await self.alias(o, temporary=False, permanentRef=self._fast_poll_macro) a = await o.short_name() try: - await self._fast_poll_macro.add(cmd=f'r{a}', - cb=lambda x, _, t=t: o.handle_read(x, self.timestamp_to_time(t.value)), key=o) + await self._fast_poll_macro.add( + cmd=f"r{a}", + cb=lambda x, _, t=t: o.handle_read(x, self.timestamp_to_time(t.value)), + key=o, + ) except lexc.NotSupported: # Cannot do a fast poll. return False @@ -2860,7 +3109,9 @@ async def _poll_fast(self, o : Object, interval_s : float): await self._poll_fast_stop() if self._fast_poll_task is None or self._fast_poll_task.done(): - self._fast_poll_task = self.periodic(self._fast_poll_interval_s, self._poll_fast_task, name='poll fast') + self._fast_poll_task = self.periodic( + self._fast_poll_interval_s, self._poll_fast_task, name="poll fast" + ) return True @@ -2871,7 +3122,7 @@ async def _poll_fast_task(self): await self._fast_poll_macro.run() - async def _trace(self, o : Object, interval_s : float): + async def _trace(self, o: Object, interval_s: float): if self._tracing is False: # Not supported return False @@ -2887,7 +3138,7 @@ async def _trace(self, o : Object, interval_s : float): try: await self._tracing._init() except BaseException as e: - self.logger.info('cannot initialize tracing; %s', e) + self.logger.info("cannot initialize tracing; %s", e) self._tracing = False await self.release_macro(m) return False @@ -2897,21 +3148,19 @@ async def _trace(self, o : Object, interval_s : float): await self.alias(o, temporary=False, permanentRef=self._tracing) a = await o.short_name() try: - await self._tracing.add(cmd=f'r{a}', cb=o.handle_read, key=o) + await self._tracing.add(cmd=f"r{a}", cb=o.handle_read, key=o) except lexc.NotSupported: # Cannot do a trace. return False return True - - ############################################## # Socket monitor async def _monitor_socket(self): if not self.is_connected(): - self.logger.debug('Not connected, not starting socket monitor') + self.logger.debug("Not connected, not starting socket monitor") return assert self._socket is not None @@ -2922,23 +3171,21 @@ async def _monitor_socket(self): event = await monitor.recv_multipart() evt = zmq.utils.monitor.parse_monitor_message(event) self.logger.debug(f'socket event: {repr(evt["event"])}') - if evt['event'] == zmq.EVENT_DISCONNECTED: - self.logger.info('socket disconnected') + if evt["event"] == zmq.EVENT_DISCONNECTED: + self.logger.info("socket disconnected") await self.disconnect() except asyncio.CancelledError: pass except Exception as e: - self.logger.exception('exception in socket monitor: %s', e) + self.logger.exception("exception in socket monitor: %s", e) finally: monitor.close(0) - - ############################################## # State def state(self) -> dict: - '''Get the current state of the client.''' + """Get the current state of the client.""" if not self.is_connected: return {} @@ -2952,26 +3199,34 @@ def state(self) -> dict: objs.update(o.state()) s = { - 'identification': id, - 'version': self._version, - 'last': datetime.datetime.now(datetime.timezone.utc).isoformat(), - 'objects': objs + "identification": id, + "version": self._version, + "last": datetime.datetime.now(datetime.timezone.utc).isoformat(), + "objects": objs, } return {id: s} @overload - async def save_state(self, state_name : str | None=None) -> None: ... + async def save_state(self, state_name: str | None = None) -> None: ... @overload - def save_state(self, state_name : str | None=None, *, block : typing.Literal[False]) -> asyncio.Future[None]: ... + def save_state( + self, state_name: str | None = None, *, block: typing.Literal[False] + ) -> asyncio.Future[None]: ... @overload - def save_state(self, state_name : str | None=None, *, sync : typing.Literal[True]) -> None: ... + def save_state(self, state_name: str | None = None, *, sync: typing.Literal[True]) -> None: ... @overload - def save_state(self, state_name : str | None=None, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[None]: ... + def save_state( + self, + state_name: str | None = None, + *, + block: typing.Literal[False], + sync: typing.Literal[True], + ) -> concurrent.futures.Future[None]: ... @Work.run_sync - async def save_state(self, state_name : str | None=None) -> None: - ''' + async def save_state(self, state_name: str | None = None) -> None: + """ Save the current state to a file. **Arguments** @@ -2980,10 +3235,10 @@ async def save_state(self, state_name : str | None=None) -> None: **Result** * `None`: when `block = True` * otherwise a future - ''' + """ if not self.is_connected(): - raise lexc.InvalidState('Not connected') + raise lexc.InvalidState("Not connected") filename = self.state_file(state_name) if filename is None: @@ -2991,10 +3246,10 @@ async def save_state(self, state_name : str | None=None) -> None: s = self.state() - async with lexc.DeadlockChecker(filelock.AsyncFileLock(f'{filename}.lock')): + async with lexc.DeadlockChecker(filelock.AsyncFileLock(f"{filename}.lock")): state = {} try: - async with aiofiles.open(filename, 'r') as f: + async with aiofiles.open(filename, "r") as f: state = json.loads(await f.read()) except FileNotFoundError: pass @@ -3008,26 +3263,36 @@ async def save_state(self, state_name : str | None=None) -> None: else: state.update(s) - state['_version'] = libstored_version + state["_version"] = libstored_version os.makedirs(os.path.dirname(filename), exist_ok=True) - async with aiofiles.open(filename, 'w') as f: + async with aiofiles.open(filename, "w") as f: await f.write(json.dumps(state, indent=4, sort_keys=True)) - self.logger.debug('saved state to %s', filename) + self.logger.debug("saved state to %s", filename) @overload - async def restore_state(self, state_name : str | None=None) -> None: ... + async def restore_state(self, state_name: str | None = None) -> None: ... @overload - def restore_state(self, state_name : str | None=None, *, block : typing.Literal[False]) -> asyncio.Future[None]: ... + def restore_state( + self, state_name: str | None = None, *, block: typing.Literal[False] + ) -> asyncio.Future[None]: ... @overload - def restore_state(self, state_name : str | None=None, *, sync : typing.Literal[True]) -> None: ... + def restore_state( + self, state_name: str | None = None, *, sync: typing.Literal[True] + ) -> None: ... @overload - def restore_state(self, state_name : str | None=None, *, block : typing.Literal[False], sync : typing.Literal[True]) -> concurrent.futures.Future[None]: ... + def restore_state( + self, + state_name: str | None = None, + *, + block: typing.Literal[False], + sync: typing.Literal[True], + ) -> concurrent.futures.Future[None]: ... @Work.run_sync - async def restore_state(self, state_name : str | None=None): - ''' + async def restore_state(self, state_name: str | None = None): + """ Restore the state from a file. **Arguments** @@ -3036,10 +3301,10 @@ async def restore_state(self, state_name : str | None=None): **Result** * `None`: when `block = True` * otherwise a future - ''' + """ if not self.is_connected(): - raise lexc.InvalidState('Not connected') + raise lexc.InvalidState("Not connected") filename = self.state_file(state_name) if not filename: @@ -3053,31 +3318,31 @@ async def restore_state(self, state_name : str | None=None): if not obj: return - async with lexc.DeadlockChecker(filelock.AsyncFileLock(f'{filename}.lock')): + async with lexc.DeadlockChecker(filelock.AsyncFileLock(f"{filename}.lock")): try: - async with aiofiles.open(filename, 'r') as f: + async with aiofiles.open(filename, "r") as f: state = json.loads(await f.read()) except FileNotFoundError: - self.logger.debug('cannot restore state from %s; not found', filename) + self.logger.debug("cannot restore state from %s; not found", filename) return except json.JSONDecodeError as e: - self.logger.warning('cannot restore state from %s; invalid JSON: %s', filename, e) + self.logger.warning("cannot restore state from %s; invalid JSON: %s", filename, e) return if not id in state: return s = state[id] - if not 'objects' in s: + if not "objects" in s: return for o in obj: - await o.restore_state(s['objects']) + await o.restore_state(s["objects"]) - self.logger.debug('restored state from %s', filename) + self.logger.debug("restored state from %s", filename) - def state_file(self, state_name : str | None=None) -> str | None: - '''Get the state file name.''' + def state_file(self, state_name: str | None = None) -> str | None: + """Get the state file name.""" if not state_name: state_name = self._use_state @@ -3085,48 +3350,46 @@ def state_file(self, state_name : str | None=None) -> str | None: if not state_name: return None - return os.path.join(platformdirs.user_config_dir('libstored'), state_name + '.json') - + return os.path.join(platformdirs.user_config_dir("libstored"), state_name + ".json") class SyncObject: - ''' + """ A synchronous ZeroMQ client object. This class wraps the AsyncZmqClientObject to provide a synchronous interface. - ''' + """ - def __init__(self, obj : Object): + def __init__(self, obj: Object): self._obj = obj def __getattr__(self, name): return getattr(self._obj, name) - def read(self, acquire_alias : bool=True) -> typing.Any: - '''Read the value of the object.''' + def read(self, acquire_alias: bool = True) -> typing.Any: + """Read the value of the object.""" return self._obj.read(acquire_alias, sync=True) - def write(self, value : typing.Any = None) -> None: - '''Write a value to the object.''' + def write(self, value: typing.Any = None) -> None: + """Write a value to the object.""" return self._obj.write(value, sync=True) - class SyncZmqClient: - ''' + """ A synchronous ZeroMQ client. This class wraps the ZmqClient to provide a synchronous interface that is understood by static code analyzers. - ''' + """ - def __init__(self, client : ZmqClient): + def __init__(self, client: ZmqClient): self._client = client def __getattr__(self, name): return getattr(self._client, name) - def __getitem__(self, x : str) -> SyncObject: + def __getitem__(self, x: str) -> SyncObject: return SyncObject(self.obj(x)) def __enter__(self): diff --git a/python/libstored/cli/__main__.py b/python/libstored/cli/__main__.py index 7a9f5201..69e8acdf 100644 --- a/python/libstored/cli/__main__.py +++ b/python/libstored/cli/__main__.py @@ -14,31 +14,51 @@ from ..version import __version__ from .. import protocol as lprot + @run_sync -async def async_main(args : argparse.Namespace): +async def async_main(args: argparse.Namespace): stack = None if args.encrypted: stack = lprot.Aes256Layer(args.encrypted, reqrep=True) async with ZmqClient(args.server, args.port, multi=True, stack=stack) as client: - prefix = '> ' + prefix = "> " await aiofiles.stdout.write(prefix) await aiofiles.stdout.flush() async for line in aiofiles.stdin: line = line.strip() if len(line) > 0: - await aiofiles.stdout.write('< ' + await client.req(line) + '\n') + await aiofiles.stdout.write("< " + await client.req(line) + "\n") await aiofiles.stdout.write(prefix) await aiofiles.stdout.flush() + def main(): - parser = argparse.ArgumentParser(description='ZMQ command line client', prog=__package__) - parser.add_argument('-V', '--version', action='version', version=__version__) - parser.add_argument('-s', '--server', dest='server', type=str, default='localhost', help='ZMQ server to connect to') - parser.add_argument('-p', '--port', dest='port', type=int, default=lprot.default_port, help='port') - parser.add_argument('-v', '--verbose', dest='verbose', default=0, help='Enable verbose output', action='count') - parser.add_argument('-e', '--encrypt', dest='encrypted', type=str, default=None, - help='Enable AES-256 CTR encryption with the given pre-shared key file', metavar='file') + parser = argparse.ArgumentParser(description="ZMQ command line client", prog=__package__) + parser.add_argument("-V", "--version", action="version", version=__version__) + parser.add_argument( + "-s", + "--server", + dest="server", + type=str, + default="localhost", + help="ZMQ server to connect to", + ) + parser.add_argument( + "-p", "--port", dest="port", type=int, default=lprot.default_port, help="port" + ) + parser.add_argument( + "-v", "--verbose", dest="verbose", default=0, help="Enable verbose output", action="count" + ) + parser.add_argument( + "-e", + "--encrypt", + dest="encrypted", + type=str, + default=None, + help="Enable AES-256 CTR encryption with the given pre-shared key file", + metavar="file", + ) args = parser.parse_args() @@ -59,5 +79,6 @@ def main(): os._exit(0) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/python/libstored/cmake/__main__.py b/python/libstored/cmake/__main__.py index 6cf0ada9..d864a291 100644 --- a/python/libstored/cmake/__main__.py +++ b/python/libstored/cmake/__main__.py @@ -13,50 +13,66 @@ from ..version import __version__ -logger = logging.getLogger('cmake') +logger = logging.getLogger("cmake") # We are either installed by pip, so the sources are at here/../data, # or we are just running from the git repo, which is at here/../../.. here = os.path.dirname(__file__) -libstored_dir = os.path.normpath(os.path.abspath(os.path.join(here, '..', 'data'))) -if not os.path.isdir(libstored_dir) or os.path.isfile(os.path.join(libstored_dir, 'ignore')): - libstored_dir = os.path.normpath(os.path.abspath(os.path.join(here, '..', '..', '..'))) +libstored_dir = os.path.normpath(os.path.abspath(os.path.join(here, "..", "data"))) +if not os.path.isdir(libstored_dir) or os.path.isfile(os.path.join(libstored_dir, "ignore")): + libstored_dir = os.path.normpath(os.path.abspath(os.path.join(here, "..", "..", ".."))) + def escapebs(s): - return re.sub(r'\\', r'\\\\', s) + return re.sub(r"\\", r"\\\\", s) + def escapestr(s): - return re.sub(r'([\\"])', r'\\\1', s) + return re.sub(r'([\\"])', r"\\\1", s) + def generate_cmake(filename, defines): - jenv = jinja2.Environment( - loader = jinja2.FileSystemLoader( - os.path.join(libstored_dir, 'cmake') - )) + jenv = jinja2.Environment(loader=jinja2.FileSystemLoader(os.path.join(libstored_dir, "cmake"))) - jenv.filters['escapebs'] = escapebs - jenv.filters['escapestr'] = escapestr + jenv.filters["escapebs"] = escapebs + jenv.filters["escapestr"] = escapestr - tmpl = jenv.get_template('FindLibstored.cmake.tmpl') + tmpl = jenv.get_template("FindLibstored.cmake.tmpl") + + logger.info("Writing to %s...", filename) + with open(filename, "w") as f: + f.write( + tmpl.render( + python_executable=sys.executable, + libstored_dir=libstored_dir, + defines=defines, + ) + ) - logger.info('Writing to %s...', filename) - with open(filename, 'w') as f: - f.write(tmpl.render( - python_executable=sys.executable, - libstored_dir=libstored_dir, - defines=defines, - )) def main(): - parser = argparse.ArgumentParser(prog=__package__, - description='Generator for find_package(Libstored) in CMake', - formatter_class=argparse.ArgumentDefaultsHelpFormatter) - - parser.add_argument('-V', '--version', action='version', version=__version__) - parser.add_argument('-v', '--verbose', dest='verbose', default=0, help='enable verbose output', action='count') - parser.add_argument('-D', dest='define', metavar='key[=value]', default=[], nargs=1, action='append', help='CMake defines') - parser.add_argument('filename', default='FindLibstored.cmake', nargs='?', - type=str, help='Output filename') + parser = argparse.ArgumentParser( + prog=__package__, + description="Generator for find_package(Libstored) in CMake", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument("-V", "--version", action="version", version=__version__) + parser.add_argument( + "-v", "--verbose", dest="verbose", default=0, help="enable verbose output", action="count" + ) + parser.add_argument( + "-D", + dest="define", + metavar="key[=value]", + default=[], + nargs=1, + action="append", + help="CMake defines", + ) + parser.add_argument( + "filename", default="FindLibstored.cmake", nargs="?", type=str, help="Output filename" + ) args = parser.parse_args() if args.verbose == 0: @@ -69,10 +85,11 @@ def main(): defines = {} if args.define is not None: for d in args.define: - kv = d[0].split('=', 1) + ['ON'] + kv = d[0].split("=", 1) + ["ON"] defines[kv[0]] = kv[1] generate_cmake(args.filename, defines) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/python/libstored/exceptions.py b/python/libstored/exceptions.py index c77ff15e..24cc69c7 100644 --- a/python/libstored/exceptions.py +++ b/python/libstored/exceptions.py @@ -12,26 +12,33 @@ import typing import logging + class Disconnected(RuntimeError): pass + class OperationFailed(RuntimeError): pass + class InvalidState(RuntimeError): pass + class NotSupported(RuntimeError): pass + class InvalidResponse(ValueError): pass + class Deadlock(RuntimeError): pass + class DeadlockChecker: - ''' + """ Context manager to check for deadlocks when acquiring a lock. Usage: @@ -39,9 +46,9 @@ class DeadlockChecker: with DeadlockChecker(lock, timeout_s=5): # Critical section ... - ''' + """ - default_timeout_s : float | None = None + default_timeout_s: float | None = None class Type(enum.Enum): THREADING_LOCK = enum.auto() @@ -51,7 +58,7 @@ class Type(enum.Enum): ASYNCIO_FUTURE = enum.auto() CONCURRENT_FUTURE = enum.auto() - def __init__(self, lock : typing.Any, timeout_s : float | None=None): + def __init__(self, lock: typing.Any, timeout_s: float | None = None): self._lock = lock self._timeout_s = timeout_s if timeout_s is not None else self.default_timeout_s self._acquired = False @@ -72,23 +79,27 @@ def __init__(self, lock : typing.Any, timeout_s : float | None=None): elif isinstance(lock, concurrent.futures.Future): # concurrent.futures.Future self._type = self.Type.CONCURRENT_FUTURE - elif hasattr(lock, 'acquire') and callable(getattr(lock, 'acquire')): + elif hasattr(lock, "acquire") and callable(getattr(lock, "acquire")): # Looks like threading.Lock or threading.RLock self._type = self.Type.THREADING_LOCK else: raise TypeError("Unsupported lock type %s" % type(lock)) def _deadlock(self, type): - self.logger.critical(f"Deadlock detected: could not {type} {self._lock} within {self._timeout_s} seconds") + self.logger.critical( + f"Deadlock detected: could not {type} {self._lock} within {self._timeout_s} seconds" + ) raise Deadlock("Deadlock detected") from None def __enter__(self): if self._type != self.Type.THREADING_LOCK: - raise RuntimeError('Wrong access method') + raise RuntimeError("Wrong access method") - self._acquired = self._lock.acquire(timeout=self._timeout_s if self._timeout_s is not None else -1) + self._acquired = self._lock.acquire( + timeout=self._timeout_s if self._timeout_s is not None else -1 + ) if not self._acquired: - self._deadlock('acquire lock') + self._deadlock("acquire lock") return self def __exit__(self, exc_type, exc_value, traceback): @@ -103,10 +114,10 @@ async def __aenter__(self): elif self._type == self.Type.ASYNCIO_FILELOCK: await self._lock.acquire(timeout=self._timeout_s) else: - raise RuntimeError('Wrong access method') + raise RuntimeError("Wrong access method") self._acquired = True except asyncio.TimeoutError: - self._deadlock('acquire lock') + self._deadlock("acquire lock") return self async def __aexit__(self, exc_type, exc_value, traceback): @@ -132,18 +143,18 @@ def result(self): # We are in an event loop, wrap in asyncio future. return asyncio.wait_for(asyncio.wrap_future(self._lock), timeout=self._timeout_s) else: - raise RuntimeError('Wrong access method') + raise RuntimeError("Wrong access method") except asyncio.TimeoutError: - self._deadlock('acquire future') + self._deadlock("acquire future") def __await__(self): if self._type != self.Type.COROUTINE: - raise RuntimeError('Wrong access method') + raise RuntimeError("Wrong access method") try: return asyncio.wait_for(self._lock, timeout=self._timeout_s).__await__() except asyncio.TimeoutError: - self._deadlock('complete coroutine') + self._deadlock("complete coroutine") def has_lock(self): return self._acquired diff --git a/python/libstored/generator/__main__.py b/python/libstored/generator/__main__.py index e392ac50..ccde1dc7 100644 --- a/python/libstored/generator/__main__.py +++ b/python/libstored/generator/__main__.py @@ -31,124 +31,146 @@ # We are either installed by pip, so the sources are at generator_dir/../data, # or we are just running from the git repo, which is at generator_dir/../../.. -libstored_dir = os.path.normpath(os.path.abspath(os.path.join(generator_dir, '..', 'data'))) -if not os.path.isdir(libstored_dir) or os.path.isfile(os.path.join(libstored_dir, 'ignore')): - libstored_dir = os.path.normpath(os.path.abspath(os.path.join(generator_dir, '..', '..', '..'))) +libstored_dir = os.path.normpath(os.path.abspath(os.path.join(generator_dir, "..", "data"))) +if not os.path.isdir(libstored_dir) or os.path.isfile(os.path.join(libstored_dir, "ignore")): + libstored_dir = os.path.normpath(os.path.abspath(os.path.join(generator_dir, "..", "..", ".."))) + def is_variable(o): return isinstance(o, types.Variable) + def is_function(o): return isinstance(o, types.Function) + def has_function(os): for o in os: if is_function(o): return True return False + def is_blob(o): return o.isBlob() + def is_string(o): - return o.type == 'string' + return o.type == "string" + def is_pointer(o): - return o.type in ['ptr32', 'ptr64'] + return o.type in ["ptr32", "ptr64"] + def ctype(o): return { - 'bool': 'bool', - 'int8': 'int8_t', - 'uint8': 'uint8_t', - 'int16': 'int16_t', - 'uint16': 'uint16_t', - 'int32': 'int32_t', - 'uint32': 'uint32_t', - 'int64': 'int64_t', - 'uint64': 'uint64_t', - 'float': 'float', - 'double': 'double', - 'ptr32': 'void*', - 'ptr64': 'void*', - 'blob': 'void', - 'string': 'char' + "bool": "bool", + "int8": "int8_t", + "uint8": "uint8_t", + "int16": "int16_t", + "uint16": "uint16_t", + "int32": "int32_t", + "uint32": "uint32_t", + "int64": "int64_t", + "uint64": "uint64_t", + "float": "float", + "double": "double", + "ptr32": "void*", + "ptr64": "void*", + "blob": "void", + "string": "char", }[o.type] + def stype(o): t = { - 'bool': 'Type::Bool', - 'int8': 'Type::Int8', - 'uint8': 'Type::Uint8', - 'int16': 'Type::Int16', - 'uint16': 'Type::Uint16', - 'int32': 'Type::Int32', - 'uint32': 'Type::Uint32', - 'int64': 'Type::Int64', - 'uint64': 'Type::Uint64', - 'float': 'Type::Float', - 'double': 'Type::Double', - 'ptr32': 'Type::Pointer', - 'ptr64': 'Type::Pointer', - 'blob': 'Type::Blob', - 'string': 'Type::String' + "bool": "Type::Bool", + "int8": "Type::Int8", + "uint8": "Type::Uint8", + "int16": "Type::Int16", + "uint16": "Type::Uint16", + "int32": "Type::Int32", + "uint32": "Type::Uint32", + "int64": "Type::Int64", + "uint64": "Type::Uint64", + "float": "Type::Float", + "double": "Type::Double", + "ptr32": "Type::Pointer", + "ptr64": "Type::Pointer", + "blob": "Type::Blob", + "string": "Type::String", }[o.type] if is_function(o): - t = f'(Type::type)({t} | Type::FlagFunction)' + t = f"(Type::type)({t} | Type::FlagFunction)" return t + def vhdltype(o): return { - 'bool': 'std_logic', - 'int8': 'signed(7 downto 0)', - 'uint8': 'unsigned(7 downto 0)', - 'int16': 'signed(15 downto 0)', - 'uint16': 'unsigned(15 downto 0)', - 'int32': 'signed(31 downto 0)', - 'uint32': 'unsigned(31 downto 0)', - 'int64': 'signed(63 downto 0)', - 'uint64': 'unsigned(63 downto 0)', - 'float': 'std_logic_vector(31 downto 0)', - 'double': 'std_logic_vector(63 downto 0)', - 'ptr32': 'std_logic_vector(31 downto 0)', - 'ptr64': 'std_logic_vector(63 downto 0)', - 'blob': 'std_logic_vector(%d downto 0)' % (o.size * 8 - 1), - 'string': 'std_logic_vector(%d downto 0)' % (o.size * 8 - 1), + "bool": "std_logic", + "int8": "signed(7 downto 0)", + "uint8": "unsigned(7 downto 0)", + "int16": "signed(15 downto 0)", + "uint16": "unsigned(15 downto 0)", + "int32": "signed(31 downto 0)", + "uint32": "unsigned(31 downto 0)", + "int64": "signed(63 downto 0)", + "uint64": "unsigned(63 downto 0)", + "float": "std_logic_vector(31 downto 0)", + "double": "std_logic_vector(63 downto 0)", + "ptr32": "std_logic_vector(31 downto 0)", + "ptr64": "std_logic_vector(63 downto 0)", + "blob": "std_logic_vector(%d downto 0)" % (o.size * 8 - 1), + "string": "std_logic_vector(%d downto 0)" % (o.size * 8 - 1), }[o.type] + def vhdlinit(o): - b = lambda: 'x"' + (('00' * o.size) + reduce(lambda a, b: '%02x' % b + a, o.encode(o.init, False)))[-o.size * 2:] + '"' + b = ( + lambda: 'x"' + + (("00" * o.size) + reduce(lambda a, b: "%02x" % b + a, o.encode(o.init, False)))[ + -o.size * 2 : + ] + + '"' + ) b = None if o.init != None: - b = ('00' * o.size) - b += reduce(lambda a, b: a + ('%02x' % b), o.encode(o.init, False), '') + b = "00" * o.size + b += reduce(lambda a, b: a + ("%02x" % b), o.encode(o.init, False), "") b = f'x"{b[-o.size * 2:]}"' return { - 'bool': "(7 downto 0 => '0')" if o.init == None or b == 'x"00"' else "(7 downto 1 => '0', 0 => '1')", - 'int8': "(7 downto 0 => '0')" if o.init == None else b, - 'uint8': "(7 downto 0 => '0')" if o.init == None else b, - 'int16': "(15 downto 0 => '0')" if o.init == None else b, - 'uint16': "(15 downto 0 => '0')" if o.init == None else b, - 'int32': "(31 downto 0 => '0')" if o.init == None else b, - 'uint32': "(31 downto 0 => '0')" if o.init == None else b, - 'int64': "(63 downto 0 => '0')" if o.init == None else b, - 'uint64': "(63 downto 0 => '0')" if o.init == None else b, - 'float': "(31 downto 0 => '0')" if o.init == None else b, - 'double': "(63 downto 0 => '0')" if o.init == None else b, - 'ptr32': "(31 downto 0 => '0')" if o.init == None else b, - 'ptr64': "(63 downto 0 => '0')" if o.init == None else b, - 'blob': "(%d downto 0 => '0')" % (o.size * 8 - 1) if o.init == None else b, - 'string': "(%d downto 0 => '0')" % (o.size * 8 - 1) if o.init == None else b, + "bool": ( + "(7 downto 0 => '0')" + if o.init == None or b == 'x"00"' + else "(7 downto 1 => '0', 0 => '1')" + ), + "int8": "(7 downto 0 => '0')" if o.init == None else b, + "uint8": "(7 downto 0 => '0')" if o.init == None else b, + "int16": "(15 downto 0 => '0')" if o.init == None else b, + "uint16": "(15 downto 0 => '0')" if o.init == None else b, + "int32": "(31 downto 0 => '0')" if o.init == None else b, + "uint32": "(31 downto 0 => '0')" if o.init == None else b, + "int64": "(63 downto 0 => '0')" if o.init == None else b, + "uint64": "(63 downto 0 => '0')" if o.init == None else b, + "float": "(31 downto 0 => '0')" if o.init == None else b, + "double": "(63 downto 0 => '0')" if o.init == None else b, + "ptr32": "(31 downto 0 => '0')" if o.init == None else b, + "ptr64": "(63 downto 0 => '0')" if o.init == None else b, + "blob": "(%d downto 0 => '0')" % (o.size * 8 - 1) if o.init == None else b, + "string": "(%d downto 0 => '0')" % (o.size * 8 - 1) if o.init == None else b, }[o.type] + def vhdlstr(s): - return '(' + ', '.join(map(lambda c: 'x"%02x"' % c, s.encode())) + ')' + return "(" + ", ".join(map(lambda c: 'x"%02x"' % c, s.encode())) + ")" + def vhdlkey(o, store, littleEndian): - key = struct.pack(('<' if littleEndian else '>') + 'I', o.offset) + key = struct.pack(("<" if littleEndian else ">") + "I", o.offset) if store.buffer.size >= 0x1000000: pass elif store.buffer.size >= 0x10000: @@ -167,28 +189,32 @@ def vhdlkey(o, store, littleEndian): else: key = key[3:] - return 'x"' + ''.join(map(lambda x: '%02x' % x, key)) + '"' + return 'x"' + "".join(map(lambda x: "%02x" % x, key)) + '"' + def carray(a): - s = '' + s = "" line = 0 for i in a: - s += '0x%02x, ' % i + s += "0x%02x, " % i line += 1 if line >= 16: - s += '\n' + s += "\n" line = 0 return s + def escapebs(s): - return re.sub(r'\\', r'\\\\', s) + return re.sub(r"\\", r"\\\\", s) + def rtfstring(s): - return re.sub(r'([\\{}])', r'\\\1', s) + return re.sub(r"([\\{}])", r"\\\1", s) + def csvstring(s): needEscape = False - for c in ['\r','\n','"',',']: + for c in ["\r", "\n", '"', ","]: if c in s: needEscape = True @@ -197,46 +223,51 @@ def csvstring(s): return '"' + re.sub(r'"', r'""', s) + '"' + def pystring(s): return repr(str(s)) + def pyliteral(x): if isinstance(x, float): - return f'float(\'{x}\')' + return f"float('{x}')" elif isinstance(x, bool): - return 'True' if x else 'False' + return "True" if x else "False" elif isinstance(x, int): - return f'int({x})' + return f"int({x})" else: return repr(x) + def pyinit(o): if o.init is None: return None v = o.init type_map = { - 'bool': lambda: bool(v), - 'int8': lambda: int(v), - 'uint8': lambda: int(v), - 'int16': lambda: int(v), - 'uint16': lambda: int(v), - 'int32': lambda: int(v), - 'uint32': lambda: int(v), - 'int64': lambda: int(v), - 'uint64': lambda: int(v), - 'float': lambda: float(v), - 'double': lambda: float(v), - 'ptr32': lambda: int(v), - 'ptr64': lambda: int(v), - 'blob': lambda: None, - 'string': lambda: str(v), + "bool": lambda: bool(v), + "int8": lambda: int(v), + "uint8": lambda: int(v), + "int16": lambda: int(v), + "uint16": lambda: int(v), + "int32": lambda: int(v), + "uint32": lambda: int(v), + "int64": lambda: int(v), + "uint64": lambda: int(v), + "float": lambda: float(v), + "double": lambda: float(v), + "ptr32": lambda: int(v), + "ptr64": lambda: int(v), + "blob": lambda: None, + "string": lambda: str(v), } return type_map[o.type]() + def jsonstring(s): return json.dumps(s) + def yamlstring(s): """Return a YAML-safe scalar representation. @@ -248,58 +279,67 @@ def yamlstring(s): - For other types, fall back to jsonstring to ensure safe serialization. """ if s is None: - return 'null' + return "null" if isinstance(s, str): # Escape backslashes and quotes inside the string - esc = s.replace('\\', '\\\\').replace('"', '\\"') + esc = s.replace("\\", "\\\\").replace('"', '\\"') return '"' + esc + '"' elif isinstance(s, bool): - return 'true' if s else 'false' + return "true" if s else "false" elif isinstance(s, (int, float)): # Special float cases per YAML core schema conventions if isinstance(s, float): if math.isnan(s): - return '.nan' + return ".nan" if math.isinf(s): - return '.inf' if s > 0 else '-.inf' + return ".inf" if s > 0 else "-.inf" return str(s) # Fallback: JSON is valid YAML for these cases return jsonstring(s) + def tab_indent(s, num): - return ('\t' * num).join(s.splitlines(True)) + return ("\t" * num).join(s.splitlines(True)) + def model_name(model_file): return os.path.splitext(os.path.split(model_file)[1])[0] + def model_cname(model_file): s = model_name(model_file) - s = re.sub(r'[^A-Za-z0-9]+', '_', s) - s = re.sub(r'_([a-z])', lambda m: m.group(1).upper(), s) - s = re.sub(r'^_|_$', '', s) + s = re.sub(r"[^A-Za-z0-9]+", "_", s) + s = re.sub(r"_([a-z])", lambda m: m.group(1).upper(), s) + s = re.sub(r"^_|_$", "", s) s = s[0].upper() + s[1:] return s + def platform_win32(): - return sys.platform == 'win32' + return sys.platform == "win32" -def spdx(license='MPL-2.0', prefix=''): + +def spdx(license="MPL-2.0", prefix=""): # REUSE-IgnoreStart - return \ - f'{prefix}SPDX-FileCopyrightText: 2020-2025 Jochem Rutgers\n' + \ - f'{prefix}\n' + \ - f'{prefix}SPDX-License-Identifier: {license}\n' + return ( + f"{prefix}SPDX-FileCopyrightText: 2020-2025 Jochem Rutgers\n" + + f"{prefix}\n" + + f"{prefix}SPDX-License-Identifier: {license}\n" + ) # REUSE-IgnoreEnd + def sha1(file): - with open(file, 'rb') as f: + with open(file, "rb") as f: return hashlib.sha1(f.read()).hexdigest() + model_names = set() model_cnames = set() models = set() + ## # @brief Load a model from a file # @param filename The name of the file to load @@ -307,23 +347,29 @@ def sha1(file): # def load_model(filename, littleEndian=True, debug=False): meta = metamodel_from_file( - os.path.join(generator_dir, 'dsl', 'grammar.tx'), - classes=[types.Store, - types.Variable, types.Function, types.Scope, - types.BlobType, types.StringType, types.Immediate + os.path.join(generator_dir, "dsl", "grammar.tx"), + classes=[ + types.Store, + types.Variable, + types.Function, + types.Scope, + types.BlobType, + types.StringType, + types.Immediate, ], - debug=debug) + debug=debug, + ) mname = model_name(filename) if mname in model_names: - logger.critical(f'Model {mname} already exists') + logger.critical(f"Model {mname} already exists") sys.exit(1) model_names.add(mname) mcname = model_cname(filename) if mcname in model_cnames: - logger.critical(f'Model {mname}\'s class name {mcname} is ambiguous') + logger.critical(f"Model {mname}'s class name {mcname} is ambiguous") sys.exit(1) model_cnames.add(mcname) @@ -337,116 +383,123 @@ def load_model(filename, littleEndian=True, debug=False): models.add(model) return model + def generate_store(model_file, output_dir, littleEndian=True): logger.info(f"generating store {model_name(model_file)}") model = load_model(model_file, littleEndian) mname = model_name(model_file) - with open(model_file, 'rb') as f: - model.hash = hashlib.sha1(f.read().replace(b'\r\n', b'\n')).hexdigest() + with open(model_file, "rb") as f: + model.hash = hashlib.sha1(f.read().replace(b"\r\n", b"\n")).hexdigest() # create the output dir if it does not exist yet if not os.path.exists(output_dir): os.mkdir(output_dir) - if not os.path.exists(os.path.join(output_dir, 'include')): - os.mkdir(os.path.join(output_dir, 'include')) - if not os.path.exists(os.path.join(output_dir, 'src')): - os.mkdir(os.path.join(output_dir, 'src')) - if not os.path.exists(os.path.join(output_dir, 'doc')): - os.mkdir(os.path.join(output_dir, 'doc')) - if not os.path.exists(os.path.join(output_dir, 'rtl')): - os.mkdir(os.path.join(output_dir, 'rtl')) + if not os.path.exists(os.path.join(output_dir, "include")): + os.mkdir(os.path.join(output_dir, "include")) + if not os.path.exists(os.path.join(output_dir, "src")): + os.mkdir(os.path.join(output_dir, "src")) + if not os.path.exists(os.path.join(output_dir, "doc")): + os.mkdir(os.path.join(output_dir, "doc")) + if not os.path.exists(os.path.join(output_dir, "rtl")): + os.mkdir(os.path.join(output_dir, "rtl")) # now generate the code jenv = jinja2.Environment( - loader = jinja2.FileSystemLoader([ - os.path.join(libstored_dir, 'include', 'libstored'), - os.path.join(libstored_dir, 'src'), - os.path.join(libstored_dir, 'doc'), - os.path.join(libstored_dir, 'fpga', 'rtl'), - ]), - trim_blocks = True, - lstrip_blocks = True) - - jenv.globals['store'] = model - jenv.globals['win32'] = platform_win32() - jenv.filters['ctype'] = ctype - jenv.filters['stype'] = stype - jenv.filters['vhdltype'] = vhdltype - jenv.filters['vhdlinit'] = vhdlinit - jenv.filters['vhdlstr'] = vhdlstr - jenv.filters['vhdlkey'] = vhdlkey - jenv.filters['cname'] = types.cname - jenv.filters['carray'] = carray - jenv.filters['vhdlname'] = types.vhdlname - jenv.filters['len'] = len - jenv.filters['hasfunction'] = has_function - jenv.filters['rtfstring'] = rtfstring - jenv.filters['csvstring'] = csvstring - jenv.filters['pystring'] = pystring - jenv.filters['pyliteral'] = pyliteral - jenv.filters['pyinit'] = pyinit - jenv.filters['jsonstring'] = jsonstring - jenv.filters['yamlstring'] = yamlstring - jenv.filters['tab_indent'] = tab_indent - jenv.tests['variable'] = is_variable - jenv.tests['function'] = is_function - jenv.tests['blob'] = is_blob - jenv.tests['string'] = is_string - jenv.tests['pointer'] = is_pointer - - store_h_tmpl = jenv.get_template('store.h.tmpl') - store_cpp_tmpl = jenv.get_template('store.cpp.tmpl') - store_rtf_tmpl = jenv.get_template('store.rtf.tmpl') - store_csv_tmpl = jenv.get_template('store.csv.tmpl') - store_py_tmpl = jenv.get_template('store.py.tmpl') - store_vhd_tmpl = jenv.get_template('store.vhd.tmpl') - store_pkg_vhd_tmpl = jenv.get_template('store_pkg.vhd.tmpl') - store_yml_tmpl = jenv.get_template('store.yml.tmpl') - - with open(os.path.join(output_dir, 'include', mname + '.h'), 'w') as f: + loader=jinja2.FileSystemLoader( + [ + os.path.join(libstored_dir, "include", "libstored"), + os.path.join(libstored_dir, "src"), + os.path.join(libstored_dir, "doc"), + os.path.join(libstored_dir, "fpga", "rtl"), + ] + ), + trim_blocks=True, + lstrip_blocks=True, + ) + + jenv.globals["store"] = model + jenv.globals["win32"] = platform_win32() + jenv.filters["ctype"] = ctype + jenv.filters["stype"] = stype + jenv.filters["vhdltype"] = vhdltype + jenv.filters["vhdlinit"] = vhdlinit + jenv.filters["vhdlstr"] = vhdlstr + jenv.filters["vhdlkey"] = vhdlkey + jenv.filters["cname"] = types.cname + jenv.filters["carray"] = carray + jenv.filters["vhdlname"] = types.vhdlname + jenv.filters["len"] = len + jenv.filters["hasfunction"] = has_function + jenv.filters["rtfstring"] = rtfstring + jenv.filters["csvstring"] = csvstring + jenv.filters["pystring"] = pystring + jenv.filters["pyliteral"] = pyliteral + jenv.filters["pyinit"] = pyinit + jenv.filters["jsonstring"] = jsonstring + jenv.filters["yamlstring"] = yamlstring + jenv.filters["tab_indent"] = tab_indent + jenv.tests["variable"] = is_variable + jenv.tests["function"] = is_function + jenv.tests["blob"] = is_blob + jenv.tests["string"] = is_string + jenv.tests["pointer"] = is_pointer + + store_h_tmpl = jenv.get_template("store.h.tmpl") + store_cpp_tmpl = jenv.get_template("store.cpp.tmpl") + store_rtf_tmpl = jenv.get_template("store.rtf.tmpl") + store_csv_tmpl = jenv.get_template("store.csv.tmpl") + store_py_tmpl = jenv.get_template("store.py.tmpl") + store_vhd_tmpl = jenv.get_template("store.vhd.tmpl") + store_pkg_vhd_tmpl = jenv.get_template("store_pkg.vhd.tmpl") + store_yml_tmpl = jenv.get_template("store.yml.tmpl") + + with open(os.path.join(output_dir, "include", mname + ".h"), "w") as f: f.write(store_h_tmpl.render()) - with open(os.path.join(output_dir, 'src', mname + '.cpp'), 'w') as f: + with open(os.path.join(output_dir, "src", mname + ".cpp"), "w") as f: f.write(store_cpp_tmpl.render()) - with open(os.path.join(output_dir, 'doc', mname + '.rtf'), 'w') as f: + with open(os.path.join(output_dir, "doc", mname + ".rtf"), "w") as f: f.write(store_rtf_tmpl.render()) - with open(os.path.join(output_dir, 'doc', mname + '.rtf.license'), 'w') as f: - f.write(spdx('CC0-1.0')) + with open(os.path.join(output_dir, "doc", mname + ".rtf.license"), "w") as f: + f.write(spdx("CC0-1.0")) - with open(os.path.join(output_dir, 'doc', mname + '.csv'), 'w') as f: + with open(os.path.join(output_dir, "doc", mname + ".csv"), "w") as f: f.write(store_csv_tmpl.render()) - with open(os.path.join(output_dir, 'doc', mname + '.csv.license'), 'w') as f: - f.write(spdx('CC0-1.0')) + with open(os.path.join(output_dir, "doc", mname + ".csv.license"), "w") as f: + f.write(spdx("CC0-1.0")) - with open(os.path.join(output_dir, 'doc', mname + 'Meta.py'), 'w') as f: + with open(os.path.join(output_dir, "doc", mname + "Meta.py"), "w") as f: f.write(store_py_tmpl.render()) - with open(os.path.join(output_dir, 'doc', mname + '.yml'), 'w') as f: + with open(os.path.join(output_dir, "doc", mname + ".yml"), "w") as f: f.write(store_yml_tmpl.render()) - with open(os.path.join(output_dir, 'rtl', mname + '.vhd'), 'w') as f: + with open(os.path.join(output_dir, "rtl", mname + ".vhd"), "w") as f: f.write(store_vhd_tmpl.render()) - with open(os.path.join(output_dir, 'rtl', mname + '_pkg.vhd'), 'w') as f: + with open(os.path.join(output_dir, "rtl", mname + "_pkg.vhd"), "w") as f: f.write(store_pkg_vhd_tmpl.render()) - licenses_dir = os.path.join(output_dir, 'LICENSES') + licenses_dir = os.path.join(output_dir, "LICENSES") os.makedirs(licenses_dir, exist_ok=True) - shutil.copy(os.path.join(libstored_dir, 'LICENSES', 'CC0-1.0.txt'), licenses_dir) - shutil.copy(os.path.join(libstored_dir, 'LICENSES', 'MPL-2.0.txt'), licenses_dir) + shutil.copy(os.path.join(libstored_dir, "LICENSES", "CC0-1.0.txt"), licenses_dir) + shutil.copy(os.path.join(libstored_dir, "LICENSES", "MPL-2.0.txt"), licenses_dir) + def generate_cmake(libprefix, model_files, output_dir): logger.info("generating CMakeLists.txt") model_map = list(map(model_name, model_files)) try: - libstored_reldir = '${CMAKE_CURRENT_SOURCE_DIR}/' + os.path.relpath(libstored_dir, output_dir) + libstored_reldir = "${CMAKE_CURRENT_SOURCE_DIR}/" + os.path.relpath( + libstored_dir, output_dir + ) except: libstored_reldir = libstored_dir @@ -455,67 +508,87 @@ def generate_cmake(libprefix, model_files, output_dir): os.mkdir(output_dir) jenv = jinja2.Environment( - loader = jinja2.FileSystemLoader([ + loader=jinja2.FileSystemLoader( + [ os.path.join(libstored_dir), - os.path.join(libstored_dir, 'doc'), - os.path.join(libstored_dir, 'fpga', 'vivado'), - ]), - trim_blocks = True, - lstrip_blocks = True) - - jenv.filters['header'] = lambda m: f'include/{m}.h' - jenv.filters['src'] = lambda m: f'src/{m}.cpp' - jenv.filters['escapebs'] = escapebs - jenv.globals['sha1'] = lambda f: sha1(os.path.join(output_dir, f)) - - cmake_tmpl = jenv.get_template('CMakeLists.txt.tmpl') - vivado_tmpl = jenv.get_template('vivado.tcl.tmpl') - spdx_tmpl = jenv.get_template('libstored-src.spdx.tmpl') - sha1sum_tmpl = jenv.get_template('SHA1SUM.tmpl') - - with open(os.path.join(output_dir, 'CMakeLists.txt'), 'w') as f: - f.write(cmake_tmpl.render( - libstored_dir=libstored_reldir, - models=model_map, - libprefix=libprefix, - python_executable=sys.executable, - )) - - with open(os.path.join(output_dir, 'rtl', 'vivado.tcl'), 'w') as f: - f.write(vivado_tmpl.render( - libstored_dir=libstored_dir, - models=model_map, - libprefix=libprefix, - )) - - with open(os.path.join(output_dir, 'doc', 'SHA1SUM'), 'w') as f: - f.write(sha1sum_tmpl.render( - libstored_dir=libstored_dir, - models=models, - libprefix=libprefix, - )) - - with open(os.path.join(output_dir, 'doc', 'SHA1SUM.license'), 'w') as f: - f.write(spdx('CC0-1.0')) - - with open(os.path.join(output_dir, 'doc', 'libstored-src.spdx'), 'w') as f: - f.write(spdx_tmpl.render( - libstored_dir=libstored_dir, - models=models, - libprefix=libprefix, - libstored_version=libstored_version, - uuid=str(uuid.uuid4()), - timestamp=datetime.datetime.now(datetime.timezone.utc).strftime('%Y-%m-%dT%H:%M:%SZ') - )) + os.path.join(libstored_dir, "doc"), + os.path.join(libstored_dir, "fpga", "vivado"), + ] + ), + trim_blocks=True, + lstrip_blocks=True, + ) + + jenv.filters["header"] = lambda m: f"include/{m}.h" + jenv.filters["src"] = lambda m: f"src/{m}.cpp" + jenv.filters["escapebs"] = escapebs + jenv.globals["sha1"] = lambda f: sha1(os.path.join(output_dir, f)) + + cmake_tmpl = jenv.get_template("CMakeLists.txt.tmpl") + vivado_tmpl = jenv.get_template("vivado.tcl.tmpl") + spdx_tmpl = jenv.get_template("libstored-src.spdx.tmpl") + sha1sum_tmpl = jenv.get_template("SHA1SUM.tmpl") + + with open(os.path.join(output_dir, "CMakeLists.txt"), "w") as f: + f.write( + cmake_tmpl.render( + libstored_dir=libstored_reldir, + models=model_map, + libprefix=libprefix, + python_executable=sys.executable, + ) + ) + + with open(os.path.join(output_dir, "rtl", "vivado.tcl"), "w") as f: + f.write( + vivado_tmpl.render( + libstored_dir=libstored_dir, + models=model_map, + libprefix=libprefix, + ) + ) + + with open(os.path.join(output_dir, "doc", "SHA1SUM"), "w") as f: + f.write( + sha1sum_tmpl.render( + libstored_dir=libstored_dir, + models=models, + libprefix=libprefix, + ) + ) + + with open(os.path.join(output_dir, "doc", "SHA1SUM.license"), "w") as f: + f.write(spdx("CC0-1.0")) + + with open(os.path.join(output_dir, "doc", "libstored-src.spdx"), "w") as f: + f.write( + spdx_tmpl.render( + libstored_dir=libstored_dir, + models=models, + libprefix=libprefix, + libstored_version=libstored_version, + uuid=str(uuid.uuid4()), + timestamp=datetime.datetime.now(datetime.timezone.utc).strftime( + "%Y-%m-%dT%H:%M:%SZ" + ), + ) + ) + def main(): - parser = argparse.ArgumentParser(description='Store generator', prog=__package__) - parser.add_argument('-V', '--version',action='version', version=__version__) - parser.add_argument('-p', '--prefix', type=str, help='libstored prefix for cmake library target') - parser.add_argument('-b', '--big', help='generate for big-endian device (default=little)', action='store_true') - parser.add_argument('-v', '--verbose', dest='verbose', default=0, help='Enable verbose output', action='count') - parser.add_argument('store_file', type=str, nargs='+', help='store description to parse') - parser.add_argument('output_dir', type=str, help='output directory for generated files') + parser = argparse.ArgumentParser(description="Store generator", prog=__package__) + parser.add_argument("-V", "--version", action="version", version=__version__) + parser.add_argument( + "-p", "--prefix", type=str, help="libstored prefix for cmake library target" + ) + parser.add_argument( + "-b", "--big", help="generate for big-endian device (default=little)", action="store_true" + ) + parser.add_argument( + "-v", "--verbose", dest="verbose", default=0, help="Enable verbose output", action="count" + ) + parser.add_argument("store_file", type=str, nargs="+", help="store description to parse") + parser.add_argument("output_dir", type=str, help="output directory for generated files") args = parser.parse_args() @@ -527,12 +600,13 @@ def main(): logging.basicConfig(level=logging.DEBUG) global logger - logger = logging.getLogger('libstored') + logger = logging.getLogger("libstored") for f in args.store_file: generate_store(f, args.output_dir, not args.big) generate_cmake(args.prefix, args.store_file, args.output_dir) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/python/libstored/generator/dsl/types.py b/python/libstored/generator/dsl/types.py index ff64c0b8..b9062213 100644 --- a/python/libstored/generator/dsl/types.py +++ b/python/libstored/generator/dsl/types.py @@ -9,49 +9,170 @@ cnames = {} + def is_reserved_name(s): return s in [ # C++ - 'alignas', 'alignof', 'and', 'and_eq', 'asm', 'atomic_cancel', - 'atomic_commit', 'atomic_noexcept', 'auto', 'bitand', 'bitor', 'bool', - 'break', 'case', 'catch', 'char', 'char8_t', 'char16_t', 'char32_t', - 'class', 'compl', 'concept', 'const', 'consteval', 'constexpr', - 'constinit', 'const_cast', 'continue', 'co_await', 'co_return', - 'co_yield', 'decltype', 'default', 'delete', 'do', 'double', - 'dynamic_cast', 'else', 'enum', 'explicit', 'export', 'extern', - 'false', 'float', 'for', 'friend', 'goto', 'if', 'inline', 'int', - 'long', 'mutable', 'namespace', 'new', 'noexcept', 'not', 'not_eq', - 'nullptr', 'operator', 'or', 'or_eq', 'private', 'protected', 'public', - 'reflexpr', 'register', 'reinterpret_cast', 'requires', 'return', - 'short', 'signed', 'sizeof', 'static', 'static_assert', 'static_cast', - 'struct', 'switch', 'synchronized', 'template', 'this', 'thread_local', - 'throw', 'true', 'try', 'typedef', 'typeid', 'typename', 'union', - 'unsigned', 'using', 'virtual', 'void', 'volatile', 'wchar_t', 'while', - 'xor', 'xor_eq', + "alignas", + "alignof", + "and", + "and_eq", + "asm", + "atomic_cancel", + "atomic_commit", + "atomic_noexcept", + "auto", + "bitand", + "bitor", + "bool", + "break", + "case", + "catch", + "char", + "char8_t", + "char16_t", + "char32_t", + "class", + "compl", + "concept", + "const", + "consteval", + "constexpr", + "constinit", + "const_cast", + "continue", + "co_await", + "co_return", + "co_yield", + "decltype", + "default", + "delete", + "do", + "double", + "dynamic_cast", + "else", + "enum", + "explicit", + "export", + "extern", + "false", + "float", + "for", + "friend", + "goto", + "if", + "inline", + "int", + "long", + "mutable", + "namespace", + "new", + "noexcept", + "not", + "not_eq", + "nullptr", + "operator", + "or", + "or_eq", + "private", + "protected", + "public", + "reflexpr", + "register", + "reinterpret_cast", + "requires", + "return", + "short", + "signed", + "sizeof", + "static", + "static_assert", + "static_cast", + "struct", + "switch", + "synchronized", + "template", + "this", + "thread_local", + "throw", + "true", + "try", + "typedef", + "typeid", + "typename", + "union", + "unsigned", + "using", + "virtual", + "void", + "volatile", + "wchar_t", + "while", + "xor", + "xor_eq", # C++ non-keywords, but still tricky to use - 'override', 'final', + "override", + "final", # C - 'auto', 'break', 'case', 'char', 'const', 'continue', 'default', 'do', - 'double', 'else', 'enum', 'extern', 'float', 'for', 'goto', 'if', - 'inline', 'int', 'long', 'register', 'restrict', 'return', 'short', - 'signed', 'sizeof', 'static', 'struct', 'switch', 'typedef', 'union', - 'unsigned', 'void', 'volatile', 'while', '_Alignas', '_Alignof', - '_Atomic', '_Bool', '_Complex', '_Generic', '_Imaginary', '_Noreturn', - '_Static_assert', '_Thread_local', + "auto", + "break", + "case", + "char", + "const", + "continue", + "default", + "do", + "double", + "else", + "enum", + "extern", + "float", + "for", + "goto", + "if", + "inline", + "int", + "long", + "register", + "restrict", + "return", + "short", + "signed", + "sizeof", + "static", + "struct", + "switch", + "typedef", + "union", + "unsigned", + "void", + "volatile", + "while", + "_Alignas", + "_Alignof", + "_Atomic", + "_Bool", + "_Complex", + "_Generic", + "_Imaginary", + "_Noreturn", + "_Static_assert", + "_Thread_local", ] + def cname(s): if s in cnames: return cnames[s] c = s - c = re.sub(r'[^A-Za-z0-9/]+', '_', c) - c = re.sub(r'_*/+', '__', c) - c = re.sub(r'^__', '', c) - c = re.sub(r'^[^A-Za-z]_*', '_', c) - c = re.sub(r'_+$', '', c) + c = re.sub(r"[^A-Za-z0-9/]+", "_", c) + c = re.sub(r"_*/+", "__", c) + c = re.sub(r"^__", "", c) + c = re.sub(r"^[^A-Za-z]_*", "_", c) + c = re.sub(r"_+$", "", c) - if s == '': - c = 'obj' + if s == "": + c = "obj" if is_reserved_name(c): c += "_obj" @@ -59,94 +180,101 @@ def cname(s): u = c i = 2 while u in cnames.values(): - u = c + f'_{i}' + u = c + f"_{i}" i += 1 cnames[s] = u return u + vhdlnames = {} + def vhdlname(s): s = str(s) if s in vhdlnames: return vhdlnames[s] c = s - c = re.sub(r'\\', '\\\\', c) + c = re.sub(r"\\", "\\\\", c) - if s == '': - c = 'obj' + if s == "": + c = "obj" u = c i = 2 while u in vhdlnames.values(): - u = c + f' {i}' + u = c + f" {i}" i += 1 vhdlnames[s] = u return u + def csize(o): return { - 'bool': 1, - 'int8': 1, - 'uint8': 1, - 'int16': 2, - 'uint16': 2, - 'int32': 4, - 'uint32': 4, - 'int64': 8, - 'uint64': 8, - 'float': 4, - 'double': 8, - 'ptr32': 4, - 'ptr64': 8, - 'blob': 0, - 'string': 0, + "bool": 1, + "int8": 1, + "uint8": 1, + "int16": 2, + "uint16": 2, + "int32": 4, + "uint32": 4, + "int64": 8, + "uint64": 8, + "float": 4, + "double": 8, + "ptr32": 4, + "ptr64": 8, + "blob": 0, + "string": 0, }[o] + def typeflags(s, func=False): return { - 'bool': 0x20, - 'int8': 0x38, - 'uint8': 0x30, - 'int16': 0x39, - 'uint16': 0x31, - 'int32': 0x3b, - 'uint32': 0x33, - 'int64': 0x3f, - 'uint64': 0x37, - 'float': 0x2b, - 'double': 0x2f, - 'ptr32': 0x23, - 'ptr64': 0x27, - 'blob': 0x01, - 'string': 0x02, + "bool": 0x20, + "int8": 0x38, + "uint8": 0x30, + "int16": 0x39, + "uint16": 0x31, + "int32": 0x3B, + "uint32": 0x33, + "int64": 0x3F, + "uint64": 0x37, + "float": 0x2B, + "double": 0x2F, + "ptr32": 0x23, + "ptr64": 0x27, + "blob": 0x01, + "string": 0x02, }[s] + (0x40 if func else 0) + def object_name(s): - s = re.sub(r'\s+', ' ', s) + s = re.sub(r"\s+", " ", s) return s + class Directory(object): def __init__(self): self.data = [] self.longdata = [] def merge(self, root, h): -# print(f'merge {root} {h}') - for k,v in h.items(): + # print(f'merge {root} {h}') + for k, v in h.items(): if k in root: self.merge(root[k], v) else: root[k] = v -# print(f'merged into {root}') + + # print(f'merged into {root}') def convertHierarchy(self, xs): # xs is a list of pairs of ([name chunks], object) res = {} for x in xs: -# print(x) + # print(x) if len(x[0]) == 1: # No sub scope res[x[0][0]] = x[1] @@ -162,12 +290,12 @@ def hierarchical(self, objects): objects = sorted(map(lambda o: (o.name, o), objects), key=lambda x: x[0]) # Split in hierarchy. - objects = map(lambda x: (list(x[0] + '\x00'), x[1]), objects) + objects = map(lambda x: (list(x[0] + "\x00"), x[1]), objects) # Make hierarchy. objects = self.convertHierarchy(objects) -# print(objects) + # print(objects) return objects def encodeInt(self, i): @@ -176,11 +304,11 @@ def encodeInt(self, i): res = [] while i >= 0x80: - res.insert(0, i & 0x7f) + res.insert(0, i & 0x7F) i = i >> 7 res.insert(0, i) - for j in range(0,len(res)-1): + for j in range(0, len(res) - 1): res[j] += 0x80 return res @@ -196,40 +324,40 @@ def generateDict(self, h): if isinstance(h, Variable): return self.encodeType(h) + self.encodeInt(h.offset) elif isinstance(h, Function): -# print(f'function {h.name} {h.f}') + # print(f'function {h.name} {h.f}') return self.encodeType(h) + self.encodeInt(h.f) else: assert isinstance(h, dict) names = list(h.keys()) names.sort() -# print(names) + # print(names) if names == []: # end return [0] - elif names == ['\x00']: - return self.generateDict(h['\x00']) - elif names == ['/']: - return [ord('/')] + self.generateDict(h['/']) + elif names == ["\x00"]: + return self.generateDict(h["\x00"]) + elif names == ["/"]: + return [ord("/")] + self.generateDict(h["/"]) elif len(names) == 1 and isinstance(names[0], int): # skip skip = names[0] res = [] while skip > 0: - s = min(skip, 0x1f) + s = min(skip, 0x1F) res += [s] skip -= s return res + self.generateDict(h[names[0]]) else: # Choose pivot to compare to pivot = int(len(names) / 2) - if names[pivot] == '/': + if names[pivot] == "/": if pivot > 0: pivot -= 1 else: pivot += 1 - assert(names[pivot] != '/') + assert names[pivot] != "/" expr = self.generateDict(h[names[pivot]]) # Elements less than pivot @@ -242,7 +370,7 @@ def generateDict(self, h): # Elements greater than pivot g = {} - for n in names[pivot+1:]: + for n in names[pivot + 1 :]: g[n] = h[n] expr_g = self.generateDict(g) if expr_g == [0]: @@ -264,10 +392,10 @@ def stripUnambig(self, h): if not isinstance(h, dict): return if len(h) == 1: - if '/' in h: - self.stripUnambig(h['/']) + if "/" in h: + self.stripUnambig(h["/"]) return - elif '\x00' in h: + elif "\x00" in h: return # Drop at end of name: @@ -286,7 +414,7 @@ def stripUnambig(self, h): # Find the chain: h[k] -> v[/] -> vv={char: {\0: object}} # And replace by: k[k] -> vv -# print(f'strip {h}') + # print(f'strip {h}') for k in list(h.keys()): v = h[k] if not isinstance(v, dict): @@ -294,40 +422,49 @@ def stripUnambig(self, h): self.stripUnambig(v) if len(v) == 1: vk = list(v.keys())[0] - if vk == '\x00': + if vk == "\x00": continue vv = v[vk] - if vk == '/': + if vk == "/": if len(vv) == 1: # Scope with single leaf? vkk = list(vv.keys())[0] - if isinstance(vkk, str) and isinstance(vv[vkk], dict) and len(vv[vkk]) == 1 and '\0' in vv[vkk]: -# print(f'drop path of {v}') + if ( + isinstance(vkk, str) + and isinstance(vv[vkk], dict) + and len(vv[vkk]) == 1 + and "\0" in vv[vkk] + ): + # print(f'drop path of {v}') h[k] = vv[vkk] continue if len(vv) == 1: vvk = list(vv.keys())[0] if isinstance(vvk, int): -# print(f'skip more {vk}') + # print(f'skip more {vk}') h[k] = vv vv[vvk + 1] = vv[vvk] del vv[vvk] - elif vvk == '/' or vvk == '\x00': -# print(f'drop {vk}') + elif vvk == "/" or vvk == "\x00": + # print(f'drop {vk}') h[k] = vv else: -# print(f'skip {vk}') + # print(f'skip {vk}') v[1] = v[vk] del v[vk] -# print(f'stripped {h}') + + # print(f'stripped {h}') def generate(self, objects): h = self.hierarchical(objects) - self.longdata = [ord('/')] + self.generateDict(h) + [0] + self.longdata = [ord("/")] + self.generateDict(h) + [0] self.stripUnambig(h) - self.data = [ord('/')] + self.generateDict(h) + [0] + self.data = [ord("/")] + self.generateDict(h) + [0] + + # print(self.data) + class Buffer(object): def __init__(self): self.size = 0 @@ -351,7 +488,7 @@ def align(self, size, force=None): else: return size + a - (size % a) - def generate(self, initvars, defaultvars, littleEndian = True): + def generate(self, initvars, defaultvars, littleEndian=True): self.init = [] for v in initvars: size = v.buffersize() @@ -371,6 +508,7 @@ def generate(self, initvars, defaultvars, littleEndian = True): if self.size == 0: self.size = 1 + class ArrayLookup(object): def __init__(self, objects): # All objects should have the same name, but only differ in their array index. @@ -387,7 +525,7 @@ def __init__(self, objects): self._placeholders = len(objects[0].name_index) if len(objects) > 0: - self.name = '' + self.name = "" for i in range(0, len(objects[0].name_index)): ni = objects[0].name_index[i] if ni[1] == None: @@ -400,9 +538,9 @@ def __init__(self, objects): self.name = None def placeholders(self): - return 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'[:self._placeholders] + return "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"[: self._placeholders] - def _tree(self, objects, index = 0): + def _tree(self, objects, index=0): if len(objects) == 0: return None if len(objects) == 1: @@ -424,24 +562,25 @@ def _tree(self, objects, index = 0): return res def c_impl(self): - return self._c_impl(self.tree, 0, self.placeholders(), '\t\t\t') + return self._c_impl(self.tree, 0, self.placeholders(), "\t\t\t") def _c_impl(self, tree, index, placeholders, indent): if tree == None: - return f'{indent}return Variant();\n' + return f"{indent}return Variant();\n" if not isinstance(tree, dict): - return f'{indent}return this->{tree.cname}.variant();\n' + return f"{indent}return this->{tree.cname}.variant();\n" - res = f'{indent}switch({placeholders[index]}) {{\n' + res = f"{indent}switch({placeholders[index]}) {{\n" for i in tree.keys(): - res += f'{indent}case {i}:\n' - res += self._c_impl(tree[i], index + 1, placeholders, indent + '\t') - res += f'{indent}default:\n{indent}\treturn Variant();\n' - res += f'{indent}}}\n' + res += f"{indent}case {i}:\n" + res += self._c_impl(tree[i], index + 1, placeholders, indent + "\t") + res += f"{indent}default:\n{indent}\treturn Variant();\n" + res += f"{indent}}}\n" return res def c_decl(self): - return self.cname + '(' + ', '.join([f'int {x}' for x in self.placeholders()]) + ')' + return self.cname + "(" + ", ".join([f"int {x}" for x in self.placeholders()]) + ")" + class Store(object): def __init__(self, objects): @@ -464,10 +603,10 @@ def flattenScope(self, scope): for i in range(0, len(scope)): o = scope[i] if isinstance(o, Scope): -# print(f'flatten {o.name}') + # print(f'flatten {o.name}') flatten = self.flattenScope(self.expandArrays(o.objects)) for f in flatten: - f.setName(o.name + '/' + f.name) + f.setName(o.name + "/" + f.name) res += flatten else: res.append(o) @@ -487,11 +626,11 @@ def expandArrays(self, objects): os = [] for o in objects: if o.len > 1: -# print(f'expand {o.name}') + # print(f'expand {o.name}') for i in range(0, o.len): newo = self.copy(o) newo.len = 1 - newo.setName(newo.name + f'[{i}]') + newo.setName(newo.name + f"[{i}]") os.append(newo) else: os.append(o) @@ -501,7 +640,7 @@ def checkNames(self): for o in self.objects: if len(list(filter(lambda x: x.name == o.name, self.objects))) > 1: sys.exit(f'Duplicate name "{o.name}"') - if '//' in o.name: + if "//" in o.name: sys.exit(f'Empty scope name in "{o.name}"') def generateBuffer(self): @@ -552,14 +691,14 @@ def generateAxiAddresses(self): addr = 0 for o in self.objects: if isinstance(o, Variable): - assert(o.len == 1) + assert o.len == 1 if o.size <= 4: o.axi = addr addr += 4 class Object(object): - def __init__(self, parent, name, len = 0): + def __init__(self, parent, name, len=0): self.setName(name) self.len = len if isinstance(len, int) and len > 1 else 1 @@ -569,7 +708,7 @@ def setName(self, name): # Split our name to find common array indices across objects. def splitName(self): - chunks = re.split(r'\[(\d+)\]', self.name) + chunks = re.split(r"\[(\d+)\]", self.name) if len(chunks) == 1: # Nothing to split self.name_index = None @@ -577,7 +716,7 @@ def splitName(self): # chunks has now an alternating string/array index sequence. # Merge into pairs - if chunks[-1] == '': + if chunks[-1] == "": # drop empty element at the end (the name ends with an array index) chunks.pop() if len(chunks) % 2 == 1: @@ -587,7 +726,8 @@ def splitName(self): self.name_index = [] for i in range(0, len(chunks), 2): - self.name_index.append((chunks[i], chunks[i+1])) + self.name_index.append((chunks[i], chunks[i + 1])) + class Variable(Object): def __init__(self, parent, type, name): @@ -598,16 +738,18 @@ def __init__(self, parent, type, name): self.type = type.fixed.type self.size = csize(type.fixed.type) self.init = None if type.init is None or type.init.value == 0 else type.init.value - self.len = type.fixed.len if isinstance(type.fixed.len, int) and type.fixed.len > 1 else 1 + self.len = ( + type.fixed.len if isinstance(type.fixed.len, int) and type.fixed.len > 1 else 1 + ) else: self.type = type.blob.type self.size = type.blob.size self.init = None - if self.type == 'string' and type.init is not None: + if self.type == "string" and type.init is not None: self.init = bytes(str(type.init), "utf-8").decode("unicode_escape") l = len(self.init.encode()) if self.size < l: - sys.exit(f'String initializer is too long') + sys.exit(f"String initializer is too long") if l == 0: # Empty string, handle as default-initialized. self.init = None @@ -615,7 +757,7 @@ def __init__(self, parent, type, name): self.axi = None def isBlob(self): - return self.type in ['blob', 'string'] + return self.type in ["blob", "string"] def _encode_string(self, x): s = x.encode() @@ -623,28 +765,28 @@ def _encode_string(self, x): return s + bytes([0] * (self.buffersize() - len(s))) def encode(self, x, littleEndian=True): - endian = '<' if littleEndian else '>' + endian = "<" if littleEndian else ">" res = { - 'bool': lambda x: struct.pack(endian + '?', not x in [False, 'false', 0]), - 'int8': lambda x: struct.pack(endian + 'b', int(x)), - 'uint8': lambda x: struct.pack(endian + 'B', int(x)), - 'int16': lambda x: struct.pack(endian + 'h', int(x)), - 'uint16': lambda x: struct.pack(endian + 'H', int(x)), - 'int32': lambda x: struct.pack(endian + 'i', int(x)), - 'uint32': lambda x: struct.pack(endian + 'I', int(x)), - 'int64': lambda x: struct.pack(endian + 'q', int(x)), - 'uint64': lambda x: struct.pack(endian + 'Q', int(x)), - 'float': lambda x: struct.pack(endian + 'f', float(x)), - 'double': lambda x: struct.pack(endian + 'd', float(x)), - 'ptr32': lambda x: struct.pack(endian + 'L', int(x)), - 'ptr64': lambda x: struct.pack(endian + 'Q', int(x)), - 'blob': lambda x: bytearray(x), - 'string': self._encode_string + "bool": lambda x: struct.pack(endian + "?", not x in [False, "false", 0]), + "int8": lambda x: struct.pack(endian + "b", int(x)), + "uint8": lambda x: struct.pack(endian + "B", int(x)), + "int16": lambda x: struct.pack(endian + "h", int(x)), + "uint16": lambda x: struct.pack(endian + "H", int(x)), + "int32": lambda x: struct.pack(endian + "i", int(x)), + "uint32": lambda x: struct.pack(endian + "I", int(x)), + "int64": lambda x: struct.pack(endian + "q", int(x)), + "uint64": lambda x: struct.pack(endian + "Q", int(x)), + "float": lambda x: struct.pack(endian + "f", float(x)), + "double": lambda x: struct.pack(endian + "d", float(x)), + "ptr32": lambda x: struct.pack(endian + "L", int(x)), + "ptr64": lambda x: struct.pack(endian + "Q", int(x)), + "blob": lambda x: bytearray(x), + "string": self._encode_string, }[self.type](x) return res def buffersize(self): - if self.type == 'string': + if self.type == "string": return self.size + 1 else: return self.size @@ -652,6 +794,7 @@ def buffersize(self): def __str__(self): return self.name + class Function(Object): f = 1 @@ -668,7 +811,7 @@ def __init__(self, parent, type, name, len): self.size = type.blob.size def isBlob(self): - return self.type in ['blob', 'string'] + return self.type in ["blob", "string"] def bump(self): self.f = Function.f @@ -677,6 +820,7 @@ def bump(self): def __str__(self): return self.name + class Scope(Object): def __init__(self, parent, objects, name, len): super().__init__(self, name, len) @@ -686,6 +830,7 @@ def __init__(self, parent, objects, name, len): def __str__(self): return self.name + class BlobType(object): def __init__(self, parent, type, size, len): self.parent = parent @@ -693,27 +838,27 @@ def __init__(self, parent, type, size, len): self.len = len if len != None and len > 1 else 1 self.size = size if size > 0 else 0 + class StringType(BlobType): def __init__(self, parent, type, size, len): super().__init__(parent, type, size, len) + class Immediate(object): def __init__(self, parent, value): self.parent = parent if isinstance(value, str): - if value.lower() == 'true': + if value.lower() == "true": self.value = True - elif value.lower() == 'false': + elif value.lower() == "false": self.value = False - elif value.lower() == 'nan': - self.value = float('nan') - elif value.lower() in ['inf', 'infinity']: - self.value = float('inf') - elif value.lower() in ['-inf', '-infinity']: - self.value = float('-inf') + elif value.lower() == "nan": + self.value = float("nan") + elif value.lower() in ["inf", "infinity"]: + self.value = float("inf") + elif value.lower() in ["-inf", "-infinity"]: + self.value = float("-inf") else: self.value = int(value, 0) else: self.value = value - - diff --git a/python/libstored/gui/__main__.py b/python/libstored/gui/__main__.py index e0ff6b17..a46e2239 100644 --- a/python/libstored/gui/__main__.py +++ b/python/libstored/gui/__main__.py @@ -31,17 +31,18 @@ from .. import protocol as lprot - ##################################################################### # Style # + def darken_color(color, factor=0.9): - color = color.lstrip('#') + color = color.lstrip("#") lv = len(color) - rgb = tuple(int(color[i:i+lv//3], 16) for i in range(0, lv, lv//3)) + rgb = tuple(int(color[i : i + lv // 3], 16) for i in range(0, lv, lv // 3)) darker = tuple(int(c * factor) for c in rgb) - return '#{:02x}{:02x}{:02x}'.format(*darker) + return "#{:02x}{:02x}{:02x}".format(*darker) + class Style: root_width = 800 @@ -51,21 +52,21 @@ class Style: separator_padding = grid_padding * 10 - ##################################################################### # Plotter # + class PlotData: WINDOW_s = 30 def __init__(self): - self.t : list[float] = [] - self.values : list[float] = [] - self.connection : typing.Hashable | None = None - self.line : plt.Line2D | None = None # type: ignore + self.t: list[float] = [] + self.values: list[float] = [] + self.connection: typing.Hashable | None = None + self.line: plt.Line2D | None = None # type: ignore - def append(self, value : float, t=time.time()): + def append(self, value: float, t=time.time()): self.t.append(t) self.values.append(value) @@ -89,17 +90,18 @@ def cleanup(self): self.values = self.values[drop:] -plotter : Plotter | None = None +plotter: Plotter | None = None + class Plotter(laio_tk.Work): - available : bool = plt is not None - title : str | None = None + available: bool = plt is not None + title: str | None = None @classmethod def instance(cls) -> Plotter: - ''' + """ Get the singleton instance of the Plotter. - ''' + """ if not cls.available: raise RuntimeError("Matplotlib is not available") @@ -115,24 +117,26 @@ def __init__(self, *args, **kwargs): assert self.available, "Matplotlib is not available" super().__init__(*args, **kwargs) - self._data : dict[laio_zmq.Object, PlotData] = {} - self._fig : plt.Figure | None = None # type: ignore - self._ax : plt.Axes | None = None # type: ignore + self._data: dict[laio_zmq.Object, PlotData] = {} + self._fig: plt.Figure | None = None # type: ignore + self._ax: plt.Axes | None = None # type: ignore self._ready = False - self._changed : set[PlotData] = set() + self._changed: set[PlotData] = set() self._paused = False - self._timer : typing.Any = None + self._timer: typing.Any = None - self.plotting = laio_event.ValueWrapper(bool, self._plotting_get, event_name='plotting') - self.paused = laio_event.ValueWrapper(bool, self._paused_get, self._paused_set, event_name='paused') - self.closed = laio_event.Event('closed') + self.plotting = laio_event.ValueWrapper(bool, self._plotting_get, event_name="plotting") + self.paused = laio_event.ValueWrapper( + bool, self._paused_get, self._paused_set, event_name="paused" + ) + self.closed = laio_event.Event("closed") @laio_tk.Work.tk_func def start(self): - ''' + """ Start the plotter. - ''' + """ self._timer_start() @@ -151,9 +155,9 @@ def _timer_stop(self): @laio_tk.Work.tk_func def stop(self): - ''' + """ Stop the plotter. - ''' + """ self._timer_stop() @@ -163,7 +167,7 @@ def stop(self): if not self._ready: return - self.logger.debug('Closing plotter') + self.logger.debug("Closing plotter") self._ready = False assert plt is not None @@ -183,10 +187,10 @@ def __del__(self): plotter = None @laio_tk.Work.tk_func - def add(self, o : laio_zmq.Object): - ''' + def add(self, o: laio_zmq.Object): + """ Add a libstored.asyncio.Object to the plotter. - ''' + """ if o in self._data: return @@ -194,7 +198,7 @@ def add(self, o : laio_zmq.Object): if not o.is_fixed(): return - self.logger.debug(f'Plot {o.name}') + self.logger.debug(f"Plot {o.name}") if self._fig is None: assert plt @@ -222,7 +226,7 @@ def add(self, o : laio_zmq.Object): self.plotting.trigger() @laio_tk.Work.tk_func - def _update(self, o : laio_zmq.Object, value : typing.Any, t : float=time.time()): + def _update(self, o: laio_zmq.Object, value: typing.Any, t: float = time.time()): if o not in self._data: return if value is None: @@ -234,7 +238,7 @@ def _update(self, o : laio_zmq.Object, value : typing.Any, t : float=time.time() self._timer_start() def _update_plot(self): - self.logger.debug('Update plot') + self.logger.debug("Update plot") if len(self._changed) == 0 or self._paused: self._timer = None @@ -258,15 +262,15 @@ def _update_plot(self): assert self._timer is not None @laio_tk.Work.tk_func - def remove(self, o : laio_zmq.Object): - ''' + def remove(self, o: laio_zmq.Object): + """ Remove a libstored.asyncio.Object from the plotter. - ''' + """ if o not in self._data: return - self.logger.debug(f'Remove plot {o.name}') + self.logger.debug(f"Remove plot {o.name}") data = self._data[o] del self._data[o] @@ -290,9 +294,9 @@ def remove(self, o : laio_zmq.Object): @laio_tk.Work.tk_func def show(self): - ''' + """ Show the plotter window. - ''' + """ if self._fig is None or self._ax is None: return @@ -300,15 +304,15 @@ def show(self): if not self._ready: if self.title is not None: self._ax.set_title(self.title) - self._fig.canvas.manager.set_window_title(f'libstored GUI plots: {self.title}') + self._fig.canvas.manager.set_window_title(f"libstored GUI plots: {self.title}") else: - self._fig.canvas.manager.set_window_title(f'libstored GUI plots') + self._fig.canvas.manager.set_window_title(f"libstored GUI plots") self._ax.grid(True) - self._ax.set_xlabel('t (s)') + self._ax.set_xlabel("t (s)") self._update_legend() - self._fig.canvas.mpl_connect('close_event', lambda _: self.stop()) + self._fig.canvas.mpl_connect("close_event", lambda _: self.stop()) assert plt is not None plt.show(block=False) @@ -323,10 +327,10 @@ def _update_legend(self): self._ax.legend().set_draggable(True) @laio_tk.Work.tk_func - def pause(self, paused : bool=True): - ''' + def pause(self, paused: bool = True): + """ Pause or resume the plotter. - ''' + """ if self._paused == paused: return @@ -350,45 +354,70 @@ def toggle_pause(self): def _paused_get(self) -> bool: return self._paused - def _paused_set(self, x : bool): + def _paused_set(self, x: bool): self.pause(x) def _plotting_get(self): return len(self._data) > 0 - ##################################################################### # GUI elements # + class ClientConnection(laio_tk.AsyncWidget, ttk.Frame): - def __init__(self, app : laio_tk.AsyncApp, parent : ttk.Widget, client : laio_zmq.ZmqClient, clear_state : bool=False, *args, **kwargs): + def __init__( + self, + app: laio_tk.AsyncApp, + parent: ttk.Widget, + client: laio_zmq.ZmqClient, + clear_state: bool = False, + *args, + **kwargs, + ): super().__init__(app=app, master=parent, *args, **kwargs) self._client = client self._clear_state = clear_state self.columnconfigure(1, weight=1) - self._host_label = ttk.Label(self, text='Host:') - self._host_label.grid(row=0, column=0, sticky='e', padx=(0, Style.grid_padding), pady=Style.grid_padding) + self._host_label = ttk.Label(self, text="Host:") + self._host_label.grid( + row=0, column=0, sticky="e", padx=(0, Style.grid_padding), pady=Style.grid_padding + ) self._host = ltk.Entry(self, text=self.host) - self._host.grid(row=0, column=1, sticky='nswe', padx=Style.grid_padding, pady=Style.grid_padding) - - self._port_label = ttk.Label(self, text='Port:') - self._port_label.grid(row=0, column=2, sticky='e', padx=Style.grid_padding, pady=Style.grid_padding) - - self._port = ltk.Entry(self, text=str(self.port), hint=f'default: {lprot.default_port}', validation=r'^[0-9]{0,5}$') - self._port.grid(row=0, column=3, sticky='nswe', padx=Style.grid_padding, pady=Style.grid_padding) + self._host.grid( + row=0, column=1, sticky="nswe", padx=Style.grid_padding, pady=Style.grid_padding + ) + + self._port_label = ttk.Label(self, text="Port:") + self._port_label.grid( + row=0, column=2, sticky="e", padx=Style.grid_padding, pady=Style.grid_padding + ) + + self._port = ltk.Entry( + self, + text=str(self.port), + hint=f"default: {lprot.default_port}", + validation=r"^[0-9]{0,5}$", + ) + self._port.grid( + row=0, column=3, sticky="nswe", padx=Style.grid_padding, pady=Style.grid_padding + ) self._multi_var = tk.BooleanVar(value=client.multi) - self._multi = ttk.Checkbutton(self, text='Multi', variable=self._multi_var) - self._multi.grid(row=0, column=4, sticky='nswe', padx=Style.grid_padding, pady=Style.grid_padding) + self._multi = ttk.Checkbutton(self, text="Multi", variable=self._multi_var) + self._multi.grid( + row=0, column=4, sticky="nswe", padx=Style.grid_padding, pady=Style.grid_padding + ) - self._connect = ttk.Button(self, text='Connect') - self._connect.grid(row=0, column=5, sticky='nswe', padx=(Style.grid_padding, 0), pady=Style.grid_padding) - self._connect['command'] = self._on_connect_button + self._connect = ttk.Button(self, text="Connect") + self._connect.grid( + row=0, column=5, sticky="nswe", padx=(Style.grid_padding, 0), pady=Style.grid_padding + ) + self._connect["command"] = self._on_connect_button self.connect(self.client.connecting, self._on_connected) self.connect(self.client.disconnected, self._on_disconnected) @@ -413,24 +442,24 @@ def port(self) -> int: @laio_tk.AsyncApp.tk_func def _on_connected(self, *args, **kwargs): - self._connect['text'] = 'Disconnect' - self._host['state'] = 'disabled' - self._port['state'] = 'disabled' - self._multi['state'] = 'disabled' + self._connect["text"] = "Disconnect" + self._host["state"] = "disabled" + self._port["state"] = "disabled" + self._multi["state"] = "disabled" @laio_tk.AsyncApp.tk_func def _on_disconnected(self, *args, **kwargs): - self._connect['text'] = 'Connect' - self._host['state'] = 'normal' - self._port['state'] = 'normal' - self._multi['state'] = 'normal' - self.app.root.title(f'libstored GUI') + self._connect["text"] = "Connect" + self._host["state"] = "normal" + self._port["state"] = "normal" + self._multi["state"] = "normal" + self.app.root.title(f"libstored GUI") if plotter is not None: plotter.stop() def _on_connect_button(self): host = self._host.get() - if host == '': + if host == "": return try: @@ -444,64 +473,90 @@ def _on_connect_button(self): self._do_connect_button(host, port, self._multi_var.get()) @laio_tk.AsyncApp.worker_func - async def _do_connect_button(self, host : str, port : int, multi : bool): + async def _do_connect_button(self, host: str, port: int, multi: bool): if self.client.is_connected(): await self.client.disconnect() else: try: await self.client.connect(host, port, multi, default_state=self._clear_state) - self._after_connection(await self.client.identification(), await self.client.version()) + self._after_connection( + await self.client.identification(), await self.client.version() + ) except asyncio.CancelledError: raise except BaseException as e: - self.logger.warning('Connect failed: %s', e) + self.logger.warning("Connect failed: %s", e) @laio_tk.AsyncApp.tk_func - def _after_connection(self, identification : str, version : str): - self.logger.info(f'Connected to {identification} ({version})') - self.app.root.title(f'libstored GUI - {identification} ({version})') - + def _after_connection(self, identification: str, version: str): + self.logger.info(f"Connected to {identification} ({version})") + self.app.root.title(f"libstored GUI - {identification} ({version})") class ObjectRow(laio_tk.AsyncWidget, ttk.Frame): - def __init__(self, app : GUIClient, parent : ttk.Widget, obj : laio_zmq.Object, style : str | None=None, show_plot : bool=False, *args, **kwargs): + def __init__( + self, + app: GUIClient, + parent: ttk.Widget, + obj: laio_zmq.Object, + style: str | None = None, + show_plot: bool = False, + *args, + **kwargs, + ): super().__init__(app=app, master=parent, *args, **kwargs) self._obj = obj self.columnconfigure(0, weight=1) self._label = ttk.Label(self, text=obj.name) - self._label.grid(row=0, column=0, sticky='w', padx=(0, Style.grid_padding), pady=Style.grid_padding) + self._label.grid( + row=0, column=0, sticky="w", padx=(0, Style.grid_padding), pady=Style.grid_padding + ) self._show_plot = show_plot and Plotter.available and plotter and obj.is_fixed() self._plot_var = tk.BooleanVar(value=False) - self._plot = ttk.Checkbutton(self, variable=self._plot_var, command=self._on_plot_check_change) + self._plot = ttk.Checkbutton( + self, variable=self._plot_var, command=self._on_plot_check_change + ) - self._format = ttk.Combobox(self, values=obj.formats(), width=10, state='readonly') + self._format = ttk.Combobox(self, values=obj.formats(), width=10, state="readonly") self._format.set(obj.format) - self._format.bind('<>', lambda e: obj.format.set(self._format.get())) - self._format.grid(row=0, column=2, sticky='nswe', padx=Style.grid_padding, pady=Style.grid_padding) + self._format.bind("<>", lambda e: obj.format.set(self._format.get())) + self._format.grid( + row=0, column=2, sticky="nswe", padx=Style.grid_padding, pady=Style.grid_padding + ) app.connect(obj.format, self._format.set) - self._type = ttk.Label(self, text=obj.type_name, width=10, anchor='e') - self._type.grid(row=0, column=3, sticky='e', padx=Style.grid_padding, pady=Style.grid_padding) + self._type = ttk.Label(self, text=obj.type_name, width=10, anchor="e") + self._type.grid( + row=0, column=3, sticky="e", padx=Style.grid_padding, pady=Style.grid_padding + ) self._value = laio_tk.ZmqObjectEntry(app, self, obj) - self._value.grid(row=0, column=4, sticky='nswe', padx=Style.grid_padding, pady=Style.grid_padding) + self._value.grid( + row=0, column=4, sticky="nswe", padx=Style.grid_padding, pady=Style.grid_padding + ) self._poll_var = tk.BooleanVar(value=obj.polling.value is not None) app.connect(obj.polling, self._on_poll_obj_change) app.connect(obj.polling, self._on_poll_obj_change_csv) - self._poll = ttk.Checkbutton(self, variable=self._poll_var, command=self._on_poll_check_change) - self._poll.grid(row=0, column=5, sticky='nswe', padx=(Style.grid_padding, 0), pady=Style.grid_padding) + self._poll = ttk.Checkbutton( + self, variable=self._poll_var, command=self._on_poll_check_change + ) + self._poll.grid( + row=0, column=5, sticky="nswe", padx=(Style.grid_padding, 0), pady=Style.grid_padding + ) if obj.polling.value is not None: self._on_poll_obj_change_csv(obj.polling.value) if plotter: app.connect(plotter.closed, lambda: self._plot_var.set(False)) - self._refresh = ttk.Button(self, text='Refresh', command=self._value.refresh) - self._refresh.grid(row=0, column=6, sticky='nswe', padx=(Style.grid_padding, 0), pady=Style.grid_padding) + self._refresh = ttk.Button(self, text="Refresh", command=self._value.refresh) + self._refresh.grid( + row=0, column=6, sticky="nswe", padx=(Style.grid_padding, 0), pady=Style.grid_padding + ) if style is not None: self.style(style) @@ -512,25 +567,31 @@ def __init__(self, app : GUIClient, parent : ttk.Widget, obj : laio_zmq.Object, def obj(self) -> laio_zmq.Object: return self._obj - def style(self, style : str): - if style != '': - style += '.' + def style(self, style: str): + if style != "": + style += "." - self['style'] = f'{style}TFrame' - self._label['style'] = f'{style}TLabel' - self._plot['style'] = f'{style}TCheckbutton' - self._format['style'] = f'{style}TCombobox' - self._type['style'] = f'{style}TLabel' - self._value['style'] = f'{style}TEntry' - self._poll['style'] = f'{style}TCheckbutton' - self._refresh['style'] = f'{style}TButton' + self["style"] = f"{style}TFrame" + self._label["style"] = f"{style}TLabel" + self._plot["style"] = f"{style}TCheckbutton" + self._format["style"] = f"{style}TCombobox" + self._type["style"] = f"{style}TLabel" + self._value["style"] = f"{style}TEntry" + self._poll["style"] = f"{style}TCheckbutton" + self._refresh["style"] = f"{style}TButton" @laio_tk.AsyncApp.tk_func def _on_poll_obj_change(self, x): if x is not None: self._poll_var.set(True) if self._show_plot: - self._plot.grid(row=0, column=1, sticky='nswe', padx=(Style.grid_padding, 0), pady=Style.grid_padding) + self._plot.grid( + row=0, + column=1, + sticky="nswe", + padx=(Style.grid_padding, 0), + pady=Style.grid_padding, + ) else: self._poll_var.set(False) self._plot.grid_forget() @@ -572,33 +633,40 @@ def _on_disconnect(self): return self._plot_var.set(False) - self._plot['state'] = 'disabled' - self._format['state'] = 'disabled' + self._plot["state"] = "disabled" + self._format["state"] = "disabled" self._poll_var.set(False) - self._poll['state'] = 'disabled' - self._refresh['state'] = 'disabled' - + self._poll["state"] = "disabled" + self._refresh["state"] = "disabled" class ObjectList(ttk.Frame): - def __init__(self, app : GUIClient, parent : ttk.Widget, objects : list[laio_zmq.Object]=[], \ - filter : typing.Callable[[laio_zmq.Object], bool] | None=None, show_plot : bool=False, *args, **kwargs): + def __init__( + self, + app: GUIClient, + parent: ttk.Widget, + objects: list[laio_zmq.Object] = [], + filter: typing.Callable[[laio_zmq.Object], bool] | None = None, + show_plot: bool = False, + *args, + **kwargs, + ): super().__init__(parent, *args, **kwargs) self._app = app - self._objects : list[ObjectRow] = [] - self._filtered_objects : list[ObjectRow] = [] + self._objects: list[ObjectRow] = [] + self._filtered_objects: list[ObjectRow] = [] self._filter = filter self._show_plot = show_plot self.columnconfigure(0, weight=1) - self.filtered = laio_event.Event('filtered') - self.changed = laio_event.Event('changed') + self.filtered = laio_event.Event("filtered") + self.changed = laio_event.Event("changed") self.set_objects(objects) @property def objects(self) -> list[ObjectRow]: return self._filtered_objects - def set_objects(self, objects : list[laio_zmq.Object]): + def set_objects(self, objects: list[laio_zmq.Object]): for o in self._objects: o.destroy() self._objects = [] @@ -610,7 +678,7 @@ def set_objects(self, objects : list[laio_zmq.Object]): self.changed.trigger() self.filter() - def filter(self, f : typing.Callable[[laio_zmq.Object], bool] | None | bool=True): + def filter(self, f: typing.Callable[[laio_zmq.Object], bool] | None | bool = True): if f is False: return elif f is True: @@ -628,8 +696,8 @@ def filter(self, f : typing.Callable[[laio_zmq.Object], bool] | None | bool=True self._filtered_objects = [] for o in self._objects: if f(o.obj): - o.grid(column=0, row=row, sticky='nsew') - o.style('Even' if row % 2 == 0 else 'Odd') + o.grid(column=0, row=row, sticky="nsew") + o.style("Even" if row % 2 == 0 else "Odd") row += 1 self._filtered_objects.append(o) else: @@ -640,13 +708,14 @@ def filter(self, f : typing.Callable[[laio_zmq.Object], bool] | None | bool=True self.filtered.trigger() + class ScrollableFrame(ttk.Frame): - def __init__(self, parent : ttk.Widget, *args, **kwargs): + def __init__(self, parent: ttk.Widget, *args, **kwargs): super().__init__(parent, *args, **kwargs) self._canvas = tk.Canvas(self) - self._scrollbar = ttk.Scrollbar(self, orient='vertical', command=self._canvas.yview) - self._scrollbar.grid(column=1, row=0, sticky='ns', padx=(Style.grid_padding, 0)) + self._scrollbar = ttk.Scrollbar(self, orient="vertical", command=self._canvas.yview) + self._scrollbar.grid(column=1, row=0, sticky="ns", padx=(Style.grid_padding, 0)) self._content = ttk.Frame(self._canvas) self._content.bind("", self._update_scrollregion) @@ -654,9 +723,9 @@ def __init__(self, parent : ttk.Widget, *args, **kwargs): self._canvas.configure(yscrollcommand=self._scrollbar.set) s = ttk.Style() - self._canvas.configure(background=s.lookup('TFrame', 'background'), highlightthickness=0) + self._canvas.configure(background=s.lookup("TFrame", "background"), highlightthickness=0) - self._canvas.grid(column=0, row=0, sticky='nsew', padx=(0, Style.grid_padding)) + self._canvas.grid(column=0, row=0, sticky="nsew", padx=(0, Style.grid_padding)) self.columnconfigure(0, weight=1) self.rowconfigure(0, weight=1) @@ -671,7 +740,7 @@ def _fit_content_to_canvas(self, event): self._canvas.itemconfigure("inner_frame", width=event.width - 1) def _on_mousewheel(self, event): - self._canvas.yview_scroll(int(-1*(event.delta/120)), "units") + self._canvas.yview_scroll(int(-1 * (event.delta / 120)), "units") return "break" def _on_linux_scroll(self, event): @@ -682,15 +751,17 @@ def _on_linux_scroll(self, event): return "break" @staticmethod - def _children(widget : tk.BaseWidget) -> set[tk.BaseWidget]: - return set(widget.winfo_children()).union(*(ScrollableFrame._children(w) for w in widget.winfo_children())) + def _children(widget: tk.BaseWidget) -> set[tk.BaseWidget]: + return set(widget.winfo_children()).union( + *(ScrollableFrame._children(w) for w in widget.winfo_children()) + ) def bind_scroll(self): # find all children recursively and bind to mousewheel for child in self._children(self._canvas): - child.bind("", self._on_mousewheel) # Windows and macOS - child.bind("", self._on_linux_scroll) # Linux scroll up - child.bind("", self._on_linux_scroll) # Linux scroll down + child.bind("", self._on_mousewheel) # Windows and macOS + child.bind("", self._on_linux_scroll) # Linux scroll up + child.bind("", self._on_linux_scroll) # Linux scroll down def updated_content(self): self._update_scrollregion(None) @@ -704,59 +775,72 @@ def canvas(self) -> tk.Canvas: return self._canvas - class FilterEntry(ltk.Entry): - ''' + """ Regex filter on a given ObjectList. - ''' + """ - def __init__(self, parent : ttk.Widget, object_list : ObjectList, *args, **kwargs): - super().__init__(parent, hint='enter regex filter', *args, **kwargs) + def __init__(self, parent: ttk.Widget, object_list: ObjectList, *args, **kwargs): + super().__init__(parent, hint="enter regex filter", *args, **kwargs) self._object_list = object_list - self._var.trace_add('write', self._on_change) + self._var.trace_add("write", self._on_change) def _on_change(self, *args): text = self.text - if text == '': + if text == "": self._object_list.filter(None) else: try: regex = re.compile(text, re.IGNORECASE) - self['foreground'] = 'black' + self["foreground"] = "black" except re.error: - self['foreground'] = 'red' + self["foreground"] = "red" return - def f(o : laio_zmq.Object) -> bool: + def f(o: laio_zmq.Object) -> bool: return regex.search(o.name) is not None self._object_list.filter(f) - class Tools(laio_tk.Work, ttk.Frame): - def __init__(self, app : GUIClient, parent : ttk.Widget, \ - filter_objects : ObjectList, refresh_command : typing.Callable[[], typing.Any], *args, **kwargs): + def __init__( + self, + app: GUIClient, + parent: ttk.Widget, + filter_objects: ObjectList, + refresh_command: typing.Callable[[], typing.Any], + *args, + **kwargs, + ): super().__init__(atk=app.atk, master=parent, *args, **kwargs) self._app = app filter = FilterEntry(self, filter_objects) - filter.grid(column=0, row=0, sticky='nswe', padx=(0, Style.grid_padding), pady=Style.grid_padding) + filter.grid( + column=0, row=0, sticky="nswe", padx=(0, Style.grid_padding), pady=Style.grid_padding + ) self._plot = None if plotter is not None: - self._plot = ttk.Button(self, text='Show plots', command=self._on_plot, width=12) - self._plot.grid(column=1, row=0, sticky='nswe', padx=Style.grid_padding, pady=Style.grid_padding) + self._plot = ttk.Button(self, text="Show plots", command=self._on_plot, width=12) + self._plot.grid( + column=1, row=0, sticky="nswe", padx=Style.grid_padding, pady=Style.grid_padding + ) self.connect(plotter.plotting, self._on_plotter_update) self.connect(plotter.paused, self._on_plotter_update) self._on_plotter_update() - self._default_poll = ltk.Entry(self, text='1', hint='poll (s)', width=7, justify='right') - self._default_poll.grid(row=0, column=2, sticky='nswe', padx=Style.grid_padding, pady=Style.grid_padding) + self._default_poll = ltk.Entry(self, text="1", hint="poll (s)", width=7, justify="right") + self._default_poll.grid( + row=0, column=2, sticky="nswe", padx=Style.grid_padding, pady=Style.grid_padding + ) - refresh_all = ttk.Button(self, text='Refresh all', command=refresh_command) - refresh_all.grid(row=0, column=3, sticky='nswe', padx=(Style.grid_padding, 0), pady=Style.grid_padding) + refresh_all = ttk.Button(self, text="Refresh all", command=refresh_command) + refresh_all.grid( + row=0, column=3, sticky="nswe", padx=(Style.grid_padding, 0), pady=Style.grid_padding + ) self.columnconfigure(0, weight=1) @@ -768,22 +852,22 @@ def default_poll(self) -> float: except ValueError: v = 1.0 - self._default_poll.delete(0, 'end') - self._default_poll.insert(0, locale.format_string('%g', v, grouping=True)) + self._default_poll.delete(0, "end") + self._default_poll.insert(0, locale.format_string("%g", v, grouping=True)) return v def _on_plotter_update(self): assert self._plot is not None if not plotter or not plotter.plotting.value: - self._plot['state'] = 'disabled' - self._plot['text'] = 'No plots' + self._plot["state"] = "disabled" + self._plot["text"] = "No plots" else: - self._plot['state'] = 'normal' + self._plot["state"] = "normal" if plotter.paused.value: - self._plot['text'] = 'Resume plot' + self._plot["text"] = "Resume plot" else: - self._plot['text'] = 'Pause plot' + self._plot["text"] = "Pause plot" def _on_plot(self): assert plotter is not None @@ -794,23 +878,22 @@ def _on_plot(self): plotter.toggle_pause() - class Stream(laio_tk.AsyncWidget, tk.Toplevel): - def __init__(self, app : GUIClient, parent : ttk.Widget, name : str, *args, **kwargs): + def __init__(self, app: GUIClient, parent: ttk.Widget, name: str, *args, **kwargs): super().__init__(app=app, master=parent, *args, **kwargs) self._name = name self._stream = app.client.stream(name) - self._task : asyncio.Task | None = None - self.title(f'libstored GUI - stream {name}') + self._task: asyncio.Task | None = None + self.title(f"libstored GUI - stream {name}") self._out = tk.Text(self) - self._out.configure(state='disabled') - self._out.grid(column=0, row=0, sticky='nsew') + self._out.configure(state="disabled") + self._out.grid(column=0, row=0, sticky="nsew") self.columnconfigure(0, weight=1) self.rowconfigure(0, weight=1) - self._out.bind('', self._select_all) - self._out.bind('', self._select_all) + self._out.bind("", self._select_all) + self._out.bind("", self._select_all) self._start() @@ -819,14 +902,14 @@ def client(self) -> laio_zmq.ZmqClient: return typing.cast(GUIClient, self.app).client def _select_all(self, event): - event.widget.tag_add('sel', '1.0', 'end') - event.widget.mark_set('insert', 'end') - return 'break' + event.widget.tag_add("sel", "1.0", "end") + event.widget.mark_set("insert", "end") + return "break" @laio_tk.AsyncApp.worker_func async def _start(self): if self._task is None: - self._task = self.client.periodic(1.0, self._poll, name=f'stream {self._name}') + self._task = self.client.periodic(1.0, self._poll, name=f"stream {self._name}") @laio_tk.AsyncApp.worker_func async def _stop(self): @@ -842,12 +925,12 @@ async def _poll(self): self._append(x) @laio_tk.AsyncApp.tk_func - def _append(self, x : str): - self._out.configure(state='normal') + def _append(self, x: str): + self._out.configure(state="normal") self._out.insert(tk.END, x) if self.focus_get() != self._out: self._out.see(tk.END) - self._out.configure(state='disabled') + self._out.configure(state="disabled") @laio_tk.AsyncApp.tk_func def cleanup(self): @@ -856,20 +939,21 @@ def cleanup(self): super().cleanup() - class Streams(laio_tk.AsyncWidget, ttk.Frame): - def __init__(self, app : GUIClient, parent : ttk.Widget, client : laio_zmq.ZmqClient, *args, **kwargs): + def __init__( + self, app: GUIClient, parent: ttk.Widget, client: laio_zmq.ZmqClient, *args, **kwargs + ): super().__init__(app=app, master=parent, *args, **kwargs) self._client = client - self._streams : dict[str, dict] = {} - self._refresh = ttk.Button(self, text='Refresh streams', command=self._on_refresh) + self._streams: dict[str, dict] = {} + self._refresh = ttk.Button(self, text="Refresh streams", command=self._on_refresh) self.connect(self._client.connected, self._on_connect) self.connect(self._client.disconnected, self._on_disconnect) @laio_tk.AsyncApp.worker_func async def _on_connect(self): - if self._client.is_connected() and 's' in await self._client.capabilities(): + if self._client.is_connected() and "s" in await self._client.capabilities(): self._on_connected() else: self._on_disconnect() @@ -877,7 +961,9 @@ async def _on_connect(self): @laio_tk.AsyncApp.tk_func def _on_connected(self): if self._client.is_connected(): - self._refresh.grid(column=256, row=0, sticky='nswe', padx=(Style.grid_padding * 2, 0), pady=0) + self._refresh.grid( + column=256, row=0, sticky="nswe", padx=(Style.grid_padding * 2, 0), pady=0 + ) self._on_refresh() else: self._on_disconnect() @@ -897,23 +983,31 @@ async def _on_refresh(self): self._on_streams(await self._client.other_streams()) @laio_tk.AsyncApp.tk_func - def _on_streams(self, streams : list[str]): + def _on_streams(self, streams: list[str]): for s, sconf in list(self._streams.items()): if s not in streams: - if 'window' in sconf: - sconf['window'].destroy() - sconf['check'].destroy() + if "window" in sconf: + sconf["window"].destroy() + sconf["check"].destroy() del self._streams[s] for s in streams: if s not in self._streams: var = tk.BooleanVar(value=False) - check = ttk.Checkbutton(self, text=s, command=lambda s=s: self._show_stream(s), variable=var) - check.grid(row=0, column=len(self._streams), sticky='nswe', padx=(Style.grid_padding * 2, 0), pady=0) - self._streams[s] = {'check': check, 'var': var} + check = ttk.Checkbutton( + self, text=s, command=lambda s=s: self._show_stream(s), variable=var + ) + check.grid( + row=0, + column=len(self._streams), + sticky="nswe", + padx=(Style.grid_padding * 2, 0), + pady=0, + ) + self._streams[s] = {"check": check, "var": var} @laio_tk.AsyncApp.tk_func - def _show_stream(self, stream : str): + def _show_stream(self, stream: str): if stream not in self._streams: return @@ -922,101 +1016,102 @@ def _show_stream(self, stream : str): return sconf = self._streams[stream] - if 'check' not in sconf: + if "check" not in sconf: self._hide_stream(stream) return - if not sconf['check'].instate(['selected']): + if not sconf["check"].instate(["selected"]): self._hide_stream(stream) return - if 'window' not in sconf: + if "window" not in sconf: w = Stream(typing.cast(GUIClient, self.app), self, stream) - sconf['window'] = w + sconf["window"] = w w.protocol("WM_DELETE_WINDOW", lambda: self._hide_stream(stream)) @laio_tk.AsyncApp.tk_func - def _hide_stream(self, stream : str): + def _hide_stream(self, stream: str): if stream not in self._streams: return sconf = self._streams[stream] - if 'window' in sconf: - w = sconf['window'] - del sconf['window'] + if "window" in sconf: + w = sconf["window"] + del sconf["window"] w.destroy() - if 'check' in sconf and sconf['check'].instate(['selected']): - sconf['check'].state(['!selected']) - + if "check" in sconf and sconf["check"].instate(["selected"]): + sconf["check"].state(["!selected"]) class ManualCommand(laio_tk.AsyncWidget, ttk.Frame): - def __init__(self, app : GUIClient, parent : ttk.Widget, client : laio_zmq.ZmqClient, *args, **kwargs): + def __init__( + self, app: GUIClient, parent: ttk.Widget, client: laio_zmq.ZmqClient, *args, **kwargs + ): super().__init__(app=app, master=parent, *args, **kwargs) self._client = client self._empty = True - self._req = ltk.Entry(self, hint='enter command') - self._req.grid(column=0, row=0, sticky='nswe', padx=0, pady=Style.grid_padding) + self._req = ltk.Entry(self, hint="enter command") + self._req.grid(column=0, row=0, sticky="nswe", padx=0, pady=Style.grid_padding) self.columnconfigure(0, weight=1) - self._req.bind('', self._on_enter) - self._req.bind('', self._on_enter) + self._req.bind("", self._on_enter) + self._req.bind("", self._on_enter) def select_all(event): if isinstance(event.widget, ttk.Entry): - event.widget.select_range(0, 'end') - event.widget.icursor('end') - return 'break' + event.widget.select_range(0, "end") + event.widget.icursor("end") + return "break" elif isinstance(event.widget, tk.Text): - event.widget.tag_add('sel', '1.0', 'end') - event.widget.mark_set('insert', 'end') - return 'break' + event.widget.tag_add("sel", "1.0", "end") + event.widget.mark_set("insert", "end") + return "break" - self._req.bind('', select_all) - self._req.bind('', select_all) + self._req.bind("", select_all) + self._req.bind("", select_all) self._rep = tk.Text(self, height=5) - self._rep.configure(state='disabled') - self._rep.bind('', select_all) - self._rep.bind('', select_all) + self._rep.configure(state="disabled") + self._rep.bind("", select_all) + self._rep.bind("", select_all) - self._req.bind('', self._focus_in, add=True) - self.bind('', self._focus_out) + self._req.bind("", self._focus_in, add=True) + self.bind("", self._focus_out) streams = Streams(app, self, client) - streams.grid(column=1, row=0, sticky='nswe', padx=0, pady=Style.grid_padding) + streams.grid(column=1, row=0, sticky="nswe", padx=0, pady=Style.grid_padding) self.columnconfigure(1, weight=0) self._focus_out() def _on_enter(self, event): self._do_command(self._req.text) - return 'break' + return "break" @laio_tk.AsyncApp.worker_func - async def _do_command(self, command : str): + async def _do_command(self, command: str): if not self._client.is_connected(): return - if command == '': + if command == "": return try: self._response(await self._client.req(command)) except BaseException as e: - self.logger.exception(f'Manual command: {e}') + self.logger.exception(f"Manual command: {e}") @laio_tk.AsyncApp.tk_func - def _response(self, response : str): - self._rep.configure(state='normal') - self._rep.delete('1.0', 'end') - self._rep.insert('end', response) - self._rep.see('1.0') - self._rep.configure(state='disabled') + def _response(self, response: str): + self._rep.configure(state="normal") + self._rep.delete("1.0", "end") + self._rep.insert("end", response) + self._rep.see("1.0") + self._rep.configure(state="disabled") def _focus_in(self, *args): - self._rep.grid(column=0, row=1, sticky='nsew', pady=(Style.grid_padding, 0), columnspan=2) + self._rep.grid(column=0, row=1, sticky="nsew", pady=(Style.grid_padding, 0), columnspan=2) self.rowconfigure(1, weight=1) def _focus_out(self, *args): @@ -1024,20 +1119,27 @@ def _focus_out(self, *args): self.rowconfigure(1, weight=0) - ##################################################################### # GUI # + class GUIClient(laio_tk.AsyncApp): - def __init__(self, client : laio_zmq.ZmqClient, clear_state : bool=False, csv : laio_csv.CsvExport | None=None, *args, **kwargs): + def __init__( + self, + client: laio_zmq.ZmqClient, + clear_state: bool = False, + csv: laio_csv.CsvExport | None = None, + *args, + **kwargs, + ): super().__init__(*args, **kwargs) self._client = client self._client_connections = set() self._csv = csv - self.root.title(f'libstored GUI') - icon_path = os.path.join(os.path.dirname(__file__), 'twotone_bug_report_black_48dp.png') + self.root.title(f"libstored GUI") + icon_path = os.path.join(os.path.dirname(__file__), "twotone_bug_report_black_48dp.png") icon = tk.PhotoImage(file=icon_path) self.root.iconphoto(False, icon) @@ -1048,52 +1150,77 @@ def __init__(self, client : laio_zmq.ZmqClient, clear_state : bool=False, csv : w = Style.root_width h = Style.root_height - ws = self.root.winfo_screenwidth() # width of the screen - hs = self.root.winfo_screenheight() # height of the screen + ws = self.root.winfo_screenwidth() # width of the screen + hs = self.root.winfo_screenheight() # height of the screen - x = (ws/2) - (w/2) - y = (hs/2) - (h/2) + x = (ws / 2) - (w / 2) + y = (hs / 2) - (h / 2) - self.root.geometry('%dx%d+%d+%d' % (w, h, x, y)) + self.root.geometry("%dx%d+%d+%d" % (w, h, x, y)) s = ttk.Style() - s.theme_use('clam') - s.configure('TEntry', padding=(5, 4)) - odd_bg = darken_color(s.lookup('TFrame', 'background')) - s.configure('Odd.TFrame', background=odd_bg) - s.configure('Odd.TLabel', background=odd_bg) - s.configure('Odd.TCheckbutton', background=odd_bg, focuscolor=odd_bg, activebackground=odd_bg) + s.theme_use("clam") + s.configure("TEntry", padding=(5, 4)) + odd_bg = darken_color(s.lookup("TFrame", "background")) + s.configure("Odd.TFrame", background=odd_bg) + s.configure("Odd.TLabel", background=odd_bg) + s.configure( + "Odd.TCheckbutton", background=odd_bg, focuscolor=odd_bg, activebackground=odd_bg + ) s.map("Odd.TCheckbutton", background=[("active", odd_bg)]) - s.configure('Odd.TButton', width=8, padding=3) - - even_bg = s.lookup('TFrame', 'background') - s.configure('Even.TFrame') - s.configure('Even.TLabel', background=even_bg) - s.configure('Even.TCheckbutton', background=even_bg, focuscolor=even_bg, activebackground=even_bg) + s.configure("Odd.TButton", width=8, padding=3) + + even_bg = s.lookup("TFrame", "background") + s.configure("Even.TFrame") + s.configure("Even.TLabel", background=even_bg) + s.configure( + "Even.TCheckbutton", background=even_bg, focuscolor=even_bg, activebackground=even_bg + ) s.map("Even.TCheckbutton", background=[("active", even_bg)]) - s.configure('Even.TButton', width=8, padding=3) + s.configure("Even.TButton", width=8, padding=3) connect = ClientConnection(self, self, self.client, clear_state) - connect.grid(column=0, row=0, sticky='we', padx=Style.window_padding, pady=(Style.window_padding, Style.grid_padding)) + connect.grid( + column=0, + row=0, + sticky="we", + padx=Style.window_padding, + pady=(Style.window_padding, Style.grid_padding), + ) scrollable_objects = ScrollableFrame(self) self._objects = ObjectList(self, scrollable_objects.content) - self._objects.pack(fill='both', expand=True) - scrollable_objects.grid(column=0, row=2, sticky='nsew', padx=Style.window_padding) + self._objects.pack(fill="both", expand=True) + scrollable_objects.grid(column=0, row=2, sticky="nsew", padx=Style.window_padding) self.connect(self._objects.changed, scrollable_objects.bind_scroll) self.connect(self._objects.filtered, scrollable_objects.updated_content) - self._tools = Tools(app=self, parent=self, filter_objects=self._objects, refresh_command=self._refresh_all) - self._tools.grid(column=0, row=1, sticky='nswe', padx=Style.window_padding, pady=Style.grid_padding) + self._tools = Tools( + app=self, parent=self, filter_objects=self._objects, refresh_command=self._refresh_all + ) + self._tools.grid( + column=0, row=1, sticky="nswe", padx=Style.window_padding, pady=Style.grid_padding + ) self._scrollable_polled = ScrollableFrame(self) - self._polled_objects = ObjectList(self, self._scrollable_polled.content, filter=lambda o: o.polling.value is not None, show_plot=True) - self._polled_objects.pack(fill='both', expand=True) + self._polled_objects = ObjectList( + self, + self._scrollable_polled.content, + filter=lambda o: o.polling.value is not None, + show_plot=True, + ) + self._polled_objects.pack(fill="both", expand=True) self.connect(self._polled_objects.changed, self._scrollable_polled.bind_scroll) self.connect(self._polled_objects.filtered, self._scrollable_polled.updated_content) self._manual = ManualCommand(self, self, self.client) - self._manual.grid(column=0, row=4, sticky='nsew', padx=Style.window_padding, pady=(Style.grid_padding, Style.window_padding)) + self._manual.grid( + column=0, + row=4, + sticky="nsew", + padx=Style.window_padding, + pady=(Style.grid_padding, Style.window_padding), + ) self.columnconfigure(0, weight=1) self.rowconfigure(2, weight=2) @@ -1157,7 +1284,7 @@ def cleanup(self): plotter = None self.disconnect_all() - self.logger.debug('Close client') + self.logger.debug("Close client") self._close_async() super().cleanup() @@ -1165,7 +1292,7 @@ def cleanup(self): @laio_tk.AsyncApp.worker_func async def _close_async(self): - self.logger.debug('Closing client') + self.logger.debug("Closing client") await self.client.close() if self._csv is not None: await self._csv.close() @@ -1175,7 +1302,7 @@ def _refresh_all(self): self._refresh_objects([o.obj for o in self._objects.objects]) @laio_tk.AsyncApp.worker_func - async def _refresh_objects(self, objs : list[laio_zmq.Object]): + async def _refresh_objects(self, objs: list[laio_zmq.Object]): await asyncio.gather(*(o.read(acquire_alias=False) for o in objs)) @laio_tk.AsyncApp.tk_func @@ -1190,7 +1317,13 @@ def _resize_polled_objects(self, *args): height += o.winfo_reqheight() if height > 0: - self._scrollable_polled.grid(column=0, row=3, sticky='nsew', padx=Style.window_padding, pady=(Style.separator_padding, 0)) + self._scrollable_polled.grid( + column=0, + row=3, + sticky="nsew", + padx=Style.window_padding, + pady=(Style.separator_padding, 0), + ) self._scrollable_polled.update_idletasks() total_height = self.winfo_height() if total_height > 0: @@ -1206,51 +1339,99 @@ def default_poll(self) -> float: return self._tools.default_poll() - ##################################################################### # CLI # + def main(): - parser = argparse.ArgumentParser(prog=__package__, description='ZMQ GUI client', - formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('-V', '--version', action='version', version=__version__) - parser.add_argument('-s', '--server', dest='server', type=str, default='localhost', help='ZMQ server to connect to') - parser.add_argument('-p', '--port', dest='port', type=int, default=lprot.default_port, help='port') - parser.add_argument('-v', '--verbose', dest='verbose', default=0, help='Enable verbose output', action='count') - parser.add_argument('-m', '--multi', dest='multi', default=False, - help='Enable multi-mode; allow multiple simultaneous connections to the same target, ' + - 'but it is less efficient.', action='store_true') - parser.add_argument('-c', '--clearstate', dest='clear_state', default=False, - help='Clear previously saved state', action='store_true') - parser.add_argument('-D', '--deadlock', dest='deadlock', default=0, nargs='?', - help='Enable deadlock checks after x seconds', type=float, const=10.0) - parser.add_argument('-f', '--csv', dest='csv', default=None, nargs='?', - help='Log auto-refreshed data to csv file. ' + - 'The file is truncated upon startup and when the set of auto-refreshed objects change. ' + - 'The file name may include strftime() format codes.', const='log.csv') - parser.add_argument('-e', '--encrypt', dest='encrypted', type=str, default=None, - help='Enable AES-256 CTR encryption with the given pre-shared key file', metavar='file') + parser = argparse.ArgumentParser( + prog=__package__, + description="ZMQ GUI client", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-V", "--version", action="version", version=__version__) + parser.add_argument( + "-s", + "--server", + dest="server", + type=str, + default="localhost", + help="ZMQ server to connect to", + ) + parser.add_argument( + "-p", "--port", dest="port", type=int, default=lprot.default_port, help="port" + ) + parser.add_argument( + "-v", "--verbose", dest="verbose", default=0, help="Enable verbose output", action="count" + ) + parser.add_argument( + "-m", + "--multi", + dest="multi", + default=False, + help="Enable multi-mode; allow multiple simultaneous connections to the same target, " + + "but it is less efficient.", + action="store_true", + ) + parser.add_argument( + "-c", + "--clearstate", + dest="clear_state", + default=False, + help="Clear previously saved state", + action="store_true", + ) + parser.add_argument( + "-D", + "--deadlock", + dest="deadlock", + default=0, + nargs="?", + help="Enable deadlock checks after x seconds", + type=float, + const=10.0, + ) + parser.add_argument( + "-f", + "--csv", + dest="csv", + default=None, + nargs="?", + help="Log auto-refreshed data to csv file. " + + "The file is truncated upon startup and when the set of auto-refreshed objects change. " + + "The file name may include strftime() format codes.", + const="log.csv", + ) + parser.add_argument( + "-e", + "--encrypt", + dest="encrypted", + type=str, + default=None, + help="Enable AES-256 CTR encryption with the given pre-shared key file", + metavar="file", + ) args = parser.parse_args() - logging_config : dict[str, typing.Any] = { - 'format': '[%(asctime)s.%(msecs)03d] %(levelname)s %(name)s (%(threadName)s): %(message)s', - 'datefmt': '%H:%M:%S', + logging_config: dict[str, typing.Any] = { + "format": "[%(asctime)s.%(msecs)03d] %(levelname)s %(name)s (%(threadName)s): %(message)s", + "datefmt": "%H:%M:%S", } logger = logging.getLogger(__package__) if args.verbose == 0: - logging_config['level'] = logging.WARNING + logging_config["level"] = logging.WARNING elif args.verbose == 1: - logging_config['level'] = logging.INFO + logging_config["level"] = logging.INFO else: - logging_config['level'] = logging.DEBUG + logging_config["level"] = logging.DEBUG logging.basicConfig(**logging_config) if args.deadlock > 0: - logger.info(f'Enable deadlock checks after {args.deadlock} seconds') + logger.info(f"Enable deadlock checks after {args.deadlock} seconds") lexc.DeadlockChecker.default_timeout_s = args.deadlock csv = None @@ -1261,10 +1442,13 @@ def main(): stack = None if args.encrypted: stack = lprot.Aes256Layer(args.encrypted, reqrep=True) - logger.info(f'Enable AES-256 encryption with key file {args.encrypted}') + logger.info(f"Enable AES-256 encryption with key file {args.encrypted}") - client = laio_zmq.ZmqClient(host=args.server, port=args.port, multi=args.multi, use_state='gui', stack=stack) + client = laio_zmq.ZmqClient( + host=args.server, port=args.port, multi=args.multi, use_state="gui", stack=stack + ) GUIClient.run(worker=client.worker, client=client, clear_state=args.clear_state, csv=csv) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/python/libstored/heatshrink.py b/python/libstored/heatshrink.py index 195fd642..34dd4da3 100644 --- a/python/libstored/heatshrink.py +++ b/python/libstored/heatshrink.py @@ -5,22 +5,26 @@ from enum import Enum import logging + class HSD_sink_res(Enum): HSDR_SINK_OK = 0 HSDR_SINK_FULL = 1 HSDR_SINK_ERROR_NULL = -1 + class HSD_poll_res(Enum): HSDR_POLL_EMPTY = 0 HSDR_POLL_MORE = 1 HSDR_POLL_ERROR_NULL = -1 HSDR_POLL_ERROR_UNKNOWN = -2 + class HSD_finish_res(Enum): HSDR_FINISH_DONE = 0 HSDR_FINISH_MORE = 1 HSDR_FINISH_ERROR_NULL = -1 + class HSD_state(Enum): HSDS_TAG_BIT = 0 HSDS_YIELD_LITERAL = 1 @@ -30,15 +34,17 @@ class HSD_state(Enum): HSDS_BACKREF_COUNT_LSB = 5 HSDS_YIELD_BACKREF = 6 -NO_BITS = 0xffff + +NO_BITS = 0xFFFF + class HeatshrinkDecoder: - ''' + """ This is the decoder implementation of heatshrink: https://github.com/atomicobject/heatshrink Although there is a python wrapper available at https://github.com/eerimoq/pyheatshrink, this implementation exists here to break dependencies and compatibility issues. - ''' + """ logger = logging.getLogger(__name__) @@ -54,14 +60,14 @@ def fill(self, x): out_buf = bytearray() while True: if rem > 0: - res, size = self._sink(x[start:start + rem]) + res, size = self._sink(x[start : start + rem]) start += size rem -= size if self._poll(out_buf) == HSD_poll_res.HSDR_POLL_EMPTY and rem == 0: return out_buf - def finish(self, x = b''): + def finish(self, x=b""): out_buf = self.fill(x) while self._finish() == HSD_finish_res.HSDR_FINISH_MORE: @@ -81,7 +87,7 @@ def _reset(self): self._state = HSD_state.HSDS_TAG_BIT self._current_byte = 0 self._bit_index = 0 - self._buffers = bytearray(b'\0' * (self._input_buffer_size + 2 ** self._window_sz2)) + self._buffers = bytearray(b"\0" * (self._input_buffer_size + 2**self._window_sz2)) def _sink(self, x): rem = self._input_buffer_size - self._input_size @@ -89,7 +95,7 @@ def _sink(self, x): return (HSD_sink_res.HSDR_SINK_FULL, 0) size = min(len(x), rem) - self._buffers[self._input_size:size] = x[:size] + self._buffers[self._input_size : size] = x[:size] self._input_size += size return (HSD_sink_res.HSDR_SINK_OK, size) @@ -97,7 +103,7 @@ def _poll(self, out_buf): assert isinstance(out_buf, bytearray) while True: -# self.logger.debug('-- poll, state is %s, input_size %d', self._state, self._input_size) + # self.logger.debug('-- poll, state is %s, input_size %d', self._state, self._input_size) in_state = self._state if in_state == HSD_state.HSDS_TAG_BIT: self._state = self._st_tag_bit() @@ -138,8 +144,8 @@ def _st_yield_literal(self, out_buf): return HSD_state.HSDS_YIELD_LITERAL buf_i = self._input_buffer_size - mask = 2 ** self._window_sz2 - 1 - c = byte & 0xff + mask = 2**self._window_sz2 - 1 + c = byte & 0xFF self._buffers[buf_i + (self._head_index & mask)] = c self._head_index += 1 self._push_byte(out_buf, c) @@ -162,7 +168,9 @@ def _st_backref_index_lsb(self): self._output_index = (self._output_index | bits) + 1 br_bit_ct = self._lookahead_sz2 self._output_count = 0 - return HSD_state.HSDS_BACKREF_COUNT_MSB if br_bit_ct > 8 else HSD_state.HSDS_BACKREF_COUNT_LSB + return ( + HSD_state.HSDS_BACKREF_COUNT_MSB if br_bit_ct > 8 else HSD_state.HSDS_BACKREF_COUNT_LSB + ) def _st_backref_count_msb(self): br_bit_ct = self._lookahead_sz2 @@ -184,10 +192,10 @@ def _st_backref_count_lsb(self): def _st_yield_backref(self, out_buf): count = self._output_count buf_i = self._input_buffer_size - mask = 2 ** self._window_sz2 - 1 + mask = 2**self._window_sz2 - 1 neg_offset = self._output_index assert neg_offset <= mask + 1 - assert count <= 2 ** self._lookahead_sz2 + assert count <= 2**self._lookahead_sz2 for i in range(0, count): c = self._buffers[buf_i + ((self._head_index - neg_offset) & mask)] @@ -228,17 +236,30 @@ def _get_bits(self, count): def _finish(self): if self._state == HSD_state.HSDS_TAG_BIT: - return HSD_finish_res.HSDR_FINISH_DONE if self._input_size == 0 else HSD_finish_res.HSDR_FINISH_MORE - elif self._state == HSD_state.HSDS_BACKREF_INDEX_LSB or \ - self._state == HSD_state.HSDS_BACKREF_INDEX_MSB or \ - self._state == HSD_state.HSDS_BACKREF_COUNT_LSB or \ - self._state == HSD_state.HSDS_BACKREF_COUNT_MSB: - return HSD_finish_res.HSDR_FINISH_DONE if self._input_size == 0 else HSD_finish_res.HSDR_FINISH_MORE + return ( + HSD_finish_res.HSDR_FINISH_DONE + if self._input_size == 0 + else HSD_finish_res.HSDR_FINISH_MORE + ) + elif ( + self._state == HSD_state.HSDS_BACKREF_INDEX_LSB + or self._state == HSD_state.HSDS_BACKREF_INDEX_MSB + or self._state == HSD_state.HSDS_BACKREF_COUNT_LSB + or self._state == HSD_state.HSDS_BACKREF_COUNT_MSB + ): + return ( + HSD_finish_res.HSDR_FINISH_DONE + if self._input_size == 0 + else HSD_finish_res.HSDR_FINISH_MORE + ) elif self._state == HSD_state.HSDS_YIELD_LITERAL: - return HSD_finish_res.HSDR_FINISH_DONE if self._input_size == 0 else HSD_finish_res.HSDR_FINISH_MORE + return ( + HSD_finish_res.HSDR_FINISH_DONE + if self._input_size == 0 + else HSD_finish_res.HSDR_FINISH_MORE + ) else: return HSD_finish_res.HSDR_FINISH_MORE def _push_byte(self, out_buf, x): out_buf.append(x) - diff --git a/python/libstored/log/__main__.py b/python/libstored/log/__main__.py index 78fc7ed8..b9887962 100644 --- a/python/libstored/log/__main__.py +++ b/python/libstored/log/__main__.py @@ -14,12 +14,13 @@ from ..version import __version__ from .. import protocol as lprot + @run_sync -async def async_main(args : argparse.Namespace) -> int: +async def async_main(args: argparse.Namespace) -> int: global logger - filename : str = args.csv - if filename != '-': + filename: str = args.csv + if filename != "-": filename = generate_filename(filename, add_timestamp=args.timestamp, unique=args.unique) stack = None @@ -33,7 +34,7 @@ async def async_main(args : argparse.Namespace) -> int: try: obj = client[o] except ValueError: - logger.fatal('Unknown object: %s', o) + logger.fatal("Unknown object: %s", o) return 1 if obj not in objs: @@ -47,18 +48,18 @@ async def async_main(args : argparse.Namespace) -> int: try: obj = client[o] except ValueError: - logger.fatal('Unknown object: %s', o) + logger.fatal("Unknown object: %s", o) return 1 if obj not in objs: objs.append(obj) if not objs: - logger.error('No objects specified') + logger.error("No objects specified") return 1 for o in objs: - logger.info('Poll %s', o.name) + logger.info("Poll %s", o.name) await o.poll(args.interval) async with CsvExport(filename) as csv: @@ -66,40 +67,97 @@ async def async_main(args : argparse.Namespace) -> int: await csv.add(obj) if args.duration is not None: - logger.info('Start logging for %g s', args.duration) + logger.info("Start logging for %g s", args.duration) await asyncio.sleep(args.duration) else: - logger.info('Start logging') + logger.info("Start logging") await asyncio.Event().wait() return 0 + def main(): global logger - logger = logging.getLogger('log') - - parser = argparse.ArgumentParser(prog=__package__, - description='ZMQ command line logging client', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - - parser.add_argument('-V', '--version', action='version', version=__version__) - parser.add_argument('-s', '--server', dest='server', type=str, default='localhost', help='ZMQ server to connect to') - parser.add_argument('-p', '--port', dest='port', type=int, default=lprot.default_port, help='port') - parser.add_argument('-v', '--verbose', dest='verbose', default=0, help='Enable verbose output', action='count') - parser.add_argument('-f', '--csv', dest='csv', default='-', - help='File to log to. The file name may include strftime() format codes.') - parser.add_argument('-t', '--timestamp', dest='timestamp', default=False, help='Append time stamp in csv file name', action='store_true') - parser.add_argument('-u', '--unique', dest='unique', default=False, - help='Make sure that the log filename is unique by appending a suffix', action='store_true') - parser.add_argument('-m', '--multi', dest='multi', default=False, - help='Enable multi-mode; allow multiple simultaneous connections to the same target, ' + - 'but it is less efficient.', action='store_true') - parser.add_argument('-i', '--interval', dest='interval', type=float, default=1, help='Poll interval (s)') - parser.add_argument('-d', '--duration', dest='duration', type=float, default=None, help='Poll duration (s)') - parser.add_argument('objects', metavar='obj', type=str, nargs='*', help='Object to poll') - parser.add_argument('-o', '--objectfile', dest='objectfile', type=str, action='append', help='File with list of objects to poll') - parser.add_argument('-e', '--encrypt', dest='encrypted', type=str, default=None, - help='Enable AES-256 CTR encryption with the given pre-shared key file', metavar='file') + logger = logging.getLogger("log") + + parser = argparse.ArgumentParser( + prog=__package__, + description="ZMQ command line logging client", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument("-V", "--version", action="version", version=__version__) + parser.add_argument( + "-s", + "--server", + dest="server", + type=str, + default="localhost", + help="ZMQ server to connect to", + ) + parser.add_argument( + "-p", "--port", dest="port", type=int, default=lprot.default_port, help="port" + ) + parser.add_argument( + "-v", "--verbose", dest="verbose", default=0, help="Enable verbose output", action="count" + ) + parser.add_argument( + "-f", + "--csv", + dest="csv", + default="-", + help="File to log to. The file name may include strftime() format codes.", + ) + parser.add_argument( + "-t", + "--timestamp", + dest="timestamp", + default=False, + help="Append time stamp in csv file name", + action="store_true", + ) + parser.add_argument( + "-u", + "--unique", + dest="unique", + default=False, + help="Make sure that the log filename is unique by appending a suffix", + action="store_true", + ) + parser.add_argument( + "-m", + "--multi", + dest="multi", + default=False, + help="Enable multi-mode; allow multiple simultaneous connections to the same target, " + + "but it is less efficient.", + action="store_true", + ) + parser.add_argument( + "-i", "--interval", dest="interval", type=float, default=1, help="Poll interval (s)" + ) + parser.add_argument( + "-d", "--duration", dest="duration", type=float, default=None, help="Poll duration (s)" + ) + parser.add_argument("objects", metavar="obj", type=str, nargs="*", help="Object to poll") + parser.add_argument( + "-o", + "--objectfile", + dest="objectfile", + type=str, + action="append", + help="File with list of objects to poll", + ) + parser.add_argument( + "-e", + "--encrypt", + dest="encrypted", + type=str, + default=None, + help="Enable AES-256 CTR encryption with the given pre-shared key file", + metavar="file", + ) args = parser.parse_args() @@ -114,9 +172,10 @@ def main(): try: res = async_main(args) except KeyboardInterrupt: - logger.info('Interrupted, exiting') + logger.info("Interrupted, exiting") sys.exit(res) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/python/libstored/protocol/file.py b/python/libstored/protocol/file.py index 9c68f026..ca5a04ac 100644 --- a/python/libstored/protocol/file.py +++ b/python/libstored/protocol/file.py @@ -6,49 +6,52 @@ import logging import os -if os.name == 'posix': +if os.name == "posix": import posix import select from . import protocol as lprot from . import util as lutil + class FileLayer(lprot.ProtocolLayer): - ''' + """ A protocol layer that reads/writes a file for I/O. - ''' + """ - name = 'file' + name = "file" - def __init__(self, file : str | tuple[str, str], *args, **kwargs): + def __init__(self, file: str | tuple[str, str], *args, **kwargs): super().__init__(*args, **kwargs) - read = self._posix_read if os.name == 'posix' else self._read - self._reader = lutil.Reader(read, thread_name=f'{self.__class__.__name__} reader') - self._writer = lutil.Writer(self._write, thread_name=f'{self.__class__.__name__} writer') - self._task : asyncio.Task | None = asyncio.create_task(self._reader_task(), name=f'{self.__class__.__name__} reader') + read = self._posix_read if os.name == "posix" else self._read + self._reader = lutil.Reader(read, thread_name=f"{self.__class__.__name__} reader") + self._writer = lutil.Writer(self._write, thread_name=f"{self.__class__.__name__} writer") + self._task: asyncio.Task | None = asyncio.create_task( + self._reader_task(), name=f"{self.__class__.__name__} reader" + ) if isinstance(file, str): file = (file, file) file_in, file_out = file - if os.name == 'posix': + if os.name == "posix": if not os.path.exists(file_in): os.mkfifo(file_in) if not os.path.exists(file_out): os.mkfifo(file_out) - self._file_in = os.fdopen(posix.open(file_in, posix.O_RDWR), 'rb') - self._file_out = os.fdopen(posix.open(file_out, posix.O_RDWR), 'wb') + self._file_in = os.fdopen(posix.open(file_in, posix.O_RDWR), "rb") + self._file_out = os.fdopen(posix.open(file_out, posix.O_RDWR), "wb") else: - self._file_in = open(file_in, 'rb') - self._file_out = open(file_out, 'wb') + self._file_in = open(file_in, "rb") + self._file_out = open(file_out, "wb") def _posix_read(self) -> bytes: f = self._file_in if f is None: - return b'' + return b"" while self._reader.running: res = select.select([f.fileno()], [], [], 1) @@ -57,20 +60,20 @@ def _posix_read(self) -> bytes: # Readable return f.read1(4096) - return b'' + return b"" def _read(self) -> bytes: f = self._file_in if f is None: - return b'' + return b"" return f.read1(4096) - def _write(self, data : bytes) -> None: + def _write(self, data: bytes) -> None: f = self._file_out if f is None: return - self.logger.debug('write %s', data) + self.logger.debug("write %s", data) f.write(data) f.flush() @@ -81,7 +84,7 @@ async def _reader_task(self) -> None: while self._reader.running: x = await self._reader.read() - self.logger.debug('read %s', x) + self.logger.debug("read %s", x) await self.decode(x) except asyncio.CancelledError: pass @@ -112,11 +115,11 @@ async def close(self) -> None: await super().close() - async def encode(self, data : lprot.ProtocolLayer.Packet) -> None: + async def encode(self, data: lprot.ProtocolLayer.Packet) -> None: if isinstance(data, str): data = data.encode() elif isinstance(data, memoryview): - data = data.cast('B') + data = data.cast("B") if not self._writer.running: await self._writer.start() @@ -126,4 +129,5 @@ async def encode(self, data : lprot.ProtocolLayer.Packet) -> None: await super().encode(data) + lprot.register_layer_type(FileLayer) diff --git a/python/libstored/protocol/protocol.py b/python/libstored/protocol/protocol.py index 4d179f13..f02ffc70 100644 --- a/python/libstored/protocol/protocol.py +++ b/python/libstored/protocol/protocol.py @@ -22,25 +22,31 @@ from .. import protocol as lprot from ..asyncio import worker as laio_worker -T = typing.TypeVar('T') +T = typing.TypeVar("T") -def callback_factory(f : typing.Callable[[T], typing.Any] | None) -> \ - typing.Callable[[T], typing.Coroutine[typing.Any, typing.Any, None]]: + +def callback_factory( + f: typing.Callable[[T], typing.Any] | None, +) -> typing.Callable[[T], typing.Coroutine[typing.Any, typing.Any, None]]: if f is None: - async def no_callback(x : T) -> None: + + async def no_callback(x: T) -> None: pass + return no_callback elif inspect.iscoroutinefunction(f): return f else: - async def callback(x : T) -> None: + + async def callback(x: T) -> None: f(x) + return callback class ProtocolLayer: - ''' + """ Base class for all protocol layers. Layers can be stacked (aka wrapped); the top (inner) layer is the @@ -48,31 +54,33 @@ class ProtocolLayer: Encoding means sending data down the stack (towards the physical layer), decoding means receiving data up the stack (towards the application layer). - ''' + """ - name = 'layer' - Packet : typing.TypeAlias = bytes | bytearray | memoryview | str - Callback : typing.TypeAlias = typing.Callable[[Packet], typing.Any] - AsyncCallback : typing.TypeAlias = typing.Callable[[Packet], typing.Coroutine[typing.Any, typing.Any, None]] + name = "layer" + Packet: typing.TypeAlias = bytes | bytearray | memoryview | str + Callback: typing.TypeAlias = typing.Callable[[Packet], typing.Any] + AsyncCallback: typing.TypeAlias = typing.Callable[ + [Packet], typing.Coroutine[typing.Any, typing.Any, None] + ] def __init__(self, *args, **kwargs): - self._closed : bool = False - self._connected : bool = True + self._closed: bool = False + self._connected: bool = True super().__init__(*args, **kwargs) self.logger = logging.getLogger(self.__class__.__name__) - self._down : ProtocolLayer | None = None - self._up : ProtocolLayer | None = None - self._encode_callback : ProtocolLayer.AsyncCallback = callback_factory(None) - self._decode_callback : ProtocolLayer.AsyncCallback = callback_factory(None) - self._activity : float = 0 + self._down: ProtocolLayer | None = None + self._up: ProtocolLayer | None = None + self._encode_callback: ProtocolLayer.AsyncCallback = callback_factory(None) + self._decode_callback: ProtocolLayer.AsyncCallback = callback_factory(None) + self._activity: float = 0 self._async_except_hook = callback_factory(self.default_async_except_hook) - def wrap(self, layer : ProtocolLayer) -> None: - ''' + def wrap(self, layer: ProtocolLayer) -> None: + """ Wrap this layer around another layer. - ''' + """ layer._down = self self._up = layer @@ -81,10 +89,10 @@ def up(self) -> ProtocolLayer | None: return self._up @up.setter - def up(self, cb : ProtocolLayer.Callback | None) -> None: - ''' + def up(self, cb: ProtocolLayer.Callback | None) -> None: + """ Set a callback to be called when data is received from the lower layer. - ''' + """ self._decode_callback = callback_factory(cb) @property @@ -92,24 +100,24 @@ def down(self) -> ProtocolLayer | None: return self._down @down.setter - def down(self, cb : ProtocolLayer.Callback | None) -> None: - ''' + def down(self, cb: ProtocolLayer.Callback | None) -> None: + """ Set a callback to be called when data is received from the upper layer. - ''' + """ self._encode_callback = callback_factory(cb) async def connected(self) -> None: - ''' + """ Called when the connection is (re)connected. - ''' + """ self._connected = True if self.up is not None: await self.up.connected() async def disconnected(self) -> None: - ''' + """ Called when the connection is disconnected. - ''' + """ if not self.is_connected(): return @@ -118,15 +126,15 @@ async def disconnected(self) -> None: await self.up.disconnected() def is_connected(self) -> bool: - ''' + """ Return whether the connection is currently connected. - ''' + """ return self._connected - async def encode(self, data : ProtocolLayer.Packet) -> None: - ''' + async def encode(self, data: ProtocolLayer.Packet) -> None: + """ Encode data for transmission. - ''' + """ self.activity() await self._encode_callback(data) @@ -134,10 +142,10 @@ async def encode(self, data : ProtocolLayer.Packet) -> None: if self.down is not None: await self.down.encode(data) - async def decode(self, data : ProtocolLayer.Packet) -> None: - ''' + async def decode(self, data: ProtocolLayer.Packet) -> None: + """ Decode data received from the lower layer. - ''' + """ self.activity() await self._decode_callback(data) @@ -147,32 +155,32 @@ async def decode(self, data : ProtocolLayer.Packet) -> None: @property def mtu(self) -> int | None: - ''' + """ Maximum Transmission Unit (MTU) for this layer. Return None when there is no limit. - ''' + """ if self.down is not None: return self.down.mtu else: return None async def timeout(self) -> None: - ''' + """ Trigger maintenance actions when a timeout occurs. - ''' + """ if self.down is not None: await self.down.timeout() def activity(self) -> None: - ''' + """ Mark that there was activity on this layer. - ''' + """ self._activity = time.time() def last_activity(self) -> float: - ''' + """ Get the time of the last activity on this layer or any lower layer. - ''' + """ a = 0 if self.down is not None: a = self.down.last_activity() @@ -180,11 +188,11 @@ def last_activity(self) -> float: return max(a, self._activity) async def close(self) -> None: - ''' + """ Close the layer and release resources. Closing cannot be undone. - ''' + """ if self._closed: return @@ -195,12 +203,12 @@ async def close(self) -> None: try: await self.down.close() except BaseException as e: - self.logger.warning(f'Exception while closing: {e}') + self.logger.warning(f"Exception while closing: {e}") def is_closed(self) -> bool: - ''' + """ Return whether the layer is closed. - ''' + """ return self._closed async def __aenter__(self): @@ -210,18 +218,22 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): await self.close() def __del__(self): - assert self._closed, f'ProtocolLayer {self.__class__.__name__} was not close()d upon deletion' + assert ( + self._closed + ), f"ProtocolLayer {self.__class__.__name__} was not close()d upon deletion" @property - def async_except_hook(self) -> typing.Callable[[BaseException], typing.Coroutine[typing.Any, typing.Any, None]]: + def async_except_hook( + self, + ) -> typing.Callable[[BaseException], typing.Coroutine[typing.Any, typing.Any, None]]: return self._async_except_hook @async_except_hook.setter - def async_except_hook(self, f : typing.Callable[[BaseException], typing.Any] | None) -> None: + def async_except_hook(self, f: typing.Callable[[BaseException], typing.Any] | None) -> None: self._async_except_hook = callback_factory(f) - async def default_async_except_hook(self, e : BaseException) -> None: - self.logger.exception(f'Async exception {e}', exc_info=(type(e), e, e.__traceback__)) + async def default_async_except_hook(self, e: BaseException) -> None: + self.logger.exception(f"Async exception {e}", exc_info=(type(e), e, e.__traceback__)) w = laio_worker.current_worker() if w is not None: w.cancel() @@ -229,56 +241,55 @@ async def default_async_except_hook(self, e : BaseException) -> None: asyncio.get_running_loop().stop() raise - async def async_except(self, e : BaseException) -> None: + async def async_except(self, e: BaseException) -> None: await self._async_except_hook(e) - class AsciiEscapeLayer(ProtocolLayer): - ''' + """ Layer that escapes non-printable ASCII characters. - ''' + """ - name = 'ascii' + name = "ascii" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - async def decode(self, data : ProtocolLayer.Packet) -> None: + async def decode(self, data: ProtocolLayer.Packet) -> None: if isinstance(data, str): data = data.encode() elif isinstance(data, memoryview): - data = data.cast('B') + data = data.cast("B") res = bytearray() esc = False for b in data: if esc: - if b == 0x7f: - res.append(0x7f) + if b == 0x7F: + res.append(0x7F) else: - res.append(b & 0x3f) + res.append(b & 0x3F) esc = False - elif b == 0x7f: + elif b == 0x7F: esc = True else: res.append(b) await super().decode(res) - async def encode(self, data : ProtocolLayer.Packet) -> None: + async def encode(self, data: ProtocolLayer.Packet) -> None: if isinstance(data, str): data = data.encode() elif isinstance(data, memoryview): - data = data.cast('B') + data = data.cast("B") res = bytearray() for b in data: if b < 0x20: - res += bytearray([0x7f, b | 0x40]) - elif b == 0x7f: - res += bytearray([0x7f, 0x7f]) + res += bytearray([0x7F, b | 0x40]) + elif b == 0x7F: + res += bytearray([0x7F, 0x7F]) else: res.append(b) @@ -292,74 +303,83 @@ def mtu(self) -> int | None: return max(1, int(m / 2)) - class TerminalLayer(ProtocolLayer): - ''' + """ A ProtocolLayer that encodes debug messages in terminal escape codes. Non-debug messages are passed through unchanged. - ''' - - name = 'term' - start = b'\x1b_' # APC - end = b'\x1b\\' # ST + """ - def __init__(self, fdout : str | int | ProtocolLayer.Callback | None=1, ignoreEscapesTillFirstEncode : bool=True, *args, **kwargs): + name = "term" + start = b"\x1b_" # APC + end = b"\x1b\\" # ST + + def __init__( + self, + fdout: str | int | ProtocolLayer.Callback | None = 1, + ignoreEscapesTillFirstEncode: bool = True, + *args, + **kwargs, + ): super().__init__(*args, **kwargs) - self.fdout : ProtocolLayer.Callback | None = None + self.fdout: ProtocolLayer.Callback | None = None if isinstance(fdout, str): fdout = int(fdout) if isinstance(fdout, int): if fdout == 1: + def fdout_stdout(x): if isinstance(x, (bytes, bytearray)): - x = x.decode(errors='replace') + x = x.decode(errors="replace") elif isinstance(x, memoryview): - x = x.tobytes().decode(errors='replace') + x = x.tobytes().decode(errors="replace") sys.stdout.write(x) sys.stdout.flush() + self.fdout = fdout_stdout else: + def fdout_stderr(x): if isinstance(x, (bytes, bytearray)): - x = x.decode(errors='replace') + x = x.decode(errors="replace") elif isinstance(x, memoryview): - x = x.tobytes().decode(errors='replace') + x = x.tobytes().decode(errors="replace") sys.stderr.write(x) sys.stderr.flush() + self.fdout = fdout_stderr else: self.fdout = fdout - self._data : bytearray = bytearray() - self._inMsg : bool = False - self._ignoreEscape : bool = ignoreEscapesTillFirstEncode + self._data: bytearray = bytearray() + self._inMsg: bool = False + self._ignoreEscape: bool = ignoreEscapesTillFirstEncode - async def non_debug_data(self, data : ProtocolLayer.Packet) -> None: + async def non_debug_data(self, data: ProtocolLayer.Packet) -> None: if len(data) > 0: if self.fdout is not None: self.fdout(data) - async def encode(self, data : ProtocolLayer.Packet) -> None: + async def encode(self, data: ProtocolLayer.Packet) -> None: if isinstance(data, str): data = data.encode() self._ignoreEscape = False await super().encode(self.start + bytes(data) + self.end) - async def inject(self, data : ProtocolLayer.Packet) -> None: - ''' + async def inject(self, data: ProtocolLayer.Packet) -> None: + """ Inject non-debug data down the stack. - ''' + """ if isinstance(data, str): data = data.encode() await super().encode(data) - async def decode(self, data : ProtocolLayer.Packet) -> None: - if data == b'': + async def decode(self, data: ProtocolLayer.Packet) -> None: + if data == b"": return if self._ignoreEscape and not self._inMsg: await self.non_debug_data(data) @@ -376,8 +396,8 @@ async def decode(self, data : ProtocolLayer.Packet) -> None: while True: if not self._inMsg: c = self._data.split(self.start, 1) - if c[0] != b'': - self.logger.debug('non-debug %s', bytes(c[0])) + if c[0] != b"": + self.logger.debug("non-debug %s", bytes(c[0])) await self.non_debug_data(c[0]) if len(c) == 1: @@ -397,8 +417,8 @@ async def decode(self, data : ProtocolLayer.Packet) -> None: # Got a full message. # Remove \r as they can be inserted automatically by Windows. # If \r is meant to be sent, escape it. - msg = c[0].replace(b'\r', b'') - self.logger.debug('extracted %s', bytes(msg)) + msg = c[0].replace(b"\r", b"") + self.logger.debug("extracted %s", bytes(msg)) self._data = c[1] self._inMsg = False await super().decode(msg) @@ -415,21 +435,26 @@ def mtu(self) -> int | None: return max(1, m - len(self.start) - len(self.end)) - class PubTerminalLayer(TerminalLayer): """ A TerminalLayer (term), that also forwards all non-debug data over a PUB socket. """ - name = 'pubterm' + name = "pubterm" default_port = lprot.default_port + 1 - def __init__(self, bind : str=f'*:{default_port}', *args, context : zmq.asyncio.Context | None=None, **kwargs): + def __init__( + self, + bind: str = f"*:{default_port}", + *args, + context: zmq.asyncio.Context | None = None, + **kwargs, + ): super().__init__(*args, **kwargs) - self._context : zmq.asyncio.Context = context or zmq.asyncio.Context.instance() - self._socket : zmq.asyncio.Socket | None = self.context.socket(zmq.PUB) + self._context: zmq.asyncio.Context = context or zmq.asyncio.Context.instance() + self._socket: zmq.asyncio.Socket | None = self.context.socket(zmq.PUB) assert self._socket is not None - self._socket.bind(f'tcp://{bind}') + self._socket.bind(f"tcp://{bind}") @property def context(self) -> zmq.asyncio.Context: @@ -446,42 +471,44 @@ async def close(self) -> None: await super().close() - async def non_debug_data(self, data : ProtocolLayer.Packet) -> None: + async def non_debug_data(self, data: ProtocolLayer.Packet) -> None: await super().non_debug_data(data) if len(data) > 0 and self._socket is not None: await self._socket.send(data) class ReqRepCheckLayer(ProtocolLayer): - ''' + """ A ProtocolLayer that checks that requests and replies are matched. It triggers timeout() when a reply is not received in time. - ''' + """ - name = 'reqrepcheck' + name = "reqrepcheck" - def __init__(self, timeout_s : float = 1, *args, **kwargs): + def __init__(self, timeout_s: float = 1, *args, **kwargs): super().__init__(*args, **kwargs) - self._req : bool = False - self._timeout_s : float = timeout_s - self._retransmit_time : float = 0 - self._retransmitter : asyncio.Task | None = asyncio.create_task(self._retransmitter_task(), name=self.__class__.__name__) + self._req: bool = False + self._timeout_s: float = timeout_s + self._retransmit_time: float = 0 + self._retransmitter: asyncio.Task | None = asyncio.create_task( + self._retransmitter_task(), name=self.__class__.__name__ + ) @property def timeout_s(self) -> float: return self._timeout_s @timeout_s.setter - def timeout_s(self, value : float) -> None: + def timeout_s(self, value: float) -> None: if not self._req: self._retransmit_time = time.time() + value self._timeout_s = value @property def req(self) -> bool: - ''' + """ Return if we are currently waiting for a reply to a request. - ''' + """ return self._req async def _retransmitter_task(self) -> None: @@ -518,17 +545,19 @@ async def close(self) -> None: await super().close() - async def encode(self, data : ProtocolLayer.Packet) -> None: + async def encode(self, data: ProtocolLayer.Packet) -> None: if self._req: - raise RuntimeError('ReqRepCheckLayer encode called while previous request not yet handled') + raise RuntimeError( + "ReqRepCheckLayer encode called while previous request not yet handled" + ) self._req = True self._retransmit_time = time.time() + self._timeout_s await super().encode(data) - async def decode(self, data : ProtocolLayer.Packet) -> None: + async def decode(self, data: ProtocolLayer.Packet) -> None: if not self._req: - self.logger.debug('Ignoring unexpected rep %s', data) + self.logger.debug("Ignoring unexpected rep %s", data) return self._req = False @@ -539,41 +568,40 @@ async def disconnected(self) -> None: self._req = False - class SegmentationLayer(ProtocolLayer): - ''' + """ A ProtocolLayer that segments and reassembles data according to a simple protocol with 'E' (end) and 'C' (continue) markers. - ''' + """ - name = 'segment' - end = b'E' - cont = b'C' + name = "segment" + end = b"E" + cont = b"C" - def __init__(self, mtu : str | int | None=None, *args, **kwargs): + def __init__(self, mtu: str | int | None = None, *args, **kwargs): super().__init__(*args, **kwargs) if isinstance(mtu, str): - mtu = int(mtu) + mtu = int(mtu) self._mtu = mtu if mtu is not None and mtu > 0 else None self._buffer = bytearray() - async def decode(self, data : ProtocolLayer.Packet) -> None: + async def decode(self, data: ProtocolLayer.Packet) -> None: if isinstance(data, str): data = data.encode() elif isinstance(data, memoryview): - data = data.cast('B') + data = data.cast("B") self._buffer += data[:-1] if data[-1:] == self.end: - self.logger.debug('reassembled %s', bytes(self._buffer)) + self.logger.debug("reassembled %s", bytes(self._buffer)) await super().decode(self._buffer) self._buffer = bytearray() - async def encode(self, data : ProtocolLayer.Packet) -> None: + async def encode(self, data: ProtocolLayer.Packet) -> None: if isinstance(data, str): data = data.encode() elif isinstance(data, memoryview): - data = bytearray(data.cast('B')) + data = bytearray(data.cast("B")) mtu = self._mtu if self._mtu is None: @@ -584,18 +612,18 @@ async def encode(self, data : ProtocolLayer.Packet) -> None: mtu = max(1, mtu - 1) for i in range(0, len(data), mtu): if i + mtu >= len(data): - await super().encode(data[i:i+mtu] + self.end) + await super().encode(data[i : i + mtu] + self.end) else: - await super().encode(data[i:i+mtu] + self.cont) + await super().encode(data[i : i + mtu] + self.cont) @property def mtu(self) -> int | None: return None async def timeout(self) -> None: - ''' + """ A DebugArqLayer below us is going to do a retransmit. Clear the buffer. - ''' + """ self._buffer = bytearray() await super().timeout() @@ -604,30 +632,29 @@ async def disconnected(self) -> None: self._buffer = bytearray() - class DebugArqLayer(ProtocolLayer): - ''' + """ A ProtocolLayer that implements a simple ARQ protocol for debugging. - ''' + """ - name = 'arq' + name = "arq" reset_flag = 0x80 def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._req : bool = False - self._request : list[bytes] = [] - self._reset : bool = True - self._syncing : bool = False - self._decode_seq : int = 1 - self._decode_seq_start : int = self._decode_seq - self._encode_lock : asyncio.Lock = asyncio.Lock() - - async def decode(self, data : ProtocolLayer.Packet) -> None: + self._req: bool = False + self._request: list[bytes] = [] + self._reset: bool = True + self._syncing: bool = False + self._decode_seq: int = 1 + self._decode_seq_start: int = self._decode_seq + self._encode_lock: asyncio.Lock = asyncio.Lock() + + async def decode(self, data: ProtocolLayer.Packet) -> None: if isinstance(data, str): data = data.encode() elif isinstance(data, memoryview): - data = data.cast('B') + data = data.cast("B") if len(data) == 0: return @@ -647,7 +674,7 @@ async def decode(self, data : ProtocolLayer.Packet) -> None: if len(msg) > 0: await super().decode(msg) else: - self.logger.debug(f'unexpected seq {seq} instead of {self._decode_seq}; dropped') + self.logger.debug(f"unexpected seq {seq} instead of {self._decode_seq}; dropped") if self._syncing and data[0] == self.reset_flag: self._syncing = False @@ -655,26 +682,26 @@ async def decode(self, data : ProtocolLayer.Packet) -> None: await self._encode(r) @staticmethod - def decode_seq(data : bytes | bytearray | memoryview) -> tuple[int, memoryview[int]]: + def decode_seq(data: bytes | bytearray | memoryview) -> tuple[int, memoryview[int]]: if isinstance(data, memoryview): - data = data.cast('B') + data = data.cast("B") seq = 0 if len(data) == 0: raise ValueError - seq = data[0] & 0x3f + seq = data[0] & 0x3F if data[0] & 0x40: if len(data) == 1: raise ValueError - seq = (seq << 7) | data[1] & 0x7f + seq = (seq << 7) | data[1] & 0x7F if data[1] & 0x80: if len(data) == 2: raise ValueError - seq = (seq << 7) | data[2] & 0x7f + seq = (seq << 7) | data[2] & 0x7F if data[2] & 0x80: if len(data) == 3: raise ValueError - seq = (seq << 7) | data[3] & 0x7f + seq = (seq << 7) | data[3] & 0x7F if data[3] & 0x80: raise ValueError return (seq, memoryview(data)[4:]) @@ -686,31 +713,29 @@ def decode_seq(data : bytes | bytearray | memoryview) -> tuple[int, memoryview[i return (seq, memoryview(data)[1:]) @staticmethod - def encode_seq(seq : int) -> bytes: + def encode_seq(seq: int) -> bytes: if seq < 0x40: - return bytes([seq & 0x3f]) + return bytes([seq & 0x3F]) if seq < 0x2000: - return bytes([ - 0x40 | ((seq >> 7) & 0x3f), - seq & 0x7f]) + return bytes([0x40 | ((seq >> 7) & 0x3F), seq & 0x7F]) if seq < 0x100000: - return bytes([ - 0x40 | ((seq >> 14) & 0x3f), - 0x80 | ((seq >> 7) & 0x7f), - seq & 0x7f]) + return bytes([0x40 | ((seq >> 14) & 0x3F), 0x80 | ((seq >> 7) & 0x7F), seq & 0x7F]) if seq < 0x8000000: - return bytes([ - 0x40 | ((seq >> 21) & 0x3f), - 0x80 | ((seq >> 14) & 0x7f), - 0x80 | ((seq >> 7) & 0x7f), - seq & 0x7f]) + return bytes( + [ + 0x40 | ((seq >> 21) & 0x3F), + 0x80 | ((seq >> 14) & 0x7F), + 0x80 | ((seq >> 7) & 0x7F), + seq & 0x7F, + ] + ) return DebugArqLayer.encode_seq(seq % 0x8000000) - async def encode(self, data : ProtocolLayer.Packet) -> None: + async def encode(self, data: ProtocolLayer.Packet) -> None: if isinstance(data, str): data = data.encode() elif isinstance(data, memoryview): - data = data.cast('B') + data = data.cast("B") if self._reset: self._reset = False @@ -732,7 +757,7 @@ async def encode(self, data : ProtocolLayer.Packet) -> None: if not self._syncing: await self._encode(request) - async def _encode(self, data : ProtocolLayer.Packet) -> None: + async def _encode(self, data: ProtocolLayer.Packet) -> None: async with self._encode_lock: await super().encode(data) @@ -752,7 +777,7 @@ async def connected(self) -> None: await super().connected() async def retransmit(self) -> None: - self.logger.debug('retransmit') + self.logger.debug("retransmit") if not self._req: self._decode_seq = self._decode_seq_start @@ -766,7 +791,7 @@ async def timeout(self) -> None: await self.retransmit() @staticmethod - def next_seq(seq : int) -> int: + def next_seq(seq: int) -> int: seq = (seq + 1) % 0x8000000 return 1 if seq == 0 else seq @@ -778,25 +803,26 @@ def mtu(self) -> int | None: return max(1, m - 4) - class ArqLayer(ProtocolLayer): - ''' + """ A ProtocolLayer that implements a general-purpose ARQ protocol. - ''' + """ - name = 'Arq' + name = "Arq" nop_flag = 0x40 ack_flag = 0x80 - seq_mask = 0x3f + seq_mask = 0x3F - def __init__(self, timeout_s : float | None=None, *args, keep_alive_s : float | None=None, **kwargs): + def __init__( + self, timeout_s: float | None = None, *args, keep_alive_s: float | None = None, **kwargs + ): super().__init__(*args, **kwargs) - self._encode_lock : asyncio.Lock = asyncio.Lock() - self._retransmitter : asyncio.Task | None = None - self._keep_alive : asyncio.Task | None = None - self._timeout_s : float | None = None - self._keep_alive_s : float | None = None + self._encode_lock: asyncio.Lock = asyncio.Lock() + self._retransmitter: asyncio.Task | None = None + self._keep_alive: asyncio.Task | None = None + self._timeout_s: float | None = None + self._keep_alive_s: float | None = None self._reset() self.timeout_s = timeout_s self.keep_alive_s = keep_alive_s @@ -806,21 +832,23 @@ def timeout_s(self) -> float | None: return self._timeout_s @timeout_s.setter - def timeout_s(self, value : float | None) -> None: + def timeout_s(self, value: float | None) -> None: self._timeout_s = value if value is None and self._retransmitter is not None: self._retransmitter.cancel() self._retransmitter = None elif value is not None and self._retransmitter is None: - self._retransmitter = asyncio.create_task(self._retransmitter_task(), name=self.__class__.__name__) + self._retransmitter = asyncio.create_task( + self._retransmitter_task(), name=self.__class__.__name__ + ) @property def keep_alive_s(self) -> float | None: return self._keep_alive_s @keep_alive_s.setter - def keep_alive_s(self, value : float | None) -> None: + def keep_alive_s(self, value: float | None) -> None: self._keep_alive_s = value if value is not None and self.timeout_s is None: @@ -830,25 +858,27 @@ def keep_alive_s(self, value : float | None) -> None: self._keep_alive.cancel() self._keep_alive = None elif value is not None and self._keep_alive is None: - self._keep_alive = asyncio.create_task(self._keep_alive_task(), name=self.__class__.__name__) + self._keep_alive = asyncio.create_task( + self._keep_alive_task(), name=self.__class__.__name__ + ) def _reset(self) -> None: - self._encode_queue : list[bytes] = [bytes([self.nop_flag])] - self._send_seq : int = self._next_seq(0) - self._recv_seq : int = 0 - self._sent : bool = False - self._pause_transmit : bool = False - self._t_sent : float = time.time() - - async def decode(self, data : ProtocolLayer.Packet) -> None: + self._encode_queue: list[bytes] = [bytes([self.nop_flag])] + self._send_seq: int = self._next_seq(0) + self._recv_seq: int = 0 + self._sent: bool = False + self._pause_transmit: bool = False + self._t_sent: float = time.time() + + async def decode(self, data: ProtocolLayer.Packet) -> None: if isinstance(data, str): data = data.encode() if isinstance(data, memoryview): - data = data.cast('B') + data = data.cast("B") else: - data = memoryview(data).cast('B') + data = memoryview(data).cast("B") - resp = b'' + resp = b"" reset_handshake = False do_transmit = False do_decode = False @@ -876,7 +906,7 @@ async def decode(self, data : ProtocolLayer.Packet) -> None: elif hdr_seq == 0: # Reset handshake. resp += bytes([self.ack_flag]) - data = b'' + data = b"" if not reset_handshake: self._reset() @@ -896,10 +926,10 @@ async def decode(self, data : ProtocolLayer.Packet) -> None: data = data[1:] else: # Already decoded. - data = b'' + data = b"" else: # Drop. - data = b'' + data = b"" do_transmit = True if do_decode: @@ -919,7 +949,7 @@ async def decode(self, data : ProtocolLayer.Packet) -> None: def waiting_for_ack(self) -> bool: return len(self._encode_queue) > 0 and self._sent - async def encode(self, data : ProtocolLayer.Packet) -> None: + async def encode(self, data: ProtocolLayer.Packet) -> None: if len(data) == 0: return @@ -928,27 +958,27 @@ async def encode(self, data : ProtocolLayer.Packet) -> None: if is_idle and not self._pause_transmit: await self._transmit() - def _push_encode_queue(self, data : ProtocolLayer.Packet) -> None: + def _push_encode_queue(self, data: ProtocolLayer.Packet) -> None: if isinstance(data, str): data = data.encode() elif isinstance(data, memoryview): - data = data.cast('B') + data = data.cast("B") self._encode_queue.append(bytes([self._send_seq]) + data) self._send_seq = self._next_seq(self._send_seq) - def _next_seq(self, seq : int) -> int: + def _next_seq(self, seq: int) -> int: seq = (seq + 1) & self.seq_mask if seq == 0: seq = 1 return seq - async def _transmit(self, prefix : bytes = b'') -> bool: + async def _transmit(self, prefix: bytes = b"") -> bool: async with self._encode_lock: self._t_sent = time.time() if len(self._encode_queue) == 0: - if prefix == b'': + if prefix == b"": return False await super().encode(prefix) return True @@ -964,7 +994,7 @@ async def connected(self) -> None: await super().connected() async def retransmit(self) -> None: - self.logger.debug('retransmit') + self.logger.debug("retransmit") await self._transmit() async def timeout(self) -> None: @@ -1009,7 +1039,7 @@ async def keep_alive(self) -> None: await self._transmit() return - self.logger.debug('keep alive') + self.logger.debug("keep alive") self._encode_queue.append(bytes([self._send_seq | self.nop_flag])) self._send_seq = self._next_seq(self._send_seq) await self._transmit() @@ -1033,41 +1063,40 @@ async def _keep_alive_task(self) -> None: raise - class Crc8Layer(ProtocolLayer): - ''' + """ ProtocolLayer to add and check integrity using a CRC8. - ''' + """ - name = 'crc8' + name = "crc8" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._crc = crcmod.mkCrcFun(0x1a6, 0xff, False, 0) + self._crc = crcmod.mkCrcFun(0x1A6, 0xFF, False, 0) - async def encode(self, data : ProtocolLayer.Packet) -> None: + async def encode(self, data: ProtocolLayer.Packet) -> None: if isinstance(data, str): data = data.encode() await super().encode(bytearray(data) + bytes([self.crc(data)])) - async def decode(self, data : ProtocolLayer.Packet) -> None: + async def decode(self, data: ProtocolLayer.Packet) -> None: if isinstance(data, str): data = data.encode() elif isinstance(data, memoryview): - data = data.cast('B') + data = data.cast("B") if len(data) == 0: return if self.crc(data[0:-1]) != data[-1]: -# self.logger.debug('invalid CRC, dropped ' + str(bytes(data))) + # self.logger.debug('invalid CRC, dropped ' + str(bytes(data))) return - self.logger.debug('valid CRC %s', bytes(data)) + self.logger.debug("valid CRC %s", bytes(data)) await super().decode(data[0:-1]) - def crc(self, data : ProtocolLayer.Packet) -> int: + def crc(self, data: ProtocolLayer.Packet) -> int: return self._crc(data) @property @@ -1079,41 +1108,40 @@ def mtu(self) -> int | None: return min(256, max(1, m - 1)) - class Crc16Layer(ProtocolLayer): - ''' + """ ProtocolLayer to add and check integrity using a CRC16. - ''' + """ - name = 'crc16' + name = "crc16" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._crc = crcmod.mkCrcFun(0x1baad, 0xffff, False, 0) + self._crc = crcmod.mkCrcFun(0x1BAAD, 0xFFFF, False, 0) async def encode(self, data: ProtocolLayer.Packet) -> None: if isinstance(data, str): data = data.encode() - await super().encode(bytearray(data) + struct.pack('>H', self.crc(data))) + await super().encode(bytearray(data) + struct.pack(">H", self.crc(data))) async def decode(self, data: ProtocolLayer.Packet) -> None: if isinstance(data, str): data = data.encode() elif isinstance(data, memoryview): - data = data.cast('B') + data = data.cast("B") if len(data) < 2: return - if self.crc(data[0:-2]) != struct.unpack('>H', data[-2:])[0]: -# self.logger.debug('invalid CRC, dropped ' + str(bytes(data))) + if self.crc(data[0:-2]) != struct.unpack(">H", data[-2:])[0]: + # self.logger.debug('invalid CRC, dropped ' + str(bytes(data))) return - self.logger.debug('valid CRC %s', bytes(data)) + self.logger.debug("valid CRC %s", bytes(data)) await super().decode(data[0:-2]) - def crc(self, data : ProtocolLayer.Packet) -> int: + def crc(self, data: ProtocolLayer.Packet) -> int: return self._crc(data) @property @@ -1126,39 +1154,39 @@ def mtu(self) -> int | None: class Crc32Layer(ProtocolLayer): - ''' + """ ProtocolLayer to add and check integrity using a CRC32. - ''' + """ - name = 'crc32' + name = "crc32" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._crc = crcmod.mkCrcFun(0x104c11db7, 0, True, 0xffffffff) + self._crc = crcmod.mkCrcFun(0x104C11DB7, 0, True, 0xFFFFFFFF) async def encode(self, data: ProtocolLayer.Packet) -> None: if isinstance(data, str): data = data.encode() - await super().encode(bytearray(data) + struct.pack('>I', self.crc(data))) + await super().encode(bytearray(data) + struct.pack(">I", self.crc(data))) async def decode(self, data: ProtocolLayer.Packet) -> None: if isinstance(data, str): data = data.encode() elif isinstance(data, memoryview): - data = data.cast('B') + data = data.cast("B") if len(data) < 4: return - if self.crc(data[0:-4]) != struct.unpack('>I', data[-4:])[0]: -# self.logger.debug('invalid CRC, dropped ' + str(bytes(data))) + if self.crc(data[0:-4]) != struct.unpack(">I", data[-4:])[0]: + # self.logger.debug('invalid CRC, dropped ' + str(bytes(data))) return - self.logger.debug('valid CRC %s', bytes(data)) + self.logger.debug("valid CRC %s", bytes(data)) await super().decode(data[0:-4]) - def crc(self, data : ProtocolLayer.Packet) -> int: + def crc(self, data: ProtocolLayer.Packet) -> int: return self._crc(data) @property @@ -1169,18 +1197,17 @@ def mtu(self) -> int | None: return max(1, m - 1) - class ProtocolStack(ProtocolLayer): - ''' + """ Composition of a stack of layers. The given stack assumes that the layers are already wrapped. At index 0, the application layer is expected, at index -1 the physical layer. - ''' + """ - name = 'stack' + name = "stack" - def __init__(self, layers : list[ProtocolLayer], *args, **kwargs): + def __init__(self, layers: list[ProtocolLayer], *args, **kwargs): super().__init__(*args, **kwargs) if layers == []: layers = [ProtocolLayer()] @@ -1188,10 +1215,10 @@ def __init__(self, layers : list[ProtocolLayer], *args, **kwargs): self._layers[-1].down = super().encode self._layers[0].up = super().decode - async def encode(self, data : ProtocolLayer.Packet) -> None: + async def encode(self, data: ProtocolLayer.Packet) -> None: await self._layers[0].encode(data) - async def decode(self, data : ProtocolLayer.Packet) -> None: + async def decode(self, data: ProtocolLayer.Packet) -> None: await self._layers[-1].decode(data) def last_activity(self) -> float: @@ -1201,7 +1228,7 @@ async def close(self) -> None: try: await self._layers[0].close() except BaseException as e: - self.logger.warning(f'Exception while closing: {e}') + self.logger.warning(f"Exception while closing: {e}") await super().close() @@ -1245,106 +1272,106 @@ def __iter__(self): return self.Iterator(self) - @typing.overload -def stack(layers : list[ProtocolLayer], /) -> ProtocolLayer: ... +def stack(layers: list[ProtocolLayer], /) -> ProtocolLayer: ... @typing.overload -def stack(layers : ProtocolLayer, /, *args) -> ProtocolLayer: ... +def stack(layers: ProtocolLayer, /, *args) -> ProtocolLayer: ... -def stack(layers : list[ProtocolLayer] | ProtocolLayer, /, *args) -> ProtocolLayer: - ''' + +def stack(layers: list[ProtocolLayer] | ProtocolLayer, /, *args) -> ProtocolLayer: + """ Create a ProtocolStack from a list of layers. - ''' + """ if isinstance(layers, ProtocolLayer): layers = [layers] + list(args) elif len(args) > 0: - raise ValueError('When layers is a list, no additional arguments are allowed') + raise ValueError("When layers is a list, no additional arguments are allowed") for i in range(len(layers) - 1): layers[i + 1].wrap(layers[i]) return ProtocolStack(layers) - class LoopbackLayer(ProtocolLayer): - ''' + """ A ProtocolLayer that loops back all data. - ''' + """ - name = 'loop' + name = "loop" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - async def encode(self, data : ProtocolLayer.Packet) -> None: + async def encode(self, data: ProtocolLayer.Packet) -> None: await self.decode(data) await super().encode(data) - class RawLayer(ProtocolLayer): - ''' + """ A ProtocolLayer that just forwards raw data as bytes. - ''' - name = 'raw' + """ + + name = "raw" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - async def encode(self, data : ProtocolLayer.Packet) -> None: + async def encode(self, data: ProtocolLayer.Packet) -> None: if isinstance(data, str): data = data.encode() await super().encode(data) - async def decode(self, data : ProtocolLayer.Packet) -> None: + async def decode(self, data: ProtocolLayer.Packet) -> None: if isinstance(data, str): data = data.encode() await super().decode(data) - class MuxLayer(ProtocolLayer): - ''' + """ A ProtocolLayer that multiplexes data to different upper layers based on an channel identifier. - ''' + """ - name = 'mux' - esc = 0x10 # DLE + name = "mux" + esc = 0x10 # DLE repeat = 0x15 # NAK - def __init__(self, default : int | str=0, repeat_interval : float=1, *args, **kwargs): + def __init__(self, default: int | str = 0, repeat_interval: float = 1, *args, **kwargs): super().__init__(*args, **kwargs) - self._layers : dict[int, ProtocolLayer] = {} + self._layers: dict[int, ProtocolLayer] = {} self._default = int(default, 0) if isinstance(default, str) else default self._repeat_interval = repeat_interval - self._prev : int | None = None - self._t_prev : float = 0 - self._decoding : int | None = None - self._decoding_esc : bool = False + self._prev: int | None = None + self._t_prev: float = 0 + self._decoding: int | None = None + self._decoding_esc: bool = False self.set(self._default, self) - def set(self, chan : int, layer : ProtocolLayer) -> None: - ''' + def set(self, chan: int, layer: ProtocolLayer) -> None: + """ Register an upper layer with a given channel identifier. - ''' + """ if chan < 0 or chan > 255: - raise ValueError('chan must be a single byte') + raise ValueError("chan must be a single byte") if chan in self._layers: - raise ValueError(f'channel {chan} already registered') + raise ValueError(f"channel {chan} already registered") if chan == self.esc or chan == self.repeat: - raise ValueError(f'channel {chan} is reserved') + raise ValueError(f"channel {chan} is reserved") self._layers[chan] = layer if layer is not self: + async def _encode(data): await self._encode(chan, data) + layer.down = _encode - def reset(self, chan : int) -> None: - ''' + def reset(self, chan: int) -> None: + """ Unregister an upper layer from a given channel identifier. - ''' + """ if chan not in self._layers: return @@ -1353,10 +1380,10 @@ def reset(self, chan : int) -> None: l.down = None del self._layers[chan] - async def encode(self, data : ProtocolLayer.Packet) -> None: + async def encode(self, data: ProtocolLayer.Packet) -> None: await self._encode(self._default, data) - async def _encode(self, chan : int, data : ProtocolLayer.Packet) -> None: + async def _encode(self, chan: int, data: ProtocolLayer.Packet) -> None: if isinstance(data, str): data = data.encode() elif isinstance(data, memoryview): @@ -1366,22 +1393,26 @@ async def _encode(self, chan : int, data : ProtocolLayer.Packet) -> None: return now = time.time() - prefix = b'' + prefix = b"" if self._decoding is None: prefix = bytes([self.esc, self.repeat]) - if self._prev is None or chan != self._prev or (now - self._t_prev) >= self._repeat_interval: + if ( + self._prev is None + or chan != self._prev + or (now - self._t_prev) >= self._repeat_interval + ): prefix = bytes([self.esc, chan]) self._t_prev = now self._prev = chan await super().encode(prefix + data.replace(bytes([self.esc]), bytes([self.esc, self.esc]))) - async def decode(self, data : ProtocolLayer.Packet) -> None: + async def decode(self, data: ProtocolLayer.Packet) -> None: if isinstance(data, str): data = data.encode() if not isinstance(data, memoryview): data = memoryview(data) - data = data.cast('B') + data = data.cast("B") start = 0 for i in range(len(data)): @@ -1406,13 +1437,13 @@ async def decode(self, data : ProtocolLayer.Packet) -> None: if not self._decoding_esc and start < len(data): await self._dispatch(data[start:]) - async def _dispatch(self, data : bytes | memoryview) -> None: + async def _dispatch(self, data: bytes | memoryview) -> None: chan = self._decoding if chan is None: - self.logger.debug('Current decoding channel unknown, dropped %s', bytes(data)) + self.logger.debug("Current decoding channel unknown, dropped %s", bytes(data)) return if chan not in self._layers: - self.logger.debug('No protocol stack for channel %d, dropped %s', chan, bytes(data)) + self.logger.debug("No protocol stack for channel %d, dropped %s", chan, bytes(data)) return layer = self._layers[chan] @@ -1435,9 +1466,8 @@ async def disconnected(self) -> None: await super().disconnected() - class Aes256Layer(ProtocolLayer): - ''' + """ A ProtocolLayer that adds AES-256 encryption/decryption. The unified mode allows using the same cipher state for encryption and decryption, which uses a @@ -1446,35 +1476,42 @@ class Aes256Layer(ProtocolLayer): A specific mode is the reqrep mode, which is a unified mode, but changes the IV for every request/reply pair. This is useful for a ZeroMQ REQ/REP pattern, where the server may handle multiple clients simultaneously, which does not allow having a single cipher state. - ''' + """ - name = 'aes256' + name = "aes256" - def __init__(self, key : bytes | str | None=None, *args, unified : bool=False, reqrep : bool=False, **kwargs): + def __init__( + self, + key: bytes | str | None = None, + *args, + unified: bool = False, + reqrep: bool = False, + **kwargs, + ): super().__init__(*args, **kwargs) self._encrypt = None self._decrypt = None - self._reqrep : bool = reqrep - self._unified : bool = unified or reqrep + self._reqrep: bool = reqrep + self._unified: bool = unified or reqrep if key is None: - raise ValueError('Key file or binary string must be provided for Aes256Layer') + raise ValueError("Key file or binary string must be provided for Aes256Layer") self.set_key(key) - def set_key(self, key : bytes | str) -> None: - ''' + def set_key(self, key: bytes | str) -> None: + """ Change the AES-256 key. The argument can be either a 32 byte binary string, or a filename containing the key. - ''' + """ if isinstance(key, str): - with open(key, 'rb') as f: + with open(key, "rb") as f: key = f.read() if len(key) != 32: - raise ValueError('Key must be 32 bytes for AES-256') + raise ValueError("Key must be 32 bytes for AES-256") self._key = key self._encrypt = None @@ -1482,16 +1519,16 @@ def set_key(self, key : bytes | str) -> None: @property def unified(self) -> bool: - ''' + """ Return if unified mode is enabled. - ''' + """ return self._unified or self._reqrep @unified.setter - def unified(self, enable : bool=True) -> None: - ''' + def unified(self, enable: bool = True) -> None: + """ Set unified mode. - ''' + """ self._unified = enable self._encrypt = None @@ -1502,16 +1539,16 @@ def unified(self, enable : bool=True) -> None: @property def reqrep(self) -> bool: - ''' + """ Return if reqrep mode is enabled. - ''' + """ return self._reqrep @reqrep.setter - def reqrep(self, enable : bool=True) -> None: - ''' + def reqrep(self, enable: bool = True) -> None: + """ Set reqrep mode. - ''' + """ self._reqrep = enable self._unified = enable @@ -1519,11 +1556,11 @@ def reqrep(self, enable : bool=True) -> None: if enable: self._decrypt = None - async def encode(self, data : ProtocolLayer.Packet) -> None: + async def encode(self, data: ProtocolLayer.Packet) -> None: if isinstance(data, str): data = data.encode() elif isinstance(data, memoryview): - data = data.cast('B') + data = data.cast("B") data = Crypto.Util.Padding.pad(data, 16) @@ -1538,40 +1575,42 @@ async def encode(self, data : ProtocolLayer.Packet) -> None: data = prefix + data await super().encode(data) - async def decode(self, data : ProtocolLayer.Packet) -> None: + async def decode(self, data: ProtocolLayer.Packet) -> None: if isinstance(data, str): data = data.encode() elif isinstance(data, memoryview): - data = data.cast('B') + data = data.cast("B") if len(data) > 16 and len(data) % 16 == 1: # Received IV for decryption iv = data[1:17] - cypher = Crypto.Cipher.AES.new(self._key, Crypto.Cipher.AES.MODE_CTR, nonce=b'', initial_value=iv) - if data[0:1] == b'U': - self.logger.debug('Received IV for unified operation') + cypher = Crypto.Cipher.AES.new( + self._key, Crypto.Cipher.AES.MODE_CTR, nonce=b"", initial_value=iv + ) + if data[0:1] == b"U": + self.logger.debug("Received IV for unified operation") self._decrypt = cypher self._encrypt = cypher self._unified = True self._reqrep = False - elif data[0:1] == b'B': - self.logger.debug('Received IV for decryption') + elif data[0:1] == b"B": + self.logger.debug("Received IV for decryption") self._decrypt = cypher self._unified = False self._reqrep = False else: - self.logger.debug('Invalid IV prefix') + self.logger.debug("Invalid IV prefix") self._decrypt = None data = data[17:] if self._decrypt is None: - self.logger.debug('Got data before IV, waiting for IV') + self.logger.debug("Got data before IV, waiting for IV") self._decrypt = None return if len(data) % 16 != 0: - self.logger.debug('Data length not multiple of 16, dropped') + self.logger.debug("Data length not multiple of 16, dropped") self._decrypt = None return @@ -1584,22 +1623,24 @@ async def decode(self, data : ProtocolLayer.Packet) -> None: try: data = Crypto.Util.Padding.unpad(data, 16) except ValueError: - self.logger.debug('Invalid padding, dropped') + self.logger.debug("Invalid padding, dropped") self._decrypt = None return await super().decode(data) - def _iv(self, unified : bool) -> bytes: + def _iv(self, unified: bool) -> bytes: iv = Crypto.Random.get_random_bytes(16) # Make sure not to wrap around the counter soon. - iv = bytes([iv[0] & 0x0f]) + iv[1:] - self._encrypt = Crypto.Cipher.AES.new(self._key, Crypto.Cipher.AES.MODE_CTR, nonce=b'', initial_value=iv) + iv = bytes([iv[0] & 0x0F]) + iv[1:] + self._encrypt = Crypto.Cipher.AES.new( + self._key, Crypto.Cipher.AES.MODE_CTR, nonce=b"", initial_value=iv + ) if unified: self._decrypt = self._encrypt - return b'U' + iv + return b"U" + iv else: - return b'B' + iv + return b"B" + iv async def connected(self) -> None: self._encrypt = None @@ -1607,8 +1648,7 @@ async def connected(self) -> None: await super().connected() - -layer_types : list[typing.Type[ProtocolLayer]] = [ +layer_types: list[typing.Type[ProtocolLayer]] = [ AsciiEscapeLayer, TerminalLayer, PubTerminalLayer, @@ -1625,46 +1665,51 @@ async def connected(self) -> None: Aes256Layer, ] -def register_layer_type(layer_type : typing.Type[ProtocolLayer]) -> None: - ''' + +def register_layer_type(layer_type: typing.Type[ProtocolLayer]) -> None: + """ Register a new protocol layer type. The layer type must be a subclass of ProtocolLayer with a unique name. - ''' + """ for lt in layer_types: if layer_type.name == lt.name: - raise ValueError(f'Layer type {layer_type.name} already registered') + raise ValueError(f"Layer type {layer_type.name} already registered") layer_types.append(layer_type) -def unregister_layer_type(layer_type : typing.Type[ProtocolLayer]) -> None: - ''' + +def unregister_layer_type(layer_type: typing.Type[ProtocolLayer]) -> None: + """ Unregister a protocol layer type. - ''' + """ for i, lt in enumerate(layer_types): if layer_type.name == lt.name: del layer_types[i] return - raise ValueError(f'Layer type {layer_type.name} not registered') + raise ValueError(f"Layer type {layer_type.name} not registered") -def get_layer_type(name : str) -> typing.Type[ProtocolLayer]: - ''' + +def get_layer_type(name: str) -> typing.Type[ProtocolLayer]: + """ Get a protocol layer type by name. - ''' + """ for lt in layer_types: if name == lt.name: return lt - raise ValueError(f'Unknown layer type {name}') + raise ValueError(f"Unknown layer type {name}") + def get_layer_types() -> list[typing.Type[ProtocolLayer]]: - ''' + """ Get the list of registered protocol layer types. - ''' + """ return layer_types.copy() -def build_stack(description : str) -> ProtocolLayer: - ''' + +def build_stack(description: str) -> ProtocolLayer: + """ Construct the protocol stack from a description. The description is a comma-separated string with layer ids. If the layer has @@ -1672,21 +1717,21 @@ def build_stack(description : str) -> ProtocolLayer: of the specified layers. Grammar: ( ( ``=`` ) ? ) (``,`` ( ``=`` ) ? ) * - ''' + """ - layers = description.split(',') + layers = description.split(",") if layers == []: # Dummy layer return ProtocolLayer() - stack : list[ProtocolLayer] = [] + stack: list[ProtocolLayer] = [] try: for l in layers: - name_arg = l.split('=') - if name_arg[0] == '': - raise ValueError(f'Missing layer type') + name_arg = l.split("=") + if name_arg[0] == "": + raise ValueError(f"Missing layer type") layer_type = get_layer_type(name_arg[0]) diff --git a/python/libstored/protocol/serial.py b/python/libstored/protocol/serial.py index b50c2733..a50edaf3 100644 --- a/python/libstored/protocol/serial.py +++ b/python/libstored/protocol/serial.py @@ -8,20 +8,23 @@ from . import protocol as lprot from . import util as lprot_util + class SerialLayer(lprot.ProtocolLayer): - name = 'serial' + name = "serial" - def __init__(self, *, drop_s : float | None=1, **kwargs): + def __init__(self, *, drop_s: float | None = 1, **kwargs): super().__init__() - self.logger.debug('Opening serial port %s', kwargs['port']) - self._serial : serial.Serial | None = None - self._writer : lprot_util.Writer | None = None + self.logger.debug("Opening serial port %s", kwargs["port"]) + self._serial: serial.Serial | None = None + self._writer: lprot_util.Writer | None = None self._encode_buffer = bytearray() - self._open : bool = True + self._open: bool = True - self._serial_task : asyncio.Task | None = asyncio.create_task(self._serial_run(drop_s, kwargs), name=self.__class__.__name__) + self._serial_task: asyncio.Task | None = asyncio.create_task( + self._serial_run(drop_s, kwargs), name=self.__class__.__name__ + ) @property def open(self) -> bool: @@ -29,42 +32,46 @@ def open(self) -> bool: def _read(self): if not self._open: - raise RuntimeError('Serial port closed') + raise RuntimeError("Serial port closed") assert self._serial is not None data = self._serial.read(max(1, self._serial.in_waiting)) - self.logger.debug('received %s', data) + self.logger.debug("received %s", data) return data - def _write(self, data : bytes) -> None: + def _write(self, data: bytes) -> None: if not self._open: - raise RuntimeError('Serial port closed') + raise RuntimeError("Serial port closed") assert self._serial is not None - self.logger.debug('send %s', data) + self.logger.debug("send %s", data) cnt = self._serial.write(data) assert cnt == len(data) self._serial.flush() - async def _serial_run(self, drop_s : float | None, serial_args : dict) -> None: + async def _serial_run(self, drop_s: float | None, serial_args: dict) -> None: try: # Only access _serial within the current asyncio loop. self._serial = await self._serial_open(**serial_args) - self.logger.debug('Serial port %s opened', serial_args['port']) + self.logger.debug("Serial port %s opened", serial_args["port"]) if drop_s is not None and drop_s > 0: await asyncio.sleep(drop_s) if self._serial.in_waiting > 0: data = self._serial.read(self._serial.in_waiting) - self.logger.debug('Flushing initial data: %s', data) + self.logger.debug("Flushing initial data: %s", data) # Only access self._serial read/write in reader/writer threads. - async with lprot_util.Writer(self._write, thread_name=f'{self.__class__.__name__}-writer') as writer: - async with lprot_util.Reader(self._read, thread_name=f'{self.__class__.__name__}-reader') as reader: + async with lprot_util.Writer( + self._write, thread_name=f"{self.__class__.__name__}-writer" + ) as writer: + async with lprot_util.Reader( + self._read, thread_name=f"{self.__class__.__name__}-reader" + ) as reader: try: if self._encode_buffer: - self.logger.debug('sending buffered %s', self._encode_buffer) + self.logger.debug("sending buffered %s", self._encode_buffer) await self._encode(self._encode_buffer) self._encode_buffer = bytearray() @@ -94,9 +101,9 @@ async def _serial_run(self, drop_s : float | None, serial_args : dict) -> None: if self._serial is not None: self._serial.close() self._serial = None - self.logger.debug('Closed serial port') + self.logger.debug("Closed serial port") - async def _serial_open(self, timeout_s : int=60, **kwargs) -> serial.Serial: + async def _serial_open(self, timeout_s: int = 60, **kwargs) -> serial.Serial: last_e = TimeoutError() for i in range(0, timeout_s): try: @@ -110,7 +117,7 @@ async def _serial_open(self, timeout_s : int=60, **kwargs) -> serial.Serial: # For unclear reasons, Windows sometimes reports the port as # being in use. That issue seems to clear automatically after # a while. - if 'PermissionError' not in str(e): + if "PermissionError" not in str(e): raise last_e = e @@ -118,23 +125,23 @@ async def _serial_open(self, timeout_s : int=60, **kwargs) -> serial.Serial: await asyncio.sleep(1) raise last_e - async def encode(self, data : lprot.ProtocolLayer.Packet) -> None: + async def encode(self, data: lprot.ProtocolLayer.Packet) -> None: if isinstance(data, str): data = data.encode() elif isinstance(data, memoryview): - data = data.cast('B') + data = data.cast("B") if not self.open: - self.logger.debug('Serial port closed; dropping data %s', data) + self.logger.debug("Serial port closed; dropping data %s", data) elif self._writer is None: - self.logger.debug('buffering %s', data) + self.logger.debug("buffering %s", data) self._encode_buffer += data else: await self._encode(data) await super().encode(data) - async def _encode(self, data : lprot.ProtocolLayer.Packet) -> None: + async def _encode(self, data: lprot.ProtocolLayer.Packet) -> None: if len(data) == 0: return @@ -160,4 +167,5 @@ async def close(self): await super().close() + lprot.register_layer_type(SerialLayer) diff --git a/python/libstored/protocol/stdio.py b/python/libstored/protocol/stdio.py index 4c3123f4..da4f1b59 100644 --- a/python/libstored/protocol/stdio.py +++ b/python/libstored/protocol/stdio.py @@ -18,13 +18,15 @@ libc = None + # Helper to clean up the child when python crashes. def set_pdeathsig_(libc, sig): - if os.name == 'posix': + if os.name == "posix": os.setsid() libc.prctl(1, sig) -def set_pdeathsig(sig = signal.SIGTERM): + +def set_pdeathsig(sig=signal.SIGTERM): global libc if libc is None: @@ -40,16 +42,19 @@ def set_pdeathsig(sig = signal.SIGTERM): return lambda: set_pdeathsig_(libc, sig) - class StdinLayer(lprot.ProtocolLayer): - '''A terminal layer that reads from stdin.''' + """A terminal layer that reads from stdin.""" - name = 'stdin' + name = "stdin" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._stdin_reader = lprot_util.Reader(self._from_stdin, thread_name=self.__class__.__name__) - self._reader_task : asyncio.Task | None = asyncio.create_task(self._reader_run(), name=self.__class__.__name__) + self._stdin_reader = lprot_util.Reader( + self._from_stdin, thread_name=self.__class__.__name__ + ) + self._reader_task: asyncio.Task | None = asyncio.create_task( + self._reader_run(), name=self.__class__.__name__ + ) def _from_stdin(self) -> str: return sys.stdin.read(1) @@ -80,49 +85,62 @@ async def close(self) -> None: await super().close() -lprot.register_layer_type(StdinLayer) +lprot.register_layer_type(StdinLayer) class StdioLayer(lprot.ProtocolLayer): - '''A protocol layer that runs a subprocess and connects to its stdin/stdout.''' + """A protocol layer that runs a subprocess and connects to its stdin/stdout.""" - name = 'stdio' + name = "stdio" def __init__(self, cmd, *args, **kwargs): super().__init__(*args) self._process = subprocess.Popen( - args=cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=sys.stderr, text=False, - preexec_fn = set_pdeathsig() if os.name == 'posix' else None, + args=cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=sys.stderr, + text=False, + preexec_fn=set_pdeathsig() if os.name == "posix" else None, shell=not os.path.exists(cmd[0] if isinstance(cmd, list) else cmd), - **kwargs) - - self._reader : lprot_util.Reader[bytes] = lprot_util.Reader(self._from_process, thread_name=f'{self.__class__.__name__}-reader') - self._writer : lprot_util.Writer[bytes] = lprot_util.Writer(self._to_process, thread_name=f'{self.__class__.__name__}-writer') - self._reader_task : asyncio.Task | None = asyncio.create_task(self._reader_run(), name=f'{self.__class__.__name__} reader') - self._writer_task : asyncio.Task | None = asyncio.create_task(self._writer.start(), name=f'{self.__class__.__name__} writer') - self._check_task : asyncio.Task | None = None + **kwargs, + ) + + self._reader: lprot_util.Reader[bytes] = lprot_util.Reader( + self._from_process, thread_name=f"{self.__class__.__name__}-reader" + ) + self._writer: lprot_util.Writer[bytes] = lprot_util.Writer( + self._to_process, thread_name=f"{self.__class__.__name__}-writer" + ) + self._reader_task: asyncio.Task | None = asyncio.create_task( + self._reader_run(), name=f"{self.__class__.__name__} reader" + ) + self._writer_task: asyncio.Task | None = asyncio.create_task( + self._writer.start(), name=f"{self.__class__.__name__} writer" + ) + self._check_task: asyncio.Task | None = None self.set_terminate_callback(lambda _: None) def _from_process(self) -> bytes: if self._process.stdout is None or self._process.stdout.closed: - raise RuntimeError('Process has no stdout anymore') + raise RuntimeError("Process has no stdout anymore") - x = self._process.stdout.read1(4096) # type: ignore - self.logger.debug('received %s', x) + x = self._process.stdout.read1(4096) # type: ignore + self.logger.debug("received %s", x) return x - def _to_process(self, data : bytes) -> None: + def _to_process(self, data: bytes) -> None: if self._process.stdin is None or self._process.stdin.closed: - raise RuntimeError('Process has no stdin anymore') + raise RuntimeError("Process has no stdin anymore") try: - self.logger.debug('send %s', data) - self._process.stdin.write(data) # type: ignore + self.logger.debug("send %s", data) + self._process.stdin.write(data) # type: ignore self._process.stdin.flush() except BaseException as e: - self.logger.info(f'Cannot write to stdin; shutdown: {e}') + self.logger.info(f"Cannot write to stdin; shutdown: {e}") try: self._process.stdin.close() except BrokenPipeError: @@ -139,23 +157,26 @@ async def _reader_run(self) -> None: except asyncio.CancelledError: pass except BaseException as e: - self.logger.error(f'StdioLayer process reader error: {e}') + self.logger.error(f"StdioLayer process reader error: {e}") await self.close() finally: await self._reader.stop() - def set_terminate_callback(self, f : Callable[[int], None | Coroutine[None, None, None]]) -> None: - ''' + def set_terminate_callback( + self, f: Callable[[int], None | Coroutine[None, None, None]] + ) -> None: + """ Set a callback function that is called when the process terminates. The function is called with the exit code as argument. - ''' + """ + async def check_task() -> None: try: while True: await asyncio.sleep(1) ret = self._process.poll() if ret is not None: - self.logger.error(f'Process terminated with exit code {ret}') + self.logger.error(f"Process terminated with exit code {ret}") if asyncio.iscoroutinefunction(f): await f(ret) else: @@ -170,20 +191,22 @@ async def check_task() -> None: if self._check_task is not None: self._check_task.cancel() - self._check_task = asyncio.create_task(check_task(), name=f'{self.__class__.__name__} check') + self._check_task = asyncio.create_task( + check_task(), name=f"{self.__class__.__name__} check" + ) - async def encode(self, data : lprot.ProtocolLayer.Packet) -> None: + async def encode(self, data: lprot.ProtocolLayer.Packet) -> None: if isinstance(data, str): data = data.encode() elif isinstance(data, memoryview): - data = data.cast('B') + data = data.cast("B") await self._writer.write(data) await super().encode(data) async def close(self) -> None: - self.logger.debug('Closing; terminate process') - if os.name == 'posix': + self.logger.debug("Closing; terminate process") + if os.name == "posix": try: os.killpg(os.getpgid(self._process.pid), signal.SIGTERM) except ProcessLookupError: @@ -212,55 +235,55 @@ async def close(self) -> None: await super().close() -lprot.register_layer_type(StdioLayer) +lprot.register_layer_type(StdioLayer) class PrintLayer(lprot.ProtocolLayer): - ''' + """ A protocol layer that prints all data sent and received. - ''' + """ - name = 'print' + name = "print" - def __init__(self, prefix : tuple | str | None=None, *args, **kwargs): + def __init__(self, prefix: tuple | str | None = None, *args, **kwargs): super().__init__(*args, **kwargs) self._print = self.default_print - self._prefix : tuple[str, str] = ('', '') + self._prefix: tuple[str, str] = ("", "") if isinstance(prefix, tuple): self._prefix = prefix elif isinstance(prefix, str): - p = prefix.split(',', 1) + p = prefix.split(",", 1) self._prefix = (p[0], p[1] if len(p) > 1 else p[0]) - def _format(self, data : lprot.ProtocolLayer.Packet) -> str: + def _format(self, data: lprot.ProtocolLayer.Packet) -> str: if isinstance(data, str): data = data.encode() elif isinstance(data, memoryview): - data = data.cast('B') + data = data.cast("B") - s = '' + s = "" for b in data: if 32 <= b <= 126 or b in (9, 10, 13): s += chr(b) else: - s += f'\\x{b:02x}' + s += f"\\x{b:02x}" return s - async def encode(self, data : lprot.ProtocolLayer.Packet) -> None: - await self.print(f'{self._prefix[0]}{self._format(data)}') + async def encode(self, data: lprot.ProtocolLayer.Packet) -> None: + await self.print(f"{self._prefix[0]}{self._format(data)}") await super().encode(data) - async def decode(self, data : lprot.ProtocolLayer.Packet) -> None: - await self.print(f'{self._prefix[1]}{self._format(data)}') + async def decode(self, data: lprot.ProtocolLayer.Packet) -> None: + await self.print(f"{self._prefix[1]}{self._format(data)}") await super().decode(data) - async def default_print(self, msg : str) -> None: + async def default_print(self, msg: str) -> None: await aiofiles.stdout.write(msg) await aiofiles.stdout.flush() - async def print(self, msg : str) -> None: + async def print(self, msg: str) -> None: await self._print(msg) @property @@ -268,18 +291,20 @@ def printer(self) -> Callable[[str], Awaitable[None]]: return self._print @printer.setter - def printer(self, func : Callable[[str], typing.Any]) -> None: + def printer(self, func: Callable[[str], typing.Any]) -> None: if asyncio.iscoroutinefunction(func): self._print = func else: - async def wrapper(msg : str) -> None: + + async def wrapper(msg: str) -> None: func(msg) + self._print = wrapper __all__ = [ - 'StdinLayer', - 'StdioLayer', - 'PrintLayer', - 'set_pdeathsig', + "StdinLayer", + "StdioLayer", + "PrintLayer", + "set_pdeathsig", ] diff --git a/python/libstored/protocol/util.py b/python/libstored/protocol/util.py index d05ea98a..54d599be 100644 --- a/python/libstored/protocol/util.py +++ b/python/libstored/protocol/util.py @@ -14,38 +14,45 @@ try: import fcntl, os + def set_blocking(stdout): fileno = stdout.fileno() fl = fcntl.fcntl(fileno, fcntl.F_GETFL) fcntl.fcntl(fileno, fcntl.F_SETFL, fl & ~os.O_NONBLOCK) + except: # Not supported def set_blocking(stdout): pass - class InfiniteStdoutBuffer: - ''' + """ A class that provides a non-blocking infinite buffer wrapper for stdout. - ''' + """ - def __init__(self, stdout : typing.TextIO | None=sys.__stdout__, cleanup : typing.Callable[[], None] | None= None): + def __init__( + self, + stdout: typing.TextIO | None = sys.__stdout__, + cleanup: typing.Callable[[], None] | None = None, + ): self.stdout = stdout - self._queue : queue.Queue[str] | None = queue.Queue() - self._closed : bool = False - self._cleanup : typing.Callable[[], None] | None = cleanup - self._thread : threading.Thread | None = threading.Thread(target=self._worker, daemon=True, name='InfiniteStdoutBufferWorker') + self._queue: queue.Queue[str] | None = queue.Queue() + self._closed: bool = False + self._cleanup: typing.Callable[[], None] | None = cleanup + self._thread: threading.Thread | None = threading.Thread( + target=self._worker, daemon=True, name="InfiniteStdoutBufferWorker" + ) atexit.register(self.close) self._thread.start() - def write(self, data : str) -> None: + def write(self, data: str) -> None: if self._closed or self._queue is None: self._write(data) else: self._queue.put(data) - def _write(self, data : str) -> None: + def _write(self, data: str) -> None: # This may block. if self.stdout is not None: self.stdout.write(data) @@ -65,7 +72,7 @@ def close(self) -> None: # Force wakeup assert self._queue is not None - self._queue.put('') + self._queue.put("") self._thread.join() self._thread = None @@ -99,9 +106,11 @@ def _worker(self): finally: queue.task_done() -def reset_stdout(old_stdout : typing.TextIO | None) -> None: + +def reset_stdout(old_stdout: typing.TextIO | None) -> None: sys.stdout = old_stdout + def set_infinite_stdout(): if isinstance(sys.stdout, InfiniteStdoutBuffer): return @@ -111,23 +120,24 @@ def set_infinite_stdout(): sys.stdout = InfiniteStdoutBuffer(old_stdout, lambda: reset_stdout(old_stdout)) -T = typing.TypeVar('T') +T = typing.TypeVar("T") + class Reader(typing.Generic[T]): - ''' + """ Asyncio single-reader from a blocking source. - ''' + """ - def __init__(self, f : typing.Callable[[], T], thread_name : str | None=None, *args, **kwargs): + def __init__(self, f: typing.Callable[[], T], thread_name: str | None = None, *args, **kwargs): super().__init__(*args, **kwargs) - self._f : typing.Callable[[], T] = f - self._thread : threading.Thread | None = None - self._queue : list[T] = [] - self._lock : threading.Lock = threading.Lock() - self._event : asyncio.Event = asyncio.Event() - self._loop : asyncio.AbstractEventLoop = asyncio.get_event_loop() - self._running : bool = False - self._thread_name : str = thread_name if thread_name else self.__class__.__name__ + self._f: typing.Callable[[], T] = f + self._thread: threading.Thread | None = None + self._queue: list[T] = [] + self._lock: threading.Lock = threading.Lock() + self._event: asyncio.Event = asyncio.Event() + self._loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() + self._running: bool = False + self._thread_name: str = thread_name if thread_name else self.__class__.__name__ async def start(self) -> None: if self._thread is not None: @@ -137,8 +147,9 @@ async def start(self) -> None: self._event.clear() self._loop = asyncio.get_event_loop() - self._thread = threading.Thread(target=self._thread_func, daemon=True, - name=self._thread_name) + self._thread = threading.Thread( + target=self._thread_func, daemon=True, name=self._thread_name + ) self._thread.start() await self._event.wait() @@ -146,7 +157,7 @@ async def start(self) -> None: def running(self) -> bool: return self._running - async def stop(self, join : bool=True) -> None: + async def stop(self, join: bool = True) -> None: self._running = False if self._thread is not None and self._thread is not threading.current_thread(): @@ -189,7 +200,7 @@ def _thread_func(self) -> None: if was_empty: self._wakeup() except BaseException as e: - logging.getLogger(self.__class__.__qualname__).exception(f'Reader thread error: {e}') + logging.getLogger(self.__class__.__qualname__).exception(f"Reader thread error: {e}") self._running = False finally: # Signal a blocking read() or stop(). @@ -200,7 +211,7 @@ def _wakeup(self) -> None: if not self._loop.is_closed(): asyncio.run_coroutine_threadsafe(self._wakeup_coro(), self._loop) except Exception as e: - logging.getLogger(self.__class__.__qualname__).error(f'Error waking up reader: {e}') + logging.getLogger(self.__class__.__qualname__).error(f"Error waking up reader: {e}") async def _wakeup_coro(self) -> None: self._event.set() @@ -210,7 +221,7 @@ async def read(self) -> T: while True: if not self._running: - raise RuntimeError('Reader not running') + raise RuntimeError("Reader not running") with self._lock: if self._queue: @@ -225,29 +236,31 @@ async def __await__(self) -> T: return await self.read() - class Writer(typing.Generic[T]): - ''' + """ Asyncio single-writer to a blocking sink. - ''' + """ - def __init__(self, f : typing.Callable[[T], None], thread_name : str | None=None, *args, **kwargs): + def __init__( + self, f: typing.Callable[[T], None], thread_name: str | None = None, *args, **kwargs + ): super().__init__(*args, **kwargs) - self._f : typing.Callable[[T], None] = f - self._thread : threading.Thread | None = None - self._queue : queue.Queue[T | None] = queue.Queue() - self._event : asyncio.Event = asyncio.Event() - self._loop : asyncio.AbstractEventLoop = asyncio.get_event_loop() - self._running : bool = False - self._thread_name : str = thread_name if thread_name else self.__class__.__name__ + self._f: typing.Callable[[T], None] = f + self._thread: threading.Thread | None = None + self._queue: queue.Queue[T | None] = queue.Queue() + self._event: asyncio.Event = asyncio.Event() + self._loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() + self._running: bool = False + self._thread_name: str = thread_name if thread_name else self.__class__.__name__ async def start(self) -> None: if self._thread is not None: return self._event.clear() self._loop = asyncio.get_event_loop() - self._thread = threading.Thread(target=self._thread_func, daemon=True, - name=self._thread_name) + self._thread = threading.Thread( + target=self._thread_func, daemon=True, name=self._thread_name + ) self._running = True self._thread.start() await self._event.wait() @@ -284,7 +297,7 @@ def __del__(self): self._queue.put(None) self._thread = None - def _write(self, x : T) -> None: + def _write(self, x: T) -> None: self._f(x) def _thread_func(self): @@ -297,7 +310,7 @@ def _thread_func(self): self._write(x) self._queue.task_done() except BaseException as e: - logging.getLogger(self.__class__.__qualname__).error(f'Writer thread error: {e}') + logging.getLogger(self.__class__.__qualname__).error(f"Writer thread error: {e}") self._running = False finally: self._wakeup() @@ -308,16 +321,15 @@ def _wakeup(self) -> None: async def _wakeup_coro(self) -> None: self._event.set() - async def write(self, x : T) -> None: + async def write(self, x: T) -> None: if not self._running: - raise RuntimeError('Writer not running') + raise RuntimeError("Writer not running") self._queue.put(x) - __all__ = [ - 'set_infinite_stdout', - 'Reader', - 'Writer', + "set_infinite_stdout", + "Reader", + "Writer", ] diff --git a/python/libstored/protocol/zmq.py b/python/libstored/protocol/zmq.py index 8c905ab2..b03e21b7 100644 --- a/python/libstored/protocol/zmq.py +++ b/python/libstored/protocol/zmq.py @@ -12,15 +12,17 @@ from .. import protocol as lprot + @overload def free_ports() -> int: ... @overload -def free_ports(num : typing.Literal[None]) -> int: ... +def free_ports(num: typing.Literal[None]) -> int: ... @overload -def free_ports(num : int) -> list[int]: ... +def free_ports(num: int) -> list[int]: ... + -def free_ports(num : int | None=None) -> list[int] | int: - ss : list[socketserver.TCPServer] = [] +def free_ports(num: int | None = None) -> list[int] | int: + ss: list[socketserver.TCPServer] = [] ports = [] for i in range(0, max(1, num) if num is not None else 1): @@ -34,22 +36,25 @@ def free_ports(num : int | None=None) -> list[int] | int: return ports if num is not None else ports[0] - class ZmqSocketBase(lprot.ProtocolLayer): - ''' + """ Generic ZMQ socket layer. - ''' + """ - default_timeout_s : float | None = 10 + default_timeout_s: float | None = 10 - def __init__(self, *args, type : int=zmq.DEALER, context : zmq.asyncio.Context | None=None, **kwargs): + def __init__( + self, *args, type: int = zmq.DEALER, context: zmq.asyncio.Context | None = None, **kwargs + ): super().__init__(*args, **kwargs) - self._context : zmq.asyncio.Context = context or zmq.asyncio.Context.instance() - self._socket : zmq.asyncio.Socket | None = self._context.socket(type) - self._recv : asyncio.Task | None = asyncio.create_task(self._recv_task(), name=f'{self.__class__.__name__} recv') - self._timeout_s : float | None = self.default_timeout_s - self._open : bool = False - self._sent : list[tuple[asyncio.Future, float]] = [] + self._context: zmq.asyncio.Context = context or zmq.asyncio.Context.instance() + self._socket: zmq.asyncio.Socket | None = self._context.socket(type) + self._recv: asyncio.Task | None = asyncio.create_task( + self._recv_task(), name=f"{self.__class__.__name__} recv" + ) + self._timeout_s: float | None = self.default_timeout_s + self._open: bool = False + self._sent: list[tuple[asyncio.Future, float]] = [] @property def context(self) -> zmq.asyncio.Context: @@ -58,7 +63,7 @@ def context(self) -> zmq.asyncio.Context: @property def socket(self) -> zmq.asyncio.Socket: if self._socket is None: - raise RuntimeError('ZMQ socket is closed') + raise RuntimeError("ZMQ socket is closed") return self._socket async def mark_open(self) -> None: @@ -78,9 +83,9 @@ async def _recv_task(self) -> None: await self._recv_init() while True: - x = b''.join(await socket.recv_multipart()) + x = b"".join(await socket.recv_multipart()) if self.logger.getEffectiveLevel() <= logging.DEBUG: - self.logger.debug(f'recv {x}') + self.logger.debug(f"recv {x}") await self.mark_open() await self._handle_recv(x) except asyncio.CancelledError: @@ -92,7 +97,7 @@ async def _recv_task(self) -> None: async def _recv_init(self) -> None: pass - async def _handle_recv(self, data : bytes) -> None: + async def _handle_recv(self, data: bytes) -> None: raise NotImplementedError() async def close(self) -> None: @@ -122,7 +127,7 @@ async def _check_sent(self) -> None: try: f.result() except Exception as e: - self.logger.warning(f'send error: {e}') + self.logger.warning(f"send error: {e}") continue if t is None or self._sent[0][1] > t: @@ -130,7 +135,7 @@ async def _check_sent(self) -> None: break if self.open: - self.logger.info('connection timed out') + self.logger.info("connection timed out") await self.disconnected() return @@ -141,17 +146,17 @@ async def disconnected(self) -> None: self._sent = [] await super().disconnected() - async def _send(self, data : lprot.ProtocolLayer.Packet) -> None: + async def _send(self, data: lprot.ProtocolLayer.Packet) -> None: if isinstance(data, str): data = data.encode() elif isinstance(data, memoryview): - data = data.cast('B') + data = data.cast("B") await self._check_sent() if self.open: if self.logger.getEffectiveLevel() <= logging.DEBUG: - self.logger.debug(f'send {bytes(data)}') + self.logger.debug(f"send {bytes(data)}") f = self.socket.send_multipart([data]) assert isinstance(f, asyncio.Future) self._sent.append((f, asyncio.get_running_loop().time())) @@ -161,44 +166,67 @@ def timeout_s(self) -> float | None: return self._timeout_s @timeout_s.setter - def timeout_s(self, value : float | None) -> None: + def timeout_s(self, value: float | None) -> None: self._timeout_s = value - class ZmqSocketClient(ZmqSocketBase): - ''' + """ Generic ZMQ client socket layer. This layer is expected to be at the bottom of the protocol stack. Received data is passed up the stack. - ''' + """ default_port = lprot.default_port - name = 'connect' + name = "connect" @overload - def __init__(self, *args, server : str='localhost', port : int=default_port, type : int=zmq.DEALER, context : zmq.asyncio.Context | None=None, **kwargs): ... + def __init__( + self, + *args, + server: str = "localhost", + port: int = default_port, + type: int = zmq.DEALER, + context: zmq.asyncio.Context | None = None, + **kwargs, + ): ... @overload - def __init__(self, connect : str, *args, context : zmq.asyncio.Context | None=None, type : int=zmq.DEALER, **kwargs): ... - - def __init__(self, connect : str | None=None, *args, server : str='localhost', port : int=default_port, **kwargs): + def __init__( + self, + connect: str, + *args, + context: zmq.asyncio.Context | None = None, + type: int = zmq.DEALER, + **kwargs, + ): ... + + def __init__( + self, + connect: str | None = None, + *args, + server: str = "localhost", + port: int = default_port, + **kwargs, + ): super().__init__(*args, **kwargs) server, port = self.parse_connect(connect, server, port) - self.logger.debug(f'connecting to {server}:{port}') - self.socket.connect(f'tcp://{server}:{port}') + self.logger.debug(f"connecting to {server}:{port}") + self.socket.connect(f"tcp://{server}:{port}") @staticmethod - def parse_connect(connect : str | None=None, default_server : str='*', default_port : int=default_port) -> tuple[str, int]: + def parse_connect( + connect: str | None = None, default_server: str = "*", default_port: int = default_port + ) -> tuple[str, int]: server = default_server port = default_port if connect is not None: - s = connect.split(':', 1) + s = connect.split(":", 1) if len(s) == 2: - if s[0] != '': + if s[0] != "": server = s[0] - if s[1] != '': + if s[1] != "": port = int(s[1]) else: try: @@ -208,60 +236,79 @@ def parse_connect(connect : str | None=None, default_server : str='*', default_p return (server, port) - async def _handle_recv(self, data : bytes) -> None: + async def _handle_recv(self, data: bytes) -> None: await self.decode(data) async def _recv_init(self) -> None: # Indicate that we are connected. await self.mark_open() - await self._send(b'') + await self._send(b"") - async def encode(self, data : lprot.ProtocolLayer.Packet) -> None: + async def encode(self, data: lprot.ProtocolLayer.Packet) -> None: await super()._send(data) await super().encode(data) -lprot.register_layer_type(ZmqSocketClient) +lprot.register_layer_type(ZmqSocketClient) class ZmqSocketServer(ZmqSocketBase): - ''' + """ Generic ZMQ server (listening) socket layer. This layer is expected to be at the top of the protocol stack. Received data is passed down the stack. - ''' + """ default_port = 0 - name = 'sock' + name = "sock" @overload - def __init__(self, *args, type : int=zmq.DEALER, listen : str='*', port : int=default_port, context : zmq.asyncio.Context | None=None, **kwargs): ... + def __init__( + self, + *args, + type: int = zmq.DEALER, + listen: str = "*", + port: int = default_port, + context: zmq.asyncio.Context | None = None, + **kwargs, + ): ... @overload - def __init__(self, bind : str, *args, type : int=zmq.DEALER, context : zmq.asyncio.Context | None=None, **kwargs): ... - - def __init__(self, bind : str | None=None, *args, listen : str='*', port : int=default_port, **kwargs): + def __init__( + self, + bind: str, + *args, + type: int = zmq.DEALER, + context: zmq.asyncio.Context | None = None, + **kwargs, + ): ... + + def __init__( + self, bind: str | None = None, *args, listen: str = "*", port: int = default_port, **kwargs + ): super().__init__(*args, **kwargs) listen, port, random_port = self.parse_bind(bind, listen, port) if random_port: - self.logger.info(f'listening to {listen}:{port}') + self.logger.info(f"listening to {listen}:{port}") else: - self.logger.debug(f'listening to {listen}:{port}') + self.logger.debug(f"listening to {listen}:{port}") - self.socket.bind(f'tcp://{listen}:{port}') + self.socket.bind(f"tcp://{listen}:{port}") @staticmethod - def parse_bind(bind : str | None=None, default_listen : str='*', default_port : int=default_port) -> tuple[str, int, bool]: + def parse_bind( + bind: str | None = None, default_listen: str = "*", default_port: int = default_port + ) -> tuple[str, int, bool]: listen = default_listen port = default_port if bind is not None: - s = bind.split(':', 1) + s = bind.split(":", 1) if len(s) == 2: - if s[0] != '': + if s[0] != "": listen = s[0] - if s[1] != '': + if s[1] != "": port = int(s[1]) else: try: @@ -275,45 +322,52 @@ def parse_bind(bind : str | None=None, default_listen : str='*', default_port : return (listen, port, random_port) - async def _handle_recv(self, data : bytes) -> None: + async def _handle_recv(self, data: bytes) -> None: await self.encode(data) - async def decode(self, data : lprot.ProtocolLayer.Packet) -> None: + async def decode(self, data: lprot.ProtocolLayer.Packet) -> None: await super()._send(data) await super().decode(data) -lprot.register_layer_type(ZmqSocketServer) +lprot.register_layer_type(ZmqSocketServer) class ZmqServer(ZmqSocketServer): - ''' + """ A ZMQ Server, for REQ/REP debug messages. This can be used to create a bridge from an arbitrary interface to ZMQ, which in turn can be used to connect a libstored.asyncio.ZmqClient to. - ''' + """ default_port = lprot.default_port - name = 'zmq' + name = "zmq" @overload - def __init__(self, *args, listen : str='*', port : int=default_port, context : zmq.asyncio.Context | None=None, **kwargs): ... + def __init__( + self, + *args, + listen: str = "*", + port: int = default_port, + context: zmq.asyncio.Context | None = None, + **kwargs, + ): ... @overload - def __init__(self, bind : str, *args, context : zmq.asyncio.Context | None=None, **kwargs): ... + def __init__(self, bind: str, *args, context: zmq.asyncio.Context | None = None, **kwargs): ... - def __init__(self, bind : str | None=None, *args, **kwargs): + def __init__(self, bind: str | None = None, *args, **kwargs): super().__init__(bind, *args, type=zmq.REP, **kwargs) - self._req : bool = False + self._req: bool = False - async def _handle_recv(self, data : bytes) -> None: - assert not self._req, 'ZmqServer received request while previous request not yet handled' + async def _handle_recv(self, data: bytes) -> None: + assert not self._req, "ZmqServer received request while previous request not yet handled" self._req = True await super()._handle_recv(data) - async def decode(self, data : lprot.ProtocolLayer.Packet) -> None: + async def decode(self, data: lprot.ProtocolLayer.Packet) -> None: if not self._req: - self.logger.debug('Ignoring unexpected rep %s', data) + self.logger.debug("Ignoring unexpected rep %s", data) return self._req = False await super().decode(data) @@ -322,4 +376,5 @@ async def disconnected(self) -> None: await super().disconnected() self._req = False + lprot.register_layer_type(ZmqServer) diff --git a/python/libstored/protocol/zmqcat.py b/python/libstored/protocol/zmqcat.py index 40f2615f..cb832216 100644 --- a/python/libstored/protocol/zmqcat.py +++ b/python/libstored/protocol/zmqcat.py @@ -12,23 +12,26 @@ from .. import protocol as lprot from ..asyncio.worker import AsyncioWorker, run_sync + @run_sync -async def async_main(args : argparse.Namespace) -> None: +async def async_main(args: argparse.Namespace) -> None: - if args.type == 'dealer': + if args.type == "dealer": type = zmq.DEALER - elif args.type == 'pair': + elif args.type == "pair": type = zmq.PAIR - elif args.type == 'req': + elif args.type == "req": type = zmq.REQ else: - raise ValueError(f'Unknown socket type: {args.type}') + raise ValueError(f"Unknown socket type: {args.type}") - stack = lprot.stack([ - lprot.PrintLayer(), - lprot.StdinLayer(), - lprot.ZmqSocketClient(server=args.host, port=int(args.port), type=type) - ]) + stack = lprot.stack( + [ + lprot.PrintLayer(), + lprot.StdinLayer(), + lprot.ZmqSocketClient(server=args.host, port=int(args.port), type=type), + ] + ) try: while True: @@ -39,16 +42,22 @@ async def async_main(args : argparse.Namespace) -> None: await stack.close() - def main(): - parser = argparse.ArgumentParser(prog=__package__, - description='ZMQ cat utility that fits nicely with libstored.protocol.ZmqSocketServer', - formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = argparse.ArgumentParser( + prog=__package__, + description="ZMQ cat utility that fits nicely with libstored.protocol.ZmqSocketServer", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) from ..version import __version__ - parser.add_argument('-s', dest='host', help='Server hostname', default='localhost') - parser.add_argument('-p', dest='port', help='Specify TCP port') - parser.add_argument('-t', dest='type', choices=['dealer', 'pair', 'req'], help='Socket type', default='dealer') - parser.add_argument('-v', dest='verbose', default=0, help='Enable verbose output', action='count') + + parser.add_argument("-s", dest="host", help="Server hostname", default="localhost") + parser.add_argument("-p", dest="port", help="Specify TCP port") + parser.add_argument( + "-t", dest="type", choices=["dealer", "pair", "req"], help="Socket type", default="dealer" + ) + parser.add_argument( + "-v", dest="verbose", default=0, help="Enable verbose output", action="count" + ) args = parser.parse_args() @@ -67,5 +76,6 @@ def main(): except KeyboardInterrupt: w.cancel() -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/python/libstored/tk.py b/python/libstored/tk.py index a9ba76ee..7b2a2c77 100644 --- a/python/libstored/tk.py +++ b/python/libstored/tk.py @@ -8,57 +8,66 @@ import tkinter.ttk as ttk import typing + class Entry(ttk.Entry): class State(enum.IntEnum): EMPTY = enum.auto() FILLED = enum.auto() FOCUSED = enum.auto() - def __init__(self, parent : ttk.Widget, text : str = '', hint : str = '', hint_color='gray', \ - validation : str | typing.Callable[[str], bool]='', *args, **kwargs): + def __init__( + self, + parent: ttk.Widget, + text: str = "", + hint: str = "", + hint_color="gray", + validation: str | typing.Callable[[str], bool] = "", + *args, + **kwargs + ): super().__init__(parent, *args, **kwargs) self._var = tk.StringVar(self, value=text) - self['textvariable'] = self._var + self["textvariable"] = self._var self._hint = hint self._hint_color = hint_color - self._foreground_color = self['foreground'] + self._foreground_color = self["foreground"] - self.bind('', self._focus_in) - self.bind('', self._focus_out) + self.bind("", self._focus_in) + self.bind("", self._focus_out) - self.bind('', self._select_all) - self.bind('', self._select_all) + self.bind("", self._select_all) + self.bind("", self._select_all) - if validation != '': + if validation != "": if isinstance(validation, str): self._validation = lambda s: re.compile(validation).match(s) is not None else: self._validation = validation - self['validate'] = 'key' - self['validatecommand'] = (self.register(self._validate), '%P') + self["validate"] = "key" + self["validatecommand"] = (self.register(self._validate), "%P") - self._state = Entry.State.EMPTY if text == '' else Entry.State.FILLED + self._state = Entry.State.EMPTY if text == "" else Entry.State.FILLED self._update_style() @property def text(self) -> str: if self._state == Entry.State.EMPTY: - return '' + return "" else: return self._var.get() @text.setter - def text(self, value : str): - if self._state == Entry.State.EMPTY and value == '': + def text(self, value: str): + if self._state == Entry.State.EMPTY and value == "": pass - elif self._state == Entry.State.EMPTY and value != '': + elif self._state == Entry.State.EMPTY and value != "": self._state = Entry.State.FILLED self._var.set(value) self._update_style() - elif self._state == Entry.State.FILLED and value == '': + elif self._state == Entry.State.FILLED and value == "": self._state = Entry.State.EMPTY self._update_style() else: @@ -69,24 +78,24 @@ def hint(self) -> str: return self._hint @hint.setter - def hint(self, value : str): + def hint(self, value: str): self._hint = value if self._state == Entry.State.EMPTY: self._var.set(value) def _update_style(self): if self._state == Entry.State.EMPTY: - self['foreground'] = self._hint_color + self["foreground"] = self._hint_color self._var.set(self._hint) else: - self['foreground'] = self._foreground_color + self["foreground"] = self._foreground_color def _focus_in(self, *args): if self._state == Entry.State.FOCUSED: return if self._state == Entry.State.EMPTY: - self._var.set('') + self._var.set("") self._state = Entry.State.FOCUSED self._update_style() @@ -95,7 +104,7 @@ def _focus_out(self, *args): if self._state != Entry.State.FOCUSED: return - if self._var.get() == '': + if self._var.get() == "": self._state = Entry.State.EMPTY else: self._state = Entry.State.FILLED @@ -103,11 +112,11 @@ def _focus_out(self, *args): self._update_style() def _select_all(self, event): - event.widget.select_range(0, 'end') - event.widget.icursor('end') - return 'break' + event.widget.select_range(0, "end") + event.widget.icursor("end") + return "break" - def _validate(self, proposed : str) -> bool: + def _validate(self, proposed: str) -> bool: if self._validation(proposed): return True else: diff --git a/python/libstored/wrapper/serial.py b/python/libstored/wrapper/serial.py index 76b79e55..5b9ca975 100644 --- a/python/libstored/wrapper/serial.py +++ b/python/libstored/wrapper/serial.py @@ -13,18 +13,49 @@ from .. import protocol as lprot from ..asyncio.worker import AsyncioWorker, run_sync + def main(): - parser = argparse.ArgumentParser(description='serial wrapper to ZMQ server', - formatter_class=argparse.ArgumentDefaultsHelpFormatter, prog=__package__) - parser.add_argument('-V', '--version', action='version', version=__version__) - parser.add_argument('-l', '--listen', dest='zmqlisten', type=str, default='*', help='ZMQ listen address') - parser.add_argument('-p', '--port', dest='zmqport', type=int, default=lprot.default_port, help='ZMQ port') - parser.add_argument('port', help='serial port') - parser.add_argument('baud', nargs='?', type=int, default=115200, help='baud rate') - parser.add_argument('-r', '--rtscts', dest='rtscts', default=False, help='RTS/CTS flow control', action='store_true') - parser.add_argument('-x', '--xonxoff', dest='xonxoff', default=False, help='XON/XOFF flow control', action='store_true') - parser.add_argument('-v', '--verbose', dest='verbose', default=0, help='Enable verbose output', action='count') - parser.add_argument('-S', '--stack', dest='stack', type=str, default='ascii,pubterm,stdin', help='protocol stack') + parser = argparse.ArgumentParser( + description="serial wrapper to ZMQ server", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + prog=__package__, + ) + parser.add_argument("-V", "--version", action="version", version=__version__) + parser.add_argument( + "-l", "--listen", dest="zmqlisten", type=str, default="*", help="ZMQ listen address" + ) + parser.add_argument( + "-p", "--port", dest="zmqport", type=int, default=lprot.default_port, help="ZMQ port" + ) + parser.add_argument("port", help="serial port") + parser.add_argument("baud", nargs="?", type=int, default=115200, help="baud rate") + parser.add_argument( + "-r", + "--rtscts", + dest="rtscts", + default=False, + help="RTS/CTS flow control", + action="store_true", + ) + parser.add_argument( + "-x", + "--xonxoff", + dest="xonxoff", + default=False, + help="XON/XOFF flow control", + action="store_true", + ) + parser.add_argument( + "-v", "--verbose", dest="verbose", default=0, help="Enable verbose output", action="count" + ) + parser.add_argument( + "-S", + "--stack", + dest="stack", + type=str, + default="ascii,pubterm,stdin", + help="protocol stack", + ) args = parser.parse_args() @@ -39,16 +70,26 @@ def main(): with AsyncioWorker() as w: try: + @run_sync - async def async_main(args : argparse.Namespace): + async def async_main(args: argparse.Namespace): stack = lprot.build_stack( - ','.join([ - f'zmq={args.zmqlisten}:{args.zmqport}', - 'reqrepcheck', - re.sub(r'\bpubterm\b(,|$)', f'pubterm={args.zmqlisten}:{args.zmqport+1}\\1', args.stack)]) + ",".join( + [ + f"zmq={args.zmqlisten}:{args.zmqport}", + "reqrepcheck", + re.sub( + r"\bpubterm\b(,|$)", + f"pubterm={args.zmqlisten}:{args.zmqport+1}\\1", + args.stack, + ), + ] ) + ) - serial = lprot.SerialLayer(port=args.port, baudrate=args.baud, rtscts=args.rtscts, xonxoff=args.xonxoff) + serial = lprot.SerialLayer( + port=args.port, baudrate=args.baud, rtscts=args.rtscts, xonxoff=args.xonxoff + ) serial.wrap(stack) try: while True: @@ -60,5 +101,6 @@ async def async_main(args : argparse.Namespace): except KeyboardInterrupt: w.cancel() -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/python/libstored/wrapper/stdio.py b/python/libstored/wrapper/stdio.py index 28efea89..4085cc52 100644 --- a/python/libstored/wrapper/stdio.py +++ b/python/libstored/wrapper/stdio.py @@ -13,16 +13,33 @@ from .. import protocol as lprot from ..asyncio.worker import AsyncioWorker, run_sync + def main(): - parser = argparse.ArgumentParser(description='stdin/stdout wrapper to ZMQ server', - formatter_class=argparse.ArgumentDefaultsHelpFormatter, prog=__package__) - parser.add_argument('-V', '--version', action='version', version=__version__) - parser.add_argument('-l', '--listen', dest='listen', type=str, default='*', help='listen address') - parser.add_argument('-p', '--port', dest='port', type=int, default=lprot.default_port, help='port') - parser.add_argument('-S', '--stack', dest='stack', type=str, default='ascii,pubterm,stdin', help='protocol stack') - parser.add_argument('-v', '--verbose', dest='verbose', default=0, help='Enable verbose output', action='count') - parser.add_argument('command') - parser.add_argument('args', nargs='*') + parser = argparse.ArgumentParser( + description="stdin/stdout wrapper to ZMQ server", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + prog=__package__, + ) + parser.add_argument("-V", "--version", action="version", version=__version__) + parser.add_argument( + "-l", "--listen", dest="listen", type=str, default="*", help="listen address" + ) + parser.add_argument( + "-p", "--port", dest="port", type=int, default=lprot.default_port, help="port" + ) + parser.add_argument( + "-S", + "--stack", + dest="stack", + type=str, + default="ascii,pubterm,stdin", + help="protocol stack", + ) + parser.add_argument( + "-v", "--verbose", dest="verbose", default=0, help="Enable verbose output", action="count" + ) + parser.add_argument("command") + parser.add_argument("args", nargs="*") args = parser.parse_args() @@ -38,14 +55,22 @@ def main(): ret = 0 with AsyncioWorker() as w: try: + @run_sync - async def async_main(args : argparse.Namespace): + async def async_main(args: argparse.Namespace): stack = lprot.build_stack( - ','.join([ - f'zmq={args.listen}:{args.port}', - 'reqrepcheck', - re.sub(r'\bpubterm\b(,|$)', f'pubterm={args.listen}:{args.port+1}\\1', args.stack)]) + ",".join( + [ + f"zmq={args.listen}:{args.port}", + "reqrepcheck", + re.sub( + r"\bpubterm\b(,|$)", + f"pubterm={args.listen}:{args.port+1}\\1", + args.stack, + ), + ] ) + ) stdio = lprot.StdioLayer(cmd=[args.command] + args.args) stdio.wrap(stack) @@ -64,5 +89,6 @@ async def at_term(code): sys.exit(ret) -if __name__ == '__main__': + +if __name__ == "__main__": main() From c0f18d3cb523621f920e29496fb7437744cdfe7f Mon Sep 17 00:00:00 2001 From: Jochem Rutgers <68805714+jhrutgers@users.noreply.github.com> Date: Wed, 17 Dec 2025 16:21:25 +0100 Subject: [PATCH 07/15] improve process termination handling --- python/libstored/protocol/stdio.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/python/libstored/protocol/stdio.py b/python/libstored/protocol/stdio.py index da4f1b59..9e35f8af 100644 --- a/python/libstored/protocol/stdio.py +++ b/python/libstored/protocol/stdio.py @@ -96,6 +96,8 @@ class StdioLayer(lprot.ProtocolLayer): def __init__(self, cmd, *args, **kwargs): super().__init__(*args) + self.logger.debug(f"Starting process: {cmd}") + self._process = subprocess.Popen( args=cmd, stdin=subprocess.PIPE, @@ -127,7 +129,17 @@ def _from_process(self) -> bytes: if self._process.stdout is None or self._process.stdout.closed: raise RuntimeError("Process has no stdout anymore") - x = self._process.stdout.read1(4096) # type: ignore + x = b"" + try: + x = self._process.stdout.read1(4096) # type: ignore + except BaseException: + pass + + if x is None or x == b"": + if self._process.poll() is not None: + self._process.stdout.close() + raise RuntimeError("Process has terminated") + self.logger.debug("received %s", x) return x @@ -177,6 +189,10 @@ async def check_task() -> None: ret = self._process.poll() if ret is not None: self.logger.error(f"Process terminated with exit code {ret}") + if self._process.stdin is not None: + self._process.stdin.close() + if self._process.stdout is not None: + self._process.stdout.close() if asyncio.iscoroutinefunction(f): await f(ret) else: @@ -211,7 +227,6 @@ async def close(self) -> None: os.killpg(os.getpgid(self._process.pid), signal.SIGTERM) except ProcessLookupError: pass - self._process.terminate() if self._reader_task is not None: self._reader_task.cancel() @@ -233,6 +248,7 @@ async def close(self) -> None: pass self._check_task = None + self._process.terminate() await super().close() From 91f0cdcae98b90a9748693b1603b35279e34dd0c Mon Sep 17 00:00:00 2001 From: Jochem Rutgers <68805714+jhrutgers@users.noreply.github.com> Date: Thu, 18 Dec 2025 15:55:25 +0100 Subject: [PATCH 08/15] fix warning --- python/libstored/generator/dsl/types.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/libstored/generator/dsl/types.py b/python/libstored/generator/dsl/types.py index b9062213..c11cbfe0 100644 --- a/python/libstored/generator/dsl/types.py +++ b/python/libstored/generator/dsl/types.py @@ -579,6 +579,7 @@ def _c_impl(self, tree, index, placeholders, indent): return res def c_decl(self): + assert self.cname is not None return self.cname + "(" + ", ".join([f"int {x}" for x in self.placeholders()]) + ")" @@ -754,7 +755,7 @@ def __init__(self, parent, type, name): # Empty string, handle as default-initialized. self.init = None self.len = type.blob.len - self.axi = None + self.axi: int | None = None def isBlob(self): return self.type in ["blob", "string"] From 5e9f37c00bf37b9db5527209a7ad9eff4a9fa447 Mon Sep 17 00:00:00 2001 From: Jochem Rutgers <68805714+jhrutgers@users.noreply.github.com> Date: Thu, 18 Dec 2025 21:43:00 +0100 Subject: [PATCH 09/15] fix sending buffer --- python/libstored/protocol/serial.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/libstored/protocol/serial.py b/python/libstored/protocol/serial.py index a50edaf3..2f706c03 100644 --- a/python/libstored/protocol/serial.py +++ b/python/libstored/protocol/serial.py @@ -70,13 +70,14 @@ async def _serial_run(self, drop_s: float | None, serial_args: dict) -> None: self._read, thread_name=f"{self.__class__.__name__}-reader" ) as reader: try: + self._writer = writer + if self._encode_buffer: self.logger.debug("sending buffered %s", self._encode_buffer) await self._encode(self._encode_buffer) self._encode_buffer = bytearray() - self._writer = writer - + self.logger.debug("ready") while self._open: data = await reader.read() await self.decode(data) From 26f43566542224d75d9cb4eb7671b116591011c6 Mon Sep 17 00:00:00 2001 From: Jochem Rutgers <68805714+jhrutgers@users.noreply.github.com> Date: Fri, 19 Dec 2025 01:21:21 +0100 Subject: [PATCH 10/15] improve reqrep disconnect handling --- python/libstored/protocol/protocol.py | 48 +++++++++++++++++++-------- 1 file changed, 35 insertions(+), 13 deletions(-) diff --git a/python/libstored/protocol/protocol.py b/python/libstored/protocol/protocol.py index f02ffc70..4071572d 100644 --- a/python/libstored/protocol/protocol.py +++ b/python/libstored/protocol/protocol.py @@ -485,25 +485,35 @@ class ReqRepCheckLayer(ProtocolLayer): name = "reqrepcheck" - def __init__(self, timeout_s: float = 1, *args, **kwargs): + def __init__( + self, timeout_s: float | None = 1, error_rep: bytes | None = b"?", *args, **kwargs + ): super().__init__(*args, **kwargs) self._req: bool = False - self._timeout_s: float = timeout_s + self._timeout_s: float | None = None + self._error_rep: bytes | None = error_rep self._retransmit_time: float = 0 - self._retransmitter: asyncio.Task | None = asyncio.create_task( - self._retransmitter_task(), name=self.__class__.__name__ - ) + self._retransmitter: asyncio.Task | None = None + self.timeout_s = timeout_s @property - def timeout_s(self) -> float: + def timeout_s(self) -> float | None: return self._timeout_s @timeout_s.setter - def timeout_s(self, value: float) -> None: - if not self._req: + def timeout_s(self, value: float | None) -> None: + if self._req and value is not None: self._retransmit_time = time.time() + value self._timeout_s = value + if value is None and self._retransmitter is not None: + self._retransmitter.cancel() + self._retransmitter = None + elif value is not None and self._retransmitter is None: + self._retransmitter = asyncio.create_task( + self._retransmitter_task(), name=self.__class__.__name__ + ) + @property def req(self) -> bool: """ @@ -514,15 +524,17 @@ def req(self) -> bool: async def _retransmitter_task(self) -> None: try: dt_s = self._timeout_s - while True: + while dt_s is not None: await asyncio.sleep(dt_s) if not self._req: continue + if self._timeout_s is None: + break now = time.time() dt_s = self._retransmit_time - now if dt_s <= 0: - self._retransmit_time = now + self._timeout_s dt_s = self._timeout_s + self._retransmit_time = now + dt_s await super().timeout() except asyncio.CancelledError: pass @@ -534,6 +546,15 @@ async def timeout(self) -> None: # Ignore timeouts from above, we are checking for retransmissions ourselves. pass + async def _send_error_rep(self) -> None: + if not self._req: + return + + if self._error_rep is not None: + await self.decode(self._error_rep) + else: + self._req = False + async def close(self) -> None: if self._retransmitter is not None: self._retransmitter.cancel() @@ -544,6 +565,7 @@ async def close(self) -> None: self._retransmitter = None await super().close() + await self._send_error_rep() async def encode(self, data: ProtocolLayer.Packet) -> None: if self._req: @@ -552,7 +574,7 @@ async def encode(self, data: ProtocolLayer.Packet) -> None: ) self._req = True - self._retransmit_time = time.time() + self._timeout_s + self._retransmit_time = time.time() + (self._timeout_s or 0) await super().encode(data) async def decode(self, data: ProtocolLayer.Packet) -> None: @@ -564,8 +586,8 @@ async def decode(self, data: ProtocolLayer.Packet) -> None: await super().decode(data) async def disconnected(self) -> None: + await self._send_error_rep() await super().disconnected() - self._req = False class SegmentationLayer(ProtocolLayer): @@ -935,7 +957,7 @@ async def decode(self, data: ProtocolLayer.Packet) -> None: if do_decode: break - if do_decode: + if do_decode and len(data) > 0: self._pause_transmit = True try: await super().decode(data) From 94588c6db710a161753c405df6929f95ab886b48 Mon Sep 17 00:00:00 2001 From: Jochem Rutgers <68805714+jhrutgers@users.noreply.github.com> Date: Mon, 22 Dec 2025 20:52:50 +0100 Subject: [PATCH 11/15] optimize --- include/libstored/protocol.h | 63 ++++++++++++++++++++++++++++-------- src/protocol.cpp | 3 +- 2 files changed, 51 insertions(+), 15 deletions(-) diff --git a/include/libstored/protocol.h b/include/libstored/protocol.h index d54fb121..238c7e58 100644 --- a/include/libstored/protocol.h +++ b/include/libstored/protocol.h @@ -269,6 +269,11 @@ class ProtocolLayer { return m_down; } +# ifdef STORED_COMPILER_GCC +# pragma GCC push_options +# pragma GCC optimize("-foptimize-sibling-calls") +# endif // GCC + /*! * \brief Decode a frame and forward the decoded frame to the upper layer. * @@ -276,8 +281,11 @@ class ProtocolLayer { */ virtual void decode(void* buffer, size_t len) { - if(up()) - up()->decode(buffer, len); + ProtocolLayer* p = up(); + if(!p) + return; + + p->decode(buffer, len); } /*! @@ -296,8 +304,11 @@ class ProtocolLayer { */ virtual void encode(void const* buffer, size_t len, bool last = true) { - if(down()) - down()->encode(buffer, len, last); + ProtocolLayer* p = down(); + if(!p) + return; + + p->encode(buffer, len, last); } /*! @@ -316,8 +327,11 @@ class ProtocolLayer { */ virtual void setPurgeableResponse(bool purgeable = true) { - if(down()) - down()->setPurgeableResponse(purgeable); + ProtocolLayer* p = down(); + if(!p) + return; + + p->setPurgeableResponse(purgeable); } /*! @@ -331,7 +345,11 @@ class ProtocolLayer { */ virtual size_t mtu() const { - return down() ? down()->mtu() : 0; + ProtocolLayer* p = down(); + if(!p) + return 0; + + return p->mtu(); } /*! @@ -346,7 +364,11 @@ class ProtocolLayer { */ virtual bool flush() { - return down() ? down()->flush() : true; + ProtocolLayer* p = down(); + if(!p) + return true; + + return p->flush(); } /*! @@ -354,8 +376,11 @@ class ProtocolLayer { */ virtual void reset() { - if(down()) - down()->reset(); + ProtocolLayer* p = down(); + if(!p) + return; + + p->reset(); } /*! @@ -363,8 +388,11 @@ class ProtocolLayer { */ virtual void connected() { - if(up()) - up()->connected(); + ProtocolLayer* p = up(); + if(!p) + return; + + p->connected(); } /*! @@ -372,10 +400,17 @@ class ProtocolLayer { */ virtual void disconnected() { - if(up()) - up()->disconnected(); + ProtocolLayer* p = up(); + if(!p) + return; + + p->disconnected(); } +# ifdef STORED_COMPILER_GCC +# pragma GCC pop_options +# endif // GCC + private: /*! \brief The layer above this one. */ ProtocolLayer* m_up; diff --git a/src/protocol.cpp b/src/protocol.cpp index 4cba5809..ff692e67 100644 --- a/src/protocol.cpp +++ b/src/protocol.cpp @@ -2690,7 +2690,8 @@ bool FileLayer::isOpen() const void FileLayer::encode(void const* buffer, size_t len, bool last) { if(m_fd_w == -1) { - setLastError(EBADF); + if(lastError() == 0) + setLastError(EBADF); done: base::encode(buffer, len, last); return; From ce25781fe6f89091096fa5fc5c23d41ca1da5d26 Mon Sep 17 00:00:00 2001 From: Jochem Rutgers <68805714+jhrutgers@users.noreply.github.com> Date: Mon, 22 Dec 2025 20:52:59 +0100 Subject: [PATCH 12/15] add req client --- python/libstored/protocol/zmq.py | 69 ++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/python/libstored/protocol/zmq.py b/python/libstored/protocol/zmq.py index b03e21b7..1ca7c4bd 100644 --- a/python/libstored/protocol/zmq.py +++ b/python/libstored/protocol/zmq.py @@ -252,6 +252,75 @@ async def encode(self, data: lprot.ProtocolLayer.Packet) -> None: lprot.register_layer_type(ZmqSocketClient) +class ZmqReqClient(ZmqSocketClient): + """ + Generic ZMQ client REQ/REP layer. + + This layer is expected to be at the bottom of the protocol stack. + Received data is passed up the stack. + """ + + default_port = lprot.default_port + name = "req" + + @overload + def __init__( + self, + *args, + server: str = "localhost", + port: int = default_port, + context: zmq.asyncio.Context | None = None, + **kwargs, + ): ... + @overload + def __init__( + self, + connect: str, + *args, + context: zmq.asyncio.Context | None = None, + **kwargs, + ): ... + + def __init__( + self, + connect: str | None = None, + *args, + server: str = "localhost", + port: int = default_port, + **kwargs, + ): + super().__init__(connect, *args, server=server, port=port, type=zmq.REQ, **kwargs) + self._lock = asyncio.Lock() + self._rep: asyncio.Future | None = None + + async def _recv_init(self) -> None: + pass + + async def _handle_recv(self, data: bytes) -> None: + assert self._rep is not None + assert not self._rep.done() + self._rep.set_result(data) + + async def encode(self, data: lprot.ProtocolLayer.Packet) -> None: + rep = await self.req(data) + await super().decode(rep) + + async def req(self, data: lprot.ProtocolLayer.Packet) -> lprot.ProtocolLayer.Packet: + async with self._lock: + await self.mark_open() + self._rep = asyncio.get_running_loop().create_future() + await super()._send(data) + rep = await self._rep + self._rep = None + return rep + + async def decode(self, data: lprot.ProtocolLayer.Packet) -> None: + pass + + +lprot.register_layer_type(ZmqReqClient) + + class ZmqSocketServer(ZmqSocketBase): """ Generic ZMQ server (listening) socket layer. From bccd5daefcbd6dd6f767d0dc2c308e2ec3445b8a Mon Sep 17 00:00:00 2001 From: Jochem Rutgers <68805714+jhrutgers@users.noreply.github.com> Date: Tue, 23 Dec 2025 11:17:19 +0100 Subject: [PATCH 13/15] fix empty packets through mux --- include/libstored/protocol.h | 1 + python/libstored/protocol/protocol.py | 21 +++++++++++++++------ src/protocol.cpp | 12 ++++++++++-- 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/include/libstored/protocol.h b/include/libstored/protocol.h index 238c7e58..19f2290c 100644 --- a/include/libstored/protocol.h +++ b/include/libstored/protocol.h @@ -1505,6 +1505,7 @@ class MuxLayer : public ProtocolLayer { ChannelId m_encodingChannel; ProtocolLayer* m_decodingChannel; bool m_decodingEsc; + bool m_encoding; }; namespace impl { diff --git a/python/libstored/protocol/protocol.py b/python/libstored/protocol/protocol.py index 4071572d..578a9f2d 100644 --- a/python/libstored/protocol/protocol.py +++ b/python/libstored/protocol/protocol.py @@ -1411,9 +1411,6 @@ async def _encode(self, chan: int, data: ProtocolLayer.Packet) -> None: elif isinstance(data, memoryview): data = data.tobytes() - if len(data) == 0: - return - now = time.time() prefix = b"" if self._decoding is None: @@ -1422,6 +1419,7 @@ async def _encode(self, chan: int, data: ProtocolLayer.Packet) -> None: self._prev is None or chan != self._prev or (now - self._t_prev) >= self._repeat_interval + or len(data) == 0 ): prefix = bytes([self.esc, chan]) @@ -1435,6 +1433,8 @@ async def decode(self, data: ProtocolLayer.Packet) -> None: if not isinstance(data, memoryview): data = memoryview(data) data = data.cast("B") + decoded = bytearray() + do_decode = self._decoding_esc start = 0 for i in range(len(data)): @@ -1444,20 +1444,29 @@ async def decode(self, data: ProtocolLayer.Packet) -> None: if data[i] == self.esc: # esc was in the data - await self._dispatch(bytes([self.esc])) + decoded.append(self.esc) elif data[i] == self.repeat: # Repeat last channel request self._prev = None else: # Switched channel + if do_decode: + await self._dispatch(decoded) + decoded = bytearray() + do_decode = True self._decoding = data[i] elif data[i] == self.esc: if i > start: - await self._dispatch(data[start:i]) + decoded += data[start:i] + do_decode = True self._decoding_esc = True if not self._decoding_esc and start < len(data): - await self._dispatch(data[start:]) + decoded += data[start:] + do_decode = True + + if do_decode: + await self._dispatch(decoded) async def _dispatch(self, data: bytes | memoryview) -> None: chan = self._decoding diff --git a/src/protocol.cpp b/src/protocol.cpp index ff692e67..18aa2426 100644 --- a/src/protocol.cpp +++ b/src/protocol.cpp @@ -1994,6 +1994,7 @@ MuxLayer::MuxLayer(ProtocolLayer* up, ProtocolLayer* down) , m_encodingChannel(Repeat) , m_decodingChannel() , m_decodingEsc() + , m_encoding() {} /*! @@ -2226,7 +2227,8 @@ void MuxLayer::encode_(ChannelId channel, void const* buffer, size_t len, bool l if(channel == Repeat) return; - if(channel != m_encodingChannel) { + if(channel != m_encodingChannel || (!m_encoding && !len && last)) { + m_encoding = true; uint8_t buf[2] = {Esc, channel}; base::encode(buf, sizeof(buf), false); m_encodingChannel = channel; @@ -2249,14 +2251,20 @@ void MuxLayer::encode_(ChannelId channel, void const* buffer, size_t len, bool l break; // Found an escape byte at position c. Repeat escape byte. + m_encoding = true; base::encode(buffer_ + i, c - i + 1, false); enc_last = c == len && last; base::encode(buffer_ + c, 1, enc_last); i = c + 1; } - if(!enc_last && (i < len || last)) + if(!enc_last && (i < len || last)) { + m_encoding = true; base::encode(buffer_ + i, len - i, last); + } + + if(last) + m_encoding = false; } size_t MuxLayer::mtu() const From 6f5bc6e120638378b12691f75a61fe7a2f8004dc Mon Sep 17 00:00:00 2001 From: Jochem Rutgers <68805714+jhrutgers@users.noreply.github.com> Date: Tue, 23 Dec 2025 13:53:22 +0100 Subject: [PATCH 14/15] fix empty message handling --- .vscode/c_cpp_properties.json | 3 ++- include/libstored/protocol.h | 16 ++++++++-------- src/protocol.cpp | 12 +++++++----- tests/test_protocol.cpp | 9 +++++++++ 4 files changed, 26 insertions(+), 14 deletions(-) diff --git a/.vscode/c_cpp_properties.json b/.vscode/c_cpp_properties.json index 4a1f1ce8..728f06e5 100644 --- a/.vscode/c_cpp_properties.json +++ b/.vscode/c_cpp_properties.json @@ -3,6 +3,7 @@ "includePath": [ "${workspaceFolder}/include", "${workspaceFolder}/dist/*/build/*-extern-prefix/src/**", + "${workspaceFolder}/dist/*/build/_deps/*-src/*/include", "${workspaceFolder}/examples/**", "${workspaceFolder}/tests/**" ], @@ -53,4 +54,4 @@ } ], "version": 4 -} +} \ No newline at end of file diff --git a/include/libstored/protocol.h b/include/libstored/protocol.h index 19f2290c..23ddbaf0 100644 --- a/include/libstored/protocol.h +++ b/include/libstored/protocol.h @@ -281,7 +281,7 @@ class ProtocolLayer { */ virtual void decode(void* buffer, size_t len) { - ProtocolLayer* p = up(); + ProtocolLayer* const p = up(); if(!p) return; @@ -304,7 +304,7 @@ class ProtocolLayer { */ virtual void encode(void const* buffer, size_t len, bool last = true) { - ProtocolLayer* p = down(); + ProtocolLayer* const p = down(); if(!p) return; @@ -327,7 +327,7 @@ class ProtocolLayer { */ virtual void setPurgeableResponse(bool purgeable = true) { - ProtocolLayer* p = down(); + ProtocolLayer* const p = down(); if(!p) return; @@ -345,7 +345,7 @@ class ProtocolLayer { */ virtual size_t mtu() const { - ProtocolLayer* p = down(); + ProtocolLayer const* const p = down(); if(!p) return 0; @@ -364,7 +364,7 @@ class ProtocolLayer { */ virtual bool flush() { - ProtocolLayer* p = down(); + ProtocolLayer* const p = down(); if(!p) return true; @@ -376,7 +376,7 @@ class ProtocolLayer { */ virtual void reset() { - ProtocolLayer* p = down(); + ProtocolLayer* const p = down(); if(!p) return; @@ -388,7 +388,7 @@ class ProtocolLayer { */ virtual void connected() { - ProtocolLayer* p = up(); + ProtocolLayer* const p = up(); if(!p) return; @@ -400,7 +400,7 @@ class ProtocolLayer { */ virtual void disconnected() { - ProtocolLayer* p = up(); + ProtocolLayer* const p = up(); if(!p) return; diff --git a/src/protocol.cpp b/src/protocol.cpp index 18aa2426..7c84e982 100644 --- a/src/protocol.cpp +++ b/src/protocol.cpp @@ -2169,6 +2169,7 @@ void MuxLayer::decode(void* buffer, size_t len) size_t out_start = 0; size_t out_end = 0; size_t in = 0; + bool do_decode = m_decodingEsc; while(in < len) { uint8_t b = buffer_[in++]; @@ -2178,13 +2179,16 @@ void MuxLayer::decode(void* buffer, size_t len) if(b == Esc) { // Escaped escape byte. buffer_[out_end++] = Esc; + do_decode = true; } else if(b == Repeat) { // Just a control command in between. // Repeat channel id on next encode. m_encodingChannel = Repeat; } else { // Switch channel. - decode_(buffer_ + out_start, out_end - out_start); + if(do_decode) + decode_(buffer_ + out_start, out_end - out_start); + do_decode = true; out_start = out_end = in; m_decodingChannel = channel(b); } @@ -2197,11 +2201,12 @@ void MuxLayer::decode(void* buffer, size_t len) size_t i = out_end++; if(unlikely(out_end != in)) buffer_[i] = b; + do_decode = true; } } } - if(likely(out_end > out_start)) + if(likely(out_end > out_start || do_decode)) decode_(buffer_ + out_start, out_end - out_start); } @@ -2210,9 +2215,6 @@ void MuxLayer::decode(void* buffer, size_t len) */ void MuxLayer::decode_(void* buffer, size_t len) { - if(!buffer || len == 0) - return; - if(m_decodingChannel == this) base::decode(buffer, len); else if(m_decodingChannel) diff --git a/tests/test_protocol.cpp b/tests/test_protocol.cpp index 8c4453a1..ffa939a3 100644 --- a/tests/test_protocol.cpp +++ b/tests/test_protocol.cpp @@ -1466,6 +1466,9 @@ TEST(MuxLayer, Encode) ch2.encode(" ch\x10 2", 6); EXPECT_EQ(ll.encoded().at(3), "\x10\x02 ch\x10\x10 2"); + + ch2.encode("", 0); + EXPECT_EQ(ll.encoded().at(4), "\x10\x02"); } TEST(MuxLayer, Decode) @@ -1504,6 +1507,12 @@ TEST(MuxLayer, Decode) DECODE(l, "\x10\x03 ch?"); EXPECT_EQ(ch0.decoded().size(), 3U); + + DECODE(l, "\x10\x02"); + EXPECT_EQ(ch2.decoded().at(3), ""); + + DECODE(l, "\x10\x03"); + EXPECT_EQ(ch2.decoded().size(), 4U); } TEST(Aes256Layer, EncodeDecode) From 77dbf34a67a7c560fe90ce51d7a214e5390d5fcb Mon Sep 17 00:00:00 2001 From: Jochem Rutgers <68805714+jhrutgers@users.noreply.github.com> Date: Tue, 23 Dec 2025 14:01:57 +0100 Subject: [PATCH 15/15] fix warning --- include/libstored/util.h | 236 ++++++++++++++++++++------------------- 1 file changed, 119 insertions(+), 117 deletions(-) diff --git a/include/libstored/util.h b/include/libstored/util.h index 7df7a4e9..19c0c936 100644 --- a/include/libstored/util.h +++ b/include/libstored/util.h @@ -38,7 +38,10 @@ */ #ifndef likely # ifdef __GNUC__ -# define likely(expr) __builtin_expect(!!(expr), 1) +# define likely(expr) \ + __builtin_expect( \ + !!(expr), /* NOLINT(readability-simplify-boolean-expr) */ \ + 1) # else # define likely(expr) (expr) # endif @@ -52,10 +55,10 @@ */ #ifndef unlikely # ifdef __GNUC__ -# define unlikely(expr) \ - __builtin_expect( \ - !!(expr), /* NOLINT(readability-simplify-boolean-expr) */ \ - 0) +# define unlikely(expr) \ + __builtin_expect( \ + !!(expr), /* NOLINT(readability-simplify-boolean-expr) */ \ + 0) # else # define unlikely(expr) (expr) # endif @@ -73,9 +76,9 @@ # define stored_yield() zth_yield() # endif # else -# define stored_yield() \ - do { /* NOLINT(cppcoreguidelines-avoid-do-while) */ \ - } while(false) +# define stored_yield() \ + do { /* NOLINT(cppcoreguidelines-avoid-do-while) */ \ + } while(false) # endif #endif @@ -125,11 +128,11 @@ #ifdef STORED_HAVE_VALGRIND # define STORED_MAKE_MEM_NOACCESS_VALGRIND(buffer, size) \ - (void)VALGRIND_MAKE_MEM_NOACCESS(buffer, size) + (void)VALGRIND_MAKE_MEM_NOACCESS(buffer, size) # define STORED_MAKE_MEM_UNDEFINED_VALGRIND(buffer, size) \ - (void)VALGRIND_MAKE_MEM_UNDEFINED(buffer, size) + (void)VALGRIND_MAKE_MEM_UNDEFINED(buffer, size) # define STORED_MAKE_MEM_DEFINED_VALGRIND(buffer, size) \ - (void)VALGRIND_MAKE_MEM_DEFINED(buffer, size) + (void)VALGRIND_MAKE_MEM_DEFINED(buffer, size) #else // !STORED_HAVE_VALGRIND # define STORED_MAKE_MEM_NOACCESS_VALGRIND(buffer, size) (void)0 # define STORED_MAKE_MEM_UNDEFINED_VALGRIND(buffer, size) (void)0 @@ -151,40 +154,40 @@ #if !defined(NDEBUG) \ && ((defined(STORED_HAVE_VALGRIND) && !defined(NVALGRIND)) || defined(STORED_ENABLE_ASAN)) -# define STORED_MAKE_MEM_NOACCESS(buffer, size) \ - do { /* NOLINT(cppcoreguidelines-avoid-do-while) */ \ - void* b_ = (void*)(buffer); \ - size_t s_ = (size_t)(size); \ - STORED_MAKE_MEM_NOACCESS_VALGRIND(b_, s_); \ - STORED_MAKE_MEM_NOACCESS_ASAN(b_, s_); \ - } while(0) - -# define STORED_MAKE_MEM_UNDEFINED(buffer, size) \ - do { /* NOLINT(cppcoreguidelines-avoid-do-while) */ \ - void* b_ = (void*)(buffer); \ - size_t s_ = (size_t)(size); \ - STORED_MAKE_MEM_UNDEFINED_VALGRIND(b_, s_); \ - STORED_MAKE_MEM_UNDEFINED_ASAN(b_, s_); \ - if(Config::Debug && !RUNNING_ON_VALGRIND && b_) \ - memset(b_, 0xef, s_); \ - } while(0) -# define STORED_MAKE_MEM_DEFINED(buffer, size) \ - do { /* NOLINT(cppcoreguidelines-avoid-do-while) */ \ - void* b_ = (void*)(buffer); \ - size_t s_ = (size_t)(size); \ - STORED_MAKE_MEM_DEFINED_VALGRIND(b_, s_); \ - STORED_MAKE_MEM_DEFINED_ASAN(b_, s_); \ - } while(0) +# define STORED_MAKE_MEM_NOACCESS(buffer, size) \ + do { /* NOLINT(cppcoreguidelines-avoid-do-while) */ \ + void* b_ = (void*)(buffer); \ + size_t s_ = (size_t)(size); \ + STORED_MAKE_MEM_NOACCESS_VALGRIND(b_, s_); \ + STORED_MAKE_MEM_NOACCESS_ASAN(b_, s_); \ + } while(0) + +# define STORED_MAKE_MEM_UNDEFINED(buffer, size) \ + do { /* NOLINT(cppcoreguidelines-avoid-do-while) */ \ + void* b_ = (void*)(buffer); \ + size_t s_ = (size_t)(size); \ + STORED_MAKE_MEM_UNDEFINED_VALGRIND(b_, s_); \ + STORED_MAKE_MEM_UNDEFINED_ASAN(b_, s_); \ + if(Config::Debug && !RUNNING_ON_VALGRIND && b_) \ + memset(b_, 0xef, s_); \ + } while(0) +# define STORED_MAKE_MEM_DEFINED(buffer, size) \ + do { /* NOLINT(cppcoreguidelines-avoid-do-while) */ \ + void* b_ = (void*)(buffer); \ + size_t s_ = (size_t)(size); \ + STORED_MAKE_MEM_DEFINED_VALGRIND(b_, s_); \ + STORED_MAKE_MEM_DEFINED_ASAN(b_, s_); \ + } while(0) #else -# define STORED_MAKE_MEM_NOACCESS(buffer, size) \ - do { /* NOLINT(cppcoreguidelines-avoid-do-while) */ \ - } while(0) -# define STORED_MAKE_MEM_UNDEFINED(buffer, size) \ - do { /* NOLINT(cppcoreguidelines-avoid-do-while) */ \ - } while(0) -# define STORED_MAKE_MEM_DEFINED(buffer, size) \ - do { /* NOLINT(cppcoreguidelines-avoid-do-while) */ \ - } while(0) +# define STORED_MAKE_MEM_NOACCESS(buffer, size) \ + do { /* NOLINT(cppcoreguidelines-avoid-do-while) */ \ + } while(0) +# define STORED_MAKE_MEM_UNDEFINED(buffer, size) \ + do { /* NOLINT(cppcoreguidelines-avoid-do-while) */ \ + } while(0) +# define STORED_MAKE_MEM_DEFINED(buffer, size) \ + do { /* NOLINT(cppcoreguidelines-avoid-do-while) */ \ + } while(0) #endif @@ -208,7 +211,7 @@ */ # ifndef DOXYGEN # define SFINAE_IS_FUNCTION(T, F, T_OK) \ - typename std::enable_if, T>::value, T_OK>::type + typename std::enable_if, T>::value, T_OK>::type # else # define SFINAE_IS_FUNCTION(T, F, T_OK) T_OK # endif @@ -218,10 +221,10 @@ # if defined(STORED_cplusplus) && STORED_cplusplus < 201103L && !defined(static_assert) \ && !defined(STORED_COMPILER_MSVC) -# define static_assert(expr, msg) \ - do { /* NOLINT(cppcoreguidelines-avoid-do-while) */ \ - typedef __attribute__((unused)) int static_assert_[(expr) ? 1 : -1]; \ - } while(0) +# define static_assert(expr, msg) \ + do { /* NOLINT(cppcoreguidelines-avoid-do-while) */ \ + typedef __attribute__((unused)) int static_assert_[(expr) ? 1 : -1]; \ + } while(0) # endif # ifndef STORED_CLASS_NOCOPY @@ -236,35 +239,35 @@ * \param Class the class this macro is embedded in */ # if STORED_cplusplus >= 201103L -# define STORED_CLASS_NOCOPY(Class) \ - public: \ - /*! \brief Deleted copy constructor. */ \ - Class(Class const&) = delete; \ - /*! \brief Default move constructor. */ \ - Class(Class&&) noexcept = default; /* NOLINT */ \ - /*! \brief Deleted assignment operator. */ \ - void operator=(Class const&) = delete; \ - /*! \brief Default move assignment operator. */ \ - Class& operator=(Class&&) noexcept = default; /* NOLINT */ +# define STORED_CLASS_NOCOPY(Class) \ +public: \ + /*! \brief Deleted copy constructor. */ \ + Class(Class const&) = delete; \ + /*! \brief Default move constructor. */ \ + Class(Class&&) noexcept = default; /* NOLINT */ \ + /*! \brief Deleted assignment operator. */ \ + void operator=(Class const&) = delete; \ + /*! \brief Default move assignment operator. */ \ + Class& operator=(Class&&) noexcept = default; /* NOLINT */ # else -# define STORED_CLASS_NOCOPY(Class) \ - private: \ - /*! \brief Deleted copy constructor. */ \ - Class(Class const&); \ - /*! \brief Deleted assignment operator. */ \ - void operator=(Class const&); +# define STORED_CLASS_NOCOPY(Class) \ +private: \ + /*! \brief Deleted copy constructor. */ \ + Class(Class const&); \ + /*! \brief Deleted assignment operator. */ \ + void operator=(Class const&); # endif # endif # if STORED_cplusplus >= 201103L && !defined(STORED_CLASS_DEFAULT_COPY_MOVE) -# define STORED_CLASS_DEFAULT_COPY_MOVE(type) \ - public: \ - type(type const&) = default; \ - type(type&&) noexcept = default; \ - type& operator=(type const&) = default; \ - type& operator=(type&&) noexcept = default; \ - \ - private: +# define STORED_CLASS_DEFAULT_COPY_MOVE(type) \ + public: \ + type(type const&) = default; \ + type(type&&) noexcept = default; \ + type& operator=(type const&) = default; \ + type& operator=(type&&) noexcept = default; \ + \ + private: # endif # ifndef CLASS_NO_WEAK_VTABLE @@ -273,19 +276,19 @@ * \details Use CLASS_NO_WEAK_VTABLE_DEF() in one .cpp file. */ # define CLASS_NO_WEAK_VTABLE \ - protected: \ - void force_to_translation_unit(); + protected: \ + void force_to_translation_unit(); /*! * \see CLASS_NO_WEAK_VTABLE */ // cppcheck-suppress-macro duplInheritedMember -# define CLASS_NO_WEAK_VTABLE_DEF(Class) \ - /*! \brief Dummy function to force the vtable of this \ - * class to this translation unit. Don't call. */ \ - void Class::force_to_translation_unit() \ - { \ - abort(); \ - } +# define CLASS_NO_WEAK_VTABLE_DEF(Class) \ + /*! \brief Dummy function to force the vtable of this \ + * class to this translation unit. Don't call. */ \ + void Class::force_to_translation_unit() \ + { \ + abort(); \ + } # endif namespace stored { @@ -294,19 +297,19 @@ namespace stored { * \brief Like \c assert(), but only emits code when #stored::Config::EnableAssert. */ # ifdef STORED_HAVE_ZTH -# define stored_assert(expr) \ - do { /* NOLINT(cppcoreguidelines-avoid-do-while) */ \ - if(::stored::Config::EnableAssert) { \ - zth_assert(expr); \ - } \ - } while(false) +# define stored_assert(expr) \ + do { /* NOLINT(cppcoreguidelines-avoid-do-while) */ \ + if(::stored::Config::EnableAssert) { \ + zth_assert(expr); \ + } \ + } while(false) # else -# define stored_assert(expr) \ - do { /* NOLINT(cppcoreguidelines-avoid-do-while) */ \ - if(::stored::Config::EnableAssert) { \ - assert(expr); \ - } \ - } while(false) +# define stored_assert(expr) \ + do { /* NOLINT(cppcoreguidelines-avoid-do-while) */ \ + if(::stored::Config::EnableAssert) { \ + assert(expr); \ + } \ + } while(false) # endif void swap_endian(void* buffer, size_t len) noexcept; @@ -835,10 +838,10 @@ using store_t = typename store::type; * * \hideinitializer */ -# define STORE_T(Impl, ...) \ - EXPAND(STORED_GET_MACRO_ARGN( \ - Impl, ##__VA_ARGS__, STORE_T_9, STORE_T_8, STORE_T_7, STORE_T_6, STORE_T_5, \ - STORE_T_4, STORE_T_3, STORE_T_2, STORE_T_1)(Impl, ##__VA_ARGS__)) +# define STORE_T(Impl, ...) \ + EXPAND(STORED_GET_MACRO_ARGN( \ + Impl, ##__VA_ARGS__, STORE_T_9, STORE_T_8, STORE_T_7, STORE_T_6, STORE_T_5, STORE_T_4, \ + STORE_T_3, STORE_T_2, STORE_T_1)(Impl, ##__VA_ARGS__)) // Make sure to match the number of template arguments of stored::store. # define STORE_BASE_CLASS_1(x) x @@ -851,12 +854,11 @@ using store_t = typename store::type; # define STORE_BASE_CLASS_8(x, ...) EXPAND(STORE_BASE_CLASS_7(__VA_ARGS__)) # define STORE_BASE_CLASS_9(x, ...) EXPAND(STORE_BASE_CLASS_8(__VA_ARGS__)) -# define STORE_CLASS_BASE(Impl, ...) \ - EXPAND(STORED_GET_MACRO_ARGN( \ - 0, Impl, ##__VA_ARGS__, STORE_BASE_CLASS_9, STORE_BASE_CLASS_8, \ - STORE_BASE_CLASS_7, STORE_BASE_CLASS_6, STORE_BASE_CLASS_5, STORE_BASE_CLASS_4, \ - STORE_BASE_CLASS_3, STORE_BASE_CLASS_2, \ - STORE_BASE_CLASS_1)(Impl, ##__VA_ARGS__)) +# define STORE_CLASS_BASE(Impl, ...) \ + EXPAND(STORED_GET_MACRO_ARGN( \ + 0, Impl, ##__VA_ARGS__, STORE_BASE_CLASS_9, STORE_BASE_CLASS_8, STORE_BASE_CLASS_7, \ + STORE_BASE_CLASS_6, STORE_BASE_CLASS_5, STORE_BASE_CLASS_4, STORE_BASE_CLASS_3, \ + STORE_BASE_CLASS_2, STORE_BASE_CLASS_1)(Impl, ##__VA_ARGS__)) # ifdef STORED_COMPILER_MSVC // https://developercommunity.visualstudio.com/t/compile-error-when-using-using-declaration-referen/486683 @@ -866,25 +868,25 @@ using store_t = typename store::type; # define STORE_CLASS_USING_BASE_TYPE(type, ...) using typename __VA_ARGS__::type; # endif -# define STORE_CLASS_(Impl, ...) \ - STORED_CLASS_NOCOPY(Impl) \ - STORED_CLASS_NEW_DELETE(Impl) \ - public: \ - typedef Impl self; \ - typedef __VA_ARGS__ base; \ - STORE_CLASS_USING_BASE_TYPE(root, __VA_ARGS__) \ - STORE_CLASS_USING_BASE_TYPE(Implementation, __VA_ARGS__) \ - \ - private: +# define STORE_CLASS_(Impl, ...) \ + STORED_CLASS_NOCOPY(Impl) \ + STORED_CLASS_NEW_DELETE(Impl) \ + public: \ + typedef Impl self; \ + typedef __VA_ARGS__ base; \ + STORE_CLASS_USING_BASE_TYPE(root, __VA_ARGS__) \ + STORE_CLASS_USING_BASE_TYPE(Implementation, __VA_ARGS__) \ + \ + private: /*! * \brief Class helper macro to get a store implementation class right. * \see #stored::store * \hideinitializer */ -# define STORE_CLASS(Impl, ...) \ - STORE_CLASS_(Impl, STORE_T(Impl, __VA_ARGS__)) \ - friend class STORE_CLASS_BASE(Impl, ##__VA_ARGS__); +# define STORE_CLASS(Impl, ...) \ + STORE_CLASS_(Impl, STORE_T(Impl, __VA_ARGS__)) \ + friend class STORE_CLASS_BASE(Impl, ##__VA_ARGS__); /*! * \brief Class helper macro to get a store wrapper class right.