Skip to content

Conversation

@sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Jan 31, 2023

Context: #2163

Potentially closes #2163

Command to fire training:

export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export DATASET_NAME="lambdalabs/pokemon-blip-captions"

accelerate launch --mixed_precision="fp16"  train_text_to_image.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --dataset_name=$DATASET_NAME \
  --use_ema \
  --resolution=512 --center_crop --random_flip \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --gradient_checkpointing \
  --max_train_steps=250 \
  --learning_rate=1e-05 \
  --max_grad_norm=1 \
  --lr_scheduler="constant" --lr_warmup_steps=0 \
  --validation_prompt="cute dragon creature" \
  --seed=666 \
  --report_to="wandb" \
  --output_dir="sd-pokemon-model" 

Leads to:

Traceback (most recent call last):█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊      | 29/30 [00:02<00:00, 13.74it/s]
  File "train_text_to_image.py", line 813, in <module>
    main()
  File "train_text_to_image.py", line 791, in main
    pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 636, in __call__
    image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 361, in run_safety_checker
    image, has_nsfw_concept = self.safety_checker(
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/safety_checker.py", line 52, in forward
    pooled_output = self.vision_model(clip_input)[1]  # pooled_output
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/transformers/models/clip/modeling_clip.py", line 934, in forward
    return self.vision_model(
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/transformers/models/clip/modeling_clip.py", line 859, in forward
    hidden_states = self.embeddings(pixel_values)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/transformers/models/clip/modeling_clip.py", line 195, in forward
    patch_embeds = self.patch_embedding(pixel_values)  # shape = [*, width, grid, grid]
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 463, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.FloatTensor) should be the same

The PR also follows this comment: #2163 (comment)

It's the safety checker that causes the problem.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jan 31, 2023

The documentation is not available anymore as the PR was closed or merged.

@sayakpaul sayakpaul requested review from patil-suraj and pcuenca and removed request for patil-suraj January 31, 2023 10:48
@sayakpaul sayakpaul self-assigned this Jan 31, 2023
@sayakpaul sayakpaul marked this pull request as ready for review January 31, 2023 11:29
@sayakpaul
Copy link
Member Author

Adding the safety checker explicitly (f2a143f) also didn't help.

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR! Looks good to me, just left a comment about loading text_encoder and vae again for inference, not really in favour of that.

Also, when doing mixed-precision training, I would expect the generation to be in mixed-precision as well. And for that we'll probably need to use torch.autocast for the reasons explained here #2163 (comment)

We don't really want to promote autocast for inference, but I don't see any other clean way of handling it here. We could explain it with a comment on why it's used in the script and why it's not needed for general inference.

Also, this will be enabled only when using mixed-precision, else everything will default to fp32. This can be achieved using

with torch.autocast(accelerator.device, enabled=args.mixed_precision == "fp16"):
    ....

Happy to hear suggestions :)


if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
raise ImportError("Make sure to install wandb if you want to use it for logging during training. You can do so by doing `pip install wandb`")

Comment on lines +731 to +733
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
args.pretrained_model_name_or_path, subfolder="safety_checker", revision=args.non_ema_revision
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StableDiffusionPipeline.from_pretrained should automatically load safety_checker when available, is there any reason we need to load it here explicitly?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I'm not sure what difference it makes when we load safety_checker separately like this, StableDiffusionPipeline.from_pretrained does pretty much the same thing.

Comment on lines +738 to +739
safety_checker=safety_checker,
revision=args.revision,
Copy link
Contributor

@patil-suraj patil-suraj Jan 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also directly pass vae and text_encoder here, not really in favour of loading them again, as this would take more memory and time and might also lead to OOM (depending on the GPU).

# safety_checker.to(accelerator.device, dtype=weight_dtype)
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When doing ema, we should use ema weights for inference.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to do that, we'll need to

  • temporarily store the non-ema weights
  • copy the ema weights to unet
  • restore the non-ema weight back in the unet.

For that we'll need to add the store, restore method in EMAModel as defined in https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L139

Happy to take care of this if you want :)

@sayakpaul
Copy link
Member Author

@patil-suraj thanks a lot for your comments. Let me address them in some time (might take tomorrow mid half) as I have one other commitment and will be travelling for some time. I hope it's not a blocker.

@patil-suraj
Copy link
Contributor

No worries at all, not super urgent anyway : )

@patil-suraj
Copy link
Contributor

Actually, there's already a similar PR #2157

@sayakpaul
Copy link
Member Author

Actually, there's already a similar PR #2157

Sure, I can close this one then.

@sayakpaul
Copy link
Member Author

Closing in favor of #2157

@sayakpaul sayakpaul closed this Feb 1, 2023
@patil-suraj
Copy link
Contributor

Thanks a lot!

@sayakpaul sayakpaul deleted the feat/text2image-logging branch February 2, 2023 03:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[examples] ]Error due to mismatches in the dtype of UNet while running train_text_to_image.py

4 participants