diff --git a/src/rasterix/lib.py b/src/rasterix/lib.py index d82497e..a3b0e34 100644 --- a/src/rasterix/lib.py +++ b/src/rasterix/lib.py @@ -1,9 +1,14 @@ """Shared library utilities for rasterix.""" import logging +from typing import NotRequired, TypedDict from affine import Affine +# https://github.com/zarr-conventions/spatial +_ZARR_SPATIAL_CONVENTION_UUID = "689b58e2-cf7b-45e0-9fff-9cfc0883d6b4" + + # Define TRACE level (lower than DEBUG) TRACE = 5 logging.addLevelName(TRACE, "TRACE") @@ -104,3 +109,72 @@ def affine_from_stac_proj_metadata(metadata: dict) -> Affine | None: a, b, c, d, e, f = transform[:6] return Affine(a, b, c, d, e, f) + + +_ZarrConventionRegistration = TypedDict("_ZarrConventionRegistration", {"spatial:": str}) + +_ZarrSpatialMetadata = TypedDict( + "_ZarrSpatialMetadata", + { + "zarr_conventions": NotRequired[list[_ZarrConventionRegistration | dict]], + "spatial:transform": NotRequired[list[float]], + "spatial:transform_type": NotRequired[str], + "spatial:registration": NotRequired[str], + }, +) + + +def _has_spatial_zarr_convention(metadata: _ZarrSpatialMetadata) -> bool: + zarr_conventions = metadata.get("zarr_conventions") + if not zarr_conventions: + return False + for entry in zarr_conventions: + if isinstance(entry, dict) and ( + entry.get("uuid") == _ZARR_SPATIAL_CONVENTION_UUID or entry.get("name") == "spatial:" + ): + return True + return False + + +def affine_from_spatial_zarr_convention(metadata: dict) -> Affine | None: + """Extract Affine transform from Zarr spatial convention metadata. + + See https://github.com/zarr-conventions/spatial for the full specification. + + Parameters + ---------- + metadata : dict + Dictionary containing Zarr spatial convention metadata. + + Returns + ------- + Affine or None + Affine transformation matrix if minimal Zarr spatial metadata is found, None otherwise. + + Examples + -------- + >>> ds: xr.Dataset = ... + >>> affine = affine_from_spatial_zarr_convention(ds.attrs) + """ + possibly_spatial_metadata: _ZarrSpatialMetadata = metadata # type: ignore[assignment] + + if _has_spatial_zarr_convention(possibly_spatial_metadata): + if transform := possibly_spatial_metadata.get("spatial:transform"): + if len(transform) < 6: + raise ValueError(f"spatial:transform must have at least 6 elements, got {len(transform)}") + + transform_type = possibly_spatial_metadata.get("spatial:transform_type", "affine") + if transform_type != "affine": + raise NotImplementedError( + f"Unsupported spatial:transform_type {transform_type!r}; only 'affine' is supported." + ) + + registration = possibly_spatial_metadata.get("spatial:registration", "pixel") + if registration != "pixel": + raise NotImplementedError( + f"Unsupported spatial:registration {registration!r}; only 'pixel' is supported." + ) + + return Affine(*map(float, transform[:6])) + + return None diff --git a/src/rasterix/raster_index.py b/src/rasterix/raster_index.py index 88d65a4..89c7591 100644 --- a/src/rasterix/raster_index.py +++ b/src/rasterix/raster_index.py @@ -22,7 +22,7 @@ from rasterix.odc_compat import BoundingBox, bbox_intersection, bbox_union, maybe_int, snap_grid from rasterix.options import get_options as get_rasterix_options from rasterix.rioxarray_compat import guess_dims -from rasterix.utils import get_affine +from rasterix.utils import get_affine, get_crs_from_proj_zarr_convention T_Xarray = TypeVar("T_Xarray", "DataArray", "Dataset") @@ -88,13 +88,17 @@ def assign_index( affine = get_affine(obj, x_dim=x_dim, y_dim=y_dim, clear_transform=True) + detected_crs = obj.proj.crs if crs else None + if detected_crs is None: + detected_crs = get_crs_from_proj_zarr_convention(obj) + index = RasterIndex.from_transform( affine, width=obj.sizes[x_dim], height=obj.sizes[y_dim], x_dim=x_dim, y_dim=y_dim, - crs=obj.proj.crs if crs else None, + crs=detected_crs, ) coords = Coordinates.from_xindex(index) return obj.assign_coords(coords) diff --git a/src/rasterix/utils.py b/src/rasterix/utils.py index 307f43b..857e18f 100644 --- a/src/rasterix/utils.py +++ b/src/rasterix/utils.py @@ -1,7 +1,18 @@ +from typing import NotRequired, TypedDict + import xarray as xr from affine import Affine +from pyproj import CRS + +from rasterix.lib import ( + affine_from_spatial_zarr_convention, + affine_from_stac_proj_metadata, + affine_from_tiepoint_and_scale, + logger, +) -from rasterix.lib import affine_from_stac_proj_metadata, affine_from_tiepoint_and_scale, logger +# https://github.com/zarr-conventions/geo-proj +_ZARR_GEO_PROJ_CONVENTION_UUID = "f17cb550-5864-4468-aeb7-f3180cfb622f" def get_grid_mapping_var(obj: xr.Dataset | xr.DataArray) -> xr.DataArray | None: @@ -58,7 +69,7 @@ def get_affine( del grid_mapping_var.attrs["GeoTransform"] return Affine.from_gdal(*map(float, transform.split(" "))) - # Check for STAC and GeoTIFF metadata in DataArray attrs + # Check for STAC, GeoTIFF, or spatial zarr convention metadata in DataArray attrs attrs = obj.attrs if isinstance(obj, xr.DataArray) else {} # Try to extract affine from STAC proj:transform @@ -80,6 +91,13 @@ def get_affine( return affine + # Try to extract from spatial zarr convention attributes + if affine := affine_from_spatial_zarr_convention(attrs): + logger.trace("Creating affine from spatial zarr convention attributes") + if clear_transform: + del attrs["spatial:transform"] + return affine + # Fall back to computing from coordinate arrays logger.trace(f"Creating affine from coordinate arrays {x_dim=!r} and {y_dim=!r}") if x_dim not in obj.coords or y_dim not in obj.coords: @@ -106,3 +124,57 @@ def get_affine( return Affine.translation( x[0].item() - dx / 2, (y[0] if dy < 0 else y[-1]).item() - dy / 2 ) * Affine.scale(dx, dy) + + +_ZarrConventionRegistration = TypedDict("_ZarrConventionRegistration", {"proj:": str}) + +_ZarrProjMetadata = TypedDict( + "_ZarrProjMetadata", + { + "zarr_conventions": NotRequired[list[_ZarrConventionRegistration | dict]], + "proj:code": NotRequired[str], + "proj:wkt2": NotRequired[str], + "proj:projjson": NotRequired[object], + }, +) + + +def _has_proj_zarr_convention(metadata: _ZarrProjMetadata) -> bool: + zarr_conventions = metadata.get("zarr_conventions") + if not zarr_conventions: + return False + for entry in zarr_conventions: + if isinstance(entry, dict) and ( + entry.get("uuid") == _ZARR_GEO_PROJ_CONVENTION_UUID or entry.get("name") == "proj:" + ): + return True + return False + + +def get_crs_from_proj_zarr_convention(obj: xr.Dataset | xr.DataArray) -> CRS | None: + """Extract CRS from Zarr proj: convention metadata if present. + + See https://github.com/zarr-conventions/geo-proj for more details. + + Parameters + ---------- + obj: xr.Dataset or xr.DataArray + The Xarray object to extract CRS from. + + Returns + ------- + CRS or None + The extracted CRS object, or None if not found. + """ + metadata: _ZarrProjMetadata = obj.attrs # type: ignore[assignment] + + if not _has_proj_zarr_convention(metadata): + return None + + if code := metadata.get("proj:code"): + return CRS.from_string(code) + if wkt2 := metadata.get("proj:wkt2"): + return CRS.from_wkt(wkt2) + if projjson := metadata.get("proj:projjson"): + return CRS.from_user_input(projjson) + return None diff --git a/tests/test_raster_index.py b/tests/test_raster_index.py index 7031d9e..98d6b43 100644 --- a/tests/test_raster_index.py +++ b/tests/test_raster_index.py @@ -567,6 +567,82 @@ def test_assign_index_with_stac_proj_transform_9_elements(): assert actual_affine == expected_affine +@pytest.mark.parametrize( + "convention_spec", + [ + {"name": "spatial:"}, # optional + {"uuid": "689b58e2-cf7b-45e0-9fff-9cfc0883d6b4"}, # mandatory + ], +) +def test_assign_index_with_spatial_zarr_convention(convention_spec: dict[str, str]): + da = xr.DataArray( + np.ones((100, 100)), + dims=("y", "x"), + attrs={ + "zarr_conventions": [convention_spec], + "spatial:transform": [30.0, 0.0, 323400.0, 0.0, 30.0, 4268400.0], + }, + ) + + result = assign_index(da) + + # Check that the index was created + assert isinstance(result.xindexes["x"], RasterIndex) + assert isinstance(result.xindexes["y"], RasterIndex) + + # Verify the affine transform + expected_affine = Affine(30.0, 0.0, 323400.0, 0.0, 30.0, 4268400.0) + actual_affine = result.xindexes["x"].transform() + assert actual_affine == expected_affine + + # Verify spatial:transform attribute is removed + assert "spatial:transform" not in result.attrs + + +def test_assign_index_with_spatial_zarr_convention_too_few_raises(): + da = xr.DataArray( + np.ones((100, 100)), + dims=("y", "x"), + attrs={ + "zarr_conventions": [{"name": "spatial:"}], + "spatial:transform": [30.0, 0.0, 323400.0, 0.0, 30.0], + }, + ) + + with pytest.raises(ValueError, match="spatial:transform must have at least 6 elements"): + assign_index(da) + + +def test_assign_index_with_spatial_zarr_convention_transform_type_not_implemented(): + da = xr.DataArray( + np.ones((100, 100)), + dims=("y", "x"), + attrs={ + "zarr_conventions": [{"name": "spatial:"}], + "spatial:transform_type": "not_affine", + "spatial:transform": [30.0, 0.0, 323400.0, 0.0, 30.0, 4268400.0], + }, + ) + + with pytest.raises(NotImplementedError, match="Unsupported spatial:transform_type"): + assign_index(da) + + +def test_assign_index_with_spatial_zarr_convention_registration_not_implemented(): + da = xr.DataArray( + np.ones((100, 100)), + dims=("y", "x"), + attrs={ + "zarr_conventions": [{"name": "spatial:"}], + "spatial:registration": "not_pixel", + "spatial:transform": [30.0, 0.0, 323400.0, 0.0, 30.0, 4268400.0], + }, + ) + + with pytest.raises(NotImplementedError, match="Unsupported spatial:registration"): + assign_index(da) + + def test_assign_index_no_coords_no_metadata(): """Test that assign_index raises error when coords are missing and no transform metadata.""" da = xr.DataArray(np.ones((10, 10)), dims=("y", "x")) @@ -709,6 +785,60 @@ def test_raster_index_from_stac_proj_metadata_with_crs(): assert index.crs.to_epsg() == 32610 +@pytest.mark.parametrize( + "convention_spec", + [ + {"name": "proj:"}, # optional + {"uuid": "f17cb550-5864-4468-aeb7-f3180cfb622f"}, # mandatory + ], +) +def test_assign_index_proj_zarr_convention_code(convention_spec: dict[str, str]): + ds = xr.DataArray( + np.ones((3, 4)), + dims=("y", "x"), + attrs={ + "zarr_conventions": [convention_spec, {"name": "spatial:"}], + "proj:code": "EPSG:4326", + "spatial:transform": [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + }, + ) + indexed = assign_index(ds) + assert indexed.xindexes["x"].crs is not None + assert indexed.xindexes["x"].crs.to_epsg() == 4326 + + +def test_assign_index_proj_zarr_convention_wkt2(): + crs = pyproj.CRS.from_epsg(3857) + ds = xr.DataArray( + np.ones((3, 4)), + dims=("y", "x"), + attrs={ + "zarr_conventions": [{"name": "proj:"}, {"name": "spatial:"}], + "proj:wkt2": crs.to_wkt(), + "spatial:transform": [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + }, + ) + indexed = assign_index(ds) + assert indexed.xindexes["x"].crs is not None + assert indexed.xindexes["x"].crs.to_epsg() == 3857 + + +def test_assign_index_proj_zarr_convention_projjson(): + crs = pyproj.CRS.from_epsg(32610) + ds = xr.DataArray( + np.ones((3, 4)), + dims=("y", "x"), + attrs={ + "zarr_conventions": [{"name": "proj:"}, {"name": "spatial:"}], + "proj:projjson": crs.to_json_dict(), + "spatial:transform": [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + }, + ) + indexed = assign_index(ds) + assert indexed.xindexes["x"].crs is not None + assert indexed.xindexes["x"].crs.to_epsg() == 32610 + + @pytest.fixture def edge_case_ds(): """Create a 100x100 grid covering x=[0, 10], y=[-10, 0]."""