Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion tests/pytorch/test_float8tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# See LICENSE for license information.

from collections.abc import Iterable
import io
from typing import Any, Dict, List, Tuple, Union

import pytest
Expand Down Expand Up @@ -263,7 +264,7 @@ def test_transpose(
dims: DimsType,
transpose_dims: Tuple[int, int],
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 1,
scale: float = 0.5,
dtype: torch.dtype = torch.float32,
) -> None:
"""Test transpose"""
Expand Down Expand Up @@ -316,3 +317,45 @@ def test_transpose(
x_ref.transpose(*transpose_dims),
**tols,
)

def test_serialization(
self,
dims: DimsType = [2,3,5],
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 0.5,
dtype: torch.dtype = torch.float32,
):

# Initialize random data
dims = _to_list(dims)
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
x_fp8 = Float8Tensor.to_float8(
x_ref,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
x_ref = x_fp8.from_float8()

# Serialize tensor
byte_stream = io.BytesIO()
torch.save(x_fp8, byte_stream)
x_bytes = byte_stream.getvalue()

# Mess up and delete old tensor
x_fp8._data.zero_()
x_fp8._scale_inv.zero_()
del x_fp8, byte_stream

# Deserialize tensor
x_fp8 = torch.load(io.BytesIO(x_bytes))
del x_bytes

# Check results
tols = dict(rtol=0, atol=0)
torch.testing.assert_close(x_fp8, x_ref, **tols)

# Make sure we are not trivially passing tests
x_fp8._data.zero_()
x_fp8._scale_inv.zero_()
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8, x_ref, **tols)
122 changes: 118 additions & 4 deletions tests/pytorch/test_torch_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,22 @@
are identical to the original values.
"""

import io
import tempfile
from typing import Iterable, Union

import pytest
import torch
import transformer_engine.pytorch as te
import transformer_engine_extensions as tex
from transformer_engine.pytorch.cpp_extensions import fp8_gemm, cast_to_fp8
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.module.base import get_workspace
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule

# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()


def init_meta(size: int=1):
meta = tex.FP8TensorMeta()
Expand All @@ -29,16 +36,13 @@ def init_meta(size: int=1):
return meta


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("scale_fwd", [224, 112, 66])
@pytest.mark.parametrize("scale_bwd", [448, 33])
@pytest.mark.parametrize("history_fwd", [1.23, 4.56])
@pytest.mark.parametrize("history_bwd", [2.34, 5.67])
def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd):

# Skip FP8 tests on non-hopper devices
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
pytest.skip("Device compute capability 9.x required for FP8 execution.")

tmp_filename = tempfile.NamedTemporaryFile().name

precision = torch.float32
Expand Down Expand Up @@ -118,3 +122,113 @@ def forward(self, inp, weight):
assert torch.allclose(model_in.fp8_meta["scaling_bwd"].scale_inv, model_out.fp8_meta["scaling_bwd"].scale_inv)
assert torch.allclose(model_in.fp8_meta["scaling_bwd"].amax_history, model_out.fp8_meta["scaling_bwd"].amax_history)


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("save_fp8_model", [True, False])
@pytest.mark.parametrize("load_fp8_model", [True, False])
def test_fp8_model_checkpoint(
save_fp8_model: bool,
load_fp8_model: bool,
dims: Iterable[int] = [32,32],
dtype: torch.dtype = torch.float32,
device: Union[torch.device, str] = "cuda",
):

# Construct model
dims = list(dims)
hidden_dim = dims[-1]
with te.fp8_model_init(enabled=save_fp8_model):
model = te.Linear(
hidden_dim,
hidden_dim,
bias=False,
params_dtype=dtype,
device=device,
)

# Keep track of model output
x = torch.randn(dims, dtype=dtype, device=device)
with te.fp8_autocast():
y_ref = model(x.detach().clone()).detach().clone()

# Keep track of weights and FP8 scaling factors
weight_ref = model.weight.float().detach().clone()
fp8_meta_ref = { "scaling_fwd": {}, "scaling_bwd": {} }
with te.fp8_autocast(), torch.no_grad():
fp8_meta_fwd = model.fp8_meta["scaling_fwd"]
fp8_meta_bwd = model.fp8_meta["scaling_bwd"]
fp8_meta_fwd_ref = fp8_meta_ref["scaling_fwd"]
fp8_meta_bwd_ref = fp8_meta_ref["scaling_bwd"]
fp8_meta_fwd_ref["scale"] = torch.rand_like(fp8_meta_fwd.scale) + 0.5
fp8_meta_fwd_ref["scale_inv"] = fp8_meta_fwd_ref["scale"].reciprocal()
fp8_meta_bwd_ref["scale"] = torch.rand_like(fp8_meta_bwd.scale) + 0.5
fp8_meta_bwd_ref["scale_inv"] = fp8_meta_bwd_ref["scale"].reciprocal()
fp8_meta_fwd.scale.copy_(fp8_meta_fwd_ref["scale"])
fp8_meta_fwd.scale_inv.copy_(fp8_meta_fwd_ref["scale_inv"])
fp8_meta_bwd.scale.copy_(fp8_meta_bwd_ref["scale"])
fp8_meta_bwd.scale_inv.copy_(fp8_meta_bwd_ref["scale_inv"])
del fp8_meta_fwd, fp8_meta_bwd

# Save checkpoint
byte_stream = io.BytesIO()
torch.save(model.state_dict(), byte_stream)
model_bytes = byte_stream.getvalue()
del byte_stream

# Disturb and destroy model
with torch.no_grad():
model.weight.zero_()
model.fp8_meta = {"This": "is", "filled": "with", "nonsense": 1234}
del model

# Construct new model
with te.fp8_model_init(enabled=load_fp8_model):
model = te.Linear(
hidden_dim,
hidden_dim,
bias=False,
params_dtype=dtype,
device=device,
)

# Make sure new model does not match saved model
tols = dict(rtol=0.125, atol=0.0675) # fp8e4me3 epsilon = 0.0625
with pytest.raises(AssertionError):
torch.testing.assert_close(model.weight, weight_ref, **tols)
with te.fp8_autocast():
model.init_fp8_metadata()
fp8_meta_fwd = model.fp8_meta["scaling_fwd"]
fp8_meta_bwd = model.fp8_meta["scaling_bwd"]
fp8_meta_fwd_ref = fp8_meta_ref["scaling_fwd"]
fp8_meta_bwd_ref = fp8_meta_ref["scaling_bwd"]
with pytest.raises(AssertionError):
torch.testing.assert_close(fp8_meta_fwd.scale, fp8_meta_fwd_ref["scale"])
with pytest.raises(AssertionError):
torch.testing.assert_close(fp8_meta_fwd.scale_inv, fp8_meta_fwd_ref["scale_inv"])
with pytest.raises(AssertionError):
torch.testing.assert_close(fp8_meta_bwd.scale, fp8_meta_bwd_ref["scale"])
with pytest.raises(AssertionError):
torch.testing.assert_close(fp8_meta_bwd.scale_inv, fp8_meta_bwd_ref["scale_inv"])
with te.fp8_autocast():
y = model(x.detach().clone())
with pytest.raises(AssertionError):
torch.testing.assert_close(y, y_ref, **tols)

# Load checkpoint
model.load_state_dict(torch.load(io.BytesIO(model_bytes)))
del model_bytes

# Check that loaded model matches saved model
torch.testing.assert_close(model.weight, weight_ref, **tols)
with te.fp8_autocast():
fp8_meta_fwd = model.fp8_meta["scaling_fwd"]
fp8_meta_bwd = model.fp8_meta["scaling_bwd"]
fp8_meta_fwd_ref = fp8_meta_ref["scaling_fwd"]
fp8_meta_bwd_ref = fp8_meta_ref["scaling_bwd"]
torch.testing.assert_close(fp8_meta_fwd.scale, fp8_meta_fwd_ref["scale"])
torch.testing.assert_close(fp8_meta_fwd.scale_inv, fp8_meta_fwd_ref["scale_inv"])
torch.testing.assert_close(fp8_meta_bwd.scale, fp8_meta_bwd_ref["scale"])
torch.testing.assert_close(fp8_meta_bwd.scale_inv, fp8_meta_bwd_ref["scale_inv"])
with te.fp8_autocast():
y = model(x.detach().clone())
torch.testing.assert_close(y, y_ref, **tols)
Loading