From 621a94830dfc544336555284d84b8c6f34445fca Mon Sep 17 00:00:00 2001 From: Joe Date: Tue, 18 Mar 2025 16:22:39 -0400 Subject: [PATCH 01/46] [WIP] Draft implementation of the XField and XFieldset class Ultimately, yes, these will replace Field and Fieldset. For the time being, because this is a massive overhaul, keeping xfield and xfieldset in their own files is a small step towards leveraging xarray and uxarray natively under the hood. --- parcels/xfield.py | 311 +++++++++++++++++++++++++++++++++++++++++++ parcels/xfieldset.py | 91 +++++++++++++ 2 files changed, 402 insertions(+) create mode 100644 parcels/xfield.py create mode 100644 parcels/xfieldset.py diff --git a/parcels/xfield.py b/parcels/xfield.py new file mode 100644 index 0000000000..f07eb7005c --- /dev/null +++ b/parcels/xfield.py @@ -0,0 +1,311 @@ +import collections +import math +import warnings +from typing import TYPE_CHECKING, cast + +import dask.array as da +import numpy as np +import xarray as xr +import uxarray as ux + +import parcels.tools.interpolation_utils as i_u +from parcels._compat import add_note +from parcels._interpolation import ( + InterpolationContext2D, + InterpolationContext3D, + get_2d_interpolator_registry, + get_3d_interpolator_registry, +) +from parcels._typing import ( + GridIndexingType, + InterpMethod, + InterpMethodOption, + Mesh, + VectorType, + assert_valid_gridindexingtype, + assert_valid_interp_method, +) +from parcels.tools._helpers import default_repr, field_repr, should_calculate_next_ti +from parcels.tools.converters import ( + TimeConverter, + UnitConverter, + unitconverters_map, +) +from parcels.tools.statuscodes import ( + AllParcelsErrorCodes, + FieldOutOfBoundError, + FieldOutOfBoundSurfaceError, + FieldSamplingError, + _raise_field_out_of_bound_error, +) +from parcels.tools.warnings import FieldSetWarning +import inspect +from typing import Callable, Union + +#from ._index_search import _search_indices_curvilinear, _search_indices_rectilinear, _search_time_index + +if TYPE_CHECKING: + import numpy.typing as npt + + from parcels.xfieldset import XFieldSet + +__all__ = ["XField", "XVectorField"] + + +def _isParticle(key): + if hasattr(key, "obs_written"): + return True + else: + return False + + +def _deal_with_errors(error, key, vector_type: VectorType): + if _isParticle(key): + key.state = AllParcelsErrorCodes[type(error)] + elif _isParticle(key[-1]): + key[-1].state = AllParcelsErrorCodes[type(error)] + else: + raise RuntimeError(f"{error}. Error could not be handled because particle was not part of the Field Sampling.") + + if vector_type and "3D" in vector_type: + return (0, 0, 0) + elif vector_type == "2D": + return (0, 0) + else: + return 0 + +class XField: + """The XField class that holds scalar field data. + The `XField` object is a wrapper around a xarray.DataArray or uxarray.UxDataArray object. + Additionally, it holds a dynamic Callable procedure that is used to interpolate the field data. + During initialization, the user can supply a custom interpolation method that is used to interpolate the field data, + so long as the interpolation method has the correct signature. + + Notes + ----- + + + The xarray.DataArray or uxarray.UxDataArray object contains the field data and metadata. + * dims: (time, [nz1 | nz], [face_lat | node_lat | edge_lat], [face_lon | node_lon | edge_lon]) + * attrs: (location, mesh) + + When using a xarray.DataArray object, + * The xarray.DataArray object must have the "location" and "mesh" attributes set. + * The "location" attribute must be set to one of the following to define which pairing of points a field is associated with. + * "node" + * "face" + * "x_edge" + * "y_edge" + * For an A-Grid, the "location" attribute must be set to / is assumed to be "node" (node_lat,node_lon). + * For a C-Grid, the "location" setting for a field has the following interpretation: + * "node" ~> the field is associated with the vorticity points (node_lat, node_lon) + * "face" ~> the field is associated with the tracer points (face_lat, face_lon) + * "x_edge" ~> the field is associated with the u-velocity points (face_lat, node_lon) + * "y_edge" ~> the field is associated with the v-velocity points (node_lat, face_lon) + + When using a uxarray.UxDataArray object, + * The uxarray.UxDataArray.UxGrid object must have the "Conventions" attribute set to "UGRID-1.0" + and the uxarray.UxDataArray object must comply with the UGRID conventions. + See https://ugrid-conventions.github.io/ugrid-conventions/ for more information. + + """ + + @staticmethod + def _interp_template( + self, + ti: int, + zi: int, + ei: int, + t: Union[np.float32,np.float64], + z: Union[np.float32,np.float64], + y: Union[np.float32,np.float64], + x: Union[np.float32,np.float64] + )-> Union[np.float32,np.float64]: + """ Template function used for the signature check of the lateral interpolation methods.""" + return 0.0 + + def _validate_interp_function(self, func: Callable): + """Ensures that the function has the correct signature.""" + expected_params = ["ti", "zi", "ei", "t", "z", "y", "x"] + expected_return_types = (np.float32,np.float64) + + sig = inspect.signature(func) + params = list(sig.parameters.keys()) + + # Check the parameter names and count + if params != expected_params: + raise TypeError( + f"Function must have parameters {expected_params}, but got {params}" + ) + + # Check return annotation if present + return_annotation = sig.return_annotation + if return_annotation not in (inspect.Signature.empty, *expected_return_types): + raise TypeError( + f"Function must return a float, but got {return_annotation}" + ) + + def __init__( + self, + name: str, + data: xr.DataArray | ux.UxDataArray, + fieldtype=None, + interp_method: Callable | None = None, + allow_time_extrapolation: bool | None = None, + ): + + self.name = name + self.data = data + + self._validate_dataarray(data) + + self._parent_mesh = data.attributes["mesh"] + self._location = data.attributes["location"] + # Set the vertical location + if "nz1" in data.dims: + self._vertical_location = "center" + elif "nz" in data.dims: + self._vertical_location = "face" + + # Setting the interpolation method dynamically + if interp_method is None: + self._interp_method = self._interp_template # Default to method that returns 0 always + else: + self._validate_interp_function(interp_method) + self._interp_method = interp_method + + self.igrid = -1 # Default the grid index to -1 + self.fieldtype = self.name if fieldtype is None else fieldtype + + self.fieldset: XFieldSet | None = None + if allow_time_extrapolation is None: + self.allow_time_extrapolation = True if len(self.data["time"]) == 1 else False + else: + self.allow_time_extrapolation = allow_time_extrapolation + + def _validate_dataarray(self): + """ Verifies that all the required attributes are present in the xarray.DataArray or + uxarray.UxDataArray object.""" + + # Validate dimensions + if not( "nz1" in self.data.dims or "nz" in self.data.dims ): + raise ValueError( + f"Field {self.name} is missing a 'nz1' or 'nz' dimension in the field's metadata. " + "This attribute is required for xarray.DataArray objects." + ) + + if not( "time" in self.data.dims ): + raise ValueError( + f"Field {self.name} is missing a 'time' dimension in the field's metadata. " + "This attribute is required for xarray.DataArray objects." + ) + + # Validate attributes + required_keys = ["location", "mesh"] + for key in required_keys: + if key not in self.data.attrs.keys(): + raise ValueError( + f"Field {self.name} is missing a '{key}' attribute in the field's metadata. " + "This attribute is required for xarray.DataArray objects." + ) + + if type(self.data) is ux.UxDataArray: + self._validate_uxgrid() + + + def _validate_uxgrid(self): + """ Verifies that all the required attributes are present in the uxarray.UxDataArray.UxGrid object.""" + + if "Conventions" not in self.data.uxgrid.attrs.keys(): + raise ValueError( + f"Field {self.name} is missing a 'Conventions' attribute in the field's metadata. " + "This attribute is required for uxarray.UxDataArray objects." + ) + if self.data.uxgrid.attrs["Conventions"] != "UGRID-1.0": + raise ValueError( + f"Field {self.name} has a 'Conventions' attribute that is not 'UGRID-1.0'. " + "This attribute is required for uxarray.UxDataArray objects." + "See https://ugrid-conventions.github.io/ugrid-conventions/ for more information." + ) + + + def __getattr__(self, key: str): + return getattr(self.data, key) + + def __contains__(self, key: str): + return key in self.data + + +class XVectorField: + """XVectorField class that holds vector field data needed to execute particles.""" + def __init__(self, name: str, U: XField, V: XField, W: XField | None = None): + self.name = name + self.U = U + self.V = V + self.W = W + + if self.W: + self.vector_type = "3D" + else: + self.vector_type = "2D" + + def __repr__(self): + return f"""<{type(self).__name__}> + name: {self.name!r} + U: {default_repr(self.U)} + V: {default_repr(self.V)} + W: {default_repr(self.W)}""" + + + # @staticmethod + # To do : def _check_grid_dimensions(grid1, grid2): + # return ( + # np.allclose(grid1.lon, grid2.lon) + # and np.allclose(grid1.lat, grid2.lat) + # and np.allclose(grid1.depth, grid2.depth) + # and np.allclose(grid1.time, grid2.time) + # ) + + +# Private helper routines +def _barycentric_coordinates(nodes, point): + """ + Compute the barycentric coordinates of a point P inside a convex polygon using area-based weights. + So that this method generalizes to n-sided polygons, we use the Waschpress points as the generalized + barycentric coordinates, which is only valid for convex polygons. + + Parameters + ---------- + nodes : numpy.ndarray + Spherical coordinates (lon,lat) of each corner node of a face + point : numpy.ndarray + Spherical coordinates (lon,lat) of the point + Returns + ------- + numpy.ndarray + Barycentric coordinates corresponding to each vertex. + + """ + n = len(nodes) + sum_wi = 0 + w = [] + + for i in range(0, n): + vim1 = nodes[i - 1] + vi = nodes[i] + vi1 = nodes[(i + 1) % n] + a0 = _triangle_area(vim1, vi, vi1) + a1 = _triangle_area(point, vim1, vi) + a2 = _triangle_area(point, vi, vi1) + sum_wi += a0 / (a1 * a2) + w.append(a0 / (a1 * a2)) + + barycentric_coords = [w_i / sum_wi for w_i in w] + + return barycentric_coords + +def _triangle_area(A, B, C): + """ + Compute the area of a triangle given by three points. + """ + return 0.5 * (A[0] * (B[1] - C[1]) + B[0] * (C[1] - A[1]) + C[0] * (A[1] - B[1])) \ No newline at end of file diff --git a/parcels/xfieldset.py b/parcels/xfieldset.py new file mode 100644 index 0000000000..478d132026 --- /dev/null +++ b/parcels/xfieldset.py @@ -0,0 +1,91 @@ +import importlib.util +import os +import sys +import warnings +from glob import glob + +import numpy as np + +from parcels._typing import GridIndexingType, InterpMethodOption, Mesh +from parcels.xfield import XField, XVectorField +from parcels.particlefile import ParticleFile +from parcels.tools._helpers import fieldset_repr, default_repr +from parcels.tools.converters import TimeConverter +from parcels.tools.warnings import FieldSetWarning + +import xarray as xr +import uxarray as ux + +__all__ = ["FieldSet"] + + +class XFieldSet: + """XFieldSet class that holds hydrodynamic data needed to execute particles. + + Parameters + ---------- + ds : xarray.Dataset | uxarray.UxDataset) + xarray.Dataset and/or uxarray.UxDataset objects containing the field data. + + Notes + ----- + The `ds` object is a xarray.Dataset or uxarray.UxDataset object. + In XArray terminology, the (Ux)Dataset holds multiple (Ux)DataArray objects. + Each (Ux)DataArray object is a single "field" that is associated with their own + dimensions and coordinates within the (Ux)Dataset. + + A (Ux)Dataset object is associated with a single mesh, which can have multiple + types of "points" (multiple "grids") (e.g. for UxDataSets, these are "face_lon", + "face_lat", "node_lon", "node_lat", "edge_lon", "edge_lat"). Each (Ux)DataArray is + registered to a specific set of points on the mesh. + + For UxDataset objects, each `UXDataArray.attributes` field dictionary contains + the necessary metadata to help determine which set of points a field is registered + to and what parent model the field is associated with. Parcels uses this metadata + during execution for interpolation. Each `UXDataArray.attributes` field dictionary + must have: + * "location" key set to "face", "node", or "edge" to define which pairing of points a field is associated with. + * "mesh" key to define which parent model the fields are associated with (e.g. "fesom_mesh", "icon_mesh") + + """ + + def __init__(self, ds: xr.Dataset | ux.UxDataset): + self.ds = ds + + # Create pointers to each (Ux)DataArray + for field in self.ds.data_vars: + setattr(self, field, XField(field,self.ds[field])) + + self._add_UVfield() + + def add_vector_field(self, vfield): + """Add a :class:`parcels.field.VectorField` object to the FieldSet. + + Parameters + ---------- + vfield : parcels.XVectorField + class:`parcels.xfieldset.XVectorField` object to be added + """ + setattr(self, vfield.name, vfield) + for v in vfield.__dict__.values(): + if isinstance(v, XField) and (v not in self.get_fields()): + self.add_field(v) + + def get_fields(self) -> list[XField | XVectorField]: + """Returns a list of all the :class:`parcels.field.Field` and :class:`parcels.field.VectorField` + objects associated with this FieldSet. + """ + fields = [] + for v in self.__dict__.values(): + if type(v) in [XField, XVectorField]: + if v not in fields: + fields.append(v) + return fields + + def _add_UVfield(self): + if not hasattr(self, "UV") and hasattr(self, "u") and hasattr(self, "v"): + self.add_xvector_field(XVectorField("UV", self.u, self.v)) + if not hasattr(self, "UVW") and hasattr(self, "w"): + self.add_xvector_field(XVectorField("UVW", self.u, self.v, self.w)) + + From 36e164a24be7f1c8f236b78cedc777ef36e3761e Mon Sep 17 00:00:00 2001 From: Joe Date: Fri, 21 Mar 2025 15:33:17 -0400 Subject: [PATCH 02/46] Add notebook that illustrates created xarray and uxarray objects for stommel gyre example --- docs/examples/tutorial_stommel_uxarray.ipynb | 1691 ++++++++++++++++++ 1 file changed, 1691 insertions(+) create mode 100644 docs/examples/tutorial_stommel_uxarray.ipynb diff --git a/docs/examples/tutorial_stommel_uxarray.ipynb b/docs/examples/tutorial_stommel_uxarray.ipynb new file mode 100644 index 0000000000..9dd5b0042b --- /dev/null +++ b/docs/examples/tutorial_stommel_uxarray.ipynb @@ -0,0 +1,1691 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Stommel Gyre on Unstructured Grid\n", + "This tutorial walks through creating a UXArray dataset using the Stommel Gyre analytical solution for a closed rectangular domain on a beta-plane" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": "(function(root) {\n function now() {\n return new Date();\n }\n\n const force = true;\n const py_version = '3.6.0'.replace('rc', '-rc.').replace('.dev', '-dev.');\n const reloading = false;\n const Bokeh = root.Bokeh;\n\n // Set a timeout for this load but only if we are not already initializing\n if (typeof (root._bokeh_timeout) === \"undefined\" || (force || !root._bokeh_is_initializing)) {\n root._bokeh_timeout = Date.now() + 5000;\n root._bokeh_failed_load = false;\n }\n\n function run_callbacks() {\n try {\n root._bokeh_onload_callbacks.forEach(function(callback) {\n if (callback != null)\n callback();\n });\n } finally {\n delete root._bokeh_onload_callbacks;\n }\n console.debug(\"Bokeh: all callbacks have finished\");\n }\n\n function load_libs(css_urls, js_urls, js_modules, js_exports, callback) {\n if (css_urls == null) css_urls = [];\n if (js_urls == null) js_urls = [];\n if (js_modules == null) js_modules = [];\n if (js_exports == null) js_exports = {};\n\n root._bokeh_onload_callbacks.push(callback);\n\n if (root._bokeh_is_loading > 0) {\n // Don't load bokeh if it is still initializing\n console.debug(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n return null;\n } else if (js_urls.length === 0 && js_modules.length === 0 && Object.keys(js_exports).length === 0) {\n // There is nothing to load\n run_callbacks();\n return null;\n }\n\n function on_load() {\n root._bokeh_is_loading--;\n if (root._bokeh_is_loading === 0) {\n console.debug(\"Bokeh: all BokehJS libraries/stylesheets loaded\");\n run_callbacks()\n }\n }\n window._bokeh_on_load = on_load\n\n function on_error(e) {\n const src_el = e.srcElement\n console.error(\"failed to load \" + (src_el.href || src_el.src));\n }\n\n const skip = [];\n if (window.requirejs) {\n window.requirejs.config({'packages': {}, 'paths': {}, 'shim': {}});\n root._bokeh_is_loading = css_urls.length + 0;\n } else {\n root._bokeh_is_loading = css_urls.length + js_urls.length + js_modules.length + Object.keys(js_exports).length;\n }\n\n const existing_stylesheets = []\n const links = document.getElementsByTagName('link')\n for (let i = 0; i < links.length; i++) {\n const link = links[i]\n if (link.href != null) {\n existing_stylesheets.push(link.href)\n }\n }\n for (let i = 0; i < css_urls.length; i++) {\n const url = css_urls[i];\n const escaped = encodeURI(url)\n if (existing_stylesheets.indexOf(escaped) !== -1) {\n on_load()\n continue;\n }\n const element = document.createElement(\"link\");\n element.onload = on_load;\n element.onerror = on_error;\n element.rel = \"stylesheet\";\n element.type = \"text/css\";\n element.href = url;\n console.debug(\"Bokeh: injecting link tag for BokehJS stylesheet: \", url);\n document.body.appendChild(element);\n } var existing_scripts = []\n const scripts = document.getElementsByTagName('script')\n for (let i = 0; i < scripts.length; i++) {\n var script = scripts[i]\n if (script.src != null) {\n existing_scripts.push(script.src)\n }\n }\n for (let i = 0; i < js_urls.length; i++) {\n const url = js_urls[i];\n const escaped = encodeURI(url)\n if (skip.indexOf(escaped) !== -1 || existing_scripts.indexOf(escaped) !== -1) {\n if (!window.requirejs) {\n on_load();\n }\n continue;\n }\n const element = document.createElement('script');\n element.onload = on_load;\n element.onerror = on_error;\n element.async = false;\n element.src = url;\n console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n document.head.appendChild(element);\n }\n for (let i = 0; i < js_modules.length; i++) {\n const url = js_modules[i];\n const escaped = encodeURI(url)\n if (skip.indexOf(escaped) !== -1 || existing_scripts.indexOf(escaped) !== -1) {\n if (!window.requirejs) {\n on_load();\n }\n continue;\n }\n var element = document.createElement('script');\n element.onload = on_load;\n element.onerror = on_error;\n element.async = false;\n element.src = url;\n element.type = \"module\";\n console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n document.head.appendChild(element);\n }\n for (const name in js_exports) {\n const url = js_exports[name];\n const escaped = encodeURI(url)\n if (skip.indexOf(escaped) >= 0 || root[name] != null) {\n if (!window.requirejs) {\n on_load();\n }\n continue;\n }\n var element = document.createElement('script');\n element.onerror = on_error;\n element.async = false;\n element.type = \"module\";\n console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n element.textContent = `\n import ${name} from \"${url}\"\n window.${name} = ${name}\n window._bokeh_on_load()\n `\n document.head.appendChild(element);\n }\n if (!js_urls.length && !js_modules.length) {\n on_load()\n }\n };\n\n function inject_raw_css(css) {\n const element = document.createElement(\"style\");\n element.appendChild(document.createTextNode(css));\n document.body.appendChild(element);\n }\n\n const js_urls = [\"https://cdn.holoviz.org/panel/1.6.1/dist/bundled/reactiveesm/es-module-shims@^1.10.0/dist/es-module-shims.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-3.6.0.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-gl-3.6.0.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-widgets-3.6.0.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-tables-3.6.0.min.js\", \"https://cdn.holoviz.org/panel/1.6.1/dist/panel.min.js\"];\n const js_modules = [];\n const js_exports = {};\n const css_urls = [];\n const inline_js = [ function(Bokeh) {\n Bokeh.set_log_level(\"info\");\n },\nfunction(Bokeh) {} // ensure no trailing comma for IE\n ];\n\n function run_inline_js() {\n if ((root.Bokeh !== undefined) || (force === true)) {\n for (let i = 0; i < inline_js.length; i++) {\n try {\n inline_js[i].call(root, root.Bokeh);\n } catch(e) {\n if (!reloading) {\n throw e;\n }\n }\n }\n // Cache old bokeh versions\n if (Bokeh != undefined && !reloading) {\n var NewBokeh = root.Bokeh;\n if (Bokeh.versions === undefined) {\n Bokeh.versions = new Map();\n }\n if (NewBokeh.version !== Bokeh.version) {\n Bokeh.versions.set(NewBokeh.version, NewBokeh)\n }\n root.Bokeh = Bokeh;\n }\n } else if (Date.now() < root._bokeh_timeout) {\n setTimeout(run_inline_js, 100);\n } else if (!root._bokeh_failed_load) {\n console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n root._bokeh_failed_load = true;\n }\n root._bokeh_is_initializing = false\n }\n\n function load_or_wait() {\n // Implement a backoff loop that tries to ensure we do not load multiple\n // versions of Bokeh and its dependencies at the same time.\n // In recent versions we use the root._bokeh_is_initializing flag\n // to determine whether there is an ongoing attempt to initialize\n // bokeh, however for backward compatibility we also try to ensure\n // that we do not start loading a newer (Panel>=1.0 and Bokeh>3) version\n // before older versions are fully initialized.\n if (root._bokeh_is_initializing && Date.now() > root._bokeh_timeout) {\n // If the timeout and bokeh was not successfully loaded we reset\n // everything and try loading again\n root._bokeh_timeout = Date.now() + 5000;\n root._bokeh_is_initializing = false;\n root._bokeh_onload_callbacks = undefined;\n root._bokeh_is_loading = 0\n console.log(\"Bokeh: BokehJS was loaded multiple times but one version failed to initialize.\");\n load_or_wait();\n } else if (root._bokeh_is_initializing || (typeof root._bokeh_is_initializing === \"undefined\" && root._bokeh_onload_callbacks !== undefined)) {\n setTimeout(load_or_wait, 100);\n } else {\n root._bokeh_is_initializing = true\n root._bokeh_onload_callbacks = []\n const bokeh_loaded = root.Bokeh != null && (root.Bokeh.version === py_version || (root.Bokeh.versions !== undefined && root.Bokeh.versions.has(py_version)));\n if (!reloading && !bokeh_loaded) {\n if (root.Bokeh) {\n root.Bokeh = undefined;\n }\n console.debug(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n }\n load_libs(css_urls, js_urls, js_modules, js_exports, function() {\n console.debug(\"Bokeh: BokehJS plotting callback run at\", now());\n run_inline_js();\n });\n }\n }\n // Give older versions of the autoload script a head-start to ensure\n // they initialize before we start loading newer version.\n setTimeout(load_or_wait, 100)\n}(window));", + "application/vnd.holoviews_load.v0+json": "" + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": "\nif ((window.PyViz === undefined) || (window.PyViz instanceof HTMLElement)) {\n window.PyViz = {comms: {}, comm_status:{}, kernels:{}, receivers: {}, plot_index: []}\n}\n\n\n function JupyterCommManager() {\n }\n\n JupyterCommManager.prototype.register_target = function(plot_id, comm_id, msg_handler) {\n if (window.comm_manager || ((window.Jupyter !== undefined) && (Jupyter.notebook.kernel != null))) {\n var comm_manager = window.comm_manager || Jupyter.notebook.kernel.comm_manager;\n comm_manager.register_target(comm_id, function(comm) {\n comm.on_msg(msg_handler);\n });\n } else if ((plot_id in window.PyViz.kernels) && (window.PyViz.kernels[plot_id])) {\n window.PyViz.kernels[plot_id].registerCommTarget(comm_id, function(comm) {\n comm.onMsg = msg_handler;\n });\n } else if (typeof google != 'undefined' && google.colab.kernel != null) {\n google.colab.kernel.comms.registerTarget(comm_id, (comm) => {\n var messages = comm.messages[Symbol.asyncIterator]();\n function processIteratorResult(result) {\n var message = result.value;\n console.log(message)\n var content = {data: message.data, comm_id};\n var buffers = []\n for (var buffer of message.buffers || []) {\n buffers.push(new DataView(buffer))\n }\n var metadata = message.metadata || {};\n var msg = {content, buffers, metadata}\n msg_handler(msg);\n return messages.next().then(processIteratorResult);\n }\n return messages.next().then(processIteratorResult);\n })\n }\n }\n\n JupyterCommManager.prototype.get_client_comm = function(plot_id, comm_id, msg_handler) {\n if (comm_id in window.PyViz.comms) {\n return window.PyViz.comms[comm_id];\n } else if (window.comm_manager || ((window.Jupyter !== undefined) && (Jupyter.notebook.kernel != null))) {\n var comm_manager = window.comm_manager || Jupyter.notebook.kernel.comm_manager;\n var comm = comm_manager.new_comm(comm_id, {}, {}, {}, comm_id);\n if (msg_handler) {\n comm.on_msg(msg_handler);\n }\n } else if ((plot_id in window.PyViz.kernels) && (window.PyViz.kernels[plot_id])) {\n var comm = window.PyViz.kernels[plot_id].connectToComm(comm_id);\n comm.open();\n if (msg_handler) {\n comm.onMsg = msg_handler;\n }\n } else if (typeof google != 'undefined' && google.colab.kernel != null) {\n var comm_promise = google.colab.kernel.comms.open(comm_id)\n comm_promise.then((comm) => {\n window.PyViz.comms[comm_id] = comm;\n if (msg_handler) {\n var messages = comm.messages[Symbol.asyncIterator]();\n function processIteratorResult(result) {\n var message = result.value;\n var content = {data: message.data};\n var metadata = message.metadata || {comm_id};\n var msg = {content, metadata}\n msg_handler(msg);\n return messages.next().then(processIteratorResult);\n }\n return messages.next().then(processIteratorResult);\n }\n })\n var sendClosure = (data, metadata, buffers, disposeOnDone) => {\n return comm_promise.then((comm) => {\n comm.send(data, metadata, buffers, disposeOnDone);\n });\n };\n var comm = {\n send: sendClosure\n };\n }\n window.PyViz.comms[comm_id] = comm;\n return comm;\n }\n window.PyViz.comm_manager = new JupyterCommManager();\n \n\n\nvar JS_MIME_TYPE = 'application/javascript';\nvar HTML_MIME_TYPE = 'text/html';\nvar EXEC_MIME_TYPE = 'application/vnd.holoviews_exec.v0+json';\nvar CLASS_NAME = 'output';\n\n/**\n * Render data to the DOM node\n */\nfunction render(props, node) {\n var div = document.createElement(\"div\");\n var script = document.createElement(\"script\");\n node.appendChild(div);\n node.appendChild(script);\n}\n\n/**\n * Handle when a new output is added\n */\nfunction handle_add_output(event, handle) {\n var output_area = handle.output_area;\n var output = handle.output;\n if ((output.data == undefined) || (!output.data.hasOwnProperty(EXEC_MIME_TYPE))) {\n return\n }\n var id = output.metadata[EXEC_MIME_TYPE][\"id\"];\n var toinsert = output_area.element.find(\".\" + CLASS_NAME.split(' ')[0]);\n if (id !== undefined) {\n var nchildren = toinsert.length;\n var html_node = toinsert[nchildren-1].children[0];\n html_node.innerHTML = output.data[HTML_MIME_TYPE];\n var scripts = [];\n var nodelist = html_node.querySelectorAll(\"script\");\n for (var i in nodelist) {\n if (nodelist.hasOwnProperty(i)) {\n scripts.push(nodelist[i])\n }\n }\n\n scripts.forEach( function (oldScript) {\n var newScript = document.createElement(\"script\");\n var attrs = [];\n var nodemap = oldScript.attributes;\n for (var j in nodemap) {\n if (nodemap.hasOwnProperty(j)) {\n attrs.push(nodemap[j])\n }\n }\n attrs.forEach(function(attr) { newScript.setAttribute(attr.name, attr.value) });\n newScript.appendChild(document.createTextNode(oldScript.innerHTML));\n oldScript.parentNode.replaceChild(newScript, oldScript);\n });\n if (JS_MIME_TYPE in output.data) {\n toinsert[nchildren-1].children[1].textContent = output.data[JS_MIME_TYPE];\n }\n output_area._hv_plot_id = id;\n if ((window.Bokeh !== undefined) && (id in Bokeh.index)) {\n window.PyViz.plot_index[id] = Bokeh.index[id];\n } else {\n window.PyViz.plot_index[id] = null;\n }\n } else if (output.metadata[EXEC_MIME_TYPE][\"server_id\"] !== undefined) {\n var bk_div = document.createElement(\"div\");\n bk_div.innerHTML = output.data[HTML_MIME_TYPE];\n var script_attrs = bk_div.children[0].attributes;\n for (var i = 0; i < script_attrs.length; i++) {\n toinsert[toinsert.length - 1].childNodes[1].setAttribute(script_attrs[i].name, script_attrs[i].value);\n }\n // store reference to server id on output_area\n output_area._bokeh_server_id = output.metadata[EXEC_MIME_TYPE][\"server_id\"];\n }\n}\n\n/**\n * Handle when an output is cleared or removed\n */\nfunction handle_clear_output(event, handle) {\n var id = handle.cell.output_area._hv_plot_id;\n var server_id = handle.cell.output_area._bokeh_server_id;\n if (((id === undefined) || !(id in PyViz.plot_index)) && (server_id !== undefined)) { return; }\n var comm = window.PyViz.comm_manager.get_client_comm(\"hv-extension-comm\", \"hv-extension-comm\", function () {});\n if (server_id !== null) {\n comm.send({event_type: 'server_delete', 'id': server_id});\n return;\n } else if (comm !== null) {\n comm.send({event_type: 'delete', 'id': id});\n }\n delete PyViz.plot_index[id];\n if ((window.Bokeh !== undefined) & (id in window.Bokeh.index)) {\n var doc = window.Bokeh.index[id].model.document\n doc.clear();\n const i = window.Bokeh.documents.indexOf(doc);\n if (i > -1) {\n window.Bokeh.documents.splice(i, 1);\n }\n }\n}\n\n/**\n * Handle kernel restart event\n */\nfunction handle_kernel_cleanup(event, handle) {\n delete PyViz.comms[\"hv-extension-comm\"];\n window.PyViz.plot_index = {}\n}\n\n/**\n * Handle update_display_data messages\n */\nfunction handle_update_output(event, handle) {\n handle_clear_output(event, {cell: {output_area: handle.output_area}})\n handle_add_output(event, handle)\n}\n\nfunction register_renderer(events, OutputArea) {\n function append_mime(data, metadata, element) {\n // create a DOM node to render to\n var toinsert = this.create_output_subarea(\n metadata,\n CLASS_NAME,\n EXEC_MIME_TYPE\n );\n this.keyboard_manager.register_events(toinsert);\n // Render to node\n var props = {data: data, metadata: metadata[EXEC_MIME_TYPE]};\n render(props, toinsert[0]);\n element.append(toinsert);\n return toinsert\n }\n\n events.on('output_added.OutputArea', handle_add_output);\n events.on('output_updated.OutputArea', handle_update_output);\n events.on('clear_output.CodeCell', handle_clear_output);\n events.on('delete.Cell', handle_clear_output);\n events.on('kernel_ready.Kernel', handle_kernel_cleanup);\n\n OutputArea.prototype.register_mime_type(EXEC_MIME_TYPE, append_mime, {\n safe: true,\n index: 0\n });\n}\n\nif (window.Jupyter !== undefined) {\n try {\n var events = require('base/js/events');\n var OutputArea = require('notebook/js/outputarea').OutputArea;\n if (OutputArea.prototype.mime_types().indexOf(EXEC_MIME_TYPE) == -1) {\n register_renderer(events, OutputArea);\n }\n } catch(err) {\n }\n}\n", + "application/vnd.holoviews_load.v0+json": "" + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.holoviews_exec.v0+json": "", + "text/html": [ + "
\n", + "
\n", + "
\n", + "" + ] + }, + "metadata": { + "application/vnd.holoviews_exec.v0+json": { + "id": "890cdc6e-b61f-49f0-bf13-7c907467a7fc" + } + }, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": "(function(root) {\n function now() {\n return new Date();\n }\n\n const force = false;\n const py_version = '3.6.0'.replace('rc', '-rc.').replace('.dev', '-dev.');\n const reloading = true;\n const Bokeh = root.Bokeh;\n\n // Set a timeout for this load but only if we are not already initializing\n if (typeof (root._bokeh_timeout) === \"undefined\" || (force || !root._bokeh_is_initializing)) {\n root._bokeh_timeout = Date.now() + 5000;\n root._bokeh_failed_load = false;\n }\n\n function run_callbacks() {\n try {\n root._bokeh_onload_callbacks.forEach(function(callback) {\n if (callback != null)\n callback();\n });\n } finally {\n delete root._bokeh_onload_callbacks;\n }\n console.debug(\"Bokeh: all callbacks have finished\");\n }\n\n function load_libs(css_urls, js_urls, js_modules, js_exports, callback) {\n if (css_urls == null) css_urls = [];\n if (js_urls == null) js_urls = [];\n if (js_modules == null) js_modules = [];\n if (js_exports == null) js_exports = {};\n\n root._bokeh_onload_callbacks.push(callback);\n\n if (root._bokeh_is_loading > 0) {\n // Don't load bokeh if it is still initializing\n console.debug(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n return null;\n } else if (js_urls.length === 0 && js_modules.length === 0 && Object.keys(js_exports).length === 0) {\n // There is nothing to load\n run_callbacks();\n return null;\n }\n\n function on_load() {\n root._bokeh_is_loading--;\n if (root._bokeh_is_loading === 0) {\n console.debug(\"Bokeh: all BokehJS libraries/stylesheets loaded\");\n run_callbacks()\n }\n }\n window._bokeh_on_load = on_load\n\n function on_error(e) {\n const src_el = e.srcElement\n console.error(\"failed to load \" + (src_el.href || src_el.src));\n }\n\n const skip = [];\n if (window.requirejs) {\n window.requirejs.config({'packages': {}, 'paths': {}, 'shim': {}});\n root._bokeh_is_loading = css_urls.length + 0;\n } else {\n root._bokeh_is_loading = css_urls.length + js_urls.length + js_modules.length + Object.keys(js_exports).length;\n }\n\n const existing_stylesheets = []\n const links = document.getElementsByTagName('link')\n for (let i = 0; i < links.length; i++) {\n const link = links[i]\n if (link.href != null) {\n existing_stylesheets.push(link.href)\n }\n }\n for (let i = 0; i < css_urls.length; i++) {\n const url = css_urls[i];\n const escaped = encodeURI(url)\n if (existing_stylesheets.indexOf(escaped) !== -1) {\n on_load()\n continue;\n }\n const element = document.createElement(\"link\");\n element.onload = on_load;\n element.onerror = on_error;\n element.rel = \"stylesheet\";\n element.type = \"text/css\";\n element.href = url;\n console.debug(\"Bokeh: injecting link tag for BokehJS stylesheet: \", url);\n document.body.appendChild(element);\n } var existing_scripts = []\n const scripts = document.getElementsByTagName('script')\n for (let i = 0; i < scripts.length; i++) {\n var script = scripts[i]\n if (script.src != null) {\n existing_scripts.push(script.src)\n }\n }\n for (let i = 0; i < js_urls.length; i++) {\n const url = js_urls[i];\n const escaped = encodeURI(url)\n if (skip.indexOf(escaped) !== -1 || existing_scripts.indexOf(escaped) !== -1) {\n if (!window.requirejs) {\n on_load();\n }\n continue;\n }\n const element = document.createElement('script');\n element.onload = on_load;\n element.onerror = on_error;\n element.async = false;\n element.src = url;\n console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n document.head.appendChild(element);\n }\n for (let i = 0; i < js_modules.length; i++) {\n const url = js_modules[i];\n const escaped = encodeURI(url)\n if (skip.indexOf(escaped) !== -1 || existing_scripts.indexOf(escaped) !== -1) {\n if (!window.requirejs) {\n on_load();\n }\n continue;\n }\n var element = document.createElement('script');\n element.onload = on_load;\n element.onerror = on_error;\n element.async = false;\n element.src = url;\n element.type = \"module\";\n console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n document.head.appendChild(element);\n }\n for (const name in js_exports) {\n const url = js_exports[name];\n const escaped = encodeURI(url)\n if (skip.indexOf(escaped) >= 0 || root[name] != null) {\n if (!window.requirejs) {\n on_load();\n }\n continue;\n }\n var element = document.createElement('script');\n element.onerror = on_error;\n element.async = false;\n element.type = \"module\";\n console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n element.textContent = `\n import ${name} from \"${url}\"\n window.${name} = ${name}\n window._bokeh_on_load()\n `\n document.head.appendChild(element);\n }\n if (!js_urls.length && !js_modules.length) {\n on_load()\n }\n };\n\n function inject_raw_css(css) {\n const element = document.createElement(\"style\");\n element.appendChild(document.createTextNode(css));\n document.body.appendChild(element);\n }\n\n const js_urls = [\"https://cdn.holoviz.org/panel/1.6.1/dist/bundled/reactiveesm/es-module-shims@^1.10.0/dist/es-module-shims.min.js\"];\n const js_modules = [];\n const js_exports = {};\n const css_urls = [];\n const inline_js = [ function(Bokeh) {\n Bokeh.set_log_level(\"info\");\n },\nfunction(Bokeh) {} // ensure no trailing comma for IE\n ];\n\n function run_inline_js() {\n if ((root.Bokeh !== undefined) || (force === true)) {\n for (let i = 0; i < inline_js.length; i++) {\n try {\n inline_js[i].call(root, root.Bokeh);\n } catch(e) {\n if (!reloading) {\n throw e;\n }\n }\n }\n // Cache old bokeh versions\n if (Bokeh != undefined && !reloading) {\n var NewBokeh = root.Bokeh;\n if (Bokeh.versions === undefined) {\n Bokeh.versions = new Map();\n }\n if (NewBokeh.version !== Bokeh.version) {\n Bokeh.versions.set(NewBokeh.version, NewBokeh)\n }\n root.Bokeh = Bokeh;\n }\n } else if (Date.now() < root._bokeh_timeout) {\n setTimeout(run_inline_js, 100);\n } else if (!root._bokeh_failed_load) {\n console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n root._bokeh_failed_load = true;\n }\n root._bokeh_is_initializing = false\n }\n\n function load_or_wait() {\n // Implement a backoff loop that tries to ensure we do not load multiple\n // versions of Bokeh and its dependencies at the same time.\n // In recent versions we use the root._bokeh_is_initializing flag\n // to determine whether there is an ongoing attempt to initialize\n // bokeh, however for backward compatibility we also try to ensure\n // that we do not start loading a newer (Panel>=1.0 and Bokeh>3) version\n // before older versions are fully initialized.\n if (root._bokeh_is_initializing && Date.now() > root._bokeh_timeout) {\n // If the timeout and bokeh was not successfully loaded we reset\n // everything and try loading again\n root._bokeh_timeout = Date.now() + 5000;\n root._bokeh_is_initializing = false;\n root._bokeh_onload_callbacks = undefined;\n root._bokeh_is_loading = 0\n console.log(\"Bokeh: BokehJS was loaded multiple times but one version failed to initialize.\");\n load_or_wait();\n } else if (root._bokeh_is_initializing || (typeof root._bokeh_is_initializing === \"undefined\" && root._bokeh_onload_callbacks !== undefined)) {\n setTimeout(load_or_wait, 100);\n } else {\n root._bokeh_is_initializing = true\n root._bokeh_onload_callbacks = []\n const bokeh_loaded = root.Bokeh != null && (root.Bokeh.version === py_version || (root.Bokeh.versions !== undefined && root.Bokeh.versions.has(py_version)));\n if (!reloading && !bokeh_loaded) {\n if (root.Bokeh) {\n root.Bokeh = undefined;\n }\n console.debug(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n }\n load_libs(css_urls, js_urls, js_modules, js_exports, function() {\n console.debug(\"Bokeh: BokehJS plotting callback run at\", now());\n run_inline_js();\n });\n }\n }\n // Give older versions of the autoload script a head-start to ensure\n // they initialize before we start loading newer version.\n setTimeout(load_or_wait, 100)\n}(window));", + "application/vnd.holoviews_load.v0+json": "" + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": "\nif ((window.PyViz === undefined) || (window.PyViz instanceof HTMLElement)) {\n window.PyViz = {comms: {}, comm_status:{}, kernels:{}, receivers: {}, plot_index: []}\n}\n\n\n function JupyterCommManager() {\n }\n\n JupyterCommManager.prototype.register_target = function(plot_id, comm_id, msg_handler) {\n if (window.comm_manager || ((window.Jupyter !== undefined) && (Jupyter.notebook.kernel != null))) {\n var comm_manager = window.comm_manager || Jupyter.notebook.kernel.comm_manager;\n comm_manager.register_target(comm_id, function(comm) {\n comm.on_msg(msg_handler);\n });\n } else if ((plot_id in window.PyViz.kernels) && (window.PyViz.kernels[plot_id])) {\n window.PyViz.kernels[plot_id].registerCommTarget(comm_id, function(comm) {\n comm.onMsg = msg_handler;\n });\n } else if (typeof google != 'undefined' && google.colab.kernel != null) {\n google.colab.kernel.comms.registerTarget(comm_id, (comm) => {\n var messages = comm.messages[Symbol.asyncIterator]();\n function processIteratorResult(result) {\n var message = result.value;\n console.log(message)\n var content = {data: message.data, comm_id};\n var buffers = []\n for (var buffer of message.buffers || []) {\n buffers.push(new DataView(buffer))\n }\n var metadata = message.metadata || {};\n var msg = {content, buffers, metadata}\n msg_handler(msg);\n return messages.next().then(processIteratorResult);\n }\n return messages.next().then(processIteratorResult);\n })\n }\n }\n\n JupyterCommManager.prototype.get_client_comm = function(plot_id, comm_id, msg_handler) {\n if (comm_id in window.PyViz.comms) {\n return window.PyViz.comms[comm_id];\n } else if (window.comm_manager || ((window.Jupyter !== undefined) && (Jupyter.notebook.kernel != null))) {\n var comm_manager = window.comm_manager || Jupyter.notebook.kernel.comm_manager;\n var comm = comm_manager.new_comm(comm_id, {}, {}, {}, comm_id);\n if (msg_handler) {\n comm.on_msg(msg_handler);\n }\n } else if ((plot_id in window.PyViz.kernels) && (window.PyViz.kernels[plot_id])) {\n var comm = window.PyViz.kernels[plot_id].connectToComm(comm_id);\n comm.open();\n if (msg_handler) {\n comm.onMsg = msg_handler;\n }\n } else if (typeof google != 'undefined' && google.colab.kernel != null) {\n var comm_promise = google.colab.kernel.comms.open(comm_id)\n comm_promise.then((comm) => {\n window.PyViz.comms[comm_id] = comm;\n if (msg_handler) {\n var messages = comm.messages[Symbol.asyncIterator]();\n function processIteratorResult(result) {\n var message = result.value;\n var content = {data: message.data};\n var metadata = message.metadata || {comm_id};\n var msg = {content, metadata}\n msg_handler(msg);\n return messages.next().then(processIteratorResult);\n }\n return messages.next().then(processIteratorResult);\n }\n })\n var sendClosure = (data, metadata, buffers, disposeOnDone) => {\n return comm_promise.then((comm) => {\n comm.send(data, metadata, buffers, disposeOnDone);\n });\n };\n var comm = {\n send: sendClosure\n };\n }\n window.PyViz.comms[comm_id] = comm;\n return comm;\n }\n window.PyViz.comm_manager = new JupyterCommManager();\n \n\n\nvar JS_MIME_TYPE = 'application/javascript';\nvar HTML_MIME_TYPE = 'text/html';\nvar EXEC_MIME_TYPE = 'application/vnd.holoviews_exec.v0+json';\nvar CLASS_NAME = 'output';\n\n/**\n * Render data to the DOM node\n */\nfunction render(props, node) {\n var div = document.createElement(\"div\");\n var script = document.createElement(\"script\");\n node.appendChild(div);\n node.appendChild(script);\n}\n\n/**\n * Handle when a new output is added\n */\nfunction handle_add_output(event, handle) {\n var output_area = handle.output_area;\n var output = handle.output;\n if ((output.data == undefined) || (!output.data.hasOwnProperty(EXEC_MIME_TYPE))) {\n return\n }\n var id = output.metadata[EXEC_MIME_TYPE][\"id\"];\n var toinsert = output_area.element.find(\".\" + CLASS_NAME.split(' ')[0]);\n if (id !== undefined) {\n var nchildren = toinsert.length;\n var html_node = toinsert[nchildren-1].children[0];\n html_node.innerHTML = output.data[HTML_MIME_TYPE];\n var scripts = [];\n var nodelist = html_node.querySelectorAll(\"script\");\n for (var i in nodelist) {\n if (nodelist.hasOwnProperty(i)) {\n scripts.push(nodelist[i])\n }\n }\n\n scripts.forEach( function (oldScript) {\n var newScript = document.createElement(\"script\");\n var attrs = [];\n var nodemap = oldScript.attributes;\n for (var j in nodemap) {\n if (nodemap.hasOwnProperty(j)) {\n attrs.push(nodemap[j])\n }\n }\n attrs.forEach(function(attr) { newScript.setAttribute(attr.name, attr.value) });\n newScript.appendChild(document.createTextNode(oldScript.innerHTML));\n oldScript.parentNode.replaceChild(newScript, oldScript);\n });\n if (JS_MIME_TYPE in output.data) {\n toinsert[nchildren-1].children[1].textContent = output.data[JS_MIME_TYPE];\n }\n output_area._hv_plot_id = id;\n if ((window.Bokeh !== undefined) && (id in Bokeh.index)) {\n window.PyViz.plot_index[id] = Bokeh.index[id];\n } else {\n window.PyViz.plot_index[id] = null;\n }\n } else if (output.metadata[EXEC_MIME_TYPE][\"server_id\"] !== undefined) {\n var bk_div = document.createElement(\"div\");\n bk_div.innerHTML = output.data[HTML_MIME_TYPE];\n var script_attrs = bk_div.children[0].attributes;\n for (var i = 0; i < script_attrs.length; i++) {\n toinsert[toinsert.length - 1].childNodes[1].setAttribute(script_attrs[i].name, script_attrs[i].value);\n }\n // store reference to server id on output_area\n output_area._bokeh_server_id = output.metadata[EXEC_MIME_TYPE][\"server_id\"];\n }\n}\n\n/**\n * Handle when an output is cleared or removed\n */\nfunction handle_clear_output(event, handle) {\n var id = handle.cell.output_area._hv_plot_id;\n var server_id = handle.cell.output_area._bokeh_server_id;\n if (((id === undefined) || !(id in PyViz.plot_index)) && (server_id !== undefined)) { return; }\n var comm = window.PyViz.comm_manager.get_client_comm(\"hv-extension-comm\", \"hv-extension-comm\", function () {});\n if (server_id !== null) {\n comm.send({event_type: 'server_delete', 'id': server_id});\n return;\n } else if (comm !== null) {\n comm.send({event_type: 'delete', 'id': id});\n }\n delete PyViz.plot_index[id];\n if ((window.Bokeh !== undefined) & (id in window.Bokeh.index)) {\n var doc = window.Bokeh.index[id].model.document\n doc.clear();\n const i = window.Bokeh.documents.indexOf(doc);\n if (i > -1) {\n window.Bokeh.documents.splice(i, 1);\n }\n }\n}\n\n/**\n * Handle kernel restart event\n */\nfunction handle_kernel_cleanup(event, handle) {\n delete PyViz.comms[\"hv-extension-comm\"];\n window.PyViz.plot_index = {}\n}\n\n/**\n * Handle update_display_data messages\n */\nfunction handle_update_output(event, handle) {\n handle_clear_output(event, {cell: {output_area: handle.output_area}})\n handle_add_output(event, handle)\n}\n\nfunction register_renderer(events, OutputArea) {\n function append_mime(data, metadata, element) {\n // create a DOM node to render to\n var toinsert = this.create_output_subarea(\n metadata,\n CLASS_NAME,\n EXEC_MIME_TYPE\n );\n this.keyboard_manager.register_events(toinsert);\n // Render to node\n var props = {data: data, metadata: metadata[EXEC_MIME_TYPE]};\n render(props, toinsert[0]);\n element.append(toinsert);\n return toinsert\n }\n\n events.on('output_added.OutputArea', handle_add_output);\n events.on('output_updated.OutputArea', handle_update_output);\n events.on('clear_output.CodeCell', handle_clear_output);\n events.on('delete.Cell', handle_clear_output);\n events.on('kernel_ready.Kernel', handle_kernel_cleanup);\n\n OutputArea.prototype.register_mime_type(EXEC_MIME_TYPE, append_mime, {\n safe: true,\n index: 0\n });\n}\n\nif (window.Jupyter !== undefined) {\n try {\n var events = require('base/js/events');\n var OutputArea = require('notebook/js/outputarea').OutputArea;\n if (OutputArea.prototype.mime_types().indexOf(EXEC_MIME_TYPE) == -1) {\n register_renderer(events, OutputArea);\n }\n } catch(err) {\n }\n}\n", + "application/vnd.holoviews_load.v0+json": "" + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/joe/miniconda3/envs/parcels/lib/python3.13/site-packages/uxarray/grid/coordinates.py:255: UserWarning: This cannot be guaranteed to work correctly on concave polygons\n", + " warnings.warn(\"This cannot be guaranteed to work correctly on concave polygons\")\n" + ] + }, + { + "data": {}, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.holoviews_exec.v0+json": "", + "text/html": [ + "
\n", + "
\n", + "
\n", + "" + ], + "text/plain": [ + ":Path [Longitude,Latitude]" + ] + }, + "execution_count": 1, + "metadata": { + "application/vnd.holoviews_exec.v0+json": { + "id": "9b0baffb-0afd-4cba-a35f-55f20dcbe9f8" + } + }, + "output_type": "execute_result" + } + ], + "source": [ + "def stommel_fieldset_uxarray(xdim=200, ydim=200):\n", + " \"\"\"Simulate a periodic current along a western boundary, with significantly\n", + " larger velocities along the western edge than the rest of the region\n", + "\n", + " The original test description can be found in: N. Fabbroni, 2009,\n", + " Numerical Simulation of Passive tracers dispersion in the sea,\n", + " Ph.D. dissertation, University of Bologna\n", + " http://amsdottorato.unibo.it/1733/1/Fabbroni_Nicoletta_Tesi.pdf\n", + " \"\"\"\n", + " import uxarray as ux\n", + " import numpy as np\n", + " import math\n", + " import pandas as pd\n", + "\n", + " a = b = 66666 * 1e3\n", + " scalefac = 0.00025 # to scale for physically meaningful velocities\n", + "\n", + " # Coordinates of the test fieldset\n", + " # Crowd points to the west edge of the domain\n", + " # using a polyonmial map on x-direction\n", + " x = np.linspace(0, 1, xdim, dtype=np.float32)\n", + " lon, lat = np.meshgrid(\n", + " a*x,\n", + " np.linspace(0, b, ydim, dtype=np.float32)\n", + " )\n", + " points = (lon.flatten()/1111111.111111111,lat.flatten()/1111111.111111111)\n", + " \n", + " # Create the grid\n", + " uxgrid = ux.Grid.from_points(points, method=\"regional_delaunay\")\n", + " uxgrid.construct_face_centers()\n", + "\n", + " # Define arrays U (zonal), V (meridional) and P (sea surface height)\n", + " U = np.zeros((1,1,lat.size), dtype=np.float32)\n", + " V = np.zeros((1,1,lat.size), dtype=np.float32)\n", + " P = np.zeros((1,1,lat.size), dtype=np.float32)\n", + "\n", + " beta = 2e-11\n", + " r = 1 / (11.6 * 86400)\n", + " es = r / (beta * a)\n", + "\n", + " i = 0\n", + " for x, y in zip(lon.flatten(), lat.flatten()):\n", + " xi = x / a\n", + " yi = y / b\n", + " P[0,0,i] = (\n", + " (1 - math.exp(-xi / es) - xi)\n", + " * math.pi\n", + " * np.sin(math.pi * yi)\n", + " * scalefac\n", + " )\n", + " U[0,0,i] = (\n", + " -(1 - math.exp(-xi / es) - xi)\n", + " * math.pi**2\n", + " * np.cos(math.pi * yi)\n", + " * scalefac\n", + " )\n", + " V[0,0,i] = (\n", + " (math.exp(-xi / es) / es - 1)\n", + " * math.pi\n", + " * np.sin(math.pi * yi)\n", + " * scalefac\n", + " )\n", + " i+=1\n", + "\n", + " u = ux.UxDataArray(\n", + " data=U,\n", + " name='u',\n", + " uxgrid=uxgrid,\n", + " dims=[\"time\",\"nz1\",\"n_node\"],\n", + " coords = dict(\n", + " time=([\"time\"], pd.to_datetime(['2000-01-01'])),\n", + " nz1=([\"nz1\"], [0]),\n", + " ),\n", + " attrs=dict(\n", + " description=\"zonal velocity\",\n", + " units=\"m/s\",\n", + " location=\"node\",\n", + " mesh=\"delaunay\",\n", + " ),\n", + " )\n", + " v = ux.UxDataArray(\n", + " data=V,\n", + " name='v',\n", + " uxgrid=uxgrid,\n", + " dims=[\"time\",\"nz1\",\"n_node\"],\n", + " coords = dict(\n", + " time=([\"time\"], pd.to_datetime(['2000-01-01'])),\n", + " nz1=([\"nz1\"], [0]),\n", + " ),\n", + " attrs=dict(\n", + " description=\"meridional velocity\",\n", + " units=\"m/s\",\n", + " location=\"node\",\n", + " mesh=\"delaunay\",\n", + " ),\n", + " )\n", + " p = ux.UxDataArray(\n", + " data=P,\n", + " name='p',\n", + " uxgrid=uxgrid,\n", + " dims=[\"time\",\"nz1\",\"n_node\"],\n", + " coords = dict(\n", + " time=([\"time\"], pd.to_datetime(['2000-01-01'])),\n", + " nz1=([\"nz1\"], [0]),\n", + " ),\n", + " attrs=dict(\n", + " description=\"pressure\",\n", + " units=\"N/m^2\",\n", + " location=\"node\",\n", + " mesh=\"delaunay\",\n", + " ),\n", + " )\n", + "\n", + "\n", + " return ux.UxDataset(\n", + " {'u':u, 'v':v, 'p': p}, \n", + " uxgrid=uxgrid\n", + " )\n", + "\n", + "uxds = stommel_fieldset_uxarray(50,50)\n", + "\n", + "uxds.uxgrid.plot(\n", + " line_width=0.5,\n", + " height=500,\n", + " width=1000,\n", + " title=\"Regional Delaunay Regions\",\n", + ")\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "def stommel_fieldset_xarray(xdim=200, ydim=200, grid_type=\"A\"):\n", + " \"\"\"Simulate a periodic current along a western boundary, with significantly\n", + " larger velocities along the western edge than the rest of the region\n", + "\n", + " The original test description can be found in: N. Fabbroni, 2009,\n", + " Numerical Simulation of Passive tracers dispersion in the sea,\n", + " Ph.D. dissertation, University of Bologna\n", + " http://amsdottorato.unibo.it/1733/1/Fabbroni_Nicoletta_Tesi.pdf\n", + " \"\"\"\n", + " import xarray as xr\n", + " import numpy as np\n", + " import math\n", + " import pandas as pd\n", + "\n", + " a = b = 10000 * 1e3\n", + " scalefac = 0.05 # to scale for physically meaningful velocities\n", + " dx, dy = a / xdim, b / ydim\n", + "\n", + " # Coordinates of the test fieldset (on A-grid in deg)\n", + " lon = np.linspace(0, a, xdim, dtype=np.float32)\n", + " lat = np.linspace(0, b, ydim, dtype=np.float32)\n", + "\n", + " # Define arrays U (zonal), V (meridional) and P (sea surface height)\n", + " U = np.zeros((1,1,lat.size, lon.size), dtype=np.float32)\n", + " V = np.zeros((1,1,lat.size, lon.size), dtype=np.float32)\n", + " P = np.zeros((1,1,lat.size, lon.size), dtype=np.float32)\n", + "\n", + " beta = 2e-11\n", + " r = 1 / (11.6 * 86400)\n", + " es = r / (beta * a)\n", + "\n", + " for j in range(lat.size):\n", + " for i in range(lon.size):\n", + " xi = lon[i] / a\n", + " yi = lat[j] / b\n", + " P[...,j, i] = (\n", + " (1 - math.exp(-xi / es) - xi)\n", + " * math.pi\n", + " * np.sin(math.pi * yi)\n", + " * scalefac\n", + " )\n", + " if grid_type == \"A\":\n", + " U[...,j, i] = (\n", + " -(1 - math.exp(-xi / es) - xi)\n", + " * math.pi**2\n", + " * np.cos(math.pi * yi)\n", + " * scalefac\n", + " )\n", + " V[...,j, i] = (\n", + " (math.exp(-xi / es) / es - 1)\n", + " * math.pi\n", + " * np.sin(math.pi * yi)\n", + " * scalefac\n", + " )\n", + "\n", + " time = pd.to_datetime(['2000-01-01'])\n", + " z = [0]\n", + " if grid_type == \"C\":\n", + " V[...,:, 1:] = (P[...,:, 1:] - P[...,:, 0:-1]) / dx * a\n", + " U[...,1:, :] = -(P[...,1:, :] - P[...,0:-1, :]) / dy * b\n", + " u_dims = [\"time\",\"nz1\",\"face_lat\", \"node_lon\"]\n", + " u_lat = lat\n", + " u_lon = lon - dx * 0.5\n", + " u_location = \"x_edge\"\n", + " v_dims = [\"time\",\"nz1\",\"node_lat\", \"face_lon\"]\n", + " v_lat = lat - dy * 0.5\n", + " v_lon = lon\n", + " v_location = \"y_edge\"\n", + " p_dims = [\"time\",\"nz1\",\"face_lat\", \"face_lon\"]\n", + " p_lat = lat\n", + " p_lon = lon\n", + " p_location = \"face\"\n", + " \n", + " else:\n", + " u_dims = [\"time\",\"nz1\",\"node_lat\", \"node_lon\"]\n", + " v_dims = [\"time\",\"nz1\",\"node_lat\", \"node_lon\"]\n", + " p_dims = [\"time\",\"nz1\",\"node_lat\", \"node_lon\"]\n", + " u_lat = lat\n", + " u_lon = lon\n", + " v_lat = lat\n", + " v_lon = lon\n", + " u_location = \"node\"\n", + " v_location = \"node\"\n", + " p_lat = lat\n", + " p_lon = lon\n", + " p_location = \"node\"\n", + "\n", + " u = xr.DataArray(\n", + " data=U,\n", + " name='u',\n", + " dims=u_dims,\n", + " coords = [time,z,u_lat,u_lon],\n", + " attrs=dict(\n", + " description=\"zonal velocity\",\n", + " units=\"m/s\",\n", + " location=u_location,\n", + " mesh=f\"Arakawa-{grid_type}\",\n", + " ),\n", + " )\n", + " v = xr.DataArray(\n", + " data=V,\n", + " name='v',\n", + " dims=v_dims,\n", + " coords = [time,z,v_lat,v_lon],\n", + " attrs=dict(\n", + " description=\"meridional velocity\",\n", + " units=\"m/s\",\n", + " location=v_location,\n", + " mesh=f\"Arakawa-{grid_type}\",\n", + " ),\n", + " )\n", + " p = xr.DataArray(\n", + " data=P,\n", + " name='p',\n", + " dims=p_dims,\n", + " coords = [time,z,p_lat,p_lon],\n", + " attrs=dict(\n", + " description=\"pressure\",\n", + " units=\"N/m^2\",\n", + " location=p_location,\n", + " mesh=f\"Arakawa-{grid_type}\",\n", + " ),\n", + " )\n", + "\n", + " return xr.Dataset(\n", + " {'u':u, 'v':v, 'p': p}\n", + " )\n", + "\n", + "ds_arakawa_a = stommel_fieldset_xarray(50,50,\"A\")\n", + "ds_arakawa_c = stommel_fieldset_xarray(50,50,\"C\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.Dataset> Size: 30kB\n",
+       "Dimensions:   (time: 1, nz1: 1, node_lat: 50, node_lon: 50)\n",
+       "Coordinates:\n",
+       "  * time      (time) datetime64[ns] 8B 2000-01-01\n",
+       "  * nz1       (nz1) int64 8B 0\n",
+       "  * node_lat  (node_lat) float32 200B 0.0 2.041e+05 ... 9.796e+06 1e+07\n",
+       "  * node_lon  (node_lon) float32 200B 0.0 2.041e+05 ... 9.796e+06 1e+07\n",
+       "Data variables:\n",
+       "    u         (time, nz1, node_lat, node_lon) float32 10kB -0.0 -0.4752 ... 0.0\n",
+       "    v         (time, nz1, node_lat, node_lon) float32 10kB 0.0 0.0 ... 1.373e-08\n",
+       "    p         (time, nz1, node_lat, node_lon) float32 10kB 0.0 0.0 ... -0.0
" + ], + "text/plain": [ + " Size: 30kB\n", + "Dimensions: (time: 1, nz1: 1, node_lat: 50, node_lon: 50)\n", + "Coordinates:\n", + " * time (time) datetime64[ns] 8B 2000-01-01\n", + " * nz1 (nz1) int64 8B 0\n", + " * node_lat (node_lat) float32 200B 0.0 2.041e+05 ... 9.796e+06 1e+07\n", + " * node_lon (node_lon) float32 200B 0.0 2.041e+05 ... 9.796e+06 1e+07\n", + "Data variables:\n", + " u (time, nz1, node_lat, node_lon) float32 10kB -0.0 -0.4752 ... 0.0\n", + " v (time, nz1, node_lat, node_lon) float32 10kB 0.0 0.0 ... 1.373e-08\n", + " p (time, nz1, node_lat, node_lon) float32 10kB 0.0 0.0 ... -0.0" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds_arakawa_a" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'description': 'zonal velocity',\n", + " 'units': 'm/s',\n", + " 'location': 'node',\n", + " 'mesh': 'Arakawa-A'}" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds_arakawa_a[\"u\"].attrs" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.Dataset> Size: 31kB\n",
+       "Dimensions:   (time: 1, nz1: 1, face_lat: 50, node_lon: 50, node_lat: 50,\n",
+       "               face_lon: 50)\n",
+       "Coordinates:\n",
+       "  * time      (time) datetime64[ns] 8B 2000-01-01\n",
+       "  * nz1       (nz1) int64 8B 0\n",
+       "  * face_lat  (face_lat) float32 200B 0.0 2.041e+05 ... 9.796e+06 1e+07\n",
+       "  * node_lon  (node_lon) float32 200B -1e+05 1.041e+05 ... 9.696e+06 9.9e+06\n",
+       "  * node_lat  (node_lat) float32 200B -1e+05 1.041e+05 ... 9.696e+06 9.9e+06\n",
+       "  * face_lon  (face_lon) float32 200B 0.0 2.041e+05 ... 9.796e+06 1e+07\n",
+       "Data variables:\n",
+       "    u         (time, nz1, face_lat, node_lon) float32 10kB 0.0 0.0 ... 0.0\n",
+       "    v         (time, nz1, node_lat, face_lon) float32 10kB 0.0 0.0 ... 1.401e-08\n",
+       "    p         (time, nz1, face_lat, face_lon) float32 10kB 0.0 0.0 ... -0.0
" + ], + "text/plain": [ + " Size: 31kB\n", + "Dimensions: (time: 1, nz1: 1, face_lat: 50, node_lon: 50, node_lat: 50,\n", + " face_lon: 50)\n", + "Coordinates:\n", + " * time (time) datetime64[ns] 8B 2000-01-01\n", + " * nz1 (nz1) int64 8B 0\n", + " * face_lat (face_lat) float32 200B 0.0 2.041e+05 ... 9.796e+06 1e+07\n", + " * node_lon (node_lon) float32 200B -1e+05 1.041e+05 ... 9.696e+06 9.9e+06\n", + " * node_lat (node_lat) float32 200B -1e+05 1.041e+05 ... 9.696e+06 9.9e+06\n", + " * face_lon (face_lon) float32 200B 0.0 2.041e+05 ... 9.796e+06 1e+07\n", + "Data variables:\n", + " u (time, nz1, face_lat, node_lon) float32 10kB 0.0 0.0 ... 0.0\n", + " v (time, nz1, node_lat, face_lon) float32 10kB 0.0 0.0 ... 1.401e-08\n", + " p (time, nz1, face_lat, face_lon) float32 10kB 0.0 0.0 ... -0.0" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds_arakawa_c" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Size: 8B\n", + "array(1142.17017624)\n", + " Size: 8B\n", + "array(1.04820871)\n", + " Size: 8B\n", + "array(108.96400321)\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "min_length_scale = 1111111.111111111*np.sqrt(np.min(uxds.uxgrid.face_areas))\n", + "print(min_length_scale)\n", + "\n", + "max_v = np.sqrt(uxds['u']**2 + uxds['v']**2).max()\n", + "print(max_v)\n", + "\n", + "cfl = 0.1\n", + "dt = cfl * min_length_scale / max_v\n", + "print(dt)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "ename": "", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n", + "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n", + "\u001b[1;31mClick here for more info. \n", + "\u001b[1;31mView Jupyter log for further details." + ] + } + ], + "source": [ + "import uxarray as ux\n", + "from datetime import timedelta\n", + "from parcels import (\n", + " UXFieldSet,\n", + " ParticleSet,\n", + " Particle,\n", + " UxAdvectionEuler\n", + ")\n", + "import numpy as np\n", + "\n", + "npart = 10\n", + "fieldset = UXFieldSet(uxds)\n", + "# pset = ParticleSet(\n", + "# fieldset, \n", + "# pclass=Particle, \n", + "# lon=np.linspace(1, 59, npart), \n", + "# lat=np.zeros(npart)+30)\n", + "# pset.execute(UxAdvectionEuler, runtime=timedelta(hours=24), dt=timedelta(seconds=dt))\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "parcels", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 5578201679f8fd5e16384718940781a31c3cef53 Mon Sep 17 00:00:00 2001 From: Joe Date: Fri, 21 Mar 2025 16:06:34 -0400 Subject: [PATCH 03/46] Bring in remaining API from Field to XField. --- docs/examples/tutorial_stommel_uxarray.ipynb | 1339 +----------------- parcels/xfield.py | 148 +- 2 files changed, 158 insertions(+), 1329 deletions(-) diff --git a/docs/examples/tutorial_stommel_uxarray.ipynb b/docs/examples/tutorial_stommel_uxarray.ipynb index 9dd5b0042b..cc7187e985 100644 --- a/docs/examples/tutorial_stommel_uxarray.ipynb +++ b/docs/examples/tutorial_stommel_uxarray.ipynb @@ -10,273 +10,9 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/javascript": "(function(root) {\n function now() {\n return new Date();\n }\n\n const force = true;\n const py_version = '3.6.0'.replace('rc', '-rc.').replace('.dev', '-dev.');\n const reloading = false;\n const Bokeh = root.Bokeh;\n\n // Set a timeout for this load but only if we are not already initializing\n if (typeof (root._bokeh_timeout) === \"undefined\" || (force || !root._bokeh_is_initializing)) {\n root._bokeh_timeout = Date.now() + 5000;\n root._bokeh_failed_load = false;\n }\n\n function run_callbacks() {\n try {\n root._bokeh_onload_callbacks.forEach(function(callback) {\n if (callback != null)\n callback();\n });\n } finally {\n delete root._bokeh_onload_callbacks;\n }\n console.debug(\"Bokeh: all callbacks have finished\");\n }\n\n function load_libs(css_urls, js_urls, js_modules, js_exports, callback) {\n if (css_urls == null) css_urls = [];\n if (js_urls == null) js_urls = [];\n if (js_modules == null) js_modules = [];\n if (js_exports == null) js_exports = {};\n\n root._bokeh_onload_callbacks.push(callback);\n\n if (root._bokeh_is_loading > 0) {\n // Don't load bokeh if it is still initializing\n console.debug(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n return null;\n } else if (js_urls.length === 0 && js_modules.length === 0 && Object.keys(js_exports).length === 0) {\n // There is nothing to load\n run_callbacks();\n return null;\n }\n\n function on_load() {\n root._bokeh_is_loading--;\n if (root._bokeh_is_loading === 0) {\n console.debug(\"Bokeh: all BokehJS libraries/stylesheets loaded\");\n run_callbacks()\n }\n }\n window._bokeh_on_load = on_load\n\n function on_error(e) {\n const src_el = e.srcElement\n console.error(\"failed to load \" + (src_el.href || src_el.src));\n }\n\n const skip = [];\n if (window.requirejs) {\n window.requirejs.config({'packages': {}, 'paths': {}, 'shim': {}});\n root._bokeh_is_loading = css_urls.length + 0;\n } else {\n root._bokeh_is_loading = css_urls.length + js_urls.length + js_modules.length + Object.keys(js_exports).length;\n }\n\n const existing_stylesheets = []\n const links = document.getElementsByTagName('link')\n for (let i = 0; i < links.length; i++) {\n const link = links[i]\n if (link.href != null) {\n existing_stylesheets.push(link.href)\n }\n }\n for (let i = 0; i < css_urls.length; i++) {\n const url = css_urls[i];\n const escaped = encodeURI(url)\n if (existing_stylesheets.indexOf(escaped) !== -1) {\n on_load()\n continue;\n }\n const element = document.createElement(\"link\");\n element.onload = on_load;\n element.onerror = on_error;\n element.rel = \"stylesheet\";\n element.type = \"text/css\";\n element.href = url;\n console.debug(\"Bokeh: injecting link tag for BokehJS stylesheet: \", url);\n document.body.appendChild(element);\n } var existing_scripts = []\n const scripts = document.getElementsByTagName('script')\n for (let i = 0; i < scripts.length; i++) {\n var script = scripts[i]\n if (script.src != null) {\n existing_scripts.push(script.src)\n }\n }\n for (let i = 0; i < js_urls.length; i++) {\n const url = js_urls[i];\n const escaped = encodeURI(url)\n if (skip.indexOf(escaped) !== -1 || existing_scripts.indexOf(escaped) !== -1) {\n if (!window.requirejs) {\n on_load();\n }\n continue;\n }\n const element = document.createElement('script');\n element.onload = on_load;\n element.onerror = on_error;\n element.async = false;\n element.src = url;\n console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n document.head.appendChild(element);\n }\n for (let i = 0; i < js_modules.length; i++) {\n const url = js_modules[i];\n const escaped = encodeURI(url)\n if (skip.indexOf(escaped) !== -1 || existing_scripts.indexOf(escaped) !== -1) {\n if (!window.requirejs) {\n on_load();\n }\n continue;\n }\n var element = document.createElement('script');\n element.onload = on_load;\n element.onerror = on_error;\n element.async = false;\n element.src = url;\n element.type = \"module\";\n console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n document.head.appendChild(element);\n }\n for (const name in js_exports) {\n const url = js_exports[name];\n const escaped = encodeURI(url)\n if (skip.indexOf(escaped) >= 0 || root[name] != null) {\n if (!window.requirejs) {\n on_load();\n }\n continue;\n }\n var element = document.createElement('script');\n element.onerror = on_error;\n element.async = false;\n element.type = \"module\";\n console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n element.textContent = `\n import ${name} from \"${url}\"\n window.${name} = ${name}\n window._bokeh_on_load()\n `\n document.head.appendChild(element);\n }\n if (!js_urls.length && !js_modules.length) {\n on_load()\n }\n };\n\n function inject_raw_css(css) {\n const element = document.createElement(\"style\");\n element.appendChild(document.createTextNode(css));\n document.body.appendChild(element);\n }\n\n const js_urls = [\"https://cdn.holoviz.org/panel/1.6.1/dist/bundled/reactiveesm/es-module-shims@^1.10.0/dist/es-module-shims.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-3.6.0.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-gl-3.6.0.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-widgets-3.6.0.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-tables-3.6.0.min.js\", \"https://cdn.holoviz.org/panel/1.6.1/dist/panel.min.js\"];\n const js_modules = [];\n const js_exports = {};\n const css_urls = [];\n const inline_js = [ function(Bokeh) {\n Bokeh.set_log_level(\"info\");\n },\nfunction(Bokeh) {} // ensure no trailing comma for IE\n ];\n\n function run_inline_js() {\n if ((root.Bokeh !== undefined) || (force === true)) {\n for (let i = 0; i < inline_js.length; i++) {\n try {\n inline_js[i].call(root, root.Bokeh);\n } catch(e) {\n if (!reloading) {\n throw e;\n }\n }\n }\n // Cache old bokeh versions\n if (Bokeh != undefined && !reloading) {\n var NewBokeh = root.Bokeh;\n if (Bokeh.versions === undefined) {\n Bokeh.versions = new Map();\n }\n if (NewBokeh.version !== Bokeh.version) {\n Bokeh.versions.set(NewBokeh.version, NewBokeh)\n }\n root.Bokeh = Bokeh;\n }\n } else if (Date.now() < root._bokeh_timeout) {\n setTimeout(run_inline_js, 100);\n } else if (!root._bokeh_failed_load) {\n console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n root._bokeh_failed_load = true;\n }\n root._bokeh_is_initializing = false\n }\n\n function load_or_wait() {\n // Implement a backoff loop that tries to ensure we do not load multiple\n // versions of Bokeh and its dependencies at the same time.\n // In recent versions we use the root._bokeh_is_initializing flag\n // to determine whether there is an ongoing attempt to initialize\n // bokeh, however for backward compatibility we also try to ensure\n // that we do not start loading a newer (Panel>=1.0 and Bokeh>3) version\n // before older versions are fully initialized.\n if (root._bokeh_is_initializing && Date.now() > root._bokeh_timeout) {\n // If the timeout and bokeh was not successfully loaded we reset\n // everything and try loading again\n root._bokeh_timeout = Date.now() + 5000;\n root._bokeh_is_initializing = false;\n root._bokeh_onload_callbacks = undefined;\n root._bokeh_is_loading = 0\n console.log(\"Bokeh: BokehJS was loaded multiple times but one version failed to initialize.\");\n load_or_wait();\n } else if (root._bokeh_is_initializing || (typeof root._bokeh_is_initializing === \"undefined\" && root._bokeh_onload_callbacks !== undefined)) {\n setTimeout(load_or_wait, 100);\n } else {\n root._bokeh_is_initializing = true\n root._bokeh_onload_callbacks = []\n const bokeh_loaded = root.Bokeh != null && (root.Bokeh.version === py_version || (root.Bokeh.versions !== undefined && root.Bokeh.versions.has(py_version)));\n if (!reloading && !bokeh_loaded) {\n if (root.Bokeh) {\n root.Bokeh = undefined;\n }\n console.debug(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n }\n load_libs(css_urls, js_urls, js_modules, js_exports, function() {\n console.debug(\"Bokeh: BokehJS plotting callback run at\", now());\n run_inline_js();\n });\n }\n }\n // Give older versions of the autoload script a head-start to ensure\n // they initialize before we start loading newer version.\n setTimeout(load_or_wait, 100)\n}(window));", - "application/vnd.holoviews_load.v0+json": "" - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/javascript": "\nif ((window.PyViz === undefined) || (window.PyViz instanceof HTMLElement)) {\n window.PyViz = {comms: {}, comm_status:{}, kernels:{}, receivers: {}, plot_index: []}\n}\n\n\n function JupyterCommManager() {\n }\n\n JupyterCommManager.prototype.register_target = function(plot_id, comm_id, msg_handler) {\n if (window.comm_manager || ((window.Jupyter !== undefined) && (Jupyter.notebook.kernel != null))) {\n var comm_manager = window.comm_manager || Jupyter.notebook.kernel.comm_manager;\n comm_manager.register_target(comm_id, function(comm) {\n comm.on_msg(msg_handler);\n });\n } else if ((plot_id in window.PyViz.kernels) && (window.PyViz.kernels[plot_id])) {\n window.PyViz.kernels[plot_id].registerCommTarget(comm_id, function(comm) {\n comm.onMsg = msg_handler;\n });\n } else if (typeof google != 'undefined' && google.colab.kernel != null) {\n google.colab.kernel.comms.registerTarget(comm_id, (comm) => {\n var messages = comm.messages[Symbol.asyncIterator]();\n function processIteratorResult(result) {\n var message = result.value;\n console.log(message)\n var content = {data: message.data, comm_id};\n var buffers = []\n for (var buffer of message.buffers || []) {\n buffers.push(new DataView(buffer))\n }\n var metadata = message.metadata || {};\n var msg = {content, buffers, metadata}\n msg_handler(msg);\n return messages.next().then(processIteratorResult);\n }\n return messages.next().then(processIteratorResult);\n })\n }\n }\n\n JupyterCommManager.prototype.get_client_comm = function(plot_id, comm_id, msg_handler) {\n if (comm_id in window.PyViz.comms) {\n return window.PyViz.comms[comm_id];\n } else if (window.comm_manager || ((window.Jupyter !== undefined) && (Jupyter.notebook.kernel != null))) {\n var comm_manager = window.comm_manager || Jupyter.notebook.kernel.comm_manager;\n var comm = comm_manager.new_comm(comm_id, {}, {}, {}, comm_id);\n if (msg_handler) {\n comm.on_msg(msg_handler);\n }\n } else if ((plot_id in window.PyViz.kernels) && (window.PyViz.kernels[plot_id])) {\n var comm = window.PyViz.kernels[plot_id].connectToComm(comm_id);\n comm.open();\n if (msg_handler) {\n comm.onMsg = msg_handler;\n }\n } else if (typeof google != 'undefined' && google.colab.kernel != null) {\n var comm_promise = google.colab.kernel.comms.open(comm_id)\n comm_promise.then((comm) => {\n window.PyViz.comms[comm_id] = comm;\n if (msg_handler) {\n var messages = comm.messages[Symbol.asyncIterator]();\n function processIteratorResult(result) {\n var message = result.value;\n var content = {data: message.data};\n var metadata = message.metadata || {comm_id};\n var msg = {content, metadata}\n msg_handler(msg);\n return messages.next().then(processIteratorResult);\n }\n return messages.next().then(processIteratorResult);\n }\n })\n var sendClosure = (data, metadata, buffers, disposeOnDone) => {\n return comm_promise.then((comm) => {\n comm.send(data, metadata, buffers, disposeOnDone);\n });\n };\n var comm = {\n send: sendClosure\n };\n }\n window.PyViz.comms[comm_id] = comm;\n return comm;\n }\n window.PyViz.comm_manager = new JupyterCommManager();\n \n\n\nvar JS_MIME_TYPE = 'application/javascript';\nvar HTML_MIME_TYPE = 'text/html';\nvar EXEC_MIME_TYPE = 'application/vnd.holoviews_exec.v0+json';\nvar CLASS_NAME = 'output';\n\n/**\n * Render data to the DOM node\n */\nfunction render(props, node) {\n var div = document.createElement(\"div\");\n var script = document.createElement(\"script\");\n node.appendChild(div);\n node.appendChild(script);\n}\n\n/**\n * Handle when a new output is added\n */\nfunction handle_add_output(event, handle) {\n var output_area = handle.output_area;\n var output = handle.output;\n if ((output.data == undefined) || (!output.data.hasOwnProperty(EXEC_MIME_TYPE))) {\n return\n }\n var id = output.metadata[EXEC_MIME_TYPE][\"id\"];\n var toinsert = output_area.element.find(\".\" + CLASS_NAME.split(' ')[0]);\n if (id !== undefined) {\n var nchildren = toinsert.length;\n var html_node = toinsert[nchildren-1].children[0];\n html_node.innerHTML = output.data[HTML_MIME_TYPE];\n var scripts = [];\n var nodelist = html_node.querySelectorAll(\"script\");\n for (var i in nodelist) {\n if (nodelist.hasOwnProperty(i)) {\n scripts.push(nodelist[i])\n }\n }\n\n scripts.forEach( function (oldScript) {\n var newScript = document.createElement(\"script\");\n var attrs = [];\n var nodemap = oldScript.attributes;\n for (var j in nodemap) {\n if (nodemap.hasOwnProperty(j)) {\n attrs.push(nodemap[j])\n }\n }\n attrs.forEach(function(attr) { newScript.setAttribute(attr.name, attr.value) });\n newScript.appendChild(document.createTextNode(oldScript.innerHTML));\n oldScript.parentNode.replaceChild(newScript, oldScript);\n });\n if (JS_MIME_TYPE in output.data) {\n toinsert[nchildren-1].children[1].textContent = output.data[JS_MIME_TYPE];\n }\n output_area._hv_plot_id = id;\n if ((window.Bokeh !== undefined) && (id in Bokeh.index)) {\n window.PyViz.plot_index[id] = Bokeh.index[id];\n } else {\n window.PyViz.plot_index[id] = null;\n }\n } else if (output.metadata[EXEC_MIME_TYPE][\"server_id\"] !== undefined) {\n var bk_div = document.createElement(\"div\");\n bk_div.innerHTML = output.data[HTML_MIME_TYPE];\n var script_attrs = bk_div.children[0].attributes;\n for (var i = 0; i < script_attrs.length; i++) {\n toinsert[toinsert.length - 1].childNodes[1].setAttribute(script_attrs[i].name, script_attrs[i].value);\n }\n // store reference to server id on output_area\n output_area._bokeh_server_id = output.metadata[EXEC_MIME_TYPE][\"server_id\"];\n }\n}\n\n/**\n * Handle when an output is cleared or removed\n */\nfunction handle_clear_output(event, handle) {\n var id = handle.cell.output_area._hv_plot_id;\n var server_id = handle.cell.output_area._bokeh_server_id;\n if (((id === undefined) || !(id in PyViz.plot_index)) && (server_id !== undefined)) { return; }\n var comm = window.PyViz.comm_manager.get_client_comm(\"hv-extension-comm\", \"hv-extension-comm\", function () {});\n if (server_id !== null) {\n comm.send({event_type: 'server_delete', 'id': server_id});\n return;\n } else if (comm !== null) {\n comm.send({event_type: 'delete', 'id': id});\n }\n delete PyViz.plot_index[id];\n if ((window.Bokeh !== undefined) & (id in window.Bokeh.index)) {\n var doc = window.Bokeh.index[id].model.document\n doc.clear();\n const i = window.Bokeh.documents.indexOf(doc);\n if (i > -1) {\n window.Bokeh.documents.splice(i, 1);\n }\n }\n}\n\n/**\n * Handle kernel restart event\n */\nfunction handle_kernel_cleanup(event, handle) {\n delete PyViz.comms[\"hv-extension-comm\"];\n window.PyViz.plot_index = {}\n}\n\n/**\n * Handle update_display_data messages\n */\nfunction handle_update_output(event, handle) {\n handle_clear_output(event, {cell: {output_area: handle.output_area}})\n handle_add_output(event, handle)\n}\n\nfunction register_renderer(events, OutputArea) {\n function append_mime(data, metadata, element) {\n // create a DOM node to render to\n var toinsert = this.create_output_subarea(\n metadata,\n CLASS_NAME,\n EXEC_MIME_TYPE\n );\n this.keyboard_manager.register_events(toinsert);\n // Render to node\n var props = {data: data, metadata: metadata[EXEC_MIME_TYPE]};\n render(props, toinsert[0]);\n element.append(toinsert);\n return toinsert\n }\n\n events.on('output_added.OutputArea', handle_add_output);\n events.on('output_updated.OutputArea', handle_update_output);\n events.on('clear_output.CodeCell', handle_clear_output);\n events.on('delete.Cell', handle_clear_output);\n events.on('kernel_ready.Kernel', handle_kernel_cleanup);\n\n OutputArea.prototype.register_mime_type(EXEC_MIME_TYPE, append_mime, {\n safe: true,\n index: 0\n });\n}\n\nif (window.Jupyter !== undefined) {\n try {\n var events = require('base/js/events');\n var OutputArea = require('notebook/js/outputarea').OutputArea;\n if (OutputArea.prototype.mime_types().indexOf(EXEC_MIME_TYPE) == -1) {\n register_renderer(events, OutputArea);\n }\n } catch(err) {\n }\n}\n", - "application/vnd.holoviews_load.v0+json": "" - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.holoviews_exec.v0+json": "", - "text/html": [ - "
\n", - "
\n", - "
\n", - "" - ] - }, - "metadata": { - "application/vnd.holoviews_exec.v0+json": { - "id": "890cdc6e-b61f-49f0-bf13-7c907467a7fc" - } - }, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/javascript": "(function(root) {\n function now() {\n return new Date();\n }\n\n const force = false;\n const py_version = '3.6.0'.replace('rc', '-rc.').replace('.dev', '-dev.');\n const reloading = true;\n const Bokeh = root.Bokeh;\n\n // Set a timeout for this load but only if we are not already initializing\n if (typeof (root._bokeh_timeout) === \"undefined\" || (force || !root._bokeh_is_initializing)) {\n root._bokeh_timeout = Date.now() + 5000;\n root._bokeh_failed_load = false;\n }\n\n function run_callbacks() {\n try {\n root._bokeh_onload_callbacks.forEach(function(callback) {\n if (callback != null)\n callback();\n });\n } finally {\n delete root._bokeh_onload_callbacks;\n }\n console.debug(\"Bokeh: all callbacks have finished\");\n }\n\n function load_libs(css_urls, js_urls, js_modules, js_exports, callback) {\n if (css_urls == null) css_urls = [];\n if (js_urls == null) js_urls = [];\n if (js_modules == null) js_modules = [];\n if (js_exports == null) js_exports = {};\n\n root._bokeh_onload_callbacks.push(callback);\n\n if (root._bokeh_is_loading > 0) {\n // Don't load bokeh if it is still initializing\n console.debug(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n return null;\n } else if (js_urls.length === 0 && js_modules.length === 0 && Object.keys(js_exports).length === 0) {\n // There is nothing to load\n run_callbacks();\n return null;\n }\n\n function on_load() {\n root._bokeh_is_loading--;\n if (root._bokeh_is_loading === 0) {\n console.debug(\"Bokeh: all BokehJS libraries/stylesheets loaded\");\n run_callbacks()\n }\n }\n window._bokeh_on_load = on_load\n\n function on_error(e) {\n const src_el = e.srcElement\n console.error(\"failed to load \" + (src_el.href || src_el.src));\n }\n\n const skip = [];\n if (window.requirejs) {\n window.requirejs.config({'packages': {}, 'paths': {}, 'shim': {}});\n root._bokeh_is_loading = css_urls.length + 0;\n } else {\n root._bokeh_is_loading = css_urls.length + js_urls.length + js_modules.length + Object.keys(js_exports).length;\n }\n\n const existing_stylesheets = []\n const links = document.getElementsByTagName('link')\n for (let i = 0; i < links.length; i++) {\n const link = links[i]\n if (link.href != null) {\n existing_stylesheets.push(link.href)\n }\n }\n for (let i = 0; i < css_urls.length; i++) {\n const url = css_urls[i];\n const escaped = encodeURI(url)\n if (existing_stylesheets.indexOf(escaped) !== -1) {\n on_load()\n continue;\n }\n const element = document.createElement(\"link\");\n element.onload = on_load;\n element.onerror = on_error;\n element.rel = \"stylesheet\";\n element.type = \"text/css\";\n element.href = url;\n console.debug(\"Bokeh: injecting link tag for BokehJS stylesheet: \", url);\n document.body.appendChild(element);\n } var existing_scripts = []\n const scripts = document.getElementsByTagName('script')\n for (let i = 0; i < scripts.length; i++) {\n var script = scripts[i]\n if (script.src != null) {\n existing_scripts.push(script.src)\n }\n }\n for (let i = 0; i < js_urls.length; i++) {\n const url = js_urls[i];\n const escaped = encodeURI(url)\n if (skip.indexOf(escaped) !== -1 || existing_scripts.indexOf(escaped) !== -1) {\n if (!window.requirejs) {\n on_load();\n }\n continue;\n }\n const element = document.createElement('script');\n element.onload = on_load;\n element.onerror = on_error;\n element.async = false;\n element.src = url;\n console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n document.head.appendChild(element);\n }\n for (let i = 0; i < js_modules.length; i++) {\n const url = js_modules[i];\n const escaped = encodeURI(url)\n if (skip.indexOf(escaped) !== -1 || existing_scripts.indexOf(escaped) !== -1) {\n if (!window.requirejs) {\n on_load();\n }\n continue;\n }\n var element = document.createElement('script');\n element.onload = on_load;\n element.onerror = on_error;\n element.async = false;\n element.src = url;\n element.type = \"module\";\n console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n document.head.appendChild(element);\n }\n for (const name in js_exports) {\n const url = js_exports[name];\n const escaped = encodeURI(url)\n if (skip.indexOf(escaped) >= 0 || root[name] != null) {\n if (!window.requirejs) {\n on_load();\n }\n continue;\n }\n var element = document.createElement('script');\n element.onerror = on_error;\n element.async = false;\n element.type = \"module\";\n console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n element.textContent = `\n import ${name} from \"${url}\"\n window.${name} = ${name}\n window._bokeh_on_load()\n `\n document.head.appendChild(element);\n }\n if (!js_urls.length && !js_modules.length) {\n on_load()\n }\n };\n\n function inject_raw_css(css) {\n const element = document.createElement(\"style\");\n element.appendChild(document.createTextNode(css));\n document.body.appendChild(element);\n }\n\n const js_urls = [\"https://cdn.holoviz.org/panel/1.6.1/dist/bundled/reactiveesm/es-module-shims@^1.10.0/dist/es-module-shims.min.js\"];\n const js_modules = [];\n const js_exports = {};\n const css_urls = [];\n const inline_js = [ function(Bokeh) {\n Bokeh.set_log_level(\"info\");\n },\nfunction(Bokeh) {} // ensure no trailing comma for IE\n ];\n\n function run_inline_js() {\n if ((root.Bokeh !== undefined) || (force === true)) {\n for (let i = 0; i < inline_js.length; i++) {\n try {\n inline_js[i].call(root, root.Bokeh);\n } catch(e) {\n if (!reloading) {\n throw e;\n }\n }\n }\n // Cache old bokeh versions\n if (Bokeh != undefined && !reloading) {\n var NewBokeh = root.Bokeh;\n if (Bokeh.versions === undefined) {\n Bokeh.versions = new Map();\n }\n if (NewBokeh.version !== Bokeh.version) {\n Bokeh.versions.set(NewBokeh.version, NewBokeh)\n }\n root.Bokeh = Bokeh;\n }\n } else if (Date.now() < root._bokeh_timeout) {\n setTimeout(run_inline_js, 100);\n } else if (!root._bokeh_failed_load) {\n console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n root._bokeh_failed_load = true;\n }\n root._bokeh_is_initializing = false\n }\n\n function load_or_wait() {\n // Implement a backoff loop that tries to ensure we do not load multiple\n // versions of Bokeh and its dependencies at the same time.\n // In recent versions we use the root._bokeh_is_initializing flag\n // to determine whether there is an ongoing attempt to initialize\n // bokeh, however for backward compatibility we also try to ensure\n // that we do not start loading a newer (Panel>=1.0 and Bokeh>3) version\n // before older versions are fully initialized.\n if (root._bokeh_is_initializing && Date.now() > root._bokeh_timeout) {\n // If the timeout and bokeh was not successfully loaded we reset\n // everything and try loading again\n root._bokeh_timeout = Date.now() + 5000;\n root._bokeh_is_initializing = false;\n root._bokeh_onload_callbacks = undefined;\n root._bokeh_is_loading = 0\n console.log(\"Bokeh: BokehJS was loaded multiple times but one version failed to initialize.\");\n load_or_wait();\n } else if (root._bokeh_is_initializing || (typeof root._bokeh_is_initializing === \"undefined\" && root._bokeh_onload_callbacks !== undefined)) {\n setTimeout(load_or_wait, 100);\n } else {\n root._bokeh_is_initializing = true\n root._bokeh_onload_callbacks = []\n const bokeh_loaded = root.Bokeh != null && (root.Bokeh.version === py_version || (root.Bokeh.versions !== undefined && root.Bokeh.versions.has(py_version)));\n if (!reloading && !bokeh_loaded) {\n if (root.Bokeh) {\n root.Bokeh = undefined;\n }\n console.debug(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n }\n load_libs(css_urls, js_urls, js_modules, js_exports, function() {\n console.debug(\"Bokeh: BokehJS plotting callback run at\", now());\n run_inline_js();\n });\n }\n }\n // Give older versions of the autoload script a head-start to ensure\n // they initialize before we start loading newer version.\n setTimeout(load_or_wait, 100)\n}(window));", - "application/vnd.holoviews_load.v0+json": "" - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/javascript": "\nif ((window.PyViz === undefined) || (window.PyViz instanceof HTMLElement)) {\n window.PyViz = {comms: {}, comm_status:{}, kernels:{}, receivers: {}, plot_index: []}\n}\n\n\n function JupyterCommManager() {\n }\n\n JupyterCommManager.prototype.register_target = function(plot_id, comm_id, msg_handler) {\n if (window.comm_manager || ((window.Jupyter !== undefined) && (Jupyter.notebook.kernel != null))) {\n var comm_manager = window.comm_manager || Jupyter.notebook.kernel.comm_manager;\n comm_manager.register_target(comm_id, function(comm) {\n comm.on_msg(msg_handler);\n });\n } else if ((plot_id in window.PyViz.kernels) && (window.PyViz.kernels[plot_id])) {\n window.PyViz.kernels[plot_id].registerCommTarget(comm_id, function(comm) {\n comm.onMsg = msg_handler;\n });\n } else if (typeof google != 'undefined' && google.colab.kernel != null) {\n google.colab.kernel.comms.registerTarget(comm_id, (comm) => {\n var messages = comm.messages[Symbol.asyncIterator]();\n function processIteratorResult(result) {\n var message = result.value;\n console.log(message)\n var content = {data: message.data, comm_id};\n var buffers = []\n for (var buffer of message.buffers || []) {\n buffers.push(new DataView(buffer))\n }\n var metadata = message.metadata || {};\n var msg = {content, buffers, metadata}\n msg_handler(msg);\n return messages.next().then(processIteratorResult);\n }\n return messages.next().then(processIteratorResult);\n })\n }\n }\n\n JupyterCommManager.prototype.get_client_comm = function(plot_id, comm_id, msg_handler) {\n if (comm_id in window.PyViz.comms) {\n return window.PyViz.comms[comm_id];\n } else if (window.comm_manager || ((window.Jupyter !== undefined) && (Jupyter.notebook.kernel != null))) {\n var comm_manager = window.comm_manager || Jupyter.notebook.kernel.comm_manager;\n var comm = comm_manager.new_comm(comm_id, {}, {}, {}, comm_id);\n if (msg_handler) {\n comm.on_msg(msg_handler);\n }\n } else if ((plot_id in window.PyViz.kernels) && (window.PyViz.kernels[plot_id])) {\n var comm = window.PyViz.kernels[plot_id].connectToComm(comm_id);\n comm.open();\n if (msg_handler) {\n comm.onMsg = msg_handler;\n }\n } else if (typeof google != 'undefined' && google.colab.kernel != null) {\n var comm_promise = google.colab.kernel.comms.open(comm_id)\n comm_promise.then((comm) => {\n window.PyViz.comms[comm_id] = comm;\n if (msg_handler) {\n var messages = comm.messages[Symbol.asyncIterator]();\n function processIteratorResult(result) {\n var message = result.value;\n var content = {data: message.data};\n var metadata = message.metadata || {comm_id};\n var msg = {content, metadata}\n msg_handler(msg);\n return messages.next().then(processIteratorResult);\n }\n return messages.next().then(processIteratorResult);\n }\n })\n var sendClosure = (data, metadata, buffers, disposeOnDone) => {\n return comm_promise.then((comm) => {\n comm.send(data, metadata, buffers, disposeOnDone);\n });\n };\n var comm = {\n send: sendClosure\n };\n }\n window.PyViz.comms[comm_id] = comm;\n return comm;\n }\n window.PyViz.comm_manager = new JupyterCommManager();\n \n\n\nvar JS_MIME_TYPE = 'application/javascript';\nvar HTML_MIME_TYPE = 'text/html';\nvar EXEC_MIME_TYPE = 'application/vnd.holoviews_exec.v0+json';\nvar CLASS_NAME = 'output';\n\n/**\n * Render data to the DOM node\n */\nfunction render(props, node) {\n var div = document.createElement(\"div\");\n var script = document.createElement(\"script\");\n node.appendChild(div);\n node.appendChild(script);\n}\n\n/**\n * Handle when a new output is added\n */\nfunction handle_add_output(event, handle) {\n var output_area = handle.output_area;\n var output = handle.output;\n if ((output.data == undefined) || (!output.data.hasOwnProperty(EXEC_MIME_TYPE))) {\n return\n }\n var id = output.metadata[EXEC_MIME_TYPE][\"id\"];\n var toinsert = output_area.element.find(\".\" + CLASS_NAME.split(' ')[0]);\n if (id !== undefined) {\n var nchildren = toinsert.length;\n var html_node = toinsert[nchildren-1].children[0];\n html_node.innerHTML = output.data[HTML_MIME_TYPE];\n var scripts = [];\n var nodelist = html_node.querySelectorAll(\"script\");\n for (var i in nodelist) {\n if (nodelist.hasOwnProperty(i)) {\n scripts.push(nodelist[i])\n }\n }\n\n scripts.forEach( function (oldScript) {\n var newScript = document.createElement(\"script\");\n var attrs = [];\n var nodemap = oldScript.attributes;\n for (var j in nodemap) {\n if (nodemap.hasOwnProperty(j)) {\n attrs.push(nodemap[j])\n }\n }\n attrs.forEach(function(attr) { newScript.setAttribute(attr.name, attr.value) });\n newScript.appendChild(document.createTextNode(oldScript.innerHTML));\n oldScript.parentNode.replaceChild(newScript, oldScript);\n });\n if (JS_MIME_TYPE in output.data) {\n toinsert[nchildren-1].children[1].textContent = output.data[JS_MIME_TYPE];\n }\n output_area._hv_plot_id = id;\n if ((window.Bokeh !== undefined) && (id in Bokeh.index)) {\n window.PyViz.plot_index[id] = Bokeh.index[id];\n } else {\n window.PyViz.plot_index[id] = null;\n }\n } else if (output.metadata[EXEC_MIME_TYPE][\"server_id\"] !== undefined) {\n var bk_div = document.createElement(\"div\");\n bk_div.innerHTML = output.data[HTML_MIME_TYPE];\n var script_attrs = bk_div.children[0].attributes;\n for (var i = 0; i < script_attrs.length; i++) {\n toinsert[toinsert.length - 1].childNodes[1].setAttribute(script_attrs[i].name, script_attrs[i].value);\n }\n // store reference to server id on output_area\n output_area._bokeh_server_id = output.metadata[EXEC_MIME_TYPE][\"server_id\"];\n }\n}\n\n/**\n * Handle when an output is cleared or removed\n */\nfunction handle_clear_output(event, handle) {\n var id = handle.cell.output_area._hv_plot_id;\n var server_id = handle.cell.output_area._bokeh_server_id;\n if (((id === undefined) || !(id in PyViz.plot_index)) && (server_id !== undefined)) { return; }\n var comm = window.PyViz.comm_manager.get_client_comm(\"hv-extension-comm\", \"hv-extension-comm\", function () {});\n if (server_id !== null) {\n comm.send({event_type: 'server_delete', 'id': server_id});\n return;\n } else if (comm !== null) {\n comm.send({event_type: 'delete', 'id': id});\n }\n delete PyViz.plot_index[id];\n if ((window.Bokeh !== undefined) & (id in window.Bokeh.index)) {\n var doc = window.Bokeh.index[id].model.document\n doc.clear();\n const i = window.Bokeh.documents.indexOf(doc);\n if (i > -1) {\n window.Bokeh.documents.splice(i, 1);\n }\n }\n}\n\n/**\n * Handle kernel restart event\n */\nfunction handle_kernel_cleanup(event, handle) {\n delete PyViz.comms[\"hv-extension-comm\"];\n window.PyViz.plot_index = {}\n}\n\n/**\n * Handle update_display_data messages\n */\nfunction handle_update_output(event, handle) {\n handle_clear_output(event, {cell: {output_area: handle.output_area}})\n handle_add_output(event, handle)\n}\n\nfunction register_renderer(events, OutputArea) {\n function append_mime(data, metadata, element) {\n // create a DOM node to render to\n var toinsert = this.create_output_subarea(\n metadata,\n CLASS_NAME,\n EXEC_MIME_TYPE\n );\n this.keyboard_manager.register_events(toinsert);\n // Render to node\n var props = {data: data, metadata: metadata[EXEC_MIME_TYPE]};\n render(props, toinsert[0]);\n element.append(toinsert);\n return toinsert\n }\n\n events.on('output_added.OutputArea', handle_add_output);\n events.on('output_updated.OutputArea', handle_update_output);\n events.on('clear_output.CodeCell', handle_clear_output);\n events.on('delete.Cell', handle_clear_output);\n events.on('kernel_ready.Kernel', handle_kernel_cleanup);\n\n OutputArea.prototype.register_mime_type(EXEC_MIME_TYPE, append_mime, {\n safe: true,\n index: 0\n });\n}\n\nif (window.Jupyter !== undefined) {\n try {\n var events = require('base/js/events');\n var OutputArea = require('notebook/js/outputarea').OutputArea;\n if (OutputArea.prototype.mime_types().indexOf(EXEC_MIME_TYPE) == -1) {\n register_renderer(events, OutputArea);\n }\n } catch(err) {\n }\n}\n", - "application/vnd.holoviews_load.v0+json": "" - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/joe/miniconda3/envs/parcels/lib/python3.13/site-packages/uxarray/grid/coordinates.py:255: UserWarning: This cannot be guaranteed to work correctly on concave polygons\n", - " warnings.warn(\"This cannot be guaranteed to work correctly on concave polygons\")\n" - ] - }, - { - "data": {}, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.holoviews_exec.v0+json": "", - "text/html": [ - "
\n", - "
\n", - "
\n", - "" - ], - "text/plain": [ - ":Path [Longitude,Latitude]" - ] - }, - "execution_count": 1, - "metadata": { - "application/vnd.holoviews_exec.v0+json": { - "id": "9b0baffb-0afd-4cba-a35f-55f20dcbe9f8" - } - }, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "def stommel_fieldset_uxarray(xdim=200, ydim=200):\n", " \"\"\"Simulate a periodic current along a western boundary, with significantly\n", @@ -410,7 +146,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -548,1073 +284,36 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
<xarray.Dataset> Size: 30kB\n",
-       "Dimensions:   (time: 1, nz1: 1, node_lat: 50, node_lon: 50)\n",
-       "Coordinates:\n",
-       "  * time      (time) datetime64[ns] 8B 2000-01-01\n",
-       "  * nz1       (nz1) int64 8B 0\n",
-       "  * node_lat  (node_lat) float32 200B 0.0 2.041e+05 ... 9.796e+06 1e+07\n",
-       "  * node_lon  (node_lon) float32 200B 0.0 2.041e+05 ... 9.796e+06 1e+07\n",
-       "Data variables:\n",
-       "    u         (time, nz1, node_lat, node_lon) float32 10kB -0.0 -0.4752 ... 0.0\n",
-       "    v         (time, nz1, node_lat, node_lon) float32 10kB 0.0 0.0 ... 1.373e-08\n",
-       "    p         (time, nz1, node_lat, node_lon) float32 10kB 0.0 0.0 ... -0.0
" - ], - "text/plain": [ - " Size: 30kB\n", - "Dimensions: (time: 1, nz1: 1, node_lat: 50, node_lon: 50)\n", - "Coordinates:\n", - " * time (time) datetime64[ns] 8B 2000-01-01\n", - " * nz1 (nz1) int64 8B 0\n", - " * node_lat (node_lat) float32 200B 0.0 2.041e+05 ... 9.796e+06 1e+07\n", - " * node_lon (node_lon) float32 200B 0.0 2.041e+05 ... 9.796e+06 1e+07\n", - "Data variables:\n", - " u (time, nz1, node_lat, node_lon) float32 10kB -0.0 -0.4752 ... 0.0\n", - " v (time, nz1, node_lat, node_lon) float32 10kB 0.0 0.0 ... 1.373e-08\n", - " p (time, nz1, node_lat, node_lon) float32 10kB 0.0 0.0 ... -0.0" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "ds_arakawa_a" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'description': 'zonal velocity',\n", - " 'units': 'm/s',\n", - " 'location': 'node',\n", - " 'mesh': 'Arakawa-A'}" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "ds_arakawa_a[\"u\"].attrs" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
<xarray.Dataset> Size: 31kB\n",
-       "Dimensions:   (time: 1, nz1: 1, face_lat: 50, node_lon: 50, node_lat: 50,\n",
-       "               face_lon: 50)\n",
-       "Coordinates:\n",
-       "  * time      (time) datetime64[ns] 8B 2000-01-01\n",
-       "  * nz1       (nz1) int64 8B 0\n",
-       "  * face_lat  (face_lat) float32 200B 0.0 2.041e+05 ... 9.796e+06 1e+07\n",
-       "  * node_lon  (node_lon) float32 200B -1e+05 1.041e+05 ... 9.696e+06 9.9e+06\n",
-       "  * node_lat  (node_lat) float32 200B -1e+05 1.041e+05 ... 9.696e+06 9.9e+06\n",
-       "  * face_lon  (face_lon) float32 200B 0.0 2.041e+05 ... 9.796e+06 1e+07\n",
-       "Data variables:\n",
-       "    u         (time, nz1, face_lat, node_lon) float32 10kB 0.0 0.0 ... 0.0\n",
-       "    v         (time, nz1, node_lat, face_lon) float32 10kB 0.0 0.0 ... 1.401e-08\n",
-       "    p         (time, nz1, face_lat, face_lon) float32 10kB 0.0 0.0 ... -0.0
" - ], - "text/plain": [ - " Size: 31kB\n", - "Dimensions: (time: 1, nz1: 1, face_lat: 50, node_lon: 50, node_lat: 50,\n", - " face_lon: 50)\n", - "Coordinates:\n", - " * time (time) datetime64[ns] 8B 2000-01-01\n", - " * nz1 (nz1) int64 8B 0\n", - " * face_lat (face_lat) float32 200B 0.0 2.041e+05 ... 9.796e+06 1e+07\n", - " * node_lon (node_lon) float32 200B -1e+05 1.041e+05 ... 9.696e+06 9.9e+06\n", - " * node_lat (node_lat) float32 200B -1e+05 1.041e+05 ... 9.696e+06 9.9e+06\n", - " * face_lon (face_lon) float32 200B 0.0 2.041e+05 ... 9.796e+06 1e+07\n", - "Data variables:\n", - " u (time, nz1, face_lat, node_lon) float32 10kB 0.0 0.0 ... 0.0\n", - " v (time, nz1, node_lat, face_lon) float32 10kB 0.0 0.0 ... 1.401e-08\n", - " p (time, nz1, face_lat, face_lon) float32 10kB 0.0 0.0 ... -0.0" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "ds_arakawa_c" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " Size: 8B\n", - "array(1142.17017624)\n", - " Size: 8B\n", - "array(1.04820871)\n", - " Size: 8B\n", - "array(108.96400321)\n" - ] - } - ], + "outputs": [], "source": [ "import numpy as np\n", "min_length_scale = 1111111.111111111*np.sqrt(np.min(uxds.uxgrid.face_areas))\n", @@ -1630,21 +329,9 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "ename": "", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n", - "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n", - "\u001b[1;31mClick here for more info. \n", - "\u001b[1;31mView Jupyter log for further details." - ] - } - ], + "outputs": [], "source": [ "import uxarray as ux\n", "from datetime import timedelta\n", diff --git a/parcels/xfield.py b/parcels/xfield.py index f07eb7005c..b39b9067bd 100644 --- a/parcels/xfield.py +++ b/parcels/xfield.py @@ -114,8 +114,8 @@ class XField: def _interp_template( self, ti: int, - zi: int, ei: int, + bcoords: np.ndarray, t: Union[np.float32,np.float64], z: Union[np.float32,np.float64], y: Union[np.float32,np.float64], @@ -126,7 +126,7 @@ def _interp_template( def _validate_interp_function(self, func: Callable): """Ensures that the function has the correct signature.""" - expected_params = ["ti", "zi", "ei", "t", "z", "y", "x"] + expected_params = ["ti", "ei", "bcoords", "t", "z", "y", "x"] expected_return_types = (np.float32,np.float64) sig = inspect.signature(func) @@ -160,7 +160,9 @@ def __init__( self._validate_dataarray(data) self._parent_mesh = data.attributes["mesh"] + self._mesh_type = data.attributes["mesh_type"] self._location = data.attributes["location"] + # Set the vertical location if "nz1" in data.dims: self._vertical_location = "center" @@ -177,12 +179,152 @@ def __init__( self.igrid = -1 # Default the grid index to -1 self.fieldtype = self.name if fieldtype is None else fieldtype + if self._mesh_type == "flat" or (self.fieldtype not in unitconverters_map.keys()): + self.units = UnitConverter() + elif self._mesh_type == "spherical": + self.units = unitconverters_map[self.fieldtype] + else: + raise ValueError("Unsupported mesh type in data array attributes. Choose either: 'spherical' or 'flat'") + self.fieldset: XFieldSet | None = None if allow_time_extrapolation is None: self.allow_time_extrapolation = True if len(self.data["time"]) == 1 else False else: self.allow_time_extrapolation = allow_time_extrapolation + def __repr__(self): + return field_repr(self) + + @property + def grid(self): + if type(self.data) is ux.UxDataArray: + return self.data.uxgrid + else: + return self.data # To do : need to decide on what to return for xarray.DataArray objects + + @property + def lat(self): + if type(self.data) is ux.UxDataArray: + if self._location == "node": + return self.data.uxgrid.node_lat + elif self._location == "face": + return self.data.uxgrid.face_lat + elif self._location == "edge": + return self.data.uxgrid.edge_lat + else: + if self._location == "node": + return self.data.node_lat + elif self._location == "face": + return self.data.face_lat + elif self._location == "x_edge": + return self.data.face_lat + elif self._location == "y_edge": + return self.data.node_lat + + @property + def lon(self): + if type(self.data) is ux.UxDataArray: + if self._location == "node": + return self.data.uxgrid.node_lon + elif self._location == "face": + return self.data.uxgrid.face_lon + elif self._location == "edge": + return self.data.uxgrid.edge_lon + else: + if self._location == "node": + return self.data.node_lon + elif self._location == "face": + return self.data.face_lon + elif self._location == "x_edge": + return self.data.node_lon + elif self._location == "y_edge": + return self.data.face_lon + + @property + def depth(self): + if type(self.data) is ux.UxDataArray: + if self._vertical_location == "center": + return self.data.uxgrid.nz1 + elif self._vertical_location == "face": + return self.data.uxgrid.nz + else: + if self._vertical_location == "center": + return self.data.nz1 + elif self._vertical_location == "face": + return self.data.nz + + @property + def interp_method(self): + return self._interp_method + + @interp_method.setter + def interp_method(self, method: Callable): + self._validate_interp_function(method) + self._interp_method = method + + # @property + # def gridindexingtype(self): + # return self._gridindexingtype + def _search_indices(self, time, z, y, x, ei=None, search2D=False): + + tau, ti = self._search_time_index(time) # To do : Need to implement this method + + if type(self.data) is ux.UxDataArray: + bcoords, ei = self._search_indices_unstructured(z, y, x, ei=ei, search2D=search2D) # To do : Need to implement this method + else: + bcoords, ei = self._search_indices_structured(z, y, x, ei=ei, search2D=search2D) # To do : Need to implement this method + return bcoords, ei, ti + + def _interpolate(self, time, z, y, x, ei=None): + + try: + bcoords, ei, ti = self._search_indices(time, z, y, x, ei=ei) + val = self._interp_method(ti, ei, bcoords, time, z, y, x) + + if np.isnan(val): + # Detect Out-of-bounds sampling and raise exception + _raise_field_out_of_bound_error(z, y, x) + else: + return val + + except (FieldSamplingError, FieldOutOfBoundError, FieldOutOfBoundSurfaceError) as e: + e = add_note(e, f"Error interpolating field '{self.name}'.", before=True) + raise e + + def _check_velocitysampling(self): + if self.name in ["U", "V", "W"]: + warnings.warn( + "Sampling of velocities should normally be done using fieldset.UV or fieldset.UVW object; tread carefully", + RuntimeWarning, + stacklevel=2, + ) + + def __getitem__(self, key): + self._check_velocitysampling() + try: + if _isParticle(key): + return self.eval(key.time, key.depth, key.lat, key.lon, key) + else: + return self.eval(*key) + except tuple(AllParcelsErrorCodes.keys()) as error: + return _deal_with_errors(error, key, vector_type=None) + + def eval(self, time, z, y, x, ei=None, applyConversion=True): + """Interpolate field values in space and time. + + We interpolate linearly in time and apply implicit unit + conversion to the result. Note that we defer to + scipy.interpolate to perform spatial interpolation. + """ + + value = self._interpolate(time, z, y, x, ei=ei) + + if applyConversion: + return self.units.to_target(value, z, y, x) + else: + return value + + def _validate_dataarray(self): """ Verifies that all the required attributes are present in the xarray.DataArray or uxarray.UxDataArray object.""" @@ -201,7 +343,7 @@ def _validate_dataarray(self): ) # Validate attributes - required_keys = ["location", "mesh"] + required_keys = ["location", "mesh", "mesh_type"] for key in required_keys: if key not in self.data.attrs.keys(): raise ValueError( From 10e29cbd666a4f459f4be66dcc7c3d4c4a909478 Mon Sep 17 00:00:00 2001 From: Joe Date: Mon, 24 Mar 2025 12:48:33 -0400 Subject: [PATCH 04/46] Remove fieldtype attribute --- parcels/xfield.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/parcels/xfield.py b/parcels/xfield.py index b39b9067bd..f9b8415631 100644 --- a/parcels/xfield.py +++ b/parcels/xfield.py @@ -87,7 +87,7 @@ class XField: The xarray.DataArray or uxarray.UxDataArray object contains the field data and metadata. * dims: (time, [nz1 | nz], [face_lat | node_lat | edge_lat], [face_lon | node_lon | edge_lon]) - * attrs: (location, mesh) + * attrs: (location, mesh, mesh_type) When using a xarray.DataArray object, * The xarray.DataArray object must have the "location" and "mesh" attributes set. @@ -149,7 +149,6 @@ def __init__( self, name: str, data: xr.DataArray | ux.UxDataArray, - fieldtype=None, interp_method: Callable | None = None, allow_time_extrapolation: bool | None = None, ): @@ -177,12 +176,11 @@ def __init__( self._interp_method = interp_method self.igrid = -1 # Default the grid index to -1 - self.fieldtype = self.name if fieldtype is None else fieldtype - if self._mesh_type == "flat" or (self.fieldtype not in unitconverters_map.keys()): + if self._mesh_type == "flat" or (self.name not in unitconverters_map.keys()): self.units = UnitConverter() elif self._mesh_type == "spherical": - self.units = unitconverters_map[self.fieldtype] + self.units = unitconverters_map[self.name] else: raise ValueError("Unsupported mesh type in data array attributes. Choose either: 'spherical' or 'flat'") @@ -323,8 +321,7 @@ def eval(self, time, z, y, x, ei=None, applyConversion=True): return self.units.to_target(value, z, y, x) else: return value - - + def _validate_dataarray(self): """ Verifies that all the required attributes are present in the xarray.DataArray or uxarray.UxDataArray object.""" From dcf31a2ea926dc4f1d45c5b0ebd927bd89f86a54 Mon Sep 17 00:00:00 2001 From: Joe Date: Mon, 24 Mar 2025 12:53:06 -0400 Subject: [PATCH 05/46] Add ravel/unravel index --- parcels/xfield.py | 76 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 75 insertions(+), 1 deletion(-) diff --git a/parcels/xfield.py b/parcels/xfield.py index f9b8415631..4831992148 100644 --- a/parcels/xfield.py +++ b/parcels/xfield.py @@ -251,6 +251,25 @@ def depth(self): elif self._vertical_location == "face": return self.data.nz + @property + def nx(self): + if type(self.data) is xr.DataArray: + if "face_lon" in self.data.dims: + return self.data.sizes["face_lon"] + elif "node_lon" in self.data.dims: + return self.data.sizes["node_lon"] + else: + return 0 # To do : Discuss what we want to return for uxdataarray obj + @property + def ny(self): + if type(self.data) is xr.DataArray: + if "face_lat" in self.data.dims: + return self.data.sizes["face_lat"] + elif "node_lat" in self.data.dims: + return self.data.sizes["node_lat"] + else: + return 0 # To do : Discuss what we want to return for uxdataarray obj + @property def interp_method(self): return self._interp_method @@ -321,7 +340,62 @@ def eval(self, time, z, y, x, ei=None, applyConversion=True): return self.units.to_target(value, z, y, x) else: return value - + + def _rescale_and_set_minmax(self, data): + data[np.isnan(data)] = 0 + return data + + def ravel_index(self, zi, yi, xi): + """Return the flat index of the given grid points. + Only used when working with fields on a structured grid. + + Parameters + ---------- + zi : int + z index + yi : int + y index + xi : int + x index + + Returns + ------- + int + flat index + """ + if type(self.data) is xr.DataArray: + return xi + self.nx * (yi + self.ny * zi) + else: + return None + + def unravel_index(self, ei): + """Return the zi, yi, xi indices for a given flat index. + Only used when working with fields on a structured grid. + + Parameters + ---------- + ei : int + The flat index to be unraveled. + + Returns + ------- + zi : int + The z index. + yi : int + The y index. + xi : int + The x index. + """ + if type(self.data) is xr.DataArray: + _ei = ei[self.igrid] + zi = _ei // (self.nx * self.ny) + _ei = _ei % (self.nx * self.ny) + yi = _ei // self.nx + xi = _ei % self.nx + return zi, yi, xi + else: + return None,None,None # To do : Discuss what we want to return for uxdataarray + def _validate_dataarray(self): """ Verifies that all the required attributes are present in the xarray.DataArray or uxarray.UxDataArray object.""" From 65c573513cf90a8e51d076e06f29364cc12f9bf4 Mon Sep 17 00:00:00 2001 From: Joe Date: Mon, 24 Mar 2025 13:16:10 -0400 Subject: [PATCH 06/46] Add ravel/unravel; Add vector field interpolation support and access via particle as key --- parcels/xfield.py | 116 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 111 insertions(+), 5 deletions(-) diff --git a/parcels/xfield.py b/parcels/xfield.py index 4831992148..87491bf6a9 100644 --- a/parcels/xfield.py +++ b/parcels/xfield.py @@ -292,11 +292,11 @@ def _search_indices(self, time, z, y, x, ei=None, search2D=False): bcoords, ei = self._search_indices_structured(z, y, x, ei=ei, search2D=search2D) # To do : Need to implement this method return bcoords, ei, ti - def _interpolate(self, time, z, y, x, ei=None): + def _interpolate(self, time, z, y, x, ei): try: - bcoords, ei, ti = self._search_indices(time, z, y, x, ei=ei) - val = self._interp_method(ti, ei, bcoords, time, z, y, x) + bcoords, _ei, ti = self._search_indices(time, z, y, x, ei=ei) + val = self._interp_method(ti, _ei, bcoords, time, z, y, x) if np.isnan(val): # Detect Out-of-bounds sampling and raise exception @@ -333,8 +333,12 @@ def eval(self, time, z, y, x, ei=None, applyConversion=True): conversion to the result. Note that we defer to scipy.interpolate to perform spatial interpolation. """ + if ei is None: + _ei = 0 + else: + _ei = ei[self.igrid] - value = self._interpolate(time, z, y, x, ei=ei) + value = self._interpolate(time, z, y, x, ei=_ei) if applyConversion: return self.units.to_target(value, z, y, x) @@ -451,7 +455,52 @@ def __contains__(self, key: str): class XVectorField: """XVectorField class that holds vector field data needed to execute particles.""" - def __init__(self, name: str, U: XField, V: XField, W: XField | None = None): + + + @staticmethod + def _vector_interp_template( + self, + ti: int, + ei: int, + bcoords: np.ndarray, + t: Union[np.float32,np.float64], + z: Union[np.float32,np.float64], + y: Union[np.float32,np.float64], + x: Union[np.float32,np.float64] + )-> Union[np.float32,np.float64]: + """ Template function used for the signature check of the lateral interpolation methods.""" + return 0.0 + + def _validate_vector_interp_function(self, func: Callable): + """Ensures that the function has the correct signature.""" + expected_params = ["ti", "ei", "bcoords", "t", "z", "y", "x"] + expected_return_types = (np.float32,np.float64) + + sig = inspect.signature(func) + params = list(sig.parameters.keys()) + + # Check the parameter names and count + if params != expected_params: + raise TypeError( + f"Function must have parameters {expected_params}, but got {params}" + ) + + # Check return annotation if present + return_annotation = sig.return_annotation + if return_annotation not in (inspect.Signature.empty, *expected_return_types): + raise TypeError( + f"Function must return a float, but got {return_annotation}" + ) + + def __init__( + self, + name: str, + U: XField, + V: XField, + W: XField | None = None, + vector_interp_method: Callable | None = None + ): + self.name = name self.U = U self.V = V @@ -462,6 +511,13 @@ def __init__(self, name: str, U: XField, V: XField, W: XField | None = None): else: self.vector_type = "2D" + # Setting the interpolation method dynamically + if vector_interp_method is None: + self._vector_interp_method = None + else: + self._validate_vector_interp_function(vector_interp_method) + self._interp_method = vector_interp_method + def __repr__(self): return f"""<{type(self).__name__}> name: {self.name!r} @@ -469,6 +525,14 @@ def __repr__(self): V: {default_repr(self.V)} W: {default_repr(self.W)}""" + @property + def vector_interp_method(self): + return self._vector_interp_method + + @vector_interp_method.setter + def vector_interp_method(self, method: Callable): + self._validate_vector_interp_function(method) + self._vector_interp_method = method # @staticmethod # To do : def _check_grid_dimensions(grid1, grid2): @@ -478,6 +542,48 @@ def __repr__(self): # and np.allclose(grid1.depth, grid2.depth) # and np.allclose(grid1.time, grid2.time) # ) + def _interpolate(self, time, z, y, x, ei): + + bcoords, _ei, ti = self._search_indices(time, z, y, x, ei=ei) + + if self._vector_interp_method is None: + u = self.U.eval(time, z, y, x, _ei, applyConversion=False) + v = self.V.eval(time, z, y, x, _ei, applyConversion=False) + if "3D" in self.vector_type: + w = self.W.eval(time, z, y, x, ei, applyConversion=False) + return (u, v, w) + else: + return (u, v, 0) + else: + (u,v,w) = self._vector_interp_method(ti, _ei, bcoords, time, z, y, x) + return (u, v, w) + + + def eval(self, time, z, y, x, ei=None, applyConversion=True): + + if ei is None: + _ei = 0 + else: + _ei = ei[self.igrid] + + (u,v,w) = self._interpolate(time, z, y, x, _ei) + + if applyConversion: + u = self.U.units.to_target(u, z, y, x) + v = self.V.units.to_target(v, z, y, x) + if "3D" in self.vector_type: + w = self.W.units.to_target(w, z, y, x) + + return (u, v, w) + + def __getitem__(self,key): + try: + if _isParticle(key): + return self.eval(key.time, key.depth, key.lat, key.lon, key.ei) + else: + return self.eval(*key) + except tuple(AllParcelsErrorCodes.keys()) as error: + return _deal_with_errors(error, key, vector_type=self.vector_type) # Private helper routines From c6a9340f82b4df9019c64414998badc2690fb84c Mon Sep 17 00:00:00 2001 From: Joe Date: Mon, 24 Mar 2025 14:18:17 -0400 Subject: [PATCH 07/46] Add basic implementation for xfieldset --- parcels/xfieldset.py | 248 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 247 insertions(+), 1 deletion(-) diff --git a/parcels/xfieldset.py b/parcels/xfieldset.py index 478d132026..3b62ad9a0d 100644 --- a/parcels/xfieldset.py +++ b/parcels/xfieldset.py @@ -16,7 +16,7 @@ import xarray as xr import uxarray as ux -__all__ = ["FieldSet"] +__all__ = ["XFieldSet"] class XFieldSet: @@ -52,12 +52,99 @@ class XFieldSet: def __init__(self, ds: xr.Dataset | ux.UxDataset): self.ds = ds + self._completed: bool = False # Create pointers to each (Ux)DataArray for field in self.ds.data_vars: setattr(self, field, XField(field,self.ds[field])) self._add_UVfield() + def __repr__(self): + return fieldset_repr(self) + + # @property + # def particlefile(self): + # return self._particlefile + + # @staticmethod + # def checkvaliddimensionsdict(dims): + # for d in dims: + # if d not in ["lon", "lat", "depth", "time"]: + # raise NameError(f"{d} is not a valid key in the dimensions dictionary") + + def add_field(self, field: XField, name: str | None = None): + """Add a :class:`parcels.field.Field` object to the FieldSet. + + Parameters + ---------- + field : parcels.field.Field + Field object to be added + name : str + Name of the :class:`parcels.field.Field` object to be added. Defaults + to name in Field object. + + + Examples + -------- + For usage examples see the following tutorials: + + * `Unit converters <../examples/tutorial_unitconverters.ipynb>`__ (Default value = None) + + """ + if self._completed: + raise RuntimeError( + "FieldSet has already been completed. Are you trying to add a Field after you've created the ParticleSet?" + ) + name = field.name if name is None else name + + if hasattr(self, name): # check if Field with same name already exists when adding new Field + raise RuntimeError(f"FieldSet already has a Field with name '{name}'") + else: + setattr(self, name, field) + + def add_constant_field(self, name: str, value: float, mesh: Mesh = "flat"): + """Wrapper function to add a Field that is constant in space, + useful e.g. when using constant horizontal diffusivity + + Parameters + ---------- + name : str + Name of the :class:`parcels.field.Field` object to be added + value : float + Value of the constant field (stored as 32-bit float) + mesh : str + String indicating the type of mesh coordinates and + units used during velocity interpolation, see also `this tutorial <../examples/tutorial_unitconverters.ipynb>`__: + + 1. spherical (default): Lat and lon in degree, with a + correction for zonal velocity U near the poles. + 2. flat: No conversion, lat/lon are assumed to be in m. + """ + import pandas as pd + + time = pd.to_datetime(['2000-01-01']) + values = np.zeros((1,1,1,1), dtype=np.float32) + value + data = xr.DataArray( + data=values, + name=name, + dims='null', + coords = [time,[0],[0],[0]], + attrs=dict( + description="null", + units="null", + location="node", + mesh=f"constant", + mesh_type=mesh + )) + self.add_field( + XField( + name, + data, + interp_method=None, # To do : Need to define an interpolation method for constants + allow_time_extrapolation=True + ) + ) + def add_vector_field(self, vfield): """Add a :class:`parcels.field.VectorField` object to the FieldSet. @@ -88,4 +175,163 @@ def _add_UVfield(self): if not hasattr(self, "UVW") and hasattr(self, "w"): self.add_xvector_field(XVectorField("UVW", self.u, self.v, self.w)) + def _check_complete(self): + assert self.u, 'FieldSet does not have a Field named "u"' + assert self.v, 'FieldSet does not have a Field named "v"' + for attr, value in vars(self).items(): + if type(value) is XField: + assert value.name == attr, f"Field {value.name}.name ({attr}) is not consistent" + + self._add_UVfield() + + self._completed = True + + @classmethod + def _parse_wildcards(cls, paths, filenames, var): + if not isinstance(paths, list): + paths = sorted(glob(str(paths))) + if len(paths) == 0: + notfound_paths = filenames[var] if isinstance(filenames, dict) and var in filenames else filenames + raise OSError(f"FieldSet files not found for variable {var}: {notfound_paths}") + for fp in paths: + if not os.path.exists(fp): + raise OSError(f"FieldSet file not found: {fp}") + return paths + + # @classmethod + # def from_netcdf( + # cls, + # filenames, + # variables, + # dimensions, + # fieldtype=None, + # mesh: Mesh = "spherical", + # allow_time_extrapolation: bool | None = None, + # **kwargs, + # ): + + # @classmethod + # def from_nemo( + # cls, + # filenames, + # variables, + # dimensions, + # mesh: Mesh = "spherical", + # allow_time_extrapolation: bool | None = None, + # tracer_interp_method: InterpMethodOption = "cgrid_tracer", + # **kwargs, + # ): + + # @classmethod + # def from_mitgcm( + # cls, + # filenames, + # variables, + # dimensions, + # mesh: Mesh = "spherical", + # allow_time_extrapolation: bool | None = None, + # tracer_interp_method: InterpMethodOption = "cgrid_tracer", + # **kwargs, + # ): + + # @classmethod + # def from_croco( + # cls, + # filenames, + # variables, + # dimensions, + # hc: float | None = None, + # mesh="spherical", + # allow_time_extrapolation=None, + # tracer_interp_method="cgrid_tracer", + # **kwargs, + # ): + + # @classmethod + # def from_c_grid_dataset( + # cls, + # filenames, + # variables, + # dimensions, + # mesh: Mesh = "spherical", + # allow_time_extrapolation: bool | None = None, + # tracer_interp_method: InterpMethodOption = "cgrid_tracer", + # gridindexingtype: GridIndexingType = "nemo", + # **kwargs, + # ): + + + # @classmethod + # def from_mom5( + # cls, + # filenames, + # variables, + # dimensions, + # mesh: Mesh = "spherical", + # allow_time_extrapolation: bool | None = None, + # tracer_interp_method: InterpMethodOption = "bgrid_tracer", + # **kwargs, + # ): + + # @classmethod + # def from_a_grid_dataset(cls, filenames, variables, dimensions, **kwargs): + + # @classmethod + # def from_b_grid_dataset( + # cls, + # filenames, + # variables, + # dimensions, + # mesh: Mesh = "spherical", + # allow_time_extrapolation: bool | None = None, + # tracer_interp_method: InterpMethodOption = "bgrid_tracer", + # **kwargs, + # ): + + def add_constant(self, name, value): + """Add a constant to the FieldSet. Note that all constants are + stored as 32-bit floats. + + Parameters + ---------- + name : str + Name of the constant + value : + Value of the constant (stored as 32-bit float) + + + Examples + -------- + Tutorials using fieldset.add_constant: + `Analytical advection <../examples/tutorial_analyticaladvection.ipynb>`__ + `Diffusion <../examples/tutorial_diffusion.ipynb>`__ + `Periodic boundaries <../examples/tutorial_periodic_boundaries.ipynb>`__ + """ + setattr(self, name, value) + + # def computeTimeChunk(self, time=0.0, dt=1): + # """Load a chunk of three data time steps into the FieldSet. + # This is used when FieldSet uses data imported from netcdf, + # with default option deferred_load. The loaded time steps are at or immediatly before time + # and the two time steps immediately following time if dt is positive (and inversely for negative dt) + + # Parameters + # ---------- + # time : + # Time around which the FieldSet data are to be loaded. + # Time is provided as a double, relatively to Fieldset.time_origin. + # Default is 0. + # dt : + # time step of the integration scheme, needed to set the direction of time chunk loading. + # Default is 1. + # """ + # nextTime = np.inf if dt > 0 else -np.inf + # if abs(nextTime) == np.inf or np.isnan(nextTime): # Second happens when dt=0 + # return nextTime + # else: + # nSteps = int((nextTime - time) / dt) + # if nSteps == 0: + # return nextTime + # else: + # return time + nSteps * dt From da8489396dc7af44a2a325e56a400715821f033f Mon Sep 17 00:00:00 2001 From: Joe Date: Tue, 25 Mar 2025 14:35:36 -0400 Subject: [PATCH 08/46] Add draft of unstructured grid index search. Non-working draft of structured grid index search has been brought in as well. This needs a bit of work. --- parcels/xfield.py | 155 ++++++++++++++++++++++++++++++---------------- 1 file changed, 100 insertions(+), 55 deletions(-) diff --git a/parcels/xfield.py b/parcels/xfield.py index 87491bf6a9..988055cd68 100644 --- a/parcels/xfield.py +++ b/parcels/xfield.py @@ -7,6 +7,8 @@ import numpy as np import xarray as xr import uxarray as ux +from uxarray.grid.neighbors import _barycentric_coordinates + import parcels.tools.interpolation_utils as i_u from parcels._compat import add_note @@ -42,7 +44,7 @@ import inspect from typing import Callable, Union -#from ._index_search import _search_indices_curvilinear, _search_indices_rectilinear, _search_time_index +from ._index_search import _search_indices_curvilinear, _search_indices_rectilinear, _search_time_index if TYPE_CHECKING: import numpy.typing as npt @@ -149,6 +151,7 @@ def __init__( self, name: str, data: xr.DataArray | ux.UxDataArray, + mesh_type: Mesh = "flat", interp_method: Callable | None = None, allow_time_extrapolation: bool | None = None, ): @@ -159,7 +162,7 @@ def __init__( self._validate_dataarray(data) self._parent_mesh = data.attributes["mesh"] - self._mesh_type = data.attributes["mesh_type"] + self._mesh_type = mesh_type self._location = data.attributes["location"] # Set the vertical location @@ -184,12 +187,16 @@ def __init__( else: raise ValueError("Unsupported mesh type in data array attributes. Choose either: 'spherical' or 'flat'") - self.fieldset: XFieldSet | None = None if allow_time_extrapolation is None: self.allow_time_extrapolation = True if len(self.data["time"]) == 1 else False else: self.allow_time_extrapolation = allow_time_extrapolation + if type(self.data) is ux.UxDataArray: + self._spatialhash = self.data.uxgrid.get_spatial_hash() + else: + self._spatialhash = None + def __repr__(self): return field_repr(self) @@ -269,7 +276,13 @@ def ny(self): return self.data.sizes["node_lat"] else: return 0 # To do : Discuss what we want to return for uxdataarray obj - + @property + def n_face(self): + if type(self.data) is ux.uxDataArray: + return self.data.uxgrid.n_face + else: + return 0 # To do : Discuss what we want to return for dataarray obj + @property def interp_method(self): return self._interp_method @@ -282,6 +295,80 @@ def interp_method(self, method: Callable): # @property # def gridindexingtype(self): # return self._gridindexingtype + + def _get_ux_barycentric_coordinates(self, y, x, fi): + "Checks if a point is inside a given face id. Used for unstructured grids." + + # Check if particle is in the same face, otherwise search again. + n_nodes = self.data.uxgrid.n_nodes_per_face[fi].to_numpy() + node_ids = self.data.uxgrid.face_node_connectivity[fi, 0:n_nodes] + nodes = np.column_stack( + ( + np.deg2rad(self.data.uxgrid.node_lon[node_ids].to_numpy()), + np.deg2rad(self.data.uxgrid.node_lat[node_ids].to_numpy()), + ) + ) + + coord = np.deg2rad([x, y]) + bcoord = np.asarray(_barycentric_coordinates(nodes, coord)) + err = abs(np.dot(bcoord, nodes[:, 0]) - coord[0]) + abs( + np.dot(bcoord, nodes[:, 1]) - coord[1] + ) + return bcoord, err + + + def _search_indices_unstructured(self, z, y, x, ei=None, search2D=False): + + tol = 1e-10 + if ei is None: + # Search using global search + fi, bcoords = self._spatialhash.query([[x,y]]) # Get the face id for the particle + if fi == -1: + raise FieldOutOfBoundError(z, y, x) # To do : how to handle lost particle ?? + # To do : Do the vertical grid search + # zi = self._vertical_search(z) + zi = 0 # For now + return bcoords, self.ravel_index(zi, 0, fi) + else: + zi, fi = self.unravel_index(ei[self.igrid]) # Get the z, and face index of the particle + # Search using nearest neighbors + bcoords, err = self._get_ux_barycentric_coordinates(y, x, fi) + + if ((bcoords >= 0).all()) and ((bcoords <= 1.0).all()) and err < tol: + # To do: Do the vertical grid search + return bcoords, ei + else: + # In this case we need to search the neighbors + for neighbor in self.data.uxgrid.face_face_connectivity[fi,:]: + bcoords, err = self._get_ux_barycentric_coordinates(y, x, neighbor) + if ((bcoords >= 0).all()) and ((bcoords <= 1.0).all()) and err < tol: + # To do: Do the vertical grid search + return bcoords, self.ravel_index(zi, 0, neighbor) + + # If we reach this point, we do a global search as a last ditch effort the particle is out of bounds + fi, bcoords = self._spatialhash.query([[x,y]]) # Get the face id for the particle + if fi == -1: + raise FieldOutOfBoundError(z, y, x) # To do : how to handle lost particle ?? + + + def _search_indices_structured(self, z, y, x, ei=None, search2D=False): + + # To do, determine grid type from xarray.coords shapes + # Rectilinear uses 1-D array for lat and lon + # Curvilinear uses 2-D array for lat and lon + if self.grid._gtype in [GridType.RectilinearSGrid, GridType.RectilinearZGrid]: + (zeta, eta, xsi, zi, yi, xi) = _search_indices_rectilinear( + self, z, y, x,particle=particle, search2D=search2D + ) + else: + (zeta, eta, xsi, zi, yi, xi) = _search_indices_curvilinear( + self, z, y, x, ei=ei, search2D=search2D + ) + + # To do : Calcualte barycentric coordinates from zeta, eta, xsi + + return (zeta, eta, xsi, zi, yi, xi) + def _search_indices(self, time, z, y, x, ei=None, search2D=False): tau, ti = self._search_time_index(time) # To do : Need to implement this method @@ -334,7 +421,7 @@ def eval(self, time, z, y, x, ei=None, applyConversion=True): scipy.interpolate to perform spatial interpolation. """ if ei is None: - _ei = 0 + _ei = None else: _ei = ei[self.igrid] @@ -351,7 +438,6 @@ def _rescale_and_set_minmax(self, data): def ravel_index(self, zi, yi, xi): """Return the flat index of the given grid points. - Only used when working with fields on a structured grid. Parameters ---------- @@ -360,7 +446,7 @@ def ravel_index(self, zi, yi, xi): yi : int y index xi : int - x index + x index. When using an unstructured grid, this is the face index (fi) Returns ------- @@ -370,7 +456,7 @@ def ravel_index(self, zi, yi, xi): if type(self.data) is xr.DataArray: return xi + self.nx * (yi + self.ny * zi) else: - return None + return xi + self.n_face*zi def unravel_index(self, ei): """Return the zi, yi, xi indices for a given flat index. @@ -398,7 +484,10 @@ def unravel_index(self, ei): xi = _ei % self.nx return zi, yi, xi else: - return None,None,None # To do : Discuss what we want to return for uxdataarray + _ei = ei[self.igrid] + zi = _ei // self.n_face + fi = _ei % self.n_face + return zi, fi def _validate_dataarray(self): """ Verifies that all the required attributes are present in the xarray.DataArray or @@ -418,7 +507,7 @@ def _validate_dataarray(self): ) # Validate attributes - required_keys = ["location", "mesh", "mesh_type"] + required_keys = ["location", "mesh"] for key in required_keys: if key not in self.data.attrs.keys(): raise ValueError( @@ -583,48 +672,4 @@ def __getitem__(self,key): else: return self.eval(*key) except tuple(AllParcelsErrorCodes.keys()) as error: - return _deal_with_errors(error, key, vector_type=self.vector_type) - - -# Private helper routines -def _barycentric_coordinates(nodes, point): - """ - Compute the barycentric coordinates of a point P inside a convex polygon using area-based weights. - So that this method generalizes to n-sided polygons, we use the Waschpress points as the generalized - barycentric coordinates, which is only valid for convex polygons. - - Parameters - ---------- - nodes : numpy.ndarray - Spherical coordinates (lon,lat) of each corner node of a face - point : numpy.ndarray - Spherical coordinates (lon,lat) of the point - Returns - ------- - numpy.ndarray - Barycentric coordinates corresponding to each vertex. - - """ - n = len(nodes) - sum_wi = 0 - w = [] - - for i in range(0, n): - vim1 = nodes[i - 1] - vi = nodes[i] - vi1 = nodes[(i + 1) % n] - a0 = _triangle_area(vim1, vi, vi1) - a1 = _triangle_area(point, vim1, vi) - a2 = _triangle_area(point, vi, vi1) - sum_wi += a0 / (a1 * a2) - w.append(a0 / (a1 * a2)) - - barycentric_coords = [w_i / sum_wi for w_i in w] - - return barycentric_coords - -def _triangle_area(A, B, C): - """ - Compute the area of a triangle given by three points. - """ - return 0.5 * (A[0] * (B[1] - C[1]) + B[0] * (C[1] - A[1]) + C[0] * (A[1] - B[1])) \ No newline at end of file + return _deal_with_errors(error, key, vector_type=self.vector_type) \ No newline at end of file From 7514e82cca2b7278f232a64968eb84002e1e4598 Mon Sep 17 00:00:00 2001 From: Joe Date: Thu, 27 Mar 2025 11:53:30 -0400 Subject: [PATCH 09/46] Swap grid for Field in _search_time_index This is done since the Field object holds either an xr.DataArray or ux.UxDataArray which contain the grid, dimensions, and coordinates. With this, the time argument is now set to be a datetime object. --- parcels/_index_search.py | 62 +- parcels/field.py | 1405 ++++++++++++-------------------------- parcels/fieldset.py | 1112 ++++++------------------------ parcels/xfield.py | 675 ------------------ parcels/xfieldset.py | 337 --------- 5 files changed, 692 insertions(+), 2899 deletions(-) delete mode 100644 parcels/xfield.py delete mode 100644 parcels/xfieldset.py diff --git a/parcels/_index_search.py b/parcels/_index_search.py index ddfbadd173..28a40c34f9 100644 --- a/parcels/_index_search.py +++ b/parcels/_index_search.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING import numpy as np +from datetime import datetime from parcels._typing import ( GridIndexingType, @@ -21,73 +22,80 @@ if TYPE_CHECKING: from .field import Field - from .grid import Grid + #from .grid import Grid -def _search_time_index(grid: Grid, time: float, allow_time_extrapolation=True): +def _search_time_index(field: Field, time: datetime , allow_time_extrapolation=True): """Find and return the index and relative coordinate in the time array associated with a given time. + Parameters + ---------- + field: Field + + time: datetime + This is the amount of time, in seconds (time_delta), in unix epoch Note that we normalize to either the first or the last index if the sampled value is outside the time value range. """ - if not allow_time_extrapolation and (time < grid.time[0] or time > grid.time[-1]): + if not allow_time_extrapolation and (time < field.data.time[0] or time > field.data.time[-1]): _raise_time_extrapolation_error(time, field=None) - time_index = grid.time <= time + time_index = field.data.time <= time if time_index.all(): # If given time > last known field time, use # the last field frame without interpolation - ti = len(grid.time) - 1 + ti = len(field.data.time) - 1 + elif np.logical_not(time_index).all(): # If given time < any time in the field, use # the first field frame without interpolation ti = 0 else: ti = int(time_index.argmin() - 1) if time_index.any() else 0 - if grid.tdim == 1: + if len(field.data.time)== 1: tau = 0 - elif ti == len(grid.time) - 1: + elif ti == len(field.data.time) - 1: tau = 1 else: - tau = (time - grid.time[ti]) / (grid.time[ti + 1] - grid.time[ti]) if grid.time[ti] != grid.time[ti + 1] else 0 + tau = (time - field.data.time[ti]).total_seconds() / (field.data.time[ti + 1] - field.data.time[ti]).total_seconds() if field.data.time[ti] != field.data.time[ti + 1] else 0 return tau, ti -def search_indices_vertical_z(grid: Grid, gridindexingtype: GridIndexingType, z: float): - if grid.depth[-1] > grid.depth[0]: - if z < grid.depth[0]: +def search_indices_vertical_z(depth, gridindexingtype: GridIndexingType, z: float): + if depth[-1] > depth[0]: + if z < depth[0]: # Since MOM5 is indexed at cell bottom, allow z at depth[0] - dz where dz = (depth[1] - depth[0]) - if gridindexingtype == "mom5" and z > 2 * grid.depth[0] - grid.depth[1]: - return (-1, z / grid.depth[0]) + if gridindexingtype == "mom5" and z > 2 * depth[0] - depth[1]: + return (-1, z / depth[0]) else: _raise_field_out_of_bound_surface_error(z, None, None) - elif z > grid.depth[-1]: + elif z > depth[-1]: # In case of CROCO, allow particles in last (uppermost) layer using depth[-1] if gridindexingtype in ["croco"] and z < 0: return (-2, 1) _raise_field_out_of_bound_error(z, None, None) - depth_indices = grid.depth < z - if z >= grid.depth[-1]: - zi = len(grid.depth) - 2 + depth_indices = depth < z + if z >= depth[-1]: + zi = len(depth) - 2 else: - zi = depth_indices.argmin() - 1 if z > grid.depth[0] else 0 + zi = depth_indices.argmin() - 1 if z > depth[0] else 0 else: - if z > grid.depth[0]: + if z > depth[0]: _raise_field_out_of_bound_surface_error(z, None, None) - elif z < grid.depth[-1]: + elif z < depth[-1]: _raise_field_out_of_bound_error(z, None, None) - depth_indices = grid.depth > z - if z <= grid.depth[-1]: - zi = len(grid.depth) - 2 + depth_indices = depth > z + if z <= depth[-1]: + zi = len(depth) - 2 else: - zi = depth_indices.argmin() - 1 if z < grid.depth[0] else 0 - zeta = (z - grid.depth[zi]) / (grid.depth[zi + 1] - grid.depth[zi]) + zi = depth_indices.argmin() - 1 if z < depth[0] else 0 + zeta = (z - depth[zi]) / (depth[zi + 1] - depth[zi]) while zeta > 1: zi += 1 - zeta = (z - grid.depth[zi]) / (grid.depth[zi + 1] - grid.depth[zi]) + zeta = (z - depth[zi]) / (depth[zi + 1] - depth[zi]) while zeta < 0: zi -= 1 - zeta = (z - grid.depth[zi]) / (grid.depth[zi + 1] - grid.depth[zi]) + zeta = (z - depth[zi]) / (depth[zi + 1] - depth[zi]) return (zi, zeta) diff --git a/parcels/field.py b/parcels/field.py index 761cee0f82..988055cd68 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -6,6 +6,9 @@ import dask.array as da import numpy as np import xarray as xr +import uxarray as ux +from uxarray.grid.neighbors import _barycentric_coordinates + import parcels.tools.interpolation_utils as i_u from parcels._compat import add_note @@ -38,19 +41,17 @@ _raise_field_out_of_bound_error, ) from parcels.tools.warnings import FieldSetWarning +import inspect +from typing import Callable, Union from ._index_search import _search_indices_curvilinear, _search_indices_rectilinear, _search_time_index -from .fieldfilebuffer import ( - NetcdfFileBuffer, -) -from .grid import Grid, GridType if TYPE_CHECKING: import numpy.typing as npt - from parcels.fieldset import FieldSet + from parcels.xfieldset import XFieldSet -__all__ = ["Field", "VectorField"] +__all__ = ["XField", "XVectorField"] def _isParticle(key): @@ -74,518 +75,322 @@ def _deal_with_errors(error, key, vector_type: VectorType): return (0, 0) else: return 0 + +class XField: + """The XField class that holds scalar field data. + The `XField` object is a wrapper around a xarray.DataArray or uxarray.UxDataArray object. + Additionally, it holds a dynamic Callable procedure that is used to interpolate the field data. + During initialization, the user can supply a custom interpolation method that is used to interpolate the field data, + so long as the interpolation method has the correct signature. + + Notes + ----- + + + The xarray.DataArray or uxarray.UxDataArray object contains the field data and metadata. + * dims: (time, [nz1 | nz], [face_lat | node_lat | edge_lat], [face_lon | node_lon | edge_lon]) + * attrs: (location, mesh, mesh_type) + + When using a xarray.DataArray object, + * The xarray.DataArray object must have the "location" and "mesh" attributes set. + * The "location" attribute must be set to one of the following to define which pairing of points a field is associated with. + * "node" + * "face" + * "x_edge" + * "y_edge" + * For an A-Grid, the "location" attribute must be set to / is assumed to be "node" (node_lat,node_lon). + * For a C-Grid, the "location" setting for a field has the following interpretation: + * "node" ~> the field is associated with the vorticity points (node_lat, node_lon) + * "face" ~> the field is associated with the tracer points (face_lat, face_lon) + * "x_edge" ~> the field is associated with the u-velocity points (face_lat, node_lon) + * "y_edge" ~> the field is associated with the v-velocity points (node_lat, face_lon) + + When using a uxarray.UxDataArray object, + * The uxarray.UxDataArray.UxGrid object must have the "Conventions" attribute set to "UGRID-1.0" + and the uxarray.UxDataArray object must comply with the UGRID conventions. + See https://ugrid-conventions.github.io/ugrid-conventions/ for more information. - -def _croco_from_z_to_sigma_scipy(fieldset, time, z, y, x, particle): - """Calculate local sigma level of the particle, by linearly interpolating the - scaling function that maps sigma to depth (using local ocean depth H, - sea-surface Zeta and stretching parameters Cs_w and hc). - See also https://croco-ocean.gitlabpages.inria.fr/croco_doc/model/model.grid.html#vertical-grid-parameters """ - h = fieldset.H.eval(time, 0, y, x, particle=particle, applyConversion=False) - zeta = fieldset.Zeta.eval(time, 0, y, x, particle=particle, applyConversion=False) - sigma_levels = fieldset.U.grid.depth - z0 = fieldset.hc * sigma_levels + (h - fieldset.hc) * fieldset.Cs_w.data[0, :, 0, 0] - zvec = z0 + zeta * (1 + (z0 / h)) - zinds = zvec <= z - if z >= zvec[-1]: - zi = len(zvec) - 2 - else: - zi = zinds.argmin() - 1 if z >= zvec[0] else 0 - - return sigma_levels[zi] + (z - zvec[zi]) * (sigma_levels[zi + 1] - sigma_levels[zi]) / (zvec[zi + 1] - zvec[zi]) - - -class Field: - """Class that encapsulates access to field data. - - Parameters - ---------- - name : str - Name of the field - data : np.ndarray - 2D, 3D or 4D numpy array of field data with shape [ydim, xdim], [zdim, ydim, xdim], [tdim, ydim, xdim] or [tdim, zdim, ydim, xdim], - lon : np.ndarray or list - Longitude coordinates (numpy vector or array) of the field (only if grid is None) - lat : np.ndarray or list - Latitude coordinates (numpy vector or array) of the field (only if grid is None) - depth : np.ndarray or list - Depth coordinates (numpy vector or array) of the field (only if grid is None) - time : np.ndarray - Time coordinates (numpy vector) of the field (only if grid is None) - mesh : str - String indicating the type of mesh coordinates and - units used during velocity interpolation: (only if grid is None) - - 1. spherical: Lat and lon in degree, with a - correction for zonal velocity U near the poles. - 2. flat (default): No conversion, lat/lon are assumed to be in m. - grid : parcels.grid.Grid - :class:`parcels.grid.Grid` object containing all the lon, lat depth, time - mesh and time_origin information. Can be constructed from any of the Grid objects - fieldtype : str - Type of Field to be used for UnitConverter (either 'U', 'V', 'Kh_zonal', 'Kh_meridional' or None) - time_origin : parcels.tools.converters.TimeConverter - Time origin of the time axis (only if grid is None) - interp_method : str - Method for interpolation. Options are 'linear' (default), 'nearest', - 'linear_invdist_land_tracer', 'cgrid_velocity', 'cgrid_tracer' and 'bgrid_velocity' - allow_time_extrapolation : bool - boolean whether to allow for extrapolation in time - (i.e. beyond the last available time snapshot) - """ + @staticmethod + def _interp_template( + self, + ti: int, + ei: int, + bcoords: np.ndarray, + t: Union[np.float32,np.float64], + z: Union[np.float32,np.float64], + y: Union[np.float32,np.float64], + x: Union[np.float32,np.float64] + )-> Union[np.float32,np.float64]: + """ Template function used for the signature check of the lateral interpolation methods.""" + return 0.0 + + def _validate_interp_function(self, func: Callable): + """Ensures that the function has the correct signature.""" + expected_params = ["ti", "ei", "bcoords", "t", "z", "y", "x"] + expected_return_types = (np.float32,np.float64) + + sig = inspect.signature(func) + params = list(sig.parameters.keys()) + + # Check the parameter names and count + if params != expected_params: + raise TypeError( + f"Function must have parameters {expected_params}, but got {params}" + ) - allow_time_extrapolation: bool + # Check return annotation if present + return_annotation = sig.return_annotation + if return_annotation not in (inspect.Signature.empty, *expected_return_types): + raise TypeError( + f"Function must return a float, but got {return_annotation}" + ) def __init__( self, - name: str | tuple[str, str], - data, - lon=None, - lat=None, - depth=None, - time=None, - grid=None, - mesh: Mesh = "flat", - fieldtype=None, - time_origin: TimeConverter | None = None, - interp_method: InterpMethod = "linear", + name: str, + data: xr.DataArray | ux.UxDataArray, + mesh_type: Mesh = "flat", + interp_method: Callable | None = None, allow_time_extrapolation: bool | None = None, - gridindexingtype: GridIndexingType = "nemo", - data_full_zdim=None, ): - if not isinstance(name, tuple): - self.name = name - else: - self.name = name[0] + + self.name = name self.data = data - if grid: - self._grid = grid + + self._validate_dataarray(data) + + self._parent_mesh = data.attributes["mesh"] + self._mesh_type = mesh_type + self._location = data.attributes["location"] + + # Set the vertical location + if "nz1" in data.dims: + self._vertical_location = "center" + elif "nz" in data.dims: + self._vertical_location = "face" + + # Setting the interpolation method dynamically + if interp_method is None: + self._interp_method = self._interp_template # Default to method that returns 0 always else: - if (time is not None) and isinstance(time[0], np.datetime64): - time_origin = TimeConverter(time[0]) - time = np.array([time_origin.reltime(t) for t in time]) - else: - time_origin = TimeConverter(0) - self._grid = Grid.create_grid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh) - self.igrid = -1 - self.fieldtype = self.name if fieldtype is None else fieldtype - if self.grid.mesh == "flat" or (self.fieldtype not in unitconverters_map.keys()): + self._validate_interp_function(interp_method) + self._interp_method = interp_method + + self.igrid = -1 # Default the grid index to -1 + + if self._mesh_type == "flat" or (self.name not in unitconverters_map.keys()): self.units = UnitConverter() - elif self.grid.mesh == "spherical": - self.units = unitconverters_map[self.fieldtype] - else: - raise ValueError("Unsupported mesh type. Choose either: 'spherical' or 'flat'") - if isinstance(interp_method, dict): - if self.name in interp_method: - self.interp_method = interp_method[self.name] - else: - raise RuntimeError(f"interp_method is a dictionary but {name} is not in it") + elif self._mesh_type == "spherical": + self.units = unitconverters_map[self.name] else: - self.interp_method = interp_method - assert_valid_gridindexingtype(gridindexingtype) - self._gridindexingtype = gridindexingtype - if self.interp_method in ["bgrid_velocity", "bgrid_w_velocity", "bgrid_tracer"] and self.grid._gtype in [ - GridType.RectilinearSGrid, - GridType.CurvilinearSGrid, - ]: - warnings.warn( - "General s-levels are not supported in B-grid. RectilinearSGrid and CurvilinearSGrid can still be used to deal with shaved cells, but the levels must be horizontal.", - FieldSetWarning, - stacklevel=2, - ) - - self.fieldset: FieldSet | None = None + raise ValueError("Unsupported mesh type in data array attributes. Choose either: 'spherical' or 'flat'") + if allow_time_extrapolation is None: - self.allow_time_extrapolation = True if len(self.grid.time) == 1 else False + self.allow_time_extrapolation = True if len(self.data["time"]) == 1 else False else: self.allow_time_extrapolation = allow_time_extrapolation - self.data = self._reshape(self.data) - - # Hack around the fact that NaN and ridiculously large values - # propagate in SciPy's interpolators - self.data[np.isnan(self.data)] = 0.0 - - # data_full_zdim is the vertical dimension of the complete field data, ignoring the indices. - # (data_full_zdim = grid.zdim if no indices are used, for A- and C-grids and for some B-grids). It is used for the B-grid, - # since some datasets do not provide the deeper level of data (which is ignored by the interpolation). - self.data_full_zdim = data_full_zdim + if type(self.data) is ux.UxDataArray: + self._spatialhash = self.data.uxgrid.get_spatial_hash() + else: + self._spatialhash = None - def __repr__(self) -> str: + def __repr__(self): return field_repr(self) @property def grid(self): - return self._grid - - @property - def lon(self): - """Lon defined on the Grid object""" - return self.grid.lon - + if type(self.data) is ux.UxDataArray: + return self.data.uxgrid + else: + return self.data # To do : need to decide on what to return for xarray.DataArray objects + @property def lat(self): - """Lat defined on the Grid object""" - return self.grid.lat + if type(self.data) is ux.UxDataArray: + if self._location == "node": + return self.data.uxgrid.node_lat + elif self._location == "face": + return self.data.uxgrid.face_lat + elif self._location == "edge": + return self.data.uxgrid.edge_lat + else: + if self._location == "node": + return self.data.node_lat + elif self._location == "face": + return self.data.face_lat + elif self._location == "x_edge": + return self.data.face_lat + elif self._location == "y_edge": + return self.data.node_lat @property - def depth(self): - """Depth defined on the Grid object""" - return self.grid.depth + def lon(self): + if type(self.data) is ux.UxDataArray: + if self._location == "node": + return self.data.uxgrid.node_lon + elif self._location == "face": + return self.data.uxgrid.face_lon + elif self._location == "edge": + return self.data.uxgrid.edge_lon + else: + if self._location == "node": + return self.data.node_lon + elif self._location == "face": + return self.data.face_lon + elif self._location == "x_edge": + return self.data.node_lon + elif self._location == "y_edge": + return self.data.face_lon @property - def interp_method(self): - return self._interp_method - - @interp_method.setter - def interp_method(self, value): - assert_valid_interp_method(value) - self._interp_method = value + def depth(self): + if type(self.data) is ux.UxDataArray: + if self._vertical_location == "center": + return self.data.uxgrid.nz1 + elif self._vertical_location == "face": + return self.data.uxgrid.nz + else: + if self._vertical_location == "center": + return self.data.nz1 + elif self._vertical_location == "face": + return self.data.nz @property - def gridindexingtype(self): - return self._gridindexingtype - - @classmethod - def _get_dim_filenames(cls, filenames, dim): - if isinstance(filenames, str) or not isinstance(filenames, collections.abc.Iterable): - return [filenames] - elif isinstance(filenames, dict): - assert dim in filenames.keys(), "filename dimension keys must be lon, lat, depth or data" - filename = filenames[dim] - if isinstance(filename, str): - return [filename] - else: - return filename + def nx(self): + if type(self.data) is xr.DataArray: + if "face_lon" in self.data.dims: + return self.data.sizes["face_lon"] + elif "node_lon" in self.data.dims: + return self.data.sizes["node_lon"] else: - return filenames - - @staticmethod - def _collect_time(data_filenames, dimensions, indices): - time = [] - for fname in data_filenames: - with NetcdfFileBuffer(fname, dimensions, indices) as filebuffer: - ftime = filebuffer.time - time.append(ftime) - time = np.concatenate(time).ravel() - if time.size == 1 and time[0] is None: - time[0] = 0 - time_origin = TimeConverter(time[0]) - time = time_origin.reltime(time) - - return time, time_origin - - @classmethod - def from_netcdf( - cls, - filenames, - variable, - dimensions, - grid=None, - mesh: Mesh = "spherical", - allow_time_extrapolation: bool | None = None, - **kwargs, - ) -> "Field": - """Create field from netCDF file. - - Parameters - ---------- - filenames : list of str or dict - list of filenames to read for the field. filenames can be a list ``[files]`` or - a dictionary ``{dim:[files]}`` (if lon, lat, depth and/or data not stored in same files as data) - In the latter case, time values are in filenames[data] - variable : dict, tuple of str or str - Dict or tuple mapping field name to variable name in the NetCDF file. - dimensions : dict - Dictionary mapping variable names for the relevant dimensions in the NetCDF file - mesh : - String indicating the type of mesh coordinates and - units used during velocity interpolation: - - 1. spherical (default): Lat and lon in degree, with a - correction for zonal velocity U near the poles. - 2. flat: No conversion, lat/lon are assumed to be in m. - allow_time_extrapolation : bool - boolean whether to allow for extrapolation in time - (i.e. beyond the last available time snapshot) - Default is False if dimensions includes time, else True - gridindexingtype : str - The type of gridindexing. Either 'nemo' (default), 'mitgcm', 'mom5', 'pop', or 'croco' are supported. - See also the Grid indexing documentation on oceanparcels.org - grid : - (Default value = None) - **kwargs : - Keyword arguments passed to the :class:`Field` constructor. - """ - if isinstance(variable, str): # for backward compatibility with Parcels < 2.0.0 - variable = (variable, variable) - elif isinstance(variable, dict): - assert ( - len(variable) == 1 - ), "Field.from_netcdf() supports only one variable at a time. Use FieldSet.from_netcdf() for multiple variables." - variable = tuple(variable.items())[0] - assert ( - len(variable) == 2 - ), "The variable tuple must have length 2. Use FieldSet.from_netcdf() for multiple variables" - - data_filenames = cls._get_dim_filenames(filenames, "data") - lonlat_filename = cls._get_dim_filenames(filenames, "lon") - if isinstance(filenames, dict): - assert len(lonlat_filename) == 1 - if lonlat_filename != cls._get_dim_filenames(filenames, "lat"): - raise NotImplementedError( - "longitude and latitude dimensions are currently processed together from one single file" - ) - lonlat_filename = lonlat_filename[0] - if "depth" in dimensions: - depth_filename = cls._get_dim_filenames(filenames, "depth") - if isinstance(filenames, dict) and len(depth_filename) != 1: - raise NotImplementedError("Vertically adaptive meshes not implemented for from_netcdf()") - depth_filename = depth_filename[0] - - gridindexingtype = kwargs.get("gridindexingtype", "nemo") - - indices: dict[str, npt.NDArray] = {} - - interp_method: InterpMethod = kwargs.pop("interp_method", "linear") - if type(interp_method) is dict: - if variable[0] in interp_method: - interp_method = interp_method[variable[0]] - else: - raise RuntimeError(f"interp_method is a dictionary but {variable[0]} is not in it") - interp_method = cast(InterpMethodOption, interp_method) - - if "lon" in dimensions and "lat" in dimensions: - with NetcdfFileBuffer( - lonlat_filename, - dimensions, - indices, - gridindexingtype=gridindexingtype, - ) as filebuffer: - lat, lon = filebuffer.latlon - indices = filebuffer.indices - # Check if parcels_mesh has been explicitly set in file - if "parcels_mesh" in filebuffer.dataset.attrs: - mesh = filebuffer.dataset.attrs["parcels_mesh"] + return 0 # To do : Discuss what we want to return for uxdataarray obj + @property + def ny(self): + if type(self.data) is xr.DataArray: + if "face_lat" in self.data.dims: + return self.data.sizes["face_lat"] + elif "node_lat" in self.data.dims: + return self.data.sizes["node_lat"] else: - lon = 0 - lat = 0 - mesh = "flat" - - if "depth" in dimensions: - with NetcdfFileBuffer( - depth_filename, - dimensions, - indices, - interp_method=interp_method, - gridindexingtype=gridindexingtype, - ) as filebuffer: - filebuffer.name = variable[1] - depth = filebuffer.depth + return 0 # To do : Discuss what we want to return for uxdataarray obj + @property + def n_face(self): + if type(self.data) is ux.uxDataArray: + return self.data.uxgrid.n_face else: - indices["depth"] = np.array([0]) - depth = np.zeros(1) - - if len(data_filenames) > 1 and "time" not in dimensions: - raise RuntimeError("Multiple files given but no time dimension specified") - - if grid is None: - # Concatenate time variable to determine overall dimension - # across multiple files - if "time" in dimensions: - time, time_origin = cls._collect_time(data_filenames, dimensions, indices) - grid = Grid.create_grid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh) - else: # e.g. for the CROCO CS_w field, see https://github.com/OceanParcels/Parcels/issues/1831 - grid = Grid.create_grid(lon, lat, depth, np.array([0.0]), time_origin=TimeConverter(0.0), mesh=mesh) - data_filenames = [data_filenames[0]] - - if "time" in indices: - warnings.warn( - "time dimension in indices is not necessary anymore. It is then ignored.", FieldSetWarning, stacklevel=2 + return 0 # To do : Discuss what we want to return for dataarray obj + + @property + def interp_method(self): + return self._interp_method + + @interp_method.setter + def interp_method(self, method: Callable): + self._validate_interp_function(method) + self._interp_method = method + + # @property + # def gridindexingtype(self): + # return self._gridindexingtype + + def _get_ux_barycentric_coordinates(self, y, x, fi): + "Checks if a point is inside a given face id. Used for unstructured grids." + + # Check if particle is in the same face, otherwise search again. + n_nodes = self.data.uxgrid.n_nodes_per_face[fi].to_numpy() + node_ids = self.data.uxgrid.face_node_connectivity[fi, 0:n_nodes] + nodes = np.column_stack( + ( + np.deg2rad(self.data.uxgrid.node_lon[node_ids].to_numpy()), + np.deg2rad(self.data.uxgrid.node_lat[node_ids].to_numpy()), ) + ) - with NetcdfFileBuffer( # type: ignore[operator] - data_filenames, - dimensions, - indices, - interp_method=interp_method, - ) as filebuffer: - # If Field.from_netcdf is called directly, it may not have a 'data' dimension - # In that case, assume that 'name' is the data dimension - filebuffer.name = variable[1] - buffer_data = filebuffer.data - if len(buffer_data.shape) == 4: - errormessage = ( - f"Field {filebuffer.name} expecting a data shape of [tdim={grid.tdim}, zdim={grid.zdim}, " - f"ydim={grid.ydim}, xdim={grid.xdim }] " - f"but got shape {buffer_data.shape}." + coord = np.deg2rad([x, y]) + bcoord = np.asarray(_barycentric_coordinates(nodes, coord)) + err = abs(np.dot(bcoord, nodes[:, 0]) - coord[0]) + abs( + np.dot(bcoord, nodes[:, 1]) - coord[1] ) - assert buffer_data.shape[0] == grid.tdim, errormessage - assert buffer_data.shape[2] == grid.ydim, errormessage - assert buffer_data.shape[3] == grid.xdim, errormessage - - data = buffer_data + return bcoord, err - if allow_time_extrapolation is None: - allow_time_extrapolation = False if "time" in dimensions else True - - return cls( - variable, - data, - grid=grid, - allow_time_extrapolation=allow_time_extrapolation, - interp_method=interp_method, - **kwargs, - ) - @classmethod - def from_xarray( - cls, - da: xr.DataArray, - name: str, - dimensions, - mesh: Mesh = "spherical", - allow_time_extrapolation: bool | None = None, - **kwargs, - ): - """Create field from xarray Variable. + def _search_indices_unstructured(self, z, y, x, ei=None, search2D=False): - Parameters - ---------- - da : xr.DataArray - Xarray DataArray - name : str - Name of the Field - dimensions : dict - Dictionary mapping variable names for the relevant dimensions in the DataArray - mesh : str - String indicating the type of mesh coordinates and - units used during velocity interpolation: - - 1. spherical (default): Lat and lon in degree, with a - correction for zonal velocity U near the poles. - 2. flat: No conversion, lat/lon are assumed to be in m. - allow_time_extrapolation : bool - boolean whether to allow for extrapolation in time - (i.e. beyond the last available time snapshot) - Default is False if dimensions includes time, else True - **kwargs : - Keyword arguments passed to the :class:`Field` constructor. - """ - data = da.data - interp_method = kwargs.pop("interp_method", "linear") - - time = da[dimensions["time"]].values if "time" in dimensions else np.array([0.0]) - depth = da[dimensions["depth"]].values if "depth" in dimensions else np.array([0]) - lon = da[dimensions["lon"]].values - lat = da[dimensions["lat"]].values - - time_origin = TimeConverter(time[0]) - time = time_origin.reltime(time) # type: ignore[assignment] - - grid = Grid.create_grid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh) - return cls( - name, - data, - grid=grid, - allow_time_extrapolation=allow_time_extrapolation, - interp_method=interp_method, - **kwargs, - ) - - def _reshape(self, data): - # Ensure that field data is the right data type - if not isinstance(data, (np.ndarray)): - data = np.array(data) - - if self.grid.xdim == 1 or self.grid.ydim == 1: - data = np.squeeze(data) # First remove all length-1 dimensions in data, so that we can add them below - if self.grid.xdim == 1 and len(data.shape) < 4: - data = np.expand_dims(data, axis=-1) - if self.grid.ydim == 1 and len(data.shape) < 4: - data = np.expand_dims(data, axis=-2) - if self.grid.tdim == 1: - if len(data.shape) < 4: - data = data.reshape(sum(((1,), data.shape), ())) - if self.grid.zdim == 1: - if len(data.shape) == 4: - data = data.reshape(sum(((data.shape[0],), data.shape[2:]), ())) - if len(data.shape) == 4: - errormessage = f"Field {self.name} expecting a data shape of [tdim, zdim, ydim, xdim]. " - assert data.shape[0] == self.grid.tdim, errormessage - assert data.shape[2] == self.grid.ydim, errormessage - assert data.shape[3] == self.grid.xdim, errormessage - if self.gridindexingtype == "pop": - assert data.shape[1] == self.grid.zdim or data.shape[1] == self.grid.zdim - 1, errormessage - else: - assert data.shape[1] == self.grid.zdim, errormessage + tol = 1e-10 + if ei is None: + # Search using global search + fi, bcoords = self._spatialhash.query([[x,y]]) # Get the face id for the particle + if fi == -1: + raise FieldOutOfBoundError(z, y, x) # To do : how to handle lost particle ?? + # To do : Do the vertical grid search + # zi = self._vertical_search(z) + zi = 0 # For now + return bcoords, self.ravel_index(zi, 0, fi) else: - assert data.shape == ( - self.grid.tdim, - self.grid.ydim, - self.grid.xdim, - ), f"Field {self.name} expecting a data shape of [tdim, ydim, xdim]. " + zi, fi = self.unravel_index(ei[self.igrid]) # Get the z, and face index of the particle + # Search using nearest neighbors + bcoords, err = self._get_ux_barycentric_coordinates(y, x, fi) - return data - - def _search_indices(self, time, z, y, x, particle=None, search2D=False): - tau, ti = _search_time_index(self.grid, time, self.allow_time_extrapolation) - - if self.grid._gtype in [GridType.RectilinearSGrid, GridType.RectilinearZGrid]: + if ((bcoords >= 0).all()) and ((bcoords <= 1.0).all()) and err < tol: + # To do: Do the vertical grid search + return bcoords, ei + else: + # In this case we need to search the neighbors + for neighbor in self.data.uxgrid.face_face_connectivity[fi,:]: + bcoords, err = self._get_ux_barycentric_coordinates(y, x, neighbor) + if ((bcoords >= 0).all()) and ((bcoords <= 1.0).all()) and err < tol: + # To do: Do the vertical grid search + return bcoords, self.ravel_index(zi, 0, neighbor) + + # If we reach this point, we do a global search as a last ditch effort the particle is out of bounds + fi, bcoords = self._spatialhash.query([[x,y]]) # Get the face id for the particle + if fi == -1: + raise FieldOutOfBoundError(z, y, x) # To do : how to handle lost particle ?? + + + def _search_indices_structured(self, z, y, x, ei=None, search2D=False): + + # To do, determine grid type from xarray.coords shapes + # Rectilinear uses 1-D array for lat and lon + # Curvilinear uses 2-D array for lat and lon + if self.grid._gtype in [GridType.RectilinearSGrid, GridType.RectilinearZGrid]: (zeta, eta, xsi, zi, yi, xi) = _search_indices_rectilinear( - self, time, z, y, x, ti, particle=particle, search2D=search2D + self, z, y, x,particle=particle, search2D=search2D ) else: (zeta, eta, xsi, zi, yi, xi) = _search_indices_curvilinear( - self, time, z, y, x, ti, particle=particle, search2D=search2D + self, z, y, x, ei=ei, search2D=search2D ) - return (tau, zeta, eta, xsi, ti, zi, yi, xi) - def _interpolator2D(self, time, z, y, x, particle=None): - """Impelement 2D interpolation with coordinate transformations as seen in Delandmeter and Van Sebille (2019), 10.5194/gmd-12-3571-2019..""" - try: - f = get_2d_interpolator_registry()[self.interp_method] - except KeyError: - if self.interp_method == "cgrid_velocity": - raise RuntimeError( - f"{self.name} is a scalar field. cgrid_velocity interpolation method should be used for vector fields (e.g. FieldSet.UV)" - ) - else: - raise RuntimeError(self.interp_method + " is not implemented for 2D grids") + # To do : Calcualte barycentric coordinates from zeta, eta, xsi - (tau, _, eta, xsi, ti, _, yi, xi) = self._search_indices(time, z, y, x, particle=particle) + return (zeta, eta, xsi, zi, yi, xi) + + def _search_indices(self, time, z, y, x, ei=None, search2D=False): - ctx = InterpolationContext2D(self.data, tau, eta, xsi, ti, yi, xi) - return f(ctx) - - def _interpolator3D(self, time, z, y, x, particle=None): - """Impelement 3D interpolation with coordinate transformations as seen in Delandmeter and Van Sebille (2019), 10.5194/gmd-12-3571-2019..""" - try: - f = get_3d_interpolator_registry()[self.interp_method] - except KeyError: - raise RuntimeError(self.interp_method + " is not implemented for 3D grids") + tau, ti = self._search_time_index(time) # To do : Need to implement this method - (tau, zeta, eta, xsi, ti, zi, yi, xi) = self._search_indices(time, z, y, x, particle=particle) - - ctx = InterpolationContext3D(self.data, tau, zeta, eta, xsi, ti, zi, yi, xi, self.gridindexingtype) - return f(ctx) + if type(self.data) is ux.UxDataArray: + bcoords, ei = self._search_indices_unstructured(z, y, x, ei=ei, search2D=search2D) # To do : Need to implement this method + else: + bcoords, ei = self._search_indices_structured(z, y, x, ei=ei, search2D=search2D) # To do : Need to implement this method + return bcoords, ei, ti + + def _interpolate(self, time, z, y, x, ei): - def _interpolate(self, time, z, y, x, particle=None): - """Interpolate spatial field values.""" try: - if self.grid.zdim == 1: - val = self._interpolator2D(time, z, y, x, particle=particle) - else: - val = self._interpolator3D(time, z, y, x, particle=particle) + bcoords, _ei, ti = self._search_indices(time, z, y, x, ei=ei) + val = self._interp_method(ti, _ei, bcoords, time, z, y, x) if np.isnan(val): # Detect Out-of-bounds sampling and raise exception _raise_field_out_of_bound_error(z, y, x) else: return val - + except (FieldSamplingError, FieldOutOfBoundError, FieldOutOfBoundSurfaceError) as e: e = add_note(e, f"Error interpolating field '{self.name}'.", before=True) raise e @@ -607,24 +412,26 @@ def __getitem__(self, key): return self.eval(*key) except tuple(AllParcelsErrorCodes.keys()) as error: return _deal_with_errors(error, key, vector_type=None) - - def eval(self, time, z, y, x, particle=None, applyConversion=True): + + def eval(self, time, z, y, x, ei=None, applyConversion=True): """Interpolate field values in space and time. We interpolate linearly in time and apply implicit unit conversion to the result. Note that we defer to scipy.interpolate to perform spatial interpolation. """ - if self.gridindexingtype == "croco" and self not in [self.fieldset.H, self.fieldset.Zeta]: - z = _croco_from_z_to_sigma_scipy(self.fieldset, time, z, y, x, particle=particle) + if ei is None: + _ei = None + else: + _ei = ei[self.igrid] - value = self._interpolate(time, z, y, x, particle=particle) + value = self._interpolate(time, z, y, x, ei=_ei) if applyConversion: return self.units.to_target(value, z, y, x) else: return value - + def _rescale_and_set_minmax(self, data): data[np.isnan(data)] = 0 return data @@ -639,17 +446,21 @@ def ravel_index(self, zi, yi, xi): yi : int y index xi : int - x index + x index. When using an unstructured grid, this is the face index (fi) Returns ------- int flat index """ - return xi + self.grid.xdim * (yi + self.grid.ydim * zi) + if type(self.data) is xr.DataArray: + return xi + self.nx * (yi + self.ny * zi) + else: + return xi + self.n_face*zi def unravel_index(self, ei): """Return the zi, yi, xi indices for a given flat index. + Only used when working with fields on a structured grid. Parameters ---------- @@ -665,48 +476,136 @@ def unravel_index(self, ei): xi : int The x index. """ - _ei = ei[self.igrid] - zi = _ei // (self.grid.xdim * self.grid.ydim) - _ei = _ei % (self.grid.xdim * self.grid.ydim) - yi = _ei // self.grid.xdim - xi = _ei % self.grid.xdim - return zi, yi, xi - - -class VectorField: - """Class VectorField stores 2 or 3 fields which defines together a vector field. - This enables to interpolate them as one single vector field in the kernels. - - Parameters - ---------- - name : str - Name of the vector field - U : parcels.field.Field - field defining the zonal component - V : parcels.field.Field - field defining the meridional component - W : parcels.field.Field - field defining the vertical component (default: None) - """ + if type(self.data) is xr.DataArray: + _ei = ei[self.igrid] + zi = _ei // (self.nx * self.ny) + _ei = _ei % (self.nx * self.ny) + yi = _ei // self.nx + xi = _ei % self.nx + return zi, yi, xi + else: + _ei = ei[self.igrid] + zi = _ei // self.n_face + fi = _ei % self.n_face + return zi, fi + + def _validate_dataarray(self): + """ Verifies that all the required attributes are present in the xarray.DataArray or + uxarray.UxDataArray object.""" + + # Validate dimensions + if not( "nz1" in self.data.dims or "nz" in self.data.dims ): + raise ValueError( + f"Field {self.name} is missing a 'nz1' or 'nz' dimension in the field's metadata. " + "This attribute is required for xarray.DataArray objects." + ) + + if not( "time" in self.data.dims ): + raise ValueError( + f"Field {self.name} is missing a 'time' dimension in the field's metadata. " + "This attribute is required for xarray.DataArray objects." + ) + + # Validate attributes + required_keys = ["location", "mesh"] + for key in required_keys: + if key not in self.data.attrs.keys(): + raise ValueError( + f"Field {self.name} is missing a '{key}' attribute in the field's metadata. " + "This attribute is required for xarray.DataArray objects." + ) + + if type(self.data) is ux.UxDataArray: + self._validate_uxgrid() + + + def _validate_uxgrid(self): + """ Verifies that all the required attributes are present in the uxarray.UxDataArray.UxGrid object.""" + + if "Conventions" not in self.data.uxgrid.attrs.keys(): + raise ValueError( + f"Field {self.name} is missing a 'Conventions' attribute in the field's metadata. " + "This attribute is required for uxarray.UxDataArray objects." + ) + if self.data.uxgrid.attrs["Conventions"] != "UGRID-1.0": + raise ValueError( + f"Field {self.name} has a 'Conventions' attribute that is not 'UGRID-1.0'. " + "This attribute is required for uxarray.UxDataArray objects." + "See https://ugrid-conventions.github.io/ugrid-conventions/ for more information." + ) + + + def __getattr__(self, key: str): + return getattr(self.data, key) + + def __contains__(self, key: str): + return key in self.data + + +class XVectorField: + """XVectorField class that holds vector field data needed to execute particles.""" + + + @staticmethod + def _vector_interp_template( + self, + ti: int, + ei: int, + bcoords: np.ndarray, + t: Union[np.float32,np.float64], + z: Union[np.float32,np.float64], + y: Union[np.float32,np.float64], + x: Union[np.float32,np.float64] + )-> Union[np.float32,np.float64]: + """ Template function used for the signature check of the lateral interpolation methods.""" + return 0.0 + + def _validate_vector_interp_function(self, func: Callable): + """Ensures that the function has the correct signature.""" + expected_params = ["ti", "ei", "bcoords", "t", "z", "y", "x"] + expected_return_types = (np.float32,np.float64) + + sig = inspect.signature(func) + params = list(sig.parameters.keys()) + + # Check the parameter names and count + if params != expected_params: + raise TypeError( + f"Function must have parameters {expected_params}, but got {params}" + ) - def __init__(self, name: str, U: Field, V: Field, W: Field | None = None): + # Check return annotation if present + return_annotation = sig.return_annotation + if return_annotation not in (inspect.Signature.empty, *expected_return_types): + raise TypeError( + f"Function must return a float, but got {return_annotation}" + ) + + def __init__( + self, + name: str, + U: XField, + V: XField, + W: XField | None = None, + vector_interp_method: Callable | None = None + ): + self.name = name self.U = U self.V = V self.W = W - if self.U.gridindexingtype == "croco" and self.W: - self.vector_type: VectorType = "3DSigma" - elif self.W: + + if self.W: self.vector_type = "3D" else: self.vector_type = "2D" - self.gridindexingtype = U.gridindexingtype - if self.U.interp_method == "cgrid_velocity": - assert self.V.interp_method == "cgrid_velocity", "Interpolation methods of U and V are not the same." - assert self._check_grid_dimensions(U.grid, V.grid), "Dimensions of U and V are not the same." - if W is not None and self.U.gridindexingtype != "croco": - assert W.interp_method == "cgrid_velocity", "Interpolation methods of U and W are not the same." - assert self._check_grid_dimensions(U.grid, W.grid), "Dimensions of U and W are not the same." + + # Setting the interpolation method dynamically + if vector_interp_method is None: + self._vector_interp_method = None + else: + self._validate_vector_interp_function(vector_interp_method) + self._interp_method = vector_interp_method def __repr__(self): return f"""<{type(self).__name__}> @@ -715,472 +614,62 @@ def __repr__(self): V: {default_repr(self.V)} W: {default_repr(self.W)}""" - @staticmethod - def _check_grid_dimensions(grid1, grid2): - return ( - np.allclose(grid1.lon, grid2.lon) - and np.allclose(grid1.lat, grid2.lat) - and np.allclose(grid1.depth, grid2.depth) - and np.allclose(grid1.time, grid2.time) - ) - - def c_grid_interpolation2D(self, time, z, y, x, particle=None, applyConversion=True): - grid = self.U.grid - (tau, _, eta, xsi, ti, zi, yi, xi) = self.U._search_indices(time, z, y, x, particle=particle) - - if grid._gtype in [GridType.RectilinearSGrid, GridType.RectilinearZGrid]: - px = np.array([grid.lon[xi], grid.lon[xi + 1], grid.lon[xi + 1], grid.lon[xi]]) - py = np.array([grid.lat[yi], grid.lat[yi], grid.lat[yi + 1], grid.lat[yi + 1]]) - else: - px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]]) - py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]]) - - if grid.mesh == "spherical": - px[0] = px[0] + 360 if px[0] < x - 225 else px[0] - px[0] = px[0] - 360 if px[0] > x + 225 else px[0] - px[1:] = np.where(px[1:] - px[0] > 180, px[1:] - 360, px[1:]) - px[1:] = np.where(-px[1:] + px[0] > 180, px[1:] + 360, px[1:]) - xx = (1 - xsi) * (1 - eta) * px[0] + xsi * (1 - eta) * px[1] + xsi * eta * px[2] + (1 - xsi) * eta * px[3] - assert abs(xx - x) < 1e-4 - c1 = i_u._geodetic_distance(py[0], py[1], px[0], px[1], grid.mesh, np.dot(i_u.phi2D_lin(0.0, xsi), py)) - c2 = i_u._geodetic_distance(py[1], py[2], px[1], px[2], grid.mesh, np.dot(i_u.phi2D_lin(eta, 1.0), py)) - c3 = i_u._geodetic_distance(py[2], py[3], px[2], px[3], grid.mesh, np.dot(i_u.phi2D_lin(1.0, xsi), py)) - c4 = i_u._geodetic_distance(py[3], py[0], px[3], px[0], grid.mesh, np.dot(i_u.phi2D_lin(eta, 0.0), py)) - - def _calc_UV(ti, yi, xi): - if grid.zdim == 1: - if self.gridindexingtype == "nemo": - U0 = self.U.data[ti, yi + 1, xi] * c4 - U1 = self.U.data[ti, yi + 1, xi + 1] * c2 - V0 = self.V.data[ti, yi, xi + 1] * c1 - V1 = self.V.data[ti, yi + 1, xi + 1] * c3 - elif self.gridindexingtype in ["mitgcm", "croco"]: - U0 = self.U.data[ti, yi, xi] * c4 - U1 = self.U.data[ti, yi, xi + 1] * c2 - V0 = self.V.data[ti, yi, xi] * c1 - V1 = self.V.data[ti, yi + 1, xi] * c3 - else: - if self.gridindexingtype == "nemo": - U0 = self.U.data[ti, zi, yi + 1, xi] * c4 - U1 = self.U.data[ti, zi, yi + 1, xi + 1] * c2 - V0 = self.V.data[ti, zi, yi, xi + 1] * c1 - V1 = self.V.data[ti, zi, yi + 1, xi + 1] * c3 - elif self.gridindexingtype in ["mitgcm", "croco"]: - U0 = self.U.data[ti, zi, yi, xi] * c4 - U1 = self.U.data[ti, zi, yi, xi + 1] * c2 - V0 = self.V.data[ti, zi, yi, xi] * c1 - V1 = self.V.data[ti, zi, yi + 1, xi] * c3 - U = (1 - xsi) * U0 + xsi * U1 - V = (1 - eta) * V0 + eta * V1 - rad = np.pi / 180.0 - deg2m = 1852 * 60.0 - if applyConversion: - meshJac = (deg2m * deg2m * math.cos(rad * y)) if grid.mesh == "spherical" else 1 + @property + def vector_interp_method(self): + return self._vector_interp_method + + @vector_interp_method.setter + def vector_interp_method(self, method: Callable): + self._validate_vector_interp_function(method) + self._vector_interp_method = method + + # @staticmethod + # To do : def _check_grid_dimensions(grid1, grid2): + # return ( + # np.allclose(grid1.lon, grid2.lon) + # and np.allclose(grid1.lat, grid2.lat) + # and np.allclose(grid1.depth, grid2.depth) + # and np.allclose(grid1.time, grid2.time) + # ) + def _interpolate(self, time, z, y, x, ei): + + bcoords, _ei, ti = self._search_indices(time, z, y, x, ei=ei) + + if self._vector_interp_method is None: + u = self.U.eval(time, z, y, x, _ei, applyConversion=False) + v = self.V.eval(time, z, y, x, _ei, applyConversion=False) + if "3D" in self.vector_type: + w = self.W.eval(time, z, y, x, ei, applyConversion=False) + return (u, v, w) else: - meshJac = deg2m if grid.mesh == "spherical" else 1 - - jac = i_u._compute_jacobian_determinant(py, px, eta, xsi) * meshJac - - u = ( - (-(1 - eta) * U - (1 - xsi) * V) * px[0] - + ((1 - eta) * U - xsi * V) * px[1] - + (eta * U + xsi * V) * px[2] - + (-eta * U + (1 - xsi) * V) * px[3] - ) / jac - v = ( - (-(1 - eta) * U - (1 - xsi) * V) * py[0] - + ((1 - eta) * U - xsi * V) * py[1] - + (eta * U + xsi * V) * py[2] - + (-eta * U + (1 - xsi) * V) * py[3] - ) / jac - if isinstance(u, da.core.Array): - u = u.compute() - v = v.compute() - return (u, v) - - u, v = _calc_UV(ti, yi, xi) - if should_calculate_next_ti(ti, tau, self.U.grid.tdim): - ut1, vt1 = _calc_UV(ti + 1, yi, xi) - u = (1 - tau) * u + tau * ut1 - v = (1 - tau) * v + tau * vt1 - return (u, v) - - def c_grid_interpolation3D_full(self, time, z, y, x, particle=None): - grid = self.U.grid - (tau, zeta, eta, xsi, ti, zi, yi, xi) = self.U._search_indices(time, z, y, x, particle=particle) - - if grid._gtype in [GridType.RectilinearSGrid, GridType.RectilinearZGrid]: - px = np.array([grid.lon[xi], grid.lon[xi + 1], grid.lon[xi + 1], grid.lon[xi]]) - py = np.array([grid.lat[yi], grid.lat[yi], grid.lat[yi + 1], grid.lat[yi + 1]]) + return (u, v, 0) else: - px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]]) - py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]]) - - if grid.mesh == "spherical": - px[0] = px[0] + 360 if px[0] < x - 225 else px[0] - px[0] = px[0] - 360 if px[0] > x + 225 else px[0] - px[1:] = np.where(px[1:] - px[0] > 180, px[1:] - 360, px[1:]) - px[1:] = np.where(-px[1:] + px[0] > 180, px[1:] + 360, px[1:]) - xx = (1 - xsi) * (1 - eta) * px[0] + xsi * (1 - eta) * px[1] + xsi * eta * px[2] + (1 - xsi) * eta * px[3] - assert abs(xx - x) < 1e-4 - - px = np.concatenate((px, px)) - py = np.concatenate((py, py)) - if grid._z4d: - pz = np.array( - [ - grid.depth[0, zi, yi, xi], - grid.depth[0, zi, yi, xi + 1], - grid.depth[0, zi, yi + 1, xi + 1], - grid.depth[0, zi, yi + 1, xi], - grid.depth[0, zi + 1, yi, xi], - grid.depth[0, zi + 1, yi, xi + 1], - grid.depth[0, zi + 1, yi + 1, xi + 1], - grid.depth[0, zi + 1, yi + 1, xi], - ] - ) - else: - pz = np.array( - [ - grid.depth[zi, yi, xi], - grid.depth[zi, yi, xi + 1], - grid.depth[zi, yi + 1, xi + 1], - grid.depth[zi, yi + 1, xi], - grid.depth[zi + 1, yi, xi], - grid.depth[zi + 1, yi, xi + 1], - grid.depth[zi + 1, yi + 1, xi + 1], - grid.depth[zi + 1, yi + 1, xi], - ] - ) + (u,v,w) = self._vector_interp_method(ti, _ei, bcoords, time, z, y, x) + return (u, v, w) - u0 = self.U.data[ti, zi, yi + 1, xi] - u1 = self.U.data[ti, zi, yi + 1, xi + 1] - v0 = self.V.data[ti, zi, yi, xi + 1] - v1 = self.V.data[ti, zi, yi + 1, xi + 1] - w0 = self.W.data[ti, zi, yi + 1, xi + 1] - w1 = self.W.data[ti, zi + 1, yi + 1, xi + 1] - - if should_calculate_next_ti(ti, tau, self.U.grid.tdim): - u0 = (1 - tau) * u0 + tau * self.U.data[ti + 1, zi, yi + 1, xi] - u1 = (1 - tau) * u1 + tau * self.U.data[ti + 1, zi, yi + 1, xi + 1] - v0 = (1 - tau) * v0 + tau * self.V.data[ti + 1, zi, yi, xi + 1] - v1 = (1 - tau) * v1 + tau * self.V.data[ti + 1, zi, yi + 1, xi + 1] - w0 = (1 - tau) * w0 + tau * self.W.data[ti + 1, zi, yi + 1, xi + 1] - w1 = (1 - tau) * w1 + tau * self.W.data[ti + 1, zi + 1, yi + 1, xi + 1] - - U0 = u0 * i_u.jacobian3D_lin_face(pz, py, px, zeta, eta, 0, "zonal", grid.mesh) - U1 = u1 * i_u.jacobian3D_lin_face(pz, py, px, zeta, eta, 1, "zonal", grid.mesh) - V0 = v0 * i_u.jacobian3D_lin_face(pz, py, px, zeta, 0, xsi, "meridional", grid.mesh) - V1 = v1 * i_u.jacobian3D_lin_face(pz, py, px, zeta, 1, xsi, "meridional", grid.mesh) - W0 = w0 * i_u.jacobian3D_lin_face(pz, py, px, 0, eta, xsi, "vertical", grid.mesh) - W1 = w1 * i_u.jacobian3D_lin_face(pz, py, px, 1, eta, xsi, "vertical", grid.mesh) - - # Computing fluxes in half left hexahedron -> flux_u05 - xx = [ - px[0], - (px[0] + px[1]) / 2, - (px[2] + px[3]) / 2, - px[3], - px[4], - (px[4] + px[5]) / 2, - (px[6] + px[7]) / 2, - px[7], - ] - yy = [ - py[0], - (py[0] + py[1]) / 2, - (py[2] + py[3]) / 2, - py[3], - py[4], - (py[4] + py[5]) / 2, - (py[6] + py[7]) / 2, - py[7], - ] - zz = [ - pz[0], - (pz[0] + pz[1]) / 2, - (pz[2] + pz[3]) / 2, - pz[3], - pz[4], - (pz[4] + pz[5]) / 2, - (pz[6] + pz[7]) / 2, - pz[7], - ] - flux_u0 = u0 * i_u.jacobian3D_lin_face(zz, yy, xx, 0.5, 0.5, 0, "zonal", grid.mesh) - flux_v0_halfx = v0 * i_u.jacobian3D_lin_face(zz, yy, xx, 0.5, 0, 0.5, "meridional", grid.mesh) - flux_v1_halfx = v1 * i_u.jacobian3D_lin_face(zz, yy, xx, 0.5, 1, 0.5, "meridional", grid.mesh) - flux_w0_halfx = w0 * i_u.jacobian3D_lin_face(zz, yy, xx, 0, 0.5, 0.5, "vertical", grid.mesh) - flux_w1_halfx = w1 * i_u.jacobian3D_lin_face(zz, yy, xx, 1, 0.5, 0.5, "vertical", grid.mesh) - flux_u05 = flux_u0 + flux_v0_halfx - flux_v1_halfx + flux_w0_halfx - flux_w1_halfx - - # Computing fluxes in half front hexahedron -> flux_v05 - xx = [ - px[0], - px[1], - (px[1] + px[2]) / 2, - (px[0] + px[3]) / 2, - px[4], - px[5], - (px[5] + px[6]) / 2, - (px[4] + px[7]) / 2, - ] - yy = [ - py[0], - py[1], - (py[1] + py[2]) / 2, - (py[0] + py[3]) / 2, - py[4], - py[5], - (py[5] + py[6]) / 2, - (py[4] + py[7]) / 2, - ] - zz = [ - pz[0], - pz[1], - (pz[1] + pz[2]) / 2, - (pz[0] + pz[3]) / 2, - pz[4], - pz[5], - (pz[5] + pz[6]) / 2, - (pz[4] + pz[7]) / 2, - ] - flux_u0_halfy = u0 * i_u.jacobian3D_lin_face(zz, yy, xx, 0.5, 0.5, 0, "zonal", grid.mesh) - flux_u1_halfy = u1 * i_u.jacobian3D_lin_face(zz, yy, xx, 0.5, 0.5, 1, "zonal", grid.mesh) - flux_v0 = v0 * i_u.jacobian3D_lin_face(zz, yy, xx, 0.5, 0, 0.5, "meridional", grid.mesh) - flux_w0_halfy = w0 * i_u.jacobian3D_lin_face(zz, yy, xx, 0, 0.5, 0.5, "vertical", grid.mesh) - flux_w1_halfy = w1 * i_u.jacobian3D_lin_face(zz, yy, xx, 1, 0.5, 0.5, "vertical", grid.mesh) - flux_v05 = flux_u0_halfy - flux_u1_halfy + flux_v0 + flux_w0_halfy - flux_w1_halfy - - # Computing fluxes in half lower hexahedron -> flux_w05 - xx = [ - px[0], - px[1], - px[2], - px[3], - (px[0] + px[4]) / 2, - (px[1] + px[5]) / 2, - (px[2] + px[6]) / 2, - (px[3] + px[7]) / 2, - ] - yy = [ - py[0], - py[1], - py[2], - py[3], - (py[0] + py[4]) / 2, - (py[1] + py[5]) / 2, - (py[2] + py[6]) / 2, - (py[3] + py[7]) / 2, - ] - zz = [ - pz[0], - pz[1], - pz[2], - pz[3], - (pz[0] + pz[4]) / 2, - (pz[1] + pz[5]) / 2, - (pz[2] + pz[6]) / 2, - (pz[3] + pz[7]) / 2, - ] - flux_u0_halfz = u0 * i_u.jacobian3D_lin_face(zz, yy, xx, 0.5, 0.5, 0, "zonal", grid.mesh) - flux_u1_halfz = u1 * i_u.jacobian3D_lin_face(zz, yy, xx, 0.5, 0.5, 1, "zonal", grid.mesh) - flux_v0_halfz = v0 * i_u.jacobian3D_lin_face(zz, yy, xx, 0.5, 0, 0.5, "meridional", grid.mesh) - flux_v1_halfz = v1 * i_u.jacobian3D_lin_face(zz, yy, xx, 0.5, 1, 0.5, "meridional", grid.mesh) - flux_w0 = w0 * i_u.jacobian3D_lin_face(zz, yy, xx, 0, 0.5, 0.5, "vertical", grid.mesh) - flux_w05 = flux_u0_halfz - flux_u1_halfz + flux_v0_halfz - flux_v1_halfz + flux_w0 - - surf_u05 = i_u.jacobian3D_lin_face(pz, py, px, 0.5, 0.5, 0.5, "zonal", grid.mesh) - jac_u05 = i_u.jacobian3D_lin_face(pz, py, px, zeta, eta, 0.5, "zonal", grid.mesh) - U05 = flux_u05 / surf_u05 * jac_u05 - - surf_v05 = i_u.jacobian3D_lin_face(pz, py, px, 0.5, 0.5, 0.5, "meridional", grid.mesh) - jac_v05 = i_u.jacobian3D_lin_face(pz, py, px, zeta, 0.5, xsi, "meridional", grid.mesh) - V05 = flux_v05 / surf_v05 * jac_v05 - - surf_w05 = i_u.jacobian3D_lin_face(pz, py, px, 0.5, 0.5, 0.5, "vertical", grid.mesh) - jac_w05 = i_u.jacobian3D_lin_face(pz, py, px, 0.5, eta, xsi, "vertical", grid.mesh) - W05 = flux_w05 / surf_w05 * jac_w05 - - jac = i_u.jacobian3D_lin(pz, py, px, zeta, eta, xsi, grid.mesh) - dxsidt = i_u.interpolate(i_u.phi1D_quad, [U0, U05, U1], xsi) / jac - detadt = i_u.interpolate(i_u.phi1D_quad, [V0, V05, V1], eta) / jac - dzetdt = i_u.interpolate(i_u.phi1D_quad, [W0, W05, W1], zeta) / jac - - dphidxsi, dphideta, dphidzet = i_u.dphidxsi3D_lin(zeta, eta, xsi) - - u = np.dot(dphidxsi, px) * dxsidt + np.dot(dphideta, px) * detadt + np.dot(dphidzet, px) * dzetdt - v = np.dot(dphidxsi, py) * dxsidt + np.dot(dphideta, py) * detadt + np.dot(dphidzet, py) * dzetdt - w = np.dot(dphidxsi, pz) * dxsidt + np.dot(dphideta, pz) * detadt + np.dot(dphidzet, pz) * dzetdt - - if isinstance(u, da.core.Array): - u = u.compute() - v = v.compute() - w = w.compute() - return (u, v, w) + + def eval(self, time, z, y, x, ei=None, applyConversion=True): - def c_grid_interpolation3D(self, ti, z, y, x, time, particle=None, applyConversion=True): - """Perform C grid interpolation in 3D. :: - - +---+---+---+ - | |V1 | | - +---+---+---+ - |U0 | |U1 | - +---+---+---+ - | |V0 | | - +---+---+---+ - - The interpolation is done in the following by - interpolating linearly U depending on the longitude coordinate and - interpolating linearly V depending on the latitude coordinate. - Curvilinear grids are treated properly, since the element is projected to a rectilinear parent element. - """ - if self.U.grid._gtype in [GridType.RectilinearSGrid, GridType.CurvilinearSGrid]: - (u, v, w) = self.c_grid_interpolation3D_full(time, z, y, x, particle=particle) + if ei is None: + _ei = 0 else: - if self.gridindexingtype == "croco": - z = _croco_from_z_to_sigma_scipy(self.fieldset, time, z, y, x, particle=particle) - (u, v) = self.c_grid_interpolation2D(time, z, y, x, particle=particle) - w = self.W.eval(time, z, y, x, particle=particle, applyConversion=False) - if applyConversion: + _ei = ei[self.igrid] + + (u,v,w) = self._interpolate(time, z, y, x, _ei) + + if applyConversion: + u = self.U.units.to_target(u, z, y, x) + v = self.V.units.to_target(v, z, y, x) + if "3D" in self.vector_type: w = self.W.units.to_target(w, z, y, x) + return (u, v, w) - def _is_land2D(self, di, yi, xi): - if self.U.data.ndim == 3: - if di < np.shape(self.U.data)[0]: - return np.isclose(self.U.data[di, yi, xi], 0.0) and np.isclose(self.V.data[di, yi, xi], 0.0) - else: - return True - else: - if di < self.U.grid.zdim and yi < np.shape(self.U.data)[-2] and xi < np.shape(self.U.data)[-1]: - return np.isclose(self.U.data[0, di, yi, xi], 0.0) and np.isclose(self.V.data[0, di, yi, xi], 0.0) - else: - return True - - def slip_interpolation(self, time, z, y, x, particle=None, applyConversion=True): - (_, zeta, eta, xsi, ti, zi, yi, xi) = self.U._search_indices(time, z, y, x, particle=particle) - di = ti if self.U.grid.zdim == 1 else zi # general third dimension - - f_u, f_v, f_w = 1, 1, 1 - if ( - self._is_land2D(di, yi, xi) - and self._is_land2D(di, yi, xi + 1) - and self._is_land2D(di + 1, yi, xi) - and self._is_land2D(di + 1, yi, xi + 1) - and eta > 0 - ): - if self.U.interp_method == "partialslip": - f_u = f_u * (0.5 + 0.5 * eta) / eta - if self.vector_type == "3D": - f_w = f_w * (0.5 + 0.5 * eta) / eta - elif self.U.interp_method == "freeslip": - f_u = f_u / eta - if self.vector_type == "3D": - f_w = f_w / eta - if ( - self._is_land2D(di, yi + 1, xi) - and self._is_land2D(di, yi + 1, xi + 1) - and self._is_land2D(di + 1, yi + 1, xi) - and self._is_land2D(di + 1, yi + 1, xi + 1) - and eta < 1 - ): - if self.U.interp_method == "partialslip": - f_u = f_u * (1 - 0.5 * eta) / (1 - eta) - if self.vector_type == "3D": - f_w = f_w * (1 - 0.5 * eta) / (1 - eta) - elif self.U.interp_method == "freeslip": - f_u = f_u / (1 - eta) - if self.vector_type == "3D": - f_w = f_w / (1 - eta) - if ( - self._is_land2D(di, yi, xi) - and self._is_land2D(di, yi + 1, xi) - and self._is_land2D(di + 1, yi, xi) - and self._is_land2D(di + 1, yi + 1, xi) - and xsi > 0 - ): - if self.U.interp_method == "partialslip": - f_v = f_v * (0.5 + 0.5 * xsi) / xsi - if self.vector_type == "3D": - f_w = f_w * (0.5 + 0.5 * xsi) / xsi - elif self.U.interp_method == "freeslip": - f_v = f_v / xsi - if self.vector_type == "3D": - f_w = f_w / xsi - if ( - self._is_land2D(di, yi, xi + 1) - and self._is_land2D(di, yi + 1, xi + 1) - and self._is_land2D(di + 1, yi, xi + 1) - and self._is_land2D(di + 1, yi + 1, xi + 1) - and xsi < 1 - ): - if self.U.interp_method == "partialslip": - f_v = f_v * (1 - 0.5 * xsi) / (1 - xsi) - if self.vector_type == "3D": - f_w = f_w * (1 - 0.5 * xsi) / (1 - xsi) - elif self.U.interp_method == "freeslip": - f_v = f_v / (1 - xsi) - if self.vector_type == "3D": - f_w = f_w / (1 - xsi) - if self.U.grid.zdim > 1: - if ( - self._is_land2D(di, yi, xi) - and self._is_land2D(di, yi, xi + 1) - and self._is_land2D(di, yi + 1, xi) - and self._is_land2D(di, yi + 1, xi + 1) - and zeta > 0 - ): - if self.U.interp_method == "partialslip": - f_u = f_u * (0.5 + 0.5 * zeta) / zeta - f_v = f_v * (0.5 + 0.5 * zeta) / zeta - elif self.U.interp_method == "freeslip": - f_u = f_u / zeta - f_v = f_v / zeta - if ( - self._is_land2D(di + 1, yi, xi) - and self._is_land2D(di + 1, yi, xi + 1) - and self._is_land2D(di + 1, yi + 1, xi) - and self._is_land2D(di + 1, yi + 1, xi + 1) - and zeta < 1 - ): - if self.U.interp_method == "partialslip": - f_u = f_u * (1 - 0.5 * zeta) / (1 - zeta) - f_v = f_v * (1 - 0.5 * zeta) / (1 - zeta) - elif self.U.interp_method == "freeslip": - f_u = f_u / (1 - zeta) - f_v = f_v / (1 - zeta) - - u = f_u * self.U.eval(time, z, y, x, particle, applyConversion=applyConversion) - v = f_v * self.V.eval(time, z, y, x, particle, applyConversion=applyConversion) - if self.vector_type == "3D": - w = f_w * self.W.eval(time, z, y, x, particle, applyConversion=applyConversion) - return u, v, w - else: - return u, v - - def eval(self, time, z, y, x, particle=None, applyConversion=True): - if self.U.interp_method in ["partialslip", "freeslip"]: - return self.slip_interpolation(time, z, y, x, particle=particle, applyConversion=applyConversion) - - if self.U.interp_method not in ["cgrid_velocity"]: - u = self.U.eval(time, z, y, x, particle=particle, applyConversion=False) - v = self.V.eval(time, z, y, x, particle=particle, applyConversion=False) - if applyConversion: - u = self.U.units.to_target(u, z, y, x) - v = self.V.units.to_target(v, z, y, x) - elif self.U.interp_method == "cgrid_velocity": - (u, v) = self.c_grid_interpolation2D(time, z, y, x, particle=particle, applyConversion=applyConversion) - if "3D" in self.vector_type: - w = self.W.eval(time, z, y, x, particle=particle, applyConversion=applyConversion) - return (u, v, w) - else: - return (u, v) - - def __getitem__(self, key): + def __getitem__(self,key): try: if _isParticle(key): - return self.eval(key.time, key.depth, key.lat, key.lon, key) + return self.eval(key.time, key.depth, key.lat, key.lon, key.ei) else: return self.eval(*key) except tuple(AllParcelsErrorCodes.keys()) as error: - return _deal_with_errors(error, key, vector_type=self.vector_type) + return _deal_with_errors(error, key, vector_type=self.vector_type) \ No newline at end of file diff --git a/parcels/fieldset.py b/parcels/fieldset.py index cc8707c20a..3b62ad9a0d 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -7,145 +7,72 @@ import numpy as np from parcels._typing import GridIndexingType, InterpMethodOption, Mesh -from parcels.field import Field, VectorField -from parcels.grid import Grid -from parcels.gridset import GridSet +from parcels.xfield import XField, XVectorField from parcels.particlefile import ParticleFile -from parcels.tools._helpers import fieldset_repr +from parcels.tools._helpers import fieldset_repr, default_repr from parcels.tools.converters import TimeConverter from parcels.tools.warnings import FieldSetWarning -__all__ = ["FieldSet"] +import xarray as xr +import uxarray as ux +__all__ = ["XFieldSet"] -class FieldSet: - """FieldSet class that holds hydrodynamic data needed to execute particles. +class XFieldSet: + """XFieldSet class that holds hydrodynamic data needed to execute particles. + Parameters ---------- - U : parcels.field.Field - Field object for zonal velocity component - V : parcels.field.Field - Field object for meridional velocity component - fields : dict mapping str to Field - Additional fields to include in the FieldSet. These fields can be used - in custom kernels. + ds : xarray.Dataset | uxarray.UxDataset) + xarray.Dataset and/or uxarray.UxDataset objects containing the field data. + + Notes + ----- + The `ds` object is a xarray.Dataset or uxarray.UxDataset object. + In XArray terminology, the (Ux)Dataset holds multiple (Ux)DataArray objects. + Each (Ux)DataArray object is a single "field" that is associated with their own + dimensions and coordinates within the (Ux)Dataset. + + A (Ux)Dataset object is associated with a single mesh, which can have multiple + types of "points" (multiple "grids") (e.g. for UxDataSets, these are "face_lon", + "face_lat", "node_lon", "node_lat", "edge_lon", "edge_lat"). Each (Ux)DataArray is + registered to a specific set of points on the mesh. + + For UxDataset objects, each `UXDataArray.attributes` field dictionary contains + the necessary metadata to help determine which set of points a field is registered + to and what parent model the field is associated with. Parcels uses this metadata + during execution for interpolation. Each `UXDataArray.attributes` field dictionary + must have: + * "location" key set to "face", "node", or "edge" to define which pairing of points a field is associated with. + * "mesh" key to define which parent model the fields are associated with (e.g. "fesom_mesh", "icon_mesh") + """ - def __init__(self, U: Field | None, V: Field | None, fields=None): - self.gridset = GridSet() - self._completed: bool = False - self._particlefile: ParticleFile | None = None - if U: - self.add_field(U, "U") - # see #1663 for type-ignore reason - self.time_origin = self.U.grid.time_origin if isinstance(self.U, Field) else self.U[0].grid.time_origin # type: ignore - if V: - self.add_field(V, "V") + def __init__(self, ds: xr.Dataset | ux.UxDataset): + self.ds = ds - # Add additional fields as attributes - if fields: - for name, field in fields.items(): - self.add_field(field, name) + self._completed: bool = False + # Create pointers to each (Ux)DataArray + for field in self.ds.data_vars: + setattr(self, field, XField(field,self.ds[field])) self._add_UVfield() def __repr__(self): return fieldset_repr(self) - - @property - def particlefile(self): - return self._particlefile - - @staticmethod - def checkvaliddimensionsdict(dims): - for d in dims: - if d not in ["lon", "lat", "depth", "time"]: - raise NameError(f"{d} is not a valid key in the dimensions dictionary") - - @classmethod - def from_data( - cls, - data, - dimensions, - mesh: Mesh = "spherical", - allow_time_extrapolation: bool | None = None, - **kwargs, - ): - """Initialise FieldSet object from raw data. - - Parameters - ---------- - data : - Dictionary mapping field names to numpy arrays. - Note that at least a 'U' and 'V' numpy array need to be given, and that - the built-in Advection kernels assume that U and V are in m/s. - Data shape is either [ydim, xdim], [zdim, ydim, xdim], [tdim, ydim, xdim] or [tdim, zdim, ydim, xdim], - dimensions : dict - Dictionary mapping field dimensions (lon, - lat, depth, time) to numpy arrays. - Note that dimensions can also be a dictionary of dictionaries if - dimension names are different for each variable - (e.g. dimensions['U'], dimensions['V'], etc). - mesh : str - String indicating the type of mesh coordinates and - units used during velocity interpolation, see also `this tutorial <../examples/tutorial_unitconverters.ipynb>`__: - - 1. spherical (default): Lat and lon in degree, with a - correction for zonal velocity U near the poles. - 2. flat: No conversion, lat/lon are assumed to be in m. - allow_time_extrapolation : bool - boolean whether to allow for extrapolation - (i.e. beyond the last available time snapshot) - Default is False if dimensions includes time, else True - **kwargs : - Keyword arguments passed to the :class:`Field` constructor. - - Examples - -------- - For usage examples see the following tutorials: - - * `Analytical advection <../examples/tutorial_analyticaladvection.ipynb>`__ - - * `Diffusion <../examples/tutorial_diffusion.ipynb>`__ - - * `Interpolation <../examples/tutorial_interpolation.ipynb>`__ - - * `Unit converters <../examples/tutorial_unitconverters.ipynb>`__ - """ - fields = {} - for name, datafld in data.items(): - # Use dimensions[name] if dimensions is a dict of dicts - dims = dimensions[name] if name in dimensions else dimensions - cls.checkvaliddimensionsdict(dims) - - if allow_time_extrapolation is None: - allow_time_extrapolation = False if "time" in dims else True - - lon = dims["lon"] - lat = dims["lat"] - depth = np.zeros(1, dtype=np.float32) if "depth" not in dims else dims["depth"] - time = np.zeros(1, dtype=np.float64) if "time" not in dims else dims["time"] - time = np.array(time) - if isinstance(time[0], np.datetime64): - time_origin = TimeConverter(time[0]) - time = np.array([time_origin.reltime(t) for t in time]) - else: - time_origin = kwargs.pop("time_origin", TimeConverter(0)) - grid = Grid.create_grid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh) - - fields[name] = Field( - name, - datafld, - grid=grid, - allow_time_extrapolation=allow_time_extrapolation, - **kwargs, - ) - u = fields.pop("U", None) - v = fields.pop("V", None) - return cls(u, v, fields=fields) - - def add_field(self, field: Field, name: str | None = None): + + # @property + # def particlefile(self): + # return self._particlefile + + # @staticmethod + # def checkvaliddimensionsdict(dims): + # for d in dims: + # if d not in ["lon", "lat", "depth", "time"]: + # raise NameError(f"{d} is not a valid key in the dimensions dictionary") + + def add_field(self, field: XField, name: str | None = None): """Add a :class:`parcels.field.Field` object to the FieldSet. Parameters @@ -174,7 +101,6 @@ def add_field(self, field: Field, name: str | None = None): raise RuntimeError(f"FieldSet already has a Field with name '{name}'") else: setattr(self, name, field) - self.gridset.add_grid(field) def add_constant_field(self, name: str, value: float, mesh: Mesh = "flat"): """Wrapper function to add a Field that is constant in space, @@ -194,69 +120,68 @@ def add_constant_field(self, name: str, value: float, mesh: Mesh = "flat"): correction for zonal velocity U near the poles. 2. flat: No conversion, lat/lon are assumed to be in m. """ - self.add_field(Field(name, value, lon=0, lat=0, mesh=mesh)) + import pandas as pd + + time = pd.to_datetime(['2000-01-01']) + values = np.zeros((1,1,1,1), dtype=np.float32) + value + data = xr.DataArray( + data=values, + name=name, + dims='null', + coords = [time,[0],[0],[0]], + attrs=dict( + description="null", + units="null", + location="node", + mesh=f"constant", + mesh_type=mesh + )) + self.add_field( + XField( + name, + data, + interp_method=None, # To do : Need to define an interpolation method for constants + allow_time_extrapolation=True + ) + ) def add_vector_field(self, vfield): """Add a :class:`parcels.field.VectorField` object to the FieldSet. Parameters ---------- - vfield : parcels.VectorField - class:`parcels.field.VectorField` object to be added + vfield : parcels.XVectorField + class:`parcels.xfieldset.XVectorField` object to be added """ setattr(self, vfield.name, vfield) for v in vfield.__dict__.values(): - if isinstance(v, Field) and (v not in self.get_fields()): + if isinstance(v, XField) and (v not in self.get_fields()): self.add_field(v) + def get_fields(self) -> list[XField | XVectorField]: + """Returns a list of all the :class:`parcels.field.Field` and :class:`parcels.field.VectorField` + objects associated with this FieldSet. + """ + fields = [] + for v in self.__dict__.values(): + if type(v) in [XField, XVectorField]: + if v not in fields: + fields.append(v) + return fields + def _add_UVfield(self): - if not hasattr(self, "UV") and hasattr(self, "U") and hasattr(self, "V"): - self.add_vector_field(VectorField("UV", self.U, self.V)) - if not hasattr(self, "UVW") and hasattr(self, "W"): - self.add_vector_field(VectorField("UVW", self.U, self.V, self.W)) + if not hasattr(self, "UV") and hasattr(self, "u") and hasattr(self, "v"): + self.add_xvector_field(XVectorField("UV", self.u, self.v)) + if not hasattr(self, "UVW") and hasattr(self, "w"): + self.add_xvector_field(XVectorField("UVW", self.u, self.v, self.w)) def _check_complete(self): - assert self.U, 'FieldSet does not have a Field named "U"' - assert self.V, 'FieldSet does not have a Field named "V"' + assert self.u, 'FieldSet does not have a Field named "u"' + assert self.v, 'FieldSet does not have a Field named "v"' for attr, value in vars(self).items(): - if type(value) is Field: + if type(value) is XField: assert value.name == attr, f"Field {value.name}.name ({attr}) is not consistent" - def check_velocityfields(U, V, W): - if (U.interp_method == "cgrid_velocity" and V.interp_method != "cgrid_velocity") or ( - U.interp_method != "cgrid_velocity" and V.interp_method == "cgrid_velocity" - ): - raise ValueError("If one of U,V.interp_method='cgrid_velocity', the other should be too") - - if "linear_invdist_land_tracer" in [U.interp_method, V.interp_method]: - raise NotImplementedError( - "interp_method='linear_invdist_land_tracer' is not implemented for U and V Fields" - ) - - if U.interp_method == "cgrid_velocity": - if U.grid.xdim == 1 or U.grid.ydim == 1 or V.grid.xdim == 1 or V.grid.ydim == 1: - raise NotImplementedError( - "C-grid velocities require longitude and latitude dimensions at least length 2" - ) - - if U.gridindexingtype not in ["nemo", "mitgcm", "mom5", "pop", "croco"]: - raise ValueError("Field.gridindexing has to be one of 'nemo', 'mitgcm', 'mom5', 'pop' or 'croco'") - - if V.gridindexingtype != U.gridindexingtype or (W and W.gridindexingtype != U.gridindexingtype): - raise ValueError("Not all velocity Fields have the same gridindexingtype") - - W = self.W if hasattr(self, "W") else None - check_velocityfields(self.U, self.V, W) - - for g in self.gridset.grids: - g._check_zonal_periodic() - if len(g.time) == 1: - continue - assert isinstance( - g.time_origin.time_origin, type(self.time_origin.time_origin) - ), "time origins of different grids must be have the same type" - g.time = g.time + self.time_origin.reltime(g.time_origin) - g._time_origin = self.time_origin self._add_UVfield() self._completed = True @@ -272,713 +197,96 @@ def _parse_wildcards(cls, paths, filenames, var): if not os.path.exists(fp): raise OSError(f"FieldSet file not found: {fp}") return paths - - @classmethod - def from_netcdf( - cls, - filenames, - variables, - dimensions, - fieldtype=None, - mesh: Mesh = "spherical", - allow_time_extrapolation: bool | None = None, - **kwargs, - ): - """Initialises FieldSet object from NetCDF files. - - Parameters - ---------- - filenames : - Dictionary mapping variables to file(s). The - filepath may contain wildcards to indicate multiple files - or be a list of file. - filenames can be a list ``[files]``, a dictionary ``{var:[files]}``, - a dictionary ``{dim:[files]}`` (if lon, lat, depth and/or data not stored in same files as data), - or a dictionary of dictionaries ``{var:{dim:[files]}}``. - time values are in ``filenames[data]`` - variables : dict - Dictionary mapping variables to variable names in the netCDF file(s). - Note that the built-in Advection kernels assume that U and V are in m/s - dimensions : dict - Dictionary mapping data dimensions (lon, - lat, depth, time, data) to dimensions in the netCF file(s). - Note that dimensions can also be a dictionary of dictionaries if - dimension names are different for each variable - (e.g. dimensions['U'], dimensions['V'], etc). - fieldtype : - Optional dictionary mapping fields to fieldtypes to be used for UnitConverter. - (either 'U', 'V', 'Kh_zonal', 'Kh_meridional' or None) (Default value = None) - mesh : str - String indicating the type of mesh coordinates and - units used during velocity interpolation, see also `this tutorial <../examples/tutorial_unitconverters.ipynb>`__: - - 1. spherical (default): Lat and lon in degree, with a - correction for zonal velocity U near the poles. - 2. flat: No conversion, lat/lon are assumed to be in m. - allow_time_extrapolation : bool - boolean whether to allow for extrapolation - (i.e. beyond the last available time snapshot) - Default is False if dimensions includes time, else True - interp_method : str - Method for interpolation. Options are 'linear' (default), 'nearest', - 'linear_invdist_land_tracer', 'cgrid_velocity', 'cgrid_tracer' and 'bgrid_velocity' - gridindexingtype : str - The type of gridindexing. Either 'nemo' (default), 'mitgcm', 'mom5', 'pop', or 'croco' are supported. - See also the Grid indexing documentation on oceanparcels.org - **kwargs : - Keyword arguments passed to the :class:`parcels.Field` constructor. - - - Examples - -------- - For usage examples see the following tutorials: - - * `Basic Parcels setup <../examples/parcels_tutorial.ipynb>`__ - - * `Argo floats <../examples/tutorial_Argofloats.ipynb>`__ - - * `Time-evolving depth dimensions <../examples/tutorial_timevaryingdepthdimensions.ipynb>`__ - - """ - fields: dict[str, Field] = {} - for var, name in variables.items(): - # Resolve all matching paths for the current variable - paths = filenames[var] if type(filenames) is dict and var in filenames else filenames - if type(paths) is not dict: - paths = cls._parse_wildcards(paths, filenames, var) - else: - for dim, p in paths.items(): - paths[dim] = cls._parse_wildcards(p, filenames, var) - - # Use dimensions[var] if it's a dict of dicts - dims = dimensions[var] if var in dimensions else dimensions - cls.checkvaliddimensionsdict(dims) - fieldtype = fieldtype[var] if (fieldtype and var in fieldtype) else fieldtype - - grid = None - - fields[var] = Field.from_netcdf( - paths, - (var, name), - dims, - grid=grid, - mesh=mesh, - allow_time_extrapolation=allow_time_extrapolation, - fieldtype=fieldtype, - **kwargs, - ) - - u = fields.pop("U", None) - v = fields.pop("V", None) - return cls(u, v, fields=fields) - - @classmethod - def from_nemo( - cls, - filenames, - variables, - dimensions, - mesh: Mesh = "spherical", - allow_time_extrapolation: bool | None = None, - tracer_interp_method: InterpMethodOption = "cgrid_tracer", - **kwargs, - ): - """Initialises FieldSet object from NetCDF files of Curvilinear NEMO fields. - - See `here <../examples/tutorial_nemo_curvilinear.ipynb>`__ - for a detailed tutorial on the setup for 2D NEMO fields and `here <../examples/tutorial_nemo_3D.ipynb>`__ - for the tutorial on the setup for 3D NEMO fields. - - See `here <../examples/documentation_indexing.ipynb>`__ - for a more detailed explanation of the different methods that can be used for c-grid datasets. - - Parameters - ---------- - filenames : - Dictionary mapping variables to file(s). The - filepath may contain wildcards to indicate multiple files, - or be a list of file. - filenames can be a list ``[files]``, a dictionary ``{var:[files]}``, - a dictionary ``{dim:[files]}`` (if lon, lat, depth and/or data not stored in same files as data), - or a dictionary of dictionaries ``{var:{dim:[files]}}`` - time values are in ``filenames[data]`` - variables : dict - Dictionary mapping variables to variable names in the netCDF file(s). - Note that the built-in Advection kernels assume that U and V are in m/s - dimensions : dict - Dictionary mapping data dimensions (lon, - lat, depth, time, data) to dimensions in the netCF file(s). - Note that dimensions can also be a dictionary of dictionaries if - dimension names are different for each variable. - Watch out: NEMO is discretised on a C-grid: - U and V velocities are not located on the same nodes (see https://www.nemo-ocean.eu/doc/node19.html). :: - - +-----------------------------+-----------------------------+-----------------------------+ - | | V[k,j+1,i+1] | | - +-----------------------------+-----------------------------+-----------------------------+ - |U[k,j+1,i] |W[k:k+2,j+1,i+1],T[k,j+1,i+1]|U[k,j+1,i+1] | - +-----------------------------+-----------------------------+-----------------------------+ - | | V[k,j,i+1] | | - +-----------------------------+-----------------------------+-----------------------------+ - - To interpolate U, V velocities on the C-grid, Parcels needs to read the f-nodes, - which are located on the corners of the cells. - (for indexing details: https://www.nemo-ocean.eu/doc/img360.png ) - In 3D, the depth is the one corresponding to W nodes - The gridindexingtype is set to 'nemo'. See also the Grid indexing documentation on oceanparcels.org - fieldtype : - Optional dictionary mapping fields to fieldtypes to be used for UnitConverter. - (either 'U', 'V', 'Kh_zonal', 'Kh_meridional' or None) - mesh : str - String indicating the type of mesh coordinates and - units used during velocity interpolation, see also `this tutorial <../examples/tutorial_unitconverters.ipynb>`__: - - 1. spherical (default): Lat and lon in degree, with a - correction for zonal velocity U near the poles. - 2. flat: No conversion, lat/lon are assumed to be in m. - allow_time_extrapolation : bool - boolean whether to allow for extrapolation - (i.e. beyond the last available time snapshot) - Default is False if dimensions includes time, else True - tracer_interp_method : str - Method for interpolation of tracer fields. It is recommended to use 'cgrid_tracer' (default) - Note that in the case of from_nemo() and from_c_grid_dataset(), the velocity fields are default to 'cgrid_velocity' - **kwargs : - Keyword arguments passed to the :func:`Fieldset.from_c_grid_dataset` constructor. - - """ - if kwargs.pop("gridindexingtype", "nemo") != "nemo": - raise ValueError( - "gridindexingtype must be 'nemo' in FieldSet.from_nemo(). Use FieldSet.from_c_grid_dataset otherwise" - ) - fieldset = cls.from_c_grid_dataset( - filenames, - variables, - dimensions, - mesh=mesh, - allow_time_extrapolation=allow_time_extrapolation, - tracer_interp_method=tracer_interp_method, - gridindexingtype="nemo", - **kwargs, - ) - return fieldset - - @classmethod - def from_mitgcm( - cls, - filenames, - variables, - dimensions, - mesh: Mesh = "spherical", - allow_time_extrapolation: bool | None = None, - tracer_interp_method: InterpMethodOption = "cgrid_tracer", - **kwargs, - ): - """Initialises FieldSet object from NetCDF files of MITgcm fields. - All parameters and keywords are exactly the same as for FieldSet.from_nemo(), except that - gridindexing is set to 'mitgcm' for grids that have the shape:: - - +-----------------------------+-----------------------------+-----------------------------+ - | | V[k,j+1,i] | | - +-----------------------------+-----------------------------+-----------------------------+ - |U[k,j,i] | W[k-1:k,j,i], T[k,j,i] |U[k,j,i+1] | - +-----------------------------+-----------------------------+-----------------------------+ - | | V[k,j,i] | | - +-----------------------------+-----------------------------+-----------------------------+ - - For indexing details: https://mitgcm.readthedocs.io/en/latest/algorithm/algorithm.html#spatial-discretization-of-the-dynamical-equations - Note that vertical velocity (W) is assumed positive in the positive z direction (which is upward in MITgcm) - """ - if kwargs.pop("gridindexingtype", "mitgcm") != "mitgcm": - raise ValueError( - "gridindexingtype must be 'mitgcm' in FieldSet.from_mitgcm(). Use FieldSet.from_c_grid_dataset otherwise" - ) - fieldset = cls.from_c_grid_dataset( - filenames, - variables, - dimensions, - mesh=mesh, - allow_time_extrapolation=allow_time_extrapolation, - tracer_interp_method=tracer_interp_method, - gridindexingtype="mitgcm", - **kwargs, - ) - return fieldset - - @classmethod - def from_croco( - cls, - filenames, - variables, - dimensions, - hc: float | None = None, - mesh="spherical", - allow_time_extrapolation=None, - tracer_interp_method="cgrid_tracer", - **kwargs, - ): - """Initialises FieldSet object from NetCDF files of CROCO fields. - All parameters and keywords are exactly the same as for FieldSet.from_nemo(), except that - in order to scale the vertical coordinate in CROCO, the following fields are required: - the bathymetry (``h``), the sea-surface height (``zeta``), the S-coordinate stretching curves - at W-points (``Cs_w``), and the stretching parameter (``hc``). - The horizontal interpolation uses the MITgcm grid indexing as described in FieldSet.from_mitgcm(). - - In 3D, when there is a ``depth`` dimension, the sigma grid scaling means that FieldSet.from_croco() - requires variables ``H: h`` and ``Zeta: zeta``, ``Cs_w: Cs_w``, as well as the stretching parameter ``hc`` - (as an extra input) parameter to work. - - See `the CROCO 3D tutorial <../examples/tutorial_croco_3D.ipynb>`__ for more infomation. - """ - if kwargs.pop("gridindexingtype", "croco") != "croco": - raise ValueError( - "gridindexingtype must be 'croco' in FieldSet.from_croco(). Use FieldSet.from_c_grid_dataset otherwise" - ) - - dimsU = dimensions["U"] if "U" in dimensions else dimensions - croco3D = True if "depth" in dimsU else False - - if croco3D: - if "W" in variables and variables["W"] == "omega": - warnings.warn( - "Note that Parcels expects 'w' for vertical velicites in 3D CROCO fields.\nSee https://docs.oceanparcels.org/en/latest/examples/tutorial_croco_3D.html for more information", - FieldSetWarning, - stacklevel=2, - ) - if "H" not in variables: - raise ValueError("FieldSet.from_croco() requires a bathymetry field 'H' for 3D CROCO fields") - if "Zeta" not in variables: - raise ValueError("FieldSet.from_croco() requires a free-surface field 'Zeta' for 3D CROCO fields") - if "Cs_w" not in variables: - raise ValueError( - "FieldSet.from_croco() requires the S-coordinate stretching curves at W-points 'Cs_w' for 3D CROCO fields" - ) - - interp_method = {} - for v in variables: - if v in ["U", "V"]: - interp_method[v] = "cgrid_velocity" - elif v in ["W", "H"]: - interp_method[v] = "linear" - else: - interp_method[v] = tracer_interp_method - - # Suppress the warning about the velocity interpolation since it is ok for CROCO - warnings.filterwarnings( - "ignore", - "Sampling of velocities should normally be done using fieldset.UV or fieldset.UVW object; tread carefully", - ) - - fieldset = cls.from_netcdf( - filenames, - variables, - dimensions, - mesh=mesh, - allow_time_extrapolation=allow_time_extrapolation, - interp_method=interp_method, - gridindexingtype="croco", - **kwargs, - ) - if croco3D: - if hc is None: - raise ValueError("FieldSet.from_croco() requires the hc parameter for 3D CROCO fields") - fieldset.add_constant("hc", hc) - return fieldset - - @classmethod - def from_c_grid_dataset( - cls, - filenames, - variables, - dimensions, - mesh: Mesh = "spherical", - allow_time_extrapolation: bool | None = None, - tracer_interp_method: InterpMethodOption = "cgrid_tracer", - gridindexingtype: GridIndexingType = "nemo", - **kwargs, - ): - """Initialises FieldSet object from NetCDF files of Curvilinear NEMO fields. - - See `here <../examples/documentation_indexing.ipynb>`__ - for a more detailed explanation of the different methods that can be used for c-grid datasets. - - Parameters - ---------- - filenames : - Dictionary mapping variables to file(s). The - filepath may contain wildcards to indicate multiple files, - or be a list of file. - filenames can be a list ``[files]``, a dictionary ``{var:[files]}``, - a dictionary ``{dim:[files]}`` (if lon, lat, depth and/or data not stored in same files as data), - or a dictionary of dictionaries ``{var:{dim:[files]}}`` - time values are in ``filenames[data]`` - variables : dict - Dictionary mapping variables to variable - names in the netCDF file(s). - dimensions : dict - Dictionary mapping data dimensions (lon, - lat, depth, time, data) to dimensions in the netCF file(s). - Note that dimensions can also be a dictionary of dictionaries if - dimension names are different for each variable. - Watch out: NEMO is discretised on a C-grid: - U and V velocities are not located on the same nodes (see https://www.nemo-ocean.eu/doc/node19.html ). :: - - +-----------------------------+-----------------------------+-----------------------------+ - | | V[k,j+1,i+1] | | - +-----------------------------+-----------------------------+-----------------------------+ - |U[k,j+1,i] |W[k:k+2,j+1,i+1],T[k,j+1,i+1]|U[k,j+1,i+1] | - +-----------------------------+-----------------------------+-----------------------------+ - | | V[k,j,i+1] | | - +-----------------------------+-----------------------------+-----------------------------+ - - To interpolate U, V velocities on the C-grid, Parcels needs to read the f-nodes, - which are located on the corners of the cells. - (for indexing details: https://www.nemo-ocean.eu/doc/img360.png ) - In 3D, the depth is the one corresponding to W nodes. - fieldtype : - Optional dictionary mapping fields to fieldtypes to be used for UnitConverter. - (either 'U', 'V', 'Kh_zonal', 'Kh_meridional' or None) - mesh : str - String indicating the type of mesh coordinates and - units used during velocity interpolation, see also `this tutorial <../examples/tutorial_unitconverters.ipynb>`__: - - 1. spherical (default): Lat and lon in degree, with a - correction for zonal velocity U near the poles. - 2. flat: No conversion, lat/lon are assumed to be in m. - allow_time_extrapolation : bool - boolean whether to allow for extrapolation - (i.e. beyond the last available time snapshot) - Default is False if dimensions includes time, else True - tracer_interp_method : str - Method for interpolation of tracer fields. It is recommended to use 'cgrid_tracer' (default) - Note that in the case of from_nemo() and from_c_grid_dataset(), the velocity fields are default to 'cgrid_velocity' - gridindexingtype : str - The type of gridindexing. Set to 'nemo' in FieldSet.from_nemo(), 'mitgcm' in FieldSet.from_mitgcm() or 'croco' in FieldSet.from_croco(). - See also the Grid indexing documentation on oceanparcels.org (Default value = 'nemo') - **kwargs : - Keyword arguments passed to the :func:`Fieldset.from_netcdf` constructor. - """ - if "U" in dimensions and "V" in dimensions and dimensions["U"] != dimensions["V"]: - raise ValueError( - "On a C-grid, the dimensions of velocities should be the corners (f-points) of the cells, so the same for U and V. " - "See also https://docs.oceanparcels.org/en/latest/examples/documentation_indexing.html" - ) - if "U" in dimensions and "W" in dimensions and dimensions["U"] != dimensions["W"]: - raise ValueError( - "On a C-grid, the dimensions of velocities should be the corners (f-points) of the cells, so the same for U, V and W. " - "See also https://docs.oceanparcels.org/en/latest/examples/documentation_indexing.html" - ) - if "interp_method" in kwargs.keys(): - raise TypeError("On a C-grid, the interpolation method for velocities should not be overridden") - - interp_method = {} - for v in variables: - if v in ["U", "V", "W"]: - interp_method[v] = "cgrid_velocity" - else: - interp_method[v] = tracer_interp_method - - return cls.from_netcdf( - filenames, - variables, - dimensions, - mesh=mesh, - allow_time_extrapolation=allow_time_extrapolation, - interp_method=interp_method, - gridindexingtype=gridindexingtype, - **kwargs, - ) - - @classmethod - def from_mom5( - cls, - filenames, - variables, - dimensions, - mesh: Mesh = "spherical", - allow_time_extrapolation: bool | None = None, - tracer_interp_method: InterpMethodOption = "bgrid_tracer", - **kwargs, - ): - """Initialises FieldSet object from NetCDF files of MOM5 fields. - - Parameters - ---------- - filenames : - Dictionary mapping variables to file(s). The - filepath may contain wildcards to indicate multiple files, - or be a list of file. - filenames can be a list ``[files]``, a dictionary ``{var:[files]}``, - a dictionary ``{dim:[files]}`` (if lon, lat, depth and/or data not stored in same files as data), - or a dictionary of dictionaries ``{var:{dim:[files]}}`` - time values are in ``filenames[data]`` - variables : dict - Dictionary mapping variables to variable names in the netCDF file(s). - Note that the built-in Advection kernels assume that U and V are in m/s - dimensions : dict - Dictionary mapping data dimensions (lon, - lat, depth, time, data) to dimensions in the netCF file(s). - Note that dimensions can also be a dictionary of dictionaries if - dimension names are different for each variable. :: - - +-------------------------------+-------------------------------+-------------------------------+ - |U[k,j+1,i],V[k,j+1,i] | |U[k,j+1,i+1],V[k,j+1,i+1] | - +-------------------------------+-------------------------------+-------------------------------+ - | |W[k-1:k+1,j+1,i+1],T[k,j+1,i+1]| | - +-------------------------------+-------------------------------+-------------------------------+ - |U[k,j,i],V[k,j,i] | |U[k,j,i+1],V[k,j,i+1] | - +-------------------------------+-------------------------------+-------------------------------+ - - In 2D: U and V nodes are on the cell vertices and interpolated bilinearly as a A-grid. - T node is at the cell centre and interpolated constant per cell as a C-grid. - In 3D: U and V nodes are at the middle of the cell vertical edges, - They are interpolated bilinearly (independently of z) in the cell. - W nodes are at the centre of the horizontal interfaces, but below the U and V. - They are interpolated linearly (as a function of z) in the cell. - Note that W is normally directed upward in MOM5, but Parcels requires W - in the positive z-direction (downward) so W is multiplied by -1. - T node is at the cell centre, and constant per cell. - fieldtype : - Optional dictionary mapping fields to fieldtypes to be used for UnitConverter. - (either 'U', 'V', 'Kh_zonal', 'Kh_meridional' or None) - mesh : str - String indicating the type of mesh coordinates and - units used during velocity interpolation, see also the `Unit converters tutorial <../examples/tutorial_unitconverters.ipynb>`__: - - 1. spherical (default): Lat and lon in degree, with a - correction for zonal velocity U near the poles. - 2. flat: No conversion, lat/lon are assumed to be in m. - allow_time_extrapolation : bool - boolean whether to allow for extrapolation - (i.e. beyond the last available time snapshot) - Default is False if dimensions includes time, else True - tracer_interp_method : str - Method for interpolation of tracer fields. It is recommended to use 'bgrid_tracer' (default) - Note that in the case of from_mom5() and from_b_grid_dataset(), the velocity fields are default to 'bgrid_velocity' - **kwargs : - Keyword arguments passed to the :func:`Fieldset.from_b_grid_dataset` constructor. - """ - fieldset = cls.from_b_grid_dataset( - filenames, - variables, - dimensions, - mesh=mesh, - allow_time_extrapolation=allow_time_extrapolation, - tracer_interp_method=tracer_interp_method, - gridindexingtype="mom5", - **kwargs, - ) - return fieldset - - @classmethod - def from_a_grid_dataset(cls, filenames, variables, dimensions, **kwargs): - """ - Load a FieldSet from an A-grid dataset, which is the default grid type. - - Parameters - ---------- - filenames : - Path(s) to the input files. - variables : - Dictionary of the variables in the NetCDF file. - dimensions : - Dictionary of the dimensions in the NetCDF file. - **kwargs : - Additional keyword arguments for `from_netcdf()`. - - Returns - ------- - FieldSet - A FieldSet object. - """ - return cls.from_netcdf(filenames, variables, dimensions, **kwargs) - - @classmethod - def from_b_grid_dataset( - cls, - filenames, - variables, - dimensions, - mesh: Mesh = "spherical", - allow_time_extrapolation: bool | None = None, - tracer_interp_method: InterpMethodOption = "bgrid_tracer", - **kwargs, - ): - """Initialises FieldSet object from NetCDF files of Bgrid fields. - - Parameters - ---------- - filenames : - Dictionary mapping variables to file(s). The - filepath may contain wildcards to indicate multiple files, - or be a list of file. - filenames can be a list ``[files]``, a dictionary ``{var:[files]}``, - a dictionary ``{dim:[files]}`` (if lon, lat, depth and/or data not stored in same files as data), - or a dictionary of dictionaries ``{var:{dim:[files]}}`` - time values are in ``filenames[data]`` - variables : dict - Dictionary mapping variables to variable - names in the netCDF file(s). - dimensions : dict - Dictionary mapping data dimensions (lon, - lat, depth, time, data) to dimensions in the netCF file(s). - Note that dimensions can also be a dictionary of dictionaries if - dimension names are different for each variable. - U and V velocity nodes are not located as W velocity and T tracer nodes (see http://www2.cesm.ucar.edu/models/cesm1.0/pop2/doc/sci/POPRefManual.pdf ). :: - - +-----------------------------+-----------------------------+-----------------------------+ - |U[k,j+1,i],V[k,j+1,i] | |U[k,j+1,i+1],V[k,j+1,i+1] | - +-----------------------------+-----------------------------+-----------------------------+ - | |W[k:k+2,j+1,i+1],T[k,j+1,i+1]| | - +-----------------------------+-----------------------------+-----------------------------+ - |U[k,j,i],V[k,j,i] | |U[k,j,i+1],V[k,j,i+1] | - +-----------------------------+-----------------------------+-----------------------------+ - - In 2D: U and V nodes are on the cell vertices and interpolated bilinearly as a A-grid. - T node is at the cell centre and interpolated constant per cell as a C-grid. - In 3D: U and V nodes are at the midlle of the cell vertical edges, - They are interpolated bilinearly (independently of z) in the cell. - W nodes are at the centre of the horizontal interfaces. - They are interpolated linearly (as a function of z) in the cell. - T node is at the cell centre, and constant per cell. - fieldtype : - Optional dictionary mapping fields to fieldtypes to be used for UnitConverter. - (either 'U', 'V', 'Kh_zonal', 'Kh_meridional' or None) - mesh : str - String indicating the type of mesh coordinates and - units used during velocity interpolation, see also `this tutorial <../examples/tutorial_unitconverters.ipynb>`__: - - 1. spherical (default): Lat and lon in degree, with a - correction for zonal velocity U near the poles. - 2. flat: No conversion, lat/lon are assumed to be in m. - allow_time_extrapolation : bool - boolean whether to allow for extrapolation - (i.e. beyond the last available time snapshot) - Default is False if dimensions includes time, else True - tracer_interp_method : str - Method for interpolation of tracer fields. It is recommended to use 'bgrid_tracer' (default) - Note that in the case of from_b_grid_dataset(), the velocity fields are default to 'bgrid_velocity' - **kwargs : - Keyword arguments passed to the :func:`Fieldset.from_netcdf` constructor. - """ - if "U" in dimensions and "V" in dimensions and dimensions["U"] != dimensions["V"]: - raise ValueError( - "On a B-grid, the dimensions of velocities should be the (top) corners of the grid cells, so the same for U and V. " - "See also https://docs.oceanparcels.org/en/latest/examples/documentation_indexing.html" - ) - if "U" in dimensions and "W" in dimensions and dimensions["U"] != dimensions["W"]: - raise ValueError( - "On a B-grid, the dimensions of velocities should be the (top) corners of the grid cells, so the same for U, V and W. " - "See also https://docs.oceanparcels.org/en/latest/examples/documentation_indexing.html" - ) - - interp_method = {} - for v in variables: - if v in ["U", "V"]: - interp_method[v] = "bgrid_velocity" - elif v in ["W"]: - interp_method[v] = "bgrid_w_velocity" - else: - interp_method[v] = tracer_interp_method - - return cls.from_netcdf( - filenames, - variables, - dimensions, - mesh=mesh, - allow_time_extrapolation=allow_time_extrapolation, - interp_method=interp_method, - **kwargs, - ) - - @classmethod - def from_xarray_dataset(cls, ds, variables, dimensions, mesh="spherical", allow_time_extrapolation=None, **kwargs): - """Initialises FieldSet data from xarray Datasets. - - Parameters - ---------- - ds : xr.Dataset - xarray Dataset. - Note that the built-in Advection kernels assume that U and V are in m/s - variables : dict - Dictionary mapping parcels variable names to data variables in the xarray Dataset. - dimensions : dict - Dictionary mapping data dimensions (lon, - lat, depth, time, data) to dimensions in the xarray Dataset. - Note that dimensions can also be a dictionary of dictionaries if - dimension names are different for each variable - (e.g. dimensions['U'], dimensions['V'], etc). - fieldtype : - Optional dictionary mapping fields to fieldtypes to be used for UnitConverter. - (either 'U', 'V', 'Kh_zonal', 'Kh_meridional' or None) - mesh : str - String indicating the type of mesh coordinates and - units used during velocity interpolation, see also `this tutorial <../examples/tutorial_unitconverters.ipynb>`__: - - 1. spherical (default): Lat and lon in degree, with a - correction for zonal velocity U near the poles. - 2. flat: No conversion, lat/lon are assumed to be in m. - allow_time_extrapolation : bool - boolean whether to allow for extrapolation - (i.e. beyond the last available time snapshot) - Default is False if dimensions includes time, else True - **kwargs : - Keyword arguments passed to the :func:`Field.from_xarray` constructor. - """ - fields = {} - - for var, name in variables.items(): - dims = dimensions[var] if var in dimensions else dimensions - cls.checkvaliddimensionsdict(dims) - - fields[var] = Field.from_xarray( - ds[name], - var, - dims, - mesh=mesh, - allow_time_extrapolation=allow_time_extrapolation, - **kwargs, - ) - u = fields.pop("U", None) - v = fields.pop("V", None) - return cls(u, v, fields=fields) - - @classmethod - def from_modulefile(cls, filename, modulename="create_fieldset", **kwargs): - """Initialises FieldSet data from a file containing a python module file with a create_fieldset() function. - - Parameters - ---------- - filename: path to a python file containing at least a function which returns a FieldSet object. - modulename: name of the function in the python file that returns a FieldSet object. Default is "create_fieldset". - """ - # check if filename exists - if not os.path.exists(filename): - raise OSError(f"FieldSet module file {filename} does not exist") - - # Importing the source file directly (following https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly) - spec = importlib.util.spec_from_file_location(modulename, filename) - fieldset_module = importlib.util.module_from_spec(spec) - sys.modules[modulename] = fieldset_module - spec.loader.exec_module(fieldset_module) - - if not hasattr(fieldset_module, modulename): - raise OSError(f"{filename} does not contain a {modulename} function") - fieldset = getattr(fieldset_module, modulename)(**kwargs) - if not isinstance(fieldset, FieldSet): - raise OSError(f"Module {filename}.{modulename} does not return a FieldSet object") - return fieldset - - def get_fields(self) -> list[Field | VectorField]: - """Returns a list of all the :class:`parcels.field.Field` and :class:`parcels.field.VectorField` - objects associated with this FieldSet. - """ - fields = [] - for v in self.__dict__.values(): - if type(v) in [Field, VectorField]: - if v not in fields: - fields.append(v) - return fields + + # @classmethod + # def from_netcdf( + # cls, + # filenames, + # variables, + # dimensions, + # fieldtype=None, + # mesh: Mesh = "spherical", + # allow_time_extrapolation: bool | None = None, + # **kwargs, + # ): + + # @classmethod + # def from_nemo( + # cls, + # filenames, + # variables, + # dimensions, + # mesh: Mesh = "spherical", + # allow_time_extrapolation: bool | None = None, + # tracer_interp_method: InterpMethodOption = "cgrid_tracer", + # **kwargs, + # ): + + # @classmethod + # def from_mitgcm( + # cls, + # filenames, + # variables, + # dimensions, + # mesh: Mesh = "spherical", + # allow_time_extrapolation: bool | None = None, + # tracer_interp_method: InterpMethodOption = "cgrid_tracer", + # **kwargs, + # ): + + # @classmethod + # def from_croco( + # cls, + # filenames, + # variables, + # dimensions, + # hc: float | None = None, + # mesh="spherical", + # allow_time_extrapolation=None, + # tracer_interp_method="cgrid_tracer", + # **kwargs, + # ): + + # @classmethod + # def from_c_grid_dataset( + # cls, + # filenames, + # variables, + # dimensions, + # mesh: Mesh = "spherical", + # allow_time_extrapolation: bool | None = None, + # tracer_interp_method: InterpMethodOption = "cgrid_tracer", + # gridindexingtype: GridIndexingType = "nemo", + # **kwargs, + # ): + + + # @classmethod + # def from_mom5( + # cls, + # filenames, + # variables, + # dimensions, + # mesh: Mesh = "spherical", + # allow_time_extrapolation: bool | None = None, + # tracer_interp_method: InterpMethodOption = "bgrid_tracer", + # **kwargs, + # ): + + # @classmethod + # def from_a_grid_dataset(cls, filenames, variables, dimensions, **kwargs): + + # @classmethod + # def from_b_grid_dataset( + # cls, + # filenames, + # variables, + # dimensions, + # mesh: Mesh = "spherical", + # allow_time_extrapolation: bool | None = None, + # tracer_interp_method: InterpMethodOption = "bgrid_tracer", + # **kwargs, + # ): def add_constant(self, name, value): """Add a constant to the FieldSet. Note that all constants are @@ -1001,29 +309,29 @@ def add_constant(self, name, value): """ setattr(self, name, value) - def computeTimeChunk(self, time=0.0, dt=1): - """Load a chunk of three data time steps into the FieldSet. - This is used when FieldSet uses data imported from netcdf, - with default option deferred_load. The loaded time steps are at or immediatly before time - and the two time steps immediately following time if dt is positive (and inversely for negative dt) - - Parameters - ---------- - time : - Time around which the FieldSet data are to be loaded. - Time is provided as a double, relatively to Fieldset.time_origin. - Default is 0. - dt : - time step of the integration scheme, needed to set the direction of time chunk loading. - Default is 1. - """ - nextTime = np.inf if dt > 0 else -np.inf - - if abs(nextTime) == np.inf or np.isnan(nextTime): # Second happens when dt=0 - return nextTime - else: - nSteps = int((nextTime - time) / dt) - if nSteps == 0: - return nextTime - else: - return time + nSteps * dt + # def computeTimeChunk(self, time=0.0, dt=1): + # """Load a chunk of three data time steps into the FieldSet. + # This is used when FieldSet uses data imported from netcdf, + # with default option deferred_load. The loaded time steps are at or immediatly before time + # and the two time steps immediately following time if dt is positive (and inversely for negative dt) + + # Parameters + # ---------- + # time : + # Time around which the FieldSet data are to be loaded. + # Time is provided as a double, relatively to Fieldset.time_origin. + # Default is 0. + # dt : + # time step of the integration scheme, needed to set the direction of time chunk loading. + # Default is 1. + # """ + # nextTime = np.inf if dt > 0 else -np.inf + + # if abs(nextTime) == np.inf or np.isnan(nextTime): # Second happens when dt=0 + # return nextTime + # else: + # nSteps = int((nextTime - time) / dt) + # if nSteps == 0: + # return nextTime + # else: + # return time + nSteps * dt diff --git a/parcels/xfield.py b/parcels/xfield.py deleted file mode 100644 index 988055cd68..0000000000 --- a/parcels/xfield.py +++ /dev/null @@ -1,675 +0,0 @@ -import collections -import math -import warnings -from typing import TYPE_CHECKING, cast - -import dask.array as da -import numpy as np -import xarray as xr -import uxarray as ux -from uxarray.grid.neighbors import _barycentric_coordinates - - -import parcels.tools.interpolation_utils as i_u -from parcels._compat import add_note -from parcels._interpolation import ( - InterpolationContext2D, - InterpolationContext3D, - get_2d_interpolator_registry, - get_3d_interpolator_registry, -) -from parcels._typing import ( - GridIndexingType, - InterpMethod, - InterpMethodOption, - Mesh, - VectorType, - assert_valid_gridindexingtype, - assert_valid_interp_method, -) -from parcels.tools._helpers import default_repr, field_repr, should_calculate_next_ti -from parcels.tools.converters import ( - TimeConverter, - UnitConverter, - unitconverters_map, -) -from parcels.tools.statuscodes import ( - AllParcelsErrorCodes, - FieldOutOfBoundError, - FieldOutOfBoundSurfaceError, - FieldSamplingError, - _raise_field_out_of_bound_error, -) -from parcels.tools.warnings import FieldSetWarning -import inspect -from typing import Callable, Union - -from ._index_search import _search_indices_curvilinear, _search_indices_rectilinear, _search_time_index - -if TYPE_CHECKING: - import numpy.typing as npt - - from parcels.xfieldset import XFieldSet - -__all__ = ["XField", "XVectorField"] - - -def _isParticle(key): - if hasattr(key, "obs_written"): - return True - else: - return False - - -def _deal_with_errors(error, key, vector_type: VectorType): - if _isParticle(key): - key.state = AllParcelsErrorCodes[type(error)] - elif _isParticle(key[-1]): - key[-1].state = AllParcelsErrorCodes[type(error)] - else: - raise RuntimeError(f"{error}. Error could not be handled because particle was not part of the Field Sampling.") - - if vector_type and "3D" in vector_type: - return (0, 0, 0) - elif vector_type == "2D": - return (0, 0) - else: - return 0 - -class XField: - """The XField class that holds scalar field data. - The `XField` object is a wrapper around a xarray.DataArray or uxarray.UxDataArray object. - Additionally, it holds a dynamic Callable procedure that is used to interpolate the field data. - During initialization, the user can supply a custom interpolation method that is used to interpolate the field data, - so long as the interpolation method has the correct signature. - - Notes - ----- - - - The xarray.DataArray or uxarray.UxDataArray object contains the field data and metadata. - * dims: (time, [nz1 | nz], [face_lat | node_lat | edge_lat], [face_lon | node_lon | edge_lon]) - * attrs: (location, mesh, mesh_type) - - When using a xarray.DataArray object, - * The xarray.DataArray object must have the "location" and "mesh" attributes set. - * The "location" attribute must be set to one of the following to define which pairing of points a field is associated with. - * "node" - * "face" - * "x_edge" - * "y_edge" - * For an A-Grid, the "location" attribute must be set to / is assumed to be "node" (node_lat,node_lon). - * For a C-Grid, the "location" setting for a field has the following interpretation: - * "node" ~> the field is associated with the vorticity points (node_lat, node_lon) - * "face" ~> the field is associated with the tracer points (face_lat, face_lon) - * "x_edge" ~> the field is associated with the u-velocity points (face_lat, node_lon) - * "y_edge" ~> the field is associated with the v-velocity points (node_lat, face_lon) - - When using a uxarray.UxDataArray object, - * The uxarray.UxDataArray.UxGrid object must have the "Conventions" attribute set to "UGRID-1.0" - and the uxarray.UxDataArray object must comply with the UGRID conventions. - See https://ugrid-conventions.github.io/ugrid-conventions/ for more information. - - """ - - @staticmethod - def _interp_template( - self, - ti: int, - ei: int, - bcoords: np.ndarray, - t: Union[np.float32,np.float64], - z: Union[np.float32,np.float64], - y: Union[np.float32,np.float64], - x: Union[np.float32,np.float64] - )-> Union[np.float32,np.float64]: - """ Template function used for the signature check of the lateral interpolation methods.""" - return 0.0 - - def _validate_interp_function(self, func: Callable): - """Ensures that the function has the correct signature.""" - expected_params = ["ti", "ei", "bcoords", "t", "z", "y", "x"] - expected_return_types = (np.float32,np.float64) - - sig = inspect.signature(func) - params = list(sig.parameters.keys()) - - # Check the parameter names and count - if params != expected_params: - raise TypeError( - f"Function must have parameters {expected_params}, but got {params}" - ) - - # Check return annotation if present - return_annotation = sig.return_annotation - if return_annotation not in (inspect.Signature.empty, *expected_return_types): - raise TypeError( - f"Function must return a float, but got {return_annotation}" - ) - - def __init__( - self, - name: str, - data: xr.DataArray | ux.UxDataArray, - mesh_type: Mesh = "flat", - interp_method: Callable | None = None, - allow_time_extrapolation: bool | None = None, - ): - - self.name = name - self.data = data - - self._validate_dataarray(data) - - self._parent_mesh = data.attributes["mesh"] - self._mesh_type = mesh_type - self._location = data.attributes["location"] - - # Set the vertical location - if "nz1" in data.dims: - self._vertical_location = "center" - elif "nz" in data.dims: - self._vertical_location = "face" - - # Setting the interpolation method dynamically - if interp_method is None: - self._interp_method = self._interp_template # Default to method that returns 0 always - else: - self._validate_interp_function(interp_method) - self._interp_method = interp_method - - self.igrid = -1 # Default the grid index to -1 - - if self._mesh_type == "flat" or (self.name not in unitconverters_map.keys()): - self.units = UnitConverter() - elif self._mesh_type == "spherical": - self.units = unitconverters_map[self.name] - else: - raise ValueError("Unsupported mesh type in data array attributes. Choose either: 'spherical' or 'flat'") - - if allow_time_extrapolation is None: - self.allow_time_extrapolation = True if len(self.data["time"]) == 1 else False - else: - self.allow_time_extrapolation = allow_time_extrapolation - - if type(self.data) is ux.UxDataArray: - self._spatialhash = self.data.uxgrid.get_spatial_hash() - else: - self._spatialhash = None - - def __repr__(self): - return field_repr(self) - - @property - def grid(self): - if type(self.data) is ux.UxDataArray: - return self.data.uxgrid - else: - return self.data # To do : need to decide on what to return for xarray.DataArray objects - - @property - def lat(self): - if type(self.data) is ux.UxDataArray: - if self._location == "node": - return self.data.uxgrid.node_lat - elif self._location == "face": - return self.data.uxgrid.face_lat - elif self._location == "edge": - return self.data.uxgrid.edge_lat - else: - if self._location == "node": - return self.data.node_lat - elif self._location == "face": - return self.data.face_lat - elif self._location == "x_edge": - return self.data.face_lat - elif self._location == "y_edge": - return self.data.node_lat - - @property - def lon(self): - if type(self.data) is ux.UxDataArray: - if self._location == "node": - return self.data.uxgrid.node_lon - elif self._location == "face": - return self.data.uxgrid.face_lon - elif self._location == "edge": - return self.data.uxgrid.edge_lon - else: - if self._location == "node": - return self.data.node_lon - elif self._location == "face": - return self.data.face_lon - elif self._location == "x_edge": - return self.data.node_lon - elif self._location == "y_edge": - return self.data.face_lon - - @property - def depth(self): - if type(self.data) is ux.UxDataArray: - if self._vertical_location == "center": - return self.data.uxgrid.nz1 - elif self._vertical_location == "face": - return self.data.uxgrid.nz - else: - if self._vertical_location == "center": - return self.data.nz1 - elif self._vertical_location == "face": - return self.data.nz - - @property - def nx(self): - if type(self.data) is xr.DataArray: - if "face_lon" in self.data.dims: - return self.data.sizes["face_lon"] - elif "node_lon" in self.data.dims: - return self.data.sizes["node_lon"] - else: - return 0 # To do : Discuss what we want to return for uxdataarray obj - @property - def ny(self): - if type(self.data) is xr.DataArray: - if "face_lat" in self.data.dims: - return self.data.sizes["face_lat"] - elif "node_lat" in self.data.dims: - return self.data.sizes["node_lat"] - else: - return 0 # To do : Discuss what we want to return for uxdataarray obj - @property - def n_face(self): - if type(self.data) is ux.uxDataArray: - return self.data.uxgrid.n_face - else: - return 0 # To do : Discuss what we want to return for dataarray obj - - @property - def interp_method(self): - return self._interp_method - - @interp_method.setter - def interp_method(self, method: Callable): - self._validate_interp_function(method) - self._interp_method = method - - # @property - # def gridindexingtype(self): - # return self._gridindexingtype - - def _get_ux_barycentric_coordinates(self, y, x, fi): - "Checks if a point is inside a given face id. Used for unstructured grids." - - # Check if particle is in the same face, otherwise search again. - n_nodes = self.data.uxgrid.n_nodes_per_face[fi].to_numpy() - node_ids = self.data.uxgrid.face_node_connectivity[fi, 0:n_nodes] - nodes = np.column_stack( - ( - np.deg2rad(self.data.uxgrid.node_lon[node_ids].to_numpy()), - np.deg2rad(self.data.uxgrid.node_lat[node_ids].to_numpy()), - ) - ) - - coord = np.deg2rad([x, y]) - bcoord = np.asarray(_barycentric_coordinates(nodes, coord)) - err = abs(np.dot(bcoord, nodes[:, 0]) - coord[0]) + abs( - np.dot(bcoord, nodes[:, 1]) - coord[1] - ) - return bcoord, err - - - def _search_indices_unstructured(self, z, y, x, ei=None, search2D=False): - - tol = 1e-10 - if ei is None: - # Search using global search - fi, bcoords = self._spatialhash.query([[x,y]]) # Get the face id for the particle - if fi == -1: - raise FieldOutOfBoundError(z, y, x) # To do : how to handle lost particle ?? - # To do : Do the vertical grid search - # zi = self._vertical_search(z) - zi = 0 # For now - return bcoords, self.ravel_index(zi, 0, fi) - else: - zi, fi = self.unravel_index(ei[self.igrid]) # Get the z, and face index of the particle - # Search using nearest neighbors - bcoords, err = self._get_ux_barycentric_coordinates(y, x, fi) - - if ((bcoords >= 0).all()) and ((bcoords <= 1.0).all()) and err < tol: - # To do: Do the vertical grid search - return bcoords, ei - else: - # In this case we need to search the neighbors - for neighbor in self.data.uxgrid.face_face_connectivity[fi,:]: - bcoords, err = self._get_ux_barycentric_coordinates(y, x, neighbor) - if ((bcoords >= 0).all()) and ((bcoords <= 1.0).all()) and err < tol: - # To do: Do the vertical grid search - return bcoords, self.ravel_index(zi, 0, neighbor) - - # If we reach this point, we do a global search as a last ditch effort the particle is out of bounds - fi, bcoords = self._spatialhash.query([[x,y]]) # Get the face id for the particle - if fi == -1: - raise FieldOutOfBoundError(z, y, x) # To do : how to handle lost particle ?? - - - def _search_indices_structured(self, z, y, x, ei=None, search2D=False): - - # To do, determine grid type from xarray.coords shapes - # Rectilinear uses 1-D array for lat and lon - # Curvilinear uses 2-D array for lat and lon - if self.grid._gtype in [GridType.RectilinearSGrid, GridType.RectilinearZGrid]: - (zeta, eta, xsi, zi, yi, xi) = _search_indices_rectilinear( - self, z, y, x,particle=particle, search2D=search2D - ) - else: - (zeta, eta, xsi, zi, yi, xi) = _search_indices_curvilinear( - self, z, y, x, ei=ei, search2D=search2D - ) - - # To do : Calcualte barycentric coordinates from zeta, eta, xsi - - return (zeta, eta, xsi, zi, yi, xi) - - def _search_indices(self, time, z, y, x, ei=None, search2D=False): - - tau, ti = self._search_time_index(time) # To do : Need to implement this method - - if type(self.data) is ux.UxDataArray: - bcoords, ei = self._search_indices_unstructured(z, y, x, ei=ei, search2D=search2D) # To do : Need to implement this method - else: - bcoords, ei = self._search_indices_structured(z, y, x, ei=ei, search2D=search2D) # To do : Need to implement this method - return bcoords, ei, ti - - def _interpolate(self, time, z, y, x, ei): - - try: - bcoords, _ei, ti = self._search_indices(time, z, y, x, ei=ei) - val = self._interp_method(ti, _ei, bcoords, time, z, y, x) - - if np.isnan(val): - # Detect Out-of-bounds sampling and raise exception - _raise_field_out_of_bound_error(z, y, x) - else: - return val - - except (FieldSamplingError, FieldOutOfBoundError, FieldOutOfBoundSurfaceError) as e: - e = add_note(e, f"Error interpolating field '{self.name}'.", before=True) - raise e - - def _check_velocitysampling(self): - if self.name in ["U", "V", "W"]: - warnings.warn( - "Sampling of velocities should normally be done using fieldset.UV or fieldset.UVW object; tread carefully", - RuntimeWarning, - stacklevel=2, - ) - - def __getitem__(self, key): - self._check_velocitysampling() - try: - if _isParticle(key): - return self.eval(key.time, key.depth, key.lat, key.lon, key) - else: - return self.eval(*key) - except tuple(AllParcelsErrorCodes.keys()) as error: - return _deal_with_errors(error, key, vector_type=None) - - def eval(self, time, z, y, x, ei=None, applyConversion=True): - """Interpolate field values in space and time. - - We interpolate linearly in time and apply implicit unit - conversion to the result. Note that we defer to - scipy.interpolate to perform spatial interpolation. - """ - if ei is None: - _ei = None - else: - _ei = ei[self.igrid] - - value = self._interpolate(time, z, y, x, ei=_ei) - - if applyConversion: - return self.units.to_target(value, z, y, x) - else: - return value - - def _rescale_and_set_minmax(self, data): - data[np.isnan(data)] = 0 - return data - - def ravel_index(self, zi, yi, xi): - """Return the flat index of the given grid points. - - Parameters - ---------- - zi : int - z index - yi : int - y index - xi : int - x index. When using an unstructured grid, this is the face index (fi) - - Returns - ------- - int - flat index - """ - if type(self.data) is xr.DataArray: - return xi + self.nx * (yi + self.ny * zi) - else: - return xi + self.n_face*zi - - def unravel_index(self, ei): - """Return the zi, yi, xi indices for a given flat index. - Only used when working with fields on a structured grid. - - Parameters - ---------- - ei : int - The flat index to be unraveled. - - Returns - ------- - zi : int - The z index. - yi : int - The y index. - xi : int - The x index. - """ - if type(self.data) is xr.DataArray: - _ei = ei[self.igrid] - zi = _ei // (self.nx * self.ny) - _ei = _ei % (self.nx * self.ny) - yi = _ei // self.nx - xi = _ei % self.nx - return zi, yi, xi - else: - _ei = ei[self.igrid] - zi = _ei // self.n_face - fi = _ei % self.n_face - return zi, fi - - def _validate_dataarray(self): - """ Verifies that all the required attributes are present in the xarray.DataArray or - uxarray.UxDataArray object.""" - - # Validate dimensions - if not( "nz1" in self.data.dims or "nz" in self.data.dims ): - raise ValueError( - f"Field {self.name} is missing a 'nz1' or 'nz' dimension in the field's metadata. " - "This attribute is required for xarray.DataArray objects." - ) - - if not( "time" in self.data.dims ): - raise ValueError( - f"Field {self.name} is missing a 'time' dimension in the field's metadata. " - "This attribute is required for xarray.DataArray objects." - ) - - # Validate attributes - required_keys = ["location", "mesh"] - for key in required_keys: - if key not in self.data.attrs.keys(): - raise ValueError( - f"Field {self.name} is missing a '{key}' attribute in the field's metadata. " - "This attribute is required for xarray.DataArray objects." - ) - - if type(self.data) is ux.UxDataArray: - self._validate_uxgrid() - - - def _validate_uxgrid(self): - """ Verifies that all the required attributes are present in the uxarray.UxDataArray.UxGrid object.""" - - if "Conventions" not in self.data.uxgrid.attrs.keys(): - raise ValueError( - f"Field {self.name} is missing a 'Conventions' attribute in the field's metadata. " - "This attribute is required for uxarray.UxDataArray objects." - ) - if self.data.uxgrid.attrs["Conventions"] != "UGRID-1.0": - raise ValueError( - f"Field {self.name} has a 'Conventions' attribute that is not 'UGRID-1.0'. " - "This attribute is required for uxarray.UxDataArray objects." - "See https://ugrid-conventions.github.io/ugrid-conventions/ for more information." - ) - - - def __getattr__(self, key: str): - return getattr(self.data, key) - - def __contains__(self, key: str): - return key in self.data - - -class XVectorField: - """XVectorField class that holds vector field data needed to execute particles.""" - - - @staticmethod - def _vector_interp_template( - self, - ti: int, - ei: int, - bcoords: np.ndarray, - t: Union[np.float32,np.float64], - z: Union[np.float32,np.float64], - y: Union[np.float32,np.float64], - x: Union[np.float32,np.float64] - )-> Union[np.float32,np.float64]: - """ Template function used for the signature check of the lateral interpolation methods.""" - return 0.0 - - def _validate_vector_interp_function(self, func: Callable): - """Ensures that the function has the correct signature.""" - expected_params = ["ti", "ei", "bcoords", "t", "z", "y", "x"] - expected_return_types = (np.float32,np.float64) - - sig = inspect.signature(func) - params = list(sig.parameters.keys()) - - # Check the parameter names and count - if params != expected_params: - raise TypeError( - f"Function must have parameters {expected_params}, but got {params}" - ) - - # Check return annotation if present - return_annotation = sig.return_annotation - if return_annotation not in (inspect.Signature.empty, *expected_return_types): - raise TypeError( - f"Function must return a float, but got {return_annotation}" - ) - - def __init__( - self, - name: str, - U: XField, - V: XField, - W: XField | None = None, - vector_interp_method: Callable | None = None - ): - - self.name = name - self.U = U - self.V = V - self.W = W - - if self.W: - self.vector_type = "3D" - else: - self.vector_type = "2D" - - # Setting the interpolation method dynamically - if vector_interp_method is None: - self._vector_interp_method = None - else: - self._validate_vector_interp_function(vector_interp_method) - self._interp_method = vector_interp_method - - def __repr__(self): - return f"""<{type(self).__name__}> - name: {self.name!r} - U: {default_repr(self.U)} - V: {default_repr(self.V)} - W: {default_repr(self.W)}""" - - @property - def vector_interp_method(self): - return self._vector_interp_method - - @vector_interp_method.setter - def vector_interp_method(self, method: Callable): - self._validate_vector_interp_function(method) - self._vector_interp_method = method - - # @staticmethod - # To do : def _check_grid_dimensions(grid1, grid2): - # return ( - # np.allclose(grid1.lon, grid2.lon) - # and np.allclose(grid1.lat, grid2.lat) - # and np.allclose(grid1.depth, grid2.depth) - # and np.allclose(grid1.time, grid2.time) - # ) - def _interpolate(self, time, z, y, x, ei): - - bcoords, _ei, ti = self._search_indices(time, z, y, x, ei=ei) - - if self._vector_interp_method is None: - u = self.U.eval(time, z, y, x, _ei, applyConversion=False) - v = self.V.eval(time, z, y, x, _ei, applyConversion=False) - if "3D" in self.vector_type: - w = self.W.eval(time, z, y, x, ei, applyConversion=False) - return (u, v, w) - else: - return (u, v, 0) - else: - (u,v,w) = self._vector_interp_method(ti, _ei, bcoords, time, z, y, x) - return (u, v, w) - - - def eval(self, time, z, y, x, ei=None, applyConversion=True): - - if ei is None: - _ei = 0 - else: - _ei = ei[self.igrid] - - (u,v,w) = self._interpolate(time, z, y, x, _ei) - - if applyConversion: - u = self.U.units.to_target(u, z, y, x) - v = self.V.units.to_target(v, z, y, x) - if "3D" in self.vector_type: - w = self.W.units.to_target(w, z, y, x) - - return (u, v, w) - - def __getitem__(self,key): - try: - if _isParticle(key): - return self.eval(key.time, key.depth, key.lat, key.lon, key.ei) - else: - return self.eval(*key) - except tuple(AllParcelsErrorCodes.keys()) as error: - return _deal_with_errors(error, key, vector_type=self.vector_type) \ No newline at end of file diff --git a/parcels/xfieldset.py b/parcels/xfieldset.py deleted file mode 100644 index 3b62ad9a0d..0000000000 --- a/parcels/xfieldset.py +++ /dev/null @@ -1,337 +0,0 @@ -import importlib.util -import os -import sys -import warnings -from glob import glob - -import numpy as np - -from parcels._typing import GridIndexingType, InterpMethodOption, Mesh -from parcels.xfield import XField, XVectorField -from parcels.particlefile import ParticleFile -from parcels.tools._helpers import fieldset_repr, default_repr -from parcels.tools.converters import TimeConverter -from parcels.tools.warnings import FieldSetWarning - -import xarray as xr -import uxarray as ux - -__all__ = ["XFieldSet"] - - -class XFieldSet: - """XFieldSet class that holds hydrodynamic data needed to execute particles. - - Parameters - ---------- - ds : xarray.Dataset | uxarray.UxDataset) - xarray.Dataset and/or uxarray.UxDataset objects containing the field data. - - Notes - ----- - The `ds` object is a xarray.Dataset or uxarray.UxDataset object. - In XArray terminology, the (Ux)Dataset holds multiple (Ux)DataArray objects. - Each (Ux)DataArray object is a single "field" that is associated with their own - dimensions and coordinates within the (Ux)Dataset. - - A (Ux)Dataset object is associated with a single mesh, which can have multiple - types of "points" (multiple "grids") (e.g. for UxDataSets, these are "face_lon", - "face_lat", "node_lon", "node_lat", "edge_lon", "edge_lat"). Each (Ux)DataArray is - registered to a specific set of points on the mesh. - - For UxDataset objects, each `UXDataArray.attributes` field dictionary contains - the necessary metadata to help determine which set of points a field is registered - to and what parent model the field is associated with. Parcels uses this metadata - during execution for interpolation. Each `UXDataArray.attributes` field dictionary - must have: - * "location" key set to "face", "node", or "edge" to define which pairing of points a field is associated with. - * "mesh" key to define which parent model the fields are associated with (e.g. "fesom_mesh", "icon_mesh") - - """ - - def __init__(self, ds: xr.Dataset | ux.UxDataset): - self.ds = ds - - self._completed: bool = False - # Create pointers to each (Ux)DataArray - for field in self.ds.data_vars: - setattr(self, field, XField(field,self.ds[field])) - - self._add_UVfield() - - def __repr__(self): - return fieldset_repr(self) - - # @property - # def particlefile(self): - # return self._particlefile - - # @staticmethod - # def checkvaliddimensionsdict(dims): - # for d in dims: - # if d not in ["lon", "lat", "depth", "time"]: - # raise NameError(f"{d} is not a valid key in the dimensions dictionary") - - def add_field(self, field: XField, name: str | None = None): - """Add a :class:`parcels.field.Field` object to the FieldSet. - - Parameters - ---------- - field : parcels.field.Field - Field object to be added - name : str - Name of the :class:`parcels.field.Field` object to be added. Defaults - to name in Field object. - - - Examples - -------- - For usage examples see the following tutorials: - - * `Unit converters <../examples/tutorial_unitconverters.ipynb>`__ (Default value = None) - - """ - if self._completed: - raise RuntimeError( - "FieldSet has already been completed. Are you trying to add a Field after you've created the ParticleSet?" - ) - name = field.name if name is None else name - - if hasattr(self, name): # check if Field with same name already exists when adding new Field - raise RuntimeError(f"FieldSet already has a Field with name '{name}'") - else: - setattr(self, name, field) - - def add_constant_field(self, name: str, value: float, mesh: Mesh = "flat"): - """Wrapper function to add a Field that is constant in space, - useful e.g. when using constant horizontal diffusivity - - Parameters - ---------- - name : str - Name of the :class:`parcels.field.Field` object to be added - value : float - Value of the constant field (stored as 32-bit float) - mesh : str - String indicating the type of mesh coordinates and - units used during velocity interpolation, see also `this tutorial <../examples/tutorial_unitconverters.ipynb>`__: - - 1. spherical (default): Lat and lon in degree, with a - correction for zonal velocity U near the poles. - 2. flat: No conversion, lat/lon are assumed to be in m. - """ - import pandas as pd - - time = pd.to_datetime(['2000-01-01']) - values = np.zeros((1,1,1,1), dtype=np.float32) + value - data = xr.DataArray( - data=values, - name=name, - dims='null', - coords = [time,[0],[0],[0]], - attrs=dict( - description="null", - units="null", - location="node", - mesh=f"constant", - mesh_type=mesh - )) - self.add_field( - XField( - name, - data, - interp_method=None, # To do : Need to define an interpolation method for constants - allow_time_extrapolation=True - ) - ) - - def add_vector_field(self, vfield): - """Add a :class:`parcels.field.VectorField` object to the FieldSet. - - Parameters - ---------- - vfield : parcels.XVectorField - class:`parcels.xfieldset.XVectorField` object to be added - """ - setattr(self, vfield.name, vfield) - for v in vfield.__dict__.values(): - if isinstance(v, XField) and (v not in self.get_fields()): - self.add_field(v) - - def get_fields(self) -> list[XField | XVectorField]: - """Returns a list of all the :class:`parcels.field.Field` and :class:`parcels.field.VectorField` - objects associated with this FieldSet. - """ - fields = [] - for v in self.__dict__.values(): - if type(v) in [XField, XVectorField]: - if v not in fields: - fields.append(v) - return fields - - def _add_UVfield(self): - if not hasattr(self, "UV") and hasattr(self, "u") and hasattr(self, "v"): - self.add_xvector_field(XVectorField("UV", self.u, self.v)) - if not hasattr(self, "UVW") and hasattr(self, "w"): - self.add_xvector_field(XVectorField("UVW", self.u, self.v, self.w)) - - def _check_complete(self): - assert self.u, 'FieldSet does not have a Field named "u"' - assert self.v, 'FieldSet does not have a Field named "v"' - for attr, value in vars(self).items(): - if type(value) is XField: - assert value.name == attr, f"Field {value.name}.name ({attr}) is not consistent" - - self._add_UVfield() - - self._completed = True - - @classmethod - def _parse_wildcards(cls, paths, filenames, var): - if not isinstance(paths, list): - paths = sorted(glob(str(paths))) - if len(paths) == 0: - notfound_paths = filenames[var] if isinstance(filenames, dict) and var in filenames else filenames - raise OSError(f"FieldSet files not found for variable {var}: {notfound_paths}") - for fp in paths: - if not os.path.exists(fp): - raise OSError(f"FieldSet file not found: {fp}") - return paths - - # @classmethod - # def from_netcdf( - # cls, - # filenames, - # variables, - # dimensions, - # fieldtype=None, - # mesh: Mesh = "spherical", - # allow_time_extrapolation: bool | None = None, - # **kwargs, - # ): - - # @classmethod - # def from_nemo( - # cls, - # filenames, - # variables, - # dimensions, - # mesh: Mesh = "spherical", - # allow_time_extrapolation: bool | None = None, - # tracer_interp_method: InterpMethodOption = "cgrid_tracer", - # **kwargs, - # ): - - # @classmethod - # def from_mitgcm( - # cls, - # filenames, - # variables, - # dimensions, - # mesh: Mesh = "spherical", - # allow_time_extrapolation: bool | None = None, - # tracer_interp_method: InterpMethodOption = "cgrid_tracer", - # **kwargs, - # ): - - # @classmethod - # def from_croco( - # cls, - # filenames, - # variables, - # dimensions, - # hc: float | None = None, - # mesh="spherical", - # allow_time_extrapolation=None, - # tracer_interp_method="cgrid_tracer", - # **kwargs, - # ): - - # @classmethod - # def from_c_grid_dataset( - # cls, - # filenames, - # variables, - # dimensions, - # mesh: Mesh = "spherical", - # allow_time_extrapolation: bool | None = None, - # tracer_interp_method: InterpMethodOption = "cgrid_tracer", - # gridindexingtype: GridIndexingType = "nemo", - # **kwargs, - # ): - - - # @classmethod - # def from_mom5( - # cls, - # filenames, - # variables, - # dimensions, - # mesh: Mesh = "spherical", - # allow_time_extrapolation: bool | None = None, - # tracer_interp_method: InterpMethodOption = "bgrid_tracer", - # **kwargs, - # ): - - # @classmethod - # def from_a_grid_dataset(cls, filenames, variables, dimensions, **kwargs): - - # @classmethod - # def from_b_grid_dataset( - # cls, - # filenames, - # variables, - # dimensions, - # mesh: Mesh = "spherical", - # allow_time_extrapolation: bool | None = None, - # tracer_interp_method: InterpMethodOption = "bgrid_tracer", - # **kwargs, - # ): - - def add_constant(self, name, value): - """Add a constant to the FieldSet. Note that all constants are - stored as 32-bit floats. - - Parameters - ---------- - name : str - Name of the constant - value : - Value of the constant (stored as 32-bit float) - - - Examples - -------- - Tutorials using fieldset.add_constant: - `Analytical advection <../examples/tutorial_analyticaladvection.ipynb>`__ - `Diffusion <../examples/tutorial_diffusion.ipynb>`__ - `Periodic boundaries <../examples/tutorial_periodic_boundaries.ipynb>`__ - """ - setattr(self, name, value) - - # def computeTimeChunk(self, time=0.0, dt=1): - # """Load a chunk of three data time steps into the FieldSet. - # This is used when FieldSet uses data imported from netcdf, - # with default option deferred_load. The loaded time steps are at or immediatly before time - # and the two time steps immediately following time if dt is positive (and inversely for negative dt) - - # Parameters - # ---------- - # time : - # Time around which the FieldSet data are to be loaded. - # Time is provided as a double, relatively to Fieldset.time_origin. - # Default is 0. - # dt : - # time step of the integration scheme, needed to set the direction of time chunk loading. - # Default is 1. - # """ - # nextTime = np.inf if dt > 0 else -np.inf - - # if abs(nextTime) == np.inf or np.isnan(nextTime): # Second happens when dt=0 - # return nextTime - # else: - # nSteps = int((nextTime - time) / dt) - # if nSteps == 0: - # return nextTime - # else: - # return time + nSteps * dt From 4465dc42f02437965982a3e238a548676a9aecc9 Mon Sep 17 00:00:00 2001 From: Joe Date: Thu, 27 Mar 2025 11:55:27 -0400 Subject: [PATCH 10/46] Enforce time as a datetime object in field.eval call-stack --- parcels/field.py | 67 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 46 insertions(+), 21 deletions(-) diff --git a/parcels/field.py b/parcels/field.py index 988055cd68..873cded96a 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -8,6 +8,7 @@ import xarray as xr import uxarray as ux from uxarray.grid.neighbors import _barycentric_coordinates +from datetime import datetime import parcels.tools.interpolation_utils as i_u @@ -49,10 +50,15 @@ if TYPE_CHECKING: import numpy.typing as npt - from parcels.xfieldset import XFieldSet + from parcels.fieldset import FieldSet -__all__ = ["XField", "XVectorField"] +__all__ = ["Field", "VectorField", "GridType"] +class GridType(IntEnum): + RectilinearZGrid = 0 + RectilinearSGrid = 1 + CurvilinearZGrid = 2 + CurvilinearSGrid = 3 def _isParticle(key): if hasattr(key, "obs_written"): @@ -76,9 +82,9 @@ def _deal_with_errors(error, key, vector_type: VectorType): else: return 0 -class XField: - """The XField class that holds scalar field data. - The `XField` object is a wrapper around a xarray.DataArray or uxarray.UxDataArray object. +class Field: + """The Field class that holds scalar field data. + The `Field` object is a wrapper around a xarray.DataArray or uxarray.UxDataArray object. Additionally, it holds a dynamic Callable procedure that is used to interpolate the field data. During initialization, the user can supply a custom interpolation method that is used to interpolate the field data, so long as the interpolation method has the correct signature. @@ -194,8 +200,32 @@ def __init__( if type(self.data) is ux.UxDataArray: self._spatialhash = self.data.uxgrid.get_spatial_hash() + self._gtype = None else: self._spatialhash = None + # Set the grid type + if "x_g" in self.data.coords : + lon = self.data.x_g + else: + lon = self.data.x_c + + if "nz1" in self.data.coords : + depth = self.data.nz1 + elif "nz" in self.data.coords : + depth = self.data.nz + else : + depth = None + + if len(lon.shape) <= 1: + if depth is None or len(depth.shape) <=1: + self._gtype = GridType.RectilinearZGrid + else: + self._gtype = GridType.RectilinearSGrid + else: + if depth is None or len(depth.shape) <=1: + self._gtype = GridType.CurvilinearZGrid + else: + self._gtype = GridType.CurvilinearSGrid def __repr__(self): return field_repr(self) @@ -353,23 +383,18 @@ def _search_indices_unstructured(self, z, y, x, ei=None, search2D=False): def _search_indices_structured(self, z, y, x, ei=None, search2D=False): - # To do, determine grid type from xarray.coords shapes - # Rectilinear uses 1-D array for lat and lon - # Curvilinear uses 2-D array for lat and lon - if self.grid._gtype in [GridType.RectilinearSGrid, GridType.RectilinearZGrid]: + if self._gtype in [GridType.RectilinearSGrid, GridType.RectilinearZGrid]: (zeta, eta, xsi, zi, yi, xi) = _search_indices_rectilinear( - self, z, y, x,particle=particle, search2D=search2D + self, z, y, x,ei=ei, search2D=search2D ) else: (zeta, eta, xsi, zi, yi, xi) = _search_indices_curvilinear( self, z, y, x, ei=ei, search2D=search2D ) - # To do : Calcualte barycentric coordinates from zeta, eta, xsi - return (zeta, eta, xsi, zi, yi, xi) - def _search_indices(self, time, z, y, x, ei=None, search2D=False): + def _search_indices(self, time: datetime, z, y, x, ei=None, search2D=False): tau, ti = self._search_time_index(time) # To do : Need to implement this method @@ -379,7 +404,7 @@ def _search_indices(self, time, z, y, x, ei=None, search2D=False): bcoords, ei = self._search_indices_structured(z, y, x, ei=ei, search2D=search2D) # To do : Need to implement this method return bcoords, ei, ti - def _interpolate(self, time, z, y, x, ei): + def _interpolate(self, time: datetime, z, y, x, ei): try: bcoords, _ei, ti = self._search_indices(time, z, y, x, ei=ei) @@ -413,7 +438,7 @@ def __getitem__(self, key): except tuple(AllParcelsErrorCodes.keys()) as error: return _deal_with_errors(error, key, vector_type=None) - def eval(self, time, z, y, x, ei=None, applyConversion=True): + def eval(self, time: datetime, z, y, x, ei=None, applyConversion=True): """Interpolate field values in space and time. We interpolate linearly in time and apply implicit unit @@ -542,11 +567,11 @@ def __contains__(self, key: str): return key in self.data -class XVectorField: - """XVectorField class that holds vector field data needed to execute particles.""" +class VectorField: + """VectorField class that holds vector field data needed to execute particles.""" - @staticmethod + @staticmethod def _vector_interp_template( self, ti: int, @@ -584,9 +609,9 @@ def _validate_vector_interp_function(self, func: Callable): def __init__( self, name: str, - U: XField, - V: XField, - W: XField | None = None, + U: Field, + V: Field, + W: Field | None = None, vector_interp_method: Callable | None = None ): From de2ab3f78cca5b2ff74fca2d029df76c2723af30 Mon Sep 17 00:00:00 2001 From: Joe Date: Thu, 27 Mar 2025 11:55:44 -0400 Subject: [PATCH 11/46] Set the time_origin to the minimum time dimension in the fieldset --- parcels/fieldset.py | 40 +++++++++++++++++++++++----------------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/parcels/fieldset.py b/parcels/fieldset.py index 3b62ad9a0d..af4388428f 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -7,7 +7,7 @@ import numpy as np from parcels._typing import GridIndexingType, InterpMethodOption, Mesh -from parcels.xfield import XField, XVectorField +from parcels.field import Field, VectorField from parcels.particlefile import ParticleFile from parcels.tools._helpers import fieldset_repr, default_repr from parcels.tools.converters import TimeConverter @@ -16,11 +16,11 @@ import xarray as xr import uxarray as ux -__all__ = ["XFieldSet"] +__all__ = ["FieldSet"] -class XFieldSet: - """XFieldSet class that holds hydrodynamic data needed to execute particles. +class FieldSet: + """FieldSet class that holds hydrodynamic data needed to execute particles. Parameters ---------- @@ -55,10 +55,17 @@ def __init__(self, ds: xr.Dataset | ux.UxDataset): self._completed: bool = False # Create pointers to each (Ux)DataArray for field in self.ds.data_vars: - setattr(self, field, XField(field,self.ds[field])) + setattr(self, field, Field(field,self.ds[field])) + # To do : Set the "time_origin" as a datetime object that is the minimum `time` in all of the Field objects + + if "time" in self.ds.coords: + self.time_origin = self.ds.time.min().data + else: + raise ValueError("FieldSet must have a 'time' coordinate") self._add_UVfield() + def __repr__(self): return fieldset_repr(self) @@ -72,7 +79,7 @@ def __repr__(self): # if d not in ["lon", "lat", "depth", "time"]: # raise NameError(f"{d} is not a valid key in the dimensions dictionary") - def add_field(self, field: XField, name: str | None = None): + def add_field(self, field: Field, name: str | None = None): """Add a :class:`parcels.field.Field` object to the FieldSet. Parameters @@ -120,9 +127,8 @@ def add_constant_field(self, name: str, value: float, mesh: Mesh = "flat"): correction for zonal velocity U near the poles. 2. flat: No conversion, lat/lon are assumed to be in m. """ - import pandas as pd - time = pd.to_datetime(['2000-01-01']) + time = 0.0 values = np.zeros((1,1,1,1), dtype=np.float32) + value data = xr.DataArray( data=values, @@ -137,7 +143,7 @@ def add_constant_field(self, name: str, value: float, mesh: Mesh = "flat"): mesh_type=mesh )) self.add_field( - XField( + Field( name, data, interp_method=None, # To do : Need to define an interpolation method for constants @@ -150,36 +156,36 @@ def add_vector_field(self, vfield): Parameters ---------- - vfield : parcels.XVectorField - class:`parcels.xfieldset.XVectorField` object to be added + vfield : parcels.VectorField + class:`parcels.FieldSet.VectorField` object to be added """ setattr(self, vfield.name, vfield) for v in vfield.__dict__.values(): - if isinstance(v, XField) and (v not in self.get_fields()): + if isinstance(v, Field) and (v not in self.get_fields()): self.add_field(v) - def get_fields(self) -> list[XField | XVectorField]: + def get_fields(self) -> list[Field | VectorField]: """Returns a list of all the :class:`parcels.field.Field` and :class:`parcels.field.VectorField` objects associated with this FieldSet. """ fields = [] for v in self.__dict__.values(): - if type(v) in [XField, XVectorField]: + if type(v) in [Field, VectorField]: if v not in fields: fields.append(v) return fields def _add_UVfield(self): if not hasattr(self, "UV") and hasattr(self, "u") and hasattr(self, "v"): - self.add_xvector_field(XVectorField("UV", self.u, self.v)) + self.add_Vector_field(VectorField("UV", self.u, self.v)) if not hasattr(self, "UVW") and hasattr(self, "w"): - self.add_xvector_field(XVectorField("UVW", self.u, self.v, self.w)) + self.add_Vector_field(VectorField("UVW", self.u, self.v, self.w)) def _check_complete(self): assert self.u, 'FieldSet does not have a Field named "u"' assert self.v, 'FieldSet does not have a Field named "v"' for attr, value in vars(self).items(): - if type(value) is XField: + if type(value) is Field: assert value.name == attr, f"Field {value.name}.name ({attr}) is not consistent" self._add_UVfield() From 737d2589d0c044d725134040f83f33afd69792db Mon Sep 17 00:00:00 2001 From: Joe Date: Thu, 27 Mar 2025 12:15:45 -0400 Subject: [PATCH 12/46] Port over rectilinear search. --- parcels/_index_search.py | 53 ++++++++++++++++++++-------------------- parcels/field.py | 52 ++++++++++++++++++++++++++------------- 2 files changed, 62 insertions(+), 43 deletions(-) diff --git a/parcels/_index_search.py b/parcels/_index_search.py index 28a40c34f9..5bacf2f8e5 100644 --- a/parcels/_index_search.py +++ b/parcels/_index_search.py @@ -99,6 +99,7 @@ def search_indices_vertical_z(depth, gridindexingtype: GridIndexingType, z: floa return (zi, zeta) +## joe@fluidnumerics.com : 3/27/25 : To do : Still need to implement the search_indices_vertical_s function def search_indices_vertical_s( grid: Grid, interp_method: InterpMethodOption, @@ -175,32 +176,31 @@ def search_indices_vertical_s( def _search_indices_rectilinear( - field: Field, time: float, z: float, y: float, x: float, ti: int, particle=None, search2D=False + field: Field, time: datetime, z: float, y: float, x: float, ti: int, ei:int=None, search2D=False ): - grid = field.grid - - if grid.xdim > 1 and (not grid.zonal_periodic): - if x < grid.lonlat_minmax[0] or x > grid.lonlat_minmax[1]: + # To do : If ei is provided, check if particle is in the same cell + if field.xdim > 1 and (not field.zonal_periodic): + if x < grid.lonlat_minmax[0] or x > grid.lonlat_minmax[1]: # To do : implement lonlat_minmax at field level _raise_field_out_of_bound_error(z, y, x) - if grid.ydim > 1 and (y < grid.lonlat_minmax[2] or y > grid.lonlat_minmax[3]): + if field.ydim > 1 and (y < grid.lonlat_minmax[2] or y > grid.lonlat_minmax[3]): # To do : implement lonlat_minmax at field level _raise_field_out_of_bound_error(z, y, x) - if grid.xdim > 1: - if grid.mesh != "spherical": - lon_index = grid.lon < x + if field.xdim > 1: + if field._mesh_type != "spherical": + lon_index = field.lon < x if lon_index.all(): - xi = len(grid.lon) - 2 + xi = len(field.lon) - 2 else: xi = lon_index.argmin() - 1 if lon_index.any() else 0 - xsi = (x - grid.lon[xi]) / (grid.lon[xi + 1] - grid.lon[xi]) + xsi = (x - field.lon[xi]) / (field.lon[xi + 1] - field.lon[xi]) if xsi < 0: xi -= 1 - xsi = (x - grid.lon[xi]) / (grid.lon[xi + 1] - grid.lon[xi]) + xsi = (x - field.lon[xi]) / (field.lon[xi + 1] - field.lon[xi]) elif xsi > 1: xi += 1 - xsi = (x - grid.lon[xi]) / (grid.lon[xi + 1] - grid.lon[xi]) + xsi = (x - field.lon[xi]) / (field.lon[xi + 1] - field.lon[xi]) else: - lon_fixed = grid.lon.copy() + lon_fixed = field.lon.copy() indices = lon_fixed >= lon_fixed[0] if not indices.all(): lon_fixed[indices.argmin() :] += 360 @@ -222,32 +222,33 @@ def _search_indices_rectilinear( else: xi, xsi = -1, 0 - if grid.ydim > 1: - lat_index = grid.lat < y + if field.ydim > 1: + lat_index = field.lat < y if lat_index.all(): - yi = len(grid.lat) - 2 + yi = len(field.lat) - 2 else: yi = lat_index.argmin() - 1 if lat_index.any() else 0 - eta = (y - grid.lat[yi]) / (grid.lat[yi + 1] - grid.lat[yi]) + eta = (y - field.lat[yi]) / (field.lat[yi + 1] - field.lat[yi]) if eta < 0: yi -= 1 - eta = (y - grid.lat[yi]) / (grid.lat[yi + 1] - grid.lat[yi]) + eta = (y - field.lat[yi]) / (field.lat[yi + 1] - field.lat[yi]) elif eta > 1: yi += 1 - eta = (y - grid.lat[yi]) / (grid.lat[yi + 1] - grid.lat[yi]) + eta = (y - field.lat[yi]) / (field.lat[yi + 1] - field.lat[yi]) else: yi, eta = -1, 0 - if grid.zdim > 1 and not search2D: - if grid._gtype == GridType.RectilinearZGrid: + if field.zdim > 1 and not search2D: + if field._gtype == GridType.RectilinearZGrid: try: (zi, zeta) = search_indices_vertical_z(field.grid, field.gridindexingtype, z) except FieldOutOfBoundError: _raise_field_out_of_bound_error(z, y, x) except FieldOutOfBoundSurfaceError: _raise_field_out_of_bound_surface_error(z, y, x) - elif grid._gtype == GridType.RectilinearSGrid: + elif field._gtype == GridType.RectilinearSGrid: + ## joe@fluidnumerics.com : 3/27/25 : To do : Still need to implement the search_indices_vertical_s function (zi, zeta) = search_indices_vertical_s(field.grid, field.interp_method, time, z, y, x, ti, yi, xi, eta, xsi) else: zi, zeta = -1, 0 @@ -255,12 +256,12 @@ def _search_indices_rectilinear( if not ((0 <= xsi <= 1) and (0 <= eta <= 1) and (0 <= zeta <= 1)): _raise_field_sampling_error(z, y, x) - if particle: - particle.ei[field.igrid] = field.ravel_index(zi, yi, xi) + _ei = field.ravel_index(zi, yi, xi) - return (zeta, eta, xsi, zi, yi, xi) + return (zeta, eta, xsi, _ei) +## joe@fluidnumerics.com : 3/27/25 : To do : Still need to implement the search_indices_curvilinear def _search_indices_curvilinear(field: Field, time, z, y, x, ti, particle=None, search2D=False): if particle: zi, yi, xi = field.unravel_index(particle.ei) diff --git a/parcels/field.py b/parcels/field.py index 873cded96a..9b30e1e1ac 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -124,6 +124,7 @@ def _interp_template( ti: int, ei: int, bcoords: np.ndarray, + tau: Union[np.float32,np.float64], t: Union[np.float32,np.float64], z: Union[np.float32,np.float64], y: Union[np.float32,np.float64], @@ -134,7 +135,7 @@ def _interp_template( def _validate_interp_function(self, func: Callable): """Ensures that the function has the correct signature.""" - expected_params = ["ti", "ei", "bcoords", "t", "z", "y", "x"] + expected_params = ["ti", "ei", "bcoords", "tau", "t", "z", "y", "x"] expected_return_types = (np.float32,np.float64) sig = inspect.signature(func) @@ -289,7 +290,7 @@ def depth(self): return self.data.nz @property - def nx(self): + def xdim(self): if type(self.data) is xr.DataArray: if "face_lon" in self.data.dims: return self.data.sizes["face_lon"] @@ -298,7 +299,7 @@ def nx(self): else: return 0 # To do : Discuss what we want to return for uxdataarray obj @property - def ny(self): + def ydim(self): if type(self.data) is xr.DataArray: if "face_lat" in self.data.dims: return self.data.sizes["face_lat"] @@ -306,6 +307,16 @@ def ny(self): return self.data.sizes["node_lat"] else: return 0 # To do : Discuss what we want to return for uxdataarray obj + + @property + def zdim(self): + if "nz1" in self.data.dims: + return self.data.sizes["nz1"] + elif "nz" in self.data.dims: + return self.data.sizes["nz"] + else: + return 0 + @property def n_face(self): if type(self.data) is ux.uxDataArray: @@ -388,27 +399,34 @@ def _search_indices_structured(self, z, y, x, ei=None, search2D=False): self, z, y, x,ei=ei, search2D=search2D ) else: - (zeta, eta, xsi, zi, yi, xi) = _search_indices_curvilinear( - self, z, y, x, ei=ei, search2D=search2D - ) + ## joe@fluidnumerics.com : 3/27/25 : To do : Still need to implement the search_indices_curvilinear + # (zeta, eta, xsi, zi, yi, xi) = _search_indices_curvilinear( + # self, z, y, x, ei=ei, search2D=search2D + # ) + raise NotImplementedError("Curvilinear grid search not implemented yet") return (zeta, eta, xsi, zi, yi, xi) def _search_indices(self, time: datetime, z, y, x, ei=None, search2D=False): - tau, ti = self._search_time_index(time) # To do : Need to implement this method + tau, ti = _search_time_index(self,time,self.allow_time_extrapolation) + + if ei is None: + _ei = None + else: + _ei = ei[self.igrid] if type(self.data) is ux.UxDataArray: - bcoords, ei = self._search_indices_unstructured(z, y, x, ei=ei, search2D=search2D) # To do : Need to implement this method + bcoords, ei = self._search_indices_unstructured(z, y, x, ei=_ei, search2D=search2D) else: - bcoords, ei = self._search_indices_structured(z, y, x, ei=ei, search2D=search2D) # To do : Need to implement this method - return bcoords, ei, ti + bcoords, ei = self._search_indices_structured(z, y, x, ei=_ei, search2D=search2D) + return bcoords, ei, tau, ti def _interpolate(self, time: datetime, z, y, x, ei): try: - bcoords, _ei, ti = self._search_indices(time, z, y, x, ei=ei) - val = self._interp_method(ti, _ei, bcoords, time, z, y, x) + bcoords, _ei, tau, ti = self._search_indices(time, z, y, x, ei=ei) + val = self._interp_method(ti, _ei, bcoords, tau, time, z, y, x) if np.isnan(val): # Detect Out-of-bounds sampling and raise exception @@ -479,7 +497,7 @@ def ravel_index(self, zi, yi, xi): flat index """ if type(self.data) is xr.DataArray: - return xi + self.nx * (yi + self.ny * zi) + return xi + self.xdim * (yi + self.ydim * zi) else: return xi + self.n_face*zi @@ -503,10 +521,10 @@ def unravel_index(self, ei): """ if type(self.data) is xr.DataArray: _ei = ei[self.igrid] - zi = _ei // (self.nx * self.ny) - _ei = _ei % (self.nx * self.ny) - yi = _ei // self.nx - xi = _ei % self.nx + zi = _ei // (self.xdim * self.ydim) + _ei = _ei % (self.xdim * self.ydim) + yi = _ei // self.xdim + xi = _ei % self.xdim return zi, yi, xi else: _ei = ei[self.igrid] From 6b9fe524be8e0339fe9c67c5619627ec574eac8b Mon Sep 17 00:00:00 2001 From: Joe Date: Mon, 31 Mar 2025 09:34:04 -0400 Subject: [PATCH 13/46] Change vector field component names to uppercase --- parcels/fieldset.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/parcels/fieldset.py b/parcels/fieldset.py index af4388428f..45c6752292 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -176,14 +176,14 @@ def get_fields(self) -> list[Field | VectorField]: return fields def _add_UVfield(self): - if not hasattr(self, "UV") and hasattr(self, "u") and hasattr(self, "v"): - self.add_Vector_field(VectorField("UV", self.u, self.v)) - if not hasattr(self, "UVW") and hasattr(self, "w"): - self.add_Vector_field(VectorField("UVW", self.u, self.v, self.w)) + if not hasattr(self, "UV") and hasattr(self, "U") and hasattr(self, "V"): + self.add_Vector_field(VectorField("UV", self.U, self.V)) + if not hasattr(self, "UVW") and hasattr(self, "W"): + self.add_Vector_field(VectorField("UVW", self.U, self.V, self.W)) def _check_complete(self): - assert self.u, 'FieldSet does not have a Field named "u"' - assert self.v, 'FieldSet does not have a Field named "v"' + assert self.U, 'FieldSet does not have a Field named "U"' + assert self.V, 'FieldSet does not have a Field named "V"' for attr, value in vars(self).items(): if type(value) is Field: assert value.name == attr, f"Field {value.name}.name ({attr}) is not consistent" From 5e73050e678eda359387a600e5f0591b26ae2ea0 Mon Sep 17 00:00:00 2001 From: Joe Date: Mon, 31 Mar 2025 10:07:57 -0400 Subject: [PATCH 14/46] Add simple tests --- v4-tests/test_uxarray_fieldset.py | 27 +++++++++++++++++++++++++++ v4-tests/test_xarray_fieldset.py | 0 2 files changed, 27 insertions(+) create mode 100644 v4-tests/test_uxarray_fieldset.py create mode 100644 v4-tests/test_xarray_fieldset.py diff --git a/v4-tests/test_uxarray_fieldset.py b/v4-tests/test_uxarray_fieldset.py new file mode 100644 index 0000000000..f1bdd7df30 --- /dev/null +++ b/v4-tests/test_uxarray_fieldset.py @@ -0,0 +1,27 @@ + +import uxarray as ux +from datetime import timedelta +from parcels import ( + FieldSet, + ParticleSet, + Particle, +) +import os + +# Get path of this script +V4_TEST_DATA = f"{os.path.dirname(__file__)}/test_data" + +def test_fesom_fieldset(): + # Load a FESOM dataset + grid_path = f"{V4_TEST_DATA}/fesom_channel.nc" + data_path = [ + f"{V4_TEST_DATA}/u.fesom_channel.nc", + f"{V4_TEST_DATA}/v.fesom_channel.nc", + f"{V4_TEST_DATA}/w.fesom_channel.nc", + ] + ds = ux.open_mfdataset(grid_path, data_path) + ds = ds.rename_vars({"u": "U", "v": "V", "w": "W"}) + fieldset = FieldSet(ds) + fieldset._check_complete() + # Check that the fieldset has the expected properties + assert fieldset.ds == ds diff --git a/v4-tests/test_xarray_fieldset.py b/v4-tests/test_xarray_fieldset.py new file mode 100644 index 0000000000..e69de29bb2 From 6e7c60a51e2717d24502ec0c07cba4cf6e761aad Mon Sep 17 00:00:00 2001 From: Joe Date: Mon, 31 Mar 2025 10:08:20 -0400 Subject: [PATCH 15/46] Add uxarray to environment --- environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/environment.yml b/environment.yml index 82f505f8e2..e51cd059fb 100644 --- a/environment.yml +++ b/environment.yml @@ -18,6 +18,7 @@ dependencies: #! Keep in sync with [tool.pixi.dependencies] in pyproject.toml - dask>=2.0 - scikit-learn - zarr>=2.11.0,!=2.18.0,<3 + - uxaray>=2025.3.0 # Notebooks - trajan From 7b5aba1135bbd129e2a8b743414dd0e3800f00c2 Mon Sep 17 00:00:00 2001 From: Joe Date: Mon, 31 Mar 2025 10:08:57 -0400 Subject: [PATCH 16/46] Resolve bugs for simple fesom fieldset loading --- parcels/field.py | 7 ++++--- parcels/fieldset.py | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/parcels/field.py b/parcels/field.py index 9b30e1e1ac..cb77dc854d 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -44,6 +44,7 @@ from parcels.tools.warnings import FieldSetWarning import inspect from typing import Callable, Union +from enum import IntEnum from ._index_search import _search_indices_curvilinear, _search_indices_rectilinear, _search_time_index @@ -166,11 +167,11 @@ def __init__( self.name = name self.data = data - self._validate_dataarray(data) + self._validate_dataarray() - self._parent_mesh = data.attributes["mesh"] + self._parent_mesh = data.attrs["mesh"] self._mesh_type = mesh_type - self._location = data.attributes["location"] + self._location = data.attrs["location"] # Set the vertical location if "nz1" in data.dims: diff --git a/parcels/fieldset.py b/parcels/fieldset.py index 45c6752292..1ebf688303 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -177,9 +177,9 @@ def get_fields(self) -> list[Field | VectorField]: def _add_UVfield(self): if not hasattr(self, "UV") and hasattr(self, "U") and hasattr(self, "V"): - self.add_Vector_field(VectorField("UV", self.U, self.V)) + self.add_vector_field(VectorField("UV", self.U, self.V)) if not hasattr(self, "UVW") and hasattr(self, "W"): - self.add_Vector_field(VectorField("UVW", self.U, self.V, self.W)) + self.add_vector_field(VectorField("UVW", self.U, self.V, self.W)) def _check_complete(self): assert self.U, 'FieldSet does not have a Field named "U"' From 812875ad77099074b34d0f34c66b072c780e531d Mon Sep 17 00:00:00 2001 From: Joe Date: Mon, 31 Mar 2025 10:52:29 -0400 Subject: [PATCH 17/46] Add dimrange function and gridset_size property to fieldset The fieldset.dimrange function is meant to be a replacement for fieldset.grid.dimrange. Functionally, it operates in the same way but relies on the fields underneath the (u)xarray.(Ux)dataset. Currently, this is incompatible with a fieldset constructed using the `fieldset.add_field` function, since add_field does not update the `ds`. Need to think on how we want to track this. The `fieldset.gridset_size` property is set to be equivalent to the number of fields (dataarrays) stored under the fieldset. This property is meant to replace the `fieldset.gridset.size` property. --- parcels/fieldset.py | 37 +++++++++++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/parcels/fieldset.py b/parcels/fieldset.py index 1ebf688303..5aed3d8964 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -56,8 +56,9 @@ def __init__(self, ds: xr.Dataset | ux.UxDataset): # Create pointers to each (Ux)DataArray for field in self.ds.data_vars: setattr(self, field, Field(field,self.ds[field])) - # To do : Set the "time_origin" as a datetime object that is the minimum `time` in all of the Field objects + self._gridset_size = len(self.ds.data_vars) + if "time" in self.ds.coords: self.time_origin = self.ds.time.min().data else: @@ -69,6 +70,33 @@ def __init__(self, ds: xr.Dataset | ux.UxDataset): def __repr__(self): return fieldset_repr(self) + def dimrange(self, dim): + """Returns maximum value of a dimension (lon, lat, depth or time) + on 'left' side and minimum value on 'right' side for all grids + in a gridset. Useful for finding e.g. longitude range that + overlaps on all grids in a gridset. + """ + maxleft, minright = (-np.inf, np.inf) + dim2ds = { + "depth": ["nz1","nz"], + "lat": ["node_lat", "face_lat", "edge_lat"], + "lon": ["node_lon", "face_lon", "edge_lon"], + "time": ["time"] + } + for field in self.ds.data_vars: + for d in dim2ds[dim]: # check all possible dimensions + if d in self.ds[field].dims: + if dim == "depth": + maxleft = max(maxleft, self.ds[field][d].min().data) + minright = min(minright, self.ds[field][d].max().data) + else: + maxleft = max(maxleft, self.ds[field][d].data[0]) + minright = min(minright, self.ds[field][d].data[-1]) + maxleft = 0 if maxleft == -np.inf else maxleft # if all len(dim) == 1 + minright = 0 if minright == np.inf else minright # if all len(dim) == 1 + + return maxleft, minright + # @property # def particlefile(self): # return self._particlefile @@ -78,7 +106,10 @@ def __repr__(self): # for d in dims: # if d not in ["lon", "lat", "depth", "time"]: # raise NameError(f"{d} is not a valid key in the dimensions dictionary") - + @property + def gridset_size(self): + return self._gridset_size + def add_field(self, field: Field, name: str | None = None): """Add a :class:`parcels.field.Field` object to the FieldSet. @@ -98,6 +129,7 @@ def add_field(self, field: Field, name: str | None = None): * `Unit converters <../examples/tutorial_unitconverters.ipynb>`__ (Default value = None) """ + if self._completed: raise RuntimeError( "FieldSet has already been completed. Are you trying to add a Field after you've created the ParticleSet?" @@ -108,6 +140,7 @@ def add_field(self, field: Field, name: str | None = None): raise RuntimeError(f"FieldSet already has a Field with name '{name}'") else: setattr(self, name, field) + self._gridset_size += 1 def add_constant_field(self, name: str, value: float, mesh: Mesh = "flat"): """Wrapper function to add a Field that is constant in space, From 3136dbf74dfca60feb0f63769f7757bd9af942b9 Mon Sep 17 00:00:00 2001 From: Joe Date: Mon, 31 Mar 2025 11:15:28 -0400 Subject: [PATCH 18/46] Add support for list of (U)xarray.(Ux)datasets in fieldset This commit allows us to define a fieldset using a list of UXarray.UxDatasets and Xarray.Datasets. Currently, because `fieldset.add_field` is used to add pointers to each field, the field names between the variables in each list item must be unique --- parcels/fieldset.py | 54 +++++++++++++++++++------------ v4-tests/test_uxarray_fieldset.py | 21 ++++++++++-- 2 files changed, 52 insertions(+), 23 deletions(-) diff --git a/parcels/fieldset.py b/parcels/fieldset.py index 5aed3d8964..6e8875e8bc 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -15,6 +15,7 @@ import xarray as xr import uxarray as ux +from typing import List, Union __all__ = ["FieldSet"] @@ -49,21 +50,30 @@ class FieldSet: """ - def __init__(self, ds: xr.Dataset | ux.UxDataset): - self.ds = ds + def __init__(self, datasets: List[Union[xr.Dataset,ux.UxDataset]]): + self.datasets = datasets self._completed: bool = False + self._gridset_size: int = 0 + self._fieldnames = [] + time_origin = None # Create pointers to each (Ux)DataArray - for field in self.ds.data_vars: - setattr(self, field, Field(field,self.ds[field])) - - self._gridset_size = len(self.ds.data_vars) - - if "time" in self.ds.coords: - self.time_origin = self.ds.time.min().data - else: - raise ValueError("FieldSet must have a 'time' coordinate") - + for ds in datasets: + for field in ds.data_vars: + self.add_field(Field(field, ds[field]),field) + self._gridset_size += 1 + self._fieldnames.append(field) + #setattr(self, field, Field(field,self.ds[field]), field) + + if "time" in ds.coords: + if time_origin is None: + time_origin = ds.time.min().data + else: + time_origin = min(time_origin, ds.time.min().data) + else: + raise ValueError("Each dataset must have a 'time' coordinate") + + self.time_origin = time_origin self._add_UVfield() @@ -83,15 +93,16 @@ def dimrange(self, dim): "lon": ["node_lon", "face_lon", "edge_lon"], "time": ["time"] } - for field in self.ds.data_vars: - for d in dim2ds[dim]: # check all possible dimensions - if d in self.ds[field].dims: - if dim == "depth": - maxleft = max(maxleft, self.ds[field][d].min().data) - minright = min(minright, self.ds[field][d].max().data) - else: - maxleft = max(maxleft, self.ds[field][d].data[0]) - minright = min(minright, self.ds[field][d].data[-1]) + for ds in self.datasets: + for field in ds.data_vars: + for d in dim2ds[dim]: # check all possible dimensions + if d in ds[field].dims: + if dim == "depth": + maxleft = max(maxleft, ds[field][d].min().data) + minright = min(minright, ds[field][d].max().data) + else: + maxleft = max(maxleft, ds[field][d].data[0]) + minright = min(minright, ds[field][d].data[-1]) maxleft = 0 if maxleft == -np.inf else maxleft # if all len(dim) == 1 minright = 0 if minright == np.inf else minright # if all len(dim) == 1 @@ -141,6 +152,7 @@ def add_field(self, field: Field, name: str | None = None): else: setattr(self, name, field) self._gridset_size += 1 + self._fieldnames.append(name) def add_constant_field(self, name: str, value: float, mesh: Mesh = "flat"): """Wrapper function to add a Field that is constant in space, diff --git a/v4-tests/test_uxarray_fieldset.py b/v4-tests/test_uxarray_fieldset.py index f1bdd7df30..014e72f143 100644 --- a/v4-tests/test_uxarray_fieldset.py +++ b/v4-tests/test_uxarray_fieldset.py @@ -11,6 +11,7 @@ # Get path of this script V4_TEST_DATA = f"{os.path.dirname(__file__)}/test_data" + def test_fesom_fieldset(): # Load a FESOM dataset grid_path = f"{V4_TEST_DATA}/fesom_channel.nc" @@ -21,7 +22,23 @@ def test_fesom_fieldset(): ] ds = ux.open_mfdataset(grid_path, data_path) ds = ds.rename_vars({"u": "U", "v": "V", "w": "W"}) - fieldset = FieldSet(ds) + fieldset = FieldSet([ds]) fieldset._check_complete() # Check that the fieldset has the expected properties - assert fieldset.ds == ds + assert fieldset.datasets[0] == ds + +def test_fesom_in_particleset(): + grid_path = f"{V4_TEST_DATA}/fesom_channel.nc" + data_path = [ + f"{V4_TEST_DATA}/u.fesom_channel.nc", + f"{V4_TEST_DATA}/v.fesom_channel.nc", + f"{V4_TEST_DATA}/w.fesom_channel.nc", + ] + ds = ux.open_mfdataset(grid_path, data_path) + ds = ds.rename_vars({"u": "U", "v": "V", "w": "W"}) + fieldset = FieldSet([ds]) + pset = ParticleSet(fieldset, pclass=Particle) + + +# if __name__ == "__main__": +# test_fesom_in_particleset() \ No newline at end of file From 2e0a1f8409156e794d10ddb04844840443190987 Mon Sep 17 00:00:00 2001 From: Joe Date: Mon, 31 Mar 2025 11:18:17 -0400 Subject: [PATCH 19/46] Switch to using updated fieldset api for gridset_size and dimrange --- parcels/particleset.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/parcels/particleset.py b/parcels/particleset.py index ec452f858b..563a17ba6d 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -122,7 +122,7 @@ def ArrayClass_init(self, *args, **kwargs): if type(self).ngrids.initial < 0: numgrids = ngrids if numgrids is None and fieldset is not None: - numgrids = fieldset.gridset.size + numgrids = fieldset.gridset_size assert numgrids is not None, "Neither fieldsets nor number of grids are specified - exiting." type(self).ngrids.initial = numgrids self.ngrids = type(self).ngrids.initial @@ -148,7 +148,7 @@ def ArrayClass_init(self, *args, **kwargs): pid_orig = np.arange(lon.size) if depth is None: - mindepth = self.fieldset.gridset.dimrange("depth")[0] + mindepth = self.fieldset.dimrange("depth")[0] depth = np.ones(lon.size) * mindepth else: depth = convert_to_flat_array(depth) @@ -191,7 +191,7 @@ def ArrayClass_init(self, *args, **kwargs): self._repeatkwargs = kwargs self._repeatkwargs.pop("partition_function", None) - ngrids = fieldset.gridset.size + ngrids = fieldset.gridset_size # Variables used for interaction kernels. inter_dist_horiz = None @@ -962,7 +962,7 @@ def execute( if runtime is not None and endtime is not None: raise RuntimeError("Only one of (endtime, runtime) can be specified") - mintime, maxtime = self.fieldset.gridset.dimrange("time") + mintime, maxtime = self.fieldset.dimrange("time") default_release_time = mintime if dt >= 0 else maxtime if np.any(np.isnan(self.particledata.data["time"])): @@ -980,7 +980,7 @@ def execute( if runtime is not None: endtime = starttime + runtime * np.sign(dt) elif endtime is None: - mintime, maxtime = self.fieldset.gridset.dimrange("time") + mintime, maxtime = self.fieldset.dimrange("time") endtime = maxtime if dt >= 0 else mintime if (abs(endtime - starttime) < 1e-5 or runtime == 0) and dt == 0: From a9b7f879d99739a20210bdea35162ebd46dadaff Mon Sep 17 00:00:00 2001 From: Joe Date: Mon, 31 Mar 2025 13:17:58 -0400 Subject: [PATCH 20/46] Add interpolation methods with test --- parcels/application_kernels/__init__.py | 1 + parcels/application_kernels/interpolation.py | 54 ++++++++++++++++++++ v4-tests/test_uxarray_fieldset.py | 24 ++++++++- 3 files changed, 77 insertions(+), 2 deletions(-) create mode 100644 parcels/application_kernels/interpolation.py diff --git a/parcels/application_kernels/__init__.py b/parcels/application_kernels/__init__.py index d308bcb1ec..b99e4f2544 100644 --- a/parcels/application_kernels/__init__.py +++ b/parcels/application_kernels/__init__.py @@ -1,3 +1,4 @@ from .advection import * from .advectiondiffusion import * from .interaction import * +from .interpolation import * \ No newline at end of file diff --git a/parcels/application_kernels/interpolation.py b/parcels/application_kernels/interpolation.py new file mode 100644 index 0000000000..2cc13a7347 --- /dev/null +++ b/parcels/application_kernels/interpolation.py @@ -0,0 +1,54 @@ +"""Collection of pre-built interpolation kernels.""" + +import math +from typing import Union +import numpy as np + +from parcels.tools.statuscodes import StatusCode +from parcels.field import Field + +__all__ = [ + "UXPiecewiseConstantFace", + "UXPiecewiseLinearNode", +] + +def UXPiecewiseConstantFace( + field: Field, + ti: int, + ei: int, + bcoords: np.ndarray, + tau: Union[np.float32,np.float64], + t: Union[np.float32,np.float64], + z: Union[np.float32,np.float64], + y: Union[np.float32,np.float64], + x: Union[np.float32,np.float64] + ): + """ + Piecewise constant interpolation kernel for face registered data. + This interpolation method is appropriate for fields that are + face registered, such as u,v in FESOM. + """ + # To do : handle vertical interpolation + zi, fi = field.unravel_index(ei) + return field.data[ti, zi, fi] + +def UXPiecewiseLinearNode( + field: Field, + ti: int, + ei: int, + bcoords: np.ndarray, + tau: Union[np.float32,np.float64], + t: Union[np.float32,np.float64], + z: Union[np.float32,np.float64], + y: Union[np.float32,np.float64], + x: Union[np.float32,np.float64] + ): + """ + Piecewise linear interpolation kernel for node registered data. This + interpolation method is appropriate for fields that are node registered + such as the vertical velocity w in FESOM. + """ + # To do: handle vertical interpolation + zi, fi = field.unravel_index(ei) + node_ids = field.data.uxgrid.face_node_connectivity[fi,:] + return np.dot(field.data[ti, zi, node_ids],bcoords) diff --git a/v4-tests/test_uxarray_fieldset.py b/v4-tests/test_uxarray_fieldset.py index 014e72f143..f9a9cc2ae5 100644 --- a/v4-tests/test_uxarray_fieldset.py +++ b/v4-tests/test_uxarray_fieldset.py @@ -5,6 +5,8 @@ FieldSet, ParticleSet, Particle, + UXPiecewiseConstantFace, + UXPiecewiseLinearNode, ) import os @@ -40,5 +42,23 @@ def test_fesom_in_particleset(): pset = ParticleSet(fieldset, pclass=Particle) -# if __name__ == "__main__": -# test_fesom_in_particleset() \ No newline at end of file +def test_set_interp_methods(): + grid_path = f"{V4_TEST_DATA}/fesom_channel.nc" + data_path = [ + f"{V4_TEST_DATA}/u.fesom_channel.nc", + f"{V4_TEST_DATA}/v.fesom_channel.nc", + f"{V4_TEST_DATA}/w.fesom_channel.nc", + ] + ds = ux.open_mfdataset(grid_path, data_path) + ds = ds.rename_vars({"u": "U", "v": "V", "w": "W"}) + fieldset = FieldSet([ds]) + # Set the interpolation method for each field + fieldset.U.interp_method = UXPiecewiseConstantFace + fieldset.V.interp_method = UXPiecewiseConstantFace + fieldset.W.interp_method = UXPiecewiseLinearNode + + # pset = ParticleSet(fieldset, pclass=Particle) + # pset.execute(associate_interp_function, endtime=timedelta(days=1), dt=timedelta(hours=1)) + +if __name__ == "__main__": + test_set_interp_methods() \ No newline at end of file From 3c2488fa26d5a2a15fd8c53ac4b45da980346f76 Mon Sep 17 00:00:00 2001 From: Joe Date: Mon, 31 Mar 2025 13:24:53 -0400 Subject: [PATCH 21/46] Fix method to verify interpolation method call signature. --- parcels/field.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/parcels/field.py b/parcels/field.py index cb77dc854d..26cec8ac06 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -45,7 +45,6 @@ import inspect from typing import Callable, Union from enum import IntEnum - from ._index_search import _search_indices_curvilinear, _search_indices_rectilinear, _search_time_index if TYPE_CHECKING: @@ -134,26 +133,27 @@ def _interp_template( """ Template function used for the signature check of the lateral interpolation methods.""" return 0.0 - def _validate_interp_function(self, func: Callable): + def _validate_interp_function(self, func: Callable) -> bool: """Ensures that the function has the correct signature.""" - expected_params = ["ti", "ei", "bcoords", "tau", "t", "z", "y", "x"] - expected_return_types = (np.float32,np.float64) + template_sig = inspect.signature(self._interp_template) + func_sig = inspect.signature(func) - sig = inspect.signature(func) - params = list(sig.parameters.keys()) + if len(template_sig.parameters) != len(func_sig.parameters): + return False - # Check the parameter names and count - if params != expected_params: - raise TypeError( - f"Function must have parameters {expected_params}, but got {params}" - ) + for ((name1, param1), (name2, param2)) in zip(template_sig.parameters.items(), func_sig.parameters.items()): + if param1.kind != param2.kind: + return False + if param1.annotation != param2.annotation: + return False - # Check return annotation if present - return_annotation = sig.return_annotation - if return_annotation not in (inspect.Signature.empty, *expected_return_types): - raise TypeError( - f"Function must return a float, but got {return_annotation}" - ) + return_annotation = func_sig.return_annotation + template_return = template_sig.return_annotation + + if return_annotation != template_return: + return False + + return True def __init__( self, From 75f7916dc27908075f234d40d7faa8b3216cb848 Mon Sep 17 00:00:00 2001 From: Joe Date: Mon, 31 Mar 2025 13:30:21 -0400 Subject: [PATCH 22/46] Allow eval to accept particle, rather than ei This limits the changes required in advection.py. Still need to check that the VectorField.igrid value is set appropriately. --- parcels/field.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/parcels/field.py b/parcels/field.py index 26cec8ac06..69b6191853 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -457,17 +457,17 @@ def __getitem__(self, key): except tuple(AllParcelsErrorCodes.keys()) as error: return _deal_with_errors(error, key, vector_type=None) - def eval(self, time: datetime, z, y, x, ei=None, applyConversion=True): + def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True): """Interpolate field values in space and time. We interpolate linearly in time and apply implicit unit conversion to the result. Note that we defer to scipy.interpolate to perform spatial interpolation. """ - if ei is None: + if particle is None: _ei = None else: - _ei = ei[self.igrid] + _ei = particle.ei[self.igrid] value = self._interpolate(time, z, y, x, ei=_ei) @@ -683,7 +683,7 @@ def _interpolate(self, time, z, y, x, ei): u = self.U.eval(time, z, y, x, _ei, applyConversion=False) v = self.V.eval(time, z, y, x, _ei, applyConversion=False) if "3D" in self.vector_type: - w = self.W.eval(time, z, y, x, ei, applyConversion=False) + w = self.W.eval(time, z, y, x, _ei, applyConversion=False) return (u, v, w) else: return (u, v, 0) From f04d729e43708df94c9d6006ff678e31baf8a10f Mon Sep 17 00:00:00 2001 From: Joe Date: Tue, 1 Apr 2025 14:43:37 -0400 Subject: [PATCH 23/46] Move tests to tests/v4 --- {v4-tests => tests/v4}/test_uxarray_fieldset.py | 0 {v4-tests => tests/v4}/test_xarray_fieldset.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename {v4-tests => tests/v4}/test_uxarray_fieldset.py (100%) rename {v4-tests => tests/v4}/test_xarray_fieldset.py (100%) diff --git a/v4-tests/test_uxarray_fieldset.py b/tests/v4/test_uxarray_fieldset.py similarity index 100% rename from v4-tests/test_uxarray_fieldset.py rename to tests/v4/test_uxarray_fieldset.py diff --git a/v4-tests/test_xarray_fieldset.py b/tests/v4/test_xarray_fieldset.py similarity index 100% rename from v4-tests/test_xarray_fieldset.py rename to tests/v4/test_xarray_fieldset.py From 95c701c58538309c7d6287467dd7389bd95f5272 Mon Sep 17 00:00:00 2001 From: Joe Date: Wed, 2 Apr 2025 10:47:41 -0400 Subject: [PATCH 24/46] Formatting fixes --- parcels/_index_search.py | 65 ++--- parcels/application_kernels/__init__.py | 2 +- parcels/application_kernels/interpolation.py | 51 ++-- parcels/field.py | 254 ++++++++----------- parcels/fieldset.py | 83 +++--- tests/v4/test_uxarray_fieldset.py | 17 +- 6 files changed, 215 insertions(+), 257 deletions(-) diff --git a/parcels/_index_search.py b/parcels/_index_search.py index 5bacf2f8e5..60122ca682 100644 --- a/parcels/_index_search.py +++ b/parcels/_index_search.py @@ -1,9 +1,9 @@ from __future__ import annotations +from datetime import datetime from typing import TYPE_CHECKING import numpy as np -from datetime import datetime from parcels._typing import ( GridIndexingType, @@ -22,10 +22,10 @@ if TYPE_CHECKING: from .field import Field - #from .grid import Grid + # from .grid import Grid -def _search_time_index(field: Field, time: datetime , allow_time_extrapolation=True): +def _search_time_index(field: Field, time: datetime, allow_time_extrapolation=True): """Find and return the index and relative coordinate in the time array associated with a given time. Parameters @@ -52,12 +52,17 @@ def _search_time_index(field: Field, time: datetime , allow_time_extrapolation=T ti = 0 else: ti = int(time_index.argmin() - 1) if time_index.any() else 0 - if len(field.data.time)== 1: + if len(field.data.time) == 1: tau = 0 elif ti == len(field.data.time) - 1: tau = 1 else: - tau = (time - field.data.time[ti]).total_seconds() / (field.data.time[ti + 1] - field.data.time[ti]).total_seconds() if field.data.time[ti] != field.data.time[ti + 1] else 0 + tau = ( + (time - field.data.time[ti]).total_seconds() + / (field.data.time[ti + 1] - field.data.time[ti]).total_seconds() + if field.data.time[ti] != field.data.time[ti + 1] + else 0 + ) return tau, ti @@ -101,7 +106,7 @@ def search_indices_vertical_z(depth, gridindexingtype: GridIndexingType, z: floa ## joe@fluidnumerics.com : 3/27/25 : To do : Still need to implement the search_indices_vertical_s function def search_indices_vertical_s( - grid: Grid, + field: Field, interp_method: InterpMethodOption, time: float, z: float, @@ -116,32 +121,32 @@ def search_indices_vertical_s( if interp_method in ["bgrid_velocity", "bgrid_w_velocity", "bgrid_tracer"]: xsi = 1 eta = 1 - if time < grid.time[ti]: + if time < field.time[ti]: ti -= 1 - if grid._z4d: # type: ignore[attr-defined] - if ti == len(grid.time) - 1: + if field._z4d: # type: ignore[attr-defined] + if ti == len(field.time) - 1: depth_vector = ( - (1 - xsi) * (1 - eta) * grid.depth[-1, :, yi, xi] - + xsi * (1 - eta) * grid.depth[-1, :, yi, xi + 1] - + xsi * eta * grid.depth[-1, :, yi + 1, xi + 1] - + (1 - xsi) * eta * grid.depth[-1, :, yi + 1, xi] + (1 - xsi) * (1 - eta) * field.depth[-1, :, yi, xi] + + xsi * (1 - eta) * field.depth[-1, :, yi, xi + 1] + + xsi * eta * field.depth[-1, :, yi + 1, xi + 1] + + (1 - xsi) * eta * field.depth[-1, :, yi + 1, xi] ) else: dv2 = ( - (1 - xsi) * (1 - eta) * grid.depth[ti : ti + 2, :, yi, xi] - + xsi * (1 - eta) * grid.depth[ti : ti + 2, :, yi, xi + 1] - + xsi * eta * grid.depth[ti : ti + 2, :, yi + 1, xi + 1] - + (1 - xsi) * eta * grid.depth[ti : ti + 2, :, yi + 1, xi] + (1 - xsi) * (1 - eta) * field.depth[ti : ti + 2, :, yi, xi] + + xsi * (1 - eta) * field.depth[ti : ti + 2, :, yi, xi + 1] + + xsi * eta * field.depth[ti : ti + 2, :, yi + 1, xi + 1] + + (1 - xsi) * eta * field.depth[ti : ti + 2, :, yi + 1, xi] ) - tt = (time - grid.time[ti]) / (grid.time[ti + 1] - grid.time[ti]) + tt = (time - field.time[ti]) / (field.time[ti + 1] - field.time[ti]) assert tt >= 0 and tt <= 1, "Vertical s grid is being wrongly interpolated in time" depth_vector = dv2[0, :] * (1 - tt) + dv2[1, :] * tt else: depth_vector = ( - (1 - xsi) * (1 - eta) * grid.depth[:, yi, xi] - + xsi * (1 - eta) * grid.depth[:, yi, xi + 1] - + xsi * eta * grid.depth[:, yi + 1, xi + 1] - + (1 - xsi) * eta * grid.depth[:, yi + 1, xi] + (1 - xsi) * (1 - eta) * field.depth[:, yi, xi] + + xsi * (1 - eta) * field.depth[:, yi, xi + 1] + + xsi * eta * field.depth[:, yi + 1, xi + 1] + + (1 - xsi) * eta * field.depth[:, yi + 1, xi] ) z = np.float32(z) # type: ignore # TODO: remove type ignore once we migrate to float64 @@ -176,13 +181,15 @@ def search_indices_vertical_s( def _search_indices_rectilinear( - field: Field, time: datetime, z: float, y: float, x: float, ti: int, ei:int=None, search2D=False + field: Field, time: datetime, z: float, y: float, x: float, ti: int, ei: int | None = None, search2D=False ): # To do : If ei is provided, check if particle is in the same cell if field.xdim > 1 and (not field.zonal_periodic): - if x < grid.lonlat_minmax[0] or x > grid.lonlat_minmax[1]: # To do : implement lonlat_minmax at field level + if x < field.lonlat_minmax[0] or x > field.lonlat_minmax[1]: # To do : implement lonlat_minmax at field level _raise_field_out_of_bound_error(z, y, x) - if field.ydim > 1 and (y < grid.lonlat_minmax[2] or y > grid.lonlat_minmax[3]): # To do : implement lonlat_minmax at field level + if field.ydim > 1 and ( + y < field.lonlat_minmax[2] or y > field.lonlat_minmax[3] + ): # To do : implement lonlat_minmax at field level _raise_field_out_of_bound_error(z, y, x) if field.xdim > 1: @@ -266,8 +273,8 @@ def _search_indices_curvilinear(field: Field, time, z, y, x, ti, particle=None, if particle: zi, yi, xi = field.unravel_index(particle.ei) else: - xi = int(field.grid.xdim / 2) - 1 - yi = int(field.grid.ydim / 2) - 1 + xi = int(field.xdim / 2) - 1 + yi = int(field.ydim / 2) - 1 xsi = eta = -1.0 grid = field.grid invA = np.array([[1, 0, 0, 0], [-1, 1, 0, 0], [-1, 0, 0, 1], [1, -1, 1, -1]]) @@ -275,12 +282,12 @@ def _search_indices_curvilinear(field: Field, time, z, y, x, ti, particle=None, it = 0 tol = 1.0e-10 if not grid.zonal_periodic: - if x < grid.lonlat_minmax[0] or x > grid.lonlat_minmax[1]: + if x < field.lonlat_minmax[0] or x > field.lonlat_minmax[1]: if grid.lon[0, 0] < grid.lon[0, -1]: _raise_field_out_of_bound_error(z, y, x) elif x < grid.lon[0, 0] and x > grid.lon[0, -1]: # This prevents from crashing in [160, -160] _raise_field_out_of_bound_error(z, y, x) - if y < grid.lonlat_minmax[2] or y > grid.lonlat_minmax[3]: + if y < field.lonlat_minmax[2] or y > field.lonlat_minmax[3]: _raise_field_out_of_bound_error(z, y, x) while xsi < -tol or xsi > 1 + tol or eta < -tol or eta > 1 + tol: diff --git a/parcels/application_kernels/__init__.py b/parcels/application_kernels/__init__.py index b99e4f2544..cd933324ed 100644 --- a/parcels/application_kernels/__init__.py +++ b/parcels/application_kernels/__init__.py @@ -1,4 +1,4 @@ from .advection import * from .advectiondiffusion import * from .interaction import * -from .interpolation import * \ No newline at end of file +from .interpolation import * diff --git a/parcels/application_kernels/interpolation.py b/parcels/application_kernels/interpolation.py index 2cc13a7347..eea65b1c6e 100644 --- a/parcels/application_kernels/interpolation.py +++ b/parcels/application_kernels/interpolation.py @@ -1,10 +1,7 @@ """Collection of pre-built interpolation kernels.""" -import math -from typing import Union import numpy as np -from parcels.tools.statuscodes import StatusCode from parcels.field import Field __all__ = [ @@ -12,19 +9,20 @@ "UXPiecewiseLinearNode", ] + def UXPiecewiseConstantFace( - field: Field, - ti: int, - ei: int, - bcoords: np.ndarray, - tau: Union[np.float32,np.float64], - t: Union[np.float32,np.float64], - z: Union[np.float32,np.float64], - y: Union[np.float32,np.float64], - x: Union[np.float32,np.float64] - ): + field: Field, + ti: int, + ei: int, + bcoords: np.ndarray, + tau: np.float32 | np.float64, + t: np.float32 | np.float64, + z: np.float32 | np.float64, + y: np.float32 | np.float64, + x: np.float32 | np.float64, +): """ - Piecewise constant interpolation kernel for face registered data. + Piecewise constant interpolation kernel for face registered data. This interpolation method is appropriate for fields that are face registered, such as u,v in FESOM. """ @@ -32,17 +30,18 @@ def UXPiecewiseConstantFace( zi, fi = field.unravel_index(ei) return field.data[ti, zi, fi] + def UXPiecewiseLinearNode( - field: Field, - ti: int, - ei: int, - bcoords: np.ndarray, - tau: Union[np.float32,np.float64], - t: Union[np.float32,np.float64], - z: Union[np.float32,np.float64], - y: Union[np.float32,np.float64], - x: Union[np.float32,np.float64] - ): + field: Field, + ti: int, + ei: int, + bcoords: np.ndarray, + tau: np.float32 | np.float64, + t: np.float32 | np.float64, + z: np.float32 | np.float64, + y: np.float32 | np.float64, + x: np.float32 | np.float64, +): """ Piecewise linear interpolation kernel for node registered data. This interpolation method is appropriate for fields that are node registered @@ -50,5 +49,5 @@ def UXPiecewiseLinearNode( """ # To do: handle vertical interpolation zi, fi = field.unravel_index(ei) - node_ids = field.data.uxgrid.face_node_connectivity[fi,:] - return np.dot(field.data[ti, zi, node_ids],bcoords) + node_ids = field.data.uxgrid.face_node_connectivity[fi, :] + return np.dot(field.data[ti, zi, node_ids], bcoords) diff --git a/parcels/field.py b/parcels/field.py index 69b6191853..72f908942d 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -1,36 +1,22 @@ -import collections -import math +import inspect import warnings -from typing import TYPE_CHECKING, cast +from collections.abc import Callable +from datetime import datetime +from enum import IntEnum +from typing import TYPE_CHECKING -import dask.array as da import numpy as np -import xarray as xr import uxarray as ux +import xarray as xr from uxarray.grid.neighbors import _barycentric_coordinates -from datetime import datetime - -import parcels.tools.interpolation_utils as i_u from parcels._compat import add_note -from parcels._interpolation import ( - InterpolationContext2D, - InterpolationContext3D, - get_2d_interpolator_registry, - get_3d_interpolator_registry, -) from parcels._typing import ( - GridIndexingType, - InterpMethod, - InterpMethodOption, Mesh, VectorType, - assert_valid_gridindexingtype, - assert_valid_interp_method, ) -from parcels.tools._helpers import default_repr, field_repr, should_calculate_next_ti +from parcels.tools._helpers import default_repr, field_repr from parcels.tools.converters import ( - TimeConverter, UnitConverter, unitconverters_map, ) @@ -41,18 +27,14 @@ FieldSamplingError, _raise_field_out_of_bound_error, ) -from parcels.tools.warnings import FieldSetWarning -import inspect -from typing import Callable, Union -from enum import IntEnum -from ._index_search import _search_indices_curvilinear, _search_indices_rectilinear, _search_time_index + +from ._index_search import _search_indices_rectilinear, _search_time_index if TYPE_CHECKING: - import numpy.typing as npt + pass - from parcels.fieldset import FieldSet +__all__ = ["Field", "GridType", "VectorField"] -__all__ = ["Field", "VectorField", "GridType"] class GridType(IntEnum): RectilinearZGrid = 0 @@ -60,6 +42,7 @@ class GridType(IntEnum): CurvilinearZGrid = 2 CurvilinearSGrid = 3 + def _isParticle(key): if hasattr(key, "obs_written"): return True @@ -81,18 +64,17 @@ def _deal_with_errors(error, key, vector_type: VectorType): return (0, 0) else: return 0 - + + class Field: - """The Field class that holds scalar field data. - The `Field` object is a wrapper around a xarray.DataArray or uxarray.UxDataArray object. + """The Field class that holds scalar field data. + The `Field` object is a wrapper around a xarray.DataArray or uxarray.UxDataArray object. Additionally, it holds a dynamic Callable procedure that is used to interpolate the field data. During initialization, the user can supply a custom interpolation method that is used to interpolate the field data, so long as the interpolation method has the correct signature. - + Notes ----- - - The xarray.DataArray or uxarray.UxDataArray object contains the field data and metadata. * dims: (time, [nz1 | nz], [face_lat | node_lat | edge_lat], [face_lon | node_lon | edge_lon]) * attrs: (location, mesh, mesh_type) @@ -103,7 +85,7 @@ class Field: * "node" * "face" * "x_edge" - * "y_edge" + * "y_edge" * For an A-Grid, the "location" attribute must be set to / is assumed to be "node" (node_lat,node_lon). * For a C-Grid, the "location" setting for a field has the following interpretation: * "node" ~> the field is associated with the vorticity points (node_lat, node_lon) @@ -112,7 +94,7 @@ class Field: * "y_edge" ~> the field is associated with the v-velocity points (node_lat, face_lon) When using a uxarray.UxDataArray object, - * The uxarray.UxDataArray.UxGrid object must have the "Conventions" attribute set to "UGRID-1.0" + * The uxarray.UxDataArray.UxGrid object must have the "Conventions" attribute set to "UGRID-1.0" and the uxarray.UxDataArray object must comply with the UGRID conventions. See https://ugrid-conventions.github.io/ugrid-conventions/ for more information. @@ -124,15 +106,15 @@ def _interp_template( ti: int, ei: int, bcoords: np.ndarray, - tau: Union[np.float32,np.float64], - t: Union[np.float32,np.float64], - z: Union[np.float32,np.float64], - y: Union[np.float32,np.float64], - x: Union[np.float32,np.float64] - )-> Union[np.float32,np.float64]: - """ Template function used for the signature check of the lateral interpolation methods.""" + tau: np.float32 | np.float64, + t: np.float32 | np.float64, + z: np.float32 | np.float64, + y: np.float32 | np.float64, + x: np.float32 | np.float64, + ) -> np.float32 | np.float64: + """Template function used for the signature check of the lateral interpolation methods.""" return 0.0 - + def _validate_interp_function(self, func: Callable) -> bool: """Ensures that the function has the correct signature.""" template_sig = inspect.signature(self._interp_template) @@ -141,7 +123,9 @@ def _validate_interp_function(self, func: Callable) -> bool: if len(template_sig.parameters) != len(func_sig.parameters): return False - for ((name1, param1), (name2, param2)) in zip(template_sig.parameters.items(), func_sig.parameters.items()): + for (_name1, param1), (_name2, param2) in zip( + template_sig.parameters.items(), func_sig.parameters.items(), strict=False + ): if param1.kind != param2.kind: return False if param1.annotation != param2.annotation: @@ -163,7 +147,6 @@ def __init__( interp_method: Callable | None = None, allow_time_extrapolation: bool | None = None, ): - self.name = name self.data = data @@ -181,12 +164,12 @@ def __init__( # Setting the interpolation method dynamically if interp_method is None: - self._interp_method = self._interp_template # Default to method that returns 0 always + self._interp_method = self._interp_template # Default to method that returns 0 always else: self._validate_interp_function(interp_method) self._interp_method = interp_method - self.igrid = -1 # Default the grid index to -1 + self.igrid = -1 # Default the grid index to -1 if self._mesh_type == "flat" or (self.name not in unitconverters_map.keys()): self.units = UnitConverter() @@ -194,7 +177,7 @@ def __init__( self.units = unitconverters_map[self.name] else: raise ValueError("Unsupported mesh type in data array attributes. Choose either: 'spherical' or 'flat'") - + if allow_time_extrapolation is None: self.allow_time_extrapolation = True if len(self.data["time"]) == 1 else False else: @@ -206,28 +189,32 @@ def __init__( else: self._spatialhash = None # Set the grid type - if "x_g" in self.data.coords : + if "x_g" in self.data.coords: lon = self.data.x_g else: lon = self.data.x_c - - if "nz1" in self.data.coords : + + if "nz1" in self.data.coords: depth = self.data.nz1 - elif "nz" in self.data.coords : + elif "nz" in self.data.coords: depth = self.data.nz - else : + else: depth = None if len(lon.shape) <= 1: - if depth is None or len(depth.shape) <=1: + if depth is None or len(depth.shape) <= 1: self._gtype = GridType.RectilinearZGrid else: self._gtype = GridType.RectilinearSGrid else: - if depth is None or len(depth.shape) <=1: + if depth is None or len(depth.shape) <= 1: self._gtype = GridType.CurvilinearZGrid else: - self._gtype = GridType.CurvilinearSGrid + self._gtype = GridType.CurvilinearSGrid + + self._lonlat_minmax = np.array( + [np.nanmin(self.lon), np.nanmax(self.lon), np.nanmin(self.lat), np.nanmax(self.lat)], dtype=np.float32 + ) def __repr__(self): return field_repr(self) @@ -237,8 +224,12 @@ def grid(self): if type(self.data) is ux.UxDataArray: return self.data.uxgrid else: - return self.data # To do : need to decide on what to return for xarray.DataArray objects - + return self.data # To do : need to decide on what to return for xarray.DataArray objects + + @property + def lonlat_minmax(self): + return self._lonlat_minmax + @property def lat(self): if type(self.data) is ux.UxDataArray: @@ -298,7 +289,8 @@ def xdim(self): elif "node_lon" in self.data.dims: return self.data.sizes["node_lon"] else: - return 0 # To do : Discuss what we want to return for uxdataarray obj + return 0 # To do : Discuss what we want to return for uxdataarray obj + @property def ydim(self): if type(self.data) is xr.DataArray: @@ -307,8 +299,8 @@ def ydim(self): elif "node_lat" in self.data.dims: return self.data.sizes["node_lat"] else: - return 0 # To do : Discuss what we want to return for uxdataarray obj - + return 0 # To do : Discuss what we want to return for uxdataarray obj + @property def zdim(self): if "nz1" in self.data.dims: @@ -317,17 +309,17 @@ def zdim(self): return self.data.sizes["nz"] else: return 0 - + @property def n_face(self): if type(self.data) is ux.uxDataArray: return self.data.uxgrid.n_face else: - return 0 # To do : Discuss what we want to return for dataarray obj - + return 0 # To do : Discuss what we want to return for dataarray obj + @property def interp_method(self): - return self._interp_method + return self._interp_method @interp_method.setter def interp_method(self, method: Callable): @@ -339,8 +331,7 @@ def interp_method(self, method: Callable): # return self._gridindexingtype def _get_ux_barycentric_coordinates(self, y, x, fi): - "Checks if a point is inside a given face id. Used for unstructured grids." - + """Checks if a point is inside a given face id. Used for unstructured grids.""" # Check if particle is in the same face, otherwise search again. n_nodes = self.data.uxgrid.n_nodes_per_face[fi].to_numpy() node_ids = self.data.uxgrid.face_node_connectivity[fi, 0:n_nodes] @@ -353,52 +344,44 @@ def _get_ux_barycentric_coordinates(self, y, x, fi): coord = np.deg2rad([x, y]) bcoord = np.asarray(_barycentric_coordinates(nodes, coord)) - err = abs(np.dot(bcoord, nodes[:, 0]) - coord[0]) + abs( - np.dot(bcoord, nodes[:, 1]) - coord[1] - ) + err = abs(np.dot(bcoord, nodes[:, 0]) - coord[0]) + abs(np.dot(bcoord, nodes[:, 1]) - coord[1]) return bcoord, err - def _search_indices_unstructured(self, z, y, x, ei=None, search2D=False): - tol = 1e-10 if ei is None: # Search using global search - fi, bcoords = self._spatialhash.query([[x,y]]) # Get the face id for the particle + fi, bcoords = self._spatialhash.query([[x, y]]) # Get the face id for the particle if fi == -1: - raise FieldOutOfBoundError(z, y, x) # To do : how to handle lost particle ?? + raise FieldOutOfBoundError(z, y, x) # To do : how to handle lost particle ?? # To do : Do the vertical grid search # zi = self._vertical_search(z) - zi = 0 # For now + zi = 0 # For now return bcoords, self.ravel_index(zi, 0, fi) else: - zi, fi = self.unravel_index(ei[self.igrid]) # Get the z, and face index of the particle + zi, fi = self.unravel_index(ei[self.igrid]) # Get the z, and face index of the particle # Search using nearest neighbors bcoords, err = self._get_ux_barycentric_coordinates(y, x, fi) - if ((bcoords >= 0).all()) and ((bcoords <= 1.0).all()) and err < tol: + if ((bcoords >= 0).all()) and ((bcoords <= 1.0).all()) and err < tol: # To do: Do the vertical grid search return bcoords, ei else: # In this case we need to search the neighbors - for neighbor in self.data.uxgrid.face_face_connectivity[fi,:]: + for neighbor in self.data.uxgrid.face_face_connectivity[fi, :]: bcoords, err = self._get_ux_barycentric_coordinates(y, x, neighbor) - if ((bcoords >= 0).all()) and ((bcoords <= 1.0).all()) and err < tol: + if ((bcoords >= 0).all()) and ((bcoords <= 1.0).all()) and err < tol: # To do: Do the vertical grid search return bcoords, self.ravel_index(zi, 0, neighbor) # If we reach this point, we do a global search as a last ditch effort the particle is out of bounds - fi, bcoords = self._spatialhash.query([[x,y]]) # Get the face id for the particle + fi, bcoords = self._spatialhash.query([[x, y]]) # Get the face id for the particle if fi == -1: - raise FieldOutOfBoundError(z, y, x) # To do : how to handle lost particle ?? - + raise FieldOutOfBoundError(z, y, x) # To do : how to handle lost particle ?? def _search_indices_structured(self, z, y, x, ei=None, search2D=False): - - if self._gtype in [GridType.RectilinearSGrid, GridType.RectilinearZGrid]: - (zeta, eta, xsi, zi, yi, xi) = _search_indices_rectilinear( - self, z, y, x,ei=ei, search2D=search2D - ) + if self._gtype in [GridType.RectilinearSGrid, GridType.RectilinearZGrid]: + (zeta, eta, xsi, zi, yi, xi) = _search_indices_rectilinear(self, z, y, x, ei=ei, search2D=search2D) else: ## joe@fluidnumerics.com : 3/27/25 : To do : Still need to implement the search_indices_curvilinear # (zeta, eta, xsi, zi, yi, xi) = _search_indices_curvilinear( @@ -407,10 +390,9 @@ def _search_indices_structured(self, z, y, x, ei=None, search2D=False): raise NotImplementedError("Curvilinear grid search not implemented yet") return (zeta, eta, xsi, zi, yi, xi) - - def _search_indices(self, time: datetime, z, y, x, ei=None, search2D=False): - tau, ti = _search_time_index(self,time,self.allow_time_extrapolation) + def _search_indices(self, time: datetime, z, y, x, ei=None, search2D=False): + tau, ti = _search_time_index(self, time, self.allow_time_extrapolation) if ei is None: _ei = None @@ -421,10 +403,9 @@ def _search_indices(self, time: datetime, z, y, x, ei=None, search2D=False): bcoords, ei = self._search_indices_unstructured(z, y, x, ei=_ei, search2D=search2D) else: bcoords, ei = self._search_indices_structured(z, y, x, ei=_ei, search2D=search2D) - return bcoords, ei, tau, ti - - def _interpolate(self, time: datetime, z, y, x, ei): + return bcoords, ei, tau, ti + def _interpolate(self, time: datetime, z, y, x, ei): try: bcoords, _ei, tau, ti = self._search_indices(time, z, y, x, ei=ei) val = self._interp_method(ti, _ei, bcoords, tau, time, z, y, x) @@ -434,7 +415,7 @@ def _interpolate(self, time: datetime, z, y, x, ei): _raise_field_out_of_bound_error(z, y, x) else: return val - + except (FieldSamplingError, FieldOutOfBoundError, FieldOutOfBoundSurfaceError) as e: e = add_note(e, f"Error interpolating field '{self.name}'.", before=True) raise e @@ -456,7 +437,7 @@ def __getitem__(self, key): return self.eval(*key) except tuple(AllParcelsErrorCodes.keys()) as error: return _deal_with_errors(error, key, vector_type=None) - + def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True): """Interpolate field values in space and time. @@ -475,7 +456,7 @@ def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True): return self.units.to_target(value, z, y, x) else: return value - + def _rescale_and_set_minmax(self, data): data[np.isnan(data)] = 0 return data @@ -500,7 +481,7 @@ def ravel_index(self, zi, yi, xi): if type(self.data) is xr.DataArray: return xi + self.xdim * (yi + self.ydim * zi) else: - return xi + self.n_face*zi + return xi + self.n_face * zi def unravel_index(self, ei): """Return the zi, yi, xi indices for a given flat index. @@ -534,22 +515,22 @@ def unravel_index(self, ei): return zi, fi def _validate_dataarray(self): - """ Verifies that all the required attributes are present in the xarray.DataArray or - uxarray.UxDataArray object.""" - + """Verifies that all the required attributes are present in the xarray.DataArray or + uxarray.UxDataArray object. + """ # Validate dimensions - if not( "nz1" in self.data.dims or "nz" in self.data.dims ): + if not ("nz1" in self.data.dims or "nz" in self.data.dims): raise ValueError( f"Field {self.name} is missing a 'nz1' or 'nz' dimension in the field's metadata. " "This attribute is required for xarray.DataArray objects." ) - - if not( "time" in self.data.dims ): + + if "time" not in self.data.dims: raise ValueError( f"Field {self.name} is missing a 'time' dimension in the field's metadata. " "This attribute is required for xarray.DataArray objects." ) - + # Validate attributes required_keys = ["location", "mesh"] for key in required_keys: @@ -558,14 +539,12 @@ def _validate_dataarray(self): f"Field {self.name} is missing a '{key}' attribute in the field's metadata. " "This attribute is required for xarray.DataArray objects." ) - + if type(self.data) is ux.UxDataArray: self._validate_uxgrid() - def _validate_uxgrid(self): - """ Verifies that all the required attributes are present in the uxarray.UxDataArray.UxGrid object.""" - + """Verifies that all the required attributes are present in the uxarray.UxDataArray.UxGrid object.""" if "Conventions" not in self.data.uxgrid.attrs.keys(): raise ValueError( f"Field {self.name} is missing a 'Conventions' attribute in the field's metadata. " @@ -578,62 +557,50 @@ def _validate_uxgrid(self): "See https://ugrid-conventions.github.io/ugrid-conventions/ for more information." ) - def __getattr__(self, key: str): return getattr(self.data, key) def __contains__(self, key: str): return key in self.data - + class VectorField: """VectorField class that holds vector field data needed to execute particles.""" - @staticmethod def _vector_interp_template( self, ti: int, ei: int, bcoords: np.ndarray, - t: Union[np.float32,np.float64], - z: Union[np.float32,np.float64], - y: Union[np.float32,np.float64], - x: Union[np.float32,np.float64] - )-> Union[np.float32,np.float64]: - """ Template function used for the signature check of the lateral interpolation methods.""" + t: np.float32 | np.float64, + z: np.float32 | np.float64, + y: np.float32 | np.float64, + x: np.float32 | np.float64, + ) -> np.float32 | np.float64: + """Template function used for the signature check of the lateral interpolation methods.""" return 0.0 - + def _validate_vector_interp_function(self, func: Callable): """Ensures that the function has the correct signature.""" expected_params = ["ti", "ei", "bcoords", "t", "z", "y", "x"] - expected_return_types = (np.float32,np.float64) + expected_return_types = (np.float32, np.float64) sig = inspect.signature(func) params = list(sig.parameters.keys()) # Check the parameter names and count if params != expected_params: - raise TypeError( - f"Function must have parameters {expected_params}, but got {params}" - ) + raise TypeError(f"Function must have parameters {expected_params}, but got {params}") # Check return annotation if present return_annotation = sig.return_annotation if return_annotation not in (inspect.Signature.empty, *expected_return_types): - raise TypeError( - f"Function must return a float, but got {return_annotation}" - ) - + raise TypeError(f"Function must return a float, but got {return_annotation}") + def __init__( - self, - name: str, - U: Field, - V: Field, - W: Field | None = None, - vector_interp_method: Callable | None = None - ): - + self, name: str, U: Field, V: Field, W: Field | None = None, vector_interp_method: Callable | None = None + ): self.name = name self.U = U self.V = V @@ -660,7 +627,7 @@ def __repr__(self): @property def vector_interp_method(self): - return self._vector_interp_method + return self._vector_interp_method @vector_interp_method.setter def vector_interp_method(self, method: Callable): @@ -676,7 +643,6 @@ def vector_interp_method(self, method: Callable): # and np.allclose(grid1.time, grid2.time) # ) def _interpolate(self, time, z, y, x, ei): - bcoords, _ei, ti = self._search_indices(time, z, y, x, ei=ei) if self._vector_interp_method is None: @@ -688,32 +654,30 @@ def _interpolate(self, time, z, y, x, ei): else: return (u, v, 0) else: - (u,v,w) = self._vector_interp_method(ti, _ei, bcoords, time, z, y, x) + (u, v, w) = self._vector_interp_method(ti, _ei, bcoords, time, z, y, x) return (u, v, w) - def eval(self, time, z, y, x, ei=None, applyConversion=True): - if ei is None: _ei = 0 else: _ei = ei[self.igrid] - (u,v,w) = self._interpolate(time, z, y, x, _ei) + (u, v, w) = self._interpolate(time, z, y, x, _ei) if applyConversion: u = self.U.units.to_target(u, z, y, x) v = self.V.units.to_target(v, z, y, x) if "3D" in self.vector_type: w = self.W.units.to_target(w, z, y, x) - + return (u, v, w) - def __getitem__(self,key): + def __getitem__(self, key): try: if _isParticle(key): return self.eval(key.time, key.depth, key.lat, key.lon, key.ei) else: return self.eval(*key) except tuple(AllParcelsErrorCodes.keys()) as error: - return _deal_with_errors(error, key, vector_type=self.vector_type) \ No newline at end of file + return _deal_with_errors(error, key, vector_type=self.vector_type) diff --git a/parcels/fieldset.py b/parcels/fieldset.py index 6e8875e8bc..c700ab6873 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -1,28 +1,20 @@ -import importlib.util import os -import sys -import warnings from glob import glob import numpy as np +import uxarray as ux +import xarray as xr -from parcels._typing import GridIndexingType, InterpMethodOption, Mesh +from parcels._typing import Mesh from parcels.field import Field, VectorField -from parcels.particlefile import ParticleFile -from parcels.tools._helpers import fieldset_repr, default_repr -from parcels.tools.converters import TimeConverter -from parcels.tools.warnings import FieldSetWarning - -import xarray as xr -import uxarray as ux -from typing import List, Union +from parcels.tools._helpers import fieldset_repr __all__ = ["FieldSet"] class FieldSet: """FieldSet class that holds hydrodynamic data needed to execute particles. - + Parameters ---------- ds : xarray.Dataset | uxarray.UxDataset) @@ -30,27 +22,27 @@ class FieldSet: Notes ----- - The `ds` object is a xarray.Dataset or uxarray.UxDataset object. - In XArray terminology, the (Ux)Dataset holds multiple (Ux)DataArray objects. + The `ds` object is a xarray.Dataset or uxarray.UxDataset object. + In XArray terminology, the (Ux)Dataset holds multiple (Ux)DataArray objects. Each (Ux)DataArray object is a single "field" that is associated with their own dimensions and coordinates within the (Ux)Dataset. A (Ux)Dataset object is associated with a single mesh, which can have multiple - types of "points" (multiple "grids") (e.g. for UxDataSets, these are "face_lon", - "face_lat", "node_lon", "node_lat", "edge_lon", "edge_lat"). Each (Ux)DataArray is + types of "points" (multiple "grids") (e.g. for UxDataSets, these are "face_lon", + "face_lat", "node_lon", "node_lat", "edge_lon", "edge_lat"). Each (Ux)DataArray is registered to a specific set of points on the mesh. - For UxDataset objects, each `UXDataArray.attributes` field dictionary contains + For UxDataset objects, each `UXDataArray.attributes` field dictionary contains the necessary metadata to help determine which set of points a field is registered to and what parent model the field is associated with. Parcels uses this metadata during execution for interpolation. Each `UXDataArray.attributes` field dictionary - must have: + must have: * "location" key set to "face", "node", or "edge" to define which pairing of points a field is associated with. * "mesh" key to define which parent model the fields are associated with (e.g. "fesom_mesh", "icon_mesh") """ - def __init__(self, datasets: List[Union[xr.Dataset,ux.UxDataset]]): + def __init__(self, datasets: list[xr.Dataset | ux.UxDataset]): self.datasets = datasets self._completed: bool = False @@ -60,10 +52,10 @@ def __init__(self, datasets: List[Union[xr.Dataset,ux.UxDataset]]): # Create pointers to each (Ux)DataArray for ds in datasets: for field in ds.data_vars: - self.add_field(Field(field, ds[field]),field) + self.add_field(Field(field, ds[field]), field) self._gridset_size += 1 self._fieldnames.append(field) - #setattr(self, field, Field(field,self.ds[field]), field) + # setattr(self, field, Field(field,self.ds[field]), field) if "time" in ds.coords: if time_origin is None: @@ -72,14 +64,13 @@ def __init__(self, datasets: List[Union[xr.Dataset,ux.UxDataset]]): time_origin = min(time_origin, ds.time.min().data) else: raise ValueError("Each dataset must have a 'time' coordinate") - + self.time_origin = time_origin self._add_UVfield() - def __repr__(self): return fieldset_repr(self) - + def dimrange(self, dim): """Returns maximum value of a dimension (lon, lat, depth or time) on 'left' side and minimum value on 'right' side for all grids @@ -88,14 +79,14 @@ def dimrange(self, dim): """ maxleft, minright = (-np.inf, np.inf) dim2ds = { - "depth": ["nz1","nz"], + "depth": ["nz1", "nz"], "lat": ["node_lat", "face_lat", "edge_lat"], "lon": ["node_lon", "face_lon", "edge_lon"], - "time": ["time"] + "time": ["time"], } for ds in self.datasets: for field in ds.data_vars: - for d in dim2ds[dim]: # check all possible dimensions + for d in dim2ds[dim]: # check all possible dimensions if d in ds[field].dims: if dim == "depth": maxleft = max(maxleft, ds[field][d].min().data) @@ -107,7 +98,7 @@ def dimrange(self, dim): minright = 0 if minright == np.inf else minright # if all len(dim) == 1 return maxleft, minright - + # @property # def particlefile(self): # return self._particlefile @@ -120,7 +111,7 @@ def dimrange(self, dim): @property def gridset_size(self): return self._gridset_size - + def add_field(self, field: Field, name: str | None = None): """Add a :class:`parcels.field.Field` object to the FieldSet. @@ -140,7 +131,6 @@ def add_field(self, field: Field, name: str | None = None): * `Unit converters <../examples/tutorial_unitconverters.ipynb>`__ (Default value = None) """ - if self._completed: raise RuntimeError( "FieldSet has already been completed. Are you trying to add a Field after you've created the ParticleSet?" @@ -151,7 +141,7 @@ def add_field(self, field: Field, name: str | None = None): raise RuntimeError(f"FieldSet already has a Field with name '{name}'") else: setattr(self, name, field) - self._gridset_size += 1 + self._gridset_size += 1 self._fieldnames.append(name) def add_constant_field(self, name: str, value: float, mesh: Mesh = "flat"): @@ -172,27 +162,21 @@ def add_constant_field(self, name: str, value: float, mesh: Mesh = "flat"): correction for zonal velocity U near the poles. 2. flat: No conversion, lat/lon are assumed to be in m. """ - time = 0.0 - values = np.zeros((1,1,1,1), dtype=np.float32) + value + values = np.zeros((1, 1, 1, 1), dtype=np.float32) + value data = xr.DataArray( data=values, name=name, - dims='null', - coords = [time,[0],[0],[0]], - attrs=dict( - description="null", - units="null", - location="node", - mesh=f"constant", - mesh_type=mesh - )) + dims="null", + coords=[time, [0], [0], [0]], + attrs=dict(description="null", units="null", location="node", mesh="constant", mesh_type=mesh), + ) self.add_field( Field( name, data, - interp_method=None, # To do : Need to define an interpolation method for constants - allow_time_extrapolation=True + interp_method=None, # To do : Need to define an interpolation method for constants + allow_time_extrapolation=True, ) ) @@ -219,7 +203,7 @@ def get_fields(self) -> list[Field | VectorField]: if v not in fields: fields.append(v) return fields - + def _add_UVfield(self): if not hasattr(self, "UV") and hasattr(self, "U") and hasattr(self, "V"): self.add_vector_field(VectorField("UV", self.U, self.V)) @@ -248,7 +232,7 @@ def _parse_wildcards(cls, paths, filenames, var): if not os.path.exists(fp): raise OSError(f"FieldSet file not found: {fp}") return paths - + # @classmethod # def from_netcdf( # cls, @@ -260,7 +244,7 @@ def _parse_wildcards(cls, paths, filenames, var): # allow_time_extrapolation: bool | None = None, # **kwargs, # ): - + # @classmethod # def from_nemo( # cls, @@ -272,7 +256,7 @@ def _parse_wildcards(cls, paths, filenames, var): # tracer_interp_method: InterpMethodOption = "cgrid_tracer", # **kwargs, # ): - + # @classmethod # def from_mitgcm( # cls, @@ -311,7 +295,6 @@ def _parse_wildcards(cls, paths, filenames, var): # **kwargs, # ): - # @classmethod # def from_mom5( # cls, diff --git a/tests/v4/test_uxarray_fieldset.py b/tests/v4/test_uxarray_fieldset.py index f9a9cc2ae5..27016d0ebe 100644 --- a/tests/v4/test_uxarray_fieldset.py +++ b/tests/v4/test_uxarray_fieldset.py @@ -1,14 +1,14 @@ +import os import uxarray as ux -from datetime import timedelta + from parcels import ( FieldSet, - ParticleSet, Particle, + ParticleSet, UXPiecewiseConstantFace, UXPiecewiseLinearNode, ) -import os # Get path of this script V4_TEST_DATA = f"{os.path.dirname(__file__)}/test_data" @@ -29,6 +29,7 @@ def test_fesom_fieldset(): # Check that the fieldset has the expected properties assert fieldset.datasets[0] == ds + def test_fesom_in_particleset(): grid_path = f"{V4_TEST_DATA}/fesom_channel.nc" data_path = [ @@ -39,7 +40,10 @@ def test_fesom_in_particleset(): ds = ux.open_mfdataset(grid_path, data_path) ds = ds.rename_vars({"u": "U", "v": "V", "w": "W"}) fieldset = FieldSet([ds]) + # Check that the fieldset has the expected properties + assert fieldset.datasets[0] == ds pset = ParticleSet(fieldset, pclass=Particle) + assert pset.fieldset == fieldset def test_set_interp_methods(): @@ -57,8 +61,9 @@ def test_set_interp_methods(): fieldset.V.interp_method = UXPiecewiseConstantFace fieldset.W.interp_method = UXPiecewiseLinearNode - # pset = ParticleSet(fieldset, pclass=Particle) - # pset.execute(associate_interp_function, endtime=timedelta(days=1), dt=timedelta(hours=1)) + +# pset = ParticleSet(fieldset, pclass=Particle) +# pset.execute(associate_interp_function, endtime=timedelta(days=1), dt=timedelta(hours=1)) if __name__ == "__main__": - test_set_interp_methods() \ No newline at end of file + test_set_interp_methods() From b0967d9e3353ee3e5070ee57498111eb93ea46f7 Mon Sep 17 00:00:00 2001 From: Joe Date: Fri, 4 Apr 2025 14:58:20 -0400 Subject: [PATCH 25/46] Fix a few tests in test_fieldset up to fieldset.from_data --- parcels/field.py | 58 ++++++++--------- parcels/fieldset.py | 141 +++++++++++++++++++++++++++++++++++++++-- tests/test_fieldset.py | 15 +++-- 3 files changed, 167 insertions(+), 47 deletions(-) diff --git a/parcels/field.py b/parcels/field.py index 72f908942d..01fdd60615 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -179,7 +179,7 @@ def __init__( raise ValueError("Unsupported mesh type in data array attributes. Choose either: 'spherical' or 'flat'") if allow_time_extrapolation is None: - self.allow_time_extrapolation = True if len(self.data["time"]) == 1 else False + self.allow_time_extrapolation = True if len(getattr(self.data, "time", [])) == 1 else False else: self.allow_time_extrapolation = allow_time_extrapolation @@ -191,13 +191,17 @@ def __init__( # Set the grid type if "x_g" in self.data.coords: lon = self.data.x_g - else: + elif "x_c" in self.data.coords: lon = self.data.x_c + else: + lon = self.data.lon if "nz1" in self.data.coords: depth = self.data.nz1 elif "nz" in self.data.coords: depth = self.data.nz + elif "depth" in self.data.coords: + depth = self.data.depth else: depth = None @@ -240,14 +244,7 @@ def lat(self): elif self._location == "edge": return self.data.uxgrid.edge_lat else: - if self._location == "node": - return self.data.node_lat - elif self._location == "face": - return self.data.face_lat - elif self._location == "x_edge": - return self.data.face_lat - elif self._location == "y_edge": - return self.data.node_lat + return self.data.lat @property def lon(self): @@ -259,14 +256,7 @@ def lon(self): elif self._location == "edge": return self.data.uxgrid.edge_lon else: - if self._location == "node": - return self.data.node_lon - elif self._location == "face": - return self.data.face_lon - elif self._location == "x_edge": - return self.data.node_lon - elif self._location == "y_edge": - return self.data.face_lon + return self.data.lon @property def depth(self): @@ -276,10 +266,7 @@ def depth(self): elif self._vertical_location == "face": return self.data.uxgrid.nz else: - if self._vertical_location == "center": - return self.data.nz1 - elif self._vertical_location == "face": - return self.data.nz + return self.data.depth @property def xdim(self): @@ -288,6 +275,8 @@ def xdim(self): return self.data.sizes["face_lon"] elif "node_lon" in self.data.dims: return self.data.sizes["node_lon"] + else: + return self.data.sizes["lon"] else: return 0 # To do : Discuss what we want to return for uxdataarray obj @@ -298,6 +287,8 @@ def ydim(self): return self.data.sizes["face_lat"] elif "node_lat" in self.data.dims: return self.data.sizes["node_lat"] + else: + return self.data.sizes["lat"] else: return 0 # To do : Discuss what we want to return for uxdataarray obj @@ -518,18 +509,19 @@ def _validate_dataarray(self): """Verifies that all the required attributes are present in the xarray.DataArray or uxarray.UxDataArray object. """ - # Validate dimensions - if not ("nz1" in self.data.dims or "nz" in self.data.dims): - raise ValueError( - f"Field {self.name} is missing a 'nz1' or 'nz' dimension in the field's metadata. " - "This attribute is required for xarray.DataArray objects." - ) + if isinstance(self.data, ux.UxDataArray): + # Validate dimensions + if not ("nz1" in self.data.dims or "nz" in self.data.dims): + raise ValueError( + f"Field {self.name} is missing a 'nz1' or 'nz' dimension in the field's metadata. " + "This attribute is required for xarray.DataArray objects." + ) - if "time" not in self.data.dims: - raise ValueError( - f"Field {self.name} is missing a 'time' dimension in the field's metadata. " - "This attribute is required for xarray.DataArray objects." - ) + if "time" not in self.data.dims: + raise ValueError( + f"Field {self.name} is missing a 'time' dimension in the field's metadata. " + "This attribute is required for xarray.DataArray objects." + ) # Validate attributes required_keys = ["location", "mesh"] diff --git a/parcels/fieldset.py b/parcels/fieldset.py index c700ab6873..3cd2a5a053 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -55,7 +55,6 @@ def __init__(self, datasets: list[xr.Dataset | ux.UxDataset]): self.add_field(Field(field, ds[field]), field) self._gridset_size += 1 self._fieldnames.append(field) - # setattr(self, field, Field(field,self.ds[field]), field) if "time" in ds.coords: if time_origin is None: @@ -63,7 +62,7 @@ def __init__(self, datasets: list[xr.Dataset | ux.UxDataset]): else: time_origin = min(time_origin, ds.time.min().data) else: - raise ValueError("Each dataset must have a 'time' coordinate") + time_origin = 0.0 self.time_origin = time_origin self._add_UVfield() @@ -103,11 +102,12 @@ def dimrange(self, dim): # def particlefile(self): # return self._particlefile - # @staticmethod - # def checkvaliddimensionsdict(dims): - # for d in dims: - # if d not in ["lon", "lat", "depth", "time"]: - # raise NameError(f"{d} is not a valid key in the dimensions dictionary") + @staticmethod + def checkvaliddimensionsdict(dims): + for d in dims: + if d not in ["lon", "lat", "depth", "time"]: + raise NameError(f"{d} is not a valid key in the dimensions dictionary") + @property def gridset_size(self): return self._gridset_size @@ -233,6 +233,133 @@ def _parse_wildcards(cls, paths, filenames, var): raise OSError(f"FieldSet file not found: {fp}") return paths + @classmethod + def from_data( + cls, + data, + dimensions, + mesh: Mesh = "spherical", + allow_time_extrapolation: bool | None = None, + **kwargs, + ): + """Initialise FieldSet object from raw data. Assumes structured grid and uses xarray.dataarray objects. + + Parameters + ---------- + data : + Dictionary mapping field names to numpy arrays. + Note that at least a 'U' and 'V' numpy array need to be given, and that + the built-in Advection kernels assume that U and V are in m/s. + Data shape is either [ydim, xdim], [zdim, ydim, xdim], [tdim, ydim, xdim] or [tdim, zdim, ydim, xdim], + dimensions : dict + Dictionary mapping field dimensions (lon, + lat, depth, time) to numpy arrays. + Note that dimensions can also be a dictionary of dictionaries if + dimension names are different for each variable + (e.g. dimensions['U'], dimensions['V'], etc). + mesh : str + String indicating the type of mesh coordinates and + units used during velocity interpolation, see also `this tutorial <../examples/tutorial_unitconverters.ipynb>`__: + + 1. spherical (default): Lat and lon in degree, with a + correction for zonal velocity U near the poles. + 2. flat: No conversion, lat/lon are assumed to be in m. + allow_time_extrapolation : bool + boolean whether to allow for extrapolation + (i.e. beyond the last available time snapshot) + Default is False if dimensions includes time, else True + **kwargs : + Keyword arguments passed to the :class:`Field` constructor. + + Examples + -------- + For usage examples see the following tutorials: + + * `Analytical advection <../examples/tutorial_analyticaladvection.ipynb>`__ + + * `Diffusion <../examples/tutorial_diffusion.ipynb>`__ + + * `Interpolation <../examples/tutorial_interpolation.ipynb>`__ + + * `Unit converters <../examples/tutorial_unitconverters.ipynb>`__ + """ + fields = {} + for name, datafld in data.items(): + # Use dimensions[name] if dimensions is a dict of dicts + dims = dimensions[name] if name in dimensions else dimensions + cls.checkvaliddimensionsdict(dims) + + if allow_time_extrapolation is None: + allow_time_extrapolation = False if "time" in dims else True + + lon = dims["lon"] + lat = dims["lat"] + depth = np.zeros(1, dtype=np.float32) if "depth" not in dims else dims["depth"] + time = np.zeros(1, dtype=np.float64) if "time" not in dims else dims["time"] + time = np.array(time) + + if len(datafld.shape) == 2: + coords = [lat, lon] + elif len(datafld.shape) == 3: + if "time" not in dims: + coords = [depth, lat, lon] + else: + coords = [time, lat, lon] + else: + coords = [time, depth, lat, lon] + + fields[name] = xr.DataArray( + data=datafld, + name=name, + dims=dims, + coords=coords, + attrs=dict( + description="Created with fieldset.from_data", + units="", + location="node", + mesh="Arakawa-A", + ), + ) + + return cls([xr.Dataset(fields)]) + + @classmethod + def from_xarray_dataset(cls, ds, variables, dimensions, mesh="spherical", allow_time_extrapolation=None, **kwargs): + """Initialises FieldSet data from xarray Datasets. + + Parameters + ---------- + ds : xr.Dataset + xarray Dataset. + Note that the built-in Advection kernels assume that U and V are in m/s + variables : dict + Dictionary mapping parcels variable names to data variables in the xarray Dataset. + dimensions : dict + Dictionary mapping data dimensions (lon, + lat, depth, time, data) to dimensions in the xarray Dataset. + Note that dimensions can also be a dictionary of dictionaries if + dimension names are different for each variable + (e.g. dimensions['U'], dimensions['V'], etc). + mesh : str + String indicating the type of mesh coordinates and + units used during velocity interpolation, see also `this tutorial <../examples/tutorial_unitconverters.ipynb>`__: + + 1. spherical (default): Lat and lon in degree, with a + correction for zonal velocity U near the poles. + 2. flat: No conversion, lat/lon are assumed to be in m. + allow_time_extrapolation : bool + boolean whether to allow for extrapolation + (i.e. beyond the last available time snapshot) + Default is False if dimensions includes time, else True + **kwargs : + Keyword arguments passed to the :func:`Field.from_xarray` constructor. + """ + for var, _name in variables.items(): + dims = dimensions[var] if var in dimensions else dimensions + cls.checkvaliddimensionsdict(dims) + + return cls([ds]) + # @classmethod # def from_netcdf( # cls, diff --git a/tests/test_fieldset.py b/tests/test_fieldset.py index c461d2770f..83333b1c77 100644 --- a/tests/test_fieldset.py +++ b/tests/test_fieldset.py @@ -86,10 +86,10 @@ def test_fieldset_from_data(xdim, ydim): """Simple test for fieldset initialisation from data.""" data, dimensions = generate_fieldset_data(xdim, ydim) fieldset = FieldSet.from_data(data, dimensions) - assert len(fieldset.U.data.shape) == 3 - assert len(fieldset.V.data.shape) == 3 - assert np.allclose(fieldset.U.data[0, :], data["U"], rtol=1e-12) - assert np.allclose(fieldset.V.data[0, :], data["V"], rtol=1e-12) + assert len(fieldset.U.data.shape) == 2 + assert len(fieldset.V.data.shape) == 2 + assert np.allclose(fieldset.U.data, data["U"], rtol=1e-12) + assert np.allclose(fieldset.V.data, data["V"], rtol=1e-12) @pytest.mark.v4remove @@ -111,9 +111,10 @@ def test_fieldset_from_data_timedims(ttype, tdim): dimensions["time"] = [np.datetime64("2018-01-01") + np.timedelta64(t, "D") for t in range(tdim)] fieldset = FieldSet.from_data(data, dimensions) for i, dtime in enumerate(dimensions["time"]): - assert fieldset.U.grid.time_origin.fulltime(fieldset.U.grid.time[i]) == dtime + assert fieldset.U.time[i] == dtime +@pytest.mark.v4alpha @pytest.mark.parametrize("xdim", [100, 200]) @pytest.mark.parametrize("ydim", [100, 50]) def test_fieldset_from_data_different_dimensions(xdim, ydim): @@ -134,8 +135,8 @@ def test_fieldset_from_data_different_dimensions(xdim, ydim): } fieldset = FieldSet.from_data(data, dimensions) - assert len(fieldset.U.data.shape) == 3 - assert len(fieldset.V.data.shape) == 3 + assert len(fieldset.U.data.shape) == 4 + assert len(fieldset.V.data.shape) == 4 assert len(fieldset.P.data.shape) == 4 assert fieldset.P.data.shape == (tdim, zdim, ydim / 2, xdim / 2) assert np.allclose(fieldset.U.data, 0.0, rtol=1e-12) From 71ec7e4c667936acde06272d2d665de1350c1105 Mon Sep 17 00:00:00 2001 From: Joe Date: Mon, 7 Apr 2025 14:53:36 -0400 Subject: [PATCH 26/46] Fix fieldset.from_data for test_fieldset_from_data_timedims test This requires adding a length one time dimension when dims comes in with a time dimension, but the data is a 2d array. Additionally, the dims need to be reordered to match the coords field in xarray so that `time` is first, followed by `depth` (if present), `lat`, and then `lon` --- parcels/fieldset.py | 15 ++++++++++++--- parcels/tools/_helpers.py | 3 +-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/parcels/fieldset.py b/parcels/fieldset.py index 3cd2a5a053..0353af9283 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -296,22 +296,31 @@ def from_data( lat = dims["lat"] depth = np.zeros(1, dtype=np.float32) if "depth" not in dims else dims["depth"] time = np.zeros(1, dtype=np.float64) if "time" not in dims else dims["time"] - time = np.array(time) if len(datafld.shape) == 2: - coords = [lat, lon] + if "time" in dims: + coords = [time, lat, lon] + datafld = datafld[np.newaxis, ...] + dims_xr = {"time": time, "lat": lat, "lon": lon} + else: + coords = [lat, lon] + dims_xr = {"lat": lat, "lon": lon} + elif len(datafld.shape) == 3: if "time" not in dims: coords = [depth, lat, lon] + dims_xr = {"depth": depth, "lat": lat, "lon": lon} else: coords = [time, lat, lon] + dims_xr = {"time": time, "lat": lat, "lon": lon} else: coords = [time, depth, lat, lon] + dims_xr = {"time": time, "depth": depth, "lat": lat, "lon": lon} fields[name] = xr.DataArray( data=datafld, name=name, - dims=dims, + dims=dims_xr, coords=coords, attrs=dict( description="Created with fieldset.from_data", diff --git a/parcels/tools/_helpers.py b/parcels/tools/_helpers.py index 7f26273bb6..74f4bdfb33 100644 --- a/parcels/tools/_helpers.py +++ b/parcels/tools/_helpers.py @@ -72,9 +72,8 @@ def field_repr(field: Field) -> str: """Return a pretty repr for Field""" out = f"""<{type(field).__name__}> name : {field.name!r} - grid : {field.grid!r} + data : {field.data!r} extrapolate time: {field.allow_time_extrapolation!r} - gridindexingtype: {field.gridindexingtype!r} """ return textwrap.dedent(out).strip() From 06bd5446d2a2000f74d8aa4c6ac3ecc9bf7046a7 Mon Sep 17 00:00:00 2001 From: Joe Date: Mon, 7 Apr 2025 15:08:29 -0400 Subject: [PATCH 27/46] Add pytest markers for test_fieldset --- tests/test_fieldset.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/tests/test_fieldset.py b/tests/test_fieldset.py index 83333b1c77..a5632bc962 100644 --- a/tests/test_fieldset.py +++ b/tests/test_fieldset.py @@ -80,6 +80,7 @@ def multifile_fieldset(tmp_path): return FieldSet.from_netcdf(files, variables, dimensions) +@pytest.mark.v4alpha @pytest.mark.parametrize("xdim", [100, 200]) @pytest.mark.parametrize("ydim", [100, 200]) def test_fieldset_from_data(xdim, ydim): @@ -101,6 +102,7 @@ def test_fieldset_vmin_vmax(): assert np.isclose(np.amax(fieldset.U.data), 7) +@pytest.mark.v4alpha @pytest.mark.parametrize("ttype", ["float", "datetime64"]) @pytest.mark.parametrize("tdim", [1, 20]) def test_fieldset_from_data_timedims(ttype, tdim): @@ -111,7 +113,8 @@ def test_fieldset_from_data_timedims(ttype, tdim): dimensions["time"] = [np.datetime64("2018-01-01") + np.timedelta64(t, "D") for t in range(tdim)] fieldset = FieldSet.from_data(data, dimensions) for i, dtime in enumerate(dimensions["time"]): - assert fieldset.U.time[i] == dtime + print(fieldset.U, dtime) + assert fieldset.U.time[i].data == dtime @pytest.mark.v4alpha @@ -144,6 +147,7 @@ def test_fieldset_from_data_different_dimensions(xdim, ydim): assert np.allclose(fieldset.P.data, 2.0, rtol=1e-12) +@pytest.mark.v4alpha def test_fieldset_from_modulefile(): nemo_fname = str(TEST_DATA / "fieldset_nemo.py") nemo_error_fname = str(TEST_DATA / "fieldset_nemo_error.py") @@ -162,6 +166,7 @@ def test_fieldset_from_modulefile(): FieldSet.from_modulefile(nemo_error_fname, modulename="none_returning_function") +@pytest.mark.v4alpha def test_field_from_netcdf_fieldtypes(): filenames = { "varU": { @@ -187,6 +192,7 @@ def test_field_from_netcdf_fieldtypes(): assert isinstance(fset.varU.units, GeographicPolar) +@pytest.mark.v4alpha def test_fieldset_from_agrid_dataset(): filenames = { "lon": str(TEST_DATA / "mask_nemo_cross_180lon.nc"), @@ -198,6 +204,7 @@ def test_fieldset_from_agrid_dataset(): FieldSet.from_a_grid_dataset(filenames, variable, dimensions) +@pytest.mark.v4remove def test_fieldset_from_cgrid_interpmethod(): filenames = { "lon": str(TEST_DATA / "mask_nemo_cross_180lon.nc"), @@ -212,6 +219,7 @@ def test_fieldset_from_cgrid_interpmethod(): FieldSet.from_c_grid_dataset(filenames, variable, dimensions, interp_method="partialslip") +@pytest.mark.v4alpha @pytest.mark.parametrize("calltype", ["from_data", "from_nemo"]) def test_illegal_dimensionsdict(calltype): with pytest.raises(NameError): @@ -227,6 +235,7 @@ def test_illegal_dimensionsdict(calltype): FieldSet.from_nemo(filenames, variables, dimensions) +@pytest.mark.v4alpha @pytest.mark.parametrize("xdim", [100, 200]) @pytest.mark.parametrize("ydim", [100, 200]) def test_add_field(xdim, ydim, tmpdir): @@ -237,6 +246,7 @@ def test_add_field(xdim, ydim, tmpdir): assert fieldset.newfld.data.shape == fieldset.U.data.shape +@pytest.mark.v4alpha @pytest.mark.parametrize("dupobject", ["same", "new"]) def test_add_duplicate_field(dupobject): data, dimensions = generate_fieldset_data(100, 100) @@ -251,6 +261,7 @@ def test_add_duplicate_field(dupobject): fieldset.add_field(field2) +@pytest.mark.v4alpha @pytest.mark.parametrize("fieldtype", ["normal", "vector"]) def test_add_field_after_pset(fieldtype): data, dimensions = generate_fieldset_data(100, 100) @@ -272,6 +283,7 @@ def test_fieldset_samegrids_from_file(multifile_fieldset): assert multifile_fieldset.U.grid == multifile_fieldset.V.grid +@pytest.mark.v4alpha @pytest.mark.parametrize("gridtype", ["A", "C"]) def test_fieldset_dimlength1_cgrid(gridtype): fieldset = FieldSet.from_data({"U": 0, "V": 0}, {"lon": 0, "lat": 0}) @@ -293,6 +305,7 @@ def assign_dataset_timestamp_dim(ds, timestamp): return ds +@pytest.mark.v4alpha def test_fieldset_diffgrids_from_file(tmp_path): """Test for subsetting fieldset from file using indices dict.""" stem = "test_subsets" @@ -323,6 +336,7 @@ def test_fieldset_diffgrids_from_file(tmp_path): assert fieldset.U.grid != fieldset.V.grid +@pytest.mark.v4alpha def test_fieldset_diffgrids_from_file_data(multifile_fieldset): """Test for subsetting fieldset from file using indices dict.""" data, dimensions = generate_fieldset_data(100, 100) @@ -336,6 +350,7 @@ def test_fieldset_diffgrids_from_file_data(multifile_fieldset): assert multifile_fieldset.U.grid != multifile_fieldset.B.grid +@pytest.mark.v4alpha def test_fieldset_samegrids_from_data(): """Test for subsetting fieldset from file using indices dict.""" data, dimensions = generate_fieldset_data(100, 100) @@ -351,6 +366,7 @@ def addConst(particle, fieldset, time): # pragma: no cover particle.lon = particle.lon + fieldset.movewest + fieldset.moveeast +@pytest.mark.v4alpha def test_fieldset_constant(): data, dimensions = generate_fieldset_data(100, 100) fieldset = FieldSet.from_data(data, dimensions) @@ -365,6 +381,7 @@ def test_fieldset_constant(): assert abs(pset.lon[0] - (0.5 + westval + eastval)) < 1e-4 +@pytest.mark.v4alpha @pytest.mark.parametrize("swapUV", [False, True]) def test_vector_fields(swapUV): lon = np.linspace(0.0, 10.0, 12, dtype=np.float32) @@ -388,6 +405,7 @@ def test_vector_fields(swapUV): assert abs(pset.lat[0] - 0.5) < 1e-9 +@pytest.mark.v4alpha def test_add_second_vector_field(): lon = np.linspace(0.0, 10.0, 12, dtype=np.float32) lat = np.linspace(0.0, 10.0, 10, dtype=np.float32) @@ -539,6 +557,7 @@ def sampleTemp(particle, fieldset, time): # pragma: no cover assert np.allclose(pset.d[0], 1.0) +@pytest.mark.v4alpha @pytest.mark.parametrize("tdim", [10, None]) def test_fieldset_from_xarray(tdim): def generate_dataset(xdim, ydim, zdim=1, tdim=1): @@ -593,6 +612,7 @@ def test_fieldset_frompop(): pset.execute(AdvectionRK4, runtime=3, dt=1) +@pytest.mark.v4alpha def test_fieldset_from_data_gridtypes(): """Simple test for fieldset initialisation from data.""" xdim, ydim, zdim = 20, 10, 4 From 7b9e9d38a4fa5ea820ce1e2654d91c7d6fa5cbeb Mon Sep 17 00:00:00 2001 From: Joe Date: Mon, 7 Apr 2025 15:11:22 -0400 Subject: [PATCH 28/46] Mark tests with `from_netcdf` for removal Data coming from netcdf should be read in with the appropriate xarray or uxarray calls. I've marked tests with `from_xarray` for v4alpha which should be sufficient for the same desired functionality. Model specific loaders I've also retained for v4alpha. --- tests/test_field.py | 3 +++ tests/test_fieldset.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_field.py b/tests/test_field.py index 7d7ec1ce36..989e204bb9 100644 --- a/tests/test_field.py +++ b/tests/test_field.py @@ -11,6 +11,7 @@ from tests.utils import TEST_DATA +@pytest.mark.v4alpha def test_field_from_netcdf_variables(): filename = str(TEST_DATA / "perlinfieldsU.nc") dims = {"lon": "x", "lat": "y"} @@ -30,6 +31,7 @@ def test_field_from_netcdf_variables(): f3 = Field.from_netcdf(filename, variable, dims) +@pytest.mark.v4remove def test_field_from_netcdf(): filenames = { "lon": str(TEST_DATA / "mask_nemo_cross_180lon.nc"), @@ -41,6 +43,7 @@ def test_field_from_netcdf(): Field.from_netcdf(filenames, variable, dimensions, interp_method="cgrid_velocity") +@pytest.mark.v4remove @pytest.mark.parametrize( "calendar, cftime_datetime", zip(_get_cftime_calendars(), _get_cftime_datetimes(), strict=True) ) diff --git a/tests/test_fieldset.py b/tests/test_fieldset.py index a5632bc962..d73f5f3342 100644 --- a/tests/test_fieldset.py +++ b/tests/test_fieldset.py @@ -58,6 +58,7 @@ def to_xarray_dataset(data: dict[str, np.array], dimensions: dict[str, np.array] ) +@pytest.mark.v4remove @pytest.fixture def multifile_fieldset(tmp_path): stem = "test_subsets" @@ -305,7 +306,7 @@ def assign_dataset_timestamp_dim(ds, timestamp): return ds -@pytest.mark.v4alpha +@pytest.mark.v4remove def test_fieldset_diffgrids_from_file(tmp_path): """Test for subsetting fieldset from file using indices dict.""" stem = "test_subsets" From 9b04830b2ae7705594c7bc67e86e297d5baa3597 Mon Sep 17 00:00:00 2001 From: Joe Date: Mon, 7 Apr 2025 15:14:36 -0400 Subject: [PATCH 29/46] Remove test_grids grid.py and gridset.py will be removed in a future commit since the grid information is stored as part of the xarray or uxarray dataarray/dataset --- tests/test_grids.py | 992 -------------------------------------------- 1 file changed, 992 deletions(-) delete mode 100644 tests/test_grids.py diff --git a/tests/test_grids.py b/tests/test_grids.py deleted file mode 100644 index 59b6537dca..0000000000 --- a/tests/test_grids.py +++ /dev/null @@ -1,992 +0,0 @@ -import math -from datetime import timedelta - -import numpy as np -import pytest -import xarray as xr - -from parcels import ( - AdvectionRK4, - AdvectionRK4_3D, - CurvilinearZGrid, - Field, - FieldSet, - Particle, - ParticleSet, - RectilinearSGrid, - RectilinearZGrid, - StatusCode, - UnitConverter, - Variable, -) -from tests.utils import TEST_DATA - - -def test_multi_structured_grids(): - def temp_func(lon, lat): - return 20 + lat / 1000.0 + 2 * np.sin(lon * 2 * np.pi / 5000.0) - - a = 10000 - b = 10000 - - # Grid 0 - xdim_g0 = 201 - ydim_g0 = 201 - # Coordinates of the test fieldset (on A-grid in deg) - lon_g0 = np.linspace(0, a, xdim_g0, dtype=np.float32) - lat_g0 = np.linspace(0, b, ydim_g0, dtype=np.float32) - time_g0 = np.linspace(0.0, 1000.0, 2, dtype=np.float64) - grid_0 = RectilinearZGrid(lon_g0, lat_g0, time=time_g0) - - # Grid 1 - xdim_g1 = 51 - ydim_g1 = 51 - # Coordinates of the test fieldset (on A-grid in deg) - lon_g1 = np.linspace(0, a, xdim_g1, dtype=np.float32) - lat_g1 = np.linspace(0, b, ydim_g1, dtype=np.float32) - time_g1 = np.linspace(0.0, 1000.0, 2, dtype=np.float64) - grid_1 = RectilinearZGrid(lon_g1, lat_g1, time=time_g1) - - u_data = np.ones((time_g0.size, lat_g0.size, lon_g0.size), dtype=np.float32) - u_data = 2 * u_data - u_field = Field("U", u_data, grid=grid_0) - - temp0_data = np.empty((time_g0.size, lat_g0.size, lon_g0.size), dtype=np.float32) - for i in range(lon_g0.size): - for j in range(lat_g0.size): - temp0_data[:, j, i] = temp_func(lon_g0[i], lat_g0[j]) - temp0_field = Field("temp0", temp0_data, grid=grid_0) - - v_data = np.zeros((time_g1.size, lat_g1.size, lon_g1.size), dtype=np.float32) - v_field = Field("V", v_data, grid=grid_1) - - temp1_data = np.empty((time_g1.size, lat_g1.size, lon_g1.size), dtype=np.float32) - for i in range(lon_g1.size): - for j in range(lat_g1.size): - temp1_data[:, j, i] = temp_func(lon_g1[i], lat_g1[j]) - temp1_field = Field("temp1", temp1_data, grid=grid_1) - - other_fields = {} - other_fields["temp0"] = temp0_field - other_fields["temp1"] = temp1_field - - fieldset = FieldSet(u_field, v_field, fields=other_fields) - - def sampleTemp(particle, fieldset, time): # pragma: no cover - # Note that fieldset.temp is interpolated at time=time+dt. - # Indeed, sampleTemp is called at time=time, but the result is written - # at time=time+dt, after the Kernel update - particle.temp0 = fieldset.temp0[time + particle.dt, particle.depth, particle.lat, particle.lon] - particle.temp1 = fieldset.temp1[time + particle.dt, particle.depth, particle.lat, particle.lon] - - MyParticle = Particle.add_variables( - [Variable("temp0", dtype=np.float32, initial=20.0), Variable("temp1", dtype=np.float32, initial=20.0)] - ) - - pset = ParticleSet.from_list(fieldset, MyParticle, lon=[3001], lat=[5001], repeatdt=1) - - pset.execute(AdvectionRK4 + pset.Kernel(sampleTemp), runtime=3, dt=1) - - # check if particle xi and yi are different for the two grids - # xi check from unraveled index - assert np.all( - [fieldset.U.unravel_index(pset[i].ei)[2] != fieldset.V.unravel_index(pset[i].ei)[2] for i in range(3)] - ) - # yi check from unraveled index - assert np.all( - [fieldset.U.unravel_index(pset[i].ei)[1] != fieldset.V.unravel_index(pset[i].ei)[1] for i in range(3)] - ) - # advect without updating temperature to test particle deletion - pset.remove_indices(np.array([1])) - pset.execute(AdvectionRK4, runtime=1, dt=1) - - assert np.all([np.isclose(p.temp0, p.temp1, atol=1e-3) for p in pset]) - - -def test_time_format_in_grid(): - lon = np.linspace(0, 1, 2, dtype=np.float32) - lat = np.linspace(0, 1, 2, dtype=np.float32) - time = np.array([np.datetime64("2000-01-01")] * 2) - with pytest.raises(AssertionError, match="Time vector"): - RectilinearZGrid(lon, lat, time=time) - - -@pytest.mark.v4remove -@pytest.mark.xfail(reason="negate_depth removed in v4") -def test_negate_depth(): - depth = np.linspace(0, 5, 10, dtype=np.float32) - fieldset = FieldSet.from_data( - {"U": np.zeros((10, 1, 1)), "V": np.zeros((10, 1, 1))}, {"lon": [0], "lat": [0], "depth": depth} - ) - assert np.all(fieldset.gridset.grids[0].depth == depth) - fieldset.U.grid.negate_depth() - assert np.all(fieldset.gridset.grids[0].depth == -depth) - - -def test_avoid_repeated_grids(): - lon_g0 = np.linspace(0, 1000, 11, dtype=np.float32) - lat_g0 = np.linspace(0, 1000, 11, dtype=np.float32) - time_g0 = np.linspace(0, 1000, 2, dtype=np.float64) - grid_0 = RectilinearZGrid(lon_g0, lat_g0, time=time_g0) - - lon_g1 = np.linspace(0, 1000, 21, dtype=np.float32) - lat_g1 = np.linspace(0, 1000, 21, dtype=np.float32) - time_g1 = np.linspace(0, 1000, 2, dtype=np.float64) - grid_1 = RectilinearZGrid(lon_g1, lat_g1, time=time_g1) - - u_data = np.zeros((time_g0.size, lat_g0.size, lon_g0.size), dtype=np.float32) - u_field = Field("U", u_data, grid=grid_0) - - v_data = np.zeros((time_g1.size, lat_g1.size, lon_g1.size), dtype=np.float32) - v_field = Field("V", v_data, grid=grid_1) - - temp0_field = Field("temp", u_data, lon=lon_g0, lat=lat_g0, time=time_g0) - - other_fields = {} - other_fields["temp"] = temp0_field - - fieldset = FieldSet(u_field, v_field, fields=other_fields) - assert fieldset.gridset.size == 2 - assert fieldset.U.grid is fieldset.temp.grid - assert fieldset.V.grid is not fieldset.U.grid - - -@pytest.mark.v4alpha -@pytest.mark.xfail(reason="Calls fieldset.add_periodic_halo(). Should adapt this test case.") -def test_multigrids_pointer(): - lon_g0 = np.linspace(0, 1e4, 21, dtype=np.float32) - lat_g0 = np.linspace(0, 1000, 2, dtype=np.float32) - depth_g0 = np.zeros((5, lat_g0.size, lon_g0.size), dtype=np.float32) - - def bath_func(lon): - return lon / 1000.0 + 10 - - bath = bath_func(lon_g0) - - zdim = depth_g0.shape[0] - for i in range(lon_g0.size): - for k in range(zdim): - depth_g0[k, :, i] = bath[i] * k / (zdim - 1) - - grid_0 = RectilinearSGrid(lon_g0, lat_g0, depth=depth_g0) - grid_1 = RectilinearSGrid(lon_g0, lat_g0, depth=depth_g0) - - u_data = np.zeros((zdim, lat_g0.size, lon_g0.size), dtype=np.float32) - v_data = np.zeros((zdim, lat_g0.size, lon_g0.size), dtype=np.float32) - w_data = np.zeros((zdim, lat_g0.size, lon_g0.size), dtype=np.float32) - - u_field = Field("U", u_data, grid=grid_0) - v_field = Field("V", v_data, grid=grid_0) - w_field = Field("W", w_data, grid=grid_1) - - fieldset = FieldSet(u_field, v_field, fields={"W": w_field}) - fieldset.add_periodic_halo(zonal=3, meridional=2) # unit test of halo for SGrid - - assert u_field.grid == v_field.grid - assert u_field.grid == w_field.grid # w_field.grid is now supposed to be grid_1 - - pset = ParticleSet.from_list(fieldset, Particle, lon=[0], lat=[0], depth=[1]) - - for _ in range(10): - pset.execute(AdvectionRK4_3D, runtime=1000, dt=500) - - -@pytest.mark.parametrize("z4d", ["True", "False"]) -def test_rectilinear_s_grid_sampling(z4d): - lon_g0 = np.linspace(-3e4, 3e4, 61, dtype=np.float32) - lat_g0 = np.linspace(0, 1000, 2, dtype=np.float32) - time_g0 = np.linspace(0, 1000, 2, dtype=np.float64) - if z4d: - depth_g0 = np.zeros((time_g0.size, 5, lat_g0.size, lon_g0.size), dtype=np.float32) - else: - depth_g0 = np.zeros((5, lat_g0.size, lon_g0.size), dtype=np.float32) - - def bath_func(lon): - bath = (lon <= -2e4) * 20.0 - bath += (lon > -2e4) * (lon < 2e4) * (110.0 + 90 * np.sin(lon / 2e4 * np.pi / 2.0)) - bath += (lon >= 2e4) * 200.0 - return bath - - bath = bath_func(lon_g0) - - zdim = depth_g0.shape[-3] - for i in range(depth_g0.shape[-1]): - for k in range(zdim): - if z4d: - depth_g0[:, k, :, i] = bath[i] * k / (zdim - 1) - else: - depth_g0[k, :, i] = bath[i] * k / (zdim - 1) - - grid = RectilinearSGrid(lon_g0, lat_g0, depth=depth_g0, time=time_g0) - - u_data = np.zeros((grid.tdim, grid.zdim, grid.ydim, grid.xdim), dtype=np.float32) - v_data = np.zeros((grid.tdim, grid.zdim, grid.ydim, grid.xdim), dtype=np.float32) - temp_data = np.zeros((grid.tdim, grid.zdim, grid.ydim, grid.xdim), dtype=np.float32) - for k in range(1, zdim): - temp_data[:, k, :, :] = k / (zdim - 1.0) - u_field = Field("U", u_data, grid=grid) - v_field = Field("V", v_data, grid=grid) - temp_field = Field("temp", temp_data, grid=grid) - - other_fields = {} - other_fields["temp"] = temp_field - fieldset = FieldSet(u_field, v_field, fields=other_fields) - - def sampleTemp(particle, fieldset, time): # pragma: no cover - particle.temp = fieldset.temp[time, particle.depth, particle.lat, particle.lon] - - MyParticle = Particle.add_variable("temp", dtype=np.float32, initial=20.0) - - lon = 400 - lat = 0 - ratio = 0.3 - pset = ParticleSet.from_list(fieldset, MyParticle, lon=[lon], lat=[lat], depth=[bath_func(lon) * ratio]) - - pset.execute(pset.Kernel(sampleTemp), runtime=1) - assert np.allclose(pset.temp[0], ratio, atol=1e-4) - - -def test_rectilinear_s_grids_advect1(): - # Constant water transport towards the east. check that the particle stays at the same relative depth (z/bath) - lon_g0 = np.linspace(0, 1e4, 21, dtype=np.float32) - lat_g0 = np.linspace(0, 1000, 2, dtype=np.float32) - depth_g0 = np.zeros((lon_g0.size, lat_g0.size, 5), dtype=np.float32) - - def bath_func(lon): - return lon / 1000.0 + 10 - - bath = bath_func(lon_g0) - - for i in range(depth_g0.shape[0]): - for k in range(depth_g0.shape[2]): - depth_g0[i, :, k] = bath[i] * k / (depth_g0.shape[2] - 1) - depth_g0 = depth_g0.transpose() - - grid = RectilinearSGrid(lon_g0, lat_g0, depth=depth_g0) - - zdim = depth_g0.shape[0] - u_data = np.zeros((zdim, lat_g0.size, lon_g0.size), dtype=np.float32) - v_data = np.zeros((zdim, lat_g0.size, lon_g0.size), dtype=np.float32) - w_data = np.zeros((zdim, lat_g0.size, lon_g0.size), dtype=np.float32) - for i in range(lon_g0.size): - u_data[:, :, i] = 1 * 10 / bath[i] - for k in range(zdim): - w_data[k, :, i] = u_data[k, :, i] * depth_g0[k, :, i] / bath[i] * 1e-3 - - u_field = Field("U", u_data, grid=grid) - v_field = Field("V", v_data, grid=grid) - w_field = Field("W", w_data, grid=grid) - - fieldset = FieldSet(u_field, v_field, fields={"W": w_field}) - - lon = np.zeros(11) - lat = np.zeros(11) - ratio = [min(i / 10.0, 0.99) for i in range(11)] - depth = bath_func(lon) * ratio - pset = ParticleSet.from_list(fieldset, Particle, lon=lon, lat=lat, depth=depth) - - pset.execute(AdvectionRK4_3D, runtime=10000, dt=500) - assert np.allclose(pset.depth / bath_func(pset.lon), ratio) - - -def test_rectilinear_s_grids_advect2(): - # Move particle towards the east, check relative depth evolution - lon_g0 = np.linspace(0, 1e4, 21, dtype=np.float32) - lat_g0 = np.linspace(0, 1000, 2, dtype=np.float32) - depth_g0 = np.zeros((5, lat_g0.size, lon_g0.size), dtype=np.float32) - - def bath_func(lon): - return lon / 1000.0 + 10 - - bath = bath_func(lon_g0) - - zdim = depth_g0.shape[0] - for i in range(lon_g0.size): - for k in range(zdim): - depth_g0[k, :, i] = bath[i] * k / (zdim - 1) - - grid = RectilinearSGrid(lon_g0, lat_g0, depth=depth_g0) - - u_data = np.zeros((zdim, lat_g0.size, lon_g0.size), dtype=np.float32) - v_data = np.zeros((zdim, lat_g0.size, lon_g0.size), dtype=np.float32) - rel_depth_data = np.zeros((zdim, lat_g0.size, lon_g0.size), dtype=np.float32) - for k in range(1, zdim): - rel_depth_data[k, :, :] = k / (zdim - 1.0) - - u_field = Field("U", u_data, grid=grid) - v_field = Field("V", v_data, grid=grid) - rel_depth_field = Field("relDepth", rel_depth_data, grid=grid) - fieldset = FieldSet(u_field, v_field, fields={"relDepth": rel_depth_field}) - - MyParticle = Particle.add_variable("relDepth", dtype=np.float32, initial=20.0) - - def moveEast(particle, fieldset, time): # pragma: no cover - particle_dlon += 5 * particle.dt # noqa - particle.relDepth = fieldset.relDepth[time, particle.depth, particle.lat, particle.lon] - - depth = 0.9 - pset = ParticleSet.from_list(fieldset, MyParticle, lon=[0], lat=[0], depth=[depth]) - - kernel = pset.Kernel(moveEast) - for _ in range(10): - pset.execute(kernel, runtime=100, dt=50) - assert np.allclose(pset.relDepth[0], depth / bath_func(pset.lon[0])) - - -def test_curvilinear_grids(): - x = np.linspace(0, 1e3, 7, dtype=np.float32) - y = np.linspace(0, 1e3, 5, dtype=np.float32) - (xx, yy) = np.meshgrid(x, y) - - r = np.sqrt(xx * xx + yy * yy) - theta = np.arctan2(yy, xx) - theta = theta + np.pi / 6.0 - - lon = r * np.cos(theta) - lat = r * np.sin(theta) - time = np.array([0, 86400], dtype=np.float64) - grid = CurvilinearZGrid(lon, lat, time=time) - - u_data = np.ones((2, y.size, x.size), dtype=np.float32) - v_data = np.zeros((2, y.size, x.size), dtype=np.float32) - u_data[0, :, :] = lon[:, :] + lat[:, :] - u_field = Field("U", u_data, grid=grid) - v_field = Field("V", v_data, grid=grid) - fieldset = FieldSet(u_field, v_field) - - def sampleSpeed(particle, fieldset, time): # pragma: no cover - u, v = fieldset.UV[time, particle.depth, particle.lat, particle.lon] - particle.speed = math.sqrt(u * u + v * v) - - MyParticle = Particle.add_variable("speed", dtype=np.float32, initial=0.0) - - pset = ParticleSet.from_list(fieldset, MyParticle, lon=[400, -200], lat=[600, 600]) - pset.execute(pset.Kernel(sampleSpeed), runtime=1) - assert np.allclose(pset.speed[0], 1000) - - -def test_nemo_grid(): - filenames = { - "U": { - "lon": str(TEST_DATA / "mask_nemo_cross_180lon.nc"), - "lat": str(TEST_DATA / "mask_nemo_cross_180lon.nc"), - "data": str(TEST_DATA / "Uu_eastward_nemo_cross_180lon.nc"), - }, - "V": { - "lon": str(TEST_DATA / "mask_nemo_cross_180lon.nc"), - "lat": str(TEST_DATA / "mask_nemo_cross_180lon.nc"), - "data": str(TEST_DATA / "Vv_eastward_nemo_cross_180lon.nc"), - }, - } - variables = {"U": "U", "V": "V"} - dimensions = {"lon": "glamf", "lat": "gphif"} - fieldset = FieldSet.from_nemo(filenames, variables, dimensions) - - # test ParticleSet.from_field on curvilinear grids - ParticleSet.from_field(fieldset, Particle, start_field=fieldset.U, size=5) - - def sampleVel(particle, fieldset, time): # pragma: no cover - (particle.zonal, particle.meridional) = fieldset.UV[time, particle.depth, particle.lat, particle.lon] - - MyParticle = Particle.add_variables( - [Variable("zonal", dtype=np.float32, initial=0.0), Variable("meridional", dtype=np.float32, initial=0.0)] - ) - - lonp = 175.5 - latp = 81.5 - pset = ParticleSet.from_list(fieldset, MyParticle, lon=[lonp], lat=[latp]) - pset.execute(pset.Kernel(sampleVel), runtime=1) - u = fieldset.U.units.to_source(pset.zonal[0], 0, latp, lonp) - v = fieldset.V.units.to_source(pset.meridional[0], 0, latp, lonp) - assert abs(u - 1) < 1e-4 - assert abs(v) < 1e-4 - - -def test_advect_nemo(): - filenames = { - "U": { - "lon": str(TEST_DATA / "mask_nemo_cross_180lon.nc"), - "lat": str(TEST_DATA / "mask_nemo_cross_180lon.nc"), - "data": str(TEST_DATA / "Uu_eastward_nemo_cross_180lon.nc"), - }, - "V": { - "lon": str(TEST_DATA / "mask_nemo_cross_180lon.nc"), - "lat": str(TEST_DATA / "mask_nemo_cross_180lon.nc"), - "data": str(TEST_DATA / "Vv_eastward_nemo_cross_180lon.nc"), - }, - } - variables = {"U": "U", "V": "V"} - dimensions = {"lon": "glamf", "lat": "gphif"} - fieldset = FieldSet.from_nemo(filenames, variables, dimensions) - - lonp = 175.5 - latp = 81.5 - pset = ParticleSet.from_list(fieldset, Particle, lon=[lonp], lat=[latp]) - pset.execute(AdvectionRK4, runtime=timedelta(days=2), dt=timedelta(hours=6)) - assert abs(pset.lat[0] - latp) < 1e-3 - - -@pytest.mark.parametrize("time", [True, False]) -def test_cgrid_uniform_2dvel(time): - lon = np.array([[0, 2], [0.4, 1.5]]) - lat = np.array([[0, -0.5], [0.8, 0.5]]) - U = np.array([[-99, -99], [4.4721359549995793e-01, 1.3416407864998738e00]]) - V = np.array([[-99, 1.2126781251816650e00], [-99, 1.2278812270298409e00]]) - - if time: - U = np.stack((U, U)) - V = np.stack((V, V)) - dimensions = {"lat": lat, "lon": lon, "time": np.array([0, 10])} - else: - dimensions = {"lat": lat, "lon": lon} - data = {"U": np.array(U, dtype=np.float32), "V": np.array(V, dtype=np.float32)} - fieldset = FieldSet.from_data(data, dimensions, mesh="flat") - fieldset.U.interp_method = "cgrid_velocity" - fieldset.V.interp_method = "cgrid_velocity" - - def sampleVel(particle, fieldset, time): # pragma: no cover - (particle.zonal, particle.meridional) = fieldset.UV[time, particle.depth, particle.lat, particle.lon] - - MyParticle = Particle.add_variables( - [Variable("zonal", dtype=np.float32, initial=0.0), Variable("meridional", dtype=np.float32, initial=0.0)] - ) - - pset = ParticleSet.from_list(fieldset, MyParticle, lon=0.7, lat=0.3) - pset.execute(pset.Kernel(sampleVel), runtime=1) - assert (pset[0].zonal - 1) < 1e-6 - assert (pset[0].meridional - 1) < 1e-6 - - -@pytest.mark.v4alpha -@pytest.mark.parametrize("vert_mode", ["zlev"]) # , "slev1", "slev2"]) # TODO v4: time-varying depth not supported yet -@pytest.mark.parametrize("time", [True, False]) -def test_cgrid_uniform_3dvel(vert_mode, time): - lon = np.array([[0, 2], [0.4, 1.5]]) - lat = np.array([[0, -0.5], [0.8, 0.5]]) - - u0 = 4.4721359549995793e-01 - u1 = 1.3416407864998738e00 - v0 = 1.2126781251816650e00 - v1 = 1.2278812270298409e00 - w0 = 1 - w1 = 1 - - if vert_mode == "zlev": - depth = np.array([0, 1]) - elif vert_mode == "slev1": - depth = np.array([[[0, 0], [0, 0]], [[1, 1], [1, 1]]]) - elif vert_mode == "slev2": - depth = np.array([[[-1, -0.6], [-1.1257142857142859, -0.9]], [[1, 1.5], [0.50857142857142845, 0.8]]]) - w0 = 1.0483007922296661e00 - w1 = 1.3098951476312375e00 - - U = np.array([[[-99, -99], [u0, u1]], [[-99, -99], [-99, -99]]]) - V = np.array([[[-99, v0], [-99, v1]], [[-99, -99], [-99, -99]]]) - W = np.array([[[-99, -99], [-99, w0]], [[-99, -99], [-99, w1]]]) - - if time: - U = np.stack((U, U)) - V = np.stack((V, V)) - W = np.stack((W, W)) - dimensions = {"lat": lat, "lon": lon, "depth": depth, "time": np.array([0, 10])} - else: - dimensions = {"lat": lat, "lon": lon, "depth": depth} - data = {"U": np.array(U, dtype=np.float32), "V": np.array(V, dtype=np.float32), "W": np.array(W, dtype=np.float32)} - fieldset = FieldSet.from_data(data, dimensions, mesh="flat") - fieldset.U.interp_method = "cgrid_velocity" - fieldset.V.interp_method = "cgrid_velocity" - fieldset.W.interp_method = "cgrid_velocity" - - def sampleVel(particle, fieldset, time): # pragma: no cover - (particle.zonal, particle.meridional, particle.vertical) = fieldset.UVW[ - time, particle.depth, particle.lat, particle.lon - ] - - MyParticle = Particle.add_variables( - [ - Variable("zonal", dtype=np.float32, initial=0.0), - Variable("meridional", dtype=np.float32, initial=0.0), - Variable("vertical", dtype=np.float32, initial=0.0), - ] - ) - - pset = ParticleSet.from_list(fieldset, MyParticle, lon=0.7, lat=0.3, depth=0.2) - pset.execute(pset.Kernel(sampleVel), runtime=1) - assert abs(pset[0].zonal - 1) < 1e-6 - assert abs(pset[0].meridional - 1) < 1e-6 - assert abs(pset[0].vertical - 1) < 1e-6 - - -@pytest.mark.parametrize("vert_mode", ["zlev", "slev1"]) -@pytest.mark.parametrize("time", [True, False]) -def test_cgrid_uniform_3dvel_spherical(vert_mode, time): - dim_file = xr.open_dataset(TEST_DATA / "mask_nemo_cross_180lon.nc") - u_file = xr.open_dataset(TEST_DATA / "Uu_eastward_nemo_cross_180lon.nc") - v_file = xr.open_dataset(TEST_DATA / "Vv_eastward_nemo_cross_180lon.nc") - j = 4 - i = 11 - lon = np.array(dim_file.glamf[0, j : j + 2, i : i + 2]) - lat = np.array(dim_file.gphif[0, j : j + 2, i : i + 2]) - U = np.array(u_file.U[0, j : j + 2, i : i + 2]) - V = np.array(v_file.V[0, j : j + 2, i : i + 2]) - trash = np.zeros((2, 2)) - U = np.stack((U, trash)) - V = np.stack((V, trash)) - w0 = 1 - w1 = 1 - W = np.array([[[-99, -99], [-99, w0]], [[-99, -99], [-99, w1]]]) - - if vert_mode == "zlev": - depth = np.array([0, 1]) - elif vert_mode == "slev1": - depth = np.array([[[0, 0], [0, 0]], [[1, 1], [1, 1]]]) - - if time: - U = np.stack((U, U)) - V = np.stack((V, V)) - W = np.stack((W, W)) - dimensions = {"lat": lat, "lon": lon, "depth": depth, "time": np.array([0, 10])} - else: - dimensions = {"lat": lat, "lon": lon, "depth": depth} - data = {"U": np.array(U, dtype=np.float32), "V": np.array(V, dtype=np.float32), "W": np.array(W, dtype=np.float32)} - fieldset = FieldSet.from_data(data, dimensions, mesh="spherical") - fieldset.U.interp_method = "cgrid_velocity" - fieldset.V.interp_method = "cgrid_velocity" - fieldset.W.interp_method = "cgrid_velocity" - - def sampleVel(particle, fieldset, time): # pragma: no cover - (particle.zonal, particle.meridional, particle.vertical) = fieldset.UVW[ - time, particle.depth, particle.lat, particle.lon - ] - - MyParticle = Particle.add_variables( - [ - Variable("zonal", dtype=np.float32, initial=0.0), - Variable("meridional", dtype=np.float32, initial=0.0), - Variable("vertical", dtype=np.float32, initial=0.0), - ] - ) - - lonp = 179.8 - latp = 81.35 - pset = ParticleSet.from_list(fieldset, MyParticle, lon=lonp, lat=latp, depth=0.2) - pset.execute(pset.Kernel(sampleVel), runtime=1) - pset.zonal[0] = fieldset.U.units.to_source(pset.zonal[0], 0, latp, lonp) - pset.meridional[0] = fieldset.V.units.to_source(pset.meridional[0], 0, latp, lonp) - assert abs(pset[0].zonal - 1) < 1e-3 - assert abs(pset[0].meridional) < 1e-3 - assert abs(pset[0].vertical - 1) < 1e-3 - - -@pytest.mark.v4alpha -@pytest.mark.xfail(reason="From_pop is not supported during v4-alpha development. This will be reconsidered in v4.") -@pytest.mark.parametrize("vert_discretisation", ["zlevel", "slevel", "slevel2"]) -def test_popgrid(vert_discretisation): - if vert_discretisation == "zlevel": - w_dep = "w_dep" - elif vert_discretisation == "slevel": - w_dep = "w_deps" # same as zlevel, but defined as slevel - elif vert_discretisation == "slevel2": - w_dep = "w_deps2" # contains shaved cells - - filenames = str(TEST_DATA / "POPtestdata_time.nc") - variables = {"U": "U", "V": "V", "W": "W", "T": "T"} - dimensions = {"lon": "lon", "lat": "lat", "depth": w_dep, "time": "time"} - - fieldset = FieldSet.from_pop(filenames, variables, dimensions, mesh="flat") - - def sampleVel(particle, fieldset, time): # pragma: no cover - (particle.zonal, particle.meridional, particle.vert) = fieldset.UVW[particle] - particle.tracer = fieldset.T[particle] - - def OutBoundsError(particle, fieldset, time): # pragma: no cover - if particle.state == StatusCode.ErrorOutOfBounds: - particle.out_of_bounds = 1 - particle_ddepth -= 3 # noqa - particle.state = StatusCode.Success - - MyParticle = Particle.add_variables( - [ - Variable("zonal", dtype=np.float32, initial=0.0), - Variable("meridional", dtype=np.float32, initial=0.0), - Variable("vert", dtype=np.float32, initial=0.0), - Variable("tracer", dtype=np.float32, initial=0.0), - Variable("out_of_bounds", dtype=np.float32, initial=0.0), - ] - ) - - pset = ParticleSet.from_list(fieldset, MyParticle, lon=[3, 5, 1], lat=[3, 5, 1], depth=[3, 7, 11]) - pset.execute(pset.Kernel(sampleVel) + OutBoundsError, runtime=1) - if vert_discretisation == "slevel2": - assert np.isclose(pset.vert[0], 0.0) - assert np.isclose(pset.zonal[0], 0.0) - assert np.isclose(pset.tracer[0], 99.0) - assert np.isclose(pset.vert[1], -0.0066666666) - assert np.isclose(pset.zonal[1], 0.015) - assert np.isclose(pset.tracer[1], 1.0) - assert pset.out_of_bounds[0] == 0 - assert pset.out_of_bounds[1] == 0 - assert pset.out_of_bounds[2] == 1 - else: - assert np.allclose(pset.zonal, 0.015) - assert np.allclose(pset.meridional, 0.01) - assert np.allclose(pset.vert, -0.01) - assert np.allclose(pset.tracer, 1) - - -@pytest.mark.parametrize("gridindexingtype", ["mitgcm", "nemo"]) -@pytest.mark.parametrize("coordtype", ["rectilinear", "curvilinear"]) -def test_cgrid_indexing(gridindexingtype, coordtype): - xdim, ydim = 151, 201 - a = b = 20000 # domain size - lon = np.linspace(-a / 2, a / 2, xdim, dtype=np.float32) - lat = np.linspace(-b / 2, b / 2, ydim, dtype=np.float32) - dx, dy = lon[2] - lon[1], lat[2] - lat[1] - omega = 2 * np.pi / timedelta(days=1).total_seconds() - - index_signs = {"nemo": -1, "mitgcm": 1} - isign = index_signs[gridindexingtype] - - def rotate_coords(lon, lat, alpha=0): - rotmat = np.array([[np.cos(alpha), np.sin(alpha)], [-np.sin(alpha), np.cos(alpha)]]) - lons, lats = np.meshgrid(lon, lat) - rotated = np.einsum("ji, mni -> jmn", rotmat, np.dstack([lons, lats])) - return rotated[0], rotated[1] - - if coordtype == "rectilinear": - alpha = 0 - elif coordtype == "curvilinear": - alpha = 15 * np.pi / 180 - lon, lat = rotate_coords(lon, lat, alpha) - - def calc_r_phi(ln, lt): - return np.sqrt(ln**2 + lt**2), np.arctan2(ln, lt) - - if coordtype == "rectilinear": - - def calculate_UVR(lat, lon, dx, dy, omega, alpha): - U = np.zeros((lat.size, lon.size), dtype=np.float32) - V = np.zeros((lat.size, lon.size), dtype=np.float32) - R = np.zeros((lat.size, lon.size), dtype=np.float32) - for i in range(lon.size): - for j in range(lat.size): - r, phi = calc_r_phi(lon[i], lat[j]) - R[j, i] = r - r, phi = calc_r_phi(lon[i] + isign * dx / 2, lat[j]) - V[j, i] = -omega * r * np.sin(phi) - r, phi = calc_r_phi(lon[i], lat[j] + isign * dy / 2) - U[j, i] = omega * r * np.cos(phi) - return U, V, R - elif coordtype == "curvilinear": - - def calculate_UVR(lat, lon, dx, dy, omega, alpha): - U = np.zeros(lat.shape, dtype=np.float32) - V = np.zeros(lat.shape, dtype=np.float32) - R = np.zeros(lat.shape, dtype=np.float32) - for i in range(lat.shape[1]): - for j in range(lat.shape[0]): - r, phi = calc_r_phi(lon[j, i], lat[j, i]) - R[j, i] = r - r, phi = calc_r_phi( - lon[j, i] + isign * (dx / 2) * np.cos(alpha), lat[j, i] - isign * (dx / 2) * np.sin(alpha) - ) - V[j, i] = np.sin(alpha) * (omega * r * np.cos(phi)) + np.cos(alpha) * (-omega * r * np.sin(phi)) - r, phi = calc_r_phi( - lon[j, i] + isign * (dy / 2) * np.sin(alpha), lat[j, i] + isign * (dy / 2) * np.cos(alpha) - ) - U[j, i] = np.cos(alpha) * (omega * r * np.cos(phi)) - np.sin(alpha) * (-omega * r * np.sin(phi)) - return U, V, R - - U, V, R = calculate_UVR(lat, lon, dx, dy, omega, alpha) - - data = {"U": U, "V": V, "R": R} - dimensions = {"lon": lon, "lat": lat} - fieldset = FieldSet.from_data(data, dimensions, mesh="flat", gridindexingtype=gridindexingtype) - fieldset.U.interp_method = "cgrid_velocity" - fieldset.V.interp_method = "cgrid_velocity" - - def UpdateR(particle, fieldset, time): # pragma: no cover - if time == 0: - particle.radius_start = fieldset.R[time, particle.depth, particle.lat, particle.lon] - particle.radius = fieldset.R[time, particle.depth, particle.lat, particle.lon] - - MyParticle = Particle.add_variables( - [Variable("radius", dtype=np.float32, initial=0.0), Variable("radius_start", dtype=np.float32, initial=0.0)] - ) - - pset = ParticleSet(fieldset, pclass=MyParticle, lon=0, lat=4e3, time=0) - - pset.execute(pset.Kernel(UpdateR) + AdvectionRK4, runtime=timedelta(hours=14), dt=timedelta(minutes=5)) - assert np.allclose(pset.radius, pset.radius_start, atol=10) - - -@pytest.mark.parametrize("gridindexingtype", ["mitgcm", "nemo"]) -@pytest.mark.parametrize("withtime", [False, True]) -def test_cgrid_indexing_3D(gridindexingtype, withtime): - xdim = zdim = 201 - ydim = 2 - a = c = 20000 # domain size - b = 2 - lon = np.linspace(-a / 2, a / 2, xdim, dtype=np.float32) - lat = np.linspace(-b / 2, b / 2, ydim, dtype=np.float32) - depth = np.linspace(-c / 2, c / 2, zdim, dtype=np.float32) - dx, dz = lon[1] - lon[0], depth[1] - depth[0] - omega = 2 * np.pi / timedelta(days=1).total_seconds() - if withtime: - time = np.linspace(0, 24 * 60 * 60, 10) - dimensions = {"lon": lon, "lat": lat, "depth": depth, "time": time} - dsize = (time.size, depth.size, lat.size, lon.size) - else: - dimensions = {"lon": lon, "lat": lat, "depth": depth} - dsize = (depth.size, lat.size, lon.size) - - hindex_signs = {"nemo": -1, "mitgcm": 1} - hsign = hindex_signs[gridindexingtype] - - def calc_r_phi(ln, dp): - # r = np.sqrt(ln ** 2 + dp ** 2) - # phi = np.arcsin(dp/r) if r > 0 else 0 - return np.sqrt(ln**2 + dp**2), np.arctan2(ln, dp) - - def populate_UVWR(lat, lon, depth, dx, dz, omega): - U = np.zeros(dsize, dtype=np.float32) - V = np.zeros(dsize, dtype=np.float32) - W = np.zeros(dsize, dtype=np.float32) - R = np.zeros(dsize, dtype=np.float32) - - for i in range(lon.size): - for j in range(lat.size): - for k in range(depth.size): - r, phi = calc_r_phi(lon[i], depth[k]) - if withtime: - R[:, k, j, i] = r - else: - R[k, j, i] = r - r, phi = calc_r_phi(lon[i] + hsign * dx / 2, depth[k]) - if withtime: - W[:, k, j, i] = -omega * r * np.sin(phi) - else: - W[k, j, i] = -omega * r * np.sin(phi) - r, phi = calc_r_phi(lon[i], depth[k] + dz / 2) - if withtime: - U[:, k, j, i] = omega * r * np.cos(phi) - else: - U[k, j, i] = omega * r * np.cos(phi) - return U, V, W, R - - U, V, W, R = populate_UVWR(lat, lon, depth, dx, dz, omega) - data = {"U": U, "V": V, "W": W, "R": R} - fieldset = FieldSet.from_data(data, dimensions, mesh="flat", gridindexingtype=gridindexingtype) - fieldset.U.interp_method = "cgrid_velocity" - fieldset.V.interp_method = "cgrid_velocity" - fieldset.W.interp_method = "cgrid_velocity" - - def UpdateR(particle, fieldset, time): # pragma: no cover - if time == 0: - particle.radius_start = fieldset.R[time, particle.depth, particle.lat, particle.lon] - particle.radius = fieldset.R[time, particle.depth, particle.lat, particle.lon] - - MyParticle = Particle.add_variables( - [Variable("radius", dtype=np.float32, initial=0.0), Variable("radius_start", dtype=np.float32, initial=0.0)] - ) - - pset = ParticleSet(fieldset, pclass=MyParticle, depth=4e3, lon=0, lat=0, time=0) - - pset.execute(pset.Kernel(UpdateR) + AdvectionRK4_3D, runtime=timedelta(hours=14), dt=timedelta(minutes=5)) - assert np.allclose(pset.radius, pset.radius_start, atol=10) - - -@pytest.mark.parametrize("gridindexingtype", ["pop", "mom5"]) -@pytest.mark.parametrize("withtime", [False, True]) -def test_bgrid_indexing_3D(gridindexingtype, withtime): - xdim = zdim = 201 - ydim = 2 - a = c = 20000 # domain size - b = 2 - lon = np.linspace(-a / 2, a / 2, xdim, dtype=np.float32) - lat = np.linspace(-b / 2, b / 2, ydim, dtype=np.float32) - depth = np.linspace(-c / 2, c / 2, zdim, dtype=np.float32) - dx, dz = lon[1] - lon[0], depth[1] - depth[0] - omega = 2 * np.pi / timedelta(days=1).total_seconds() - if withtime: - time = np.linspace(0, 24 * 60 * 60, 10) - dimensions = {"lon": lon, "lat": lat, "depth": depth, "time": time} - dsize = (time.size, depth.size, lat.size, lon.size) - else: - dimensions = {"lon": lon, "lat": lat, "depth": depth} - dsize = (depth.size, lat.size, lon.size) - - vindex_signs = {"pop": 1, "mom5": -1} - vsign = vindex_signs[gridindexingtype] - - def calc_r_phi(ln, dp): - return np.sqrt(ln**2 + dp**2), np.arctan2(ln, dp) - - def populate_UVWR(lat, lon, depth, dx, dz, omega): - U = np.zeros(dsize, dtype=np.float32) - V = np.zeros(dsize, dtype=np.float32) - W = np.zeros(dsize, dtype=np.float32) - R = np.zeros(dsize, dtype=np.float32) - - for i in range(lon.size): - for j in range(lat.size): - for k in range(depth.size): - r, phi = calc_r_phi(lon[i], depth[k]) - if withtime: - R[:, k, j, i] = r - else: - R[k, j, i] = r - r, phi = calc_r_phi(lon[i] - dx / 2, depth[k]) - if withtime: - W[:, k, j, i] = -omega * r * np.sin(phi) - else: - W[k, j, i] = -omega * r * np.sin(phi) - # Since Parcels loads as dimensions only the depth of W-points - # and lon/lat of UV-points, W-points are similarly interpolated - # in MOM5 and POP. Indexing is shifted for UV-points. - r, phi = calc_r_phi(lon[i], depth[k] + vsign * dz / 2) - if withtime: - U[:, k, j, i] = omega * r * np.cos(phi) - else: - U[k, j, i] = omega * r * np.cos(phi) - return U, V, W, R - - U, V, W, R = populate_UVWR(lat, lon, depth, dx, dz, omega) - data = {"U": U, "V": V, "W": W, "R": R} - fieldset = FieldSet.from_data(data, dimensions, mesh="flat", gridindexingtype=gridindexingtype) - fieldset.U.interp_method = "bgrid_velocity" - fieldset.V.interp_method = "bgrid_velocity" - fieldset.W.interp_method = "bgrid_w_velocity" - - def UpdateR(particle, fieldset, time): # pragma: no cover - if time == 0: - particle.radius_start = fieldset.R[time, particle.depth, particle.lat, particle.lon] - particle.radius = fieldset.R[time, particle.depth, particle.lat, particle.lon] - - MyParticle = Particle.add_variables( - [Variable("radius", dtype=np.float32, initial=0.0), Variable("radius_start", dtype=np.float32, initial=0.0)] - ) - - pset = ParticleSet(fieldset, pclass=MyParticle, depth=-9.995e3, lon=0, lat=0, time=0) - - pset.execute(pset.Kernel(UpdateR) + AdvectionRK4_3D, runtime=timedelta(hours=14), dt=timedelta(minutes=5)) - assert np.allclose(pset.radius, pset.radius_start, atol=10) - - -@pytest.mark.parametrize( - "gridindexingtype", - [ - pytest.param( - "mom5", - marks=[ - pytest.mark.v4alpha, - pytest.mark.xfail(reason="https://github.com/OceanParcels/Parcels/pull/1936#issuecomment-2717408483"), - ], - ) - ], -) # TODO v4: add pop in params? -@pytest.mark.parametrize("extrapolation", [True, False]) -def test_bgrid_interpolation(gridindexingtype, extrapolation): - xi, yi = 3, 2 - if extrapolation: - zi = 0 if gridindexingtype == "mom5" else -1 - else: - zi = 2 - if gridindexingtype == "mom5": - ufile = str(TEST_DATA / "access-om2-01_u.nc") - vfile = str(TEST_DATA / "access-om2-01_v.nc") - wfile = str(TEST_DATA / "access-om2-01_wt.nc") - - filenames = { - "U": {"lon": ufile, "lat": ufile, "depth": wfile, "data": ufile}, - "V": {"lon": ufile, "lat": ufile, "depth": wfile, "data": vfile}, - "W": {"lon": ufile, "lat": ufile, "depth": wfile, "data": wfile}, - } - - variables = {"U": "u", "V": "v", "W": "wt"} - - dimensions = { - "U": {"lon": "xu_ocean", "lat": "yu_ocean", "depth": "sw_ocean", "time": "time"}, - "V": {"lon": "xu_ocean", "lat": "yu_ocean", "depth": "sw_ocean", "time": "time"}, - "W": {"lon": "xu_ocean", "lat": "yu_ocean", "depth": "sw_ocean", "time": "time"}, - } - - fieldset = FieldSet.from_mom5(filenames, variables, dimensions) - ds_u = xr.open_dataset(ufile) - ds_v = xr.open_dataset(vfile) - ds_w = xr.open_dataset(wfile) - u = ds_u.u.isel(time=0, st_ocean=zi, yu_ocean=yi, xu_ocean=xi) - v = ds_v.v.isel(time=0, st_ocean=zi, yu_ocean=yi, xu_ocean=xi) - w = ds_w.wt.isel(time=0, sw_ocean=zi, yt_ocean=yi, xt_ocean=xi) - - elif gridindexingtype == "pop": - datafname = str(TEST_DATA / "popdata.nc") - coordfname = str(TEST_DATA / "popcoordinates.nc") - filenames = { - "U": {"lon": coordfname, "lat": coordfname, "depth": coordfname, "data": datafname}, - "V": {"lon": coordfname, "lat": coordfname, "depth": coordfname, "data": datafname}, - "W": {"lon": coordfname, "lat": coordfname, "depth": coordfname, "data": datafname}, - } - - variables = {"U": "UVEL", "V": "VVEL", "W": "WVEL"} - dimensions = {"lon": "U_LON_2D", "lat": "U_LAT_2D", "depth": "w_dep"} - - fieldset = FieldSet.from_pop(filenames, variables, dimensions) - dsc = xr.open_dataset(coordfname) - dsd = xr.open_dataset(datafname) - u = dsd.UVEL.isel(k=zi, j=yi, i=xi) - v = dsd.VVEL.isel(k=zi, j=yi, i=xi) - w = dsd.WVEL.isel(k=zi, j=yi, i=xi) - - fieldset.U.units = UnitConverter() - fieldset.V.units = UnitConverter() - - def VelocityInterpolator(particle, fieldset, time): # pragma: no cover - particle.Uvel = fieldset.U[time, particle.depth, particle.lat, particle.lon] - particle.Vvel = fieldset.V[time, particle.depth, particle.lat, particle.lon] - particle.Wvel = fieldset.W[time, particle.depth, particle.lat, particle.lon] - - myParticle = Particle.add_variables( - [ - Variable("Uvel", dtype=np.float32, initial=0.0), - Variable("Vvel", dtype=np.float32, initial=0.0), - Variable("Wvel", dtype=np.float32, initial=0.0), - ] - ) - - for pointtype in ["U", "V", "W"]: - if gridindexingtype == "pop": - if pointtype in ["U", "V"]: - lons = dsc.U_LON_2D[yi, xi].values - lats = dsc.U_LAT_2D[yi, xi].values - deps = dsc.depth_t[zi].values - elif pointtype == "W": - lons = dsc.T_LON_2D[yi, xi].values - lats = dsc.T_LAT_2D[yi, xi].values - deps = dsc.w_dep[zi].values - if extrapolation: - deps = 5499.0 - elif gridindexingtype == "mom5": - if pointtype in ["U", "V"]: - lons = u.xu_ocean.data.reshape(1) - lats = u.yu_ocean.data.reshape(1) - deps = u.st_ocean.data.reshape(1) - elif pointtype == "W": - lons = w.xt_ocean.data.reshape(1) - lats = w.yt_ocean.data.reshape(1) - deps = w.sw_ocean.data.reshape(1) - if extrapolation: - deps = 0 - - pset = ParticleSet.from_list(fieldset=fieldset, pclass=myParticle, lon=lons, lat=lats, depth=deps) - pset.execute(VelocityInterpolator, runtime=1) - - convfactor = 0.01 if gridindexingtype == "pop" else 1.0 - if pointtype in ["U", "V"]: - assert np.allclose(pset.Uvel[0], u * convfactor) - assert np.allclose(pset.Vvel[0], v * convfactor) - elif pointtype == "W": - if extrapolation: - assert np.allclose(pset.Wvel[0], 0, atol=1e-9) - else: - assert np.allclose(pset.Wvel[0], w * convfactor) From 0357602a9e9f4fcb4ef3dec026aba66161fa66ec Mon Sep 17 00:00:00 2001 From: Joe Date: Mon, 7 Apr 2025 15:15:38 -0400 Subject: [PATCH 30/46] Mark interpolation tests for removal In a recent meeting among the parcels team members, the _interpolation.py module will likely be going away in favor of using a `Callable` attribute for the interpolation method. --- tests/test_interpolation.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py index 9b76481598..b49cb38cf3 100644 --- a/tests/test_interpolation.py +++ b/tests/test_interpolation.py @@ -56,6 +56,7 @@ def data_2d(): return create_interpolation_data().isel(depth=0).values +@pytest.mark.v4remove @pytest.mark.parametrize( "func, eta, xsi, expected", [ @@ -74,6 +75,7 @@ def test_raw_2d_interpolation(data_2d, func, eta, xsi, expected): assert func(ctx) == expected +@pytest.mark.v4remove @pytest.mark.usefixtures("tmp_interpolator_registry") def test_interpolator_override(): fieldset = create_fieldset_zeros_3d() @@ -86,6 +88,7 @@ def test_interpolator(ctx: interpolation.InterpolationContext3D): fieldset.U[0, 0.5, 0.5, 0.5] +@pytest.mark.v4remove @pytest.mark.usefixtures("tmp_interpolator_registry") def test_full_depth_provided_to_interpolators(): """The full depth needs to be provided to the interpolation schemes as some interpolators @@ -105,6 +108,7 @@ def test_interpolator2(ctx: interpolation.InterpolationContext3D): fieldset.U[0.5, 0.5, 0.5, 0.5] +@pytest.mark.v4remove @pytest.mark.parametrize( "interp_method", [ From f3903e4be16dd05dd363c2ab9434ef318fcbbd78 Mon Sep 17 00:00:00 2001 From: Joe Date: Mon, 7 Apr 2025 15:21:13 -0400 Subject: [PATCH 31/46] Mark tests for v4alpha --- tests/test_advection.py | 11 +++++++++++ tests/test_diffusion.py | 4 ++++ 2 files changed, 15 insertions(+) diff --git a/tests/test_advection.py b/tests/test_advection.py index a62fe878dd..92f812f524 100644 --- a/tests/test_advection.py +++ b/tests/test_advection.py @@ -57,6 +57,7 @@ def depth(): return np.linspace(0, 30, zdim, dtype=np.float32) +@pytest.mark.v4alpha def test_advection_zonal(lon, lat, depth): """Particles at high latitude move geographically faster due to the pole correction in `GeographicPolar`.""" npart = 10 @@ -88,6 +89,7 @@ def test_advection_zonal(lon, lat, depth): assert (np.diff(pset3D.lon) > 1.0e-4).all() +@pytest.mark.v4alpha def test_advection_meridional(lon, lat): """Particles at high latitude move geographically faster due to the pole correction in `GeographicPolar`.""" npart = 10 @@ -101,6 +103,7 @@ def test_advection_meridional(lon, lat): assert np.allclose(np.diff(pset.lat), delta_lat, rtol=1.0e-4) +@pytest.mark.v4alpha def test_advection_3D(): """Flat 2D zonal flow that increases linearly with depth from 0 m/s to 1 m/s.""" xdim = ydim = zdim = 2 @@ -122,6 +125,7 @@ def test_advection_3D(): assert np.allclose(pset.depth * pset.time, pset.lon, atol=1.0e-1) +@pytest.mark.v4alpha @pytest.mark.parametrize("direction", ["up", "down"]) @pytest.mark.parametrize("wErrorThroughSurface", [True, False]) def test_advection_3D_outofbounds(direction, wErrorThroughSurface): @@ -167,6 +171,7 @@ def SubmergeParticle(particle, fieldset, time): # pragma: no cover assert len(pset) == 0 +@pytest.mark.v4alpha @pytest.mark.parametrize("rk45_tol", [10, 100]) def test_advection_RK45(lon, lat, rk45_tol): npart = 10 @@ -324,6 +329,7 @@ def test_advection_periodic_zonal_meridional(): assert abs(pset.lat[0] - 0.15) < 0.1 +@pytest.mark.v4alpha @pytest.mark.parametrize("u", [-0.3, np.array(0.2)]) @pytest.mark.parametrize("v", [0.2, np.array(1)]) @pytest.mark.parametrize("w", [None, -0.2, np.array(0.7)]) @@ -401,6 +407,7 @@ def fieldset_stationary(): return create_fieldset_stationary() +@pytest.mark.v4alpha @pytest.mark.parametrize( "method, rtol, diffField", [ @@ -435,6 +442,7 @@ def test_stationary_eddy(fieldset_stationary, method, rtol, diffField): assert np.allclose(pset.lat, exp_lat, rtol=rtol) +@pytest.mark.v4alpha def test_stationary_eddy_vertical(): npart = 1 lon = np.linspace(12000, 21000, npart) @@ -505,6 +513,7 @@ def fieldset_moving(): return create_fieldset_moving() +@pytest.mark.v4alpha @pytest.mark.parametrize( "method, rtol, diffField", [ @@ -582,6 +591,7 @@ def fieldset_decaying(): return create_fieldset_decaying() +@pytest.mark.v4alpha @pytest.mark.parametrize( "method, rtol, diffField", [ @@ -622,6 +632,7 @@ def test_decaying_eddy(fieldset_decaying, method, rtol, diffField): assert np.allclose(pset.lat, exp_lat, rtol=rtol) +@pytest.mark.v4alpha def test_analyticalAgrid(): lon = np.arange(0, 15, dtype=np.float32) lat = np.arange(0, 15, dtype=np.float32) diff --git a/tests/test_diffusion.py b/tests/test_diffusion.py index 389174c759..267597d6ea 100644 --- a/tests/test_diffusion.py +++ b/tests/test_diffusion.py @@ -17,6 +17,7 @@ from tests.utils import create_fieldset_zeros_conversion +@pytest.mark.v4alpha @pytest.mark.parametrize("mesh", ["spherical", "flat"]) def test_fieldKh_Brownian(mesh): xdim = 200 @@ -50,6 +51,7 @@ def test_fieldKh_Brownian(mesh): assert np.allclose(np.mean(lats), 0, atol=tol) +@pytest.mark.v4alpha @pytest.mark.parametrize("mesh", ["spherical", "flat"]) @pytest.mark.parametrize("kernel", [AdvectionDiffusionM1, AdvectionDiffusionEM]) def test_fieldKh_SpatiallyVaryingDiffusion(mesh, kernel): @@ -83,6 +85,7 @@ def test_fieldKh_SpatiallyVaryingDiffusion(mesh, kernel): assert stats.skew(lons) > stats.skew(lats) +@pytest.mark.v4alpha @pytest.mark.parametrize("lambd", [1, 5]) def test_randomexponential(lambd): fieldset = create_fieldset_zeros_conversion() @@ -109,6 +112,7 @@ def vertical_randomexponential(particle, fieldset, time): # pragma: no cover assert np.allclose(np.mean(depth), expected_mean, rtol=0.1) +@pytest.mark.v4alpha @pytest.mark.parametrize("mu", [0.8 * np.pi, np.pi]) @pytest.mark.parametrize("kappa", [2, 4]) def test_randomvonmises(mu, kappa): From f360ded9fa578640e38239d3baccf8574991bf10 Mon Sep 17 00:00:00 2001 From: Joe Date: Mon, 7 Apr 2025 15:27:08 -0400 Subject: [PATCH 32/46] Mark tests with markers for fieldset_sampling Tests that were marked for future were marked this way if they need a major rewrite to be consistent with the field and fieldset API. The functionality that they test I suspect would be things we want to maintain. --- tests/test_fieldset_sampling.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/test_fieldset_sampling.py b/tests/test_fieldset_sampling.py index c8033a658d..d0e0c9e60a 100644 --- a/tests/test_fieldset_sampling.py +++ b/tests/test_fieldset_sampling.py @@ -85,6 +85,7 @@ def fieldset_geometric_polar(): return create_fieldset_geometric_polar() +@pytest.mark.v4alpha def test_fieldset_sample(fieldset): """Sample the fieldset using indexing notation.""" xdim, ydim = 120, 80 @@ -98,6 +99,7 @@ def test_fieldset_sample(fieldset): assert np.allclose(u_s, lat, rtol=1e-5) +@pytest.mark.v4alpha def test_fieldset_sample_eval(fieldset): """Sample the fieldset using the explicit eval function.""" xdim, ydim = 60, 60 @@ -120,6 +122,7 @@ def test_fieldset_polar_with_halo(fieldset_geometric_polar): assert pset.lon[0] == 0.0 +@pytest.mark.v4alpha @pytest.mark.parametrize("zdir", [-1, 1]) def test_verticalsampling(zdir): dims = (4, 2, 2) @@ -171,6 +174,7 @@ def test_pset_from_field(): assert np.allclose(pdens / sum(pdens.flatten()), startfield / sum(startfield.flatten()), atol=1e-2) +@pytest.mark.v4alpha def test_nearest_neighbor_interpolation2D(): npart = 81 dims = (2, 2) @@ -193,6 +197,7 @@ def test_nearest_neighbor_interpolation2D(): assert np.allclose(pset.p[(pset.lon > 0.5) | (pset.lat < 0.5)], 0.0, rtol=1e-5) +@pytest.mark.v4alpha def test_nearest_neighbor_interpolation3D(): npart = 81 dims = (2, 2, 2) @@ -219,6 +224,7 @@ def test_nearest_neighbor_interpolation3D(): assert np.allclose(pset.p[(pset.lon > 0.5) | (pset.lat < 0.5) & (pset.depth < 0.5)], 0.0, rtol=1e-5) +@pytest.mark.v4future @pytest.mark.parametrize("withDepth", [True, False]) @pytest.mark.parametrize("arrtype", ["ones", "rand"]) def test_inversedistance_nearland(withDepth, arrtype): @@ -260,6 +266,7 @@ def test_inversedistance_nearland(withDepth, arrtype): assert success +@pytest.mark.v4future @pytest.mark.parametrize("boundaryslip", ["freeslip", "partialslip"]) @pytest.mark.parametrize("withW", [False, True]) @pytest.mark.parametrize("withT", [False, True]) @@ -311,6 +318,7 @@ def test_partialslip_nearland_zonal(boundaryslip, withW, withT): assert np.allclose([p.depth for p in pset], 0.1) +@pytest.mark.v4future @pytest.mark.parametrize("boundaryslip", ["freeslip", "partialslip"]) @pytest.mark.parametrize("withW", [False, True]) def test_partialslip_nearland_meridional(boundaryslip, withW): @@ -353,6 +361,7 @@ def test_partialslip_nearland_meridional(boundaryslip, withW): assert np.allclose([p.depth for p in pset], 0.1) +@pytest.mark.v4future @pytest.mark.parametrize("boundaryslip", ["freeslip", "partialslip"]) def test_partialslip_nearland_vertical(boundaryslip): npart = 20 @@ -381,6 +390,7 @@ def test_partialslip_nearland_vertical(boundaryslip): assert np.allclose([p.lat for p in pset], 0.1) +@pytest.mark.v4alpha def test_fieldset_sample_particle(): """Sample the fieldset using an array of particles.""" npart = 120 @@ -403,6 +413,7 @@ def test_fieldset_sample_particle(): assert np.allclose(pset.u, lat, rtol=1e-6) +@pytest.mark.v4alpha def test_fieldset_sample_geographic(fieldset_geometric): """Sample a fieldset with conversion to geographic units (degrees).""" npart = 120 @@ -419,6 +430,7 @@ def test_fieldset_sample_geographic(fieldset_geometric): assert np.allclose(pset.u, lat, rtol=1e-6) +@pytest.mark.v4alpha def test_fieldset_sample_geographic_noconvert(fieldset_geometric): """Sample a fieldset without conversion to geographic units.""" npart = 120 @@ -435,6 +447,7 @@ def test_fieldset_sample_geographic_noconvert(fieldset_geometric): assert np.allclose(pset.u, lat * 1000 * 1.852 * 60, rtol=1e-6) +@pytest.mark.v4alpha def test_fieldset_sample_geographic_polar(fieldset_geometric_polar): """Sample a fieldset with conversion to geographic units and a pole correction.""" npart = 120 @@ -451,6 +464,7 @@ def test_fieldset_sample_geographic_polar(fieldset_geometric_polar): assert np.allclose(pset.u, lat, rtol=1e-2) +@pytest.mark.v4alpha def test_meridionalflow_spherical(): """Create uniform NORTHWARD flow on spherical earth and advect particles. @@ -480,6 +494,7 @@ def test_meridionalflow_spherical(): assert pset.lon[1] - lonstart[1] < 1e-4 +@pytest.mark.v4alpha def test_zonalflow_spherical(): """Create uniform EASTWARD flow on spherical earth and advect particles. @@ -516,6 +531,7 @@ def test_zonalflow_spherical(): assert abs(pset.p[1] - p_fld) < 1e-4 +@pytest.mark.v4alpha def test_random_field(): """Sampling test that tests for overshoots by sampling a field of random numbers between 0 and 1.""" xdim, ydim = 20, 20 @@ -540,6 +556,7 @@ def test_random_field(): assert (sampled >= 0.0).all() +@pytest.mark.v4alpha @pytest.mark.parametrize("allow_time_extrapolation", [True, False]) def test_sampling_out_of_bounds_time(allow_time_extrapolation): xdim, ydim, tdim = 10, 10, 10 @@ -658,6 +675,7 @@ def test_sample(particle, fieldset, time): # pragma: no cover assert np.allclose(pset.sample_var, 10.0) +@pytest.mark.v4future @pytest.mark.parametrize("npart", [1, 10]) def test_sampling_multigrids_non_vectorfield(npart): xdim, ydim = 100, 200 @@ -697,6 +715,7 @@ def test_sample(particle, fieldset, time): # pragma: no cover assert np.allclose(pset.sample_var, 10.0) +@pytest.mark.v4future @pytest.mark.parametrize("ugridfactor", [1, 10]) def test_sampling_multiple_grid_sizes(ugridfactor): xdim, ydim = 10, 20 @@ -724,6 +743,7 @@ def test_sampling_multiple_grid_sizes(ugridfactor): assert np.all((0 <= pset.xi) & (pset.xi < xdim * ugridfactor)) +@pytest.mark.v4future def test_multiple_grid_addlater_error(): xdim, ydim = 10, 20 U = Field( From 29fb96418b2bf73cdc59b0f9faf9a386a1254d1c Mon Sep 17 00:00:00 2001 From: Joe Date: Wed, 9 Apr 2025 11:01:18 -0400 Subject: [PATCH 33/46] Add FESOM periodic channel data --- parcels/tools/exampledata_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/parcels/tools/exampledata_utils.py b/parcels/tools/exampledata_utils.py index ce2d0a41b5..66713a2238 100644 --- a/parcels/tools/exampledata_utils.py +++ b/parcels/tools/exampledata_utils.py @@ -32,6 +32,12 @@ "decaying_moving_eddyU.nc", "decaying_moving_eddyV.nc", ], + "FESOM_periodic_channel": [ + "fesom_channel.nc", + "u.fesom_channel.nc", + "v.fesom_channel.nc", + "w.fesom_channel.nc", + ], "NemoCurvilinear_data": [ "U_purely_zonal-ORCA025_grid_U.nc4", "V_purely_zonal-ORCA025_grid_V.nc4", From 58a138efc651ea155c981780a25ac9880cec19f9 Mon Sep 17 00:00:00 2001 From: Joe Date: Wed, 9 Apr 2025 11:02:44 -0400 Subject: [PATCH 34/46] Add particleset execute test for fesom channel --- tests/v4/test_uxarray_fieldset.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/tests/v4/test_uxarray_fieldset.py b/tests/v4/test_uxarray_fieldset.py index 27016d0ebe..cf5faca149 100644 --- a/tests/v4/test_uxarray_fieldset.py +++ b/tests/v4/test_uxarray_fieldset.py @@ -1,4 +1,5 @@ import os +from datetime import timedelta import uxarray as ux @@ -62,8 +63,19 @@ def test_set_interp_methods(): fieldset.W.interp_method = UXPiecewiseLinearNode -# pset = ParticleSet(fieldset, pclass=Particle) -# pset.execute(associate_interp_function, endtime=timedelta(days=1), dt=timedelta(hours=1)) - -if __name__ == "__main__": - test_set_interp_methods() +def test_fesom_channel(): + grid_path = f"{V4_TEST_DATA}/fesom_channel.nc" + data_path = [ + f"{V4_TEST_DATA}/u.fesom_channel.nc", + f"{V4_TEST_DATA}/v.fesom_channel.nc", + f"{V4_TEST_DATA}/w.fesom_channel.nc", + ] + ds = ux.open_mfdataset(grid_path, data_path) + ds = ds.rename_vars({"u": "U", "v": "V", "w": "W"}) + fieldset = FieldSet([ds]) + # Set the interpolation method for each field + fieldset.U.interp_method = UXPiecewiseConstantFace + fieldset.V.interp_method = UXPiecewiseConstantFace + fieldset.W.interp_method = UXPiecewiseLinearNode + pset = ParticleSet(fieldset, pclass=Particle) + pset.execute(endtime=timedelta(days=1), dt=timedelta(hours=1)) From 62cd3249c09e83386a4711e31747ffce89f74ad0 Mon Sep 17 00:00:00 2001 From: Joe Date: Wed, 9 Apr 2025 11:08:30 -0400 Subject: [PATCH 35/46] Add grid attribute to field. I have currently allowed for the grid attribute to be None, since parcels.Grid is not implemented on this branch. Once parcels.Grid is added, we can add this as a valid type for the grid attribute. --- parcels/field.py | 57 +++++++++++++++++++++------------------------ parcels/fieldset.py | 5 +++- 2 files changed, 30 insertions(+), 32 deletions(-) diff --git a/parcels/field.py b/parcels/field.py index 01fdd60615..cf677bdd0a 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -143,24 +143,21 @@ def __init__( self, name: str, data: xr.DataArray | ux.UxDataArray, + grid: ux.UxGrid | None = None, # To do : Once parcels.Grid class is added, allow for it to be passed here mesh_type: Mesh = "flat", interp_method: Callable | None = None, allow_time_extrapolation: bool | None = None, ): self.name = name self.data = data + self.grid = grid self._validate_dataarray() self._parent_mesh = data.attrs["mesh"] self._mesh_type = mesh_type self._location = data.attrs["location"] - - # Set the vertical location - if "nz1" in data.dims: - self._vertical_location = "center" - elif "nz" in data.dims: - self._vertical_location = "face" + self._vertical_location = None # Setting the interpolation method dynamically if interp_method is None: @@ -184,9 +181,14 @@ def __init__( self.allow_time_extrapolation = allow_time_extrapolation if type(self.data) is ux.UxDataArray: - self._spatialhash = self.data.uxgrid.get_spatial_hash() + self._spatialhash = self.grid.get_spatial_hash() self._gtype = None - else: + # Set the vertical location + if "nz1" in data.dims: + self._vertical_location = "center" + elif "nz" in data.dims: + self._vertical_location = "face" + else: # To do : This bit probably needs an overhaul once the parcels.Grid class is integrated. self._spatialhash = None # Set the grid type if "x_g" in self.data.coords: @@ -223,13 +225,6 @@ def __init__( def __repr__(self): return field_repr(self) - @property - def grid(self): - if type(self.data) is ux.UxDataArray: - return self.data.uxgrid - else: - return self.data # To do : need to decide on what to return for xarray.DataArray objects - @property def lonlat_minmax(self): return self._lonlat_minmax @@ -238,11 +233,11 @@ def lonlat_minmax(self): def lat(self): if type(self.data) is ux.UxDataArray: if self._location == "node": - return self.data.uxgrid.node_lat + return self.grid.node_lat elif self._location == "face": - return self.data.uxgrid.face_lat + return self.grid.face_lat elif self._location == "edge": - return self.data.uxgrid.edge_lat + return self.grid.edge_lat else: return self.data.lat @@ -250,11 +245,11 @@ def lat(self): def lon(self): if type(self.data) is ux.UxDataArray: if self._location == "node": - return self.data.uxgrid.node_lon + return self.grid.node_lon elif self._location == "face": - return self.data.uxgrid.face_lon + return self.grid.face_lon elif self._location == "edge": - return self.data.uxgrid.edge_lon + return self.grid.edge_lon else: return self.data.lon @@ -262,9 +257,9 @@ def lon(self): def depth(self): if type(self.data) is ux.UxDataArray: if self._vertical_location == "center": - return self.data.uxgrid.nz1 + return self.grid.nz1 elif self._vertical_location == "face": - return self.data.uxgrid.nz + return self.grid.nz else: return self.data.depth @@ -304,7 +299,7 @@ def zdim(self): @property def n_face(self): if type(self.data) is ux.uxDataArray: - return self.data.uxgrid.n_face + return self.grid.n_face else: return 0 # To do : Discuss what we want to return for dataarray obj @@ -324,12 +319,12 @@ def interp_method(self, method: Callable): def _get_ux_barycentric_coordinates(self, y, x, fi): """Checks if a point is inside a given face id. Used for unstructured grids.""" # Check if particle is in the same face, otherwise search again. - n_nodes = self.data.uxgrid.n_nodes_per_face[fi].to_numpy() - node_ids = self.data.uxgrid.face_node_connectivity[fi, 0:n_nodes] + n_nodes = self.grid.n_nodes_per_face[fi].to_numpy() + node_ids = self.grid.face_node_connectivity[fi, 0:n_nodes] nodes = np.column_stack( ( - np.deg2rad(self.data.uxgrid.node_lon[node_ids].to_numpy()), - np.deg2rad(self.data.uxgrid.node_lat[node_ids].to_numpy()), + np.deg2rad(self.grid.node_lon[node_ids].to_numpy()), + np.deg2rad(self.grid.node_lat[node_ids].to_numpy()), ) ) @@ -359,7 +354,7 @@ def _search_indices_unstructured(self, z, y, x, ei=None, search2D=False): return bcoords, ei else: # In this case we need to search the neighbors - for neighbor in self.data.uxgrid.face_face_connectivity[fi, :]: + for neighbor in self.grid.face_face_connectivity[fi, :]: bcoords, err = self._get_ux_barycentric_coordinates(y, x, neighbor) if ((bcoords >= 0).all()) and ((bcoords <= 1.0).all()) and err < tol: # To do: Do the vertical grid search @@ -537,12 +532,12 @@ def _validate_dataarray(self): def _validate_uxgrid(self): """Verifies that all the required attributes are present in the uxarray.UxDataArray.UxGrid object.""" - if "Conventions" not in self.data.uxgrid.attrs.keys(): + if "Conventions" not in self.grid.attrs.keys(): raise ValueError( f"Field {self.name} is missing a 'Conventions' attribute in the field's metadata. " "This attribute is required for uxarray.UxDataArray objects." ) - if self.data.uxgrid.attrs["Conventions"] != "UGRID-1.0": + if self.grid.attrs["Conventions"] != "UGRID-1.0": raise ValueError( f"Field {self.name} has a 'Conventions' attribute that is not 'UGRID-1.0'. " "This attribute is required for uxarray.UxDataArray objects." diff --git a/parcels/fieldset.py b/parcels/fieldset.py index 0353af9283..64332c4004 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -52,7 +52,10 @@ def __init__(self, datasets: list[xr.Dataset | ux.UxDataset]): # Create pointers to each (Ux)DataArray for ds in datasets: for field in ds.data_vars: - self.add_field(Field(field, ds[field]), field) + if type(ds[field]) is ux.UxDataArray: + self.add_field(Field(field, ds[field], grid=ds[field].uxgrid), field) + else: + self.add_field(Field(field, ds[field]), field) self._gridset_size += 1 self._fieldnames.append(field) From 884a6202d38edd6093e38025d9353313ee833660 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 9 Apr 2025 15:19:14 +0000 Subject: [PATCH 36/46] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/examples/tutorial_stommel_uxarray.ipynb | 154 +++++++++---------- parcels/field.py | 2 +- tests/test_fieldset.py | 1 - 3 files changed, 71 insertions(+), 86 deletions(-) diff --git a/docs/examples/tutorial_stommel_uxarray.ipynb b/docs/examples/tutorial_stommel_uxarray.ipynb index cc7187e985..d5bc4317cf 100644 --- a/docs/examples/tutorial_stommel_uxarray.ipynb +++ b/docs/examples/tutorial_stommel_uxarray.ipynb @@ -23,10 +23,11 @@ " Ph.D. dissertation, University of Bologna\n", " http://amsdottorato.unibo.it/1733/1/Fabbroni_Nicoletta_Tesi.pdf\n", " \"\"\"\n", - " import uxarray as ux\n", - " import numpy as np\n", " import math\n", + "\n", + " import numpy as np\n", " import pandas as pd\n", + " import uxarray as ux\n", "\n", " a = b = 66666 * 1e3\n", " scalefac = 0.00025 # to scale for physically meaningful velocities\n", @@ -35,20 +36,17 @@ " # Crowd points to the west edge of the domain\n", " # using a polyonmial map on x-direction\n", " x = np.linspace(0, 1, xdim, dtype=np.float32)\n", - " lon, lat = np.meshgrid(\n", - " a*x,\n", - " np.linspace(0, b, ydim, dtype=np.float32)\n", - " )\n", - " points = (lon.flatten()/1111111.111111111,lat.flatten()/1111111.111111111)\n", - " \n", + " lon, lat = np.meshgrid(a * x, np.linspace(0, b, ydim, dtype=np.float32))\n", + " points = (lon.flatten() / 1111111.111111111, lat.flatten() / 1111111.111111111)\n", + "\n", " # Create the grid\n", " uxgrid = ux.Grid.from_points(points, method=\"regional_delaunay\")\n", " uxgrid.construct_face_centers()\n", "\n", " # Define arrays U (zonal), V (meridional) and P (sea surface height)\n", - " U = np.zeros((1,1,lat.size), dtype=np.float32)\n", - " V = np.zeros((1,1,lat.size), dtype=np.float32)\n", - " P = np.zeros((1,1,lat.size), dtype=np.float32)\n", + " U = np.zeros((1, 1, lat.size), dtype=np.float32)\n", + " V = np.zeros((1, 1, lat.size), dtype=np.float32)\n", + " P = np.zeros((1, 1, lat.size), dtype=np.float32)\n", "\n", " beta = 2e-11\n", " r = 1 / (11.6 * 86400)\n", @@ -58,33 +56,27 @@ " for x, y in zip(lon.flatten(), lat.flatten()):\n", " xi = x / a\n", " yi = y / b\n", - " P[0,0,i] = (\n", - " (1 - math.exp(-xi / es) - xi)\n", - " * math.pi\n", - " * np.sin(math.pi * yi)\n", - " * scalefac\n", + " P[0, 0, i] = (\n", + " (1 - math.exp(-xi / es) - xi) * math.pi * np.sin(math.pi * yi) * scalefac\n", " )\n", - " U[0,0,i] = (\n", + " U[0, 0, i] = (\n", " -(1 - math.exp(-xi / es) - xi)\n", " * math.pi**2\n", " * np.cos(math.pi * yi)\n", " * scalefac\n", " )\n", - " V[0,0,i] = (\n", - " (math.exp(-xi / es) / es - 1)\n", - " * math.pi\n", - " * np.sin(math.pi * yi)\n", - " * scalefac\n", + " V[0, 0, i] = (\n", + " (math.exp(-xi / es) / es - 1) * math.pi * np.sin(math.pi * yi) * scalefac\n", " )\n", - " i+=1\n", + " i += 1\n", "\n", " u = ux.UxDataArray(\n", " data=U,\n", - " name='u',\n", + " name=\"u\",\n", " uxgrid=uxgrid,\n", - " dims=[\"time\",\"nz1\",\"n_node\"],\n", - " coords = dict(\n", - " time=([\"time\"], pd.to_datetime(['2000-01-01'])),\n", + " dims=[\"time\", \"nz1\", \"n_node\"],\n", + " coords=dict(\n", + " time=([\"time\"], pd.to_datetime([\"2000-01-01\"])),\n", " nz1=([\"nz1\"], [0]),\n", " ),\n", " attrs=dict(\n", @@ -96,11 +88,11 @@ " )\n", " v = ux.UxDataArray(\n", " data=V,\n", - " name='v',\n", + " name=\"v\",\n", " uxgrid=uxgrid,\n", - " dims=[\"time\",\"nz1\",\"n_node\"],\n", - " coords = dict(\n", - " time=([\"time\"], pd.to_datetime(['2000-01-01'])),\n", + " dims=[\"time\", \"nz1\", \"n_node\"],\n", + " coords=dict(\n", + " time=([\"time\"], pd.to_datetime([\"2000-01-01\"])),\n", " nz1=([\"nz1\"], [0]),\n", " ),\n", " attrs=dict(\n", @@ -112,11 +104,11 @@ " )\n", " p = ux.UxDataArray(\n", " data=P,\n", - " name='p',\n", + " name=\"p\",\n", " uxgrid=uxgrid,\n", - " dims=[\"time\",\"nz1\",\"n_node\"],\n", - " coords = dict(\n", - " time=([\"time\"], pd.to_datetime(['2000-01-01'])),\n", + " dims=[\"time\", \"nz1\", \"n_node\"],\n", + " coords=dict(\n", + " time=([\"time\"], pd.to_datetime([\"2000-01-01\"])),\n", " nz1=([\"nz1\"], [0]),\n", " ),\n", " attrs=dict(\n", @@ -127,21 +119,17 @@ " ),\n", " )\n", "\n", + " return ux.UxDataset({\"u\": u, \"v\": v, \"p\": p}, uxgrid=uxgrid)\n", "\n", - " return ux.UxDataset(\n", - " {'u':u, 'v':v, 'p': p}, \n", - " uxgrid=uxgrid\n", - " )\n", "\n", - "uxds = stommel_fieldset_uxarray(50,50)\n", + "uxds = stommel_fieldset_uxarray(50, 50)\n", "\n", "uxds.uxgrid.plot(\n", " line_width=0.5,\n", " height=500,\n", " width=1000,\n", " title=\"Regional Delaunay Regions\",\n", - ")\n", - "\n" + ")" ] }, { @@ -159,10 +147,11 @@ " Ph.D. dissertation, University of Bologna\n", " http://amsdottorato.unibo.it/1733/1/Fabbroni_Nicoletta_Tesi.pdf\n", " \"\"\"\n", - " import xarray as xr\n", - " import numpy as np\n", " import math\n", + "\n", + " import numpy as np\n", " import pandas as pd\n", + " import xarray as xr\n", "\n", " a = b = 10000 * 1e3\n", " scalefac = 0.05 # to scale for physically meaningful velocities\n", @@ -173,9 +162,9 @@ " lat = np.linspace(0, b, ydim, dtype=np.float32)\n", "\n", " # Define arrays U (zonal), V (meridional) and P (sea surface height)\n", - " U = np.zeros((1,1,lat.size, lon.size), dtype=np.float32)\n", - " V = np.zeros((1,1,lat.size, lon.size), dtype=np.float32)\n", - " P = np.zeros((1,1,lat.size, lon.size), dtype=np.float32)\n", + " U = np.zeros((1, 1, lat.size, lon.size), dtype=np.float32)\n", + " V = np.zeros((1, 1, lat.size, lon.size), dtype=np.float32)\n", + " P = np.zeros((1, 1, lat.size, lon.size), dtype=np.float32)\n", "\n", " beta = 2e-11\n", " r = 1 / (11.6 * 86400)\n", @@ -185,48 +174,48 @@ " for i in range(lon.size):\n", " xi = lon[i] / a\n", " yi = lat[j] / b\n", - " P[...,j, i] = (\n", + " P[..., j, i] = (\n", " (1 - math.exp(-xi / es) - xi)\n", " * math.pi\n", " * np.sin(math.pi * yi)\n", " * scalefac\n", " )\n", " if grid_type == \"A\":\n", - " U[...,j, i] = (\n", + " U[..., j, i] = (\n", " -(1 - math.exp(-xi / es) - xi)\n", " * math.pi**2\n", " * np.cos(math.pi * yi)\n", " * scalefac\n", " )\n", - " V[...,j, i] = (\n", + " V[..., j, i] = (\n", " (math.exp(-xi / es) / es - 1)\n", " * math.pi\n", " * np.sin(math.pi * yi)\n", " * scalefac\n", " )\n", "\n", - " time = pd.to_datetime(['2000-01-01'])\n", + " time = pd.to_datetime([\"2000-01-01\"])\n", " z = [0]\n", " if grid_type == \"C\":\n", - " V[...,:, 1:] = (P[...,:, 1:] - P[...,:, 0:-1]) / dx * a\n", - " U[...,1:, :] = -(P[...,1:, :] - P[...,0:-1, :]) / dy * b\n", - " u_dims = [\"time\",\"nz1\",\"face_lat\", \"node_lon\"]\n", + " V[..., :, 1:] = (P[..., :, 1:] - P[..., :, 0:-1]) / dx * a\n", + " U[..., 1:, :] = -(P[..., 1:, :] - P[..., 0:-1, :]) / dy * b\n", + " u_dims = [\"time\", \"nz1\", \"face_lat\", \"node_lon\"]\n", " u_lat = lat\n", " u_lon = lon - dx * 0.5\n", " u_location = \"x_edge\"\n", - " v_dims = [\"time\",\"nz1\",\"node_lat\", \"face_lon\"]\n", + " v_dims = [\"time\", \"nz1\", \"node_lat\", \"face_lon\"]\n", " v_lat = lat - dy * 0.5\n", " v_lon = lon\n", " v_location = \"y_edge\"\n", - " p_dims = [\"time\",\"nz1\",\"face_lat\", \"face_lon\"]\n", + " p_dims = [\"time\", \"nz1\", \"face_lat\", \"face_lon\"]\n", " p_lat = lat\n", " p_lon = lon\n", " p_location = \"face\"\n", - " \n", + "\n", " else:\n", - " u_dims = [\"time\",\"nz1\",\"node_lat\", \"node_lon\"]\n", - " v_dims = [\"time\",\"nz1\",\"node_lat\", \"node_lon\"]\n", - " p_dims = [\"time\",\"nz1\",\"node_lat\", \"node_lon\"]\n", + " u_dims = [\"time\", \"nz1\", \"node_lat\", \"node_lon\"]\n", + " v_dims = [\"time\", \"nz1\", \"node_lat\", \"node_lon\"]\n", + " p_dims = [\"time\", \"nz1\", \"node_lat\", \"node_lon\"]\n", " u_lat = lat\n", " u_lon = lon\n", " v_lat = lat\n", @@ -239,9 +228,9 @@ "\n", " u = xr.DataArray(\n", " data=U,\n", - " name='u',\n", + " name=\"u\",\n", " dims=u_dims,\n", - " coords = [time,z,u_lat,u_lon],\n", + " coords=[time, z, u_lat, u_lon],\n", " attrs=dict(\n", " description=\"zonal velocity\",\n", " units=\"m/s\",\n", @@ -251,9 +240,9 @@ " )\n", " v = xr.DataArray(\n", " data=V,\n", - " name='v',\n", + " name=\"v\",\n", " dims=v_dims,\n", - " coords = [time,z,v_lat,v_lon],\n", + " coords=[time, z, v_lat, v_lon],\n", " attrs=dict(\n", " description=\"meridional velocity\",\n", " units=\"m/s\",\n", @@ -263,9 +252,9 @@ " )\n", " p = xr.DataArray(\n", " data=P,\n", - " name='p',\n", + " name=\"p\",\n", " dims=p_dims,\n", - " coords = [time,z,p_lat,p_lon],\n", + " coords=[time, z, p_lat, p_lon],\n", " attrs=dict(\n", " description=\"pressure\",\n", " units=\"N/m^2\",\n", @@ -274,12 +263,11 @@ " ),\n", " )\n", "\n", - " return xr.Dataset(\n", - " {'u':u, 'v':v, 'p': p}\n", - " )\n", + " return xr.Dataset({\"u\": u, \"v\": v, \"p\": p})\n", + "\n", "\n", - "ds_arakawa_a = stommel_fieldset_xarray(50,50,\"A\")\n", - "ds_arakawa_c = stommel_fieldset_xarray(50,50,\"C\")\n" + "ds_arakawa_a = stommel_fieldset_xarray(50, 50, \"A\")\n", + "ds_arakawa_c = stommel_fieldset_xarray(50, 50, \"C\")" ] }, { @@ -316,10 +304,11 @@ "outputs": [], "source": [ "import numpy as np\n", - "min_length_scale = 1111111.111111111*np.sqrt(np.min(uxds.uxgrid.face_areas))\n", + "\n", + "min_length_scale = 1111111.111111111 * np.sqrt(np.min(uxds.uxgrid.face_areas))\n", "print(min_length_scale)\n", "\n", - "max_v = np.sqrt(uxds['u']**2 + uxds['v']**2).max()\n", + "max_v = np.sqrt(uxds[\"u\"] ** 2 + uxds[\"v\"] ** 2).max()\n", "print(max_v)\n", "\n", "cfl = 0.1\n", @@ -333,24 +322,21 @@ "metadata": {}, "outputs": [], "source": [ - "import uxarray as ux\n", "from datetime import timedelta\n", - "from parcels import (\n", - " UXFieldSet,\n", - " ParticleSet,\n", - " Particle,\n", - " UxAdvectionEuler\n", - ")\n", + "\n", "import numpy as np\n", + "import uxarray as ux\n", + "\n", + "from parcels import Particle, ParticleSet, UxAdvectionEuler, UXFieldSet\n", "\n", "npart = 10\n", "fieldset = UXFieldSet(uxds)\n", "# pset = ParticleSet(\n", - "# fieldset, \n", - "# pclass=Particle, \n", - "# lon=np.linspace(1, 59, npart), \n", + "# fieldset,\n", + "# pclass=Particle,\n", + "# lon=np.linspace(1, 59, npart),\n", "# lat=np.zeros(npart)+30)\n", - "# pset.execute(UxAdvectionEuler, runtime=timedelta(hours=24), dt=timedelta(seconds=dt))\n" + "# pset.execute(UxAdvectionEuler, runtime=timedelta(hours=24), dt=timedelta(seconds=dt))" ] } ], diff --git a/parcels/field.py b/parcels/field.py index 26252208b2..8416b7afc4 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -228,7 +228,7 @@ def __repr__(self): @property def lonlat_minmax(self): return self._lonlat_minmax - + @property def units(self): return self._units diff --git a/tests/test_fieldset.py b/tests/test_fieldset.py index c479f1cbd6..c9cdc02c37 100644 --- a/tests/test_fieldset.py +++ b/tests/test_fieldset.py @@ -190,7 +190,6 @@ def test_field_from_netcdf_fieldtypes(): @pytest.mark.v4alpha - def test_fieldset_from_agrid_dataset(): filenames = { "lon": str(TEST_DATA / "mask_nemo_cross_180lon.nc"), From 9c8fbdad302f8121b7646fc1ad9cd8d1d2f3e10c Mon Sep 17 00:00:00 2001 From: Joe Date: Wed, 9 Apr 2025 11:22:29 -0400 Subject: [PATCH 37/46] Formatting fixes --- parcels/field.py | 2 +- tests/test_fieldset.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/parcels/field.py b/parcels/field.py index 26252208b2..8416b7afc4 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -228,7 +228,7 @@ def __repr__(self): @property def lonlat_minmax(self): return self._lonlat_minmax - + @property def units(self): return self._units diff --git a/tests/test_fieldset.py b/tests/test_fieldset.py index c479f1cbd6..9824695297 100644 --- a/tests/test_fieldset.py +++ b/tests/test_fieldset.py @@ -13,6 +13,7 @@ Variable, ) from parcels.field import Field, VectorField +from parcels.tools.converters import GeographicPolar, UnitConverter from tests.utils import TEST_DATA @@ -190,7 +191,6 @@ def test_field_from_netcdf_fieldtypes(): @pytest.mark.v4alpha - def test_fieldset_from_agrid_dataset(): filenames = { "lon": str(TEST_DATA / "mask_nemo_cross_180lon.nc"), @@ -269,9 +269,9 @@ def test_add_field_after_pset(fieldtype): field2 = Field("field2", fieldset.U.data, lon=fieldset.U.lon, lat=fieldset.U.lat) vfield = VectorField("vfield", field1, field2) with pytest.raises(RuntimeError): - if field_type == "normal": + if fieldtype == "normal": fieldset.add_field(field1) - elif field_type == "vector": + elif fieldtype == "vector": fieldset.add_vector_field(vfield) From 4bf5096b64674c3249dee5b26b5e7b1570555d06 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 9 Apr 2025 18:16:55 +0200 Subject: [PATCH 38/46] Update test to use fesom download Also move the data to a pytest fixture :) --- parcels/field.py | 2 +- tests/v4/test_uxarray_fieldset.py | 68 +++++++++++-------------------- 2 files changed, 24 insertions(+), 46 deletions(-) diff --git a/parcels/field.py b/parcels/field.py index 8416b7afc4..1703653456 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -143,7 +143,7 @@ def __init__( self, name: str, data: xr.DataArray | ux.UxDataArray, - grid: ux.UxGrid | None = None, # To do : Once parcels.Grid class is added, allow for it to be passed here + grid: ux.Grid | None = None, # To do : Once parcels.Grid class is added, allow for it to be passed here mesh_type: Mesh = "flat", interp_method: Callable | None = None, allow_time_extrapolation: bool | None = None, diff --git a/tests/v4/test_uxarray_fieldset.py b/tests/v4/test_uxarray_fieldset.py index cf5faca149..0e168fd098 100644 --- a/tests/v4/test_uxarray_fieldset.py +++ b/tests/v4/test_uxarray_fieldset.py @@ -1,6 +1,6 @@ -import os from datetime import timedelta +import pytest import uxarray as ux from parcels import ( @@ -9,70 +9,48 @@ ParticleSet, UXPiecewiseConstantFace, UXPiecewiseLinearNode, + download_example_dataset, ) -# Get path of this script -V4_TEST_DATA = f"{os.path.dirname(__file__)}/test_data" - -def test_fesom_fieldset(): - # Load a FESOM dataset - grid_path = f"{V4_TEST_DATA}/fesom_channel.nc" +@pytest.fixture +def ds_fesom_channel() -> ux.UxDataset: + fesom_path = download_example_dataset("FESOM_periodic_channel") + grid_path = f"{fesom_path}/fesom_channel.nc" data_path = [ - f"{V4_TEST_DATA}/u.fesom_channel.nc", - f"{V4_TEST_DATA}/v.fesom_channel.nc", - f"{V4_TEST_DATA}/w.fesom_channel.nc", + f"{fesom_path}/u.fesom_channel.nc", + f"{fesom_path}/v.fesom_channel.nc", + f"{fesom_path}/w.fesom_channel.nc", ] - ds = ux.open_mfdataset(grid_path, data_path) - ds = ds.rename_vars({"u": "U", "v": "V", "w": "W"}) - fieldset = FieldSet([ds]) + ds = ux.open_mfdataset(grid_path, data_path).rename_vars({"u": "U", "v": "V", "w": "W"}) + return ds + + +def test_fesom_fieldset(ds_fesom_channel): + fieldset = FieldSet([ds_fesom_channel]) fieldset._check_complete() # Check that the fieldset has the expected properties - assert fieldset.datasets[0] == ds + assert fieldset.datasets[0] == ds_fesom_channel -def test_fesom_in_particleset(): - grid_path = f"{V4_TEST_DATA}/fesom_channel.nc" - data_path = [ - f"{V4_TEST_DATA}/u.fesom_channel.nc", - f"{V4_TEST_DATA}/v.fesom_channel.nc", - f"{V4_TEST_DATA}/w.fesom_channel.nc", - ] - ds = ux.open_mfdataset(grid_path, data_path) - ds = ds.rename_vars({"u": "U", "v": "V", "w": "W"}) - fieldset = FieldSet([ds]) +def test_fesom_in_particleset(ds_fesom_channel): + fieldset = FieldSet([ds_fesom_channel]) # Check that the fieldset has the expected properties - assert fieldset.datasets[0] == ds + assert fieldset.datasets[0] == ds_fesom_channel pset = ParticleSet(fieldset, pclass=Particle) assert pset.fieldset == fieldset -def test_set_interp_methods(): - grid_path = f"{V4_TEST_DATA}/fesom_channel.nc" - data_path = [ - f"{V4_TEST_DATA}/u.fesom_channel.nc", - f"{V4_TEST_DATA}/v.fesom_channel.nc", - f"{V4_TEST_DATA}/w.fesom_channel.nc", - ] - ds = ux.open_mfdataset(grid_path, data_path) - ds = ds.rename_vars({"u": "U", "v": "V", "w": "W"}) - fieldset = FieldSet([ds]) +def test_set_interp_methods(ds_fesom_channel): + fieldset = FieldSet([ds_fesom_channel]) # Set the interpolation method for each field fieldset.U.interp_method = UXPiecewiseConstantFace fieldset.V.interp_method = UXPiecewiseConstantFace fieldset.W.interp_method = UXPiecewiseLinearNode -def test_fesom_channel(): - grid_path = f"{V4_TEST_DATA}/fesom_channel.nc" - data_path = [ - f"{V4_TEST_DATA}/u.fesom_channel.nc", - f"{V4_TEST_DATA}/v.fesom_channel.nc", - f"{V4_TEST_DATA}/w.fesom_channel.nc", - ] - ds = ux.open_mfdataset(grid_path, data_path) - ds = ds.rename_vars({"u": "U", "v": "V", "w": "W"}) - fieldset = FieldSet([ds]) +def test_fesom_channel(ds_fesom_channel): + fieldset = FieldSet([ds_fesom_channel]) # Set the interpolation method for each field fieldset.U.interp_method = UXPiecewiseConstantFace fieldset.V.interp_method = UXPiecewiseConstantFace From 47360381585d635b1db0f543a551c9ea3dee8165 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Thu, 10 Apr 2025 16:29:40 +0200 Subject: [PATCH 39/46] Add xfail markers --- tests/test_advection.py | 11 +++++++++++ tests/test_diffusion.py | 4 ++++ tests/test_field.py | 3 +++ tests/test_fieldset.py | 21 +++++++++++++++++++++ tests/test_fieldset_sampling.py | 19 +++++++++++++++++++ tests/test_interpolation.py | 4 ++++ 6 files changed, 62 insertions(+) diff --git a/tests/test_advection.py b/tests/test_advection.py index 92f812f524..f0f710637e 100644 --- a/tests/test_advection.py +++ b/tests/test_advection.py @@ -58,6 +58,7 @@ def depth(): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") def test_advection_zonal(lon, lat, depth): """Particles at high latitude move geographically faster due to the pole correction in `GeographicPolar`.""" npart = 10 @@ -90,6 +91,7 @@ def test_advection_zonal(lon, lat, depth): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") def test_advection_meridional(lon, lat): """Particles at high latitude move geographically faster due to the pole correction in `GeographicPolar`.""" npart = 10 @@ -104,6 +106,7 @@ def test_advection_meridional(lon, lat): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") def test_advection_3D(): """Flat 2D zonal flow that increases linearly with depth from 0 m/s to 1 m/s.""" xdim = ydim = zdim = 2 @@ -126,6 +129,7 @@ def test_advection_3D(): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize("direction", ["up", "down"]) @pytest.mark.parametrize("wErrorThroughSurface", [True, False]) def test_advection_3D_outofbounds(direction, wErrorThroughSurface): @@ -172,6 +176,7 @@ def SubmergeParticle(particle, fieldset, time): # pragma: no cover @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize("rk45_tol", [10, 100]) def test_advection_RK45(lon, lat, rk45_tol): npart = 10 @@ -330,6 +335,7 @@ def test_advection_periodic_zonal_meridional(): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize("u", [-0.3, np.array(0.2)]) @pytest.mark.parametrize("v", [0.2, np.array(1)]) @pytest.mark.parametrize("w", [None, -0.2, np.array(0.7)]) @@ -408,6 +414,7 @@ def fieldset_stationary(): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize( "method, rtol, diffField", [ @@ -443,6 +450,7 @@ def test_stationary_eddy(fieldset_stationary, method, rtol, diffField): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") def test_stationary_eddy_vertical(): npart = 1 lon = np.linspace(12000, 21000, npart) @@ -514,6 +522,7 @@ def fieldset_moving(): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize( "method, rtol, diffField", [ @@ -592,6 +601,7 @@ def fieldset_decaying(): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize( "method, rtol, diffField", [ @@ -633,6 +643,7 @@ def test_decaying_eddy(fieldset_decaying, method, rtol, diffField): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") def test_analyticalAgrid(): lon = np.arange(0, 15, dtype=np.float32) lat = np.arange(0, 15, dtype=np.float32) diff --git a/tests/test_diffusion.py b/tests/test_diffusion.py index 267597d6ea..1cfb71e018 100644 --- a/tests/test_diffusion.py +++ b/tests/test_diffusion.py @@ -18,6 +18,7 @@ @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize("mesh", ["spherical", "flat"]) def test_fieldKh_Brownian(mesh): xdim = 200 @@ -52,6 +53,7 @@ def test_fieldKh_Brownian(mesh): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize("mesh", ["spherical", "flat"]) @pytest.mark.parametrize("kernel", [AdvectionDiffusionM1, AdvectionDiffusionEM]) def test_fieldKh_SpatiallyVaryingDiffusion(mesh, kernel): @@ -86,6 +88,7 @@ def test_fieldKh_SpatiallyVaryingDiffusion(mesh, kernel): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize("lambd", [1, 5]) def test_randomexponential(lambd): fieldset = create_fieldset_zeros_conversion() @@ -113,6 +116,7 @@ def vertical_randomexponential(particle, fieldset, time): # pragma: no cover @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize("mu", [0.8 * np.pi, np.pi]) @pytest.mark.parametrize("kappa", [2, 4]) def test_randomvonmises(mu, kappa): diff --git a/tests/test_field.py b/tests/test_field.py index 989e204bb9..2a29cdf690 100644 --- a/tests/test_field.py +++ b/tests/test_field.py @@ -12,6 +12,7 @@ @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") def test_field_from_netcdf_variables(): filename = str(TEST_DATA / "perlinfieldsU.nc") dims = {"lon": "x", "lat": "y"} @@ -32,6 +33,7 @@ def test_field_from_netcdf_variables(): @pytest.mark.v4remove +@pytest.mark.xfail(reason="GH1946") def test_field_from_netcdf(): filenames = { "lon": str(TEST_DATA / "mask_nemo_cross_180lon.nc"), @@ -44,6 +46,7 @@ def test_field_from_netcdf(): @pytest.mark.v4remove +@pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize( "calendar, cftime_datetime", zip(_get_cftime_calendars(), _get_cftime_datetimes(), strict=True) ) diff --git a/tests/test_fieldset.py b/tests/test_fieldset.py index 9824695297..10a8311a0b 100644 --- a/tests/test_fieldset.py +++ b/tests/test_fieldset.py @@ -56,6 +56,7 @@ def to_xarray_dataset(data: dict[str, np.array], dimensions: dict[str, np.array] @pytest.mark.v4remove +@pytest.mark.xfail(reason="GH1946") @pytest.fixture def multifile_fieldset(tmp_path): stem = "test_subsets" @@ -79,6 +80,7 @@ def multifile_fieldset(tmp_path): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize("xdim", [100, 200]) @pytest.mark.parametrize("ydim", [100, 200]) def test_fieldset_from_data(xdim, ydim): @@ -101,6 +103,7 @@ def test_fieldset_vmin_vmax(): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize("ttype", ["float", "datetime64"]) @pytest.mark.parametrize("tdim", [1, 20]) def test_fieldset_from_data_timedims(ttype, tdim): @@ -116,6 +119,7 @@ def test_fieldset_from_data_timedims(ttype, tdim): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize("xdim", [100, 200]) @pytest.mark.parametrize("ydim", [100, 50]) def test_fieldset_from_data_different_dimensions(xdim, ydim): @@ -146,6 +150,7 @@ def test_fieldset_from_data_different_dimensions(xdim, ydim): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") def test_fieldset_from_modulefile(): nemo_fname = str(TEST_DATA / "fieldset_nemo.py") nemo_error_fname = str(TEST_DATA / "fieldset_nemo_error.py") @@ -165,6 +170,7 @@ def test_fieldset_from_modulefile(): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") def test_field_from_netcdf_fieldtypes(): filenames = { "varU": { @@ -191,6 +197,7 @@ def test_field_from_netcdf_fieldtypes(): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") def test_fieldset_from_agrid_dataset(): filenames = { "lon": str(TEST_DATA / "mask_nemo_cross_180lon.nc"), @@ -203,6 +210,7 @@ def test_fieldset_from_agrid_dataset(): @pytest.mark.v4remove +@pytest.mark.xfail(reason="GH1946") def test_fieldset_from_cgrid_interpmethod(): filenames = { "lon": str(TEST_DATA / "mask_nemo_cross_180lon.nc"), @@ -218,6 +226,7 @@ def test_fieldset_from_cgrid_interpmethod(): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize("calltype", ["from_data", "from_nemo"]) def test_illegal_dimensionsdict(calltype): with pytest.raises(NameError): @@ -234,6 +243,7 @@ def test_illegal_dimensionsdict(calltype): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize("xdim", [100, 200]) @pytest.mark.parametrize("ydim", [100, 200]) def test_add_field(xdim, ydim, tmpdir): @@ -245,6 +255,7 @@ def test_add_field(xdim, ydim, tmpdir): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize("dupobject", ["same", "new"]) def test_add_duplicate_field(dupobject): data, dimensions = generate_fieldset_data(100, 100) @@ -260,6 +271,7 @@ def test_add_duplicate_field(dupobject): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize("fieldtype", ["normal", "vector"]) def test_add_field_after_pset(fieldtype): data, dimensions = generate_fieldset_data(100, 100) @@ -282,6 +294,7 @@ def test_fieldset_samegrids_from_file(multifile_fieldset): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize("gridtype", ["A", "C"]) def test_fieldset_dimlength1_cgrid(gridtype): fieldset = FieldSet.from_data({"U": 0, "V": 0}, {"lon": 0, "lat": 0}) @@ -304,6 +317,7 @@ def assign_dataset_timestamp_dim(ds, timestamp): @pytest.mark.v4remove +@pytest.mark.xfail(reason="GH1946") def test_fieldset_diffgrids_from_file(tmp_path): """Test for subsetting fieldset from file using indices dict.""" stem = "test_subsets" @@ -335,6 +349,7 @@ def test_fieldset_diffgrids_from_file(tmp_path): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") def test_fieldset_diffgrids_from_file_data(multifile_fieldset): """Test for subsetting fieldset from file using indices dict.""" data, dimensions = generate_fieldset_data(100, 100) @@ -349,6 +364,7 @@ def test_fieldset_diffgrids_from_file_data(multifile_fieldset): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") def test_fieldset_samegrids_from_data(): """Test for subsetting fieldset from file using indices dict.""" data, dimensions = generate_fieldset_data(100, 100) @@ -365,6 +381,7 @@ def addConst(particle, fieldset, time): # pragma: no cover @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") def test_fieldset_constant(): data, dimensions = generate_fieldset_data(100, 100) fieldset = FieldSet.from_data(data, dimensions) @@ -380,6 +397,7 @@ def test_fieldset_constant(): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize("swapUV", [False, True]) def test_vector_fields(swapUV): lon = np.linspace(0.0, 10.0, 12, dtype=np.float32) @@ -404,6 +422,7 @@ def test_vector_fields(swapUV): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") def test_add_second_vector_field(): lon = np.linspace(0.0, 10.0, 12, dtype=np.float32) lat = np.linspace(0.0, 10.0, 10, dtype=np.float32) @@ -556,6 +575,7 @@ def sampleTemp(particle, fieldset, time): # pragma: no cover @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize("tdim", [10, None]) def test_fieldset_from_xarray(tdim): def generate_dataset(xdim, ydim, zdim=1, tdim=1): @@ -611,6 +631,7 @@ def test_fieldset_frompop(): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") def test_fieldset_from_data_gridtypes(): """Simple test for fieldset initialisation from data.""" xdim, ydim, zdim = 20, 10, 4 diff --git a/tests/test_fieldset_sampling.py b/tests/test_fieldset_sampling.py index d0e0c9e60a..ab4e5af4e7 100644 --- a/tests/test_fieldset_sampling.py +++ b/tests/test_fieldset_sampling.py @@ -86,6 +86,7 @@ def fieldset_geometric_polar(): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") def test_fieldset_sample(fieldset): """Sample the fieldset using indexing notation.""" xdim, ydim = 120, 80 @@ -100,6 +101,7 @@ def test_fieldset_sample(fieldset): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") def test_fieldset_sample_eval(fieldset): """Sample the fieldset using the explicit eval function.""" xdim, ydim = 60, 60 @@ -123,6 +125,7 @@ def test_fieldset_polar_with_halo(fieldset_geometric_polar): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize("zdir", [-1, 1]) def test_verticalsampling(zdir): dims = (4, 2, 2) @@ -198,6 +201,7 @@ def test_nearest_neighbor_interpolation2D(): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") def test_nearest_neighbor_interpolation3D(): npart = 81 dims = (2, 2, 2) @@ -225,6 +229,7 @@ def test_nearest_neighbor_interpolation3D(): @pytest.mark.v4future +@pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize("withDepth", [True, False]) @pytest.mark.parametrize("arrtype", ["ones", "rand"]) def test_inversedistance_nearland(withDepth, arrtype): @@ -267,6 +272,7 @@ def test_inversedistance_nearland(withDepth, arrtype): @pytest.mark.v4future +@pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize("boundaryslip", ["freeslip", "partialslip"]) @pytest.mark.parametrize("withW", [False, True]) @pytest.mark.parametrize("withT", [False, True]) @@ -319,6 +325,7 @@ def test_partialslip_nearland_zonal(boundaryslip, withW, withT): @pytest.mark.v4future +@pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize("boundaryslip", ["freeslip", "partialslip"]) @pytest.mark.parametrize("withW", [False, True]) def test_partialslip_nearland_meridional(boundaryslip, withW): @@ -362,6 +369,7 @@ def test_partialslip_nearland_meridional(boundaryslip, withW): @pytest.mark.v4future +@pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize("boundaryslip", ["freeslip", "partialslip"]) def test_partialslip_nearland_vertical(boundaryslip): npart = 20 @@ -391,6 +399,7 @@ def test_partialslip_nearland_vertical(boundaryslip): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") def test_fieldset_sample_particle(): """Sample the fieldset using an array of particles.""" npart = 120 @@ -414,6 +423,7 @@ def test_fieldset_sample_particle(): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") def test_fieldset_sample_geographic(fieldset_geometric): """Sample a fieldset with conversion to geographic units (degrees).""" npart = 120 @@ -431,6 +441,7 @@ def test_fieldset_sample_geographic(fieldset_geometric): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") def test_fieldset_sample_geographic_noconvert(fieldset_geometric): """Sample a fieldset without conversion to geographic units.""" npart = 120 @@ -448,6 +459,7 @@ def test_fieldset_sample_geographic_noconvert(fieldset_geometric): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") def test_fieldset_sample_geographic_polar(fieldset_geometric_polar): """Sample a fieldset with conversion to geographic units and a pole correction.""" npart = 120 @@ -465,6 +477,7 @@ def test_fieldset_sample_geographic_polar(fieldset_geometric_polar): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") def test_meridionalflow_spherical(): """Create uniform NORTHWARD flow on spherical earth and advect particles. @@ -495,6 +508,7 @@ def test_meridionalflow_spherical(): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") def test_zonalflow_spherical(): """Create uniform EASTWARD flow on spherical earth and advect particles. @@ -532,6 +546,7 @@ def test_zonalflow_spherical(): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") def test_random_field(): """Sampling test that tests for overshoots by sampling a field of random numbers between 0 and 1.""" xdim, ydim = 20, 20 @@ -557,6 +572,7 @@ def test_random_field(): @pytest.mark.v4alpha +@pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize("allow_time_extrapolation", [True, False]) def test_sampling_out_of_bounds_time(allow_time_extrapolation): xdim, ydim, tdim = 10, 10, 10 @@ -676,6 +692,7 @@ def test_sample(particle, fieldset, time): # pragma: no cover @pytest.mark.v4future +@pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize("npart", [1, 10]) def test_sampling_multigrids_non_vectorfield(npart): xdim, ydim = 100, 200 @@ -716,6 +733,7 @@ def test_sample(particle, fieldset, time): # pragma: no cover @pytest.mark.v4future +@pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize("ugridfactor", [1, 10]) def test_sampling_multiple_grid_sizes(ugridfactor): xdim, ydim = 10, 20 @@ -744,6 +762,7 @@ def test_sampling_multiple_grid_sizes(ugridfactor): @pytest.mark.v4future +@pytest.mark.xfail(reason="GH1946") def test_multiple_grid_addlater_error(): xdim, ydim = 10, 20 U = Field( diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py index b49cb38cf3..cac7abc893 100644 --- a/tests/test_interpolation.py +++ b/tests/test_interpolation.py @@ -57,6 +57,7 @@ def data_2d(): @pytest.mark.v4remove +@pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize( "func, eta, xsi, expected", [ @@ -76,6 +77,7 @@ def test_raw_2d_interpolation(data_2d, func, eta, xsi, expected): @pytest.mark.v4remove +@pytest.mark.xfail(reason="GH1946") @pytest.mark.usefixtures("tmp_interpolator_registry") def test_interpolator_override(): fieldset = create_fieldset_zeros_3d() @@ -89,6 +91,7 @@ def test_interpolator(ctx: interpolation.InterpolationContext3D): @pytest.mark.v4remove +@pytest.mark.xfail(reason="GH1946") @pytest.mark.usefixtures("tmp_interpolator_registry") def test_full_depth_provided_to_interpolators(): """The full depth needs to be provided to the interpolation schemes as some interpolators @@ -109,6 +112,7 @@ def test_interpolator2(ctx: interpolation.InterpolationContext3D): @pytest.mark.v4remove +@pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize( "interp_method", [ From 8c8e947170e9b902ce250479b5bbe7df275009c5 Mon Sep 17 00:00:00 2001 From: Joe Date: Fri, 11 Apr 2025 12:58:38 -0400 Subject: [PATCH 40/46] Change "to do" to "TODO" --- parcels/_index_search.py | 14 +++++------- parcels/application_kernels/interpolation.py | 4 ++-- parcels/field.py | 24 ++++++++++---------- parcels/fieldset.py | 2 +- 4 files changed, 21 insertions(+), 23 deletions(-) diff --git a/parcels/_index_search.py b/parcels/_index_search.py index 60122ca682..8af53bee27 100644 --- a/parcels/_index_search.py +++ b/parcels/_index_search.py @@ -104,7 +104,7 @@ def search_indices_vertical_z(depth, gridindexingtype: GridIndexingType, z: floa return (zi, zeta) -## joe@fluidnumerics.com : 3/27/25 : To do : Still need to implement the search_indices_vertical_s function +## TODO : Still need to implement the search_indices_vertical_s function def search_indices_vertical_s( field: Field, interp_method: InterpMethodOption, @@ -183,13 +183,11 @@ def search_indices_vertical_s( def _search_indices_rectilinear( field: Field, time: datetime, z: float, y: float, x: float, ti: int, ei: int | None = None, search2D=False ): - # To do : If ei is provided, check if particle is in the same cell + # TODO : If ei is provided, check if particle is in the same cell if field.xdim > 1 and (not field.zonal_periodic): - if x < field.lonlat_minmax[0] or x > field.lonlat_minmax[1]: # To do : implement lonlat_minmax at field level + if x < field.lonlat_minmax[0] or x > field.lonlat_minmax[1]: _raise_field_out_of_bound_error(z, y, x) - if field.ydim > 1 and ( - y < field.lonlat_minmax[2] or y > field.lonlat_minmax[3] - ): # To do : implement lonlat_minmax at field level + if field.ydim > 1 and (y < field.lonlat_minmax[2] or y > field.lonlat_minmax[3]): _raise_field_out_of_bound_error(z, y, x) if field.xdim > 1: @@ -255,7 +253,7 @@ def _search_indices_rectilinear( except FieldOutOfBoundSurfaceError: _raise_field_out_of_bound_surface_error(z, y, x) elif field._gtype == GridType.RectilinearSGrid: - ## joe@fluidnumerics.com : 3/27/25 : To do : Still need to implement the search_indices_vertical_s function + ## TODO : Still need to implement the search_indices_vertical_s function (zi, zeta) = search_indices_vertical_s(field.grid, field.interp_method, time, z, y, x, ti, yi, xi, eta, xsi) else: zi, zeta = -1, 0 @@ -268,7 +266,7 @@ def _search_indices_rectilinear( return (zeta, eta, xsi, _ei) -## joe@fluidnumerics.com : 3/27/25 : To do : Still need to implement the search_indices_curvilinear +## TODO : Still need to implement the search_indices_curvilinear def _search_indices_curvilinear(field: Field, time, z, y, x, ti, particle=None, search2D=False): if particle: zi, yi, xi = field.unravel_index(particle.ei) diff --git a/parcels/application_kernels/interpolation.py b/parcels/application_kernels/interpolation.py index eea65b1c6e..a1abf14064 100644 --- a/parcels/application_kernels/interpolation.py +++ b/parcels/application_kernels/interpolation.py @@ -26,7 +26,7 @@ def UXPiecewiseConstantFace( This interpolation method is appropriate for fields that are face registered, such as u,v in FESOM. """ - # To do : handle vertical interpolation + # TODO joe : handle vertical interpolation zi, fi = field.unravel_index(ei) return field.data[ti, zi, fi] @@ -47,7 +47,7 @@ def UXPiecewiseLinearNode( interpolation method is appropriate for fields that are node registered such as the vertical velocity w in FESOM. """ - # To do: handle vertical interpolation + # TODO joe : handle vertical interpolation zi, fi = field.unravel_index(ei) node_ids = field.data.uxgrid.face_node_connectivity[fi, :] return np.dot(field.data[ti, zi, node_ids], bcoords) diff --git a/parcels/field.py b/parcels/field.py index 8416b7afc4..1e231cc7b9 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -143,7 +143,7 @@ def __init__( self, name: str, data: xr.DataArray | ux.UxDataArray, - grid: ux.UxGrid | None = None, # To do : Once parcels.Grid class is added, allow for it to be passed here + grid: ux.UxGrid | None = None, # TODO Nick : Once parcels.Grid class is added, allow for it to be passed here mesh_type: Mesh = "flat", interp_method: Callable | None = None, allow_time_extrapolation: bool | None = None, @@ -188,7 +188,7 @@ def __init__( self._vertical_location = "center" elif "nz" in data.dims: self._vertical_location = "face" - else: # To do : This bit probably needs an overhaul once the parcels.Grid class is integrated. + else: # TODO Nick : This bit probably needs an overhaul once the parcels.Grid class is integrated. self._spatialhash = None # Set the grid type if "x_g" in self.data.coords: @@ -283,7 +283,7 @@ def xdim(self): else: return self.data.sizes["lon"] else: - return 0 # To do : Discuss what we want to return for uxdataarray obj + return 0 # TODO : Discuss what we want to return as xdim for uxdataarray obj @property def ydim(self): @@ -295,7 +295,7 @@ def ydim(self): else: return self.data.sizes["lat"] else: - return 0 # To do : Discuss what we want to return for uxdataarray obj + return 0 # TODO : Discuss what we want to return as ydim for uxdataarray obj @property def zdim(self): @@ -311,7 +311,7 @@ def n_face(self): if type(self.data) is ux.uxDataArray: return self.grid.n_face else: - return 0 # To do : Discuss what we want to return for dataarray obj + return 0 # TODO : Discuss what we want to return as n_face for dataarray obj @property def interp_method(self): @@ -345,8 +345,8 @@ def _search_indices_unstructured(self, z, y, x, ei=None, search2D=False): # Search using global search fi, bcoords = self._spatialhash.query([[x, y]]) # Get the face id for the particle if fi == -1: - raise FieldOutOfBoundError(z, y, x) # To do : how to handle lost particle ?? - # To do : Do the vertical grid search + raise FieldOutOfBoundError(z, y, x) + # TODO Joe : Do the vertical grid search # zi = self._vertical_search(z) zi = 0 # For now return bcoords, self.ravel_index(zi, 0, fi) @@ -356,26 +356,26 @@ def _search_indices_unstructured(self, z, y, x, ei=None, search2D=False): bcoords, err = self._get_ux_barycentric_coordinates(y, x, fi) if ((bcoords >= 0).all()) and ((bcoords <= 1.0).all()) and err < tol: - # To do: Do the vertical grid search + # TODO Joe : Do the vertical grid search return bcoords, ei else: # In this case we need to search the neighbors for neighbor in self.grid.face_face_connectivity[fi, :]: bcoords, err = self._get_ux_barycentric_coordinates(y, x, neighbor) if ((bcoords >= 0).all()) and ((bcoords <= 1.0).all()) and err < tol: - # To do: Do the vertical grid search + # TODO Joe: Do the vertical grid search return bcoords, self.ravel_index(zi, 0, neighbor) # If we reach this point, we do a global search as a last ditch effort the particle is out of bounds fi, bcoords = self._spatialhash.query([[x, y]]) # Get the face id for the particle if fi == -1: - raise FieldOutOfBoundError(z, y, x) # To do : how to handle lost particle ?? + raise FieldOutOfBoundError(z, y, x) def _search_indices_structured(self, z, y, x, ei=None, search2D=False): if self._gtype in [GridType.RectilinearSGrid, GridType.RectilinearZGrid]: (zeta, eta, xsi, zi, yi, xi) = _search_indices_rectilinear(self, z, y, x, ei=ei, search2D=search2D) else: - ## joe@fluidnumerics.com : 3/27/25 : To do : Still need to implement the search_indices_curvilinear + ## TODO : Still need to implement the search_indices_curvilinear # (zeta, eta, xsi, zi, yi, xi) = _search_indices_curvilinear( # self, z, y, x, ei=ei, search2D=search2D # ) @@ -628,7 +628,7 @@ def vector_interp_method(self, method: Callable): self._vector_interp_method = method # @staticmethod - # To do : def _check_grid_dimensions(grid1, grid2): + # TODO : def _check_grid_dimensions(grid1, grid2): # return ( # np.allclose(grid1.lon, grid2.lon) # and np.allclose(grid1.lat, grid2.lat) diff --git a/parcels/fieldset.py b/parcels/fieldset.py index 64332c4004..f283b962c3 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -178,7 +178,7 @@ def add_constant_field(self, name: str, value: float, mesh: Mesh = "flat"): Field( name, data, - interp_method=None, # To do : Need to define an interpolation method for constants + interp_method=None, # TODO : Need to define an interpolation method for constants allow_time_extrapolation=True, ) ) From 12557d959ae560b78d110666c930d0644f4f999c Mon Sep 17 00:00:00 2001 From: Joe Date: Fri, 11 Apr 2025 13:02:56 -0400 Subject: [PATCH 41/46] Move _validate_* functions to be pure functions --- parcels/field.py | 92 +++++++++++++++++++++++++----------------------- 1 file changed, 47 insertions(+), 45 deletions(-) diff --git a/parcels/field.py b/parcels/field.py index 1e231cc7b9..9a78fc578d 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -152,7 +152,7 @@ def __init__( self.data = data self.grid = grid - self._validate_dataarray() + _validate_dataarray(data, name) self._parent_mesh = data.attrs["mesh"] self._mesh_type = mesh_type @@ -506,50 +506,6 @@ def unravel_index(self, ei): fi = _ei % self.n_face return zi, fi - def _validate_dataarray(self): - """Verifies that all the required attributes are present in the xarray.DataArray or - uxarray.UxDataArray object. - """ - if isinstance(self.data, ux.UxDataArray): - # Validate dimensions - if not ("nz1" in self.data.dims or "nz" in self.data.dims): - raise ValueError( - f"Field {self.name} is missing a 'nz1' or 'nz' dimension in the field's metadata. " - "This attribute is required for xarray.DataArray objects." - ) - - if "time" not in self.data.dims: - raise ValueError( - f"Field {self.name} is missing a 'time' dimension in the field's metadata. " - "This attribute is required for xarray.DataArray objects." - ) - - # Validate attributes - required_keys = ["location", "mesh"] - for key in required_keys: - if key not in self.data.attrs.keys(): - raise ValueError( - f"Field {self.name} is missing a '{key}' attribute in the field's metadata. " - "This attribute is required for xarray.DataArray objects." - ) - - if type(self.data) is ux.UxDataArray: - self._validate_uxgrid() - - def _validate_uxgrid(self): - """Verifies that all the required attributes are present in the uxarray.UxDataArray.UxGrid object.""" - if "Conventions" not in self.grid.attrs.keys(): - raise ValueError( - f"Field {self.name} is missing a 'Conventions' attribute in the field's metadata. " - "This attribute is required for uxarray.UxDataArray objects." - ) - if self.grid.attrs["Conventions"] != "UGRID-1.0": - raise ValueError( - f"Field {self.name} has a 'Conventions' attribute that is not 'UGRID-1.0'. " - "This attribute is required for uxarray.UxDataArray objects." - "See https://ugrid-conventions.github.io/ugrid-conventions/ for more information." - ) - def __getattr__(self, key: str): return getattr(self.data, key) @@ -674,3 +630,49 @@ def __getitem__(self, key): return self.eval(*key) except tuple(AllParcelsErrorCodes.keys()) as error: return _deal_with_errors(error, key, vector_type=self.vector_type) + + +def _validate_dataarray(data, name): + """Verifies that all the required attributes are present in the xarray.DataArray or + uxarray.UxDataArray object. + """ + if isinstance(data, ux.UxDataArray): + # Validate dimensions + if not ("nz1" in data.dims or "nz" in data.dims): + raise ValueError( + f"Field {name} is missing a 'nz1' or 'nz' dimension in the field's metadata. " + "This attribute is required for xarray.DataArray objects." + ) + + if "time" not in data.dims: + raise ValueError( + f"Field {name} is missing a 'time' dimension in the field's metadata. " + "This attribute is required for xarray.DataArray objects." + ) + + # Validate attributes + required_keys = ["location", "mesh"] + for key in required_keys: + if key not in data.attrs.keys(): + raise ValueError( + f"Field {name} is missing a '{key}' attribute in the field's metadata. " + "This attribute is required for xarray.DataArray objects." + ) + + if type(data) is ux.UxDataArray: + _validate_uxgrid(data.uxgrid, name) + + +def _validate_uxgrid(grid, name): + """Verifies that all the required attributes are present in the uxarray.UxDataArray.UxGrid object.""" + if "Conventions" not in grid.attrs.keys(): + raise ValueError( + f"Field {name} is missing a 'Conventions' attribute in the field's metadata. " + "This attribute is required for uxarray.UxDataArray objects." + ) + if grid.attrs["Conventions"] != "UGRID-1.0": + raise ValueError( + f"Field {name} has a 'Conventions' attribute that is not 'UGRID-1.0'. " + "This attribute is required for uxarray.UxDataArray objects." + "See https://ugrid-conventions.github.io/ugrid-conventions/ for more information." + ) From 81793d980c6adee0c310374be4ae33e750b79886 Mon Sep 17 00:00:00 2001 From: Joe Date: Fri, 11 Apr 2025 13:07:47 -0400 Subject: [PATCH 42/46] Return gridset_size as the number of fields --- parcels/fieldset.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/parcels/fieldset.py b/parcels/fieldset.py index f283b962c3..aafa3960d0 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -46,7 +46,6 @@ def __init__(self, datasets: list[xr.Dataset | ux.UxDataset]): self.datasets = datasets self._completed: bool = False - self._gridset_size: int = 0 self._fieldnames = [] time_origin = None # Create pointers to each (Ux)DataArray @@ -56,7 +55,6 @@ def __init__(self, datasets: list[xr.Dataset | ux.UxDataset]): self.add_field(Field(field, ds[field], grid=ds[field].uxgrid), field) else: self.add_field(Field(field, ds[field]), field) - self._gridset_size += 1 self._fieldnames.append(field) if "time" in ds.coords: @@ -113,7 +111,7 @@ def checkvaliddimensionsdict(dims): @property def gridset_size(self): - return self._gridset_size + return len(self._fieldnames) def add_field(self, field: Field, name: str | None = None): """Add a :class:`parcels.field.Field` object to the FieldSet. From 9d3a7e28907ec56428f9e9efffeee35941a3890c Mon Sep 17 00:00:00 2001 From: Joe Date: Fri, 11 Apr 2025 13:11:20 -0400 Subject: [PATCH 43/46] Remove precision requirment on constant field --- parcels/fieldset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/parcels/fieldset.py b/parcels/fieldset.py index aafa3960d0..fe87780287 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -145,7 +145,7 @@ def add_field(self, field: Field, name: str | None = None): self._gridset_size += 1 self._fieldnames.append(name) - def add_constant_field(self, name: str, value: float, mesh: Mesh = "flat"): + def add_constant_field(self, name: str, value, mesh: Mesh = "flat"): """Wrapper function to add a Field that is constant in space, useful e.g. when using constant horizontal diffusivity @@ -153,8 +153,8 @@ def add_constant_field(self, name: str, value: float, mesh: Mesh = "flat"): ---------- name : str Name of the :class:`parcels.field.Field` object to be added - value : float - Value of the constant field (stored as 32-bit float) + value : + Value of the constant field mesh : str String indicating the type of mesh coordinates and units used during velocity interpolation, see also `this tutorial <../examples/tutorial_unitconverters.ipynb>`__: @@ -164,7 +164,7 @@ def add_constant_field(self, name: str, value: float, mesh: Mesh = "flat"): 2. flat: No conversion, lat/lon are assumed to be in m. """ time = 0.0 - values = np.zeros((1, 1, 1, 1), dtype=np.float32) + value + values = np.full((1, 1, 1, 1), value) data = xr.DataArray( data=values, name=name, From 89e18bd93c889368f563f83ff101f34deee06693 Mon Sep 17 00:00:00 2001 From: Joe Date: Fri, 11 Apr 2025 13:12:32 -0400 Subject: [PATCH 44/46] Remove parse_wildcards --- parcels/fieldset.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/parcels/fieldset.py b/parcels/fieldset.py index fe87780287..18480fffb3 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -1,6 +1,3 @@ -import os -from glob import glob - import numpy as np import uxarray as ux import xarray as xr @@ -222,18 +219,6 @@ def _check_complete(self): self._completed = True - @classmethod - def _parse_wildcards(cls, paths, filenames, var): - if not isinstance(paths, list): - paths = sorted(glob(str(paths))) - if len(paths) == 0: - notfound_paths = filenames[var] if isinstance(filenames, dict) and var in filenames else filenames - raise OSError(f"FieldSet files not found for variable {var}: {notfound_paths}") - for fp in paths: - if not os.path.exists(fp): - raise OSError(f"FieldSet file not found: {fp}") - return paths - @classmethod def from_data( cls, From 0af7de934fdd73f9bf149d1fcbafdb6942f0af9c Mon Sep 17 00:00:00 2001 From: Joe Date: Mon, 14 Apr 2025 14:37:31 -0400 Subject: [PATCH 45/46] Remove fieldset.from_data --- parcels/fieldset.py | 148 -------------------------------------------- 1 file changed, 148 deletions(-) diff --git a/parcels/fieldset.py b/parcels/fieldset.py index 18480fffb3..9c8ee9315f 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -219,154 +219,6 @@ def _check_complete(self): self._completed = True - @classmethod - def from_data( - cls, - data, - dimensions, - mesh: Mesh = "spherical", - allow_time_extrapolation: bool | None = None, - **kwargs, - ): - """Initialise FieldSet object from raw data. Assumes structured grid and uses xarray.dataarray objects. - - Parameters - ---------- - data : - Dictionary mapping field names to numpy arrays. - Note that at least a 'U' and 'V' numpy array need to be given, and that - the built-in Advection kernels assume that U and V are in m/s. - Data shape is either [ydim, xdim], [zdim, ydim, xdim], [tdim, ydim, xdim] or [tdim, zdim, ydim, xdim], - dimensions : dict - Dictionary mapping field dimensions (lon, - lat, depth, time) to numpy arrays. - Note that dimensions can also be a dictionary of dictionaries if - dimension names are different for each variable - (e.g. dimensions['U'], dimensions['V'], etc). - mesh : str - String indicating the type of mesh coordinates and - units used during velocity interpolation, see also `this tutorial <../examples/tutorial_unitconverters.ipynb>`__: - - 1. spherical (default): Lat and lon in degree, with a - correction for zonal velocity U near the poles. - 2. flat: No conversion, lat/lon are assumed to be in m. - allow_time_extrapolation : bool - boolean whether to allow for extrapolation - (i.e. beyond the last available time snapshot) - Default is False if dimensions includes time, else True - **kwargs : - Keyword arguments passed to the :class:`Field` constructor. - - Examples - -------- - For usage examples see the following tutorials: - - * `Analytical advection <../examples/tutorial_analyticaladvection.ipynb>`__ - - * `Diffusion <../examples/tutorial_diffusion.ipynb>`__ - - * `Interpolation <../examples/tutorial_interpolation.ipynb>`__ - - * `Unit converters <../examples/tutorial_unitconverters.ipynb>`__ - """ - fields = {} - for name, datafld in data.items(): - # Use dimensions[name] if dimensions is a dict of dicts - dims = dimensions[name] if name in dimensions else dimensions - cls.checkvaliddimensionsdict(dims) - - if allow_time_extrapolation is None: - allow_time_extrapolation = False if "time" in dims else True - - lon = dims["lon"] - lat = dims["lat"] - depth = np.zeros(1, dtype=np.float32) if "depth" not in dims else dims["depth"] - time = np.zeros(1, dtype=np.float64) if "time" not in dims else dims["time"] - - if len(datafld.shape) == 2: - if "time" in dims: - coords = [time, lat, lon] - datafld = datafld[np.newaxis, ...] - dims_xr = {"time": time, "lat": lat, "lon": lon} - else: - coords = [lat, lon] - dims_xr = {"lat": lat, "lon": lon} - - elif len(datafld.shape) == 3: - if "time" not in dims: - coords = [depth, lat, lon] - dims_xr = {"depth": depth, "lat": lat, "lon": lon} - else: - coords = [time, lat, lon] - dims_xr = {"time": time, "lat": lat, "lon": lon} - else: - coords = [time, depth, lat, lon] - dims_xr = {"time": time, "depth": depth, "lat": lat, "lon": lon} - - fields[name] = xr.DataArray( - data=datafld, - name=name, - dims=dims_xr, - coords=coords, - attrs=dict( - description="Created with fieldset.from_data", - units="", - location="node", - mesh="Arakawa-A", - ), - ) - - return cls([xr.Dataset(fields)]) - - @classmethod - def from_xarray_dataset(cls, ds, variables, dimensions, mesh="spherical", allow_time_extrapolation=None, **kwargs): - """Initialises FieldSet data from xarray Datasets. - - Parameters - ---------- - ds : xr.Dataset - xarray Dataset. - Note that the built-in Advection kernels assume that U and V are in m/s - variables : dict - Dictionary mapping parcels variable names to data variables in the xarray Dataset. - dimensions : dict - Dictionary mapping data dimensions (lon, - lat, depth, time, data) to dimensions in the xarray Dataset. - Note that dimensions can also be a dictionary of dictionaries if - dimension names are different for each variable - (e.g. dimensions['U'], dimensions['V'], etc). - mesh : str - String indicating the type of mesh coordinates and - units used during velocity interpolation, see also `this tutorial <../examples/tutorial_unitconverters.ipynb>`__: - - 1. spherical (default): Lat and lon in degree, with a - correction for zonal velocity U near the poles. - 2. flat: No conversion, lat/lon are assumed to be in m. - allow_time_extrapolation : bool - boolean whether to allow for extrapolation - (i.e. beyond the last available time snapshot) - Default is False if dimensions includes time, else True - **kwargs : - Keyword arguments passed to the :func:`Field.from_xarray` constructor. - """ - for var, _name in variables.items(): - dims = dimensions[var] if var in dimensions else dimensions - cls.checkvaliddimensionsdict(dims) - - return cls([ds]) - - # @classmethod - # def from_netcdf( - # cls, - # filenames, - # variables, - # dimensions, - # fieldtype=None, - # mesh: Mesh = "spherical", - # allow_time_extrapolation: bool | None = None, - # **kwargs, - # ): - # @classmethod # def from_nemo( # cls, From 819a077fa46319d4ee1c4d74b7646d581edd1678 Mon Sep 17 00:00:00 2001 From: Joe Date: Mon, 14 Apr 2025 15:52:54 -0400 Subject: [PATCH 46/46] Change test markers; add TODO to remove `from_data` calls --- tests/test_fieldset.py | 44 +++++++++++++++++++++--------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tests/test_fieldset.py b/tests/test_fieldset.py index 10a8311a0b..ed70d46d89 100644 --- a/tests/test_fieldset.py +++ b/tests/test_fieldset.py @@ -79,7 +79,7 @@ def multifile_fieldset(tmp_path): return FieldSet.from_netcdf(files, variables, dimensions) -@pytest.mark.v4alpha +@pytest.mark.v4remove @pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize("xdim", [100, 200]) @pytest.mark.parametrize("ydim", [100, 200]) @@ -102,7 +102,7 @@ def test_fieldset_vmin_vmax(): assert np.isclose(np.amax(fieldset.U.data), 7) -@pytest.mark.v4alpha +@pytest.mark.v4remove @pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize("ttype", ["float", "datetime64"]) @pytest.mark.parametrize("tdim", [1, 20]) @@ -118,7 +118,7 @@ def test_fieldset_from_data_timedims(ttype, tdim): assert fieldset.U.time[i].data == dtime -@pytest.mark.v4alpha +@pytest.mark.v4remove @pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize("xdim", [100, 200]) @pytest.mark.parametrize("ydim", [100, 50]) @@ -225,9 +225,9 @@ def test_fieldset_from_cgrid_interpmethod(): FieldSet.from_c_grid_dataset(filenames, variable, dimensions, interp_method="partialslip") -@pytest.mark.v4alpha +@pytest.mark.v4future @pytest.mark.xfail(reason="GH1946") -@pytest.mark.parametrize("calltype", ["from_data", "from_nemo"]) +@pytest.mark.parametrize("calltype", ["from_nemo"]) def test_illegal_dimensionsdict(calltype): with pytest.raises(NameError): if calltype == "from_data": @@ -248,7 +248,7 @@ def test_illegal_dimensionsdict(calltype): @pytest.mark.parametrize("ydim", [100, 200]) def test_add_field(xdim, ydim, tmpdir): data, dimensions = generate_fieldset_data(xdim, ydim) - fieldset = FieldSet.from_data(data, dimensions) + fieldset = FieldSet.from_data(data, dimensions) # TODO : Remove from_data field = Field("newfld", fieldset.U.data, lon=fieldset.U.lon, lat=fieldset.U.lat) fieldset.add_field(field) assert fieldset.newfld.data.shape == fieldset.U.data.shape @@ -259,7 +259,7 @@ def test_add_field(xdim, ydim, tmpdir): @pytest.mark.parametrize("dupobject", ["same", "new"]) def test_add_duplicate_field(dupobject): data, dimensions = generate_fieldset_data(100, 100) - fieldset = FieldSet.from_data(data, dimensions) + fieldset = FieldSet.from_data(data, dimensions) # TODO : Remove from_data field = Field("newfld", fieldset.U.data, lon=fieldset.U.lon, lat=fieldset.U.lat) fieldset.add_field(field) with pytest.raises(RuntimeError): @@ -275,7 +275,7 @@ def test_add_duplicate_field(dupobject): @pytest.mark.parametrize("fieldtype", ["normal", "vector"]) def test_add_field_after_pset(fieldtype): data, dimensions = generate_fieldset_data(100, 100) - fieldset = FieldSet.from_data(data, dimensions) + fieldset = FieldSet.from_data(data, dimensions) # TODO : Remove from_data pset = ParticleSet(fieldset, Particle, lon=0, lat=0) # noqa ; to trigger fieldset._check_complete field1 = Field("field1", fieldset.U.data, lon=fieldset.U.lon, lat=fieldset.U.lat) field2 = Field("field2", fieldset.U.data, lon=fieldset.U.lon, lat=fieldset.U.lat) @@ -297,7 +297,7 @@ def test_fieldset_samegrids_from_file(multifile_fieldset): @pytest.mark.xfail(reason="GH1946") @pytest.mark.parametrize("gridtype", ["A", "C"]) def test_fieldset_dimlength1_cgrid(gridtype): - fieldset = FieldSet.from_data({"U": 0, "V": 0}, {"lon": 0, "lat": 0}) + fieldset = FieldSet.from_data({"U": 0, "V": 0}, {"lon": 0, "lat": 0}) # TODO : Remove from_data if gridtype == "C": fieldset.U.interp_method = "cgrid_velocity" fieldset.V.interp_method = "cgrid_velocity" @@ -353,7 +353,7 @@ def test_fieldset_diffgrids_from_file(tmp_path): def test_fieldset_diffgrids_from_file_data(multifile_fieldset): """Test for subsetting fieldset from file using indices dict.""" data, dimensions = generate_fieldset_data(100, 100) - field_U = FieldSet.from_data(data, dimensions).U + field_U = FieldSet.from_data(data, dimensions).U # TODO : Remove from_data field_U.name = "B" multifile_fieldset.add_field(field_U, "B") @@ -368,7 +368,7 @@ def test_fieldset_diffgrids_from_file_data(multifile_fieldset): def test_fieldset_samegrids_from_data(): """Test for subsetting fieldset from file using indices dict.""" data, dimensions = generate_fieldset_data(100, 100) - fieldset1 = FieldSet.from_data(data, dimensions) + fieldset1 = FieldSet.from_data(data, dimensions) # TODO : Remove from_data field_data = fieldset1.U field_data.name = "B" fieldset1.add_field(field_data, "B") @@ -384,7 +384,7 @@ def addConst(particle, fieldset, time): # pragma: no cover @pytest.mark.xfail(reason="GH1946") def test_fieldset_constant(): data, dimensions = generate_fieldset_data(100, 100) - fieldset = FieldSet.from_data(data, dimensions) + fieldset = FieldSet.from_data(data, dimensions) # TODO : Remove from_data westval = -0.2 eastval = 0.3 fieldset.add_constant("movewest", westval) @@ -406,7 +406,7 @@ def test_vector_fields(swapUV): V = np.zeros((10, 12), dtype=np.float32) data = {"U": U, "V": V} dimensions = {"U": {"lat": lat, "lon": lon}, "V": {"lat": lat, "lon": lon}} - fieldset = FieldSet.from_data(data, dimensions, mesh="flat") + fieldset = FieldSet.from_data(data, dimensions, mesh="flat") # TODO : Remove from_data if swapUV: # we test that we can freely edit whatever UV field UV = VectorField("UV", fieldset.V, fieldset.U) fieldset.add_vector_field(UV) @@ -430,11 +430,11 @@ def test_add_second_vector_field(): V = np.zeros((10, 12), dtype=np.float32) data = {"U": U, "V": V} dimensions = {"U": {"lat": lat, "lon": lon}, "V": {"lat": lat, "lon": lon}} - fieldset = FieldSet.from_data(data, dimensions, mesh="flat") + fieldset = FieldSet.from_data(data, dimensions, mesh="flat") # TODO : Remove from_data data2 = {"U2": U, "V2": V} dimensions2 = {"lon": [ln + 0.1 for ln in lon], "lat": [lt - 0.1 for lt in lat]} - fieldset2 = FieldSet.from_data(data2, dimensions2, mesh="flat") + fieldset2 = FieldSet.from_data(data2, dimensions2, mesh="flat") # TODO : Remove from_data UV2 = VectorField("UV2", fieldset2.U2, fieldset2.V2) fieldset.add_vector_field(UV2) @@ -464,11 +464,11 @@ def test_timestamps(datetype, tmpdir): dims1["time"] = np.arange("2005-02-01", "2005-02-11", dtype="datetime64[D]") dims2["time"] = np.arange("2005-02-11", "2005-02-15", dtype="datetime64[D]") - fieldset1 = FieldSet.from_data(data1, dims1) + fieldset1 = FieldSet.from_data(data1, dims1) # TODO : Remove from_data fieldset1.U.data[0, :, :] = 2.0 fieldset1.write(tmpdir.join("file1")) - fieldset2 = FieldSet.from_data(data2, dims2) + fieldset2 = FieldSet.from_data(data2, dims2) # TODO : Remove from_data fieldset2.U.data[0, :, :] = 0.0 fieldset2.write(tmpdir.join("file2")) @@ -535,7 +535,7 @@ def temp_func(time): data = {"U": U, "V": V, "W": W, "temp": temp, "D": D} fieldset = FieldSet.from_data( data, dimensions, mesh="flat", time_periodic=time_periodic, allow_time_extrapolation=True - ) + ) # TODO : Remove from_data def sampleTemp(particle, fieldset, time): # pragma: no cover particle.temp = fieldset.temp[time, particle.depth, particle.lat, particle.lon] @@ -650,7 +650,7 @@ def test_fieldset_from_data_gridtypes(): depth_s[k, :, :] = depth[k] # Rectilinear Z grid - fieldset = FieldSet.from_data(data, dimensions, mesh="flat") + fieldset = FieldSet.from_data(data, dimensions, mesh="flat") # TODO : Remove from_data pset = ParticleSet(fieldset, Particle, [0, 0], [0, 0], [0, 0.4]) pset.execute(AdvectionRK4, runtime=1.5, dt=0.5) plon = pset.lon @@ -661,7 +661,7 @@ def test_fieldset_from_data_gridtypes(): # Rectilinear S grid dimensions["depth"] = depth_s - fieldset = FieldSet.from_data(data, dimensions, mesh="flat") + fieldset = FieldSet.from_data(data, dimensions, mesh="flat") # TODO : Remove from_data pset = ParticleSet(fieldset, Particle, [0, 0], [0, 0], [0, 0.4]) pset.execute(AdvectionRK4, runtime=1.5, dt=0.5) assert np.allclose(plon, pset.lon) @@ -671,7 +671,7 @@ def test_fieldset_from_data_gridtypes(): dimensions["lon"] = lonm dimensions["lat"] = latm dimensions["depth"] = depth - fieldset = FieldSet.from_data(data, dimensions, mesh="flat") + fieldset = FieldSet.from_data(data, dimensions, mesh="flat") # TODO : Remove from_data pset = ParticleSet(fieldset, Particle, [0, 0], [0, 0], [0, 0.4]) pset.execute(AdvectionRK4, runtime=1.5, dt=0.5) assert np.allclose(plon, pset.lon) @@ -679,7 +679,7 @@ def test_fieldset_from_data_gridtypes(): # Curvilinear S grid dimensions["depth"] = depth_s - fieldset = FieldSet.from_data(data, dimensions, mesh="flat") + fieldset = FieldSet.from_data(data, dimensions, mesh="flat") # TODO : Remove from_data pset = ParticleSet(fieldset, Particle, [0, 0], [0, 0], [0, 0.4]) pset.execute(AdvectionRK4, runtime=1.5, dt=0.5) assert np.allclose(plon, pset.lon)