-
Notifications
You must be signed in to change notification settings - Fork 95
Description
Is there an existing issue for this?
- I have searched the existing issues
Bug description
Hey everyone, thanks for releasing this awesome tool.
I get an error when calling cebra_model.fit(..., adapt=True) to fine tune a model on new data.
This only happens if the model was trained on multiple sessions, but not when adapting a model trained on a single session.
The actual usage of CEBRA is within a larger code-base so I don't have a mwe, but these are the steps I do:
1: define CEBRA model with hybrid=False.
self.model = CEBRA(
model_architecture=model_architecture,
batch_size=batch_size,
temperature=temperature,
temperature_mode=temperature_mode,
learning_rate=learning_rate,
max_iterations=max_iterations,
time_offsets=time_offsets,
output_dimension=embedding_size,
device="cuda",
conditional="time_delta", # use behavioral data
distance=distance,
verbose=verbose,
hybrid=hybrid,
max_adapt_iterations=500,
)2: fit self.model on multiple sessions with:
self.model.fit(
Xs,
Ys,
adapt=False,
)where Xs and Ys are lists of np.ndarray with neural and behavioral (continuous) data from multiple experimental sessions.
3: adapt the model to new data:
self.model.fit( # single session
X_new,
Y_new,
adapt=True,
)with X_new, Y_new being arrays with the data for a single new session.
I get KeyError: '0.net.0.weight' - this is the CEBRA part of the error stack:

I did some digging and the problem is that the adapt_model created here has different keys in .state_dict() compared to self.model_.
adapt_model keys:
odict_keys(['net.0.weight', 'net.0.bias', 'net.2.module.0.weight', 'net.2.module.0.bias',
'net.3.module.0.weight', 'net.3.module.0.bias', 'net.4.module.0.weight', 'net.4.module.0.bias',
'net.5.weight', 'net.5.bias'])
self.model_ keys:
['0.net.0.weight', '0.net.0.bias', '0.net.2.module.0.weight',
'0.net.2.module.0.bias', '0.net.3.module.0.weight', '0.net.3.module.0.bias', '0.net.4.module.0.weight',
'0.net.4.module.0.bias', '0.net.5.weight', '0.net.5.bias', '1.net.0.weight', '1.net.0.bias',
'1.net.2.module.0.weight', '1.net.2.module.0.bias', '1.net.3.module.0.weight', '1.net.3.module.0.bias',
'1.net.4.module.0.weight', '1.net.4.module.0.bias', '1.net.5.weight', '1.net.5.bias', '2.net.0.weight',
'2.net.0.bias', '2.net.2.module.0.weight', '2.net.2.module.0.bias', '2.net.3.module.0.weight',
'2.net.3.module.0.bias', '2.net.4.module.0.weight', '2.net.4.module.0.bias', '2.net.5.weight',
'2.net.5.bias', '3.net.0.weight', '3.net.0.bias', '3.net.2.module.0.weight', '3.net.2.module.0.bias',
'3.net.3.module.0.weight', '3.net.3.module.0.bias', '3.net.4.module.0.weight', '3.net.4.module.0.bias',
'3.net.5.weight', '3.net.5.bias']
When I first train CEBRA on a single session self.model_'s keys match those of adapt_model.
Is training on multiple sessions + fine-tuning on new a new one not allowed? Am I doing something wrong in using CEBRA?
Thanks,
Federico
Operating System
Windows
CEBRA version
cebra version: '0.2.0'
Device type
gpu: NVIDIA GeForce RTX 3080
Steps To Reproduce
No response
Relevant log output
No response
Anything else?
No response
Code of Conduct
- I agree to follow this project's Code of Conduct