Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions colossalai/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from packaging.version import Version
from torch.distributed import ReduceOp

SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.3.0")


def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Expand Down Expand Up @@ -661,13 +663,14 @@ def backward(ctx: Any, out_grad) -> Any:
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad


if Version(torch.__version__) >= Version("2.3.0"): # TODO failed on torch < 2.3.0

@torch.compile(mode="reduce-overhead", fullgraph=True)
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _LinearFp8.apply(input, weight, bias)
@torch.compile(mode="reduce-overhead", disable=not SUPPORT_TORCH_COMPILE)
def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _LinearFp8.apply(input, weight, bias)

else:

def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _LinearFp8.apply(input, weight, bias)
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
out = _linear_fp8(input, weight, bias)
if SUPPORT_TORCH_COMPILE:
# avoid modifying the tensor created from cuda graph
out = out.clone()
return out