Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions xrspatial/hydro/flow_path_d8.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,8 @@ def flow_path_d8(flow_dir: xr.DataArray,
raster-scan order wins.
"""
_validate_raster(flow_dir, func_name='flow_path', name='flow_dir')
_validate_raster(start_points, func_name='flow_path',
name='start_points')

fd_data = flow_dir.data
sp_data = start_points.data
Expand Down
2 changes: 2 additions & 0 deletions xrspatial/hydro/flow_path_dinf.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,8 @@ def flow_path_dinf(flow_dir_dinf: xr.DataArray,
"""
_validate_raster(flow_dir_dinf, func_name='flow_path_dinf',
name='flow_dir_dinf')
_validate_raster(start_points, func_name='flow_path_dinf',
name='start_points')

fd_data = flow_dir_dinf.data
sp_data = start_points.data
Expand Down
2 changes: 2 additions & 0 deletions xrspatial/hydro/flow_path_mfd.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,8 @@ def flow_path_mfd(flow_dir_mfd: xr.DataArray,
"""
_validate_raster(flow_dir_mfd, func_name='flow_path_mfd',
name='flow_dir_mfd', ndim=3)
_validate_raster(start_points, func_name='flow_path_mfd',
name='start_points')

data = flow_dir_mfd.data
sp_data = start_points.data
Expand Down
2 changes: 2 additions & 0 deletions xrspatial/hydro/snap_pour_point_d8.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,8 @@ def snap_pour_point_d8(flow_accum: xr.DataArray,
locations. Non-pour-point cells are NaN.
"""
_validate_raster(flow_accum, func_name='snap_pour_point', name='flow_accum')
_validate_raster(pour_points, func_name='snap_pour_point',
name='pour_points')

fa_data = flow_accum.data
pp_data = pour_points.data
Expand Down
1 change: 1 addition & 0 deletions xrspatial/hydro/stream_link_d8.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,7 @@ def stream_link_d8(flow_dir: xr.DataArray,
integer ID. Non-stream cells are NaN.
"""
_validate_raster(flow_dir, func_name='stream_link', name='flow_dir')
_validate_raster(flow_accum, func_name='stream_link', name='flow_accum')

fd_data = flow_dir.data
fa_data = flow_accum.data
Expand Down
3 changes: 3 additions & 0 deletions xrspatial/hydro/stream_link_dinf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1280,6 +1280,9 @@ def stream_link_dinf(flow_dir_dinf: xr.DataArray,
_validate_raster(flow_dir_dinf,
func_name='stream_link_dinf',
name='flow_dir_dinf')
_validate_raster(flow_accum,
func_name='stream_link_dinf',
name='flow_accum')

fd_data = flow_dir_dinf.data
fa_data = flow_accum.data
Expand Down
2 changes: 2 additions & 0 deletions xrspatial/hydro/stream_link_mfd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,8 @@ def stream_link_mfd(fractions: xr.DataArray,
"""
_validate_raster(fractions, func_name='stream_link_mfd',
name='fractions', ndim=3)
_validate_raster(flow_accum, func_name='stream_link_mfd',
name='flow_accum')

data = fractions.data

Expand Down
1 change: 1 addition & 0 deletions xrspatial/hydro/stream_order_d8.py
Original file line number Diff line number Diff line change
Expand Up @@ -1594,6 +1594,7 @@ def stream_order_d8(flow_dir: xr.DataArray,
of Geology, 74(1), 17-37.
"""
_validate_raster(flow_dir, func_name='stream_order', name='flow_dir')
_validate_raster(flow_accum, func_name='stream_order', name='flow_accum')

method = ordering.lower()
if method not in ('strahler', 'shreve'):
Expand Down
3 changes: 3 additions & 0 deletions xrspatial/hydro/stream_order_dinf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1934,6 +1934,9 @@ def stream_order_dinf(flow_dir_dinf: xr.DataArray,
_validate_raster(flow_dir_dinf,
func_name='stream_order_dinf',
name='flow_dir_dinf')
_validate_raster(flow_accum,
func_name='stream_order_dinf',
name='flow_accum')

method = method.lower()
if method not in ('strahler', 'shreve'):
Expand Down
2 changes: 2 additions & 0 deletions xrspatial/hydro/stream_order_mfd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1495,6 +1495,8 @@ def stream_order_mfd(fractions: xr.DataArray,
"""
_validate_raster(fractions, func_name='stream_order_mfd',
name='fractions', ndim=3)
_validate_raster(flow_accum, func_name='stream_order_mfd',
name='flow_accum')

method = method.lower()
if method not in ('strahler', 'shreve'):
Expand Down
137 changes: 137 additions & 0 deletions xrspatial/hydro/tests/test_validate_secondary_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""Tests for issue #1425: hydro public APIs validate secondary DataArray args.

Each function below previously validated only its primary raster. Passing
a non-DataArray, a 1-D DataArray, or `None` for the secondary arg raised
a confusing AttributeError / IndexError from inside the implementation.
The fix adds `_validate_raster` on the secondary arg in the public API.
"""

import numpy as np
import pytest
import xarray as xr

from xrspatial.hydro import (
flow_direction_d8,
flow_direction_dinf,
flow_direction_mfd,
flow_path_d8,
flow_path_dinf,
flow_path_mfd,
snap_pour_point_d8,
stream_link_d8,
stream_link_dinf,
stream_link_mfd,
stream_order_d8,
stream_order_dinf,
stream_order_mfd,
watershed_d8,
watershed_dinf,
watershed_mfd,
)
from xrspatial.tests.general_checks import create_test_raster


def _elev():
"""Small bowl elevation raster used as the primary input."""
return create_test_raster(np.array([
[9, 9, 9, 9, 9],
[9, 8, 7, 6, 9],
[9, 7, 5, 4, 9],
[9, 6, 4, 3, 9],
[9, 9, 9, 9, 9],
], dtype=np.float64))


# ---------------------------------------------------------------------------
# watershed_*
# ---------------------------------------------------------------------------

class TestWatershedPourPoints:
def test_watershed_d8_rejects_non_dataarray_pour_points(self):
fd = flow_direction_d8(_elev())
with pytest.raises(TypeError):
watershed_d8(fd, pour_points=np.zeros((5, 5)))

def test_watershed_dinf_rejects_non_dataarray_pour_points(self):
fd = flow_direction_dinf(_elev())
with pytest.raises(TypeError):
watershed_dinf(fd, pour_points=np.zeros((5, 5)))

def test_watershed_mfd_rejects_non_dataarray_pour_points(self):
fd = flow_direction_mfd(_elev())
with pytest.raises(TypeError):
watershed_mfd(fd, pour_points=np.zeros((5, 5)))


# ---------------------------------------------------------------------------
# snap_pour_point_d8
# ---------------------------------------------------------------------------

class TestSnapPourPoint:
def test_rejects_non_dataarray_pour_points(self):
fa = create_test_raster(np.ones((5, 5), dtype=np.float64))
with pytest.raises(TypeError):
snap_pour_point_d8(fa, pour_points=np.zeros((5, 5)))


# ---------------------------------------------------------------------------
# flow_path_*
# ---------------------------------------------------------------------------

class TestFlowPathStartPoints:
def test_flow_path_d8_rejects_non_dataarray_start_points(self):
fd = flow_direction_d8(_elev())
with pytest.raises(TypeError):
flow_path_d8(fd, start_points=np.zeros((5, 5)))

def test_flow_path_dinf_rejects_non_dataarray_start_points(self):
fd = flow_direction_dinf(_elev())
with pytest.raises(TypeError):
flow_path_dinf(fd, start_points=np.zeros((5, 5)))

def test_flow_path_mfd_rejects_non_dataarray_start_points(self):
fd = flow_direction_mfd(_elev())
with pytest.raises(TypeError):
flow_path_mfd(fd, start_points=np.zeros((5, 5)))


# ---------------------------------------------------------------------------
# stream_link_*
# ---------------------------------------------------------------------------

class TestStreamLinkFlowAccum:
def test_stream_link_d8_rejects_non_dataarray_flow_accum(self):
fd = flow_direction_d8(_elev())
with pytest.raises(TypeError):
stream_link_d8(fd, flow_accum=np.zeros((5, 5)))

def test_stream_link_dinf_rejects_non_dataarray_flow_accum(self):
fd = flow_direction_dinf(_elev())
with pytest.raises(TypeError):
stream_link_dinf(fd, flow_accum=np.zeros((5, 5)))

def test_stream_link_mfd_rejects_non_dataarray_flow_accum(self):
fd = flow_direction_mfd(_elev())
with pytest.raises(TypeError):
stream_link_mfd(fd, flow_accum=np.zeros((5, 5)))


# ---------------------------------------------------------------------------
# stream_order_*
# ---------------------------------------------------------------------------

class TestStreamOrderFlowAccum:
def test_stream_order_d8_rejects_non_dataarray_flow_accum(self):
fd = flow_direction_d8(_elev())
with pytest.raises(TypeError):
stream_order_d8(fd, flow_accum=np.zeros((5, 5)))

def test_stream_order_dinf_rejects_non_dataarray_flow_accum(self):
fd = flow_direction_dinf(_elev())
with pytest.raises(TypeError):
stream_order_dinf(fd, flow_accum=np.zeros((5, 5)))

def test_stream_order_mfd_rejects_non_dataarray_flow_accum(self):
fd = flow_direction_mfd(_elev())
with pytest.raises(TypeError):
stream_order_mfd(fd, flow_accum=np.zeros((5, 5)))
1 change: 1 addition & 0 deletions xrspatial/hydro/watershed_d8.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,7 @@ def watershed_d8(flow_dir: xr.DataArray,
NaN for nodata or cells not reaching any pour point.
"""
_validate_raster(flow_dir, func_name='watershed', name='flow_dir')
_validate_raster(pour_points, func_name='watershed', name='pour_points')

data = flow_dir.data
pp_data = pour_points.data
Expand Down
2 changes: 2 additions & 0 deletions xrspatial/hydro/watershed_dinf.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,8 @@ def watershed_dinf(flow_dir_dinf: xr.DataArray,
"""
_validate_raster(flow_dir_dinf, func_name='watershed_dinf',
name='flow_dir_dinf')
_validate_raster(pour_points, func_name='watershed_dinf',
name='pour_points')

data = flow_dir_dinf.data
pp_data = pour_points.data
Expand Down
2 changes: 2 additions & 0 deletions xrspatial/hydro/watershed_mfd.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,8 @@ def watershed_mfd(flow_dir_mfd: xr.DataArray,
"""
_validate_raster(flow_dir_mfd, func_name='watershed_mfd',
name='flow_dir_mfd', ndim=3)
_validate_raster(pour_points, func_name='watershed_mfd',
name='pour_points')

data = flow_dir_mfd.data
pp_data = pour_points.data
Expand Down
Loading