Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions xrspatial/reproject/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import numpy as np
import xarray as xr

from xrspatial.utils import _validate_raster

from ._crs_utils import _detect_nodata, _detect_source_crs, _resolve_crs
from ._grid import (
_chunk_bounds,
Expand Down Expand Up @@ -522,11 +524,8 @@ def reproject(
If vertical transformation was applied, ``attrs['vertical_crs']``
records the target vertical datum.
"""
if not isinstance(raster, xr.DataArray):
raise TypeError(
f"reproject(): raster must be an xr.DataArray, "
f"got {type(raster).__name__}"
)
_validate_raster(raster, func_name='reproject', name='raster',
ndim=(2, 3))

_validate_grid_params(
resolution=resolution,
Expand Down Expand Up @@ -1360,6 +1359,10 @@ def merge(
if not rasters:
raise ValueError("merge(): rasters list must not be empty")

for i, r in enumerate(rasters):
_validate_raster(r, func_name='merge', name=f'rasters[{i}]',
ndim=(2, 3))

_validate_grid_params(
resolution=resolution,
bounds=bounds,
Expand Down
5 changes: 5 additions & 0 deletions xrspatial/reproject/_vertical.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,11 @@ def geoid_height_raster(raster, model='EGM96'):
"""
import xarray as xr

from xrspatial.utils import _validate_raster

_validate_raster(raster, func_name='geoid_height_raster',
name='raster', ndim=(2, 3))

data, left, top, res_x, res_y, h, w = _load_geoid(model)

y = raster.coords[raster.dims[-2]].values.astype(np.float64)
Expand Down
47 changes: 46 additions & 1 deletion xrspatial/tests/test_reproject.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ def test_missing_crs_raises(self):

def test_non_dataarray_raises(self):
from xrspatial.reproject import reproject
with pytest.raises(TypeError, match="xr.DataArray"):
with pytest.raises(TypeError, match="xarray.DataArray"):
reproject(np.zeros((4, 4)), 'EPSG:4326')

def test_output_has_crs_attr(self):
Expand Down Expand Up @@ -1519,6 +1519,51 @@ def test_numpy_chunk_source_window_guard(self):
assert result.shape[0] > 0 and result.shape[1] > 0


# =====================================================================
# Issue #1431: _validate_raster on public API inputs
# =====================================================================

class TestValidateRasterInputs:
"""reproject(), merge(), geoid_height_raster() validate inputs (#1431)."""

def test_reproject_rejects_1d_dataarray(self):
from xrspatial.reproject import reproject
bad = xr.DataArray(np.zeros(5, dtype=np.float64), dims=('y',))
with pytest.raises(ValueError, match=r"must be 2D ?or 3D"):
reproject(bad, 'EPSG:4326')

def test_reproject_rejects_complex_dtype(self):
from xrspatial.reproject import reproject
bad = xr.DataArray(
np.zeros((4, 4), dtype=np.complex128),
dims=('y', 'x'),
coords={'y': np.arange(4), 'x': np.arange(4)},
)
with pytest.raises(ValueError, match="real numeric"):
reproject(bad, 'EPSG:4326')

def test_merge_rejects_non_dataarray_element(self):
from xrspatial.reproject import merge
good = xr.DataArray(
np.zeros((4, 4), dtype=np.float64),
dims=('y', 'x'),
coords={'y': np.arange(4), 'x': np.arange(4)},
)
with pytest.raises(TypeError, match="xarray.DataArray"):
merge([good, np.zeros((4, 4))])

def test_geoid_height_raster_rejects_non_dataarray(self):
from xrspatial.reproject import geoid_height_raster
with pytest.raises(TypeError, match="xarray.DataArray"):
geoid_height_raster(np.zeros((4, 4)))

def test_geoid_height_raster_rejects_1d_dataarray(self):
from xrspatial.reproject import geoid_height_raster
bad = xr.DataArray(np.zeros(5, dtype=np.float64), dims=('y',))
with pytest.raises(ValueError, match=r"must be 2D ?or 3D"):
geoid_height_raster(bad)


# =====================================================================
# Issue #1433: grid/bounds/precision parameter validation
# =====================================================================
Expand Down