diff --git a/doc/changes/devel/12445.newfeature.rst b/doc/changes/devel/12445.newfeature.rst new file mode 100644 index 00000000000..ccaef2c2c07 --- /dev/null +++ b/doc/changes/devel/12445.newfeature.rst @@ -0,0 +1 @@ +Add support for multiple raw instances in :func:`mne.preprocessing.compute_average_dev_head_t` by `Eric Larson`_. diff --git a/mne/preprocessing/artifact_detection.py b/mne/preprocessing/artifact_detection.py index 1f3ee7b4946..514eadb00a9 100644 --- a/mne/preprocessing/artifact_detection.py +++ b/mne/preprocessing/artifact_detection.py @@ -32,6 +32,7 @@ _validate_type, logger, verbose, + warn, ) @@ -293,7 +294,8 @@ def annotate_movement( return annot, disp -def compute_average_dev_head_t(raw, pos): +@verbose +def compute_average_dev_head_t(raw, pos, *, verbose=None): """Get new device to head transform based on good segments. Segments starting with "BAD" annotations are not included for calculating @@ -301,19 +303,59 @@ def compute_average_dev_head_t(raw, pos): Parameters ---------- - raw : instance of Raw - Data to compute head position. - pos : array, shape (N, 10) - The position and quaternion parameters from cHPI fitting. + raw : instance of Raw | list of Raw + Data to compute head position. Can be a list containing multiple raw + instances. + pos : array, shape (N, 10) | list of ndarray + The position and quaternion parameters from cHPI fitting. Can be + a list containing multiple position arrays, one per raw instance passed. + %(verbose)s Returns ------- dev_head_t : instance of Transform New ``dev_head_t`` transformation using the averaged good head positions. + + Notes + ----- + .. versionchanged:: 1.7 + Support for multiple raw instances and position arrays was added. """ + # Get weighted head pos trans and rot + if not isinstance(raw, (list, tuple)): + raw = [raw] + if not isinstance(pos, (list, tuple)): + pos = [pos] + if len(pos) != len(raw): + raise ValueError( + f"Number of head positions ({len(pos)}) must match the number of raw " + f"instances ({len(raw)})" + ) + hp = list() + dt = list() + for ri, (r, p) in enumerate(zip(raw, pos)): + _validate_type(r, BaseRaw, f"raw[{ri}]") + _validate_type(p, np.ndarray, f"pos[{ri}]") + hp_, dt_ = _raw_hp_weights(r, p) + hp.append(hp_) + dt.append(dt_) + hp = np.concatenate(hp, axis=0) + dt = np.concatenate(dt, axis=0) + dt /= dt.sum() + best_q = _average_quats(hp[:, 1:4], weights=dt) + trans = np.eye(4) + trans[:3, :3] = quat_to_rot(best_q) + trans[:3, 3] = dt @ hp[:, 4:7] + dist = np.linalg.norm(trans[:3, 3]) + if dist > 1: # less than 1 meter is sane + warn(f"Implausible head position detected: {dist} meters from device origin") + dev_head_t = Transform("meg", "head", trans) + return dev_head_t + + +def _raw_hp_weights(raw, pos): sfreq = raw.info["sfreq"] seg_good = np.ones(len(raw.times)) - trans_pos = np.zeros(3) hp = pos.copy() hp_ts = hp[:, 0] - raw._first_time @@ -353,19 +395,7 @@ def compute_average_dev_head_t(raw, pos): assert (dt >= 0).all() dt = dt / sfreq del seg_good, idx - - # Get weighted head pos trans and rot - trans_pos += np.dot(dt, hp[:, 4:7]) - - rot_qs = hp[:, 1:4] - best_q = _average_quats(rot_qs, weights=dt) - - trans = np.eye(4) - trans[:3, :3] = quat_to_rot(best_q) - trans[:3, 3] = trans_pos / dt.sum() - assert np.linalg.norm(trans[:3, 3]) < 1 # less than 1 meter is sane - dev_head_t = Transform("meg", "head", trans) - return dev_head_t + return hp, dt def _annotations_from_mask(times, mask, annot_name, orig_time=None): diff --git a/mne/preprocessing/tests/test_artifact_detection.py b/mne/preprocessing/tests/test_artifact_detection.py index af01fa4416d..6aa386d0b05 100644 --- a/mne/preprocessing/tests/test_artifact_detection.py +++ b/mne/preprocessing/tests/test_artifact_detection.py @@ -18,6 +18,7 @@ compute_average_dev_head_t, ) from mne.tests.test_annotations import _assert_annotations_equal +from mne.transforms import _angle_dist_between_rigid, quat_to_rot, rot_to_quat data_path = testing.data_path(download=False) sss_path = data_path / "SSS" @@ -35,6 +36,7 @@ def test_movement_annotation_head_correction(meas_date): raw.set_meas_date(None) else: assert meas_date == "orig" + raw_unannot = raw.copy() # Check 5 rotation segments are detected annot_rot, [] = annotate_movement(raw, pos, rotation_velocity_limit=5) @@ -67,7 +69,7 @@ def test_movement_annotation_head_correction(meas_date): _assert_annotations_equal(annot_all_2, annot_all) assert annot_all.orig_time == raw.info["meas_date"] raw.set_annotations(annot_all) - dev_head_t = compute_average_dev_head_t(raw, pos) + dev_head_t = compute_average_dev_head_t(raw, pos)["trans"] dev_head_t_ori = np.array( [ @@ -78,13 +80,83 @@ def test_movement_annotation_head_correction(meas_date): ] ) - assert_allclose(dev_head_t_ori, dev_head_t["trans"], rtol=1e-5, atol=0) + assert_allclose(dev_head_t_ori, dev_head_t, rtol=1e-5, atol=0) + + with pytest.raises(ValueError, match="Number of .* must match .*"): + compute_average_dev_head_t([raw], [pos] * 2) + # Using two identical ones should be identical ... + dev_head_t_double = compute_average_dev_head_t([raw] * 2, [pos] * 2)["trans"] + assert_allclose(dev_head_t, dev_head_t_double) + # ... unannotated and annotated versions differ ... + dev_head_t_unannot = compute_average_dev_head_t(raw_unannot, pos)["trans"] + rot_tol = 1.5e-3 + mov_tol = 1e-3 + assert not np.allclose( + dev_head_t_unannot[:3, :3], + dev_head_t[:3, :3], + atol=rot_tol, + rtol=0, + ) + assert not np.allclose( + dev_head_t_unannot[:3, 3], + dev_head_t[:3, 3], + atol=mov_tol, + rtol=0, + ) + # ... and Averaging the two is close to (but not identical!) to operating on the two + # files. Note they shouldn't be identical because there are more time points + # included in the unannotated version! + dev_head_t_naive = np.eye(4) + dev_head_t_naive[:3, :3] = quat_to_rot( + np.mean( + rot_to_quat(np.array([dev_head_t[:3, :3], dev_head_t_unannot[:3, :3]])), + axis=0, + ) + ) + dev_head_t_naive[:3, 3] = np.mean( + [dev_head_t[:3, 3], dev_head_t_unannot[:3, 3]], axis=0 + ) + dev_head_t_combo = compute_average_dev_head_t([raw, raw_unannot], [pos] * 2)[ + "trans" + ] + unit_kw = dict(distance_units="mm", angle_units="deg") + deg_annot_combo, mm_annot_combo = _angle_dist_between_rigid( + dev_head_t, + dev_head_t_combo, + **unit_kw, + ) + deg_unannot_combo, mm_unannot_combo = _angle_dist_between_rigid( + dev_head_t_unannot, + dev_head_t_combo, + **unit_kw, + ) + deg_annot_unannot, mm_annot_unannot = _angle_dist_between_rigid( + dev_head_t, + dev_head_t_unannot, + **unit_kw, + ) + deg_combo_naive, mm_combo_naive = _angle_dist_between_rigid( + dev_head_t_combo, + dev_head_t_naive, + **unit_kw, + ) + # combo<->naive closer than combo<->annotated closer than annotated<->unannotated + assert 0.05 < deg_combo_naive < deg_annot_combo < deg_annot_unannot < 1.5 + assert 0.1 < mm_combo_naive < mm_annot_combo < mm_annot_unannot < 2 + # combo<->naive closer than combo<->unannotated closer than annotated<->unannotated + assert 0.05 < deg_combo_naive < deg_unannot_combo < deg_annot_unannot < 1.5 + assert 0.12 < mm_combo_naive < mm_unannot_combo < mm_annot_unannot < 2.0 # Smoke test skipping time due to previous annotations. raw.set_annotations(Annotations([raw.times[0]], 0.1, "bad")) annot_dis, _ = annotate_movement(raw, pos, mean_distance_limit=0.02) assert annot_dis.duration.size == 1 + # really far should warn + pos[:, 4] += 5 + with pytest.warns(RuntimeWarning, match="Implausible head position"): + compute_average_dev_head_t(raw, pos) + @testing.requires_testing_data @pytest.mark.parametrize("meas_date", (None, "orig")) diff --git a/mne/preprocessing/tests/test_fine_cal.py b/mne/preprocessing/tests/test_fine_cal.py index 95c9e7d63ba..2b3d4df0e3f 100644 --- a/mne/preprocessing/tests/test_fine_cal.py +++ b/mne/preprocessing/tests/test_fine_cal.py @@ -18,7 +18,7 @@ write_fine_calibration, ) from mne.preprocessing.tests.test_maxwell import _assert_shielding -from mne.transforms import _angle_between_quats, rot_to_quat +from mne.transforms import _angle_dist_between_rigid from mne.utils import object_diff # Define fine calibration filepaths @@ -75,16 +75,17 @@ def test_compute_fine_cal(): orig_trans = _loc_to_coil_trans(orig_locs) want_trans = _loc_to_coil_trans(want_locs) got_trans = _loc_to_coil_trans(got_locs) - dist = np.linalg.norm(got_trans[:, :3, 3] - want_trans[:, :3, 3], axis=1) - assert_allclose(dist, 0.0, atol=1e-6) - dist = np.linalg.norm(got_trans[:, :3, 3] - orig_trans[:, :3, 3], axis=1) - assert_allclose(dist, 0.0, atol=1e-6) - orig_quat = rot_to_quat(orig_trans[:, :3, :3]) - want_quat = rot_to_quat(want_trans[:, :3, :3]) - got_quat = rot_to_quat(got_trans[:, :3, :3]) - want_orig_angles = np.rad2deg(_angle_between_quats(want_quat, orig_quat)) - got_want_angles = np.rad2deg(_angle_between_quats(got_quat, want_quat)) - got_orig_angles = np.rad2deg(_angle_between_quats(got_quat, orig_quat)) + want_orig_angles, want_orig_dist = _angle_dist_between_rigid( + want_trans, orig_trans, angle_units="deg" + ) + got_want_angles, got_want_dist = _angle_dist_between_rigid( + got_trans, want_trans, angle_units="deg" + ) + got_orig_angles, got_orig_dist = _angle_dist_between_rigid( + got_trans, orig_trans, angle_units="deg" + ) + assert_allclose(got_want_dist, 0.0, atol=1e-6) + assert_allclose(got_orig_dist, 0.0, atol=1e-6) for key in ("mag", "grad"): # imb_cals value p = pick_types(raw.info, meg=key, exclude=()) diff --git a/mne/transforms.py b/mne/transforms.py index cb387582ef8..975a4818910 100644 --- a/mne/transforms.py +++ b/mne/transforms.py @@ -1351,6 +1351,28 @@ def _quat_to_affine(quat): return affine +def _affine_to_quat(affine): + assert affine.shape[-2:] == (4, 4) + return np.concatenate( + [rot_to_quat(affine[..., :3, :3]), affine[..., :3, 3]], + axis=-1, + ) + + +def _angle_dist_between_rigid(a, b=None, *, angle_units="rad", distance_units="m"): + a = _affine_to_quat(a) + b = np.zeros(6) if b is None else _affine_to_quat(b) + ang = _angle_between_quats(a[..., :3], b[..., :3]) + dist = np.linalg.norm(a[..., 3:] - b[..., 3:], axis=-1) + assert isinstance(angle_units, str) and angle_units in ("rad", "deg") + if angle_units == "deg": + ang = np.rad2deg(ang) + assert isinstance(distance_units, str) and distance_units in ("m", "mm") + if distance_units == "mm": + dist *= 1e3 + return ang, dist + + def _angle_between_quats(x, y=None): """Compute the ang between two quaternions w/3-element representations.""" # z = conj(x) * y @@ -1839,10 +1861,7 @@ def _compute_volume_registration( # report some useful information if step in ("translation", "rigid"): - dist = np.linalg.norm(reg_affine[:3, 3]) - angle = np.rad2deg( - _angle_between_quats(np.zeros(3), rot_to_quat(reg_affine[:3, :3])) - ) + angle, dist = _angle_dist_between_rigid(reg_affine, angle_units="deg") logger.info(f" Translation: {dist:6.1f} mm") if step == "rigid": logger.info(f" Rotation: {angle:6.1f}°")