Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions applications/ColossalChat/coati/distributed/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,6 @@ def setup(self) -> None:
)
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
plugin_config["microbatch_size"] = self.microbatch_size
if self.plugin_config.get("tp_size", 1) > 1:
plugin_config["parallel_output"] = False
plugin_config.update(self.plugin_config)
self.plugin = HybridParallelPlugin(**plugin_config)
self.booster = Booster(plugin=self.plugin)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,18 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
input_ids=data["input_ids"],
attention_mask=data["attention_mask"],
)["logits"]
action_log_probs = calc_action_log_probs(policy_model_logits, data["input_ids"], num_action)
action_log_probs = calc_action_log_probs(
policy_model_logits, data["input_ids"], num_action, self.plugin.shard_config
)

with torch.no_grad():
reference_model_logits = self.reference_model(
input_ids=data["input_ids"],
attention_mask=data["attention_mask"],
)["logits"]
reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action)
reference_action_log_probs = calc_action_log_probs(
reference_model_logits, data["input_ids"], num_action, self.plugin.shard_config
)

per_token_kl = (
torch.exp(reference_action_log_probs - action_log_probs)
Expand Down
20 changes: 17 additions & 3 deletions applications/ColossalChat/coati/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import torch

from colossalai.shardformer.layer.loss import dist_log_prob


def unbind_batch(batch: Dict[str, torch.Tensor]) -> List[Dict[str, torch.Tensor]]:
batches = []
Expand Down Expand Up @@ -66,18 +68,30 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T
return per_label_logps.squeeze(-1)


def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
def calc_action_log_probs(
logits: torch.Tensor,
sequences: torch.LongTensor,
num_actions: int,
shard_config,
vocab_size: int = None,
) -> torch.Tensor:
"""Calculate action log probs.

Args:
output (torch.Tensor): Output tensor of Actor.forward.logits.
logits (torch.Tensor): Output tensor of Actor.forward.logits.
sequences (torch.LongTensor): Input sequences.
num_actions (int): Number of actions.
shard_config
vocab_size


Returns:
torch.Tensor: Action log probs.
"""
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
# labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
# logits: torch.Tensor, # [B, S, Vocab_size]
log_probs = dist_log_prob(sequences, logits, shard_config, vocab_size, logits.dtype)
log_probs = log_probs.squeeze(-1)
return log_probs[:, -num_actions:]


Expand Down
4 changes: 3 additions & 1 deletion colossalai/shardformer/layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHead, VocabParallelLMHead1D
from .loss import cross_entropy_1d, dist_cross_entropy
from .loss import cross_entropy_1d, dist_cross_entropy, dist_log_prob, dist_log_prob_1d
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule
from .qkv_fused_linear import (
Expand All @@ -28,6 +28,8 @@
"DropoutForReplicatedInput",
"cross_entropy_1d",
"dist_cross_entropy",
"dist_log_prob_1d",
"dist_log_prob",
"BaseLayerNorm",
"LayerNorm",
"RMSNorm",
Expand Down
150 changes: 149 additions & 1 deletion colossalai/shardformer/layer/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,21 @@
from torch.autograd import Function
from torch.distributed import ProcessGroup
from torch.nn import CrossEntropyLoss
from torch.nn.functional import log_softmax

from colossalai.shardformer.layer._operation import reduce_forward
from colossalai.shardformer.shard import ShardConfig

from .utils import is_share_sp_tp

__all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"]
__all__ = [
"DistCrossEntropy",
"cross_entropy_1d",
"dist_cross_entropy",
"DistLogProb",
"dist_log_prob_1d",
"dist_log_prob",
]

_IGNORE_IDX = -100

Expand Down Expand Up @@ -137,6 +145,98 @@ def backward(ctx, grad_output):
return grad_logits, None, None, None, None, None, None


class DistLogProb(Function):
r"""
Overwrite the forward and backward function to calculate the log prob before gather

Args:
Function (:class:`torch.autograd.Function`): default
"""

@staticmethod
def forward(
ctx,
vocab_logits: torch.Tensor,
target: torch.Tensor,
process_group: ProcessGroup,
vocab_size: int,
dtype=torch.float32,
):

##################
# Step1:Find the global maximum value of logits
##################
logits_max = torch.max(vocab_logits, dim=-1)[0]
handle = dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group, async_op=True)

##################
# Step2:Find the local mask. local mask will be use to select log_probs value in Step 4.
# For accleration, we overlap Step 2 and Step 3
##################
rank = dist.get_rank(group=process_group)
world_size = dist.get_world_size(group=process_group)
if vocab_size is None:
partition_vocab_size = vocab_logits.size()[-1]
global_vocab_size = partition_vocab_size * world_size
else:
global_vocab_size = vocab_size
partition_vocab_size = global_vocab_size // world_size
# down and up threshold for local logits
delta = (global_vocab_size + world_size - 1) // world_size
down_threshold = rank * delta
up_threshold = down_threshold + delta
if up_threshold > global_vocab_size:
up_threshold = global_vocab_size
# mask
mask = (target < down_threshold) | (target >= up_threshold)
masked_target = target.clone() - down_threshold
masked_target[mask] = 0
masked_target_1d = masked_target.view(-1).contiguous()
handle.wait()

##################
# Step3:Calculate global summation exp logits
##################
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
exp_logits = torch.exp(vocab_logits)
sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32) # local summation exp logits
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group)

##################
# Step4:Calculate local prob. We first cal log_softmax, then select log probs via local mask
##################
log_probs = vocab_logits - torch.log(sum_exp_logits.unsqueeze(dim=-1)) # cal log_softmax
log_probs = log_probs.gather(dim=-1, index=masked_target.unsqueeze(-1))
log_probs[mask.unsqueeze(-1)] = 0 # set masked val to zero
dist.all_reduce(log_probs, op=dist.ReduceOp.SUM, group=process_group)

ctx.save_for_backward(exp_logits, mask, masked_target_1d, sum_exp_logits)
ctx.dtype = dtype
return log_probs

@staticmethod
def backward(ctx, grad_output):
exp_logits, mask, masked_target_1d, sum_exp_logits = ctx.saved_tensors
##################
# Step1:Find the global sofmax value
##################
softmax_logits = exp_logits / sum_exp_logits.unsqueeze(dim=-1)

##################
# Step2:Update softmax value based on local target index
##################
partion_vocab_size = softmax_logits.shape[-1]
softmax_logits_2d = softmax_logits.view(-1, partion_vocab_size)
update = 1.0 - mask.view(-1).float().to(ctx.dtype)
softmax_logits_2d[torch.arange(0, softmax_logits_2d.shape[0]), masked_target_1d] -= update

##################
# Step3:Calculate grad_output, which is the gradient of the loss function with respect to the output of logsoftmax
##################
grad_logits = -softmax_logits.mul_(grad_output)
return grad_logits, None, None, None, None, None, None


def cross_entropy_1d(
vocab_logits: torch.Tensor,
labels: torch.Tensor,
Expand All @@ -149,6 +249,16 @@ def cross_entropy_1d(
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype, mode)


def dist_log_prob_1d(
vocab_logits: torch.Tensor,
labels: torch.Tensor,
process_group: ProcessGroup = None,
vocab_size: int = None,
dtype: torch.dtype = None,
) -> torch.Tensor:
return DistLogProb.apply(vocab_logits, labels, process_group, vocab_size, dtype)


def dist_cross_entropy(
labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
logits: torch.Tensor, # [B, S, Vocab_size]
Expand Down Expand Up @@ -243,3 +353,41 @@ def dist_cross_entropy(
loss, num_nonzero = loss[0], loss[1].detach()
loss = (loss / num_nonzero).squeeze()
return loss


def dist_log_prob(
labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
logits: torch.Tensor, # [B, S, Vocab_size]
shard_config: ShardConfig,
vocab_size: int,
dtype: torch.dtype,
seq_dim: int = 1,
) -> torch.Tensor:
"""
Helper to compute log prob for most shardformer models supporting PP, TP.
"""
# Split labels if not gather output
parallel_output = shard_config.parallel_output
is_tp = shard_config.enable_tensor_parallelism

# TODO:support sp
labels = labels[..., 1:]
logits = logits[..., :-1, :]
labels = labels.contiguous()
logits = logits.contiguous()
assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}"

# Flatten the tokens
if is_tp and parallel_output:
log_prob = dist_log_prob_1d(
logits,
labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=vocab_size,
dtype=dtype,
)
else:
log_prob = log_softmax(logits)
log_prob = log_prob.gather(dim=-1, index=labels.unsqueeze(-1))

return log_prob
1 change: 0 additions & 1 deletion colossalai/shardformer/modeling/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,7 +832,6 @@ def forward(
loss = None
if labels is not None:
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype)

if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
Expand Down
8 changes: 6 additions & 2 deletions colossalai/shardformer/policies/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,8 +430,12 @@ def module_policy(self):
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv),
target_module=VocabParallelLMHead1D,
kwargs=dict(
gather_output=not self.shard_config.parallel_output,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
)
],
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
Expand Down
52 changes: 52 additions & 0 deletions tests/test_shardformer/test_layer/test_dist_log_prob.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pytest
import torch
from coati.distributed.utils import log_probs_from_logits

import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer import dist_log_prob_1d
from colossalai.testing import rerun_if_address_is_in_use, spawn

CONFIG = dict(
parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode="1d")),
)


def check_dist_log_prob(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl")

# prepare data
pred = torch.randn(2, 4, 8, requires_grad=True).cuda()
labels = torch.randint(8, (2, 4)).cuda()

logprob = log_probs_from_logits(pred, labels)

pred.retain_grad()
logprob.mean().backward()

dist_pred = pred.clone().chunk(world_size, -1)[rank].detach()
dist_pred.requires_grad = True
dist_logprob = dist_log_prob_1d(dist_pred, labels)

dist_pred.retain_grad()
dist_logprob.squeeze(-1).mean().backward()

assert torch.allclose(
logprob, dist_logprob.squeeze(-1), atol=1e-5
), f"dist cross entropy logprob is not equal to orgin logprob\n{logprob}\n{dist_logprob.squeeze(-1)}"

pred_grad_partial = pred.grad.clone().chunk(world_size, -1)[rank].detach()
assert torch.allclose(
pred_grad_partial, dist_pred.grad
), f"dist grad is not equal to orgin grad\n{pred.grad}\n{dist_pred.grad}"


@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_dist_log_prob():
spawn(check_dist_log_prob, 2)


if __name__ == "__main__":
test_dist_log_prob()