From 5c62e3093952720efc7fc839b95d49a2f80f19ba Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Wed, 13 May 2026 08:56:28 -0700 Subject: [PATCH 1/4] geotiff: build lazy dask graph for read_vrt(chunks=) (closes #1814) Before this change, read_vrt(chunks=...) materialised the full VRT mosaic on host RAM and then wrapped the resulting numpy array via .chunk(). chunks= gave no memory protection, and gpu=True + chunks= still assembled the entire mosaic on the CPU before moving to the device. The chunked path now parses the VRT XML once up front, builds one dask.delayed per destination chunk window, and assembles them via from_delayed + da.concatenate. Each task calls the existing VRT internal reader with its own window= so only the sources intersecting that window are decoded. With gpu=True each block calls cupy.asarray before returning, so the dask array is dask-on-cupy from the start. A task-count cap (50,000, matching read_geotiff_dask) refuses chunk grids that would build a scheduler-busting graph and suggests a larger chunks= value. attrs['vrt_holes'] is not populated for chunked reads; the GeoTIFFFallbackWarning still fires from each worker. --- xrspatial/geotiff/__init__.py | 398 +++++++++++++++++- .../tests/test_vrt_lazy_chunks_1814.py | 271 ++++++++++++ 2 files changed, 663 insertions(+), 6 deletions(-) create mode 100644 xrspatial/geotiff/tests/test_vrt_lazy_chunks_1814.py diff --git a/xrspatial/geotiff/__init__.py b/xrspatial/geotiff/__init__.py index 588cc763..26db9d8c 100644 --- a/xrspatial/geotiff/__init__.py +++ b/xrspatial/geotiff/__init__.py @@ -3858,6 +3858,15 @@ def read_vrt(source: str, *, ``relativeToVRT='1'`` source that escapes the VRT directory (e.g. ``../../etc/passwd`` or a symlink to a file outside the directory) is rejected regardless of the allowlist. + + Lazy chunked reads (issue #1814): when ``chunks=`` is set, the + returned DataArray wraps a dask graph that decodes one chunk + window per task. Construction does not materialise any pixels; + only the VRT XML is parsed. The eager read populates + ``attrs['vrt_holes']`` from skipped sources; the chunked path does + not aggregate per-task hole records, so that attribute is not set + when ``chunks=`` is used. Each worker still emits + ``GeoTIFFFallbackWarning`` for missing sources. """ from ._reader import _coerce_path from ._vrt import read_vrt as _read_vrt_internal @@ -3876,6 +3885,26 @@ def read_vrt(source: str, *, f"missing_sources must be 'warn' or 'raise', got " f"{missing_sources!r}") + # Lazy chunked path (issue #1814). The eager call below materialises + # the full mosaic on host RAM and then wraps the array via + # ``.chunk()``, so chunks= gave no memory protection and gpu=True + + # chunks= still assembled the full mosaic on the CPU before moving to + # the device. When chunks= is set, dispatch to a delayed-per-window + # builder so each task decodes only the sources intersecting its + # destination window. + if chunks is not None: + return _read_vrt_chunked( + source, + window=window, + band=band, + name=name, + chunks=chunks, + gpu=gpu, + dtype=dtype, + max_pixels=max_pixels, + missing_sources=missing_sources, + ) + arr, vrt = _read_vrt_internal( source, window=window, band=band, max_pixels=max_pixels, missing_sources=missing_sources, @@ -4070,14 +4099,371 @@ def _sentinel_for_dtype(nodata_val, dtype): result = xr.DataArray(arr, dims=dims, coords=coords, name=name, attrs=attrs) - # Chunk for Dask (or Dask+CuPy if gpu=True) - if chunks is not None: - if isinstance(chunks, int): - chunk_dict = {'y': chunks, 'x': chunks} + # ``chunks is not None`` is handled by ``_read_vrt_chunked`` higher up + # in this function (issue #1814); reaching this point implies the + # eager path, so no post-decode chunking is needed. + return result + + +# Hard cap on the per-VRT chunk task count. Matches the +# ``_MAX_DASK_CHUNKS`` value used by ``read_geotiff_dask`` so the two +# entry points refuse the same scheduler-busting chunk grids. See +# issue #1814. +_MAX_VRT_DASK_CHUNKS = 50_000 + + +def _vrt_chunk_read(source, r0, c0, r1, c1, *, + band, max_pixels, missing_sources, + declared_dtype, gpu): + """Decode a single chunk window from a VRT. + + Called by ``dask.delayed`` from :func:`_read_vrt_chunked`. The + function reads only the destination window via the existing VRT + internal reader, applies the same integer-sentinel masking the + eager :func:`read_vrt` does post-decode, casts to the dtype the + dask graph declared up front, and optionally moves the block to + the GPU. + + Returning a ``numpy.ndarray`` (or ``cupy.ndarray`` when ``gpu`` is + set) whose shape and dtype match the ``shape=`` / ``dtype=`` kwargs + of the surrounding ``dask.array.from_delayed`` is the contract; a + mismatch would silently produce a wrong-shape dask array. + """ + from ._vrt import read_vrt as _read_vrt_internal + + arr, vrt = _read_vrt_internal( + source, window=(r0, c0, r1, c1), band=band, + max_pixels=max_pixels, missing_sources=missing_sources, + ) + + # Mirror the eager post-decode integer-sentinel masking in + # ``read_vrt``. The internal reader NaN-masks float source arrays + # inline but leaves integer sentinels untouched, so the eager path + # promotes to float64 when sentinels hit. Apply the same logic per + # chunk; the surrounding dask graph already declared float64 when + # any band has a representable integer sentinel, so any chunk that + # actually fires the mask returns a buffer whose dtype matches the + # declared one. + if arr.dtype.kind in ('u', 'i'): + if arr.ndim == 3 and band is None and vrt.bands: + int_arr = arr + int_dtype = int_arr.dtype + for i, vrt_band in enumerate(vrt.bands): + if i >= int_arr.shape[-1]: + break + sentinel = _vrt_sentinel_for_dtype(vrt_band.nodata, int_dtype) + if sentinel is None: + continue + mask = int_arr[..., i] == sentinel + if not mask.any(): + continue + if arr.dtype != np.float64: + arr = arr.astype(np.float64) + arr[..., i][mask] = np.nan else: - chunk_dict = {'y': chunks[0], 'x': chunks[1]} - result = result.chunk(chunk_dict) + band_idx = band if band is not None else 0 + nodata = None + if vrt.bands and 0 <= band_idx < len(vrt.bands): + nodata = vrt.bands[band_idx].nodata + if nodata is not None: + sentinel = _vrt_sentinel_for_dtype(nodata, arr.dtype) + if sentinel is not None: + mask = arr == sentinel + if mask.any(): + arr = arr.astype(np.float64) + arr[mask] = np.nan + + if declared_dtype is not None and arr.dtype != declared_dtype: + arr = arr.astype(declared_dtype) + + if gpu: + import cupy + arr = cupy.asarray(arr) + + return arr + + +def _vrt_sentinel_for_dtype(nodata_val, dtype): + """Return ``dtype``-cast sentinel for ``nodata_val`` or None. + + Module-level twin of the closure ``_sentinel_for_dtype`` defined + inside :func:`read_vrt`. Lifted to module scope so the per-chunk + helper :func:`_vrt_chunk_read` can call it without paying the cost + of re-binding the closure on every block. + """ + if nodata_val is None or dtype.kind not in ('u', 'i'): + return None + info = np.iinfo(dtype) + if isinstance(nodata_val, (int, np.integer)) and not isinstance( + nodata_val, bool): + nodata_int = int(nodata_val) + if info.min <= nodata_int <= info.max: + return dtype.type(nodata_int) + return None + try: + nodata_f = float(nodata_val) + except (TypeError, ValueError): + return None + if not (np.isfinite(nodata_f) and nodata_f.is_integer() + and info.min <= nodata_f <= info.max): + return None + return dtype.type(int(nodata_f)) + + +def _read_vrt_chunked(source, *, window, band, name, chunks, gpu, dtype, + max_pixels, missing_sources): + """Lazy ``read_vrt`` dispatch when ``chunks=`` is set (issue #1814). + + Parses the VRT XML once to recover the extent, CRS, GeoTransform, + and per-band metadata, then builds a dask graph with one task per + chunk window. Each task calls into the existing VRT internal reader + with its own ``window=`` so only the sources intersecting the + chunk's destination rectangle are decoded. + + The eager :func:`read_vrt` populates ``attrs['vrt_holes']`` from + skipped sources; the chunked path does not aggregate per-task hole + records, so that attribute is not set here. The underlying + ``GeoTIFFFallbackWarning`` still fires from each worker when a + source is missing. + """ + import os as _os + import dask + import dask.array as da + + from ._reader import MAX_PIXELS_DEFAULT + from ._vrt import parse_vrt + + # Parse the VRT XML up-front (cheap; no pixel decode). + with open(source, 'r') as f: + xml_str = f.read() + vrt_dir = _os.path.dirname(_os.path.abspath(source)) + vrt = parse_vrt(xml_str, vrt_dir) + + # Validate ``band`` against the parsed band count, matching the + # internal reader's contract so the failure mode is the same whether + # the user reads eagerly or chunked. + if band is not None: + if not isinstance(band, (int, np.integer)) or isinstance(band, bool): + raise ValueError( + f"band must be a non-negative int, got {band!r}") + if band < 0 or band >= len(vrt.bands): + raise ValueError( + f"band index {band} out of range for VRT with " + f"{len(vrt.bands)} band(s)") + + # Resolve the windowed extent against the VRT. + if window is not None: + r0, c0, r1, c1 = window + if (r0 < 0 or c0 < 0 + or r1 > vrt.height or c1 > vrt.width + or r0 >= r1 or c0 >= c1): + raise ValueError( + f"window={window} is outside the VRT extent " + f"({vrt.height}x{vrt.width}) or has non-positive size.") + win_r0, win_c0 = r0, c0 + full_h, full_w = r1 - r0, c1 - c0 + else: + win_r0, win_c0 = 0, 0 + full_h, full_w = vrt.height, vrt.width + + max_pixels_effective = ( + max_pixels if max_pixels is not None else MAX_PIXELS_DEFAULT + ) + + # Up-front pixel-count guard against the windowed extent. Mirrors + # the eager ``_vrt.read_vrt`` (which calls ``_check_dimensions`` on + # the full output shape) and ``read_geotiff_dask`` (which guards + # ``full_h * full_w * eff_bands`` before scheduling any task). Each + # chunk task additionally re-checks via ``max_pixels`` through the + # internal reader, but catching an oversized request up front saves + # the caller from a misleading per-chunk error. + eff_bands = 1 if band is not None else max(1, len(vrt.bands)) + if full_h * full_w * eff_bands > max_pixels_effective: + raise ValueError( + f"Requested region {full_h}x{full_w}x{eff_bands} exceeds " + f"max_pixels={max_pixels_effective:,}.") + + if isinstance(chunks, int): + ch_h = ch_w = chunks + else: + ch_h, ch_w = chunks + + # Refuse chunk grids that would build more tasks than the scheduler + # can hold without OOMing the driver. ``read_geotiff_dask`` uses the + # same cap with the same suggestion logic (see issue #1814 and the + # ``_MAX_DASK_CHUNKS`` guard upstream). + n_chunks = ((full_h + ch_h - 1) // ch_h) * ((full_w + ch_w - 1) // ch_w) + if n_chunks > _MAX_VRT_DASK_CHUNKS: + scale = math.sqrt(n_chunks / _MAX_VRT_DASK_CHUNKS) + suggested_h = int(math.ceil(ch_h * scale)) + suggested_w = int(math.ceil(ch_w * scale)) + raise ValueError( + f"read_vrt: chunks=({ch_h}, {ch_w}) on a " + f"{full_h}x{full_w} VRT region would produce {n_chunks:,} " + f"dask tasks, exceeding the {_MAX_VRT_DASK_CHUNKS:,}-task " + f"cap. Pass a larger chunks=... value explicitly (e.g. " + f"chunks=({suggested_h}, {suggested_w}) keeps the task " + f"count under the cap)." + ) + + # Select bands for shape/dtype declaration. + if band is not None: + selected_bands = [vrt.bands[band]] + else: + selected_bands = vrt.bands + + if not selected_bands: + raise ValueError( + "VRT has no elements; cannot determine " + "output dtype") + + # Compute the declared dtype. Match the internal reader's + # ``np.result_type`` over per-band effective dtypes, then widen to + # float64 if any selected band has an integer dtype with a + # representable nodata sentinel (the eager path promotes that case + # on mask hits; declaring float64 up front keeps every block's + # dtype consistent with the dask array's metadata regardless of + # whether the chunk actually contains sentinel pixels). + effective_dtypes = [] + for vrt_band in selected_bands: + eff = vrt_band.dtype + for src in vrt_band.sources: + scaled = src.scale is not None and src.scale != 1.0 + offset = src.offset is not None and src.offset != 0.0 + if scaled or offset: + eff = np.dtype(np.float64) + break + effective_dtypes.append(eff) + declared_dtype = np.result_type(*effective_dtypes) + + if declared_dtype.kind in ('u', 'i'): + promotes = False + for vrt_band in selected_bands: + if _vrt_sentinel_for_dtype(vrt_band.nodata, + declared_dtype) is not None: + promotes = True + break + if promotes: + declared_dtype = np.dtype(np.float64) + + out_has_band_axis = band is None and len(vrt.bands) > 1 + n_out_bands = len(selected_bands) + + # Build the dask graph: one ``from_delayed`` per chunk window. The + # destination coordinate space is the VRT's full extent (or the + # windowed extent), so chunk windows are computed relative to that + # space and translated to absolute VRT coords before being passed + # into the per-chunk reader. + rows = list(range(0, full_h, ch_h)) + cols = list(range(0, full_w, ch_w)) + + delayed_read = dask.delayed(_vrt_chunk_read) + + if gpu: + import cupy + meta = cupy.empty((0,) * (3 if out_has_band_axis else 2), + dtype=declared_dtype) + else: + meta = np.empty((0,) * (3 if out_has_band_axis else 2), + dtype=declared_dtype) + + dask_rows = [] + for r0c in rows: + r1c = min(r0c + ch_h, full_h) + dask_cols = [] + for c0c in cols: + c1c = min(c0c + ch_w, full_w) + if out_has_band_axis: + block_shape = (r1c - r0c, c1c - c0c, n_out_bands) + else: + block_shape = (r1c - r0c, c1c - c0c) + d = delayed_read( + source, + r0c + win_r0, c0c + win_c0, + r1c + win_r0, c1c + win_c0, + band=band, + max_pixels=max_pixels_effective, + missing_sources=missing_sources, + declared_dtype=declared_dtype, + gpu=gpu, + ) + block = da.from_delayed(d, shape=block_shape, + dtype=declared_dtype, meta=meta) + dask_cols.append(block) + dask_rows.append(da.concatenate(dask_cols, axis=1)) + dask_arr = da.concatenate(dask_rows, axis=0) + + # Optional user-requested dtype cast happens lazily on the dask + # array so the per-chunk decode dtype stays predictable. + if dtype is not None: + target = np.dtype(dtype) + _validate_dtype_cast(declared_dtype, target) + dask_arr = dask_arr.astype(target) + final_dtype = target + else: + final_dtype = declared_dtype + + # Coordinates: derive from the VRT GeoTransform and the windowed + # extent. Mirrors the eager branch in ``read_vrt`` so chunked and + # eager reads share the same x/y arrays. + gt = vrt.geo_transform + coords = {} + attrs = {} + if gt is not None: + origin_x, res_x, _, origin_y, _, res_y = gt + if vrt.raster_type == 'point': + x_shift = win_c0 * res_x + y_shift = win_r0 * res_y + else: + x_shift = (win_c0 + 0.5) * res_x + y_shift = (win_r0 + 0.5) * res_y + x = np.arange(full_w, dtype=np.float64) * res_x + origin_x + x_shift + y = np.arange(full_h, dtype=np.float64) * res_y + origin_y + y_shift + coords['y'] = y + coords['x'] = x + origin_x_out = float(origin_x) + win_c0 * float(res_x) + origin_y_out = float(origin_y) + win_r0 * float(res_y) + attrs['transform'] = ( + float(res_x), 0.0, origin_x_out, + 0.0, float(res_y), origin_y_out, + ) + + if vrt.crs_wkt: + epsg = _wkt_to_epsg(vrt.crs_wkt) + if epsg is not None: + attrs['crs'] = epsg + attrs['crs_wkt'] = vrt.crs_wkt + if vrt.raster_type == 'point': + attrs['raster_type'] = 'point' + + # Surface the nodata sentinel for the selected band. The chunked + # path does not aggregate ``vrt.holes`` across tasks (per-task holes + # would need to be reduced by an extra delayed; not done here, see + # issue #1814 note in the docstring). + nodata_meta = None + if vrt.bands: + band_idx_for_nodata = band if band is not None else 0 + nodata_meta = vrt.bands[band_idx_for_nodata].nodata + if nodata_meta is not None: + attrs['nodata'] = nodata_meta + + if out_has_band_axis: + dims = ['y', 'x', 'band'] + coords['band'] = np.arange(n_out_bands) + else: + dims = ['y', 'x'] + + if name is None: + name = _os.path.splitext(_os.path.basename(source))[0] + + result = xr.DataArray( + dask_arr, dims=dims, coords=coords, name=name, attrs=attrs, + ) + # Sanity: the declared dtype on the dask array is what we return. + assert result.dtype == final_dtype, ( + f"internal: result dtype {result.dtype} != declared {final_dtype}" + ) return result diff --git a/xrspatial/geotiff/tests/test_vrt_lazy_chunks_1814.py b/xrspatial/geotiff/tests/test_vrt_lazy_chunks_1814.py new file mode 100644 index 00000000..c9563267 --- /dev/null +++ b/xrspatial/geotiff/tests/test_vrt_lazy_chunks_1814.py @@ -0,0 +1,271 @@ +"""Lazy chunked read_vrt builds a real dask graph (issue #1814). + +The pre-fix ``read_vrt(chunks=...)`` materialised the full VRT mosaic +on host RAM, then wrapped the resulting numpy array via ``.chunk()``. +That defeated the purpose of ``chunks=`` for memory protection and +made ``gpu=True`` + ``chunks=`` even worse: the entire mosaic was +moved to the device before chunking. + +These tests cover the new lazy path: + +* construction does not decode any pixels; +* per-chunk decode happens at ``.compute()`` time; +* the resulting array is byte-identical to the eager read; +* the chunk task count is bounded so a typo in ``chunks=`` cannot + build a graph the scheduler refuses to dispatch. +""" +from __future__ import annotations + +import os +import tempfile + +import dask.array as da +import numpy as np +import pytest +import xarray as xr + +from xrspatial.geotiff import read_vrt, to_geotiff +from xrspatial.geotiff._vrt import write_vrt as _write_vrt_internal + + +def _gpu_available() -> bool: + try: + import cupy # noqa: F401 + except ImportError: + return False + try: + import cupy + return bool(cupy.cuda.is_available()) + except Exception: + return False + + +_HAS_GPU = _gpu_available() + + +@pytest.fixture +def single_tile_vrt(): + """One 128x128 float32 tile wrapped in a VRT.""" + arr = np.arange(128 * 128, dtype=np.float32).reshape(128, 128) + y = np.linspace(41.0, 40.0, 128) + x = np.linspace(-106.0, -105.0, 128) + raster = xr.DataArray(arr, dims=['y', 'x'], + coords={'y': y, 'x': x}, + attrs={'crs': 4326}) + td = tempfile.mkdtemp(prefix='tmp_1814_single_') + tile_path = os.path.join(td, 'tile.tif') + to_geotiff(raster, tile_path) + vrt_path = os.path.join(td, 'mosaic.vrt') + _write_vrt_internal(vrt_path, [tile_path]) + yield vrt_path, arr + + +@pytest.fixture +def two_by_two_vrt(): + """4-tile mosaic via the to_geotiff(.vrt, ...) dask path.""" + arr = np.arange(256 * 256, dtype=np.float32).reshape(256, 256) + y = np.linspace(41.0, 40.0, 256) + x = np.linspace(-106.0, -105.0, 256) + raster = xr.DataArray(arr, dims=['y', 'x'], + coords={'y': y, 'x': x}, + attrs={'crs': 4326}) + td = tempfile.mkdtemp(prefix='tmp_1814_2x2_') + vrt_path = os.path.join(td, 'mosaic.vrt') + # ``tile_size=128`` produces a 2x2 mosaic of 128x128 tiles. + to_geotiff(raster, vrt_path, tile_size=128) + yield vrt_path, arr + + +@pytest.fixture +def multiband_vrt(): + """3-band single-tile VRT.""" + rng = np.random.default_rng(1814) + arr = rng.random((64, 64, 3), dtype=np.float32) + y = np.linspace(41.0, 40.0, 64) + x = np.linspace(-106.0, -105.0, 64) + raster = xr.DataArray( + arr, + dims=['y', 'x', 'band'], + coords={'y': y, 'x': x, 'band': np.arange(3)}, + attrs={'crs': 4326}, + ) + td = tempfile.mkdtemp(prefix='tmp_1814_mb_') + tile_path = os.path.join(td, 'tile.tif') + to_geotiff(raster, tile_path) + vrt_path = os.path.join(td, 'mosaic.vrt') + _write_vrt_internal(vrt_path, [tile_path]) + yield vrt_path, arr + + +# --------------------------------------------------------------------------- +# 1. Construction is lazy: no pixels are decoded before .compute(). +# --------------------------------------------------------------------------- + +def test_chunks_builds_dask_array_with_multiple_blocks(two_by_two_vrt): + """``read_vrt(chunks=(N,N))`` returns a dask-backed DataArray + whose underlying array has more than one chunk along each spatial + axis. Before the fix the array was numpy-backed under + ``result.chunk()``, so this asserts the new lazy graph is in + play. + """ + vrt_path, _ = two_by_two_vrt + result = read_vrt(vrt_path, chunks=(64, 64)) + assert isinstance(result.data, da.Array), ( + f"expected dask Array, got {type(result.data).__name__}" + ) + # 256 / 64 = 4 blocks per axis. + assert result.data.numblocks == (4, 4), ( + f"expected 4x4 blocks, got {result.data.numblocks}" + ) + + +def test_chunks_is_lazy_does_not_call_internal_reader(monkeypatch, + two_by_two_vrt): + """Construction-time call count of the internal VRT reader is zero; + after ``.compute()`` it equals the chunk count. + """ + vrt_path, _ = two_by_two_vrt + + from xrspatial.geotiff import _vrt as vrt_module + + counter = {'calls': 0} + real_read = vrt_module.read_vrt + + def counting_read(*args, **kwargs): + counter['calls'] += 1 + return real_read(*args, **kwargs) + + monkeypatch.setattr(vrt_module, 'read_vrt', counting_read) + + result = read_vrt(vrt_path, chunks=(64, 64)) + + assert counter['calls'] == 0, ( + f"_read_vrt_internal called {counter['calls']} times before " + f".compute(); the chunked path leaked an eager decode" + ) + + computed = result.compute() + # 4 row blocks * 4 col blocks = 16 expected decodes. + assert counter['calls'] == 16, ( + f"expected 16 per-chunk decodes after compute, got {counter['calls']}" + ) + assert computed.shape == (256, 256) + + +# --------------------------------------------------------------------------- +# 2. Byte-identical to the eager path. +# --------------------------------------------------------------------------- + +def test_chunked_compute_matches_eager(two_by_two_vrt): + vrt_path, _ = two_by_two_vrt + eager = read_vrt(vrt_path) + chunked = read_vrt(vrt_path, chunks=(64, 64)).compute() + assert eager.shape == chunked.shape + assert np.array_equal(eager.values, chunked.values), ( + "chunked compute diverged from eager read" + ) + # Coords and key attrs must match too. + np.testing.assert_array_equal(eager['x'].values, chunked['x'].values) + np.testing.assert_array_equal(eager['y'].values, chunked['y'].values) + assert eager.attrs.get('transform') == chunked.attrs.get('transform') + assert eager.attrs.get('crs') == chunked.attrs.get('crs') + + +def test_chunked_single_tile_matches_eager(single_tile_vrt): + """Single-tile VRT (one source) should still match eager when + chunked. Exercises the path where many chunk windows hit the + same single source. + """ + vrt_path, _ = single_tile_vrt + eager = read_vrt(vrt_path) + chunked = read_vrt(vrt_path, chunks=(32, 32)).compute() + assert np.array_equal(eager.values, chunked.values) + + +# --------------------------------------------------------------------------- +# 3. Task-count cap. +# --------------------------------------------------------------------------- + +def test_chunks_task_cap_raises(two_by_two_vrt): + """``chunks=(1, 1)`` on a 256x256 VRT would build 65,536 tasks, + blowing past the 50,000-task cap. The reader should refuse with + a ValueError that names ``chunks=`` and suggests a larger size. + """ + vrt_path, _ = two_by_two_vrt + with pytest.raises(ValueError, match=r"chunks=.*task"): + read_vrt(vrt_path, chunks=(1, 1)) + + +# --------------------------------------------------------------------------- +# 4. Window + chunks: chunks tile the window, not the full extent. +# --------------------------------------------------------------------------- + +def test_window_plus_chunks_matches_eager(two_by_two_vrt): + """When both ``window=`` and ``chunks=`` are passed, the dask + graph must tile the window (not the full VRT extent). The output + shape and pixel values match an eager windowed read. + """ + vrt_path, _ = two_by_two_vrt + window = (32, 48, 160, 192) # 128 high, 144 wide + + eager = read_vrt(vrt_path, window=window) + chunked = read_vrt(vrt_path, window=window, chunks=(64, 64)) + + assert isinstance(chunked.data, da.Array) + # The chunk grid is sized off the window extent (128, 144) with + # chunks=64 => (2, 3) numblocks. + assert chunked.data.numblocks == (2, 3), ( + f"expected (2, 3) numblocks over the window, got " + f"{chunked.data.numblocks}" + ) + + computed = chunked.compute() + assert computed.shape == eager.shape == (128, 144) + assert np.array_equal(eager.values, computed.values) + + +# --------------------------------------------------------------------------- +# 5. GPU + chunks: each block is a cupy array. +# --------------------------------------------------------------------------- + +@pytest.mark.skipif(not _HAS_GPU, reason="cupy + CUDA required") +def test_gpu_plus_chunks_returns_dask_on_cupy(two_by_two_vrt): + """``read_vrt(gpu=True, chunks=...)`` must build a dask graph whose + blocks are cupy-backed (not numpy that gets cupy-wrapped at + compute time on the host). + """ + import cupy + + vrt_path, _ = two_by_two_vrt + result = read_vrt(vrt_path, gpu=True, chunks=(64, 64)) + + assert isinstance(result.data, da.Array) + assert isinstance(result.data._meta, cupy.ndarray), ( + f"expected cupy _meta, got " + f"{type(result.data._meta).__module__}." + f"{type(result.data._meta).__name__}" + ) + computed = result.compute() + assert isinstance(computed.data, cupy.ndarray) + + +# --------------------------------------------------------------------------- +# 6. Multi-band VRT + chunks. +# --------------------------------------------------------------------------- + +def test_multiband_plus_chunks_preserves_band_dim(multiband_vrt): + """3-band VRT read with ``chunks=`` keeps the band dimension on + every block and the assembled DataArray. + """ + vrt_path, src = multiband_vrt + result = read_vrt(vrt_path, chunks=(32, 32)) + + assert isinstance(result.data, da.Array) + assert result.dims == ('y', 'x', 'band') + assert result.shape == (64, 64, 3) + # Per-block shape on the band axis is 3 (whole band axis in one + # chunk because we did not pass a band-chunk size). + assert result.data.chunks[2] == (3,) + + computed = result.compute() + np.testing.assert_allclose(computed.values, src, rtol=0, atol=0) From ec214f1f59464ebc42a1923ef922bcffffb75a12 Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Wed, 13 May 2026 10:20:07 -0700 Subject: [PATCH 2/4] geotiff: propagate vrt_holes and gate float64 promotion on chunked path (#1822 review) Address Copilot review comments on PR #1822: (c) The eager read_vrt populates attrs['vrt_holes'] as the machine-readable partial-mosaic detection contract from #1734. The chunked path silently omitted it. Populate the attr from a parse-time os.path.exists sweep over every source referenced by the parsed VRT so callers switching from eager to chunked keep the contract. The check is a static approximation that catches the dominant missing-file case; codec-error holes still surface as per-task GeoTIFFFallbackWarning. (d) Document the static "any band declares nodata?" promotion check as an explicit approximation of the eager path's runtime mask.any(). The gate was already correct (no promotion when no band declares nodata) but the surrounding comment did not call out the parse-time-vs-runtime trade-off. Add a regression test pinning the no-nodata uint16 case to the source dtype. Defer the duplicated eager logic and N+1 VRT-XML re-parse via tracking issue #1825; TODO markers reference that issue at both call sites. --- xrspatial/geotiff/__init__.py | 77 +++++++++++++--- .../tests/test_vrt_lazy_chunks_1814.py | 89 +++++++++++++++++++ 2 files changed, 152 insertions(+), 14 deletions(-) diff --git a/xrspatial/geotiff/__init__.py b/xrspatial/geotiff/__init__.py index 26db9d8c..a913bf75 100644 --- a/xrspatial/geotiff/__init__.py +++ b/xrspatial/geotiff/__init__.py @@ -4131,6 +4131,9 @@ def _vrt_chunk_read(source, r0, c0, r1, c1, *, """ from ._vrt import read_vrt as _read_vrt_internal + # TODO(#1825): this re-parses the VRT XML and re-validates source + # paths once per chunk task. Plumb the parent's parsed VRT through + # the task graph to remove the N+1 parse cost. arr, vrt = _read_vrt_internal( source, window=(r0, c0, r1, c1), band=band, max_pixels=max_pixels, missing_sources=missing_sources, @@ -4220,11 +4223,16 @@ def _read_vrt_chunked(source, *, window, band, name, chunks, gpu, dtype, with its own ``window=`` so only the sources intersecting the chunk's destination rectangle are decoded. - The eager :func:`read_vrt` populates ``attrs['vrt_holes']`` from - skipped sources; the chunked path does not aggregate per-task hole - records, so that attribute is not set here. The underlying - ``GeoTIFFFallbackWarning`` still fires from each worker when a - source is missing. + ``attrs['vrt_holes']`` is populated from a parse-time + ``os.path.exists`` sweep over every source referenced by the parsed + VRT; this preserves the eager-path contract documented in #1734 so + callers switching from eager to chunked can still detect partial + mosaics by attribute lookup (rather than monitoring the + ``GeoTIFFFallbackWarning`` stream). The check is a static + approximation of the eager path's per-source decode-time exception + handling: it catches the dominant "missing file" case but does not + detect decode-time codec failures, which surface as per-task + ``GeoTIFFFallbackWarning`` from each worker. """ import os as _os import dask @@ -4319,11 +4327,25 @@ def _read_vrt_chunked(source, *, window, band, name, chunks, gpu, dtype, # Compute the declared dtype. Match the internal reader's # ``np.result_type`` over per-band effective dtypes, then widen to - # float64 if any selected band has an integer dtype with a - # representable nodata sentinel (the eager path promotes that case - # on mask hits; declaring float64 up front keeps every block's - # dtype consistent with the dask array's metadata regardless of - # whether the chunk actually contains sentinel pixels). + # float64 only when at least one selected band declares an integer + # nodata sentinel that round-trips through the band's dtype. + # + # The eager path (``read_vrt`` at lines ~4033-4064) defers the + # promotion to runtime: it scans every band's pixels and promotes + # only if at least one sentinel hits. The chunked path cannot + # afford that scan up front (it would require decoding the mosaic + # the dask graph was constructed to defer), so this is a + # parse-time approximation. The trade-off: + # * if a band declares nodata and no chunk contains the + # sentinel, the chunked output is float64 where the eager + # output would have stayed integer (acceptable: the user + # asked the source for nodata, so they expect NaN masking); + # * if a band does not declare nodata, both paths keep the + # source integer dtype (handled by the ``promotes is False`` + # fall-through below). + # See also Copilot review on PR #1822. + # TODO(#1825): share dtype + scale/offset/sentinel logic with + # the eager path instead of re-implementing it here. effective_dtypes = [] for vrt_band in selected_bands: eff = vrt_band.dtype @@ -4437,10 +4459,7 @@ def _read_vrt_chunked(source, *, window, band, name, chunks, gpu, dtype, if vrt.raster_type == 'point': attrs['raster_type'] = 'point' - # Surface the nodata sentinel for the selected band. The chunked - # path does not aggregate ``vrt.holes`` across tasks (per-task holes - # would need to be reduced by an extra delayed; not done here, see - # issue #1814 note in the docstring). + # Surface the nodata sentinel for the selected band. nodata_meta = None if vrt.bands: band_idx_for_nodata = band if band is not None else 0 @@ -4448,6 +4467,36 @@ def _read_vrt_chunked(source, *, window, band, name, chunks, gpu, dtype, if nodata_meta is not None: attrs['nodata'] = nodata_meta + # Static hole detection: mirror the eager-path ``attrs['vrt_holes']`` + # contract (#1734) by scanning every source referenced in the parsed + # VRT and recording the ones whose backing file does not exist on + # disk. The eager path discovers holes at decode time (per-source + # OSError / codec error) and aggregates them onto ``vrt.holes``; + # under chunked dispatch each per-task decode catches its own + # missing source and warns, but those records cannot be reduced + # back onto the parent DataArray without an extra synchronisation + # pass. The parse-time existence sweep catches the dominant + # missing-file case before scheduling and lets callers branch on + # ``"vrt_holes" in da.attrs`` exactly as with the eager reader. + # Empty list is omitted so the attr only appears when a hole is + # actually present. Each entry mirrors the eager schema: + # ``{'source', 'band', 'dst_rect', 'error'}``. + # TODO(#1825): the per-task path independently re-parses and + # re-resolves source paths; refactor to share the parent's scan. + chunked_holes: list[dict] = [] + for vrt_band in vrt.bands: + for src in vrt_band.sources: + if not _os.path.exists(src.filename): + chunked_holes.append({ + 'source': src.filename, + 'band': vrt_band.band_num, + 'dst_rect': (src.dst_rect.x_off, src.dst_rect.y_off, + src.dst_rect.x_size, src.dst_rect.y_size), + 'error': 'FileNotFoundError: source file not found', + }) + if chunked_holes: + attrs['vrt_holes'] = chunked_holes + if out_has_band_axis: dims = ['y', 'x', 'band'] coords['band'] = np.arange(n_out_bands) diff --git a/xrspatial/geotiff/tests/test_vrt_lazy_chunks_1814.py b/xrspatial/geotiff/tests/test_vrt_lazy_chunks_1814.py index c9563267..36bb91fa 100644 --- a/xrspatial/geotiff/tests/test_vrt_lazy_chunks_1814.py +++ b/xrspatial/geotiff/tests/test_vrt_lazy_chunks_1814.py @@ -269,3 +269,92 @@ def test_multiband_plus_chunks_preserves_band_dim(multiband_vrt): computed = result.compute() np.testing.assert_allclose(computed.values, src, rtol=0, atol=0) + + +# --------------------------------------------------------------------------- +# 7. Copilot review: ``attrs['vrt_holes']`` must propagate to the chunked +# path so users switching from eager to chunked keep the #1734 contract. +# --------------------------------------------------------------------------- + +def test_chunked_propagates_vrt_holes_when_source_missing(two_by_two_vrt): + """When a source referenced by the VRT does not exist on disk the + chunked reader must populate ``attrs['vrt_holes']`` with the same + schema the eager reader uses, so callers can branch on + ``"vrt_holes" in da.attrs`` regardless of which code path produced + the DataArray. + """ + import warnings + from xrspatial.geotiff import GeoTIFFFallbackWarning + + vrt_path, _ = two_by_two_vrt + vrt_dir = os.path.dirname(vrt_path) + # Remove one of the four source tiles. ``to_geotiff(.vrt, tile_size=128)`` + # writes tile files into a ``_tiles/`` subdirectory next to the + # .vrt; walk the tree for any .tif and unlink the first one. + tile_files = [] + for root, _dirs, files in os.walk(vrt_dir): + for f in files: + if f.endswith('.tif'): + tile_files.append(os.path.join(root, f)) + assert len(tile_files) >= 1 + os.unlink(tile_files[0]) + + with warnings.catch_warnings(): + warnings.simplefilter('ignore', GeoTIFFFallbackWarning) + result = read_vrt(vrt_path, chunks=(64, 64)) + + assert 'vrt_holes' in result.attrs, ( + "chunked path dropped vrt_holes contract from #1734" + ) + holes = result.attrs['vrt_holes'] + assert isinstance(holes, list) and len(holes) >= 1 + entry = holes[0] + # Schema parity with the eager path (see read_vrt at ~line 3963). + assert set(entry.keys()) >= {'source', 'band', 'dst_rect', 'error'} + assert isinstance(entry['dst_rect'], tuple) + assert len(entry['dst_rect']) == 4 + + +def test_chunked_no_vrt_holes_attr_when_complete(two_by_two_vrt): + """When every source is on disk the chunked reader must not set + ``attrs['vrt_holes']`` (eager parity: empty hole list is omitted). + """ + vrt_path, _ = two_by_two_vrt + result = read_vrt(vrt_path, chunks=(64, 64)) + assert 'vrt_holes' not in result.attrs + + +# --------------------------------------------------------------------------- +# 8. Copilot review: integer source with no declared nodata must keep its +# integer dtype through the chunked path (no spurious float64 promotion). +# --------------------------------------------------------------------------- + +def test_chunked_integer_no_nodata_keeps_source_dtype(): + """A uint16 source with no declared must produce a + uint16 chunked DataArray, not float64. The eager path stays integer + in this case because its runtime ``mask.any()`` is False; the + chunked path approximates with a static "any band declares nodata?" + check, which yields the same answer here. + """ + arr = np.arange(128 * 128, dtype=np.uint16).reshape(128, 128) + y = np.linspace(41.0, 40.0, 128) + x = np.linspace(-106.0, -105.0, 128) + raster = xr.DataArray(arr, dims=['y', 'x'], + coords={'y': y, 'x': x}, + attrs={'crs': 4326}) + td = tempfile.mkdtemp(prefix='tmp_1814_uint16_nonodata_') + tile_path = os.path.join(td, 'tile.tif') + to_geotiff(raster, tile_path) + vrt_path = os.path.join(td, 'mosaic.vrt') + # No ``nodata=`` passed: the VRT will not declare for + # this band, exercising the no-promotion branch. + _write_vrt_internal(vrt_path, [tile_path]) + + result = read_vrt(vrt_path, chunks=(32, 32)) + assert result.dtype == np.uint16, ( + f"expected uint16 (source dtype), got {result.dtype}; " + f"chunked path promoted to float64 despite no declared nodata" + ) + computed = result.compute() + assert computed.dtype == np.uint16 + np.testing.assert_array_equal(computed.values, arr) From 0927282d4115123d9578efbc432f292f9cf8d7fd Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Wed, 13 May 2026 10:13:49 -0700 Subject: [PATCH 3/4] geotiff: per-tile dim check uses default cap, not caller budget (#1823) PR #1803 forwarded the caller's max_pixels to read_to_array inside read_vrt's source loop so a tiny VRT output cannot force a huge source decode (#1796). The output-window check at the source read enforces that correctly. A separate per-tile dimension check at the same call sites also consumed the caller's max_pixels, so a caller setting max_pixels as an output budget (e.g. 10_000) failed the per-tile sanity check on any normal source whose default tile size is 256x256 (= 65_536 pixels). Use MAX_PIXELS_DEFAULT for the per-tile dim check at the two call sites in _read_tiles (local) and _read_tiles_cog_http (HTTP). The output-window check at the same functions continues to enforce the user-supplied max_pixels, preserving the #1796 protection. --- xrspatial/geotiff/_reader.py | 19 ++- .../tests/test_vrt_source_tile_check_1823.py | 113 ++++++++++++++++++ 2 files changed, 127 insertions(+), 5 deletions(-) create mode 100644 xrspatial/geotiff/tests/test_vrt_source_tile_check_1823.py diff --git a/xrspatial/geotiff/_reader.py b/xrspatial/geotiff/_reader.py index b3da32bd..c174c3fc 100644 --- a/xrspatial/geotiff/_reader.py +++ b/xrspatial/geotiff/_reader.py @@ -1560,9 +1560,14 @@ def _read_tiles(data: bytes, ifd: IFD, header: TIFFHeader, raise ValueError( f"Invalid tile dimensions: TileWidth={tw}, TileLength={th}") - # Reject crafted tile dims that would force huge per-tile allocations. - # A single tile's decoded bytes must also fit under the pixel budget. - _check_dimensions(tw, th, samples, max_pixels) + # Reject crafted tile dims (e.g. TileWidth = 2**31). This guards the + # TIFF header against malformed values; it is not the caller's output + # budget. The output-window check below uses ``max_pixels`` and is + # what enforces the user's per-call memory cap. The source-read path + # under ``read_vrt`` (#1796) relies on that output check to honour a + # small caller ``max_pixels`` against a normal-tile source; see + # #1823. + _check_dimensions(tw, th, samples, MAX_PIXELS_DEFAULT) # Per-tile compressed-byte cap (issue #1664). Same env var as the # HTTP path. mmap slicing is bounded by the file size, but the slice @@ -2016,10 +2021,14 @@ def _fetch_decode_cog_http_tiles( # A windowed HTTP read of a multi-billion-pixel COG only allocates # the window, so capping the full image would reject legitimate # tiled reads. The full-image cap still applies for whole-file - # reads (window is None). The single-tile budget always applies. + # reads (window is None). The per-tile dim check below guards the + # TIFF header against absurd ``TileWidth`` / ``TileLength`` values + # (e.g. 2**31) and uses ``MAX_PIXELS_DEFAULT`` so a caller's small + # ``max_pixels`` -- intended as an output-window budget -- does not + # reject normal 256x256 tiles. See #1823. if window is None: _check_dimensions(width, height, samples, max_pixels) - _check_dimensions(tw, th, samples, max_pixels) + _check_dimensions(tw, th, samples, MAX_PIXELS_DEFAULT) # Reject malformed TIFFs whose declared tile grid exceeds the supplied # TileOffsets length. See issue #1219. diff --git a/xrspatial/geotiff/tests/test_vrt_source_tile_check_1823.py b/xrspatial/geotiff/tests/test_vrt_source_tile_check_1823.py new file mode 100644 index 00000000..b02dd2fb --- /dev/null +++ b/xrspatial/geotiff/tests/test_vrt_source_tile_check_1823.py @@ -0,0 +1,113 @@ +"""Regression tests for #1823. + +PR #1803 forwarded the caller's ``max_pixels`` to ``read_to_array`` inside +the VRT source loop so that a tiny VRT output could not force a huge +source decode (#1796). The output-window check enforces that. A separate +per-tile dimension check was incorrectly using the same ``max_pixels`` +value, so a caller setting ``max_pixels`` as an output budget (e.g. +10,000) would also fail the per-tile sanity check on every normal source +whose default tile size is 256x256 (= 65,536 pixels). + +The #1796 protection remains: the output-window check still catches a +tiny VRT output that asks for a large source window. +""" +from __future__ import annotations + +import os +import tempfile + +import numpy as np +import pytest + +from xrspatial.geotiff import to_geotiff +from xrspatial.geotiff._reader import PixelSafetyLimitError +from xrspatial.geotiff._vrt import read_vrt + + +def _write_normal_tile_source(td: str) -> str: + """10x10 uint8 source -- ``to_geotiff`` pads to a 256x256 tile.""" + src = os.path.join(td, 'src.tif') + to_geotiff(np.zeros((10, 10), dtype=np.uint8), src, compression='none') + return src + + +def _write_vrt(td: str, *, dst_x_size: int, dst_y_size: int, + raster_x: int = 100, raster_y: int = 100, + src_x_size: int = 10, src_y_size: int = 10) -> str: + vrt = os.path.join(td, 'mosaic.vrt') + xml = ( + f'\n' + f' \n' + f' \n' + f' src.tif\n' + f' 1\n' + f' \n' + f' \n' + f' \n' + f' \n' + f'\n' + ) + with open(vrt, 'w') as f: + f.write(xml) + return vrt + + +class TestPerTileCheckDoesNotUseCallerBudget: + """Per-tile dim sanity must not reject normal 256x256 source tiles + when the caller's ``max_pixels`` is a small output-budget value.""" + + def test_normal_tile_source_with_small_max_pixels(self): + with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as td: + _write_normal_tile_source(td) + vrt = _write_vrt(td, dst_x_size=100, dst_y_size=100) + arr, _ = read_vrt(vrt, max_pixels=10_000) + assert arr.shape == (100, 100) + + def test_normal_tile_source_with_tiny_max_pixels(self): + """An output budget below a single tile must still succeed when + the requested output window itself fits.""" + with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as td: + _write_normal_tile_source(td) + # Output 5x5 = 25 pixels; max_pixels = 100 fits 25 with room. + vrt = _write_vrt(td, dst_x_size=5, dst_y_size=5, + raster_x=5, raster_y=5) + arr, _ = read_vrt(vrt, max_pixels=100) + assert arr.shape == (5, 5) + + +class TestOutputWindowCheckStillEnforced: + """The output-window check at the source read still rejects an + over-budget read; the #1796 protection is preserved.""" + + def test_output_window_exceeds_max_pixels_still_rejected(self): + with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as td: + src = os.path.join(td, 'src.tif') + to_geotiff(np.arange(16, dtype=np.uint8).reshape(4, 4), + src, compression='none') + vrt = _write_vrt(td, dst_x_size=1, dst_y_size=1, + raster_x=1, raster_y=1, + src_x_size=4, src_y_size=4) + # SrcRect 4x4 = 16 pixels > max_pixels=1 → output check fires. + with pytest.raises(ValueError, match="exceed"): + read_vrt(vrt, max_pixels=1) + + +class TestPerTileCheckStillRejectsCraftedHeader: + """A pathological ``TileWidth``/``TileLength`` must still fail at + the per-tile sanity check, which uses ``MAX_PIXELS_DEFAULT``.""" + + def test_per_tile_check_caps_at_default(self, monkeypatch): + """Lower ``MAX_PIXELS_DEFAULT`` to verify the per-tile call site + is wired to it (rather than to the caller's budget).""" + from xrspatial.geotiff import _reader as reader_mod + + monkeypatch.setattr(reader_mod, "MAX_PIXELS_DEFAULT", 100) + with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as td: + _write_normal_tile_source(td) + vrt = _write_vrt(td, dst_x_size=100, dst_y_size=100) + # 256x256 tile > patched MAX_PIXELS_DEFAULT=100 → per-tile + # check fires regardless of caller's max_pixels (1e9 here). + with pytest.raises(PixelSafetyLimitError, match="65,536"): + read_vrt(vrt, max_pixels=1_000_000_000) From e8b9643f760fd9797bbf8314e59ef526322ee4fd Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Wed, 13 May 2026 20:52:34 -0700 Subject: [PATCH 4/4] geotiff: clear _mmap_cache before unlink in #1814 hole test The Windows 3.14 CI job failed with WinError 32 when the test tried to unlink a tile referenced by the VRT: write_vrt() opens each tile via _FileSource to read its header, and _FileSource.close() only decrements the refcount in the shared _mmap_cache -- the mmap and file handle remain idle in the cache. POSIX allows unlink of an mmap'd file; Windows does not. Call _mmap_cache.clear() (the cache's existing test helper, which drops idle entries) immediately before os.unlink. Linux/macOS are unaffected. --- xrspatial/geotiff/tests/test_vrt_lazy_chunks_1814.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/xrspatial/geotiff/tests/test_vrt_lazy_chunks_1814.py b/xrspatial/geotiff/tests/test_vrt_lazy_chunks_1814.py index 36bb91fa..c692bdcd 100644 --- a/xrspatial/geotiff/tests/test_vrt_lazy_chunks_1814.py +++ b/xrspatial/geotiff/tests/test_vrt_lazy_chunks_1814.py @@ -285,6 +285,7 @@ def test_chunked_propagates_vrt_holes_when_source_missing(two_by_two_vrt): """ import warnings from xrspatial.geotiff import GeoTIFFFallbackWarning + from xrspatial.geotiff._reader import _mmap_cache vrt_path, _ = two_by_two_vrt vrt_dir = os.path.dirname(vrt_path) @@ -297,6 +298,10 @@ def test_chunked_propagates_vrt_holes_when_source_missing(two_by_two_vrt): if f.endswith('.tif'): tile_files.append(os.path.join(root, f)) assert len(tile_files) >= 1 + # write_vrt() opens each tile via _FileSource to read its header; + # _FileSource.close() decrements the refcount but the mmap stays + # cached. On Windows an active mmap blocks os.unlink (WinError 32). + _mmap_cache.clear() os.unlink(tile_files[0]) with warnings.catch_warnings():