In this study, we propose UniVG, a novel generative foundation model for universal few-shot vascular image segmentation based on compositional learning and few-shot generative adaptation. By decomposing and recombining vascular structures with diverse backgrounds, our framework can synthesize highly realistic and diverse vessel images using minimal training data. UniVG has been systematically validated across 11 vascular datasets spanning five different modalities. Experimental results demonstrate its superior performance over existing methods in few-shot scenarios (using as few as five annotated images), confirming the model's excellent generalization capability and cross-domain adaptability.
We have established this repository to support the reproducibility of our work:
"Generative Data-engine Foundation Model for Universal Few-shot 2D Vascular Image Segmentation"
The UniVG-58K dataset presented in this paper comprises both pre-training data and downstream task data, which can be accessed at https://huggingface.co/datasets/xinaloha/UniVG.
| Component | Status | Release Date |
|---|---|---|
| Spatial Colonization Algorithm (SCA) for vascular structure synthesis | ✅ Available | 2024.12.12 |
| Pre-trained Foundation Model & Training Code | ✅ Available | 2025.01.11 |
| Downstream Modality Fine-tuning Code | ✅ Available | 2025.01.26 |
| Downstream Segmentation Training Code | ✅ Available | 2025.01.26 |
# Clone the repository
git clone https://github.com/XinAloha/UniVG.git
cd UniVG/R-SCA
# Install dependencies
pip install numpy opencv-python Pillow matplotlib scipy shapely PyYAML tqdmGenerate synthetic vascular structures:
python Main.py --input-dir ./RealCoronaryArteryMask --output-dir ./output --modality CoronaryArterycd UniVG/pretrainedModel
# Install dependencies
pip install torch torchvision pytorch-lightning albumentations opencv-python einopsFirst, pre-train the LDM on your vascular image dataset to learn domain-specific features:
python train_ldm.py \
--resume_path /path/to/sd-v1-5.ckpt \
--data_path /path/to/vascular/images \
--batch_size 2 \
--learning_rate 1e-4 \
--max_epochs 100 \
--checkpoint_dir ./ldm_checkpoints/Dataset structure for LDM:
vascular_images/
├── image_001.png
├── image_002.png
└── ...
After pre-training the LDM, use it to initialize ControlNet for conditional generation:
Train from pretrained LDM weights:
python train_controlnet.py \
--pretrained ./ldm_checkpoints/ldm-epoch=100.ckpt \
--data_path /path/to/paired/dataset \
--batch_size 2 \
--max_epochs 2000Generate vascular images from condition inputs:
python inference_controlnet.py \
--checkpoint /path/to/trained/model.ckpt \
--input_path /path/to/input/masks \
--output_path /path/to/output/images \
--ddim_steps 20For LDM pre-training (unconditional generation):
vascular_images/
├── image_001.png
├── image_002.png
├── image_003.png
└── ...
For ControlNet training (conditional generation):
paired_dataset/
├── source/ # Condition images (masks, edges, etc.)
│ ├── image_001.png
│ └── ...
└── target/ # Target images (ground truth)
├── image_001.png
└── ...
For fine-tuning the pre-trained model on specific downstream vascular modalities, we recommend using SD-Trainer (lora-scripts) - a LoRA & Dreambooth training GUI based on kohya-ss's trainer.
# Clone and install
git clone --recurse-submodules https://github.com/Akegarasu/lora-scripts
cd lora-scripts
# Windows
./install.ps1 # or install-cn.ps1 for China mainland
./run_gui.ps1
# Linux
bash install.bash
bash run_gui.shThe GUI will open at http://127.0.0.1:28000
| Parameter | Value | Description |
|---|---|---|
| Base Model | UniVG checkpoint | Pre-trained model path |
| Network Type | LoRA | Low-Rank Adaptation |
| Network Rank | 32 | Higher = more capacity |
| Learning Rate | 1e-4 ~ 5e-5 | Lower for few-shot |
| Resolution | 512 | Match pre-training |
Train UNet segmentation model with joint real and generated data.
cd UniVG/Segmentation
# Install dependencies
pip install -r requirements.txtPrepare your data in the following structure:
Real data:
real_data/
├── train_origin/ # Training images
├── train_mask/ # Training masks
├── test_origin/ # Test images
└── test_mask/ # Test masks
Generated data:
generated_data/
├── images/ # Generated images
└── masks/ # Corresponding masks
Train with joint real and generated data:
python train.py \
--real_data_path /path/to/real/data \
--fake_data_path /path/to/generated/data \
--output_dir ./checkpoints \
--batch_size 2 \
--learning_rate 3e-4 \
--steps 75000 \
--save_steps 6250Training Parameters:
| Parameter | Default | Description |
|---|---|---|
--real_data_path |
required | Path to real training data |
--fake_data_path |
required | Path to generated training data |
--output_dir |
./checkpoints |
Directory to save checkpoints |
--batch_size |
2 | Batch size per data source |
--learning_rate |
3e-4 | Learning rate |
--steps |
75000 | Total training steps |
--save_steps |
6250 | Save checkpoint interval |
--batch_norm |
False | Use batch normalization |
--resume |
None | Resume from checkpoint |
Evaluate trained model on test data:
python evaluate.py \
--checkpoint ./checkpoints/unet_final.pth \
--test_data_path /path/to/test/data \
--save_predictions \
--output_dir ./resultsEvaluation Metrics:
- Dice: Dice coefficient
- IoU: Intersection over Union
- clDice: Centerline Dice (for tubular structures)
- NSD: Normalized Surface Dice
