diff --git a/xrspatial/multispectral.py b/xrspatial/multispectral.py index 4fa5c9dd..7e6fcd8c 100644 --- a/xrspatial/multispectral.py +++ b/xrspatial/multispectral.py @@ -1793,6 +1793,8 @@ def true_color(r, g, b, nodata=1, c=10.0, th=0.125, name='true_color'): _validate_raster(g, func_name='true_color', name='g') _validate_raster(b, func_name='true_color', name='b') + validate_arrays(r, g, b) + mapper = ArrayTypeFunctionMapping( numpy_func=_true_color_numpy, dask_func=_true_color_dask, diff --git a/xrspatial/tests/test_multispectral.py b/xrspatial/tests/test_multispectral.py index 08eccec4..e2054f85 100644 --- a/xrspatial/tests/test_multispectral.py +++ b/xrspatial/tests/test_multispectral.py @@ -840,6 +840,50 @@ def test_true_color_gpu_memory_guard_raises_when_oversized(monkeypatch): multispectral._check_true_color_gpu_memory(100_000, 100_000) +def test_true_color_mismatched_shapes_raises(): + red = xr.DataArray(np.ones((4, 4), dtype=np.float32), dims=['y', 'x']) + red = red.assign_coords(y=np.arange(4), x=np.arange(4)) + green = xr.DataArray(np.ones((4, 5), dtype=np.float32), dims=['y', 'x']) + green = green.assign_coords(y=np.arange(4), x=np.arange(5)) + blue = xr.DataArray(np.ones((4, 4), dtype=np.float32), dims=['y', 'x']) + blue = blue.assign_coords(y=np.arange(4), x=np.arange(4)) + + with pytest.raises(ValueError, match='equal shapes'): + true_color(red, green, blue) + + +@dask_array_available +def test_true_color_mismatched_shapes_raises_dask(): + import dask.array as da + red = xr.DataArray( + da.ones((4, 4), chunks=(2, 2), dtype=np.float32), dims=['y', 'x']) + red = red.assign_coords(y=np.arange(4), x=np.arange(4)) + green = xr.DataArray( + da.ones((4, 5), chunks=(2, 2), dtype=np.float32), dims=['y', 'x']) + green = green.assign_coords(y=np.arange(4), x=np.arange(5)) + blue = xr.DataArray( + da.ones((4, 4), chunks=(2, 2), dtype=np.float32), dims=['y', 'x']) + blue = blue.assign_coords(y=np.arange(4), x=np.arange(4)) + + with pytest.raises(ValueError, match='equal shapes'): + true_color(red, green, blue) + + +def test_true_color_mismatched_backends_raises(): + pytest.importorskip('dask.array') + import dask.array as da + red = xr.DataArray(np.ones((4, 4), dtype=np.float32), dims=['y', 'x']) + red = red.assign_coords(y=np.arange(4), x=np.arange(4)) + green = xr.DataArray( + da.ones((4, 4), chunks=(2, 2), dtype=np.float32), dims=['y', 'x']) + green = green.assign_coords(y=np.arange(4), x=np.arange(4)) + blue = xr.DataArray(np.ones((4, 4), dtype=np.float32), dims=['y', 'x']) + blue = blue.assign_coords(y=np.arange(4), x=np.arange(4)) + + with pytest.raises(ValueError, match='same type'): + true_color(red, green, blue) + + # NDSI ---------- @pytest.fixture def expected_ndsi():