Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions model-zoo/models/mednist_ddpm/bundle/configs/common.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# This file defines common definitions used in training and inference, most importantly the network definition

imports:
- $import os
- $import datetime
- $import torch
- $import scripts
- $import monai
- $import generative
- $import torch.distributed as dist

image: $monai.utils.CommonKeys.IMAGE
label: $monai.utils.CommonKeys.LABEL
pred: $monai.utils.CommonKeys.PRED

is_dist: '$dist.is_initialized()'
rank: '$dist.get_rank() if @is_dist else 0'
is_not_rank0: '$@rank > 0'
device: '$torch.device(f"cuda:{@rank}" if torch.cuda.is_available() else "cpu")'

network_def:
_target_: generative.networks.nets.DiffusionModelUNet
spatial_dims: 2
in_channels: 1
out_channels: 1
num_channels: [64, 128, 128]
attention_levels: [false, true, true]
num_res_blocks: 1
num_head_channels: 128

network: $@network_def.to(@device)

bundle_root: .
ckpt_path: $@bundle_root + '/models/model.pt'
use_amp: true
image_dim: 64
image_size: [1, '@image_dim', '@image_dim']
num_train_timesteps: 1000

base_transforms:
- _target_: LoadImaged
keys: '@image'
image_only: true
- _target_: EnsureChannelFirstd
keys: '@image'
- _target_: ScaleIntensityRanged
keys: '@image'
a_min: 0.0
a_max: 255.0
b_min: 0.0
b_max: 1.0
clip: true

scheduler:
_target_: generative.networks.schedulers.DDPMScheduler
num_train_timesteps: '@num_train_timesteps'

inferer:
_target_: generative.inferers.DiffusionInferer
scheduler: '@scheduler'

38 changes: 38 additions & 0 deletions model-zoo/models/mednist_ddpm/bundle/configs/infer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# This defines an inference script for generating a random image to a Pytorch file

batch_size: 1
num_workers: 0

noise: $torch.rand(1,1,@image_dim,@image_dim) # create a random image every time this program is run

out_file: "" # where to save the tensor to

# using a lambda this defines a simple sampling function used below
sample: '$lambda x: @inferer.sample(input_noise=x, diffusion_model=@network, scheduler=@scheduler)'

load_state: '$@network.load_state_dict(torch.load(@ckpt_path))' # command to load the saved model weights

save_trans:
_target_: Compose
transforms:
- _target_: ScaleIntensity
minv: 0.0
maxv: 255.0
- _target_: ToTensor
track_meta: false
- _target_: SaveImage
output_ext: "jpg"
resample: false
output_dtype: '$torch.uint8'
separate_folder: false
output_postfix: '@out_file'

# program to load the model weights, run `sample`, and store results to `out_file`
testing:
- '@load_state'
- '$torch.save(@sample(@noise.to(@device)), @out_file)'

#alternative version which saves to a jpg file
testing_jpg:
- '@load_state'
- '$@save_trans(@sample(@noise.to(@device))[0])'
21 changes: 21 additions & 0 deletions model-zoo/models/mednist_ddpm/bundle/configs/logging.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
[loggers]
keys=root

[handlers]
keys=consoleHandler

[formatters]
keys=fullFormatter

[logger_root]
level=INFO
handlers=consoleHandler

[handler_consoleHandler]
class=StreamHandler
level=INFO
formatter=fullFormatter
args=(sys.stdout,)

[formatter_fullFormatter]
format=%(asctime)s - %(name)s - %(levelname)s - %(message)s
59 changes: 59 additions & 0 deletions model-zoo/models/mednist_ddpm/bundle/configs/metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
{
"schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220729.json",
"version": "0.1.0",
"changelog": {
"0.1.0": "Initial version"
},
"monai_version": "1.0.0",
"pytorch_version": "1.10.2",
"numpy_version": "1.21.2",
"optional_packages_version": {"generative":"0.1.0"},
"task": "MedNIST Hand Generation",
"description": "",
"authors": "Walter Hugo Lopez Pinaya, Mark Graham, and Eric Kerfoot",
"copyright": "Copyright (c) KCL",
"references": [],
"intended_use": "This is suitable for research purposes only",
"image_classes": "Single channel magnitude data",
"data_source": "MedNIST",
"network_data_format": {
"inputs": {
"image": {
"type": "image",
"format": "magnitude",
"modality": "xray",
"num_channels": 1,
"spatial_shape": [
1,
64,
64
],
"dtype": "float32",
"value_range": [],
"is_patch_data": false,
"channel_def": {
"0": "image"
}
}
},
"outputs": {
"pred": {
"type": "image",
"format": "magnitude",
"modality": "xray",
"num_channels": 1,
"spatial_shape": [
1,
64,
64
],
"dtype": "float32",
"value_range": [],
"is_patch_data": false,
"channel_def": {
"0": "image"
}
}
}
}
}
157 changes: 157 additions & 0 deletions model-zoo/models/mednist_ddpm/bundle/configs/train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# This defines the training script for the network

# choose a new directory for every run
output_dir: $datetime.datetime.now().strftime('./results/output_%y%m%d_%H%M%S')
dataset_dir: ./data

train_data:
_target_ : MedNISTDataset
root_dir: '@dataset_dir'
section: training
download: true
progress: false
seed: 0

val_data:
_target_ : MedNISTDataset
root_dir: '@dataset_dir'
section: validation
download: true
progress: false
seed: 0

train_datalist: '$[{"image": item["image"]} for item in @train_data.data if item["class_name"] == "Hand"]'
val_datalist: '$[{"image": item["image"]} for item in @val_data.data if item["class_name"] == "Hand"]'

batch_size: 8
num_substeps: 1
num_workers: 4
use_thread_workers: false

lr: 0.000025
rand_prob: 0.5
num_epochs: 75
val_interval: 5
save_interval: 5

train_transforms:
- _target_: RandAffined
keys: '@image'
rotate_range:
- ['$-np.pi / 36', '$np.pi / 36']
- ['$-np.pi / 36', '$np.pi / 36']
translate_range:
- [-1, 1]
- [-1, 1]
scale_range:
- [-0.05, 0.05]
- [-0.05, 0.05]
spatial_size: [64, 64]
padding_mode: "zeros"
prob: '@rand_prob'

train_ds:
_target_: Dataset
data: $@train_datalist
transform:
_target_: Compose
transforms: '$@base_transforms + @train_transforms'

train_loader:
_target_: ThreadDataLoader
dataset: '@train_ds'
batch_size: '@batch_size'
repeats: '@num_substeps'
num_workers: '@num_workers'
use_thread_workers: '@use_thread_workers'
persistent_workers: '$@num_workers > 0'
shuffle: true

val_ds:
_target_: Dataset
data: $@val_datalist
transform:
_target_: Compose
transforms: '@base_transforms'

val_loader:
_target_: DataLoader
dataset: '@val_ds'
batch_size: '@batch_size'
num_workers: '@num_workers'
persistent_workers: '$@num_workers > 0'
shuffle: false

lossfn:
_target_: torch.nn.MSELoss

optimizer:
_target_: torch.optim.Adam
params: $@network.parameters()
lr: '@lr'

prepare_batch:
_target_: scripts.DiffusionPrepareBatch
num_train_timesteps: '@num_train_timesteps'

val_handlers:
- _target_: StatsHandler
name: train_log
output_transform: '$lambda x: None'
_disabled_: '@is_not_rank0'

evaluator:
_target_: SupervisedEvaluator
device: '@device'
val_data_loader: '@val_loader'
network: '@network'
amp: '@use_amp'
inferer: '@inferer'
prepare_batch: '@prepare_batch'
key_val_metric:
val_mean_abs_error:
_target_: MeanAbsoluteError
output_transform: $monai.handlers.from_engine([@pred, @label])
metric_cmp_fn: '$scripts.inv_metric_cmp_fn'
val_handlers: '$list(filter(bool, @val_handlers))'

handlers:
- _target_: CheckpointLoader
_disabled_: $not os.path.exists(@ckpt_path)
load_path: '@ckpt_path'
load_dict:
model: '@network'
- _target_: ValidationHandler
validator: '@evaluator'
epoch_level: true
interval: '@val_interval'
- _target_: CheckpointSaver
save_dir: '@output_dir'
save_dict:
model: '@network'
save_interval: '@save_interval'
save_final: true
epoch_level: true
_disabled_: '@is_not_rank0'

trainer:
_target_: SupervisedTrainer
max_epochs: '@num_epochs'
device: '@device'
train_data_loader: '@train_loader'
network: '@network'
loss_function: '@lossfn'
optimizer: '@optimizer'
inferer: '@inferer'
prepare_batch: '@prepare_batch'
key_train_metric:
train_acc:
_target_: MeanSquaredError
output_transform: $monai.handlers.from_engine([@pred, @label])
metric_cmp_fn: '$scripts.inv_metric_cmp_fn'
train_handlers: '$list(filter(bool, @handlers))'
amp: '@use_amp'

training:
- '$monai.utils.set_determinism(0)'
- '$@trainer.run()'
30 changes: 30 additions & 0 deletions model-zoo/models/mednist_ddpm/bundle/configs/train_multigpu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# This can be mixed in with the training script to enable multi-GPU training

network:
_target_: torch.nn.parallel.DistributedDataParallel
module: $@network_def.to(@device)
device_ids: ['@device']
find_unused_parameters: true

tsampler:
_target_: DistributedSampler
dataset: '@train_ds'
even_divisible: true
shuffle: true
train_loader#sampler: '@tsampler'
train_loader#shuffle: false

vsampler:
_target_: DistributedSampler
dataset: '@val_ds'
even_divisible: false
shuffle: false
val_loader#sampler: '@vsampler'

training:
- $import torch.distributed as dist
- $dist.init_process_group(backend='nccl')
- $torch.cuda.set_device(@device)
- $monai.utils.set_determinism(seed=123),
- $@trainer.run()
- $dist.destroy_process_group()
Loading