Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 64 additions & 27 deletions PyStageLinQ/MessageClasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def read_network_string(self, frame, start_offset):

if data_stop > len(frame):
# Out of bounds
return
raise PyStageLinQError.INVALIDLENGTH

return data_stop, frame[data_start:data_stop].decode(encoding="UTF-16be")

Expand All @@ -63,6 +63,9 @@ def __init__(self):
self.connection_type = None
self.token = StageLinQToken()
self.length = 0
self.min_length = (
self.magic_flag_length + StageLinQToken.TOKENLENGTH + self.network_len_size
)

def encode_frame(self, discovery_data: StageLinQDiscoveryData) -> bytes:
if self.verify_discovery_data(discovery_data) != PyStageLinQError.STAGELINQOK:
Expand Down Expand Up @@ -102,6 +105,9 @@ def get(self):
)

def decode_frame(self, frame):
if len(frame) < self.length:
return PyStageLinQError.INVALIDLENGTH

# Local Constants
token_start = self.magic_flag_stop
token_length = StageLinQToken.TOKENLENGTH
Expand All @@ -118,23 +124,34 @@ def decode_frame(self, frame):
):
return PyStageLinQError.MAGICFLAGNOTFOUND

token_valid = self.token.set_token(
self.token.set_token(
int.from_bytes(frame[token_start:token_stop], byteorder="big")
)

if token_valid != PyStageLinQError.STAGELINQOK:
return token_valid
try:
connection_type_start, self.device_name = self.read_network_string(
frame, device_name_size_start
)

connection_type_start, self.device_name = self.read_network_string(
frame, device_name_size_start
)
sw_name_start, self.connection_type = self.read_network_string(
frame, connection_type_start
)
sw_version_start, self.sw_name = self.read_network_string(frame, sw_name_start)
port_start, self.sw_version = self.read_network_string(frame, sw_version_start)
sw_name_start, self.connection_type = self.read_network_string(
frame, connection_type_start
)

sw_version_start, self.sw_name = self.read_network_string(
frame, sw_name_start
)
port_start, self.sw_version = self.read_network_string(
frame, sw_version_start
)

except PyStageLinQError.INVALIDLENGTH:
return PyStageLinQError.INVALIDLENGTH

port_stop = port_start + self.port_length

if len(frame) < port_stop:
return PyStageLinQError.INVALIDLENGTH

self.Port = int.from_bytes(frame[port_start:port_stop], byteorder="big")
self.length = port_stop
return PyStageLinQError.STAGELINQOK
Expand All @@ -146,11 +163,15 @@ class StageLinQServiceAnnouncement(StageLinQMessage):
Port: int | type(None)

def __init__(self):
self.Token = None
self.Token = StageLinQToken()
self.Service = None
self.Port = None
self.length = None

self.min_length = (
self.magic_flag_length + StageLinQToken.TOKENLENGTH + self.network_len_size
)

def encode_frame(
self, service_announcement_data: StageLinQServiceAnnouncementData
) -> bytes:
Expand All @@ -163,8 +184,9 @@ def encode_frame(
return request_frame

def decode_frame(self, frame):
if len(frame) < 4:
return PyStageLinQError.INVALIDFRAME
if len(frame) < self.min_length:
return PyStageLinQError.INVALIDLENGTH

# Verify frame type
if (
frame[self.magic_flag_start : self.magic_flag_stop]
Expand All @@ -176,10 +198,22 @@ def decode_frame(self, frame):
token_stop = token_start + StageLinQToken.TOKENLENGTH
service_name_start = token_stop

self.Token = frame[token_start:token_stop]
self.Token.set_token(
(0).from_bytes(frame[token_start:token_stop], byteorder="big")
)

try:
port_start, self.Service = self.read_network_string(
frame, service_name_start
)
except PyStageLinQError.INVALIDLENGTH:
return PyStageLinQError.INVALIDLENGTH

port_start, self.Service = self.read_network_string(frame, service_name_start)
port_stop = port_start + self.port_length

if len(frame) < port_stop:
return PyStageLinQError.INVALIDLENGTH

self.Port = int.from_bytes(frame[port_start:port_stop], byteorder="big")
self.length = port_stop
return PyStageLinQError.STAGELINQOK
Expand All @@ -196,8 +230,8 @@ class StageLinQReference(StageLinQMessage):
Reference: int | type(None)

def __init__(self):
self.OwnToken = None
self.DeviceToken = None
self.OwnToken = StageLinQToken()
self.DeviceToken = StageLinQToken()
self.Reference = None
self.length = (
self.magic_flag_length + StageLinQToken.TOKENLENGTH * 2 + self.reference_len
Expand All @@ -209,15 +243,13 @@ def encode_frame(reference_data) -> StageLinQReferenceData:
request_frame += reference_data.OwnToken.get_token().to_bytes(
StageLinQToken.TOKENLENGTH, byteorder="big"
)
request_frame += reference_data.DeviceToken.get_token().to_bytes(
StageLinQToken.TOKENLENGTH, byteorder="big"
)
request_frame += 0x00.to_bytes(StageLinQToken.TOKENLENGTH, byteorder="big")
request_frame += reference_data.Reference.to_bytes(8, byteorder="big")
return request_frame

def decode_frame(self, frame):
if len(frame) < self.magic_flag_length:
return PyStageLinQError.INVALIDFRAME
if len(frame) < self.length:
return PyStageLinQError.INVALIDLENGTH

# Verify frame type
if (
Expand All @@ -233,8 +265,13 @@ def decode_frame(self, frame):
reference_start = device_token_stop
reference_stop = reference_start + self.reference_len

self.OwnToken = frame[own_token_start:own_token_stop]
self.DeviceToken = frame[device_token_start:device_token_stop]
self.OwnToken.set_token(
(0).from_bytes(frame[own_token_start:own_token_stop], byteorder="big")
)

self.DeviceToken.set_token(
(0).from_bytes(frame[device_token_start:device_token_stop], byteorder="big")
)
self.Reference = int.from_bytes(
frame[reference_start:reference_stop], byteorder="big"
)
Expand Down Expand Up @@ -267,7 +304,7 @@ def encode_frame(service_request_data) -> StageLinQServiceRequestService:
def decode_frame(self, frame):
# Verify frame
if len(frame) < self.length:
return PyStageLinQError.INVALIDFRAME
return PyStageLinQError.INVALIDLENGTH
if (
frame[self.magic_flag_start : self.magic_flag_stop]
!= StageLinQMessageIDs.StageLinQServiceRequestData
Expand Down
42 changes: 29 additions & 13 deletions PyStageLinQ/Network.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@
from __future__ import annotations
import asyncio
from . import EngineServices
from .DataClasses import (
StageLinQServiceAnnouncementData,
StageLinQReferenceData,
StageLinQServiceRequestService,
)
from .MessageClasses import *
from . import Token
from typing import Callable
from typing import Callable, Tuple, List, Any


class StageLinQService:
Expand Down Expand Up @@ -40,6 +45,10 @@ def __init__(

self.service_found_callback = service_found_callback

self.remaining_data = None

self.debug = []

def get_services(self) -> list[EngineServices]:
if self.services_available:
return_list = []
Expand Down Expand Up @@ -123,11 +132,15 @@ async def _receive_frames(
raise RuntimeError(
f"Remote socket for IP:{self.Ip} Port:{self.Port} closed!"
)
frames = self.decode_multiframe(response)

if self.remaining_data is not None:
response = b"".join([self.remaining_data, response])
frames, self.remaining_data = self.decode_multiframe(response)
if frames is None:
# Something went wrong during decoding, lets throw away the frame and hope it doesn't happen again
print(f"Error while decoding the frame")
return False
self.last_frame = response
return frames

async def _handle_frames(self, frames: bytes) -> None:
Expand All @@ -142,7 +155,9 @@ async def _handle_frames(self, frames: bytes) -> None:
self._set_device_token(frame)

elif type(frame) is StageLinQReferenceData:
asyncio.create_task(self.send_reference_message())
# Do not send a new frame if
if frame.OwnToken != self.OwnToken.get_token().to_bytes(16, "big"):
asyncio.create_task(self.send_reference_message())

if adding_services:
await self._handle_new_services()
Expand Down Expand Up @@ -185,13 +200,9 @@ async def send_reference_message(self) -> None:
@staticmethod
def decode_multiframe(
frame: bytes,
) -> list[
StageLinQServiceAnnouncementData
| StageLinQReferenceData
| StageLinQServiceRequestService
]:
) -> tuple[list[Any], None | bytes] | None:
subframes = []
while len(frame) > 4:
while len(frame) >= 4:
match (int.from_bytes(frame[0:4], byteorder="big")):
case 0:
data = StageLinQServiceAnnouncement()
Expand All @@ -201,11 +212,16 @@ def decode_multiframe(
data = StageLinQRequestServices()
case _:
# invalid data, return
return
if data.decode_frame(frame) != PyStageLinQError.STAGELINQOK:
return None
return None
decode_status = data.decode_frame(frame)

if decode_status != PyStageLinQError.STAGELINQOK:
if decode_status == PyStageLinQError.INVALIDLENGTH:
return subframes, frame
else:
return None

subframes.append(data.get())
frame = frame[data.get_len() :]

return subframes
return subframes, None
9 changes: 3 additions & 6 deletions PyStageLinQ/Token.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,16 @@ def _get_randomized_bytes(length: int) -> bytes:
def get_token(self) -> int:
return self.token

def set_token(self, token: int) -> PyStageLinQError:
def set_token(self, token: int) -> None:
if type(token) == int:
if self.validate_token(token) == PyStageLinQError.STAGELINQOK:
self.token = token
ret = PyStageLinQError.STAGELINQOK
else:
# Token could not be Validated
ret = PyStageLinQError.INVALIDTOKEN
raise PyStageLinQError.INVALIDTOKEN
else:
# Token is not of type int
ret = PyStageLinQError.INVALIDTOKENTYPE

return ret
raise PyStageLinQError.INVALIDTOKENTYPE

@staticmethod
def validate_token(token: int) -> PyStageLinQError:
Expand Down
2 changes: 1 addition & 1 deletion tests/Main.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def state_map_data_print(data):
def main():
global PrimeGo
PrimeGo = PyStageLinQ.PyStageLinQ(
new_device_found_callback, name="Jaxcie StageLinQ"
new_device_found_callback, name="Jaxcie StageLinQ", ip="255.255.255.255"
)
PrimeGo.start_standalone()

Expand Down
14 changes: 7 additions & 7 deletions tests/unit/unit_Network.py → tests/unit/test_unit_Network.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ class reader:
reader_dummy = reader()
dummy_stagelinq_service.reader = reader_dummy

decode_multiframe_mock = Mock(side_effect=[None])
decode_multiframe_mock = Mock(side_effect=[[None, None]])
monkeypatch.setattr(
dummy_stagelinq_service, "decode_multiframe", decode_multiframe_mock
)
Expand All @@ -313,7 +313,7 @@ class reader:
reader_dummy = reader()
dummy_stagelinq_service.reader = reader_dummy

decode_multiframe_mock = Mock(side_effect=[frames_data])
decode_multiframe_mock = Mock(side_effect=[[frames_data, None]])
monkeypatch.setattr(
dummy_stagelinq_service, "decode_multiframe", decode_multiframe_mock
)
Expand Down Expand Up @@ -586,11 +586,11 @@ class writer_dummy:


def test_decode_multiframe_no_frame(dummy_stagelinq_service):
assert dummy_stagelinq_service.decode_multiframe(bytes()) == []
assert dummy_stagelinq_service.decode_multiframe(bytes()) == ([], None)


def test_decode_multiframe_short_frame(dummy_stagelinq_service):
assert dummy_stagelinq_service.decode_multiframe(b"1234") == []
assert dummy_stagelinq_service.decode_multiframe(b"0000") == None


def test_decode_multiframe_service_announcement(dummy_stagelinq_service, monkeypatch):
Expand All @@ -612,7 +612,7 @@ class service_announcement_dummy:
PyStageLinQ.Network, "StageLinQServiceAnnouncement", service_announcement_dummy
)

assert dummy_stagelinq_service.decode_multiframe(service) == [frame_data]
assert dummy_stagelinq_service.decode_multiframe(service) == ([frame_data], None)

service_announcement_mock.decode_frame.assert_called_once_with(service)
service_announcement_mock.get.assert_called_once()
Expand All @@ -637,7 +637,7 @@ class service_announcement_dummy:
PyStageLinQ.Network, "StageLinQReference", service_announcement_dummy
)

assert dummy_stagelinq_service.decode_multiframe(service) == [frame_data]
assert dummy_stagelinq_service.decode_multiframe(service) == ([frame_data], None)

service_announcement_mock.decode_frame.assert_called_once_with(service)
service_announcement_mock.get.assert_called_once()
Expand All @@ -663,7 +663,7 @@ class service_announcement_dummy:
PyStageLinQ.Network, "StageLinQRequestServices", service_announcement_dummy
)

assert dummy_stagelinq_service.decode_multiframe(service) == [frame_data]
assert dummy_stagelinq_service.decode_multiframe(service) == ([frame_data], None)

service_announcement_mock.decode_frame.assert_called_once_with(service)
service_announcement_mock.get.assert_called_once()
Expand Down
File renamed without changes.
8 changes: 5 additions & 3 deletions tests/unit/unit_Token.py → tests/unit/test_unit_Token.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def mock_random_msb1(length):
def test_set_token_wrong_input_type(token):
# test type
testValue = None
assert token.set_token(testValue) == PyStageLinQError.INVALIDTOKENTYPE
with pytest.raises(Exception):
token.set_token(testValue)


def test_set_token_valid_input(token, monkeypatch):
Expand All @@ -55,7 +56,7 @@ def ret_ok(_):

monkeypatch.setattr(token, "validate_token", ret_ok)
testValue = 0
assert token.set_token(testValue) == PyStageLinQError.STAGELINQOK
token.set_token(testValue)
assert token.get_token() == testValue


Expand All @@ -65,7 +66,8 @@ def ret_nok(_):

testValue = 0
monkeypatch.setattr(token, "validate_token", ret_nok)
assert token.set_token(testValue) == PyStageLinQError.INVALIDTOKEN
with pytest.raises(Exception):
assert token.set_token(testValue)


def test_validate_token_ok(token):
Expand Down
File renamed without changes.
Loading