From f5a52e1600e59d461a11abbf05585c0e0b19c906 Mon Sep 17 00:00:00 2001 From: HangXu Date: Mon, 1 Jul 2024 13:44:21 +0800 Subject: [PATCH 1/7] fp8 operators for compressed communication cast_to_fp8, cast_from_fp8, all_reduce_fp8 --- colossalai/quantization/fp8.py | 105 +++++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 colossalai/quantization/fp8.py diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py new file mode 100644 index 000000000000..d405de2de3d1 --- /dev/null +++ b/colossalai/quantization/fp8.py @@ -0,0 +1,105 @@ +import torch +import torch.distributed as dist + + +def cast_to_fp8(inp: torch.Tensor, scale=None, fp8_format="e4m3") -> (torch.Tensor, torch.Tensor): + r""" + casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling. + Args: + inp: input torch Tensor, should be in torch.FloatTensor, torch.HalfTensor, torch.BFloat16Tensor. + scale: scaling factor for fp8 casting. If it is None, then it is computed automatically. Per-channel scaling + is applied if input tensor is 2 dimension, otherwise, per-tensor scaling is applied. + fp8_format: e4m3 or e5m2 + + Returns: + Tuples: A tuple (fp8_tensor, scale) + """ + if inp.dtype in [torch.FloatTensor, torch.HalfTensor, torch.BFloat16Tensor]: + return inp + fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 + + if inp.dim() == 2: + if scale is None: + per_channel_max = inp.abs().max(dim=-1).values + scale = per_channel_max + scale_inv = 1.0 / scale + scale_inv = scale_inv[:, None] + ret = (scale_inv * inp).to(fp8_type) + else: + if scale is None: + per_tensor_max = inp.abs().max() + scale = per_tensor_max + scale_inv = 1.0 / scale + ret = (scale_inv * inp).to(fp8_type) + + return ret, scale + + +def cast_from_fp8(inp: torch.Tensor, scale: torch.Tensor, ret_type: torch.dtype) -> torch.Tensor: + r""" + + Args: + inp: should be a fp8 torch tensor in one of the types: [torch.float8_e4m3fn, torch.float8_e5m2]. + scale: scaling factor returned by cast_to_fp8 function. + ret_type: the datatype of the returned tensor. + + Returns: + torch.Tensor + """ + if inp.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: + return inp + if inp.dim() == 2: + ret = scale[:, None] * inp.to(ret_type) + else: + ret = scale * inp.to(ret_type) + return ret + + +def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None: + r""" + This is an in-place operation for compressed all_reduce using fp8. + It works like dist.all_reduce but during communication the data is cast to fp8 format. + + Args: + tensor: torch.Tensor in fp32, fp16, bf16 datatype. + fp8_format: e4m3 or e5m2 + + Returns: + None + """ + + world_size = dist.get_world_size() + rank = dist.get_rank() + input_type = tensor.dtype + input_shape = tensor.shape + input_device = tensor.device + input_size = tensor.numel() + tensor = tensor.flatten() + + fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 + + ret, scale = cast_to_fp8(tensor, fp8_format=fp8_format) + + inp = ret.view(torch.uint8) + input_chunks = list(torch.chunk(inp, world_size, dim=0)) + if dist.get_rank() == world_size - 1: + output_chunks = [torch.empty_like(input_chunks[-1]) for _ in range(world_size)] + else: + output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)] + dist.all_to_all(output_chunks, input_chunks) + scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)] + dist.all_gather(scale_list, scale) + summed_out = torch.zeros_like(output_chunks[0]).to(input_type) + for scale, out in zip(scale_list, output_chunks): + out = out.view(fp8_type) + summed_out += cast_from_fp8(out, scale, input_type) + + summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format) + dist.all_gather(scale_list, scale) + + tensor_list = list(torch.chunk(torch.empty(input_size, device=input_device, dtype=torch.uint8), world_size, dim=0)) + dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8)) + for i in range(world_size): + tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i] + tensor_out = torch.cat(tensor_list, dim=0) + tensor.data = tensor_out.view(input_shape).to(input_type) \ No newline at end of file From e17f835df7c637e18df708b929b570c2ac459434 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 4 Jul 2024 12:47:16 +0000 Subject: [PATCH 2/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/quantization/fp8.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index d405de2de3d1..c880cd4aa44b 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -69,7 +69,7 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None: """ world_size = dist.get_world_size() - rank = dist.get_rank() + dist.get_rank() input_type = tensor.dtype input_shape = tensor.shape input_device = tensor.device @@ -102,4 +102,4 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None: for i in range(world_size): tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i] tensor_out = torch.cat(tensor_list, dim=0) - tensor.data = tensor_out.view(input_shape).to(input_type) \ No newline at end of file + tensor.data = tensor_out.view(input_shape).to(input_type) From dbfa7d39fc06534cf3d44ba8d1a5ae4d147d7133 Mon Sep 17 00:00:00 2001 From: GuangyaoZhang Date: Wed, 10 Jul 2024 08:13:26 +0000 Subject: [PATCH 3/7] fix typo --- colossalai/quantization/fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index c880cd4aa44b..58cedbc9554f 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -14,7 +14,7 @@ def cast_to_fp8(inp: torch.Tensor, scale=None, fp8_format="e4m3") -> (torch.Tens Returns: Tuples: A tuple (fp8_tensor, scale) """ - if inp.dtype in [torch.FloatTensor, torch.HalfTensor, torch.BFloat16Tensor]: + if inp.dtype not in [torch.float32, torch.float16, torch.bfloat16]: return inp fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 From 1e1959467e4cbc0c0bc15aff4c19c969e76c9a72 Mon Sep 17 00:00:00 2001 From: BurkeHulk Date: Fri, 12 Jul 2024 15:23:37 +0800 Subject: [PATCH 4/7] fix scaling algorithm in FP8 casting --- colossalai/quantization/fp8.py | 44 ++++++++++++++++++---------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index d405de2de3d1..051ecb45a8fc 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -1,8 +1,10 @@ +from typing import Any, Callable, List, Optional, Tuple, Union + import torch import torch.distributed as dist -def cast_to_fp8(inp: torch.Tensor, scale=None, fp8_format="e4m3") -> (torch.Tensor, torch.Tensor): +def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3") -> (torch.Tensor, torch.Tensor): r""" casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling. Args: @@ -14,28 +16,28 @@ def cast_to_fp8(inp: torch.Tensor, scale=None, fp8_format="e4m3") -> (torch.Tens Returns: Tuples: A tuple (fp8_tensor, scale) """ - if inp.dtype in [torch.FloatTensor, torch.HalfTensor, torch.BFloat16Tensor]: - return inp + + if inp.dtype not in [torch.float32, torch.float16, torch.bfloat16]: + raise TypeError("Only float16, bfloat16, and float32 are allowed.") + fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 + fp8_max = torch.finfo(fp8_type).max if inp.dim() == 2: - if scale is None: - per_channel_max = inp.abs().max(dim=-1).values - scale = per_channel_max - scale_inv = 1.0 / scale - scale_inv = scale_inv[:, None] - ret = (scale_inv * inp).to(fp8_type) + per_channel_max = inp.abs().max(dim=-1).values.float() + per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0) + scale = fp8_max / per_channel_max[:, None] else: - if scale is None: - per_tensor_max = inp.abs().max() - scale = per_tensor_max - scale_inv = 1.0 / scale - ret = (scale_inv * inp).to(fp8_type) + per_tensor_max = inp.abs().max().float() + per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0) + scale = fp8_max / per_tensor_max - return ret, scale + scale_inv = 1.0 / scale + ret = (scale * inp.float()).to(fp8_type) + return ret, scale_inv -def cast_from_fp8(inp: torch.Tensor, scale: torch.Tensor, ret_type: torch.dtype) -> torch.Tensor: +def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype) -> torch.Tensor: r""" Args: @@ -47,12 +49,13 @@ def cast_from_fp8(inp: torch.Tensor, scale: torch.Tensor, ret_type: torch.dtype) torch.Tensor """ if inp.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: - return inp + raise TypeError("Only float8_e4m3fn and float8_e5m2 are allowed.") + if inp.dim() == 2: - ret = scale[:, None] * inp.to(ret_type) + ret = scale_inv[:, None] * inp.float() else: - ret = scale * inp.to(ret_type) - return ret + ret = scale_inv * inp.float() + return ret.to(ret_type) def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None: @@ -69,7 +72,6 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None: """ world_size = dist.get_world_size() - rank = dist.get_rank() input_type = tensor.dtype input_shape = tensor.shape input_device = tensor.device From e88190184aea4e67fe5432d0501314287fedec62 Mon Sep 17 00:00:00 2001 From: BurkeHulk Date: Fri, 12 Jul 2024 15:25:25 +0800 Subject: [PATCH 5/7] support fp8 communication in pipeline parallelism --- .../booster/plugin/hybrid_parallel_plugin.py | 3 + .../pipeline/schedule/interleaved_pp.py | 29 ++++++++ colossalai/pipeline/schedule/one_f_one_b.py | 26 +++++++ colossalai/quantization/fp8.py | 69 ++++++++++++++++++- 4 files changed, 126 insertions(+), 1 deletion(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index a3d6f1e74771..b818209a6668 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -992,6 +992,7 @@ def __init__( make_vocab_size_divisible_by: int = 64, dp_outside: bool = True, overlap_p2p: bool = True, + fp8_communication: bool = False, ) -> None: super().__init__() assert ( @@ -1082,6 +1083,7 @@ def __init__( microbatch_size=microbatch_size, enable_metadata_cache=enable_metadata_cache, overlap_p2p=overlap_p2p, + fp8_communication=fp8_communication, ) elif pp_style == "1f1b": self.schedule = OneForwardOneBackwardSchedule( @@ -1089,6 +1091,7 @@ def __init__( num_microbatches=num_microbatches, microbatch_size=microbatch_size, enable_metadata_cache=enable_metadata_cache, + fp8_communication=fp8_communication, ) else: raise NotImplementedError() diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index a21b45c44a2c..86ce536d0a3c 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -12,6 +12,7 @@ from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.utils import get_current_device +from colossalai.quantization.fp8 import cast_to_fp8_pipeline, cast_from_fp8_pipeline from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device from .base import PipelineSchedule @@ -32,6 +33,7 @@ def __init__( microbatch_size: Optional[int] = None, enable_metadata_cache: bool = True, overlap_p2p: bool = True, + fp8_communication: bool = False, ) -> None: super().__init__(stage_manager) assert ( @@ -56,6 +58,7 @@ def __init__( self.tensor_metadata_recv = None self.grad_metadata_recv = None + self.fp8_communication = fp8_communication def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -191,8 +194,12 @@ def send_forward(self, model_chunk_id: int, output_tensor: Any, next_rank: int = """ with self.stage_manager.switch_model_chunk_id(model_chunk_id): if not self.stage_manager.is_last_stage(): + if self.fp8_communication: + cast_to_fp8_pipeline(output_tensor) send_handles = self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata) self.send_tensor_metadata = not self.enable_metadata_cache + if self.fp8_communication: + cast_from_fp8_pipeline(output_tensor) return send_handles return [] @@ -210,10 +217,14 @@ def send_backward(self, model_chunk_id: int, input_tensor_grad: Any, prev_rank: """ with self.stage_manager.switch_model_chunk_id(model_chunk_id): if not self.stage_manager.is_first_stage(): + if self.fp8_communication: + cast_to_fp8_pipeline(input_tensor_grad) send_handles = self.comm.send_backward( input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata ) self.send_grad_metadata = not self.enable_metadata_cache + if self.fp8_communication: + cast_from_fp8_pipeline(input_tensor_grad) return send_handles return [] @@ -224,6 +235,8 @@ def send_forward_recv_forward( is_send = not self.stage_manager.is_last_stage() with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv): is_recv = not self.stage_manager.is_first_stage() + if self.fp8_communication: + cast_to_fp8_pipeline(output_tensor) input_tensor, wait_handles = self.comm.send_forward_recv_forward( output_tensor, is_send, @@ -237,6 +250,8 @@ def send_forward_recv_forward( if is_recv and self.enable_metadata_cache and self.tensor_metadata_recv is None: self.tensor_metadata_recv = create_send_metadata(input_tensor) + if self.fp8_communication: + cast_from_fp8_pipeline(output_tensor) return input_tensor, wait_handles def send_backward_recv_backward( @@ -246,6 +261,8 @@ def send_backward_recv_backward( is_send = not self.stage_manager.is_first_stage() with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv): is_recv = not self.stage_manager.is_last_stage() + if self.fp8_communication: + cast_to_fp8_pipeline(input_tensor_grad) output_tensor_grad, wait_handles = self.comm.send_backward_recv_backward( input_tensor_grad, is_send, @@ -258,6 +275,8 @@ def send_backward_recv_backward( self.send_grad_metadata = not self.enable_metadata_cache and is_send if is_recv and self.enable_metadata_cache and self.grad_metadata_recv is None: self.grad_metadata_recv = create_send_metadata(output_tensor_grad) + if self.fp8_communication: + cast_from_fp8_pipeline(input_tensor_grad) return output_tensor_grad, wait_handles def forward_step( @@ -379,6 +398,8 @@ def run_forward_only( # Wait until current input is received _wait_p2p(fwd_wait_handles) + if self.fp8_communication and input_obj is not None: + cast_from_fp8_pipeline(input_obj) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) if not last_batch: @@ -441,6 +462,8 @@ def run_forward_backward( # Wait for input _wait_p2p(fwd_wait_handles) + if self.fp8_communication and input_obj is not None: + cast_from_fp8_pipeline(input_obj) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) input_objs[model_chunk_id].append(input_obj) output_objs[model_chunk_id].append(output_obj) @@ -467,6 +490,8 @@ def run_forward_backward( # Wait for input. _wait_p2p(fwd_wait_handles) + if self.fp8_communication and input_obj is not None: + cast_from_fp8_pipeline(input_obj) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) # Add input_obj and output_obj to end of list. input_objs[model_chunk_id].append(input_obj) @@ -511,6 +536,8 @@ def send_backward_recv_backward(): input_obj, fwd_wait_handles = send_forward_recv_forward() # Wait for upstream grad _wait_p2p(bwd_wait_handles) + if self.fp8_communication and output_obj_grad is not None: + cast_from_fp8_pipeline(output_obj_grad) input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad) # NOTE: It's documented by NCCL that running two concurrent communicators (batch_isend_irecv) # risks deadlock (https://docs.nvidia.com/deeplearning/nccl/archives/nccl_2134/user-guide/docs/usage/communicators.html) @@ -532,6 +559,8 @@ def send_backward_recv_backward(): # Wait for upstream grad _wait_p2p(bwd_wait_handles) + if self.fp8_communication and output_obj_grad is not None: + cast_from_fp8_pipeline(output_obj_grad) # backward local grads input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad) if not last_batch: diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 7f0d0e3493f7..90ebb0534497 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -11,6 +11,7 @@ from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.utils import get_current_device +from colossalai.quantization.fp8 import cast_to_fp8_pipeline, cast_from_fp8_pipeline from ._utils import ( detach, @@ -32,6 +33,7 @@ def __init__( num_microbatches: Optional[int] = None, microbatch_size: Optional[int] = None, enable_metadata_cache: bool = True, + fp8_communication: bool = False, ) -> None: """1F1B pipeline schedule. @@ -61,6 +63,8 @@ def __init__( self.tensor_metadata_recv = None self.grad_metadata_recv = None + self.fp8_communication = fp8_communication + def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -129,6 +133,8 @@ def recv_forward(self, prev_rank: int = None) -> Any: if self.enable_metadata_cache and self.tensor_metadata_recv is None: self.tensor_metadata_recv = create_send_metadata(input_tensor) + if self.fp8_communication: + cast_from_fp8_pipeline(input_tensor) return input_tensor def recv_backward(self, next_rank: int = None) -> Any: @@ -143,6 +149,8 @@ def recv_backward(self, next_rank: int = None) -> Any: """ if not self.stage_manager.is_last_stage(): output_tensor_grad, _ = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv) + if self.fp8_communication: + cast_from_fp8_pipeline(output_tensor_grad) if self.enable_metadata_cache and self.grad_metadata_recv is None: self.grad_metadata_recv = create_send_metadata(output_tensor_grad) @@ -157,9 +165,13 @@ def send_forward(self, output_tensor: Any, next_rank: int = None) -> None: next_rank (int, optional): The rank of the recipient of the tensor. """ if not self.stage_manager.is_last_stage(): + if self.fp8_communication: + cast_to_fp8_pipeline(output_tensor) self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata) self.send_tensor_metadata = not self.enable_metadata_cache + if self.fp8_communication: + cast_from_fp8_pipeline(output_tensor, del_metadata=False) def send_backward(self, input_tensor_grad: Any, prev_rank: int = None) -> None: """Sends the gradient tensor to the previous stage in pipeline. For 1F1B. @@ -169,8 +181,12 @@ def send_backward(self, input_tensor_grad: Any, prev_rank: int = None) -> None: prev_rank (int, optional): The rank of the recipient of the tensor """ if not self.stage_manager.is_first_stage(): + if self.fp8_communication: + cast_to_fp8_pipeline(input_tensor_grad) self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata) self.send_grad_metadata = not self.enable_metadata_cache + if self.fp8_communication: + cast_from_fp8_pipeline(input_tensor_grad, del_metadata=False) def send_forward_recv_backward(self, output_tensor: Any, send_first: Optional[bool] = None) -> Any: """Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline. @@ -183,6 +199,8 @@ def send_forward_recv_backward(self, output_tensor: Any, send_first: Optional[bo if not self.stage_manager.is_last_stage(): if not self.send_tensor_metadata and self.grad_metadata_recv is not None: send_first = None + if self.fp8_communication: + cast_to_fp8_pipeline(output_tensor) output_tensor_grad, _ = self.comm.send_forward_recv_backward( output_tensor, send_metadata=self.send_tensor_metadata, @@ -192,6 +210,9 @@ def send_forward_recv_backward(self, output_tensor: Any, send_first: Optional[bo self.send_tensor_metadata = not self.enable_metadata_cache if self.enable_metadata_cache and self.grad_metadata_recv is None: self.grad_metadata_recv = create_send_metadata(output_tensor_grad) + if self.fp8_communication: + cast_from_fp8_pipeline(output_tensor, del_metadata=False) + cast_from_fp8_pipeline(output_tensor_grad) return output_tensor_grad @@ -206,6 +227,8 @@ def send_backward_recv_forward(self, input_tensor_grad: Any, send_first: Optiona if not self.stage_manager.is_first_stage(): if not self.send_grad_metadata and self.tensor_metadata_recv is not None: send_first = None # must not fallback + if self.fp8_communication: + cast_to_fp8_pipeline(input_tensor_grad) input_tensor, _ = self.comm.send_backward_recv_forward( input_tensor_grad, send_metadata=self.send_grad_metadata, @@ -215,6 +238,9 @@ def send_backward_recv_forward(self, input_tensor_grad: Any, send_first: Optiona self.send_grad_metadata = not self.enable_metadata_cache if self.enable_metadata_cache and self.tensor_metadata_recv is None: self.tensor_metadata_recv = create_send_metadata(input_tensor) + if self.fp8_communication: + cast_from_fp8_pipeline(input_tensor) + cast_from_fp8_pipeline(input_tensor_grad, del_metadata=False) return input_tensor diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 051ecb45a8fc..c02223331163 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -104,4 +104,71 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None: for i in range(world_size): tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i] tensor_out = torch.cat(tensor_list, dim=0) - tensor.data = tensor_out.view(input_shape).to(input_type) \ No newline at end of file + tensor.data = tensor_out.view(input_shape).to(input_type) + + + +def cast_to_fp8_pipeline(inp: Any) -> None: + """ + Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline. + The activations tensor is indexed by 'hidden_states' in the inp dict. + After FP8 casting, the resulting tensor is saved as float16 or bfloat16 format but the size becomes halved. + Metadata such as fp8_scale is saved into inp dict for communication. + """ + if inp is None: + return + # In pipeline parallelism, when inp is torch.Tensor, it only contains one element, thus can be omitted. + if type(inp) == torch.Tensor: + return + + assert 'hidden_states' in inp, 'required by pipeline parallelism.' + inp_tensor = inp["hidden_states"] + + min_val, max_val = inp_tensor.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()) + + finfo = torch.finfo(torch.float8_e4m3fn) + if amax > finfo.max: + fp8_type = torch.float8_e5m2 + fp8_view_type = torch.float16 + else: + fp8_type = torch.float8_e4m3fn + fp8_view_type = torch.bfloat16 + + finfo = torch.finfo(fp8_type) + scale = torch.tensor(1.0).to(inp_tensor.device) if amax == 0.0 else finfo.max / amax.float() + q_tensor = (inp_tensor.data.float() * scale) + # Todo: Currently we use fp8_view_type to indicate which fp8 format is used. This is a temporary workaround due to 'Only support tensor for fast send'. + # inp_tensor needs to be a float datatype to avoid error during gradient placement. + inp_tensor.data = q_tensor.to(fp8_type).view(fp8_view_type) + + inp["fp8_scale"] = scale.float().reciprocal() + + + +def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None: + """ + Cast the FP8 encoded hidden_states tensor back to original dtype after p2p communication in pipeline. + del_metadata = False is useful when this function is called before p2p communication. + """ + if inp is None: + return + if type(inp) == torch.Tensor: + return + + assert 'hidden_states' in inp, 'required by pipeline parallelism.' + inp_tensor = inp["hidden_states"] + scale = inp["fp8_scale"] + + fp8_view_type = inp_tensor.dtype + if fp8_view_type == torch.float16: + fp8_type = torch.float8_e5m2 + elif fp8_view_type == torch.bfloat16: + fp8_type = torch.float8_e4m3fn + else: + raise TypeError("Only float16, bfloat16 are implemented.") + + inp_tensor.data = inp_tensor.data.view(fp8_type).to(torch.float16) * scale + + if del_metadata: + del inp["fp8_scale"] \ No newline at end of file From 66018749f3fd79ff92e36b6fa39262f4c6355872 Mon Sep 17 00:00:00 2001 From: BurkeHulk Date: Fri, 12 Jul 2024 15:26:17 +0800 Subject: [PATCH 6/7] add fp8_communication flag in the script --- examples/language/bert/finetune.py | 2 ++ examples/language/gpt/hybridparallelism/finetune.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index 7e8c07fdce47..8a59ab6838a6 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -190,6 +190,7 @@ def main(): ) parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached") parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context") + parser.add_argument("--use_fp8_comm", type=bool, default=False, help="for using fp8 during communication") args = parser.parse_args() if args.model_type == "bert": @@ -232,6 +233,7 @@ def main(): zero_stage=1, precision="fp16", initial_scale=1, + fp8_communication=args.use_fp8_comm, ) booster = Booster(plugin=plugin, **booster_kwargs) diff --git a/examples/language/gpt/hybridparallelism/finetune.py b/examples/language/gpt/hybridparallelism/finetune.py index 777d16cb9ea0..9b3a101609dc 100644 --- a/examples/language/gpt/hybridparallelism/finetune.py +++ b/examples/language/gpt/hybridparallelism/finetune.py @@ -187,6 +187,7 @@ def main(): ) parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached") parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context") + parser.add_argument("--use_fp8_comm", type=bool, default=False, help="for using fp8 during communication") args = parser.parse_args() if args.model_type == "gpt2": @@ -225,6 +226,7 @@ def main(): zero_stage=1, precision="fp16", initial_scale=1, + fp8_communication=args.use_fp8_comm, ) booster = Booster(plugin=plugin, **booster_kwargs) From 51f916b11d87ecdfa3763da7a6b396a030b32b13 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 12 Jul 2024 07:33:44 +0000 Subject: [PATCH 7/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/pipeline/schedule/interleaved_pp.py | 3 ++- colossalai/pipeline/schedule/one_f_one_b.py | 3 ++- colossalai/quantization/fp8.py | 12 +++++------- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 86ce536d0a3c..a7571c73139b 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -11,8 +11,8 @@ from colossalai.interface import OptimizerWrapper from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.quantization.fp8 import cast_from_fp8_pipeline, cast_to_fp8_pipeline from colossalai.utils import get_current_device -from colossalai.quantization.fp8 import cast_to_fp8_pipeline, cast_from_fp8_pipeline from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device from .base import PipelineSchedule @@ -59,6 +59,7 @@ def __init__( self.grad_metadata_recv = None self.fp8_communication = fp8_communication + def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 90ebb0534497..3269d67ba7d4 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -10,8 +10,8 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.quantization.fp8 import cast_from_fp8_pipeline, cast_to_fp8_pipeline from colossalai.utils import get_current_device -from colossalai.quantization.fp8 import cast_to_fp8_pipeline, cast_from_fp8_pipeline from ._utils import ( detach, @@ -172,6 +172,7 @@ def send_forward(self, output_tensor: Any, next_rank: int = None) -> None: if self.fp8_communication: cast_from_fp8_pipeline(output_tensor, del_metadata=False) + def send_backward(self, input_tensor_grad: Any, prev_rank: int = None) -> None: """Sends the gradient tensor to the previous stage in pipeline. For 1F1B. diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index c02223331163..e514f435eaed 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any import torch import torch.distributed as dist @@ -107,7 +107,6 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None: tensor.data = tensor_out.view(input_shape).to(input_type) - def cast_to_fp8_pipeline(inp: Any) -> None: """ Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline. @@ -121,7 +120,7 @@ def cast_to_fp8_pipeline(inp: Any) -> None: if type(inp) == torch.Tensor: return - assert 'hidden_states' in inp, 'required by pipeline parallelism.' + assert "hidden_states" in inp, "required by pipeline parallelism." inp_tensor = inp["hidden_states"] min_val, max_val = inp_tensor.aminmax() @@ -137,7 +136,7 @@ def cast_to_fp8_pipeline(inp: Any) -> None: finfo = torch.finfo(fp8_type) scale = torch.tensor(1.0).to(inp_tensor.device) if amax == 0.0 else finfo.max / amax.float() - q_tensor = (inp_tensor.data.float() * scale) + q_tensor = inp_tensor.data.float() * scale # Todo: Currently we use fp8_view_type to indicate which fp8 format is used. This is a temporary workaround due to 'Only support tensor for fast send'. # inp_tensor needs to be a float datatype to avoid error during gradient placement. inp_tensor.data = q_tensor.to(fp8_type).view(fp8_view_type) @@ -145,7 +144,6 @@ def cast_to_fp8_pipeline(inp: Any) -> None: inp["fp8_scale"] = scale.float().reciprocal() - def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None: """ Cast the FP8 encoded hidden_states tensor back to original dtype after p2p communication in pipeline. @@ -156,7 +154,7 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None: if type(inp) == torch.Tensor: return - assert 'hidden_states' in inp, 'required by pipeline parallelism.' + assert "hidden_states" in inp, "required by pipeline parallelism." inp_tensor = inp["hidden_states"] scale = inp["fp8_scale"] @@ -171,4 +169,4 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None: inp_tensor.data = inp_tensor.data.view(fp8_type).to(torch.float16) * scale if del_metadata: - del inp["fp8_scale"] \ No newline at end of file + del inp["fp8_scale"]