From 81ac95b70e52e9da552bf131786904307e9a626e Mon Sep 17 00:00:00 2001 From: Andreas Steiner Date: Thu, 6 Apr 2023 09:01:49 +0200 Subject: [PATCH 01/26] WIP controlnet training - bugfix --streaming - bugfix running report_to!='wandb' - adds memory profile before validation --- examples/controlnet/train_controlnet_flax.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 224a50bb7fbe..7d02290b5467 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -19,6 +19,7 @@ import os import random from pathlib import Path +import time import jax import jax.numpy as jnp @@ -77,6 +78,7 @@ def image_grid(imgs, rows, cols): def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_dtype): logger.info("Running validation... ") + jax.profiler.save_device_memory_profile(f"memory_{int(time.time())}.prof") pipeline, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( args.pretrained_model_name_or_path, @@ -983,7 +985,7 @@ def cumul_grad_step(grad_idx, loss_grad_rng): logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}") logger.info(f" Total optimization steps = {args.num_train_epochs * num_update_steps_per_epoch}") - if jax.process_index() == 0: + if jax.process_index() == 0 and args.report_to == "wandb": wandb.define_metric("*", step_metric="train/step") wandb.config.update( { @@ -1008,7 +1010,7 @@ def cumul_grad_step(grad_idx, loss_grad_rng): steps_per_epoch = ( args.max_train_samples // total_train_batch_size - if args.streaming + if args.streaming or args.max_train_samples else len(train_dataset) // total_train_batch_size ) train_step_progress_bar = tqdm( From cd7c0d2a984350894cf21b41abc9262522785860 Mon Sep 17 00:00:00 2001 From: Andreas Steiner Date: Fri, 7 Apr 2023 17:09:18 +0000 Subject: [PATCH 02/26] Adds final logging statement. --- examples/controlnet/train_controlnet_flax.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 7d02290b5467..600d12b1e2cc 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -1086,6 +1086,8 @@ def cumul_grad_step(grad_idx, loss_grad_rng): ignore_patterns=["step_*", "epoch_*"], ) + logger.info("Finished training.") + if __name__ == "__main__": main() From c15af708531fd955fe4b3cc5637e5300af0f38d4 Mon Sep 17 00:00:00 2001 From: Andreas Steiner Date: Sat, 8 Apr 2023 11:47:28 +0000 Subject: [PATCH 03/26] Sets train epochs to 11. Looking at a longer ~16ep run, we see only good validation images after ~11ep: https://wandb.ai/andsteing/controlnet_fill50k/runs/3j2hx6n8 --- examples/controlnet/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/controlnet/README.md b/examples/controlnet/README.md index 4b388d92a195..b9229b44c92f 100644 --- a/examples/controlnet/README.md +++ b/examples/controlnet/README.md @@ -363,7 +363,7 @@ python3 train_controlnet_flax.py \ --revision="non-ema" \ --from_pt \ --report_to="wandb" \ - --max_train_steps=10000 \ + --num_train_epochs=11 \ --push_to_hub \ --hub_model_id=$HUB_MODEL_ID ``` From f33869653c2a6a7db7dd4ef0d9023088ff7c75bd Mon Sep 17 00:00:00 2001 From: Andreas Steiner Date: Sat, 8 Apr 2023 11:50:29 +0000 Subject: [PATCH 04/26] Removes --logging_dir (it's not used). --- examples/controlnet/train_controlnet_flax.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 600d12b1e2cc..c186f5b94461 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -319,15 +319,6 @@ def parse_args(): default=None, help="The name of the repository to keep in sync with the local `output_dir`.", ) - parser.add_argument( - "--logging_dir", - type=str, - default="logs", - help=( - "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" - " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." - ), - ) parser.add_argument( "--logging_steps", type=int, From 26ceff2791f8a50d800e75da699cc13c23112b54 Mon Sep 17 00:00:00 2001 From: Andreas Steiner Date: Sat, 8 Apr 2023 11:53:34 +0000 Subject: [PATCH 05/26] Adds --profile flags. --- examples/controlnet/train_controlnet_flax.py | 34 ++++++++++++++++++-- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index c186f5b94461..8f4bab114d85 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -78,7 +78,6 @@ def image_grid(imgs, rows, cols): def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_dtype): logger.info("Running validation... ") - jax.profiler.save_device_memory_profile(f"memory_{int(time.time())}.prof") pipeline, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( args.pretrained_model_name_or_path, @@ -216,6 +215,22 @@ def parse_args(): action="store_true", help="Load the pretrained model from a PyTorch checkpoint.", ) + parser.add_argument( + "--profile_steps", + type=int, + default=0, + help="How many training steps to profile in the beginning.", + ) + parser.add_argument( + "--profile_validation", + action="store_true", + help="Whether to profile the (last) validation.", + ) + parser.add_argument( + "--profile_memory", + action="store_true", + help="Whether to dump an initial (before training loop) and a final (at program end) memory profile.", + ) parser.add_argument( "--controlnet_revision", type=str, @@ -994,6 +1009,8 @@ def cumul_grad_step(grad_idx, loss_grad_rng): position=0, disable=jax.process_index() > 0, ) + if args.profile_memory: + jax.profiler.save_device_memory_profile(os.path.join(args.output_dir, "memory_initial.prof")) for epoch in epochs: # ======================== Training ================================ @@ -1012,7 +1029,12 @@ def cumul_grad_step(grad_idx, loss_grad_rng): disable=jax.process_index() > 0, ) # train - for batch in train_dataloader: + for step, batch in enumerate(train_dataloader): + if args.profile_steps and step == 0: + jax.profiler.start_trace(args.output_dir) + elif args.profile_steps == step: + jax.profiler.stop_trace() + batch = shard(batch) state, train_metric, train_rngs = p_train_step( state, unet_params, text_encoder_params, vae_params, batch, train_rngs @@ -1051,10 +1073,14 @@ def cumul_grad_step(grad_idx, loss_grad_rng): train_step_progress_bar.close() epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})") - # Create the pipeline using using the trained modules and save it. + # Final validation & store model. if jax.process_index() == 0: if args.validation_prompt is not None: + if args.profile_validation: + jax.profiler.start_trace(args.output_dir) image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype) + if args.profile_validation: + jax.profiler.stop_trace() else: image_logs = None @@ -1077,6 +1103,8 @@ def cumul_grad_step(grad_idx, loss_grad_rng): ignore_patterns=["step_*", "epoch_*"], ) + if args.profile_memory: + jax.profiler.save_device_memory_profile(os.path.join(args.output_dir, "memory_final.prof")) logger.info("Finished training.") From 56d777cc352b3d3197550e87157649a43324ba28 Mon Sep 17 00:00:00 2001 From: Andreas Steiner Date: Sat, 8 Apr 2023 11:56:11 +0000 Subject: [PATCH 06/26] Updates --output_dir=runs/fill-circle-{timestamp}. --- examples/controlnet/README.md | 4 ++-- examples/controlnet/train_controlnet_flax.py | 9 +++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/controlnet/README.md b/examples/controlnet/README.md index b9229b44c92f..1eacdcccdbae 100644 --- a/examples/controlnet/README.md +++ b/examples/controlnet/README.md @@ -343,8 +343,8 @@ Make sure you have the `MODEL_DIR`,`OUTPUT_DIR` and `HUB_MODEL_ID` environment v ```bash export MODEL_DIR="runwayml/stable-diffusion-v1-5" -export OUTPUT_DIR="control_out" -export HUB_MODEL_ID="fill-circle-controlnet" +export OUTPUT_DIR="runs/fill-circle-{timestamp}" +export HUB_MODEL_ID="controlnet-fill-circle" ``` And finally start the training diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 8f4bab114d85..218ac5a4d966 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -251,8 +251,9 @@ def parse_args(): parser.add_argument( "--output_dir", type=str, - default="controlnet-model", - help="The output directory where the model predictions and checkpoints will be written.", + default="runs/{timestamp}", + help="The output directory where the model predictions and checkpoints will be written. " + "Can contain placeholders: {timestamp}.", ) parser.add_argument( "--cache_dir", @@ -467,6 +468,10 @@ def parse_args(): parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") args = parser.parse_args() + args.output_dir = args.output_dir.replace( + '{timestamp}', time.strftime('%Y%m%d_%H%M%S') + ) + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank From 34a63f521c3e617cd0026da843a68ca0942bbc43 Mon Sep 17 00:00:00 2001 From: Andreas Steiner Date: Sat, 8 Apr 2023 11:59:01 +0000 Subject: [PATCH 07/26] Compute mean of `train_metrics`. Previously `train_metrics[-1]` was logged, resulting in very bumpy train metrics. --- examples/controlnet/train_controlnet_flax.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 218ac5a4d966..e25583e2032d 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -1061,13 +1061,16 @@ def cumul_grad_step(grad_idx, loss_grad_rng): if global_step % args.logging_steps == 0 and jax.process_index() == 0: if args.report_to == "wandb": + train_metrics = jax_utils.unreplicate(train_metrics) + train_metrics = jax.tree_util.tree_map(lambda *m: jnp.array(m).mean(), *train_metrics) wandb.log( { "train/step": global_step, "train/epoch": epoch, - "train/loss": jax_utils.unreplicate(train_metric)["loss"], + **{f"train/{k}": v for k, v in train_metrics.items()}, } ) + train_metrics = [] if global_step % args.checkpointing_steps == 0 and jax.process_index() == 0: controlnet.save_pretrained( f"{args.output_dir}/{global_step}", From f02e482d9ff5506371e34e4d64cdba3dde61c20a Mon Sep 17 00:00:00 2001 From: Andreas Steiner Date: Sat, 8 Apr 2023 12:09:59 +0000 Subject: [PATCH 08/26] Improves logging a bit. - adds l2_grads gradient norm logging - adds steps_per_sec - sets walltime as x coordinate of train/step - logs controlnet_params config --- examples/controlnet/train_controlnet_flax.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index e25583e2032d..d764d9e45a78 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -964,6 +964,9 @@ def cumul_grad_step(grad_idx, loss_grad_rng): metrics = {"loss": loss} metrics = jax.lax.pmean(metrics, axis_name="batch") + l2 = lambda xs: jnp.sqrt(sum([ + jnp.vdot(x, x) for x in jax.tree_util.tree_leaves(xs)])) + metrics["l2_grads"] = l2(jax.tree_util.tree_leaves(grad)) return new_state, metrics, new_train_rng @@ -998,12 +1001,14 @@ def cumul_grad_step(grad_idx, loss_grad_rng): if jax.process_index() == 0 and args.report_to == "wandb": wandb.define_metric("*", step_metric="train/step") + wandb.define_metric("train/step", step_metric="walltime") wandb.config.update( { "num_train_examples": args.max_train_samples if args.streaming else len(train_dataset), "total_train_batch_size": total_train_batch_size, "total_optimization_step": args.num_train_epochs * num_update_steps_per_epoch, "num_devices": jax.device_count(), + "controlnet_params": sum(np.prod(x.shape) for x in jax.tree_util.tree_leaves(state.params)), } ) @@ -1034,6 +1039,7 @@ def cumul_grad_step(grad_idx, loss_grad_rng): disable=jax.process_index() > 0, ) # train + t0, step0 = time.monotonic(), -1 for step, batch in enumerate(train_dataloader): if args.profile_steps and step == 0: jax.profiler.start_trace(args.output_dir) @@ -1067,9 +1073,11 @@ def cumul_grad_step(grad_idx, loss_grad_rng): { "train/step": global_step, "train/epoch": epoch, + "train/steps_per_sec": (step - step0) / (time.monotonic() - t0), **{f"train/{k}": v for k, v in train_metrics.items()}, } ) + t0, step0 = time.monotonic(), step train_metrics = [] if global_step % args.checkpointing_steps == 0 and jax.process_index() == 0: controlnet.save_pretrained( From f7a2f28575d6fb75d8e6aea58664e6383660cefb Mon Sep 17 00:00:00 2001 From: Andreas Steiner Date: Sat, 8 Apr 2023 12:10:14 +0000 Subject: [PATCH 09/26] Adds --ccache (doesn't really help though). --- examples/controlnet/train_controlnet_flax.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index d764d9e45a78..d46d43a0a72f 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -22,6 +22,7 @@ import time import jax +import jax.experimental.compilation_cache.compilation_cache as cc import jax.numpy as jnp import numpy as np import optax @@ -231,6 +232,12 @@ def parse_args(): action="store_true", help="Whether to dump an initial (before training loop) and a final (at program end) memory profile.", ) + parser.add_argument( + "--ccache", + type=str, + default=None, + help="Enables compilation cache.", + ) parser.add_argument( "--controlnet_revision", type=str, @@ -1012,6 +1019,9 @@ def cumul_grad_step(grad_idx, loss_grad_rng): } ) + if args.ccache: + cc.initialize_cache(args.ccache) + global_step = 0 epochs = tqdm( range(args.num_train_epochs), From 95d5b1853f16c12df748e0f1f50ddeb06fc0cacd Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Thu, 6 Apr 2023 10:27:41 -1000 Subject: [PATCH 10/26] minor fix in controlnet flax example (#2986) * fix the error when push_to_hub but not log validation * contronet_from_pt & controlnet_revision * add intermediate checkpointing to the guide --- examples/controlnet/README.md | 1 + examples/controlnet/train_controlnet_flax.py | 11 +++++++++++ 2 files changed, 12 insertions(+) diff --git a/examples/controlnet/README.md b/examples/controlnet/README.md index 1eacdcccdbae..901b4f094a5a 100644 --- a/examples/controlnet/README.md +++ b/examples/controlnet/README.md @@ -326,6 +326,7 @@ If you want to use Weights and Biases logging, you should also install `wandb` n pip install wandb ``` + Now let's downloading two conditioning images that we will use to run validation during the training in order to track our progress ``` diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index d46d43a0a72f..831b75e496d8 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -216,6 +216,17 @@ def parse_args(): action="store_true", help="Load the pretrained model from a PyTorch checkpoint.", ) + parser.add_argument( + "--controlnet_revision", + type=str, + default=None, + help="Revision of controlnet model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--controlnet_from_pt", + action="store_true", + help="Load the controlnet model from a PyTorch checkpoint.", + ) parser.add_argument( "--profile_steps", type=int, From 0874a9d32940581ca019166f9935181836b82e82 Mon Sep 17 00:00:00 2001 From: Andreas Steiner Date: Sat, 8 Apr 2023 12:50:57 +0000 Subject: [PATCH 11/26] Bugfix --profile_steps --- examples/controlnet/train_controlnet_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 831b75e496d8..9b8f334b6c2e 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -1064,7 +1064,7 @@ def cumul_grad_step(grad_idx, loss_grad_rng): for step, batch in enumerate(train_dataloader): if args.profile_steps and step == 0: jax.profiler.start_trace(args.output_dir) - elif args.profile_steps == step: + if args.profile_steps and args.profile_steps == step: jax.profiler.stop_trace() batch = shard(batch) From 1a74f34fb02ff778db29815b77323f945e612860 Mon Sep 17 00:00:00 2001 From: Andreas Steiner Date: Sat, 8 Apr 2023 13:11:42 +0000 Subject: [PATCH 12/26] Sets `RACKER_PROJECT_NAME='controlnet_fill50k'`. --- examples/controlnet/README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/controlnet/README.md b/examples/controlnet/README.md index 901b4f094a5a..fbbdf7f45783 100644 --- a/examples/controlnet/README.md +++ b/examples/controlnet/README.md @@ -340,12 +340,13 @@ We encourage you to store or share your model with the community. To use hugging huggingface-cli login ``` -Make sure you have the `MODEL_DIR`,`OUTPUT_DIR` and `HUB_MODEL_ID` environment variables set. The `OUTPUT_DIR` and `HUB_MODEL_ID` variables specify where to save the model to on the Hub: +Make sure you have the `MODEL_DIR`,`OUTPUT_DIR` and `HUB_MODEL_ID` environment variables set. The `OUTPUT_DIR` and `HUB_MODEL_ID` variables specify where to save the model to on the Hub. The `TRACKER_PROJECT_NAME` variable sets the project name in `wandb`. ```bash export MODEL_DIR="runwayml/stable-diffusion-v1-5" export OUTPUT_DIR="runs/fill-circle-{timestamp}" export HUB_MODEL_ID="controlnet-fill-circle" +export TRACKER_PROJECT_NAME='controlnet_fill50k' ``` And finally start the training @@ -355,6 +356,7 @@ python3 train_controlnet_flax.py \ --pretrained_model_name_or_path=$MODEL_DIR \ --output_dir=$OUTPUT_DIR \ --dataset_name=fusing/fill50k \ + --tracker_project_name="$TRACKER_PROJECT_NAME" \ --resolution=512 \ --learning_rate=1e-5 \ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \ From 7f9a3c3fec5213173dc71d0b4c29fa97eb31db6c Mon Sep 17 00:00:00 2001 From: Andreas Steiner Date: Sat, 8 Apr 2023 13:20:17 +0000 Subject: [PATCH 13/26] Logs fractional epoch. --- examples/controlnet/train_controlnet_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 9b8f334b6c2e..21c5de76c869 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -1093,7 +1093,7 @@ def cumul_grad_step(grad_idx, loss_grad_rng): wandb.log( { "train/step": global_step, - "train/epoch": epoch, + "train/epoch": epoch + (step + 1) / dataset_length, "train/steps_per_sec": (step - step0) / (time.monotonic() - t0), **{f"train/{k}": v for k, v in train_metrics.items()}, } From 4d20a0405aadd9d927b79dda13d98d4c659c760d Mon Sep 17 00:00:00 2001 From: Andreas Steiner Date: Sat, 8 Apr 2023 17:49:49 +0000 Subject: [PATCH 14/26] Adds relative `walltime` metric. --- examples/controlnet/train_controlnet_flax.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 21c5de76c869..87f0925e7231 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -1042,6 +1042,7 @@ def cumul_grad_step(grad_idx, loss_grad_rng): ) if args.profile_memory: jax.profiler.save_device_memory_profile(os.path.join(args.output_dir, "memory_initial.prof")) + t00 = time.monotonic() for epoch in epochs: # ======================== Training ================================ @@ -1092,6 +1093,7 @@ def cumul_grad_step(grad_idx, loss_grad_rng): train_metrics = jax.tree_util.tree_map(lambda *m: jnp.array(m).mean(), *train_metrics) wandb.log( { + "walltime": time.monotonic() - t00, "train/step": global_step, "train/epoch": epoch + (step + 1) / dataset_length, "train/steps_per_sec": (step - step0) / (time.monotonic() - t0), From d1f59942561e377e6c4ca069c4180f77c9cacb52 Mon Sep 17 00:00:00 2001 From: Andreas Steiner Date: Sat, 8 Apr 2023 18:01:18 +0000 Subject: [PATCH 15/26] Adds `StepTraceAnnotation` and uses `global_step` insetad of `step`. --- examples/controlnet/train_controlnet_flax.py | 24 ++++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 87f0925e7231..0d2cdb356b33 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -1033,7 +1033,7 @@ def cumul_grad_step(grad_idx, loss_grad_rng): if args.ccache: cc.initialize_cache(args.ccache) - global_step = 0 + global_step = step0 = 0 epochs = tqdm( range(args.num_train_epochs), desc="Epoch ... ", @@ -1042,7 +1042,7 @@ def cumul_grad_step(grad_idx, loss_grad_rng): ) if args.profile_memory: jax.profiler.save_device_memory_profile(os.path.join(args.output_dir, "memory_initial.prof")) - t00 = time.monotonic() + t00 = t0 = time.monotonic() for epoch in epochs: # ======================== Training ================================ @@ -1061,17 +1061,17 @@ def cumul_grad_step(grad_idx, loss_grad_rng): disable=jax.process_index() > 0, ) # train - t0, step0 = time.monotonic(), -1 - for step, batch in enumerate(train_dataloader): - if args.profile_steps and step == 0: + for batch in train_dataloader: + if args.profile_steps and global_step == 0: jax.profiler.start_trace(args.output_dir) - if args.profile_steps and args.profile_steps == step: + if args.profile_steps and args.profile_steps == global_step: jax.profiler.stop_trace() batch = shard(batch) - state, train_metric, train_rngs = p_train_step( - state, unet_params, text_encoder_params, vae_params, batch, train_rngs - ) + with jax.profiler.StepTraceAnnotation("train", step_num=global_step): + state, train_metric, train_rngs = p_train_step( + state, unet_params, text_encoder_params, vae_params, batch, train_rngs + ) train_metrics.append(train_metric) train_step_progress_bar.update(1) @@ -1095,12 +1095,12 @@ def cumul_grad_step(grad_idx, loss_grad_rng): { "walltime": time.monotonic() - t00, "train/step": global_step, - "train/epoch": epoch + (step + 1) / dataset_length, - "train/steps_per_sec": (step - step0) / (time.monotonic() - t0), + "train/epoch": global_step / dataset_length, + "train/steps_per_sec": (global_step - step0) / (time.monotonic() - t0), **{f"train/{k}": v for k, v in train_metrics.items()}, } ) - t0, step0 = time.monotonic(), step + t0, step0 = time.monotonic(), global_step train_metrics = [] if global_step % args.checkpointing_steps == 0 and jax.process_index() == 0: controlnet.save_pretrained( From 4d83f6fcb467eea5aa444fb22fa59c74aac8f392 Mon Sep 17 00:00:00 2001 From: Andreas Steiner Date: Tue, 11 Apr 2023 05:53:55 +0000 Subject: [PATCH 16/26] Applied `black`. --- examples/controlnet/train_controlnet_flax.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 0d2cdb356b33..90ed973d73de 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -18,8 +18,8 @@ import math import os import random -from pathlib import Path import time +from pathlib import Path import jax import jax.experimental.compilation_cache.compilation_cache as cc @@ -486,9 +486,7 @@ def parse_args(): parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") args = parser.parse_args() - args.output_dir = args.output_dir.replace( - '{timestamp}', time.strftime('%Y%m%d_%H%M%S') - ) + args.output_dir = args.output_dir.replace("{timestamp}", time.strftime("%Y%m%d_%H%M%S")) env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: @@ -982,8 +980,8 @@ def cumul_grad_step(grad_idx, loss_grad_rng): metrics = {"loss": loss} metrics = jax.lax.pmean(metrics, axis_name="batch") - l2 = lambda xs: jnp.sqrt(sum([ - jnp.vdot(x, x) for x in jax.tree_util.tree_leaves(xs)])) + def l2(xs): + return jnp.sqrt(sum([jnp.vdot(x, x) for x in jax.tree_util.tree_leaves(xs)])) metrics["l2_grads"] = l2(jax.tree_util.tree_leaves(grad)) return new_state, metrics, new_train_rng From 4fed793f114b070c1d75302348bf13b1950275f4 Mon Sep 17 00:00:00 2001 From: Andreas Steiner Date: Tue, 11 Apr 2023 06:15:27 +0000 Subject: [PATCH 17/26] Streamlines commands in README a bit. --- examples/controlnet/README.md | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/examples/controlnet/README.md b/examples/controlnet/README.md index fbbdf7f45783..5fc356e94833 100644 --- a/examples/controlnet/README.md +++ b/examples/controlnet/README.md @@ -340,13 +340,12 @@ We encourage you to store or share your model with the community. To use hugging huggingface-cli login ``` -Make sure you have the `MODEL_DIR`,`OUTPUT_DIR` and `HUB_MODEL_ID` environment variables set. The `OUTPUT_DIR` and `HUB_MODEL_ID` variables specify where to save the model to on the Hub. The `TRACKER_PROJECT_NAME` variable sets the project name in `wandb`. +Make sure you have the `MODEL_DIR`,`OUTPUT_DIR` and `HUB_MODEL_ID` environment variables set. The `OUTPUT_DIR` and `HUB_MODEL_ID` variables specify where to save the model to on the Hub: ```bash export MODEL_DIR="runwayml/stable-diffusion-v1-5" export OUTPUT_DIR="runs/fill-circle-{timestamp}" export HUB_MODEL_ID="controlnet-fill-circle" -export TRACKER_PROJECT_NAME='controlnet_fill50k' ``` And finally start the training @@ -356,7 +355,6 @@ python3 train_controlnet_flax.py \ --pretrained_model_name_or_path=$MODEL_DIR \ --output_dir=$OUTPUT_DIR \ --dataset_name=fusing/fill50k \ - --tracker_project_name="$TRACKER_PROJECT_NAME" \ --resolution=512 \ --learning_rate=1e-5 \ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \ @@ -366,6 +364,7 @@ python3 train_controlnet_flax.py \ --revision="non-ema" \ --from_pt \ --report_to="wandb" \ + --tracker_project_name=$HUB_MODEL_ID \ --num_train_epochs=11 \ --push_to_hub \ --hub_model_id=$HUB_MODEL_ID @@ -373,9 +372,13 @@ python3 train_controlnet_flax.py \ Since we passed the `--push_to_hub` flag, it will automatically create a model repo under your huggingface account based on `$HUB_MODEL_ID`. By the end of training, the final checkpoint will be automatically stored on the hub. You can find an example model repo [here](https://huggingface.co/YiYiXu/fill-circle-controlnet). -Our training script also provides limited support for streaming large datasets from the Hugging Face Hub. In order to enable streaming, one must also set `--max_train_samples`. Here is an example command: +Our training script also provides limited support for streaming large datasets from the Hugging Face Hub. In order to enable streaming, one must also set `--max_train_samples`. Here is an example command (from [this Blog article](https://huggingface.co/blog/train-your-controlnet)): ```bash +export MODEL_DIR="runwayml/stable-diffusion-v1-5" +export OUTPUT_DIR="runs/uncanny-faces-{timestamp}" +export HUB_MODEL_ID="controlnet-uncanny-faces" + python3 train_controlnet_flax.py \ --pretrained_model_name_or_path=$MODEL_DIR \ --output_dir=$OUTPUT_DIR \ @@ -385,13 +388,12 @@ python3 train_controlnet_flax.py \ --image_column=image \ --caption_column=image_caption \ --resolution=512 \ - --max_train_samples 50 \ - --max_train_steps 5 \ + --max_train_samples 100000 \ --learning_rate=1e-5 \ - --validation_steps=2 \ --train_batch_size=1 \ --revision="flax" \ - --report_to="wandb" + --report_to="wandb" \ + --tracker_project_name=$HUB_MODEL_ID ``` Note, however, that the performance of the TPUs might get bottlenecked as streaming with `datasets` is not optimized for images. For ensuring maximum throughput, we encourage you to explore the following options: From 92b7befe639f1ad460f300dbb4001e04bd2cd7a7 Mon Sep 17 00:00:00 2001 From: Andreas Steiner Date: Tue, 11 Apr 2023 06:28:28 +0000 Subject: [PATCH 18/26] Removes `--ccache`. This makes only a very small difference (~1 min) with this model size, so removing the option introduced in cdb3cc. --- examples/controlnet/train_controlnet_flax.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 90ed973d73de..163d2014629c 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -22,7 +22,6 @@ from pathlib import Path import jax -import jax.experimental.compilation_cache.compilation_cache as cc import jax.numpy as jnp import numpy as np import optax @@ -1028,9 +1027,6 @@ def l2(xs): } ) - if args.ccache: - cc.initialize_cache(args.ccache) - global_step = step0 = 0 epochs = tqdm( range(args.num_train_epochs), From a827b34e408361d05cc85bcf6fba922c14235d6e Mon Sep 17 00:00:00 2001 From: Andreas Steiner Date: Tue, 11 Apr 2023 07:09:03 +0000 Subject: [PATCH 19/26] Re-ran `black`. --- examples/controlnet/train_controlnet_flax.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 163d2014629c..8bf87a098ae5 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -979,8 +979,10 @@ def cumul_grad_step(grad_idx, loss_grad_rng): metrics = {"loss": loss} metrics = jax.lax.pmean(metrics, axis_name="batch") + def l2(xs): return jnp.sqrt(sum([jnp.vdot(x, x) for x in jax.tree_util.tree_leaves(xs)])) + metrics["l2_grads"] = l2(jax.tree_util.tree_leaves(grad)) return new_state, metrics, new_train_rng From cb410f665074ec02c91c3cdbc4937d177128d9aa Mon Sep 17 00:00:00 2001 From: Andreas Steiner Date: Tue, 11 Apr 2023 14:20:33 +0200 Subject: [PATCH 20/26] Update examples/controlnet/README.md Co-authored-by: Sayak Paul --- examples/controlnet/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/controlnet/README.md b/examples/controlnet/README.md index 5fc356e94833..1abdec906bbc 100644 --- a/examples/controlnet/README.md +++ b/examples/controlnet/README.md @@ -372,7 +372,7 @@ python3 train_controlnet_flax.py \ Since we passed the `--push_to_hub` flag, it will automatically create a model repo under your huggingface account based on `$HUB_MODEL_ID`. By the end of training, the final checkpoint will be automatically stored on the hub. You can find an example model repo [here](https://huggingface.co/YiYiXu/fill-circle-controlnet). -Our training script also provides limited support for streaming large datasets from the Hugging Face Hub. In order to enable streaming, one must also set `--max_train_samples`. Here is an example command (from [this Blog article](https://huggingface.co/blog/train-your-controlnet)): +Our training script also provides limited support for streaming large datasets from the Hugging Face Hub. In order to enable streaming, one must also set `--max_train_samples`. Here is an example command (from [this blog article](https://huggingface.co/blog/train-your-controlnet)): ```bash export MODEL_DIR="runwayml/stable-diffusion-v1-5" From 00a394eecdb40b735ea4e74b1d5b73e9ef01ea58 Mon Sep 17 00:00:00 2001 From: Andreas Steiner Date: Tue, 11 Apr 2023 12:25:53 +0000 Subject: [PATCH 21/26] Converts spaces to tab. --- examples/controlnet/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/controlnet/README.md b/examples/controlnet/README.md index 1abdec906bbc..3baac71a5a24 100644 --- a/examples/controlnet/README.md +++ b/examples/controlnet/README.md @@ -393,7 +393,7 @@ python3 train_controlnet_flax.py \ --train_batch_size=1 \ --revision="flax" \ --report_to="wandb" \ - --tracker_project_name=$HUB_MODEL_ID + --tracker_project_name=$HUB_MODEL_ID ``` Note, however, that the performance of the TPUs might get bottlenecked as streaming with `datasets` is not optimized for images. For ensuring maximum throughput, we encourage you to explore the following options: From 6dc0c680edc05ef5508d0b680bb7cfd0ffd10ce4 Mon Sep 17 00:00:00 2001 From: Andreas Steiner Date: Wed, 12 Apr 2023 07:01:24 +0000 Subject: [PATCH 22/26] Removes repeated args. --- examples/controlnet/train_controlnet_flax.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 8bf87a098ae5..6fcc9a728194 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -221,11 +221,6 @@ def parse_args(): default=None, help="Revision of controlnet model identifier from huggingface.co/models.", ) - parser.add_argument( - "--controlnet_from_pt", - action="store_true", - help="Load the controlnet model from a PyTorch checkpoint.", - ) parser.add_argument( "--profile_steps", type=int, @@ -248,12 +243,6 @@ def parse_args(): default=None, help="Enables compilation cache.", ) - parser.add_argument( - "--controlnet_revision", - type=str, - default=None, - help="Revision of controlnet model identifier from huggingface.co/models.", - ) parser.add_argument( "--controlnet_from_pt", action="store_true", From f4af068102bca848b048bbd1d6a0bebe5dd6a18b Mon Sep 17 00:00:00 2001 From: Andreas Steiner Date: Wed, 12 Apr 2023 07:03:49 +0000 Subject: [PATCH 23/26] Skips first step (compilation) in profiling --- examples/controlnet/train_controlnet_flax.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 6fcc9a728194..fecc86f8784a 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -1047,9 +1047,12 @@ def l2(xs): ) # train for batch in train_dataloader: - if args.profile_steps and global_step == 0: + + if args.profile_steps and global_step == 1: + train_metric['loss'].block_until_ready() jax.profiler.start_trace(args.output_dir) - if args.profile_steps and args.profile_steps == global_step: + if args.profile_steps and global_step == 1 + args.profile_steps: + train_metric['loss'].block_until_ready() jax.profiler.stop_trace() batch = shard(batch) From cd084f5ea2ae5bb42cb9d2a309d3e305efd8e60c Mon Sep 17 00:00:00 2001 From: Andreas Steiner Date: Wed, 12 Apr 2023 11:54:52 +0000 Subject: [PATCH 24/26] Updates README with profiling instructions. --- examples/controlnet/README.md | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/examples/controlnet/README.md b/examples/controlnet/README.md index 3baac71a5a24..8271f8697f41 100644 --- a/examples/controlnet/README.md +++ b/examples/controlnet/README.md @@ -417,4 +417,23 @@ You can then start your training from this saved checkpoint with We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence by rebalancing the loss. To use it, one needs to set the `--snr_gamma` argument. The recommended value when using it is `5.0`. -We also support gradient accumulation - it is a technique that lets you use a bigger batch size than your machine would normally be able to fit into memory. You can use `gradient_accumulation_steps` argument to set gradient accumulation steps. The ControlNet author recommends using gradient accumulation to achieve better convergence. Read more [here](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md#more-consideration-sudden-converge-phenomenon-and-gradient-accumulation). \ No newline at end of file +We also support gradient accumulation - it is a technique that lets you use a bigger batch size than your machine would normally be able to fit into memory. You can use `gradient_accumulation_steps` argument to set gradient accumulation steps. The ControlNet author recommends using gradient accumulation to achieve better convergence. Read more [here](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md#more-consideration-sudden-converge-phenomenon-and-gradient-accumulation). + +You can **profile your code** with: + +```bash + --profile_steps==5 +``` + +Refer to the [JAX documentation on profiling](https://jax.readthedocs.io/en/latest/profiling.html). To inspect the profile trace, you'll have to install and start Tensorboard with the profile plugin: + +```bash +pip install tensorflow tensorboard-plugin-profile +tensorboard --logdir runs/fill-circle-100steps-20230411_165612/ +``` + +The profile can then be inspected at http://localhost:6006/#profile + +Sometimes you'll get version conflicts (error messages like `Duplicate plugins for name projector`), which means that you have to uninstall and reinstall all versions of Tensorflow/Tensorboard (e.g. with `pip uninstall tensorflow tf-nightly tensorboard tb-nightly tensorboard-plugin-profile && pip install tf-nightly tbp-nightly tensorboard-plugin-profile`). + +Note that the debugging functionality of the Tensorboard `profile` plugin is still under active development. Not all views are fully functional, and for example the `trace_viewer` cuts off events after 1M (which can result in all your device traces getting lost if you for example profile the compilation step by accident). From 5ef3b5c7392c6729a5073357e25dea68a08106c2 Mon Sep 17 00:00:00 2001 From: Andreas Steiner Date: Wed, 12 Apr 2023 12:05:15 +0000 Subject: [PATCH 25/26] Unifies tabs/spaces in README. --- examples/controlnet/README.md | 40 +++++++++++++++++------------------ 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/examples/controlnet/README.md b/examples/controlnet/README.md index 8271f8697f41..387755624729 100644 --- a/examples/controlnet/README.md +++ b/examples/controlnet/README.md @@ -284,9 +284,9 @@ TPU_TYPE=v4-8 VM_NAME=hg_flax gcloud alpha compute tpus tpu-vm create $VM_NAME \ - --zone $ZONE \ - --accelerator-type $TPU_TYPE \ - --version tpu-vm-v4-base + --zone $ZONE \ + --accelerator-type $TPU_TYPE \ + --version tpu-vm-v4-base gcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone $ZONE -- \ ``` @@ -380,20 +380,20 @@ export OUTPUT_DIR="runs/uncanny-faces-{timestamp}" export HUB_MODEL_ID="controlnet-uncanny-faces" python3 train_controlnet_flax.py \ - --pretrained_model_name_or_path=$MODEL_DIR \ - --output_dir=$OUTPUT_DIR \ - --dataset_name=multimodalart/facesyntheticsspigacaptioned \ - --streaming \ - --conditioning_image_column=spiga_seg \ - --image_column=image \ - --caption_column=image_caption \ - --resolution=512 \ - --max_train_samples 100000 \ - --learning_rate=1e-5 \ - --train_batch_size=1 \ - --revision="flax" \ - --report_to="wandb" \ - --tracker_project_name=$HUB_MODEL_ID + --pretrained_model_name_or_path=$MODEL_DIR \ + --output_dir=$OUTPUT_DIR \ + --dataset_name=multimodalart/facesyntheticsspigacaptioned \ + --streaming \ + --conditioning_image_column=spiga_seg \ + --image_column=image \ + --caption_column=image_caption \ + --resolution=512 \ + --max_train_samples 100000 \ + --learning_rate=1e-5 \ + --train_batch_size=1 \ + --revision="flax" \ + --report_to="wandb" \ + --tracker_project_name=$HUB_MODEL_ID ``` Note, however, that the performance of the TPUs might get bottlenecked as streaming with `datasets` is not optimized for images. For ensuring maximum throughput, we encourage you to explore the following options: @@ -405,14 +405,14 @@ Note, however, that the performance of the TPUs might get bottlenecked as stream When work with a larger dataset, you may need to run training process for a long time and it’s useful to save regular checkpoints during the process. You can use the following argument to enable intermediate checkpointing: ```bash - --checkpointing_steps=500 + --checkpointing_steps=500 ``` This will save the trained model in subfolders of your output_dir. Subfolder names is the number of steps performed so far; for example: a checkpoint saved after 500 training steps would be saved in a subfolder named 500 You can then start your training from this saved checkpoint with ```bash - --controlnet_model_name_or_path="./control_out/500" + --controlnet_model_name_or_path="./control_out/500" ``` We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence by rebalancing the loss. To use it, one needs to set the `--snr_gamma` argument. The recommended value when using it is `5.0`. @@ -422,7 +422,7 @@ We also support gradient accumulation - it is a technique that lets you use a bi You can **profile your code** with: ```bash - --profile_steps==5 + --profile_steps==5 ``` Refer to the [JAX documentation on profiling](https://jax.readthedocs.io/en/latest/profiling.html). To inspect the profile trace, you'll have to install and start Tensorboard with the profile plugin: From 94c6faf24f171cc1ab8d4aba727654b0f69ca797 Mon Sep 17 00:00:00 2001 From: Andreas Steiner Date: Wed, 12 Apr 2023 15:29:42 +0000 Subject: [PATCH 26/26] Re-ran style & quality. --- examples/controlnet/train_controlnet_flax.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index fecc86f8784a..bef3e49ed007 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -1032,6 +1032,7 @@ def l2(xs): # ======================== Training ================================ train_metrics = [] + train_metric = None steps_per_epoch = ( args.max_train_samples // total_train_batch_size @@ -1047,12 +1048,11 @@ def l2(xs): ) # train for batch in train_dataloader: - if args.profile_steps and global_step == 1: - train_metric['loss'].block_until_ready() + train_metric["loss"].block_until_ready() jax.profiler.start_trace(args.output_dir) if args.profile_steps and global_step == 1 + args.profile_steps: - train_metric['loss'].block_until_ready() + train_metric["loss"].block_until_ready() jax.profiler.stop_trace() batch = shard(batch)