Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions xrspatial/pathfinding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,18 @@

from xrspatial.cost_distance import _heap_push, _heap_pop
from xrspatial.utils import (
_validate_raster,
get_dataarray_resolution, ngjit,
has_cuda_and_cupy, is_cupy_array, is_dask_cupy, has_dask_array,
)

NONE = -1

# Maximum waypoint count for multi_stop_search. optimize_order builds an
# N x N distance matrix and runs N(N-1)/2 A* calls (O(N^3) when stitched
# with 2-opt), so unbounded N is a CPU DoS.
_MAX_WAYPOINTS = 1000


def _get_pixel_id(point, raster, xdim=None, ydim=None):
# get location in `raster` pixel space for `point` in y-x coordinate space
Expand Down Expand Up @@ -894,8 +900,12 @@ def a_star_search(surface: xr.DataArray,
>>> path_agg = a_star_search(agg, start, goal, barriers, 'lon', 'lat')
"""

if surface.ndim != 2:
raise ValueError("input `surface` must be 2D")
_validate_raster(surface, func_name='a_star_search',
name='surface', ndim=2)

if friction is not None:
_validate_raster(friction, func_name='a_star_search',
name='friction', ndim=2)

if surface.dims != (y, x):
raise ValueError("`surface.coords` should be named as coordinates:"
Expand Down Expand Up @@ -1370,12 +1380,23 @@ def multi_stop_search(surface: xr.DataArray,
unreachable.
"""
# --- Input validation ---
if surface.ndim != 2:
raise ValueError("input `surface` must be 2D")
_validate_raster(surface, func_name='multi_stop_search',
name='surface', ndim=2)

if friction is not None:
_validate_raster(friction, func_name='multi_stop_search',
name='friction', ndim=2)

if len(waypoints) < 2:
raise ValueError("at least 2 waypoints are required")

if len(waypoints) > _MAX_WAYPOINTS:
raise ValueError(
f"multi_stop_search() supports at most {_MAX_WAYPOINTS} "
f"waypoints, got {len(waypoints)}. optimize_order is "
f"O(N^3) so larger lists can hang the worker."
)

for idx, wp in enumerate(waypoints):
if len(wp) != 2:
raise ValueError(
Expand Down
52 changes: 52 additions & 0 deletions xrspatial/tests/test_pathfinding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,3 +1004,55 @@ def test_multi_stop_cupy_matches_numpy():
path_np.values,
equal_nan=True, atol=1e-10,
)


# =====================================================================
# Issue #1439: input validation
# =====================================================================

import xarray as _xr_for_validation


class TestPathfindingInputValidation:
"""a_star_search / multi_stop_search reject bad surface and waypoint cap (#1439)."""

@staticmethod
def _good_surface():
return _xr_for_validation.DataArray(
np.zeros((10, 10), dtype=np.float64),
dims=('y', 'x'),
coords={'y': np.arange(10), 'x': np.arange(10)},
)

def test_a_star_rejects_non_dataarray_surface(self):
from xrspatial.pathfinding import a_star_search
with pytest.raises(TypeError, match="xarray.DataArray"):
a_star_search(np.zeros((10, 10)), (0, 0), (5, 5))

def test_a_star_rejects_complex_dtype_surface(self):
from xrspatial.pathfinding import a_star_search
bad = _xr_for_validation.DataArray(
np.zeros((10, 10), dtype=np.complex128),
dims=('y', 'x'),
coords={'y': np.arange(10), 'x': np.arange(10)},
)
with pytest.raises(ValueError, match="real numeric"):
a_star_search(bad, (0, 0), (5, 5))

def test_a_star_rejects_non_dataarray_friction(self):
from xrspatial.pathfinding import a_star_search
s = self._good_surface()
with pytest.raises(TypeError, match="xarray.DataArray"):
a_star_search(s, (0, 0), (5, 5), friction=np.ones((10, 10)))

def test_multi_stop_rejects_non_dataarray_surface(self):
from xrspatial.pathfinding import multi_stop_search
with pytest.raises(TypeError, match="xarray.DataArray"):
multi_stop_search(np.zeros((10, 10)), [(0, 0), (5, 5)])

def test_multi_stop_caps_waypoints(self):
from xrspatial.pathfinding import multi_stop_search, _MAX_WAYPOINTS
s = self._good_surface()
too_many = [(i % 10, (i * 7) % 10) for i in range(_MAX_WAYPOINTS + 1)]
with pytest.raises(ValueError, match=f"at most {_MAX_WAYPOINTS}"):
multi_stop_search(s, too_many)
Loading