From 03ce3c53ff14fe27834fcdbec29f6f70bcc50f19 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 13 Mar 2025 13:24:40 +0800 Subject: [PATCH 01/19] [fix] fix qwen VocabParallelLMHead1D and gather output --- colossalai/shardformer/policies/qwen2.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py index 84d2b2fdbd99..5caeae7060a7 100644 --- a/colossalai/shardformer/policies/qwen2.py +++ b/colossalai/shardformer/policies/qwen2.py @@ -11,8 +11,10 @@ Linear1D_Row, LinearWithGradAccum, PaddingEmbedding, + PaddingLMHead, RMSNorm, VocabParallelEmbedding1D, + VocabParallelLMHead1D, ) from ..modeling.qwen2 import ( @@ -429,8 +431,12 @@ def module_policy(self): sub_module_replacement=[ SubModuleReplacementDescription( suffix="lm_head", - target_module=Linear1D_Col, - kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv), + target_module=VocabParallelLMHead1D, + kwargs=dict( + gather_output=not self.shard_config.parallel_output, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ) ], method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}, @@ -444,7 +450,7 @@ def module_policy(self): sub_module_replacement=[ SubModuleReplacementDescription( suffix="lm_head", - target_module=LinearWithGradAccum, + target_module=PaddingLMHead, kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv), ) ], From b835d1bcd3bfa157c78ef80dfc837f0087a7ae5d Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 13 Mar 2025 14:52:09 +0800 Subject: [PATCH 02/19] fix tp bug --- colossalai/shardformer/policies/qwen2.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py index 5caeae7060a7..fd14029a3a36 100644 --- a/colossalai/shardformer/policies/qwen2.py +++ b/colossalai/shardformer/policies/qwen2.py @@ -452,7 +452,16 @@ def module_policy(self): suffix="lm_head", target_module=PaddingLMHead, kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv), - ) + ), + SubModuleReplacementDescription( + suffix="lm_head", + target_module=VocabParallelLMHead1D, + kwargs={ + "gather_output": not self.shard_config.parallel_output, + "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by, + "fp8_communication": self.shard_config.fp8_communication, + }, + ), ], method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}, ) From 137ec17781193cb4aedd8f3276d2c35cb8fddc89 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 13 Mar 2025 14:55:26 +0800 Subject: [PATCH 03/19] fix consumer --- applications/ColossalChat/coati/distributed/consumer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 1e85cccb3c5b..380a2ee1b78a 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -73,6 +73,8 @@ def setup(self) -> None: ) if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config: plugin_config["microbatch_size"] = self.microbatch_size + if self.plugin_config.get("tp_size", 1) > 1: + plugin_config["parallel_output"] = False plugin_config.update(self.plugin_config) self.plugin = HybridParallelPlugin(**plugin_config) self.booster = Booster(plugin=self.plugin) From ce8a8b30f02264474ddd2a75129b466ccf23e292 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Fri, 14 Mar 2025 18:41:30 +0800 Subject: [PATCH 04/19] [feat] Support Distributed LogProb for GRPO Training --- .../coati/distributed/consumer.py | 2 +- .../coati/distributed/grpo_consumer.py | 16 +- .../ColossalChat/coati/distributed/utils.py | 42 +++- colossalai/shardformer/layer/__init__.py | 4 +- colossalai/shardformer/layer/loss.py | 197 +++++++++++++++++- colossalai/shardformer/modeling/qwen2.py | 6 +- .../test_layer/test_dist_log_prob.py | 67 ++++++ 7 files changed, 319 insertions(+), 15 deletions(-) create mode 100644 tests/test_shardformer/test_layer/test_dist_log_prob.py diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 380a2ee1b78a..6aa63086f87c 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -66,7 +66,7 @@ def setup(self) -> None: launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0) plugin_config = dict( - tp_size=1, + tp_size=2, pp_size=1, precision="bf16", zero_stage=1, diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 55dfd09ab244..bee66c770655 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -8,7 +8,6 @@ from coati.distributed.loss import PolicyLoss from coati.distributed.reward.reward_fn import math_reward_fn from coati.distributed.reward.verifiable_reward import VerifiableReward -from coati.distributed.utils import calc_action_log_probs from coati.trainer.utils import all_reduce_mean from transformers import AutoModelForCausalLM, AutoTokenizer @@ -116,18 +115,23 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer) with ctx: - policy_model_logits = self.policy_model( + policy_model_log_probs = self.policy_model( input_ids=data["input_ids"], attention_mask=data["attention_mask"], + return_dist_log_prob=True, )["logits"] - action_log_probs = calc_action_log_probs(policy_model_logits, data["input_ids"], num_action) - + policy_model_log_probs = policy_model_log_probs.reshape(8, -1) + action_log_probs = policy_model_log_probs[:, -num_action:] with torch.no_grad(): - reference_model_logits = self.reference_model( + reference_model_log_probs = self.reference_model( input_ids=data["input_ids"], attention_mask=data["attention_mask"], + return_dist_log_prob=True, )["logits"] - reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action) + reference_model_log_probs = reference_model_log_probs.reshape(8, -1) + reference_action_log_probs = reference_model_log_probs[:, -num_action:] + # reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action) + # reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action, self.booster.shard_config, reference_model_logits.shape[-1]) per_token_kl = ( torch.exp(reference_action_log_probs - action_log_probs) diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index 98b54815b5b4..b9f6e64643bd 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -2,6 +2,8 @@ import torch +from colossalai.shardformer.layer.loss import dist_log_prob + def unbind_batch(batch: Dict[str, torch.Tensor]) -> List[Dict[str, torch.Tensor]]: batches = [] @@ -66,18 +68,52 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T return per_label_logps.squeeze(-1) -def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor: +# def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor: +# """Calculate action log probs. + +# Args: +# output (torch.Tensor): Output tensor of Actor.forward.logits. +# sequences (torch.LongTensor): Input sequences. +# num_actions (int): Number of actions. + +# Returns: +# torch.Tensor: Action log probs. +# """ +# log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) +# return log_probs[:, -num_actions:] + + +def calc_action_log_probs( + logits: torch.Tensor, + sequences: torch.LongTensor, + num_actions: int, + shard_config, + vocab_size: int, +) -> torch.Tensor: """Calculate action log probs. Args: - output (torch.Tensor): Output tensor of Actor.forward.logits. + logits (torch.Tensor): Output tensor of Actor.forward.logits. sequences (torch.LongTensor): Input sequences. num_actions (int): Number of actions. + shard_config + vocab_size + Returns: torch.Tensor: Action log probs. """ - log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) + print(f"sequences {sequences.shape} logits {logits.shape}") + log_probs = dist_log_prob(sequences, logits, shard_config, vocab_size, logits.dtype) + + # log_probs = dist_log_prob(sequences[:, 1:], logits[:, :-1, :], shard_config, vocab_size, logits.dtype) + # # labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] + # logits: torch.Tensor, # [B, S, Vocab_size] + # shard_config: ShardConfig, + # vocab_size: int, + # dtype: torch.dtype, + # seq_dim: int = 1, + # log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) return log_probs[:, -num_actions:] diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 0bd1b60923e9..a1b80bf56b63 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -3,7 +3,7 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHead, VocabParallelLMHead1D -from .loss import cross_entropy_1d, dist_cross_entropy +from .loss import cross_entropy_1d, dist_cross_entropy, dist_log_prob, dist_log_prob_1d from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .parallel_module import ParallelModule from .qkv_fused_linear import ( @@ -28,6 +28,8 @@ "DropoutForReplicatedInput", "cross_entropy_1d", "dist_cross_entropy", + "dist_log_prob_1d", + "dist_log_prob", "BaseLayerNorm", "LayerNorm", "RMSNorm", diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index 0e2241af9fc9..e19299608114 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -2,14 +2,21 @@ import torch.distributed as dist from torch.autograd import Function from torch.distributed import ProcessGroup -from torch.nn import CrossEntropyLoss +from torch.nn import CrossEntropyLoss, LogSoftmax from colossalai.shardformer.layer._operation import reduce_forward from colossalai.shardformer.shard import ShardConfig from .utils import is_share_sp_tp -__all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"] +__all__ = [ + "DistCrossEntropy", + "cross_entropy_1d", + "dist_cross_entropy", + "DistLogProb", + "dist_log_prob_1d", + "dist_log_prob", +] _IGNORE_IDX = -100 @@ -137,6 +144,86 @@ def backward(ctx, grad_output): return grad_logits, None, None, None, None, None, None +class DistLogProb(Function): + r""" + Overwrite the forward and backward function to calculate the cross entropy loss before gather + + Args: + Function (:class:`torch.autograd.Function`): default + """ + + @staticmethod + def forward( + ctx, + vocab_logits: torch.Tensor, + target: torch.Tensor, + process_group: ProcessGroup, + vocab_size: int, + dtype=torch.float32, + ): + ###### + # 1.log softmax + ###### + # local max for softmax + logits_max = torch.max(vocab_logits, dim=-1)[0] + handle = dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group, async_op=True) + # get rank and worldsize + rank = dist.get_rank(group=process_group) + world_size = dist.get_world_size(group=process_group) + # cal vocal size + if vocab_size is None: + partition_vocab_size = vocab_logits.size()[-1] + global_vocab_size = partition_vocab_size * world_size + else: + global_vocab_size = vocab_size + partition_vocab_size = global_vocab_size // world_size + # down and up threshold + delta = (global_vocab_size + world_size - 1) // world_size + down_threshold = rank * delta + up_threshold = down_threshold + delta + if up_threshold > global_vocab_size: + up_threshold = global_vocab_size + # mask + mask = (target < down_threshold) | (target >= up_threshold) + masked_target = target.clone() - down_threshold + masked_target[mask] = 0 + masked_target_1d = masked_target.view(-1).contiguous() + # wait for allreduce + handle.wait() + vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1) + # cal exp_logits + exp_logits = torch.exp(vocab_logits) + # cal local sum_exp_logits + sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32) + # all_reduce get global sum_exp_logits + dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group) + # cal log_softmax + log_probs = vocab_logits - torch.log(sum_exp_logits.unsqueeze(dim=-1)) + + ###### + # 2.gather via labels + ###### + log_probs = log_probs.gather(dim=-1, index=masked_target.unsqueeze(-1)) + # set masked val to zero, then all reduce + log_probs[mask.unsqueeze(-1)] = 0 + # allreduce log_probs with ops SUM + dist.all_reduce(log_probs, op=dist.ReduceOp.SUM, group=process_group) + ctx.save_for_backward(exp_logits, mask, masked_target_1d, sum_exp_logits) + ctx.dtype = dtype + return log_probs + + @staticmethod + def backward(ctx, grad_output): + exp_logits, mask, masked_target_1d, sum_exp_logits = ctx.saved_tensors + softmax_logits = exp_logits / sum_exp_logits.unsqueeze(dim=-1) + partion_vocab_size = softmax_logits.shape[-1] + softmax_logits_2d = softmax_logits.view(-1, partion_vocab_size) + update = 1.0 - mask.view(-1).float().to(ctx.dtype) + softmax_logits_2d[torch.arange(0, softmax_logits_2d.shape[0]), masked_target_1d] -= update + grad_logits = -softmax_logits.mul_(grad_output) + return grad_logits, None, None, None, None, None, None + + def cross_entropy_1d( vocab_logits: torch.Tensor, labels: torch.Tensor, @@ -149,6 +236,16 @@ def cross_entropy_1d( return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype, mode) +def dist_log_prob_1d( + vocab_logits: torch.Tensor, + labels: torch.Tensor, + process_group: ProcessGroup = None, + vocab_size: int = None, + dtype: torch.dtype = None, +) -> torch.Tensor: + return DistLogProb.apply(vocab_logits, labels, process_group, vocab_size, dtype) + + def dist_cross_entropy( labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] logits: torch.Tensor, # [B, S, Vocab_size] @@ -243,3 +340,99 @@ def dist_cross_entropy( loss, num_nonzero = loss[0], loss[1].detach() loss = (loss / num_nonzero).squeeze() return loss + + +def dist_log_prob( + labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] + logits: torch.Tensor, # [B, S, Vocab_size] + shard_config: ShardConfig, + vocab_size: int, + dtype: torch.dtype, + seq_dim: int = 1, +) -> torch.Tensor: + """ + Helper to compute cross entropy loss for most shardformer models supporting PP, TP and SP. + """ + # Split labels if not gather output + sp_group = shard_config.sequence_parallel_process_group + dist.get_rank(sp_group) + sp_size = shard_config.sequence_parallel_size + sp_mode = shard_config.sequence_parallelism_mode + parallel_output = shard_config.parallel_output + is_tp = shard_config.enable_tensor_parallelism + is_packed = labels.dim() == 2 + if is_packed: + bs, seq_len = labels.shape + else: + # padded sequence + seq_len = labels.shape[-1] + logits = logits.reshape(-1, *logits.shape[2:]) + seq_dim = 0 + + # Shift labels to predict the next token, and remove the tail logit predicting + sp_size > 1 and (not is_share_sp_tp(sp_mode)) + split_labels_here = seq_len // sp_size == logits.size(seq_dim) # ring attn splits labels before forward + + # if sp_mode == "ring_attn": + # # For Zigzag Ring Attention, labels should've been split and + # # shifted by RingAttention.prepare_varlen_batch() + # if sp_rank == 0: + # logits = logits[..., :-1, :] + # logits = torch.cat([logits, torch.full_like(logits[:, :1, :], _IGNORE_IDX)], dim=seq_dim) + # elif is_sp: + # # Shift only once: either before splitting or in the last rank without splitting + # if split_labels_here or (sp_rank == sp_size - 1): + # labels = labels[..., 1:] + # if split_labels_here: + # labels = labels.split(seq_len // sp_size, dim=-1)[sp_rank] + + # if sp_rank == sp_size - 1: + # logits = logits[..., :-1, :] + # # Pad logits and labels to the same shape across all ranks for TP all_reduce + # if is_tp and parallel_output: + # # If is packed sequence (label dim is 1), then each seq already has the end label token padded. + # # torch.cat is faster than F.pad... + # pad_shape = (logits.shape[0], 1, *logits.shape[2:]) if is_packed else (1, *logits.shape[1:]) + # padding = torch.full(pad_shape, _IGNORE_IDX, dtype=logits.dtype, device=logits.device) + # logits = torch.cat([logits, padding], dim=seq_dim) + # pad_shape = (labels.shape[0], 1) if is_packed else (1,) + # padding = torch.full(pad_shape, _IGNORE_IDX, dtype=labels.dtype, device=labels.device) + # labels = torch.cat([labels, padding], dim=seq_dim) + # else: + # TODO:support sp + labels = labels[..., 1:] + logits = logits[..., :-1, :] + labels = labels.contiguous() + logits = logits.contiguous() + # num_nonzero = (labels != _IGNORE_IDX).sum() + assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}" + + # Flatten the tokens + loss_fct = LogSoftmax() + labels = labels.view(-1) + + if is_tp and parallel_output: + # Cross entropy with all-reduce for TP + new_vocab_size = logits.shape[-1] + logits = logits.view(-1, new_vocab_size) + loss = dist_log_prob_1d( + logits, + labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=vocab_size, + dtype=dtype, + ) + else: + # NOTE if use TP and not parallel_output, the output is gathered in VocabParallelLMHead1D + logits = logits.view(-1, logits.size(-1)) + loss = loss_fct(logits) + + # # Reduce loss instead of gathering logits over seq dim for savings + # if split_labels_here or sp_mode == "ring_attn": + # # Get the global non-zero count + # loss = torch.stack((loss, num_nonzero)) + # # Rescale to offset the grad / (DP * SP) in HybridParallelPlugin + # loss = reduce_forward(loss, sp_group, grad_scale=sp_size) + # loss, num_nonzero = loss[0], loss[1].detach() + # loss = (loss / num_nonzero).squeeze() + return loss diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 569fc4a459c5..d4d89c7b2fae 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -35,7 +35,7 @@ from colossalai.shardformer.layer._operation import all_to_all_comm, split_forward_gather_backward from colossalai.shardformer.shard import ShardConfig -from ..layer import ColoAttention, dist_cross_entropy +from ..layer import ColoAttention, dist_cross_entropy, dist_log_prob from ..layer._operation import gather_sp_output from ..layer.utils import is_share_sp_tp @@ -779,6 +779,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + return_dist_log_prob: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -832,7 +833,8 @@ def forward( loss = None if labels is not None: loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype) - + if return_dist_log_prob: + logits = dist_log_prob(input_ids, logits, shard_config, self.lm_head.out_features, logits.dtype) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output diff --git a/tests/test_shardformer/test_layer/test_dist_log_prob.py b/tests/test_shardformer/test_layer/test_dist_log_prob.py new file mode 100644 index 000000000000..2cf874bc0ca1 --- /dev/null +++ b/tests/test_shardformer/test_layer/test_dist_log_prob.py @@ -0,0 +1,67 @@ +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer import dist_log_prob_1d +from colossalai.testing import rerun_if_address_is_in_use, spawn + +CONFIG = dict( + parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode="1d")), +) + + +def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + """ + Compute the log probabilities from logits for the given labels. + + Args: + logits (torch.Tensor): The input logits. + labels (torch.Tensor): The target labels. + + Returns: + torch.Tensor: The log probabilities corresponding to the labels. + """ + log_probs = torch.log_softmax(logits, dim=-1) + per_label_logps = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)) + return per_label_logps.squeeze(-1) + + +def check_dist_log_prob(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl") + + # prepare data + pred = torch.randn(2, 4, 8, requires_grad=True).cuda() + labels = torch.randint(8, (2, 4)).cuda() + + loss = log_probs_from_logits(pred, labels) + + pred.retain_grad() + loss.mean().backward() + + dist_pred = pred.clone().chunk(world_size, -1)[rank].detach() + dist_pred.requires_grad = True + dist_loss = dist_log_prob_1d(dist_pred, labels) + + dist_pred.retain_grad() + dist_loss.squeeze(-1).mean().backward() + + assert torch.allclose( + loss, dist_loss.squeeze(-1), atol=1e-5 + ), f"dist cross entropy loss is not equal to orgin loss\n{loss}\n{dist_loss.squeeze(-1)}" + + pred_grad_partial = pred.grad.clone().chunk(world_size, -1)[rank].detach() + assert torch.allclose( + pred_grad_partial, dist_pred.grad + ), f"dist grad is not equal to orgin grad\n{pred.grad}\n{dist_pred.grad}" + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_dist_log_prob(): + spawn(check_dist_log_prob, 2) + + +if __name__ == "__main__": + test_dist_log_prob() From a810b209f0a838f75223a50aaae988f113d5388b Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 17 Mar 2025 15:34:12 +0800 Subject: [PATCH 05/19] [fix] fix loss func --- .../coati/distributed/grpo_consumer.py | 37 ++++++++++++++----- .../ColossalChat/coati/distributed/utils.py | 5 ++- 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index bee66c770655..26a5dda506d8 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -8,6 +8,7 @@ from coati.distributed.loss import PolicyLoss from coati.distributed.reward.reward_fn import math_reward_fn from coati.distributed.reward.verifiable_reward import VerifiableReward +from coati.distributed.utils import calc_action_log_probs from coati.trainer.utils import all_reduce_mean from transformers import AutoModelForCausalLM, AutoTokenizer @@ -115,23 +116,39 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer) with ctx: - policy_model_log_probs = self.policy_model( + # policy_model_log_probs = self.policy_model( + # input_ids=data["input_ids"], + # attention_mask=data["attention_mask"], + # return_dist_log_prob=True, + # )["logits"] + # policy_model_log_probs = policy_model_log_probs.reshape(8, -1) + # action_log_probs = policy_model_log_probs[:, -num_action:] + # with torch.no_grad(): + # reference_model_log_probs = self.reference_model( + # input_ids=data["input_ids"], + # attention_mask=data["attention_mask"], + # return_dist_log_prob=True, + # )["logits"] + # reference_model_log_probs = reference_model_log_probs.reshape(8, -1) + # reference_action_log_probs = reference_model_log_probs[:, -num_action:] + # # reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action) + # # reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action, self.booster.shard_config, reference_model_logits.shape[-1]) + policy_model_logits = self.policy_model( input_ids=data["input_ids"], attention_mask=data["attention_mask"], - return_dist_log_prob=True, )["logits"] - policy_model_log_probs = policy_model_log_probs.reshape(8, -1) - action_log_probs = policy_model_log_probs[:, -num_action:] + action_log_probs = calc_action_log_probs( + policy_model_logits, data["input_ids"], num_action, self.plugin.shard_config + ) + with torch.no_grad(): - reference_model_log_probs = self.reference_model( + reference_model_logits = self.reference_model( input_ids=data["input_ids"], attention_mask=data["attention_mask"], - return_dist_log_prob=True, )["logits"] - reference_model_log_probs = reference_model_log_probs.reshape(8, -1) - reference_action_log_probs = reference_model_log_probs[:, -num_action:] - # reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action) - # reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action, self.booster.shard_config, reference_model_logits.shape[-1]) + reference_action_log_probs = calc_action_log_probs( + reference_model_logits, data["input_ids"], num_action, self.plugin.shard_config + ) per_token_kl = ( torch.exp(reference_action_log_probs - action_log_probs) diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index b9f6e64643bd..08912458e966 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -88,7 +88,7 @@ def calc_action_log_probs( sequences: torch.LongTensor, num_actions: int, shard_config, - vocab_size: int, + vocab_size: int = None, ) -> torch.Tensor: """Calculate action log probs. @@ -105,7 +105,8 @@ def calc_action_log_probs( """ print(f"sequences {sequences.shape} logits {logits.shape}") log_probs = dist_log_prob(sequences, logits, shard_config, vocab_size, logits.dtype) - + print(f"log_probs {log_probs.shape}") + log_probs = log_probs.squeeze(-1) # log_probs = dist_log_prob(sequences[:, 1:], logits[:, :-1, :], shard_config, vocab_size, logits.dtype) # # labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] # logits: torch.Tensor, # [B, S, Vocab_size] From c247bd8f202b43de9eedf9c3eff9dc7109485c0d Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 17 Mar 2025 16:29:22 +0800 Subject: [PATCH 06/19] [fix] fix log prob plugin --- .../coati/distributed/grpo_consumer.py | 17 ------ .../ColossalChat/coati/distributed/utils.py | 12 +--- colossalai/shardformer/layer/loss.py | 57 ++----------------- colossalai/shardformer/modeling/qwen2.py | 7 +-- 4 files changed, 10 insertions(+), 83 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 26a5dda506d8..b1edb89bb0e5 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -116,23 +116,6 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]: ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer) with ctx: - # policy_model_log_probs = self.policy_model( - # input_ids=data["input_ids"], - # attention_mask=data["attention_mask"], - # return_dist_log_prob=True, - # )["logits"] - # policy_model_log_probs = policy_model_log_probs.reshape(8, -1) - # action_log_probs = policy_model_log_probs[:, -num_action:] - # with torch.no_grad(): - # reference_model_log_probs = self.reference_model( - # input_ids=data["input_ids"], - # attention_mask=data["attention_mask"], - # return_dist_log_prob=True, - # )["logits"] - # reference_model_log_probs = reference_model_log_probs.reshape(8, -1) - # reference_action_log_probs = reference_model_log_probs[:, -num_action:] - # # reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action) - # # reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action, self.booster.shard_config, reference_model_logits.shape[-1]) policy_model_logits = self.policy_model( input_ids=data["input_ids"], attention_mask=data["attention_mask"], diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index 08912458e966..524a7a336d57 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -103,18 +103,10 @@ def calc_action_log_probs( Returns: torch.Tensor: Action log probs. """ - print(f"sequences {sequences.shape} logits {logits.shape}") + # labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] + # logits: torch.Tensor, # [B, S, Vocab_size] log_probs = dist_log_prob(sequences, logits, shard_config, vocab_size, logits.dtype) - print(f"log_probs {log_probs.shape}") log_probs = log_probs.squeeze(-1) - # log_probs = dist_log_prob(sequences[:, 1:], logits[:, :-1, :], shard_config, vocab_size, logits.dtype) - # # labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] - # logits: torch.Tensor, # [B, S, Vocab_size] - # shard_config: ShardConfig, - # vocab_size: int, - # dtype: torch.dtype, - # seq_dim: int = 1, - # log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) return log_probs[:, -num_actions:] diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index e19299608114..bf5ad2567491 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -204,6 +204,7 @@ def forward( # 2.gather via labels ###### log_probs = log_probs.gather(dim=-1, index=masked_target.unsqueeze(-1)) + print(f"log_probs in ops {log_probs.shape}") # set masked val to zero, then all reduce log_probs[mask.unsqueeze(-1)] = 0 # allreduce log_probs with ops SUM @@ -360,62 +361,22 @@ def dist_log_prob( sp_mode = shard_config.sequence_parallelism_mode parallel_output = shard_config.parallel_output is_tp = shard_config.enable_tensor_parallelism - is_packed = labels.dim() == 2 - if is_packed: - bs, seq_len = labels.shape - else: - # padded sequence - seq_len = labels.shape[-1] - logits = logits.reshape(-1, *logits.shape[2:]) - seq_dim = 0 # Shift labels to predict the next token, and remove the tail logit predicting sp_size > 1 and (not is_share_sp_tp(sp_mode)) - split_labels_here = seq_len // sp_size == logits.size(seq_dim) # ring attn splits labels before forward - # if sp_mode == "ring_attn": - # # For Zigzag Ring Attention, labels should've been split and - # # shifted by RingAttention.prepare_varlen_batch() - # if sp_rank == 0: - # logits = logits[..., :-1, :] - # logits = torch.cat([logits, torch.full_like(logits[:, :1, :], _IGNORE_IDX)], dim=seq_dim) - # elif is_sp: - # # Shift only once: either before splitting or in the last rank without splitting - # if split_labels_here or (sp_rank == sp_size - 1): - # labels = labels[..., 1:] - # if split_labels_here: - # labels = labels.split(seq_len // sp_size, dim=-1)[sp_rank] - - # if sp_rank == sp_size - 1: - # logits = logits[..., :-1, :] - # # Pad logits and labels to the same shape across all ranks for TP all_reduce - # if is_tp and parallel_output: - # # If is packed sequence (label dim is 1), then each seq already has the end label token padded. - # # torch.cat is faster than F.pad... - # pad_shape = (logits.shape[0], 1, *logits.shape[2:]) if is_packed else (1, *logits.shape[1:]) - # padding = torch.full(pad_shape, _IGNORE_IDX, dtype=logits.dtype, device=logits.device) - # logits = torch.cat([logits, padding], dim=seq_dim) - # pad_shape = (labels.shape[0], 1) if is_packed else (1,) - # padding = torch.full(pad_shape, _IGNORE_IDX, dtype=labels.dtype, device=labels.device) - # labels = torch.cat([labels, padding], dim=seq_dim) - # else: # TODO:support sp labels = labels[..., 1:] logits = logits[..., :-1, :] labels = labels.contiguous() logits = logits.contiguous() - # num_nonzero = (labels != _IGNORE_IDX).sum() assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}" # Flatten the tokens loss_fct = LogSoftmax() - labels = labels.view(-1) if is_tp and parallel_output: - # Cross entropy with all-reduce for TP - new_vocab_size = logits.shape[-1] - logits = logits.view(-1, new_vocab_size) - loss = dist_log_prob_1d( + log_prob = dist_log_prob_1d( logits, labels, process_group=shard_config.tensor_parallel_process_group, @@ -425,14 +386,6 @@ def dist_log_prob( else: # NOTE if use TP and not parallel_output, the output is gathered in VocabParallelLMHead1D logits = logits.view(-1, logits.size(-1)) - loss = loss_fct(logits) - - # # Reduce loss instead of gathering logits over seq dim for savings - # if split_labels_here or sp_mode == "ring_attn": - # # Get the global non-zero count - # loss = torch.stack((loss, num_nonzero)) - # # Rescale to offset the grad / (DP * SP) in HybridParallelPlugin - # loss = reduce_forward(loss, sp_group, grad_scale=sp_size) - # loss, num_nonzero = loss[0], loss[1].detach() - # loss = (loss / num_nonzero).squeeze() - return loss + log_prob = loss_fct(logits) + + return log_prob diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index d4d89c7b2fae..4d22e5be2c09 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -35,7 +35,7 @@ from colossalai.shardformer.layer._operation import all_to_all_comm, split_forward_gather_backward from colossalai.shardformer.shard import ShardConfig -from ..layer import ColoAttention, dist_cross_entropy, dist_log_prob +from ..layer import ColoAttention, dist_cross_entropy from ..layer._operation import gather_sp_output from ..layer.utils import is_share_sp_tp @@ -779,7 +779,6 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - return_dist_log_prob: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -833,8 +832,8 @@ def forward( loss = None if labels is not None: loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype) - if return_dist_log_prob: - logits = dist_log_prob(input_ids, logits, shard_config, self.lm_head.out_features, logits.dtype) + # if return_dist_log_prob: + # logits = dist_log_prob(input_ids, logits, shard_config, self.lm_head.out_features, logits.dtype) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output From b78ab3a6ffe1eda0082ea9a235a98f79b1b3d568 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 17 Mar 2025 16:31:23 +0800 Subject: [PATCH 07/19] [fix] fix qwen modeling param --- colossalai/shardformer/modeling/qwen2.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 4d22e5be2c09..71e3557fe214 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -832,8 +832,6 @@ def forward( loss = None if labels is not None: loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype) - # if return_dist_log_prob: - # logits = dist_log_prob(input_ids, logits, shard_config, self.lm_head.out_features, logits.dtype) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output From dddd062cfb903856885fbb046422e950c94f64b3 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 17 Mar 2025 16:32:57 +0800 Subject: [PATCH 08/19] [fix] rm comments --- .../ColossalChat/coati/distributed/utils.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index 524a7a336d57..919e4434faa6 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -68,21 +68,6 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T return per_label_logps.squeeze(-1) -# def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor: -# """Calculate action log probs. - -# Args: -# output (torch.Tensor): Output tensor of Actor.forward.logits. -# sequences (torch.LongTensor): Input sequences. -# num_actions (int): Number of actions. - -# Returns: -# torch.Tensor: Action log probs. -# """ -# log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) -# return log_probs[:, -num_actions:] - - def calc_action_log_probs( logits: torch.Tensor, sequences: torch.LongTensor, From 74de49db802ce07d17f8d30421cfb7961ad93760 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 17 Mar 2025 18:09:53 +0800 Subject: [PATCH 09/19] [fix] rm hard-code;fix non-dist version --- applications/ColossalChat/coati/distributed/consumer.py | 2 +- applications/ColossalChat/coati/distributed/utils.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 6aa63086f87c..380a2ee1b78a 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -66,7 +66,7 @@ def setup(self) -> None: launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0) plugin_config = dict( - tp_size=2, + tp_size=1, pp_size=1, precision="bf16", zero_stage=1, diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index 919e4434faa6..c3e68763d84c 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -90,8 +90,11 @@ def calc_action_log_probs( """ # labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] # logits: torch.Tensor, # [B, S, Vocab_size] - log_probs = dist_log_prob(sequences, logits, shard_config, vocab_size, logits.dtype) - log_probs = log_probs.squeeze(-1) + if shard_config.tensor_parallel_size > 1: + log_probs = dist_log_prob(sequences, logits, shard_config, vocab_size, logits.dtype) + log_probs = log_probs.squeeze(-1) + else: + log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) return log_probs[:, -num_actions:] From 188d69dccf185548af7e1dad53c727ffab9198e6 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 17 Mar 2025 18:36:08 +0800 Subject: [PATCH 10/19] [fix] fix test file param name and benchmark tp gather output=True/False --- .../ColossalChat/coati/distributed/consumer.py | 2 -- applications/ColossalChat/coati/distributed/utils.py | 2 +- colossalai/shardformer/layer/loss.py | 2 -- colossalai/shardformer/policies/qwen2.py | 3 +-- .../test_layer/test_dist_log_prob.py | 12 ++++++------ 5 files changed, 8 insertions(+), 13 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 380a2ee1b78a..1e85cccb3c5b 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -73,8 +73,6 @@ def setup(self) -> None: ) if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config: plugin_config["microbatch_size"] = self.microbatch_size - if self.plugin_config.get("tp_size", 1) > 1: - plugin_config["parallel_output"] = False plugin_config.update(self.plugin_config) self.plugin = HybridParallelPlugin(**plugin_config) self.booster = Booster(plugin=self.plugin) diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index c3e68763d84c..d1187a4cc24a 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -90,7 +90,7 @@ def calc_action_log_probs( """ # labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] # logits: torch.Tensor, # [B, S, Vocab_size] - if shard_config.tensor_parallel_size > 1: + if shard_config.tensor_parallel_size > 1 and shard_config.parallel_output: log_probs = dist_log_prob(sequences, logits, shard_config, vocab_size, logits.dtype) log_probs = log_probs.squeeze(-1) else: diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index bf5ad2567491..b1f745f8746d 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -384,8 +384,6 @@ def dist_log_prob( dtype=dtype, ) else: - # NOTE if use TP and not parallel_output, the output is gathered in VocabParallelLMHead1D - logits = logits.view(-1, logits.size(-1)) log_prob = loss_fct(logits) return log_prob diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py index fd14029a3a36..0adcdfdbd553 100644 --- a/colossalai/shardformer/policies/qwen2.py +++ b/colossalai/shardformer/policies/qwen2.py @@ -11,7 +11,6 @@ Linear1D_Row, LinearWithGradAccum, PaddingEmbedding, - PaddingLMHead, RMSNorm, VocabParallelEmbedding1D, VocabParallelLMHead1D, @@ -450,7 +449,7 @@ def module_policy(self): sub_module_replacement=[ SubModuleReplacementDescription( suffix="lm_head", - target_module=PaddingLMHead, + target_module=LinearWithGradAccum, kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv), ), SubModuleReplacementDescription( diff --git a/tests/test_shardformer/test_layer/test_dist_log_prob.py b/tests/test_shardformer/test_layer/test_dist_log_prob.py index 2cf874bc0ca1..f863ee555c22 100644 --- a/tests/test_shardformer/test_layer/test_dist_log_prob.py +++ b/tests/test_shardformer/test_layer/test_dist_log_prob.py @@ -35,21 +35,21 @@ def check_dist_log_prob(rank, world_size, port): pred = torch.randn(2, 4, 8, requires_grad=True).cuda() labels = torch.randint(8, (2, 4)).cuda() - loss = log_probs_from_logits(pred, labels) + logprob = log_probs_from_logits(pred, labels) pred.retain_grad() - loss.mean().backward() + logprob.mean().backward() dist_pred = pred.clone().chunk(world_size, -1)[rank].detach() dist_pred.requires_grad = True - dist_loss = dist_log_prob_1d(dist_pred, labels) + dist_logprob = dist_log_prob_1d(dist_pred, labels) dist_pred.retain_grad() - dist_loss.squeeze(-1).mean().backward() + dist_logprob.squeeze(-1).mean().backward() assert torch.allclose( - loss, dist_loss.squeeze(-1), atol=1e-5 - ), f"dist cross entropy loss is not equal to orgin loss\n{loss}\n{dist_loss.squeeze(-1)}" + logprob, dist_logprob.squeeze(-1), atol=1e-5 + ), f"dist cross entropy logprob is not equal to orgin logprob\n{logprob}\n{dist_logprob.squeeze(-1)}" pred_grad_partial = pred.grad.clone().chunk(world_size, -1)[rank].detach() assert torch.allclose( From 01bcacac02012b3fd8e56ce00b71a89d7a1158cd Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 18 Mar 2025 09:34:07 +0800 Subject: [PATCH 11/19] [fix] rm non-dist version in dist log prob --- colossalai/shardformer/layer/loss.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index b1f745f8746d..3da9b6ab6b47 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -2,7 +2,7 @@ import torch.distributed as dist from torch.autograd import Function from torch.distributed import ProcessGroup -from torch.nn import CrossEntropyLoss, LogSoftmax +from torch.nn import CrossEntropyLoss from colossalai.shardformer.layer._operation import reduce_forward from colossalai.shardformer.shard import ShardConfig @@ -359,8 +359,8 @@ def dist_log_prob( dist.get_rank(sp_group) sp_size = shard_config.sequence_parallel_size sp_mode = shard_config.sequence_parallelism_mode - parallel_output = shard_config.parallel_output - is_tp = shard_config.enable_tensor_parallelism + shard_config.parallel_output + shard_config.enable_tensor_parallelism # Shift labels to predict the next token, and remove the tail logit predicting sp_size > 1 and (not is_share_sp_tp(sp_mode)) @@ -372,18 +372,12 @@ def dist_log_prob( logits = logits.contiguous() assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}" - # Flatten the tokens - loss_fct = LogSoftmax() - - if is_tp and parallel_output: - log_prob = dist_log_prob_1d( - logits, - labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=vocab_size, - dtype=dtype, - ) - else: - log_prob = loss_fct(logits) + log_prob = dist_log_prob_1d( + logits, + labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=vocab_size, + dtype=dtype, + ) return log_prob From 027759279b2f553a9988a2c493dad8664708a2e2 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 18 Mar 2025 09:36:17 +0800 Subject: [PATCH 12/19] [fix] fix comments --- colossalai/shardformer/layer/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index 3da9b6ab6b47..62d7d2300580 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -352,7 +352,7 @@ def dist_log_prob( seq_dim: int = 1, ) -> torch.Tensor: """ - Helper to compute cross entropy loss for most shardformer models supporting PP, TP and SP. + Helper to compute log prob for most shardformer models supporting PP, TP. Will Support SP soon in feature """ # Split labels if not gather output sp_group = shard_config.sequence_parallel_process_group From 3a8a387e71354b156f5a4d50bd374dfe1ec5127f Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 18 Mar 2025 11:34:57 +0800 Subject: [PATCH 13/19] [fix] fix dis log prob plugin --- .../coati/distributed/consumer.py | 2 ++ .../ColossalChat/coati/distributed/utils.py | 7 ++--- colossalai/shardformer/layer/loss.py | 28 +++++++++++-------- 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 1e85cccb3c5b..41bed8047061 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -73,6 +73,8 @@ def setup(self) -> None: ) if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config: plugin_config["microbatch_size"] = self.microbatch_size + if self.plugin_config.get("tp_size", 1) > 1: + plugin_config["parallel_output"] = True plugin_config.update(self.plugin_config) self.plugin = HybridParallelPlugin(**plugin_config) self.booster = Booster(plugin=self.plugin) diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index d1187a4cc24a..919e4434faa6 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -90,11 +90,8 @@ def calc_action_log_probs( """ # labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] # logits: torch.Tensor, # [B, S, Vocab_size] - if shard_config.tensor_parallel_size > 1 and shard_config.parallel_output: - log_probs = dist_log_prob(sequences, logits, shard_config, vocab_size, logits.dtype) - log_probs = log_probs.squeeze(-1) - else: - log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) + log_probs = dist_log_prob(sequences, logits, shard_config, vocab_size, logits.dtype) + log_probs = log_probs.squeeze(-1) return log_probs[:, -num_actions:] diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index 62d7d2300580..98e8086dae22 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -2,7 +2,7 @@ import torch.distributed as dist from torch.autograd import Function from torch.distributed import ProcessGroup -from torch.nn import CrossEntropyLoss +from torch.nn import CrossEntropyLoss, LogSoftmax from colossalai.shardformer.layer._operation import reduce_forward from colossalai.shardformer.shard import ShardConfig @@ -352,15 +352,15 @@ def dist_log_prob( seq_dim: int = 1, ) -> torch.Tensor: """ - Helper to compute log prob for most shardformer models supporting PP, TP. Will Support SP soon in feature + Helper to compute cross entropy loss for most shardformer models supporting PP, TP and SP. """ # Split labels if not gather output sp_group = shard_config.sequence_parallel_process_group dist.get_rank(sp_group) sp_size = shard_config.sequence_parallel_size sp_mode = shard_config.sequence_parallelism_mode - shard_config.parallel_output - shard_config.enable_tensor_parallelism + parallel_output = shard_config.parallel_output + is_tp = shard_config.enable_tensor_parallelism # Shift labels to predict the next token, and remove the tail logit predicting sp_size > 1 and (not is_share_sp_tp(sp_mode)) @@ -372,12 +372,18 @@ def dist_log_prob( logits = logits.contiguous() assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}" - log_prob = dist_log_prob_1d( - logits, - labels, - process_group=shard_config.tensor_parallel_process_group, - vocab_size=vocab_size, - dtype=dtype, - ) + # Flatten the tokens + loss_fct = LogSoftmax() + if is_tp and parallel_output: + log_prob = dist_log_prob_1d( + logits, + labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=vocab_size, + dtype=dtype, + ) + else: + log_prob = loss_fct(logits) + log_prob = log_prob.gather(dim=-1, index=labels.unsqueeze(-1)) return log_prob From d29f39d52e24ab9d9a6de1c55a1c595947652dae Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 18 Mar 2025 11:41:10 +0800 Subject: [PATCH 14/19] [fix] fix test case --- .../test_layer/test_dist_log_prob.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/tests/test_shardformer/test_layer/test_dist_log_prob.py b/tests/test_shardformer/test_layer/test_dist_log_prob.py index f863ee555c22..05a6a5d4766f 100644 --- a/tests/test_shardformer/test_layer/test_dist_log_prob.py +++ b/tests/test_shardformer/test_layer/test_dist_log_prob.py @@ -1,5 +1,6 @@ import pytest import torch +from coati.distributed.utils import log_probs_from_logits import colossalai from colossalai.logging import disable_existing_loggers @@ -11,22 +12,6 @@ ) -def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: - """ - Compute the log probabilities from logits for the given labels. - - Args: - logits (torch.Tensor): The input logits. - labels (torch.Tensor): The target labels. - - Returns: - torch.Tensor: The log probabilities corresponding to the labels. - """ - log_probs = torch.log_softmax(logits, dim=-1) - per_label_logps = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)) - return per_label_logps.squeeze(-1) - - def check_dist_log_prob(rank, world_size, port): disable_existing_loggers() colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl") From dcf3f9b6e996d402bb98d753eb4eddfd3e3b9fcd Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 13 Mar 2025 13:24:40 +0800 Subject: [PATCH 15/19] [fix] fix qwen VocabParallelLMHead1D and gather output --- colossalai/shardformer/policies/qwen2.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py index 8e150fef1f3d..fd14029a3a36 100644 --- a/colossalai/shardformer/policies/qwen2.py +++ b/colossalai/shardformer/policies/qwen2.py @@ -11,6 +11,7 @@ Linear1D_Row, LinearWithGradAccum, PaddingEmbedding, + PaddingLMHead, RMSNorm, VocabParallelEmbedding1D, VocabParallelLMHead1D, @@ -430,8 +431,12 @@ def module_policy(self): sub_module_replacement=[ SubModuleReplacementDescription( suffix="lm_head", - target_module=Linear1D_Col, - kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv), + target_module=VocabParallelLMHead1D, + kwargs=dict( + gather_output=not self.shard_config.parallel_output, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ) ], method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}, @@ -445,7 +450,7 @@ def module_policy(self): sub_module_replacement=[ SubModuleReplacementDescription( suffix="lm_head", - target_module=LinearWithGradAccum, + target_module=PaddingLMHead, kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv), ), SubModuleReplacementDescription( From 0ebeebc102deeeae83009a33a8323966ed9f86aa Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 18 Mar 2025 16:16:00 +0800 Subject: [PATCH 16/19] [fix] fix DistLogProb comments --- .../coati/distributed/consumer.py | 4 +- colossalai/shardformer/layer/loss.py | 62 ++++++++++--------- 2 files changed, 34 insertions(+), 32 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 41bed8047061..7dac48c77636 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -66,15 +66,13 @@ def setup(self) -> None: launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0) plugin_config = dict( - tp_size=1, + tp_size=2, pp_size=1, precision="bf16", zero_stage=1, ) if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config: plugin_config["microbatch_size"] = self.microbatch_size - if self.plugin_config.get("tp_size", 1) > 1: - plugin_config["parallel_output"] = True plugin_config.update(self.plugin_config) self.plugin = HybridParallelPlugin(**plugin_config) self.booster = Booster(plugin=self.plugin) diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index 98e8086dae22..1d6b68d01a68 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -161,23 +161,26 @@ def forward( vocab_size: int, dtype=torch.float32, ): - ###### - # 1.log softmax - ###### - # local max for softmax + + ################## + # Step1:Find the global maximum value of logits + ################## logits_max = torch.max(vocab_logits, dim=-1)[0] handle = dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group, async_op=True) - # get rank and worldsize + + ################## + # Step2:Find the local mask. local mask will be use to select log_probs value in Step 4. + # For accleration, we overlap Step 2 and Step 3 + ################## rank = dist.get_rank(group=process_group) world_size = dist.get_world_size(group=process_group) - # cal vocal size if vocab_size is None: partition_vocab_size = vocab_logits.size()[-1] global_vocab_size = partition_vocab_size * world_size else: global_vocab_size = vocab_size partition_vocab_size = global_vocab_size // world_size - # down and up threshold + # down and up threshold for local logits delta = (global_vocab_size + world_size - 1) // world_size down_threshold = rank * delta up_threshold = down_threshold + delta @@ -188,27 +191,24 @@ def forward( masked_target = target.clone() - down_threshold masked_target[mask] = 0 masked_target_1d = masked_target.view(-1).contiguous() - # wait for allreduce handle.wait() + + ################## + # Step3:Calculate global summation exp logits + ################## vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1) - # cal exp_logits exp_logits = torch.exp(vocab_logits) - # cal local sum_exp_logits - sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32) - # all_reduce get global sum_exp_logits + sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32) # local summation exp logits dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group) - # cal log_softmax - log_probs = vocab_logits - torch.log(sum_exp_logits.unsqueeze(dim=-1)) - ###### - # 2.gather via labels - ###### + ################## + # Step4:Calculate local prob. We first cal log_softmax, them + ################## + log_probs = vocab_logits - torch.log(sum_exp_logits.unsqueeze(dim=-1)) # cal log_softmax log_probs = log_probs.gather(dim=-1, index=masked_target.unsqueeze(-1)) - print(f"log_probs in ops {log_probs.shape}") - # set masked val to zero, then all reduce - log_probs[mask.unsqueeze(-1)] = 0 - # allreduce log_probs with ops SUM + log_probs[mask.unsqueeze(-1)] = 0 # # set masked val to zero dist.all_reduce(log_probs, op=dist.ReduceOp.SUM, group=process_group) + ctx.save_for_backward(exp_logits, mask, masked_target_1d, sum_exp_logits) ctx.dtype = dtype return log_probs @@ -216,11 +216,22 @@ def forward( @staticmethod def backward(ctx, grad_output): exp_logits, mask, masked_target_1d, sum_exp_logits = ctx.saved_tensors + ################## + # Step1:Find the global sofmax value + ################## softmax_logits = exp_logits / sum_exp_logits.unsqueeze(dim=-1) + + ################## + # Step2:Update softmax value based on local target index + ################## partion_vocab_size = softmax_logits.shape[-1] softmax_logits_2d = softmax_logits.view(-1, partion_vocab_size) update = 1.0 - mask.view(-1).float().to(ctx.dtype) softmax_logits_2d[torch.arange(0, softmax_logits_2d.shape[0]), masked_target_1d] -= update + + ################## + # Step3:Calculate grad_output, which is the gradient of the loss function with respect to the output of logsoftmax + ################## grad_logits = -softmax_logits.mul_(grad_output) return grad_logits, None, None, None, None, None, None @@ -352,19 +363,12 @@ def dist_log_prob( seq_dim: int = 1, ) -> torch.Tensor: """ - Helper to compute cross entropy loss for most shardformer models supporting PP, TP and SP. + Helper to compute log prob for most shardformer models supporting PP, TP. """ # Split labels if not gather output - sp_group = shard_config.sequence_parallel_process_group - dist.get_rank(sp_group) - sp_size = shard_config.sequence_parallel_size - sp_mode = shard_config.sequence_parallelism_mode parallel_output = shard_config.parallel_output is_tp = shard_config.enable_tensor_parallelism - # Shift labels to predict the next token, and remove the tail logit predicting - sp_size > 1 and (not is_share_sp_tp(sp_mode)) - # TODO:support sp labels = labels[..., 1:] logits = logits[..., :-1, :] From 1a7cc25336a731000069e661af4054f3c1cd7488 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 18 Mar 2025 16:17:50 +0800 Subject: [PATCH 17/19] [fix] restore tp size --- applications/ColossalChat/coati/distributed/consumer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 7dac48c77636..1e85cccb3c5b 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -66,7 +66,7 @@ def setup(self) -> None: launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0) plugin_config = dict( - tp_size=2, + tp_size=1, pp_size=1, precision="bf16", zero_stage=1, From 7e2f0585a2a5324df38b4548efe6bbb23a61232c Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 18 Mar 2025 16:24:54 +0800 Subject: [PATCH 18/19] [fix] fix comments --- colossalai/shardformer/layer/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index 1d6b68d01a68..901c35c27e93 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -146,7 +146,7 @@ def backward(ctx, grad_output): class DistLogProb(Function): r""" - Overwrite the forward and backward function to calculate the cross entropy loss before gather + Overwrite the forward and backward function to calculate the log prob before gather Args: Function (:class:`torch.autograd.Function`): default From f381cea43cceaf2b725dee2c33b545261e574ad5 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 18 Mar 2025 17:23:30 +0800 Subject: [PATCH 19/19] [fix] fix comment; fix LogSoftmax usage --- colossalai/shardformer/layer/loss.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index 901c35c27e93..51419a38a0ed 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -2,7 +2,8 @@ import torch.distributed as dist from torch.autograd import Function from torch.distributed import ProcessGroup -from torch.nn import CrossEntropyLoss, LogSoftmax +from torch.nn import CrossEntropyLoss +from torch.nn.functional import log_softmax from colossalai.shardformer.layer._operation import reduce_forward from colossalai.shardformer.shard import ShardConfig @@ -202,11 +203,11 @@ def forward( dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group) ################## - # Step4:Calculate local prob. We first cal log_softmax, them + # Step4:Calculate local prob. We first cal log_softmax, then select log probs via local mask ################## log_probs = vocab_logits - torch.log(sum_exp_logits.unsqueeze(dim=-1)) # cal log_softmax log_probs = log_probs.gather(dim=-1, index=masked_target.unsqueeze(-1)) - log_probs[mask.unsqueeze(-1)] = 0 # # set masked val to zero + log_probs[mask.unsqueeze(-1)] = 0 # set masked val to zero dist.all_reduce(log_probs, op=dist.ReduceOp.SUM, group=process_group) ctx.save_for_backward(exp_logits, mask, masked_target_1d, sum_exp_logits) @@ -377,7 +378,6 @@ def dist_log_prob( assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}" # Flatten the tokens - loss_fct = LogSoftmax() if is_tp and parallel_output: log_prob = dist_log_prob_1d( logits, @@ -387,7 +387,7 @@ def dist_log_prob( dtype=dtype, ) else: - log_prob = loss_fct(logits) + log_prob = log_softmax(logits) log_prob = log_prob.gather(dim=-1, index=labels.unsqueeze(-1)) return log_prob