diff --git a/pyforestscan/visualize.py b/pyforestscan/visualize.py index fd7b6c2..89b2670 100644 --- a/pyforestscan/visualize.py +++ b/pyforestscan/visualize.py @@ -8,7 +8,7 @@ def plot_2d(points, x_dim='X', alpha=1.0, point_size=1, fig_size=None, fig_title=None, slice_dim=None, slice_val=0.0, - slice_tolerance=5 + slice_tolerance=5, save_fname=None ): """ Plots a 2D scatter plot of data points with customizable axes, coloring, and display settings. @@ -29,6 +29,7 @@ def plot_2d(points, x_dim='X', slice_val (float): The coordinate value at which to slice the slice_dim dimension. slice_tolerance (float): Allowed absolute difference when matching slice_val. (Helps with floating-point comparisons.) + save_fname (str): If provided, will be forwarded to `plt.savefig` to save the figure. Returns: None @@ -79,10 +80,13 @@ def plot_2d(points, x_dim='X', plt.ylabel(y_dim) plt.title(fig_title) plt.colorbar(label=colorbar_label) + if save_fname is not None: + plt.savefig(save_fname, dpi=300, bbox_inches='tight') plt.show() -def plot_metric(title, metric, extent, metric_name=None, cmap='viridis', fig_size=None): +def plot_metric(title, metric, extent, metric_name=None, cmap='viridis', fig_size=None, + save_fname=None): """ Plots a given metric using the provided data and configuration. @@ -100,6 +104,8 @@ def plot_metric(title, metric, extent, metric_name=None, cmap='viridis', fig_siz Colormap to be used for the plot. Default is 'viridis'. :param fig_size: tuple, optional Tuple specifying the size of the figure (width, height). Default is calculated based on the extent. + :param save_fname: str, optional + If provided, will be forwarded to `plt.savefig` to save the figure. :return: None :rtype: None """ @@ -123,11 +129,13 @@ def plot_metric(title, metric, extent, metric_name=None, cmap='viridis', fig_siz plt.title(title) plt.xlabel('X') plt.ylabel('Y') + if save_fname is not None: + plt.savefig(save_fname, dpi=300, bbox_inches='tight') plt.show() def plot_pad(pad, slice_index=None, axis='x', cmap='viridis', - hag_values=None, horizontal_values=None, title=None): + hag_values=None, horizontal_values=None, title=None, save_fname=None): """ Plots the plant area density (PAD) data as a 2D image visualization. @@ -153,6 +161,8 @@ def plot_pad(pad, slice_index=None, axis='x', cmap='viridis', title: Optional; A string specifying the title of the plot. If None, an appropriate default title will be generated based on the input parameters. + save_fname: Optional; A string that will be forwarded to `plt.savefig` to save + the figure. Returns: None. The function visualizes the PAD data using matplotlib. @@ -227,4 +237,6 @@ def plot_pad(pad, slice_index=None, axis='x', cmap='viridis', plt.xlabel(horizontal_axis_label) plt.ylabel('dZ') plt.tight_layout() + if save_fname is not None: + plt.savefig(save_fname, dpi=300, bbox_inches='tight') plt.show() \ No newline at end of file