Skip to content

Conversation

@linoytsaban
Copy link
Collaborator

@linoytsaban linoytsaban commented Aug 9, 2024

  • fixes to text encoder training bugs
  • fix loss calculation
  • fix param preparation - improve memory usage
  • minor updates to readme
  • add requirements
  • add tests

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@linoytsaban linoytsaban marked this pull request as ready for review August 9, 2024 08:44
@linoytsaban linoytsaban changed the title [Flux Dreambooth LoRA] - small fixes & improvements [Flux Dreambooth LoRA] - minor updates Aug 9, 2024
@linoytsaban linoytsaban changed the title [Flux Dreambooth LoRA] - minor updates [Flux Dreambooth LoRA] - te bug fixes & minor updates Aug 9, 2024
@linoytsaban linoytsaban requested a review from sayakpaul August 9, 2024 09:27
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add tests. No way otherwise, we can detect these. What say?

@linoytsaban linoytsaban requested a review from sayakpaul August 9, 2024 14:49
@linoytsaban linoytsaban requested a review from sayakpaul August 10, 2024 10:09
@Gothos
Copy link
Contributor

Gothos commented Aug 10, 2024

training doesn't really seem to converge:
image
Granted, this is bsz 1, but at lr of 3e-6. This was supposed to be a person and only the outlines are visible.
Unless this is a lora loading issue? I think diffusers trained LoRA ought to be ok with diffusers loading?

@sayakpaul
Copy link
Member

@Gothos could be because of some bug in the training, not sure if it's the loading because the lora_state_dict() and load_lora_weights() are from SD3, which we know works.

@Gothos
Copy link
Contributor

Gothos commented Aug 10, 2024

@Gothos could be because of some bug in the training, not sure if it's the loading because the lora_state_dict() and load_lora_weights() are from SD3, which we know works.
Ok, thanks for the info!

@sayakpaul
Copy link
Member

That was meant to be a piece of info, though. If you find anything dodgy holler at me.

@Gothos
Copy link
Contributor

Gothos commented Aug 10, 2024

That was meant to be a piece of info, though. If you find anything dodgy holler at me.

Will do!

@linoytsaban linoytsaban changed the title [Flux Dreambooth LoRA] - te bug fixes & minor updates [Flux Dreambooth LoRA] - te bug fixes & updates Aug 10, 2024
@linoytsaban
Copy link
Collaborator Author

linoytsaban commented Aug 10, 2024

I think it should be better now - with loss fix + memory usage is lower - @Gothos @arcanite24 if you want to give it a try
I did a 500 step training run on the dog dataset and it converged well -
prompt "an sks dog"
image

@Gothos
Copy link
Contributor

Gothos commented Aug 10, 2024

Yep I can also confirm that it converges!

vae_scale_factor=vae_scale_factor,
)

model_pred = model_pred * (-sigmas) + noisy_model_input
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where did this go?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was removed because we discarded of precondition_outputs ,
originaly it was

# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
# Preconditioning of the model outputs.
        if args.precondition_outputs:
            model_pred = model_pred * (-sigmas) + noisy_model_input

and was left accidentally in the previous merge after we removed precondition_outputs

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see also #9086 (comment)
for more context

@sayakpaul
Copy link
Member

@linoytsaban let's maybe also include your additional results in the comments when they are available? I left one question on the changes and then we can merge I think.

@linoytsaban
Copy link
Collaborator Author

yarn art lora example (overfit a bit but nice vibe to it)

#!/usr/bin/env bash
!accelerate launch train_dreambooth_lora_flux.py \
  --pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
  --dataset_name="Norod78/Yarn-art-style"\
  --instance_prompt="a Yarn art style tarot card"\
  --caption_column="text"\
  --output_dir="yarn_art_flux_500" \
  --mixed_precision="bf16" \
  --optimizer="prodigy" \
  --weighting_scheme="none"\
  --resolution=512 \
  --train_batch_size=1 \
  --repeats=1\
  --learning_rate=1.0\
  --report_to="wandb"\
  --gradient_accumulation_steps=1 \
  --gradient_checkpointing \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --rank=4\
  --max_train_steps=500 \
  --checkpointing_steps=2000 \
  --seed="0"\
  --push_to_hub

Group 1

@linoytsaban
Copy link
Collaborator Author

@sayakpaul let's merge once tests pass?

@sayakpaul sayakpaul merged commit 413ca29 into huggingface:main Aug 12, 2024
@sayakpaul
Copy link
Member

Thank you!

@linoytsaban linoytsaban deleted the flux-dreambooth-lora branch August 12, 2024 14:08
sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
* add requirements + fix link to bghira's guide

* text ecnoder training fixes

* text encoder training fixes

* text encoder training fixes

* text encoder training fixes

* style

* add tests

* fix encode_prompt call

* style

* unpack_latents test

* fix lora saving

* remove default val for max_sequenece_length in encode_prompt

* remove default val for max_sequenece_length in encode_prompt

* style

* testing

* style

* testing

* testing

* style

* fix sizing issue

* style

* revert scaling

* style

* style

* scaling test

* style

* scaling test

* remove model pred operation left from pre-conditioning

* remove model pred operation left from pre-conditioning

* fix trainable params

* remove te2 from casting

* transformer to accelerator

* remove prints

* empty commit
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants