From 73bd7da76919af7b5cd48c0998fdd55e27492a7b Mon Sep 17 00:00:00 2001 From: Jaxc Date: Wed, 29 Nov 2023 13:47:32 +0100 Subject: [PATCH 1/6] Reformat code according to Python Black --- tests/Main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/Main.py b/tests/Main.py index 4778592..6ed9878 100644 --- a/tests/Main.py +++ b/tests/Main.py @@ -45,7 +45,7 @@ def state_map_data_print(data): def main(): global PrimeGo PrimeGo = PyStageLinQ.PyStageLinQ( - new_device_found_callback, name="Jaxcie StageLinQ" + new_device_found_callback, name="Jaxcie StageLinQ", ip="255.255.255.255" ) PrimeGo.start_standalone() From 04e29ad35be31d73e0648707335216268b13337c Mon Sep 17 00:00:00 2001 From: Jaxc Date: Wed, 29 Nov 2023 16:57:12 +0100 Subject: [PATCH 2/6] Fix a bunch of issues discovered in Linux --- PyStageLinQ/MessageClasses.py | 58 +++++++++++++++++++++++++---------- PyStageLinQ/Network.py | 38 +++++++++++++++-------- 2 files changed, 67 insertions(+), 29 deletions(-) diff --git a/PyStageLinQ/MessageClasses.py b/PyStageLinQ/MessageClasses.py index 3025b59..6884e53 100644 --- a/PyStageLinQ/MessageClasses.py +++ b/PyStageLinQ/MessageClasses.py @@ -39,7 +39,7 @@ def read_network_string(self, frame, start_offset): if data_stop > len(frame): # Out of bounds - return + raise PyStageLinQError.INVALIDLENGTH return data_stop, frame[data_start:data_stop].decode(encoding="UTF-16be") @@ -102,6 +102,9 @@ def get(self): ) def decode_frame(self, frame): + if len(frame) < self.length: + return PyStageLinQError.INVALIDLENGTH + # Local Constants token_start = self.magic_flag_stop token_length = StageLinQToken.TOKENLENGTH @@ -125,16 +128,26 @@ def decode_frame(self, frame): if token_valid != PyStageLinQError.STAGELINQOK: return token_valid - connection_type_start, self.device_name = self.read_network_string( - frame, device_name_size_start - ) - sw_name_start, self.connection_type = self.read_network_string( - frame, connection_type_start - ) - sw_version_start, self.sw_name = self.read_network_string(frame, sw_name_start) - port_start, self.sw_version = self.read_network_string(frame, sw_version_start) + try: + connection_type_start, self.device_name = self.read_network_string( + frame, device_name_size_start + ) + + sw_name_start, self.connection_type = self.read_network_string( + frame, connection_type_start + ) + + sw_version_start, self.sw_name = self.read_network_string(frame, sw_name_start) + port_start, self.sw_version = self.read_network_string(frame, sw_version_start) + + except PyStageLinQError.INVALIDLENGTH: + return PyStageLinQError.INVALIDLENGTH port_stop = port_start + self.port_length + + if len(frame) < port_stop: + return PyStageLinQError.INVALIDLENGTH + self.Port = int.from_bytes(frame[port_start:port_stop], byteorder="big") self.length = port_stop return PyStageLinQError.STAGELINQOK @@ -151,6 +164,10 @@ def __init__(self): self.Port = None self.length = None + self.min_length = ( + self.magic_flag_length + StageLinQToken.TOKENLENGTH + self.network_len_size + ) + def encode_frame( self, service_announcement_data: StageLinQServiceAnnouncementData ) -> bytes: @@ -163,8 +180,9 @@ def encode_frame( return request_frame def decode_frame(self, frame): - if len(frame) < 4: - return PyStageLinQError.INVALIDFRAME + if len(frame) < self.min_length: + return PyStageLinQError.INVALIDLENGTH + # Verify frame type if ( frame[self.magic_flag_start : self.magic_flag_stop] @@ -178,8 +196,16 @@ def decode_frame(self, frame): self.Token = frame[token_start:token_stop] - port_start, self.Service = self.read_network_string(frame, service_name_start) + try: + port_start, self.Service = self.read_network_string(frame, service_name_start) + except PyStageLinQError.INVALIDLENGTH: + return PyStageLinQError.INVALIDLENGTH + port_stop = port_start + self.port_length + + if len(frame) < port_stop: + return PyStageLinQError.INVALIDLENGTH + self.Port = int.from_bytes(frame[port_start:port_stop], byteorder="big") self.length = port_stop return PyStageLinQError.STAGELINQOK @@ -209,15 +235,15 @@ def encode_frame(reference_data) -> StageLinQReferenceData: request_frame += reference_data.OwnToken.get_token().to_bytes( StageLinQToken.TOKENLENGTH, byteorder="big" ) - request_frame += reference_data.DeviceToken.get_token().to_bytes( + request_frame += 0x00.to_bytes( StageLinQToken.TOKENLENGTH, byteorder="big" ) request_frame += reference_data.Reference.to_bytes(8, byteorder="big") return request_frame def decode_frame(self, frame): - if len(frame) < self.magic_flag_length: - return PyStageLinQError.INVALIDFRAME + if len(frame) < self.length: + return PyStageLinQError.INVALIDLENGTH # Verify frame type if ( @@ -267,7 +293,7 @@ def encode_frame(service_request_data) -> StageLinQServiceRequestService: def decode_frame(self, frame): # Verify frame if len(frame) < self.length: - return PyStageLinQError.INVALIDFRAME + return PyStageLinQError.INVALIDLENGTH if ( frame[self.magic_flag_start : self.magic_flag_stop] != StageLinQMessageIDs.StageLinQServiceRequestData diff --git a/PyStageLinQ/Network.py b/PyStageLinQ/Network.py index 2c3456b..5593eb2 100644 --- a/PyStageLinQ/Network.py +++ b/PyStageLinQ/Network.py @@ -5,9 +5,10 @@ from __future__ import annotations import asyncio from . import EngineServices +from .DataClasses import StageLinQServiceAnnouncementData, StageLinQReferenceData, StageLinQServiceRequestService from .MessageClasses import * from . import Token -from typing import Callable +from typing import Callable, Tuple, List, Any class StageLinQService: @@ -40,6 +41,10 @@ def __init__( self.service_found_callback = service_found_callback + self.remaining_data = None + + self.debug = [] + def get_services(self) -> list[EngineServices]: if self.services_available: return_list = [] @@ -123,11 +128,15 @@ async def _receive_frames( raise RuntimeError( f"Remote socket for IP:{self.Ip} Port:{self.Port} closed!" ) - frames = self.decode_multiframe(response) + + if self.remaining_data is not None: + response = b''.join([self.remaining_data, response]) + frames, self.remaining_data = self.decode_multiframe(response) if frames is None: # Something went wrong during decoding, lets throw away the frame and hope it doesn't happen again print(f"Error while decoding the frame") return False + self.last_frame = response return frames async def _handle_frames(self, frames: bytes) -> None: @@ -142,7 +151,9 @@ async def _handle_frames(self, frames: bytes) -> None: self._set_device_token(frame) elif type(frame) is StageLinQReferenceData: - asyncio.create_task(self.send_reference_message()) + # Do not send a new frame if + if frame.OwnToken != self.OwnToken.get_token().to_bytes(16, "big"): + asyncio.create_task(self.send_reference_message()) if adding_services: await self._handle_new_services() @@ -185,13 +196,9 @@ async def send_reference_message(self) -> None: @staticmethod def decode_multiframe( frame: bytes, - ) -> list[ - StageLinQServiceAnnouncementData - | StageLinQReferenceData - | StageLinQServiceRequestService - ]: + ) -> tuple[list[Any], None | bytes ] | None: subframes = [] - while len(frame) > 4: + while len(frame) >= 4: match (int.from_bytes(frame[0:4], byteorder="big")): case 0: data = StageLinQServiceAnnouncement() @@ -201,11 +208,16 @@ def decode_multiframe( data = StageLinQRequestServices() case _: # invalid data, return - return - if data.decode_frame(frame) != PyStageLinQError.STAGELINQOK: - return None + return None + decode_status = data.decode_frame(frame) + + if decode_status != PyStageLinQError.STAGELINQOK: + if decode_status == PyStageLinQError.INVALIDLENGTH: + return subframes, frame + else: + return None subframes.append(data.get()) frame = frame[data.get_len() :] - return subframes + return subframes, None From 1c25ba1655573fa16b0c71138cf152b31a4b56a2 Mon Sep 17 00:00:00 2001 From: Jaxc Date: Thu, 30 Nov 2023 14:11:29 +0100 Subject: [PATCH 3/6] Change messages classes to use token class solves #22 --- PyStageLinQ/MessageClasses.py | 13 +++++++------ PyStageLinQ/Token.py | 8 +++----- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/PyStageLinQ/MessageClasses.py b/PyStageLinQ/MessageClasses.py index 6884e53..a2702a2 100644 --- a/PyStageLinQ/MessageClasses.py +++ b/PyStageLinQ/MessageClasses.py @@ -159,7 +159,7 @@ class StageLinQServiceAnnouncement(StageLinQMessage): Port: int | type(None) def __init__(self): - self.Token = None + self.Token = StageLinQToken() self.Service = None self.Port = None self.length = None @@ -194,7 +194,7 @@ def decode_frame(self, frame): token_stop = token_start + StageLinQToken.TOKENLENGTH service_name_start = token_stop - self.Token = frame[token_start:token_stop] + self.Token.set_token((0).from_bytes(frame[token_start:token_stop], byteorder="big")) try: port_start, self.Service = self.read_network_string(frame, service_name_start) @@ -222,8 +222,8 @@ class StageLinQReference(StageLinQMessage): Reference: int | type(None) def __init__(self): - self.OwnToken = None - self.DeviceToken = None + self.OwnToken = StageLinQToken() + self.DeviceToken = StageLinQToken() self.Reference = None self.length = ( self.magic_flag_length + StageLinQToken.TOKENLENGTH * 2 + self.reference_len @@ -259,8 +259,9 @@ def decode_frame(self, frame): reference_start = device_token_stop reference_stop = reference_start + self.reference_len - self.OwnToken = frame[own_token_start:own_token_stop] - self.DeviceToken = frame[device_token_start:device_token_stop] + self.OwnToken.set_token((0).from_bytes(frame[own_token_start:own_token_stop], byteorder="big")) + + self.DeviceToken.set_token((0).from_bytes(frame[device_token_start:device_token_stop], byteorder="big")) self.Reference = int.from_bytes( frame[reference_start:reference_stop], byteorder="big" ) diff --git a/PyStageLinQ/Token.py b/PyStageLinQ/Token.py index fe342ed..44bb9b6 100644 --- a/PyStageLinQ/Token.py +++ b/PyStageLinQ/Token.py @@ -39,19 +39,17 @@ def _get_randomized_bytes(length: int) -> bytes: def get_token(self) -> int: return self.token - def set_token(self, token: int) -> PyStageLinQError: + def set_token(self, token: int) -> None: if type(token) == int: if self.validate_token(token) == PyStageLinQError.STAGELINQOK: self.token = token - ret = PyStageLinQError.STAGELINQOK else: # Token could not be Validated - ret = PyStageLinQError.INVALIDTOKEN + raise PyStageLinQError.INVALIDTOKEN else: # Token is not of type int - ret = PyStageLinQError.INVALIDTOKENTYPE + raise PyStageLinQError.INVALIDTOKENTYPE - return ret @staticmethod def validate_token(token: int) -> PyStageLinQError: From ea5c0b92d52ab4da3b09789ee577c25e5d14880c Mon Sep 17 00:00:00 2001 From: Jaxc Date: Thu, 30 Nov 2023 14:17:26 +0100 Subject: [PATCH 4/6] Fix unittests --- ...Services.py => test_unit_EngineServices.py} | 0 .../{unit_Network.py => test_unit_Network.py} | 14 +++++++------- ...PyStageLinQ.py => test_unit_PyStageLinQ.py} | 0 .../unit/{unit_Token.py => test_unit_Token.py} | 8 +++++--- .../{unit_device.py => test_unit_device.py} | 0 ...y => test_unit_messageClasses_Discovery.py} | 0 ....py => test_unit_messageClasses_Message.py} | 3 ++- ...y => test_unit_messageClasses_Reference.py} | 18 +++++++----------- ...est_unit_messageClasses_RequestServices.py} | 4 ++-- ...unit_messageClasses_ServiceAnnouncement.py} | 12 +++++------- 10 files changed, 28 insertions(+), 31 deletions(-) rename tests/unit/{unit_EngineServices.py => test_unit_EngineServices.py} (100%) rename tests/unit/{unit_Network.py => test_unit_Network.py} (95%) rename tests/unit/{unit_PyStageLinQ.py => test_unit_PyStageLinQ.py} (100%) rename tests/unit/{unit_Token.py => test_unit_Token.py} (87%) rename tests/unit/{unit_device.py => test_unit_device.py} (100%) rename tests/unit/{unit_messageClasses_Discovery.py => test_unit_messageClasses_Discovery.py} (100%) rename tests/unit/{unit_messageClasses_Message.py => test_unit_messageClasses_Message.py} (91%) rename tests/unit/{unit_messageClasses_Reference.py => test_unit_messageClasses_Reference.py} (82%) rename tests/unit/{unit_messageClasses_RequestServices.py => test_unit_messageClasses_RequestServices.py} (92%) rename tests/unit/{unit_messageClasses_ServiceAnnouncement.py => test_unit_messageClasses_ServiceAnnouncement.py} (88%) diff --git a/tests/unit/unit_EngineServices.py b/tests/unit/test_unit_EngineServices.py similarity index 100% rename from tests/unit/unit_EngineServices.py rename to tests/unit/test_unit_EngineServices.py diff --git a/tests/unit/unit_Network.py b/tests/unit/test_unit_Network.py similarity index 95% rename from tests/unit/unit_Network.py rename to tests/unit/test_unit_Network.py index c6336f0..34bedf0 100644 --- a/tests/unit/unit_Network.py +++ b/tests/unit/test_unit_Network.py @@ -293,7 +293,7 @@ class reader: reader_dummy = reader() dummy_stagelinq_service.reader = reader_dummy - decode_multiframe_mock = Mock(side_effect=[None]) + decode_multiframe_mock = Mock(side_effect=[[None, None]]) monkeypatch.setattr( dummy_stagelinq_service, "decode_multiframe", decode_multiframe_mock ) @@ -313,7 +313,7 @@ class reader: reader_dummy = reader() dummy_stagelinq_service.reader = reader_dummy - decode_multiframe_mock = Mock(side_effect=[frames_data]) + decode_multiframe_mock = Mock(side_effect=[[frames_data, None]]) monkeypatch.setattr( dummy_stagelinq_service, "decode_multiframe", decode_multiframe_mock ) @@ -586,11 +586,11 @@ class writer_dummy: def test_decode_multiframe_no_frame(dummy_stagelinq_service): - assert dummy_stagelinq_service.decode_multiframe(bytes()) == [] + assert dummy_stagelinq_service.decode_multiframe(bytes()) == ([], None) def test_decode_multiframe_short_frame(dummy_stagelinq_service): - assert dummy_stagelinq_service.decode_multiframe(b"1234") == [] + assert dummy_stagelinq_service.decode_multiframe(b"0000") == None def test_decode_multiframe_service_announcement(dummy_stagelinq_service, monkeypatch): @@ -612,7 +612,7 @@ class service_announcement_dummy: PyStageLinQ.Network, "StageLinQServiceAnnouncement", service_announcement_dummy ) - assert dummy_stagelinq_service.decode_multiframe(service) == [frame_data] + assert dummy_stagelinq_service.decode_multiframe(service) == ([frame_data], None) service_announcement_mock.decode_frame.assert_called_once_with(service) service_announcement_mock.get.assert_called_once() @@ -637,7 +637,7 @@ class service_announcement_dummy: PyStageLinQ.Network, "StageLinQReference", service_announcement_dummy ) - assert dummy_stagelinq_service.decode_multiframe(service) == [frame_data] + assert dummy_stagelinq_service.decode_multiframe(service) == ([frame_data], None) service_announcement_mock.decode_frame.assert_called_once_with(service) service_announcement_mock.get.assert_called_once() @@ -663,7 +663,7 @@ class service_announcement_dummy: PyStageLinQ.Network, "StageLinQRequestServices", service_announcement_dummy ) - assert dummy_stagelinq_service.decode_multiframe(service) == [frame_data] + assert dummy_stagelinq_service.decode_multiframe(service) == ([frame_data], None) service_announcement_mock.decode_frame.assert_called_once_with(service) service_announcement_mock.get.assert_called_once() diff --git a/tests/unit/unit_PyStageLinQ.py b/tests/unit/test_unit_PyStageLinQ.py similarity index 100% rename from tests/unit/unit_PyStageLinQ.py rename to tests/unit/test_unit_PyStageLinQ.py diff --git a/tests/unit/unit_Token.py b/tests/unit/test_unit_Token.py similarity index 87% rename from tests/unit/unit_Token.py rename to tests/unit/test_unit_Token.py index 2cee217..3933231 100644 --- a/tests/unit/unit_Token.py +++ b/tests/unit/test_unit_Token.py @@ -46,7 +46,8 @@ def mock_random_msb1(length): def test_set_token_wrong_input_type(token): # test type testValue = None - assert token.set_token(testValue) == PyStageLinQError.INVALIDTOKENTYPE + with pytest.raises(Exception): + token.set_token(testValue) def test_set_token_valid_input(token, monkeypatch): @@ -55,7 +56,7 @@ def ret_ok(_): monkeypatch.setattr(token, "validate_token", ret_ok) testValue = 0 - assert token.set_token(testValue) == PyStageLinQError.STAGELINQOK + token.set_token(testValue) assert token.get_token() == testValue @@ -65,7 +66,8 @@ def ret_nok(_): testValue = 0 monkeypatch.setattr(token, "validate_token", ret_nok) - assert token.set_token(testValue) == PyStageLinQError.INVALIDTOKEN + with pytest.raises(Exception): + assert token.set_token(testValue) def test_validate_token_ok(token): diff --git a/tests/unit/unit_device.py b/tests/unit/test_unit_device.py similarity index 100% rename from tests/unit/unit_device.py rename to tests/unit/test_unit_device.py diff --git a/tests/unit/unit_messageClasses_Discovery.py b/tests/unit/test_unit_messageClasses_Discovery.py similarity index 100% rename from tests/unit/unit_messageClasses_Discovery.py rename to tests/unit/test_unit_messageClasses_Discovery.py diff --git a/tests/unit/unit_messageClasses_Message.py b/tests/unit/test_unit_messageClasses_Message.py similarity index 91% rename from tests/unit/unit_messageClasses_Message.py rename to tests/unit/test_unit_messageClasses_Message.py index 1fb8fff..823426f 100644 --- a/tests/unit/unit_messageClasses_Message.py +++ b/tests/unit/test_unit_messageClasses_Message.py @@ -30,7 +30,8 @@ def test_write_network_string(stagelinq_message): def test_read_network_string_incorrect_length(stagelinq_message): test_data = (20).to_bytes(4, byteorder="big") + "hello".encode(encoding="UTF-16be") - assert stagelinq_message.read_network_string(test_data, 0) is None + with pytest.raises(Exception): + stagelinq_message.read_network_string(test_data, 0) def test_read_network_string_valid_input(stagelinq_message): diff --git a/tests/unit/unit_messageClasses_Reference.py b/tests/unit/test_unit_messageClasses_Reference.py similarity index 82% rename from tests/unit/unit_messageClasses_Reference.py rename to tests/unit/test_unit_messageClasses_Reference.py index 5994caf..0dbc94e 100644 --- a/tests/unit/unit_messageClasses_Reference.py +++ b/tests/unit/test_unit_messageClasses_Reference.py @@ -32,8 +32,8 @@ def stagelinq_reference(): def test_init_values(stagelinq_reference): - assert stagelinq_reference.OwnToken is None - assert stagelinq_reference.DeviceToken is None + assert type(stagelinq_reference.OwnToken) is PyStageLinQ.Token.StageLinQToken + assert type(stagelinq_reference.DeviceToken) is PyStageLinQ.Token.StageLinQToken assert stagelinq_reference.Reference is None assert stagelinq_reference.length == 44 assert stagelinq_reference.port_length == 2 @@ -56,20 +56,20 @@ def test_encode_frame(stagelinq_reference, owntoken, devicetoken): == test_output[0:4] ) assert owntoken.get_token().to_bytes(16, byteorder="big") == test_output[4:20] - assert devicetoken.get_token().to_bytes(16, byteorder="big") == test_output[20:36] + assert (0).to_bytes(16, byteorder="big") == test_output[20:36] assert (313).to_bytes(8, byteorder="big") == test_output[36:44] def test_decode_frame_invalid_magic_flag_length(stagelinq_reference): assert ( stagelinq_reference.decode_frame(random.randbytes(3)) - == PyStageLinQError.INVALIDFRAME + == PyStageLinQError.INVALIDLENGTH ) def test_decode_frame_invalid_frame_id(stagelinq_reference): assert ( - stagelinq_reference.decode_frame("airJ".encode()) + stagelinq_reference.decode_frame(("airJ"*20).encode()) == PyStageLinQError.INVALIDFRAME ) @@ -84,12 +84,8 @@ def test_decode_frame_valid_input(stagelinq_reference, owntoken, devicetoken): assert stagelinq_reference.decode_frame(dummy_frame) == PyStageLinQError.STAGELINQOK - assert stagelinq_reference.OwnToken == owntoken.get_token().to_bytes( - 16, byteorder="big" - ) - assert stagelinq_reference.DeviceToken == devicetoken.get_token().to_bytes( - 16, byteorder="big" - ) + assert stagelinq_reference.OwnToken.get_token() == owntoken.get_token() + assert stagelinq_reference.DeviceToken.get_token() == devicetoken.get_token() assert stagelinq_reference.Reference == 313 diff --git a/tests/unit/unit_messageClasses_RequestServices.py b/tests/unit/test_unit_messageClasses_RequestServices.py similarity index 92% rename from tests/unit/unit_messageClasses_RequestServices.py rename to tests/unit/test_unit_messageClasses_RequestServices.py index 6ee31b3..6022195 100644 --- a/tests/unit/unit_messageClasses_RequestServices.py +++ b/tests/unit/test_unit_messageClasses_RequestServices.py @@ -43,10 +43,10 @@ def test_encode_frame(stagelinq_request_services, dummy_token): assert dummy_token.get_token().to_bytes(16, byteorder="big") == test_output[4:20] -def test_decode_frame_invalid_magic_flag_length(stagelinq_request_services): +def test_decode_frame_invalid_length(stagelinq_request_services): assert ( stagelinq_request_services.decode_frame(random.randbytes(3)) - == PyStageLinQError.INVALIDFRAME + == PyStageLinQError.INVALIDLENGTH ) diff --git a/tests/unit/unit_messageClasses_ServiceAnnouncement.py b/tests/unit/test_unit_messageClasses_ServiceAnnouncement.py similarity index 88% rename from tests/unit/unit_messageClasses_ServiceAnnouncement.py rename to tests/unit/test_unit_messageClasses_ServiceAnnouncement.py index 618f040..4dbbd11 100644 --- a/tests/unit/unit_messageClasses_ServiceAnnouncement.py +++ b/tests/unit/test_unit_messageClasses_ServiceAnnouncement.py @@ -25,7 +25,7 @@ def stagelinq_service_announcement(): def test_init_values(stagelinq_service_announcement): - assert stagelinq_service_announcement.Token is None + assert type(stagelinq_service_announcement.Token) is PyStageLinQ.Token.StageLinQToken assert stagelinq_service_announcement.Service is None assert stagelinq_service_announcement.Port is None assert stagelinq_service_announcement.length is None @@ -65,16 +65,16 @@ def write_network_string_mock(string): assert dummy_port.to_bytes(2, byteorder="big") == test_output[31:33] -def test_decode_frame_invalid_magic_flag_length(stagelinq_service_announcement): +def test_decode_frame_invalid_length(stagelinq_service_announcement): assert ( stagelinq_service_announcement.decode_frame(random.randbytes(3)) - == PyStageLinQError.INVALIDFRAME + == PyStageLinQError.INVALIDLENGTH ) def test_decode_frame_invalid_frame_id(stagelinq_service_announcement): assert ( - stagelinq_service_announcement.decode_frame("airJ".encode()) + stagelinq_service_announcement.decode_frame(("airJ"*20).encode()) == PyStageLinQError.INVALIDFRAME ) @@ -104,9 +104,7 @@ def read_network_string_mock(_, start_offset): == PyStageLinQError.STAGELINQOK ) - assert stagelinq_service_announcement.Token == dummy_token.get_token().to_bytes( - 16, byteorder="big" - ) + assert stagelinq_service_announcement.Token.get_token() == dummy_token.get_token() assert stagelinq_service_announcement.Port.to_bytes( 2, byteorder="big" ) == dummy_port.to_bytes(2, byteorder="big") From 48ce737801b0c66c32a78e888f452327329975c2 Mon Sep 17 00:00:00 2001 From: Jaxc Date: Thu, 30 Nov 2023 14:21:54 +0100 Subject: [PATCH 5/6] Fix black formatting --- PyStageLinQ/MessageClasses.py | 28 +++++++++++++------ PyStageLinQ/Network.py | 10 +++++-- PyStageLinQ/Token.py | 1 - .../test_unit_messageClasses_Reference.py | 2 +- ...unit_messageClasses_ServiceAnnouncement.py | 6 ++-- 5 files changed, 31 insertions(+), 16 deletions(-) diff --git a/PyStageLinQ/MessageClasses.py b/PyStageLinQ/MessageClasses.py index a2702a2..3422936 100644 --- a/PyStageLinQ/MessageClasses.py +++ b/PyStageLinQ/MessageClasses.py @@ -137,8 +137,12 @@ def decode_frame(self, frame): frame, connection_type_start ) - sw_version_start, self.sw_name = self.read_network_string(frame, sw_name_start) - port_start, self.sw_version = self.read_network_string(frame, sw_version_start) + sw_version_start, self.sw_name = self.read_network_string( + frame, sw_name_start + ) + port_start, self.sw_version = self.read_network_string( + frame, sw_version_start + ) except PyStageLinQError.INVALIDLENGTH: return PyStageLinQError.INVALIDLENGTH @@ -194,10 +198,14 @@ def decode_frame(self, frame): token_stop = token_start + StageLinQToken.TOKENLENGTH service_name_start = token_stop - self.Token.set_token((0).from_bytes(frame[token_start:token_stop], byteorder="big")) + self.Token.set_token( + (0).from_bytes(frame[token_start:token_stop], byteorder="big") + ) try: - port_start, self.Service = self.read_network_string(frame, service_name_start) + port_start, self.Service = self.read_network_string( + frame, service_name_start + ) except PyStageLinQError.INVALIDLENGTH: return PyStageLinQError.INVALIDLENGTH @@ -235,9 +243,7 @@ def encode_frame(reference_data) -> StageLinQReferenceData: request_frame += reference_data.OwnToken.get_token().to_bytes( StageLinQToken.TOKENLENGTH, byteorder="big" ) - request_frame += 0x00.to_bytes( - StageLinQToken.TOKENLENGTH, byteorder="big" - ) + request_frame += 0x00.to_bytes(StageLinQToken.TOKENLENGTH, byteorder="big") request_frame += reference_data.Reference.to_bytes(8, byteorder="big") return request_frame @@ -259,9 +265,13 @@ def decode_frame(self, frame): reference_start = device_token_stop reference_stop = reference_start + self.reference_len - self.OwnToken.set_token((0).from_bytes(frame[own_token_start:own_token_stop], byteorder="big")) + self.OwnToken.set_token( + (0).from_bytes(frame[own_token_start:own_token_stop], byteorder="big") + ) - self.DeviceToken.set_token((0).from_bytes(frame[device_token_start:device_token_stop], byteorder="big")) + self.DeviceToken.set_token( + (0).from_bytes(frame[device_token_start:device_token_stop], byteorder="big") + ) self.Reference = int.from_bytes( frame[reference_start:reference_stop], byteorder="big" ) diff --git a/PyStageLinQ/Network.py b/PyStageLinQ/Network.py index 5593eb2..b1c3442 100644 --- a/PyStageLinQ/Network.py +++ b/PyStageLinQ/Network.py @@ -5,7 +5,11 @@ from __future__ import annotations import asyncio from . import EngineServices -from .DataClasses import StageLinQServiceAnnouncementData, StageLinQReferenceData, StageLinQServiceRequestService +from .DataClasses import ( + StageLinQServiceAnnouncementData, + StageLinQReferenceData, + StageLinQServiceRequestService, +) from .MessageClasses import * from . import Token from typing import Callable, Tuple, List, Any @@ -130,7 +134,7 @@ async def _receive_frames( ) if self.remaining_data is not None: - response = b''.join([self.remaining_data, response]) + response = b"".join([self.remaining_data, response]) frames, self.remaining_data = self.decode_multiframe(response) if frames is None: # Something went wrong during decoding, lets throw away the frame and hope it doesn't happen again @@ -196,7 +200,7 @@ async def send_reference_message(self) -> None: @staticmethod def decode_multiframe( frame: bytes, - ) -> tuple[list[Any], None | bytes ] | None: + ) -> tuple[list[Any], None | bytes] | None: subframes = [] while len(frame) >= 4: match (int.from_bytes(frame[0:4], byteorder="big")): diff --git a/PyStageLinQ/Token.py b/PyStageLinQ/Token.py index 44bb9b6..b1e117a 100644 --- a/PyStageLinQ/Token.py +++ b/PyStageLinQ/Token.py @@ -50,7 +50,6 @@ def set_token(self, token: int) -> None: # Token is not of type int raise PyStageLinQError.INVALIDTOKENTYPE - @staticmethod def validate_token(token: int) -> PyStageLinQError: # The token is validated by converting it to a 16 byte array and then back to an int. If the value is the same diff --git a/tests/unit/test_unit_messageClasses_Reference.py b/tests/unit/test_unit_messageClasses_Reference.py index 0dbc94e..c24bfe3 100644 --- a/tests/unit/test_unit_messageClasses_Reference.py +++ b/tests/unit/test_unit_messageClasses_Reference.py @@ -69,7 +69,7 @@ def test_decode_frame_invalid_magic_flag_length(stagelinq_reference): def test_decode_frame_invalid_frame_id(stagelinq_reference): assert ( - stagelinq_reference.decode_frame(("airJ"*20).encode()) + stagelinq_reference.decode_frame(("airJ" * 20).encode()) == PyStageLinQError.INVALIDFRAME ) diff --git a/tests/unit/test_unit_messageClasses_ServiceAnnouncement.py b/tests/unit/test_unit_messageClasses_ServiceAnnouncement.py index 4dbbd11..39891fc 100644 --- a/tests/unit/test_unit_messageClasses_ServiceAnnouncement.py +++ b/tests/unit/test_unit_messageClasses_ServiceAnnouncement.py @@ -25,7 +25,9 @@ def stagelinq_service_announcement(): def test_init_values(stagelinq_service_announcement): - assert type(stagelinq_service_announcement.Token) is PyStageLinQ.Token.StageLinQToken + assert ( + type(stagelinq_service_announcement.Token) is PyStageLinQ.Token.StageLinQToken + ) assert stagelinq_service_announcement.Service is None assert stagelinq_service_announcement.Port is None assert stagelinq_service_announcement.length is None @@ -74,7 +76,7 @@ def test_decode_frame_invalid_length(stagelinq_service_announcement): def test_decode_frame_invalid_frame_id(stagelinq_service_announcement): assert ( - stagelinq_service_announcement.decode_frame(("airJ"*20).encode()) + stagelinq_service_announcement.decode_frame(("airJ" * 20).encode()) == PyStageLinQError.INVALIDFRAME ) From 8d0a797b60bdc68722f1790f05bba857c60876d8 Mon Sep 17 00:00:00 2001 From: Jaxc Date: Thu, 30 Nov 2023 15:12:54 +0100 Subject: [PATCH 6/6] Fix length bug in DiscoveryMessageClass --- PyStageLinQ/MessageClasses.py | 8 ++++---- tests/unit/test_unit_messageClasses_Discovery.py | 8 ++------ 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/PyStageLinQ/MessageClasses.py b/PyStageLinQ/MessageClasses.py index 3422936..f1b3614 100644 --- a/PyStageLinQ/MessageClasses.py +++ b/PyStageLinQ/MessageClasses.py @@ -63,6 +63,9 @@ def __init__(self): self.connection_type = None self.token = StageLinQToken() self.length = 0 + self.min_length = ( + self.magic_flag_length + StageLinQToken.TOKENLENGTH + self.network_len_size + ) def encode_frame(self, discovery_data: StageLinQDiscoveryData) -> bytes: if self.verify_discovery_data(discovery_data) != PyStageLinQError.STAGELINQOK: @@ -121,13 +124,10 @@ def decode_frame(self, frame): ): return PyStageLinQError.MAGICFLAGNOTFOUND - token_valid = self.token.set_token( + self.token.set_token( int.from_bytes(frame[token_start:token_stop], byteorder="big") ) - if token_valid != PyStageLinQError.STAGELINQOK: - return token_valid - try: connection_type_start, self.device_name = self.read_network_string( frame, device_name_size_start diff --git a/tests/unit/test_unit_messageClasses_Discovery.py b/tests/unit/test_unit_messageClasses_Discovery.py index 63ae0bb..697b589 100644 --- a/tests/unit/test_unit_messageClasses_Discovery.py +++ b/tests/unit/test_unit_messageClasses_Discovery.py @@ -155,10 +155,8 @@ def set_token_invalid_type(_): monkeypatch.setattr(stagelinq_discovery.token, "set_token", set_token_invalid_type) - assert ( + with pytest.raises(Exception): stagelinq_discovery.decode_frame("airD".encode()) - == PyStageLinQError.INVALIDTOKENTYPE - ) def test_decode_frame_invalid_token_length(stagelinq_discovery, monkeypatch): @@ -167,10 +165,8 @@ def set_token_invalid(_): monkeypatch.setattr(stagelinq_discovery.token, "set_token", set_token_invalid) - assert ( + with pytest.raises(Exception): stagelinq_discovery.decode_frame("airD".encode()) - == PyStageLinQError.INVALIDTOKEN - ) def test_decode_frame_valid_input(stagelinq_discovery, monkeypatch, dummy_port):