diff --git a/xrspatial/reproject/__init__.py b/xrspatial/reproject/__init__.py index 4f67f282..c3327faf 100644 --- a/xrspatial/reproject/__init__.py +++ b/xrspatial/reproject/__init__.py @@ -530,6 +530,17 @@ def reproject( The output ``attrs['crs']`` is in WKT format. If vertical transformation was applied, ``attrs['vertical_crs']`` records the target vertical datum. + + The output y coordinate is always emitted in descending order + (top-down, north-up) regardless of the input direction. This + matches the standard raster convention and the output of common + GIS libraries. + + Non-spatial coords from the input (such as a scalar ``time`` + coord or a non-dimension coord that is not aligned to the + spatial dims) are carried through to the output. Coords that + are aligned to the input y or x dims are dropped because their + values do not apply to the rebuilt grid. """ _validate_raster(raster, func_name='reproject', name='raster', ndim=(2, 3)) @@ -697,9 +708,27 @@ def reproject( out_coords = {ydim: y_coords, xdim: x_coords} if band_dim in raster.coords: out_coords[band_dim] = raster.coords[band_dim] + # Carry forward non-spatial coords (e.g. scalar 'time' coord). + # Skip coords aligned to the rebuilt spatial dims because their + # values do not apply to the new grid. + for cname, cval in raster.coords.items(): + if cname in (ydim, xdim, band_dim): + continue + if ydim in cval.dims or xdim in cval.dims: + continue + out_coords[cname] = cval else: out_dims = [ydim, xdim] out_coords = {ydim: y_coords, xdim: x_coords} + # Carry forward non-spatial coords (e.g. scalar 'time' coord). + # Skip coords aligned to the rebuilt spatial dims because their + # values do not apply to the new grid. + for cname, cval in raster.coords.items(): + if cname in (ydim, xdim): + continue + if ydim in cval.dims or xdim in cval.dims: + continue + out_coords[cname] = cval result = xr.DataArray( result_data, @@ -1388,6 +1417,15 @@ def merge( Returns ------- xr.DataArray + The output y coordinate is always emitted in descending order + (top-down, north-up) regardless of the input direction. This + matches the standard raster convention and the output of common + GIS libraries. + + Non-spatial coords from the first raster (such as a scalar + ``time`` coord) are carried through to the output. Coords + aligned to the spatial dims are dropped because their values + do not apply to the merged grid. """ if not rasters: raise ValueError("merge(): rasters list must not be empty") @@ -1506,6 +1544,17 @@ def merge( ydim = rasters[0].dims[-2] xdim = rasters[0].dims[-1] + out_coords = {ydim: y_coords, xdim: x_coords} + # Carry forward non-spatial coords from the first raster (e.g. scalar + # 'time' coord). Skip coords aligned to the rebuilt spatial dims + # because their values do not apply to the new grid. + for cname, cval in rasters[0].coords.items(): + if cname in (ydim, xdim): + continue + if ydim in cval.dims or xdim in cval.dims: + continue + out_coords[cname] = cval + # Carry the first raster's attrs forward (matches the default # strategy='first'). Drop attrs describing the old grid: `transform`, # `res`, and the duplicate `crs_wkt` are no longer accurate. @@ -1519,7 +1568,7 @@ def merge( result = xr.DataArray( result_data, dims=[ydim, xdim], - coords={ydim: y_coords, xdim: x_coords}, + coords=out_coords, name=name or rasters[0].name or 'merged', attrs=out_attrs, ) diff --git a/xrspatial/tests/test_reproject.py b/xrspatial/tests/test_reproject.py index 851e848f..632ff995 100644 --- a/xrspatial/tests/test_reproject.py +++ b/xrspatial/tests/test_reproject.py @@ -2326,3 +2326,105 @@ def test_dask_cupy_reproject_matches_numpy(self): np.testing.assert_allclose( np_result[finite], dc_vals[finite], rtol=1e-5, atol=1e-5, ) + + +class TestCoordsPreservation: + """Non-spatial coords pass through reproject() and merge().""" + + def _small_raster(self, name='test'): + from xrspatial.tests.test_reproject import _make_raster + data = np.random.RandomState(0).rand(8, 8).astype(np.float64) + return _make_raster(data, name=name) + + def test_reproject_preserves_scalar_time_coord(self): + from xrspatial.reproject import reproject + raster = self._small_raster() + ts = np.datetime64('2024-01-15') + raster = raster.assign_coords(time=ts) + + out = reproject(raster, 'EPSG:3857') + assert 'time' in out.coords + assert out.coords['time'].values == ts + + def test_reproject_preserves_non_spatial_string_coord(self): + from xrspatial.reproject import reproject + raster = self._small_raster() + raster = raster.assign_coords(source='tile_a') + + out = reproject(raster, 'EPSG:3857') + assert 'source' in out.coords + assert str(out.coords['source'].values) == 'tile_a' + + def test_reproject_drops_stale_y_coord_alias(self): + from xrspatial.reproject import reproject + raster = self._small_raster() + # 'latitude' is a non-dim coord aligned to the y dim. + latitude = ('y', raster.coords['y'].values.copy()) + raster = raster.assign_coords(latitude=latitude) + assert 'latitude' in raster.coords + + out = reproject(raster, 'EPSG:3857') + # The new grid's y values do not match the stale 'latitude' + # values, so it must be dropped. + assert 'latitude' not in out.coords + + def test_reproject_preserves_band_coord(self): + from xrspatial.reproject import reproject + data = np.random.RandomState(1).rand(8, 8, 3).astype(np.float64) + y = np.linspace(1, -1, 8) + x = np.linspace(-1, 1, 8) + raster = xr.DataArray( + data, dims=['y', 'x', 'band'], + coords={'y': y, 'x': x, 'band': ['R', 'G', 'B']}, + attrs={'crs': 'EPSG:4326', 'nodata': np.nan}, + ) + + out = reproject(raster, 'EPSG:3857') + assert 'band' in out.coords + assert list(out.coords['band'].values) == ['R', 'G', 'B'] + + def test_merge_preserves_first_raster_scalar_coord(self): + from xrspatial.reproject import merge + r1 = self._small_raster(name='r1') + r2 = self._small_raster(name='r2') + ts = np.datetime64('2024-06-01') + r1 = r1.assign_coords(time=ts) + + out = merge([r1, r2], target_crs='EPSG:4326') + assert 'time' in out.coords + assert out.coords['time'].values == ts + + def test_reproject_y_descending_regardless_of_input(self): + from xrspatial.reproject import reproject + # Build a y-ascending input (override default y direction) + data = np.random.RandomState(2).rand(8, 8).astype(np.float64) + y_asc = np.linspace(-1, 1, 8) # ascending + x = np.linspace(-1, 1, 8) + raster = xr.DataArray( + data, dims=['y', 'x'], + coords={'y': y_asc, 'x': x}, + attrs={'crs': 'EPSG:4326', 'nodata': np.nan}, + ) + + out = reproject(raster, 'EPSG:3857') + y_out = out.coords['y'].values + # Strictly descending (top-down, north-up). + assert np.all(np.diff(y_out) < 0), ( + f"Output y must be descending, got {y_out}" + ) + + @pytest.mark.skipif(not HAS_DASK, reason="dask required") + def test_reproject_y_descending_dask(self): + from xrspatial.reproject import reproject + data = np.random.RandomState(3).rand(8, 8).astype(np.float64) + y_asc = np.linspace(-1, 1, 8) + x = np.linspace(-1, 1, 8) + raster = xr.DataArray( + da.from_array(data, chunks=4), dims=['y', 'x'], + coords={'y': y_asc, 'x': x}, + attrs={'crs': 'EPSG:4326', 'nodata': np.nan}, + ) + + out = reproject(raster, 'EPSG:3857') + y_out = out.coords['y'].values + assert np.all(np.diff(y_out) < 0)