Repo used to train CNNs for the marmoset retina used in the "Learning to cluster neuronal function" paper.
The paper can be found here
The data used for training the models can be found here.
The cell_type_dict.pkl file in the root of this repository is used to assign cell types.
To set up the environment, create a virtual environment using the provided requirements.txt file with python 3.9.
The models were trained using the run_multi_retinal_marmoset_cnn.pyscript in this repository.
The default arguments correspond to the settings used to train the base models.
The shell script used with sbatch to launch the training jobs is cnn_nm.sh.
The performance of the trained models can be evaluated using the evaluations/multiretinal_model_avaluation.py script.
The continue_training.py script takes the trained models and finetunes them including the clustering loss.