From 4e86285481a7fd542ffc995b36969ad220611249 Mon Sep 17 00:00:00 2001 From: botbw Date: Thu, 8 Aug 2024 05:11:03 +0000 Subject: [PATCH 1/4] [fp8] use torch compile (torch >= 2.4.0) --- colossalai/quantization/fp8.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 53febd16c8f6..bf77e3c4b403 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any, Optional, Tuple import numpy as np import torch @@ -7,7 +7,7 @@ from torch.distributed import ReduceOp -def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> (torch.Tensor, torch.Tensor): +def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> Tuple[torch.Tensor, torch.Tensor]: r""" casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling. Args: @@ -652,5 +652,13 @@ def backward(ctx: Any, out_grad) -> Any: return x_grad.reshape(ctx.x_shape), w_grad, bias_grad -def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return _LinearFp8.apply(input, weight, bias) +if torch.__version__ >= (2, 4): # TODO failed on torch < 2.4 + + @torch.compile(mode="reduce-overhead", fullgraph=True) + def linear_fp8(x: torch.Tensor, w: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _LinearFp8.apply(x, w, bias) + +else: + + def linear_fp8(x: torch.Tensor, w: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _LinearFp8.apply(x, w, bias) From f33b00a956af9de16209fda607d998289f2c391b Mon Sep 17 00:00:00 2001 From: botbw Date: Thu, 8 Aug 2024 08:06:28 +0000 Subject: [PATCH 2/4] [fp8] set use_fast_accum in linear --- colossalai/quantization/fp8.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index bf77e3c4b403..f92b2411b35e 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -624,7 +624,13 @@ def forward( ctx.inv_scale_x = inv_scale_x ctx.inv_scale_w = inv_scale_w out = torch._scaled_mm( - x_fp8, ctx.w_fp8_t, bias=bias, out_dtype=ctx.out_dtype, scale_a=inv_scale_x, scale_b=inv_scale_w + x_fp8, + ctx.w_fp8_t, + bias=bias, + out_dtype=ctx.out_dtype, + scale_a=inv_scale_x, + scale_b=inv_scale_w, + use_fast_accum=True, )[0] return out.reshape(*ctx.x_shape[:-1], w.shape[0]) @@ -638,6 +644,7 @@ def backward(ctx: Any, out_grad) -> Any: out_dtype=ctx.out_dtype, scale_a=out_grad_scale, scale_b=ctx.inv_scale_w, + use_fast_accum=True, )[0] w_grad = torch._scaled_mm( out_grad_fp8.t().contiguous(), @@ -645,6 +652,7 @@ def backward(ctx: Any, out_grad) -> Any: out_dtype=ctx.out_dtype, scale_a=out_grad_scale, scale_b=ctx.inv_scale_x, + use_fast_accum=True, )[0] bias_grad = None if ctx.has_bias: From 67f538e1df4f9bb01c08b8af8b233d1cb64246e3 Mon Sep 17 00:00:00 2001 From: botbw Date: Fri, 9 Aug 2024 06:52:49 +0000 Subject: [PATCH 3/4] [chore] formal version check --- colossalai/quantization/fp8.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index f92b2411b35e..606d898d5e0f 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -4,6 +4,7 @@ import torch import torch.distributed as dist import torch.nn.functional as F +from packaging.version import Version from torch.distributed import ReduceOp @@ -660,7 +661,7 @@ def backward(ctx: Any, out_grad) -> Any: return x_grad.reshape(ctx.x_shape), w_grad, bias_grad -if torch.__version__ >= (2, 4): # TODO failed on torch < 2.4 +if Version(torch.__version__) >= Version("2.3.0"): # TODO failed on torch < 2.3.0 @torch.compile(mode="reduce-overhead", fullgraph=True) def linear_fp8(x: torch.Tensor, w: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: From 8275f6a29178546c01ae9a8d523677cf91655fb0 Mon Sep 17 00:00:00 2001 From: botbw Date: Fri, 9 Aug 2024 06:57:36 +0000 Subject: [PATCH 4/4] [chore] fix sig --- colossalai/quantization/fp8.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 606d898d5e0f..cfbf1fcf7e40 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -664,10 +664,10 @@ def backward(ctx: Any, out_grad) -> Any: if Version(torch.__version__) >= Version("2.3.0"): # TODO failed on torch < 2.3.0 @torch.compile(mode="reduce-overhead", fullgraph=True) - def linear_fp8(x: torch.Tensor, w: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return _LinearFp8.apply(x, w, bias) + def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _LinearFp8.apply(input, weight, bias) else: - def linear_fp8(x: torch.Tensor, w: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return _LinearFp8.apply(x, w, bias) + def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _LinearFp8.apply(input, weight, bias)