Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 122 additions & 4 deletions test/test_nmt.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,139 @@
import time
import unittest

import can
import canopen
from canopen.nmt import COMMAND_TO_STATE, NMT_STATES, NMT_COMMANDS, NmtError
from .util import SAMPLE_EDS


class TestNmtBase(unittest.TestCase):
def setUp(self):
node_id = 2
self.node_id = node_id
self.nmt = canopen.nmt.NmtBase(node_id)

def test_send_command(self):
dataset = (
"OPERATIONAL",
"PRE-OPERATIONAL",
"SLEEP",
"STANDBY",
"STOPPED",
)
for cmd in dataset:
with self.subTest(cmd=cmd):
code = NMT_COMMANDS[cmd]
self.nmt.send_command(code)
expected = NMT_STATES[COMMAND_TO_STATE[code]]
self.assertEqual(self.nmt.state, expected)

def test_state_getset(self):
for state in NMT_STATES.values():
with self.subTest(state=state):
self.nmt.state = state
self.assertEqual(self.nmt.state, state)

def test_state_set_invalid(self):
with self.assertRaisesRegex(ValueError, "INVALID"):
self.nmt.state = "INVALID"


class TestNmtMaster(unittest.TestCase):
NODE_ID = 2
COB_ID = 0x700 + NODE_ID
PERIOD = 0.01
TIMEOUT = PERIOD * 2

def setUp(self):
bus = can.ThreadSafeBus(
interface="virtual",
channel="test",
receive_own_messages=True,
)
net = canopen.Network(bus)
net.connect()
with self.assertLogs():
node = net.add_node(self.NODE_ID, SAMPLE_EDS)

self.bus = bus
self.net = net
self.node = node

def tearDown(self):
self.net.disconnect()

def test_nmt_master_no_heartbeat(self):
with self.assertRaisesRegex(NmtError, "heartbeat"):
self.node.nmt.wait_for_heartbeat(self.TIMEOUT)
with self.assertRaisesRegex(NmtError, "boot-up"):
self.node.nmt.wait_for_bootup(self.TIMEOUT)

def test_nmt_master_on_heartbeat(self):
# Skip the special INITIALISING case.
for code in [st for st in NMT_STATES if st != 0]:
with self.subTest(code=code):
task = self.net.send_periodic(self.COB_ID, [code], self.PERIOD)
try:
actual = self.node.nmt.wait_for_heartbeat(self.TIMEOUT)
finally:
task.stop()
expected = NMT_STATES[code]
self.assertEqual(actual, expected)

def test_nmt_master_on_heartbeat_initialising(self):
task = self.net.send_periodic(self.COB_ID, [0], self.PERIOD)
self.addCleanup(task.stop)
self.node.nmt.wait_for_bootup(self.TIMEOUT)
state = self.node.nmt.wait_for_heartbeat(self.TIMEOUT)
self.assertEqual(state, "PRE-OPERATIONAL")

@unittest.expectedFailure
def test_nmt_master_on_heartbeat_unknown_state(self):
task = self.net.send_periodic(self.COB_ID, [0xcb], self.PERIOD)
self.addCleanup(task.stop)
state = self.node.nmt.wait_for_heartbeat(self.TIMEOUT)
# Expect the high bit to be masked out, and and unknown state string to
# be returned.
self.assertEqual(state, "UNKNOWN STATE '75'")

def test_nmt_master_add_heartbeat_callback(self):
from threading import Event
event = Event()
state = None
def hook(st):
nonlocal state
state = st
event.set()
self.node.nmt.add_heartbeat_callback(hook)
self.net.send_message(self.COB_ID, bytes([127]))
self.assertTrue(event.wait(self.TIMEOUT))
self.assertEqual(state, 127)

def test_nmt_master_node_guarding(self):
self.node.nmt.start_node_guarding(self.PERIOD)
msg = self.bus.recv(self.TIMEOUT)
self.assertIsNotNone(msg)
self.assertEqual(msg.arbitration_id, self.COB_ID)
self.assertEqual(msg.dlc, 0)

self.node.nmt.stop_node_guarding()
self.assertIsNone(self.bus.recv(self.TIMEOUT))


class TestNmtSlave(unittest.TestCase):
def setUp(self):
self.network1 = canopen.Network()
self.network1.connect("test", interface="virtual")
self.remote_node = self.network1.add_node(2, SAMPLE_EDS)
with self.assertLogs():
self.remote_node = self.network1.add_node(2, SAMPLE_EDS)

self.network2 = canopen.Network()
self.network2.connect("test", interface="virtual")
self.local_node = self.network2.create_node(2, SAMPLE_EDS)
self.remote_node2 = self.network1.add_node(3, SAMPLE_EDS)
self.local_node2 = self.network2.create_node(3, SAMPLE_EDS)
with self.assertLogs():
self.local_node = self.network2.create_node(2, SAMPLE_EDS)
self.remote_node2 = self.network1.add_node(3, SAMPLE_EDS)
self.local_node2 = self.network2.create_node(3, SAMPLE_EDS)

def tearDown(self):
self.network1.disconnect()
Expand Down