diff --git a/spatialpy/core/__init__.py b/spatialpy/core/__init__.py index bda84e58..b8f9d1dc 100644 --- a/spatialpy/core/__init__.py +++ b/spatialpy/core/__init__.py @@ -30,6 +30,7 @@ from .result import * from .spatialpyerror import * from .species import * +from .visualization import Visualization from .vtkreader import * _formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') diff --git a/spatialpy/core/domain.py b/spatialpy/core/domain.py index b6c763de..88649a91 100644 --- a/spatialpy/core/domain.py +++ b/spatialpy/core/domain.py @@ -27,6 +27,7 @@ from plotly.offline import init_notebook_mode, iplot from scipy.spatial import KDTree +from spatialpy.core.visualization import Visualization from spatialpy.core.spatialpyerror import DomainError class Domain(): @@ -204,8 +205,8 @@ def set_properties(self, geometry_ivar, type_id, vol=None, mass=None, nu=None, r """ Add a type definition to the domain. By default, all regions are set to type 0. - :param geometry_ivar: an instance of a :py:class:`spatialpy.core.geometry.Geometry` subclass. The 'inside()' method - of this object will be used to assign properties to points. + :param geometry_ivar: an instance of a :py:class:`spatialpy.core.geometry.Geometry` subclass. \ + The 'inside()' method of this object will be used to assign properties to points. :type geometry_ivar: spatialpy.core.geometry.Geometry :param type_id: The identifier for this type. @@ -270,8 +271,8 @@ def fill_with_particles(self, geometry_ivar, deltax, deltay=None, deltaz=None, x """ Fill a geometric shape with particles. - :param geometry_ivar: an instance of a :py:class:`spatialpy.core.geometry.Geometry` subclass. The 'inside()' method - of this object will be used to create add the particles. + :param geometry_ivar: an instance of a :py:class:`spatialpy.core.geometry.Geometry` subclass. \ + The 'inside()' method of this object will be used to create add the particles. :type geometry_ivar: spatialpy.core.geometry.Geometry :param deltax: Distance between particles on the x-axis. @@ -535,7 +536,7 @@ def calculate_vol(self): self.vol[v3] += t_vol / 4 self.vol[v4] += t_vol / 4 - def plot_types(self, width=None, height=None, colormap=None, size=5, title=None, + def plot_types(self, width=None, height=None, colormap=None, size=None, title=None, included_types_list=None, use_matplotlib=False, return_plotly_figure=False): ''' Plots the domain using plotly. Can only be viewed in a Jupyter Notebook. @@ -574,13 +575,10 @@ def plot_types(self, width=None, height=None, colormap=None, size=5, title=None, ''' from spatialpy.core.result import _plotly_iterate # pylint: disable=import-outside-toplevel - if use_matplotlib: - width = 6.4 if width in (None, "auto") else width - height = 4.8 if height in (None, "auto") else height - else: + if not use_matplotlib: if width in (None, "auto"): width = None if width == "auto" else 500 - if height is None: + if height in (None, "auto"): height = None if height == "auto" else 500 if not numpy.count_nonzero(self.vertices[:, 1]): @@ -593,32 +591,67 @@ def plot_types(self, width=None, height=None, colormap=None, size=5, title=None, self._get_type_name_mapping() types = {} + # Normalize volumes to [0, 1] + vols = (self.vol - numpy.min(self.vol))/numpy.ptp(self.vol) for i, type_id in enumerate(self.type_id): name = type_id[5:] if included_types_list is None or name in included_types_list: if name in types: types[name]['points'].append(self.vertices[i]) types[name]['data'].append(self.typeNdxMapping[type_id]) + types[name]['size_scale'] = numpy.append(types[name]['size_scale'], vols[i]) else: - types[name] = {"points":[self.vertices[i]], "data":[self.typeNdxMapping[type_id]]} + types[name] = { + "points": [self.vertices[i]], + "data": [self.typeNdxMapping[type_id]], + "size_scale": numpy.array([vols[i]]) + } if use_matplotlib: - import matplotlib.pyplot as plt # pylint: disable=import-outside-toplevel - - fig, ax = plt.subplots(figsize=(width, height)) - for name, data in types.items(): - x_coords = list(map(lambda point: point[0], data["points"])) - y_coords = list(map(lambda point: point[1], data["points"])) - - ax.scatter(x_coords, y_coords, label=name) - ax.grid(linestyle='--', linewidth=1) - ax.legend(loc='upper right', fontsize=12) - if title is not None: - ax.set_title(title) - - plt.axis('scaled') + if not isinstance(use_matplotlib, dict): + use_matplotlib = {} + use_matplotlib['limits'] = ( + (self.xlim[0] - 0.25, self.xlim[1] + 0.25), (self.ylim[0] - 0.25, self.ylim[1] + 0.25) + ) + + # Support for width, height, and title args + if width not in (None, "auto") and height not in (None, "auto"): + # TODO: Deprecation warning for width and height + plot_args = {"figsize": (width, height)} + + if "plot_args" in use_matplotlib: + for name, val in use_matplotlib['plot_args'].items(): + plot_args[name] = val + use_matplotlib['plot_args'] = plot_args + + base_group_args = {} + if colormap is not None: + base_group_args['cmap'] = colormap + base_group_args['vmin'] = 1 # minimum number of defined types + base_group_args['vmax'] = len(self.typeNdxMapping) # number of defined types + if size is not None: + base_group_args['s'] = size + + if "scatter_args" not in use_matplotlib: + use_matplotlib['scatter_args'] = {} + for type_id in self.typeNdxMapping.keys(): + type_id = type_id[5:] + group_args = base_group_args.copy() + if type_id in use_matplotlib['scatter_args']: + for name, val in use_matplotlib['scatter_args'][type_id].items(): + group_args[name] = val + use_matplotlib['scatter_args'][type_id] = group_args + + if title is not None: + use_matplotlib['title'] = title + + vis_obj = Visualization(data=types) + vis_obj.plot_scatter(**use_matplotlib) return + if size is None: + size = 5 + is_2d = self.dimensions == 2 trace_list = _plotly_iterate(types, size=size, property_name="type", @@ -860,7 +893,7 @@ def create_3D_domain(cls, xlim, ylim, zlim, nx, ny, nz, type_id=1, mass=1.0, :type fixed: bool :param \**kwargs: Additional keyword arguments passed to :py:class:`Domain`. - + :returns: Uniform 3D SpatialPy Domain object. :rtype: spatialpy.core.domain.Domain """ diff --git a/spatialpy/core/result.py b/spatialpy/core/result.py index a8408d8f..a32b6ee5 100644 --- a/spatialpy/core/result.py +++ b/spatialpy/core/result.py @@ -29,6 +29,7 @@ from plotly.offline import init_notebook_mode, iplot # from spatialpy.core.model import * +from spatialpy.core.visualization import Visualization from spatialpy.core.vtkreader import VTKReader from spatialpy.core.spatialpyerror import ResultError @@ -221,6 +222,9 @@ def __del__(self): def __map_property_to_type(self, property_name, data, included_types_list, points, p_ndx): types = {} if property_name == 'type': + # Normalize volumes to [0, 1] + vol = data['mass'] / data['rho'] + vols = (vol - numpy.min(vol))/numpy.ptp(vol) for i, val in enumerate(data['type']): name = self.model.domain.typeNameMapping[val][5:] @@ -228,9 +232,15 @@ def __map_property_to_type(self, property_name, data, included_types_list, point if name in types: types[name]['points'].append(points[i]) types[name]['data'].append(data[property_name][i]) + types[name]['size_scale'] = numpy.append(types[name]['size_scale'], vols[i]) else: - types[name] = {"points":[points[i]], "data":[data[property_name][i]]} - elif property_name == 'v': + types[name] = { + "points": [points[i]], + "data": [data[property_name][i]], + "size_scale": numpy.array([vols[i]]) + } + return types + if property_name == 'v': types[property_name] = { "points": points, "data" : [data[property_name][i][p_ndx] for i in range(0,len(data[property_name]))] @@ -326,7 +336,8 @@ def get_species(self, species, timepoints=None, concentration=False, determinist If set to True, the concentration (=copy_number/volume) is returned. Defaults to False :type concentration: bool - :param deterministic: Whether or not the species is deterministic (True) or stochastic (False). Defaults to False + :param deterministic: Whether or not the species is deterministic (True) or stochastic (False). \ + Defaults to False :type deterministic: bool :param debug: Whether or not debug information should be printed. Defaults to False @@ -614,7 +625,6 @@ def get_property(self, property_name, timepoints=None): t_index_arr = [t_index_arr] num_timepoints = 1 - if property_name == "v": ret = numpy.zeros((num_timepoints, num_voxel, 3)) else: @@ -630,7 +640,7 @@ def get_property(self, property_name, timepoints=None): return ret def plot_property(self, property_name, t_ndx=None, t_val=None, p_ndx=0, width=None, height=None, - colormap=None, size=5, title=None, animated=False, t_ndx_list=None, speed=1, + colormap=None, size=None, title=None, animated=False, t_ndx_list=None, speed=1, f_duration=500, t_duration=300, included_types_list=None, return_plotly_figure=False, use_matplotlib=False, debug=False): """ @@ -743,17 +753,47 @@ def plot_property(self, property_name, t_ndx=None, t_val=None, p_ndx=0, width=No if use_matplotlib: import matplotlib.pyplot as plt # pylint: disable=import-outside-toplevel + if not isinstance(use_matplotlib, dict): + use_matplotlib = {} + use_matplotlib['limits'] = ( + (self.model.domain.xlim[0] - 0.25, self.model.domain.xlim[1] + 0.25), + (self.model.domain.ylim[0] - 0.25, self.model.domain.ylim[1] + 0.25) + ) + + # Support for width, height, and title args + if width not in (None, "auto") and height not in (None, "auto"): + # TODO: Deprecation warning for width and height + plot_args = {"figsize": (width, height)} + + if "plot_args" in use_matplotlib: + for name, val in use_matplotlib['plot_args'].items(): + plot_args[name] = val + use_matplotlib['plot_args'] = plot_args + + base_group_args = {} + if colormap is not None: + base_group_args['cmap'] = colormap + base_group_args['vmin'] = 1 # minimum number of defined types + base_group_args['vmax'] = len(self.model.domain.typeNdxMapping) # number of defined types + if size is not None: + base_group_args['s'] = size + + if "scatter_args" not in use_matplotlib: + use_matplotlib['scatter_args'] = {} + for type_id in self.model.domain.typeNdxMapping.keys(): + type_id = type_id[5:] + group_args = base_group_args.copy() + if type_id in use_matplotlib['scatter_args']: + for name, val in use_matplotlib['scatter_args'][type_id].items(): + group_args[name] = val + use_matplotlib['scatter_args'][type_id] = group_args + + if title is not None: + use_matplotlib['title'] = title + if property_name == "type": - fig, ax = plt.subplots(figsize=(width, height)) - for name, data in types.items(): - x_coords = list(map(lambda point: point[0], data["points"])) - y_coords = list(map(lambda point: point[1], data["points"])) - - ax.scatter(x_coords, y_coords, label=name) - ax.grid(linestyle='--', linewidth=1) - ax.legend(loc='upper right', fontsize=12) - if title is not None: - ax.set_title(title) + vis_obj = Visualization(data=types) + vis_obj.plot_scatter(**use_matplotlib) else: if property_name == 'v': p_data = data[property_name] @@ -769,9 +809,12 @@ def plot_property(self, property_name, t_ndx=None, t_val=None, p_ndx=0, width=No if title is not None: plt.title(title) plt.grid(linestyle='--', linewidth=1) - plt.axis('scaled') + plt.axis('scaled') return + if size is None: + size = 5 + is_2d = self.model.domain.dimensions == 2 trace_list = _plotly_iterate(types, size=size, property_name=property_name, colormap=colormap, is_2d=is_2d) diff --git a/spatialpy/core/spatialpyerror.py b/spatialpy/core/spatialpyerror.py index ca32ff95..4fedb450 100644 --- a/spatialpy/core/spatialpyerror.py +++ b/spatialpy/core/spatialpyerror.py @@ -27,6 +27,11 @@ class ResultError(Exception): Class for exceptions in the results module. """ +class VisualizationError(Exception): + """ + Class for exceptions in the visualization module. + """ + class VTKReaderError(Exception): """ Bass class for exceptions in the vtkreader module. @@ -81,6 +86,9 @@ class SpeciesError(ModelError): # Result Exceptions +# Visualization Exceptions + + # VTKReader Exceptions class VTKReaderIOError(VTKReaderError): """ diff --git a/spatialpy/core/visualization.py b/spatialpy/core/visualization.py new file mode 100644 index 00000000..78df1fa5 --- /dev/null +++ b/spatialpy/core/visualization.py @@ -0,0 +1,186 @@ +''' +SpatialPy is a Python 3 package for simulation of +spatial deterministic/stochastic reaction-diffusion-advection problems +Copyright (C) 2019 - 2022 SpatialPy developers. + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU GENERAL PUBLIC LICENSE Version 3 as +published by the Free Software Foundation. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU GENERAL PUBLIC LICENSE Version 3 for more details. + +You should have received a copy of the GNU General Public License +along with this program. If not, see . +''' + +import math +import numpy + +try: + import matplotlib.pyplot as plt + mpl_import_err = None +except ImportError as err: + mpl_import_err = err + +from spatialpy.core.spatialpyerror import VisualizationError + +common_rgb_values=['#1f77b4','#ff7f0e','#2ca02c','#d62728','#9467bd','#8c564b','#e377c2','#7f7f7f', + '#bcbd22','#17becf','#ff0000','#00ff00','#0000ff','#ffff00','#00ffff','#ff00ff', + '#800000','#808000','#008000','#800080','#008080','#000080','#ff9999','#ffcc99', + '#ccff99','#cc99ff','#ffccff','#62666a','#8896bb','#77a096','#9d5a6c','#9d5a6c', + '#eabc75','#ff9600','#885300','#9172ad','#a1b9c4','#18749b','#dadecf','#c5b8a8', + '#000117','#13a8fe','#cf0060','#04354b','#0297a0','#037665','#eed284','#442244', + '#ffddee','#702afb'] + +def _get_coords(points): + labels = ["X-Axis", "Y-Axis", "Z-Axis"] + coords = [] + axes = [] + i = 0 + while len(coords) < 2 and i < 3: + vertex = list(map(lambda point: point[i], points)) + if numpy.count_nonzero(numpy.diff(vertex)) > 0: + coords.append(vertex) + axes.append(labels[i]) + i += 1 + if len(coords) < 2: + if labels[0] not in axes: + coords.insert(0, list(map(lambda point: point[0], points))) + axes.insert(0, labels[0]) + elif labels[1] not in axes: + coords.append(list(map(lambda point: point[1], points))) + axes.append(labels[1]) + return coords, axes + +def _validate_mplplot_args(args): + if args is None: + return {"figsize": (Visualization.MPL_WIDTH, Visualization.MPL_HEIGHT)} + + supported_splts_args = ["sharex", "sharey", "squeeze"] + supported_fig_args = ["figsize", "dpi", "facecolor", "edgecolor", "frameon"] + + kwargs = {} + for name, val in args.items(): + if name in supported_splts_args or name in supported_fig_args: + kwargs[name] = val + else: + from spatialpy.core import log # pylint: disable=import-outside-toplevel + logmsg = f"Un-supported key word argument: {name} is not currently supported" + log.warning(logmsg) + return kwargs + +class Visualization(): + + MPL_WIDTH = 6.4 + MPL_HEIGHT = 4.8 + MPL_SIZE = 40 + + def __init__(self, data): + self.data = data + + def __get_grid_shape(self, multiple_graphs): + if isinstance(multiple_graphs, tuple): + num_subplots = multiple_graphs[0] * multiple_graphs[1] + if num_subplots < len(self.data.keys()): + errmsg = f"The shape {multiple_graphs} of the graphs is to small for the given data" + raise VisualizationError(errmsg) + nrows = multiple_graphs[0] + ncols = multiple_graphs[1] + else: + nrows = math.ceil(len(self.data.keys()) / 2) + ncols = 2 + return nrows, ncols + + def __validate_mplscatter_args(self, args, name): + if args is None: + return {"s": Visualization.MPL_SIZE} + + supported_args = ["s", "cmap", "color", "marker", "vmin", "vmax"] + + group_args = args[name] + if "cmap" in group_args and "color" in group_args: + errmsg = f"scatter args for {name}: 'cmap' and 'color' cannot both be set." + raise VisualizationError(errmsg) + + kwargs = {} + for arg_name, val in group_args.items(): + if arg_name in supported_args: + kwargs[arg_name] = val + else: + from spatialpy.core import log # pylint: disable=import-outside-toplevel + logmsg = f"Un-supported key word argument: {arg_name} is not currently supported" + log.warning(logmsg) + if "s" not in kwargs: + if "size_scale" in self.data[name]: + mpl_smin = 18 + mpl_smax = 180 + kwargs['s'] = (mpl_smax - mpl_smin) * self.data[name]['size_scale'] + mpl_smin + else: + kwargs['s'] = Visualization.MPL_SIZE + return kwargs + + def plot_scatter(self, plot_args=None, scatter_args=None, multiple_graphs=False, title=None, limits=None): + """ + Visualize data using maplotlib scatter plots. + + :param plot_args: additional keyword arguments passed to :py:class:`matplotlib.pyplot.subplots` + :type plot_args: dict + + :param scatter_args: dict of additional keyword arguments passed to \ + :py:class:`matplotlib.pyplot.scatter` for each group. + :type scatter_args: dist + + :param multiple_graphs: if each data entry should be ploted separately or on the same plot. \ + If ploted separately a shape may be provided. + :type multiple_graphs: bool | tuple(nrows, ncols) + """ + if mpl_import_err is not None: + raise VisualizationError("Missing MatPlotLib dependency.") from mpl_import_err + + plot_args = _validate_mplplot_args(plot_args) + if multiple_graphs: + plot_args['nrows'], plot_args['ncols'] = self.__get_grid_shape(multiple_graphs) + + fig, axs = plt.subplots(**plot_args) + if multiple_graphs: + axs = axs.flatten() + if len(self.data.keys())%plot_args['ncols'] != 0: + fig.delaxes(axs[-1]) + axs = numpy.delete(axs, -1) + data_keys = list(self.data.keys()) + for index, ax in enumerate(axs): + name = data_keys[index] + coords, axes_labels = _get_coords(self.data[name]["points"]) + + group_args = self.__validate_mplscatter_args(scatter_args, name) + if "cmap" in group_args: + group_args['c'] = self.data[name]["data"] + + ax.scatter(*coords, label=name, **group_args) + ax.set_xlabel(axes_labels[0]) + ax.set_xlim(limits[0]) + ax.set_ylabel(axes_labels[1]) + ax.set_ylim(limits[1]) + ax.grid(linestyle='--', linewidth=1) + ax.legend(loc='upper right', fontsize=12) + if title is not None: + ax.set_title(title) + else: + for index, (name, data) in enumerate(self.data.items()): + x_coords = list(map(lambda point: point[0], data["points"])) + y_coords = list(map(lambda point: point[1], data["points"])) + + group_args = self.__validate_mplscatter_args(scatter_args, name) + if "cmap" in group_args: + group_args['c'] = data["data"] + + axs.scatter(x_coords, y_coords, label=name, **group_args) + axs.grid(linestyle='--', linewidth=1) + axs.legend(loc='upper right', fontsize=12) + if title is not None: + axs.set_title(title) + + plt.axis('scaled')