Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions xrspatial/geotiff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
compression_level: int | None = None,
tiled: bool = True,
tile_size: int = 256,
predictor: bool = False,
predictor: bool | int = False,
cog: bool = False,
overview_levels: list[int] | None = None,
overview_resampling: str = 'mean',
Expand Down Expand Up @@ -453,8 +453,11 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
Use tiled layout (default True).
tile_size : int
Tile size in pixels (default 256).
predictor : bool
Use horizontal differencing predictor.
predictor : bool or int
TIFF predictor. ``False``/``0``/``1`` -> none, ``True``/``2`` ->
horizontal differencing (good for integer data), ``3`` ->
floating-point predictor (float dtypes only; typically gives
better deflate/zstd ratios on float data than predictor 2).
cog : bool
Write as Cloud Optimized GeoTIFF.
overview_levels : list[int] or None
Expand Down Expand Up @@ -707,7 +710,8 @@ def _write_single_tile(chunk_data, path, geo_transform, epsg, wkt,

def _write_vrt_tiled(data, vrt_path, *, crs=None, nodata=None,
compression='zstd', compression_level=None,
tile_size=256, predictor=False, bigtiff=None):
tile_size=256, predictor: bool | int = False,
bigtiff=None):
"""Write a DataArray as a directory of tiled GeoTIFFs with a VRT index.

This enables streaming dask arrays to disk without materializing the
Expand Down Expand Up @@ -1223,7 +1227,7 @@ def write_geotiff_gpu(data, path: str, *,
compression: str = 'zstd',
compression_level: int | None = None,
tile_size: int = 256,
predictor: bool = False,
predictor: bool | int = False,
cog: bool = False,
overview_levels: list[int] | None = None,
overview_resampling: str = 'mean') -> None:
Expand Down Expand Up @@ -1258,8 +1262,10 @@ def write_geotiff_gpu(data, path: str, *,
currently ignored -- nvCOMP does not expose level control.
tile_size : int
Tile size in pixels (default 256).
predictor : bool
Apply horizontal differencing predictor.
predictor : bool or int
TIFF predictor. ``False``/``0``/``1`` -> none, ``True``/``2`` ->
horizontal differencing, ``3`` -> floating-point predictor
(float dtypes only).
cog : bool
Write as Cloud Optimized GeoTIFF with overviews.
overview_levels : list[int] or None
Expand All @@ -1278,6 +1284,7 @@ def write_geotiff_gpu(data, path: str, *,
from ._gpu_decode import gpu_compress_tiles, make_overview_gpu
from ._writer import (
_compression_tag, _assemble_tiff, _write_bytes,
normalize_predictor,
GeoTransform as _GT,
)
from ._dtypes import numpy_to_tiff_dtype
Expand Down Expand Up @@ -1328,7 +1335,7 @@ def write_geotiff_gpu(data, path: str, *,
np_dtype = np.dtype(str(arr.dtype)) # cupy dtype -> numpy dtype

comp_tag = _compression_tag(compression)
pred_val = 2 if predictor else 1
pred_val = normalize_predictor(predictor, np_dtype, comp_tag)

def _gpu_compress_to_part(gpu_arr, w, h, spp):
"""Compress a GPU array into a (stub, w, h, offsets, counts, tiles) tuple."""
Expand Down Expand Up @@ -1366,7 +1373,7 @@ def _gpu_compress_to_part(gpu_arr, w, h, spp):
parts.append(_gpu_compress_to_part(current, ow, oh, samples))

file_bytes = _assemble_tiff(
width, height, np_dtype, comp_tag, predictor, True, tile_size,
width, height, np_dtype, comp_tag, pred_val, True, tile_size,
parts, geo_transform, epsg, nodata,
is_cog=(cog and len(parts) > 1),
raster_type=raster_type)
Expand Down
93 changes: 68 additions & 25 deletions xrspatial/geotiff/_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
COMPRESSION_PACKBITS,
COMPRESSION_ZSTD,
compress,
fp_predictor_encode,
jpeg_compress,
predictor_encode,
)
Expand Down Expand Up @@ -67,6 +68,42 @@
BO = '<'


def normalize_predictor(predictor, dtype, compression: int) -> int:
"""Normalize a user-supplied predictor value to a TIFF predictor int.

Accepts ``False``/``True`` (legacy) and integers ``1``/``2``/``3``.
Returns ``1`` (no predictor), ``2`` (horizontal differencing), or ``3``
(floating-point predictor).
"""
if predictor is False or predictor == 0:
return 1
if predictor is True or predictor == 2:
return 2
if predictor == 1:
return 1
if predictor == 3:
if np.dtype(dtype).kind != 'f':
raise ValueError(
"predictor=3 (floating-point) requires float data, "
f"got dtype={np.dtype(dtype)}")
return 3
raise ValueError(
f"predictor must be False/True or 1/2/3, got {predictor!r}")


def _apply_predictor_encode(buf: np.ndarray, predictor: int,
width: int, height: int,
bytes_per_sample: int, samples: int) -> np.ndarray:
"""Apply the chosen predictor to a flat uint8 buffer."""
if predictor == 2:
return predictor_encode(buf, width, height,
bytes_per_sample * samples)
if predictor == 3:
return fp_predictor_encode(buf, width * samples, height,
bytes_per_sample)
return buf


def _compression_tag(compression_name: str) -> int:
"""Convert compression name to TIFF tag value."""
_map = {
Expand Down Expand Up @@ -303,7 +340,7 @@ def _build_ifd(tags: list[tuple], overflow_base: int,
# Strip writer
# ---------------------------------------------------------------------------

def _write_stripped(data: np.ndarray, compression: int, predictor: bool,
def _write_stripped(data: np.ndarray, compression: int, predictor: int,
rows_per_strip: int = 256,
compression_level: int | None = None) -> tuple[list, list, list]:
"""Compress data as strips.
Expand Down Expand Up @@ -333,10 +370,11 @@ def _write_stripped(data: np.ndarray, compression: int, predictor: bool,
if compression == COMPRESSION_JPEG:
strip_data = np.ascontiguousarray(data[r0:r1]).tobytes()
compressed = jpeg_compress(strip_data, width, strip_rows, samples)
elif predictor and compression != COMPRESSION_NONE:
elif predictor != 1 and compression != COMPRESSION_NONE:
strip_arr = np.ascontiguousarray(data[r0:r1])
buf = strip_arr.view(np.uint8).ravel().copy()
buf = predictor_encode(buf, width, strip_rows, bytes_per_sample * samples)
buf = _apply_predictor_encode(
buf, predictor, width, strip_rows, bytes_per_sample, samples)
strip_data = buf.tobytes()
if compression_level is None:
compressed = compress(strip_data, compression)
Expand Down Expand Up @@ -371,7 +409,7 @@ def _write_stripped(data: np.ndarray, compression: int, predictor: bool,
# ---------------------------------------------------------------------------

def _prepare_tile(data, tr, tc, th, tw, height, width, samples, dtype,
bytes_per_sample, predictor, compression,
bytes_per_sample, predictor: int, compression,
compression_level=None):
"""Extract, pad, and compress a single tile. Thread-safe."""
r0 = tr * th
Expand Down Expand Up @@ -400,9 +438,10 @@ def _prepare_tile(data, tr, tc, th, tw, height, width, samples, dtype,
if compression == COMPRESSION_JPEG:
tile_data = tile_arr.tobytes()
return jpeg_compress(tile_data, tw, th, samples)
elif predictor and compression != COMPRESSION_NONE:
elif predictor != 1 and compression != COMPRESSION_NONE:
buf = tile_arr.view(np.uint8).ravel().copy()
buf = predictor_encode(buf, tw, th, bytes_per_sample * samples)
buf = _apply_predictor_encode(
buf, predictor, tw, th, bytes_per_sample, samples)
tile_data = buf.tobytes()
else:
tile_data = tile_arr.tobytes()
Expand All @@ -420,7 +459,7 @@ def _prepare_tile(data, tr, tc, th, tw, height, width, samples, dtype,
return compress(tile_data, compression, level=compression_level)


def _write_tiled(data: np.ndarray, compression: int, predictor: bool,
def _write_tiled(data: np.ndarray, compression: int, predictor: int,
tile_size: int = 256,
compression_level: int | None = None) -> tuple[list, list, list]:
"""Compress data as tiles, using parallel compression.
Expand Down Expand Up @@ -538,7 +577,7 @@ def _write_tiled(data: np.ndarray, compression: int, predictor: bool,
# ---------------------------------------------------------------------------

def _assemble_tiff(width: int, height: int, dtype: np.dtype,
compression: int, predictor: bool,
compression: int, predictor: int,
tiled: bool, tile_size: int,
pixel_data_parts: list[tuple],
geo_transform: GeoTransform | None,
Expand Down Expand Up @@ -594,7 +633,7 @@ def _assemble_tiff(width: int, height: int, dtype: np.dtype,
geo_tags_dict.pop(TAG_MODEL_TIEPOINT, None)

# Compression tag for predictor
pred_val = 2 if (predictor and compression != COMPRESSION_NONE) else 1
pred_val = predictor if compression != COMPRESSION_NONE else 1

# Build IFDs for each resolution level
ifd_specs = []
Expand Down Expand Up @@ -916,7 +955,7 @@ def write(data: np.ndarray, path: str, *,
compression_level: int | None = None,
tiled: bool = True,
tile_size: int = 256,
predictor: bool = False,
predictor: bool | int = False,
cog: bool = False,
overview_levels: list[int] | None = None,
overview_resampling: str = 'mean',
Expand Down Expand Up @@ -947,15 +986,18 @@ def write(data: np.ndarray, path: str, *,
Use tiled layout (vs strips).
tile_size : int
Tile width and height.
predictor : bool
Use horizontal differencing predictor.
predictor : bool or int
TIFF predictor. ``False``/``0``/``1`` -> none, ``True``/``2`` ->
horizontal differencing, ``3`` -> floating-point predictor
(float dtypes only).
cog : bool
Write as Cloud Optimized GeoTIFF.
overview_levels : list of int or None
Overview decimation factors (e.g. [2, 4, 8]).
Only used if cog=True. If None and cog=True, auto-generate.
"""
comp_tag = _compression_tag(compression)
pred_int = normalize_predictor(predictor, data.dtype, comp_tag)

# JPEG validation: only uint8, 1 or 3 bands
if comp_tag == COMPRESSION_JPEG:
Expand All @@ -973,10 +1015,10 @@ def write(data: np.ndarray, path: str, *,

# Full resolution
if tiled:
rel_off, bc, comp_data = _write_tiled(data, comp_tag, predictor, tile_size,
rel_off, bc, comp_data = _write_tiled(data, comp_tag, pred_int, tile_size,
compression_level=compression_level)
else:
rel_off, bc, comp_data = _write_stripped(data, comp_tag, predictor,
rel_off, bc, comp_data = _write_stripped(data, comp_tag, pred_int,
compression_level=compression_level)

h, w = data.shape[:2]
Expand All @@ -999,16 +1041,16 @@ def write(data: np.ndarray, path: str, *,
current = _make_overview(current, method=overview_resampling)
oh, ow = current.shape[:2]
if tiled:
o_off, o_bc, o_data = _write_tiled(current, comp_tag, predictor,
o_off, o_bc, o_data = _write_tiled(current, comp_tag, pred_int,
tile_size,
compression_level=compression_level)
else:
o_off, o_bc, o_data = _write_stripped(current, comp_tag, predictor,
o_off, o_bc, o_data = _write_stripped(current, comp_tag, pred_int,
compression_level=compression_level)
parts.append((current, ow, oh, o_off, o_bc, o_data))

file_bytes = _assemble_tiff(
w, h, data.dtype, comp_tag, predictor, tiled, tile_size,
w, h, data.dtype, comp_tag, pred_int, tiled, tile_size,
parts, geo_transform, crs_epsg, nodata, is_cog=cog,
raster_type=raster_type, crs_wkt=crs_wkt,
gdal_metadata_xml=gdal_metadata_xml,
Expand All @@ -1030,15 +1072,15 @@ def write(data: np.ndarray, path: str, *,


def _compress_block(arr, block_w, block_h, samples, dtype, bytes_per_sample,
predictor, compression, compression_level=None):
predictor: int, compression, compression_level=None):
"""Compress a tile or strip. *arr* must be contiguous and correctly sized."""
if compression == COMPRESSION_JPEG:
return jpeg_compress(arr.tobytes(), block_w, block_h, samples)

if predictor and compression != COMPRESSION_NONE:
if predictor != 1 and compression != COMPRESSION_NONE:
buf = arr.view(np.uint8).ravel().copy()
buf = predictor_encode(buf, block_w, block_h,
bytes_per_sample * samples)
buf = _apply_predictor_encode(
buf, predictor, block_w, block_h, bytes_per_sample, samples)
raw_data = buf.tobytes()
else:
raw_data = arr.tobytes()
Expand Down Expand Up @@ -1069,7 +1111,7 @@ def write_streaming(dask_data, path: str, *,
compression_level: int | None = None,
tiled: bool = True,
tile_size: int = 256,
predictor: bool = False,
predictor: bool | int = False,
raster_type: int = 1,
x_resolution: float | None = None,
y_resolution: float | None = None,
Expand Down Expand Up @@ -1109,6 +1151,7 @@ def write_streaming(dask_data, path: str, *,
bits_per_sample, sample_format = numpy_to_tiff_dtype(out_dtype)
bytes_per_sample = out_dtype.itemsize
comp_tag = _compression_tag(compression)
pred_int = normalize_predictor(predictor, out_dtype, comp_tag)

if comp_tag == COMPRESSION_JPEG:
if out_dtype != np.uint8:
Expand Down Expand Up @@ -1166,7 +1209,7 @@ def write_streaming(dask_data, path: str, *,
extra_vals = [0] * n_extra
tags.append((TAG_EXTRA_SAMPLES, SHORT, n_extra, extra_vals))

pred_val = 2 if (predictor and comp_tag != COMPRESSION_NONE) else 1
pred_val = pred_int if comp_tag != COMPRESSION_NONE else 1
if pred_val != 1:
tags.append((TAG_PREDICTOR, SHORT, 1, pred_val))

Expand Down Expand Up @@ -1313,7 +1356,7 @@ def write_streaming(dask_data, path: str, *,

compressed = _compress_block(
tile_arr, tw, th, samples, out_dtype,
bytes_per_sample, predictor, comp_tag,
bytes_per_sample, pred_int, comp_tag,
compression_level)

actual_offsets.append(current_offset)
Expand Down Expand Up @@ -1350,7 +1393,7 @@ def write_streaming(dask_data, path: str, *,
compressed = _compress_block(
np.ascontiguousarray(strip_np),
width, strip_rows, samples, out_dtype,
bytes_per_sample, predictor, comp_tag,
bytes_per_sample, pred_int, comp_tag,
compression_level)

actual_offsets.append(current_offset)
Expand Down
Loading
Loading