Skip to content

NJUDeepEngine/OTCA

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 

Repository files navigation

OTCA

Reduce Cache Error via Optimal Transport

Establish

We build this repertory to complete the idea of reduce error caused by previous dit accelerating technology based on cache by introducing optimal transport. we just provide a todo list below,and will rewrite the README.md after we finish our base job.

A. 先明确实验对象与输出(当天就能定)

  1. 选 baseline cache 方法:FORA / L2C / TeaCache / ToCa / 你自己的 cache(确定一个主 baseline)

  2. 选纠偏插入点:纠偏哪个张量

    • ☐ Attention out / MLP out / Block out(先选一个最常被 cache 复用的)
  3. 选 ((t,l)) 子集:先少量

    • ☐ timesteps:后 20%(例如 50/250–250/250)
    • ☐ layers:中后层 2–4 个 block
  4. 定义评测指标与设置

    • ☐ FID / IS / sFID(如果你用)
    • ☐ 速度:latency、GFLOPs、cache hit rate
    • ☐ 质量:人眼对比(后期糊的 case)

B. 离线数据采集(稳定版必做)

  1. 构建离线样本集 (\mathcal{S})

    • ☐ N=2k~10k prompts×seeds(先小后大)
    • ☐ 覆盖你目标评测的语义分布
  2. 对每个 (s_i) 跑两遍并保存特征

    • ☐ Oracle 推理:保存 (F^{oracle}_{t,l}(s_i))
    • ☐ Cache 推理:保存 (F^{cache}_{t,l}(s_i))
    • ☐ 保存格式:float16/float32、按 ((t,l)) 分文件或 mmap
  3. 实现特征抽取 (\phi(\cdot))(pooled feature)

    • ☐ mean pooling(默认)
    • ☐ 备选:cls token / attention pooling
  4. 特征预处理(强烈建议)

    • ☐ LayerNorm 或 ℓ2 normalize(记录你用哪种)
    • ☐ 统一在 ((t,l)) 维度分别处理(避免尺度混)

C. 稳定版 OT 纠偏器训练(离线)

  1. 定义经验分布

    • ☐ (x_i=\phi(F^{cache}_{t,l}(s_i)))
    • ☐ (y_i=\phi(F^{oracle}_{t,l}(s_i)))
    • ☐ (a_i=b_i=1/N)
  2. (推荐)原型压缩:oracle 原型库 (Y_{t,l})

  • ☐ 对 ({y_i}) 做 k-means 得 (K=512/1024) 原型
  • ☐ 设定 (\tilde b_j):按簇大小或均匀
  1. 选择代价 (c(x,y))(稳定默认)
  • ☐ 余弦:(c=1-\cos(x,y))(特征已归一化)
  • ☐ 等价:(|\hat x-\hat y|^2)
  1. 跑 Sinkhorn(熵正则 OT)并保存可在线化参数
  • ☐ 选 (\varepsilon_{t,l})(先固定一个)
  • ☐ 输出保存:((Y_{t,l}, v_{t,l}, \varepsilon_{t,l}))
  • ☐ (可选)记录训练集上的 OT cost / 收敛情况

D. 在线插拔实现(稳定版上线)

  1. 在线权重计算
  • ☐ 对新点 (x):(w_j(x)\propto v_j\exp(-c(x,y_j)/\varepsilon))
  1. barycentric projection
  • ☐ (T(x)=\frac{\sum_j w_j(x)y_j}{\sum_j w_j(x)})
  1. 回写策略(先稳)
  • ☐ 混合:(\tilde f\leftarrow (1-\alpha)\tilde f^{cache}+\alpha,\text{unproj}(T(x)))
  • ☐ 固定 (\alpha)(例如 0.2/0.4/0.6)做网格
  1. 只在选定 ((t,l)) 启用纠偏
  • ☐ 后 20% step + 2–4 layers(先小规模验证)
  1. 跑完整评测
  • ☐ 质量:FID/IS + 关键 case 的视觉对比
  • ☐ 速度:吞吐/延迟开销(OT 投影额外耗时)

E. 诊断与 sanity checks(强烈建议做)

  1. 检查“是否避免拉向中心”
  • ☐ 可视化:PCA/UMAP 下 oracle/cache/corrected 分布
  • ☐ 统计:corrected 的最近原型分布是否更接近 oracle
  1. 检查权重是否集中
  • ☐ top-1 权重占比、熵(是否“软得太糊”)
  1. 检查后期是否更稳
  • ☐ 单独统计后期步数的质量差异(你关心的糊)

改进/论文增强(从这里开始做 ablation)

F. Cluster-conditioned OT(核心改进)

  1. 选择分簇信号(从易到强)
  • ☐ label(如 ImageNet)
  • ☐ prompt embedding 聚类
  • ☐ 两级:label → label 内再聚类
  1. 按簇重复训练
  • ☐ 得到 ((Y_{t,l,c}, v_{t,l,c}, \varepsilon_{t,l,c}))
  1. 在线路由
  • ☐ hard:选最近簇 (c^*)
  • ☐ soft:top-k 簇混合(边界更稳)

G. 更强 cost(当余弦仍错配时)

  1. PCA/白化后 L2²
  • ☐ 每个 ((t,l,c)) 做 PCA 到 64/128 维
  • ☐ (c(x,y)=|x'-y'|^2)

H. 防“软得太糊”的技巧

  1. top-k 截断权重
  • ☐ 只保留最大 k(8/16/32)再归一化
  1. (\varepsilon) annealing
  • ☐ 前期大、后期小(更接近“选峰”)
  1. 置信门控 (\alpha(x))
  • ☐ 用 (d_2-d_1)(最近/次近距离差)控制是否纠偏与纠偏强度

I. token-level 精修(最后做)

  1. token 子集纠偏
  • ☐ 采样 32/64 个重要 token(attention 权重/梯度/随机)
  • ☐ 只在后期步数启用(成本可控)
  1. 对比 pooled vs token-level
  • ☐ 量化“细节/纹理恢复”与计算开销

J. 写作与呈现(proposal/paper 需要)

  1. 主叙事
  • ☐ EOC:单峰/全局趋势 → 后期跨峰混合 → 糊
  • ☐ Ours:分布对齐 + 几何最省(OT)→ near-mode transport
  1. Ablation 表
  • ☐ no OT / OT pooled / + prototypes / + clustering / + top-k / + gating / token-level
  1. 可视化
  • ☐ 2D toy(你已经要了)
  • ☐ 真实特征 PCA/UMAP(oracle vs cache vs corrected)

About

Reduce Cache Error via Optimal Transport

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published