Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

from collections.abc import Iterable
import io
import math
from typing import Optional

Expand Down Expand Up @@ -1882,3 +1883,118 @@ def test_backward_linear_add(
torch.testing.assert_close(y2_test, y2_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)


class TestCheckpointing:
"""Tests for checkpointing"""

@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantized_weight", (False, True))
def test_linear(
self,
*,
pre_checkpoint_steps: int = 2,
post_checkpoint_steps: int = 2,
weight_shape: tuple[int, int] = (32, 32),
in_shape: Iterable[int] = (32, -1),
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
quantization: Optional[str],
quantized_weight: bool,
) -> None:
"""Check checkpointing with linear op"""

# Make input and weight shapes consistent
out_features, in_features = weight_shape
in_shape = list(in_shape)[:-1] + [in_features]
out_shape = in_shape[:-1] + [out_features]

# Skip invalid configurations
quantized_compute = quantization is not None
maybe_skip_quantization(quantization, dims=in_shape, device=device)
maybe_skip_quantization(quantization, dims=out_shape)

# Construct model
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
model_save = te_ops.Sequential(
te_ops.Linear(in_features, out_features, device=device, dtype=dtype)
)
optim_save = torch.optim.SGD(model_save.parameters(), lr=0.25)

# Warmup training steps
for _ in range(pre_checkpoint_steps):
x = torch.randn(in_shape, dtype=dtype, device=device, requires_grad=True)
dy = torch.randn(out_shape, dtype=dtype, device=device)
optim_save.zero_grad()
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y = model_save(x)
y.backward(dy)
optim_save.step()

# Save checkpoint
byte_stream = io.BytesIO()
torch.save(
{"model": model_save.state_dict(), "optim": optim_save.state_dict()},
byte_stream,
)
checkpoint_bytes = byte_stream.getvalue()
del byte_stream

# Synthetic data for evaluation
xs_save = [
torch.randn(in_shape, dtype=dtype, device=device, requires_grad=True)
for _ in range(post_checkpoint_steps)
]
with torch.no_grad():
xs_load = [x.clone().requires_grad_() for x in xs_save]
dys = [
torch.randn(out_shape, dtype=dtype, device=device) for _ in range(post_checkpoint_steps)
]

# Training steps with original model
ys_save = []
for i in range(post_checkpoint_steps):
optim_save.zero_grad()
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y = model_save(xs_save[i])
y.backward(dys[i])
optim_save.step()
ys_save.append(y)

# Load checkpoint
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
model_load = te_ops.Sequential(
te_ops.Linear(in_features, out_features, device=device, dtype=dtype)
)
optim_load = torch.optim.SGD(model_load.parameters(), lr=0.25)
state_dict = torch.load(io.BytesIO(checkpoint_bytes), weights_only=False)
model_load.load_state_dict(state_dict["model"])
optim_load.load_state_dict(state_dict["optim"])

# Training steps with loaded model
ys_load = []
for i in range(post_checkpoint_steps):
optim_load.zero_grad()
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y = model_load(xs_load[i])
y.backward(dys[i])
optim_load.step()
ys_load.append(y)

# Check that original and loaded model match exactly
tols = {"rtol": 0, "atol": 0}
for param_load, param_save in zip(model_load.parameters(), model_save.parameters()):
torch.testing.assert_close(param_load, param_save, **tols)
torch.testing.assert_close(param_load.grad, param_save.grad, **tols)
for y_load, y_save in zip(ys_load, ys_save):
torch.testing.assert_close(y_load, y_save, **tols)
for x_load, x_save in zip(xs_load, xs_save):
torch.testing.assert_close(x_load.grad, x_save.grad, **tols)
71 changes: 38 additions & 33 deletions transformer_engine/pytorch/ops/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
DelayedScalingRecipeState,
FP8GlobalStateManager,
RecipeState,
fp8_autocast,
)
from ..tensor import Quantizer

Expand Down Expand Up @@ -508,7 +509,7 @@ def forward(
def get_extra_state(self) -> torch.Tensor:
"""Serialize extra state

Contains metadata for FP8 casting.
Contains metadata for quantization recipe.

"""

Expand Down Expand Up @@ -540,23 +541,27 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor:
dst.copy_(src, non_blocking=True)
return dst

# Store FP8 state
# Store quantizer state if needed
state = {}
for mode in ("forward", "backward"):

# Get state for a given FP8 tensor
if self.num_quantizers(mode) == 0:
# Skip if op has no quantizer state
if self._fp8_metas is None or self._fp8_metas[mode] is None:
continue
fp8_meta = self.get_fp8_meta(mode)

# Quantizer state
fp8_meta = self._fp8_metas[mode]
state[mode] = {}
state[mode]["recipe"] = fp8_meta["recipe"]

# Store tensors
if "scaling_fwd" in fp8_meta:
state[mode]["scale_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale)
state[mode]["amax_history_fwd"] = to_cpu(fp8_meta["scaling_fwd"].amax_history)
if "scaling_bwd" in fp8_meta:
state[mode]["scale_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale)
state[mode]["amax_history_bwd"] = to_cpu(fp8_meta["scaling_bwd"].amax_history)
# Copy tensors to CPU and store
if state[mode]["recipe"].delayed():
if mode == "forward":
state[mode]["scale_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale)
state[mode]["amax_history_fwd"] = to_cpu(fp8_meta["scaling_fwd"].amax_history)
if mode == "backward":
state[mode]["scale_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale)
state[mode]["amax_history_bwd"] = to_cpu(fp8_meta["scaling_bwd"].amax_history)

# Store other picklable items
extra = {}
Expand Down Expand Up @@ -595,37 +600,37 @@ def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None:
dst.data = torch.empty(src.size(), dtype=dst.dtype, device=dst.device)
dst.copy_(src, non_blocking=True)

# Load FP8 state
# Load quantizer state if needed
for mode in ("forward", "backward"):

# Get state for a given FP8 tensor
# Skip if checkpoint has no quantizer state
if mode not in state:
continue
if self.num_quantizers(mode) == 0:
continue
fp8_meta = self.get_fp8_meta(mode)
if fp8_meta is None:
continue

# Load extra state
# Get op's quantizer state, initializing if needed
if self._fp8_metas is None or self._fp8_metas[mode] is None:
with fp8_autocast(fp8_recipe=state[mode]["recipe"]):
self._reset_quantization_recipe_state()
fp8_meta = self._fp8_metas[mode]

# Load extra items
fp8_meta["recipe"] = state[mode]["recipe"]
fp8_meta.update(state[mode]["extra_fp8_variables"])
if "amax_history_fwd" in state[mode]:
fp8_meta["recipe"].amax_history_len = state[mode]["amax_history_fwd"].size(0)
elif "amax_history_bwd" in state[mode]:
fp8_meta["recipe"].amax_history_len = state[mode]["amax_history_bwd"].size(0)
if "global_fp8_buffer_pos_fwd_recompute" in fp8_meta:
del fp8_meta["global_fp8_buffer_pos_fwd_recompute"]

# Load tensors
fp8_meta = self.get_fp8_meta(mode)
if "scaling_fwd" in fp8_meta:
fp8_meta_fwd = fp8_meta["scaling_fwd"]
copy_tensor(state[mode]["scale_fwd"], fp8_meta_fwd.scale)
copy_tensor(state[mode]["amax_history_fwd"], fp8_meta_fwd.amax_history)
if "scaling_bwd" in fp8_meta:
fp8_meta_bwd = fp8_meta["scaling_bwd"]
copy_tensor(state[mode]["scale_bwd"], fp8_meta_bwd.scale)
copy_tensor(state[mode]["amax_history_bwd"], fp8_meta_bwd.amax_history)
if state[mode]["recipe"].delayed():
if mode == "forward":
copy_tensor(state[mode]["scale_fwd"], fp8_meta["scaling_fwd"].scale)
copy_tensor(
state[mode]["amax_history_fwd"], fp8_meta["scaling_fwd"].amax_history
)
if mode == "backward":
copy_tensor(state[mode]["scale_bwd"], fp8_meta["scaling_bwd"].scale)
copy_tensor(
state[mode]["amax_history_bwd"], fp8_meta["scaling_bwd"].amax_history
)

# Finish CPU-GPU memory transfers
torch.cuda.synchronize()
Expand Down
5 changes: 4 additions & 1 deletion transformer_engine/pytorch/tensor/mxfp8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def _make_in_reduce_ex(
columnwise_scale_inv: torch.Tensor,
fp8_dtype: TE_DType,
dtype: torch.dtype,
shape: torch.shape,
) -> MXFP8Tensor:
"""Build MXFP8Tensor, for use in __reduce__

Expand All @@ -361,10 +362,11 @@ def _make_in_reduce_ex(
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
dtype=dtype,
shape=shape,
)

def __reduce_ex__(self, protocol: int) -> tuple:
"""Custom pickling to remove references to FP8 metadata objects"""
"""Custom pickling"""
return (
MXFP8Tensor._make_in_reduce_ex,
(
Expand All @@ -374,6 +376,7 @@ def __reduce_ex__(self, protocol: int) -> tuple:
self._columnwise_scale_inv,
self._fp8_dtype,
self.dtype,
self.shape,
),
)

Expand Down