From 4d03b6f8131c2efa7bdf06e6db2a5ef1ec7ff5bb Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 1 Apr 2025 01:39:06 +0000 Subject: [PATCH 1/2] Debug checkpointing with te.Sequential Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 117 ++++++++++++++++++ transformer_engine/pytorch/ops/op.py | 67 +++++----- .../pytorch/tensor/mxfp8_tensor.py | 5 +- 3 files changed, 155 insertions(+), 34 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 9c1a842cd8..b41475c78a 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -5,6 +5,7 @@ from __future__ import annotations from collections.abc import Iterable +import io import math from typing import Optional @@ -1882,3 +1883,119 @@ def test_backward_linear_add( torch.testing.assert_close(y2_test, y2_ref, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(dw_test, w_ref.grad, **tols) + + +class TestCheckpointing: + """Tests for checkpointing""" + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantized_weight", (False, True)) + def test_linear( + self, + *, + pre_checkpoint_steps: int = 2, + post_checkpoint_steps: int = 2, + weight_shape: tuple[int, int] = (32, 32), + in_shape: Iterable[int] = (32, -1), + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + quantization: Optional[str], + quantized_weight: bool, + ) -> None: + """Check checkpointing with linear op""" + + # Make input and weight shapes consistent + out_features, in_features = weight_shape + in_shape = list(in_shape)[:-1] + [in_features] + out_shape = in_shape[:-1] + [out_features] + + # Skip invalid configurations + quantized_compute = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=out_shape) + + # Construct model + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + model_save = te_ops.Sequential( + te_ops.Linear(in_features, out_features, device=device, dtype=dtype) + ) + optim_save = torch.optim.SGD(model_save.parameters(), lr=0.25) + + # Warmup training steps + for _ in range(pre_checkpoint_steps): + x = torch.randn(in_shape, dtype=dtype, device=device, requires_grad=True) + dy = torch.randn(out_shape, dtype=dtype, device=device) + optim_save.zero_grad() + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + y = model_save(x) + y.backward(dy) + optim_save.step() + + # Save checkpoint + byte_stream = io.BytesIO() + torch.save( + { "model": model_save.state_dict(), "optim": optim_save.state_dict() }, + byte_stream, + ) + checkpoint_bytes = byte_stream.getvalue() + del byte_stream + + # Synthetic data for evaluation + xs_save = [ + torch.randn(in_shape, dtype=dtype, device=device, requires_grad=True) + for _ in range(post_checkpoint_steps) + ] + with torch.no_grad(): + xs_load = [x.clone().requires_grad_() for x in xs_save] + dys = [ + torch.randn(out_shape, dtype=dtype, device=device) + for _ in range(post_checkpoint_steps) + ] + + # Training steps with original model + ys_save = [] + for i in range(post_checkpoint_steps): + optim_save.zero_grad() + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + y = model_save(xs_save[i]) + y.backward(dys[i]) + optim_save.step() + ys_save.append(y) + + # Load checkpoint + with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + model_load = te_ops.Sequential( + te_ops.Linear(in_features, out_features, device=device, dtype=dtype) + ) + optim_load = torch.optim.SGD(model_load.parameters(), lr=0.25) + state_dict = torch.load(io.BytesIO(checkpoint_bytes), weights_only=False) + model_load.load_state_dict(state_dict["model"]) + optim_load.load_state_dict(state_dict["optim"]) + + # Training steps with loaded model + ys_load = [] + for i in range(post_checkpoint_steps): + optim_load.zero_grad() + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + y = model_load(xs_load[i]) + y.backward(dys[i]) + optim_load.step() + ys_load.append(y) + + # Check that original and loaded model match exactly + tols = { "rtol": 0, "atol": 0 } + for param_load, param_save in zip(model_load.parameters(), model_save.parameters()): + torch.testing.assert_close(param_load, param_save, **tols) + torch.testing.assert_close(param_load.grad, param_save.grad, **tols) + for y_load, y_save in zip(ys_load, ys_save): + torch.testing.assert_close(y_load, y_save, **tols) + for x_load, x_save in zip(xs_load, xs_save): + torch.testing.assert_close(x_load.grad, x_save.grad, **tols) diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 2e212e15f4..503ead2cc6 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -19,6 +19,7 @@ DelayedScalingRecipeState, FP8GlobalStateManager, RecipeState, + fp8_autocast, ) from ..tensor import Quantizer @@ -508,7 +509,7 @@ def forward( def get_extra_state(self) -> torch.Tensor: """Serialize extra state - Contains metadata for FP8 casting. + Contains metadata for quantization recipe. """ @@ -540,23 +541,27 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor: dst.copy_(src, non_blocking=True) return dst - # Store FP8 state + # Store quantizer state if needed state = {} for mode in ("forward", "backward"): - # Get state for a given FP8 tensor - if self.num_quantizers(mode) == 0: + # Skip if op has no quantizer state + if self._fp8_metas is None or self._fp8_metas[mode] is None: continue - fp8_meta = self.get_fp8_meta(mode) + + # Quantizer state + fp8_meta = self._fp8_metas[mode] state[mode] = {} + state[mode]["recipe"] = fp8_meta["recipe"] - # Store tensors - if "scaling_fwd" in fp8_meta: - state[mode]["scale_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale) - state[mode]["amax_history_fwd"] = to_cpu(fp8_meta["scaling_fwd"].amax_history) - if "scaling_bwd" in fp8_meta: - state[mode]["scale_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale) - state[mode]["amax_history_bwd"] = to_cpu(fp8_meta["scaling_bwd"].amax_history) + # Copy tensors to CPU and store + if state[mode]["recipe"].delayed(): + if mode == "forward": + state[mode]["scale_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale) + state[mode]["amax_history_fwd"] = to_cpu(fp8_meta["scaling_fwd"].amax_history) + if mode == "backward": + state[mode]["scale_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale) + state[mode]["amax_history_bwd"] = to_cpu(fp8_meta["scaling_bwd"].amax_history) # Store other picklable items extra = {} @@ -595,37 +600,33 @@ def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: dst.data = torch.empty(src.size(), dtype=dst.dtype, device=dst.device) dst.copy_(src, non_blocking=True) - # Load FP8 state + # Load quantizer state if needed for mode in ("forward", "backward"): - # Get state for a given FP8 tensor + # Skip if checkpoint has no quantizer state if mode not in state: continue - if self.num_quantizers(mode) == 0: - continue - fp8_meta = self.get_fp8_meta(mode) - if fp8_meta is None: - continue - # Load extra state + # Get op's quantizer state, initializing if needed + if self._fp8_metas is None or self._fp8_metas[mode] is None: + with fp8_autocast(fp8_recipe=state[mode]["recipe"]): + self._reset_quantization_recipe_state() + fp8_meta = self._fp8_metas[mode] + + # Load extra items + fp8_meta["recipe"] = state[mode]["recipe"] fp8_meta.update(state[mode]["extra_fp8_variables"]) - if "amax_history_fwd" in state[mode]: - fp8_meta["recipe"].amax_history_len = state[mode]["amax_history_fwd"].size(0) - elif "amax_history_bwd" in state[mode]: - fp8_meta["recipe"].amax_history_len = state[mode]["amax_history_bwd"].size(0) if "global_fp8_buffer_pos_fwd_recompute" in fp8_meta: del fp8_meta["global_fp8_buffer_pos_fwd_recompute"] # Load tensors - fp8_meta = self.get_fp8_meta(mode) - if "scaling_fwd" in fp8_meta: - fp8_meta_fwd = fp8_meta["scaling_fwd"] - copy_tensor(state[mode]["scale_fwd"], fp8_meta_fwd.scale) - copy_tensor(state[mode]["amax_history_fwd"], fp8_meta_fwd.amax_history) - if "scaling_bwd" in fp8_meta: - fp8_meta_bwd = fp8_meta["scaling_bwd"] - copy_tensor(state[mode]["scale_bwd"], fp8_meta_bwd.scale) - copy_tensor(state[mode]["amax_history_bwd"], fp8_meta_bwd.amax_history) + if state[mode]["recipe"].delayed(): + if mode == "forward": + copy_tensor(state[mode]["scale_fwd"], fp8_meta["scaling_fwd"].scale) + copy_tensor(state[mode]["amax_history_fwd"], fp8_meta["scaling_fwd"].amax_history) + if mode == "backward": + copy_tensor(state[mode]["scale_bwd"], fp8_meta["scaling_bwd"].scale) + copy_tensor(state[mode]["amax_history_bwd"], fp8_meta["scaling_bwd"].amax_history) # Finish CPU-GPU memory transfers torch.cuda.synchronize() diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 843c7936f2..2694319a0f 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -347,6 +347,7 @@ def _make_in_reduce_ex( columnwise_scale_inv: torch.Tensor, fp8_dtype: TE_DType, dtype: torch.dtype, + shape: torch.shape, ) -> MXFP8Tensor: """Build MXFP8Tensor, for use in __reduce__ @@ -361,10 +362,11 @@ def _make_in_reduce_ex( columnwise_data=columnwise_data, columnwise_scale_inv=columnwise_scale_inv, dtype=dtype, + shape=shape, ) def __reduce_ex__(self, protocol: int) -> tuple: - """Custom pickling to remove references to FP8 metadata objects""" + """Custom pickling""" return ( MXFP8Tensor._make_in_reduce_ex, ( @@ -374,6 +376,7 @@ def __reduce_ex__(self, protocol: int) -> tuple: self._columnwise_scale_inv, self._fp8_dtype, self.dtype, + self.shape, ), ) From 8edec96dad76a1eeca3f7dce41c9fa7435fb26f5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 1 Apr 2025 01:45:04 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_fusible_ops.py | 7 +++---- transformer_engine/pytorch/ops/op.py | 8 ++++++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index b41475c78a..3773f0ab35 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1942,7 +1942,7 @@ def test_linear( # Save checkpoint byte_stream = io.BytesIO() torch.save( - { "model": model_save.state_dict(), "optim": optim_save.state_dict() }, + {"model": model_save.state_dict(), "optim": optim_save.state_dict()}, byte_stream, ) checkpoint_bytes = byte_stream.getvalue() @@ -1956,8 +1956,7 @@ def test_linear( with torch.no_grad(): xs_load = [x.clone().requires_grad_() for x in xs_save] dys = [ - torch.randn(out_shape, dtype=dtype, device=device) - for _ in range(post_checkpoint_steps) + torch.randn(out_shape, dtype=dtype, device=device) for _ in range(post_checkpoint_steps) ] # Training steps with original model @@ -1991,7 +1990,7 @@ def test_linear( ys_load.append(y) # Check that original and loaded model match exactly - tols = { "rtol": 0, "atol": 0 } + tols = {"rtol": 0, "atol": 0} for param_load, param_save in zip(model_load.parameters(), model_save.parameters()): torch.testing.assert_close(param_load, param_save, **tols) torch.testing.assert_close(param_load.grad, param_save.grad, **tols) diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 503ead2cc6..ad32055479 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -623,10 +623,14 @@ def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: if state[mode]["recipe"].delayed(): if mode == "forward": copy_tensor(state[mode]["scale_fwd"], fp8_meta["scaling_fwd"].scale) - copy_tensor(state[mode]["amax_history_fwd"], fp8_meta["scaling_fwd"].amax_history) + copy_tensor( + state[mode]["amax_history_fwd"], fp8_meta["scaling_fwd"].amax_history + ) if mode == "backward": copy_tensor(state[mode]["scale_bwd"], fp8_meta["scaling_bwd"].scale) - copy_tensor(state[mode]["amax_history_bwd"], fp8_meta["scaling_bwd"].amax_history) + copy_tensor( + state[mode]["amax_history_bwd"], fp8_meta["scaling_bwd"].amax_history + ) # Finish CPU-GPU memory transfers torch.cuda.synchronize()