diff --git a/xrspatial/flood.py b/xrspatial/flood.py index efc983ab..e0adb94a 100644 --- a/xrspatial/flood.py +++ b/xrspatial/flood.py @@ -328,8 +328,8 @@ def _cn_runoff_numpy(p, cn): s = (25400.0 / cn) - 254.0 ia = 0.2 * s q = np.where(p > ia, (p - ia) ** 2 / (p + 0.8 * s), 0.0) - # propagate NaN from rainfall - q = np.where(np.isnan(p), np.nan, q) + # propagate NaN from rainfall or curve number + q = np.where(np.isnan(p) | np.isnan(cn), np.nan, q) return q @@ -340,7 +340,7 @@ def _cn_runoff_cupy(p, cn): s = (25400.0 / cn) - 254.0 ia = 0.2 * s q = cp.where(p > ia, (p - ia) ** 2 / (p + 0.8 * s), 0.0) - q = cp.where(cp.isnan(p), cp.nan, q) + q = cp.where(cp.isnan(p) | cp.isnan(cn), cp.nan, q) return q @@ -349,7 +349,7 @@ def _cn_runoff_dask(p, cn): s = (25400.0 / cn) - 254.0 ia = 0.2 * s q = _da.where(p > ia, (p - ia) ** 2 / (p + 0.8 * s), 0.0) - q = _da.where(_da.isnan(p), np.nan, q) + q = _da.where(_da.isnan(p) | _da.isnan(cn), np.nan, q) return q diff --git a/xrspatial/tests/test_flood.py b/xrspatial/tests/test_flood.py index e3c7304e..b30f4e91 100644 --- a/xrspatial/tests/test_flood.py +++ b/xrspatial/tests/test_flood.py @@ -362,6 +362,32 @@ def test_numpy_equals_dask_cupy(self): expected_results=result_np.data) +def test_cn_runoff_nan_curve_number_1104(): + """NaN in curve_number should produce NaN output, not 0. + + Regression test for #1104: P > NaN is always False, so np.where + took the else-branch and wrote 0.0 instead of NaN. + """ + rainfall = xr.DataArray( + np.array([[100.0, 100.0, 100.0]], dtype=np.float64) + ) + cn_data = np.array([[80.0, np.nan, 90.0]], dtype=np.float64) + cn_raster = xr.DataArray(cn_data) + + result = curve_number_runoff(rainfall, curve_number=cn_raster) + data = result.data + if hasattr(data, 'compute'): + data = data.compute() + data = np.asarray(data) + + # Cell 0 (CN=80): valid runoff + assert np.isfinite(data[0, 0]) and data[0, 0] > 0 + # Cell 1 (CN=NaN): must be NaN, not 0 + assert np.isnan(data[0, 1]), f"expected NaN, got {data[0, 1]}" + # Cell 2 (CN=90): valid runoff + assert np.isfinite(data[0, 2]) and data[0, 2] > 0 + + # =================================================================== # travel_time # ===================================================================