diff --git a/xrspatial/geotiff/_compression.py b/xrspatial/geotiff/_compression.py index 718189fac..f83f64e8a 100644 --- a/xrspatial/geotiff/_compression.py +++ b/xrspatial/geotiff/_compression.py @@ -978,6 +978,78 @@ def unpack_bits(data: np.ndarray, bps: int, pixel_count: int) -> np.ndarray: # -- PackBits (simple RLE) ---------------------------------------------------- +@ngjit +def _packbits_decode_kernel(src, src_len, dst, dst_cap): + """Decode PackBits (TIFF compression tag 32773) into a uint8 buffer. + + Parameters + ---------- + src : uint8 array + Compressed bytes. + src_len : int + Number of valid bytes in ``src``. + dst : uint8 array + Pre-allocated output buffer. + dst_cap : int + Maximum bytes to write before bailing out. When ``0``, the kernel + writes up to ``len(dst)`` (the wrapper sizes ``dst`` for the + no-cap path so this is safe). + + Returns + ------- + int + Number of bytes written. ``-1`` signals the cap was exceeded + (the wrapper turns this into ``ValueError``). + """ + out_pos = 0 + i = 0 + # When dst_cap is zero we use the full buffer length as the write limit. + if dst_cap > 0: + write_limit = dst_cap + enforce_cap = True + else: + write_limit = len(dst) + enforce_cap = False + + while i < src_len: + header = src[i] + i += 1 + if header < 128: + # Literal run of header + 1 bytes. + count = header + 1 + # Don't read past end of src on a truncated stream. + available = src_len - i + if count > available: + count = available + for k in range(count): + if out_pos >= write_limit: + if enforce_cap: + return -1 + # write_limit equals len(dst); refuse to write past it. + return out_pos + dst[out_pos] = src[i + k] + out_pos += 1 + i += count + elif header > 128: + # Replicate run: header interpreted as signed yields n in + # [-127, -1]; repeat count is 1 - n == 257 - header. + count = 257 - header + if i >= src_len: + break + byte_val = src[i] + i += 1 + for _ in range(count): + if out_pos >= write_limit: + if enforce_cap: + return -1 + return out_pos + dst[out_pos] = byte_val + out_pos += 1 + # header == 128: no-op marker. + + return out_pos + + def packbits_decompress(data: bytes, expected_size: int = 0) -> bytes: """Decompress PackBits (TIFF compression tag 32773). @@ -990,32 +1062,41 @@ def packbits_decompress(data: bytes, expected_size: int = 0) -> bytes: ``expected_size * 1.05 + 1`` bytes and raises ``ValueError`` on overflow (decompression-bomb guard). """ - src = data if isinstance(data, (bytes, bytearray)) else bytes(data) + if isinstance(data, (bytes, bytearray, memoryview)): + src = np.frombuffer(data, dtype=np.uint8) + else: + src = np.frombuffer(bytes(data), dtype=np.uint8) + + src_len = len(src) + if src_len == 0: + return b"" + cap = _max_output_with_margin(expected_size) - out = bytearray() - i = 0 - length = len(src) - while i < length: - n = src[i] - if n > 127: - n = n - 256 # interpret as signed - i += 1 - if 0 <= n <= 127: - count = n + 1 - out.extend(src[i:i + count]) - i += count - elif -127 <= n <= -1: - if i < length: - out.extend(bytes([src[i]]) * (1 - n)) - i += 1 - # n == -128: skip - if cap and len(out) > cap: - raise ValueError( - f"packbits decode exceeded expected size: {len(out)} bytes " - f"produced, cap is {cap} (expected {expected_size}). " - f"Likely a decompression bomb." - ) - return bytes(out) + if cap > 0: + # dst sized exactly to the cap; the kernel returns -1 the moment + # out_pos would reach write_limit (= cap), so we never need a + # sentinel byte past the cap. + dst = np.empty(cap, dtype=np.uint8) + n_written = _packbits_decode_kernel(src, src_len, dst, cap) + else: + # No cap supplied. PackBits expands by at most 128:1 (a single + # replicate header byte yields 128 output bytes), so this bound + # is tight enough to keep the decoder bounded even on adversarial + # input while still always fitting the legitimate output. The + # reader path always supplies expected_size, so this branch is + # only hit by direct callers and round-trip tests where the + # peak allocation is acceptable. + worst_case = src_len * 128 + dst = np.empty(worst_case, dtype=np.uint8) + n_written = _packbits_decode_kernel(src, src_len, dst, 0) + + if n_written < 0: + raise ValueError( + f"packbits decode exceeded expected size: produced more than " + f"{cap} bytes (cap = expected_size * 1.05 + 1, expected " + f"{expected_size}). Likely a decompression bomb." + ) + return bytes(dst[:n_written]) def packbits_compress(data: bytes) -> bytes: diff --git a/xrspatial/geotiff/tests/test_packbits_jit_2048.py b/xrspatial/geotiff/tests/test_packbits_jit_2048.py new file mode 100644 index 000000000..b46a0ac8e --- /dev/null +++ b/xrspatial/geotiff/tests/test_packbits_jit_2048.py @@ -0,0 +1,173 @@ +"""PackBits decode JIT kernel coverage (issue #2048). + +``packbits_decompress`` was reworked from a pure-Python ``while`` loop into a +numba ``@ngjit`` kernel wrapped by a thin bytes-in / bytes-out shim. These +tests pin the kernel against PackBits' boundary conditions: the 128-byte run +length boundary that switches literal/replicate encodings, max-length runs, +the ``-128`` no-op sentinel, empty input, and the decompression-bomb cap. + +The pre-existing PackBits coverage (``test_features.py`` round-trips and +``test_decompression_caps.py`` bomb guard) keeps passing too; this file just +fills in the bit-exact edge cases that the JIT rewrite is most likely to +regress on. +""" +from __future__ import annotations + +import pytest + +from xrspatial.geotiff._compression import ( + packbits_compress, + packbits_decompress, +) + + +# -- Bit-exact decode against known PackBits encodings ----------------------- + + +def test_packbits_decode_empty_input_returns_empty_bytes(): + assert packbits_decompress(b"") == b"" + + +def test_packbits_decode_single_literal_byte(): + # Header 0x00 -> "copy next 1 byte literally". + assert packbits_decompress(bytes([0x00, 0x42])) == b"\x42" + + +def test_packbits_decode_single_replicate_pair(): + # Header 0xFF (signed -1) -> "repeat next byte 2 times". + assert packbits_decompress(bytes([0xFF, 0x42])) == b"\x42\x42" + + +def test_packbits_decode_noop_sentinel_is_skipped(): + # Header 0x80 (signed -128) is a no-op marker; surrounding data still decodes. + assert packbits_decompress(bytes([0x80, 0x00, 0x42])) == b"\x42" + + +def test_packbits_decode_wikipedia_canonical_example(): + # From the PackBits Wikipedia article, normalized to a self-contained vector. + encoded = bytes( + [0xFE, 0xAA, 0x02, 0x80, 0x00, 0x2A, 0xFD, 0xAA, + 0x03, 0x80, 0x00, 0x2A, 0x22, 0xF7, 0xAA] + ) + expected = ( + b"\xAA" * 3 + + b"\x80\x00\x2A" + + b"\xAA" * 4 + + b"\x80\x00\x2A\x22" + + b"\xAA" * 10 + ) + assert packbits_decompress(encoded) == expected + + +# -- Run-length boundary cases (128-byte switch) ----------------------------- + + +def test_packbits_decode_max_literal_run_is_128_bytes(): + # Header 0x7F (127) -> 128 literal bytes follow. + literal = bytes(range(128)) + encoded = bytes([0x7F]) + literal + assert packbits_decompress(encoded) == literal + + +def test_packbits_decode_max_replicate_run_is_128_bytes(): + # Header 0x81 (signed -127) -> repeat next byte 128 times. + encoded = bytes([0x81, 0x42]) + assert packbits_decompress(encoded) == b"\x42" * 128 + + +def test_packbits_decode_back_to_back_max_runs(): + # Two back-to-back max-length runs land on the 128-byte boundary + # and exercise the literal -> replicate switch with no gap. + literal = bytes(range(128)) + encoded = bytes([0x7F]) + literal + bytes([0x81, 0x99]) + assert packbits_decompress(encoded) == literal + b"\x99" * 128 + + +# -- Round-trip parity with the (untouched) compressor ----------------------- + + +@pytest.mark.parametrize( + "payload", + [ + b"", + b"A", + b"AAAA", + b"ABCD" * 64, + bytes(range(256)), + b"\x00" * 1024, + (b"long literal stretch with no runs at all 1234567890" * 17), + ], +) +def test_packbits_roundtrip(payload): + assert packbits_decompress(packbits_compress(payload)) == payload + + +# -- Decompression-bomb cap stays intact across the rewrite ------------------ + + +def test_packbits_cap_rejects_oversized_expansion(): + # 0x81 0x42 produces 128 bytes from a 2-byte input; cap of 4 must reject. + with pytest.raises(ValueError, match="packbits decode exceeded expected size"): + packbits_decompress(bytes([0x81, 0x42]), expected_size=4) + + +def test_packbits_cap_allows_within_margin(): + # Cap is expected_size * 1.05 + 1 = 5; a 4-byte decode must pass. + out = packbits_decompress(bytes([0xFF, 0x42, 0xFF, 0x43]), expected_size=4) + assert out == b"BBCC" + + +@pytest.mark.parametrize( + "expected_size, payload_bytes, should_pass", + [ + # Cap = int(expected_size * 1.05) + 1. + # expected_size=1 -> cap=2; 2-byte legitimate decode lands on the cap. + (1, 2, True), + # expected_size=100 -> cap=106; 106-byte decode equals the cap. + (100, 106, True), + # expected_size=100 -> cap=106; 107-byte decode trips the guard. + (100, 107, False), + ], +) +def test_packbits_cap_boundary(expected_size, payload_bytes, should_pass): + # Encode `payload_bytes` zeros using replicate runs of 128 plus a tail. + full_runs, tail = divmod(payload_bytes, 128) + encoded = bytes([0x81, 0x00]) * full_runs + if tail: + # Replicate run of `tail` bytes: header = 257 - tail (for tail >= 2), + # or a single literal byte (header 0x00) for tail == 1. + if tail == 1: + encoded += bytes([0x00, 0x00]) + else: + encoded += bytes([257 - tail, 0x00]) + if should_pass: + out = packbits_decompress(encoded, expected_size=expected_size) + assert out == b"\x00" * payload_bytes + else: + with pytest.raises(ValueError, match="packbits decode exceeded"): + packbits_decompress(encoded, expected_size=expected_size) + + +def test_packbits_no_cap_when_expected_size_is_zero(): + # expected_size=0 disables the cap (backward-compat path). + out = packbits_decompress(bytes([0x81, 0x42])) + assert out == b"\x42" * 128 + + +# -- Truncated input must not read past src ---------------------------------- + + +def test_packbits_truncated_literal_run_stops_at_src_end(): + # Header claims 4 literal bytes but only 2 follow. Decoder must not + # read past the end of src. + encoded = bytes([0x03, 0x01, 0x02]) + out = packbits_decompress(encoded) + assert out == b"\x01\x02" + + +def test_packbits_truncated_replicate_header_stops_cleanly(): + # Replicate header without its data byte: decoder must terminate + # rather than reading off the end. + encoded = bytes([0xFE]) + out = packbits_decompress(encoded) + assert out == b""