diff --git a/ultraplot/axes/geo.py b/ultraplot/axes/geo.py index 267acb206..f08bc48cf 100644 --- a/ultraplot/axes/geo.py +++ b/ultraplot/axes/geo.py @@ -2,6 +2,8 @@ """ Axes filled with cartographic projections. """ +from __future__ import annotations + import copy import inspect from functools import partial @@ -12,8 +14,8 @@ except ImportError: # From Python 3.5 from typing_extensions import override - -from collections.abc import MutableMapping +from collections.abc import Iterator, MutableMapping, Sequence +from typing import Any, Optional, Protocol import matplotlib.axis as maxis import matplotlib.path as mpath @@ -55,6 +57,16 @@ __all__ = ["GeoAxes"] +# Basemap gridlines are dicts keyed by location containing (lines, labels). +GridlineDict = MutableMapping[float, tuple[list[Any], list[mtext.Text]]] +_GRIDLINER_PAD_SCALE = 2.0 # points; matches tick size visually +_MINOR_TICK_SCALE = 0.6 # relative to major tick length +_BASEMAP_LABEL_SIZE_SCALE = 0.5 # empirical scaling for label offset +_BASEMAP_LABEL_Y_SCALE = 0.65 # empirical spacing to mimic cartopy +_BASEMAP_LABEL_X_SCALE = 0.25 # empirical spacing to mimic cartopy +_CARTOPY_LABEL_SIDES = ("labelleft", "labelright", "labelbottom", "labeltop", "geo") +_BASEMAP_LABEL_SIDES = ("labelleft", "labelright", "labeltop", "labelbottom", "geo") + # Format docstring _format_docstring = """ @@ -217,17 +229,60 @@ class _GeoLabel(object): Optionally omit overlapping check if an rc setting is disabled. """ - def check_overlapping(self, *args, **kwargs): + def check_overlapping(self, *args: Any, **kwargs: Any) -> bool: if rc["grid.checkoverlap"]: return super().check_overlapping(*args, **kwargs) else: return False -# Add monkey patch to gridliner module if cgridliner is not None and hasattr(cgridliner, "Label"): # only recent versions - _cls = type("Label", (_GeoLabel, cgridliner.Label), {}) - cgridliner.Label = _cls + + class _CartopyLabel(_GeoLabel, cgridliner.Label): + """Label class with configurable overlap checks.""" + + class _CartopyGridliner(cgridliner.Gridliner): + """ + Gridliner subclass to localize cartopy quirks in one place. + """ + + LabelClass = _CartopyLabel + + def _generate_labels(self) -> Iterator[_CartopyLabel]: + """Yield label objects, reusing cached instances when possible.""" + for label in self._all_labels: + yield label + + while True: + new_artist = mtext.Text() + new_artist.set_figure(self.axes.figure) + new_artist.axes = self.axes + + new_label = self.LabelClass(new_artist, None, None, None) + self._all_labels.append(new_label) + + yield new_label + + def _axes_domain(self, *args: Any, **kwargs: Any) -> tuple[Any, Any]: + x_range, y_range = super()._axes_domain(*args, **kwargs) + if _version_cartopy < "0.18": + lon_0 = self.axes.projection.proj4_params.get("lon_0", 0) + x_range = np.asarray(x_range) + lon_0 + return x_range, y_range + + def _draw_gridliner(self, *args: Any, **kwargs: Any) -> Any: # noqa: E306 + result = super()._draw_gridliner(*args, **kwargs) + if _version_cartopy >= "0.18": + lon_lim, _ = self._axes_domain() + if abs(np.diff(lon_lim)) == abs(np.diff(self.crs.x_limits)): + for collection in self.xline_artists: + if not getattr(collection, "_cartopy_fix", False): + collection.get_paths().pop(-1) + collection._cartopy_fix = True + return result + +else: + _CartopyGridliner = None class _GeoAxis(object): @@ -240,7 +295,7 @@ class _GeoAxis(object): # NOTE: Due to cartopy bug (https://github.com/SciTools/cartopy/issues/1564) # we store presistent longitude and latitude locators on axes, then *call* # them whenever set_extent is called and apply *fixed* locators. - def __init__(self, axes): + def __init__(self, axes: "GeoAxes") -> None: self.axes = axes self.major = maxis.Ticker() self.minor = maxis.Ticker() @@ -256,7 +311,7 @@ def __init__(self, axes): and _version_cartopy >= "0.18" ) - def _get_extent(self): + def _get_extent(self) -> tuple[float, float, float, float]: # Try to get extent but bail out for projections where this is # impossible. So far just transverse Mercator try: @@ -266,7 +321,7 @@ def _get_extent(self): return (-180 + lon0, 180 + lon0, -90, 90) @staticmethod - def _pad_ticks(ticks, vmin, vmax): + def _pad_ticks(ticks: np.ndarray, vmin: float, vmax: float) -> np.ndarray: # Wrap up to the longitude/latitude range to avoid # giant lists of 10,000 gridline locations. if len(ticks) == 0: @@ -282,50 +337,56 @@ def _pad_ticks(ticks, vmin, vmax): ticks = np.concatenate((ticks_lo, ticks, ticks_hi)) return ticks - def get_scale(self): + def get_scale(self) -> str: return "linear" - def get_tick_space(self): + def get_tick_space(self) -> int: return 9 # longstanding default of nbins=9 - def get_major_formatter(self): + def get_major_formatter(self) -> mticker.Formatter | None: return self.major.formatter - def get_major_locator(self): + def get_major_locator(self) -> mticker.Locator | None: return self.major.locator - def get_minor_locator(self): + def get_minor_locator(self) -> mticker.Locator | None: return self.minor.locator - def get_majorticklocs(self): + def get_majorticklocs(self) -> np.ndarray: return self._get_ticklocs(self.major.locator) - def get_minorticklocs(self): + def get_minorticklocs(self) -> np.ndarray: return self._get_ticklocs(self.minor.locator) - def set_major_formatter(self, formatter, default=False): + def set_major_formatter( + self, formatter: mticker.Formatter, default: bool = False + ) -> None: # NOTE: Cartopy formatters check Formatter.axis.axes.projection # in order to implement special projection-dependent behavior. self.major.formatter = formatter formatter.set_axis(self) self.isDefault_majfmt = default - def set_major_locator(self, locator, default=False): + def set_major_locator( + self, locator: mticker.Locator, default: bool = False + ) -> None: self.major.locator = locator if self.major.formatter: self.major.formatter._set_locator(locator) locator.set_axis(self) self.isDefault_majloc = default - def set_minor_locator(self, locator, default=False): + def set_minor_locator( + self, locator: mticker.Locator, default: bool = False + ) -> None: self.minor.locator = locator locator.set_axis(self) self.isDefault_majfmt = default - def set_view_interval(self, vmin, vmax): + def set_view_interval(self, vmin: float, vmax: float) -> None: self._interval = (vmin, vmax) - def _copy_locator_properties(self, other: "_GeoAxis"): + def _copy_locator_properties(self, other: "_GeoAxis") -> None: """ This function copies the locator properties. It is used when the @self is sharing with @other. @@ -353,6 +414,382 @@ def _copy_locator_properties(self, other: "_GeoAxis"): setattr(other, prop, this_prop) +class _GridlinerAdapter(Protocol): + """ + Lightweight facade used to normalize cartopy and basemap gridliner behavior. + These adapters let GeoAxes apply gridline label toggles and styles without + backend-specific branching. + """ + + def labels_for_sides( + self, + *, + bottom: bool | str | None = None, + top: bool | str | None = None, + left: bool | str | None = None, + right: bool | str | None = None, + ) -> dict[str, list[mtext.Text]]: ... + + def toggle_labels( + self, + *, + labelleft: bool | str | None = None, + labelright: bool | str | None = None, + labelbottom: bool | str | None = None, + labeltop: bool | str | None = None, + geo: bool | str | None = None, + ) -> None: ... + + def apply_style( + self, + *, + axis: str = "both", + pad: float | None = None, + labelsize: float | str | None = None, + labelcolor: Any = None, + labelrotation: float | None = None, + linecolor: Any = None, + linewidth: float | None = None, + ) -> None: ... + + def tick_positions( + self, axis: str, *, lonaxis: "_GeoAxis", lataxis: "_GeoAxis" + ) -> np.ndarray: ... + + def is_label_on(self, side: str) -> bool: ... + + +class _CartopyGridlinerProtocol(Protocol): + """ + Structural protocol for the subset of cartopy Gridliner attributes we use. + This keeps type hints tight without importing cartopy at runtime. + """ + + collection_kwargs: dict[str, Any] + xlabel_style: dict[str, Any] + ylabel_style: dict[str, Any] + xlocator: mticker.Locator + ylocator: mticker.Locator + xpadding: float | None + ypadding: float | None + xlines: bool + ylines: bool + x_inline: bool | None + y_inline: bool | None + rotate_labels: bool | None + inline_labels: bool | str | None + geo_labels: bool | str | None + left_label_artists: list[mtext.Text] + right_label_artists: list[mtext.Text] + bottom_label_artists: list[mtext.Text] + top_label_artists: list[mtext.Text] + xline_artists: list[Any] + + def _axes_domain(self, *args: Any, **kwargs: Any) -> tuple[Any, Any]: ... + def _draw_gridliner(self, *args: Any, **kwargs: Any) -> Any: ... + + +class _CartopyGridlinerAdapter(_GridlinerAdapter): + """ + Adapter for cartopy's Gridliner, translating common label/style operations + into the Gridliner API while hiding cartopy version differences. + """ + + def __init__(self, gridliner: Optional[_CartopyGridlinerProtocol]) -> None: + self.gridliner = gridliner + + @staticmethod + def _side_labels() -> tuple[str, str, str, str]: + # Cartopy label attribute names vary by version. + if _version_cartopy >= "0.18": + left_labels = "left_labels" + right_labels = "right_labels" + bottom_labels = "bottom_labels" + top_labels = "top_labels" + else: # cartopy < 0.18 + left_labels = "ylabels_left" + right_labels = "ylabels_right" + bottom_labels = "xlabels_bottom" + top_labels = "xlabels_top" + return (left_labels, right_labels, bottom_labels, top_labels) + + def labels_for_sides( + self, + *, + bottom: bool | str | None = None, + top: bool | str | None = None, + left: bool | str | None = None, + right: bool | str | None = None, + ) -> dict[str, list[mtext.Text]]: + sides = {} + gl = self.gridliner + if gl is None: + return sides + for dir, side in zip( + "bottom top left right".split(), [bottom, top, left, right] + ): + if side != True: + continue + sides[dir] = getattr(gl, f"{dir}_label_artists") + return sides + + def toggle_labels( + self, + *, + labelleft: bool | str | None = None, + labelright: bool | str | None = None, + labelbottom: bool | str | None = None, + labeltop: bool | str | None = None, + geo: bool | str | None = None, + ) -> None: + gl = self.gridliner + if gl is None: + return + side_labels = self._side_labels() + togglers = (labelleft, labelright, labelbottom, labeltop) + for toggle, side in zip(togglers, side_labels): + if toggle is not None: + setattr(gl, side, toggle) + if geo is not None: # only cartopy 0.20 supported but harmless + setattr(gl, "geo_labels", geo) + + def apply_style( + self, + *, + axis: str = "both", + pad: float | None = None, + labelsize: float | str | None = None, + labelcolor: Any = None, + labelrotation: float | None = None, + linecolor: Any = None, + linewidth: float | None = None, + ) -> None: + gl = self.gridliner + if gl is None: + return + + def _apply_label_style(style: dict[str, Any]) -> None: + if labelcolor is not None: + style["color"] = labelcolor + if labelsize is not None: + style["fontsize"] = labelsize + if labelrotation is not None: + style["rotation"] = labelrotation + + # Cartopy line styling is stored in the collection kwargs. + if linecolor is not None: + gl.collection_kwargs["color"] = linecolor + if linewidth is not None: + gl.collection_kwargs["linewidth"] = linewidth + if axis in ("x", "both"): + _apply_label_style(gl.xlabel_style) + if pad is not None and hasattr(gl, "xpadding"): + gl.xpadding = pad + if axis in ("y", "both"): + _apply_label_style(gl.ylabel_style) + if pad is not None and hasattr(gl, "ypadding"): + gl.ypadding = pad + + def tick_positions( + self, axis: str, *, lonaxis: _GeoAxis, lataxis: _GeoAxis + ) -> np.ndarray: + gl = self.gridliner + if gl is None: + return np.asarray([]) + if axis == "x": + locator = gl.xlocator + if locator is None: + return np.asarray([]) + return lonaxis._get_ticklocs(locator) + if axis == "y": + locator = gl.ylocator + if locator is None: + return np.asarray([]) + return lataxis._get_ticklocs(locator) + raise ValueError(f"Invalid axis: {axis!r}") + + def is_label_on(self, side: str) -> bool: + gl = self.gridliner + if gl is None: + return False + left_labels, right_labels, bottom_labels, top_labels = self._side_labels() + if side == "labelleft": + return getattr(gl, left_labels) + elif side == "labelright": + return getattr(gl, right_labels) + elif side == "labelbottom": + return getattr(gl, bottom_labels) + elif side == "labeltop": + return getattr(gl, top_labels) + else: + raise ValueError(f"Invalid side: {side}") + + +class _BasemapGridlinerAdapter(_GridlinerAdapter): + """ + Adapter for basemap meridian/parallel dictionaries, emulating the subset + of cartopy Gridliner behavior needed by GeoAxes (labels, toggles, styling). + """ + + def __init__( + self, + lonlines: GridlineDict | None, + latlines: GridlineDict | None, + ) -> None: + self.lonlines = lonlines + self.latlines = latlines + + def labels_for_sides( + self, + *, + bottom: bool | str | None = None, + top: bool | str | None = None, + left: bool | str | None = None, + right: bool | str | None = None, + ) -> dict[str, list[mtext.Text]]: + directions = "left right top bottom".split() + bools = [left, right, top, bottom] + sides = {} + for direction, is_on in zip(directions, bools): + if is_on is None: + continue + gl = self.lonlines + if direction in ["left", "right"]: + gl = self.latlines + for loc, (lines, labels) in (gl or {}).items(): + for label in labels: + # Determine side by label position (Basemap clusters by location). + position = label.get_position() + match direction: + case "top" if position[1] > 0: + add = True + case "bottom" if position[1] < 0: + add = True + case "left" if position[0] < 0: + add = True + case "right" if position[0] > 0: + add = True + case _: + add = False + if add: + sides.setdefault(direction, []).append(label) + return sides + + def toggle_labels( + self, + *, + labelleft: bool | str | None = None, + labelright: bool | str | None = None, + labelbottom: bool | str | None = None, + labeltop: bool | str | None = None, + geo: bool | str | None = None, + ) -> None: + labels = self.labels_for_sides( + bottom=labelbottom, top=labeltop, left=labelleft, right=labelright + ) + toggles = { + "bottom": labelbottom, + "top": labeltop, + "left": labelleft, + "right": labelright, + } + for direction, toggle in toggles.items(): + if toggle is None: + continue + for label in labels.get(direction, []): + label.set_visible(bool(toggle) or toggle in ("x", "y")) + + def apply_style( + self, + *, + axis: str = "both", + pad: float | None = None, + labelsize: float | str | None = None, + labelcolor: Any = None, + labelrotation: float | None = None, + linecolor: Any = None, + linewidth: float | None = None, + ) -> None: + pad # unused for basemap gridlines + targets = [] + if axis in ("x", "both"): + targets.append(self.lonlines) + if axis in ("y", "both"): + targets.append(self.latlines) + for gl in targets: + for loc, (lines, labels) in (gl or {}).items(): + # Basemap stores line artists and label text separately. + for line in lines: + if linecolor is not None and hasattr(line, "set_color"): + line.set_color(linecolor) + if linewidth is not None and hasattr(line, "set_linewidth"): + line.set_linewidth(linewidth) + for label in labels: + if labelcolor is not None: + label.set_color(labelcolor) + if labelsize is not None: + label.set_fontsize(labelsize) + if labelrotation is not None: + label.set_rotation(labelrotation) + + def tick_positions( + self, axis: str, *, lonaxis: _GeoAxis, lataxis: _GeoAxis + ) -> np.ndarray: + lonaxis, lataxis # unused; tick positions are stored in dict keys + if axis == "x": + locator = self.lonlines + elif axis == "y": + locator = self.latlines + else: + raise ValueError(f"Invalid axis: {axis!r}") + if not locator: + return np.asarray([]) + return np.asarray(list(locator.keys())) + + def is_label_on(self, side: str) -> bool: + def group_labels( + labels: list[mtext.Text], + which: str, + labelbottom: bool | str | None = None, + labeltop: bool | str | None = None, + labelleft: bool | str | None = None, + labelright: bool | str | None = None, + ) -> dict[str, list[mtext.Text]]: + group = {} + for label in labels: + position = label.get_position() + target = None + if which == "x": + if labelbottom is not None and position[1] < 0: + target = "labelbottom" + elif labeltop is not None and position[1] >= 0: + target = "labeltop" + else: + if labelleft is not None and position[0] < 0: + target = "labelleft" + elif labelright is not None and position[0] >= 0: + target = "labelright" + if target is not None: + group[target] = group.get(target, []) + [label] + return group + + gl = self.lonlines + which = "x" + if side in ["labelleft", "labelright"]: + gl = self.latlines + which = "y" + for loc, (line, labels) in (gl or {}).items(): + grouped = group_labels( + labels=labels, + which=which, + **{side: True}, + ) + for label in grouped.get(side, []): + if label.get_visible(): + return True + return False + + class _LonAxis(_GeoAxis): """ Axis with default longitude locator. @@ -363,7 +800,7 @@ class _LonAxis(_GeoAxis): # NOTE: Basemap accepts tick formatters with drawmeridians(fmt=Formatter()) # Try to use cartopy formatter if cartopy installed. Otherwise use # default builtin basemap formatting. - def __init__(self, axes): + def __init__(self, axes: "GeoAxes") -> None: super().__init__(axes) if self._use_dms: locator = formatter = "dmslon" @@ -376,7 +813,7 @@ def __init__(self, axes): self.set_major_locator(constructor.Locator(locator), default=True) self.set_minor_locator(mticker.AutoMinorLocator(), default=True) - def _get_ticklocs(self, locator): + def _get_ticklocs(self, locator: mticker.Locator) -> np.ndarray: # Prevent ticks from looping around # NOTE: Cartopy 0.17 formats numbers offset by eps with the cardinal indicator # (e.g. 0 degrees for map centered on 180 degrees). So skip in that case. @@ -413,7 +850,7 @@ def _get_ticklocs(self, locator): return ticks - def get_view_interval(self): + def get_view_interval(self) -> tuple[float, float]: # NOTE: ultraplot tries to set its *own* view intervals to avoid dateline # weirdness, but if rc['geo.extent'] is 'auto' the interval will be unset. # In this case we use _get_extent() as a backup. @@ -431,7 +868,7 @@ class _LatAxis(_GeoAxis): axis_name = "lat" - def __init__(self, axes, latmax=90): + def __init__(self, axes: "GeoAxes", latmax: float = 90) -> None: # NOTE: Need to pass projection because lataxis/lonaxis are # initialized before geoaxes is initialized, because format() needs # the axes and format() is called by ultraplot.axes.Axes.__init__() @@ -445,7 +882,7 @@ def __init__(self, axes, latmax=90): self.set_major_locator(constructor.Locator(locator), default=True) self.set_minor_locator(mticker.AutoMinorLocator(), default=True) - def _get_ticklocs(self, locator): + def _get_ticklocs(self, locator: mticker.Locator) -> np.ndarray: # Adjust latitude ticks to fix bug in some projections. Harmless for basemap. # NOTE: Maybe this was fixed by cartopy 0.18? eps = 1e-10 @@ -467,20 +904,64 @@ def _get_ticklocs(self, locator): return ticks - def get_latmax(self): + def get_latmax(self) -> float: return self._latmax - def get_view_interval(self): + def get_view_interval(self) -> tuple[float, float]: interval = self._interval if interval is None: extent = self._get_extent() interval = extent[2:] # latitudes return interval - def set_latmax(self, latmax): + def set_latmax(self, latmax: float) -> None: self._latmax = latmax +def _gridliner_sides_from_arrays( + lonarray: Sequence[bool | None] | None, + latarray: Sequence[bool | None] | None, + *, + order: Sequence[str], + allow_xy: bool, + include_false: bool, +) -> dict[str, bool | str]: + """ + Map lon/lat label arrays to gridliner toggle flags. + + Parameters + ---------- + allow_xy + Use "x"/"y" to preserve axis-specific toggles when only one of lon/lat + is enabled for a given side (cartopy behavior). + include_false + Include explicit False entries to actively hide existing labels instead + of leaving previous state untouched (backend-dependent behavior). + """ + if lonarray is None or latarray is None: + return {} + sides: dict[str, bool | str] = {} + for side, lon, lat in zip(order, lonarray, latarray): + value: bool | str | None = None + if allow_xy: + if lon and lat: + value = True + elif lon: + value = "x" + elif lat: + value = "y" + elif include_false and (lon is not None or lat is not None): + value = False + else: + if lon or lat: + value = True + elif include_false and (lon is not None or lat is not None): + value = False + if value is not None: + sides[side] = value + return sides + + class GeoAxes(shared._SharedAxes, plot.PlotAxes): """ Axes subclass for plotting in geographic projections. Uses either cartopy @@ -509,7 +990,7 @@ class GeoAxes(shared._SharedAxes, plot.PlotAxes): """ @docstring._snippet_manager - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: """ Parameters ---------- @@ -535,17 +1016,19 @@ def __init__(self, *args, **kwargs): ultraplot.figure.Figure.subplot ultraplot.figure.Figure.add_subplot """ + # Cache of backend-specific gridliner adapters (major/minor). + self._gridliner_adapters: dict[str, _GridlinerAdapter] = {} super().__init__(*args, **kwargs) @override - def _sharey_limits(self, sharey: "GeoAxes"): + def _sharey_limits(self, sharey: "GeoAxes") -> None: return self._share_limits_with(sharey, which="y") @override - def _sharex_limits(self, sharex: "GeoAxes"): + def _sharex_limits(self, sharex: "GeoAxes") -> None: return self._share_limits_with(sharex, which="x") - def _share_limits_with(self, other: "GeoAxes", which: str): + def _share_limits_with(self, other: "GeoAxes", which: str) -> None: """ Safely share limits and tickers without resetting things. """ @@ -563,7 +1046,7 @@ def _share_limits_with(self, other: "GeoAxes", which: str): getattr(self, f"share{which}")(other) this_ax._copy_locator_properties(other_ax) - def _is_rectilinear(self): + def _is_rectilinear(self) -> bool: return _is_rectilinear_projection(self) def __share_axis_setup( @@ -573,7 +1056,7 @@ def __share_axis_setup( which: str, labels: bool, limits: bool, - ): + ) -> None: level = getattr(self.figure, f"_share{which}") if getattr(self, f"_panel_share{which}_group") and self._is_panel_group_member( other @@ -595,7 +1078,9 @@ def __share_axis_setup( self._share_limits_with(other, which=which) @override - def _sharey_setup(self, sharey, *, labels=True, limits=True): + def _sharey_setup( + self, sharey: "GeoAxes", *, labels: bool = True, limits: bool = True + ) -> None: """ Configure shared axes accounting for panels. The input is the 'parent' axes, from which this one will draw its properties. @@ -604,12 +1089,14 @@ def _sharey_setup(self, sharey, *, labels=True, limits=True): return self.__share_axis_setup(sharey, which="y", labels=labels, limits=limits) @override - def _sharex_setup(self, sharex, *, labels=True, limits=True): + def _sharex_setup( + self, sharex: "GeoAxes", *, labels: bool = True, limits: bool = True + ) -> None: # Share panels across *different* subplots super()._sharex_setup(sharex, labels=labels, limits=limits) return self.__share_axis_setup(sharex, which="x", labels=labels, limits=limits) - def _toggle_ticks(self, label: "str | None", which: str): + def _toggle_ticks(self, label: str | None, which: str) -> None: """ Ticks are controlled by matplotlib independent of the backend. We can toggle ticks on and of depending on the desired position. """ @@ -647,7 +1134,113 @@ def _toggle_ticks(self, label: "str | None", which: str): f"Not toggling {label=}. Input was not understood. Valid values are ['left', 'right', 'top', 'bottom', 'all', 'both']" ) - def _apply_axis_sharing(self): + def _set_gridliner_adapter( + self, which: str, adapter: Optional[_GridlinerAdapter] + ) -> None: + if adapter is None: + self._gridliner_adapters.pop(which, None) + else: + self._gridliner_adapters[which] = adapter + + def _get_gridliner_adapter(self, which: str) -> Optional[_GridlinerAdapter]: + return self._gridliner_adapters.get(which) + + def _gridliner_adapter( + self, which: str, *, create: bool = True + ) -> Optional[_GridlinerAdapter]: + """ + Return a cached gridliner adapter, optionally creating it via the backend + builder when missing. + """ + adapter = self._get_gridliner_adapter(which) + if adapter is None and create: + builder = getattr(self, "_build_gridliner_adapter", None) + if builder is not None: + adapter = builder(which) + self._set_gridliner_adapter(which, adapter) + return adapter + + def _iter_gridliner_adapters(self, which: str) -> Iterator[_GridlinerAdapter]: + """ + Yield available gridliner adapters for the requested tick selection. + """ + if which in ("major", "both"): + adapter = self._gridliner_adapter("major") + if adapter is not None: + yield adapter + if which in ("minor", "both"): + adapter = self._gridliner_adapter("minor") + if adapter is not None: + yield adapter + + def _gridliner_tick_positions( + self, axis: str, *, which: str = "major" + ) -> np.ndarray: + """ + Return tick positions from the backend gridliner for a given axis. + """ + if axis not in ("x", "y"): + raise ValueError(f"Invalid axis: {axis!r}") + adapter = self._gridliner_adapter(which) + if adapter is None: + return np.asarray([]) + return adapter.tick_positions( + axis, lonaxis=self._lonaxis, lataxis=self._lataxis + ) + + @override + def tick_params(self, *args: Any, **kwargs: Any) -> Any: + """ + Apply tick parameters and mirror a subset of settings onto the backend + gridliner artists so gridline labels respond to common tick tweaks. + """ + result = super().tick_params(*args, **kwargs) + + axis = kwargs.get("axis", "both") + which = kwargs.get("which", "major") + pad = kwargs.get("pad", None) + labelsize = kwargs.get("labelsize", None) + labelcolor = kwargs.get( + "labelcolor", kwargs.get("colors", kwargs.get("color", None)) + ) + labelrotation = kwargs.get("labelrotation", None) + linecolor = kwargs.get("colors", kwargs.get("color", None)) + linewidth = kwargs.get("width", kwargs.get("linewidth", None)) + + adapters = tuple(self._iter_gridliner_adapters(which)) + if not adapters: + return result + + for adapter in adapters: + adapter.apply_style( + axis=axis, + pad=pad, + labelsize=labelsize, + labelcolor=labelcolor, + labelrotation=labelrotation, + linecolor=linecolor, + linewidth=linewidth, + ) + + # Toggle label visibility for major gridliners when requested. + if which in ("major", "both"): + adapter = self._gridliner_adapter("major") + toggles = {} + if axis in ("x", "both"): + for key in ("labelbottom", "labeltop"): + if key in kwargs: + toggles[key] = kwargs[key] + if axis in ("y", "both"): + for key in ("labelleft", "labelright"): + if key in kwargs: + toggles[key] = kwargs[key] + if toggles and adapter is not None: + adapter.toggle_labels(**toggles) + + self.stale = True + return result + + def _apply_axis_sharing(self) -> None: """ Enforce the "shared" axis labels and axis tick labels. If this is not called at drawtime, "shared" labels can be inadvertantly turned off. @@ -690,7 +1283,7 @@ def _apply_axis_sharing(self): self._lataxis.set_view_interval(*self._sharey._lataxis.get_view_interval()) self._lataxis.set_minor_locator(self._sharey._lataxis.get_minor_locator()) - def _apply_aspect_and_adjust_panels(self, *, tol=1e-9): + def _apply_aspect_and_adjust_panels(self, *, tol: float = 1e-9) -> None: """ Apply aspect and then align panels to the adjusted axes box. @@ -702,7 +1295,7 @@ def _apply_aspect_and_adjust_panels(self, *, tol=1e-9): self.apply_aspect() self._adjust_panel_positions(tol=tol) - def _adjust_panel_positions(self, *, tol=1e-9): + def _adjust_panel_positions(self, *, tol: float = 1e-9) -> None: """ Adjust panel positions to align with the aspect-constrained main axes. After apply_aspect() shrinks the main axes, panels should flank the actual @@ -828,23 +1421,32 @@ def _adjust_panel_positions(self, *, tol=1e-9): def _get_gridliner_labels( self, - bottom=None, - top=None, - left=None, - right=None, - ): - raise NotImplementedError("Should be implemented by Cartopy or Basemap Axes") + bottom: bool | str | None = None, + top: bool | str | None = None, + left: bool | str | None = None, + right: bool | str | None = None, + ) -> dict[str, list[mtext.Text]]: + adapter = self._gridliner_adapter("major") + if adapter is None: + return {} + return adapter.labels_for_sides( + bottom=bottom, + top=top, + left=left, + right=right, + ) def _toggle_gridliner_labels( self, - labeltop=None, - labelbottom=None, - labelleft=None, - labelright=None, - geo=None, - ): + labeltop: bool | str | None = None, + labelbottom: bool | str | None = None, + labelleft: bool | str | None = None, + labelright: bool | str | None = None, + geo: bool | str | None = None, + ) -> None: """ - Toggle visibility of gridliner labels for each direction. + Toggle visibility of gridliner labels for each direction via the backend + adapter. Parameters ---------- @@ -853,29 +1455,29 @@ def _toggle_gridliner_labels( geo : optional Not used in this method. """ - # Ensure gridlines_major is fully initialized - if any(i is None for i in self.gridlines_major): + adapter = self._gridliner_adapter("major") + if adapter is None: return - - gridlabels = self._get_gridliner_labels( - bottom=labelbottom, top=labeltop, left=labelleft, right=labelright + adapter.toggle_labels( + labelleft=labelleft, + labelright=labelright, + labelbottom=labelbottom, + labeltop=labeltop, + geo=geo, ) - toggles = { - "bottom": labelbottom, - "top": labeltop, - "left": labelleft, - "right": labelright, - } - - for direction, toggle in toggles.items(): - if toggle is None: - continue - for label in gridlabels.get(direction, []): - label.set_visible(bool(toggle) or toggle in ("x", "y")) + @override + def _is_ticklabel_on(self, side: str) -> bool: + """ + Check if tick labels are visible on the requested side via the backend adapter. + """ + adapter = self._gridliner_adapter("major") + if adapter is None: + return False + return adapter.is_label_on(side) @override - def draw(self, renderer=None, *args, **kwargs): + def draw(self, renderer: Any = None, *args: Any, **kwargs: Any) -> None: # Perform extra post-processing steps # NOTE: In *principle* axis sharing application step goes here. But should # already be complete because auto_layout() (called by figure pre-processor) @@ -883,7 +1485,7 @@ def draw(self, renderer=None, *args, **kwargs): self._apply_axis_sharing() super().draw(renderer, *args, **kwargs) - def _get_lonticklocs(self, which="major"): + def _get_lonticklocs(self, which: str = "major") -> np.ndarray: """ Retrieve longitude tick locations. """ @@ -898,7 +1500,7 @@ def _get_lonticklocs(self, which="major"): lines = axis.get_minorticklocs() return lines - def _get_latticklocs(self, which="major"): + def _get_latticklocs(self, which: str = "major") -> np.ndarray: """ Retrieve latitude tick locations. """ @@ -909,7 +1511,7 @@ def _get_latticklocs(self, which="major"): lines = axis.get_minorticklocs() return lines - def _set_view_intervals(self, extent): + def _set_view_intervals(self, extent: Sequence[float]) -> None: """ Update view intervals for lon and lat axis. """ @@ -917,7 +1519,7 @@ def _set_view_intervals(self, extent): self._lataxis.set_view_interval(*extent[2:]) @staticmethod - def _to_label_array(arg, lon=True): + def _to_label_array(arg: Any, lon: bool = True) -> list[bool | None]: """ Convert labels argument to length-5 boolean array. """ @@ -952,6 +1554,7 @@ def _to_label_array(arg, lon=True): for char in string: array["lrbtg".index(char)] = True if rc["grid.geolabels"] and any(array): + # Geo labels only apply if any edge labels are enabled. array[4] = True # possibly toggle geo spine labels elif not any(isinstance(_, str) for _ in array): if len(array) == 1: @@ -964,68 +1567,393 @@ def _to_label_array(arg, lon=True): if rc["grid.geolabels"] else None ) - array.append(b) - if len(array) != 5: - raise ValueError(f"Invald boolean label array length {len(array)}.") - else: - raise ValueError(f"Invalid {which}label spec: {arg}.") - return array + array.append(b) + if len(array) != 5: + raise ValueError(f"Invald boolean label array length {len(array)}.") + else: + raise ValueError(f"Invalid {which}label spec: {arg}.") + return array + + def _format_init_basemap_boundary(self) -> None: + """ + Initialize basemap boundaries before format triggers gridline work. + + Basemap can create a hidden boundary when gridlines are drawn before the + map boundary is initialized, so we force initialization here. + """ + if self._name != "basemap" or self._map_boundary is not None: + return + if self.projection.projection in self._proj_non_rectangular: + patch = self.projection.drawmapboundary(ax=self) + self._map_boundary = patch + else: + self.projection.set_axes_limits(self) # initialize aspect ratio + self._map_boundary = object() # sentinel + + def _format_rc_context( + self, + kwargs: MutableMapping[str, Any], + *, + ticklen: Any, + labelcolor: Any, + labelsize: Any, + labelweight: Any, + ) -> tuple[dict[str, Any], int, Any]: + """ + Pop rc overrides and prepare context settings for format(). + """ + rc_kw, rc_mode = _pop_rc(kwargs) + ticklen = _not_none(ticklen, rc_kw.get("tick.len", None)) + labelcolor = _not_none(labelcolor, kwargs.get("color", None)) + if labelcolor is not None: + rc_kw["grid.labelcolor"] = labelcolor + if labelsize is not None: + rc_kw["grid.labelsize"] = labelsize + if labelweight is not None: + rc_kw["grid.labelweight"] = labelweight + return rc_kw, rc_mode, ticklen + + def _format_normalize_label_inputs( + self, + *, + labels: Any, + lonlabels: Any, + latlabels: Any, + loninline: bool | None, + latinline: bool | None, + inlinelabels: bool | None, + ) -> tuple[Any, Any]: + """ + Normalize label inputs before rc context is applied. + """ + lonlabels = _not_none(lonlabels, labels) + latlabels = _not_none(latlabels, labels) + if "0.18" <= _version_cartopy < "0.20": + lonlabels = _not_none(lonlabels, loninline, inlinelabels) + latlabels = _not_none(latlabels, latinline, inlinelabels) + return lonlabels, latlabels + + def _format_resolve_label_arrays( + self, *, labels: Any, lonlabels: Any, latlabels: Any + ) -> tuple[Any, Any, list[bool | None], list[bool | None]]: + """ + Resolve label toggles and return label arrays for gridliners. + """ + if lonlabels is None and latlabels is None: + labels = _not_none(labels, rc.find("grid.labels", context=True)) + lonlabels = labels + latlabels = labels + else: + lonlabels = _not_none(lonlabels, labels) + latlabels = _not_none(latlabels, labels) + + self._toggle_ticks(lonlabels, "x") + self._toggle_ticks(latlabels, "y") + lonarray = self._to_label_array(lonlabels, lon=True) + latarray = self._to_label_array(latlabels, lon=False) + return lonlabels, latlabels, lonarray, latarray + + def _format_update_latmax(self, latmax: float | None) -> None: + """ + Update the latitude gridline cutoff. + """ + latmax = _not_none(latmax, rc.find("grid.latmax", context=True)) + if latmax is not None: + self._lataxis.set_latmax(latmax) + + def _format_update_major_locators( + self, + *, + lonlocator: Any, + lonlines: Any, + latlocator: Any, + latlines: Any, + lonlocator_kw: MutableMapping | None, + lonlines_kw: MutableMapping | None, + latlocator_kw: MutableMapping | None, + latlines_kw: MutableMapping | None, + ) -> None: + """ + Update major longitude/latitude locators. + """ + lonlocator = _not_none(lonlocator=lonlocator, lonlines=lonlines) + latlocator = _not_none(latlocator=latlocator, latlines=latlines) + if lonlocator is not None: + lonlocator_kw = _not_none( + lonlocator_kw=lonlocator_kw, + lonlines_kw=lonlines_kw, + default={}, + ) + locator = constructor.Locator(lonlocator, **lonlocator_kw) + self._lonaxis.set_major_locator(locator) + if latlocator is not None: + latlocator_kw = _not_none( + latlocator_kw=latlocator_kw, + latlines_kw=latlines_kw, + default={}, + ) + locator = constructor.Locator(latlocator, **latlocator_kw) + self._lataxis.set_major_locator(locator) + + def _format_update_minor_locators( + self, + *, + lonminorlocator: Any, + lonminorlines: Any, + latminorlocator: Any, + latminorlines: Any, + lonminorlocator_kw: MutableMapping | None, + lonminorlines_kw: MutableMapping | None, + latminorlocator_kw: MutableMapping | None, + latminorlines_kw: MutableMapping | None, + ) -> None: + """ + Update minor longitude/latitude locators. + """ + lonminorlocator = _not_none( + lonminorlocator=lonminorlocator, lonminorlines=lonminorlines + ) + latminorlocator = _not_none( + latminorlocator=latminorlocator, latminorlines=latminorlines + ) + if lonminorlocator is not None: + lonminorlocator_kw = _not_none( + lonminorlocator_kw=lonminorlocator_kw, + lonminorlines_kw=lonminorlines_kw, + default={}, + ) + locator = constructor.Locator(lonminorlocator, **lonminorlocator_kw) + self._lonaxis.set_minor_locator(locator) + if latminorlocator is not None: + latminorlocator_kw = _not_none( + latminorlocator_kw=latminorlocator_kw, + latminorlines_kw=latminorlines_kw, + default={}, + ) + locator = constructor.Locator(latminorlocator, **latminorlocator_kw) + self._lataxis.set_minor_locator(locator) + + def _format_resolve_gridline_params( + self, + *, + loninline: bool | None, + latinline: bool | None, + inlinelabels: bool | None, + rotatelabels: bool | None, + labelrotation: float | None, + lonlabelrotation: float | None, + latlabelrotation: float | None, + labelpad: Any, + dms: bool | None, + nsteps: int | None, + ) -> tuple[ + bool | None, + bool | None, + bool | None, + float | None, + float | None, + Any, + bool | None, + int | None, + ]: + """ + Resolve gridline-related parameters with rc defaults. + """ + loninline = _not_none( + loninline, inlinelabels, rc.find("grid.inlinelabels", context=True) + ) + latinline = _not_none( + latinline, inlinelabels, rc.find("grid.inlinelabels", context=True) + ) + rotatelabels = _not_none( + rotatelabels, rc.find("grid.rotatelabels", context=True) + ) + lonlabelrotation = _not_none(lonlabelrotation, labelrotation) + latlabelrotation = _not_none(latlabelrotation, labelrotation) + labelpad = _not_none(labelpad, rc.find("grid.labelpad", context=True)) + dms = _not_none(dms, rc.find("grid.dmslabels", context=True)) + nsteps = _not_none(nsteps, rc.find("grid.nsteps", context=True)) + return ( + loninline, + latinline, + rotatelabels, + lonlabelrotation, + latlabelrotation, + labelpad, + dms, + nsteps, + ) + + def _format_update_formatters( + self, + *, + lonformatter: Any, + latformatter: Any, + lonformatter_kw: MutableMapping | None, + latformatter_kw: MutableMapping | None, + dms: bool | None, + ) -> None: + """ + Update longitude/latitude formatters and DMS flags. + """ + if lonformatter is not None: + lonformatter_kw = lonformatter_kw or {} + formatter = constructor.Formatter(lonformatter, **lonformatter_kw) + self._lonaxis.set_major_formatter(formatter) + if latformatter is not None: + latformatter_kw = latformatter_kw or {} + formatter = constructor.Formatter(latformatter, **latformatter_kw) + self._lataxis.set_major_formatter(formatter) + if dms is not None: # harmless if these are not GeoLocators + self._lonaxis.get_major_formatter()._dms = dms + self._lataxis.get_major_formatter()._dms = dms + self._lonaxis.get_major_locator()._dms = dms + self._lataxis.get_major_locator()._dms = dms + + def _format_apply_grid_updates( + self, + *, + lonlim: tuple[float | None, float | None] | None, + latlim: tuple[float | None, float | None] | None, + boundinglat: float | None, + longrid: bool | None, + latgrid: bool | None, + longridminor: bool | None, + latgridminor: bool | None, + lonarray: Sequence[bool | None], + latarray: Sequence[bool | None], + loninline: bool | None, + latinline: bool | None, + rotatelabels: bool | None, + lonlabelrotation: float | None, + latlabelrotation: float | None, + labelpad: Any, + nsteps: int | None, + ) -> tuple[tuple[float | None, float | None], tuple[float | None, float | None]]: + """ + Apply extent, features, and gridline updates for format(). + """ + lonlim = _not_none(lonlim, default=(None, None)) + latlim = _not_none(latlim, default=(None, None)) + self._update_extent(lonlim=lonlim, latlim=latlim, boundinglat=boundinglat) + self._update_features() + self._update_major_gridlines( + longrid=longrid, + latgrid=latgrid, # gridline toggles + lonarray=lonarray, + latarray=latarray, # label toggles + loninline=loninline, + latinline=latinline, + rotatelabels=rotatelabels, + lonlabelrotation=lonlabelrotation, + latlabelrotation=latlabelrotation, + labelpad=labelpad, + nsteps=nsteps, + ) + self._update_minor_gridlines( + longrid=longridminor, + latgrid=latgridminor, + nsteps=nsteps, + ) + return lonlim, latlim + + def _format_apply_ticklen( + self, + *, + lonlim: tuple[float | None, float | None], + latlim: tuple[float | None, float | None], + boundinglat: float | None, + ticklen: Any, + lonticklen: Any, + latticklen: Any, + ) -> None: + """ + Apply tick length updates, including any extent refresh for geoticks. + """ + lonticklen = _not_none(lonticklen, ticklen) + latticklen = _not_none(latticklen, ticklen) + + if lonticklen or latticklen: + # Only add warning when ticks are given + if _is_rectilinear_projection(self): + self._add_geoticks("x", lonticklen, ticklen) + self._add_geoticks("y", latticklen, ticklen) + # If latlim is set to None it resets + # the view; this affects the visible range + # we need to force this to prevent + # side effects + if latlim == (None, None): + latlim = self._lataxis.get_view_interval() + if lonlim == (None, None): + lonlim = self._lonaxis.get_view_interval() + self._update_extent( + lonlim=lonlim, latlim=latlim, boundinglat=boundinglat + ) + else: + warnings._warn_ultraplot( + f"Projection is not rectilinear. Ignoring {lonticklen=} and {latticklen=} settings." + ) + # Format flow: + # 1) init basemap boundary + # 2) enter rc context and resolve label/locator/formatter inputs + # 3) apply extent, features, and gridlines + # 4) apply tick lengths and defer to parent format @docstring._snippet_manager def format( self, *, - extent=None, - round=None, - lonlim=None, - latlim=None, - boundinglat=None, - longrid=None, - latgrid=None, - longridminor=None, - latgridminor=None, - ticklen=None, - lonticklen=None, - latticklen=None, - latmax=None, - nsteps=None, - lonlocator=None, - lonlines=None, - latlocator=None, - latlines=None, - lonminorlocator=None, - lonminorlines=None, - latminorlocator=None, - latminorlines=None, - lonlocator_kw=None, - lonlines_kw=None, - latlocator_kw=None, - latlines_kw=None, - lonminorlocator_kw=None, - lonminorlines_kw=None, - latminorlocator_kw=None, - latminorlines_kw=None, - lonformatter=None, - latformatter=None, - lonformatter_kw=None, - latformatter_kw=None, - labels=None, - latlabels=None, - lonlabels=None, - rotatelabels=None, - labelrotation=None, - lonlabelrotation=None, - latlabelrotation=None, - loninline=None, - latinline=None, - inlinelabels=None, - dms=None, - labelpad=None, - labelcolor=None, - labelsize=None, - labelweight=None, - **kwargs, - ): + extent: str | None = None, + round: bool | None = None, + lonlim: tuple[float | None, float | None] | None = None, + latlim: tuple[float | None, float | None] | None = None, + boundinglat: float | None = None, + longrid: bool | None = None, + latgrid: bool | None = None, + longridminor: bool | None = None, + latgridminor: bool | None = None, + ticklen: Any = None, + lonticklen: Any = None, + latticklen: Any = None, + latmax: float | None = None, + nsteps: int | None = None, + lonlocator: Any = None, + lonlines: Any = None, + latlocator: Any = None, + latlines: Any = None, + lonminorlocator: Any = None, + lonminorlines: Any = None, + latminorlocator: Any = None, + latminorlines: Any = None, + lonlocator_kw: MutableMapping | None = None, + lonlines_kw: MutableMapping | None = None, + latlocator_kw: MutableMapping | None = None, + latlines_kw: MutableMapping | None = None, + lonminorlocator_kw: MutableMapping | None = None, + lonminorlines_kw: MutableMapping | None = None, + latminorlocator_kw: MutableMapping | None = None, + latminorlines_kw: MutableMapping | None = None, + lonformatter: Any = None, + latformatter: Any = None, + lonformatter_kw: MutableMapping | None = None, + latformatter_kw: MutableMapping | None = None, + labels: Any = None, + latlabels: Any = None, + lonlabels: Any = None, + rotatelabels: bool | None = None, + labelrotation: float | None = None, + lonlabelrotation: float | None = None, + latlabelrotation: float | None = None, + loninline: bool | None = None, + latinline: bool | None = None, + inlinelabels: bool | None = None, + dms: bool | None = None, + labelpad: Any = None, + labelcolor: Any = None, + labelsize: Any = None, + labelweight: Any = None, + **kwargs: Any, + ) -> None: """ Modify map limits, longitude and latitude gridlines, geographic features, and more. @@ -1045,38 +1973,22 @@ def format( ultraplot.axes.Axes.format ultraplot.config.Configurator.context """ - # Initialize map boundary - # WARNING: Normal workflow is Axes.format() does 'universal' tasks including - # updating the map boundary (in the future may also handle gridlines). However - # drawing gridlines before basemap map boundary will call set_axes_limits() - # which initializes a boundary hidden from external access. So we must call - # it here. Must do this between mpl.Axes.__init__() and base.Axes.format(). - # - if self._name == "basemap" and self._map_boundary is None: - if self.projection.projection in self._proj_non_rectangular: - patch = self.projection.drawmapboundary(ax=self) - self._map_boundary = patch - else: - self.projection.set_axes_limits(self) # initialize aspect ratio - self._map_boundary = object() # sentinel - - # Initiate context block - rc_kw, rc_mode = _pop_rc(kwargs) - ticklen = _not_none( - ticklen, rc_kw.get("tick.len", None) - ) # Don't pop this as it will only plot on a singular axis - lonlabels = _not_none(lonlabels, labels) - latlabels = _not_none(latlabels, labels) - if "0.18" <= _version_cartopy < "0.20": - lonlabels = _not_none(lonlabels, loninline, inlinelabels) - latlabels = _not_none(latlabels, latinline, inlinelabels) - labelcolor = _not_none(labelcolor, kwargs.get("color", None)) - if labelcolor is not None: - rc_kw["grid.labelcolor"] = labelcolor - if labelsize is not None: - rc_kw["grid.labelsize"] = labelsize - if labelweight is not None: - rc_kw["grid.labelweight"] = labelweight + self._format_init_basemap_boundary() + lonlabels, latlabels = self._format_normalize_label_inputs( + labels=labels, + lonlabels=lonlabels, + latlabels=latlabels, + loninline=loninline, + latinline=latinline, + inlinelabels=inlinelabels, + ) + rc_kw, rc_mode, ticklen = self._format_rc_context( + kwargs, + ticklen=ticklen, + labelcolor=labelcolor, + labelsize=labelsize, + labelweight=labelweight, + ) with rc.context(rc_kw, mode=rc_mode): # Apply extent mode first # NOTE: We deprecate autoextent on _CartopyAxes with _rename_kwargs which @@ -1090,151 +2002,93 @@ def format( # NOTE: Cartopy 0.18 and 0.19 inline labels require any of # top, bottom, left, or right to be toggled then ignores them. # Later versions of cartopy permit both or neither labels. - if lonlabels is None and latlabels is None: - labels = _not_none(labels, rc.find("grid.labels", context=True)) - lonlabels = labels - latlabels = labels - else: - lonlabels = _not_none(lonlabels, labels) - latlabels = _not_none(latlabels, labels) - # Set the ticks - self._toggle_ticks(lonlabels, "x") - self._toggle_ticks(latlabels, "y") - lonarray = self._to_label_array(lonlabels, lon=True) - latarray = self._to_label_array(latlabels, lon=False) - - # Update max latitude - latmax = _not_none(latmax, rc.find("grid.latmax", context=True)) - if latmax is not None: - self._lataxis.set_latmax(latmax) - - # Update major locators - lonlocator = _not_none(lonlocator=lonlocator, lonlines=lonlines) - latlocator = _not_none(latlocator=latlocator, latlines=latlines) - if lonlocator is not None: - lonlocator_kw = _not_none( - lonlocator_kw=lonlocator_kw, - lonlines_kw=lonlines_kw, - default={}, + lonlabels, latlabels, lonarray, latarray = ( + self._format_resolve_label_arrays( + labels=labels, + lonlabels=lonlabels, + latlabels=latlabels, ) - locator = constructor.Locator(lonlocator, **lonlocator_kw) - self._lonaxis.set_major_locator(locator) - if latlocator is not None: - latlocator_kw = _not_none( - latlocator_kw=latlocator_kw, - latlines_kw=latlines_kw, - default={}, - ) - locator = constructor.Locator(latlocator, **latlocator_kw) - self._lataxis.set_major_locator(locator) - - # Update minor locators - lonminorlocator = _not_none( - lonminorlocator=lonminorlocator, lonminorlines=lonminorlines ) - latminorlocator = _not_none( - latminorlocator=latminorlocator, latminorlines=latminorlines + self._format_update_latmax(latmax) + self._format_update_major_locators( + lonlocator=lonlocator, + lonlines=lonlines, + latlocator=latlocator, + latlines=latlines, + lonlocator_kw=lonlocator_kw, + lonlines_kw=lonlines_kw, + latlocator_kw=latlocator_kw, + latlines_kw=latlines_kw, ) - if lonminorlocator is not None: - lonminorlocator_kw = _not_none( - lonminorlocator_kw=lonminorlocator_kw, - lonminorlines_kw=lonminorlines_kw, - default={}, - ) - locator = constructor.Locator(lonminorlocator, **lonminorlocator_kw) - self._lonaxis.set_minor_locator(locator) - if latminorlocator is not None: - latminorlocator_kw = _not_none( - latminorlocator_kw=latminorlocator_kw, - latminorlines_kw=latminorlines_kw, - default={}, - ) - locator = constructor.Locator(latminorlocator, **latminorlocator_kw) - self._lataxis.set_minor_locator(locator) - - # Update formatters - loninline = _not_none( - loninline, inlinelabels, rc.find("grid.inlinelabels", context=True) - ) # noqa: E501 - latinline = _not_none( - latinline, inlinelabels, rc.find("grid.inlinelabels", context=True) - ) # noqa: E501 - rotatelabels = _not_none( - rotatelabels, rc.find("grid.rotatelabels", context=True) - ) # noqa: E501 - lonlabelrotation = _not_none(lonlabelrotation, labelrotation) - latlabelrotation = _not_none(latlabelrotation, labelrotation) - labelpad = _not_none(labelpad, rc.find("grid.labelpad", context=True)) - dms = _not_none(dms, rc.find("grid.dmslabels", context=True)) - nsteps = _not_none(nsteps, rc.find("grid.nsteps", context=True)) - lon0 = self._get_lon0() - - if lonformatter is not None: - lonformatter_kw = lonformatter_kw or {} - formatter = constructor.Formatter(lonformatter, **lonformatter_kw) - self._lonaxis.set_major_formatter(formatter) - if latformatter is not None: - latformatter_kw = latformatter_kw or {} - formatter = constructor.Formatter(latformatter, **latformatter_kw) - self._lataxis.set_major_formatter(formatter) - if dms is not None: # harmless if these are not GeoLocators - self._lonaxis.get_major_formatter()._dms = dms - self._lataxis.get_major_formatter()._dms = dms - self._lonaxis.get_major_locator()._dms = dms - self._lataxis.get_major_locator()._dms = dms - - # Apply worker extent, feature, and gridline functions - lonlim = _not_none(lonlim, default=(None, None)) - latlim = _not_none(latlim, default=(None, None)) - self._update_extent(lonlim=lonlim, latlim=latlim, boundinglat=boundinglat) - self._update_features() - self._update_major_gridlines( - longrid=longrid, - latgrid=latgrid, # gridline toggles - lonarray=lonarray, - latarray=latarray, # label toggles + self._format_update_minor_locators( + lonminorlocator=lonminorlocator, + lonminorlines=lonminorlines, + latminorlocator=latminorlocator, + latminorlines=latminorlines, + lonminorlocator_kw=lonminorlocator_kw, + lonminorlines_kw=lonminorlines_kw, + latminorlocator_kw=latminorlocator_kw, + latminorlines_kw=latminorlines_kw, + ) + ( + loninline, + latinline, + rotatelabels, + lonlabelrotation, + latlabelrotation, + labelpad, + dms, + nsteps, + ) = self._format_resolve_gridline_params( loninline=loninline, latinline=latinline, + inlinelabels=inlinelabels, rotatelabels=rotatelabels, + labelrotation=labelrotation, lonlabelrotation=lonlabelrotation, latlabelrotation=latlabelrotation, labelpad=labelpad, + dms=dms, nsteps=nsteps, ) - self._update_minor_gridlines( - longrid=longridminor, - latgrid=latgridminor, + self._format_update_formatters( + lonformatter=lonformatter, + latformatter=latformatter, + lonformatter_kw=lonformatter_kw, + latformatter_kw=latformatter_kw, + dms=dms, + ) + lonlim, latlim = self._format_apply_grid_updates( + lonlim=lonlim, + latlim=latlim, + boundinglat=boundinglat, + longrid=longrid, + latgrid=latgrid, + longridminor=longridminor, + latgridminor=latgridminor, + lonarray=lonarray, + latarray=latarray, + loninline=loninline, + latinline=latinline, + rotatelabels=rotatelabels, + lonlabelrotation=lonlabelrotation, + latlabelrotation=latlabelrotation, + labelpad=labelpad, nsteps=nsteps, ) - # Set tick lengths for flat projections - lonticklen = _not_none(lonticklen, ticklen) - latticklen = _not_none(latticklen, ticklen) - - if lonticklen or latticklen: - # Only add warning when ticks are given - if _is_rectilinear_projection(self): - self._add_geoticks("x", lonticklen, ticklen) - self._add_geoticks("y", latticklen, ticklen) - # If latlim is set to None it resets - # the view; this affects the visible range - # we need to force this to prevent - # side effects - if latlim == (None, None): - latlim = self._lataxis.get_view_interval() - if lonlim == (None, None): - lonlim = self._lonaxis.get_view_interval() - self._update_extent( - lonlim=lonlim, latlim=latlim, boundinglat=boundinglat - ) - else: - warnings._warn_ultraplot( - f"Projection is not rectilinear. Ignoring {lonticklen=} and {latticklen=} settings." - ) + self._format_apply_ticklen( + lonlim=lonlim, + latlim=latlim, + boundinglat=boundinglat, + ticklen=ticklen, + lonticklen=lonticklen, + latticklen=latticklen, + ) # Parent format method super().format(rc_kw=rc_kw, rc_mode=rc_mode, **kwargs) - def _add_geoticks(self, x_or_y, itick, ticklen): + def _add_geoticks(self, x_or_y: str, itick: Any, ticklen: Any) -> None: """ Add tick marks to the geographic axes. @@ -1257,32 +2111,19 @@ def _add_geoticks(self, x_or_y, itick, ticklen): # Skip if no tick size specified if size is None: return + # Convert unit spec to points and apply rc scaling factor. size = units(size) * rc["tick.len"] ax = getattr(self, f"{x_or_y}axis") - # Get the tick positions based on the locator - gl = self.gridlines_major - # Note: set_xticks points to a different method than self.[x/y]axis.set_ticks - # from the mpl backend. For basemap we are adding the ticks to the mpl backend - # and for cartopy we are simple using their functions by showing the axis. - if isinstance(gl, tuple): - locator = gl[0] if x_or_y == "x" else gl[1] - tick_positions = np.asarray(list(locator.keys())) - # Turn off the ticks otherwise they are double for - # basemap (different from cartopy) + # Get the tick positions based on the backend gridliner (adapter-aware). + adapter = self._gridliner_adapter("major") + is_basemap = self._name == "basemap" + tick_positions = self._gridliner_tick_positions(x_or_y, which="major") + if is_basemap: + # Turn off the ticks otherwise they are double for basemap. ax.set_major_formatter(mticker.NullFormatter()) - else: - if x_or_y == "x": - lim = self._lonaxis.get_view_interval() - locator = gl.xlocator - tick_positions = self._lonaxis._get_ticklocs(locator) - else: - lim = self._lataxis.get_view_interval() - locator = gl.ylocator - tick_positions = self._lataxis._get_ticklocs(locator) - # Always show the ticks ax.set_ticks(tick_positions) ax.set_visible(True) @@ -1290,7 +2131,11 @@ def _add_geoticks(self, x_or_y, itick, ticklen): # Note: set grid_alpha to 0 as it is controlled through the gridlines_major # object (which is not the same ticker) params = ax.get_tick_params() - sizes = [size, 0.6 * size if isinstance(size, (int, float)) else size] + # Minor ticks are shortened relative to major ticks. + sizes = [ + size, + _MINOR_TICK_SCALE * size if isinstance(size, (int, float)) else size, + ] for size, which in zip(sizes, ["major", "minor"]): params.update({"length": size}) params.pop("grid_alpha", None) @@ -1302,16 +2147,24 @@ def _add_geoticks(self, x_or_y, itick, ticklen): ) # Apply tick parameters # Move the labels outwards if specified - if hasattr(gl, f"{x_or_y}padding"): - setattr(gl, f"{x_or_y}padding", 2 * size) - elif isinstance(gl, tuple): - # For basemap backends, emulate the label placement - # like how cartopy does this - self._add_gridline_labels(ax, gl, padding=size) + gl = getattr(self, "_gridlines_major", None) + if gl is not None and hasattr(gl, f"{x_or_y}padding"): + # Cartopy gridliner padding is in points; scale matches tick size visually. + setattr(gl, f"{x_or_y}padding", _GRIDLINER_PAD_SCALE * size) + elif is_basemap and isinstance(adapter, _BasemapGridlinerAdapter): + # For basemap backends, emulate the label placement like cartopy. + self._add_gridline_labels( + ax, (adapter.lonlines, adapter.latlines), padding=size + ) self.stale = True - def _add_gridline_labels(self, ax, gl, padding=8): + def _add_gridline_labels( + self, + ax: maxis.Axis, + gl: tuple[GridlineDict, GridlineDict], + padding: float | int = 8, + ) -> None: """ This function is intended for the Basemap backend and mirrors the label placement behavior of Cartopy. @@ -1345,9 +2198,9 @@ def _add_gridline_labels(self, ax, gl, padding=8): which_line = 1 if shift_scale == 1 else 2 tickline = getattr(tick, f"tick{which_line}line") position = np.array(label.get_position()) - # Magic numbers are judged by eye (not great) + # Convert points to display units using DPI (72 points per inch). size = ( - 0.5 + _BASEMAP_LABEL_SIZE_SCALE * (tick._size + label.get_fontsize() + padding) * self.figure.dpi / 72 @@ -1359,7 +2212,10 @@ def _add_gridline_labels(self, ax, gl, padding=8): if which == "x": # Move y position - position[1] = offset[1] + shift_scale * size * 0.65 + # Empirical scaling to mimic cartopy label spacing. + position[1] = ( + offset[1] + shift_scale * size * _BASEMAP_LABEL_Y_SCALE + ) ha = "center" va = "top" if shift_scale == 1 else "bottom" if shift_scale == 1: @@ -1369,7 +2225,10 @@ def _add_gridline_labels(self, ax, gl, padding=8): else: # Move x position - position[0] = offset[0] + shift_scale * size * 0.25 + # Empirical scaling to mimic cartopy label spacing. + position[0] = ( + offset[0] + shift_scale * size * _BASEMAP_LABEL_X_SCALE + ) ha = "left" if shift_scale == 1 else "right" va = "center" if shift_scale == 1: @@ -1394,7 +2253,7 @@ def _add_gridline_labels(self, ax, gl, padding=8): label.set_visible(False) @property - def gridlines_major(self): + def gridlines_major(self) -> Any: """ The cartopy `~cartopy.mpl.gridliner.Gridliner` used for major gridlines or a 2-tuple containing the @@ -1403,13 +2262,17 @@ def gridlines_major(self): and :func:`~mpl_toolkits.basemap.Basemap.drawparallels`. This can be used for customization and debugging. """ + # Refresh adapters so external access sees up-to-date gridliner state. + builder = getattr(self, "_build_gridliner_adapter", None) + if builder is not None: + self._set_gridliner_adapter("major", builder("major")) if self._name == "basemap": return (self._lonlines_major, self._latlines_major) else: return self._gridlines_major @property - def gridlines_minor(self): + def gridlines_minor(self) -> Any: """ The cartopy `~cartopy.mpl.gridliner.Gridliner` used for minor gridlines or a 2-tuple containing the @@ -1418,13 +2281,17 @@ def gridlines_minor(self): and :func:`~mpl_toolkits.basemap.Basemap.drawparallels`. This can be used for customization and debugging. """ + # Refresh adapters so external access sees up-to-date gridliner state. + builder = getattr(self, "_build_gridliner_adapter", None) + if builder is not None: + self._set_gridliner_adapter("minor", builder("minor")) if self._name == "basemap": return (self._lonlines_minor, self._latlines_minor) else: return self._gridlines_minor @property - def projection(self): + def projection(self) -> Any: """ The cartopy `~cartopy.crs.Projection` or basemap `~mpl_toolkits.basemap.Basemap` instance associated with this axes. @@ -1432,7 +2299,7 @@ def projection(self): return self._map_projection @projection.setter - def projection(self, map_projection): + def projection(self, map_projection: Any) -> None: cls = self._proj_class if not isinstance(map_projection, cls): raise ValueError(f"Projection must be a {cls} instance.") @@ -1469,7 +2336,7 @@ class _CartopyAxes(GeoAxes, _GeoAxes): # NOTE: The rename argument wrapper belongs here instead of format() because # these arguments were previously only accepted during initialization. @warnings._rename_kwargs("0.10", circular="round", autoextent="extent") - def __init__(self, *args, map_projection=None, **kwargs): + def __init__(self, *args: Any, map_projection: Any = None, **kwargs: Any) -> None: """ Parameters ---------- @@ -1501,7 +2368,7 @@ def __init__(self, *args, map_projection=None, **kwargs): axis.set_tick_params(which="both", size=0) # prevent extra label offset @staticmethod - def _get_circle_path(N=100): + def _get_circle_path(N: int = 100) -> mpath.Path: """ Return a circle `~matplotlib.path.Path` used as the outline for polar stereographic, azimuthal equidistant, Lambert conformal, and gnomonic @@ -1513,131 +2380,97 @@ def _get_circle_path(N=100): verts = np.vstack([np.sin(theta), np.cos(theta)]).T return mpath.Path(verts * radius + center) - def _get_global_extent(self): + def _get_global_extent(self) -> list[float]: """ Return the global extent with meridian properly shifted. """ lon0 = self._get_lon0() return [-180 + lon0, 180 + lon0, -90, 90] - def _get_lon0(self): + def _get_lon0(self) -> float: """ Get the central longitude. Default is ``0``. """ return self.projection.proj4_params.get("lon_0", 0) - def _init_gridlines(self): - """ - Create monkey patched "major" and "minor" gridliners managed by ultraplot. + def gridlines( + self, + crs: Any = None, + draw_labels: bool | str | None = False, + xlocs: mticker.Locator | Sequence[float] | None = None, + ylocs: mticker.Locator | Sequence[float] | None = None, + dms: bool = False, + x_inline: bool | None = None, + y_inline: bool | None = None, + auto_inline: bool = True, + xformatter: Any = None, + yformatter: Any = None, + xlim: Sequence[float] | None = None, + ylim: Sequence[float] | None = None, + rotate_labels: bool | float | None = None, + xlabel_style: MutableMapping[str, Any] | None = None, + ylabel_style: MutableMapping[str, Any] | None = None, + labels_bbox_style: MutableMapping[str, Any] | None = None, + xpadding: float | None = 5, + ypadding: float | None = 5, + offset_angle: float = 25, + auto_update: bool | None = None, + formatter_kwargs: MutableMapping[str, Any] | None = None, + **kwargs: Any, + ) -> _CartopyGridlinerProtocol: + """ + Override cartopy gridlines to use a local Gridliner subclass. """ - - # Cartopy < 0.18 monkey patch. Helps filter valid coordates to lon_0 +/- 180 - def _axes_domain(self, *args, **kwargs): - x_range, y_range = type(self)._axes_domain(self, *args, **kwargs) - if _version_cartopy < "0.18": - lon_0 = self.axes.projection.proj4_params.get("lon_0", 0) - x_range = np.asarray(x_range) + lon_0 - return x_range, y_range - - # Cartopy >= 0.18 monkey patch. Fixes issue where cartopy draws an overlapping - # dateline gridline (e.g. polar maps). See the nx -= 1 line in _draw_gridliner - def _draw_gridliner(self, *args, **kwargs): # noqa: E306 - result = type(self)._draw_gridliner(self, *args, **kwargs) - if _version_cartopy >= "0.18": - lon_lim, _ = self._axes_domain() - if abs(np.diff(lon_lim)) == abs(np.diff(self.crs.x_limits)): - for collection in self.xline_artists: - if not getattr(collection, "_cartopy_fix", False): - collection.get_paths().pop(-1) - collection._cartopy_fix = True - return result - - # Return the gridliner with monkey patch - gl = self.gridlines(crs=ccrs.PlateCarree()) - gl._axes_domain = _axes_domain.__get__(gl) - gl._draw_gridliner = _draw_gridliner.__get__(gl) - gl.xlines = gl.ylines = False + if crs is None: + crs = ccrs.PlateCarree(globe=self.projection.globe) + gridliner_cls = _CartopyGridliner or cgridliner.Gridliner + gl = gridliner_cls( + self, + crs=crs, + draw_labels=draw_labels, + xlocator=xlocs, + ylocator=ylocs, + collection_kwargs=kwargs, + dms=dms, + x_inline=x_inline, + y_inline=y_inline, + auto_inline=auto_inline, + xformatter=xformatter, + yformatter=yformatter, + xlim=xlim, + ylim=ylim, + rotate_labels=rotate_labels, + xlabel_style=xlabel_style, + ylabel_style=ylabel_style, + labels_bbox_style=labels_bbox_style, + xpadding=xpadding, + ypadding=ypadding, + offset_angle=offset_angle, + auto_update=auto_update, + formatter_kwargs=formatter_kwargs, + ) + self.add_artist(gl) return gl - @override - def _get_gridliner_labels( - self, - bottom=None, - top=None, - left=None, - right=None, - ) -> dict[str, list[mtext.Text]]: - sides = {} - for dir, side in zip( - "bottom top left right".split(), [bottom, top, left, right] - ): - if side != True: - continue - if self.gridlines_major is None: - continue - sides[dir] = getattr(self.gridlines_major, f"{dir}_label_artists") - return sides - - @staticmethod - def _get_side_labels() -> tuple: - if _version_cartopy >= "0.18": - left_labels = "left_labels" - right_labels = "right_labels" - bottom_labels = "bottom_labels" - top_labels = "top_labels" - else: # cartopy < 0.18 - left_labels = "ylabels_left" - right_labels = "ylabels_right" - bottom_labels = "xlabels_bottom" - top_labels = "xlabels_top" - return (left_labels, right_labels, bottom_labels, top_labels) - - @override - def _is_ticklabel_on(self, side: str) -> bool: + def _init_gridlines(self) -> _CartopyGridlinerProtocol: """ - Helper function to check if tick labels are on for a given side. + Create "major" and "minor" gridliners managed by ultraplot. """ - # Deal with different cartopy versions - left_labels, right_labels, bottom_labels, top_labels = self._get_side_labels() - - if self.gridlines_major is None: - return False - elif side == "labelleft": - return getattr(self.gridlines_major, left_labels) - elif side == "labelright": - return getattr(self.gridlines_major, right_labels) - elif side == "labelbottom": - return getattr(self.gridlines_major, bottom_labels) - elif side == "labeltop": - return getattr(self.gridlines_major, top_labels) - else: - raise ValueError(f"Invalid side: {side}") - @override - def _toggle_gridliner_labels( - self, - labelleft=None, - labelright=None, - labelbottom=None, - labeltop=None, - geo=None, - ): - """ - Toggle gridliner labels across different cartopy versions. - """ - # Retrieve the property name depending - # on cartopy version. - side_labels = _CartopyAxes._get_side_labels() - togglers = (labelleft, labelright, labelbottom, labeltop) - gl = self.gridlines_major + # Return gridliner using our subclass to isolate cartopy quirks. + gl = self.gridlines(crs=ccrs.PlateCarree()) + gl.xlines = gl.ylines = False + return gl - for toggle, side in zip(togglers, side_labels): - if toggle is not None: - setattr(gl, side, toggle) - if geo is not None: # only cartopy 0.20 supported but harmless - setattr(gl, "geo_labels", geo) + def _build_gridliner_adapter( + self, which: str = "major" + ) -> Optional[_GridlinerAdapter]: + gl = getattr(self, f"_gridlines_{which}", None) + if gl is None: + return None + return _CartopyGridlinerAdapter(gl) - def _update_background(self, **kwargs): + def _update_background(self, **kwargs: Any) -> None: """ Update the map background patches. This is called in `Axes.format`. """ @@ -1656,7 +2489,7 @@ def _update_background(self, **kwargs): self.background_patch.update(kw_face) self.outline_patch.update(kw_edge) - def _update_boundary(self, round=None): + def _update_boundary(self, round: bool | None = None) -> None: """ Update the map boundary path. """ @@ -1672,7 +2505,9 @@ def _update_boundary(self, round=None): else: warnings._warn_ultraplot("Failed to reset round map boundary.") - def _update_extent_mode(self, extent=None, boundinglat=None): + def _update_extent_mode( + self, extent: str | None = None, boundinglat: float | None = None + ) -> None: """ Update the extent mode. """ @@ -1706,7 +2541,12 @@ def _update_extent_mode(self, extent=None, boundinglat=None): self.set_autoscalex_on(True) self.set_autoscaley_on(True) - def _update_extent(self, lonlim=None, latlim=None, boundinglat=None): + def _update_extent( + self, + lonlim: tuple[float | None, float | None] | None = None, + latlim: tuple[float | None, float | None] | None = None, + boundinglat: float | None = None, + ) -> None: """ Set the projection extent. """ @@ -1769,7 +2609,7 @@ def _update_extent(self, lonlim=None, latlim=None, boundinglat=None): extent = lonlim + latlim self.set_extent(extent, crs=ccrs.PlateCarree()) - def _update_features(self): + def _update_features(self) -> None: """ Update geographic features. """ @@ -1824,12 +2664,12 @@ def _update_features(self): def _update_gridlines( self, - gl, - which="major", - longrid=None, - latgrid=None, - nsteps=None, - ): + gl: _CartopyGridlinerProtocol, + which: str = "major", + longrid: bool | None = None, + latgrid: bool | None = None, + nsteps: int | None = None, + ) -> None: """ Update gridliner object with axis locators, and toggle gridlines on and off. """ @@ -1865,18 +2705,18 @@ def _update_gridlines( def _update_major_gridlines( self, - longrid=None, - latgrid=None, - lonarray=None, - latarray=None, - loninline=None, - latinline=None, - labelpad=None, - rotatelabels=None, - lonlabelrotation=None, - latlabelrotation=None, - nsteps=None, - ): + longrid: bool | None = None, + latgrid: bool | None = None, + lonarray: Sequence[bool | None] | None = None, + latarray: Sequence[bool | None] | None = None, + loninline: bool | None = None, + latinline: bool | None = None, + labelpad: Any = None, + rotatelabels: bool | None = None, + lonlabelrotation: float | None = None, + latlabelrotation: float | None = None, + nsteps: int | None = None, + ) -> None: """ Update major gridlines. """ @@ -1940,24 +2780,27 @@ def _update_major_gridlines( f"{type(self.projection).__name__} projection." ) lonarray = [False] * 5 - sides = dict() - # The ordering of these sides are important. The arrays are ordered lrbtg - for side, lon, lat in zip( - "labelleft labelright labelbottom labeltop geo".split(), lonarray, latarray - ): - sides[side] = None - if lon and lat: - sides[side] = True - elif lon: - sides[side] = "x" - elif lat: - sides[side] = "y" - elif lon is not None or lat is not None: - sides[side] = False + # The ordering of these sides are important. The arrays are ordered lrbtg. + sides = _gridliner_sides_from_arrays( + lonarray, + latarray, + order=_CARTOPY_LABEL_SIDES, + allow_xy=True, + include_false=True, + ) + if not sides and lonarray is not None and latarray is not None: + # Preserve legacy behavior by calling the toggle even for no-op arrays. + sides = {side: None for side in _CARTOPY_LABEL_SIDES} if sides: self._toggle_gridliner_labels(**sides) + self._set_gridliner_adapter("major", self._build_gridliner_adapter("major")) - def _update_minor_gridlines(self, longrid=None, latgrid=None, nsteps=None): + def _update_minor_gridlines( + self, + longrid: bool | None = None, + latgrid: bool | None = None, + nsteps: int | None = None, + ) -> None: """ Update minor gridlines. """ @@ -1971,8 +2814,9 @@ def _update_minor_gridlines(self, longrid=None, latgrid=None, nsteps=None): latgrid=latgrid, nsteps=nsteps, ) + self._set_gridliner_adapter("minor", self._build_gridliner_adapter("minor")) - def get_extent(self, crs=None): + def get_extent(self, crs: Any = None) -> Sequence[float]: # Get extent and try to repair longitude bounds. if crs is None: crs = ccrs.PlateCarree() @@ -1987,7 +2831,7 @@ def get_extent(self, crs=None): return extent @override - def draw(self, renderer=None, *args, **kwargs): + def draw(self, renderer: Any = None, *args: Any, **kwargs: Any) -> None: """ Override draw to adjust panel positions for cartopy axes. @@ -1998,7 +2842,7 @@ def draw(self, renderer=None, *args, **kwargs): super().draw(renderer, *args, **kwargs) self._adjust_panel_positions(tol=self._PANEL_TOL) - def get_tightbbox(self, renderer, *args, **kwargs): + def get_tightbbox(self, renderer: Any, *args: Any, **kwargs: Any) -> Any: # Perform extra post-processing steps # For now this just draws the gridliners self._apply_axis_sharing() @@ -2037,7 +2881,7 @@ def get_tightbbox(self, renderer, *args, **kwargs): return super().get_tightbbox(renderer, *args, **kwargs) - def set_extent(self, extent, crs=None): + def set_extent(self, extent: Sequence[float], crs: Any = None) -> Any: # Fix paths, so axes tight bounding box gets correct box! From this issue: # https://github.com/SciTools/cartopy/issues/1207#issuecomment-439975083 # Also record the requested longitude latitude extent so we can use these @@ -2063,7 +2907,7 @@ def set_extent(self, extent, crs=None): self.background_patch._path = clipped_path return super().set_extent(extent, crs=crs) - def set_global(self): + def set_global(self) -> Any: # Set up "global" extent and update _LatAxis and _LonAxis view intervals result = super().set_global() self._set_view_intervals(self._get_global_extent()) @@ -2095,7 +2939,7 @@ class _BasemapAxes(GeoAxes): ) _PANEL_TOL = 1e-6 - def __init__(self, *args, map_projection=None, **kwargs): + def __init__(self, *args: Any, map_projection: Any = None, **kwargs: Any) -> None: """ Parameters ---------- @@ -2144,7 +2988,7 @@ def __init__(self, *args, map_projection=None, **kwargs): self._turnoff_tick_labels(self._lonlines_major) self._turnoff_tick_labels(self._latlines_major) - def get_tightbbox(self, renderer, *args, **kwargs): + def get_tightbbox(self, renderer: Any, *args: Any, **kwargs: Any) -> Any: """ Get tight bounding box, adjusting panel positions after aspect is applied. @@ -2157,7 +3001,7 @@ def get_tightbbox(self, renderer, *args, **kwargs): return super().get_tightbbox(renderer, *args, **kwargs) @override - def draw(self, renderer=None, *args, **kwargs): + def draw(self, renderer: Any = None, *args: Any, **kwargs: Any) -> None: """ Override draw to adjust panel positions for basemap axes. @@ -2167,7 +3011,7 @@ def draw(self, renderer=None, *args, **kwargs): super().draw(renderer, *args, **kwargs) self._adjust_panel_positions(tol=self._PANEL_TOL) - def _turnoff_tick_labels(self, locator: mticker.Formatter): + def _turnoff_tick_labels(self, locator: GridlineDict) -> None: """ For GeoAxes with are dealing with a duality. Basemap axes behave differently than Cartopy axes and vice versa. UltraPlot abstracts away from these by providing GeoAxes. For basemap axes we need to turn off the tick labels as they will be handles by GeoAxis """ @@ -2179,48 +3023,14 @@ def _turnoff_tick_labels(self, locator: mticker.Formatter): if isinstance(object, mtext.Text): object.set_visible(False) - def _get_gridliner_labels( - self, - bottom=None, - top=None, - left=None, - right=None, - ): - directions = "left right top bottom".split() - bools = [left, right, top, bottom] - sides = {} - for direction, is_on in zip(directions, bools): - if is_on is None: - continue - gl = self.gridlines_major[0] - if direction in ["left", "right"]: - gl = self.gridlines_major[1] - for loc, (lines, labels) in gl.items(): - for label in labels: - position = label.get_position() - match direction: - case "top" if position[1] > 0: - add = True - case "bottom" if position[1] < 0: - add = True - case "left" if position[0] < 0: - add = True - case "right" if position[0] > 0: - add = True - case _: - add = False - if add: - sides.setdefault(direction, []).append(label) - return sides - - def _get_lon0(self): + def _get_lon0(self) -> float: """ Get the central longitude. """ return getattr(self.projection, "projparams", {}).get("lon_0", 0) @staticmethod - def _iter_gridlines(dict_): + def _iter_gridlines(dict_: GridlineDict | None) -> Iterator[Any]: """ Iterate over longitude latitude lines. """ @@ -2230,7 +3040,16 @@ def _iter_gridlines(dict_): for obj in pj: yield obj - def _update_background(self, **kwargs): + def _build_gridliner_adapter( + self, which: str = "major" + ) -> Optional[_GridlinerAdapter]: + lonlines = getattr(self, f"_lonlines_{which}", None) + latlines = getattr(self, f"_latlines_{which}", None) + if lonlines is None or latlines is None: + return None + return _BasemapGridlinerAdapter(lonlines, latlines) + + def _update_background(self, **kwargs: Any) -> None: """ Update the map boundary patches. This is called in `Axes.format`. """ @@ -2249,7 +3068,7 @@ def _update_background(self, **kwargs): for spine in self.spines.values(): spine.update(kw_edge) - def _update_boundary(self, round=None): + def _update_boundary(self, round: bool | None = None) -> None: """ No-op. Boundary mode cannot be changed in basemap. """ @@ -2263,7 +3082,9 @@ def _update_boundary(self, round=None): "instead (e.g. using the uplt.subplots() dictionary keyword 'proj_kw')." ) - def _update_extent_mode(self, extent=None, boundinglat=None): # noqa: U100 + def _update_extent_mode( + self, extent: str | None = None, boundinglat: float | None = None + ) -> None: # noqa: U100 """ No-op. Extent mode cannot be changed in basemap. """ @@ -2280,7 +3101,12 @@ def _update_extent_mode(self, extent=None, boundinglat=None): # noqa: U100 "in basemap projections. Please consider switching to cartopy." ) - def _update_extent(self, lonlim=None, latlim=None, boundinglat=None): + def _update_extent( + self, + lonlim: tuple[float | None, float | None] | None = None, + latlim: tuple[float | None, float | None] | None = None, + boundinglat: float | None = None, + ) -> None: """ No-op. Map bounds cannot be changed in basemap. """ @@ -2297,7 +3123,7 @@ def _update_extent(self, lonlim=None, latlim=None, boundinglat=None): "'width', or 'height'." ) - def _update_features(self): + def _update_features(self) -> None: """ Update geographic features. """ @@ -2329,14 +3155,14 @@ def _update_features(self): def _update_gridlines( self, - which="major", - longrid=None, - latgrid=None, - lonarray=None, - latarray=None, - lonlabelrotation=None, - latlabelrotation=None, - ): + which: str = "major", + longrid: bool | None = None, + latgrid: bool | None = None, + lonarray: Sequence[bool | None] | None = None, + latarray: Sequence[bool | None] | None = None, + lonlabelrotation: float | None = None, + latlabelrotation: float | None = None, + ) -> None: """ Apply changes to the basemap axes. """ @@ -2416,18 +3242,18 @@ def _update_gridlines( def _update_major_gridlines( self, - longrid=None, - latgrid=None, - lonarray=None, - latarray=None, - loninline=None, - latinline=None, - rotatelabels=None, - lonlabelrotation=None, - latlabelrotation=None, - labelpad=None, - nsteps=None, - ): + longrid: bool | None = None, + latgrid: bool | None = None, + lonarray: Sequence[bool | None] | None = None, + latarray: Sequence[bool | None] | None = None, + loninline: bool | None = None, + latinline: bool | None = None, + rotatelabels: bool | None = None, + lonlabelrotation: float | None = None, + latlabelrotation: float | None = None, + labelpad: Any = None, + nsteps: int | None = None, + ) -> None: """ Update major gridlines. """ @@ -2441,15 +3267,23 @@ def _update_major_gridlines( lonlabelrotation=lonlabelrotation, latlabelrotation=latlabelrotation, ) - sides = {} - for side, lonon, laton in zip( - "labelleft labelright labeltop labelbottom geo".split(), lonarray, latarray - ): - if lonon or laton: - sides[side] = True - self._toggle_gridliner_labels(**sides) + sides = _gridliner_sides_from_arrays( + lonarray, + latarray, + order=_BASEMAP_LABEL_SIDES, + allow_xy=False, + include_false=False, + ) + if sides: + self._toggle_gridliner_labels(**sides) + self._set_gridliner_adapter("major", self._build_gridliner_adapter("major")) - def _update_minor_gridlines(self, longrid=None, latgrid=None, nsteps=None): + def _update_minor_gridlines( + self, + longrid: bool | None = None, + latgrid: bool | None = None, + nsteps: int | None = None, + ) -> None: """ Update minor gridlines. """ @@ -2465,6 +3299,7 @@ def _update_minor_gridlines(self, longrid=None, latgrid=None, nsteps=None): lonlabelrotation=None, latlabelrotation=None, ) + self._set_gridliner_adapter("minor", self._build_gridliner_adapter("minor")) # Set isDefault_majloc, etc. to True for both axes # NOTE: This cannot be done inside _update_gridlines or minor gridlines # will not update to reflect new major gridline locations. @@ -2473,70 +3308,13 @@ def _update_minor_gridlines(self, longrid=None, latgrid=None, nsteps=None): axis.isDefault_majloc = True axis.isDefault_minloc = True - @override - def _is_ticklabel_on(self, side: str) -> bool: - # For basemap object, the text is organized - # as a dictionary. The keys are the numerical - # location values, and the values are a list - # where the version item is the tick and the - # the rest are mtext.Text objects. The labels - # are clustereed on the location per axis. - # This means that top and bottom labels are assigned - # to the same numerical loc. - # We therefore create a mapping per direction to make - # it more semantically logical. - def group_labels( - labels: list[mtext.Text], - which: str, - labelbottom=None, - labeltop=None, - labelleft=None, - labelright=None, - ) -> dict[str, list[mtext.Text]]: - group = {} - # We take zero here as a baseline - for label in labels: - position = label.get_position() - target = None - if which == "x": - if labelbottom is not None and position[1] < 0: - target = "labelbottom" - elif labeltop is not None and position[1] >= 0: - target = "labeltop" - else: - if labelleft is not None and position[0] < 0: - target = "labelleft" - elif labelright is not None and position[0] >= 0: - target = "labelright" - if target is not None: - group[target] = group.get(target, []) + [label] - return group - - gl = self.gridlines_major[0] - which = "x" - if side in ["labelleft", "labelright"]: - gl = self.gridlines_major[1] - which = "y" - # Group the text object based on their location - grouped = {} - for loc, (line, labels) in gl.items(): - labels = group_labels( - labels=labels, - which=which, - **{side: True}, - ) - for label in labels.get(side, []): - if label.get_visible(): - return True - return False - # Apply signature obfuscation after storing previous signature GeoAxes._format_signatures[GeoAxes] = inspect.signature(GeoAxes.format) GeoAxes.format = docstring._obfuscate_kwargs(GeoAxes.format) -def _is_rectilinear_projection(ax): +def _is_rectilinear_projection(ax: Any) -> bool: """Check if the axis has a flat projection (works with Cartopy).""" # Determine what the projection function is # Create a square and determine if the lengths are preserved diff --git a/ultraplot/tests/test_geographic.py b/ultraplot/tests/test_geographic.py index f1efed6ec..a57a6904c 100644 --- a/ultraplot/tests/test_geographic.py +++ b/ultraplot/tests/test_geographic.py @@ -615,6 +615,144 @@ def test_get_gridliner_labels_cartopy(): uplt.close(fig) +def test_get_gridliner_labels_basemap(): + fig, ax = uplt.subplots(proj="cyl", backend="basemap") + ax.format(labels="both", lonlines=30, latlines=30) + fig.canvas.draw() # ensure labels are positioned + labels = ax[0]._get_gridliner_labels(bottom=True, top=True, left=True, right=True) + assert labels.get("bottom") + assert labels.get("top") + assert labels.get("left") + assert labels.get("right") + uplt.close(fig) + + +def test_toggle_gridliner_labels_basemap(): + fig, ax = uplt.subplots(proj="cyl", backend="basemap") + ax[0].format(labels="both", lonlines=30, latlines=30) + fig.canvas.draw() + + ax[0]._toggle_gridliner_labels( + labelbottom=False, + labeltop=True, + labelleft=True, + labelright=True, + ) + labels = ax[0]._get_gridliner_labels(bottom=True, top=True, left=True, right=True) + assert labels.get("bottom") + assert labels.get("top") + assert labels.get("left") + assert labels.get("right") + assert all(not label.get_visible() for label in labels["bottom"]) + assert any(label.get_visible() for label in labels["top"]) + assert any(label.get_visible() for label in labels["left"]) + assert any(label.get_visible() for label in labels["right"]) + uplt.close(fig) + + +@pytest.mark.parametrize("backend", ["cartopy", "basemap"]) +def test_tick_params_updates_gridliner(backend): + fig, ax = uplt.subplots(proj="cyl", backend=backend) + ax[0].format(lonlines=30, latlines=30, labels=True, grid=True) + ax[0].tick_params( + labelcolor="red", + labelsize=8, + labelrotation=15, + pad=6, + colors="blue", + width=1.5, + labelbottom=False, + labelleft=False, + ) + + assert not ax[0]._is_ticklabel_on("labelbottom") + assert not ax[0]._is_ticklabel_on("labelleft") + + if ax[0]._name == "cartopy": + gl = ax[0].gridlines_major + assert gl.collection_kwargs.get("color") == "blue" + assert gl.collection_kwargs.get("linewidth") == 1.5 + assert gl.xlabel_style.get("color") == "red" + assert gl.ylabel_style.get("color") == "red" + assert gl.xlabel_style.get("fontsize") == 8 + assert gl.ylabel_style.get("fontsize") == 8 + assert gl.xlabel_style.get("rotation") == 15 + assert gl.ylabel_style.get("rotation") == 15 + if hasattr(gl, "xpadding"): + assert gl.xpadding == 6 + if hasattr(gl, "ypadding"): + assert gl.ypadding == 6 + else: # basemap + from matplotlib import colors as mcolors + from matplotlib import text as mtext + + lonlines, latlines = ax[0].gridlines_major + label_colors = [] + label_sizes = [] + label_rotations = [] + line_colors = [] + line_widths = [] + for grid in (lonlines, latlines): + for _, (lines, labels) in grid.items(): + for line in lines: + if hasattr(line, "get_color"): + line_colors.append(mcolors.to_rgba(line.get_color())) + if hasattr(line, "get_linewidth"): + line_widths.append(line.get_linewidth()) + for label in labels: + if isinstance(label, mtext.Text): + label_colors.append(mcolors.to_rgba(label.get_color())) + label_sizes.append(label.get_fontsize()) + label_rotations.append(label.get_rotation()) + expected_label_color = mcolors.to_rgba("red") + expected_line_color = mcolors.to_rgba("blue") + assert label_colors and all(c == expected_label_color for c in label_colors) + assert label_sizes and all(np.isclose(s, 8) for s in label_sizes) + assert label_rotations and all(np.isclose(r, 15) for r in label_rotations) + assert line_colors and all(c == expected_line_color for c in line_colors) + assert line_widths and all(np.isclose(w, 1.5) for w in line_widths) + + uplt.close(fig) + + +@pytest.mark.parametrize("backend", ["cartopy", "basemap"]) +def test_gridliner_adapter_refresh(backend): + fig, ax = uplt.subplots(proj="cyl", backend=backend) + ax[0].format(lonlines=30, latlines=30, labels=True) + assert ax[0]._gridliner_adapter("major", create=False) is not None + + ax[0]._gridliner_adapters.pop("major", None) + assert ax[0]._gridliner_adapter("major", create=False) is None + _ = ax[0].gridlines_major + assert ax[0]._gridliner_adapter("major", create=False) is not None + uplt.close(fig) + + +@pytest.mark.parametrize("backend", ["cartopy", "basemap"]) +def test_gridliner_tick_positions(backend): + fig, ax = uplt.subplots(proj="cyl", backend=backend) + ax[0].format(lonlines=30, latlines=30, labels=True, grid=True) + fig.canvas.draw() + lon_positions = ax[0]._gridliner_tick_positions("x", which="major") + lat_positions = ax[0]._gridliner_tick_positions("y", which="major") + assert len(lon_positions) > 0 + assert len(lat_positions) > 0 + + if ax[0]._name == "cartopy": + expected_lon = ax[0]._get_lonticklocs() + expected_lat = ax[0]._get_latticklocs() + assert np.allclose(lon_positions, expected_lon) + assert np.allclose(lat_positions, expected_lat) + else: # basemap + lonlines, latlines = ax[0].gridlines_major + expected_lon = np.sort(np.asarray(list(lonlines.keys()))) + expected_lat = np.sort(np.asarray(list(latlines.keys()))) + assert np.allclose(np.sort(lon_positions), expected_lon) + assert np.allclose(np.sort(lat_positions), expected_lat) + + uplt.close(fig) + + @pytest.mark.parametrize("level", [0, 1, 2, 3, 4]) def test_sharing_levels(level): """