Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion xrspatial/reproject/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,17 @@ def reproject(
The output ``attrs['crs']`` is in WKT format.
If vertical transformation was applied, ``attrs['vertical_crs']``
records the target vertical datum.

The output y coordinate is always emitted in descending order
(top-down, north-up) regardless of the input direction. This
matches the standard raster convention and the output of common
GIS libraries.

Non-spatial coords from the input (such as a scalar ``time``
coord or a non-dimension coord that is not aligned to the
spatial dims) are carried through to the output. Coords that
are aligned to the input y or x dims are dropped because their
values do not apply to the rebuilt grid.
"""
_validate_raster(raster, func_name='reproject', name='raster',
ndim=(2, 3))
Expand Down Expand Up @@ -697,9 +708,27 @@ def reproject(
out_coords = {ydim: y_coords, xdim: x_coords}
if band_dim in raster.coords:
out_coords[band_dim] = raster.coords[band_dim]
# Carry forward non-spatial coords (e.g. scalar 'time' coord).
# Skip coords aligned to the rebuilt spatial dims because their
# values do not apply to the new grid.
for cname, cval in raster.coords.items():
if cname in (ydim, xdim, band_dim):
continue
if ydim in cval.dims or xdim in cval.dims:
continue
out_coords[cname] = cval
else:
out_dims = [ydim, xdim]
out_coords = {ydim: y_coords, xdim: x_coords}
# Carry forward non-spatial coords (e.g. scalar 'time' coord).
# Skip coords aligned to the rebuilt spatial dims because their
# values do not apply to the new grid.
for cname, cval in raster.coords.items():
if cname in (ydim, xdim):
continue
if ydim in cval.dims or xdim in cval.dims:
continue
out_coords[cname] = cval

result = xr.DataArray(
result_data,
Expand Down Expand Up @@ -1388,6 +1417,15 @@ def merge(
Returns
-------
xr.DataArray
The output y coordinate is always emitted in descending order
(top-down, north-up) regardless of the input direction. This
matches the standard raster convention and the output of common
GIS libraries.

Non-spatial coords from the first raster (such as a scalar
``time`` coord) are carried through to the output. Coords
aligned to the spatial dims are dropped because their values
do not apply to the merged grid.
"""
if not rasters:
raise ValueError("merge(): rasters list must not be empty")
Expand Down Expand Up @@ -1506,6 +1544,17 @@ def merge(
ydim = rasters[0].dims[-2]
xdim = rasters[0].dims[-1]

out_coords = {ydim: y_coords, xdim: x_coords}
# Carry forward non-spatial coords from the first raster (e.g. scalar
# 'time' coord). Skip coords aligned to the rebuilt spatial dims
# because their values do not apply to the new grid.
for cname, cval in rasters[0].coords.items():
if cname in (ydim, xdim):
continue
if ydim in cval.dims or xdim in cval.dims:
continue
out_coords[cname] = cval

# Carry the first raster's attrs forward (matches the default
# strategy='first'). Drop attrs describing the old grid: `transform`,
# `res`, and the duplicate `crs_wkt` are no longer accurate.
Expand All @@ -1519,7 +1568,7 @@ def merge(
result = xr.DataArray(
result_data,
dims=[ydim, xdim],
coords={ydim: y_coords, xdim: x_coords},
coords=out_coords,
name=name or rasters[0].name or 'merged',
attrs=out_attrs,
)
Expand Down
102 changes: 102 additions & 0 deletions xrspatial/tests/test_reproject.py
Original file line number Diff line number Diff line change
Expand Up @@ -2326,3 +2326,105 @@ def test_dask_cupy_reproject_matches_numpy(self):
np.testing.assert_allclose(
np_result[finite], dc_vals[finite], rtol=1e-5, atol=1e-5,
)


class TestCoordsPreservation:
"""Non-spatial coords pass through reproject() and merge()."""

def _small_raster(self, name='test'):
from xrspatial.tests.test_reproject import _make_raster
data = np.random.RandomState(0).rand(8, 8).astype(np.float64)
return _make_raster(data, name=name)

def test_reproject_preserves_scalar_time_coord(self):
from xrspatial.reproject import reproject
raster = self._small_raster()
ts = np.datetime64('2024-01-15')
raster = raster.assign_coords(time=ts)

out = reproject(raster, 'EPSG:3857')
assert 'time' in out.coords
assert out.coords['time'].values == ts

def test_reproject_preserves_non_spatial_string_coord(self):
from xrspatial.reproject import reproject
raster = self._small_raster()
raster = raster.assign_coords(source='tile_a')

out = reproject(raster, 'EPSG:3857')
assert 'source' in out.coords
assert str(out.coords['source'].values) == 'tile_a'

def test_reproject_drops_stale_y_coord_alias(self):
from xrspatial.reproject import reproject
raster = self._small_raster()
# 'latitude' is a non-dim coord aligned to the y dim.
latitude = ('y', raster.coords['y'].values.copy())
raster = raster.assign_coords(latitude=latitude)
assert 'latitude' in raster.coords

out = reproject(raster, 'EPSG:3857')
# The new grid's y values do not match the stale 'latitude'
# values, so it must be dropped.
assert 'latitude' not in out.coords

def test_reproject_preserves_band_coord(self):
from xrspatial.reproject import reproject
data = np.random.RandomState(1).rand(8, 8, 3).astype(np.float64)
y = np.linspace(1, -1, 8)
x = np.linspace(-1, 1, 8)
raster = xr.DataArray(
data, dims=['y', 'x', 'band'],
coords={'y': y, 'x': x, 'band': ['R', 'G', 'B']},
attrs={'crs': 'EPSG:4326', 'nodata': np.nan},
)

out = reproject(raster, 'EPSG:3857')
assert 'band' in out.coords
assert list(out.coords['band'].values) == ['R', 'G', 'B']

def test_merge_preserves_first_raster_scalar_coord(self):
from xrspatial.reproject import merge
r1 = self._small_raster(name='r1')
r2 = self._small_raster(name='r2')
ts = np.datetime64('2024-06-01')
r1 = r1.assign_coords(time=ts)

out = merge([r1, r2], target_crs='EPSG:4326')
assert 'time' in out.coords
assert out.coords['time'].values == ts

def test_reproject_y_descending_regardless_of_input(self):
from xrspatial.reproject import reproject
# Build a y-ascending input (override default y direction)
data = np.random.RandomState(2).rand(8, 8).astype(np.float64)
y_asc = np.linspace(-1, 1, 8) # ascending
x = np.linspace(-1, 1, 8)
raster = xr.DataArray(
data, dims=['y', 'x'],
coords={'y': y_asc, 'x': x},
attrs={'crs': 'EPSG:4326', 'nodata': np.nan},
)

out = reproject(raster, 'EPSG:3857')
y_out = out.coords['y'].values
# Strictly descending (top-down, north-up).
assert np.all(np.diff(y_out) < 0), (
f"Output y must be descending, got {y_out}"
)

@pytest.mark.skipif(not HAS_DASK, reason="dask required")
def test_reproject_y_descending_dask(self):
from xrspatial.reproject import reproject
data = np.random.RandomState(3).rand(8, 8).astype(np.float64)
y_asc = np.linspace(-1, 1, 8)
x = np.linspace(-1, 1, 8)
raster = xr.DataArray(
da.from_array(data, chunks=4), dims=['y', 'x'],
coords={'y': y_asc, 'x': x},
attrs={'crs': 'EPSG:4326', 'nodata': np.nan},
)

out = reproject(raster, 'EPSG:3857')
y_out = out.coords['y'].values
assert np.all(np.diff(y_out) < 0)