Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
b6bfddb
Experimental FP8 tensor
ksivaman Oct 19, 2023
36093a5
Add fp8 tensor to ci test
ksivaman Oct 20, 2023
b50423b
Merge branch 'main' into float8tensor_experiments
ksivaman Oct 20, 2023
78239da
Merge branch 'main' into float8tensor_experiments
ksivaman Oct 23, 2023
dfcbcf1
Merge branch 'main' into float8tensor_experiments
ksivaman Oct 24, 2023
8814ead
review comments and tests
ksivaman Oct 24, 2023
0bf9029
Minor changes
ksivaman Oct 24, 2023
31d1eeb
Default to FP8 usage
ksivaman Oct 24, 2023
287fce7
Fix docs
ksivaman Oct 24, 2023
489d208
Naming changes
ksivaman Oct 24, 2023
c2b9aad
minor fix
ksivaman Oct 24, 2023
9220752
Fix transpose caching
ksivaman Oct 25, 2023
c3e0078
Debug transpose caching
timmoon10 Oct 25, 2023
202afcb
Rename FP8GlobalStateManager.with_fp8_parameters
timmoon10 Oct 25, 2023
1d0b1fe
remove Float8Tensor from import API
ksivaman Oct 25, 2023
39add1a
Avoid caching FP8 transposes if not required
timmoon10 Oct 25, 2023
b845d32
Fix import error in FP8 tensor tests
timmoon10 Oct 25, 2023
a5351b3
Fix tranpose caching and checkpointing bug
ksivaman Oct 26, 2023
79932f1
Merge branch 'main' into float8tensor_experiments
ksivaman Oct 26, 2023
7d95a91
Improve caching and fix distopt case
ksivaman Oct 27, 2023
20fc9a9
Update transformer_engine/pytorch/float8_tensor.py
timmoon10 Oct 27, 2023
9f08be7
Remove recursive logic
ksivaman Oct 27, 2023
00b9c31
Fix cache reset bug
ksivaman Oct 27, 2023
4cf27a1
Store FP8 attributes in dict
timmoon10 Oct 30, 2023
718d284
Make sure scale_inv is 1D tensor
timmoon10 Oct 31, 2023
94848da
Make sure scale_inv is 1D tensor
timmoon10 Oct 31, 2023
ac192d8
Fixes and detach recipe
ksivaman Oct 31, 2023
918bda3
Merge branch 'float8tensor_experiments' of https://github.com/timmoon…
ksivaman Oct 31, 2023
7b3f5cd
Set default fp8 data type
ksivaman Oct 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ pyTorch

.. autoapifunction:: transformer_engine.pytorch.fp8_autocast

.. autoapifunction:: transformer_engine.pytorch.fp8_model_init

.. autoapifunction:: transformer_engine.pytorch.checkpoint

.. autoapifunction:: transformer_engine.pytorch.onnx_export
1 change: 1 addition & 0 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pyt
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_attn.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
318 changes: 318 additions & 0 deletions tests/pytorch/test_float8tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

from collections.abc import Iterable
from typing import Any, Dict, List, Tuple, Union

import pytest
import torch

import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine_extensions as tex

# PyTorch tensor dtypes
_dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16]
# TE FP8 dtypes
_fp8_dtypes: List[tex.DType] = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]

# Numerical tolerances with FP8 types
_tols: Dict[tex.DType, Dict[str, float]] = {
tex.DType.kFloat8E4M3: dict(rtol=0.125, atol=0.0675), # epsilon = 0.0625
tex.DType.kFloat8E5M2: dict(rtol=0.25, atol=0.125), # epsilon = 0.125
}

def _to_list(x: Union[Iterable, Any]) -> List:
"""Convert to list if iterable, otherwise put in singleton list"""
if isinstance(x, Iterable):
return list(x)
else:
return [x]

# Types that can be interpreted as tensor dims
DimsType = Union[Iterable[int], int]

# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()

@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFloat8Tensor:

@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

def test_constructor(
self,
dims: DimsType = 1,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale_inv: float = 0.375,
dtype: torch.dtype = torch.float32,
) -> None:
"""Call constructor and perform sanity checks"""
dims = _to_list(dims)
tensor = Float8Tensor(
data=torch.zeros(dims, device="cuda", dtype=torch.uint8),
fp8_dtype=fp8_dtype,
fp8_scale_inv=torch.full([1], scale_inv),
dtype=dtype,
)
assert list(tensor.size()) == dims, "Incorrect dims"
assert tensor.dtype == dtype, "Incorrect nominal dtype"
assert tensor.is_cuda, "Incorrect device"

def _test_quantize_dequantize(
self,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 3.5,
dtype: torch.dtype = torch.float32,
dims: DimsType = 23,
) -> None:
"""Check numerical error when casting to FP8 and back"""

# Initialize random data
x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cpu") - 1

# Cast to FP8 and back
x_fp8 = Float8Tensor.to_float8(
x_ref,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
x_fp8 = x_fp8.from_float8().cpu()

# Check results
torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype])

# Make sure we are not trivially passing the test
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8, -x_ref, **_tols[fp8_dtype])

@pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
@pytest.mark.parametrize("dtype", _dtypes)
def test_quantize_dequantize_dtypes(
self,
fp8_dtype: tex.DType,
dtype: torch.dtype,
) -> None:
self._test_quantize_dequantize(fp8_dtype=fp8_dtype, dtype=dtype)

@pytest.mark.parametrize("scale", [0.375, 1, 3.5])
def test_quantize_dequantize_scales(self, scale: float) -> None:
self._test_quantize_dequantize(scale=scale)

@pytest.mark.parametrize("dims", [[], 1, 311, [7,11], [7,5,3], [2,3,5,3]])
def test_quantize_dequantize_dims(self, dims: DimsType) -> None:
self._test_quantize_dequantize(dims=dims)

def test_fp8_meta(
self,
dtype: torch.dtype = torch.float32,
dims: DimsType = 23,
) -> None:
"""Construct Float8Tensor using FP8 metadata and perform basic checks"""

# Get FP8 metadata from linear module
fp8_dtype = tex.DType.kFloat8E4M3
recipe = transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
module = te.Linear(32, 32)
_ = module(torch.zeros([8, 32], device="cuda"))
fp8_meta = module.fp8_meta
fp8_meta_index = tex.FP8FwdTensors.GEMM1_WEIGHT
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True)

# Initialize random data
dims = _to_list(dims)
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1

# Make Float8Tensor
x_fp8 = Float8Tensor.to_float8(
x_ref,
fp8_meta=fp8_meta,
fp8_meta_index=fp8_meta_index,
)
x_ref = x_fp8.from_float8()
assert list(x_fp8.size()) == dims, "Incorrect dims"
assert x_fp8.dtype == dtype, "Incorrect nominal dtype"
assert x_fp8.is_cuda, "Incorrect device"
assert x_fp8._fp8_dtype == fp8_dtype, "Incorrect FP8 dtype"

# Change FP8 metadata scale
fp8_meta[fp8_meta_key].scale[fp8_meta_index] = 2
fp8_meta[fp8_meta_key].scale_inv.fill_(123)

# Check results
torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype])
with pytest.raises(AssertionError):
# Make sure we are not trivially passing the test
torch.testing.assert_close(x_fp8, -x_ref, **_tols[fp8_dtype])

# Check if scaling factor is updated after in-place ops
x_fp8 += 0
fp8_meta[fp8_meta_key].scale[fp8_meta_index] = 4
fp8_meta[fp8_meta_key].scale_inv.fill_(321)
assert x_fp8._scale_inv.item() == 0.5, "Incorrect FP8 scale_inv"
torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype])
y = x_fp8.detach()
y += 0
assert x_fp8._scale_inv.item() == 0.25, "Incorrect FP8 scale_inv"
torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype])

def test_basic_ops(
self,
dims: DimsType = 23,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 3.5,
dtype: torch.dtype = torch.float32,
) -> None:
"""Test basic out-of-place ops"""

# Initialize random data
dims = _to_list(dims)
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
y_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
x_fp8 = Float8Tensor.to_float8(
x_ref,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
y_fp8 = Float8Tensor.to_float8(
y_ref,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
x_ref = x_fp8.from_float8()
y_ref = y_fp8.from_float8()

# Exact operations
torch.testing.assert_close(-x_fp8, -x_ref, rtol=0, atol=0)
torch.testing.assert_close(x_fp8.abs(), x_ref.abs(), rtol=0, atol=0)

# Operations with numerical error
tols = _tols[fp8_dtype]
torch.testing.assert_close(x_fp8 + y_fp8, x_ref + y_ref, **tols)
torch.testing.assert_close(x_fp8 - y_fp8, x_ref - y_ref, **tols)
torch.testing.assert_close(x_fp8 * y_fp8, x_ref * y_ref, **tols)
torch.testing.assert_close(x_fp8 + y_ref, x_ref + y_ref, **tols)
torch.testing.assert_close(x_ref + y_fp8, x_ref + y_ref, **tols)
torch.testing.assert_close(torch.sin(x_fp8), torch.sin(x_ref), **tols)

# Make sure we are not trivially passing tests
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8 + y_fp8, x_ref - y_fp8, **tols)

def test_inplace_ops(
self,
dims: DimsType = 23,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 3.5,
dtype: torch.dtype = torch.float32,
) -> None:
"""Test in-place ops"""

# Initialize random data
dims = _to_list(dims)
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
y_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
x_fp8 = Float8Tensor.to_float8(
x_ref,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
y_fp8 = Float8Tensor.to_float8(
y_ref,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
x_ref = x_fp8.from_float8()
y_ref = y_fp8.from_float8()

# In-place operations
tols = _tols[fp8_dtype]
x_fp8 += y_ref
x_ref += y_ref
torch.testing.assert_close(x_fp8, x_ref, **tols)
x_ref = x_fp8.from_float8()
x_fp8 -= y_fp8
x_ref -= y_fp8
torch.testing.assert_close(x_fp8, x_ref, **tols)
x_ref = x_fp8.from_float8()
x_fp8 *= 2
x_ref *= 2
torch.testing.assert_close(x_fp8, x_ref, **tols)
x_ref = x_fp8.from_float8()

# Make sure we are not trivially passing tests
x_ref += 123
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8, x_ref, **tols)

@pytest.mark.parametrize("dims", [[33, 41], [5, 7, 11]])
@pytest.mark.parametrize("transpose_dims", [(0, 1), (-2, -1), (0, 0)])
def test_transpose(
self,
dims: DimsType,
transpose_dims: Tuple[int, int],
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 1,
dtype: torch.dtype = torch.float32,
) -> None:
"""Test transpose"""

# Initialize random data
dims = _to_list(dims)
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
x_fp8 = Float8Tensor.to_float8(
x_ref,
fp8_dtype=fp8_dtype,
scale=torch.full([1], scale),
)
x_ref = x_fp8.from_float8()

# Perform transpose
y_fp8 = x_fp8.transpose(*transpose_dims)
y_ref = x_ref.transpose(*transpose_dims)

# Check results
tols = dict(rtol=0, atol=0)
torch.testing.assert_close(y_fp8, y_ref, **tols)

# Make sure we are not trivially passing the test
if transpose_dims[0] != transpose_dims[1]:
with pytest.raises(AssertionError):
torch.testing.assert_close(
y_fp8,
x_ref,
**tols,
)

# Check transpose caching
if x_fp8.dim() == 2 and transpose_dims[0] != transpose_dims[1]:
x_fp8 += 0.5
x_ref = x_fp8.from_float8()
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache=True),
x_ref.transpose(*transpose_dims),
**tols,
)
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache=True),
x_ref.transpose(*transpose_dims),
**tols,
)
x_fp8 += 0.5
x_ref = x_fp8.from_float8()
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache=True),
x_ref.transpose(*transpose_dims),
**tols,
)
Loading