From b3ba9629b4fdcc0f29afee8a1c662dc71adbf856 Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Sun, 3 May 2026 15:58:36 -0700 Subject: [PATCH] Add _validate_raster on secondary DataArray args in hydro (#1425) Each of the following public functions previously validated only its primary raster. A non-DataArray, 1-D DataArray, or None for the secondary arg raised a confusing AttributeError / IndexError from inside the implementation rather than a clean ValueError from the public API. Functions touched (secondary arg in parens): - watershed_d8 / watershed_dinf / watershed_mfd (pour_points) - snap_pour_point_d8 (pour_points) - flow_path_d8 / flow_path_dinf / flow_path_mfd (start_points) - stream_link_d8 / stream_link_dinf / stream_link_mfd (flow_accum) - stream_order_d8 / stream_order_dinf / stream_order_mfd (flow_accum) 13 new tests in test_validate_secondary_args.py confirm each function rejects a numpy ndarray for its secondary arg. --- xrspatial/hydro/flow_path_d8.py | 2 + xrspatial/hydro/flow_path_dinf.py | 2 + xrspatial/hydro/flow_path_mfd.py | 2 + xrspatial/hydro/snap_pour_point_d8.py | 2 + xrspatial/hydro/stream_link_d8.py | 1 + xrspatial/hydro/stream_link_dinf.py | 3 + xrspatial/hydro/stream_link_mfd.py | 2 + xrspatial/hydro/stream_order_d8.py | 1 + xrspatial/hydro/stream_order_dinf.py | 3 + xrspatial/hydro/stream_order_mfd.py | 2 + .../tests/test_validate_secondary_args.py | 137 ++++++++++++++++++ xrspatial/hydro/watershed_d8.py | 1 + xrspatial/hydro/watershed_dinf.py | 2 + xrspatial/hydro/watershed_mfd.py | 2 + 14 files changed, 162 insertions(+) create mode 100644 xrspatial/hydro/tests/test_validate_secondary_args.py diff --git a/xrspatial/hydro/flow_path_d8.py b/xrspatial/hydro/flow_path_d8.py index 56e36d24..468b4b8f 100644 --- a/xrspatial/hydro/flow_path_d8.py +++ b/xrspatial/hydro/flow_path_d8.py @@ -458,6 +458,8 @@ def flow_path_d8(flow_dir: xr.DataArray, raster-scan order wins. """ _validate_raster(flow_dir, func_name='flow_path', name='flow_dir') + _validate_raster(start_points, func_name='flow_path', + name='start_points') fd_data = flow_dir.data sp_data = start_points.data diff --git a/xrspatial/hydro/flow_path_dinf.py b/xrspatial/hydro/flow_path_dinf.py index 6e82fb88..d9d3e178 100644 --- a/xrspatial/hydro/flow_path_dinf.py +++ b/xrspatial/hydro/flow_path_dinf.py @@ -423,6 +423,8 @@ def flow_path_dinf(flow_dir_dinf: xr.DataArray, """ _validate_raster(flow_dir_dinf, func_name='flow_path_dinf', name='flow_dir_dinf') + _validate_raster(start_points, func_name='flow_path_dinf', + name='start_points') fd_data = flow_dir_dinf.data sp_data = start_points.data diff --git a/xrspatial/hydro/flow_path_mfd.py b/xrspatial/hydro/flow_path_mfd.py index 2743e3fe..6e992e9d 100644 --- a/xrspatial/hydro/flow_path_mfd.py +++ b/xrspatial/hydro/flow_path_mfd.py @@ -418,6 +418,8 @@ def flow_path_mfd(flow_dir_mfd: xr.DataArray, """ _validate_raster(flow_dir_mfd, func_name='flow_path_mfd', name='flow_dir_mfd', ndim=3) + _validate_raster(start_points, func_name='flow_path_mfd', + name='start_points') data = flow_dir_mfd.data sp_data = start_points.data diff --git a/xrspatial/hydro/snap_pour_point_d8.py b/xrspatial/hydro/snap_pour_point_d8.py index 91cb7b46..97576dc4 100644 --- a/xrspatial/hydro/snap_pour_point_d8.py +++ b/xrspatial/hydro/snap_pour_point_d8.py @@ -565,6 +565,8 @@ def snap_pour_point_d8(flow_accum: xr.DataArray, locations. Non-pour-point cells are NaN. """ _validate_raster(flow_accum, func_name='snap_pour_point', name='flow_accum') + _validate_raster(pour_points, func_name='snap_pour_point', + name='pour_points') fa_data = flow_accum.data pp_data = pour_points.data diff --git a/xrspatial/hydro/stream_link_d8.py b/xrspatial/hydro/stream_link_d8.py index 2d66a862..aceb0e26 100644 --- a/xrspatial/hydro/stream_link_d8.py +++ b/xrspatial/hydro/stream_link_d8.py @@ -1103,6 +1103,7 @@ def stream_link_d8(flow_dir: xr.DataArray, integer ID. Non-stream cells are NaN. """ _validate_raster(flow_dir, func_name='stream_link', name='flow_dir') + _validate_raster(flow_accum, func_name='stream_link', name='flow_accum') fd_data = flow_dir.data fa_data = flow_accum.data diff --git a/xrspatial/hydro/stream_link_dinf.py b/xrspatial/hydro/stream_link_dinf.py index e5222ab8..a7ee9490 100644 --- a/xrspatial/hydro/stream_link_dinf.py +++ b/xrspatial/hydro/stream_link_dinf.py @@ -1280,6 +1280,9 @@ def stream_link_dinf(flow_dir_dinf: xr.DataArray, _validate_raster(flow_dir_dinf, func_name='stream_link_dinf', name='flow_dir_dinf') + _validate_raster(flow_accum, + func_name='stream_link_dinf', + name='flow_accum') fd_data = flow_dir_dinf.data fa_data = flow_accum.data diff --git a/xrspatial/hydro/stream_link_mfd.py b/xrspatial/hydro/stream_link_mfd.py index 82bafd23..85004e30 100644 --- a/xrspatial/hydro/stream_link_mfd.py +++ b/xrspatial/hydro/stream_link_mfd.py @@ -1014,6 +1014,8 @@ def stream_link_mfd(fractions: xr.DataArray, """ _validate_raster(fractions, func_name='stream_link_mfd', name='fractions', ndim=3) + _validate_raster(flow_accum, func_name='stream_link_mfd', + name='flow_accum') data = fractions.data diff --git a/xrspatial/hydro/stream_order_d8.py b/xrspatial/hydro/stream_order_d8.py index 3f3d19de..68e09905 100644 --- a/xrspatial/hydro/stream_order_d8.py +++ b/xrspatial/hydro/stream_order_d8.py @@ -1594,6 +1594,7 @@ def stream_order_d8(flow_dir: xr.DataArray, of Geology, 74(1), 17-37. """ _validate_raster(flow_dir, func_name='stream_order', name='flow_dir') + _validate_raster(flow_accum, func_name='stream_order', name='flow_accum') method = ordering.lower() if method not in ('strahler', 'shreve'): diff --git a/xrspatial/hydro/stream_order_dinf.py b/xrspatial/hydro/stream_order_dinf.py index c99a234f..8e8c985f 100644 --- a/xrspatial/hydro/stream_order_dinf.py +++ b/xrspatial/hydro/stream_order_dinf.py @@ -1934,6 +1934,9 @@ def stream_order_dinf(flow_dir_dinf: xr.DataArray, _validate_raster(flow_dir_dinf, func_name='stream_order_dinf', name='flow_dir_dinf') + _validate_raster(flow_accum, + func_name='stream_order_dinf', + name='flow_accum') method = method.lower() if method not in ('strahler', 'shreve'): diff --git a/xrspatial/hydro/stream_order_mfd.py b/xrspatial/hydro/stream_order_mfd.py index b7bd3768..12e13f9c 100644 --- a/xrspatial/hydro/stream_order_mfd.py +++ b/xrspatial/hydro/stream_order_mfd.py @@ -1495,6 +1495,8 @@ def stream_order_mfd(fractions: xr.DataArray, """ _validate_raster(fractions, func_name='stream_order_mfd', name='fractions', ndim=3) + _validate_raster(flow_accum, func_name='stream_order_mfd', + name='flow_accum') method = method.lower() if method not in ('strahler', 'shreve'): diff --git a/xrspatial/hydro/tests/test_validate_secondary_args.py b/xrspatial/hydro/tests/test_validate_secondary_args.py new file mode 100644 index 00000000..e5ecb3e4 --- /dev/null +++ b/xrspatial/hydro/tests/test_validate_secondary_args.py @@ -0,0 +1,137 @@ +"""Tests for issue #1425: hydro public APIs validate secondary DataArray args. + +Each function below previously validated only its primary raster. Passing +a non-DataArray, a 1-D DataArray, or `None` for the secondary arg raised +a confusing AttributeError / IndexError from inside the implementation. +The fix adds `_validate_raster` on the secondary arg in the public API. +""" + +import numpy as np +import pytest +import xarray as xr + +from xrspatial.hydro import ( + flow_direction_d8, + flow_direction_dinf, + flow_direction_mfd, + flow_path_d8, + flow_path_dinf, + flow_path_mfd, + snap_pour_point_d8, + stream_link_d8, + stream_link_dinf, + stream_link_mfd, + stream_order_d8, + stream_order_dinf, + stream_order_mfd, + watershed_d8, + watershed_dinf, + watershed_mfd, +) +from xrspatial.tests.general_checks import create_test_raster + + +def _elev(): + """Small bowl elevation raster used as the primary input.""" + return create_test_raster(np.array([ + [9, 9, 9, 9, 9], + [9, 8, 7, 6, 9], + [9, 7, 5, 4, 9], + [9, 6, 4, 3, 9], + [9, 9, 9, 9, 9], + ], dtype=np.float64)) + + +# --------------------------------------------------------------------------- +# watershed_* +# --------------------------------------------------------------------------- + +class TestWatershedPourPoints: + def test_watershed_d8_rejects_non_dataarray_pour_points(self): + fd = flow_direction_d8(_elev()) + with pytest.raises(TypeError): + watershed_d8(fd, pour_points=np.zeros((5, 5))) + + def test_watershed_dinf_rejects_non_dataarray_pour_points(self): + fd = flow_direction_dinf(_elev()) + with pytest.raises(TypeError): + watershed_dinf(fd, pour_points=np.zeros((5, 5))) + + def test_watershed_mfd_rejects_non_dataarray_pour_points(self): + fd = flow_direction_mfd(_elev()) + with pytest.raises(TypeError): + watershed_mfd(fd, pour_points=np.zeros((5, 5))) + + +# --------------------------------------------------------------------------- +# snap_pour_point_d8 +# --------------------------------------------------------------------------- + +class TestSnapPourPoint: + def test_rejects_non_dataarray_pour_points(self): + fa = create_test_raster(np.ones((5, 5), dtype=np.float64)) + with pytest.raises(TypeError): + snap_pour_point_d8(fa, pour_points=np.zeros((5, 5))) + + +# --------------------------------------------------------------------------- +# flow_path_* +# --------------------------------------------------------------------------- + +class TestFlowPathStartPoints: + def test_flow_path_d8_rejects_non_dataarray_start_points(self): + fd = flow_direction_d8(_elev()) + with pytest.raises(TypeError): + flow_path_d8(fd, start_points=np.zeros((5, 5))) + + def test_flow_path_dinf_rejects_non_dataarray_start_points(self): + fd = flow_direction_dinf(_elev()) + with pytest.raises(TypeError): + flow_path_dinf(fd, start_points=np.zeros((5, 5))) + + def test_flow_path_mfd_rejects_non_dataarray_start_points(self): + fd = flow_direction_mfd(_elev()) + with pytest.raises(TypeError): + flow_path_mfd(fd, start_points=np.zeros((5, 5))) + + +# --------------------------------------------------------------------------- +# stream_link_* +# --------------------------------------------------------------------------- + +class TestStreamLinkFlowAccum: + def test_stream_link_d8_rejects_non_dataarray_flow_accum(self): + fd = flow_direction_d8(_elev()) + with pytest.raises(TypeError): + stream_link_d8(fd, flow_accum=np.zeros((5, 5))) + + def test_stream_link_dinf_rejects_non_dataarray_flow_accum(self): + fd = flow_direction_dinf(_elev()) + with pytest.raises(TypeError): + stream_link_dinf(fd, flow_accum=np.zeros((5, 5))) + + def test_stream_link_mfd_rejects_non_dataarray_flow_accum(self): + fd = flow_direction_mfd(_elev()) + with pytest.raises(TypeError): + stream_link_mfd(fd, flow_accum=np.zeros((5, 5))) + + +# --------------------------------------------------------------------------- +# stream_order_* +# --------------------------------------------------------------------------- + +class TestStreamOrderFlowAccum: + def test_stream_order_d8_rejects_non_dataarray_flow_accum(self): + fd = flow_direction_d8(_elev()) + with pytest.raises(TypeError): + stream_order_d8(fd, flow_accum=np.zeros((5, 5))) + + def test_stream_order_dinf_rejects_non_dataarray_flow_accum(self): + fd = flow_direction_dinf(_elev()) + with pytest.raises(TypeError): + stream_order_dinf(fd, flow_accum=np.zeros((5, 5))) + + def test_stream_order_mfd_rejects_non_dataarray_flow_accum(self): + fd = flow_direction_mfd(_elev()) + with pytest.raises(TypeError): + stream_order_mfd(fd, flow_accum=np.zeros((5, 5))) diff --git a/xrspatial/hydro/watershed_d8.py b/xrspatial/hydro/watershed_d8.py index f040c87a..f65b57ad 100644 --- a/xrspatial/hydro/watershed_d8.py +++ b/xrspatial/hydro/watershed_d8.py @@ -1045,6 +1045,7 @@ def watershed_d8(flow_dir: xr.DataArray, NaN for nodata or cells not reaching any pour point. """ _validate_raster(flow_dir, func_name='watershed', name='flow_dir') + _validate_raster(pour_points, func_name='watershed', name='pour_points') data = flow_dir.data pp_data = pour_points.data diff --git a/xrspatial/hydro/watershed_dinf.py b/xrspatial/hydro/watershed_dinf.py index 1b994b19..d1e7b240 100644 --- a/xrspatial/hydro/watershed_dinf.py +++ b/xrspatial/hydro/watershed_dinf.py @@ -686,6 +686,8 @@ def watershed_dinf(flow_dir_dinf: xr.DataArray, """ _validate_raster(flow_dir_dinf, func_name='watershed_dinf', name='flow_dir_dinf') + _validate_raster(pour_points, func_name='watershed_dinf', + name='pour_points') data = flow_dir_dinf.data pp_data = pour_points.data diff --git a/xrspatial/hydro/watershed_mfd.py b/xrspatial/hydro/watershed_mfd.py index 6be3a0f5..6dfbd783 100644 --- a/xrspatial/hydro/watershed_mfd.py +++ b/xrspatial/hydro/watershed_mfd.py @@ -668,6 +668,8 @@ def watershed_mfd(flow_dir_mfd: xr.DataArray, """ _validate_raster(flow_dir_mfd, func_name='watershed_mfd', name='flow_dir_mfd', ndim=3) + _validate_raster(pour_points, func_name='watershed_mfd', + name='pour_points') data = flow_dir_mfd.data pp_data = pour_points.data