diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index fb26a228ef..e86689ebf4 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -180,6 +180,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): **unitwaveformdensitymapwidget_kwargs, ) col_counter += 1 + ax_waveform_density.set_xlabel(None) ax_waveform_density.set_ylabel(None) if sorting_analyzer.has_extension("correlograms"): diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index 9543cbf734..a4260ac752 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -83,8 +83,8 @@ def __init__( templates = ext_templates.get_templates(unit_ids=unit_ids) bin_min = np.min(templates) * 1.3 bin_max = np.max(templates) * 1.3 - bin_size = (bin_max - bin_min) / 100 - bins = np.arange(bin_min, bin_max, bin_size) + num_bins = 100 + bins = np.linspace(bin_min, bin_max, num_bins + 1) # 2d histograms if same_axis: @@ -121,14 +121,9 @@ def __init__( wfs = wfs_ # make histogram density - wfs_flat = wfs.swapaxes(1, 2).reshape(wfs.shape[0], -1) - hist2d = np.zeros((wfs_flat.shape[1], bins.size)) - indexes0 = np.arange(wfs_flat.shape[1]) - - wf_bined = np.floor((wfs_flat - bin_min) / bin_size).astype("int32") - wf_bined = wf_bined.clip(0, bins.size - 1) - for d in wf_bined: - hist2d[indexes0, d] += 1 + wfs_flat = wfs.swapaxes(1, 2).reshape(wfs.shape[0], -1) # num_spikes x (num_channels * timepoints) + hists_per_timepoint = [np.histogram(one_timepoint, bins=bins)[0] for one_timepoint in wfs_flat.T] + hist2d = np.stack(hists_per_timepoint) if same_axis: if all_hist2d is None: @@ -162,6 +157,7 @@ def __init__( bin_min=bin_min, bin_max=bin_max, all_hist2d=all_hist2d, + sampling_frequency=sorting_analyzer.sampling_frequency, templates_flat=templates_flat, template_width=wfs.shape[1], ) @@ -169,53 +165,45 @@ def __init__( BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt from .utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) - if backend_kwargs["axes"] is not None or backend_kwargs["ax"] is not None: - self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - else: - if dp.same_axis: - num_axes = 1 - else: - num_axes = len(dp.unit_ids) + if backend_kwargs["axes"] is None and backend_kwargs["ax"] is None: backend_kwargs["ncols"] = 1 - backend_kwargs["num_axes"] = num_axes - self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + backend_kwargs["num_axes"] = 1 if dp.same_axis else len(dp.unit_ids) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + freq_khz = dp.sampling_frequency / 1000 # samples / msec if dp.same_axis: - ax = self.ax hist2d = dp.all_hist2d - im = ax.imshow( + x_max = len(hist2d) / freq_khz # in milliseconds + self.ax.imshow( hist2d.T, interpolation="nearest", origin="lower", aspect="auto", - extent=(0, hist2d.shape[0], dp.bin_min, dp.bin_max), + extent=(0, x_max, dp.bin_min, dp.bin_max), cmap="hot", ) else: - for unit_index, unit_id in enumerate(dp.unit_ids): + for ax, unit_id in zip(self.axes.flatten(), dp.unit_ids): hist2d = dp.all_hist2d[unit_id] - ax = self.axes.flatten()[unit_index] - im = ax.imshow( + x_max = len(hist2d) / freq_khz # in milliseconds + ax.imshow( hist2d.T, interpolation="nearest", origin="lower", aspect="auto", - extent=(0, hist2d.shape[0], dp.bin_min, dp.bin_max), + extent=(0, x_max, dp.bin_min, dp.bin_max), cmap="hot", ) for unit_index, unit_id in enumerate(dp.unit_ids): - if dp.same_axis: - ax = self.ax - else: - ax = self.axes.flatten()[unit_index] + ax = self.ax if dp.same_axis else self.axes.flatten()[unit_index] color = dp.unit_colors[unit_id] - ax.plot(dp.templates_flat[unit_id], color=color, lw=1) + x = np.arange(len(dp.templates_flat[unit_id])) / freq_khz + ax.plot(x, dp.templates_flat[unit_id], color=color, lw=1) # final cosmetics for unit_index, unit_id in enumerate(dp.unit_ids): @@ -228,11 +216,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): chan_inds = dp.channel_inds[unit_id] for i, chan_ind in enumerate(chan_inds): if i != 0: - ax.axvline(i * dp.template_width, color="w", lw=3) + ax.axvline(i * dp.template_width / freq_khz, color="w", lw=3) channel_id = dp.channel_ids[chan_ind] - x = i * dp.template_width + dp.template_width // 2 + x = (i + 0.5) * dp.template_width / freq_khz y = (dp.bin_max + dp.bin_min) / 2.0 ax.text(x, y, f"chan_id {channel_id}", color="w", ha="center", va="center") - ax.set_xticks([]) + ax.set_xlabel("Time [ms]") ax.set_ylabel(f"unit_id {unit_id}")