From d90ee8efebe97c02f0f885e2e8fed2e0c98317bc Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 3 May 2024 23:47:15 +0000 Subject: [PATCH] Update FP8 recipe test to handle recipe changes Signed-off-by: Tim Moon --- tests/pytorch/test_recipe.py | 128 +++++++++++++++++------------------ 1 file changed, 63 insertions(+), 65 deletions(-) diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 92c7f26f59..2de849fdf2 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -29,7 +29,7 @@ def setup_class(cls) -> None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) - @pytest.mark.parametrize("amax_history_len", [1, 31, 1024]) + @pytest.mark.parametrize("amax_history_len", [31, 1024]) @pytest.mark.parametrize("amax_compute_algo", ["max", "most_recent"]) @pytest.mark.parametrize("is_first_microbatch", [None, True, False]) def test_amax_and_scale_update( @@ -51,7 +51,10 @@ def test_amax_and_scale_update( ) with te.fp8_autocast(enabled=True, fp8_recipe=recipe): module = te.Linear(16, 16) - y = module(torch.zeros([16, 16], device="cuda")) + y = module( + torch.randn([16, 16], device="cuda"), + is_first_microbatch=True, + ) y.backward(torch.zeros_like(y)) # Get amax history and scaling factors @@ -67,101 +70,96 @@ def test_amax_and_scale_update( # Tweak amax history and scaling factors amax_history_forward.copy_(2 * torch.rand_like(amax_history_forward) + 0.5) - if amax_history_len > 1: - amax_history_forward[1, 0].fill_(3) + amax_history_forward[0, :].zero_() scale_forward.copy_(2 * torch.rand_like(scale_forward) + 0.5) scale_inv_forward.copy_(torch.reciprocal(scale_forward)) - amax_history_backward.copy_(2 * torch.rand_like(amax_history_backward) + 0.5) - scale_backward.copy_(2 * torch.rand_like(scale_backward) + 0.5) - scale_inv_backward.copy_(torch.reciprocal(scale_backward)) + amax_history_backward[0, :].zero_() # Expected amax history after update - ref_amax_history_forward = torch.roll(amax_history_forward, -1, dims=0) - ref_amax_history_forward[0].zero_() - ref_amax_history_backward = torch.roll(amax_history_backward, -1, dims=0) - ref_amax_history_backward[0].zero_() + # Note: amax history is only updated when amax is updated + update_weight_amax = is_first_microbatch is None or is_first_microbatch + ref_amax_history_forward = amax_history_forward.clone() + ref_amax_history_forward[:, 0].copy_(torch.roll(amax_history_forward[:, 0], -1)) + if update_weight_amax: + ref_amax_history_forward[:, 1].copy_(torch.roll(amax_history_forward[:, 1], -1)) + ref_amax_history_forward[0, :].zero_() + ref_amax_history_backward = amax_history_backward.clone() + ref_amax_history_backward[:, 0].copy_(torch.roll(amax_history_backward[:, 0], -1)) + ref_amax_history_backward[0, :].zero_() # Expected scale and scale inverse if amax_compute_algo == "max": ref_amax_forward = amax_history_forward.max(dim=0).values ref_amax_backward = amax_history_backward.max(dim=0).values elif amax_compute_algo == "most_recent": - ref_amax_forward = amax_history_forward[0] - ref_amax_backward = amax_history_backward[0] + ref_amax_forward = amax_history_forward[-1] + ref_amax_backward = amax_history_backward[-1] else: raise ValueError(f"{amax_compute_algo=} is not supported") ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2 ** margin) ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2 ** margin) ref_scale_inv_forward = torch.reciprocal(ref_scale_forward) - update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch - if not update_weight_scale_inv: + update_weight_amax = is_first_microbatch is None or is_first_microbatch + if not update_weight_amax: ref_scale_inv_forward[1].copy_(scale_inv_forward[1]) ref_scale_inv_backward = torch.reciprocal(ref_scale_backward) - # Make sure we are not trivially passing tests - if amax_history_len > 1: - with pytest.raises(AssertionError): - torch.testing.assert_close( - amax_history_forward[1:], - ref_amax_history_forward[1:], - ) - with pytest.raises(AssertionError): - torch.testing.assert_close( - scale_forward, - ref_scale_forward, - ) - with pytest.raises(AssertionError): - torch.testing.assert_close( - scale_inv_forward, - ref_scale_inv_forward, - ) - if amax_history_len > 1: - with pytest.raises(AssertionError): - torch.testing.assert_close( - fp8_meta[backward_key].amax_history[1:], - ref_amax_history_backward[1:], - ) - with pytest.raises(AssertionError): - torch.testing.assert_close( - fp8_meta[backward_key].scale, - ref_scale_backward, - ) - with pytest.raises(AssertionError): - torch.testing.assert_close( - fp8_meta[backward_key].scale_inv, - ref_scale_inv_backward, - ) - - # Perform forward and backward pass to update fp8_meta + # Perform forward, backward, and optimizer steps to update fp8_meta with te.fp8_autocast(enabled=True, fp8_recipe=recipe): - x = torch.zeros([16, 16], device="cuda") + x = torch.randn([16, 16], device="cuda") y = module(x, is_first_microbatch=is_first_microbatch) - y.backward(torch.zeros_like(y)) + y.backward(torch.randn_like(y)) - # Check that fp8_meta matches expected values + # Check that amax history matches expected values torch.testing.assert_close( - fp8_meta[forward_key].amax_history[1:], - ref_amax_history_forward[1:], + amax_history_forward[:-1], + ref_amax_history_forward[:-1], ) torch.testing.assert_close( - fp8_meta[forward_key].scale, - ref_scale_forward, + amax_history_backward[:-1], + ref_amax_history_backward[:-1], ) + + # Expected scale and scale inverse + if amax_compute_algo == "max": + ref_amax_forward = amax_history_forward.max(dim=0).values + ref_amax_backward = amax_history_backward.max(dim=0).values + elif amax_compute_algo == "most_recent": + ref_amax_forward = amax_history_forward[-1] + ref_amax_backward = amax_history_backward[-1] + else: + raise ValueError(f"{amax_compute_algo=} is not supported") + ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2 ** margin) + ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2 ** margin) + ref_scale_inv_forward = torch.reciprocal(ref_scale_forward) + ref_scale_inv_backward = torch.reciprocal(ref_scale_backward) + + # Check that scale and scale inverse match expected values + # Note: scale and scale inverse are only updated when amax is updated torch.testing.assert_close( - fp8_meta[forward_key].scale_inv, - ref_scale_inv_forward, + scale_forward[0], + ref_scale_forward[0], ) torch.testing.assert_close( - fp8_meta[backward_key].amax_history[1:], - ref_amax_history_backward[1:], + scale_inv_forward[0], + ref_scale_inv_forward[0], ) + if update_weight_amax: + torch.testing.assert_close( + scale_forward[1], + ref_scale_forward[1], + ) + torch.testing.assert_close( + scale_inv_forward[1], + ref_scale_inv_forward[1], + ) torch.testing.assert_close( - fp8_meta[backward_key].scale, - ref_scale_backward, + scale_backward[0], + ref_scale_backward[0], ) torch.testing.assert_close( - fp8_meta[backward_key].scale_inv, - ref_scale_inv_backward, + scale_inv_backward[0], + ref_scale_inv_backward[0], ) @pytest.mark.parametrize("amax_case", ["zero", "tiny", "normal", "inf", "nan"])