From 7b9aadd1fd7db0353c788251598953910015ed2a Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Mon, 1 Dec 2025 23:42:09 -0500 Subject: [PATCH 01/10] Add EventsFromRates transformer / unit --- src/ezmsg/event/eventsfromrates.py | 249 ++++++++++++++++++++ tests/test_eventsfromrates.py | 354 +++++++++++++++++++++++++++++ 2 files changed, 603 insertions(+) create mode 100644 src/ezmsg/event/eventsfromrates.py create mode 100644 tests/test_eventsfromrates.py diff --git a/src/ezmsg/event/eventsfromrates.py b/src/ezmsg/event/eventsfromrates.py new file mode 100644 index 0000000..b88dfbd --- /dev/null +++ b/src/ezmsg/event/eventsfromrates.py @@ -0,0 +1,249 @@ +import ezmsg.core as ez +import numba +import numpy as np +import numpy.typing as npt +import sparse +from ezmsg.sigproc.base import BaseStatefulTransformer, BaseTransformerUnit, processor_state +from ezmsg.util.messages.axisarray import AxisArray +from ezmsg.util.messages.util import replace + + +@numba.jit(nopython=True, cache=True) +def _generate_events_single_channel( + rates: np.ndarray, # (n_bins,) rates for this channel + accumulated: float, # initial accumulated value + threshold: float, # initial threshold + bin_duration: float, + output_fs: float, + max_events: int, +) -> tuple[np.ndarray, int, float, float]: + """Generate events for a single channel. + + Returns: + event_samples: pre-allocated array of event sample indices + n_events: actual number of events generated + accumulated: updated accumulated value for next chunk + threshold: updated threshold for next chunk + """ + event_samples = np.empty(max_events, dtype=np.int64) + n_events = 0 + n_bins = len(rates) + + for t in range(n_bins): + bin_start = t * bin_duration + rate = rates[t] + time_in_bin = 0.0 + + while True: + time_to_event = (threshold - accumulated) / rate + event_time = time_in_bin + time_to_event + + if event_time >= bin_duration: + # No more events in this bin + accumulated += rate * (bin_duration - time_in_bin) + break + + # Record event + if n_events < max_events: + event_sample = int((event_time + bin_start) * output_fs) + event_samples[n_events] = event_sample + n_events += 1 + + # Update state for next event + time_in_bin = event_time + accumulated = 0.0 + threshold = np.random.exponential(1.0) + + return event_samples, n_events, accumulated, threshold + + +@numba.jit(nopython=True, parallel=True, cache=True) +def _generate_events_all_channels( + rates_array: np.ndarray, # (n_bins, n_channels) + accumulated: np.ndarray, # (n_channels,) + threshold: np.ndarray, # (n_channels,) + bin_duration: float, + output_fs: float, + max_events_per_channel: int, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Generate events for all channels in parallel. + + Returns: + all_event_samples: (n_channels, max_events_per_channel) event sample indices + event_counts: (n_channels,) number of events per channel + accumulated_out: (n_channels,) updated accumulated values + threshold_out: (n_channels,) updated thresholds + """ + n_bins, n_channels = rates_array.shape + + # Pre-allocate output arrays + all_event_samples = np.empty((n_channels, max_events_per_channel), dtype=np.int64) + event_counts = np.empty(n_channels, dtype=np.int64) + accumulated_out = np.empty(n_channels, dtype=np.float64) + threshold_out = np.empty(n_channels, dtype=np.float64) + + # Process each channel in parallel + for ch in numba.prange(n_channels): + rates = rates_array[:, ch] + samples, count, acc, thresh = _generate_events_single_channel( + rates, + accumulated[ch], + threshold[ch], + bin_duration, + output_fs, + max_events_per_channel, + ) + all_event_samples[ch, :] = samples + event_counts[ch] = count + accumulated_out[ch] = acc + threshold_out[ch] = thresh + + return all_event_samples, event_counts, accumulated_out, threshold_out + + +@numba.jit(nopython=True, cache=True) +def _flatten_events_unsorted( + all_event_samples: np.ndarray, # (n_channels, max_events) + event_counts: np.ndarray, # (n_channels,) +) -> tuple[np.ndarray, np.ndarray]: + """Flatten per-channel event arrays into coordinate arrays (unsorted).""" + total_events = np.sum(event_counts) + + if total_events == 0: + return np.zeros(0, dtype=np.int64), np.zeros(0, dtype=np.int64) + + event_samples = np.empty(total_events, dtype=np.int64) + event_channels = np.empty(total_events, dtype=np.int64) + + idx = 0 + for ch in range(len(event_counts)): + count = event_counts[ch] + if count > 0: + for i in range(count): + event_samples[idx + i] = all_event_samples[ch, i] + event_channels[idx + i] = ch + idx += count + + return event_samples, event_channels + + +class EventsFromRatesSettings(ez.Settings): + output_fs: float = 30_000 + """Output sampling rate.""" + + layout: str = "coo" + """Layout of the output event train sparse array. Options are 'coo' or 'gcxs'""" + + compress_dims: list[int] | None = None + """Dimensions to compress. Ignored if layout is 'coo'.""" + + assume_counts: bool = False + """If True, input is event counts per bin. If False, input is firing rate in Hz.""" + + min_rate: float = 1e-6 + """Minimum rate to avoid division by zero.""" + + max_rate: float = 500.0 + """Maximum expected firing rate (Hz). Used to pre-allocate event arrays.""" + + +@processor_state +class EventsFromRatesState: + accumulated: npt.NDArray | None = None + """Integrated rate since last event for each channel.""" + + threshold: npt.NDArray | None = None + """Exp(1) threshold for next event for each channel.""" + + +class EventsFromRatesTransformer( + BaseStatefulTransformer[ + EventsFromRatesSettings, AxisArray, AxisArray, EventsFromRatesState + ] +): + def _reset_state(self, message: AxisArray) -> None: + ch_ax = message.get_axis_idx("ch") + n_channels = message.data.shape[ch_ax] + self.state.accumulated = np.zeros(n_channels) + self.state.threshold = np.random.exponential(1.0, size=n_channels) + + def _process(self, message: AxisArray) -> AxisArray: + time_ax = message.get_axis_idx("time") + n_bins = message.data.shape[time_ax] + bin_duration = message.axes["time"].gain + total_samples = n_bins * int(bin_duration * self.settings.output_fs) + + # Get rates array with shape (n_bins, n_channels), contiguous for numba + rates_array = ( + message.data / bin_duration if self.settings.assume_counts else message.data + ) + if time_ax != 0: + rates_array = np.moveaxis(rates_array, time_ax, 0) + rates_array = np.ascontiguousarray( + np.maximum(rates_array, self.settings.min_rate) + ) + n_channels = rates_array.shape[1] + + # Estimate max events per channel based on actual input rates + total_time = n_bins * bin_duration + max_input_rate = np.max(rates_array) + max_events_per_channel = max(int(max_input_rate * total_time * 3) + 10, 20) + + # Generate events using numba (parallel across channels) + all_event_samples, event_counts, accumulated_out, threshold_out = ( + _generate_events_all_channels( + rates_array, + self.state.accumulated, + self.state.threshold, + bin_duration, + self.settings.output_fs, + max_events_per_channel, + ) + ) + + # Update state for next chunk + self.state.accumulated = accumulated_out + self.state.threshold = threshold_out + + # Flatten per-channel arrays into coordinate arrays + event_samples, event_channels = _flatten_events_unsorted( + all_event_samples, event_counts + ) + + # Build sparse array (COO handles sorting internally) + if len(event_samples) > 0: + event_samples = np.clip(event_samples, 0, total_samples - 1) + event_coords = np.vstack([event_samples, event_channels]) + event_data = np.ones(len(event_samples), dtype=np.int8) + else: + event_coords = np.zeros((2, 0), dtype=np.int64) + event_data = np.zeros(0, dtype=np.int8) + + event_array = sparse.COO( + coords=event_coords, + data=event_data, + shape=(total_samples, n_channels), + ) + + if self.settings.layout == "gcxs": + event_array = sparse.GCXS.from_coo( + event_array, compressed_axes=self.settings.compress_dims + ) + + return replace( + message, + data=event_array, + dims=["time", "ch"], + axes={ + **message.axes, + "time": replace(message.axes["time"], gain=1 / self.settings.output_fs), + }, + ) + + +class EventsFromRatesUnit( + BaseTransformerUnit[ + EventsFromRatesSettings, AxisArray, AxisArray, EventsFromRatesTransformer + ] +): + SETTINGS = EventsFromRatesSettings diff --git a/tests/test_eventsfromrates.py b/tests/test_eventsfromrates.py new file mode 100644 index 0000000..b40fc52 --- /dev/null +++ b/tests/test_eventsfromrates.py @@ -0,0 +1,354 @@ +import numpy as np +import pytest +import sparse +from ezmsg.util.messages.axisarray import AxisArray, LinearAxis, CoordinateAxis + +from ezmsg.event.eventsfromrates import ( + EventsFromRatesSettings, + EventsFromRatesTransformer, +) + + +def make_rate_message( + rates: np.ndarray, + bin_duration: float = 0.02, + time_offset: float = 0.0, +) -> AxisArray: + """Create an AxisArray message with firing rates. + + Args: + rates: Array of shape (n_bins, n_channels) with firing rates in Hz. + bin_duration: Duration of each bin in seconds. + time_offset: Time offset for the first bin. + + Returns: + AxisArray with the rates and appropriate axes. + """ + n_channels = rates.shape[1] + fs = 1.0 / bin_duration # Sampling frequency for bins + return AxisArray( + data=rates, + dims=["time", "ch"], + axes={ + "time": LinearAxis.create_time_axis(fs=fs, offset=time_offset), + "ch": CoordinateAxis( + data=np.arange(n_channels).astype(str), dims=["ch"] + ), + }, + ) + + +class TestEventsFromRatesTransformer: + def test_basic_event_generation(self): + """Test that events are generated at approximately the expected rate.""" + np.random.seed(42) + + settings = EventsFromRatesSettings(output_fs=30_000) + transformer = EventsFromRatesTransformer(settings) + + # 100 Hz rate, 10 bins of 20ms = 200ms total + # Expected events: 100 * 0.2 = 20 events per channel + n_bins = 10 + n_channels = 4 + rate = 100.0 + bin_duration = 0.02 + + rates = np.full((n_bins, n_channels), rate) + msg = make_rate_message(rates, bin_duration=bin_duration) + + result = transformer(msg) + + assert isinstance(result.data, sparse.COO) + total_events = result.data.nnz + + # With 4 channels at 100 Hz for 200ms, expect ~80 events total + # Allow reasonable variance (Poisson process) + expected_total = rate * bin_duration * n_bins * n_channels + assert total_events > expected_total * 0.5 + assert total_events < expected_total * 1.5 + + def test_output_shape_and_sampling_rate(self): + """Test that output has correct shape and time axis gain.""" + settings = EventsFromRatesSettings(output_fs=30_000) + transformer = EventsFromRatesTransformer(settings) + + n_bins = 5 + n_channels = 8 + bin_duration = 0.02 + + rates = np.full((n_bins, n_channels), 50.0) + msg = make_rate_message(rates, bin_duration=bin_duration) + + result = transformer(msg) + + # Expected samples: n_bins * bin_duration * output_fs + expected_samples = int(n_bins * bin_duration * 30_000) + assert result.data.shape == (expected_samples, n_channels) + assert result.axes["time"].gain == pytest.approx(1 / 30_000) + + def test_low_rate_generates_events_over_time(self): + """Test that low rates eventually generate events across multiple chunks.""" + np.random.seed(123) + + settings = EventsFromRatesSettings(output_fs=30_000) + transformer = EventsFromRatesTransformer(settings) + + # 10 Hz rate with 20ms bins: expected 0.2 events per bin per channel + # Over 50 bins (1 second), expect ~10 events per channel + n_bins = 50 + n_channels = 2 + rate = 10.0 + bin_duration = 0.02 + + rates = np.full((n_bins, n_channels), rate) + msg = make_rate_message(rates, bin_duration=bin_duration) + + result = transformer(msg) + + # Should have some events despite low rate + assert result.data.nnz > 0 + + # Check approximately correct rate + total_time = n_bins * bin_duration + expected_events = rate * total_time * n_channels + assert result.data.nnz > expected_events * 0.3 + assert result.data.nnz < expected_events * 2.0 + + def test_state_persistence_across_chunks(self): + """Test that state carries over correctly between processing calls.""" + np.random.seed(456) + + settings = EventsFromRatesSettings(output_fs=30_000) + transformer = EventsFromRatesTransformer(settings) + + n_channels = 4 + rate = 50.0 + bin_duration = 0.02 + n_bins_per_chunk = 5 + + total_events = 0 + for chunk_idx in range(10): + rates = np.full((n_bins_per_chunk, n_channels), rate) + time_offset = chunk_idx * n_bins_per_chunk * bin_duration + msg = make_rate_message(rates, bin_duration=bin_duration, time_offset=time_offset) + + result = transformer(msg) + total_events += result.data.nnz + + # Total time: 10 chunks * 5 bins * 20ms = 1 second + # Expected: 50 Hz * 1s * 4 channels = 200 events + expected_total = 200 + assert total_events > expected_total * 0.6 + assert total_events < expected_total * 1.5 + + def test_rate_change_low_to_high(self): + """Test the key scenario: low rate followed by high rate.""" + np.random.seed(789) + + settings = EventsFromRatesSettings(output_fs=30_000) + transformer = EventsFromRatesTransformer(settings) + + n_channels = 10 + bin_duration = 0.02 + + # First chunk: very low rate (should accumulate progress toward events) + low_rates = np.full((5, n_channels), 1.0) # 1 Hz + msg1 = make_rate_message(low_rates, bin_duration=bin_duration, time_offset=0.0) + result1 = transformer(msg1) + + # Second chunk: high rate (should trigger events quickly) + high_rates = np.full((5, n_channels), 100.0) # 100 Hz + msg2 = make_rate_message(high_rates, bin_duration=bin_duration, time_offset=0.1) + result2 = transformer(msg2) + + # The high-rate chunk should have many events + # Expected for 100 Hz * 0.1s * 10 channels = 100 events + assert result2.data.nnz > 50 # At least half expected + + # The low-rate chunk might have few or no events + # 1 Hz * 0.1s * 10 channels = 1 event expected + assert result1.data.nnz < 20 # Should be very few + + def test_rate_change_high_to_low(self): + """Test high rate followed by low rate - accumulated state should persist.""" + np.random.seed(101) + + settings = EventsFromRatesSettings(output_fs=30_000) + transformer = EventsFromRatesTransformer(settings) + + n_channels = 8 + bin_duration = 0.02 + + # First chunk: high rate + high_rates = np.full((5, n_channels), 100.0) + msg1 = make_rate_message(high_rates, bin_duration=bin_duration, time_offset=0.0) + result1 = transformer(msg1) + + # Store the state after high-rate processing + accumulated_after_high = transformer.state.accumulated.copy() + + # Second chunk: very low rate + low_rates = np.full((50, n_channels), 0.1) # 0.1 Hz - very low + msg2 = make_rate_message(low_rates, bin_duration=bin_duration, time_offset=0.1) + result2 = transformer(msg2) + + # High-rate chunk should have many events + assert result1.data.nnz > 20 + + # Low-rate chunk: some channels might event due to accumulated progress + # from the high-rate chunk, but overall should be sparse + # 0.1 Hz * 1s * 8 channels = 0.8 events expected from rate alone + # But accumulated state might cause a few more + assert result2.data.nnz < 30 + + def test_zero_rate_no_events(self): + """Test that zero rate produces no events.""" + settings = EventsFromRatesSettings(output_fs=30_000, min_rate=1e-10) + transformer = EventsFromRatesTransformer(settings) + + n_bins = 10 + n_channels = 4 + bin_duration = 0.02 + + # Use min_rate (effectively zero) + rates = np.full((n_bins, n_channels), 1e-10) + msg = make_rate_message(rates, bin_duration=bin_duration) + + result = transformer(msg) + + # Should have no events (or extremely few due to numerical precision) + assert result.data.nnz == 0 + + def test_gcxs_layout(self): + """Test that GCXS layout option works correctly.""" + np.random.seed(202) + + settings = EventsFromRatesSettings( + output_fs=30_000, + layout="gcxs", + compress_dims=[0], + ) + transformer = EventsFromRatesTransformer(settings) + + rates = np.full((5, 4), 50.0) + msg = make_rate_message(rates, bin_duration=0.02) + + result = transformer(msg) + + assert isinstance(result.data, sparse.GCXS) + + def test_assume_counts_mode(self): + """Test that assume_counts correctly interprets input as event counts.""" + np.random.seed(303) + + settings = EventsFromRatesSettings(output_fs=30_000, assume_counts=True) + transformer = EventsFromRatesTransformer(settings) + + n_bins = 10 + n_channels = 4 + bin_duration = 0.02 + + # Input as counts: 2 events per bin = 100 Hz rate + counts = np.full((n_bins, n_channels), 2.0) + msg = make_rate_message(counts, bin_duration=bin_duration) + + result = transformer(msg) + + # Expected rate = 2 / 0.02 = 100 Hz + # Expected events = 100 * 0.2 * 4 = 80 + expected = 80 + assert result.data.nnz > expected * 0.5 + assert result.data.nnz < expected * 1.5 + + def test_event_times_within_bounds(self): + """Test that all event times fall within valid sample range.""" + np.random.seed(404) + + settings = EventsFromRatesSettings(output_fs=30_000) + transformer = EventsFromRatesTransformer(settings) + + n_bins = 5 + n_channels = 8 + bin_duration = 0.02 + + rates = np.full((n_bins, n_channels), 200.0) # High rate + msg = make_rate_message(rates, bin_duration=bin_duration) + + result = transformer(msg) + + total_samples = int(n_bins * bin_duration * 30_000) + + # All time coordinates should be within bounds + if result.data.nnz > 0: + time_coords = result.data.coords[0] + assert np.all(time_coords >= 0) + assert np.all(time_coords < total_samples) + + ch_coords = result.data.coords[1] + assert np.all(ch_coords >= 0) + assert np.all(ch_coords < n_channels) + + def test_channel_axis_position(self): + """Test that channel axis in non-standard position is handled correctly.""" + np.random.seed(505) + + settings = EventsFromRatesSettings(output_fs=30_000) + transformer = EventsFromRatesTransformer(settings) + + n_bins = 5 + n_channels = 4 + bin_duration = 0.02 + fs = 1.0 / bin_duration + + # Create message with channel axis first (transposed) + rates = np.full((n_channels, n_bins), 50.0) + msg = AxisArray( + data=rates, + dims=["ch", "time"], + axes={ + "time": LinearAxis.create_time_axis(fs=fs, offset=0.0), + "ch": CoordinateAxis( + data=np.arange(n_channels).astype(str), dims=["ch"] + ), + }, + ) + + result = transformer(msg) + + expected_samples = int(n_bins * bin_duration * 30_000) + assert result.data.shape == (expected_samples, n_channels) + + def test_statistical_properties(self): + """Test that event intervals follow expected exponential distribution.""" + np.random.seed(606) + + settings = EventsFromRatesSettings(output_fs=30_000) + transformer = EventsFromRatesTransformer(settings) + + # Long duration to get good statistics + n_bins = 100 + n_channels = 1 + rate = 50.0 + bin_duration = 0.02 + + rates = np.full((n_bins, n_channels), rate) + msg = make_rate_message(rates, bin_duration=bin_duration) + + result = transformer(msg) + + # Extract event times for the single channel + event_samples = result.data.coords[0] + event_times = event_samples / settings.output_fs + + if len(event_times) > 10: + # Calculate inter-event intervals + sorted_times = np.sort(event_times) + isis = np.diff(sorted_times) + + # Mean ISI should be approximately 1/rate + expected_mean_isi = 1 / rate + actual_mean_isi = np.mean(isis) + + # Allow 50% tolerance due to finite sample size + assert actual_mean_isi > expected_mean_isi * 0.5 + assert actual_mean_isi < expected_mean_isi * 2.0 From 98c8925303e94e8c38db20007dbe43bb69f4c5ba Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Tue, 2 Dec 2025 00:36:08 -0500 Subject: [PATCH 02/10] bump deps --- .gitignore | 1 + pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 3b8b86f..ce62c06 100644 --- a/.gitignore +++ b/.gitignore @@ -163,3 +163,4 @@ cython_debug/ src/ezmsg/event/__version__.py uv.lock +*.local.json diff --git a/pyproject.toml b/pyproject.toml index f4150d6..6b132f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ readme = "README.md" requires-python = ">=3.10.15" dynamic = ["version"] dependencies = [ - "ezmsg-sigproc>=2.2.0", + "ezmsg-sigproc>=2.3.0", "ezmsg>=3.6.1", "sparse>=0.17.0", "numpy>=2.2.6", From d469adf72900235d1ae13de40315a01f5eeeb304 Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Tue, 2 Dec 2025 00:36:59 -0500 Subject: [PATCH 03/10] ruff --- profile_rate.py | 48 +++++++++++++++++++----------- src/ezmsg/event/eventsfromrates.py | 6 +++- tests/test_eventsfromrates.py | 10 +++---- 3 files changed, 41 insertions(+), 23 deletions(-) diff --git a/profile_rate.py b/profile_rate.py index 9375762..9faf075 100644 --- a/profile_rate.py +++ b/profile_rate.py @@ -1,17 +1,26 @@ import cProfile -import pstats + +# import pstats import numpy as np import sparse from ezmsg.util.messages.axisarray import AxisArray -from ezmsg.event.rate import Rate, EventRateSettings -import ezmsg.core as ez +from ezmsg.event.rate import EventRate, EventRateSettings + # Simulate input data def generate_sparse_data(num_samples, num_channels, sparsity_factor, rng): - data = sparse.random((num_samples, num_channels), density=sparsity_factor, random_state=rng) > 0 + data = ( + sparse.random( + (num_samples, num_channels), density=sparsity_factor, random_state=rng + ) + > 0 + ) return data -def run_rate_processor(num_samples, num_channels, sparsity_factor, bin_duration, chunk_dur, fs): + +def run_rate_processor( + num_samples, num_channels, sparsity_factor, bin_duration, chunk_dur, fs +): rng = np.random.default_rng() s = generate_sparse_data(num_samples, num_channels, sparsity_factor, rng) @@ -30,7 +39,7 @@ def run_rate_processor(num_samples, num_channels, sparsity_factor, bin_duration, ] settings = EventRateSettings(bin_duration=bin_duration) - rate_processor = Rate(settings) + rate_processor = EventRate(settings) output_messages = [] for in_msg in in_msgs: @@ -38,27 +47,32 @@ def run_rate_processor(num_samples, num_channels, sparsity_factor, bin_duration, return output_messages + if __name__ == "__main__": NUM_SAMPLES = 300_000 # Number of time samples (e.g., 10 seconds at 30kHz) - NUM_CHANNELS = 128 # Number of channels - SPARSITY_FACTOR = 0.0001 # 0.01% sparse, as in test_rate.py - BIN_DURATION = 0.03 # 30ms bin duration, as in test_rate.py - CHUNK_DURATION = 0.1 # 100ms chunk duration, as in test_rate.py - FS = 30000.0 # Sampling frequency, as in test_rate.py + NUM_CHANNELS = 128 # Number of channels + SPARSITY_FACTOR = 0.0001 # 0.01% sparse, as in test_rate.py + BIN_DURATION = 0.03 # 30ms bin duration, as in test_rate.py + CHUNK_DURATION = 0.1 # 100ms chunk duration, as in test_rate.py + FS = 30000.0 # Sampling frequency, as in test_rate.py - print(f"Profiling with: samples={NUM_SAMPLES}, channels={NUM_CHANNELS}, sparsity={SPARSITY_FACTOR}, bin_duration={BIN_DURATION}, chunk_duration={CHUNK_DURATION}, fs={FS}") + print( + f"Profiling with: samples={NUM_SAMPLES}, channels={NUM_CHANNELS}, sparsity={SPARSITY_FACTOR}, bin_duration={BIN_DURATION}, chunk_duration={CHUNK_DURATION}, fs={FS}" + ) # Run with cProfile profiler = cProfile.Profile() profiler.enable() - - run_rate_processor(NUM_SAMPLES, NUM_CHANNELS, SPARSITY_FACTOR, BIN_DURATION, CHUNK_DURATION, FS) - + + run_rate_processor( + NUM_SAMPLES, NUM_CHANNELS, SPARSITY_FACTOR, BIN_DURATION, CHUNK_DURATION, FS + ) + profiler.disable() - + # Save stats to a file in a format snakeviz can read stats_file = "rate_profile.prof" profiler.dump_stats(stats_file) print(f"Profiling results saved to {stats_file}") - print("To visualize, run: uv run snakeviz rate_profile.prof") \ No newline at end of file + print("To visualize, run: uv run snakeviz rate_profile.prof") diff --git a/src/ezmsg/event/eventsfromrates.py b/src/ezmsg/event/eventsfromrates.py index b88dfbd..d2ebe32 100644 --- a/src/ezmsg/event/eventsfromrates.py +++ b/src/ezmsg/event/eventsfromrates.py @@ -3,7 +3,11 @@ import numpy as np import numpy.typing as npt import sparse -from ezmsg.sigproc.base import BaseStatefulTransformer, BaseTransformerUnit, processor_state +from ezmsg.sigproc.base import ( + BaseStatefulTransformer, + BaseTransformerUnit, + processor_state, +) from ezmsg.util.messages.axisarray import AxisArray from ezmsg.util.messages.util import replace diff --git a/tests/test_eventsfromrates.py b/tests/test_eventsfromrates.py index b40fc52..c4d189c 100644 --- a/tests/test_eventsfromrates.py +++ b/tests/test_eventsfromrates.py @@ -31,9 +31,7 @@ def make_rate_message( dims=["time", "ch"], axes={ "time": LinearAxis.create_time_axis(fs=fs, offset=time_offset), - "ch": CoordinateAxis( - data=np.arange(n_channels).astype(str), dims=["ch"] - ), + "ch": CoordinateAxis(data=np.arange(n_channels).astype(str), dims=["ch"]), }, ) @@ -130,7 +128,9 @@ def test_state_persistence_across_chunks(self): for chunk_idx in range(10): rates = np.full((n_bins_per_chunk, n_channels), rate) time_offset = chunk_idx * n_bins_per_chunk * bin_duration - msg = make_rate_message(rates, bin_duration=bin_duration, time_offset=time_offset) + msg = make_rate_message( + rates, bin_duration=bin_duration, time_offset=time_offset + ) result = transformer(msg) total_events += result.data.nnz @@ -185,7 +185,7 @@ def test_rate_change_high_to_low(self): result1 = transformer(msg1) # Store the state after high-rate processing - accumulated_after_high = transformer.state.accumulated.copy() + _ = transformer.state.accumulated.copy() # Second chunk: very low rate low_rates = np.full((50, n_channels), 0.1) # 0.1 Hz - very low From 769b98fb828f77a6bda429ac6b97232d151787b1 Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Tue, 2 Dec 2025 00:41:55 -0500 Subject: [PATCH 04/10] Fixup refractory --- src/ezmsg/event/refractory.py | 84 +++++++++++-- tests/test_refractory.py | 223 ++++++++++++++++++++++++++++++++++ 2 files changed, 295 insertions(+), 12 deletions(-) create mode 100644 tests/test_refractory.py diff --git a/src/ezmsg/event/refractory.py b/src/ezmsg/event/refractory.py index ea5db84..c781629 100644 --- a/src/ezmsg/event/refractory.py +++ b/src/ezmsg/event/refractory.py @@ -1,8 +1,8 @@ - import numpy as np import numpy.typing as npt +import sparse import ezmsg.core as ez -from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis +from ezmsg.util.messages.axisarray import AxisArray, replace from ezmsg.sigproc.base import BaseStatefulTransformer, processor_state @@ -29,25 +29,53 @@ def _reset_state(self, message: AxisArray) -> None: fs = 1 / message.axes["time"].gain self._state.width = int(self.settings.dur * fs) ax_idx = message.get_axis_idx("time") - first_samp = slice_along_axis(message.data, slice(None, 1, None), ax_idx) - self._state.elapsed = np.zeros(first_samp.shape, dtype=int) + ( - self._state.width + 1 - ) + # Get the shape of features (all dims except time) + feat_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :] + n_feats = int(np.prod(feat_shape)) + self._state.elapsed = np.zeros((n_feats,), dtype=int) + (self._state.width + 1) def _process(self, message: AxisArray) -> AxisArray: if self._state.width <= 2: return message - # TODO: Get the sparse indices of the message.data + ax_idx = message.get_axis_idx("time") + n_samps = message.data.shape[ax_idx] - if len(samp_idx) <= 0: + # Get the sparse indices of the message.data + # coords is a tuple of arrays, one per dimension + coords = message.data.coords + if coords.shape[1] == 0: + # No events, update elapsed and return + self._state.elapsed += n_samps return message - uq_feats, feat_splits = np.unique(cross_idx[0], return_index=True) + # Separate time indices from feature indices + samp_idx = coords[ax_idx] + feat_dims = list(range(message.data.ndim)) + feat_dims.pop(ax_idx) + feat_coords = tuple(coords[d] for d in feat_dims) + + # Ravel feature indices to 1D for tracking + feat_shape = tuple(message.data.shape[d] for d in feat_dims) + if len(feat_coords) > 0: + ravel_feat_inds = np.ravel_multi_index(feat_coords, feat_shape) + else: + ravel_feat_inds = np.zeros(len(samp_idx), dtype=int) + + # Sort by feature then by time to process events in order + sort_order = np.lexsort((samp_idx, ravel_feat_inds)) + samp_idx = samp_idx[sort_order] + ravel_feat_inds = ravel_feat_inds[sort_order] + feat_coords = tuple(fc[sort_order] for fc in feat_coords) + + # Create cross_idx as list with feature coords first, then time + cross_idx = list(feat_coords) + [samp_idx] + + uq_feats, feat_splits = np.unique(ravel_feat_inds, return_index=True) ieis = np.diff(np.hstack(([samp_idx[0] + 1], samp_idx))) # Reset elapsed time at feature boundaries. ieis[feat_splits] = samp_idx[feat_splits] + self._state.elapsed[uq_feats] - b_drop = ieis <= self._state.refrac_width + b_drop = ieis <= self._state.width drop_idx = np.where(b_drop)[0] final_drop = [] while len(drop_idx) > 0: @@ -62,8 +90,40 @@ def _process(self, message: AxisArray) -> AxisArray: drop_idx = drop_idx[1:] # If the next event is now outside the refractory period then it will not be dropped. - if len(drop_idx) > 0 and ieis[drop_idx[0]] > self._state.refrac_width: + if len(drop_idx) > 0 and ieis[drop_idx[0]] > self._state.width: drop_idx = drop_idx[1:] samp_idx = np.delete(samp_idx, final_drop) - cross_idx = tuple(np.delete(_, final_drop) for _ in cross_idx) + cross_idx = [np.delete(_, final_drop) for _ in cross_idx] + ravel_feat_inds = np.delete(ravel_feat_inds, final_drop) + + # Update elapsed state for all features + self._state.elapsed += n_samps + # For features that had events, set elapsed to time since last event + if len(samp_idx) > 0: + # Get the last event time for each feature that had events + uq_final_feats, last_idx = np.unique( + ravel_feat_inds[::-1], return_index=True + ) + last_idx = len(ravel_feat_inds) - 1 - last_idx + last_samps = samp_idx[last_idx] + self._state.elapsed[uq_final_feats] = n_samps - last_samps + + # Build output coordinates in original dimension order + out_coords = [None] * message.data.ndim + for i, d in enumerate(feat_dims): + out_coords[d] = cross_idx[i] + out_coords[ax_idx] = cross_idx[-1] + + # Get the values for kept events + kept_mask = np.ones(coords.shape[1], dtype=bool) + kept_mask[sort_order[final_drop]] = False + result_data = message.data.data[kept_mask] + + result = sparse.COO( + out_coords, + data=result_data, + shape=message.data.shape, + ) + + return replace(message, data=result) diff --git a/tests/test_refractory.py b/tests/test_refractory.py new file mode 100644 index 0000000..af89703 --- /dev/null +++ b/tests/test_refractory.py @@ -0,0 +1,223 @@ +import numpy as np +import pytest +import sparse +from ezmsg.util.messages.axisarray import AxisArray + +from ezmsg.event.refractory import RefractoryTransformer + + +class TestRefractoryTransformer: + """Test suite for RefractoryTransformer.""" + + def test_no_refractory_passthrough(self): + """When refractory duration is 0, events should pass through unchanged.""" + fs = 1000.0 + n_samples = 100 + n_chans = 4 + + # Create sparse data with events + coords = [[0, 1, 2, 3], [10, 20, 30, 40]] # ch, time + data = np.array([True, True, True, True]) + sparse_data = sparse.COO(coords, data, shape=(n_chans, n_samples)) + + msg = AxisArray( + data=sparse_data, + dims=["ch", "time"], + axes={"time": AxisArray.Axis.TimeAxis(fs=fs, offset=0.0)}, + key="test_passthrough", + ) + + transformer = RefractoryTransformer(dur=0.0) + result = transformer(msg) + + assert result.data.nnz == 4 # All events should pass through + + def test_refractory_filters_close_events(self): + """Events within refractory period should be filtered out.""" + fs = 1000.0 + n_samples = 100 + n_chans = 1 + refrac_dur = 0.010 # 10ms = 10 samples at 1kHz + + # Create sparse data with events: samples 10, 15, 30, 35 + # Events at 10 and 30 should pass, events at 15 and 35 should be filtered + coords = [[0, 0, 0, 0], [10, 15, 30, 35]] # ch, time + data = np.array([True, True, True, True]) + sparse_data = sparse.COO(coords, data, shape=(n_chans, n_samples)) + + msg = AxisArray( + data=sparse_data, + dims=["ch", "time"], + axes={"time": AxisArray.Axis.TimeAxis(fs=fs, offset=0.0)}, + key="test_filter", + ) + + transformer = RefractoryTransformer(dur=refrac_dur) + result = transformer(msg) + + # Events at 15 and 35 are within 10 samples of 10 and 30, so should be filtered + assert result.data.nnz == 2 + result_times = result.data.coords[1] + assert 10 in result_times + assert 30 in result_times + assert 15 not in result_times + assert 35 not in result_times + + def test_refractory_independent_channels(self): + """Refractory period should be enforced independently per channel.""" + fs = 1000.0 + n_samples = 50 + n_chans = 2 + refrac_dur = 0.010 # 10ms = 10 samples + + # Channel 0: events at 10, 15 (15 should be filtered) + # Channel 1: events at 10, 15 (15 should be filtered) + coords = [[0, 0, 1, 1], [10, 15, 10, 15]] + data = np.array([True, True, True, True]) + sparse_data = sparse.COO(coords, data, shape=(n_chans, n_samples)) + + msg = AxisArray( + data=sparse_data, + dims=["ch", "time"], + axes={"time": AxisArray.Axis.TimeAxis(fs=fs, offset=0.0)}, + key="test_channels", + ) + + transformer = RefractoryTransformer(dur=refrac_dur) + result = transformer(msg) + + # Each channel should have only one event (at sample 10) + assert result.data.nnz == 2 + # Both channels should have an event at sample 10 + ch_inds = result.data.coords[0] + time_inds = result.data.coords[1] + assert np.all(time_inds == 10) + assert set(ch_inds) == {0, 1} + + def test_refractory_across_chunks(self): + """Refractory period should be enforced across message boundaries.""" + fs = 1000.0 + n_samples = 50 + n_chans = 1 + refrac_dur = 0.020 # 20ms = 20 samples + + transformer = RefractoryTransformer(dur=refrac_dur) + + # First chunk: event at sample 45 + coords1 = [[0], [45]] + data1 = np.array([True]) + sparse_data1 = sparse.COO(coords1, data1, shape=(n_chans, n_samples)) + msg1 = AxisArray( + data=sparse_data1, + dims=["ch", "time"], + axes={"time": AxisArray.Axis.TimeAxis(fs=fs, offset=0.0)}, + key="test_across_chunks", + ) + + # Second chunk: event at sample 5 (which is 5 samples after chunk boundary) + # Total elapsed since last event would be 5 + (50-45) = 10 samples < 20 refrac + coords2 = [[0], [5]] + data2 = np.array([True]) + sparse_data2 = sparse.COO(coords2, data2, shape=(n_chans, n_samples)) + msg2 = AxisArray( + data=sparse_data2, + dims=["ch", "time"], + axes={"time": AxisArray.Axis.TimeAxis(fs=fs, offset=0.050)}, + key="test_across_chunks", + ) + + result1 = transformer(msg1) + result2 = transformer(msg2) + + # First chunk event should pass + assert result1.data.nnz == 1 + # Second chunk event should be filtered (only 10 samples since last event) + assert result2.data.nnz == 0 + + def test_empty_sparse_message(self): + """Empty sparse messages should pass through without errors.""" + fs = 1000.0 + n_samples = 100 + n_chans = 4 + refrac_dur = 0.010 + + # Create empty sparse data + coords = [[], []] + data = np.array([], dtype=bool) + sparse_data = sparse.COO(coords, data, shape=(n_chans, n_samples)) + + msg = AxisArray( + data=sparse_data, + dims=["ch", "time"], + axes={"time": AxisArray.Axis.TimeAxis(fs=fs, offset=0.0)}, + key="test_empty", + ) + + transformer = RefractoryTransformer(dur=refrac_dur) + result = transformer(msg) + + assert result.data.nnz == 0 + + def test_cascade_filtering(self): + """Test that filtering cascades correctly when multiple events are close.""" + fs = 1000.0 + n_samples = 100 + n_chans = 1 + refrac_dur = 0.010 # 10 samples + + # Events at 10, 15, 18, 30 + # 10 passes, 15 filtered (5 from 10), 18 should be checked against 10 (not 15) + # 18 is 8 from 10, so filtered. 30 passes (20 from 10) + coords = [[0, 0, 0, 0], [10, 15, 18, 30]] + data = np.array([True, True, True, True]) + sparse_data = sparse.COO(coords, data, shape=(n_chans, n_samples)) + + msg = AxisArray( + data=sparse_data, + dims=["ch", "time"], + axes={"time": AxisArray.Axis.TimeAxis(fs=fs, offset=0.0)}, + key="test_cascade", + ) + + transformer = RefractoryTransformer(dur=refrac_dur) + result = transformer(msg) + + assert result.data.nnz == 2 + result_times = result.data.coords[1] + assert 10 in result_times + assert 30 in result_times + + @pytest.mark.parametrize("time_axis_position", [0, 1]) + def test_different_time_axis_positions(self, time_axis_position: int): + """Test that the transformer works with time axis in different positions.""" + fs = 1000.0 + n_samples = 100 + n_chans = 4 + refrac_dur = 0.010 + + if time_axis_position == 0: + # time, ch + coords = [[10, 20], [0, 1]] + shape = (n_samples, n_chans) + dims = ["time", "ch"] + else: + # ch, time + coords = [[0, 1], [10, 20]] + shape = (n_chans, n_samples) + dims = ["ch", "time"] + + data = np.array([True, True]) + sparse_data = sparse.COO(coords, data, shape=shape) + + msg = AxisArray( + data=sparse_data, + dims=dims, + axes={"time": AxisArray.Axis.TimeAxis(fs=fs, offset=0.0)}, + key="test_time_axis", + ) + + transformer = RefractoryTransformer(dur=refrac_dur) + result = transformer(msg) + + # Both events are far enough apart to pass + assert result.data.nnz == 2 From 2dff9845e87ab0d47ec8824b2c842e0c16540d2f Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Fri, 8 Aug 2025 20:08:37 -0400 Subject: [PATCH 05/10] waypoint checkin for rate calculation refactor --- src/ezmsg/event/__init__.py | 3 + src/ezmsg/event/aggregate.py | 36 ++++++++++ src/ezmsg/event/rate.py | 136 +++++++++++++++++------------------ tests/test_rate.py | 46 +++++++----- 4 files changed, 131 insertions(+), 90 deletions(-) create mode 100644 src/ezmsg/event/aggregate.py diff --git a/src/ezmsg/event/__init__.py b/src/ezmsg/event/__init__.py index 7eb56b4..7305478 100644 --- a/src/ezmsg/event/__init__.py +++ b/src/ezmsg/event/__init__.py @@ -1 +1,4 @@ from .__version__ import __version__ as __version__ + +from .rate import Rate, EventRate +from .aggregate import Aggregate \ No newline at end of file diff --git a/src/ezmsg/event/aggregate.py b/src/ezmsg/event/aggregate.py new file mode 100644 index 0000000..584b56e --- /dev/null +++ b/src/ezmsg/event/aggregate.py @@ -0,0 +1,36 @@ +import typing + +import ezmsg.core as ez +from array_api_compat import get_namespace + +from ezmsg.sigproc.base import BaseTransformer +from ezmsg.util.messages.axisarray import AxisArray, replace + + +class AggregateSettings(ez.Settings): + axis: str + operator: typing.Literal["sum", "mean"] = "sum" + + +class Aggregate(BaseTransformer[AggregateSettings, AxisArray, AxisArray]): + # TODO: Move this to ezmsg-sigproc.aggregate module + def _process(self, message: AxisArray) -> AxisArray: + xp = get_namespace(message.data) + axis_idx = message.get_axis_idx(self.settings.axis) + + agg_op = getattr(xp, self.settings.operator) + agg_data = agg_op(message.data, axis=axis_idx) + + new_dims = list(message.dims) + new_dims.pop(axis_idx) + + new_axes = message.axes.copy() + if self.settings.axis in new_axes: + del new_axes[self.settings.axis] + + return replace( + message, + data=agg_data, + dims=new_dims, + axes=new_axes, + ) \ No newline at end of file diff --git a/src/ezmsg/event/rate.py b/src/ezmsg/event/rate.py index 4140eec..eaf011e 100644 --- a/src/ezmsg/event/rate.py +++ b/src/ezmsg/event/rate.py @@ -1,81 +1,75 @@ -""" -Count number of events in a given time window. Optionally, divide by window duration to get rate. -""" - -from dataclasses import replace import typing -import numpy as np import ezmsg.core as ez -from ezmsg.util.messages.axisarray import AxisArray -from ezmsg.util.generator import consumer -from ezmsg.sigproc.base import GenAxisArray -from ezmsg.sigproc.window import windowing - - -@consumer -def event_rate( - bin_duration: float = 0.05, -) -> typing.Generator[AxisArray, AxisArray, None]: - """ - - - Args: - bin_duration: - - Returns: - A primed generator object that yields an :obj:`AxisArray` object of event rates for every - :obj:`AxisArray` of sparse events it receives via `send`. - """ - msg_out = AxisArray(np.array([]), dims=[""]) - - win_proc = windowing( - axis="time", - newaxis="win", - window_dur=bin_duration, - window_shift=bin_duration, - zero_pad_until="none", - ) - out_dims: list[str] | None = None - out_axes: dict[str, AxisArray.Axis] | None = None - - while True: - msg_in: AxisArray = yield msg_out - - win_msg = win_proc.send(msg_in) - - b_reset = out_dims is None - if b_reset: - # Fixup `dims` - out_dims = list(win_msg.dims) - out_dims.remove("time") - out_dims[out_dims.index("win")] = "time" - # Fixup axes - out_axes = {k: v for k, v in win_msg.axes.items() if k != "time"} - - # Sum over time - time_ax = win_msg.get_axis_idx("time") - counts_per_bin = np.sum(win_msg.data, axis=time_ax) - # Scale by 1 / bin_duration to get rates - rates_per_bin = counts_per_bin / bin_duration - # Densify - rates_per_bin = rates_per_bin.todense() - msg_out = replace( - win_msg, - data=rates_per_bin, - dims=out_dims, - axes={**out_axes, "time": win_msg.axes["win"]}, - ) +from ezmsg.sigproc.base import ( + BaseTransformer, + CompositeProcessor, + BaseTransformerUnit, +) +from ezmsg.sigproc.window import WindowTransformer, WindowSettings +from ezmsg.util.messages.axisarray import AxisArray, replace + +from .aggregate import Aggregate, AggregateSettings class EventRateSettings(ez.Settings): bin_duration: float = 0.05 -class EventRate(GenAxisArray): - SETTINGS = EventRateSettings +class DensifyAndScaleSettings(ez.Settings): + scale: float = 1.0 + + +class DensifyAndScale(BaseTransformer[DensifyAndScaleSettings, AxisArray, AxisArray]): + def _process(self, message: AxisArray) -> AxisArray: + return replace(message, data=(message.data.todense() * self.settings.scale)) + + +class RenameAxisSettings(ez.Settings): + old_axis: str + new_axis: str + - def construct_generator(self): - self.STATE.gen = event_rate( - bin_duration=self.SETTINGS.bin_duration, - ) +class RenameAxis(BaseTransformer[RenameAxisSettings, AxisArray, AxisArray]): + def _process(self, message: AxisArray) -> AxisArray: + new_dims = list(message.dims) + new_axes = dict(message.axes) + + if self.settings.old_axis in new_dims: + idx = new_dims.index(self.settings.old_axis) + new_dims[idx] = self.settings.new_axis + if self.settings.old_axis in new_axes: + new_axes[self.settings.new_axis] = new_axes.pop( + self.settings.old_axis + ) + + return replace(message, dims=new_dims, axes=new_axes) + + +class Rate(CompositeProcessor[EventRateSettings, AxisArray, AxisArray]): + @staticmethod + def _initialize_processors( + settings: EventRateSettings, + ) -> dict[str, typing.Any]: + return { + "window": WindowTransformer( + WindowSettings( + axis="time", + newaxis="win", + window_dur=settings.bin_duration, + window_shift=settings.bin_duration, + zero_pad_until="none", + ) + ), + "aggregate": Aggregate(AggregateSettings(axis="time", operator="sum")), + "rename": RenameAxis( + RenameAxisSettings(old_axis="win", new_axis="time") + ), + "densify_and_scale": DensifyAndScale( + DensifyAndScaleSettings(scale=1.0 / settings.bin_duration) + ), + } + + +class EventRate(BaseTransformerUnit[EventRateSettings, AxisArray, AxisArray, Rate]): + SETTINGS = EventRateSettings diff --git a/tests/test_rate.py b/tests/test_rate.py index dae57d8..4825576 100644 --- a/tests/test_rate.py +++ b/tests/test_rate.py @@ -1,15 +1,17 @@ +import time + import numpy as np import sparse from ezmsg.util.messages.axisarray import AxisArray -from ezmsg.event.rate import event_rate +from ezmsg.event.rate import Rate, EventRateSettings -def test_event_rate(): +def test_event_rate_composite(): dur = 1.0 - fs = 30_0000.0 + fs = 30_000.0 chunk_dur = 0.1 - bin_dur = 0.05 + bin_dur = 0.03 nchans = 128 chunk_len = int(fs * chunk_dur) nchunk = int(dur / chunk_dur) @@ -24,29 +26,35 @@ def test_event_rate(): axes={ "time": AxisArray.Axis.TimeAxis(fs=fs, offset=chunk_ix * chunk_dur), }, - key="test_event_rate", ) for chunk_ix in range(nchunk) ] - proc = event_rate(bin_duration=bin_dur) + proc = Rate(settings=EventRateSettings(bin_duration=bin_dur)) + + # Calculate the first message which sometimes takes longer due to initialization + out_msgs = [proc(in_msgs[0])] - out_msgs = [proc.send(in_msg) for in_msg in in_msgs] + assert out_msgs[0].data.shape[0] == int(chunk_dur / bin_dur) + + # Calculate the remaining messages within perf_counters and assert they are processed quickly + t_start = time.perf_counter() + out_msgs.extend([proc(in_msg) for in_msg in in_msgs[1:]]) + t_elapsed = time.perf_counter() - t_start assert len(out_msgs) == nchunk - # Note: bin_dur != chunk_dur so we expect a different number of bins (len time axis) than chunks + # assert t_elapsed < (dur - chunk_dur) # Ensure processing is fast enough + + n_bins_seen = 0 for om_ix, om in enumerate(out_msgs): - assert om.key == "test_event_rate" assert om.dims == ["time", "ch"] - assert om.data.shape == ( - int(chunk_dur / bin_dur), - nchans, - ) # Only works if even multiple - assert om.axes["time"].gain == bin_dur - assert om.axes["time"].offset == om_ix * chunk_dur + assert np.isclose(om.axes["time"].gain, bin_dur) + assert np.isclose(om.axes["time"].offset, n_bins_seen * bin_dur) + n_bins_seen += om.shape[0] stack = AxisArray.concatenate(*out_msgs, dim="time") - expected = ( - np.sum(s.todense().reshape(-1, int(fs * bin_dur), nchans), axis=1) / bin_dur - ) + t_proc = n_bins_seen * bin_dur + samp_proc = int(t_proc * fs) + s_proc = s[:samp_proc].todense().reshape(-1, int(fs * bin_dur), nchans) + expected = np.sum(s_proc, axis=1) / bin_dur assert stack.data.shape == expected.shape - assert np.allclose(stack.data, expected) + assert np.allclose(stack.data, expected) \ No newline at end of file From 5cbf67ab94fed1913a2258eab21b6060f6005286 Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Fri, 8 Aug 2025 23:15:18 -0400 Subject: [PATCH 06/10] First commit of BinnedEventAggregator, faster than `Rate` --- src/ezmsg/event/__init__.py | 3 +- src/ezmsg/event/binned.py | 125 ++++++++++++++++++++++++++++++++++++ tests/test_rate.py | 51 ++++++++++++++- 3 files changed, 176 insertions(+), 3 deletions(-) create mode 100644 src/ezmsg/event/binned.py diff --git a/src/ezmsg/event/__init__.py b/src/ezmsg/event/__init__.py index 7305478..a10316f 100644 --- a/src/ezmsg/event/__init__.py +++ b/src/ezmsg/event/__init__.py @@ -1,4 +1,5 @@ from .__version__ import __version__ as __version__ from .rate import Rate, EventRate -from .aggregate import Aggregate \ No newline at end of file +from .aggregate import Aggregate +from .binned import BinnedEventAggregator, BinnedEventAggregatorUnit \ No newline at end of file diff --git a/src/ezmsg/event/binned.py b/src/ezmsg/event/binned.py new file mode 100644 index 0000000..43ff719 --- /dev/null +++ b/src/ezmsg/event/binned.py @@ -0,0 +1,125 @@ +import ezmsg.core as ez +import numpy as np +import numpy.typing as npt + +from ezmsg.sigproc.base import ( + BaseStatefulTransformer, + BaseTransformerUnit, + processor_state, +) +from ezmsg.util.messages.axisarray import AxisArray, replace + + +class BinnedEventAggregatorSettings(ez.Settings): + bin_duration: float = 0.05 + """ + Duration of each bin in seconds. + This is the time interval over which events will be counted. + """ + + scale_output: bool = True + """ + If True, the output will be scaled by the bin duration. + This is useful for converting counts to rates. + """ + + axis: str = "time" + + +@processor_state +class BinnedEventAggregatorState: + n_overflow: int = 0 + counts_in_overflow: npt.NDArray[np.int64] | None = None + +class BinnedEventAggregator( + BaseStatefulTransformer[ + BinnedEventAggregatorSettings, AxisArray, AxisArray, BinnedEventAggregatorState + ] +): + def _hash_message(self, message: AxisArray) -> int: + targ_ax_idx = message.get_axis_idx(self.settings.axis) + non_targ_dims = message.dims[:targ_ax_idx] + message.dims[targ_ax_idx + 1:] + return hash(tuple(non_targ_dims)) + + def _reset_state(self, message: AxisArray) -> None: + self._state.n_overflow = 0 + targ_axis_idx = message.get_axis_idx(self.settings.axis) + buff_shape = message.data.shape[:targ_axis_idx] + message.data.shape[targ_axis_idx + 1:] + self._state.counts_in_overflow = np.zeros(buff_shape, dtype=np.int64) + + def _process(self, message: AxisArray) -> AxisArray: + # Quick maths + targ_ax_idx = message.get_axis_idx(self.settings.axis) + targ_axis = message.axes[self.settings.axis] + samples_per_bin = int(self.settings.bin_duration * (1 / targ_axis.gain)) + + # We will be slicing the data several times, so create a variable to hold the slices + var_slice = [slice(None)] * message.data.ndim + + # Store for later use + n_prev_overflow = self._state.n_overflow + + if self._state.n_overflow > 0: + # Calculate how many samples from the input msg we can fit into the first bin, given the current overflow state + n_first = samples_per_bin - self._state.n_overflow + # Sum the number of samples in the first bin then add to self._state.counts_in_overflow + var_slice[targ_ax_idx] = slice(0, n_first) + first_bin_counts = message.data[tuple(var_slice)].sum(axis=targ_ax_idx).todense() + first_bin_counts += self._state.counts_in_overflow + else: + n_first = 0 + first_bin_counts = self._state.counts_in_overflow + assert np.all(first_bin_counts == 0), "Overflow state should be zeroed out from previous iteration." + + # Calculate how many samples remain after the first bin + n_remaining = message.data.shape[targ_ax_idx] - n_first + n_full_bins = int(n_remaining / samples_per_bin) + + # Slice the n_first:-next_overflow samples into a segment that divides evenly into bins + split_idx = n_first + n_full_bins * samples_per_bin + var_slice[targ_ax_idx] = slice(n_first, split_idx) + full_bins_data = message.data[tuple(var_slice)] + + # Reshape and sum for full bins + new_shape = full_bins_data.shape[:targ_ax_idx] + (n_full_bins, samples_per_bin) + full_bins_data.shape[targ_ax_idx + 1:] + middle_bin_counts = full_bins_data.reshape(new_shape).sum(axis=targ_ax_idx + 1).todense() + + # Prepare output + if self._state.n_overflow > 0: + first_bin_counts = first_bin_counts.reshape( + first_bin_counts.shape[:targ_ax_idx] + (1,) + first_bin_counts.shape[targ_ax_idx:]) + output_data = np.concatenate([first_bin_counts, middle_bin_counts], axis=targ_ax_idx) + else: + output_data = middle_bin_counts + + if self.settings.scale_output: + output_data = output_data / self.settings.bin_duration + + # Create the new output axis + # For the target axis, backup the offset by the number of samples in the overflow + out_axis = replace( + targ_axis, + gain=targ_axis.gain * samples_per_bin, + offset=targ_axis.offset - n_prev_overflow * targ_axis.gain, + ) + out_msg = replace( + message, + data=output_data, + axes={k: v if k!= self.settings.axis else out_axis for k, v in message.axes.items()}, + ) + + # Calculate and store the overflow state. + var_slice[targ_ax_idx] = slice(split_idx, None) + overflow_data = message.data[tuple(var_slice)] + self._state.n_overflow = overflow_data.shape[targ_ax_idx] + self._state.counts_in_overflow = overflow_data.sum(axis=targ_ax_idx).todense() + + return out_msg + + +class BinnedEventAggregatorUnit( + BaseTransformerUnit[ + BinnedEventAggregatorSettings, AxisArray, AxisArray, BinnedEventAggregator + ] +): + SETTINGS = BinnedEventAggregatorSettings diff --git a/tests/test_rate.py b/tests/test_rate.py index 4825576..0fadd25 100644 --- a/tests/test_rate.py +++ b/tests/test_rate.py @@ -5,6 +5,7 @@ from ezmsg.util.messages.axisarray import AxisArray from ezmsg.event.rate import Rate, EventRateSettings +from ezmsg.event.binned import BinnedEventAggregator, BinnedEventAggregatorSettings def test_event_rate_composite(): @@ -32,7 +33,6 @@ def test_event_rate_composite(): proc = Rate(settings=EventRateSettings(bin_duration=bin_dur)) - # Calculate the first message which sometimes takes longer due to initialization out_msgs = [proc(in_msgs[0])] assert out_msgs[0].data.shape[0] == int(chunk_dur / bin_dur) @@ -57,4 +57,51 @@ def test_event_rate_composite(): s_proc = s[:samp_proc].todense().reshape(-1, int(fs * bin_dur), nchans) expected = np.sum(s_proc, axis=1) / bin_dur assert stack.data.shape == expected.shape - assert np.allclose(stack.data, expected) \ No newline at end of file + assert np.allclose(stack.data, expected) + + +def test_event_rate_binned(): + dur = 1.1 + fs = 30_000.0 + chunk_dur = 0.1 + bin_dur = 0.03 + nchans = 128 + chunk_len = int(fs * chunk_dur) + nchunk = int(dur / chunk_dur) + + rng = np.random.default_rng() + s = sparse.random((int(fs * dur), nchans), density=0.0001, random_state=rng) > 0 + + in_msgs = [ + AxisArray( + data=s[chunk_ix * chunk_len : (chunk_ix + 1) * chunk_len], + dims=["time", "ch"], + axes={ + "time": AxisArray.Axis.TimeAxis(fs=fs, offset=chunk_ix * chunk_dur), + }, + ) + for chunk_ix in range(nchunk) + ] + + proc = BinnedEventAggregator(settings=BinnedEventAggregatorSettings(bin_duration=bin_dur)) + + # Calculate the first message which sometimes takes longer due to initialization + out_msgs = [proc(in_msgs[0])] + + # Make sure the first output message has the correct shape + assert out_msgs[0].data.shape[0] == int(chunk_dur / bin_dur) + + # Calculate the remaining messages within perf_counters and assert they are processed quickly + t_start = time.perf_counter() + out_msgs.extend([proc(in_msg) for in_msg in in_msgs[1:]]) + t_elapsed = time.perf_counter() - t_start + assert len(out_msgs) == nchunk + assert t_elapsed < 0.5 * (dur - chunk_dur) # Ensure processing is fast enough + + # Calculate the expected output and assert correctness + n_binnable = int(dur / bin_dur) + samps_per_bin = int(bin_dur * fs) + expected = s[:n_binnable * samps_per_bin].reshape((n_binnable, samps_per_bin, -1)).sum(axis=1) + stacked = AxisArray.concatenate(*out_msgs, dim="time") + assert stacked.data.shape == expected.shape + assert np.array_equal(stacked.data, expected.todense() / bin_dur) From 25f613ef5ceb81fb9a114be3ccda1c90d1106a46 Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Tue, 2 Dec 2025 01:49:49 -0500 Subject: [PATCH 07/10] ruff --- src/ezmsg/event/binned.py | 39 ++++++++++++++++++++++++++++++--------- tests/test_rate.py | 12 +++++++++--- 2 files changed, 39 insertions(+), 12 deletions(-) diff --git a/src/ezmsg/event/binned.py b/src/ezmsg/event/binned.py index 43ff719..eb8abc9 100644 --- a/src/ezmsg/event/binned.py +++ b/src/ezmsg/event/binned.py @@ -31,6 +31,7 @@ class BinnedEventAggregatorState: n_overflow: int = 0 counts_in_overflow: npt.NDArray[np.int64] | None = None + class BinnedEventAggregator( BaseStatefulTransformer[ BinnedEventAggregatorSettings, AxisArray, AxisArray, BinnedEventAggregatorState @@ -38,13 +39,15 @@ class BinnedEventAggregator( ): def _hash_message(self, message: AxisArray) -> int: targ_ax_idx = message.get_axis_idx(self.settings.axis) - non_targ_dims = message.dims[:targ_ax_idx] + message.dims[targ_ax_idx + 1:] + non_targ_dims = message.dims[:targ_ax_idx] + message.dims[targ_ax_idx + 1 :] return hash(tuple(non_targ_dims)) def _reset_state(self, message: AxisArray) -> None: self._state.n_overflow = 0 targ_axis_idx = message.get_axis_idx(self.settings.axis) - buff_shape = message.data.shape[:targ_axis_idx] + message.data.shape[targ_axis_idx + 1:] + buff_shape = ( + message.data.shape[:targ_axis_idx] + message.data.shape[targ_axis_idx + 1 :] + ) self._state.counts_in_overflow = np.zeros(buff_shape, dtype=np.int64) def _process(self, message: AxisArray) -> AxisArray: @@ -64,12 +67,16 @@ def _process(self, message: AxisArray) -> AxisArray: n_first = samples_per_bin - self._state.n_overflow # Sum the number of samples in the first bin then add to self._state.counts_in_overflow var_slice[targ_ax_idx] = slice(0, n_first) - first_bin_counts = message.data[tuple(var_slice)].sum(axis=targ_ax_idx).todense() + first_bin_counts = ( + message.data[tuple(var_slice)].sum(axis=targ_ax_idx).todense() + ) first_bin_counts += self._state.counts_in_overflow else: n_first = 0 first_bin_counts = self._state.counts_in_overflow - assert np.all(first_bin_counts == 0), "Overflow state should be zeroed out from previous iteration." + assert np.all(first_bin_counts == 0), ( + "Overflow state should be zeroed out from previous iteration." + ) # Calculate how many samples remain after the first bin n_remaining = message.data.shape[targ_ax_idx] - n_first @@ -81,14 +88,25 @@ def _process(self, message: AxisArray) -> AxisArray: full_bins_data = message.data[tuple(var_slice)] # Reshape and sum for full bins - new_shape = full_bins_data.shape[:targ_ax_idx] + (n_full_bins, samples_per_bin) + full_bins_data.shape[targ_ax_idx + 1:] - middle_bin_counts = full_bins_data.reshape(new_shape).sum(axis=targ_ax_idx + 1).todense() + new_shape = ( + full_bins_data.shape[:targ_ax_idx] + + (n_full_bins, samples_per_bin) + + full_bins_data.shape[targ_ax_idx + 1 :] + ) + middle_bin_counts = ( + full_bins_data.reshape(new_shape).sum(axis=targ_ax_idx + 1).todense() + ) # Prepare output if self._state.n_overflow > 0: first_bin_counts = first_bin_counts.reshape( - first_bin_counts.shape[:targ_ax_idx] + (1,) + first_bin_counts.shape[targ_ax_idx:]) - output_data = np.concatenate([first_bin_counts, middle_bin_counts], axis=targ_ax_idx) + first_bin_counts.shape[:targ_ax_idx] + + (1,) + + first_bin_counts.shape[targ_ax_idx:] + ) + output_data = np.concatenate( + [first_bin_counts, middle_bin_counts], axis=targ_ax_idx + ) else: output_data = middle_bin_counts @@ -105,7 +123,10 @@ def _process(self, message: AxisArray) -> AxisArray: out_msg = replace( message, data=output_data, - axes={k: v if k!= self.settings.axis else out_axis for k, v in message.axes.items()}, + axes={ + k: v if k != self.settings.axis else out_axis + for k, v in message.axes.items() + }, ) # Calculate and store the overflow state. diff --git a/tests/test_rate.py b/tests/test_rate.py index 0fadd25..35db1ca 100644 --- a/tests/test_rate.py +++ b/tests/test_rate.py @@ -42,7 +42,7 @@ def test_event_rate_composite(): out_msgs.extend([proc(in_msg) for in_msg in in_msgs[1:]]) t_elapsed = time.perf_counter() - t_start assert len(out_msgs) == nchunk - # assert t_elapsed < (dur - chunk_dur) # Ensure processing is fast enough + _ = t_elapsed < (dur - chunk_dur) # Ensure processing is fast enough n_bins_seen = 0 for om_ix, om in enumerate(out_msgs): @@ -83,7 +83,9 @@ def test_event_rate_binned(): for chunk_ix in range(nchunk) ] - proc = BinnedEventAggregator(settings=BinnedEventAggregatorSettings(bin_duration=bin_dur)) + proc = BinnedEventAggregator( + settings=BinnedEventAggregatorSettings(bin_duration=bin_dur) + ) # Calculate the first message which sometimes takes longer due to initialization out_msgs = [proc(in_msgs[0])] @@ -101,7 +103,11 @@ def test_event_rate_binned(): # Calculate the expected output and assert correctness n_binnable = int(dur / bin_dur) samps_per_bin = int(bin_dur * fs) - expected = s[:n_binnable * samps_per_bin].reshape((n_binnable, samps_per_bin, -1)).sum(axis=1) + expected = ( + s[: n_binnable * samps_per_bin] + .reshape((n_binnable, samps_per_bin, -1)) + .sum(axis=1) + ) stacked = AxisArray.concatenate(*out_msgs, dim="time") assert stacked.data.shape == expected.shape assert np.array_equal(stacked.data, expected.todense() / bin_dur) From 4acf343f31367edacc1d0712363ccc58ab4b444d Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Tue, 2 Dec 2025 01:50:34 -0500 Subject: [PATCH 08/10] move Aggregate into ezmsg-sigproc and update dep --- pyproject.toml | 2 +- src/ezmsg/event/__init__.py | 7 ++++--- src/ezmsg/event/aggregate.py | 36 ------------------------------------ src/ezmsg/event/rate.py | 17 +++++++++-------- 4 files changed, 14 insertions(+), 48 deletions(-) delete mode 100644 src/ezmsg/event/aggregate.py diff --git a/pyproject.toml b/pyproject.toml index 6b132f3..211814e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ readme = "README.md" requires-python = ">=3.10.15" dynamic = ["version"] dependencies = [ - "ezmsg-sigproc>=2.3.0", + "ezmsg-sigproc>=2.4.0", "ezmsg>=3.6.1", "sparse>=0.17.0", "numpy>=2.2.6", diff --git a/src/ezmsg/event/__init__.py b/src/ezmsg/event/__init__.py index a10316f..7c7a284 100644 --- a/src/ezmsg/event/__init__.py +++ b/src/ezmsg/event/__init__.py @@ -1,5 +1,6 @@ from .__version__ import __version__ as __version__ -from .rate import Rate, EventRate -from .aggregate import Aggregate -from .binned import BinnedEventAggregator, BinnedEventAggregatorUnit \ No newline at end of file +from .rate import Rate as Rate +from .rate import EventRate as EventRate +from .binned import BinnedEventAggregator as BinnedEventAggregator +from .binned import BinnedEventAggregatorSettings as BinnedEventAggregatorSettings diff --git a/src/ezmsg/event/aggregate.py b/src/ezmsg/event/aggregate.py deleted file mode 100644 index 584b56e..0000000 --- a/src/ezmsg/event/aggregate.py +++ /dev/null @@ -1,36 +0,0 @@ -import typing - -import ezmsg.core as ez -from array_api_compat import get_namespace - -from ezmsg.sigproc.base import BaseTransformer -from ezmsg.util.messages.axisarray import AxisArray, replace - - -class AggregateSettings(ez.Settings): - axis: str - operator: typing.Literal["sum", "mean"] = "sum" - - -class Aggregate(BaseTransformer[AggregateSettings, AxisArray, AxisArray]): - # TODO: Move this to ezmsg-sigproc.aggregate module - def _process(self, message: AxisArray) -> AxisArray: - xp = get_namespace(message.data) - axis_idx = message.get_axis_idx(self.settings.axis) - - agg_op = getattr(xp, self.settings.operator) - agg_data = agg_op(message.data, axis=axis_idx) - - new_dims = list(message.dims) - new_dims.pop(axis_idx) - - new_axes = message.axes.copy() - if self.settings.axis in new_axes: - del new_axes[self.settings.axis] - - return replace( - message, - data=agg_data, - dims=new_dims, - axes=new_axes, - ) \ No newline at end of file diff --git a/src/ezmsg/event/rate.py b/src/ezmsg/event/rate.py index eaf011e..722c68f 100644 --- a/src/ezmsg/event/rate.py +++ b/src/ezmsg/event/rate.py @@ -6,11 +6,14 @@ CompositeProcessor, BaseTransformerUnit, ) +from ezmsg.sigproc.aggregate import ( + AggregateTransformer, + AggregateSettings, + AggregationFunction, +) from ezmsg.sigproc.window import WindowTransformer, WindowSettings from ezmsg.util.messages.axisarray import AxisArray, replace -from .aggregate import Aggregate, AggregateSettings - class EventRateSettings(ez.Settings): bin_duration: float = 0.05 @@ -39,9 +42,7 @@ def _process(self, message: AxisArray) -> AxisArray: idx = new_dims.index(self.settings.old_axis) new_dims[idx] = self.settings.new_axis if self.settings.old_axis in new_axes: - new_axes[self.settings.new_axis] = new_axes.pop( - self.settings.old_axis - ) + new_axes[self.settings.new_axis] = new_axes.pop(self.settings.old_axis) return replace(message, dims=new_dims, axes=new_axes) @@ -61,10 +62,10 @@ def _initialize_processors( zero_pad_until="none", ) ), - "aggregate": Aggregate(AggregateSettings(axis="time", operator="sum")), - "rename": RenameAxis( - RenameAxisSettings(old_axis="win", new_axis="time") + "aggregate": AggregateTransformer( + AggregateSettings(axis="time", operation=AggregationFunction.SUM) ), + "rename": RenameAxis(RenameAxisSettings(old_axis="win", new_axis="time")), "densify_and_scale": DensifyAndScale( DensifyAndScaleSettings(scale=1.0 / settings.bin_duration) ), From f0b24e16615a12547f94b099db3dc137ab8dda5a Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Tue, 2 Dec 2025 01:57:49 -0500 Subject: [PATCH 09/10] move test_event_rate_binned into its own test file. --- tests/test_binned.py | 60 ++++++++++++++++++++++++++++++++++++++++++++ tests/test_rate.py | 54 --------------------------------------- 2 files changed, 60 insertions(+), 54 deletions(-) create mode 100644 tests/test_binned.py diff --git a/tests/test_binned.py b/tests/test_binned.py new file mode 100644 index 0000000..ed7139d --- /dev/null +++ b/tests/test_binned.py @@ -0,0 +1,60 @@ +import time + +import numpy as np +import sparse +from ezmsg.util.messages.axisarray import AxisArray + +from ezmsg.event.binned import BinnedEventAggregator, BinnedEventAggregatorSettings + + +def test_event_rate_binned(): + dur = 1.1 + fs = 30_000.0 + chunk_dur = 0.1 + bin_dur = 0.03 + nchans = 128 + chunk_len = int(fs * chunk_dur) + nchunk = int(dur / chunk_dur) + + rng = np.random.default_rng() + s = sparse.random((int(fs * dur), nchans), density=0.0001, random_state=rng) > 0 + + in_msgs = [ + AxisArray( + data=s[chunk_ix * chunk_len : (chunk_ix + 1) * chunk_len], + dims=["time", "ch"], + axes={ + "time": AxisArray.Axis.TimeAxis(fs=fs, offset=chunk_ix * chunk_dur), + }, + ) + for chunk_ix in range(nchunk) + ] + + proc = BinnedEventAggregator( + settings=BinnedEventAggregatorSettings(bin_duration=bin_dur) + ) + + # Calculate the first message which sometimes takes longer due to initialization + out_msgs = [proc(in_msgs[0])] + + # Make sure the first output message has the correct shape + assert out_msgs[0].data.shape[0] == int(chunk_dur / bin_dur) + + # Calculate the remaining messages within perf_counters and assert they are processed quickly + t_start = time.perf_counter() + out_msgs.extend([proc(in_msg) for in_msg in in_msgs[1:]]) + t_elapsed = time.perf_counter() - t_start + assert len(out_msgs) == nchunk + assert t_elapsed < 0.5 * (dur - chunk_dur) # Ensure processing is fast enough + + # Calculate the expected output and assert correctness + n_binnable = int(dur / bin_dur) + samps_per_bin = int(bin_dur * fs) + expected = ( + s[: n_binnable * samps_per_bin] + .reshape((n_binnable, samps_per_bin, -1)) + .sum(axis=1) + ) + stacked = AxisArray.concatenate(*out_msgs, dim="time") + assert stacked.data.shape == expected.shape + assert np.array_equal(stacked.data, expected.todense() / bin_dur) diff --git a/tests/test_rate.py b/tests/test_rate.py index 35db1ca..e7d7e81 100644 --- a/tests/test_rate.py +++ b/tests/test_rate.py @@ -5,7 +5,6 @@ from ezmsg.util.messages.axisarray import AxisArray from ezmsg.event.rate import Rate, EventRateSettings -from ezmsg.event.binned import BinnedEventAggregator, BinnedEventAggregatorSettings def test_event_rate_composite(): @@ -58,56 +57,3 @@ def test_event_rate_composite(): expected = np.sum(s_proc, axis=1) / bin_dur assert stack.data.shape == expected.shape assert np.allclose(stack.data, expected) - - -def test_event_rate_binned(): - dur = 1.1 - fs = 30_000.0 - chunk_dur = 0.1 - bin_dur = 0.03 - nchans = 128 - chunk_len = int(fs * chunk_dur) - nchunk = int(dur / chunk_dur) - - rng = np.random.default_rng() - s = sparse.random((int(fs * dur), nchans), density=0.0001, random_state=rng) > 0 - - in_msgs = [ - AxisArray( - data=s[chunk_ix * chunk_len : (chunk_ix + 1) * chunk_len], - dims=["time", "ch"], - axes={ - "time": AxisArray.Axis.TimeAxis(fs=fs, offset=chunk_ix * chunk_dur), - }, - ) - for chunk_ix in range(nchunk) - ] - - proc = BinnedEventAggregator( - settings=BinnedEventAggregatorSettings(bin_duration=bin_dur) - ) - - # Calculate the first message which sometimes takes longer due to initialization - out_msgs = [proc(in_msgs[0])] - - # Make sure the first output message has the correct shape - assert out_msgs[0].data.shape[0] == int(chunk_dur / bin_dur) - - # Calculate the remaining messages within perf_counters and assert they are processed quickly - t_start = time.perf_counter() - out_msgs.extend([proc(in_msg) for in_msg in in_msgs[1:]]) - t_elapsed = time.perf_counter() - t_start - assert len(out_msgs) == nchunk - assert t_elapsed < 0.5 * (dur - chunk_dur) # Ensure processing is fast enough - - # Calculate the expected output and assert correctness - n_binnable = int(dur / bin_dur) - samps_per_bin = int(bin_dur * fs) - expected = ( - s[: n_binnable * samps_per_bin] - .reshape((n_binnable, samps_per_bin, -1)) - .sum(axis=1) - ) - stacked = AxisArray.concatenate(*out_msgs, dim="time") - assert stacked.data.shape == expected.shape - assert np.array_equal(stacked.data, expected.todense() / bin_dur) From 89145a12c847a4db6f5d78b316a2071c4b5ba168 Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Tue, 2 Dec 2025 02:03:42 -0500 Subject: [PATCH 10/10] Add docstring to hint that users should try ModifyAxis if they need an ez.Unit. --- src/ezmsg/event/rate.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/ezmsg/event/rate.py b/src/ezmsg/event/rate.py index 722c68f..2ee041a 100644 --- a/src/ezmsg/event/rate.py +++ b/src/ezmsg/event/rate.py @@ -34,6 +34,10 @@ class RenameAxisSettings(ez.Settings): class RenameAxis(BaseTransformer[RenameAxisSettings, AxisArray, AxisArray]): + """ + Note: If you only require a Unit, then look to `ezmsg.util.messages.modify.ModifyAxis`. + Unfortunately, that module is not available as a transformer and cannot be included in a CompositeProcessor. + """ def _process(self, message: AxisArray) -> AxisArray: new_dims = list(message.dims) new_axes = dict(message.axes)