Skip to content
30 changes: 27 additions & 3 deletions optimas/diagnostics/exploration_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,7 @@ def plot_history(
select: Optional[Dict] = None,
sort: Optional[Dict] = None,
top: Optional[Dict] = None,
parnames_as_titles: Optional[bool] = True,
show_top_evaluation_indices: Optional[bool] = False,
show_legend: Optional[bool] = False,
subplot_spec: Optional[SubplotSpec] = None,
Expand All @@ -617,6 +618,9 @@ def plot_history(
e.g. {'f': False} sort simulations according to f descendingly.
top: int, optional
Highlight the 'top' evaluations of every objective.
parnames_as_titles : bool, optional
Whether to show the parameter names as titles (instead of labels)
for a more compact layout.
show_top_evaluation_indices : bool, optional
Whether to show the indices of the top evaluations.
show_legend : bool, optional
Expand Down Expand Up @@ -674,13 +678,25 @@ def plot_history(

# Make figure
nplots = len(parnames)
if parnames_as_titles is True:
hspace = 0.40
else:
hspace = None

if subplot_spec is None:
fig = plt.figure(**figure_kw)
gs = GridSpec(nplots, 2, width_ratios=[0.8, 0.2], wspace=0.05)
gs = GridSpec(
nplots, 2, width_ratios=[0.8, 0.2], wspace=0.05, hspace=hspace
)
else:
fig = plt.gcf()
gs = GridSpecFromSubplotSpec(
nplots, 2, subplot_spec, width_ratios=[0.8, 0.2], wspace=0.05
nplots,
2,
subplot_spec,
width_ratios=[0.8, 0.2],
wspace=0.05,
hspace=hspace,
)

# Actual plotting
Expand Down Expand Up @@ -792,7 +808,15 @@ def plot_history(
ax_histy.set_ylim(ax_scatter.get_ylim())

# Tuning axes and labels
ax_scatter.set_ylabel(parnames[i])
if parnames_as_titles is True:
ax_scatter.set_title(
parnames[i].replace("_", " "),
fontdict={"fontsize": "x-small"},
loc="right",
pad=2,
)
else:
ax_scatter.set_ylabel(parnames[i])

if i != nplots - 1:
ax_scatter.tick_params(labelbottom=False)
Expand Down
12 changes: 8 additions & 4 deletions optimas/utils/ax/ax_model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from matplotlib.figure import Figure
from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec, SubplotSpec
from matplotlib.axes import Axes
from mpl_toolkits.axes_grid1 import make_axes_locatable

# Ax utilities for model building
try:
Expand Down Expand Up @@ -358,6 +359,7 @@ def plot_contour(
range_x: Optional[List[float]] = None,
range_y: Optional[List[float]] = None,
mode: Optional[Literal["mean", "sem", "both"]] = "mean",
cbar_location: Optional[Literal["top", "right"]] = "top",
show_trials: Optional[bool] = True,
show_contour: Optional[bool] = True,
show_contour_labels: Optional[bool] = False,
Expand Down Expand Up @@ -393,6 +395,8 @@ def plot_contour(
mode : str, optional.
Whether to plot the ``"mean"`` of the model, the standard error of
the mean ``"sem"``, or ``"both"``. By default, ``"mean"``.
cbar_location : str, optional.
Set position of the colorbar. By default, ``"top"``.
show_trials : bool
Whether to show the trials used to build the model. By default,
``True``.
Expand Down Expand Up @@ -503,7 +507,9 @@ def plot_contour(
# colormesh
pcolormesh_kw = dict(pcolormesh_kw or {})
im = ax.pcolormesh(xaxis, yaxis, f, shading="auto", **pcolormesh_kw)
cbar = plt.colorbar(im, ax=ax, location="top")
divider = make_axes_locatable(ax)
cax = divider.append_axes(cbar_location, size="2.5%", pad=0.1)
cbar = plt.colorbar(im, cax=cax, location=cbar_location)
cbar.set_label(labels[i])
ax.set(xlabel=param_x, ylabel=param_y)
# contour
Expand All @@ -518,9 +524,7 @@ def plot_contour(
linestyles="solid",
)
if show_contour_labels:
ax.clabel(
cset, inline=True, fmt="%1.1f", fontsize="xx-small"
)
ax.clabel(cset, inline=True, fontsize="xx-small")
if show_trials:
ax.scatter(
trials[param_x], trials[param_y], s=8, c="black", marker="o"
Expand Down