Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 101 additions & 2 deletions examples/text_to_image/train_text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version, deprecate
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.import_utils import is_wandb_available, is_xformers_available


# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
Expand Down Expand Up @@ -297,6 +297,24 @@ def parse_args():
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument(
"--validation_prompt",
type=str,
default=None,
help="A prompt that is used during validation to verify that the model is learning.",
)
parser.add_argument(
"--num_validation_images",
type=int,
default=4,
help="Number of images that should be generated during validation with `validation_prompt`.",
)
parser.add_argument(
"--validation_steps",
type=int,
default=500,
help="Sample a validation image every X updates.",
)

args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
Expand Down Expand Up @@ -353,7 +371,14 @@ def main():
project_config=accelerator_project_config,
)

# Make one log on every process with the configuration for debugging.
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError(
"Make sure to install wandb if you want to use it for logging during training. You can do so by doing"
" `pip install wandb`"
)
import wandb # Make one log on every process with the configuration for debugging.

logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
Expand Down Expand Up @@ -756,6 +781,53 @@ def collate_fn(examples):
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")

if global_step % args.validation_steps == 0:
if accelerator.is_main_process:
if args.validation_prompt:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add validation_epochs and generate according to that instead of generating after each loop.
you could refer to this script to see how to do that

if args.validation_prompt is not None and epoch % args.validation_epochs == 0:

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And this should be wrapped under the main process condition (if accelerator.is_main_process:) to handle situations for multi-GPU training.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm currently using the args.checkpointing_steps. Do we have a preference for # of epochs vs. # of global steps? I slightly favor # of global steps, since that's how we're controlling checkpointing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, for validation images, we prefer epochs since conceptually it's a bit simpler to think of when the inference is going to take place.

if args.use_ema:
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
ema_unet.store(unet.parameters())
ema_unet.copy_to(unet.parameters())
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
text_encoder=text_encoder,
vae=vae,
unet=unet,
revision=args.revision,
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)

# run inference
prompt = [args.validation_prompt]
images = []
with torch.autocast(
str(accelerator.device), enabled=accelerator.mixed_precision == "fp16"
):
for _ in range(args.num_validation_images):
images.append(pipeline(prompt).images[0])

for i, image in enumerate(images):
image.save(os.path.join(args.output_dir, f"validation-{global_step}-{i}.jpg"))
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
"validation": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
for i, image in enumerate(images)
]
}
)
if args.use_ema:
# Switch back to the original UNet parameters.
ema_unet.restore(unet.parameters())
del pipeline
torch.cuda.empty_cache()

logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also want to add another final inference logging as done in the train_dreambooth_lora.py example?

if args.validation_prompt and args.num_validation_images > 0:

progress_bar.set_postfix(**logs)

Expand All @@ -778,6 +850,33 @@ def collate_fn(examples):
)
pipeline.save_pretrained(args.output_dir)

if args.validation_prompt:
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)

# run inference
prompt = [args.validation_prompt]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this required? We can simply pass the validation prompt, no?

images = []
with torch.autocast(str(accelerator.device), enabled=accelerator.mixed_precision == "fp16"):
for _ in range(args.num_validation_images):
images.append(pipeline(prompt).images[0])

for i, image in enumerate(images):
image.save(os.path.join(args.output_dir, f"test-{i}.jpg"))
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
"test": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
for i, image in enumerate(images)
]
}
)

if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)

Expand Down