From 824896b73e997c474e8816631cf514177186646f Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Wed, 29 Apr 2026 07:21:57 -0700 Subject: [PATCH] Support TIFF predictor=3 on the CPU write path (#1313) to_geotiff and the streaming writer now accept predictor as bool or int. False/0/1 means none, True/2 keeps horizontal differencing, and 3 selects the floating-point predictor. The encoder (fp_predictor_encode) already existed in _compression.py but was not reachable from the writer. Predictor 3 is rejected for non-float dtypes. Predictor=True still emits TIFF predictor 2, so existing callers are unaffected. --- xrspatial/geotiff/__init__.py | 25 ++- xrspatial/geotiff/_writer.py | 93 +++++--- .../tests/test_predictor_fp_write_1313.py | 203 ++++++++++++++++++ 3 files changed, 287 insertions(+), 34 deletions(-) create mode 100644 xrspatial/geotiff/tests/test_predictor_fp_write_1313.py diff --git a/xrspatial/geotiff/__init__.py b/xrspatial/geotiff/__init__.py index e4545db8..3c02b675 100644 --- a/xrspatial/geotiff/__init__.py +++ b/xrspatial/geotiff/__init__.py @@ -406,7 +406,7 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *, compression_level: int | None = None, tiled: bool = True, tile_size: int = 256, - predictor: bool = False, + predictor: bool | int = False, cog: bool = False, overview_levels: list[int] | None = None, overview_resampling: str = 'mean', @@ -453,8 +453,11 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *, Use tiled layout (default True). tile_size : int Tile size in pixels (default 256). - predictor : bool - Use horizontal differencing predictor. + predictor : bool or int + TIFF predictor. ``False``/``0``/``1`` -> none, ``True``/``2`` -> + horizontal differencing (good for integer data), ``3`` -> + floating-point predictor (float dtypes only; typically gives + better deflate/zstd ratios on float data than predictor 2). cog : bool Write as Cloud Optimized GeoTIFF. overview_levels : list[int] or None @@ -707,7 +710,8 @@ def _write_single_tile(chunk_data, path, geo_transform, epsg, wkt, def _write_vrt_tiled(data, vrt_path, *, crs=None, nodata=None, compression='zstd', compression_level=None, - tile_size=256, predictor=False, bigtiff=None): + tile_size=256, predictor: bool | int = False, + bigtiff=None): """Write a DataArray as a directory of tiled GeoTIFFs with a VRT index. This enables streaming dask arrays to disk without materializing the @@ -1223,7 +1227,7 @@ def write_geotiff_gpu(data, path: str, *, compression: str = 'zstd', compression_level: int | None = None, tile_size: int = 256, - predictor: bool = False, + predictor: bool | int = False, cog: bool = False, overview_levels: list[int] | None = None, overview_resampling: str = 'mean') -> None: @@ -1258,8 +1262,10 @@ def write_geotiff_gpu(data, path: str, *, currently ignored -- nvCOMP does not expose level control. tile_size : int Tile size in pixels (default 256). - predictor : bool - Apply horizontal differencing predictor. + predictor : bool or int + TIFF predictor. ``False``/``0``/``1`` -> none, ``True``/``2`` -> + horizontal differencing, ``3`` -> floating-point predictor + (float dtypes only). cog : bool Write as Cloud Optimized GeoTIFF with overviews. overview_levels : list[int] or None @@ -1278,6 +1284,7 @@ def write_geotiff_gpu(data, path: str, *, from ._gpu_decode import gpu_compress_tiles, make_overview_gpu from ._writer import ( _compression_tag, _assemble_tiff, _write_bytes, + normalize_predictor, GeoTransform as _GT, ) from ._dtypes import numpy_to_tiff_dtype @@ -1328,7 +1335,7 @@ def write_geotiff_gpu(data, path: str, *, np_dtype = np.dtype(str(arr.dtype)) # cupy dtype -> numpy dtype comp_tag = _compression_tag(compression) - pred_val = 2 if predictor else 1 + pred_val = normalize_predictor(predictor, np_dtype, comp_tag) def _gpu_compress_to_part(gpu_arr, w, h, spp): """Compress a GPU array into a (stub, w, h, offsets, counts, tiles) tuple.""" @@ -1366,7 +1373,7 @@ def _gpu_compress_to_part(gpu_arr, w, h, spp): parts.append(_gpu_compress_to_part(current, ow, oh, samples)) file_bytes = _assemble_tiff( - width, height, np_dtype, comp_tag, predictor, True, tile_size, + width, height, np_dtype, comp_tag, pred_val, True, tile_size, parts, geo_transform, epsg, nodata, is_cog=(cog and len(parts) > 1), raster_type=raster_type) diff --git a/xrspatial/geotiff/_writer.py b/xrspatial/geotiff/_writer.py index 7101106d..13311a94 100644 --- a/xrspatial/geotiff/_writer.py +++ b/xrspatial/geotiff/_writer.py @@ -17,6 +17,7 @@ COMPRESSION_PACKBITS, COMPRESSION_ZSTD, compress, + fp_predictor_encode, jpeg_compress, predictor_encode, ) @@ -67,6 +68,42 @@ BO = '<' +def normalize_predictor(predictor, dtype, compression: int) -> int: + """Normalize a user-supplied predictor value to a TIFF predictor int. + + Accepts ``False``/``True`` (legacy) and integers ``1``/``2``/``3``. + Returns ``1`` (no predictor), ``2`` (horizontal differencing), or ``3`` + (floating-point predictor). + """ + if predictor is False or predictor == 0: + return 1 + if predictor is True or predictor == 2: + return 2 + if predictor == 1: + return 1 + if predictor == 3: + if np.dtype(dtype).kind != 'f': + raise ValueError( + "predictor=3 (floating-point) requires float data, " + f"got dtype={np.dtype(dtype)}") + return 3 + raise ValueError( + f"predictor must be False/True or 1/2/3, got {predictor!r}") + + +def _apply_predictor_encode(buf: np.ndarray, predictor: int, + width: int, height: int, + bytes_per_sample: int, samples: int) -> np.ndarray: + """Apply the chosen predictor to a flat uint8 buffer.""" + if predictor == 2: + return predictor_encode(buf, width, height, + bytes_per_sample * samples) + if predictor == 3: + return fp_predictor_encode(buf, width * samples, height, + bytes_per_sample) + return buf + + def _compression_tag(compression_name: str) -> int: """Convert compression name to TIFF tag value.""" _map = { @@ -303,7 +340,7 @@ def _build_ifd(tags: list[tuple], overflow_base: int, # Strip writer # --------------------------------------------------------------------------- -def _write_stripped(data: np.ndarray, compression: int, predictor: bool, +def _write_stripped(data: np.ndarray, compression: int, predictor: int, rows_per_strip: int = 256, compression_level: int | None = None) -> tuple[list, list, list]: """Compress data as strips. @@ -333,10 +370,11 @@ def _write_stripped(data: np.ndarray, compression: int, predictor: bool, if compression == COMPRESSION_JPEG: strip_data = np.ascontiguousarray(data[r0:r1]).tobytes() compressed = jpeg_compress(strip_data, width, strip_rows, samples) - elif predictor and compression != COMPRESSION_NONE: + elif predictor != 1 and compression != COMPRESSION_NONE: strip_arr = np.ascontiguousarray(data[r0:r1]) buf = strip_arr.view(np.uint8).ravel().copy() - buf = predictor_encode(buf, width, strip_rows, bytes_per_sample * samples) + buf = _apply_predictor_encode( + buf, predictor, width, strip_rows, bytes_per_sample, samples) strip_data = buf.tobytes() if compression_level is None: compressed = compress(strip_data, compression) @@ -371,7 +409,7 @@ def _write_stripped(data: np.ndarray, compression: int, predictor: bool, # --------------------------------------------------------------------------- def _prepare_tile(data, tr, tc, th, tw, height, width, samples, dtype, - bytes_per_sample, predictor, compression, + bytes_per_sample, predictor: int, compression, compression_level=None): """Extract, pad, and compress a single tile. Thread-safe.""" r0 = tr * th @@ -400,9 +438,10 @@ def _prepare_tile(data, tr, tc, th, tw, height, width, samples, dtype, if compression == COMPRESSION_JPEG: tile_data = tile_arr.tobytes() return jpeg_compress(tile_data, tw, th, samples) - elif predictor and compression != COMPRESSION_NONE: + elif predictor != 1 and compression != COMPRESSION_NONE: buf = tile_arr.view(np.uint8).ravel().copy() - buf = predictor_encode(buf, tw, th, bytes_per_sample * samples) + buf = _apply_predictor_encode( + buf, predictor, tw, th, bytes_per_sample, samples) tile_data = buf.tobytes() else: tile_data = tile_arr.tobytes() @@ -420,7 +459,7 @@ def _prepare_tile(data, tr, tc, th, tw, height, width, samples, dtype, return compress(tile_data, compression, level=compression_level) -def _write_tiled(data: np.ndarray, compression: int, predictor: bool, +def _write_tiled(data: np.ndarray, compression: int, predictor: int, tile_size: int = 256, compression_level: int | None = None) -> tuple[list, list, list]: """Compress data as tiles, using parallel compression. @@ -538,7 +577,7 @@ def _write_tiled(data: np.ndarray, compression: int, predictor: bool, # --------------------------------------------------------------------------- def _assemble_tiff(width: int, height: int, dtype: np.dtype, - compression: int, predictor: bool, + compression: int, predictor: int, tiled: bool, tile_size: int, pixel_data_parts: list[tuple], geo_transform: GeoTransform | None, @@ -594,7 +633,7 @@ def _assemble_tiff(width: int, height: int, dtype: np.dtype, geo_tags_dict.pop(TAG_MODEL_TIEPOINT, None) # Compression tag for predictor - pred_val = 2 if (predictor and compression != COMPRESSION_NONE) else 1 + pred_val = predictor if compression != COMPRESSION_NONE else 1 # Build IFDs for each resolution level ifd_specs = [] @@ -916,7 +955,7 @@ def write(data: np.ndarray, path: str, *, compression_level: int | None = None, tiled: bool = True, tile_size: int = 256, - predictor: bool = False, + predictor: bool | int = False, cog: bool = False, overview_levels: list[int] | None = None, overview_resampling: str = 'mean', @@ -947,8 +986,10 @@ def write(data: np.ndarray, path: str, *, Use tiled layout (vs strips). tile_size : int Tile width and height. - predictor : bool - Use horizontal differencing predictor. + predictor : bool or int + TIFF predictor. ``False``/``0``/``1`` -> none, ``True``/``2`` -> + horizontal differencing, ``3`` -> floating-point predictor + (float dtypes only). cog : bool Write as Cloud Optimized GeoTIFF. overview_levels : list of int or None @@ -956,6 +997,7 @@ def write(data: np.ndarray, path: str, *, Only used if cog=True. If None and cog=True, auto-generate. """ comp_tag = _compression_tag(compression) + pred_int = normalize_predictor(predictor, data.dtype, comp_tag) # JPEG validation: only uint8, 1 or 3 bands if comp_tag == COMPRESSION_JPEG: @@ -973,10 +1015,10 @@ def write(data: np.ndarray, path: str, *, # Full resolution if tiled: - rel_off, bc, comp_data = _write_tiled(data, comp_tag, predictor, tile_size, + rel_off, bc, comp_data = _write_tiled(data, comp_tag, pred_int, tile_size, compression_level=compression_level) else: - rel_off, bc, comp_data = _write_stripped(data, comp_tag, predictor, + rel_off, bc, comp_data = _write_stripped(data, comp_tag, pred_int, compression_level=compression_level) h, w = data.shape[:2] @@ -999,16 +1041,16 @@ def write(data: np.ndarray, path: str, *, current = _make_overview(current, method=overview_resampling) oh, ow = current.shape[:2] if tiled: - o_off, o_bc, o_data = _write_tiled(current, comp_tag, predictor, + o_off, o_bc, o_data = _write_tiled(current, comp_tag, pred_int, tile_size, compression_level=compression_level) else: - o_off, o_bc, o_data = _write_stripped(current, comp_tag, predictor, + o_off, o_bc, o_data = _write_stripped(current, comp_tag, pred_int, compression_level=compression_level) parts.append((current, ow, oh, o_off, o_bc, o_data)) file_bytes = _assemble_tiff( - w, h, data.dtype, comp_tag, predictor, tiled, tile_size, + w, h, data.dtype, comp_tag, pred_int, tiled, tile_size, parts, geo_transform, crs_epsg, nodata, is_cog=cog, raster_type=raster_type, crs_wkt=crs_wkt, gdal_metadata_xml=gdal_metadata_xml, @@ -1030,15 +1072,15 @@ def write(data: np.ndarray, path: str, *, def _compress_block(arr, block_w, block_h, samples, dtype, bytes_per_sample, - predictor, compression, compression_level=None): + predictor: int, compression, compression_level=None): """Compress a tile or strip. *arr* must be contiguous and correctly sized.""" if compression == COMPRESSION_JPEG: return jpeg_compress(arr.tobytes(), block_w, block_h, samples) - if predictor and compression != COMPRESSION_NONE: + if predictor != 1 and compression != COMPRESSION_NONE: buf = arr.view(np.uint8).ravel().copy() - buf = predictor_encode(buf, block_w, block_h, - bytes_per_sample * samples) + buf = _apply_predictor_encode( + buf, predictor, block_w, block_h, bytes_per_sample, samples) raw_data = buf.tobytes() else: raw_data = arr.tobytes() @@ -1069,7 +1111,7 @@ def write_streaming(dask_data, path: str, *, compression_level: int | None = None, tiled: bool = True, tile_size: int = 256, - predictor: bool = False, + predictor: bool | int = False, raster_type: int = 1, x_resolution: float | None = None, y_resolution: float | None = None, @@ -1109,6 +1151,7 @@ def write_streaming(dask_data, path: str, *, bits_per_sample, sample_format = numpy_to_tiff_dtype(out_dtype) bytes_per_sample = out_dtype.itemsize comp_tag = _compression_tag(compression) + pred_int = normalize_predictor(predictor, out_dtype, comp_tag) if comp_tag == COMPRESSION_JPEG: if out_dtype != np.uint8: @@ -1166,7 +1209,7 @@ def write_streaming(dask_data, path: str, *, extra_vals = [0] * n_extra tags.append((TAG_EXTRA_SAMPLES, SHORT, n_extra, extra_vals)) - pred_val = 2 if (predictor and comp_tag != COMPRESSION_NONE) else 1 + pred_val = pred_int if comp_tag != COMPRESSION_NONE else 1 if pred_val != 1: tags.append((TAG_PREDICTOR, SHORT, 1, pred_val)) @@ -1313,7 +1356,7 @@ def write_streaming(dask_data, path: str, *, compressed = _compress_block( tile_arr, tw, th, samples, out_dtype, - bytes_per_sample, predictor, comp_tag, + bytes_per_sample, pred_int, comp_tag, compression_level) actual_offsets.append(current_offset) @@ -1350,7 +1393,7 @@ def write_streaming(dask_data, path: str, *, compressed = _compress_block( np.ascontiguousarray(strip_np), width, strip_rows, samples, out_dtype, - bytes_per_sample, predictor, comp_tag, + bytes_per_sample, pred_int, comp_tag, compression_level) actual_offsets.append(current_offset) diff --git a/xrspatial/geotiff/tests/test_predictor_fp_write_1313.py b/xrspatial/geotiff/tests/test_predictor_fp_write_1313.py new file mode 100644 index 00000000..d168a209 --- /dev/null +++ b/xrspatial/geotiff/tests/test_predictor_fp_write_1313.py @@ -0,0 +1,203 @@ +"""Tests for floating-point predictor (predictor=3) write path. + +Issue #1313: ``to_geotiff`` previously accepted ``predictor: bool`` and +emitted only TIFF predictor 2 (horizontal differencing). Predictor 3 +(byte-swizzled differencing per TIFF Technical Note 3) gives noticeably +better deflate/zstd ratios on float data and is what most GDAL/rasterio +workflows use for elevation rasters. These tests cover the new write +path end-to-end. +""" +from __future__ import annotations + +import struct + +import numpy as np +import pytest +import xarray as xr + +from xrspatial.geotiff import open_geotiff, to_geotiff +from xrspatial.geotiff._writer import normalize_predictor + + +def _smooth_float(shape, dtype): + """Smooth surface where FP predictor is expected to help compression.""" + y, x = np.mgrid[0:shape[0], 0:shape[1]].astype(dtype) + return (np.sin(x / 40) * np.cos(y / 40) * 100 + (x + y) / 4).astype(dtype) + + +def _da(arr): + h, w = arr.shape[:2] + coords = { + 'x': np.arange(w, dtype=np.float64) * 10.0, + 'y': np.arange(h, dtype=np.float64) * 10.0, + } + if arr.ndim == 2: + return xr.DataArray(arr, dims=('y', 'x'), coords=coords) + # 3D: (band, y, x) + return xr.DataArray(arr, dims=('band', 'y', 'x'), coords=coords) + + +def _read_predictor_tag(path): + """Read the TIFF Predictor tag (id=317) directly from a file's IFD.""" + with open(path, 'rb') as f: + header = f.read(8) + assert header[:2] == b'II', "test fixture writes little-endian" + magic = struct.unpack(' predictor 1 (none) + + +@pytest.mark.parametrize('dtype', [np.float32, np.float64]) +@pytest.mark.parametrize('compression', ['deflate', 'zstd']) +@pytest.mark.parametrize('tiled', [True, False]) +def test_predictor3_round_trip(tmp_path, dtype, compression, tiled): + arr = _smooth_float((96, 128), dtype) + da = _da(arr) + path = tmp_path / f'fp_pred_1313_{np.dtype(dtype).name}_{compression}.tif' + to_geotiff(da, str(path), compression=compression, predictor=3, + tiled=tiled) + + out = open_geotiff(str(path)) + np.testing.assert_array_equal(out.values, arr) + + # File should advertise predictor=3 in the TIFF tag + assert _read_predictor_tag(str(path)) == 3 + + +def test_predictor3_better_than_predictor2_on_smooth_floats(tmp_path): + arr = _smooth_float((512, 512), np.float32) + da = _da(arr) + p2 = tmp_path / 'pred2_1313.tif' + p3 = tmp_path / 'pred3_1313.tif' + to_geotiff(da, str(p2), compression='deflate', predictor=2) + to_geotiff(da, str(p3), compression='deflate', predictor=3) + + # FP predictor exists precisely because it compresses smooth floats + # better than horizontal differencing. If this regresses, something + # is wrong with the encoder wiring. + assert p3.stat().st_size < p2.stat().st_size + + +def test_predictor3_rejects_integer_dtype(tmp_path): + arr = np.zeros((8, 8), dtype=np.int32) + da = _da(arr) + path = tmp_path / 'bad_int_1313.tif' + with pytest.raises(ValueError, match='predictor=3'): + to_geotiff(da, str(path), compression='deflate', predictor=3) + + +def test_predictor_legacy_bool_unchanged(tmp_path): + """``predictor=True`` keeps emitting TIFF predictor 2.""" + arr = _smooth_float((32, 32), np.float32) + da = _da(arr) + path = tmp_path / 'legacy_true_1313.tif' + to_geotiff(da, str(path), compression='deflate', predictor=True) + assert _read_predictor_tag(str(path)) == 2 + + out = open_geotiff(str(path)) + np.testing.assert_array_equal(out.values, arr) + + +def test_predictor_false_emits_no_tag(tmp_path): + arr = _smooth_float((32, 32), np.float32) + da = _da(arr) + path = tmp_path / 'legacy_false_1313.tif' + to_geotiff(da, str(path), compression='deflate', predictor=False) + # Tag absent OR set to 1 -> both mean "no predictor" + tag = _read_predictor_tag(str(path)) + assert tag in (None, 1) + + +def test_predictor3_with_compression_none_is_silent(tmp_path): + """compression='none' suppresses any predictor (matches predictor=2 behavior).""" + arr = _smooth_float((16, 16), np.float32) + da = _da(arr) + path = tmp_path / 'pred3_nocomp_1313.tif' + to_geotiff(da, str(path), compression='none', predictor=3) + + tag = _read_predictor_tag(str(path)) + assert tag in (None, 1), \ + "predictor must be suppressed when compression=none" + + out = open_geotiff(str(path)) + np.testing.assert_array_equal(out.values, arr) + + +def test_normalize_predictor_table(): + """The bool|int -> int normalization mapping.""" + f32 = np.dtype('float32') + i32 = np.dtype('int32') + deflate = 8 # COMPRESSION_DEFLATE + + assert normalize_predictor(False, f32, deflate) == 1 + assert normalize_predictor(0, f32, deflate) == 1 + assert normalize_predictor(1, f32, deflate) == 1 + assert normalize_predictor(True, f32, deflate) == 2 + assert normalize_predictor(2, f32, deflate) == 2 + assert normalize_predictor(3, f32, deflate) == 3 + + with pytest.raises(ValueError, match='predictor=3'): + normalize_predictor(3, i32, deflate) + + with pytest.raises(ValueError, match='predictor must be'): + normalize_predictor(99, f32, deflate) + + +def test_predictor3_streaming_dask(tmp_path): + """Dask-backed input takes the streaming path; predictor=3 must work.""" + da_module = pytest.importorskip('dask.array') + arr = _smooth_float((128, 192), np.float32) + dask_arr = da_module.from_array(arr, chunks=(64, 96)) + da = xr.DataArray( + dask_arr, dims=('y', 'x'), + coords={'x': np.arange(192, dtype=np.float64) * 10.0, + 'y': np.arange(128, dtype=np.float64) * 10.0}, + ) + path = tmp_path / 'pred3_streaming_1313.tif' + to_geotiff(da, str(path), compression='deflate', predictor=3, + tile_size=64) + + assert _read_predictor_tag(str(path)) == 3 + out = open_geotiff(str(path)) + np.testing.assert_array_equal(out.values, arr) + + +def test_predictor3_multiband_round_trip(tmp_path): + """Multi-band float predictor=3 round-trip. + + Issue #1247 fixed the read side; this checks the write side now + round-trips correctly for the multi-band case where the row swizzle + has to use ``width * samples`` lanes, not ``width``. + """ + h, w = 48, 64 + arr = np.stack([ + _smooth_float((h, w), np.float32), + _smooth_float((h, w), np.float32) * 1.5, + _smooth_float((h, w), np.float32) - 10.0, + ], axis=0) # (3, h, w) + da = xr.DataArray( + arr, dims=('band', 'y', 'x'), + coords={'x': np.arange(w, dtype=np.float64) * 10.0, + 'y': np.arange(h, dtype=np.float64) * 10.0}, + ) + path = tmp_path / 'pred3_multiband_1313.tif' + to_geotiff(da, str(path), compression='deflate', predictor=3) + + assert _read_predictor_tag(str(path)) == 3 + out = open_geotiff(str(path)) + # open_geotiff returns (y, x, band) for multi-band; reorder for compare + if out.ndim == 3 and out.shape[-1] == 3: + out_arr = np.moveaxis(out.values, -1, 0) + else: + out_arr = out.values + np.testing.assert_array_equal(out_arr, arr)