diff --git a/flux_train_network.py b/flux_train_network.py index db61f15d9..98a7eeb3d 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -154,7 +154,7 @@ def load_target_model(self, args, weight_dtype, accelerator): elif t5xxl.dtype == torch.float8_e4m3fn: logger.info("Loaded fp8 T5XXL model") - ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, fp8_scaled=args.fp8_scaled_ae) model_version = flux_utils.MODEL_VERSION_FLUX_V1 if self.model_type != "chroma" else flux_utils.MODEL_VERSION_CHROMA return model_version, [clip_l, t5xxl], ae, model @@ -529,6 +529,7 @@ def setup_parser() -> argparse.ArgumentParser: flux_train_utils.add_flux_train_arguments(parser) parser.add_argument("--fp8_scaled", action="store_true", help="Use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う") + parser.add_argument("--fp8_scaled_ae", action="store_true", help="Use scaled fp8 for AutoEncoder / AutoEncoderにスケーリングされたfp8を使う") parser.add_argument( "--split_mode", action="store_true", diff --git a/library/flux_models.py b/library/flux_models.py index 034543f07..89df90cc1 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -346,14 +346,74 @@ def dtype(self) -> torch.dtype: return next(self.parameters()).dtype def encode(self, x: Tensor) -> Tensor: - z = self.reg(self.encoder(x)) + if hasattr(self, 'use_gradient_checkpointing') and self.use_gradient_checkpointing: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + z = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.encoder), + x, + use_reentrant=False + ) + else: + z = self.encoder(x) + + z = self.reg(z) z = self.scale_factor * (z - self.shift_factor) return z def decode(self, z: Tensor) -> Tensor: z = z / self.scale_factor + self.shift_factor - return self.decoder(z) + + if hasattr(self, 'use_gradient_checkpointing') and self.use_gradient_checkpointing: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + return torch.utils.checkpoint.checkpoint( + create_custom_forward(self.decoder), + z, + use_reentrant=False + ) + else: + return self.decoder(z) + def enable_gradient_checkpointing(self): + """Enable gradient checkpointing at the block level for memory efficiency""" + self.use_gradient_checkpointing = True + + # Checkpoint each ResNet block individually in decoder + for up_block in self.decoder.up: + for i, resnet_block in enumerate(up_block.block): + original_forward = resnet_block.forward + + def make_checkpointed_forward(orig_fwd): + def checkpointed_forward(x): + return torch.utils.checkpoint.checkpoint( + orig_fwd, + x, + use_reentrant=False + ) + return checkpointed_forward + + resnet_block.forward = make_checkpointed_forward(original_forward) + + # Checkpoint decoder middle blocks + self.decoder.mid.block_1.forward = self._make_checkpointed(self.decoder.mid.block_1.forward) + self.decoder.mid.block_2.forward = self._make_checkpointed(self.decoder.mid.block_2.forward) + + def _make_checkpointed(self, original_forward): + def checkpointed_forward(x): + return torch.utils.checkpoint.checkpoint( + original_forward, + x, + use_reentrant=False + ) + return checkpointed_forward + def forward(self, x: Tensor) -> Tensor: return self.decode(self.encode(x)) diff --git a/library/flux_utils.py b/library/flux_utils.py index c4bc6712f..49a355a98 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -11,7 +11,7 @@ from transformers import CLIPConfig, CLIPTextModel, T5Config, T5EncoderModel from library.fp8_optimization_utils import apply_fp8_monkey_patch -from library.lora_utils import load_safetensors_with_lora_and_fp8 +from library.lora_utils import load_safetensors_with_lora_and_fp8, load_safetensors_with_fp8_optimization_and_hook from library.utils import setup_logging setup_logging() @@ -30,6 +30,9 @@ FP8_OPTIMIZATION_TARGET_KEYS = ["double_blocks", "single_blocks"] FP8_OPTIMIZATION_EXCLUDE_KEYS = ["_mod", "norm", "modulation"] +AE_FP8_OPTIMIZATION_TARGET_KEYS = ["encoder", "decoder"] +AE_FP8_OPTIMIZATION_EXCLUDE_KEYS = ["norm"] + def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]: """ @@ -193,7 +196,7 @@ def load_flow_model( def load_ae( - ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False + ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False, fp8_scaled=False ) -> flux_models.AutoEncoder: logger.info("Building AutoEncoder") with torch.device("meta"): @@ -201,7 +204,18 @@ def load_ae( ae = flux_models.AutoEncoder(flux_models.configs[MODEL_NAME_DEV].ae_params).to(dtype) logger.info(f"Loading state dict from {ckpt_path}") - sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + if fp8_scaled: + sd = load_safetensors_with_fp8_optimization_and_hook( + [ckpt_path], + fp8_optimization=True, + calc_device=torch.device(device), + target_keys=AE_FP8_OPTIMIZATION_TARGET_KEYS, + exclude_keys=AE_FP8_OPTIMIZATION_EXCLUDE_KEYS, + ) + + apply_fp8_monkey_patch(ae, sd, use_scaled_mm=False) + else: + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) info = ae.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded AE: {info}") return ae @@ -456,7 +470,6 @@ def pack_latents(x: torch.Tensor) -> torch.Tensor: x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) return x - # region Diffusers NUM_DOUBLE_BLOCKS = 19 diff --git a/library/fp8_optimization_utils.py b/library/fp8_optimization_utils.py index 02f99ab6d..5a993a6a9 100644 --- a/library/fp8_optimization_utils.py +++ b/library/fp8_optimization_utils.py @@ -85,6 +85,99 @@ def quantize_fp8(tensor, scale, fp8_dtype, max_value, min_value): return tensor +def quantize_conv_weight( + key: str, + tensor: torch.Tensor, + fp8_dtype: torch.dtype, + max_value: float, + min_value: float, + quantization_mode: str = "channel", + block_size: int = 64, +): + """ + Quantize convolution weights to FP8 format. + + Args: + key (str): Layer key for logging + tensor (torch.Tensor): Convolution weight tensor (out_channels, in_channels, *kernel_size) + fp8_dtype (torch.dtype): Target FP8 dtype + max_value (float): Maximum representable value in FP8 + min_value (float): Minimum representable value in FP8 + quantization_mode (str): "tensor", "channel", or "block" + block_size (int): Block size for block-wise quantization + + Returns: + tuple: (quantized_weight, scale_tensor) + """ + original_shape = tensor.shape + + # Convolution weights have shape: (out_channels, in_channels, *kernel_size) + # We'll quantize per output channel by default + if tensor.ndim < 3: + # Not a convolution weight, fallback to tensor-wise + quantization_mode = "tensor" + + if quantization_mode == "tensor": + # Per-tensor quantization + scale = torch.max(torch.abs(tensor)) / max_value + scale = scale.clamp(min=1e-12) # avoid division by zero + + quantized_weight = quantize_fp8(tensor, scale, fp8_dtype, max_value, min_value) + scale_tensor = scale.reshape(1) + + elif quantization_mode == "channel": + # Per-channel quantization (per output channel for Conv) + out_channels = tensor.shape[0] + + # Flatten spatial dimensions to compute scale per output channel + tensor_flat = tensor.reshape(out_channels, -1) + + # Calculate scale per output channel + scale = torch.max(torch.abs(tensor_flat), dim=1, keepdim=True)[0] / max_value + scale = scale.clamp(min=1e-12) + + # Reshape scale for broadcasting during quantization + scale_broadcast = scale.reshape(out_channels, *([1] * (tensor.ndim - 1))) + + quantized_weight = quantize_fp8(tensor, scale_broadcast, fp8_dtype, max_value, min_value) + scale_tensor = scale # shape: (out_channels, 1) + + elif quantization_mode == "block": + # Block-wise quantization along the flattened spatial dimension + out_channels = tensor.shape[0] + spatial_size = tensor[0].numel() # in_channels * kernel spatial dimensions + + if spatial_size % block_size != 0: + # Fallback to per-channel if not divisible + logger.warning( + f"Layer {key} with shape {tensor.shape} has spatial size {spatial_size} " + f"not divisible by block_size {block_size}, fallback to per-channel quantization." + ) + return quantize_conv_weight(key, tensor, fp8_dtype, max_value, min_value, "channel", block_size) + + num_blocks = spatial_size // block_size + + # Reshape to (out_channels, num_blocks, block_size) + tensor_blocked = tensor.reshape(out_channels, num_blocks, block_size) + + # Calculate scale per block + scale = torch.max(torch.abs(tensor_blocked), dim=2, keepdim=True)[0] / max_value + scale = scale.clamp(min=1e-12) + + # Broadcast scale for quantization + quantized_weight = quantize_fp8(tensor_blocked, scale, fp8_dtype, max_value, min_value) + + # Reshape back to original shape + quantized_weight = quantized_weight.reshape(original_shape) + # scale shape: (out_channels, num_blocks, 1) + + scale_tensor = scale + else: + raise ValueError(f"Unsupported quantization mode: {quantization_mode}") + + return quantized_weight, scale_tensor + + def optimize_state_dict_with_fp8( state_dict: dict, calc_device: Union[str, torch.device], @@ -149,7 +242,23 @@ def optimize_state_dict_with_fp8( if calc_device is not None: value = value.to(calc_device) - quantized_weight, scale_tensor = quantize_weight(key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size) + # Determine if this is a convolution weight based on tensor dimensionality + # Linear: 2D (out_features, in_features) + # Conv1d: 3D (out_channels, in_channels, kernel_size) + # Conv2d: 4D (out_channels, in_channels, kernel_h, kernel_w) + # Conv3d: 5D (out_channels, in_channels, kernel_d, kernel_h, kernel_w) + is_conv = value.ndim > 2 + + if is_conv: + logger.info(f"Quantizing CONV layer: {key} with shape {value.shape}") + quantized_weight, scale_tensor = quantize_conv_weight( + key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size + ) + else: + logger.info(f"Quantizing LINEAR layer: {key} with shape {value.shape}") + quantized_weight, scale_tensor = quantize_weight( + key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size + ) # Add to state dict using original key for weight and new key for scale fp8_key = key # Maintain original key @@ -309,11 +418,20 @@ def is_target_key(key): # Move to calculation device if calc_device is not None: value = value.to(calc_device) + # Determine if this is a convolution weight based on tensor dimensionality + is_conv = value.ndim > 2 original_dtype = value.dtype - quantized_weight, scale_tensor = quantize_weight( - key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size - ) + if is_conv: + logger.info(f"Quantizing CONV layer: {key} with shape {value.shape}") + quantized_weight, scale_tensor = quantize_conv_weight( + key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size + ) + else: + logger.info(f"Quantizing LINEAR layer: {key} with shape {value.shape}") + quantized_weight, scale_tensor = quantize_weight( + key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size + ) # Add to state dict using original key for weight and new key for scale fp8_key = key # Maintain original key @@ -449,21 +567,146 @@ def apply_fp8_monkey_patch(model, optimized_state_dict, use_scaled_mm=False): # Check if this module has a corresponding scale_weight has_scale = name in patched_module_paths - # Apply patch if it's a Linear layer with FP8 scale - if isinstance(module, nn.Linear) and has_scale: - # register the scale_weight as a buffer to load the state_dict - # module.register_buffer("scale_weight", torch.tensor(1.0, dtype=module.weight.dtype)) + # Apply patch if it's a Linear or Conv layer with FP8 scale + is_linear = isinstance(module, nn.Linear) + is_conv = isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)) + + if (is_linear or is_conv) and has_scale: + # Register the scale_weight as a buffer to load the state_dict scale_shape = scale_shape_info[name] module.register_buffer("scale_weight", torch.ones(scale_shape, dtype=module.weight.dtype)) - # Create a new forward method with the patched version. - def new_forward(self, x): - return fp8_linear_forward_patch(self, x, use_scaled_mm, max_value) + # Create a new forward method with the patched version + if is_linear: + def new_forward(self, x): + return fp8_linear_forward_patch(self, x, use_scaled_mm, max_value) + else: # is_conv + def new_forward(self, x): + return fp8_conv_forward_patch(self, x, use_scaled_mm, max_value) # Bind method to module module.forward = new_forward.__get__(module, type(module)) patched_count += 1 - logger.info(f"Number of monkey-patched Linear layers: {patched_count}") + logger.info(f"Number of monkey-patched Linear and Conv layers: {patched_count}") return model + + +def fp8_conv_forward_patch(self, x, use_scaled_mm=False, max_value=None): + """ + Patched forward method for Conv2d/Conv3d layers with FP8 weights. + + Args: + self: Conv layer instance (nn.Conv1d, nn.Conv2d, or nn.Conv3d) + x (torch.Tensor): Input tensor + use_scaled_mm (bool): Not applicable for convolutions (ignored) + max_value (float): Maximum value for FP8 quantization of input (not implemented) + + Returns: + torch.Tensor: Result of convolution transformation + """ + # Note: scaled_mm is not applicable for convolutions, so we always dequantize + + # Dequantize the weight + original_dtype = self.scale_weight.dtype + weight_shape = self.weight.shape # (out_channels, in_channels, *kernel_size) + + if self.scale_weight.ndim == 1: + # Per-tensor quantization: scale shape is (1,) + dequantized_weight = self.weight.to(original_dtype) * self.scale_weight + + elif self.scale_weight.ndim == 2: + # Per-channel quantization: scale shape is (out_channels, 1) + # Need to reshape scale to broadcast correctly with weight + # Weight shape: (out_channels, in_channels, *kernel_size) + # Scale needs to be: (out_channels, 1, 1, ...) to broadcast + out_channels = weight_shape[0] + scale_broadcast_shape = [out_channels] + [1] * (len(weight_shape) - 1) + scale_broadcast = self.scale_weight.reshape(scale_broadcast_shape) + dequantized_weight = self.weight.to(original_dtype) * scale_broadcast + + else: + # Block-wise quantization: scale shape is (out_channels, num_blocks, 1) + out_channels, num_blocks, _ = self.scale_weight.shape + spatial_size = self.weight[0].numel() # in_channels * kernel spatial dimensions + block_size = spatial_size // num_blocks + + # Reshape weight to (out_channels, num_blocks, block_size) + dequantized_weight = self.weight.to(original_dtype).contiguous().view(out_channels, num_blocks, block_size) + dequantized_weight = dequantized_weight * self.scale_weight + dequantized_weight = dequantized_weight.view(weight_shape) + + # Perform convolution based on layer type + if isinstance(self, nn.Conv1d): + output = F.conv1d( + x, dequantized_weight, self.bias, + self.stride, self.padding, self.dilation, self.groups + ) + elif isinstance(self, nn.Conv2d): + output = F.conv2d( + x, dequantized_weight, self.bias, + self.stride, self.padding, self.dilation, self.groups + ) + elif isinstance(self, nn.Conv3d): + output = F.conv3d( + x, dequantized_weight, self.bias, + self.stride, self.padding, self.dilation, self.groups + ) + else: + raise ValueError(f"Unsupported convolution type: {type(self)}") + + return output + +def apply_fp8_conv_monkey_patch(model, optimized_state_dict, use_scaled_mm=False): + """ + Apply monkey patching to convolution layers in a model using FP8 optimized state dict. + + Args: + model (nn.Module): Model instance to patch + optimized_state_dict (dict): FP8 optimized state dict + use_scaled_mm (bool): Not applicable for convolutions (ignored) + + Returns: + int: Number of patched convolution layers + """ + max_value = None # do not quantize input tensor + + # Find all scale keys to identify FP8-optimized layers + scale_keys = [k for k in optimized_state_dict.keys() if k.endswith(".scale_weight")] + + # Enumerate patched layers + patched_module_paths = set() + scale_shape_info = {} + for scale_key in scale_keys: + # Extract module path from scale key (remove .scale_weight) + module_path = scale_key.rsplit(".scale_weight", 1)[0] + patched_module_paths.add(module_path) + + # Store scale shape information + scale_shape_info[module_path] = optimized_state_dict[scale_key].shape + + patched_count = 0 + + # Apply monkey patch to each convolution layer with FP8 weights + for name, module in model.named_modules(): + # Check if this module has a corresponding scale_weight + has_scale = name in patched_module_paths + + # Apply patch if it's a Conv layer with FP8 scale + if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)) and has_scale: + # Register the scale_weight as a buffer to load the state_dict + scale_shape = scale_shape_info[name] + module.register_buffer("scale_weight", torch.ones(scale_shape, dtype=module.weight.dtype)) + + # Create a new forward method with the patched version + def new_forward(self, x): + return fp8_conv_forward_patch(self, x, use_scaled_mm, max_value) + + # Bind method to module + module.forward = new_forward.__get__(module, type(module)) + + patched_count += 1 + + logger.info(f"Number of monkey-patched Convolution layers: {patched_count}") + return patched_count diff --git a/tests/library/test_fp8_optimization_utils_quantize_conv.py b/tests/library/test_fp8_optimization_utils_quantize_conv.py new file mode 100644 index 000000000..b1efd2688 --- /dev/null +++ b/tests/library/test_fp8_optimization_utils_quantize_conv.py @@ -0,0 +1,94 @@ +import pytest +import torch + +from library.fp8_optimization_utils import quantize_conv_weight + +def test_quantize_conv_weight_tensor_mode(): + """Test tensor-wise quantization for conv weights.""" + weight = torch.randn(16, 3, 3, 3) # out_channels, in_channels, kh, kw + fp8_dtype = torch.float8_e4m3fn + max_value = 448.0 + min_value = -448.0 + + quantized, scale = quantize_conv_weight( + "test_layer", weight, fp8_dtype, max_value, min_value, "tensor" + ) + + assert quantized.shape == weight.shape + assert quantized.dtype == fp8_dtype + assert scale.shape == (1,) + + +def test_quantize_conv_weight_channel_mode(): + """Test per-channel quantization for conv weights.""" + weight = torch.randn(16, 3, 3, 3) + fp8_dtype = torch.float8_e4m3fn + max_value = 448.0 + min_value = -448.0 + + quantized, scale = quantize_conv_weight( + "test_layer", weight, fp8_dtype, max_value, min_value, "channel" + ) + + assert quantized.shape == weight.shape + assert quantized.dtype == fp8_dtype + assert scale.shape == (16, 1) # one scale per output channel + + +def test_quantize_conv_weight_block_mode(): + """Test block-wise quantization for conv weights.""" + weight = torch.randn(16, 8, 4, 4) # spatial size = 8*4*4 = 128 + fp8_dtype = torch.float8_e4m3fn + max_value = 448.0 + min_value = -448.0 + block_size = 64 + + quantized, scale = quantize_conv_weight( + "test_layer", weight, fp8_dtype, max_value, min_value, "block", block_size + ) + + assert quantized.shape == weight.shape + assert quantized.dtype == fp8_dtype + assert scale.shape == (16, 2, 1) # 128 / 64 = 2 blocks per channel + + +def test_quantize_conv_weight_block_fallback(): + """Test block-wise fallback to channel mode when not divisible.""" + weight = torch.randn(8, 3, 3, 3) # spatial size = 3*3*3 = 27, not divisible by 64 + fp8_dtype = torch.float8_e4m3fn + max_value = 448.0 + min_value = -448.0 + block_size = 64 + + quantized, scale = quantize_conv_weight( + "test_layer", weight, fp8_dtype, max_value, min_value, "block", block_size + ) + + assert quantized.shape == weight.shape + assert scale.shape == (8, 1) # fallback to channel mode + + +def test_quantize_conv_weight_non_conv_tensor(): + """Test fallback for non-convolution tensors.""" + weight = torch.randn(128, 64) # 2D tensor (e.g., linear layer) + fp8_dtype = torch.float8_e4m3fn + max_value = 448.0 + min_value = -448.0 + + quantized, scale = quantize_conv_weight( + "test_layer", weight, fp8_dtype, max_value, min_value, "channel" + ) + + assert quantized.shape == weight.shape + assert scale.shape == (1,) # should fallback to tensor mode + + +def test_quantize_conv_weight_invalid_mode(): + """Test that invalid quantization mode raises error.""" + weight = torch.randn(16, 3, 3, 3) + fp8_dtype = torch.float8_e4m3fn + + with pytest.raises(ValueError, match="Unsupported quantization mode"): + quantize_conv_weight( + "test_layer", weight, fp8_dtype, 448.0, -448.0, "invalid_mode" + )