diff --git a/docs/source/reference/geotiff.rst b/docs/source/reference/geotiff.rst index 93d82685..e58a6f16 100644 --- a/docs/source/reference/geotiff.rst +++ b/docs/source/reference/geotiff.rst @@ -94,3 +94,14 @@ silently falls back to CPU. XRSPATIAL_GEOTIFF_STRICT=1 pytest xrspatial/geotiff/tests/ See issue #1662 for the audit and the full list of affected call sites. + +VRT missing sources +=================== + +``read_vrt`` accepts ``missing_sources='warn'`` or ``'raise'``. The default +``'warn'`` preserves the historical behavior: unreadable source files emit +:class:`xrspatial.geotiff.GeoTIFFFallbackWarning`, the returned DataArray +contains ``attrs['vrt_holes']``, and the mosaic is returned with holes. +Use ``missing_sources='raise'`` when a partial mosaic should fail the +pipeline immediately. ``XRSPATIAL_GEOTIFF_STRICT=1`` still raises in +``'warn'`` mode so CI environments can enforce fail-fast behavior globally. diff --git a/xrspatial/geotiff/__init__.py b/xrspatial/geotiff/__init__.py index 67a5fade..3b1ba3e4 100644 --- a/xrspatial/geotiff/__init__.py +++ b/xrspatial/geotiff/__init__.py @@ -3760,7 +3760,8 @@ def read_vrt(source: str, *, name: str | None = None, chunks: int | tuple | None = None, gpu: bool = False, - max_pixels: int | None = None) -> xr.DataArray: + max_pixels: int | None = None, + missing_sources: str = 'warn') -> xr.DataArray: """Read a GDAL Virtual Raster Table (.vrt) into an xarray.DataArray. The VRT's source GeoTIFFs are read via windowed reads and assembled @@ -3789,6 +3790,12 @@ def read_vrt(source: str, *, assembled VRT region. None uses the reader default (~1 billion). Matches ``open_geotiff`` / ``read_geotiff_dask`` / ``read_geotiff_gpu``. + missing_sources : {'warn', 'raise'}, default 'warn' + Policy for unreadable source files referenced by the VRT. ``'warn'`` + preserves the historical behavior: emit ``GeoTIFFFallbackWarning``, + record ``attrs['vrt_holes']``, and return a partial mosaic. + ``'raise'`` fails immediately. ``XRSPATIAL_GEOTIFF_STRICT=1`` also + raises, even when ``missing_sources='warn'``. Returns ------- @@ -3828,8 +3835,15 @@ def read_vrt(source: str, *, # default (eager read), so allow it through here. chunks = _validate_chunks_arg(chunks, allow_none=True) - arr, vrt = _read_vrt_internal(source, window=window, band=band, - max_pixels=max_pixels) + if missing_sources not in ('warn', 'raise'): + raise ValueError( + f"missing_sources must be 'warn' or 'raise', got " + f"{missing_sources!r}") + + arr, vrt = _read_vrt_internal( + source, window=window, band=band, max_pixels=max_pixels, + missing_sources=missing_sources, + ) if name is None: import os diff --git a/xrspatial/geotiff/_vrt.py b/xrspatial/geotiff/_vrt.py index 5a792ffa..74271f25 100644 --- a/xrspatial/geotiff/_vrt.py +++ b/xrspatial/geotiff/_vrt.py @@ -610,7 +610,8 @@ def _resample_nearest(src_arr: np.ndarray, def read_vrt(vrt_path: str, *, window=None, band: int | None = None, - max_pixels: int | None = None) -> tuple[np.ndarray, VRTDataset]: + max_pixels: int | None = None, + missing_sources: str = 'warn') -> tuple[np.ndarray, VRTDataset]: """Read a VRT file by assembling pixel data from its source files. Parameters @@ -621,6 +622,14 @@ def read_vrt(vrt_path: str, *, window=None, (row_start, col_start, row_stop, col_stop) for windowed read. band : int or None Band index (0-based). None returns all bands. + max_pixels : int or None + Maximum allowed pixel count (width * height * samples) for the + assembled VRT region. None uses the reader default. + missing_sources : {'warn', 'raise'} + Policy for unreadable source files referenced by the VRT. + ``'warn'`` emits ``GeoTIFFFallbackWarning`` and records + ``vrt.holes`` unless ``XRSPATIAL_GEOTIFF_STRICT=1`` is set. + ``'raise'`` fails immediately. Returns ------- @@ -633,6 +642,10 @@ def read_vrt(vrt_path: str, *, window=None, vrt_dir = os.path.dirname(os.path.abspath(vrt_path)) vrt = parse_vrt(xml_str, vrt_dir) + if missing_sources not in ('warn', 'raise'): + raise ValueError( + f"missing_sources must be 'warn' or 'raise', got " + f"{missing_sources!r}") # Validate ``band`` against the parsed band count. Python list # indexing would silently accept negative values (``vrt.bands[-1]`` @@ -892,13 +905,14 @@ def read_vrt(vrt_path: str, *, window=None, # See issue #1734. import warnings from . import _geotiff_strict_mode, GeoTIFFFallbackWarning - if _geotiff_strict_mode(): + if missing_sources == 'raise' or _geotiff_strict_mode(): raise warnings.warn( f"VRT source {src.filename!r} could not be read " f"({type(e).__name__}: {e}); skipping. The output " f"mosaic will have a hole at this tile. Inspect " f"``DataArray.attrs['vrt_holes']`` or set " + f"missing_sources='raise' or " f"XRSPATIAL_GEOTIFF_STRICT=1 to raise instead.", GeoTIFFFallbackWarning, stacklevel=2, diff --git a/xrspatial/geotiff/tests/test_vrt_missing_sources_policy_1799.py b/xrspatial/geotiff/tests/test_vrt_missing_sources_policy_1799.py new file mode 100644 index 00000000..217e62a2 --- /dev/null +++ b/xrspatial/geotiff/tests/test_vrt_missing_sources_policy_1799.py @@ -0,0 +1,51 @@ +"""VRT missing-source handling has an explicit policy (#1799).""" +from __future__ import annotations + +import pytest + +from xrspatial.geotiff import read_vrt +from xrspatial.geotiff import GeoTIFFFallbackWarning + + +def _write_missing_source_vrt(path): + path.write_text( + '\n' + ' \n' + ' \n' + ' missing.tif' + '\n' + ' 1\n' + ' \n' + ' \n' + ' \n' + ' \n' + '\n' + ) + + +def test_read_vrt_missing_sources_warns_and_records_hole(tmp_path): + vrt = tmp_path / "tmp_1799_missing_warn.vrt" + _write_missing_source_vrt(vrt) + + with pytest.warns(GeoTIFFFallbackWarning, match="could not be read"): + da = read_vrt(str(vrt), missing_sources='warn') + + assert 'vrt_holes' in da.attrs + assert da.attrs['vrt_holes'][0]['source'].endswith('missing.tif') + + +def test_read_vrt_missing_sources_raise_fails_fast(tmp_path): + vrt = tmp_path / "tmp_1799_missing_raise.vrt" + _write_missing_source_vrt(vrt) + + with pytest.raises((OSError, ValueError)): + read_vrt(str(vrt), missing_sources='raise') + + +def test_read_vrt_missing_sources_validates_policy(tmp_path): + vrt = tmp_path / "tmp_1799_missing_bad_policy.vrt" + _write_missing_source_vrt(vrt) + + with pytest.raises(ValueError, match="missing_sources"): + read_vrt(str(vrt), missing_sources='ignore') +