Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions examples/images/dreambooth/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,29 @@ torchrun --nproc_per_node 2 train_dreambooth_colossalai.py \
--placement="cuda"
```

## New API
We have modified our previous implementation of Dreambooth with our new Booster API, which offers a more flexible and efficient way to train your model. The new API is more user-friendly and easy to use. You can find the new API in `train_dreambooth_colossalai.py`.
We have also offer a shell script `test_ci.sh` for you to go through all our plugins for the booster.
For more information about the booster API you can refer to https://colossalai.org/docs/basics/booster_api/.

## Performance

| Strategy | #GPU | Batch Size | GPU RAM(GB) | speedup |
|:--------------:|:----:|:----------:|:-----------:|:-------:|
| Traditional | 1 | 16 | oom | \ |
| Traditional | 1 | 8 | 61.81 | 1 |
| torch_ddp | 4 | 16 | oom | \ |
| torch_ddp | 4 | 8 | 41.97 | 0.97 |
| gemini | 4 | 16 | 53.29 | \ |
| gemini | 4 | 8 | 29.36 | 2.00 |
| low_level_zero | 4 | 16 | 52.80 | \ |
| low_level_zero | 4 | 8 | 28.87 | 2.02 |

The evaluation is performed on 4 Nvidia A100 GPUs with 80GB memory each, with GPU 0 & 1, 2 & 3 connected with NVLink.
We finetuned the [stable-diffusion-v1-4](https://huggingface.co/stabilityai/stable-diffusion-v1-4) model with 512x512 resolution on the [Teyvat](https://huggingface.co/datasets/Fazzie/Teyvat) dataset and compared
the memory cost and the throughput for the plugins.


## Inference

Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. `--instance_prompt="a photo of sks dog" ` in the above example) in your prompt.
Expand Down
24 changes: 10 additions & 14 deletions examples/images/dreambooth/colossalai.sh
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
export MODEL_NAME= <Your Pretrained Model Path>
export INSTANCE_DIR= <Your Input Pics Path>
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"

HF_DATASETS_OFFLINE=1
TRANSFORMERS_OFFLINE=1
HF_DATASETS_OFFLINE=1
TRANSFORMERS_OFFLINE=1
DIFFUSERS_OFFLINE=1

torchrun --nproc_per_node 2 --master_port=25641 train_dreambooth_colossalai.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--instance_prompt="a photo of a dog" \
torchrun --nproc_per_node 4 --standalone train_dreambooth_colossalai.py \
--pretrained_model_name_or_path="/data/dreambooth/diffuser/stable-diffusion-v1-4" \
--instance_data_dir="/data/dreambooth/Teyvat/data" \
--output_dir="./weight_output" \
--instance_prompt="a picture of a dog" \
--resolution=512 \
--plugin="gemini" \
--train_batch_size=1 \
--gradient_accumulation_steps=1 \
--learning_rate=5e-6 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--num_class_images=200 \
--placement="cuda" \
--test_run=True \
--placement="auto" \
6 changes: 3 additions & 3 deletions examples/images/dreambooth/dreambooth.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
python train_dreambooth.py \
--pretrained_model_name_or_path= ## Your Model Path \
--instance_data_dir= ## Your Training Input Pics Path \
--output_dir="path-to-save-model" \
--pretrained_model_name_or_path="/data/dreambooth/diffuser/stable-diffusion-v1-4" \
--instance_data_dir="/data/dreambooth/Teyvat/data" \
--output_dir="./weight_output" \
--instance_prompt="a photo of a dog" \
--resolution=512 \
--train_batch_size=1 \
Expand Down
25 changes: 25 additions & 0 deletions examples/images/dreambooth/test_ci.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/bin/bash
set -xe
pip install -r requirements.txt

HF_DATASETS_OFFLINE=1
TRANSFORMERS_OFFLINE=1
DIFFUSERS_OFFLINE=1

# "torch_ddp" "torch_ddp_fp16" "low_level_zero"
for plugin in "gemini"; do
torchrun --nproc_per_node 4 --standalone train_dreambooth_colossalai.py \
--pretrained_model_name_or_path="/data/dreambooth/diffuser/stable-diffusion-v1-4" \
--instance_data_dir="/data/dreambooth/Teyvat/data" \
--output_dir="./weight_output" \
--instance_prompt="a picture of a dog" \
--resolution=512 \
--plugin=$plugin \
--train_batch_size=1 \
--learning_rate=5e-6 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--test_run=True \
--num_class_images=200 \
--placement="auto" # "cuda"
done
100 changes: 59 additions & 41 deletions examples/images/dreambooth/train_dreambooth_colossalai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
from pathlib import Path
from typing import Optional
import shutil

import torch
import torch.nn.functional as F
Expand All @@ -21,9 +22,12 @@
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer
from colossalai.zero import ColoInitContext
from colossalai.zero.gemini import get_static_torch_model
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin

disable_existing_loggers()
logger = get_dist_logger()
Expand Down Expand Up @@ -58,6 +62,13 @@ def parse_args(input_args=None):
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--externel_unet_path",
type=str,
default=None,
required=False,
help="Path to the externel unet model.",
)
parser.add_argument(
"--revision",
type=str,
Expand Down Expand Up @@ -187,12 +198,19 @@ def parse_args(input_args=None):
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument("--test_run", default=False, help="Whether to use a smaller dataset for test run.")
parser.add_argument(
"--hub_model_id",
type=str,
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument('-p',
'--plugin',
type=str,
default='torch_ddp',
choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
help="plugin to use")
parser.add_argument(
"--logging_dir",
type=str,
Expand Down Expand Up @@ -250,6 +268,7 @@ def __init__(
class_prompt=None,
size=512,
center_crop=False,
test=False,
):
self.size = size
self.center_crop = center_crop
Expand All @@ -260,6 +279,8 @@ def __init__(
raise ValueError("Instance images root doesn't exists.")

self.instance_images_path = list(Path(instance_data_root).iterdir())
if test:
self.instance_images_path = self.instance_images_path[:10]
self.num_instance_images = len(self.instance_images_path)
self.instance_prompt = instance_prompt
self._length = self.num_instance_images
Expand Down Expand Up @@ -339,18 +360,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
return f"{organization}/{model_id}"


# Gemini + ZeRO DDP
def gemini_zero_dpp(model: torch.nn.Module, placement_policy: str = "auto"):
from colossalai.nn.parallel import GeminiDDP

model = GeminiDDP(model,
device=get_current_device(),
placement_policy=placement_policy,
pin_memory=True,
search_range_mb=64)
return model


def main(args):
if args.seed is None:
colossalai.launch_from_torch(config={})
Expand Down Expand Up @@ -392,7 +401,7 @@ def main(args):
images = pipeline(example["prompt"]).images

for i, image in enumerate(images):
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
hash_image = hashlib.sha256(image.tobytes()).hexdigest()
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename)

Expand Down Expand Up @@ -452,12 +461,18 @@ def main(args):
revision=args.revision,
)

logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
with ColoInitContext(device=get_current_device()):

if args.externel_unet_path is None:
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
low_cpu_mem_usage=False)
subfolder="unet",
revision=args.revision,
low_cpu_mem_usage=False)
else:
logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0])
unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path,
revision=args.revision,
low_cpu_mem_usage=False)

vae.requires_grad_(False)
text_encoder.requires_grad_(False)
Expand All @@ -468,10 +483,22 @@ def main(args):
if args.scale_lr:
args.learning_rate = args.learning_rate * args.train_batch_size * world_size

unet = gemini_zero_dpp(unet, args.placement)
# Use Booster API to use Gemini/Zero with ColossalAI

booster_kwargs = {}
if args.plugin == 'torch_ddp_fp16':
booster_kwargs['mixed_precision'] = 'fp16'
if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin()
elif args.plugin == 'gemini':
plugin = GeminiPlugin(placement_policy=args.placement, strict_ddp_mode=True, initial_scale=2 ** 5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2 ** 5)

booster = Booster(plugin=plugin, **booster_kwargs)

# config optimizer for colossalai zero
optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)
optimizer = HybridAdam(unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)

# load noise_scheduler
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
Expand All @@ -486,6 +513,7 @@ def main(args):
tokenizer=tokenizer,
size=args.resolution,
center_crop=args.center_crop,
test=args.test_run
)

def collate_fn(examples):
Expand Down Expand Up @@ -554,6 +582,8 @@ def collate_fn(examples):
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

unet, optimizer, _, _, lr_scheduler = booster.boost(unet, optimizer, lr_scheduler=lr_scheduler)

# Train!
total_batch_size = args.train_batch_size * world_size

Expand Down Expand Up @@ -642,36 +672,24 @@ def collate_fn(examples):

if global_step % args.save_steps == 0:
torch.cuda.synchronize()
torch_unet = get_static_torch_model(unet)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin"))
if local_rank == 0:
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=torch_unet,
revision=args.revision,
)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
pipeline.save_pretrained(save_path)
if not os.path.exists(os.path.join(save_path, "config.json")):
shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path)
logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])
if global_step >= args.max_train_steps:
break

torch.cuda.synchronize()
unet = get_static_torch_model(unet)

booster.save_model(unet, os.path.join(args.output_dir, "diffusion_pytorch_model.bin"))
logger.info(f"Saving model checkpoint to {args.output_dir} on rank {local_rank}")
if local_rank == 0:
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=unet,
revision=args.revision,
)

pipeline.save_pretrained(args.output_dir)
logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0])

if not os.path.exists(os.path.join(args.output_dir, "config.json")):
shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), args.output_dir)
if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)


if __name__ == "__main__":
args = parse_args()
main(args)
Loading