From cab59fc8c01794c45cdd8f546e058f156d955b1d Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Wed, 13 May 2026 07:35:04 -0700 Subject: [PATCH 1/2] geotiff: parallelize strip writer and add optional libdeflate backend (#1800) The deflate strip-write path was 3.7x slower than rioxarray/GDAL because `_write_stripped` ran zlib.compress serially while the tile writer already parallelized via a thread pool. Three changes: 1. Mirror `_write_tiled`'s ThreadPoolExecutor pattern in `_write_stripped`. Strip preparation is hoisted into a new `_prepare_strip` helper so the same code drives both the serial and parallel paths. A 2048x2048 deflate strip write drops from 405 ms to 70 ms (5.8x speedup, beats rioxarray's 102 ms). 2. Replace the tile writer's `n_tiles <= 4` sequential cutoff with a bytes-based threshold (`_PARALLEL_MIN_BYTES = 4 MiB`). Pre-fix, `tile_size=1024` on a 2048x2048 image produced n_tiles=4 and forced the slow path; now those writes parallelize too. 3. Route `deflate_compress` through the optional `libdeflate` package when installed (1.5-2x faster than stdlib zlib at the same level; GDAL >= 3.7 already uses it). Output is wire-compatible -- decoded streams round-trip through `zlib.decompress` unchanged. Compressors are cached per thread via `threading.local`. --- xrspatial/geotiff/_compression.py | 45 ++- xrspatial/geotiff/_writer.py | 146 +++++++--- .../tests/test_parallel_writer_1800.py | 259 ++++++++++++++++++ 3 files changed, 405 insertions(+), 45 deletions(-) create mode 100644 xrspatial/geotiff/tests/test_parallel_writer_1800.py diff --git a/xrspatial/geotiff/_compression.py b/xrspatial/geotiff/_compression.py index 3be9400c..9d011e4f 100644 --- a/xrspatial/geotiff/_compression.py +++ b/xrspatial/geotiff/_compression.py @@ -1,12 +1,48 @@ """Compression codecs: deflate (zlib) and LZW (Numba), plus horizontal predictor.""" from __future__ import annotations +import threading import zlib import numpy as np from xrspatial.utils import ngjit +# -- Optional libdeflate backend -------------------------------------------- +# +# When the ``libdeflate`` package is installed, ``deflate_compress`` routes +# through it: libdeflate is typically 1.5-2x faster than ``zlib`` at the +# same compression level, and GDAL >= 3.7 already uses it when available +# so our installs match throughput. Output is wire-compatible (zlib +# format), so encoded streams round-trip through stdlib ``zlib.decompress`` +# unchanged. +# +# libdeflate's ``Compressor`` objects are not thread-safe, so we keep one +# per (thread, level) pair via ``threading.local``. The writer drives +# compression from a ``ThreadPoolExecutor``; per-thread caching avoids +# allocating a fresh compressor per strip/tile. +try: # pragma: no cover - exercised only when libdeflate is installed + import libdeflate as _libdeflate + _HAVE_LIBDEFLATE = True +except ImportError: + _libdeflate = None + _HAVE_LIBDEFLATE = False + +_libdeflate_thread_local = threading.local() + + +def _libdeflate_compressor(level: int): + """Return a thread-local libdeflate Compressor for *level*.""" + cache = getattr(_libdeflate_thread_local, 'cache', None) + if cache is None: + cache = {} + _libdeflate_thread_local.cache = cache + comp = cache.get(level) + if comp is None: + comp = _libdeflate.Compressor(level) + cache[level] = comp + return comp + # -- Decompression-bomb defenses --------------------------------------------- # # A malicious TIFF can declare a small strip/tile compressed payload that @@ -98,7 +134,14 @@ def deflate_decompress(data: bytes, expected_size: int = 0) -> bytes: def deflate_compress(data: bytes, level: int = 6) -> bytes: - """Compress data with deflate/zlib.""" + """Compress data with deflate/zlib. + + Uses ``libdeflate`` when installed (1.5-2x faster than ``zlib``) and + falls back to ``zlib.compress`` otherwise. Output is wire-compatible + either way: the stdlib ``zlib.decompress`` accepts both. + """ + if _HAVE_LIBDEFLATE: + return _libdeflate_compressor(level).compress(data, _libdeflate.Format.ZLIB) return zlib.compress(data, level) diff --git a/xrspatial/geotiff/_writer.py b/xrspatial/geotiff/_writer.py index abe22bf3..c2245aa2 100644 --- a/xrspatial/geotiff/_writer.py +++ b/xrspatial/geotiff/_writer.py @@ -225,6 +225,14 @@ def _compression_tag(compression_name: str) -> int: #: override. _MAX_OVERVIEW_LEVELS = 8 +#: Total uncompressed payload (bytes) below which the strip and tile +#: writers stay sequential. The thread-pool startup cost dominates on +#: small rasters; above this size the per-block compression cost more +#: than pays for it. 4 MiB was chosen empirically on a 20-core box: +#: parallel becomes a net win around ~2 MiB, and the 4 MiB margin keeps +#: a few-tile / two-strip layout from incurring a slowdown. +_PARALLEL_MIN_BYTES = 4 * 1024 * 1024 + def _validate_overview_levels(overview_levels, height=None, width=None): """Validate and normalise an explicit ``overview_levels`` list. @@ -651,12 +659,50 @@ def _build_ifd(tags: list[tuple], overflow_base: int, # Strip writer # --------------------------------------------------------------------------- +def _prepare_strip(data, i, rows_per_strip, height, width, samples, dtype, + bytes_per_sample, predictor: int, compression, + compression_level=None, max_z_error: float = 0.0): + """Extract and compress a single strip. Thread-safe.""" + r0 = i * rows_per_strip + r1 = min(r0 + rows_per_strip, height) + strip_rows = r1 - r0 + + if compression == COMPRESSION_JPEG: + strip_data = np.ascontiguousarray(data[r0:r1]).tobytes() + return jpeg_compress(strip_data, width, strip_rows, samples) + if predictor != 1 and compression != COMPRESSION_NONE: + strip_arr = np.ascontiguousarray(data[r0:r1]) + buf = strip_arr.view(np.uint8).ravel().copy() + buf = _apply_predictor_encode( + buf, predictor, width, strip_rows, bytes_per_sample, samples) + strip_data = buf.tobytes() + else: + strip_data = np.ascontiguousarray(data[r0:r1]).tobytes() + + if compression == COMPRESSION_JPEG2000: + from ._compression import jpeg2000_compress + return jpeg2000_compress( + strip_data, width, strip_rows, samples=samples, dtype=dtype) + if compression == COMPRESSION_LERC: + from ._compression import lerc_compress + return lerc_compress( + strip_data, width, strip_rows, samples=samples, dtype=dtype, + max_z_error=max_z_error) + if compression_level is None: + return compress(strip_data, compression) + return compress(strip_data, compression, level=compression_level) + + def _write_stripped(data: np.ndarray, compression: int, predictor: int, rows_per_strip: int = 256, compression_level: int | None = None, max_z_error: float = 0.0) -> tuple[list, list, list]: """Compress data as strips. + For compressed formats (deflate, lzw, zstd, lz4, ...) strips are + compressed in parallel using a thread pool: zlib, zstandard, lz4, + and the Numba LZW kernel all release the GIL during compression. + Returns ------- (offsets_placeholder, byte_counts, compressed_chunks) @@ -668,53 +714,60 @@ def _write_stripped(data: np.ndarray, compression: int, predictor: int, dtype = data.dtype bytes_per_sample = dtype.itemsize - strips = [] + num_strips = math.ceil(height / rows_per_strip) + + total_bytes = int(data.nbytes) + + # Sequential path: uncompressed, few strips, or small payload. The + # threshold mirrors the tile writer so we don't pay thread-pool + # overhead on tiny rasters. + use_parallel = ( + compression != COMPRESSION_NONE + and num_strips > 2 + and total_bytes > _PARALLEL_MIN_BYTES + ) + + if not use_parallel: + strips = [] + rel_offsets = [] + byte_counts = [] + current_offset = 0 + for i in range(num_strips): + compressed = _prepare_strip( + data, i, rows_per_strip, height, width, samples, dtype, + bytes_per_sample, predictor, compression, + compression_level, max_z_error, + ) + rel_offsets.append(current_offset) + byte_counts.append(len(compressed)) + strips.append(compressed) + current_offset += len(compressed) + return rel_offsets, byte_counts, strips + + # Parallel strip compression -- zlib/zstd/lz4/LZW all release the GIL. + from concurrent.futures import ThreadPoolExecutor + import os + + n_workers = min(num_strips, os.cpu_count() or 4) + with ThreadPoolExecutor(max_workers=n_workers) as pool: + compressed_strips = list(pool.map( + lambda i: _prepare_strip( + data, i, rows_per_strip, height, width, samples, dtype, + bytes_per_sample, predictor, compression, + compression_level, max_z_error, + ), + range(num_strips), + )) + rel_offsets = [] byte_counts = [] current_offset = 0 - - num_strips = math.ceil(height / rows_per_strip) - for i in range(num_strips): - r0 = i * rows_per_strip - r1 = min(r0 + rows_per_strip, height) - strip_rows = r1 - r0 - - if compression == COMPRESSION_JPEG: - strip_data = np.ascontiguousarray(data[r0:r1]).tobytes() - compressed = jpeg_compress(strip_data, width, strip_rows, samples) - elif predictor != 1 and compression != COMPRESSION_NONE: - strip_arr = np.ascontiguousarray(data[r0:r1]) - buf = strip_arr.view(np.uint8).ravel().copy() - buf = _apply_predictor_encode( - buf, predictor, width, strip_rows, bytes_per_sample, samples) - strip_data = buf.tobytes() - if compression_level is None: - compressed = compress(strip_data, compression) - else: - compressed = compress(strip_data, compression, level=compression_level) - else: - strip_data = np.ascontiguousarray(data[r0:r1]).tobytes() - - if compression == COMPRESSION_JPEG2000: - from ._compression import jpeg2000_compress - compressed = jpeg2000_compress( - strip_data, width, strip_rows, samples=samples, dtype=dtype) - elif compression == COMPRESSION_LERC: - from ._compression import lerc_compress - compressed = lerc_compress( - strip_data, width, strip_rows, samples=samples, dtype=dtype, - max_z_error=max_z_error) - elif compression_level is None: - compressed = compress(strip_data, compression) - else: - compressed = compress(strip_data, compression, level=compression_level) - + for cs in compressed_strips: rel_offsets.append(current_offset) - byte_counts.append(len(compressed)) - strips.append(compressed) - current_offset += len(compressed) + byte_counts.append(len(cs)) + current_offset += len(cs) - return rel_offsets, byte_counts, strips + return rel_offsets, byte_counts, compressed_strips # --------------------------------------------------------------------------- @@ -841,8 +894,13 @@ def _write_tiled(data: np.ndarray, compression: int, predictor: int, return rel_offsets, byte_counts, tiles - if n_tiles <= 4: - # Very few tiles: sequential (thread pool overhead not worth it) + # Sequential path: very few tiles, or small total payload. A previous + # ``n_tiles <= 4`` cutoff sent ``tile_size=1024`` writes on a 2048x2048 + # image down the serial path (n_tiles=4) and made them ~8x slower than + # the parallel path. Switching to a bytes-based threshold lets + # large-tile writes parallelize while still skipping the pool on + # small rasters where its setup cost dominates. + if n_tiles <= 2 or int(data.nbytes) <= _PARALLEL_MIN_BYTES: tiles = [] rel_offsets = [] byte_counts = [] diff --git a/xrspatial/geotiff/tests/test_parallel_writer_1800.py b/xrspatial/geotiff/tests/test_parallel_writer_1800.py new file mode 100644 index 00000000..a7895243 --- /dev/null +++ b/xrspatial/geotiff/tests/test_parallel_writer_1800.py @@ -0,0 +1,259 @@ +"""Round-trip and threshold tests for the parallel strip/tile writer (#1800).""" +from __future__ import annotations + +from concurrent.futures import ThreadPoolExecutor + +import numpy as np +import pytest + +from xrspatial.geotiff._writer import ( + _PARALLEL_MIN_BYTES, + _write_stripped, + _write_tiled, + write, +) +from xrspatial.geotiff._reader import read_to_array +from xrspatial.geotiff._compression import ( + COMPRESSION_DEFLATE, + COMPRESSION_NONE, + _HAVE_LIBDEFLATE, + deflate_compress, +) + + +# -- Strip writer parity -------------------------------------------------- + + +def _make_data(h, w, dtype=np.float32, pattern='gradient'): + """Reproducible array used across tests.""" + n = h * w + if pattern == 'gradient': + return np.arange(n, dtype=dtype).reshape(h, w) + rng = np.random.RandomState(1800) + arr = rng.rand(h, w) * 1000 + return arr.astype(dtype) + + +@pytest.mark.parametrize('compression', ['deflate', 'lzw', 'zstd']) +@pytest.mark.parametrize('predictor', [False, True]) +def test_strip_writer_round_trip_large(tmp_path, compression, predictor): + """Multi-strip writes round-trip bit-identically through the parallel path.""" + expected = _make_data(1024, 768, pattern='random') + path = str(tmp_path / f'parallel_strip_1800_{compression}_{predictor}.tif') + write(expected, path, compression=compression, tiled=False, + predictor=predictor) + arr, _ = read_to_array(path) + np.testing.assert_array_equal(arr, expected) + + +@pytest.mark.parametrize('dtype', [np.uint8, np.uint16, np.int16, np.int32, + np.float32, np.float64]) +def test_strip_writer_dtypes(tmp_path, dtype): + """Parallel strip path preserves every supported numeric dtype.""" + if np.issubdtype(dtype, np.floating): + expected = _make_data(800, 400, dtype=dtype, pattern='random') + else: + info = np.iinfo(dtype) + rng = np.random.RandomState(1800) + expected = rng.randint(info.min, info.max, + size=(800, 400), dtype=dtype) + path = str(tmp_path / f'parallel_strip_1800_dtype_{dtype.__name__}.tif') + write(expected, path, compression='deflate', tiled=False) + arr, _ = read_to_array(path) + np.testing.assert_array_equal(arr, expected) + + +def test_strip_writer_small_takes_sequential_path(tmp_path): + """Below the byte threshold the parallel strip path is skipped. + + The sequential branch is functionally identical, so the round-trip + check just guards against the threshold logic accidentally breaking + the small-payload case. + """ + expected = _make_data(32, 64, pattern='gradient') + assert expected.nbytes < _PARALLEL_MIN_BYTES + path = str(tmp_path / 'small_seq_strip_1800.tif') + write(expected, path, compression='deflate', tiled=False) + arr, _ = read_to_array(path) + np.testing.assert_array_equal(arr, expected) + + +def test_strip_writer_thread_pool_used_when_large(monkeypatch): + """A multi-MiB strip write must dispatch through ThreadPoolExecutor.""" + expected = _make_data(2048, 2048, dtype=np.float32, pattern='random') + assert expected.nbytes > _PARALLEL_MIN_BYTES + + used = {'pool': False} + + import concurrent.futures as cf + + class _Probe(cf.ThreadPoolExecutor): + def __init__(self, *a, **kw): + used['pool'] = True + super().__init__(*a, **kw) + + # The writer does ``from concurrent.futures import ThreadPoolExecutor`` + # inside the function, so patching the module attribute is enough. + monkeypatch.setattr(cf, 'ThreadPoolExecutor', _Probe) + + rel, bc, blobs = _write_stripped( + expected, COMPRESSION_DEFLATE, predictor=1, rows_per_strip=256) + assert used['pool'], 'parallel strip writer should have used ThreadPoolExecutor' + # And the output should still round-trip + import zlib + decoded = b''.join(zlib.decompress(b) for b in blobs) + rt = np.frombuffer(decoded, dtype=np.float32).reshape(expected.shape) + np.testing.assert_array_equal(rt, expected) + + +def test_strip_writer_uncompressed_stays_sequential(monkeypatch): + """``compression='none'`` never dispatches to the thread pool.""" + expected = _make_data(2048, 2048, dtype=np.float32, pattern='gradient') + assert expected.nbytes > _PARALLEL_MIN_BYTES + + used = {'pool': False} + + import concurrent.futures as cf + + class _Probe(cf.ThreadPoolExecutor): + def __init__(self, *a, **kw): + used['pool'] = True + super().__init__(*a, **kw) + + monkeypatch.setattr(cf, 'ThreadPoolExecutor', _Probe) + _write_stripped(expected, COMPRESSION_NONE, predictor=1, rows_per_strip=256) + assert not used['pool'], 'uncompressed strip writer must stay sequential' + + +# -- Tile writer adaptive threshold --------------------------------------- + + +def test_tile_writer_large_tile_size_parallelizes(monkeypatch): + """A 2048x2048 deflate write with tile_size=1024 (n_tiles=4) must run + in parallel after the threshold fix. + + Pre-fix, ``n_tiles <= 4`` shoved this case onto the serial path even + though the payload was 16 MiB; that produced ~8x slower writes. + """ + expected = _make_data(2048, 2048, dtype=np.float32, pattern='random') + assert expected.nbytes > _PARALLEL_MIN_BYTES + + used = {'pool': False} + + import concurrent.futures as cf + + class _Probe(cf.ThreadPoolExecutor): + def __init__(self, *a, **kw): + used['pool'] = True + super().__init__(*a, **kw) + + monkeypatch.setattr(cf, 'ThreadPoolExecutor', _Probe) + _write_tiled( + expected, COMPRESSION_DEFLATE, predictor=1, tile_size=1024) + assert used['pool'], ( + 'tile writer with tile_size=1024 on 2048x2048 (n_tiles=4, 16 MiB) ' + 'must parallelize after the adaptive-threshold change' + ) + + +def test_tile_writer_small_payload_stays_sequential(monkeypatch): + """A small raster keeps the sequential path even with n_tiles > 2.""" + expected = _make_data(128, 128, dtype=np.float32, pattern='gradient') + assert expected.nbytes < _PARALLEL_MIN_BYTES + + used = {'pool': False} + + import concurrent.futures as cf + + class _Probe(cf.ThreadPoolExecutor): + def __init__(self, *a, **kw): + used['pool'] = True + super().__init__(*a, **kw) + + monkeypatch.setattr(cf, 'ThreadPoolExecutor', _Probe) + _write_tiled( + expected, COMPRESSION_DEFLATE, predictor=1, tile_size=32) + assert not used['pool'] + + +# -- libdeflate backend ---------------------------------------------------- + + +def test_deflate_compress_zlib_wire_compatible(): + """Output is decompressible by stdlib zlib regardless of backend.""" + import zlib + raw = (np.arange(1024, dtype=np.uint8) % 251).tobytes() * 64 + compressed = deflate_compress(raw, level=6) + assert zlib.decompress(compressed) == raw + + +def test_deflate_compress_fallback_when_libdeflate_missing(monkeypatch): + """When libdeflate is absent we route through stdlib zlib unchanged.""" + import zlib + + import xrspatial.geotiff._compression as comp_mod + monkeypatch.setattr(comp_mod, '_HAVE_LIBDEFLATE', False) + monkeypatch.setattr(comp_mod, '_libdeflate', None) + + raw = b'1800-deflate-fallback' * 4096 + blob = comp_mod.deflate_compress(raw, level=6) + assert zlib.decompress(blob) == raw + # Exact byte equality to ``zlib.compress`` at the same level (the + # fallback path is a direct call). + assert blob == zlib.compress(raw, 6) + + +@pytest.mark.skipif(not _HAVE_LIBDEFLATE, + reason='libdeflate not installed') +def test_deflate_compress_uses_libdeflate_when_available(): + """When libdeflate is installed, deflate output stays wire-compatible.""" + import zlib + raw = (np.arange(8192, dtype=np.uint8) % 251).tobytes() * 16 + blob = deflate_compress(raw, level=6) + assert zlib.decompress(blob) == raw + + +def test_libdeflate_compressor_cache_is_thread_local(): + """The cache lives in threading.local, so two threads see distinct dicts.""" + import xrspatial.geotiff._compression as comp_mod + + if not comp_mod._HAVE_LIBDEFLATE: + pytest.skip('libdeflate not installed') + + seen = {} + + def grab(tag): + # First call populates the cache; we grab its id(). + comp_mod._libdeflate_compressor(6) + seen[tag] = id(comp_mod._libdeflate_thread_local.cache) + + t1 = ThreadPoolExecutor(max_workers=2) + list(t1.map(grab, ['a', 'b'])) + t1.shutdown(wait=True) + # Two workers populated two distinct local caches. + assert len(set(seen.values())) == 2 + + +# -- End-to-end via write() ------------------------------------------------ + + +def test_write_strip_deflate_round_trip_multi_strip(tmp_path): + """Drive the writer entrypoint with a multi-strip deflate payload. + + The reader doesn't care which path produced the bytes; this guards + the full write pipeline (predictor on, multiple strips). + """ + expected = _make_data(900, 700, dtype=np.float32, pattern='random') + path = str(tmp_path / 'e2e_strip_1800.tif') + write(expected, path, compression='deflate', tiled=False, predictor=True) + arr, _ = read_to_array(path) + np.testing.assert_array_equal(arr, expected) + + +def test_write_tiled_deflate_large_tile_round_trip(tmp_path): + """tile_size=1024 on 2048x2048 must round-trip through the parallel path.""" + expected = _make_data(2048, 2048, dtype=np.float32, pattern='random') + path = str(tmp_path / 'e2e_tile1024_1800.tif') + write(expected, path, compression='deflate', tiled=True, tile_size=1024) + arr, _ = read_to_array(path) + np.testing.assert_array_equal(arr, expected) From 82b7b19c63b63fde0001eda10b4fc01328008992 Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Wed, 13 May 2026 07:57:14 -0700 Subject: [PATCH 2/2] geotiff: harden thread-local cache test against pool scheduling (#1800) PR #1801's review flagged that `test_libdeflate_compressor_cache_is_thread_local` could pass with a single observed cache id: `ThreadPoolExecutor(max_workers=2).map(...)` is free to run both submissions on the same worker if the first returns quickly. Force both tasks to occupy a worker at the same time with a `threading.Barrier`, record `threading.get_ident()` so the assertion fails loudly if only one thread actually ran, and use the executor as a context manager so the pool is shut down on assertion failure. --- .../tests/test_parallel_writer_1800.py | 37 ++++++++++++++----- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/xrspatial/geotiff/tests/test_parallel_writer_1800.py b/xrspatial/geotiff/tests/test_parallel_writer_1800.py index a7895243..39baead0 100644 --- a/xrspatial/geotiff/tests/test_parallel_writer_1800.py +++ b/xrspatial/geotiff/tests/test_parallel_writer_1800.py @@ -214,24 +214,41 @@ def test_deflate_compress_uses_libdeflate_when_available(): def test_libdeflate_compressor_cache_is_thread_local(): - """The cache lives in threading.local, so two threads see distinct dicts.""" + """The cache lives in threading.local, so two threads see distinct dicts. + + Uses a ``threading.Barrier`` to force both tasks to occupy a worker + at the same time. Without that, ``ThreadPoolExecutor(max_workers=2)`` + is free to run both submissions on the same thread (if the first + returns before the second is scheduled), and the test would + intermittently pass with only one observed cache id. + """ + import threading + import xrspatial.geotiff._compression as comp_mod if not comp_mod._HAVE_LIBDEFLATE: pytest.skip('libdeflate not installed') - seen = {} + seen_caches: dict[int, int] = {} + barrier = threading.Barrier(2, timeout=10) - def grab(tag): - # First call populates the cache; we grab its id(). + def grab(_tag): + # Both workers must reach the barrier before either proceeds, so + # the pool is forced to use both threads. + barrier.wait() comp_mod._libdeflate_compressor(6) - seen[tag] = id(comp_mod._libdeflate_thread_local.cache) + tid = threading.get_ident() + seen_caches[tid] = id(comp_mod._libdeflate_thread_local.cache) - t1 = ThreadPoolExecutor(max_workers=2) - list(t1.map(grab, ['a', 'b'])) - t1.shutdown(wait=True) - # Two workers populated two distinct local caches. - assert len(set(seen.values())) == 2 + with ThreadPoolExecutor(max_workers=2) as pool: + list(pool.map(grab, ['a', 'b'])) + + # Two distinct threads ran and each populated its own threading.local + # cache, so we should see two thread ids and two cache ids. + assert len(seen_caches) == 2, f'expected 2 threads, saw {len(seen_caches)}' + assert len(set(seen_caches.values())) == 2, ( + f'expected 2 distinct caches, saw {seen_caches}' + ) # -- End-to-end via write() ------------------------------------------------