From b3e1e8fdcd1683fdea2ed5a5180fcc01938e32c8 Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Sun, 22 Mar 2026 08:10:02 -0700 Subject: [PATCH] Add nvJPEG GPU acceleration for JPEG-compressed GeoTIFFs (#1050) Wire JPEG (TIFF tag 7) into the GPU decode and encode pipelines. When libnvjpeg.so is available, read_geotiff(gpu=True) and write_geotiff(gpu=True, compression='jpeg') use nvJPEG for hardware-accelerated JPEG on GPU. Falls back to Pillow on CPU when nvJPEG is not installed. Changes: - _gpu_decode.py: Add _find_nvjpeg_lib/_get_nvjpeg lazy discovery, _try_nvjpeg_batch_decode for GPU reads, _nvjpeg_batch_encode for GPU writes. Hook tag 7 into gpu_decode_tiles and gpu_compress_tiles. - _writer.py: Add 'jpeg' to _compression_tag map. Handle JPEG in _write_tiled and _write_stripped (call jpeg_compress directly, skip predictor). Validate uint8 dtype and 1/3 band constraint. - __init__.py: Update docstrings with JPEG option. - README.md: Add nvJPEG to GPU codecs list and JPEG write example. - tests/test_jpeg.py: 13 tests covering codec round trips, tag wiring, tiled/stripped write-read, public API, and dtype validation. - tests/test_writer.py: Fix test_unsupported_compression (was using 'jpeg' as the unsupported example). --- README.md | 3 +- xrspatial/geotiff/__init__.py | 8 +- xrspatial/geotiff/_gpu_decode.py | 347 ++++++++++++++++++++++++- xrspatial/geotiff/_writer.py | 33 ++- xrspatial/geotiff/tests/test_jpeg.py | 151 +++++++++++ xrspatial/geotiff/tests/test_writer.py | 2 +- 6 files changed, 533 insertions(+), 11 deletions(-) create mode 100644 xrspatial/geotiff/tests/test_jpeg.py diff --git a/README.md b/README.md index ccd4bcaf..a65af93c 100644 --- a/README.md +++ b/README.md @@ -158,12 +158,13 @@ read_geotiff('mosaic.vrt') # VRT mosaic (auto-detected write_geotiff(cupy_array, 'out.tif') # auto-detects GPU write_geotiff(data, 'out.tif', gpu=True) # force GPU compress +write_geotiff(data, 'ortho.tif', compression='jpeg') # JPEG for orthophotos write_vrt('mosaic.vrt', ['tile1.tif', 'tile2.tif']) # generate VRT ``` **Compression codecs:** Deflate, LZW (Numba JIT), ZSTD, PackBits, JPEG (Pillow), uncompressed -**GPU codecs:** Deflate and ZSTD via nvCOMP batch API; LZW via Numba CUDA kernels +**GPU codecs:** Deflate and ZSTD via nvCOMP; LZW via Numba CUDA; JPEG via nvJPEG **Features:** - Tiled, stripped, BigTIFF, multi-band (RGB/RGBA), sub-byte (1/2/4/12-bit) diff --git a/xrspatial/geotiff/__init__.py b/xrspatial/geotiff/__init__.py index 2940ba4c..57001366 100644 --- a/xrspatial/geotiff/__init__.py +++ b/xrspatial/geotiff/__init__.py @@ -319,7 +319,10 @@ def write_geotiff(data: xr.DataArray | np.ndarray, path: str, *, nodata : float, int, or None NoData value. compression : str - 'none', 'deflate', or 'lzw'. + 'none', 'deflate', 'lzw', 'jpeg', 'packbits', or 'zstd'. + JPEG is lossy and only supports uint8 data (1 or 3 bands). + With ``gpu=True``, JPEG uses nvJPEG for GPU-accelerated + encode/decode when available, falling back to Pillow on CPU. tiled : bool Use tiled layout (default True). tile_size : int @@ -756,7 +759,8 @@ def write_geotiff_gpu(data, path: str, *, nodata : float, int, or None NoData value. compression : str - 'zstd' (default, fastest on GPU), 'deflate', or 'none'. + 'zstd' (default, fastest on GPU), 'deflate', 'jpeg', or 'none'. + JPEG uses nvJPEG when available, falling back to Pillow. tile_size : int Tile size in pixels (default 256). predictor : bool diff --git a/xrspatial/geotiff/_gpu_decode.py b/xrspatial/geotiff/_gpu_decode.py index 93f2ae1a..45dda422 100644 --- a/xrspatial/geotiff/_gpu_decode.py +++ b/xrspatial/geotiff/_gpu_decode.py @@ -891,6 +891,311 @@ class _NvcompDeflateDecompOpts(ctypes.Structure): return None +# --------------------------------------------------------------------------- +# nvJPEG batch decode/encode (optional, GPU-accelerated JPEG) +# --------------------------------------------------------------------------- + +def _find_nvjpeg_lib(): + """Find and load libnvjpeg.so from the CUDA toolkit. Returns CDLL or None.""" + import ctypes + import os + + search_paths = [ + 'libnvjpeg.so', # system LD_LIBRARY_PATH + ] + + # CUDA toolkit path + cuda_home = os.environ.get('CUDA_HOME', os.environ.get('CUDA_PATH', '')) + if cuda_home: + for subdir in ('lib64', 'lib'): + search_paths.append(os.path.join(cuda_home, subdir, 'libnvjpeg.so')) + + # Conda env + conda_prefix = os.environ.get('CONDA_PREFIX', '') + if conda_prefix: + search_paths.append(os.path.join(conda_prefix, 'lib', 'libnvjpeg.so')) + + # Common CUDA toolkit install locations + for ver_dir in ('/usr/local/cuda/lib64', '/usr/local/cuda/lib'): + search_paths.append(os.path.join(ver_dir, 'libnvjpeg.so')) + + for path in search_paths: + try: + return ctypes.CDLL(path) + except OSError: + continue + return None + + +_nvjpeg_lib = None +_nvjpeg_checked = False + + +def _get_nvjpeg(): + """Get the nvJPEG library handle (cached). Returns CDLL or None.""" + global _nvjpeg_lib, _nvjpeg_checked + if not _nvjpeg_checked: + _nvjpeg_checked = True + _nvjpeg_lib = _find_nvjpeg_lib() + return _nvjpeg_lib + + +# nvJPEG status codes +_NVJPEG_STATUS_SUCCESS = 0 + +# nvJPEG output formats +_NVJPEG_OUTPUT_RGB = 2 # planar RGB +_NVJPEG_OUTPUT_RGBI = 3 # interleaved RGB (R0G0B0 R1G1B1 ...) +_NVJPEG_OUTPUT_UNCHANGED = 5 # native colorspace + +# nvJPEG backend +_NVJPEG_BACKEND_DEFAULT = 0 +_NVJPEG_BACKEND_GPU_HYBRID = 2 + + +def _try_nvjpeg_batch_decode(compressed_tiles, tile_width, tile_height, + samples): + """Try batch JPEG decode via nvJPEG. Returns CuPy buffer or None. + + Decodes all JPEG tiles on GPU in one batched call. Falls back to None + if nvJPEG is unavailable or any decode fails. + """ + lib = _get_nvjpeg() + if lib is None: + return None + + import ctypes + import cupy + + try: + n_tiles = len(compressed_tiles) + tile_pixels = tile_width * tile_height + tile_bytes = tile_pixels * samples # JPEG is always uint8 + + # nvJPEG handle type (opaque pointer) + nvjpeg_handle = ctypes.c_void_p() + + # nvjpegCreateSimple(&handle) + create_fn = getattr(lib, 'nvjpegCreateSimple', None) + if create_fn is None: + return None + create_fn.restype = ctypes.c_int + status = create_fn(ctypes.byref(nvjpeg_handle)) + if status != _NVJPEG_STATUS_SUCCESS: + return None + + try: + # Create JPEG state: nvjpegJpegStateCreate(handle, &state) + jpeg_state = ctypes.c_void_p() + state_create = getattr(lib, 'nvjpegJpegStateCreate') + state_create.restype = ctypes.c_int + status = state_create(nvjpeg_handle, ctypes.byref(jpeg_state)) + if status != _NVJPEG_STATUS_SUCCESS: + return None + + try: + # Decode tiles one at a time using the simple API. + # nvJPEG batch API requires more setup; the simple decode + # is still GPU-accelerated and avoids complex state management. + output_format = _NVJPEG_OUTPUT_RGBI if samples == 3 else _NVJPEG_OUTPUT_UNCHANGED + + # nvjpegImage_t: array of 4 channel pointers + 4 pitches + class _NvjpegImage(ctypes.Structure): + _fields_ = [ + ('channel', ctypes.c_void_p * 4), + ('pitch', ctypes.c_size_t * 4), + ] + + d_all = cupy.empty(n_tiles * tile_bytes, dtype=cupy.uint8) + + decode_fn = getattr(lib, 'nvjpegDecode') + decode_fn.restype = ctypes.c_int + + for i, tile_data in enumerate(compressed_tiles): + d_out = d_all[i * tile_bytes:(i + 1) * tile_bytes] + + nv_img = _NvjpegImage() + nv_img.channel[0] = ctypes.c_void_p(d_out.data.ptr) + for ch in range(1, 4): + nv_img.channel[ch] = ctypes.c_void_p(0) + nv_img.pitch[0] = ctypes.c_size_t(tile_width * samples) + for ch in range(1, 4): + nv_img.pitch[ch] = ctypes.c_size_t(0) + + src = tile_data if isinstance(tile_data, bytes) else bytes(tile_data) + + status = decode_fn( + nvjpeg_handle, + jpeg_state, + ctypes.c_char_p(src), + ctypes.c_size_t(len(src)), + ctypes.c_int(output_format), + ctypes.byref(nv_img), + ctypes.c_void_p(0), # default CUDA stream + ) + if status != _NVJPEG_STATUS_SUCCESS: + return None + + cupy.cuda.Device().synchronize() + return d_all + + finally: + destroy_state = getattr(lib, 'nvjpegJpegStateDestroy', None) + if destroy_state is not None: + destroy_state(jpeg_state) + finally: + destroy_fn = getattr(lib, 'nvjpegDestroy', None) + if destroy_fn is not None: + destroy_fn(nvjpeg_handle) + + except Exception: + return None + + +def _nvjpeg_batch_encode(d_tile_bufs, tile_width, tile_height, samples, + quality=75): + """Encode tiles as JPEG on GPU via nvJPEG. Returns list of bytes or None. + + Each tile must be a CuPy uint8 array of interleaved pixel data. + """ + lib = _get_nvjpeg() + if lib is None: + return None + + import ctypes + import cupy + + try: + n_tiles = len(d_tile_bufs) + + nvjpeg_handle = ctypes.c_void_p() + create_fn = getattr(lib, 'nvjpegCreateSimple', None) + if create_fn is None: + return None + create_fn.restype = ctypes.c_int + status = create_fn(ctypes.byref(nvjpeg_handle)) + if status != _NVJPEG_STATUS_SUCCESS: + return None + + try: + # Create encoder state and params + enc_state = ctypes.c_void_p() + enc_state_create = getattr(lib, 'nvjpegEncoderStateCreate', None) + if enc_state_create is None: + return None + enc_state_create.restype = ctypes.c_int + status = enc_state_create( + nvjpeg_handle, ctypes.byref(enc_state), + ctypes.c_void_p(0)) # default stream + if status != _NVJPEG_STATUS_SUCCESS: + return None + + try: + enc_params = ctypes.c_void_p() + params_create = getattr(lib, 'nvjpegEncoderParamsCreate') + params_create.restype = ctypes.c_int + status = params_create( + nvjpeg_handle, ctypes.byref(enc_params), + ctypes.c_void_p(0)) + if status != _NVJPEG_STATUS_SUCCESS: + return None + + try: + # Set quality + set_quality = getattr(lib, 'nvjpegEncoderParamsSetQuality') + set_quality.restype = ctypes.c_int + set_quality(enc_params, ctypes.c_int(quality), + ctypes.c_void_p(0)) + + # Set interleaved sampling + set_sampling = getattr(lib, 'nvjpegEncoderParamsSetSamplingFactors', None) + # 0 = NVJPEG_CSS_444 + if set_sampling is not None: + set_sampling.restype = ctypes.c_int + set_sampling(enc_params, ctypes.c_int(0), + ctypes.c_void_p(0)) + + class _NvjpegImage(ctypes.Structure): + _fields_ = [ + ('channel', ctypes.c_void_p * 4), + ('pitch', ctypes.c_size_t * 4), + ] + + # Choose input format + input_format = _NVJPEG_OUTPUT_RGBI if samples == 3 else _NVJPEG_OUTPUT_UNCHANGED + + encode_fn = getattr(lib, 'nvjpegEncodeImage') + encode_fn.restype = ctypes.c_int + + retrieve_fn = getattr(lib, 'nvjpegEncodeRetrieveBitstream') + retrieve_fn.restype = ctypes.c_int + + result = [] + for d_tile in d_tile_bufs: + nv_img = _NvjpegImage() + nv_img.channel[0] = ctypes.c_void_p(d_tile.data.ptr) + for ch in range(1, 4): + nv_img.channel[ch] = ctypes.c_void_p(0) + nv_img.pitch[0] = ctypes.c_size_t(tile_width * samples) + for ch in range(1, 4): + nv_img.pitch[ch] = ctypes.c_size_t(0) + + status = encode_fn( + nvjpeg_handle, enc_state, enc_params, + ctypes.byref(nv_img), + ctypes.c_int(input_format), + ctypes.c_int(tile_width), + ctypes.c_int(tile_height), + ctypes.c_void_p(0), # default stream + ) + if status != _NVJPEG_STATUS_SUCCESS: + return None + + cupy.cuda.Device().synchronize() + + # Get compressed size + length = ctypes.c_size_t(0) + status = retrieve_fn( + nvjpeg_handle, enc_state, + ctypes.c_void_p(0), # NULL = query size + ctypes.byref(length), + ctypes.c_void_p(0), + ) + if status != _NVJPEG_STATUS_SUCCESS: + return None + + # Retrieve compressed data + buf = ctypes.create_string_buffer(length.value) + status = retrieve_fn( + nvjpeg_handle, enc_state, + buf, + ctypes.byref(length), + ctypes.c_void_p(0), + ) + if status != _NVJPEG_STATUS_SUCCESS: + return None + + result.append(buf.raw[:length.value]) + + return result + + finally: + params_destroy = getattr(lib, 'nvjpegEncoderParamsDestroy', None) + if params_destroy is not None: + params_destroy(enc_params) + finally: + state_destroy = getattr(lib, 'nvjpegEncoderStateDestroy', None) + if state_destroy is not None: + state_destroy(enc_state) + finally: + destroy_fn = getattr(lib, 'nvjpegDestroy', None) + if destroy_fn is not None: + destroy_fn(nvjpeg_handle) + + except Exception: + return None + + # --------------------------------------------------------------------------- # High-level GPU decode pipeline # --------------------------------------------------------------------------- @@ -1182,6 +1487,30 @@ def gpu_decode_tiles( ) cuda.synchronize() + elif compression == 7: # JPEG + # Try nvJPEG GPU decode first, fall back to CPU Pillow + nvjpeg_result = _try_nvjpeg_batch_decode( + compressed_tiles, tile_width, tile_height, samples) + if nvjpeg_result is not None: + d_decomp = nvjpeg_result + decomp_offsets = np.arange(n_tiles, dtype=np.int64) * tile_bytes + d_decomp_offsets = cupy.asarray(decomp_offsets) + else: + from ._compression import jpeg_decompress + raw_host = np.empty(n_tiles * tile_bytes, dtype=np.uint8) + for i, tile in enumerate(compressed_tiles): + start = i * tile_bytes + decoded = np.frombuffer( + jpeg_decompress(tile, tile_width, tile_height, samples), + dtype=np.uint8) + n = min(len(decoded), tile_bytes) + raw_host[start:start + n] = decoded[:n] + if n < tile_bytes: + raw_host[start + n:start + tile_bytes] = 0 + d_decomp = cupy.asarray(raw_host) + decomp_offsets = np.arange(n_tiles, dtype=np.int64) * tile_bytes + d_decomp_offsets = cupy.asarray(decomp_offsets) + elif compression == 1: # Uncompressed raw_host = np.empty(n_tiles * tile_bytes, dtype=np.uint8) for i, tile in enumerate(compressed_tiles): @@ -1550,9 +1879,25 @@ def gpu_compress_tiles(d_image, tile_width, tile_height, d_tile_buf, d_tmp, tile_width * samples, total_rows, dtype.itemsize) cuda.synchronize() - # Split into per-tile buffers for nvCOMP + # Split into per-tile buffers d_tiles = [d_tile_buf[i * tile_bytes:(i + 1) * tile_bytes] for i in range(n_tiles)] + # JPEG: try nvJPEG encode, fall back to Pillow + if compression == 7: + result = _nvjpeg_batch_encode(d_tiles, tile_width, tile_height, samples) + if result is not None: + return result + # Fallback: CPU Pillow encode + from ._compression import jpeg_compress + cpu_buf = d_tile_buf.get() + result = [] + for i in range(n_tiles): + start = i * tile_bytes + tile_data = bytes(cpu_buf[start:start + tile_bytes]) + result.append(jpeg_compress(tile_data, tile_width, tile_height, + samples)) + return result + # Try nvCOMP batch compress result = _nvcomp_batch_compress(d_tiles, None, tile_bytes, compression, n_tiles) diff --git a/xrspatial/geotiff/_writer.py b/xrspatial/geotiff/_writer.py index ae7658ab..ae6805d1 100644 --- a/xrspatial/geotiff/_writer.py +++ b/xrspatial/geotiff/_writer.py @@ -8,11 +8,13 @@ from ._compression import ( COMPRESSION_DEFLATE, + COMPRESSION_JPEG, COMPRESSION_LZW, COMPRESSION_NONE, COMPRESSION_PACKBITS, COMPRESSION_ZSTD, compress, + jpeg_compress, predictor_encode, ) from ._dtypes import ( @@ -65,6 +67,7 @@ def _compression_tag(compression_name: str) -> int: 'none': COMPRESSION_NONE, 'deflate': COMPRESSION_DEFLATE, 'lzw': COMPRESSION_LZW, + 'jpeg': COMPRESSION_JPEG, 'packbits': COMPRESSION_PACKBITS, 'zstd': COMPRESSION_ZSTD, } @@ -310,15 +313,18 @@ def _write_stripped(data: np.ndarray, compression: int, predictor: bool, r1 = min(r0 + rows_per_strip, height) strip_rows = r1 - r0 - if predictor and compression != COMPRESSION_NONE: + if compression == COMPRESSION_JPEG: + strip_data = np.ascontiguousarray(data[r0:r1]).tobytes() + compressed = jpeg_compress(strip_data, width, strip_rows, samples) + elif predictor and compression != COMPRESSION_NONE: strip_arr = np.ascontiguousarray(data[r0:r1]) buf = strip_arr.view(np.uint8).ravel().copy() buf = predictor_encode(buf, width, strip_rows, bytes_per_sample * samples) strip_data = buf.tobytes() + compressed = compress(strip_data, compression) else: strip_data = np.ascontiguousarray(data[r0:r1]).tobytes() - - compressed = compress(strip_data, compression) + compressed = compress(strip_data, compression) rel_offsets.append(current_offset) byte_counts.append(len(compressed)) @@ -384,14 +390,18 @@ def _write_tiled(data: np.ndarray, compression: int, predictor: bool, else: tile_arr = np.ascontiguousarray(tile_slice) - if predictor and compression != COMPRESSION_NONE: + if compression == COMPRESSION_JPEG: + # JPEG: no predictor, use jpeg_compress directly + tile_data = tile_arr.tobytes() + compressed = jpeg_compress(tile_data, tw, th, samples) + elif predictor and compression != COMPRESSION_NONE: buf = tile_arr.view(np.uint8).ravel().copy() buf = predictor_encode(buf, tw, th, bytes_per_sample * samples) tile_data = buf.tobytes() + compressed = compress(tile_data, compression) else: tile_data = tile_arr.tobytes() - - compressed = compress(tile_data, compression) + compressed = compress(tile_data, compression) rel_offsets.append(current_offset) byte_counts.append(len(compressed)) @@ -780,6 +790,17 @@ def write(data: np.ndarray, path: str, *, """ comp_tag = _compression_tag(compression) + # JPEG validation: only uint8, 1 or 3 bands + if comp_tag == COMPRESSION_JPEG: + samples = data.shape[2] if data.ndim == 3 else 1 + if data.dtype != np.uint8: + raise ValueError( + f"JPEG compression requires uint8 data, got {data.dtype}. " + f"JPEG is lossy and only supports 8-bit unsigned data.") + if samples not in (1, 3): + raise ValueError( + f"JPEG compression requires 1 or 3 bands, got {samples}") + # Build pixel data parts parts = [] diff --git a/xrspatial/geotiff/tests/test_jpeg.py b/xrspatial/geotiff/tests/test_jpeg.py new file mode 100644 index 00000000..535be202 --- /dev/null +++ b/xrspatial/geotiff/tests/test_jpeg.py @@ -0,0 +1,151 @@ +"""Tests for JPEG compression support (issue #1050).""" +from __future__ import annotations + +import numpy as np +import pytest +import xarray as xr + +from xrspatial.geotiff._compression import ( + COMPRESSION_JPEG, + jpeg_compress, + jpeg_decompress, +) +from xrspatial.geotiff._writer import write, _compression_tag +from xrspatial.geotiff._reader import read_to_array + + +class TestJpegCodec: + """Low-level JPEG compress/decompress round trips.""" + + def test_grayscale_round_trip(self): + rng = np.random.RandomState(1050) + arr = rng.randint(0, 256, (32, 32), dtype=np.uint8) + compressed = jpeg_compress(arr.tobytes(), 32, 32, samples=1) + decoded = np.frombuffer( + jpeg_decompress(compressed, 32, 32, samples=1), dtype=np.uint8 + ).reshape(32, 32) + # JPEG is lossy: check approximate match + assert decoded.shape == arr.shape + assert np.abs(decoded.astype(int) - arr.astype(int)).mean() < 10 + + def test_rgb_round_trip(self): + # Use a smooth gradient -- random noise is the worst case for JPEG + y = np.linspace(50, 200, 32, dtype=np.uint8) + x = np.linspace(50, 200, 32, dtype=np.uint8) + r = np.outer(y, np.ones(32, dtype=np.uint8)) + g = np.outer(np.ones(32, dtype=np.uint8), x) + b = np.full((32, 32), 128, dtype=np.uint8) + arr = np.stack([r, g, b], axis=2) + compressed = jpeg_compress(arr.tobytes(), 32, 32, samples=3) + decoded = np.frombuffer( + jpeg_decompress(compressed, 32, 32, samples=3), dtype=np.uint8 + ).reshape(32, 32, 3) + assert decoded.shape == arr.shape + assert np.abs(decoded.astype(int) - arr.astype(int)).mean() < 10 + + def test_quality_affects_size(self): + rng = np.random.RandomState(1050) + arr = rng.randint(0, 256, (32, 32), dtype=np.uint8) + data = arr.tobytes() + low_q = jpeg_compress(data, 32, 32, samples=1, quality=10) + high_q = jpeg_compress(data, 32, 32, samples=1, quality=95) + assert len(low_q) < len(high_q) + + def test_invalid_samples_raises(self): + with pytest.raises(ValueError, match="1 or 3 bands"): + jpeg_compress(b'\x00' * 64, 4, 4, samples=2) + + +class TestCompressionTagJpeg: + """Verify JPEG is wired into the writer's compression tag map.""" + + def test_jpeg_tag_value(self): + assert _compression_tag('jpeg') == COMPRESSION_JPEG + assert _compression_tag('JPEG') == COMPRESSION_JPEG + + def test_tag_value_is_7(self): + assert COMPRESSION_JPEG == 7 + + +class TestJpegWriteRoundTrip: + """Write JPEG-compressed GeoTIFFs and read them back.""" + + def test_grayscale_tiled(self, tmp_path): + rng = np.random.RandomState(1050) + expected = rng.randint(50, 200, (32, 32), dtype=np.uint8) + path = str(tmp_path / 'gray_1050_tiled.tif') + write(expected, path, compression='jpeg', tiled=True, tile_size=16) + + arr, geo = read_to_array(path) + assert arr.shape == expected.shape + assert arr.dtype == np.uint8 + # JPEG is lossy, check approximate + assert np.abs(arr.astype(int) - expected.astype(int)).mean() < 10 + + def test_grayscale_stripped(self, tmp_path): + rng = np.random.RandomState(1050) + expected = rng.randint(50, 200, (32, 32), dtype=np.uint8) + path = str(tmp_path / 'gray_1050_stripped.tif') + write(expected, path, compression='jpeg', tiled=False) + + arr, geo = read_to_array(path) + assert arr.shape == expected.shape + assert np.abs(arr.astype(int) - expected.astype(int)).mean() < 10 + + def test_rgb_tiled(self, tmp_path): + # Smooth gradient for predictable JPEG behavior + y = np.linspace(50, 200, 32, dtype=np.uint8) + x = np.linspace(50, 200, 32, dtype=np.uint8) + r = np.outer(y, np.ones(32, dtype=np.uint8)) + g = np.outer(np.ones(32, dtype=np.uint8), x) + b = np.full((32, 32), 128, dtype=np.uint8) + expected = np.stack([r, g, b], axis=2) + path = str(tmp_path / 'rgb_1050_tiled.tif') + write(expected, path, compression='jpeg', tiled=True, tile_size=16) + + arr, geo = read_to_array(path) + assert arr.shape == expected.shape + assert np.abs(arr.astype(int) - expected.astype(int)).mean() < 10 + + +class TestJpegValidation: + """Verify that JPEG rejects invalid input.""" + + def test_float_data_rejected(self, tmp_path): + arr = np.zeros((8, 8), dtype=np.float32) + path = str(tmp_path / 'bad_1050.tif') + with pytest.raises(ValueError, match="uint8"): + write(arr, path, compression='jpeg') + + def test_uint16_data_rejected(self, tmp_path): + arr = np.zeros((8, 8), dtype=np.uint16) + path = str(tmp_path / 'bad16_1050.tif') + with pytest.raises(ValueError, match="uint8"): + write(arr, path, compression='jpeg') + + def test_4band_rejected(self, tmp_path): + arr = np.zeros((8, 8, 4), dtype=np.uint8) + path = str(tmp_path / 'bad4b_1050.tif') + with pytest.raises(ValueError, match="1 or 3 bands"): + write(arr, path, compression='jpeg') + + +class TestWriteGeotiffJpeg: + """Test the public write_geotiff API with compression='jpeg'.""" + + def test_write_geotiff_jpeg(self, tmp_path): + from xrspatial.geotiff import write_geotiff, read_geotiff + + rng = np.random.RandomState(1050) + data = rng.randint(50, 200, (32, 32), dtype=np.uint8) + da = xr.DataArray( + data, dims=['y', 'x'], + coords={'y': np.arange(32, dtype=float), + 'x': np.arange(32, dtype=float)}, + ) + path = str(tmp_path / 'api_1050.tif') + write_geotiff(da, path, compression='jpeg', tile_size=16) + + result = read_geotiff(path) + assert result.shape == (32, 32) + assert np.abs(result.values.astype(int) - data.astype(int)).mean() < 10 diff --git a/xrspatial/geotiff/tests/test_writer.py b/xrspatial/geotiff/tests/test_writer.py index a016f49f..8a33a375 100644 --- a/xrspatial/geotiff/tests/test_writer.py +++ b/xrspatial/geotiff/tests/test_writer.py @@ -101,4 +101,4 @@ class TestWriteInvalidInput: def test_unsupported_compression(self, tmp_path): arr = np.zeros((4, 4), dtype=np.float32) with pytest.raises(ValueError, match="Unsupported compression"): - write(arr, str(tmp_path / 'bad.tif'), compression='jpeg') + write(arr, str(tmp_path / 'bad.tif'), compression='bzip2')