Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 97 additions & 16 deletions colossalai/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed import ReduceOp


def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3") -> (torch.Tensor, torch.Tensor):
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> (torch.Tensor, torch.Tensor):
r"""
casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling.
Args:
Expand All @@ -23,7 +25,7 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3") -> (torch.Tensor, torch.Te
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
fp8_max = torch.finfo(fp8_type).max

if inp.dim() == 2:
if per_channel_scale:
per_channel_max = inp.abs().max(dim=-1).values.float()
per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0)
scale = fp8_max / per_channel_max[:, None]
Expand All @@ -37,7 +39,9 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3") -> (torch.Tensor, torch.Te
return ret, scale_inv


def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype) -> torch.Tensor:
def cast_from_fp8(
inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype, per_channel_scale=False
) -> torch.Tensor:
r"""
Args:
inp: should be a fp8 torch tensor in one of the types: [torch.float8_e4m3fn, torch.float8_e5m2].
Expand All @@ -49,20 +53,23 @@ def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dt
if inp.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]:
raise TypeError("Only float8_e4m3fn and float8_e5m2 are allowed.")

if inp.dim() == 2:
if per_channel_scale:
ret = scale_inv[:, None] * inp.float()
else:
ret = scale_inv * inp.float()
return ret.to(ret_type)


def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e5m2", group=None) -> None:
def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None) -> None:
r"""
This is an in-place operation for compressed all_reduce using fp8.
It works like dist.all_reduce but during communication the data is cast to fp8 format.

Args:
tensor: torch.Tensor in fp32, fp16, bf16 datatype.
fp8_format: e4m3 or e5m2
op: ReduceOp.SUM or ReduceOp.AVG

Returns:
None
"""
Expand All @@ -72,18 +79,20 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e5m2", group=None) -> None:
input_shape = tensor.shape
input_device = tensor.device
input_size = tensor.numel()
tensor = tensor.flatten()
flat_padded_x = tensor.flatten()

fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
assert op in [ReduceOp.SUM, ReduceOp.AVG], "op can only be ReduceOp.SUM or ReduceOp.AVG"

ret, scale = cast_to_fp8(tensor, fp8_format=fp8_format)
if flat_padded_x.size(0) % world_size != 0:
pad_size = world_size - flat_padded_x.size(0) % world_size
flat_padded_x = F.pad(flat_padded_x, (0, pad_size))

fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
ret, scale = cast_to_fp8(flat_padded_x, fp8_format=fp8_format)

inp = ret.view(torch.uint8)
input_chunks = list(torch.chunk(inp, world_size, dim=0))
if dist.get_rank() == world_size - 1:
output_chunks = [torch.empty_like(input_chunks[-1]) for _ in range(world_size)]
else:
output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)]
output_chunks = list(torch.chunk(torch.empty_like(inp), world_size, dim=0))
dist.all_to_all(output_chunks, input_chunks, group=group)
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
dist.all_gather(scale_list, scale, group=group)
Expand All @@ -92,15 +101,18 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e5m2", group=None) -> None:
out = out.view(fp8_type)
summed_out += cast_from_fp8(out, scale, input_type)

if op == ReduceOp.AVG:
summed_out.div_(world_size)

summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format)
dist.all_gather(scale_list, scale, group=group)

tensor_list = list(torch.chunk(torch.empty(input_size, device=input_device, dtype=torch.uint8), world_size, dim=0))
tensor_list = [torch.empty_like(summed_out_fp8.view(torch.uint8)) for _ in range(world_size)]
dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8), group=group)
for i in range(world_size):
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
tensor_out = torch.cat(tensor_list, dim=0)
tensor.data = tensor_out.view(input_shape).to(input_type)
out = torch.cat(tensor_list, dim=0)
tensor.copy_(out[:input_size].view(input_shape).to(input_type))


def cast_to_fp8_pipeline(inp: Any) -> None:
Expand Down Expand Up @@ -276,5 +288,74 @@ def all_gather_into_tensor_flat_fp8(
dist.all_gather_into_tensor(buffer.view(torch.uint8), fp8_input.view(torch.uint8), group=group)
numel = np.prod(output_shape)
valid_buffer = buffer[:numel].reshape(output_shape)
valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type)
valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type, per_channel_scale=(len(output_shape) == 2))
output_tensor[:numel].copy_(valid_buffer.view(-1))


def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2"):

world_size = dist.get_world_size(group)

input_type = input_list[0].dtype
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
scale_list = []
tensor_list = []

for i in range(world_size):
input_tensor = input_list[i]
ret, scale = cast_to_fp8(input_tensor, fp8_format=fp8_format)
scale_list.append(scale)
ret = ret.view(torch.uint8)
tensor_list.append(ret)

output_scale_list = [torch.empty_like(x) for x in scale_list]
output_tensor_list = [torch.empty_like(x) for x in tensor_list]
dist.all_to_all(output_tensor_list, tensor_list, group=group)
dist.all_to_all(output_scale_list, scale_list, group=group)

for i in range(world_size):
scale = output_scale_list[i]
tensor = output_tensor_list[i]
tensor = tensor.view(fp8_type)
output_list[i].copy_(cast_from_fp8(tensor, scale, input_type))


def all_to_all_single_fp8(output_tensor, input_tensor, group=None, fp8_format="e5m2"):

world_size = dist.get_world_size(group)

per_slice_len = input_tensor.size(0) // world_size
input_type = input_tensor.dtype
ret, scale = cast_to_fp8(input_tensor, fp8_format=fp8_format)
fp8_type = ret.dtype
input_tensor = ret.view(torch.uint8)
tensor = torch.empty_like(input_tensor)
scale_list = [torch.empty_like(scale) for _ in range(world_size)]
dist.all_to_all_single(tensor, input_tensor, group=group)
dist.all_gather(scale_list, scale, group=group)
cast_tensor_list = []

for i in range(world_size):
output_part = tensor[per_slice_len * i : per_slice_len * (i + 1)].view(fp8_type)
output_part = cast_from_fp8(output_part, scale_list[i], input_type)
cast_tensor_list.append(output_part)
output_tensor.copy_(torch.concatenate(cast_tensor_list, dim=0))


def gather_fp8(output_list, input_, group=None, fp8_format="e5m2"):

world_size = dist.get_world_size(group)

input_type = input_.dtype
ret, scale = cast_to_fp8(input_, fp8_format=fp8_format)
fp8_type = ret.dtype
input_ = ret.view(torch.uint8)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(world_size)]
dist.all_gather(tensor_list, input_, group=group)
dist.all_gather(scale_list, scale, group=group)

for i in range(world_size):
output = tensor_list[i].view(fp8_type)
scale = scale_list[i]
output_list[i].copy_(cast_from_fp8(output, scale, input_type))
Loading