Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions xrspatial/hydro/fill_d8.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,12 @@ def fill_d8(dem: xr.DataArray,
"""
_validate_raster(dem, func_name='fill', name='dem')

if z_limit is not None and not (np.isfinite(z_limit) and z_limit >= 0):
raise ValueError(
"z_limit must be a non-negative finite number or None, got %s"
% (z_limit,)
)

data = dem.data

if isinstance(data, np.ndarray):
Expand Down
6 changes: 4 additions & 2 deletions xrspatial/hydro/flow_direction_mfd.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,10 @@ def flow_direction_mfd(agg: xr.DataArray,
_validate_boundary(boundary)

if p is not None:
if p <= 0:
raise ValueError("p must be a positive number, got %s" % p)
if not (np.isfinite(p) and p > 0):
raise ValueError(
"p must be a positive finite number, got %s" % p
)
p_fixed = float(p)
else:
p_fixed = -1.0 # sentinel for adaptive mode
Expand Down
5 changes: 5 additions & 0 deletions xrspatial/hydro/hand_d8.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,11 @@ def hand_d8(flow_dir: xr.DataArray,
_validate_raster(flow_accum, func_name='hand', name='flow_accum')
_validate_raster(elevation, func_name='hand', name='elevation')

if not np.isfinite(threshold):
raise ValueError(
"threshold must be a finite number, got %s" % threshold
)

fd_data = flow_dir.data
fa_data = flow_accum.data
el_data = elevation.data
Expand Down
5 changes: 5 additions & 0 deletions xrspatial/hydro/hand_dinf.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,11 @@ def hand_dinf(flow_dir_dinf: xr.DataArray,
_validate_raster(flow_accum, func_name='hand_dinf', name='flow_accum')
_validate_raster(elevation, func_name='hand_dinf', name='elevation')

if not np.isfinite(threshold):
raise ValueError(
"threshold must be a finite number, got %s" % threshold
)

fd_data = flow_dir_dinf.data
fa_data = flow_accum.data
el_data = elevation.data
Expand Down
5 changes: 5 additions & 0 deletions xrspatial/hydro/hand_mfd.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,11 @@ def hand_mfd(flow_dir_mfd: xr.DataArray,
_validate_raster(flow_accum, func_name='hand_mfd', name='flow_accum')
_validate_raster(elevation, func_name='hand_mfd', name='elevation')

if not np.isfinite(threshold):
raise ValueError(
"threshold must be a finite number, got %s" % threshold
)

data = flow_dir_mfd.data
fa_data = flow_accum.data
el_data = elevation.data
Expand Down
6 changes: 6 additions & 0 deletions xrspatial/hydro/snap_pour_point_d8.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,12 @@ def snap_pour_point_d8(flow_accum: xr.DataArray,
"""
_validate_raster(flow_accum, func_name='snap_pour_point', name='flow_accum')

if not isinstance(search_radius, (int, np.integer)) or search_radius < 1:
raise ValueError(
"search_radius must be a positive integer, got %r"
% (search_radius,)
)

fa_data = flow_accum.data
pp_data = pour_points.data

Expand Down
4 changes: 2 additions & 2 deletions xrspatial/hydro/tests/test_flow_direction_mfd.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,9 @@ def test_fixed_exponent_high():
def test_invalid_p():
data = np.ones((4, 5), dtype=np.float64)
agg = create_test_raster(data)
with pytest.raises(ValueError, match="p must be a positive number"):
with pytest.raises(ValueError, match="positive finite"):
flow_direction_mfd(agg, p=-1.0)
with pytest.raises(ValueError, match="p must be a positive number"):
with pytest.raises(ValueError, match="positive finite"):
flow_direction_mfd(agg, p=0.0)


Expand Down
124 changes: 124 additions & 0 deletions xrspatial/hydro/tests/test_validate_scalar_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""Tests for issue #1427: hydro scalar parameter validation.

Several public functions accept scalar parameters that previously did not
reject NaN/Inf or out-of-range values, leading to silent all-NaN output or
no-op behavior.
"""

import numpy as np
import pytest

from xrspatial.hydro import (
fill_d8,
flow_direction_d8,
flow_direction_dinf,
flow_direction_mfd,
hand_d8,
hand_dinf,
hand_mfd,
snap_pour_point_d8,
)
from xrspatial.tests.general_checks import create_test_raster


def _elev():
return create_test_raster(np.array([
[9, 9, 9, 9, 9],
[9, 8, 7, 6, 9],
[9, 7, 5, 4, 9],
[9, 6, 4, 3, 9],
[9, 9, 9, 9, 9],
], dtype=np.float64))


def _stream_inputs(method):
elev = _elev()
if method == 'd8':
fd = flow_direction_d8(elev)
elif method == 'dinf':
fd = flow_direction_dinf(elev)
else:
fd = flow_direction_mfd(elev)
fa = create_test_raster(np.ones((5, 5), dtype=np.float64))
el = elev
return fd, fa, el


# ---------------------------------------------------------------------------
# flow_direction_mfd p
# ---------------------------------------------------------------------------

class TestFlowDirectionMfdP:
@pytest.mark.parametrize("p", [float('nan'), float('inf'), float('-inf')])
def test_rejects_non_finite_p(self, p):
with pytest.raises(ValueError, match="positive finite"):
flow_direction_mfd(_elev(), p=p)

@pytest.mark.parametrize("p", [0, -1, -0.5])
def test_rejects_non_positive_p(self, p):
with pytest.raises(ValueError, match="positive finite"):
flow_direction_mfd(_elev(), p=p)

def test_accepts_positive_finite_p(self):
result = flow_direction_mfd(_elev(), p=1.5)
assert result.shape == (8, 5, 5)


# ---------------------------------------------------------------------------
# snap_pour_point_d8 search_radius
# ---------------------------------------------------------------------------

class TestSnapPourPointSearchRadius:
@pytest.mark.parametrize("r", [0, -1, -5])
def test_rejects_non_positive(self, r):
fa = create_test_raster(np.ones((5, 5), dtype=np.float64))
pp = create_test_raster(np.full((5, 5), np.nan, dtype=np.float64))
with pytest.raises(ValueError, match="positive integer"):
snap_pour_point_d8(fa, pp, search_radius=r)

@pytest.mark.parametrize("r", [5.5, float('nan'), float('inf')])
def test_rejects_non_int(self, r):
fa = create_test_raster(np.ones((5, 5), dtype=np.float64))
pp = create_test_raster(np.full((5, 5), np.nan, dtype=np.float64))
with pytest.raises(ValueError, match="positive integer"):
snap_pour_point_d8(fa, pp, search_radius=r)


# ---------------------------------------------------------------------------
# hand_* threshold
# ---------------------------------------------------------------------------

class TestHandThreshold:
@pytest.mark.parametrize("method,fn", [
('d8', hand_d8),
('dinf', hand_dinf),
('mfd', hand_mfd),
])
@pytest.mark.parametrize("t", [float('nan'), float('inf'), float('-inf')])
def test_rejects_non_finite_threshold(self, method, fn, t):
fd, fa, el = _stream_inputs(method)
with pytest.raises(ValueError, match="threshold must be a finite"):
fn(fd, fa, el, threshold=t)


# ---------------------------------------------------------------------------
# fill_d8 z_limit
# ---------------------------------------------------------------------------

class TestFillZLimit:
@pytest.mark.parametrize("z", [float('nan'), float('inf'), -0.5, -100])
def test_rejects_bad_z_limit(self, z):
with pytest.raises(ValueError, match="z_limit"):
fill_d8(_elev(), z_limit=z)

def test_accepts_none(self):
result = fill_d8(_elev(), z_limit=None)
assert result.shape == (5, 5)

def test_accepts_zero(self):
result = fill_d8(_elev(), z_limit=0)
assert result.shape == (5, 5)

def test_accepts_positive(self):
result = fill_d8(_elev(), z_limit=1.0)
assert result.shape == (5, 5)
Loading