Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ optional =
rtxpy
reproject =
pyproj
geotiff =
deflate
dask =
dask[array]
dask-geopandas
Expand Down
85 changes: 50 additions & 35 deletions xrspatial/geotiff/_compression.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Compression codecs: deflate (zlib) and LZW (Numba), plus horizontal predictor."""
from __future__ import annotations

import threading
import warnings
import zlib

import numpy as np
Expand All @@ -10,38 +10,20 @@

# -- Optional libdeflate backend --------------------------------------------
#
# When the ``libdeflate`` package is installed, ``deflate_compress`` routes
# through it: libdeflate is typically 1.5-2x faster than ``zlib`` at the
# same compression level, and GDAL >= 3.7 already uses it when available
# so our installs match throughput. Output is wire-compatible (zlib
# format), so encoded streams round-trip through stdlib ``zlib.decompress``
# unchanged.
#
# libdeflate's ``Compressor`` objects are not thread-safe, so we keep one
# per (thread, level) pair via ``threading.local``. The writer drives
# compression from a ``ThreadPoolExecutor``; per-thread caching avoids
# allocating a fresh compressor per strip/tile.
try: # pragma: no cover - exercised only when libdeflate is installed
import libdeflate as _libdeflate
# When the ``deflate`` PyPI package (Christian Heimes's libdeflate binding)
# is installed, ``deflate_compress`` routes through it. On the 1 MB random
# float32 buffer used in our write bench, libdeflate level 6 is ~3x faster
# than stdlib ``zlib`` level 6 (7.6 ms vs 24 ms), matching what GDAL/libtiff
# achieves when GDAL is built against libdeflate (the default in recent
# rasterio/GDAL wheels). Output is wire-compatible (zlib format), so encoded
# streams round-trip through stdlib ``zlib.decompress`` unchanged.
try: # pragma: no cover - exercised only when the deflate package is installed
import deflate as _deflate
_HAVE_LIBDEFLATE = True
except ImportError:
_libdeflate = None
_deflate = None
_HAVE_LIBDEFLATE = False

_libdeflate_thread_local = threading.local()


def _libdeflate_compressor(level: int):
"""Return a thread-local libdeflate Compressor for *level*."""
cache = getattr(_libdeflate_thread_local, 'cache', None)
if cache is None:
cache = {}
_libdeflate_thread_local.cache = cache
comp = cache.get(level)
if comp is None:
comp = _libdeflate.Compressor(level)
cache[level] = comp
return comp

# -- Decompression-bomb defenses ---------------------------------------------
#
Expand Down Expand Up @@ -133,15 +115,41 @@ def deflate_decompress(data: bytes, expected_size: int = 0) -> bytes:
return bytes(out)


def deflate_compress(data: bytes, level: int = 6) -> bytes:
_zlib_fallback_warned = False


def deflate_compress(data: bytes, level: int = 6,
gil_friendly: bool = False) -> bytes:
"""Compress data with deflate/zlib.

Uses ``libdeflate`` when installed (1.5-2x faster than ``zlib``) and
Uses the ``deflate`` package (libdeflate binding) when installed
(~3x faster than stdlib ``zlib`` on typical raster payloads) and
falls back to ``zlib.compress`` otherwise. Output is wire-compatible
either way: the stdlib ``zlib.decompress`` accepts both.

``gil_friendly=True`` forces stdlib ``zlib`` regardless of libdeflate
availability. The ``deflate`` PyPI binding does not release the GIL
during compression (measured: 1.2x speedup across 8 threads vs zlib's
5x speedup), so the writer's parallel paths request the GIL-releasing
codec to keep thread-pool scaling. The sequential path leaves the
default, picking up libdeflate's per-call speedup.
"""
if _HAVE_LIBDEFLATE:
return _libdeflate_compressor(level).compress(data, _libdeflate.Format.ZLIB)
if _HAVE_LIBDEFLATE and not gil_friendly:
# ``deflate.zlib_compress`` returns ``bytearray``; cast for the
# ``-> bytes`` contract callers (and tests) rely on.
return bytes(_deflate.zlib_compress(data, level))
if not _HAVE_LIBDEFLATE:
global _zlib_fallback_warned
if not _zlib_fallback_warned:
_zlib_fallback_warned = True
warnings.warn(
"xrspatial.geotiff: the `deflate` package is not installed; "
"falling back to stdlib zlib for deflate-compressed writes "
"(~3x slower on typical rasters). Install with `pip install "
"deflate` or `pip install xarray-spatial[geotiff]` to recover "
"full throughput.",
stacklevel=2,
)
return zlib.compress(data, level)


Expand Down Expand Up @@ -1705,7 +1713,8 @@ def decompress(data, compression: int, expected_size: int = 0,
raise ValueError(f"Unsupported compression type: {compression}")


def compress(data: bytes, compression: int, level: int = 6) -> bytes:
def compress(data: bytes, compression: int, level: int = 6,
gil_friendly: bool = False) -> bytes:
"""Compress data based on TIFF compression tag.

Parameters
Expand All @@ -1717,6 +1726,12 @@ def compress(data: bytes, compression: int, level: int = 6) -> bytes:
level : int
Compression level (deflate: 1-9, zstd: 1-22, lz4: 0-16).
Ignored for codecs without level support.
gil_friendly : bool
When True, prefer codec variants that release the GIL (used by
the parallel writer paths so a thread pool actually scales). Only
affects deflate: stdlib ``zlib`` releases the GIL, the ``deflate``
package's binding does not. All other codecs already release the
GIL and ignore this flag.

Returns
-------
Expand All @@ -1725,7 +1740,7 @@ def compress(data: bytes, compression: int, level: int = 6) -> bytes:
if compression == COMPRESSION_NONE:
return data
elif compression in (COMPRESSION_DEFLATE, COMPRESSION_ADOBE_DEFLATE):
return deflate_compress(data, level)
return deflate_compress(data, level, gil_friendly=gil_friendly)
elif compression == COMPRESSION_LZW:
return lzw_compress(data)
elif compression == COMPRESSION_PACKBITS:
Expand Down
33 changes: 22 additions & 11 deletions xrspatial/geotiff/_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,8 @@ def _build_ifd(tags: list[tuple], overflow_base: int,

def _prepare_strip(data, i, rows_per_strip, height, width, samples, dtype,
bytes_per_sample, predictor: int, compression,
compression_level=None, max_z_error: float = 0.0):
compression_level=None, max_z_error: float = 0.0,
gil_friendly: bool = False):
"""Extract and compress a single strip. Thread-safe."""
r0 = i * rows_per_strip
r1 = min(r0 + rows_per_strip, height)
Expand Down Expand Up @@ -689,8 +690,9 @@ def _prepare_strip(data, i, rows_per_strip, height, width, samples, dtype,
strip_data, width, strip_rows, samples=samples, dtype=dtype,
max_z_error=max_z_error)
if compression_level is None:
return compress(strip_data, compression)
return compress(strip_data, compression, level=compression_level)
return compress(strip_data, compression, gil_friendly=gil_friendly)
return compress(strip_data, compression, level=compression_level,
gil_friendly=gil_friendly)


def _write_stripped(data: np.ndarray, compression: int, predictor: int,
Expand Down Expand Up @@ -745,6 +747,10 @@ def _write_stripped(data: np.ndarray, compression: int, predictor: int,
return rel_offsets, byte_counts, strips

# Parallel strip compression -- zlib/zstd/lz4/LZW all release the GIL.
# ``gil_friendly=True`` keeps deflate on stdlib zlib here: the
# ``deflate`` (libdeflate) binding holds the GIL during compress, so
# 8 threads run effectively serially through it. Sequential callers
# still get libdeflate's per-call speedup (~3x).
from concurrent.futures import ThreadPoolExecutor
import os

Expand All @@ -755,6 +761,7 @@ def _write_stripped(data: np.ndarray, compression: int, predictor: int,
data, i, rows_per_strip, height, width, samples, dtype,
bytes_per_sample, predictor, compression,
compression_level, max_z_error,
gil_friendly=True,
),
range(num_strips),
))
Expand All @@ -776,7 +783,8 @@ def _write_stripped(data: np.ndarray, compression: int, predictor: int,

def _prepare_tile(data, tr, tc, th, tw, height, width, samples, dtype,
bytes_per_sample, predictor: int, compression,
compression_level=None, max_z_error: float = 0.0):
compression_level=None, max_z_error: float = 0.0,
gil_friendly: bool = False):
"""Extract, pad, and compress a single tile. Thread-safe."""
r0 = tr * th
c0 = tc * tw
Expand Down Expand Up @@ -822,8 +830,9 @@ def _prepare_tile(data, tr, tc, th, tw, height, width, samples, dtype,
tile_data, tw, th, samples=samples, dtype=dtype,
max_z_error=max_z_error)
if compression_level is None:
return compress(tile_data, compression)
return compress(tile_data, compression, level=compression_level)
return compress(tile_data, compression, gil_friendly=gil_friendly)
return compress(tile_data, compression, level=compression_level,
gil_friendly=gil_friendly)


def _write_tiled(data: np.ndarray, compression: int, predictor: int,
Expand Down Expand Up @@ -931,7 +940,7 @@ def _write_tiled(data: np.ndarray, compression: int, predictor: int,
pool.submit(
_prepare_tile, data, tr, tc, th, tw, height, width,
samples, dtype, bytes_per_sample, predictor, compression,
compression_level, max_z_error,
compression_level, max_z_error, True,
)
for tr, tc in tile_indices
]
Expand Down Expand Up @@ -1589,7 +1598,7 @@ def write(data: np.ndarray, path: str, *,

def _compress_block(arr, block_w, block_h, samples, dtype, bytes_per_sample,
predictor: int, compression, compression_level=None,
max_z_error: float = 0.0):
max_z_error: float = 0.0, gil_friendly: bool = False):
"""Compress a tile or strip. *arr* must be contiguous and correctly sized."""
if compression == COMPRESSION_JPEG:
return jpeg_compress(arr.tobytes(), block_w, block_h, samples)
Expand All @@ -1612,8 +1621,9 @@ def _compress_block(arr, block_w, block_h, samples, dtype, bytes_per_sample,
samples=samples, dtype=dtype,
max_z_error=max_z_error)
if compression_level is None:
return compress(raw_data, compression)
return compress(raw_data, compression, level=compression_level)
return compress(raw_data, compression, gil_friendly=gil_friendly)
return compress(raw_data, compression, level=compression_level,
gil_friendly=gil_friendly)


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -2081,7 +2091,8 @@ def write_streaming(dask_data, path: str, *,
_compress_block,
ta, tw, th, samples, out_dtype,
bytes_per_sample, pred_int, comp_tag,
compression_level, max_z_error)
compression_level, max_z_error,
True)
for ta in seg_tile_arrs
]
seg_compressed = [
Expand Down
12 changes: 7 additions & 5 deletions xrspatial/geotiff/tests/test_compression_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,14 @@ def test_deflate_level_9(self, tmp_path):
class TestLevelEffect:
"""Higher compression level produces a smaller or equal file."""

def _make_compressible(self, shape=(128, 128)):
"""Smooth, highly compressible float32 array."""
rng = np.random.default_rng(42)
# Smooth gradient + small noise -- compresses well
def _make_compressible(self, shape=(512, 512)):
"""Smooth, highly compressible float32 array.

Large + smooth so the level-9 vs level-1 gap is dominated by real
compression work, not codec heuristic noise on tiny inputs.
"""
y, x = np.mgrid[0:shape[0], 0:shape[1]]
arr = (y + x).astype(np.float32) + rng.standard_normal(shape).astype(np.float32) * 0.01
arr = (y + x).astype(np.float32)
return xr.DataArray(arr, dims=['y', 'x'])

def test_zstd_higher_level_not_larger(self, tmp_path):
Expand Down
50 changes: 6 additions & 44 deletions xrspatial/geotiff/tests/test_parallel_writer_1800.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
"""Round-trip and threshold tests for the parallel strip/tile writer (#1800)."""
from __future__ import annotations

from concurrent.futures import ThreadPoolExecutor

import numpy as np
import pytest

Expand Down Expand Up @@ -188,12 +186,14 @@ def test_deflate_compress_zlib_wire_compatible():


def test_deflate_compress_fallback_when_libdeflate_missing(monkeypatch):
"""When libdeflate is absent we route through stdlib zlib unchanged."""
"""When the deflate package is absent we route through stdlib zlib unchanged."""
import zlib

import xrspatial.geotiff._compression as comp_mod
monkeypatch.setattr(comp_mod, '_HAVE_LIBDEFLATE', False)
monkeypatch.setattr(comp_mod, '_libdeflate', None)
monkeypatch.setattr(comp_mod, '_deflate', None)
# Reset the one-shot warning latch so the test path is exercised cleanly.
monkeypatch.setattr(comp_mod, '_zlib_fallback_warned', True)

raw = b'1800-deflate-fallback' * 4096
blob = comp_mod.deflate_compress(raw, level=6)
Expand All @@ -204,53 +204,15 @@ def test_deflate_compress_fallback_when_libdeflate_missing(monkeypatch):


@pytest.mark.skipif(not _HAVE_LIBDEFLATE,
reason='libdeflate not installed')
reason='deflate package not installed')
def test_deflate_compress_uses_libdeflate_when_available():
"""When libdeflate is installed, deflate output stays wire-compatible."""
"""When the deflate package is installed, output stays wire-compatible."""
import zlib
raw = (np.arange(8192, dtype=np.uint8) % 251).tobytes() * 16
blob = deflate_compress(raw, level=6)
assert zlib.decompress(blob) == raw


def test_libdeflate_compressor_cache_is_thread_local():
"""The cache lives in threading.local, so two threads see distinct dicts.

Uses a ``threading.Barrier`` to force both tasks to occupy a worker
at the same time. Without that, ``ThreadPoolExecutor(max_workers=2)``
is free to run both submissions on the same thread (if the first
returns before the second is scheduled), and the test would
intermittently pass with only one observed cache id.
"""
import threading

import xrspatial.geotiff._compression as comp_mod

if not comp_mod._HAVE_LIBDEFLATE:
pytest.skip('libdeflate not installed')

seen_caches: dict[int, int] = {}
barrier = threading.Barrier(2, timeout=10)

def grab(_tag):
# Both workers must reach the barrier before either proceeds, so
# the pool is forced to use both threads.
barrier.wait()
comp_mod._libdeflate_compressor(6)
tid = threading.get_ident()
seen_caches[tid] = id(comp_mod._libdeflate_thread_local.cache)

with ThreadPoolExecutor(max_workers=2) as pool:
list(pool.map(grab, ['a', 'b']))

# Two distinct threads ran and each populated its own threading.local
# cache, so we should see two thread ids and two cache ids.
assert len(seen_caches) == 2, f'expected 2 threads, saw {len(seen_caches)}'
assert len(set(seen_caches.values())) == 2, (
f'expected 2 distinct caches, saw {seen_caches}'
)


# -- End-to-end via write() ------------------------------------------------


Expand Down
Loading