Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xrspatial/geotiff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
transform_tuple as _transform_tuple,
transform_from_attr as _transform_from_attr,
coords_to_transform as _coords_to_transform,
require_transform_for_georeferenced as _require_transform_for_georeferenced,
)
from ._geotags import GeoTransform, RASTER_PIXEL_IS_AREA, RASTER_PIXEL_IS_POINT
from ._reader import UnsafeURLError
Expand Down
80 changes: 74 additions & 6 deletions xrspatial/geotiff/_coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,49 @@ def transform_from_attr(attr_val) -> 'GeoTransform | None':
)


def require_transform_for_georeferenced(
da: xr.DataArray, geo_transform
) -> None:
"""Raise if ``da`` carries spatial coords but no transform was derived.

Used by the writer entry points (#1945). A DataArray whose spatial
dim names appear in ``da.coords`` is an explicit caller request for a
georeferenced output. Silently falling through to a non-georeferenced
TIFF -- which is what the old code did for 1x1 inputs and inputs with
a degenerate axis -- corrupted round-trips. If the writer cannot
recover a transform from coords *and* the caller did not supply
``attrs['transform']``, fail closed instead.

``geo_transform`` is the value the writer has already resolved (from
``attrs['transform']`` first, then from coord arrays). If it's not
``None`` we have a transform and there's nothing to check.
"""
if geo_transform is not None:
return
if da.ndim == 3:
spatial = tuple(d for d in da.dims if d not in _BAND_DIM_NAMES)
if len(spatial) == 2:
ydim, xdim = spatial[0], spatial[1]
else:
ydim = da.dims[-2]
xdim = da.dims[-1]
else:
ydim = da.dims[-2]
xdim = da.dims[-1]
if xdim in da.coords and ydim in da.coords:
raise ValueError(
f"Cannot infer GeoTIFF transform from a "
f"{tuple(da.sizes.values())} array with spatial coords on "
f"both axes: both axes are degenerate (1x1), so neither "
f"coord array carries a pixel size step. 1xN and Nx1 inputs "
f"recover the pixel size from the non-degenerate axis, but "
f"a 1x1 cannot. Supply the affine transform explicitly via "
f"``attrs['transform']`` (rasterio 6-tuple "
f"``(px, 0, ox, 0, py, oy)``) or drop the coords if a "
f"non-georeferenced TIFF is desired."
)


def coords_to_transform(da: xr.DataArray) -> 'GeoTransform | None':
"""Infer GeoTransform from DataArray coordinates.

Expand Down Expand Up @@ -274,7 +317,11 @@ def coords_to_transform(da: xr.DataArray) -> 'GeoTransform | None':
x = da.coords[xdim].values
y = da.coords[ydim].values

if len(x) < 2 or len(y) < 2:
# 1x1 has no pixel-size signal on either axis. The caller must supply
# ``attrs['transform']`` (handled by the writer before calling us).
# Returning ``None`` lets the writer detect this and raise rather than
# silently writing a non-georeferenced TIFF (#1945).
if len(x) < 2 and len(y) < 2:
return None

# GeoTIFF only supports an affine transform; non-uniform spacing
Expand All @@ -298,11 +345,32 @@ def _is_regular(coord, name):
f"GeoTIFF requires an affine transform."
)

_is_regular(x, "x")
_is_regular(y, "y")

pixel_width = float(x[1] - x[0])
pixel_height = float(y[1] - y[0])
# Degenerate-axis fallback (#1945). When one axis has length 1, we
# can't read a step off it (``coord[1] - coord[0]`` is undefined),
# so we recover the per-axis pixel size from the non-degenerate
# axis and assume square pixels for the degenerate one. That matches
# how every other geospatial reader handles 1xN / Nx1 strips. The
# earlier behaviour — bailing out and silently writing a
# non-georeferenced TIFF — broke round-trips for legitimate
# single-scanline / single-profile rasters.
if len(x) >= 2:
_is_regular(x, "x")
pixel_width = float(x[1] - x[0])
else:
pixel_width = None
if len(y) >= 2:
_is_regular(y, "y")
pixel_height = float(y[1] - y[0])
else:
pixel_height = None

if pixel_width is None:
# Borrow magnitude from y; x increases left-to-right by convention.
pixel_width = abs(pixel_height)
if pixel_height is None:
# Borrow magnitude from x; y decreases top-to-bottom by convention,
# so flip sign.
pixel_height = -abs(pixel_width)

is_point = da.attrs.get('raster_type') == 'point'
if is_point:
Expand Down
9 changes: 9 additions & 0 deletions xrspatial/geotiff/_writers/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .._coords import (
_BAND_DIM_NAMES,
coords_to_transform as _coords_to_transform,
require_transform_for_georeferenced as _require_transform_for_georeferenced,
transform_from_attr as _transform_from_attr,
)
from .._crs import _validate_crs_fallback, _wkt_to_epsg
Expand Down Expand Up @@ -517,6 +518,11 @@ def to_geotiff(data: xr.DataArray | np.ndarray,
geo_transform = _transform_from_attr(data.attrs.get('transform'))
if geo_transform is None:
geo_transform = _coords_to_transform(data)
# Fail closed when coords are present but no transform could be
# derived (e.g. 1x1 without ``attrs['transform']``) instead of
# silently writing a non-georeferenced TIFF that round-trips back
# with integer pixel coords (#1945).
_require_transform_for_georeferenced(data, geo_transform)
if epsg is None and crs is None:
crs_attr = data.attrs.get('crs')
if isinstance(crs_attr, str):
Expand Down Expand Up @@ -837,6 +843,9 @@ def _write_vrt_tiled(data, vrt_path, *, crs=None, nodata=None,
geo_transform = _transform_from_attr(data.attrs.get('transform'))
if geo_transform is None:
geo_transform = _coords_to_transform(data)
# Match the to_geotiff fail-closed guard so VRT writes don't
# silently produce non-georeferenced tiles either (#1945).
_require_transform_for_georeferenced(data, geo_transform)
# Pull the same rich-tag set that to_geotiff forwards to
# ``write`` so per-tile files under the VRT carry it too.
_rich = _extract_rich_tags(data.attrs)
Expand Down
6 changes: 6 additions & 0 deletions xrspatial/geotiff/_writers/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .._coords import (
_BAND_DIM_NAMES,
coords_to_transform as _coords_to_transform,
require_transform_for_georeferenced as _require_transform_for_georeferenced,
transform_from_attr as _transform_from_attr,
)
from .._crs import _validate_crs_fallback, _wkt_to_epsg
Expand Down Expand Up @@ -348,6 +349,11 @@ def write_geotiff_gpu(data: xr.DataArray | cupy.ndarray | np.ndarray,
geo_transform = _transform_from_attr(data.attrs.get('transform'))
if geo_transform is None:
geo_transform = _coords_to_transform(data)
# Match the CPU writer's fail-closed guard: an array with spatial
# coords but no derivable transform (e.g. 1x1 without
# ``attrs['transform']``) must not silently round-trip as a
# non-georeferenced TIFF (#1945).
_require_transform_for_georeferenced(data, geo_transform)
# Resolve CRS the same way the CPU writer does. attrs['crs'] may
# be an int EPSG or a WKT string; attrs['crs_wkt'] only carries
# WKT. Without the WKT branch the GPU writer silently drops CRS
Expand Down
Loading
Loading