Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 91 additions & 1 deletion examples/text_to_image/train_text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@
from datasets import load_dataset
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
from huggingface_hub import HfFolder, Repository, create_repo, whoami
from torchvision import transforms
Expand Down Expand Up @@ -101,6 +102,24 @@ def parse_args():
default="text",
help="The column of the dataset containing a caption or a list of captions.",
)
parser.add_argument(
"--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
)
parser.add_argument(
"--num_validation_images",
type=int,
default=4,
help="Number of images that should be generated during validation with `validation_prompt`.",
)
parser.add_argument(
"--validation_epochs",
type=int,
default=1,
help=(
"Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
" `args.validation_prompt` multiple times: `args.num_validation_images`."
),
)
parser.add_argument(
"--max_train_samples",
type=int,
Expand Down Expand Up @@ -328,6 +347,11 @@ def main():
logging_dir=logging_dir,
)

if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
raise ImportError("Make sure to install wandb if you want to use it for logging during training. You can do so by doing `pip install wandb`")

import wandb

# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
Expand Down Expand Up @@ -697,6 +721,51 @@ def collate_fn(examples):
if global_step >= args.max_train_steps:
break

if accelerator.is_main_process:
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
# create pipeline
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
args.pretrained_model_name_or_path, subfolder="safety_checker", revision=args.non_ema_revision
)
Comment on lines +731 to +733
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StableDiffusionPipeline.from_pretrained should automatically load safety_checker when available, is there any reason we need to load it here explicitly?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I'm not sure what difference it makes when we load safety_checker separately like this, StableDiffusionPipeline.from_pretrained does pretty much the same thing.

# safety_checker.to(accelerator.device, dtype=weight_dtype)
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When doing ema, we should use ema weights for inference.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to do that, we'll need to

  • temporarily store the non-ema weights
  • copy the ema weights to unet
  • restore the non-ema weight back in the unet.

For that we'll need to add the store, restore method in EMAModel as defined in https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L139

Happy to take care of this if you want :)

safety_checker=safety_checker,
revision=args.revision,
Comment on lines +738 to +739
Copy link
Contributor

@patil-suraj patil-suraj Jan 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also directly pass vae and text_encoder here, not really in favour of loading them again, as this would take more memory and time and might also lead to OOM (depending on the GPU).

)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)

# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
images = []
for _ in range(args.num_validation_images):
images.append(
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
)

for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
"validation": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
for i, image in enumerate(images)
]
}
)
del safety_checker
del pipeline
torch.cuda.empty_cache()

# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
Expand All @@ -716,6 +785,27 @@ def collate_fn(examples):
if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)

if args.validation_prompt is not None:
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
images = []
for _ in range(args.num_validation_images):
images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])

for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
"test": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
for i, image in enumerate(images)
]
}
)

accelerator.end_training()


Expand Down