diff --git a/src/rasterix/rasterize/core.py b/src/rasterix/rasterize/core.py index 98e8863..6286d0a 100644 --- a/src/rasterix/rasterize/core.py +++ b/src/rasterix/rasterize/core.py @@ -8,14 +8,25 @@ import numpy as np import xarray as xr +from ..raster_index import RasterIndex from ..utils import get_affine, get_grid_mapping_var from .utils import XAXIS, YAXIS, clip_to_bbox, is_in_memory, prepare_for_dask if TYPE_CHECKING: import dask_geopandas + from affine import Affine __all__ = ["rasterize", "geometry_mask", "geometry_clip"] + +def _get_affine(obj: xr.Dataset | xr.DataArray, *, x_dim: str, y_dim: str) -> Affine: + """Get affine transform, preferring RasterIndex if available.""" + idx = obj.xindexes.get(x_dim) + if isinstance(idx, RasterIndex): + return idx.transform() + return get_affine(obj, x_dim=x_dim, y_dim=y_dim) + + Engine = Literal["rasterio", "rusterize", "exactextract"] @@ -222,7 +233,7 @@ def rasterize( if clip: obj = clip_to_bbox(obj, geometries, xdim=xdim, ydim=ydim) - affine = get_affine(obj, x_dim=xdim, y_dim=ydim) + affine = _get_affine(obj, x_dim=xdim, y_dim=ydim) engine_merge_alg = _normalize_merge_alg(merge_alg, resolved_engine) rasterize_geometries, dask_rasterize_wrapper = _get_rasterize_funcs(resolved_engine) @@ -370,7 +381,7 @@ def geometry_mask( if clip: obj = clip_to_bbox(obj, geometries, xdim=xdim, ydim=ydim) - affine = get_affine(obj, x_dim=xdim, y_dim=ydim) + affine = _get_affine(obj, x_dim=xdim, y_dim=ydim) np_geometry_mask, dask_mask_wrapper = _get_mask_funcs(resolved_engine) diff --git a/src/rasterix/rasterize/exact.py b/src/rasterix/rasterize/exact.py index e4c1dcc..99a2138 100644 --- a/src/rasterix/rasterize/exact.py +++ b/src/rasterix/rasterize/exact.py @@ -10,6 +10,7 @@ from exactextract import exact_extract from exactextract.raster import NumPyRasterSource +from ..rasterize.core import _get_affine from ..utils import get_grid_mapping_var from .utils import clip_to_bbox, geometries_as_dask_array, is_in_memory @@ -103,9 +104,9 @@ def get_dtype(coverage_weight: CoverageWeights, geometries): def np_coverage( - x: np.ndarray, - y: np.ndarray, + affine, *, + shape: tuple[int, int], geometries: gpd.GeoDataFrame, strategy: Strategy = DEFAULT_STRATEGY, coverage_weight: CoverageWeights = "fraction", @@ -115,8 +116,7 @@ def np_coverage( if len(geometries.columns) > 1: raise ValueError("Require a single geometries column or a GeoSeries.") - shape = (y.size, x.size) - raster = xy_to_raster_source(x, y, srs_wkt=geometries.crs.to_wkt()) + raster = affine_to_raster_source(affine, shape, srs_wkt=geometries.crs.to_wkt()) result = exact_extract( rast=raster, vec=geometries, @@ -161,43 +161,59 @@ def np_coverage( def coverage_np_dask_wrapper( geom_array: np.ndarray, - x: np.ndarray, - y: np.ndarray, + x_offsets: np.ndarray, + y_offsets: np.ndarray, + x_sizes: np.ndarray, + y_sizes: np.ndarray, + affine, coverage_weight: CoverageWeights, strategy: Strategy, crs, ) -> np.ndarray: + shape = (y_sizes.item(), x_sizes.item()) + chunk_affine = affine * affine.translation(x_offsets.item(), y_offsets.item()) return np_coverage( - x=x.squeeze(axis=(GEOM_AXIS, Y_AXIS)), - y=y.squeeze(axis=(GEOM_AXIS, X_AXIS)), + chunk_affine, + shape=shape, geometries=gpd.GeoDataFrame(geometry=geom_array.squeeze(axis=(X_AXIS, Y_AXIS)), crs=crs), coverage_weight=coverage_weight, ) def dask_coverage( - x: dask.array.Array, - y: dask.array.Array, + affine, *, + chunks: tuple[tuple[int, ...], tuple[int, ...]], geom_array: dask.array.Array, coverage_weight: CoverageWeights = "fraction", strategy: Strategy = DEFAULT_STRATEGY, crs: Any, ) -> dask.array.Array: import dask.array + from dask.array import from_array - if any(c == 1 for c in x.chunks[0]) or any(c == 1 for c in y.chunks[0]): + y_chunks, x_chunks = chunks + + if any(c == 1 for c in x_chunks) or any(c == 1 for c in y_chunks): raise ValueError("exactextract does not support a chunksize of 1. Please rechunk to avoid this") + x_sizes = from_array(x_chunks, chunks=1) + y_sizes = from_array(y_chunks, chunks=1) + x_offsets = from_array(np.cumulative_sum(x_chunks[:-1], include_initial=True), chunks=1) + y_offsets = from_array(np.cumulative_sum(y_chunks[:-1], include_initial=True), chunks=1) + out = dask.array.map_blocks( coverage_np_dask_wrapper, geom_array[:, np.newaxis, np.newaxis], - x[np.newaxis, np.newaxis, :], - y[np.newaxis, :, np.newaxis], + x_offsets[np.newaxis, np.newaxis, :], + y_offsets[np.newaxis, :, np.newaxis], + x_sizes[np.newaxis, np.newaxis, :], + y_sizes[np.newaxis, :, np.newaxis], + affine=affine, crs=crs, coverage_weight=coverage_weight, strategy=strategy, - chunks=(*geom_array.chunks, *y.chunks, *x.chunks), + chunks=(*geom_array.chunks, tuple(y_chunks), tuple(x_chunks)), meta=sparse.COO( [], data=np.array([], dtype=get_dtype(coverage_weight, geom_array)), shape=(0, 0, 0), fill_value=0 ), @@ -290,32 +306,30 @@ def coverage( if clip: obj = clip_to_bbox(obj, geometries, xdim=xdim, ydim=ydim) + + affine = _get_affine(obj, x_dim=xdim, y_dim=ydim) + shape = (obj.sizes[ydim], obj.sizes[xdim]) + if is_in_memory(obj=obj, geometries=geometries): out = np_coverage( - x=obj[xdim].data, - y=obj[ydim].data, + affine, + shape=shape, geometries=geometries, coverage_weight=coverage_weight, strategy=strategy, ) geom_array = geometries.to_numpy().squeeze(axis=1) else: - from dask.array import Array, from_array - geom_dask_array = geometries_as_dask_array(geometries) - if not isinstance(obj[xdim].data, Array): - dask_x = from_array(obj[xdim].data, chunks=obj.chunksizes.get(xdim, -1)) - else: - dask_x = obj[xdim].data - if not isinstance(obj[ydim].data, Array): - dask_y = from_array(obj[ydim].data, chunks=obj.chunksizes.get(ydim, -1)) - else: - dask_y = obj[ydim].data + chunks = ( + obj.chunksizes.get(ydim, (obj.sizes[ydim],)), + obj.chunksizes.get(xdim, (obj.sizes[xdim],)), + ) out = dask_coverage( - x=dask_x, - y=dask_y, + affine, + chunks=chunks, geom_array=geom_dask_array, crs=geometries.crs, coverage_weight=coverage_weight, diff --git a/tests/test_exact.py b/tests/test_exact.py index 17c20a2..5b8fb20 100644 --- a/tests/test_exact.py +++ b/tests/test_exact.py @@ -120,6 +120,10 @@ def test_coverage_weights( ds = ds.isel(x=xslicer, y=yslicer) if xchunks is not None or ychunks is not None: ds = ds.chunk({"x": xchunks, "y": ychunks}) + if not indexed: + # Ensure coordinate arrays are in memory so get_affine doesn't trigger dask compute + ds.coords["x"].load() + ds.coords["y"].load() with raise_if_dask_computes(): actual = coverage(ds, geometries[["geometry"]], coverage_weight=coverage_weight) diff --git a/tests/test_rasterize.py b/tests/test_rasterize.py index 073987e..1e00125 100644 --- a/tests/test_rasterize.py +++ b/tests/test_rasterize.py @@ -6,6 +6,7 @@ import xproj # noqa from xarray.tests import raise_if_dask_computes +from rasterix import RasterIndex, assign_index from rasterix.rasterize import geometry_clip, geometry_mask, rasterize pytestmark = pytest.mark.filterwarnings("ignore:variable '.*' has non-conforming '_FillValue'") @@ -121,3 +122,35 @@ def test_geometry_clip(engine, dataset): assert clipped is not None # Basic check that clipping worked - masked values outside geometries assert clipped["u"].isnull().any() + + +@pytest.fixture +def raster_index_dataset(dataset): + """Same grid as ``dataset`` but with a RasterIndex on the spatial dims.""" + ds = assign_index(dataset, x_dim="longitude", y_dim="latitude") + assert isinstance(ds.xindexes["longitude"], RasterIndex) + return ds + + +def test_rasterize_with_raster_index(engine, raster_index_dataset, dataset): + world = gpd.read_file(geodatasets.get_path("naturalearth land")) + kwargs = dict(xdim="longitude", ydim="latitude", engine=engine) + + expected = rasterize(dataset, world[["geometry"]], **kwargs) + result = rasterize(raster_index_dataset, world[["geometry"]], **kwargs) + + xr.testing.assert_equal(result, expected) + assert isinstance(result.xindexes["longitude"], RasterIndex) + assert isinstance(result.xindexes["latitude"], RasterIndex) + + +def test_geometry_mask_with_raster_index(engine, raster_index_dataset, dataset): + world = gpd.read_file(geodatasets.get_path("naturalearth land")) + kwargs = dict(xdim="longitude", ydim="latitude", engine=engine) + + expected = geometry_mask(dataset, world[["geometry"]], **kwargs) + result = geometry_mask(raster_index_dataset, world[["geometry"]], **kwargs) + + xr.testing.assert_equal(result, expected) + assert isinstance(result.xindexes["longitude"], RasterIndex) + assert isinstance(result.xindexes["latitude"], RasterIndex)