From 8e51f33fbb146e86cb4086c756e49d920f0a1ea5 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 1 Jul 2025 13:15:10 +0200 Subject: [PATCH 01/17] unrelated docstrings fix --- xarray/core/coordinate_transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/coordinate_transform.py b/xarray/core/coordinate_transform.py index 94b3b109e1e..d1e434c3d64 100644 --- a/xarray/core/coordinate_transform.py +++ b/xarray/core/coordinate_transform.py @@ -80,7 +80,7 @@ def equals(self, other: CoordinateTransform, **kwargs) -> bool: Parameters ---------- other : CoordinateTransform - The other Index object to compare with this object. + The other CoordinateTransform object to compare with this object. exclude : frozenset of hashable, optional Dimensions excluded from checking. It is None by default, (i.e., when this method is not called in the context of alignment). For a From f8e906ac35904e614e2d2bb056e485809d58d119 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 1 Jul 2025 13:15:40 +0200 Subject: [PATCH 02/17] add TreeIndex --- xarray/indexes/__init__.py | 2 + xarray/indexes/tree_index.py | 260 +++++++++++++++++++++++++++++++++++ 2 files changed, 262 insertions(+) create mode 100644 xarray/indexes/tree_index.py diff --git a/xarray/indexes/__init__.py b/xarray/indexes/__init__.py index c53a4b8c2ce..972dce764ff 100644 --- a/xarray/indexes/__init__.py +++ b/xarray/indexes/__init__.py @@ -11,6 +11,7 @@ PandasMultiIndex, ) from xarray.indexes.range_index import RangeIndex +from xarray.indexes.tree_index import TreeIndex __all__ = [ "CoordinateTransform", @@ -19,4 +20,5 @@ "PandasIndex", "PandasMultiIndex", "RangeIndex", + "TreeIndex", ] diff --git a/xarray/indexes/tree_index.py b/xarray/indexes/tree_index.py new file mode 100644 index 00000000000..b8de38aa327 --- /dev/null +++ b/xarray/indexes/tree_index.py @@ -0,0 +1,260 @@ +from __future__ import annotations + +import abc +from collections.abc import Hashable, Iterable, Mapping +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +import numpy as np + +from xarray.core.dataarray import DataArray +from xarray.core.indexes import Index +from xarray.core.indexing import IndexSelResult +from xarray.core.variable import Variable + +if TYPE_CHECKING: + from scipy.spatial import KDTree + + from xarray.core.types import Self + + +class TreeAdapter(abc.ABC): + """Lightweight adapter abstract class for plugging in 3rd-party structures + like :py:class:`scipy.spatial.KDTree` or :py:class:`sklearn.neighbors.KDTree` + into :py:class:`~xarray.indexes.TreeIndex`. + + """ + + @abc.abstractmethod + def __init__(self, points: np.ndarray, *, options: Mapping[str, Any]): + """ + Parameters + ---------- + points : ndarray of shape (n_points, n_coordinates) + Two-dimensional array of points/samples (rows) and their + corresponding coordinate labels (columns) to index. + """ + ... + + @abc.abstractmethod + def query(self, points: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """Query points. + + Parameters + ---------- + points: ndarray of shape (n_points, n_coordinates) + Two-dimensional array of points/samples (rows) and their + corresponding coordinate labels (columns) to query. + + Returns + ------- + distances : ndarray of shape (n_points) + Distances to the nearest neighbors. + indices : ndarray of shape (n_points) + Indices of the nearest neighbors in the array of the indexed + points. + """ + ... + + def equals(self, other: TreeAdapter) -> bool: + """Check equality with another TreeAdapter of the same kind. + + Parameters + ---------- + other : + The other TreeAdapter object to compare with this object. + + """ + raise NotImplementedError + + +class ScipyKDTreeAdapter(TreeAdapter): + """:py:class:`scipy.spatial.KDTree` adapter for :py:class:`~xarray.indexes.TreeIndex`.""" + + _kdtree: KDTree + + def __init__(self, points: np.ndarray, options: Mapping[str, Any]): + from scipy.spatial import KDTree + + self._kdtree = KDTree(points, **options) + + def query(self, points: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + return self._kdtree.query(points) + + def equals(self, other: ScipyKDTreeAdapter) -> bool: + return np.all(self._kdtree.data, other._kdtree.data) + + +def get_points(coords: Iterable[Variable | Any]) -> np.ndarray: + """Re-arrange data from a sequence of xarray coordinate variables or + labels into a 2-d array of shape (n_points, n_coordinates). + + """ + data = [c.values if isinstance(c, Variable | DataArray) else c for c in coords] + return np.stack([np.ravel(d) for d in data]).T + + +T_TreeAdapter = TypeVar("T_TreeAdapter", bound=TreeAdapter) + + +class TreeIndex(Index, Generic[T_TreeAdapter]): + """Xarray index for irregular, n-dimensional data. + + This index may be associated with a set of coordinate variables representing + the location of the data points in an n-dimensional space. All coordinates + must have the same shape and dimensions. The number of associated coordinate + variables must correspond to the number of dimensions of the space. + + This index supports label-based selection (nearest neighbor lookup). It also + has limited support for alignment. + + By default, this index relies on :py:class:`scipy.spatial.KDTree` for fast + lookup. + + Examples + -------- + TODO + + """ + + _tree_obj: T_TreeAdapter + _coord_names: tuple[Hashable, ...] + _dims: tuple[Hashable, ...] + _shape: tuple[int, ...] + + def __init__( + self, + tree_obj: T_TreeAdapter, + *, + coord_names: tuple[Hashable, ...], + dims: tuple[Hashable, ...], + shape: tuple[int, ...], + ): + # this constructor is "private" + assert isinstance(tree_obj, TreeAdapter) + self._tree_obj = tree_obj + + assert len(coord_names) == len(dims) == len(shape) + self._coord_names = coord_names + self._dims = dims + self._shape = shape + + @classmethod + def from_variables( + cls, + variables: Mapping[Any, Variable], + *, + options: Mapping[str, Any], + ) -> Self: + if len(set([var.dims for var in variables.values()])) > 1: + var_names = ",".join(vn for vn in variables) + raise ValueError( + f"variables {var_names} must all have the same dimensions and the same shape" + ) + + var0 = next(iter(variables.values())) + + if len(variables) != len(var0.dims): + raise ValueError( + f"the number of variables {len(variables)} doesn't match the number of dimensions {len(var0.dims)}" + ) + + opts = dict(options) + + tree_adapter_cls: type[T_TreeAdapter] = opts.pop("tree_adapter_cls", None) + if tree_adapter_cls is None: + tree_adapter_cls = ScipyKDTreeAdapter + + points = get_points(variables.values()) + + return cls( + tree_adapter_cls(points, options=opts), + coord_names=tuple(variables), + dims=var0.dims, + shape=var0.shape, + ) + + def equals( + self, other: Index, *, exclude: frozenset[Hashable] | None = None + ) -> bool: + if not isinstance(other, TreeIndex): + return False + if type(self._tree_obj) is not type(other._tree_obj): + return False + return self._tree_obj.equals(other._tree_obj) + + def _get_dim_indexers( + self, + indices: np.ndarray, + label_dims: tuple[Hashable, ...], + label_shape: tuple[int, ...], + ) -> dict[Hashable, np.ndarray]: + """Returns dimension indexers based on the query results (indices) and + the original label dimensions and shape. + + 1. Unravel the flat indices returned from the query + 2. Reshape the unraveled indices according to indexers shapes + 3. Wrap the indices in xarray.Variable objects. + + """ + dim_indexers = {} + + u_indices = list(np.unravel_index(indices.ravel(), self._shape)) + + for dim, ind in zip(self._dims, u_indices, strict=False): + dim_indexers[dim] = Variable(label_dims, ind.reshape(label_shape)) + + return dim_indexers + + def sel( + self, labels: dict[Any, Any], method=None, tolerance=None + ) -> IndexSelResult: + if method != "nearest": + raise ValueError( + "CoordinateTransformIndex only supports selection with method='nearest'" + ) + + missing_labels = set(self._coord_names) - set(labels) + if missing_labels: + missing_labels_str = ",".join([f"{name}" for name in missing_labels]) + raise ValueError(f"missing labels for coordinate(s): {missing_labels_str}.") + + if any(not isinstance(lbl, Variable | DataArray) for lbl in labels): + raise TypeError( + "TreeIndex only supports advanced (point-wise) indexing " + "with either xarray.DataArray or xarray.Variable objects." + ) + + if len(set([var.dims for var in labels.values()])) > 1: + raise ValueError( + "TreeIndex only supports advanced (point-wise) indexing " + "with xarray.DataArray or xarray.Variable objects of matching dimensions." + ) + + label0: DataArray | Variable = next(iter(labels.values())) + + points = get_points(labels[name] for name in self._coord_names) + _, indices = self._tree_obj.query(points) + + dim_indexers = self._get_dim_indexers(indices, label0.dims, label0.shape) + + return IndexSelResult(dim_indexers=dim_indexers) + + def rename( + self, + name_dict: Mapping[Any, Hashable], + dims_dict: Mapping[Any, Hashable], + ) -> Self: + if not set(self._coord_names) & set(name_dict) and not set(self._dims) & set( + dims_dict + ): + return self + + new_coord_names = tuple(name_dict.get(n, n) for n in self._coord_names) + new_dims = tuple(dims_dict.get(d, d) for d in self._dims) + + return type(self)( + self._tree_obj, + coord_names=new_coord_names, + dims=new_dims, + shape=self._shape, + ) From ae55d4981d949f043b67a9f479e7261dbc352c12 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 1 Jul 2025 16:25:01 +0200 Subject: [PATCH 03/17] a few fixes --- xarray/indexes/tree_index.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/xarray/indexes/tree_index.py b/xarray/indexes/tree_index.py index b8de38aa327..bcaac55f47b 100644 --- a/xarray/indexes/tree_index.py +++ b/xarray/indexes/tree_index.py @@ -81,7 +81,7 @@ def query(self, points: np.ndarray) -> tuple[np.ndarray, np.ndarray]: return self._kdtree.query(points) def equals(self, other: ScipyKDTreeAdapter) -> bool: - return np.all(self._kdtree.data, other._kdtree.data) + return np.array_equal(self._kdtree.data, other._kdtree.data) def get_points(coords: Iterable[Variable | Any]) -> np.ndarray: @@ -209,16 +209,14 @@ def sel( self, labels: dict[Any, Any], method=None, tolerance=None ) -> IndexSelResult: if method != "nearest": - raise ValueError( - "CoordinateTransformIndex only supports selection with method='nearest'" - ) + raise ValueError("TreeIndex only supports selection with method='nearest'") missing_labels = set(self._coord_names) - set(labels) if missing_labels: missing_labels_str = ",".join([f"{name}" for name in missing_labels]) raise ValueError(f"missing labels for coordinate(s): {missing_labels_str}.") - if any(not isinstance(lbl, Variable | DataArray) for lbl in labels): + if any(not isinstance(lbl, Variable | DataArray) for lbl in labels.values()): raise TypeError( "TreeIndex only supports advanced (point-wise) indexing " "with either xarray.DataArray or xarray.Variable objects." From 534169f921d8347fa1ac990914e5c3968f14620c Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 2 Jul 2025 09:32:48 +0200 Subject: [PATCH 04/17] typing fixes --- xarray/indexes/tree_index.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/indexes/tree_index.py b/xarray/indexes/tree_index.py index bcaac55f47b..f55f8b0a6ec 100644 --- a/xarray/indexes/tree_index.py +++ b/xarray/indexes/tree_index.py @@ -55,7 +55,7 @@ def query(self, points: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """ ... - def equals(self, other: TreeAdapter) -> bool: + def equals(self, other: Self) -> bool: """Check equality with another TreeAdapter of the same kind. Parameters @@ -80,7 +80,7 @@ def __init__(self, points: np.ndarray, options: Mapping[str, Any]): def query(self, points: np.ndarray) -> tuple[np.ndarray, np.ndarray]: return self._kdtree.query(points) - def equals(self, other: ScipyKDTreeAdapter) -> bool: + def equals(self, other: Self) -> bool: return np.array_equal(self._kdtree.data, other._kdtree.data) @@ -187,7 +187,7 @@ def _get_dim_indexers( indices: np.ndarray, label_dims: tuple[Hashable, ...], label_shape: tuple[int, ...], - ) -> dict[Hashable, np.ndarray]: + ) -> dict[Hashable, Variable]: """Returns dimension indexers based on the query results (indices) and the original label dimensions and shape. From 793341609053f6d75b8585c4a66d2b878d2200b7 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 2 Jul 2025 10:31:57 +0200 Subject: [PATCH 05/17] TreeIndex.sel: support scalar and array-like label --- xarray/indexes/tree_index.py | 68 +++++++++++++++++++++++++++--------- 1 file changed, 52 insertions(+), 16 deletions(-) diff --git a/xarray/indexes/tree_index.py b/xarray/indexes/tree_index.py index f55f8b0a6ec..a0f0a497c99 100644 --- a/xarray/indexes/tree_index.py +++ b/xarray/indexes/tree_index.py @@ -9,6 +9,7 @@ from xarray.core.dataarray import DataArray from xarray.core.indexes import Index from xarray.core.indexing import IndexSelResult +from xarray.core.utils import is_scalar from xarray.core.variable import Variable if TYPE_CHECKING: @@ -216,24 +217,59 @@ def sel( missing_labels_str = ",".join([f"{name}" for name in missing_labels]) raise ValueError(f"missing labels for coordinate(s): {missing_labels_str}.") - if any(not isinstance(lbl, Variable | DataArray) for lbl in labels.values()): - raise TypeError( - "TreeIndex only supports advanced (point-wise) indexing " - "with either xarray.DataArray or xarray.Variable objects." - ) - - if len(set([var.dims for var in labels.values()])) > 1: - raise ValueError( - "TreeIndex only supports advanced (point-wise) indexing " - "with xarray.DataArray or xarray.Variable objects of matching dimensions." - ) - - label0: DataArray | Variable = next(iter(labels.values())) - - points = get_points(labels[name] for name in self._coord_names) + # maybe convert labels into xarray Variable objects + xr_labels: dict[Any, Variable | DataArray] = {} + + for name, lbl in labels.items(): + if isinstance(lbl, Variable | DataArray): + xr_labels[name] = lbl + elif is_scalar(lbl): + xr_labels[name] = Variable((), lbl) + elif np.asarray(lbl).ndim == len(self._dims): + xr_labels[name] = Variable(self._dims, lbl) + else: + raise TypeError( + "invalid label type. TreeIndex only supports advanced (point-wise) indexing " + "with the following label value types:\n" + "- xarray.DataArray or xarray.Variable objects\n" + "- scalar types\n" + "- unlabelled array-like objects with the same number of dimensions " + f"than the {self._coord_names} coordinate variables ({len(self._dims)})" + ) + + # determine xarray label shape and dimensions + label_dims: tuple[Hashable, ...] = () + label_shape: tuple[int, ...] = () + + for name, lbl in xr_labels.items(): + if lbl.ndim > 0: + if label_dims and lbl.dims != label_dims: + raise ValueError( + f"label {name} has dimensions {lbl.dims} that conflict with " + f"other label dimensions {label_dims}" + ) + else: + label_dims = lbl.dims + if label_shape and lbl.shape != label_shape: + raise ValueError( + f"label {name} has shape {lbl.shape} that conflicts with " + f"other label shape {label_shape}" + ) + else: + label_shape = lbl.shape + + # maybe broadcast scalar xarray labels + if label_dims: + for name, lbl in xr_labels.items(): + if not lbl.dims: + xr_labels[name] = Variable( + label_dims, np.broadcast_to(lbl.values, label_shape) + ) + + points = get_points(xr_labels[name] for name in self._coord_names) _, indices = self._tree_obj.query(points) - dim_indexers = self._get_dim_indexers(indices, label0.dims, label0.shape) + dim_indexers = self._get_dim_indexers(indices, label_dims, label_shape) return IndexSelResult(dim_indexers=dim_indexers) From 8af432796c2a48185760c1377ca0b8add7a6207b Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 2 Jul 2025 11:32:00 +0200 Subject: [PATCH 06/17] add tests --- xarray/indexes/tree_index.py | 21 +++- xarray/tests/test_tree_index.py | 172 ++++++++++++++++++++++++++++++++ 2 files changed, 189 insertions(+), 4 deletions(-) create mode 100644 xarray/tests/test_tree_index.py diff --git a/xarray/indexes/tree_index.py b/xarray/indexes/tree_index.py index a0f0a497c99..67fd59b98f8 100644 --- a/xarray/indexes/tree_index.py +++ b/xarray/indexes/tree_index.py @@ -174,6 +174,19 @@ def from_variables( shape=var0.shape, ) + def create_variables( + self, variables: Mapping[Any, Variable] | None = None + ) -> dict[Any, Variable]: + if variables is not None: + for var in variables.values(): + # might need to update variable dimensions from the index object + # returned from TreeIndex.rename() + if var.dims != self._dims: + var.dims = self._dims + return dict(**variables) + else: + return {} + def equals( self, other: Index, *, exclude: frozenset[Hashable] | None = None ) -> bool: @@ -228,11 +241,11 @@ def sel( elif np.asarray(lbl).ndim == len(self._dims): xr_labels[name] = Variable(self._dims, lbl) else: - raise TypeError( - "invalid label type. TreeIndex only supports advanced (point-wise) indexing " - "with the following label value types:\n" + raise ValueError( + "invalid label value. TreeIndex only supports advanced (point-wise) indexing " + "with the following label value kinds:\n" "- xarray.DataArray or xarray.Variable objects\n" - "- scalar types\n" + "- scalar values\n" "- unlabelled array-like objects with the same number of dimensions " f"than the {self._coord_names} coordinate variables ({len(self._dims)})" ) diff --git a/xarray/tests/test_tree_index.py b/xarray/tests/test_tree_index.py new file mode 100644 index 00000000000..462a4eae8d6 --- /dev/null +++ b/xarray/tests/test_tree_index.py @@ -0,0 +1,172 @@ +import numpy as np +import pytest + +import xarray as xr +from xarray.indexes import TreeIndex +from xarray.tests import assert_identical + + +def test_tree_index_init() -> None: + from xarray.indexes.tree_index import ScipyKDTreeAdapter + + xx, yy = np.meshgrid([1.0, 2.0], [3.0, 4.0]) + ds = xr.Dataset(coords={"xx": (("y", "x"), xx), "yy": (("y", "x"), yy)}) + + ds_indexed1 = ds.set_xindex(("xx", "yy"), TreeIndex) + assert "xx" in ds_indexed1.xindexes + assert "yy" in ds_indexed1.xindexes + assert isinstance(ds_indexed1.xindexes["xx"], TreeIndex) + assert ds_indexed1.xindexes["xx"] is ds_indexed1.xindexes["yy"] + + ds_indexed2 = ds.set_xindex( + ("xx", "yy"), TreeIndex, tree_adapter_cls=ScipyKDTreeAdapter + ) + assert ds_indexed1.xindexes["xx"].equals(ds_indexed2.xindexes["yy"]) + + +def test_tree_index_init_errors() -> None: + xx, yy = np.meshgrid([1.0, 2.0], [3.0, 4.0]) + ds = xr.Dataset(coords={"xx": (("y", "x"), xx), "yy": (("y", "x"), yy)}) + + with pytest.raises(ValueError, match="number of variables"): + ds.set_xindex("xx", TreeIndex) + + ds2 = ds.assign_coords(yy=(("u", "v"), [[3.0, 3.0], [4.0, 4.0]])) + + with pytest.raises(ValueError, match="same dimensions"): + ds2.set_xindex(("xx", "yy"), TreeIndex) + + +def test_tree_index_sel() -> None: + xx, yy = np.meshgrid([1.0, 2.0], [3.0, 4.0]) + ds = xr.Dataset(coords={"xx": (("y", "x"), xx), "yy": (("y", "x"), yy)}).set_xindex( + ("xx", "yy"), TreeIndex + ) + + # 1-dimensiona labels + actual = ds.sel( + xx=xr.Variable("u", [1.1, 1.1, 1.1]), + yy=xr.Variable("u", [3.1, 3.1, 3.1]), + method="nearest", + ) + expected = xr.Dataset( + coords={"xx": ("u", [1.0, 1.0, 1.0]), "yy": ("u", [3.0, 3.0, 3.0])} + ) + assert_identical(actual, expected) + + # 2-dimensional labels + actual = ds.sel( + xx=xr.Variable(("u", "v"), [[1.1, 1.1, 1.1], [1.9, 1.9, 1.9]]), + yy=xr.Variable(("u", "v"), [[3.1, 3.1, 3.1], [3.9, 3.9, 3.9]]), + method="nearest", + ) + expected = xr.Dataset( + coords={ + "xx": (("u", "v"), [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), + "yy": (("u", "v"), [[3.0, 3.0, 3.0], [4.0, 4.0, 4.0]]), + }, + ) + assert_identical(actual, expected) + + # all scalar labels + actual = ds.sel(xx=1.1, yy=3.1, method="nearest") + expected = xr.Dataset(coords={"xx": 1.0, "yy": 3.0}) + assert_identical(actual, expected) + + # broadcast scalar to label shape and dimensions + actual = ds.sel(xx=1.1, yy=xr.Variable("u", [3.1, 3.1, 3.1]), method="nearest") + expected = ds.sel( + xx=xr.Variable("u", [1.1, 1.1, 1.1]), + yy=xr.Variable("u", [3.1, 3.1, 3.1]), + method="nearest", + ) + assert_identical(actual, expected) + + # unlabelled array-like labels with dimensions matching index coordinate dimensions + actual = ds.sel( + xx=[[1.1, 1.1, 1.1], [1.9, 1.9, 1.9]], + yy=[[3.1, 3.1, 3.1], [3.9, 3.9, 3.9]], + method="nearest", + ) + expected = ds.sel( + xx=xr.Variable(ds.xx.dims, [[1.1, 1.1, 1.1], [1.9, 1.9, 1.9]]), + yy=xr.Variable(ds.yy.dims, [[3.1, 3.1, 3.1], [3.9, 3.9, 3.9]]), + method="nearest", + ) + assert_identical(actual, expected) + + +def test_tree_index_sel_errors() -> None: + xx, yy = np.meshgrid([1.0, 2.0], [3.0, 4.0]) + ds = xr.Dataset(coords={"xx": (("y", "x"), xx), "yy": (("y", "x"), yy)}).set_xindex( + ("xx", "yy"), TreeIndex + ) + + with pytest.raises(ValueError, match="method='nearest'"): + ds.sel(xx=1.1, yy=3.1) + + with pytest.raises(ValueError, match="missing labels"): + ds.sel(xx=1.1, method="nearest") + + with pytest.raises(ValueError, match="invalid label value"): + # invalid array-like dimensions + ds.sel(xx=[1.1, 1.9], yy=[3.1, 3.9], method="nearest") + + with pytest.raises(ValueError, match=".*dimensions.*conflict"): + ds.sel( + xx=xr.Variable("u", [1.1, 1.1, 1.1]), + yy=xr.Variable("v", [3.1, 3.1, 3.1]), + method="nearest", + ) + + with pytest.raises(ValueError, match=".*shape.*conflict"): + ds.sel( + xx=xr.Variable("u", [1.1, 1.1, 1.1]), + yy=xr.Variable("u", [3.1, 3.1]), + method="nearest", + ) + + +def test_tree_index_equals() -> None: + xx1, yy1 = np.meshgrid([1.0, 2.0], [3.0, 4.0]) + ds1 = xr.Dataset( + coords={"xx": (("y", "x"), xx1), "yy": (("y", "x"), yy1)} + ).set_xindex(("xx", "yy"), TreeIndex) + + xx2, yy2 = np.meshgrid([1.0, 2.0], [3.0, 4.0]) + ds2 = xr.Dataset( + coords={"xx": (("y", "x"), xx2), "yy": (("y", "x"), yy2)} + ).set_xindex(("xx", "yy"), TreeIndex) + + xx3, yy3 = np.meshgrid([10.0, 20.0], [30.0, 40.0]) + ds3 = xr.Dataset( + coords={"xx": (("y", "x"), xx3), "yy": (("y", "x"), yy3)} + ).set_xindex(("xx", "yy"), TreeIndex) + + assert ds1.xindexes["xx"].equals(ds2.xindexes["xx"]) + assert not ds1.xindexes["xx"].equals(ds3.xindexes["xx"]) + + +def test_tree_index_rename() -> None: + xx, yy = np.meshgrid([1.0, 2.0], [3.0, 4.0]) + ds = xr.Dataset(coords={"xx": (("y", "x"), xx), "yy": (("y", "x"), yy)}).set_xindex( + ("xx", "yy"), TreeIndex + ) + + ds_renamed = ds.rename_dims(y="u").rename_vars(yy="uu") + assert "uu" in ds_renamed.xindexes + assert isinstance(ds_renamed.xindexes["uu"], TreeIndex) + assert ds_renamed.xindexes["xx"] is ds_renamed.xindexes["uu"] + + # check via sel() that uses coord names and dims under the hood + actual = ds_renamed.sel( + xx=[[1.1, 1.1, 1.1], [1.9, 1.9, 1.9]], + uu=[[3.1, 3.1, 3.1], [3.9, 3.9, 3.9]], + method="nearest", + ) + expected = ds_renamed.sel( + xx=xr.Variable(ds_renamed.xx.dims, [[1.1, 1.1, 1.1], [1.9, 1.9, 1.9]]), + uu=xr.Variable(ds_renamed.uu.dims, [[3.1, 3.1, 3.1], [3.9, 3.9, 3.9]]), + method="nearest", + ) + assert_identical(actual, expected) From 971944de267f85899850c67bcf4ba03707c817b0 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 2 Jul 2025 11:33:44 +0200 Subject: [PATCH 07/17] doc: update API reference --- doc/api.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/api.rst b/doc/api.rst index df6e87c0cf8..d943d84c5ed 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -1577,6 +1577,7 @@ Custom Indexes CFTimeIndex indexes.RangeIndex indexes.CoordinateTransformIndex + indexes.TreeIndex Creating custom indexes ----------------------- From f8854eb6f80b9d23cfb5269100b973e9f368a8c4 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 2 Jul 2025 11:55:27 +0200 Subject: [PATCH 08/17] add docstrings examples --- xarray/indexes/tree_index.py | 85 +++++++++++++++++++++++++++++++++++- 1 file changed, 84 insertions(+), 1 deletion(-) diff --git a/xarray/indexes/tree_index.py b/xarray/indexes/tree_index.py index 67fd59b98f8..21b2f294a8a 100644 --- a/xarray/indexes/tree_index.py +++ b/xarray/indexes/tree_index.py @@ -113,7 +113,90 @@ class TreeIndex(Index, Generic[T_TreeAdapter]): Examples -------- - TODO + An example using a dataset with 2-dimensional coordinates representing + irregularly spaced data points. + + >>> xx = [[1.0, 2.0], [3.0, 0.0]] + >>> yy = [[11.0, 21.0], [29.0, 9.0]] + >>> ds = xr.Dataset(coords={"xx": (("y", "x"), xx), "yy": (("y", "x"), yy)}) + >>> ds + Size: 64B + Dimensions: (y: 2, x: 2) + Coordinates: + xx (y, x) float64 32B 1.0 2.0 3.0 0.0 + yy (y, x) float64 32B 11.0 21.0 29.0 9.0 + Dimensions without coordinates: y, x + Data variables: + *empty* + + Create a TreeIndex from the "xx" and "yy" coordinate variables: + + >>> ds = ds.set_xindex(("xx", "yy"), xr.indexes.TreeIndex) + >>> ds + Size: 64B + Dimensions: (y: 2, x: 2) + Coordinates: + * xx (y, x) float64 32B 1.0 2.0 3.0 0.0 + * yy (y, x) float64 32B 11.0 21.0 29.0 9.0 + Dimensions without coordinates: y, x + Data variables: + *empty* + Indexes: + ┌ xx TreeIndex + └ yy + + Point-wise (nearest-neighbor) data selection using Xarray's advanced + indexing, i.e., using arbitrary dimension(s) for the Variable objects passed + as labels: + + >>> ds.sel( + ... xx=xr.Variable("points", [1.9, 0.1]), + ... yy=xr.Variable("points", [13.0, 8.0]), + ... method="nearest", + ... ) + Size: 32B + Dimensions: (points: 2) + Coordinates: + xx (points) float64 16B 1.0 0.0 + yy (points) float64 16B 11.0 9.0 + Dimensions without coordinates: points + Data variables: + *empty* + + Data selection with scalar labels: + + >>> ds.sel(xx=1.9, yy=13.0, method="nearest") + Size: 16B + Dimensions: () + Coordinates: + xx float64 8B 1.0 + yy float64 8B 11.0 + Data variables: + *empty* + + Data selection with broadcasted scalar labels: + + >>> ds.sel(xx=1.9, yy=xr.Variable("points", [13.0, 8.0]), method="nearest") + Size: 32B + Dimensions: (points: 2) + Coordinates: + xx (points) float64 16B 1.0 0.0 + yy (points) float64 16B 11.0 9.0 + Dimensions without coordinates: points + Data variables: + *empty* + + Data selection with array-like labels (implicit dimensions): + + >>> ds.sel(xx=[[1.9], [0.1]], yy=[[13.0], [8.0]], method="nearest") + Size: 32B + Dimensions: (y: 2, x: 1) + Coordinates: + xx (y, x) float64 16B 1.0 0.0 + yy (y, x) float64 16B 11.0 9.0 + Dimensions without coordinates: y, x + Data variables: + *empty* """ From 7ddc9c7eeb8ff7037f88971754f8982c46a52005 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 2 Jul 2025 12:00:02 +0200 Subject: [PATCH 09/17] doc: update what's new --- doc/whats-new.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 31d3a1077e7..c984a055f03 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -14,6 +14,9 @@ New Features ~~~~~~~~~~~~ - Expose :py:class:`~xarray.indexes.RangeIndex`, and :py:class:`~xarray.indexes.CoordinateTransformIndex` as public api under the ``xarray.indexes`` namespace. By `Deepak Cherian `_. +- New :py:class:`xarray.indexes.TreeIndex`, which by default uses :py:class:`scipy.spatial.KDTree` under the hood for + the selection of irregular, n-dimensional data (:pull:`10478`). + By `Benoit Bovy `_. Breaking changes From 95f79b02511c15252b4202af9a37adbf22fe52f2 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 2 Jul 2025 12:02:14 +0200 Subject: [PATCH 10/17] typo --- xarray/tests/test_tree_index.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_tree_index.py b/xarray/tests/test_tree_index.py index 462a4eae8d6..306d346f934 100644 --- a/xarray/tests/test_tree_index.py +++ b/xarray/tests/test_tree_index.py @@ -43,7 +43,7 @@ def test_tree_index_sel() -> None: ("xx", "yy"), TreeIndex ) - # 1-dimensiona labels + # 1-dimensional labels actual = ds.sel( xx=xr.Variable("u", [1.1, 1.1, 1.1]), yy=xr.Variable("u", [3.1, 3.1, 3.1]), From 0d8d6ac3d2a29bd38b2c518e0b056116610f09bc Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 2 Jul 2025 12:04:42 +0200 Subject: [PATCH 11/17] skip tests if scipy is not available --- xarray/tests/test_tree_index.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xarray/tests/test_tree_index.py b/xarray/tests/test_tree_index.py index 306d346f934..4948de01349 100644 --- a/xarray/tests/test_tree_index.py +++ b/xarray/tests/test_tree_index.py @@ -5,6 +5,8 @@ from xarray.indexes import TreeIndex from xarray.tests import assert_identical +pytest.importorskip("scipy") + def test_tree_index_init() -> None: from xarray.indexes.tree_index import ScipyKDTreeAdapter From 99f941f174191959762fe176d753384c213e7c7f Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 2 Jul 2025 12:28:35 +0200 Subject: [PATCH 12/17] docstrings: add note about private __init__ --- xarray/indexes/tree_index.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/xarray/indexes/tree_index.py b/xarray/indexes/tree_index.py index 21b2f294a8a..c1a6a30e10c 100644 --- a/xarray/indexes/tree_index.py +++ b/xarray/indexes/tree_index.py @@ -111,6 +111,11 @@ class TreeIndex(Index, Generic[T_TreeAdapter]): By default, this index relies on :py:class:`scipy.spatial.KDTree` for fast lookup. + Do not use :py:meth:`~xarray.indexes.TreeIndex.__init__` directly. Instead + use :py:meth:`~xarray.Dataset.set_xindex` or + :py:meth:`~xarray.DataArray.set_xindex` to create and set the index from + existing coordinates (see the example below). + Examples -------- An example using a dataset with 2-dimensional coordinates representing From e867c7cf5b70edc2edec7c302e334a9ccbb1544d Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 2 Jul 2025 15:38:29 +0200 Subject: [PATCH 13/17] add TreeIndex._repr_inline_ Display the name of the wrapper index adapter type. --- xarray/indexes/tree_index.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/xarray/indexes/tree_index.py b/xarray/indexes/tree_index.py index c1a6a30e10c..1061306a69e 100644 --- a/xarray/indexes/tree_index.py +++ b/xarray/indexes/tree_index.py @@ -393,3 +393,7 @@ def rename( dims=new_dims, shape=self._shape, ) + + def _repr_inline_(self, max_width: int) -> str: + tree_obj_type = self._tree_obj.__class__.__name__ + return f"{self.__class__.__name__} ({tree_obj_type})" From 74affdcbd8359b7e1b33d35fb58656623d3922f9 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 2 Jul 2025 15:46:55 +0200 Subject: [PATCH 14/17] rename TreeIndex -> NDPointIndex --- doc/api.rst | 2 +- doc/whats-new.rst | 2 +- xarray/indexes/__init__.py | 4 +-- .../{tree_index.py => nd_point_index.py} | 24 ++++++++-------- ...t_tree_index.py => test_nd_point_index.py} | 28 +++++++++---------- 5 files changed, 31 insertions(+), 29 deletions(-) rename xarray/indexes/{tree_index.py => nd_point_index.py} (94%) rename xarray/tests/{test_tree_index.py => test_nd_point_index.py} (88%) diff --git a/doc/api.rst b/doc/api.rst index d943d84c5ed..0d722a4bec9 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -1577,7 +1577,7 @@ Custom Indexes CFTimeIndex indexes.RangeIndex indexes.CoordinateTransformIndex - indexes.TreeIndex + indexes.NDPointIndex Creating custom indexes ----------------------- diff --git a/doc/whats-new.rst b/doc/whats-new.rst index c984a055f03..2cb18eb283b 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -14,7 +14,7 @@ New Features ~~~~~~~~~~~~ - Expose :py:class:`~xarray.indexes.RangeIndex`, and :py:class:`~xarray.indexes.CoordinateTransformIndex` as public api under the ``xarray.indexes`` namespace. By `Deepak Cherian `_. -- New :py:class:`xarray.indexes.TreeIndex`, which by default uses :py:class:`scipy.spatial.KDTree` under the hood for +- New :py:class:`xarray.indexes.NDPointIndex`, which by default uses :py:class:`scipy.spatial.KDTree` under the hood for the selection of irregular, n-dimensional data (:pull:`10478`). By `Benoit Bovy `_. diff --git a/xarray/indexes/__init__.py b/xarray/indexes/__init__.py index 972dce764ff..2cba69607f3 100644 --- a/xarray/indexes/__init__.py +++ b/xarray/indexes/__init__.py @@ -10,15 +10,15 @@ PandasIndex, PandasMultiIndex, ) +from xarray.indexes.nd_point_index import NDPointIndex from xarray.indexes.range_index import RangeIndex -from xarray.indexes.tree_index import TreeIndex __all__ = [ "CoordinateTransform", "CoordinateTransformIndex", "Index", + "NDPointIndex", "PandasIndex", "PandasMultiIndex", "RangeIndex", - "TreeIndex", ] diff --git a/xarray/indexes/tree_index.py b/xarray/indexes/nd_point_index.py similarity index 94% rename from xarray/indexes/tree_index.py rename to xarray/indexes/nd_point_index.py index 1061306a69e..6c16981a1a7 100644 --- a/xarray/indexes/tree_index.py +++ b/xarray/indexes/nd_point_index.py @@ -21,7 +21,7 @@ class TreeAdapter(abc.ABC): """Lightweight adapter abstract class for plugging in 3rd-party structures like :py:class:`scipy.spatial.KDTree` or :py:class:`sklearn.neighbors.KDTree` - into :py:class:`~xarray.indexes.TreeIndex`. + into :py:class:`~xarray.indexes.NDPointIndex`. """ @@ -69,7 +69,7 @@ def equals(self, other: Self) -> bool: class ScipyKDTreeAdapter(TreeAdapter): - """:py:class:`scipy.spatial.KDTree` adapter for :py:class:`~xarray.indexes.TreeIndex`.""" + """:py:class:`scipy.spatial.KDTree` adapter for :py:class:`~xarray.indexes.NDPointIndex`.""" _kdtree: KDTree @@ -97,7 +97,7 @@ def get_points(coords: Iterable[Variable | Any]) -> np.ndarray: T_TreeAdapter = TypeVar("T_TreeAdapter", bound=TreeAdapter) -class TreeIndex(Index, Generic[T_TreeAdapter]): +class NDPointIndex(Index, Generic[T_TreeAdapter]): """Xarray index for irregular, n-dimensional data. This index may be associated with a set of coordinate variables representing @@ -111,7 +111,7 @@ class TreeIndex(Index, Generic[T_TreeAdapter]): By default, this index relies on :py:class:`scipy.spatial.KDTree` for fast lookup. - Do not use :py:meth:`~xarray.indexes.TreeIndex.__init__` directly. Instead + Do not use :py:meth:`~xarray.indexes.NDPointIndex.__init__` directly. Instead use :py:meth:`~xarray.Dataset.set_xindex` or :py:meth:`~xarray.DataArray.set_xindex` to create and set the index from existing coordinates (see the example below). @@ -134,9 +134,9 @@ class TreeIndex(Index, Generic[T_TreeAdapter]): Data variables: *empty* - Create a TreeIndex from the "xx" and "yy" coordinate variables: + Create a NDPointIndex from the "xx" and "yy" coordinate variables: - >>> ds = ds.set_xindex(("xx", "yy"), xr.indexes.TreeIndex) + >>> ds = ds.set_xindex(("xx", "yy"), xr.indexes.NDPointIndex) >>> ds Size: 64B Dimensions: (y: 2, x: 2) @@ -147,7 +147,7 @@ class TreeIndex(Index, Generic[T_TreeAdapter]): Data variables: *empty* Indexes: - ┌ xx TreeIndex + ┌ xx NDPointIndex └ yy Point-wise (nearest-neighbor) data selection using Xarray's advanced @@ -268,7 +268,7 @@ def create_variables( if variables is not None: for var in variables.values(): # might need to update variable dimensions from the index object - # returned from TreeIndex.rename() + # returned from NDPointIndex.rename() if var.dims != self._dims: var.dims = self._dims return dict(**variables) @@ -278,7 +278,7 @@ def create_variables( def equals( self, other: Index, *, exclude: frozenset[Hashable] | None = None ) -> bool: - if not isinstance(other, TreeIndex): + if not isinstance(other, NDPointIndex): return False if type(self._tree_obj) is not type(other._tree_obj): return False @@ -311,7 +311,9 @@ def sel( self, labels: dict[Any, Any], method=None, tolerance=None ) -> IndexSelResult: if method != "nearest": - raise ValueError("TreeIndex only supports selection with method='nearest'") + raise ValueError( + "NDPointIndex only supports selection with method='nearest'" + ) missing_labels = set(self._coord_names) - set(labels) if missing_labels: @@ -330,7 +332,7 @@ def sel( xr_labels[name] = Variable(self._dims, lbl) else: raise ValueError( - "invalid label value. TreeIndex only supports advanced (point-wise) indexing " + "invalid label value. NDPointIndex only supports advanced (point-wise) indexing " "with the following label value kinds:\n" "- xarray.DataArray or xarray.Variable objects\n" "- scalar values\n" diff --git a/xarray/tests/test_tree_index.py b/xarray/tests/test_nd_point_index.py similarity index 88% rename from xarray/tests/test_tree_index.py rename to xarray/tests/test_nd_point_index.py index 4948de01349..e322e21133c 100644 --- a/xarray/tests/test_tree_index.py +++ b/xarray/tests/test_nd_point_index.py @@ -2,26 +2,26 @@ import pytest import xarray as xr -from xarray.indexes import TreeIndex +from xarray.indexes import NDPointIndex from xarray.tests import assert_identical pytest.importorskip("scipy") def test_tree_index_init() -> None: - from xarray.indexes.tree_index import ScipyKDTreeAdapter + from xarray.indexes.nd_point_index import ScipyKDTreeAdapter xx, yy = np.meshgrid([1.0, 2.0], [3.0, 4.0]) ds = xr.Dataset(coords={"xx": (("y", "x"), xx), "yy": (("y", "x"), yy)}) - ds_indexed1 = ds.set_xindex(("xx", "yy"), TreeIndex) + ds_indexed1 = ds.set_xindex(("xx", "yy"), NDPointIndex) assert "xx" in ds_indexed1.xindexes assert "yy" in ds_indexed1.xindexes - assert isinstance(ds_indexed1.xindexes["xx"], TreeIndex) + assert isinstance(ds_indexed1.xindexes["xx"], NDPointIndex) assert ds_indexed1.xindexes["xx"] is ds_indexed1.xindexes["yy"] ds_indexed2 = ds.set_xindex( - ("xx", "yy"), TreeIndex, tree_adapter_cls=ScipyKDTreeAdapter + ("xx", "yy"), NDPointIndex, tree_adapter_cls=ScipyKDTreeAdapter ) assert ds_indexed1.xindexes["xx"].equals(ds_indexed2.xindexes["yy"]) @@ -31,18 +31,18 @@ def test_tree_index_init_errors() -> None: ds = xr.Dataset(coords={"xx": (("y", "x"), xx), "yy": (("y", "x"), yy)}) with pytest.raises(ValueError, match="number of variables"): - ds.set_xindex("xx", TreeIndex) + ds.set_xindex("xx", NDPointIndex) ds2 = ds.assign_coords(yy=(("u", "v"), [[3.0, 3.0], [4.0, 4.0]])) with pytest.raises(ValueError, match="same dimensions"): - ds2.set_xindex(("xx", "yy"), TreeIndex) + ds2.set_xindex(("xx", "yy"), NDPointIndex) def test_tree_index_sel() -> None: xx, yy = np.meshgrid([1.0, 2.0], [3.0, 4.0]) ds = xr.Dataset(coords={"xx": (("y", "x"), xx), "yy": (("y", "x"), yy)}).set_xindex( - ("xx", "yy"), TreeIndex + ("xx", "yy"), NDPointIndex ) # 1-dimensional labels @@ -101,7 +101,7 @@ def test_tree_index_sel() -> None: def test_tree_index_sel_errors() -> None: xx, yy = np.meshgrid([1.0, 2.0], [3.0, 4.0]) ds = xr.Dataset(coords={"xx": (("y", "x"), xx), "yy": (("y", "x"), yy)}).set_xindex( - ("xx", "yy"), TreeIndex + ("xx", "yy"), NDPointIndex ) with pytest.raises(ValueError, match="method='nearest'"): @@ -133,17 +133,17 @@ def test_tree_index_equals() -> None: xx1, yy1 = np.meshgrid([1.0, 2.0], [3.0, 4.0]) ds1 = xr.Dataset( coords={"xx": (("y", "x"), xx1), "yy": (("y", "x"), yy1)} - ).set_xindex(("xx", "yy"), TreeIndex) + ).set_xindex(("xx", "yy"), NDPointIndex) xx2, yy2 = np.meshgrid([1.0, 2.0], [3.0, 4.0]) ds2 = xr.Dataset( coords={"xx": (("y", "x"), xx2), "yy": (("y", "x"), yy2)} - ).set_xindex(("xx", "yy"), TreeIndex) + ).set_xindex(("xx", "yy"), NDPointIndex) xx3, yy3 = np.meshgrid([10.0, 20.0], [30.0, 40.0]) ds3 = xr.Dataset( coords={"xx": (("y", "x"), xx3), "yy": (("y", "x"), yy3)} - ).set_xindex(("xx", "yy"), TreeIndex) + ).set_xindex(("xx", "yy"), NDPointIndex) assert ds1.xindexes["xx"].equals(ds2.xindexes["xx"]) assert not ds1.xindexes["xx"].equals(ds3.xindexes["xx"]) @@ -152,12 +152,12 @@ def test_tree_index_equals() -> None: def test_tree_index_rename() -> None: xx, yy = np.meshgrid([1.0, 2.0], [3.0, 4.0]) ds = xr.Dataset(coords={"xx": (("y", "x"), xx), "yy": (("y", "x"), yy)}).set_xindex( - ("xx", "yy"), TreeIndex + ("xx", "yy"), NDPointIndex ) ds_renamed = ds.rename_dims(y="u").rename_vars(yy="uu") assert "uu" in ds_renamed.xindexes - assert isinstance(ds_renamed.xindexes["uu"], TreeIndex) + assert isinstance(ds_renamed.xindexes["uu"], NDPointIndex) assert ds_renamed.xindexes["xx"] is ds_renamed.xindexes["uu"] # check via sel() that uses coord names and dims under the hood From 37152804612a8cc2df7fa517d5ac02ff64778f36 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 2 Jul 2025 15:51:22 +0200 Subject: [PATCH 15/17] fix doctests --- xarray/indexes/nd_point_index.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/indexes/nd_point_index.py b/xarray/indexes/nd_point_index.py index 6c16981a1a7..11c68983a01 100644 --- a/xarray/indexes/nd_point_index.py +++ b/xarray/indexes/nd_point_index.py @@ -147,7 +147,7 @@ class NDPointIndex(Index, Generic[T_TreeAdapter]): Data variables: *empty* Indexes: - ┌ xx NDPointIndex + ┌ xx NDPointIndex (ScipyKDTreeAdapter) └ yy Point-wise (nearest-neighbor) data selection using Xarray's advanced From 4bf48ab511f911d6522bc4d3c500f47c007ee547 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 3 Jul 2025 08:41:25 +0200 Subject: [PATCH 16/17] a few tweaks --- xarray/indexes/nd_point_index.py | 23 ++++++++++++----------- xarray/tests/test_nd_point_index.py | 5 +++-- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/xarray/indexes/nd_point_index.py b/xarray/indexes/nd_point_index.py index 11c68983a01..da440fd56dd 100644 --- a/xarray/indexes/nd_point_index.py +++ b/xarray/indexes/nd_point_index.py @@ -101,9 +101,10 @@ class NDPointIndex(Index, Generic[T_TreeAdapter]): """Xarray index for irregular, n-dimensional data. This index may be associated with a set of coordinate variables representing - the location of the data points in an n-dimensional space. All coordinates - must have the same shape and dimensions. The number of associated coordinate - variables must correspond to the number of dimensions of the space. + the arbitrary location of data points in an n-dimensional space. All + coordinates must have the same shape and dimensions. The number of + associated coordinate variables must correspond to the number of dimensions + of the space. This index supports label-based selection (nearest neighbor lookup). It also has limited support for alignment. @@ -112,14 +113,13 @@ class NDPointIndex(Index, Generic[T_TreeAdapter]): lookup. Do not use :py:meth:`~xarray.indexes.NDPointIndex.__init__` directly. Instead - use :py:meth:`~xarray.Dataset.set_xindex` or - :py:meth:`~xarray.DataArray.set_xindex` to create and set the index from + use :py:meth:`xarray.Dataset.set_xindex` or + :py:meth:`xarray.DataArray.set_xindex` to create and set the index from existing coordinates (see the example below). Examples -------- - An example using a dataset with 2-dimensional coordinates representing - irregularly spaced data points. + An example using a dataset with 2-dimensional coordinates. >>> xx = [[1.0, 2.0], [3.0, 0.0]] >>> yy = [[11.0, 21.0], [29.0, 9.0]] @@ -134,7 +134,7 @@ class NDPointIndex(Index, Generic[T_TreeAdapter]): Data variables: *empty* - Create a NDPointIndex from the "xx" and "yy" coordinate variables: + Creation of a NDPointIndex from the "xx" and "yy" coordinate variables: >>> ds = ds.set_xindex(("xx", "yy"), xr.indexes.NDPointIndex) >>> ds @@ -244,7 +244,8 @@ def from_variables( if len(variables) != len(var0.dims): raise ValueError( - f"the number of variables {len(variables)} doesn't match the number of dimensions {len(var0.dims)}" + f"the number of variables {len(variables)} doesn't match " + f"the number of dimensions {len(var0.dims)}" ) opts = dict(options) @@ -267,8 +268,8 @@ def create_variables( ) -> dict[Any, Variable]: if variables is not None: for var in variables.values(): - # might need to update variable dimensions from the index object - # returned from NDPointIndex.rename() + # maybe re-sync variable dimensions with the index object + # returned by NDPointIndex.rename() if var.dims != self._dims: var.dims = self._dims return dict(**variables) diff --git a/xarray/tests/test_nd_point_index.py b/xarray/tests/test_nd_point_index.py index e322e21133c..cbfd3504bba 100644 --- a/xarray/tests/test_nd_point_index.py +++ b/xarray/tests/test_nd_point_index.py @@ -84,7 +84,7 @@ def test_tree_index_sel() -> None: ) assert_identical(actual, expected) - # unlabelled array-like labels with dimensions matching index coordinate dimensions + # implicit dimension array-like labels actual = ds.sel( xx=[[1.1, 1.1, 1.1], [1.9, 1.9, 1.9]], yy=[[3.1, 3.1, 3.1], [3.9, 3.9, 3.9]], @@ -160,7 +160,8 @@ def test_tree_index_rename() -> None: assert isinstance(ds_renamed.xindexes["uu"], NDPointIndex) assert ds_renamed.xindexes["xx"] is ds_renamed.xindexes["uu"] - # check via sel() that uses coord names and dims under the hood + # test via sel() with implicit dimension array-like labels, which relies on + # NDPointIndex._coord_names and NDPointIndex._dims internal attrs actual = ds_renamed.sel( xx=[[1.1, 1.1, 1.1], [1.9, 1.9, 1.9]], uu=[[3.1, 3.1, 3.1], [3.9, 3.9, 3.9]], From f41816781048f77ef6c38b536f5b8e9966176f48 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Fri, 4 Jul 2025 09:41:30 +0200 Subject: [PATCH 17/17] sel: simply broadcast labels against each other --- xarray/indexes/nd_point_index.py | 64 ++++++++++++++--------------- xarray/tests/test_nd_point_index.py | 24 +++++++---- 2 files changed, 46 insertions(+), 42 deletions(-) diff --git a/xarray/indexes/nd_point_index.py b/xarray/indexes/nd_point_index.py index da440fd56dd..283b8d7d676 100644 --- a/xarray/indexes/nd_point_index.py +++ b/xarray/indexes/nd_point_index.py @@ -11,6 +11,7 @@ from xarray.core.indexing import IndexSelResult from xarray.core.utils import is_scalar from xarray.core.variable import Variable +from xarray.structure.alignment import broadcast if TYPE_CHECKING: from scipy.spatial import KDTree @@ -179,7 +180,7 @@ class NDPointIndex(Index, Generic[T_TreeAdapter]): Data variables: *empty* - Data selection with broadcasted scalar labels: + Data selection with broadcasting the input labels: >>> ds.sel(xx=1.9, yy=xr.Variable("points", [13.0, 8.0]), method="nearest") Size: 32B @@ -191,6 +192,21 @@ class NDPointIndex(Index, Generic[T_TreeAdapter]): Data variables: *empty* + >>> da = xr.DataArray( + ... [[45.1, 53.3], [65.4, 78.2]], + ... coords={"u": [1.9, 0.1], "v": [13.0, 8.0]}, + ... dims=("u", "v"), + ... ) + >>> ds.sel(xx=da.u, yy=da.v, method="nearest") + Size: 64B + Dimensions: (u: 2, v: 2) + Coordinates: + xx (u, v) float64 32B 1.0 0.0 1.0 0.0 + yy (u, v) float64 32B 11.0 9.0 11.0 9.0 + Dimensions without coordinates: u, v + Data variables: + *empty* + Data selection with array-like labels (implicit dimensions): >>> ds.sel(xx=[[1.9], [0.1]], yy=[[13.0], [8.0]], method="nearest") @@ -321,16 +337,18 @@ def sel( missing_labels_str = ",".join([f"{name}" for name in missing_labels]) raise ValueError(f"missing labels for coordinate(s): {missing_labels_str}.") - # maybe convert labels into xarray Variable objects - xr_labels: dict[Any, Variable | DataArray] = {} + # maybe convert labels into xarray DataArray objects + xr_labels: dict[Any, DataArray] = {} for name, lbl in labels.items(): - if isinstance(lbl, Variable | DataArray): + if isinstance(lbl, DataArray): xr_labels[name] = lbl + elif isinstance(lbl, Variable): + xr_labels[name] = DataArray(lbl) elif is_scalar(lbl): - xr_labels[name] = Variable((), lbl) + xr_labels[name] = DataArray(lbl, dims=()) elif np.asarray(lbl).ndim == len(self._dims): - xr_labels[name] = Variable(self._dims, lbl) + xr_labels[name] = DataArray(lbl, dims=self._dims) else: raise ValueError( "invalid label value. NDPointIndex only supports advanced (point-wise) indexing " @@ -341,35 +359,13 @@ def sel( f"than the {self._coord_names} coordinate variables ({len(self._dims)})" ) - # determine xarray label shape and dimensions - label_dims: tuple[Hashable, ...] = () - label_shape: tuple[int, ...] = () - - for name, lbl in xr_labels.items(): - if lbl.ndim > 0: - if label_dims and lbl.dims != label_dims: - raise ValueError( - f"label {name} has dimensions {lbl.dims} that conflict with " - f"other label dimensions {label_dims}" - ) - else: - label_dims = lbl.dims - if label_shape and lbl.shape != label_shape: - raise ValueError( - f"label {name} has shape {lbl.shape} that conflicts with " - f"other label shape {label_shape}" - ) - else: - label_shape = lbl.shape - - # maybe broadcast scalar xarray labels - if label_dims: - for name, lbl in xr_labels.items(): - if not lbl.dims: - xr_labels[name] = Variable( - label_dims, np.broadcast_to(lbl.values, label_shape) - ) + # broadcast xarray labels against one another and determine labels shape and dimensions + broadcasted = broadcast(*xr_labels.values()) + label_dims = broadcasted[0].dims + label_shape = broadcasted[0].shape + xr_labels = dict(zip(xr_labels, broadcasted, strict=True)) + # get and return dimension indexers points = get_points(xr_labels[name] for name in self._coord_names) _, indices = self._tree_obj.query(points) diff --git a/xarray/tests/test_nd_point_index.py b/xarray/tests/test_nd_point_index.py index cbfd3504bba..eb497aa263f 100644 --- a/xarray/tests/test_nd_point_index.py +++ b/xarray/tests/test_nd_point_index.py @@ -84,6 +84,20 @@ def test_tree_index_sel() -> None: ) assert_identical(actual, expected) + # broadcast orthogonal 1-dimensional labels + actual = ds.sel( + xx=xr.Variable("u", [1.1, 1.1]), + yy=xr.Variable("v", [3.1, 3.1]), + method="nearest", + ) + expected = xr.Dataset( + coords={ + "xx": (("u", "v"), [[1.0, 1.0], [1.0, 1.0]]), + "yy": (("u", "v"), [[3.0, 3.0], [3.0, 3.0]]), + }, + ) + assert_identical(actual, expected) + # implicit dimension array-like labels actual = ds.sel( xx=[[1.1, 1.1, 1.1], [1.9, 1.9, 1.9]], @@ -114,14 +128,8 @@ def test_tree_index_sel_errors() -> None: # invalid array-like dimensions ds.sel(xx=[1.1, 1.9], yy=[3.1, 3.9], method="nearest") - with pytest.raises(ValueError, match=".*dimensions.*conflict"): - ds.sel( - xx=xr.Variable("u", [1.1, 1.1, 1.1]), - yy=xr.Variable("v", [3.1, 3.1, 3.1]), - method="nearest", - ) - - with pytest.raises(ValueError, match=".*shape.*conflict"): + # error while trying to broadcast labels + with pytest.raises(xr.AlignmentError, match=".*conflicting dimension sizes"): ds.sel( xx=xr.Variable("u", [1.1, 1.1, 1.1]), yy=xr.Variable("u", [3.1, 3.1]),