diff --git a/smpclient/mcuboot.py b/smpclient/mcuboot.py index 6f10c50..039767c 100644 --- a/smpclient/mcuboot.py +++ b/smpclient/mcuboot.py @@ -11,10 +11,12 @@ from enum import IntEnum, IntFlag, unique from functools import cached_property from io import BufferedReader, BytesIO -from typing import Dict, Final, List +from typing import Annotated, Any, Dict, Final, List, Union from intelhex import hex2bin # type: ignore +from pydantic import Field, GetCoreSchemaHandler from pydantic.dataclasses import dataclass +from pydantic_core import CoreSchema, core_schema IMAGE_MAGIC: Final = 0x96F3B83D IMAGE_HEADER_SIZE: Final = 32 @@ -60,12 +62,21 @@ class IMAGE_F(IntFlag): @unique class IMAGE_TLV(IntEnum): - """Image trailer TLV types.""" + """Image trailer TLV types. + + Specification: https://docs.mcuboot.com/design.html#image-format + """ KEYHASH = 0x01 - """hash of the public key""" + """Hash of the public key""" + PUBKEY = 0x02 + """Public key""" SHA256 = 0x10 """SHA256 of image hdr and body""" + SHA384 = 0x11 + """SHA384 of image hdr and body""" + SHA512 = 0x12 + """SHA512 of image hdr and body""" RSA2048_PSS = 0x20 """RSA2048 of hash output""" ECDSA224 = 0x21 @@ -76,6 +87,8 @@ class IMAGE_TLV(IntEnum): """RSA3072 of hash output""" ED25519 = 0x24 """ED25519 of hash output""" + SIG_PURE = 0x25 + """Signature prepared over full image rather than digest""" ENC_RSA2048 = 0x30 """Key encrypted with RSA-OAEP-2048""" ENC_KW = 0x31 @@ -84,10 +97,69 @@ class IMAGE_TLV(IntEnum): """Key encrypted with ECIES-P256""" ENC_X25519 = 0x33 """Key encrypted with ECIES-X25519""" + ENC_X25519_SHA512 = 0x34 + """Key exchange using X25519 with SHA512 MAC""" DEPENDENCY = 0x40 """Image depends on other image""" SEC_CNT = 0x50 - """security counter""" + """Security counter""" + BOOT_RECORD = 0x60 + """Measured boot record""" + DECOMP_SIZE = 0x70 + """Decompressed image size excluding header/TLVs""" + DECOMP_SHA = 0x71 + """Decompressed image hash matching format of compressed slot""" + DECOMP_SIGNATURE = 0x72 + """Decompressed image signature matching compressed format""" + COMP_DEC_SIZE = 0x73 + """Compressed decrypted image size""" + UUID_VID = 0x80 + """Vendor unique identifier""" + UUID_CID = 0x81 + """Device class unique identifier""" + + +class VendorTLV(int): + """Vendor-defined TLV type in reserved ranges (0xXXA0-0xXXFE). + + Vendor reserved TLVs occupy ranges from 0xXXA0 to 0xXXFE, where XX + represents any upper byte value. Examples include ranges 0x00A0-0x00FF, + 0x01A0-0x01FF, and 0x02A0-0x02FF, continuing through 0xFFA0-0xFFFE. + """ + + def __new__(cls, value: int) -> 'VendorTLV': + """Create a new VendorTLV, validating the range.""" + lower_byte = value & 0xFF + if not (0xA0 <= lower_byte <= 0xFE): + raise ValueError( + f"VendorTLV 0x{value:02x} must have lower byte in range 0xA0-0xFE, " + f"got 0x{lower_byte:02x}" + ) + return int.__new__(cls, value) + + @classmethod + def __get_pydantic_core_schema__( + cls, _source_type: Any, _handler: GetCoreSchemaHandler + ) -> CoreSchema: + def validate(value: int) -> VendorTLV: + return cls(value) + + return core_schema.no_info_after_validator_function( + validate, + core_schema.int_schema(), + ) + + +ImageTLVType = Annotated[Union[IMAGE_TLV, VendorTLV, int], Field(union_mode="left_to_right")] +"""TLV type that accepts standard IMAGE_TLV enums, vendor-defined TLVs, or any integer. + +This uses Pydantic's "left to right" union mode to: +1. First try to match against IMAGE_TLV enum values +2. Then try to validate as a VendorTLV (0xXXA0-0xXXFE ranges) +3. Finally accept any integer as a fallback + +This ensures backward compatibility and supports future TLV types without validation errors. +""" @dataclass(frozen=True) @@ -189,7 +261,7 @@ def load_from(file: BytesIO | BufferedReader) -> 'ImageTLVInfo': class ImageTLV: """A TLV header - type and length.""" - type: IMAGE_TLV + type: ImageTLVType len: int """Data length (not including TLV header).""" @@ -209,7 +281,12 @@ def __post_init__(self) -> None: raise MCUBootImageError(f"TLV requires length {self.header.len}, got {len(self.value)}") def __str__(self) -> str: - return f"{self.header.type.name}={self.value.hex()}" + type_name = ( + self.header.type.name + if isinstance(self.header.type, IMAGE_TLV) + else f"0x{self.header.type:02x}" + ) + return f"{type_name}={self.value.hex()}" @dataclass(frozen=True) @@ -221,7 +298,7 @@ class ImageInfo: tlvs: List[ImageTLVValue] file: str | None = None - def get_tlv(self, tlv: IMAGE_TLV) -> ImageTLVValue: + def get_tlv(self, tlv: ImageTLVType) -> ImageTLVValue: """Get a TLV from the image or raise `TLVNotFound`.""" if tlv in self._map_tlv_type_to_value: return self._map_tlv_type_to_value[tlv] @@ -263,7 +340,7 @@ def load_file(path: str) -> 'ImageInfo': return ImageInfo(file=path, header=image_header, tlv_info=tlv_info, tlvs=tlvs) @cached_property - def _map_tlv_type_to_value(self) -> Dict[IMAGE_TLV, ImageTLVValue]: + def _map_tlv_type_to_value(self) -> Dict[int, ImageTLVValue]: return {tlv.header.type: tlv for tlv in self.tlvs} def __str__(self) -> str: diff --git a/tests/test_mcuboot_tools.py b/tests/test_mcuboot_tools.py index 1171639..91ffa94 100644 --- a/tests/test_mcuboot_tools.py +++ b/tests/test_mcuboot_tools.py @@ -14,7 +14,11 @@ IMAGE_TLV_INFO_MAGIC, ImageHeader, ImageInfo, + ImageTLV, + ImageTLVType, + ImageTLVValue, ImageVersion, + VendorTLV, ) @@ -113,3 +117,118 @@ def test_ImageVersion() -> None: assert v.revision == 0xFFFF assert v.build_num == 0xFFFFFFFF assert str(v) == "1.255.65535-build4294967295" + + +def test_pubkey_tlv_exists() -> None: + """Test that PUBKEY (0x02) TLV type exists. + + https://github.com/intercreate/smpclient/issues/83 + """ + assert IMAGE_TLV.PUBKEY == 0x02 + assert IMAGE_TLV.PUBKEY.name == "PUBKEY" + + +def test_standard_tlv_coercion() -> None: + """Test that standard TLV values are coerced to IMAGE_TLV enum.""" + # PUBKEY (the bug fix!) + tlv = ImageTLV(type=0x02, len=256) + assert isinstance(tlv.type, IMAGE_TLV) + assert tlv.type == IMAGE_TLV.PUBKEY + assert tlv.type.name == "PUBKEY" + + # SHA256 + tlv = ImageTLV(type=0x10, len=32) + assert isinstance(tlv.type, IMAGE_TLV) + assert tlv.type == IMAGE_TLV.SHA256 + + # SHA384 + tlv = ImageTLV(type=0x11, len=48) + assert isinstance(tlv.type, IMAGE_TLV) + assert tlv.type == IMAGE_TLV.SHA384 + + +def test_vendor_tlv_validation() -> None: + """Test that vendor TLV ranges are validated correctly.""" + # Lower byte 0xA0-0xFE should be valid vendor TLVs + tlv = ImageTLV(type=0xA0, len=16) + assert isinstance(tlv.type, VendorTLV) + assert tlv.type == 0xA0 + + tlv = ImageTLV(type=0xFE, len=8) + assert isinstance(tlv.type, VendorTLV) + assert tlv.type == 0xFE + + # Multi-byte vendor TLVs + tlv = ImageTLV(type=0x01A0, len=16) + assert isinstance(tlv.type, VendorTLV) + assert tlv.type == 0x01A0 + + tlv = ImageTLV(type=0xFFFE, len=4) + assert isinstance(tlv.type, VendorTLV) + assert tlv.type == 0xFFFE + + +def test_unknown_tlv_fallback() -> None: + """Test that unknown TLV types fall back to int without error.""" + # This should not raise a validation error + tlv = ImageTLV(type=0x99, len=8) + assert isinstance(tlv.type, int) + assert tlv.type == 0x99 + + # Another unknown type + tlv = ImageTLV(type=0x05, len=4) + assert isinstance(tlv.type, int) + assert tlv.type == 0x05 + + +def test_tlv_type_union_order() -> None: + """Test that union resolution follows left-to-right order.""" + from pydantic import TypeAdapter + + adapter: TypeAdapter[ImageTLVType] = TypeAdapter(ImageTLVType) + + # Standard TLV should match IMAGE_TLV first + result = adapter.validate_python(0x02) + assert isinstance(result, IMAGE_TLV) + assert result == IMAGE_TLV.PUBKEY + + # Vendor TLV should validate + result = adapter.validate_python(0xA0) + assert isinstance(result, int) + assert result == 0xA0 + + # Unknown TLV should fallback to int + result = adapter.validate_python(0x99) + assert isinstance(result, int) + assert result == 0x99 + + +def test_tlv_value_str_standard() -> None: + """Test __str__ with standard IMAGE_TLV enum types.""" + # PUBKEY + tlv_header = ImageTLV(type=0x02, len=4) + tlv_value = ImageTLVValue(header=tlv_header, value=b"\x00\x01\x02\x03") + assert str(tlv_value) == "PUBKEY=00010203" + + # SHA256 + tlv_header = ImageTLV(type=0x10, len=4) + tlv_value = ImageTLVValue(header=tlv_header, value=b"\xAA\xBB\xCC\xDD") + assert str(tlv_value) == "SHA256=aabbccdd" + + +def test_tlv_value_str_vendor() -> None: + """Test __str__ with vendor TLV types (should show hex).""" + tlv_header = ImageTLV(type=0xA0, len=4) + tlv_value = ImageTLVValue(header=tlv_header, value=b"\xFF\xFF\xFF\xFF") + assert str(tlv_value) == "0xa0=ffffffff" + + tlv_header = ImageTLV(type=0xFE, len=2) + tlv_value = ImageTLVValue(header=tlv_header, value=b"\x12\x34") + assert str(tlv_value) == "0xfe=1234" + + +def test_tlv_value_str_unknown() -> None: + """Test __str__ with unknown TLV types (should show hex).""" + tlv_header = ImageTLV(type=0x99, len=4) + tlv_value = ImageTLVValue(header=tlv_header, value=b"\xDE\xAD\xBE\xEF") + assert str(tlv_value) == "0x99=deadbeef"