From 80486c15aa421948996bfd2cdc32052dce0fd8bd Mon Sep 17 00:00:00 2001 From: Tim Hinderliter Date: Mon, 5 Dec 2022 23:00:12 -0800 Subject: [PATCH 1/4] dreambooth: fix #1566: maintain fp32 wrapper when saving a checkpoint to avoid crash when running fp16 --- examples/dreambooth/train_dreambooth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index c76ff7c632e2..6e92e4f4c70a 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -692,8 +692,8 @@ def main(args): if accelerator.is_main_process: pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet), - text_encoder=accelerator.unwrap_model(text_encoder), + unet=accelerator.unwrap_model(unet, True), + text_encoder=accelerator.unwrap_model(text_encoder, True), revision=args.revision, ) save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") From ebfd8bf24002f20b5c7a92f765905f9397b26b90 Mon Sep 17 00:00:00 2001 From: Tim Hinderliter Date: Wed, 7 Dec 2022 23:50:35 -0800 Subject: [PATCH 2/4] dreambooth: guard against passing keep_fp32_wrapper arg to older versions of accelerate. part of fix for #1566 --- examples/dreambooth/train_dreambooth.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 6e92e4f4c70a..12316c74ab49 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -1,5 +1,6 @@ import argparse import hashlib +import inspect import itertools import math import os @@ -690,10 +691,18 @@ def main(args): if global_step % args.save_steps == 0: if accelerator.is_main_process: + # newer versions of accelerate allow the 'keep_fp32_wrapper' arg. without passing + # it, the models will be unwrapped, and when they are then used for further training, + # we will crash. pass this, but only to newer versions of accelerate. fixes + # https://github.com/huggingface/diffusers/issues/1566 + accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set( + inspect.signature(accelerator.unwrap_model).parameters.keys() + ) + extra_args = {"keep_fp32_wrapper": True} if accepts_keep_fp32_wrapper else {} pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet, True), - text_encoder=accelerator.unwrap_model(text_encoder, True), + unet=accelerator.unwrap_model(unet, **extra_args), + text_encoder=accelerator.unwrap_model(text_encoder, **extra_args), revision=args.revision, ) save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") From 7d039df2b1d5261b29539de8e3a56cbcb0b5e28c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 10 Dec 2022 15:37:23 +0100 Subject: [PATCH 3/4] Apply suggestions from code review Co-authored-by: Pedro Cuenca --- examples/dreambooth/train_dreambooth.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 12316c74ab49..edc6717e3c15 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -691,10 +691,10 @@ def main(args): if global_step % args.save_steps == 0: if accelerator.is_main_process: - # newer versions of accelerate allow the 'keep_fp32_wrapper' arg. without passing - # it, the models will be unwrapped, and when they are then used for further training, - # we will crash. pass this, but only to newer versions of accelerate. fixes - # https://github.com/huggingface/diffusers/issues/1566 + # When 'keep_fp32_wrapper' is `False` (the default), then the models are + # unwrapped and the mixed precision hooks are removed, so training crashes + # when the unwrapped models are used for further training. + # This is only supported in newer versions of `accelerate`. accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set( inspect.signature(accelerator.unwrap_model).parameters.keys() ) From 94bd34d23fde07cd8527c6728c9521f21113a883 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 10 Dec 2022 15:38:41 +0100 Subject: [PATCH 4/4] Update examples/dreambooth/train_dreambooth.py --- examples/dreambooth/train_dreambooth.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index edc6717e3c15..f55e63442f45 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -695,6 +695,7 @@ def main(args): # unwrapped and the mixed precision hooks are removed, so training crashes # when the unwrapped models are used for further training. # This is only supported in newer versions of `accelerate`. + # TODO(Pedro, Suraj): Remove `accepts_keep_fp32_wrapper` when forcing newer accelerate versions accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set( inspect.signature(accelerator.unwrap_model).parameters.keys() )