diff --git a/xrspatial/reproject/__init__.py b/xrspatial/reproject/__init__.py index ac9f93cd..4ecb031f 100644 --- a/xrspatial/reproject/__init__.py +++ b/xrspatial/reproject/__init__.py @@ -423,6 +423,7 @@ def _reproject_chunk_cupy( window = window.compute() if not isinstance(window, cp.ndarray): window = cp.asarray(window) + orig_dtype = window.dtype window = window.astype(cp.float64) # Adjust coordinates relative to window (stays on GPU if CuPy) @@ -432,16 +433,22 @@ def _reproject_chunk_cupy( if _use_native_cuda: # Coordinates are already CuPy arrays -- use native CUDA kernels # (nodata->NaN conversion is handled inside _resample_cupy_native) - return _resample_cupy_native(window, local_row, local_col, - resampling=resampling, nodata=nodata) + result = _resample_cupy_native(window, local_row, local_col, + resampling=resampling, nodata=nodata) + else: + # CPU coordinates -- convert sentinel nodata to NaN before map_coordinates + if not np.isnan(nodata): + window = window.copy() + window[window == nodata] = cp.nan - # CPU coordinates -- convert sentinel nodata to NaN before map_coordinates - if not np.isnan(nodata): - window = window.copy() - window[window == nodata] = cp.nan + result = _resample_cupy(window, local_row, local_col, + resampling=resampling, nodata=nodata) - return _resample_cupy(window, local_row, local_col, - resampling=resampling, nodata=nodata) + # Clamp and cast back for integer source dtypes (parity with numpy path) + if np.issubdtype(orig_dtype, np.integer): + info = np.iinfo(orig_dtype) + result = cp.clip(cp.round(result), info.min, info.max).astype(orig_dtype) + return result # --------------------------------------------------------------------------- @@ -1298,15 +1305,25 @@ def _reproject_dask( src_footprint_tgt=src_footprint_tgt, ) + # Pick the template dtype to match the eager path: integer sources + # round-trip back to their original dtype after clamping; floats stay + # float64. Without this, dask claims float64 meta but the chunks + # actually return the integer dtype, producing inconsistent output. + src_dtype = np.dtype(raster.dtype) + if np.issubdtype(src_dtype, np.integer): + out_dtype = src_dtype + else: + out_dtype = np.dtype(np.float64) + template = da.empty( - out_shape, dtype=np.float64, chunks=(row_chunks, col_chunks) + out_shape, dtype=out_dtype, chunks=(row_chunks, col_chunks) ) return da.map_blocks( bound_adapter, template, - dtype=np.float64, - meta=np.array((), dtype=np.float64), + dtype=out_dtype, + meta=np.array((), dtype=out_dtype), ) @@ -1583,7 +1600,7 @@ def _merge_block_adapter( src_wkt_list, tgt_wkt, out_bounds, out_shape, resampling, nodata, strategy, precision, - src_footprints_tgt, + src_footprints_tgt, same_crs_list, ): """``map_blocks`` adapter for merge.""" info = block_info[0] @@ -1598,14 +1615,30 @@ def _merge_block_adapter( if (src_footprints_tgt[i] is not None and not _bounds_overlap(cb, src_footprints_tgt[i])): continue - reprojected = _reproject_chunk_numpy( - raster_data_list[i], - src_bounds_list[i], src_shape_list[i], y_desc_list[i], - src_wkt_list[i], tgt_wkt, - cb, chunk_shape, - resampling, nodata, precision, - ) - arrays.append(reprojected) + + placed = None + if same_crs_list[i]: + # Same-CRS path: direct pixel placement (no resampling). + # Mirrors the eager merge so dask matches numpy bit-for-bit. + src_data = raster_data_list[i] + if hasattr(src_data, 'compute'): + src_data = src_data.compute() + placed = _place_same_crs( + np.asarray(src_data), + src_bounds_list[i], src_shape_list[i], y_desc_list[i], + cb, chunk_shape, nodata, + ) + if placed is not None: + arrays.append(placed) + else: + reprojected = _reproject_chunk_numpy( + raster_data_list[i], + src_bounds_list[i], src_shape_list[i], y_desc_list[i], + src_wkt_list[i], tgt_wkt, + cb, chunk_shape, + resampling, nodata, precision, + ) + arrays.append(reprojected) if not arrays: return np.full(chunk_shape, nodata, dtype=np.float64) @@ -1636,6 +1669,15 @@ def _merge_dask( for i in range(len(raster_infos)) ] + # Precompute CRS-equality flags so per-block adapters can shortcut to + # direct pixel placement (matches the eager _merge_inmemory path). + from ._crs_utils import _crs_from_wkt + tgt_crs = _crs_from_wkt(tgt_wkt) + same_crs_list = [ + bool(_crs_from_wkt(wkt_list[i]) == tgt_crs) + for i in range(len(raster_infos)) + ] + # Bind via partial to prevent map_blocks from adding dask arrays # in data_list as whole-array dependencies. bound_adapter = functools.partial( @@ -1653,6 +1695,7 @@ def _merge_dask( strategy=strategy, precision=16, src_footprints_tgt=footprints, + same_crs_list=same_crs_list, ) template = da.empty( diff --git a/xrspatial/tests/test_reproject.py b/xrspatial/tests/test_reproject.py index 704469c4..3fec9bf1 100644 --- a/xrspatial/tests/test_reproject.py +++ b/xrspatial/tests/test_reproject.py @@ -1709,3 +1709,189 @@ def test_detect_nodata_accepts_finite(self): from xrspatial.reproject._crs_utils import _detect_nodata r = xr.DataArray(np.zeros((4, 4)), dims=('y', 'x')) assert _detect_nodata(r, nodata=-9999) == -9999.0 + + +# --------------------------------------------------------------------------- +# Backend parity: dask dtype + same-CRS dask merge + cupy +# --------------------------------------------------------------------------- + +@pytest.mark.skipif(not HAS_DASK, reason="dask required") +class TestDaskDtypeParity: + """Dask reproject should preserve source integer dtype (matches numpy).""" + + def test_dask_reproject_int8_preserves_dtype(self): + from xrspatial.reproject import reproject + data = np.arange(64, dtype=np.int8).reshape(8, 8) + coords = {'y': np.linspace(5, -5, 8), 'x': np.linspace(-5, 5, 8)} + attrs = {'crs': 'EPSG:4326', 'nodata': -1} + raster = xr.DataArray( + da.from_array(data, chunks=(4, 4)), + dims=['y', 'x'], coords=coords, attrs=attrs, + ) + result = reproject(raster, 'EPSG:4326', resolution=1.0) + # Lazy meta dtype should match + assert result.data.dtype == np.int8 + # Computed dtype should also match + assert result.compute().dtype == np.int8 + + def test_dask_reproject_uint16_preserves_dtype(self): + from xrspatial.reproject import reproject + data = (np.arange(64, dtype=np.uint16) * 100).reshape(8, 8) + coords = {'y': np.linspace(5, -5, 8), 'x': np.linspace(-5, 5, 8)} + attrs = {'crs': 'EPSG:4326', 'nodata': 0} + raster = xr.DataArray( + da.from_array(data, chunks=(4, 4)), + dims=['y', 'x'], coords=coords, attrs=attrs, + ) + result = reproject(raster, 'EPSG:4326', resolution=1.0) + assert result.data.dtype == np.uint16 + assert result.compute().dtype == np.uint16 + + def test_dask_reproject_float32_stays_float64(self): + """Float input still upcasts to float64 (existing behaviour guard).""" + from xrspatial.reproject import reproject + data = np.random.RandomState(0).rand(8, 8).astype(np.float32) + coords = {'y': np.linspace(5, -5, 8), 'x': np.linspace(-5, 5, 8)} + attrs = {'crs': 'EPSG:4326', 'nodata': np.nan} + raster = xr.DataArray( + da.from_array(data, chunks=(4, 4)), + dims=['y', 'x'], coords=coords, attrs=attrs, + ) + result = reproject(raster, 'EPSG:4326', resolution=1.0) + assert result.data.dtype == np.float64 + assert result.compute().dtype == np.float64 + + +@pytest.mark.skipif(not HAS_DASK, reason="dask required") +class TestMergeDaskParity: + """Dask merge should match the eager numpy merge.""" + + def test_merge_dask_same_crs_matches_eager(self): + """Same-CRS merge should be bit-equal between eager and dask paths. + + Source and output resolutions match (within 1%) so + ``_place_same_crs`` activates in both paths -- direct pixel copy + means the dask result must equal the eager result bit-for-bit. + """ + from xrspatial.reproject import merge + # 16 pixels with center-to-center spacing of exactly 1.0 -> bounds + # extend half a pixel past coords, source resolution matches output. + a_data = np.arange(256, dtype=np.float64).reshape(16, 16) + b_data = (np.arange(256, dtype=np.float64) * 2).reshape(16, 16) + a = _make_raster(a_data, x_range=(-7.5, 7.5), y_range=(-7.5, 7.5)) + b = _make_raster(b_data, x_range=(8.5, 23.5), y_range=(-7.5, 7.5)) + + eager = merge([a, b], resolution=1.0).compute().values + + a_dask = a.copy() + b_dask = b.copy() + a_dask.data = da.from_array(a_data, chunks=(8, 8)) + b_dask.data = da.from_array(b_data, chunks=(8, 8)) + dasked = merge( + [a_dask, b_dask], resolution=1.0, chunk_size=8, + ).compute().values + + assert eager.shape == dasked.shape + eager_nan = np.isnan(eager) + dask_nan = np.isnan(dasked) + np.testing.assert_array_equal(eager_nan, dask_nan) + # Finite values must be bit-equal: same-CRS path is direct copy + np.testing.assert_array_equal(eager[~eager_nan], dasked[~dask_nan]) + + def test_merge_dask_different_crs_matches_eager(self): + """Different-CRS merge should match within float tolerance.""" + from xrspatial.reproject import merge + a_data = np.arange(256, dtype=np.float64).reshape(16, 16) + b_data = (np.arange(256, dtype=np.float64) + 100.0).reshape(16, 16) + # One in WGS84, one in Web Mercator (forces reprojection) + a = _make_raster(a_data, crs='EPSG:4326', + x_range=(-10, 0), y_range=(-5, 5)) + # Build a Web-Mercator tile that overlaps the target + b = _make_raster(b_data, crs='EPSG:3857', + x_range=(0, 1_000_000), y_range=(-500_000, 500_000)) + + eager = merge( + [a, b], target_crs='EPSG:4326', resolution=1.0, + ).compute().values + + a_dask = a.copy() + b_dask = b.copy() + a_dask.data = da.from_array(a_data, chunks=(8, 8)) + b_dask.data = da.from_array(b_data, chunks=(8, 8)) + dasked = merge( + [a_dask, b_dask], target_crs='EPSG:4326', + resolution=1.0, chunk_size=8, + ).compute().values + + assert eager.shape == dasked.shape + np.testing.assert_array_equal(np.isnan(eager), np.isnan(dasked)) + finite = np.isfinite(eager) + if finite.any(): + np.testing.assert_allclose( + eager[finite], dasked[finite], rtol=1e-10, atol=1e-10, + ) + + +@pytest.mark.skipif(not HAS_CUPY, reason="cupy required") +class TestCupyReprojectParity: + """End-to-end cupy backend parity checks.""" + + def test_cupy_reproject_matches_numpy(self): + from xrspatial.reproject import reproject + data = np.random.RandomState(7).rand(32, 32).astype(np.float64) + coords = {'y': np.linspace(55, 45, 32), 'x': np.linspace(-5, 5, 32)} + attrs = {'crs': 'EPSG:4326', 'nodata': np.nan} + + np_raster = xr.DataArray(data, dims=['y', 'x'], + coords=coords, attrs=attrs) + cp_raster = xr.DataArray(cp.asarray(data), dims=['y', 'x'], + coords=coords, attrs=attrs) + np_result = reproject(np_raster, 'EPSG:3857').values + cp_result_arr = reproject(cp_raster, 'EPSG:3857').data + # cupy DataArray: pull through .get() to avoid implicit numpy convert + if hasattr(cp_result_arr, 'get'): + cp_vals = cp_result_arr.get() + else: + cp_vals = np.asarray(cp_result_arr) + + assert np_result.shape == cp_vals.shape + np.testing.assert_array_equal( + np.isnan(np_result), np.isnan(cp_vals), + ) + finite = np.isfinite(np_result) + if finite.any(): + np.testing.assert_allclose( + np_result[finite], cp_vals[finite], rtol=1e-5, atol=1e-5, + ) + + @pytest.mark.skipif(not HAS_DASK, reason="dask required") + def test_dask_cupy_reproject_matches_numpy(self): + from xrspatial.reproject import reproject + data = np.random.RandomState(11).rand(32, 32).astype(np.float64) + coords = {'y': np.linspace(55, 45, 32), 'x': np.linspace(-5, 5, 32)} + attrs = {'crs': 'EPSG:4326', 'nodata': np.nan} + + np_raster = xr.DataArray(data, dims=['y', 'x'], + coords=coords, attrs=attrs) + dc_raster = xr.DataArray( + da.from_array(cp.asarray(data), chunks=(16, 16)), + dims=['y', 'x'], coords=coords, attrs=attrs, + ) + np_result = reproject(np_raster, 'EPSG:3857').values + dc_arr = reproject(dc_raster, 'EPSG:3857').data + if hasattr(dc_arr, 'compute'): + dc_arr = dc_arr.compute() + if hasattr(dc_arr, 'get'): + dc_vals = dc_arr.get() + else: + dc_vals = np.asarray(dc_arr) + + assert np_result.shape == dc_vals.shape + np.testing.assert_array_equal( + np.isnan(np_result), np.isnan(dc_vals), + ) + finite = np.isfinite(np_result) + if finite.any(): + np.testing.assert_allclose( + np_result[finite], dc_vals[finite], rtol=1e-5, atol=1e-5, + )