diff --git a/.vscode/settings.json b/.vscode/settings.json index 25c04eb..3db9dc6 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -7,7 +7,7 @@ "editor.defaultFormatter": "ms-python.black-formatter", "editor.formatOnSave": true, "editor.codeActionsOnSave": { - "source.organizeImports": "explicit" + "source.organizeImports.isort": "explicit" }, }, "python.testing.pytestArgs": [ diff --git a/envr-default b/envr-default index 6ce8b7e..15617a9 100644 --- a/envr-default +++ b/envr-default @@ -7,5 +7,5 @@ PYTHON_VENV=.venv [ADD_TO_PATH] [ALIASES] -lint=black --check . && isort --check-only . && flake8 . && pydoclint smpclient && mypy . +lint=black --check . && isort --check-only --diff . && flake8 . && pydoclint smpclient && mypy . test=coverage erase && pytest --cov --maxfail=1 \ No newline at end of file diff --git a/smpclient/transport/udp.py b/smpclient/transport/udp.py index 2823e44..c06cb6c 100644 --- a/smpclient/transport/udp.py +++ b/smpclient/transport/udp.py @@ -2,6 +2,7 @@ import asyncio import logging +from socket import AF_INET6 from typing import Final from smp import header as smphdr @@ -13,15 +14,39 @@ logger = logging.getLogger(__name__) +IPV4_HEADER_SIZE: Final = 20 +"""Minimum IPv4 header size in bytes.""" + +IPV6_HEADER_SIZE: Final = 40 +"""IPv6 header size in bytes.""" + +UDP_HEADER_SIZE: Final = 8 +"""UDP header size in bytes.""" + +IPV4_UDP_OVERHEAD: Final = IPV4_HEADER_SIZE + UDP_HEADER_SIZE +"""Total overhead (28 bytes) to subtract from MTU to get maximum UDP payload (MSS) for IPv4. + +Per RFC 8085 section 3.2, applications must subtract IP and UDP header sizes from the +PMTU to avoid fragmentation.""" + +IPV6_UDP_OVERHEAD: Final = IPV6_HEADER_SIZE + UDP_HEADER_SIZE +"""Total overhead (48 bytes) to subtract from MTU to get maximum UDP payload (MSS) for IPv6. + +Per RFC 8085 section 3.2, applications must subtract IP and UDP header sizes from the +PMTU to avoid fragmentation.""" + class SMPUDPTransport(SMPTransport): def __init__(self, mtu: int = 1500) -> None: """Initialize the SMP UDP transport. Args: - mtu: The Maximum Transmission Unit (MTU) in 8-bit bytes. + mtu: The Maximum Transmission Unit (MTU) of the link layer in bytes. + IP and UDP header overhead will be subtracted to calculate the maximum + UDP payload size (MSS) to avoid fragmentation per RFC 8085 section 3.2. """ self._mtu = mtu + self._is_ipv6 = False self._client: Final = UDPClient() @@ -29,6 +54,11 @@ def __init__(self, mtu: int = 1500) -> None: async def connect(self, address: str, timeout_s: float, port: int = 1337) -> None: logger.debug(f"Connecting to {address=} {port=}") await asyncio.wait_for(self._client.connect(Addr(host=address, port=port)), timeout_s) + + if sock := self._client._transport.get_extra_info('socket'): + self._is_ipv6 = sock.family == AF_INET6 + logger.debug(f"Detected {'IPv6' if self._is_ipv6 else 'IPv4'} connection") + logger.info(f"Connected to {address=} {port=}") @override @@ -104,4 +134,10 @@ def mtu(self) -> int: @override @property def max_unencoded_size(self) -> int: - return self._mtu + """Maximum UDP payload size (MSS) to avoid fragmentation. + + Subtracts IPv4/IPv6 and UDP header overhead from MTU per RFC 8085 section 3.2. + The IP version is auto-detected after connection. + """ + overhead = IPV6_UDP_OVERHEAD if self._is_ipv6 else IPV4_UDP_OVERHEAD + return self._mtu - overhead diff --git a/tests/test_smp_udp_transport.py b/tests/test_smp_udp_transport.py index 4b2eba5..3859bca 100644 --- a/tests/test_smp_udp_transport.py +++ b/tests/test_smp_udp_transport.py @@ -9,7 +9,7 @@ from smpclient.exceptions import SMPClientException from smpclient.requests.os_management import EchoWrite from smpclient.transport._udp_client import Addr, UDPClient -from smpclient.transport.udp import SMPUDPTransport +from smpclient.transport.udp import IPV4_UDP_OVERHEAD, IPV6_UDP_OVERHEAD, SMPUDPTransport def test_init() -> None: @@ -27,6 +27,10 @@ async def test_connect(_: MagicMock) -> None: t = SMPUDPTransport() t._client = cast(MagicMock, t._client) # type: ignore + # Mock _transport for IPv4/IPv6 detection + t._client._transport = MagicMock() + t._client._transport.get_extra_info.return_value = None + await t.connect("192.168.0.1", 0.001) t._client.connect.assert_awaited_once_with(Addr(host="192.168.0.1", port=1337)) @@ -110,3 +114,48 @@ async def test_send_and_receive() -> None: await t.send_and_receive(message) send_mock.assert_awaited_once_with(message) receive_mock.assert_awaited_once() + + +def test_max_unencoded_size_ipv4() -> None: + """Test MSS calculation for IPv4 (default).""" + t = SMPUDPTransport(mtu=1500) + # Before connection, defaults to IPv4 + assert t.max_unencoded_size == 1500 - IPV4_UDP_OVERHEAD + assert t.max_unencoded_size == 1472 + + +def test_max_unencoded_size_custom_mtu() -> None: + """Test MSS calculation with custom MTU.""" + t = SMPUDPTransport(mtu=512) + assert t.max_unencoded_size == 512 - IPV4_UDP_OVERHEAD + assert t.max_unencoded_size == 484 + + +@pytest.mark.asyncio +async def test_ipv4_detection_real_socket() -> None: + """Test IPv4 auto-detection with real socket connection.""" + t = SMPUDPTransport(mtu=1500) + + # Create a real UDP connection to localhost IPv4 + await t.connect("127.0.0.1", 1.0) + + assert t._is_ipv6 is False + assert t.max_unencoded_size == 1500 - IPV4_UDP_OVERHEAD + assert t.max_unencoded_size == 1472 + + await t.disconnect() + + +@pytest.mark.asyncio +async def test_ipv6_detection_real_socket() -> None: + """Test IPv6 auto-detection with real socket connection.""" + t = SMPUDPTransport(mtu=1500) + + # Create a real UDP connection to localhost IPv6 + await t.connect("::1", 1.0) + + assert t._is_ipv6 is True + assert t.max_unencoded_size == 1500 - IPV6_UDP_OVERHEAD + assert t.max_unencoded_size == 1452 + + await t.disconnect()