From 6e7d48396b6bf213b2d0f914c0f2e5fe64926114 Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Wed, 29 Apr 2026 17:29:35 -0700 Subject: [PATCH] Reject complex dtypes in _validate_raster() (#1384) `_validate_raster` checked `np.issubdtype(dtype, np.number)` to enforce a numeric dtype. `complex64` and `complex128` are subtypes of `np.number`, so the check passed for complex DataArrays even though every consumer expects real-valued raster data. Downstream operations either raised a confusing TypeError mid-kernel or silently dropped the imaginary part. Tighten the dtype check to also exclude `np.complexfloating` and update the docstring to say "real numeric". Add tests covering complex64/128 rejection and float32/float64/int32/int64/uint8 acceptance. --- xrspatial/tests/test_utils.py | 22 ++++++++++++++++++++++ xrspatial/utils.py | 13 +++++++++---- 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/xrspatial/tests/test_utils.py b/xrspatial/tests/test_utils.py index b0e9af15..4cc631a8 100644 --- a/xrspatial/tests/test_utils.py +++ b/xrspatial/tests/test_utils.py @@ -95,3 +95,25 @@ def fake_get_dataarray_resolution(arr): utils.warn_if_unit_mismatch(da) assert len(w) == 0, "Expected no warnings when vertical units are angles" + + +# --------------------------------------------------------------------------- +# _validate_raster dtype handling +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("dtype", [np.complex64, np.complex128]) +def test_validate_raster_rejects_complex_dtype(dtype): + """Complex dtypes are not real numeric and must be rejected.""" + raster = xr.DataArray(np.zeros((4, 4), dtype=dtype)) + with pytest.raises(ValueError, match="real numeric"): + utils._validate_raster(raster, func_name="example") + + +@pytest.mark.parametrize( + "dtype", [np.float32, np.float64, np.int32, np.int64, np.uint8], +) +def test_validate_raster_accepts_real_numeric_dtypes(dtype): + """Integer and float dtypes pass the default numeric check.""" + raster = xr.DataArray(np.zeros((4, 4), dtype=dtype)) + # Should not raise. + utils._validate_raster(raster, func_name="example") diff --git a/xrspatial/utils.py b/xrspatial/utils.py index 26013b81..1d0e7d75 100644 --- a/xrspatial/utils.py +++ b/xrspatial/utils.py @@ -63,7 +63,9 @@ def _validate_raster( ndim : int, tuple of int, or None Allowed number of dimensions. ``None`` skips the check. numeric : bool - If True, require a numeric dtype (int or float). + If True, require a real numeric dtype (integer or float). + Complex dtypes are rejected because xrspatial operations + assume real-valued raster data. integer_only : bool If True, require an integer dtype specifically. @@ -97,10 +99,13 @@ def _validate_raster( f"got {agg.dtype}" ) else: - if not np.issubdtype(agg.dtype, np.number): + if ( + not np.issubdtype(agg.dtype, np.number) + or np.issubdtype(agg.dtype, np.complexfloating) + ): raise ValueError( - f"{func_name}(): `{name}` must have a numeric dtype " - f"(integer or float), got {agg.dtype}" + f"{func_name}(): `{name}` must have a real numeric " + f"dtype (integer or float), got {agg.dtype}" )