Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 63 additions & 20 deletions xrspatial/reproject/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ def _reproject_chunk_cupy(
window = window.compute()
if not isinstance(window, cp.ndarray):
window = cp.asarray(window)
orig_dtype = window.dtype
window = window.astype(cp.float64)

# Adjust coordinates relative to window (stays on GPU if CuPy)
Expand All @@ -432,16 +433,22 @@ def _reproject_chunk_cupy(
if _use_native_cuda:
# Coordinates are already CuPy arrays -- use native CUDA kernels
# (nodata->NaN conversion is handled inside _resample_cupy_native)
return _resample_cupy_native(window, local_row, local_col,
resampling=resampling, nodata=nodata)
result = _resample_cupy_native(window, local_row, local_col,
resampling=resampling, nodata=nodata)
else:
# CPU coordinates -- convert sentinel nodata to NaN before map_coordinates
if not np.isnan(nodata):
window = window.copy()
window[window == nodata] = cp.nan

# CPU coordinates -- convert sentinel nodata to NaN before map_coordinates
if not np.isnan(nodata):
window = window.copy()
window[window == nodata] = cp.nan
result = _resample_cupy(window, local_row, local_col,
resampling=resampling, nodata=nodata)

return _resample_cupy(window, local_row, local_col,
resampling=resampling, nodata=nodata)
# Clamp and cast back for integer source dtypes (parity with numpy path)
if np.issubdtype(orig_dtype, np.integer):
info = np.iinfo(orig_dtype)
result = cp.clip(cp.round(result), info.min, info.max).astype(orig_dtype)
return result


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -1298,15 +1305,25 @@ def _reproject_dask(
src_footprint_tgt=src_footprint_tgt,
)

# Pick the template dtype to match the eager path: integer sources
# round-trip back to their original dtype after clamping; floats stay
# float64. Without this, dask claims float64 meta but the chunks
# actually return the integer dtype, producing inconsistent output.
src_dtype = np.dtype(raster.dtype)
if np.issubdtype(src_dtype, np.integer):
out_dtype = src_dtype
else:
out_dtype = np.dtype(np.float64)

template = da.empty(
out_shape, dtype=np.float64, chunks=(row_chunks, col_chunks)
out_shape, dtype=out_dtype, chunks=(row_chunks, col_chunks)
)

return da.map_blocks(
bound_adapter,
template,
dtype=np.float64,
meta=np.array((), dtype=np.float64),
dtype=out_dtype,
meta=np.array((), dtype=out_dtype),
)


Expand Down Expand Up @@ -1583,7 +1600,7 @@ def _merge_block_adapter(
src_wkt_list, tgt_wkt,
out_bounds, out_shape,
resampling, nodata, strategy, precision,
src_footprints_tgt,
src_footprints_tgt, same_crs_list,
):
"""``map_blocks`` adapter for merge."""
info = block_info[0]
Expand All @@ -1598,14 +1615,30 @@ def _merge_block_adapter(
if (src_footprints_tgt[i] is not None
and not _bounds_overlap(cb, src_footprints_tgt[i])):
continue
reprojected = _reproject_chunk_numpy(
raster_data_list[i],
src_bounds_list[i], src_shape_list[i], y_desc_list[i],
src_wkt_list[i], tgt_wkt,
cb, chunk_shape,
resampling, nodata, precision,
)
arrays.append(reprojected)

placed = None
if same_crs_list[i]:
# Same-CRS path: direct pixel placement (no resampling).
# Mirrors the eager merge so dask matches numpy bit-for-bit.
src_data = raster_data_list[i]
if hasattr(src_data, 'compute'):
src_data = src_data.compute()
placed = _place_same_crs(
np.asarray(src_data),
src_bounds_list[i], src_shape_list[i], y_desc_list[i],
cb, chunk_shape, nodata,
)
if placed is not None:
arrays.append(placed)
else:
reprojected = _reproject_chunk_numpy(
raster_data_list[i],
src_bounds_list[i], src_shape_list[i], y_desc_list[i],
src_wkt_list[i], tgt_wkt,
cb, chunk_shape,
resampling, nodata, precision,
)
arrays.append(reprojected)

if not arrays:
return np.full(chunk_shape, nodata, dtype=np.float64)
Expand Down Expand Up @@ -1636,6 +1669,15 @@ def _merge_dask(
for i in range(len(raster_infos))
]

# Precompute CRS-equality flags so per-block adapters can shortcut to
# direct pixel placement (matches the eager _merge_inmemory path).
from ._crs_utils import _crs_from_wkt
tgt_crs = _crs_from_wkt(tgt_wkt)
same_crs_list = [
bool(_crs_from_wkt(wkt_list[i]) == tgt_crs)
for i in range(len(raster_infos))
]

# Bind via partial to prevent map_blocks from adding dask arrays
# in data_list as whole-array dependencies.
bound_adapter = functools.partial(
Expand All @@ -1653,6 +1695,7 @@ def _merge_dask(
strategy=strategy,
precision=16,
src_footprints_tgt=footprints,
same_crs_list=same_crs_list,
)

template = da.empty(
Expand Down
186 changes: 186 additions & 0 deletions xrspatial/tests/test_reproject.py
Original file line number Diff line number Diff line change
Expand Up @@ -1709,3 +1709,189 @@ def test_detect_nodata_accepts_finite(self):
from xrspatial.reproject._crs_utils import _detect_nodata
r = xr.DataArray(np.zeros((4, 4)), dims=('y', 'x'))
assert _detect_nodata(r, nodata=-9999) == -9999.0


# ---------------------------------------------------------------------------
# Backend parity: dask dtype + same-CRS dask merge + cupy
# ---------------------------------------------------------------------------

@pytest.mark.skipif(not HAS_DASK, reason="dask required")
class TestDaskDtypeParity:
"""Dask reproject should preserve source integer dtype (matches numpy)."""

def test_dask_reproject_int8_preserves_dtype(self):
from xrspatial.reproject import reproject
data = np.arange(64, dtype=np.int8).reshape(8, 8)
coords = {'y': np.linspace(5, -5, 8), 'x': np.linspace(-5, 5, 8)}
attrs = {'crs': 'EPSG:4326', 'nodata': -1}
raster = xr.DataArray(
da.from_array(data, chunks=(4, 4)),
dims=['y', 'x'], coords=coords, attrs=attrs,
)
result = reproject(raster, 'EPSG:4326', resolution=1.0)
# Lazy meta dtype should match
assert result.data.dtype == np.int8
# Computed dtype should also match
assert result.compute().dtype == np.int8

def test_dask_reproject_uint16_preserves_dtype(self):
from xrspatial.reproject import reproject
data = (np.arange(64, dtype=np.uint16) * 100).reshape(8, 8)
coords = {'y': np.linspace(5, -5, 8), 'x': np.linspace(-5, 5, 8)}
attrs = {'crs': 'EPSG:4326', 'nodata': 0}
raster = xr.DataArray(
da.from_array(data, chunks=(4, 4)),
dims=['y', 'x'], coords=coords, attrs=attrs,
)
result = reproject(raster, 'EPSG:4326', resolution=1.0)
assert result.data.dtype == np.uint16
assert result.compute().dtype == np.uint16

def test_dask_reproject_float32_stays_float64(self):
"""Float input still upcasts to float64 (existing behaviour guard)."""
from xrspatial.reproject import reproject
data = np.random.RandomState(0).rand(8, 8).astype(np.float32)
coords = {'y': np.linspace(5, -5, 8), 'x': np.linspace(-5, 5, 8)}
attrs = {'crs': 'EPSG:4326', 'nodata': np.nan}
raster = xr.DataArray(
da.from_array(data, chunks=(4, 4)),
dims=['y', 'x'], coords=coords, attrs=attrs,
)
result = reproject(raster, 'EPSG:4326', resolution=1.0)
assert result.data.dtype == np.float64
assert result.compute().dtype == np.float64


@pytest.mark.skipif(not HAS_DASK, reason="dask required")
class TestMergeDaskParity:
"""Dask merge should match the eager numpy merge."""

def test_merge_dask_same_crs_matches_eager(self):
"""Same-CRS merge should be bit-equal between eager and dask paths.

Source and output resolutions match (within 1%) so
``_place_same_crs`` activates in both paths -- direct pixel copy
means the dask result must equal the eager result bit-for-bit.
"""
from xrspatial.reproject import merge
# 16 pixels with center-to-center spacing of exactly 1.0 -> bounds
# extend half a pixel past coords, source resolution matches output.
a_data = np.arange(256, dtype=np.float64).reshape(16, 16)
b_data = (np.arange(256, dtype=np.float64) * 2).reshape(16, 16)
a = _make_raster(a_data, x_range=(-7.5, 7.5), y_range=(-7.5, 7.5))
b = _make_raster(b_data, x_range=(8.5, 23.5), y_range=(-7.5, 7.5))

eager = merge([a, b], resolution=1.0).compute().values

a_dask = a.copy()
b_dask = b.copy()
a_dask.data = da.from_array(a_data, chunks=(8, 8))
b_dask.data = da.from_array(b_data, chunks=(8, 8))
dasked = merge(
[a_dask, b_dask], resolution=1.0, chunk_size=8,
).compute().values

assert eager.shape == dasked.shape
eager_nan = np.isnan(eager)
dask_nan = np.isnan(dasked)
np.testing.assert_array_equal(eager_nan, dask_nan)
# Finite values must be bit-equal: same-CRS path is direct copy
np.testing.assert_array_equal(eager[~eager_nan], dasked[~dask_nan])

def test_merge_dask_different_crs_matches_eager(self):
"""Different-CRS merge should match within float tolerance."""
from xrspatial.reproject import merge
a_data = np.arange(256, dtype=np.float64).reshape(16, 16)
b_data = (np.arange(256, dtype=np.float64) + 100.0).reshape(16, 16)
# One in WGS84, one in Web Mercator (forces reprojection)
a = _make_raster(a_data, crs='EPSG:4326',
x_range=(-10, 0), y_range=(-5, 5))
# Build a Web-Mercator tile that overlaps the target
b = _make_raster(b_data, crs='EPSG:3857',
x_range=(0, 1_000_000), y_range=(-500_000, 500_000))

eager = merge(
[a, b], target_crs='EPSG:4326', resolution=1.0,
).compute().values

a_dask = a.copy()
b_dask = b.copy()
a_dask.data = da.from_array(a_data, chunks=(8, 8))
b_dask.data = da.from_array(b_data, chunks=(8, 8))
dasked = merge(
[a_dask, b_dask], target_crs='EPSG:4326',
resolution=1.0, chunk_size=8,
).compute().values

assert eager.shape == dasked.shape
np.testing.assert_array_equal(np.isnan(eager), np.isnan(dasked))
finite = np.isfinite(eager)
if finite.any():
np.testing.assert_allclose(
eager[finite], dasked[finite], rtol=1e-10, atol=1e-10,
)


@pytest.mark.skipif(not HAS_CUPY, reason="cupy required")
class TestCupyReprojectParity:
"""End-to-end cupy backend parity checks."""

def test_cupy_reproject_matches_numpy(self):
from xrspatial.reproject import reproject
data = np.random.RandomState(7).rand(32, 32).astype(np.float64)
coords = {'y': np.linspace(55, 45, 32), 'x': np.linspace(-5, 5, 32)}
attrs = {'crs': 'EPSG:4326', 'nodata': np.nan}

np_raster = xr.DataArray(data, dims=['y', 'x'],
coords=coords, attrs=attrs)
cp_raster = xr.DataArray(cp.asarray(data), dims=['y', 'x'],
coords=coords, attrs=attrs)
np_result = reproject(np_raster, 'EPSG:3857').values
cp_result_arr = reproject(cp_raster, 'EPSG:3857').data
# cupy DataArray: pull through .get() to avoid implicit numpy convert
if hasattr(cp_result_arr, 'get'):
cp_vals = cp_result_arr.get()
else:
cp_vals = np.asarray(cp_result_arr)

assert np_result.shape == cp_vals.shape
np.testing.assert_array_equal(
np.isnan(np_result), np.isnan(cp_vals),
)
finite = np.isfinite(np_result)
if finite.any():
np.testing.assert_allclose(
np_result[finite], cp_vals[finite], rtol=1e-5, atol=1e-5,
)

@pytest.mark.skipif(not HAS_DASK, reason="dask required")
def test_dask_cupy_reproject_matches_numpy(self):
from xrspatial.reproject import reproject
data = np.random.RandomState(11).rand(32, 32).astype(np.float64)
coords = {'y': np.linspace(55, 45, 32), 'x': np.linspace(-5, 5, 32)}
attrs = {'crs': 'EPSG:4326', 'nodata': np.nan}

np_raster = xr.DataArray(data, dims=['y', 'x'],
coords=coords, attrs=attrs)
dc_raster = xr.DataArray(
da.from_array(cp.asarray(data), chunks=(16, 16)),
dims=['y', 'x'], coords=coords, attrs=attrs,
)
np_result = reproject(np_raster, 'EPSG:3857').values
dc_arr = reproject(dc_raster, 'EPSG:3857').data
if hasattr(dc_arr, 'compute'):
dc_arr = dc_arr.compute()
if hasattr(dc_arr, 'get'):
dc_vals = dc_arr.get()
else:
dc_vals = np.asarray(dc_arr)

assert np_result.shape == dc_vals.shape
np.testing.assert_array_equal(
np.isnan(np_result), np.isnan(dc_vals),
)
finite = np.isfinite(np_result)
if finite.any():
np.testing.assert_allclose(
np_result[finite], dc_vals[finite], rtol=1e-5, atol=1e-5,
)
Loading