diff --git a/parcels/_datasets/structured/grid_datasets.py b/parcels/_datasets/structured/grid_datasets.py new file mode 100644 index 0000000000..7f95aff504 --- /dev/null +++ b/parcels/_datasets/structured/grid_datasets.py @@ -0,0 +1,169 @@ +"""Datasets focussing on grid geometry""" + +import numpy as np +import xarray as xr + +N = 30 +T = 10 + + +def rotated_curvilinear_grid(): + XG = np.arange(N) + YG = np.arange(2 * N) + LON, LAT = np.meshgrid(XG, YG) + + angle = -np.pi / 24 + rotation = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]]) + + # rotate the LON and LAT grids + LON, LAT = np.einsum("ji, mni -> jmn", rotation, np.dstack([LON, LAT])) + + return xr.Dataset( + { + "data_g": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)), + "data_c": (["ZC", "YC", "XC"], np.random.rand(3 * N, 2 * N, N)), + }, + coords={ + "XG": (["XG"], XG, {"axis": "X", "c_grid_axis_shift": -0.5}), + "YG": (["YG"], YG, {"axis": "Y", "c_grid_axis_shift": -0.5}), + "XC": (["XC"], XG + 0.5, {"axis": "X"}), + "YC": (["YC"], YG + 0.5, {"axis": "Y"}), + "ZG": ( + ["ZG"], + np.arange(3 * N), + {"axis": "Z", "c_grid_axis_shift": -0.5}, + ), + "ZC": ( + ["ZC"], + np.arange(3 * N) + 0.5, + {"axis": "Z"}, + ), + "depth": (["ZG"], np.arange(3 * N), {"axis": "Z"}), + "time": (["time"], np.arange(T), {"axis": "T"}), + "lon": ( + ["YG", "XG"], + LON, + {"axis": "X", "c_grid_axis_shift": -0.5}, # ? Needed? + ), + "lat": ( + ["YG", "XG"], + LAT, + {"axis": "Y", "c_grid_axis_shift": -0.5}, # ? Needed? + ), + }, + ) + + +def _cartesion_to_polar(x, y): + r = np.sqrt(x**2 + y**2) + theta = np.arctan2(y, x) + return r, theta + + +def _polar_to_cartesian(r, theta): + x = r * np.cos(theta) + y = r * np.sin(theta) + return x, y + + +def unrolled_cone_curvilinear_grid(): + # Not a great unrolled cone, but this is good enough for testing + # you can use matplotlib pcolormesh to plot + XG = np.arange(N) + YG = np.arange(2 * N) * 0.25 + + pivot = -10, 0 + LON, LAT = np.meshgrid(XG, YG) + + new_lon_lat = [] + + min_lon = np.min(XG) + for lon, lat in zip(LON.flatten(), LAT.flatten(), strict=True): + r, _ = _cartesion_to_polar(lon - pivot[0], lat - pivot[1]) + _, theta = _cartesion_to_polar(min_lon - pivot[0], lat - pivot[1]) + theta *= 1.2 + r *= 1.2 + lon, lat = _polar_to_cartesian(r, theta) + new_lon_lat.append((lon + pivot[0], lat + pivot[1])) + + new_lon, new_lat = zip(*new_lon_lat, strict=True) + LON, LAT = np.array(new_lon).reshape(LON.shape), np.array(new_lat).reshape(LAT.shape) + + return xr.Dataset( + { + "data_g": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)), + "data_c": (["ZC", "YC", "XC"], np.random.rand(3 * N, 2 * N, N)), + }, + coords={ + "XG": (["XG"], XG, {"axis": "X", "c_grid_axis_shift": -0.5}), + "YG": (["YG"], YG, {"axis": "Y", "c_grid_axis_shift": -0.5}), + "XC": (["XC"], XG + 0.5, {"axis": "X"}), + "YC": (["YC"], YG + 0.5, {"axis": "Y"}), + "ZG": ( + ["ZG"], + np.arange(3 * N), + {"axis": "Z", "c_grid_axis_shift": -0.5}, + ), + "ZC": ( + ["ZC"], + np.arange(3 * N) + 0.5, + {"axis": "Z"}, + ), + "depth": (["ZG"], np.arange(3 * N), {"axis": "Z"}), + "time": (["time"], np.arange(T), {"axis": "T"}), + "lon": ( + ["YG", "XG"], + LON, + {"axis": "X", "c_grid_axis_shift": -0.5}, # ? Needed? + ), + "lat": ( + ["YG", "XG"], + LAT, + {"axis": "Y", "c_grid_axis_shift": -0.5}, # ? Needed? + ), + }, + ) + + +datasets = { + "2d_left_rotated": rotated_curvilinear_grid(), + "ds_2d_left": xr.Dataset( + { + "data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, 3 * N, 2 * N, N)), + "data_c": (["time", "ZC", "YC", "XC"], np.random.rand(T, 3 * N, 2 * N, N)), + }, + coords={ + "XG": ( + ["XG"], + 2 * np.pi / N * np.arange(0, N), + {"axis": "X", "c_grid_axis_shift": -0.5}, + ), + "XC": (["XC"], 2 * np.pi / N * (np.arange(0, N) + 0.5), {"axis": "X"}), + "YG": ( + ["YG"], + 2 * np.pi / (2 * N) * np.arange(0, 2 * N), + {"axis": "Y", "c_grid_axis_shift": -0.5}, + ), + "YC": ( + ["YC"], + 2 * np.pi / (2 * N) * (np.arange(0, 2 * N) + 0.5), + {"axis": "Y"}, + ), + "ZG": ( + ["ZG"], + np.arange(3 * N), + {"axis": "Z", "c_grid_axis_shift": -0.5}, + ), + "ZC": ( + ["ZC"], + np.arange(3 * N) + 0.5, + {"axis": "Z"}, + ), + "lon": (["XG"], 2 * np.pi / N * np.arange(0, N)), + "lat": (["YG"], 2 * np.pi / (2 * N) * np.arange(0, 2 * N)), + "depth": (["ZG"], np.arange(3 * N)), + "time": (["time"], np.arange(T), {"axis": "T"}), + }, + ), + "2d_left_unrolled_cone": unrolled_cone_curvilinear_grid(), +} diff --git a/parcels/grid.py b/parcels/grid.py index 9332dcb909..ed8c67e295 100644 --- a/parcels/grid.py +++ b/parcels/grid.py @@ -10,7 +10,6 @@ "CurvilinearSGrid", "CurvilinearZGrid", "Grid", - "GridCode", "GridType", "RectilinearSGrid", "RectilinearZGrid", @@ -24,11 +23,6 @@ class GridType(IntEnum): CurvilinearSGrid = 3 -# GridCode has been renamed to GridType for consistency. -# TODO: Remove alias in Parcels v4 -GridCode = GridType - - class Grid: """Grid class that defines a (spatial and temporal) grid on which Fields are defined.""" @@ -40,7 +34,6 @@ def __init__( time_origin: TimeConverter | None, mesh: Mesh, ): - self._ti = -1 lon = np.array(lon) lat = np.array(lat) time = np.zeros(1, dtype=np.float64) if time is None else time @@ -112,7 +105,6 @@ def create_grid( time, time_origin, mesh: Mesh, - **kwargs, ): lon = np.array(lon) lat = np.array(lat) @@ -122,14 +114,14 @@ def create_grid( if len(lon.shape) <= 1: if depth is None or len(depth.shape) <= 1: - return RectilinearZGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh, **kwargs) + return RectilinearZGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh) else: - return RectilinearSGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh, **kwargs) + return RectilinearSGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh) else: if depth is None or len(depth.shape) <= 1: - return CurvilinearZGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh, **kwargs) + return CurvilinearZGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh) else: - return CurvilinearSGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh, **kwargs) + return CurvilinearSGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh) def _check_zonal_periodic(self): if self.zonal_periodic or self.mesh == "flat" or self.lon.size == 1: diff --git a/parcels/particleset.py b/parcels/particleset.py index ec452f858b..2a8d625ca0 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -13,7 +13,7 @@ from parcels._compat import MPI from parcels.application_kernels.advection import AdvectionRK4 from parcels.field import Field -from parcels.grid import CurvilinearGrid, GridType +from parcels.grid import GridType from parcels.interaction.interactionkernel import InteractionKernel from parcels.interaction.neighborsearch import ( BruteFlatNeighborSearch, @@ -430,7 +430,7 @@ def populate_indices(self): may be quite expensive. """ for i, grid in enumerate(self.fieldset.gridset.grids): - if not isinstance(grid, CurvilinearGrid): + if grid._gtype not in [GridType.CurvilinearZGrid, GridType.CurvilinearSGrid]: continue tree_data = np.stack((grid.lon.flat, grid.lat.flat), axis=-1) diff --git a/parcels/v4/__init__.py b/parcels/v4/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/parcels/v4/comodo.py b/parcels/v4/comodo.py new file mode 100644 index 0000000000..cf82d50330 --- /dev/null +++ b/parcels/v4/comodo.py @@ -0,0 +1,128 @@ +from collections import OrderedDict + +# Representation of axis shifts +axis_shift_left = -0.5 +axis_shift_right = 0.5 +axis_shift_center = 0 +# Characterizes valid shifts only +valid_axis_shifts = [axis_shift_left, axis_shift_right, axis_shift_center] + + +def assert_valid_comodo(ds): + """Verify that the dataset meets comodo conventions + + Parameters + ---------- + ds : xarray.dataset + """ + # TODO: implement + assert True + + +def get_all_axes(ds): + axes = set() + for d in ds.dims: + if "axis" in ds[d].attrs: + axes.add(ds[d].attrs["axis"]) + return axes + + +def get_axis_coords(ds, axis_name): + """Find the name of the coordinates associated with a comodo axis. + + Parameters + ---------- + ds : xarray.dataset or xarray.dataarray + axis_name : str + The name of the axis to find (e.g. 'X') + + Returns + ------- + coord_name : list + The names of the coordinate matching that axis + """ + coord_names = [] + for d in ds.dims: + axis = ds[d].attrs.get("axis") + if axis == axis_name: + coord_names.append(d) + return coord_names + + +def get_axis_positions_and_coords(ds, axis_name): + coord_names = get_axis_coords(ds, axis_name) + ncoords = len(coord_names) + if ncoords == 0: + # didn't find anything for this axis + raise ValueError(f"Couldn't find any coordinates for axis {axis_name}") + + # now figure out what type of coordinates these are: + # center, left, right, or outer + coords = {name: ds[name] for name in coord_names} + + # some tortured logic for dealing with malformed c_grid_axis_shift + # attributes such as produced by old versions of xmitgcm. + # This should be a float (either -0.5 or 0.5) + # this function returns that, or True of the attribute is set to + # anything at all + def _maybe_fix_type(attr): + if attr is not None: + try: + return float(attr) + except TypeError: + return True + + axis_shift = {name: _maybe_fix_type(coord.attrs.get("c_grid_axis_shift")) for name, coord in coords.items()} + coord_len = {name: len(coord) for name, coord in coords.items()} + + # look for the center coord, which is required + # this list will potentially contain "center", "inner", and "outer" points + coords_without_axis_shift = {name: coord_len[name] for name, shift in axis_shift.items() if not shift} + if len(coords_without_axis_shift) == 0: + raise ValueError(f"Couldn't find a center coordinate for axis {axis_name}") + elif len(coords_without_axis_shift) > 1: + raise ValueError(f"Found two coordinates without `c_grid_axis_shift` attribute for axis {axis_name}") + center_coord_name = list(coords_without_axis_shift)[0] + # the length of the center coord is key to decoding the other coords + axis_len = coord_len[center_coord_name] + + # now we can start filling in the information about the different coords + axis_coords = OrderedDict() + axis_coords["center"] = center_coord_name + + # now check the other coords + coord_names.remove(center_coord_name) + for name in coord_names: + shift = axis_shift[name] + clen = coord_len[name] + if clen == axis_len + 1: + axis_coords["outer"] = name + elif clen == axis_len - 1: + axis_coords["inner"] = name + elif shift == axis_shift_left: + if clen == axis_len: + axis_coords["left"] = name + else: + raise ValueError(f"Left coordinate {name} has incompatible length {clen:g} (axis_len={axis_len:g})") + elif shift == axis_shift_right: + if clen == axis_len: + axis_coords["right"] = name + else: + raise ValueError(f"Right coordinate {name} has incompatible length {clen:g} (axis_len={axis_len:g})") + else: + if shift not in valid_axis_shifts: + # string representing valid axis shifts + valids = str(valid_axis_shifts)[1:-1] + + raise ValueError( + f"Coordinate {name} has invalid " + f"`c_grid_axis_shift` attribute `{shift!r}`. " + f"`c_grid_axis_shift` must be one of: {valids}" + ) + else: + raise ValueError(f"Coordinate {name} has missing `c_grid_axis_shift` attribute `{shift!r}`") + return axis_coords + + +def _assert_data_on_grid(da): + pass diff --git a/parcels/v4/grid.py b/parcels/v4/grid.py new file mode 100644 index 0000000000..382cfa6511 --- /dev/null +++ b/parcels/v4/grid.py @@ -0,0 +1,597 @@ +"""This Grid object is adapted from xgcm.Grid, removing a lot of the code that is not needed for Parcels.""" + +import warnings +from collections import OrderedDict +from collections.abc import Iterable +from typing import ( + Any, +) + +from . import comodo + +_VALID_BOUNDARY = [None, "fill", "extend", "periodic"] + + +def _maybe_promote_str_to_list(a): + # TODO: improve this + if isinstance(a, str): + return [a] + else: + return a + + +class Axis: + """ + An object that represents a group of coordinates that all lie along the same + physical dimension but at different positions with respect to a grid cell. + There are four possible positions: + + Center + |------o-------|------o-------|------o-------|------o-------| + [0] [1] [2] [3] + + Left + |------o-------|------o-------|------o-------|------o-------| + [0] [1] [2] [3] + + Right + |------o-------|------o-------|------o-------|------o-------| + [0] [1] [2] [3] + + Inner + |------o-------|------o-------|------o-------|------o-------| + [0] [1] [2] + + Outer + |------o-------|------o-------|------o-------|------o-------| + [0] [1] [2] [3] [4] + + The `center` position is the only one without the `c_grid_axis_shift` + attribute, which must be present for the other four. However, the actual + value of `c_grid_axis_shift` is ignored for `inner` and `outer`, which are + differentiated by their length. + """ + + def __init__( + self, + ds, + axis_name, + periodic=True, + default_shifts=None, + coords=None, + boundary=None, + fill_value=None, + ): + """ + Create a new Axis object from an input dataset. + + Parameters + ---------- + ds : xarray.Dataset + Contains the relevant grid information. Coordinate attributes + should conform to Comodo conventions [1]_. + axis_name : str + The name of the axis (should match axis attribute) + periodic : bool, optional + Whether the domain is periodic along this axis + default_shifts : dict, optional + Default mapping from and to grid positions + (e.g. `{'center': 'left'}`). Will be inferred if not specified. + coords : dict, optional + Mapping of axis positions to coordinate names + (e.g. `{'center': 'XC', 'left: 'XG'}`) + boundary : str or dict, optional, + boundary can either be one of {None, 'fill', 'extend', 'extrapolate', 'periodic'} + + * None: Do not apply any boundary conditions. Raise an error if + boundary conditions are required for the operation. + * 'fill': Set values outside the array boundary to fill_value + (i.e. a Dirichlet boundary condition.) + * 'extend': Set values outside the array to the nearest array + value. (i.e. a limited form of Neumann boundary condition where + the difference at the boundary will be zero.) + * 'extrapolate': Set values by extrapolating linearly from the two + points nearest to the edge + * 'periodic' : Wrap arrays around. Equivalent to setting `periodic=True` + This sets the default value. It can be overriden by specifying the + boundary kwarg when calling specific methods. + fill_value : float, optional + The value to use in the boundary condition when `boundary='fill'`. + + References + ---------- + .. [1] Comodo Conventions https://web.archive.org/web/20160417032300/http://pycomodo.forge.imag.fr/norm.html + """ + if default_shifts is None: + default_shifts = {} + self._ds = ds + self.name = axis_name + self._periodic = periodic + if boundary not in _VALID_BOUNDARY: + raise ValueError(f"Expected 'boundary' to be one of {_VALID_BOUNDARY}. Received {boundary!r} instead.") + self.boundary = boundary + if fill_value is not None and not isinstance(fill_value, (int, float)): + raise ValueError("Expected 'fill_value' to be a number.") + self.fill_value = fill_value if fill_value is not None else 0.0 + + if coords: + # use specified coords + self.coords = {pos: name for pos, name in coords.items()} + else: + # fall back on comodo conventions + self.coords = comodo.get_axis_positions_and_coords(ds, axis_name) + + # self.coords is a dictionary with the following structure + # key: position_name {'center' ,'left' ,'right', 'outer', 'inner'} + # value: name of the dimension + + # set default position shifts + fallback_shifts = { + "center": ("left", "right", "outer", "inner"), + "left": ("center",), + "right": ("center",), + "outer": ("center",), + "inner": ("center",), + } + self._default_shifts = {} + for pos in self.coords: + # use user-specified value if present + if pos in default_shifts: + self._default_shifts[pos] = default_shifts[pos] + else: + for possible_shift in fallback_shifts[pos]: + if possible_shift in self.coords: + self._default_shifts[pos] = possible_shift + break + + ######################################################################## + # DEVELOPER DOCUMENTATION + # + # The attributes below are my best attempt to represent grid topology + # in a general way. The data structures are complicated, but I can't + # think of any way to simplify them. + # + # self._facedim (str) is the name of a dimension (e.g. 'face') or None. + # If it is None, that means that the grid topology is _simple_, i.e. + # that this is not a cubed-sphere grid or similar. For example: + # + # ds.dims == ('time', 'lat', 'lon') + # + # If _facedim is set to a dimension name, that means that shifting + # grid positions requires exchanging data among multiple "faces" + # (a.k.a. "tiles", "facets", etc.). For this to work, there must be a + # dimension corresponding to the different faces. This is `_facedim`. + # For example: + # + # ds.dims == ('time', 'face', 'lat', 'lon') + # + # In this case, `self._facedim == 'face'` + # + # We initialize all of this to None and let the `Grid` class handle + # setting these attributes for complex geometries. + self._facedim = None + # + # `self._connections` is a dictionary. It contains information about the + # connectivity among this axis and other axes. + # It should have the structure + # + # {facedim_index: ((left_facedim_index, left_axis, left_reverse), + # (right_facedim_index, right_axis, right_reverse)} + # + # `facedim_index` : a value used to index the `self._facedim` dimension + # (If `self._facedim` is `None`, then there should be only one key in + # `facedim_index` and that key should be `None`.) + # `left_facedim_index` : the facedim index of the neighbor to the left. + # (If `self._facedim` is `None`, this must also be `None`.) + # `left_axis` : an `Axis` object for the values to the left of this axis + # `left_reverse` : bool, whether the connection should be reversed. By + # default, the left side of this axis will be connected to the right + # side of the neighboring axis. `left_reverse` overrides this and + # instead connects to the left side of the neighboring axis + self._connections = {None: (None, None)} + + # now we implement periodic coordinates by setting appropriate + # connections + if periodic: + self._connections = {None: ((None, self, False), (None, self, False))} + + def __repr__(self): + is_periodic = "periodic" if self._periodic else "not periodic" + summary = [f""] + summary.append("Axis Coordinates:") + summary += self._coord_desc() + return "\n".join(summary) + + def _coord_desc(self): + summary = [] + for name, cname in self.coords.items(): + coord_info = f" * {name:<8} {cname}" + if name in self._default_shifts: + coord_info += f" --> {self._default_shifts[name]}" + summary.append(coord_info) + return summary + + def _get_position_name(self, da): + """Return the position and name of the axis coordinate in a DataArray.""" + for position, coord_name in self.coords.items(): + # TODO: should we have more careful checking of alignment here? + if coord_name in da.dims: + return position, coord_name + + raise KeyError(f"None of the DataArray's dims {da.dims!r} were found in axis coords.") + + def _get_axis_dim_num(self, da): + """Return the dimension number of the axis coordinate in a DataArray.""" + _, coord_name = self._get_position_name(da) + return da.get_axis_num(coord_name) + + +class Grid: + """ + An object with multiple :class:`parcels.Axis` objects representing different + independent axes. + """ + + def __init__( + self, + ds, + check_dims=True, + periodic=True, + default_shifts=None, + face_connections=None, + coords=None, + metrics=None, + boundary=None, + fill_value=None, + ): + """ + Create a new Grid object from an input dataset. + + Parameters + ---------- + ds : xarray.Dataset + Contains the relevant grid information. Coordinate attributes + should conform to Comodo conventions [1]_. + check_dims : bool, optional + Whether to check the compatibility of input data dimensions before + performing grid operations. + periodic : {True, False, list} + Whether the grid is periodic (i.e. "wrap-around"). If a list is + specified (e.g. ``['X', 'Y']``), the axis names in the list will be + be periodic and any other axes founds will be assumed non-periodic. + default_shifts : dict + A dictionary of dictionaries specifying default grid position + shifts (e.g. ``{'X': {'center': 'left', 'left': 'center'}}``) + face_connections : dict + Grid topology + coords : dict, optional + Specifies positions of dimension names along axes X, Y, Z, e.g + ``{'X': {'center': 'XC', 'left: 'XG'}}``. + Each key should be an axis name (e.g., `X`, `Y`, or `Z`) and map + to a dictionary which maps positions (`center`, `left`, `right`, + `outer`, `inner`) to dimension names in the dataset + (in the example above, `XC` is at the `center` position and `XG` + at the `left` position along the `X` axis). + If the values are not present in ``ds`` or are not dimensions, + an error will be raised. + metrics : dict, optional + Specification of grid metrics mapping axis names (X, Y, Z) to corresponding + metric variable names in the dataset + (e.g. {('X',):['dx_t'], ('X', 'Y'):['area_tracer', 'area_u']} + for the cell distance in the x-direction ``dx_t`` and the + horizontal cell areas ``area_tracer`` and ``area_u``, located at + different grid positions). + boundary : {None, 'fill', 'extend', 'extrapolate', dict}, optional + A flag indicating how to handle boundaries: + + * None: Do not apply any boundary conditions. Raise an error if + boundary conditions are required for the operation. + * 'fill': Set values outside the array boundary to fill_value + (i.e. a Dirichlet boundary condition.) + * 'extend': Set values outside the array to the nearest array + value. (i.e. a limited form of Neumann boundary condition.) + * 'extrapolate': Set values by extrapolating linearly from the two + points nearest to the edge + Optionally a dict mapping axis name to seperate values for each axis + can be passed. + fill_value : {float, dict}, optional + The value to use in boundary conditions with `boundary='fill'`. + Optionally a dict mapping axis name to seperate values for each axis + can be passed. + + References + ---------- + .. [1] Comodo Conventions https://web.archive.org/web/20160417032300/http://pycomodo.forge.imag.fr/norm.html + """ + if default_shifts is None: + default_shifts = {} + self._ds = ds + self._check_dims = check_dims + + if boundary: + warnings.warn( + "The `boundary` argument will be renamed " + "to `padding` to better reflect the process " + "of array padding and avoid confusion with " + "physical boundary conditions (e.g. ocean land boundary).", + category=DeprecationWarning, + stacklevel=2, + ) + + # Deprecation Warnigns + if periodic: + warnings.warn( + "The `periodic` argument will be deprecated. " + "To preserve previous behavior supply `boundary = 'periodic'.", + category=DeprecationWarning, + stacklevel=2, + ) + + if fill_value: + warnings.warn( + "The default fill_value will be changed to nan (from 0.0 previously) " + "in future versions. Provide `fill_value=0.0` to preserve previous behavior.", + category=DeprecationWarning, + stacklevel=2, + ) + + extrapolate_warning = False + if boundary == "extrapolate": + extrapolate_warning = True + if isinstance(boundary, dict): + if any([k == "extrapolate" for k in boundary.keys()]): + extrapolate_warning = True + if extrapolate_warning: + warnings.warn( + "The `boundary='extrapolate'` option will no longer be supported in future releases.", + category=DeprecationWarning, + stacklevel=2, + ) + + if coords: + all_axes = coords.keys() + else: + all_axes = comodo.get_all_axes(ds) + coords = {} + + # check coords input validity + for axis, positions in coords.items(): + for pos, dim in positions.items(): + if not (dim in ds.variables or dim in ds.dims): + raise ValueError( + f"Could not find dimension `{dim}` (for the `{pos}` position on axis `{axis}`) in input dataset." + ) + if dim not in ds.dims: + raise ValueError( + f"Input `{dim}` (for the `{pos}` position on axis `{axis}`) is not a dimension in the input datasets `ds`." + ) + + # Convert all inputs to axes-kwarg mappings + # TODO We need a way here to check valid input. Maybe also in _as_axis_kwargs? + # Parse axis properties + boundary = self._as_axis_kwarg_mapping(boundary, axes=all_axes) + fill_value = self._as_axis_kwarg_mapping(fill_value, axes=all_axes) + # TODO: In the future we want this the only place where we store these. + # TODO: This info needs to then be accessible to e.g. pad() + + # Parse list input. This case does only apply to periodic. + # Since we plan on deprecating it soon handle it here, so we can easily + # remove it later + if isinstance(periodic, list): + periodic = {axname: True for axname in periodic} + periodic = self._as_axis_kwarg_mapping(periodic, axes=all_axes) + + # Set properties on grid object. + self._facedim = list(face_connections.keys())[0] if face_connections else None + self._connections = face_connections if face_connections else None + # TODO: I think of the face connection data as grid not axes properties, since they almost by defintion + # TODO: involve multiple axes. In a future PR we should remove this info from the axes + # TODO: but make sure to properly port the checking functionality! + + # Populate axes. Much of this is just for backward compatibility. + self.axes = OrderedDict() + for axis_name in all_axes: + # periodic + is_periodic = periodic.get(axis_name, False) + + # default_shifts + if axis_name in default_shifts: + axis_default_shifts = default_shifts[axis_name] + else: + axis_default_shifts = {} + + # boundary + if isinstance(boundary, dict): + axis_boundary = boundary.get(axis_name, None) + elif isinstance(boundary, str) or boundary is None: + axis_boundary = boundary + else: + raise ValueError( + f"boundary={boundary} is invalid. Please specify a dictionary " + "mapping axis name to a boundary option; a string or None." + ) + + if isinstance(fill_value, dict): + axis_fillvalue = fill_value.get(axis_name, None) # TODO: This again sets defaults. Dont do that here. + elif isinstance(fill_value, (int, float)) or fill_value is None: + axis_fillvalue = fill_value + else: + raise ValueError( + f"fill_value={fill_value} is invalid. Please specify a dictionary " + "mapping axis name to a boundary option; a number or None." + ) + + self.axes[axis_name] = Axis( + ds, + axis_name, + is_periodic, + default_shifts=axis_default_shifts, + coords=coords.get(axis_name), + boundary=axis_boundary, + fill_value=axis_fillvalue, + ) + + if face_connections is not None: + self._assign_face_connections(face_connections) + + self._metrics = {} + + if metrics is not None: + for key, value in metrics.items(): + self.set_metrics(key, value) + + def _as_axis_kwarg_mapping( + self, + kwargs: Any | dict[str, Any], + axes: Iterable[str] | None = None, + ax_property_name=None, + default_value: Any | None = None, + ) -> dict[str, Any]: + """Convert kwarg input into dict for each available axis + E.g. for a grid with 2 axes for the keyword argument `periodic` + periodic = True --> periodic = {'X': True, 'Y':True} + or if not all axes are provided, the other axes will be parsed as defaults (None) + periodic = {'X':True} --> periodic={'X': True, 'Y':None} + """ + if axes is None: + axes = self.axes + + parsed_kwargs: dict[str, Any] = dict() + + if isinstance(kwargs, dict): + parsed_kwargs = kwargs + else: + for axname in axes: + parsed_kwargs[axname] = kwargs + + # Check axis properties for values that were not provided (before using the default) + if ax_property_name is not None: + for axname in axes: + if axname not in parsed_kwargs.keys() or parsed_kwargs[axname] is None: + ax_property = getattr(self.axes[axname], ax_property_name) + parsed_kwargs[axname] = ax_property + + # if None set to default value. + parsed_kwargs_w_defaults = {k: default_value if v is None else v for k, v in parsed_kwargs.items()} + # At this point the output should be guaranteed to have an entry per existing axis. + # If neither a default value was given, nor an axis property was found, the value will be mapped to None. + + # temporary hack to get periodic conditions from axis + if ax_property_name == "boundary": + for axname in axes: + if self.axes[axname]._periodic: + if axname not in parsed_kwargs_w_defaults.keys(): + parsed_kwargs_w_defaults[axname] = "periodic" + + return parsed_kwargs_w_defaults + + def _assign_face_connections(self, fc): + """Check a dictionary of face connections to make sure all the links are + consistent. + """ + if len(fc) > 1: + raise ValueError(f"Only one face dimension is supported for now. Instead found {repr(fc.keys())!r}") + + # we will populate this with the axes we find in face_connections + axis_connections = {} + + facedim = list(fc.keys())[0] + assert facedim in self._ds + + face_links = fc[facedim] + for fidx, face_axis_links in face_links.items(): + for axis, axis_links in face_axis_links.items(): + # initialize the axis dict if necssary + if axis not in axis_connections: + axis_connections[axis] = {} + link_left, link_right = axis_links + + def check_neighbor(link, position): + if link is None: + return + idx, ax, rev = link + # need to swap position if the link is reversed + correct_position = int(not position) if rev else position + try: + neighbor_link = face_links[idx][ax][correct_position] + except (KeyError, IndexError): + raise KeyError( + f"Couldn't find a face link for face {idx!r}" + f"in axis {ax!r} at position {correct_position!r}" + ) + idx_n, ax_n, rev_n = neighbor_link + if ax not in self.axes: + raise KeyError(f"axis {ax!r} is not a valid axis") + if ax_n not in self.axes: + raise KeyError(f"axis {ax_n!r} is not a valid axis") + if idx not in self._ds[facedim].values: + raise IndexError(f"{idx!r} is not a valid index for face dimension {facedim!r}") + if idx_n not in self._ds[facedim].values: + raise IndexError(f"{idx!r} is not a valid index for face dimension {facedim!r}") + # check for consistent links from / to neighbor + if (idx_n != fidx) or (ax_n != axis) or (rev_n != rev): # noqa: B023 # TODO: fix? + raise ValueError( + "Face link mismatch: neighbor doesn't" + " correctly link back to this face. " + f"face: {fidx!r}, axis: {axis!r}, position: {position!r}, " # noqa: B023 # TODO: fix? + f"rev: {rev!r}, link: {link!r}, neighbor_link: {neighbor_link!r}" + ) + # convert the axis name to an acutal axis object + actual_axis = self.axes[ax] + return idx, actual_axis, rev + + left = check_neighbor(link_left, 1) + right = check_neighbor(link_right, 0) + axis_connections[axis][fidx] = (left, right) + + for axis, axis_links in axis_connections.items(): + self.axes[axis]._facedim = facedim + self.axes[axis]._connections = axis_links + + def set_metrics(self, key, value, overwrite=False): + metric_axes = frozenset(_maybe_promote_str_to_list(key)) + axes_not_found = [ma for ma in metric_axes if ma not in self.axes] + if len(axes_not_found) > 0: + raise KeyError(f"Metric axes {axes_not_found!r} not compatible with grid axes {tuple(self.axes)!r}") + + metric_value = _maybe_promote_str_to_list(value) + for metric_varname in metric_value: + if metric_varname not in self._ds.variables: + raise KeyError(f"Metric variable {metric_varname} not found in dataset.") + + existing_metric_axes = set(self._metrics.keys()) + if metric_axes in existing_metric_axes: + value_exist = self._metrics.get(metric_axes) + # resetting coords avoids potential broadcasting / alignment issues + value_new = self._ds[metric_varname].reset_coords(drop=True) + did_overwrite = False + # go through each existing value until data array with matching dimensions is selected + for idx, ve in enumerate(value_exist): + # double check if dimensions match + if set(value_new.dims) == set(ve.dims): + if overwrite: + # replace existing data array with new data array input + self._metrics[metric_axes][idx] = value_new + did_overwrite = True + else: + raise ValueError( + f"Metric variable {ve.name} with dimensions {ve.dims} already assigned in metrics." + f" Overwrite {ve.name} with {metric_varname} by setting overwrite=True." + ) + # if no existing value matches new value dimension-wise, just append new value + if not did_overwrite: + self._metrics[metric_axes].append(value_new) + else: + # no existing metrics for metric_axes yet; initialize empty list + self._metrics[metric_axes] = [] + for metric_varname in metric_value: + metric_var = self._ds[metric_varname].reset_coords(drop=True) + self._metrics[metric_axes].append(metric_var) + + def __repr__(self): + summary = [""] + for name, axis in self.axes.items(): + is_periodic = "periodic" if axis._periodic else "not periodic" + summary.append(f"{name} Axis ({is_periodic}, boundary={axis.boundary!r}):") + summary += axis._coord_desc() + return "\n".join(summary) diff --git a/parcels/v4/gridadapter.py b/parcels/v4/gridadapter.py new file mode 100644 index 0000000000..99e34a2c65 --- /dev/null +++ b/parcels/v4/gridadapter.py @@ -0,0 +1,131 @@ +from typing import Literal + +import numpy as np +import numpy.typing as npt + +from parcels.tools.converters import TimeConverter +from parcels.v4.grid import Axis +from parcels.v4.grid import Grid as NewGrid + + +def get_dimensionality(axis: Axis | None) -> int: + if axis is None: + return 1 + first_coord = list(axis.coords.items())[0] + pos, coord = first_coord + + pos_to_dim = { # TODO: These could do with being explicitly tested + "center": lambda x: x, + "left": lambda x: x, + "right": lambda x: x, + "inner": lambda x: x + 1, + "outer": lambda x: x - 1, + } + + n = axis._ds[coord].size + return pos_to_dim[pos](n) + + +def get_time(axis: Axis) -> npt.NDArray: + return axis._ds[axis.coords["center"]].values + + +class GridAdapter(NewGrid): + def __init__(self, ds, mesh="flat", *args, **kwargs): + super().__init__(ds, *args, **kwargs) + self.mesh = mesh + + self.lonlat_minmax = np.array( + [ + np.nanmin(self._ds["lon"]), + np.nanmax(self._ds["lon"]), + np.nanmin(self._ds["lat"]), + np.nanmax(self._ds["lat"]), + ] + ) + + @property + def lon(self): + try: + _ = self.axes["X"] + except KeyError: + return np.zeros(1) + return self._ds["lon"].values + + @property + def lat(self): + try: + _ = self.axes["Y"] + except KeyError: + return np.zeros(1) + return self._ds["lat"].values + + @property + def depth(self): + try: + _ = self.axes["Z"] + except KeyError: + return np.zeros(1) + return self._ds["depth"].values + + @property + def time(self): + try: + axis = self.axes["T"] + except KeyError: + return np.zeros(1) + return get_time(axis) + + @property + def xdim(self): + return get_dimensionality(self.axes.get("X")) + + @property + def ydim(self): + return get_dimensionality(self.axes.get("Y")) + + @property + def zdim(self): + return get_dimensionality(self.axes.get("Z")) + + @property + def tdim(self): + return get_dimensionality(self.axes.get("T")) + + @property + def time_origin(self): + return TimeConverter(self.time[0]) + + @property + def _z4d(self) -> Literal[0, 1]: + return 1 if self.depth.shape == 4 else 0 + + @property + def zonal_periodic(self): ... # ? hmmm + + @property + def _gtype(self): + """This class is created *purely* for compatibility with v3 code and will be removed + or changed in future. + + TODO: Remove + """ + from parcels.grid import GridType + + if len(self.lon.shape) <= 1: + if self.depth is None or len(self.depth.shape) <= 1: + return GridType.RectilinearZGrid + else: + return GridType.RectilinearSGrid + else: + if self.depth is None or len(self.depth.shape) <= 1: + return GridType.CurvilinearZGrid + else: + return GridType.CurvilinearSGrid + + @staticmethod + def create_grid(lon, lat, depth, time, time_origin, mesh, **kwargs): ... # ? hmmm + + def _check_zonal_periodic(self): ... # ? hmmm + + def _add_Sdepth_periodic_halo(self, zonal, meridional, halosize): ... # ? hmmm diff --git a/pyproject.toml b/pyproject.toml index 1120a148c3..361d8a8d36 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ platforms = ["win-64", "linux-64", "osx-64", "osx-arm64"] [tool.pixi.tasks] tests = "pytest" -tests-notebooks = "pytest -v -s --nbval-lax -k 'not documentation and not tutorial_periodic_boundaries and not tutorial_timevaryingdepthdimensions and not tutorial_particle_field_interaction and not tutorial_croco_3D and not tutorial_nemo_3D and not tutorial_analyticaladvection'" # TODO v4: Mirror ci.yml for notebooks being run +tests-notebooks = "pytest -v -s --nbval-lax -k 'not documentation and not tutorial_periodic_boundaries and not tutorial_timevaryingdepthdimensions and not tutorial_particle_field_interaction and not tutorial_croco_3D and not tutorial_nemo_3D and not tutorial_analyticaladvection' docs/examples" # TODO v4: Mirror ci.yml for notebooks being run coverage = "coverage run -m pytest && coverage html" typing = "mypy parcels" pre-commit = "pre-commit run --all-files" @@ -188,6 +188,7 @@ ignore = [ "D205", # do not use bare except, specify exception instead "E722", + "F811", # TODO: These bugbear issues are to be resolved diff --git a/tests/v4/test_grid_vendored.py b/tests/v4/test_grid_vendored.py new file mode 100644 index 0000000000..c47422fbd6 --- /dev/null +++ b/tests/v4/test_grid_vendored.py @@ -0,0 +1,262 @@ +"""Test cases that have been vendored from xgcm.""" + +import numpy as np +import pytest +import xarray as xr + +from parcels.v4.grid import Axis, Grid +from tests.vendor.xgcm_datasets import ( + all_2d, # noqa: F401 + all_datasets, # noqa: F401 + datasets, + datasets_grid_metric, # noqa: F401 + nonperiodic_1d, # noqa: F401 + nonperiodic_2d, # noqa: F401 + periodic_1d, # noqa: F401 + periodic_2d, # noqa: F401 +) + + +# helper function to produce axes from datasets +def _get_axes(ds): + all_axes = {ds[c].attrs["axis"] for c in ds.dims if "axis" in ds[c].attrs} + axis_objs = {ax: Axis(ds, ax) for ax in all_axes} + return axis_objs + + +def test_create_axis(all_datasets): + ds, periodic, expected = all_datasets + axis_objs = _get_axes(ds) + for ax_expected, coords_expected in expected["axes"].items(): + assert ax_expected in axis_objs + this_axis = axis_objs[ax_expected] + for axis_name, coord_name in coords_expected.items(): + assert axis_name in this_axis.coords + assert this_axis.coords[axis_name] == coord_name + + +def _assert_axes_equal(ax1, ax2): + assert ax1.name == ax2.name + for pos, coord in ax1.coords.items(): + assert pos in ax2.coords + assert coord == ax2.coords[pos] + assert ax1._periodic == ax2._periodic + assert ax1._default_shifts == ax2._default_shifts + assert ax1._facedim == ax2._facedim + # TODO: make this work... + # assert ax1._connections == ax2._connections + + +def test_create_axis_no_comodo(all_datasets): + ds, periodic, expected = all_datasets + axis_objs = _get_axes(ds) + + # now strip out the metadata + ds_noattr = ds.copy() + for var in ds.variables: + ds_noattr[var].attrs.clear() + + for axis_name, axis_coords in expected["axes"].items(): + # now create the axis from scratch with no attributes + ax2 = Axis(ds_noattr, axis_name, coords=axis_coords) + # and compare to the one created with attributes + ax1 = axis_objs[axis_name] + + assert ax1.name == ax2.name + for pos, coord_name in ax1.coords.items(): + assert pos in ax2.coords + assert coord_name == ax2.coords[pos] + assert ax1._periodic == ax2._periodic + assert ax1._default_shifts == ax2._default_shifts + assert ax1._facedim == ax2._facedim + + +def test_create_axis_no_coords(all_datasets): + ds, periodic, expected = all_datasets + axis_objs = _get_axes(ds) + + ds_drop = ds.drop_vars(list(ds.coords)) + + for axis_name, axis_coords in expected["axes"].items(): + # now create the axis from scratch with no attributes OR coords + ax2 = Axis(ds_drop, axis_name, coords=axis_coords) + # and compare to the one created with attributes + ax1 = axis_objs[axis_name] + + assert ax1.name == ax2.name + for pos in ax1.coords.keys(): + assert pos in ax2.coords + assert ax1._periodic == ax2._periodic + assert ax1._default_shifts == ax2._default_shifts + assert ax1._facedim == ax2._facedim + + +def test_axis_repr(all_datasets): + ds, periodic, expected = all_datasets + axis_objs = _get_axes(ds) + for axis in axis_objs.values(): + r = repr(axis).split("\n") + assert r[0].startswith("" + + +def test_grid_dict_input_boundary_fill(nonperiodic_1d): + """Test axis kwarg input functionality using dict input""" + ds, _, _ = nonperiodic_1d + grid_direct = Grid(ds, periodic=False, boundary="fill", fill_value=5) + grid_dict = Grid(ds, periodic=False, boundary={"X": "fill"}, fill_value={"X": 5}) + assert grid_direct.axes["X"].fill_value == grid_dict.axes["X"].fill_value + assert grid_direct.axes["X"].boundary == grid_dict.axes["X"].boundary + + +def test_invalid_boundary_error(): + ds = datasets["1d_left"] + with pytest.raises(ValueError): + Axis(ds, "X", boundary="bad") + with pytest.raises(ValueError): + Grid(ds, boundary="bad") + with pytest.raises(ValueError): + Grid(ds, boundary={"X": "bad"}) + with pytest.raises(ValueError): + Grid(ds, boundary={"X": 0}) + with pytest.raises(ValueError): + Grid(ds, boundary=0) + + +def test_invalid_fill_value_error(): + ds = datasets["1d_left"] + with pytest.raises(ValueError): + Axis(ds, "X", fill_value="x") + with pytest.raises(ValueError): + Grid(ds, fill_value="bad") + with pytest.raises(ValueError): + Grid(ds, fill_value={"X": "bad"}) + + +def test_input_not_dims(): + data = np.random.rand(4, 5) + coord = np.random.rand(4, 5) + ds = xr.DataArray(data, dims=["x", "y"], coords={"c": (["x", "y"], coord)}).to_dataset(name="data") + msg = r"is not a dimension in the input dataset" + with pytest.raises(ValueError, match=msg): + Grid(ds, coords={"X": {"center": "c"}}) + + +def test_input_dim_notfound(): + data = np.random.rand(4, 5) + coord = np.random.rand(4, 5) + ds = xr.DataArray(data, dims=["x", "y"], coords={"c": (["x", "y"], coord)}).to_dataset(name="data") + msg = r"Could not find dimension `other` \(for the `center` position on axis `X`\) in input dataset." + with pytest.raises(ValueError, match=msg): + Grid(ds, coords={"X": {"center": "other"}}) diff --git a/tests/v4/test_gridadapter.py b/tests/v4/test_gridadapter.py new file mode 100644 index 0000000000..1fdcc66745 --- /dev/null +++ b/tests/v4/test_gridadapter.py @@ -0,0 +1,75 @@ +from collections import namedtuple + +import numpy as np +import pytest +from numpy.testing import assert_allclose + +from parcels._datasets.structured.grid_datasets import N, T, datasets +from parcels.grid import Grid as OldGrid +from parcels.tools.converters import TimeConverter +from parcels.v4.gridadapter import GridAdapter + +TestCase = namedtuple("TestCase", ["Grid", "attr", "expected"]) + +test_cases = [ + TestCase(datasets["ds_2d_left"], "lon", datasets["ds_2d_left"].XG.values), + TestCase(datasets["ds_2d_left"], "lat", datasets["ds_2d_left"].YG.values), + TestCase(datasets["ds_2d_left"], "depth", datasets["ds_2d_left"].ZG.values), + TestCase(datasets["ds_2d_left"], "time", datasets["ds_2d_left"].time.values), + TestCase(datasets["ds_2d_left"], "xdim", N), + TestCase(datasets["ds_2d_left"], "ydim", 2 * N), + TestCase(datasets["ds_2d_left"], "zdim", 3 * N), + TestCase(datasets["ds_2d_left"], "tdim", T), + TestCase(datasets["ds_2d_left"], "time_origin", TimeConverter(datasets["ds_2d_left"].time.values[0])), +] + + +def assert_equal(actual, expected): + if expected is None: + assert actual is None + elif isinstance(expected, TimeConverter): + assert actual == expected + elif isinstance(expected, np.ndarray): + assert actual.shape == expected.shape + assert_allclose(actual, expected) + else: + assert_allclose(actual, expected) + + +@pytest.mark.parametrize("ds, attr, expected", test_cases) +def test_grid_adapter_properties_ground_truth(ds, attr, expected): + adapter = GridAdapter(ds, periodic=False) + actual = getattr(adapter, attr) + assert_equal(actual, expected) + + +@pytest.mark.parametrize( + "attr", + [ + "lon", + "lat", + "depth", + "time", + "xdim", + "ydim", + "zdim", + "tdim", + "time_origin", + "_gtype", + ], +) +@pytest.mark.parametrize("ds", datasets.values()) +def test_grid_adapter_against_old(ds, attr): + adapter = GridAdapter(ds, periodic=False) + + grid = OldGrid.create_grid( + lon=ds.lon.values, + lat=ds.lat.values, + depth=ds.depth.values, + time=ds.time.values, + time_origin=TimeConverter(ds.time.values[0]), + mesh="spherical", + ) + actual = getattr(adapter, attr) + expected = getattr(grid, attr) + assert_equal(actual, expected) diff --git a/tests/vendor/__init__.py b/tests/vendor/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/vendor/xgcm_datasets.py b/tests/vendor/xgcm_datasets.py new file mode 100644 index 0000000000..6f8082ef9f --- /dev/null +++ b/tests/vendor/xgcm_datasets.py @@ -0,0 +1,364 @@ +"""Datasets vendored from xgcm test suite for the testing of grids.""" + +import numpy as np +import pytest +import xarray as xr + +# example from comodo website +# https://web.archive.org/web/20160417032300/http://pycomodo.forge.imag.fr/norm.html +# netcdf example { +# dimensions: +# ni = 9 ; +# ni_u = 10 ; +# variables: +# float ni(ni) ; +# ni:axis = "X" ; +# ni:standard_name = "x_grid_index" ; +# ni:long_name = "x-dimension of the grid" ; +# ni:c_grid_dynamic_range = "2:8" ; +# float ni_u(ni_u) ; +# ni_u:axis = "X" ; +# ni_u:standard_name = "x_grid_index_at_u_location" ; +# ni_u:long_name = "x-dimension of the grid" ; +# ni_u:c_grid_dynamic_range = "3:8" ; +# ni_u:c_grid_axis_shift = -0.5 ; +# data: +# ni = 1, 2, 3, 4, 5, 6, 7, 8, 9 ; +# ni_u = 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5 ; +# } + +N = 100 +datasets = { + # the comodo example, with renamed dimensions + "1d_outer": xr.Dataset( + {"data_c": (["XC"], np.random.rand(9)), "data_g": (["XG"], np.random.rand(10))}, + coords={ + "XC": ( + ["XC"], + np.arange(1, 10), + { + "axis": "X", + "standard_name": "x_grid_index", + "long_name": "x-dimension of the grid", + "c_grid_dynamic_range": "2:8", + }, + ), + "XG": ( + ["XG"], + np.arange(0.5, 10), + { + "axis": "X", + "standard_name": "x_grid_index_at_u_location", + "long_name": "x-dimension of the grid", + "c_grid_dynamic_range": "3:8", + "c_grid_axis_shift": -0.5, + }, + ), + }, + ), + "1d_inner": xr.Dataset( + {"data_c": (["XC"], np.random.rand(9)), "data_g": (["XG"], np.random.rand(8))}, + coords={ + "XC": ( + ["XC"], + np.arange(1, 10), + { + "axis": "X", + "standard_name": "x_grid_index", + "long_name": "x-dimension of the grid", + "c_grid_dynamic_range": "2:8", + }, + ), + "XG": ( + ["XG"], + np.arange(1.5, 9), + { + "axis": "X", + "standard_name": "x_grid_index_at_u_location", + "long_name": "x-dimension of the grid", + "c_grid_dynamic_range": "3:8", + "c_grid_axis_shift": -0.5, + }, + ), + }, + ), + "1d_left": xr.Dataset( + {"data_g": (["XG"], np.random.rand(N)), "data_c": (["XC"], np.random.rand(N))}, + coords={ + "XG": ( + ["XG"], + 2 * np.pi / N * np.arange(0, N), + {"axis": "X", "c_grid_axis_shift": -0.5}, + ), + "XC": (["XC"], 2 * np.pi / N * (np.arange(0, N) + 0.5), {"axis": "X"}), + }, + ), + "1d_right": xr.Dataset( + {"data_g": (["XG"], np.random.rand(N)), "data_c": (["XC"], np.random.rand(N))}, + coords={ + "XG": ( + ["XG"], + 2 * np.pi / N * np.arange(1, N + 1), + {"axis": "X", "c_grid_axis_shift": 0.5}, + ), + "XC": (["XC"], 2 * np.pi / N * (np.arange(0, N) - 0.5), {"axis": "X"}), + }, + ), + "2d_left": xr.Dataset( + { + "data_g": (["YG", "XG"], np.random.rand(2 * N, N)), + "data_c": (["YC", "XC"], np.random.rand(2 * N, N)), + }, + coords={ + "XG": ( + ["XG"], + 2 * np.pi / N * np.arange(0, N), + {"axis": "X", "c_grid_axis_shift": -0.5}, + ), + "XC": (["XC"], 2 * np.pi / N * (np.arange(0, N) + 0.5), {"axis": "X"}), + "YG": ( + ["YG"], + 2 * np.pi / (2 * N) * np.arange(0, 2 * N), + {"axis": "Y", "c_grid_axis_shift": -0.5}, + ), + "YC": ( + ["YC"], + 2 * np.pi / (2 * N) * (np.arange(0, 2 * N) + 0.5), + {"axis": "Y"}, + ), + }, + ), +} + +# include periodicity +datasets_with_periodicity = { + "nonperiodic_1d_outer": (datasets["1d_outer"], False), + "nonperiodic_1d_inner": (datasets["1d_inner"], False), + "periodic_1d_left": (datasets["1d_left"], True), + "nonperiodic_1d_left": (datasets["1d_left"], False), + "periodic_1d_right": (datasets["1d_right"], True), + "nonperiodic_1d_right": (datasets["1d_right"], False), + "periodic_2d_left": (datasets["2d_left"], True), + "nonperiodic_2d_left": (datasets["2d_left"], False), + "xperiodic_2d_left": (datasets["2d_left"], ["X"]), + "yperiodic_2d_left": (datasets["2d_left"], ["Y"]), +} + +expected_values = { + "nonperiodic_1d_outer": {"axes": {"X": {"center": "XC", "outer": "XG"}}}, + "nonperiodic_1d_inner": {"axes": {"X": {"center": "XC", "inner": "XG"}}}, + "periodic_1d_left": {"axes": {"X": {"center": "XC", "left": "XG"}}}, + "nonperiodic_1d_left": {"axes": {"X": {"center": "XC", "left": "XG"}}}, + "periodic_1d_right": { + "axes": {"X": {"center": "XC", "right": "XG"}}, + "shift": True, + }, + "nonperiodic_1d_right": { + "axes": {"X": {"center": "XC", "right": "XG"}}, + "shift": True, + }, + "periodic_2d_left": { + "axes": { + "X": {"center": "XC", "left": "XG"}, + "Y": {"center": "YC", "left": "YG"}, + } + }, + "nonperiodic_2d_left": { + "axes": { + "X": {"center": "XC", "left": "XG"}, + "Y": {"center": "YC", "left": "YG"}, + } + }, + "xperiodic_2d_left": { + "axes": { + "X": {"center": "XC", "left": "XG"}, + "Y": {"center": "YC", "left": "YG"}, + } + }, + "yperiodic_2d_left": { + "axes": { + "X": {"center": "XC", "left": "XG"}, + "Y": {"center": "YC", "left": "YG"}, + } + }, +} + + +@pytest.fixture(scope="module", params=datasets_with_periodicity.keys()) +def all_datasets(request): + ds, periodic = datasets_with_periodicity[request.param] + return ds, periodic, expected_values[request.param] + + +@pytest.fixture( + scope="module", + params=[ + "nonperiodic_1d_outer", + "nonperiodic_1d_inner", + "nonperiodic_1d_left", + "nonperiodic_1d_right", + ], +) +def nonperiodic_1d(request): + ds, periodic = datasets_with_periodicity[request.param] + return ds, periodic, expected_values[request.param] + + +@pytest.fixture(scope="module", params=["periodic_1d_left", "periodic_1d_right"]) +def periodic_1d(request): + ds, periodic = datasets_with_periodicity[request.param] + return ds, periodic, expected_values[request.param] + + +@pytest.fixture( + scope="module", + params=[ + "periodic_2d_left", + "nonperiodic_2d_left", + "xperiodic_2d_left", + "yperiodic_2d_left", + ], +) +def all_2d(request): + ds, periodic = datasets_with_periodicity[request.param] + return ds, periodic, expected_values[request.param] + + +@pytest.fixture(scope="module", params=["periodic_2d_left"]) +def periodic_2d(request): + ds, periodic = datasets_with_periodicity[request.param] + return ds, periodic, expected_values[request.param] + + +@pytest.fixture( + scope="module", + params=["nonperiodic_2d_left", "xperiodic_2d_left", "yperiodic_2d_left"], +) +def nonperiodic_2d(request): + ds, periodic = datasets_with_periodicity[request.param] + return ds, periodic, expected_values[request.param] + + +def datasets_grid_metric(grid_type): + """Uniform grid test dataset. + Should eventually be extended to nonuniform grid + """ + xt = np.arange(4) + xu = xt + 0.5 + yt = np.arange(5) + yu = yt + 0.5 + zt = np.arange(6) + zw = zt + 0.5 + t = np.arange(10) + + def data_generator(): + return np.random.rand(len(xt), len(yt), len(t), len(zt)) + + # Need to add a tracer here to get the tracer dimsuffix + tr = xr.DataArray(data_generator(), coords=[("xt", xt), ("yt", yt), ("time", t), ("zt", zt)]) + + u_b = xr.DataArray(data_generator(), coords=[("xu", xu), ("yu", yu), ("time", t), ("zt", zt)]) + + v_b = xr.DataArray(data_generator(), coords=[("xu", xu), ("yu", yu), ("time", t), ("zt", zt)]) + + u_c = xr.DataArray(data_generator(), coords=[("xu", xu), ("yt", yt), ("time", t), ("zt", zt)]) + + v_c = xr.DataArray(data_generator(), coords=[("xt", xt), ("yu", yu), ("time", t), ("zt", zt)]) + + wt = xr.DataArray(data_generator(), coords=[("xt", xt), ("yt", yt), ("time", t), ("zw", zw)]) + + # maybe also add some other combo of x,t y,t arrays.... + timeseries = xr.DataArray(np.random.rand(len(t)), coords=[("time", t)]) + + # northeast distance + dx = 0.3 + dy = 2 + dz = 20 + + dx_ne = xr.DataArray(np.ones([len(xt), len(yt)]) * dx - 0.1, coords=[("xu", xu), ("yu", yu)]) + dx_n = xr.DataArray(np.ones([len(xt), len(yt)]) * dx - 0.2, coords=[("xt", xt), ("yu", yu)]) + dx_e = xr.DataArray(np.ones([len(xt), len(yt)]) * dx - 0.3, coords=[("xu", xu), ("yt", yt)]) + dx_t = xr.DataArray(np.ones([len(xt), len(yt)]) * dx - 0.4, coords=[("xt", xt), ("yt", yt)]) + + dy_ne = xr.DataArray(np.ones([len(xt), len(yt)]) * dy + 0.1, coords=[("xu", xu), ("yu", yu)]) + dy_n = xr.DataArray(np.ones([len(xt), len(yt)]) * dy + 0.2, coords=[("xt", xt), ("yu", yu)]) + dy_e = xr.DataArray(np.ones([len(xt), len(yt)]) * dy + 0.3, coords=[("xu", xu), ("yt", yt)]) + dy_t = xr.DataArray(np.ones([len(xt), len(yt)]) * dy + 0.4, coords=[("xt", xt), ("yt", yt)]) + + # dz elements at horizontal tracer points + dz_t = xr.DataArray(data_generator() * dz, coords=[("xt", xt), ("yt", yt), ("time", t), ("zt", zt)]) + dz_w = xr.DataArray(data_generator() * dz, coords=[("xt", xt), ("yt", yt), ("time", t), ("zw", zw)]) + # dz elements at velocity points + dz_w_ne = xr.DataArray(data_generator() * dz, coords=[("xu", xu), ("yu", yu), ("time", t), ("zw", zw)]) + dz_w_n = xr.DataArray(data_generator() * dz, coords=[("xt", xt), ("yu", yu), ("time", t), ("zw", zw)]) + dz_w_e = xr.DataArray(data_generator() * dz, coords=[("xu", xu), ("yt", yt), ("time", t), ("zw", zw)]) + + # Make sure the areas are not just the product of x and y distances + area_ne = (dx_ne * dy_ne) + 0.1 + area_n = (dx_n * dy_n) + 0.2 + area_e = (dx_e * dy_e) + 0.3 + area_t = (dx_t * dy_t) + 0.4 + + # calculate volumes, but again add small differences. + volume_t = (dx_t * dy_t * dz_t) + 0.25 + + def _add_metrics(obj): + obj = obj.copy() + for name, data in [ + ("dx_ne", dx_ne), + ("dx_n", dx_n), + ("dx_e", dx_e), + ("dx_t", dx_t), + ("dy_ne", dy_ne), + ("dy_n", dy_n), + ("dy_e", dy_e), + ("dy_t", dy_t), + ("dz_t", dz_t), + ("dz_w", dz_w), + ("dz_w_ne", dz_w_ne), + ("dz_w_n", dz_w_n), + ("dz_w_e", dz_w_e), + ("area_ne", area_ne), + ("area_n", area_n), + ("area_e", area_e), + ("area_t", area_t), + ("volume_t", volume_t), + ]: + obj.coords[name] = data + obj.coords[name].attrs["tracked_name"] = name + # add xgcm attrs + for ii in ["xu", "xt"]: + obj[ii].attrs["axis"] = "X" + for ii in ["yu", "yt"]: + obj[ii].attrs["axis"] = "Y" + for ii in ["zt", "zw"]: + obj[ii].attrs["axis"] = "Z" + for ii in ["time"]: + obj[ii].attrs["axis"] = "T" + for ii in ["xu", "yu", "zw"]: + obj[ii].attrs["c_grid_axis_shift"] = 0.5 + return obj + + coords = { + "X": {"center": "xt", "right": "xu"}, + "Y": {"center": "yt", "right": "yu"}, + "Z": {"center": "zt", "right": "zw"}, + } + + metrics = { + ("X",): ["dx_t", "dx_n", "dx_e", "dx_ne"], + ("Y",): ["dy_t", "dy_n", "dy_e", "dy_ne"], + ("Z",): ["dz_t", "dz_w", "dz_w_ne", "dz_w_n", "dz_w_e"], + ("X", "Y"): ["area_t", "area_n", "area_e", "area_ne"], + ("X", "Y", "Z"): ["volume_t"], + } + + # combine to different grid configurations (B and C grid) + if grid_type == "B": + ds = _add_metrics(xr.Dataset({"u": u_b, "v": v_b, "wt": wt, "tracer": tr, "timeseries": timeseries})) + elif grid_type == "C": + ds = _add_metrics(xr.Dataset({"u": u_c, "v": v_c, "wt": wt, "tracer": tr, "timeseries": timeseries})) + else: + raise ValueError(f"Invalid input [{grid_type}] for `grid_type`. Only supports `B` and `C` at the moment ") + + return ds, coords, metrics