[PyTorch] Miscellanous fixes for FP8 DPA module#804
Conversation
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
…old checkpoints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
|
I looked at this further post our sync, looks like |
We do use |
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
…with fp8_group Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
|
/te-ci pytorch |
|
With #575, the amax reduction is handled in the |
|
@cyanguwa Regarding checkpoints compatibility: not requiring |
…xtra_state Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
|
@mikolajblaz I've moved |
Signed-off-by: cyanguwa <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
…re_attention; keep the test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
|
@ksivaman could you please help take another look? |
|
/te-ci pytorch |
|
I had some discussion with mikolajblaz offline and we decided to not pursue the move from |
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
|
/te-ci pytorch |
|
/te-ci pytorch |
* initialize tp_group for FP8 DPA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix cuDNN version in unit tests for cuDNN v9 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add hook to ignore missing fused_attn._extra_states if training from old checkpoints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove test and redundant implementation from last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove warning message and replace with docstring Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove tp_size/tp_group in FusedAttention; amax reduction is handled with fp8_group Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * move core_attention.fused_attention._extra_state to core_attention._extra_state Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * simplify post_state_dict_hooks between FU and DPA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add temporary test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove previous attempts to move core_attention.fused_attention to core_attention; keep the test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove the test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * disable pylint self arg for hook which is required by hook Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Signed-off-by: cyanguwa <8636796+cyanguwa@users.noreply.github.com>
* initialize tp_group for FP8 DPA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix cuDNN version in unit tests for cuDNN v9 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add hook to ignore missing fused_attn._extra_states if training from old checkpoints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove test and redundant implementation from last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove warning message and replace with docstring Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove tp_size/tp_group in FusedAttention; amax reduction is handled with fp8_group Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * move core_attention.fused_attention._extra_state to core_attention._extra_state Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * simplify post_state_dict_hooks between FU and DPA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add temporary test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove previous attempts to move core_attention.fused_attention to core_attention; keep the test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove the test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * disable pylint self arg for hook which is required by hook Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Signed-off-by: cyanguwa <8636796+cyanguwa@users.noreply.github.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
This PR
FusedAttentionhas been subclassed withTEBaseModule, and an_extra_statehas been added to the module'sstate_dict._extra_statecontains FP8 meta data, but due to the subclassing, the addition of_extra_statetostate_dicthappens regardless of FP8 training or F16 training. This PR allows users to load older checkpoints (which do not have_extra_stateforFusedAttention), as well as save and load new checkpoints as usual (which will contain_extra_stateforFusedAttention).