Skip to content

httkxy/MyDiffusionModel

Repository files navigation

My Diffusion Model

一个基于 PyTorch 实现的扩散模型 (Diffusion Model) 项目,支持多种数据集的图像生成任务。

🚀 项目特点

  • 完整的扩散模型实现: 包含前向扩散过程和反向去噪过程
  • UNet 架构: 使用 UNet 作为噪声预测网络
  • 多数据集支持: 支持 CIFAR-10, Fashion-MNIST, CelebA, STL-10 等数据集
  • 灵活的配置: 可配置的时间步数、噪声调度等参数
  • 可视化工具: 提供去噪过程的可视化和结果展示

📁 项目结构

MyDiffusionModel/
├── image_test.py                    # 主测试脚本,包含数据集配置和训练逻辑
├── diffusionModels/                 # 扩散模型实现
│   └── simpleDiffusion/
│       ├── simpleDiffusion.py       # 扩散模型核心实现
│       └── varianceSchedule.py      # 噪声调度器
├── noisePredictModels/              # 噪声预测模型
│   └── Unet/
│       └── UNet.py                  # UNet 网络架构
├── utils/                           # 工具函数
│   ├── network_helper.py            # 网络辅助函数
│   └── train_network_helper.py      # 训练辅助函数
├── denoising_visualization/         # 去噪过程可视化结果
└── README.md                        # 项目文档

🛠️ 环境要求

torch>=1.9.0
torchvision>=0.10.0
matplotlib>=3.0.0
numpy>=1.19.0
pillow>=8.0.0
tqdm>=4.0.0

🏃‍♂️ 快速开始

安装依赖

pip install -r requirements.txt

运行测试

python image_test.py

⚙️ 配置说明

项目支持多种数据集配置,在 image_test.py 中可以轻松切换:

支持的数据集

  • Fashion-MNIST: 28x28 灰度图像,时尚物品分类
  • CIFAR-10: 32x32 彩色图像,10个类别
  • CIFAR-100: 32x32 彩色图像,100个类别
  • CelebA: 64x64 彩色人脸图像
  • STL-10: 96x96 彩色图像
  • MNIST: 28x28 灰度手写数字

模型参数

  • 时间步数: 1000 (可配置)
  • 噪声调度: 线性调度 (linear_beta_schedule)
  • UNet 维度倍数: (1, 2, 4) 可调整
  • 图像尺寸: 根据数据集自动配置

训练配置

# 在 image_test.py 中修改以下参数
CURRENT_DATASET = 'fashion_mnist'  # 选择数据集
T = 1000                           # 扩散时间步数
schedule_name = "linear"           # 噪声调度类型
dim_mults = (1, 2, 4)             # UNet 维度倍数

🎨 可视化功能

项目提供了完整的去噪过程可视化:

  • 逐步去噪过程展示: 展示从纯噪声到清晰图像的转变
  • 网格化结果对比: 多个样本的并行去噪过程
  • 最终生成结果: 高质量的生成图像
  • 训练过程监控: 损失曲线和生成质量跟踪

🔬 模型架构

扩散过程

  1. 前向过程: 逐步向图像添加高斯噪声,直到变成纯噪声
  2. 反向过程: 使用训练好的 UNet 预测噪声并逐步去噪
  3. 采样: 从纯噪声开始生成新图像

UNet 网络特点

  • 编码器-解码器架构: 下采样和上采样对称设计
  • 跳跃连接: 保留细节信息
  • 时间嵌入: 将扩散时间步嵌入到网络中
  • 残差块: 提升训练稳定性
  • 组归一化: 改善收敛性能

噪声调度

  • 线性调度: β 值线性增长
  • 余弦调度: 更平滑的噪声添加过程
  • 可扩展: 支持自定义调度策略

📊 训练说明

训练策略

  • 优化器: Adam (lr=1e-3)
  • 损失函数: MSE Loss (均方误差)
  • 批次大小: 可配置
  • 学习率调度: 支持多种调度策略

训练过程

  1. 数据预处理和增强
  2. 随机时间步采样
  3. 噪声添加和预测
  4. 损失计算和反向传播
  5. 定期保存检查点

🎯 使用示例

基础训练

from utils.train_network_helper import SimpleDiffusionTrainer
from noisePredictModels.Unet.UNet import Unet
from diffusionModels.simpleDiffusion.simpleDiffusion import DiffusionModel

# 创建模型
model = Unet(dim=64, channels=3, dim_mults=(1, 2, 4))
diffusion = DiffusionModel(denoise_fn=model, image_size=32, timesteps=1000)

# 训练
trainer = SimpleDiffusionTrainer(diffusion)
trainer.train()

图像生成

# 从训练好的模型生成图像
samples = diffusion.sample(batch_size=4)

📈 性能指标

  • FID Score: 评估生成图像质量
  • IS Score: 评估图像多样性
  • 训练时间: 根据数据集和硬件配置而定
  • 内存占用: 支持梯度累积以适应不同GPU

🔧 故障排除

常见问题

  1. CUDA 内存不足: 减少批次大小或使用梯度累积
  2. 训练不稳定: 调整学习率或使用学习率调度
  3. 生成质量差: 增加训练轮数或调整网络架构

🤝 贡献指南

欢迎提交 Issue 和 Pull Request!

  1. Fork 项目
  2. 创建特性分支
  3. 提交更改
  4. 推送到分支
  5. 创建 Pull Request

📚 参考文献

  • Denoising Diffusion Probabilistic Models (DDPM)
  • Improved Denoising Diffusion Probabilistic Models
  • Diffusion Models Beat GANs on Image Synthesis

📄 许可证

MIT License

👤 作者

httkxy


如果这个项目对您有帮助,请给个 ⭐️ 支持一下!

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages