From 5659a2b5e45fa9a51ed2ad20ce94caeffb3ed647 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 1 Apr 2025 18:04:57 +0200 Subject: [PATCH 01/39] Take xgcm.Grid and adapt for our uses Note that the datasets.py file was also taken from xgcm --- parcels/comodo.py | 148 +++++++ parcels/gridv4.py | 635 +++++++++++++++++++++++++++ tests/v4/__init__.py | 0 tests/v4/datasets.py | 364 ++++++++++++++++ tests/v4/test_grid.py | 965 ++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 2112 insertions(+) create mode 100644 parcels/comodo.py create mode 100644 parcels/gridv4.py create mode 100644 tests/v4/__init__.py create mode 100644 tests/v4/datasets.py create mode 100644 tests/v4/test_grid.py diff --git a/parcels/comodo.py b/parcels/comodo.py new file mode 100644 index 0000000000..c014b7f031 --- /dev/null +++ b/parcels/comodo.py @@ -0,0 +1,148 @@ +from collections import OrderedDict + +# Representation of axis shifts +axis_shift_left = -0.5 +axis_shift_right = 0.5 +axis_shift_center = 0 +# Characterizes valid shifts only +valid_axis_shifts = [axis_shift_left, axis_shift_right, axis_shift_center] + + +def assert_valid_comodo(ds): + """Verify that the dataset meets comodo conventions + + Parameters + ---------- + ds : xarray.dataset + """ + + # TODO: implement + assert True + + +def get_all_axes(ds): + axes = set() + for d in ds.dims: + if "axis" in ds[d].attrs: + axes.add(ds[d].attrs["axis"]) + return axes + + +def get_axis_coords(ds, axis_name): + """Find the name of the coordinates associated with a comodo axis. + + Parameters + ---------- + ds : xarray.dataset or xarray.dataarray + axis_name : str + The name of the axis to find (e.g. 'X') + + Returns + ------- + coord_name : list + The names of the coordinate matching that axis + """ + + coord_names = [] + for d in ds.dims: + axis = ds[d].attrs.get("axis") + if axis == axis_name: + coord_names.append(d) + return coord_names + + +def get_axis_positions_and_coords(ds, axis_name): + coord_names = get_axis_coords(ds, axis_name) + ncoords = len(coord_names) + if ncoords == 0: + # didn't find anything for this axis + raise ValueError("Couldn't find any coordinates for axis %s" % axis_name) + + # now figure out what type of coordinates these are: + # center, left, right, or outer + coords = {name: ds[name] for name in coord_names} + + # some tortured logic for dealing with malformed c_grid_axis_shift + # attributes such as produced by old versions of xmitgcm. + # This should be a float (either -0.5 or 0.5) + # this function returns that, or True of the attribute is set to + # anything at all + def _maybe_fix_type(attr): + if attr is not None: + try: + return float(attr) + except TypeError: + return True + + axis_shift = { + name: _maybe_fix_type(coord.attrs.get("c_grid_axis_shift")) + for name, coord in coords.items() + } + coord_len = {name: len(coord) for name, coord in coords.items()} + + # look for the center coord, which is required + # this list will potentially contain "center", "inner", and "outer" points + coords_without_axis_shift = { + name: coord_len[name] for name, shift in axis_shift.items() if not shift + } + if len(coords_without_axis_shift) == 0: + raise ValueError("Couldn't find a center coordinate for axis %s" % axis_name) + elif len(coords_without_axis_shift) > 1: + raise ValueError( + "Found two coordinates without " + "`c_grid_axis_shift` attribute for axis %s" % axis_name + ) + center_coord_name = list(coords_without_axis_shift)[0] + # the length of the center coord is key to decoding the other coords + axis_len = coord_len[center_coord_name] + + # now we can start filling in the information about the different coords + axis_coords = OrderedDict() + axis_coords["center"] = center_coord_name + + # now check the other coords + coord_names.remove(center_coord_name) + for name in coord_names: + shift = axis_shift[name] + clen = coord_len[name] + if clen == axis_len + 1: + axis_coords["outer"] = name + elif clen == axis_len - 1: + axis_coords["inner"] = name + elif shift == axis_shift_left: + if clen == axis_len: + axis_coords["left"] = name + else: + raise ValueError( + "Left coordinate %s has incompatible " + "length %g (axis_len=%g)" % (name, clen, axis_len) + ) + elif shift == axis_shift_right: + if clen == axis_len: + axis_coords["right"] = name + else: + raise ValueError( + "Right coordinate %s has incompatible " + "length %g (axis_len=%g)" % (name, clen, axis_len) + ) + else: + if shift not in valid_axis_shifts: + # string representing valid axis shifts + valids = str(valid_axis_shifts)[1:-1] + + raise ValueError( + "Coordinate %s has invalid " + "`c_grid_axis_shift` attribute `%s`. " + "`c_grid_axis_shift` must be one of: %s" + % (name, repr(shift), valids) + ) + else: + raise ValueError( + "Coordinate %s has missing " + "`c_grid_axis_shift` attribute `%s`" % (name, repr(shift)) + ) + return axis_coords + + +def _assert_data_on_grid(da): + pass diff --git a/parcels/gridv4.py b/parcels/gridv4.py new file mode 100644 index 0000000000..7fd1f1e52b --- /dev/null +++ b/parcels/gridv4.py @@ -0,0 +1,635 @@ +"""This Grid object is adapted from xgcm.Grid, removing a lot of the code that is not needed for Parcels.""" + +import functools +import inspect +import itertools +import operator +import warnings +from collections import OrderedDict + +import numpy as np +import xarray as xr +from dask.array import Array as Dask_Array + +from . import comodo + +# from .duck_array_ops import _apply_boundary_condition, _pad_array, concatenate +# from .grid_ufunc import ( +# GridUFunc, +# _check_data_input, +# _GridUFuncSignature, +# _has_chunked_core_dims, +# _maybe_unpack_vector_component, +# _reattach_coords, +# apply_as_grid_ufunc, +# ) +# from .metrics import iterate_axis_combinations +# from .padding import pad +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) + +try: + import numba # type: ignore + + from .transform import conservative_interpolation, linear_interpolation +except ImportError: + numba = None + +_VALID_BOUNDARY = [None, "fill", "extend", "periodic"] + + +def _maybe_promote_str_to_list(a): + # TODO: improve this + if isinstance(a, str): + return [a] + else: + return a + + +class Axis: + """ + An object that represents a group of coordinates that all lie along the same + physical dimension but at different positions with respect to a grid cell. + There are four possible positions: + + Center + |------o-------|------o-------|------o-------|------o-------| + [0] [1] [2] [3] + + Left + |------o-------|------o-------|------o-------|------o-------| + [0] [1] [2] [3] + + Right + |------o-------|------o-------|------o-------|------o-------| + [0] [1] [2] [3] + + Inner + |------o-------|------o-------|------o-------|------o-------| + [0] [1] [2] + + Outer + |------o-------|------o-------|------o-------|------o-------| + [0] [1] [2] [3] [4] + + The `center` position is the only one without the `c_grid_axis_shift` + attribute, which must be present for the other four. However, the actual + value of `c_grid_axis_shift` is ignored for `inner` and `outer`, which are + differentiated by their length. + """ + + def __init__( + self, + ds, + axis_name, + periodic=True, + default_shifts={}, + coords=None, + boundary=None, + fill_value=None, + ): + """ + Create a new Axis object from an input dataset. + + Parameters + ---------- + ds : xarray.Dataset + Contains the relevant grid information. Coordinate attributes + should conform to Comodo conventions [1]_. + axis_name : str + The name of the axis (should match axis attribute) + periodic : bool, optional + Whether the domain is periodic along this axis + default_shifts : dict, optional + Default mapping from and to grid positions + (e.g. `{'center': 'left'}`). Will be inferred if not specified. + coords : dict, optional + Mapping of axis positions to coordinate names + (e.g. `{'center': 'XC', 'left: 'XG'}`) + boundary : str or dict, optional, + boundary can either be one of {None, 'fill', 'extend', 'extrapolate', 'periodic'} + + * None: Do not apply any boundary conditions. Raise an error if + boundary conditions are required for the operation. + * 'fill': Set values outside the array boundary to fill_value + (i.e. a Dirichlet boundary condition.) + * 'extend': Set values outside the array to the nearest array + value. (i.e. a limited form of Neumann boundary condition where + the difference at the boundary will be zero.) + * 'extrapolate': Set values by extrapolating linearly from the two + points nearest to the edge + * 'periodic' : Wrap arrays around. Equivalent to setting `periodic=True` + This sets the default value. It can be overriden by specifying the + boundary kwarg when calling specific methods. + fill_value : float, optional + The value to use in the boundary condition when `boundary='fill'`. + + REFERENCES + ---------- + .. [1] Comodo Conventions https://web.archive.org/web/20160417032300/http://pycomodo.forge.imag.fr/norm.html + """ + + self._ds = ds + self.name = axis_name + self._periodic = periodic + if boundary not in _VALID_BOUNDARY: + raise ValueError(f"Expected 'boundary' to be one of {_VALID_BOUNDARY}. Received {boundary!r} instead.") + self.boundary = boundary + if fill_value is not None and not isinstance(fill_value, (int, float)): + raise ValueError("Expected 'fill_value' to be a number.") + self.fill_value = fill_value if fill_value is not None else 0.0 + + if coords: + # use specified coords + self.coords = {pos: name for pos, name in coords.items()} + else: + # fall back on comodo conventions + self.coords = comodo.get_axis_positions_and_coords(ds, axis_name) + + # self.coords is a dictionary with the following structure + # key: position_name {'center' ,'left' ,'right', 'outer', 'inner'} + # value: name of the dimension + + # set default position shifts + fallback_shifts = { + "center": ("left", "right", "outer", "inner"), + "left": ("center",), + "right": ("center",), + "outer": ("center",), + "inner": ("center",), + } + self._default_shifts = {} + for pos in self.coords: + # use user-specified value if present + if pos in default_shifts: + self._default_shifts[pos] = default_shifts[pos] + else: + for possible_shift in fallback_shifts[pos]: + if possible_shift in self.coords: + self._default_shifts[pos] = possible_shift + break + + ######################################################################## + # DEVELOPER DOCUMENTATION + # + # The attributes below are my best attempt to represent grid topology + # in a general way. The data structures are complicated, but I can't + # think of any way to simplify them. + # + # self._facedim (str) is the name of a dimension (e.g. 'face') or None. + # If it is None, that means that the grid topology is _simple_, i.e. + # that this is not a cubed-sphere grid or similar. For example: + # + # ds.dims == ('time', 'lat', 'lon') + # + # If _facedim is set to a dimension name, that means that shifting + # grid positions requires exchanging data among multiple "faces" + # (a.k.a. "tiles", "facets", etc.). For this to work, there must be a + # dimension corresponding to the different faces. This is `_facedim`. + # For example: + # + # ds.dims == ('time', 'face', 'lat', 'lon') + # + # In this case, `self._facedim == 'face'` + # + # We initialize all of this to None and let the `Grid` class handle + # setting these attributes for complex geometries. + self._facedim = None + # + # `self._connections` is a dictionary. It contains information about the + # connectivity among this axis and other axes. + # It should have the structure + # + # {facedim_index: ((left_facedim_index, left_axis, left_reverse), + # (right_facedim_index, right_axis, right_reverse)} + # + # `facedim_index` : a value used to index the `self._facedim` dimension + # (If `self._facedim` is `None`, then there should be only one key in + # `facedim_index` and that key should be `None`.) + # `left_facedim_index` : the facedim index of the neighbor to the left. + # (If `self._facedim` is `None`, this must also be `None`.) + # `left_axis` : an `Axis` object for the values to the left of this axis + # `left_reverse` : bool, whether the connection should be reversed. By + # default, the left side of this axis will be connected to the right + # side of the neighboring axis. `left_reverse` overrides this and + # instead connects to the left side of the neighboring axis + self._connections = {None: (None, None)} + + # now we implement periodic coordinates by setting appropriate + # connections + if periodic: + self._connections = {None: ((None, self, False), (None, self, False))} + + def __repr__(self): + is_periodic = "periodic" if self._periodic else "not periodic" + summary = ["" % (self.name, is_periodic, self.boundary)] + summary.append("Axis Coordinates:") + summary += self._coord_desc() + return "\n".join(summary) + + def _coord_desc(self): + summary = [] + for name, cname in self.coords.items(): + coord_info = " * %-8s %s" % (name, cname) + if name in self._default_shifts: + coord_info += " --> %s" % self._default_shifts[name] + summary.append(coord_info) + return summary + + def _get_position_name(self, da): + """Return the position and name of the axis coordinate in a DataArray.""" + for position, coord_name in self.coords.items(): + # TODO: should we have more careful checking of alignment here? + if coord_name in da.dims: + return position, coord_name + + raise KeyError("None of the DataArray's dims %s were found in axis " "coords." % repr(da.dims)) + + def _get_axis_dim_num(self, da): + """Return the dimension number of the axis coordinate in a DataArray.""" + _, coord_name = self._get_position_name(da) + return da.get_axis_num(coord_name) + + +class Grid: + """ + An object with multiple :class:`xgcm.Axis` objects representing different + independent axes. + """ + + def __init__( + self, + ds, + check_dims=True, + periodic=True, + default_shifts={}, + face_connections=None, + coords=None, + metrics=None, + boundary=None, + fill_value=None, + ): + """ + Create a new Grid object from an input dataset. + + Parameters + ---------- + ds : xarray.Dataset + Contains the relevant grid information. Coordinate attributes + should conform to Comodo conventions [1]_. + check_dims : bool, optional + Whether to check the compatibility of input data dimensions before + performing grid operations. + periodic : {True, False, list} + Whether the grid is periodic (i.e. "wrap-around"). If a list is + specified (e.g. ``['X', 'Y']``), the axis names in the list will be + be periodic and any other axes founds will be assumed non-periodic. + default_shifts : dict + A dictionary of dictionaries specifying default grid position + shifts (e.g. ``{'X': {'center': 'left', 'left': 'center'}}``) + face_connections : dict + Grid topology + coords : dict, optional + Specifies positions of dimension names along axes X, Y, Z, e.g + ``{'X': {'center': 'XC', 'left: 'XG'}}``. + Each key should be an axis name (e.g., `X`, `Y`, or `Z`) and map + to a dictionary which maps positions (`center`, `left`, `right`, + `outer`, `inner`) to dimension names in the dataset + (in the example above, `XC` is at the `center` position and `XG` + at the `left` position along the `X` axis). + If the values are not present in ``ds`` or are not dimensions, + an error will be raised. + metrics : dict, optional + Specification of grid metrics mapping axis names (X, Y, Z) to corresponding + metric variable names in the dataset + (e.g. {('X',):['dx_t'], ('X', 'Y'):['area_tracer', 'area_u']} + for the cell distance in the x-direction ``dx_t`` and the + horizontal cell areas ``area_tracer`` and ``area_u``, located at + different grid positions). + boundary : {None, 'fill', 'extend', 'extrapolate', dict}, optional + A flag indicating how to handle boundaries: + + * None: Do not apply any boundary conditions. Raise an error if + boundary conditions are required for the operation. + * 'fill': Set values outside the array boundary to fill_value + (i.e. a Dirichlet boundary condition.) + * 'extend': Set values outside the array to the nearest array + value. (i.e. a limited form of Neumann boundary condition.) + * 'extrapolate': Set values by extrapolating linearly from the two + points nearest to the edge + Optionally a dict mapping axis name to seperate values for each axis + can be passed. + fill_value : {float, dict}, optional + The value to use in boundary conditions with `boundary='fill'`. + Optionally a dict mapping axis name to seperate values for each axis + can be passed. + + REFERENCES + ---------- + .. [1] Comodo Conventions https://web.archive.org/web/20160417032300/http://pycomodo.forge.imag.fr/norm.html + """ + self._ds = ds + self._check_dims = check_dims + + # Deprecation Warnigns + warnings.warn( + "The `xgcm.Axis` class will be deprecated in the future. " + "Please make sure to use the `xgcm.Grid` methods for your work instead.", + category=DeprecationWarning, + ) + # This will show up every time, but I think that is fine + + if boundary: + warnings.warn( + "The `boundary` argument will be renamed " + "to `padding` to better reflect the process " + "of array padding and avoid confusion with " + "physical boundary conditions (e.g. ocean land boundary).", + category=DeprecationWarning, + ) + + # Deprecation Warnigns + if periodic: + warnings.warn( + "The `periodic` argument will be deprecated. " + "To preserve previous behavior supply `boundary = 'periodic'.", + category=DeprecationWarning, + ) + + if fill_value: + warnings.warn( + "The default fill_value will be changed to nan (from 0.0 previously) " + "in future versions. Provide `fill_value=0.0` to preserve previous behavior.", + category=DeprecationWarning, + ) + + extrapolate_warning = False + if boundary == "extrapolate": + extrapolate_warning = True + if isinstance(boundary, dict): + if any([k == "extrapolate" for k in boundary.keys()]): + extrapolate_warning = True + if extrapolate_warning: + warnings.warn( + "The `boundary='extrapolate'` option will no longer be supported in future releases.", + category=DeprecationWarning, + ) + + if coords: + all_axes = coords.keys() + else: + all_axes = comodo.get_all_axes(ds) + coords = {} + + # check coords input validity + for axis, positions in coords.items(): + for pos, dim in positions.items(): + if not (dim in ds.variables or dim in ds.dims): + raise ValueError( + f"Could not find dimension `{dim}` (for the `{pos}` position on axis `{axis}`) in input dataset." + ) + if dim not in ds.dims: + raise ValueError( + f"Input `{dim}` (for the `{pos}` position on axis `{axis}`) is not a dimension in the input datasets `ds`." + ) + + # Convert all inputs to axes-kwarg mappings + # TODO We need a way here to check valid input. Maybe also in _as_axis_kwargs? + # Parse axis properties + boundary = self._as_axis_kwarg_mapping(boundary, axes=all_axes) + fill_value = self._as_axis_kwarg_mapping(fill_value, axes=all_axes) + # TODO: In the future we want this the only place where we store these. + # TODO: This info needs to then be accessible to e.g. pad() + + # Parse list input. This case does only apply to periodic. + # Since we plan on deprecating it soon handle it here, so we can easily + # remove it later + if isinstance(periodic, list): + periodic = {axname: True for axname in periodic} + periodic = self._as_axis_kwarg_mapping(periodic, axes=all_axes) + + # Set properties on grid object. + self._facedim = list(face_connections.keys())[0] if face_connections else None + self._connections = face_connections if face_connections else None + # TODO: I think of the face connection data as grid not axes properties, since they almost by defintion + # TODO: involve multiple axes. In a future PR we should remove this info from the axes + # TODO: but make sure to properly port the checking functionality! + + # Populate axes. Much of this is just for backward compatibility. + self.axes = OrderedDict() + for axis_name in all_axes: + # periodic + is_periodic = periodic.get(axis_name, False) + + # default_shifts + if axis_name in default_shifts: + axis_default_shifts = default_shifts[axis_name] + else: + axis_default_shifts = {} + + # boundary + if isinstance(boundary, dict): + axis_boundary = boundary.get(axis_name, None) + elif isinstance(boundary, str) or boundary is None: + axis_boundary = boundary + else: + raise ValueError( + f"boundary={boundary} is invalid. Please specify a dictionary " + "mapping axis name to a boundary option; a string or None." + ) + + if isinstance(fill_value, dict): + axis_fillvalue = fill_value.get(axis_name, None) # TODO: This again sets defaults. Dont do that here. + elif isinstance(fill_value, (int, float)) or fill_value is None: + axis_fillvalue = fill_value + else: + raise ValueError( + f"fill_value={fill_value} is invalid. Please specify a dictionary " + "mapping axis name to a boundary option; a number or None." + ) + + self.axes[axis_name] = Axis( + ds, + axis_name, + is_periodic, + default_shifts=axis_default_shifts, + coords=coords.get(axis_name), + boundary=axis_boundary, + fill_value=axis_fillvalue, + ) + + if face_connections is not None: + self._assign_face_connections(face_connections) + + self._metrics = {} + + if metrics is not None: + for key, value in metrics.items(): + self.set_metrics(key, value) + + def _as_axis_kwarg_mapping( + self, + kwargs: Union[Any, Dict[str, Any]], + axes: Optional[Iterable[str]] = None, + ax_property_name=None, + default_value: Optional[Any] = None, + ) -> Dict[str, Any]: + """Convert kwarg input into dict for each available axis + E.g. for a grid with 2 axes for the keyword argument `periodic` + periodic = True --> periodic = {'X': True, 'Y':True} + or if not all axes are provided, the other axes will be parsed as defaults (None) + periodic = {'X':True} --> periodic={'X': True, 'Y':None} + """ + if axes is None: + axes = self.axes + + parsed_kwargs: Dict[str, Any] = dict() + + if isinstance(kwargs, dict): + parsed_kwargs = kwargs + else: + for axname in axes: + parsed_kwargs[axname] = kwargs + + # Check axis properties for values that were not provided (before using the default) + if ax_property_name is not None: + for axname in axes: + if axname not in parsed_kwargs.keys() or parsed_kwargs[axname] is None: + ax_property = getattr(self.axes[axname], ax_property_name) + parsed_kwargs[axname] = ax_property + + # if None set to default value. + parsed_kwargs_w_defaults = {k: default_value if v is None else v for k, v in parsed_kwargs.items()} + # At this point the output should be guaranteed to have an entry per existing axis. + # If neither a default value was given, nor an axis property was found, the value will be mapped to None. + + # temporary hack to get periodic conditions from axis + if ax_property_name == "boundary": + for axname in axes: + if self.axes[axname]._periodic: + if axname not in parsed_kwargs_w_defaults.keys(): + parsed_kwargs_w_defaults[axname] = "periodic" + + return parsed_kwargs_w_defaults + + def _assign_face_connections(self, fc): + """Check a dictionary of face connections to make sure all the links are + consistent. + """ + + if len(fc) > 1: + raise ValueError("Only one face dimension is supported for now. " "Instead found %r" % repr(fc.keys())) + + # we will populate this with the axes we find in face_connections + axis_connections = {} + + facedim = list(fc.keys())[0] + assert facedim in self._ds + + face_links = fc[facedim] + for fidx, face_axis_links in face_links.items(): + for axis, axis_links in face_axis_links.items(): + # initialize the axis dict if necssary + if axis not in axis_connections: + axis_connections[axis] = {} + link_left, link_right = axis_links + + def check_neighbor(link, position): + if link is None: + return + idx, ax, rev = link + # need to swap position if the link is reversed + correct_position = int(not position) if rev else position + try: + neighbor_link = face_links[idx][ax][correct_position] + except (KeyError, IndexError): + raise KeyError( + "Couldn't find a face link for face %r" + "in axis %r at position %r" % (idx, ax, correct_position) + ) + idx_n, ax_n, rev_n = neighbor_link + if ax not in self.axes: + raise KeyError("axis %r is not a valid axis" % ax) + if ax_n not in self.axes: + raise KeyError("axis %r is not a valid axis" % ax_n) + if idx not in self._ds[facedim].values: + raise IndexError("%r is not a valid index for face" "dimension %r" % (idx, facedim)) + if idx_n not in self._ds[facedim].values: + raise IndexError("%r is not a valid index for face" "dimension %r" % (idx, facedim)) + # check for consistent links from / to neighbor + if (idx_n != fidx) or (ax_n != axis) or (rev_n != rev): + raise ValueError( + "Face link mismatch: neighbor doesn't" + " correctly link back to this face. " + "face: %r, axis: %r, position: %r, " + "rev: %r, link: %r, neighbor_link: %r" % (fidx, axis, position, rev, link, neighbor_link) + ) + # convert the axis name to an acutal axis object + actual_axis = self.axes[ax] + return idx, actual_axis, rev + + left = check_neighbor(link_left, 1) + right = check_neighbor(link_right, 0) + axis_connections[axis][fidx] = (left, right) + + for axis, axis_links in axis_connections.items(): + self.axes[axis]._facedim = facedim + self.axes[axis]._connections = axis_links + + def set_metrics(self, key, value, overwrite=False): + metric_axes = frozenset(_maybe_promote_str_to_list(key)) + axes_not_found = [ma for ma in metric_axes if ma not in self.axes] + if len(axes_not_found) > 0: + raise KeyError(f"Metric axes {axes_not_found!r} not compatible with grid axes {tuple(self.axes)!r}") + + metric_value = _maybe_promote_str_to_list(value) + for metric_varname in metric_value: + if metric_varname not in self._ds.variables: + raise KeyError(f"Metric variable {metric_varname} not found in dataset.") + + existing_metric_axes = set(self._metrics.keys()) + if metric_axes in existing_metric_axes: + value_exist = self._metrics.get(metric_axes) + # resetting coords avoids potential broadcasting / alignment issues + value_new = self._ds[metric_varname].reset_coords(drop=True) + did_overwrite = False + # go through each existing value until data array with matching dimensions is selected + for idx, ve in enumerate(value_exist): + # double check if dimensions match + if set(value_new.dims) == set(ve.dims): + if overwrite: + # replace existing data array with new data array input + self._metrics[metric_axes][idx] = value_new + did_overwrite = True + else: + raise ValueError( + f"Metric variable {ve.name} with dimensions {ve.dims} already assigned in metrics." + f" Overwrite {ve.name} with {metric_varname} by setting overwrite=True." + ) + # if no existing value matches new value dimension-wise, just append new value + if not did_overwrite: + self._metrics[metric_axes].append(value_new) + else: + # no existing metrics for metric_axes yet; initialize empty list + self._metrics[metric_axes] = [] + for metric_varname in metric_value: + metric_var = self._ds[metric_varname].reset_coords(drop=True) + self._metrics[metric_axes].append(metric_var) + + def __repr__(self): + summary = [""] + for name, axis in self.axes.items(): + is_periodic = "periodic" if axis._periodic else "not periodic" + summary.append("%s Axis (%s, boundary=%r):" % (name, is_periodic, axis.boundary)) + summary += axis._coord_desc() + return "\n".join(summary) diff --git a/tests/v4/__init__.py b/tests/v4/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/v4/datasets.py b/tests/v4/datasets.py new file mode 100644 index 0000000000..135021c8c5 --- /dev/null +++ b/tests/v4/datasets.py @@ -0,0 +1,364 @@ +"""Datasets vendored from xgcm test suite for the testing of grids.""" + +import numpy as np +import pytest +import xarray as xr + +# example from comodo website +# https://web.archive.org/web/20160417032300/http://pycomodo.forge.imag.fr/norm.html +# netcdf example { +# dimensions: +# ni = 9 ; +# ni_u = 10 ; +# variables: +# float ni(ni) ; +# ni:axis = "X" ; +# ni:standard_name = "x_grid_index" ; +# ni:long_name = "x-dimension of the grid" ; +# ni:c_grid_dynamic_range = "2:8" ; +# float ni_u(ni_u) ; +# ni_u:axis = "X" ; +# ni_u:standard_name = "x_grid_index_at_u_location" ; +# ni_u:long_name = "x-dimension of the grid" ; +# ni_u:c_grid_dynamic_range = "3:8" ; +# ni_u:c_grid_axis_shift = -0.5 ; +# data: +# ni = 1, 2, 3, 4, 5, 6, 7, 8, 9 ; +# ni_u = 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5 ; +# } + +N = 100 +datasets = { + # the comodo example, with renamed dimensions + "1d_outer": xr.Dataset( + {"data_c": (["XC"], np.random.rand(9)), "data_g": (["XG"], np.random.rand(10))}, + coords={ + "XC": ( + ["XC"], + np.arange(1, 10), + { + "axis": "X", + "standard_name": "x_grid_index", + "long_name": "x-dimension of the grid", + "c_grid_dynamic_range": "2:8", + }, + ), + "XG": ( + ["XG"], + np.arange(0.5, 10), + { + "axis": "X", + "standard_name": "x_grid_index_at_u_location", + "long_name": "x-dimension of the grid", + "c_grid_dynamic_range": "3:8", + "c_grid_axis_shift": -0.5, + }, + ), + }, + ), + "1d_inner": xr.Dataset( + {"data_c": (["XC"], np.random.rand(9)), "data_g": (["XG"], np.random.rand(8))}, + coords={ + "XC": ( + ["XC"], + np.arange(1, 10), + { + "axis": "X", + "standard_name": "x_grid_index", + "long_name": "x-dimension of the grid", + "c_grid_dynamic_range": "2:8", + }, + ), + "XG": ( + ["XG"], + np.arange(1.5, 9), + { + "axis": "X", + "standard_name": "x_grid_index_at_u_location", + "long_name": "x-dimension of the grid", + "c_grid_dynamic_range": "3:8", + "c_grid_axis_shift": -0.5, + }, + ), + }, + ), + # my own invention + "1d_left": xr.Dataset( + {"data_g": (["XG"], np.random.rand(N)), "data_c": (["XC"], np.random.rand(N))}, + coords={ + "XG": ( + ["XG"], + 2 * np.pi / N * np.arange(0, N), + {"axis": "X", "c_grid_axis_shift": -0.5}, + ), + "XC": (["XC"], 2 * np.pi / N * (np.arange(0, N) + 0.5), {"axis": "X"}), + }, + ), + "1d_right": xr.Dataset( + {"data_g": (["XG"], np.random.rand(N)), "data_c": (["XC"], np.random.rand(N))}, + coords={ + "XG": ( + ["XG"], + 2 * np.pi / N * np.arange(1, N + 1), + {"axis": "X", "c_grid_axis_shift": 0.5}, + ), + "XC": (["XC"], 2 * np.pi / N * (np.arange(0, N) - 0.5), {"axis": "X"}), + }, + ), + "2d_left": xr.Dataset( + { + "data_g": (["YG", "XG"], np.random.rand(2 * N, N)), + "data_c": (["YC", "XC"], np.random.rand(2 * N, N)), + }, + coords={ + "XG": ( + ["XG"], + 2 * np.pi / N * np.arange(0, N), + {"axis": "X", "c_grid_axis_shift": -0.5}, + ), + "XC": (["XC"], 2 * np.pi / N * (np.arange(0, N) + 0.5), {"axis": "X"}), + "YG": ( + ["YG"], + 2 * np.pi / (2 * N) * np.arange(0, 2 * N), + {"axis": "Y", "c_grid_axis_shift": -0.5}, + ), + "YC": ( + ["YC"], + 2 * np.pi / (2 * N) * (np.arange(0, 2 * N) + 0.5), + {"axis": "Y"}, + ), + }, + ), +} + +# include periodicity +datasets_with_periodicity = { + "nonperiodic_1d_outer": (datasets["1d_outer"], False), + "nonperiodic_1d_inner": (datasets["1d_inner"], False), + "periodic_1d_left": (datasets["1d_left"], True), + "nonperiodic_1d_left": (datasets["1d_left"], False), + "periodic_1d_right": (datasets["1d_right"], True), + "nonperiodic_1d_right": (datasets["1d_right"], False), + "periodic_2d_left": (datasets["2d_left"], True), + "nonperiodic_2d_left": (datasets["2d_left"], False), + "xperiodic_2d_left": (datasets["2d_left"], ["X"]), + "yperiodic_2d_left": (datasets["2d_left"], ["Y"]), +} + +expected_values = { + "nonperiodic_1d_outer": {"axes": {"X": {"center": "XC", "outer": "XG"}}}, + "nonperiodic_1d_inner": {"axes": {"X": {"center": "XC", "inner": "XG"}}}, + "periodic_1d_left": {"axes": {"X": {"center": "XC", "left": "XG"}}}, + "nonperiodic_1d_left": {"axes": {"X": {"center": "XC", "left": "XG"}}}, + "periodic_1d_right": { + "axes": {"X": {"center": "XC", "right": "XG"}}, + "shift": True, + }, + "nonperiodic_1d_right": { + "axes": {"X": {"center": "XC", "right": "XG"}}, + "shift": True, + }, + "periodic_2d_left": { + "axes": { + "X": {"center": "XC", "left": "XG"}, + "Y": {"center": "YC", "left": "YG"}, + } + }, + "nonperiodic_2d_left": { + "axes": { + "X": {"center": "XC", "left": "XG"}, + "Y": {"center": "YC", "left": "YG"}, + } + }, + "xperiodic_2d_left": { + "axes": { + "X": {"center": "XC", "left": "XG"}, + "Y": {"center": "YC", "left": "YG"}, + } + }, + "yperiodic_2d_left": { + "axes": { + "X": {"center": "XC", "left": "XG"}, + "Y": {"center": "YC", "left": "YG"}, + } + }, +} + + +@pytest.fixture(scope="module", params=datasets_with_periodicity.keys()) +def all_datasets(request): + ds, periodic = datasets_with_periodicity[request.param] + return ds, periodic, expected_values[request.param] + + +@pytest.fixture( + scope="module", + params=[ + "nonperiodic_1d_outer", + "nonperiodic_1d_inner", + "nonperiodic_1d_left", + "nonperiodic_1d_right", + ], +) +def nonperiodic_1d(request): + ds, periodic = datasets_with_periodicity[request.param] + return ds, periodic, expected_values[request.param] + + +@pytest.fixture(scope="module", params=["periodic_1d_left", "periodic_1d_right"]) +def periodic_1d(request): + ds, periodic = datasets_with_periodicity[request.param] + return ds, periodic, expected_values[request.param] + + +@pytest.fixture( + scope="module", + params=[ + "periodic_2d_left", + "nonperiodic_2d_left", + "xperiodic_2d_left", + "yperiodic_2d_left", + ], +) +def all_2d(request): + ds, periodic = datasets_with_periodicity[request.param] + return ds, periodic, expected_values[request.param] + + +@pytest.fixture(scope="module", params=["periodic_2d_left"]) +def periodic_2d(request): + ds, periodic = datasets_with_periodicity[request.param] + return ds, periodic, expected_values[request.param] + + +@pytest.fixture( + scope="module", + params=["nonperiodic_2d_left", "xperiodic_2d_left", "yperiodic_2d_left"], +) +def nonperiodic_2d(request): + ds, periodic = datasets_with_periodicity[request.param] + return ds, periodic, expected_values[request.param] + + +def datasets_grid_metric(grid_type): + """Uniform grid test dataset. + Should eventually be extended to nonuniform grid""" + xt = np.arange(4) + xu = xt + 0.5 + yt = np.arange(5) + yu = yt + 0.5 + zt = np.arange(6) + zw = zt + 0.5 + t = np.arange(10) + + def data_generator(): + return np.random.rand(len(xt), len(yt), len(t), len(zt)) + + # Need to add a tracer here to get the tracer dimsuffix + tr = xr.DataArray(data_generator(), coords=[("xt", xt), ("yt", yt), ("time", t), ("zt", zt)]) + + u_b = xr.DataArray(data_generator(), coords=[("xu", xu), ("yu", yu), ("time", t), ("zt", zt)]) + + v_b = xr.DataArray(data_generator(), coords=[("xu", xu), ("yu", yu), ("time", t), ("zt", zt)]) + + u_c = xr.DataArray(data_generator(), coords=[("xu", xu), ("yt", yt), ("time", t), ("zt", zt)]) + + v_c = xr.DataArray(data_generator(), coords=[("xt", xt), ("yu", yu), ("time", t), ("zt", zt)]) + + wt = xr.DataArray(data_generator(), coords=[("xt", xt), ("yt", yt), ("time", t), ("zw", zw)]) + + # maybe also add some other combo of x,t y,t arrays.... + timeseries = xr.DataArray(np.random.rand(len(t)), coords=[("time", t)]) + + # northeast distance + dx = 0.3 + dy = 2 + dz = 20 + + dx_ne = xr.DataArray(np.ones([len(xt), len(yt)]) * dx - 0.1, coords=[("xu", xu), ("yu", yu)]) + dx_n = xr.DataArray(np.ones([len(xt), len(yt)]) * dx - 0.2, coords=[("xt", xt), ("yu", yu)]) + dx_e = xr.DataArray(np.ones([len(xt), len(yt)]) * dx - 0.3, coords=[("xu", xu), ("yt", yt)]) + dx_t = xr.DataArray(np.ones([len(xt), len(yt)]) * dx - 0.4, coords=[("xt", xt), ("yt", yt)]) + + dy_ne = xr.DataArray(np.ones([len(xt), len(yt)]) * dy + 0.1, coords=[("xu", xu), ("yu", yu)]) + dy_n = xr.DataArray(np.ones([len(xt), len(yt)]) * dy + 0.2, coords=[("xt", xt), ("yu", yu)]) + dy_e = xr.DataArray(np.ones([len(xt), len(yt)]) * dy + 0.3, coords=[("xu", xu), ("yt", yt)]) + dy_t = xr.DataArray(np.ones([len(xt), len(yt)]) * dy + 0.4, coords=[("xt", xt), ("yt", yt)]) + + # dz elements at horizontal tracer points + dz_t = xr.DataArray(data_generator() * dz, coords=[("xt", xt), ("yt", yt), ("time", t), ("zt", zt)]) + dz_w = xr.DataArray(data_generator() * dz, coords=[("xt", xt), ("yt", yt), ("time", t), ("zw", zw)]) + # dz elements at velocity points + dz_w_ne = xr.DataArray(data_generator() * dz, coords=[("xu", xu), ("yu", yu), ("time", t), ("zw", zw)]) + dz_w_n = xr.DataArray(data_generator() * dz, coords=[("xt", xt), ("yu", yu), ("time", t), ("zw", zw)]) + dz_w_e = xr.DataArray(data_generator() * dz, coords=[("xu", xu), ("yt", yt), ("time", t), ("zw", zw)]) + + # Make sure the areas are not just the product of x and y distances + area_ne = (dx_ne * dy_ne) + 0.1 + area_n = (dx_n * dy_n) + 0.2 + area_e = (dx_e * dy_e) + 0.3 + area_t = (dx_t * dy_t) + 0.4 + + # calculate volumes, but again add small differences. + volume_t = (dx_t * dy_t * dz_t) + 0.25 + + def _add_metrics(obj): + obj = obj.copy() + for name, data in [ + ("dx_ne", dx_ne), + ("dx_n", dx_n), + ("dx_e", dx_e), + ("dx_t", dx_t), + ("dy_ne", dy_ne), + ("dy_n", dy_n), + ("dy_e", dy_e), + ("dy_t", dy_t), + ("dz_t", dz_t), + ("dz_w", dz_w), + ("dz_w_ne", dz_w_ne), + ("dz_w_n", dz_w_n), + ("dz_w_e", dz_w_e), + ("area_ne", area_ne), + ("area_n", area_n), + ("area_e", area_e), + ("area_t", area_t), + ("volume_t", volume_t), + ]: + obj.coords[name] = data + obj.coords[name].attrs["tracked_name"] = name + # add xgcm attrs + for ii in ["xu", "xt"]: + obj[ii].attrs["axis"] = "X" + for ii in ["yu", "yt"]: + obj[ii].attrs["axis"] = "Y" + for ii in ["zt", "zw"]: + obj[ii].attrs["axis"] = "Z" + for ii in ["time"]: + obj[ii].attrs["axis"] = "T" + for ii in ["xu", "yu", "zw"]: + obj[ii].attrs["c_grid_axis_shift"] = 0.5 + return obj + + coords = { + "X": {"center": "xt", "right": "xu"}, + "Y": {"center": "yt", "right": "yu"}, + "Z": {"center": "zt", "right": "zw"}, + } + + metrics = { + ("X",): ["dx_t", "dx_n", "dx_e", "dx_ne"], + ("Y",): ["dy_t", "dy_n", "dy_e", "dy_ne"], + ("Z",): ["dz_t", "dz_w", "dz_w_ne", "dz_w_n", "dz_w_e"], + ("X", "Y"): ["area_t", "area_n", "area_e", "area_ne"], + ("X", "Y", "Z"): ["volume_t"], + } + + # combine to different grid configurations (B and C grid) + if grid_type == "B": + ds = _add_metrics(xr.Dataset({"u": u_b, "v": v_b, "wt": wt, "tracer": tr, "timeseries": timeseries})) + elif grid_type == "C": + ds = _add_metrics(xr.Dataset({"u": u_c, "v": v_c, "wt": wt, "tracer": tr, "timeseries": timeseries})) + else: + raise ValueError("Invalid input [%s] for `grid_type`. Only supports `B` and `C` at the moment " % grid_type) + + return ds, coords, metrics diff --git a/tests/v4/test_grid.py b/tests/v4/test_grid.py new file mode 100644 index 0000000000..4fbbc090a1 --- /dev/null +++ b/tests/v4/test_grid.py @@ -0,0 +1,965 @@ +import numpy as np +import pytest +import xarray as xr + +from parcels.gridv4 import Axis, Grid + +from tests.v4.datasets import all_2d # noqa: F401 +from tests.v4.datasets import all_datasets # noqa: F401 +from tests.v4.datasets import datasets # noqa: F401 +from tests.v4.datasets import datasets_grid_metric # noqa: F401 +from tests.v4.datasets import nonperiodic_1d # noqa: F401 +from tests.v4.datasets import nonperiodic_2d # noqa: F401 +from tests.v4.datasets import periodic_1d # noqa: F401 +from tests.v4.datasets import periodic_2d # noqa: F401 + + +# helper function to produce axes from datasets +def _get_axes(ds): + all_axes = {ds[c].attrs["axis"] for c in ds.dims if "axis" in ds[c].attrs} + axis_objs = {ax: Axis(ds, ax) for ax in all_axes} + return axis_objs + + +def test_create_axis(all_datasets): + ds, periodic, expected = all_datasets + axis_objs = _get_axes(ds) + for ax_expected, coords_expected in expected["axes"].items(): + assert ax_expected in axis_objs + this_axis = axis_objs[ax_expected] + for axis_name, coord_name in coords_expected.items(): + assert axis_name in this_axis.coords + assert this_axis.coords[axis_name] == coord_name + + +def _assert_axes_equal(ax1, ax2): + assert ax1.name == ax2.name + for pos, coord in ax1.coords.items(): + assert pos in ax2.coords + assert coord == ax2.coords[pos] + assert ax1._periodic == ax2._periodic + assert ax1._default_shifts == ax2._default_shifts + assert ax1._facedim == ax2._facedim + # TODO: make this work... + # assert ax1._connections == ax2._connections + + +def test_create_axis_no_comodo(all_datasets): + ds, periodic, expected = all_datasets + axis_objs = _get_axes(ds) + + # now strip out the metadata + ds_noattr = ds.copy() + for var in ds.variables: + ds_noattr[var].attrs.clear() + + for axis_name, axis_coords in expected["axes"].items(): + # now create the axis from scratch with no attributes + ax2 = Axis(ds_noattr, axis_name, coords=axis_coords) + # and compare to the one created with attributes + ax1 = axis_objs[axis_name] + + assert ax1.name == ax2.name + for pos, coord_name in ax1.coords.items(): + assert pos in ax2.coords + assert coord_name == ax2.coords[pos] + assert ax1._periodic == ax2._periodic + assert ax1._default_shifts == ax2._default_shifts + assert ax1._facedim == ax2._facedim + + +def test_create_axis_no_coords(all_datasets): + ds, periodic, expected = all_datasets + axis_objs = _get_axes(ds) + + ds_drop = ds.drop_vars(list(ds.coords)) + + for axis_name, axis_coords in expected["axes"].items(): + # now create the axis from scratch with no attributes OR coords + ax2 = Axis(ds_drop, axis_name, coords=axis_coords) + # and compare to the one created with attributes + ax1 = axis_objs[axis_name] + + assert ax1.name == ax2.name + for pos, coord in ax1.coords.items(): + assert pos in ax2.coords + assert ax1._periodic == ax2._periodic + assert ax1._default_shifts == ax2._default_shifts + assert ax1._facedim == ax2._facedim + + +def test_axis_repr(all_datasets): + ds, periodic, expected = all_datasets + axis_objs = _get_axes(ds) + for ax_name, axis in axis_objs.items(): + r = repr(axis).split("\n") + assert r[0].startswith(" dim_line_diff = 1 + # inner --> dim_line_diff = -1 + dim_len_diff = len(ds.XG) - len(ds.XC) + + if from_center: + to = (set(expected["axes"]["X"].keys()) - {"center"}).pop() + da = ds.data_c + else: + to = "center" + da = ds.data_g + + shift = expected.get("shift") or False + + # need boundary condition for everything but outer to center + if (boundary is None) and ( + dim_len_diff == 0 or (dim_len_diff == 1 and from_center) or (dim_len_diff == -1 and not from_center) + ): + with pytest.raises(ValueError): + data_left, data_right = axis._get_neighbor_data_pairs(da, to, boundary=boundary) + else: + data_left, data_right = axis._get_neighbor_data_pairs(da, to, boundary=boundary) + if ((dim_len_diff == 1) and not from_center) or ((dim_len_diff == -1) and from_center): + expected_left = da.data[:-1] + expected_right = da.data[1:] + elif ((dim_len_diff == 1) and from_center) or ((dim_len_diff == -1) and not from_center): + expected_left = _pad_left(da.data, boundary) + expected_right = _pad_right(da.data, boundary) + elif (shift and not from_center) or (not shift and from_center): + expected_right = da.data + expected_left = _pad_left(da.data, boundary)[:-1] + else: + expected_left = da.data + expected_right = _pad_right(da.data, boundary)[1:] + + np.testing.assert_allclose(data_left, expected_left) + np.testing.assert_allclose(data_right, expected_right) + + +@pytest.mark.parametrize("boundary", ["extend", "fill", pytest.param("extrapolate", marks=pytest.mark.xfail)]) +def test_axis_cumsum(nonperiodic_1d, boundary): + ds, periodic, expected = nonperiodic_1d + axis = Axis(ds, "X", periodic=periodic) + + axis_expected = expected["axes"]["X"] + + cumsum_g = axis.cumsum(ds.data_g, to="center", boundary=boundary) + assert cumsum_g.dims == ds.data_c.dims + # check default "to" + assert cumsum_g.equals(axis.cumsum(ds.data_g, boundary=boundary)) + + to = set(axis_expected).difference({"center"}).pop() + cumsum_c = axis.cumsum(ds.data_c, to=to, boundary=boundary) + assert cumsum_c.dims == ds.data_g.dims + # check default "to" + assert cumsum_c.equals(axis.cumsum(ds.data_c, boundary=boundary)) + + cumsum_c_raw = np.cumsum(ds.data_c.data) + cumsum_g_raw = np.cumsum(ds.data_g.data) + + if to == "right": + np.testing.assert_allclose(cumsum_c.data, cumsum_c_raw) + fill_value = 0.0 if boundary == "fill" else cumsum_g_raw[0] + np.testing.assert_allclose(cumsum_g.data, np.hstack([fill_value, cumsum_g_raw[:-1]])) + elif to == "left": + np.testing.assert_allclose(cumsum_g.data, cumsum_g_raw) + fill_value = 0.0 if boundary == "fill" else cumsum_c_raw[0] + np.testing.assert_allclose(cumsum_c.data, np.hstack([fill_value, cumsum_c_raw[:-1]])) + elif to == "inner": + np.testing.assert_allclose(cumsum_c.data, cumsum_c_raw[:-1]) + fill_value = 0.0 if boundary == "fill" else cumsum_g_raw[0] + np.testing.assert_allclose(cumsum_g.data, np.hstack([fill_value, cumsum_g_raw])) + elif to == "outer": + np.testing.assert_allclose(cumsum_g.data, cumsum_g_raw[:-1]) + fill_value = 0.0 if boundary == "fill" else cumsum_c_raw[0] + np.testing.assert_allclose(cumsum_c.data, np.hstack([fill_value, cumsum_c_raw])) + + # not much point doing this...we don't have the right test datasets + # to really test the errors + # other_positions = {'left', 'right', 'inner', 'outer'}.difference({to}) + # for pos in other_positions: + # with pytest.raises(KeyError): + # axis.cumsum(ds.data_c, to=pos, boundary=boundary) + + +@pytest.mark.parametrize( + "varname, axis_name, to, roll, roll_axis, swap_order", + [ + ("data_c", "X", "left", 1, 1, False), + ("data_c", "Y", "left", 1, 0, False), + ("data_g", "X", "center", -1, 1, True), + ("data_g", "Y", "center", -1, 0, True), + ], +) +def test_axis_neighbor_pairs_2d(periodic_2d, varname, axis_name, to, roll, roll_axis, swap_order): + ds, _, _ = periodic_2d + + axis = Axis(ds, axis_name) + + data = ds[varname] + data_left, data_right = axis._get_neighbor_data_pairs(data, to) + if swap_order: + data_left, data_right = data_right, data_left + np.testing.assert_allclose(data_left, np.roll(data.data, roll, axis=roll_axis)) + np.testing.assert_allclose(data_right, data.data) + + +@pytest.mark.parametrize("boundary", ["extend", "fill", pytest.param("extrapolate", marks=pytest.mark.xfail)]) +@pytest.mark.parametrize("from_center", [True, False]) +def test_axis_diff_and_interp_nonperiodic_1d(nonperiodic_1d, boundary, from_center): + ds, periodic, expected = nonperiodic_1d + axis = Axis(ds, "X", periodic=periodic) + + dim_len_diff = len(ds.XG) - len(ds.XC) + + if from_center: + to = (set(expected["axes"]["X"].keys()) - {"center"}).pop() + coord_to = "XG" + da = ds.data_c + else: + to = "center" + coord_to = "XC" + da = ds.data_g + + shift = expected.get("shift") or False + + data = da.data + if (dim_len_diff == 1 and not from_center) or (dim_len_diff == -1 and from_center): + data_left = data[:-1] + data_right = data[1:] + elif (dim_len_diff == 1 and from_center) or (dim_len_diff == -1 and not from_center): + data_left = _pad_left(data, boundary) + data_right = _pad_right(data, boundary) + elif (shift and not from_center) or (not shift and from_center): + data_left = _pad_left(data[:-1], boundary) + data_right = data + else: + data_left = data + data_right = _pad_right(data[1:], boundary) + + # interpolate + data_interp_expected = xr.DataArray( + 0.5 * (data_left + data_right), dims=[coord_to], coords={coord_to: ds[coord_to]} + ) + data_interp = axis.interp(da, to, boundary=boundary) + assert data_interp_expected.equals(data_interp) + # check without "to" specified + assert data_interp.equals(axis.interp(da, boundary=boundary)) + + # difference + data_diff_expected = xr.DataArray(data_right - data_left, dims=[coord_to], coords={coord_to: ds[coord_to]}) + data_diff = axis.diff(da, to, boundary=boundary) + assert data_diff_expected.equals(data_diff) + # check without "to" specified + assert data_diff.equals(axis.diff(da, boundary=boundary)) + + # max + data_max_expected = xr.DataArray( + np.maximum(data_right, data_left), + dims=[coord_to], + coords={coord_to: ds[coord_to]}, + ) + data_max = axis.max(da, to, boundary=boundary) + assert data_max_expected.equals(data_max) + # check without "to" specified + assert data_max.equals(axis.max(da, boundary=boundary)) + + # min + data_min_expected = xr.DataArray( + np.minimum(data_right, data_left), + dims=[coord_to], + coords={coord_to: ds[coord_to]}, + ) + data_min = axis.min(da, to, boundary=boundary) + assert data_min_expected.equals(data_min) + # check without "to" specified + assert data_min.equals(axis.min(da, boundary=boundary)) + + +# this mega test covers all options for 2D data + + +@pytest.mark.parametrize("boundary", ["extend", "fill", pytest.param("extrapolate", marks=pytest.mark.xfail)]) +@pytest.mark.parametrize("axis_name", ["X", "Y"]) +@pytest.mark.parametrize("varname, this, to", [("data_c", "center", "left"), ("data_g", "left", "center")]) +def test_axis_diff_and_interp_nonperiodic_2d(all_2d, boundary, axis_name, varname, this, to): + ds, periodic, _ = all_2d + + try: + ax_periodic = axis_name in periodic + except TypeError: + ax_periodic = periodic + + boundary_arg = boundary if not ax_periodic else None + axis = Axis(ds, axis_name, periodic=ax_periodic, boundary=boundary_arg) + da = ds[varname] + + # everything is left shift + data = ds[varname].data + + axis_num = da.get_axis_num(axis.coords[this]) + + # lookups for numpy.pad + numpy_pad_arg = {"extend": "edge", "fill": "constant"} + # args for numpy.pad + pad_left = (1, 0) + pad_right = (0, 1) + pad_none = (0, 0) + + if this == "center": + if ax_periodic: + data_left = np.roll(data, 1, axis=axis_num) + else: + pad_width = [pad_left if i == axis_num else pad_none for i in range(data.ndim)] + the_slice = tuple([slice(0, -1) if i == axis_num else slice(None) for i in range(data.ndim)]) + data_left = np.pad(data, pad_width, numpy_pad_arg[boundary])[the_slice] + data_right = data + elif this == "left": + if ax_periodic: + data_left = data + data_right = np.roll(data, -1, axis=axis_num) + else: + pad_width = [pad_right if i == axis_num else pad_none for i in range(data.ndim)] + the_slice = tuple([slice(1, None) if i == axis_num else slice(None) for i in range(data.ndim)]) + data_right = np.pad(data, pad_width, numpy_pad_arg[boundary])[the_slice] + data_left = data + + data_interp = 0.5 * (data_left + data_right) + data_diff = data_right - data_left + + # determine new dims + dims = list(da.dims) + dims[axis_num] = axis.coords[to] + coords = {dim: ds[dim] for dim in dims} + + da_interp_expected = xr.DataArray(data_interp, dims=dims, coords=coords) + da_diff_expected = xr.DataArray(data_diff, dims=dims, coords=coords) + + da_interp = axis.interp(da, to) + da_diff = axis.diff(da, to) + + assert da_interp_expected.equals(da_interp) + assert da_diff_expected.equals(da_diff) + + if boundary_arg is not None: + if boundary == "extend": + bad_boundary = "fill" + elif boundary == "fill": + bad_boundary = "extend" + + da_interp_wrong = axis.interp(da, to, boundary=bad_boundary) + assert not da_interp_expected.equals(da_interp_wrong) + da_diff_wrong = axis.diff(da, to, boundary=bad_boundary) + assert not da_diff_expected.equals(da_diff_wrong) + + +def test_axis_errors(): + ds = datasets["1d_left"] + + ds_noattr = ds.copy() + del ds_noattr.XC.attrs["axis"] + with pytest.raises(ValueError, match="Couldn't find a center coordinate for axis X"): + _ = Axis(ds_noattr, "X", periodic=True) + + del ds_noattr.XG.attrs["axis"] + with pytest.raises(ValueError, match="Couldn't find any coordinates for axis X"): + _ = Axis(ds_noattr, "X", periodic=True) + + ds_chopped = ds.copy().isel(XG=slice(None, 3)) + del ds_chopped["data_g"] + with pytest.raises(ValueError, match="coordinate XG has incompatible length"): + _ = Axis(ds_chopped, "X", periodic=True) + + ds_chopped.XG.attrs["c_grid_axis_shift"] = -0.5 + with pytest.raises(ValueError, match="coordinate XG has incompatible length"): + _ = Axis(ds_chopped, "X", periodic=True) + + del ds_chopped.XG.attrs["c_grid_axis_shift"] + with pytest.raises( + ValueError, + match="Found two coordinates without `c_grid_axis_shift` attribute for axis X", + ): + _ = Axis(ds_chopped, "X", periodic=True) + + ax = Axis(ds, "X", periodic=True) + + with pytest.raises(ValueError, match="Can't get neighbor pairs for the same position."): + ax.interp(ds.data_c, "center") + + with pytest.raises(ValueError, match="This axis doesn't contain a `right` position"): + ax.interp(ds.data_c, "right") + + # This case is broken, need to fix! + # with pytest.raises( + # ValueError, match="`boundary=fill` is not allowed " "with periodic axis X." + # ): + # ax.interp(ds.data_c, "left", boundary="fill") + + +@pytest.mark.parametrize( + "boundary", + [ + None, + "fill", + "extend", + pytest.param("extrapolate", marks=pytest.mark.xfail(strict=True)), + {"X": "fill", "Y": "extend"}, + ], +) +@pytest.mark.parametrize("fill_value", [None, 0, 1.0]) +def test_grid_create(all_datasets, boundary, fill_value): + ds, periodic, expected = all_datasets + grid = Grid(ds, periodic=periodic) + assert grid is not None + for ax in grid.axes.values(): + assert ax.boundary is None + grid = Grid(ds, periodic=periodic, boundary=boundary, fill_value=fill_value) + for name, ax in grid.axes.items(): + if isinstance(boundary, dict): + expected = boundary.get(name) + else: + expected = boundary + assert ax.boundary == expected + + if fill_value is None: + expected = 0.0 + elif isinstance(fill_value, dict): + expected = fill_value.get(name) + else: + expected = fill_value + assert ax.fill_value == expected + + +def test_create_grid_no_comodo(all_datasets): + ds, periodic, expected = all_datasets + grid_expected = Grid(ds, periodic=periodic) + + ds_noattr = ds.copy() + for var in ds.variables: + ds_noattr[var].attrs.clear() + + coords = expected["axes"] + grid = Grid(ds_noattr, periodic=periodic, coords=coords) + + for axis_name_expected in grid_expected.axes: + axis_expected = grid_expected.axes[axis_name_expected] + axis_actual = grid.axes[axis_name_expected] + _assert_axes_equal(axis_expected, axis_actual) + + +def test_grid_no_coords(periodic_1d): + """Ensure that you can use xgcm with Xarray datasets that don't have dimension coordinates.""" + ds, periodic, expected = periodic_1d + ds_nocoords = ds.drop_vars(list(ds.dims.keys())) + + coords = expected["axes"] + grid = Grid(ds_nocoords, periodic=periodic, coords=coords) + + diff = grid.diff(ds["data_c"], "X") + assert len(diff.coords) == 0 + interp = grid.interp(ds["data_c"], "X") + assert len(interp.coords) == 0 + + +def test_grid_repr(all_datasets): + ds, periodic, _ = all_datasets + grid = Grid(ds, periodic=periodic) + r = repr(grid).split("\n") + assert r[0] == "" + + +def test_grid_ops(all_datasets): + """ + Check that we get the same answer using Axis or Grid objects + """ + ds, periodic, _ = all_datasets + grid = Grid(ds, periodic=periodic) + + for axis_name in grid.axes.keys(): + try: + ax_periodic = axis_name in periodic + except TypeError: + ax_periodic = periodic + axis = Axis(ds, axis_name, periodic=ax_periodic) + + bcs = [None] if ax_periodic else ["fill", "extend"] + for varname in ["data_c", "data_g"]: + for boundary in bcs: + da_interp = grid.interp(ds[varname], axis_name, boundary=boundary) + da_interp_ax = axis.interp(ds[varname], boundary=boundary) + assert da_interp.equals(da_interp_ax) + + da_diff = grid.diff(ds[varname], axis_name, boundary=boundary) + da_diff_ax = axis.diff(ds[varname], boundary=boundary) + assert da_diff.equals(da_diff_ax) + + if boundary is not None: + da_cumsum = grid.cumsum(ds[varname], axis_name, boundary=boundary) + da_cumsum_ax = axis.cumsum(ds[varname], boundary=boundary) + assert da_cumsum.equals(da_cumsum_ax) + + +@pytest.mark.parametrize("func", ["interp", "max", "min", "diff", "cumsum"]) +@pytest.mark.parametrize("periodic", ["True", "False", ["X"], ["Y"], ["X", "Y"]]) +@pytest.mark.parametrize( + "boundary", + [ + "fill", + pytest.param("extrapolate", marks=pytest.mark.xfail(strict=True)), + "extend", + {"X": "fill", "Y": "extend"}, + {"X": "extend", "Y": "fill"}, + ], +) +def test_multi_axis_input(all_datasets, func, periodic, boundary): + ds, periodic_unused, expected_unused = all_datasets + grid = Grid(ds, periodic=periodic) + axes = list(grid.axes.keys()) + for varname in ["data_c", "data_g"]: + serial = ds[varname] + for axis in axes: + boundary_axis = boundary + if isinstance(boundary, dict): + boundary_axis = boundary[axis] + serial = getattr(grid, func)(serial, axis, boundary=boundary_axis) + full = getattr(grid, func)(ds[varname], axes, boundary=boundary) + xr.testing.assert_allclose(serial, full) + + +@pytest.mark.parametrize( + "func", + ["interp", "max", "min", "diff", "cumsum"], +) +@pytest.mark.parametrize( + "boundary", + [ + "fill", + pytest.param("extrapolate", marks=pytest.mark.xfail), + "extend", + {"X": "fill", "Y": "extend"}, + {"X": "extend", "Y": "fill"}, + ], +) +def test_dask_vs_eager(all_datasets, func, boundary): + ds, coords, metrics = datasets_grid_metric("C") + grid = Grid(ds, coords=coords) + grid_method = getattr(grid, func) + eager_result = grid_method(ds.tracer, "X", boundary=boundary) + + ds = ds.chunk({"xt": 1, "yt": 1, "time": 1, "zt": 1}) + grid = Grid(ds, coords=coords) + grid_method = getattr(grid, func) + dask_result = grid_method(ds.tracer, "X", boundary=boundary).compute() + + xr.testing.assert_allclose(dask_result, eager_result) + + +def test_grid_dict_input_boundary_fill(nonperiodic_1d): + """Test axis kwarg input functionality using dict input""" + ds, _, _ = nonperiodic_1d + grid_direct = Grid(ds, periodic=False, boundary="fill", fill_value=5) + grid_dict = Grid(ds, periodic=False, boundary={"X": "fill"}, fill_value={"X": 5}) + assert grid_direct.axes["X"].fill_value == grid_dict.axes["X"].fill_value + assert grid_direct.axes["X"].boundary == grid_dict.axes["X"].boundary + + +def test_invalid_boundary_error(): + ds = datasets["1d_left"] + with pytest.raises(ValueError): + Axis(ds, "X", boundary="bad") + with pytest.raises(ValueError): + Grid(ds, boundary="bad") + with pytest.raises(ValueError): + Grid(ds, boundary={"X": "bad"}) + with pytest.raises(ValueError): + Grid(ds, boundary={"X": 0}) + with pytest.raises(ValueError): + Grid(ds, boundary=0) + + +def test_invalid_fill_value_error(): + ds = datasets["1d_left"] + with pytest.raises(ValueError): + Axis(ds, "X", fill_value="x") + with pytest.raises(ValueError): + Grid(ds, fill_value="bad") + with pytest.raises(ValueError): + Grid(ds, fill_value={"X": "bad"}) + + +@pytest.mark.parametrize( + "funcname", + [ + "diff", + "interp", + "min", + "max", + "integrate", + "average", + "cumsum", + "cumint", + "derivative", + # TODO: we can get rid of many of these after the release. With the grid_ufunc logic many of these go through the same codepath + # e.g. diff/interp/min/max all are the same, so we can probably reduce this to diff, cumsum, integrate, derivative, cumint + ], +) +@pytest.mark.parametrize("gridtype", ["B", "C"]) +def test_keep_coords(funcname, gridtype): + ds, coords, metrics = datasets_grid_metric(gridtype) + ds = ds.assign_coords(yt_bis=ds["yt"], xt_bis=ds["xt"]) + grid = Grid(ds, coords=coords, metrics=metrics) + func = getattr(grid, funcname) + for axis_name in grid.axes.keys(): + result = func(ds.tracer, axis_name) + base_coords = list(result.dims) + augmented_coords = [c for c in ds.coords if set(ds[c].dims).issubset(result.dims) and c not in result.dims] + + if funcname in ["integrate", "average"]: + assert set(result.coords) == set(base_coords + augmented_coords) + else: + assert set(result.coords) == set(base_coords) + + # TODO: why is the behavior different for integrate and average? + if funcname not in ["integrate", "average"]: + result = func(ds.tracer, axis_name, keep_coords=False) + assert set(result.coords) == set(base_coords) + + result = func(ds.tracer, axis_name, keep_coords=True) + assert set(result.coords) == set(base_coords + augmented_coords) + + +def test_keep_coords_deprecation(): + ds, coords, metrics = datasets_grid_metric("B") + ds = ds.assign_coords(yt_bis=ds["yt"], xt_bis=ds["xt"]) + grid = Grid(ds, coords=coords, metrics=metrics) + for axis_name in grid.axes.keys(): + with pytest.warns(DeprecationWarning): + grid.diff(ds.tracer, axis_name, keep_coords=False) + + +def test_boundary_kwarg_same_as_grid_constructor_kwarg(): + ds = datasets["2d_left"] + grid1 = Grid(ds, periodic=False) + grid2 = Grid(ds, periodic=False, boundary={"X": "fill", "Y": "fill"}) + + actual1 = grid1.interp(ds.data_g, ("X", "Y"), boundary={"X": "fill", "Y": "fill"}) + actual2 = grid2.interp(ds.data_g, ("X", "Y")) + + xr.testing.assert_identical(actual1, actual2) + + +@pytest.mark.parametrize( + "metric_axes,metric_name", + [ + (["Y", "X"], "area_n"), + ("X", "dx_t"), + ("Y", "dy_ne"), + (["Y", "X"], "dy_n"), + (["X"], "tracer"), + ], +) +@pytest.mark.parametrize("periodic", [True, False]) +@pytest.mark.parametrize( + "boundary, boundary_expected", + [ + ({"X": "fill", "Y": "fill"}, {"X": "fill", "Y": "fill"}), + ({"X": "extend", "Y": "extend"}, {"X": "extend", "Y": "extend"}), + pytest.param( + {"X": "extrapolate", "Y": "extrapolate"}, + {"X": "extrapolate", "Y": "extrapolate"}, + marks=pytest.mark.xfail(reason="padding via extrapolation not yet supported in grid_ufunc refactor"), + ), + ("fill", {"X": "fill", "Y": "fill"}), + ("extend", {"X": "extend", "Y": "extend"}), + pytest.param( + "extrapolate", + {"X": "extrapolate", "Y": "extrapolate"}, + marks=pytest.mark.xfail(reason="padding via extrapolation not yet supported in grid_ufunc refactor"), + ), + ({"X": "extend", "Y": "fill"}, {"X": "extend", "Y": "fill"}), + pytest.param( + {"X": "extrapolate", "Y": "fill"}, + {"X": "extrapolate", "Y": "fill"}, + marks=pytest.mark.xfail(reason="padding via extrapolation not yet supported in grid_ufunc refactor"), + ), + pytest.param( + "fill", + {"X": "fill", "Y": "extend"}, + marks=pytest.mark.xfail, + id="boundary not equal to boundary_expected", + ), + ], +) +@pytest.mark.parametrize("fill_value", [None, 0.1]) +def test_interp_like(metric_axes, metric_name, periodic, boundary, boundary_expected, fill_value): + ds, coords, _ = datasets_grid_metric("C") + grid = Grid(ds, coords=coords, periodic=periodic) + grid.set_metrics(metric_axes, metric_name) + metric_available = grid._metrics.get(frozenset(metric_axes), None) + metric_available = metric_available[0] + interp_metric = grid.interp_like(metric_available, ds.u, boundary=boundary, fill_value=fill_value) + expected_metric = grid.interp(ds[metric_name], metric_axes, boundary=boundary_expected, fill_value=fill_value) + + xr.testing.assert_allclose(interp_metric, expected_metric) + + +def test_input_not_dims(): + data = np.random.rand(4, 5) + coord = np.random.rand(4, 5) + ds = xr.DataArray(data, dims=["x", "y"], coords={"c": (["x", "y"], coord)}).to_dataset(name="data") + msg = r"is not a dimension in the input dataset" + with pytest.raises(ValueError, match=msg): + Grid(ds, coords={"X": {"center": "c"}}) + + +def test_input_dim_notfound(): + data = np.random.rand(4, 5) + coord = np.random.rand(4, 5) + ds = xr.DataArray(data, dims=["x", "y"], coords={"c": (["x", "y"], coord)}).to_dataset(name="data") + msg = r"Could not find dimension `other` \(for the `center` position on axis `X`\) in input dataset." + with pytest.raises(ValueError, match=msg): + Grid(ds, coords={"X": {"center": "other"}}) + + +@pytest.mark.parametrize( + "funcname", + [ + "interp", + "diff", + "min", + "max", + "cumsum", + "derivative", + "cumint", + ], +) +@pytest.mark.parametrize( + "boundary", + ["fill", "extend"], +) +@pytest.mark.parametrize( + "fill_value", + [0, 10, None], +) +def test_boundary_global_input(funcname, boundary, fill_value): + """Test that globally defined boundary values result in + the same output as when the parameters are defined the grid methods + """ + ds, coords, metrics = datasets_grid_metric("C") + axis = "X" + # Test results by globally specifying fill value/boundary on grid object + grid_global = Grid( + ds, + coords=coords, + metrics=metrics, + periodic=False, + boundary=boundary, + fill_value=fill_value, + ) + func_global = getattr(grid_global, funcname) + global_result = func_global(ds.tracer, axis) + + # Test results by manually specifying fill value/boundary on grid method + grid_manual = Grid(ds, coords=coords, metrics=metrics, periodic=False, boundary=boundary) + func_manual = getattr(grid_manual, funcname) + manual_result = func_manual(ds.tracer, axis, boundary=boundary, fill_value=fill_value) + xr.testing.assert_allclose(global_result, manual_result) + + +class TestInputErrorGridMethods: + def test_multiple_keys_vector_input(self): + ds, _, _ = datasets_grid_metric("C") + grid = Grid(ds) + msg = "Vector components provided as dictionaries should contain exactly one key/value pair. .*?" + with pytest.raises( + ValueError, + match=msg, + ): + grid.diff({"X": xr.DataArray(), "Y": xr.DataArray()}, "X") + + def test_wrong_input_type_scalar(self): + ds, _, _ = datasets_grid_metric("C") + grid = Grid(ds) + msg = "All data arguments must be either a DataArray or Dictionary .*?" + with pytest.raises( + TypeError, + match=msg, + ): + grid.diff("not_a_dataarray", "X") + + def test_wrong_input_type_vector(self): + ds, _, _ = datasets_grid_metric("C") + grid = Grid(ds) + msg = "Dictionary inputs must have a DataArray as value. Got .*?" + with pytest.raises( + TypeError, + match=msg, + ): + grid.diff({"X": "not_a_dataarray"}, "X") + + def test_wrong_axis_vector_input_axis(self): + ds, _, _ = datasets_grid_metric("C") + grid = Grid(ds) + msg = "Vector component with unknown axis provided. Grid has axes .*?" + with pytest.raises( + ValueError, + match=msg, + ): + grid.diff({"wrong": xr.DataArray()}, "X") + + +class TestInputErrorApplyAsGridUfunc: + def test_multiple_keys_vector_input(self): + ds, _, _ = datasets_grid_metric("C") + grid = Grid(ds) + msg = "Vector components provided as dictionaries should contain exactly one key/value pair. .*?" + with pytest.raises( + ValueError, + match=msg, + ): + grid.apply_as_grid_ufunc(lambda x: x, {"X": xr.DataArray(), "Y": xr.DataArray()}, "X") + + def test_wrong_input_type_scalar(self): + ds, _, _ = datasets_grid_metric("C") + grid = Grid(ds) + msg = "All data arguments must be either a DataArray or Dictionary .*?" + with pytest.raises( + TypeError, + match=msg, + ): + grid.apply_as_grid_ufunc(lambda x: x, "not_a_dataarray", "X") + + def test_wrong_input_type_vector(self): + ds, _, _ = datasets_grid_metric("C") + grid = Grid(ds) + msg = "Dictionary inputs must have a DataArray as value. Got .*?" + with pytest.raises( + TypeError, + match=msg, + ): + grid.apply_as_grid_ufunc(lambda x: x, {"X": "not_a_dataarray"}, "X") + + def test_wrong_axis_vector_input_axis(self): + ds, _, _ = datasets_grid_metric("C") + grid = Grid(ds) + msg = "Vector component with unknown axis provided. Grid has axes .*?" + with pytest.raises( + ValueError, + match=msg, + ): + grid.apply_as_grid_ufunc(lambda x: x, {"wrong": xr.DataArray()}, "X") + + def test_vector_input_data_other_mismatch(self): + ds, _, _ = datasets_grid_metric("C") + grid = Grid(ds) + msg = "When providing multiple input arguments, `other_component`" " needs to provide one dictionary per input" + with pytest.raises( + ValueError, + match=msg, + ): + # Passing 3 args and 2 other components should fail. + grid.apply_as_grid_ufunc( + lambda x: x, + {"X": xr.DataArray()}, + {"Y": xr.DataArray()}, + {"Z": xr.DataArray()}, + axis="X", + other_component=[{"X": xr.DataArray()}, {"Y": xr.DataArray()}], + ) + + def test_wrong_input_type_vector_multi_input(self): + ds, _, _ = datasets_grid_metric("C") + grid = Grid(ds) + msg = "Dictionary inputs must have a DataArray as value. Got .*?" + with pytest.raises( + TypeError, + match=msg, + ): + # Passing 3 args and 2 other components should fail. + grid.apply_as_grid_ufunc( + lambda x: x, + {"X": xr.DataArray()}, + {"Y": "not_a_data_array"}, + axis="X", + other_component=[{"X": xr.DataArray()}, {"Y": xr.DataArray()}], + ) + + def test_wrong_axis_vector_input_axis_multi_input(self): + ds, _, _ = datasets_grid_metric("C") + grid = Grid(ds) + msg = "Vector component with unknown axis provided. Grid has axes .*?" + with pytest.raises( + ValueError, + match=msg, + ): + # Passing 3 args and 2 other components should fail. + grid.apply_as_grid_ufunc( + lambda x: x, + {"X": xr.DataArray()}, + {"Y": xr.DataArray()}, + axis="X", + other_component=[{"wrong": xr.DataArray()}, {"Y": xr.DataArray()}], + ) From f26a3082a87788beb110447187bd22cc33103811 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 1 Apr 2025 18:19:58 +0200 Subject: [PATCH 02/39] Adapt xgcm tests Remove tests that are no longer relevant --- tests/v4/test_grid.py | 696 +----------------------------------------- 1 file changed, 2 insertions(+), 694 deletions(-) diff --git a/tests/v4/test_grid.py b/tests/v4/test_grid.py index 4fbbc090a1..684c994e65 100644 --- a/tests/v4/test_grid.py +++ b/tests/v4/test_grid.py @@ -107,37 +107,6 @@ def test_get_position_name(all_datasets): assert axis._get_position_name(da) == (position, coord) -def test_axis_wrap_and_replace_2d(periodic_2d): - ds, periodic, expected = periodic_2d - axis_objs = _get_axes(ds) - - da_xc_yc = 0 * ds.XC * ds.YC + 1 - da_xc_yg = 0 * ds.XC * ds.YG + 1 - da_xg_yc = 0 * ds.XG * ds.YC + 1 - - da_xc_yg_test = axis_objs["Y"]._wrap_and_replace_coords(da_xc_yc, da_xc_yc.data, "left") - assert da_xc_yg.equals(da_xc_yg_test) - - da_xg_yc_test = axis_objs["X"]._wrap_and_replace_coords(da_xc_yc, da_xc_yc.data, "left") - assert da_xg_yc.equals(da_xg_yc_test) - - -def test_axis_wrap_and_replace_nonperiodic(nonperiodic_1d): - ds, periodic, expected = nonperiodic_1d - axis = Axis(ds, "X") - - da_c = 0 * ds.XC + 1 - da_g = 0 * ds.XG + 1 - - to = (set(expected["axes"]["X"].keys()) - {"center"}).pop() - - da_g_test = axis._wrap_and_replace_coords(da_c, da_g.data, to) - assert da_g.equals(da_g_test) - - da_c_test = axis._wrap_and_replace_coords(da_g, da_c.data, "center") - assert da_c.equals(da_c_test) - - # helper functions for padding arrays # this feels silly...I'm basically just re-coding the function in order to # test it @@ -151,271 +120,6 @@ def _pad_right(data, boundary, fill_value=0.0): return np.hstack([data, pad_val]) -@pytest.mark.parametrize( - "boundary", - [None, "extend", "fill", pytest.param("extrapolate", marks=pytest.mark.xfail)], -) -@pytest.mark.parametrize("from_center", [True, False]) -def test_axis_neighbor_pairs_nonperiodic_1d(nonperiodic_1d, boundary, from_center): - ds, periodic, expected = nonperiodic_1d - axis = Axis(ds, "X", periodic=periodic) - - # detect whether this is an outer or inner case - # outer --> dim_line_diff = 1 - # inner --> dim_line_diff = -1 - dim_len_diff = len(ds.XG) - len(ds.XC) - - if from_center: - to = (set(expected["axes"]["X"].keys()) - {"center"}).pop() - da = ds.data_c - else: - to = "center" - da = ds.data_g - - shift = expected.get("shift") or False - - # need boundary condition for everything but outer to center - if (boundary is None) and ( - dim_len_diff == 0 or (dim_len_diff == 1 and from_center) or (dim_len_diff == -1 and not from_center) - ): - with pytest.raises(ValueError): - data_left, data_right = axis._get_neighbor_data_pairs(da, to, boundary=boundary) - else: - data_left, data_right = axis._get_neighbor_data_pairs(da, to, boundary=boundary) - if ((dim_len_diff == 1) and not from_center) or ((dim_len_diff == -1) and from_center): - expected_left = da.data[:-1] - expected_right = da.data[1:] - elif ((dim_len_diff == 1) and from_center) or ((dim_len_diff == -1) and not from_center): - expected_left = _pad_left(da.data, boundary) - expected_right = _pad_right(da.data, boundary) - elif (shift and not from_center) or (not shift and from_center): - expected_right = da.data - expected_left = _pad_left(da.data, boundary)[:-1] - else: - expected_left = da.data - expected_right = _pad_right(da.data, boundary)[1:] - - np.testing.assert_allclose(data_left, expected_left) - np.testing.assert_allclose(data_right, expected_right) - - -@pytest.mark.parametrize("boundary", ["extend", "fill", pytest.param("extrapolate", marks=pytest.mark.xfail)]) -def test_axis_cumsum(nonperiodic_1d, boundary): - ds, periodic, expected = nonperiodic_1d - axis = Axis(ds, "X", periodic=periodic) - - axis_expected = expected["axes"]["X"] - - cumsum_g = axis.cumsum(ds.data_g, to="center", boundary=boundary) - assert cumsum_g.dims == ds.data_c.dims - # check default "to" - assert cumsum_g.equals(axis.cumsum(ds.data_g, boundary=boundary)) - - to = set(axis_expected).difference({"center"}).pop() - cumsum_c = axis.cumsum(ds.data_c, to=to, boundary=boundary) - assert cumsum_c.dims == ds.data_g.dims - # check default "to" - assert cumsum_c.equals(axis.cumsum(ds.data_c, boundary=boundary)) - - cumsum_c_raw = np.cumsum(ds.data_c.data) - cumsum_g_raw = np.cumsum(ds.data_g.data) - - if to == "right": - np.testing.assert_allclose(cumsum_c.data, cumsum_c_raw) - fill_value = 0.0 if boundary == "fill" else cumsum_g_raw[0] - np.testing.assert_allclose(cumsum_g.data, np.hstack([fill_value, cumsum_g_raw[:-1]])) - elif to == "left": - np.testing.assert_allclose(cumsum_g.data, cumsum_g_raw) - fill_value = 0.0 if boundary == "fill" else cumsum_c_raw[0] - np.testing.assert_allclose(cumsum_c.data, np.hstack([fill_value, cumsum_c_raw[:-1]])) - elif to == "inner": - np.testing.assert_allclose(cumsum_c.data, cumsum_c_raw[:-1]) - fill_value = 0.0 if boundary == "fill" else cumsum_g_raw[0] - np.testing.assert_allclose(cumsum_g.data, np.hstack([fill_value, cumsum_g_raw])) - elif to == "outer": - np.testing.assert_allclose(cumsum_g.data, cumsum_g_raw[:-1]) - fill_value = 0.0 if boundary == "fill" else cumsum_c_raw[0] - np.testing.assert_allclose(cumsum_c.data, np.hstack([fill_value, cumsum_c_raw])) - - # not much point doing this...we don't have the right test datasets - # to really test the errors - # other_positions = {'left', 'right', 'inner', 'outer'}.difference({to}) - # for pos in other_positions: - # with pytest.raises(KeyError): - # axis.cumsum(ds.data_c, to=pos, boundary=boundary) - - -@pytest.mark.parametrize( - "varname, axis_name, to, roll, roll_axis, swap_order", - [ - ("data_c", "X", "left", 1, 1, False), - ("data_c", "Y", "left", 1, 0, False), - ("data_g", "X", "center", -1, 1, True), - ("data_g", "Y", "center", -1, 0, True), - ], -) -def test_axis_neighbor_pairs_2d(periodic_2d, varname, axis_name, to, roll, roll_axis, swap_order): - ds, _, _ = periodic_2d - - axis = Axis(ds, axis_name) - - data = ds[varname] - data_left, data_right = axis._get_neighbor_data_pairs(data, to) - if swap_order: - data_left, data_right = data_right, data_left - np.testing.assert_allclose(data_left, np.roll(data.data, roll, axis=roll_axis)) - np.testing.assert_allclose(data_right, data.data) - - -@pytest.mark.parametrize("boundary", ["extend", "fill", pytest.param("extrapolate", marks=pytest.mark.xfail)]) -@pytest.mark.parametrize("from_center", [True, False]) -def test_axis_diff_and_interp_nonperiodic_1d(nonperiodic_1d, boundary, from_center): - ds, periodic, expected = nonperiodic_1d - axis = Axis(ds, "X", periodic=periodic) - - dim_len_diff = len(ds.XG) - len(ds.XC) - - if from_center: - to = (set(expected["axes"]["X"].keys()) - {"center"}).pop() - coord_to = "XG" - da = ds.data_c - else: - to = "center" - coord_to = "XC" - da = ds.data_g - - shift = expected.get("shift") or False - - data = da.data - if (dim_len_diff == 1 and not from_center) or (dim_len_diff == -1 and from_center): - data_left = data[:-1] - data_right = data[1:] - elif (dim_len_diff == 1 and from_center) or (dim_len_diff == -1 and not from_center): - data_left = _pad_left(data, boundary) - data_right = _pad_right(data, boundary) - elif (shift and not from_center) or (not shift and from_center): - data_left = _pad_left(data[:-1], boundary) - data_right = data - else: - data_left = data - data_right = _pad_right(data[1:], boundary) - - # interpolate - data_interp_expected = xr.DataArray( - 0.5 * (data_left + data_right), dims=[coord_to], coords={coord_to: ds[coord_to]} - ) - data_interp = axis.interp(da, to, boundary=boundary) - assert data_interp_expected.equals(data_interp) - # check without "to" specified - assert data_interp.equals(axis.interp(da, boundary=boundary)) - - # difference - data_diff_expected = xr.DataArray(data_right - data_left, dims=[coord_to], coords={coord_to: ds[coord_to]}) - data_diff = axis.diff(da, to, boundary=boundary) - assert data_diff_expected.equals(data_diff) - # check without "to" specified - assert data_diff.equals(axis.diff(da, boundary=boundary)) - - # max - data_max_expected = xr.DataArray( - np.maximum(data_right, data_left), - dims=[coord_to], - coords={coord_to: ds[coord_to]}, - ) - data_max = axis.max(da, to, boundary=boundary) - assert data_max_expected.equals(data_max) - # check without "to" specified - assert data_max.equals(axis.max(da, boundary=boundary)) - - # min - data_min_expected = xr.DataArray( - np.minimum(data_right, data_left), - dims=[coord_to], - coords={coord_to: ds[coord_to]}, - ) - data_min = axis.min(da, to, boundary=boundary) - assert data_min_expected.equals(data_min) - # check without "to" specified - assert data_min.equals(axis.min(da, boundary=boundary)) - - -# this mega test covers all options for 2D data - - -@pytest.mark.parametrize("boundary", ["extend", "fill", pytest.param("extrapolate", marks=pytest.mark.xfail)]) -@pytest.mark.parametrize("axis_name", ["X", "Y"]) -@pytest.mark.parametrize("varname, this, to", [("data_c", "center", "left"), ("data_g", "left", "center")]) -def test_axis_diff_and_interp_nonperiodic_2d(all_2d, boundary, axis_name, varname, this, to): - ds, periodic, _ = all_2d - - try: - ax_periodic = axis_name in periodic - except TypeError: - ax_periodic = periodic - - boundary_arg = boundary if not ax_periodic else None - axis = Axis(ds, axis_name, periodic=ax_periodic, boundary=boundary_arg) - da = ds[varname] - - # everything is left shift - data = ds[varname].data - - axis_num = da.get_axis_num(axis.coords[this]) - - # lookups for numpy.pad - numpy_pad_arg = {"extend": "edge", "fill": "constant"} - # args for numpy.pad - pad_left = (1, 0) - pad_right = (0, 1) - pad_none = (0, 0) - - if this == "center": - if ax_periodic: - data_left = np.roll(data, 1, axis=axis_num) - else: - pad_width = [pad_left if i == axis_num else pad_none for i in range(data.ndim)] - the_slice = tuple([slice(0, -1) if i == axis_num else slice(None) for i in range(data.ndim)]) - data_left = np.pad(data, pad_width, numpy_pad_arg[boundary])[the_slice] - data_right = data - elif this == "left": - if ax_periodic: - data_left = data - data_right = np.roll(data, -1, axis=axis_num) - else: - pad_width = [pad_right if i == axis_num else pad_none for i in range(data.ndim)] - the_slice = tuple([slice(1, None) if i == axis_num else slice(None) for i in range(data.ndim)]) - data_right = np.pad(data, pad_width, numpy_pad_arg[boundary])[the_slice] - data_left = data - - data_interp = 0.5 * (data_left + data_right) - data_diff = data_right - data_left - - # determine new dims - dims = list(da.dims) - dims[axis_num] = axis.coords[to] - coords = {dim: ds[dim] for dim in dims} - - da_interp_expected = xr.DataArray(data_interp, dims=dims, coords=coords) - da_diff_expected = xr.DataArray(data_diff, dims=dims, coords=coords) - - da_interp = axis.interp(da, to) - da_diff = axis.diff(da, to) - - assert da_interp_expected.equals(da_interp) - assert da_diff_expected.equals(da_diff) - - if boundary_arg is not None: - if boundary == "extend": - bad_boundary = "fill" - elif boundary == "fill": - bad_boundary = "extend" - - da_interp_wrong = axis.interp(da, to, boundary=bad_boundary) - assert not da_interp_expected.equals(da_interp_wrong) - da_diff_wrong = axis.diff(da, to, boundary=bad_boundary) - assert not da_diff_expected.equals(da_diff_wrong) - - def test_axis_errors(): ds = datasets["1d_left"] @@ -444,14 +148,6 @@ def test_axis_errors(): ): _ = Axis(ds_chopped, "X", periodic=True) - ax = Axis(ds, "X", periodic=True) - - with pytest.raises(ValueError, match="Can't get neighbor pairs for the same position."): - ax.interp(ds.data_c, "center") - - with pytest.raises(ValueError, match="This axis doesn't contain a `right` position"): - ax.interp(ds.data_c, "right") - # This case is broken, need to fix! # with pytest.raises( # ValueError, match="`boundary=fill` is not allowed " "with periodic axis X." @@ -511,17 +207,12 @@ def test_create_grid_no_comodo(all_datasets): def test_grid_no_coords(periodic_1d): - """Ensure that you can use xgcm with Xarray datasets that don't have dimension coordinates.""" + """Ensure that you can use Grid with Xarray datasets that don't have dimension coordinates.""" ds, periodic, expected = periodic_1d ds_nocoords = ds.drop_vars(list(ds.dims.keys())) coords = expected["axes"] - grid = Grid(ds_nocoords, periodic=periodic, coords=coords) - - diff = grid.diff(ds["data_c"], "X") - assert len(diff.coords) == 0 - interp = grid.interp(ds["data_c"], "X") - assert len(interp.coords) == 0 + Grid(ds_nocoords, periodic=periodic, coords=coords) def test_grid_repr(all_datasets): @@ -531,92 +222,6 @@ def test_grid_repr(all_datasets): assert r[0] == "" -def test_grid_ops(all_datasets): - """ - Check that we get the same answer using Axis or Grid objects - """ - ds, periodic, _ = all_datasets - grid = Grid(ds, periodic=periodic) - - for axis_name in grid.axes.keys(): - try: - ax_periodic = axis_name in periodic - except TypeError: - ax_periodic = periodic - axis = Axis(ds, axis_name, periodic=ax_periodic) - - bcs = [None] if ax_periodic else ["fill", "extend"] - for varname in ["data_c", "data_g"]: - for boundary in bcs: - da_interp = grid.interp(ds[varname], axis_name, boundary=boundary) - da_interp_ax = axis.interp(ds[varname], boundary=boundary) - assert da_interp.equals(da_interp_ax) - - da_diff = grid.diff(ds[varname], axis_name, boundary=boundary) - da_diff_ax = axis.diff(ds[varname], boundary=boundary) - assert da_diff.equals(da_diff_ax) - - if boundary is not None: - da_cumsum = grid.cumsum(ds[varname], axis_name, boundary=boundary) - da_cumsum_ax = axis.cumsum(ds[varname], boundary=boundary) - assert da_cumsum.equals(da_cumsum_ax) - - -@pytest.mark.parametrize("func", ["interp", "max", "min", "diff", "cumsum"]) -@pytest.mark.parametrize("periodic", ["True", "False", ["X"], ["Y"], ["X", "Y"]]) -@pytest.mark.parametrize( - "boundary", - [ - "fill", - pytest.param("extrapolate", marks=pytest.mark.xfail(strict=True)), - "extend", - {"X": "fill", "Y": "extend"}, - {"X": "extend", "Y": "fill"}, - ], -) -def test_multi_axis_input(all_datasets, func, periodic, boundary): - ds, periodic_unused, expected_unused = all_datasets - grid = Grid(ds, periodic=periodic) - axes = list(grid.axes.keys()) - for varname in ["data_c", "data_g"]: - serial = ds[varname] - for axis in axes: - boundary_axis = boundary - if isinstance(boundary, dict): - boundary_axis = boundary[axis] - serial = getattr(grid, func)(serial, axis, boundary=boundary_axis) - full = getattr(grid, func)(ds[varname], axes, boundary=boundary) - xr.testing.assert_allclose(serial, full) - - -@pytest.mark.parametrize( - "func", - ["interp", "max", "min", "diff", "cumsum"], -) -@pytest.mark.parametrize( - "boundary", - [ - "fill", - pytest.param("extrapolate", marks=pytest.mark.xfail), - "extend", - {"X": "fill", "Y": "extend"}, - {"X": "extend", "Y": "fill"}, - ], -) -def test_dask_vs_eager(all_datasets, func, boundary): - ds, coords, metrics = datasets_grid_metric("C") - grid = Grid(ds, coords=coords) - grid_method = getattr(grid, func) - eager_result = grid_method(ds.tracer, "X", boundary=boundary) - - ds = ds.chunk({"xt": 1, "yt": 1, "time": 1, "zt": 1}) - grid = Grid(ds, coords=coords) - grid_method = getattr(grid, func) - dask_result = grid_method(ds.tracer, "X", boundary=boundary).compute() - - xr.testing.assert_allclose(dask_result, eager_result) - - def test_grid_dict_input_boundary_fill(nonperiodic_1d): """Test axis kwarg input functionality using dict input""" ds, _, _ = nonperiodic_1d @@ -650,122 +255,6 @@ def test_invalid_fill_value_error(): Grid(ds, fill_value={"X": "bad"}) -@pytest.mark.parametrize( - "funcname", - [ - "diff", - "interp", - "min", - "max", - "integrate", - "average", - "cumsum", - "cumint", - "derivative", - # TODO: we can get rid of many of these after the release. With the grid_ufunc logic many of these go through the same codepath - # e.g. diff/interp/min/max all are the same, so we can probably reduce this to diff, cumsum, integrate, derivative, cumint - ], -) -@pytest.mark.parametrize("gridtype", ["B", "C"]) -def test_keep_coords(funcname, gridtype): - ds, coords, metrics = datasets_grid_metric(gridtype) - ds = ds.assign_coords(yt_bis=ds["yt"], xt_bis=ds["xt"]) - grid = Grid(ds, coords=coords, metrics=metrics) - func = getattr(grid, funcname) - for axis_name in grid.axes.keys(): - result = func(ds.tracer, axis_name) - base_coords = list(result.dims) - augmented_coords = [c for c in ds.coords if set(ds[c].dims).issubset(result.dims) and c not in result.dims] - - if funcname in ["integrate", "average"]: - assert set(result.coords) == set(base_coords + augmented_coords) - else: - assert set(result.coords) == set(base_coords) - - # TODO: why is the behavior different for integrate and average? - if funcname not in ["integrate", "average"]: - result = func(ds.tracer, axis_name, keep_coords=False) - assert set(result.coords) == set(base_coords) - - result = func(ds.tracer, axis_name, keep_coords=True) - assert set(result.coords) == set(base_coords + augmented_coords) - - -def test_keep_coords_deprecation(): - ds, coords, metrics = datasets_grid_metric("B") - ds = ds.assign_coords(yt_bis=ds["yt"], xt_bis=ds["xt"]) - grid = Grid(ds, coords=coords, metrics=metrics) - for axis_name in grid.axes.keys(): - with pytest.warns(DeprecationWarning): - grid.diff(ds.tracer, axis_name, keep_coords=False) - - -def test_boundary_kwarg_same_as_grid_constructor_kwarg(): - ds = datasets["2d_left"] - grid1 = Grid(ds, periodic=False) - grid2 = Grid(ds, periodic=False, boundary={"X": "fill", "Y": "fill"}) - - actual1 = grid1.interp(ds.data_g, ("X", "Y"), boundary={"X": "fill", "Y": "fill"}) - actual2 = grid2.interp(ds.data_g, ("X", "Y")) - - xr.testing.assert_identical(actual1, actual2) - - -@pytest.mark.parametrize( - "metric_axes,metric_name", - [ - (["Y", "X"], "area_n"), - ("X", "dx_t"), - ("Y", "dy_ne"), - (["Y", "X"], "dy_n"), - (["X"], "tracer"), - ], -) -@pytest.mark.parametrize("periodic", [True, False]) -@pytest.mark.parametrize( - "boundary, boundary_expected", - [ - ({"X": "fill", "Y": "fill"}, {"X": "fill", "Y": "fill"}), - ({"X": "extend", "Y": "extend"}, {"X": "extend", "Y": "extend"}), - pytest.param( - {"X": "extrapolate", "Y": "extrapolate"}, - {"X": "extrapolate", "Y": "extrapolate"}, - marks=pytest.mark.xfail(reason="padding via extrapolation not yet supported in grid_ufunc refactor"), - ), - ("fill", {"X": "fill", "Y": "fill"}), - ("extend", {"X": "extend", "Y": "extend"}), - pytest.param( - "extrapolate", - {"X": "extrapolate", "Y": "extrapolate"}, - marks=pytest.mark.xfail(reason="padding via extrapolation not yet supported in grid_ufunc refactor"), - ), - ({"X": "extend", "Y": "fill"}, {"X": "extend", "Y": "fill"}), - pytest.param( - {"X": "extrapolate", "Y": "fill"}, - {"X": "extrapolate", "Y": "fill"}, - marks=pytest.mark.xfail(reason="padding via extrapolation not yet supported in grid_ufunc refactor"), - ), - pytest.param( - "fill", - {"X": "fill", "Y": "extend"}, - marks=pytest.mark.xfail, - id="boundary not equal to boundary_expected", - ), - ], -) -@pytest.mark.parametrize("fill_value", [None, 0.1]) -def test_interp_like(metric_axes, metric_name, periodic, boundary, boundary_expected, fill_value): - ds, coords, _ = datasets_grid_metric("C") - grid = Grid(ds, coords=coords, periodic=periodic) - grid.set_metrics(metric_axes, metric_name) - metric_available = grid._metrics.get(frozenset(metric_axes), None) - metric_available = metric_available[0] - interp_metric = grid.interp_like(metric_available, ds.u, boundary=boundary, fill_value=fill_value) - expected_metric = grid.interp(ds[metric_name], metric_axes, boundary=boundary_expected, fill_value=fill_value) - - xr.testing.assert_allclose(interp_metric, expected_metric) - - def test_input_not_dims(): data = np.random.rand(4, 5) coord = np.random.rand(4, 5) @@ -782,184 +271,3 @@ def test_input_dim_notfound(): msg = r"Could not find dimension `other` \(for the `center` position on axis `X`\) in input dataset." with pytest.raises(ValueError, match=msg): Grid(ds, coords={"X": {"center": "other"}}) - - -@pytest.mark.parametrize( - "funcname", - [ - "interp", - "diff", - "min", - "max", - "cumsum", - "derivative", - "cumint", - ], -) -@pytest.mark.parametrize( - "boundary", - ["fill", "extend"], -) -@pytest.mark.parametrize( - "fill_value", - [0, 10, None], -) -def test_boundary_global_input(funcname, boundary, fill_value): - """Test that globally defined boundary values result in - the same output as when the parameters are defined the grid methods - """ - ds, coords, metrics = datasets_grid_metric("C") - axis = "X" - # Test results by globally specifying fill value/boundary on grid object - grid_global = Grid( - ds, - coords=coords, - metrics=metrics, - periodic=False, - boundary=boundary, - fill_value=fill_value, - ) - func_global = getattr(grid_global, funcname) - global_result = func_global(ds.tracer, axis) - - # Test results by manually specifying fill value/boundary on grid method - grid_manual = Grid(ds, coords=coords, metrics=metrics, periodic=False, boundary=boundary) - func_manual = getattr(grid_manual, funcname) - manual_result = func_manual(ds.tracer, axis, boundary=boundary, fill_value=fill_value) - xr.testing.assert_allclose(global_result, manual_result) - - -class TestInputErrorGridMethods: - def test_multiple_keys_vector_input(self): - ds, _, _ = datasets_grid_metric("C") - grid = Grid(ds) - msg = "Vector components provided as dictionaries should contain exactly one key/value pair. .*?" - with pytest.raises( - ValueError, - match=msg, - ): - grid.diff({"X": xr.DataArray(), "Y": xr.DataArray()}, "X") - - def test_wrong_input_type_scalar(self): - ds, _, _ = datasets_grid_metric("C") - grid = Grid(ds) - msg = "All data arguments must be either a DataArray or Dictionary .*?" - with pytest.raises( - TypeError, - match=msg, - ): - grid.diff("not_a_dataarray", "X") - - def test_wrong_input_type_vector(self): - ds, _, _ = datasets_grid_metric("C") - grid = Grid(ds) - msg = "Dictionary inputs must have a DataArray as value. Got .*?" - with pytest.raises( - TypeError, - match=msg, - ): - grid.diff({"X": "not_a_dataarray"}, "X") - - def test_wrong_axis_vector_input_axis(self): - ds, _, _ = datasets_grid_metric("C") - grid = Grid(ds) - msg = "Vector component with unknown axis provided. Grid has axes .*?" - with pytest.raises( - ValueError, - match=msg, - ): - grid.diff({"wrong": xr.DataArray()}, "X") - - -class TestInputErrorApplyAsGridUfunc: - def test_multiple_keys_vector_input(self): - ds, _, _ = datasets_grid_metric("C") - grid = Grid(ds) - msg = "Vector components provided as dictionaries should contain exactly one key/value pair. .*?" - with pytest.raises( - ValueError, - match=msg, - ): - grid.apply_as_grid_ufunc(lambda x: x, {"X": xr.DataArray(), "Y": xr.DataArray()}, "X") - - def test_wrong_input_type_scalar(self): - ds, _, _ = datasets_grid_metric("C") - grid = Grid(ds) - msg = "All data arguments must be either a DataArray or Dictionary .*?" - with pytest.raises( - TypeError, - match=msg, - ): - grid.apply_as_grid_ufunc(lambda x: x, "not_a_dataarray", "X") - - def test_wrong_input_type_vector(self): - ds, _, _ = datasets_grid_metric("C") - grid = Grid(ds) - msg = "Dictionary inputs must have a DataArray as value. Got .*?" - with pytest.raises( - TypeError, - match=msg, - ): - grid.apply_as_grid_ufunc(lambda x: x, {"X": "not_a_dataarray"}, "X") - - def test_wrong_axis_vector_input_axis(self): - ds, _, _ = datasets_grid_metric("C") - grid = Grid(ds) - msg = "Vector component with unknown axis provided. Grid has axes .*?" - with pytest.raises( - ValueError, - match=msg, - ): - grid.apply_as_grid_ufunc(lambda x: x, {"wrong": xr.DataArray()}, "X") - - def test_vector_input_data_other_mismatch(self): - ds, _, _ = datasets_grid_metric("C") - grid = Grid(ds) - msg = "When providing multiple input arguments, `other_component`" " needs to provide one dictionary per input" - with pytest.raises( - ValueError, - match=msg, - ): - # Passing 3 args and 2 other components should fail. - grid.apply_as_grid_ufunc( - lambda x: x, - {"X": xr.DataArray()}, - {"Y": xr.DataArray()}, - {"Z": xr.DataArray()}, - axis="X", - other_component=[{"X": xr.DataArray()}, {"Y": xr.DataArray()}], - ) - - def test_wrong_input_type_vector_multi_input(self): - ds, _, _ = datasets_grid_metric("C") - grid = Grid(ds) - msg = "Dictionary inputs must have a DataArray as value. Got .*?" - with pytest.raises( - TypeError, - match=msg, - ): - # Passing 3 args and 2 other components should fail. - grid.apply_as_grid_ufunc( - lambda x: x, - {"X": xr.DataArray()}, - {"Y": "not_a_data_array"}, - axis="X", - other_component=[{"X": xr.DataArray()}, {"Y": xr.DataArray()}], - ) - - def test_wrong_axis_vector_input_axis_multi_input(self): - ds, _, _ = datasets_grid_metric("C") - grid = Grid(ds) - msg = "Vector component with unknown axis provided. Grid has axes .*?" - with pytest.raises( - ValueError, - match=msg, - ): - # Passing 3 args and 2 other components should fail. - grid.apply_as_grid_ufunc( - lambda x: x, - {"X": xr.DataArray()}, - {"Y": xr.DataArray()}, - axis="X", - other_component=[{"wrong": xr.DataArray()}, {"Y": xr.DataArray()}], - ) From b181643e5a01ea9e9958e6b04daff794168702b4 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 1 Apr 2025 18:27:03 +0200 Subject: [PATCH 03/39] Remove unused functions and TODOs --- pyproject.toml | 1 + tests/v4/test_grid.py | 39 +++++++++++++-------------------------- 2 files changed, 14 insertions(+), 26 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1120a148c3..d890114e53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -188,6 +188,7 @@ ignore = [ "D205", # do not use bare except, specify exception instead "E722", + "F811", # TODO: These bugbear issues are to be resolved diff --git a/tests/v4/test_grid.py b/tests/v4/test_grid.py index 684c994e65..ca7251675d 100644 --- a/tests/v4/test_grid.py +++ b/tests/v4/test_grid.py @@ -3,15 +3,16 @@ import xarray as xr from parcels.gridv4 import Axis, Grid - -from tests.v4.datasets import all_2d # noqa: F401 -from tests.v4.datasets import all_datasets # noqa: F401 -from tests.v4.datasets import datasets # noqa: F401 -from tests.v4.datasets import datasets_grid_metric # noqa: F401 -from tests.v4.datasets import nonperiodic_1d # noqa: F401 -from tests.v4.datasets import nonperiodic_2d # noqa: F401 -from tests.v4.datasets import periodic_1d # noqa: F401 -from tests.v4.datasets import periodic_2d # noqa: F401 +from tests.v4.datasets import ( + all_2d, # noqa: F401 + all_datasets, # noqa: F401 + datasets, + datasets_grid_metric, # noqa: F401 + nonperiodic_1d, # noqa: F401 + nonperiodic_2d, # noqa: F401 + periodic_1d, # noqa: F401 + periodic_2d, # noqa: F401 +) # helper function to produce axes from datasets @@ -81,7 +82,7 @@ def test_create_axis_no_coords(all_datasets): ax1 = axis_objs[axis_name] assert ax1.name == ax2.name - for pos, coord in ax1.coords.items(): + for pos in ax1.coords.keys(): assert pos in ax2.coords assert ax1._periodic == ax2._periodic assert ax1._default_shifts == ax2._default_shifts @@ -91,35 +92,21 @@ def test_create_axis_no_coords(all_datasets): def test_axis_repr(all_datasets): ds, periodic, expected = all_datasets axis_objs = _get_axes(ds) - for ax_name, axis in axis_objs.items(): + for axis in axis_objs.values(): r = repr(axis).split("\n") assert r[0].startswith(" Date: Tue, 1 Apr 2025 18:32:27 +0200 Subject: [PATCH 04/39] Rename file to xgcm_datasets.py --- tests/v4/test_grid.py | 2 +- tests/v4/{datasets.py => xgcm_datasets.py} | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) rename tests/v4/{datasets.py => xgcm_datasets.py} (99%) diff --git a/tests/v4/test_grid.py b/tests/v4/test_grid.py index ca7251675d..476b7c0c56 100644 --- a/tests/v4/test_grid.py +++ b/tests/v4/test_grid.py @@ -3,7 +3,7 @@ import xarray as xr from parcels.gridv4 import Axis, Grid -from tests.v4.datasets import ( +from tests.v4.xgcm_datasets import ( all_2d, # noqa: F401 all_datasets, # noqa: F401 datasets, diff --git a/tests/v4/datasets.py b/tests/v4/xgcm_datasets.py similarity index 99% rename from tests/v4/datasets.py rename to tests/v4/xgcm_datasets.py index 135021c8c5..d875848b25 100644 --- a/tests/v4/datasets.py +++ b/tests/v4/xgcm_datasets.py @@ -242,7 +242,8 @@ def nonperiodic_2d(request): def datasets_grid_metric(grid_type): """Uniform grid test dataset. - Should eventually be extended to nonuniform grid""" + Should eventually be extended to nonuniform grid + """ xt = np.arange(4) xu = xt + 0.5 yt = np.arange(5) From b2ffce7ccbbf9a541c5d9b9a842958b792c931f0 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 1 Apr 2025 18:36:00 +0200 Subject: [PATCH 05/39] run pre-commit --- parcels/comodo.py | 28 +++++++--------------------- parcels/gridv4.py | 46 ++++++++++++++-------------------------------- 2 files changed, 21 insertions(+), 53 deletions(-) diff --git a/parcels/comodo.py b/parcels/comodo.py index c014b7f031..b20ae8a969 100644 --- a/parcels/comodo.py +++ b/parcels/comodo.py @@ -15,7 +15,6 @@ def assert_valid_comodo(ds): ---------- ds : xarray.dataset """ - # TODO: implement assert True @@ -42,7 +41,6 @@ def get_axis_coords(ds, axis_name): coord_name : list The names of the coordinate matching that axis """ - coord_names = [] for d in ds.dims: axis = ds[d].attrs.get("axis") @@ -74,24 +72,16 @@ def _maybe_fix_type(attr): except TypeError: return True - axis_shift = { - name: _maybe_fix_type(coord.attrs.get("c_grid_axis_shift")) - for name, coord in coords.items() - } + axis_shift = {name: _maybe_fix_type(coord.attrs.get("c_grid_axis_shift")) for name, coord in coords.items()} coord_len = {name: len(coord) for name, coord in coords.items()} # look for the center coord, which is required # this list will potentially contain "center", "inner", and "outer" points - coords_without_axis_shift = { - name: coord_len[name] for name, shift in axis_shift.items() if not shift - } + coords_without_axis_shift = {name: coord_len[name] for name, shift in axis_shift.items() if not shift} if len(coords_without_axis_shift) == 0: raise ValueError("Couldn't find a center coordinate for axis %s" % axis_name) elif len(coords_without_axis_shift) > 1: - raise ValueError( - "Found two coordinates without " - "`c_grid_axis_shift` attribute for axis %s" % axis_name - ) + raise ValueError("Found two coordinates without " "`c_grid_axis_shift` attribute for axis %s" % axis_name) center_coord_name = list(coords_without_axis_shift)[0] # the length of the center coord is key to decoding the other coords axis_len = coord_len[center_coord_name] @@ -114,16 +104,14 @@ def _maybe_fix_type(attr): axis_coords["left"] = name else: raise ValueError( - "Left coordinate %s has incompatible " - "length %g (axis_len=%g)" % (name, clen, axis_len) + "Left coordinate %s has incompatible " "length %g (axis_len=%g)" % (name, clen, axis_len) ) elif shift == axis_shift_right: if clen == axis_len: axis_coords["right"] = name else: raise ValueError( - "Right coordinate %s has incompatible " - "length %g (axis_len=%g)" % (name, clen, axis_len) + "Right coordinate %s has incompatible " "length %g (axis_len=%g)" % (name, clen, axis_len) ) else: if shift not in valid_axis_shifts: @@ -133,13 +121,11 @@ def _maybe_fix_type(attr): raise ValueError( "Coordinate %s has invalid " "`c_grid_axis_shift` attribute `%s`. " - "`c_grid_axis_shift` must be one of: %s" - % (name, repr(shift), valids) + "`c_grid_axis_shift` must be one of: %s" % (name, repr(shift), valids) ) else: raise ValueError( - "Coordinate %s has missing " - "`c_grid_axis_shift` attribute `%s`" % (name, repr(shift)) + "Coordinate %s has missing " "`c_grid_axis_shift` attribute `%s`" % (name, repr(shift)) ) return axis_coords diff --git a/parcels/gridv4.py b/parcels/gridv4.py index 7fd1f1e52b..9a3efc8afa 100644 --- a/parcels/gridv4.py +++ b/parcels/gridv4.py @@ -1,17 +1,8 @@ """This Grid object is adapted from xgcm.Grid, removing a lot of the code that is not needed for Parcels.""" -import functools -import inspect -import itertools -import operator import warnings from collections import OrderedDict - -import numpy as np -import xarray as xr -from dask.array import Array as Dask_Array - -from . import comodo +from collections.abc import Iterable # from .duck_array_ops import _apply_boundary_condition, _pad_array, concatenate # from .grid_ufunc import ( @@ -27,17 +18,10 @@ # from .padding import pad from typing import ( Any, - Callable, - Dict, - Iterable, - List, - Mapping, - Optional, - Sequence, - Tuple, - Union, ) +from . import comodo + try: import numba # type: ignore @@ -134,11 +118,10 @@ def __init__( fill_value : float, optional The value to use in the boundary condition when `boundary='fill'`. - REFERENCES + References ---------- .. [1] Comodo Conventions https://web.archive.org/web/20160417032300/http://pycomodo.forge.imag.fr/norm.html """ - self._ds = ds self.name = axis_name self._periodic = periodic @@ -253,7 +236,7 @@ def _get_position_name(self, da): if coord_name in da.dims: return position, coord_name - raise KeyError("None of the DataArray's dims %s were found in axis " "coords." % repr(da.dims)) + raise KeyError("None of the DataArray's dims %s were found in axis coords." % repr(da.dims)) def _get_axis_dim_num(self, da): """Return the dimension number of the axis coordinate in a DataArray.""" @@ -334,7 +317,7 @@ def __init__( Optionally a dict mapping axis name to seperate values for each axis can be passed. - REFERENCES + References ---------- .. [1] Comodo Conventions https://web.archive.org/web/20160417032300/http://pycomodo.forge.imag.fr/norm.html """ @@ -479,11 +462,11 @@ def __init__( def _as_axis_kwarg_mapping( self, - kwargs: Union[Any, Dict[str, Any]], - axes: Optional[Iterable[str]] = None, + kwargs: Any | dict[str, Any], + axes: Iterable[str] | None = None, ax_property_name=None, - default_value: Optional[Any] = None, - ) -> Dict[str, Any]: + default_value: Any | None = None, + ) -> dict[str, Any]: """Convert kwarg input into dict for each available axis E.g. for a grid with 2 axes for the keyword argument `periodic` periodic = True --> periodic = {'X': True, 'Y':True} @@ -493,7 +476,7 @@ def _as_axis_kwarg_mapping( if axes is None: axes = self.axes - parsed_kwargs: Dict[str, Any] = dict() + parsed_kwargs: dict[str, Any] = dict() if isinstance(kwargs, dict): parsed_kwargs = kwargs @@ -526,9 +509,8 @@ def _assign_face_connections(self, fc): """Check a dictionary of face connections to make sure all the links are consistent. """ - if len(fc) > 1: - raise ValueError("Only one face dimension is supported for now. " "Instead found %r" % repr(fc.keys())) + raise ValueError("Only one face dimension is supported for now. Instead found %r" % repr(fc.keys())) # we will populate this with the axes we find in face_connections axis_connections = {} @@ -563,9 +545,9 @@ def check_neighbor(link, position): if ax_n not in self.axes: raise KeyError("axis %r is not a valid axis" % ax_n) if idx not in self._ds[facedim].values: - raise IndexError("%r is not a valid index for face" "dimension %r" % (idx, facedim)) + raise IndexError("%r is not a valid index for face dimension %r" % (idx, facedim)) if idx_n not in self._ds[facedim].values: - raise IndexError("%r is not a valid index for face" "dimension %r" % (idx, facedim)) + raise IndexError("%r is not a valid index for face dimension %r" % (idx, facedim)) # check for consistent links from / to neighbor if (idx_n != fidx) or (ax_n != axis) or (rev_n != rev): raise ValueError( From 5e9a145682e4ec6b00588f700bcf8c3f33fc3ca9 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 1 Apr 2025 18:42:23 +0200 Subject: [PATCH 06/39] Move v4 grid files --- parcels/v4/__init__.py | 0 parcels/{ => v4}/comodo.py | 12 ++++-------- parcels/{gridv4.py => v4/grid.py} | 7 ------- tests/v4/test_grid.py | 2 +- 4 files changed, 5 insertions(+), 16 deletions(-) create mode 100644 parcels/v4/__init__.py rename parcels/{ => v4}/comodo.py (88%) rename parcels/{gridv4.py => v4/grid.py} (99%) diff --git a/parcels/v4/__init__.py b/parcels/v4/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/parcels/comodo.py b/parcels/v4/comodo.py similarity index 88% rename from parcels/comodo.py rename to parcels/v4/comodo.py index b20ae8a969..1c4a5a6f29 100644 --- a/parcels/comodo.py +++ b/parcels/v4/comodo.py @@ -81,7 +81,7 @@ def _maybe_fix_type(attr): if len(coords_without_axis_shift) == 0: raise ValueError("Couldn't find a center coordinate for axis %s" % axis_name) elif len(coords_without_axis_shift) > 1: - raise ValueError("Found two coordinates without " "`c_grid_axis_shift` attribute for axis %s" % axis_name) + raise ValueError("Found two coordinates without `c_grid_axis_shift` attribute for axis %s" % axis_name) center_coord_name = list(coords_without_axis_shift)[0] # the length of the center coord is key to decoding the other coords axis_len = coord_len[center_coord_name] @@ -103,15 +103,13 @@ def _maybe_fix_type(attr): if clen == axis_len: axis_coords["left"] = name else: - raise ValueError( - "Left coordinate %s has incompatible " "length %g (axis_len=%g)" % (name, clen, axis_len) - ) + raise ValueError("Left coordinate %s has incompatible length %g (axis_len=%g)" % (name, clen, axis_len)) elif shift == axis_shift_right: if clen == axis_len: axis_coords["right"] = name else: raise ValueError( - "Right coordinate %s has incompatible " "length %g (axis_len=%g)" % (name, clen, axis_len) + "Right coordinate %s has incompatible length %g (axis_len=%g)" % (name, clen, axis_len) ) else: if shift not in valid_axis_shifts: @@ -124,9 +122,7 @@ def _maybe_fix_type(attr): "`c_grid_axis_shift` must be one of: %s" % (name, repr(shift), valids) ) else: - raise ValueError( - "Coordinate %s has missing " "`c_grid_axis_shift` attribute `%s`" % (name, repr(shift)) - ) + raise ValueError("Coordinate %s has missing `c_grid_axis_shift` attribute `%s`" % (name, repr(shift))) return axis_coords diff --git a/parcels/gridv4.py b/parcels/v4/grid.py similarity index 99% rename from parcels/gridv4.py rename to parcels/v4/grid.py index 9a3efc8afa..5a442d25e9 100644 --- a/parcels/gridv4.py +++ b/parcels/v4/grid.py @@ -22,13 +22,6 @@ from . import comodo -try: - import numba # type: ignore - - from .transform import conservative_interpolation, linear_interpolation -except ImportError: - numba = None - _VALID_BOUNDARY = [None, "fill", "extend", "periodic"] diff --git a/tests/v4/test_grid.py b/tests/v4/test_grid.py index 476b7c0c56..96ae7be046 100644 --- a/tests/v4/test_grid.py +++ b/tests/v4/test_grid.py @@ -2,7 +2,7 @@ import pytest import xarray as xr -from parcels.gridv4 import Axis, Grid +from parcels.v4.grid import Axis, Grid from tests.v4.xgcm_datasets import ( all_2d, # noqa: F401 all_datasets, # noqa: F401 From 899f4e3b9f1baea257b1f924816b1ebd60157b90 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 1 Apr 2025 18:59:31 +0200 Subject: [PATCH 07/39] Run pre-commit and manual fixes --- parcels/v4/comodo.py | 20 ++++++++---------- parcels/v4/grid.py | 43 +++++++++++++++++++++++---------------- tests/v4/xgcm_datasets.py | 2 +- 3 files changed, 36 insertions(+), 29 deletions(-) diff --git a/parcels/v4/comodo.py b/parcels/v4/comodo.py index 1c4a5a6f29..cf82d50330 100644 --- a/parcels/v4/comodo.py +++ b/parcels/v4/comodo.py @@ -54,7 +54,7 @@ def get_axis_positions_and_coords(ds, axis_name): ncoords = len(coord_names) if ncoords == 0: # didn't find anything for this axis - raise ValueError("Couldn't find any coordinates for axis %s" % axis_name) + raise ValueError(f"Couldn't find any coordinates for axis {axis_name}") # now figure out what type of coordinates these are: # center, left, right, or outer @@ -79,9 +79,9 @@ def _maybe_fix_type(attr): # this list will potentially contain "center", "inner", and "outer" points coords_without_axis_shift = {name: coord_len[name] for name, shift in axis_shift.items() if not shift} if len(coords_without_axis_shift) == 0: - raise ValueError("Couldn't find a center coordinate for axis %s" % axis_name) + raise ValueError(f"Couldn't find a center coordinate for axis {axis_name}") elif len(coords_without_axis_shift) > 1: - raise ValueError("Found two coordinates without `c_grid_axis_shift` attribute for axis %s" % axis_name) + raise ValueError(f"Found two coordinates without `c_grid_axis_shift` attribute for axis {axis_name}") center_coord_name = list(coords_without_axis_shift)[0] # the length of the center coord is key to decoding the other coords axis_len = coord_len[center_coord_name] @@ -103,26 +103,24 @@ def _maybe_fix_type(attr): if clen == axis_len: axis_coords["left"] = name else: - raise ValueError("Left coordinate %s has incompatible length %g (axis_len=%g)" % (name, clen, axis_len)) + raise ValueError(f"Left coordinate {name} has incompatible length {clen:g} (axis_len={axis_len:g})") elif shift == axis_shift_right: if clen == axis_len: axis_coords["right"] = name else: - raise ValueError( - "Right coordinate %s has incompatible length %g (axis_len=%g)" % (name, clen, axis_len) - ) + raise ValueError(f"Right coordinate {name} has incompatible length {clen:g} (axis_len={axis_len:g})") else: if shift not in valid_axis_shifts: # string representing valid axis shifts valids = str(valid_axis_shifts)[1:-1] raise ValueError( - "Coordinate %s has invalid " - "`c_grid_axis_shift` attribute `%s`. " - "`c_grid_axis_shift` must be one of: %s" % (name, repr(shift), valids) + f"Coordinate {name} has invalid " + f"`c_grid_axis_shift` attribute `{shift!r}`. " + f"`c_grid_axis_shift` must be one of: {valids}" ) else: - raise ValueError("Coordinate %s has missing `c_grid_axis_shift` attribute `%s`" % (name, repr(shift))) + raise ValueError(f"Coordinate {name} has missing `c_grid_axis_shift` attribute `{shift!r}`") return axis_coords diff --git a/parcels/v4/grid.py b/parcels/v4/grid.py index 5a442d25e9..6b86f9e632 100644 --- a/parcels/v4/grid.py +++ b/parcels/v4/grid.py @@ -70,7 +70,7 @@ def __init__( ds, axis_name, periodic=True, - default_shifts={}, + default_shifts=None, coords=None, boundary=None, fill_value=None, @@ -115,6 +115,8 @@ def __init__( ---------- .. [1] Comodo Conventions https://web.archive.org/web/20160417032300/http://pycomodo.forge.imag.fr/norm.html """ + if default_shifts is None: + default_shifts = {} self._ds = ds self.name = axis_name self._periodic = periodic @@ -208,7 +210,7 @@ def __init__( def __repr__(self): is_periodic = "periodic" if self._periodic else "not periodic" - summary = ["" % (self.name, is_periodic, self.boundary)] + summary = [f""] summary.append("Axis Coordinates:") summary += self._coord_desc() return "\n".join(summary) @@ -216,9 +218,9 @@ def __repr__(self): def _coord_desc(self): summary = [] for name, cname in self.coords.items(): - coord_info = " * %-8s %s" % (name, cname) + coord_info = f" * {name:<8} {cname}" if name in self._default_shifts: - coord_info += " --> %s" % self._default_shifts[name] + coord_info += f" --> {self._default_shifts[name]}" summary.append(coord_info) return summary @@ -229,7 +231,7 @@ def _get_position_name(self, da): if coord_name in da.dims: return position, coord_name - raise KeyError("None of the DataArray's dims %s were found in axis coords." % repr(da.dims)) + raise KeyError(f"None of the DataArray's dims {da.dims!r} were found in axis coords.") def _get_axis_dim_num(self, da): """Return the dimension number of the axis coordinate in a DataArray.""" @@ -248,7 +250,7 @@ def __init__( ds, check_dims=True, periodic=True, - default_shifts={}, + default_shifts=None, face_connections=None, coords=None, metrics=None, @@ -314,6 +316,8 @@ def __init__( ---------- .. [1] Comodo Conventions https://web.archive.org/web/20160417032300/http://pycomodo.forge.imag.fr/norm.html """ + if default_shifts is None: + default_shifts = {} self._ds = ds self._check_dims = check_dims @@ -322,6 +326,7 @@ def __init__( "The `xgcm.Axis` class will be deprecated in the future. " "Please make sure to use the `xgcm.Grid` methods for your work instead.", category=DeprecationWarning, + stacklevel=2, ) # This will show up every time, but I think that is fine @@ -332,6 +337,7 @@ def __init__( "of array padding and avoid confusion with " "physical boundary conditions (e.g. ocean land boundary).", category=DeprecationWarning, + stacklevel=2, ) # Deprecation Warnigns @@ -340,6 +346,7 @@ def __init__( "The `periodic` argument will be deprecated. " "To preserve previous behavior supply `boundary = 'periodic'.", category=DeprecationWarning, + stacklevel=2, ) if fill_value: @@ -347,6 +354,7 @@ def __init__( "The default fill_value will be changed to nan (from 0.0 previously) " "in future versions. Provide `fill_value=0.0` to preserve previous behavior.", category=DeprecationWarning, + stacklevel=2, ) extrapolate_warning = False @@ -359,6 +367,7 @@ def __init__( warnings.warn( "The `boundary='extrapolate'` option will no longer be supported in future releases.", category=DeprecationWarning, + stacklevel=2, ) if coords: @@ -503,7 +512,7 @@ def _assign_face_connections(self, fc): consistent. """ if len(fc) > 1: - raise ValueError("Only one face dimension is supported for now. Instead found %r" % repr(fc.keys())) + raise ValueError(f"Only one face dimension is supported for now. Instead found {repr(fc.keys())!r}") # we will populate this with the axes we find in face_connections axis_connections = {} @@ -529,25 +538,25 @@ def check_neighbor(link, position): neighbor_link = face_links[idx][ax][correct_position] except (KeyError, IndexError): raise KeyError( - "Couldn't find a face link for face %r" - "in axis %r at position %r" % (idx, ax, correct_position) + f"Couldn't find a face link for face {idx!r}" + f"in axis {ax!r} at position {correct_position!r}" ) idx_n, ax_n, rev_n = neighbor_link if ax not in self.axes: - raise KeyError("axis %r is not a valid axis" % ax) + raise KeyError(f"axis {ax!r} is not a valid axis") if ax_n not in self.axes: - raise KeyError("axis %r is not a valid axis" % ax_n) + raise KeyError(f"axis {ax_n!r} is not a valid axis") if idx not in self._ds[facedim].values: - raise IndexError("%r is not a valid index for face dimension %r" % (idx, facedim)) + raise IndexError(f"{idx!r} is not a valid index for face dimension {facedim!r}") if idx_n not in self._ds[facedim].values: - raise IndexError("%r is not a valid index for face dimension %r" % (idx, facedim)) + raise IndexError(f"{idx!r} is not a valid index for face dimension {facedim!r}") # check for consistent links from / to neighbor - if (idx_n != fidx) or (ax_n != axis) or (rev_n != rev): + if (idx_n != fidx) or (ax_n != axis) or (rev_n != rev): # noqa: B023 # TODO: fix? raise ValueError( "Face link mismatch: neighbor doesn't" " correctly link back to this face. " - "face: %r, axis: %r, position: %r, " - "rev: %r, link: %r, neighbor_link: %r" % (fidx, axis, position, rev, link, neighbor_link) + f"face: {fidx!r}, axis: {axis!r}, position: {position!r}, " # noqa: B023 # TODO: fix? + f"rev: {rev!r}, link: {link!r}, neighbor_link: {neighbor_link!r}" ) # convert the axis name to an acutal axis object actual_axis = self.axes[ax] @@ -605,6 +614,6 @@ def __repr__(self): summary = [""] for name, axis in self.axes.items(): is_periodic = "periodic" if axis._periodic else "not periodic" - summary.append("%s Axis (%s, boundary=%r):" % (name, is_periodic, axis.boundary)) + summary.append(f"{name} Axis ({is_periodic}, boundary={axis.boundary!r}):") summary += axis._coord_desc() return "\n".join(summary) diff --git a/tests/v4/xgcm_datasets.py b/tests/v4/xgcm_datasets.py index d875848b25..17c46f29c1 100644 --- a/tests/v4/xgcm_datasets.py +++ b/tests/v4/xgcm_datasets.py @@ -360,6 +360,6 @@ def _add_metrics(obj): elif grid_type == "C": ds = _add_metrics(xr.Dataset({"u": u_c, "v": v_c, "wt": wt, "tracer": tr, "timeseries": timeseries})) else: - raise ValueError("Invalid input [%s] for `grid_type`. Only supports `B` and `C` at the moment " % grid_type) + raise ValueError(f"Invalid input [{grid_type}] for `grid_type`. Only supports `B` and `C` at the moment ") return ds, coords, metrics From b552b0341494838cf2e7dafe6e7c9c0e7bf950ea Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 1 Apr 2025 19:03:30 +0200 Subject: [PATCH 08/39] Update reprs and remove warning --- parcels/v4/grid.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/parcels/v4/grid.py b/parcels/v4/grid.py index 6b86f9e632..43e1e49962 100644 --- a/parcels/v4/grid.py +++ b/parcels/v4/grid.py @@ -210,7 +210,7 @@ def __init__( def __repr__(self): is_periodic = "periodic" if self._periodic else "not periodic" - summary = [f""] + summary = [f""] summary.append("Axis Coordinates:") summary += self._coord_desc() return "\n".join(summary) @@ -241,7 +241,7 @@ def _get_axis_dim_num(self, da): class Grid: """ - An object with multiple :class:`xgcm.Axis` objects representing different + An object with multiple :class:`parcels.Axis` objects representing different independent axes. """ @@ -321,15 +321,6 @@ def __init__( self._ds = ds self._check_dims = check_dims - # Deprecation Warnigns - warnings.warn( - "The `xgcm.Axis` class will be deprecated in the future. " - "Please make sure to use the `xgcm.Grid` methods for your work instead.", - category=DeprecationWarning, - stacklevel=2, - ) - # This will show up every time, but I think that is fine - if boundary: warnings.warn( "The `boundary` argument will be renamed " @@ -611,7 +602,7 @@ def set_metrics(self, key, value, overwrite=False): self._metrics[metric_axes].append(metric_var) def __repr__(self): - summary = [""] + summary = [""] for name, axis in self.axes.items(): is_periodic = "periodic" if axis._periodic else "not periodic" summary.append(f"{name} Axis ({is_periodic}, boundary={axis.boundary!r}):") From 0bffe9e64b6de02daefbe3ba297dac1a1506b558 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 1 Apr 2025 19:14:28 +0200 Subject: [PATCH 09/39] Remove unused imports --- parcels/v4/grid.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/parcels/v4/grid.py b/parcels/v4/grid.py index 43e1e49962..382cfa6511 100644 --- a/parcels/v4/grid.py +++ b/parcels/v4/grid.py @@ -3,19 +3,6 @@ import warnings from collections import OrderedDict from collections.abc import Iterable - -# from .duck_array_ops import _apply_boundary_condition, _pad_array, concatenate -# from .grid_ufunc import ( -# GridUFunc, -# _check_data_input, -# _GridUFuncSignature, -# _has_chunked_core_dims, -# _maybe_unpack_vector_component, -# _reattach_coords, -# apply_as_grid_ufunc, -# ) -# from .metrics import iterate_axis_combinations -# from .padding import pad from typing import ( Any, ) From ca7dcfc965513fb6886e471570d3e3cdbc5d92ba Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 1 Apr 2025 19:32:08 +0200 Subject: [PATCH 10/39] remove Grid._ti from old grid --- parcels/grid.py | 1 - 1 file changed, 1 deletion(-) diff --git a/parcels/grid.py b/parcels/grid.py index 9332dcb909..1163c7bda7 100644 --- a/parcels/grid.py +++ b/parcels/grid.py @@ -40,7 +40,6 @@ def __init__( time_origin: TimeConverter | None, mesh: Mesh, ): - self._ti = -1 lon = np.array(lon) lat = np.array(lat) time = np.zeros(1, dtype=np.float64) if time is None else time From bdc7041499113fdb7ebb12af26a731fbdb4ceb34 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 1 Apr 2025 19:54:14 +0200 Subject: [PATCH 11/39] Add grid adapter test cases and mock implementation --- parcels/v4/gridadapter.py | 48 ++++++++++++++++++++++++++ tests/v4/test_gridadapter.py | 67 ++++++++++++++++++++++++++++++++++++ 2 files changed, 115 insertions(+) create mode 100644 parcels/v4/gridadapter.py create mode 100644 tests/v4/test_gridadapter.py diff --git a/parcels/v4/gridadapter.py b/parcels/v4/gridadapter.py new file mode 100644 index 0000000000..4f9a8d35d4 --- /dev/null +++ b/parcels/v4/gridadapter.py @@ -0,0 +1,48 @@ +from parcels.v4.grid import Grid + + +class GridAdapter(Grid): + def __init__(self, ds, *args, **kwargs): + super().__init__(ds, *args, **kwargs) + + @property + def lon(self): ... + + @property + def lat(self): ... + + @property + def depth(self): ... + + @property + def time(self): ... + + @property + def xdim(self): ... + + @property + def ydim(self): ... + + @property + def zdim(self): ... + @property + def tdim(self): ... + + @property + def time_origin(self): ... + + @property + def mesh(self): ... # ? hmmm + + @property + def zonal_periodic(self): ... # ? hmmm + + @property + def lonlat_minmax(self): ... # ? hmmm + + @staticmethod + def create_grid(lon, lat, depth, time, time_origin, mesh, **kwargs): ... # ? hmmm + + def _check_zonal_periodic(self): ... # ? hmmm + + def _add_Sdepth_periodic_halo(self, zonal, meridional, halosize): ... # ? hmmm diff --git a/tests/v4/test_gridadapter.py b/tests/v4/test_gridadapter.py new file mode 100644 index 0000000000..9b2e7eb2ee --- /dev/null +++ b/tests/v4/test_gridadapter.py @@ -0,0 +1,67 @@ +from collections import namedtuple + +import numpy as np +import pytest +import xarray as xr + +from parcels.v4.gridadapter import GridAdapter + +N = 100 + +ds_2d_left = xr.Dataset( + { + "data_g": (["YG", "XG"], np.random.rand(2 * N, N)), + "data_c": (["YC", "XC"], np.random.rand(2 * N, N)), + }, + coords={ + "XG": ( + ["XG"], + 2 * np.pi / N * np.arange(0, N), + {"axis": "X", "c_grid_axis_shift": -0.5}, + ), + "XC": (["XC"], 2 * np.pi / N * (np.arange(0, N) + 0.5), {"axis": "X"}), + "YG": ( + ["YG"], + 2 * np.pi / (2 * N) * np.arange(0, 2 * N), + {"axis": "Y", "c_grid_axis_shift": -0.5}, + ), + "YC": ( + ["YC"], + 2 * np.pi / (2 * N) * (np.arange(0, 2 * N) + 0.5), + {"axis": "Y"}, + ), + }, +) + + +@pytest.fixture +def grid(): + return + + +TestCase = namedtuple("TestCase", ["Grid", "attr", "expected"]) + +test_cases = [ + TestCase(ds_2d_left, "lon", ds_2d_left.XC.values), + TestCase(ds_2d_left, "lat", ds_2d_left.YC.values), + TestCase(ds_2d_left, "depth", None), + TestCase(ds_2d_left, "time", None), + TestCase(ds_2d_left, "xdim", N), + TestCase(ds_2d_left, "ydim", 2 * N), + TestCase(ds_2d_left, "zdim", 1), + TestCase(ds_2d_left, "tdim", 1), +] + + +def assert_equal(actual, expected): + if expected is None: + assert actual is None + else: + assert np.allclose(actual, expected) + + +@pytest.mark.parametrize("ds, attr, expected", test_cases) +def test_grid_adapter_properties(ds, attr, expected): + adapter = GridAdapter(ds, periodic=False) + actual = getattr(adapter, attr) + assert_equal(actual, expected) From ff62de75600bc2b31b85896f21a0ddaf5a54ef85 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 1 Apr 2025 20:06:28 +0200 Subject: [PATCH 12/39] Define GridAdapter tdim, xdim, ydim and zdim --- parcels/v4/gridadapter.py | 33 ++++++++++++++++++++++++++++----- tests/v4/test_gridadapter.py | 8 ++++---- 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/parcels/v4/gridadapter.py b/parcels/v4/gridadapter.py index 4f9a8d35d4..9a8aac4ff3 100644 --- a/parcels/v4/gridadapter.py +++ b/parcels/v4/gridadapter.py @@ -1,4 +1,22 @@ -from parcels.v4.grid import Grid +from parcels.v4.grid import Axis, Grid + + +def get_dimensionality(axis: Axis | None) -> int: + if axis is None: + return 1 + first_coord = list(axis.coords.items())[0] + pos, coord = first_coord + + pos_to_dim = { # TODO: These could do with being explicitly tested + "center": lambda x: x, + "left": lambda x: x, + "right": lambda x: x, + "inner": lambda x: x + 1, + "outer": lambda x: x - 1, + } + + n = axis._ds[coord].size + return pos_to_dim[pos](n) class GridAdapter(Grid): @@ -18,15 +36,20 @@ def depth(self): ... def time(self): ... @property - def xdim(self): ... + def xdim(self): + return get_dimensionality(self.axes.get("X")) @property - def ydim(self): ... + def ydim(self): + return get_dimensionality(self.axes.get("Y")) @property - def zdim(self): ... + def zdim(self): + return get_dimensionality(self.axes.get("Z")) + @property - def tdim(self): ... + def tdim(self): + return get_dimensionality(self.axes.get("T")) @property def time_origin(self): ... diff --git a/tests/v4/test_gridadapter.py b/tests/v4/test_gridadapter.py index 9b2e7eb2ee..8479b64095 100644 --- a/tests/v4/test_gridadapter.py +++ b/tests/v4/test_gridadapter.py @@ -42,10 +42,10 @@ def grid(): TestCase = namedtuple("TestCase", ["Grid", "attr", "expected"]) test_cases = [ - TestCase(ds_2d_left, "lon", ds_2d_left.XC.values), - TestCase(ds_2d_left, "lat", ds_2d_left.YC.values), - TestCase(ds_2d_left, "depth", None), - TestCase(ds_2d_left, "time", None), + # TestCase(ds_2d_left, "lon", ds_2d_left.XC.values), + # TestCase(ds_2d_left, "lat", ds_2d_left.YC.values), + # TestCase(ds_2d_left, "depth", None), + # TestCase(ds_2d_left, "time", None), TestCase(ds_2d_left, "xdim", N), TestCase(ds_2d_left, "ydim", 2 * N), TestCase(ds_2d_left, "zdim", 1), From f175eea3d5c91e803c7e438e875fc7492bdb90a5 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 1 Apr 2025 20:10:50 +0200 Subject: [PATCH 13/39] Add time to test_gridadapter dataset --- tests/v4/test_gridadapter.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/v4/test_gridadapter.py b/tests/v4/test_gridadapter.py index 8479b64095..e431327d5e 100644 --- a/tests/v4/test_gridadapter.py +++ b/tests/v4/test_gridadapter.py @@ -7,11 +7,12 @@ from parcels.v4.gridadapter import GridAdapter N = 100 +T = 10 ds_2d_left = xr.Dataset( { - "data_g": (["YG", "XG"], np.random.rand(2 * N, N)), - "data_c": (["YC", "XC"], np.random.rand(2 * N, N)), + "data_g": (["time", "YG", "XG"], np.random.rand(10, 2 * N, N)), + "data_c": (["time", "YC", "XC"], np.random.rand(10, 2 * N, N)), }, coords={ "XG": ( @@ -30,6 +31,7 @@ 2 * np.pi / (2 * N) * (np.arange(0, 2 * N) + 0.5), {"axis": "Y"}, ), + "time": (["time"], np.arange(T), {"axis": "T"}), }, ) @@ -49,7 +51,7 @@ def grid(): TestCase(ds_2d_left, "xdim", N), TestCase(ds_2d_left, "ydim", 2 * N), TestCase(ds_2d_left, "zdim", 1), - TestCase(ds_2d_left, "tdim", 1), + TestCase(ds_2d_left, "tdim", T), ] From f7fa370e4142ae91393fc9ddcc2ce7716fd18f86 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 1 Apr 2025 20:12:41 +0200 Subject: [PATCH 14/39] Add Z dim to test_gridadapter dataset --- tests/v4/test_gridadapter.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/tests/v4/test_gridadapter.py b/tests/v4/test_gridadapter.py index e431327d5e..e889825144 100644 --- a/tests/v4/test_gridadapter.py +++ b/tests/v4/test_gridadapter.py @@ -11,8 +11,8 @@ ds_2d_left = xr.Dataset( { - "data_g": (["time", "YG", "XG"], np.random.rand(10, 2 * N, N)), - "data_c": (["time", "YC", "XC"], np.random.rand(10, 2 * N, N)), + "data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, 3 * N, 2 * N, N)), + "data_c": (["time", "ZC", "YC", "XC"], np.random.rand(T, 3 * N, 2 * N, N)), }, coords={ "XG": ( @@ -31,6 +31,16 @@ 2 * np.pi / (2 * N) * (np.arange(0, 2 * N) + 0.5), {"axis": "Y"}, ), + "ZG": ( + ["ZG"], + np.arange(3 * N), + {"axis": "Z", "c_grid_axis_shift": -0.5}, + ), + "ZC": ( + ["ZC"], + np.arange(3 * N) + 0.5, + {"axis": "Z"}, + ), "time": (["time"], np.arange(T), {"axis": "T"}), }, ) @@ -50,7 +60,7 @@ def grid(): # TestCase(ds_2d_left, "time", None), TestCase(ds_2d_left, "xdim", N), TestCase(ds_2d_left, "ydim", 2 * N), - TestCase(ds_2d_left, "zdim", 1), + TestCase(ds_2d_left, "zdim", 3 * N), TestCase(ds_2d_left, "tdim", T), ] From df93007de996cea8e9d22f6fa901d2525c5b32d6 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 1 Apr 2025 20:23:25 +0200 Subject: [PATCH 15/39] Define GridAdapter lon, lat, depth, and time Assuming for now that left f-points are given for all Grids (which I think is an assumption that held for the grid class in Parcels v3) --- parcels/v4/gridadapter.py | 28 ++++++++++++++++++++++++---- tests/v4/test_gridadapter.py | 15 +++++---------- 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/parcels/v4/gridadapter.py b/parcels/v4/gridadapter.py index 9a8aac4ff3..9ab8779f82 100644 --- a/parcels/v4/gridadapter.py +++ b/parcels/v4/gridadapter.py @@ -1,3 +1,5 @@ +import numpy.typing as npt + from parcels.v4.grid import Axis, Grid @@ -19,21 +21,39 @@ def get_dimensionality(axis: Axis | None) -> int: return pos_to_dim[pos](n) +def get_left_fpoints(axis: Axis) -> npt.NDArray: + return axis._ds[axis.coords["left"]].values + + +def get_time(axis: Axis) -> npt.NDArray: + return axis._ds[axis.coords["center"]].values + + class GridAdapter(Grid): def __init__(self, ds, *args, **kwargs): super().__init__(ds, *args, **kwargs) @property - def lon(self): ... + def lon(self): + return get_left_fpoints(self.axes["X"]) @property - def lat(self): ... + def lat(self): + return get_left_fpoints(self.axes["Y"]) @property - def depth(self): ... + def depth(self): + try: + return get_left_fpoints(self.axes["Z"]) + except KeyError: + return None @property - def time(self): ... + def time(self): + try: + return get_time(self.axes["T"]) + except KeyError: + return None @property def xdim(self): diff --git a/tests/v4/test_gridadapter.py b/tests/v4/test_gridadapter.py index e889825144..cb0d19d352 100644 --- a/tests/v4/test_gridadapter.py +++ b/tests/v4/test_gridadapter.py @@ -46,18 +46,13 @@ ) -@pytest.fixture -def grid(): - return - - TestCase = namedtuple("TestCase", ["Grid", "attr", "expected"]) test_cases = [ - # TestCase(ds_2d_left, "lon", ds_2d_left.XC.values), - # TestCase(ds_2d_left, "lat", ds_2d_left.YC.values), - # TestCase(ds_2d_left, "depth", None), - # TestCase(ds_2d_left, "time", None), + TestCase(ds_2d_left, "lon", ds_2d_left.XG.values), + TestCase(ds_2d_left, "lat", ds_2d_left.YG.values), + TestCase(ds_2d_left, "depth", ds_2d_left.ZG.values), + TestCase(ds_2d_left, "time", ds_2d_left.time.values), TestCase(ds_2d_left, "xdim", N), TestCase(ds_2d_left, "ydim", 2 * N), TestCase(ds_2d_left, "zdim", 3 * N), @@ -69,7 +64,7 @@ def assert_equal(actual, expected): if expected is None: assert actual is None else: - assert np.allclose(actual, expected) + np.testing.assert_allclose(actual, expected) @pytest.mark.parametrize("ds, attr, expected", test_cases) From 82b827146d0f9ea94f3213446101a5872a7c0797 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 1 Apr 2025 20:27:21 +0200 Subject: [PATCH 16/39] Limit scope of keyerror except blocks --- parcels/v4/gridadapter.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/parcels/v4/gridadapter.py b/parcels/v4/gridadapter.py index 9ab8779f82..c127919cdb 100644 --- a/parcels/v4/gridadapter.py +++ b/parcels/v4/gridadapter.py @@ -35,25 +35,35 @@ def __init__(self, ds, *args, **kwargs): @property def lon(self): - return get_left_fpoints(self.axes["X"]) + try: + axis = self.axes["X"] + except KeyError: + return None + return get_left_fpoints(axis) @property def lat(self): - return get_left_fpoints(self.axes["Y"]) + try: + axis = self.axes["Y"] + except KeyError: + return None + return get_left_fpoints(axis) @property def depth(self): try: - return get_left_fpoints(self.axes["Z"]) + axis = self.axes["Z"] except KeyError: return None + return get_left_fpoints(axis) @property def time(self): try: - return get_time(self.axes["T"]) + axis = self.axes["T"] except KeyError: return None + return get_time(axis) @property def xdim(self): From 6ef7fc65b09cb3f1bd0c59da99ce38770ceb5698 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 1 Apr 2025 20:44:35 +0200 Subject: [PATCH 17/39] Define GridAdapter time_origin Also define equality for TimeConverter objects --- parcels/tools/converters.py | 3 +++ parcels/v4/gridadapter.py | 4 +++- tests/v4/test_gridadapter.py | 4 ++++ 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/parcels/tools/converters.py b/parcels/tools/converters.py index 4b60204e1f..36cf0c1bf4 100644 --- a/parcels/tools/converters.py +++ b/parcels/tools/converters.py @@ -61,6 +61,9 @@ def __init__(self, time_origin: float | np.datetime64 | np.timedelta64 | cftime. elif isinstance(time_origin, cftime.datetime): self.calendar = time_origin.calendar + def __eq__(self, other): + return self.time_origin == other.time_origin and self.calendar == other.calendar + def reltime(self, time: TimeConverter | np.datetime64 | np.timedelta64 | cftime.datetime) -> float | npt.NDArray: """Method to compute the difference, in seconds, between a time and the time_origin of the TimeConverter diff --git a/parcels/v4/gridadapter.py b/parcels/v4/gridadapter.py index c127919cdb..b75c7c54db 100644 --- a/parcels/v4/gridadapter.py +++ b/parcels/v4/gridadapter.py @@ -1,5 +1,6 @@ import numpy.typing as npt +from parcels.tools.converters import TimeConverter from parcels.v4.grid import Axis, Grid @@ -82,7 +83,8 @@ def tdim(self): return get_dimensionality(self.axes.get("T")) @property - def time_origin(self): ... + def time_origin(self): + return TimeConverter(self.time[0]) @property def mesh(self): ... # ? hmmm diff --git a/tests/v4/test_gridadapter.py b/tests/v4/test_gridadapter.py index cb0d19d352..5bfd56fba8 100644 --- a/tests/v4/test_gridadapter.py +++ b/tests/v4/test_gridadapter.py @@ -4,6 +4,7 @@ import pytest import xarray as xr +from parcels.tools.converters import TimeConverter from parcels.v4.gridadapter import GridAdapter N = 100 @@ -57,12 +58,15 @@ TestCase(ds_2d_left, "ydim", 2 * N), TestCase(ds_2d_left, "zdim", 3 * N), TestCase(ds_2d_left, "tdim", T), + TestCase(ds_2d_left, "time_origin", TimeConverter(ds_2d_left.time.values[0])), ] def assert_equal(actual, expected): if expected is None: assert actual is None + elif isinstance(expected, TimeConverter): + assert actual == expected else: np.testing.assert_allclose(actual, expected) From 1ec179a15745952a52a6adb4da7c430ba5ce37e0 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 2 Apr 2025 16:57:28 +0200 Subject: [PATCH 18/39] Update adapter to return zero array --- parcels/v4/gridadapter.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/parcels/v4/gridadapter.py b/parcels/v4/gridadapter.py index b75c7c54db..0664848088 100644 --- a/parcels/v4/gridadapter.py +++ b/parcels/v4/gridadapter.py @@ -1,3 +1,4 @@ +import numpy as np import numpy.typing as npt from parcels.tools.converters import TimeConverter @@ -39,7 +40,7 @@ def lon(self): try: axis = self.axes["X"] except KeyError: - return None + return np.zeros(1) return get_left_fpoints(axis) @property @@ -47,7 +48,7 @@ def lat(self): try: axis = self.axes["Y"] except KeyError: - return None + return np.zeros(1) return get_left_fpoints(axis) @property @@ -55,7 +56,7 @@ def depth(self): try: axis = self.axes["Z"] except KeyError: - return None + return np.zeros(1) return get_left_fpoints(axis) @property @@ -63,7 +64,7 @@ def time(self): try: axis = self.axes["T"] except KeyError: - return None + return np.zeros(1) return get_time(axis) @property From ec49bc0143b62715952de39e6038aff2afb4775a Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 2 Apr 2025 16:58:26 +0200 Subject: [PATCH 19/39] Add _z4d to gridadapter --- parcels/v4/gridadapter.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/parcels/v4/gridadapter.py b/parcels/v4/gridadapter.py index 0664848088..7b1fd5602e 100644 --- a/parcels/v4/gridadapter.py +++ b/parcels/v4/gridadapter.py @@ -1,3 +1,5 @@ +from typing import Literal + import numpy as np import numpy.typing as npt @@ -87,6 +89,10 @@ def tdim(self): def time_origin(self): return TimeConverter(self.time[0]) + @property + def _z4d(self) -> Literal[0, 1]: + return 1 if self.depth.shape == 4 else 0 + @property def mesh(self): ... # ? hmmm From 86277303834ac2a345d96a45c7bf80c4b76b4c9c Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 3 Apr 2025 14:19:32 +0200 Subject: [PATCH 20/39] Move vendored file to vendor folder --- tests/v4/test_grid.py | 2 +- tests/vendor/__init__.py | 0 tests/{v4 => vendor}/xgcm_datasets.py | 0 3 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 tests/vendor/__init__.py rename tests/{v4 => vendor}/xgcm_datasets.py (100%) diff --git a/tests/v4/test_grid.py b/tests/v4/test_grid.py index 96ae7be046..4f265578c3 100644 --- a/tests/v4/test_grid.py +++ b/tests/v4/test_grid.py @@ -3,7 +3,7 @@ import xarray as xr from parcels.v4.grid import Axis, Grid -from tests.v4.xgcm_datasets import ( +from tests.vendor.xgcm_datasets import ( all_2d, # noqa: F401 all_datasets, # noqa: F401 datasets, diff --git a/tests/vendor/__init__.py b/tests/vendor/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/v4/xgcm_datasets.py b/tests/vendor/xgcm_datasets.py similarity index 100% rename from tests/v4/xgcm_datasets.py rename to tests/vendor/xgcm_datasets.py From 86f4d3876f284dd66aaaa58ac296060acd3010b8 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 3 Apr 2025 14:20:21 +0200 Subject: [PATCH 21/39] Update test suite repr --- tests/v4/test_grid.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/v4/test_grid.py b/tests/v4/test_grid.py index 4f265578c3..7f6f28a905 100644 --- a/tests/v4/test_grid.py +++ b/tests/v4/test_grid.py @@ -94,7 +94,7 @@ def test_axis_repr(all_datasets): axis_objs = _get_axes(ds) for axis in axis_objs.values(): r = repr(axis).split("\n") - assert r[0].startswith("" + assert r[0] == "" def test_grid_dict_input_boundary_fill(nonperiodic_1d): From c6de1cdad51e29cc1ca40bb7a4b2b806fa80b228 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 3 Apr 2025 16:35:49 +0200 Subject: [PATCH 22/39] Add curvilinear test grid datasets --- tests/v4/grid_datasets.py | 101 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 101 insertions(+) create mode 100644 tests/v4/grid_datasets.py diff --git a/tests/v4/grid_datasets.py b/tests/v4/grid_datasets.py new file mode 100644 index 0000000000..14e0656f5b --- /dev/null +++ b/tests/v4/grid_datasets.py @@ -0,0 +1,101 @@ +import numpy as np +import xarray as xr + +N = 30 + + +def rotated_curvilinear_grid(): + XG = np.arange(N) + YG = np.arange(2 * N) + LON, LAT = np.meshgrid(XG, YG) + + angle = -np.pi / 24 + rotation = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]]) + + # rotate the LON and LAT grids + LON, LAT = np.einsum("ji, mni -> jmn", rotation, np.dstack([LON, LAT])) + + return xr.Dataset( + { + "data_g": (["YG", "XG"], np.random.rand(2 * N, N)), + "data_c": (["YC", "XC"], np.random.rand(2 * N, N)), + }, + coords={ + "XG": (["XG"], XG, {"axis": "X", "c_grid_axis_shift": -0.5}), + "YG": (["YG"], YG, {"axis": "Y", "c_grid_axis_shift": -0.5}), + "XC": (["XC"], XG + 0.5, {"axis": "X"}), + "YC": (["YC"], YG + 0.5, {"axis": "Y"}), + "lon": ( + ["YG", "XG"], + LON, + {"axis": "X", "c_grid_axis_shift": -0.5}, # ? Needed? + ), + "lat": ( + ["YG", "XG"], + LAT, + {"axis": "Y", "c_grid_axis_shift": -0.5}, # ? Needed? + ), + }, + ) + + +def _cartesion_to_polar(x, y): + r = np.sqrt(x**2 + y**2) + theta = np.arctan2(y, x) + return r, theta + + +def _polar_to_cartesian(r, theta): + x = r * np.cos(theta) + y = r * np.sin(theta) + return x, y + + +def unrolled_cone_curvilinear_grid(): + # Not a great unrolled cone, but this is good enough for testing + # you can use matplotlib pcolormesh to plot + XG = np.arange(N) + YG = np.arange(2 * N) * 0.25 + + pivot = -10, 0 + LON, LAT = np.meshgrid(XG, YG) + + new_lon_lat = [] + + min_lon = np.min(XG) + for lon, lat in zip(LON.flatten(), LAT.flatten(), strict=True): + r, _ = _cartesion_to_polar(lon - pivot[0], lat - pivot[1]) + _, theta = _cartesion_to_polar(min_lon - pivot[0], lat - pivot[1]) + theta *= 1.2 + r *= 1.2 + lon, lat = _polar_to_cartesian(r, theta) + new_lon_lat.append((lon + pivot[0], lat + pivot[1])) + + new_lon, new_lat = zip(*new_lon_lat, strict=True) + LON, LAT = np.array(new_lon).reshape(LON.shape), np.array(new_lat).reshape(LAT.shape) + + return xr.Dataset( + { + "data_g": (["YG", "XG"], np.random.rand(2 * N, N)), + "data_c": (["YC", "XC"], np.random.rand(2 * N, N)), + }, + coords={ + "XG": (["XG"], XG, {"axis": "X", "c_grid_axis_shift": -0.5}), + "YG": (["YG"], YG, {"axis": "Y", "c_grid_axis_shift": -0.5}), + "XC": (["XC"], XG + 0.5, {"axis": "X"}), + "YC": (["YC"], YG + 0.5, {"axis": "Y"}), + "lon": ( + ["YG", "XG"], + LON, + {"axis": "X", "c_grid_axis_shift": -0.5}, # ? Needed? + ), + "lat": ( + ["YG", "XG"], + LAT, + {"axis": "Y", "c_grid_axis_shift": -0.5}, # ? Needed? + ), + }, + ) + + +datasets = {"2d_left_rotated": rotated_curvilinear_grid(), "2d_left_unrolled_cone": unrolled_cone_curvilinear_grid()} From 78889e11df9d226c8d2021a229344847fef5cbd3 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 3 Apr 2025 16:48:28 +0200 Subject: [PATCH 23/39] Update grid datasets with depth and time --- tests/v4/grid_datasets.py | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/tests/v4/grid_datasets.py b/tests/v4/grid_datasets.py index 14e0656f5b..b9f15c2b05 100644 --- a/tests/v4/grid_datasets.py +++ b/tests/v4/grid_datasets.py @@ -2,6 +2,7 @@ import xarray as xr N = 30 +T = 10 def rotated_curvilinear_grid(): @@ -17,14 +18,26 @@ def rotated_curvilinear_grid(): return xr.Dataset( { - "data_g": (["YG", "XG"], np.random.rand(2 * N, N)), - "data_c": (["YC", "XC"], np.random.rand(2 * N, N)), + "data_g": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)), + "data_c": (["ZC", "YC", "XC"], np.random.rand(3 * N, 2 * N, N)), }, coords={ "XG": (["XG"], XG, {"axis": "X", "c_grid_axis_shift": -0.5}), "YG": (["YG"], YG, {"axis": "Y", "c_grid_axis_shift": -0.5}), "XC": (["XC"], XG + 0.5, {"axis": "X"}), "YC": (["YC"], YG + 0.5, {"axis": "Y"}), + "ZG": ( + ["ZG"], + np.arange(3 * N), + {"axis": "Z", "c_grid_axis_shift": -0.5}, + ), + "ZC": ( + ["ZC"], + np.arange(3 * N) + 0.5, + {"axis": "Z"}, + ), + "depth": (["ZG"], np.arange(3 * N), {"axis": "Z"}), + "time": (["time"], np.arange(T), {"axis": "T"}), "lon": ( ["YG", "XG"], LON, @@ -76,14 +89,26 @@ def unrolled_cone_curvilinear_grid(): return xr.Dataset( { - "data_g": (["YG", "XG"], np.random.rand(2 * N, N)), - "data_c": (["YC", "XC"], np.random.rand(2 * N, N)), + "data_g": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)), + "data_c": (["ZC", "YC", "XC"], np.random.rand(3 * N, 2 * N, N)), }, coords={ "XG": (["XG"], XG, {"axis": "X", "c_grid_axis_shift": -0.5}), "YG": (["YG"], YG, {"axis": "Y", "c_grid_axis_shift": -0.5}), "XC": (["XC"], XG + 0.5, {"axis": "X"}), "YC": (["YC"], YG + 0.5, {"axis": "Y"}), + "ZG": ( + ["ZG"], + np.arange(3 * N), + {"axis": "Z", "c_grid_axis_shift": -0.5}, + ), + "ZC": ( + ["ZC"], + np.arange(3 * N) + 0.5, + {"axis": "Z"}, + ), + "depth": (["ZG"], np.arange(3 * N), {"axis": "Z"}), + "time": (["time"], np.arange(T), {"axis": "T"}), "lon": ( ["YG", "XG"], LON, From 4c56c06a1472c63a0c7ac26cee3d69a2b4c30b06 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 3 Apr 2025 17:02:59 +0200 Subject: [PATCH 24/39] Move dataset and start with new test --- tests/v4/grid_datasets.py | 40 +++++++++++++++- tests/v4/test_gridadapter.py | 89 +++++++++++++++--------------------- 2 files changed, 77 insertions(+), 52 deletions(-) diff --git a/tests/v4/grid_datasets.py b/tests/v4/grid_datasets.py index b9f15c2b05..5a27e16904 100644 --- a/tests/v4/grid_datasets.py +++ b/tests/v4/grid_datasets.py @@ -123,4 +123,42 @@ def unrolled_cone_curvilinear_grid(): ) -datasets = {"2d_left_rotated": rotated_curvilinear_grid(), "2d_left_unrolled_cone": unrolled_cone_curvilinear_grid()} +datasets = { + "2d_left_rotated": rotated_curvilinear_grid(), + "ds_2d_left": xr.Dataset( + { + "data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, 3 * N, 2 * N, N)), + "data_c": (["time", "ZC", "YC", "XC"], np.random.rand(T, 3 * N, 2 * N, N)), + }, + coords={ + "XG": ( + ["XG"], + 2 * np.pi / N * np.arange(0, N), + {"axis": "X", "c_grid_axis_shift": -0.5}, + ), + "XC": (["XC"], 2 * np.pi / N * (np.arange(0, N) + 0.5), {"axis": "X"}), + "YG": ( + ["YG"], + 2 * np.pi / (2 * N) * np.arange(0, 2 * N), + {"axis": "Y", "c_grid_axis_shift": -0.5}, + ), + "YC": ( + ["YC"], + 2 * np.pi / (2 * N) * (np.arange(0, 2 * N) + 0.5), + {"axis": "Y"}, + ), + "ZG": ( + ["ZG"], + np.arange(3 * N), + {"axis": "Z", "c_grid_axis_shift": -0.5}, + ), + "ZC": ( + ["ZC"], + np.arange(3 * N) + 0.5, + {"axis": "Z"}, + ), + "time": (["time"], np.arange(T), {"axis": "T"}), + }, + ), + "2d_left_unrolled_cone": unrolled_cone_curvilinear_grid(), +} diff --git a/tests/v4/test_gridadapter.py b/tests/v4/test_gridadapter.py index 5bfd56fba8..b38936df88 100644 --- a/tests/v4/test_gridadapter.py +++ b/tests/v4/test_gridadapter.py @@ -2,63 +2,25 @@ import numpy as np import pytest -import xarray as xr +from numpy.testing import assert_array_equal +from parcels.grid import Grid as OldGrid from parcels.tools.converters import TimeConverter from parcels.v4.gridadapter import GridAdapter - -N = 100 -T = 10 - -ds_2d_left = xr.Dataset( - { - "data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, 3 * N, 2 * N, N)), - "data_c": (["time", "ZC", "YC", "XC"], np.random.rand(T, 3 * N, 2 * N, N)), - }, - coords={ - "XG": ( - ["XG"], - 2 * np.pi / N * np.arange(0, N), - {"axis": "X", "c_grid_axis_shift": -0.5}, - ), - "XC": (["XC"], 2 * np.pi / N * (np.arange(0, N) + 0.5), {"axis": "X"}), - "YG": ( - ["YG"], - 2 * np.pi / (2 * N) * np.arange(0, 2 * N), - {"axis": "Y", "c_grid_axis_shift": -0.5}, - ), - "YC": ( - ["YC"], - 2 * np.pi / (2 * N) * (np.arange(0, 2 * N) + 0.5), - {"axis": "Y"}, - ), - "ZG": ( - ["ZG"], - np.arange(3 * N), - {"axis": "Z", "c_grid_axis_shift": -0.5}, - ), - "ZC": ( - ["ZC"], - np.arange(3 * N) + 0.5, - {"axis": "Z"}, - ), - "time": (["time"], np.arange(T), {"axis": "T"}), - }, -) - +from tests.v4.grid_datasets import N, T, datasets TestCase = namedtuple("TestCase", ["Grid", "attr", "expected"]) test_cases = [ - TestCase(ds_2d_left, "lon", ds_2d_left.XG.values), - TestCase(ds_2d_left, "lat", ds_2d_left.YG.values), - TestCase(ds_2d_left, "depth", ds_2d_left.ZG.values), - TestCase(ds_2d_left, "time", ds_2d_left.time.values), - TestCase(ds_2d_left, "xdim", N), - TestCase(ds_2d_left, "ydim", 2 * N), - TestCase(ds_2d_left, "zdim", 3 * N), - TestCase(ds_2d_left, "tdim", T), - TestCase(ds_2d_left, "time_origin", TimeConverter(ds_2d_left.time.values[0])), + TestCase(datasets["ds_2d_left"], "lon", datasets["ds_2d_left"].XG.values), + TestCase(datasets["ds_2d_left"], "lat", datasets["ds_2d_left"].YG.values), + TestCase(datasets["ds_2d_left"], "depth", datasets["ds_2d_left"].ZG.values), + TestCase(datasets["ds_2d_left"], "time", datasets["ds_2d_left"].time.values), + TestCase(datasets["ds_2d_left"], "xdim", N), + TestCase(datasets["ds_2d_left"], "ydim", 2 * N), + TestCase(datasets["ds_2d_left"], "zdim", 3 * N), + TestCase(datasets["ds_2d_left"], "tdim", T), + TestCase(datasets["ds_2d_left"], "time_origin", TimeConverter(datasets["ds_2d_left"].time.values[0])), ] @@ -72,7 +34,32 @@ def assert_equal(actual, expected): @pytest.mark.parametrize("ds, attr, expected", test_cases) -def test_grid_adapter_properties(ds, attr, expected): +def test_grid_adapter_properties_ground_truth(ds, attr, expected): adapter = GridAdapter(ds, periodic=False) actual = getattr(adapter, attr) assert_equal(actual, expected) + + +@pytest.mark.parametrize("ds", datasets.values()) +def test_grid_adapter_against_old(ds): + adapter = GridAdapter(ds, periodic=False) + + grid = OldGrid.create_grid( + lon=ds.lon.values, + lat=ds.lat.values, + depth=ds.depth.values, + time=ds.time.values, + time_origin=TimeConverter(ds.time.values[0]), + mesh="spherical", + ) + assert grid.lon.shape == adapter.lon.shape + assert grid.lat.shape == adapter.lat.shape + assert grid.depth.shape == adapter.depth.shape + assert grid.time.shape == adapter.time.shape + + assert_array_equal(grid.lon, adapter.lon) + assert_array_equal(grid.lat, adapter.lat) + assert_array_equal(grid.depth, adapter.depth) + assert_array_equal(grid.time, adapter.time) + + assert grid.time_origin == adapter.time_origin From 29568b9907a8b2a069dc29d1bf8c47857047b4a6 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 3 Apr 2025 17:20:55 +0200 Subject: [PATCH 25/39] Update test --- tests/v4/grid_datasets.py | 3 +++ tests/v4/test_gridadapter.py | 37 ++++++++++++++++++++++-------------- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/tests/v4/grid_datasets.py b/tests/v4/grid_datasets.py index 5a27e16904..72928fc604 100644 --- a/tests/v4/grid_datasets.py +++ b/tests/v4/grid_datasets.py @@ -157,6 +157,9 @@ def unrolled_cone_curvilinear_grid(): np.arange(3 * N) + 0.5, {"axis": "Z"}, ), + "lon": (["XG"], 2 * np.pi / N * np.arange(0, N)), + "lat": (["YG"], 2 * np.pi / (2 * N) * np.arange(0, 2 * N)), + "depth": (["ZG"], np.arange(3 * N)), "time": (["time"], np.arange(T), {"axis": "T"}), }, ), diff --git a/tests/v4/test_gridadapter.py b/tests/v4/test_gridadapter.py index b38936df88..f83d7812da 100644 --- a/tests/v4/test_gridadapter.py +++ b/tests/v4/test_gridadapter.py @@ -2,7 +2,7 @@ import numpy as np import pytest -from numpy.testing import assert_array_equal +from numpy.testing import assert_allclose from parcels.grid import Grid as OldGrid from parcels.tools.converters import TimeConverter @@ -29,8 +29,11 @@ def assert_equal(actual, expected): assert actual is None elif isinstance(expected, TimeConverter): assert actual == expected + elif isinstance(expected, np.ndarray): + assert actual.shape == expected.shape + assert_allclose(actual, expected) else: - np.testing.assert_allclose(actual, expected) + assert_allclose(actual, expected) @pytest.mark.parametrize("ds, attr, expected", test_cases) @@ -40,8 +43,22 @@ def test_grid_adapter_properties_ground_truth(ds, attr, expected): assert_equal(actual, expected) +@pytest.mark.parametrize( + "attr", + [ + "lon", + "lat", + "depth", + "time", + "xdim", + "ydim", + "zdim", + "tdim", + "time_origin", + ], +) @pytest.mark.parametrize("ds", datasets.values()) -def test_grid_adapter_against_old(ds): +def test_grid_adapter_against_old(ds, attr): adapter = GridAdapter(ds, periodic=False) grid = OldGrid.create_grid( @@ -52,14 +69,6 @@ def test_grid_adapter_against_old(ds): time_origin=TimeConverter(ds.time.values[0]), mesh="spherical", ) - assert grid.lon.shape == adapter.lon.shape - assert grid.lat.shape == adapter.lat.shape - assert grid.depth.shape == adapter.depth.shape - assert grid.time.shape == adapter.time.shape - - assert_array_equal(grid.lon, adapter.lon) - assert_array_equal(grid.lat, adapter.lat) - assert_array_equal(grid.depth, adapter.depth) - assert_array_equal(grid.time, adapter.time) - - assert grid.time_origin == adapter.time_origin + actual = getattr(adapter, attr) + expected = getattr(grid, attr) + assert_equal(actual, expected) From 70588bcec80c269effc8074631239c497fb183fa Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 3 Apr 2025 17:24:03 +0200 Subject: [PATCH 26/39] Use lon, lat, depth arrays on underlying dataset --- parcels/v4/gridadapter.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/parcels/v4/gridadapter.py b/parcels/v4/gridadapter.py index 7b1fd5602e..369303550b 100644 --- a/parcels/v4/gridadapter.py +++ b/parcels/v4/gridadapter.py @@ -25,10 +25,6 @@ def get_dimensionality(axis: Axis | None) -> int: return pos_to_dim[pos](n) -def get_left_fpoints(axis: Axis) -> npt.NDArray: - return axis._ds[axis.coords["left"]].values - - def get_time(axis: Axis) -> npt.NDArray: return axis._ds[axis.coords["center"]].values @@ -40,26 +36,26 @@ def __init__(self, ds, *args, **kwargs): @property def lon(self): try: - axis = self.axes["X"] + _ = self.axes["X"] except KeyError: return np.zeros(1) - return get_left_fpoints(axis) + return self._ds["lon"].values @property def lat(self): try: - axis = self.axes["Y"] + _ = self.axes["Y"] except KeyError: return np.zeros(1) - return get_left_fpoints(axis) + return self._ds["lat"].values @property def depth(self): try: - axis = self.axes["Z"] + _ = self.axes["Z"] except KeyError: return np.zeros(1) - return get_left_fpoints(axis) + return self._ds["depth"].values @property def time(self): From faf7a6c942377b8c2855ee259fdf0d371404af33 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 3 Apr 2025 17:28:17 +0200 Subject: [PATCH 27/39] Add grid_type property to grid adapter --- parcels/v4/gridadapter.py | 19 +++++++++++++++++++ tests/v4/test_gridadapter.py | 15 +++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/parcels/v4/gridadapter.py b/parcels/v4/gridadapter.py index 369303550b..091f61328d 100644 --- a/parcels/v4/gridadapter.py +++ b/parcels/v4/gridadapter.py @@ -3,6 +3,7 @@ import numpy as np import numpy.typing as npt +from parcels.grid import CurvilinearSGrid, CurvilinearZGrid, RectilinearSGrid, RectilinearZGrid from parcels.tools.converters import TimeConverter from parcels.v4.grid import Axis, Grid @@ -98,6 +99,24 @@ def zonal_periodic(self): ... # ? hmmm @property def lonlat_minmax(self): ... # ? hmmm + @property + def grid_type(self): + """This class is created *purely* for compatibility with v3 code and will be removed + or changed in future. + + TODO: Remove + """ + if len(self.lon.shape) <= 1: + if self.depth is None or len(self.depth.shape) <= 1: + return RectilinearZGrid + else: + return RectilinearSGrid + else: + if self.depth is None or len(self.depth.shape) <= 1: + return CurvilinearZGrid + else: + return CurvilinearSGrid + @staticmethod def create_grid(lon, lat, depth, time, time_origin, mesh, **kwargs): ... # ? hmmm diff --git a/tests/v4/test_gridadapter.py b/tests/v4/test_gridadapter.py index f83d7812da..edbd5adbb4 100644 --- a/tests/v4/test_gridadapter.py +++ b/tests/v4/test_gridadapter.py @@ -72,3 +72,18 @@ def test_grid_adapter_against_old(ds, attr): actual = getattr(adapter, attr) expected = getattr(grid, attr) assert_equal(actual, expected) + + +@pytest.mark.parametrize("ds", datasets.values()) +def test_grid_adapter_against_old_grid_type(ds): + adapter = GridAdapter(ds, periodic=False) + + grid = OldGrid.create_grid( + lon=ds.lon.values, + lat=ds.lat.values, + depth=ds.depth.values, + time=ds.time.values, + time_origin=TimeConverter(ds.time.values[0]), + mesh="spherical", + ) + assert isinstance(grid, adapter.grid_type) From b434eaa6eea2ddd17eaa2350a734ab2b6f116781 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 3 Apr 2025 17:30:45 +0200 Subject: [PATCH 28/39] Rename test file --- tests/v4/{test_grid.py => test_grid_vendored.py} | 2 ++ 1 file changed, 2 insertions(+) rename tests/v4/{test_grid.py => test_grid_vendored.py} (99%) diff --git a/tests/v4/test_grid.py b/tests/v4/test_grid_vendored.py similarity index 99% rename from tests/v4/test_grid.py rename to tests/v4/test_grid_vendored.py index 7f6f28a905..c47422fbd6 100644 --- a/tests/v4/test_grid.py +++ b/tests/v4/test_grid_vendored.py @@ -1,3 +1,5 @@ +"""Test cases that have been vendored from xgcm.""" + import numpy as np import pytest import xarray as xr From 9fdd9933791a94f389ed609977d67ad52ecd7457 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Fri, 4 Apr 2025 15:34:55 +0200 Subject: [PATCH 29/39] Add mesh kwarg to adapter --- parcels/v4/gridadapter.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/parcels/v4/gridadapter.py b/parcels/v4/gridadapter.py index 091f61328d..4240fb416b 100644 --- a/parcels/v4/gridadapter.py +++ b/parcels/v4/gridadapter.py @@ -31,8 +31,9 @@ def get_time(axis: Axis) -> npt.NDArray: class GridAdapter(Grid): - def __init__(self, ds, *args, **kwargs): + def __init__(self, ds, mesh="flat", *args, **kwargs): super().__init__(ds, *args, **kwargs) + self.mesh = mesh @property def lon(self): From 1fd9b143acc941a8179fb41d570b50e5361ea960 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Fri, 4 Apr 2025 16:37:01 +0200 Subject: [PATCH 30/39] Remove mesh property from adapter --- parcels/v4/gridadapter.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/parcels/v4/gridadapter.py b/parcels/v4/gridadapter.py index 4240fb416b..9af6079a8d 100644 --- a/parcels/v4/gridadapter.py +++ b/parcels/v4/gridadapter.py @@ -91,9 +91,6 @@ def time_origin(self): def _z4d(self) -> Literal[0, 1]: return 1 if self.depth.shape == 4 else 0 - @property - def mesh(self): ... # ? hmmm - @property def zonal_periodic(self): ... # ? hmmm From 42d3715ed4559583fc1c025ae9216d366210047c Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Fri, 4 Apr 2025 16:45:24 +0200 Subject: [PATCH 31/39] Remove kwargs from Grid.create_grid --- parcels/grid.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/parcels/grid.py b/parcels/grid.py index 1163c7bda7..060e2fa355 100644 --- a/parcels/grid.py +++ b/parcels/grid.py @@ -111,7 +111,6 @@ def create_grid( time, time_origin, mesh: Mesh, - **kwargs, ): lon = np.array(lon) lat = np.array(lat) @@ -121,14 +120,14 @@ def create_grid( if len(lon.shape) <= 1: if depth is None or len(depth.shape) <= 1: - return RectilinearZGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh, **kwargs) + return RectilinearZGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh) else: - return RectilinearSGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh, **kwargs) + return RectilinearSGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh) else: if depth is None or len(depth.shape) <= 1: - return CurvilinearZGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh, **kwargs) + return CurvilinearZGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh) else: - return CurvilinearSGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh, **kwargs) + return CurvilinearSGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh) def _check_zonal_periodic(self): if self.zonal_periodic or self.mesh == "flat" or self.lon.size == 1: From 98436c681be38f068d3f643e18ae34f03ea8bc77 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Fri, 4 Apr 2025 16:50:49 +0200 Subject: [PATCH 32/39] Update isinstance to GridType check --- parcels/particleset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/parcels/particleset.py b/parcels/particleset.py index ec452f858b..d5e7800ea6 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -13,7 +13,7 @@ from parcels._compat import MPI from parcels.application_kernels.advection import AdvectionRK4 from parcels.field import Field -from parcels.grid import CurvilinearGrid, GridType +from parcels.grid import GridType from parcels.interaction.interactionkernel import InteractionKernel from parcels.interaction.neighborsearch import ( BruteFlatNeighborSearch, @@ -430,7 +430,7 @@ def populate_indices(self): may be quite expensive. """ for i, grid in enumerate(self.fieldset.gridset.grids): - if not isinstance(grid, CurvilinearGrid): + if grid._gtype != GridType.CurvilinearGrid: continue tree_data = np.stack((grid.lon.flat, grid.lat.flat), axis=-1) From d34a45c2263629fe1a149dadab908d537812b32e Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Fri, 4 Apr 2025 16:54:23 +0200 Subject: [PATCH 33/39] Update grid_type on adapter to use enum --- parcels/v4/gridadapter.py | 13 +++++++------ tests/v4/test_gridadapter.py | 16 +--------------- 2 files changed, 8 insertions(+), 21 deletions(-) diff --git a/parcels/v4/gridadapter.py b/parcels/v4/gridadapter.py index 9af6079a8d..f97ea17ba5 100644 --- a/parcels/v4/gridadapter.py +++ b/parcels/v4/gridadapter.py @@ -3,7 +3,6 @@ import numpy as np import numpy.typing as npt -from parcels.grid import CurvilinearSGrid, CurvilinearZGrid, RectilinearSGrid, RectilinearZGrid from parcels.tools.converters import TimeConverter from parcels.v4.grid import Axis, Grid @@ -98,22 +97,24 @@ def zonal_periodic(self): ... # ? hmmm def lonlat_minmax(self): ... # ? hmmm @property - def grid_type(self): + def _gtype(self): """This class is created *purely* for compatibility with v3 code and will be removed or changed in future. TODO: Remove """ + from parcels.grid import GridType + if len(self.lon.shape) <= 1: if self.depth is None or len(self.depth.shape) <= 1: - return RectilinearZGrid + return GridType.RectilinearZGrid else: - return RectilinearSGrid + return GridType.RectilinearSGrid else: if self.depth is None or len(self.depth.shape) <= 1: - return CurvilinearZGrid + return GridType.CurvilinearZGrid else: - return CurvilinearSGrid + return GridType.CurvilinearSGrid @staticmethod def create_grid(lon, lat, depth, time, time_origin, mesh, **kwargs): ... # ? hmmm diff --git a/tests/v4/test_gridadapter.py b/tests/v4/test_gridadapter.py index edbd5adbb4..44cd777b39 100644 --- a/tests/v4/test_gridadapter.py +++ b/tests/v4/test_gridadapter.py @@ -55,6 +55,7 @@ def test_grid_adapter_properties_ground_truth(ds, attr, expected): "zdim", "tdim", "time_origin", + "_gtype", ], ) @pytest.mark.parametrize("ds", datasets.values()) @@ -72,18 +73,3 @@ def test_grid_adapter_against_old(ds, attr): actual = getattr(adapter, attr) expected = getattr(grid, attr) assert_equal(actual, expected) - - -@pytest.mark.parametrize("ds", datasets.values()) -def test_grid_adapter_against_old_grid_type(ds): - adapter = GridAdapter(ds, periodic=False) - - grid = OldGrid.create_grid( - lon=ds.lon.values, - lat=ds.lat.values, - depth=ds.depth.values, - time=ds.time.values, - time_origin=TimeConverter(ds.time.values[0]), - mesh="spherical", - ) - assert isinstance(grid, adapter.grid_type) From bc0f7b089ede5ed3ab03027d8d6a36ce18ecdc6f Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Fri, 4 Apr 2025 16:54:40 +0200 Subject: [PATCH 34/39] Remove GridCode alias --- parcels/grid.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/parcels/grid.py b/parcels/grid.py index 060e2fa355..ed8c67e295 100644 --- a/parcels/grid.py +++ b/parcels/grid.py @@ -10,7 +10,6 @@ "CurvilinearSGrid", "CurvilinearZGrid", "Grid", - "GridCode", "GridType", "RectilinearSGrid", "RectilinearZGrid", @@ -24,11 +23,6 @@ class GridType(IntEnum): CurvilinearSGrid = 3 -# GridCode has been renamed to GridType for consistency. -# TODO: Remove alias in Parcels v4 -GridCode = GridType - - class Grid: """Grid class that defines a (spatial and temporal) grid on which Fields are defined.""" From 182a22b5e991db090c38d0ee49fe9c88493ffb43 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Fri, 4 Apr 2025 17:04:11 +0200 Subject: [PATCH 35/39] Add lonlat_minmax to GridAdapter --- parcels/v4/gridadapter.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/parcels/v4/gridadapter.py b/parcels/v4/gridadapter.py index f97ea17ba5..99e34a2c65 100644 --- a/parcels/v4/gridadapter.py +++ b/parcels/v4/gridadapter.py @@ -4,7 +4,8 @@ import numpy.typing as npt from parcels.tools.converters import TimeConverter -from parcels.v4.grid import Axis, Grid +from parcels.v4.grid import Axis +from parcels.v4.grid import Grid as NewGrid def get_dimensionality(axis: Axis | None) -> int: @@ -29,11 +30,20 @@ def get_time(axis: Axis) -> npt.NDArray: return axis._ds[axis.coords["center"]].values -class GridAdapter(Grid): +class GridAdapter(NewGrid): def __init__(self, ds, mesh="flat", *args, **kwargs): super().__init__(ds, *args, **kwargs) self.mesh = mesh + self.lonlat_minmax = np.array( + [ + np.nanmin(self._ds["lon"]), + np.nanmax(self._ds["lon"]), + np.nanmin(self._ds["lat"]), + np.nanmax(self._ds["lat"]), + ] + ) + @property def lon(self): try: @@ -93,9 +103,6 @@ def _z4d(self) -> Literal[0, 1]: @property def zonal_periodic(self): ... # ? hmmm - @property - def lonlat_minmax(self): ... # ? hmmm - @property def _gtype(self): """This class is created *purely* for compatibility with v3 code and will be removed From f52ef6380057a68ec9b668fe198959155f5548b9 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 10 Apr 2025 16:35:56 +0200 Subject: [PATCH 36/39] review feedback --- tests/vendor/xgcm_datasets.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/vendor/xgcm_datasets.py b/tests/vendor/xgcm_datasets.py index 17c46f29c1..6f8082ef9f 100644 --- a/tests/vendor/xgcm_datasets.py +++ b/tests/vendor/xgcm_datasets.py @@ -82,7 +82,6 @@ ), }, ), - # my own invention "1d_left": xr.Dataset( {"data_g": (["XG"], np.random.rand(N)), "data_c": (["XC"], np.random.rand(N))}, coords={ From f3475cc29c850baaec4b6e250f8a6cd92f65270e Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Fri, 11 Apr 2025 16:11:45 +0200 Subject: [PATCH 37/39] Move dataset functions --- {tests/v4 => parcels/_datasets/structured}/grid_datasets.py | 2 ++ tests/v4/__init__.py | 0 tests/v4/test_gridadapter.py | 2 +- 3 files changed, 3 insertions(+), 1 deletion(-) rename {tests/v4 => parcels/_datasets/structured}/grid_datasets.py (99%) delete mode 100644 tests/v4/__init__.py diff --git a/tests/v4/grid_datasets.py b/parcels/_datasets/structured/grid_datasets.py similarity index 99% rename from tests/v4/grid_datasets.py rename to parcels/_datasets/structured/grid_datasets.py index 72928fc604..7f95aff504 100644 --- a/tests/v4/grid_datasets.py +++ b/parcels/_datasets/structured/grid_datasets.py @@ -1,3 +1,5 @@ +"""Datasets focussing on grid geometry""" + import numpy as np import xarray as xr diff --git a/tests/v4/__init__.py b/tests/v4/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/v4/test_gridadapter.py b/tests/v4/test_gridadapter.py index 44cd777b39..1fdcc66745 100644 --- a/tests/v4/test_gridadapter.py +++ b/tests/v4/test_gridadapter.py @@ -4,10 +4,10 @@ import pytest from numpy.testing import assert_allclose +from parcels._datasets.structured.grid_datasets import N, T, datasets from parcels.grid import Grid as OldGrid from parcels.tools.converters import TimeConverter from parcels.v4.gridadapter import GridAdapter -from tests.v4.grid_datasets import N, T, datasets TestCase = namedtuple("TestCase", ["Grid", "attr", "expected"]) From 324047f599364ad594a303fc8a358367d8a39888 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 15 Apr 2025 10:47:51 +0200 Subject: [PATCH 38/39] Patch grid._gtype call and pixi run tests-notebooks --- parcels/particleset.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/parcels/particleset.py b/parcels/particleset.py index d5e7800ea6..2a8d625ca0 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -430,7 +430,7 @@ def populate_indices(self): may be quite expensive. """ for i, grid in enumerate(self.fieldset.gridset.grids): - if grid._gtype != GridType.CurvilinearGrid: + if grid._gtype not in [GridType.CurvilinearZGrid, GridType.CurvilinearSGrid]: continue tree_data = np.stack((grid.lon.flat, grid.lat.flat), axis=-1) diff --git a/pyproject.toml b/pyproject.toml index d890114e53..361d8a8d36 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ platforms = ["win-64", "linux-64", "osx-64", "osx-arm64"] [tool.pixi.tasks] tests = "pytest" -tests-notebooks = "pytest -v -s --nbval-lax -k 'not documentation and not tutorial_periodic_boundaries and not tutorial_timevaryingdepthdimensions and not tutorial_particle_field_interaction and not tutorial_croco_3D and not tutorial_nemo_3D and not tutorial_analyticaladvection'" # TODO v4: Mirror ci.yml for notebooks being run +tests-notebooks = "pytest -v -s --nbval-lax -k 'not documentation and not tutorial_periodic_boundaries and not tutorial_timevaryingdepthdimensions and not tutorial_particle_field_interaction and not tutorial_croco_3D and not tutorial_nemo_3D and not tutorial_analyticaladvection' docs/examples" # TODO v4: Mirror ci.yml for notebooks being run coverage = "coverage run -m pytest && coverage html" typing = "mypy parcels" pre-commit = "pre-commit run --all-files" From ca24d17028952483064504a181c11e1014dd9d98 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 15 Apr 2025 11:05:18 +0200 Subject: [PATCH 39/39] remove doubled up TimeConverter.__eq__ --- parcels/tools/converters.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/parcels/tools/converters.py b/parcels/tools/converters.py index 36cf0c1bf4..4b60204e1f 100644 --- a/parcels/tools/converters.py +++ b/parcels/tools/converters.py @@ -61,9 +61,6 @@ def __init__(self, time_origin: float | np.datetime64 | np.timedelta64 | cftime. elif isinstance(time_origin, cftime.datetime): self.calendar = time_origin.calendar - def __eq__(self, other): - return self.time_origin == other.time_origin and self.calendar == other.calendar - def reltime(self, time: TimeConverter | np.datetime64 | np.timedelta64 | cftime.datetime) -> float | npt.NDArray: """Method to compute the difference, in seconds, between a time and the time_origin of the TimeConverter