Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions colossalai/quantization/fp8.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import Any, Optional
from typing import Any, Optional, Tuple

import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from packaging.version import Version
from torch.distributed import ReduceOp


def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> (torch.Tensor, torch.Tensor):
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling.
Args:
Expand Down Expand Up @@ -624,7 +625,13 @@ def forward(
ctx.inv_scale_x = inv_scale_x
ctx.inv_scale_w = inv_scale_w
out = torch._scaled_mm(
x_fp8, ctx.w_fp8_t, bias=bias, out_dtype=ctx.out_dtype, scale_a=inv_scale_x, scale_b=inv_scale_w
x_fp8,
ctx.w_fp8_t,
bias=bias,
out_dtype=ctx.out_dtype,
scale_a=inv_scale_x,
scale_b=inv_scale_w,
use_fast_accum=True,
)[0]
return out.reshape(*ctx.x_shape[:-1], w.shape[0])

Expand All @@ -638,19 +645,29 @@ def backward(ctx: Any, out_grad) -> Any:
out_dtype=ctx.out_dtype,
scale_a=out_grad_scale,
scale_b=ctx.inv_scale_w,
use_fast_accum=True,
)[0]
w_grad = torch._scaled_mm(
out_grad_fp8.t().contiguous(),
ctx.x_fp8.t().contiguous().t(),
out_dtype=ctx.out_dtype,
scale_a=out_grad_scale,
scale_b=ctx.inv_scale_x,
use_fast_accum=True,
)[0]
bias_grad = None
if ctx.has_bias:
bias_grad = out_grad.sum(0)
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad


def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _LinearFp8.apply(input, weight, bias)
if Version(torch.__version__) >= Version("2.3.0"): # TODO failed on torch < 2.3.0

@torch.compile(mode="reduce-overhead", fullgraph=True)
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _LinearFp8.apply(input, weight, bias)

else:

def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _LinearFp8.apply(input, weight, bias)