diff --git a/ultraplot/figure.py b/ultraplot/figure.py index d44f31e61..981fa2424 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -892,17 +892,42 @@ def _parse_proj( def _get_align_axes(self, side): """ Return the main axes along the edge of the figure. + + For 'left'/'right': select one extreme axis per row (leftmost/rightmost). + For 'top'/'bottom': select one extreme axis per column (topmost/bottommost). """ - x, y = "xy" if side in ("left", "right") else "yx" - axs = self._subplot_dict.values() + axs = tuple(self._subplot_dict.values()) if not axs: return [] - ranges = np.array([ax._range_subplotspec(x) for ax in axs]) - edge = ranges[:, 0].min() if side in ("left", "top") else ranges[:, 1].max() - idx = 0 if side in ("left", "top") else 1 - axs = [ax for ax in axs if ax._range_subplotspec(x)[idx] == edge] - axs = [ax for ax in sorted(axs, key=lambda ax: ax._range_subplotspec(y)[0])] - axs = [ax for ax in axs if ax.get_visible()] + if side not in ("left", "right", "top", "bottom"): + raise ValueError(f"Invalid side {side!r}.") + from .utils import _get_subplot_layout + + grid = _get_subplot_layout( + self._gridspec, list(self._iter_axes(panels=False, hidden=False)) + )[0] + # From the @side we find the first non-zero + # entry in each row or column and collect the axes + if side == "left": + options = grid + elif side == "right": + options = grid[:, ::-1] + elif side == "top": + options = grid.T + else: # bottom + options = grid.T[:, ::-1] + uids = set() + for option in options: + idx = np.where(option > 0)[0] + if idx.size > 0: + first = idx.min() + number = option[first].astype(int) + uids.add(number) + axs = [] + # Collect correct axes + for axi in self._iter_axes(): + if axi.number in uids and axi not in axs: + axs.append(axi) return axs def _get_border_axes( diff --git a/ultraplot/tests/test_subplots.py b/ultraplot/tests/test_subplots.py index e215a90ee..9a6d6d10d 100644 --- a/ultraplot/tests/test_subplots.py +++ b/ultraplot/tests/test_subplots.py @@ -327,3 +327,39 @@ def test_uneven_span_subplots(rng): axs[-1, -1].format(fc="gray4", grid=False) axs[0].plot((rng.random((50, 10)) - 0.5).cumsum(axis=0), cycle="Grays_r", lw=2) return fig + + +@pytest.mark.mpl_image_compare +def test_non_rectangular_outside_labels_top(): + """ + Check that non-rectangular layouts work with outside labels. + """ + layout = [ + [1, 1, 2, 2], + [0, 3, 3, 0], + [4, 4, 5, 5], + ] + + fig, ax = uplt.subplots( + layout, + ) + ax.format(rightlabels=[2, 3, 5]) + ax.format(bottomlabels=[4, 5]) + ax.format(leftlabels=[1, 3, 4]) + ax.format(toplabels=[1, 2]) + return fig + + +@pytest.mark.mpl_image_compare +def test_outside_labels_with_panels(): + fig, ax = uplt.subplots( + ncols=2, + nrows=2, + ) + # Create extreme case where we add a lot of panels + # This should push the left labels further left + for idx in range(5): + ax[0].panel("left") + ax.format(leftlabels=["A", "B"]) + uplt.show(block=1) + return fig