Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 85 additions & 8 deletions smpclient/mcuboot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
from enum import IntEnum, IntFlag, unique
from functools import cached_property
from io import BufferedReader, BytesIO
from typing import Dict, Final, List
from typing import Annotated, Any, Dict, Final, List, Union

from intelhex import hex2bin # type: ignore
from pydantic import Field, GetCoreSchemaHandler
from pydantic.dataclasses import dataclass
from pydantic_core import CoreSchema, core_schema

IMAGE_MAGIC: Final = 0x96F3B83D
IMAGE_HEADER_SIZE: Final = 32
Expand Down Expand Up @@ -60,12 +62,21 @@ class IMAGE_F(IntFlag):

@unique
class IMAGE_TLV(IntEnum):
"""Image trailer TLV types."""
"""Image trailer TLV types.

Specification: https://docs.mcuboot.com/design.html#image-format
"""

KEYHASH = 0x01
"""hash of the public key"""
"""Hash of the public key"""
PUBKEY = 0x02
"""Public key"""
SHA256 = 0x10
"""SHA256 of image hdr and body"""
SHA384 = 0x11
"""SHA384 of image hdr and body"""
SHA512 = 0x12
"""SHA512 of image hdr and body"""
RSA2048_PSS = 0x20
"""RSA2048 of hash output"""
ECDSA224 = 0x21
Expand All @@ -76,6 +87,8 @@ class IMAGE_TLV(IntEnum):
"""RSA3072 of hash output"""
ED25519 = 0x24
"""ED25519 of hash output"""
SIG_PURE = 0x25
"""Signature prepared over full image rather than digest"""
ENC_RSA2048 = 0x30
"""Key encrypted with RSA-OAEP-2048"""
ENC_KW = 0x31
Expand All @@ -84,10 +97,69 @@ class IMAGE_TLV(IntEnum):
"""Key encrypted with ECIES-P256"""
ENC_X25519 = 0x33
"""Key encrypted with ECIES-X25519"""
ENC_X25519_SHA512 = 0x34
"""Key exchange using X25519 with SHA512 MAC"""
DEPENDENCY = 0x40
"""Image depends on other image"""
SEC_CNT = 0x50
"""security counter"""
"""Security counter"""
BOOT_RECORD = 0x60
"""Measured boot record"""
DECOMP_SIZE = 0x70
"""Decompressed image size excluding header/TLVs"""
DECOMP_SHA = 0x71
"""Decompressed image hash matching format of compressed slot"""
DECOMP_SIGNATURE = 0x72
"""Decompressed image signature matching compressed format"""
COMP_DEC_SIZE = 0x73
"""Compressed decrypted image size"""
UUID_VID = 0x80
"""Vendor unique identifier"""
UUID_CID = 0x81
"""Device class unique identifier"""


class VendorTLV(int):
"""Vendor-defined TLV type in reserved ranges (0xXXA0-0xXXFE).

Vendor reserved TLVs occupy ranges from 0xXXA0 to 0xXXFE, where XX
represents any upper byte value. Examples include ranges 0x00A0-0x00FF,
0x01A0-0x01FF, and 0x02A0-0x02FF, continuing through 0xFFA0-0xFFFE.
"""

def __new__(cls, value: int) -> 'VendorTLV':
"""Create a new VendorTLV, validating the range."""
lower_byte = value & 0xFF
if not (0xA0 <= lower_byte <= 0xFE):
raise ValueError(
f"VendorTLV 0x{value:02x} must have lower byte in range 0xA0-0xFE, "
f"got 0x{lower_byte:02x}"
)
return int.__new__(cls, value)

@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> CoreSchema:
def validate(value: int) -> VendorTLV:
return cls(value)

return core_schema.no_info_after_validator_function(
validate,
core_schema.int_schema(),
)


ImageTLVType = Annotated[Union[IMAGE_TLV, VendorTLV, int], Field(union_mode="left_to_right")]
"""TLV type that accepts standard IMAGE_TLV enums, vendor-defined TLVs, or any integer.

This uses Pydantic's "left to right" union mode to:
1. First try to match against IMAGE_TLV enum values
2. Then try to validate as a VendorTLV (0xXXA0-0xXXFE ranges)
3. Finally accept any integer as a fallback

This ensures backward compatibility and supports future TLV types without validation errors.
"""


@dataclass(frozen=True)
Expand Down Expand Up @@ -189,7 +261,7 @@ def load_from(file: BytesIO | BufferedReader) -> 'ImageTLVInfo':
class ImageTLV:
"""A TLV header - type and length."""

type: IMAGE_TLV
type: ImageTLVType
len: int
"""Data length (not including TLV header)."""

Expand All @@ -209,7 +281,12 @@ def __post_init__(self) -> None:
raise MCUBootImageError(f"TLV requires length {self.header.len}, got {len(self.value)}")

def __str__(self) -> str:
return f"{self.header.type.name}={self.value.hex()}"
type_name = (
self.header.type.name
if isinstance(self.header.type, IMAGE_TLV)
else f"0x{self.header.type:02x}"
)
return f"{type_name}={self.value.hex()}"


@dataclass(frozen=True)
Expand All @@ -221,7 +298,7 @@ class ImageInfo:
tlvs: List[ImageTLVValue]
file: str | None = None

def get_tlv(self, tlv: IMAGE_TLV) -> ImageTLVValue:
def get_tlv(self, tlv: ImageTLVType) -> ImageTLVValue:
"""Get a TLV from the image or raise `TLVNotFound`."""
if tlv in self._map_tlv_type_to_value:
return self._map_tlv_type_to_value[tlv]
Expand Down Expand Up @@ -263,7 +340,7 @@ def load_file(path: str) -> 'ImageInfo':
return ImageInfo(file=path, header=image_header, tlv_info=tlv_info, tlvs=tlvs)

@cached_property
def _map_tlv_type_to_value(self) -> Dict[IMAGE_TLV, ImageTLVValue]:
def _map_tlv_type_to_value(self) -> Dict[int, ImageTLVValue]:
return {tlv.header.type: tlv for tlv in self.tlvs}

def __str__(self) -> str:
Expand Down
119 changes: 119 additions & 0 deletions tests/test_mcuboot_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
IMAGE_TLV_INFO_MAGIC,
ImageHeader,
ImageInfo,
ImageTLV,
ImageTLVType,
ImageTLVValue,
ImageVersion,
VendorTLV,
)


Expand Down Expand Up @@ -113,3 +117,118 @@ def test_ImageVersion() -> None:
assert v.revision == 0xFFFF
assert v.build_num == 0xFFFFFFFF
assert str(v) == "1.255.65535-build4294967295"


def test_pubkey_tlv_exists() -> None:
"""Test that PUBKEY (0x02) TLV type exists.

https://github.com/intercreate/smpclient/issues/83
"""
assert IMAGE_TLV.PUBKEY == 0x02
assert IMAGE_TLV.PUBKEY.name == "PUBKEY"


def test_standard_tlv_coercion() -> None:
"""Test that standard TLV values are coerced to IMAGE_TLV enum."""
# PUBKEY (the bug fix!)
tlv = ImageTLV(type=0x02, len=256)
assert isinstance(tlv.type, IMAGE_TLV)
assert tlv.type == IMAGE_TLV.PUBKEY
assert tlv.type.name == "PUBKEY"

# SHA256
tlv = ImageTLV(type=0x10, len=32)
assert isinstance(tlv.type, IMAGE_TLV)
assert tlv.type == IMAGE_TLV.SHA256

# SHA384
tlv = ImageTLV(type=0x11, len=48)
assert isinstance(tlv.type, IMAGE_TLV)
assert tlv.type == IMAGE_TLV.SHA384


def test_vendor_tlv_validation() -> None:
"""Test that vendor TLV ranges are validated correctly."""
# Lower byte 0xA0-0xFE should be valid vendor TLVs
tlv = ImageTLV(type=0xA0, len=16)
assert isinstance(tlv.type, VendorTLV)
assert tlv.type == 0xA0

tlv = ImageTLV(type=0xFE, len=8)
assert isinstance(tlv.type, VendorTLV)
assert tlv.type == 0xFE

# Multi-byte vendor TLVs
tlv = ImageTLV(type=0x01A0, len=16)
assert isinstance(tlv.type, VendorTLV)
assert tlv.type == 0x01A0

tlv = ImageTLV(type=0xFFFE, len=4)
assert isinstance(tlv.type, VendorTLV)
assert tlv.type == 0xFFFE


def test_unknown_tlv_fallback() -> None:
"""Test that unknown TLV types fall back to int without error."""
# This should not raise a validation error
tlv = ImageTLV(type=0x99, len=8)
assert isinstance(tlv.type, int)
assert tlv.type == 0x99

# Another unknown type
tlv = ImageTLV(type=0x05, len=4)
assert isinstance(tlv.type, int)
assert tlv.type == 0x05


def test_tlv_type_union_order() -> None:
"""Test that union resolution follows left-to-right order."""
from pydantic import TypeAdapter

adapter: TypeAdapter[ImageTLVType] = TypeAdapter(ImageTLVType)

# Standard TLV should match IMAGE_TLV first
result = adapter.validate_python(0x02)
assert isinstance(result, IMAGE_TLV)
assert result == IMAGE_TLV.PUBKEY

# Vendor TLV should validate
result = adapter.validate_python(0xA0)
assert isinstance(result, int)
assert result == 0xA0

# Unknown TLV should fallback to int
result = adapter.validate_python(0x99)
assert isinstance(result, int)
assert result == 0x99


def test_tlv_value_str_standard() -> None:
"""Test __str__ with standard IMAGE_TLV enum types."""
# PUBKEY
tlv_header = ImageTLV(type=0x02, len=4)
tlv_value = ImageTLVValue(header=tlv_header, value=b"\x00\x01\x02\x03")
assert str(tlv_value) == "PUBKEY=00010203"

# SHA256
tlv_header = ImageTLV(type=0x10, len=4)
tlv_value = ImageTLVValue(header=tlv_header, value=b"\xAA\xBB\xCC\xDD")
assert str(tlv_value) == "SHA256=aabbccdd"


def test_tlv_value_str_vendor() -> None:
"""Test __str__ with vendor TLV types (should show hex)."""
tlv_header = ImageTLV(type=0xA0, len=4)
tlv_value = ImageTLVValue(header=tlv_header, value=b"\xFF\xFF\xFF\xFF")
assert str(tlv_value) == "0xa0=ffffffff"

tlv_header = ImageTLV(type=0xFE, len=2)
tlv_value = ImageTLVValue(header=tlv_header, value=b"\x12\x34")
assert str(tlv_value) == "0xfe=1234"


def test_tlv_value_str_unknown() -> None:
"""Test __str__ with unknown TLV types (should show hex)."""
tlv_header = ImageTLV(type=0x99, len=4)
tlv_value = ImageTLVValue(header=tlv_header, value=b"\xDE\xAD\xBE\xEF")
assert str(tlv_value) == "0x99=deadbeef"
Loading