diff --git a/xrspatial/hydro/fill_d8.py b/xrspatial/hydro/fill_d8.py index e41541d9..7fd97283 100644 --- a/xrspatial/hydro/fill_d8.py +++ b/xrspatial/hydro/fill_d8.py @@ -566,6 +566,12 @@ def fill_d8(dem: xr.DataArray, """ _validate_raster(dem, func_name='fill', name='dem') + if z_limit is not None and not (np.isfinite(z_limit) and z_limit >= 0): + raise ValueError( + "z_limit must be a non-negative finite number or None, got %s" + % (z_limit,) + ) + data = dem.data if isinstance(data, np.ndarray): diff --git a/xrspatial/hydro/flow_direction_mfd.py b/xrspatial/hydro/flow_direction_mfd.py index eb9c5f2d..f91f6a02 100644 --- a/xrspatial/hydro/flow_direction_mfd.py +++ b/xrspatial/hydro/flow_direction_mfd.py @@ -385,8 +385,10 @@ def flow_direction_mfd(agg: xr.DataArray, _validate_boundary(boundary) if p is not None: - if p <= 0: - raise ValueError("p must be a positive number, got %s" % p) + if not (np.isfinite(p) and p > 0): + raise ValueError( + "p must be a positive finite number, got %s" % p + ) p_fixed = float(p) else: p_fixed = -1.0 # sentinel for adaptive mode diff --git a/xrspatial/hydro/hand_d8.py b/xrspatial/hydro/hand_d8.py index d2ae14a9..4c2200b8 100644 --- a/xrspatial/hydro/hand_d8.py +++ b/xrspatial/hydro/hand_d8.py @@ -969,6 +969,11 @@ def hand_d8(flow_dir: xr.DataArray, _validate_raster(flow_accum, func_name='hand', name='flow_accum') _validate_raster(elevation, func_name='hand', name='elevation') + if not np.isfinite(threshold): + raise ValueError( + "threshold must be a finite number, got %s" % threshold + ) + fd_data = flow_dir.data fa_data = flow_accum.data el_data = elevation.data diff --git a/xrspatial/hydro/hand_dinf.py b/xrspatial/hydro/hand_dinf.py index 09e2f0dc..ce5c7640 100644 --- a/xrspatial/hydro/hand_dinf.py +++ b/xrspatial/hydro/hand_dinf.py @@ -674,6 +674,11 @@ def hand_dinf(flow_dir_dinf: xr.DataArray, _validate_raster(flow_accum, func_name='hand_dinf', name='flow_accum') _validate_raster(elevation, func_name='hand_dinf', name='elevation') + if not np.isfinite(threshold): + raise ValueError( + "threshold must be a finite number, got %s" % threshold + ) + fd_data = flow_dir_dinf.data fa_data = flow_accum.data el_data = elevation.data diff --git a/xrspatial/hydro/hand_mfd.py b/xrspatial/hydro/hand_mfd.py index 23d49507..7d199fb4 100644 --- a/xrspatial/hydro/hand_mfd.py +++ b/xrspatial/hydro/hand_mfd.py @@ -695,6 +695,11 @@ def hand_mfd(flow_dir_mfd: xr.DataArray, _validate_raster(flow_accum, func_name='hand_mfd', name='flow_accum') _validate_raster(elevation, func_name='hand_mfd', name='elevation') + if not np.isfinite(threshold): + raise ValueError( + "threshold must be a finite number, got %s" % threshold + ) + data = flow_dir_mfd.data fa_data = flow_accum.data el_data = elevation.data diff --git a/xrspatial/hydro/snap_pour_point_d8.py b/xrspatial/hydro/snap_pour_point_d8.py index 91cb7b46..7e49baac 100644 --- a/xrspatial/hydro/snap_pour_point_d8.py +++ b/xrspatial/hydro/snap_pour_point_d8.py @@ -566,6 +566,12 @@ def snap_pour_point_d8(flow_accum: xr.DataArray, """ _validate_raster(flow_accum, func_name='snap_pour_point', name='flow_accum') + if not isinstance(search_radius, (int, np.integer)) or search_radius < 1: + raise ValueError( + "search_radius must be a positive integer, got %r" + % (search_radius,) + ) + fa_data = flow_accum.data pp_data = pour_points.data diff --git a/xrspatial/hydro/tests/test_flow_direction_mfd.py b/xrspatial/hydro/tests/test_flow_direction_mfd.py index a689ebbb..cc9a64b4 100644 --- a/xrspatial/hydro/tests/test_flow_direction_mfd.py +++ b/xrspatial/hydro/tests/test_flow_direction_mfd.py @@ -273,9 +273,9 @@ def test_fixed_exponent_high(): def test_invalid_p(): data = np.ones((4, 5), dtype=np.float64) agg = create_test_raster(data) - with pytest.raises(ValueError, match="p must be a positive number"): + with pytest.raises(ValueError, match="positive finite"): flow_direction_mfd(agg, p=-1.0) - with pytest.raises(ValueError, match="p must be a positive number"): + with pytest.raises(ValueError, match="positive finite"): flow_direction_mfd(agg, p=0.0) diff --git a/xrspatial/hydro/tests/test_validate_scalar_params.py b/xrspatial/hydro/tests/test_validate_scalar_params.py new file mode 100644 index 00000000..95222a62 --- /dev/null +++ b/xrspatial/hydro/tests/test_validate_scalar_params.py @@ -0,0 +1,124 @@ +"""Tests for issue #1427: hydro scalar parameter validation. + +Several public functions accept scalar parameters that previously did not +reject NaN/Inf or out-of-range values, leading to silent all-NaN output or +no-op behavior. +""" + +import numpy as np +import pytest + +from xrspatial.hydro import ( + fill_d8, + flow_direction_d8, + flow_direction_dinf, + flow_direction_mfd, + hand_d8, + hand_dinf, + hand_mfd, + snap_pour_point_d8, +) +from xrspatial.tests.general_checks import create_test_raster + + +def _elev(): + return create_test_raster(np.array([ + [9, 9, 9, 9, 9], + [9, 8, 7, 6, 9], + [9, 7, 5, 4, 9], + [9, 6, 4, 3, 9], + [9, 9, 9, 9, 9], + ], dtype=np.float64)) + + +def _stream_inputs(method): + elev = _elev() + if method == 'd8': + fd = flow_direction_d8(elev) + elif method == 'dinf': + fd = flow_direction_dinf(elev) + else: + fd = flow_direction_mfd(elev) + fa = create_test_raster(np.ones((5, 5), dtype=np.float64)) + el = elev + return fd, fa, el + + +# --------------------------------------------------------------------------- +# flow_direction_mfd p +# --------------------------------------------------------------------------- + +class TestFlowDirectionMfdP: + @pytest.mark.parametrize("p", [float('nan'), float('inf'), float('-inf')]) + def test_rejects_non_finite_p(self, p): + with pytest.raises(ValueError, match="positive finite"): + flow_direction_mfd(_elev(), p=p) + + @pytest.mark.parametrize("p", [0, -1, -0.5]) + def test_rejects_non_positive_p(self, p): + with pytest.raises(ValueError, match="positive finite"): + flow_direction_mfd(_elev(), p=p) + + def test_accepts_positive_finite_p(self): + result = flow_direction_mfd(_elev(), p=1.5) + assert result.shape == (8, 5, 5) + + +# --------------------------------------------------------------------------- +# snap_pour_point_d8 search_radius +# --------------------------------------------------------------------------- + +class TestSnapPourPointSearchRadius: + @pytest.mark.parametrize("r", [0, -1, -5]) + def test_rejects_non_positive(self, r): + fa = create_test_raster(np.ones((5, 5), dtype=np.float64)) + pp = create_test_raster(np.full((5, 5), np.nan, dtype=np.float64)) + with pytest.raises(ValueError, match="positive integer"): + snap_pour_point_d8(fa, pp, search_radius=r) + + @pytest.mark.parametrize("r", [5.5, float('nan'), float('inf')]) + def test_rejects_non_int(self, r): + fa = create_test_raster(np.ones((5, 5), dtype=np.float64)) + pp = create_test_raster(np.full((5, 5), np.nan, dtype=np.float64)) + with pytest.raises(ValueError, match="positive integer"): + snap_pour_point_d8(fa, pp, search_radius=r) + + +# --------------------------------------------------------------------------- +# hand_* threshold +# --------------------------------------------------------------------------- + +class TestHandThreshold: + @pytest.mark.parametrize("method,fn", [ + ('d8', hand_d8), + ('dinf', hand_dinf), + ('mfd', hand_mfd), + ]) + @pytest.mark.parametrize("t", [float('nan'), float('inf'), float('-inf')]) + def test_rejects_non_finite_threshold(self, method, fn, t): + fd, fa, el = _stream_inputs(method) + with pytest.raises(ValueError, match="threshold must be a finite"): + fn(fd, fa, el, threshold=t) + + +# --------------------------------------------------------------------------- +# fill_d8 z_limit +# --------------------------------------------------------------------------- + +class TestFillZLimit: + @pytest.mark.parametrize("z", [float('nan'), float('inf'), -0.5, -100]) + def test_rejects_bad_z_limit(self, z): + with pytest.raises(ValueError, match="z_limit"): + fill_d8(_elev(), z_limit=z) + + def test_accepts_none(self): + result = fill_d8(_elev(), z_limit=None) + assert result.shape == (5, 5) + + def test_accepts_zero(self): + result = fill_d8(_elev(), z_limit=0) + assert result.shape == (5, 5) + + def test_accepts_positive(self): + result = fill_d8(_elev(), z_limit=1.0) + assert result.shape == (5, 5)