Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions colossalai/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,62 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e5m2", group=None) -> None:
tensor.data = tensor_out.view(input_shape).to(input_type)


def all_to_all_single_fp8(
output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False
) -> None:
r"""
This is an in-place operation for compressed all_reduce using fp8.
It works like dist.all_to_all_single but during communication the data is cast to fp8 format.
Args:
tensor: torch.Tensor in fp32, fp16, bf16 datatype.
fp8_format: e4m3 or e5m2
Returns:
None
"""
world_size = dist.get_world_size(group=group)
input_type = input.dtype
input_shape = input.shape
input_device = input.device
input = input.flatten()

fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2

ret, scale = cast_to_fp8(input, fp8_format=fp8_format)

inp = ret.view(torch.uint8)
if input_split_sizes is not None:
input_split_sizes = [input_split_sizes[i] * np.prod(input_shape[1:]) for i in range(world_size)]
input_chunks = list(torch.split(inp, input_split_sizes))
else:
input_chunks = list(torch.chunk(inp, world_size, dim=0))

if output_split_sizes is not None:
output_chunks = [
torch.empty((output_split_sizes[i] * np.prod(input_shape[1:]),), device=input_device, dtype=inp.dtype)
for i in range(world_size)
]
else:
if dist.get_rank() == world_size - 1:
output_chunks = [torch.empty_like(input_chunks[-1]) for _ in range(world_size)]
else:
output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)]

dist.all_to_all(output_chunks, input_chunks, group=group)
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
dist.all_gather(scale_list, scale, group=group)
cast_output_chunk = [
cast_from_fp8(out.view(fp8_type), scale, input_type) for scale, out in zip(scale_list, output_chunks)
]

tensor_out = torch.cat(cast_output_chunk, dim=0)
outputs_shape = list(input_shape)
if output_split_sizes is not None:
outputs_shape[0] = sum(output_split_sizes)
else:
outputs_shape = input_shape
output.data = tensor_out.view(outputs_shape).to(input_type)


def cast_to_fp8_pipeline(inp: Any) -> None:
"""
Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline.
Expand Down
67 changes: 67 additions & 0 deletions tests/test_fp8/test_all_to_all_single.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import torch
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_default_group
from torch.testing import assert_close

from colossalai import launch
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import all_to_all_single_fp8
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn


@parameterize("shape", [(4,), (1, 8, 16), (4, 8, 16)])
@parameterize("dtype", [torch.bfloat16])
def check_all2all(shape, dtype):
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
output = torch.empty_like(x)
output_fp8 = torch.empty_like(x)
dist.all_to_all_single(output, x, group=_get_default_group(), async_op=False)
all_to_all_single_fp8(output_fp8, x, group=_get_default_group(), async_op=False)
assert_close(output, output_fp8, rtol=0.1, atol=0.1)


@parameterize("shape", [(8, 8, 16)])
@parameterize("dtype", [torch.bfloat16, torch.float16])
def check_all2all_uneven(shape, dtype):
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
input_split_sizes = [3, 3, 1, 1]
if dist.get_rank() in [0, 1]:
output_split_sizes = [3, 3, 3, 3]
else:
output_split_sizes = [1, 1, 1, 1]
output_shape = list(shape)
output_shape[0] = sum(output_split_sizes)
output = torch.empty(output_shape, device=x.device, dtype=x.dtype)
output_fp8 = torch.empty(output_shape, device=x.device, dtype=x.dtype)
dist.all_to_all_single(
output,
x,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=_get_default_group(),
async_op=False,
)
all_to_all_single_fp8(
output_fp8,
x,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=_get_default_group(),
async_op=False,
)
assert_close(output, output_fp8, rtol=0.1, atol=0.1)


def run_dist(rank, world_size, port):
launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_all2all()
check_all2all_uneven()


@rerun_if_address_is_in_use()
def test_all_to_all_single():
spawn(run_dist, 4)


if __name__ == "__main__":
test_all_to_all_single()