From a2c3660679c6197ff76e07809a7e17aef25a3204 Mon Sep 17 00:00:00 2001 From: kylmcgr Date: Mon, 27 Apr 2026 17:39:27 -0600 Subject: [PATCH 1/4] new mlx threshold rate node --- src/ezmsg/event/threshold_rate.py | 242 ++++++++++++++++++ .../event/util/threshold_rate_mlx_metal.py | 158 ++++++++++++ tests/test_threshold_rate.py | 206 +++++++++++++++ 3 files changed, 606 insertions(+) create mode 100644 src/ezmsg/event/threshold_rate.py create mode 100644 src/ezmsg/event/util/threshold_rate_mlx_metal.py create mode 100644 tests/test_threshold_rate.py diff --git a/src/ezmsg/event/threshold_rate.py b/src/ezmsg/event/threshold_rate.py new file mode 100644 index 0000000..151200c --- /dev/null +++ b/src/ezmsg/event/threshold_rate.py @@ -0,0 +1,242 @@ +"""Dense fused threshold-crossing event-rate transformer.""" + +from typing import Any + +import ezmsg.core as ez +import numpy as np +from array_api_compat import get_namespace, is_numpy_array +from ezmsg.baseproc import BaseStatefulTransformer, BaseTransformerUnit, processor_state +from ezmsg.util.messages.axisarray import AxisArray, replace + + +class ThresholdCrossingRateSettings(ez.Settings): + threshold: float = -3.5 + """The value the signal must cross to count an event.""" + + refrac_dur: float = 0.001 + """Minimum duration between counted threshold crossings in seconds.""" + + bin_duration: float = 0.05 + """Output bin duration in seconds.""" + + rate_normalize: bool = True + """If True, divide counts by bin_duration to emit events/second.""" + + axis: str = "time" + """Input sample axis.""" + + use_mlx_metal: bool = True + """If True, MLX inputs use the fused on-device Metal implementation.""" + + +@processor_state +class ThresholdCrossingRateState: + prev_over: Any = None + """Whether the previous sample was over threshold for each feature.""" + + elapsed: Any = None + """Samples since the last accepted threshold crossing for each feature.""" + + overflow_counts: Any = None + """Raw counts in the current partial output bin for each feature.""" + + n_overflow: int = 0 + """Number of input samples in the current partial output bin.""" + + refrac_width: int = 0 + samples_per_bin: int = 0 + + +class ThresholdCrossingRateTransformer( + BaseStatefulTransformer[ + ThresholdCrossingRateSettings, + AxisArray, + AxisArray, + ThresholdCrossingRateState, + ] +): + """Count threshold crossings directly into dense rate bins. + + This transformer covers the simple threshold-crossing case used by dense + preprocessing pipelines: crossing-aligned events, no peak-value payload, + no peak-duration filtering, and no sparse.COO boundary. It preserves exact + refractory behavior while allowing MLX inputs to remain on device through + a fused Metal path. + """ + + def _hash_message(self, message: AxisArray) -> int: + ax_idx = message.get_axis_idx(self.settings.axis) + sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :] + return hash((message.key, sample_shape, message.axes[self.settings.axis].gain)) + + def _reset_state(self, message: AxisArray) -> None: + xp = get_namespace(message.data) + ax_idx = message.get_axis_idx(self.settings.axis) + feature_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :] + + fs = 1.0 / message.axes[self.settings.axis].gain + self._state.refrac_width = int(self.settings.refrac_dur * fs) + self._state.samples_per_bin = int(self.settings.bin_duration * fs) + if self._state.samples_per_bin < 1: + raise ValueError( + f"bin_duration={self.settings.bin_duration} is shorter than one sample at fs={fs:g} Hz" + ) + + self._state.prev_over = None + self._state.elapsed = xp.full( + feature_shape, + self._state.refrac_width + 1, + dtype=xp.int32, + ) + self._state.overflow_counts = xp.zeros(feature_shape, dtype=xp.float32) + self._state.n_overflow = 0 + + def _process(self, message: AxisArray) -> AxisArray: + xp = get_namespace(message.data) + ax_idx = message.get_axis_idx(self.settings.axis) + + if ax_idx != 0: + perm = (ax_idx,) + tuple(i for i in range(message.data.ndim) if i != ax_idx) + message = replace( + message, + data=xp.permute_dims(message.data, perm), + dims=[self.settings.axis] + message.dims[:ax_idx] + message.dims[ax_idx + 1 :], + ) + + n_samples = message.data.shape[0] + n_prev_overflow = self._state.n_overflow + n_total = n_prev_overflow + n_samples + n_bins = n_total // self._state.samples_per_bin + self._state.n_overflow = n_total - n_bins * self._state.samples_per_bin + + if n_samples == 0: + feature_shape = message.data.shape[1:] + out_data = xp.zeros((n_bins,) + feature_shape, dtype=xp.float32) + elif ( + self.settings.use_mlx_metal + and not is_numpy_array(message.data) + and getattr(xp, "__name__", "") == "mlx.core" + ): + out_data = self._process_mlx(message.data, n_prev_overflow, n_bins) + else: + out_data = self._process_numpy(message.data, n_prev_overflow, n_bins) + + time_axis = message.axes[self.settings.axis] + out_offset = time_axis.offset if n_bins == 0 else time_axis.offset - n_prev_overflow * time_axis.gain + out_axis = replace( + time_axis, + gain=self.settings.bin_duration, + offset=out_offset, + ) + return replace( + message, + data=out_data, + axes={**message.axes, self.settings.axis: out_axis}, + ) + + def _process_numpy(self, data, n_prev_overflow: int, n_bins: int): + data_np = data if is_numpy_array(data) else np.asarray(data) + n_samples = data_np.shape[0] + feature_shape = data_np.shape[1:] + n_features = int(np.prod(feature_shape, dtype=np.int64)) if feature_shape else 1 + flat = data_np.reshape(n_samples, n_features) + + prev_over = self._state.prev_over + if prev_over is None: + prev_flat = _initial_prev_over(flat, self.settings.threshold) + else: + prev_flat = np.asarray(prev_over, dtype=bool).reshape(n_features) + + elapsed_flat = np.asarray(self._state.elapsed, dtype=np.int32).reshape(n_features).copy() + overflow_flat = np.asarray(self._state.overflow_counts, dtype=np.float32).reshape(n_features).copy() + + out = np.zeros((n_bins, n_features), dtype=np.float32) + if n_bins > 0: + out[0] = overflow_flat + overflow_flat.fill(0.0) + + for samp_ix in range(n_samples): + sample = flat[samp_ix] + if self.settings.threshold >= 0: + over = sample >= self.settings.threshold + else: + over = sample <= self.settings.threshold + crossing = over & ~prev_flat + prev_flat = over + + elapsed_flat += 1 + if self._state.refrac_width > 2: + accepted = crossing & (elapsed_flat > self._state.refrac_width) + else: + accepted = crossing + if np.any(accepted): + bin_ix = (n_prev_overflow + samp_ix) // self._state.samples_per_bin + accepted_f32 = accepted.astype(np.float32) + if bin_ix < n_bins: + out[bin_ix] += accepted_f32 + else: + overflow_flat += accepted_f32 + elapsed_flat[accepted] = 0 + + if self.settings.rate_normalize: + out /= self.settings.bin_duration + + self._state.prev_over = prev_flat.reshape(feature_shape) + self._state.elapsed = elapsed_flat.reshape(feature_shape) + self._state.overflow_counts = overflow_flat.reshape(feature_shape) + return out.reshape((n_bins,) + feature_shape) + + def _process_mlx(self, data, n_prev_overflow: int, n_bins: int): + import mlx.core as mx + + from ezmsg.event.util.threshold_rate_mlx_metal import threshold_crossing_rate_mlx_metal + + if self._state.prev_over is None: + first = data[0] + if self.settings.threshold >= 0: + over = first >= self.settings.threshold + else: + over = first <= self.settings.threshold + self._state.prev_over = over.astype(mx.uint32) + + self._state.elapsed = mx.asarray(self._state.elapsed, dtype=mx.int32) + self._state.overflow_counts = mx.asarray(self._state.overflow_counts, dtype=mx.float32) + self._state.prev_over = mx.asarray(self._state.prev_over, dtype=mx.uint32) + + out, self._state.prev_over, self._state.elapsed, self._state.overflow_counts = ( + threshold_crossing_rate_mlx_metal( + data, + self._state.prev_over, + self._state.elapsed, + self._state.overflow_counts, + threshold=self.settings.threshold, + refrac_width=self._state.refrac_width, + n_overflow=n_prev_overflow, + samples_per_bin=self._state.samples_per_bin, + n_bins=n_bins, + bin_duration=self.settings.bin_duration, + rate_normalize=self.settings.rate_normalize, + ) + ) + return out + + +def _initial_prev_over(flat: np.ndarray, threshold: float) -> np.ndarray: + n_features = flat.shape[1] + if flat.shape[0] == 0: + return np.zeros(n_features, dtype=bool) + first = flat[0] + return first >= threshold if threshold >= 0 else first <= threshold + + +class ThresholdCrossingRate( + BaseTransformerUnit[ + ThresholdCrossingRateSettings, + AxisArray, + AxisArray, + ThresholdCrossingRateTransformer, + ] +): + """Unit for dense threshold-crossing event rates.""" + + SETTINGS = ThresholdCrossingRateSettings diff --git a/src/ezmsg/event/util/threshold_rate_mlx_metal.py b/src/ezmsg/event/util/threshold_rate_mlx_metal.py new file mode 100644 index 0000000..fd14dad --- /dev/null +++ b/src/ezmsg/event/util/threshold_rate_mlx_metal.py @@ -0,0 +1,158 @@ +"""Fused threshold-crossing rate calculation on Apple Silicon via MLX + Metal.""" + +import mlx.core as mx + + +def threshold_crossing_rate_mlx_metal( + x, + prev_over, + elapsed, + overflow_counts, + *, + threshold: float, + refrac_width: int, + n_overflow: int, + samples_per_bin: int, + n_bins: int, + bin_duration: float, + rate_normalize: bool, +): + """Compute threshold-crossing rates for ``x`` with time on axis 0. + + Args: + x: MLX array with shape ``(n_samples, *features)``. + prev_over: UInt32 MLX array with shape ``(*features,)``; nonzero + indicates whether the sample before this chunk was over threshold. + elapsed: Int32 MLX array with shape ``(*features,)`` tracking samples + since the last accepted crossing. + overflow_counts: Float32 MLX array with shape ``(*features,)`` holding + raw counts in the partial output bin carried from previous chunks. + threshold: Threshold crossing level. + refrac_width: Refractory duration in samples. A crossing is accepted + only when the distance from the previous accepted crossing is + greater than this value. + n_overflow: Number of input samples already accumulated in the current + partial output bin. + samples_per_bin: Number of input samples per output bin. + n_bins: Number of complete output bins produced by this chunk. + bin_duration: Output bin duration in seconds. + rate_normalize: If true, output events/second; otherwise raw counts. + + Returns: + ``(rates, prev_over, elapsed, overflow_counts)``. ``rates`` has shape + ``(n_bins, *features)`` and the state outputs are shaped like the state + inputs. + """ + if x.ndim < 1: + raise ValueError(f"x must have at least 1 dimension; got {x.ndim}") + if samples_per_bin < 1: + raise ValueError(f"samples_per_bin must be >= 1; got {samples_per_bin}") + if n_bins < 0: + raise ValueError(f"n_bins must be >= 0; got {n_bins}") + + x_f32 = x.astype(mx.float32) if x.dtype != mx.float32 else x + batch_shape = tuple(x_f32.shape[1:]) + n_samples = x_f32.shape[0] + n_channels = 1 + for dim in batch_shape: + n_channels *= dim + + x_flat = x_f32.reshape(n_samples, n_channels) + prev_flat = prev_over.astype(mx.uint32).reshape(n_channels) + elapsed_flat = elapsed.astype(mx.int32).reshape(n_channels) + overflow_flat = overflow_counts.astype(mx.float32).reshape(n_channels) + params = mx.array( + [float(threshold), float(1.0 / bin_duration if rate_normalize else 1.0)], + dtype=mx.float32, + ) + + # Metal kernels cannot emit a zero-size output on all MLX versions. Use a + # one-bin scratch output and slice it away for chunks with no complete bin. + n_output_bins = max(n_bins, 1) + rates_flat, prev_out, elapsed_out, overflow_out = _kernel( + inputs=[x_flat, prev_flat, elapsed_flat, overflow_flat, params], + template=[ + ("N_SAMPLES", n_samples), + ("N_CHANNELS", n_channels), + ("N_BINS", n_bins), + ("N_OUTPUT_BINS", n_output_bins), + ("N_OVERFLOW", n_overflow), + ("SAMPLES_PER_BIN", samples_per_bin), + ("REFRAC_WIDTH", refrac_width), + ], + grid=(n_channels, 1, 1), + threadgroup=(1, 1, 1), + output_shapes=[ + (n_output_bins, n_channels), + (n_channels,), + (n_channels,), + (n_channels,), + ], + output_dtypes=[mx.float32, mx.uint32, mx.int32, mx.float32], + ) + + rates_flat = rates_flat[:n_bins] + rates = rates_flat.reshape((n_bins,) + batch_shape) + return ( + rates, + prev_out.reshape(batch_shape), + elapsed_out.reshape(batch_shape), + overflow_out.reshape(batch_shape), + ) + + +_KERNEL_SOURCE = r""" + uint ch = thread_position_in_grid.x; + if (ch >= N_CHANNELS) { + return; + } + + uint prev = prev_over_in[ch]; + int elapsed = elapsed_in[ch]; + + for (uint bin = 0; bin < N_OUTPUT_BINS; ++bin) { + rates_out[bin * N_CHANNELS + ch] = 0.0f; + } + + float overflow = overflow_counts_in[ch]; + if (N_BINS > 0) { + rates_out[ch] = overflow; + overflow = 0.0f; + } + + for (uint t = 0; t < N_SAMPLES; ++t) { + float sample = x_in[t * N_CHANNELS + ch]; + float threshold = params[0]; + uint over = threshold >= 0.0f ? (sample >= threshold) : (sample <= threshold); + uint crossing = over && !prev; + prev = over; + + elapsed += 1; + if (crossing && (REFRAC_WIDTH <= 2 || elapsed > REFRAC_WIDTH)) { + uint sample_in_bin_stream = N_OVERFLOW + t; + uint bin = sample_in_bin_stream / SAMPLES_PER_BIN; + if (bin < N_BINS) { + rates_out[bin * N_CHANNELS + ch] += 1.0f; + } else { + overflow += 1.0f; + } + elapsed = 0; + } + } + + for (uint bin = 0; bin < N_BINS; ++bin) { + rates_out[bin * N_CHANNELS + ch] *= params[1]; + } + + prev_over_out[ch] = prev; + elapsed_out[ch] = elapsed; + overflow_counts_out[ch] = overflow; +""" + + +_kernel = mx.fast.metal_kernel( + name="threshold_crossing_rate", + input_names=["x_in", "prev_over_in", "elapsed_in", "overflow_counts_in", "params"], + output_names=["rates_out", "prev_over_out", "elapsed_out", "overflow_counts_out"], + source=_KERNEL_SOURCE, +) diff --git a/tests/test_threshold_rate.py b/tests/test_threshold_rate.py new file mode 100644 index 0000000..2228e4f --- /dev/null +++ b/tests/test_threshold_rate.py @@ -0,0 +1,206 @@ +import numpy as np +from ezmsg.util.messages.axisarray import AxisArray + +from ezmsg.event.peak import ThresholdCrossingTransformer +from ezmsg.event.rate import EventRateSettings, Rate +from ezmsg.event.threshold_rate import ThresholdCrossingRateSettings, ThresholdCrossingRateTransformer + + +def _make_msg(data: np.ndarray, fs: float, offset: float, dims: list[str] | None = None) -> AxisArray: + dims = dims or ["time", "ch"] + return AxisArray( + data=data, + dims=dims, + axes={ + "time": AxisArray.TimeAxis(fs=fs, offset=offset), + "ch": AxisArray.CoordinateAxis( + data=np.arange(data.shape[dims.index("ch")]), + dims=["ch"], + ), + }, + ) + + +def _run_sparse_reference( + chunks: list[np.ndarray], + *, + fs: float, + threshold: float, + refrac_dur: float, + bin_duration: float, + dims: list[str] | None = None, +) -> list[AxisArray]: + thresh = ThresholdCrossingTransformer(threshold=threshold, refrac_dur=refrac_dur) + rate = Rate(EventRateSettings(bin_duration=bin_duration)) + out = [] + samp_offset = 0 + for chunk in chunks: + msg = _make_msg(chunk, fs, samp_offset / fs, dims=dims) + out.append(rate(thresh(msg))) + samp_offset += chunk.shape[dims.index("time") if dims else 0] + return out + + +def _run_dense_fused( + chunks: list[np.ndarray], + *, + fs: float, + threshold: float, + refrac_dur: float, + bin_duration: float, + dims: list[str] | None = None, +) -> list[AxisArray]: + proc = ThresholdCrossingRateTransformer( + ThresholdCrossingRateSettings( + threshold=threshold, + refrac_dur=refrac_dur, + bin_duration=bin_duration, + use_mlx_metal=False, + ) + ) + out = [] + samp_offset = 0 + for chunk in chunks: + msg = _make_msg(chunk, fs, samp_offset / fs, dims=dims) + out.append(proc(msg)) + samp_offset += chunk.shape[dims.index("time") if dims else 0] + return out + + +def _assert_messages_match(actual: list[AxisArray], expected: list[AxisArray]) -> None: + assert len(actual) == len(expected) + for actual_msg, expected_msg in zip(actual, expected): + assert actual_msg.dims == expected_msg.dims + assert actual_msg.data.shape == expected_msg.data.shape + assert actual_msg.axes["time"].gain == expected_msg.axes["time"].gain + assert np.isclose(actual_msg.axes["time"].offset, expected_msg.axes["time"].offset) + np.testing.assert_allclose(actual_msg.data, expected_msg.data) + + +def test_threshold_crossing_rate_matches_sparse_pipeline_with_refractory_and_overflow(): + fs = 1000.0 + threshold = 1.0 + refrac_dur = 0.006 + bin_duration = 0.010 + data = np.zeros((43, 3), dtype=np.float32) + + # Channel 0 exercises greedy refractory: sample 5 is dropped after sample 1, + # but sample 9 is accepted because dropped crossings do not extend refractory. + for samp, ch in [ + (0, 2), # first sample has no prior reference, so it should not count + (1, 0), + (5, 0), + (9, 0), + (12, 1), + (19, 1), + (20, 1), + (30, 2), + (38, 2), + ]: + data[samp, ch] = 2.0 + + chunks = [data[:4], data[4:12], data[12:19], data[19:]] + expected = _run_sparse_reference( + chunks, + fs=fs, + threshold=threshold, + refrac_dur=refrac_dur, + bin_duration=bin_duration, + ) + actual = _run_dense_fused( + chunks, + fs=fs, + threshold=threshold, + refrac_dur=refrac_dur, + bin_duration=bin_duration, + ) + + _assert_messages_match(actual, expected) + + +def test_threshold_crossing_rate_matches_sparse_pipeline_for_negative_threshold(): + fs = 1000.0 + threshold = -1.0 + refrac_dur = 0.004 + bin_duration = 0.010 + data = np.zeros((31, 2), dtype=np.float32) + for samp, ch in [(0, 0), (3, 0), (8, 0), (12, 1), (17, 1), (24, 1)]: + data[samp, ch] = -2.0 + + chunks = [data[:6], data[6:16], data[16:]] + expected = _run_sparse_reference( + chunks, + fs=fs, + threshold=threshold, + refrac_dur=refrac_dur, + bin_duration=bin_duration, + ) + actual = _run_dense_fused( + chunks, + fs=fs, + threshold=threshold, + refrac_dur=refrac_dur, + bin_duration=bin_duration, + ) + + _assert_messages_match(actual, expected) + + +def test_threshold_crossing_rate_supports_nonzero_time_axis(): + fs = 1000.0 + threshold = 1.0 + refrac_dur = 0.004 + bin_duration = 0.010 + data = np.zeros((2, 25), dtype=np.float32) + data[0, [2, 8, 14]] = 2.0 + data[1, [5, 11, 21]] = 2.0 + + chunks = [data[:, :9], data[:, 9:17], data[:, 17:]] + dims = ["ch", "time"] + expected = _run_sparse_reference( + chunks, + fs=fs, + threshold=threshold, + refrac_dur=refrac_dur, + bin_duration=bin_duration, + dims=dims, + ) + actual = _run_dense_fused( + chunks, + fs=fs, + threshold=threshold, + refrac_dur=refrac_dur, + bin_duration=bin_duration, + dims=dims, + ) + + _assert_messages_match(actual, expected) + + +def test_threshold_crossing_rate_empty_time_first_and_midstream(): + fs = 1000.0 + proc = ThresholdCrossingRateTransformer( + threshold=1.0, + refrac_dur=0.004, + bin_duration=0.010, + use_mlx_metal=False, + ) + + empty = _make_msg(np.zeros((0, 2), dtype=np.float32), fs, 0.0) + out_empty = proc(empty) + assert out_empty.data.shape == (0, 2) + + data = np.zeros((20, 2), dtype=np.float32) + data[0, 0] = 2.0 + data[3, 0] = 2.0 + data[11, 1] = 2.0 + out_data = proc(_make_msg(data, fs, 0.0)) + + ref_thresh = ThresholdCrossingTransformer(threshold=1.0, refrac_dur=0.004) + ref_rate = Rate(EventRateSettings(bin_duration=0.010)) + ref_rate(ref_thresh(empty)) + expected = ref_rate(ref_thresh(_make_msg(data, fs, 0.0))) + np.testing.assert_allclose(out_data.data, expected.data) + + mid_empty = proc(_make_msg(np.zeros((0, 2), dtype=np.float32), fs, 0.020)) + assert mid_empty.data.shape == (0, 2) From 8e6a0b74ebef2ed69d6effe9562580e71f83c8b5 Mon Sep 17 00:00:00 2001 From: kylmcgr Date: Mon, 27 Apr 2026 18:57:24 -0600 Subject: [PATCH 2/4] float samples per bin --- src/ezmsg/event/threshold_rate.py | 43 +++++++++++-------- .../event/util/threshold_rate_mlx_metal.py | 39 +++++++++++------ tests/test_threshold_rate.py | 38 ++++++++++++++++ 3 files changed, 89 insertions(+), 31 deletions(-) diff --git a/src/ezmsg/event/threshold_rate.py b/src/ezmsg/event/threshold_rate.py index 151200c..8a50598 100644 --- a/src/ezmsg/event/threshold_rate.py +++ b/src/ezmsg/event/threshold_rate.py @@ -40,11 +40,11 @@ class ThresholdCrossingRateState: overflow_counts: Any = None """Raw counts in the current partial output bin for each feature.""" - n_overflow: int = 0 - """Number of input samples in the current partial output bin.""" + bin_accumulator: float = 0.0 + """Fractional number of input samples in the current partial output bin.""" refrac_width: int = 0 - samples_per_bin: int = 0 + samples_per_bin: float = 0.0 class ThresholdCrossingRateTransformer( @@ -76,8 +76,8 @@ def _reset_state(self, message: AxisArray) -> None: fs = 1.0 / message.axes[self.settings.axis].gain self._state.refrac_width = int(self.settings.refrac_dur * fs) - self._state.samples_per_bin = int(self.settings.bin_duration * fs) - if self._state.samples_per_bin < 1: + self._state.samples_per_bin = self.settings.bin_duration * fs + if self._state.samples_per_bin < 1.0: raise ValueError( f"bin_duration={self.settings.bin_duration} is shorter than one sample at fs={fs:g} Hz" ) @@ -89,7 +89,7 @@ def _reset_state(self, message: AxisArray) -> None: dtype=xp.int32, ) self._state.overflow_counts = xp.zeros(feature_shape, dtype=xp.float32) - self._state.n_overflow = 0 + self._state.bin_accumulator = 0.0 def _process(self, message: AxisArray) -> AxisArray: xp = get_namespace(message.data) @@ -104,10 +104,10 @@ def _process(self, message: AxisArray) -> AxisArray: ) n_samples = message.data.shape[0] - n_prev_overflow = self._state.n_overflow - n_total = n_prev_overflow + n_samples - n_bins = n_total // self._state.samples_per_bin - self._state.n_overflow = n_total - n_bins * self._state.samples_per_bin + accumulator_before = self._state.bin_accumulator + n_total = accumulator_before + n_samples + n_bins = int(n_total / self._state.samples_per_bin) + self._state.bin_accumulator = n_total - n_bins * self._state.samples_per_bin if n_samples == 0: feature_shape = message.data.shape[1:] @@ -117,12 +117,12 @@ def _process(self, message: AxisArray) -> AxisArray: and not is_numpy_array(message.data) and getattr(xp, "__name__", "") == "mlx.core" ): - out_data = self._process_mlx(message.data, n_prev_overflow, n_bins) + out_data = self._process_mlx(message.data, accumulator_before, n_bins) else: - out_data = self._process_numpy(message.data, n_prev_overflow, n_bins) + out_data = self._process_numpy(message.data, accumulator_before, n_bins) time_axis = message.axes[self.settings.axis] - out_offset = time_axis.offset if n_bins == 0 else time_axis.offset - n_prev_overflow * time_axis.gain + out_offset = time_axis.offset if n_bins == 0 else time_axis.offset - accumulator_before * time_axis.gain out_axis = replace( time_axis, gain=self.settings.bin_duration, @@ -134,12 +134,21 @@ def _process(self, message: AxisArray) -> AxisArray: axes={**message.axes, self.settings.axis: out_axis}, ) - def _process_numpy(self, data, n_prev_overflow: int, n_bins: int): + def _process_numpy(self, data, accumulator_before: float, n_bins: int): data_np = data if is_numpy_array(data) else np.asarray(data) n_samples = data_np.shape[0] feature_shape = data_np.shape[1:] n_features = int(np.prod(feature_shape, dtype=np.int64)) if feature_shape else 1 flat = data_np.reshape(n_samples, n_features) + bin_end_samples = ( + ( + self._state.samples_per_bin + - accumulator_before + + np.arange(n_bins, dtype=np.float64) * self._state.samples_per_bin + ).astype(np.int64) + if n_bins > 0 + else np.empty((0,), dtype=np.int64) + ) prev_over = self._state.prev_over if prev_over is None: @@ -170,7 +179,7 @@ def _process_numpy(self, data, n_prev_overflow: int, n_bins: int): else: accepted = crossing if np.any(accepted): - bin_ix = (n_prev_overflow + samp_ix) // self._state.samples_per_bin + bin_ix = np.searchsorted(bin_end_samples, samp_ix, side="right") accepted_f32 = accepted.astype(np.float32) if bin_ix < n_bins: out[bin_ix] += accepted_f32 @@ -186,7 +195,7 @@ def _process_numpy(self, data, n_prev_overflow: int, n_bins: int): self._state.overflow_counts = overflow_flat.reshape(feature_shape) return out.reshape((n_bins,) + feature_shape) - def _process_mlx(self, data, n_prev_overflow: int, n_bins: int): + def _process_mlx(self, data, accumulator_before: float, n_bins: int): import mlx.core as mx from ezmsg.event.util.threshold_rate_mlx_metal import threshold_crossing_rate_mlx_metal @@ -211,7 +220,7 @@ def _process_mlx(self, data, n_prev_overflow: int, n_bins: int): self._state.overflow_counts, threshold=self.settings.threshold, refrac_width=self._state.refrac_width, - n_overflow=n_prev_overflow, + bin_accumulator=accumulator_before, samples_per_bin=self._state.samples_per_bin, n_bins=n_bins, bin_duration=self.settings.bin_duration, diff --git a/src/ezmsg/event/util/threshold_rate_mlx_metal.py b/src/ezmsg/event/util/threshold_rate_mlx_metal.py index fd14dad..6a6a648 100644 --- a/src/ezmsg/event/util/threshold_rate_mlx_metal.py +++ b/src/ezmsg/event/util/threshold_rate_mlx_metal.py @@ -11,8 +11,8 @@ def threshold_crossing_rate_mlx_metal( *, threshold: float, refrac_width: int, - n_overflow: int, - samples_per_bin: int, + bin_accumulator: float, + samples_per_bin: float, n_bins: int, bin_duration: float, rate_normalize: bool, @@ -31,9 +31,9 @@ def threshold_crossing_rate_mlx_metal( refrac_width: Refractory duration in samples. A crossing is accepted only when the distance from the previous accepted crossing is greater than this value. - n_overflow: Number of input samples already accumulated in the current - partial output bin. - samples_per_bin: Number of input samples per output bin. + bin_accumulator: Fractional number of input samples already accumulated + in the current partial output bin. + samples_per_bin: Fractional number of input samples per output bin. n_bins: Number of complete output bins produced by this chunk. bin_duration: Output bin duration in seconds. rate_normalize: If true, output events/second; otherwise raw counts. @@ -45,8 +45,8 @@ def threshold_crossing_rate_mlx_metal( """ if x.ndim < 1: raise ValueError(f"x must have at least 1 dimension; got {x.ndim}") - if samples_per_bin < 1: - raise ValueError(f"samples_per_bin must be >= 1; got {samples_per_bin}") + if samples_per_bin < 1.0: + raise ValueError(f"samples_per_bin must be >= 1.0; got {samples_per_bin}") if n_bins < 0: raise ValueError(f"n_bins must be >= 0; got {n_bins}") @@ -62,7 +62,12 @@ def threshold_crossing_rate_mlx_metal( elapsed_flat = elapsed.astype(mx.int32).reshape(n_channels) overflow_flat = overflow_counts.astype(mx.float32).reshape(n_channels) params = mx.array( - [float(threshold), float(1.0 / bin_duration if rate_normalize else 1.0)], + [ + float(threshold), + float(1.0 / bin_duration if rate_normalize else 1.0), + float(samples_per_bin - bin_accumulator), + float(samples_per_bin), + ], dtype=mx.float32, ) @@ -76,8 +81,6 @@ def threshold_crossing_rate_mlx_metal( ("N_CHANNELS", n_channels), ("N_BINS", n_bins), ("N_OUTPUT_BINS", n_output_bins), - ("N_OVERFLOW", n_overflow), - ("SAMPLES_PER_BIN", samples_per_bin), ("REFRAC_WIDTH", refrac_width), ], grid=(n_channels, 1, 1), @@ -120,7 +123,17 @@ def threshold_crossing_rate_mlx_metal( overflow = 0.0f; } + uint active_bin = 0; + uint active_bin_end = N_BINS > 0 ? uint(params[2]) : 0; + for (uint t = 0; t < N_SAMPLES; ++t) { + while (active_bin < N_BINS && t >= active_bin_end) { + active_bin += 1; + if (active_bin < N_BINS) { + active_bin_end = uint(params[2] + float(active_bin) * params[3]); + } + } + float sample = x_in[t * N_CHANNELS + ch]; float threshold = params[0]; uint over = threshold >= 0.0f ? (sample >= threshold) : (sample <= threshold); @@ -129,10 +142,8 @@ def threshold_crossing_rate_mlx_metal( elapsed += 1; if (crossing && (REFRAC_WIDTH <= 2 || elapsed > REFRAC_WIDTH)) { - uint sample_in_bin_stream = N_OVERFLOW + t; - uint bin = sample_in_bin_stream / SAMPLES_PER_BIN; - if (bin < N_BINS) { - rates_out[bin * N_CHANNELS + ch] += 1.0f; + if (active_bin < N_BINS) { + rates_out[active_bin * N_CHANNELS + ch] += 1.0f; } else { overflow += 1.0f; } diff --git a/tests/test_threshold_rate.py b/tests/test_threshold_rate.py index 2228e4f..3921bb6 100644 --- a/tests/test_threshold_rate.py +++ b/tests/test_threshold_rate.py @@ -146,6 +146,44 @@ def test_threshold_crossing_rate_matches_sparse_pipeline_for_negative_threshold( _assert_messages_match(actual, expected) +def test_threshold_crossing_rate_matches_sparse_pipeline_for_fractional_samples_per_bin(): + fs = 30012.0048 + threshold = 1.0 + refrac_dur = 0.0 + bin_duration = 0.020 + data = np.zeros((5000, 2), dtype=np.float32) + for samp, ch in [ + (599, 0), + (600, 1), + (1199, 0), + (1200, 1), + (2399, 0), + (2400, 1), + (3000, 0), + (3001, 1), + (3602, 0), + ]: + data[samp, ch] = 2.0 + + chunks = [data[:777], data[777:1310], data[1310:2450], data[2450:3800], data[3800:]] + expected = _run_sparse_reference( + chunks, + fs=fs, + threshold=threshold, + refrac_dur=refrac_dur, + bin_duration=bin_duration, + ) + actual = _run_dense_fused( + chunks, + fs=fs, + threshold=threshold, + refrac_dur=refrac_dur, + bin_duration=bin_duration, + ) + + _assert_messages_match(actual, expected) + + def test_threshold_crossing_rate_supports_nonzero_time_axis(): fs = 1000.0 threshold = 1.0 From cc9b7d7f956a1255f8ff6033ea946441ef575184 Mon Sep 17 00:00:00 2001 From: kylmcgr Date: Tue, 28 Apr 2026 10:36:45 -0600 Subject: [PATCH 3/4] mlx update --- src/ezmsg/event/threshold_rate.py | 4 +-- .../event/util/threshold_rate_mlx_metal.py | 28 ++++++++----------- 2 files changed, 13 insertions(+), 19 deletions(-) diff --git a/src/ezmsg/event/threshold_rate.py b/src/ezmsg/event/threshold_rate.py index 8a50598..4bd7d21 100644 --- a/src/ezmsg/event/threshold_rate.py +++ b/src/ezmsg/event/threshold_rate.py @@ -206,11 +206,11 @@ def _process_mlx(self, data, accumulator_before: float, n_bins: int): over = first >= self.settings.threshold else: over = first <= self.settings.threshold - self._state.prev_over = over.astype(mx.uint32) + self._state.prev_over = over.astype(mx.int8) self._state.elapsed = mx.asarray(self._state.elapsed, dtype=mx.int32) self._state.overflow_counts = mx.asarray(self._state.overflow_counts, dtype=mx.float32) - self._state.prev_over = mx.asarray(self._state.prev_over, dtype=mx.uint32) + self._state.prev_over = mx.asarray(self._state.prev_over, dtype=mx.int8) out, self._state.prev_over, self._state.elapsed, self._state.overflow_counts = ( threshold_crossing_rate_mlx_metal( diff --git a/src/ezmsg/event/util/threshold_rate_mlx_metal.py b/src/ezmsg/event/util/threshold_rate_mlx_metal.py index 6a6a648..7a82819 100644 --- a/src/ezmsg/event/util/threshold_rate_mlx_metal.py +++ b/src/ezmsg/event/util/threshold_rate_mlx_metal.py @@ -21,7 +21,7 @@ def threshold_crossing_rate_mlx_metal( Args: x: MLX array with shape ``(n_samples, *features)``. - prev_over: UInt32 MLX array with shape ``(*features,)``; nonzero + prev_over: Int8 MLX array with shape ``(*features,)``; nonzero indicates whether the sample before this chunk was over threshold. elapsed: Int32 MLX array with shape ``(*features,)`` tracking samples since the last accepted crossing. @@ -58,7 +58,7 @@ def threshold_crossing_rate_mlx_metal( n_channels *= dim x_flat = x_f32.reshape(n_samples, n_channels) - prev_flat = prev_over.astype(mx.uint32).reshape(n_channels) + prev_flat = prev_over.astype(mx.int8).reshape(n_channels) elapsed_flat = elapsed.astype(mx.int32).reshape(n_channels) overflow_flat = overflow_counts.astype(mx.float32).reshape(n_channels) params = mx.array( @@ -91,7 +91,7 @@ def threshold_crossing_rate_mlx_metal( (n_channels,), (n_channels,), ], - output_dtypes=[mx.float32, mx.uint32, mx.int32, mx.float32], + output_dtypes=[mx.float32, mx.int8, mx.int32, mx.float32], ) rates_flat = rates_flat[:n_bins] @@ -110,7 +110,7 @@ def threshold_crossing_rate_mlx_metal( return; } - uint prev = prev_over_in[ch]; + uint prev = prev_over_in[ch] != 0; int elapsed = elapsed_in[ch]; for (uint bin = 0; bin < N_OUTPUT_BINS; ++bin) { @@ -123,27 +123,21 @@ def threshold_crossing_rate_mlx_metal( overflow = 0.0f; } - uint active_bin = 0; - uint active_bin_end = N_BINS > 0 ? uint(params[2]) : 0; + float threshold = params[0]; + float first_bin_end = params[2]; + float samples_per_bin = params[3]; for (uint t = 0; t < N_SAMPLES; ++t) { - while (active_bin < N_BINS && t >= active_bin_end) { - active_bin += 1; - if (active_bin < N_BINS) { - active_bin_end = uint(params[2] + float(active_bin) * params[3]); - } - } - float sample = x_in[t * N_CHANNELS + ch]; - float threshold = params[0]; uint over = threshold >= 0.0f ? (sample >= threshold) : (sample <= threshold); uint crossing = over && !prev; prev = over; elapsed += 1; if (crossing && (REFRAC_WIDTH <= 2 || elapsed > REFRAC_WIDTH)) { - if (active_bin < N_BINS) { - rates_out[active_bin * N_CHANNELS + ch] += 1.0f; + int active_bin = int(ceil((float(t) + 1.0f - first_bin_end) / samples_per_bin)); + if (active_bin >= 0 && active_bin < N_BINS) { + rates_out[uint(active_bin) * N_CHANNELS + ch] += 1.0f; } else { overflow += 1.0f; } @@ -155,7 +149,7 @@ def threshold_crossing_rate_mlx_metal( rates_out[bin * N_CHANNELS + ch] *= params[1]; } - prev_over_out[ch] = prev; + prev_over_out[ch] = prev ? 1 : 0; elapsed_out[ch] = elapsed; overflow_counts_out[ch] = overflow; """ From 39e936c811eafb40a83633ee24af7c55f1c43248 Mon Sep 17 00:00:00 2001 From: kylmcgr Date: Tue, 28 Apr 2026 10:47:05 -0600 Subject: [PATCH 4/4] Metal speed up --- .../event/util/threshold_rate_mlx_metal.py | 150 ++++++++++++++---- tests/test_threshold_rate.py | 113 +++++++++++++ 2 files changed, 234 insertions(+), 29 deletions(-) diff --git a/src/ezmsg/event/util/threshold_rate_mlx_metal.py b/src/ezmsg/event/util/threshold_rate_mlx_metal.py index 7a82819..498b9af 100644 --- a/src/ezmsg/event/util/threshold_rate_mlx_metal.py +++ b/src/ezmsg/event/util/threshold_rate_mlx_metal.py @@ -57,10 +57,21 @@ def threshold_crossing_rate_mlx_metal( for dim in batch_shape: n_channels *= dim + if n_samples == 0: + rates = mx.zeros((n_bins,) + batch_shape, dtype=mx.float32) + return ( + rates, + prev_over.astype(mx.int8).reshape(batch_shape), + elapsed.astype(mx.int32).reshape(batch_shape), + overflow_counts.astype(mx.float32).reshape(batch_shape), + ) + x_flat = x_f32.reshape(n_samples, n_channels) prev_flat = prev_over.astype(mx.int8).reshape(n_channels) elapsed_flat = elapsed.astype(mx.int32).reshape(n_channels) overflow_flat = overflow_counts.astype(mx.float32).reshape(n_channels) + n_words = (n_samples + 31) // 32 + params = mx.array( [ float(threshold), @@ -71,14 +82,28 @@ def threshold_crossing_rate_mlx_metal( dtype=mx.float32, ) + crossing_words, final_over = _crossing_words_kernel( + inputs=[x_flat, prev_flat, params], + template=[ + ("N_SAMPLES", n_samples), + ("N_CHANNELS", n_channels), + ("N_WORDS", n_words), + ], + grid=(n_words, n_channels, 1), + threadgroup=(1, 1, 1), + output_shapes=[(n_words, n_channels), (n_channels,)], + output_dtypes=[mx.uint32, mx.int8], + ) + # Metal kernels cannot emit a zero-size output on all MLX versions. Use a # one-bin scratch output and slice it away for chunks with no complete bin. n_output_bins = max(n_bins, 1) - rates_flat, prev_out, elapsed_out, overflow_out = _kernel( - inputs=[x_flat, prev_flat, elapsed_flat, overflow_flat, params], + rates_flat, prev_out, elapsed_out, overflow_out = _refractory_words_kernel( + inputs=[crossing_words, final_over, elapsed_flat, overflow_flat, params], template=[ ("N_SAMPLES", n_samples), ("N_CHANNELS", n_channels), + ("N_WORDS", n_words), ("N_BINS", n_bins), ("N_OUTPUT_BINS", n_output_bins), ("REFRAC_WIDTH", refrac_width), @@ -104,60 +129,127 @@ def threshold_crossing_rate_mlx_metal( ) -_KERNEL_SOURCE = r""" +_CROSSING_WORDS_KERNEL_SOURCE = r""" + uint word = thread_position_in_grid.x; + uint ch = thread_position_in_grid.y; + if (word >= N_WORDS || ch >= N_CHANNELS) { + return; + } + + float threshold = params[0]; + uint start = word * 32; + uint prev = 0; + if (start == 0) { + prev = prev_over_in[ch] != 0; + } else { + float prev_sample = x_in[(start - 1) * N_CHANNELS + ch]; + prev = threshold >= 0.0f ? (prev_sample >= threshold) : (prev_sample <= threshold); + } + + uint bits = 0; + for (uint bit = 0; bit < 32; ++bit) { + uint t = start + bit; + if (t >= N_SAMPLES) { + break; + } + + float sample = x_in[t * N_CHANNELS + ch]; + uint over = threshold >= 0.0f ? (sample >= threshold) : (sample <= threshold); + if (over && !prev) { + bits |= (1u << bit); + } + prev = over; + } + crossing_words_out[word * N_CHANNELS + ch] = bits; + + if (word == N_WORDS - 1) { + final_over_out[ch] = prev ? 1 : 0; + } +""" + + +_REFRACTORY_WORDS_KERNEL_SOURCE = r""" uint ch = thread_position_in_grid.x; if (ch >= N_CHANNELS) { return; } - uint prev = prev_over_in[ch] != 0; - int elapsed = elapsed_in[ch]; + float scale = params[1]; + float first_bin_end = params[2]; + float samples_per_bin = params[3]; + float overflow = overflow_counts_in[ch]; for (uint bin = 0; bin < N_OUTPUT_BINS; ++bin) { rates_out[bin * N_CHANNELS + ch] = 0.0f; } - - float overflow = overflow_counts_in[ch]; if (N_BINS > 0) { rates_out[ch] = overflow; overflow = 0.0f; } - float threshold = params[0]; - float first_bin_end = params[2]; - float samples_per_bin = params[3]; + int elapsed = elapsed_in[ch]; + int last_t = -1; + + for (uint word = 0; word < N_WORDS; ++word) { + uint bits = crossing_words_in[word * N_CHANNELS + ch]; + while (bits != 0) { + uint bit = 0; + uint mask = 1u; + while ((bits & mask) == 0u) { + bit += 1; + mask <<= 1; + } + bits &= ~mask; - for (uint t = 0; t < N_SAMPLES; ++t) { - float sample = x_in[t * N_CHANNELS + ch]; - uint over = threshold >= 0.0f ? (sample >= threshold) : (sample <= threshold); - uint crossing = over && !prev; - prev = over; + uint t = word * 32 + bit; + if (t >= N_SAMPLES) { + break; + } - elapsed += 1; - if (crossing && (REFRAC_WIDTH <= 2 || elapsed > REFRAC_WIDTH)) { - int active_bin = int(ceil((float(t) + 1.0f - first_bin_end) / samples_per_bin)); - if (active_bin >= 0 && active_bin < N_BINS) { - rates_out[uint(active_bin) * N_CHANNELS + ch] += 1.0f; - } else { - overflow += 1.0f; + elapsed += int(t) - last_t; + last_t = int(t); + + if (REFRAC_WIDTH <= 2 || elapsed > REFRAC_WIDTH) { + int active_bin = int(ceil((float(t) + 1.0f - first_bin_end) / samples_per_bin)); + if (active_bin >= 0 && active_bin < N_BINS) { + rates_out[uint(active_bin) * N_CHANNELS + ch] += 1.0f; + } else { + overflow += 1.0f; + } + elapsed = 0; } - elapsed = 0; } } + elapsed += int(N_SAMPLES) - 1 - last_t; + for (uint bin = 0; bin < N_BINS; ++bin) { - rates_out[bin * N_CHANNELS + ch] *= params[1]; + rates_out[bin * N_CHANNELS + ch] *= scale; } - prev_over_out[ch] = prev ? 1 : 0; + prev_over_out[ch] = final_over_in[ch]; elapsed_out[ch] = elapsed; overflow_counts_out[ch] = overflow; """ -_kernel = mx.fast.metal_kernel( - name="threshold_crossing_rate", - input_names=["x_in", "prev_over_in", "elapsed_in", "overflow_counts_in", "params"], +_crossing_words_kernel = mx.fast.metal_kernel( + name="threshold_crossing_words", + input_names=["x_in", "prev_over_in", "params"], + output_names=["crossing_words_out", "final_over_out"], + source=_CROSSING_WORDS_KERNEL_SOURCE, +) + + +_refractory_words_kernel = mx.fast.metal_kernel( + name="threshold_refractory_words", + input_names=[ + "crossing_words_in", + "final_over_in", + "elapsed_in", + "overflow_counts_in", + "params", + ], output_names=["rates_out", "prev_over_out", "elapsed_out", "overflow_counts_out"], - source=_KERNEL_SOURCE, + source=_REFRACTORY_WORDS_KERNEL_SOURCE, ) diff --git a/tests/test_threshold_rate.py b/tests/test_threshold_rate.py index 3921bb6..11dfa77 100644 --- a/tests/test_threshold_rate.py +++ b/tests/test_threshold_rate.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from ezmsg.util.messages.axisarray import AxisArray from ezmsg.event.peak import ThresholdCrossingTransformer @@ -67,6 +68,16 @@ def _run_dense_fused( return out +def _require_mlx_metal(): + mx = pytest.importorskip("mlx.core") + try: + probe = mx.array([1.0], dtype=mx.float32) + mx.eval(probe) + except RuntimeError as exc: + pytest.skip(f"MLX Metal device unavailable: {exc}") + return mx + + def _assert_messages_match(actual: list[AxisArray], expected: list[AxisArray]) -> None: assert len(actual) == len(expected) for actual_msg, expected_msg in zip(actual, expected): @@ -242,3 +253,105 @@ def test_threshold_crossing_rate_empty_time_first_and_midstream(): mid_empty = proc(_make_msg(np.zeros((0, 2), dtype=np.float32), fs, 0.020)) assert mid_empty.data.shape == (0, 2) + + +def test_threshold_crossing_rate_mlx_metal_matches_cpu_for_adversarial_refractory(): + mx = _require_mlx_metal() + fs = 1000.0 + threshold = 1.0 + refrac_dur = 0.030 + bin_duration = 0.100 + data = np.zeros((1000, 4), dtype=np.float32) + + for ch in range(data.shape[1]): + for samp in range(ch + 1, data.shape[0], 31): + data[samp, ch] = 2.0 + + chunks = [data[:137], data[137:503], data[503:777], data[777:]] + expected = _run_dense_fused( + chunks, + fs=fs, + threshold=threshold, + refrac_dur=refrac_dur, + bin_duration=bin_duration, + ) + + proc = ThresholdCrossingRateTransformer( + ThresholdCrossingRateSettings( + threshold=threshold, + refrac_dur=refrac_dur, + bin_duration=bin_duration, + use_mlx_metal=True, + ) + ) + actual = [] + samp_offset = 0 + for chunk in chunks: + msg = _make_msg(mx.array(chunk), fs, samp_offset / fs) + out = proc(msg) + mx.eval( + out.data, + proc._state.prev_over, + proc._state.elapsed, + proc._state.overflow_counts, + ) + actual.append(out) + samp_offset += chunk.shape[0] + + _assert_messages_match(actual, expected) + + +def test_threshold_crossing_rate_mlx_metal_matches_cpu_for_negative_fractional_bins(): + mx = _require_mlx_metal() + fs = 30012.0048 + threshold = -1.0 + refrac_dur = 0.001 + bin_duration = 0.020 + data = np.zeros((5000, 3), dtype=np.float32) + for samp, ch in [ + (599, 0), + (600, 1), + (631, 1), + (1199, 0), + (1200, 2), + (1230, 2), + (2399, 0), + (2400, 1), + (3000, 0), + (3001, 2), + (3602, 1), + ]: + data[samp, ch] = -2.0 + + chunks = [data[:777], data[777:1310], data[1310:2450], data[2450:3800], data[3800:]] + expected = _run_dense_fused( + chunks, + fs=fs, + threshold=threshold, + refrac_dur=refrac_dur, + bin_duration=bin_duration, + ) + + proc = ThresholdCrossingRateTransformer( + ThresholdCrossingRateSettings( + threshold=threshold, + refrac_dur=refrac_dur, + bin_duration=bin_duration, + use_mlx_metal=True, + ) + ) + actual = [] + samp_offset = 0 + for chunk in chunks: + msg = _make_msg(mx.array(chunk), fs, samp_offset / fs) + out = proc(msg) + mx.eval( + out.data, + proc._state.prev_over, + proc._state.elapsed, + proc._state.overflow_counts, + ) + actual.append(out) + samp_offset += chunk.shape[0] + + _assert_messages_match(actual, expected)