Handle the scaling factor when amax is too tiny that leads to an infinite scale#786
Conversation
…nite scale Signed-off-by: Jinze Xue <jinzex@nvidia.com>
Signed-off-by: Jinze Xue <jinzex@nvidia.com>
Signed-off-by: Jinze Xue <jinzex@nvidia.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Jinze Xue <155670984+jinzex@users.noreply.github.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Jinze Xue <155670984+jinzex@users.noreply.github.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Jinze Xue <155670984+jinzex@users.noreply.github.com>
Signed-off-by: Jinze Xue <jinzex@nvidia.com>
|
@timmoon10 Thanks Tim for your review and suggestions. All review suggestions were applied. |
|
@ksivaman Currently |
|
@ksivaman Thanks for your reply! Do you mean removing the TransformerEngine/tests/pytorch/test_recipe.py Lines 31 to 39 in f85553e For example, change it to @pytest.mark.parametrize("amax_history_len", [1, 31, 1024])
@pytest.mark.parametrize("amax_compute_algo", ["max", "most_recent"])
@pytest.mark.parametrize("is_first_microbatch", [None, True])
def test_amax_and_scale_update(
self,
amax_history_len: int,
amax_compute_algo: str,
is_first_microbatch: Optional[bool],
margin: int = 2,
): |
|
I meant that in this line, we should set |
…r is_first_microbatch=False Signed-off-by: Jinze Xue <jinzex@nvidia.com>
|
Thanks for the suggestion! That fixed the test. Changes has been committed. |
There was a problem hiding this comment.
I don't think we should change test_recipe.py. The test failure exposes a valid bug introduced by the recipe change in #575. In particular, by updating the weight scales in every microbatch step, we might change the FP8 scale in a step where we do not change the FP8 data.
This bug is beyond the scope of this PR though, and otherwise this PR looks good. I think a better solution is to merge this PR without making any changes to the tests and we will fix that bug in an upcoming PR.
|
/te-ci pytorch |
Signed-off-by: Jinze Xue <jinzex@nvidia.com>
|
@timmoon10 Thanks Tim for the comment! The changes to |
Signed-off-by: Tim Moon <tmoon@nvidia.com>
There was a problem hiding this comment.
LGTM. This will expose some test failures from an existing bug (see #786 (review)) and I see some unrelated linter failures (see #816).
|
/te-ci pytorch |
…nite scale (NVIDIA#786) * Handle the scaling factor when amax is too tiny that leads to an infinite scale Signed-off-by: Jinze Xue <jinzex@nvidia.com> * revert formatting changes Signed-off-by: Jinze Xue <jinzex@nvidia.com> * fix comments Signed-off-by: Jinze Xue <jinzex@nvidia.com> * Apply review suggestion Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Jinze Xue <155670984+jinzex@users.noreply.github.com> * Apply review suggestion Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Jinze Xue <155670984+jinzex@users.noreply.github.com> * Apply review suggestion Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Jinze Xue <155670984+jinzex@users.noreply.github.com> * apply review suggestion Signed-off-by: Jinze Xue <jinzex@nvidia.com> * add test_recipe.py to qa/L0_pytorch_unittest/test.sh; fix unittest for is_first_microbatch=False Signed-off-by: Jinze Xue <jinzex@nvidia.com> * revert changes to update_weight_scale_inv Signed-off-by: Jinze Xue <jinzex@nvidia.com> * Debug test failures Signed-off-by: Tim Moon <tmoon@nvidia.com> --------- Signed-off-by: Jinze Xue <jinzex@nvidia.com> Signed-off-by: Jinze Xue <155670984+jinzex@users.noreply.github.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> Co-authored-by: Jinze Xue <jinzex@nvidia.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Tim Moon <tmoon@nvidia.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
When the amax is too tiny that the scale becoming infinite in FP32, we set the scale to the max value of FP32. In this case, the tensor’s amax won't get mapped to the FP8 max representable, but rather something below that, but this is the best thing we can do.
cc @Oleg-Goncharov