🎯 Goal (What & Why)
Add activation-level distillation, usually leading to better student performance.
🚀 Execution Plan
Step 1: What is the smallest working version?
- Distill based on all mixer-layer outputs
- Support the case where the student has the same number of layers as the teacher.
- Use MSE loss.
- Add a single coefficient that balances feature-level distillation vs logit-level.
As a first version:
$$L = L_{\text{logit}} + \lambda L_{\text{activation}}$$
with:
$$\mathcal{L}_{\text{activation}} = \frac{1}{N} \sum_{i=1}^{N}||T_i(\mathbf{x}) - S_i(\mathbf{x})||_2$$
$T_i(x)$ and $S_i(x)$ denoting the outputs of the i-th layer's mixer of the teacher and student models.
Details:
- teacher stores intermediate activations in
kwargs
- student uses
kwargs to compute activation-distillation losses, which it stores in losses
Step 2: What additional optimizations are possible (but optional)?
- Should support TP with sequence-parallelism (this is actually not optional, but can be done in a second step)
- Can configure which layers' outputs are used for distillation. For example, we could distill only based on mixer-layer outputs, or also based on MLP outputs, etc. Pass a {student -> teacher} mapping of layer-names to use for distillation
- Configurable loss: MSE, cosine, others?
📌 Acceptance Criteria (Must-Haves for Completion)
- The feature must be functional and tested.
- The implementation must be documented in practical terms.
- The PR must include a performance/impact summary.
- No refactors unless directly necessary for feature completion.
🛠️ Project Management
🎯 Goal (What & Why)
Add activation-level distillation, usually leading to better student performance.
🚀 Execution Plan
Step 1: What is the smallest working version?
As a first version:
$$L = L_{\text{logit}} + \lambda L_{\text{activation}}$$
$$\mathcal{L}_{\text{activation}} = \frac{1}{N} \sum_{i=1}^{N}||T_i(\mathbf{x}) - S_i(\mathbf{x})||_2$$
$T_i(x)$ and $S_i(x)$ denoting the outputs of the i-th layer's mixer of the teacher and student models.
with:
Details:
kwargskwargsto compute activation-distillation losses, which it stores inlossesStep 2: What additional optimizations are possible (but optional)?
📌 Acceptance Criteria (Must-Haves for Completion)
🛠️ Project Management
Estimatefield (in days) in the GitHub project.Sizefield to categorize the PR size (Small/Medium/Large).