diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index c76ff7c632e2..f55e63442f45 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -1,5 +1,6 @@ import argparse import hashlib +import inspect import itertools import math import os @@ -690,10 +691,19 @@ def main(args): if global_step % args.save_steps == 0: if accelerator.is_main_process: + # When 'keep_fp32_wrapper' is `False` (the default), then the models are + # unwrapped and the mixed precision hooks are removed, so training crashes + # when the unwrapped models are used for further training. + # This is only supported in newer versions of `accelerate`. + # TODO(Pedro, Suraj): Remove `accepts_keep_fp32_wrapper` when forcing newer accelerate versions + accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set( + inspect.signature(accelerator.unwrap_model).parameters.keys() + ) + extra_args = {"keep_fp32_wrapper": True} if accepts_keep_fp32_wrapper else {} pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet), - text_encoder=accelerator.unwrap_model(text_encoder), + unet=accelerator.unwrap_model(unet, **extra_args), + text_encoder=accelerator.unwrap_model(text_encoder, **extra_args), revision=args.revision, ) save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")