From 7e8b894c3072875ba75bd980e717febdbec5c8cd Mon Sep 17 00:00:00 2001 From: GuangyaoZhang Date: Mon, 22 Jul 2024 07:38:52 +0000 Subject: [PATCH 1/7] add llama shardformer fp8 --- colossalai/quantization/fp8.py | 1 - colossalai/shardformer/layer/_operation.py | 150 ++++++++++++++++----- colossalai/shardformer/layer/linear.py | 22 ++- colossalai/shardformer/modeling/llama.py | 26 ++-- colossalai/shardformer/policies/llama.py | 14 +- 5 files changed, 157 insertions(+), 56 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index fe5bd5744e69..88107982be1d 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -66,7 +66,6 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e5m2", group=None) -> None: Returns: None """ - world_size = dist.get_world_size(group=group) input_type = tensor.dtype input_shape = tensor.shape diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index bbf1c862d820..2012ec70e0b0 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -117,11 +117,12 @@ class LinearWithAsyncCommunication(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): + def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_allreduce = async_grad_allreduce + ctx.fp8_communication = fp8_communication if bias is not None: output = F.linear(input_, weight, bias) else: @@ -133,6 +134,7 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): def backward(ctx, grad_output): input, weight, bias = ctx.saved_tensors use_bias = ctx.use_bias + fp8_communication = ctx.fp8_communication # In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias. if use_bias: @@ -148,7 +150,10 @@ def backward(ctx, grad_output): if ctx.async_grad_allreduce: # Asynchronous all-reduce - handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) + if fp8_communication: + all_reduce_fp8(grad_input, group=ctx.process_group) + else: + handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) # Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py @@ -167,10 +172,10 @@ def backward(ctx, grad_output): grad_bias = grad_output.sum(dim=0) if use_bias else None - if ctx.async_grad_allreduce: + if ctx.async_grad_allreduce and not fp8_communication: handle.wait() - return grad_input, grad_weight, grad_bias, None, None, None + return grad_input, grad_weight, grad_bias, None, None, None, None def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False): @@ -238,16 +243,18 @@ class _GatherForwardReduceScatterBackward(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, process_group, dim): + def forward(ctx, input_, process_group, dim, fp8_communication=False): ctx.process_group = process_group ctx.dim = dim + ctx.fp8_communication = fp8_communication - return _gather(input_, dim, process_group) + return _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3") @staticmethod def backward(ctx, grad_output): dim = ctx.dim process_group = ctx.process_group + fp8_communication = ctx.fp8_communication # do reduce-scatter new_shape = list(grad_output.shape) assert ( @@ -259,9 +266,12 @@ def backward(ctx, grad_output): ] output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device) - dist.reduce_scatter(output, grad_list, group=process_group) + if fp8_communication: + reduce_scatter_fp8(output, grad_list, group=process_group, fp8_format="e5m2") + else: + dist.reduce_scatter(output, grad_list, group=process_group) - return output, None, None + return output, None, None, None class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): @@ -577,12 +587,8 @@ def backward(ctx, grad_output): dim = ctx.dim process_group = ctx.process_group fp8_communication = ctx.fp8_communication - return ( - _gather(grad_output, dim, process_group, fp8_communication=fp8_communication, fp8_format="e5m2"), - None, - None, - None, - ) + + return _gather(grad_output, dim, process_group, fp8_communication, fp8_format="e5m2"), None, None, None class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): @@ -816,26 +822,67 @@ class _AllToAll(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, process_group, scatter_dim, gather_dim): + def forward(ctx, input_, process_group, scatter_dim, gather_dim, fp8_communication=False): ctx.process_group = process_group ctx.scatter_dim = scatter_dim ctx.gather_dim = gather_dim + ctx.fp8_communication = fp8_communication world_size = dist.get_world_size(process_group) bsz, _, _ = input_.shape # using all_to_all_single when batch size is 1 if bsz == 1: - return _all_to_all_single(input_, world_size, process_group, scatter_dim, gather_dim) + return _all_to_all_single( + input_, + world_size, + process_group, + scatter_dim, + gather_dim, + fp8_communication=fp8_communication, + fp8_format="e4m3", + ) else: - return _all_to_all(input_, world_size, process_group, scatter_dim, gather_dim) + return _all_to_all( + input_, + world_size, + process_group, + scatter_dim, + gather_dim, + fp8_communication=fp8_communication, + fp8_format="e4m3", + ) @staticmethod - def backward(ctx, *grad_output): + def backward(ctx, grad_output): process_group = ctx.process_group scatter_dim = ctx.gather_dim gather_dim = ctx.scatter_dim - return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim) - return (return_grad, None, None, None) + fp8_communication = ctx.fp8_communication + world_size = dist.get_world_size(process_group) + bsz, _, _ = grad_output.shape + + if bsz == 1: + return_grad = _all_to_all_single( + grad_output, + world_size, + process_group, + scatter_dim, + gather_dim, + fp8_communication=fp8_communication, + fp8_format="e5m2", + ) + else: + return_grad = _all_to_all( + grad_output, + world_size, + process_group, + scatter_dim, + gather_dim, + fp8_communication=fp8_communication, + fp8_format="e5m2", + ) + + return (return_grad, None, None, None, None) class HookParameter(torch.autograd.Function): @@ -954,14 +1001,33 @@ def _reduce_scatter(input_, dim=1, process_group=None): return output -def _all_to_all(input_, world_size, group, scatter_dim, gather_dim): - input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] - output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] - dist.all_to_all(output_list, input_list, group=group) +def _all_to_all(input_, world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2"): + if fp8_communication: + input_type = input_.dtype + ret, scale = cast_to_fp8(input_, fp8_format=fp8_format) + fp8_type = ret.dtype + input_ = ret.view(torch.uint8) + input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] + output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] + scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(world_size)] + dist.all_to_all(output_list, input_list, group=group) + dist.all_gather(scale_list, scale, group=group) + cast_tensor_list = [] + for output, scale in zip(output_list, scale_list): + output = output.view(fp8_type) + output = cast_from_fp8(output, scale, input_type) + cast_tensor_list.append(output) + output_list = cast_tensor_list + else: + input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] + output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] + dist.all_to_all(output_list, input_list, group=group) return torch.cat(output_list, dim=gather_dim).contiguous() -def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim): +def _all_to_all_single( + input_, seq_world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2" +): inp_shape = list(input_.shape) inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size if scatter_dim < 2: @@ -973,8 +1039,24 @@ def _all_to_all_single(input_, seq_world_size, group, scatter_dim, gather_dim): .contiguous() ) - output = torch.empty_like(input_t) - dist.all_to_all_single(output, input_t, group=group) + if fp8_communication: + input_type = input_t.dtype + ret, scale = cast_to_fp8(input_t, fp8_format=fp8_format) + fp8_type = ret.dtype + input_t = ret.view(torch.uint8) + output = torch.empty_like(input_t) + scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(seq_world_size)] + dist.all_to_all_single(output, input_t, group=group) + dist.all_gather(scale_list, scale, group=group) + cast_tensor_list = [] + for output_part, scale in zip(output, scale_list): + output_part = output_part.view(fp8_type) + output_part = cast_from_fp8(output_part, scale, input_type) + cast_tensor_list.append(output_part) + output = torch.stack(cast_tensor_list, dim=0) + else: + output = torch.empty_like(input_t) + dist.all_to_all_single(output, input_t, group=group) if scatter_dim < 2: output = output.transpose(0, 1).contiguous() @@ -994,8 +1076,10 @@ def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allre ) -def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): - return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) +def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False): + return LinearWithAsyncCommunication.apply( + input_, weight, bias, process_group, async_grad_allreduce, fp8_communication + ) def linear_gather_forward_reducescatter_backward( @@ -1006,8 +1090,8 @@ def linear_gather_forward_reducescatter_backward( ) -def gather_forward_reducescatter_backward(input_, process_group, dim): - return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim) +def gather_forward_reducescatter_backward(input_, process_group, dim, fp8_communication=False): + return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim, fp8_communication) def reducescatter_forward_gather_backward(input_, process_group, dim, fp8_communication=False): @@ -1042,5 +1126,5 @@ def reduce_backward(input_, process_group, fp8_communication=False): return _ReduceBackward.apply(input_, process_group, fp8_communication) -def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1): - return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim) +def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8_communication=False): + return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_communication) diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 37c7542416f6..38a6ef1a19ae 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -84,6 +84,7 @@ def __init__( bias_: Optional[Parameter] = None, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + fp8_communication: bool = False, **kwargs, ): super().__init__(weight=weight, bias_=bias_, **kwargs) @@ -98,6 +99,7 @@ def __init__( self.skip_bias_add = skip_bias_add self.device = device self.process_group = process_group + self.fp8_communication = fp8_communication if skip_bias_add and not bias: raise ValueError("cannot skip bias addition if bias is None") @@ -201,10 +203,12 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: bias = self.bias if not self.skip_bias_add else None if self.seq_parallel_mode is None: - output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) + output_parallel = linear_with_async_comm( + input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication + ) elif self.seq_parallel_mode == "split_gather": input_parallel = gather_forward_reducescatter_backward( - input_parallel, self.process_group, self.seq_parallel_dim + input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication ) output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, False) elif self.seq_parallel_mode == "ring": @@ -214,7 +218,9 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: if self.gather_output: # All-gather across the partitions. - output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + output = gather_forward_split_backward( + output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication + ) else: output = output_parallel @@ -264,6 +270,7 @@ def __init__( weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), stream_chunk_num: int = 1, + fp8_communication: bool = False, ): super().__init__() @@ -278,6 +285,7 @@ def __init__( self.seq_parallel_mode = seq_parallel_mode self.seq_parallel_dim = seq_parallel_dim self.num_partitions = dist.get_world_size(self.process_group) + self.fp8_communication = fp8_communication if skip_bias_add and not bias: raise ValueError("cannot skip bias addition if bias is None") @@ -398,7 +406,9 @@ def forward(self, input_: Tensor) -> Tensor: ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format( input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions ) - input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group) + input_ = split_forward_gather_backward( + input_, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication + ) if self.stream_chunk_num > 1: if self.training: @@ -418,11 +428,11 @@ def forward(self, input_: Tensor) -> Tensor: else: if self.seq_parallel_mode is None: output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) - output = reduce_forward(output_parallel, self.process_group) + output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication) elif self.seq_parallel_mode == "split_gather": output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False) output = reducescatter_forward_gather_backward( - output_parallel, self.process_group, self.seq_parallel_dim + output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication ) elif self.seq_parallel_mode == "ring": output = linear_reducescatter_forward_gather_backward( diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index bf5ce45a8342..c8d39573e6d4 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -460,7 +460,7 @@ def llama_for_sequence_classification_forward( return {"hidden_states": hidden_states} -def get_llama_flash_attention_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): +def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): def forward( self, hidden_states: torch.Tensor, @@ -510,9 +510,9 @@ def forward( # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": - query_states = all_to_all_comm(query_states, sp_group) - key_states = all_to_all_comm(key_states, sp_group) - value_states = all_to_all_comm(value_states, sp_group) + query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication) + key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication) + value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication) bsz, q_len, _ = query_states.size() query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -592,7 +592,7 @@ def forward( return forward -def get_llama_flash_attention_model_forward(shard_config, sp_mode=None, sp_size=None, sp_group=None): +def get_llama_flash_attention_model_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): logger = logging.get_logger(__name__) def forward( @@ -659,9 +659,13 @@ def forward( attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) if sp_mode in ["ring", "split_gather"]: - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group) + inputs_embeds = split_forward_gather_backward( + inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication + ) elif sp_mode == "all_to_all": - inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size) + inputs_embeds = split_forward_gather_backward( + inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication + ) hidden_states = inputs_embeds # decoder layers @@ -706,9 +710,13 @@ def forward( hidden_states = self.norm(hidden_states) if sp_mode == "ring" or sp_mode == "split_gather": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group) + hidden_states = gather_forward_split_backward( + hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication + ) elif sp_mode == "all_to_all": - hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size) + hidden_states = gather_forward_split_backward( + hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication + ) # add hidden states from the last decoder layer if output_hidden_states: diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 85ec6717d0d1..be21b33e10a6 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -134,37 +134,37 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: SubModuleReplacementDescription( suffix="self_attn.q_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="self_attn.k_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="self_attn.v_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="self_attn.o_proj", target_module=Linear1D_Row, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="mlp.gate_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="mlp.up_proj", target_module=Linear1D_Col, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), SubModuleReplacementDescription( suffix="mlp.down_proj", target_module=Linear1D_Row, - kwargs=dict(seq_parallel_mode=sp_mode), + kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication), ), ], ) From 83dfa97d7311640d9ab89acb73302d93a13f2814 Mon Sep 17 00:00:00 2001 From: GuangyaoZhang Date: Mon, 29 Jul 2024 09:06:55 +0000 Subject: [PATCH 2/7] Llama Shardformer Parity --- colossalai/quantization/fp8.py | 6 +-- colossalai/shardformer/layer/_operation.py | 1 + colossalai/shardformer/modeling/llama.py | 4 +- colossalai/shardformer/policies/llama.py | 1 - tests/test_fp8/test_fp8_all_to_all.py | 34 +++++++++++++++ tests/test_fp8/test_fp8_all_to_all_single.py | 34 +++++++++++++++ ...llgather.py => test_fp8_allgather_flat.py} | 4 +- tests/test_fp8/test_fp8_gather.py | 41 +++++++++++++++++++ tests/test_fp8/test_fp8_reduce_scatter.py | 38 +++++++++++++++++ 9 files changed, 156 insertions(+), 7 deletions(-) create mode 100644 tests/test_fp8/test_fp8_all_to_all.py create mode 100644 tests/test_fp8/test_fp8_all_to_all_single.py rename tests/test_fp8/{test_fp8_allgather.py => test_fp8_allgather_flat.py} (96%) create mode 100644 tests/test_fp8/test_fp8_gather.py create mode 100644 tests/test_fp8/test_fp8_reduce_scatter.py diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 88107982be1d..8887310b43a7 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -5,7 +5,7 @@ import torch.distributed as dist -def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3") -> (torch.Tensor, torch.Tensor): +def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> (torch.Tensor, torch.Tensor): r""" casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling. Args: @@ -23,7 +23,7 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3") -> (torch.Tensor, torch.Te fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 fp8_max = torch.finfo(fp8_type).max - if inp.dim() == 2: + if per_channel_scale: per_channel_max = inp.abs().max(dim=-1).values.float() per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0) scale = fp8_max / per_channel_max[:, None] @@ -49,7 +49,7 @@ def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dt if inp.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: raise TypeError("Only float8_e4m3fn and float8_e5m2 are allowed.") - if inp.dim() == 2: + if scale_inv.dim() >= 1: ret = scale_inv[:, None] * inp.float() else: ret = scale_inv * inp.float() diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 2012ec70e0b0..bd69dc6a08f7 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -962,6 +962,7 @@ def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_for cast_tensor_list = [] for output, scale in zip(tensor_list, scale_list): + scale = torch.tensor(scale[0]) output = output.view(fp8_type) output = cast_from_fp8(output, scale, input_type) cast_tensor_list.append(output) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index c8d39573e6d4..f83735f0520d 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -574,7 +574,9 @@ def forward( # sp: all-to-all comminucation when introducing sequence parallel if sp_mode == "all_to_all": attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) - attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) + attn_output = all_to_all_comm( + attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication + ) else: attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index be21b33e10a6..0f28f6cf49a9 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -65,7 +65,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: norm_cls = FusedRMSNorm else: norm_cls = RMSNorm - if self.pipeline_stage_manager is not None: self.shard_config.enable_sequence_parallelism = False self.shard_config.enable_sequence_overlap = False diff --git a/tests/test_fp8/test_fp8_all_to_all.py b/tests/test_fp8/test_fp8_all_to_all.py new file mode 100644 index 000000000000..eed3583e9866 --- /dev/null +++ b/tests/test_fp8/test_fp8_all_to_all.py @@ -0,0 +1,34 @@ +import torch +import torch.distributed as dist +from torch.distributed.distributed_c10d import _get_default_group +from torch.testing import assert_close + +from colossalai import launch +from colossalai.accelerator import get_accelerator +from colossalai.shardformer.layer._operation import _all_to_all +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +@parameterize("shape", [(16, 8, 4)]) +@parameterize("scatter_dim", [0, 1, 2]) +@parameterize("dtype", [torch.bfloat16, torch.float16]) +def check_4gpu(shape, scatter_dim, dtype): + world_size = dist.get_world_size() + x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) + output_origin = _all_to_all(x, world_size, _get_default_group(), scatter_dim, 0, False) + output_fp8 = _all_to_all(x, world_size, _get_default_group(), scatter_dim, 0, True) + assert_close(output_origin, output_fp8, rtol=0.1, atol=0.1) + + +def run_dist(rank, world_size, port): + launch(rank=rank, world_size=world_size, port=port, host="localhost") + check_4gpu() + + +@rerun_if_address_is_in_use() +def test_all_to_all(): + spawn(run_dist, 4) + + +if __name__ == "__main__": + test_all_to_all() diff --git a/tests/test_fp8/test_fp8_all_to_all_single.py b/tests/test_fp8/test_fp8_all_to_all_single.py new file mode 100644 index 000000000000..1393d5cbfc42 --- /dev/null +++ b/tests/test_fp8/test_fp8_all_to_all_single.py @@ -0,0 +1,34 @@ +import torch +import torch.distributed as dist +from torch.distributed.distributed_c10d import _get_default_group +from torch.testing import assert_close + +from colossalai import launch +from colossalai.accelerator import get_accelerator +from colossalai.shardformer.layer._operation import _all_to_all_single +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +@parameterize("shape", [(1, 8, 16)]) +@parameterize("scatter_dim", [1, 2]) +@parameterize("dtype", [torch.bfloat16, torch.float16]) +def check_4gpu(shape, scatter_dim, dtype): + world_size = dist.get_world_size() + x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) + output_origin = _all_to_all_single(x, world_size, _get_default_group(), scatter_dim, 0, False) + output_fp8 = _all_to_all_single(x, world_size, _get_default_group(), scatter_dim, 0, True) + assert_close(output_origin, output_fp8, rtol=0.1, atol=0.1) + + +def run_dist(rank, world_size, port): + launch(rank=rank, world_size=world_size, port=port, host="localhost") + check_4gpu() + + +@rerun_if_address_is_in_use() +def test_all_to_all_single(): + spawn(run_dist, 4) + + +if __name__ == "__main__": + test_all_to_all_single() diff --git a/tests/test_fp8/test_fp8_allgather.py b/tests/test_fp8/test_fp8_allgather_flat.py similarity index 96% rename from tests/test_fp8/test_fp8_allgather.py rename to tests/test_fp8/test_fp8_allgather_flat.py index 1a4c8511a843..35e8796c2882 100644 --- a/tests/test_fp8/test_fp8_allgather.py +++ b/tests/test_fp8/test_fp8_allgather_flat.py @@ -32,9 +32,9 @@ def run_dist(rank, world_size, port): @rerun_if_address_is_in_use() -def test_all_gather(): +def test_all_gather_flat(): spawn(run_dist, 4) if __name__ == "__main__": - test_all_gather() + test_all_gather_flat() diff --git a/tests/test_fp8/test_fp8_gather.py b/tests/test_fp8/test_fp8_gather.py new file mode 100644 index 000000000000..2c94dd82f20d --- /dev/null +++ b/tests/test_fp8/test_fp8_gather.py @@ -0,0 +1,41 @@ +import torch +from torch.distributed.distributed_c10d import _get_default_group +from torch.testing import assert_close + +from colossalai import launch +from colossalai.accelerator import get_accelerator +from colossalai.shardformer.layer._operation import _gather +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +@parameterize( + "shape", + [ + # (3, 7), (2, 1), (1, 2), (2, 2), (4, 2), + (5,), + (4,), + (2,), + ], +) +@parameterize("dtype", [torch.bfloat16, torch.float16]) +def check_4gpu(shape, dtype): + x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) + output_origin = _gather(x, 0, _get_default_group(), False) + output_fp8 = _gather(x, 0, _get_default_group(), True) + print(output_origin.shape) + print(output_fp8.shape) + assert_close(output_origin, output_fp8, rtol=0.1, atol=0.1) + + +def run_dist(rank, world_size, port): + launch(rank=rank, world_size=world_size, port=port, host="localhost") + check_4gpu() + + +@rerun_if_address_is_in_use() +def test_all_gather(): + spawn(run_dist, 4) + + +if __name__ == "__main__": + test_all_gather() diff --git a/tests/test_fp8/test_fp8_reduce_scatter.py b/tests/test_fp8/test_fp8_reduce_scatter.py new file mode 100644 index 000000000000..ae2c7d73b22a --- /dev/null +++ b/tests/test_fp8/test_fp8_reduce_scatter.py @@ -0,0 +1,38 @@ +import torch +from torch.distributed import reduce_scatter +from torch.distributed.distributed_c10d import _get_default_group +from torch.testing import assert_close + +from colossalai import launch +from colossalai.accelerator import get_accelerator +from colossalai.quantization.fp8 import reduce_scatter_fp8 +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +@parameterize("shape", [(16, 8, 4)]) +@parameterize("scatter_dim", [0, 1, 2]) +@parameterize("dtype", [torch.bfloat16, torch.float16]) +def check_4gpu(shape, scatter_dim, dtype): + print(shape, scatter_dim, dtype) + x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) + input_list = list(torch.chunk(x, dim=scatter_dim, chunks=4)) + input_list = [t.contiguous() for t in input_list] + output_origin = torch.empty_like(input_list[0]) + output_fp8 = torch.empty_like(input_list[0]) + reduce_scatter(output_origin, input_list, group=_get_default_group()) + reduce_scatter_fp8(output_fp8, input_list, group=_get_default_group()) + assert_close(output_origin, output_fp8, rtol=0.1, atol=0.1) + + +def run_dist(rank, world_size, port): + launch(rank=rank, world_size=world_size, port=port, host="localhost") + check_4gpu() + + +@rerun_if_address_is_in_use() +def test_reduce_scatter(): + spawn(run_dist, 4) + + +if __name__ == "__main__": + test_reduce_scatter() From 5b124498a48c0633da238fc2824a13d935620367 Mon Sep 17 00:00:00 2001 From: GuangyaoZhang Date: Tue, 30 Jul 2024 09:08:50 +0000 Subject: [PATCH 3/7] fix typo --- tests/test_fp8/test_fp8_gather.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_fp8/test_fp8_gather.py b/tests/test_fp8/test_fp8_gather.py index 2c94dd82f20d..5e21254e253e 100644 --- a/tests/test_fp8/test_fp8_gather.py +++ b/tests/test_fp8/test_fp8_gather.py @@ -11,7 +11,11 @@ @parameterize( "shape", [ - # (3, 7), (2, 1), (1, 2), (2, 2), (4, 2), + (3, 7), + (2, 1), + (1, 2), + (2, 2), + (4, 2), (5,), (4,), (2,), From 7728f2be4ad53a4f0e2eb15f5dd70cff5b7c172f Mon Sep 17 00:00:00 2001 From: GuangyaoZhang Date: Wed, 31 Jul 2024 05:37:59 +0000 Subject: [PATCH 4/7] fix all reduce --- colossalai/quantization/fp8.py | 41 +++++++++++++++++--------- tests/test_fp8/test_fp8_allreduce.py | 43 ++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 14 deletions(-) create mode 100644 tests/test_fp8/test_fp8_allreduce.py diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 8887310b43a7..22b151e569e2 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -3,6 +3,7 @@ import numpy as np import torch import torch.distributed as dist +import torch.nn.functional as F def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> (torch.Tensor, torch.Tensor): @@ -37,7 +38,9 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) - return ret, scale_inv -def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype) -> torch.Tensor: +def cast_from_fp8( + inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype, per_channel_scale=False +) -> torch.Tensor: r""" Args: inp: should be a fp8 torch tensor in one of the types: [torch.float8_e4m3fn, torch.float8_e5m2]. @@ -49,40 +52,46 @@ def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dt if inp.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: raise TypeError("Only float8_e4m3fn and float8_e5m2 are allowed.") - if scale_inv.dim() >= 1: + if per_channel_scale: ret = scale_inv[:, None] * inp.float() else: ret = scale_inv * inp.float() return ret.to(ret_type) -def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e5m2", group=None) -> None: +def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op="sum", group=None) -> None: r""" This is an in-place operation for compressed all_reduce using fp8. It works like dist.all_reduce but during communication the data is cast to fp8 format. + Args: tensor: torch.Tensor in fp32, fp16, bf16 datatype. fp8_format: e4m3 or e5m2 + op: sum or mean + Returns: None """ + world_size = dist.get_world_size(group=group) input_type = tensor.dtype input_shape = tensor.shape input_device = tensor.device input_size = tensor.numel() - tensor = tensor.flatten() + flat_padded_x = tensor.flatten() - fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 + assert op in ["sum", "mean"], "op can only be sum or mean" - ret, scale = cast_to_fp8(tensor, fp8_format=fp8_format) + if flat_padded_x.size(0) % world_size != 0: + pad_size = world_size - flat_padded_x.size(0) % world_size + flat_padded_x = F.pad(flat_padded_x, (0, pad_size)) + + fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 + ret, scale = cast_to_fp8(flat_padded_x, fp8_format=fp8_format) inp = ret.view(torch.uint8) input_chunks = list(torch.chunk(inp, world_size, dim=0)) - if dist.get_rank() == world_size - 1: - output_chunks = [torch.empty_like(input_chunks[-1]) for _ in range(world_size)] - else: - output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)] + output_chunks = list(torch.chunk(torch.empty_like(inp), world_size, dim=0)) dist.all_to_all(output_chunks, input_chunks, group=group) scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)] dist.all_gather(scale_list, scale, group=group) @@ -91,15 +100,18 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e5m2", group=None) -> None: out = out.view(fp8_type) summed_out += cast_from_fp8(out, scale, input_type) + if op == "mean": + summed_out.div_(world_size) + summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format) dist.all_gather(scale_list, scale, group=group) - tensor_list = list(torch.chunk(torch.empty(input_size, device=input_device, dtype=torch.uint8), world_size, dim=0)) + tensor_list = [torch.empty_like(summed_out_fp8.view(torch.uint8)) for _ in range(world_size)] dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8), group=group) for i in range(world_size): tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i] - tensor_out = torch.cat(tensor_list, dim=0) - tensor.data = tensor_out.view(input_shape).to(input_type) + out = torch.cat(tensor_list, dim=0) + tensor.copy_(out[:input_size].view(input_shape).to(input_type)) def cast_to_fp8_pipeline(inp: Any) -> None: @@ -199,6 +211,7 @@ def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2 summed_out = torch.zeros_like(output_chunks[0]).to(input_type) for scale, out in zip(output_scale_list, output_chunks): + scale = scale[0] out = out.view(fp8_type) summed_out += cast_from_fp8(out, scale, input_type) output.data = summed_out @@ -275,5 +288,5 @@ def all_gather_into_tensor_flat_fp8( dist.all_gather_into_tensor(buffer.view(torch.uint8), fp8_input.view(torch.uint8), group=group) numel = np.prod(output_shape) valid_buffer = buffer[:numel].reshape(output_shape) - valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type) + valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type, per_channel_scale=(len(output_shape) == 2)) output_tensor[:numel].copy_(valid_buffer.view(-1)) diff --git a/tests/test_fp8/test_fp8_allreduce.py b/tests/test_fp8/test_fp8_allreduce.py new file mode 100644 index 000000000000..c273a5a64a6a --- /dev/null +++ b/tests/test_fp8/test_fp8_allreduce.py @@ -0,0 +1,43 @@ +import torch +import torch.distributed +from torch.testing import assert_close + +from colossalai import launch +from colossalai.accelerator import get_accelerator +from colossalai.quantization.fp8 import all_reduce_fp8 +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +@parameterize( + "shape", + [ + (3, 7), + (4, 7), + (7, 4), + (8, 9), + (3), + (7,), + (8,), + ], +) +@parameterize("dtype", [torch.float16]) +def check_4gpu(shape, dtype): + x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) + x_fp8 = x.clone() + torch.distributed.all_reduce(x) + all_reduce_fp8(x_fp8) + assert_close(x, x_fp8, rtol=0.1, atol=0.1) + + +def run_dist(rank, world_size, port): + launch(rank=rank, world_size=world_size, port=port, host="localhost") + check_4gpu() + + +@rerun_if_address_is_in_use() +def test_all_reduce(): + spawn(run_dist, 4) + + +if __name__ == "__main__": + test_all_reduce() From 943ea8c170e73d3fda467d579b4f7e825fd0fb4d Mon Sep 17 00:00:00 2001 From: GuangyaoZhang Date: Wed, 31 Jul 2024 08:39:56 +0000 Subject: [PATCH 5/7] fix pytest failure --- colossalai/quantization/fp8.py | 1 - tests/test_fp8/test_fp8_gather.py | 2 -- tests/test_fp8/test_fp8_reduce_scatter.py | 1 - 3 files changed, 4 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 22b151e569e2..92147783d0a0 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -211,7 +211,6 @@ def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2 summed_out = torch.zeros_like(output_chunks[0]).to(input_type) for scale, out in zip(output_scale_list, output_chunks): - scale = scale[0] out = out.view(fp8_type) summed_out += cast_from_fp8(out, scale, input_type) output.data = summed_out diff --git a/tests/test_fp8/test_fp8_gather.py b/tests/test_fp8/test_fp8_gather.py index 5e21254e253e..b0d512b3c4fc 100644 --- a/tests/test_fp8/test_fp8_gather.py +++ b/tests/test_fp8/test_fp8_gather.py @@ -26,8 +26,6 @@ def check_4gpu(shape, dtype): x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) output_origin = _gather(x, 0, _get_default_group(), False) output_fp8 = _gather(x, 0, _get_default_group(), True) - print(output_origin.shape) - print(output_fp8.shape) assert_close(output_origin, output_fp8, rtol=0.1, atol=0.1) diff --git a/tests/test_fp8/test_fp8_reduce_scatter.py b/tests/test_fp8/test_fp8_reduce_scatter.py index ae2c7d73b22a..1b4181a3141d 100644 --- a/tests/test_fp8/test_fp8_reduce_scatter.py +++ b/tests/test_fp8/test_fp8_reduce_scatter.py @@ -13,7 +13,6 @@ @parameterize("scatter_dim", [0, 1, 2]) @parameterize("dtype", [torch.bfloat16, torch.float16]) def check_4gpu(shape, scatter_dim, dtype): - print(shape, scatter_dim, dtype) x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) input_list = list(torch.chunk(x, dim=scatter_dim, chunks=4)) input_list = [t.contiguous() for t in input_list] From e5c0a8ebbc18c9f18bbb5053f20cf1ac0153f300 Mon Sep 17 00:00:00 2001 From: GuangyaoZhang Date: Fri, 2 Aug 2024 04:32:50 +0000 Subject: [PATCH 6/7] fix reduce op and move function to fp8.py --- colossalai/quantization/fp8.py | 76 +++++++++++++++++++- colossalai/shardformer/layer/_operation.py | 76 +++++--------------- tests/test_fp8/test_fp8_all_to_all.py | 17 +++-- tests/test_fp8/test_fp8_all_to_all_single.py | 19 ++--- tests/test_fp8/test_fp8_allreduce.py | 15 ++-- tests/test_fp8/test_fp8_gather.py | 15 ++-- tests/test_fp8/test_fp8_reduce_scatter.py | 5 +- 7 files changed, 136 insertions(+), 87 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 92147783d0a0..493d4e5c92fb 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -4,6 +4,7 @@ import torch import torch.distributed as dist import torch.nn.functional as F +from torch.distributed import ReduceOp def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> (torch.Tensor, torch.Tensor): @@ -59,7 +60,7 @@ def cast_from_fp8( return ret.to(ret_type) -def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op="sum", group=None) -> None: +def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None) -> None: r""" This is an in-place operation for compressed all_reduce using fp8. It works like dist.all_reduce but during communication the data is cast to fp8 format. @@ -80,7 +81,7 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op="sum", group=None input_size = tensor.numel() flat_padded_x = tensor.flatten() - assert op in ["sum", "mean"], "op can only be sum or mean" + assert op in [ReduceOp.SUM, ReduceOp.AVG], "op can only be ReduceOp.SUM or ReduceOp.AVG" if flat_padded_x.size(0) % world_size != 0: pad_size = world_size - flat_padded_x.size(0) % world_size @@ -100,7 +101,7 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op="sum", group=None out = out.view(fp8_type) summed_out += cast_from_fp8(out, scale, input_type) - if op == "mean": + if op == ReduceOp.AVG: summed_out.div_(world_size) summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format) @@ -289,3 +290,72 @@ def all_gather_into_tensor_flat_fp8( valid_buffer = buffer[:numel].reshape(output_shape) valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type, per_channel_scale=(len(output_shape) == 2)) output_tensor[:numel].copy_(valid_buffer.view(-1)) + + +def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2"): + + world_size = dist.get_world_size(group) + + input_type = input_list[0].dtype + fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 + scale_list = [] + tensor_list = [] + + for i in range(world_size): + input_tensor = input_list[i] + ret, scale = cast_to_fp8(input_tensor, fp8_format=fp8_format) + scale_list.append(scale) + ret = ret.view(torch.uint8) + tensor_list.append(ret) + + output_scale_list = [torch.empty_like(x) for x in scale_list] + output_tensor_list = [torch.empty_like(x) for x in tensor_list] + dist.all_to_all(output_tensor_list, tensor_list, group=group) + dist.all_to_all(output_scale_list, scale_list, group=group) + + for i in range(world_size): + scale = output_scale_list[i] + tensor = output_tensor_list[i] + tensor = tensor.view(fp8_type) + output_list[i].copy_(cast_from_fp8(tensor, scale, input_type)) + + +def all_to_all_single_fp8(output_tensor, input_tensor, group=None, fp8_format="e5m2"): + + world_size = dist.get_world_size(group) + + per_slice_len = input_tensor.size(0) // world_size + input_type = input_tensor.dtype + ret, scale = cast_to_fp8(input_tensor, fp8_format=fp8_format) + fp8_type = ret.dtype + input_tensor = ret.view(torch.uint8) + tensor = torch.empty_like(input_tensor) + scale_list = [torch.empty_like(scale) for _ in range(world_size)] + dist.all_to_all_single(tensor, input_tensor, group=group) + dist.all_gather(scale_list, scale, group=group) + cast_tensor_list = [] + + for i in range(world_size): + output_part = tensor[per_slice_len * i : per_slice_len * (i + 1)].view(fp8_type) + output_part = cast_from_fp8(output_part, scale_list[i], input_type) + cast_tensor_list.append(output_part) + output_tensor.copy_(torch.concatenate(cast_tensor_list, dim=0)) + + +def gather_fp8(output_list, input_, group=None, fp8_format="e5m2"): + + world_size = dist.get_world_size(group) + + input_type = input_.dtype + ret, scale = cast_to_fp8(input_, fp8_format=fp8_format) + fp8_type = ret.dtype + input_ = ret.view(torch.uint8) + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(world_size)] + torch.distributed.all_gather(tensor_list, input_, group=group) + torch.distributed.all_gather(scale_list, scale, group=group) + + for i in range(world_size): + output = tensor_list[i].view(fp8_type) + scale = scale_list[i] + output_list[i].copy_(cast_from_fp8(output, scale, input_type)) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index bd69dc6a08f7..e98506a5f35f 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -14,7 +14,13 @@ except ImportError: _grad_accum_fusion_available = False -from colossalai.quantization.fp8 import all_reduce_fp8, cast_from_fp8, cast_to_fp8, reduce_scatter_fp8 +from colossalai.quantization.fp8 import ( + all_reduce_fp8, + all_to_all_fp8, + all_to_all_single_fp8, + gather_fp8, + reduce_scatter_fp8, +) class FusedLayerNormAffineFunction1D(torch.autograd.Function): @@ -946,34 +952,14 @@ def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_for if world_size == 1: return input_ + input_ = input_.contiguous() + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] if fp8_communication: - input_type = input_.dtype - ret, scale = cast_to_fp8(input_, fp8_format=fp8_format) - fp8_type = ret.dtype - input_ = ret.view(torch.uint8) - input_ = input_.contiguous() - tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - scale = torch.tensor(scale, dtype=torch.float32).to(input_.device) - scale_list = [torch.ones(1, dtype=torch.float32, device=input_.device) for _ in range(world_size)] - - scale = torch.tensor(scale).to(input_.device) torch.distributed.all_gather(tensor_list, input_, group=process_group) - torch.distributed.all_gather(scale_list, scale, group=process_group) - - cast_tensor_list = [] - for output, scale in zip(tensor_list, scale_list): - scale = torch.tensor(scale[0]) - output = output.view(fp8_type) - output = cast_from_fp8(output, scale, input_type) - cast_tensor_list.append(output) - - output = torch.cat(cast_tensor_list, dim=dim).contiguous() - else: - input_ = input_.contiguous() - tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - torch.distributed.all_gather(tensor_list, input_, group=process_group) - output = torch.cat(tensor_list, dim=dim).contiguous() + gather_fp8(tensor_list, input_, fp8_format=fp8_format, group=process_group) + + output = torch.cat(tensor_list, dim=dim).contiguous() return output @@ -1003,25 +989,11 @@ def _reduce_scatter(input_, dim=1, process_group=None): def _all_to_all(input_, world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2"): + input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] + output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] if fp8_communication: - input_type = input_.dtype - ret, scale = cast_to_fp8(input_, fp8_format=fp8_format) - fp8_type = ret.dtype - input_ = ret.view(torch.uint8) - input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] - output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] - scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(world_size)] - dist.all_to_all(output_list, input_list, group=group) - dist.all_gather(scale_list, scale, group=group) - cast_tensor_list = [] - for output, scale in zip(output_list, scale_list): - output = output.view(fp8_type) - output = cast_from_fp8(output, scale, input_type) - cast_tensor_list.append(output) - output_list = cast_tensor_list + all_to_all_fp8(output_list, input_list, group=group, fp8_format=fp8_format) else: - input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] - output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] dist.all_to_all(output_list, input_list, group=group) return torch.cat(output_list, dim=gather_dim).contiguous() @@ -1040,23 +1012,11 @@ def _all_to_all_single( .contiguous() ) + output = torch.empty_like(input_t) if fp8_communication: - input_type = input_t.dtype - ret, scale = cast_to_fp8(input_t, fp8_format=fp8_format) - fp8_type = ret.dtype - input_t = ret.view(torch.uint8) - output = torch.empty_like(input_t) - scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(seq_world_size)] - dist.all_to_all_single(output, input_t, group=group) - dist.all_gather(scale_list, scale, group=group) - cast_tensor_list = [] - for output_part, scale in zip(output, scale_list): - output_part = output_part.view(fp8_type) - output_part = cast_from_fp8(output_part, scale, input_type) - cast_tensor_list.append(output_part) - output = torch.stack(cast_tensor_list, dim=0) + all_to_all_single_fp8(output, input_t, group=group, fp8_format=fp8_format) else: - output = torch.empty_like(input_t) + dist.all_to_all_single(output, input_t, group=group) if scatter_dim < 2: diff --git a/tests/test_fp8/test_fp8_all_to_all.py b/tests/test_fp8/test_fp8_all_to_all.py index eed3583e9866..884aab744ba0 100644 --- a/tests/test_fp8/test_fp8_all_to_all.py +++ b/tests/test_fp8/test_fp8_all_to_all.py @@ -5,19 +5,24 @@ from colossalai import launch from colossalai.accelerator import get_accelerator -from colossalai.shardformer.layer._operation import _all_to_all +from colossalai.quantization.fp8 import all_to_all_fp8 from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn @parameterize("shape", [(16, 8, 4)]) @parameterize("scatter_dim", [0, 1, 2]) @parameterize("dtype", [torch.bfloat16, torch.float16]) -def check_4gpu(shape, scatter_dim, dtype): +@parameterize("fp8_format", ["e4m3", "e5m2"]) +def check_4gpu(shape, scatter_dim, dtype, fp8_format): world_size = dist.get_world_size() - x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) - output_origin = _all_to_all(x, world_size, _get_default_group(), scatter_dim, 0, False) - output_fp8 = _all_to_all(x, world_size, _get_default_group(), scatter_dim, 0, True) - assert_close(output_origin, output_fp8, rtol=0.1, atol=0.1) + input_tensor = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) + input_tensor_list = list(torch.chunk(input_tensor, world_size, scatter_dim)) + input_tensor_list = [x.contiguous() for x in input_tensor_list] + output_tensor_list_fp8 = [torch.empty_like(x) for x in input_tensor_list] + output_tensor_list = [torch.empty_like(x) for x in input_tensor_list] + all_to_all_fp8(output_tensor_list_fp8, input_tensor_list, group=_get_default_group(), fp8_format=fp8_format) + dist.all_to_all(output_tensor_list, input_tensor_list, group=_get_default_group()) + assert_close(output_tensor_list_fp8, output_tensor_list, rtol=0.1, atol=0.1) def run_dist(rank, world_size, port): diff --git a/tests/test_fp8/test_fp8_all_to_all_single.py b/tests/test_fp8/test_fp8_all_to_all_single.py index 1393d5cbfc42..70765f2d48de 100644 --- a/tests/test_fp8/test_fp8_all_to_all_single.py +++ b/tests/test_fp8/test_fp8_all_to_all_single.py @@ -5,19 +5,22 @@ from colossalai import launch from colossalai.accelerator import get_accelerator -from colossalai.shardformer.layer._operation import _all_to_all_single +from colossalai.quantization.fp8 import all_to_all_single_fp8 from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +dist.all_to_all_single -@parameterize("shape", [(1, 8, 16)]) -@parameterize("scatter_dim", [1, 2]) + +@parameterize("shape", [(4), (8, 7), (4, 8, 16)]) @parameterize("dtype", [torch.bfloat16, torch.float16]) -def check_4gpu(shape, scatter_dim, dtype): - world_size = dist.get_world_size() +@parameterize("fp8_format", ["e4m3", "e5m2"]) +def check_4gpu(shape, dtype, fp8_format): x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) - output_origin = _all_to_all_single(x, world_size, _get_default_group(), scatter_dim, 0, False) - output_fp8 = _all_to_all_single(x, world_size, _get_default_group(), scatter_dim, 0, True) - assert_close(output_origin, output_fp8, rtol=0.1, atol=0.1) + output = torch.empty_like(x) + output_fp8 = torch.empty_like(x) + all_to_all_single_fp8(output_fp8, x, group=_get_default_group(), fp8_format=fp8_format) + dist.all_to_all_single(output, x, group=_get_default_group()) + assert_close(output, output_fp8, rtol=0.1, atol=0.1) def run_dist(rank, world_size, port): diff --git a/tests/test_fp8/test_fp8_allreduce.py b/tests/test_fp8/test_fp8_allreduce.py index c273a5a64a6a..c23959b5d0da 100644 --- a/tests/test_fp8/test_fp8_allreduce.py +++ b/tests/test_fp8/test_fp8_allreduce.py @@ -1,5 +1,5 @@ import torch -import torch.distributed +import torch.distributed as dist from torch.testing import assert_close from colossalai import launch @@ -20,12 +20,17 @@ (8,), ], ) -@parameterize("dtype", [torch.float16]) -def check_4gpu(shape, dtype): +@parameterize("dtype", [torch.float16, torch.bfloat16]) +@parameterize("fp8_format", ["e4m3", "e5m2"]) +def check_4gpu(shape, dtype, fp8_format): x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) x_fp8 = x.clone() - torch.distributed.all_reduce(x) - all_reduce_fp8(x_fp8) + dist.all_reduce(x) + all_reduce_fp8(x_fp8, fp8_format=fp8_format) + assert_close(x, x_fp8, rtol=0.1, atol=0.1) + + dist.all_reduce(x, op=dist.ReduceOp.AVG) + all_reduce_fp8(x_fp8, op=dist.ReduceOp.AVG, fp8_format=fp8_format) assert_close(x, x_fp8, rtol=0.1, atol=0.1) diff --git a/tests/test_fp8/test_fp8_gather.py b/tests/test_fp8/test_fp8_gather.py index b0d512b3c4fc..79d1d4ea49e6 100644 --- a/tests/test_fp8/test_fp8_gather.py +++ b/tests/test_fp8/test_fp8_gather.py @@ -1,10 +1,11 @@ import torch +import torch.distributed as dist from torch.distributed.distributed_c10d import _get_default_group from torch.testing import assert_close from colossalai import launch from colossalai.accelerator import get_accelerator -from colossalai.shardformer.layer._operation import _gather +from colossalai.quantization.fp8 import gather_fp8 from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn @@ -22,11 +23,15 @@ ], ) @parameterize("dtype", [torch.bfloat16, torch.float16]) -def check_4gpu(shape, dtype): +@parameterize("fp8_format", ["e4m3", "e5m2"]) +def check_4gpu(shape, dtype, fp8_format): + world_size = dist.get_world_size() x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) - output_origin = _gather(x, 0, _get_default_group(), False) - output_fp8 = _gather(x, 0, _get_default_group(), True) - assert_close(output_origin, output_fp8, rtol=0.1, atol=0.1) + output_list = [torch.empty_like(x) for _ in range(world_size)] + output_list_fp8 = [torch.empty_like(x) for _ in range(world_size)] + gather_fp8(output_list_fp8, x, group=_get_default_group(), fp8_format=fp8_format) + dist.all_gather(output_list, x, group=_get_default_group()) + assert_close(output_list, output_list_fp8, rtol=0.1, atol=0.1) def run_dist(rank, world_size, port): diff --git a/tests/test_fp8/test_fp8_reduce_scatter.py b/tests/test_fp8/test_fp8_reduce_scatter.py index 1b4181a3141d..c18446e39ea0 100644 --- a/tests/test_fp8/test_fp8_reduce_scatter.py +++ b/tests/test_fp8/test_fp8_reduce_scatter.py @@ -12,14 +12,15 @@ @parameterize("shape", [(16, 8, 4)]) @parameterize("scatter_dim", [0, 1, 2]) @parameterize("dtype", [torch.bfloat16, torch.float16]) -def check_4gpu(shape, scatter_dim, dtype): +@parameterize("fp8_format", ["e4m3", "e5m2"]) +def check_4gpu(shape, scatter_dim, dtype, fp8_format): x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) input_list = list(torch.chunk(x, dim=scatter_dim, chunks=4)) input_list = [t.contiguous() for t in input_list] output_origin = torch.empty_like(input_list[0]) output_fp8 = torch.empty_like(input_list[0]) reduce_scatter(output_origin, input_list, group=_get_default_group()) - reduce_scatter_fp8(output_fp8, input_list, group=_get_default_group()) + reduce_scatter_fp8(output_fp8, input_list, group=_get_default_group(), fp8_format=fp8_format) assert_close(output_origin, output_fp8, rtol=0.1, atol=0.1) From b6b528ce297b4ca6485517a3f2e43ac71a320f6e Mon Sep 17 00:00:00 2001 From: GuangyaoZhang Date: Fri, 2 Aug 2024 04:38:16 +0000 Subject: [PATCH 7/7] fix typo --- colossalai/quantization/fp8.py | 6 +++--- colossalai/shardformer/layer/_operation.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 493d4e5c92fb..b003e90e89f8 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -68,7 +68,7 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, gro Args: tensor: torch.Tensor in fp32, fp16, bf16 datatype. fp8_format: e4m3 or e5m2 - op: sum or mean + op: ReduceOp.SUM or ReduceOp.AVG Returns: None @@ -352,8 +352,8 @@ def gather_fp8(output_list, input_, group=None, fp8_format="e5m2"): input_ = ret.view(torch.uint8) tensor_list = [torch.empty_like(input_) for _ in range(world_size)] scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(world_size)] - torch.distributed.all_gather(tensor_list, input_, group=group) - torch.distributed.all_gather(scale_list, scale, group=group) + dist.all_gather(tensor_list, input_, group=group) + dist.all_gather(scale_list, scale, group=group) for i in range(world_size): output = tensor_list[i].view(fp8_type) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index e98506a5f35f..a27fd35c192d 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -955,9 +955,9 @@ def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_for input_ = input_.contiguous() tensor_list = [torch.empty_like(input_) for _ in range(world_size)] if fp8_communication: - torch.distributed.all_gather(tensor_list, input_, group=process_group) - else: gather_fp8(tensor_list, input_, fp8_format=fp8_format, group=process_group) + else: + dist.all_gather(tensor_list, input_, group=process_group) output = torch.cat(tensor_list, dim=dim).contiguous()