Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 186 additions & 0 deletions xrspatial/geotiff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3789,6 +3789,185 @@ def _gpu_compress_to_part(gpu_arr, w, h, spp):
_write_bytes(file_bytes, path)


def _vrt_effective_dtype(vrt, band):
"""Return the dtype a VRT read is expected to materialize."""
selected = [vrt.bands[band]] if band is not None else vrt.bands
if not selected:
raise ValueError(
"VRT has no <VRTRasterBand> elements; cannot determine "
"output dtype"
)
effective = []
for vrt_band in selected:
dt = vrt_band.dtype
for src in vrt_band.sources:
scaled = src.scale is not None and src.scale != 1.0
offset = src.offset is not None and src.offset != 0.0
if scaled or offset:
dt = np.dtype(np.float64)
break
if dt.kind in ('u', 'i') and vrt_band.nodata is not None:
try:
if isinstance(vrt_band.nodata, (int, np.integer)):
nd = int(vrt_band.nodata)
else:
nf = float(vrt_band.nodata)
nd = int(nf) if np.isfinite(nf) and nf.is_integer() else None
if nd is not None:
info = np.iinfo(dt)
if info.min <= nd <= info.max:
dt = np.dtype(np.float64)
except (TypeError, ValueError):
pass
effective.append(dt)
return np.result_type(*effective)


def _read_vrt_dask(source: str, *, dtype=None, window=None, band=None,
name=None, chunks=None, max_pixels=None,
missing_sources='warn'):
"""Build a truly lazy dask-backed VRT DataArray from window tasks."""
import os
import dask
import dask.array as da
from ._reader import _check_dimensions, MAX_PIXELS_DEFAULT
from ._vrt import parse_vrt

with open(source, 'r') as f:
xml_str = f.read()
vrt_dir = os.path.dirname(os.path.abspath(source))
vrt = parse_vrt(xml_str, vrt_dir)

if band is not None:
if not isinstance(band, (int, np.integer)) or isinstance(band, bool):
raise ValueError(f"band must be a non-negative int, got {band!r}")
if band < 0 or band >= len(vrt.bands):
raise ValueError(
f"band index {band} out of range for VRT with "
f"{len(vrt.bands)} band(s)")

if window is not None:
win_r0, win_c0, win_r1, win_c1 = window
if (win_r0 < 0 or win_c0 < 0
or win_r1 > vrt.height or win_c1 > vrt.width
or win_r0 >= win_r1 or win_c0 >= win_c1):
raise ValueError(
f"window={window} is outside the VRT extent "
f"({vrt.height}x{vrt.width}) or has non-positive size.")
else:
win_r0, win_c0, win_r1, win_c1 = 0, 0, vrt.height, vrt.width

height = win_r1 - win_r0
width = win_c1 - win_c0
n_bands = len([vrt.bands[band]] if band is not None else vrt.bands)
if max_pixels is None:
max_pixels = MAX_PIXELS_DEFAULT
_check_dimensions(width, height, n_bands, max_pixels)

out_dtype = np.dtype(dtype) if dtype is not None else _vrt_effective_dtype(vrt, band)
if dtype is not None:
_validate_dtype_cast(_vrt_effective_dtype(vrt, band), out_dtype)

if isinstance(chunks, int):
ch_h = ch_w = chunks
else:
ch_h, ch_w = chunks

# Match read_geotiff_dask's graph-size guard. Each VRT chunk becomes a
# delayed task, so tiny chunks over very large VRT extents can OOM the
# driver during graph construction before any source read executes.
_MAX_DASK_CHUNKS = 50_000
n_chunks = ((height + ch_h - 1) // ch_h) * ((width + ch_w - 1) // ch_w)
if n_chunks > _MAX_DASK_CHUNKS:
import math
scale = math.sqrt(n_chunks / _MAX_DASK_CHUNKS)
suggested_h = int(math.ceil(ch_h * scale))
suggested_w = int(math.ceil(ch_w * scale))
raise ValueError(
f"read_vrt: chunks=({ch_h}, {ch_w}) on a {height}x{width} "
f"VRT window would produce {n_chunks:,} dask tasks, exceeding "
f"the {_MAX_DASK_CHUNKS:,}-task cap. Pass a larger chunks=... "
f"value explicitly (e.g. chunks=({suggested_h}, "
f"{suggested_w}) keeps the task count under the cap)."
)

rows = list(range(0, height, ch_h))
cols = list(range(0, width, ch_w))
out_has_band_axis = band is None and n_bands > 1

Comment on lines +3894 to +3897
@dask.delayed
def _read_chunk(chunk_window):
chunk_da = read_vrt(
source, dtype=dtype, window=chunk_window, band=band,
chunks=None, gpu=False, max_pixels=max_pixels,
missing_sources=missing_sources,
)
arr = np.asarray(chunk_da.values)
if arr.dtype != out_dtype:
arr = arr.astype(out_dtype)
Comment on lines +3898 to +3907
return arr

dask_rows = []
for r0 in rows:
r1 = min(r0 + ch_h, height)
dask_cols = []
for c0 in cols:
c1 = min(c0 + ch_w, width)
chunk_window = (r0 + win_r0, c0 + win_c0,
r1 + win_r0, c1 + win_c0)
shape = ((r1 - r0, c1 - c0, n_bands)
if out_has_band_axis else (r1 - r0, c1 - c0))
dask_cols.append(da.from_delayed(
_read_chunk(chunk_window), shape=shape, dtype=out_dtype))
dask_rows.append(da.concatenate(dask_cols, axis=1))
dask_arr = da.concatenate(dask_rows, axis=0)

coords = {}
gt = vrt.geo_transform
if gt is not None:
origin_x, res_x, _, origin_y, _, res_y = gt
if vrt.raster_type == 'point':
x_shift = win_c0 * res_x
y_shift = win_r0 * res_y
else:
x_shift = (win_c0 + 0.5) * res_x
y_shift = (win_r0 + 0.5) * res_y
coords = {
'x': np.arange(width, dtype=np.float64) * res_x + origin_x + x_shift,
'y': np.arange(height, dtype=np.float64) * res_y + origin_y + y_shift,
}

attrs = {}
if vrt.crs_wkt:
epsg = _wkt_to_epsg(vrt.crs_wkt)
if epsg is not None:
attrs['crs'] = epsg
attrs['crs_wkt'] = vrt.crs_wkt
if vrt.raster_type == 'point':
attrs['raster_type'] = 'point'
if vrt.bands:
band_idx_for_nodata = band if band is not None else 0
nodata = vrt.bands[band_idx_for_nodata].nodata
if nodata is not None:
attrs['nodata'] = nodata
if gt is not None:
origin_x, res_x, _, origin_y, _, res_y = gt
attrs['transform'] = (
float(res_x), 0.0, float(origin_x) + win_c0 * float(res_x),
0.0, float(res_y), float(origin_y) + win_r0 * float(res_y),
)

if name is None:
name = os.path.splitext(os.path.basename(source))[0]
if out_has_band_axis:
dims = ['y', 'x', 'band']
coords['band'] = np.arange(n_bands)
else:
dims = ['y', 'x']
return xr.DataArray(dask_arr, dims=dims, coords=coords,
name=name, attrs=attrs)


def read_vrt(source: str, *,
dtype: str | np.dtype | None = None,
window: tuple | None = None,
Expand Down Expand Up @@ -3876,6 +4055,13 @@ def read_vrt(source: str, *,
f"missing_sources must be 'warn' or 'raise', got "
f"{missing_sources!r}")

if chunks is not None and not gpu:
return _read_vrt_dask(
source, dtype=dtype, window=window, band=band, name=name,
chunks=chunks, max_pixels=max_pixels,
missing_sources=missing_sources,
)

arr, vrt = _read_vrt_internal(
source, window=window, band=band, max_pixels=max_pixels,
missing_sources=missing_sources,
Expand Down
19 changes: 14 additions & 5 deletions xrspatial/geotiff/_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1560,9 +1560,14 @@ def _read_tiles(data: bytes, ifd: IFD, header: TIFFHeader,
raise ValueError(
f"Invalid tile dimensions: TileWidth={tw}, TileLength={th}")

# Reject crafted tile dims that would force huge per-tile allocations.
# A single tile's decoded bytes must also fit under the pixel budget.
_check_dimensions(tw, th, samples, max_pixels)
# Reject crafted tile dims (e.g. TileWidth = 2**31). This guards the
# TIFF header against malformed values; it is not the caller's output
# budget. The output-window check below uses ``max_pixels`` and is
# what enforces the user's per-call memory cap. The source-read path
# under ``read_vrt`` (#1796) relies on that output check to honour a
# small caller ``max_pixels`` against a normal-tile source; see
# #1823.
_check_dimensions(tw, th, samples, MAX_PIXELS_DEFAULT)

# Per-tile compressed-byte cap (issue #1664). Same env var as the
# HTTP path. mmap slicing is bounded by the file size, but the slice
Expand Down Expand Up @@ -2016,10 +2021,14 @@ def _fetch_decode_cog_http_tiles(
# A windowed HTTP read of a multi-billion-pixel COG only allocates
# the window, so capping the full image would reject legitimate
# tiled reads. The full-image cap still applies for whole-file
# reads (window is None). The single-tile budget always applies.
# reads (window is None). The per-tile dim check below guards the
# TIFF header against absurd ``TileWidth`` / ``TileLength`` values
# (e.g. 2**31) and uses ``MAX_PIXELS_DEFAULT`` so a caller's small
# ``max_pixels`` -- intended as an output-window budget -- does not
# reject normal 256x256 tiles. See #1823.
if window is None:
_check_dimensions(width, height, samples, max_pixels)
_check_dimensions(tw, th, samples, max_pixels)
_check_dimensions(tw, th, samples, MAX_PIXELS_DEFAULT)

# Reject malformed TIFFs whose declared tile grid exceeds the supplied
# TileOffsets length. See issue #1219.
Expand Down
63 changes: 63 additions & 0 deletions xrspatial/geotiff/tests/test_read_vrt_lazy_chunks_1798.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""read_vrt(chunks=...) should build lazy window tasks (#1798)."""
from __future__ import annotations

import os
import warnings

import numpy as np
import pytest

from xrspatial.geotiff import to_geotiff, read_vrt


def _write_vrt(vrt_path, source_name):
vrt_path.write_text(
'<VRTDataset rasterXSize="6" rasterYSize="4">\n'
' <VRTRasterBand dataType="Float32" band="1">\n'
' <SimpleSource>\n'
f' <SourceFilename relativeToVRT="1">{source_name}'
'</SourceFilename>\n'
' <SourceBand>1</SourceBand>\n'
' <SrcRect xOff="0" yOff="0" xSize="6" ySize="4"/>\n'
' <DstRect xOff="0" yOff="0" xSize="6" ySize="4"/>\n'
' </SimpleSource>\n'
' </VRTRasterBand>\n'
'</VRTDataset>\n'
)


def test_read_vrt_chunks_matches_eager_values(tmp_path):
arr = np.arange(24, dtype=np.float32).reshape(4, 6)
src = tmp_path / "tmp_1798_source.tif"
to_geotiff(arr, str(src), compression='none')
vrt = tmp_path / "tmp_1798_source.vrt"
_write_vrt(vrt, os.path.basename(src))

eager = read_vrt(str(vrt))
lazy = read_vrt(str(vrt), chunks=2)

assert lazy.data.chunks == ((2, 2), (2, 2, 2))
np.testing.assert_array_equal(lazy.compute().values, eager.values)


def test_read_vrt_chunks_does_not_read_sources_during_construction(tmp_path):
vrt = tmp_path / "tmp_1798_missing_source.vrt"
_write_vrt(vrt, "missing.tif")

with warnings.catch_warnings(record=True) as caught:
lazy = read_vrt(str(vrt), chunks=2)

assert caught == []
assert hasattr(lazy.data, 'compute')


def test_read_vrt_chunks_rejects_excessive_task_count(tmp_path):
vrt = tmp_path / "tmp_1798_huge_extent.vrt"
vrt.write_text(
'<VRTDataset rasterXSize="100000" rasterYSize="100000">\n'
' <VRTRasterBand dataType="Byte" band="1"/>\n'
'</VRTDataset>\n'
)

with pytest.raises(ValueError, match="task cap"):
read_vrt(str(vrt), chunks=1, max_pixels=20_000_000_000)
Loading
Loading