Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion xrspatial/geotiff/_compression.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,48 @@
"""Compression codecs: deflate (zlib) and LZW (Numba), plus horizontal predictor."""
from __future__ import annotations

import threading
import zlib

import numpy as np

from xrspatial.utils import ngjit

# -- Optional libdeflate backend --------------------------------------------
#
# When the ``libdeflate`` package is installed, ``deflate_compress`` routes
# through it: libdeflate is typically 1.5-2x faster than ``zlib`` at the
# same compression level, and GDAL >= 3.7 already uses it when available
# so our installs match throughput. Output is wire-compatible (zlib
# format), so encoded streams round-trip through stdlib ``zlib.decompress``
# unchanged.
#
# libdeflate's ``Compressor`` objects are not thread-safe, so we keep one
# per (thread, level) pair via ``threading.local``. The writer drives
# compression from a ``ThreadPoolExecutor``; per-thread caching avoids
# allocating a fresh compressor per strip/tile.
try: # pragma: no cover - exercised only when libdeflate is installed
import libdeflate as _libdeflate
_HAVE_LIBDEFLATE = True
except ImportError:
_libdeflate = None
_HAVE_LIBDEFLATE = False

_libdeflate_thread_local = threading.local()


def _libdeflate_compressor(level: int):
"""Return a thread-local libdeflate Compressor for *level*."""
cache = getattr(_libdeflate_thread_local, 'cache', None)
if cache is None:
cache = {}
_libdeflate_thread_local.cache = cache
comp = cache.get(level)
if comp is None:
comp = _libdeflate.Compressor(level)
cache[level] = comp
return comp

# -- Decompression-bomb defenses ---------------------------------------------
#
# A malicious TIFF can declare a small strip/tile compressed payload that
Expand Down Expand Up @@ -98,7 +134,14 @@ def deflate_decompress(data: bytes, expected_size: int = 0) -> bytes:


def deflate_compress(data: bytes, level: int = 6) -> bytes:
"""Compress data with deflate/zlib."""
"""Compress data with deflate/zlib.

Uses ``libdeflate`` when installed (1.5-2x faster than ``zlib``) and
falls back to ``zlib.compress`` otherwise. Output is wire-compatible
either way: the stdlib ``zlib.decompress`` accepts both.
"""
if _HAVE_LIBDEFLATE:
return _libdeflate_compressor(level).compress(data, _libdeflate.Format.ZLIB)
return zlib.compress(data, level)


Expand Down
146 changes: 102 additions & 44 deletions xrspatial/geotiff/_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,14 @@ def _compression_tag(compression_name: str) -> int:
#: override.
_MAX_OVERVIEW_LEVELS = 8

#: Total uncompressed payload (bytes) below which the strip and tile
#: writers stay sequential. The thread-pool startup cost dominates on
#: small rasters; above this size the per-block compression cost more
#: than pays for it. 4 MiB was chosen empirically on a 20-core box:
#: parallel becomes a net win around ~2 MiB, and the 4 MiB margin keeps
#: a few-tile / two-strip layout from incurring a slowdown.
_PARALLEL_MIN_BYTES = 4 * 1024 * 1024


def _validate_overview_levels(overview_levels, height=None, width=None):
"""Validate and normalise an explicit ``overview_levels`` list.
Expand Down Expand Up @@ -651,12 +659,50 @@ def _build_ifd(tags: list[tuple], overflow_base: int,
# Strip writer
# ---------------------------------------------------------------------------

def _prepare_strip(data, i, rows_per_strip, height, width, samples, dtype,
bytes_per_sample, predictor: int, compression,
compression_level=None, max_z_error: float = 0.0):
"""Extract and compress a single strip. Thread-safe."""
r0 = i * rows_per_strip
r1 = min(r0 + rows_per_strip, height)
strip_rows = r1 - r0

if compression == COMPRESSION_JPEG:
strip_data = np.ascontiguousarray(data[r0:r1]).tobytes()
return jpeg_compress(strip_data, width, strip_rows, samples)
if predictor != 1 and compression != COMPRESSION_NONE:
strip_arr = np.ascontiguousarray(data[r0:r1])
buf = strip_arr.view(np.uint8).ravel().copy()
buf = _apply_predictor_encode(
buf, predictor, width, strip_rows, bytes_per_sample, samples)
strip_data = buf.tobytes()
else:
strip_data = np.ascontiguousarray(data[r0:r1]).tobytes()

if compression == COMPRESSION_JPEG2000:
from ._compression import jpeg2000_compress
return jpeg2000_compress(
strip_data, width, strip_rows, samples=samples, dtype=dtype)
if compression == COMPRESSION_LERC:
from ._compression import lerc_compress
return lerc_compress(
strip_data, width, strip_rows, samples=samples, dtype=dtype,
max_z_error=max_z_error)
if compression_level is None:
return compress(strip_data, compression)
return compress(strip_data, compression, level=compression_level)


def _write_stripped(data: np.ndarray, compression: int, predictor: int,
rows_per_strip: int = 256,
compression_level: int | None = None,
max_z_error: float = 0.0) -> tuple[list, list, list]:
"""Compress data as strips.

For compressed formats (deflate, lzw, zstd, lz4, ...) strips are
compressed in parallel using a thread pool: zlib, zstandard, lz4,
and the Numba LZW kernel all release the GIL during compression.

Returns
-------
(offsets_placeholder, byte_counts, compressed_chunks)
Expand All @@ -668,53 +714,60 @@ def _write_stripped(data: np.ndarray, compression: int, predictor: int,
dtype = data.dtype
bytes_per_sample = dtype.itemsize

strips = []
num_strips = math.ceil(height / rows_per_strip)

total_bytes = int(data.nbytes)

# Sequential path: uncompressed, few strips, or small payload. The
# threshold mirrors the tile writer so we don't pay thread-pool
# overhead on tiny rasters.
use_parallel = (
compression != COMPRESSION_NONE
and num_strips > 2
and total_bytes > _PARALLEL_MIN_BYTES
)

if not use_parallel:
strips = []
rel_offsets = []
byte_counts = []
current_offset = 0
for i in range(num_strips):
compressed = _prepare_strip(
data, i, rows_per_strip, height, width, samples, dtype,
bytes_per_sample, predictor, compression,
compression_level, max_z_error,
)
rel_offsets.append(current_offset)
byte_counts.append(len(compressed))
strips.append(compressed)
current_offset += len(compressed)
return rel_offsets, byte_counts, strips

# Parallel strip compression -- zlib/zstd/lz4/LZW all release the GIL.
from concurrent.futures import ThreadPoolExecutor
import os

n_workers = min(num_strips, os.cpu_count() or 4)
with ThreadPoolExecutor(max_workers=n_workers) as pool:
compressed_strips = list(pool.map(
lambda i: _prepare_strip(
data, i, rows_per_strip, height, width, samples, dtype,
bytes_per_sample, predictor, compression,
compression_level, max_z_error,
),
range(num_strips),
))

rel_offsets = []
byte_counts = []
current_offset = 0

num_strips = math.ceil(height / rows_per_strip)
for i in range(num_strips):
r0 = i * rows_per_strip
r1 = min(r0 + rows_per_strip, height)
strip_rows = r1 - r0

if compression == COMPRESSION_JPEG:
strip_data = np.ascontiguousarray(data[r0:r1]).tobytes()
compressed = jpeg_compress(strip_data, width, strip_rows, samples)
elif predictor != 1 and compression != COMPRESSION_NONE:
strip_arr = np.ascontiguousarray(data[r0:r1])
buf = strip_arr.view(np.uint8).ravel().copy()
buf = _apply_predictor_encode(
buf, predictor, width, strip_rows, bytes_per_sample, samples)
strip_data = buf.tobytes()
if compression_level is None:
compressed = compress(strip_data, compression)
else:
compressed = compress(strip_data, compression, level=compression_level)
else:
strip_data = np.ascontiguousarray(data[r0:r1]).tobytes()

if compression == COMPRESSION_JPEG2000:
from ._compression import jpeg2000_compress
compressed = jpeg2000_compress(
strip_data, width, strip_rows, samples=samples, dtype=dtype)
elif compression == COMPRESSION_LERC:
from ._compression import lerc_compress
compressed = lerc_compress(
strip_data, width, strip_rows, samples=samples, dtype=dtype,
max_z_error=max_z_error)
elif compression_level is None:
compressed = compress(strip_data, compression)
else:
compressed = compress(strip_data, compression, level=compression_level)

for cs in compressed_strips:
rel_offsets.append(current_offset)
byte_counts.append(len(compressed))
strips.append(compressed)
current_offset += len(compressed)
byte_counts.append(len(cs))
current_offset += len(cs)

return rel_offsets, byte_counts, strips
return rel_offsets, byte_counts, compressed_strips


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -841,8 +894,13 @@ def _write_tiled(data: np.ndarray, compression: int, predictor: int,

return rel_offsets, byte_counts, tiles

if n_tiles <= 4:
# Very few tiles: sequential (thread pool overhead not worth it)
# Sequential path: very few tiles, or small total payload. A previous
# ``n_tiles <= 4`` cutoff sent ``tile_size=1024`` writes on a 2048x2048
# image down the serial path (n_tiles=4) and made them ~8x slower than
# the parallel path. Switching to a bytes-based threshold lets
# large-tile writes parallelize while still skipping the pool on
# small rasters where its setup cost dominates.
if n_tiles <= 2 or int(data.nbytes) <= _PARALLEL_MIN_BYTES:
tiles = []
rel_offsets = []
byte_counts = []
Expand Down
Loading
Loading