Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions src/rasterix/rasterize/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,25 @@
import numpy as np
import xarray as xr

from ..raster_index import RasterIndex
from ..utils import get_affine, get_grid_mapping_var
from .utils import XAXIS, YAXIS, clip_to_bbox, is_in_memory, prepare_for_dask

if TYPE_CHECKING:
import dask_geopandas
from affine import Affine

__all__ = ["rasterize", "geometry_mask", "geometry_clip"]


def _get_affine(obj: xr.Dataset | xr.DataArray, *, x_dim: str, y_dim: str) -> Affine:
"""Get affine transform, preferring RasterIndex if available."""
idx = obj.xindexes.get(x_dim)
if isinstance(idx, RasterIndex):
return idx.transform()
return get_affine(obj, x_dim=x_dim, y_dim=y_dim)


Engine = Literal["rasterio", "rusterize", "exactextract"]


Expand Down Expand Up @@ -222,7 +233,7 @@ def rasterize(
if clip:
obj = clip_to_bbox(obj, geometries, xdim=xdim, ydim=ydim)

affine = get_affine(obj, x_dim=xdim, y_dim=ydim)
affine = _get_affine(obj, x_dim=xdim, y_dim=ydim)
engine_merge_alg = _normalize_merge_alg(merge_alg, resolved_engine)

rasterize_geometries, dask_rasterize_wrapper = _get_rasterize_funcs(resolved_engine)
Expand Down Expand Up @@ -370,7 +381,7 @@ def geometry_mask(
if clip:
obj = clip_to_bbox(obj, geometries, xdim=xdim, ydim=ydim)

affine = get_affine(obj, x_dim=xdim, y_dim=ydim)
affine = _get_affine(obj, x_dim=xdim, y_dim=ydim)

np_geometry_mask, dask_mask_wrapper = _get_mask_funcs(resolved_engine)

Expand Down
70 changes: 42 additions & 28 deletions src/rasterix/rasterize/exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from exactextract import exact_extract
from exactextract.raster import NumPyRasterSource

from ..rasterize.core import _get_affine
from ..utils import get_grid_mapping_var
from .utils import clip_to_bbox, geometries_as_dask_array, is_in_memory

Expand Down Expand Up @@ -103,9 +104,9 @@ def get_dtype(coverage_weight: CoverageWeights, geometries):


def np_coverage(
x: np.ndarray,
y: np.ndarray,
affine,
*,
shape: tuple[int, int],
geometries: gpd.GeoDataFrame,
strategy: Strategy = DEFAULT_STRATEGY,
coverage_weight: CoverageWeights = "fraction",
Expand All @@ -115,8 +116,7 @@ def np_coverage(
if len(geometries.columns) > 1:
raise ValueError("Require a single geometries column or a GeoSeries.")

shape = (y.size, x.size)
raster = xy_to_raster_source(x, y, srs_wkt=geometries.crs.to_wkt())
raster = affine_to_raster_source(affine, shape, srs_wkt=geometries.crs.to_wkt())
result = exact_extract(
rast=raster,
vec=geometries,
Expand Down Expand Up @@ -161,43 +161,59 @@ def np_coverage(

def coverage_np_dask_wrapper(
geom_array: np.ndarray,
x: np.ndarray,
y: np.ndarray,
x_offsets: np.ndarray,
y_offsets: np.ndarray,
x_sizes: np.ndarray,
y_sizes: np.ndarray,
affine,
coverage_weight: CoverageWeights,
strategy: Strategy,
crs,
) -> np.ndarray:
shape = (y_sizes.item(), x_sizes.item())
chunk_affine = affine * affine.translation(x_offsets.item(), y_offsets.item())
return np_coverage(
x=x.squeeze(axis=(GEOM_AXIS, Y_AXIS)),
y=y.squeeze(axis=(GEOM_AXIS, X_AXIS)),
chunk_affine,
shape=shape,
geometries=gpd.GeoDataFrame(geometry=geom_array.squeeze(axis=(X_AXIS, Y_AXIS)), crs=crs),
coverage_weight=coverage_weight,
)


def dask_coverage(
x: dask.array.Array,
y: dask.array.Array,
affine,
*,
chunks: tuple[tuple[int, ...], tuple[int, ...]],
geom_array: dask.array.Array,
coverage_weight: CoverageWeights = "fraction",
strategy: Strategy = DEFAULT_STRATEGY,
crs: Any,
) -> dask.array.Array:
import dask.array
from dask.array import from_array

if any(c == 1 for c in x.chunks[0]) or any(c == 1 for c in y.chunks[0]):
y_chunks, x_chunks = chunks

if any(c == 1 for c in x_chunks) or any(c == 1 for c in y_chunks):
raise ValueError("exactextract does not support a chunksize of 1. Please rechunk to avoid this")

x_sizes = from_array(x_chunks, chunks=1)
y_sizes = from_array(y_chunks, chunks=1)
x_offsets = from_array(np.cumulative_sum(x_chunks[:-1], include_initial=True), chunks=1)
y_offsets = from_array(np.cumulative_sum(y_chunks[:-1], include_initial=True), chunks=1)

out = dask.array.map_blocks(
coverage_np_dask_wrapper,
geom_array[:, np.newaxis, np.newaxis],
x[np.newaxis, np.newaxis, :],
y[np.newaxis, :, np.newaxis],
x_offsets[np.newaxis, np.newaxis, :],
y_offsets[np.newaxis, :, np.newaxis],
x_sizes[np.newaxis, np.newaxis, :],
y_sizes[np.newaxis, :, np.newaxis],
affine=affine,
crs=crs,
coverage_weight=coverage_weight,
strategy=strategy,
chunks=(*geom_array.chunks, *y.chunks, *x.chunks),
chunks=(*geom_array.chunks, tuple(y_chunks), tuple(x_chunks)),
meta=sparse.COO(
[], data=np.array([], dtype=get_dtype(coverage_weight, geom_array)), shape=(0, 0, 0), fill_value=0
),
Expand Down Expand Up @@ -290,32 +306,30 @@ def coverage(

if clip:
obj = clip_to_bbox(obj, geometries, xdim=xdim, ydim=ydim)

affine = _get_affine(obj, x_dim=xdim, y_dim=ydim)
shape = (obj.sizes[ydim], obj.sizes[xdim])

if is_in_memory(obj=obj, geometries=geometries):
out = np_coverage(
x=obj[xdim].data,
y=obj[ydim].data,
affine,
shape=shape,
geometries=geometries,
coverage_weight=coverage_weight,
strategy=strategy,
)
geom_array = geometries.to_numpy().squeeze(axis=1)
else:
from dask.array import Array, from_array

geom_dask_array = geometries_as_dask_array(geometries)
if not isinstance(obj[xdim].data, Array):
dask_x = from_array(obj[xdim].data, chunks=obj.chunksizes.get(xdim, -1))
else:
dask_x = obj[xdim].data

if not isinstance(obj[ydim].data, Array):
dask_y = from_array(obj[ydim].data, chunks=obj.chunksizes.get(ydim, -1))
else:
dask_y = obj[ydim].data
chunks = (
obj.chunksizes.get(ydim, (obj.sizes[ydim],)),
obj.chunksizes.get(xdim, (obj.sizes[xdim],)),
)

out = dask_coverage(
x=dask_x,
y=dask_y,
affine,
chunks=chunks,
geom_array=geom_dask_array,
crs=geometries.crs,
coverage_weight=coverage_weight,
Expand Down
4 changes: 4 additions & 0 deletions tests/test_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ def test_coverage_weights(
ds = ds.isel(x=xslicer, y=yslicer)
if xchunks is not None or ychunks is not None:
ds = ds.chunk({"x": xchunks, "y": ychunks})
if not indexed:
# Ensure coordinate arrays are in memory so get_affine doesn't trigger dask compute
ds.coords["x"].load()
ds.coords["y"].load()

with raise_if_dask_computes():
actual = coverage(ds, geometries[["geometry"]], coverage_weight=coverage_weight)
Expand Down
33 changes: 33 additions & 0 deletions tests/test_rasterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import xproj # noqa
from xarray.tests import raise_if_dask_computes

from rasterix import RasterIndex, assign_index
from rasterix.rasterize import geometry_clip, geometry_mask, rasterize

pytestmark = pytest.mark.filterwarnings("ignore:variable '.*' has non-conforming '_FillValue'")
Expand Down Expand Up @@ -121,3 +122,35 @@ def test_geometry_clip(engine, dataset):
assert clipped is not None
# Basic check that clipping worked - masked values outside geometries
assert clipped["u"].isnull().any()


@pytest.fixture
def raster_index_dataset(dataset):
"""Same grid as ``dataset`` but with a RasterIndex on the spatial dims."""
ds = assign_index(dataset, x_dim="longitude", y_dim="latitude")
assert isinstance(ds.xindexes["longitude"], RasterIndex)
return ds


def test_rasterize_with_raster_index(engine, raster_index_dataset, dataset):
world = gpd.read_file(geodatasets.get_path("naturalearth land"))
kwargs = dict(xdim="longitude", ydim="latitude", engine=engine)

expected = rasterize(dataset, world[["geometry"]], **kwargs)
result = rasterize(raster_index_dataset, world[["geometry"]], **kwargs)

xr.testing.assert_equal(result, expected)
assert isinstance(result.xindexes["longitude"], RasterIndex)
assert isinstance(result.xindexes["latitude"], RasterIndex)


def test_geometry_mask_with_raster_index(engine, raster_index_dataset, dataset):
world = gpd.read_file(geodatasets.get_path("naturalearth land"))
kwargs = dict(xdim="longitude", ydim="latitude", engine=engine)

expected = geometry_mask(dataset, world[["geometry"]], **kwargs)
result = geometry_mask(raster_index_dataset, world[["geometry"]], **kwargs)

xr.testing.assert_equal(result, expected)
assert isinstance(result.xindexes["longitude"], RasterIndex)
assert isinstance(result.xindexes["latitude"], RasterIndex)
Loading