Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/images/diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ You can also update an existing [latent diffusion](https://github.com/CompVis/la

```
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
pip install transformers==4.19.2 diffusers invisible-watermark
pip install transformers diffusers invisible-watermark
```

#### Step 2: install lightning
Expand Down
2 changes: 1 addition & 1 deletion examples/images/diffusion/environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies:
- test-tube>=0.7.5
- streamlit==1.12.1
- einops==0.3.0
- transformers==4.19.2
- transformers
- webdataset==0.2.5
- kornia==0.6
- open_clip_torch==2.0.2
Expand Down
2 changes: 1 addition & 1 deletion examples/images/diffusion/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ omegaconf==2.1.1
test-tube>=0.7.5
streamlit>=0.73.1
einops==0.3.0
transformers==4.19.2
transformers
webdataset==0.2.5
open-clip-torch==2.7.0
gradio==3.11
Expand Down
37 changes: 16 additions & 21 deletions examples/images/dreambooth/train_dreambooth_colossalai.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch.utils.checkpoint
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from huggingface_hub import HfFolder, Repository, whoami
from huggingface_hub import HfFolder, Repository, create_repo, whoami
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
Expand Down Expand Up @@ -133,9 +133,13 @@ def parse_args(input_args=None):
default="cpu",
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
)
parser.add_argument("--center_crop",
action="store_true",
help="Whether to center crop images before resizing to resolution")
parser.add_argument(
"--center_crop",
default=False,
action="store_true",
help=("Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."),
)
parser.add_argument("--train_batch_size",
type=int,
default=4,
Expand All @@ -149,13 +153,6 @@ def parse_args(input_args=None):
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help=
"Number of updates steps to accumulate before performing a backward/update pass. If using Gemini, it must be 1",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
Expand Down Expand Up @@ -356,7 +353,6 @@ def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"):


def main(args):

if args.seed is None:
colossalai.launch_from_torch(config={})
else:
Expand Down Expand Up @@ -410,7 +406,8 @@ def main(args):
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name)
create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)

with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
Expand Down Expand Up @@ -469,9 +466,8 @@ def main(args):
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()

assert args.gradient_accumulation_steps == 1, "if using ColossalAI gradient_accumulation_steps must be set to 1."
if args.scale_lr:
args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * world_size
args.learning_rate = args.learning_rate * args.train_batch_size * world_size

unet = gemini_zero_dpp(unet, args.placement)

Expand Down Expand Up @@ -529,16 +525,16 @@ def collate_fn(examples):

# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_update_steps_per_epoch = math.ceil(len(train_dataloader))
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True

lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps,
)
weight_dtype = torch.float32
if args.mixed_precision == "fp16":
Expand All @@ -553,22 +549,21 @@ def collate_fn(examples):
text_encoder.to(get_current_device(), dtype=weight_dtype)

# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_update_steps_per_epoch = math.ceil(len(train_dataloader))
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

# Train!
total_batch_size = args.train_batch_size * world_size * args.gradient_accumulation_steps
total_batch_size = args.train_batch_size * world_size

logger.info("***** Running training *****", ranks=[0])
logger.info(f" Num examples = {len(train_dataset)}", ranks=[0])
logger.info(f" Num batches each epoch = {len(train_dataloader)}", ranks=[0])
logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0])
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}", ranks=[0])
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0])
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}", ranks=[0])
logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0])

# Only show the progress bar once on each machine.
Expand Down