Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ dependencies = [
"pandas>=2",
"numpy>=2",
"xproj>=0.2.0",
"xarray>=2025",
"xarray @ git+https://github.com/dcherian/xarray.git@fix-coord-transform-indexing",
]
dynamic=["version"]

Expand Down Expand Up @@ -75,6 +75,9 @@ docs = [
[tool.hatch]
version.source = "vcs"

[tool.hatch.metadata]
allow-direct-references = true

[tool.hatch.build]
hooks.vcs.version-file = "src/rasterix/_version.py"

Expand Down
11 changes: 7 additions & 4 deletions src/rasterix/raster_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,17 +310,20 @@ def isel( # type: ignore[override]
def sel(self, labels, method=None, tolerance=None):
coord_name = self.axis_transform.coord_name
label = labels[coord_name]

transform = self.axis_transform
if isinstance(label, slice):
if label.start is None:
label = slice(0, label.stop, label.step)
label = slice(
label.start or transform.forward({coord_name: 0})[coord_name],
label.stop or transform.forward({coord_name: transform.size})[coord_name],
label.step,
)
if label.step is None:
# continuous interval slice indexing (preserves the index)
pos = self.transform.reverse({coord_name: np.array([label.start, label.stop])})
# np.round rounds to even, this way we round upwards
pos = np.floor(pos[self.dim] + 0.5).astype("int")
new_start = max(pos[0], 0)
new_stop = min(pos[1], self.axis_transform.size)
new_stop = min(pos[1] + 1, self.axis_transform.size)
return IndexSelResult({self.dim: slice(new_start, new_stop)})
else:
# otherwise convert to basic (array) indexing
Expand Down
165 changes: 165 additions & 0 deletions src/rasterix/strategies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
"""Hypothesis strategies for generating label-based indexers."""

from collections.abc import Hashable
from typing import Any

import numpy as np
import pandas as pd
import xarray as xr
from hypothesis import note
from hypothesis import strategies as st
from xarray.core.indexes import Indexes
from xarray.testing.strategies import (
basic_indexers,
outer_array_indexers,
vectorized_indexers,
)


def pos_to_label_indexer(idx: pd.Index, idxr: int | slice | np.ndarray, *, use_scalar: bool = True) -> Any:
"""Convert a positional indexer to a label-based indexer.

Parameters
----------
idx : pd.Index
The pandas Index to use for label lookup.
idxr : int | slice | np.ndarray
The positional indexer (integer, slice, or array of integers).
use_scalar : bool, optional
If True, attempt to convert scalar values to Python scalars. Default is True.

Returns
-------
Any
The label-based indexer (scalar, slice, or array of labels).
"""
if isinstance(idxr, slice):
return slice(
None if idxr.start is None else idx[idxr.start],
# FIXME: This will never go past the label range
None if idxr.stop is None else idx[min(idxr.stop, idx.size - 1)],
)
elif isinstance(idxr, np.ndarray):
# Convert array of position indices to array of label values
return idx[idxr].values
else:
val = idx[idxr]
if use_scalar:
try:
# pass python scalars occasionally
val = val.item()
except Exception:
note(f"casting {val!r} to item() failed")
pass
return val


@st.composite
def basic_label_indexers(draw, /, *, indexes: Indexes) -> dict[Hashable, float | slice]:
"""Generate label-based indexers by converting position indexers to labels.

This works in label space by using the coordinate Index values.

Parameters
----------
draw : callable
The Hypothesis draw function (automatically provided by @st.composite).
indexes : Indexes
Dictionary mapping dimension names to their associated indexes

Returns
-------
dict[Hashable, float | slice]
Label-based indexers as a dict with keys from sizes.keys().
Values are either float (for scalar labels) or slice (for label ranges).
"""
idxs = indexes.get_unique()
assert all(isinstance(idx, xr.indexes.PandasIndex) for idx in idxs)

# FIXME: this should be indexes.sizes!
sizes = indexes.dims

pos_indexer = draw(basic_indexers(sizes=sizes))
pdindexes = indexes.to_pandas_indexes()

label_indexer = {
dim: pos_to_label_indexer(pdindexes[dim], idx, use_scalar=draw(st.booleans()))
for dim, idx in pos_indexer.items()
}
return label_indexer


@st.composite
def outer_array_label_indexers(draw, /, *, indexes: Indexes) -> dict[Hashable, np.ndarray]:
"""Generate label-based outer array indexers by converting position indexers to labels.

This works in label space by using the coordinate Index values.

Parameters
----------
draw : callable
The Hypothesis draw function (automatically provided by @st.composite).
indexes : Indexes
Dictionary mapping dimension names to their associated indexes

Returns
-------
dict[Hashable, np.ndarray]
Label-based indexers as a dict with keys from indexes.
Values are numpy arrays of label values for each dimension.
"""
idxs = indexes.get_unique()
assert all(isinstance(idx, xr.indexes.PandasIndex) for idx in idxs)

# FIXME: this should be indexes.sizes!
sizes = indexes.dims

pos_indexer = draw(outer_array_indexers(sizes=sizes))
pdindexes = indexes.to_pandas_indexes()

label_indexer = {
dim: pos_to_label_indexer(pdindexes[dim], idx, use_scalar=False) for dim, idx in pos_indexer.items()
}
return label_indexer


@st.composite
def vectorized_label_indexers(draw, /, *, indexes: Indexes, **kwargs) -> dict[Hashable, xr.DataArray]:
"""Generate label-based vectorized indexers by converting position indexers to labels.

This works in label space by using the coordinate Index values.

Parameters
----------
draw : callable
The Hypothesis draw function (automatically provided by @st.composite).
indexes : Indexes
Dictionary mapping dimension names to their associated indexes
**kwargs : dict
Additional keyword arguments to pass to vectorized_indexers

Returns
-------
dict[Hashable, xr.DataArray]
Label-based indexers as a dict with keys from indexes.
Values are DataArrays of label values for each dimension.
"""
idxs = indexes.get_unique()
assert all(isinstance(idx, xr.indexes.PandasIndex) for idx in idxs)

# FIXME: this should be indexes.sizes!
sizes = indexes.dims

pos_indexer = draw(vectorized_indexers(sizes=sizes, **kwargs))
pdindexes = indexes.to_pandas_indexes()

label_indexer = {}
for dim, idx_array in pos_indexer.items():
# Convert each position in the array to its corresponding label
# Flatten, index, then reshape back to original shape
flat_indices = idx_array.values.ravel()
flat_labels = pdindexes[dim][flat_indices].values
label_values = flat_labels.reshape(idx_array.shape)
label_indexer[dim] = xr.DataArray(label_values, dims=idx_array.dims)

return label_indexer
Loading
Loading