Skip to content

KeyError with cebra.fit(adapt=True) on multi-session embeddings (only) #108

@FedeClaudi

Description

@FedeClaudi

Is there an existing issue for this?

  • I have searched the existing issues

Bug description

Hey everyone, thanks for releasing this awesome tool.

I get an error when calling cebra_model.fit(..., adapt=True) to fine tune a model on new data.
This only happens if the model was trained on multiple sessions, but not when adapting a model trained on a single session.

The actual usage of CEBRA is within a larger code-base so I don't have a mwe, but these are the steps I do:

1: define CEBRA model with hybrid=False.

        self.model = CEBRA(
            model_architecture=model_architecture,
            batch_size=batch_size,
            temperature=temperature,
            temperature_mode=temperature_mode,
            learning_rate=learning_rate,
            max_iterations=max_iterations,
            time_offsets=time_offsets,
            output_dimension=embedding_size,
            device="cuda",
            conditional="time_delta",  # use behavioral data
            distance=distance,
            verbose=verbose,
            hybrid=hybrid,
            max_adapt_iterations=500,
        )

2: fit self.model on multiple sessions with:

            self.model.fit(
                Xs,
                Ys,
                adapt=False,
            )

where Xs and Ys are lists of np.ndarray with neural and behavioral (continuous) data from multiple experimental sessions.

3: adapt the model to new data:

            self.model.fit(  # single session
                X_new,
                Y_new,
                adapt=True,
            )

with X_new, Y_new being arrays with the data for a single new session.

I get KeyError: '0.net.0.weight' - this is the CEBRA part of the error stack:
image

I did some digging and the problem is that the adapt_model created here has different keys in .state_dict() compared to self.model_.
adapt_model keys:

odict_keys(['net.0.weight', 'net.0.bias', 'net.2.module.0.weight', 'net.2.module.0.bias',  
'net.3.module.0.weight', 'net.3.module.0.bias', 'net.4.module.0.weight', 'net.4.module.0.bias',        
'net.5.weight', 'net.5.bias'])

self.model_ keys:

['0.net.0.weight', '0.net.0.bias', '0.net.2.module.0.weight',
'0.net.2.module.0.bias', '0.net.3.module.0.weight', '0.net.3.module.0.bias', '0.net.4.module.0.weight',
'0.net.4.module.0.bias', '0.net.5.weight', '0.net.5.bias', '1.net.0.weight', '1.net.0.bias',
'1.net.2.module.0.weight', '1.net.2.module.0.bias', '1.net.3.module.0.weight', '1.net.3.module.0.bias',
'1.net.4.module.0.weight', '1.net.4.module.0.bias', '1.net.5.weight', '1.net.5.bias', '2.net.0.weight',
'2.net.0.bias', '2.net.2.module.0.weight', '2.net.2.module.0.bias', '2.net.3.module.0.weight',
'2.net.3.module.0.bias', '2.net.4.module.0.weight', '2.net.4.module.0.bias', '2.net.5.weight',
'2.net.5.bias', '3.net.0.weight', '3.net.0.bias', '3.net.2.module.0.weight', '3.net.2.module.0.bias',  
'3.net.3.module.0.weight', '3.net.3.module.0.bias', '3.net.4.module.0.weight', '3.net.4.module.0.bias',
'3.net.5.weight', '3.net.5.bias']

When I first train CEBRA on a single session self.model_'s keys match those of adapt_model.
Is training on multiple sessions + fine-tuning on new a new one not allowed? Am I doing something wrong in using CEBRA?

Thanks,
Federico

Operating System

Windows

CEBRA version

cebra version: '0.2.0'

Device type

gpu: NVIDIA GeForce RTX 3080

Steps To Reproduce

No response

Relevant log output

No response

Anything else?

No response

Code of Conduct

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingenhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions