diff --git a/src/petab_gui/C.py b/src/petab_gui/C.py index d566728..7a727c3 100644 --- a/src/petab_gui/C.py +++ b/src/petab_gui/C.py @@ -14,6 +14,17 @@ "datasetId": {"type": np.object_, "optional": True}, "replicateId": {"type": np.object_, "optional": True}, }, + "simulation": { + "observableId": {"type": np.object_, "optional": False}, + "preequilibrationConditionId": {"type": np.object_, "optional": True}, + "simulationConditionId": {"type": np.object_, "optional": False}, + "time": {"type": np.float64, "optional": False}, + "simulation": {"type": np.float64, "optional": False}, + "observableParameters": {"type": np.object_, "optional": True}, + "noiseParameters": {"type": np.object_, "optional": True}, + "datasetId": {"type": np.object_, "optional": True}, + "replicateId": {"type": np.object_, "optional": True}, + }, "observable": { "observableId": {"type": np.object_, "optional": False}, "observableName": {"type": np.object_, "optional": True}, @@ -42,6 +53,25 @@ "conditionId": {"type": np.object_, "optional": False}, "conditionName": {"type": np.object_, "optional": False}, }, + "visualization": { + "plotId": {"type": np.object_, "optional": False}, + "plotName": {"type": np.object_, "optional": True}, + "plotTypeSimulation": { + "type": np.object_, + "optional": True, + }, + "plotTypeData": {"type": np.object_, "optional": True}, + "datasetId": {"type": np.object_, "optional": True}, + "xValues": {"type": np.object_, "optional": True}, + "xOffset": {"type": np.float64, "optional": True}, + "xLabel": {"type": np.object_, "optional": True}, + "xScale": {"type": np.object_, "optional": True}, + "yValues": {"type": np.object_, "optional": True}, + "yOffset": {"type": np.float64, "optional": True}, + "yLabel": {"type": np.object_, "optional": True}, + "yScale": {"type": np.object_, "optional": True}, + "legendEntry": {"type": np.object_, "optional": True}, + } } CONFIG = { diff --git a/src/petab_gui/commands.py b/src/petab_gui/commands.py index 518cf67..9860fb9 100644 --- a/src/petab_gui/commands.py +++ b/src/petab_gui/commands.py @@ -141,12 +141,19 @@ def redo(self): df = self.model._data_frame if self.add_mode: - position = df.shape[0] - 1 # insert *before* the auto-row + position = 0 if df.empty else df.shape[0] - 1 # insert *before* the auto-row self.model.beginInsertRows( QModelIndex(), position, position + len(self.row_indices) - 1 ) + # save dtypes + dtypes = df.dtypes.copy() for _i, idx in enumerate(self.row_indices): df.loc[idx] = [np.nan] * df.shape[1] + # set dtypes + if np.any(dtypes != df.dtypes): + for col, dtype in dtypes.items(): + if dtype != df.dtypes[col]: + df[col] = df[col].astype(dtype) self.model.endInsertRows() else: self.model.beginRemoveRows( diff --git a/src/petab_gui/controllers/mother_controller.py b/src/petab_gui/controllers/mother_controller.py index e18290f..75d8c39 100644 --- a/src/petab_gui/controllers/mother_controller.py +++ b/src/petab_gui/controllers/mother_controller.py @@ -22,7 +22,11 @@ from ..models import PEtabModel from ..settings_manager import SettingsDialog, settings_manager -from ..utils import CaptureLogHandler, process_file +from ..utils import ( + CaptureLogHandler, + get_selected, + process_file, +) from ..views import TaskBar from .logger_controller import LoggerController from .sbml_controller import SbmlController @@ -31,6 +35,7 @@ MeasurementController, ObservableController, ParameterController, + VisualizationController, ) from .utils import ( RecentFilesManager, @@ -91,6 +96,20 @@ def __init__(self, view, model: PEtabModel): self.undo_stack, self, ) + self.visualization_controller = VisualizationController( + self.view.visualization_dock, + self.model.visualization, + self.logger, + self.undo_stack, + self, + ) + self.simulation_controller = MeasurementController( + self.view.simulation_dock, + self.model.simulation, + self.logger, + self.undo_stack, + self, + ) self.sbml_controller = SbmlController( self.view.sbml_viewer, self.model.sbml, self.logger, self ) @@ -100,6 +119,8 @@ def __init__(self, view, model: PEtabModel): self.parameter_controller, self.condition_controller, self.sbml_controller, + self.visualization_controller, + self.simulation_controller, ] # Recent Files self.recent_files_manager = RecentFilesManager(max_files=10) @@ -109,6 +130,8 @@ def __init__(self, view, model: PEtabModel): "observable": False, "parameter": False, "condition": False, + "visualization": False, + "simulation": False, } self.sbml_checkbox_states = {"sbml": False, "antimony": False} self.unsaved_changes = False @@ -120,13 +143,15 @@ def __init__(self, view, model: PEtabModel): self.setup_connections() self.setup_task_bar() self.setup_context_menu() + self.plotter = None + self.init_plotter() def setup_context_menu(self): """Sets up context menus for the tables.""" - self.measurement_controller.setup_context_menu(self.actions) - self.observable_controller.setup_context_menu(self.actions) - self.parameter_controller.setup_context_menu(self.actions) - self.condition_controller.setup_context_menu(self.actions) + for controller in self.controllers: + if controller == self.sbml_controller: + continue + controller.setup_context_menu(self.actions) def setup_task_bar(self): """Create shortcuts for the main window.""" @@ -169,9 +194,11 @@ def setup_connections(self): ) # Maybe Move to a Plot Model self.view.measurement_dock.table_view.selectionModel().selectionChanged.connect( - self.handle_selection_changed + self._on_table_selection_changed + ) + self.view.simulation_dock.table_view.selectionModel().selectionChanged.connect( + self._on_simulation_selection_changed ) - self.model.measurement.dataChanged.connect(self.handle_data_changed) # Unsaved Changes self.model.measurement.something_changed.connect( self.unsaved_changes_change @@ -185,6 +212,12 @@ def setup_connections(self): self.model.condition.something_changed.connect( self.unsaved_changes_change ) + self.model.visualization.something_changed.connect( + self.unsaved_changes_change + ) + self.model.simulation.something_changed.connect( + self.unsaved_changes_change + ) self.model.sbml.something_changed.connect(self.unsaved_changes_change) # Visibility self.sync_visibility_with_actions() @@ -198,6 +231,14 @@ def setup_connections(self): self.sbml_controller.overwritten_model.connect( self.parameter_controller.update_handler_sbml ) + # overwrite signals + for controller in [ + # self.measurement_controller, + self.condition_controller + ]: + controller.overwritten_df.connect( + self.init_plotter + ) def setup_actions(self): """Setup actions for the main controller.""" @@ -301,8 +342,9 @@ def setup_actions(self): self.filter_input.setPlaceholderText("Filter...") filter_layout.addWidget(self.filter_input) for table_n, table_name in zip( - ["m", "p", "o", "c"], - ["measurement", "parameter", "observable", "condition"], + ["m", "p", "o", "c", "v", "s"], + ["measurement", "parameter", "observable", "condition", + "visualization", "simulation"], strict=False, ): tool_button = QToolButton() @@ -325,7 +367,8 @@ def setup_actions(self): self.filter_input.textChanged.connect(self.filter_table) # show/hide elements - for element in ["measurement", "observable", "parameter", "condition"]: + for element in ["measurement", "observable", "parameter", + "condition", "visualization", "simulation"]: actions[f"show_{element}"] = QAction( f"{element.capitalize()} Table", self.view ) @@ -396,6 +439,8 @@ def sync_visibility_with_actions(self): "condition": self.view.condition_dock, "logger": self.view.logger_dock, "plot": self.view.plot_dock, + "visualization": self.view.visualization_dock, + "simulation": self.view.simulation_dock, } for key, dock in dock_map.items(): @@ -558,6 +603,10 @@ def _open_file(self, actionable, file_path, sep, mode): self.parameter_controller.open_table(file_path, sep, mode) elif actionable == "condition": self.condition_controller.open_table(file_path, sep, mode) + elif actionable == "visualization": + self.visualization_controller.open_table(file_path, sep, mode) + elif actionable == "simulation": + self.simulation_controller.open_table(file_path, sep, mode) elif actionable == "data_matrix": self.measurement_controller.process_data_matrix_file( file_path, mode, sep @@ -604,6 +653,14 @@ def open_yaml_and_load_files(self, yaml_path=None, mode="overwrite"): self.condition_controller.open_table( yaml_dir / yaml_content["problems"][0]["condition_files"][0] ) + # Visualization is optional + vis_path = yaml_content["problems"][0].get("visualization_files") + if vis_path: + self.visualization_controller.open_table( + yaml_dir / vis_path[0] + ) + else: + self.visualization_controller.clear_table() self.logger.log_message( "All files opened successfully from the YAML configuration.", color="green", @@ -721,6 +778,10 @@ def active_controller(self): return self.parameter_controller if active_widget == self.view.condition_dock.table_view: return self.condition_controller + if active_widget == self.view.visualization_dock.table_view: + return self.visualization_controller + if active_widget == self.view.simulation_dock.table_view: + return self.simulation_controller return None def delete_rows(self): @@ -799,3 +860,62 @@ def replace(self): if self.view.find_replace_bar is None: self.view.create_find_replace_bar() self.view.toggle_replace() + + def init_plotter(self): + """(Re-)initialize the plotter.""" + self.view.plot_dock.initialize( + self.measurement_controller.proxy_model, + self.simulation_controller.proxy_model, + self.condition_controller.proxy_model, + ) + self.plotter = self.view.plot_dock + self.plotter.highlighter.click_callback = self._on_plot_point_clicked + + def _on_plot_point_clicked(self, x, y, label): + # Extract observable ID from label, if formatted like 'obsId (label)' + meas_proxy = self.measurement_controller.proxy_model + obs = label + + x_axis_col = "time" + y_axis_col = "measurement" + observable_col = "observableId" + + def column_index(name): + for col in range(meas_proxy.columnCount()): + if ( + meas_proxy.headerData(col, Qt.Horizontal) + == name + ): + return col + raise ValueError(f"Column '{name}' not found.") + + x_col = column_index(x_axis_col) + y_col = column_index(y_axis_col) + obs_col = column_index(observable_col) + + for row in range(meas_proxy.rowCount()): + row_obs = meas_proxy.index(row, obs_col).data() + row_x = meas_proxy.index(row, x_col).data() + row_y = meas_proxy.index(row, y_col).data() + try: + row_x, row_y = float(row_x), float(row_y) + except ValueError: + continue + if row_obs == obs and row_x == x and row_y == y: + self.measurement_controller.view.table_view.selectRow(row) + break + + def _on_table_selection_changed(self, selected, deselected): + """Highlight the cells selected in measurement table.""" + selected_rows = get_selected( + self.measurement_controller.view.table_view + ) + self.plotter.highlight_from_selection(selected_rows) + + def _on_simulation_selection_changed(self, selected, deselected): + selected_rows = get_selected(self.simulation_controller.view.table_view) + self.plotter.highlight_from_selection( + selected_rows, + proxy=self.simulation_controller.proxy_model, + y_axis_col="simulation" + ) diff --git a/src/petab_gui/controllers/table_controllers.py b/src/petab_gui/controllers/table_controllers.py index 4156a86..100e51a 100644 --- a/src/petab_gui/controllers/table_controllers.py +++ b/src/petab_gui/controllers/table_controllers.py @@ -68,6 +68,8 @@ def __init__( self.undo_stack = undo_stack self.model.undo_stack = undo_stack self.check_petab_lint_mode = True + if model.table_type in ["simulation", "visualization"]: + self.check_petab_lint_mode = False self.mother_controller = mother_controller self.view.table_view.setModel(self.proxy_model) self.setup_connections() @@ -149,7 +151,9 @@ def open_table(self, file_path=None, separator=None, mode="overwrite"): if actionable in ["yaml", "sbml", "data_matrix", None]: # no table return try: - if self.model.table_type == "measurement": + if self.model.table_type in [ + "measurement", "visualization", "simulation" + ]: new_df = pd.read_csv(file_path, sep=separator) else: new_df = pd.read_csv(file_path, sep=separator, index_col=0) @@ -175,8 +179,6 @@ def open_table(self, file_path=None, separator=None, mode="overwrite"): self.model.reset_invalid_cells() def overwrite_df(self, new_df: pd.DataFrame): - # TODO: Mother controller connects to overwritten_df signal. Set df - # in petabProblem and unsaved changes to True """Overwrite the DataFrame of the model with the data from the view.""" self.proxy_model.setSourceModel(None) self.model.beginResetModel() @@ -318,9 +320,10 @@ def copy_to_clipboard(self): def paste_from_clipboard(self): """Paste the clipboard content to the currently selected cells.""" + old_lint = self.check_petab_lint_mode self.check_petab_lint_mode = False self.view.paste_from_clipboard() - self.check_petab_lint_mode = True + self.check_petab_lint_mode = old_lint try: self.check_petab_lint() except Exception as e: @@ -1134,3 +1137,27 @@ def check_petab_lint( condition_df=condition_df, model=sbml_model, ) + + +class VisualizationController(TableController): + """Controller of the Visualization table.""" + + def __init__( + self, + view: TableViewer, + model: PandasTableModel, + logger, + undo_stack, + mother_controller, + ): + """Initialize the table controller. + + See class:`TableController` for details. + """ + super().__init__( + view=view, + model=model, + logger=logger, + undo_stack=undo_stack, + mother_controller=mother_controller + ) diff --git a/src/petab_gui/models/__init__.py b/src/petab_gui/models/__init__.py index be23554..5c6af4a 100644 --- a/src/petab_gui/models/__init__.py +++ b/src/petab_gui/models/__init__.py @@ -7,10 +7,13 @@ from .pandas_table_model import ( ConditionModel, + IndexedPandasTableModel, MeasurementModel, ObservableModel, + PandasTableFilterProxy, PandasTableModel, ParameterModel, + VisualizationModel, ) from .petab_model import PEtabModel from .sbml_model import SbmlViewerModel diff --git a/src/petab_gui/models/pandas_table_model.py b/src/petab_gui/models/pandas_table_model.py index 4ff68e7..af88f07 100644 --- a/src/petab_gui/models/pandas_table_model.py +++ b/src/petab_gui/models/pandas_table_model.py @@ -138,10 +138,10 @@ def data(self, index, role=Qt.DisplayRole): if column == 0: return f"New {self.table_type}" return "" - if column == 0: + if column == 0 and self._has_named_index: value = self._data_frame.index[row] return str(value) - value = self._data_frame.iloc[row, column - 1] + value = self._data_frame.iloc[row, column - self.column_offset] if is_invalid(value): return "" return str(value) @@ -188,9 +188,9 @@ def headerData(self, section, orientation, role=Qt.DisplayRole): if role != Qt.DisplayRole: return None if orientation == Qt.Horizontal: - if section == 0: + if section == 0 and self._has_named_index: return self._data_frame.index.name - return self._data_frame.columns[section - 1] + return self._data_frame.columns[section - self.column_offset] if orientation == Qt.Vertical: return str(section) return None @@ -1054,11 +1054,12 @@ class MeasurementModel(PandasTableModel): possibly_new_condition = Signal(str) # Signal for new condition possibly_new_observable = Signal(str) # Signal for new observable - def __init__(self, data_frame, parent=None): + def __init__(self, data_frame, type: str = "measurement", parent=None): + allowed_columns = COLUMNS[type] super().__init__( data_frame=data_frame, - allowed_columns=COLUMNS["measurement"], - table_type="measurement", + allowed_columns=allowed_columns, + table_type=type, parent=parent, ) @@ -1080,39 +1081,6 @@ def get_default_values(self, index, changed: dict | None = None): else: command.redo() - def data(self, index, role=Qt.DisplayRole): - """Return the data at the given index and role for the View.""" - if not index.isValid(): - return None - row, column = index.row(), index.column() - if role == Qt.DisplayRole or role == Qt.EditRole: - if row == self._data_frame.shape[0]: - if column == 0: - return f"New {self.table_type}" - return "" - value = self._data_frame.iloc[row, column] - if is_invalid(value): - return "" - return str(value) - if role == Qt.BackgroundRole: - return self.determine_background_color(row, column) - if role == Qt.ForegroundRole: - # Return yellow text if this cell is a match - if (row, column) in self.highlighted_cells: - return QApplication.palette().color(QPalette.HighlightedText) - return QBrush(QColor(0, 0, 0)) # Default black text - return None - - def headerData(self, section, orientation, role=Qt.DisplayRole): - """Return the header data for the given section, orientation.""" - if role != Qt.DisplayRole: - return None - if orientation == Qt.Horizontal: - return self._data_frame.columns[section] - if orientation == Qt.Vertical: - return str(section) - return None - def return_column_index(self, column_name): """Return the index of a column.""" if column_name in self._data_frame.columns: @@ -1196,3 +1164,15 @@ def setDataFromText(self, text, start_row, start_column): @property def _invalid_cells(self): return self.source_model._invalid_cells + + +class VisualizationModel(PandasTableModel): + """Table model for the visualization data.""" + + def __init__(self, data_frame, parent=None): + super().__init__( + data_frame=data_frame, + allowed_columns=COLUMNS["visualization"], + table_type="visualization", + parent=parent, + ) diff --git a/src/petab_gui/models/petab_model.py b/src/petab_gui/models/petab_model.py index 0f5f857..7944b15 100644 --- a/src/petab_gui/models/petab_model.py +++ b/src/petab_gui/models/petab_model.py @@ -11,6 +11,7 @@ MeasurementModel, ObservableModel, ParameterModel, + VisualizationModel, ) from .sbml_model import SbmlViewerModel @@ -59,6 +60,11 @@ def __init__( ) self.measurement = MeasurementModel( data_frame=self.problem.measurement_df, + type="measurement", + ) + self.simulation = MeasurementModel( + data_frame=None, + type="simulation", ) self.observable = ObservableModel( data_frame=self.problem.observable_df, @@ -69,6 +75,9 @@ def __init__( self.condition = ConditionModel( data_frame=self.problem.condition_df, ) + self.visualization = VisualizationModel( + data_frame=self.problem.visualization_df, + ) @property def models(self): @@ -142,5 +151,6 @@ def current_petab_problem(self) -> petab.Problem: measurement_df=self.measurement.get_df(), observable_df=self.observable.get_df(), parameter_df=self.parameter.get_df(), + visualization_df=self.visualization.get_df(), model=self.sbml.get_current_sbml_model(), ) diff --git a/src/petab_gui/utils.py b/src/petab_gui/utils.py index bcd9188..07ce5da 100644 --- a/src/petab_gui/utils.py +++ b/src/petab_gui/utils.py @@ -410,12 +410,16 @@ def process_file(filepath, logger): # Case 3.2: Identify the table type based on header content if {"observableId", "measurement", "time"}.issubset(header): return "measurement", separator + if {"observableId", "simulation", "time"}.issubset(header): + return "simulation", separator if {"observableId", "observableFormula"}.issubset(header): return "observable", separator if "parameterId" in header: return "parameter", separator if "conditionId" in header or "\ufeffconditionId" in header: return "condition", separator + if "plotId" in header: + return "visualization", separator logger.log_message( f"Unrecognized table type for file: {filepath}. Uploading as " f"data matrix.", diff --git a/src/petab_gui/views/main_view.py b/src/petab_gui/views/main_view.py index ecfdf7a..8c5da4a 100644 --- a/src/petab_gui/views/main_view.py +++ b/src/petab_gui/views/main_view.py @@ -15,8 +15,8 @@ from ..settings_manager import settings_manager from .find_replace_bar import FindReplaceBar from .logger import Logger -from .measurement_plot import MeasuremenPlotter from .sbml_view import SbmlViewer +from .simple_plot_view import MeasurementPlotter from .table_view import TableViewer @@ -54,7 +54,9 @@ def __init__(self): self.logger_dock = QDockWidget("Info") self.logger_dock.setObjectName("logger_dock") self.logger_dock.setWidget(self.logger_views[1]) - self.plot_dock = MeasuremenPlotter(self) + self.plot_dock = MeasurementPlotter(self) + self.visualization_dock = TableViewer("Visualization Table") + self.simulation_dock = TableViewer("Simulation Table") self.dock_visibility = { self.condition_dock: self.condition_dock.isVisible(), @@ -63,6 +65,8 @@ def __init__(self): self.parameter_dock: self.parameter_dock.isVisible(), self.logger_dock: self.logger_dock.isVisible(), self.plot_dock: self.plot_dock.isVisible(), + self.visualization_dock: self.visualization_dock.isVisible(), + self.simulation_dock: self.simulation_dock.isVisible(), } self.default_view() self.condition_dock.visibilityChanged.connect( @@ -79,6 +83,12 @@ def __init__(self): ) self.logger_dock.visibilityChanged.connect(self.save_dock_visibility) self.plot_dock.visibilityChanged.connect(self.save_dock_visibility) + self.visualization_dock.visibilityChanged.connect( + self.save_dock_visibility + ) + self.simulation_dock.visibilityChanged.connect( + self.save_dock_visibility + ) # Allow docking in multiple areas self.data_tab.setDockOptions(QMainWindow.AllowNestedDocks) @@ -105,19 +115,21 @@ def default_view(self): # Get available geometry available_rect = self.data_tab.contentsRect() width = available_rect.width() // 2 - height = available_rect.height() // 3 + height = available_rect.height() // 4 x_left = available_rect.left() x_right = x_left + width - y_positions = [available_rect.top() + i * height for i in range(3)] + y_positions = [available_rect.top() + i * height for i in range(4)] # Define dock + positions layout = [ (self.measurement_dock, x_left, y_positions[0]), (self.parameter_dock, x_left, y_positions[1]), (self.logger_dock, x_left, y_positions[2]), + (self.visualization_dock, x_left, y_positions[3]), (self.observable_dock, x_right, self.measurement_dock), (self.condition_dock, x_right, self.parameter_dock), (self.plot_dock, x_right, self.logger_dock), + (self.simulation_dock, x_right, self.visualization_dock), ] for dock, x, y in layout: diff --git a/src/petab_gui/views/simple_plot_view.py b/src/petab_gui/views/simple_plot_view.py index e69de29..c755e15 100644 --- a/src/petab_gui/views/simple_plot_view.py +++ b/src/petab_gui/views/simple_plot_view.py @@ -0,0 +1,316 @@ +from collections import defaultdict + +import qtawesome as qta +from matplotlib import pyplot as plt +from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg as FigureCanvas +from matplotlib.backends.backend_qtagg import NavigationToolbar2QT +from PySide6.QtCore import QObject, QRunnable, Qt, QThreadPool, QTimer, Signal +from PySide6.QtWidgets import ( + QDockWidget, + QMenu, + QTabWidget, + QToolButton, + QVBoxLayout, + QWidget, +) + +from .utils import proxy_to_dataframe + + +class PlotWorkerSignals(QObject): + finished = Signal(object) # Emits final Figure + + +class PlotWorker(QRunnable): + def __init__(self, vis_df, cond_df, meas_df, sim_df): + super().__init__() + self.vis_df = vis_df + self.cond_df = cond_df + self.meas_df = meas_df + self.sim_df = sim_df + self.signals = PlotWorkerSignals() + + def run(self): + import petab.v1.visualize as petab_vis # Ensure this is thread-local + plt.close("all") + + if self.meas_df.empty or self.cond_df.empty: + self.signals.finished.emit(None) + return + sim_df = self.sim_df.copy() + if sim_df.empty: + sim_df = None + + try: + if self.vis_df is not None: + petab_vis.plot_with_vis_spec( + self.vis_df, + self.cond_df, + self.meas_df, + sim_df, + ) + fig = plt.gcf() + self.signals.finished.emit(fig) + return + except Exception as e: + print(f"Invalid Visualisation DF: {e}") + + # Fallback + plt.close("all") + petab_vis.plot_without_vis_spec( + self.cond_df, + measurements_df=self.meas_df, + simulations_df=sim_df, + ) + fig = plt.gcf() + fig.subplots_adjust(left=0.12, bottom=0.15, right=0.95, top=0.9, wspace=0.3, hspace=0.4) + self.signals.finished.emit(fig) + + +class PlotWidget(FigureCanvas): + def __init__(self): + self.fig, self.axes = plt.subplots() + super().__init__(self.fig) + + +class MeasurementPlotter(QDockWidget): + def __init__(self, parent=None): + super().__init__("Measurement Plot", parent) + self.setObjectName("plot_dock") + + self.meas_proxy = None + self.sim_proxy = None + self.cond_proxy = None + self.highlighter = MeasurementHighlighter() + + self.dock_widget = QWidget(self) + self.layout = QVBoxLayout(self.dock_widget) + self.layout.setContentsMargins(0, 0, 0, 0) + self.layout.setSpacing(2) + self.setWidget(self.dock_widget) + self.tab_widget = QTabWidget() + self.layout.addWidget(self.tab_widget) + self.update_timer = QTimer(self) + self.update_timer.setSingleShot(True) + self.update_timer.timeout.connect(self.plot_it) + self.observable_to_subplot = {} + + def initialize(self, meas_proxy, sim_proxy, cond_proxy): + self.meas_proxy = meas_proxy + self.cond_proxy = cond_proxy + self.sim_proxy = sim_proxy + self.vis_df = None + + # Connect data changes + self.meas_proxy.dataChanged.connect(self._debounced_plot) + self.meas_proxy.rowsInserted.connect(self._debounced_plot) + self.meas_proxy.rowsRemoved.connect(self._debounced_plot) + self.cond_proxy.dataChanged.connect(self._debounced_plot) + self.cond_proxy.rowsInserted.connect(self._debounced_plot) + self.cond_proxy.rowsRemoved.connect(self._debounced_plot) + self.sim_proxy.dataChanged.connect(self._debounced_plot) + self.sim_proxy.rowsInserted.connect(self._debounced_plot) + self.sim_proxy.rowsRemoved.connect(self._debounced_plot) + + self.plot_it() + + def plot_it(self): + if not self.meas_proxy or not self.cond_proxy: + return + + measurements_df = proxy_to_dataframe(self.meas_proxy) + simulations_df = proxy_to_dataframe(self.sim_proxy) + conditions_df = proxy_to_dataframe(self.cond_proxy) + + worker = PlotWorker( + self.vis_df, conditions_df, measurements_df, simulations_df + ) + worker.signals.finished.connect(self._update_tabs) + QThreadPool.globalInstance().start(worker) + + def _update_tabs(self, fig: plt.Figure): + # Clean previous tabs + self.tab_widget.clear() + if fig is None: + # Fallback: show one empty plot tab + empty_fig, _ = plt.subplots() + empty_canvas = FigureCanvas(empty_fig) + empty_toolbar = CustomNavigationToolbar(empty_canvas, self) + + tab = QWidget() + layout = QVBoxLayout(tab) + layout.setContentsMargins(0, 0, 0, 0) + layout.setSpacing(2) + layout.addWidget(empty_toolbar) + layout.addWidget(empty_canvas) + + self.tab_widget.addTab(tab, "All Plots") + return + + # Full figure tab + full_canvas = FigureCanvas(fig) + full_toolbar = CustomNavigationToolbar(full_canvas, self) + + full_tab = QWidget() + full_layout = QVBoxLayout(full_tab) + full_layout.setContentsMargins(0, 0, 0, 0) + full_layout.setSpacing(2) + full_layout.addWidget(full_toolbar) + full_layout.addWidget(full_canvas) + + self.tab_widget.addTab(full_tab, "All Plots") + + # One tab per Axes + for idx, ax in enumerate(fig.axes): + # Create a new figure and copy Axes content + sub_fig, sub_ax = plt.subplots(constrained_layout=True) + for line in ax.get_lines(): + sub_ax.plot( + line.get_xdata(), + line.get_ydata(), + label=line.get_label(), + linestyle=line.get_linestyle(), + marker=line.get_marker(), + color=line.get_color(), + alpha=line.get_alpha(), + picker=True, + ) + sub_ax.set_title(ax.get_title()) + sub_ax.set_xlabel(ax.get_xlabel()) + sub_ax.set_ylabel(ax.get_ylabel()) + handles, labels = ax.get_legend_handles_labels() + if handles: + sub_ax.legend(handles=handles, labels=labels, loc="best") + + sub_canvas = FigureCanvas(sub_fig) + sub_toolbar = CustomNavigationToolbar(sub_canvas, self) + + sub_tab = QWidget() + sub_layout = QVBoxLayout(sub_tab) + sub_layout.setContentsMargins(0, 0, 0, 0) + sub_layout.setSpacing(2) + sub_layout.addWidget(sub_toolbar) + sub_layout.addWidget(sub_canvas) + + self.tab_widget.addTab(sub_tab, f"Subplot {idx + 1}") + if ax.get_title(): + obs_id = ax.get_title() + elif ax.get_legend_handles_labels()[1]: + obs_id = ax.get_legend_handles_labels()[1][0] + obs_id = obs_id.split(" ")[-1] + else: + obs_id = f"subplot_{idx}" + + self.observable_to_subplot[obs_id] = idx + # Also register the original ax from the full figure (main tab) + self.highlighter.register_subplot(ax, idx) + # Register subplot canvas + self.highlighter.register_subplot(sub_ax, idx) + self.highlighter.connect_picking(sub_canvas) + + def highlight_from_selection(self, selected_rows: list[int], proxy=None, y_axis_col="measurement"): + proxy = proxy or self.meas_proxy + if not proxy: + return + + # x_axis_col = self.x_axis_selector.currentText() + x_axis_col = "time" + y_axis_col = "measurement" if proxy == self.meas_proxy else "simulation" + observable_col = "observableId" + + def column_index(name): + for col in range(proxy.columnCount()): + if proxy.headerData(col, Qt.Horizontal) == name: + return col + raise ValueError(f"Column '{name}' not found in proxy.") + + x_col = column_index(x_axis_col) + y_col = column_index(y_axis_col) + obs_col = column_index(observable_col) + + grouped_points = {} # subplot_idx → list of (x, y) + + for row in selected_rows: + x = proxy.index(row, x_col).data() + y = proxy.index(row, y_col).data() + try: + x = float(x) + y = float(y) + except ValueError: + pass + obs = proxy.index(row, obs_col).data() + subplot_idx = self.observable_to_subplot.get(obs) + if subplot_idx is not None: + grouped_points.setdefault(subplot_idx, []).append((x, y)) + + for subplot_idx, points in grouped_points.items(): + self.highlighter.update_highlight(subplot_idx, points) + + def _debounced_plot(self): + self.update_timer.start(1000) + + def update_visualization(self, plot_data): + print("OK") + return + + +class MeasurementHighlighter: + def __init__(self): + self.highlight_scatters = defaultdict(list) # (subplot index) → scatter artist + self.point_index_map = {} # (subplot index, observableId, x, y) → row index + self.click_callback = None + + def register_subplot(self, ax, subplot_idx): + scatter = ax.scatter( + [], [], s=80, edgecolors='black', facecolors='none', zorder=5 + ) + self.highlight_scatters[subplot_idx].append(scatter) + + def update_highlight(self, subplot_idx, points: list[tuple[float, float]]): + """Update highlighted points on one subplot.""" + for scatter in self.highlight_scatters.get(subplot_idx, []): + if points: + x, y = zip(*points, strict=False) + scatter.set_offsets(list(zip(x, y, strict=False))) + else: + scatter.set_offsets([]) + scatter.figure.canvas.draw_idle() + + def connect_picking(self, canvas): + canvas.mpl_connect("pick_event", self._on_pick) + + def _on_pick(self, event): + if not callable(self.click_callback): + return + + artist = event.artist + if not hasattr(artist, "get_xdata"): + return + + ind = event.ind + xdata = artist.get_xdata() + ydata = artist.get_ydata() + ax = artist.axes + + # Try to recover the label from the legend (handle → label mapping) + label = ax.get_legend().texts[1].get_text().split()[-1] + + for i in ind: + x = xdata[i] + y = ydata[i] + self.click_callback(x, y, label) + + +class CustomNavigationToolbar(NavigationToolbar2QT): + def __init__(self, canvas, parent): + super().__init__(canvas, parent) + + self.settings_btn = QToolButton(self) + self.settings_btn.setIcon(qta.icon("mdi6.cog-outline")) + self.settings_btn.setPopupMode(QToolButton.InstantPopup) + self.settings_menu = QMenu(self.settings_btn) + self.settings_menu.addAction("Option 1") + self.settings_menu.addAction("Option 2") + self.settings_btn.setMenu(self.settings_menu) + + self.addWidget(self.settings_btn) diff --git a/src/petab_gui/views/task_bar.py b/src/petab_gui/views/task_bar.py index 0137028..8d910b8 100644 --- a/src/petab_gui/views/task_bar.py +++ b/src/petab_gui/views/task_bar.py @@ -117,6 +117,8 @@ def __init__(self, parent, actions): self.menu.addAction(actions["show_condition"]) self.menu.addAction(actions["show_logger"]) self.menu.addAction(actions["show_plot"]) + self.menu.addAction(actions["show_visualization"]) + self.menu.addAction(actions["show_simulation"]) self.menu.addSeparator() self.menu.addAction(actions["reset_view"]) self.menu.addAction(actions["clear_log"]) diff --git a/src/petab_gui/views/utils.py b/src/petab_gui/views/utils.py new file mode 100644 index 0000000..66a4549 --- /dev/null +++ b/src/petab_gui/views/utils.py @@ -0,0 +1,40 @@ +import pandas as pd +from PySide6.QtCore import Qt + + +def proxy_to_dataframe(proxy_model): + rows = proxy_model.rowCount() + cols = proxy_model.columnCount() + + headers = [proxy_model.headerData(c, Qt.Horizontal) for c in range(cols)] + data = [] + + for r in range(rows-1): + row = { + headers[c]: proxy_model.index(r, c).data() + for c in range(cols) + } + for key, value in row.items(): + if isinstance(value, str) and value == "": + row[key] = None + data.append(row) + if not data: + return pd.DataFrame() + if proxy_model.source_model.table_type == "condition": + data = pd.DataFrame(data).set_index("conditionId") + elif proxy_model.source_model.table_type == "observable": + data = pd.DataFrame(data).set_index("observableId") + elif proxy_model.source_model.table_type == "parameter": + data = pd.DataFrame(data).set_index("parameterId") + elif proxy_model.source_model.table_type == "measurement": + # turn measurement and time to float + data = pd.DataFrame(data) + data["measurement"] = data["measurement"].astype(float) + data["time"] = data["time"].astype(float) + elif proxy_model.source_model.table_type == "simulation": + # turn simulation and time to float + data = pd.DataFrame(data) + data["simulation"] = data["simulation"].astype(float) + data["time"] = data["time"].astype(float) + + return data