Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/devel/13056.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug with saving of anonymized data when helium info is present in measurement info, by `Eric Larson`_.
15 changes: 11 additions & 4 deletions mne/_fiff/meas_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -2493,6 +2493,8 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None):
hi["meas_date"] = _ensure_meas_date_none_or_dt(
tuple(int(t) for t in tag.data),
)
if "meas_date" not in hi:
hi["meas_date"] = None
info["helium_info"] = hi
del hi

Expand Down Expand Up @@ -2879,7 +2881,8 @@ def write_meas_info(fid, info, data_type=None, reset_range=True):
write_float(fid, FIFF.FIFF_HELIUM_LEVEL, hi["helium_level"])
if hi.get("orig_file_guid") is not None:
write_string(fid, FIFF.FIFF_ORIG_FILE_GUID, hi["orig_file_guid"])
write_int(fid, FIFF.FIFF_MEAS_DATE, _dt_to_stamp(hi["meas_date"]))
if hi["meas_date"] is not None:
write_int(fid, FIFF.FIFF_MEAS_DATE, _dt_to_stamp(hi["meas_date"]))
end_block(fid, FIFF.FIFFB_HELIUM)
del hi

Expand Down Expand Up @@ -2916,8 +2919,10 @@ def write_meas_info(fid, info, data_type=None, reset_range=True):
_write_proc_history(fid, info)


@fill_doc
def write_info(fname, info, data_type=None, reset_range=True):
@verbose
def write_info(
fname, info, *, data_type=None, reset_range=True, overwrite=False, verbose=None
):
"""Write measurement info in fif file.

Parameters
Expand All @@ -2931,8 +2936,10 @@ def write_info(fname, info, data_type=None, reset_range=True):
raw data.
reset_range : bool
If True, info['chs'][k]['range'] will be set to unity.
%(overwrite)s
%(verbose)s
"""
with start_and_end_file(fname) as fid:
with start_and_end_file(fname, overwrite=overwrite) as fid:
start_block(fid, FIFF.FIFFB_MEAS)
write_meas_info(fid, info, data_type, reset_range)
end_block(fid, FIFF.FIFFB_MEAS)
Expand Down
48 changes: 30 additions & 18 deletions mne/_fiff/tests/test_meas_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,9 @@ def test_read_write_info(tmp_path):
gantry_angle = info["gantry_angle"]

meas_id = info["meas_id"]
write_info(temp_file, info)
with pytest.raises(FileExistsError, match="Destination file exists"):
write_info(temp_file, info)
write_info(temp_file, info, overwrite=True)
info = read_info(temp_file)
assert info["proc_history"][0]["creator"] == creator
assert info["hpi_meas"][0]["creator"] == creator
Expand Down Expand Up @@ -348,7 +350,7 @@ def test_read_write_info(tmp_path):
info["meas_date"] = datetime(1800, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
fname = tmp_path / "test.fif"
with pytest.raises(RuntimeError, match="must be between "):
write_info(fname, info)
write_info(fname, info, overwrite=True)


@testing.requires_testing_data
Expand Down Expand Up @@ -377,7 +379,7 @@ def test_io_coord_frame(tmp_path):
for ch_type in ("eeg", "seeg", "ecog", "dbs", "hbo", "hbr"):
info = create_info(ch_names=["Test Ch"], sfreq=1000.0, ch_types=[ch_type])
info["chs"][0]["loc"][:3] = [0.05, 0.01, -0.03]
write_info(fname, info)
write_info(fname, info, overwrite=True)
info2 = read_info(fname)
assert info2["chs"][0]["coord_frame"] == FIFF.FIFFV_COORD_HEAD

Expand Down Expand Up @@ -585,7 +587,7 @@ def test_check_consistency():
info2["subject_info"] = {"height": "bad"}


def _test_anonymize_info(base_info):
def _test_anonymize_info(base_info, tmp_path):
"""Test that sensitive information can be anonymized."""
pytest.raises(TypeError, anonymize_info, "foo")
assert isinstance(base_info, Info)
Expand Down Expand Up @@ -692,14 +694,25 @@ def _adjust_back(e_i, dt):
# exp 4 tests is a supplied daysback
delta_t_3 = timedelta(days=223 + 364 * 500)

def _check_equiv(got, want, err_msg):
__tracebackhide__ = True
fname_temp = tmp_path / "test.fif"
assert_object_equal(got, want, err_msg=err_msg)
write_info(fname_temp, got, reset_range=False, overwrite=True)
got = read_info(fname_temp)
# this gets changed on write but that's expected
with got._unlock():
got["file_id"] = want["file_id"]
assert_object_equal(got, want, err_msg=f"{err_msg} (on I/O round trip)")

new_info = anonymize_info(base_info.copy())
assert_object_equal(new_info, exp_info, err_msg="anon mismatch")
_check_equiv(new_info, exp_info, err_msg="anon mismatch")

new_info = anonymize_info(base_info.copy(), keep_his=True)
assert_object_equal(new_info, exp_info_2, err_msg="anon keep_his mismatch")
_check_equiv(new_info, exp_info_2, err_msg="anon keep_his mismatch")

new_info = anonymize_info(base_info.copy(), daysback=delta_t_2.days)
assert_object_equal(new_info, exp_info_3, err_msg="anon daysback mismatch")
_check_equiv(new_info, exp_info_3, err_msg="anon daysback mismatch")

with pytest.raises(RuntimeError, match="anonymize_info generated"):
anonymize_info(base_info.copy(), daysback=delta_t_3.days)
Expand All @@ -726,15 +739,15 @@ def _adjust_back(e_i, dt):
new_info = anonymize_info(base_info.copy(), daysback=delta_t_2.days)
else:
new_info = anonymize_info(base_info.copy(), daysback=delta_t_2.days)
assert_object_equal(
_check_equiv(
new_info,
exp_info_3,
err_msg="meas_date=None daysback mismatch",
)

with _record_warnings(): # meas_date is None
new_info = anonymize_info(base_info.copy())
assert_object_equal(new_info, exp_info_3, err_msg="meas_date=None mismatch")
_check_equiv(new_info, exp_info_3, err_msg="meas_date=None mismatch")


@pytest.mark.parametrize(
Expand Down Expand Up @@ -777,8 +790,8 @@ def _complete_info(info):
height=2.0,
)
info["helium_info"] = dict(
he_level_raw=12.34,
helium_level=45.67,
he_level_raw=np.float32(12.34),
helium_level=np.float32(45.67),
meas_date=datetime(2024, 11, 14, 14, 8, 2, tzinfo=timezone.utc),
orig_file_guid="e",
)
Expand All @@ -796,14 +809,13 @@ def _complete_info(info):
machid=np.ones(2, int),
secs=d[0],
usecs=d[1],
date=d,
),
experimenter="j",
max_info=dict(
max_st=[],
sss_ctc=[],
sss_cal=[],
sss_info=dict(head_pos=None, in_order=8),
max_st=dict(),
sss_ctc=dict(),
sss_cal=dict(),
sss_info=dict(in_order=8),
),
date=d,
),
Expand All @@ -830,8 +842,8 @@ def test_anonymize(tmp_path):
# test mne.anonymize_info()
events = read_events(event_name)
epochs = Epochs(raw, events[:1], 2, 0.0, 0.1, baseline=None)
_test_anonymize_info(raw.info)
_test_anonymize_info(epochs.info)
_test_anonymize_info(raw.info, tmp_path)
_test_anonymize_info(epochs.info, tmp_path)

# test instance methods & I/O roundtrip
for inst, keep_his in zip((raw, epochs), (True, False)):
Expand Down
9 changes: 5 additions & 4 deletions mne/_fiff/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as np
from scipy.sparse import csc_array, csr_array

from ..utils import _file_like, _validate_type, logger
from ..utils import _check_fname, _file_like, _validate_type, logger
from ..utils.numerics import _date_to_julian
from .constants import FIFF

Expand Down Expand Up @@ -277,7 +277,7 @@ def end_block(fid, kind):
write_int(fid, FIFF.FIFF_BLOCK_END, kind)


def start_file(fname, id_=None):
def start_file(fname, id_=None, *, overwrite=True):
"""Open a fif file for writing and writes the compulsory header tags.

Parameters
Expand All @@ -294,6 +294,7 @@ def start_file(fname, id_=None):
fid = fname
fid.seek(0)
else:
fname = _check_fname(fname, overwrite=overwrite)
fname = str(fname)
if op.splitext(fname)[1].lower() == ".gz":
logger.debug("Writing using gzip")
Expand All @@ -311,9 +312,9 @@ def start_file(fname, id_=None):


@contextmanager
def start_and_end_file(fname, id_=None):
def start_and_end_file(fname, id_=None, *, overwrite=True):
"""Start and (if successfully written) close the file."""
with start_file(fname, id_=id_) as fid:
with start_file(fname, id_=id_, overwrite=overwrite) as fid:
yield fid
end_file(fid) # we only hit this line if the yield does not err

Expand Down
4 changes: 2 additions & 2 deletions mne/utils/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,9 @@ def assert_and_remove_boundary_annot(annotations, n=1):
annotations.delete(idx)


def assert_object_equal(a, b, *, err_msg="Object mismatch"):
def assert_object_equal(a, b, *, err_msg="Object mismatch", allclose=False):
"""Assert two objects are equal."""
d = object_diff(a, b)
d = object_diff(a, b, allclose=allclose)
assert d == "", f"{err_msg}\n{d}"


Expand Down
Loading