diff --git a/test/test_nmt.py b/test/test_nmt.py index 50db1187..c24e8d96 100644 --- a/test/test_nmt.py +++ b/test/test_nmt.py @@ -1,21 +1,139 @@ import time import unittest +import can import canopen +from canopen.nmt import COMMAND_TO_STATE, NMT_STATES, NMT_COMMANDS, NmtError from .util import SAMPLE_EDS +class TestNmtBase(unittest.TestCase): + def setUp(self): + node_id = 2 + self.node_id = node_id + self.nmt = canopen.nmt.NmtBase(node_id) + + def test_send_command(self): + dataset = ( + "OPERATIONAL", + "PRE-OPERATIONAL", + "SLEEP", + "STANDBY", + "STOPPED", + ) + for cmd in dataset: + with self.subTest(cmd=cmd): + code = NMT_COMMANDS[cmd] + self.nmt.send_command(code) + expected = NMT_STATES[COMMAND_TO_STATE[code]] + self.assertEqual(self.nmt.state, expected) + + def test_state_getset(self): + for state in NMT_STATES.values(): + with self.subTest(state=state): + self.nmt.state = state + self.assertEqual(self.nmt.state, state) + + def test_state_set_invalid(self): + with self.assertRaisesRegex(ValueError, "INVALID"): + self.nmt.state = "INVALID" + + +class TestNmtMaster(unittest.TestCase): + NODE_ID = 2 + COB_ID = 0x700 + NODE_ID + PERIOD = 0.01 + TIMEOUT = PERIOD * 2 + + def setUp(self): + bus = can.ThreadSafeBus( + interface="virtual", + channel="test", + receive_own_messages=True, + ) + net = canopen.Network(bus) + net.connect() + with self.assertLogs(): + node = net.add_node(self.NODE_ID, SAMPLE_EDS) + + self.bus = bus + self.net = net + self.node = node + + def tearDown(self): + self.net.disconnect() + + def test_nmt_master_no_heartbeat(self): + with self.assertRaisesRegex(NmtError, "heartbeat"): + self.node.nmt.wait_for_heartbeat(self.TIMEOUT) + with self.assertRaisesRegex(NmtError, "boot-up"): + self.node.nmt.wait_for_bootup(self.TIMEOUT) + + def test_nmt_master_on_heartbeat(self): + # Skip the special INITIALISING case. + for code in [st for st in NMT_STATES if st != 0]: + with self.subTest(code=code): + task = self.net.send_periodic(self.COB_ID, [code], self.PERIOD) + try: + actual = self.node.nmt.wait_for_heartbeat(self.TIMEOUT) + finally: + task.stop() + expected = NMT_STATES[code] + self.assertEqual(actual, expected) + + def test_nmt_master_on_heartbeat_initialising(self): + task = self.net.send_periodic(self.COB_ID, [0], self.PERIOD) + self.addCleanup(task.stop) + self.node.nmt.wait_for_bootup(self.TIMEOUT) + state = self.node.nmt.wait_for_heartbeat(self.TIMEOUT) + self.assertEqual(state, "PRE-OPERATIONAL") + + @unittest.expectedFailure + def test_nmt_master_on_heartbeat_unknown_state(self): + task = self.net.send_periodic(self.COB_ID, [0xcb], self.PERIOD) + self.addCleanup(task.stop) + state = self.node.nmt.wait_for_heartbeat(self.TIMEOUT) + # Expect the high bit to be masked out, and and unknown state string to + # be returned. + self.assertEqual(state, "UNKNOWN STATE '75'") + + def test_nmt_master_add_heartbeat_callback(self): + from threading import Event + event = Event() + state = None + def hook(st): + nonlocal state + state = st + event.set() + self.node.nmt.add_heartbeat_callback(hook) + self.net.send_message(self.COB_ID, bytes([127])) + self.assertTrue(event.wait(self.TIMEOUT)) + self.assertEqual(state, 127) + + def test_nmt_master_node_guarding(self): + self.node.nmt.start_node_guarding(self.PERIOD) + msg = self.bus.recv(self.TIMEOUT) + self.assertIsNotNone(msg) + self.assertEqual(msg.arbitration_id, self.COB_ID) + self.assertEqual(msg.dlc, 0) + + self.node.nmt.stop_node_guarding() + self.assertIsNone(self.bus.recv(self.TIMEOUT)) + + class TestNmtSlave(unittest.TestCase): def setUp(self): self.network1 = canopen.Network() self.network1.connect("test", interface="virtual") - self.remote_node = self.network1.add_node(2, SAMPLE_EDS) + with self.assertLogs(): + self.remote_node = self.network1.add_node(2, SAMPLE_EDS) self.network2 = canopen.Network() self.network2.connect("test", interface="virtual") - self.local_node = self.network2.create_node(2, SAMPLE_EDS) - self.remote_node2 = self.network1.add_node(3, SAMPLE_EDS) - self.local_node2 = self.network2.create_node(3, SAMPLE_EDS) + with self.assertLogs(): + self.local_node = self.network2.create_node(2, SAMPLE_EDS) + self.remote_node2 = self.network1.add_node(3, SAMPLE_EDS) + self.local_node2 = self.network2.create_node(3, SAMPLE_EDS) def tearDown(self): self.network1.disconnect()