Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
decf2e8
support fp8_communication in the Torch DDP grad comm, FSDP grad comm,…
BurkeHulk Jul 19, 2024
84a6be0
Merge remote-tracking branch 'origin/feature/fp8_comm' into feature/f…
BurkeHulk Jul 19, 2024
de811e7
Merge branch 'feature/fp8_comm' into feature/fp8_comm
BurkeHulk Jul 26, 2024
70e3f8b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 26, 2024
5a5f1b0
implement communication hook for FSDP params all-gather
BurkeHulk Jul 26, 2024
6264da8
added unit test for fp8 operators
BurkeHulk Jul 26, 2024
1d5e5be
support fp8 communication in GeminiPlugin
BurkeHulk Jul 26, 2024
ef68dbf
update training scripts to support fsdp and fp8 communication
BurkeHulk Jul 26, 2024
694d125
fixed some minor bugs observed in unit test
BurkeHulk Jul 26, 2024
a92da21
add all_gather_into_tensor_flat_fp8
BurkeHulk Jul 26, 2024
62c99f4
Merge remote-tracking branch 'origin/feature/fp8_comm' into feature/f…
BurkeHulk Jul 26, 2024
c0c1a30
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 26, 2024
ea295b9
Merge branch 'feature/fp8_comm' into feature/fp8_comm
BurkeHulk Aug 5, 2024
e3d5455
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2024
1ed1355
add skip the test if torch < 2.2.0
BurkeHulk Aug 7, 2024
2080b73
Merge remote-tracking branch 'origin/feature/fp8_comm' into feature/f…
BurkeHulk Aug 7, 2024
edd0d64
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 7, 2024
53ea42b
add skip the test if torch < 2.2.0
BurkeHulk Aug 7, 2024
0314698
add skip the test if torch < 2.2.0
BurkeHulk Aug 7, 2024
65a8653
Merge branch 'hpcaitech:feature/fp8_comm' into feature/fp8_comm
BurkeHulk Aug 7, 2024
b4aa1e4
add fp8_comm flag
BurkeHulk Aug 7, 2024
90c4280
rebase latest fp8 operators
BurkeHulk Aug 7, 2024
6079b13
rebase latest fp8 operators
BurkeHulk Aug 7, 2024
6a93e23
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ def __init__(
enable_sequence_overlap: bool = False,
enable_async_reduce: bool = True,
verbose: bool = False,
fp8_communication: bool = False,
) -> None:
super().__init__()
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
Expand Down Expand Up @@ -395,6 +396,7 @@ def __init__(
master_weights=master_weights,
max_prefetch=max_prefetch,
enable_async_reduce=enable_async_reduce,
fp8_communication=fp8_communication,
)
self.zero_optim_config = dict(
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
Expand Down
7 changes: 7 additions & 0 deletions colossalai/booster/plugin/torch_ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def __init__(
check_reduction: bool = False,
gradient_as_bucket_view: bool = False,
static_graph: bool = False,
fp8_communication: bool = False,
) -> None:
super().__init__()
self.ddp_kwargs = dict(
Expand All @@ -187,6 +188,7 @@ def __init__(
gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph,
)
self.fp8_communication = fp8_communication

def support_no_sync(self) -> bool:
return True
Expand Down Expand Up @@ -226,6 +228,11 @@ def configure(
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer = OptimizerWrapper(optimizer)

if self.fp8_communication:
from colossalai.quantization.fp8 import fp8_compress_ddp_grad_comm_hook_async

model.module.register_comm_hook(None, fp8_compress_ddp_grad_comm_hook_async)

return model, optimizer, criterion, dataloader, lr_scheduler

def control_checkpoint_io(self) -> bool:
Expand Down
15 changes: 15 additions & 0 deletions colossalai/booster/plugin/torch_fsdp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def __init__(
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
sync_module_states: bool = False,
fp8_communication: bool = False,
):
super().__init__()
self.fsdp_kwargs = dict(
Expand All @@ -311,6 +312,7 @@ def __init__(
param_init_fn=param_init_fn,
sync_module_states=sync_module_states,
)
self.fp8_communication = fp8_communication

else:
raise RuntimeError("FSDP is not supported while torch version under 1.12.0.")
Expand Down Expand Up @@ -347,6 +349,19 @@ def configure(
# wrap the model with PyTorch FSDP
fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs)

if self.fp8_communication:
from colossalai.quantization.utils import patch_fsdp_params_comm_hook

patch_fsdp_params_comm_hook()

from colossalai.quantization.fp8 import fp8_compress_fsdp_params_comm_hook

fsdp_model.module.register_params_comm_hook(None, fp8_compress_fsdp_params_comm_hook)

from colossalai.quantization.fp8 import fp8_compress_fsdp_grad_comm_hook

fsdp_model.module.register_comm_hook(None, fp8_compress_fsdp_grad_comm_hook)

if optimizer is not None:
if len(optimizer.param_groups) > 1:
warnings.warn(
Expand Down
229 changes: 204 additions & 25 deletions colossalai/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -
scale: scaling factor for fp8 casting. If it is None, then it is computed automatically. Per-channel scaling
is applied if input tensor is 2 dimension, otherwise, per-tensor scaling is applied.
fp8_format: e4m3 or e5m2

Returns:
Tuples: A tuple (fp8_tensor, scale)
"""
Expand All @@ -29,12 +30,13 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -
per_channel_max = inp.abs().max(dim=-1).values.float()
per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0)
scale = fp8_max / per_channel_max[:, None]
scale_inv = per_channel_max / fp8_max
else:
per_tensor_max = inp.abs().max().float()
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
scale = fp8_max / per_tensor_max
scale_inv = 1.0 / scale

scale_inv = 1.0 / scale
ret = (scale * inp.float()).to(fp8_type)
return ret, scale_inv

Expand Down Expand Up @@ -185,7 +187,11 @@ def cast_to_fp8_pipeline(inp: Any) -> None:
return

assert "hidden_states" in inp, "required by pipeline parallelism."
assert (
inp["hidden_states"].size(-1) % 2 == 0
), "tensor size(-1) must be divisible by 2 to view Float8_e4m3fn as BFloat16 or Float16"
inp_tensor = inp["hidden_states"]
inp_dtype = inp_tensor.dtype

min_val, max_val = inp_tensor.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs())
Expand All @@ -206,6 +212,7 @@ def cast_to_fp8_pipeline(inp: Any) -> None:
inp_tensor.data = q_tensor.to(fp8_type).view(fp8_view_type)

inp["fp8_scale"] = scale.float().reciprocal()
inp["dtype"] = torch.zeros_like(scale).to(inp_dtype)


def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
Expand All @@ -230,10 +237,11 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
else:
raise TypeError("Only float16, bfloat16 are implemented.")

inp_tensor.data = inp_tensor.data.view(fp8_type).to(torch.float16) * scale
inp_tensor.data = inp_tensor.data.view(fp8_type).to(inp["dtype"]) * scale

if del_metadata:
del inp["fp8_scale"]
del inp["dtype"]


def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2") -> None:
Expand Down Expand Up @@ -273,6 +281,199 @@ def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2
output.data = summed_out
Comment thread
BurkeHulk marked this conversation as resolved.


def fp8_compress_ddp_grad_comm_hook_async(
process_group: dist.ProcessGroup,
bucket: dist.GradBucket,
fp8_format: str = "e5m2",
) -> torch.futures.Future[torch.Tensor]:
"""
Compress by casting ``GradBucket`` to FP8 floating-point format divided by process group size.

This DDP communication hook implements a simple gradient compression approach that casts ``GradBucket`` tensor
to FP8 floating-point format (``torch.float8_e5m2`` or ``torch.bfloat16_e4m3``), and then divides it
by the process group size.
Once compressed gradient tensors are allreduced, the chained callback ``decompress`` casts it back
to the input data type (such as ``float32``).

Example::
>>> ddp_model.register_comm_hook(process_group, fp8_compress_ddp_grad_comm_hook_async)
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD

input_tensor = bucket.buffer()
world_size = dist.get_world_size()
input_type = input_tensor.dtype
input_device = input_tensor.device
flat_padded_x = input_tensor.flatten()

if flat_padded_x.size(0) % world_size != 0:
pad_size = world_size - flat_padded_x.size(0) % world_size
flat_padded_x = F.pad(flat_padded_x, (0, pad_size))

fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
ret, scale = cast_to_fp8(flat_padded_x, fp8_format=fp8_format)

inp = ret.view(torch.uint8)
output_chunks_single = torch.empty_like(inp)
split_sizes = [inp.numel() // world_size for _ in range(world_size)]
fut0 = dist.all_to_all_single(
output_chunks_single,
inp,
output_split_sizes=split_sizes,
input_split_sizes=split_sizes,
group=group_to_use,
async_op=True,
).get_future()

scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
fut1 = dist.all_gather_into_tensor(
torch.cat(scale_list, dim=0), scale, group=group_to_use, async_op=True
).get_future()
all_to_all_fut = torch.futures.collect_all([fut0, fut1])

def sum_and_allgather(fut):
output_chunks_single = fut.value()[0].wait()[0]
scale_list_single = fut.value()[1].wait()[0]

output_chunks = list(torch.chunk(output_chunks_single, world_size, dim=0))
scale_list = scale_list_single.chunk(world_size, dim=0)

summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
for scale, out in zip(scale_list, output_chunks):
out = out.view(fp8_type)
summed_out += cast_from_fp8(out, scale, input_type)
summed_out.div_(world_size)

summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format)

tensor_list_single = torch.empty(summed_out_fp8.size(0) * world_size, device=input_device, dtype=torch.uint8)
fut2 = dist.all_gather_into_tensor(
tensor_list_single, summed_out_fp8.view(torch.uint8), group=group_to_use, async_op=True
).get_future()

scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
fut3 = dist.all_gather_into_tensor(
torch.cat(scale_list, dim=0), scale, group=group_to_use, async_op=True
).get_future()
fut_combined2 = torch.futures.collect_all([fut2, fut3])
return fut_combined2

def decompress(fut):
tensor_list_single = fut.value().wait()[0].value()[0]
scale_list_single = fut.value().wait()[1].value()[0]

tensor_list = list(torch.chunk(tensor_list_single, world_size, dim=0))
scale_list = scale_list_single.chunk(world_size, dim=0)

for i in range(world_size):
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
out = torch.cat(tensor_list, dim=0)

input_tensor_size = input_tensor.numel()
input_shape = input_tensor.shape
out = out[:input_tensor_size]

input_tensor.copy_(out.view(input_shape).to(input_type))
return input_tensor

return all_to_all_fut.then(sum_and_allgather).then(decompress)


def fp8_compress_ddp_grad_comm_hook_sync(
process_group: dist.ProcessGroup,
bucket: dist.GradBucket,
fp8_format="e5m2",
) -> torch.futures.Future[torch.Tensor]:
"""
Return a future that wraps the input, after the input is allreduced. However, the allreduce commnunication is synchronized.
This breaks the overlapping between allreduce communication and backward compuation.

This hook should **only** be used for debugging purposes, instead of the normal gradient synchronization.
For asynchronized implementation, use fp8_compress_ddp_grad_comm_hook_async instead.

Example::
>>> # xdoctest: +SKIP
>>> ddp_model.register_comm_hook(None, fp8_compress_ddp_grad_comm_hook_sync)
"""

buffer = bucket.buffer()
all_reduce_fp8(buffer, fp8_format=fp8_format)

fut: torch.futures.Future[torch.Tensor] = torch.futures.Future()
fut.set_result(bucket.buffer())

return fut


def fp8_compress_fsdp_grad_comm_hook(
state: object,
unsharded_gradient_flattened: torch.Tensor,
sharded_gradient: torch.Tensor,
group=None,
fp8_format="e5m2",
) -> None:
"""
This communication hook implements a simple gradient compression approach that casts unsharded_gradient_flattened tensor
to FP8 floating-point format (``torch.float8_e5m2`` or ``torch.bfloat16_e4m3``), and then perform scatter_allreduce logic
by using all_to_all and all_gather among the process group.

Example::
>>> fsdp_model.register_comm_hook(None, fp8_compress_fsdp_grad_comm_hook)
"""
grad = unsharded_gradient_flattened
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
input_type = grad.dtype
input_device = grad.device
world_size = dist.get_world_size(group=group)

grad_fp8, scale = cast_to_fp8(grad, fp8_format=fp8_format)
uint8_buffer = torch.empty_like(grad_fp8).view(torch.uint8)
dist.all_to_all_single(uint8_buffer, grad_fp8.view(torch.uint8), group=group)

scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
dist.all_gather(scale_list, scale, group=group)

buffer_list = list(torch.chunk(uint8_buffer.view(fp8_type), world_size, dim=0))
sharded_gradient.zero_()
for tensor, scale in zip(buffer_list, scale_list):
sharded_gradient += cast_from_fp8(tensor, scale, input_type)


def fp8_compress_fsdp_params_comm_hook(
state: object,
padded_unsharded_flat_param: torch.Tensor,
sharded_flat_param: torch.Tensor,
group=None,
fp8_format="e5m2",
) -> None:
"""
This hook is pending the official support for parameters communication hook in FSDP, e.g. register_params_comm_hook.

Example::
>>> fsdp_model.register_params_comm_hook(None, fp8_compress_fsdp_params_comm_hook)
"""

fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
fp8_max = torch.finfo(fp8_type).max
inp = sharded_flat_param
out = padded_unsharded_flat_param

per_tensor_max = inp.abs().max().float()
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
dist.all_reduce(per_tensor_max, op=torch.distributed.ReduceOp.MAX, group=group)

scale = fp8_max / per_tensor_max
fp8_sharded_flat_param = (scale * inp.float()).to(fp8_type).view(torch.uint8)

fp8_out = torch.empty(out.shape, dtype=torch.uint8, device=out.device)
dist.all_gather_into_tensor(
fp8_out,
fp8_sharded_flat_param,
group=group,
)
padded_unsharded_flat_param.copy_((fp8_out.view(fp8_type).float() / scale).to(out.dtype))


def split_chunk_by_channel(
chunk: torch.Tensor, channel_size: int, num_channels: int, rank: int = 0, world_size: int = 1
):
Expand Down Expand Up @@ -342,7 +543,7 @@ def all_gather_into_tensor_flat_fp8(
scale_inv = 1.0 / scale
buffer = torch.empty_like(output_tensor, dtype=fp8_type)
dist.all_gather_into_tensor(buffer.view(torch.uint8), fp8_input.view(torch.uint8), group=group)
numel = np.prod(output_shape)
numel = output_shape.numel()
valid_buffer = buffer[:numel].reshape(output_shape)
valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type, per_channel_scale=(len(output_shape) == 2))
output_tensor[:numel].copy_(valid_buffer.view(-1))
Expand Down Expand Up @@ -376,28 +577,6 @@ def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2"):
output_list[i].copy_(cast_from_fp8(tensor, scale, input_type))


def all_to_all_single_fp8(output_tensor, input_tensor, group=None, fp8_format="e5m2"):

world_size = dist.get_world_size(group)

per_slice_len = input_tensor.size(0) // world_size
input_type = input_tensor.dtype
ret, scale = cast_to_fp8(input_tensor, fp8_format=fp8_format)
fp8_type = ret.dtype
input_tensor = ret.view(torch.uint8)
tensor = torch.empty_like(input_tensor)
scale_list = [torch.empty_like(scale) for _ in range(world_size)]
dist.all_to_all_single(tensor, input_tensor, group=group)
dist.all_gather(scale_list, scale, group=group)
cast_tensor_list = []

for i in range(world_size):
output_part = tensor[per_slice_len * i : per_slice_len * (i + 1)].view(fp8_type)
output_part = cast_from_fp8(output_part, scale_list[i], input_type)
cast_tensor_list.append(output_part)
output_tensor.copy_(torch.concatenate(cast_tensor_list, dim=0))


def gather_fp8(output_list, input_, group=None, fp8_format="e5m2"):

world_size = dist.get_world_size(group)
Expand Down
Loading