From f9753eaac6646e9495bac71d6d524c1393486869 Mon Sep 17 00:00:00 2001 From: CHAO ZHOU Date: Sun, 9 Nov 2025 01:08:07 -0500 Subject: [PATCH 1/2] ported over the code for threaded server that uses the router/dealer zmq pattern; changed default unit for timeout from miliseconds to seconds - also made `proxy.Client` cache instrument blueprints --- instrumentserver/base.py | 43 +++- instrumentserver/blueprints.py | 3 +- instrumentserver/client/core.py | 88 ++++--- instrumentserver/client/proxy.py | 107 +++++++-- instrumentserver/server/application.py | 6 +- instrumentserver/server/core.py | 214 ++++++++++++------ .../testing/dummy_instruments/generic.py | 14 +- .../testing/test_async_requests/__init__.py | 0 .../test_async_requests/test_client.py | 39 ++++ setup.py | 10 +- 10 files changed, 384 insertions(+), 140 deletions(-) create mode 100644 instrumentserver/testing/test_async_requests/__init__.py create mode 100644 instrumentserver/testing/test_async_requests/test_client.py diff --git a/instrumentserver/base.py b/instrumentserver/base.py index dbb36ca..5b489b0 100644 --- a/instrumentserver/base.py +++ b/instrumentserver/base.py @@ -1,8 +1,10 @@ import zmq import json +import logging from .blueprints import to_dict, deserialize_obj +logger = logging.getLogger(__name__) def encode(data): return json.dumps(to_dict(data)) @@ -12,12 +14,45 @@ def decode(data): return deserialize_obj(json.loads(data)) -def send(socket, data): - return socket.send_string(encode(data)) +def send(socket, data, use_string=True): + payload = encode(data) + if use_string: + return socket.send_string(payload) + else: + return socket.send(payload.encode('utf-8')) def recv(socket): - return decode(socket.recv_string()) + # Try multipart receive first (ROUTER replies) + parts = socket.recv_multipart() + while socket.getsockopt(zmq.RCVMORE): + leftover = socket.recv() + logger.warning(f"Additional part found in recv: {leftover}") + if len(parts) == 1: + data = parts[0] + elif len(parts) == 2 and parts[0] == b'': # optional empty delimiter + data = parts[1] + else: + data = parts[-1] # assume last part is the actual message + return decode(data) + + +def send_router(socket, identity, message): + socket.setsockopt(zmq.SNDTIMEO, 5000) + socket.setsockopt(zmq.LINGER, 0) + payload = encode(message).encode('utf-8') + socket.send_multipart([identity, b'', payload]) + + +def recv_router(socket): + parts = socket.recv_multipart() + if len(parts) == 2: + identity, payload = parts + elif len(parts) == 3 and parts[1] == b'': + identity, payload = parts[0], parts[2] + else: + raise ValueError(f"Malformed ROUTER message: {parts}") + return identity, decode(payload) def sendBroadcast(socket, name, message): @@ -30,7 +65,7 @@ def sendBroadcast(socket, name, message): :param messages: The data to send. """ socket.send_string(name, flags=zmq.SNDMORE) - socket.send_string(encode(message)) + socket.send(encode(message).encode('utf-8')) def recvMultipart(socket): diff --git a/instrumentserver/blueprints.py b/instrumentserver/blueprints.py index f51e7dd..e8524dd 100644 --- a/instrumentserver/blueprints.py +++ b/instrumentserver/blueprints.py @@ -59,9 +59,8 @@ from typing import Union, Optional, List, Dict, Callable, Tuple, Any, get_args, cast import numpy as np -import qcodes as qc from qcodes import ( - Station, Instrument, InstrumentChannel, Parameter, ParameterWithSetpoints, ChannelTuple) + Station, Instrument, InstrumentChannel, Parameter, ParameterWithSetpoints) from qcodes.instrument.base import InstrumentBase from qcodes.utils.validators import Validator diff --git a/instrumentserver/client/core.py b/instrumentserver/client/core.py index 7757596..27f65fb 100644 --- a/instrumentserver/client/core.py +++ b/instrumentserver/client/core.py @@ -1,8 +1,9 @@ import logging import warnings import zmq +import uuid -from instrumentserver import DEFAULT_PORT, QtCore +from instrumentserver import DEFAULT_PORT from instrumentserver.base import send, recv from instrumentserver.server.core import ServerResponse @@ -10,8 +11,6 @@ logger = logging.getLogger(__name__) -# TODO: allow for the client to operate as context manager. - class BaseClient: """Simple client for the StationServer. @@ -21,12 +20,12 @@ class BaseClient: :param host: The host address of the server, defaults to localhost. :param port: The port of the server, defaults to the value of DEFAULT_PORT. :param connect: If true, the server connects as it is being constructed, defaults to True. - :param timeout: Amount of time that the client waits for an answer before declaring timeout in ms. - Defaults to 5000. + :param timeout: Amount of time that the client waits for an answer before declaring timeout in seconds. + Defaults to 20s. :param raise_exceptions: If true the client will raise an exception when the server sends one to it, defaults to True. """ - def __init__(self, host='localhost', port=DEFAULT_PORT, connect=True, timeout=5000, raise_exceptions=True): + def __init__(self, host='localhost', port=DEFAULT_PORT, connect=True, timeout=20, raise_exceptions=True): self.connected = False self.context = None self.socket = None @@ -34,8 +33,7 @@ def __init__(self, host='localhost', port=DEFAULT_PORT, connect=True, timeout=50 self.port = port self.addr = f"tcp://{host}:{port}" self.raise_exceptions = raise_exceptions - #: Timeout for server replies. - self.recv_timeout = timeout + self.recv_timeout_ms = int(timeout * 1e3) if connect: self.connect() @@ -51,8 +49,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): def connect(self): logger.info(f"Connecting to {self.addr}") self.context = zmq.Context() - self.socket = self.context.socket(zmq.REQ) - self.socket.setsockopt(zmq.RCVTIMEO, self.recv_timeout) + self.socket = self.context.socket(zmq.DEALER) + self.socket.setsockopt(zmq.RCVTIMEO, self.recv_timeout_ms) + self.socket.setsockopt(zmq.IDENTITY, uuid.uuid4().hex.encode()) #todo: more meaningful id? self.socket.connect(self.addr) self.connected = True @@ -66,37 +65,54 @@ def ask(self, message): ret = recv(self.socket) logger.debug(f"Response received.") logger.debug(f"Response: {str(ret)}") - - if isinstance(ret, ServerResponse): - err = ret.error - if err is not None: - if isinstance(err, str): - logger.error(err) - elif isinstance(err, Warning): - warnings.warn(err) - elif isinstance(err, Exception): - if self.raise_exceptions: - raise err - else: - logger.error(f'Server raised the following exception: {err}') - else: - if self.raise_exceptions: - raise TypeError(f'Unknown Error Type: {str(err)}') - else: - logger.error(f'Unknown Error Type: {str(err)}') + except zmq.error.Again: + self._reset_connection() + if self.raise_exceptions: + raise RuntimeError("Server did not reply before timeout.") + else: + logger.error("Server did not reply before timeout.") + return None + + if isinstance(ret, ServerResponse): + err = ret.error + if err is not None: + self._handle_server_error(err) return ret.message - except zmq.error.Again as e: - # if there is a timeout, close the socket and connect again - self.socket.close() + return ret + + def _reset_connection(self): + try: + if self.socket is not None: + self.socket.close(linger=0) + finally: + self.connected = False self.connect() + + def _handle_server_error(self, err): + if isinstance(err, str): + logger.error(err) if self.raise_exceptions: - raise RuntimeError(f'Server did not reply before timeout.') - else: - logger.error(f'Server did not reply before timeout.') - + raise RuntimeError(err) + elif isinstance(err, Warning): + warnings.warn(err) + elif isinstance(err, Exception): + if self.raise_exceptions: + raise err + logger.error(f"Server raised exception: {err}") + else: + msg = f"Unknown error type from server: {err!r}" + if self.raise_exceptions: + raise TypeError(msg) + logger.error(msg) + def disconnect(self): - self.socket.close() + if self.socket is not None: + try: + self.socket.close(linger=0) + except Exception: + pass + self.socket = None self.connected = False diff --git a/instrumentserver/client/proxy.py b/instrumentserver/client/proxy.py index 6a7b0cf..f4e8423 100644 --- a/instrumentserver/client/proxy.py +++ b/instrumentserver/client/proxy.py @@ -6,10 +6,13 @@ """ import inspect import json +import yaml import logging import os from types import MethodType from typing import Any, Union, Optional, Dict, List +import threading +from contextlib import contextmanager import qcodes as qc import zmq @@ -60,7 +63,10 @@ def __init__(self, *args, if remotePath is not None and bluePrint is None: self.remotePath = remotePath - self.bp = self._getBluePrintFromServer(self.remotePath) + if self.cli is None: + self.bp = self._getBluePrintFromServer(self.remotePath) + else: + self.bp = self.cli.getBluePrint(self.remotePath) elif bluePrint is not None: self.bp = bluePrint self.remotePath = self.bp.path @@ -167,11 +173,11 @@ def _remoteGet(self): class ProxyInstrumentModule(ProxyMixin, InstrumentBase): - """Construct a proxy module using the given blue print. Each proxy + """Construct a proxy module using the given blueprint. Each proxy instantiation represents a virtual module (instrument of submodule of instrument). - :param bluePrint: The blue print that the describes the module. + :param bluePrint: The blueprint that the describes the module. :param host: The name of the host where the server lives. :param port: The port number of the server. """ @@ -192,24 +198,54 @@ def __init__(self, name: str, *args, if cli is None: self.cli = Client(host=host, port=port) + for mn in self.bp.methods.keys(): + if mn == 'remove_parameter': + def remove_parameter(obj, name: str): + obj.cli.call(f'{obj.remotePath}.remove_parameter', name) + obj.update() + + self.remove_parameter = MethodType(remove_parameter, self) + self.parameters.pop('IDN', None) # we will redefine this later # When a new parameter or method is added to client, qcodes checks if that item exists or not. This is done # by calling __getattr__ method. The problem is that when that method gets called and cannot find that item it # creates it, generating an infinite loop. This flag stops that. It should be set to True before doing any change # to the proxy object and set to False after the change is done. - self.is_updating = True - self.update() self.is_updating = False + with self._updating(): + self.update() + + @contextmanager + def _updating(self): + old = self.is_updating + self.is_updating = True + try: + yield + finally: + self.is_updating = old + def initKwargsFromBluePrint(self, bp): return {} def update(self): + self.cli.invalidateBlueprint(self.remotePath) self.bp = self.cli.getBluePrint(self.remotePath) self._getProxyParameters() self._getProxyMethods() self._getProxySubmodules() + + def set_parameters(self, **param_dict:dict): + """ + Set instrument parameters in batch with a dict, keyed by parameter names. + + """ + for k, v in param_dict.items(): + try: + self.parameters[k](v) + except KeyError: + raise KeyError(f"{self.bp.instrument_module_class} instrument does not have parameter '{k}'") def add_parameter(self, name: str, *arg, **kw): """Add a parameter to the proxy instrument. @@ -255,10 +291,9 @@ def _getProxyParameters(self) -> None: for pn, p in self.bp.parameters.items(): if pn not in self.parameters: pbp = self.cli.getBluePrint(f"{self.remotePath}.{pn}") - self.is_updating = True - super().add_parameter(pbp.name, ProxyParameter, cli=self.cli, host=self.host, - port=self.port, bluePrint=pbp, setpoints_instrument=self) - self.is_updating = False + with self._updating(): + super().add_parameter(pbp.name, ProxyParameter, cli=self.cli, host=self.host, + port=self.port, bluePrint=pbp, setpoints_instrument=self) delKeys = [] for pn in self.parameters.keys(): @@ -275,11 +310,10 @@ def _getProxyMethods(self): """ for n, m in self.bp.methods.items(): if not hasattr(self, n): - self.is_updating = True - fun = self._makeProxyMethod(m) - setattr(self, n, MethodType(fun, self)) - self.functions[n] = getattr(self, n) - self.is_updating = False + with self._updating(): + fun = self._makeProxyMethod(m) + setattr(self, n, MethodType(fun, self)) + self.functions[n] = getattr(self, n) def _makeProxyMethod(self, bp: MethodBluePrint): def wrap(*a, **k): @@ -313,7 +347,7 @@ def wrap(*a, **k): # make sure the method knows the wrap function. # TODO: this is not complete! globs = {'wrap': wrap, 'qcodes': qc} - exec(new_func_str, globs) + _ret = exec(new_func_str, globs) fun = globs[bp.name] fun.__doc__ = bp.docstring return globs[bp.name] @@ -378,12 +412,20 @@ def __getattr__(self, item): class Client(BaseClient): """Client with common server requests as convenience functions.""" + def __init__(self, host='localhost', port=DEFAULT_PORT, connect=True, timeout=20, raise_exceptions=True): + super().__init__(host, port, connect, timeout, raise_exceptions) + self._bp_cache = {} + self._bp_cache_lock = threading.Lock() def list_instruments(self) -> Dict[str, str]: """ Get the existing instruments on the server. """ - msg = ServerInstruction(operation=Operation.get_existing_instruments) - return self.ask(msg) + message = ServerInstruction(operation=Operation.get_existing_instruments) + try: + return self.ask(message) + except Exception as e: + logger.error(f"Failed to send or receive message to server at {self.host}:{self.port}", exc_info=True) + raise RuntimeError("Communication with server failed. See logs for details.") from e def find_or_create_instrument(self, name: str, instrument_class: Optional[str] = None, *args: Any, **kwargs: Any) -> ProxyInstrumentModule: @@ -439,11 +481,38 @@ def get_instrument(self, name): return ProxyInstrumentModule(name=name, cli=self, remotePath=name) def getBluePrint(self, path): + """ + get blueprint from server + :param path: + :return: + """ + with self._bp_cache_lock: + bp = self._bp_cache.get(path) + if bp is not None: + return bp + msg = ServerInstruction( operation=Operation.get_blueprint, requested_path=path, ) - return self.ask(msg) + bp = self.ask(msg) + with self._bp_cache_lock: + self._bp_cache[path] = bp + return bp + + def invalidateBlueprint(self, path=None): + """ + invalidate a parameter in the blueprint cache + :param path: + :return: + """ + with self._bp_cache_lock: + if path is None: + self._bp_cache.clear() + else: + for k in list(self._bp_cache): + if k == path or k.startswith(path + '.'): + del self._bp_cache[k] def get_snapshot(self, instrument: str | None = None, *args, **kwargs): msg = ServerInstruction( @@ -564,7 +633,7 @@ def __init__(self, parent=None, host='localhost', port=DEFAULT_PORT, connect=True, - timeout=5000, + timeout=5, raise_exceptions=True): # Calling the parents like this ensures that the arguments arrive to the parents properly. _QtAdapter.__init__(self, parent=parent) diff --git a/instrumentserver/server/application.py b/instrumentserver/server/application.py index f963680..e4ff90f 100644 --- a/instrumentserver/server/application.py +++ b/instrumentserver/server/application.py @@ -510,8 +510,8 @@ def __init__(self, startServer: Optional[bool] = True, self.setWindowTitle('Instrument server') # A test client, just a simple helper object. - self.client = EmbeddedClient(raise_exceptions=False, timeout=5000000) - self.client.recv_timeout = 10_000 + self.client = EmbeddedClient(raise_exceptions=False, timeout=5000) + self.client.recv_timeout_ms = 10_000 # Central widget is simply a tab container. self.tabs = DetachableTabWidget(self) @@ -910,7 +910,7 @@ def parameterToHtml(bp: ParameterBluePrint, headerLevel=None): # FIXME: We deleted the validator since there is no real easy way of deserializing them. It would be a good idea to # have them here though #
  • Validator: {html.escape(str(bp.vals))}
  • - var = """
  • Doc: {html.escape(str(bp.docstring))}
  • + var = f"""
  • Doc: {html.escape(str(bp.docstring))}
  • """ diff --git a/instrumentserver/server/core.py b/instrumentserver/server/core.py index cd19688..72d0f7b 100644 --- a/instrumentserver/server/core.py +++ b/instrumentserver/server/core.py @@ -13,16 +13,22 @@ # operations for adding parameters/submodules/functions # TODO: can we also create methods remotely? +# TODO: client white list + import os import importlib -import inspect +import json import logging import random -import json +import queue +import socket + from pathlib import Path from dataclasses import dataclass, field, fields from enum import Enum, unique from typing import Dict, Any, Union, Optional, Tuple, List, Callable +from concurrent.futures import ThreadPoolExecutor +import threading import zmq @@ -38,7 +44,7 @@ INSTRUMENT_MODULE_BASE_CLASSES, PARAMETER_BASE_CLASSES, Operation, InstrumentCreationSpec, CallSpec, ParameterSerializeSpec, ServerInstruction, ServerResponse,) -from ..base import send, recv, sendBroadcast +from ..base import send_router, recv_router, sendBroadcast from ..helpers import nestedAttributeFromString, objectClassPath, typeClassPath __author__ = 'Wolfgang Pfaff', 'Chao Zhou' @@ -156,6 +162,13 @@ def __init__(self, f"'{n}', args: {str(args)}, " f"kwargs: {str(kw)})'.") ) + + # a queue for responses that are ready to be sent to client + self._response_queue = queue.Queue() + # a socket pair for immediate wakeup of the main thread that sends response to client + self._wakeup_r, self._wakeup_w = socket.socketpair() + self._wakeup_r.setblocking(False) + self._wakeup_w.setblocking(False) def _runInitScript(self): if os.path.exists(self.initScript): @@ -173,12 +186,17 @@ def startServer(self) -> bool: logger.info(f"Starting server.") logger.info(f"The safe word is: {self.SAFEWORD}") context = zmq.Context() - socket = context.socket(zmq.REP) + socket = context.socket(zmq.ROUTER) + # make a zmq poller for detecting activate sockets + poller = zmq.Poller() + poller.register(socket, zmq.POLLIN) + poller.register(self._wakeup_r, zmq.POLLIN) for a in self.listenAddresses: addr = f"tcp://{a}:{self.port}" socket.bind(addr) logger.info(f"Listening at {addr}") + self.serverStarted.emit(addr) # creating and binding publishing socket to broadcast changes broadcastAddr = f"tcp://*:{self.broadcastPort}" @@ -197,81 +215,132 @@ def startServer(self) -> bool: if self.initScript not in ['', None]: logger.info(f"Running init script") self._runInitScript() - self.serverStarted.emit(addr) - - while self.serverRunning: - message = recv(socket) - message_ok = True - response_to_client: ServerResponse | tuple[ServerResponse, str] | None = None - response_log = None - - # Allow the test client from within the same process to make sure the - # server shuts down. This is - if message == self.SAFEWORD: - response_log = 'Server has received the safeword and will shut down.' - response_to_client = ServerResponse(message=response_log) - self.serverRunning = False - logger.warning(response_log) - - elif self.allowUserShutdown and message == 'SHUTDOWN': - response_log = 'Server shutdown requested by client.' - response_to_client = ServerResponse(message=response_log) - self.serverRunning = False - logger.warning(response_log) - - # If the message is a string we just echo it back. - # This is used for testing sometimes, but has no functionality. - elif isinstance(message, str): - response_log = f"Server has received: {message}. No further action." - response_to_client = ServerResponse(message=response_log) - logger.debug(response_log) - - # We assume this is a valid instruction set now. - elif isinstance(message, ServerInstruction): - instruction = message + + # create a thread pool for handling incoming client requests concurrently + with ThreadPoolExecutor() as pool: + while self.serverRunning or not self._response_queue.empty(): try: - instruction.validate() - logger.debug(f"Received request for operation: " - f"{str(instruction.operation)}") - logger.debug(f"Instruction received: " - f"{str(instruction)}") + # check if there is either incoming request from client, or a processing worker has finished + socks = dict(poller.poll(10)) + + # handle router socket events (incoming requests) + if self.serverRunning and socket in socks and (socks[socket] & zmq.POLLIN): + identity, message = recv_router(socket) + pool.submit(self._handleRouterMessage, identity, message) + + # handle wakeup events (one or more workers finished) + if self._wakeup_r in socks and (socks[self._wakeup_r] & zmq.POLLIN): + # Drain the wakeup pipe so it doesn't stay "always readable" + try: + # Read whatever is there; content doesn't matter + self._wakeup_r.recv(1024) + except BlockingIOError: + pass + + # drain completed responses from workers + while True: + try: + identity, response_to_client, response_log, shutdown = self._response_queue.get_nowait() + except queue.Empty: + break + + try: + send_router(socket, identity, response_to_client) + except Exception as e: + logger.error(f"Failed to send response to client: {e}") + + # emit log signal + self.messageReceived.emit(str(response_to_client.message), response_log) + + # flip the shutdown flag in the main thread + if shutdown: + self.serverRunning = False + except Exception as e: - message_ok = False - response_log = f'Received invalid message. Error raised: {str(e)}' - response_to_client = ServerResponse(message=None, error=e) - logger.warning(response_log) - - if message_ok: - # We don't need to use a try-block here, because - # errors are already handled in executeServerInstruction. - response_to_client = self.executeServerInstruction(instruction) - response_log = f"Response to client: {str(response_to_client)}" - if response_to_client.error is None: - logger.debug(f"Response sent to client.") - logger.debug(response_log) - else: - logger.warning(response_log) - - else: - response_log = f"Invalid message type." - response_to_client = ServerResponse(message=None, error=response_log) - logger.warning(f"Invalid message type: {type(message)}.") - logger.debug(f"Invalid message received: {str(message)}") - - send(socket, response_to_client) - - self.messageReceived.emit(str(message), response_log) - - if self.pollingThread is not None and isinstance(self.pollingThread,QtCore.QThread): - self.pollingThread.quit() - logger.info("Polling thread finished") + logger.exception(f"Unexpected error in server loop: {e}") + break - self.broadcastSocket.close() socket.close() + self._wakeup_r.close() + self._wakeup_w.close() + self.broadcastSocket.close() self.finished.emit() + logger.info("StationServer shut down cleanly.") return True + + def _handleRouterMessage(self, identity, message): + """ + Handle a router message and put the response message in the response queue. + + """ + message_ok = True + response_to_client = None + response_log = None + shutdown = False # flag for letting the main thread shut down the server + + # Allow the test client from within the same process to make sure the + # server shuts down. + if message == self.SAFEWORD: + response_log = 'Server has received the safeword and will shut down.' + response_to_client = ServerResponse(message=response_log) + shutdown = True + logger.warning(response_log) + + elif self.allowUserShutdown and message == 'SHUTDOWN': + response_log = 'Server shutdown requested by client.' + response_to_client = ServerResponse(message=response_log) + shutdown = True + logger.warning(response_log) + + # If the message is a string we just echo it back. + # This is used for testing sometimes, but has no functionality. + elif isinstance(message, str): + response_log = f"Server has received: {message}. No further action." + response_to_client = ServerResponse(message=response_log) + logger.debug(response_log) + + # We assume this is a valid instruction set now. + elif isinstance(message, ServerInstruction): + instruction = message + try: + instruction.validate() + logger.debug(f"Received request for operation: " + f"{str(instruction.operation)}") + logger.debug(f"Instruction received: " + f"{str(instruction)}") + except Exception as e: + message_ok = False + response_log = f'Received invalid message. Error raised: {str(e)}' + response_to_client = ServerResponse(message=None, error=e) + logger.warning(response_log) + + if message_ok: + # We don't need to use a try-block here, because + # errors are already handled in executeServerInstruction. + response_to_client = self.executeServerInstruction(instruction) + response_log = f"Response to client: {str(response_to_client)}" + if response_to_client.error is None: + logger.debug(f"Response sent to client.") + logger.debug(response_log) + else: + logger.warning(response_log) - def executeServerInstruction(self, instruction: ServerInstruction) -> ServerResponse: + else: + response_log = f"Invalid message type." + response_to_client = ServerResponse(message=None, error=response_log) + logger.warning(f"Invalid message type: {type(message)}.") + logger.debug(f"Invalid message received: {str(message)}") + + self._response_queue.put((identity, response_to_client, response_log, shutdown)) + # wake up the server loop so it can send the response immediately + try: + self._wakeup_w.send(b"\0") + except OSError: + # If we're shutting down / socket closed, ignore + pass + + def executeServerInstruction(self, instruction: ServerInstruction) \ + -> Tuple[ServerResponse, str]: """ This is the interpreter function that the server will call to translate the dictionary received from the proxy to instrument calls. @@ -376,6 +445,7 @@ def _callObject(self, spec: CallSpec) -> Any: def _getBluePrint(self, path: str) -> Union[InstrumentModuleBluePrint, ParameterBluePrint, MethodBluePrint]: + logger.debug(f"Fetching blueprint for: {path}") obj = nestedAttributeFromString(self.station, path) if isinstance(obj, tuple(INSTRUMENT_MODULE_BASE_CLASSES)): instrument_blueprint = bluePrintFromInstrumentModule(path, obj) diff --git a/instrumentserver/testing/dummy_instruments/generic.py b/instrumentserver/testing/dummy_instruments/generic.py index 8f8cc54..5147323 100644 --- a/instrumentserver/testing/dummy_instruments/generic.py +++ b/instrumentserver/testing/dummy_instruments/generic.py @@ -92,13 +92,21 @@ def __init__(self, name: str, *args, **kwargs): super().__init__(name, *args, **kwargs) self.random = np.random.randint(10000) + self._param1 = 1 + self._param2 = 2 + + self.add_parameter('random_int', get_cmd=self.get_random) + self.add_parameter('param1', get_cmd=lambda : self._param1, set_cmd=lambda p: setattr(self, '_param1', p)) + self.add_parameter('param2', get_cmd=lambda : self._param2, set_cmd=lambda p: setattr(self, '_param2', p)) + def get_random(self): return self.random - def get_random_timeout(self): - time.sleep(10) - return self.random + def get_random_timeout(self, wait_time=10): + time.sleep(wait_time) + return self.get_random() + class DummyInstrumentRandomNumber(Instrument): diff --git a/instrumentserver/testing/test_async_requests/__init__.py b/instrumentserver/testing/test_async_requests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/instrumentserver/testing/test_async_requests/test_client.py b/instrumentserver/testing/test_async_requests/test_client.py new file mode 100644 index 0000000..f11f4a7 --- /dev/null +++ b/instrumentserver/testing/test_async_requests/test_client.py @@ -0,0 +1,39 @@ +from instrumentserver.client import Client + + +''' +A simple test script for the concurrence feature on the server. + +With the server started, run the full code below in one console, +then comment out the `dummy1.get_random_timeout` line, run the code in a new console, the `dummy2.get_random` should +be able to return immediately. +Without concurrence on the server, the `dummy2.get_random` in the new console won't return until the dummy1 in the first +console is done. + + +This mimics the case when on client is ramping bias voltage, while another client wants to change a parameter of +a different instrument. Or more commonly, a client is ramping bias voltage, and we want to view parameter of an instrument +in the server gui (which also is basically another client that runs in a different console.) +''' + +if __name__ == "__main__": + cli = Client(timeout=50, port=5555) + import time + t0 = time.time() + dummy1 = cli.find_or_create_instrument('test1', + 'instrumentserver.testing.dummy_instruments.generic.DummyInstrumentTimeout') + dummy2 = cli.find_or_create_instrument('test2', + 'instrumentserver.testing.dummy_instruments.generic.DummyInstrumentTimeout') + + # print(dummy1.get_random_timeout(10)) + print(dummy1.get_random()) + print(dummy2.get_random()) + + + # for i in range(20): + # print(dummy1.get_random()) + # print(dummy2.get_random()) + + print(f"took {time.time() - t0} seconds") + + diff --git a/setup.py b/setup.py index 2b56e62..ef4848a 100644 --- a/setup.py +++ b/setup.py @@ -15,5 +15,13 @@ "instrumentserver-detached = instrumentserver.apps:detachedServerScript", "instrumentserver-param-manager = instrumentserver.apps:parameterManagerScript", "instrumentserver-listener = instrumentserver.monitoring.listener:startListener", - ]} + ]}, + install_requires = [ + 'zmq', + 'qcodes', + 'qtpy', + 'pyqt5', + 'bokeh', + 'scipy' + ] ) From ce4190df2999207015d193c1ca1797bce1ec6033 Mon Sep 17 00:00:00 2001 From: CHAO ZHOU Date: Sun, 9 Nov 2025 16:44:35 -0500 Subject: [PATCH 2/2] implemented per-instrument lock on the server such that each instrument can only be called from one thread at a time. - also made a better demo for the usage of threaded server --- instrumentserver/server/core.py | 103 +++++++++++++----- .../test_async_requests/demo_concurrency.py | 65 +++++++++++ .../test_async_requests/test_client.py | 39 ------- 3 files changed, 141 insertions(+), 66 deletions(-) create mode 100644 instrumentserver/testing/test_async_requests/demo_concurrency.py delete mode 100644 instrumentserver/testing/test_async_requests/test_client.py diff --git a/instrumentserver/server/core.py b/instrumentserver/server/core.py index 72d0f7b..1d6b3c9 100644 --- a/instrumentserver/server/core.py +++ b/instrumentserver/server/core.py @@ -169,6 +169,10 @@ def __init__(self, self._wakeup_r, self._wakeup_w = socket.socketpair() self._wakeup_r.setblocking(False) self._wakeup_w.setblocking(False) + + # Per-instrument locks to avoid races when multiple threads talk to the same instrument concurrently + self._instrument_locks: dict[str, threading.RLock] = {} + self._instrument_locks_lock = threading.Lock() def _runInitScript(self): if os.path.exists(self.initScript): @@ -407,40 +411,60 @@ def _createInstrument(self, spec: InstrumentCreationSpec) -> None: args = [] if spec.args is None else spec.args kwargs = dict() if spec.kwargs is None else spec.kwargs - - new_instrument = qc.find_or_create_instrument( - cls, spec.name, *args, **kwargs) - if new_instrument.name not in self.station.components: - self.station.add_component(new_instrument) - - self.instrumentCreated.emit(bluePrintFromInstrumentModule(new_instrument.name, new_instrument), - args, kwargs) + + # lock based on the intended instrument name + lock = self._get_lock_for_target(spec.name) + if lock is None: + # in case name isn't in station yet, just guard creation with the dict lock + lock = self._instrument_locks_lock # coarse but fine for this rare operation + + with lock: + new_instrument = qc.find_or_create_instrument( + cls, spec.name, *args, **kwargs) + + if new_instrument.name not in self.station.components: + self.station.add_component(new_instrument) + + self.instrumentCreated.emit(bluePrintFromInstrumentModule(new_instrument.name, new_instrument), + args, kwargs) def _callObject(self, spec: CallSpec) -> Any: """Call some callable found in the station.""" obj = nestedAttributeFromString(self.station, spec.target) args = spec.args if spec.args is not None else [] kwargs = spec.kwargs if spec.kwargs is not None else {} - ret = obj(*args, **kwargs) - - # Check if a new parameter is being created. - self._newOrDeleteParameterDetection(spec, args, kwargs) - - if isinstance(obj, Parameter): - if len(args) > 0: - self.parameterSet.emit(spec.target, args[0]) - - # Broadcast changes in parameter values. - self._broadcastParameterChange(ParameterBroadcastBluePrint(spec.target, 'parameter-update', args[0])) + + def _invoke(): + ret = obj(*args, **kwargs) + + # Check if a new parameter is being created. + self._newOrDeleteParameterDetection(spec, args, kwargs) + + if isinstance(obj, Parameter): + if len(args) > 0: + self.parameterSet.emit(spec.target, args[0]) + + # Broadcast changes in parameter values. + self._broadcastParameterChange(ParameterBroadcastBluePrint(spec.target, 'parameter-update', args[0])) + else: + self.parameterGet.emit(spec.target, ret) + + # Broadcast calls of parameters. + self._broadcastParameterChange(ParameterBroadcastBluePrint(spec.target, 'parameter-call', ret)) else: - self.parameterGet.emit(spec.target, ret) - - # Broadcast calls of parameters. - self._broadcastParameterChange(ParameterBroadcastBluePrint(spec.target, 'parameter-call', ret)) + self.funcCalled.emit(spec.target, args, kwargs, ret) + + return ret + + # Get the appropriate per-instrument lock, if any + lock = self._get_lock_for_target(spec.target) + if lock is None: + # Not an instrument (e.g. Station-level call); just invoke + return _invoke() else: - self.funcCalled.emit(spec.target, args, kwargs, ret) - - return ret + # Serialize access to this instrument across threads + with lock: + return _invoke() def _getBluePrint(self, path: str) -> Union[InstrumentModuleBluePrint, ParameterBluePrint, @@ -534,7 +558,32 @@ def _newOrDeleteParameterDetection(self, spec, args, kwargs): pb = ParameterBroadcastBluePrint(name, 'parameter-deletion') self._broadcastParameterChange(pb) - + + def _get_lock_for_target(self, target: str) -> Optional[threading.RLock]: + """ + Given a call target like 'dac1.ch1.offset' or 'awg.ch2.set_sq_wave', + return a per-instrument lock if the root is one of the station components. + Otherwise, return None (no locking needed). + """ + # todo: here we assume each instrument can only be used by one thread at a time, which is generally the safer option. + # There might exists hardware that actually supports independent, concurrent control of different channels, + # in which case we might want to add a tag to the instrument and disable the locking here. + if not target: + return None + + # First token before the first dot: assumed to be instrument name + root = target.split('.')[0] + + # Only lock if this actually corresponds to an instrument in the station + if root not in self.station.components: + return None + + with self._instrument_locks_lock: + lock = self._instrument_locks.get(root) + if lock is None: + lock = threading.RLock() + self._instrument_locks[root] = lock + return lock def startServer(port: int = 5555, allowUserShutdown: bool = False, diff --git a/instrumentserver/testing/test_async_requests/demo_concurrency.py b/instrumentserver/testing/test_async_requests/demo_concurrency.py new file mode 100644 index 0000000..70f161d --- /dev/null +++ b/instrumentserver/testing/test_async_requests/demo_concurrency.py @@ -0,0 +1,65 @@ +from instrumentserver.client import Client +import sys +import time + +''' +Simple concurrency demo. + +Usage (server already running): + +Terminal A (long-running call on dummy1): + python demo_concurrency.py ramp + +Terminal B (start while A is still running): + + # Case 1: same instrument -> should block behind ramp + python demo_concurrency.py same + + # Case 2: different instrument -> should return immediately + python demo_concurrency.py other + + + +This mimics the case when one client is ramping bias voltage, while another client wants to change a parameter of +a different instrument. Or more commonly, a client is ramping bias voltage, and we want to view parameter of an instrument +in the server gui (which also is basically another client that runs in a different thread.) +''' + +if __name__ == "__main__": + role = sys.argv[1] if len(sys.argv) > 1 else "ramp" + print(f"[demo] role = {role}") + + cli = Client(timeout=50, port=5555) + + # We only create what we need for the role, but this is cheap anyway + dummy1 = cli.find_or_create_instrument( + "test1", + "instrumentserver.testing.dummy_instruments.generic.DummyInstrumentTimeout", + ) + dummy2 = cli.find_or_create_instrument( + "test2", + "instrumentserver.testing.dummy_instruments.generic.DummyInstrumentTimeout", + ) + + t0 = time.time() + + if role == "ramp": # within a single process, operations are always blocking + print("[ramp] dummy1.get_random_timeout(10)") + print(dummy1.get_random_timeout(10)) + print("[after ramp] dummy2.get_random()") + print(dummy2.get_random()) + + elif role == "same": # from a different process, operations on the same instrument are still blocked + print("[same] dummy1.get_random() (same instrument as ramp)") + print(dummy1.get_random()) + + elif role == "other": # from a different process, operations on a different instrument are NOT blocked + print("[other] dummy2.get_random() (different instrument)") + print(dummy2.get_random()) + + else: + print(f"Unknown role {role!r}. Use 'ramp', 'same', or 'other'.") + + print(f"[{role}] took {time.time() - t0:.3f} s") + + diff --git a/instrumentserver/testing/test_async_requests/test_client.py b/instrumentserver/testing/test_async_requests/test_client.py deleted file mode 100644 index f11f4a7..0000000 --- a/instrumentserver/testing/test_async_requests/test_client.py +++ /dev/null @@ -1,39 +0,0 @@ -from instrumentserver.client import Client - - -''' -A simple test script for the concurrence feature on the server. - -With the server started, run the full code below in one console, -then comment out the `dummy1.get_random_timeout` line, run the code in a new console, the `dummy2.get_random` should -be able to return immediately. -Without concurrence on the server, the `dummy2.get_random` in the new console won't return until the dummy1 in the first -console is done. - - -This mimics the case when on client is ramping bias voltage, while another client wants to change a parameter of -a different instrument. Or more commonly, a client is ramping bias voltage, and we want to view parameter of an instrument -in the server gui (which also is basically another client that runs in a different console.) -''' - -if __name__ == "__main__": - cli = Client(timeout=50, port=5555) - import time - t0 = time.time() - dummy1 = cli.find_or_create_instrument('test1', - 'instrumentserver.testing.dummy_instruments.generic.DummyInstrumentTimeout') - dummy2 = cli.find_or_create_instrument('test2', - 'instrumentserver.testing.dummy_instruments.generic.DummyInstrumentTimeout') - - # print(dummy1.get_random_timeout(10)) - print(dummy1.get_random()) - print(dummy2.get_random()) - - - # for i in range(20): - # print(dummy1.get_random()) - # print(dummy2.get_random()) - - print(f"took {time.time() - t0} seconds") - -