Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/devel/12445.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for multiple raw instances in :func:`mne.preprocessing.compute_average_dev_head_t` by `Eric Larson`_.
68 changes: 49 additions & 19 deletions mne/preprocessing/artifact_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
_validate_type,
logger,
verbose,
warn,
)


Expand Down Expand Up @@ -293,27 +294,68 @@ def annotate_movement(
return annot, disp


def compute_average_dev_head_t(raw, pos):
@verbose
def compute_average_dev_head_t(raw, pos, *, verbose=None):
"""Get new device to head transform based on good segments.

Segments starting with "BAD" annotations are not included for calculating
the mean head position.

Parameters
----------
raw : instance of Raw
Data to compute head position.
pos : array, shape (N, 10)
The position and quaternion parameters from cHPI fitting.
raw : instance of Raw | list of Raw
Data to compute head position. Can be a list containing multiple raw
instances.
pos : array, shape (N, 10) | list of ndarray
The position and quaternion parameters from cHPI fitting. Can be
a list containing multiple position arrays, one per raw instance passed.
%(verbose)s

Returns
-------
dev_head_t : instance of Transform
New ``dev_head_t`` transformation using the averaged good head positions.

Notes
-----
.. versionchanged:: 1.7
Support for multiple raw instances and position arrays was added.
"""
# Get weighted head pos trans and rot
if not isinstance(raw, (list, tuple)):
raw = [raw]
if not isinstance(pos, (list, tuple)):
pos = [pos]
if len(pos) != len(raw):
raise ValueError(
f"Number of head positions ({len(pos)}) must match the number of raw "
f"instances ({len(raw)})"
)
hp = list()
dt = list()
for ri, (r, p) in enumerate(zip(raw, pos)):
_validate_type(r, BaseRaw, f"raw[{ri}]")
_validate_type(p, np.ndarray, f"pos[{ri}]")
Comment on lines +337 to +338
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the error messages here (which will reference raw[0] or pos[0]) might be slightly confusing for the case where a user passed only 1 raw / pos (not a list). But getting it "right" is pretty convoluted:

item_name = "raw" if len(raw) == 1 else "raw[{}]" 
_validate_type(r, BaseRaw, item_name.format(ri))

if you can think of a cleaner/simpler way please go for it, otherwise probably fine as-is.

hp_, dt_ = _raw_hp_weights(r, p)
hp.append(hp_)
dt.append(dt_)
hp = np.concatenate(hp, axis=0)
dt = np.concatenate(dt, axis=0)
dt /= dt.sum()
best_q = _average_quats(hp[:, 1:4], weights=dt)
trans = np.eye(4)
trans[:3, :3] = quat_to_rot(best_q)
trans[:3, 3] = dt @ hp[:, 4:7]
dist = np.linalg.norm(trans[:3, 3])
if dist > 1: # less than 1 meter is sane
warn(f"Implausible head position detected: {dist} meters from device origin")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tempted to do this but it bumps the line over the length limit

Suggested change
warn(f"Implausible head position detected: {dist} meters from device origin")
warn(f"Implausible head position detected: {dist:0.3f} meters from device origin")

dev_head_t = Transform("meg", "head", trans)
return dev_head_t


def _raw_hp_weights(raw, pos):
sfreq = raw.info["sfreq"]
seg_good = np.ones(len(raw.times))
trans_pos = np.zeros(3)
hp = pos.copy()
hp_ts = hp[:, 0] - raw._first_time

Expand Down Expand Up @@ -353,19 +395,7 @@ def compute_average_dev_head_t(raw, pos):
assert (dt >= 0).all()
dt = dt / sfreq
del seg_good, idx

# Get weighted head pos trans and rot
trans_pos += np.dot(dt, hp[:, 4:7])

rot_qs = hp[:, 1:4]
best_q = _average_quats(rot_qs, weights=dt)

trans = np.eye(4)
trans[:3, :3] = quat_to_rot(best_q)
trans[:3, 3] = trans_pos / dt.sum()
assert np.linalg.norm(trans[:3, 3]) < 1 # less than 1 meter is sane
dev_head_t = Transform("meg", "head", trans)
return dev_head_t
return hp, dt


def _annotations_from_mask(times, mask, annot_name, orig_time=None):
Expand Down
76 changes: 74 additions & 2 deletions mne/preprocessing/tests/test_artifact_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
compute_average_dev_head_t,
)
from mne.tests.test_annotations import _assert_annotations_equal
from mne.transforms import _angle_dist_between_rigid, quat_to_rot, rot_to_quat

data_path = testing.data_path(download=False)
sss_path = data_path / "SSS"
Expand All @@ -35,6 +36,7 @@ def test_movement_annotation_head_correction(meas_date):
raw.set_meas_date(None)
else:
assert meas_date == "orig"
raw_unannot = raw.copy()

# Check 5 rotation segments are detected
annot_rot, [] = annotate_movement(raw, pos, rotation_velocity_limit=5)
Expand Down Expand Up @@ -67,7 +69,7 @@ def test_movement_annotation_head_correction(meas_date):
_assert_annotations_equal(annot_all_2, annot_all)
assert annot_all.orig_time == raw.info["meas_date"]
raw.set_annotations(annot_all)
dev_head_t = compute_average_dev_head_t(raw, pos)
dev_head_t = compute_average_dev_head_t(raw, pos)["trans"]

dev_head_t_ori = np.array(
[
Expand All @@ -78,13 +80,83 @@ def test_movement_annotation_head_correction(meas_date):
]
)

assert_allclose(dev_head_t_ori, dev_head_t["trans"], rtol=1e-5, atol=0)
assert_allclose(dev_head_t_ori, dev_head_t, rtol=1e-5, atol=0)

with pytest.raises(ValueError, match="Number of .* must match .*"):
compute_average_dev_head_t([raw], [pos] * 2)
# Using two identical ones should be identical ...
dev_head_t_double = compute_average_dev_head_t([raw] * 2, [pos] * 2)["trans"]
assert_allclose(dev_head_t, dev_head_t_double)
# ... unannotated and annotated versions differ ...
dev_head_t_unannot = compute_average_dev_head_t(raw_unannot, pos)["trans"]
rot_tol = 1.5e-3
mov_tol = 1e-3
assert not np.allclose(
dev_head_t_unannot[:3, :3],
dev_head_t[:3, :3],
atol=rot_tol,
rtol=0,
)
assert not np.allclose(
dev_head_t_unannot[:3, 3],
dev_head_t[:3, 3],
atol=mov_tol,
rtol=0,
)
# ... and Averaging the two is close to (but not identical!) to operating on the two
# files. Note they shouldn't be identical because there are more time points
# included in the unannotated version!
dev_head_t_naive = np.eye(4)
dev_head_t_naive[:3, :3] = quat_to_rot(
np.mean(
rot_to_quat(np.array([dev_head_t[:3, :3], dev_head_t_unannot[:3, :3]])),
axis=0,
)
)
dev_head_t_naive[:3, 3] = np.mean(
[dev_head_t[:3, 3], dev_head_t_unannot[:3, 3]], axis=0
)
dev_head_t_combo = compute_average_dev_head_t([raw, raw_unannot], [pos] * 2)[
"trans"
]
unit_kw = dict(distance_units="mm", angle_units="deg")
deg_annot_combo, mm_annot_combo = _angle_dist_between_rigid(
dev_head_t,
dev_head_t_combo,
**unit_kw,
)
deg_unannot_combo, mm_unannot_combo = _angle_dist_between_rigid(
dev_head_t_unannot,
dev_head_t_combo,
**unit_kw,
)
deg_annot_unannot, mm_annot_unannot = _angle_dist_between_rigid(
dev_head_t,
dev_head_t_unannot,
**unit_kw,
)
deg_combo_naive, mm_combo_naive = _angle_dist_between_rigid(
dev_head_t_combo,
dev_head_t_naive,
**unit_kw,
)
# combo<->naive closer than combo<->annotated closer than annotated<->unannotated
assert 0.05 < deg_combo_naive < deg_annot_combo < deg_annot_unannot < 1.5
assert 0.1 < mm_combo_naive < mm_annot_combo < mm_annot_unannot < 2
# combo<->naive closer than combo<->unannotated closer than annotated<->unannotated
assert 0.05 < deg_combo_naive < deg_unannot_combo < deg_annot_unannot < 1.5
assert 0.12 < mm_combo_naive < mm_unannot_combo < mm_annot_unannot < 2.0

# Smoke test skipping time due to previous annotations.
raw.set_annotations(Annotations([raw.times[0]], 0.1, "bad"))
annot_dis, _ = annotate_movement(raw, pos, mean_distance_limit=0.02)
assert annot_dis.duration.size == 1

# really far should warn
pos[:, 4] += 5
with pytest.warns(RuntimeWarning, match="Implausible head position"):
compute_average_dev_head_t(raw, pos)


@testing.requires_testing_data
@pytest.mark.parametrize("meas_date", (None, "orig"))
Expand Down
23 changes: 12 additions & 11 deletions mne/preprocessing/tests/test_fine_cal.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
write_fine_calibration,
)
from mne.preprocessing.tests.test_maxwell import _assert_shielding
from mne.transforms import _angle_between_quats, rot_to_quat
from mne.transforms import _angle_dist_between_rigid
from mne.utils import object_diff

# Define fine calibration filepaths
Expand Down Expand Up @@ -75,16 +75,17 @@ def test_compute_fine_cal():
orig_trans = _loc_to_coil_trans(orig_locs)
want_trans = _loc_to_coil_trans(want_locs)
got_trans = _loc_to_coil_trans(got_locs)
dist = np.linalg.norm(got_trans[:, :3, 3] - want_trans[:, :3, 3], axis=1)
assert_allclose(dist, 0.0, atol=1e-6)
dist = np.linalg.norm(got_trans[:, :3, 3] - orig_trans[:, :3, 3], axis=1)
assert_allclose(dist, 0.0, atol=1e-6)
orig_quat = rot_to_quat(orig_trans[:, :3, :3])
want_quat = rot_to_quat(want_trans[:, :3, :3])
got_quat = rot_to_quat(got_trans[:, :3, :3])
want_orig_angles = np.rad2deg(_angle_between_quats(want_quat, orig_quat))
got_want_angles = np.rad2deg(_angle_between_quats(got_quat, want_quat))
got_orig_angles = np.rad2deg(_angle_between_quats(got_quat, orig_quat))
want_orig_angles, want_orig_dist = _angle_dist_between_rigid(
want_trans, orig_trans, angle_units="deg"
)
got_want_angles, got_want_dist = _angle_dist_between_rigid(
got_trans, want_trans, angle_units="deg"
)
got_orig_angles, got_orig_dist = _angle_dist_between_rigid(
got_trans, orig_trans, angle_units="deg"
)
assert_allclose(got_want_dist, 0.0, atol=1e-6)
assert_allclose(got_orig_dist, 0.0, atol=1e-6)
for key in ("mag", "grad"):
# imb_cals value
p = pick_types(raw.info, meg=key, exclude=())
Expand Down
27 changes: 23 additions & 4 deletions mne/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1351,6 +1351,28 @@ def _quat_to_affine(quat):
return affine


def _affine_to_quat(affine):
assert affine.shape[-2:] == (4, 4)
return np.concatenate(
[rot_to_quat(affine[..., :3, :3]), affine[..., :3, 3]],
axis=-1,
)


def _angle_dist_between_rigid(a, b=None, *, angle_units="rad", distance_units="m"):
a = _affine_to_quat(a)
b = np.zeros(6) if b is None else _affine_to_quat(b)
ang = _angle_between_quats(a[..., :3], b[..., :3])
dist = np.linalg.norm(a[..., 3:] - b[..., 3:], axis=-1)
assert isinstance(angle_units, str) and angle_units in ("rad", "deg")
if angle_units == "deg":
ang = np.rad2deg(ang)
assert isinstance(distance_units, str) and distance_units in ("m", "mm")
if distance_units == "mm":
dist *= 1e3
return ang, dist


def _angle_between_quats(x, y=None):
"""Compute the ang between two quaternions w/3-element representations."""
# z = conj(x) * y
Expand Down Expand Up @@ -1839,10 +1861,7 @@ def _compute_volume_registration(

# report some useful information
if step in ("translation", "rigid"):
dist = np.linalg.norm(reg_affine[:3, 3])
angle = np.rad2deg(
_angle_between_quats(np.zeros(3), rot_to_quat(reg_affine[:3, :3]))
)
angle, dist = _angle_dist_between_rigid(reg_affine, angle_units="deg")
logger.info(f" Translation: {dist:6.1f} mm")
if step == "rigid":
logger.info(f" Rotation: {angle:6.1f}°")
Expand Down