Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 33 additions & 8 deletions ultraplot/figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,17 +892,42 @@ def _parse_proj(
def _get_align_axes(self, side):
"""
Return the main axes along the edge of the figure.

For 'left'/'right': select one extreme axis per row (leftmost/rightmost).
For 'top'/'bottom': select one extreme axis per column (topmost/bottommost).
"""
x, y = "xy" if side in ("left", "right") else "yx"
axs = self._subplot_dict.values()
axs = tuple(self._subplot_dict.values())
if not axs:
return []
ranges = np.array([ax._range_subplotspec(x) for ax in axs])
edge = ranges[:, 0].min() if side in ("left", "top") else ranges[:, 1].max()
idx = 0 if side in ("left", "top") else 1
axs = [ax for ax in axs if ax._range_subplotspec(x)[idx] == edge]
axs = [ax for ax in sorted(axs, key=lambda ax: ax._range_subplotspec(y)[0])]
axs = [ax for ax in axs if ax.get_visible()]
if side not in ("left", "right", "top", "bottom"):
raise ValueError(f"Invalid side {side!r}.")
from .utils import _get_subplot_layout

grid = _get_subplot_layout(
self._gridspec, list(self._iter_axes(panels=False, hidden=False))
)[0]
# From the @side we find the first non-zero
# entry in each row or column and collect the axes
if side == "left":
options = grid
elif side == "right":
options = grid[:, ::-1]
elif side == "top":
options = grid.T
else: # bottom
options = grid.T[:, ::-1]
uids = set()
for option in options:
idx = np.where(option > 0)[0]
if idx.size > 0:
first = idx.min()
number = option[first].astype(int)
uids.add(number)
axs = []
# Collect correct axes
for axi in self._iter_axes():
if axi.number in uids and axi not in axs:
axs.append(axi)
return axs

def _get_border_axes(
Expand Down
36 changes: 36 additions & 0 deletions ultraplot/tests/test_subplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,3 +327,39 @@ def test_uneven_span_subplots(rng):
axs[-1, -1].format(fc="gray4", grid=False)
axs[0].plot((rng.random((50, 10)) - 0.5).cumsum(axis=0), cycle="Grays_r", lw=2)
return fig


@pytest.mark.mpl_image_compare
def test_non_rectangular_outside_labels_top():
"""
Check that non-rectangular layouts work with outside labels.
"""
layout = [
[1, 1, 2, 2],
[0, 3, 3, 0],
[4, 4, 5, 5],
]

fig, ax = uplt.subplots(
layout,
)
ax.format(rightlabels=[2, 3, 5])
ax.format(bottomlabels=[4, 5])
ax.format(leftlabels=[1, 3, 4])
ax.format(toplabels=[1, 2])
return fig


@pytest.mark.mpl_image_compare
def test_outside_labels_with_panels():
fig, ax = uplt.subplots(
ncols=2,
nrows=2,
)
# Create extreme case where we add a lot of panels
# This should push the left labels further left
for idx in range(5):
ax[0].panel("left")
ax.format(leftlabels=["A", "B"])
uplt.show(block=1)
return fig