Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 39 additions & 4 deletions instrumentserver/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import zmq
import json
import logging

from .blueprints import to_dict, deserialize_obj

logger = logging.getLogger(__name__)

def encode(data):
return json.dumps(to_dict(data))
Expand All @@ -12,12 +14,45 @@ def decode(data):
return deserialize_obj(json.loads(data))


def send(socket, data):
return socket.send_string(encode(data))
def send(socket, data, use_string=True):
payload = encode(data)
if use_string:
return socket.send_string(payload)
else:
return socket.send(payload.encode('utf-8'))


def recv(socket):
return decode(socket.recv_string())
# Try multipart receive first (ROUTER replies)
parts = socket.recv_multipart()
while socket.getsockopt(zmq.RCVMORE):
leftover = socket.recv()
logger.warning(f"Additional part found in recv: {leftover}")
if len(parts) == 1:
data = parts[0]
elif len(parts) == 2 and parts[0] == b'': # optional empty delimiter
data = parts[1]
else:
data = parts[-1] # assume last part is the actual message
return decode(data)


def send_router(socket, identity, message):
socket.setsockopt(zmq.SNDTIMEO, 5000)
socket.setsockopt(zmq.LINGER, 0)
payload = encode(message).encode('utf-8')
socket.send_multipart([identity, b'', payload])


def recv_router(socket):
parts = socket.recv_multipart()
if len(parts) == 2:
identity, payload = parts
elif len(parts) == 3 and parts[1] == b'':
identity, payload = parts[0], parts[2]
else:
raise ValueError(f"Malformed ROUTER message: {parts}")
return identity, decode(payload)


def sendBroadcast(socket, name, message):
Expand All @@ -30,7 +65,7 @@ def sendBroadcast(socket, name, message):
:param messages: The data to send.
"""
socket.send_string(name, flags=zmq.SNDMORE)
socket.send_string(encode(message))
socket.send(encode(message).encode('utf-8'))


def recvMultipart(socket):
Expand Down
3 changes: 1 addition & 2 deletions instrumentserver/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,8 @@
from typing import Union, Optional, List, Dict, Callable, Tuple, Any, get_args, cast

import numpy as np
import qcodes as qc
from qcodes import (
Station, Instrument, InstrumentChannel, Parameter, ParameterWithSetpoints, ChannelTuple)
Station, Instrument, InstrumentChannel, Parameter, ParameterWithSetpoints)
from qcodes.instrument.base import InstrumentBase
from qcodes.utils.validators import Validator

Expand Down
88 changes: 52 additions & 36 deletions instrumentserver/client/core.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import logging
import warnings
import zmq
import uuid

from instrumentserver import DEFAULT_PORT, QtCore
from instrumentserver import DEFAULT_PORT
from instrumentserver.base import send, recv
from instrumentserver.server.core import ServerResponse


logger = logging.getLogger(__name__)


# TODO: allow for the client to operate as context manager.


class BaseClient:
"""Simple client for the StationServer.
Expand All @@ -21,21 +20,20 @@ class BaseClient:
:param host: The host address of the server, defaults to localhost.
:param port: The port of the server, defaults to the value of DEFAULT_PORT.
:param connect: If true, the server connects as it is being constructed, defaults to True.
:param timeout: Amount of time that the client waits for an answer before declaring timeout in ms.
Defaults to 5000.
:param timeout: Amount of time that the client waits for an answer before declaring timeout in seconds.
Defaults to 20s.
:param raise_exceptions: If true the client will raise an exception when the server sends one to it, defaults to True.
"""

def __init__(self, host='localhost', port=DEFAULT_PORT, connect=True, timeout=5000, raise_exceptions=True):
def __init__(self, host='localhost', port=DEFAULT_PORT, connect=True, timeout=20, raise_exceptions=True):
self.connected = False
self.context = None
self.socket = None
self.host = host
self.port = port
self.addr = f"tcp://{host}:{port}"
self.raise_exceptions = raise_exceptions
#: Timeout for server replies.
self.recv_timeout = timeout
self.recv_timeout_ms = int(timeout * 1e3)

if connect:
self.connect()
Expand All @@ -51,8 +49,9 @@ def __exit__(self, exc_type, exc_val, exc_tb):
def connect(self):
logger.info(f"Connecting to {self.addr}")
self.context = zmq.Context()
self.socket = self.context.socket(zmq.REQ)
self.socket.setsockopt(zmq.RCVTIMEO, self.recv_timeout)
self.socket = self.context.socket(zmq.DEALER)
self.socket.setsockopt(zmq.RCVTIMEO, self.recv_timeout_ms)
self.socket.setsockopt(zmq.IDENTITY, uuid.uuid4().hex.encode()) #todo: more meaningful id?
self.socket.connect(self.addr)
self.connected = True

Expand All @@ -66,37 +65,54 @@ def ask(self, message):
ret = recv(self.socket)
logger.debug(f"Response received.")
logger.debug(f"Response: {str(ret)}")

if isinstance(ret, ServerResponse):
err = ret.error
if err is not None:
if isinstance(err, str):
logger.error(err)
elif isinstance(err, Warning):
warnings.warn(err)
elif isinstance(err, Exception):
if self.raise_exceptions:
raise err
else:
logger.error(f'Server raised the following exception: {err}')
else:
if self.raise_exceptions:
raise TypeError(f'Unknown Error Type: {str(err)}')
else:
logger.error(f'Unknown Error Type: {str(err)}')
except zmq.error.Again:
self._reset_connection()
if self.raise_exceptions:
raise RuntimeError("Server did not reply before timeout.")
else:
logger.error("Server did not reply before timeout.")
return None

if isinstance(ret, ServerResponse):
err = ret.error
if err is not None:
self._handle_server_error(err)
return ret.message

except zmq.error.Again as e:
# if there is a timeout, close the socket and connect again
self.socket.close()
return ret

def _reset_connection(self):
try:
if self.socket is not None:
self.socket.close(linger=0)
finally:
self.connected = False
self.connect()

def _handle_server_error(self, err):
if isinstance(err, str):
logger.error(err)
if self.raise_exceptions:
raise RuntimeError(f'Server did not reply before timeout.')
else:
logger.error(f'Server did not reply before timeout.')

raise RuntimeError(err)
elif isinstance(err, Warning):
warnings.warn(err)
elif isinstance(err, Exception):
if self.raise_exceptions:
raise err
logger.error(f"Server raised exception: {err}")
else:
msg = f"Unknown error type from server: {err!r}"
if self.raise_exceptions:
raise TypeError(msg)
logger.error(msg)

def disconnect(self):
self.socket.close()
if self.socket is not None:
try:
self.socket.close(linger=0)
except Exception:
pass
self.socket = None
self.connected = False


Expand Down
Loading