diff --git a/xrspatial/reproject/__init__.py b/xrspatial/reproject/__init__.py index ac9f93cd..33292f81 100644 --- a/xrspatial/reproject/__init__.py +++ b/xrspatial/reproject/__init__.py @@ -14,6 +14,8 @@ import numpy as np import xarray as xr +from xrspatial.utils import _validate_raster + from ._crs_utils import _detect_nodata, _detect_source_crs, _resolve_crs from ._grid import ( _chunk_bounds, @@ -522,11 +524,8 @@ def reproject( If vertical transformation was applied, ``attrs['vertical_crs']`` records the target vertical datum. """ - if not isinstance(raster, xr.DataArray): - raise TypeError( - f"reproject(): raster must be an xr.DataArray, " - f"got {type(raster).__name__}" - ) + _validate_raster(raster, func_name='reproject', name='raster', + ndim=(2, 3)) _validate_grid_params( resolution=resolution, @@ -1360,6 +1359,10 @@ def merge( if not rasters: raise ValueError("merge(): rasters list must not be empty") + for i, r in enumerate(rasters): + _validate_raster(r, func_name='merge', name=f'rasters[{i}]', + ndim=(2, 3)) + _validate_grid_params( resolution=resolution, bounds=bounds, diff --git a/xrspatial/reproject/_vertical.py b/xrspatial/reproject/_vertical.py index e0013e54..430cea53 100644 --- a/xrspatial/reproject/_vertical.py +++ b/xrspatial/reproject/_vertical.py @@ -239,6 +239,11 @@ def geoid_height_raster(raster, model='EGM96'): """ import xarray as xr + from xrspatial.utils import _validate_raster + + _validate_raster(raster, func_name='geoid_height_raster', + name='raster', ndim=(2, 3)) + data, left, top, res_x, res_y, h, w = _load_geoid(model) y = raster.coords[raster.dims[-2]].values.astype(np.float64) diff --git a/xrspatial/tests/test_reproject.py b/xrspatial/tests/test_reproject.py index 704469c4..9fbd268c 100644 --- a/xrspatial/tests/test_reproject.py +++ b/xrspatial/tests/test_reproject.py @@ -507,7 +507,7 @@ def test_missing_crs_raises(self): def test_non_dataarray_raises(self): from xrspatial.reproject import reproject - with pytest.raises(TypeError, match="xr.DataArray"): + with pytest.raises(TypeError, match="xarray.DataArray"): reproject(np.zeros((4, 4)), 'EPSG:4326') def test_output_has_crs_attr(self): @@ -1519,6 +1519,51 @@ def test_numpy_chunk_source_window_guard(self): assert result.shape[0] > 0 and result.shape[1] > 0 +# ===================================================================== +# Issue #1431: _validate_raster on public API inputs +# ===================================================================== + +class TestValidateRasterInputs: + """reproject(), merge(), geoid_height_raster() validate inputs (#1431).""" + + def test_reproject_rejects_1d_dataarray(self): + from xrspatial.reproject import reproject + bad = xr.DataArray(np.zeros(5, dtype=np.float64), dims=('y',)) + with pytest.raises(ValueError, match=r"must be 2D ?or 3D"): + reproject(bad, 'EPSG:4326') + + def test_reproject_rejects_complex_dtype(self): + from xrspatial.reproject import reproject + bad = xr.DataArray( + np.zeros((4, 4), dtype=np.complex128), + dims=('y', 'x'), + coords={'y': np.arange(4), 'x': np.arange(4)}, + ) + with pytest.raises(ValueError, match="real numeric"): + reproject(bad, 'EPSG:4326') + + def test_merge_rejects_non_dataarray_element(self): + from xrspatial.reproject import merge + good = xr.DataArray( + np.zeros((4, 4), dtype=np.float64), + dims=('y', 'x'), + coords={'y': np.arange(4), 'x': np.arange(4)}, + ) + with pytest.raises(TypeError, match="xarray.DataArray"): + merge([good, np.zeros((4, 4))]) + + def test_geoid_height_raster_rejects_non_dataarray(self): + from xrspatial.reproject import geoid_height_raster + with pytest.raises(TypeError, match="xarray.DataArray"): + geoid_height_raster(np.zeros((4, 4))) + + def test_geoid_height_raster_rejects_1d_dataarray(self): + from xrspatial.reproject import geoid_height_raster + bad = xr.DataArray(np.zeros(5, dtype=np.float64), dims=('y',)) + with pytest.raises(ValueError, match=r"must be 2D ?or 3D"): + geoid_height_raster(bad) + + # ===================================================================== # Issue #1433: grid/bounds/precision parameter validation # =====================================================================