diff --git a/xrspatial/geotiff/__init__.py b/xrspatial/geotiff/__init__.py index 588cc763..06edd26f 100644 --- a/xrspatial/geotiff/__init__.py +++ b/xrspatial/geotiff/__init__.py @@ -3789,6 +3789,185 @@ def _gpu_compress_to_part(gpu_arr, w, h, spp): _write_bytes(file_bytes, path) +def _vrt_effective_dtype(vrt, band): + """Return the dtype a VRT read is expected to materialize.""" + selected = [vrt.bands[band]] if band is not None else vrt.bands + if not selected: + raise ValueError( + "VRT has no elements; cannot determine " + "output dtype" + ) + effective = [] + for vrt_band in selected: + dt = vrt_band.dtype + for src in vrt_band.sources: + scaled = src.scale is not None and src.scale != 1.0 + offset = src.offset is not None and src.offset != 0.0 + if scaled or offset: + dt = np.dtype(np.float64) + break + if dt.kind in ('u', 'i') and vrt_band.nodata is not None: + try: + if isinstance(vrt_band.nodata, (int, np.integer)): + nd = int(vrt_band.nodata) + else: + nf = float(vrt_band.nodata) + nd = int(nf) if np.isfinite(nf) and nf.is_integer() else None + if nd is not None: + info = np.iinfo(dt) + if info.min <= nd <= info.max: + dt = np.dtype(np.float64) + except (TypeError, ValueError): + pass + effective.append(dt) + return np.result_type(*effective) + + +def _read_vrt_dask(source: str, *, dtype=None, window=None, band=None, + name=None, chunks=None, max_pixels=None, + missing_sources='warn'): + """Build a truly lazy dask-backed VRT DataArray from window tasks.""" + import os + import dask + import dask.array as da + from ._reader import _check_dimensions, MAX_PIXELS_DEFAULT + from ._vrt import parse_vrt + + with open(source, 'r') as f: + xml_str = f.read() + vrt_dir = os.path.dirname(os.path.abspath(source)) + vrt = parse_vrt(xml_str, vrt_dir) + + if band is not None: + if not isinstance(band, (int, np.integer)) or isinstance(band, bool): + raise ValueError(f"band must be a non-negative int, got {band!r}") + if band < 0 or band >= len(vrt.bands): + raise ValueError( + f"band index {band} out of range for VRT with " + f"{len(vrt.bands)} band(s)") + + if window is not None: + win_r0, win_c0, win_r1, win_c1 = window + if (win_r0 < 0 or win_c0 < 0 + or win_r1 > vrt.height or win_c1 > vrt.width + or win_r0 >= win_r1 or win_c0 >= win_c1): + raise ValueError( + f"window={window} is outside the VRT extent " + f"({vrt.height}x{vrt.width}) or has non-positive size.") + else: + win_r0, win_c0, win_r1, win_c1 = 0, 0, vrt.height, vrt.width + + height = win_r1 - win_r0 + width = win_c1 - win_c0 + n_bands = len([vrt.bands[band]] if band is not None else vrt.bands) + if max_pixels is None: + max_pixels = MAX_PIXELS_DEFAULT + _check_dimensions(width, height, n_bands, max_pixels) + + out_dtype = np.dtype(dtype) if dtype is not None else _vrt_effective_dtype(vrt, band) + if dtype is not None: + _validate_dtype_cast(_vrt_effective_dtype(vrt, band), out_dtype) + + if isinstance(chunks, int): + ch_h = ch_w = chunks + else: + ch_h, ch_w = chunks + + # Match read_geotiff_dask's graph-size guard. Each VRT chunk becomes a + # delayed task, so tiny chunks over very large VRT extents can OOM the + # driver during graph construction before any source read executes. + _MAX_DASK_CHUNKS = 50_000 + n_chunks = ((height + ch_h - 1) // ch_h) * ((width + ch_w - 1) // ch_w) + if n_chunks > _MAX_DASK_CHUNKS: + import math + scale = math.sqrt(n_chunks / _MAX_DASK_CHUNKS) + suggested_h = int(math.ceil(ch_h * scale)) + suggested_w = int(math.ceil(ch_w * scale)) + raise ValueError( + f"read_vrt: chunks=({ch_h}, {ch_w}) on a {height}x{width} " + f"VRT window would produce {n_chunks:,} dask tasks, exceeding " + f"the {_MAX_DASK_CHUNKS:,}-task cap. Pass a larger chunks=... " + f"value explicitly (e.g. chunks=({suggested_h}, " + f"{suggested_w}) keeps the task count under the cap)." + ) + + rows = list(range(0, height, ch_h)) + cols = list(range(0, width, ch_w)) + out_has_band_axis = band is None and n_bands > 1 + + @dask.delayed + def _read_chunk(chunk_window): + chunk_da = read_vrt( + source, dtype=dtype, window=chunk_window, band=band, + chunks=None, gpu=False, max_pixels=max_pixels, + missing_sources=missing_sources, + ) + arr = np.asarray(chunk_da.values) + if arr.dtype != out_dtype: + arr = arr.astype(out_dtype) + return arr + + dask_rows = [] + for r0 in rows: + r1 = min(r0 + ch_h, height) + dask_cols = [] + for c0 in cols: + c1 = min(c0 + ch_w, width) + chunk_window = (r0 + win_r0, c0 + win_c0, + r1 + win_r0, c1 + win_c0) + shape = ((r1 - r0, c1 - c0, n_bands) + if out_has_band_axis else (r1 - r0, c1 - c0)) + dask_cols.append(da.from_delayed( + _read_chunk(chunk_window), shape=shape, dtype=out_dtype)) + dask_rows.append(da.concatenate(dask_cols, axis=1)) + dask_arr = da.concatenate(dask_rows, axis=0) + + coords = {} + gt = vrt.geo_transform + if gt is not None: + origin_x, res_x, _, origin_y, _, res_y = gt + if vrt.raster_type == 'point': + x_shift = win_c0 * res_x + y_shift = win_r0 * res_y + else: + x_shift = (win_c0 + 0.5) * res_x + y_shift = (win_r0 + 0.5) * res_y + coords = { + 'x': np.arange(width, dtype=np.float64) * res_x + origin_x + x_shift, + 'y': np.arange(height, dtype=np.float64) * res_y + origin_y + y_shift, + } + + attrs = {} + if vrt.crs_wkt: + epsg = _wkt_to_epsg(vrt.crs_wkt) + if epsg is not None: + attrs['crs'] = epsg + attrs['crs_wkt'] = vrt.crs_wkt + if vrt.raster_type == 'point': + attrs['raster_type'] = 'point' + if vrt.bands: + band_idx_for_nodata = band if band is not None else 0 + nodata = vrt.bands[band_idx_for_nodata].nodata + if nodata is not None: + attrs['nodata'] = nodata + if gt is not None: + origin_x, res_x, _, origin_y, _, res_y = gt + attrs['transform'] = ( + float(res_x), 0.0, float(origin_x) + win_c0 * float(res_x), + 0.0, float(res_y), float(origin_y) + win_r0 * float(res_y), + ) + + if name is None: + name = os.path.splitext(os.path.basename(source))[0] + if out_has_band_axis: + dims = ['y', 'x', 'band'] + coords['band'] = np.arange(n_bands) + else: + dims = ['y', 'x'] + return xr.DataArray(dask_arr, dims=dims, coords=coords, + name=name, attrs=attrs) + + def read_vrt(source: str, *, dtype: str | np.dtype | None = None, window: tuple | None = None, @@ -3876,6 +4055,13 @@ def read_vrt(source: str, *, f"missing_sources must be 'warn' or 'raise', got " f"{missing_sources!r}") + if chunks is not None and not gpu: + return _read_vrt_dask( + source, dtype=dtype, window=window, band=band, name=name, + chunks=chunks, max_pixels=max_pixels, + missing_sources=missing_sources, + ) + arr, vrt = _read_vrt_internal( source, window=window, band=band, max_pixels=max_pixels, missing_sources=missing_sources, diff --git a/xrspatial/geotiff/_reader.py b/xrspatial/geotiff/_reader.py index b3da32bd..c174c3fc 100644 --- a/xrspatial/geotiff/_reader.py +++ b/xrspatial/geotiff/_reader.py @@ -1560,9 +1560,14 @@ def _read_tiles(data: bytes, ifd: IFD, header: TIFFHeader, raise ValueError( f"Invalid tile dimensions: TileWidth={tw}, TileLength={th}") - # Reject crafted tile dims that would force huge per-tile allocations. - # A single tile's decoded bytes must also fit under the pixel budget. - _check_dimensions(tw, th, samples, max_pixels) + # Reject crafted tile dims (e.g. TileWidth = 2**31). This guards the + # TIFF header against malformed values; it is not the caller's output + # budget. The output-window check below uses ``max_pixels`` and is + # what enforces the user's per-call memory cap. The source-read path + # under ``read_vrt`` (#1796) relies on that output check to honour a + # small caller ``max_pixels`` against a normal-tile source; see + # #1823. + _check_dimensions(tw, th, samples, MAX_PIXELS_DEFAULT) # Per-tile compressed-byte cap (issue #1664). Same env var as the # HTTP path. mmap slicing is bounded by the file size, but the slice @@ -2016,10 +2021,14 @@ def _fetch_decode_cog_http_tiles( # A windowed HTTP read of a multi-billion-pixel COG only allocates # the window, so capping the full image would reject legitimate # tiled reads. The full-image cap still applies for whole-file - # reads (window is None). The single-tile budget always applies. + # reads (window is None). The per-tile dim check below guards the + # TIFF header against absurd ``TileWidth`` / ``TileLength`` values + # (e.g. 2**31) and uses ``MAX_PIXELS_DEFAULT`` so a caller's small + # ``max_pixels`` -- intended as an output-window budget -- does not + # reject normal 256x256 tiles. See #1823. if window is None: _check_dimensions(width, height, samples, max_pixels) - _check_dimensions(tw, th, samples, max_pixels) + _check_dimensions(tw, th, samples, MAX_PIXELS_DEFAULT) # Reject malformed TIFFs whose declared tile grid exceeds the supplied # TileOffsets length. See issue #1219. diff --git a/xrspatial/geotiff/tests/test_read_vrt_lazy_chunks_1798.py b/xrspatial/geotiff/tests/test_read_vrt_lazy_chunks_1798.py new file mode 100644 index 00000000..d4cb89d8 --- /dev/null +++ b/xrspatial/geotiff/tests/test_read_vrt_lazy_chunks_1798.py @@ -0,0 +1,63 @@ +"""read_vrt(chunks=...) should build lazy window tasks (#1798).""" +from __future__ import annotations + +import os +import warnings + +import numpy as np +import pytest + +from xrspatial.geotiff import to_geotiff, read_vrt + + +def _write_vrt(vrt_path, source_name): + vrt_path.write_text( + '\n' + ' \n' + ' \n' + f' {source_name}' + '\n' + ' 1\n' + ' \n' + ' \n' + ' \n' + ' \n' + '\n' + ) + + +def test_read_vrt_chunks_matches_eager_values(tmp_path): + arr = np.arange(24, dtype=np.float32).reshape(4, 6) + src = tmp_path / "tmp_1798_source.tif" + to_geotiff(arr, str(src), compression='none') + vrt = tmp_path / "tmp_1798_source.vrt" + _write_vrt(vrt, os.path.basename(src)) + + eager = read_vrt(str(vrt)) + lazy = read_vrt(str(vrt), chunks=2) + + assert lazy.data.chunks == ((2, 2), (2, 2, 2)) + np.testing.assert_array_equal(lazy.compute().values, eager.values) + + +def test_read_vrt_chunks_does_not_read_sources_during_construction(tmp_path): + vrt = tmp_path / "tmp_1798_missing_source.vrt" + _write_vrt(vrt, "missing.tif") + + with warnings.catch_warnings(record=True) as caught: + lazy = read_vrt(str(vrt), chunks=2) + + assert caught == [] + assert hasattr(lazy.data, 'compute') + + +def test_read_vrt_chunks_rejects_excessive_task_count(tmp_path): + vrt = tmp_path / "tmp_1798_huge_extent.vrt" + vrt.write_text( + '\n' + ' \n' + '\n' + ) + + with pytest.raises(ValueError, match="task cap"): + read_vrt(str(vrt), chunks=1, max_pixels=20_000_000_000) diff --git a/xrspatial/geotiff/tests/test_vrt_source_tile_check_1823.py b/xrspatial/geotiff/tests/test_vrt_source_tile_check_1823.py new file mode 100644 index 00000000..b02dd2fb --- /dev/null +++ b/xrspatial/geotiff/tests/test_vrt_source_tile_check_1823.py @@ -0,0 +1,113 @@ +"""Regression tests for #1823. + +PR #1803 forwarded the caller's ``max_pixels`` to ``read_to_array`` inside +the VRT source loop so that a tiny VRT output could not force a huge +source decode (#1796). The output-window check enforces that. A separate +per-tile dimension check was incorrectly using the same ``max_pixels`` +value, so a caller setting ``max_pixels`` as an output budget (e.g. +10,000) would also fail the per-tile sanity check on every normal source +whose default tile size is 256x256 (= 65,536 pixels). + +The #1796 protection remains: the output-window check still catches a +tiny VRT output that asks for a large source window. +""" +from __future__ import annotations + +import os +import tempfile + +import numpy as np +import pytest + +from xrspatial.geotiff import to_geotiff +from xrspatial.geotiff._reader import PixelSafetyLimitError +from xrspatial.geotiff._vrt import read_vrt + + +def _write_normal_tile_source(td: str) -> str: + """10x10 uint8 source -- ``to_geotiff`` pads to a 256x256 tile.""" + src = os.path.join(td, 'src.tif') + to_geotiff(np.zeros((10, 10), dtype=np.uint8), src, compression='none') + return src + + +def _write_vrt(td: str, *, dst_x_size: int, dst_y_size: int, + raster_x: int = 100, raster_y: int = 100, + src_x_size: int = 10, src_y_size: int = 10) -> str: + vrt = os.path.join(td, 'mosaic.vrt') + xml = ( + f'\n' + f' \n' + f' \n' + f' src.tif\n' + f' 1\n' + f' \n' + f' \n' + f' \n' + f' \n' + f'\n' + ) + with open(vrt, 'w') as f: + f.write(xml) + return vrt + + +class TestPerTileCheckDoesNotUseCallerBudget: + """Per-tile dim sanity must not reject normal 256x256 source tiles + when the caller's ``max_pixels`` is a small output-budget value.""" + + def test_normal_tile_source_with_small_max_pixels(self): + with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as td: + _write_normal_tile_source(td) + vrt = _write_vrt(td, dst_x_size=100, dst_y_size=100) + arr, _ = read_vrt(vrt, max_pixels=10_000) + assert arr.shape == (100, 100) + + def test_normal_tile_source_with_tiny_max_pixels(self): + """An output budget below a single tile must still succeed when + the requested output window itself fits.""" + with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as td: + _write_normal_tile_source(td) + # Output 5x5 = 25 pixels; max_pixels = 100 fits 25 with room. + vrt = _write_vrt(td, dst_x_size=5, dst_y_size=5, + raster_x=5, raster_y=5) + arr, _ = read_vrt(vrt, max_pixels=100) + assert arr.shape == (5, 5) + + +class TestOutputWindowCheckStillEnforced: + """The output-window check at the source read still rejects an + over-budget read; the #1796 protection is preserved.""" + + def test_output_window_exceeds_max_pixels_still_rejected(self): + with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as td: + src = os.path.join(td, 'src.tif') + to_geotiff(np.arange(16, dtype=np.uint8).reshape(4, 4), + src, compression='none') + vrt = _write_vrt(td, dst_x_size=1, dst_y_size=1, + raster_x=1, raster_y=1, + src_x_size=4, src_y_size=4) + # SrcRect 4x4 = 16 pixels > max_pixels=1 → output check fires. + with pytest.raises(ValueError, match="exceed"): + read_vrt(vrt, max_pixels=1) + + +class TestPerTileCheckStillRejectsCraftedHeader: + """A pathological ``TileWidth``/``TileLength`` must still fail at + the per-tile sanity check, which uses ``MAX_PIXELS_DEFAULT``.""" + + def test_per_tile_check_caps_at_default(self, monkeypatch): + """Lower ``MAX_PIXELS_DEFAULT`` to verify the per-tile call site + is wired to it (rather than to the caller's budget).""" + from xrspatial.geotiff import _reader as reader_mod + + monkeypatch.setattr(reader_mod, "MAX_PIXELS_DEFAULT", 100) + with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as td: + _write_normal_tile_source(td) + vrt = _write_vrt(td, dst_x_size=100, dst_y_size=100) + # 256x256 tile > patched MAX_PIXELS_DEFAULT=100 → per-tile + # check fires regardless of caller's max_pixels (1e9 here). + with pytest.raises(PixelSafetyLimitError, match="65,536"): + read_vrt(vrt, max_pixels=1_000_000_000)