diff --git a/src/ezmsg/event/threshold_rate.py b/src/ezmsg/event/threshold_rate.py new file mode 100644 index 0000000..4bd7d21 --- /dev/null +++ b/src/ezmsg/event/threshold_rate.py @@ -0,0 +1,251 @@ +"""Dense fused threshold-crossing event-rate transformer.""" + +from typing import Any + +import ezmsg.core as ez +import numpy as np +from array_api_compat import get_namespace, is_numpy_array +from ezmsg.baseproc import BaseStatefulTransformer, BaseTransformerUnit, processor_state +from ezmsg.util.messages.axisarray import AxisArray, replace + + +class ThresholdCrossingRateSettings(ez.Settings): + threshold: float = -3.5 + """The value the signal must cross to count an event.""" + + refrac_dur: float = 0.001 + """Minimum duration between counted threshold crossings in seconds.""" + + bin_duration: float = 0.05 + """Output bin duration in seconds.""" + + rate_normalize: bool = True + """If True, divide counts by bin_duration to emit events/second.""" + + axis: str = "time" + """Input sample axis.""" + + use_mlx_metal: bool = True + """If True, MLX inputs use the fused on-device Metal implementation.""" + + +@processor_state +class ThresholdCrossingRateState: + prev_over: Any = None + """Whether the previous sample was over threshold for each feature.""" + + elapsed: Any = None + """Samples since the last accepted threshold crossing for each feature.""" + + overflow_counts: Any = None + """Raw counts in the current partial output bin for each feature.""" + + bin_accumulator: float = 0.0 + """Fractional number of input samples in the current partial output bin.""" + + refrac_width: int = 0 + samples_per_bin: float = 0.0 + + +class ThresholdCrossingRateTransformer( + BaseStatefulTransformer[ + ThresholdCrossingRateSettings, + AxisArray, + AxisArray, + ThresholdCrossingRateState, + ] +): + """Count threshold crossings directly into dense rate bins. + + This transformer covers the simple threshold-crossing case used by dense + preprocessing pipelines: crossing-aligned events, no peak-value payload, + no peak-duration filtering, and no sparse.COO boundary. It preserves exact + refractory behavior while allowing MLX inputs to remain on device through + a fused Metal path. + """ + + def _hash_message(self, message: AxisArray) -> int: + ax_idx = message.get_axis_idx(self.settings.axis) + sample_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :] + return hash((message.key, sample_shape, message.axes[self.settings.axis].gain)) + + def _reset_state(self, message: AxisArray) -> None: + xp = get_namespace(message.data) + ax_idx = message.get_axis_idx(self.settings.axis) + feature_shape = message.data.shape[:ax_idx] + message.data.shape[ax_idx + 1 :] + + fs = 1.0 / message.axes[self.settings.axis].gain + self._state.refrac_width = int(self.settings.refrac_dur * fs) + self._state.samples_per_bin = self.settings.bin_duration * fs + if self._state.samples_per_bin < 1.0: + raise ValueError( + f"bin_duration={self.settings.bin_duration} is shorter than one sample at fs={fs:g} Hz" + ) + + self._state.prev_over = None + self._state.elapsed = xp.full( + feature_shape, + self._state.refrac_width + 1, + dtype=xp.int32, + ) + self._state.overflow_counts = xp.zeros(feature_shape, dtype=xp.float32) + self._state.bin_accumulator = 0.0 + + def _process(self, message: AxisArray) -> AxisArray: + xp = get_namespace(message.data) + ax_idx = message.get_axis_idx(self.settings.axis) + + if ax_idx != 0: + perm = (ax_idx,) + tuple(i for i in range(message.data.ndim) if i != ax_idx) + message = replace( + message, + data=xp.permute_dims(message.data, perm), + dims=[self.settings.axis] + message.dims[:ax_idx] + message.dims[ax_idx + 1 :], + ) + + n_samples = message.data.shape[0] + accumulator_before = self._state.bin_accumulator + n_total = accumulator_before + n_samples + n_bins = int(n_total / self._state.samples_per_bin) + self._state.bin_accumulator = n_total - n_bins * self._state.samples_per_bin + + if n_samples == 0: + feature_shape = message.data.shape[1:] + out_data = xp.zeros((n_bins,) + feature_shape, dtype=xp.float32) + elif ( + self.settings.use_mlx_metal + and not is_numpy_array(message.data) + and getattr(xp, "__name__", "") == "mlx.core" + ): + out_data = self._process_mlx(message.data, accumulator_before, n_bins) + else: + out_data = self._process_numpy(message.data, accumulator_before, n_bins) + + time_axis = message.axes[self.settings.axis] + out_offset = time_axis.offset if n_bins == 0 else time_axis.offset - accumulator_before * time_axis.gain + out_axis = replace( + time_axis, + gain=self.settings.bin_duration, + offset=out_offset, + ) + return replace( + message, + data=out_data, + axes={**message.axes, self.settings.axis: out_axis}, + ) + + def _process_numpy(self, data, accumulator_before: float, n_bins: int): + data_np = data if is_numpy_array(data) else np.asarray(data) + n_samples = data_np.shape[0] + feature_shape = data_np.shape[1:] + n_features = int(np.prod(feature_shape, dtype=np.int64)) if feature_shape else 1 + flat = data_np.reshape(n_samples, n_features) + bin_end_samples = ( + ( + self._state.samples_per_bin + - accumulator_before + + np.arange(n_bins, dtype=np.float64) * self._state.samples_per_bin + ).astype(np.int64) + if n_bins > 0 + else np.empty((0,), dtype=np.int64) + ) + + prev_over = self._state.prev_over + if prev_over is None: + prev_flat = _initial_prev_over(flat, self.settings.threshold) + else: + prev_flat = np.asarray(prev_over, dtype=bool).reshape(n_features) + + elapsed_flat = np.asarray(self._state.elapsed, dtype=np.int32).reshape(n_features).copy() + overflow_flat = np.asarray(self._state.overflow_counts, dtype=np.float32).reshape(n_features).copy() + + out = np.zeros((n_bins, n_features), dtype=np.float32) + if n_bins > 0: + out[0] = overflow_flat + overflow_flat.fill(0.0) + + for samp_ix in range(n_samples): + sample = flat[samp_ix] + if self.settings.threshold >= 0: + over = sample >= self.settings.threshold + else: + over = sample <= self.settings.threshold + crossing = over & ~prev_flat + prev_flat = over + + elapsed_flat += 1 + if self._state.refrac_width > 2: + accepted = crossing & (elapsed_flat > self._state.refrac_width) + else: + accepted = crossing + if np.any(accepted): + bin_ix = np.searchsorted(bin_end_samples, samp_ix, side="right") + accepted_f32 = accepted.astype(np.float32) + if bin_ix < n_bins: + out[bin_ix] += accepted_f32 + else: + overflow_flat += accepted_f32 + elapsed_flat[accepted] = 0 + + if self.settings.rate_normalize: + out /= self.settings.bin_duration + + self._state.prev_over = prev_flat.reshape(feature_shape) + self._state.elapsed = elapsed_flat.reshape(feature_shape) + self._state.overflow_counts = overflow_flat.reshape(feature_shape) + return out.reshape((n_bins,) + feature_shape) + + def _process_mlx(self, data, accumulator_before: float, n_bins: int): + import mlx.core as mx + + from ezmsg.event.util.threshold_rate_mlx_metal import threshold_crossing_rate_mlx_metal + + if self._state.prev_over is None: + first = data[0] + if self.settings.threshold >= 0: + over = first >= self.settings.threshold + else: + over = first <= self.settings.threshold + self._state.prev_over = over.astype(mx.int8) + + self._state.elapsed = mx.asarray(self._state.elapsed, dtype=mx.int32) + self._state.overflow_counts = mx.asarray(self._state.overflow_counts, dtype=mx.float32) + self._state.prev_over = mx.asarray(self._state.prev_over, dtype=mx.int8) + + out, self._state.prev_over, self._state.elapsed, self._state.overflow_counts = ( + threshold_crossing_rate_mlx_metal( + data, + self._state.prev_over, + self._state.elapsed, + self._state.overflow_counts, + threshold=self.settings.threshold, + refrac_width=self._state.refrac_width, + bin_accumulator=accumulator_before, + samples_per_bin=self._state.samples_per_bin, + n_bins=n_bins, + bin_duration=self.settings.bin_duration, + rate_normalize=self.settings.rate_normalize, + ) + ) + return out + + +def _initial_prev_over(flat: np.ndarray, threshold: float) -> np.ndarray: + n_features = flat.shape[1] + if flat.shape[0] == 0: + return np.zeros(n_features, dtype=bool) + first = flat[0] + return first >= threshold if threshold >= 0 else first <= threshold + + +class ThresholdCrossingRate( + BaseTransformerUnit[ + ThresholdCrossingRateSettings, + AxisArray, + AxisArray, + ThresholdCrossingRateTransformer, + ] +): + """Unit for dense threshold-crossing event rates.""" + + SETTINGS = ThresholdCrossingRateSettings diff --git a/src/ezmsg/event/util/threshold_rate_mlx_metal.py b/src/ezmsg/event/util/threshold_rate_mlx_metal.py new file mode 100644 index 0000000..498b9af --- /dev/null +++ b/src/ezmsg/event/util/threshold_rate_mlx_metal.py @@ -0,0 +1,255 @@ +"""Fused threshold-crossing rate calculation on Apple Silicon via MLX + Metal.""" + +import mlx.core as mx + + +def threshold_crossing_rate_mlx_metal( + x, + prev_over, + elapsed, + overflow_counts, + *, + threshold: float, + refrac_width: int, + bin_accumulator: float, + samples_per_bin: float, + n_bins: int, + bin_duration: float, + rate_normalize: bool, +): + """Compute threshold-crossing rates for ``x`` with time on axis 0. + + Args: + x: MLX array with shape ``(n_samples, *features)``. + prev_over: Int8 MLX array with shape ``(*features,)``; nonzero + indicates whether the sample before this chunk was over threshold. + elapsed: Int32 MLX array with shape ``(*features,)`` tracking samples + since the last accepted crossing. + overflow_counts: Float32 MLX array with shape ``(*features,)`` holding + raw counts in the partial output bin carried from previous chunks. + threshold: Threshold crossing level. + refrac_width: Refractory duration in samples. A crossing is accepted + only when the distance from the previous accepted crossing is + greater than this value. + bin_accumulator: Fractional number of input samples already accumulated + in the current partial output bin. + samples_per_bin: Fractional number of input samples per output bin. + n_bins: Number of complete output bins produced by this chunk. + bin_duration: Output bin duration in seconds. + rate_normalize: If true, output events/second; otherwise raw counts. + + Returns: + ``(rates, prev_over, elapsed, overflow_counts)``. ``rates`` has shape + ``(n_bins, *features)`` and the state outputs are shaped like the state + inputs. + """ + if x.ndim < 1: + raise ValueError(f"x must have at least 1 dimension; got {x.ndim}") + if samples_per_bin < 1.0: + raise ValueError(f"samples_per_bin must be >= 1.0; got {samples_per_bin}") + if n_bins < 0: + raise ValueError(f"n_bins must be >= 0; got {n_bins}") + + x_f32 = x.astype(mx.float32) if x.dtype != mx.float32 else x + batch_shape = tuple(x_f32.shape[1:]) + n_samples = x_f32.shape[0] + n_channels = 1 + for dim in batch_shape: + n_channels *= dim + + if n_samples == 0: + rates = mx.zeros((n_bins,) + batch_shape, dtype=mx.float32) + return ( + rates, + prev_over.astype(mx.int8).reshape(batch_shape), + elapsed.astype(mx.int32).reshape(batch_shape), + overflow_counts.astype(mx.float32).reshape(batch_shape), + ) + + x_flat = x_f32.reshape(n_samples, n_channels) + prev_flat = prev_over.astype(mx.int8).reshape(n_channels) + elapsed_flat = elapsed.astype(mx.int32).reshape(n_channels) + overflow_flat = overflow_counts.astype(mx.float32).reshape(n_channels) + n_words = (n_samples + 31) // 32 + + params = mx.array( + [ + float(threshold), + float(1.0 / bin_duration if rate_normalize else 1.0), + float(samples_per_bin - bin_accumulator), + float(samples_per_bin), + ], + dtype=mx.float32, + ) + + crossing_words, final_over = _crossing_words_kernel( + inputs=[x_flat, prev_flat, params], + template=[ + ("N_SAMPLES", n_samples), + ("N_CHANNELS", n_channels), + ("N_WORDS", n_words), + ], + grid=(n_words, n_channels, 1), + threadgroup=(1, 1, 1), + output_shapes=[(n_words, n_channels), (n_channels,)], + output_dtypes=[mx.uint32, mx.int8], + ) + + # Metal kernels cannot emit a zero-size output on all MLX versions. Use a + # one-bin scratch output and slice it away for chunks with no complete bin. + n_output_bins = max(n_bins, 1) + rates_flat, prev_out, elapsed_out, overflow_out = _refractory_words_kernel( + inputs=[crossing_words, final_over, elapsed_flat, overflow_flat, params], + template=[ + ("N_SAMPLES", n_samples), + ("N_CHANNELS", n_channels), + ("N_WORDS", n_words), + ("N_BINS", n_bins), + ("N_OUTPUT_BINS", n_output_bins), + ("REFRAC_WIDTH", refrac_width), + ], + grid=(n_channels, 1, 1), + threadgroup=(1, 1, 1), + output_shapes=[ + (n_output_bins, n_channels), + (n_channels,), + (n_channels,), + (n_channels,), + ], + output_dtypes=[mx.float32, mx.int8, mx.int32, mx.float32], + ) + + rates_flat = rates_flat[:n_bins] + rates = rates_flat.reshape((n_bins,) + batch_shape) + return ( + rates, + prev_out.reshape(batch_shape), + elapsed_out.reshape(batch_shape), + overflow_out.reshape(batch_shape), + ) + + +_CROSSING_WORDS_KERNEL_SOURCE = r""" + uint word = thread_position_in_grid.x; + uint ch = thread_position_in_grid.y; + if (word >= N_WORDS || ch >= N_CHANNELS) { + return; + } + + float threshold = params[0]; + uint start = word * 32; + uint prev = 0; + if (start == 0) { + prev = prev_over_in[ch] != 0; + } else { + float prev_sample = x_in[(start - 1) * N_CHANNELS + ch]; + prev = threshold >= 0.0f ? (prev_sample >= threshold) : (prev_sample <= threshold); + } + + uint bits = 0; + for (uint bit = 0; bit < 32; ++bit) { + uint t = start + bit; + if (t >= N_SAMPLES) { + break; + } + + float sample = x_in[t * N_CHANNELS + ch]; + uint over = threshold >= 0.0f ? (sample >= threshold) : (sample <= threshold); + if (over && !prev) { + bits |= (1u << bit); + } + prev = over; + } + crossing_words_out[word * N_CHANNELS + ch] = bits; + + if (word == N_WORDS - 1) { + final_over_out[ch] = prev ? 1 : 0; + } +""" + + +_REFRACTORY_WORDS_KERNEL_SOURCE = r""" + uint ch = thread_position_in_grid.x; + if (ch >= N_CHANNELS) { + return; + } + + float scale = params[1]; + float first_bin_end = params[2]; + float samples_per_bin = params[3]; + + float overflow = overflow_counts_in[ch]; + for (uint bin = 0; bin < N_OUTPUT_BINS; ++bin) { + rates_out[bin * N_CHANNELS + ch] = 0.0f; + } + if (N_BINS > 0) { + rates_out[ch] = overflow; + overflow = 0.0f; + } + + int elapsed = elapsed_in[ch]; + int last_t = -1; + + for (uint word = 0; word < N_WORDS; ++word) { + uint bits = crossing_words_in[word * N_CHANNELS + ch]; + while (bits != 0) { + uint bit = 0; + uint mask = 1u; + while ((bits & mask) == 0u) { + bit += 1; + mask <<= 1; + } + bits &= ~mask; + + uint t = word * 32 + bit; + if (t >= N_SAMPLES) { + break; + } + + elapsed += int(t) - last_t; + last_t = int(t); + + if (REFRAC_WIDTH <= 2 || elapsed > REFRAC_WIDTH) { + int active_bin = int(ceil((float(t) + 1.0f - first_bin_end) / samples_per_bin)); + if (active_bin >= 0 && active_bin < N_BINS) { + rates_out[uint(active_bin) * N_CHANNELS + ch] += 1.0f; + } else { + overflow += 1.0f; + } + elapsed = 0; + } + } + } + + elapsed += int(N_SAMPLES) - 1 - last_t; + + for (uint bin = 0; bin < N_BINS; ++bin) { + rates_out[bin * N_CHANNELS + ch] *= scale; + } + + prev_over_out[ch] = final_over_in[ch]; + elapsed_out[ch] = elapsed; + overflow_counts_out[ch] = overflow; +""" + + +_crossing_words_kernel = mx.fast.metal_kernel( + name="threshold_crossing_words", + input_names=["x_in", "prev_over_in", "params"], + output_names=["crossing_words_out", "final_over_out"], + source=_CROSSING_WORDS_KERNEL_SOURCE, +) + + +_refractory_words_kernel = mx.fast.metal_kernel( + name="threshold_refractory_words", + input_names=[ + "crossing_words_in", + "final_over_in", + "elapsed_in", + "overflow_counts_in", + "params", + ], + output_names=["rates_out", "prev_over_out", "elapsed_out", "overflow_counts_out"], + source=_REFRACTORY_WORDS_KERNEL_SOURCE, +) diff --git a/tests/test_threshold_rate.py b/tests/test_threshold_rate.py new file mode 100644 index 0000000..11dfa77 --- /dev/null +++ b/tests/test_threshold_rate.py @@ -0,0 +1,357 @@ +import numpy as np +import pytest +from ezmsg.util.messages.axisarray import AxisArray + +from ezmsg.event.peak import ThresholdCrossingTransformer +from ezmsg.event.rate import EventRateSettings, Rate +from ezmsg.event.threshold_rate import ThresholdCrossingRateSettings, ThresholdCrossingRateTransformer + + +def _make_msg(data: np.ndarray, fs: float, offset: float, dims: list[str] | None = None) -> AxisArray: + dims = dims or ["time", "ch"] + return AxisArray( + data=data, + dims=dims, + axes={ + "time": AxisArray.TimeAxis(fs=fs, offset=offset), + "ch": AxisArray.CoordinateAxis( + data=np.arange(data.shape[dims.index("ch")]), + dims=["ch"], + ), + }, + ) + + +def _run_sparse_reference( + chunks: list[np.ndarray], + *, + fs: float, + threshold: float, + refrac_dur: float, + bin_duration: float, + dims: list[str] | None = None, +) -> list[AxisArray]: + thresh = ThresholdCrossingTransformer(threshold=threshold, refrac_dur=refrac_dur) + rate = Rate(EventRateSettings(bin_duration=bin_duration)) + out = [] + samp_offset = 0 + for chunk in chunks: + msg = _make_msg(chunk, fs, samp_offset / fs, dims=dims) + out.append(rate(thresh(msg))) + samp_offset += chunk.shape[dims.index("time") if dims else 0] + return out + + +def _run_dense_fused( + chunks: list[np.ndarray], + *, + fs: float, + threshold: float, + refrac_dur: float, + bin_duration: float, + dims: list[str] | None = None, +) -> list[AxisArray]: + proc = ThresholdCrossingRateTransformer( + ThresholdCrossingRateSettings( + threshold=threshold, + refrac_dur=refrac_dur, + bin_duration=bin_duration, + use_mlx_metal=False, + ) + ) + out = [] + samp_offset = 0 + for chunk in chunks: + msg = _make_msg(chunk, fs, samp_offset / fs, dims=dims) + out.append(proc(msg)) + samp_offset += chunk.shape[dims.index("time") if dims else 0] + return out + + +def _require_mlx_metal(): + mx = pytest.importorskip("mlx.core") + try: + probe = mx.array([1.0], dtype=mx.float32) + mx.eval(probe) + except RuntimeError as exc: + pytest.skip(f"MLX Metal device unavailable: {exc}") + return mx + + +def _assert_messages_match(actual: list[AxisArray], expected: list[AxisArray]) -> None: + assert len(actual) == len(expected) + for actual_msg, expected_msg in zip(actual, expected): + assert actual_msg.dims == expected_msg.dims + assert actual_msg.data.shape == expected_msg.data.shape + assert actual_msg.axes["time"].gain == expected_msg.axes["time"].gain + assert np.isclose(actual_msg.axes["time"].offset, expected_msg.axes["time"].offset) + np.testing.assert_allclose(actual_msg.data, expected_msg.data) + + +def test_threshold_crossing_rate_matches_sparse_pipeline_with_refractory_and_overflow(): + fs = 1000.0 + threshold = 1.0 + refrac_dur = 0.006 + bin_duration = 0.010 + data = np.zeros((43, 3), dtype=np.float32) + + # Channel 0 exercises greedy refractory: sample 5 is dropped after sample 1, + # but sample 9 is accepted because dropped crossings do not extend refractory. + for samp, ch in [ + (0, 2), # first sample has no prior reference, so it should not count + (1, 0), + (5, 0), + (9, 0), + (12, 1), + (19, 1), + (20, 1), + (30, 2), + (38, 2), + ]: + data[samp, ch] = 2.0 + + chunks = [data[:4], data[4:12], data[12:19], data[19:]] + expected = _run_sparse_reference( + chunks, + fs=fs, + threshold=threshold, + refrac_dur=refrac_dur, + bin_duration=bin_duration, + ) + actual = _run_dense_fused( + chunks, + fs=fs, + threshold=threshold, + refrac_dur=refrac_dur, + bin_duration=bin_duration, + ) + + _assert_messages_match(actual, expected) + + +def test_threshold_crossing_rate_matches_sparse_pipeline_for_negative_threshold(): + fs = 1000.0 + threshold = -1.0 + refrac_dur = 0.004 + bin_duration = 0.010 + data = np.zeros((31, 2), dtype=np.float32) + for samp, ch in [(0, 0), (3, 0), (8, 0), (12, 1), (17, 1), (24, 1)]: + data[samp, ch] = -2.0 + + chunks = [data[:6], data[6:16], data[16:]] + expected = _run_sparse_reference( + chunks, + fs=fs, + threshold=threshold, + refrac_dur=refrac_dur, + bin_duration=bin_duration, + ) + actual = _run_dense_fused( + chunks, + fs=fs, + threshold=threshold, + refrac_dur=refrac_dur, + bin_duration=bin_duration, + ) + + _assert_messages_match(actual, expected) + + +def test_threshold_crossing_rate_matches_sparse_pipeline_for_fractional_samples_per_bin(): + fs = 30012.0048 + threshold = 1.0 + refrac_dur = 0.0 + bin_duration = 0.020 + data = np.zeros((5000, 2), dtype=np.float32) + for samp, ch in [ + (599, 0), + (600, 1), + (1199, 0), + (1200, 1), + (2399, 0), + (2400, 1), + (3000, 0), + (3001, 1), + (3602, 0), + ]: + data[samp, ch] = 2.0 + + chunks = [data[:777], data[777:1310], data[1310:2450], data[2450:3800], data[3800:]] + expected = _run_sparse_reference( + chunks, + fs=fs, + threshold=threshold, + refrac_dur=refrac_dur, + bin_duration=bin_duration, + ) + actual = _run_dense_fused( + chunks, + fs=fs, + threshold=threshold, + refrac_dur=refrac_dur, + bin_duration=bin_duration, + ) + + _assert_messages_match(actual, expected) + + +def test_threshold_crossing_rate_supports_nonzero_time_axis(): + fs = 1000.0 + threshold = 1.0 + refrac_dur = 0.004 + bin_duration = 0.010 + data = np.zeros((2, 25), dtype=np.float32) + data[0, [2, 8, 14]] = 2.0 + data[1, [5, 11, 21]] = 2.0 + + chunks = [data[:, :9], data[:, 9:17], data[:, 17:]] + dims = ["ch", "time"] + expected = _run_sparse_reference( + chunks, + fs=fs, + threshold=threshold, + refrac_dur=refrac_dur, + bin_duration=bin_duration, + dims=dims, + ) + actual = _run_dense_fused( + chunks, + fs=fs, + threshold=threshold, + refrac_dur=refrac_dur, + bin_duration=bin_duration, + dims=dims, + ) + + _assert_messages_match(actual, expected) + + +def test_threshold_crossing_rate_empty_time_first_and_midstream(): + fs = 1000.0 + proc = ThresholdCrossingRateTransformer( + threshold=1.0, + refrac_dur=0.004, + bin_duration=0.010, + use_mlx_metal=False, + ) + + empty = _make_msg(np.zeros((0, 2), dtype=np.float32), fs, 0.0) + out_empty = proc(empty) + assert out_empty.data.shape == (0, 2) + + data = np.zeros((20, 2), dtype=np.float32) + data[0, 0] = 2.0 + data[3, 0] = 2.0 + data[11, 1] = 2.0 + out_data = proc(_make_msg(data, fs, 0.0)) + + ref_thresh = ThresholdCrossingTransformer(threshold=1.0, refrac_dur=0.004) + ref_rate = Rate(EventRateSettings(bin_duration=0.010)) + ref_rate(ref_thresh(empty)) + expected = ref_rate(ref_thresh(_make_msg(data, fs, 0.0))) + np.testing.assert_allclose(out_data.data, expected.data) + + mid_empty = proc(_make_msg(np.zeros((0, 2), dtype=np.float32), fs, 0.020)) + assert mid_empty.data.shape == (0, 2) + + +def test_threshold_crossing_rate_mlx_metal_matches_cpu_for_adversarial_refractory(): + mx = _require_mlx_metal() + fs = 1000.0 + threshold = 1.0 + refrac_dur = 0.030 + bin_duration = 0.100 + data = np.zeros((1000, 4), dtype=np.float32) + + for ch in range(data.shape[1]): + for samp in range(ch + 1, data.shape[0], 31): + data[samp, ch] = 2.0 + + chunks = [data[:137], data[137:503], data[503:777], data[777:]] + expected = _run_dense_fused( + chunks, + fs=fs, + threshold=threshold, + refrac_dur=refrac_dur, + bin_duration=bin_duration, + ) + + proc = ThresholdCrossingRateTransformer( + ThresholdCrossingRateSettings( + threshold=threshold, + refrac_dur=refrac_dur, + bin_duration=bin_duration, + use_mlx_metal=True, + ) + ) + actual = [] + samp_offset = 0 + for chunk in chunks: + msg = _make_msg(mx.array(chunk), fs, samp_offset / fs) + out = proc(msg) + mx.eval( + out.data, + proc._state.prev_over, + proc._state.elapsed, + proc._state.overflow_counts, + ) + actual.append(out) + samp_offset += chunk.shape[0] + + _assert_messages_match(actual, expected) + + +def test_threshold_crossing_rate_mlx_metal_matches_cpu_for_negative_fractional_bins(): + mx = _require_mlx_metal() + fs = 30012.0048 + threshold = -1.0 + refrac_dur = 0.001 + bin_duration = 0.020 + data = np.zeros((5000, 3), dtype=np.float32) + for samp, ch in [ + (599, 0), + (600, 1), + (631, 1), + (1199, 0), + (1200, 2), + (1230, 2), + (2399, 0), + (2400, 1), + (3000, 0), + (3001, 2), + (3602, 1), + ]: + data[samp, ch] = -2.0 + + chunks = [data[:777], data[777:1310], data[1310:2450], data[2450:3800], data[3800:]] + expected = _run_dense_fused( + chunks, + fs=fs, + threshold=threshold, + refrac_dur=refrac_dur, + bin_duration=bin_duration, + ) + + proc = ThresholdCrossingRateTransformer( + ThresholdCrossingRateSettings( + threshold=threshold, + refrac_dur=refrac_dur, + bin_duration=bin_duration, + use_mlx_metal=True, + ) + ) + actual = [] + samp_offset = 0 + for chunk in chunks: + msg = _make_msg(mx.array(chunk), fs, samp_offset / fs) + out = proc(msg) + mx.eval( + out.data, + proc._state.prev_over, + proc._state.elapsed, + proc._state.overflow_counts, + ) + actual.append(out) + samp_offset += chunk.shape[0] + + _assert_messages_match(actual, expected)