From d1c6b4f29be167dcb82f19ab43105f0ddb86a363 Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Mon, 4 May 2026 08:08:42 -0700 Subject: [PATCH] Preserve non-spatial coords through reproject() and merge() Closes #1454. reproject() and merge() rebuilt the output DataArray with only the y/x coords (and band coord on the 3D path). Other coords on the input were silently dropped: scalar coords like a single 'time' timestamp, non-dimension coords such as rioxarray's 'spatial_ref', and any extra dim coord unrelated to the spatial grid. Both functions now copy any coord whose dims do not include the rebuilt spatial dims. Coords aligned to ydim or xdim are still dropped because their values are stale after the grid is rebuilt. Also document in the reproject() and merge() Returns section that the output y coordinate is always emitted in descending order (top-down, north-up) regardless of the input direction. Adds TestCoordsPreservation covering scalar time coord, non-dim string coord, stale y-aligned coord drop, band coord round-trip, merge scalar coord, and y-descending output for numpy and dask. --- xrspatial/reproject/__init__.py | 51 ++++++++++++++- xrspatial/tests/test_reproject.py | 102 ++++++++++++++++++++++++++++++ 2 files changed, 152 insertions(+), 1 deletion(-) diff --git a/xrspatial/reproject/__init__.py b/xrspatial/reproject/__init__.py index 16f3c1ad..5df89e30 100644 --- a/xrspatial/reproject/__init__.py +++ b/xrspatial/reproject/__init__.py @@ -530,6 +530,17 @@ def reproject( The output ``attrs['crs']`` is in WKT format. If vertical transformation was applied, ``attrs['vertical_crs']`` records the target vertical datum. + + The output y coordinate is always emitted in descending order + (top-down, north-up) regardless of the input direction. This + matches the standard raster convention and the output of common + GIS libraries. + + Non-spatial coords from the input (such as a scalar ``time`` + coord or a non-dimension coord that is not aligned to the + spatial dims) are carried through to the output. Coords that + are aligned to the input y or x dims are dropped because their + values do not apply to the rebuilt grid. """ _validate_raster(raster, func_name='reproject', name='raster', ndim=(2, 3)) @@ -691,9 +702,27 @@ def reproject( out_coords = {ydim: y_coords, xdim: x_coords} if band_dim in raster.coords: out_coords[band_dim] = raster.coords[band_dim] + # Carry forward non-spatial coords (e.g. scalar 'time' coord). + # Skip coords aligned to the rebuilt spatial dims because their + # values do not apply to the new grid. + for cname, cval in raster.coords.items(): + if cname in (ydim, xdim, band_dim): + continue + if ydim in cval.dims or xdim in cval.dims: + continue + out_coords[cname] = cval else: out_dims = [ydim, xdim] out_coords = {ydim: y_coords, xdim: x_coords} + # Carry forward non-spatial coords (e.g. scalar 'time' coord). + # Skip coords aligned to the rebuilt spatial dims because their + # values do not apply to the new grid. + for cname, cval in raster.coords.items(): + if cname in (ydim, xdim): + continue + if ydim in cval.dims or xdim in cval.dims: + continue + out_coords[cname] = cval result = xr.DataArray( result_data, @@ -1372,6 +1401,15 @@ def merge( Returns ------- xr.DataArray + The output y coordinate is always emitted in descending order + (top-down, north-up) regardless of the input direction. This + matches the standard raster convention and the output of common + GIS libraries. + + Non-spatial coords from the first raster (such as a scalar + ``time`` coord) are carried through to the output. Coords + aligned to the spatial dims are dropped because their values + do not apply to the merged grid. """ if not rasters: raise ValueError("merge(): rasters list must not be empty") @@ -1485,10 +1523,21 @@ def merge( ydim = rasters[0].dims[-2] xdim = rasters[0].dims[-1] + out_coords = {ydim: y_coords, xdim: x_coords} + # Carry forward non-spatial coords from the first raster (e.g. scalar + # 'time' coord). Skip coords aligned to the rebuilt spatial dims + # because their values do not apply to the new grid. + for cname, cval in rasters[0].coords.items(): + if cname in (ydim, xdim): + continue + if ydim in cval.dims or xdim in cval.dims: + continue + out_coords[cname] = cval + result = xr.DataArray( result_data, dims=[ydim, xdim], - coords={ydim: y_coords, xdim: x_coords}, + coords=out_coords, name=name or 'merged', attrs={ 'crs': tgt_wkt, diff --git a/xrspatial/tests/test_reproject.py b/xrspatial/tests/test_reproject.py index 893b2150..0486e6b5 100644 --- a/xrspatial/tests/test_reproject.py +++ b/xrspatial/tests/test_reproject.py @@ -1940,3 +1940,105 @@ def test_dask_cupy_reproject_matches_numpy(self): np.testing.assert_allclose( np_result[finite], dc_vals[finite], rtol=1e-5, atol=1e-5, ) + + +class TestCoordsPreservation: + """Non-spatial coords pass through reproject() and merge().""" + + def _small_raster(self, name='test'): + from xrspatial.tests.test_reproject import _make_raster + data = np.random.RandomState(0).rand(8, 8).astype(np.float64) + return _make_raster(data, name=name) + + def test_reproject_preserves_scalar_time_coord(self): + from xrspatial.reproject import reproject + raster = self._small_raster() + ts = np.datetime64('2024-01-15') + raster = raster.assign_coords(time=ts) + + out = reproject(raster, 'EPSG:3857') + assert 'time' in out.coords + assert out.coords['time'].values == ts + + def test_reproject_preserves_non_spatial_string_coord(self): + from xrspatial.reproject import reproject + raster = self._small_raster() + raster = raster.assign_coords(source='tile_a') + + out = reproject(raster, 'EPSG:3857') + assert 'source' in out.coords + assert str(out.coords['source'].values) == 'tile_a' + + def test_reproject_drops_stale_y_coord_alias(self): + from xrspatial.reproject import reproject + raster = self._small_raster() + # 'latitude' is a non-dim coord aligned to the y dim. + latitude = ('y', raster.coords['y'].values.copy()) + raster = raster.assign_coords(latitude=latitude) + assert 'latitude' in raster.coords + + out = reproject(raster, 'EPSG:3857') + # The new grid's y values do not match the stale 'latitude' + # values, so it must be dropped. + assert 'latitude' not in out.coords + + def test_reproject_preserves_band_coord(self): + from xrspatial.reproject import reproject + data = np.random.RandomState(1).rand(8, 8, 3).astype(np.float64) + y = np.linspace(1, -1, 8) + x = np.linspace(-1, 1, 8) + raster = xr.DataArray( + data, dims=['y', 'x', 'band'], + coords={'y': y, 'x': x, 'band': ['R', 'G', 'B']}, + attrs={'crs': 'EPSG:4326', 'nodata': np.nan}, + ) + + out = reproject(raster, 'EPSG:3857') + assert 'band' in out.coords + assert list(out.coords['band'].values) == ['R', 'G', 'B'] + + def test_merge_preserves_first_raster_scalar_coord(self): + from xrspatial.reproject import merge + r1 = self._small_raster(name='r1') + r2 = self._small_raster(name='r2') + ts = np.datetime64('2024-06-01') + r1 = r1.assign_coords(time=ts) + + out = merge([r1, r2], target_crs='EPSG:4326') + assert 'time' in out.coords + assert out.coords['time'].values == ts + + def test_reproject_y_descending_regardless_of_input(self): + from xrspatial.reproject import reproject + # Build a y-ascending input (override default y direction) + data = np.random.RandomState(2).rand(8, 8).astype(np.float64) + y_asc = np.linspace(-1, 1, 8) # ascending + x = np.linspace(-1, 1, 8) + raster = xr.DataArray( + data, dims=['y', 'x'], + coords={'y': y_asc, 'x': x}, + attrs={'crs': 'EPSG:4326', 'nodata': np.nan}, + ) + + out = reproject(raster, 'EPSG:3857') + y_out = out.coords['y'].values + # Strictly descending (top-down, north-up). + assert np.all(np.diff(y_out) < 0), ( + f"Output y must be descending, got {y_out}" + ) + + @pytest.mark.skipif(not HAS_DASK, reason="dask required") + def test_reproject_y_descending_dask(self): + from xrspatial.reproject import reproject + data = np.random.RandomState(3).rand(8, 8).astype(np.float64) + y_asc = np.linspace(-1, 1, 8) + x = np.linspace(-1, 1, 8) + raster = xr.DataArray( + da.from_array(data, chunks=4), dims=['y', 'x'], + coords={'y': y_asc, 'x': x}, + attrs={'crs': 'EPSG:4326', 'nodata': np.nan}, + ) + + out = reproject(raster, 'EPSG:3857') + y_out = out.coords['y'].values + assert np.all(np.diff(y_out) < 0)