From 2aca2776f703193a820310c0993066f60cc219ed Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Mon, 4 May 2026 07:41:48 -0700 Subject: [PATCH] Fix dask reproject dtype and same-CRS dask merge The dask reproject path always built its output template with dtype float64, but the per-chunk numpy function clamps integer sources back to their original dtype. Dask then advertised float64 in its meta while the chunks underneath returned int. Pick the template dtype from the source instead, mirroring the eager path. Mirror the same integer round-trip in the cupy chunk for backend parity. The dask merge adapter always called the reproject chunk, even when source CRS matched target CRS. The eager merge has used the direct pixel placement shortcut (`_place_same_crs`) for that case, so dask and eager produced numerically different pixels for same-CRS inputs. Precompute `same_crs_list` in `_merge_dask` and pass it through, then let the per-block adapter try `_place_same_crs` first and fall back to the reproject chunk when resolutions are too far apart. Add tests for: - dask reproject preserving int8 / uint16 dtype - dask reproject keeping float32 -> float64 (regression guard) - same-CRS dask merge bit-equal to eager merge - different-CRS dask merge matching eager within rtol=1e-10 - end-to-end cupy parity vs numpy - end-to-end dask+cupy parity vs numpy Closes #1447 --- xrspatial/reproject/__init__.py | 83 +++++++++---- xrspatial/tests/test_reproject.py | 186 ++++++++++++++++++++++++++++++ 2 files changed, 249 insertions(+), 20 deletions(-) diff --git a/xrspatial/reproject/__init__.py b/xrspatial/reproject/__init__.py index ac9f93cd..4ecb031f 100644 --- a/xrspatial/reproject/__init__.py +++ b/xrspatial/reproject/__init__.py @@ -423,6 +423,7 @@ def _reproject_chunk_cupy( window = window.compute() if not isinstance(window, cp.ndarray): window = cp.asarray(window) + orig_dtype = window.dtype window = window.astype(cp.float64) # Adjust coordinates relative to window (stays on GPU if CuPy) @@ -432,16 +433,22 @@ def _reproject_chunk_cupy( if _use_native_cuda: # Coordinates are already CuPy arrays -- use native CUDA kernels # (nodata->NaN conversion is handled inside _resample_cupy_native) - return _resample_cupy_native(window, local_row, local_col, - resampling=resampling, nodata=nodata) + result = _resample_cupy_native(window, local_row, local_col, + resampling=resampling, nodata=nodata) + else: + # CPU coordinates -- convert sentinel nodata to NaN before map_coordinates + if not np.isnan(nodata): + window = window.copy() + window[window == nodata] = cp.nan - # CPU coordinates -- convert sentinel nodata to NaN before map_coordinates - if not np.isnan(nodata): - window = window.copy() - window[window == nodata] = cp.nan + result = _resample_cupy(window, local_row, local_col, + resampling=resampling, nodata=nodata) - return _resample_cupy(window, local_row, local_col, - resampling=resampling, nodata=nodata) + # Clamp and cast back for integer source dtypes (parity with numpy path) + if np.issubdtype(orig_dtype, np.integer): + info = np.iinfo(orig_dtype) + result = cp.clip(cp.round(result), info.min, info.max).astype(orig_dtype) + return result # --------------------------------------------------------------------------- @@ -1298,15 +1305,25 @@ def _reproject_dask( src_footprint_tgt=src_footprint_tgt, ) + # Pick the template dtype to match the eager path: integer sources + # round-trip back to their original dtype after clamping; floats stay + # float64. Without this, dask claims float64 meta but the chunks + # actually return the integer dtype, producing inconsistent output. + src_dtype = np.dtype(raster.dtype) + if np.issubdtype(src_dtype, np.integer): + out_dtype = src_dtype + else: + out_dtype = np.dtype(np.float64) + template = da.empty( - out_shape, dtype=np.float64, chunks=(row_chunks, col_chunks) + out_shape, dtype=out_dtype, chunks=(row_chunks, col_chunks) ) return da.map_blocks( bound_adapter, template, - dtype=np.float64, - meta=np.array((), dtype=np.float64), + dtype=out_dtype, + meta=np.array((), dtype=out_dtype), ) @@ -1583,7 +1600,7 @@ def _merge_block_adapter( src_wkt_list, tgt_wkt, out_bounds, out_shape, resampling, nodata, strategy, precision, - src_footprints_tgt, + src_footprints_tgt, same_crs_list, ): """``map_blocks`` adapter for merge.""" info = block_info[0] @@ -1598,14 +1615,30 @@ def _merge_block_adapter( if (src_footprints_tgt[i] is not None and not _bounds_overlap(cb, src_footprints_tgt[i])): continue - reprojected = _reproject_chunk_numpy( - raster_data_list[i], - src_bounds_list[i], src_shape_list[i], y_desc_list[i], - src_wkt_list[i], tgt_wkt, - cb, chunk_shape, - resampling, nodata, precision, - ) - arrays.append(reprojected) + + placed = None + if same_crs_list[i]: + # Same-CRS path: direct pixel placement (no resampling). + # Mirrors the eager merge so dask matches numpy bit-for-bit. + src_data = raster_data_list[i] + if hasattr(src_data, 'compute'): + src_data = src_data.compute() + placed = _place_same_crs( + np.asarray(src_data), + src_bounds_list[i], src_shape_list[i], y_desc_list[i], + cb, chunk_shape, nodata, + ) + if placed is not None: + arrays.append(placed) + else: + reprojected = _reproject_chunk_numpy( + raster_data_list[i], + src_bounds_list[i], src_shape_list[i], y_desc_list[i], + src_wkt_list[i], tgt_wkt, + cb, chunk_shape, + resampling, nodata, precision, + ) + arrays.append(reprojected) if not arrays: return np.full(chunk_shape, nodata, dtype=np.float64) @@ -1636,6 +1669,15 @@ def _merge_dask( for i in range(len(raster_infos)) ] + # Precompute CRS-equality flags so per-block adapters can shortcut to + # direct pixel placement (matches the eager _merge_inmemory path). + from ._crs_utils import _crs_from_wkt + tgt_crs = _crs_from_wkt(tgt_wkt) + same_crs_list = [ + bool(_crs_from_wkt(wkt_list[i]) == tgt_crs) + for i in range(len(raster_infos)) + ] + # Bind via partial to prevent map_blocks from adding dask arrays # in data_list as whole-array dependencies. bound_adapter = functools.partial( @@ -1653,6 +1695,7 @@ def _merge_dask( strategy=strategy, precision=16, src_footprints_tgt=footprints, + same_crs_list=same_crs_list, ) template = da.empty( diff --git a/xrspatial/tests/test_reproject.py b/xrspatial/tests/test_reproject.py index 704469c4..3fec9bf1 100644 --- a/xrspatial/tests/test_reproject.py +++ b/xrspatial/tests/test_reproject.py @@ -1709,3 +1709,189 @@ def test_detect_nodata_accepts_finite(self): from xrspatial.reproject._crs_utils import _detect_nodata r = xr.DataArray(np.zeros((4, 4)), dims=('y', 'x')) assert _detect_nodata(r, nodata=-9999) == -9999.0 + + +# --------------------------------------------------------------------------- +# Backend parity: dask dtype + same-CRS dask merge + cupy +# --------------------------------------------------------------------------- + +@pytest.mark.skipif(not HAS_DASK, reason="dask required") +class TestDaskDtypeParity: + """Dask reproject should preserve source integer dtype (matches numpy).""" + + def test_dask_reproject_int8_preserves_dtype(self): + from xrspatial.reproject import reproject + data = np.arange(64, dtype=np.int8).reshape(8, 8) + coords = {'y': np.linspace(5, -5, 8), 'x': np.linspace(-5, 5, 8)} + attrs = {'crs': 'EPSG:4326', 'nodata': -1} + raster = xr.DataArray( + da.from_array(data, chunks=(4, 4)), + dims=['y', 'x'], coords=coords, attrs=attrs, + ) + result = reproject(raster, 'EPSG:4326', resolution=1.0) + # Lazy meta dtype should match + assert result.data.dtype == np.int8 + # Computed dtype should also match + assert result.compute().dtype == np.int8 + + def test_dask_reproject_uint16_preserves_dtype(self): + from xrspatial.reproject import reproject + data = (np.arange(64, dtype=np.uint16) * 100).reshape(8, 8) + coords = {'y': np.linspace(5, -5, 8), 'x': np.linspace(-5, 5, 8)} + attrs = {'crs': 'EPSG:4326', 'nodata': 0} + raster = xr.DataArray( + da.from_array(data, chunks=(4, 4)), + dims=['y', 'x'], coords=coords, attrs=attrs, + ) + result = reproject(raster, 'EPSG:4326', resolution=1.0) + assert result.data.dtype == np.uint16 + assert result.compute().dtype == np.uint16 + + def test_dask_reproject_float32_stays_float64(self): + """Float input still upcasts to float64 (existing behaviour guard).""" + from xrspatial.reproject import reproject + data = np.random.RandomState(0).rand(8, 8).astype(np.float32) + coords = {'y': np.linspace(5, -5, 8), 'x': np.linspace(-5, 5, 8)} + attrs = {'crs': 'EPSG:4326', 'nodata': np.nan} + raster = xr.DataArray( + da.from_array(data, chunks=(4, 4)), + dims=['y', 'x'], coords=coords, attrs=attrs, + ) + result = reproject(raster, 'EPSG:4326', resolution=1.0) + assert result.data.dtype == np.float64 + assert result.compute().dtype == np.float64 + + +@pytest.mark.skipif(not HAS_DASK, reason="dask required") +class TestMergeDaskParity: + """Dask merge should match the eager numpy merge.""" + + def test_merge_dask_same_crs_matches_eager(self): + """Same-CRS merge should be bit-equal between eager and dask paths. + + Source and output resolutions match (within 1%) so + ``_place_same_crs`` activates in both paths -- direct pixel copy + means the dask result must equal the eager result bit-for-bit. + """ + from xrspatial.reproject import merge + # 16 pixels with center-to-center spacing of exactly 1.0 -> bounds + # extend half a pixel past coords, source resolution matches output. + a_data = np.arange(256, dtype=np.float64).reshape(16, 16) + b_data = (np.arange(256, dtype=np.float64) * 2).reshape(16, 16) + a = _make_raster(a_data, x_range=(-7.5, 7.5), y_range=(-7.5, 7.5)) + b = _make_raster(b_data, x_range=(8.5, 23.5), y_range=(-7.5, 7.5)) + + eager = merge([a, b], resolution=1.0).compute().values + + a_dask = a.copy() + b_dask = b.copy() + a_dask.data = da.from_array(a_data, chunks=(8, 8)) + b_dask.data = da.from_array(b_data, chunks=(8, 8)) + dasked = merge( + [a_dask, b_dask], resolution=1.0, chunk_size=8, + ).compute().values + + assert eager.shape == dasked.shape + eager_nan = np.isnan(eager) + dask_nan = np.isnan(dasked) + np.testing.assert_array_equal(eager_nan, dask_nan) + # Finite values must be bit-equal: same-CRS path is direct copy + np.testing.assert_array_equal(eager[~eager_nan], dasked[~dask_nan]) + + def test_merge_dask_different_crs_matches_eager(self): + """Different-CRS merge should match within float tolerance.""" + from xrspatial.reproject import merge + a_data = np.arange(256, dtype=np.float64).reshape(16, 16) + b_data = (np.arange(256, dtype=np.float64) + 100.0).reshape(16, 16) + # One in WGS84, one in Web Mercator (forces reprojection) + a = _make_raster(a_data, crs='EPSG:4326', + x_range=(-10, 0), y_range=(-5, 5)) + # Build a Web-Mercator tile that overlaps the target + b = _make_raster(b_data, crs='EPSG:3857', + x_range=(0, 1_000_000), y_range=(-500_000, 500_000)) + + eager = merge( + [a, b], target_crs='EPSG:4326', resolution=1.0, + ).compute().values + + a_dask = a.copy() + b_dask = b.copy() + a_dask.data = da.from_array(a_data, chunks=(8, 8)) + b_dask.data = da.from_array(b_data, chunks=(8, 8)) + dasked = merge( + [a_dask, b_dask], target_crs='EPSG:4326', + resolution=1.0, chunk_size=8, + ).compute().values + + assert eager.shape == dasked.shape + np.testing.assert_array_equal(np.isnan(eager), np.isnan(dasked)) + finite = np.isfinite(eager) + if finite.any(): + np.testing.assert_allclose( + eager[finite], dasked[finite], rtol=1e-10, atol=1e-10, + ) + + +@pytest.mark.skipif(not HAS_CUPY, reason="cupy required") +class TestCupyReprojectParity: + """End-to-end cupy backend parity checks.""" + + def test_cupy_reproject_matches_numpy(self): + from xrspatial.reproject import reproject + data = np.random.RandomState(7).rand(32, 32).astype(np.float64) + coords = {'y': np.linspace(55, 45, 32), 'x': np.linspace(-5, 5, 32)} + attrs = {'crs': 'EPSG:4326', 'nodata': np.nan} + + np_raster = xr.DataArray(data, dims=['y', 'x'], + coords=coords, attrs=attrs) + cp_raster = xr.DataArray(cp.asarray(data), dims=['y', 'x'], + coords=coords, attrs=attrs) + np_result = reproject(np_raster, 'EPSG:3857').values + cp_result_arr = reproject(cp_raster, 'EPSG:3857').data + # cupy DataArray: pull through .get() to avoid implicit numpy convert + if hasattr(cp_result_arr, 'get'): + cp_vals = cp_result_arr.get() + else: + cp_vals = np.asarray(cp_result_arr) + + assert np_result.shape == cp_vals.shape + np.testing.assert_array_equal( + np.isnan(np_result), np.isnan(cp_vals), + ) + finite = np.isfinite(np_result) + if finite.any(): + np.testing.assert_allclose( + np_result[finite], cp_vals[finite], rtol=1e-5, atol=1e-5, + ) + + @pytest.mark.skipif(not HAS_DASK, reason="dask required") + def test_dask_cupy_reproject_matches_numpy(self): + from xrspatial.reproject import reproject + data = np.random.RandomState(11).rand(32, 32).astype(np.float64) + coords = {'y': np.linspace(55, 45, 32), 'x': np.linspace(-5, 5, 32)} + attrs = {'crs': 'EPSG:4326', 'nodata': np.nan} + + np_raster = xr.DataArray(data, dims=['y', 'x'], + coords=coords, attrs=attrs) + dc_raster = xr.DataArray( + da.from_array(cp.asarray(data), chunks=(16, 16)), + dims=['y', 'x'], coords=coords, attrs=attrs, + ) + np_result = reproject(np_raster, 'EPSG:3857').values + dc_arr = reproject(dc_raster, 'EPSG:3857').data + if hasattr(dc_arr, 'compute'): + dc_arr = dc_arr.compute() + if hasattr(dc_arr, 'get'): + dc_vals = dc_arr.get() + else: + dc_vals = np.asarray(dc_arr) + + assert np_result.shape == dc_vals.shape + np.testing.assert_array_equal( + np.isnan(np_result), np.isnan(dc_vals), + ) + finite = np.isfinite(np_result) + if finite.any(): + np.testing.assert_allclose( + np_result[finite], dc_vals[finite], rtol=1e-5, atol=1e-5, + )