activation-level disillation#388
Conversation
|
great progress! did you freeze everything except the randomly initialized mixers? |
|
Resetting and distilling only one layer, freezing the rest of the model gives satisfactory results:
Note some changes were required to allow loading a pretrained model while freezing certain layers (#394 ) |
| ModelTestingGroup.convert: ModelTestingGroupAction.unimportant, | ||
| ModelTestingGroup.generate: ModelTestingGroupAction.unimportant, | ||
| ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, | ||
| ModelTestingGroup.distributed: ModelTestingGroupAction.broken, # failing: tp2, stp2, stp2_ce4 |
There was a problem hiding this comment.
We'll probably want to leave these as unimportant and run once in a while, because the testing suite can't really support many distributed runs.
There was a problem hiding this comment.
Makes sense. Some of the distributed tests are failing currently, so let's leave it as broken for now?
| """ | ||
| Maybe apply activation distillation loss and setup backward hooks | ||
| """ | ||
| mixer_output = hidden_states if bias is None else hidden_states + bias |
There was a problem hiding this comment.
This should only be evaluated if needed.
| self.mlp.preprocess(kwargs) | ||
|
|
||
| # TODO: add layer_index | ||
| _activation_distillation_loss_name = "activation_distillation_loss" |
There was a problem hiding this comment.
would be nice to have a layer index in logging
There was a problem hiding this comment.
Agree! This involves a bit more changes because FixedBlockSequence currently assumes that all the blocks have the same loss definitions.
We could add this in a second PR
There was a problem hiding this comment.
We'd also need to be careful not to drown the logs in these activation losses for each layer
|
Thank you for the reviews! The comments are addressed, could you have another look? @jlamypoirier |
jlamypoirier
left a comment
There was a problem hiding this comment.
Small comments, otherwise LGTM
| # TODO: fix this | ||
| if "block" in default: | ||
| removed = default.pop("block") | ||
| log( |



✨ Description
Closes #385
TODOs:
0and gradients as well.Sanity checks:
0loss ✔️. But loss then increases to a small value instead of staying at 0.0loss (orange)With the caveat that distillation seems to experience memory spikes at specific points in training. The actual usage was lower most of the time:
🔍 Type of change
Select all that apply:
Testing
Performance Impact
📊 Performance Impact Details
If there is any impact on performance, describe it and provide benchmark results, if applicable: