This code implements a baseline 3D multi-label classification model for intracranial hemorrhage (ICH) detection using PyTorch Lightning and MONAI. Here's a breakdown of its key components:
- Handles loading and preprocessing of 3D NIfTI volumes
- Extracts multi-label information from segmentation masks (6 classes: any_bleed, EDH, IPH, IVH, SAH, SDH)
- Implements data augmentation for training
- Calculates class weights to handle imbalanced data
- Provides basic train/val/test splits with proper transforms
- Uses a 3D ResNet and/or DenseNet backbone from MONAI
- Implements multi-label classification with weighted BCE loss
- Includes comprehensive metrics tracking (AUC, precision, recall, F1, etc.)
- Supports different learning rate schedulers
- Provides detailed test evaluation with ROC curves and performance metrics
- Uses PyTorch Lightning for streamlined training
- Supports top-k ModelCheckpoint and EarlyStopping
- Supports mixed precision training with multiple-gpus
- Logs metrics to TensorBoard
- Handles both training/val and testing modes
- 3D volumetric processing with proper spatial transforms
- Multi-label classification for hemorrhage subtypes
- Class imbalance handling through weighted loss
- Comprehensive evaluation metrics
- Detailed test results visualization
- Reproducible experiment tracking
Organize your data in the expected directory structure:
data_dir/
train/
images/*.nii.gz
masks/*.nii.gz
val/
images/*.nii.gz
masks/*.nii.gz
test/
images/*.nii.gz
masks/*.nii.gz
negative/ (optional)
images/*.nii.gz
Run training
python ich_classifier.py --data_dir /path/to/data --gpus 1For testing trained model:
python ich_classifier.py --data_dir /path/to/data --test_only --checkpoint_path /path/to/checkpoint.ckpt- Training progress and validation metrics
- Best model checkpoints
- TensorBoard logs for visualization
tensorboard --logdir=lightning_logs/ich_classification- Comprehensive test results including ROC curves and performance metrics per class