diff --git a/pyproject.toml b/pyproject.toml index cc8c984..3b65a2b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,12 +5,12 @@ authors = [ { name = "Chadwick Boulay", email = "chadwick.boulay@gmail.com" }, ] readme = "README.md" -requires-python = ">=3.10.15" +requires-python = ">=3.10" dynamic = ["version"] dependencies = [ - "ezmsg-baseproc>=1.1.0", - "ezmsg-sigproc>=2.4.0", - "ezmsg>=3.6.1", + "ezmsg>=3.7.3", + "ezmsg-baseproc>=1.5.1", + "ezmsg-sigproc>=2.17.0", "sparse>=0.17.0", "numpy>=2.2.6", ] diff --git a/src/ezmsg/event/peak.py b/src/ezmsg/event/peak.py index 739c0d5..42e8def 100644 --- a/src/ezmsg/event/peak.py +++ b/src/ezmsg/event/peak.py @@ -151,6 +151,24 @@ def _process(self, message: AxisArray) -> AxisArray: # Take note of how many samples were prepended. We will need this later when we modify `overs`. n_prepended = self._state.data.shape[0] + if data.shape[0] == 0: + # No data at all (empty buffer + empty message). Return empty sparse output. + result = sparse.COO( + np.zeros((data.ndim, 0), dtype=np.int64), + data=np.array([], dtype=data.dtype if self.settings.return_peak_val else bool), + shape=data.shape, + ) + return replace(message, data=result) + + if n_prepended == 0: + # No reference sample from previous iteration (e.g. first message after an empty-data reset). + # Duplicate the first sample as the reference, matching the convention that _reset_state + # stores data[:1] so it gets prepended on the next call. + data = xp.concat((data[:1], data), axis=0) + n_prepended = 1 + if self._state.data_raw is not None: + self._state.data_raw = xp.concat((self._state.data_raw[:1], self._state.data_raw), axis=0) + # Identify which data points are over threshold overs = data >= self.settings.threshold if self.settings.threshold >= 0 else data <= self.settings.threshold diff --git a/src/ezmsg/event/poissonevents.py b/src/ezmsg/event/poissonevents.py index caaf101..9667861 100644 --- a/src/ezmsg/event/poissonevents.py +++ b/src/ezmsg/event/poissonevents.py @@ -174,6 +174,24 @@ def _process(self, message: AxisArray) -> AxisArray: bin_duration = message.axes["time"].gain total_samples = n_bins * int(bin_duration * self.settings.output_fs) + if n_bins == 0: + ch_ax = message.get_axis_idx("ch") + n_channels = message.data.shape[ch_ax] + event_array = sparse.COO( + coords=np.zeros((2, 0), dtype=np.int64), + data=np.zeros(0, dtype=np.int8), + shape=(0, n_channels), + ) + return replace( + message, + data=event_array, + dims=["time", "ch"], + axes={ + **message.axes, + "time": replace(message.axes["time"], gain=1 / self.settings.output_fs), + }, + ) + # Get rates array with shape (n_bins, n_channels), contiguous for numba rates_array = message.data / bin_duration if self.settings.assume_counts else message.data if time_ax != 0: diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..6d7e299 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,44 @@ +"""Shared test fixtures and helpers for ezmsg-event tests.""" + +import numpy as np +import sparse +from ezmsg.util.messages.axisarray import AxisArray + +FS = 30_000.0 +N_CH = 4 +CHUNK_LEN = 600 # 20ms at 30kHz + + +def make_dense_msg(n_time: int, n_ch: int = N_CH, fs: float = FS, offset: float = 0.0) -> AxisArray: + """Create a dense AxisArray message with random data.""" + return AxisArray( + data=np.random.randn(n_time, n_ch), + dims=["time", "ch"], + axes={"time": AxisArray.Axis.TimeAxis(fs=fs, offset=offset)}, + ) + + +def make_sparse_event_msg( + n_time: int, n_ch: int = N_CH, fs: float = FS, density: float = 0.01, offset: float = 0.0 +) -> AxisArray: + """Create a sparse AxisArray message (COO format).""" + if n_time == 0: + s = sparse.COO(coords=np.zeros((2, 0), dtype=np.int64), data=np.zeros(0, dtype=bool), shape=(0, n_ch)) + else: + s = sparse.random((n_time, n_ch), density=density, random_state=np.random.default_rng(42)) > 0 + return AxisArray( + data=s, + dims=["time", "ch"], + axes={"time": AxisArray.Axis.TimeAxis(fs=fs, offset=offset)}, + ) + + +def make_rate_msg( + n_time: int, n_ch: int = N_CH, rate_hz: float = 50.0, bin_fs: float = 50.0, offset: float = 0.0 +) -> AxisArray: + """Create a dense AxisArray with firing-rate data (for PoissonEventTransformer).""" + return AxisArray( + data=np.full((n_time, n_ch), rate_hz), + dims=["time", "ch"], + axes={"time": AxisArray.Axis.TimeAxis(fs=bin_fs, offset=offset)}, + ) diff --git a/tests/test_binned.py b/tests/test_binned.py index 594f40e..69c2797 100644 --- a/tests/test_binned.py +++ b/tests/test_binned.py @@ -2,6 +2,7 @@ import numpy as np import sparse +from conftest import CHUNK_LEN, FS, N_CH, make_sparse_event_msg from ezmsg.util.messages.axisarray import AxisArray from ezmsg.event.binned import BinnedEventAggregator, BinnedEventAggregatorSettings @@ -52,3 +53,37 @@ def test_event_rate_binned(): stacked = AxisArray.concatenate(*out_msgs, dim="time") assert stacked.data.shape == expected.shape assert np.array_equal(stacked.data, expected.todense() / bin_dur) + + +def test_binned_event_aggregator_empty_time_after_init(): + """Normal → empty → normal: mid-stream empty message.""" + proc = BinnedEventAggregator(settings=BinnedEventAggregatorSettings(bin_duration=0.02)) + + msg1 = make_sparse_event_msg(CHUNK_LEN, offset=0.0) + msg_empty = make_sparse_event_msg(0, offset=CHUNK_LEN / FS) + msg2 = make_sparse_event_msg(CHUNK_LEN, offset=CHUNK_LEN / FS) + + out1 = proc(msg1) + assert out1.data.ndim == 2 + + out_empty = proc(msg_empty) + assert out_empty.data.ndim == 2 + + out2 = proc(msg2) + assert out2.data.ndim == 2 + assert out2.data.shape[1] == N_CH + + +def test_binned_event_aggregator_empty_time_first(): + """Empty → normal: empty first message triggers _reset_state on empty data.""" + proc = BinnedEventAggregator(settings=BinnedEventAggregatorSettings(bin_duration=0.02)) + + msg_empty = make_sparse_event_msg(0, offset=0.0) + msg_normal = make_sparse_event_msg(CHUNK_LEN, offset=0.0) + + out_empty = proc(msg_empty) + assert out_empty.data.ndim == 2 + + out_normal = proc(msg_normal) + assert out_normal.data.ndim == 2 + assert out_normal.data.shape[1] == N_CH diff --git a/tests/test_kernel_activation.py b/tests/test_kernel_activation.py index 508ce26..7b89750 100644 --- a/tests/test_kernel_activation.py +++ b/tests/test_kernel_activation.py @@ -3,6 +3,7 @@ import numpy as np import pytest import sparse +from conftest import CHUNK_LEN, FS, N_CH, make_sparse_event_msg from ezmsg.util.messages.axisarray import AxisArray from ezmsg.event.kernel_activation import ( @@ -406,3 +407,45 @@ def test_sum_aggregation(self): result = activation(message) assert result.data[0, 0] == 3.0 # Sum of all events in bin + + +class TestEmptyTime: + """Test handling of length-0 time dimension inputs.""" + + @pytest.mark.parametrize( + "kernel_type", + [ActivationKernelType.EXPONENTIAL, ActivationKernelType.ALPHA, ActivationKernelType.COUNT], + ) + def test_empty_time_after_init(self, kernel_type: ActivationKernelType): + """Normal → empty → normal: mid-stream empty message.""" + proc = BinnedKernelActivation(BinnedKernelActivationSettings(kernel_type=kernel_type, bin_duration=0.02)) + msg1 = make_sparse_event_msg(CHUNK_LEN, offset=0.0) + msg_empty = make_sparse_event_msg(0, offset=CHUNK_LEN / FS) + msg2 = make_sparse_event_msg(CHUNK_LEN, offset=CHUNK_LEN / FS) + + out1 = proc(msg1) + assert out1.data.ndim == 2 + + out_empty = proc(msg_empty) + assert out_empty.data.shape[0] == 0 or out_empty.data.ndim == 2 + + out2 = proc(msg2) + assert out2.data.ndim == 2 + assert out2.data.shape[1] == N_CH + + @pytest.mark.parametrize( + "kernel_type", + [ActivationKernelType.EXPONENTIAL, ActivationKernelType.ALPHA, ActivationKernelType.COUNT], + ) + def test_empty_time_first(self, kernel_type: ActivationKernelType): + """Empty → normal: empty first message triggers _reset_state on empty data.""" + proc = BinnedKernelActivation(BinnedKernelActivationSettings(kernel_type=kernel_type, bin_duration=0.02)) + msg_empty = make_sparse_event_msg(0, offset=0.0) + msg_normal = make_sparse_event_msg(CHUNK_LEN, offset=0.0) + + out_empty = proc(msg_empty) + assert out_empty.data.ndim == 2 + + out_normal = proc(msg_normal) + assert out_normal.data.ndim == 2 + assert out_normal.data.shape[1] == N_CH diff --git a/tests/test_kernel_insert.py b/tests/test_kernel_insert.py index 9976eb2..4f74482 100644 --- a/tests/test_kernel_insert.py +++ b/tests/test_kernel_insert.py @@ -1,7 +1,9 @@ """Unit tests for ezmsg.event.kernel_insert module.""" import numpy as np +import pytest import sparse +from conftest import CHUNK_LEN, FS, N_CH, make_sparse_event_msg from ezmsg.util.messages.axisarray import AxisArray from ezmsg.event.kernel import ArrayKernel, MultiKernel @@ -280,3 +282,46 @@ def test_zero_length_chunk(self): result = inserter(message) assert result.data.shape == (0, 2) + + +class TestSparseKernelInserterEmptyTime: + """Test handling of length-0 time dimension inputs.""" + + @pytest.mark.parametrize("use_kernel", [False, True]) + def test_empty_time_after_init(self, use_kernel: bool): + """Normal → empty → normal: mid-stream empty message.""" + from ezmsg.event.kernel import FunctionalKernel, exponential_kernel + + kernel = FunctionalKernel(exponential_kernel, sigma=0.001, fs=FS) if use_kernel else None + proc = SparseKernelInserter(SparseKernelInserterSettings(kernel=kernel)) + + msg1 = make_sparse_event_msg(CHUNK_LEN, offset=0.0) + msg_empty = make_sparse_event_msg(0, offset=CHUNK_LEN / FS) + msg2 = make_sparse_event_msg(CHUNK_LEN, offset=CHUNK_LEN / FS) + + out1 = proc(msg1) + assert out1.data.shape == (CHUNK_LEN, N_CH) + + out_empty = proc(msg_empty) + assert out_empty.data.shape[0] == 0 + assert out_empty.data.shape[1] == N_CH + + out2 = proc(msg2) + assert out2.data.shape == (CHUNK_LEN, N_CH) + + @pytest.mark.parametrize("use_kernel", [False, True]) + def test_empty_time_first(self, use_kernel: bool): + """Empty → normal: empty first message triggers _reset_state on empty data.""" + from ezmsg.event.kernel import FunctionalKernel, exponential_kernel + + kernel = FunctionalKernel(exponential_kernel, sigma=0.001, fs=FS) if use_kernel else None + proc = SparseKernelInserter(SparseKernelInserterSettings(kernel=kernel)) + + msg_empty = make_sparse_event_msg(0, offset=0.0) + msg_normal = make_sparse_event_msg(CHUNK_LEN, offset=0.0) + + out_empty = proc(msg_empty) + assert out_empty.data.shape[0] == 0 + + out_normal = proc(msg_normal) + assert out_normal.data.shape == (CHUNK_LEN, N_CH) diff --git a/tests/test_peak.py b/tests/test_peak.py index 7d1adb4..16d6f4c 100644 --- a/tests/test_peak.py +++ b/tests/test_peak.py @@ -6,12 +6,13 @@ import numpy as np import pytest import sparse +from conftest import CHUNK_LEN, FS, N_CH, make_dense_msg from ezmsg.util.messagecodec import message_log from ezmsg.util.messagelogger import MessageLogger from ezmsg.util.messages.chunker import ArrayChunker, array_chunker from ezmsg.util.terminate import TerminateOnTotal -from ezmsg.event.peak import ThresholdCrossing, ThresholdCrossingTransformer +from ezmsg.event.peak import ThresholdCrossing, ThresholdCrossingTransformer, ThresholdSettings from ezmsg.event.util.simulate import generate_white_noise_with_events @@ -138,3 +139,52 @@ def test_system(): assert isinstance(msg.data, sparse.SparseArray) assert msg.axes["time"].gain == 1 / fs assert np.round(msg.axes["time"].offset, 3) == np.round(msg_ix * chunk_dur, 3) + + +@pytest.mark.parametrize("return_peak_val", [False, True]) +@pytest.mark.parametrize("auto_scale_tau", [0.0, 0.5]) +def test_threshold_crossing_empty_time_after_init(return_peak_val: bool, auto_scale_tau: float): + """Normal → empty → normal: mid-stream empty message should not crash or corrupt state.""" + proc = ThresholdCrossingTransformer( + ThresholdSettings( + threshold=-3.5, + return_peak_val=return_peak_val, + auto_scale_tau=auto_scale_tau, + ) + ) + msg1 = make_dense_msg(CHUNK_LEN, offset=0.0) + msg_empty = make_dense_msg(0, offset=CHUNK_LEN / FS) + msg2 = make_dense_msg(CHUNK_LEN, offset=CHUNK_LEN / FS) + + out1 = proc(msg1) + assert isinstance(out1.data, sparse.SparseArray) + + out_empty = proc(msg_empty) + assert isinstance(out_empty.data, sparse.SparseArray) + assert out_empty.data.shape[1] == N_CH + + out2 = proc(msg2) + assert isinstance(out2.data, sparse.SparseArray) + assert out2.data.shape[1] == N_CH + + +@pytest.mark.parametrize("return_peak_val", [False, True]) +@pytest.mark.parametrize("auto_scale_tau", [0.0, 0.5]) +def test_threshold_crossing_empty_time_first(return_peak_val: bool, auto_scale_tau: float): + """Empty → normal: empty first message triggers _reset_state on empty data.""" + proc = ThresholdCrossingTransformer( + ThresholdSettings( + threshold=-3.5, + return_peak_val=return_peak_val, + auto_scale_tau=auto_scale_tau, + ) + ) + msg_empty = make_dense_msg(0, offset=0.0) + msg_normal = make_dense_msg(CHUNK_LEN, offset=0.0) + + out_empty = proc(msg_empty) + assert isinstance(out_empty.data, sparse.SparseArray) + + out_normal = proc(msg_normal) + assert isinstance(out_normal.data, sparse.SparseArray) + assert out_normal.data.shape[1] == N_CH diff --git a/tests/test_poissonevents.py b/tests/test_poissonevents.py index a419481..e29817b 100644 --- a/tests/test_poissonevents.py +++ b/tests/test_poissonevents.py @@ -1,6 +1,7 @@ import numpy as np import pytest import sparse +from conftest import FS, N_CH, make_rate_msg from ezmsg.util.messages.axisarray import AxisArray, CoordinateAxis, LinearAxis from ezmsg.event.poissonevents import ( @@ -348,3 +349,35 @@ def test_statistical_properties(self): # Allow 50% tolerance due to finite sample size assert actual_mean_isi > expected_mean_isi * 0.5 assert actual_mean_isi < expected_mean_isi * 2.0 + + def test_empty_time_after_init(self): + """Normal → empty → normal: mid-stream empty message.""" + proc = PoissonEventTransformer(PoissonEventSettings(output_fs=FS)) + + msg1 = make_rate_msg(5, offset=0.0) # 5 bins at 50 Hz = 100ms + msg_empty = make_rate_msg(0, offset=0.1) + msg2 = make_rate_msg(5, offset=0.1) + + out1 = proc(msg1) + assert isinstance(out1.data, sparse.SparseArray) + + out_empty = proc(msg_empty) + assert isinstance(out_empty.data, sparse.SparseArray) + + out2 = proc(msg2) + assert isinstance(out2.data, sparse.SparseArray) + assert out2.data.shape[1] == N_CH + + def test_empty_time_first(self): + """Empty → normal: empty first message triggers _reset_state on empty data.""" + proc = PoissonEventTransformer(PoissonEventSettings(output_fs=FS)) + + msg_empty = make_rate_msg(0, offset=0.0) + msg_normal = make_rate_msg(5, offset=0.0) + + out_empty = proc(msg_empty) + assert isinstance(out_empty.data, sparse.SparseArray) + + out_normal = proc(msg_normal) + assert isinstance(out_normal.data, sparse.SparseArray) + assert out_normal.data.shape[1] == N_CH diff --git a/tests/test_rate.py b/tests/test_rate.py index 9c7c43e..861e7cd 100644 --- a/tests/test_rate.py +++ b/tests/test_rate.py @@ -2,6 +2,7 @@ import numpy as np import sparse +from conftest import CHUNK_LEN, FS, N_CH, make_sparse_event_msg from ezmsg.util.messages.axisarray import AxisArray from ezmsg.event.rate import EventRateSettings, Rate @@ -57,3 +58,37 @@ def test_event_rate_composite(): expected = np.sum(s_proc, axis=1) / bin_dur assert stack.data.shape == expected.shape assert np.allclose(stack.data, expected) + + +def test_rate_empty_time_after_init(): + """Normal → empty → normal: mid-stream empty message.""" + proc = Rate(EventRateSettings(bin_duration=0.02)) + + msg1 = make_sparse_event_msg(CHUNK_LEN, offset=0.0) + msg_empty = make_sparse_event_msg(0, offset=CHUNK_LEN / FS) + msg2 = make_sparse_event_msg(CHUNK_LEN, offset=CHUNK_LEN / FS) + + out1 = proc(msg1) + assert out1.data.ndim == 2 + + out_empty = proc(msg_empty) + assert out_empty.data.shape[0] == 0 or out_empty.data.ndim == 2 + + out2 = proc(msg2) + assert out2.data.ndim == 2 + assert out2.data.shape[1] == N_CH + + +def test_rate_empty_time_first(): + """Empty → normal: empty first message triggers _reset_state on empty data.""" + proc = Rate(EventRateSettings(bin_duration=0.02)) + + msg_empty = make_sparse_event_msg(0, offset=0.0) + msg_normal = make_sparse_event_msg(CHUNK_LEN, offset=0.0) + + out_empty = proc(msg_empty) + assert out_empty.data.ndim == 2 + + out_normal = proc(msg_normal) + assert out_normal.data.ndim == 2 + assert out_normal.data.shape[1] == N_CH diff --git a/tests/test_refractory.py b/tests/test_refractory.py index af89703..f0e4d39 100644 --- a/tests/test_refractory.py +++ b/tests/test_refractory.py @@ -1,9 +1,10 @@ import numpy as np import pytest import sparse +from conftest import CHUNK_LEN, FS, N_CH, make_sparse_event_msg from ezmsg.util.messages.axisarray import AxisArray -from ezmsg.event.refractory import RefractoryTransformer +from ezmsg.event.refractory import RefractorySettings, RefractoryTransformer class TestRefractoryTransformer: @@ -221,3 +222,36 @@ def test_different_time_axis_positions(self, time_axis_position: int): # Both events are far enough apart to pass assert result.data.nnz == 2 + + def test_empty_time_after_init(self): + """Normal → empty → normal: mid-stream empty message.""" + proc = RefractoryTransformer(RefractorySettings(dur=0.001)) + + msg1 = make_sparse_event_msg(CHUNK_LEN, offset=0.0) + msg_empty = make_sparse_event_msg(0, offset=CHUNK_LEN / FS) + msg2 = make_sparse_event_msg(CHUNK_LEN, offset=CHUNK_LEN / FS) + + out1 = proc(msg1) + assert isinstance(out1.data, sparse.SparseArray) + + out_empty = proc(msg_empty) + assert isinstance(out_empty.data, sparse.SparseArray) + assert out_empty.data.shape[0] == 0 + + out2 = proc(msg2) + assert isinstance(out2.data, sparse.SparseArray) + assert out2.data.shape[1] == N_CH + + def test_empty_time_first(self): + """Empty → normal: empty first message triggers _reset_state on empty data.""" + proc = RefractoryTransformer(RefractorySettings(dur=0.001)) + + msg_empty = make_sparse_event_msg(0, offset=0.0) + msg_normal = make_sparse_event_msg(CHUNK_LEN, offset=0.0) + + out_empty = proc(msg_empty) + assert isinstance(out_empty.data, sparse.SparseArray) + + out_normal = proc(msg_normal) + assert isinstance(out_normal.data, sparse.SparseArray) + assert out_normal.data.shape[1] == N_CH