Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 63 additions & 65 deletions tests/pytorch/test_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def setup_class(cls) -> None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

@pytest.mark.parametrize("amax_history_len", [1, 31, 1024])
@pytest.mark.parametrize("amax_history_len", [31, 1024])
@pytest.mark.parametrize("amax_compute_algo", ["max", "most_recent"])
@pytest.mark.parametrize("is_first_microbatch", [None, True, False])
def test_amax_and_scale_update(
Expand All @@ -51,7 +51,10 @@ def test_amax_and_scale_update(
)
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
module = te.Linear(16, 16)
y = module(torch.zeros([16, 16], device="cuda"))
y = module(
torch.randn([16, 16], device="cuda"),
is_first_microbatch=True,
)
y.backward(torch.zeros_like(y))

# Get amax history and scaling factors
Expand All @@ -67,101 +70,96 @@ def test_amax_and_scale_update(

# Tweak amax history and scaling factors
amax_history_forward.copy_(2 * torch.rand_like(amax_history_forward) + 0.5)
if amax_history_len > 1:
amax_history_forward[1, 0].fill_(3)
amax_history_forward[0, :].zero_()
scale_forward.copy_(2 * torch.rand_like(scale_forward) + 0.5)
scale_inv_forward.copy_(torch.reciprocal(scale_forward))
amax_history_backward.copy_(2 * torch.rand_like(amax_history_backward) + 0.5)
scale_backward.copy_(2 * torch.rand_like(scale_backward) + 0.5)
scale_inv_backward.copy_(torch.reciprocal(scale_backward))
amax_history_backward[0, :].zero_()

# Expected amax history after update
ref_amax_history_forward = torch.roll(amax_history_forward, -1, dims=0)
ref_amax_history_forward[0].zero_()
ref_amax_history_backward = torch.roll(amax_history_backward, -1, dims=0)
ref_amax_history_backward[0].zero_()
# Note: amax history is only updated when amax is updated
update_weight_amax = is_first_microbatch is None or is_first_microbatch
ref_amax_history_forward = amax_history_forward.clone()
ref_amax_history_forward[:, 0].copy_(torch.roll(amax_history_forward[:, 0], -1))
if update_weight_amax:
ref_amax_history_forward[:, 1].copy_(torch.roll(amax_history_forward[:, 1], -1))
ref_amax_history_forward[0, :].zero_()
ref_amax_history_backward = amax_history_backward.clone()
ref_amax_history_backward[:, 0].copy_(torch.roll(amax_history_backward[:, 0], -1))
ref_amax_history_backward[0, :].zero_()

# Expected scale and scale inverse
if amax_compute_algo == "max":
ref_amax_forward = amax_history_forward.max(dim=0).values
ref_amax_backward = amax_history_backward.max(dim=0).values
elif amax_compute_algo == "most_recent":
ref_amax_forward = amax_history_forward[0]
ref_amax_backward = amax_history_backward[0]
ref_amax_forward = amax_history_forward[-1]
ref_amax_backward = amax_history_backward[-1]
else:
raise ValueError(f"{amax_compute_algo=} is not supported")
ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2 ** margin)
ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2 ** margin)
ref_scale_inv_forward = torch.reciprocal(ref_scale_forward)
update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch
if not update_weight_scale_inv:
update_weight_amax = is_first_microbatch is None or is_first_microbatch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For our current setup, this should always be True right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The FP8 cast kernel is never called on the weights with is_first_microbatch=False, so the amax is left at zero. #825 changes the FP8 scale update kernel so that it doesn't roll the amax history or update the scales if the amax is zero.

if not update_weight_amax:
ref_scale_inv_forward[1].copy_(scale_inv_forward[1])
ref_scale_inv_backward = torch.reciprocal(ref_scale_backward)

# Make sure we are not trivially passing tests
if amax_history_len > 1:
with pytest.raises(AssertionError):
torch.testing.assert_close(
amax_history_forward[1:],
ref_amax_history_forward[1:],
)
with pytest.raises(AssertionError):
torch.testing.assert_close(
scale_forward,
ref_scale_forward,
)
with pytest.raises(AssertionError):
torch.testing.assert_close(
scale_inv_forward,
ref_scale_inv_forward,
)
if amax_history_len > 1:
with pytest.raises(AssertionError):
torch.testing.assert_close(
fp8_meta[backward_key].amax_history[1:],
ref_amax_history_backward[1:],
)
with pytest.raises(AssertionError):
torch.testing.assert_close(
fp8_meta[backward_key].scale,
ref_scale_backward,
)
with pytest.raises(AssertionError):
torch.testing.assert_close(
fp8_meta[backward_key].scale_inv,
ref_scale_inv_backward,
)

# Perform forward and backward pass to update fp8_meta
# Perform forward, backward, and optimizer steps to update fp8_meta
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
x = torch.zeros([16, 16], device="cuda")
x = torch.randn([16, 16], device="cuda")
y = module(x, is_first_microbatch=is_first_microbatch)
y.backward(torch.zeros_like(y))
y.backward(torch.randn_like(y))

# Check that fp8_meta matches expected values
# Check that amax history matches expected values
torch.testing.assert_close(
fp8_meta[forward_key].amax_history[1:],
ref_amax_history_forward[1:],
amax_history_forward[:-1],
ref_amax_history_forward[:-1],
)
torch.testing.assert_close(
fp8_meta[forward_key].scale,
ref_scale_forward,
amax_history_backward[:-1],
ref_amax_history_backward[:-1],
)

# Expected scale and scale inverse
if amax_compute_algo == "max":
ref_amax_forward = amax_history_forward.max(dim=0).values
ref_amax_backward = amax_history_backward.max(dim=0).values
elif amax_compute_algo == "most_recent":
ref_amax_forward = amax_history_forward[-1]
ref_amax_backward = amax_history_backward[-1]
else:
raise ValueError(f"{amax_compute_algo=} is not supported")
ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2 ** margin)
ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2 ** margin)
ref_scale_inv_forward = torch.reciprocal(ref_scale_forward)
ref_scale_inv_backward = torch.reciprocal(ref_scale_backward)

# Check that scale and scale inverse match expected values
# Note: scale and scale inverse are only updated when amax is updated
torch.testing.assert_close(
fp8_meta[forward_key].scale_inv,
ref_scale_inv_forward,
scale_forward[0],
ref_scale_forward[0],
)
torch.testing.assert_close(
fp8_meta[backward_key].amax_history[1:],
ref_amax_history_backward[1:],
scale_inv_forward[0],
ref_scale_inv_forward[0],
)
if update_weight_amax:
torch.testing.assert_close(
scale_forward[1],
ref_scale_forward[1],
)
torch.testing.assert_close(
scale_inv_forward[1],
ref_scale_inv_forward[1],
)
torch.testing.assert_close(
fp8_meta[backward_key].scale,
ref_scale_backward,
scale_backward[0],
ref_scale_backward[0],
)
torch.testing.assert_close(
fp8_meta[backward_key].scale_inv,
ref_scale_inv_backward,
scale_inv_backward[0],
ref_scale_inv_backward[0],
)

@pytest.mark.parametrize("amax_case", ["zero", "tiny", "normal", "inf", "nan"])
Expand Down