Skip to content
Merged
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,4 @@ cython_debug/

src/ezmsg/event/__version__.py
uv.lock
*.local.json
48 changes: 31 additions & 17 deletions profile_rate.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
import cProfile
import pstats

# import pstats
import numpy as np
import sparse
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.event.rate import Rate, EventRateSettings
import ezmsg.core as ez
from ezmsg.event.rate import EventRate, EventRateSettings


# Simulate input data
def generate_sparse_data(num_samples, num_channels, sparsity_factor, rng):
data = sparse.random((num_samples, num_channels), density=sparsity_factor, random_state=rng) > 0
data = (
sparse.random(
(num_samples, num_channels), density=sparsity_factor, random_state=rng
)
> 0
)
return data

def run_rate_processor(num_samples, num_channels, sparsity_factor, bin_duration, chunk_dur, fs):

def run_rate_processor(
num_samples, num_channels, sparsity_factor, bin_duration, chunk_dur, fs
):
rng = np.random.default_rng()
s = generate_sparse_data(num_samples, num_channels, sparsity_factor, rng)

Expand All @@ -30,35 +39,40 @@ def run_rate_processor(num_samples, num_channels, sparsity_factor, bin_duration,
]

settings = EventRateSettings(bin_duration=bin_duration)
rate_processor = Rate(settings)
rate_processor = EventRate(settings)

output_messages = []
for in_msg in in_msgs:
output_messages.append(rate_processor(in_msg))

return output_messages


if __name__ == "__main__":
NUM_SAMPLES = 300_000 # Number of time samples (e.g., 10 seconds at 30kHz)
NUM_CHANNELS = 128 # Number of channels
SPARSITY_FACTOR = 0.0001 # 0.01% sparse, as in test_rate.py
BIN_DURATION = 0.03 # 30ms bin duration, as in test_rate.py
CHUNK_DURATION = 0.1 # 100ms chunk duration, as in test_rate.py
FS = 30000.0 # Sampling frequency, as in test_rate.py
NUM_CHANNELS = 128 # Number of channels
SPARSITY_FACTOR = 0.0001 # 0.01% sparse, as in test_rate.py
BIN_DURATION = 0.03 # 30ms bin duration, as in test_rate.py
CHUNK_DURATION = 0.1 # 100ms chunk duration, as in test_rate.py
FS = 30000.0 # Sampling frequency, as in test_rate.py

print(f"Profiling with: samples={NUM_SAMPLES}, channels={NUM_CHANNELS}, sparsity={SPARSITY_FACTOR}, bin_duration={BIN_DURATION}, chunk_duration={CHUNK_DURATION}, fs={FS}")
print(
f"Profiling with: samples={NUM_SAMPLES}, channels={NUM_CHANNELS}, sparsity={SPARSITY_FACTOR}, bin_duration={BIN_DURATION}, chunk_duration={CHUNK_DURATION}, fs={FS}"
)

# Run with cProfile
profiler = cProfile.Profile()
profiler.enable()

run_rate_processor(NUM_SAMPLES, NUM_CHANNELS, SPARSITY_FACTOR, BIN_DURATION, CHUNK_DURATION, FS)


run_rate_processor(
NUM_SAMPLES, NUM_CHANNELS, SPARSITY_FACTOR, BIN_DURATION, CHUNK_DURATION, FS
)

profiler.disable()

# Save stats to a file in a format snakeviz can read
stats_file = "rate_profile.prof"
profiler.dump_stats(stats_file)

print(f"Profiling results saved to {stats_file}")
print("To visualize, run: uv run snakeviz rate_profile.prof")
print("To visualize, run: uv run snakeviz rate_profile.prof")
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ readme = "README.md"
requires-python = ">=3.10.15"
dynamic = ["version"]
dependencies = [
"ezmsg-sigproc>=2.2.0",
"ezmsg-sigproc>=2.4.0",
"ezmsg>=3.6.1",
"sparse>=0.17.0",
"numpy>=2.2.6",
Expand Down
5 changes: 5 additions & 0 deletions src/ezmsg/event/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
from .__version__ import __version__ as __version__

from .rate import Rate as Rate
from .rate import EventRate as EventRate
from .binned import BinnedEventAggregator as BinnedEventAggregator
from .binned import BinnedEventAggregatorSettings as BinnedEventAggregatorSettings
146 changes: 146 additions & 0 deletions src/ezmsg/event/binned.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import ezmsg.core as ez
import numpy as np
import numpy.typing as npt

from ezmsg.sigproc.base import (
BaseStatefulTransformer,
BaseTransformerUnit,
processor_state,
)
from ezmsg.util.messages.axisarray import AxisArray, replace


class BinnedEventAggregatorSettings(ez.Settings):
bin_duration: float = 0.05
"""
Duration of each bin in seconds.
This is the time interval over which events will be counted.
"""

scale_output: bool = True
"""
If True, the output will be scaled by the bin duration.
This is useful for converting counts to rates.
"""

axis: str = "time"


@processor_state
class BinnedEventAggregatorState:
n_overflow: int = 0
counts_in_overflow: npt.NDArray[np.int64] | None = None


class BinnedEventAggregator(
BaseStatefulTransformer[
BinnedEventAggregatorSettings, AxisArray, AxisArray, BinnedEventAggregatorState
]
):
def _hash_message(self, message: AxisArray) -> int:
targ_ax_idx = message.get_axis_idx(self.settings.axis)
non_targ_dims = message.dims[:targ_ax_idx] + message.dims[targ_ax_idx + 1 :]
return hash(tuple(non_targ_dims))

def _reset_state(self, message: AxisArray) -> None:
self._state.n_overflow = 0
targ_axis_idx = message.get_axis_idx(self.settings.axis)
buff_shape = (
message.data.shape[:targ_axis_idx] + message.data.shape[targ_axis_idx + 1 :]
)
self._state.counts_in_overflow = np.zeros(buff_shape, dtype=np.int64)

def _process(self, message: AxisArray) -> AxisArray:
# Quick maths
targ_ax_idx = message.get_axis_idx(self.settings.axis)
targ_axis = message.axes[self.settings.axis]
samples_per_bin = int(self.settings.bin_duration * (1 / targ_axis.gain))

# We will be slicing the data several times, so create a variable to hold the slices
var_slice = [slice(None)] * message.data.ndim

# Store for later use
n_prev_overflow = self._state.n_overflow

if self._state.n_overflow > 0:
# Calculate how many samples from the input msg we can fit into the first bin, given the current overflow state
n_first = samples_per_bin - self._state.n_overflow
# Sum the number of samples in the first bin then add to self._state.counts_in_overflow
var_slice[targ_ax_idx] = slice(0, n_first)
first_bin_counts = (
message.data[tuple(var_slice)].sum(axis=targ_ax_idx).todense()
)
first_bin_counts += self._state.counts_in_overflow
else:
n_first = 0
first_bin_counts = self._state.counts_in_overflow
assert np.all(first_bin_counts == 0), (
"Overflow state should be zeroed out from previous iteration."
)

# Calculate how many samples remain after the first bin
n_remaining = message.data.shape[targ_ax_idx] - n_first
n_full_bins = int(n_remaining / samples_per_bin)

# Slice the n_first:-next_overflow samples into a segment that divides evenly into bins
split_idx = n_first + n_full_bins * samples_per_bin
var_slice[targ_ax_idx] = slice(n_first, split_idx)
full_bins_data = message.data[tuple(var_slice)]

# Reshape and sum for full bins
new_shape = (
full_bins_data.shape[:targ_ax_idx]
+ (n_full_bins, samples_per_bin)
+ full_bins_data.shape[targ_ax_idx + 1 :]
)
middle_bin_counts = (
full_bins_data.reshape(new_shape).sum(axis=targ_ax_idx + 1).todense()
)

# Prepare output
if self._state.n_overflow > 0:
first_bin_counts = first_bin_counts.reshape(
first_bin_counts.shape[:targ_ax_idx]
+ (1,)
+ first_bin_counts.shape[targ_ax_idx:]
)
output_data = np.concatenate(
[first_bin_counts, middle_bin_counts], axis=targ_ax_idx
)
else:
output_data = middle_bin_counts

if self.settings.scale_output:
output_data = output_data / self.settings.bin_duration

# Create the new output axis
# For the target axis, backup the offset by the number of samples in the overflow
out_axis = replace(
targ_axis,
gain=targ_axis.gain * samples_per_bin,
offset=targ_axis.offset - n_prev_overflow * targ_axis.gain,
)
out_msg = replace(
message,
data=output_data,
axes={
k: v if k != self.settings.axis else out_axis
for k, v in message.axes.items()
},
)

# Calculate and store the overflow state.
var_slice[targ_ax_idx] = slice(split_idx, None)
overflow_data = message.data[tuple(var_slice)]
self._state.n_overflow = overflow_data.shape[targ_ax_idx]
self._state.counts_in_overflow = overflow_data.sum(axis=targ_ax_idx).todense()

return out_msg


class BinnedEventAggregatorUnit(
BaseTransformerUnit[
BinnedEventAggregatorSettings, AxisArray, AxisArray, BinnedEventAggregator
]
):
SETTINGS = BinnedEventAggregatorSettings
Loading