From d56b3ccbab5d0b52873c9efa8f637e1f4bb663a3 Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Tue, 28 Apr 2026 06:37:12 -0700 Subject: [PATCH] Enforce equal band shapes in true_color() (#1293) Add validate_arrays(r, g, b) so mismatched band shapes raise a clean ValueError up front instead of producing a generic broadcast error from numpy, a misleading CuPy error, or a silently misaligned dask cube. Brings true_color() in line with every other public function in multispectral.py. --- xrspatial/multispectral.py | 2 ++ xrspatial/tests/test_multispectral.py | 44 +++++++++++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/xrspatial/multispectral.py b/xrspatial/multispectral.py index 4fa5c9dd..7e6fcd8c 100644 --- a/xrspatial/multispectral.py +++ b/xrspatial/multispectral.py @@ -1793,6 +1793,8 @@ def true_color(r, g, b, nodata=1, c=10.0, th=0.125, name='true_color'): _validate_raster(g, func_name='true_color', name='g') _validate_raster(b, func_name='true_color', name='b') + validate_arrays(r, g, b) + mapper = ArrayTypeFunctionMapping( numpy_func=_true_color_numpy, dask_func=_true_color_dask, diff --git a/xrspatial/tests/test_multispectral.py b/xrspatial/tests/test_multispectral.py index 08eccec4..e2054f85 100644 --- a/xrspatial/tests/test_multispectral.py +++ b/xrspatial/tests/test_multispectral.py @@ -840,6 +840,50 @@ def test_true_color_gpu_memory_guard_raises_when_oversized(monkeypatch): multispectral._check_true_color_gpu_memory(100_000, 100_000) +def test_true_color_mismatched_shapes_raises(): + red = xr.DataArray(np.ones((4, 4), dtype=np.float32), dims=['y', 'x']) + red = red.assign_coords(y=np.arange(4), x=np.arange(4)) + green = xr.DataArray(np.ones((4, 5), dtype=np.float32), dims=['y', 'x']) + green = green.assign_coords(y=np.arange(4), x=np.arange(5)) + blue = xr.DataArray(np.ones((4, 4), dtype=np.float32), dims=['y', 'x']) + blue = blue.assign_coords(y=np.arange(4), x=np.arange(4)) + + with pytest.raises(ValueError, match='equal shapes'): + true_color(red, green, blue) + + +@dask_array_available +def test_true_color_mismatched_shapes_raises_dask(): + import dask.array as da + red = xr.DataArray( + da.ones((4, 4), chunks=(2, 2), dtype=np.float32), dims=['y', 'x']) + red = red.assign_coords(y=np.arange(4), x=np.arange(4)) + green = xr.DataArray( + da.ones((4, 5), chunks=(2, 2), dtype=np.float32), dims=['y', 'x']) + green = green.assign_coords(y=np.arange(4), x=np.arange(5)) + blue = xr.DataArray( + da.ones((4, 4), chunks=(2, 2), dtype=np.float32), dims=['y', 'x']) + blue = blue.assign_coords(y=np.arange(4), x=np.arange(4)) + + with pytest.raises(ValueError, match='equal shapes'): + true_color(red, green, blue) + + +def test_true_color_mismatched_backends_raises(): + pytest.importorskip('dask.array') + import dask.array as da + red = xr.DataArray(np.ones((4, 4), dtype=np.float32), dims=['y', 'x']) + red = red.assign_coords(y=np.arange(4), x=np.arange(4)) + green = xr.DataArray( + da.ones((4, 4), chunks=(2, 2), dtype=np.float32), dims=['y', 'x']) + green = green.assign_coords(y=np.arange(4), x=np.arange(4)) + blue = xr.DataArray(np.ones((4, 4), dtype=np.float32), dims=['y', 'x']) + blue = blue.assign_coords(y=np.arange(4), x=np.arange(4)) + + with pytest.raises(ValueError, match='same type'): + true_color(red, green, blue) + + # NDSI ---------- @pytest.fixture def expected_ndsi():