Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.organizeImports": "explicit"
"source.organizeImports.isort": "explicit"
},
},
"python.testing.pytestArgs": [
Expand Down
2 changes: 1 addition & 1 deletion envr-default
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ PYTHON_VENV=.venv
[ADD_TO_PATH]

[ALIASES]
lint=black --check . && isort --check-only . && flake8 . && pydoclint smpclient && mypy .
lint=black --check . && isort --check-only --diff . && flake8 . && pydoclint smpclient && mypy .
test=coverage erase && pytest --cov --maxfail=1
40 changes: 38 additions & 2 deletions smpclient/transport/udp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import logging
from socket import AF_INET6
from typing import Final

from smp import header as smphdr
Expand All @@ -13,22 +14,51 @@

logger = logging.getLogger(__name__)

IPV4_HEADER_SIZE: Final = 20
"""Minimum IPv4 header size in bytes."""

IPV6_HEADER_SIZE: Final = 40
"""IPv6 header size in bytes."""

UDP_HEADER_SIZE: Final = 8
"""UDP header size in bytes."""

IPV4_UDP_OVERHEAD: Final = IPV4_HEADER_SIZE + UDP_HEADER_SIZE
"""Total overhead (28 bytes) to subtract from MTU to get maximum UDP payload (MSS) for IPv4.

Per RFC 8085 section 3.2, applications must subtract IP and UDP header sizes from the
PMTU to avoid fragmentation."""

IPV6_UDP_OVERHEAD: Final = IPV6_HEADER_SIZE + UDP_HEADER_SIZE
"""Total overhead (48 bytes) to subtract from MTU to get maximum UDP payload (MSS) for IPv6.

Per RFC 8085 section 3.2, applications must subtract IP and UDP header sizes from the
PMTU to avoid fragmentation."""


class SMPUDPTransport(SMPTransport):
def __init__(self, mtu: int = 1500) -> None:
"""Initialize the SMP UDP transport.

Args:
mtu: The Maximum Transmission Unit (MTU) in 8-bit bytes.
mtu: The Maximum Transmission Unit (MTU) of the link layer in bytes.
IP and UDP header overhead will be subtracted to calculate the maximum
UDP payload size (MSS) to avoid fragmentation per RFC 8085 section 3.2.
"""
self._mtu = mtu
self._is_ipv6 = False

self._client: Final = UDPClient()

@override
async def connect(self, address: str, timeout_s: float, port: int = 1337) -> None:
logger.debug(f"Connecting to {address=} {port=}")
await asyncio.wait_for(self._client.connect(Addr(host=address, port=port)), timeout_s)

if sock := self._client._transport.get_extra_info('socket'):
self._is_ipv6 = sock.family == AF_INET6
logger.debug(f"Detected {'IPv6' if self._is_ipv6 else 'IPv4'} connection")

logger.info(f"Connected to {address=} {port=}")

@override
Expand Down Expand Up @@ -104,4 +134,10 @@ def mtu(self) -> int:
@override
@property
def max_unencoded_size(self) -> int:
return self._mtu
"""Maximum UDP payload size (MSS) to avoid fragmentation.

Subtracts IPv4/IPv6 and UDP header overhead from MTU per RFC 8085 section 3.2.
The IP version is auto-detected after connection.
"""
overhead = IPV6_UDP_OVERHEAD if self._is_ipv6 else IPV4_UDP_OVERHEAD
return self._mtu - overhead
51 changes: 50 additions & 1 deletion tests/test_smp_udp_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from smpclient.exceptions import SMPClientException
from smpclient.requests.os_management import EchoWrite
from smpclient.transport._udp_client import Addr, UDPClient
from smpclient.transport.udp import SMPUDPTransport
from smpclient.transport.udp import IPV4_UDP_OVERHEAD, IPV6_UDP_OVERHEAD, SMPUDPTransport


def test_init() -> None:
Expand All @@ -27,6 +27,10 @@ async def test_connect(_: MagicMock) -> None:
t = SMPUDPTransport()
t._client = cast(MagicMock, t._client) # type: ignore

# Mock _transport for IPv4/IPv6 detection
t._client._transport = MagicMock()
t._client._transport.get_extra_info.return_value = None

await t.connect("192.168.0.1", 0.001)
t._client.connect.assert_awaited_once_with(Addr(host="192.168.0.1", port=1337))

Expand Down Expand Up @@ -110,3 +114,48 @@ async def test_send_and_receive() -> None:
await t.send_and_receive(message)
send_mock.assert_awaited_once_with(message)
receive_mock.assert_awaited_once()


def test_max_unencoded_size_ipv4() -> None:
"""Test MSS calculation for IPv4 (default)."""
t = SMPUDPTransport(mtu=1500)
# Before connection, defaults to IPv4
assert t.max_unencoded_size == 1500 - IPV4_UDP_OVERHEAD
assert t.max_unencoded_size == 1472


def test_max_unencoded_size_custom_mtu() -> None:
"""Test MSS calculation with custom MTU."""
t = SMPUDPTransport(mtu=512)
assert t.max_unencoded_size == 512 - IPV4_UDP_OVERHEAD
assert t.max_unencoded_size == 484


@pytest.mark.asyncio
async def test_ipv4_detection_real_socket() -> None:
"""Test IPv4 auto-detection with real socket connection."""
t = SMPUDPTransport(mtu=1500)

# Create a real UDP connection to localhost IPv4
await t.connect("127.0.0.1", 1.0)

assert t._is_ipv6 is False
assert t.max_unencoded_size == 1500 - IPV4_UDP_OVERHEAD
assert t.max_unencoded_size == 1472

await t.disconnect()


@pytest.mark.asyncio
async def test_ipv6_detection_real_socket() -> None:
"""Test IPv6 auto-detection with real socket connection."""
t = SMPUDPTransport(mtu=1500)

# Create a real UDP connection to localhost IPv6
await t.connect("::1", 1.0)

assert t._is_ipv6 is True
assert t.max_unencoded_size == 1500 - IPV6_UDP_OVERHEAD
assert t.max_unencoded_size == 1452

await t.disconnect()
Loading