diff --git a/ultraplot/axes/base.py b/ultraplot/axes/base.py index 22e489d18..d19c33d0f 100644 --- a/ultraplot/axes/base.py +++ b/ultraplot/axes/base.py @@ -1545,6 +1545,8 @@ def shared(paxs): iax._panel_sharex_group = True iax._sharex_setup(bottom) # parent is bottom-most paxs = shared(self._panel_dict["top"]) + if paxs and self.figure._sharex > 0: + self._panel_sharex_group = True for iax in paxs: iax._panel_sharex_group = True iax._sharex_setup(bottom) @@ -1559,6 +1561,8 @@ def shared(paxs): iax._panel_sharey_group = True iax._sharey_setup(left) # parent is left-most paxs = shared(self._panel_dict["right"]) + if paxs and self.figure._sharey > 0: + self._panel_sharey_group = True for iax in paxs: iax._panel_sharey_group = True iax._sharey_setup(left) @@ -3261,6 +3265,27 @@ def _is_panel_group_member(self, other: "Axes") -> bool: # Not in the same panel group return False + def _label_key(self, side: str) -> str: + """ + Map requested side name to the correct tick_params key across mpl versions. + + This accounts for the API change around Matplotlib 3.10 where labeltop/labelbottom + became first-class tick parameter keys. For older versions, these map to + labelright/labelleft respectively. + """ + from packaging import version + from ..internals import _version_mpl + + # TODO: internal deprecation warning when we drop 3.9, we need to remove this + + use_new = version.parse(str(_version_mpl)) >= version.parse("3.10") + if side == "labeltop": + return "labeltop" if use_new else "labelright" + if side == "labelbottom": + return "labelbottom" if use_new else "labelleft" + # "labelleft" and "labelright" are stable across versions + return side + def _is_ticklabel_on(self, side: str) -> bool: """ Check if tick labels are on for the specified sides. @@ -3274,10 +3299,8 @@ def _is_ticklabel_on(self, side: str) -> bool: label = "label1" if side in ["labelright", "labeltop"]: label = "label2" - for tick in axis.get_major_ticks(): - if getattr(tick, label).get_visible(): - return True - return False + + return axis.get_tick_params().get(self._label_key(side), False) @docstring._snippet_manager def inset(self, *args, **kwargs): diff --git a/ultraplot/axes/geo.py b/ultraplot/axes/geo.py index 896bc0a6d..15c5f9a43 100644 --- a/ultraplot/axes/geo.py +++ b/ultraplot/axes/geo.py @@ -652,27 +652,16 @@ def _apply_axis_sharing(self): or to the *right* of the leftmost panel. But the sharing level used for the leftmost and bottommost is the *figure* sharing level. """ - # Handle X axis sharing - if self._sharex: - self._handle_axis_sharing( - source_axis=self._sharex._lonaxis, - target_axis=self._lonaxis, - ) - # Handle Y axis sharing - if self._sharey: - self._handle_axis_sharing( - source_axis=self._sharey._lataxis, - target_axis=self._lataxis, - ) + # Share interval x + if self._sharex and self.figure._sharex >= 2: + self._lonaxis.set_view_interval(*self._sharex._lonaxis.get_view_interval()) + self._lonaxis.set_minor_locator(self._sharex._lonaxis.get_minor_locator()) - # This block is apart of the draw sequence as the - # gridliner object is created late in the - # build chain. - if not self.stale: - return - if self.figure._get_sharing_level() == 0: - return + # Share interval y + if self._sharey and self.figure._sharey >= 2: + self._lataxis.set_view_interval(*self._sharey._lataxis.get_view_interval()) + self._lataxis.set_minor_locator(self._sharey._lataxis.get_minor_locator()) def _get_gridliner_labels( self, @@ -691,38 +680,36 @@ def _toggle_gridliner_labels( labelright=None, geo=None, ): - # For BasemapAxes the gridlines are dicts with key as the coordinate and keys the line and label - # We override the dict here assuming the labels are mut excl due to the N S E W extra chars + """ + Toggle visibility of gridliner labels for each direction. + + Parameters + ---------- + labeltop, labelbottom, labelleft, labelright : bool or None + Whether to show labels on each side. If None, do not change. + geo : optional + Not used in this method. + """ + # Ensure gridlines_major is fully initialized if any(i is None for i in self.gridlines_major): return + gridlabels = self._get_gridliner_labels( bottom=labelbottom, top=labeltop, left=labelleft, right=labelright ) - bools = [labelbottom, labeltop, labelleft, labelright] - directions = "bottom top left right".split() - for direction, toggle in zip(directions, bools): + + toggles = { + "bottom": labelbottom, + "top": labeltop, + "left": labelleft, + "right": labelright, + } + + for direction, toggle in toggles.items(): if toggle is None: continue for label in gridlabels.get(direction, []): - label.set_visible(toggle) - - def _handle_axis_sharing( - self, - source_axis: "GeoAxes", - target_axis: "GeoAxes", - ): - """ - Helper method to handle axis sharing for both X and Y axes. - - Args: - source_axis: The source axis to share from - target_axis: The target axis to apply sharing to - """ - # Copy view interval and minor locator from source to target - - if self.figure._get_sharing_level() >= 2: - target_axis.set_view_interval(*source_axis.get_view_interval()) - target_axis.set_minor_locator(source_axis.get_minor_locator()) + label.set_visible(bool(toggle) or toggle in ("x", "y")) @override def draw(self, renderer=None, *args, **kwargs): @@ -1441,6 +1428,7 @@ def _is_ticklabel_on(self, side: str) -> bool: """ # Deal with different cartopy versions left_labels, right_labels, bottom_labels, top_labels = self._get_side_labels() + if self.gridlines_major is None: return False elif side == "labelleft": diff --git a/ultraplot/axes/polar.py b/ultraplot/axes/polar.py index d66e3e2ea..94950179d 100644 --- a/ultraplot/axes/polar.py +++ b/ultraplot/axes/polar.py @@ -4,6 +4,11 @@ """ import inspect +try: + from typing import override +except: + from typing_extensions import override + import matplotlib.projections.polar as mpolar import numpy as np @@ -138,6 +143,11 @@ def __init__(self, *args, **kwargs): for axis in (self.xaxis, self.yaxis): axis.set_tick_params(which="both", size=0) + @override + def _apply_axis_sharing(self): + # Not implemented. Silently pass + return + def _update_formatter(self, x, *, formatter=None, formatter_kw=None): """ Update the gridline label formatter. diff --git a/ultraplot/axes/shared.py b/ultraplot/axes/shared.py index 57d5abe0b..8b434645a 100644 --- a/ultraplot/axes/shared.py +++ b/ultraplot/axes/shared.py @@ -212,7 +212,6 @@ def _share_axis_with(self, other: "Axes", *, which: str): ) self._shared_axes[which].join(self, other) - # Get axis objects this_axis = getattr(self, f"{which}axis") other_axis = getattr(other, f"{which}axis") @@ -227,7 +226,7 @@ def _share_axis_with(self, other: "Axes", *, which: str): get_autoscale = getattr(other, f"get_autoscale{which}_on") lim0, lim1 = limits - set_lim(lim0, lim1, emit=False, auto=get_autoscale()) + set_lim(lim0, lim1, emit=False, auto=get_autoscale()) # Set scale - # Set scale + # Override scale this_axis._scale = other_axis._scale diff --git a/ultraplot/figure.py b/ultraplot/figure.py index 981fa2424..d28e929c8 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -6,6 +6,7 @@ import inspect import os from numbers import Integral +from packaging import version try: from typing import List @@ -20,6 +21,11 @@ import matplotlib.transforms as mtransforms import numpy as np +try: + from typing import override +except: + from typing_extensions import override + from . import axes as paxes from . import constructor from . import gridspec as pgridspec @@ -477,6 +483,21 @@ def _canvas_preprocess(self, *args, **kwargs): return canvas +def _clear_border_cache(func): + """ + Decorator that clears the border cache after function execution. + """ + + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + result = func(self, *args, **kwargs) + if hasattr(self, "_cached_border_axes"): + delattr(self, "_cached_border_axes") + return result + + return wrapper + + class Figure(mfigure.Figure): """ The `~matplotlib.figure.Figure` subclass used by ultraplot. @@ -801,6 +822,217 @@ def __init__( # NOTE: This ignores user-input rc_mode. self.format(rc_kw=rc_kw, rc_mode=1, skip_axes=True, **kw_format) + @override + def draw(self, renderer): + # implement the tick sharing here + # should be shareable --> either all cartesian or all geographic + # but no mixing (panels can be mixed) + # check which ticks are on for x or y and push the labels to the + # outer most on a given column or row. + # we can use get_border_axes for the outermost plots and then collect their outermost panels that are not colorbars + self._share_ticklabels(axis="x") + self._share_ticklabels(axis="y") + super().draw(renderer) + + def _share_ticklabels(self, *, axis: str) -> None: + """ + Tick label sharing is determined at the figure level. While + each subplot controls the limits, we are dealing with the ticklabels + here as the complexity is easier to deal with. + axis: str 'x' or 'y', row or columns to update + """ + if not self.stale: + return + + outer_axes = self._get_border_axes() + sides = ("top", "bottom") if axis == "x" else ("left", "right") + + # Group axes by row (for x) or column (for y) + axes = list(self._iter_axes(panels=True, hidden=False)) + groups = self._group_axes_by_axis(axes, axis) + + # Version-dependent label name mapping for reading back params + label_keys = self._label_key_map() + + # Process each group independently + for _, group_axes in groups.items(): + # Build baseline from MAIN axes only (exclude panels) + baseline, skip_group = self._compute_baseline_tick_state( + group_axes, axis, label_keys + ) + if skip_group: + continue + + # Apply baseline to all axes in the group (including panels) + for axi in group_axes: + # Respect figure border sides and panel opposite sides + masked = self._apply_border_mask(axi, baseline, sides, outer_axes) + + # Determine sharing level for this axes + if self._effective_share_level(axi, axis, sides) < 3: + continue + + # Apply to geo/cartesian appropriately + self._set_ticklabel_state(axi, axis, masked) + + self.stale = True + + def _label_key_map(self): + """ + Return a mapping for version-dependent label keys for Matplotlib tick params. + """ + first_axi = next(self._iter_axes(panels=True), None) + if first_axi is None: + return { + "labelleft": "labelleft", + "labelright": "labelright", + "labeltop": "labeltop", + "labelbottom": "labelbottom", + } + return { + name: first_axi._label_key(name) + for name in ("labelleft", "labelright", "labeltop", "labelbottom") + } + + def _group_axes_by_axis(self, axes, axis: str): + """ + Group axes by row (x) or column (y). Panels included; invalid subplotspec skipped. + """ + from collections import defaultdict + + def _group_key(ax): + ss = ax.get_subplotspec() + return ss.rowspan.start if axis == "x" else ss.colspan.start + + groups = defaultdict(list) + for axi in axes: + try: + key = _group_key(axi) + except Exception: + # If we can't get a subplotspec, skip grouping for this axes + continue + groups[key].append(axi) + return groups + + def _compute_baseline_tick_state(self, group_axes, axis: str, label_keys): + """ + Build a baseline ticklabel visibility dict from MAIN axes (panels excluded). + Returns (baseline_dict, skip_group: bool). Emits warnings when encountering + unsupported or mixed subplot types. + """ + baseline = {} + subplot_types = set() + unsupported_found = False + sides = ("top", "bottom") if axis == "x" else ("left", "right") + + for axi in group_axes: + # Only main axes "vote" + if getattr(axi, "_panel_side", None): + continue + + # Supported axes types + if not isinstance( + axi, (paxes.CartesianAxes, paxes._CartopyAxes, paxes._BasemapAxes) + ): + warnings._warn_ultraplot( + f"Tick label sharing not implemented for {type(axi)} subplots." + ) + unsupported_found = True + break + + subplot_types.add(type(axi)) + + # Collect label visibility state + if isinstance(axi, paxes.CartesianAxes): + params = getattr(axi, f"{axis}axis").get_tick_params() + for side in sides: + key = label_keys[f"label{side}"] + if params.get(key): + baseline[key] = params[key] + elif isinstance(axi, paxes.GeoAxes): + for side in sides: + key = f"label{side}" + if axi._is_ticklabel_on(key): + baseline[key] = axi._is_ticklabel_on(key) + + if unsupported_found: + return {}, True + + # We cannot mix types (yet) within a group + if len(subplot_types) > 1: + warnings._warn_ultraplot( + "Tick label sharing not implemented for mixed subplot types." + ) + return {}, True + + return baseline, False + + def _apply_border_mask( + self, axi, baseline: dict, sides: tuple[str, str], outer_axes + ): + """ + Apply figure-border constraints and panel opposite-side suppression. + Keeps label key mapping per-axis for cartesian. + """ + from .axes.cartesian import OPPOSITE_SIDE + + masked = baseline.copy() + for side in sides: + label = f"label{side}" + if isinstance(axi, paxes.CartesianAxes): + # Use per-axis version-mapped key when writing + label = axi._label_key(label) + + # Only keep labels on true figure borders + if axi not in outer_axes[side]: + masked[label] = False + + # For panels, suppress labels on their opposite side + if ( + getattr(axi, "_panel_side", None) + and OPPOSITE_SIDE[axi._panel_side] == side + ): + masked[label] = False + + return masked + + def _effective_share_level(self, axi, axis: str, sides: tuple[str, str]) -> int: + """ + Compute the effective share level for an axes, considering panel groups and + adjacent panels. Fixes the original variable leak by checking any relevant side. + """ + level = getattr(self, f"_share{axis}") + # If figure-level sharing is disabled (0/False), don't promote due to panels + if not level or (isinstance(level, (int, float)) and level < 1): + return level + + # Panel group-level sharing + if getattr(axi, f"_panel_share{axis}_group", None): + return 3 + + # Panel member sharing + if getattr(axi, "_panel_side", None) and getattr(axi, f"_share{axis}", None): + return 3 + + # Adjacent panels on any relevant side + panel_dict = getattr(axi, "_panel_dict", {}) + for side in sides: + side_panels = panel_dict.get(side) or [] + if side_panels and getattr(side_panels[0], f"_share{axis}", False): + return 3 + + return level + + def _set_ticklabel_state(self, axi, axis: str, state: dict): + """Apply the computed ticklabel state to cartesian or geo axes.""" + if state: + # Normalize "x"/"y" values to booleans for both Geo and Cartesian axes + cleaned = {k: (True if v in ("x", "y") else v) for k, v in state.items()} + if isinstance(axi, paxes.GeoAxes): + axi._toggle_gridliner_labels(**cleaned) + else: + getattr(axi, f"{axis}axis").set_tick_params(**cleaned) + def _context_adjusting(self, cache=True): """ Prevent re-running auto layout steps due to draws triggered by figure @@ -918,10 +1150,10 @@ def _get_align_axes(self, side): options = grid.T[:, ::-1] uids = set() for option in options: - idx = np.where(option > 0)[0] + idx = np.where(option != None)[0] if idx.size > 0: first = idx.min() - number = option[first].astype(int) + number = option[first].number uids.add(number) axs = [] # Collect correct axes @@ -953,8 +1185,9 @@ def _get_border_axes( if gs is None: return border_axes - # Skip colorbars or panels etc - all_axes = [axi for axi in self.axes if axi.number is not None] + all_axes = [] + for axi in self._iter_axes(panels=True): + all_axes.append(axi) # Handle empty cases nrows, ncols = gs.nrows, gs.ncols @@ -966,26 +1199,51 @@ def _get_border_axes( # Reconstruct the grid based on axis locations. Note that # spanning axes will fit into one of the boxes. Check # this with unittest to see how empty axes are handles - grid, grid_axis_type, seen_axis_type = _get_subplot_layout( - gs, - all_axes, - same_type=same_type, - ) + + gs = self.axes[0].get_gridspec() + shape = (gs.nrows_total, gs.ncols_total) + grid = np.zeros(shape, dtype=object) + grid.fill(None) + grid_axis_type = np.zeros(shape, dtype=int) + seen_axis_type = dict() + ax_type_mapping = dict() + for axi in self._iter_axes(panels=True, hidden=True): + gs = axi.get_subplotspec() + x, y = np.unravel_index(gs.num1, shape) + span = gs._get_rows_columns() + + xleft, xright, yleft, yright = span + xspan = xright - xleft + 1 + yspan = yright - yleft + 1 + number = axi.number + axis_type = type(axi) + if isinstance(axi, (paxes.GeoAxes)): + axis_type = axi.projection + if axis_type not in seen_axis_type: + seen_axis_type[axis_type] = len(seen_axis_type) + type_number = seen_axis_type[axis_type] + ax_type_mapping[axi] = type_number + if axi.get_visible(): + grid[x : x + xspan, y : y + yspan] = axi + grid_axis_type[x : x + xspan, y : y + yspan] = type_number # We check for all axes is they are a border or not # Note we could also write the crawler in a way where # it find the borders by moving around in the grid, without spawning on each axis point. We may change # this in the future for axi in all_axes: - axis_type = seen_axis_type.get(type(axi), 1) + axis_type = ax_type_mapping[axi] + number = axi.number + if axi.number is None: + number = -axi._panel_parent.number crawler = _Crawler( ax=axi, grid=grid, - target=axi.number, + target=number, axis_type=axis_type, grid_axis_type=grid_axis_type, ) for direction, is_border in crawler.find_edges(): - if is_border: + if is_border and axi not in border_axes[direction]: border_axes[direction].append(axi) self._cached_border_axes = border_axes return border_axes @@ -1079,12 +1337,7 @@ def _get_renderer(self): renderer = canvas.get_renderer() return renderer - def _get_sharing_level(self): - """ - We take the average here as the sharex and sharey should be the same value. In case this changes in the future we can track down the error easily - """ - return 0.5 * (self.figure._sharex + self.figure._sharey) - + @_clear_border_cache def _add_axes_panel(self, ax, side=None, **kwargs): """ Add an axes panel. @@ -1116,6 +1369,13 @@ def _add_axes_panel(self, ax, side=None, **kwargs): raise RuntimeError("The gridspec must be active.") kw = _pop_params(kwargs, gs._insert_panel_slot) ss, share = gs._insert_panel_slot(side, ax, **kw) + # Guard: GeoAxes with non-rectilinear projections cannot share with panels + if isinstance(ax, paxes.GeoAxes) and not ax._is_rectilinear(): + if share: + warnings._warn_ultraplot( + "Panel sharing disabled for non-rectilinear GeoAxes projections." + ) + share = False kwargs["autoshare"] = False kwargs.setdefault("number", False) # power users might number panels pax = self.add_subplot(ss, **kwargs) @@ -1127,8 +1387,70 @@ def _add_axes_panel(self, ax, side=None, **kwargs): axis = pax.yaxis if side in ("left", "right") else pax.xaxis getattr(axis, "tick_" + side)() # set tick and tick label position axis.set_label_position(side) # set label position + # Sync limits and formatters with parent when sharing to ensure consistent ticks + # Copy limits for the shared axis + # Note: for non-geo axes this is handled by auto sharing + if share and isinstance(ax, paxes.GeoAxes): + # Align with backend: for GeoAxes, use lon/lat degree formatters on panels. + # Otherwise, copy the parent's axis formatters. + fmt_key = "deglat" if side in ("left", "right") else "deglon" + axis.set_major_formatter(constructor.Formatter(fmt_key)) + # Update limits + axis._set_lim( + *getattr(ax, f"get_{'y' if side in ('left','right') else 'x'}lim")(), + auto=True, + ) + # Push main axes tick labels to the outside relative to the added panel + # Skip this for filled panels (colorbars/legends) + if not kw.get("filled", False) and share: + if isinstance(ax, paxes.GeoAxes): + if side == "top": + ax._toggle_gridliner_labels(labeltop=False) + elif side == "bottom": + ax._toggle_gridliner_labels(labelbottom=False) + elif side == "left": + ax._toggle_gridliner_labels(labelleft=False) + elif side == "right": + ax._toggle_gridliner_labels(labelright=False) + else: + if side == "top": + ax.xaxis.set_tick_params(**{ax._label_key("labeltop"): False}) + elif side == "bottom": + ax.xaxis.set_tick_params(**{ax._label_key("labelbottom"): False}) + elif side == "left": + ax.yaxis.set_tick_params(**{ax._label_key("labelleft"): False}) + elif side == "right": + ax.yaxis.set_tick_params(**{ax._label_key("labelright"): False}) + + # Panel labels: prefer outside only for non-sharing top/right; otherwise keep off + if side == "top": + if not share: + pax.xaxis.set_tick_params( + **{ + pax._label_key("labeltop"): True, + pax._label_key("labelbottom"): False, + } + ) + else: + on = ax.xaxis.get_tick_params()[ax._label_key("labeltop")] + pax.xaxis.set_tick_params(**{pax._label_key("labeltop"): on}) + ax.yaxis.set_tick_params(labeltop=False) + elif side == "right": + if not share: + pax.yaxis.set_tick_params( + **{ + pax._label_key("labelright"): True, + pax._label_key("labelleft"): False, + } + ) + else: + on = ax.yaxis.get_tick_params()[ax._label_key("labelright")] + pax.yaxis.set_tick_params(**{pax._label_key("labelright"): on}) + ax.yaxis.set_tick_params(**{ax._label_key("labelright"): False}) + return pax + @_clear_border_cache def _add_figure_panel( self, side=None, span=None, row=None, col=None, rows=None, cols=None, **kwargs ): @@ -1163,6 +1485,7 @@ def _add_figure_panel( pax._panel_parent = None return pax + @_clear_border_cache def _add_subplot(self, *args, **kwargs): """ The driver function for adding single subplots. @@ -1271,9 +1594,6 @@ def _add_subplot(self, *args, **kwargs): if ax.number: self._subplot_dict[ax.number] = ax - # Invalidate border axes cache - if hasattr(self, "_cached_border_axes"): - delattr(self, "_cached_border_axes") return ax def _unshare_axes(self): @@ -1288,56 +1608,6 @@ def _unshare_axes(self): if isinstance(ax, paxes.GeoAxes) and hasattr(ax, "set_global"): ax.set_global() - def _share_labels_with_others(self, *, which="both"): - """ - Helpers function to ensure the labels - are shared for rectilinear GeoAxes. - """ - # Only apply sharing of labels when we are - # actually sharing labels. - if self._get_sharing_level() == 0: - return - # Turn all labels off - # Note: this action performs it for all the axes in - # the figure. We use the stale here to only perform - # it once as it is an expensive action. - # The axis will be a border if it is either - # (a) on the edge - # (b) not next to a subplot - # (c) not next to a subplot of the same kind - border_axes = self._get_border_axes() - # Recode: - recoded = {} - for direction, axes in border_axes.items(): - for axi in axes: - recoded[axi] = recoded.get(axi, []) + [direction] - - are_ticks_on = False - default = dict( - labelleft=are_ticks_on, - labelright=are_ticks_on, - labeltop=are_ticks_on, - labelbottom=are_ticks_on, - ) - for axi in self._iter_axes(hidden=False, panels=False, children=False): - # Turn the ticks on or off depending on the position - sides = recoded.get(axi, []) - turn_on_or_off = default.copy() - - for side in sides: - sidelabel = f"label{side}" - is_label_on = axi._is_ticklabel_on(sidelabel) - if is_label_on: - # When we are a border an the labels are on - # we keep them on - assert sidelabel in turn_on_or_off - turn_on_or_off[sidelabel] = True - - if isinstance(axi, paxes.GeoAxes): - axi._toggle_gridliner_labels(**turn_on_or_off) - else: - axi._apply_axis_sharing() - def _toggle_axis_sharing( self, *, @@ -1400,19 +1670,19 @@ def get_key(ax): # shared axes behave consistently. if which == "x": other._sharex = ref - ref.xaxis.major = other.xaxis.major - ref.xaxis.minor = other.xaxis.minor - lim = other.get_xlim() - ref.set_xlim(*lim, emit=False, auto=other.get_autoscalex_on()) - ref.xaxis._scale = other.xaxis._scale + other.xaxis.major = ref.xaxis.major + other.xaxis.minor = ref.xaxis.minor + lim = ref.get_xlim() + other.set_xlim(*lim, emit=False, auto=ref.get_autoscalex_on()) + other.xaxis._scale = ref.xaxis._scale if which == "y": # This logic is from sharey other._sharey = ref - ref.yaxis.major = other.yaxis.major - ref.yaxis.minor = other.yaxis.minor - lim = other.get_ylim() - ref.set_ylim(*lim, emit=False, auto=other.get_autoscaley_on()) - ref.yaxis._scale = other.yaxis._scale + other.yaxis.major = ref.yaxis.major + other.yaxis.minor = ref.yaxis.minor + lim = ref.get_ylim() + other.set_ylim(*lim, emit=False, auto=ref.get_autoscaley_on()) + other.yaxis._scale = ref.yaxis._scale def _add_subplots( self, @@ -1753,6 +2023,7 @@ def _update_super_title(self, title, **kwargs): if title is not None: self._suptitle.set_text(title) + @_clear_border_cache @docstring._concatenate_inherited @docstring._snippet_manager def add_axes(self, rect, **kwargs): @@ -1847,7 +2118,6 @@ def _align_content(): # noqa: E306 # subsequent tight layout really weird. Have to resize twice. _draw_content() if not gs: - print("hello") return if aspect: gs._auto_layout_aspect() @@ -1993,12 +2263,6 @@ def format( } ax.format(rc_kw=rc_kw, rc_mode=rc_mode, skip_figure=True, **kw, **kwargs) ax.number = store_old_number - # When we apply formatting to all axes, we need - # to potentially adjust the labels. - - if len(axs) == len(self.axes) and self._get_sharing_level() > 0: - self._share_labels_with_others() - # Warn unused keyword argument(s) kw = { key: value @@ -2010,53 +2274,6 @@ def format( f"Ignoring unused projection-specific format() keyword argument(s): {kw}" # noqa: E501 ) - def _share_labels_with_others(self, *, which="both"): - """ - Helpers function to ensure the labels - are shared for rectilinear GeoAxes. - """ - # Turn all labels off - # Note: this action performs it for all the axes in - # the figure. We use the stale here to only perform - # it once as it is an expensive action. - border_axes = self._get_border_axes(same_type=False) - # Recode: - recoded = {} - for direction, axes in border_axes.items(): - for axi in axes: - recoded[axi] = recoded.get(axi, []) + [direction] - - # We turn off the tick labels when the scale and - # ticks are shared (level > 0) - are_ticks_on = False - default = dict( - labelleft=are_ticks_on, - labelright=are_ticks_on, - labeltop=are_ticks_on, - labelbottom=are_ticks_on, - ) - for axi in self._iter_axes(hidden=False, panels=False, children=False): - # Turn the ticks on or off depending on the position - sides = recoded.get(axi, []) - turn_on_or_off = default.copy() - # The axis will be a border if it is either - # (a) on the edge - # (b) not next to a subplot - # (c) not next to a subplot of the same kind - for side in sides: - sidelabel = f"label{side}" - is_label_on = axi._is_ticklabel_on(sidelabel) - if is_label_on: - # When we are a border an the labels are on - # we keep them on - assert sidelabel in turn_on_or_off - turn_on_or_off[sidelabel] = True - - if isinstance(axi, paxes.GeoAxes): - axi._toggle_gridliner_labels(**turn_on_or_off) - else: - axi.tick_params(which=which, **turn_on_or_off) - @docstring._concatenate_inherited @docstring._snippet_manager def colorbar( diff --git a/ultraplot/gridspec.py b/ultraplot/gridspec.py index 159cac2c5..eaa20a5fb 100644 --- a/ultraplot/gridspec.py +++ b/ultraplot/gridspec.py @@ -203,11 +203,12 @@ def _get_grid_span(self, hidden=False) -> (int, int, int, int): """ gs = self.get_gridspec() nrows, ncols = gs.nrows_total, gs.ncols_total - if not hidden: + if hidden: + x, y = np.unravel_index(self.num1, (nrows, ncols)) + else: nrows, ncols = gs.nrows, gs.ncols - # Use num1 or num2 - decoded = gs._decode_indices(self.num1) - x, y = np.unravel_index(decoded, (nrows, ncols)) + decoded = gs._decode_indices(self.num1) + x, y = np.unravel_index(decoded, (nrows, ncols)) span = self._get_rows_columns() xspan = span[1] - span[0] + 1 # inclusive @@ -411,13 +412,15 @@ def _normalize_index(key, size, axis=None): # noqa: E306 num1, num2 = self._encode_indices(num1, num2) return _SubplotSpec(self, num1, num2) - def _encode_indices(self, *args, which=None): + def _encode_indices(self, *args, which=None, panel=False): """ - Convert indices from the "unhidden" gridspec geometry into indices for the + Convert indices from the selected gridspec geometry into indices for the total geometry. If `which` is not passed these should be flattened indices. + When `panel` is True, indices are interpreted relative to panel slots + along the specified axis; otherwise they refer to non-panel slots. """ nums = [] - idxs = self._get_indices(which) + idxs = self._get_indices(which=which, panel=panel) for arg in args: try: nums.append(idxs[arg]) @@ -425,13 +428,15 @@ def _encode_indices(self, *args, which=None): raise ValueError(f"Invalid gridspec index {arg}.") return nums[0] if len(nums) == 1 else nums - def _decode_indices(self, *args, which=None): + def _decode_indices(self, *args, which=None, panel=False): """ - Convert indices from the total geometry into the "unhidden" gridspec + Convert indices from the total geometry into the selected gridspec geometry. If `which` is not passed these should be flattened indices. + When `panel` is True, indices are interpreted relative to panel slots + along the specified axis; otherwise they refer to non-panel slots. """ nums = [] - idxs = self._get_indices(which) + idxs = self._get_indices(which=which, panel=panel) for arg in args: try: nums.append(idxs.index(arg)) @@ -1552,17 +1557,29 @@ def __getitem__(self, key): f"{self.__class__.__name__} has no gridspec, cannot index with {key!r}." ) # Build grid with None for empty slots - grid = np.full((gs.nrows_total, gs.ncols_total), None, dtype=object) + from .utils import _get_subplot_layout + + print(self) + grid = _get_subplot_layout(gs, [i for i in self])[0] + + # Determine if along each axis this grid consists only of panel slots + used_rows = set() + used_cols = set() for ax in self: - spec = ax.get_subplotspec() - x1, x2, y1, y2 = spec._get_rows_columns(ncols=gs.ncols_total) - grid[x1 : x2 + 1, y1 : y2 + 1] = ax + ss = ax.get_subplotspec().get_topmost_subplotspec() + r1, r2, c1, c2 = ss._get_rows_columns() + used_rows.update(range(r1, r2 + 1)) + used_cols.update(range(c1, c2 + 1)) + panel_h = all(gs._hpanels[i] for i in used_rows) if used_rows else False + panel_w = all(gs._wpanels[i] for i in used_cols) if used_cols else False new_key = [] for which, keyi in zip("hw", key): try: - encoded_keyi = gs._encode_indices(keyi, which=which) - except: + panel_flag = panel_h if which == "h" else panel_w + encoded_keyi = gs._encode_indices(keyi, which=which, panel=panel_flag) + print(encoded_keyi) + except Exception: raise IndexError( f"Attempted to access {key=} for gridspec {grid.shape=}" ) @@ -1573,6 +1590,7 @@ def __getitem__(self, key): objs = [obj for obj in objs.flat if obj is not None] elif not isinstance(objs, list): objs = [objs] + print(objs) if len(objs) == 1: return objs[0] diff --git a/ultraplot/tests/conftest.py b/ultraplot/tests/conftest.py index e6848abaa..db2482d90 100644 --- a/ultraplot/tests/conftest.py +++ b/ultraplot/tests/conftest.py @@ -3,7 +3,6 @@ import warnings, logging logging.getLogger("matplotlib").setLevel(logging.ERROR) - SEED = 51423 diff --git a/ultraplot/tests/test_2dplots.py b/ultraplot/tests/test_2dplots.py index 13f084c64..a2b75319d 100644 --- a/ultraplot/tests/test_2dplots.py +++ b/ultraplot/tests/test_2dplots.py @@ -30,12 +30,12 @@ def test_auto_diverging1(rng): """ # Test with basic data fig = uplt.figure() - # fig.format(collabels=('Auto sequential', 'Auto diverging'), suptitle='Default') ax = fig.subplot(121) ax.pcolor(rng.random((10, 10)) * 5, colorbar="b") ax = fig.subplot(122) ax.pcolor(rng.random((10, 10)) * 5 - 3.5, colorbar="b") fig.format(toplabels=("Sequential", "Diverging")) + fig.canvas.draw() return fig diff --git a/ultraplot/tests/test_axes.py b/ultraplot/tests/test_axes.py index a04c2233a..370f2c520 100644 --- a/ultraplot/tests/test_axes.py +++ b/ultraplot/tests/test_axes.py @@ -8,6 +8,46 @@ from ultraplot.internals.warnings import UltraPlotWarning +@pytest.mark.parametrize( + "side,row_sel,col_sel,expected_len,fmt_kwargs", + [ + ("right", slice(None), -1, 2, {"yticklabelloc": "l"}), + ("left", slice(None), -1, 2, {"yticklabelloc": "l"}), + ("top", -1, slice(None), 2, {"xticklabelloc": "b"}), + ("bottom", -1, slice(None), 2, {"xticklabelloc": "b"}), + ], +) +@pytest.mark.mpl_image_compare +def test_panel_only_gridspec_indexing_panels( + side, row_sel, col_sel, expected_len, fmt_kwargs +): + """ + Ensure indexing works for grids that consist only of panel axes across sides. + For left/right panels, we index the last panel column with pax[:, -1]. + For top/bottom panels, we index the last panel row with pax[-1, :]. + """ + fig, ax = uplt.subplots(nrows=2, ncols=2) + pax = ax.panel(side) + + # Should be able to index the desired panel slice without raising + sub = pax[row_sel, col_sel] + + # It should return the expected number of panel axes + try: + n = len(sub) + except TypeError: + pytest.fail("Expected a SubplotGrid selection, got a single Axes.") + else: + assert n == expected_len + + # And formatting should work on the selection + sub.format(**fmt_kwargs) + + # Draw to finalize layout and return figure for image comparison + fig.canvas.draw() + return fig + + @pytest.mark.parametrize( "value", [ @@ -352,7 +392,7 @@ def test_sharing_labels_top_right(): [3, 4, 5], [3, 4, 0], ], - 3, # default sharing level + True, # default sharing level {"xticklabelloc": "t", "yticklabelloc": "r"}, [1, 3, 4], # y-axis labels visible indices [0, 1, 4], # x-axis labels visible indices @@ -405,6 +445,7 @@ def check_state(ax, numbers, state, which): # Format axes with the specified tick label locations ax.format(**tick_loc) + fig.canvas.draw() # needed for sharing labels # Calculate the indices where labels should be hidden all_indices = list(range(len(ax))) diff --git a/ultraplot/tests/test_colorbar.py b/ultraplot/tests/test_colorbar.py index 0e19ad8c7..6781e3b81 100644 --- a/ultraplot/tests/test_colorbar.py +++ b/ultraplot/tests/test_colorbar.py @@ -363,6 +363,54 @@ def test_label_placement_fig_colorbar2(): return fig +def test_colorbar_does_not_promote_panel_group_with_share_false(): + """ + Colorbars should not affect panel group membership, and panels should + not promote sharing when the figure-level share is disabled. + """ + fig, ax = uplt.subplots(nrows=2, share=False) + ax[0].panel("right") + ax[0].colorbar("magma", loc="top") + fig.canvas.draw() + assert ax[0]._panel_sharey_group is False + + +def test_legend_does_not_promote_panel_group_with_share_false(): + """ + Legends should not affect panel group membership, and panels should + not promote sharing when the figure-level share is disabled. + """ + fig, ax = uplt.subplots(ncols=2, share=False) + ax[0].panel("top") + ax[0].legend(loc="right") + fig.canvas.draw() + assert ax[0]._panel_sharex_group is False + + +def test_border_axes_update_after_panel_with_colorbar_and_legend(): + """ + Adding a panel should update border axes cache even if colorbars/legends exist. + The main axes should no longer be considered the outermost on that side; the + new panel should be instead. + """ + fig, axs = uplt.subplots() + axi = axs[0] + # Add guides that could affect layout + axi.colorbar("magma", loc="top") + axi.legend(loc="right") + + before = fig._get_border_axes() + pax = axi.panel("right") + fig.canvas.draw() + after = fig._get_border_axes() + + # Right border before: main axes is outermost + assert axi in before.get("right", []) + # Right border after: main axes is no longer outermost; panel is + assert axi not in after.get("right", []) + assert pax in after.get("right", []) + + @pytest.mark.parametrize( ("labelloc", "cbarloc"), product( diff --git a/ultraplot/tests/test_figure.py b/ultraplot/tests/test_figure.py index 0e92f8f2f..cffa3c7f6 100644 --- a/ultraplot/tests/test_figure.py +++ b/ultraplot/tests/test_figure.py @@ -58,7 +58,17 @@ def test_unsharing_different_rectilinear(): """ with pytest.warns(uplt.internals.warnings.UltraPlotWarning): fig, ax = uplt.subplots(ncols=2, proj=("cyl", "merc"), share="all") - uplt.close(fig) + + +def test_get_renderer_basic(): + """ + Test that _get_renderer returns a renderer object. + """ + fig, ax = uplt.subplots() + renderer = fig._get_renderer() + # Renderer should not be None and should have draw_path method + assert renderer is not None + assert hasattr(renderer, "draw_path") def test_figure_sharing_toggle(): diff --git a/ultraplot/tests/test_geographic.py b/ultraplot/tests/test_geographic.py index 35789a54d..30911c176 100644 --- a/ultraplot/tests/test_geographic.py +++ b/ultraplot/tests/test_geographic.py @@ -296,6 +296,7 @@ def are_labels_on(ax, which=["top", "bottom", "right", "left"]) -> tuple[bool]: settings = dict(land=True, ocean=True, labels="both") fig, ax = uplt.subplots(layout, share="all", proj="cyl") ax.format(**settings) + fig.canvas.draw() # needed for sharing labels for axi in ax: state = are_labels_on(axi) expectation = expectations[axi.number - 1] @@ -354,6 +355,80 @@ def test_toggle_gridliner_labels(): uplt.close(fig) +def test_geo_panel_group_respects_figure_share(): + """ + Ensure that panel-only configurations do not promote sharing when figure-level + sharing is disabled, and do promote when figure-level sharing is enabled for GeoAxes. + """ + # Right-only panels with share=False should NOT mark y panel-group + fig, ax = uplt.subplots(nrows=2, proj="cyl", share=False) + ax[0].panel("right") + fig.canvas.draw() + assert ax[0]._panel_sharey_group is False + + # Right-only panels with share='labels' SHOULD mark y panel-group + fig2, ax2 = uplt.subplots(nrows=2, proj="cyl", share="labels") + ax2[0].panel("right") + fig2.canvas.draw() + assert ax2[0]._panel_sharey_group is True + + # Top-only panels with share=False should NOT mark x panel-group + fig3, ax3 = uplt.subplots(ncols=2, proj="cyl", share=False) + ax3[0].panel("top") + fig3.canvas.draw() + assert ax3[0]._panel_sharex_group is False + + # Top-only panels with share='labels' SHOULD mark x panel-group + fig4, ax4 = uplt.subplots(ncols=2, proj="cyl", share="labels") + ax4[0].panel("top") + fig4.canvas.draw() + assert ax4[0]._panel_sharex_group is True + + +def test_geo_panel_share_flag_controls_membership(): + """ + Panels created with share=False should not join panel share groups even when + the figure has sharing enabled, for GeoAxes as well. + """ + # Y panels: right-only with panel share=False + fig, ax = uplt.subplots(nrows=2, proj="cyl", share="labels") + ax[0].panel("right", share=False) + fig.canvas.draw() + assert ax[0]._panel_sharey_group is False + + # X panels: top-only with panel share=False + fig2, ax2 = uplt.subplots(ncols=2, proj="cyl", share="labels") + ax2[0].panel("top", share=False) + fig2.canvas.draw() + assert ax2[0]._panel_sharex_group is False + + +def test_geo_non_rectilinear_right_panel_forces_no_share_and_warns(): + """ + Non-rectilinear Geo projections should not allow panel sharing; adding a right panel + should warn and force panel share=False, and not promote the main axes to y panel group. + """ + fig, ax = uplt.subplots(nrows=1, proj="aeqd", share="labels") + with pytest.warns(uplt.warnings.UltraPlotWarning): + pax = ax[0].panel("right") # should warn and force share=False internally + fig.canvas.draw() + assert ax[0]._panel_sharey_group is False + assert pax._panel_share is False + + +def test_geo_non_rectilinear_top_panel_forces_no_share_and_warns(): + """ + Non-rectilinear Geo projections should not allow panel sharing; adding a top panel + should warn and force panel share=False, and not promote the main axes to x panel group. + """ + fig, ax = uplt.subplots(ncols=1, proj="aeqd", share="labels") + with pytest.warns(uplt.warnings.UltraPlotWarning): + pax = ax[0].panel("top") # should warn and force share=False internally + fig.canvas.draw() + assert ax[0]._panel_sharex_group is False + assert pax._panel_share is False + + def test_sharing_geo_limits(): """ Test that we can share limits on GeoAxes @@ -491,7 +566,8 @@ def test_get_gridliner_labels_cartopy(): uplt.close(fig) -def test_sharing_levels(): +@pytest.mark.parametrize("level", [0, 1, 2, 3, 4]) +def test_sharing_levels(level): """ We can share limits or labels. We check if we can do both for the GeoAxes. @@ -515,7 +591,6 @@ def test_sharing_levels(): x = np.array([0, 10]) y = np.array([0, 10]) - sharing_levels = [0, 1, 2, 3, 4] lonlim = latlim = np.array((-10, 10)) def assert_views_are_sharing(ax): @@ -551,46 +626,42 @@ def assert_views_are_sharing(ax): l2 = np.linalg.norm( np.asarray(latview) - np.asarray(target_lat), ) - level = ax.figure._get_sharing_level() + level = ax.figure._sharex if level <= 1: share_x = share_y = False assert np.allclose(l1, 0) == share_x assert np.allclose(l2, 0) == share_y - for level in sharing_levels: - fig, ax = uplt.subplots(ncols=2, nrows=2, proj="cyl", share=level) - ax.format(labels="both") - for axi in ax: - axi.format( - lonlim=lonlim * axi.number, - latlim=latlim * axi.number, - ) + fig, ax = uplt.subplots(ncols=2, nrows=2, proj="cyl", share=level) + ax.format(labels="both") + for axi in ax: + axi.format( + lonlim=lonlim * axi.number, + latlim=latlim * axi.number, + ) - fig.canvas.draw() - for idx, axi in enumerate(ax): - axi.plot(x * (idx + 1), y * (idx + 1)) - - fig.canvas.draw() # need this to update the labels - # All the labels should be on - for axi in ax: - side_labels = axi._get_gridliner_labels( - left=True, - right=True, - top=True, - bottom=True, - ) - s = 0 - for dir, labels in side_labels.items(): - s += any([label.get_visible() for label in labels]) - - assert_views_are_sharing(axi) - # When we share the labels but not the limits, - # we expect all ticks to be on - if level == 0: - assert s == 4 - else: - assert s == 2 - uplt.close(fig) + fig.canvas.draw() + for idx, axi in enumerate(ax): + axi.plot(x * (idx + 1), y * (idx + 1)) + + # All the labels should be on + for axi in ax: + + s = sum( + [ + 1 if axi._is_ticklabel_on(side) else 0 + for side in "labeltop labelbottom labelleft labelright".split() + ] + ) + + assert_views_are_sharing(axi) + # When we share the labels but not the limits, + # we expect all ticks to be on + if level > 2: + assert s == 2 + else: + assert s == 4 + uplt.close(fig) @pytest.mark.mpl_image_compare @@ -616,8 +687,10 @@ def test_cartesian_and_geo(rng): ax.format(land=True, lonlim=(-10, 10), latlim=(-10, 10)) ax[0].pcolormesh(rng.random((10, 10))) ax[1].scatter(*rng.random((2, 100))) - ax[0]._apply_axis_sharing() - assert mocked.call_count == 2 + fig.canvas.draw() + assert ( + mocked.call_count >= 2 + ) # needs to be called at least twice; one for each axis return fig @@ -676,21 +749,38 @@ def test_check_tricontourf(): def test_panels_geo(): fig, ax = uplt.subplots(proj="cyl") ax.format(labels=True) - for dir in "top bottom right left".split(): + dirs = "top bottom right left".split() + for dir in dirs: pax = ax.panel_axes(dir) - match dir: - case "top": - assert len(pax.get_xticklabels()) > 0 - assert len(pax.get_yticklabels()) > 0 - case "bottom": - assert len(pax.get_xticklabels()) > 0 - assert len(pax.get_yticklabels()) > 0 - case "left": - assert len(pax.get_xticklabels()) > 0 - assert len(pax.get_yticklabels()) > 0 - case "right": - assert len(pax.get_xticklabels()) > 0 - assert len(pax.get_yticklabels()) > 0 + fig.canvas.draw() + pax = ax[0]._panel_dict["left"][-1] + assert pax._is_ticklabel_on("labelleft") # should not error + assert not pax._is_ticklabel_on("labelright") + assert not pax._is_ticklabel_on("labeltop") + assert pax._is_ticklabel_on("labelbottom") + + pax = ax[0]._panel_dict["top"][-1] + assert pax._is_ticklabel_on("labelleft") # should not error + assert not pax._is_ticklabel_on("labelright") + assert not pax._is_ticklabel_on("labeltop") + assert not pax._is_ticklabel_on("labelbottom") + + pax = ax[0]._panel_dict["bottom"][-1] + assert pax._is_ticklabel_on("labelleft") # should not error + assert not pax._is_ticklabel_on("labelright") + assert not pax._is_ticklabel_on("labeltop") + assert pax._is_ticklabel_on("labelbottom") + + pax = ax[0]._panel_dict["right"][-1] + assert not pax._is_ticklabel_on("labelleft") # should not error + assert not pax._is_ticklabel_on("labelright") + assert not pax._is_ticklabel_on("labeltop") + assert pax._is_ticklabel_on("labelbottom") + + for dir in dirs: + not ax[0]._is_ticklabel_on(f"label{dir}") + + return fig @pytest.mark.mpl_image_compare @@ -717,10 +807,11 @@ def test_geo_with_panels(rng): elevation = np.clip(elevation, 0, 4000) fig, ax = uplt.subplots(nrows=2, proj="cyl") - pax = ax[0].panel("r") - pax.barh(lat_zoom, elevation.sum(axis=1)) - pax = ax[1].panel("r") - pax.barh(lat_zoom - 30, elevation.sum(axis=1)) + ax.format(lonlabels="r") # by default they are off + pax = ax.panel("r") + z = elevation.sum() + pax[0].barh(lat_zoom, elevation.sum(axis=1)) + pax[1].barh(lat_zoom - 30, elevation.sum(axis=1)) ax[0].pcolormesh( lon_zoom, lat_zoom, @@ -807,6 +898,7 @@ def are_labels_on(ax, which=("top", "bottom", "right", "left")) -> tuple[bool]: h = ax.imshow(data)[0] ax.format(land=True, labels="both") # need this otherwise no labels are printed fig.colorbar(h, loc="r") + fig.canvas.draw() # needed to invoke axis sharing expectations = ( [True, False, False, True], diff --git a/ultraplot/tests/test_inset.py b/ultraplot/tests/test_inset.py index ea1bf76af..9a1dfc611 100644 --- a/ultraplot/tests/test_inset.py +++ b/ultraplot/tests/test_inset.py @@ -7,6 +7,7 @@ def test_inset_basic(): # spacing, aspect ratios, and axis sharing gs = uplt.GridSpec(nrows=2, ncols=2) fig = uplt.figure(refwidth=1.5, share=False) + fig.canvas.draw() for ss, side in zip(gs, "tlbr"): ax = fig.add_subplot(ss) px = ax.panel_axes(side, width="3em") diff --git a/ultraplot/tests/test_subplots.py b/ultraplot/tests/test_subplots.py index 9a6d6d10d..13b402951 100644 --- a/ultraplot/tests/test_subplots.py +++ b/ultraplot/tests/test_subplots.py @@ -19,6 +19,18 @@ def test_align_labels(): return fig +@pytest.mark.mpl_image_compare +@pytest.mark.parametrize("share", [0, 1, 2, 3, 4]) +def test_all_share_levels(share): + N = 10 + x = np.arange(N) + fig, ax = uplt.subplots(nrows=2, ncols=2, share=share) + ax[0].plot(x, x) + ax[-1].plot(x * 1000, x * 1000) + ax.format(xlabel="xlabel", ylabel="ylabel", suptitle=f"Share level={share}") + return fig + + @pytest.mark.mpl_image_compare def test_share_all_basic(): """ @@ -290,29 +302,53 @@ def test_panel_sharing_top_right(layout): for dir in "left right top bottom".split(): pax = ax[0].panel(dir) fig.canvas.draw() # force redraw tick labels - for dir, paxs in ax[0]._panel_dict.items(): - # Since we are sharing some of the ticks - # should be hidden depending on where the panel is - # in the grid - for pax in paxs: - match dir: - case "left": - assert pax._is_ticklabel_on("labelleft") - assert pax._is_ticklabel_on("labelbottom") - case "top": - assert pax._is_ticklabel_on("labeltop") == False - assert pax._is_ticklabel_on("labelbottom") == False - assert pax._is_ticklabel_on("labelleft") - case "right": - print(pax._is_ticklabel_on("labelright")) - assert pax._is_ticklabel_on("labelright") == False - assert pax._is_ticklabel_on("labelbottom") - case "bottom": - assert pax._is_ticklabel_on("labelleft") - assert pax._is_ticklabel_on("labelbottom") == False - - # The sharing axis is not showing any ticks - assert ax[0]._is_ticklabel_on(dir) == False + + # Main panel: ticks are off + assert not ax[0]._is_ticklabel_on("labelleft") + assert not ax[0]._is_ticklabel_on("labelright") + assert not ax[0]._is_ticklabel_on("labeltop") + assert not ax[0]._is_ticklabel_on("labelbottom") + + # For panels the inside ticks are off + panel = ax[0]._panel_dict["left"][-1] + assert panel._is_ticklabel_on("labelleft") + assert panel._is_ticklabel_on("labelbottom") + assert not panel._is_ticklabel_on("labelright") + assert not panel._is_ticklabel_on("labeltop") + + panel = ax[0]._panel_dict["top"][-1] + assert panel._is_ticklabel_on("labelleft") + assert not panel._is_ticklabel_on("labelbottom") + assert not panel._is_ticklabel_on("labelright") + assert not panel._is_ticklabel_on("labeltop") + + panel = ax[0]._panel_dict["right"][-1] + assert not panel._is_ticklabel_on("labelleft") + assert panel._is_ticklabel_on("labelbottom") + assert not panel._is_ticklabel_on("labelright") + assert not panel._is_ticklabel_on("labeltop") + + panel = ax[0]._panel_dict["bottom"][-1] + assert panel._is_ticklabel_on("labelleft") + assert not panel._is_ticklabel_on("labelbottom") + assert not panel._is_ticklabel_on("labelright") + assert not panel._is_ticklabel_on("labeltop") + + assert not ax[1]._is_ticklabel_on("labelleft") + assert not ax[1]._is_ticklabel_on("labelright") + assert not ax[1]._is_ticklabel_on("labeltop") + assert not ax[1]._is_ticklabel_on("labelbottom") + + assert ax[2]._is_ticklabel_on("labelleft") + assert not ax[2]._is_ticklabel_on("labelright") + assert not ax[2]._is_ticklabel_on("labeltop") + assert ax[2]._is_ticklabel_on("labelbottom") + + assert not ax[3]._is_ticklabel_on("labelleft") + assert not ax[3]._is_ticklabel_on("labelright") + assert not ax[3]._is_ticklabel_on("labeltop") + assert ax[3]._is_ticklabel_on("labelbottom") + return fig @@ -330,6 +366,69 @@ def test_uneven_span_subplots(rng): @pytest.mark.mpl_image_compare +def test_uneven_span_subplots(rng): + fig = uplt.figure(refwidth=1, refnum=5, span=False) + axs = fig.subplots([[1, 1, 2], [3, 4, 2], [3, 4, 5]], hratios=[2.2, 1, 1]) + axs.format(xlabel="xlabel", ylabel="ylabel", suptitle="Complex SubplotGrid") + axs[0].format(ec="black", fc="gray1", lw=1.4) + axs[1, 1:].format(fc="blush") + axs[1, :1].format(fc="sky blue") + axs[-1, -1].format(fc="gray4", grid=False) + axs[0].plot((rng.random((50, 10)) - 0.5).cumsum(axis=0), cycle="Grays_r", lw=2) + return fig + + +@pytest.mark.parametrize("share_panels", [True, False]) +def test_panel_ticklabels_all_sides_share_and_no_share(share_panels): + # 2x2 grid; add panels on all sides of the first axes + fig, ax = uplt.subplots(nrows=2, ncols=2) + axi = ax[0] + + # Create panels on all sides with configurable sharing + pax_left = axi.panel("left", share=share_panels) + pax_right = axi.panel("right", share=share_panels) + pax_top = axi.panel("top", share=share_panels) + pax_bottom = axi.panel("bottom", share=share_panels) + + # Force draw so ticklabel state is resolved + fig.canvas.draw() + + def assert_panel(axi_panel, side, share_flag): + on_left = axi_panel._is_ticklabel_on("labelleft") + on_right = axi_panel._is_ticklabel_on("labelright") + on_top = axi_panel._is_ticklabel_on("labeltop") + on_bottom = axi_panel._is_ticklabel_on("labelbottom") + + # Inside (toward the main) must be off in all cases + if side == "left": + # Inside is right + assert not on_right + elif side == "right": + # Inside is left + assert not on_left + elif side == "top": + # Inside is bottom + assert not on_bottom + elif side == "bottom": + # Inside is top + assert not on_top + + if not share_flag: + # For non-sharing panels, prefer outside labels on for top/right + if side == "right": + assert on_right + if side == "top": + assert on_top + # For left/bottom non-sharing, we don't enforce outside on here + # (baseline may keep left/bottom on the main) + + # Check each panel side + assert_panel(pax_left, "left", share_panels) + assert_panel(pax_right, "right", share_panels) + assert_panel(pax_top, "top", share_panels) + assert_panel(pax_bottom, "bottom", share_panels) + + def test_non_rectangular_outside_labels_top(): """ Check that non-rectangular layouts work with outside labels. @@ -361,5 +460,174 @@ def test_outside_labels_with_panels(): for idx in range(5): ax[0].panel("left") ax.format(leftlabels=["A", "B"]) - uplt.show(block=1) return fig + + +def test_panel_group_membership_respects_figure_share_flags(): + """ + Ensure that panel-only configurations do not promote sharing when figure-level + sharing is disabled, and do promote when figure-level sharing is enabled. + """ + # Right-only panels with share=False should NOT mark y panel-group + fig, ax = uplt.subplots(nrows=2, share=False) + ax[0].panel("right") + fig.canvas.draw() + assert ax[0]._panel_sharey_group is False + + # Right-only panels with share='labels' SHOULD mark y panel-group + fig2, ax2 = uplt.subplots(nrows=2, share="labels") + ax2[0].panel("right") + fig2.canvas.draw() + assert ax2[0]._panel_sharey_group is True + + # Top-only panels with share=False should NOT mark x panel-group + fig3, ax3 = uplt.subplots(ncols=2, share=False) + ax3[0].panel("top") + fig3.canvas.draw() + assert ax3[0]._panel_sharex_group is False + + # Top-only panels with share='labels' SHOULD mark x panel-group + fig4, ax4 = uplt.subplots(ncols=2, share="labels") + ax4[0].panel("top") + fig4.canvas.draw() + assert ax4[0]._panel_sharex_group is True + + +def test_panel_share_flag_controls_group_membership(): + """ + Panels created with share=False should not join panel share groups even when + the figure has sharing enabled. + """ + # Y panels: right-only with panel share=False + fig, ax = uplt.subplots(nrows=2, share="labels") + ax[0].panel("right", share=False) + fig.canvas.draw() + assert ax[0]._panel_sharey_group is False + + # X panels: top-only with panel share=False + fig2, ax2 = uplt.subplots(ncols=2, share="labels") + ax2[0].panel("top", share=False) + fig2.canvas.draw() + assert ax2[0]._panel_sharex_group is False + + +def test_ticklabels_with_guides_share_true_cartesian(): + """ + With share=True, tick labels should only appear on bottom row and left column + even when colorbars and legends are present on borders. + """ + rng = np.random.default_rng(0) + fig, ax = uplt.subplots(nrows=2, ncols=2, share=True) + m = ax[0].pcolormesh(rng.random((8, 8)), colorbar="r") # outer right colorbar + ax[3].legend(loc="bottom") # bottom legend + fig.canvas.draw() + for i, axi in enumerate(ax): + on_left = axi._is_ticklabel_on("labelleft") + on_right = axi._is_ticklabel_on("labelright") + on_top = axi._is_ticklabel_on("labeltop") + on_bottom = axi._is_ticklabel_on("labelbottom") + + # Left column indices: 0, 2 + if i % 2 == 0: + assert on_left + assert not on_right + else: + assert not on_left + assert not on_right + + # Bottom row indices: 2, 3 + if i // 2 == 1: + assert on_bottom + assert not on_top + else: + assert not on_bottom + assert not on_top + + +def test_ticklabels_with_guides_share_true_geo(): + """ + With share=True on GeoAxes, tick labels should only appear on bottom row and left column + even when colorbars and legends are present on borders. + """ + rng = np.random.default_rng(1) + fig, ax = uplt.subplots(nrows=2, ncols=2, share=True, proj="cyl") + ax.format(labels="both", land=True) # ensure gridliner labels can be toggled + ax[0].pcolormesh(rng.random((10, 10)), colorbar="r") # outer right colorbar + ax[3].legend(loc="bottom") # bottom legend + fig.canvas.draw() + for i, axi in enumerate(ax): + on_left = axi._is_ticklabel_on("labelleft") + on_right = axi._is_ticklabel_on("labelright") + on_top = axi._is_ticklabel_on("labeltop") + on_bottom = axi._is_ticklabel_on("labelbottom") + if i == 0: + assert on_left + assert on_top + assert not on_bottom + assert not on_right + elif i == 1: + assert not on_left + assert on_top + assert not on_bottom + assert on_right + elif i == 2: + assert on_left + assert not on_top + assert on_bottom + assert not on_right + else: # i == 3 + assert not on_left + assert not on_top + assert on_bottom + assert on_right + + +def test_deep_panel_stacks_border_detection(): + """ + Multiple stacked panels on the same side should mark only the outermost panel + as the figure border for that side. The main axes should not be considered a + border once a panel exists on that side. + """ + fig, axs = uplt.subplots() + axi = axs[0] + # Stack multiple right panels + p1 = axi.panel("right") + p2 = axi.panel("right") + p3 = axi.panel("right") # outermost + # Stack multiple top panels + t1 = axi.panel("top") + t2 = axi.panel("top") # outermost + fig.canvas.draw() + + borders = fig._get_border_axes(force_recalculate=True) + # Main axes should not be the border on right/top anymore + assert axi not in borders.get("right", []) + assert axi not in borders.get("top", []) + # Outermost panels should be borders + assert p3 in borders.get("right", []) + assert t2 in borders.get("top", []) + + +def test_right_panel_and_right_colorbar_border_priority(): + """ + When both a right panel and a right colorbar exist, the colorbar (added last) + should be considered the outermost border on the right. The main axes should + not be listed as a right border. Accept either the panel or the colorbar + container as the right border, depending on backend/implementation details. + """ + rng = np.random.default_rng(0) + fig, axs = uplt.subplots() + axi = axs[0] + # Add a right panel first + pax = axi.panel("right") + # Add a right colorbar after plotting, making it the outermost right object + m = axi.pcolormesh(rng.random((5, 5))) + cbar = axi.colorbar(m, loc="right") + fig.canvas.draw() + + borders = fig._get_border_axes(force_recalculate=True) + right_borders = borders.get("right", []) + # Main axes should not be the right border anymore + assert axi not in right_borders + # Either the panel or the colorbar axes should be recognized as a right border + assert (pax in right_borders) or (cbar.ax in right_borders) diff --git a/ultraplot/utils.py b/ultraplot/utils.py index 1b1b97a95..be2a439a8 100644 --- a/ultraplot/utils.py +++ b/ultraplot/utils.py @@ -918,7 +918,8 @@ def _get_subplot_layout( axis types. This function is used internally to determine the layout of axes in a GridSpec. """ - grid = np.zeros((gs.nrows, gs.ncols)) + grid = np.zeros((gs.nrows_total, gs.ncols_total), dtype=object) + grid.fill(None) grid_axis_type = np.zeros((gs.nrows, gs.ncols)) # Collect grouper based on kinds of axes. This # would allow us to share labels across types @@ -928,7 +929,7 @@ def _get_subplot_layout( for axi in all_axes: # Infer coordinate from grdispec spec = axi.get_subplotspec() - spans = spec._get_grid_span() + spans = spec._get_grid_span(hidden=True) rowspan = spans[:2] colspan = spans[-2:] @@ -936,7 +937,7 @@ def _get_subplot_layout( grid[ slice(*rowspan), slice(*colspan), - ] = axi.number + ] = axi # Allow grouping of mixed types axis_type = 1 @@ -996,22 +997,28 @@ def find_edge_for( direction: str, d: tuple[int, int], ) -> tuple[str, bool]: - from itertools import product - """ Setup search for a specific direction. """ + from itertools import product + # Retrieve where the axis is in the grid spec = self.ax.get_subplotspec() - spans = spec._get_grid_span() + shape = (spec.get_gridspec().nrows_total, spec.get_gridspec().ncols_total) + x, y = np.unravel_index(spec.num1, shape) + spans = spec._get_rows_columns() rowspan = spans[:2] colspan = spans[-2:] - xs = range(*rowspan) - ys = range(*colspan) + + a = rowspan[1] - rowspan[0] + b = colspan[1] - colspan[0] + xs = range(x, x + a + 1) + ys = range(y, y + b + 1) + is_border = False - for x, y in product(xs, ys): - pos = (x, y) + for xl, yl in product(xs, ys): + pos = (xl, yl) if self.is_border(pos, d): is_border = True break @@ -1026,27 +1033,34 @@ def is_border( Recursively move over the grid by following the direction. """ x, y = pos - # Check if we are at an edge of the grid (out-of-bounds). - if x < 0: - return True - elif x > self.grid.shape[0] - 1: + # Edge of grid (out-of-bounds) + if not (0 <= x < self.grid.shape[0] and 0 <= y < self.grid.shape[1]): return True - if y < 0: - return True - elif y > self.grid.shape[1] - 1: - return True + cell = self.grid[x, y] + dx, dy = direction + if cell is None: + return self.is_border((x + dx, y + dy), direction) + if getattr(cell, "_colorbar_fill", None) is not None: + return self.is_border((x + dx, y + dy), direction) - if self.grid[x, y] == 0 or self.grid_axis_type[x, y] != self.axis_type: - return True + if hasattr(cell, "_panel_hidden") and cell._panel_hidden: + return self.is_border((x + dx, y + dy), direction) - # Check if we reached a plot or an internal edge - if self.grid[x, y] != self.target and self.grid[x, y] > 0: - return self._check_ranges(direction, other=self.grid[x, y]) + if self.grid_axis_type[x, y] != self.axis_type: + # Allow traversing across the parent<->panel interface even when types differ + # e.g., GeoAxes main with cartesian panel or vice versa + if getattr(self.ax, "_panel_parent", None) is cell: + return self._check_ranges(direction, other=cell) + if getattr(cell, "_panel_parent", None) is self.ax: + return self._check_ranges(direction, other=cell) + return False - dx, dy = direction - pos = (x + dx, y + dy) - return self.is_border(pos, direction) + # Internal edge or plot reached + if cell != self.ax: + return self._check_ranges(direction, other=cell) + + return self.is_border((x + dx, y + dy), direction) def _check_ranges( self, @@ -1065,14 +1079,15 @@ def _check_ranges( can share x. """ this_spec = self.ax.get_subplotspec() - other_spec = self.ax.figure._subplot_dict[other].get_subplotspec() + other_spec = other.get_subplotspec() # Get the row and column spans of both axes - this_span = this_spec._get_grid_span() + this_span = this_spec._get_rows_columns() this_rowspan = this_span[:2] this_colspan = this_span[-2:] - other_span = other_spec._get_grid_span() + other_span = other_spec._get_grid_span(hidden=True) + other_span = other_spec._get_rows_columns() other_rowspan = other_span[:2] other_colspan = other_span[-2:] @@ -1089,7 +1104,30 @@ def _check_ranges( other_start, other_stop = other_rowspan if this_start == other_start and this_stop == other_stop: - return False # not a border + # We may hit an internal border if we are at + # the interface with a panel that is not sharing + dmap = { + (-1, 0): "bottom", + (1, 0): "top", + (0, -1): "left", + (0, 1): "right", + } + side = dmap[direction] + if self.ax.number is None: # panel + panel_side = getattr(self.ax, "_panel_side", None) + # Non-sharing panels: border only on their outward side + if not getattr(self.ax, "_panel_share", False): + return side == panel_side + # Sharing panels: border only if this is the outward side and this + # panel is the outer-most panel for that side relative to its parent. + parent = self.ax._panel_parent + panels = parent._panel_dict.get(panel_side, []) + if side == panel_side and panels and panels[-1] is self.ax: + return True + else: # main axis + if other._panel_parent and not other._panel_share: + return True + return False return True