Skip to content

Conversation

@mandeep511
Copy link

This feature enables training Flux models on systems with limited RAM
(like Google Colab Free Tier) by loading text encoders first, caching
embeddings to disk, unloading text encoders, then loading the transformer.
This prevents having both loaded simultaneously, reducing peak RAM usage.

Changes:

  • Add sequential_load config option to ModelConfig
  • Defer transformer loading in stable_diffusion_model.py when enabled
  • Add load_deferred_transformer() method to StableDiffusion class
  • Integrate with SDTrainer to load transformer after TE caching/unload
  • Add UI checkbox in SimpleJob.tsx for Flux models
  • Add example config: train_lora_flux_low_ram.yaml

Usage:
model:
sequential_load: true
train:
cache_text_embeddings: true

This feature enables training Flux models on systems with limited RAM
(like Google Colab Free Tier) by loading text encoders first, caching
embeddings to disk, unloading text encoders, then loading the transformer.
This prevents having both loaded simultaneously, reducing peak RAM usage.

Changes:
- Add sequential_load config option to ModelConfig
- Defer transformer loading in stable_diffusion_model.py when enabled
- Add load_deferred_transformer() method to StableDiffusion class
- Integrate with SDTrainer to load transformer after TE caching/unload
- Add UI checkbox in SimpleJob.tsx for Flux models
- Add example config: train_lora_flux_low_ram.yaml

Usage:
  model:
    sequential_load: true
  train:
    cache_text_embeddings: true
When using pre-quantized BitsAndBytes models like diffusers/FLUX.1-dev-bnb-8bit,
calling .to(device, dtype=dtype) fails with 'Casting a quantized model to a new
dtype is unsupported'.

This fix:
- Adds is_bnb_quantized() helper to detect BnB-quantized models
- Modifies all .to() calls in Flux loading to skip dtype arg for BnB models
- Applies to transformer, T5 text encoder, CLIP text encoder, and deferred load

The check looks for model.quantization_config.load_in_8bit or load_in_4bit.
The is_bnb_quantized() function now also checks for bnb.nn.Linear8bitLt
and bnb.nn.Linear4bit layers directly in the model as a fallback.

This fixes detection for T5 text encoder in diffusers/FLUX.1-dev-bnb-8bit
where the quantization_config attribute may not be properly set but the
model still contains BnB quantized layers.
BnB 8-bit models cannot have .to() called at all - not even for device
movement. The error message says 'model has already been set to the
correct devices and casted to the correct dtype'.

Changed all BnB checks from:
  if is_bnb_quantized(model):
      model.to(device)  # This still fails!
  else:
      model.to(device, dtype=dtype)

To:
  if not is_bnb_quantized(model):
      model.to(device, dtype=dtype)
  # BnB models are left as-is, already on correct device

This affects transformer, T5, and CLIP in both load_model and
load_deferred_transformer.
BnB 8-bit/4-bit models cannot be moved with .to() (they are already
on device) and cannot be optimized/quantized with optimum.quanto (they
are already quantized).

This fix:
1. Wraps ALL .to() calls for transformer, T5, CLIP in checks
   if not is_bnb_quantized(model):
2. Wraps quantization logic for transformer and T5 in checks
   if conf.quantize and not is_bnb_quantized(model):
3. Applies these fixes to both main load_model and deferred loading.

This prevents crashes when using diffusers/FLUX.1-dev-bnb-8bit and similar.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant