diff --git a/xrspatial/pathfinding.py b/xrspatial/pathfinding.py index 76beb12e..9e836fdf 100644 --- a/xrspatial/pathfinding.py +++ b/xrspatial/pathfinding.py @@ -16,12 +16,18 @@ from xrspatial.cost_distance import _heap_push, _heap_pop from xrspatial.utils import ( + _validate_raster, get_dataarray_resolution, ngjit, has_cuda_and_cupy, is_cupy_array, is_dask_cupy, has_dask_array, ) NONE = -1 +# Maximum waypoint count for multi_stop_search. optimize_order builds an +# N x N distance matrix and runs N(N-1)/2 A* calls (O(N^3) when stitched +# with 2-opt), so unbounded N is a CPU DoS. +_MAX_WAYPOINTS = 1000 + def _get_pixel_id(point, raster, xdim=None, ydim=None): # get location in `raster` pixel space for `point` in y-x coordinate space @@ -894,8 +900,12 @@ def a_star_search(surface: xr.DataArray, >>> path_agg = a_star_search(agg, start, goal, barriers, 'lon', 'lat') """ - if surface.ndim != 2: - raise ValueError("input `surface` must be 2D") + _validate_raster(surface, func_name='a_star_search', + name='surface', ndim=2) + + if friction is not None: + _validate_raster(friction, func_name='a_star_search', + name='friction', ndim=2) if surface.dims != (y, x): raise ValueError("`surface.coords` should be named as coordinates:" @@ -1370,12 +1380,23 @@ def multi_stop_search(surface: xr.DataArray, unreachable. """ # --- Input validation --- - if surface.ndim != 2: - raise ValueError("input `surface` must be 2D") + _validate_raster(surface, func_name='multi_stop_search', + name='surface', ndim=2) + + if friction is not None: + _validate_raster(friction, func_name='multi_stop_search', + name='friction', ndim=2) if len(waypoints) < 2: raise ValueError("at least 2 waypoints are required") + if len(waypoints) > _MAX_WAYPOINTS: + raise ValueError( + f"multi_stop_search() supports at most {_MAX_WAYPOINTS} " + f"waypoints, got {len(waypoints)}. optimize_order is " + f"O(N^3) so larger lists can hang the worker." + ) + for idx, wp in enumerate(waypoints): if len(wp) != 2: raise ValueError( diff --git a/xrspatial/tests/test_pathfinding.py b/xrspatial/tests/test_pathfinding.py index 91eab85e..4fccd29f 100644 --- a/xrspatial/tests/test_pathfinding.py +++ b/xrspatial/tests/test_pathfinding.py @@ -1004,3 +1004,55 @@ def test_multi_stop_cupy_matches_numpy(): path_np.values, equal_nan=True, atol=1e-10, ) + + +# ===================================================================== +# Issue #1439: input validation +# ===================================================================== + +import xarray as _xr_for_validation + + +class TestPathfindingInputValidation: + """a_star_search / multi_stop_search reject bad surface and waypoint cap (#1439).""" + + @staticmethod + def _good_surface(): + return _xr_for_validation.DataArray( + np.zeros((10, 10), dtype=np.float64), + dims=('y', 'x'), + coords={'y': np.arange(10), 'x': np.arange(10)}, + ) + + def test_a_star_rejects_non_dataarray_surface(self): + from xrspatial.pathfinding import a_star_search + with pytest.raises(TypeError, match="xarray.DataArray"): + a_star_search(np.zeros((10, 10)), (0, 0), (5, 5)) + + def test_a_star_rejects_complex_dtype_surface(self): + from xrspatial.pathfinding import a_star_search + bad = _xr_for_validation.DataArray( + np.zeros((10, 10), dtype=np.complex128), + dims=('y', 'x'), + coords={'y': np.arange(10), 'x': np.arange(10)}, + ) + with pytest.raises(ValueError, match="real numeric"): + a_star_search(bad, (0, 0), (5, 5)) + + def test_a_star_rejects_non_dataarray_friction(self): + from xrspatial.pathfinding import a_star_search + s = self._good_surface() + with pytest.raises(TypeError, match="xarray.DataArray"): + a_star_search(s, (0, 0), (5, 5), friction=np.ones((10, 10))) + + def test_multi_stop_rejects_non_dataarray_surface(self): + from xrspatial.pathfinding import multi_stop_search + with pytest.raises(TypeError, match="xarray.DataArray"): + multi_stop_search(np.zeros((10, 10)), [(0, 0), (5, 5)]) + + def test_multi_stop_caps_waypoints(self): + from xrspatial.pathfinding import multi_stop_search, _MAX_WAYPOINTS + s = self._good_surface() + too_many = [(i % 10, (i * 7) % 10) for i in range(_MAX_WAYPOINTS + 1)] + with pytest.raises(ValueError, match=f"at most {_MAX_WAYPOINTS}"): + multi_stop_search(s, too_many)