diff --git a/xrspatial/geotiff/__init__.py b/xrspatial/geotiff/__init__.py index fab13e335..0908ccfc8 100644 --- a/xrspatial/geotiff/__init__.py +++ b/xrspatial/geotiff/__init__.py @@ -2737,8 +2737,13 @@ def read_geotiff_gpu(source: str, *, CuPy-backed DataArray that stays on device memory. No CPU->GPU transfer needed for downstream xrspatial GPU operations. - With ``chunks=``, returns a Dask+CuPy DataArray for out-of-core - GPU pipelines. + With ``chunks=``, returns a Dask+CuPy DataArray with real + out-of-core memory bounds: each chunk reads only the tiles for its + window (via the CPU dask path) and uploads the result to the + device, so peak GPU memory is one chunk rather than the whole + raster. The trade-off is per-chunk CPU decode rather than bulk GPU + decode; for rasters that fit on the device, ``chunks=None`` keeps + the full GPU-decode fast path. Requires: cupy, numba with CUDA support. @@ -2765,8 +2770,12 @@ def read_geotiff_gpu(source: str, *, multi-band files, 2D for single-band). Selecting a single band yields a 2D DataArray. chunks : int, tuple, or None - If set, return a Dask-chunked CuPy DataArray. int for square - chunks, (row, col) tuple for rectangular. + If set, return a Dask-chunked CuPy DataArray decoded one chunk + at a time. int for square chunks, (row, col) tuple for + rectangular. Each chunk task reads only the tiles overlapping + its window (CPU decode) and uploads the result to the device, + so peak GPU memory is bounded by chunk size. ``chunks=None`` + (default) decodes the full raster on the GPU in one pass. name : str or None Name for the DataArray. max_pixels : int or None @@ -2847,6 +2856,24 @@ def read_geotiff_gpu(source: str, *, "cupy is required for GPU reads. " "Install it with: pip install cupy-cuda12x") + # When ``chunks=`` is set, bound peak GPU memory to chunk size by + # building a Dask+CuPy graph that decodes one chunk at a time. The + # CPU dask path already lays out a window-per-chunk delayed graph + # (parses TIFF metadata once, decodes only the tiles overlapping + # each chunk window, handles HTTP/fsspec/local/sparse/planar=2/ + # MinIsWhite/nodata/orientation). Reusing it and uploading each + # block to the device via ``map_blocks(cupy.asarray)`` gives real + # out-of-core behaviour for the read; the trade-off is per-chunk + # CPU decode rather than the eager path's bulk GPU decode. Users + # who want full GPU-side decode (and have device memory for the + # whole image) pass ``chunks=None``. See issue #1876. + if chunks is not None: + return _read_geotiff_gpu_chunked( + source, dtype=dtype, chunks=chunks, + overview_level=overview_level, window=window, band=band, + name=name, max_pixels=max_pixels, + ) + from ._reader import ( _FileSource, _check_dimensions, MAX_PIXELS_DEFAULT, _coerce_path, _resolve_masked_fill, @@ -3325,16 +3352,383 @@ def _read_once(): result = xr.DataArray(arr_gpu, dims=dims, coords=coords, name=name, attrs=attrs) - if chunks is not None: - if isinstance(chunks, int): - chunk_dict = {'y': chunks, 'x': chunks} - else: - chunk_dict = {'y': chunks[0], 'x': chunks[1]} - result = result.chunk(chunk_dict) + # ``chunks=`` is handled at function entry via + # ``_read_geotiff_gpu_chunked`` for real out-of-core support; this + # eager path always returns a non-chunked CuPy-backed DataArray. return result +def _gds_chunk_path_available(source, ifd, has_sparse_tile, orientation): + """Return True iff a direct-to-GPU per-chunk decode is possible. + + The disk->GPU per-chunk path requires: + + - KvikIO present (so ``_try_kvikio_read_tiles`` can DMA tiles to VRAM). + - A local file path (no HTTP/fsspec source). + - A tiled layout (no strip-only file). + - PlanarConfiguration=1 (chunky); planar=2 would need per-band tile + grids and per-band crops. + - No sparse tiles, since the GPU decoders skip the bytes-zero-fill + handling the CPU reader does for them. + - Orientation == 1, since a non-default orientation needs the full + array on hand to apply the transform. + - PhotometricInterpretation != 0 (MinIsWhite needs an inversion + pass that lives in the eager path). + """ + if not isinstance(source, str): + return False + if source.startswith(('http://', 'https://')): + return False + try: + from ._reader import _is_fsspec_uri + if _is_fsspec_uri(source): + return False + except Exception: + pass + try: + import importlib.util + if importlib.util.find_spec('kvikio') is None: + return False + except Exception: + return False + if not ifd.is_tiled: + return False + if has_sparse_tile: + return False + if ifd.planar_config != 1: + return False + if orientation != 1: + return False + if ifd.photometric == 0: + return False + return True + + +def _decode_window_gpu_direct(file_path, all_offsets, all_byte_counts, + tw, th, full_w, compression, predictor, + file_dtype, samples, byte_order, + r0, c0, r1, c1): + """Decode a window's tile subset disk->GPU. + + Picks just the tiles overlapping ``(r0..r1, c0..c1)`` from the full + tile sequence, runs them through ``gpu_decode_tiles_from_file`` + (which tries KvikIO GDS first, then a CPU mmap + ``gpu_decode_tiles`` + fallback), and crops the assembled sub-image to the requested window. + Peak device memory is the sub-tile-grid bounding box, not the full + raster. + + Called from inside a ``dask.delayed`` per-chunk task, so it runs + once per chunk and only pulls the tiles that chunk needs from disk. + """ + from ._gpu_decode import gpu_decode_tiles, gpu_decode_tiles_from_file + + tiles_across = (full_w + tw - 1) // tw + + ty_start = r0 // th + ty_end = (r1 - 1) // th + 1 + tx_start = c0 // tw + tx_end = (c1 - 1) // tw + 1 + + sub_tiles_across = tx_end - tx_start + sub_tiles_down = ty_end - ty_start + sub_h = sub_tiles_down * th + sub_w = sub_tiles_across * tw + + indices = [ty * tiles_across + tx + for ty in range(ty_start, ty_end) + for tx in range(tx_start, tx_end)] + sub_offsets = [all_offsets[i] for i in indices] + sub_byte_counts = [all_byte_counts[i] for i in indices] + + arr_gpu = gpu_decode_tiles_from_file( + file_path, sub_offsets, sub_byte_counts, + tw, th, sub_w, sub_h, + compression, predictor, file_dtype, samples, + byte_order=byte_order, + ) + + if arr_gpu is None: + # ``gpu_decode_tiles_from_file`` returns None when KvikIO is not + # usable on the host. Open the file via mmap, slice out just the + # bytes for these tiles, and run the GPU decoder on those. + from ._reader import _FileSource + src = _FileSource(file_path) + try: + data = src.read_all() + compressed_tiles = [ + bytes(data[sub_offsets[i]:sub_offsets[i] + sub_byte_counts[i]]) + for i in range(len(sub_offsets)) + ] + finally: + src.close() + arr_gpu = gpu_decode_tiles( + compressed_tiles, tw, th, sub_w, sub_h, + compression, predictor, file_dtype, samples, + byte_order=byte_order, + ) + + crop_r0 = r0 - ty_start * th + crop_c0 = c0 - tx_start * tw + crop_r1 = crop_r0 + (r1 - r0) + crop_c1 = crop_c0 + (c1 - c0) + if samples > 1: + return arr_gpu[crop_r0:crop_r1, crop_c0:crop_c1, :] + return arr_gpu[crop_r0:crop_r1, crop_c0:crop_c1] + + +def _read_geotiff_gpu_chunked(source, *, dtype, chunks, overview_level, + window, band, name, max_pixels): + """Lazy Dask+CuPy backend for ``read_geotiff_gpu(chunks=...)``. + + Two paths produce the same shape of dask graph: + + 1. **Direct disk->GPU** when KvikIO is available and the file is a + local, tiled, chunky GeoTIFF with no sparse tiles and a trivial + orientation. Each chunk task picks the tile subset for its + window, DMA's those tiles to the device via GDS, decodes on the + GPU, and crops. Peak GPU memory is one chunk and the file bytes + never traverse host memory. + + 2. **CPU decode + GPU upload** for everything else (HTTP / fsspec, + no KvikIO, planar=2, sparse, MinIsWhite, non-trivial orientation, + stripped layouts). Reuses ``read_geotiff_dask`` to build the + per-chunk windowed delayed graph and ``map_blocks(cupy.asarray)`` + to upload each block. Peak GPU memory is still one chunk; the + cost is per-chunk CPU decode rather than GDS DMA. + + Both paths are real out-of-core for device memory. + """ + import cupy + import dask + import dask.array as da_mod + + from ._reader import _FileSource, _coerce_path + from ._header import parse_header, parse_all_ifds, select_overview_ifd + from ._dtypes import resolve_bits_per_sample, tiff_dtype_to_numpy + from ._geotags import extract_geo_info_with_overview_inheritance + + src_path = _coerce_path(source) + + # Try the disk->GPU path. Parse metadata once; if the file does not + # qualify, fall through to the CPU-decode path. Any unexpected + # exception during the qualification probe also falls through so we + # never lose the ability to return a result. + try: + if isinstance(src_path, str) and not src_path.startswith( + ('http://', 'https://')): + fs = _FileSource(src_path) + try: + raw = fs.read_all() + finally: + fs.close() + header = parse_header(raw) + ifds = parse_all_ifds(raw, header) + if not ifds: + raise ValueError("No IFDs found in TIFF file") + ifd = select_overview_ifd(ifds, overview_level) + geo_info = extract_geo_info_with_overview_inheritance( + ifd, ifds, raw, header.byte_order, + ) + orientation = ifd.orientation + has_sparse_tile = ( + ifd.tile_byte_counts is not None + and any(bc == 0 for bc in ifd.tile_byte_counts) + ) + if _gds_chunk_path_available( + src_path, ifd, has_sparse_tile, orientation): + return _read_geotiff_gpu_chunked_gds( + src_path, ifd, geo_info, header, + dtype=dtype, chunks=chunks, window=window, band=band, + name=name, max_pixels=max_pixels, + ) + except Exception: + # GDS qualification failed; fall back to the CPU path. The + # error would otherwise be unrelated to what the user asked + # for (the CPU path re-parses metadata anyway). + pass + + cpu_da = read_geotiff_dask( + source, dtype=dtype, chunks=chunks, + overview_level=overview_level, window=window, band=band, + max_pixels=max_pixels, name=name, + ) + + cpu_dask_arr = cpu_da.data + + def _upload(block): + return cupy.asarray(block) + + gpu_dask_arr = cpu_dask_arr.map_blocks( + _upload, + dtype=cpu_dask_arr.dtype, + meta=cupy.empty((0,) * cpu_dask_arr.ndim, dtype=cpu_dask_arr.dtype), + ) + + return xr.DataArray( + gpu_dask_arr, dims=cpu_da.dims, coords=cpu_da.coords, + name=cpu_da.name, attrs=dict(cpu_da.attrs), + ) + + +def _read_geotiff_gpu_chunked_gds(source, ifd, geo_info, header, *, + dtype, chunks, window, band, name, + max_pixels): + """Build a Dask+CuPy graph that decodes each chunk disk->GPU. + + Caller must have verified that the source qualifies via + ``_gds_chunk_path_available``. Each chunk task pulls only the tile + subset overlapping its window via KvikIO GDS (or an mmap fallback + inside ``gpu_decode_tiles_from_file``) and crops on device. + """ + import cupy + import dask + import dask.array as da_mod + + from ._reader import _check_dimensions, MAX_PIXELS_DEFAULT + from ._header import validate_tile_layout + from ._dtypes import resolve_bits_per_sample, tiff_dtype_to_numpy + + if max_pixels is None: + max_pixels = MAX_PIXELS_DEFAULT + + full_h = ifd.height + full_w = ifd.width + samples = ifd.samples_per_pixel + bps = resolve_bits_per_sample(ifd.bits_per_sample) + file_dtype = tiff_dtype_to_numpy(bps, ifd.sample_format) + tw = ifd.tile_width + th = ifd.tile_height + compression = ifd.compression + predictor = ifd.predictor + byte_order = header.byte_order + offsets = list(ifd.tile_offsets) + byte_counts = list(ifd.tile_byte_counts) + + _check_dimensions(full_w, full_h, samples, max_pixels) + _check_dimensions(tw, th, samples, max_pixels) + validate_tile_layout(ifd) + + # Window restricts the visible region; offsets are computed relative + # to the windowed origin so chunks line up with the user's request. + if window is not None: + w_r0, w_c0, w_r1, w_c1 = window + if (w_r0 < 0 or w_c0 < 0 or w_r1 > full_h or w_c1 > full_w + or w_r0 >= w_r1 or w_c0 >= w_c1): + raise ValueError( + f"window={window} is out of bounds for image " + f"{full_w}x{full_h}.") + out_h, out_w = w_r1 - w_r0, w_c1 - w_c0 + win_r0, win_c0 = w_r0, w_c0 + else: + out_h, out_w = full_h, full_w + win_r0, win_c0 = 0, 0 + + if isinstance(chunks, int): + ch_h = ch_w = chunks + else: + ch_h, ch_w = chunks + if ch_h <= 0 or ch_w <= 0: + raise ValueError(f"Invalid chunks: {chunks}") + + # Validate band kwarg against the file's band count. + n_bands_out = samples if samples > 1 else 0 + if band is not None: + if n_bands_out == 0: + if band != 0: + raise IndexError( + f"band={band} requested but file is single-band.") + elif band < 0 or band >= n_bands_out: + raise IndexError( + f"band={band} out of range for {n_bands_out}-band file.") + + # Wrap the big tile-offset/byte-count tuples in a single Delayed so + # every chunk task takes them as one graph input rather than burning + # them into every task's pickled closure. + meta_key = dask.delayed( + (offsets, byte_counts), pure=True, + ) + + nodata = geo_info.nodata + + @dask.delayed + def _chunk_task(meta, r0, c0, r1, c1): + all_offsets, all_byte_counts = meta + arr = _decode_window_gpu_direct( + source, all_offsets, all_byte_counts, + tw, th, full_w, compression, predictor, + file_dtype, samples, byte_order, + r0, c0, r1, c1, + ) + if nodata is not None: + arr = _apply_nodata_mask_gpu(arr, nodata) + if dtype is not None: + target = np.dtype(dtype) + _validate_dtype_cast(np.dtype(str(arr.dtype)), target) + arr = arr.astype(target) + if band is not None and arr.ndim == 3: + arr = arr[:, :, band] + return arr + + # Determine declared dtype for the dask graph. Nodata masking + # promotes integer rasters to float64; mirror the CPU dask path. + declared_dtype = file_dtype + if nodata is not None and file_dtype.kind in ('u', 'i'): + if np.isfinite(nodata) and float(nodata).is_integer(): + info = np.iinfo(file_dtype) + if info.min <= int(nodata) <= info.max: + declared_dtype = np.dtype('float64') + if dtype is not None: + declared_dtype = np.dtype(dtype) + + out_has_band_axis = band is None and n_bands_out > 0 + + blocks_rows = [] + for r0 in range(0, out_h, ch_h): + r1 = min(r0 + ch_h, out_h) + blocks_cols = [] + for c0 in range(0, out_w, ch_w): + c1 = min(c0 + ch_w, out_w) + if out_has_band_axis: + block_shape = (r1 - r0, c1 - c0, n_bands_out) + else: + block_shape = (r1 - r0, c1 - c0) + # Convert chunk coords to file-space coords. + block = da_mod.from_delayed( + _chunk_task(meta_key, + r0 + win_r0, c0 + win_c0, + r1 + win_r0, c1 + win_c0), + shape=block_shape, + dtype=declared_dtype, + meta=cupy.empty((0,) * len(block_shape), + dtype=declared_dtype), + ) + blocks_cols.append(block) + blocks_rows.append(da_mod.concatenate(blocks_cols, axis=1)) + dask_arr = da_mod.concatenate(blocks_rows, axis=0) + + # Build coords/attrs that match read_geotiff_dask's output. + coords = _coords_from_geo_info(geo_info, out_h, out_w, window=window) + if out_has_band_axis: + dims = ['y', 'x', 'band'] + coords['band'] = np.arange(n_bands_out) + else: + dims = ['y', 'x'] + + attrs = {} + _populate_attrs_from_geo_info(attrs, geo_info, window=window) + if nodata is not None: + attrs['nodata'] = nodata + + if name is None: + import os + name = os.path.splitext(os.path.basename(source))[0] + + return xr.DataArray( + dask_arr, dims=dims, coords=coords, name=name, attrs=attrs, + ) + + def write_geotiff_gpu(data: xr.DataArray | cupy.ndarray | np.ndarray, path: str | BinaryIO, *, crs: int | str | None = None, diff --git a/xrspatial/geotiff/tests/test_gpu_chunks_out_of_core_1876.py b/xrspatial/geotiff/tests/test_gpu_chunks_out_of_core_1876.py new file mode 100644 index 000000000..506b2791e --- /dev/null +++ b/xrspatial/geotiff/tests/test_gpu_chunks_out_of_core_1876.py @@ -0,0 +1,276 @@ +"""Regression tests for issue #1876. + +``read_geotiff_gpu(chunks=...)`` used to eagerly decode the full raster +into a single CuPy array and then call ``.chunk()`` on the resulting +DataArray, so peak GPU memory was the whole-image size even though the +docstring advertised an out-of-core pipeline. The function now +dispatches to a real Dask+CuPy graph that decodes one chunk window at +a time and uploads each block to the device, so peak GPU memory is +bounded by chunk size. + +Two paths back that promise. When ``kvikio`` is available and the file +is a local, tiled, chunky, non-sparse GeoTIFF with trivial orientation, +each chunk task pulls its tile subset directly from disk to GPU via +GDS. Otherwise the per-chunk window is decoded on CPU via the existing +``read_geotiff_dask`` graph and uploaded with ``cupy.asarray``. Both +paths keep peak device memory at one chunk. +""" +from __future__ import annotations + +import importlib.util + +import numpy as np +import pytest + + +def _gpu_available() -> bool: + if importlib.util.find_spec("cupy") is None: + return False + try: + import cupy + return bool(cupy.cuda.is_available()) + except Exception: + return False + + +def _kvikio_available() -> bool: + return importlib.util.find_spec("kvikio") is not None + + +_HAS_GPU = _gpu_available() +_HAS_KVIKIO = _kvikio_available() +_gpu_only = pytest.mark.skipif( + not _HAS_GPU, reason="cupy + CUDA required", +) +_gds_only = pytest.mark.skipif( + not (_HAS_GPU and _HAS_KVIKIO), + reason="cupy + CUDA + kvikio required for GDS path", +) + + +@pytest.fixture +def small_raster_path_1876(tmp_path): + from xrspatial.geotiff import to_geotiff + import xarray as xr + + arr = np.arange(32 * 32, dtype=np.float32).reshape(32, 32) + da = xr.DataArray(arr, dims=['y', 'x'], + attrs={'crs': 4326, + 'transform': (1.0, 0, 0, 0, -1.0, 32.0)}) + path = str(tmp_path / 'small_raster_1876.tif') + to_geotiff(da, path, compression='deflate', tile_size=16) + return path + + +@pytest.fixture +def multi_band_path_1876(tmp_path): + from xrspatial.geotiff import to_geotiff + import xarray as xr + + rng = np.random.RandomState(42) + arr = rng.rand(3, 32, 32).astype(np.float32) + da = xr.DataArray(arr, dims=['band', 'y', 'x'], + attrs={'crs': 4326, + 'transform': (1.0, 0, 0, 0, -1.0, 32.0)}) + path = str(tmp_path / 'multi_band_1876.tif') + to_geotiff(da, path, compression='deflate', tile_size=16) + return path + + +@_gpu_only +def test_read_geotiff_gpu_chunks_yields_dask_cupy_chunks(small_raster_path_1876): + """Each block of the returned dask array must be a cupy array, not + a numpy array and not a single eager cupy block.""" + import cupy + import dask.array as da_mod + + from xrspatial.geotiff import read_geotiff_gpu + + result = read_geotiff_gpu(small_raster_path_1876, chunks=8) + + assert isinstance(result.data, da_mod.Array), ( + f"expected dask Array, got {type(result.data).__name__}" + ) + assert isinstance(result.data._meta, cupy.ndarray), ( + f"expected cupy chunks, got meta={type(result.data._meta).__name__}" + ) + assert result.data.numblocks == (4, 4) + + block = result.data.blocks[0, 0].compute() + assert isinstance(block, cupy.ndarray) + assert block.shape == (8, 8) + + +@_gpu_only +def test_read_geotiff_gpu_chunks_values_match_eager(small_raster_path_1876): + """Lazy chunked result must equal the eager GPU result element-wise.""" + import cupy + + from xrspatial.geotiff import read_geotiff_gpu + + eager = read_geotiff_gpu(small_raster_path_1876) + chunked = read_geotiff_gpu(small_raster_path_1876, chunks=8) + + eager_np = cupy.asnumpy(eager.data) + chunked_np = cupy.asnumpy(chunked.compute().data) + np.testing.assert_array_equal(eager_np, chunked_np) + + +@_gpu_only +def test_read_geotiff_gpu_no_chunks_returns_eager_cupy(small_raster_path_1876): + """``chunks=None`` keeps the eager GPU decode path.""" + import cupy + + from xrspatial.geotiff import read_geotiff_gpu + + result = read_geotiff_gpu(small_raster_path_1876) + + assert isinstance(result.data, cupy.ndarray) + + +@_gpu_only +def test_open_geotiff_gpu_chunks_propagates_to_dask(small_raster_path_1876): + """``open_geotiff(gpu=True, chunks=...)`` must return the same + Dask+CuPy result as the direct read.""" + import cupy + import dask.array as da_mod + + from xrspatial.geotiff import open_geotiff + + result = open_geotiff(small_raster_path_1876, gpu=True, chunks=8) + + assert isinstance(result.data, da_mod.Array) + assert isinstance(result.data._meta, cupy.ndarray) + + +@_gpu_only +def test_read_geotiff_gpu_chunks_preserves_attrs(small_raster_path_1876): + """Geo attrs (transform, crs) must survive the dask path.""" + from xrspatial.geotiff import read_geotiff_gpu + + result = read_geotiff_gpu(small_raster_path_1876, chunks=8) + assert 'transform' in result.attrs + assert 'crs' in result.attrs + + +@_gds_only +def test_read_geotiff_gpu_chunks_uses_gds_path_when_available( + small_raster_path_1876, monkeypatch): + """When kvikio is installed and the file qualifies (local + tiled + + chunky + no sparse + orientation=1 + photometric!=0), each chunk + task must call the direct disk->GPU decoder rather than detouring + through ``read_geotiff_dask``.""" + from xrspatial.geotiff import read_geotiff_gpu + from xrspatial import geotiff as gtmod + + direct_calls = {'n': 0} + real_direct = gtmod._decode_window_gpu_direct + + def _spy(*args, **kwargs): + direct_calls['n'] += 1 + return real_direct(*args, **kwargs) + + monkeypatch.setattr(gtmod, '_decode_window_gpu_direct', _spy) + + result = read_geotiff_gpu(small_raster_path_1876, chunks=8) + result.compute() + + assert direct_calls['n'] == 16, ( + f"expected one disk->GPU call per chunk (4x4 = 16); " + f"got {direct_calls['n']}" + ) + + +@_gpu_only +def test_read_geotiff_gpu_chunks_window_subset(small_raster_path_1876): + """A window on the dask path produces the same values as a window + on the eager path.""" + import cupy + + from xrspatial.geotiff import read_geotiff_gpu + + eager = read_geotiff_gpu(small_raster_path_1876, window=(4, 4, 24, 28)) + chunked = read_geotiff_gpu(small_raster_path_1876, chunks=8, + window=(4, 4, 24, 28)) + + eager_np = cupy.asnumpy(eager.data) + chunked_np = cupy.asnumpy(chunked.compute().data) + assert eager_np.shape == (20, 24) + np.testing.assert_array_equal(eager_np, chunked_np) + + +@_gpu_only +def test_read_geotiff_gpu_chunks_multi_band(multi_band_path_1876): + """Multi-band tiled files chunk along (y, x) with a band axis.""" + import cupy + import dask.array as da_mod + + from xrspatial.geotiff import read_geotiff_gpu + + result = read_geotiff_gpu(multi_band_path_1876, chunks=16) + assert isinstance(result.data, da_mod.Array) + assert isinstance(result.data._meta, cupy.ndarray) + assert result.sizes['band'] == 3 + + eager = read_geotiff_gpu(multi_band_path_1876) + eager_np = cupy.asnumpy(eager.data) + chunked_np = cupy.asnumpy(result.compute().data) + np.testing.assert_allclose(eager_np, chunked_np, rtol=1e-5) + + +@_gpu_only +def test_read_geotiff_gpu_chunks_single_band_selection(multi_band_path_1876): + """``band=k`` collapses to a 2D Dask+CuPy DataArray.""" + import cupy + + from xrspatial.geotiff import read_geotiff_gpu + + result = read_geotiff_gpu(multi_band_path_1876, chunks=16, band=1) + assert result.ndim == 2 + assert isinstance(result.data._meta, cupy.ndarray) + + eager = read_geotiff_gpu(multi_band_path_1876, band=1) + eager_np = cupy.asnumpy(eager.data) + chunked_np = cupy.asnumpy(result.compute().data) + np.testing.assert_allclose(eager_np, chunked_np, rtol=1e-5) + + +@_gpu_only +def test_read_geotiff_gpu_chunks_fallback_when_kvikio_absent( + small_raster_path_1876, monkeypatch): + """When kvikio is reported missing, the chunked path falls back to + the CPU-decode + cupy.asarray graph and still produces a Dask+CuPy + DataArray with correct values.""" + import cupy + import importlib.util as _ilu + + from xrspatial.geotiff import read_geotiff_gpu + from xrspatial import geotiff as gtmod + + original_find_spec = _ilu.find_spec + + def _fake_find_spec(name, *a, **kw): + if name == 'kvikio': + return None + return original_find_spec(name, *a, **kw) + + monkeypatch.setattr(_ilu, 'find_spec', _fake_find_spec) + + direct_calls = {'n': 0} + real_direct = gtmod._decode_window_gpu_direct + + def _spy(*args, **kwargs): + direct_calls['n'] += 1 + return real_direct(*args, **kwargs) + + monkeypatch.setattr(gtmod, '_decode_window_gpu_direct', _spy) + + result = read_geotiff_gpu(small_raster_path_1876, chunks=8) + computed = result.compute() + assert direct_calls['n'] == 0 + assert isinstance(computed.data, cupy.ndarray) + + eager = read_geotiff_gpu(small_raster_path_1876) + np.testing.assert_array_equal( + cupy.asnumpy(eager.data), cupy.asnumpy(computed.data), + )