Code for the paper "Automating MedSAM by Learning Prompts with Weak Few-Shot Supervision" by Mélanie Gaillochet, Christian Desrosiers and Hervé Lombaert, presented at MICCAI-MedAGI (2024)
- Option 1: Create conda environment
$ conda create -n py310 python=3.10 pip
$ conda activate py310
$ conda install pytorch torchvision -c pytorch -c nvidia
$ conda install numpy matplotlib scikit-learn scikit-image h5py
$ pip install nibabel comet_ml flatten-dict pytorch_lightning
$ pip install transformers monai opencv-python
- Option 2: Create virtual environment with pip
$ virtualenv venv
$ source venv/bin/activate
$ pip install torch numpy torchvision matplotlib scikit-image monai h5py nibabel comet_ml flatten-dict pytorch_lightning pytorch-lightning-spells opencv-python
$ git clone https://github.com/bowang-lab/MedSAM
$ cd MedSAM
$ pip install -e .
- Get package to download google drive files:
$ pip install gdown - Create target folder:
$ mkdir <path-to-models-folder>/medsam - Download MedSAM model checkpoint:
$ gdown --id 1UAmWL88roYR7wKlnApw5Bcuzf2iQgk6_ -O <path-to-models-folder>/medsam/medsam_vit_b.pth
Note: For this project, the model checkpoint ismedsam_vit_b.pth
- ACDC challenge (link to train & test sets)
- CAMUS challenge (link to data)
- HC Challenge (link to data)
Examples of the preprocessing applied are shown in the notebook src/JupyterNotebook/Preprocessing_ACDC.ipynb
The notebook preprocessed the data and generate the following folders that can be used by our dataloader: train_2d_images, train_2d_masks, val_2d_images, val_2d_masks, test_2d_images, test_2d_masks.
Note: For this project, preprocessed outputs folders are ACDC/preprocessed_sam_slice_minsize_nobgslice_crop256_256, CAMUS_public/preprocessed_sam_slice_nobgslice_crop512_512 and HC18/preprocessed_sam_slice_foreground_crop640_640
To use pre-computed image embeddings:
- Create a new folder
image_embeddings/medsam_vit_b-pthin the preprocessed data folder. - Create subfolders
train_2d_images,val_2d_imagesandtest_2d_imagesinmedsam_vit_b-pth. - Run MedSAM and save each image embedding with the same name as the original files, in
train_2d_images,val_2d_imagesandtest_2d_images.
To use pre-computed positional encodings:
- Create folder
image_positional_embeddings/medsam_vit_b-pthin the preprocessed data folder. - Create subfolders
train_2d_images,val_2d_imagesandtest_2d_imagesinmedsam_vit_b-pth. - Run MedSAM and save each computed positional encoding with the same name as the original files, in
train_2d_images,val_2d_imagesandtest_2d_images. - (For fixed positional encoding) Copy positional encoding of first file in
image_positional_embeddings/medsam_vit_b-pth/fixed
The data folder (data_dir) will have the following structure:
<data-dir>
├── <dataset_name>
│ └── raw
│ └── ...
│ └── <preprocessed-folder-name>
| └── train_2d_images
| └── <patient_name_slice_number.nii.gz>
| └── ...
| └── train_2d_masks
| └── <patient_name_slice_number.nii.gz>
| └── ...
| └── val_2d_images
| └── val_2d_masks
| └── test_2d_images
| └── test_2d_masks
| └── image_embeddings
| └── └── <model-checkpoint>
| └── train_2d_images
| └── <patient_name_slice_number.nii.gz>
| └── ...
| └── val_2d_images
| └── test_2d_images
| └── image_positional_embeddings
| └── <model-checkpoint>
| └── fixed
| └── <patient_name_slice_number.h5>
The code is located in the src folder.
- Change path to sam_checkpoint in
Configs/model_config/medsam_config/yaml. - If using Comet-ml logger, fill in api-key and workspace name in
Configs/logger_config.yaml.
Experiments can be run with:
$ bash src/Bash_scripts/run_experiments.sh <path-to-dataset-config> <class-id> <num-samples> <gpu-idx> <optional-experiment-log-name>
ie: bash src/Bash_scripts/run_experiments.sh data_config/ACDC_256.yaml 1 0 0
Options:
<path-to-dataset-config>:data_config/ACDC_256.yaml,data_config/CAMUS_512.yamlordata_config/HC_640.yaml<class-id>: 1 or 3<num-samples>: 0 (all) or 10<gpu-idx>: depending on server
The bash script runs the main.py file that trains the prompt module.
There are several input arguments, but most of them get filled in when using the appropriate bash script.
Input of main.py:
# These are the paths to the data and output folder
--data_dir # path to directory where we saved our (preprocessed) data
--output_dir # path to directory where we want to save our output model
# These are config file names located in src/Config
--data_config
--model_config
--module_config
--train_config
--logger_config
--loss_config # can be several file names
# Additional training configs (seed and gpu)
--seed
--num_gpu # number of GPU devices to use
--gpu_idx # otherwise, gpu index, if we want to use a specific gpu
# Training hyper-parameters that we should change according to the dataset
# Note: arguments start with 'train__' if modifying train_config,
# and with 'data__' if modifying model_config, etc.
# Name is determined by hierarchical keywords to value in config, each separated by '__'
--data__use_precomputed_sam_embeddings # whether to use precomputed embeddings in data folder
--train__sam_positional_encoding # Whether to use fixed positional encoding from data folder
--train__train_indices # indices of labelled training data for the main segmentation task
Most hyperparameters are set to the default value used to perform the experiments in our paper.
The main file can be run directly
(ie: python src/main.py --data_dir <my-data-dir> --output-dir <my-output-dir> --data_config data_config/ACDC_256.yaml --data__class_to_segment 1 --seed 0)