Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
5659a2b
Take xgcm.Grid and adapt for our uses
VeckoTheGecko Apr 1, 2025
f26a308
Adapt xgcm tests
VeckoTheGecko Apr 1, 2025
b181643
Remove unused functions and TODOs
VeckoTheGecko Apr 1, 2025
72fadde
Rename file to xgcm_datasets.py
VeckoTheGecko Apr 1, 2025
b2ffce7
run pre-commit
VeckoTheGecko Apr 1, 2025
5e9a145
Move v4 grid files
VeckoTheGecko Apr 1, 2025
899f4e3
Run pre-commit and manual fixes
VeckoTheGecko Apr 1, 2025
b552b03
Update reprs and remove warning
VeckoTheGecko Apr 1, 2025
0bffe9e
Remove unused imports
VeckoTheGecko Apr 1, 2025
ca7dcfc
remove Grid._ti from old grid
VeckoTheGecko Apr 1, 2025
bdc7041
Add grid adapter test cases and mock implementation
VeckoTheGecko Apr 1, 2025
ff62de7
Define GridAdapter tdim, xdim, ydim and zdim
VeckoTheGecko Apr 1, 2025
f175eea
Add time to test_gridadapter dataset
VeckoTheGecko Apr 1, 2025
f7fa370
Add Z dim to test_gridadapter dataset
VeckoTheGecko Apr 1, 2025
df93007
Define GridAdapter lon, lat, depth, and time
VeckoTheGecko Apr 1, 2025
82b8271
Limit scope of keyerror except blocks
VeckoTheGecko Apr 1, 2025
6ef7fc6
Define GridAdapter time_origin
VeckoTheGecko Apr 1, 2025
1ec179a
Update adapter to return zero array
VeckoTheGecko Apr 2, 2025
ec49bc0
Add _z4d to gridadapter
VeckoTheGecko Apr 2, 2025
8627730
Move vendored file to vendor folder
VeckoTheGecko Apr 3, 2025
86f4d38
Update test suite repr
VeckoTheGecko Apr 3, 2025
c6de1cd
Add curvilinear test grid datasets
VeckoTheGecko Apr 3, 2025
78889e1
Update grid datasets with depth and time
VeckoTheGecko Apr 3, 2025
4c56c06
Move dataset and start with new test
VeckoTheGecko Apr 3, 2025
29568b9
Update test
VeckoTheGecko Apr 3, 2025
70588bc
Use lon, lat, depth arrays on underlying dataset
VeckoTheGecko Apr 3, 2025
faf7a6c
Add grid_type property to grid adapter
VeckoTheGecko Apr 3, 2025
b434eaa
Rename test file
VeckoTheGecko Apr 3, 2025
9fdd993
Add mesh kwarg to adapter
VeckoTheGecko Apr 4, 2025
1fd9b14
Remove mesh property from adapter
VeckoTheGecko Apr 4, 2025
42d3715
Remove kwargs from Grid.create_grid
VeckoTheGecko Apr 4, 2025
98436c6
Update isinstance to GridType check
VeckoTheGecko Apr 4, 2025
d34a45c
Update grid_type on adapter to use enum
VeckoTheGecko Apr 4, 2025
bc0f7b0
Remove GridCode alias
VeckoTheGecko Apr 4, 2025
182a22b
Add lonlat_minmax to GridAdapter
VeckoTheGecko Apr 4, 2025
f52ef63
review feedback
VeckoTheGecko Apr 10, 2025
f3475cc
Move dataset functions
VeckoTheGecko Apr 11, 2025
324047f
Patch grid._gtype call and pixi run tests-notebooks
VeckoTheGecko Apr 15, 2025
ca24d17
remove doubled up TimeConverter.__eq__
VeckoTheGecko Apr 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 169 additions & 0 deletions parcels/_datasets/structured/grid_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
"""Datasets focussing on grid geometry"""

import numpy as np
import xarray as xr

N = 30
T = 10


def rotated_curvilinear_grid():
XG = np.arange(N)
YG = np.arange(2 * N)
LON, LAT = np.meshgrid(XG, YG)

angle = -np.pi / 24
rotation = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]])

# rotate the LON and LAT grids
LON, LAT = np.einsum("ji, mni -> jmn", rotation, np.dstack([LON, LAT]))

return xr.Dataset(
{
"data_g": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)),
"data_c": (["ZC", "YC", "XC"], np.random.rand(3 * N, 2 * N, N)),
},
coords={
"XG": (["XG"], XG, {"axis": "X", "c_grid_axis_shift": -0.5}),
"YG": (["YG"], YG, {"axis": "Y", "c_grid_axis_shift": -0.5}),
"XC": (["XC"], XG + 0.5, {"axis": "X"}),
"YC": (["YC"], YG + 0.5, {"axis": "Y"}),
"ZG": (
["ZG"],
np.arange(3 * N),
{"axis": "Z", "c_grid_axis_shift": -0.5},
),
"ZC": (
["ZC"],
np.arange(3 * N) + 0.5,
{"axis": "Z"},
),
"depth": (["ZG"], np.arange(3 * N), {"axis": "Z"}),
"time": (["time"], np.arange(T), {"axis": "T"}),
"lon": (
["YG", "XG"],
LON,
{"axis": "X", "c_grid_axis_shift": -0.5}, # ? Needed?
),
"lat": (
["YG", "XG"],
LAT,
{"axis": "Y", "c_grid_axis_shift": -0.5}, # ? Needed?
),
},
)


def _cartesion_to_polar(x, y):
r = np.sqrt(x**2 + y**2)
theta = np.arctan2(y, x)
return r, theta


def _polar_to_cartesian(r, theta):
x = r * np.cos(theta)
y = r * np.sin(theta)
return x, y


def unrolled_cone_curvilinear_grid():
# Not a great unrolled cone, but this is good enough for testing
# you can use matplotlib pcolormesh to plot
XG = np.arange(N)
YG = np.arange(2 * N) * 0.25

pivot = -10, 0
LON, LAT = np.meshgrid(XG, YG)

new_lon_lat = []

min_lon = np.min(XG)
for lon, lat in zip(LON.flatten(), LAT.flatten(), strict=True):
r, _ = _cartesion_to_polar(lon - pivot[0], lat - pivot[1])
_, theta = _cartesion_to_polar(min_lon - pivot[0], lat - pivot[1])
theta *= 1.2
r *= 1.2
lon, lat = _polar_to_cartesian(r, theta)
new_lon_lat.append((lon + pivot[0], lat + pivot[1]))

new_lon, new_lat = zip(*new_lon_lat, strict=True)
LON, LAT = np.array(new_lon).reshape(LON.shape), np.array(new_lat).reshape(LAT.shape)

return xr.Dataset(
{
"data_g": (["ZG", "YG", "XG"], np.random.rand(3 * N, 2 * N, N)),
"data_c": (["ZC", "YC", "XC"], np.random.rand(3 * N, 2 * N, N)),
},
coords={
"XG": (["XG"], XG, {"axis": "X", "c_grid_axis_shift": -0.5}),
"YG": (["YG"], YG, {"axis": "Y", "c_grid_axis_shift": -0.5}),
"XC": (["XC"], XG + 0.5, {"axis": "X"}),
"YC": (["YC"], YG + 0.5, {"axis": "Y"}),
"ZG": (
["ZG"],
np.arange(3 * N),
{"axis": "Z", "c_grid_axis_shift": -0.5},
),
"ZC": (
["ZC"],
np.arange(3 * N) + 0.5,
{"axis": "Z"},
),
"depth": (["ZG"], np.arange(3 * N), {"axis": "Z"}),
"time": (["time"], np.arange(T), {"axis": "T"}),
"lon": (
["YG", "XG"],
LON,
{"axis": "X", "c_grid_axis_shift": -0.5}, # ? Needed?
),
"lat": (
["YG", "XG"],
LAT,
{"axis": "Y", "c_grid_axis_shift": -0.5}, # ? Needed?
),
},
)


datasets = {
"2d_left_rotated": rotated_curvilinear_grid(),
"ds_2d_left": xr.Dataset(
{
"data_g": (["time", "ZG", "YG", "XG"], np.random.rand(T, 3 * N, 2 * N, N)),
"data_c": (["time", "ZC", "YC", "XC"], np.random.rand(T, 3 * N, 2 * N, N)),
},
coords={
"XG": (
["XG"],
2 * np.pi / N * np.arange(0, N),
{"axis": "X", "c_grid_axis_shift": -0.5},
),
"XC": (["XC"], 2 * np.pi / N * (np.arange(0, N) + 0.5), {"axis": "X"}),
"YG": (
["YG"],
2 * np.pi / (2 * N) * np.arange(0, 2 * N),
{"axis": "Y", "c_grid_axis_shift": -0.5},
),
"YC": (
["YC"],
2 * np.pi / (2 * N) * (np.arange(0, 2 * N) + 0.5),
{"axis": "Y"},
),
"ZG": (
["ZG"],
np.arange(3 * N),
{"axis": "Z", "c_grid_axis_shift": -0.5},
),
"ZC": (
["ZC"],
np.arange(3 * N) + 0.5,
{"axis": "Z"},
),
"lon": (["XG"], 2 * np.pi / N * np.arange(0, N)),
"lat": (["YG"], 2 * np.pi / (2 * N) * np.arange(0, 2 * N)),
"depth": (["ZG"], np.arange(3 * N)),
"time": (["time"], np.arange(T), {"axis": "T"}),
},
),
"2d_left_unrolled_cone": unrolled_cone_curvilinear_grid(),
}
16 changes: 4 additions & 12 deletions parcels/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
"CurvilinearSGrid",
"CurvilinearZGrid",
"Grid",
"GridCode",
"GridType",
"RectilinearSGrid",
"RectilinearZGrid",
Expand All @@ -24,11 +23,6 @@ class GridType(IntEnum):
CurvilinearSGrid = 3


# GridCode has been renamed to GridType for consistency.
# TODO: Remove alias in Parcels v4
GridCode = GridType


class Grid:
"""Grid class that defines a (spatial and temporal) grid on which Fields are defined."""

Expand All @@ -40,7 +34,6 @@ def __init__(
time_origin: TimeConverter | None,
mesh: Mesh,
):
self._ti = -1
lon = np.array(lon)
lat = np.array(lat)
time = np.zeros(1, dtype=np.float64) if time is None else time
Expand Down Expand Up @@ -112,7 +105,6 @@ def create_grid(
time,
time_origin,
mesh: Mesh,
**kwargs,
):
lon = np.array(lon)
lat = np.array(lat)
Expand All @@ -122,14 +114,14 @@ def create_grid(

if len(lon.shape) <= 1:
if depth is None or len(depth.shape) <= 1:
return RectilinearZGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh, **kwargs)
return RectilinearZGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh)
else:
return RectilinearSGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh, **kwargs)
return RectilinearSGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh)
else:
if depth is None or len(depth.shape) <= 1:
return CurvilinearZGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh, **kwargs)
return CurvilinearZGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh)
else:
return CurvilinearSGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh, **kwargs)
return CurvilinearSGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh)

def _check_zonal_periodic(self):
if self.zonal_periodic or self.mesh == "flat" or self.lon.size == 1:
Expand Down
4 changes: 2 additions & 2 deletions parcels/particleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from parcels._compat import MPI
from parcels.application_kernels.advection import AdvectionRK4
from parcels.field import Field
from parcels.grid import CurvilinearGrid, GridType
from parcels.grid import GridType
from parcels.interaction.interactionkernel import InteractionKernel
from parcels.interaction.neighborsearch import (
BruteFlatNeighborSearch,
Expand Down Expand Up @@ -430,7 +430,7 @@
may be quite expensive.
"""
for i, grid in enumerate(self.fieldset.gridset.grids):
if not isinstance(grid, CurvilinearGrid):
if grid._gtype not in [GridType.CurvilinearZGrid, GridType.CurvilinearSGrid]:

Check warning on line 433 in parcels/particleset.py

View check run for this annotation

Codecov / codecov/patch

parcels/particleset.py#L433

Added line #L433 was not covered by tests
continue

tree_data = np.stack((grid.lon.flat, grid.lat.flat), axis=-1)
Expand Down
Empty file added parcels/v4/__init__.py
Empty file.
128 changes: 128 additions & 0 deletions parcels/v4/comodo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from collections import OrderedDict

# Representation of axis shifts
axis_shift_left = -0.5
axis_shift_right = 0.5
axis_shift_center = 0
# Characterizes valid shifts only
valid_axis_shifts = [axis_shift_left, axis_shift_right, axis_shift_center]


def assert_valid_comodo(ds):
"""Verify that the dataset meets comodo conventions

Parameters
----------
ds : xarray.dataset
"""
# TODO: implement
assert True

Check warning on line 19 in parcels/v4/comodo.py

View check run for this annotation

Codecov / codecov/patch

parcels/v4/comodo.py#L19

Added line #L19 was not covered by tests


def get_all_axes(ds):
axes = set()
for d in ds.dims:
if "axis" in ds[d].attrs:
axes.add(ds[d].attrs["axis"])
return axes


def get_axis_coords(ds, axis_name):
"""Find the name of the coordinates associated with a comodo axis.

Parameters
----------
ds : xarray.dataset or xarray.dataarray
axis_name : str
The name of the axis to find (e.g. 'X')

Returns
-------
coord_name : list
The names of the coordinate matching that axis
"""
coord_names = []
for d in ds.dims:
axis = ds[d].attrs.get("axis")
if axis == axis_name:
coord_names.append(d)
return coord_names


def get_axis_positions_and_coords(ds, axis_name):
coord_names = get_axis_coords(ds, axis_name)
ncoords = len(coord_names)
if ncoords == 0:
# didn't find anything for this axis
raise ValueError(f"Couldn't find any coordinates for axis {axis_name}")

# now figure out what type of coordinates these are:
# center, left, right, or outer
coords = {name: ds[name] for name in coord_names}

# some tortured logic for dealing with malformed c_grid_axis_shift
# attributes such as produced by old versions of xmitgcm.
# This should be a float (either -0.5 or 0.5)
# this function returns that, or True of the attribute is set to
# anything at all
def _maybe_fix_type(attr):
if attr is not None:
try:
return float(attr)
except TypeError:
return True

Check warning on line 73 in parcels/v4/comodo.py

View check run for this annotation

Codecov / codecov/patch

parcels/v4/comodo.py#L72-L73

Added lines #L72 - L73 were not covered by tests

axis_shift = {name: _maybe_fix_type(coord.attrs.get("c_grid_axis_shift")) for name, coord in coords.items()}
coord_len = {name: len(coord) for name, coord in coords.items()}

# look for the center coord, which is required
# this list will potentially contain "center", "inner", and "outer" points
coords_without_axis_shift = {name: coord_len[name] for name, shift in axis_shift.items() if not shift}
if len(coords_without_axis_shift) == 0:
raise ValueError(f"Couldn't find a center coordinate for axis {axis_name}")
elif len(coords_without_axis_shift) > 1:
raise ValueError(f"Found two coordinates without `c_grid_axis_shift` attribute for axis {axis_name}")
center_coord_name = list(coords_without_axis_shift)[0]
# the length of the center coord is key to decoding the other coords
axis_len = coord_len[center_coord_name]

# now we can start filling in the information about the different coords
axis_coords = OrderedDict()
axis_coords["center"] = center_coord_name

# now check the other coords
coord_names.remove(center_coord_name)
for name in coord_names:
shift = axis_shift[name]
clen = coord_len[name]
if clen == axis_len + 1:
axis_coords["outer"] = name
elif clen == axis_len - 1:
axis_coords["inner"] = name
elif shift == axis_shift_left:
if clen == axis_len:
axis_coords["left"] = name
else:
raise ValueError(f"Left coordinate {name} has incompatible length {clen:g} (axis_len={axis_len:g})")
elif shift == axis_shift_right:
if clen == axis_len:
axis_coords["right"] = name
else:
raise ValueError(f"Right coordinate {name} has incompatible length {clen:g} (axis_len={axis_len:g})")

Check warning on line 111 in parcels/v4/comodo.py

View check run for this annotation

Codecov / codecov/patch

parcels/v4/comodo.py#L111

Added line #L111 was not covered by tests
else:
if shift not in valid_axis_shifts:

Check warning on line 113 in parcels/v4/comodo.py

View check run for this annotation

Codecov / codecov/patch

parcels/v4/comodo.py#L113

Added line #L113 was not covered by tests
# string representing valid axis shifts
valids = str(valid_axis_shifts)[1:-1]

Check warning on line 115 in parcels/v4/comodo.py

View check run for this annotation

Codecov / codecov/patch

parcels/v4/comodo.py#L115

Added line #L115 was not covered by tests

raise ValueError(

Check warning on line 117 in parcels/v4/comodo.py

View check run for this annotation

Codecov / codecov/patch

parcels/v4/comodo.py#L117

Added line #L117 was not covered by tests
f"Coordinate {name} has invalid "
f"`c_grid_axis_shift` attribute `{shift!r}`. "
f"`c_grid_axis_shift` must be one of: {valids}"
)
else:
raise ValueError(f"Coordinate {name} has missing `c_grid_axis_shift` attribute `{shift!r}`")

Check warning on line 123 in parcels/v4/comodo.py

View check run for this annotation

Codecov / codecov/patch

parcels/v4/comodo.py#L123

Added line #L123 was not covered by tests
return axis_coords


def _assert_data_on_grid(da):
pass

Check warning on line 128 in parcels/v4/comodo.py

View check run for this annotation

Codecov / codecov/patch

parcels/v4/comodo.py#L128

Added line #L128 was not covered by tests
Loading
Loading