From c15047b37859f3fd8f2fb812476dedf9fc070b13 Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Sun, 10 May 2026 21:32:23 -0700 Subject: [PATCH 1/5] Centralise GeoTIFF attrs population across all read backends (#1548) Before this change ``open_geotiff`` returned a different attrs set depending on which backend handled the read: - numpy (eager): full set including resolution, image_description, extra_samples, extra_tags, plus all the CRS-description fields and GDAL metadata. - dask: only crs, transform, raster_type, nodata. - cupy / dask+cupy: only crs, crs_wkt, transform, nodata. Round-tripping through dask or GPU silently dropped resolution and pass-through tags, and any downstream code that branched on ``attrs['x_resolution']`` or ``attrs['nodata']`` saw different behaviour per backend. Factor a single ``_populate_attrs_from_geo_info(attrs, geo_info, window=None)`` helper and call it from all four read sites (eager numpy, dask, GPU stripped, GPU tiled). The helper owns the entire attrs population so the read paths cannot diverge again. Each caller still sets ``attrs['nodata']`` next to its own masking step (presence in attrs signals "array has been NaN-masked"). Adds ``test_attrs_parity_1548.py`` with 5 tests pinning the contract: each backend matches the eager numpy attrs on the pass-through keys (x/y_resolution, resolution_unit, image_description, extra_samples) and the full attrs key set is identical across all available backends. The VRT read path (``read_vrt`` at line ~2405) operates on a different ``vrt`` object rather than ``geo_info`` and is not covered here; that divergence is pre-existing and out of scope. Closes #1548. --- xrspatial/geotiff/__init__.py | 217 ++++++++---------- .../geotiff/tests/test_attrs_parity_1548.py | 187 +++++++++++++++ 2 files changed, 288 insertions(+), 116 deletions(-) create mode 100644 xrspatial/geotiff/tests/test_attrs_parity_1548.py diff --git a/xrspatial/geotiff/__init__.py b/xrspatial/geotiff/__init__.py index 9f30299a0..f13b4442b 100644 --- a/xrspatial/geotiff/__init__.py +++ b/xrspatial/geotiff/__init__.py @@ -298,6 +298,103 @@ def _extent_to_window(transform, file_height, file_width, return (row_start, col_start, row_stop, col_stop) +def _populate_attrs_from_geo_info(attrs: dict, geo_info, *, + window=None) -> None: + """Populate ``attrs`` with all GeoTIFF metadata from ``geo_info``. + + Centralised so the eager numpy, dask, and GPU read paths emit the + same attrs keys for the same input file. Mutates ``attrs`` in place. + + The ``nodata`` attr is intentionally NOT set here because each caller + sets it next to its own nodata-masking step (the value's presence in + attrs signals "this array has been NaN-masked"). + + ``window`` is a ``(r0, c0, r1, c1)`` tuple for the eager windowed + read; when set, the emitted ``attrs['transform']`` shifts the origin + to the window's top-left. The dask and GPU paths do not use this -- + their windows are per-chunk inside the graph, not on the outer + DataArray. + """ + if geo_info.crs_epsg is not None: + attrs['crs'] = geo_info.crs_epsg + if geo_info.crs_wkt is not None: + attrs['crs_wkt'] = geo_info.crs_wkt + if geo_info.raster_type == RASTER_PIXEL_IS_POINT: + attrs['raster_type'] = 'point' + + src_t = geo_info.transform + if src_t is not None: + if window is not None: + r0, c0, _r1, _c1 = window + origin_x_w = float(src_t.origin_x) + c0 * float(src_t.pixel_width) + origin_y_w = float(src_t.origin_y) + r0 * float(src_t.pixel_height) + attrs['transform'] = ( + float(src_t.pixel_width), 0.0, origin_x_w, + 0.0, float(src_t.pixel_height), origin_y_w, + ) + else: + t_tuple = _transform_tuple(geo_info) + if t_tuple is not None: + attrs['transform'] = t_tuple + + if geo_info.crs_name is not None: + attrs['crs_name'] = geo_info.crs_name + if geo_info.geog_citation is not None: + attrs['geog_citation'] = geo_info.geog_citation + if geo_info.datum_code is not None: + attrs['datum_code'] = geo_info.datum_code + if geo_info.angular_units is not None: + attrs['angular_units'] = geo_info.angular_units + if geo_info.linear_units is not None: + attrs['linear_units'] = geo_info.linear_units + if geo_info.semi_major_axis is not None: + attrs['semi_major_axis'] = geo_info.semi_major_axis + if geo_info.inv_flattening is not None: + attrs['inv_flattening'] = geo_info.inv_flattening + if geo_info.projection_code is not None: + attrs['projection_code'] = geo_info.projection_code + if geo_info.vertical_epsg is not None: + attrs['vertical_crs'] = geo_info.vertical_epsg + if geo_info.vertical_citation is not None: + attrs['vertical_citation'] = geo_info.vertical_citation + if geo_info.vertical_units is not None: + attrs['vertical_units'] = geo_info.vertical_units + + if geo_info.gdal_metadata is not None: + attrs['gdal_metadata'] = geo_info.gdal_metadata + if geo_info.gdal_metadata_xml is not None: + attrs['gdal_metadata_xml'] = geo_info.gdal_metadata_xml + + if geo_info.extra_tags is not None: + attrs['extra_tags'] = geo_info.extra_tags + if geo_info.image_description is not None: + attrs['image_description'] = geo_info.image_description + if geo_info.extra_samples is not None: + attrs['extra_samples'] = geo_info.extra_samples + + if geo_info.x_resolution is not None: + attrs['x_resolution'] = geo_info.x_resolution + if geo_info.y_resolution is not None: + attrs['y_resolution'] = geo_info.y_resolution + if geo_info.resolution_unit is not None: + _unit_names = {1: 'none', 2: 'inch', 3: 'centimeter'} + attrs['resolution_unit'] = _unit_names.get( + geo_info.resolution_unit, str(geo_info.resolution_unit)) + + if geo_info.colormap is not None: + try: + from matplotlib.colors import ListedColormap + attrs['cmap'] = ListedColormap( + geo_info.colormap, name='tiff_palette') + attrs['colormap_rgba'] = geo_info.colormap + except ImportError: + attrs['colormap_rgba'] = geo_info.colormap + + if geo_info.extra_tags is not None: + for _tag_id, _tt, _tc, _tv in geo_info.extra_tags: + if _tag_id == 320: # TAG_COLORMAP + attrs['colormap'] = _tv + break def open_geotiff(source, *, dtype=None, window=None, @@ -440,103 +537,7 @@ def open_geotiff(source, *, dtype=None, window=None, name = os.path.splitext(os.path.basename(source))[0] attrs = {} - if geo_info.crs_epsg is not None: - attrs['crs'] = geo_info.crs_epsg - if geo_info.crs_wkt is not None: - attrs['crs_wkt'] = geo_info.crs_wkt - if geo_info.raster_type == RASTER_PIXEL_IS_POINT: - attrs['raster_type'] = 'point' - - # Preserve the source GeoTransform verbatim. For a windowed read the - # origin shifts to the window's top-left pixel so the transform stays - # consistent with the returned y/x coords. - src_t = geo_info.transform - if src_t is not None: - if window is not None: - r0, c0, _r1, _c1 = window - origin_x_w = float(src_t.origin_x) + c0 * float(src_t.pixel_width) - origin_y_w = float(src_t.origin_y) + r0 * float(src_t.pixel_height) - attrs['transform'] = ( - float(src_t.pixel_width), 0.0, origin_x_w, - 0.0, float(src_t.pixel_height), origin_y_w, - ) - else: - attrs['transform'] = _transform_tuple(geo_info) - - # CRS description fields - if geo_info.crs_name is not None: - attrs['crs_name'] = geo_info.crs_name - if geo_info.geog_citation is not None: - attrs['geog_citation'] = geo_info.geog_citation - if geo_info.datum_code is not None: - attrs['datum_code'] = geo_info.datum_code - if geo_info.angular_units is not None: - attrs['angular_units'] = geo_info.angular_units - if geo_info.linear_units is not None: - attrs['linear_units'] = geo_info.linear_units - if geo_info.semi_major_axis is not None: - attrs['semi_major_axis'] = geo_info.semi_major_axis - if geo_info.inv_flattening is not None: - attrs['inv_flattening'] = geo_info.inv_flattening - if geo_info.projection_code is not None: - attrs['projection_code'] = geo_info.projection_code - # Vertical CRS - if geo_info.vertical_epsg is not None: - attrs['vertical_crs'] = geo_info.vertical_epsg - if geo_info.vertical_citation is not None: - attrs['vertical_citation'] = geo_info.vertical_citation - if geo_info.vertical_units is not None: - attrs['vertical_units'] = geo_info.vertical_units - - # GDAL metadata (tag 42112) - if geo_info.gdal_metadata is not None: - attrs['gdal_metadata'] = geo_info.gdal_metadata - if geo_info.gdal_metadata_xml is not None: - attrs['gdal_metadata_xml'] = geo_info.gdal_metadata_xml - - # Extra (non-managed) TIFF tags for pass-through - if geo_info.extra_tags is not None: - attrs['extra_tags'] = geo_info.extra_tags - - # Friendly accessors for a few common pass-through tags. The raw - # entry stays in attrs['extra_tags'] so the writer can re-emit the - # exact bytes; users who tweak these convenience attrs can rely on - # to_geotiff to fold the new value into extra_tags before write. - if geo_info.image_description is not None: - attrs['image_description'] = geo_info.image_description - if geo_info.extra_samples is not None: - attrs['extra_samples'] = geo_info.extra_samples - - # Resolution / DPI metadata - if geo_info.x_resolution is not None: - attrs['x_resolution'] = geo_info.x_resolution - if geo_info.y_resolution is not None: - attrs['y_resolution'] = geo_info.y_resolution - if geo_info.resolution_unit is not None: - _unit_names = {1: 'none', 2: 'inch', 3: 'centimeter'} - attrs['resolution_unit'] = _unit_names.get( - geo_info.resolution_unit, str(geo_info.resolution_unit)) - - # Attach palette colormap for indexed-color TIFFs. The normalized - # RGBA triples drive matplotlib display; the raw uint16 ColorMap - # tag value lives in attrs['extra_tags'] for round-trip and is - # exposed here as attrs['colormap'] for convenience. - if geo_info.colormap is not None: - try: - from matplotlib.colors import ListedColormap - cmap = ListedColormap(geo_info.colormap, name='tiff_palette') - attrs['cmap'] = cmap - attrs['colormap_rgba'] = geo_info.colormap - except ImportError: - # matplotlib not available -- store raw RGBA tuples only - attrs['colormap_rgba'] = geo_info.colormap - - # Raw uint16 ColorMap tag value (3 * 2**bps entries, R-then-G-then-B) - if geo_info.extra_tags is not None: - for _tag_id, _tt, _tc, _tv in geo_info.extra_tags: - if _tag_id == 320: # TAG_COLORMAP - attrs['colormap'] = _tv - break + _populate_attrs_from_geo_info(attrs, geo_info, window=window) # Apply nodata mask: replace nodata sentinel values with NaN. # ``arr`` came from ``read_to_array``, which returns a freshly @@ -1410,15 +1411,9 @@ def read_geotiff_dask(source: str, *, dtype=None, chunks: int | tuple = 512, name = os.path.splitext(os.path.basename(source))[0] attrs = {} - if geo_info.crs_epsg is not None: - attrs['crs'] = geo_info.crs_epsg - if geo_info.raster_type == RASTER_PIXEL_IS_POINT: - attrs['raster_type'] = 'point' + _populate_attrs_from_geo_info(attrs, geo_info) if nodata is not None: attrs['nodata'] = nodata - transform_tuple = _transform_tuple(geo_info) - if transform_tuple is not None: - attrs['transform'] = transform_tuple if isinstance(chunks, int): ch_h = ch_w = chunks @@ -1859,11 +1854,7 @@ def read_geotiff_gpu(source: str, *, import os name = os.path.splitext(os.path.basename(source))[0] attrs = {} - if geo_info.crs_epsg is not None: - attrs['crs'] = geo_info.crs_epsg - t_tuple = _transform_tuple(geo_info) - if t_tuple is not None: - attrs['transform'] = t_tuple + _populate_attrs_from_geo_info(attrs, geo_info) # Apply nodata mask + record sentinel so the GPU read agrees # with the CPU eager path. Without this, integer rasters keep # the literal sentinel value and float rasters keep the @@ -2127,13 +2118,7 @@ def _read_once(): coords = _geo_to_coords(geo_info, out_h, out_w) attrs = {} - if geo_info.crs_epsg is not None: - attrs['crs'] = geo_info.crs_epsg - if geo_info.crs_wkt is not None: - attrs['crs_wkt'] = geo_info.crs_wkt - t_tuple = _transform_tuple(geo_info) - if t_tuple is not None: - attrs['transform'] = t_tuple + _populate_attrs_from_geo_info(attrs, geo_info) if nodata is not None: attrs['nodata'] = nodata diff --git a/xrspatial/geotiff/tests/test_attrs_parity_1548.py b/xrspatial/geotiff/tests/test_attrs_parity_1548.py new file mode 100644 index 000000000..720196f29 --- /dev/null +++ b/xrspatial/geotiff/tests/test_attrs_parity_1548.py @@ -0,0 +1,187 @@ +"""4-backend attrs parity tests for issue #1548. + +Before the fix, ``open_geotiff`` returned a different ``attrs`` set +depending on which backend handled the read: + +* numpy (eager): full set, including ``x_resolution``, ``y_resolution``, + ``resolution_unit``, ``extra_tags``, ``image_description``, + ``extra_samples``. +* dask: only ``crs``, ``transform``, ``raster_type``, ``nodata``. +* cupy / dask+cupy: only ``crs``, ``crs_wkt``, ``transform``, ``nodata``. + +The fix factors a single ``_populate_attrs_from_geo_info`` helper and +calls it from every read path, so all four backends now emit the same +keys with the same values for the same input file. + +These tests pin that contract. +""" +from __future__ import annotations + +import importlib.util + +import numpy as np +import pytest +import xarray as xr + +tifffile = pytest.importorskip("tifffile") + +from xrspatial.geotiff import open_geotiff, to_geotiff + + +def _gpu_available() -> bool: + if importlib.util.find_spec("cupy") is None: + return False + try: + import cupy + return bool(cupy.cuda.is_available()) + except Exception: + return False + + +_HAS_GPU = _gpu_available() +_gpu_only = pytest.mark.skipif(not _HAS_GPU, reason="cupy + CUDA required") + + +def _write_tiff_with_pass_through_tags(path): + """Write a tiled 2-band float32 TIFF that exercises the pass-through + TIFF tags called out in the issue. + + Uses tifffile's first-class ``resolution`` / ``resolutionunit`` / + ``description`` kwargs (its preferred path; ``extratags`` would be + silently dropped for the resolution rationals). ``metadata=None`` + suppresses tifffile's auto-generated shape JSON in ImageDescription + so the fixture description survives. ``ExtraSamples`` (338) is + auto-derived from the 2-band layout. + """ + arr = np.random.default_rng(seed=1548).random( + (64, 64, 2)).astype(np.float32) + tifffile.imwrite( + path, arr, photometric='minisblack', planarconfig='contig', + tile=(32, 32), compression='deflate', + resolution=(300, 300), resolutionunit=2, + description='issue-1548 parity fixture', + metadata=None, + ) + return arr + + +def _attrs_subset(attrs, keys): + """Slice attrs down to the keys we are comparing across backends. + + ``cmap`` is a matplotlib object whose equality is brittle; ``transform`` + is a tuple of floats so handle separately. + """ + return {k: attrs.get(k) for k in keys} + + +_PASS_THROUGH_KEYS = ( + 'x_resolution', + 'y_resolution', + 'resolution_unit', + 'image_description', + 'extra_samples', +) + + +def test_numpy_attrs_includes_pass_through_tags(tmp_path): + """Sanity baseline: the eager numpy path emits all pass-through keys.""" + path = str(tmp_path / 'attrs_parity_1548_baseline.tif') + _write_tiff_with_pass_through_tags(path) + + da = open_geotiff(path) + for key in _PASS_THROUGH_KEYS: + assert key in da.attrs, ( + f"eager numpy is the canonical reference and should always " + f"emit '{key}'; got attrs={sorted(da.attrs.keys())}" + ) + + +def test_dask_attrs_match_numpy(tmp_path): + """The dask read path now emits the same pass-through attrs as numpy.""" + path = str(tmp_path / 'attrs_parity_1548_dask.tif') + _write_tiff_with_pass_through_tags(path) + + np_da = open_geotiff(path) + dk_da = open_geotiff(path, chunks=32) + + np_subset = _attrs_subset(np_da.attrs, _PASS_THROUGH_KEYS) + dk_subset = _attrs_subset(dk_da.attrs, _PASS_THROUGH_KEYS) + + assert dk_subset == np_subset, ( + f"dask attrs diverge from numpy:\n" + f" numpy: {np_subset}\n" + f" dask : {dk_subset}" + ) + + +@_gpu_only +def test_cupy_attrs_match_numpy(tmp_path): + """Cupy / GPU read emits the same pass-through attrs as numpy.""" + path = str(tmp_path / 'attrs_parity_1548_cupy.tif') + _write_tiff_with_pass_through_tags(path) + + np_da = open_geotiff(path) + gpu_da = open_geotiff(path, gpu=True) + + np_subset = _attrs_subset(np_da.attrs, _PASS_THROUGH_KEYS) + gpu_subset = _attrs_subset(gpu_da.attrs, _PASS_THROUGH_KEYS) + + assert gpu_subset == np_subset, ( + f"cupy attrs diverge from numpy:\n" + f" numpy: {np_subset}\n" + f" cupy : {gpu_subset}" + ) + + +@_gpu_only +def test_dask_cupy_attrs_match_numpy(tmp_path): + """Combined dask+cupy read still emits the pass-through attrs.""" + path = str(tmp_path / 'attrs_parity_1548_dask_cupy.tif') + _write_tiff_with_pass_through_tags(path) + + np_da = open_geotiff(path) + combined = open_geotiff(path, gpu=True, chunks=32) + + np_subset = _attrs_subset(np_da.attrs, _PASS_THROUGH_KEYS) + combined_subset = _attrs_subset(combined.attrs, _PASS_THROUGH_KEYS) + + assert combined_subset == np_subset, ( + f"dask+cupy attrs diverge from numpy:\n" + f" numpy : {np_subset}\n" + f" dask+cupy : {combined_subset}" + ) + + +def test_all_backend_attrs_keysets_equal(tmp_path): + """Strong contract: the *set* of attrs keys is identical across all + available backends, not just the pass-through subset. + + This guards against any future read path silently dropping a + different attr that nobody happened to test for above. + """ + path = str(tmp_path / 'attrs_parity_1548_keysets.tif') + _write_tiff_with_pass_through_tags(path) + + np_keys = set(open_geotiff(path).attrs.keys()) + dk_keys = set(open_geotiff(path, chunks=32).attrs.keys()) + + backend_keys = {'numpy': np_keys, 'dask+numpy': dk_keys} + if _HAS_GPU: + backend_keys['cupy'] = set(open_geotiff(path, gpu=True).attrs.keys()) + backend_keys['dask+cupy'] = set( + open_geotiff(path, gpu=True, chunks=32).attrs.keys()) + + differences = { + name: keys ^ np_keys + for name, keys in backend_keys.items() + if keys != np_keys + } + assert not differences, ( + f"backend attrs keysets diverge from numpy:\n" + f" numpy keys: {sorted(np_keys)}\n" + f" diffs : " + + "\n ".join( + f"{name}: symmetric_diff={sorted(diff)}" + for name, diff in differences.items() + ) + ) From 7f6d529237277fbe76acc8e6fde15797a1f3dca2 Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Sun, 10 May 2026 21:44:48 -0700 Subject: [PATCH 2/5] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- xrspatial/geotiff/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/xrspatial/geotiff/__init__.py b/xrspatial/geotiff/__init__.py index f13b4442b..37c69fb46 100644 --- a/xrspatial/geotiff/__init__.py +++ b/xrspatial/geotiff/__init__.py @@ -298,8 +298,7 @@ def _extent_to_window(transform, file_height, file_width, return (row_start, col_start, row_stop, col_stop) -def _populate_attrs_from_geo_info(attrs: dict, geo_info, *, - window=None) -> None: +def _populate_attrs_from_geo_info(attrs: dict, geo_info, *, window=None) -> None: """Populate ``attrs`` with all GeoTIFF metadata from ``geo_info``. Centralised so the eager numpy, dask, and GPU read paths emit the From fa0c90abccf075a86898a740dc49dd2bfc99e2e1 Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Sun, 10 May 2026 21:45:00 -0700 Subject: [PATCH 3/5] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- xrspatial/geotiff/tests/test_attrs_parity_1548.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xrspatial/geotiff/tests/test_attrs_parity_1548.py b/xrspatial/geotiff/tests/test_attrs_parity_1548.py index 720196f29..c83105cbc 100644 --- a/xrspatial/geotiff/tests/test_attrs_parity_1548.py +++ b/xrspatial/geotiff/tests/test_attrs_parity_1548.py @@ -23,10 +23,10 @@ import pytest import xarray as xr -tifffile = pytest.importorskip("tifffile") - from xrspatial.geotiff import open_geotiff, to_geotiff +tifffile = pytest.importorskip("tifffile") + def _gpu_available() -> bool: if importlib.util.find_spec("cupy") is None: From 3faac730ab05878ba9136064638f7bc464e91c3b Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Sun, 10 May 2026 21:50:39 -0700 Subject: [PATCH 4/5] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- xrspatial/geotiff/tests/test_attrs_parity_1548.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xrspatial/geotiff/tests/test_attrs_parity_1548.py b/xrspatial/geotiff/tests/test_attrs_parity_1548.py index c83105cbc..9a3fb404e 100644 --- a/xrspatial/geotiff/tests/test_attrs_parity_1548.py +++ b/xrspatial/geotiff/tests/test_attrs_parity_1548.py @@ -66,10 +66,10 @@ def _write_tiff_with_pass_through_tags(path): def _attrs_subset(attrs, keys): - """Slice attrs down to the keys we are comparing across backends. + """Return a dict containing only the requested attr keys. - ``cmap`` is a matplotlib object whose equality is brittle; ``transform`` - is a tuple of floats so handle separately. + This helper performs a simple ``attrs.get`` lookup for each key and + does not special-case any values. """ return {k: attrs.get(k) for k in keys} From 5f0e8c9acd9f569dbb25816201e358813b1976f0 Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Sun, 10 May 2026 21:50:44 -0700 Subject: [PATCH 5/5] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- xrspatial/geotiff/tests/test_attrs_parity_1548.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/xrspatial/geotiff/tests/test_attrs_parity_1548.py b/xrspatial/geotiff/tests/test_attrs_parity_1548.py index 9a3fb404e..15a104eb5 100644 --- a/xrspatial/geotiff/tests/test_attrs_parity_1548.py +++ b/xrspatial/geotiff/tests/test_attrs_parity_1548.py @@ -21,9 +21,8 @@ import numpy as np import pytest -import xarray as xr -from xrspatial.geotiff import open_geotiff, to_geotiff +from xrspatial.geotiff import open_geotiff tifffile = pytest.importorskip("tifffile")