From d10ed514fcd8e5a1be3fbeb1b9428064dfaf62af Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Mon, 6 Oct 2025 16:14:10 +0200 Subject: [PATCH 01/33] Add curved curved_quiver --- ultraplot/axes/plot.py | 784 ++++++++++++++++++++++++++++++++++- ultraplot/tests/test_plot.py | 32 ++ 2 files changed, 803 insertions(+), 13 deletions(-) diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index 9cbc89050..3121456db 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -31,10 +31,13 @@ import matplotlib.ticker as mticker import matplotlib.pyplot as mplt import matplotlib as mpl +from matplotlib.streamplot import StreamplotSet from packaging import version import numpy as np import numpy.ma as ma +from matplotlib.streamplot import StreamplotSet + from .. import colors as pcolors from .. import constructor, utils from ..config import rc @@ -170,6 +173,50 @@ docstring._snippet_manager["plot.args_1d_shared"] = _args_1d_shared_docstring docstring._snippet_manager["plot.args_2d_shared"] = _args_2d_shared_docstring +_curved_quiver_docstring = """ +Draws curved vector field arrows (streamlines with arrows) for 2D vector fields. + +Parameters +---------- +x, y : 1D or 2D arrays + Grid coordinates. +u, v : 2D arrays + Vector components. +c : color or 2D array, optional + Streamline color. +density : float or (float, float), optional + Controls the closeness of streamlines. +grains : int or (int, int), optional + Number of seed points in x and y. +linewidth : float or 2D array, optional + Width of streamlines. +color : color or 2D array, optional + Streamline color. +cmap, norm : optional + Colormap and normalization for array colors. +arrowsize : float, optional + Arrow size scaling. +arrowstyle : str, optional + Arrow style specification. +transform : optional + Matplotlib transform. +zorder : float, optional + Z-order for lines/arrows. +start_points : (N, 2) array, optional + Starting points for streamlines. +integration_direction : {'forward', 'backward', 'both'}, optional + Direction to integrate streamlines. +broken_streamlines : bool, optional + If True, streamlines may terminate early. + +Returns +------- +StreamplotSet + Container with attributes: + - lines: LineCollection of streamlines + - arrows: PatchCollection of arrows +""" +docstring._snippet_manager["plot.curved_quiver"] = _args_2d_shared_docstring # Auto colorbar and legend docstring _guide_docstring = """ @@ -1493,28 +1540,648 @@ def _inside_seaborn_call(): return False +# The following helper classes and functions for curved_quiver are based on the +# work in the `dfm_tools` repository. +# Original file: https://github.com/Deltares/dfm_tools/blob/829e76f48ebc42460aae118cc190147a595a5f26/dfm_tools/modplot.py + + +class _TerminateTrajectory(Exception): + pass + + +from dataclasses import dataclass + + +@dataclass +class CurvedQuiverSet(StreamplotSet): + lines: object + arrows: object + + +class _DomainMap(object): + """Map representing different coordinate systems. + + Coordinate definitions: + * axes-coordinates goes from 0 to 1 in the domain. + * data-coordinates are specified by the input x-y coordinates. + * grid-coordinates goes from 0 to N and 0 to M for an N x M grid, + where N and M match the shape of the input data. + * mask-coordinates goes from 0 to N and 0 to M for an N x M mask, + where N and M are user-specified to control the density of + streamlines. + + This class also has methods for adding trajectories to the + StreamMask. Before adding a trajectory, run `start_trajectory` to + keep track of regions crossed by a given trajectory. Later, if you + decide the trajectory is bad (e.g., if the trajectory is very + short) just call `undo_trajectory`. + """ + + def __init__(self, grid, mask): + self.grid = grid + self.mask = mask + + # Constants for conversion between grid- and mask-coordinates + self.x_grid2mask = (mask.nx - 1) / grid.nx + self.y_grid2mask = (mask.ny - 1) / grid.ny + self.x_mask2grid = 1.0 / self.x_grid2mask + self.y_mask2grid = 1.0 / self.y_grid2mask + + self.x_data2grid = 1.0 / grid.dx + self.y_data2grid = 1.0 / grid.dy + + def grid2mask(self, xi, yi): + """Return nearest space in mask-coords from given grid-coords.""" + return (int((xi * self.x_grid2mask) + 0.5), int((yi * self.y_grid2mask) + 0.5)) + + def mask2grid(self, xm, ym): + return xm * self.x_mask2grid, ym * self.y_mask2grid + + def data2grid(self, xd, yd): + return xd * self.x_data2grid, yd * self.y_data2grid + + def grid2data(self, xg, yg): + return xg / self.x_data2grid, yg / self.y_data2grid + + def start_trajectory(self, xg, yg): + xm, ym = self.grid2mask(xg, yg) + self.mask._start_trajectory(xm, ym) + + def reset_start_point(self, xg, yg): + xm, ym = self.grid2mask(xg, yg) + self.mask._current_xy = (xm, ym) + + def update_trajectory(self, xg, yg): + xm, ym = self.grid2mask(xg, yg) + # self.mask._update_trajectory(xm, ym) + + def undo_trajectory(self): + self.mask._undo_trajectory() + + +class _Grid(object): + """Grid of data.""" + + def __init__(self, x, y): + if x.ndim == 1: + pass + elif x.ndim == 2: + x_row = x[0, :] + if not np.allclose(x_row, x): + raise ValueError("The rows of 'x' must be equal") + x = x_row + else: + raise ValueError("'x' can have at maximum 2 dimensions") + + if y.ndim == 1: + pass + elif y.ndim == 2: + y_col = y[:, 0] + if not np.allclose(y_col, y.T): + raise ValueError("The columns of 'y' must be equal") + y = y_col + else: + raise ValueError("'y' can have at maximum 2 dimensions") + + self.nx = len(x) + self.ny = len(y) + self.dx = x[1] - x[0] + self.dy = y[1] - y[0] + self.x_origin = x[0] + self.y_origin = y[0] + self.width = x[-1] - x[0] + self.height = y[-1] - y[0] + + @property + def shape(self): + return self.ny, self.nx + + def within_grid(self, xi, yi): + """Return True if point is a valid index of grid.""" + # Note that xi/yi can be floats; so, for example, we can't simply check + # `xi < self.nx` since `xi` can be `self.nx - 1 < xi < self.nx` + return xi >= 0 and xi <= self.nx - 1 and yi >= 0 and yi <= self.ny - 1 + + +class _StreamMask(object): + """Mask to keep track of discrete regions crossed by streamlines. + + The resolution of this grid determines the approximate spacing + between trajectories. Streamlines are only allowed to pass through + zeroed cells: When a streamline enters a cell, that cell is set to + 1, and no new streamlines are allowed to enter. + """ + + def __init__(self, density): + if np.isscalar(density): + if density <= 0: + raise ValueError("If a scalar, 'density' must be positive") + self.nx = self.ny = int(30 * density) + else: + if len(density) != 2: + raise ValueError("'density' can have at maximum 2 dimensions") + self.nx = int(30 * density[0]) + self.ny = int(30 * density[1]) + + self._mask = np.zeros((self.ny, self.nx)) + self.shape = self._mask.shape + self._current_xy = None + + def __getitem__(self, *args): + return self._mask.__getitem__(*args) + + def _start_trajectory(self, xm, ym): + """Start recording streamline trajectory""" + self._traj = [] + self._update_trajectory(xm, ym) + + def _undo_trajectory(self): + """Remove current trajectory from mask""" + for t in self._traj: + self._mask.__setitem__(t, 0) + + def _update_trajectory(self, xm, ym): + """Update current trajectory position in mask. + + If the new position has already been filled, raise + `InvalidIndexError`. + """ + # if self._current_xy != (xm, ym): + # if self[ym, xm] == 0: + self._traj.append((ym, xm)) + self._mask[ym, xm] = 1 + self._current_xy = (xm, ym) + # else: + # raise InvalidIndexError + + +def _get_integrator(u, v, dmap, minlength, resolution, magnitude): + # rescale velocity onto grid-coordinates for integrations. + u, v = dmap.data2grid(u, v) + + # speed (path length) will be in axes-coordinates + u_ax = u / dmap.grid.nx + v_ax = v / dmap.grid.ny + speed = np.ma.sqrt(u_ax**2 + v_ax**2) + + def forward_time(xi, yi): + ds_dt = _interpgrid(speed, xi, yi) + if ds_dt == 0: + raise TerminateTrajectory() + dt_ds = 1.0 / ds_dt + ui = _interpgrid(u, xi, yi) + vi = _interpgrid(v, xi, yi) + return ui * dt_ds, vi * dt_ds + + def integrate(x0, y0): + """Return x, y grid-coordinates of trajectory based on starting point. + + Integrate both forward and backward in time from starting point + in grid coordinates. Integration is terminated when a trajectory + reaches a domain boundary or when it crosses into an already + occupied cell in the StreamMask. The resulting trajectory is + None if it is shorter than `minlength`. + """ + stotal, x_traj, y_traj = 0.0, [], [] + dmap.start_trajectory(x0, y0) + dmap.reset_start_point(x0, y0) + stotal, x_traj, y_traj, m_total, hit_edge = _integrate_rk12( + x0, y0, dmap, forward_time, resolution, magnitude + ) + + if len(x_traj) > 1: + return (x_traj, y_traj), hit_edge + else: + # reject short trajectories + dmap.undo_trajectory() + return None + + return integrate + + +def _integrate_rk12(x0, y0, dmap, f, resolution, magnitude): + """2nd-order Runge-Kutta algorithm with adaptive step size. + + This method is also referred to as the improved Euler's method, or + Heun's method. This method is favored over higher-order methods + because: + + 1. To get decent looking trajectories and to sample every mask cell + on the trajectory we need a small timestep, so a lower order + solver doesn't hurt us unless the data is *very* high + resolution. In fact, for cases where the user inputs data + smaller or of similar grid size to the mask grid, the higher + order corrections are negligible because of the very fast linear + interpolation used in `interpgrid`. + + 2. For high resolution input data (i.e. beyond the mask + resolution), we must reduce the timestep. Therefore, an + adaptive timestep is more suited to the problem as this would be + very hard to judge automatically otherwise. + + This integrator is about 1.5 - 2x as fast as both the RK4 and RK45 + solvers in most setups on my machine. I would recommend removing + the other two to keep things simple. + """ + # This error is below that needed to match the RK4 integrator. It + # is set for visual reasons -- too low and corners start + # appearing ugly and jagged. Can be tuned. + maxerror = 0.003 + + # This limit is important (for all integrators) to avoid the + # trajectory skipping some mask cells. We could relax this + # condition if we use the code which is commented out below to + # increment the location gradually. However, due to the efficient + # nature of the interpolation, this doesn't boost speed by much + # for quite a bit of complexity. + maxds = min(1.0 / dmap.mask.nx, 1.0 / dmap.mask.ny, 0.1) + ds = maxds + + stotal = 0 + xi = x0 + yi = y0 + xf_traj = [] + yf_traj = [] + m_total = [] + hit_edge = False + + while dmap.grid.within_grid(xi, yi): + xf_traj.append(xi) + yf_traj.append(yi) + m_total.append(_interpgrid(magnitude, xi, yi)) + + try: + k1x, k1y = f(xi, yi) + k2x, k2y = f(xi + ds * k1x, yi + ds * k1y) + except IndexError: + # Out of the domain on one of the intermediate integration steps. + # Take an Euler step to the boundary to improve neatness. + ds, xf_traj, yf_traj = _euler_step(xf_traj, yf_traj, dmap, f) + stotal += ds + hit_edge = True + break + except TerminateTrajectory: + break + + dx1 = ds * k1x + dy1 = ds * k1y + dx2 = ds * 0.5 * (k1x + k2x) + dy2 = ds * 0.5 * (k1y + k2y) + + nx, ny = dmap.grid.shape + # Error is normalized to the axes coordinates + error = np.sqrt(((dx2 - dx1) / nx) ** 2 + ((dy2 - dy1) / ny) ** 2) + + # Only save step if within error tolerance + if error < maxerror: + xi += dx2 + yi += dy2 + dmap.update_trajectory(xi, yi) + if not dmap.grid.within_grid(xi, yi): + hit_edge = True + if (stotal + ds) > resolution * np.mean(m_total): + break + stotal += ds + + # recalculate stepsize based on step error + if error == 0: + ds = maxds + else: + ds = min(maxds, 0.85 * ds * (maxerror / error) ** 0.5) + + return stotal, xf_traj, yf_traj, m_total, hit_edge + + +def _euler_step(xf_traj, yf_traj, dmap, f): + """Simple Euler integration step that extends streamline to boundary.""" + ny, nx = dmap.grid.shape + xi = xf_traj[-1] + yi = yf_traj[-1] + cx, cy = f(xi, yi) + + if cx == 0: + dsx = np.inf + elif cx < 0: + dsx = xi / -cx + else: + dsx = (nx - 1 - xi) / cx + + if cy == 0: + dsy = np.inf + elif cy < 0: + dsy = yi / -cy + else: + dsy = (ny - 1 - yi) / cy + + ds = min(dsx, dsy) + + xf_traj.append(xi + cx * ds) + yf_traj.append(yi + cy * ds) + + return ds, xf_traj, yf_traj + + +def _interpgrid(a, xi, yi): + """Fast 2D, linear interpolation on an integer grid""" + Ny, Nx = np.shape(a) + + if isinstance(xi, np.ndarray): + x = xi.astype(int) + y = yi.astype(int) + + # Check that xn, yn don't exceed max index + xn = np.clip(x + 1, 0, Nx - 1) + yn = np.clip(y + 1, 0, Ny - 1) + else: + x = int(xi) + y = int(yi) + xn = min(x + 1, Nx - 1) + yn = min(y + 1, Ny - 1) + + a00 = a[y, x] + a01 = a[y, xn] + a10 = a[yn, x] + a11 = a[yn, xn] + + xt = xi - x + yt = yi - y + + a0 = a00 * (1 - xt) + a01 * xt + a1 = a10 * (1 - xt) + a11 * xt + ai = a0 * (1 - yt) + a1 * yt + + if not isinstance(xi, np.ndarray): + if np.ma.is_masked(ai): + raise TerminateTrajectory + return ai + + +def _gen_starting_points(x, y, grains): + eps = np.finfo(np.float32).eps + tmp_x = np.linspace(x.min() + eps, x.max() - eps, grains) + tmp_y = np.linspace(y.min() + eps, y.max() - eps, grains) + xs = np.tile(tmp_x, grains) + ys = np.repeat(tmp_y, grains) + seed_points = np.array([list(xs), list(ys)]) + return seed_points.T + + class PlotAxes(base.Axes): """ The second lowest-level `~matplotlib.axes.Axes` subclass used by ultraplot. Implements all plotting overrides. """ - def __init__(self, *args, **kwargs): + @docstring._snippet_manager + def curved_quiver( + self, + x, + y, + u, + v, + linewidth=None, + color=None, + cmap=None, + norm=None, + arrowsize=1, + arrowstyle="-|>", + transform=None, + zorder=None, + start_points=None, + scale=1.0, + grains=15, + density=10, + arrow_at_end=True, + ): + """ + %(plot.curved_quiver)s + + Notes + ----- + The implementation of this function is based on the `dfm_tools` repository. + Original file: https://github.com/Deltares/dfm_tools/blob/829e76f48ebc42460aae118cc190147a595a5f26/dfm_tools/modplot.py """ - Parameters - ---------- - *args, **kwargs - Passed to `ultraplot.axes.Axes`. + grid = _Grid(x, y) + mask = _StreamMask(density) + dmap = _DomainMap(grid, mask) + + if zorder is None: + zorder = mlines.Line2D.zorder + + # default to data coordinates + if transform is None: + transform = self.transData + + if color is None: + color = self._get_lines.get_next_color() + + if linewidth is None: + linewidth = rc["lines.linewidth"] + + line_kw = {} + arrow_kw = dict(arrowstyle=arrowstyle, mutation_scale=10 * arrowsize) + + use_multicolor_lines = isinstance(color, np.ndarray) + if use_multicolor_lines: + if color.shape != grid.shape: + raise ValueError( + "If 'color' is given, must have the shape of 'Grid(x,y)'" + ) + line_colors = [] + color = np.ma.masked_invalid(color) + else: + line_kw["color"] = color + arrow_kw["color"] = color + + if isinstance(linewidth, np.ndarray): + if linewidth.shape != grid.shape: + raise ValueError( + "If 'linewidth' is given, must have the shape of 'Grid(x,y)'" + ) + line_kw["linewidth"] = [] + else: + line_kw["linewidth"] = linewidth + arrow_kw["linewidth"] = linewidth + + line_kw["zorder"] = zorder + arrow_kw["zorder"] = zorder + + ## Sanity checks. + if u.shape != grid.shape or v.shape != grid.shape: + raise ValueError("'u' and 'v' must be of shape 'Grid(x,y)'") + + u = np.ma.masked_invalid(u) + v = np.ma.masked_invalid(v) + magnitude = np.sqrt(u**2 + v**2) + magnitude /= np.max(magnitude) + + resolution = scale / grains + minlength = 0.9 * resolution + + integrate = _get_integrator(u, v, dmap, minlength, resolution, magnitude) + + trajectories = [] + edges = [] + + if start_points is None: + start_points = _gen_starting_points(x, y, grains) + + sp2 = np.asanyarray(start_points, dtype=float).copy() + + # Check if start_points are outside the data boundaries + for xs, ys in sp2: + if not ( + grid.x_origin <= xs <= grid.x_origin + grid.width + and grid.y_origin <= ys <= grid.y_origin + grid.height + ): + raise ValueError( + "Starting point ({}, {}) outside of data " + "boundaries".format(xs, ys) + ) + + # Convert start_points from data to array coords + # Shift the seed points from the bottom left of the data so that + # data2grid works properly. + sp2[:, 0] -= grid.x_origin + sp2[:, 1] -= grid.y_origin + + for xs, ys in sp2: + xg, yg = dmap.data2grid(xs, ys) + t = integrate(xg, yg) + if t is not None: + trajectories.append(t[0]) + edges.append(t[1]) + + if use_multicolor_lines: + if norm is None: + norm = mcolors.Normalize(color.min(), color.max()) + if cmap is None: + cmap = constructor.Colormap(rc["image.cmap"]) + else: + cmap = mcm.get_cmap(cmap) + + streamlines = [] + arrows = [] + for t, edge in zip(trajectories, edges): + tgx = np.array(t[0]) + tgy = np.array(t[1]) + + # Rescale from grid-coordinates to data-coordinates. + tx, ty = dmap.grid2data(*np.array(t)) + tx += grid.x_origin + ty += grid.y_origin + + points = np.transpose([tx, ty]).reshape(-1, 1, 2) + streamlines.extend(np.hstack([points[:-1], points[1:]])) + + if len(tx) < 2: + continue + + # Add arrows + s = np.cumsum(np.sqrt(np.diff(tx) ** 2 + np.diff(ty) ** 2)) + if arrow_at_end: + if len(tx) < 2: + continue + + arrow_tail = (tx[-1], ty[-1]) + + # Extrapolate to find arrow head + xg, yg = dmap.data2grid(tx[-1] - grid.x_origin, ty[-1] - grid.y_origin) + + ui = _interpgrid(u, xg, yg) + vi = _interpgrid(v, xg, yg) + + norm_v = np.sqrt(ui**2 + vi**2) + if norm_v > 0: + ui /= norm_v + vi /= norm_v + + if len(s) > 0: + # use average segment length + arrow_length = arrowsize * (s[-1] / len(s)) + else: + # fallback for very short streamlines + arrow_length = arrowsize * 0.1 * np.mean([grid.dx, grid.dy]) + + arrow_head = (tx[-1] + ui * arrow_length, ty[-1] + vi * arrow_length) + n = len(s) - 1 if len(s) > 0 else 0 + else: + n = np.searchsorted(s, s[-1] / 2.0) + arrow_tail = (tx[n], ty[n]) + arrow_head = (np.mean(tx[n : n + 2]), np.mean(ty[n : n + 2])) + + if isinstance(linewidth, np.ndarray): + line_widths = _interpgrid(linewidth, tgx, tgy)[:-1] + line_kw["linewidth"].extend(line_widths) + arrow_kw["linewidth"] = line_widths[n] + + if use_multicolor_lines: + color_values = _interpgrid(color, tgx, tgy)[:-1] + line_colors.append(color_values) + arrow_kw["color"] = cmap(norm(color_values[n])) + + if not edge: + p = mpatches.FancyArrowPatch( + arrow_tail, arrow_head, transform=transform, **arrow_kw + ) + else: + continue + + ds = np.sqrt( + (arrow_tail[0] - arrow_head[0]) ** 2 + + (arrow_tail[1] - arrow_head[1]) ** 2 + ) + if ds < 1e-15: + continue # remove vanishingly short arrows that cause Patch to fail + + self.add_patch(p) + arrows.append(p) + + lc = mcollections.LineCollection(streamlines, transform=transform, **line_kw) + lc.sticky_edges.x[:] = [grid.x_origin, grid.x_origin + grid.width] + lc.sticky_edges.y[:] = [grid.y_origin, grid.y_origin + grid.height] - See also - -------- - matplotlib.axes.Axes - ultraplot.axes.Axes - ultraplot.axes.CartesianAxes - ultraplot.axes.PolarAxes - ultraplot.axes.GeoAxes + if use_multicolor_lines: + lc.set_array(np.ma.hstack(line_colors)) + lc.set_cmap(cmap) + lc.set_norm(norm) + + self.add_collection(lc) + self.autoscale_view() + + ac = mcollections.PatchCollection(arrows) + stream_container = CurvedQuiverSet(lc, ac) + return stream_container + + + def _add_plot_elements( + self, + streamlines, + arrows, + grid, + line_kw, + arrow_kw, + use_multicolor_lines, + line_colors, + cmap, + norm, + transform, + ): """ - super().__init__(*args, **kwargs) + Add line and arrow collections to the axes, set sticky edges, and return the container. + """ + lc = mcollections.LineCollection(streamlines, transform=transform, **line_kw) + lc.sticky_edges.x[:] = [grid.x_origin, grid.x_origin + grid.width] + lc.sticky_edges.y[:] = [grid.y_origin, grid.y_origin + grid.height] + if use_multicolor_lines: + lc.set_array(np.ma.hstack(line_colors)) + lc.set_cmap(cmap) + lc.set_norm(norm) + self.add_collection(lc) + self.autoscale_view() + ac = mcollections.PatchCollection(arrows) + stream_container = CurvedQuiverSet(lc, ac) + return stream_container def _call_native(self, name, *args, **kwargs): """ @@ -5359,6 +6026,8 @@ def tripcolor(self, *args, **kwargs): # Update kwargs and handle cmap kw.update(_pop_props(kw, "collection")) + + center_levels = kw.pop("center_levels", None) kw = self._parse_cmap( triangulation.x, triangulation.y, z, center_levels=center_levels, **kw @@ -5481,3 +6150,92 @@ def _iter_arg_cols(self, *args, label=None, labels=None, values=None, **kwargs): # Rename the shorthands boxes = warnings._rename_objs("0.8.0", boxes=box) violins = warnings._rename_objs("0.8.0", violins=violin) + +# +def _setup_grid_and_mask(x, y, density): + """ + Helper for `curved_quiver`. + + Initializes and returns the grid, stream mask, and domain map objects for the vector field. + """ + grid = _Grid(x, y) + mask = _StreamMask(density) + dmap = _DomainMap(grid, mask) + return grid, mask, dmap + +def _validate_vector_shapes(u, v, grid): + """ + Helper for `curved_quiver`. + + Validates that the shapes of `u` and `v` match the grid shape. Raises ValueError if not. + """ + if u.shape != grid.shape or v.shape != grid.shape: + raise ValueError("'u' and 'v' must be of shape 'Grid(x,y)'") + +def _normalize_magnitude(u, v): + """ + Helper for `curved_quiver`. + + Computes and returns the normalized magnitude array for the vector field. + """ + u = np.ma.masked_invalid(u) + v = np.ma.masked_invalid(v) + magnitude = np.sqrt(u**2 + v**2) + magnitude /= np.max(magnitude) + return magnitude + +def _generate_start_points(x, y, grains, start_points, grid): + """ + Helper for `curved_quiver`. + + Generates or validates starting points for streamlines, ensuring they are within grid boundaries. + Returns points in grid coordinates. + """ + if start_points is None: + start_points = _gen_starting_points(x, y, grains) + sp2 = np.asanyarray(start_points, dtype=float).copy() + for xs, ys in sp2: + if not ( + grid.x_origin <= xs <= grid.x_origin + grid.width + and grid.y_origin <= ys <= grid.y_origin + grid.height + ): + raise ValueError( + "Starting point ({}, {}) outside of data boundaries".format(xs, ys) + ) + sp2[:, 0] -= grid.x_origin + sp2[:, 1] -= grid.y_origin + return sp2 + +def _calculate_trajectories(sp2, dmap, integrate): + """ + Helper for `curved_quiver`. + + Integrates trajectories from each starting point using the provided integrator. + Returns lists of trajectories and edges. + """ + trajectories = [] + edges = [] + for xs, ys in sp2: + xg, yg = dmap.data2grid(xs, ys) + t = integrate(xg, yg) + if t is not None: + trajectories.append(t[0]) + edges.append(t[1]) + return trajectories, edges + +def _handle_multicolor_lines(color, norm, cmap, grid): + """ + Helper for `curved_quiver`. + + Prepares colormap and normalization for multicolor lines. + Returns updated color, norm, cmap, and a list for line colors. + """ + line_colors = [] + if norm is None: + norm = mcolors.Normalize(color.min(), color.max()) + if cmap is None: + cmap = constructor.Colormap(rc["image.cmap"]) + else: + cmap = mcm.get_cmap(cmap) + color = np.ma.masked_invalid(color) + return color, norm, cmap, line_colors diff --git a/ultraplot/tests/test_plot.py b/ultraplot/tests/test_plot.py index 9f43a144b..a7ee6ec10 100644 --- a/ultraplot/tests/test_plot.py +++ b/ultraplot/tests/test_plot.py @@ -447,3 +447,35 @@ def test_inhomogeneous_violin(rng): for violin in violins: assert violin.get_paths() # Ensure paths are created return fig + + +@pytest.mark.mpl_image_compare +def test_curved_quiver(rng): + # Create a grid + x = np.linspace(-4, 4, 20) + y = np.linspace(-3, 3, 20) + X, Y = np.meshgrid(x, y) + + # Define a rotational vector field (circular flow) + U = -Y + V = X + speed = np.sqrt(U**2 + V**2) + + # Create a figure and axes + fig, axs = uplt.subplots(ncols=3, sharey=True, figsize=(12, 4)) + + # Left plot: matplotlib's streamplot + axs[0].streamplot(X, Y, U, V, color=speed) + axs[0].set_title("streamplot (native)") + + # Middle plot: quiver + axs[1].quiver(X, Y, U, V, speed) + axs[1].set_title("quiver") + + # Right plot: curved_quiver + m = axs[2].curved_quiver( + X, Y, U, V, color=speed, arrow_at_end=True, scale=2.0, grains=10 + ) + axs[2].set_title("curved_quiver") + fig.colorbar(m.lines, ax=axs[:], label="speed") + return fig From 5b0a0eb96bb6cc9d7de2a77f0d8d4019da28197d Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Mon, 6 Oct 2025 17:02:40 +0200 Subject: [PATCH 02/33] add unittests --- ultraplot/tests/test_plot.py | 90 ++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/ultraplot/tests/test_plot.py b/ultraplot/tests/test_plot.py index a7ee6ec10..47703a580 100644 --- a/ultraplot/tests/test_plot.py +++ b/ultraplot/tests/test_plot.py @@ -479,3 +479,93 @@ def test_curved_quiver(rng): axs[2].set_title("curved_quiver") fig.colorbar(m.lines, ax=axs[:], label="speed") return fig + + +def test_setup_grid_and_mask(): + x = np.linspace(0, 1, 5) + y = np.linspace(0, 1, 5) + grid, mask, dmap = uplt.axes.plot._setup_grid_and_mask(x, y, density=5) + assert grid.shape == (5, 5) + assert hasattr(mask, "shape") + assert hasattr(dmap, "grid") + assert hasattr(dmap, "mask") + + +def test_validate_vector_shapes_pass(): + x = np.linspace(0, 1, 3) + y = np.linspace(0, 1, 3) + grid, _, _ = uplt.axes.plot._setup_grid_and_mask(x, y, density=3) + u = np.ones(grid.shape) + v = np.ones(grid.shape) + # Should not raise + uplt.axes.plot._validate_vector_shapes(u, v, grid) + + +def test_validate_vector_shapes_fail(): + x = np.linspace(0, 1, 3) + y = np.linspace(0, 1, 3) + grid, _, _ = uplt.axes.plot._setup_grid_and_mask(x, y, density=3) + u = np.ones((2, 2)) + v = np.ones(grid.shape) + import pytest + + with pytest.raises(ValueError): + uplt.axes.plot._validate_vector_shapes(u, v, grid) + + +def test_normalize_magnitude(): + u = np.array([[1, 2], [3, 4]]) + v = np.array([[4, 3], [2, 1]]) + mag = uplt.axes.plot._normalize_magnitude(u, v) + assert np.allclose(np.max(mag), 1.0) + assert mag.shape == u.shape + + +def test_generate_start_points(): + x = np.linspace(0, 1, 5) + y = np.linspace(0, 1, 5) + grid, _, _ = uplt.axes.plot._setup_grid_and_mask(x, y, density=5) + sp2 = uplt.axes.plot._generate_start_points( + x, y, grains=5, start_points=None, grid=grid + ) + assert sp2.shape[1] == 2 + # Should raise if outside boundaries + import pytest + + bad_points = np.array([[10, 10]]) + with pytest.raises(ValueError): + uplt.axes.plot._generate_start_points( + x, y, grains=5, start_points=bad_points, grid=grid + ) + + +def test_calculate_trajectories(): + # Use a dummy integrator that returns a fixed trajectory + def dummy_integrate(xg, yg): + return ([np.array([xg, xg + 1]), np.array([yg, yg + 1])], False) + + sp2 = np.array([[0, 0], [1, 1]]) + + class DummyDMap: + def data2grid(self, xs, ys): + return xs, ys + + trajectories, edges = uplt.axes.plot._calculate_trajectories( + sp2, DummyDMap(), dummy_integrate + ) + assert len(trajectories) == 2 + assert len(edges) == 2 + + +def test_handle_multicolor_lines(): + color = np.array([[0, 1], [2, 3]]) + norm = None + cmap = None + grid = mock.Mock() + out_color, out_norm, out_cmap, line_colors = ( + uplt.axes.plot._handle_multicolor_lines(color, norm, cmap, grid) + ) + assert out_color.shape == color.shape + assert hasattr(out_norm, "autoscale") + assert hasattr(out_cmap, "__call__") + assert isinstance(line_colors, list) From e4b1023276deb91899a40c4cdb6ad8e94a889c84 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Mon, 6 Oct 2025 17:04:51 +0200 Subject: [PATCH 03/33] update docstrings --- ultraplot/axes/plot.py | 10 ++-------- ultraplot/tests/test_plot.py | 25 +++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index 3121456db..638fd8beb 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -182,7 +182,7 @@ Grid coordinates. u, v : 2D arrays Vector components. -c : color or 2D array, optional +color : color or 2D array, optional Streamline color. density : float or (float, float), optional Controls the closeness of streamlines. @@ -190,8 +190,6 @@ Number of seed points in x and y. linewidth : float or 2D array, optional Width of streamlines. -color : color or 2D array, optional - Streamline color. cmap, norm : optional Colormap and normalization for array colors. arrowsize : float, optional @@ -204,14 +202,10 @@ Z-order for lines/arrows. start_points : (N, 2) array, optional Starting points for streamlines. -integration_direction : {'forward', 'backward', 'both'}, optional - Direction to integrate streamlines. -broken_streamlines : bool, optional - If True, streamlines may terminate early. Returns ------- -StreamplotSet +CurvedQuiverSet Container with attributes: - lines: LineCollection of streamlines - arrows: PatchCollection of arrows diff --git a/ultraplot/tests/test_plot.py b/ultraplot/tests/test_plot.py index 47703a580..e6330f308 100644 --- a/ultraplot/tests/test_plot.py +++ b/ultraplot/tests/test_plot.py @@ -482,6 +482,10 @@ def test_curved_quiver(rng): def test_setup_grid_and_mask(): + """ + Test that _setup_grid_and_mask creates grid, mask, and domain map objects + with expected attributes and shapes for a simple input. + """ x = np.linspace(0, 1, 5) y = np.linspace(0, 1, 5) grid, mask, dmap = uplt.axes.plot._setup_grid_and_mask(x, y, density=5) @@ -492,6 +496,9 @@ def test_setup_grid_and_mask(): def test_validate_vector_shapes_pass(): + """ + Test that _validate_vector_shapes passes when u and v match the grid shape. + """ x = np.linspace(0, 1, 3) y = np.linspace(0, 1, 3) grid, _, _ = uplt.axes.plot._setup_grid_and_mask(x, y, density=3) @@ -502,6 +509,9 @@ def test_validate_vector_shapes_pass(): def test_validate_vector_shapes_fail(): + """ + Test that _validate_vector_shapes raises ValueError when u and v do not match the grid shape. + """ x = np.linspace(0, 1, 3) y = np.linspace(0, 1, 3) grid, _, _ = uplt.axes.plot._setup_grid_and_mask(x, y, density=3) @@ -514,6 +524,9 @@ def test_validate_vector_shapes_fail(): def test_normalize_magnitude(): + """ + Test that _normalize_magnitude returns a normalized array with max value 1.0 and correct shape. + """ u = np.array([[1, 2], [3, 4]]) v = np.array([[4, 3], [2, 1]]) mag = uplt.axes.plot._normalize_magnitude(u, v) @@ -522,6 +535,10 @@ def test_normalize_magnitude(): def test_generate_start_points(): + """ + Test that _generate_start_points returns valid grid coordinates for seed points, + and raises ValueError for points outside the grid boundaries. + """ x = np.linspace(0, 1, 5) y = np.linspace(0, 1, 5) grid, _, _ = uplt.axes.plot._setup_grid_and_mask(x, y, density=5) @@ -540,6 +557,11 @@ def test_generate_start_points(): def test_calculate_trajectories(): + """ + Test that _calculate_trajectories calls the integrator for each seed point + and returns lists of trajectories and edges of correct length. + """ + # Use a dummy integrator that returns a fixed trajectory def dummy_integrate(xg, yg): return ([np.array([xg, xg + 1]), np.array([yg, yg + 1])], False) @@ -558,6 +580,9 @@ def data2grid(self, xs, ys): def test_handle_multicolor_lines(): + """ + Test that _handle_multicolor_lines returns masked color array, norm, cmap, and an empty line_colors list. + """ color = np.array([[0, 1], [2, 3]]) norm = None cmap = None From ce7bb374de43a651a9a20071fde736e1b20b6a8c Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Mon, 6 Oct 2025 17:07:51 +0200 Subject: [PATCH 04/33] black formatting --- ultraplot/axes/plot.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index 638fd8beb..5d1e460e0 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -2147,7 +2147,6 @@ def curved_quiver( stream_container = CurvedQuiverSet(lc, ac) return stream_container - def _add_plot_elements( self, streamlines, @@ -6021,7 +6020,6 @@ def tripcolor(self, *args, **kwargs): # Update kwargs and handle cmap kw.update(_pop_props(kw, "collection")) - center_levels = kw.pop("center_levels", None) kw = self._parse_cmap( triangulation.x, triangulation.y, z, center_levels=center_levels, **kw @@ -6145,6 +6143,7 @@ def _iter_arg_cols(self, *args, label=None, labels=None, values=None, **kwargs): boxes = warnings._rename_objs("0.8.0", boxes=box) violins = warnings._rename_objs("0.8.0", violins=violin) + # def _setup_grid_and_mask(x, y, density): """ @@ -6157,6 +6156,7 @@ def _setup_grid_and_mask(x, y, density): dmap = _DomainMap(grid, mask) return grid, mask, dmap + def _validate_vector_shapes(u, v, grid): """ Helper for `curved_quiver`. @@ -6166,6 +6166,7 @@ def _validate_vector_shapes(u, v, grid): if u.shape != grid.shape or v.shape != grid.shape: raise ValueError("'u' and 'v' must be of shape 'Grid(x,y)'") + def _normalize_magnitude(u, v): """ Helper for `curved_quiver`. @@ -6178,6 +6179,7 @@ def _normalize_magnitude(u, v): magnitude /= np.max(magnitude) return magnitude + def _generate_start_points(x, y, grains, start_points, grid): """ Helper for `curved_quiver`. @@ -6200,6 +6202,7 @@ def _generate_start_points(x, y, grains, start_points, grid): sp2[:, 1] -= grid.y_origin return sp2 + def _calculate_trajectories(sp2, dmap, integrate): """ Helper for `curved_quiver`. @@ -6217,6 +6220,7 @@ def _calculate_trajectories(sp2, dmap, integrate): edges.append(t[1]) return trajectories, edges + def _handle_multicolor_lines(color, norm, cmap, grid): """ Helper for `curved_quiver`. From 14b7ce621461f4cdaffcbd3b60ace7a8d7ea5a6c Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Mon, 6 Oct 2025 17:08:34 +0200 Subject: [PATCH 05/33] rm dup import --- ultraplot/axes/plot.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index 5d1e460e0..ac2c3d487 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -36,8 +36,6 @@ import numpy as np import numpy.ma as ma -from matplotlib.streamplot import StreamplotSet - from .. import colors as pcolors from .. import constructor, utils from ..config import rc From 7b1dc1948b7693524c215db624d5f2b7350c849f Mon Sep 17 00:00:00 2001 From: Casper van Elteren Date: Mon, 6 Oct 2025 17:11:27 +0200 Subject: [PATCH 06/33] Update ultraplot/axes/plot.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- ultraplot/axes/plot.py | 29 ----------------------------- 1 file changed, 29 deletions(-) diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index ac2c3d487..50930d802 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -2145,35 +2145,6 @@ def curved_quiver( stream_container = CurvedQuiverSet(lc, ac) return stream_container - def _add_plot_elements( - self, - streamlines, - arrows, - grid, - line_kw, - arrow_kw, - use_multicolor_lines, - line_colors, - cmap, - norm, - transform, - ): - """ - Add line and arrow collections to the axes, set sticky edges, and return the container. - """ - lc = mcollections.LineCollection(streamlines, transform=transform, **line_kw) - lc.sticky_edges.x[:] = [grid.x_origin, grid.x_origin + grid.width] - lc.sticky_edges.y[:] = [grid.y_origin, grid.y_origin + grid.height] - if use_multicolor_lines: - lc.set_array(np.ma.hstack(line_colors)) - lc.set_cmap(cmap) - lc.set_norm(norm) - self.add_collection(lc) - self.autoscale_view() - ac = mcollections.PatchCollection(arrows) - stream_container = CurvedQuiverSet(lc, ac) - return stream_container - def _call_native(self, name, *args, **kwargs): """ Call the plotting method and redirect internal calls to native methods. From 578dd2a047efba2a581bec9d8692032911d644ac Mon Sep 17 00:00:00 2001 From: Casper van Elteren Date: Mon, 6 Oct 2025 17:11:35 +0200 Subject: [PATCH 07/33] Update ultraplot/axes/plot.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- ultraplot/axes/plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index 50930d802..50c95a227 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -1719,7 +1719,7 @@ def _get_integrator(u, v, dmap, minlength, resolution, magnitude): def forward_time(xi, yi): ds_dt = _interpgrid(speed, xi, yi) if ds_dt == 0: - raise TerminateTrajectory() + raise _TerminateTrajectory() dt_ds = 1.0 / ds_dt ui = _interpgrid(u, xi, yi) vi = _interpgrid(v, xi, yi) From a3290c1644fcee3e9b4437475397caab5920031c Mon Sep 17 00:00:00 2001 From: Casper van Elteren Date: Mon, 6 Oct 2025 17:11:44 +0200 Subject: [PATCH 08/33] Update ultraplot/axes/plot.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- ultraplot/axes/plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index 50c95a227..9b605fe0f 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -1904,7 +1904,7 @@ def _interpgrid(a, xi, yi): if not isinstance(xi, np.ndarray): if np.ma.is_masked(ai): - raise TerminateTrajectory + raise _TerminateTrajectory return ai From 28b20a1bad63297a5365b01aea8425b73c79a3ca Mon Sep 17 00:00:00 2001 From: Casper van Elteren Date: Mon, 6 Oct 2025 17:11:53 +0200 Subject: [PATCH 09/33] Update ultraplot/tests/test_plot.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- ultraplot/tests/test_plot.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ultraplot/tests/test_plot.py b/ultraplot/tests/test_plot.py index e6330f308..baf650a51 100644 --- a/ultraplot/tests/test_plot.py +++ b/ultraplot/tests/test_plot.py @@ -547,7 +547,6 @@ def test_generate_start_points(): ) assert sp2.shape[1] == 2 # Should raise if outside boundaries - import pytest bad_points = np.array([[10, 10]]) with pytest.raises(ValueError): From e27e9e371726fe0bd8fa30303810be6ecba0fb63 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Mon, 6 Oct 2025 17:12:47 +0200 Subject: [PATCH 10/33] mv import up --- ultraplot/tests/test_plot.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ultraplot/tests/test_plot.py b/ultraplot/tests/test_plot.py index baf650a51..d412cce32 100644 --- a/ultraplot/tests/test_plot.py +++ b/ultraplot/tests/test_plot.py @@ -517,7 +517,6 @@ def test_validate_vector_shapes_fail(): grid, _, _ = uplt.axes.plot._setup_grid_and_mask(x, y, density=3) u = np.ones((2, 2)) v = np.ones(grid.shape) - import pytest with pytest.raises(ValueError): uplt.axes.plot._validate_vector_shapes(u, v, grid) From 8ab593bb94f824cf9f319bea5064fc507009e4df Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Tue, 7 Oct 2025 17:38:20 +0200 Subject: [PATCH 11/33] refactor and move down --- ultraplot/axes/plot.py | 897 ++++++++++++++++++----------------------- 1 file changed, 402 insertions(+), 495 deletions(-) diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index 9b605fe0f..49785295b 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -12,6 +12,7 @@ from typing import Any, Union, Iterable, Optional +from dataclasses import dataclass from typing import Any, Union from collections.abc import Callable from collections.abc import Iterable @@ -1532,391 +1533,6 @@ def _inside_seaborn_call(): return False -# The following helper classes and functions for curved_quiver are based on the -# work in the `dfm_tools` repository. -# Original file: https://github.com/Deltares/dfm_tools/blob/829e76f48ebc42460aae118cc190147a595a5f26/dfm_tools/modplot.py - - -class _TerminateTrajectory(Exception): - pass - - -from dataclasses import dataclass - - -@dataclass -class CurvedQuiverSet(StreamplotSet): - lines: object - arrows: object - - -class _DomainMap(object): - """Map representing different coordinate systems. - - Coordinate definitions: - * axes-coordinates goes from 0 to 1 in the domain. - * data-coordinates are specified by the input x-y coordinates. - * grid-coordinates goes from 0 to N and 0 to M for an N x M grid, - where N and M match the shape of the input data. - * mask-coordinates goes from 0 to N and 0 to M for an N x M mask, - where N and M are user-specified to control the density of - streamlines. - - This class also has methods for adding trajectories to the - StreamMask. Before adding a trajectory, run `start_trajectory` to - keep track of regions crossed by a given trajectory. Later, if you - decide the trajectory is bad (e.g., if the trajectory is very - short) just call `undo_trajectory`. - """ - - def __init__(self, grid, mask): - self.grid = grid - self.mask = mask - - # Constants for conversion between grid- and mask-coordinates - self.x_grid2mask = (mask.nx - 1) / grid.nx - self.y_grid2mask = (mask.ny - 1) / grid.ny - self.x_mask2grid = 1.0 / self.x_grid2mask - self.y_mask2grid = 1.0 / self.y_grid2mask - - self.x_data2grid = 1.0 / grid.dx - self.y_data2grid = 1.0 / grid.dy - - def grid2mask(self, xi, yi): - """Return nearest space in mask-coords from given grid-coords.""" - return (int((xi * self.x_grid2mask) + 0.5), int((yi * self.y_grid2mask) + 0.5)) - - def mask2grid(self, xm, ym): - return xm * self.x_mask2grid, ym * self.y_mask2grid - - def data2grid(self, xd, yd): - return xd * self.x_data2grid, yd * self.y_data2grid - - def grid2data(self, xg, yg): - return xg / self.x_data2grid, yg / self.y_data2grid - - def start_trajectory(self, xg, yg): - xm, ym = self.grid2mask(xg, yg) - self.mask._start_trajectory(xm, ym) - - def reset_start_point(self, xg, yg): - xm, ym = self.grid2mask(xg, yg) - self.mask._current_xy = (xm, ym) - - def update_trajectory(self, xg, yg): - xm, ym = self.grid2mask(xg, yg) - # self.mask._update_trajectory(xm, ym) - - def undo_trajectory(self): - self.mask._undo_trajectory() - - -class _Grid(object): - """Grid of data.""" - - def __init__(self, x, y): - if x.ndim == 1: - pass - elif x.ndim == 2: - x_row = x[0, :] - if not np.allclose(x_row, x): - raise ValueError("The rows of 'x' must be equal") - x = x_row - else: - raise ValueError("'x' can have at maximum 2 dimensions") - - if y.ndim == 1: - pass - elif y.ndim == 2: - y_col = y[:, 0] - if not np.allclose(y_col, y.T): - raise ValueError("The columns of 'y' must be equal") - y = y_col - else: - raise ValueError("'y' can have at maximum 2 dimensions") - - self.nx = len(x) - self.ny = len(y) - self.dx = x[1] - x[0] - self.dy = y[1] - y[0] - self.x_origin = x[0] - self.y_origin = y[0] - self.width = x[-1] - x[0] - self.height = y[-1] - y[0] - - @property - def shape(self): - return self.ny, self.nx - - def within_grid(self, xi, yi): - """Return True if point is a valid index of grid.""" - # Note that xi/yi can be floats; so, for example, we can't simply check - # `xi < self.nx` since `xi` can be `self.nx - 1 < xi < self.nx` - return xi >= 0 and xi <= self.nx - 1 and yi >= 0 and yi <= self.ny - 1 - - -class _StreamMask(object): - """Mask to keep track of discrete regions crossed by streamlines. - - The resolution of this grid determines the approximate spacing - between trajectories. Streamlines are only allowed to pass through - zeroed cells: When a streamline enters a cell, that cell is set to - 1, and no new streamlines are allowed to enter. - """ - - def __init__(self, density): - if np.isscalar(density): - if density <= 0: - raise ValueError("If a scalar, 'density' must be positive") - self.nx = self.ny = int(30 * density) - else: - if len(density) != 2: - raise ValueError("'density' can have at maximum 2 dimensions") - self.nx = int(30 * density[0]) - self.ny = int(30 * density[1]) - - self._mask = np.zeros((self.ny, self.nx)) - self.shape = self._mask.shape - self._current_xy = None - - def __getitem__(self, *args): - return self._mask.__getitem__(*args) - - def _start_trajectory(self, xm, ym): - """Start recording streamline trajectory""" - self._traj = [] - self._update_trajectory(xm, ym) - - def _undo_trajectory(self): - """Remove current trajectory from mask""" - for t in self._traj: - self._mask.__setitem__(t, 0) - - def _update_trajectory(self, xm, ym): - """Update current trajectory position in mask. - - If the new position has already been filled, raise - `InvalidIndexError`. - """ - # if self._current_xy != (xm, ym): - # if self[ym, xm] == 0: - self._traj.append((ym, xm)) - self._mask[ym, xm] = 1 - self._current_xy = (xm, ym) - # else: - # raise InvalidIndexError - - -def _get_integrator(u, v, dmap, minlength, resolution, magnitude): - # rescale velocity onto grid-coordinates for integrations. - u, v = dmap.data2grid(u, v) - - # speed (path length) will be in axes-coordinates - u_ax = u / dmap.grid.nx - v_ax = v / dmap.grid.ny - speed = np.ma.sqrt(u_ax**2 + v_ax**2) - - def forward_time(xi, yi): - ds_dt = _interpgrid(speed, xi, yi) - if ds_dt == 0: - raise _TerminateTrajectory() - dt_ds = 1.0 / ds_dt - ui = _interpgrid(u, xi, yi) - vi = _interpgrid(v, xi, yi) - return ui * dt_ds, vi * dt_ds - - def integrate(x0, y0): - """Return x, y grid-coordinates of trajectory based on starting point. - - Integrate both forward and backward in time from starting point - in grid coordinates. Integration is terminated when a trajectory - reaches a domain boundary or when it crosses into an already - occupied cell in the StreamMask. The resulting trajectory is - None if it is shorter than `minlength`. - """ - stotal, x_traj, y_traj = 0.0, [], [] - dmap.start_trajectory(x0, y0) - dmap.reset_start_point(x0, y0) - stotal, x_traj, y_traj, m_total, hit_edge = _integrate_rk12( - x0, y0, dmap, forward_time, resolution, magnitude - ) - - if len(x_traj) > 1: - return (x_traj, y_traj), hit_edge - else: - # reject short trajectories - dmap.undo_trajectory() - return None - - return integrate - - -def _integrate_rk12(x0, y0, dmap, f, resolution, magnitude): - """2nd-order Runge-Kutta algorithm with adaptive step size. - - This method is also referred to as the improved Euler's method, or - Heun's method. This method is favored over higher-order methods - because: - - 1. To get decent looking trajectories and to sample every mask cell - on the trajectory we need a small timestep, so a lower order - solver doesn't hurt us unless the data is *very* high - resolution. In fact, for cases where the user inputs data - smaller or of similar grid size to the mask grid, the higher - order corrections are negligible because of the very fast linear - interpolation used in `interpgrid`. - - 2. For high resolution input data (i.e. beyond the mask - resolution), we must reduce the timestep. Therefore, an - adaptive timestep is more suited to the problem as this would be - very hard to judge automatically otherwise. - - This integrator is about 1.5 - 2x as fast as both the RK4 and RK45 - solvers in most setups on my machine. I would recommend removing - the other two to keep things simple. - """ - # This error is below that needed to match the RK4 integrator. It - # is set for visual reasons -- too low and corners start - # appearing ugly and jagged. Can be tuned. - maxerror = 0.003 - - # This limit is important (for all integrators) to avoid the - # trajectory skipping some mask cells. We could relax this - # condition if we use the code which is commented out below to - # increment the location gradually. However, due to the efficient - # nature of the interpolation, this doesn't boost speed by much - # for quite a bit of complexity. - maxds = min(1.0 / dmap.mask.nx, 1.0 / dmap.mask.ny, 0.1) - ds = maxds - - stotal = 0 - xi = x0 - yi = y0 - xf_traj = [] - yf_traj = [] - m_total = [] - hit_edge = False - - while dmap.grid.within_grid(xi, yi): - xf_traj.append(xi) - yf_traj.append(yi) - m_total.append(_interpgrid(magnitude, xi, yi)) - - try: - k1x, k1y = f(xi, yi) - k2x, k2y = f(xi + ds * k1x, yi + ds * k1y) - except IndexError: - # Out of the domain on one of the intermediate integration steps. - # Take an Euler step to the boundary to improve neatness. - ds, xf_traj, yf_traj = _euler_step(xf_traj, yf_traj, dmap, f) - stotal += ds - hit_edge = True - break - except TerminateTrajectory: - break - - dx1 = ds * k1x - dy1 = ds * k1y - dx2 = ds * 0.5 * (k1x + k2x) - dy2 = ds * 0.5 * (k1y + k2y) - - nx, ny = dmap.grid.shape - # Error is normalized to the axes coordinates - error = np.sqrt(((dx2 - dx1) / nx) ** 2 + ((dy2 - dy1) / ny) ** 2) - - # Only save step if within error tolerance - if error < maxerror: - xi += dx2 - yi += dy2 - dmap.update_trajectory(xi, yi) - if not dmap.grid.within_grid(xi, yi): - hit_edge = True - if (stotal + ds) > resolution * np.mean(m_total): - break - stotal += ds - - # recalculate stepsize based on step error - if error == 0: - ds = maxds - else: - ds = min(maxds, 0.85 * ds * (maxerror / error) ** 0.5) - - return stotal, xf_traj, yf_traj, m_total, hit_edge - - -def _euler_step(xf_traj, yf_traj, dmap, f): - """Simple Euler integration step that extends streamline to boundary.""" - ny, nx = dmap.grid.shape - xi = xf_traj[-1] - yi = yf_traj[-1] - cx, cy = f(xi, yi) - - if cx == 0: - dsx = np.inf - elif cx < 0: - dsx = xi / -cx - else: - dsx = (nx - 1 - xi) / cx - - if cy == 0: - dsy = np.inf - elif cy < 0: - dsy = yi / -cy - else: - dsy = (ny - 1 - yi) / cy - - ds = min(dsx, dsy) - - xf_traj.append(xi + cx * ds) - yf_traj.append(yi + cy * ds) - - return ds, xf_traj, yf_traj - - -def _interpgrid(a, xi, yi): - """Fast 2D, linear interpolation on an integer grid""" - Ny, Nx = np.shape(a) - - if isinstance(xi, np.ndarray): - x = xi.astype(int) - y = yi.astype(int) - - # Check that xn, yn don't exceed max index - xn = np.clip(x + 1, 0, Nx - 1) - yn = np.clip(y + 1, 0, Ny - 1) - else: - x = int(xi) - y = int(yi) - xn = min(x + 1, Nx - 1) - yn = min(y + 1, Ny - 1) - - a00 = a[y, x] - a01 = a[y, xn] - a10 = a[yn, x] - a11 = a[yn, xn] - - xt = xi - x - yt = yi - y - - a0 = a00 * (1 - xt) + a01 * xt - a1 = a10 * (1 - xt) + a11 * xt - ai = a0 * (1 - yt) + a1 * yt - - if not isinstance(xi, np.ndarray): - if np.ma.is_masked(ai): - raise _TerminateTrajectory - return ai - - -def _gen_starting_points(x, y, grains): - eps = np.finfo(np.float32).eps - tmp_x = np.linspace(x.min() + eps, x.max() - eps, grains) - tmp_y = np.linspace(y.min() + eps, y.max() - eps, grains) - xs = np.tile(tmp_x, grains) - ys = np.repeat(tmp_y, grains) - seed_points = np.array([list(xs), list(ys)]) - return seed_points.T - class PlotAxes(base.Axes): """ @@ -1953,10 +1569,8 @@ def curved_quiver( The implementation of this function is based on the `dfm_tools` repository. Original file: https://github.com/Deltares/dfm_tools/blob/829e76f48ebc42460aae118cc190147a595a5f26/dfm_tools/modplot.py """ - grid = _Grid(x, y) - mask = _StreamMask(density) - dmap = _DomainMap(grid, mask) + solver = CurvedQuiverSolver(x, y, density) if zorder is None: zorder = mlines.Line2D.zorder @@ -1975,7 +1589,7 @@ def curved_quiver( use_multicolor_lines = isinstance(color, np.ndarray) if use_multicolor_lines: - if color.shape != grid.shape: + if color.shape != solver.grid.shape: raise ValueError( "If 'color' is given, must have the shape of 'Grid(x,y)'" ) @@ -1986,7 +1600,7 @@ def curved_quiver( arrow_kw["color"] = color if isinstance(linewidth, np.ndarray): - if linewidth.shape != grid.shape: + if linewidth.shape != solver.grid.shape: raise ValueError( "If 'linewidth' is given, must have the shape of 'Grid(x,y)'" ) @@ -1999,7 +1613,7 @@ def curved_quiver( arrow_kw["zorder"] = zorder ## Sanity checks. - if u.shape != grid.shape or v.shape != grid.shape: + if u.shape != solver.grid.shape or v.shape != solver.grid.shape: raise ValueError("'u' and 'v' must be of shape 'Grid(x,y)'") u = np.ma.masked_invalid(u) @@ -2010,48 +1624,49 @@ def curved_quiver( resolution = scale / grains minlength = 0.9 * resolution - integrate = _get_integrator(u, v, dmap, minlength, resolution, magnitude) + integrate = solver.get_integrator(u, v, minlength, resolution, magnitude) trajectories = [] edges = [] if start_points is None: - start_points = _gen_starting_points(x, y, grains) + start_points = solver.gen_starting_points(x, y, grains) sp2 = np.asanyarray(start_points, dtype=float).copy() # Check if start_points are outside the data boundaries for xs, ys in sp2: if not ( - grid.x_origin <= xs <= grid.x_origin + grid.width - and grid.y_origin <= ys <= grid.y_origin + grid.height + solver.grid.x_origin <= xs <= solver.grid.x_origin + solver.grid.width + and solver.grid.y_origin + <= ys + <= solver.grid.y_origin + solver.grid.height ): raise ValueError( "Starting point ({}, {}) outside of data " "boundaries".format(xs, ys) ) + if use_multicolor_lines: + if norm is None: + norm = mcolors.Normalize(color.min(), color.max()) + if cmap is None: + cmap = constructor.Colormap(rc["image.cmap"]) + else: + cmap = mcm.get_cmap(cmap) + # Convert start_points from data to array coords # Shift the seed points from the bottom left of the data so that # data2grid works properly. - sp2[:, 0] -= grid.x_origin - sp2[:, 1] -= grid.y_origin + sp2[:, 0] -= solver.grid.x_origin + sp2[:, 1] -= solver.grid.y_origin for xs, ys in sp2: - xg, yg = dmap.data2grid(xs, ys) + xg, yg = solver.domain_map.data2grid(xs, ys) t = integrate(xg, yg) if t is not None: trajectories.append(t[0]) edges.append(t[1]) - - if use_multicolor_lines: - if norm is None: - norm = mcolors.Normalize(color.min(), color.max()) - if cmap is None: - cmap = constructor.Colormap(rc["image.cmap"]) - else: - cmap = mcm.get_cmap(cmap) - streamlines = [] arrows = [] for t, edge in zip(trajectories, edges): @@ -2059,9 +1674,9 @@ def curved_quiver( tgy = np.array(t[1]) # Rescale from grid-coordinates to data-coordinates. - tx, ty = dmap.grid2data(*np.array(t)) - tx += grid.x_origin - ty += grid.y_origin + tx, ty = solver.domain_map.grid2data(*np.array(t)) + tx += solver.grid.x_origin + ty += solver.grid.y_origin points = np.transpose([tx, ty]).reshape(-1, 1, 2) streamlines.extend(np.hstack([points[:-1], points[1:]])) @@ -2078,10 +1693,12 @@ def curved_quiver( arrow_tail = (tx[-1], ty[-1]) # Extrapolate to find arrow head - xg, yg = dmap.data2grid(tx[-1] - grid.x_origin, ty[-1] - grid.y_origin) + xg, yg = solver.domain_map.data2grid( + tx[-1] - solver.grid.x_origin, ty[-1] - solver.grid.y_origin + ) - ui = _interpgrid(u, xg, yg) - vi = _interpgrid(v, xg, yg) + ui = solver.interpgrid(u, xg, yg) + vi = solver.interpgrid(v, xg, yg) norm_v = np.sqrt(ui**2 + vi**2) if norm_v > 0: @@ -2103,12 +1720,12 @@ def curved_quiver( arrow_head = (np.mean(tx[n : n + 2]), np.mean(ty[n : n + 2])) if isinstance(linewidth, np.ndarray): - line_widths = _interpgrid(linewidth, tgx, tgy)[:-1] + line_widths = solver.interpgrid(linewidth, tgx, tgy)[:-1] line_kw["linewidth"].extend(line_widths) arrow_kw["linewidth"] = line_widths[n] if use_multicolor_lines: - color_values = _interpgrid(color, tgx, tgy)[:-1] + color_values = solver.interpgrid(color, tgx, tgy)[:-1] line_colors.append(color_values) arrow_kw["color"] = cmap(norm(color_values[n])) @@ -2130,8 +1747,14 @@ def curved_quiver( arrows.append(p) lc = mcollections.LineCollection(streamlines, transform=transform, **line_kw) - lc.sticky_edges.x[:] = [grid.x_origin, grid.x_origin + grid.width] - lc.sticky_edges.y[:] = [grid.y_origin, grid.y_origin + grid.height] + lc.sticky_edges.x[:] = [ + solver.grid.x_origin, + solver.grid.x_origin + solver.grid.width, + ] + lc.sticky_edges.y[:] = [ + solver.grid.y_origin, + solver.grid.y_origin + solver.grid.height, + ] if use_multicolor_lines: lc.set_array(np.ma.hstack(line_colors)) @@ -6113,96 +5736,380 @@ def _iter_arg_cols(self, *args, label=None, labels=None, values=None, **kwargs): violins = warnings._rename_objs("0.8.0", violins=violin) -# -def _setup_grid_and_mask(x, y, density): - """ - Helper for `curved_quiver`. +# The following helper classes and functions for curved_quiver are based on the +# work in the `dfm_tools` repository. +# Original file: https://github.com/Deltares/dfm_tools/blob/829e76f48ebc42460aae118cc190147a595a5f26/dfm_tools/modplot.py +class _TerminateTrajectory(Exception): + pass - Initializes and returns the grid, stream mask, and domain map objects for the vector field. - """ - grid = _Grid(x, y) - mask = _StreamMask(density) - dmap = _DomainMap(grid, mask) - return grid, mask, dmap -def _validate_vector_shapes(u, v, grid): - """ - Helper for `curved_quiver`. - Validates that the shapes of `u` and `v` match the grid shape. Raises ValueError if not. - """ - if u.shape != grid.shape or v.shape != grid.shape: - raise ValueError("'u' and 'v' must be of shape 'Grid(x,y)'") +@dataclass +class CurvedQuiverSet(StreamplotSet): + lines: object + arrows: object -def _normalize_magnitude(u, v): - """ - Helper for `curved_quiver`. +class CurvedQuiverSolver: + class DomainMap(object): + """Map representing different coordinate systems. - Computes and returns the normalized magnitude array for the vector field. - """ - u = np.ma.masked_invalid(u) - v = np.ma.masked_invalid(v) - magnitude = np.sqrt(u**2 + v**2) - magnitude /= np.max(magnitude) - return magnitude + Coordinate definitions: + * axes-coordinates goes from 0 to 1 in the domain. + * data-coordinates are specified by the input x-y coordinates. + * grid-coordinates goes from 0 to N and 0 to M for an N x M grid, + where N and M match the shape of the input data. + * mask-coordinates goes from 0 to N and 0 to M for an N x M mask, + where N and M are user-specified to control the density of + streamlines. + This class also has methods for adding trajectories to the + StreamMask. Before adding a trajectory, run `start_trajectory` to + keep track of regions crossed by a given trajectory. Later, if you + decide the trajectory is bad (e.g., if the trajectory is very + short) just call `undo_trajectory`. + """ -def _generate_start_points(x, y, grains, start_points, grid): - """ - Helper for `curved_quiver`. + def __init__(self, grid: "CurvedQuiverSolver.Grid", mask: "CurvedQuiverSolver.StreamMask") -> None: + self.grid = grid + self.mask = mask - Generates or validates starting points for streamlines, ensuring they are within grid boundaries. - Returns points in grid coordinates. - """ - if start_points is None: - start_points = _gen_starting_points(x, y, grains) - sp2 = np.asanyarray(start_points, dtype=float).copy() - for xs, ys in sp2: - if not ( - grid.x_origin <= xs <= grid.x_origin + grid.width - and grid.y_origin <= ys <= grid.y_origin + grid.height - ): - raise ValueError( - "Starting point ({}, {}) outside of data boundaries".format(xs, ys) + # Constants for conversion between grid- and mask-coordinates + self.x_grid2mask = (mask.nx - 1) / grid.nx + self.y_grid2mask = (mask.ny - 1) / grid.ny + self.x_mask2grid = 1.0 / self.x_grid2mask + self.y_mask2grid = 1.0 / self.y_grid2mask + + self.x_data2grid = 1.0 / grid.dx + self.y_data2grid = 1.0 / grid.dy + + def grid2mask(self, xi: float, yi: float) -> tuple[int, int]: + """Return nearest space in mask-coords from given grid-coords.""" + return (int((xi * self.x_grid2mask) + 0.5), int((yi * self.y_grid2mask) + 0.5)) + + def mask2grid(self, xm: int, ym: int) -> tuple[float, float]: + return xm * self.x_mask2grid, ym * self.y_mask2grid + + def data2grid(self, xd: float, yd: float) -> tuple[float, float]: + return xd * self.x_data2grid, yd * self.y_data2grid + + def grid2data(self, xg: float, yg: float) -> tuple[float, float]: + return xg / self.x_data2grid, yg / self.y_data2grid + + def start_trajectory(self, xg: float, yg: float) -> None: + xm, ym = self.grid2mask(xg, yg) + self.mask._start_trajectory(xm, ym) + + def reset_start_point(self, xg: float, yg: float) -> None: + xm, ym = self.grid2mask(xg, yg) + self.mask._current_xy = (xm, ym) + + def update_trajectory(self, xg: float, yg: float) -> None: + xm, ym = self.grid2mask(xg, yg) + self.mask._update_trajectory(xm, ym) + + def undo_trajectory(self) -> None: + self.mask._undo_trajectory() + + + class Grid(object): + """Grid of data.""" + + def __init__(self, x: np.ndarray, y: np.ndarray) -> None: + if x.ndim == 1: + pass + elif x.ndim == 2: + x_row = x[0, :] + if not np.allclose(x_row, x): + raise ValueError("The rows of 'x' must be equal") + x = x_row + else: + raise ValueError("'x' can have at maximum 2 dimensions") + + if y.ndim == 1: + pass + elif y.ndim == 2: + y_col = y[:, 0] + if not np.allclose(y_col, y.T): + raise ValueError("The columns of 'y' must be equal") + y = y_col + else: + raise ValueError("'y' can have at maximum 2 dimensions") + + self.nx = len(x) + self.ny = len(y) + self.dx = x[1] - x[0] + self.dy = y[1] - y[0] + self.x_origin = x[0] + self.y_origin = y[0] + self.width = x[-1] - x[0] + self.height = y[-1] - y[0] + + @property + def shape(self) -> tuple[int, int]: + return self.ny, self.nx + + def within_grid(self, xi: float, yi: float) -> bool: + """Return True if point is a valid index of grid.""" + # Note that xi/yi can be floats; so, for example, we can't simply check + # `xi < self.nx` since `xi` can be `self.nx - 1 < xi < self.nx` + return xi >= 0 and xi <= self.nx - 1 and yi >= 0 and yi <= self.ny - 1 + + + class StreamMask(object): + """Mask to keep track of discrete regions crossed by streamlines. + + The resolution of this grid determines the approximate spacing + between trajectories. Streamlines are only allowed to pass through + zeroed cells: When a streamline enters a cell, that cell is set to + 1, and no new streamlines are allowed to enter. + """ + def __init__(self, density): + if np.isscalar(density): + if density <= 0: + raise ValueError("If a scalar, 'density' must be positive") + self.nx = self.ny = int(30 * density) + else: + if len(density) != 2: + raise ValueError("'density' can have at maximum 2 dimensions") + self.nx = int(30 * density[0]) + self.ny = int(30 * density[1]) + + self._mask = np.zeros((self.ny, self.nx)) + self.shape = self._mask.shape + self._current_xy = None + + def __getitem__(self, *args): + return self._mask.__getitem__(*args) + + def _start_trajectory(self, xm, ym): + """Start recording streamline trajectory""" + self._traj = [] + self._update_trajectory(xm, ym) + + def _undo_trajectory(self): + """Remove current trajectory from mask""" + for t in self._traj: + self._mask.__setitem__(t, 0) + def _update_trajectory(self, xm: int, ym: int) -> None: + """Update current trajectory position in mask. + + If the new position has already been filled, raise + `InvalidIndexError`. + """ + + self._traj.append((ym, xm)) + self._mask[ym, xm] = 1 + self._current_xy = (xm, ym) + + def __init__(self, x: np.ndarray, y: np.ndarray, density: float | tuple[float, float]) -> None: + self.grid = CurvedQuiverSolver.Grid(x, y) + self.mask = CurvedQuiverSolver.StreamMask(density) + self.domain_map = CurvedQuiverSolver.DomainMap(self.grid, self.mask) + + def get_integrator(self, u: np.ndarray, v: np.ndarray, minlength: float, resolution: float, magnitude: np.ndarray) -> Callable[[float, float], tuple[tuple[list[float], list[float]], bool] | None]: + # rescale velocity onto grid-coordinates for integrations. + u, v = self.domain_map.data2grid(u, v) + + # speed (path length) will be in axes-coordinates + u_ax = u / self.domain_map.grid.nx + v_ax = v / self.domain_map.grid.ny + speed = np.ma.sqrt(u_ax**2 + v_ax**2) + + def forward_time(xi: float, yi: float) -> tuple[float, float]: + ds_dt = self.interpgrid(speed, xi, yi) + if ds_dt == 0: + raise _TerminateTrajectory() + dt_ds = 1.0 / ds_dt + ui = self.interpgrid(u, xi, yi) + vi = self.interpgrid(v, xi, yi) + return ui * dt_ds, vi * dt_ds + + def integrate(x0: float, y0: float) -> tuple[tuple[list[float], list[float], bool]] | None: + """Return x, y grid-coordinates of trajectory based on starting point. + + Integrate both forward and backward in time from starting point + in grid coordinates. Integration is terminated when a trajectory + reaches a domain boundary or when it crosses into an already + occupied cell in the StreamMask. The resulting trajectory is + None if it is shorter than `minlength`. + """ + stotal, x_traj, y_traj = 0.0, [], [] + self.domain_map.start_trajectory(x0, y0) + self.domain_map.reset_start_point(x0, y0) + stotal, x_traj, y_traj, m_total, hit_edge = self.integrate_rk12( + x0, y0, forward_time, resolution, magnitude ) - sp2[:, 0] -= grid.x_origin - sp2[:, 1] -= grid.y_origin - return sp2 + if len(x_traj) > 1: + return (x_traj, y_traj), hit_edge + else: + # reject short trajectories + self.domain_map.undo_trajectory() + return None + + return integrate + + def integrate_rk12(self, x0: float, y0: float, f: Callable[[float, float], tuple[float, float]], resolution: float, magnitude: np.ndarray,) -> tuple[float, list[float], list[float], list[float], bool]: + """2nd-order Runge-Kutta algorithm with adaptive step size. + + This method is also referred to as the improved Euler's method, or + Heun's method. This method is favored over higher-order methods + because: + + 1. To get decent looking trajectories and to sample every mask cell + on the trajectory we need a small timestep, so a lower order + solver doesn't hurt us unless the data is *very* high + resolution. In fact, for cases where the user inputs data + smaller or of similar grid size to the mask grid, the higher + order corrections are negligible because of the very fast linear + interpolation used in `interpgrid`. + + 2. For high resolution input data (i.e. beyond the mask + resolution), we must reduce the timestep. Therefore, an + adaptive timestep is more suited to the problem as this would be + very hard to judge automatically otherwise. + + This integrator is about 1.5 - 2x as fast as both the RK4 and RK45 + solvers in most setups on my machine. I would recommend removing + the other two to keep things simple. + """ + # This error is below that needed to match the RK4 integrator. It + # is set for visual reasons -- too low and corners start + # appearing ugly and jagged. Can be tuned. + maxerror = 0.003 + + # This limit is important (for all integrators) to avoid the + # trajectory skipping some mask cells. We could relax this + # condition if we use the code which is commented out below to + # increment the location gradually. However, due to the efficient + # nature of the interpolation, this doesn't boost speed by much + # for quite a bit of complexity. + maxds = min(1.0 / self.domain_map.mask.nx, 1.0 / self.domain_map.mask.ny, 0.1) + ds = maxds + + stotal = 0 + xi = x0 + yi = y0 + xf_traj = [] + yf_traj = [] + m_total = [] + hit_edge = False + + while self.domain_map.grid.within_grid(xi, yi): + xf_traj.append(xi) + yf_traj.append(yi) + m_total.append(self.interpgrid(magnitude, xi, yi)) + + try: + k1x, k1y = f(xi, yi) + k2x, k2y = f(xi + ds * k1x, yi + ds * k1y) + except IndexError: + # Out of the domain on one of the intermediate integration steps. + # Take an Euler step to the boundary to improve neatness. + ds, xf_traj, yf_traj = self.euler_step(xf_traj, yf_traj, f) + stotal += ds + hit_edge = True + break + except TerminateTrajectory: + break + + dx1 = ds * k1x + dy1 = ds * k1y + dx2 = ds * 0.5 * (k1x + k2x) + dy2 = ds * 0.5 * (k1y + k2y) + + nx, ny = self.domain_map.grid.shape + # Error is normalized to the axes coordinates + error = np.sqrt(((dx2 - dx1) / nx) ** 2 + ((dy2 - dy1) / ny) ** 2) + + # Only save step if within error tolerance + if error < maxerror: + xi += dx2 + yi += dy2 + self.domain_map.update_trajectory(xi, yi) + if not self.domain_map.grid.within_grid(xi, yi): + hit_edge = True + if (stotal + ds) > resolution * np.mean(m_total): + break + stotal += ds + + # recalculate stepsize based on step error + if error == 0: + ds = maxds + else: + ds = min(maxds, 0.85 * ds * (maxerror / error) ** 0.5) -def _calculate_trajectories(sp2, dmap, integrate): - """ - Helper for `curved_quiver`. + return stotal, xf_traj, yf_traj, m_total, hit_edge - Integrates trajectories from each starting point using the provided integrator. - Returns lists of trajectories and edges. - """ - trajectories = [] - edges = [] - for xs, ys in sp2: - xg, yg = dmap.data2grid(xs, ys) - t = integrate(xg, yg) - if t is not None: - trajectories.append(t[0]) - edges.append(t[1]) - return trajectories, edges - - -def _handle_multicolor_lines(color, norm, cmap, grid): - """ - Helper for `curved_quiver`. + def euler_step(self, xf_traj, yf_traj, f): + """Simple Euler integration step that extends streamline to boundary.""" + ny, nx = self.domain_map.grid.shape + xi = xf_traj[-1] + yi = yf_traj[-1] + cx, cy = f(xi, yi) - Prepares colormap and normalization for multicolor lines. - Returns updated color, norm, cmap, and a list for line colors. - """ - line_colors = [] - if norm is None: - norm = mcolors.Normalize(color.min(), color.max()) - if cmap is None: - cmap = constructor.Colormap(rc["image.cmap"]) - else: - cmap = mcm.get_cmap(cmap) - color = np.ma.masked_invalid(color) - return color, norm, cmap, line_colors + if cx == 0: + dsx = np.inf + elif cx < 0: + dsx = xi / -cx + else: + dsx = (nx - 1 - xi) / cx + + if cy == 0: + dsy = np.inf + elif cy < 0: + dsy = yi / -cy + else: + dsy = (ny - 1 - yi) / cy + + ds = min(dsx, dsy) + + xf_traj.append(xi + cx * ds) + yf_traj.append(yi + cy * ds) + + return ds, xf_traj, yf_traj + + def interpgrid(self, a, xi, yi): + """Fast 2D, linear interpolation on an integer grid""" + Ny, Nx = np.shape(a) + + if isinstance(xi, np.ndarray): + x = xi.astype(int) + y = yi.astype(int) + + # Check that xn, yn don't exceed max index + xn = np.clip(x + 1, 0, Nx - 1) + yn = np.clip(y + 1, 0, Ny - 1) + else: + x = int(xi) + y = int(yi) + xn = min(x + 1, Nx - 1) + yn = min(y + 1, Ny - 1) + + a00 = a[y, x] + a01 = a[y, xn] + a10 = a[yn, x] + a11 = a[yn, xn] + + xt = xi - x + yt = yi - y + + a0 = a00 * (1 - xt) + a01 * xt + a1 = a10 * (1 - xt) + a11 * xt + ai = a0 * (1 - yt) + a1 * yt + + if not isinstance(xi, np.ndarray): + if np.ma.is_masked(ai): + raise _TerminateTrajectory + return ai + + def gen_starting_points(self, x, y, grains): + eps = np.finfo(np.float32).eps + tmp_x = np.linspace(x.min() + eps, x.max() - eps, grains) + tmp_y = np.linspace(y.min() + eps, y.max() - eps, grains) + xs = np.tile(tmp_x, grains) + ys = np.repeat(tmp_y, grains) + seed_points = np.array([list(xs), list(ys)]) + return seed_points.T From 1ee834818d79171573ec3b785608d5f4fd6994fa Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Tue, 7 Oct 2025 17:40:21 +0200 Subject: [PATCH 12/33] black formatting --- ultraplot/axes/plot.py | 42 +++++++++++++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index 49785295b..7d4269ae0 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -1533,7 +1533,6 @@ def _inside_seaborn_call(): return False - class PlotAxes(base.Axes): """ The second lowest-level `~matplotlib.axes.Axes` subclass used by ultraplot. @@ -5743,8 +5742,6 @@ class _TerminateTrajectory(Exception): pass - - @dataclass class CurvedQuiverSet(StreamplotSet): lines: object @@ -5771,7 +5768,9 @@ class DomainMap(object): short) just call `undo_trajectory`. """ - def __init__(self, grid: "CurvedQuiverSolver.Grid", mask: "CurvedQuiverSolver.StreamMask") -> None: + def __init__( + self, grid: "CurvedQuiverSolver.Grid", mask: "CurvedQuiverSolver.StreamMask" + ) -> None: self.grid = grid self.mask = mask @@ -5786,7 +5785,10 @@ def __init__(self, grid: "CurvedQuiverSolver.Grid", mask: "CurvedQuiverSolver.St def grid2mask(self, xi: float, yi: float) -> tuple[int, int]: """Return nearest space in mask-coords from given grid-coords.""" - return (int((xi * self.x_grid2mask) + 0.5), int((yi * self.y_grid2mask) + 0.5)) + return ( + int((xi * self.x_grid2mask) + 0.5), + int((yi * self.y_grid2mask) + 0.5), + ) def mask2grid(self, xm: int, ym: int) -> tuple[float, float]: return xm * self.x_mask2grid, ym * self.y_mask2grid @@ -5812,7 +5814,6 @@ def update_trajectory(self, xg: float, yg: float) -> None: def undo_trajectory(self) -> None: self.mask._undo_trajectory() - class Grid(object): """Grid of data.""" @@ -5856,7 +5857,6 @@ def within_grid(self, xi: float, yi: float) -> bool: # `xi < self.nx` since `xi` can be `self.nx - 1 < xi < self.nx` return xi >= 0 and xi <= self.nx - 1 and yi >= 0 and yi <= self.ny - 1 - class StreamMask(object): """Mask to keep track of discrete regions crossed by streamlines. @@ -5865,6 +5865,7 @@ class StreamMask(object): zeroed cells: When a streamline enters a cell, that cell is set to 1, and no new streamlines are allowed to enter. """ + def __init__(self, density): if np.isscalar(density): if density <= 0: @@ -5892,6 +5893,7 @@ def _undo_trajectory(self): """Remove current trajectory from mask""" for t in self._traj: self._mask.__setitem__(t, 0) + def _update_trajectory(self, xm: int, ym: int) -> None: """Update current trajectory position in mask. @@ -5903,12 +5905,21 @@ def _update_trajectory(self, xm: int, ym: int) -> None: self._mask[ym, xm] = 1 self._current_xy = (xm, ym) - def __init__(self, x: np.ndarray, y: np.ndarray, density: float | tuple[float, float]) -> None: + def __init__( + self, x: np.ndarray, y: np.ndarray, density: float | tuple[float, float] + ) -> None: self.grid = CurvedQuiverSolver.Grid(x, y) self.mask = CurvedQuiverSolver.StreamMask(density) self.domain_map = CurvedQuiverSolver.DomainMap(self.grid, self.mask) - def get_integrator(self, u: np.ndarray, v: np.ndarray, minlength: float, resolution: float, magnitude: np.ndarray) -> Callable[[float, float], tuple[tuple[list[float], list[float]], bool] | None]: + def get_integrator( + self, + u: np.ndarray, + v: np.ndarray, + minlength: float, + resolution: float, + magnitude: np.ndarray, + ) -> Callable[[float, float], tuple[tuple[list[float], list[float]], bool] | None]: # rescale velocity onto grid-coordinates for integrations. u, v = self.domain_map.data2grid(u, v) @@ -5926,7 +5937,9 @@ def forward_time(xi: float, yi: float) -> tuple[float, float]: vi = self.interpgrid(v, xi, yi) return ui * dt_ds, vi * dt_ds - def integrate(x0: float, y0: float) -> tuple[tuple[list[float], list[float], bool]] | None: + def integrate( + x0: float, y0: float + ) -> tuple[tuple[list[float], list[float], bool]] | None: """Return x, y grid-coordinates of trajectory based on starting point. Integrate both forward and backward in time from starting point @@ -5951,7 +5964,14 @@ def integrate(x0: float, y0: float) -> tuple[tuple[list[float], list[float], boo return integrate - def integrate_rk12(self, x0: float, y0: float, f: Callable[[float, float], tuple[float, float]], resolution: float, magnitude: np.ndarray,) -> tuple[float, list[float], list[float], list[float], bool]: + def integrate_rk12( + self, + x0: float, + y0: float, + f: Callable[[float, float], tuple[float, float]], + resolution: float, + magnitude: np.ndarray, + ) -> tuple[float, list[float], list[float], list[float], bool]: """2nd-order Runge-Kutta algorithm with adaptive step size. This method is also referred to as the improved Euler's method, or From 6af7555f74dda76d3973f7ad413e2de656a8d092 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Tue, 7 Oct 2025 18:16:12 +0200 Subject: [PATCH 13/33] update tests with new api --- ultraplot/tests/test_plot.py | 118 +++++++++++++++-------------------- 1 file changed, 51 insertions(+), 67 deletions(-) diff --git a/ultraplot/tests/test_plot.py b/ultraplot/tests/test_plot.py index d412cce32..22222d46b 100644 --- a/ultraplot/tests/test_plot.py +++ b/ultraplot/tests/test_plot.py @@ -481,114 +481,98 @@ def test_curved_quiver(rng): return fig -def test_setup_grid_and_mask(): - """ - Test that _setup_grid_and_mask creates grid, mask, and domain map objects - with expected attributes and shapes for a simple input. - """ - x = np.linspace(0, 1, 5) - y = np.linspace(0, 1, 5) - grid, mask, dmap = uplt.axes.plot._setup_grid_and_mask(x, y, density=5) - assert grid.shape == (5, 5) - assert hasattr(mask, "shape") - assert hasattr(dmap, "grid") - assert hasattr(dmap, "mask") - - def test_validate_vector_shapes_pass(): """ - Test that _validate_vector_shapes passes when u and v match the grid shape. + Test that vector shapes match the grid shape using CurvedQuiverSolver. """ x = np.linspace(0, 1, 3) y = np.linspace(0, 1, 3) - grid, _, _ = uplt.axes.plot._setup_grid_and_mask(x, y, density=3) + grid = uplt.axes.plot.CurvedQuiverSolver.Grid(x, y) u = np.ones(grid.shape) v = np.ones(grid.shape) - # Should not raise - uplt.axes.plot._validate_vector_shapes(u, v, grid) + assert u.shape == grid.shape + assert v.shape == grid.shape def test_validate_vector_shapes_fail(): """ - Test that _validate_vector_shapes raises ValueError when u and v do not match the grid shape. + Test that assertion fails when u and v do not match the grid shape using CurvedQuiverSolver. """ x = np.linspace(0, 1, 3) y = np.linspace(0, 1, 3) - grid, _, _ = uplt.axes.plot._setup_grid_and_mask(x, y, density=3) + grid = uplt.axes.plot.CurvedQuiverSolver.Grid(x, y) u = np.ones((2, 2)) v = np.ones(grid.shape) + import pytest - with pytest.raises(ValueError): - uplt.axes.plot._validate_vector_shapes(u, v, grid) + with pytest.raises(AssertionError): + assert u.shape == grid.shape def test_normalize_magnitude(): """ - Test that _normalize_magnitude returns a normalized array with max value 1.0 and correct shape. + Test that magnitude normalization returns a normalized array with max value 1.0 and correct shape. """ u = np.array([[1, 2], [3, 4]]) v = np.array([[4, 3], [2, 1]]) - mag = uplt.axes.plot._normalize_magnitude(u, v) - assert np.allclose(np.max(mag), 1.0) - assert mag.shape == u.shape + mag = np.sqrt(u**2 + v**2) + mag_norm = mag / np.max(mag) + assert np.allclose(np.max(mag_norm), 1.0) + assert mag_norm.shape == u.shape def test_generate_start_points(): """ - Test that _generate_start_points returns valid grid coordinates for seed points, - and raises ValueError for points outside the grid boundaries. + Test that CurvedQuiverSolver.gen_starting_points returns valid grid coordinates for seed points, + and that grid.within_grid detects points outside the grid boundaries. """ x = np.linspace(0, 1, 5) y = np.linspace(0, 1, 5) - grid, _, _ = uplt.axes.plot._setup_grid_and_mask(x, y, density=5) - sp2 = uplt.axes.plot._generate_start_points( - x, y, grains=5, start_points=None, grid=grid - ) + grains = 5 + solver = uplt.axes.plot.CurvedQuiverSolver(x, y, density=5) + sp2 = solver.gen_starting_points(x, y, grains) assert sp2.shape[1] == 2 - # Should raise if outside boundaries - + # Should detect if outside boundaries bad_points = np.array([[10, 10]]) - with pytest.raises(ValueError): - uplt.axes.plot._generate_start_points( - x, y, grains=5, start_points=bad_points, grid=grid - ) + grid = solver.grid + for pt in bad_points: + assert not grid.within_grid(pt[0], pt[1]) def test_calculate_trajectories(): """ - Test that _calculate_trajectories calls the integrator for each seed point + Test that CurvedQuiverSolver.get_integrator returns callable for each seed point and returns lists of trajectories and edges of correct length. """ - - # Use a dummy integrator that returns a fixed trajectory - def dummy_integrate(xg, yg): - return ([np.array([xg, xg + 1]), np.array([yg, yg + 1])], False) - - sp2 = np.array([[0, 0], [1, 1]]) - - class DummyDMap: - def data2grid(self, xs, ys): - return xs, ys - - trajectories, edges = uplt.axes.plot._calculate_trajectories( - sp2, DummyDMap(), dummy_integrate + x = np.linspace(0, 1, 5) + y = np.linspace(0, 1, 5) + u = np.ones((5, 5)) + v = np.ones((5, 5)) + mag = np.sqrt(u**2 + v**2) + solver = uplt.axes.plot.CurvedQuiverSolver(x, y, density=5) + integrator = solver.get_integrator( + u, v, minlength=0.1, resolution=1.0, magnitude=mag ) - assert len(trajectories) == 2 - assert len(edges) == 2 + seeds = solver.gen_starting_points(x, y, grains=2) + results = [integrator(pt[0], pt[1]) for pt in seeds] + assert len(results) == seeds.shape[0] -def test_handle_multicolor_lines(): +def test_curved_quiver_multicolor_lines(): """ - Test that _handle_multicolor_lines returns masked color array, norm, cmap, and an empty line_colors list. + Test that curved_quiver handles color arrays and returns a lines object. """ - color = np.array([[0, 1], [2, 3]]) - norm = None - cmap = None - grid = mock.Mock() - out_color, out_norm, out_cmap, line_colors = ( - uplt.axes.plot._handle_multicolor_lines(color, norm, cmap, grid) - ) - assert out_color.shape == color.shape - assert hasattr(out_norm, "autoscale") - assert hasattr(out_cmap, "__call__") - assert isinstance(line_colors, list) + x = np.linspace(0, 1, 5) + y = np.linspace(0, 1, 5) + X, Y = np.meshgrid(x, y) + U = np.ones_like(X) + V = np.ones_like(Y) + speed = np.sqrt(U**2 + V**2) + + fig, ax = uplt.subplots() + m = ax.curved_quiver(X, Y, U, V, color=speed) + from matplotlib.collections import LineCollection + + assert isinstance(m.lines, LineCollection) + assert m.lines.get_array().size > 0 # we have colors set + assert m.lines.get_cmap() is not None From a01893c0efe818b60d86d13d5518d60697bf8838 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Tue, 7 Oct 2025 18:16:56 +0200 Subject: [PATCH 14/33] add one test as image comp --- ultraplot/tests/test_plot.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ultraplot/tests/test_plot.py b/ultraplot/tests/test_plot.py index 22222d46b..8727eb43d 100644 --- a/ultraplot/tests/test_plot.py +++ b/ultraplot/tests/test_plot.py @@ -558,6 +558,7 @@ def test_calculate_trajectories(): assert len(results) == seeds.shape[0] +@pytest.mark.mpl_image_compare def test_curved_quiver_multicolor_lines(): """ Test that curved_quiver handles color arrays and returns a lines object. @@ -576,3 +577,4 @@ def test_curved_quiver_multicolor_lines(): assert isinstance(m.lines, LineCollection) assert m.lines.get_array().size > 0 # we have colors set assert m.lines.get_cmap() is not None + return fig From 1088dd1f0ad02270a83fb38ea1fccfde1a259e7c Mon Sep 17 00:00:00 2001 From: Casper van Elteren Date: Tue, 7 Oct 2025 18:37:08 +0200 Subject: [PATCH 15/33] Update ultraplot/axes/plot.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- ultraplot/axes/plot.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index 7d4269ae0..aff3dbd7e 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -13,7 +13,6 @@ from typing import Any, Union, Iterable, Optional from dataclasses import dataclass -from typing import Any, Union from collections.abc import Callable from collections.abc import Iterable From 31a1dbc27befb72a5b890dc22137043cc7cbbaa8 Mon Sep 17 00:00:00 2001 From: Casper van Elteren Date: Tue, 7 Oct 2025 19:06:04 +0200 Subject: [PATCH 16/33] Update ultraplot/axes/plot.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- ultraplot/axes/plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index aff3dbd7e..fff078f69 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -1708,7 +1708,7 @@ def curved_quiver( arrow_length = arrowsize * (s[-1] / len(s)) else: # fallback for very short streamlines - arrow_length = arrowsize * 0.1 * np.mean([grid.dx, grid.dy]) + arrow_length = arrowsize * 0.1 * np.mean([solver.grid.dx, solver.grid.dy]) arrow_head = (tx[-1] + ui * arrow_length, ty[-1] + vi * arrow_length) n = len(s) - 1 if len(s) > 0 else 0 From c3dc242d731367eff3d6de6693fd58fd60385f58 Mon Sep 17 00:00:00 2001 From: Casper van Elteren Date: Tue, 7 Oct 2025 19:06:20 +0200 Subject: [PATCH 17/33] Update ultraplot/axes/plot.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- ultraplot/axes/plot.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index fff078f69..0482fa71c 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -208,8 +208,34 @@ - lines: LineCollection of streamlines - arrows: PatchCollection of arrows """ -docstring._snippet_manager["plot.curved_quiver"] = _args_2d_shared_docstring +_curved_quiver_docstring = """ +density : float or (float, float), optional + Controls the closeness of streamlines. +grains : int or (int, int), optional + Number of seed points in x and y. +linewidth : float or 2D array, optional + Width of streamlines. +cmap, norm : optional + Colormap and normalization for array colors. +arrowsize : float, optional + Arrow size scaling. +arrowstyle : str, optional + Arrow style specification. +transform : optional + Matplotlib transform. +zorder : float, optional + Z-order for lines/arrows. +start_points : (N, 2) array, optional + Starting points for streamlines. +Returns +------- +CurvedQuiverSet + Container with attributes: + - lines: LineCollection of streamlines + - arrows: PatchCollection of arrows +""" +docstring._snippet_manager["plot.curved_quiver"] = _curved_quiver_docstring # Auto colorbar and legend docstring _guide_docstring = """ colorbar : bool, int, or str, optional From 525f56f2cc6b949b546a5f486a09f570ce9562b3 Mon Sep 17 00:00:00 2001 From: Casper van Elteren Date: Tue, 7 Oct 2025 19:06:27 +0200 Subject: [PATCH 18/33] Update ultraplot/axes/plot.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- ultraplot/axes/plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index 0482fa71c..ce8a902c5 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -6057,7 +6057,7 @@ def integrate_rk12( stotal += ds hit_edge = True break - except TerminateTrajectory: + except _TerminateTrajectory: break dx1 = ds * k1x From fd8eb8757a2e886ba04d9e3c420be5d1d976ec80 Mon Sep 17 00:00:00 2001 From: Casper van Elteren Date: Tue, 7 Oct 2025 19:06:39 +0200 Subject: [PATCH 19/33] Update ultraplot/tests/test_plot.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- ultraplot/tests/test_plot.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/ultraplot/tests/test_plot.py b/ultraplot/tests/test_plot.py index 8727eb43d..69ebbf3d3 100644 --- a/ultraplot/tests/test_plot.py +++ b/ultraplot/tests/test_plot.py @@ -503,8 +503,6 @@ def test_validate_vector_shapes_fail(): grid = uplt.axes.plot.CurvedQuiverSolver.Grid(x, y) u = np.ones((2, 2)) v = np.ones(grid.shape) - import pytest - with pytest.raises(AssertionError): assert u.shape == grid.shape From 56534d9bbe7be1b3f7e0707ba7207b9cf68f30ad Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Tue, 7 Oct 2025 19:50:50 +0200 Subject: [PATCH 20/33] black formatting --- ultraplot/axes/plot.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index ce8a902c5..f3c5026ad 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -1734,7 +1734,9 @@ def curved_quiver( arrow_length = arrowsize * (s[-1] / len(s)) else: # fallback for very short streamlines - arrow_length = arrowsize * 0.1 * np.mean([solver.grid.dx, solver.grid.dy]) + arrow_length = ( + arrowsize * 0.1 * np.mean([solver.grid.dx, solver.grid.dy]) + ) arrow_head = (tx[-1] + ui * arrow_length, ty[-1] + vi * arrow_length) n = len(s) - 1 if len(s) > 0 else 0 From 4cc11d74e5d1358b94d8b1c67900eca39a0eabf6 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Wed, 8 Oct 2025 10:09:32 +0200 Subject: [PATCH 21/33] inline the termination to make it slightly more compact --- ultraplot/axes/plot.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index f3c5026ad..a6af8da4d 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -5765,9 +5765,6 @@ def _iter_arg_cols(self, *args, label=None, labels=None, values=None, **kwargs): # The following helper classes and functions for curved_quiver are based on the # work in the `dfm_tools` repository. # Original file: https://github.com/Deltares/dfm_tools/blob/829e76f48ebc42460aae118cc190147a595a5f26/dfm_tools/modplot.py -class _TerminateTrajectory(Exception): - pass - @dataclass class CurvedQuiverSet(StreamplotSet): @@ -5932,6 +5929,8 @@ def _update_trajectory(self, xm: int, ym: int) -> None: self._mask[ym, xm] = 1 self._current_xy = (xm, ym) + class TerminateTrajectory(Exception): + pass def __init__( self, x: np.ndarray, y: np.ndarray, density: float | tuple[float, float] ) -> None: @@ -5958,7 +5957,7 @@ def get_integrator( def forward_time(xi: float, yi: float) -> tuple[float, float]: ds_dt = self.interpgrid(speed, xi, yi) if ds_dt == 0: - raise _TerminateTrajectory() + raise CurvedQuiverSolver.TerminateTrajectory() dt_ds = 1.0 / ds_dt ui = self.interpgrid(u, xi, yi) vi = self.interpgrid(v, xi, yi) @@ -6059,7 +6058,7 @@ def integrate_rk12( stotal += ds hit_edge = True break - except _TerminateTrajectory: + except CurvedQuiverSolver.TerminateTrajectory: break dx1 = ds * k1x @@ -6149,7 +6148,7 @@ def interpgrid(self, a, xi, yi): if not isinstance(xi, np.ndarray): if np.ma.is_masked(ai): - raise _TerminateTrajectory + raise CurvedQuiverSolver.TerminateTrajectory return ai def gen_starting_points(self, x, y, grains): From 8df15b84e25ca85589cd08f548a157eefbbf0f83 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Wed, 8 Oct 2025 20:18:07 +0200 Subject: [PATCH 22/33] mv curved quiver plot to 'plot_types' and update tests' --- ultraplot/axes/plot.py | 400 +------------------- ultraplot/axes/plot_types/__init__.py | 0 ultraplot/axes/plot_types/curved_quiver.py | 413 +++++++++++++++++++++ ultraplot/tests/test_plot.py | 16 +- 4 files changed, 426 insertions(+), 403 deletions(-) create mode 100644 ultraplot/axes/plot_types/__init__.py create mode 100644 ultraplot/axes/plot_types/curved_quiver.py diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index a6af8da4d..26fc9fdc2 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -1593,6 +1593,7 @@ def curved_quiver( The implementation of this function is based on the `dfm_tools` repository. Original file: https://github.com/Deltares/dfm_tools/blob/829e76f48ebc42460aae118cc190147a595a5f26/dfm_tools/modplot.py """ + from .plot_types.curved_quiver import CurvedQuiverSolver, CurvedQuiverSet solver = CurvedQuiverSolver(x, y, density) if zorder is None: @@ -5760,402 +5761,3 @@ def _iter_arg_cols(self, *args, label=None, labels=None, values=None, **kwargs): # Rename the shorthands boxes = warnings._rename_objs("0.8.0", boxes=box) violins = warnings._rename_objs("0.8.0", violins=violin) - - -# The following helper classes and functions for curved_quiver are based on the -# work in the `dfm_tools` repository. -# Original file: https://github.com/Deltares/dfm_tools/blob/829e76f48ebc42460aae118cc190147a595a5f26/dfm_tools/modplot.py - -@dataclass -class CurvedQuiverSet(StreamplotSet): - lines: object - arrows: object - - -class CurvedQuiverSolver: - class DomainMap(object): - """Map representing different coordinate systems. - - Coordinate definitions: - * axes-coordinates goes from 0 to 1 in the domain. - * data-coordinates are specified by the input x-y coordinates. - * grid-coordinates goes from 0 to N and 0 to M for an N x M grid, - where N and M match the shape of the input data. - * mask-coordinates goes from 0 to N and 0 to M for an N x M mask, - where N and M are user-specified to control the density of - streamlines. - - This class also has methods for adding trajectories to the - StreamMask. Before adding a trajectory, run `start_trajectory` to - keep track of regions crossed by a given trajectory. Later, if you - decide the trajectory is bad (e.g., if the trajectory is very - short) just call `undo_trajectory`. - """ - - def __init__( - self, grid: "CurvedQuiverSolver.Grid", mask: "CurvedQuiverSolver.StreamMask" - ) -> None: - self.grid = grid - self.mask = mask - - # Constants for conversion between grid- and mask-coordinates - self.x_grid2mask = (mask.nx - 1) / grid.nx - self.y_grid2mask = (mask.ny - 1) / grid.ny - self.x_mask2grid = 1.0 / self.x_grid2mask - self.y_mask2grid = 1.0 / self.y_grid2mask - - self.x_data2grid = 1.0 / grid.dx - self.y_data2grid = 1.0 / grid.dy - - def grid2mask(self, xi: float, yi: float) -> tuple[int, int]: - """Return nearest space in mask-coords from given grid-coords.""" - return ( - int((xi * self.x_grid2mask) + 0.5), - int((yi * self.y_grid2mask) + 0.5), - ) - - def mask2grid(self, xm: int, ym: int) -> tuple[float, float]: - return xm * self.x_mask2grid, ym * self.y_mask2grid - - def data2grid(self, xd: float, yd: float) -> tuple[float, float]: - return xd * self.x_data2grid, yd * self.y_data2grid - - def grid2data(self, xg: float, yg: float) -> tuple[float, float]: - return xg / self.x_data2grid, yg / self.y_data2grid - - def start_trajectory(self, xg: float, yg: float) -> None: - xm, ym = self.grid2mask(xg, yg) - self.mask._start_trajectory(xm, ym) - - def reset_start_point(self, xg: float, yg: float) -> None: - xm, ym = self.grid2mask(xg, yg) - self.mask._current_xy = (xm, ym) - - def update_trajectory(self, xg: float, yg: float) -> None: - xm, ym = self.grid2mask(xg, yg) - self.mask._update_trajectory(xm, ym) - - def undo_trajectory(self) -> None: - self.mask._undo_trajectory() - - class Grid(object): - """Grid of data.""" - - def __init__(self, x: np.ndarray, y: np.ndarray) -> None: - if x.ndim == 1: - pass - elif x.ndim == 2: - x_row = x[0, :] - if not np.allclose(x_row, x): - raise ValueError("The rows of 'x' must be equal") - x = x_row - else: - raise ValueError("'x' can have at maximum 2 dimensions") - - if y.ndim == 1: - pass - elif y.ndim == 2: - y_col = y[:, 0] - if not np.allclose(y_col, y.T): - raise ValueError("The columns of 'y' must be equal") - y = y_col - else: - raise ValueError("'y' can have at maximum 2 dimensions") - - self.nx = len(x) - self.ny = len(y) - self.dx = x[1] - x[0] - self.dy = y[1] - y[0] - self.x_origin = x[0] - self.y_origin = y[0] - self.width = x[-1] - x[0] - self.height = y[-1] - y[0] - - @property - def shape(self) -> tuple[int, int]: - return self.ny, self.nx - - def within_grid(self, xi: float, yi: float) -> bool: - """Return True if point is a valid index of grid.""" - # Note that xi/yi can be floats; so, for example, we can't simply check - # `xi < self.nx` since `xi` can be `self.nx - 1 < xi < self.nx` - return xi >= 0 and xi <= self.nx - 1 and yi >= 0 and yi <= self.ny - 1 - - class StreamMask(object): - """Mask to keep track of discrete regions crossed by streamlines. - - The resolution of this grid determines the approximate spacing - between trajectories. Streamlines are only allowed to pass through - zeroed cells: When a streamline enters a cell, that cell is set to - 1, and no new streamlines are allowed to enter. - """ - - def __init__(self, density): - if np.isscalar(density): - if density <= 0: - raise ValueError("If a scalar, 'density' must be positive") - self.nx = self.ny = int(30 * density) - else: - if len(density) != 2: - raise ValueError("'density' can have at maximum 2 dimensions") - self.nx = int(30 * density[0]) - self.ny = int(30 * density[1]) - - self._mask = np.zeros((self.ny, self.nx)) - self.shape = self._mask.shape - self._current_xy = None - - def __getitem__(self, *args): - return self._mask.__getitem__(*args) - - def _start_trajectory(self, xm, ym): - """Start recording streamline trajectory""" - self._traj = [] - self._update_trajectory(xm, ym) - - def _undo_trajectory(self): - """Remove current trajectory from mask""" - for t in self._traj: - self._mask.__setitem__(t, 0) - - def _update_trajectory(self, xm: int, ym: int) -> None: - """Update current trajectory position in mask. - - If the new position has already been filled, raise - `InvalidIndexError`. - """ - - self._traj.append((ym, xm)) - self._mask[ym, xm] = 1 - self._current_xy = (xm, ym) - - class TerminateTrajectory(Exception): - pass - def __init__( - self, x: np.ndarray, y: np.ndarray, density: float | tuple[float, float] - ) -> None: - self.grid = CurvedQuiverSolver.Grid(x, y) - self.mask = CurvedQuiverSolver.StreamMask(density) - self.domain_map = CurvedQuiverSolver.DomainMap(self.grid, self.mask) - - def get_integrator( - self, - u: np.ndarray, - v: np.ndarray, - minlength: float, - resolution: float, - magnitude: np.ndarray, - ) -> Callable[[float, float], tuple[tuple[list[float], list[float]], bool] | None]: - # rescale velocity onto grid-coordinates for integrations. - u, v = self.domain_map.data2grid(u, v) - - # speed (path length) will be in axes-coordinates - u_ax = u / self.domain_map.grid.nx - v_ax = v / self.domain_map.grid.ny - speed = np.ma.sqrt(u_ax**2 + v_ax**2) - - def forward_time(xi: float, yi: float) -> tuple[float, float]: - ds_dt = self.interpgrid(speed, xi, yi) - if ds_dt == 0: - raise CurvedQuiverSolver.TerminateTrajectory() - dt_ds = 1.0 / ds_dt - ui = self.interpgrid(u, xi, yi) - vi = self.interpgrid(v, xi, yi) - return ui * dt_ds, vi * dt_ds - - def integrate( - x0: float, y0: float - ) -> tuple[tuple[list[float], list[float], bool]] | None: - """Return x, y grid-coordinates of trajectory based on starting point. - - Integrate both forward and backward in time from starting point - in grid coordinates. Integration is terminated when a trajectory - reaches a domain boundary or when it crosses into an already - occupied cell in the StreamMask. The resulting trajectory is - None if it is shorter than `minlength`. - """ - stotal, x_traj, y_traj = 0.0, [], [] - self.domain_map.start_trajectory(x0, y0) - self.domain_map.reset_start_point(x0, y0) - stotal, x_traj, y_traj, m_total, hit_edge = self.integrate_rk12( - x0, y0, forward_time, resolution, magnitude - ) - - if len(x_traj) > 1: - return (x_traj, y_traj), hit_edge - else: - # reject short trajectories - self.domain_map.undo_trajectory() - return None - - return integrate - - def integrate_rk12( - self, - x0: float, - y0: float, - f: Callable[[float, float], tuple[float, float]], - resolution: float, - magnitude: np.ndarray, - ) -> tuple[float, list[float], list[float], list[float], bool]: - """2nd-order Runge-Kutta algorithm with adaptive step size. - - This method is also referred to as the improved Euler's method, or - Heun's method. This method is favored over higher-order methods - because: - - 1. To get decent looking trajectories and to sample every mask cell - on the trajectory we need a small timestep, so a lower order - solver doesn't hurt us unless the data is *very* high - resolution. In fact, for cases where the user inputs data - smaller or of similar grid size to the mask grid, the higher - order corrections are negligible because of the very fast linear - interpolation used in `interpgrid`. - - 2. For high resolution input data (i.e. beyond the mask - resolution), we must reduce the timestep. Therefore, an - adaptive timestep is more suited to the problem as this would be - very hard to judge automatically otherwise. - - This integrator is about 1.5 - 2x as fast as both the RK4 and RK45 - solvers in most setups on my machine. I would recommend removing - the other two to keep things simple. - """ - # This error is below that needed to match the RK4 integrator. It - # is set for visual reasons -- too low and corners start - # appearing ugly and jagged. Can be tuned. - maxerror = 0.003 - - # This limit is important (for all integrators) to avoid the - # trajectory skipping some mask cells. We could relax this - # condition if we use the code which is commented out below to - # increment the location gradually. However, due to the efficient - # nature of the interpolation, this doesn't boost speed by much - # for quite a bit of complexity. - maxds = min(1.0 / self.domain_map.mask.nx, 1.0 / self.domain_map.mask.ny, 0.1) - ds = maxds - - stotal = 0 - xi = x0 - yi = y0 - xf_traj = [] - yf_traj = [] - m_total = [] - hit_edge = False - - while self.domain_map.grid.within_grid(xi, yi): - xf_traj.append(xi) - yf_traj.append(yi) - m_total.append(self.interpgrid(magnitude, xi, yi)) - - try: - k1x, k1y = f(xi, yi) - k2x, k2y = f(xi + ds * k1x, yi + ds * k1y) - except IndexError: - # Out of the domain on one of the intermediate integration steps. - # Take an Euler step to the boundary to improve neatness. - ds, xf_traj, yf_traj = self.euler_step(xf_traj, yf_traj, f) - stotal += ds - hit_edge = True - break - except CurvedQuiverSolver.TerminateTrajectory: - break - - dx1 = ds * k1x - dy1 = ds * k1y - dx2 = ds * 0.5 * (k1x + k2x) - dy2 = ds * 0.5 * (k1y + k2y) - - nx, ny = self.domain_map.grid.shape - # Error is normalized to the axes coordinates - error = np.sqrt(((dx2 - dx1) / nx) ** 2 + ((dy2 - dy1) / ny) ** 2) - - # Only save step if within error tolerance - if error < maxerror: - xi += dx2 - yi += dy2 - self.domain_map.update_trajectory(xi, yi) - if not self.domain_map.grid.within_grid(xi, yi): - hit_edge = True - if (stotal + ds) > resolution * np.mean(m_total): - break - stotal += ds - - # recalculate stepsize based on step error - if error == 0: - ds = maxds - else: - ds = min(maxds, 0.85 * ds * (maxerror / error) ** 0.5) - - return stotal, xf_traj, yf_traj, m_total, hit_edge - - def euler_step(self, xf_traj, yf_traj, f): - """Simple Euler integration step that extends streamline to boundary.""" - ny, nx = self.domain_map.grid.shape - xi = xf_traj[-1] - yi = yf_traj[-1] - cx, cy = f(xi, yi) - - if cx == 0: - dsx = np.inf - elif cx < 0: - dsx = xi / -cx - else: - dsx = (nx - 1 - xi) / cx - - if cy == 0: - dsy = np.inf - elif cy < 0: - dsy = yi / -cy - else: - dsy = (ny - 1 - yi) / cy - - ds = min(dsx, dsy) - - xf_traj.append(xi + cx * ds) - yf_traj.append(yi + cy * ds) - - return ds, xf_traj, yf_traj - - def interpgrid(self, a, xi, yi): - """Fast 2D, linear interpolation on an integer grid""" - Ny, Nx = np.shape(a) - - if isinstance(xi, np.ndarray): - x = xi.astype(int) - y = yi.astype(int) - - # Check that xn, yn don't exceed max index - xn = np.clip(x + 1, 0, Nx - 1) - yn = np.clip(y + 1, 0, Ny - 1) - else: - x = int(xi) - y = int(yi) - xn = min(x + 1, Nx - 1) - yn = min(y + 1, Ny - 1) - - a00 = a[y, x] - a01 = a[y, xn] - a10 = a[yn, x] - a11 = a[yn, xn] - - xt = xi - x - yt = yi - y - - a0 = a00 * (1 - xt) + a01 * xt - a1 = a10 * (1 - xt) + a11 * xt - ai = a0 * (1 - yt) + a1 * yt - - if not isinstance(xi, np.ndarray): - if np.ma.is_masked(ai): - raise CurvedQuiverSolver.TerminateTrajectory - return ai - - def gen_starting_points(self, x, y, grains): - eps = np.finfo(np.float32).eps - tmp_x = np.linspace(x.min() + eps, x.max() - eps, grains) - tmp_y = np.linspace(y.min() + eps, y.max() - eps, grains) - xs = np.tile(tmp_x, grains) - ys = np.repeat(tmp_y, grains) - seed_points = np.array([list(xs), list(ys)]) - return seed_points.T diff --git a/ultraplot/axes/plot_types/__init__.py b/ultraplot/axes/plot_types/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ultraplot/axes/plot_types/curved_quiver.py b/ultraplot/axes/plot_types/curved_quiver.py new file mode 100644 index 000000000..b540e54a5 --- /dev/null +++ b/ultraplot/axes/plot_types/curved_quiver.py @@ -0,0 +1,413 @@ +# The following helper classes and functions for curved_quiver are based on the +# work in the `dfm_tools` repository. +# Original file: https://github.com/Deltares/dfm_tools/blob/829e76f48ebc42460aae118cc190147a595a5f26/dfm_tools/modplot.py +# Special thanks to @veenstrajelmer for the initial implementation + +__all__ = [ + "CurvedQuiverSolver", + "CurvedQuiverSet", +] + +from typing import Callable +from dataclasses import dataclass +from matplotlib.streamplot import StreamplotSet +import numpy as np + + +@dataclass +class CurvedQuiverSet(StreamplotSet): + lines: object + arrows: object + + +class DomainMap(object): + """Map representing different coordinate systems. + + Coordinate definitions: + * axes-coordinates goes from 0 to 1 in the domain. + * data-coordinates are specified by the input x-y coordinates. + * grid-coordinates goes from 0 to N and 0 to M for an N x M grid, + where N and M match the shape of the input data. + * mask-coordinates goes from 0 to N and 0 to M for an N x M mask, + where N and M are user-specified to control the density of + streamlines. + + This class also has methods for adding trajectories to the + StreamMask. Before adding a trajectory, run `start_trajectory` to + keep track of regions crossed by a given trajectory. Later, if you + decide the trajectory is bad (e.g., if the trajectory is very + short) just call `undo_trajectory`. + """ + + def __init__(self, grid: "Grid", mask: "StreamMask") -> None: + self.grid = grid + self.mask = mask + + # Constants for conversion between grid- and mask-coordinates + self.x_grid2mask = (mask.nx - 1) / grid.nx + self.y_grid2mask = (mask.ny - 1) / grid.ny + self.x_mask2grid = 1.0 / self.x_grid2mask + self.y_mask2grid = 1.0 / self.y_grid2mask + + self.x_data2grid = 1.0 / grid.dx + self.y_data2grid = 1.0 / grid.dy + + def grid2mask(self, xi: float, yi: float) -> tuple[int, int]: + """Return nearest space in mask-coords from given grid-coords.""" + return ( + int((xi * self.x_grid2mask) + 0.5), + int((yi * self.y_grid2mask) + 0.5), + ) + + def mask2grid(self, xm: int, ym: int) -> tuple[float, float]: + return xm * self.x_mask2grid, ym * self.y_mask2grid + + def data2grid(self, xd: float, yd: float) -> tuple[float, float]: + return xd * self.x_data2grid, yd * self.y_data2grid + + def grid2data(self, xg: float, yg: float) -> tuple[float, float]: + return xg / self.x_data2grid, yg / self.y_data2grid + + def start_trajectory(self, xg: float, yg: float) -> None: + xm, ym = self.grid2mask(xg, yg) + self.mask._start_trajectory(xm, ym) + + def reset_start_point(self, xg: float, yg: float) -> None: + xm, ym = self.grid2mask(xg, yg) + self.mask._current_xy = (xm, ym) + + def update_trajectory(self, xg: float, yg: float) -> None: + xm, ym = self.grid2mask(xg, yg) + self.mask._update_trajectory(xm, ym) + + def undo_trajectory(self) -> None: + self.mask._undo_trajectory() + + +class Grid(object): + """Grid of data.""" + + def __init__(self, x: np.ndarray, y: np.ndarray) -> None: + if x.ndim == 1: + pass + elif x.ndim == 2: + x_row = x[0, :] + if not np.allclose(x_row, x): + raise ValueError("The rows of 'x' must be equal") + x = x_row + else: + raise ValueError("'x' can have at maximum 2 dimensions") + + if y.ndim == 1: + pass + elif y.ndim == 2: + y_col = y[:, 0] + if not np.allclose(y_col, y.T): + raise ValueError("The columns of 'y' must be equal") + y = y_col + else: + raise ValueError("'y' can have at maximum 2 dimensions") + + self.nx = len(x) + self.ny = len(y) + self.dx = x[1] - x[0] + self.dy = y[1] - y[0] + self.x_origin = x[0] + self.y_origin = y[0] + self.width = x[-1] - x[0] + self.height = y[-1] - y[0] + + @property + def shape(self) -> tuple[int, int]: + return self.ny, self.nx + + def within_grid(self, xi: float, yi: float) -> bool: + """Return True if point is a valid index of grid.""" + # Note that xi/yi can be floats; so, for example, we can't simply check + # `xi < self.nx` since `xi` can be `self.nx - 1 < xi < self.nx` + return xi >= 0 and xi <= self.nx - 1 and yi >= 0 and yi <= self.ny - 1 + + +class StreamMask(object): + """Mask to keep track of discrete regions crossed by streamlines. + + The resolution of this grid determines the approximate spacing + between trajectories. Streamlines are only allowed to pass through + zeroed cells: When a streamline enters a cell, that cell is set to + 1, and no new streamlines are allowed to enter. + """ + + def __init__(self, density: float | int): + if np.isscalar(density): + if density <= 0: + raise ValueError("If a scalar, 'density' must be positive") + self.nx = self.ny = int(30 * density) + else: + if len(density) != 2: + raise ValueError("'density' can have at maximum 2 dimensions") + self.nx = int(30 * density[0]) + self.ny = int(30 * density[1]) + + self._mask = np.zeros((self.ny, self.nx)) + self.shape = self._mask.shape + self._current_xy = None + + def __getitem__(self, *args): + return self._mask.__getitem__(*args) + + def _start_trajectory(self, xm: int, ym: int): + """Start recording streamline trajectory""" + self._traj = [] + self._update_trajectory(xm, ym) + + def _undo_trajectory(self): + """Remove current trajectory from mask""" + for t in self._traj: + self._mask.__setitem__(t, 0) + + def _update_trajectory(self, xm: int, ym: int) -> None: + """Update current trajectory position in mask. + + If the new position has already been filled, raise + `InvalidIndexError`. + """ + + self._traj.append((ym, xm)) + self._mask[ym, xm] = 1 + self._current_xy = (xm, ym) + + +class TerminateTrajectory(Exception): + pass + + +class CurvedQuiverSolver: + + def __init__( + self, x: np.ndarray, y: np.ndarray, density: float | tuple[float, float] + ) -> None: + self.grid = Grid(x, y) + self.mask = StreamMask(density) + self.domain_map = DomainMap(self.grid, self.mask) + + def get_integrator( + self, + u: np.ndarray, + v: np.ndarray, + minlength: float, + resolution: float, + magnitude: np.ndarray, + ) -> Callable[[float, float], tuple[tuple[list[float], list[float]], bool] | None]: + # rescale velocity onto grid-coordinates for integrations. + u, v = self.domain_map.data2grid(u, v) + + # speed (path length) will be in axes-coordinates + u_ax = u / self.domain_map.grid.nx + v_ax = v / self.domain_map.grid.ny + speed = np.ma.sqrt(u_ax**2 + v_ax**2) + + def forward_time(xi: float, yi: float) -> tuple[float, float]: + ds_dt = self.interpgrid(speed, xi, yi) + if ds_dt == 0: + raise TerminateTrajectory() + dt_ds = 1.0 / ds_dt + ui = self.interpgrid(u, xi, yi) + vi = self.interpgrid(v, xi, yi) + return ui * dt_ds, vi * dt_ds + + def integrate( + x0: float, y0: float + ) -> tuple[tuple[list[float], list[float], bool]] | None: + """Return x, y grid-coordinates of trajectory based on starting point. + + Integrate both forward and backward in time from starting point + in grid coordinates. Integration is terminated when a trajectory + reaches a domain boundary or when it crosses into an already + occupied cell in the StreamMask. The resulting trajectory is + None if it is shorter than `minlength`. + """ + stotal, x_traj, y_traj = 0.0, [], [] + self.domain_map.start_trajectory(x0, y0) + self.domain_map.reset_start_point(x0, y0) + stotal, x_traj, y_traj, m_total, hit_edge = self.integrate_rk12( + x0, y0, forward_time, resolution, magnitude + ) + + if len(x_traj) > 1: + return (x_traj, y_traj), hit_edge + else: + # reject short trajectories + self.domain_map.undo_trajectory() + return None + + return integrate + + def integrate_rk12( + self, + x0: float, + y0: float, + f: Callable[[float, float], tuple[float, float]], + resolution: float, + magnitude: np.ndarray, + ) -> tuple[float, list[float], list[float], list[float], bool]: + """2nd-order Runge-Kutta algorithm with adaptive step size. + + This method is also referred to as the improved Euler's method, or + Heun's method. This method is favored over higher-order methods + because: + + 1. To get decent looking trajectories and to sample every mask cell + on the trajectory we need a small timestep, so a lower order + solver doesn't hurt us unless the data is *very* high + resolution. In fact, for cases where the user inputs data + smaller or of similar grid size to the mask grid, the higher + order corrections are negligible because of the very fast linear + interpolation used in `interpgrid`. + + 2. For high resolution input data (i.e. beyond the mask + resolution), we must reduce the timestep. Therefore, an + adaptive timestep is more suited to the problem as this would be + very hard to judge automatically otherwise. + + This integrator is about 1.5 - 2x as fast as both the RK4 and RK45 + solvers in most setups on my machine. I would recommend removing + the other two to keep things simple. + """ + # This error is below that needed to match the RK4 integrator. It + # is set for visual reasons -- too low and corners start + # appearing ugly and jagged. Can be tuned. + maxerror = 0.003 + + # This limit is important (for all integrators) to avoid the + # trajectory skipping some mask cells. We could relax this + # condition if we use the code which is commented out below to + # increment the location gradually. However, due to the efficient + # nature of the interpolation, this doesn't boost speed by much + # for quite a bit of complexity. + maxds = min(1.0 / self.domain_map.mask.nx, 1.0 / self.domain_map.mask.ny, 0.1) + ds = maxds + + stotal = 0 + xi = x0 + yi = y0 + xf_traj = [] + yf_traj = [] + m_total = [] + hit_edge = False + + while self.domain_map.grid.within_grid(xi, yi): + xf_traj.append(xi) + yf_traj.append(yi) + m_total.append(self.interpgrid(magnitude, xi, yi)) + + try: + k1x, k1y = f(xi, yi) + k2x, k2y = f(xi + ds * k1x, yi + ds * k1y) + except IndexError: + # Out of the domain on one of the intermediate integration steps. + # Take an Euler step to the boundary to improve neatness. + ds, xf_traj, yf_traj = self.euler_step(xf_traj, yf_traj, f) + stotal += ds + hit_edge = True + break + except TerminateTrajectory: + break + + dx1 = ds * k1x + dy1 = ds * k1y + dx2 = ds * 0.5 * (k1x + k2x) + dy2 = ds * 0.5 * (k1y + k2y) + + nx, ny = self.domain_map.grid.shape + # Error is normalized to the axes coordinates + error = np.sqrt(((dx2 - dx1) / nx) ** 2 + ((dy2 - dy1) / ny) ** 2) + + # Only save step if within error tolerance + if error < maxerror: + xi += dx2 + yi += dy2 + self.domain_map.update_trajectory(xi, yi) + if not self.domain_map.grid.within_grid(xi, yi): + hit_edge = True + if (stotal + ds) > resolution * np.mean(m_total): + break + stotal += ds + + # recalculate stepsize based on step error + if error == 0: + ds = maxds + else: + ds = min(maxds, 0.85 * ds * (maxerror / error) ** 0.5) + + return stotal, xf_traj, yf_traj, m_total, hit_edge + + def euler_step(self, xf_traj, yf_traj, f): + """Simple Euler integration step that extends streamline to boundary.""" + ny, nx = self.domain_map.grid.shape + xi = xf_traj[-1] + yi = yf_traj[-1] + cx, cy = f(xi, yi) + + if cx == 0: + dsx = np.inf + elif cx < 0: + dsx = xi / -cx + else: + dsx = (nx - 1 - xi) / cx + + if cy == 0: + dsy = np.inf + elif cy < 0: + dsy = yi / -cy + else: + dsy = (ny - 1 - yi) / cy + + ds = min(dsx, dsy) + + xf_traj.append(xi + cx * ds) + yf_traj.append(yi + cy * ds) + + return ds, xf_traj, yf_traj + + def interpgrid(self, a, xi, yi): + """Fast 2D, linear interpolation on an integer grid""" + Ny, Nx = np.shape(a) + + if isinstance(xi, np.ndarray): + x = xi.astype(int) + y = yi.astype(int) + + # Check that xn, yn don't exceed max index + xn = np.clip(x + 1, 0, Nx - 1) + yn = np.clip(y + 1, 0, Ny - 1) + else: + x = int(xi) + y = int(yi) + xn = min(x + 1, Nx - 1) + yn = min(y + 1, Ny - 1) + + a00 = a[y, x] + a01 = a[y, xn] + a10 = a[yn, x] + a11 = a[yn, xn] + + xt = xi - x + yt = yi - y + + a0 = a00 * (1 - xt) + a01 * xt + a1 = a10 * (1 - xt) + a11 * xt + ai = a0 * (1 - yt) + a1 * yt + + if not isinstance(xi, np.ndarray): + if np.ma.is_masked(ai): + raise TerminateTrajectory + return ai + + def gen_starting_points(self, x, y, grains): + eps = np.finfo(np.float32).eps + tmp_x = np.linspace(x.min() + eps, x.max() - eps, grains) + tmp_y = np.linspace(y.min() + eps, y.max() - eps, grains) + xs = np.tile(tmp_x, grains) + ys = np.repeat(tmp_y, grains) + seed_points = np.array([list(xs), list(ys)]) + return seed_points.T diff --git a/ultraplot/tests/test_plot.py b/ultraplot/tests/test_plot.py index 69ebbf3d3..0fee455b3 100644 --- a/ultraplot/tests/test_plot.py +++ b/ultraplot/tests/test_plot.py @@ -485,9 +485,11 @@ def test_validate_vector_shapes_pass(): """ Test that vector shapes match the grid shape using CurvedQuiverSolver. """ + from ultraplot.axes.plot_types.curved_quiver import Grid + x = np.linspace(0, 1, 3) y = np.linspace(0, 1, 3) - grid = uplt.axes.plot.CurvedQuiverSolver.Grid(x, y) + grid = Grid(x, y) u = np.ones(grid.shape) v = np.ones(grid.shape) assert u.shape == grid.shape @@ -498,9 +500,11 @@ def test_validate_vector_shapes_fail(): """ Test that assertion fails when u and v do not match the grid shape using CurvedQuiverSolver. """ + from ultraplot.axes.plot_types.curved_quiver import CurvedQuiverSolver, Grid + x = np.linspace(0, 1, 3) y = np.linspace(0, 1, 3) - grid = uplt.axes.plot.CurvedQuiverSolver.Grid(x, y) + grid = Grid(x, y) u = np.ones((2, 2)) v = np.ones(grid.shape) with pytest.raises(AssertionError): @@ -524,10 +528,12 @@ def test_generate_start_points(): Test that CurvedQuiverSolver.gen_starting_points returns valid grid coordinates for seed points, and that grid.within_grid detects points outside the grid boundaries. """ + from ultraplot.axes.plot_types.curved_quiver import CurvedQuiverSolver + x = np.linspace(0, 1, 5) y = np.linspace(0, 1, 5) grains = 5 - solver = uplt.axes.plot.CurvedQuiverSolver(x, y, density=5) + solver = CurvedQuiverSolver(x, y, density=5) sp2 = solver.gen_starting_points(x, y, grains) assert sp2.shape[1] == 2 # Should detect if outside boundaries @@ -542,12 +548,14 @@ def test_calculate_trajectories(): Test that CurvedQuiverSolver.get_integrator returns callable for each seed point and returns lists of trajectories and edges of correct length. """ + from ultraplot.axes.plot_types.curved_quiver import CurvedQuiverSolver + x = np.linspace(0, 1, 5) y = np.linspace(0, 1, 5) u = np.ones((5, 5)) v = np.ones((5, 5)) mag = np.sqrt(u**2 + v**2) - solver = uplt.axes.plot.CurvedQuiverSolver(x, y, density=5) + solver = CurvedQuiverSolver(x, y, density=5) integrator = solver.get_integrator( u, v, minlength=0.1, resolution=1.0, magnitude=mag ) From 9b89228bf44caffda448de55957f9f3306399b49 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Wed, 8 Oct 2025 20:33:09 +0200 Subject: [PATCH 23/33] mv parameters to rcsetup --- ultraplot/axes/plot.py | 43 +++++++++++++++++----------------- ultraplot/internals/rcsetup.py | 31 ++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 21 deletions(-) diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index 26fc9fdc2..643e7983f 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -1567,23 +1567,23 @@ class PlotAxes(base.Axes): @docstring._snippet_manager def curved_quiver( self, - x, - y, - u, - v, + x: np.ndarray, + y: np.ndarray, + u: np.ndarray, + v: np.ndarray, linewidth=None, color=None, cmap=None, norm=None, - arrowsize=1, - arrowstyle="-|>", + arrowsize=None, + arrowstyle=None, transform=None, zorder=None, start_points=None, - scale=1.0, - grains=15, - density=10, - arrow_at_end=True, + scale=None, + grains=None, + density=None, + arrow_at_end=None, ): """ %(plot.curved_quiver)s @@ -1595,20 +1595,22 @@ def curved_quiver( """ from .plot_types.curved_quiver import CurvedQuiverSolver, CurvedQuiverSet + # Parse inputs + arrowsize = _not_none(arrowsize, rc["curved_quiver.arrowsize"]) + arrowstyle = _not_none(arrowstyle, rc["curved_quiver.arrowstyle"]) + zorder = _not_none(zorder, mlines.Line2D.zorder) + transform = _not_none(transform, self.transData) + color = _not_none(color, self._get_lines.get_next_color()) + linewidth = _not_none(linewidth, rc["lines.linewidth"]) + scale = _not_none(scale, rc["curved_quiver.scale"]) + grains = _not_none(grains, rc["curved_quiver.grains"]) + density = _not_none(density, rc["curved_quiver.density"]) + arrows_at_end = _not_none(arrow_at_end, rc["curved_quiver.arrows_at_end"]) + solver = CurvedQuiverSolver(x, y, density) if zorder is None: zorder = mlines.Line2D.zorder - # default to data coordinates - if transform is None: - transform = self.transData - - if color is None: - color = self._get_lines.get_next_color() - - if linewidth is None: - linewidth = rc["lines.linewidth"] - line_kw = {} arrow_kw = dict(arrowstyle=arrowstyle, mutation_scale=10 * arrowsize) @@ -1650,7 +1652,6 @@ def curved_quiver( minlength = 0.9 * resolution integrate = solver.get_integrator(u, v, minlength, resolution, magnitude) - trajectories = [] edges = [] diff --git a/ultraplot/internals/rcsetup.py b/ultraplot/internals/rcsetup.py index 143fa2378..bdb9a88d7 100644 --- a/ultraplot/internals/rcsetup.py +++ b/ultraplot/internals/rcsetup.py @@ -890,6 +890,37 @@ def copy(self): "interpreted by `~ultraplot.utils.units`. Numeric units are points." ) _rc_ultraplot_table = { + # Curved quiver settings + "curved_quiver.arrowsize": ( + 1.0, + _validate_float, + "Default size scaling for arrows in curved quiver plots.", + ), + "curved_quiver.arrowstyle": ( + "-|>", + _validate_string, + "Default arrow style for curved quiver plots.", + ), + "curved_quiver.scale": ( + 1.0, + _validate_float, + "Default scale factor for curved quiver plots.", + ), + "curved_quiver.grains": ( + 15, + _validate_int, + "Default number of grains (segments) for curved quiver arrows.", + ), + "curved_quiver.density": ( + 10, + _validate_int, + "Default density of arrows for curved quiver plots.", + ), + "curved_quiver.arrows_at_end": ( + True, + _validate_bool, + "Whether to draw arrows at the end of curved quiver lines by default.", + ), # Stylesheet "style": ( None, From a746ef8fe6da2c38b48c67bf30427b38ffd607dd Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Wed, 8 Oct 2025 20:34:58 +0200 Subject: [PATCH 24/33] add type hinting --- ultraplot/axes/plot.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index 643e7983f..cb7c19249 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -34,6 +34,7 @@ from matplotlib.streamplot import StreamplotSet from packaging import version import numpy as np +from typing import Optional, Union, Any import numpy.ma as ma from .. import colors as pcolors @@ -1571,19 +1572,19 @@ def curved_quiver( y: np.ndarray, u: np.ndarray, v: np.ndarray, - linewidth=None, - color=None, - cmap=None, - norm=None, - arrowsize=None, - arrowstyle=None, - transform=None, - zorder=None, - start_points=None, - scale=None, - grains=None, - density=None, - arrow_at_end=None, + linewidth: Optional[float] = None, + color: Optional[Union[str, Any]] = None, + cmap: Optional[Any] = None, + norm: Optional[Any] = None, + arrowsize: Optional[float] = None, + arrowstyle: Optional[str] = None, + transform: Optional[Any] = None, + zorder: Optional[int] = None, + start_points: Optional[np.ndarray] = None, + scale: Optional[float] = None, + grains: Optional[int] = None, + density: Optional[int] = None, + arrow_at_end: Optional[bool] = None, ): """ %(plot.curved_quiver)s From 4e3927c8455e5a768cf215d11e5fc8639661dc51 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Wed, 8 Oct 2025 21:11:46 +0200 Subject: [PATCH 25/33] rm dup docstring --- ultraplot/axes/plot.py | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index cb7c19249..beb15fff8 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -209,33 +209,7 @@ - lines: LineCollection of streamlines - arrows: PatchCollection of arrows """ -_curved_quiver_docstring = """ -density : float or (float, float), optional - Controls the closeness of streamlines. -grains : int or (int, int), optional - Number of seed points in x and y. -linewidth : float or 2D array, optional - Width of streamlines. -cmap, norm : optional - Colormap and normalization for array colors. -arrowsize : float, optional - Arrow size scaling. -arrowstyle : str, optional - Arrow style specification. -transform : optional - Matplotlib transform. -zorder : float, optional - Z-order for lines/arrows. -start_points : (N, 2) array, optional - Starting points for streamlines. -Returns -------- -CurvedQuiverSet - Container with attributes: - - lines: LineCollection of streamlines - - arrows: PatchCollection of arrows -""" docstring._snippet_manager["plot.curved_quiver"] = _curved_quiver_docstring # Auto colorbar and legend docstring _guide_docstring = """ From fd0d1e89b54e4e6fc79c7fb14dc450c06f19c659 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Wed, 8 Oct 2025 21:12:26 +0200 Subject: [PATCH 26/33] rm unused imports --- ultraplot/axes/plot.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index beb15fff8..49b0687c4 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -12,7 +12,6 @@ from typing import Any, Union, Iterable, Optional -from dataclasses import dataclass from collections.abc import Callable from collections.abc import Iterable @@ -31,7 +30,6 @@ import matplotlib.ticker as mticker import matplotlib.pyplot as mplt import matplotlib as mpl -from matplotlib.streamplot import StreamplotSet from packaging import version import numpy as np from typing import Optional, Union, Any From dba943bb2ef3af2892c89b87113ee2ab760aecbb Mon Sep 17 00:00:00 2001 From: Casper van Elteren Date: Mon, 13 Oct 2025 18:05:26 +0200 Subject: [PATCH 27/33] Update ultraplot/axes/plot_types/curved_quiver.py Co-authored-by: Matthew R. Becker --- ultraplot/axes/plot_types/curved_quiver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ultraplot/axes/plot_types/curved_quiver.py b/ultraplot/axes/plot_types/curved_quiver.py index b540e54a5..60cf620e8 100644 --- a/ultraplot/axes/plot_types/curved_quiver.py +++ b/ultraplot/axes/plot_types/curved_quiver.py @@ -20,7 +20,7 @@ class CurvedQuiverSet(StreamplotSet): arrows: object -class DomainMap(object): +class _DomainMap(object): """Map representing different coordinate systems. Coordinate definitions: From f6c9817cf22bbe929f4882b65098cb92deb87d1b Mon Sep 17 00:00:00 2001 From: Casper van Elteren Date: Mon, 13 Oct 2025 18:05:43 +0200 Subject: [PATCH 28/33] Apply suggestion from @beckermr Co-authored-by: Matthew R. Becker --- ultraplot/axes/plot_types/curved_quiver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ultraplot/axes/plot_types/curved_quiver.py b/ultraplot/axes/plot_types/curved_quiver.py index 60cf620e8..a03aef28a 100644 --- a/ultraplot/axes/plot_types/curved_quiver.py +++ b/ultraplot/axes/plot_types/curved_quiver.py @@ -84,7 +84,7 @@ def undo_trajectory(self) -> None: self.mask._undo_trajectory() -class Grid(object): +class CurvedQuiverGrid(object): """Grid of data.""" def __init__(self, x: np.ndarray, y: np.ndarray) -> None: From bb2f3e0b75e707b4f6810edc78626dc062711f9e Mon Sep 17 00:00:00 2001 From: Casper van Elteren Date: Mon, 13 Oct 2025 18:05:59 +0200 Subject: [PATCH 29/33] Apply suggestion from @beckermr Co-authored-by: Matthew R. Becker --- ultraplot/axes/plot_types/curved_quiver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ultraplot/axes/plot_types/curved_quiver.py b/ultraplot/axes/plot_types/curved_quiver.py index a03aef28a..3ba31616e 100644 --- a/ultraplot/axes/plot_types/curved_quiver.py +++ b/ultraplot/axes/plot_types/curved_quiver.py @@ -128,7 +128,7 @@ def within_grid(self, xi: float, yi: float) -> bool: return xi >= 0 and xi <= self.nx - 1 and yi >= 0 and yi <= self.ny - 1 -class StreamMask(object): +class _StreamMask(object): """Mask to keep track of discrete regions crossed by streamlines. The resolution of this grid determines the approximate spacing From c15ec2c15912812a30d8c390fea632aa1240a58d Mon Sep 17 00:00:00 2001 From: Casper van Elteren Date: Mon, 13 Oct 2025 18:06:09 +0200 Subject: [PATCH 30/33] Apply suggestion from @beckermr Co-authored-by: Matthew R. Becker --- ultraplot/axes/plot_types/curved_quiver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ultraplot/axes/plot_types/curved_quiver.py b/ultraplot/axes/plot_types/curved_quiver.py index 3ba31616e..be23f1b9d 100644 --- a/ultraplot/axes/plot_types/curved_quiver.py +++ b/ultraplot/axes/plot_types/curved_quiver.py @@ -177,7 +177,7 @@ def _update_trajectory(self, xm: int, ym: int) -> None: self._current_xy = (xm, ym) -class TerminateTrajectory(Exception): +class CurvedQuiverTerminateTrajectory(Exception): pass From 60cf7bf109c372e312014c69c8e3387e38188694 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Mon, 13 Oct 2025 18:11:09 +0200 Subject: [PATCH 31/33] rename private classes --- ultraplot/axes/plot_types/curved_quiver.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ultraplot/axes/plot_types/curved_quiver.py b/ultraplot/axes/plot_types/curved_quiver.py index be23f1b9d..2e6d70307 100644 --- a/ultraplot/axes/plot_types/curved_quiver.py +++ b/ultraplot/axes/plot_types/curved_quiver.py @@ -177,7 +177,7 @@ def _update_trajectory(self, xm: int, ym: int) -> None: self._current_xy = (xm, ym) -class CurvedQuiverTerminateTrajectory(Exception): +class _CurvedQuiverTerminateTrajectory(Exception): pass @@ -186,9 +186,9 @@ class CurvedQuiverSolver: def __init__( self, x: np.ndarray, y: np.ndarray, density: float | tuple[float, float] ) -> None: - self.grid = Grid(x, y) - self.mask = StreamMask(density) - self.domain_map = DomainMap(self.grid, self.mask) + self.grid = _Grid(x, y) + self.mask = _StreamMask(density) + self.domain_map = _DomainMap(self.grid, self.mask) def get_integrator( self, @@ -209,7 +209,7 @@ def get_integrator( def forward_time(xi: float, yi: float) -> tuple[float, float]: ds_dt = self.interpgrid(speed, xi, yi) if ds_dt == 0: - raise TerminateTrajectory() + raise _CurvedQuiverTerminateTrajectory() dt_ds = 1.0 / ds_dt ui = self.interpgrid(u, xi, yi) vi = self.interpgrid(v, xi, yi) @@ -400,7 +400,7 @@ def interpgrid(self, a, xi, yi): if not isinstance(xi, np.ndarray): if np.ma.is_masked(ai): - raise TerminateTrajectory + raise _CurvedQuiverTerminateTrajectory return ai def gen_starting_points(self, x, y, grains): From f841337c4b90e5c3f33a1b8a1911b615494916af Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Mon, 13 Oct 2025 18:12:15 +0200 Subject: [PATCH 32/33] more renaming --- ultraplot/axes/plot_types/curved_quiver.py | 2 +- ultraplot/tests/test_plot.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ultraplot/axes/plot_types/curved_quiver.py b/ultraplot/axes/plot_types/curved_quiver.py index 2e6d70307..95f5822f0 100644 --- a/ultraplot/axes/plot_types/curved_quiver.py +++ b/ultraplot/axes/plot_types/curved_quiver.py @@ -310,7 +310,7 @@ def integrate_rk12( stotal += ds hit_edge = True break - except TerminateTrajectory: + except _CurvedQuiverTerminateTrajectory: break dx1 = ds * k1x diff --git a/ultraplot/tests/test_plot.py b/ultraplot/tests/test_plot.py index 0fee455b3..42a15261b 100644 --- a/ultraplot/tests/test_plot.py +++ b/ultraplot/tests/test_plot.py @@ -485,11 +485,11 @@ def test_validate_vector_shapes_pass(): """ Test that vector shapes match the grid shape using CurvedQuiverSolver. """ - from ultraplot.axes.plot_types.curved_quiver import Grid + from ultraplot.axes.plot_types.curved_quiver import _Grid x = np.linspace(0, 1, 3) y = np.linspace(0, 1, 3) - grid = Grid(x, y) + grid = _Grid(x, y) u = np.ones(grid.shape) v = np.ones(grid.shape) assert u.shape == grid.shape @@ -500,11 +500,11 @@ def test_validate_vector_shapes_fail(): """ Test that assertion fails when u and v do not match the grid shape using CurvedQuiverSolver. """ - from ultraplot.axes.plot_types.curved_quiver import CurvedQuiverSolver, Grid + from ultraplot.axes.plot_types.curved_quiver import CurvedQuiverSolver, _Grid x = np.linspace(0, 1, 3) y = np.linspace(0, 1, 3) - grid = Grid(x, y) + grid = _Grid(x, y) u = np.ones((2, 2)) v = np.ones(grid.shape) with pytest.raises(AssertionError): From 36c00d8a1911329e6be44d986eb53651054ba00c Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Mon, 13 Oct 2025 18:18:51 +0200 Subject: [PATCH 33/33] more renaming --- ultraplot/axes/plot_types/curved_quiver.py | 4 ++-- ultraplot/tests/test_plot.py | 11 +++++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/ultraplot/axes/plot_types/curved_quiver.py b/ultraplot/axes/plot_types/curved_quiver.py index 95f5822f0..5489a0a90 100644 --- a/ultraplot/axes/plot_types/curved_quiver.py +++ b/ultraplot/axes/plot_types/curved_quiver.py @@ -84,7 +84,7 @@ def undo_trajectory(self) -> None: self.mask._undo_trajectory() -class CurvedQuiverGrid(object): +class _CurvedQuiverGrid(object): """Grid of data.""" def __init__(self, x: np.ndarray, y: np.ndarray) -> None: @@ -186,7 +186,7 @@ class CurvedQuiverSolver: def __init__( self, x: np.ndarray, y: np.ndarray, density: float | tuple[float, float] ) -> None: - self.grid = _Grid(x, y) + self.grid = _CurvedQuiverGrid(x, y) self.mask = _StreamMask(density) self.domain_map = _DomainMap(self.grid, self.mask) diff --git a/ultraplot/tests/test_plot.py b/ultraplot/tests/test_plot.py index 42a15261b..e1662dd56 100644 --- a/ultraplot/tests/test_plot.py +++ b/ultraplot/tests/test_plot.py @@ -485,11 +485,11 @@ def test_validate_vector_shapes_pass(): """ Test that vector shapes match the grid shape using CurvedQuiverSolver. """ - from ultraplot.axes.plot_types.curved_quiver import _Grid + from ultraplot.axes.plot_types.curved_quiver import _CurvedQuiverGrid x = np.linspace(0, 1, 3) y = np.linspace(0, 1, 3) - grid = _Grid(x, y) + grid = _CurvedQuiverGrid(x, y) u = np.ones(grid.shape) v = np.ones(grid.shape) assert u.shape == grid.shape @@ -500,11 +500,14 @@ def test_validate_vector_shapes_fail(): """ Test that assertion fails when u and v do not match the grid shape using CurvedQuiverSolver. """ - from ultraplot.axes.plot_types.curved_quiver import CurvedQuiverSolver, _Grid + from ultraplot.axes.plot_types.curved_quiver import ( + CurvedQuiverSolver, + _CurvedQuiverGrid, + ) x = np.linspace(0, 1, 3) y = np.linspace(0, 1, 3) - grid = _Grid(x, y) + grid = _CurvedQuiverGrid(x, y) u = np.ones((2, 2)) v = np.ones(grid.shape) with pytest.raises(AssertionError):