diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d35fe3925..12d2cfcc0 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -43,7 +43,7 @@ jobs: - name: Install dependencies run: | # Pre-install torch to pin version (requirements.txt has dependencies like transformers which requires pytorch) - pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision pytest==8.3.4 + pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision pytest==8.3.4 faiss-cpu==1.12.0 pip install -r requirements.txt - name: Test with pytest diff --git a/.gitignore b/.gitignore index cfdc02685..a3272cc45 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ GEMINI.md .claude .gemini MagicMock +benchmark_*.py diff --git a/flux_train_network.py b/flux_train_network.py index cfc617088..001f71763 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -1,7 +1,5 @@ import argparse import copy -import math -import random from typing import Any, Optional, Union import torch @@ -36,6 +34,7 @@ def __init__(self): self.is_schnell: Optional[bool] = None self.is_swapping_blocks: bool = False self.model_type: Optional[str] = None + self.gamma_b_dataset = None # CDC-FM Γ_b dataset def assert_extra_args( self, @@ -327,9 +326,15 @@ def get_noise_pred_and_target( noise = torch.randn_like(latents) bsz = latents.shape[0] - # get noisy model input and timesteps - noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( - args, noise_scheduler, latents, noise, accelerator.device, weight_dtype + # Get CDC parameters if enabled + gamma_b_dataset = self.gamma_b_dataset if (self.gamma_b_dataset is not None and "latents_npz" in batch) else None + latents_npz_paths = batch.get("latents_npz") if gamma_b_dataset is not None else None + + # Get noisy model input and timesteps + # If CDC is enabled, this will transform the noise with geometry-aware covariance + noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timestep( + args, noise_scheduler, latents, noise, accelerator.device, weight_dtype, + gamma_b_dataset=gamma_b_dataset, latents_npz_paths=latents_npz_paths ) # pack latents and get img_ids @@ -456,6 +461,15 @@ def update_metadata(self, metadata, args): metadata["ss_model_prediction_type"] = args.model_prediction_type metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift + # CDC-FM metadata + metadata["ss_use_cdc_fm"] = getattr(args, "use_cdc_fm", False) + metadata["ss_cdc_k_neighbors"] = getattr(args, "cdc_k_neighbors", None) + metadata["ss_cdc_k_bandwidth"] = getattr(args, "cdc_k_bandwidth", None) + metadata["ss_cdc_d_cdc"] = getattr(args, "cdc_d_cdc", None) + metadata["ss_cdc_gamma"] = getattr(args, "cdc_gamma", None) + metadata["ss_cdc_adaptive_k"] = getattr(args, "cdc_adaptive_k", None) + metadata["ss_cdc_min_bucket_size"] = getattr(args, "cdc_min_bucket_size", None) + def is_text_encoder_not_needed_for_training(self, args): return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) @@ -494,7 +508,7 @@ def forward(hidden_states): module.forward = forward_hook(module) if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype: - logger.info(f"T5XXL already prepared for fp8") + logger.info("T5XXL already prepared for fp8") else: logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks") text_encoder.to(te_weight_dtype) # fp8 @@ -533,6 +547,72 @@ def setup_parser() -> argparse.ArgumentParser: help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead." " / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。", ) + + # CDC-FM arguments + parser.add_argument( + "--use_cdc_fm", + action="store_true", + help="Enable CDC-FM (Carré du Champ Flow Matching) for geometry-aware noise during training" + " / CDC-FM(Carré du Champ Flow Matching)を有効にして幾何学的ノイズを使用", + ) + parser.add_argument( + "--cdc_k_neighbors", + type=int, + default=256, + help="Number of neighbors for k-NN graph in CDC-FM (default: 256)" + " / CDC-FMのk-NNグラフの近傍数(デフォルト: 256)", + ) + parser.add_argument( + "--cdc_k_bandwidth", + type=int, + default=8, + help="Number of neighbors for bandwidth estimation in CDC-FM (default: 8)" + " / CDC-FMの帯域幅推定の近傍数(デフォルト: 8)", + ) + parser.add_argument( + "--cdc_d_cdc", + type=int, + default=8, + help="Dimension of CDC subspace (default: 8)" + " / CDCサブ空間の次元(デフォルト: 8)", + ) + parser.add_argument( + "--cdc_gamma", + type=float, + default=1.0, + help="CDC strength parameter (default: 1.0)" + " / CDC強度パラメータ(デフォルト: 1.0)", + ) + parser.add_argument( + "--force_recache_cdc", + action="store_true", + help="Force recompute CDC cache even if valid cache exists" + " / 有効なCDCキャッシュが存在してもCDCキャッシュを再計算", + ) + parser.add_argument( + "--cdc_debug", + action="store_true", + help="Enable verbose CDC debug output showing bucket details" + " / CDCの詳細デバッグ出力を有効化(バケット詳細表示)", + ) + parser.add_argument( + "--cdc_adaptive_k", + action="store_true", + help="Use adaptive k_neighbors based on bucket size. If enabled, buckets smaller than k_neighbors will use " + "k=bucket_size-1 instead of skipping CDC entirely. Buckets smaller than cdc_min_bucket_size are still skipped." + " / バケットサイズに基づいてk_neighborsを適応的に調整。有効にすると、k_neighbors未満のバケットは" + "CDCをスキップせずk=バケットサイズ-1を使用。cdc_min_bucket_size未満のバケットは引き続きスキップ。", + ) + parser.add_argument( + "--cdc_min_bucket_size", + type=int, + default=16, + help="Minimum bucket size for CDC computation. Buckets with fewer samples will use standard Gaussian noise. " + "Only relevant when --cdc_adaptive_k is enabled (default: 16)" + " / CDC計算の最小バケットサイズ。これより少ないサンプルのバケットは標準ガウスノイズを使用。" + "--cdc_adaptive_k有効時のみ関連(デフォルト: 16)", + ) + return parser diff --git a/library/cdc_fm.py b/library/cdc_fm.py new file mode 100644 index 000000000..4a5772ad6 --- /dev/null +++ b/library/cdc_fm.py @@ -0,0 +1,905 @@ +import logging +import torch +import numpy as np +from pathlib import Path +from tqdm import tqdm +from safetensors.torch import save_file +from typing import List, Dict, Optional, Union, Tuple +from dataclasses import dataclass + +logger = logging.getLogger(__name__) + + +@dataclass +class LatentSample: + """ + Container for a single latent with metadata + """ + latent: np.ndarray # (d,) flattened latent vector + global_idx: int # Global index in dataset + shape: Tuple[int, ...] # Original shape before flattening (C, H, W) + latents_npz_path: str # Path to the latent cache file + metadata: Optional[Dict] = None # Any extra info (prompt, filename, etc.) + + +class CarreDuChampComputer: + """ + Core CDC-FM computation - agnostic to data source + Just handles the math for a batch of same-size latents + """ + + def __init__( + self, + k_neighbors: int = 256, + k_bandwidth: int = 8, + d_cdc: int = 8, + gamma: float = 1.0, + device: str = 'cuda' + ): + self.k = k_neighbors + self.k_bw = k_bandwidth + self.d_cdc = d_cdc + self.gamma = gamma + self.device = torch.device(device if torch.cuda.is_available() else 'cpu') + + def compute_knn_graph(self, latents_np: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """ + Build k-NN graph using pure PyTorch + + Args: + latents_np: (N, d) numpy array of same-dimensional latents + + Returns: + distances: (N, k_actual+1) distances (k_actual may be less than k if N is small) + indices: (N, k_actual+1) neighbor indices + """ + N, d = latents_np.shape + + # Clamp k to available neighbors (can't have more neighbors than samples) + k_actual = min(self.k, N - 1) + + # Convert to torch tensor + latents_tensor = torch.from_numpy(latents_np).to(self.device) + + # Compute pairwise L2 distances efficiently + # ||a - b||^2 = ||a||^2 + ||b||^2 - 2 + # This is more memory efficient than computing all pairwise differences + # For large batches, we'll chunk the computation + chunk_size = 1000 # Process 1000 queries at a time to manage memory + + if N <= chunk_size: + # Small batch: compute all at once + distances_sq = torch.cdist(latents_tensor, latents_tensor, p=2) ** 2 + distances_k_sq, indices_k = torch.topk( + distances_sq, k=k_actual + 1, dim=1, largest=False + ) + distances = torch.sqrt(distances_k_sq).cpu().numpy() + indices = indices_k.cpu().numpy() + else: + # Large batch: chunk to avoid OOM + distances_list = [] + indices_list = [] + + for i in range(0, N, chunk_size): + end_i = min(i + chunk_size, N) + chunk = latents_tensor[i:end_i] + + # Compute distances for this chunk + distances_sq = torch.cdist(chunk, latents_tensor, p=2) ** 2 + distances_k_sq, indices_k = torch.topk( + distances_sq, k=k_actual + 1, dim=1, largest=False + ) + + distances_list.append(torch.sqrt(distances_k_sq).cpu().numpy()) + indices_list.append(indices_k.cpu().numpy()) + + # Free memory + del distances_sq, distances_k_sq, indices_k + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + distances = np.concatenate(distances_list, axis=0) + indices = np.concatenate(indices_list, axis=0) + + return distances, indices + + @torch.no_grad() + def compute_gamma_b_single( + self, + point_idx: int, + latents_np: np.ndarray, + distances: np.ndarray, + indices: np.ndarray, + epsilon: np.ndarray + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute Γ_b for a single point + + Args: + point_idx: Index of point to process + latents_np: (N, d) all latents in this batch + distances: (N, k+1) precomputed distances + indices: (N, k+1) precomputed neighbor indices + epsilon: (N,) bandwidth per point + + Returns: + eigenvectors: (d_cdc, d) as half precision tensor + eigenvalues: (d_cdc,) as half precision tensor + """ + d = latents_np.shape[1] + + # Get neighbors (exclude self) + neighbor_idx = indices[point_idx, 1:] # (k,) + neighbor_points = latents_np[neighbor_idx] # (k, d) + + # Clamp distances to prevent overflow (max realistic L2 distance) + MAX_DISTANCE = 1e10 + neighbor_dists = np.clip(distances[point_idx, 1:], 0, MAX_DISTANCE) + neighbor_dists_sq = neighbor_dists ** 2 # (k,) + + # Compute Gaussian kernel weights with numerical guards + eps_i = max(epsilon[point_idx], 1e-10) # Prevent division by zero + eps_neighbors = np.maximum(epsilon[neighbor_idx], 1e-10) + + # Compute denominator with guard against overflow + denom = eps_i * eps_neighbors + denom = np.maximum(denom, 1e-20) # Additional guard + + # Compute weights with safe exponential + exp_arg = -neighbor_dists_sq / denom + exp_arg = np.clip(exp_arg, -50, 0) # Prevent exp overflow/underflow + weights = np.exp(exp_arg) + + # Normalize weights, handle edge case of all zeros + weight_sum = weights.sum() + if weight_sum < 1e-20 or not np.isfinite(weight_sum): + # Fallback to uniform weights + weights = np.ones_like(weights) / len(weights) + else: + weights = weights / weight_sum + + # Compute local mean + m_star = np.sum(weights[:, None] * neighbor_points, axis=0) + + # Center and weight for SVD + centered = neighbor_points - m_star + weighted_centered = np.sqrt(weights)[:, None] * centered # (k, d) + + # Validate input is finite before SVD + if not np.all(np.isfinite(weighted_centered)): + logger.warning(f"Non-finite values detected in weighted_centered for point {point_idx}, using fallback") + # Fallback: use uniform weights and simple centering + weights_uniform = np.ones(len(neighbor_points)) / len(neighbor_points) + m_star = np.mean(neighbor_points, axis=0) + centered = neighbor_points - m_star + weighted_centered = np.sqrt(weights_uniform)[:, None] * centered + + # Move to GPU for SVD + weighted_centered_torch = torch.from_numpy(weighted_centered).to( + self.device, dtype=torch.float32 + ) + + try: + U, S, Vh = torch.linalg.svd(weighted_centered_torch, full_matrices=False) + except RuntimeError as e: + logger.debug(f"GPU SVD failed for point {point_idx}, using CPU: {e}") + try: + U, S, Vh = np.linalg.svd(weighted_centered, full_matrices=False) + U = torch.from_numpy(U).to(self.device) + S = torch.from_numpy(S).to(self.device) + Vh = torch.from_numpy(Vh).to(self.device) + except np.linalg.LinAlgError as e2: + logger.error(f"CPU SVD also failed for point {point_idx}: {e2}, returning zero matrix") + # Return zero eigenvalues/vectors as fallback + return ( + torch.zeros(self.d_cdc, d, dtype=torch.float16), + torch.zeros(self.d_cdc, dtype=torch.float16) + ) + + # Eigenvalues of Γ_b + eigenvalues_full = S ** 2 + + # Keep top d_cdc + if len(eigenvalues_full) >= self.d_cdc: + top_eigenvalues, top_idx = torch.topk(eigenvalues_full, self.d_cdc) + top_eigenvectors = Vh[top_idx, :] # (d_cdc, d) + else: + # Pad if k < d_cdc + top_eigenvalues = eigenvalues_full + top_eigenvectors = Vh + if len(eigenvalues_full) < self.d_cdc: + pad_size = self.d_cdc - len(eigenvalues_full) + top_eigenvalues = torch.cat([ + top_eigenvalues, + torch.zeros(pad_size, device=self.device) + ]) + top_eigenvectors = torch.cat([ + top_eigenvectors, + torch.zeros(pad_size, d, device=self.device) + ]) + + # Eigenvalue Rescaling (per CDC-FM paper Appendix E, Equation 33) + # Paper formula: c_i = (1/λ_1^i) × min(neighbor_distance²/9, c²_max) + # Then apply gamma: γc_i Γ̂(x^(i)) + # + # Our implementation: + # 1. Normalize by max eigenvalue (λ_1^i) - aligns with paper's 1/λ_1^i factor + # 2. Apply gamma hyperparameter - aligns with paper's γ global scaling + # 3. Clamp for numerical stability + # + # Raw eigenvalues from SVD can be very large (100-5000 for 65k-dimensional FLUX latents) + # Without normalization, clamping to [1e-3, 1.0] would saturate all values at upper bound + + # Step 1: Normalize by the maximum eigenvalue to get relative scales + # This is the paper's 1/λ_1^i normalization factor + max_eigenval = top_eigenvalues[0].item() if len(top_eigenvalues) > 0 else 1.0 + + if max_eigenval > 1e-10: + # Scale so max eigenvalue = 1.0, preserving relative ratios + top_eigenvalues = top_eigenvalues / max_eigenval + + # Step 2: Apply gamma and clamp to safe range + # Gamma is the paper's tuneable hyperparameter (defaults to 1.0) + # Clamping ensures numerical stability and prevents extreme values + top_eigenvalues = torch.clamp(top_eigenvalues * self.gamma, 1e-3, self.gamma * 1.0) + + # Convert to fp16 for storage - now safe since eigenvalues are ~0.01-1.0 + # fp16 range: 6e-5 to 65,504, our values are well within this + eigenvectors_fp16 = top_eigenvectors.cpu().half() + eigenvalues_fp16 = top_eigenvalues.cpu().half() + + # Cleanup + del weighted_centered_torch, U, S, Vh, top_eigenvectors, top_eigenvalues + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return eigenvectors_fp16, eigenvalues_fp16 + + def compute_for_batch( + self, + latents_np: np.ndarray, + global_indices: List[int] + ) -> Dict[int, Tuple[torch.Tensor, torch.Tensor]]: + """ + Compute Γ_b for all points in a batch of same-size latents + + Args: + latents_np: (N, d) numpy array + global_indices: List of global dataset indices for each latent + + Returns: + Dict mapping global_idx -> (eigenvectors, eigenvalues) + """ + N, d = latents_np.shape + + # Validate inputs + if len(global_indices) != N: + raise ValueError(f"Length mismatch: latents has {N} samples but got {len(global_indices)} indices") + + print(f"Computing CDC for batch: {N} samples, dim={d}") + + # Handle small sample cases - require minimum samples for meaningful k-NN + MIN_SAMPLES_FOR_CDC = 5 # Need at least 5 samples for reasonable geometry estimation + + if N < MIN_SAMPLES_FOR_CDC: + print(f" Only {N} samples (< {MIN_SAMPLES_FOR_CDC}) - using identity matrix (no CDC correction)") + results = {} + for local_idx in range(N): + global_idx = global_indices[local_idx] + # Return zero eigenvectors/eigenvalues (will result in identity in compute_sigma_t_x) + eigvecs = np.zeros((self.d_cdc, d), dtype=np.float16) + eigvals = np.zeros(self.d_cdc, dtype=np.float16) + results[global_idx] = (eigvecs, eigvals) + return results + + # Step 1: Build k-NN graph + print(" Building k-NN graph...") + distances, indices = self.compute_knn_graph(latents_np) + + # Step 2: Compute bandwidth + # Use min to handle case where k_bw >= actual neighbors returned + k_bw_actual = min(self.k_bw, distances.shape[1] - 1) + epsilon = distances[:, k_bw_actual] + + # Step 3: Compute Γ_b for each point + results = {} + print(" Computing Γ_b for each point...") + for local_idx in tqdm(range(N), desc=" Processing", leave=False): + global_idx = global_indices[local_idx] + eigvecs, eigvals = self.compute_gamma_b_single( + local_idx, latents_np, distances, indices, epsilon + ) + results[global_idx] = (eigvecs, eigvals) + + return results + + +class LatentBatcher: + """ + Collects variable-size latents and batches them by size + """ + + def __init__(self, size_tolerance: float = 0.0): + """ + Args: + size_tolerance: If > 0, group latents within tolerance % of size + If 0, only exact size matches are batched + """ + self.size_tolerance = size_tolerance + self.samples: List[LatentSample] = [] + + def add_sample(self, sample: LatentSample): + """Add a single latent sample""" + self.samples.append(sample) + + def add_latent( + self, + latent: Union[np.ndarray, torch.Tensor], + global_idx: int, + latents_npz_path: str, + shape: Optional[Tuple[int, ...]] = None, + metadata: Optional[Dict] = None + ): + """ + Add a latent vector with automatic shape tracking + + Args: + latent: Latent vector (any shape, will be flattened) + global_idx: Global index in dataset + latents_npz_path: Path to the latent cache file (e.g., "image_0512x0768_flux.npz") + shape: Original shape (if None, uses latent.shape) + metadata: Optional metadata dict + """ + # Convert to numpy and flatten + if isinstance(latent, torch.Tensor): + latent_np = latent.cpu().numpy() + else: + latent_np = latent + + original_shape = shape if shape is not None else latent_np.shape + latent_flat = latent_np.flatten() + + sample = LatentSample( + latent=latent_flat, + global_idx=global_idx, + shape=original_shape, + latents_npz_path=latents_npz_path, + metadata=metadata + ) + + self.add_sample(sample) + + def get_batches(self) -> Dict[Tuple[int, ...], List[LatentSample]]: + """ + Group samples by exact shape to avoid resizing distortion. + + Each bucket contains only samples with identical latent dimensions. + Buckets with fewer than k_neighbors samples will be skipped during CDC + computation and fall back to standard Gaussian noise. + + Returns: + Dict mapping exact_shape -> list of samples with that shape + """ + batches = {} + shapes = set() + + for sample in self.samples: + shape_key = sample.shape + shapes.add(shape_key) + + # Group by exact shape only - no aspect ratio grouping or resizing + if shape_key not in batches: + batches[shape_key] = [] + + batches[shape_key].append(sample) + + # If more than one unique shape, log a warning + if len(shapes) > 1: + logger.warning( + "Dimension mismatch: %d unique shapes detected. " + "Shapes: %s. Using Gaussian fallback for these samples.", + len(shapes), + shapes + ) + + return batches + + def _get_aspect_ratio_key(self, shape: Tuple[int, ...]) -> str: + """ + Get aspect ratio category for grouping. + Groups images by aspect ratio bins to ensure sufficient samples. + + For shape (C, H, W), computes aspect ratio H/W and bins it. + """ + if len(shape) < 3: + return "unknown" + + # Extract spatial dimensions (H, W) + h, w = shape[-2], shape[-1] + aspect_ratio = h / w + + # Define aspect ratio bins (±15% tolerance) + # Common ratios: 1.0 (square), 1.33 (4:3), 0.75 (3:4), 1.78 (16:9), 0.56 (9:16) + bins = [ + (0.5, 0.65, "9:16"), # Portrait tall + (0.65, 0.85, "3:4"), # Portrait + (0.85, 1.15, "1:1"), # Square + (1.15, 1.50, "4:3"), # Landscape + (1.50, 2.0, "16:9"), # Landscape wide + (2.0, 3.0, "21:9"), # Ultra wide + ] + + for min_ratio, max_ratio, label in bins: + if min_ratio <= aspect_ratio < max_ratio: + return label + + # Fallback for extreme ratios + if aspect_ratio < 0.5: + return "ultra_tall" + else: + return "ultra_wide" + + def _shapes_similar(self, shape1: Tuple[int, ...], shape2: Tuple[int, ...]) -> bool: + """Check if two shapes are within tolerance""" + if len(shape1) != len(shape2): + return False + + size1 = np.prod(shape1) + size2 = np.prod(shape2) + + ratio = abs(size1 - size2) / max(size1, size2) + return ratio <= self.size_tolerance + + def __len__(self): + return len(self.samples) + + +class CDCPreprocessor: + """ + High-level CDC preprocessing coordinator + Handles variable-size latents by batching and delegating to CarreDuChampComputer + """ + + def __init__( + self, + k_neighbors: int = 256, + k_bandwidth: int = 8, + d_cdc: int = 8, + gamma: float = 1.0, + device: str = 'cuda', + size_tolerance: float = 0.0, + debug: bool = False, + adaptive_k: bool = False, + min_bucket_size: int = 16, + dataset_dirs: Optional[List[str]] = None + ): + self.computer = CarreDuChampComputer( + k_neighbors=k_neighbors, + k_bandwidth=k_bandwidth, + d_cdc=d_cdc, + gamma=gamma, + device=device + ) + self.batcher = LatentBatcher(size_tolerance=size_tolerance) + self.debug = debug + self.adaptive_k = adaptive_k + self.min_bucket_size = min_bucket_size + self.dataset_dirs = dataset_dirs or [] + self.config_hash = self._compute_config_hash() + + def _compute_config_hash(self) -> str: + """ + Compute a short hash of CDC configuration for filename uniqueness. + + Hash includes: + - Sorted dataset/subset directory paths + - CDC parameters (k_neighbors, d_cdc, gamma) + + This ensures CDC files are invalidated when: + - Dataset composition changes (different dirs) + - CDC parameters change + + Returns: + 8-character hex hash + """ + import hashlib + + # Sort dataset dirs for consistent hashing + dirs_str = "|".join(sorted(self.dataset_dirs)) + + # Include CDC parameters + config_str = f"{dirs_str}|k={self.computer.k}|d={self.computer.d_cdc}|gamma={self.computer.gamma}" + + # Create short hash (8 chars is enough for uniqueness in this context) + hash_obj = hashlib.sha256(config_str.encode()) + return hash_obj.hexdigest()[:8] + + def add_latent( + self, + latent: Union[np.ndarray, torch.Tensor], + global_idx: int, + latents_npz_path: str, + shape: Optional[Tuple[int, ...]] = None, + metadata: Optional[Dict] = None + ): + """ + Add a single latent to the preprocessing queue + + Args: + latent: Latent vector (will be flattened) + global_idx: Global dataset index + latents_npz_path: Path to the latent cache file + shape: Original shape (C, H, W) + metadata: Optional metadata + """ + self.batcher.add_latent(latent, global_idx, latents_npz_path, shape, metadata) + + @staticmethod + def get_cdc_npz_path( + latents_npz_path: str, + config_hash: Optional[str] = None, + latent_shape: Optional[Tuple[int, ...]] = None + ) -> str: + """ + Get CDC cache path from latents cache path + + Includes optional config_hash to ensure CDC files are unique to dataset/subset + configuration and CDC parameters. This prevents using stale CDC files when + the dataset composition or CDC settings change. + + IMPORTANT: When using multi-resolution training, you MUST pass latent_shape to ensure + CDC files are unique per resolution. Without it, different resolutions will overwrite + each other's CDC caches, causing dimension mismatch errors. + + Args: + latents_npz_path: Path to latent cache (e.g., "image_0512x0768_flux.npz") + config_hash: Optional 8-char hash of (dataset_dirs + CDC params) + If None, returns path without hash (for backward compatibility) + latent_shape: Optional latent shape tuple (C, H, W) to make CDC resolution-specific + For multi-resolution training, this MUST be provided + + Returns: + CDC cache path examples: + - With shape + hash: "image_0512x0768_flux_cdc_104x80_a1b2c3d4.npz" + - With hash only: "image_0512x0768_flux_cdc_a1b2c3d4.npz" + - Without hash: "image_0512x0768_flux_cdc.npz" + + Example multi-resolution scenario: + resolution=512 → latent_shape=(16,64,48) → "image_flux_cdc_64x48_hash.npz" + resolution=768 → latent_shape=(16,104,80) → "image_flux_cdc_104x80_hash.npz" + """ + path = Path(latents_npz_path) + + # Build filename components + components = [path.stem, "cdc"] + + # Add latent resolution if provided (for multi-resolution training) + if latent_shape is not None: + if len(latent_shape) >= 3: + # Format: HxW (e.g., "104x80" from shape (16, 104, 80)) + h, w = latent_shape[-2], latent_shape[-1] + components.append(f"{h}x{w}") + else: + raise ValueError(f"latent_shape must have at least 3 dimensions (C, H, W), got {latent_shape}") + + # Add config hash if provided + if config_hash: + components.append(config_hash) + + # Build final filename + new_stem = "_".join(components) + return str(path.with_stem(new_stem)) + + def compute_all(self) -> int: + """ + Compute Γ_b for all added latents and save individual CDC files next to each latent cache + + Returns: + Number of CDC files saved + """ + + # Get batches by exact size (no resizing) + batches = self.batcher.get_batches() + + # Count samples that will get CDC vs fallback + k_neighbors = self.computer.k + min_threshold = self.min_bucket_size if self.adaptive_k else k_neighbors + + if self.adaptive_k: + samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= min_threshold) + else: + samples_with_cdc = sum(len(samples) for samples in batches.values() if len(samples) >= k_neighbors) + samples_fallback = len(self.batcher) - samples_with_cdc + + if self.debug: + print(f"\nProcessing {len(self.batcher)} samples in {len(batches)} exact size buckets") + if self.adaptive_k: + print(f" Adaptive k enabled: k_max={k_neighbors}, min_bucket_size={min_threshold}") + print(f" Samples with CDC (≥{min_threshold} per bucket): {samples_with_cdc}/{len(self.batcher)} ({samples_with_cdc/len(self.batcher)*100:.1f}%)") + print(f" Samples using Gaussian fallback: {samples_fallback}/{len(self.batcher)} ({samples_fallback/len(self.batcher)*100:.1f}%)") + else: + mode = "adaptive" if self.adaptive_k else "fixed" + logger.info(f"Processing {len(self.batcher)} samples in {len(batches)} buckets ({mode} k): {samples_with_cdc} with CDC, {samples_fallback} fallback") + + # Storage for results + all_results = {} + + # Process each bucket with progress bar + bucket_iter = tqdm(batches.items(), desc="Computing CDC", unit="bucket", disable=self.debug) if not self.debug else batches.items() + + for shape, samples in bucket_iter: + num_samples = len(samples) + + if self.debug: + print(f"\n{'='*60}") + print(f"Bucket: {shape} ({num_samples} samples)") + print(f"{'='*60}") + + # Determine effective k for this bucket + if self.adaptive_k: + # Adaptive mode: skip if below minimum, otherwise use best available k + if num_samples < min_threshold: + if self.debug: + print(f" ⚠️ Skipping CDC: {num_samples} samples < min_bucket_size={min_threshold}") + print(" → These samples will use standard Gaussian noise (no CDC)") + + # Store zero eigenvectors/eigenvalues (Gaussian fallback) + C, H, W = shape + d = C * H * W + + for sample in samples: + eigvecs = np.zeros((self.computer.d_cdc, d), dtype=np.float16) + eigvals = np.zeros(self.computer.d_cdc, dtype=np.float16) + all_results[sample.global_idx] = (eigvecs, eigvals) + + continue + + # Use adaptive k for this bucket + k_effective = min(k_neighbors, num_samples - 1) + else: + # Fixed mode: skip if below k_neighbors + if num_samples < k_neighbors: + if self.debug: + print(f" ⚠️ Skipping CDC: {num_samples} samples < k={k_neighbors}") + print(" → These samples will use standard Gaussian noise (no CDC)") + + # Store zero eigenvectors/eigenvalues (Gaussian fallback) + C, H, W = shape + d = C * H * W + + for sample in samples: + eigvecs = np.zeros((self.computer.d_cdc, d), dtype=np.float16) + eigvals = np.zeros(self.computer.d_cdc, dtype=np.float16) + all_results[sample.global_idx] = (eigvecs, eigvals) + + continue + + k_effective = k_neighbors + + # Collect latents (no resizing needed - all same shape) + latents_list = [] + global_indices = [] + + for sample in samples: + global_indices.append(sample.global_idx) + latents_list.append(sample.latent) # Already flattened + + latents_np = np.stack(latents_list, axis=0) # (N, C*H*W) + + # Compute CDC for this batch with effective k + if self.debug: + if self.adaptive_k and k_effective < k_neighbors: + print(f" Computing CDC with adaptive k={k_effective} (max_k={k_neighbors}), d_cdc={self.computer.d_cdc}") + else: + print(f" Computing CDC with k={k_effective} neighbors, d_cdc={self.computer.d_cdc}") + + # Temporarily override k for this bucket + original_k = self.computer.k + self.computer.k = k_effective + batch_results = self.computer.compute_for_batch(latents_np, global_indices) + self.computer.k = original_k + + # No resizing needed - eigenvectors are already correct size + if self.debug: + print(f" ✓ CDC computed for {len(batch_results)} samples (no resizing)") + + # Merge into overall results + all_results.update(batch_results) + + # Save individual CDC files next to each latent cache + if self.debug: + print(f"\n{'='*60}") + print("Saving individual CDC files...") + print(f"{'='*60}") + + files_saved = 0 + total_size = 0 + + save_iter = tqdm(self.batcher.samples, desc="Saving CDC files", disable=self.debug) if not self.debug else self.batcher.samples + + for sample in save_iter: + # Get CDC cache path with config hash and latent shape (for multi-resolution support) + cdc_path = self.get_cdc_npz_path(sample.latents_npz_path, self.config_hash, sample.shape) + + # Get CDC results for this sample + if sample.global_idx in all_results: + eigvecs, eigvals = all_results[sample.global_idx] + + # Convert to numpy if needed + if isinstance(eigvecs, torch.Tensor): + eigvecs = eigvecs.numpy() + if isinstance(eigvals, torch.Tensor): + eigvals = eigvals.numpy() + + # Save metadata and CDC results + np.savez( + cdc_path, + eigenvectors=eigvecs, + eigenvalues=eigvals, + shape=np.array(sample.shape), + k_neighbors=self.computer.k, + d_cdc=self.computer.d_cdc, + gamma=self.computer.gamma + ) + + files_saved += 1 + total_size += Path(cdc_path).stat().st_size + + logger.debug(f"Saved CDC file: {cdc_path}") + + total_size_mb = total_size / 1024 / 1024 + logger.info(f"Saved {files_saved} CDC files, total size: {total_size_mb:.2f} MB") + + return files_saved + + +class GammaBDataset: + """ + Efficient loader for Γ_b matrices during training + Loads from individual CDC cache files next to latent caches + """ + + def __init__(self, device: str = 'cuda', config_hash: Optional[str] = None): + """ + Initialize CDC dataset loader + + Args: + device: Device for loading tensors + config_hash: Optional config hash to use for CDC file lookup. + If None, uses default naming without hash. + """ + self.device = torch.device(device if torch.cuda.is_available() else 'cpu') + self.config_hash = config_hash + if config_hash: + logger.info(f"CDC loader initialized (hash: {config_hash})") + else: + logger.info("CDC loader initialized (no hash, backward compatibility mode)") + + @torch.no_grad() + def get_gamma_b_sqrt( + self, + latents_npz_paths: List[str], + device: Optional[str] = None, + latent_shape: Optional[Tuple[int, ...]] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get Γ_b^(1/2) components for a batch of latents + + Args: + latents_npz_paths: List of latent cache paths (e.g., ["image_0512x0768_flux.npz", ...]) + device: Device to load to (defaults to self.device) + latent_shape: Latent shape (C, H, W) to identify which CDC file to load + Required for multi-resolution training to avoid loading wrong CDC + + Returns: + eigenvectors: (B, d_cdc, d) - NOTE: d may vary per sample! + eigenvalues: (B, d_cdc) + + Note: + For multi-resolution training, latent_shape MUST be provided to load the correct + CDC file. Without it, the wrong CDC file may be loaded, causing dimension mismatch. + """ + if device is None: + device = self.device + + eigenvectors_list = [] + eigenvalues_list = [] + + for latents_npz_path in latents_npz_paths: + # Get CDC cache path with config hash and latent shape (for multi-resolution support) + cdc_path = CDCPreprocessor.get_cdc_npz_path(latents_npz_path, self.config_hash, latent_shape) + + # Load CDC data + if not Path(cdc_path).exists(): + raise FileNotFoundError( + f"CDC cache file not found: {cdc_path}. " + f"Make sure to run CDC preprocessing before training." + ) + + data = np.load(cdc_path) + eigvecs = torch.from_numpy(data['eigenvectors']).to(device).float() + eigvals = torch.from_numpy(data['eigenvalues']).to(device).float() + + eigenvectors_list.append(eigvecs) + eigenvalues_list.append(eigvals) + + # Stack - all should have same d_cdc and d within a batch (enforced by bucketing) + # Check if all eigenvectors have the same dimension + dims = [ev.shape[1] for ev in eigenvectors_list] + if len(set(dims)) > 1: + # Dimension mismatch! This shouldn't happen with proper bucketing + # but can occur if batch contains mixed sizes + raise RuntimeError( + f"CDC eigenvector dimension mismatch in batch: {set(dims)}. " + f"Latent paths: {latents_npz_paths}. " + f"This means the training batch contains images of different sizes, " + f"which violates CDC's requirement for uniform latent dimensions per batch. " + f"Check that your dataloader buckets are configured correctly." + ) + + eigenvectors = torch.stack(eigenvectors_list, dim=0) + eigenvalues = torch.stack(eigenvalues_list, dim=0) + + return eigenvectors, eigenvalues + + def compute_sigma_t_x( + self, + eigenvectors: torch.Tensor, + eigenvalues: torch.Tensor, + x: torch.Tensor, + t: Union[float, torch.Tensor] + ) -> torch.Tensor: + """ + Compute Σ_t @ x where Σ_t ≈ (1-t) I + t Γ_b^(1/2) + + Args: + eigenvectors: (B, d_cdc, d) + eigenvalues: (B, d_cdc) + x: (B, d) or (B, C, H, W) - will be flattened if needed + t: (B,) or scalar time + + Returns: + result: Same shape as input x + + Note: + Gradients flow through this function for backprop during training. + """ + # Store original shape to restore later + orig_shape = x.shape + + # Flatten x if it's 4D + if x.dim() == 4: + B, C, H, W = x.shape + x = x.reshape(B, -1) # (B, C*H*W) + + if not isinstance(t, torch.Tensor): + t = torch.tensor(t, device=x.device, dtype=x.dtype) + + if t.dim() == 0: + t = t.expand(x.shape[0]) + + t = t.view(-1, 1) + + # Early return for t=0 to avoid numerical errors + if not t.requires_grad and torch.allclose(t, torch.zeros_like(t), atol=1e-8): + return x.reshape(orig_shape) + + # Check if CDC is disabled (all eigenvalues are zero) + # This happens for buckets with < k_neighbors samples + if torch.allclose(eigenvalues, torch.zeros_like(eigenvalues), atol=1e-8): + # Fallback to standard Gaussian noise (no CDC correction) + return x.reshape(orig_shape) + + # Γ_b^(1/2) @ x using low-rank representation + Vt_x = torch.einsum('bkd,bd->bk', eigenvectors, x) + sqrt_eigenvalues = torch.sqrt(eigenvalues.clamp(min=1e-10)) + sqrt_lambda_Vt_x = sqrt_eigenvalues * Vt_x + gamma_sqrt_x = torch.einsum('bkd,bk->bd', eigenvectors, sqrt_lambda_Vt_x) + + # Σ_t @ x + result = (1 - t) * x + t * gamma_sqrt_x + + # Restore original shape + result = result.reshape(orig_shape) + + return result diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 06fe0b953..ca030730c 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -2,10 +2,8 @@ import math import os import numpy as np -import toml -import json import time -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple import torch from accelerate import Accelerator, PartialState @@ -183,7 +181,7 @@ def sample_image_inference( if cfg_scale != 1.0: logger.info(f"negative_prompt: {negative_prompt}") elif negative_prompt != "": - logger.info(f"negative prompt is ignored because scale is 1.0") + logger.info("negative prompt is ignored because scale is 1.0") logger.info(f"height: {height}") logger.info(f"width: {width}") logger.info(f"sample_steps: {sample_steps}") @@ -468,9 +466,91 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): return weighting -def get_noisy_model_input_and_timesteps( - args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +# Global set to track samples that have already been warned about shape mismatches +# This prevents log spam during training (warning once per sample is sufficient) +_cdc_warned_samples = set() + + +def apply_cdc_noise_transformation( + noise: torch.Tensor, + timesteps: torch.Tensor, + num_timesteps: int, + gamma_b_dataset, + latents_npz_paths, + device +) -> torch.Tensor: + """ + Apply CDC-FM geometry-aware noise transformation. + + Args: + noise: (B, C, H, W) standard Gaussian noise + timesteps: (B,) timesteps for this batch + num_timesteps: Total number of timesteps in scheduler + gamma_b_dataset: GammaBDataset with cached CDC matrices + latents_npz_paths: List of latent cache paths for this batch + device: Device to load CDC matrices to + + Returns: + Transformed noise with geometry-aware covariance + """ + # Device consistency validation + # Normalize device strings: "cuda" -> "cuda:0", "cpu" -> "cpu" + target_device = torch.device(device) if not isinstance(device, torch.device) else device + noise_device = noise.device + + # Check if devices are compatible (cuda:0 vs cuda should not warn) + devices_compatible = ( + noise_device == target_device or + (noise_device.type == "cuda" and target_device.type == "cuda") or + (noise_device.type == "cpu" and target_device.type == "cpu") + ) + + if not devices_compatible: + logger.warning( + f"CDC device mismatch: noise on {noise_device} but CDC loading to {target_device}. " + f"Transferring noise to {target_device} to avoid errors." + ) + noise = noise.to(target_device) + device = target_device + + # Normalize timesteps to [0, 1] for CDC-FM + t_normalized = timesteps.to(device) / num_timesteps + + B, C, H, W = noise.shape + + # Batch processing: Get CDC data for all samples at once + # Pass latent shape for multi-resolution CDC support + latent_shape = (C, H, W) + eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(latents_npz_paths, device=device, latent_shape=latent_shape) + noise_flat = noise.reshape(B, -1) + noise_cdc_flat = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, noise_flat, t_normalized) + return noise_cdc_flat.reshape(B, C, H, W) + + +def get_noisy_model_input_and_timestep( + args, + noise_scheduler, + latents: torch.Tensor, + noise: torch.Tensor, + device: torch.device, + dtype: torch.dtype, + gamma_b_dataset=None, + latents_npz_paths=None, + timestep_index: int | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Generate noisy model input and corresponding timesteps for training. + + Args: + args: Configuration with sampling parameters + noise_scheduler: Scheduler for noise/timestep management + latents: Clean latent representations + noise: Random noise tensor + device: Target device + dtype: Target dtype + gamma_b_dataset: Optional CDC-FM gamma_b dataset for geometry-aware noise + latents_npz_paths: Optional list of latent cache file paths for CDC-FM (required if gamma_b_dataset provided) + """ bsz, _, h, w = latents.shape assert bsz > 0, "Batch size not large enough" num_timesteps = noise_scheduler.config.num_train_timesteps @@ -514,10 +594,20 @@ def get_noisy_model_input_and_timesteps( # Broadcast sigmas to latent shape sigmas = sigmas.view(-1, 1, 1, 1) - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) + # Apply CDC-FM geometry-aware noise transformation if enabled + if gamma_b_dataset is not None and latents_npz_paths is not None: + noise = apply_cdc_noise_transformation( + noise=noise, + timesteps=timesteps, + num_timesteps=num_timesteps, + gamma_b_dataset=gamma_b_dataset, + latents_npz_paths=latents_npz_paths, + device=device + ) + if args.ip_noise_gamma: xi = torch.randn_like(latents, device=latents.device, dtype=dtype) + if args.ip_noise_gamma_random_strength: ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma else: diff --git a/library/train_util.py b/library/train_util.py index 756d88b1c..36ded89d4 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -40,6 +40,8 @@ from torchvision import transforms from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection import transformers + +from library.cdc_fm import CDCPreprocessor from diffusers.optimization import ( SchedulerType as DiffusersSchedulerType, TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION, @@ -1569,11 +1571,17 @@ def __getitem__(self, index): flippeds = [] # 変数名が微妙 text_encoder_outputs_list = [] custom_attributes = [] + image_keys = [] # CDC-FM: track image keys for CDC lookup + latents_npz_paths = [] # CDC-FM: track latents_npz paths for CDC lookup for image_key in bucket[image_index : image_index + bucket_batch_size]: image_info = self.image_data[image_key] subset = self.image_to_subset[image_key] + # CDC-FM: Store image_key and latents_npz path for CDC lookup + image_keys.append(image_key) + latents_npz_paths.append(image_info.latents_npz) + custom_attributes.append(subset.custom_attributes) # in case of fine tuning, is_reg is always False @@ -1819,6 +1827,9 @@ def none_or_stack_elements(tensors_list, converter): example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions)) + # CDC-FM: Add latents_npz paths to batch for CDC lookup + example["latents_npz"] = latents_npz_paths + if self.debug_dataset: example["image_keys"] = bucket[image_index : image_index + self.batch_size] return example @@ -2690,6 +2701,220 @@ def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Acceler dataset.new_cache_text_encoder_outputs(models, accelerator) accelerator.wait_for_everyone() + def cache_cdc_gamma_b( + self, + k_neighbors: int = 256, + k_bandwidth: int = 8, + d_cdc: int = 8, + gamma: float = 1.0, + force_recache: bool = False, + accelerator: Optional["Accelerator"] = None, + debug: bool = False, + adaptive_k: bool = False, + min_bucket_size: int = 16, + ) -> Optional[str]: + """ + Cache CDC Γ_b matrices for all latents in the dataset + + CDC files are saved as individual .npz files next to each latent cache file. + For example: image_0512x0768_flux.npz → image_0512x0768_flux_cdc_a1b2c3d4.npz + where 'a1b2c3d4' is the config hash (dataset dirs + CDC params). + + Args: + k_neighbors: k-NN neighbors + k_bandwidth: Bandwidth estimation neighbors + d_cdc: CDC subspace dimension + gamma: CDC strength + force_recache: Force recompute even if cache exists + accelerator: For multi-GPU support + debug: Enable debug logging + adaptive_k: Enable adaptive k selection for small buckets + min_bucket_size: Minimum bucket size for CDC computation + + Returns: + Config hash string for this CDC configuration, or None on error + """ + from pathlib import Path + + # Validate that latent caching is enabled + # CDC requires latents to be cached (either to disk or in memory) because: + # 1. CDC files are named based on latent cache filenames + # 2. CDC files are saved next to latent cache files + # 3. Training needs latent paths to load corresponding CDC files + has_cached_latents = False + for dataset in self.datasets: + for info in dataset.image_data.values(): + if info.latents is not None or info.latents_npz is not None: + has_cached_latents = True + break + if has_cached_latents: + break + + if not has_cached_latents: + raise ValueError( + "CDC-FM requires latent caching to be enabled. " + "Please enable latent caching by setting one of:\n" + " - cache_latents = true (cache in memory)\n" + " - cache_latents_to_disk = true (cache to disk)\n" + "in your training config or command line arguments." + ) + + # Collect dataset/subset directories for config hash + dataset_dirs = [] + for dataset in self.datasets: + # Get the directory containing the images + if hasattr(dataset, 'image_dir'): + dataset_dirs.append(str(dataset.image_dir)) + # Fallback: use first image's parent directory + elif dataset.image_data: + first_image = next(iter(dataset.image_data.values())) + dataset_dirs.append(str(Path(first_image.absolute_path).parent)) + + # Create preprocessor to get config hash + preprocessor = CDCPreprocessor( + k_neighbors=k_neighbors, + k_bandwidth=k_bandwidth, + d_cdc=d_cdc, + gamma=gamma, + device="cuda" if torch.cuda.is_available() else "cpu", + debug=debug, + adaptive_k=adaptive_k, + min_bucket_size=min_bucket_size, + dataset_dirs=dataset_dirs + ) + + logger.info(f"CDC config hash: {preprocessor.config_hash}") + + # Check if CDC caches already exist (unless force_recache) + if not force_recache: + all_cached = self._check_cdc_caches_exist(preprocessor.config_hash) + if all_cached: + logger.info("All CDC cache files found, skipping preprocessing") + return preprocessor.config_hash + else: + logger.info("Some CDC cache files missing, will compute") + + # Only main process computes CDC + is_main = accelerator is None or accelerator.is_main_process + if not is_main: + if accelerator is not None: + accelerator.wait_for_everyone() + return preprocessor.config_hash + + logger.info("Starting CDC-FM preprocessing") + logger.info(f"Parameters: k={k_neighbors}, k_bw={k_bandwidth}, d_cdc={d_cdc}, gamma={gamma}") + + # Get caching strategy for loading latents + from library.strategy_base import LatentsCachingStrategy + + caching_strategy = LatentsCachingStrategy.get_strategy() + + # Collect all latents from all datasets + for dataset_idx, dataset in enumerate(self.datasets): + logger.info(f"Loading latents from dataset {dataset_idx}...") + image_infos = list(dataset.image_data.values()) + + for local_idx, info in enumerate(tqdm(image_infos, desc=f"Dataset {dataset_idx}")): + # Load latent from disk or memory + if info.latents is not None: + latent = info.latents + elif info.latents_npz is not None: + # Load from disk + latent, _, _, _, _ = caching_strategy.load_latents_from_disk(info.latents_npz, info.bucket_reso) + if latent is None: + logger.warning(f"Failed to load latent from {info.latents_npz}, skipping") + continue + else: + logger.warning(f"No latent found for {info.absolute_path}, skipping") + continue + + # Add to preprocessor (with unique global index across all datasets) + actual_global_idx = sum(len(d.image_data) for d in self.datasets[:dataset_idx]) + local_idx + + # Get latents_npz_path - will be set whether caching to disk or memory + if info.latents_npz is None: + # If not set, generate the path from the caching strategy + info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.bucket_reso) + + preprocessor.add_latent( + latent=latent, + global_idx=actual_global_idx, + latents_npz_path=info.latents_npz, + shape=latent.shape, + metadata={"image_key": info.image_key} + ) + + # Compute and save individual CDC files + logger.info(f"\nComputing CDC Γ_b matrices for {len(preprocessor.batcher)} samples...") + files_saved = preprocessor.compute_all() + logger.info(f"Saved {files_saved} CDC cache files") + + if accelerator is not None: + accelerator.wait_for_everyone() + + # Return config hash so training can initialize GammaBDataset with it + return preprocessor.config_hash + + def _check_cdc_caches_exist(self, config_hash: str) -> bool: + """ + Check if CDC cache files exist for all latents in the dataset + + Args: + config_hash: The config hash to use for CDC filename lookup + """ + from pathlib import Path + + missing_count = 0 + total_count = 0 + + for dataset in self.datasets: + for info in dataset.image_data.values(): + total_count += 1 + if info.latents_npz is None: + # If latents_npz not set, we can't check for CDC cache + continue + + # Compute expected latent shape from bucket_reso + # For multi-resolution CDC, we need to pass latent_shape to get the correct filename + latent_shape = None + if info.bucket_reso is not None: + # Get latent shape efficiently without loading full data + # First check if latent is already in memory + if info.latents is not None: + latent_shape = info.latents.shape + else: + # Load latent shape from npz file metadata + # This is faster than loading the full latent data + try: + import numpy as np + with np.load(info.latents_npz) as data: + # Find the key for this bucket resolution + # Multi-resolution format uses keys like "latents_104x80" + h, w = info.bucket_reso[1] // 8, info.bucket_reso[0] // 8 + key = f"latents_{h}x{w}" + if key in data: + latent_shape = data[key].shape + elif 'latents' in data: + # Fallback for single-resolution cache + latent_shape = data['latents'].shape + except Exception as e: + logger.debug(f"Failed to read latent shape from {info.latents_npz}: {e}") + # Fall back to checking without shape (backward compatibility) + latent_shape = None + + cdc_path = CDCPreprocessor.get_cdc_npz_path(info.latents_npz, config_hash, latent_shape) + if not Path(cdc_path).exists(): + missing_count += 1 + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f"Missing CDC cache: {cdc_path}") + + if missing_count > 0: + logger.info(f"Found {missing_count}/{total_count} missing CDC cache files") + return False + + logger.debug(f"All {total_count} CDC cache files exist") + return True + def set_caching_mode(self, caching_mode): for dataset in self.datasets: dataset.set_caching_mode(caching_mode) @@ -6107,8 +6332,19 @@ def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: tor def get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents: torch.FloatTensor + args, noise_scheduler, latents: torch.FloatTensor, ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]: + """ + Sample noise and create noisy latents. + + Args: + args: Training arguments + noise_scheduler: The noise scheduler + latents: Clean latents + + Returns: + (noise, noisy_latents, timesteps) + """ # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) if args.noise_offset: diff --git a/tests/library/test_cdc_advanced.py b/tests/library/test_cdc_advanced.py new file mode 100644 index 000000000..e2a43ea40 --- /dev/null +++ b/tests/library/test_cdc_advanced.py @@ -0,0 +1,183 @@ +import torch +from typing import Union + + +class MockGammaBDataset: + """ + Mock implementation of GammaBDataset for testing gradient flow + """ + def __init__(self, *args, **kwargs): + """ + Simple initialization that doesn't require file loading + """ + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + def compute_sigma_t_x( + self, + eigenvectors: torch.Tensor, + eigenvalues: torch.Tensor, + x: torch.Tensor, + t: Union[float, torch.Tensor] + ) -> torch.Tensor: + """ + Simplified implementation of compute_sigma_t_x for testing + """ + # Store original shape to restore later + orig_shape = x.shape + + # Flatten x if it's 4D + if x.dim() == 4: + B, C, H, W = x.shape + x = x.reshape(B, -1) # (B, C*H*W) + + if not isinstance(t, torch.Tensor): + t = torch.tensor(t, device=x.device, dtype=x.dtype) + + # Validate dimensions + assert eigenvectors.shape[0] == x.shape[0], "Batch size mismatch" + assert eigenvectors.shape[2] == x.shape[1], "Dimension mismatch" + + # Early return for t=0 with gradient preservation + if torch.allclose(t, torch.zeros_like(t), atol=1e-8) and not t.requires_grad: + return x.reshape(orig_shape) + + # Compute Σ_t @ x + # V^T x + Vt_x = torch.einsum('bkd,bd->bk', eigenvectors, x) + + # sqrt(λ) * V^T x + sqrt_eigenvalues = torch.sqrt(eigenvalues.clamp(min=1e-10)) + sqrt_lambda_Vt_x = sqrt_eigenvalues * Vt_x + + # V @ (sqrt(λ) * V^T x) + gamma_sqrt_x = torch.einsum('bkd,bk->bd', eigenvectors, sqrt_lambda_Vt_x) + + # Interpolate between original and noisy latent + result = (1 - t) * x + t * gamma_sqrt_x + + # Restore original shape + result = result.reshape(orig_shape) + + return result + +class TestCDCAdvanced: + def setup_method(self): + """Prepare consistent test environment""" + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + def test_gradient_flow_preservation(self): + """ + Verify that gradient flow is preserved even for near-zero time steps + with learnable time embeddings + """ + # Set random seed for reproducibility + torch.manual_seed(42) + + # Create a learnable time embedding with small initial value + t = torch.tensor(0.001, requires_grad=True, device=self.device, dtype=torch.float32) + + # Generate mock latent and CDC components + batch_size, latent_dim = 4, 64 + latent = torch.randn(batch_size, latent_dim, device=self.device, requires_grad=True) + + # Create mock eigenvectors and eigenvalues + eigenvectors = torch.randn(batch_size, 8, latent_dim, device=self.device) + eigenvalues = torch.rand(batch_size, 8, device=self.device) + + # Ensure eigenvectors and eigenvalues are meaningful + eigenvectors /= torch.norm(eigenvectors, dim=-1, keepdim=True) + eigenvalues = torch.clamp(eigenvalues, min=1e-4, max=1.0) + + # Use the mock dataset + mock_dataset = MockGammaBDataset() + + # Compute noisy latent with gradient tracking + noisy_latent = mock_dataset.compute_sigma_t_x( + eigenvectors, + eigenvalues, + latent, + t + ) + + # Compute a dummy loss to check gradient flow + loss = noisy_latent.sum() + + # Compute gradients + loss.backward() + + # Assertions to verify gradient flow + assert t.grad is not None, "Time embedding gradient should be computed" + assert latent.grad is not None, "Input latent gradient should be computed" + + # Check gradient magnitudes are non-zero + t_grad_magnitude = torch.abs(t.grad).sum() + latent_grad_magnitude = torch.abs(latent.grad).sum() + + assert t_grad_magnitude > 0, f"Time embedding gradient is zero: {t_grad_magnitude}" + assert latent_grad_magnitude > 0, f"Input latent gradient is zero: {latent_grad_magnitude}" + + # Optional: Print gradient details for debugging + print(f"Time embedding gradient magnitude: {t_grad_magnitude}") + print(f"Latent gradient magnitude: {latent_grad_magnitude}") + + def test_gradient_flow_with_different_time_steps(self): + """ + Verify gradient flow across different time step values + """ + # Test time steps + time_steps = [0.0001, 0.001, 0.01, 0.1, 0.5, 1.0] + + for time_val in time_steps: + # Create a learnable time embedding + t = torch.tensor(time_val, requires_grad=True, device=self.device, dtype=torch.float32) + + # Generate mock latent and CDC components + batch_size, latent_dim = 4, 64 + latent = torch.randn(batch_size, latent_dim, device=self.device, requires_grad=True) + + # Create mock eigenvectors and eigenvalues + eigenvectors = torch.randn(batch_size, 8, latent_dim, device=self.device) + eigenvalues = torch.rand(batch_size, 8, device=self.device) + + # Ensure eigenvectors and eigenvalues are meaningful + eigenvectors /= torch.norm(eigenvectors, dim=-1, keepdim=True) + eigenvalues = torch.clamp(eigenvalues, min=1e-4, max=1.0) + + # Use the mock dataset + mock_dataset = MockGammaBDataset() + + # Compute noisy latent with gradient tracking + noisy_latent = mock_dataset.compute_sigma_t_x( + eigenvectors, + eigenvalues, + latent, + t + ) + + # Compute a dummy loss to check gradient flow + loss = noisy_latent.sum() + + # Compute gradients + loss.backward() + + # Assertions to verify gradient flow + t_grad_magnitude = torch.abs(t.grad).sum() + latent_grad_magnitude = torch.abs(latent.grad).sum() + + assert t_grad_magnitude > 0, f"Time embedding gradient is zero for t={time_val}" + assert latent_grad_magnitude > 0, f"Input latent gradient is zero for t={time_val}" + + # Reset gradients for next iteration + if t.grad is not None: + t.grad.zero_() + if latent.grad is not None: + latent.grad.zero_() + +def pytest_configure(config): + """ + Add custom markers for CDC-FM tests + """ + config.addinivalue_line( + "markers", + "gradient_flow: mark test to verify gradient preservation in CDC Flow Matching" + ) \ No newline at end of file diff --git a/tests/library/test_cdc_cache_detection.py b/tests/library/test_cdc_cache_detection.py new file mode 100644 index 000000000..faba20582 --- /dev/null +++ b/tests/library/test_cdc_cache_detection.py @@ -0,0 +1,285 @@ +""" +Test CDC cache detection with multi-resolution filenames + +This test verifies that _check_cdc_caches_exist() correctly detects CDC cache files +that include resolution information in their filenames (e.g., image_flux_cdc_104x80_hash.npz). + +This was a bug where the check was looking for files without resolution +(image_flux_cdc_hash.npz) while the actual files had resolution in the name. +""" + +import os +import tempfile +import shutil +from pathlib import Path +import numpy as np +import pytest + +from library.train_util import DatasetGroup, ImageInfo +from library.cdc_fm import CDCPreprocessor + + +class MockDataset: + """Mock dataset for testing""" + def __init__(self, image_data): + self.image_data = image_data + self.image_dir = "/mock/dataset" + self.num_train_images = len(image_data) + self.num_reg_images = 0 + + def __len__(self): + return len(self.image_data) + + +def test_cdc_cache_detection_with_resolution(): + """ + Test that CDC cache files with resolution in filename are properly detected. + + This reproduces the bug where: + - CDC files are created with resolution: image_flux_cdc_104x80_hash.npz + - But check looked for: image_flux_cdc_hash.npz + - Result: Files not detected, unnecessary regeneration + """ + + with tempfile.TemporaryDirectory() as tmpdir: + # Setup: Create a mock latent cache file and corresponding CDC cache + config_hash = "test1234" + + # Create latent cache file with multi-resolution format + latent_path = Path(tmpdir) / "image_0832x0640_flux.npz" + latent_shape = (16, 104, 80) # C, H, W for resolution 832x640 (832/8=104, 640/8=80) + + # Save a mock latent file + np.savez( + latent_path, + **{f"latents_{latent_shape[1]}x{latent_shape[2]}": np.random.randn(*latent_shape).astype(np.float32)} + ) + + # Create the CDC cache file with resolution in filename (as it's actually created) + cdc_path = CDCPreprocessor.get_cdc_npz_path( + str(latent_path), + config_hash, + latent_shape + ) + + # Verify the CDC path includes resolution + assert "104x80" in cdc_path, f"CDC path should include resolution: {cdc_path}" + + # Create a mock CDC file + np.savez( + cdc_path, + eigenvectors=np.random.randn(8, 16*104*80).astype(np.float16), + eigenvalues=np.random.randn(8).astype(np.float16), + shape=np.array(latent_shape), + k_neighbors=256, + d_cdc=8, + gamma=1.0 + ) + + # Setup mock dataset + image_info = ImageInfo( + image_key="test_image", + num_repeats=1, + caption="test", + is_reg=False, + absolute_path=str(Path(tmpdir) / "image.png") + ) + image_info.latents_npz = str(latent_path) + image_info.bucket_reso = (640, 832) # W, H (note: reversed from latent shape H,W) + image_info.latents = None # Not in memory + + mock_dataset = MockDataset({"test_image": image_info}) + dataset_group = DatasetGroup([mock_dataset]) + + # Test: Check if CDC cache is detected + result = dataset_group._check_cdc_caches_exist(config_hash) + + # Verify: Should return True since the CDC file exists + assert result is True, "CDC cache file should be detected when it exists with resolution in filename" + + +def test_cdc_cache_detection_missing_file(): + """ + Test that missing CDC cache files are correctly identified as missing. + """ + + with tempfile.TemporaryDirectory() as tmpdir: + config_hash = "test5678" + + # Create latent cache file but NO CDC cache + latent_path = Path(tmpdir) / "image_0768x0512_flux.npz" + latent_shape = (16, 96, 64) # C, H, W + + np.savez( + latent_path, + **{f"latents_{latent_shape[1]}x{latent_shape[2]}": np.random.randn(*latent_shape).astype(np.float32)} + ) + + # Setup mock dataset (CDC file does NOT exist) + image_info = ImageInfo( + image_key="test_image", + num_repeats=1, + caption="test", + is_reg=False, + absolute_path=str(Path(tmpdir) / "image.png") + ) + image_info.latents_npz = str(latent_path) + image_info.bucket_reso = (512, 768) # W, H + image_info.latents = None + + mock_dataset = MockDataset({"test_image": image_info}) + dataset_group = DatasetGroup([mock_dataset]) + + # Test: Check if CDC cache is detected + result = dataset_group._check_cdc_caches_exist(config_hash) + + # Verify: Should return False since CDC file doesn't exist + assert result is False, "Should detect that CDC cache file is missing" + + +def test_cdc_cache_detection_with_in_memory_latent(): + """ + Test CDC cache detection when latent is already in memory (faster path). + """ + + with tempfile.TemporaryDirectory() as tmpdir: + config_hash = "test_mem1" + + # Create latent cache file path (file may or may not exist) + latent_path = Path(tmpdir) / "image_1024x1024_flux.npz" + latent_shape = (16, 128, 128) # C, H, W + + # Create the CDC cache file + cdc_path = CDCPreprocessor.get_cdc_npz_path( + str(latent_path), + config_hash, + latent_shape + ) + + np.savez( + cdc_path, + eigenvectors=np.random.randn(8, 16*128*128).astype(np.float16), + eigenvalues=np.random.randn(8).astype(np.float16), + shape=np.array(latent_shape), + k_neighbors=256, + d_cdc=8, + gamma=1.0 + ) + + # Setup mock dataset with latent in memory + import torch + image_info = ImageInfo( + image_key="test_image", + num_repeats=1, + caption="test", + is_reg=False, + absolute_path=str(Path(tmpdir) / "image.png") + ) + image_info.latents_npz = str(latent_path) + image_info.bucket_reso = (1024, 1024) # W, H + image_info.latents = torch.randn(latent_shape) # In memory! + + mock_dataset = MockDataset({"test_image": image_info}) + dataset_group = DatasetGroup([mock_dataset]) + + # Test: Check if CDC cache is detected (should use faster in-memory path) + result = dataset_group._check_cdc_caches_exist(config_hash) + + # Verify: Should return True + assert result is True, "CDC cache should be detected using in-memory latent shape" + + +def test_cdc_cache_detection_partial_cache(): + """ + Test that partial cache (some files exist, some don't) is correctly identified. + """ + + with tempfile.TemporaryDirectory() as tmpdir: + config_hash = "testpart" + + # Create two latent files + latent_path1 = Path(tmpdir) / "image1_0640x0512_flux.npz" + latent_path2 = Path(tmpdir) / "image2_0640x0512_flux.npz" + latent_shape = (16, 80, 64) + + for latent_path in [latent_path1, latent_path2]: + np.savez( + latent_path, + **{f"latents_{latent_shape[1]}x{latent_shape[2]}": np.random.randn(*latent_shape).astype(np.float32)} + ) + + # Create CDC cache for ONLY the first image + cdc_path1 = CDCPreprocessor.get_cdc_npz_path(str(latent_path1), config_hash, latent_shape) + np.savez( + cdc_path1, + eigenvectors=np.random.randn(8, 16*80*64).astype(np.float16), + eigenvalues=np.random.randn(8).astype(np.float16), + shape=np.array(latent_shape), + k_neighbors=256, + d_cdc=8, + gamma=1.0 + ) + + # CDC cache for second image does NOT exist + + # Setup mock dataset with both images + info1 = ImageInfo("img1", 1, "test", False, str(Path(tmpdir) / "img1.png")) + info1.latents_npz = str(latent_path1) + info1.bucket_reso = (512, 640) + info1.latents = None + + info2 = ImageInfo("img2", 1, "test", False, str(Path(tmpdir) / "img2.png")) + info2.latents_npz = str(latent_path2) + info2.bucket_reso = (512, 640) + info2.latents = None + + mock_dataset = MockDataset({"img1": info1, "img2": info2}) + dataset_group = DatasetGroup([mock_dataset]) + + # Test: Check if all CDC caches exist + result = dataset_group._check_cdc_caches_exist(config_hash) + + # Verify: Should return False since not all files exist + assert result is False, "Should detect that some CDC cache files are missing" + + +def test_cdc_requires_latent_caching(): + """ + Test that CDC-FM gives a clear error when latent caching is not enabled. + """ + + with tempfile.TemporaryDirectory() as tmpdir: + # Setup mock dataset with NO latent caching (both latents and latents_npz are None) + image_info = ImageInfo( + image_key="test_image", + num_repeats=1, + caption="test", + is_reg=False, + absolute_path=str(Path(tmpdir) / "image.png") + ) + image_info.latents_npz = None # No disk cache + image_info.latents = None # No memory cache + image_info.bucket_reso = (512, 512) + + mock_dataset = MockDataset({"test_image": image_info}) + dataset_group = DatasetGroup([mock_dataset]) + + # Test: Attempt to cache CDC without latent caching enabled + with pytest.raises(ValueError) as exc_info: + dataset_group.cache_cdc_gamma_b( + k_neighbors=256, + k_bandwidth=8, + d_cdc=8, + gamma=1.0 + ) + + # Verify: Error message should mention latent caching requirement + error_message = str(exc_info.value) + assert "CDC-FM requires latent caching" in error_message + assert "cache_latents" in error_message + assert "cache_latents_to_disk" in error_message + + +if __name__ == "__main__": + # Run tests with verbose output + pytest.main([__file__, "-v"]) diff --git a/tests/library/test_cdc_hash_validation.py b/tests/library/test_cdc_hash_validation.py new file mode 100644 index 000000000..a6034c094 --- /dev/null +++ b/tests/library/test_cdc_hash_validation.py @@ -0,0 +1,157 @@ +""" +Test CDC config hash generation and cache invalidation +""" + +import pytest +import torch +from pathlib import Path + +from library.cdc_fm import CDCPreprocessor + + +class TestCDCConfigHash: + """ + Test that CDC config hash properly invalidates cache when dataset or parameters change + """ + + def test_same_config_produces_same_hash(self, tmp_path): + """ + Test that identical configurations produce identical hashes + """ + preprocessor1 = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, + device="cpu", dataset_dirs=[str(tmp_path / "dataset1")] + ) + + preprocessor2 = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, + device="cpu", dataset_dirs=[str(tmp_path / "dataset1")] + ) + + assert preprocessor1.config_hash == preprocessor2.config_hash + + def test_different_dataset_dirs_produce_different_hash(self, tmp_path): + """ + Test that different dataset directories produce different hashes + """ + preprocessor1 = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, + device="cpu", dataset_dirs=[str(tmp_path / "dataset1")] + ) + + preprocessor2 = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, + device="cpu", dataset_dirs=[str(tmp_path / "dataset2")] + ) + + assert preprocessor1.config_hash != preprocessor2.config_hash + + def test_different_k_neighbors_produces_different_hash(self, tmp_path): + """ + Test that different k_neighbors values produce different hashes + """ + preprocessor1 = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, + device="cpu", dataset_dirs=[str(tmp_path / "dataset1")] + ) + + preprocessor2 = CDCPreprocessor( + k_neighbors=10, k_bandwidth=3, d_cdc=4, gamma=1.0, + device="cpu", dataset_dirs=[str(tmp_path / "dataset1")] + ) + + assert preprocessor1.config_hash != preprocessor2.config_hash + + def test_different_d_cdc_produces_different_hash(self, tmp_path): + """ + Test that different d_cdc values produce different hashes + """ + preprocessor1 = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, + device="cpu", dataset_dirs=[str(tmp_path / "dataset1")] + ) + + preprocessor2 = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=8, gamma=1.0, + device="cpu", dataset_dirs=[str(tmp_path / "dataset1")] + ) + + assert preprocessor1.config_hash != preprocessor2.config_hash + + def test_different_gamma_produces_different_hash(self, tmp_path): + """ + Test that different gamma values produce different hashes + """ + preprocessor1 = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, + device="cpu", dataset_dirs=[str(tmp_path / "dataset1")] + ) + + preprocessor2 = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=2.0, + device="cpu", dataset_dirs=[str(tmp_path / "dataset1")] + ) + + assert preprocessor1.config_hash != preprocessor2.config_hash + + def test_multiple_dataset_dirs_order_independent(self, tmp_path): + """ + Test that dataset directory order doesn't affect hash (they are sorted) + """ + preprocessor1 = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, + device="cpu", + dataset_dirs=[str(tmp_path / "dataset1"), str(tmp_path / "dataset2")] + ) + + preprocessor2 = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, + device="cpu", + dataset_dirs=[str(tmp_path / "dataset2"), str(tmp_path / "dataset1")] + ) + + assert preprocessor1.config_hash == preprocessor2.config_hash + + def test_hash_length_is_8_chars(self, tmp_path): + """ + Test that hash is exactly 8 characters (hex) + """ + preprocessor = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, + device="cpu", dataset_dirs=[str(tmp_path / "dataset1")] + ) + + assert len(preprocessor.config_hash) == 8 + # Verify it's hex + int(preprocessor.config_hash, 16) # Should not raise + + def test_filename_includes_hash(self, tmp_path): + """ + Test that CDC filenames include the config hash + """ + preprocessor = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, + device="cpu", dataset_dirs=[str(tmp_path / "dataset1")] + ) + + latents_path = str(tmp_path / "image_0512x0768_flux.npz") + cdc_path = CDCPreprocessor.get_cdc_npz_path(latents_path, preprocessor.config_hash) + + # Should be: image_0512x0768_flux_cdc_.npz + expected = str(tmp_path / f"image_0512x0768_flux_cdc_{preprocessor.config_hash}.npz") + assert cdc_path == expected + + def test_backward_compatibility_no_hash(self, tmp_path): + """ + Test that get_cdc_npz_path works without hash (backward compatibility) + """ + latents_path = str(tmp_path / "image_0512x0768_flux.npz") + cdc_path = CDCPreprocessor.get_cdc_npz_path(latents_path, config_hash=None) + + # Should be: image_0512x0768_flux_cdc.npz (no hash suffix) + expected = str(tmp_path / "image_0512x0768_flux_cdc.npz") + assert cdc_path == expected + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/library/test_cdc_multiresolution.py b/tests/library/test_cdc_multiresolution.py new file mode 100644 index 000000000..4a67feac7 --- /dev/null +++ b/tests/library/test_cdc_multiresolution.py @@ -0,0 +1,234 @@ +""" +Test CDC-FM multi-resolution support + +This test verifies that CDC files are correctly created and loaded for different +resolutions, preventing dimension mismatch errors in multi-resolution training. +""" + +import torch +import numpy as np +from pathlib import Path +import pytest + +from library.cdc_fm import CDCPreprocessor, GammaBDataset + + +class TestCDCMultiResolution: + """Test CDC multi-resolution caching and loading""" + + def test_different_resolutions_create_separate_cdc_files(self, tmp_path): + """ + Test that the same image with different latent resolutions creates + separate CDC cache files. + """ + # Create preprocessor + preprocessor = CDCPreprocessor( + k_neighbors=5, + k_bandwidth=3, + d_cdc=4, + gamma=1.0, + device="cpu", + dataset_dirs=[str(tmp_path)] + ) + + # Same image, two different resolutions + image_base_path = str(tmp_path / "test_image_1200x1500_flux.npz") + + # Resolution 1: 64x48 (simulating resolution=512 training) + latent_64x48 = torch.randn(16, 64, 48, dtype=torch.float32) + for i in range(10): # Need multiple samples for CDC + preprocessor.add_latent( + latent=latent_64x48, + global_idx=i, + latents_npz_path=image_base_path, + shape=latent_64x48.shape, + metadata={'image_key': f'test_image_{i}'} + ) + + # Compute and save + files_saved = preprocessor.compute_all() + assert files_saved == 10 + + # Verify CDC file for 64x48 exists with shape in filename + cdc_path_64x48 = CDCPreprocessor.get_cdc_npz_path( + image_base_path, + preprocessor.config_hash, + latent_shape=(16, 64, 48) + ) + assert Path(cdc_path_64x48).exists() + assert "64x48" in cdc_path_64x48 + + # Create new preprocessor for resolution 2 + preprocessor2 = CDCPreprocessor( + k_neighbors=5, + k_bandwidth=3, + d_cdc=4, + gamma=1.0, + device="cpu", + dataset_dirs=[str(tmp_path)] + ) + + # Resolution 2: 104x80 (simulating resolution=768 training) + latent_104x80 = torch.randn(16, 104, 80, dtype=torch.float32) + for i in range(10): + preprocessor2.add_latent( + latent=latent_104x80, + global_idx=i, + latents_npz_path=image_base_path, + shape=latent_104x80.shape, + metadata={'image_key': f'test_image_{i}'} + ) + + files_saved2 = preprocessor2.compute_all() + assert files_saved2 == 10 + + # Verify CDC file for 104x80 exists with different shape in filename + cdc_path_104x80 = CDCPreprocessor.get_cdc_npz_path( + image_base_path, + preprocessor2.config_hash, + latent_shape=(16, 104, 80) + ) + assert Path(cdc_path_104x80).exists() + assert "104x80" in cdc_path_104x80 + + # Verify both files exist and are different + assert cdc_path_64x48 != cdc_path_104x80 + assert Path(cdc_path_64x48).exists() + assert Path(cdc_path_104x80).exists() + + # Verify the CDC files have different dimensions + data_64x48 = np.load(cdc_path_64x48) + data_104x80 = np.load(cdc_path_104x80) + + # 64x48 -> flattened dim = 16 * 64 * 48 = 49152 + # 104x80 -> flattened dim = 16 * 104 * 80 = 133120 + assert data_64x48['eigenvectors'].shape[1] == 16 * 64 * 48 + assert data_104x80['eigenvectors'].shape[1] == 16 * 104 * 80 + + def test_loading_correct_cdc_for_resolution(self, tmp_path): + """ + Test that GammaBDataset loads the correct CDC file based on latent_shape + """ + # Create and save CDC files for two resolutions + config_hash = "testHash" + + image_path = str(tmp_path / "test_image_flux.npz") + + # Create CDC file for 64x48 + cdc_path_64x48 = CDCPreprocessor.get_cdc_npz_path( + image_path, + config_hash, + latent_shape=(16, 64, 48) + ) + eigvecs_64x48 = np.random.randn(4, 16 * 64 * 48).astype(np.float16) + eigvals_64x48 = np.random.randn(4).astype(np.float16) + np.savez( + cdc_path_64x48, + eigenvectors=eigvecs_64x48, + eigenvalues=eigvals_64x48, + shape=np.array([16, 64, 48]) + ) + + # Create CDC file for 104x80 + cdc_path_104x80 = CDCPreprocessor.get_cdc_npz_path( + image_path, + config_hash, + latent_shape=(16, 104, 80) + ) + eigvecs_104x80 = np.random.randn(4, 16 * 104 * 80).astype(np.float16) + eigvals_104x80 = np.random.randn(4).astype(np.float16) + np.savez( + cdc_path_104x80, + eigenvectors=eigvecs_104x80, + eigenvalues=eigvals_104x80, + shape=np.array([16, 104, 80]) + ) + + # Create GammaBDataset + gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash) + + # Load with 64x48 shape + eigvecs_loaded, eigvals_loaded = gamma_b_dataset.get_gamma_b_sqrt( + [image_path], + device="cpu", + latent_shape=(16, 64, 48) + ) + assert eigvecs_loaded.shape == (1, 4, 16 * 64 * 48) + + # Load with 104x80 shape + eigvecs_loaded2, eigvals_loaded2 = gamma_b_dataset.get_gamma_b_sqrt( + [image_path], + device="cpu", + latent_shape=(16, 104, 80) + ) + assert eigvecs_loaded2.shape == (1, 4, 16 * 104 * 80) + + # Verify different dimensions were loaded + assert eigvecs_loaded.shape[2] != eigvecs_loaded2.shape[2] + + def test_error_when_latent_shape_not_provided_for_multireso(self, tmp_path): + """ + Test that loading without latent_shape still works for backward compatibility + but will use old filename format without resolution + """ + config_hash = "testHash" + image_path = str(tmp_path / "test_image_flux.npz") + + # Create CDC file with old naming (no latent shape) + cdc_path_old = CDCPreprocessor.get_cdc_npz_path( + image_path, + config_hash, + latent_shape=None # Old format + ) + eigvecs = np.random.randn(4, 16 * 64 * 48).astype(np.float16) + eigvals = np.random.randn(4).astype(np.float16) + np.savez( + cdc_path_old, + eigenvectors=eigvecs, + eigenvalues=eigvals, + shape=np.array([16, 64, 48]) + ) + + # Load without latent_shape (backward compatibility) + gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash) + eigvecs_loaded, eigvals_loaded = gamma_b_dataset.get_gamma_b_sqrt( + [image_path], + device="cpu", + latent_shape=None + ) + assert eigvecs_loaded.shape == (1, 4, 16 * 64 * 48) + + def test_filename_format_with_latent_shape(self): + """Test that CDC filenames include latent dimensions correctly""" + base_path = "/path/to/image_1200x1500_flux.npz" + config_hash = "abc123de" + + # With latent shape + cdc_path = CDCPreprocessor.get_cdc_npz_path( + base_path, + config_hash, + latent_shape=(16, 104, 80) + ) + + # Should include latent H×W in filename + assert "104x80" in cdc_path + assert config_hash in cdc_path + assert cdc_path.endswith("_flux_cdc_104x80_abc123de.npz") + + def test_filename_format_without_latent_shape(self): + """Test backward compatible filename without latent shape""" + base_path = "/path/to/image_1200x1500_flux.npz" + config_hash = "abc123de" + + # Without latent shape (old format) + cdc_path = CDCPreprocessor.get_cdc_npz_path( + base_path, + config_hash, + latent_shape=None + ) + + # Should NOT include latent dimensions + assert "104x80" not in cdc_path + assert "64x48" not in cdc_path + assert config_hash in cdc_path + assert cdc_path.endswith("_flux_cdc_abc123de.npz") diff --git a/tests/library/test_cdc_preprocessor.py b/tests/library/test_cdc_preprocessor.py new file mode 100644 index 000000000..d8c925735 --- /dev/null +++ b/tests/library/test_cdc_preprocessor.py @@ -0,0 +1,322 @@ +""" +CDC Preprocessor and Device Consistency Tests + +This module provides testing of: +1. CDC Preprocessor functionality +2. Device consistency handling +3. GammaBDataset loading and usage +4. End-to-end CDC workflow verification +""" + +import pytest +import logging +import torch +from pathlib import Path +from safetensors.torch import save_file +from safetensors import safe_open + +from library.cdc_fm import CDCPreprocessor, GammaBDataset +from library.flux_train_utils import apply_cdc_noise_transformation + + +class TestCDCPreprocessorIntegration: + """ + Comprehensive testing of CDC preprocessing and device handling + """ + + def test_basic_preprocessor_workflow(self, tmp_path): + """ + Test basic CDC preprocessing with small dataset + """ + preprocessor = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu", + dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash + ) + + # Add 10 small latents + for i in range(10): + latent = torch.randn(16, 4, 4, dtype=torch.float32) # C, H, W + latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz") + metadata = {'image_key': f'test_image_{i}'} + preprocessor.add_latent( + latent=latent, + global_idx=i, + latents_npz_path=latents_npz_path, + shape=latent.shape, + metadata=metadata + ) + + # Compute and save + files_saved = preprocessor.compute_all() + + # Verify files were created + assert files_saved == 10 + + # Verify first CDC file structure (with config hash and latent shape) + latents_npz_path = str(tmp_path / "test_image_0_0004x0004_flux.npz") + latent_shape = (16, 4, 4) + cdc_path = Path(CDCPreprocessor.get_cdc_npz_path(latents_npz_path, preprocessor.config_hash, latent_shape)) + assert cdc_path.exists() + + import numpy as np + data = np.load(cdc_path) + + assert data['k_neighbors'] == 5 + assert data['d_cdc'] == 4 + + # Check eigenvectors and eigenvalues + eigvecs = data['eigenvectors'] + eigvals = data['eigenvalues'] + + assert eigvecs.shape[0] == 4 # d_cdc + assert eigvals.shape[0] == 4 # d_cdc + + def test_preprocessor_with_different_shapes(self, tmp_path): + """ + Test CDC preprocessing with variable-size latents (bucketing) + """ + preprocessor = CDCPreprocessor( + k_neighbors=3, k_bandwidth=2, d_cdc=2, gamma=1.0, device="cpu", + dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash + ) + + # Add 5 latents of shape (16, 4, 4) + for i in range(5): + latent = torch.randn(16, 4, 4, dtype=torch.float32) + latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz") + metadata = {'image_key': f'test_image_{i}'} + preprocessor.add_latent( + latent=latent, + global_idx=i, + latents_npz_path=latents_npz_path, + shape=latent.shape, + metadata=metadata + ) + + # Add 5 latents of different shape (16, 8, 8) + for i in range(5, 10): + latent = torch.randn(16, 8, 8, dtype=torch.float32) + latents_npz_path = str(tmp_path / f"test_image_{i}_0008x0008_flux.npz") + metadata = {'image_key': f'test_image_{i}'} + preprocessor.add_latent( + latent=latent, + global_idx=i, + latents_npz_path=latents_npz_path, + shape=latent.shape, + metadata=metadata + ) + + # Compute and save + files_saved = preprocessor.compute_all() + + # Verify both shape groups were processed + assert files_saved == 10 + + import numpy as np + # Check shapes are stored in individual files (with config hash and latent shape) + cdc_path_0 = CDCPreprocessor.get_cdc_npz_path( + str(tmp_path / "test_image_0_0004x0004_flux.npz"), preprocessor.config_hash, latent_shape=(16, 4, 4) + ) + cdc_path_5 = CDCPreprocessor.get_cdc_npz_path( + str(tmp_path / "test_image_5_0008x0008_flux.npz"), preprocessor.config_hash, latent_shape=(16, 8, 8) + ) + data_0 = np.load(cdc_path_0) + data_5 = np.load(cdc_path_5) + + assert tuple(data_0['shape']) == (16, 4, 4) + assert tuple(data_5['shape']) == (16, 8, 8) + + +class TestDeviceConsistency: + """ + Test device handling and consistency for CDC transformations + """ + + def test_matching_devices_no_warning(self, tmp_path, caplog): + """ + Test that no warnings are emitted when devices match. + """ + # Create CDC cache on CPU + preprocessor = CDCPreprocessor( + k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu", + dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash + ) + + shape = (16, 32, 32) + latents_npz_paths = [] + for i in range(10): + latent = torch.randn(*shape, dtype=torch.float32) + latents_npz_path = str(tmp_path / f"test_image_{i}_0032x0032_flux.npz") + latents_npz_paths.append(latents_npz_path) + metadata = {'image_key': f'test_image_{i}'} + preprocessor.add_latent( + latent=latent, + global_idx=i, + latents_npz_path=latents_npz_path, + shape=shape, + metadata=metadata + ) + + preprocessor.compute_all() + + dataset = GammaBDataset(device="cpu", config_hash=preprocessor.config_hash) + + noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu") + timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu") + latents_npz_paths_batch = latents_npz_paths[:2] + + with caplog.at_level(logging.WARNING): + caplog.clear() + _ = apply_cdc_noise_transformation( + noise=noise, + timesteps=timesteps, + num_timesteps=1000, + gamma_b_dataset=dataset, + latents_npz_paths=latents_npz_paths_batch, + device="cpu" + ) + + # No device mismatch warnings + device_warnings = [rec for rec in caplog.records if "device mismatch" in rec.message.lower()] + assert len(device_warnings) == 0, "Should not warn when devices match" + + def test_device_mismatch_handling(self, tmp_path): + """ + Test that CDC transformation handles device mismatch gracefully + """ + # Create CDC cache on CPU + preprocessor = CDCPreprocessor( + k_neighbors=8, k_bandwidth=3, d_cdc=8, gamma=1.0, device="cpu", + dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash + ) + + shape = (16, 32, 32) + latents_npz_paths = [] + for i in range(10): + latent = torch.randn(*shape, dtype=torch.float32) + latents_npz_path = str(tmp_path / f"test_image_{i}_0032x0032_flux.npz") + latents_npz_paths.append(latents_npz_path) + metadata = {'image_key': f'test_image_{i}'} + preprocessor.add_latent( + latent=latent, + global_idx=i, + latents_npz_path=latents_npz_path, + shape=shape, + metadata=metadata + ) + + preprocessor.compute_all() + + dataset = GammaBDataset(device="cpu", config_hash=preprocessor.config_hash) + + # Create noise and timesteps + noise = torch.randn(2, *shape, dtype=torch.float32, device="cpu", requires_grad=True) + timesteps = torch.tensor([100.0, 200.0], dtype=torch.float32, device="cpu") + latents_npz_paths_batch = latents_npz_paths[:2] + + # Perform CDC transformation + result = apply_cdc_noise_transformation( + noise=noise, + timesteps=timesteps, + num_timesteps=1000, + gamma_b_dataset=dataset, + latents_npz_paths=latents_npz_paths_batch, + device="cpu" + ) + + # Verify output characteristics + assert result.shape == noise.shape + assert result.device == noise.device + assert result.requires_grad # Gradients should still work + assert not torch.isnan(result).any() + assert not torch.isinf(result).any() + + # Verify gradients flow + loss = result.sum() + loss.backward() + assert noise.grad is not None + + +class TestCDCEndToEnd: + """ + End-to-end CDC workflow tests + """ + + def test_full_preprocessing_usage_workflow(self, tmp_path): + """ + Test complete workflow: preprocess -> save -> load -> use + """ + # Step 1: Preprocess latents + preprocessor = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu", + dataset_dirs=[str(tmp_path)] # Add dataset_dirs for hash + ) + + num_samples = 10 + latents_npz_paths = [] + for i in range(num_samples): + latent = torch.randn(16, 4, 4, dtype=torch.float32) + latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz") + latents_npz_paths.append(latents_npz_path) + metadata = {'image_key': f'test_image_{i}'} + preprocessor.add_latent( + latent=latent, + global_idx=i, + latents_npz_path=latents_npz_path, + shape=latent.shape, + metadata=metadata + ) + + files_saved = preprocessor.compute_all() + assert files_saved == num_samples + + # Step 2: Load with GammaBDataset (use config hash) + gamma_b_dataset = GammaBDataset(device="cpu", config_hash=preprocessor.config_hash) + + # Step 3: Use in mock training scenario + batch_size = 3 + batch_latents_flat = torch.randn(batch_size, 256) # B, d (flattened 16*4*4=256) + batch_t = torch.rand(batch_size) + latents_npz_paths_batch = [latents_npz_paths[0], latents_npz_paths[5], latents_npz_paths[9]] + + # Get Γ_b components (pass latent_shape for multi-resolution support) + latent_shape = (16, 4, 4) + eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(latents_npz_paths_batch, device="cpu", latent_shape=latent_shape) + + # Compute geometry-aware noise + sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, batch_t) + + # Verify output is reasonable + assert sigma_t_x.shape == batch_latents_flat.shape + assert not torch.isnan(sigma_t_x).any() + assert torch.isfinite(sigma_t_x).all() + + # Verify that noise changes with different timesteps + sigma_t0 = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, torch.zeros(batch_size)) + sigma_t1 = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, torch.ones(batch_size)) + + # At t=0, should be close to x; at t=1, should be different + assert torch.allclose(sigma_t0, batch_latents_flat, atol=1e-6) + assert not torch.allclose(sigma_t1, batch_latents_flat, atol=0.1) + + +def pytest_configure(config): + """ + Configure custom markers for CDC tests + """ + config.addinivalue_line( + "markers", + "device_consistency: mark test to verify device handling in CDC transformations" + ) + config.addinivalue_line( + "markers", + "preprocessor: mark test to verify CDC preprocessing workflow" + ) + config.addinivalue_line( + "markers", + "end_to_end: mark test to verify full CDC workflow" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) \ No newline at end of file diff --git a/tests/library/test_cdc_standalone.py b/tests/library/test_cdc_standalone.py new file mode 100644 index 000000000..c5a6914a1 --- /dev/null +++ b/tests/library/test_cdc_standalone.py @@ -0,0 +1,299 @@ +""" +Standalone tests for CDC-FM per-file caching. + +These tests focus on the current CDC-FM per-file caching implementation +with hash-based cache validation. +""" + +from pathlib import Path + +import pytest +import torch +import numpy as np + +from library.cdc_fm import CDCPreprocessor, GammaBDataset + + +class TestCDCPreprocessor: + """Test CDC preprocessing functionality with per-file caching""" + + def test_cdc_preprocessor_basic_workflow(self, tmp_path): + """Test basic CDC preprocessing with small dataset""" + preprocessor = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu", + dataset_dirs=[str(tmp_path)] + ) + + # Add 10 small latents + for i in range(10): + latent = torch.randn(16, 4, 4, dtype=torch.float32) # C, H, W + latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz") + metadata = {'image_key': f'test_image_{i}'} + preprocessor.add_latent( + latent=latent, + global_idx=i, + latents_npz_path=latents_npz_path, + shape=latent.shape, + metadata=metadata + ) + + # Compute and save (creates per-file CDC caches) + files_saved = preprocessor.compute_all() + + # Verify files were created + assert files_saved == 10 + + # Verify first CDC file structure + latents_npz_path = str(tmp_path / "test_image_0_0004x0004_flux.npz") + latent_shape = (16, 4, 4) + cdc_path = Path(CDCPreprocessor.get_cdc_npz_path(latents_npz_path, preprocessor.config_hash, latent_shape)) + assert cdc_path.exists() + + data = np.load(cdc_path) + assert data['k_neighbors'] == 5 + assert data['d_cdc'] == 4 + + # Check eigenvectors and eigenvalues + eigvecs = data['eigenvectors'] + eigvals = data['eigenvalues'] + + assert eigvecs.shape[0] == 4 # d_cdc + assert eigvals.shape[0] == 4 # d_cdc + + def test_cdc_preprocessor_different_shapes(self, tmp_path): + """Test CDC preprocessing with variable-size latents (bucketing)""" + preprocessor = CDCPreprocessor( + k_neighbors=3, k_bandwidth=2, d_cdc=2, gamma=1.0, device="cpu", + dataset_dirs=[str(tmp_path)] + ) + + # Add 5 latents of shape (16, 4, 4) + for i in range(5): + latent = torch.randn(16, 4, 4, dtype=torch.float32) + latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz") + metadata = {'image_key': f'test_image_{i}'} + preprocessor.add_latent( + latent=latent, + global_idx=i, + latents_npz_path=latents_npz_path, + shape=latent.shape, + metadata=metadata + ) + + # Add 5 latents of different shape (16, 8, 8) + for i in range(5, 10): + latent = torch.randn(16, 8, 8, dtype=torch.float32) + latents_npz_path = str(tmp_path / f"test_image_{i}_0008x0008_flux.npz") + metadata = {'image_key': f'test_image_{i}'} + preprocessor.add_latent( + latent=latent, + global_idx=i, + latents_npz_path=latents_npz_path, + shape=latent.shape, + metadata=metadata + ) + + # Compute and save + files_saved = preprocessor.compute_all() + + # Verify both shape groups were processed + assert files_saved == 10 + + # Check shapes are stored in individual files + cdc_path_0 = CDCPreprocessor.get_cdc_npz_path( + str(tmp_path / "test_image_0_0004x0004_flux.npz"), preprocessor.config_hash, latent_shape=(16, 4, 4) + ) + cdc_path_5 = CDCPreprocessor.get_cdc_npz_path( + str(tmp_path / "test_image_5_0008x0008_flux.npz"), preprocessor.config_hash, latent_shape=(16, 8, 8) + ) + + data_0 = np.load(cdc_path_0) + data_5 = np.load(cdc_path_5) + + assert tuple(data_0['shape']) == (16, 4, 4) + assert tuple(data_5['shape']) == (16, 8, 8) + + +class TestGammaBDataset: + """Test GammaBDataset loading and retrieval with per-file caching""" + + @pytest.fixture + def sample_cdc_cache(self, tmp_path): + """Create sample CDC cache files for testing""" + # Use 20 samples to ensure proper k-NN computation + # (minimum 256 neighbors recommended, but 20 samples with k=5 is sufficient for testing) + preprocessor = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu", + dataset_dirs=[str(tmp_path)], + adaptive_k=True, # Enable adaptive k for small dataset + min_bucket_size=5 + ) + + # Create 20 samples + latents_npz_paths = [] + for i in range(20): + latent = torch.randn(16, 8, 8, dtype=torch.float32) # C=16, d=1024 when flattened + latents_npz_path = str(tmp_path / f"test_{i}_0008x0008_flux.npz") + latents_npz_paths.append(latents_npz_path) + metadata = {'image_key': f'test_{i}'} + preprocessor.add_latent( + latent=latent, + global_idx=i, + latents_npz_path=latents_npz_path, + shape=latent.shape, + metadata=metadata + ) + + preprocessor.compute_all() + return tmp_path, latents_npz_paths, preprocessor.config_hash + + def test_gamma_b_dataset_loads_metadata(self, sample_cdc_cache): + """Test that GammaBDataset loads CDC files correctly""" + tmp_path, latents_npz_paths, config_hash = sample_cdc_cache + gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash) + + # Get components for first sample + latent_shape = (16, 8, 8) + eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([latents_npz_paths[0]], device="cpu", latent_shape=latent_shape) + + # Check shapes + assert eigvecs.shape[0] == 1 # batch size + assert eigvecs.shape[1] == 4 # d_cdc + assert eigvals.shape == (1, 4) # batch, d_cdc + + def test_gamma_b_dataset_get_gamma_b_sqrt(self, sample_cdc_cache): + """Test retrieving Γ_b^(1/2) components""" + tmp_path, latents_npz_paths, config_hash = sample_cdc_cache + gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash) + + # Get Γ_b for paths [0, 2, 4] + paths = [latents_npz_paths[0], latents_npz_paths[2], latents_npz_paths[4]] + latent_shape = (16, 8, 8) + eigenvectors, eigenvalues = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu", latent_shape=latent_shape) + + # Check shapes + assert eigenvectors.shape[0] == 3 # batch + assert eigenvectors.shape[1] == 4 # d_cdc + assert eigenvalues.shape == (3, 4) # (batch, d_cdc) + + # Check values are positive + assert torch.all(eigenvalues > 0) + + def test_gamma_b_dataset_compute_sigma_t_x_at_t0(self, sample_cdc_cache): + """Test compute_sigma_t_x returns x unchanged at t=0""" + tmp_path, latents_npz_paths, config_hash = sample_cdc_cache + gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash) + + # Create test latents (batch of 3, matching d=1024 flattened) + x = torch.randn(3, 1024) # B, d (flattened) + t = torch.zeros(3) # t = 0 for all samples + + # Get Γ_b components + paths = [latents_npz_paths[0], latents_npz_paths[1], latents_npz_paths[2]] + latent_shape = (16, 8, 8) + eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu", latent_shape=latent_shape) + + sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t) + + # At t=0, should return x unchanged + assert torch.allclose(sigma_t_x, x, atol=1e-6) + + def test_gamma_b_dataset_compute_sigma_t_x_shape(self, sample_cdc_cache): + """Test compute_sigma_t_x returns correct shape""" + tmp_path, latents_npz_paths, config_hash = sample_cdc_cache + gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash) + + x = torch.randn(2, 1024) # B, d (flattened) + t = torch.tensor([0.3, 0.7]) + + # Get Γ_b components + paths = [latents_npz_paths[1], latents_npz_paths[3]] + latent_shape = (16, 8, 8) + eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu", latent_shape=latent_shape) + + sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t) + + # Should return same shape as input + assert sigma_t_x.shape == x.shape + + def test_gamma_b_dataset_compute_sigma_t_x_no_nans(self, sample_cdc_cache): + """Test compute_sigma_t_x produces finite values""" + tmp_path, latents_npz_paths, config_hash = sample_cdc_cache + gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash) + + x = torch.randn(3, 1024) # B, d (flattened) + t = torch.rand(3) # Random timesteps in [0, 1] + + # Get Γ_b components + paths = [latents_npz_paths[0], latents_npz_paths[2], latents_npz_paths[4]] + latent_shape = (16, 8, 8) + eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu", latent_shape=latent_shape) + + sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t) + + # Should not contain NaNs or Infs + assert not torch.isnan(sigma_t_x).any() + assert torch.isfinite(sigma_t_x).all() + + +class TestCDCEndToEnd: + """End-to-end CDC workflow tests""" + + def test_full_preprocessing_and_usage_workflow(self, tmp_path): + """Test complete workflow: preprocess -> save -> load -> use""" + # Step 1: Preprocess latents + preprocessor = CDCPreprocessor( + k_neighbors=5, k_bandwidth=3, d_cdc=4, gamma=1.0, device="cpu", + dataset_dirs=[str(tmp_path)] + ) + + num_samples = 10 + latents_npz_paths = [] + for i in range(num_samples): + latent = torch.randn(16, 4, 4, dtype=torch.float32) + latents_npz_path = str(tmp_path / f"test_image_{i}_0004x0004_flux.npz") + latents_npz_paths.append(latents_npz_path) + metadata = {'image_key': f'test_image_{i}'} + preprocessor.add_latent( + latent=latent, + global_idx=i, + latents_npz_path=latents_npz_path, + shape=latent.shape, + metadata=metadata + ) + + files_saved = preprocessor.compute_all() + assert files_saved == num_samples + + # Step 2: Load with GammaBDataset + gamma_b_dataset = GammaBDataset(device="cpu", config_hash=preprocessor.config_hash) + + # Step 3: Use in mock training scenario + batch_size = 3 + batch_latents_flat = torch.randn(batch_size, 256) # B, d (flattened 16*4*4=256) + batch_t = torch.rand(batch_size) + paths_batch = [latents_npz_paths[0], latents_npz_paths[5], latents_npz_paths[9]] + + # Get Γ_b components + latent_shape = (16, 4, 4) + eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths_batch, device="cpu", latent_shape=latent_shape) + + # Compute geometry-aware noise + sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, batch_t) + + # Verify output is reasonable + assert sigma_t_x.shape == batch_latents_flat.shape + assert not torch.isnan(sigma_t_x).any() + assert torch.isfinite(sigma_t_x).all() + + # Verify that noise changes with different timesteps + sigma_t0 = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, torch.zeros(batch_size)) + sigma_t1 = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, torch.ones(batch_size)) + + # At t=0, should be close to x; at t=1, should be different + assert torch.allclose(sigma_t0, batch_latents_flat, atol=1e-6) + assert not torch.allclose(sigma_t1, batch_latents_flat, atol=0.1) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/library/test_flux_train_utils.py b/tests/library/test_flux_train_utils.py index 2ad7ce4ee..bc9a5fdb8 100644 --- a/tests/library/test_flux_train_utils.py +++ b/tests/library/test_flux_train_utils.py @@ -2,7 +2,7 @@ import torch from unittest.mock import MagicMock, patch from library.flux_train_utils import ( - get_noisy_model_input_and_timesteps, + get_noisy_model_input_and_timestep, ) # Mock classes and functions @@ -66,7 +66,7 @@ def test_uniform_sampling(args, noise_scheduler, latents, noise, device): args.timestep_sampling = "uniform" dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape assert timesteps.shape == (latents.shape[0],) @@ -80,7 +80,7 @@ def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device): args.sigmoid_scale = 1.0 dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape assert timesteps.shape == (latents.shape[0],) @@ -93,7 +93,7 @@ def test_shift_sampling(args, noise_scheduler, latents, noise, device): args.discrete_flow_shift = 3.1582 dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape assert timesteps.shape == (latents.shape[0],) @@ -105,7 +105,7 @@ def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device): args.sigmoid_scale = 1.0 dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape assert timesteps.shape == (latents.shape[0],) @@ -126,7 +126,7 @@ def test_weighting_scheme(args, noise_scheduler, latents, noise, device): args.mode_scale = 1.0 dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps( + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep( args, noise_scheduler, latents, noise, device, dtype ) @@ -141,7 +141,7 @@ def test_with_ip_noise(args, noise_scheduler, latents, noise, device): args.ip_noise_gamma_random_strength = False dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape assert timesteps.shape == (latents.shape[0],) @@ -153,7 +153,7 @@ def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device): args.ip_noise_gamma_random_strength = True dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape assert timesteps.shape == (latents.shape[0],) @@ -164,7 +164,7 @@ def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device): def test_float16_dtype(args, noise_scheduler, latents, noise, device): dtype = torch.float16 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.dtype == dtype assert timesteps.dtype == dtype @@ -176,7 +176,7 @@ def test_different_batch_size(args, noise_scheduler, device): noise = torch.randn(5, 4, 8, 8) dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape assert timesteps.shape == (5,) @@ -189,7 +189,7 @@ def test_different_image_size(args, noise_scheduler, device): noise = torch.randn(2, 4, 16, 16) dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape assert timesteps.shape == (2,) @@ -203,7 +203,7 @@ def test_zero_batch_size(args, noise_scheduler, device): noise = torch.randn(0, 4, 8, 8) dtype = torch.float32 - get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) def test_different_timestep_count(args, device): @@ -212,7 +212,7 @@ def test_different_timestep_count(args, device): noise = torch.randn(2, 4, 8, 8) dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape assert timesteps.shape == (2,) diff --git a/train_network.py b/train_network.py index 6cebf5fc7..1866045b1 100644 --- a/train_network.py +++ b/train_network.py @@ -622,6 +622,27 @@ def train(self, args): accelerator.wait_for_everyone() + # CDC-FM preprocessing + if hasattr(args, "use_cdc_fm") and args.use_cdc_fm: + logger.info("CDC-FM enabled, preprocessing Γ_b matrices...") + + self.cdc_config_hash = train_dataset_group.cache_cdc_gamma_b( + k_neighbors=args.cdc_k_neighbors, + k_bandwidth=args.cdc_k_bandwidth, + d_cdc=args.cdc_d_cdc, + gamma=args.cdc_gamma, + force_recache=args.force_recache_cdc, + accelerator=accelerator, + debug=getattr(args, 'cdc_debug', False), + adaptive_k=getattr(args, 'cdc_adaptive_k', False), + min_bucket_size=getattr(args, 'cdc_min_bucket_size', 16), + ) + + if self.cdc_config_hash is None: + logger.warning("CDC-FM preprocessing failed (likely missing FAISS). Training will continue without CDC-FM.") + else: + self.cdc_config_hash = None + # 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される # cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu text_encoding_strategy = self.get_text_encoding_strategy(args) @@ -660,6 +681,19 @@ def train(self, args): accelerator.print(f"all weights merged: {', '.join(args.base_weights)}") + # Load CDC-FM Γ_b dataset if enabled + if hasattr(args, "use_cdc_fm") and args.use_cdc_fm and self.cdc_config_hash is not None: + from library.cdc_fm import GammaBDataset + + logger.info(f"CDC Γ_b dataset ready (hash: {self.cdc_config_hash})") + + self.gamma_b_dataset = GammaBDataset( + device="cuda" if torch.cuda.is_available() else "cpu", + config_hash=self.cdc_config_hash + ) + else: + self.gamma_b_dataset = None + # prepare network net_kwargs = {} if args.network_args is not None: