Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,7 +945,8 @@ class HybridParallelPlugin(PipelinePluginBase):
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism.
fp8_communication (bool, optional): Whether to enable fp8 communication in model parallelism
"""

def __init__(
Expand Down Expand Up @@ -1119,6 +1120,7 @@ def __init__(
parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config,
fp8_communication=fp8_communication,
)
self.amp_config = dict(
initial_scale=initial_scale,
Expand Down
54 changes: 43 additions & 11 deletions colossalai/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3") -> (torch.Tensor, torch.Te
scale: scaling factor for fp8 casting. If it is None, then it is computed automatically. Per-channel scaling
is applied if input tensor is 2 dimension, otherwise, per-tensor scaling is applied.
fp8_format: e4m3 or e5m2

Returns:
Tuples: A tuple (fp8_tensor, scale)
"""
Expand All @@ -39,12 +38,10 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3") -> (torch.Tensor, torch.Te

def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype) -> torch.Tensor:
r"""

Args:
inp: should be a fp8 torch tensor in one of the types: [torch.float8_e4m3fn, torch.float8_e5m2].
scale: scaling factor returned by cast_to_fp8 function.
ret_type: the datatype of the returned tensor.

Returns:
torch.Tensor
"""
Expand All @@ -58,20 +55,18 @@ def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dt
return ret.to(ret_type)


def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None:
def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e5m2", group=None) -> None:
r"""
This is an in-place operation for compressed all_reduce using fp8.
It works like dist.all_reduce but during communication the data is cast to fp8 format.

Args:
tensor: torch.Tensor in fp32, fp16, bf16 datatype.
fp8_format: e4m3 or e5m2

Returns:
None
"""

world_size = dist.get_world_size()
world_size = dist.get_world_size(group=group)
input_type = tensor.dtype
input_shape = tensor.shape
input_device = tensor.device
Expand All @@ -88,19 +83,19 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None:
output_chunks = [torch.empty_like(input_chunks[-1]) for _ in range(world_size)]
else:
output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)]
dist.all_to_all(output_chunks, input_chunks)
dist.all_to_all(output_chunks, input_chunks, group=group)
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
dist.all_gather(scale_list, scale)
dist.all_gather(scale_list, scale, group=group)
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
for scale, out in zip(scale_list, output_chunks):
out = out.view(fp8_type)
summed_out += cast_from_fp8(out, scale, input_type)

summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format)
dist.all_gather(scale_list, scale)
dist.all_gather(scale_list, scale, group=group)

tensor_list = list(torch.chunk(torch.empty(input_size, device=input_device, dtype=torch.uint8), world_size, dim=0))
dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8))
dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8), group=group)
for i in range(world_size):
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
tensor_out = torch.cat(tensor_list, dim=0)
Expand Down Expand Up @@ -170,3 +165,40 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:

if del_metadata:
del inp["fp8_scale"]


def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2") -> None:
r"""
This is an in-place operation for compressed reduce_scatter using fp8.
It works like dist.reduce_scatter but during communication the data is cast to fp8 format.

Args:
tensor: torch.Tensor in fp32, fp16, bf16 datatype.
fp8_format: e4m3 or e5m2

Returns:
None
"""

input_type = output.dtype

fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
scale_list = []
cast_input_list = []
output_chunks = []
output_scale_list = []
for input in input_list:
ret, scale = cast_to_fp8(input, fp8_format=fp8_format)
scale_list.append(scale)
ret = ret.view(torch.uint8)
cast_input_list.append(ret)
output_chunks.append(torch.empty_like(ret))
output_scale_list.append(torch.empty_like(scale))
dist.all_to_all(output_chunks, cast_input_list, group=group)
dist.all_to_all(output_scale_list, scale_list, group=group)

summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
for scale, out in zip(output_scale_list, output_chunks):
out = out.view(fp8_type)
summed_out += cast_from_fp8(out, scale, input_type)
output.data = summed_out
Loading