From 93f66d44a80b74f88293d61e71585842996f1966 Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Mon, 4 May 2026 07:37:07 -0700 Subject: [PATCH] Preserve input attrs through reproject() and merge() Both public APIs in xrspatial.reproject built `out_attrs` from scratch, dropping units, long_name, scale_factor, add_offset, _FillValue, and any custom metadata the caller had attached. Merge input attrs forward and override only what the operation actually changes. Pop attrs that are stale after the transform (transform, res, crs_wkt) since carrying them through would be actively wrong. merge() now also falls back name to rasters[0].name before 'merged', matching how reproject() falls back to raster.name. Closes #1445 --- xrspatial/reproject/__init__.py | 31 +++++--- xrspatial/tests/test_reproject.py | 114 ++++++++++++++++++++++++++++++ 2 files changed, 136 insertions(+), 9 deletions(-) diff --git a/xrspatial/reproject/__init__.py b/xrspatial/reproject/__init__.py index ac9f93cd..328d45f8 100644 --- a/xrspatial/reproject/__init__.py +++ b/xrspatial/reproject/__init__.py @@ -669,10 +669,16 @@ def reproject( ) ydim, xdim = _find_spatial_dims(raster) - out_attrs = { - 'crs': tgt_wkt, - 'nodata': nd, - } + # Carry input attrs forward so units, long_name, scale_factor, etc. + # survive the transform. Pop attrs that are stale after reprojection: + # the affine `transform` and grid `res` describe the old grid, and + # `crs_wkt` would duplicate (or contradict) the canonical `crs` we re-emit. + out_attrs = {**raster.attrs} + out_attrs.pop('transform', None) + out_attrs.pop('crs_wkt', None) + out_attrs.pop('res', None) + out_attrs['crs'] = tgt_wkt + out_attrs['nodata'] = nd if tgt_vertical_crs is not None: out_attrs['vertical_crs'] = tgt_vertical_crs @@ -1465,15 +1471,22 @@ def merge( ydim = rasters[0].dims[-2] xdim = rasters[0].dims[-1] + # Carry the first raster's attrs forward (matches the default + # strategy='first'). Drop attrs describing the old grid: `transform`, + # `res`, and the duplicate `crs_wkt` are no longer accurate. + out_attrs = {**rasters[0].attrs} + out_attrs.pop('transform', None) + out_attrs.pop('crs_wkt', None) + out_attrs.pop('res', None) + out_attrs['crs'] = tgt_wkt + out_attrs['nodata'] = nd + result = xr.DataArray( result_data, dims=[ydim, xdim], coords={ydim: y_coords, xdim: x_coords}, - name=name or 'merged', - attrs={ - 'crs': tgt_wkt, - 'nodata': nd, - }, + name=name or rasters[0].name or 'merged', + attrs=out_attrs, ) return result diff --git a/xrspatial/tests/test_reproject.py b/xrspatial/tests/test_reproject.py index 704469c4..451dc19c 100644 --- a/xrspatial/tests/test_reproject.py +++ b/xrspatial/tests/test_reproject.py @@ -1709,3 +1709,117 @@ def test_detect_nodata_accepts_finite(self): from xrspatial.reproject._crs_utils import _detect_nodata r = xr.DataArray(np.zeros((4, 4)), dims=('y', 'x')) assert _detect_nodata(r, nodata=-9999) == -9999.0 + + +class TestMetadataPreservation: + """reproject() and merge() must carry input attrs forward.""" + + @staticmethod + def _raster_with_attrs(extra_attrs=None, h=8, w=8, + crs='EPSG:4326', + x_range=(-1, 1), y_range=(-1, 1), + name='dem'): + data = np.ones((h, w), dtype=np.float64) + attrs = {'crs': crs, 'nodata': np.nan} + if extra_attrs: + attrs.update(extra_attrs) + y = np.linspace(y_range[1], y_range[0], h) + x = np.linspace(x_range[0], x_range[1], w) + return xr.DataArray( + data, dims=['y', 'x'], + coords={'y': y, 'x': x}, + name=name, attrs=attrs, + ) + + # reproject() ---------------------------------------------------------- + + def test_reproject_preserves_units_attr(self): + from xrspatial.reproject import reproject + raster = self._raster_with_attrs({'units': 'meters'}) + result = reproject(raster, 'EPSG:4326', resolution=0.25) + assert result.attrs.get('units') == 'meters' + + def test_reproject_preserves_scale_offset(self): + from xrspatial.reproject import reproject + raster = self._raster_with_attrs( + {'scale_factor': 0.1, 'add_offset': 10.0} + ) + result = reproject(raster, 'EPSG:4326', resolution=0.25) + assert result.attrs.get('scale_factor') == 0.1 + assert result.attrs.get('add_offset') == 10.0 + + def test_reproject_preserves_long_name(self): + from xrspatial.reproject import reproject + raster = self._raster_with_attrs({'long_name': 'elevation'}) + result = reproject(raster, 'EPSG:4326', resolution=0.25) + assert result.attrs.get('long_name') == 'elevation' + + def test_reproject_drops_stale_transform(self): + from xrspatial.reproject import reproject + raster = self._raster_with_attrs( + {'transform': (1.0, 0.0, 0.0, 0.0, -1.0, 0.0)} + ) + result = reproject(raster, 'EPSG:3857') + assert 'transform' not in result.attrs + + def test_reproject_drops_stale_res(self): + from xrspatial.reproject import reproject + raster = self._raster_with_attrs({'res': (1.0, 1.0)}) + result = reproject(raster, 'EPSG:3857') + assert 'res' not in result.attrs + + def test_reproject_overrides_crs(self): + from xrspatial.reproject import reproject + raster = self._raster_with_attrs(crs='EPSG:4326') + result = reproject(raster, 'EPSG:3857') + # Output crs is the new target CRS WKT, not the input EPSG:4326 + assert 'crs' in result.attrs + out_crs = result.attrs['crs'] + assert out_crs != 'EPSG:4326' + # WKT for 3857 mentions Mercator / pseudo-mercator + assert 'Mercator' in out_crs or '3857' in out_crs + + def test_reproject_drops_stale_crs_wkt(self): + from xrspatial.reproject import reproject + raster = self._raster_with_attrs({'crs_wkt': 'OLD_DUPLICATE_WKT'}) + result = reproject(raster, 'EPSG:3857') + assert 'crs_wkt' not in result.attrs + + # merge() -------------------------------------------------------------- + + def test_merge_preserves_first_raster_attrs(self): + from xrspatial.reproject import merge + a = self._raster_with_attrs( + {'units': 'm', 'long_name': 'elev'}, + x_range=(-5, 0), y_range=(-5, 5), name='dem_a', + ) + b = self._raster_with_attrs( + {'units': 'feet'}, + x_range=(0, 5), y_range=(-5, 5), name='dem_b', + ) + result = merge([a, b], resolution=1.0) + assert result.attrs.get('units') == 'm' + assert result.attrs.get('long_name') == 'elev' + + def test_merge_drops_stale_transform(self): + from xrspatial.reproject import merge + a = self._raster_with_attrs( + {'transform': (1.0, 0.0, 0.0, 0.0, -1.0, 0.0)}, + x_range=(-5, 0), y_range=(-5, 5), + ) + b = self._raster_with_attrs( + x_range=(0, 5), y_range=(-5, 5), + ) + result = merge([a, b], resolution=1.0) + assert 'transform' not in result.attrs + + def test_merge_name_falls_back_to_first_raster(self): + from xrspatial.reproject import merge + a = self._raster_with_attrs( + x_range=(-5, 0), y_range=(-5, 5), name='dem_a', + ) + b = self._raster_with_attrs( + x_range=(0, 5), y_range=(-5, 5), name='dem_b', + ) + result = merge([a, b], resolution=1.0) + assert result.name == 'dem_a'