diff --git a/qcodes/instrument/base.py b/qcodes/instrument/base.py index 546750d4054a..d12c713fee22 100644 --- a/qcodes/instrument/base.py +++ b/qcodes/instrument/base.py @@ -73,6 +73,8 @@ class Instrument(Metadatable, DelegateAttributes, NestedAttrAccess): shared_kwargs = () + _all_instruments = {} + def __new__(cls, *args, server_name='', **kwargs): """Figure out whether to create a base instrument or proxy.""" if server_name is None: @@ -207,16 +209,35 @@ def record_instance(cls, instance): """ Record (a weak ref to) an instance in a class's instance list. + Also records the instance in list of *all* instruments, and verifies + that there are no other instruments with the same name. + Args: instance (Union[Instrument, RemoteInstrument]): Note: we *do not* check that instance is actually an instance of ``cls``. This is important, because a ``RemoteInstrument`` should function as an instance of the instrument it proxies. + + Raises: + KeyError: if another instance with the same name is already present """ + wr = weakref.ref(instance) + name = instance.name + + # First insert this instrument in the record of *all* instruments + # making sure its name is unique + existing_wr = cls._all_instruments.get(name) + if existing_wr and existing_wr(): + raise KeyError('Another instrument has the name: {}'.format(name)) + + cls._all_instruments[name] = wr + + # Then add it to the record for this specific subclass, using ``_type`` + # to make sure we're not recording it in a base class instance list if getattr(cls, '_type', None) is not cls: cls._type = cls cls._instances = [] - cls._instances.append(weakref.ref(instance)) + cls._instances.append(wr) @classmethod def instances(cls): @@ -251,6 +272,74 @@ def remove_instance(cls, instance): if wr in cls._instances: cls._instances.remove(wr) + # remove from all_instruments too, but don't depend on the + # name to do it, in case name has changed or been deleted + all_ins = cls._all_instruments + for name, ref in list(all_ins.items()): + if ref is wr: + del all_ins[name] + + @classmethod + def find_instrument(cls, name, instrument_class=None): + """ + Find an existing instrument by name. + + Args: + name (str) + instrument_class (Optional[class]): The type of instrument + you are looking for. + + Returns: + Union[Instrument, RemoteInstrument] + + Raises: + KeyError: if no instrument of that name was found, or if its + reference is invalid (dead). + TypeError: if a specific class was requested but a different + type was found + """ + ins = cls._all_instruments[name]() + + if ins is None: + del cls._all_instruments[name] + raise KeyError('Instrument {} has been removed'.format(name)) + + if instrument_class is not None: + if not isinstance(ins, instrument_class): + raise TypeError( + 'Instrument {} is {} but {} was requested'.format( + name, type(ins), instrument_class)) + + return ins + + @classmethod + def find_component(cls, name_attr, instrument_class=None): + """ + Find a component of an existing instrument by name and attribute. + + Args: + name_attr (str): A string in nested attribute format: + .[.] and so on. + For example, can be a parameter name, + or a method name. + instrument_class (Optional[class]): The type of instrument + you are looking for this component within. + + Returns: + Any: The component requested. + """ + + if '.' in name_attr: + name, attr = name_attr.split('.', 1) + ins = cls.find_instrument(name, instrument_class=instrument_class) + return ins.getattr(attr) + + else: + # allow find_component to return the whole instrument, + # if no attribute was specified, for maximum generality. + return cls.find_instrument(name_attr, + instrument_class=instrument_class) + def add_parameter(self, name, parameter_class=StandardParameter, **kwargs): """ diff --git a/qcodes/instrument/remote.py b/qcodes/instrument/remote.py index c6ef2377df13..d65409845560 100644 --- a/qcodes/instrument/remote.py +++ b/qcodes/instrument/remote.py @@ -62,9 +62,11 @@ def __init__(self, *args, instrument_class=None, server_name='', self._args = args self._kwargs = kwargs - instrument_class.record_instance(self) self.connect() + # must come after connect() because that sets self.name + instrument_class.record_instance(self) + def connect(self): """Create the instrument on the server and replicate its API here.""" @@ -181,6 +183,23 @@ def instances(self): """ return self._instrument_class.instances() + def find_instrument(self, name, instrument_class=None): + """ + Find an existing instrument by name. + + Args: + name (str) + + Returns: + Union[Instrument, RemoteInstrument] + + Raises: + KeyError: if no instrument of that name was found, or if its + reference is invalid (dead). + """ + return self._instrument_class.find_instrument( + name, instrument_class=instrument_class) + def close(self): """Irreversibly close and tear down the server & remote instruments.""" if hasattr(self, '_manager'): diff --git a/qcodes/instrument/server.py b/qcodes/instrument/server.py index 39b3f5a71f1c..2dc90bf65109 100644 --- a/qcodes/instrument/server.py +++ b/qcodes/instrument/server.py @@ -113,6 +113,13 @@ def __init__(self, query_queue, response_queue, shared_kwargs): self.instruments = {} self.next_id = 0 + # Ensure no references of instruments defined in the main process + # are copied to the server process. With the spawn multiprocessing + # method this is not an issue, as the class is reimported in the + # new process, but with fork it can be a problem ironically. + from qcodes.instrument.base import Instrument + Instrument._all_instruments = {} + self.run_event_loop() def handle_new_id(self): diff --git a/qcodes/tests/instrument_mocks.py b/qcodes/tests/instrument_mocks.py index 685b5ed380bc..73b5866b679b 100644 --- a/qcodes/tests/instrument_mocks.py +++ b/qcodes/tests/instrument_mocks.py @@ -74,6 +74,23 @@ def meter_get(self, parameter): elif parameter[:5] == 'echo ': return self.fmt(float(parameter[5:])) + # alias because we need new names when we instantiate an instrument + # locally at the same time as remotely + def gateslocal_set(self, parameter, value): + return self.gates_set(parameter, value) + + def gateslocal_get(self, parameter): + return self.gates_get(parameter) + + def sourcelocal_set(self, parameter, value): + return self.source_set(parameter, value) + + def sourcelocal_get(self, parameter): + return self.source_get(parameter) + + def meterlocal_get(self, parameter): + return self.meter_get(parameter) + class ParamNoDoc: @@ -115,8 +132,8 @@ def add5(self, b): class MockGates(MockInstTester): - def __init__(self, model=None, **kwargs): - super().__init__('gates', model=model, delay=0.001, **kwargs) + def __init__(self, name='gates', model=None, **kwargs): + super().__init__(name, model=model, delay=0.001, **kwargs) for i in range(3): cmdbase = 'c{}'.format(i) @@ -164,8 +181,8 @@ def slow_neg_set(self, val): class MockSource(MockInstTester): - def __init__(self, model=None, **kwargs): - super().__init__('source', model=model, delay=0.001, **kwargs) + def __init__(self, name='source', model=None, **kwargs): + super().__init__(name, model=model, delay=0.001, **kwargs) self.add_parameter('amplitude', get_cmd='ampl?', set_cmd='ampl:{:.4f}', get_parser=float, @@ -175,8 +192,8 @@ def __init__(self, model=None, **kwargs): class MockMeter(MockInstTester): - def __init__(self, model=None, **kwargs): - super().__init__('meter', model=model, delay=0.001, **kwargs) + def __init__(self, name='meter', model=None, **kwargs): + super().__init__(name, model=model, delay=0.001, **kwargs) self.add_parameter('amplitude', get_cmd='ampl?', get_parser=float) self.add_function('echo', call_cmd='echo {:.2f}?', diff --git a/qcodes/tests/test_instrument.py b/qcodes/tests/test_instrument.py index 77069cf986bb..7299c1c9bad8 100644 --- a/qcodes/tests/test_instrument.py +++ b/qcodes/tests/test_instrument.py @@ -16,7 +16,8 @@ from qcodes.process.helpers import kill_processes from .instrument_mocks import (AMockModel, MockInstTester, - MockGates, MockSource, MockMeter, DummyInstrument) + MockGates, MockSource, MockMeter, + DummyInstrument) from .common import strip_qc @@ -94,7 +95,8 @@ def test_unpicklable(self): def test_slow_set(self): # at least for now, need a local instrument to test logging - gatesLocal = MockGates(model=self.model, server_name=None) + gatesLocal = MockGates(model=self.model, server_name=None, + name='gateslocal') for param, logcount in (('chan0slow', 2), ('chan0slow2', 2), ('chan0slow3', 0), ('chan0slow4', 1), ('chan0slow5', 0)): @@ -123,10 +125,10 @@ def test_max_delay_errors(self): # need to talk to the hardware, so these need to be included # from the beginning when the instrument is created on the # server. - GatesBadDelayType(model=self.model) + GatesBadDelayType(model=self.model, name='gatesBDT') with self.assertRaises(ValueError): - GatesBadDelayValue(model=self.model) + GatesBadDelayValue(model=self.model, name='gatesBDV') def check_ts(self, ts_str): now = datetime.now().strftime('%Y-%m-%d %H:%M:%S') @@ -137,11 +139,42 @@ def test_instances(self): for instrument in instruments: for other_instrument in instruments: instances = instrument.instances() + # check that each instrument is in only its own + # instances list + # also test type checking in find_instrument, + # but we need to use find_component so it executes + # on the server if other_instrument is instrument: self.assertIn(instrument, instances) + + name2 = other_instrument.find_component( + instrument.name + '.name', + other_instrument._instrument_class) + self.assertEqual(name2, instrument.name) else: self.assertNotIn(other_instrument, instances) + with self.assertRaises(TypeError): + other_instrument.find_component( + instrument.name + '.name', + other_instrument._instrument_class) + + # check that we can find each instrument from any other + # find_instrument is explicitly mapped in RemoteInstrument + # so this call gets executed in the main process + self.assertEqual( + instrument, + other_instrument.find_instrument(instrument.name)) + + # but find_component is not, so it executes on the server + self.assertEqual( + instrument.name, + other_instrument.find_component(instrument.name + '.name')) + + # check that we can find this instrument from the base class + self.assertEqual(instrument, + Instrument.find_instrument(instrument.name)) + # somehow instances never go away... there are always 3 # extra references to every instrument object, so del doesn't # work. For this reason, instrument tests should take @@ -149,6 +182,20 @@ def test_instances(self): # so we can't test that the list of defined instruments is actually # *only* what we want to see defined. + def test_instance_name_uniqueness(self): + with self.assertRaises(KeyError): + MockGates(model=self.model) + + def test_remove_instance(self): + self.gates.close() + self.assertEqual(self.gates.instances(), []) + with self.assertRaises(KeyError): + Instrument.find_instrument('gates') + + type(self).gates = MockGates(model=self.model) + self.assertEqual(self.gates.instances(), [self.gates]) + self.assertEqual(Instrument.find_instrument('gates'), self.gates) + def test_mock_instrument(self): gates, source, meter = self.gates, self.source, self.meter @@ -268,9 +315,11 @@ def test_mock_instrument_errors(self): gates.ask('ampl?') with self.assertRaises(TypeError): - MockInstrument('', delay='forever') + MockInstrument('mockbaddelay1', delay='forever') with self.assertRaises(TypeError): - MockInstrument('', delay=-1) + # TODO: since this instrument didn't work, it should be OK + # to use the same name again... how do we allow that? + MockInstrument('mockbaddelay2', delay=-1) # TODO: when an error occurs during constructing an instrument, # we don't have the instrument but its server doesn't know to stop. @@ -314,7 +363,8 @@ def test_sweep_steps_edge_case(self): # but we should handle it # at least for now, need a local instrument to check logging source = self.sourceLocal = MockSource(model=self.model, - server_name=None) + server_name=None, + name='sourcelocal') source.add_parameter('amplitude2', get_cmd='ampl?', set_cmd='ampl:{}', get_parser=float, vals=MultiType(Numbers(0, 1), Strings()), @@ -801,16 +851,18 @@ def test_add_function(self): class TestLocalMock(TestCase): - def setUp(self): - self.model = AMockModel() + @classmethod + def setUpClass(cls): + cls.model = AMockModel() - self.gates = MockGates(self.model, server_name=None) - self.source = MockSource(self.model, server_name=None) - self.meter = MockMeter(self.model, server_name=None) + cls.gates = MockGates(model=cls.model, server_name=None) + cls.source = MockSource(model=cls.model, server_name=None) + cls.meter = MockMeter(model=cls.model, server_name=None) - def tearDown(self): - self.model.close() - for instrument in [self.gates, self.source, self.meter]: + @classmethod + def tearDownClass(cls): + cls.model.close() + for instrument in [cls.gates, cls.source, cls.meter]: instrument.close() def test_local(self): @@ -823,6 +875,35 @@ def test_local(self): with self.assertRaises(ValueError): self.gates.ask('knock knock? Oh never mind.') + def test_instances(self): + # copied from the main (server-based) version + # make sure it all works the same here + instruments = [self.gates, self.source, self.meter] + for instrument in instruments: + for other_instrument in instruments: + instances = instrument.instances() + # check that each instrument is in only its own + # instances list + if other_instrument is instrument: + self.assertIn(instrument, instances) + else: + self.assertNotIn(other_instrument, instances) + + # check that we can find each instrument from any other + # use find_component here to test that it rolls over to + # find_instrument if only a name is given + self.assertEqual( + instrument, + other_instrument.find_component(instrument.name)) + + self.assertEqual( + instrument.name, + other_instrument.find_component(instrument.name + '.name')) + + # check that we can find this instrument from the base class + self.assertEqual(instrument, + Instrument.find_instrument(instrument.name)) + class TestModelAttrAccess(TestCase): diff --git a/qcodes/tests/test_loop.py b/qcodes/tests/test_loop.py index 3f5f5e4ae7f6..2ae8ecd0fa55 100644 --- a/qcodes/tests/test_loop.py +++ b/qcodes/tests/test_loop.py @@ -91,7 +91,9 @@ def test_background_and_datamanager(self): def test_local_instrument(self): # a local instrument should work in a foreground loop, but # not in a background loop (should give a RuntimeError) + self.gates.close() # so we don't have two gates with same name gates_local = MockGates(model=self.model, server_name=None) + self.gates = gates_local c1 = gates_local.chan1 loop_local = Loop(c1[1:5:1], 0.001).each(c1) diff --git a/qcodes/tests/test_visa.py b/qcodes/tests/test_visa.py index 6105e2f2d95e..c09e82f5ea8c 100644 --- a/qcodes/tests/test_visa.py +++ b/qcodes/tests/test_visa.py @@ -114,6 +114,8 @@ def test_ask_write_local(self): for arg in self.args3: self.assertIn(arg, e.exception.args) + mv.close() + def test_ask_write_server(self): # same thing as above but Joe is on a server now... mv = MockVisa('Joe') @@ -149,6 +151,8 @@ def test_ask_write_server(self): for arg in self.args3: self.assertIn(repr(arg), e.exception.args[0]) + mv.close() + @patch('qcodes.instrument.visa.visa.ResourceManager') def test_visa_backend(self, rm_mock): address_opened = [None] @@ -168,12 +172,13 @@ def open_resource(self, address): self.assertEqual(rm_mock.call_args, ((),)) self.assertEqual(address_opened[0], None) - MockBackendVisaInstrument('name', server_name=None, address='ASRL2') + MockBackendVisaInstrument('name2', server_name=None, address='ASRL2') self.assertEqual(rm_mock.call_count, 2) self.assertEqual(rm_mock.call_args, ((),)) self.assertEqual(address_opened[0], 'ASRL2') - MockBackendVisaInstrument('name', server_name=None, address='ASRL3@py') + MockBackendVisaInstrument('name3', server_name=None, + address='ASRL3@py') self.assertEqual(rm_mock.call_count, 3) self.assertEqual(rm_mock.call_args, (('@py',),)) self.assertEqual(address_opened[0], 'ASRL3')