diff --git a/instrumentserver/base.py b/instrumentserver/base.py
index dbb36ca..5b489b0 100644
--- a/instrumentserver/base.py
+++ b/instrumentserver/base.py
@@ -1,8 +1,10 @@
import zmq
import json
+import logging
from .blueprints import to_dict, deserialize_obj
+logger = logging.getLogger(__name__)
def encode(data):
return json.dumps(to_dict(data))
@@ -12,12 +14,45 @@ def decode(data):
return deserialize_obj(json.loads(data))
-def send(socket, data):
- return socket.send_string(encode(data))
+def send(socket, data, use_string=True):
+ payload = encode(data)
+ if use_string:
+ return socket.send_string(payload)
+ else:
+ return socket.send(payload.encode('utf-8'))
def recv(socket):
- return decode(socket.recv_string())
+ # Try multipart receive first (ROUTER replies)
+ parts = socket.recv_multipart()
+ while socket.getsockopt(zmq.RCVMORE):
+ leftover = socket.recv()
+ logger.warning(f"Additional part found in recv: {leftover}")
+ if len(parts) == 1:
+ data = parts[0]
+ elif len(parts) == 2 and parts[0] == b'': # optional empty delimiter
+ data = parts[1]
+ else:
+ data = parts[-1] # assume last part is the actual message
+ return decode(data)
+
+
+def send_router(socket, identity, message):
+ socket.setsockopt(zmq.SNDTIMEO, 5000)
+ socket.setsockopt(zmq.LINGER, 0)
+ payload = encode(message).encode('utf-8')
+ socket.send_multipart([identity, b'', payload])
+
+
+def recv_router(socket):
+ parts = socket.recv_multipart()
+ if len(parts) == 2:
+ identity, payload = parts
+ elif len(parts) == 3 and parts[1] == b'':
+ identity, payload = parts[0], parts[2]
+ else:
+ raise ValueError(f"Malformed ROUTER message: {parts}")
+ return identity, decode(payload)
def sendBroadcast(socket, name, message):
@@ -30,7 +65,7 @@ def sendBroadcast(socket, name, message):
:param messages: The data to send.
"""
socket.send_string(name, flags=zmq.SNDMORE)
- socket.send_string(encode(message))
+ socket.send(encode(message).encode('utf-8'))
def recvMultipart(socket):
diff --git a/instrumentserver/blueprints.py b/instrumentserver/blueprints.py
index f51e7dd..e8524dd 100644
--- a/instrumentserver/blueprints.py
+++ b/instrumentserver/blueprints.py
@@ -59,9 +59,8 @@
from typing import Union, Optional, List, Dict, Callable, Tuple, Any, get_args, cast
import numpy as np
-import qcodes as qc
from qcodes import (
- Station, Instrument, InstrumentChannel, Parameter, ParameterWithSetpoints, ChannelTuple)
+ Station, Instrument, InstrumentChannel, Parameter, ParameterWithSetpoints)
from qcodes.instrument.base import InstrumentBase
from qcodes.utils.validators import Validator
diff --git a/instrumentserver/client/core.py b/instrumentserver/client/core.py
index 7757596..27f65fb 100644
--- a/instrumentserver/client/core.py
+++ b/instrumentserver/client/core.py
@@ -1,8 +1,9 @@
import logging
import warnings
import zmq
+import uuid
-from instrumentserver import DEFAULT_PORT, QtCore
+from instrumentserver import DEFAULT_PORT
from instrumentserver.base import send, recv
from instrumentserver.server.core import ServerResponse
@@ -10,8 +11,6 @@
logger = logging.getLogger(__name__)
-# TODO: allow for the client to operate as context manager.
-
class BaseClient:
"""Simple client for the StationServer.
@@ -21,12 +20,12 @@ class BaseClient:
:param host: The host address of the server, defaults to localhost.
:param port: The port of the server, defaults to the value of DEFAULT_PORT.
:param connect: If true, the server connects as it is being constructed, defaults to True.
- :param timeout: Amount of time that the client waits for an answer before declaring timeout in ms.
- Defaults to 5000.
+ :param timeout: Amount of time that the client waits for an answer before declaring timeout in seconds.
+ Defaults to 20s.
:param raise_exceptions: If true the client will raise an exception when the server sends one to it, defaults to True.
"""
- def __init__(self, host='localhost', port=DEFAULT_PORT, connect=True, timeout=5000, raise_exceptions=True):
+ def __init__(self, host='localhost', port=DEFAULT_PORT, connect=True, timeout=20, raise_exceptions=True):
self.connected = False
self.context = None
self.socket = None
@@ -34,8 +33,7 @@ def __init__(self, host='localhost', port=DEFAULT_PORT, connect=True, timeout=50
self.port = port
self.addr = f"tcp://{host}:{port}"
self.raise_exceptions = raise_exceptions
- #: Timeout for server replies.
- self.recv_timeout = timeout
+ self.recv_timeout_ms = int(timeout * 1e3)
if connect:
self.connect()
@@ -51,8 +49,9 @@ def __exit__(self, exc_type, exc_val, exc_tb):
def connect(self):
logger.info(f"Connecting to {self.addr}")
self.context = zmq.Context()
- self.socket = self.context.socket(zmq.REQ)
- self.socket.setsockopt(zmq.RCVTIMEO, self.recv_timeout)
+ self.socket = self.context.socket(zmq.DEALER)
+ self.socket.setsockopt(zmq.RCVTIMEO, self.recv_timeout_ms)
+ self.socket.setsockopt(zmq.IDENTITY, uuid.uuid4().hex.encode()) #todo: more meaningful id?
self.socket.connect(self.addr)
self.connected = True
@@ -66,37 +65,54 @@ def ask(self, message):
ret = recv(self.socket)
logger.debug(f"Response received.")
logger.debug(f"Response: {str(ret)}")
-
- if isinstance(ret, ServerResponse):
- err = ret.error
- if err is not None:
- if isinstance(err, str):
- logger.error(err)
- elif isinstance(err, Warning):
- warnings.warn(err)
- elif isinstance(err, Exception):
- if self.raise_exceptions:
- raise err
- else:
- logger.error(f'Server raised the following exception: {err}')
- else:
- if self.raise_exceptions:
- raise TypeError(f'Unknown Error Type: {str(err)}')
- else:
- logger.error(f'Unknown Error Type: {str(err)}')
+ except zmq.error.Again:
+ self._reset_connection()
+ if self.raise_exceptions:
+ raise RuntimeError("Server did not reply before timeout.")
+ else:
+ logger.error("Server did not reply before timeout.")
+ return None
+
+ if isinstance(ret, ServerResponse):
+ err = ret.error
+ if err is not None:
+ self._handle_server_error(err)
return ret.message
- except zmq.error.Again as e:
- # if there is a timeout, close the socket and connect again
- self.socket.close()
+ return ret
+
+ def _reset_connection(self):
+ try:
+ if self.socket is not None:
+ self.socket.close(linger=0)
+ finally:
+ self.connected = False
self.connect()
+
+ def _handle_server_error(self, err):
+ if isinstance(err, str):
+ logger.error(err)
if self.raise_exceptions:
- raise RuntimeError(f'Server did not reply before timeout.')
- else:
- logger.error(f'Server did not reply before timeout.')
-
+ raise RuntimeError(err)
+ elif isinstance(err, Warning):
+ warnings.warn(err)
+ elif isinstance(err, Exception):
+ if self.raise_exceptions:
+ raise err
+ logger.error(f"Server raised exception: {err}")
+ else:
+ msg = f"Unknown error type from server: {err!r}"
+ if self.raise_exceptions:
+ raise TypeError(msg)
+ logger.error(msg)
+
def disconnect(self):
- self.socket.close()
+ if self.socket is not None:
+ try:
+ self.socket.close(linger=0)
+ except Exception:
+ pass
+ self.socket = None
self.connected = False
diff --git a/instrumentserver/client/proxy.py b/instrumentserver/client/proxy.py
index 6a7b0cf..f4e8423 100644
--- a/instrumentserver/client/proxy.py
+++ b/instrumentserver/client/proxy.py
@@ -6,10 +6,13 @@
"""
import inspect
import json
+import yaml
import logging
import os
from types import MethodType
from typing import Any, Union, Optional, Dict, List
+import threading
+from contextlib import contextmanager
import qcodes as qc
import zmq
@@ -60,7 +63,10 @@ def __init__(self, *args,
if remotePath is not None and bluePrint is None:
self.remotePath = remotePath
- self.bp = self._getBluePrintFromServer(self.remotePath)
+ if self.cli is None:
+ self.bp = self._getBluePrintFromServer(self.remotePath)
+ else:
+ self.bp = self.cli.getBluePrint(self.remotePath)
elif bluePrint is not None:
self.bp = bluePrint
self.remotePath = self.bp.path
@@ -167,11 +173,11 @@ def _remoteGet(self):
class ProxyInstrumentModule(ProxyMixin, InstrumentBase):
- """Construct a proxy module using the given blue print. Each proxy
+ """Construct a proxy module using the given blueprint. Each proxy
instantiation represents a virtual module (instrument of submodule of
instrument).
- :param bluePrint: The blue print that the describes the module.
+ :param bluePrint: The blueprint that the describes the module.
:param host: The name of the host where the server lives.
:param port: The port number of the server.
"""
@@ -192,24 +198,54 @@ def __init__(self, name: str, *args,
if cli is None:
self.cli = Client(host=host, port=port)
+ for mn in self.bp.methods.keys():
+ if mn == 'remove_parameter':
+ def remove_parameter(obj, name: str):
+ obj.cli.call(f'{obj.remotePath}.remove_parameter', name)
+ obj.update()
+
+ self.remove_parameter = MethodType(remove_parameter, self)
+
self.parameters.pop('IDN', None) # we will redefine this later
# When a new parameter or method is added to client, qcodes checks if that item exists or not. This is done
# by calling __getattr__ method. The problem is that when that method gets called and cannot find that item it
# creates it, generating an infinite loop. This flag stops that. It should be set to True before doing any change
# to the proxy object and set to False after the change is done.
- self.is_updating = True
- self.update()
self.is_updating = False
+ with self._updating():
+ self.update()
+
+ @contextmanager
+ def _updating(self):
+ old = self.is_updating
+ self.is_updating = True
+ try:
+ yield
+ finally:
+ self.is_updating = old
+
def initKwargsFromBluePrint(self, bp):
return {}
def update(self):
+ self.cli.invalidateBlueprint(self.remotePath)
self.bp = self.cli.getBluePrint(self.remotePath)
self._getProxyParameters()
self._getProxyMethods()
self._getProxySubmodules()
+
+ def set_parameters(self, **param_dict:dict):
+ """
+ Set instrument parameters in batch with a dict, keyed by parameter names.
+
+ """
+ for k, v in param_dict.items():
+ try:
+ self.parameters[k](v)
+ except KeyError:
+ raise KeyError(f"{self.bp.instrument_module_class} instrument does not have parameter '{k}'")
def add_parameter(self, name: str, *arg, **kw):
"""Add a parameter to the proxy instrument.
@@ -255,10 +291,9 @@ def _getProxyParameters(self) -> None:
for pn, p in self.bp.parameters.items():
if pn not in self.parameters:
pbp = self.cli.getBluePrint(f"{self.remotePath}.{pn}")
- self.is_updating = True
- super().add_parameter(pbp.name, ProxyParameter, cli=self.cli, host=self.host,
- port=self.port, bluePrint=pbp, setpoints_instrument=self)
- self.is_updating = False
+ with self._updating():
+ super().add_parameter(pbp.name, ProxyParameter, cli=self.cli, host=self.host,
+ port=self.port, bluePrint=pbp, setpoints_instrument=self)
delKeys = []
for pn in self.parameters.keys():
@@ -275,11 +310,10 @@ def _getProxyMethods(self):
"""
for n, m in self.bp.methods.items():
if not hasattr(self, n):
- self.is_updating = True
- fun = self._makeProxyMethod(m)
- setattr(self, n, MethodType(fun, self))
- self.functions[n] = getattr(self, n)
- self.is_updating = False
+ with self._updating():
+ fun = self._makeProxyMethod(m)
+ setattr(self, n, MethodType(fun, self))
+ self.functions[n] = getattr(self, n)
def _makeProxyMethod(self, bp: MethodBluePrint):
def wrap(*a, **k):
@@ -313,7 +347,7 @@ def wrap(*a, **k):
# make sure the method knows the wrap function.
# TODO: this is not complete!
globs = {'wrap': wrap, 'qcodes': qc}
- exec(new_func_str, globs)
+ _ret = exec(new_func_str, globs)
fun = globs[bp.name]
fun.__doc__ = bp.docstring
return globs[bp.name]
@@ -378,12 +412,20 @@ def __getattr__(self, item):
class Client(BaseClient):
"""Client with common server requests as convenience functions."""
+ def __init__(self, host='localhost', port=DEFAULT_PORT, connect=True, timeout=20, raise_exceptions=True):
+ super().__init__(host, port, connect, timeout, raise_exceptions)
+ self._bp_cache = {}
+ self._bp_cache_lock = threading.Lock()
def list_instruments(self) -> Dict[str, str]:
""" Get the existing instruments on the server.
"""
- msg = ServerInstruction(operation=Operation.get_existing_instruments)
- return self.ask(msg)
+ message = ServerInstruction(operation=Operation.get_existing_instruments)
+ try:
+ return self.ask(message)
+ except Exception as e:
+ logger.error(f"Failed to send or receive message to server at {self.host}:{self.port}", exc_info=True)
+ raise RuntimeError("Communication with server failed. See logs for details.") from e
def find_or_create_instrument(self, name: str, instrument_class: Optional[str] = None,
*args: Any, **kwargs: Any) -> ProxyInstrumentModule:
@@ -439,11 +481,38 @@ def get_instrument(self, name):
return ProxyInstrumentModule(name=name, cli=self, remotePath=name)
def getBluePrint(self, path):
+ """
+ get blueprint from server
+ :param path:
+ :return:
+ """
+ with self._bp_cache_lock:
+ bp = self._bp_cache.get(path)
+ if bp is not None:
+ return bp
+
msg = ServerInstruction(
operation=Operation.get_blueprint,
requested_path=path,
)
- return self.ask(msg)
+ bp = self.ask(msg)
+ with self._bp_cache_lock:
+ self._bp_cache[path] = bp
+ return bp
+
+ def invalidateBlueprint(self, path=None):
+ """
+ invalidate a parameter in the blueprint cache
+ :param path:
+ :return:
+ """
+ with self._bp_cache_lock:
+ if path is None:
+ self._bp_cache.clear()
+ else:
+ for k in list(self._bp_cache):
+ if k == path or k.startswith(path + '.'):
+ del self._bp_cache[k]
def get_snapshot(self, instrument: str | None = None, *args, **kwargs):
msg = ServerInstruction(
@@ -564,7 +633,7 @@ def __init__(self, parent=None,
host='localhost',
port=DEFAULT_PORT,
connect=True,
- timeout=5000,
+ timeout=5,
raise_exceptions=True):
# Calling the parents like this ensures that the arguments arrive to the parents properly.
_QtAdapter.__init__(self, parent=parent)
diff --git a/instrumentserver/server/application.py b/instrumentserver/server/application.py
index f963680..e4ff90f 100644
--- a/instrumentserver/server/application.py
+++ b/instrumentserver/server/application.py
@@ -510,8 +510,8 @@ def __init__(self, startServer: Optional[bool] = True,
self.setWindowTitle('Instrument server')
# A test client, just a simple helper object.
- self.client = EmbeddedClient(raise_exceptions=False, timeout=5000000)
- self.client.recv_timeout = 10_000
+ self.client = EmbeddedClient(raise_exceptions=False, timeout=5000)
+ self.client.recv_timeout_ms = 10_000
# Central widget is simply a tab container.
self.tabs = DetachableTabWidget(self)
@@ -910,7 +910,7 @@ def parameterToHtml(bp: ParameterBluePrint, headerLevel=None):
# FIXME: We deleted the validator since there is no real easy way of deserializing them. It would be a good idea to
# have them here though
#
Validator: {html.escape(str(bp.vals))}
- var = """Doc: {html.escape(str(bp.docstring))}
+ var = f"""Doc: {html.escape(str(bp.docstring))}
"""
diff --git a/instrumentserver/server/core.py b/instrumentserver/server/core.py
index cd19688..1d6b3c9 100644
--- a/instrumentserver/server/core.py
+++ b/instrumentserver/server/core.py
@@ -13,16 +13,22 @@
# operations for adding parameters/submodules/functions
# TODO: can we also create methods remotely?
+# TODO: client white list
+
import os
import importlib
-import inspect
+import json
import logging
import random
-import json
+import queue
+import socket
+
from pathlib import Path
from dataclasses import dataclass, field, fields
from enum import Enum, unique
from typing import Dict, Any, Union, Optional, Tuple, List, Callable
+from concurrent.futures import ThreadPoolExecutor
+import threading
import zmq
@@ -38,7 +44,7 @@
INSTRUMENT_MODULE_BASE_CLASSES, PARAMETER_BASE_CLASSES, Operation,
InstrumentCreationSpec, CallSpec, ParameterSerializeSpec, ServerInstruction, ServerResponse,)
-from ..base import send, recv, sendBroadcast
+from ..base import send_router, recv_router, sendBroadcast
from ..helpers import nestedAttributeFromString, objectClassPath, typeClassPath
__author__ = 'Wolfgang Pfaff', 'Chao Zhou'
@@ -156,6 +162,17 @@ def __init__(self,
f"'{n}', args: {str(args)}, "
f"kwargs: {str(kw)})'.")
)
+
+ # a queue for responses that are ready to be sent to client
+ self._response_queue = queue.Queue()
+ # a socket pair for immediate wakeup of the main thread that sends response to client
+ self._wakeup_r, self._wakeup_w = socket.socketpair()
+ self._wakeup_r.setblocking(False)
+ self._wakeup_w.setblocking(False)
+
+ # Per-instrument locks to avoid races when multiple threads talk to the same instrument concurrently
+ self._instrument_locks: dict[str, threading.RLock] = {}
+ self._instrument_locks_lock = threading.Lock()
def _runInitScript(self):
if os.path.exists(self.initScript):
@@ -173,12 +190,17 @@ def startServer(self) -> bool:
logger.info(f"Starting server.")
logger.info(f"The safe word is: {self.SAFEWORD}")
context = zmq.Context()
- socket = context.socket(zmq.REP)
+ socket = context.socket(zmq.ROUTER)
+ # make a zmq poller for detecting activate sockets
+ poller = zmq.Poller()
+ poller.register(socket, zmq.POLLIN)
+ poller.register(self._wakeup_r, zmq.POLLIN)
for a in self.listenAddresses:
addr = f"tcp://{a}:{self.port}"
socket.bind(addr)
logger.info(f"Listening at {addr}")
+ self.serverStarted.emit(addr)
# creating and binding publishing socket to broadcast changes
broadcastAddr = f"tcp://*:{self.broadcastPort}"
@@ -197,81 +219,132 @@ def startServer(self) -> bool:
if self.initScript not in ['', None]:
logger.info(f"Running init script")
self._runInitScript()
- self.serverStarted.emit(addr)
-
- while self.serverRunning:
- message = recv(socket)
- message_ok = True
- response_to_client: ServerResponse | tuple[ServerResponse, str] | None = None
- response_log = None
-
- # Allow the test client from within the same process to make sure the
- # server shuts down. This is
- if message == self.SAFEWORD:
- response_log = 'Server has received the safeword and will shut down.'
- response_to_client = ServerResponse(message=response_log)
- self.serverRunning = False
- logger.warning(response_log)
-
- elif self.allowUserShutdown and message == 'SHUTDOWN':
- response_log = 'Server shutdown requested by client.'
- response_to_client = ServerResponse(message=response_log)
- self.serverRunning = False
- logger.warning(response_log)
-
- # If the message is a string we just echo it back.
- # This is used for testing sometimes, but has no functionality.
- elif isinstance(message, str):
- response_log = f"Server has received: {message}. No further action."
- response_to_client = ServerResponse(message=response_log)
- logger.debug(response_log)
-
- # We assume this is a valid instruction set now.
- elif isinstance(message, ServerInstruction):
- instruction = message
+
+ # create a thread pool for handling incoming client requests concurrently
+ with ThreadPoolExecutor() as pool:
+ while self.serverRunning or not self._response_queue.empty():
try:
- instruction.validate()
- logger.debug(f"Received request for operation: "
- f"{str(instruction.operation)}")
- logger.debug(f"Instruction received: "
- f"{str(instruction)}")
+ # check if there is either incoming request from client, or a processing worker has finished
+ socks = dict(poller.poll(10))
+
+ # handle router socket events (incoming requests)
+ if self.serverRunning and socket in socks and (socks[socket] & zmq.POLLIN):
+ identity, message = recv_router(socket)
+ pool.submit(self._handleRouterMessage, identity, message)
+
+ # handle wakeup events (one or more workers finished)
+ if self._wakeup_r in socks and (socks[self._wakeup_r] & zmq.POLLIN):
+ # Drain the wakeup pipe so it doesn't stay "always readable"
+ try:
+ # Read whatever is there; content doesn't matter
+ self._wakeup_r.recv(1024)
+ except BlockingIOError:
+ pass
+
+ # drain completed responses from workers
+ while True:
+ try:
+ identity, response_to_client, response_log, shutdown = self._response_queue.get_nowait()
+ except queue.Empty:
+ break
+
+ try:
+ send_router(socket, identity, response_to_client)
+ except Exception as e:
+ logger.error(f"Failed to send response to client: {e}")
+
+ # emit log signal
+ self.messageReceived.emit(str(response_to_client.message), response_log)
+
+ # flip the shutdown flag in the main thread
+ if shutdown:
+ self.serverRunning = False
+
except Exception as e:
- message_ok = False
- response_log = f'Received invalid message. Error raised: {str(e)}'
- response_to_client = ServerResponse(message=None, error=e)
- logger.warning(response_log)
-
- if message_ok:
- # We don't need to use a try-block here, because
- # errors are already handled in executeServerInstruction.
- response_to_client = self.executeServerInstruction(instruction)
- response_log = f"Response to client: {str(response_to_client)}"
- if response_to_client.error is None:
- logger.debug(f"Response sent to client.")
- logger.debug(response_log)
- else:
- logger.warning(response_log)
-
- else:
- response_log = f"Invalid message type."
- response_to_client = ServerResponse(message=None, error=response_log)
- logger.warning(f"Invalid message type: {type(message)}.")
- logger.debug(f"Invalid message received: {str(message)}")
-
- send(socket, response_to_client)
-
- self.messageReceived.emit(str(message), response_log)
-
- if self.pollingThread is not None and isinstance(self.pollingThread,QtCore.QThread):
- self.pollingThread.quit()
- logger.info("Polling thread finished")
+ logger.exception(f"Unexpected error in server loop: {e}")
+ break
- self.broadcastSocket.close()
socket.close()
+ self._wakeup_r.close()
+ self._wakeup_w.close()
+ self.broadcastSocket.close()
self.finished.emit()
+ logger.info("StationServer shut down cleanly.")
return True
+
+ def _handleRouterMessage(self, identity, message):
+ """
+ Handle a router message and put the response message in the response queue.
+
+ """
+ message_ok = True
+ response_to_client = None
+ response_log = None
+ shutdown = False # flag for letting the main thread shut down the server
+
+ # Allow the test client from within the same process to make sure the
+ # server shuts down.
+ if message == self.SAFEWORD:
+ response_log = 'Server has received the safeword and will shut down.'
+ response_to_client = ServerResponse(message=response_log)
+ shutdown = True
+ logger.warning(response_log)
+
+ elif self.allowUserShutdown and message == 'SHUTDOWN':
+ response_log = 'Server shutdown requested by client.'
+ response_to_client = ServerResponse(message=response_log)
+ shutdown = True
+ logger.warning(response_log)
+
+ # If the message is a string we just echo it back.
+ # This is used for testing sometimes, but has no functionality.
+ elif isinstance(message, str):
+ response_log = f"Server has received: {message}. No further action."
+ response_to_client = ServerResponse(message=response_log)
+ logger.debug(response_log)
+
+ # We assume this is a valid instruction set now.
+ elif isinstance(message, ServerInstruction):
+ instruction = message
+ try:
+ instruction.validate()
+ logger.debug(f"Received request for operation: "
+ f"{str(instruction.operation)}")
+ logger.debug(f"Instruction received: "
+ f"{str(instruction)}")
+ except Exception as e:
+ message_ok = False
+ response_log = f'Received invalid message. Error raised: {str(e)}'
+ response_to_client = ServerResponse(message=None, error=e)
+ logger.warning(response_log)
- def executeServerInstruction(self, instruction: ServerInstruction) -> ServerResponse:
+ if message_ok:
+ # We don't need to use a try-block here, because
+ # errors are already handled in executeServerInstruction.
+ response_to_client = self.executeServerInstruction(instruction)
+ response_log = f"Response to client: {str(response_to_client)}"
+ if response_to_client.error is None:
+ logger.debug(f"Response sent to client.")
+ logger.debug(response_log)
+ else:
+ logger.warning(response_log)
+
+ else:
+ response_log = f"Invalid message type."
+ response_to_client = ServerResponse(message=None, error=response_log)
+ logger.warning(f"Invalid message type: {type(message)}.")
+ logger.debug(f"Invalid message received: {str(message)}")
+
+ self._response_queue.put((identity, response_to_client, response_log, shutdown))
+ # wake up the server loop so it can send the response immediately
+ try:
+ self._wakeup_w.send(b"\0")
+ except OSError:
+ # If we're shutting down / socket closed, ignore
+ pass
+
+ def executeServerInstruction(self, instruction: ServerInstruction) \
+ -> Tuple[ServerResponse, str]:
"""
This is the interpreter function that the server will call to translate the
dictionary received from the proxy to instrument calls.
@@ -338,44 +411,65 @@ def _createInstrument(self, spec: InstrumentCreationSpec) -> None:
args = [] if spec.args is None else spec.args
kwargs = dict() if spec.kwargs is None else spec.kwargs
-
- new_instrument = qc.find_or_create_instrument(
- cls, spec.name, *args, **kwargs)
- if new_instrument.name not in self.station.components:
- self.station.add_component(new_instrument)
-
- self.instrumentCreated.emit(bluePrintFromInstrumentModule(new_instrument.name, new_instrument),
- args, kwargs)
+
+ # lock based on the intended instrument name
+ lock = self._get_lock_for_target(spec.name)
+ if lock is None:
+ # in case name isn't in station yet, just guard creation with the dict lock
+ lock = self._instrument_locks_lock # coarse but fine for this rare operation
+
+ with lock:
+ new_instrument = qc.find_or_create_instrument(
+ cls, spec.name, *args, **kwargs)
+
+ if new_instrument.name not in self.station.components:
+ self.station.add_component(new_instrument)
+
+ self.instrumentCreated.emit(bluePrintFromInstrumentModule(new_instrument.name, new_instrument),
+ args, kwargs)
def _callObject(self, spec: CallSpec) -> Any:
"""Call some callable found in the station."""
obj = nestedAttributeFromString(self.station, spec.target)
args = spec.args if spec.args is not None else []
kwargs = spec.kwargs if spec.kwargs is not None else {}
- ret = obj(*args, **kwargs)
-
- # Check if a new parameter is being created.
- self._newOrDeleteParameterDetection(spec, args, kwargs)
-
- if isinstance(obj, Parameter):
- if len(args) > 0:
- self.parameterSet.emit(spec.target, args[0])
-
- # Broadcast changes in parameter values.
- self._broadcastParameterChange(ParameterBroadcastBluePrint(spec.target, 'parameter-update', args[0]))
+
+ def _invoke():
+ ret = obj(*args, **kwargs)
+
+ # Check if a new parameter is being created.
+ self._newOrDeleteParameterDetection(spec, args, kwargs)
+
+ if isinstance(obj, Parameter):
+ if len(args) > 0:
+ self.parameterSet.emit(spec.target, args[0])
+
+ # Broadcast changes in parameter values.
+ self._broadcastParameterChange(ParameterBroadcastBluePrint(spec.target, 'parameter-update', args[0]))
+ else:
+ self.parameterGet.emit(spec.target, ret)
+
+ # Broadcast calls of parameters.
+ self._broadcastParameterChange(ParameterBroadcastBluePrint(spec.target, 'parameter-call', ret))
else:
- self.parameterGet.emit(spec.target, ret)
-
- # Broadcast calls of parameters.
- self._broadcastParameterChange(ParameterBroadcastBluePrint(spec.target, 'parameter-call', ret))
+ self.funcCalled.emit(spec.target, args, kwargs, ret)
+
+ return ret
+
+ # Get the appropriate per-instrument lock, if any
+ lock = self._get_lock_for_target(spec.target)
+ if lock is None:
+ # Not an instrument (e.g. Station-level call); just invoke
+ return _invoke()
else:
- self.funcCalled.emit(spec.target, args, kwargs, ret)
-
- return ret
+ # Serialize access to this instrument across threads
+ with lock:
+ return _invoke()
def _getBluePrint(self, path: str) -> Union[InstrumentModuleBluePrint,
ParameterBluePrint,
MethodBluePrint]:
+ logger.debug(f"Fetching blueprint for: {path}")
obj = nestedAttributeFromString(self.station, path)
if isinstance(obj, tuple(INSTRUMENT_MODULE_BASE_CLASSES)):
instrument_blueprint = bluePrintFromInstrumentModule(path, obj)
@@ -464,7 +558,32 @@ def _newOrDeleteParameterDetection(self, spec, args, kwargs):
pb = ParameterBroadcastBluePrint(name,
'parameter-deletion')
self._broadcastParameterChange(pb)
-
+
+ def _get_lock_for_target(self, target: str) -> Optional[threading.RLock]:
+ """
+ Given a call target like 'dac1.ch1.offset' or 'awg.ch2.set_sq_wave',
+ return a per-instrument lock if the root is one of the station components.
+ Otherwise, return None (no locking needed).
+ """
+ # todo: here we assume each instrument can only be used by one thread at a time, which is generally the safer option.
+ # There might exists hardware that actually supports independent, concurrent control of different channels,
+ # in which case we might want to add a tag to the instrument and disable the locking here.
+ if not target:
+ return None
+
+ # First token before the first dot: assumed to be instrument name
+ root = target.split('.')[0]
+
+ # Only lock if this actually corresponds to an instrument in the station
+ if root not in self.station.components:
+ return None
+
+ with self._instrument_locks_lock:
+ lock = self._instrument_locks.get(root)
+ if lock is None:
+ lock = threading.RLock()
+ self._instrument_locks[root] = lock
+ return lock
def startServer(port: int = 5555,
allowUserShutdown: bool = False,
diff --git a/instrumentserver/testing/dummy_instruments/generic.py b/instrumentserver/testing/dummy_instruments/generic.py
index 8f8cc54..5147323 100644
--- a/instrumentserver/testing/dummy_instruments/generic.py
+++ b/instrumentserver/testing/dummy_instruments/generic.py
@@ -92,13 +92,21 @@ def __init__(self, name: str, *args, **kwargs):
super().__init__(name, *args, **kwargs)
self.random = np.random.randint(10000)
+ self._param1 = 1
+ self._param2 = 2
+
+ self.add_parameter('random_int', get_cmd=self.get_random)
+ self.add_parameter('param1', get_cmd=lambda : self._param1, set_cmd=lambda p: setattr(self, '_param1', p))
+ self.add_parameter('param2', get_cmd=lambda : self._param2, set_cmd=lambda p: setattr(self, '_param2', p))
+
def get_random(self):
return self.random
- def get_random_timeout(self):
- time.sleep(10)
- return self.random
+ def get_random_timeout(self, wait_time=10):
+ time.sleep(wait_time)
+ return self.get_random()
+
class DummyInstrumentRandomNumber(Instrument):
diff --git a/instrumentserver/testing/test_async_requests/__init__.py b/instrumentserver/testing/test_async_requests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/instrumentserver/testing/test_async_requests/demo_concurrency.py b/instrumentserver/testing/test_async_requests/demo_concurrency.py
new file mode 100644
index 0000000..70f161d
--- /dev/null
+++ b/instrumentserver/testing/test_async_requests/demo_concurrency.py
@@ -0,0 +1,65 @@
+from instrumentserver.client import Client
+import sys
+import time
+
+'''
+Simple concurrency demo.
+
+Usage (server already running):
+
+Terminal A (long-running call on dummy1):
+ python demo_concurrency.py ramp
+
+Terminal B (start while A is still running):
+
+ # Case 1: same instrument -> should block behind ramp
+ python demo_concurrency.py same
+
+ # Case 2: different instrument -> should return immediately
+ python demo_concurrency.py other
+
+
+
+This mimics the case when one client is ramping bias voltage, while another client wants to change a parameter of
+a different instrument. Or more commonly, a client is ramping bias voltage, and we want to view parameter of an instrument
+in the server gui (which also is basically another client that runs in a different thread.)
+'''
+
+if __name__ == "__main__":
+ role = sys.argv[1] if len(sys.argv) > 1 else "ramp"
+ print(f"[demo] role = {role}")
+
+ cli = Client(timeout=50, port=5555)
+
+ # We only create what we need for the role, but this is cheap anyway
+ dummy1 = cli.find_or_create_instrument(
+ "test1",
+ "instrumentserver.testing.dummy_instruments.generic.DummyInstrumentTimeout",
+ )
+ dummy2 = cli.find_or_create_instrument(
+ "test2",
+ "instrumentserver.testing.dummy_instruments.generic.DummyInstrumentTimeout",
+ )
+
+ t0 = time.time()
+
+ if role == "ramp": # within a single process, operations are always blocking
+ print("[ramp] dummy1.get_random_timeout(10)")
+ print(dummy1.get_random_timeout(10))
+ print("[after ramp] dummy2.get_random()")
+ print(dummy2.get_random())
+
+ elif role == "same": # from a different process, operations on the same instrument are still blocked
+ print("[same] dummy1.get_random() (same instrument as ramp)")
+ print(dummy1.get_random())
+
+ elif role == "other": # from a different process, operations on a different instrument are NOT blocked
+ print("[other] dummy2.get_random() (different instrument)")
+ print(dummy2.get_random())
+
+ else:
+ print(f"Unknown role {role!r}. Use 'ramp', 'same', or 'other'.")
+
+ print(f"[{role}] took {time.time() - t0:.3f} s")
+
+
diff --git a/setup.py b/setup.py
index 2b56e62..ef4848a 100644
--- a/setup.py
+++ b/setup.py
@@ -15,5 +15,13 @@
"instrumentserver-detached = instrumentserver.apps:detachedServerScript",
"instrumentserver-param-manager = instrumentserver.apps:parameterManagerScript",
"instrumentserver-listener = instrumentserver.monitoring.listener:startListener",
- ]}
+ ]},
+ install_requires = [
+ 'zmq',
+ 'qcodes',
+ 'qtpy',
+ 'pyqt5',
+ 'bokeh',
+ 'scipy'
+ ]
)