From c3e1ae7797f3061a472907c84df725382fe22b7f Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 12 Aug 2024 15:31:46 +0800 Subject: [PATCH 1/3] [fp8] refactor fp8 linear with compile --- colossalai/quantization/fp8.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index cfbf1fcf7e40..429568d932f7 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -7,6 +7,8 @@ from packaging.version import Version from torch.distributed import ReduceOp +SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.3.0") + def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> Tuple[torch.Tensor, torch.Tensor]: r""" @@ -635,6 +637,7 @@ def forward( )[0] return out.reshape(*ctx.x_shape[:-1], w.shape[0]) + @torch.compile(mode="reduce-overhead", disable=not SUPPORT_TORCH_COMPILE) @staticmethod def backward(ctx: Any, out_grad) -> Any: out_grad = out_grad.reshape(-1, out_grad.shape[-1]) @@ -661,13 +664,5 @@ def backward(ctx: Any, out_grad) -> Any: return x_grad.reshape(ctx.x_shape), w_grad, bias_grad -if Version(torch.__version__) >= Version("2.3.0"): # TODO failed on torch < 2.3.0 - - @torch.compile(mode="reduce-overhead", fullgraph=True) - def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return _LinearFp8.apply(input, weight, bias) - -else: - - def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return _LinearFp8.apply(input, weight, bias) +def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _LinearFp8.apply(input, weight, bias) From 3529256cd3a504d24d1d235f13e705881d9549f8 Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 12 Aug 2024 17:49:11 +0800 Subject: [PATCH 2/3] [fp8] fix linear test --- tests/test_fp8/test_fp8_linear.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/test_fp8/test_fp8_linear.py b/tests/test_fp8/test_fp8_linear.py index d035957f2a31..f7a534eb03ea 100644 --- a/tests/test_fp8/test_fp8_linear.py +++ b/tests/test_fp8/test_fp8_linear.py @@ -5,6 +5,7 @@ from colossalai.accelerator import get_accelerator from colossalai.quantization.fp8 import linear_fp8 +from colossalai.testing import spawn from colossalai.utils import get_current_device D_IN, D_OUT = 16, 32 @@ -12,10 +13,7 @@ DTYPE = torch.bfloat16 -@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0") -@pytest.mark.parametrize("use_bias", [True, False]) -@pytest.mark.parametrize("use_batch", [True, False]) -def test_fp8_linear(use_bias: bool, use_batch: bool): +def run_test(rank, world_size=None, port=None, use_bias: bool = False, use_batch: bool = False): # create tensors w = torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True) ref_w = w.clone().detach().requires_grad_() @@ -43,3 +41,10 @@ def test_fp8_linear(use_bias: bool, use_batch: bool): assert_close(w.grad, ref_w.grad, rtol=0.2, atol=0.1) if use_bias: assert_close(bias.grad, ref_bias.grad, rtol=0.2, atol=0.1) + + +@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0") +@pytest.mark.parametrize("use_bias", [True, False]) +@pytest.mark.parametrize("use_batch", [True, False]) +def test_fp8_linear(use_bias: bool, use_batch: bool): + spawn(run_test, nprocs=1, use_bias=use_bias, use_batch=use_batch) From bf4aa02e9a1f34f75a978e14ebdf8558f76444ac Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 13 Aug 2024 14:03:21 +0800 Subject: [PATCH 3/3] [fp8] fix linear test --- colossalai/quantization/fp8.py | 12 ++++++++++-- tests/test_fp8/test_fp8_linear.py | 13 ++++--------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 429568d932f7..7b74cc673fa9 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -637,7 +637,6 @@ def forward( )[0] return out.reshape(*ctx.x_shape[:-1], w.shape[0]) - @torch.compile(mode="reduce-overhead", disable=not SUPPORT_TORCH_COMPILE) @staticmethod def backward(ctx: Any, out_grad) -> Any: out_grad = out_grad.reshape(-1, out_grad.shape[-1]) @@ -664,5 +663,14 @@ def backward(ctx: Any, out_grad) -> Any: return x_grad.reshape(ctx.x_shape), w_grad, bias_grad -def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: +@torch.compile(mode="reduce-overhead", disable=not SUPPORT_TORCH_COMPILE) +def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: return _LinearFp8.apply(input, weight, bias) + + +def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + out = _linear_fp8(input, weight, bias) + if SUPPORT_TORCH_COMPILE: + # avoid modifying the tensor created from cuda graph + out = out.clone() + return out diff --git a/tests/test_fp8/test_fp8_linear.py b/tests/test_fp8/test_fp8_linear.py index f7a534eb03ea..d035957f2a31 100644 --- a/tests/test_fp8/test_fp8_linear.py +++ b/tests/test_fp8/test_fp8_linear.py @@ -5,7 +5,6 @@ from colossalai.accelerator import get_accelerator from colossalai.quantization.fp8 import linear_fp8 -from colossalai.testing import spawn from colossalai.utils import get_current_device D_IN, D_OUT = 16, 32 @@ -13,7 +12,10 @@ DTYPE = torch.bfloat16 -def run_test(rank, world_size=None, port=None, use_bias: bool = False, use_batch: bool = False): +@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0") +@pytest.mark.parametrize("use_bias", [True, False]) +@pytest.mark.parametrize("use_batch", [True, False]) +def test_fp8_linear(use_bias: bool, use_batch: bool): # create tensors w = torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True) ref_w = w.clone().detach().requires_grad_() @@ -41,10 +43,3 @@ def run_test(rank, world_size=None, port=None, use_bias: bool = False, use_batch assert_close(w.grad, ref_w.grad, rtol=0.2, atol=0.1) if use_bias: assert_close(bias.grad, ref_bias.grad, rtol=0.2, atol=0.1) - - -@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0") -@pytest.mark.parametrize("use_bias", [True, False]) -@pytest.mark.parametrize("use_batch", [True, False]) -def test_fp8_linear(use_bias: bool, use_batch: bool): - spawn(run_test, nprocs=1, use_bias=use_bias, use_batch=use_batch)