From c3922b5cde73f9886eea267547df9ea8a06d8df8 Mon Sep 17 00:00:00 2001 From: Tim Hinderliter Date: Mon, 5 Dec 2022 23:00:12 -0800 Subject: [PATCH 1/2] dreambooth: fix #1566: maintain fp32 wrapper when saving a checkpoint to avoid crash when running fp16 --- examples/dreambooth/train_dreambooth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index b904920f1cd4..2265c41f6599 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -682,8 +682,8 @@ def main(args): if accelerator.is_main_process: pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet), - text_encoder=accelerator.unwrap_model(text_encoder), + unet=accelerator.unwrap_model(unet, True), + text_encoder=accelerator.unwrap_model(text_encoder, True), revision=args.revision, ) save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") From 053e6b7c1a228fbf456f193a7dc89fa66e6ebc56 Mon Sep 17 00:00:00 2001 From: Tim Hinderliter Date: Wed, 7 Dec 2022 23:50:35 -0800 Subject: [PATCH 2/2] dreambooth: guard against passing keep_fp32_wrapper arg to older versions of accelerate. part of fix for #1566 --- examples/dreambooth/train_dreambooth.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 2265c41f6599..5c35760a9370 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -1,5 +1,6 @@ import argparse import hashlib +import inspect import itertools import math import os @@ -680,10 +681,18 @@ def main(args): if global_step % args.save_steps == 0: if accelerator.is_main_process: + # newer versions of accelerate allow the 'keep_fp32_wrapper' arg. without passing + # it, the models will be unwrapped, and when they are then used for further training, + # we will crash. pass this, but only to newer versions of accelerate. fixes + # https://github.com/huggingface/diffusers/issues/1566 + accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set( + inspect.signature(accelerator.unwrap_model).parameters.keys() + ) + extra_args = {"keep_fp32_wrapper": True} if accepts_keep_fp32_wrapper else {} pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet, True), - text_encoder=accelerator.unwrap_model(text_encoder, True), + unet=accelerator.unwrap_model(unet, **extra_args), + text_encoder=accelerator.unwrap_model(text_encoder, **extra_args), revision=args.revision, ) save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")