Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ authors = [
{ name = "Chadwick Boulay", email = "chadwick.boulay@gmail.com" },
]
readme = "README.md"
requires-python = ">=3.10.15"
requires-python = ">=3.10"
dynamic = ["version"]
dependencies = [
"ezmsg-baseproc>=1.1.0",
"ezmsg-sigproc>=2.4.0",
"ezmsg>=3.6.1",
"ezmsg>=3.7.3",
"ezmsg-baseproc>=1.5.1",
"ezmsg-sigproc>=2.17.0",
"sparse>=0.17.0",
"numpy>=2.2.6",
]
Expand Down
18 changes: 18 additions & 0 deletions src/ezmsg/event/peak.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,24 @@ def _process(self, message: AxisArray) -> AxisArray:
# Take note of how many samples were prepended. We will need this later when we modify `overs`.
n_prepended = self._state.data.shape[0]

if data.shape[0] == 0:
# No data at all (empty buffer + empty message). Return empty sparse output.
result = sparse.COO(
np.zeros((data.ndim, 0), dtype=np.int64),
data=np.array([], dtype=data.dtype if self.settings.return_peak_val else bool),
shape=data.shape,
)
return replace(message, data=result)

if n_prepended == 0:
# No reference sample from previous iteration (e.g. first message after an empty-data reset).
# Duplicate the first sample as the reference, matching the convention that _reset_state
# stores data[:1] so it gets prepended on the next call.
data = xp.concat((data[:1], data), axis=0)
n_prepended = 1
if self._state.data_raw is not None:
self._state.data_raw = xp.concat((self._state.data_raw[:1], self._state.data_raw), axis=0)

# Identify which data points are over threshold
overs = data >= self.settings.threshold if self.settings.threshold >= 0 else data <= self.settings.threshold

Expand Down
18 changes: 18 additions & 0 deletions src/ezmsg/event/poissonevents.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,24 @@ def _process(self, message: AxisArray) -> AxisArray:
bin_duration = message.axes["time"].gain
total_samples = n_bins * int(bin_duration * self.settings.output_fs)

if n_bins == 0:
ch_ax = message.get_axis_idx("ch")
n_channels = message.data.shape[ch_ax]
event_array = sparse.COO(
coords=np.zeros((2, 0), dtype=np.int64),
data=np.zeros(0, dtype=np.int8),
shape=(0, n_channels),
)
return replace(
message,
data=event_array,
dims=["time", "ch"],
axes={
**message.axes,
"time": replace(message.axes["time"], gain=1 / self.settings.output_fs),
},
)

# Get rates array with shape (n_bins, n_channels), contiguous for numba
rates_array = message.data / bin_duration if self.settings.assume_counts else message.data
if time_ax != 0:
Expand Down
44 changes: 44 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Shared test fixtures and helpers for ezmsg-event tests."""

import numpy as np
import sparse
from ezmsg.util.messages.axisarray import AxisArray

FS = 30_000.0
N_CH = 4
CHUNK_LEN = 600 # 20ms at 30kHz


def make_dense_msg(n_time: int, n_ch: int = N_CH, fs: float = FS, offset: float = 0.0) -> AxisArray:
"""Create a dense AxisArray message with random data."""
return AxisArray(
data=np.random.randn(n_time, n_ch),
dims=["time", "ch"],
axes={"time": AxisArray.Axis.TimeAxis(fs=fs, offset=offset)},
)


def make_sparse_event_msg(
n_time: int, n_ch: int = N_CH, fs: float = FS, density: float = 0.01, offset: float = 0.0
) -> AxisArray:
"""Create a sparse AxisArray message (COO format)."""
if n_time == 0:
s = sparse.COO(coords=np.zeros((2, 0), dtype=np.int64), data=np.zeros(0, dtype=bool), shape=(0, n_ch))
else:
s = sparse.random((n_time, n_ch), density=density, random_state=np.random.default_rng(42)) > 0
return AxisArray(
data=s,
dims=["time", "ch"],
axes={"time": AxisArray.Axis.TimeAxis(fs=fs, offset=offset)},
)


def make_rate_msg(
n_time: int, n_ch: int = N_CH, rate_hz: float = 50.0, bin_fs: float = 50.0, offset: float = 0.0
) -> AxisArray:
"""Create a dense AxisArray with firing-rate data (for PoissonEventTransformer)."""
return AxisArray(
data=np.full((n_time, n_ch), rate_hz),
dims=["time", "ch"],
axes={"time": AxisArray.Axis.TimeAxis(fs=bin_fs, offset=offset)},
)
35 changes: 35 additions & 0 deletions tests/test_binned.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import sparse
from conftest import CHUNK_LEN, FS, N_CH, make_sparse_event_msg
from ezmsg.util.messages.axisarray import AxisArray

from ezmsg.event.binned import BinnedEventAggregator, BinnedEventAggregatorSettings
Expand Down Expand Up @@ -52,3 +53,37 @@ def test_event_rate_binned():
stacked = AxisArray.concatenate(*out_msgs, dim="time")
assert stacked.data.shape == expected.shape
assert np.array_equal(stacked.data, expected.todense() / bin_dur)


def test_binned_event_aggregator_empty_time_after_init():
"""Normal → empty → normal: mid-stream empty message."""
proc = BinnedEventAggregator(settings=BinnedEventAggregatorSettings(bin_duration=0.02))

msg1 = make_sparse_event_msg(CHUNK_LEN, offset=0.0)
msg_empty = make_sparse_event_msg(0, offset=CHUNK_LEN / FS)
msg2 = make_sparse_event_msg(CHUNK_LEN, offset=CHUNK_LEN / FS)

out1 = proc(msg1)
assert out1.data.ndim == 2

out_empty = proc(msg_empty)
assert out_empty.data.ndim == 2

out2 = proc(msg2)
assert out2.data.ndim == 2
assert out2.data.shape[1] == N_CH


def test_binned_event_aggregator_empty_time_first():
"""Empty → normal: empty first message triggers _reset_state on empty data."""
proc = BinnedEventAggregator(settings=BinnedEventAggregatorSettings(bin_duration=0.02))

msg_empty = make_sparse_event_msg(0, offset=0.0)
msg_normal = make_sparse_event_msg(CHUNK_LEN, offset=0.0)

out_empty = proc(msg_empty)
assert out_empty.data.ndim == 2

out_normal = proc(msg_normal)
assert out_normal.data.ndim == 2
assert out_normal.data.shape[1] == N_CH
43 changes: 43 additions & 0 deletions tests/test_kernel_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pytest
import sparse
from conftest import CHUNK_LEN, FS, N_CH, make_sparse_event_msg
from ezmsg.util.messages.axisarray import AxisArray

from ezmsg.event.kernel_activation import (
Expand Down Expand Up @@ -406,3 +407,45 @@ def test_sum_aggregation(self):
result = activation(message)

assert result.data[0, 0] == 3.0 # Sum of all events in bin


class TestEmptyTime:
"""Test handling of length-0 time dimension inputs."""

@pytest.mark.parametrize(
"kernel_type",
[ActivationKernelType.EXPONENTIAL, ActivationKernelType.ALPHA, ActivationKernelType.COUNT],
)
def test_empty_time_after_init(self, kernel_type: ActivationKernelType):
"""Normal → empty → normal: mid-stream empty message."""
proc = BinnedKernelActivation(BinnedKernelActivationSettings(kernel_type=kernel_type, bin_duration=0.02))
msg1 = make_sparse_event_msg(CHUNK_LEN, offset=0.0)
msg_empty = make_sparse_event_msg(0, offset=CHUNK_LEN / FS)
msg2 = make_sparse_event_msg(CHUNK_LEN, offset=CHUNK_LEN / FS)

out1 = proc(msg1)
assert out1.data.ndim == 2

out_empty = proc(msg_empty)
assert out_empty.data.shape[0] == 0 or out_empty.data.ndim == 2

out2 = proc(msg2)
assert out2.data.ndim == 2
assert out2.data.shape[1] == N_CH

@pytest.mark.parametrize(
"kernel_type",
[ActivationKernelType.EXPONENTIAL, ActivationKernelType.ALPHA, ActivationKernelType.COUNT],
)
def test_empty_time_first(self, kernel_type: ActivationKernelType):
"""Empty → normal: empty first message triggers _reset_state on empty data."""
proc = BinnedKernelActivation(BinnedKernelActivationSettings(kernel_type=kernel_type, bin_duration=0.02))
msg_empty = make_sparse_event_msg(0, offset=0.0)
msg_normal = make_sparse_event_msg(CHUNK_LEN, offset=0.0)

out_empty = proc(msg_empty)
assert out_empty.data.ndim == 2

out_normal = proc(msg_normal)
assert out_normal.data.ndim == 2
assert out_normal.data.shape[1] == N_CH
45 changes: 45 additions & 0 deletions tests/test_kernel_insert.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Unit tests for ezmsg.event.kernel_insert module."""

import numpy as np
import pytest
import sparse
from conftest import CHUNK_LEN, FS, N_CH, make_sparse_event_msg
from ezmsg.util.messages.axisarray import AxisArray

from ezmsg.event.kernel import ArrayKernel, MultiKernel
Expand Down Expand Up @@ -280,3 +282,46 @@ def test_zero_length_chunk(self):

result = inserter(message)
assert result.data.shape == (0, 2)


class TestSparseKernelInserterEmptyTime:
"""Test handling of length-0 time dimension inputs."""

@pytest.mark.parametrize("use_kernel", [False, True])
def test_empty_time_after_init(self, use_kernel: bool):
"""Normal → empty → normal: mid-stream empty message."""
from ezmsg.event.kernel import FunctionalKernel, exponential_kernel

kernel = FunctionalKernel(exponential_kernel, sigma=0.001, fs=FS) if use_kernel else None
proc = SparseKernelInserter(SparseKernelInserterSettings(kernel=kernel))

msg1 = make_sparse_event_msg(CHUNK_LEN, offset=0.0)
msg_empty = make_sparse_event_msg(0, offset=CHUNK_LEN / FS)
msg2 = make_sparse_event_msg(CHUNK_LEN, offset=CHUNK_LEN / FS)

out1 = proc(msg1)
assert out1.data.shape == (CHUNK_LEN, N_CH)

out_empty = proc(msg_empty)
assert out_empty.data.shape[0] == 0
assert out_empty.data.shape[1] == N_CH

out2 = proc(msg2)
assert out2.data.shape == (CHUNK_LEN, N_CH)

@pytest.mark.parametrize("use_kernel", [False, True])
def test_empty_time_first(self, use_kernel: bool):
"""Empty → normal: empty first message triggers _reset_state on empty data."""
from ezmsg.event.kernel import FunctionalKernel, exponential_kernel

kernel = FunctionalKernel(exponential_kernel, sigma=0.001, fs=FS) if use_kernel else None
proc = SparseKernelInserter(SparseKernelInserterSettings(kernel=kernel))

msg_empty = make_sparse_event_msg(0, offset=0.0)
msg_normal = make_sparse_event_msg(CHUNK_LEN, offset=0.0)

out_empty = proc(msg_empty)
assert out_empty.data.shape[0] == 0

out_normal = proc(msg_normal)
assert out_normal.data.shape == (CHUNK_LEN, N_CH)
52 changes: 51 additions & 1 deletion tests/test_peak.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
import numpy as np
import pytest
import sparse
from conftest import CHUNK_LEN, FS, N_CH, make_dense_msg
from ezmsg.util.messagecodec import message_log
from ezmsg.util.messagelogger import MessageLogger
from ezmsg.util.messages.chunker import ArrayChunker, array_chunker
from ezmsg.util.terminate import TerminateOnTotal

from ezmsg.event.peak import ThresholdCrossing, ThresholdCrossingTransformer
from ezmsg.event.peak import ThresholdCrossing, ThresholdCrossingTransformer, ThresholdSettings
from ezmsg.event.util.simulate import generate_white_noise_with_events


Expand Down Expand Up @@ -138,3 +139,52 @@ def test_system():
assert isinstance(msg.data, sparse.SparseArray)
assert msg.axes["time"].gain == 1 / fs
assert np.round(msg.axes["time"].offset, 3) == np.round(msg_ix * chunk_dur, 3)


@pytest.mark.parametrize("return_peak_val", [False, True])
@pytest.mark.parametrize("auto_scale_tau", [0.0, 0.5])
def test_threshold_crossing_empty_time_after_init(return_peak_val: bool, auto_scale_tau: float):
"""Normal → empty → normal: mid-stream empty message should not crash or corrupt state."""
proc = ThresholdCrossingTransformer(
ThresholdSettings(
threshold=-3.5,
return_peak_val=return_peak_val,
auto_scale_tau=auto_scale_tau,
)
)
msg1 = make_dense_msg(CHUNK_LEN, offset=0.0)
msg_empty = make_dense_msg(0, offset=CHUNK_LEN / FS)
msg2 = make_dense_msg(CHUNK_LEN, offset=CHUNK_LEN / FS)

out1 = proc(msg1)
assert isinstance(out1.data, sparse.SparseArray)

out_empty = proc(msg_empty)
assert isinstance(out_empty.data, sparse.SparseArray)
assert out_empty.data.shape[1] == N_CH

out2 = proc(msg2)
assert isinstance(out2.data, sparse.SparseArray)
assert out2.data.shape[1] == N_CH


@pytest.mark.parametrize("return_peak_val", [False, True])
@pytest.mark.parametrize("auto_scale_tau", [0.0, 0.5])
def test_threshold_crossing_empty_time_first(return_peak_val: bool, auto_scale_tau: float):
"""Empty → normal: empty first message triggers _reset_state on empty data."""
proc = ThresholdCrossingTransformer(
ThresholdSettings(
threshold=-3.5,
return_peak_val=return_peak_val,
auto_scale_tau=auto_scale_tau,
)
)
msg_empty = make_dense_msg(0, offset=0.0)
msg_normal = make_dense_msg(CHUNK_LEN, offset=0.0)

out_empty = proc(msg_empty)
assert isinstance(out_empty.data, sparse.SparseArray)

out_normal = proc(msg_normal)
assert isinstance(out_normal.data, sparse.SparseArray)
assert out_normal.data.shape[1] == N_CH
33 changes: 33 additions & 0 deletions tests/test_poissonevents.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pytest
import sparse
from conftest import FS, N_CH, make_rate_msg
from ezmsg.util.messages.axisarray import AxisArray, CoordinateAxis, LinearAxis

from ezmsg.event.poissonevents import (
Expand Down Expand Up @@ -348,3 +349,35 @@ def test_statistical_properties(self):
# Allow 50% tolerance due to finite sample size
assert actual_mean_isi > expected_mean_isi * 0.5
assert actual_mean_isi < expected_mean_isi * 2.0

def test_empty_time_after_init(self):
"""Normal → empty → normal: mid-stream empty message."""
proc = PoissonEventTransformer(PoissonEventSettings(output_fs=FS))

msg1 = make_rate_msg(5, offset=0.0) # 5 bins at 50 Hz = 100ms
msg_empty = make_rate_msg(0, offset=0.1)
msg2 = make_rate_msg(5, offset=0.1)

out1 = proc(msg1)
assert isinstance(out1.data, sparse.SparseArray)

out_empty = proc(msg_empty)
assert isinstance(out_empty.data, sparse.SparseArray)

out2 = proc(msg2)
assert isinstance(out2.data, sparse.SparseArray)
assert out2.data.shape[1] == N_CH

def test_empty_time_first(self):
"""Empty → normal: empty first message triggers _reset_state on empty data."""
proc = PoissonEventTransformer(PoissonEventSettings(output_fs=FS))

msg_empty = make_rate_msg(0, offset=0.0)
msg_normal = make_rate_msg(5, offset=0.0)

out_empty = proc(msg_empty)
assert isinstance(out_empty.data, sparse.SparseArray)

out_normal = proc(msg_normal)
assert isinstance(out_normal.data, sparse.SparseArray)
assert out_normal.data.shape[1] == N_CH
Loading