CDC-FM (Carré du Champ Flow Matching) Implementation#2221
CDC-FM (Carré du Champ Flow Matching) Implementation#2221rockerBOO wants to merge 27 commits intokohya-ss:sd3from
Conversation
Implements geometry-aware noise generation for FLUX training based on arXiv:2510.05930v1.
- Cache all shapes during GammaBDataset initialization - Eliminates file I/O on every training step (9.5M accesses/sec) - Reduces get_shape() from file operation to dict lookup - Memory overhead: ~126 bytes/sample (~12.6 MB per 100k images)
- Create apply_cdc_noise_transformation() for better modularity - Implement fast path for batch processing when all shapes match - Implement slow path for per-sample processing on shape mismatch - Clone noise tensors in fallback path for gradient consistency
- Remove @torch.no_grad() decorator from compute_sigma_t_x() - Gradients now properly flow through CDC transformation during training - Add comprehensive gradient flow tests for fast/slow paths and fallback - All 25 CDC tests passing
- Track warned samples in global set to prevent log spam - Each sample only warned once per training session - Prevents thousands of duplicate warnings during training - Add tests to verify throttling behavior
- Check that noise and CDC matrices are on same device - Automatically transfer noise if device mismatch detected - Warn user when device transfer occurs - Add tests to verify device handling
- Treat cuda and cuda:0 as compatible devices - Only warn on actual device mismatches (cuda vs cpu) - Eliminates warning spam during multi-subset training
Fixes shape mismatch bug in multi-subset training where CDC preprocessing and training used different index calculations, causing wrong CDC data to be loaded for samples. Changes: - CDC cache now stores/loads data using image_key strings instead of integer indices - Training passes image_key list instead of computed integer indices - All CDC lookups use stable image_key identifiers - Improved device compatibility check (handles "cuda" vs "cuda:0") - Updated all 30 CDC tests to use image_key-based access Root cause: Preprocessing used cumulative dataset indices while training used sorted keys, resulting in mismatched lookups during shuffled multi-subset training.
- Add --cdc_debug flag to enable verbose bucket-by-bucket output - When debug=False (default): Show tqdm progress bar, concise logging - When debug=True: Show detailed bucket information, no progress bar - Improves user experience during CDC cache generation
|
Thank you for this! This seems to be effective when the data set is limited, so it looks very good. I plan to merge the sd3 branch into main soon, so I'd like to merge this (and a few other PRs) before then. |
- Add ss_use_cdc_fm, ss_cdc_k_neighbors, ss_cdc_k_bandwidth, ss_cdc_d_cdc, ss_cdc_gamma - Ensures CDC-FM training parameters are tracked in model metadata - Enables reproducibility and model provenance tracking
- Add --cdc_adaptive_k flag to enable adaptive k based on bucket size - Add --cdc_min_bucket_size to set minimum bucket threshold (default: 16) - Fixed mode (default): Skip buckets with < k_neighbors samples - Adaptive mode: Use k=min(k_neighbors, bucket_size-1) for buckets >= min_bucket_size - Update CDCPreprocessor to support adaptive k per bucket - Add metadata tracking for adaptive_k and min_bucket_size - Add comprehensive pytest tests for adaptive k behavior This allows CDC-FM to work effectively with multi-resolution bucketing where bucket sizes may vary widely. Users can choose between strict paper methodology (fixed k) or pragmatic approach (adaptive k).
|
Issue right now is we are caching the neighbors into a file but saving it into the output_dir. this means each run we make a new file. We could:
I'd usually set it with the dataset but if multiple subsets are set it isn't one place. |
- Make FAISS import optional with try/except - CDCPreprocessor raises helpful ImportError if FAISS unavailable - train_util.py catches ImportError and returns None - train_network.py checks for None and warns user - Training continues without CDC-FM if FAISS not installed - Remove benchmark file (not needed in repo) This allows users to run training without FAISS dependency. CDC-FM will be automatically disabled with a warning if FAISS is missing.
|
@rockerBOO i plan to test this this is only for flux lora? |
- Add explicit warning and tracking for multiple unique latent shapes - Simplify test imports by removing unused modules - Minor formatting improvements in print statements - Ensure log messages provide clear context about dimension mismatches
- Merged redundant test files - Removed 'comprehensive' from file and docstring names - Improved test organization and clarity - Ensured all tests continue to pass - Simplified test documentation
Yes only Flux LoRA for the moment |
flux_train_network.py
Outdated
| # If CDC is enabled, this will transform the noise with geometry-aware covariance | ||
| noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timestep( | ||
| args, noise_scheduler, latents, noise, accelerator.device, weight_dtype, | ||
| gamma_b_dataset=gamma_b_dataset, latents_npz_paths=latents_npz_paths, timestep_index=timestep_index |
There was a problem hiding this comment.
timestep_index doesn't seem to be defined?
There was a problem hiding this comment.
Thank you @recris. Removed. Was from another PR.
Add support for CDC-FM, a geometry-aware noise generation method that improves diffusion model training by adapting noise to the local geometry of the latent space. CDC-FM replaces standard Gaussian noise with geometry-informed noise that better preserves the structure of the data manifold.
Note: Only implemented for Flux network training so far. Can be expanded to other flow matching models SD3, Lumina Image 2.
https://arxiv.org/abs/2510.05930
Note: Written with AI but I guided how it was implemented.
Recommended Configurations:
Single Resolution (e.g., all 512×512):
Multi-Resolution with Bucketing (FLUX/SDXL):
Small Dataset (<1000 images):
Parameter Guide:
--cdc_k_neighbors--cdc_adaptive_kk = min(k_neighbors, bucket_size - 1)forbuckets ≥ min_bucket_size--cdc_min_bucket_size--cdc_adaptive_kis enabled--cdc_k_bandwidth--cdc_gamma--cdc_d_cdc