Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 29 additions & 15 deletions src/petab_gui/controllers/mother_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import qtawesome as qta
import yaml
from PySide6.QtCore import Qt, QUrl
from PySide6.QtCore import Qt, QTimer, QUrl
from PySide6.QtGui import QAction, QDesktopServices, QKeySequence, QUndoStack
from PySide6.QtWidgets import (
QFileDialog,
Expand Down Expand Up @@ -231,13 +231,19 @@ def setup_connections(self):
self.sbml_controller.overwritten_model.connect(
self.parameter_controller.update_handler_sbml
)
# overwrite signals
# Plotting update. Regulated through a Timer
self._plot_update_timer = QTimer()
self._plot_update_timer.setSingleShot(True)
self._plot_update_timer.setInterval(0)
self._plot_update_timer.timeout.connect(self.init_plotter)
for controller in [
# self.measurement_controller,
self.condition_controller
self.measurement_controller,
self.condition_controller,
self.visualization_controller,
self.simulation_controller,
]:
controller.overwritten_df.connect(
self.init_plotter
self._schedule_plot_update
)

def setup_actions(self):
Expand Down Expand Up @@ -871,19 +877,23 @@ def init_plotter(self):
self.plotter = self.view.plot_dock
self.plotter.highlighter.click_callback = self._on_plot_point_clicked

def _on_plot_point_clicked(self, x, y, label):
def _on_plot_point_clicked(self, x, y, label, data_type):
# Extract observable ID from label, if formatted like 'obsId (label)'
meas_proxy = self.measurement_controller.proxy_model
proxy = self.measurement_controller.proxy_model
view = self.measurement_controller.view.table_view
if data_type == "simulation":
proxy = self.simulation_controller.proxy_model
view = self.simulation_controller.view.table_view
obs = label

x_axis_col = "time"
y_axis_col = "measurement"
y_axis_col = data_type
observable_col = "observableId"

def column_index(name):
for col in range(meas_proxy.columnCount()):
for col in range(proxy.columnCount()):
if (
meas_proxy.headerData(col, Qt.Horizontal)
proxy.headerData(col, Qt.Horizontal)
== name
):
return col
Expand All @@ -893,16 +903,16 @@ def column_index(name):
y_col = column_index(y_axis_col)
obs_col = column_index(observable_col)

for row in range(meas_proxy.rowCount()):
row_obs = meas_proxy.index(row, obs_col).data()
row_x = meas_proxy.index(row, x_col).data()
row_y = meas_proxy.index(row, y_col).data()
for row in range(proxy.rowCount()):
row_obs = proxy.index(row, obs_col).data()
row_x = proxy.index(row, x_col).data()
row_y = proxy.index(row, y_col).data()
try:
row_x, row_y = float(row_x), float(row_y)
except ValueError:
continue
if row_obs == obs and row_x == x and row_y == y:
self.measurement_controller.view.table_view.selectRow(row)
view.selectRow(row)
break

def _on_table_selection_changed(self, selected, deselected):
Expand All @@ -919,3 +929,7 @@ def _on_simulation_selection_changed(self, selected, deselected):
proxy=self.simulation_controller.proxy_model,
y_axis_col="simulation"
)

def _schedule_plot_update(self):
"""Start the plot schedule timer."""
self._plot_update_timer.start()
102 changes: 88 additions & 14 deletions src/petab_gui/views/simple_plot_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from matplotlib import pyplot as plt
from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.backends.backend_qtagg import NavigationToolbar2QT
from matplotlib.container import ErrorbarContainer
from PySide6.QtCore import QObject, QRunnable, Qt, QThreadPool, QTimer, Signal
from PySide6.QtGui import QAction
from PySide6.QtWidgets import (
QDockWidget,
QMenu,
Expand All @@ -22,12 +24,13 @@ class PlotWorkerSignals(QObject):


class PlotWorker(QRunnable):
def __init__(self, vis_df, cond_df, meas_df, sim_df):
def __init__(self, vis_df, cond_df, meas_df, sim_df, group_by):
super().__init__()
self.vis_df = vis_df
self.cond_df = cond_df
self.meas_df = meas_df
self.sim_df = sim_df
self.group_by = group_by
self.signals = PlotWorkerSignals()

def run(self):
Expand Down Expand Up @@ -61,6 +64,7 @@ def run(self):
self.cond_df,
measurements_df=self.meas_df,
simulations_df=sim_df,
group_by=self.group_by,
)
fig = plt.gcf()
fig.subplots_adjust(left=0.12, bottom=0.15, right=0.95, top=0.9, wspace=0.3, hspace=0.4)
Expand All @@ -77,6 +81,7 @@ class MeasurementPlotter(QDockWidget):
def __init__(self, parent=None):
super().__init__("Measurement Plot", parent)
self.setObjectName("plot_dock")
self.options_manager = ToolbarOptionManager()

self.meas_proxy = None
self.sim_proxy = None
Expand All @@ -102,6 +107,7 @@ def initialize(self, meas_proxy, sim_proxy, cond_proxy):
self.vis_df = None

# Connect data changes
self.options_manager.option_changed.connect(self._debounced_plot)
self.meas_proxy.dataChanged.connect(self._debounced_plot)
self.meas_proxy.rowsInserted.connect(self._debounced_plot)
self.meas_proxy.rowsRemoved.connect(self._debounced_plot)
Expand All @@ -121,16 +127,26 @@ def plot_it(self):
measurements_df = proxy_to_dataframe(self.meas_proxy)
simulations_df = proxy_to_dataframe(self.sim_proxy)
conditions_df = proxy_to_dataframe(self.cond_proxy)
group_by = self.options_manager.get_option()
# group_by different value in petab.visualize
if group_by == "condition":
group_by = "simulation"

worker = PlotWorker(
self.vis_df, conditions_df, measurements_df, simulations_df
self.vis_df,
conditions_df,
measurements_df,
simulations_df,
group_by
)
worker.signals.finished.connect(self._update_tabs)
QThreadPool.globalInstance().start(worker)

def _update_tabs(self, fig: plt.Figure):
# Clean previous tabs
self.tab_widget.clear()
# Clear Highlighter
self.highlighter.clear_highlight()
if fig is None:
# Fallback: show one empty plot tab
empty_fig, _ = plt.subplots()
Expand Down Expand Up @@ -164,11 +180,18 @@ def _update_tabs(self, fig: plt.Figure):
for idx, ax in enumerate(fig.axes):
# Create a new figure and copy Axes content
sub_fig, sub_ax = plt.subplots(constrained_layout=True)
for line in ax.get_lines():
handles, labels = ax.get_legend_handles_labels()
for handle, label in zip(handles, labels, strict=False):
if isinstance(handle, ErrorbarContainer):
line = handle.lines[0]
elif isinstance(handle, plt.Line2D):
line = handle
else:
continue
sub_ax.plot(
line.get_xdata(),
line.get_ydata(),
label=line.get_label(),
label=label,
linestyle=line.get_linestyle(),
marker=line.get_marker(),
color=line.get_color(),
Expand All @@ -178,9 +201,7 @@ def _update_tabs(self, fig: plt.Figure):
sub_ax.set_title(ax.get_title())
sub_ax.set_xlabel(ax.get_xlabel())
sub_ax.set_ylabel(ax.get_ylabel())
handles, labels = ax.get_legend_handles_labels()
if handles:
sub_ax.legend(handles=handles, labels=labels, loc="best")
sub_ax.legend()

sub_canvas = FigureCanvas(sub_fig)
sub_toolbar = CustomNavigationToolbar(sub_canvas, self)
Expand All @@ -202,20 +223,18 @@ def _update_tabs(self, fig: plt.Figure):
obs_id = f"subplot_{idx}"

self.observable_to_subplot[obs_id] = idx
# Also register the original ax from the full figure (main tab)
self.highlighter.register_subplot(ax, idx)
# Register subplot canvas
self.highlighter.register_subplot(sub_ax, idx)
# Also register the original ax from the full figure (main tab)
self.highlighter.connect_picking(sub_canvas)

def highlight_from_selection(self, selected_rows: list[int], proxy=None, y_axis_col="measurement"):
proxy = proxy or self.meas_proxy
if not proxy:
return

# x_axis_col = self.x_axis_selector.currentText()
x_axis_col = "time"
y_axis_col = "measurement" if proxy == self.meas_proxy else "simulation"
observable_col = "observableId"

def column_index(name):
Expand Down Expand Up @@ -260,6 +279,9 @@ def __init__(self):
self.point_index_map = {} # (subplot index, observableId, x, y) → row index
self.click_callback = None

def clear_highlight(self):
self.highlight_scatters = defaultdict(list)

def register_subplot(self, ax, subplot_idx):
scatter = ax.scatter(
[], [], s=80, edgecolors='black', facecolors='none', zorder=5
Expand Down Expand Up @@ -293,24 +315,76 @@ def _on_pick(self, event):
ax = artist.axes

# Try to recover the label from the legend (handle → label mapping)
label = ax.get_legend().texts[1].get_text().split()[-1]
handles, labels = ax.get_legend_handles_labels()
label = None
for h, l in zip(handles, labels, strict=False):
if h is artist:
label_parts = l.split()
if label_parts[-1] == "simulation":
data_type = "simulation"
label = label_parts[-2]
else:
data_type = "measurement"
label = label_parts[-1]
break

for i in ind:
x = xdata[i]
y = ydata[i]
self.click_callback(x, y, label)
self.click_callback(x, y, label, data_type)


class ToolbarOptionManager(QObject):
"""A Manager, synchronizing the selected option across all toolbars."""

option_changed = Signal(str)
_instance = None
_initialized = False

def __new__(cls):
if cls._instance is None:
cls._instance = super(ToolbarOptionManager, cls).__new__(cls)
return cls._instance

def __init__(self):
# Ensure QObject.__init__ runs only once
if not self._initialized:
super().__init__()
self._selected_option = "observable"
ToolbarOptionManager._initialized = True

def set_option(self, option):
if option != self._selected_option:
self._selected_option = option
self.option_changed.emit(option)

def get_option(self):
return self._selected_option


class CustomNavigationToolbar(NavigationToolbar2QT):
def __init__(self, canvas, parent):
super().__init__(canvas, parent)
self.manager = ToolbarOptionManager()

self.settings_btn = QToolButton(self)
self.settings_btn.setIcon(qta.icon("mdi6.cog-outline"))
self.settings_btn.setPopupMode(QToolButton.InstantPopup)
self.settings_menu = QMenu(self.settings_btn)
self.settings_menu.addAction("Option 1")
self.settings_menu.addAction("Option 2")
self.groupy_by_options = {
grp: QAction(f"Groupy by {grp}", self)
for grp in ["observable", "dataset", "condition"]
}
for grp, action in self.groupy_by_options.items():
Comment on lines +374 to +378
Copy link

Copilot AI May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The UI label text 'Groupy by' appears to be a typo. Consider changing it to 'Group by' for clarity.

Suggested change
self.groupy_by_options = {
grp: QAction(f"Groupy by {grp}", self)
for grp in ["observable", "dataset", "condition"]
}
for grp, action in self.groupy_by_options.items():
self.group_by_options = {
grp: QAction(f"Group by {grp}", self)
for grp in ["observable", "dataset", "condition"]
}
for grp, action in self.group_by_options.items():

Copilot uses AI. Check for mistakes.
action.setCheckable(True)
action.triggered.connect(lambda _, grp=grp: self.manager.set_option(grp))
self.settings_menu.addAction(action)
self.manager.option_changed.connect(self.update_checked_state)
self.update_checked_state(self.manager.get_option())
self.settings_btn.setMenu(self.settings_menu)

self.addWidget(self.settings_btn)

def update_checked_state(self, selected_option):
for action in self.groupy_by_options.values():
action.setChecked(action.text() == f"Groupy by {selected_option}")