Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def load_target_model(self, args, weight_dtype, accelerator):
elif t5xxl.dtype == torch.float8_e4m3fn:
logger.info("Loaded fp8 T5XXL model")

ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, fp8_scaled=args.fp8_scaled_ae)

model_version = flux_utils.MODEL_VERSION_FLUX_V1 if self.model_type != "chroma" else flux_utils.MODEL_VERSION_CHROMA
return model_version, [clip_l, t5xxl], ae, model
Expand Down Expand Up @@ -529,6 +529,7 @@ def setup_parser() -> argparse.ArgumentParser:
flux_train_utils.add_flux_train_arguments(parser)

parser.add_argument("--fp8_scaled", action="store_true", help="Use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う")
parser.add_argument("--fp8_scaled_ae", action="store_true", help="Use scaled fp8 for AutoEncoder / AutoEncoderにスケーリングされたfp8を使う")
parser.add_argument(
"--split_mode",
action="store_true",
Expand Down
64 changes: 62 additions & 2 deletions library/flux_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,14 +346,74 @@ def dtype(self) -> torch.dtype:
return next(self.parameters()).dtype

def encode(self, x: Tensor) -> Tensor:
z = self.reg(self.encoder(x))
if hasattr(self, 'use_gradient_checkpointing') and self.use_gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward

z = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.encoder),
x,
use_reentrant=False
)
else:
z = self.encoder(x)

z = self.reg(z)
z = self.scale_factor * (z - self.shift_factor)
return z

def decode(self, z: Tensor) -> Tensor:
z = z / self.scale_factor + self.shift_factor
return self.decoder(z)

if hasattr(self, 'use_gradient_checkpointing') and self.use_gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward

return torch.utils.checkpoint.checkpoint(
create_custom_forward(self.decoder),
z,
use_reentrant=False
)
else:
return self.decoder(z)

def enable_gradient_checkpointing(self):
"""Enable gradient checkpointing at the block level for memory efficiency"""
self.use_gradient_checkpointing = True

# Checkpoint each ResNet block individually in decoder
for up_block in self.decoder.up:
for i, resnet_block in enumerate(up_block.block):
original_forward = resnet_block.forward

def make_checkpointed_forward(orig_fwd):
def checkpointed_forward(x):
return torch.utils.checkpoint.checkpoint(
orig_fwd,
x,
use_reentrant=False
)
return checkpointed_forward

resnet_block.forward = make_checkpointed_forward(original_forward)

# Checkpoint decoder middle blocks
self.decoder.mid.block_1.forward = self._make_checkpointed(self.decoder.mid.block_1.forward)
self.decoder.mid.block_2.forward = self._make_checkpointed(self.decoder.mid.block_2.forward)

def _make_checkpointed(self, original_forward):
def checkpointed_forward(x):
return torch.utils.checkpoint.checkpoint(
original_forward,
x,
use_reentrant=False
)
return checkpointed_forward

def forward(self, x: Tensor) -> Tensor:
return self.decode(self.encode(x))

Expand Down
21 changes: 17 additions & 4 deletions library/flux_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from transformers import CLIPConfig, CLIPTextModel, T5Config, T5EncoderModel

from library.fp8_optimization_utils import apply_fp8_monkey_patch
from library.lora_utils import load_safetensors_with_lora_and_fp8
from library.lora_utils import load_safetensors_with_lora_and_fp8, load_safetensors_with_fp8_optimization_and_hook
from library.utils import setup_logging

setup_logging()
Expand All @@ -30,6 +30,9 @@
FP8_OPTIMIZATION_TARGET_KEYS = ["double_blocks", "single_blocks"]
FP8_OPTIMIZATION_EXCLUDE_KEYS = ["_mod", "norm", "modulation"]

AE_FP8_OPTIMIZATION_TARGET_KEYS = ["encoder", "decoder"]
AE_FP8_OPTIMIZATION_EXCLUDE_KEYS = ["norm"]


def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
"""
Expand Down Expand Up @@ -193,15 +196,26 @@ def load_flow_model(


def load_ae(
ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False, fp8_scaled=False
) -> flux_models.AutoEncoder:
logger.info("Building AutoEncoder")
with torch.device("meta"):
# dev and schnell have the same AE params
ae = flux_models.AutoEncoder(flux_models.configs[MODEL_NAME_DEV].ae_params).to(dtype)

logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
if fp8_scaled:
sd = load_safetensors_with_fp8_optimization_and_hook(
[ckpt_path],
fp8_optimization=True,
calc_device=torch.device(device),
target_keys=AE_FP8_OPTIMIZATION_TARGET_KEYS,
exclude_keys=AE_FP8_OPTIMIZATION_EXCLUDE_KEYS,
)

apply_fp8_monkey_patch(ae, sd, use_scaled_mm=False)
else:
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = ae.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded AE: {info}")
return ae
Expand Down Expand Up @@ -456,7 +470,6 @@ def pack_latents(x: torch.Tensor) -> torch.Tensor:
x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
return x


# region Diffusers

NUM_DOUBLE_BLOCKS = 19
Expand Down
Loading