From 4923288323cfc97bbb4f7d02ca8227255e66e598 Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Sun, 3 May 2026 16:07:51 -0700 Subject: [PATCH] Validate raster inputs in reproject public APIs (#1431) reproject() and geoid_height_raster() previously checked only isinstance(raster, xr.DataArray) and then accessed raster.coords / raster.dims, so 1-D DataArrays and complex-dtype rasters failed inside helpers with cryptic IndexError / KeyError. merge() did not validate elements at all, so a non-DataArray entry surfaced as AttributeError on .data. Replace the inline isinstance guard in reproject() and add per- element validation in merge() with _validate_raster (ndim=(2, 3)). Add the same call at the top of geoid_height_raster. Update existing test_non_dataarray_raises to match the new error message. 5 new tests in TestValidateRasterInputs. --- xrspatial/reproject/__init__.py | 13 +++++---- xrspatial/reproject/_vertical.py | 5 ++++ xrspatial/tests/test_reproject.py | 47 ++++++++++++++++++++++++++++++- 3 files changed, 59 insertions(+), 6 deletions(-) diff --git a/xrspatial/reproject/__init__.py b/xrspatial/reproject/__init__.py index 5d041bf2..a3a1abc9 100644 --- a/xrspatial/reproject/__init__.py +++ b/xrspatial/reproject/__init__.py @@ -14,6 +14,8 @@ import numpy as np import xarray as xr +from xrspatial.utils import _validate_raster + from ._crs_utils import _detect_nodata, _detect_source_crs, _resolve_crs from ._grid import ( _chunk_bounds, @@ -521,11 +523,8 @@ def reproject( If vertical transformation was applied, ``attrs['vertical_crs']`` records the target vertical datum. """ - if not isinstance(raster, xr.DataArray): - raise TypeError( - f"reproject(): raster must be an xr.DataArray, " - f"got {type(raster).__name__}" - ) + _validate_raster(raster, func_name='reproject', name='raster', + ndim=(2, 3)) _validate_resampling(resampling) @@ -1350,6 +1349,10 @@ def merge( if not rasters: raise ValueError("merge(): rasters list must not be empty") + for i, r in enumerate(rasters): + _validate_raster(r, func_name='merge', name=f'rasters[{i}]', + ndim=(2, 3)) + _validate_resampling(resampling) _validate_strategy(strategy) diff --git a/xrspatial/reproject/_vertical.py b/xrspatial/reproject/_vertical.py index 2762db46..9853e339 100644 --- a/xrspatial/reproject/_vertical.py +++ b/xrspatial/reproject/_vertical.py @@ -226,6 +226,11 @@ def geoid_height_raster(raster, model='EGM96'): """ import xarray as xr + from xrspatial.utils import _validate_raster + + _validate_raster(raster, func_name='geoid_height_raster', + name='raster', ndim=(2, 3)) + data, left, top, res_x, res_y, h, w = _load_geoid(model) y = raster.coords[raster.dims[-2]].values.astype(np.float64) diff --git a/xrspatial/tests/test_reproject.py b/xrspatial/tests/test_reproject.py index a3ad6b8e..5cedacbd 100644 --- a/xrspatial/tests/test_reproject.py +++ b/xrspatial/tests/test_reproject.py @@ -507,7 +507,7 @@ def test_missing_crs_raises(self): def test_non_dataarray_raises(self): from xrspatial.reproject import reproject - with pytest.raises(TypeError, match="xr.DataArray"): + with pytest.raises(TypeError, match="xarray.DataArray"): reproject(np.zeros((4, 4)), 'EPSG:4326') def test_output_has_crs_attr(self): @@ -1517,3 +1517,48 @@ def test_numpy_chunk_source_window_guard(self): ) result = reproject(raster, target_crs='EPSG:3857') assert result.shape[0] > 0 and result.shape[1] > 0 + + +# ===================================================================== +# Issue #1431: _validate_raster on public API inputs +# ===================================================================== + +class TestValidateRasterInputs: + """reproject(), merge(), geoid_height_raster() validate inputs (#1431).""" + + def test_reproject_rejects_1d_dataarray(self): + from xrspatial.reproject import reproject + bad = xr.DataArray(np.zeros(5, dtype=np.float64), dims=('y',)) + with pytest.raises(ValueError, match=r"must be 2D ?or 3D"): + reproject(bad, 'EPSG:4326') + + def test_reproject_rejects_complex_dtype(self): + from xrspatial.reproject import reproject + bad = xr.DataArray( + np.zeros((4, 4), dtype=np.complex128), + dims=('y', 'x'), + coords={'y': np.arange(4), 'x': np.arange(4)}, + ) + with pytest.raises(ValueError, match="real numeric"): + reproject(bad, 'EPSG:4326') + + def test_merge_rejects_non_dataarray_element(self): + from xrspatial.reproject import merge + good = xr.DataArray( + np.zeros((4, 4), dtype=np.float64), + dims=('y', 'x'), + coords={'y': np.arange(4), 'x': np.arange(4)}, + ) + with pytest.raises(TypeError, match="xarray.DataArray"): + merge([good, np.zeros((4, 4))]) + + def test_geoid_height_raster_rejects_non_dataarray(self): + from xrspatial.reproject import geoid_height_raster + with pytest.raises(TypeError, match="xarray.DataArray"): + geoid_height_raster(np.zeros((4, 4))) + + def test_geoid_height_raster_rejects_1d_dataarray(self): + from xrspatial.reproject import geoid_height_raster + bad = xr.DataArray(np.zeros(5, dtype=np.float64), dims=('y',)) + with pytest.raises(ValueError, match=r"must be 2D ?or 3D"): + geoid_height_raster(bad)