Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions examples/images/dreambooth/test_ci.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,5 @@ for plugin in "gemini"; do
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--test_run=True \
--num_class_images=200 \
--placement="auto" # "cuda"
--num_class_images=200
done
53 changes: 27 additions & 26 deletions examples/images/dreambooth/train_dreambooth_colossalai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import hashlib
import math
import os
import shutil
from pathlib import Path
from typing import Optional
import shutil

import torch
import torch.nn.functional as F
Expand All @@ -19,15 +19,15 @@
from transformers import AutoTokenizer, PretrainedConfig

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext
from colossalai.zero.gemini import get_static_torch_model
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin

disable_existing_loggers()
logger = get_dist_logger()
Expand Down Expand Up @@ -138,10 +138,10 @@ def parse_args(input_args=None):
" resolution"),
)
parser.add_argument(
"--placement",
type=str,
default="cpu",
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
"--offload_optim_frac",
type=float,
default=1.0,
help="Fraction of optimizer states to be offloaded. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--center_crop",
Expand Down Expand Up @@ -461,18 +461,17 @@ def main(args):
revision=args.revision,
)


if args.externel_unet_path is None:
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
low_cpu_mem_usage=False)
subfolder="unet",
revision=args.revision,
low_cpu_mem_usage=False)
else:
logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0])
unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path,
revision=args.revision,
low_cpu_mem_usage=False)
revision=args.revision,
low_cpu_mem_usage=False)

vae.requires_grad_(False)
text_encoder.requires_grad_(False)
Expand All @@ -491,30 +490,31 @@ def main(args):
if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin()
elif args.plugin == 'gemini':
plugin = GeminiPlugin(placement_policy=args.placement, strict_ddp_mode=True, initial_scale=2 ** 5)
plugin = GeminiPlugin(offload_optim_frac=args.offload_optim_frac, strict_ddp_mode=True, initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2 ** 5)
plugin = LowLevelZeroPlugin(initial_scale=2**5)

booster = Booster(plugin=plugin, **booster_kwargs)

# config optimizer for colossalai zero
optimizer = HybridAdam(unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)
optimizer = HybridAdam(unet.parameters(),
lr=args.learning_rate,
initial_scale=2**5,
clipping_norm=args.max_grad_norm)

# load noise_scheduler
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")

# prepare dataset
logger.info(f"Prepare dataset from {args.instance_data_dir}", ranks=[0])
train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
class_prompt=args.class_prompt,
tokenizer=tokenizer,
size=args.resolution,
center_crop=args.center_crop,
test=args.test_run
)
train_dataset = DreamBoothDataset(instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
class_prompt=args.class_prompt,
tokenizer=tokenizer,
size=args.resolution,
center_crop=args.center_crop,
test=args.test_run)

def collate_fn(examples):
input_ids = [example["instance_prompt_ids"] for example in examples]
Expand Down Expand Up @@ -690,6 +690,7 @@ def collate_fn(examples):
if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)


if __name__ == "__main__":
args = parse_args()
main(args)
30 changes: 16 additions & 14 deletions examples/images/dreambooth/train_dreambooth_colossalai_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import hashlib
import math
import os
import shutil
from pathlib import Path
from typing import Optional
import shutil

import torch
import torch.nn.functional as F
Expand All @@ -21,15 +21,15 @@
from transformers import AutoTokenizer, PretrainedConfig

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer
from colossalai.zero.gemini import get_static_torch_model
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin

disable_existing_loggers()
logger = get_dist_logger()
Expand Down Expand Up @@ -459,18 +459,17 @@ def main(args):
revision=args.revision,
)


if args.externel_unet_path is None:
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
low_cpu_mem_usage=False)
subfolder="unet",
revision=args.revision,
low_cpu_mem_usage=False)
else:
logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0])
unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path,
revision=args.revision,
low_cpu_mem_usage=False)
revision=args.revision,
low_cpu_mem_usage=False)
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
Expand All @@ -490,8 +489,7 @@ def main(args):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]

lora_attn_procs[name] = LoRACrossAttnProcessor(hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim)
lora_attn_procs[name] = LoRACrossAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)

unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors)
Expand All @@ -513,14 +511,17 @@ def main(args):
if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin()
elif args.plugin == 'gemini':
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2 ** 5)
plugin = GeminiPlugin(strict_ddp_mode=True, initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2 ** 5)
plugin = LowLevelZeroPlugin(initial_scale=2**5)

booster = Booster(plugin=plugin, **booster_kwargs)

# config optimizer for colossalai zero
optimizer = HybridAdam(unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)
optimizer = HybridAdam(unet.parameters(),
lr=args.learning_rate,
initial_scale=2**5,
clipping_norm=args.max_grad_norm)

# load noise_scheduler
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
Expand Down Expand Up @@ -711,6 +712,7 @@ def collate_fn(examples):
if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)


if __name__ == "__main__":
args = parse_args()
main(args)
52 changes: 26 additions & 26 deletions examples/images/vit/vit_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import time

import torch
import tqdm
import transformers
from args import parse_benchmark_args
from transformers import ViTConfig, ViTForImageClassification
import tqdm

import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.utils import get_current_device
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam

from args import parse_benchmark_args

def format_num(num: int, bytes=False):
"""Scale bytes to its proper format, e.g. 1253656 => '1.20MB'"""
Expand All @@ -26,8 +25,13 @@ def format_num(num: int, bytes=False):


def get_data(batch_size, num_labels, num_channels=3, height=224, width=224):
pixel_values = torch.randn(batch_size, num_channels, height, width, device=torch.cuda.current_device(), dtype=torch.float)
labels = torch.randint(0, num_labels, (batch_size, ), device=torch.cuda.current_device(), dtype=torch.int64)
pixel_values = torch.randn(batch_size,
num_channels,
height,
width,
device=torch.cuda.current_device(),
dtype=torch.float)
labels = torch.randint(0, num_labels, (batch_size,), device=torch.cuda.current_device(), dtype=torch.int64)
return pixel_values, labels


Expand Down Expand Up @@ -55,11 +59,11 @@ def main():
transformers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()

# Whether to set limit on memory capacity
if args.mem_cap > 0:
colo_memory_cap(args.mem_cap)

# Build ViT model
config = ViTConfig.from_pretrained(args.model_name_or_path)
model = ViTForImageClassification(config)
Expand All @@ -75,11 +79,7 @@ def main():
if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin()
elif args.plugin == 'gemini':
plugin = GeminiPlugin(device=get_current_device(),
placement_policy='cpu',
pin_memory=True,
strict_ddp_mode=True,
initial_scale=2**5)
plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5)
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
Expand All @@ -90,16 +90,15 @@ def main():
# Set booster
booster = Booster(plugin=plugin, **booster_kwargs)
model, optimizer, _, _, _ = booster.boost(model, optimizer)


# Start training.
logger.info(f"Start testing", ranks=[0])
progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master())

torch.cuda.synchronize()
model.train()
start_time = time.time()

for _ in range(args.max_train_steps):

pixel_values, labels = get_data(args.batch_size, args.num_labels, 3, 224, 224)
Expand All @@ -111,18 +110,19 @@ def main():

torch.cuda.synchronize()
progress_bar.update(1)
# Compute Statistics

# Compute Statistics
end_time = time.time()
throughput = "{:.4f}".format((world_size * args.max_train_steps * args.batch_size) / (end_time - start_time))
max_mem = format_num(torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), bytes=True)

logger.info(f"Testing finished, "
f"batch size per gpu: {args.batch_size}, "
f"plugin: {args.plugin}, "
f"throughput: {throughput}, "
f"maximum memory usage per gpu: {max_mem}.",
ranks=[0])

logger.info(
f"Testing finished, "
f"batch size per gpu: {args.batch_size}, "
f"plugin: {args.plugin}, "
f"throughput: {throughput}, "
f"maximum memory usage per gpu: {max_mem}.",
ranks=[0])


if __name__ == "__main__":
Expand Down
Loading