From 75fa0b623a3b045f7b052022ff79dd4848b29f5b Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> Date: Mon, 14 Aug 2023 15:25:58 +0800 Subject: [PATCH 01/46] [moe] support moe fwd and bwd with low level zero (#4421) * fix test files * new file * add new * fix zero * update moe tests for forward and backward * remove useless test * remove print * moe * code style * code style * rename * rename * remove useless func * update param check * update utils and config --- colossalai/zero/low_level/low_level_optim.py | 5 +- tests/test_moe/moe_utils.py | 41 +++++++ tests/test_moe/test_grad_handler.py | 3 +- tests/test_moe/test_kernel.py | 3 +- tests/test_moe/test_moe_checkpoint.py | 5 +- tests/test_moe/test_moe_colo_init.py | 55 ---------- tests/test_moe/test_moe_group.py | 3 +- tests/test_moe/test_moe_zero_fwd_bwd.py | 106 +++++++++++++++++++ tests/test_moe/test_moe_zero_init.py | 106 ------------------- tests/test_moe/test_moe_zero_optim.py | 30 +++++- 10 files changed, 183 insertions(+), 174 deletions(-) create mode 100644 tests/test_moe/moe_utils.py delete mode 100644 tests/test_moe/test_moe_colo_init.py create mode 100644 tests/test_moe/test_moe_zero_fwd_bwd.py delete mode 100644 tests/test_moe/test_moe_zero_init.py diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index e6974a6760ce..cc7c291e507d 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -131,7 +131,10 @@ def __init__( # and add buffers to parameter store for future access for group_id, param_group in enumerate(self.optim.param_groups): group_params = list() - for param in param_group["params"]: + for param in param_group['params']: + # skip moe param + if hasattr(param, "moe_info"): + continue if param.requires_grad: group_params.append(param) diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py new file mode 100644 index 000000000000..4b067c1ceea9 --- /dev/null +++ b/tests/test_moe/moe_utils.py @@ -0,0 +1,41 @@ +import torch.nn as nn + +from colossalai.context import MOE_CONTEXT +from colossalai.nn import CheckpointModule +from colossalai.nn.layer import MoeModule + + +class MoeModel(nn.Module): + + def __init__(self, checkpoint: bool = False): + + class TestSubModule(CheckpointModule): + + def __init__(self): + super().__init__(checkpoint) + expert_cls = nn.Linear + expert_args_dict = dict(in_features=16, out_features=16) + self.moe = MoeModule(dim_model=16, + num_experts=8, + use_residual=True, + expert_cls=expert_cls, + **expert_args_dict) + self.proj = nn.Linear(16, 4) + + def _forward(self, x): + x, y = self.moe(x) + x = self.proj(x) + return x, y + + super().__init__() + self.test_embed = nn.Linear(4, 16) + self.test_transform = TestSubModule() + + def forward(self, x): + MOE_CONTEXT.reset_loss() + + x = self.test_embed(x) + x, y = self.test_transform(x) + + MOE_CONTEXT.add_loss(y) + return x diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index 8742e5f41136..2c42d55fa9e3 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -13,11 +13,10 @@ BATCH_SIZE = 4 DIM = 16 -CONFIG = dict() def run_test(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') expert_module = nn.Linear expert_factor = dict(in_features=DIM, out_features=DIM, device=get_current_device()) diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index 7a9c551d679d..bd0af109fde6 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -12,7 +12,6 @@ BATCH_SIZE = 16 NUM_EXPERTS = 4 -CONFIG = dict() def check_equal(tensor_a, tensor_b, atol=1e-06): @@ -23,7 +22,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f # Here we do not need TF32, since it brings absolute error on results torch.backends.cuda.matmul.allow_tf32 = False - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') local_rank = gpc.get_local_rank(ParallelMode.GLOBAL) MOE_CONTEXT.setup(42) # MOE environment initialization diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index b7024f32b1cf..f108dc3cd5b1 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -10,8 +10,7 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device from colossalai.zero import ColoInitContext -from tests.test_moe.test_moe_zero_init import MoeModel -from tests.test_zero.test_legacy.common import CONFIG +from tests.test_moe.moe_utils import MoeModel def exam_moe_checkpoint(): @@ -34,7 +33,7 @@ def exam_moe_checkpoint(): def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') MOE_CONTEXT.setup(seed=42) exam_moe_checkpoint() diff --git a/tests/test_moe/test_moe_colo_init.py b/tests/test_moe/test_moe_colo_init.py deleted file mode 100644 index 488573b733b1..000000000000 --- a/tests/test_moe/test_moe_colo_init.py +++ /dev/null @@ -1,55 +0,0 @@ -import pytest -import torch -import torch.distributed as dist - -import colossalai -from colossalai.context import MOE_CONTEXT -from colossalai.tensor import ColoParameter -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device -from colossalai.zero import ColoInitContext -from tests.test_moe.test_moe_zero_init import MoeModel -from tests.test_zero.test_legacy.common import CONFIG - - -@parameterize("init_device_type", ["cpu", "cuda"]) -def exam_moe_colo_init(init_device_type): - world_size = dist.get_world_size() - - if init_device_type == "cuda": - init_device = get_current_device() - elif init_device_type == "cpu": - init_device = torch.device("cpu") - else: - raise NotImplementedError("Unknown device found.") - - with ColoInitContext(device=init_device): - model = MoeModel(checkpoint=True) - - for name, param in model.named_parameters(): - assert isinstance(param, ColoParameter), "parameter `{}` has an init problem".format(name) - - if hasattr(param, "moe_info"): - param.set_process_group(param.moe_info.pg) - - if hasattr(param, "moe_info"): - assert param.process_group.dp_world_size() == param.moe_info.dp_size - else: - assert param.process_group.dp_world_size() == world_size - - -def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - MOE_CONTEXT.setup(seed=42) - exam_moe_colo_init() - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [4]) -@rerun_if_address_is_in_use() -def test_moe_colo_init(world_size): - spawn(_run_dist, world_size) - - -if __name__ == "__main__": - test_moe_colo_init(world_size=4) diff --git a/tests/test_moe/test_moe_group.py b/tests/test_moe/test_moe_group.py index 300fb6c99b7b..54005f04fa16 100644 --- a/tests/test_moe/test_moe_group.py +++ b/tests/test_moe/test_moe_group.py @@ -11,12 +11,11 @@ D_MODEL = 4 D_FF = 8 -CONFIG = dict() def run_test(rank, world_size, port): world_size = 4 - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') expert_module = nn.Linear expert_factor = dict(in_features=D_MODEL, out_features=D_FF, device=get_current_device()) diff --git a/tests/test_moe/test_moe_zero_fwd_bwd.py b/tests/test_moe/test_moe_zero_fwd_bwd.py new file mode 100644 index 000000000000..83ec884b1515 --- /dev/null +++ b/tests/test_moe/test_moe_zero_fwd_bwd.py @@ -0,0 +1,106 @@ +import pytest +import torch + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import LowLevelZeroPlugin +from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel +from colossalai.context import MOE_CONTEXT +from colossalai.engine.gradient_handler import MoeGradientHandler +from colossalai.nn import MoeLoss +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing.random import seed_all +from tests.test_moe.moe_utils import MoeModel + + +def split_ddp_grad(grad, world_size): + with torch.no_grad(): + grad = grad.clone().detach().flatten() + padding_size = (world_size - grad.numel() % world_size) % world_size + if padding_size > 0: + grad = torch.nn.functional.pad(grad, [0, padding_size]) + splited_grad = grad.split(grad.numel() // world_size) + return splited_grad + + +def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False): + model.train() + with torch.cuda.amp.autocast(enabled=enable_autocast): + if criterion: + y = model(data) + loss = criterion(y, label) + else: + loss = model(data, label) + loss = loss.float() + + if isinstance(model, LowLevelZeroModel): + optimizer.backward(loss) + else: + loss.backward() + return y + + +def run_zero_test(local_rank, world_size, stage=1): + criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss) + + zero_model = MoeModel(checkpoint=True) + optimizer = torch.optim.Adam(zero_model.parameters()) + plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") + booster = Booster(plugin=plugin) + zero_model, optimizer, _, _, _ = booster.boost(zero_model, optimizer) + + torch_model = MoeModel(checkpoint=True) + for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): + torch_param.data.copy_(zero_param.data) + torch_model = torch_model.cuda() + grad_handler = MoeGradientHandler(torch_model) + + # assert zero model + assert len(zero_model.module.test_transform.moe.moe_layer.experts.experts) == 8 // MOE_CONTEXT.world_size + for (torch_name, torch_param), (zero_name, zero_param) in zip(torch_model.named_parameters(), + zero_model.module.named_parameters()): + assert zero_name == torch_name + assert torch.allclose(zero_param.data, torch_param.data) + + data = torch.randn(16, 4).cuda() + label = torch.randint(0, 4, (16,)).cuda() + + torch_out = run_fwd_bwd(torch_model, data, label, criterion, None) + zero_out = run_fwd_bwd(zero_model, data, label, criterion, optimizer) + assert torch.allclose(torch_out, zero_out) + grad_handler.handle_gradient() + + for (zero_name, zero_param), (torch_name, torch_param) in zip(zero_model.module.named_parameters(), + torch_model.named_parameters()): + assert zero_name == torch_name + zero_grad_list = optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param)) + if hasattr(zero_param, "moe_info"): + assert len(zero_grad_list) == 0 + assert torch.allclose(zero_param.grad, torch_param.grad) + else: + assert len(zero_grad_list) > 0 + torch_grad_list = split_ddp_grad(torch_param.grad, world_size) + if stage == 2: + torch_grad_list = torch_grad_list[local_rank:local_rank + 1] + assert len(zero_grad_list) == len(torch_grad_list) + for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list): + assert torch.allclose(zero_grad, torch_grad) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + MOE_CONTEXT.setup(seed=42) + seed_all(42 + rank) + run_zero_test(rank, world_size, stage=1) + run_zero_test(rank, world_size, stage=2) + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [2]) +@rerun_if_address_is_in_use() +def test_moe_zero_model(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_moe_zero_model(world_size=2) diff --git a/tests/test_moe/test_moe_zero_init.py b/tests/test_moe/test_moe_zero_init.py deleted file mode 100644 index c48f9a3557ce..000000000000 --- a/tests/test_moe/test_moe_zero_init.py +++ /dev/null @@ -1,106 +0,0 @@ -import pytest -import torch -import torch.nn as nn - -import colossalai -from colossalai.context import MOE_CONTEXT -from colossalai.logging import get_dist_logger -from colossalai.nn import CheckpointModule -from colossalai.nn.layer import MoeModule -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device -from colossalai.zero.legacy.init_ctx import ZeroInitContext -from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy -from tests.test_zero.test_legacy.common import CONFIG - - -class MoeModel(nn.Module): - def __init__(self, checkpoint: bool = False): - class TestSubModule(CheckpointModule): - def __init__(self): - super().__init__(checkpoint) - expert_cls = nn.Linear - expert_args_dict = dict(in_features=16, out_features=16) - self.moe = MoeModule( - dim_model=16, num_experts=8, use_residual=True, expert_cls=expert_cls, **expert_args_dict - ) - self.proj = nn.Linear(16, 4) - - def _forward(self, x): - x, y = self.moe(x) - x = self.proj(x) - return x, y - - super().__init__() - self.test_embed = nn.Linear(4, 16) - self.test_transform = TestSubModule() - - def forward(self, x): - MOE_CONTEXT.reset_loss() - - x = self.test_embed(x) - x, y = self.test_transform(x) - - MOE_CONTEXT.add_loss(y) - return x - - -@parameterize("init_device_type", ["cpu", "cuda"]) -@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) -def run_moe_zero_init(init_device_type, shard_strategy_class): - get_dist_logger("test_moe_zero_init") - - if init_device_type == "cuda": - init_device = get_current_device() - elif init_device_type == "cpu": - init_device = torch.device("cpu") - else: - raise NotImplementedError("Unknown device found.") - - model_numel_tensor = torch.zeros(1, dtype=torch.int) - with ZeroInitContext( - target_device=init_device, - shard_strategy=shard_strategy_class(), - shard_param=True, - model_numel_tensor=model_numel_tensor, - ): - model = MoeModel(checkpoint=True) - - for name, param in model.named_parameters(): - assert hasattr(param, "colo_attr") - - # the parameters in moe experts and its gate should not be sharded - if ("experts" in name) or ("gate" in name) or ("residual_combine" in name): - assert not param.colo_attr.sharded_data_tensor.is_sharded, "`{}` parameter has problem".format(name) - else: - assert param.colo_attr.sharded_data_tensor.is_sharded - - # the parameters in moe experts is not replicated - if "experts" in name: - assert not param.colo_attr.is_replicated - else: - assert param.colo_attr.is_replicated - - if param.colo_attr.param_is_sharded: - assert ( - param.colo_attr.data_payload.device.type == init_device.type - ), f"{param.colo_attr.data_payload.device.type} vs. {init_device.type}" - else: - assert param.colo_attr.data_payload.device.type == "cuda" - - -def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - MOE_CONTEXT.setup(seed=42) - run_moe_zero_init() - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [2, 4]) -@rerun_if_address_is_in_use() -def test_moe_zero_init(world_size): - spawn(_run_dist, world_size) - - -if __name__ == "__main__": - test_moe_zero_init(world_size=2) diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index bb9822daee05..bbb43b4b9871 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -1,5 +1,6 @@ import pytest import torch +import torch.distributed as dist import colossalai from colossalai.context import MOE_CONTEXT @@ -16,8 +17,31 @@ from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2 from colossalai.zero.low_level._utils import has_inf_or_nan from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_moe.test_moe_zero_init import MoeModel -from tests.test_zero.test_legacy.common import CONFIG, check_sharded_model_params +from tests.test_moe.moe_utils import MoeModel + + +def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool: + if loose: + return torch.allclose(tensor_a, tensor_b, atol=1e-2, rtol=1e-3) + return torch.allclose(tensor_a, tensor_b) + + +def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=False): + rank = dist.get_rank() + for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()): + if zero_p.colo_attr.param_is_sharded: + zero_p = zero_p.colo_attr.data_payload.to(p.device).float() + chunks = torch.flatten(p).chunk(dist.get_world_size()) + if rank >= len(chunks): + continue + p = chunks[rank].float() + if zero_p.size(0) > p.size(0): + zero_p = zero_p[:p.size(0)] + else: + zero_p = zero_p.colo_attr.data_payload.to(p.device) + + assert p.dtype == zero_p.dtype, "Parameter `{}`:\n{} vs {}".format(name, p.dtype, zero_p.dtype) + assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}' def _run_step(model, optimizer, data, label, criterion, grad_handler): @@ -104,7 +128,7 @@ def _run_test_sharded_optim_v2( def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') MOE_CONTEXT.setup(seed=42) _run_test_sharded_optim_v2() From 4373d0644570846c722cba8795578d333029405e Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> Date: Mon, 14 Aug 2023 17:41:53 +0800 Subject: [PATCH 02/46] [moe] support low level zero optim (#4429) * update optim * update grad handler * update moe param interface * update doc * move moe tensor --- .../engine/gradient_handler/__init__.py | 9 +- colossalai/nn/layer/moe/experts.py | 5 +- colossalai/tensor/moe_tensor/api.py | 26 +++ colossalai/zero/low_level/low_level_optim.py | 29 ++- tests/test_moe/moe_utils.py | 45 +++++ tests/test_moe/test_grad_handler.py | 2 +- tests/test_moe/test_moe_zero_fwd_bwd.py | 3 +- tests/test_moe/test_moe_zero_optim.py | 181 +++++++----------- 8 files changed, 166 insertions(+), 134 deletions(-) create mode 100644 colossalai/tensor/moe_tensor/api.py diff --git a/colossalai/legacy/engine/gradient_handler/__init__.py b/colossalai/legacy/engine/gradient_handler/__init__.py index 78928b138842..633e9f885918 100644 --- a/colossalai/legacy/engine/gradient_handler/__init__.py +++ b/colossalai/legacy/engine/gradient_handler/__init__.py @@ -1,15 +1,10 @@ from ._base_gradient_handler import BaseGradientHandler from ._data_parallel_gradient_handler import DataParallelGradientHandler -from ._moe_gradient_handler import MoeGradientHandler from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler from ._zero_gradient_handler import ZeROGradientHandler __all__ = [ - "BaseGradientHandler", - "DataParallelGradientHandler", - "ZeROGradientHandler", - "PipelineSharedModuleGradientHandler", - "MoeGradientHandler", - "SequenceParallelGradientHandler", + 'BaseGradientHandler', 'DataParallelGradientHandler', 'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler', + 'SequenceParallelGradientHandler' ] diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/nn/layer/moe/experts.py index 4b2ecb241702..6e5048463d50 100644 --- a/colossalai/nn/layer/moe/experts.py +++ b/colossalai/nn/layer/moe/experts.py @@ -7,8 +7,7 @@ import torch.nn as nn from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.legacy.context import ParallelMode, seed -from colossalai.legacy.zero.init_ctx import no_shard_zero_decrator +from colossalai.tensor.moe_tensor.api import set_moe_param_info from colossalai.utils import get_current_device @@ -52,7 +51,7 @@ def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args) # Attach parallel information for all parameters in Experts for exp in self.experts: for param in exp.parameters(): - param.__setattr__("moe_info", self.dist_info) + set_moe_param_info(param, self.dist_info) def forward(self, inputs: torch.Tensor): # Split inputs for each expert diff --git a/colossalai/tensor/moe_tensor/api.py b/colossalai/tensor/moe_tensor/api.py new file mode 100644 index 000000000000..11d07ef8c804 --- /dev/null +++ b/colossalai/tensor/moe_tensor/api.py @@ -0,0 +1,26 @@ +import torch + + +def is_moe_param(tensor: torch.Tensor) -> bool: + """ + Check whether the given tensor is a moe param. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + bool: Whether the given tensor is a moe param. + """ + return hasattr(tensor, "moe_info") + + +def set_moe_param_info(tensor: torch.Tensor, moe_info: dict) -> None: + """ + Set moe info for the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be set. + moe_info (dict): The moe info to be set. + + """ + tensor.__setattr__('moe_info', moe_info) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index cc7c291e507d..e6b473adcee6 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -18,7 +18,7 @@ ) from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger - +from colossalai.tensor.moe_tensor.api import is_moe_param # from colossalai.tensor import ColoParameter, ProcessGroup from colossalai.utils.cuda import get_current_device @@ -126,16 +126,23 @@ def __init__( self._grad_store = GradientStore(self.dp_pg, partition_grad=partition_grad) self._bucket_store = BucketStore(self.dp_pg) + # moe param should not be stored in working_groups + # because they have different parallel strategy + # so we need to store them separately in param_groups + # instead of working_groups + moe_params = list() + # iterate over the param group in the optimizer # partition these param groups for data parallel training # and add buffers to parameter store for future access for group_id, param_group in enumerate(self.optim.param_groups): group_params = list() for param in param_group['params']: - # skip moe param - if hasattr(param, "moe_info"): - continue if param.requires_grad: + # skip moe param + if is_moe_param(param): + moe_params.append(param) + continue group_params.append(param) # add the working params to working_param_groups for bookkeeping @@ -149,6 +156,15 @@ def __init__( # managed by this data parallel rank param_group["params"] = master_param_current_rank + # if there are moe params, store in addtional group in optim + if len(moe_params) > 0: + param_group = dict() + for key, value in self.optim.param_groups[0].items(): + if key != 'params': + param_group[key] = value + param_group['params'] = moe_params + self.optim.param_groups.append(param_group) + # intialize communication stream for # communication-compuation overlapping if self._overlap_communication: @@ -455,6 +471,11 @@ def step(self, closure=None): # update the parameters self.optim.step() + # release the moe grad + if len(self.param_groups) > len(self._working_param_groups): + for param in self.param_groups[-1]['params']: + param.grad = None + # release the grad grad_partition_groups = [] for group_id in range(self.num_param_groups): diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py index 4b067c1ceea9..d86d78886e23 100644 --- a/tests/test_moe/moe_utils.py +++ b/tests/test_moe/moe_utils.py @@ -1,8 +1,15 @@ import torch.nn as nn from colossalai.context import MOE_CONTEXT +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.engine.gradient_handler._base_gradient_handler import BaseGradientHandler +from colossalai.engine.gradient_handler.utils import bucket_allreduce from colossalai.nn import CheckpointModule from colossalai.nn.layer import MoeModule +from colossalai.registry import GRADIENT_HANDLER +from colossalai.utils.moe import get_moe_epsize_param_dict class MoeModel(nn.Module): @@ -39,3 +46,41 @@ def forward(self, x): MOE_CONTEXT.add_loss(y) return x + + +@GRADIENT_HANDLER.register_module +class MoeGradientHandler(BaseGradientHandler): + """A helper class to handle all-reduce operations in a data parallel group and + moe model parallel. A all-reduce collective communication will be operated in + :func:`handle_gradient` among a data parallel group. + For better performance, it bucketizes the gradients of all parameters that are + the same type to improve the efficiency of communication. + + Args: + model (Module): Model where the gradients accumulate. + optimizer (Optimizer): Optimizer for updating the parameters. + """ + + def __init__(self, model, optimizer=None): + super().__init__(model, optimizer) + + def handle_gradient(self): + """A method running an all-reduce operation in a data parallel group. + Then running an all-reduce operation for all parameters in experts + across moe model parallel group + """ + global_data = gpc.data_parallel_size + + if global_data > 1: + epsize_param_dict = get_moe_epsize_param_dict(self._model) + + # epsize is 1, indicating the params are replicated among processes in data parallelism + # use the ParallelMode.DATA to get data parallel group + # reduce gradients for all parameters in data parallelism + if 1 in epsize_param_dict: + bucket_allreduce(param_list=epsize_param_dict[1], group=gpc.get_group(ParallelMode.DATA)) + + for ep_size in epsize_param_dict: + if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: + bucket_allreduce(param_list=epsize_param_dict[ep_size], + group=MOE_CONTEXT.parallel_info_dict[ep_size].dp_group) diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index 2c42d55fa9e3..cff7c116696f 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -5,11 +5,11 @@ import colossalai from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.legacy.engine.gradient_handler import MoeGradientHandler from colossalai.nn.layer.moe import Experts, MoeLayer, Top1Router, UniformNoiseGenerator from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device from colossalai.utils.moe import sync_moe_model_param +from tests.test_moe.moe_utils import MoeGradientHandler BATCH_SIZE = 4 DIM = 16 diff --git a/tests/test_moe/test_moe_zero_fwd_bwd.py b/tests/test_moe/test_moe_zero_fwd_bwd.py index 83ec884b1515..e2acb0702f1c 100644 --- a/tests/test_moe/test_moe_zero_fwd_bwd.py +++ b/tests/test_moe/test_moe_zero_fwd_bwd.py @@ -6,11 +6,10 @@ from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel from colossalai.context import MOE_CONTEXT -from colossalai.engine.gradient_handler import MoeGradientHandler from colossalai.nn import MoeLoss from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all -from tests.test_moe.moe_utils import MoeModel +from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel def split_ddp_grad(grad, world_size): diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index bbb43b4b9871..fcb6f95d1319 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -1,145 +1,92 @@ import pytest import torch -import torch.distributed as dist import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import LowLevelZeroPlugin +from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel from colossalai.context import MOE_CONTEXT -from colossalai.legacy.amp import convert_to_apex_amp -from colossalai.legacy.engine.gradient_handler import MoeGradientHandler from colossalai.nn import MoeLoss -from colossalai.nn.optimizer import CPUAdam -from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device -from colossalai.zero.legacy.init_ctx import ZeroInitContext -from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy -from colossalai.zero.legacy.sharded_model import ShardedModelV2 -from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy -from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2 -from colossalai.zero.low_level._utils import has_inf_or_nan -from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_moe.moe_utils import MoeModel - - -def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool: - if loose: - return torch.allclose(tensor_a, tensor_b, atol=1e-2, rtol=1e-3) - return torch.allclose(tensor_a, tensor_b) - - -def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=False): - rank = dist.get_rank() - for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()): - if zero_p.colo_attr.param_is_sharded: - zero_p = zero_p.colo_attr.data_payload.to(p.device).float() - chunks = torch.flatten(p).chunk(dist.get_world_size()) - if rank >= len(chunks): - continue - p = chunks[rank].float() - if zero_p.size(0) > p.size(0): - zero_p = zero_p[:p.size(0)] - else: - zero_p = zero_p.colo_attr.data_payload.to(p.device) +from colossalai.testing import rerun_if_address_is_in_use, spawn +from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel - assert p.dtype == zero_p.dtype, "Parameter `{}`:\n{} vs {}".format(name, p.dtype, zero_p.dtype) - assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}' +def split_ddp_grad(grad, world_size): + with torch.no_grad(): + grad = grad.clone().detach().flatten() + padding_size = (world_size - grad.numel() % world_size) % world_size + if padding_size > 0: + grad = torch.nn.functional.pad(grad, [0, padding_size]) + splited_grad = grad.split(grad.numel() // world_size) + return splited_grad -def _run_step(model, optimizer, data, label, criterion, grad_handler): - model.train() - optimizer.zero_grad() - if criterion: - y = model(data) - loss = criterion(y, label) - else: - loss = model(data, label) +def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False): + model.train() + with torch.cuda.amp.autocast(enabled=enable_autocast): + if criterion: + y = model(data) + loss = criterion(y, label) + else: + loss = model(data, label) + loss = loss.float() - loss = loss.float() - if isinstance(model, ShardedModelV2): + if isinstance(model, LowLevelZeroModel): optimizer.backward(loss) else: loss.backward() + return y - if grad_handler is not None: - grad_handler.handle_gradient() - optimizer.step() - - -@parameterize("cpu_offload", [True]) -@parameterize("use_cpuadam", [True]) # We do not use Hybrid Adam right now, since it has a little bug -@parameterize("reuse_fp16_shard", [True, False]) -@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) -def _run_test_sharded_optim_v2( - cpu_offload, shard_strategy_class, use_cpuadam, reuse_fp16_shard, gpu_margin_mem_ratio=0.0 -): - shard_strategy = shard_strategy_class() - if use_cpuadam and cpu_offload is False: - return - MOE_CONTEXT.reset_loss() - get_components_func = non_distributed_component_funcs.get_callable("hanging_param_model") - _, train_dataloader, _, optimizer_class, _ = get_components_func() +def run_zero_optim_test(local_rank, world_size, stage=1): criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss) - with ZeroInitContext( - target_device=torch.device("cpu") if cpu_offload else get_current_device(), - shard_strategy=shard_strategy, - shard_param=True, - ): - zero_model = MoeModel(checkpoint=True) - - zero_model = ShardedModelV2( - zero_model, - shard_strategy, - tensor_placement_policy="cpu" if cpu_offload else "cuda", - reuse_fp16_shard=reuse_fp16_shard, - ) - - # check whether parameters are identical in ddp - for name, p in zero_model.named_parameters(): - if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated: - assert_equal_in_group(p.colo_attr.data_payload.to(get_current_device())) - - model = MoeModel(checkpoint=True).half() - col_model_deepcopy(zero_model, model) - model = model.cuda().float() - - if use_cpuadam: - optimizer_class = CPUAdam - optim = optimizer_class(model.parameters(), lr=1e-3) - sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3) - sharded_optim = ShardedOptimizerV2( - zero_model, sharded_optim, initial_scale=2**5, gpu_margin_mem_ratio=gpu_margin_mem_ratio - ) - - amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False) - apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config) - apex_grad_handler = MoeGradientHandler(model) - - for i, (data, label) in enumerate(train_dataloader): - if i > 5: - break - data, label = data.cuda(), label.cuda() - _run_step(apex_model, apex_optimizer, data, label, criterion, apex_grad_handler) - _run_step(zero_model, sharded_optim, data, label, criterion, None) - check_sharded_model_params(model, zero_model, loose=True, reuse_fp16_shard=use_cpuadam) - for param in model.parameters(): - assert not has_inf_or_nan(param) - - -def _run_dist(rank, world_size, port): + zero_model = MoeModel(checkpoint=True) + zero_optimizer = torch.optim.Adam(zero_model.parameters()) + plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") + booster = Booster(plugin=plugin) + zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) + + torch_model = MoeModel(checkpoint=True) + for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): + torch_param.data.copy_(zero_param.data) + torch_optimizer = torch.optim.Adam(torch_model.parameters()) + torch_model = torch_model.cuda() + grad_handler = MoeGradientHandler(torch_model) + + for _ in range(2): + data = torch.randn(16, 4).cuda() / (local_rank + 1) + label = torch.randint(0, 4, (16,)).cuda() + run_fwd_bwd(torch_model, data, label, criterion, None) + run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + grad_handler.handle_gradient() + + torch_optimizer.step() + zero_optimizer.step() + + for (torch_name, torch_param), (zero_name, zero_param) in zip(torch_model.named_parameters(), + zero_model.named_parameters()): + assert torch.allclose( + torch_param.data, + zero_param.data), f"{torch_name}\ntorch_param {torch_param.data}\nzero_param {zero_param.data}" + + torch_optimizer.zero_grad() + zero_optimizer.zero_grad() + + +def run_dist(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') MOE_CONTEXT.setup(seed=42) - _run_test_sharded_optim_v2() + run_zero_optim_test(rank, world_size, stage=1) + run_zero_optim_test(rank, world_size, stage=2) -# use_cpuadam = True can be used with cpu_offload = False @pytest.mark.dist @pytest.mark.parametrize("world_size", [2]) @rerun_if_address_is_in_use() def test_moe_zero_optim(world_size): - spawn(_run_dist, world_size) + spawn(run_dist, world_size) -if __name__ == "__main__": - test_moe_zero_optim(world_size=4) +if __name__ == '__main__': + test_moe_zero_optim(world_size=2) From 8240463f293e9cc44328e17a84f2b545e241952e Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> Date: Fri, 25 Aug 2023 14:09:11 +0800 Subject: [PATCH 03/46] [moe] refactor code to better adapt to llm (#4469) * polish code * rename * refactor code * fix test * refactor code * update flash attention version * Support TP (#6) * add tp test * update tp test * update * remove fa dependency * update dependency * update softmax * update checkpointio * update processgroupmesh * update name * update param * add keep vars --- colossalai/context/moe_context.py | 48 +--- colossalai/nn/layer/moe/__init__.py | 21 +- colossalai/nn/layer/moe/checkpoint.py | 77 ++++-- colossalai/nn/layer/moe/experts.py | 264 +++++++------------ colossalai/nn/layer/moe/layers.py | 203 ++++++++------ colossalai/nn/layer/moe/routers.py | 100 +++---- colossalai/nn/layer/moe/utils.py | 39 ++- colossalai/tensor/moe_tensor/api.py | 91 ++++++- colossalai/tensor/moe_tensor/moe_info.py | 15 ++ colossalai/zero/low_level/low_level_optim.py | 4 +- tests/test_moe/moe_utils.py | 61 ++++- tests/test_moe/test_grad_handler.py | 22 +- tests/test_moe/test_kernel.py | 11 +- tests/test_moe/test_moe_checkpoint.py | 16 +- tests/test_moe/test_moe_ep_tp.py | 63 +++++ tests/test_moe/test_moe_group.py | 64 +++-- tests/test_moe/test_moe_zero_fwd_bwd.py | 1 - 17 files changed, 643 insertions(+), 457 deletions(-) create mode 100644 colossalai/tensor/moe_tensor/moe_info.py create mode 100644 tests/test_moe/test_moe_ep_tp.py diff --git a/colossalai/context/moe_context.py b/colossalai/context/moe_context.py index 066dfc7222e1..ea74d2c60dd6 100644 --- a/colossalai/context/moe_context.py +++ b/colossalai/context/moe_context.py @@ -4,26 +4,8 @@ import torch.distributed as dist from colossalai.context.singleton_meta import SingletonMeta -from colossalai.legacy.tensor import ProcessGroup - - -def _check_sanity(): - from colossalai.legacy.core import global_context as gpc - - if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1: - raise NotImplementedError("Moe is not compatible with tensor or " "pipeline parallel at present.") - - -class MoeParallelInfo: - """Moe parallelism information, storing parallel sizes and groups.""" - - def __init__(self, ep_size: int, dp_size: int): - _check_sanity() - self.ep_size = ep_size - self.dp_size = dp_size - self.pg = ProcessGroup(tp_degree=ep_size, dp_degree=dp_size) - self.ep_group = self.pg.tp_process_group() - self.dp_group = self.pg.dp_process_group() +from colossalai.tensor.moe_tensor.api import get_moe_info +from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo class MoeContext(metaclass=SingletonMeta): @@ -32,12 +14,12 @@ class MoeContext(metaclass=SingletonMeta): """ def __init__(self): - self.world_size = 1 + self.world_size = None # Users may want to set maximum expert parallel size smaller than the world size # since very low bandwidth across nodes may constrain the performance of MoE # When we have a maximum expert parallel size, we have a minimum data parallel size naturally - self.max_ep_size = 1 - self.min_dp_size = 1 + self.max_ep_size = None + self.min_dp_size = None self.aux_loss = None self.use_kernel_optim = True @@ -52,19 +34,12 @@ def parallel_info_dict(self): def is_initialized(self): return self.has_setup - def setup(self, seed: int, use_kernel_optim: bool = True): + def setup(self, seed: int, use_kernel_optim: bool = True, max_ep_size: int = 8): assert not self.is_initialized, "MoE distributed context shouldn't be set up again" - _check_sanity() assert torch.cuda.is_available(), "MoE requires to enable CUDA first" self.world_size = dist.get_world_size() - - from colossalai.legacy.core import global_context as gpc - - self.max_ep_size = gpc.config.get("max_ep_size", self.world_size) - assert ( - self.world_size % self.max_ep_size == 0 - ), "Maximum expert parallel size must be a factor of the number of GPUs" + self.max_ep_size = min(max_ep_size, dist.get_world_size()) self.min_dp_size = self.world_size // self.max_ep_size # Enabling kernel optimization may raise error in some cases @@ -76,7 +51,7 @@ def setup(self, seed: int, use_kernel_optim: bool = True): moe_set_seed(seed) self.has_setup = True - def get_info(self, num_experts: int) -> Tuple[int, MoeParallelInfo]: + def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoeParallelInfo]: """Calculate the Data Parallel Group and Expert Parallel Group. Parameters @@ -107,12 +82,15 @@ def get_info(self, num_experts: int) -> Tuple[int, MoeParallelInfo]: ep_size = self.max_ep_size // dp_size # Calculate the number of experts for each GPU - num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size + if use_tp: + num_local_experts = num_experts + else: + num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size # Don't forget to multiply minimum data parallel size dp_size *= self.min_dp_size if not (ep_size in self.parallel_info_dict): - self.parallel_info_dict[ep_size] = MoeParallelInfo(ep_size, dp_size) + self.parallel_info_dict[ep_size] = get_moe_info(ep_size, dp_size) return num_local_experts, self.parallel_info_dict[ep_size] diff --git a/colossalai/nn/layer/moe/__init__.py b/colossalai/nn/layer/moe/__init__.py index 6a5ccff510be..ffeeac796441 100644 --- a/colossalai/nn/layer/moe/__init__.py +++ b/colossalai/nn/layer/moe/__init__.py @@ -1,21 +1,10 @@ -from .checkpoint import load_moe_model, save_moe_model -from .experts import Experts, FFNExperts, TPExperts -from .layers import MoeLayer, MoeModule +from .checkpoint import MoeCheckpintIO +from .experts import EPMLPExperts, TPMLPExperts +from .layers import MoeLayer, MoeModule, SparseMLP from .routers import MoeRouter, Top1Router, Top2Router from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts __all__ = [ - "Experts", - "FFNExperts", - "TPExperts", - "Top1Router", - "Top2Router", - "MoeLayer", - "NormalNoiseGenerator", - "UniformNoiseGenerator", - "build_ffn_experts", - "MoeModule", - "MoeRouter", - "save_moe_model", - "load_moe_model", + 'EPMLPExperts', 'TPMLPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'MoeModule', 'NormalNoiseGenerator', + 'UniformNoiseGenerator', 'build_ffn_experts', 'SparseMLP', 'MoeRouter', 'MoeCheckpintIO' ] diff --git a/colossalai/nn/layer/moe/checkpoint.py b/colossalai/nn/layer/moe/checkpoint.py index adad19d581ef..34af87bd9d47 100644 --- a/colossalai/nn/layer/moe/checkpoint.py +++ b/colossalai/nn/layer/moe/checkpoint.py @@ -1,40 +1,61 @@ +from pathlib import Path +from typing import Optional + import torch import torch.distributed as dist import torch.nn as nn +from torch.optim import Optimizer + +from colossalai.checkpoint_io import CheckpointIO +from colossalai.tensor.moe_tensor.api import get_ep_group + + +class MoeCheckpintIO(CheckpointIO): + + def __init__(self) -> None: + super().__init__() + + def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): + state_dict = torch.load(checkpoint) + for name, param in model.named_parameters(): + if '.experts.' in name: + ep_rank = dist.get_rank(get_ep_group(param)) + ep_size = dist.get_world_size(get_ep_group(param)) + for rank in range(ep_size): + new_name = name.replace('.experts.', f'.experts.{rank}.') + if rank == ep_rank: + state_dict[name] = state_dict.pop(new_name) + else: + state_dict.pop(new_name) -from .experts import MoeExperts + model.load_state_dict(state_dict, strict=strict) + def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + state_dict = model.state_dict() + if dist.get_rank() == 0: + torch.save(state_dict, checkpoint) + dist.barrier() -def save_moe_model(model: nn.Module, save_path: str): - state_dict = model.state_dict() - if dist.get_rank() == 0: - torch.save(state_dict, save_path) - dist.barrier() + def load_sharded_model(self, model: nn.Module, index_file_path: str, strict: bool): + raise NotImplementedError() + def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str], + size_per_shard: int, use_safetensors: bool): + raise NotImplementedError() -def load_moe_model(model: nn.Module, load_path: str): - state_dict = torch.load(load_path) + # ======================================================== + # Abstract methods for optimizer loading/saving implementation + # ======================================================== - for prefix, module in model.named_modules(): - if prefix.endswith(".moe_layer.experts"): - # this module should be an Experts instance - assert isinstance(module, MoeExperts) + def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str): + raise NotImplementedError() - ep_rank = dist.get_rank(module.dist_info.ep_group) - num_local = module.num_local_experts - for i in range(num_local): - expert_id = ep_rank * num_local + i - for name, _ in module.experts[i].named_parameters(): - cur_key = f"{prefix}.experts.{i}.{name}" - param_key = f"{prefix}.experts.{expert_id}.{name}" - load_param = state_dict[param_key] - state_dict[cur_key] = load_param + def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): + raise NotImplementedError() - for name, _ in module.experts[0].named_parameters(): - pop_pre = f"{prefix}.experts." - pop_suf = f".{name}" - for i in range(num_local, module.num_total_experts): - pop_key = f"{pop_pre}{i}{pop_suf}" - state_dict.pop(pop_key) + def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, + size_per_shard: int): + raise NotImplementedError() - model.load_state_dict(state_dict) + def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool): + raise NotImplementedError() diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/nn/layer/moe/experts.py index 6e5048463d50..7c743b025945 100644 --- a/colossalai/nn/layer/moe/experts.py +++ b/colossalai/nn/layer/moe/experts.py @@ -1,200 +1,138 @@ import math from copy import deepcopy -from typing import Type import torch import torch.distributed as dist import torch.nn as nn from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.tensor.moe_tensor.api import set_moe_param_info -from colossalai.utils import get_current_device +from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_size, set_moe_tensor_info -class MoeExperts(nn.Module): - """Basic class for experts in MoE. It stores what kind of communication experts use - to exchange tokens, how many experts in a single GPU and parallel information such as - expert parallel size, data parallel size and their distributed communication groups. +class BaseMLPExperts(nn.Module): """ - - def __init__(self, comm_name: str, num_experts: int): - super().__init__() - assert comm_name in { - "all_to_all", - "all_gather", - }, "This kind of communication has not been implemented yet.\n Please use Experts build function." - self.comm_name = comm_name - self.num_total_experts = num_experts - # Get the configuration of experts' deployment and parallel information from moe context - self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts) - - -@no_shard_zero_decrator(is_replicated=False) -class Experts(MoeExperts): - """A wrapper class to create experts. It will create E experts across the - moe model parallel group, where E is the number of experts. Every expert - is a instance of the class, 'expert' in initialization parameters. - - Args: - expert_cls (:class:`torch.nn.Module`): The class of all experts - num_experts (int): The number of experts - expert_args: Args used to initialize experts, the args could be found in corresponding expert class + SparseMLP is a multi-layer perceptron with sparse expert parallel layers. """ - def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args): - super().__init__("all_to_all", num_experts) - - # Use seed to make every expert different from others - with seed(ParallelMode.TENSOR): - self.experts = nn.ModuleList([expert_cls(**expert_args) for _ in range(self.num_local_experts)]) - - # Attach parallel information for all parameters in Experts - for exp in self.experts: - for param in exp.parameters(): - set_moe_param_info(param, self.dist_info) - - def forward(self, inputs: torch.Tensor): - # Split inputs for each expert - expert_input = torch.chunk(inputs, self.num_local_experts, dim=1) - expert_output = [] - - # Get outputs from each expert - for i in range(self.num_local_experts): - expert_output.append(self.experts[i](expert_input[i])) - - # Concatenate all outputs together - output = torch.cat(expert_output, dim=1).contiguous() - return output - - def state_dict(self, destination=None, prefix="", keep_vars=False): - assert keep_vars == False, "Only support keep_vars=False now" - dp_rank = dist.get_rank(self.dist_info.dp_group) - ep_rank = dist.get_rank(self.dist_info.ep_group) - submodule_dict = dict() - example_submodule = None - for name, subm in self.experts.named_modules(): - if subm is self.experts: - continue - module_number = self.num_local_experts * ep_rank + int(name) - submodule_dict[module_number] = subm - example_submodule = subm - - if dp_rank == 0: - local_prefix = prefix + "experts." - buffer_module = deepcopy(example_submodule) - for i in range(self.num_total_experts): - source_rank = i // self.num_local_experts - current_prefix = local_prefix + str(i) + "." - comm_module = submodule_dict.get(i, buffer_module) - for name, param in comm_module.named_parameters(): - dist.broadcast(param.data, src=source_rank, group=self.dist_info.ep_group) - if ep_rank == 0: - destination[current_prefix + name] = param.data.cpu() - - dist.barrier() - - -class FFNExperts(MoeExperts): - """Use torch.bmm to speed up for multiple experts.""" - - def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): - super().__init__("all_to_all", num_experts) + def __init__( + self, + num_experts: int, + hidden_size: int, + intermediate_size: int, + expert_parallel: str, + activation: str = None, + drop_rate: float = 0, + ): + super().__init__() + assert expert_parallel in ["EP", "TP"] + self.expert_parallel = expert_parallel - self.w1 = nn.Parameter(torch.empty(self.num_local_experts, d_model, d_ff, device=get_current_device())) - self.b1 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_ff, device=get_current_device())) + # get local and total experts + self.num_total_experts = num_experts + self.num_local_experts, self.moe_info = MOE_CONTEXT.get_info(num_experts, + use_tp=True if expert_parallel == "TP" else False) - self.w2 = nn.Parameter(torch.empty(self.num_local_experts, d_ff, d_model, device=get_current_device())) - self.b2 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_model, device=get_current_device())) + # get settings for different parallel + if expert_parallel == "TP": + assert intermediate_size % MOE_CONTEXT.max_ep_size == 0, \ + "intermediate_size should be divide by maximum expert parallel size" + intermediate_size = intermediate_size // MOE_CONTEXT.max_ep_size + num_experts = self.num_total_experts + else: + num_experts = self.num_local_experts - s1 = math.sqrt(0.1 / d_model) - s2 = math.sqrt(0.1 / d_ff) + self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) + self.wo = nn.Parameter(torch.empty(num_experts, intermediate_size, hidden_size)) with seed(ParallelMode.TENSOR): - nn.init.trunc_normal_(self.w1, std=s1) - nn.init.trunc_normal_(self.b1, std=s1) - nn.init.trunc_normal_(self.w2, std=s2) - nn.init.trunc_normal_(self.b2, std=s2) + nn.init.trunc_normal_(self.wi, std=math.sqrt(0.1 / hidden_size)) + nn.init.trunc_normal_(self.wo, std=math.sqrt(0.1 / intermediate_size)) self.act = nn.GELU() if activation is None else activation self.drop = nn.Dropout(p=drop_rate) for param in self.parameters(): - param.__setattr__("moe_info", self.dist_info) + set_moe_tensor_info(param, self.moe_info) - def forward(self, inputs): # inputs [g, el, c, h] - el = inputs.size(1) - h = inputs.size(-1) + def forward(self, x: torch.Tensor) -> torch.Tensor: # inputs [g, e, c, h] - inputs = inputs.transpose(0, 1) - inshape = inputs.shape - inputs = inputs.reshape(el, -1, h) + e = x.size(1) + h = x.size(-1) - out_ff = torch.baddbmm(self.b1, inputs, self.w1) - out_act = self.act(out_ff) - with seed(ParallelMode.TENSOR): - out_inter = self.drop(out_act) + x = x.transpose(0, 1) + inshape = x.shape + x = x.reshape(e, -1, h) - out_model = torch.baddbmm(self.b2, out_inter, self.w2) + x = torch.bmm(x, self.wi) + x = self.act(x) with seed(ParallelMode.TENSOR): - outputs = self.drop(out_model) # outputs [el, gc, h] + x = self.drop(x) + x = torch.bmm(x, self.wo) - outputs = outputs.reshape(inshape) - outputs = outputs.transpose(0, 1).contiguous() - return outputs + x = x.reshape(inshape) + x = x.transpose(0, 1).contiguous() + return x # outputs [g, e, c, h] -class TPExperts(MoeExperts): - """Use tensor parallelism to split each expert evenly, which can deploy experts in - case that the number of experts can't be divide by maximum expert parallel size or - maximum expert parallel size can't be divide by the number of experts. +class EPMLPExperts(BaseMLPExperts): + """ + Use expert parallelism to split each expert evenly, which can deploy experts in """ - def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): - super().__init__("all_gather", MOE_CONTEXT.max_ep_size) - - assert d_ff % MOE_CONTEXT.max_ep_size == 0, "d_ff should be divide by maximum expert parallel size" - - p_ff = d_ff // MOE_CONTEXT.max_ep_size - - self.w1 = nn.Parameter(torch.empty(num_experts, d_model, p_ff, device=get_current_device())) - self.b1 = nn.Parameter(torch.empty(num_experts, 1, p_ff, device=get_current_device())) - - self.w2 = nn.Parameter(torch.empty(num_experts, p_ff, d_model, device=get_current_device())) - self.b2 = nn.Parameter(torch.empty(num_experts, 1, d_model, device=get_current_device())) - - s1 = math.sqrt(0.1 / d_model) - s2 = math.sqrt(0.1 / d_ff) - - with seed(ParallelMode.TENSOR): - nn.init.trunc_normal_(self.w1, std=s1) - nn.init.trunc_normal_(self.b1, std=s1) - nn.init.trunc_normal_(self.w2, std=s2) - - nn.init.trunc_normal_(self.b2, std=s2) - - self.act = nn.GELU() if activation is None else activation - self.drop = nn.Dropout(p=drop_rate) - - self.w1.__setattr__("moe_info", self.dist_info) - self.w2.__setattr__("moe_info", self.dist_info) - self.b1.__setattr__("moe_info", self.dist_info) + def __init__(self, + num_experts: int, + hidden_size: int, + intermediate_size: int, + activation=None, + drop_rate: float = 0): + super().__init__(num_experts, hidden_size, intermediate_size, "EP", activation, drop_rate) + + def state_dict(self, destination=None, prefix='', keep_vars=False): + dp_rank = dist.get_rank(get_dp_group(self)) + ep_rank = dist.get_rank(get_ep_group(self)) + ep_size = get_ep_size(self) + # dp rank 0 will save the state dict + if dp_rank == 0: + for name, param in self.named_parameters(): + if param is self: + continue + # create buffer + buffer_module = deepcopy(param) + # gather param from every ep rank + for source_rank in range(ep_size): + current_prefix = f"{prefix}{source_rank}." + if ep_rank == source_rank: + dist.broadcast(param.data, src=source_rank, group=self.moe_info.ep_group) + else: + dist.broadcast(buffer_module.data, src=source_rank, group=self.moe_info.ep_group) + if ep_rank == 0: + if keep_vars: + destination[current_prefix + name] = buffer_module.cpu() + else: + destination[current_prefix + name] = buffer_module.data.cpu() - def forward(self, inputs): # inputs [g, e, c, h] - e = inputs.size(1) - h = inputs.size(-1) + dist.barrier() - inputs = inputs.transpose(0, 1) - inshape = inputs.shape - inputs = inputs.reshape(e, -1, h) - out_ff = torch.baddbmm(self.b1, inputs, self.w1) - out_act = self.act(out_ff) - with seed(ParallelMode.TENSOR): - out_inter = self.drop(out_act) - - out_model = torch.baddbmm(self.b2, out_inter, self.w2) - outputs = self.drop(out_model) # outputs [e, gc, h] +class TPMLPExperts(BaseMLPExperts): + """Use tensor parallelism to split each expert evenly, which can deploy experts in + case that the number of experts can't be divide by maximum expert parallel size or + maximum expert parallel size can't be divide by the number of experts. + """ - outputs = outputs.reshape(inshape) - outputs = outputs.transpose(0, 1).contiguous() - return outputs # outputs [g, e, c, h] + def __init__(self, + num_experts: int, + hidden_size: int, + intermediate_size: int, + activation: str = None, + drop_rate: float = 0): + super().__init__(num_experts, hidden_size, intermediate_size, "TP", activation, drop_rate) + + +def get_expert_class(name: str) -> BaseMLPExperts: + if name == "TP": + return TPMLPExperts + elif name == "EP": + return EPMLPExperts + else: + raise ValueError(f"Unknown expert class name: {name}") diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py index 23d483e6a17a..254e55eb3316 100644 --- a/colossalai/nn/layer/moe/layers.py +++ b/colossalai/nn/layer/moe/layers.py @@ -1,5 +1,5 @@ import math -from typing import Optional, Tuple, Type +from typing import Optional, Tuple import torch import torch.nn as nn @@ -15,13 +15,12 @@ MoeDispatch, ReduceScatter, ) -from colossalai.nn.layer.moe.experts import Experts, MoeExperts -from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router -from colossalai.nn.layer.moe.utils import NormalNoiseGenerator, UniformNoiseGenerator -from colossalai.utils import get_current_device +from colossalai.nn.layer.moe.experts import BaseMLPExperts, get_expert_class +from colossalai.nn.layer.moe.routers import MoeRouter, get_router_cls +from colossalai.nn.layer.moe.utils import get_noise_generator +from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size -@no_shard_zero_decrator(is_replicated=True) class MoeLayer(nn.Module): """A MoE layer, that puts its input tensor to its gate and uses the output logits to router all tokens, is mainly used to exchange all tokens for every expert across @@ -35,21 +34,21 @@ class MoeLayer(nn.Module): experts (MoeExperts): Instance of experts generated by Expert. """ - def __init__(self, dim_model: int, num_experts: int, router: MoeRouter, experts: MoeExperts): + def __init__(self, dim_model: int, num_experts: int, router: MoeRouter, experts: BaseMLPExperts): super().__init__() self.d_model = dim_model self.num_experts = num_experts self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, dim_model)) self.router: MoeRouter = router - self.experts: MoeExperts = experts + self.experts: BaseMLPExperts = experts self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False - self.ep_group = experts.dist_info.ep_group - self.ep_size = experts.dist_info.ep_size + self.ep_group = get_ep_group(experts) + self.ep_size = get_ep_size(experts) self.num_local_experts = experts.num_local_experts nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / dim_model)) - def a2a_process(self, dispatch_data: torch.Tensor): + def ep_process(self, dispatch_data: torch.Tensor): expert_input = AllToAll.apply(dispatch_data, self.ep_group) input_shape = expert_input.shape expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model) @@ -84,9 +83,9 @@ def forward(self, inputs: torch.Tensor) -> Tuple: dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) # dispatch_data [e, c, h] - if self.experts.comm_name == "all_to_all": - expert_output = self.a2a_process(dispatch_data) - elif self.experts.comm_name == "all_gather": + if self.experts.expert_parallel == "EP": + expert_output = self.ep_process(dispatch_data) + elif self.experts.expert_parallel == "TP": expert_output = self.tp_process(dispatch_data) else: raise NotImplementedError( @@ -107,7 +106,7 @@ def forward(self, inputs: torch.Tensor) -> Tuple: return ans, l_aux -class MoeModule(nn.Module): +class SparseMLP(nn.Module): """A class for users to create MoE modules in their models. Args: @@ -136,77 +135,119 @@ class MoeModule(nn.Module): https://arxiv.org/abs/2201.05596 """ - def __init__( - self, - dim_model: int, - num_experts: int, - top_k: int = 1, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_policy: Optional[str] = None, - drop_tks: bool = True, - use_residual: bool = False, - residual_instance: Optional[nn.Module] = None, - expert_instance: Optional[MoeExperts] = None, - expert_cls: Optional[Type[nn.Module]] = None, - **expert_args, - ): + def __init__(self, + num_experts: int, + top_k: int = 1, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_policy: Optional[str] = None, + drop_tks: bool = True, + expert_parallel: str = "EP", + hidden_size: int = 2048, + intermediate_size: int = 2048, + activation: str = None): super().__init__() + self.hidden_size = hidden_size + self.num_experts = num_experts + self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False + assert expert_parallel in ["EP", "TP"], f"Unsupported expert parallel type {expert_parallel}" + + # moe router + noisy_func = get_noise_generator(noisy_policy, num_experts) + router_cls = get_router_cls(top_k) + self.router: MoeRouter = router_cls(capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks) + + # moe experts + expert_cls = get_expert_class(expert_parallel) + self.experts: BaseMLPExperts = expert_cls(num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + activation=activation) + self.ep_group = get_ep_group(self.experts) + self.ep_size = get_ep_size(self.experts) + self.num_local_experts = self.experts.num_local_experts + + self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size)) + nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / self.hidden_size)) + + def forward(self, inputs: torch.Tensor) -> Tuple: + # reshape the input tokens + tokens = inputs.reshape(-1, self.hidden_size) + + # the data type of the inputs in the gating should be fp32 + fp32_input = tokens.to(torch.float) + fp32_weight = self.gate_weight.to(torch.float) + gate_output = F.linear(fp32_input, fp32_weight) + + # the result from the router + route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group) + + if self.use_kernel: + dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:]) + dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.hidden_size) + else: + sec_mask_f = route_result_list[1].type_as(inputs) + dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) - noisy_func = None - if noisy_policy is not None: - if noisy_policy == "Jitter": - noisy_func = UniformNoiseGenerator() - elif noisy_policy == "Gaussian": - noisy_func = NormalNoiseGenerator(num_experts) - else: - raise NotImplementedError("Unsupported input noisy policy") - - if top_k == 1: - moe_router_cls = Top1Router - elif top_k == 2: - moe_router_cls = Top2Router + # dispatch_data [e, c, h] + if self.experts.expert_parallel == "EP": + expert_output = self.ep_process(dispatch_data) + elif self.experts.expert_parallel == "TP": + expert_output = self.tp_process(dispatch_data) else: - raise NotImplementedError("top_k > 2 is not supported yet") - - self.moe_router = moe_router_cls( - capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks, - ) - self.use_residual = use_residual - if use_residual: - if residual_instance is not None: - self.residual_module = residual_instance - else: - assert expert_cls is not None, "Expert class can't be None when residual instance is not given" - self.residual_module = expert_cls(**expert_args) - - with no_shard_zero_context(): - self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device()) - - if expert_instance is not None: - my_experts = expert_instance + raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts " + "build function.") + # expert_output [e, c, h] + if self.use_kernel: + expert_output = expert_output.reshape(-1, self.hidden_size) + ans = MoeCombine.apply(expert_output, *route_result_list) else: - assert expert_cls is not None, "Expert class can't be None when experts instance is not given" - my_experts = Experts(expert_cls, num_experts, **expert_args) + combine_weights = route_result_list[0].type_as(inputs) + combine_weights = combine_weights.view(combine_weights.shape[0], -1) + expert_output = expert_output.view(-1, expert_output.shape[-1]) + ans = torch.matmul(combine_weights, expert_output) + + ans = ans.reshape(inputs.shape) + l_aux = self.router.pop_routing_loss() + return ans, l_aux + + def ep_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: + expert_input = AllToAll.apply(dispatch_data, self.ep_group) + input_shape = expert_input.shape + expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) + expert_output = self.experts(expert_input) + expert_output = expert_output.reshape(input_shape) + expert_output = AllToAll.apply(expert_output, self.ep_group) + return expert_output - self.moe_layer = MoeLayer( - dim_model=dim_model, num_experts=num_experts, router=self.moe_router, experts=my_experts - ) + def tp_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: + expert_in = AllGather.apply(dispatch_data, self.ep_group) + expert_out = self.experts(expert_in) + expert_out = ReduceScatter.apply(expert_out, self.ep_group) + return expert_out - def forward(self, inputs: torch.Tensor): - moe_output, l_aux = self.moe_layer(inputs) - if self.use_residual: - residual_output = self.residual_module(inputs) - combine_coef = self.residual_combine(inputs) - combine_coef = F.softmax(combine_coef, dim=-1) - output = moe_output * combine_coef[..., 0:1] + residual_output * combine_coef[..., 1:] - else: - output = moe_output +class MoeModule(nn.Module): + """ + For other dependency + """ - return output, l_aux + def __init__(self, + num_experts: int, + top_k: int = 1, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_policy: Optional[str] = None, + drop_tks: bool = True, + expert_parallel: str = "EP", + hidden_size: int = 2048, + intermediate_size: int = 2048, + activation: str = None): + super().__init__(num_experts, top_k, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_policy, + drop_tks, expert_parallel, hidden_size, intermediate_size, activation) diff --git a/colossalai/nn/layer/moe/routers.py b/colossalai/nn/layer/moe/routers.py index 7ba83b2787a0..53fd8fd43e91 100644 --- a/colossalai/nn/layer/moe/routers.py +++ b/colossalai/nn/layer/moe/routers.py @@ -8,6 +8,7 @@ import torch.nn.functional as F from torch.distributed import ProcessGroup +from colossalai.context import MOE_CONTEXT from colossalai.nn.layer.moe._operation import moe_cumsum from colossalai.utils import get_current_device @@ -23,15 +24,13 @@ class MoeRouter(nn.Module, ABC): drop_tks (bool, optional): Whether drops tokens in evaluation """ - def __init__( - self, - k_value: int, - capacity_factor_train: float, - capacity_factor_eval: float, - min_capacity: int, - noisy_func: Callable = None, - drop_tks: bool = True, - ): + def __init__(self, + k_value: int, + capacity_factor_train: float, + capacity_factor_eval: float, + min_capacity: int, + noisy_func: Callable = None, + drop_tks: bool = True): super().__init__() self.k_value = k_value self.capacity_factor_train = capacity_factor_train @@ -73,31 +72,28 @@ class Top1Router(MoeRouter): drop_tks (bool, optional): Whether drops tokens in evaluation """ - def __init__( - self, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - select_policy: str = "first", - noisy_func: Callable = None, - drop_tks: bool = True, - ): - super().__init__( - k_value=1, - capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks, - ) + def __init__(self, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + select_policy: str = "first", + noisy_func: Callable = None, + drop_tks: bool = True): + super().__init__(k_value=1, + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks) self.select_policy = select_policy assert select_policy in {"first", "random"} if select_policy == "random": - self.uniform = torch.distributions.uniform.Uniform( - low=torch.tensor(0.0, device=get_current_device()), high=torch.tensor(1.0, device=get_current_device()) - ).rsample + self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()), + high=torch.tensor(1.0, + device=get_current_device())).rsample def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): + if self.noisy_func is not None and self.training: inputs = self.noisy_func(inputs) @@ -157,22 +153,18 @@ class Top2Router(MoeRouter): drop_tks (bool, optional): Whether drops tokens in evaluation. """ - def __init__( - self, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_func: Callable = None, - drop_tks: bool = True, - ): - super().__init__( - k_value=2, - capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks, - ) + def __init__(self, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_func: Callable = None, + drop_tks: bool = True): + super().__init__(k_value=2, + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks) def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): # inputs: [s, h] @@ -180,7 +172,7 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti inputs = self.noisy_func(inputs) assert inputs.dtype == torch.float - logits = F.softmax(inputs, dim=-1) # logits: [s, e] + logits = F.softmax(inputs, dim=-1) # logits: [s, e] num_experts = logits.size(-1) capacity = self.get_capacity(logits.shape) @@ -190,12 +182,12 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti top2_idx = torch.argmax(logits_except1, dim=-1) mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32) - cmask = mask1 + mask2 # loss: [s, e] + cmask = (mask1 + mask2) # loss: [s, e] # caculate the auxiliary loss me = torch.mean(logits, dim=0) ce = torch.mean(cmask.float(), dim=0) - l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1 + l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1 self.set_routing_loss(l_aux) if not self.training and not self.drop_tks: @@ -203,7 +195,7 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) capacity = max_num.item() - rank1 = moe_cumsum(mask1) # rank1: [s, e] + rank1 = moe_cumsum(mask1) # rank1: [s, e] rank2 = moe_cumsum(mask2) rank2 += torch.sum(mask1, dim=-2, keepdim=True) @@ -233,3 +225,13 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti sec_mask = cb_weight.bool() return cb_weight, sec_mask + + +def get_router_cls(top_k: int) -> MoeRouter: + if top_k == 1: + router_cls = Top1Router + elif top_k == 2: + router_cls = Top2Router + else: + raise NotImplementedError("top_k > 2 is not supported yet") + return router_cls diff --git a/colossalai/nn/layer/moe/utils.py b/colossalai/nn/layer/moe/utils.py index 4f31dd5579dc..eb3bef70998d 100644 --- a/colossalai/nn/layer/moe/utils.py +++ b/colossalai/nn/layer/moe/utils.py @@ -1,13 +1,16 @@ +from typing import Callable + import torch import torch.nn.functional as F from colossalai.context.moe_context import MOE_CONTEXT from colossalai.utils import get_current_device -from .experts import FFNExperts, TPExperts +from .experts import EPMLPExperts, TPMLPExperts class ForceFP32Parameter(torch.nn.Parameter): + def half(self, memory_format=None): return self.data.clone() @@ -23,10 +26,9 @@ class NormalNoiseGenerator: """ def __init__(self, num_experts: int): - self.normal = torch.distributions.normal.Normal( - loc=torch.tensor(0.0, device=get_current_device()), - scale=torch.tensor(1.0 / num_experts**2, device=get_current_device()), - ).rsample + self.normal = torch.distributions.normal.Normal(loc=torch.tensor(0.0, device=get_current_device()), + scale=torch.tensor(1.0 / num_experts**2, + device=get_current_device())).rsample def __call__(self, inputs: torch.Tensor): noisy = self.normal(inputs.shape) @@ -45,10 +47,9 @@ class UniformNoiseGenerator: """ def __init__(self, eps: float = 1e-2): - self.uniform = torch.distributions.uniform.Uniform( - low=torch.tensor(1.0 - eps, device=get_current_device()), - high=torch.tensor(1.0 + eps, device=get_current_device()), - ).rsample + self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - eps, device=get_current_device()), + high=torch.tensor(1.0 + eps, + device=get_current_device())).rsample def __call__(self, inputs: torch.Tensor): noisy = self.uniform(inputs.shape) @@ -56,16 +57,26 @@ def __call__(self, inputs: torch.Tensor): def autocast_softmax(logit: torch.Tensor, dim: int): - if logit.dtype != torch.float32: - logit = logit.float() - return F.softmax(logit, dim=dim) + return F.softmax(logit, dim=dim, detype=torch.float32) def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): mep_size = MOE_CONTEXT.max_ep_size if num_experts % mep_size == 0 or mep_size % num_experts == 0: - return FFNExperts(num_experts, d_model, d_ff, activation, drop_rate) + return EPMLPExperts(num_experts, d_model, d_ff, activation, drop_rate) elif d_ff % mep_size == 0: - return TPExperts(num_experts, d_model, d_ff, activation, drop_rate) + return TPMLPExperts(num_experts, d_model, d_ff, activation, drop_rate) else: raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.") + + +def get_noise_generator(noise_type: str, num_experts: int) -> Callable: + if noise_type is None: + return None + elif noise_type == 'Jitter': + noisy_func = UniformNoiseGenerator() + elif noise_type == 'Gaussian': + noisy_func = NormalNoiseGenerator(num_experts) + else: + raise NotImplementedError("Unsupported input noisy policy") + return noisy_func diff --git a/colossalai/tensor/moe_tensor/api.py b/colossalai/tensor/moe_tensor/api.py index 11d07ef8c804..b9b6d338438e 100644 --- a/colossalai/tensor/moe_tensor/api.py +++ b/colossalai/tensor/moe_tensor/api.py @@ -1,20 +1,25 @@ import torch +import torch.distributed as dist +from colossalai.tensor import ProcessGroup -def is_moe_param(tensor: torch.Tensor) -> bool: +from .moe_info import MoeParallelInfo + + +def is_moe_tensor(tensor: torch.Tensor) -> bool: """ - Check whether the given tensor is a moe param. + Check whether the given tensor is a moe tensor. Args: tensor (torch.Tensor): The tensor to be checked. Returns: - bool: Whether the given tensor is a moe param. + bool: Whether the given tensor is a moe tensor. """ return hasattr(tensor, "moe_info") -def set_moe_param_info(tensor: torch.Tensor, moe_info: dict) -> None: +def set_moe_tensor_info(tensor: torch.Tensor, moe_info: MoeParallelInfo) -> None: """ Set moe info for the given tensor. @@ -24,3 +29,81 @@ def set_moe_param_info(tensor: torch.Tensor, moe_info: dict) -> None: """ tensor.__setattr__('moe_info', moe_info) + + +def get_moe_info(ep_size: int, dp_size: int) -> MoeParallelInfo: + """ + Get moe info for the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + dict: The moe info of the given tensor. + """ + return MoeParallelInfo(ep_size, dp_size) + + +def get_ep_group(tensor: torch.Tensor) -> ProcessGroup: + """ + Get the expert parallel group of the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + torch.distributed.ProcessGroup: The expert parallel group of the given tensor. + """ + return tensor.moe_info.ep_group + + +def get_ep_size(tensor: torch.Tensor) -> int: + """ + Get the expert parallel size of the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + int: The expert parallel size of the given tensor. + """ + return tensor.moe_info.ep_size + + +def get_dp_group(tensor: torch.Tensor) -> ProcessGroup: + """ + Get the data parallel group of the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + torch.distributed.ProcessGroup: The data parallel group of the given tensor. + """ + return tensor.moe_info.dp_group + + +def get_ep_rank(tensor: torch.Tensor) -> int: + """ + Get the expert parallel rank of the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + int: The expert parallel rank of the given tensor. + """ + return dist.get_rank(get_ep_group(tensor)) + + +def get_dp_rank(tensor: torch.Tensor) -> int: + """ + Get the data parallel rank of the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + int: The data parallel rank of the given tensor. + """ + return dist.get_rank(get_dp_group(tensor)) diff --git a/colossalai/tensor/moe_tensor/moe_info.py b/colossalai/tensor/moe_tensor/moe_info.py new file mode 100644 index 000000000000..89f79f162b5b --- /dev/null +++ b/colossalai/tensor/moe_tensor/moe_info.py @@ -0,0 +1,15 @@ +from colossalai.cluster import ProcessGroupMesh + + +class MoeParallelInfo: + """Moe parallelism information, storing parallel sizes and groups. + """ + + def __init__(self, ep_size: int, dp_size: int): + self.dp_axis = 0 + self.dp_size = dp_size + self.ep_axis = 1 + self.ep_size = ep_size + self.pg = ProcessGroupMesh(self.dp_size, self.ep_size) + self.ep_group = self.pg.get_group_along_axis(self.ep_axis) + self.dp_group = self.pg.get_group_along_axis(self.dp_axis) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index e6b473adcee6..b037274f922b 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -18,7 +18,7 @@ ) from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger -from colossalai.tensor.moe_tensor.api import is_moe_param +from colossalai.tensor.moe_tensor.api import is_moe_tensor # from colossalai.tensor import ColoParameter, ProcessGroup from colossalai.utils.cuda import get_current_device @@ -140,7 +140,7 @@ def __init__( for param in param_group['params']: if param.requires_grad: # skip moe param - if is_moe_param(param): + if is_moe_tensor(param): moe_params.append(param) continue group_params.append(param) diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py index d86d78886e23..b57567e74be3 100644 --- a/tests/test_moe/moe_utils.py +++ b/tests/test_moe/moe_utils.py @@ -1,3 +1,5 @@ +import torch +import torch.distributed as dist import torch.nn as nn from colossalai.context import MOE_CONTEXT @@ -7,26 +9,24 @@ from colossalai.engine.gradient_handler._base_gradient_handler import BaseGradientHandler from colossalai.engine.gradient_handler.utils import bucket_allreduce from colossalai.nn import CheckpointModule -from colossalai.nn.layer import MoeModule +from colossalai.nn.layer import SparseMLP from colossalai.registry import GRADIENT_HANDLER +from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor from colossalai.utils.moe import get_moe_epsize_param_dict class MoeModel(nn.Module): - def __init__(self, checkpoint: bool = False): + def __init__(self, checkpoint: bool = False, expert_parallel: str = "EP"): class TestSubModule(CheckpointModule): def __init__(self): super().__init__(checkpoint) - expert_cls = nn.Linear - expert_args_dict = dict(in_features=16, out_features=16) - self.moe = MoeModule(dim_model=16, - num_experts=8, - use_residual=True, - expert_cls=expert_cls, - **expert_args_dict) + self.moe = SparseMLP(num_experts=8, + expert_parallel=expert_parallel, + hidden_size=16, + intermediate_size=32) self.proj = nn.Linear(16, 4) def _forward(self, x): @@ -84,3 +84,46 @@ def handle_gradient(self): if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: bucket_allreduce(param_list=epsize_param_dict[ep_size], group=MOE_CONTEXT.parallel_info_dict[ep_size].dp_group) + + +def sync_tp_from_ep(tp_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None: + """Sync the parameters of tp model from ep model + + Args: + tp_model (MoeModule) + ep_model (MoeModule) + """ + for (tp_name, tp_param), (ep_name, ep_param) in zip(tp_model.named_parameters(), ep_model.named_parameters()): + assert tp_name == ep_name + if not is_moe_tensor(tp_param): + if assert_grad_flag: + assert torch.allclose(tp_param, ep_param) + assert torch.allclose(tp_param.grad, ep_param.grad) + else: + tp_param.data.copy_(ep_param.data) + continue + + # gather param from ep model + param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] + dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param)) + all_param = torch.cat(param_list, dim=0) + if assert_grad_flag: + grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] + dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param)) + all_grad = torch.cat(grad_list, dim=0) + + # get tp param + tp_dim = [i for i, (d1, d2) in enumerate(zip(tp_param.shape[1:], all_param.shape[1:])) if d1 != d2] + tp_rank = get_ep_rank(tp_param) + tp_dim = tp_dim[0] + 1 + tp_slice = [slice(None)] * tp_dim + [ + slice(tp_param.shape[tp_dim] * tp_rank, tp_param.shape[tp_dim] * (tp_rank + 1)) + ] + new_tp_param = all_param[tuple(tp_slice)] + if assert_grad_flag: + new_grad = all_grad[tuple(tp_slice)] + if assert_grad_flag: + assert torch.allclose(tp_param, new_tp_param) + assert torch.allclose(tp_param.grad, new_grad) + else: + tp_param.data.copy_(new_tp_param.data) diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index cff7c116696f..b004a016404d 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -5,7 +5,7 @@ import colossalai from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.nn.layer.moe import Experts, MoeLayer, Top1Router, UniformNoiseGenerator +from colossalai.nn.layer.moe import EPMLPExperts, MoeLayer, Top1Router, UniformNoiseGenerator from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device from colossalai.utils.moe import sync_moe_model_param @@ -17,8 +17,7 @@ def run_test(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - expert_module = nn.Linear - expert_factor = dict(in_features=DIM, out_features=DIM, device=get_current_device()) + expert_factor = dict(hidden_size=DIM, intermediate_size=DIM * 2) MOE_CONTEXT.setup(42) # MOE initialization noisy_func = UniformNoiseGenerator() @@ -26,7 +25,7 @@ def run_test(rank, world_size, port): num_experts_list = [1, 2, 4] layer_list = [] for num_experts in num_experts_list: - exp = Experts(expert_module, num_experts, **expert_factor) + exp = EPMLPExperts(num_experts, **expert_factor) moe_layer = MoeLayer(DIM, num_experts, router, exp) layer_list.append(moe_layer) @@ -35,8 +34,10 @@ def run_test(rank, world_size, port): sync_moe_model_param(model) dist_dict = MOE_CONTEXT.parallel_info_dict - assert_equal_in_group(layer_list[0].experts.experts[0].weight.data, dist_dict[1].dp_group) - assert_equal_in_group(layer_list[1].experts.experts[0].weight.data, dist_dict[2].dp_group) + assert_equal_in_group(layer_list[0].experts.wi.data, dist_dict[1].dp_group) + assert_equal_in_group(layer_list[0].experts.wo.data, dist_dict[1].dp_group) + assert_equal_in_group(layer_list[1].experts.wi.data, dist_dict[2].dp_group) + assert_equal_in_group(layer_list[1].experts.wo.data, dist_dict[2].dp_group) # MoE model synchronization passed grad_handler = MoeGradientHandler(model, 0) @@ -52,11 +53,10 @@ def run_test(rank, world_size, port): data.backward(grad) grad_handler.handle_gradient() - assert_equal_in_group(layer_list[0].experts.experts[0].weight.grad, dist_dict[1].dp_group) - assert_equal_in_group(layer_list[0].experts.experts[0].bias.grad, dist_dict[1].dp_group) - - assert_equal_in_group(layer_list[1].experts.experts[0].weight.grad, dist_dict[2].dp_group) - assert_equal_in_group(layer_list[1].experts.experts[0].bias.grad, dist_dict[2].dp_group) + assert_equal_in_group(layer_list[0].experts.wi.grad, dist_dict[1].dp_group) + assert_equal_in_group(layer_list[0].experts.wo.grad, dist_dict[1].dp_group) + assert_equal_in_group(layer_list[1].experts.wi.grad, dist_dict[2].dp_group) + assert_equal_in_group(layer_list[1].experts.wo.grad, dist_dict[2].dp_group) # MoE grad handler test passed diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index bd0af109fde6..5265e9a1320c 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -1,12 +1,10 @@ import pytest import torch -import torch.nn as nn import colossalai from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.legacy.context import ParallelMode -from colossalai.legacy.core import global_context as gpc -from colossalai.nn.layer.moe import Experts, MoeLayer, Top1Router, Top2Router +from colossalai.core import global_context as gpc +from colossalai.nn.layer.moe import EPMLPExperts, MoeLayer, Top1Router, Top2Router from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device @@ -32,9 +30,8 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f # get randomized data tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True) - expert_module = nn.Linear - expert_factor = dict(in_features=hidden_size, out_features=hidden_size, device=get_current_device()) - expert = Experts(expert_module, NUM_EXPERTS, **expert_factor) + expert_factor = dict(hidden_size=hidden_size, intermediate_size=hidden_size * 2) + expert = EPMLPExperts(NUM_EXPERTS, **expert_factor) layer = MoeLayer(hidden_size, NUM_EXPERTS, router(capacity_factor_train=1.0), expert) layer = layer.to(get_current_device()) if data_type == torch.float16: diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index f108dc3cd5b1..402346527530 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -6,21 +6,19 @@ import colossalai from colossalai.context import MOE_CONTEXT -from colossalai.nn.layer.moe import load_moe_model, save_moe_model +from colossalai.nn.layer.moe import MoeCheckpintIO from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device -from colossalai.zero import ColoInitContext from tests.test_moe.moe_utils import MoeModel def exam_moe_checkpoint(): - with ColoInitContext(device=get_current_device()): - model = MoeModel(checkpoint=True) - save_moe_model(model, "temp_path.pth") + ckpt = MoeCheckpintIO() + model = MoeModel(checkpoint=True).to(get_current_device()) + ckpt.save_model(model, 'temp_path.pth') - with ColoInitContext(device=get_current_device()): - other_model = MoeModel(checkpoint=True) - load_moe_model(other_model, "temp_path.pth") + other_model = MoeModel(checkpoint=True).to(get_current_device()) + ckpt.load_model(other_model, 'temp_path.pth') state_0 = model.state_dict() state_1 = other_model.state_dict() @@ -42,7 +40,7 @@ def _run_dist(rank, world_size, port): @pytest.mark.parametrize("world_size", [2, 4]) @rerun_if_address_is_in_use() def test_moe_checkpoint(world_size): - spawn(_run_dist) + spawn(_run_dist, world_size) if __name__ == "__main__": diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py new file mode 100644 index 000000000000..13c66cf73e4d --- /dev/null +++ b/tests/test_moe/test_moe_ep_tp.py @@ -0,0 +1,63 @@ +import pytest +import torch +import torch.distributed as dist + +import colossalai +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.nn.layer.moe import SparseMLP +from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device +from colossalai.utils.moe import sync_moe_model_param +from tests.test_moe.moe_utils import MoeGradientHandler, sync_tp_from_ep + +BATCH_SIZE = 4 +DIM = 4 + + +def run_test(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + MOE_CONTEXT.setup(42) # MOE initialization + + ep_model = SparseMLP(num_experts=4, expert_parallel="EP", hidden_size=DIM, intermediate_size=DIM) + tp_model = SparseMLP(num_experts=4, expert_parallel="TP", hidden_size=DIM, intermediate_size=DIM) + ep_model = ep_model.to(get_current_device()) + tp_model = tp_model.to(get_current_device()) + + # sync ep param + sync_moe_model_param(ep_model) + dist_dict = MOE_CONTEXT.parallel_info_dict + assert_equal_in_group(ep_model.experts.wi.data, dist_dict[2].dp_group) + assert_equal_in_group(ep_model.experts.wo.data, dist_dict[2].dp_group) + grad_handler = MoeGradientHandler(ep_model) + # sync tp param + sync_tp_from_ep(tp_model, ep_model) + + rank = dist.get_rank() + torch.cuda.manual_seed(78) + tp_data = torch.randn(BATCH_SIZE, DIM, device=get_current_device()) + ep_data = tp_data.detach()[2 * rank:2 * (rank + 1)] + + out_tp = tp_model(tp_data)[0] + MOE_CONTEXT.reset_loss() + out_ep = ep_model(ep_data)[0] + MOE_CONTEXT.reset_loss() + assert torch.allclose(out_ep, out_tp[2 * rank:2 * (rank + 1)]) + + out_tp.mean().backward() + out_ep.mean().backward() + grad_handler.handle_gradient() + + assert_equal_in_group(ep_model.experts.wi.grad, dist_dict[2].dp_group) + assert_equal_in_group(ep_model.experts.wo.grad, dist_dict[2].dp_group) + + sync_tp_from_ep(tp_model, ep_model, assert_grad_flag=True) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_moe_ep_tp(): + spawn(run_test, 2) + + +if __name__ == '__main__': + test_moe_ep_tp() diff --git a/tests/test_moe/test_moe_group.py b/tests/test_moe/test_moe_group.py index 54005f04fa16..fd87a9a3135d 100644 --- a/tests/test_moe/test_moe_group.py +++ b/tests/test_moe/test_moe_group.py @@ -4,36 +4,37 @@ import colossalai from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.nn.layer.moe import Experts +from colossalai.nn.layer.moe import EPMLPExperts, TPMLPExperts from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device from colossalai.utils.moe import sync_moe_model_param -D_MODEL = 4 -D_FF = 8 +HIDDEN_SIZE = 4 +INTERMEDIATE_SIZE = 8 -def run_test(rank, world_size, port): - world_size = 4 - colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - expert_module = nn.Linear - expert_factor = dict(in_features=D_MODEL, out_features=D_FF, device=get_current_device()) - - MOE_CONTEXT.setup(42) # MOE environment initialization - exp0 = Experts(expert_module, 1, **expert_factor) - exp1 = Experts(expert_module, 2, **expert_factor) - exp2 = Experts(expert_module, 4, **expert_factor) - exp3 = Experts(expert_module, 8, **expert_factor) +def run_moe_init(expert_cls): + expert_args = dict(hidden_size=HIDDEN_SIZE, intermediate_size=INTERMEDIATE_SIZE) + exp0 = expert_cls(1, **expert_args) + exp1 = expert_cls(2, **expert_args) + exp2 = expert_cls(4, **expert_args) + exp3 = expert_cls(8, **expert_args) - assert exp0.num_local_experts == 1 - assert exp1.num_local_experts == 1 - assert exp2.num_local_experts == 1 - assert exp3.num_local_experts == 2 - # experts deployment passed + if expert_cls is EPMLPExperts: + assert exp0.num_local_experts == 1 + assert exp1.num_local_experts == 1 + assert exp2.num_local_experts == 1 + assert exp3.num_local_experts == 2 + else: + assert exp0.num_local_experts == 1 + assert exp1.num_local_experts == 2 + assert exp2.num_local_experts == 4 + assert exp3.num_local_experts == 8 parallel_info_dict = MOE_CONTEXT.parallel_info_dict rank = dist.get_rank() + # group creation assert assert len(parallel_info_dict) == 3 assert dist.get_rank(parallel_info_dict[4].ep_group) == rank assert dist.get_rank(parallel_info_dict[2].ep_group) == rank % 2 @@ -42,26 +43,33 @@ def run_test(rank, world_size, port): assert dist.get_rank(parallel_info_dict[4].dp_group) == 0 assert dist.get_rank(parallel_info_dict[2].dp_group) == rank // 2 assert dist.get_rank(parallel_info_dict[1].dp_group) == rank - # group creation passed model = nn.ModuleList([exp0, exp1, exp2, exp3]) model = model.to(get_current_device()) sync_moe_model_param(model) - assert_equal_in_group(exp0.experts[0].weight.data, parallel_info_dict[1].dp_group) - assert_equal_in_group(exp0.experts[0].bias.data, parallel_info_dict[1].dp_group) # MOE experts layout success when ep_size = 1 + assert_equal_in_group(exp0.wi.data, parallel_info_dict[1].dp_group) + assert_equal_in_group(exp0.wo.data, parallel_info_dict[1].dp_group) - assert_equal_in_group(exp1.experts[0].weight.data, parallel_info_dict[2].dp_group) - assert_equal_in_group(exp1.experts[0].bias.data, parallel_info_dict[2].dp_group) # MOE experts layout success when ep_size = 2 + assert_equal_in_group(exp1.wi.data, parallel_info_dict[2].dp_group) + assert_equal_in_group(exp1.wo.data, parallel_info_dict[2].dp_group) + + +def _run_test(rank, world_size, port, expert_cls): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + MOE_CONTEXT.setup(seed=42) + run_moe_init(expert_cls) @pytest.mark.dist +@pytest.mark.parametrize("expert_cls", [EPMLPExperts, TPMLPExperts]) @rerun_if_address_is_in_use() -def test_moe_initialization(): - spawn(run_test, 4) +def test_moe_initialization(expert_cls): + spawn(_run_test, 4, expert_cls=expert_cls) -if __name__ == "__main__": - test_moe_initialization() +if __name__ == '__main__': + test_moe_initialization(EPMLPExperts) + test_moe_initialization(TPMLPExperts) diff --git a/tests/test_moe/test_moe_zero_fwd_bwd.py b/tests/test_moe/test_moe_zero_fwd_bwd.py index e2acb0702f1c..9d19ee830f77 100644 --- a/tests/test_moe/test_moe_zero_fwd_bwd.py +++ b/tests/test_moe/test_moe_zero_fwd_bwd.py @@ -55,7 +55,6 @@ def run_zero_test(local_rank, world_size, stage=1): grad_handler = MoeGradientHandler(torch_model) # assert zero model - assert len(zero_model.module.test_transform.moe.moe_layer.experts.experts) == 8 // MOE_CONTEXT.world_size for (torch_name, torch_param), (zero_name, zero_param) in zip(torch_model.named_parameters(), zero_model.module.named_parameters()): assert zero_name == torch_name From 75fdcc22c718215cf8b7d8e829a4b0326229c366 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> Date: Mon, 4 Sep 2023 00:18:02 +0800 Subject: [PATCH 04/46] [moe] support local moe and fix bugs (#4574) * add local moe * update moe layer --- colossalai/nn/layer/moe/__init__.py | 2 +- colossalai/nn/layer/moe/_operation.py | 41 ++++++++ colossalai/nn/layer/moe/experts.py | 42 +++++--- colossalai/nn/layer/moe/layers.py | 137 ++++++++------------------ colossalai/nn/layer/moe/routers.py | 4 +- tests/test_moe/moe_utils.py | 34 +++++++ tests/test_moe/test_grad_handler.py | 15 +-- tests/test_moe/test_kernel.py | 23 +++-- tests/test_moe/test_moe_local.py | 63 ++++++++++++ 9 files changed, 230 insertions(+), 131 deletions(-) create mode 100644 tests/test_moe/test_moe_local.py diff --git a/colossalai/nn/layer/moe/__init__.py b/colossalai/nn/layer/moe/__init__.py index ffeeac796441..c20d16181909 100644 --- a/colossalai/nn/layer/moe/__init__.py +++ b/colossalai/nn/layer/moe/__init__.py @@ -5,6 +5,6 @@ from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts __all__ = [ - 'EPMLPExperts', 'TPMLPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'MoeModule', 'NormalNoiseGenerator', + 'EPMLPExperts', 'TPMLPExperts', 'Top1Router', 'Top2Router', 'MoeModule', 'MoeLayer', 'NormalNoiseGenerator', 'UniformNoiseGenerator', 'build_ffn_experts', 'SparseMLP', 'MoeRouter', 'MoeCheckpintIO' ] diff --git a/colossalai/nn/layer/moe/_operation.py b/colossalai/nn/layer/moe/_operation.py index 2f0b7e43673a..01530bb55c20 100644 --- a/colossalai/nn/layer/moe/_operation.py +++ b/colossalai/nn/layer/moe/_operation.py @@ -169,3 +169,44 @@ def moe_cumsum(inputs: Tensor): return moe.cumsum_sub_one(inputs) else: return torch.cumsum(inputs, dim=0) - 1 + + +class MoeInGradScaler(torch.autograd.Function): + """ + Scale the gradient back by the number of experts + because the batch size increases in the moe stage + """ + + @staticmethod + def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor: + if ctx is not None: + ctx.ep_size = ep_size + return inputs + + @staticmethod + def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: + assert len(grad_outputs) == 1 + grad = grad_outputs[0] + if ctx.ep_size != 1: + grad = grad * ctx.ep_size + return grad, None + + +class MoeOutGradScaler(torch.autograd.Function): + """ + Scale the gradient by the number of experts + because the batch size increases in the moe stage + """ + + @staticmethod + def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor: + ctx.ep_size = ep_size + return inputs + + @staticmethod + def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: + assert len(grad_outputs) == 1 + grad = grad_outputs[0] + if ctx.ep_size != 1: + grad = grad / ctx.ep_size + return grad, None diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/nn/layer/moe/experts.py index 7c743b025945..dadc1c606457 100644 --- a/colossalai/nn/layer/moe/experts.py +++ b/colossalai/nn/layer/moe/experts.py @@ -6,6 +6,7 @@ import torch.nn as nn from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.nn.layer.moe._operation import MoeInGradScaler, MoeOutGradScaler from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_size, set_moe_tensor_info @@ -19,27 +20,31 @@ def __init__( num_experts: int, hidden_size: int, intermediate_size: int, - expert_parallel: str, + expert_parallel: str = None, activation: str = None, drop_rate: float = 0, ): super().__init__() - assert expert_parallel in ["EP", "TP"] + assert expert_parallel in ["EP", "TP", None] self.expert_parallel = expert_parallel - - # get local and total experts self.num_total_experts = num_experts - self.num_local_experts, self.moe_info = MOE_CONTEXT.get_info(num_experts, - use_tp=True if expert_parallel == "TP" else False) - - # get settings for different parallel - if expert_parallel == "TP": - assert intermediate_size % MOE_CONTEXT.max_ep_size == 0, \ - "intermediate_size should be divide by maximum expert parallel size" - intermediate_size = intermediate_size // MOE_CONTEXT.max_ep_size - num_experts = self.num_total_experts + + # get expert parallel info + if expert_parallel is not None: + self.num_local_experts, self.moe_info = MOE_CONTEXT.get_info( + num_experts, use_tp=True if expert_parallel == "TP" else False) + # get settings for different parallel + if expert_parallel == "TP": + assert intermediate_size % MOE_CONTEXT.max_ep_size == 0, \ + "intermediate_size should be divide by maximum expert parallel size" + intermediate_size = intermediate_size // MOE_CONTEXT.max_ep_size + num_experts = self.num_total_experts + else: + num_experts = self.num_local_experts + self.ep_size = get_ep_size(self) else: - num_experts = self.num_local_experts + self.num_local_experts = self.num_total_experts + self.ep_size = 1 self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) self.wo = nn.Parameter(torch.empty(num_experts, intermediate_size, hidden_size)) @@ -51,10 +56,12 @@ def __init__( self.act = nn.GELU() if activation is None else activation self.drop = nn.Dropout(p=drop_rate) - for param in self.parameters(): - set_moe_tensor_info(param, self.moe_info) + if expert_parallel is not None: + for param in self.parameters(): + set_moe_tensor_info(param, self.moe_info) def forward(self, x: torch.Tensor) -> torch.Tensor: # inputs [g, e, c, h] + x = MoeInGradScaler.apply(x, self.ep_size) e = x.size(1) h = x.size(-1) @@ -71,6 +78,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # inputs [g, e, c, h] x = x.reshape(inshape) x = x.transpose(0, 1).contiguous() + x = MoeOutGradScaler.apply(x, self.ep_size) return x # outputs [g, e, c, h] @@ -134,5 +142,7 @@ def get_expert_class(name: str) -> BaseMLPExperts: return TPMLPExperts elif name == "EP": return EPMLPExperts + elif name is None: + return BaseMLPExperts else: raise ValueError(f"Unknown expert class name: {name}") diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py index 254e55eb3316..e7940a9082d2 100644 --- a/colossalai/nn/layer/moe/layers.py +++ b/colossalai/nn/layer/moe/layers.py @@ -21,91 +21,6 @@ from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size -class MoeLayer(nn.Module): - """A MoE layer, that puts its input tensor to its gate and uses the output logits - to router all tokens, is mainly used to exchange all tokens for every expert across - the moe tensor group by all to all communication. Then it will get the output of all - experts and exchange the output. At last returns the output of the moe system. - - Args: - dim_model (int): Dimension of model. - num_experts (int): The number of experts. - router (MoeRouter): Instance of router used in routing. - experts (MoeExperts): Instance of experts generated by Expert. - """ - - def __init__(self, dim_model: int, num_experts: int, router: MoeRouter, experts: BaseMLPExperts): - super().__init__() - self.d_model = dim_model - self.num_experts = num_experts - self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, dim_model)) - self.router: MoeRouter = router - self.experts: BaseMLPExperts = experts - self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False - self.ep_group = get_ep_group(experts) - self.ep_size = get_ep_size(experts) - self.num_local_experts = experts.num_local_experts - - nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / dim_model)) - - def ep_process(self, dispatch_data: torch.Tensor): - expert_input = AllToAll.apply(dispatch_data, self.ep_group) - input_shape = expert_input.shape - expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model) - expert_output = self.experts(expert_input) - expert_output = expert_output.reshape(input_shape) - expert_output = AllToAll.apply(expert_output, self.ep_group) - return expert_output - - def tp_process(self, dispatch_data: torch.Tensor): - expert_in = AllGather.apply(dispatch_data, self.ep_group) - expert_out = self.experts(expert_in) - expert_out = ReduceScatter.apply(expert_out, self.ep_group) - return expert_out - - def forward(self, inputs: torch.Tensor) -> Tuple: - # reshape the input tokens - tokens = inputs.reshape(-1, self.d_model) - - # the data type of the inputs in the gating should be fp32 - fp32_input = tokens.to(torch.float) - fp32_weight = self.gate_weight.to(torch.float) - gate_output = F.linear(fp32_input, fp32_weight) - - # the result from the router - route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group) - - if self.use_kernel: - dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:]) - dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model) - else: - sec_mask_f = route_result_list[1].type_as(inputs) - dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) - - # dispatch_data [e, c, h] - if self.experts.expert_parallel == "EP": - expert_output = self.ep_process(dispatch_data) - elif self.experts.expert_parallel == "TP": - expert_output = self.tp_process(dispatch_data) - else: - raise NotImplementedError( - "This kind of communication has not been implemented yet.\n Please use Experts " "build function." - ) - # expert_output [e, c, h] - if self.use_kernel: - expert_output = expert_output.reshape(-1, self.d_model) - ans = MoeCombine.apply(expert_output, *route_result_list) - else: - combine_weights = route_result_list[0].type_as(inputs) - combine_weights = combine_weights.view(combine_weights.shape[0], -1) - expert_output = expert_output.view(-1, expert_output.shape[-1]) - ans = torch.matmul(combine_weights, expert_output) - - ans = ans.reshape(inputs.shape) - l_aux = self.router.pop_routing_loss() - return ans, l_aux - - class SparseMLP(nn.Module): """A class for users to create MoE modules in their models. @@ -151,7 +66,8 @@ def __init__(self, self.hidden_size = hidden_size self.num_experts = num_experts self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False - assert expert_parallel in ["EP", "TP"], f"Unsupported expert parallel type {expert_parallel}" + self.expert_parallel = expert_parallel + assert expert_parallel in ["EP", "TP", None], f"Unsupported expert parallel type {expert_parallel}" # moe router noisy_func = get_noise_generator(noisy_policy, num_experts) @@ -168,8 +84,11 @@ def __init__(self, hidden_size=hidden_size, intermediate_size=intermediate_size, activation=activation) - self.ep_group = get_ep_group(self.experts) - self.ep_size = get_ep_size(self.experts) + if expert_parallel is not None: + self.ep_group = get_ep_group(self.experts) + self.ep_size = get_ep_size(self.experts) + else: + self.ep_group = None self.num_local_experts = self.experts.num_local_experts self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size)) @@ -195,10 +114,12 @@ def forward(self, inputs: torch.Tensor) -> Tuple: dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) # dispatch_data [e, c, h] - if self.experts.expert_parallel == "EP": - expert_output = self.ep_process(dispatch_data) - elif self.experts.expert_parallel == "TP": - expert_output = self.tp_process(dispatch_data) + if self.expert_parallel == "EP": + expert_output = self._ep_process(dispatch_data) + elif self.expert_parallel == "TP": + expert_output = self._tp_process(dispatch_data) + elif self.expert_parallel is None: + expert_output = self._local_process(dispatch_data) else: raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts " "build function.") @@ -216,7 +137,12 @@ def forward(self, inputs: torch.Tensor) -> Tuple: l_aux = self.router.pop_routing_loss() return ans, l_aux - def ep_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: + def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor: + expert_in = expert_in.unsqueeze(0) + expert_out = self.experts(expert_in) + return expert_out + + def _ep_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: expert_input = AllToAll.apply(dispatch_data, self.ep_group) input_shape = expert_input.shape expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) @@ -225,14 +151,35 @@ def ep_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: expert_output = AllToAll.apply(expert_output, self.ep_group) return expert_output - def tp_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: + def _tp_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: expert_in = AllGather.apply(dispatch_data, self.ep_group) expert_out = self.experts(expert_in) expert_out = ReduceScatter.apply(expert_out, self.ep_group) return expert_out -class MoeModule(nn.Module): +class MoeModule(SparseMLP): + """ + For other dependency + """ + + def __init__(self, + num_experts: int, + top_k: int = 1, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_policy: Optional[str] = None, + drop_tks: bool = True, + expert_parallel: str = "EP", + hidden_size: int = 2048, + intermediate_size: int = 2048, + activation: str = None): + super().__init__(num_experts, top_k, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_policy, + drop_tks, expert_parallel, hidden_size, intermediate_size, activation) + + +class MoeLayer(SparseMLP): """ For other dependency """ diff --git a/colossalai/nn/layer/moe/routers.py b/colossalai/nn/layer/moe/routers.py index 53fd8fd43e91..962aec9cf1e7 100644 --- a/colossalai/nn/layer/moe/routers.py +++ b/colossalai/nn/layer/moe/routers.py @@ -111,7 +111,7 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti l_aux = num_experts * torch.sum(me * ce) self.set_routing_loss(l_aux) - if not self.training and not self.drop_tks: + if not self.training and not self.drop_tks and ep_group is not None: max_num = torch.max(torch.sum(mask, dim=0)) dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) capacity = max_num.item() @@ -190,7 +190,7 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1 self.set_routing_loss(l_aux) - if not self.training and not self.drop_tks: + if not self.training and not self.drop_tks and ep_group is not None: max_num = torch.max(torch.sum(cmask, dim=0)) dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) capacity = max_num.item() diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py index b57567e74be3..a3f6f86e6fe9 100644 --- a/tests/test_moe/moe_utils.py +++ b/tests/test_moe/moe_utils.py @@ -127,3 +127,37 @@ def sync_tp_from_ep(tp_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: assert torch.allclose(tp_param.grad, new_grad) else: tp_param.data.copy_(new_tp_param.data) + + +def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None: + """Sync the parameters of tp model from ep model + + Args: + tp_model (MoeModule) + ep_model (MoeModule) + """ + for (local_name, local_param), (ep_name, ep_param) in zip(local_model.named_parameters(), + ep_model.named_parameters()): + assert local_name == ep_name + if "experts" not in local_name: + if assert_grad_flag: + assert torch.allclose(local_param, ep_param) + assert torch.allclose(local_param.grad, ep_param.grad) + else: + local_param.data.copy_(ep_param.data) + continue + + # gather param from ep model + param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] + dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param)) + all_param = torch.cat(param_list, dim=0) + if assert_grad_flag: + grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] + dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param)) + all_grad = torch.cat(grad_list, dim=0) + + if assert_grad_flag: + assert torch.allclose(local_param, all_param) + assert torch.allclose(local_param.grad, all_grad) + else: + local_param.data.copy_(all_param.data) diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index b004a016404d..f09a845afe3d 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -5,7 +5,7 @@ import colossalai from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.nn.layer.moe import EPMLPExperts, MoeLayer, Top1Router, UniformNoiseGenerator +from colossalai.nn.layer.moe import SparseMLP from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device from colossalai.utils.moe import sync_moe_model_param @@ -17,16 +17,17 @@ def run_test(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - expert_factor = dict(hidden_size=DIM, intermediate_size=DIM * 2) - MOE_CONTEXT.setup(42) # MOE initialization - noisy_func = UniformNoiseGenerator() - router = Top1Router(noisy_func=noisy_func) + MOE_CONTEXT.setup(42) # MOE initialization num_experts_list = [1, 2, 4] layer_list = [] for num_experts in num_experts_list: - exp = EPMLPExperts(num_experts, **expert_factor) - moe_layer = MoeLayer(DIM, num_experts, router, exp) + moe_layer = SparseMLP(hidden_size=DIM, + intermediate_size=DIM * 4, + num_experts=num_experts, + top_k=1, + expert_parallel="EP", + noisy_policy="Jitter") layer_list.append(moe_layer) model = nn.ModuleList(layer_list) diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index 5265e9a1320c..9a0675bc7b20 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -4,7 +4,7 @@ import colossalai from colossalai.context.moe_context import MOE_CONTEXT from colossalai.core import global_context as gpc -from colossalai.nn.layer.moe import EPMLPExperts, MoeLayer, Top1Router, Top2Router +from colossalai.nn.layer.moe import SparseMLP from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device @@ -16,7 +16,7 @@ def check_equal(tensor_a, tensor_b, atol=1e-06): assert torch.allclose(tensor_a, tensor_b, rtol=0, atol=atol) is True -def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32, router=Top2Router): +def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32, topk=1): # Here we do not need TF32, since it brings absolute error on results torch.backends.cuda.matmul.allow_tf32 = False @@ -30,9 +30,12 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f # get randomized data tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True) - expert_factor = dict(hidden_size=hidden_size, intermediate_size=hidden_size * 2) - expert = EPMLPExperts(NUM_EXPERTS, **expert_factor) - layer = MoeLayer(hidden_size, NUM_EXPERTS, router(capacity_factor_train=1.0), expert) + layer = SparseMLP(hidden_size=hidden_size, + intermediate_size=hidden_size * 2, + num_experts=NUM_EXPERTS, + top_k=topk, + expert_parallel="EP", + capacity_factor_train=1.0) layer = layer.to(get_current_device()) if data_type == torch.float16: layer = layer.half() @@ -82,11 +85,11 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f @pytest.mark.parametrize("rs", [131]) @pytest.mark.parametrize("hidden_size", [32, 144]) @pytest.mark.parametrize("data_type", [torch.float32, torch.float16]) -@pytest.mark.parametrize("router", [Top1Router, Top2Router]) +@pytest.mark.parametrize("topk", [1, 2]) @rerun_if_address_is_in_use() -def test_moe_kernel(rs, hidden_size, data_type, router): - spawn(run_routing, 4, rs=rs, hidden_size=hidden_size, data_type=data_type, router=router) +def test_moe_kernel(rs, hidden_size, data_type, topk): + spawn(run_routing, 4, rs=rs, hidden_size=hidden_size, data_type=data_type, topk=topk) -if __name__ == "__main__": - test_moe_kernel(2, 256, torch.float16, Top2Router) +if __name__ == '__main__': + test_moe_kernel(2, 256, torch.float16, 2) diff --git a/tests/test_moe/test_moe_local.py b/tests/test_moe/test_moe_local.py new file mode 100644 index 000000000000..d240ad46ce71 --- /dev/null +++ b/tests/test_moe/test_moe_local.py @@ -0,0 +1,63 @@ +import pytest +import torch +import torch.distributed as dist + +import colossalai +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.nn.layer.moe import SparseMLP +from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device +from colossalai.utils.moe import sync_moe_model_param +from tests.test_moe.moe_utils import MoeGradientHandler, sync_local_from_ep + +BATCH_SIZE = 4 +DIM = 4 + + +def run_test(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + MOE_CONTEXT.setup(42) # MOE initialization + + ep_model = SparseMLP(num_experts=4, expert_parallel="EP", hidden_size=DIM, intermediate_size=DIM) + local_model = SparseMLP(num_experts=4, expert_parallel=None, hidden_size=DIM, intermediate_size=DIM) + ep_model = ep_model.to(get_current_device()) + local_model = local_model.to(get_current_device()) + + # sync ep param + sync_moe_model_param(ep_model) + dist_dict = MOE_CONTEXT.parallel_info_dict + assert_equal_in_group(ep_model.experts.wi.data, dist_dict[2].dp_group) + assert_equal_in_group(ep_model.experts.wo.data, dist_dict[2].dp_group) + grad_handler = MoeGradientHandler(ep_model) + # sync tp param + sync_local_from_ep(local_model, ep_model) + + rank = dist.get_rank() + torch.cuda.manual_seed(78) + tp_data = torch.randn(BATCH_SIZE, DIM, device=get_current_device()) + ep_data = tp_data.detach()[2 * rank:2 * (rank + 1)] + + out_tp = local_model(tp_data)[0] + MOE_CONTEXT.reset_loss() + out_ep = ep_model(ep_data)[0] + MOE_CONTEXT.reset_loss() + assert torch.allclose(out_ep, out_tp[2 * rank:2 * (rank + 1)]) + + out_tp.mean().backward() + out_ep.mean().backward() + grad_handler.handle_gradient() + + assert_equal_in_group(ep_model.experts.wi.grad, dist_dict[2].dp_group) + assert_equal_in_group(ep_model.experts.wo.grad, dist_dict[2].dp_group) + + sync_local_from_ep(local_model, ep_model, assert_grad_flag=True) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_moe_ep_tp(): + spawn(run_test, 2) + + +if __name__ == '__main__': + test_moe_ep_tp() From 61995f86eb71ec5fe3862d526b1147385dacc6b6 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> Date: Wed, 6 Sep 2023 16:03:31 +0800 Subject: [PATCH 05/46] [moe] support openmoe inference (#4616) * init * update moe ckpt * update config * support openmoe infernece * update config * remove pdb * update ci * update requirement * add build ffn experts * update requirement * update ci * update ci * update require * update ci --- colossalai/nn/layer/moe/__init__.py | 6 +- colossalai/nn/layer/moe/experts.py | 51 +- colossalai/nn/layer/moe/layers.py | 6 +- colossalai/nn/layer/moe/utils.py | 35 +- colossalai/tensor/moe_tensor/__init__.py | 0 examples/language/openmoe/README.md | 17 + examples/language/openmoe/infer.py | 49 + examples/language/openmoe/infer.sh | 1 + .../openmoe/model/convert_openmoe_ckpt.py | 224 ++++ .../openmoe/model/convert_openmoe_ckpt.sh | 1 + .../openmoe/model/modeling_openmoe.py | 979 ++++++++++++++++++ .../openmoe/model/openmoe_8b_config.json | 24 + .../openmoe/model/openmoe_base_config.json | 24 + examples/language/openmoe/requirements.txt | 4 + examples/language/openmoe/test_ci.sh | 4 + 15 files changed, 1395 insertions(+), 30 deletions(-) create mode 100644 colossalai/tensor/moe_tensor/__init__.py create mode 100644 examples/language/openmoe/README.md create mode 100644 examples/language/openmoe/infer.py create mode 100644 examples/language/openmoe/infer.sh create mode 100644 examples/language/openmoe/model/convert_openmoe_ckpt.py create mode 100644 examples/language/openmoe/model/convert_openmoe_ckpt.sh create mode 100644 examples/language/openmoe/model/modeling_openmoe.py create mode 100644 examples/language/openmoe/model/openmoe_8b_config.json create mode 100644 examples/language/openmoe/model/openmoe_base_config.json create mode 100644 examples/language/openmoe/requirements.txt create mode 100644 examples/language/openmoe/test_ci.sh diff --git a/colossalai/nn/layer/moe/__init__.py b/colossalai/nn/layer/moe/__init__.py index c20d16181909..52f529814eba 100644 --- a/colossalai/nn/layer/moe/__init__.py +++ b/colossalai/nn/layer/moe/__init__.py @@ -1,10 +1,10 @@ from .checkpoint import MoeCheckpintIO -from .experts import EPMLPExperts, TPMLPExperts +from .experts import EPMLPExperts, TPMLPExperts, build_ffn_experts from .layers import MoeLayer, MoeModule, SparseMLP from .routers import MoeRouter, Top1Router, Top2Router -from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts +from .utils import NormalNoiseGenerator, UniformNoiseGenerator __all__ = [ 'EPMLPExperts', 'TPMLPExperts', 'Top1Router', 'Top2Router', 'MoeModule', 'MoeLayer', 'NormalNoiseGenerator', - 'UniformNoiseGenerator', 'build_ffn_experts', 'SparseMLP', 'MoeRouter', 'MoeCheckpintIO' + 'UniformNoiseGenerator', 'SparseMLP', 'MoeRouter', 'MoeCheckpintIO', 'build_ffn_experts' ] diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/nn/layer/moe/experts.py index dadc1c606457..52d5ad72ad7d 100644 --- a/colossalai/nn/layer/moe/experts.py +++ b/colossalai/nn/layer/moe/experts.py @@ -7,6 +7,7 @@ from colossalai.context.moe_context import MOE_CONTEXT from colossalai.nn.layer.moe._operation import MoeInGradScaler, MoeOutGradScaler +from colossalai.nn.layer.moe.utils import get_activation from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_size, set_moe_tensor_info @@ -23,11 +24,13 @@ def __init__( expert_parallel: str = None, activation: str = None, drop_rate: float = 0, + gated: bool = False, ): super().__init__() assert expert_parallel in ["EP", "TP", None] self.expert_parallel = expert_parallel self.num_total_experts = num_experts + self.gated = gated # get expert parallel info if expert_parallel is not None: @@ -46,14 +49,19 @@ def __init__( self.num_local_experts = self.num_total_experts self.ep_size = 1 - self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) + if gated: + self.wi_gate = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size * 2)) + self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) + else: + self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) self.wo = nn.Parameter(torch.empty(num_experts, intermediate_size, hidden_size)) - with seed(ParallelMode.TENSOR): - nn.init.trunc_normal_(self.wi, std=math.sqrt(0.1 / hidden_size)) - nn.init.trunc_normal_(self.wo, std=math.sqrt(0.1 / intermediate_size)) + if expert_parallel is not None: + with seed(ParallelMode.TENSOR): + nn.init.trunc_normal_(self.wi, std=math.sqrt(0.1 / hidden_size)) + nn.init.trunc_normal_(self.wo, std=math.sqrt(0.1 / intermediate_size)) - self.act = nn.GELU() if activation is None else activation + self.act = get_activation(activation) self.drop = nn.Dropout(p=drop_rate) if expert_parallel is not None: @@ -70,10 +78,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # inputs [g, e, c, h] inshape = x.shape x = x.reshape(e, -1, h) - x = torch.bmm(x, self.wi) - x = self.act(x) - with seed(ParallelMode.TENSOR): - x = self.drop(x) + if self.gated: + x = self.act(torch.bmm(x, self.wi_gate)) * torch.bmm(x, self.wi_up) + else: + x = torch.bmm(x, self.wi) + x = self.act(x) + + if self.expert_parallel is not None: + with seed(ParallelMode.TENSOR): + x = self.drop(x) x = torch.bmm(x, self.wo) x = x.reshape(inshape) @@ -92,8 +105,9 @@ def __init__(self, hidden_size: int, intermediate_size: int, activation=None, - drop_rate: float = 0): - super().__init__(num_experts, hidden_size, intermediate_size, "EP", activation, drop_rate) + drop_rate: float = 0, + gated: bool = False): + super().__init__(num_experts, hidden_size, intermediate_size, "EP", activation, drop_rate, gated) def state_dict(self, destination=None, prefix='', keep_vars=False): dp_rank = dist.get_rank(get_dp_group(self)) @@ -133,8 +147,9 @@ def __init__(self, hidden_size: int, intermediate_size: int, activation: str = None, - drop_rate: float = 0): - super().__init__(num_experts, hidden_size, intermediate_size, "TP", activation, drop_rate) + drop_rate: float = 0, + gated: bool = False): + super().__init__(num_experts, hidden_size, intermediate_size, "TP", activation, drop_rate, gated) def get_expert_class(name: str) -> BaseMLPExperts: @@ -146,3 +161,13 @@ def get_expert_class(name: str) -> BaseMLPExperts: return BaseMLPExperts else: raise ValueError(f"Unknown expert class name: {name}") + + +def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): + mep_size = MOE_CONTEXT.max_ep_size + if num_experts % mep_size == 0 or mep_size % num_experts == 0: + return EPMLPExperts(num_experts, d_model, d_ff, activation, drop_rate) + elif d_ff % mep_size == 0: + return TPMLPExperts(num_experts, d_model, d_ff, activation, drop_rate) + else: + raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.") diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py index e7940a9082d2..8104e33a8bab 100644 --- a/colossalai/nn/layer/moe/layers.py +++ b/colossalai/nn/layer/moe/layers.py @@ -61,7 +61,8 @@ def __init__(self, expert_parallel: str = "EP", hidden_size: int = 2048, intermediate_size: int = 2048, - activation: str = None): + activation: str = None, + gated: bool = False): super().__init__() self.hidden_size = hidden_size self.num_experts = num_experts @@ -83,7 +84,8 @@ def __init__(self, self.experts: BaseMLPExperts = expert_cls(num_experts=num_experts, hidden_size=hidden_size, intermediate_size=intermediate_size, - activation=activation) + activation=activation, + gated=gated) if expert_parallel is not None: self.ep_group = get_ep_group(self.experts) self.ep_size = get_ep_size(self.experts) diff --git a/colossalai/nn/layer/moe/utils.py b/colossalai/nn/layer/moe/utils.py index eb3bef70998d..369f6c0752ac 100644 --- a/colossalai/nn/layer/moe/utils.py +++ b/colossalai/nn/layer/moe/utils.py @@ -6,8 +6,6 @@ from colossalai.context.moe_context import MOE_CONTEXT from colossalai.utils import get_current_device -from .experts import EPMLPExperts, TPMLPExperts - class ForceFP32Parameter(torch.nn.Parameter): @@ -60,16 +58,6 @@ def autocast_softmax(logit: torch.Tensor, dim: int): return F.softmax(logit, dim=dim, detype=torch.float32) -def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): - mep_size = MOE_CONTEXT.max_ep_size - if num_experts % mep_size == 0 or mep_size % num_experts == 0: - return EPMLPExperts(num_experts, d_model, d_ff, activation, drop_rate) - elif d_ff % mep_size == 0: - return TPMLPExperts(num_experts, d_model, d_ff, activation, drop_rate) - else: - raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.") - - def get_noise_generator(noise_type: str, num_experts: int) -> Callable: if noise_type is None: return None @@ -80,3 +68,26 @@ def get_noise_generator(noise_type: str, num_experts: int) -> Callable: else: raise NotImplementedError("Unsupported input noisy policy") return noisy_func + + +def get_activation(act: str) -> Callable: + if act is None or act == 'relu': + return torch.nn.ReLU() + elif act == 'gelu': + return torch.nn.GELU() + elif act == 'swiglu': + return SwiGLU + else: + raise NotImplementedError("Unsupported activation function") + + +def SwiGLU(x): + """Gated linear unit activation function. + Args: + x : input array + axis: the axis along which the split should be computed (default: -1) + """ + size = x.shape[-1] + assert size % 2 == 0, "axis size must be divisible by 2" + x1, x2 = torch.split(x, size // 2, -1) + return x1 * (x2 * torch.sigmoid(x2)) diff --git a/colossalai/tensor/moe_tensor/__init__.py b/colossalai/tensor/moe_tensor/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/examples/language/openmoe/README.md b/examples/language/openmoe/README.md new file mode 100644 index 000000000000..26b5ee73b054 --- /dev/null +++ b/examples/language/openmoe/README.md @@ -0,0 +1,17 @@ +## OpenMoE +[OpenMoE](https://github.com/XueFuzhao/OpenMoE) is a project aimed at Igniting the Open-Source MoE Community! + +The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates finetune and inference methods. + + +## Our Modifications + +We reimplement OpenMoE with PyTorch + GPU. + +## Run Inference + +By running the following script: +```bash +bash infer.sh +``` +You will infer a [OpenMoE-8B/32E](https://github.com/XueFuzhao/OpenMoE) model. diff --git a/examples/language/openmoe/infer.py b/examples/language/openmoe/infer.py new file mode 100644 index 000000000000..b41fa2f2e4f1 --- /dev/null +++ b/examples/language/openmoe/infer.py @@ -0,0 +1,49 @@ +from argparse import ArgumentParser + +import torch +from model.modeling_openmoe import OpenMoeForCausalLM +from transformers import T5Tokenizer +from transformers.models.llama import LlamaConfig + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument("--model", default="base", type=str, help="model path", choices=["base", "8b", "test"]) + return parser.parse_args() + + +def inference(args): + + tokenizer = T5Tokenizer.from_pretrained("google/umt5-small") + if args.model == "test": + config = LlamaConfig.from_pretrained("hpcaitech/openmoe-base") + model = OpenMoeForCausalLM(config) + else: + model = OpenMoeForCausalLM.from_pretrained(f"hpcaitech/openmoe-{args.model}") + model = model.eval().bfloat16() + model = model.to(torch.cuda.current_device()) + + input_str = """``` +y = list(map(int, ['1', 'hello', '2'])) +``` +What error does this program produce? +ValueError: invalid literal for int() with base 10: 'hello' + +``` +sum = 0 +for i in range(100): + sum += i +``` +What is the value of sum immediately after the 10th time line 3 is executed?""" + + # print("model config: ", model.config) + input_ids = tokenizer("" + input_str, return_tensors="pt", add_special_tokens=True) + input_ids = input_ids.input_ids.to(torch.cuda.current_device()) + generation_output = model.generate(input_ids, use_cache=True, do_sample=True, max_new_tokens=128) + out = tokenizer.decode(generation_output[0], skip_special_tokens=False) + print(f"output: \n{out}\n") + + +if __name__ == "__main__": + args = parse_args() + inference(args) diff --git a/examples/language/openmoe/infer.sh b/examples/language/openmoe/infer.sh new file mode 100644 index 000000000000..a578203eba84 --- /dev/null +++ b/examples/language/openmoe/infer.sh @@ -0,0 +1 @@ +python infer.py --model "base" diff --git a/examples/language/openmoe/model/convert_openmoe_ckpt.py b/examples/language/openmoe/model/convert_openmoe_ckpt.py new file mode 100644 index 000000000000..20b1e780d8b3 --- /dev/null +++ b/examples/language/openmoe/model/convert_openmoe_ckpt.py @@ -0,0 +1,224 @@ +# coding=utf-8 +# Copyright 2022 Google LLC and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Convert T5X checkpoint to PyTorch + +Steps: +- Install gsutil according to https://cloud.google.com/storage/docs/gsutil_install +- Get a T5X checkpoint at https://github.com/google-research/t5x/blob/main/docs/models.md#t5-11-checkpoints Example: + `gsutil -m cp -r gs://t5-data/pretrained_models/t5x/t5_1_1_small $HOME/` +- Create or download a corresponding config for the downloaded model. E.g. for T5 v1.1 small, you can use + https://huggingface.co/google/t5-v1_1-small/blob/main/config.json +- Convert: + ``` + python3 convert_t5x_checkpoint_to_pytorch.py --t5x_checkpoint_path=$HOME/t5_1_1_small --config_file=config.json\ + --pytorch_dump_path=$HOME/t5_1_1_small_pt + ``` +""" + +import argparse +import collections + +import torch +from flax import traverse_util +from modeling_openmoe import OpenMoeForCausalLM +from t5x import checkpoints +from transformers import LlamaConfig +from transformers.utils import logging + +logging.set_verbosity_info() + + +def t5x_attention_lookup(params, i, prefix, layer_name="attention"): + """Returns the KOQV parameters of (self-)attention. Does not transpose.""" + k = params[f"{prefix}/layers_{i}/{layer_name}/key/kernel"] + o = params[f"{prefix}/layers_{i}/{layer_name}/out/kernel"] + q = params[f"{prefix}/layers_{i}/{layer_name}/query/kernel"] + v = params[f"{prefix}/layers_{i}/{layer_name}/value/kernel"] + return k, o, q, v + + +def t5x_mlp_lookup(params, i, prefix, split_mlp_wi=False): + """Returns the MLP parameters of a layer. Does not transpose.""" + if split_mlp_wi: + wi_0 = params[f"{prefix}/layers_{i}/mlp/wi_0/kernel"] + wi_1 = params[f"{prefix}/layers_{i}/mlp/wi_1/kernel"] + wi = (wi_0, wi_1) + else: + wi = params[f"{prefix}/layers_{i}/mlp/wi/kernel"] + + wo = params[f"{prefix}/layers_{i}/mlp/wo/kernel"] + return wi, wo + + +def t5x_extra_mlp_lookup(params, i, prefix, split_mlp_wi=False): + """Returns the MLP parameters of a layer. Does not transpose.""" + if split_mlp_wi: + wi_0 = params[f"{prefix}/layers_{i}/extra_mlp/wi_0/kernel"] + wi_1 = params[f"{prefix}/layers_{i}/extra_mlp/wi_1/kernel"] + wi = (wi_0, wi_1) + else: + wi = params[f"{prefix}/layers_{i}/extra_mlp/wi/kernel"] + + wo = params[f"{prefix}/layers_{i}/extra_mlp/wo/kernel"] + return wi, wo + + +def t5x_experts_lookup(params, i, prefix, split_mlp_wi=False): + """Returns the MLP parameters of a layer. Does not transpose.""" + if split_mlp_wi: + wi_0 = params[f"{prefix}/layers_{i}/mlp/expert/wi_0/kernel"] + wi_1 = params[f"{prefix}/layers_{i}/mlp/expert/wi_1/kernel"] + wi = (wi_0, wi_1) + else: + wi = params[f"{prefix}/layers_{i}/mlp/expert/wi/kernel"] + + wo = params[f"{prefix}/layers_{i}/mlp/expert/wo/kernel"] + return wi, wo + + +def t5x_gate_lookup(params, i, prefix, split_mlp_wi=False): + """Returns the MLP parameters of a layer. Does not transpose.""" + return params[f"{prefix}/layers_{i}/mlp/router/router_weights/w/kernel"] + + +def t5x_layer_norm_lookup(params, i, prefix, layer_name): + """Returns the layer norm param of a layer.""" + return params[f"{prefix}/layers_{i}/{layer_name}/scale"] + + +def convert_t5x_to_pytorch(variables: dict, *, num_layers: int, moe_interval: int): + """Converts the parameters from T5X-Flax to Transformers-PyTorch.""" + old = traverse_util.flatten_dict(variables["target"]) + old = {"/".join(k): v for k, v in old.items()} + + # v1.1 models have a gated GeLU with wi_0 and wi_1 instead of wi + split_mlp_wi = True + print("Split MLP:", split_mlp_wi) + + new = collections.OrderedDict() + print(old.keys()) + for key, value in old.items(): + print(f"{key}: {value.shape}") + + # Shared embeddings. + new["model.embed_tokens.weight"] = old["token_embedder/embedding"] + + # Decoder. + for i in range(num_layers): + # Block i, layer 0 (Self Attention). + layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_self_attention_layer_norm") + k, o, q, v = t5x_attention_lookup(old, i, "decoder", "self_attention") + new[f"model.layers.{i}.input_layernorm.weight"] = layer_norm + new[f"model.layers.{i}.self_attn.k_proj.weight"] = k.T + new[f"model.layers.{i}.self_attn.o_proj.weight"] = o.T + new[f"model.layers.{i}.self_attn.q_proj.weight"] = q.T + new[f"model.layers.{i}.self_attn.v_proj.weight"] = v.T + + # Block i, layer 2 (MLP). + layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_mlp_layer_norm") + new[f"model.layers.{i}.post_attention_layernorm.weight"] = layer_norm + + if (i + 1) % moe_interval == 0: + # moe + gate = t5x_gate_lookup(old, i, "decoder", split_mlp_wi) + new[f"model.layers.{i}.mlp.gate_weight"] = gate.T + wi, wo = t5x_experts_lookup(old, i, "decoder", split_mlp_wi) + new[f"model.layers.{i}.mlp.experts.wi_gate"] = wi[0] + new[f"model.layers.{i}.mlp.experts.wi_up"] = wi[1] + new[f"model.layers.{i}.mlp.experts.wo"] = wo + # extra + layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_extra_mlp_layer_norm") + new[f"model.layers.{i}.pre_extra_mlp_layernorm.weight"] = layer_norm + wi, wo = t5x_extra_mlp_lookup(old, i, "decoder", split_mlp_wi) + new[f"model.layers.{i}.extra_mlp.gate_proj.weight"] = wi[0].T + new[f"model.layers.{i}.extra_mlp.up_proj.weight"] = wi[1].T + new[f"model.layers.{i}.extra_mlp.down_proj.weight"] = wo.T + else: + wi, wo = t5x_mlp_lookup(old, i, "decoder", split_mlp_wi) + new[f"model.layers.{i}.mlp.gate_proj.weight"] = wi[0].T + new[f"model.layers.{i}.mlp.up_proj.weight"] = wi[1].T + new[f"model.layers.{i}.mlp.down_proj.weight"] = wo.T + + new["model.norm.weight"] = old["decoder/decoder_norm/scale"] + + # LM Head (only in v1.1 checkpoints, in v1.0 embeddings are used instead) + if "decoder/logits_dense/kernel" in old: + new["lm_head.weight"] = old["decoder/logits_dense/kernel"].T + + return new + + +def make_state_dict(converted_params): + """Prepares a state dict for the PyTorch model.""" + # Make a state dict with torch tensors. + state_dict = collections.OrderedDict([(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()]) + + return state_dict + + +def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path): + """Replaces the params in model witht the T5X converted params.""" + variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path) + converted = convert_t5x_to_pytorch(variables, + num_layers=config.num_hidden_layers, + moe_interval=config.moe_layer_interval) + state_dict = make_state_dict(converted) + model.load_state_dict(state_dict, strict=True) + + +def convert_t5x_checkpoint_to_pytorch(t5x_checkpoint_path, config_file, pytorch_dump_path): + """Loads the config and model, converts the T5X checkpoint, and saves a PyTorch checkpoint.""" + # Initialise PyTorch model + config = LlamaConfig.from_json_file(config_file) + print(f"Building PyTorch model from configuration: {config}") + # Non-v1.1 checkpoints could also use T5Model, but this works for all. + # The v1.0 checkpoints will simply have an LM head that is the word embeddings. + model = OpenMoeForCausalLM(config) + + # Load weights from tf checkpoint + load_t5x_weights_in_t5(model, config, t5x_checkpoint_path) + + # Save pytorch-model + print(f"Save PyTorch model to {pytorch_dump_path}") + model.save_pretrained(pytorch_dump_path) + + # Verify that we can load the checkpoint. + model.from_pretrained(pytorch_dump_path) + print("Done") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Converts a native T5X checkpoint into a PyTorch checkpoint.") + # Required parameters + parser.add_argument("--t5x_checkpoint_path", + default=None, + type=str, + required=True, + help="Path to the T5X checkpoint.") + parser.add_argument( + "--config_file", + default=None, + type=str, + required=True, + help="The config json file corresponding to the pre-trained T5 model.\nThis specifies the model architecture.", + ) + parser.add_argument("--pytorch_dump_path", + default=None, + type=str, + required=True, + help="Path to the output PyTorch model.") + args = parser.parse_args() + convert_t5x_checkpoint_to_pytorch(args.t5x_checkpoint_path, args.config_file, args.pytorch_dump_path) diff --git a/examples/language/openmoe/model/convert_openmoe_ckpt.sh b/examples/language/openmoe/model/convert_openmoe_ckpt.sh new file mode 100644 index 000000000000..c0d53f562e40 --- /dev/null +++ b/examples/language/openmoe/model/convert_openmoe_ckpt.sh @@ -0,0 +1 @@ +python convert_openmoe_ckpt.py --t5x_checkpoint_path /path/to/t5x --config_file /path/to/config --pytorch_dump_path /path/to/save diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py new file mode 100644 index 000000000000..7fdd4cc32c23 --- /dev/null +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -0,0 +1,979 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch OpenMoE model.""" +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.models.llama import LlamaConfig +from transformers.models.t5.modeling_t5 import T5LayerNorm +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) + +from colossalai.nn.layer.moe.layers import SparseMLP + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlamaConfig" + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask(input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +def generate_fixed_pos_embedding(features, length, min_timescale=1.0, max_timescale=10000.0): + """Generate Sin/Cos for Rotary Embeddings. + + Args: + features: an integer + length: an integer + min_timescale: an optional float + max_timescale: an optional float + + Returns: + output_sin: a float32 Tensor with shape [length, features] + output_cos: a float32 Tensor with shape [length, features] + """ + fraction = torch.arange(0, features, 2, dtype=torch.float64).cuda() / features + timescale = min_timescale * (max_timescale / min_timescale)**fraction + rotational_frequency = 1. / timescale + + sinusoid_inp = torch.einsum('i,j->ij', torch.arange(length, dtype=torch.float64).cuda(), rotational_frequency) + + sinusoid_inp = torch.cat([sinusoid_inp, sinusoid_inp], dim=-1) + + return torch.sin(sinusoid_inp).to(torch.bfloat16), torch.cos(sinusoid_inp).to(torch.bfloat16) + + +def apply_rotary_embedding(q, k, cos, sin, decode=False, rotary_index=None): + """Helper function to apply Rotary Embeddings.""" + if len(k.shape) == 3: + # for multi query attention + k = k.unsqueeze(2) + multiquery = True + else: + multiquery = False + + batch, qlen, qheads, d = q.shape + kbatch, klen, kheads, kd = k.shape + assert batch == kbatch, f'{batch} != {kbatch}' + assert d == kd, f'{d} != {kd}' + if decode and qlen == 1 and rotary_index is not None: + qcos = cos[rotary_index + 1, :] + qsin = sin[rotary_index + 1, :] + qcos = qcos.unsqueeze(2).expand(batch, qlen, qheads, d) + qsin = qsin.unsqueeze(2).expand(batch, qlen, qheads, d) + else: + qcos, qsin = cos[:qlen, :], sin[:qlen, :] + qcos = qcos.unsqueeze(0).unsqueeze(2).expand(batch, qlen, qheads, d) + qsin = qsin.unsqueeze(0).unsqueeze(2).expand(batch, qlen, qheads, d) + + kcos, ksin = cos[:klen, :], sin[:klen, :] + kcos = kcos.unsqueeze(0).unsqueeze(2).expand(batch, klen, kheads, d) + ksin = ksin.unsqueeze(0).unsqueeze(2).expand(batch, klen, kheads, d) + + out_q = (q * qcos) + (rotate_half(q) * qsin) + out_k = (k * kcos) + (rotate_half(k) * ksin) + + if multiquery: + out_k = out_k.squeeze(2) + + return out_q, out_k + + +class LlamaRotaryEmbedding(torch.nn.Module): + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.inv_freq = inv_freq + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache(seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype()) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + ) + + +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) - + (self.scaling_factor - 1))**(self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def SwiGLU(x): + """Gated linear unit activation function. + Args: + x : input array + axis: the axis along which the split should be computed (default: -1) + """ + size = x.shape[-1] + assert size % 2 == 0, "axis size must be divisible by 2" + x1, x2 = torch.split(x, size // 2, -1) + return x1 * (x2 * torch.sigmoid(x2)) + + +class LlamaMLP(nn.Module): + + def __init__(self, config): + super().__init__() + self.pretraining_tp = config.pretraining_tp + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = SwiGLU + + def forward(self, x): + if self.pretraining_tp > 1: + slice = self.intermediate_size // self.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)] + down_proj = sum(down_proj) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.pretraining_tp = config.pretraining_tp + self.max_position_embeddings = config.max_position_embeddings + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp + query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + dim = query_states.shape[-1] + max_length = max(query_states.shape[1], key_states.shape[1]) + sin, cos = generate_fixed_pos_embedding(dim, max_length, max_timescale=1e4) + query_states, key_states = apply_rotary_embedding(query_states, + key_states, + cos, + sin, + decode=True if q_len == 1 else False, + rotary_index=position_ids) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError(f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}") + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}") + attention_mask[:, :, :, 0] = 0 + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) + + if self.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaDecoderLayer(nn.Module): + + def __init__(self, config: LlamaConfig, moe: bool): + super().__init__() + self.hidden_size = config.hidden_size + self.moe = moe + self.self_attn = LlamaAttention(config=config) + self.input_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) + if self.moe: + self.mlp = SparseMLP(num_experts=config.num_experts, + top_k=config.topk, + capacity_factor_train=config.capacity_factor_train, + capacity_factor_eval=config.capacity_factor_eval, + min_capacity=config.min_capacity, + noisy_policy=config.noisy_policy, + drop_tks=config.drop_tks, + expert_parallel=config.expert_parallel, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + activation=config.hidden_act, + gated=config.gated) + self.pre_extra_mlp_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) + self.extra_mlp = LlamaMLP(config) + else: + self.mlp = LlamaMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + if self.moe: + hidden_states = hidden_states[0] + hidden_states = residual + hidden_states + + if self.moe: + residual = hidden_states + hidden_states = self.pre_extra_mlp_layernorm(hidden_states) + hidden_states = self.extra_mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, LlamaModel): + module.gradient_checkpointing = value + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([ + LlamaDecoderLayer(config, moe=True if (i + 1) % config.moe_layer_interval == 0 else False) + for i in range(config.num_hidden_layers) + ]) + self.norm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, + tgt_len=input_shape[-1]).to(inputs_embeds.device) + combined_attention_mask = (expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + + combined_attention_mask) + + return combined_attention_mask + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange(past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # import pdb; pdb.set_trace() + # embed positions + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), + dtype=torch.bool, + device=inputs_embeds.device) + attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds, + past_key_values_length) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class OpenMoeForCausalLM(LlamaPreTrainedModel): + # _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.pretraining_tp = config.pretraining_tp + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation(self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs): + if past_key_values: + input_ids = input_ids[:, -1:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update({ + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + }) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += (tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),) + return reordered_past diff --git a/examples/language/openmoe/model/openmoe_8b_config.json b/examples/language/openmoe/model/openmoe_8b_config.json new file mode 100644 index 000000000000..248697c37d3c --- /dev/null +++ b/examples/language/openmoe/model/openmoe_8b_config.json @@ -0,0 +1,24 @@ +{ + "architectures": [ + "OpenMoeForCausalLM" + ], + "intermediate_size": 8192, + "hidden_size": 2048, + "num_hidden_layers": 24, + "head_dim": 128, + "num_attention_heads": 24, + "dropout_rate": 0.0, + "layer_norm_epsilon": 1e-06, + "vocab_size": 256384, + "hidden_act": "swiglu", + "num_experts": 32, + "topk": 2, + "capacity_factor_train": 1.25, + "capacity_factor_eval": 2.0, + "min_capacity": 4, + "noisy_policy": null, + "drop_tks": true, + "expert_parallel": null, + "gated": true, + "moe_layer_interval": 6 +} diff --git a/examples/language/openmoe/model/openmoe_base_config.json b/examples/language/openmoe/model/openmoe_base_config.json new file mode 100644 index 000000000000..5a7c97bd1916 --- /dev/null +++ b/examples/language/openmoe/model/openmoe_base_config.json @@ -0,0 +1,24 @@ +{ + "architectures": [ + "OpenMoeForCausalLM" + ], + "intermediate_size": 2048, + "hidden_size": 768, + "num_hidden_layers": 12, + "head_dim": 64, + "num_attention_heads": 12, + "dropout_rate": 0.0, + "layer_norm_epsilon": 1e-06, + "vocab_size": 256384, + "hidden_act": "swiglu", + "num_experts": 16, + "topk": 2, + "capacity_factor_train": 1.25, + "capacity_factor_eval": 2.0, + "min_capacity": 4, + "noisy_policy": null, + "drop_tks": true, + "expert_parallel": null, + "gated": true, + "moe_layer_interval": 4 +} diff --git a/examples/language/openmoe/requirements.txt b/examples/language/openmoe/requirements.txt new file mode 100644 index 000000000000..2fb95d9c71d3 --- /dev/null +++ b/examples/language/openmoe/requirements.txt @@ -0,0 +1,4 @@ +colossalai >= 0.1.12 +torch >= 1.8.1 +transformers >= 4.20.0 +sentencepiece diff --git a/examples/language/openmoe/test_ci.sh b/examples/language/openmoe/test_ci.sh new file mode 100644 index 000000000000..349b2eaccd79 --- /dev/null +++ b/examples/language/openmoe/test_ci.sh @@ -0,0 +1,4 @@ +set -xe +pip install -r requirements.txt + +python infer.py --model "test" From bf5348783a175394c47e3889f164aeec840aa9cb Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> Date: Thu, 7 Sep 2023 16:28:26 +0800 Subject: [PATCH 06/46] [moe] support openmoe train (#4637) * init * update moe ckpt * update config * support openmoe infernece * update config * remove pdb * support train * add ckpt download * update ckpt loading * use general ckpt --- colossalai/context/moe_context.py | 7 +- colossalai/nn/layer/moe/checkpoint.py | 40 ++-- colossalai/nn/layer/moe/experts.py | 32 +--- colossalai/nn/layer/moe/utils.py | 28 +++ .../openmoe/model/modeling_openmoe.py | 99 +++++++--- examples/language/openmoe/train.py | 180 ++++++++++++++++++ 6 files changed, 319 insertions(+), 67 deletions(-) create mode 100644 examples/language/openmoe/train.py diff --git a/colossalai/context/moe_context.py b/colossalai/context/moe_context.py index ea74d2c60dd6..f90b4071dee8 100644 --- a/colossalai/context/moe_context.py +++ b/colossalai/context/moe_context.py @@ -21,6 +21,7 @@ def __init__(self): self.max_ep_size = None self.min_dp_size = None self.aux_loss = None + self.parallel = None self.use_kernel_optim = True self.has_setup = False @@ -34,13 +35,14 @@ def parallel_info_dict(self): def is_initialized(self): return self.has_setup - def setup(self, seed: int, use_kernel_optim: bool = True, max_ep_size: int = 8): + def setup(self, seed: int, use_kernel_optim: bool = True, max_ep_size: int = 8, parallel: bool = None): assert not self.is_initialized, "MoE distributed context shouldn't be set up again" assert torch.cuda.is_available(), "MoE requires to enable CUDA first" self.world_size = dist.get_world_size() self.max_ep_size = min(max_ep_size, dist.get_world_size()) self.min_dp_size = self.world_size // self.max_ep_size + self.parallel = parallel # Enabling kernel optimization may raise error in some cases # Users can close kernel optimization manually @@ -106,5 +108,8 @@ def add_loss(self, loss): def get_loss(self): return self.aux_loss + def get_parallel(self): + return self.parallel + MOE_CONTEXT = MoeContext() diff --git a/colossalai/nn/layer/moe/checkpoint.py b/colossalai/nn/layer/moe/checkpoint.py index 34af87bd9d47..3cda5a7f044c 100644 --- a/colossalai/nn/layer/moe/checkpoint.py +++ b/colossalai/nn/layer/moe/checkpoint.py @@ -1,3 +1,4 @@ +from copy import deepcopy from pathlib import Path from typing import Optional @@ -6,32 +7,47 @@ import torch.nn as nn from torch.optim import Optimizer -from colossalai.checkpoint_io import CheckpointIO -from colossalai.tensor.moe_tensor.api import get_ep_group +from colossalai.checkpoint_io import GeneralCheckpointIO +from colossalai.tensor.moe_tensor.api import get_dp_rank, get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor -class MoeCheckpintIO(CheckpointIO): +class MoeCheckpintIO(GeneralCheckpointIO): def __init__(self) -> None: super().__init__() def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): state_dict = torch.load(checkpoint) - for name, param in model.named_parameters(): + for name, param in state_dict.items(): if '.experts.' in name: - ep_rank = dist.get_rank(get_ep_group(param)) - ep_size = dist.get_world_size(get_ep_group(param)) - for rank in range(ep_size): - new_name = name.replace('.experts.', f'.experts.{rank}.') - if rank == ep_rank: - state_dict[name] = state_dict.pop(new_name) - else: - state_dict.pop(new_name) + model_param = dict(model.named_parameters())[name] + if is_moe_tensor(model_param): + ep_rank = get_ep_rank(model_param) + ep_size = get_ep_size(model_param) + expert_num = param.shape[0] // ep_size + assert param.shape[0] % ep_size == 0 + param = param[ep_rank * expert_num:(ep_rank + 1) * expert_num] + state_dict[name] = param model.load_state_dict(state_dict, strict=strict) def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): state_dict = model.state_dict() + for name, param in model.named_parameters(): + if '.experts.' in name and is_moe_tensor(param): + ep_group = get_ep_group(param) + ep_rank = get_ep_rank(param) + ep_size = get_ep_size(param) + dp_rank = get_dp_rank(param) + if dp_rank == 0: + param = param.data.cuda() + all_param = [deepcopy(param) for _ in range(ep_size)] + # gather param from every ep rank + dist.all_gather(all_param, param, group=ep_group) + if ep_rank == 0: + assert dist.get_rank() == 0 + all_param = torch.cat(all_param, dim=0) + state_dict[name] = all_param.cpu() if dist.get_rank() == 0: torch.save(state_dict, checkpoint) dist.barrier() diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/nn/layer/moe/experts.py index 52d5ad72ad7d..f9289749d3a1 100644 --- a/colossalai/nn/layer/moe/experts.py +++ b/colossalai/nn/layer/moe/experts.py @@ -58,7 +58,11 @@ def __init__( if expert_parallel is not None: with seed(ParallelMode.TENSOR): - nn.init.trunc_normal_(self.wi, std=math.sqrt(0.1 / hidden_size)) + if gated: + nn.init.trunc_normal_(self.wi_gate, std=math.sqrt(0.1 / hidden_size)) + nn.init.trunc_normal_(self.wi_up, std=math.sqrt(0.1 / hidden_size)) + else: + nn.init.trunc_normal_(self.wi, std=math.sqrt(0.1 / hidden_size)) nn.init.trunc_normal_(self.wo, std=math.sqrt(0.1 / intermediate_size)) self.act = get_activation(activation) @@ -109,32 +113,6 @@ def __init__(self, gated: bool = False): super().__init__(num_experts, hidden_size, intermediate_size, "EP", activation, drop_rate, gated) - def state_dict(self, destination=None, prefix='', keep_vars=False): - dp_rank = dist.get_rank(get_dp_group(self)) - ep_rank = dist.get_rank(get_ep_group(self)) - ep_size = get_ep_size(self) - # dp rank 0 will save the state dict - if dp_rank == 0: - for name, param in self.named_parameters(): - if param is self: - continue - # create buffer - buffer_module = deepcopy(param) - # gather param from every ep rank - for source_rank in range(ep_size): - current_prefix = f"{prefix}{source_rank}." - if ep_rank == source_rank: - dist.broadcast(param.data, src=source_rank, group=self.moe_info.ep_group) - else: - dist.broadcast(buffer_module.data, src=source_rank, group=self.moe_info.ep_group) - if ep_rank == 0: - if keep_vars: - destination[current_prefix + name] = buffer_module.cpu() - else: - destination[current_prefix + name] = buffer_module.data.cpu() - - dist.barrier() - class TPMLPExperts(BaseMLPExperts): """Use tensor parallelism to split each expert evenly, which can deploy experts in diff --git a/colossalai/nn/layer/moe/utils.py b/colossalai/nn/layer/moe/utils.py index 369f6c0752ac..5b3542c80595 100644 --- a/colossalai/nn/layer/moe/utils.py +++ b/colossalai/nn/layer/moe/utils.py @@ -1,3 +1,4 @@ +import contextlib from typing import Callable import torch @@ -91,3 +92,30 @@ def SwiGLU(x): assert size % 2 == 0, "axis size must be divisible by 2" x1, x2 = torch.split(x, size // 2, -1) return x1 * (x2 * torch.sigmoid(x2)) + + +@contextlib.contextmanager +def skip_init(): + """ + skip param random init + """ + + def _skip_init(x, *args, **kwargs): + return x + + # __enter__ + fn_saved = [] + init_fn_list = [ + torch.nn.init.constant_, torch.nn.init.uniform_, torch.nn.init.normal_, torch.nn.init.xavier_uniform_, + torch.nn.init.xavier_normal_, torch.nn.init.kaiming_uniform_, torch.nn.init.kaiming_normal_ + ] + for fn in init_fn_list: + fn_saved.append(fn) + fn = _skip_init + + yield + + # __exit__ + for fn, fn_saved in zip(init_fn_list, fn_saved): + fn = fn_saved + return diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index 7fdd4cc32c23..a1e028ae6308 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -37,6 +37,7 @@ replace_return_docstrings, ) +from colossalai.context import MOE_CONTEXT from colossalai.nn.layer.moe.layers import SparseMLP logger = logging.get_logger(__name__) @@ -99,11 +100,14 @@ def generate_fixed_pos_embedding(features, length, min_timescale=1.0, max_timesc sinusoid_inp = torch.cat([sinusoid_inp, sinusoid_inp], dim=-1) - return torch.sin(sinusoid_inp).to(torch.bfloat16), torch.cos(sinusoid_inp).to(torch.bfloat16) + return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp) def apply_rotary_embedding(q, k, cos, sin, decode=False, rotary_index=None): """Helper function to apply Rotary Embeddings.""" + cos = cos.to(q.dtype) + sin = sin.to(q.dtype) + if len(k.shape) == 3: # for multi query attention k = k.unsqueeze(2) @@ -405,6 +409,8 @@ def forward( if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}") + if self.training: + attention_mask = attention_mask.clone().detach() attention_mask[:, :, :, 0] = 0 attn_weights = attn_weights + attention_mask @@ -442,18 +448,19 @@ def __init__(self, config: LlamaConfig, moe: bool): self.input_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) if self.moe: - self.mlp = SparseMLP(num_experts=config.num_experts, - top_k=config.topk, - capacity_factor_train=config.capacity_factor_train, - capacity_factor_eval=config.capacity_factor_eval, - min_capacity=config.min_capacity, - noisy_policy=config.noisy_policy, - drop_tks=config.drop_tks, - expert_parallel=config.expert_parallel, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - activation=config.hidden_act, - gated=config.gated) + self.mlp = SparseMLP( + num_experts=config.num_experts, + top_k=config.topk, + capacity_factor_train=config.capacity_factor_train, + capacity_factor_eval=config.capacity_factor_eval, + min_capacity=config.min_capacity, + noisy_policy=config.noisy_policy, + drop_tks=config.drop_tks, + expert_parallel=MOE_CONTEXT.get_parallel() if MOE_CONTEXT.is_initialized else config.expert_parallel, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + activation=config.hidden_act, + gated=config.gated) self.pre_extra_mlp_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) self.extra_mlp = LlamaMLP(config) else: @@ -860,6 +867,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + chunk_head: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -910,22 +918,59 @@ def forward( lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0) logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)] logits = torch.cat(logits, dim=-1) - else: - logits = self.lm_head(hidden_states) - logits = logits.float() loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) + # if no training, just do forward + if labels is None: + logits = self.lm_head(hidden_states) + logits = logits.float() + # the vocab size for openmoe is 30w+ + # which causes great activation memory in training, up to 20G for one sequence + # so we use chunk and checkpoint to reduce memory + else: + if chunk_head == True: + + def create_custom_forward(module): + + def custom_forward(*inputs): + logits = module(inputs[0]) + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous().float() + shift_labels = inputs[1][..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + return loss + + return custom_forward + + loss = 0. + for batch_idx in range(hidden_states.shape[0]): + loss = loss + torch.utils.checkpoint.checkpoint( + create_custom_forward(self.lm_head), + hidden_states[batch_idx, :], + labels[batch_idx, :], + ) + loss = loss / hidden_states.shape[0] + logits = None + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py new file mode 100644 index 000000000000..407809702436 --- /dev/null +++ b/examples/language/openmoe/train.py @@ -0,0 +1,180 @@ +import os + +import datasets +import torch +import transformers +from huggingface_hub import snapshot_download +from model.modeling_openmoe import OpenMoeForCausalLM +from torch.utils.data import Dataset +from tqdm import tqdm +from transformers import T5Tokenizer, get_linear_schedule_with_warmup +from transformers.models.llama import LlamaConfig + +import colossalai +from colossalai import get_default_parser +from colossalai.booster import Booster +from colossalai.booster.plugin import LowLevelZeroPlugin +from colossalai.cluster import DistCoordinator +from colossalai.context import MOE_CONTEXT +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.layer.moe import MoeCheckpintIO +from colossalai.nn.layer.moe.utils import skip_init +from colossalai.utils import get_current_device + + +def move_to_cuda(batch, device): + return {k: v.to(device) for k, v in batch.items()} + + +def load_ckpt(repo_name: str, model: OpenMoeForCausalLM): + ckpt_path = snapshot_download(repo_name) + # single ckpt + if os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin")): + ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin") + # shard ckpt + elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")): + ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json") + else: + raise ValueError(f"Invalid checkpoint path: {ckpt_path}") + MoeCheckpintIO().load_model(model, ckpt_path) + + +class RandomDataset(Dataset): + + def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000): + self.num_samples = num_samples + self.max_length = max_length + self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) + self.attention_mask = torch.ones_like(self.input_ids) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + return { + 'input_ids': self.input_ids[idx], + 'attention_mask': self.attention_mask[idx], + 'labels': self.input_ids[idx] + } + + +def parse_args(): + parser = get_default_parser() + parser.add_argument("--model_name_or_path", + type=str, + default="base", + help="Path to pretrained model or model identifier from huggingface.co/models.") + parser.add_argument("--output_path", + type=str, + default="./output_model.bin", + help="The path of your saved model after finetuning.") + parser.add_argument("--num_epoch", type=int, default=10, help="Number of epochs.") + parser.add_argument("--batch_size", + type=int, + default=4, + help="Batch size (per dp group) for the training dataloader.") + parser.add_argument("--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.") + parser.add_argument("--warmup_ratio", + type=float, + default=0.1, + help="Ratio of warmup steps against total training steps.") + parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to use.") + parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + # Launch ColossalAI + colossalai.launch_from_torch(config={}, seed=args.seed) + coordinator = DistCoordinator() + world_size = coordinator.world_size + + # Set up moe + MOE_CONTEXT.setup(seed=42, parallel="EP") + + # Manage loggers + disable_existing_loggers() + logger = get_dist_logger() + if coordinator.is_master(): + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # Build OpenMoe model + repo_name = "hpcaitech/openmoe-" + args.model_name_or_path + config = LlamaConfig.from_pretrained(repo_name) + with skip_init(): + model = OpenMoeForCausalLM(config) + load_ckpt(repo_name, model) + logger.info(f"Finish init model with config:\n{config}", ranks=[0]) + + # Enable gradient checkpointing + model.gradient_checkpointing_enable() + + # Set plugin + booster_kwargs = {} + plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2) + logger.info(f"Set plugin as {plugin}", ranks=[0]) + + # Prepare tokenizer and dataloader + tokenizer = T5Tokenizer.from_pretrained("google/umt5-small") + dataset = RandomDataset() + dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) + + # Set optimizer + optimizer = torch.optim.Adam(model.parameters(), + lr=(args.learning_rate * world_size), + weight_decay=args.weight_decay) + + # Set lr scheduler + total_steps = len(dataloader) * args.num_epoch + num_warmup_steps = int(args.warmup_ratio * total_steps) + lr_scheduler = get_linear_schedule_with_warmup(optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=len(dataloader) * args.num_epoch) + + # Set booster + booster = Booster(plugin=plugin, **booster_kwargs) + model, optimizer, _, dataloader, lr_scheduler = booster.boost(model=model, + optimizer=optimizer, + dataloader=dataloader, + lr_scheduler=lr_scheduler) + logger.info(f"Finish init booster", ranks=[0]) + + # Start finetuning + logger.info(f"Start finetuning", ranks=[0]) + for epoch in range(args.num_epoch): + model.train() + with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar: + for batch in pbar: + # Forward + optimizer.zero_grad() + batch = move_to_cuda(batch, torch.cuda.current_device()) + + outputs = model(use_cache=False, chunk_head=True, **batch) + loss = outputs['loss'] + + # Backward + booster.backward(loss, optimizer) + optimizer.step() + lr_scheduler.step() + + # Print batch loss + pbar.set_postfix({'loss': loss.item()}) + + # Finish training and evaluate + logger.info(f"Finish finetuning", ranks=[0]) + booster.save_model(model, args.output_path) + logger.info(f"Saving model checkpoint to {args.output_path}", ranks=[0]) + + +if __name__ == "__main__": + main() From 55a81a64c8f0a909f9ee5ebcbedd8fd7b45767bc Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> Date: Fri, 8 Sep 2023 15:32:48 +0800 Subject: [PATCH 07/46] [moe] align train settings and losses (#4655) * init * update moe ckpt * update config * support openmoe infernece * update config * remove pdb * support train * add ckpt download * update ckpt loading * use general ckpt * add loss and optim * update ci * update require --- colossalai/context/moe_context.py | 13 ++- colossalai/nn/layer/moe/layers.py | 3 +- colossalai/nn/layer/moe/routers.py | 49 ++++++---- .../openmoe/model/modeling_openmoe.py | 98 +++++++++++++++---- examples/language/openmoe/requirements.txt | 1 + examples/language/openmoe/test_ci.sh | 1 + examples/language/openmoe/train.py | 57 +++++------ examples/language/openmoe/train.sh | 3 + tests/test_moe/moe_utils.py | 7 +- tests/test_moe/test_grad_handler.py | 2 +- tests/test_moe/test_kernel.py | 4 +- tests/test_moe/test_moe_ep_tp.py | 4 +- tests/test_moe/test_moe_local.py | 4 +- tests/test_moe/test_moe_zero_fwd_bwd.py | 2 +- tests/test_moe/test_moe_zero_optim.py | 2 +- 15 files changed, 159 insertions(+), 91 deletions(-) create mode 100644 examples/language/openmoe/train.sh diff --git a/colossalai/context/moe_context.py b/colossalai/context/moe_context.py index f90b4071dee8..a21eda309f84 100644 --- a/colossalai/context/moe_context.py +++ b/colossalai/context/moe_context.py @@ -20,7 +20,8 @@ def __init__(self): # When we have a maximum expert parallel size, we have a minimum data parallel size naturally self.max_ep_size = None self.min_dp_size = None - self.aux_loss = None + self.router_aux_loss = [] + self.router_z_loss = [] self.parallel = None self.use_kernel_optim = True @@ -100,13 +101,15 @@ def set_kernel_not_use(self): self.use_kernel_optim = False def reset_loss(self): - self.aux_loss = 0 + self.router_aux_loss, self.router_z_loss = [], [] - def add_loss(self, loss): - self.aux_loss += loss + def add_loss(self, aux_loss: float = 0., z_loss: float = 0.): + self.router_aux_loss.append(aux_loss) + self.router_z_loss.append(z_loss) def get_loss(self): - return self.aux_loss + cur_loss = self.router_aux_loss, self.router_z_loss + return cur_loss def get_parallel(self): return self.parallel diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py index 8104e33a8bab..1ea357fa2749 100644 --- a/colossalai/nn/layer/moe/layers.py +++ b/colossalai/nn/layer/moe/layers.py @@ -136,8 +136,7 @@ def forward(self, inputs: torch.Tensor) -> Tuple: ans = torch.matmul(combine_weights, expert_output) ans = ans.reshape(inputs.shape) - l_aux = self.router.pop_routing_loss() - return ans, l_aux + return ans def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor: expert_in = expert_in.unsqueeze(0) diff --git a/colossalai/nn/layer/moe/routers.py b/colossalai/nn/layer/moe/routers.py index 962aec9cf1e7..9332302a096a 100644 --- a/colossalai/nn/layer/moe/routers.py +++ b/colossalai/nn/layer/moe/routers.py @@ -38,7 +38,8 @@ def __init__(self, self.min_capacity = min_capacity self.noisy_func = noisy_func self.drop_tks = drop_tks - self._routing_loss = None + self._aux_loss = None + self._z_loss = None def get_capacity(self, logits_shape): capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval @@ -48,15 +49,26 @@ def get_capacity(self, logits_shape): assert capacity > 0 return capacity - def set_routing_loss(self, aux_loss: torch.Tensor) -> None: - assert self._routing_loss is None - self._routing_loss = aux_loss + def set_aux_loss(self, logits: torch.Tensor, cmask: torch.Tensor, num_experts: int) -> None: + assert self._aux_loss is None + me = torch.mean(logits, dim=0) + ce = torch.mean(cmask.float(), dim=0) + aux_loss = num_experts * torch.sum(me * ce) + self._aux_loss = aux_loss + + def set_z_loss(self, router_logits: torch.Tensor): + assert self._z_loss is None + n, _ = router_logits.shape + log_z = torch.logsumexp(router_logits, axis=-1) + z_loss = log_z**2 + z_loss = torch.sum(z_loss, dtype=torch.float32) / n + self._z_loss = z_loss - def pop_routing_loss(self) -> torch.Tensor: - assert self._routing_loss is not None - reservation = self._routing_loss - self._routing_loss = None - return reservation + def pop_router_loss(self) -> torch.Tensor: + assert self._aux_loss is not None + MOE_CONTEXT.add_loss(self._aux_loss, self._z_loss) + self._aux_loss = None + self._z_loss = None class Top1Router(MoeRouter): @@ -105,11 +117,10 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti top1_idx = torch.argmax(inputs, dim=-1) mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) - # caculate the auxiliary loss - me = torch.mean(logits, dim=0) - ce = torch.mean(mask.float(), dim=0) - l_aux = num_experts * torch.sum(me * ce) - self.set_routing_loss(l_aux) + # caculate router loss + self.set_aux_loss(logits, mask, num_experts) + self.set_z_loss(inputs) + self.pop_router_loss() if not self.training and not self.drop_tks and ep_group is not None: max_num = torch.max(torch.sum(mask, dim=0)) @@ -183,12 +194,12 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32) cmask = (mask1 + mask2) # loss: [s, e] + cmask = cmask.float() / 2.0 # div 2 to normalize it to 1 - # caculate the auxiliary loss - me = torch.mean(logits, dim=0) - ce = torch.mean(cmask.float(), dim=0) - l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1 - self.set_routing_loss(l_aux) + # caculate loss + self.set_aux_loss(logits, cmask, num_experts) + self.set_z_loss(inputs) + self.pop_router_loss() if not self.training and not self.drop_tks and ep_group is not None: max_num = torch.max(torch.sum(cmask, dim=0)) diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index a1e028ae6308..1ea9d48523c3 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -18,14 +18,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch OpenMoE model.""" +import math from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss -from transformers.activations import ACT2FN from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_utils import PreTrainedModel from transformers.models.llama import LlamaConfig @@ -508,8 +507,6 @@ def forward( residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) - if self.moe: - hidden_states = hidden_states[0] hidden_states = residual + hidden_states if self.moe: @@ -742,7 +739,6 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # import pdb; pdb.set_trace() # embed positions if attention_mask is None: attention_mask = torch.ones((batch_size, seq_length_with_past), @@ -894,6 +890,8 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" + # reset moe loss + MOE_CONTEXT.reset_loss() output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = (output_hidden_states @@ -939,24 +937,19 @@ def custom_forward(*inputs): shift_logits = logits[..., :-1, :].contiguous().float() shift_labels = inputs[1][..., 1:].contiguous() # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) + loss = self._calculate_loss(shift_logits, shift_labels) return loss return custom_forward - loss = 0. + aux_loss, z_loss = self._calculate_router_loss() + loss = aux_loss + z_loss for batch_idx in range(hidden_states.shape[0]): loss = loss + torch.utils.checkpoint.checkpoint( create_custom_forward(self.lm_head), - hidden_states[batch_idx, :], - labels[batch_idx, :], + hidden_states[batch_idx:batch_idx + 1, :], + labels[batch_idx:batch_idx + 1, :], ) - loss = loss / hidden_states.shape[0] logits = None else: logits = self.lm_head(hidden_states) @@ -965,12 +958,9 @@ def custom_forward(*inputs): shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) + aux_loss, z_loss = self._calculate_router_loss() + loss = aux_loss + z_loss + loss = loss + self._calculate_loss(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] @@ -1022,3 +1012,69 @@ def _reorder_cache(past_key_values, beam_idx): reordered_past += (tuple( past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),) return reordered_past + + def _calculate_router_loss(self): + aux_loss, z_loss = MOE_CONTEXT.get_loss() + assert len(aux_loss) == len(z_loss) == self.config.num_hidden_layers // self.config.moe_layer_interval + aux_loss = self.config.router_aux_loss_factor * sum(aux_loss) / len(aux_loss) + z_loss = self.config.router_z_loss_factor * sum(z_loss) / len(z_loss) + return aux_loss, z_loss + + def _calculate_loss(self, logits, targets): + if len(logits.shape) != len(targets.shape) + 1: + raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' % + (str(logits.shape), str(targets.shape))) + vocab_size = logits.shape[-1] + confidence = 1.0 - self.config.label_smoothing + low_confidence = (1.0 - confidence) / (vocab_size - 1) + normalizing_constant = -(confidence * math.log(confidence) + + (vocab_size - 1) * low_confidence * math.log(low_confidence + 1e-20)) + + # one hot + soft_targets = targets[..., None] == \ + torch.arange(vocab_size, device=targets.device).reshape((1,) * len(targets.shape) + (-1,)) + soft_targets = torch.where(soft_targets, torch.full_like(soft_targets, confidence), + torch.full_like(soft_targets, low_confidence)) + soft_targets = soft_targets.to(torch.float32) + + # cross entropy + total_loss = ZLossCrossEntropy.apply(logits, soft_targets, self.config.z_loss_factor) + total_loss = total_loss - normalizing_constant + total_loss = torch.mean(torch.sum(total_loss, dim=-1), dim=0) + return total_loss + + +class ZLossCrossEntropy(torch.autograd.Function): + + @staticmethod + def forward(ctx, logits, targets, z_loss): + max_logit = torch.max(logits, dim=-1, keepdim=True)[0] + shifted = logits - max_logit + exp_shifted = torch.exp(shifted) + sum_exp = torch.sum(exp_shifted, axis=-1, keepdims=True) + log_softmax = shifted - torch.log(sum_exp) + loss = -torch.sum(targets * log_softmax, axis=-1) + # Add auxilliary z-loss term. + log_z = torch.squeeze(torch.log(sum_exp) + max_logit, axis=-1) + total_z_loss = z_loss * torch.square(log_z) + loss += total_z_loss + ctx.z_loss = z_loss + ctx.save_for_backward(logits, targets, exp_shifted, sum_exp, log_softmax, log_z) + return loss + + @staticmethod + def backward(ctx, *grad_outputs): + assert len(grad_outputs) == 1 + g = grad_outputs[0] + z_loss = ctx.z_loss + logits, targets, exp_shifted, sum_exp, log_softmax, log_z = ctx.saved_tensors + # z-loss term adds the (2 * z_loss * log_z) factor. + deriv = ((1 + 2 * z_loss * log_z).unsqueeze(-1) * exp_shifted / sum_exp - targets) + g_logits = g.unsqueeze(-1) * deriv + g_targets = -g.unsqueeze(-1) * log_softmax + + return ( + g_logits.to(logits.dtype), + g_targets.to(targets.dtype), + None, + ) diff --git a/examples/language/openmoe/requirements.txt b/examples/language/openmoe/requirements.txt index 2fb95d9c71d3..935a3f1e4ce0 100644 --- a/examples/language/openmoe/requirements.txt +++ b/examples/language/openmoe/requirements.txt @@ -2,3 +2,4 @@ colossalai >= 0.1.12 torch >= 1.8.1 transformers >= 4.20.0 sentencepiece +datasets diff --git a/examples/language/openmoe/test_ci.sh b/examples/language/openmoe/test_ci.sh index 349b2eaccd79..75eee902c747 100644 --- a/examples/language/openmoe/test_ci.sh +++ b/examples/language/openmoe/test_ci.sh @@ -2,3 +2,4 @@ set -xe pip install -r requirements.txt python infer.py --model "test" +torchrun --standalone --nproc_per_node 2 train.py --model_name "test" --batch_size 1 --num_epoch 20 diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index 407809702436..67dd387a3950 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -7,7 +7,7 @@ from model.modeling_openmoe import OpenMoeForCausalLM from torch.utils.data import Dataset from tqdm import tqdm -from transformers import T5Tokenizer, get_linear_schedule_with_warmup +from transformers import Adafactor, T5Tokenizer from transformers.models.llama import LlamaConfig import colossalai @@ -60,7 +60,7 @@ def __getitem__(self, idx): def parse_args(): parser = get_default_parser() - parser.add_argument("--model_name_or_path", + parser.add_argument("--model_name", type=str, default="base", help="Path to pretrained model or model identifier from huggingface.co/models.") @@ -73,16 +73,16 @@ def parse_args(): type=int, default=4, help="Batch size (per dp group) for the training dataloader.") - parser.add_argument("--learning_rate", - type=float, - default=5e-5, - help="Initial learning rate (after the potential warmup period) to use.") - parser.add_argument("--warmup_ratio", - type=float, - default=0.1, - help="Ratio of warmup steps against total training steps.") - parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to use.") parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + # loss + parser.add_argument("--router_aux_loss_factor", type=float, default=0.01, help="router_aux_loss_factor.") + parser.add_argument("--router_z_loss_factor", type=float, default=0.0001, help="router_z_loss_factor.") + parser.add_argument("--label_smoothing", type=float, default=0.0, help="label_smoothing.") + parser.add_argument("--z_loss_factor", type=float, default=0.0001, help="z_loss_factor.") + # optim + parser.add_argument("--decay_rate", type=float, default=-0.8, help="adafactor optim decay rate.") + parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to use.") + args = parser.parse_args() return args @@ -93,7 +93,6 @@ def main(): # Launch ColossalAI colossalai.launch_from_torch(config={}, seed=args.seed) coordinator = DistCoordinator() - world_size = coordinator.world_size # Set up moe MOE_CONTEXT.setup(seed=42, parallel="EP") @@ -109,11 +108,20 @@ def main(): transformers.utils.logging.set_verbosity_error() # Build OpenMoe model - repo_name = "hpcaitech/openmoe-" + args.model_name_or_path - config = LlamaConfig.from_pretrained(repo_name) + repo_name = "hpcaitech/openmoe-" + args.model_name + if args.model_name == "test": + config = LlamaConfig.from_pretrained("hpcaitech/openmoe-base") + config.vocab_size = 32000 + else: + config = LlamaConfig.from_pretrained(repo_name) + setattr(config, "router_aux_loss_factor", args.router_aux_loss_factor) + setattr(config, "router_z_loss_factor", args.router_z_loss_factor) + setattr(config, "label_smoothing", args.label_smoothing) + setattr(config, "z_loss_factor", args.z_loss_factor) with skip_init(): model = OpenMoeForCausalLM(config) - load_ckpt(repo_name, model) + if args.model_name != "test": + load_ckpt(repo_name, model) logger.info(f"Finish init model with config:\n{config}", ranks=[0]) # Enable gradient checkpointing @@ -126,27 +134,15 @@ def main(): # Prepare tokenizer and dataloader tokenizer = T5Tokenizer.from_pretrained("google/umt5-small") - dataset = RandomDataset() + dataset = RandomDataset(num_samples=1000 if args.model_name != "test" else 1) dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) # Set optimizer - optimizer = torch.optim.Adam(model.parameters(), - lr=(args.learning_rate * world_size), - weight_decay=args.weight_decay) - - # Set lr scheduler - total_steps = len(dataloader) * args.num_epoch - num_warmup_steps = int(args.warmup_ratio * total_steps) - lr_scheduler = get_linear_schedule_with_warmup(optimizer, - num_warmup_steps=num_warmup_steps, - num_training_steps=len(dataloader) * args.num_epoch) + optimizer = Adafactor(model.parameters(), decay_rate=args.decay_rate, weight_decay=args.weight_decay) # Set booster booster = Booster(plugin=plugin, **booster_kwargs) - model, optimizer, _, dataloader, lr_scheduler = booster.boost(model=model, - optimizer=optimizer, - dataloader=dataloader, - lr_scheduler=lr_scheduler) + model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader) logger.info(f"Finish init booster", ranks=[0]) # Start finetuning @@ -165,7 +161,6 @@ def main(): # Backward booster.backward(loss, optimizer) optimizer.step() - lr_scheduler.step() # Print batch loss pbar.set_postfix({'loss': loss.item()}) diff --git a/examples/language/openmoe/train.sh b/examples/language/openmoe/train.sh new file mode 100644 index 000000000000..9a55779ca5ef --- /dev/null +++ b/examples/language/openmoe/train.sh @@ -0,0 +1,3 @@ +torchrun --standalone --nproc_per_node 2 train.py \ + --model_name "base" \ + --batch_size 4 diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py index a3f6f86e6fe9..7f9bfe376632 100644 --- a/tests/test_moe/moe_utils.py +++ b/tests/test_moe/moe_utils.py @@ -30,9 +30,9 @@ def __init__(self): self.proj = nn.Linear(16, 4) def _forward(self, x): - x, y = self.moe(x) + x = self.moe(x) x = self.proj(x) - return x, y + return x super().__init__() self.test_embed = nn.Linear(4, 16) @@ -42,9 +42,8 @@ def forward(self, x): MOE_CONTEXT.reset_loss() x = self.test_embed(x) - x, y = self.test_transform(x) + x = self.test_transform(x) - MOE_CONTEXT.add_loss(y) return x diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index f09a845afe3d..e36555fa2eb9 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -50,7 +50,7 @@ def run_test(rank, world_size, port): MOE_CONTEXT.reset_loss() for layer in layer_list: - data, _ = layer(data) + data = layer(data) data.backward(grad) grad_handler.handle_gradient() diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index 9a0675bc7b20..ce663136600e 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -42,7 +42,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f # use matrix multiplication instead of COL_MOE_KERNEL in MOE dispatch and combine layer.use_kernel = False - old_out, _ = layer(tokens) + old_out = layer(tokens) ech = old_out.shape grad = torch.randn(ech, device=get_current_device()) old_out.backward(grad) # get gradient @@ -56,7 +56,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f layer.gate_weight.grad.zero_() layer.use_kernel = True - new_out, _ = layer(tokens) # get outputs through colossal kernel + new_out = layer(tokens) # get outputs through colossal kernel if data_type == torch.float32: check_equal(old_out, new_out) diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py index 13c66cf73e4d..cb261912e0f6 100644 --- a/tests/test_moe/test_moe_ep_tp.py +++ b/tests/test_moe/test_moe_ep_tp.py @@ -37,9 +37,9 @@ def run_test(rank, world_size, port): tp_data = torch.randn(BATCH_SIZE, DIM, device=get_current_device()) ep_data = tp_data.detach()[2 * rank:2 * (rank + 1)] - out_tp = tp_model(tp_data)[0] + out_tp = tp_model(tp_data) MOE_CONTEXT.reset_loss() - out_ep = ep_model(ep_data)[0] + out_ep = ep_model(ep_data) MOE_CONTEXT.reset_loss() assert torch.allclose(out_ep, out_tp[2 * rank:2 * (rank + 1)]) diff --git a/tests/test_moe/test_moe_local.py b/tests/test_moe/test_moe_local.py index d240ad46ce71..e41a0d821a10 100644 --- a/tests/test_moe/test_moe_local.py +++ b/tests/test_moe/test_moe_local.py @@ -37,9 +37,9 @@ def run_test(rank, world_size, port): tp_data = torch.randn(BATCH_SIZE, DIM, device=get_current_device()) ep_data = tp_data.detach()[2 * rank:2 * (rank + 1)] - out_tp = local_model(tp_data)[0] + out_tp = local_model(tp_data) MOE_CONTEXT.reset_loss() - out_ep = ep_model(ep_data)[0] + out_ep = ep_model(ep_data) MOE_CONTEXT.reset_loss() assert torch.allclose(out_ep, out_tp[2 * rank:2 * (rank + 1)]) diff --git a/tests/test_moe/test_moe_zero_fwd_bwd.py b/tests/test_moe/test_moe_zero_fwd_bwd.py index 9d19ee830f77..f1f888203746 100644 --- a/tests/test_moe/test_moe_zero_fwd_bwd.py +++ b/tests/test_moe/test_moe_zero_fwd_bwd.py @@ -40,7 +40,7 @@ def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False) def run_zero_test(local_rank, world_size, stage=1): - criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss) + criterion = torch.nn.CrossEntropyLoss() zero_model = MoeModel(checkpoint=True) optimizer = torch.optim.Adam(zero_model.parameters()) diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index fcb6f95d1319..229ee528b4fc 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -39,7 +39,7 @@ def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False) def run_zero_optim_test(local_rank, world_size, stage=1): - criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss) + criterion = torch.nn.CrossEntropyLoss() zero_model = MoeModel(checkpoint=True) zero_optimizer = torch.optim.Adam(zero_model.parameters()) From 84f05b1c65a96f98606fe68ee3155ce54e63bd64 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> Date: Mon, 11 Sep 2023 13:47:33 +0800 Subject: [PATCH 08/46] [moe] move to moe and remove legacy (#4672) * init * update moe ckpt * update config * support openmoe infernece * update config * remove pdb * support train * add ckpt download * update ckpt loading * use general ckpt * add loss and optim * update ci * update require * move * move * remove legacy * update file name and restore moe context * update module * update build_ffn_experts * update init * add ctx --- colossalai/context/__init__.py | 2 - colossalai/context/moe_context.py | 66 ++++++---- colossalai/legacy/context/random/__init__.py | 14 +-- colossalai/legacy/context/random/_helper.py | 10 -- colossalai/moe/__init__.py | 10 ++ colossalai/{nn/layer => }/moe/_operation.py | 0 colossalai/{nn/layer => }/moe/checkpoint.py | 0 colossalai/{nn/layer => }/moe/experts.py | 43 +++---- colossalai/{nn/layer => }/moe/layers.py | 62 +--------- .../{nn/loss/loss_moe.py => moe/loss.py} | 9 +- colossalai/moe/manager.py | 115 ++++++++++++++++++ colossalai/{nn/layer => }/moe/routers.py | 6 +- colossalai/{nn/layer => }/moe/utils.py | 49 +++++++- colossalai/nn/layer/moe/__init__.py | 22 ++-- colossalai/nn/loss/__init__.py | 1 - colossalai/tensor/moe_tensor/api.py | 26 ++++ colossalai/tensor/moe_tensor/moe_info.py | 2 + colossalai/utils/moe.py | 53 -------- .../openmoe/model/modeling_openmoe.py | 10 +- examples/language/openmoe/train.py | 8 +- tests/test_moe/moe_utils.py | 35 ++++-- tests/test_moe/test_grad_handler.py | 27 ++-- tests/test_moe/test_kernel.py | 16 +-- tests/test_moe/test_moe_checkpoint.py | 6 +- tests/test_moe/test_moe_ep_tp.py | 14 +-- tests/test_moe/test_moe_group.py | 10 +- tests/test_moe/test_moe_local.py | 21 ++-- tests/test_moe/test_moe_zero_fwd_bwd.py | 5 +- tests/test_moe/test_moe_zero_optim.py | 5 +- 29 files changed, 376 insertions(+), 271 deletions(-) create mode 100644 colossalai/moe/__init__.py rename colossalai/{nn/layer => }/moe/_operation.py (100%) rename colossalai/{nn/layer => }/moe/checkpoint.py (100%) rename colossalai/{nn/layer => }/moe/experts.py (79%) rename colossalai/{nn/layer => }/moe/layers.py (75%) rename colossalai/{nn/loss/loss_moe.py => moe/loss.py} (92%) create mode 100644 colossalai/moe/manager.py rename colossalai/{nn/layer => }/moe/routers.py (98%) rename colossalai/{nn/layer => }/moe/utils.py (66%) delete mode 100644 colossalai/utils/moe.py diff --git a/colossalai/context/__init__.py b/colossalai/context/__init__.py index ab57301bb910..3e94b7cfe993 100644 --- a/colossalai/context/__init__.py +++ b/colossalai/context/__init__.py @@ -1,7 +1,5 @@ from .config import Config, ConfigException -# from .moe_context import MOE_CONTEXT - __all__ = [ "Config", "ConfigException", diff --git a/colossalai/context/moe_context.py b/colossalai/context/moe_context.py index a21eda309f84..510b05278c56 100644 --- a/colossalai/context/moe_context.py +++ b/colossalai/context/moe_context.py @@ -3,9 +3,29 @@ import torch import torch.distributed as dist +from colossalai.context.parallel_mode import ParallelMode from colossalai.context.singleton_meta import SingletonMeta -from colossalai.tensor.moe_tensor.api import get_moe_info -from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo +from colossalai.tensor import ProcessGroup + + +def _check_sanity(): + from colossalai.core import global_context as gpc + if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1: + raise NotImplementedError("Moe is not compatible with tensor or " + "pipeline parallel at present.") + + +class MoeParallelInfo: + """Moe parallelism information, storing parallel sizes and groups. + """ + + def __init__(self, ep_size: int, dp_size: int): + _check_sanity() + self.ep_size = ep_size + self.dp_size = dp_size + self.pg = ProcessGroup(tp_degree=ep_size, dp_degree=dp_size) + self.ep_group = self.pg.tp_process_group() + self.dp_group = self.pg.dp_process_group() class MoeContext(metaclass=SingletonMeta): @@ -14,15 +34,13 @@ class MoeContext(metaclass=SingletonMeta): """ def __init__(self): - self.world_size = None + self.world_size = 1 # Users may want to set maximum expert parallel size smaller than the world size # since very low bandwidth across nodes may constrain the performance of MoE # When we have a maximum expert parallel size, we have a minimum data parallel size naturally - self.max_ep_size = None - self.min_dp_size = None - self.router_aux_loss = [] - self.router_z_loss = [] - self.parallel = None + self.max_ep_size = 1 + self.min_dp_size = 1 + self.aux_loss = None self.use_kernel_optim = True self.has_setup = False @@ -36,14 +54,18 @@ def parallel_info_dict(self): def is_initialized(self): return self.has_setup - def setup(self, seed: int, use_kernel_optim: bool = True, max_ep_size: int = 8, parallel: bool = None): + def setup(self, seed: int, use_kernel_optim: bool = True): assert not self.is_initialized, "MoE distributed context shouldn't be set up again" + _check_sanity() assert torch.cuda.is_available(), "MoE requires to enable CUDA first" self.world_size = dist.get_world_size() - self.max_ep_size = min(max_ep_size, dist.get_world_size()) + + from colossalai.core import global_context as gpc + self.max_ep_size = gpc.config.get('max_ep_size', self.world_size) + assert self.world_size % self.max_ep_size == 0, \ + "Maximum expert parallel size must be a factor of the number of GPUs" self.min_dp_size = self.world_size // self.max_ep_size - self.parallel = parallel # Enabling kernel optimization may raise error in some cases # Users can close kernel optimization manually @@ -54,7 +76,7 @@ def setup(self, seed: int, use_kernel_optim: bool = True, max_ep_size: int = 8, moe_set_seed(seed) self.has_setup = True - def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoeParallelInfo]: + def get_info(self, num_experts: int) -> Tuple[int, MoeParallelInfo]: """Calculate the Data Parallel Group and Expert Parallel Group. Parameters @@ -85,15 +107,12 @@ def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoePara ep_size = self.max_ep_size // dp_size # Calculate the number of experts for each GPU - if use_tp: - num_local_experts = num_experts - else: - num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size + num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size # Don't forget to multiply minimum data parallel size dp_size *= self.min_dp_size if not (ep_size in self.parallel_info_dict): - self.parallel_info_dict[ep_size] = get_moe_info(ep_size, dp_size) + self.parallel_info_dict[ep_size] = MoeParallelInfo(ep_size, dp_size) return num_local_experts, self.parallel_info_dict[ep_size] @@ -101,18 +120,13 @@ def set_kernel_not_use(self): self.use_kernel_optim = False def reset_loss(self): - self.router_aux_loss, self.router_z_loss = [], [] + self.aux_loss = 0 - def add_loss(self, aux_loss: float = 0., z_loss: float = 0.): - self.router_aux_loss.append(aux_loss) - self.router_z_loss.append(z_loss) + def add_loss(self, loss): + self.aux_loss += loss def get_loss(self): - cur_loss = self.router_aux_loss, self.router_z_loss - return cur_loss - - def get_parallel(self): - return self.parallel + return self.aux_loss MOE_CONTEXT = MoeContext() diff --git a/colossalai/legacy/context/random/__init__.py b/colossalai/legacy/context/random/__init__.py index 5e8d82922ddc..e2314f859d3f 100644 --- a/colossalai/legacy/context/random/__init__.py +++ b/colossalai/legacy/context/random/__init__.py @@ -3,7 +3,6 @@ get_current_mode, get_seeds, get_states, - moe_set_seed, reset_seeds, seed, set_mode, @@ -13,15 +12,6 @@ ) __all__ = [ - "seed", - "set_mode", - "with_seed", - "add_seed", - "get_seeds", - "get_states", - "get_current_mode", - "set_seed_states", - "sync_states", - "moe_set_seed", - "reset_seeds", + 'seed', 'set_mode', 'with_seed', 'add_seed', 'get_seeds', 'get_states', 'get_current_mode', 'set_seed_states', + 'sync_states', 'reset_seeds' ] diff --git a/colossalai/legacy/context/random/_helper.py b/colossalai/legacy/context/random/_helper.py index be1d951d1229..7d27b3f85db9 100644 --- a/colossalai/legacy/context/random/_helper.py +++ b/colossalai/legacy/context/random/_helper.py @@ -159,15 +159,5 @@ def wrapper(*args, **kwargs): return wrapper -def moe_set_seed(seed): - if torch.cuda.is_available(): - from colossalai.legacy.core import global_context as gpc - - global_rank = gpc.get_global_rank() - diff_seed = seed + global_rank - add_seed(ParallelMode.TENSOR, diff_seed, True) - print(f"moe seed condition: {global_rank} with tensor seed {diff_seed}", flush=True) - - def reset_seeds(): _SEED_MANAGER.reset() diff --git a/colossalai/moe/__init__.py b/colossalai/moe/__init__.py new file mode 100644 index 000000000000..492cdaf13d1d --- /dev/null +++ b/colossalai/moe/__init__.py @@ -0,0 +1,10 @@ +from .checkpoint import MoeCheckpintIO +from .experts import EPMLPExperts, TPMLPExperts, build_ffn_experts +from .layers import SparseMLP +from .routers import MoeRouter, Top1Router, Top2Router +from .utils import NormalNoiseGenerator, UniformNoiseGenerator + +__all__ = [ + 'EPMLPExperts', 'TPMLPExperts', 'Top1Router', 'Top2Router', 'NormalNoiseGenerator', 'UniformNoiseGenerator', + 'SparseMLP', 'MoeRouter', 'MoeCheckpintIO', 'build_ffn_experts' +] diff --git a/colossalai/nn/layer/moe/_operation.py b/colossalai/moe/_operation.py similarity index 100% rename from colossalai/nn/layer/moe/_operation.py rename to colossalai/moe/_operation.py diff --git a/colossalai/nn/layer/moe/checkpoint.py b/colossalai/moe/checkpoint.py similarity index 100% rename from colossalai/nn/layer/moe/checkpoint.py rename to colossalai/moe/checkpoint.py diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/moe/experts.py similarity index 79% rename from colossalai/nn/layer/moe/experts.py rename to colossalai/moe/experts.py index f9289749d3a1..da4fe58977e8 100644 --- a/colossalai/nn/layer/moe/experts.py +++ b/colossalai/moe/experts.py @@ -1,14 +1,14 @@ import math -from copy import deepcopy +from contextlib import nullcontext import torch -import torch.distributed as dist import torch.nn as nn -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.nn.layer.moe._operation import MoeInGradScaler, MoeOutGradScaler -from colossalai.nn.layer.moe.utils import get_activation -from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_size, set_moe_tensor_info +from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler +from colossalai.moe.manager import MOE_MANAGER +from colossalai.moe.utils import get_activation +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.moe_tensor.api import get_ep_size, set_moe_tensor_info class BaseMLPExperts(nn.Module): @@ -34,13 +34,13 @@ def __init__( # get expert parallel info if expert_parallel is not None: - self.num_local_experts, self.moe_info = MOE_CONTEXT.get_info( + self.num_local_experts, self.moe_info = MOE_MANAGER.get_info( num_experts, use_tp=True if expert_parallel == "TP" else False) # get settings for different parallel if expert_parallel == "TP": - assert intermediate_size % MOE_CONTEXT.max_ep_size == 0, \ + assert intermediate_size % MOE_MANAGER.max_ep_size == 0, \ "intermediate_size should be divide by maximum expert parallel size" - intermediate_size = intermediate_size // MOE_CONTEXT.max_ep_size + intermediate_size = intermediate_size // MOE_MANAGER.max_ep_size num_experts = self.num_total_experts else: num_experts = self.num_local_experts @@ -56,14 +56,18 @@ def __init__( self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) self.wo = nn.Parameter(torch.empty(num_experts, intermediate_size, hidden_size)) + # expert param should be different if expert_parallel is not None: - with seed(ParallelMode.TENSOR): - if gated: - nn.init.trunc_normal_(self.wi_gate, std=math.sqrt(0.1 / hidden_size)) - nn.init.trunc_normal_(self.wi_up, std=math.sqrt(0.1 / hidden_size)) - else: - nn.init.trunc_normal_(self.wi, std=math.sqrt(0.1 / hidden_size)) - nn.init.trunc_normal_(self.wo, std=math.sqrt(0.1 / intermediate_size)) + seed_ctx = Randomizer(MOE_MANAGER.seed).fork_rng(enable_cpu=True) + else: + seed_ctx = nullcontext() + with seed_ctx: + if gated: + nn.init.trunc_normal_(self.wi_gate, std=math.sqrt(0.1 / hidden_size)) + nn.init.trunc_normal_(self.wi_up, std=math.sqrt(0.1 / hidden_size)) + else: + nn.init.trunc_normal_(self.wi, std=math.sqrt(0.1 / hidden_size)) + nn.init.trunc_normal_(self.wo, std=math.sqrt(0.1 / intermediate_size)) self.act = get_activation(activation) self.drop = nn.Dropout(p=drop_rate) @@ -87,10 +91,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # inputs [g, e, c, h] else: x = torch.bmm(x, self.wi) x = self.act(x) - - if self.expert_parallel is not None: - with seed(ParallelMode.TENSOR): - x = self.drop(x) + x = self.drop(x) x = torch.bmm(x, self.wo) x = x.reshape(inshape) @@ -142,7 +143,7 @@ def get_expert_class(name: str) -> BaseMLPExperts: def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): - mep_size = MOE_CONTEXT.max_ep_size + mep_size = MOE_MANAGER.max_ep_size if num_experts % mep_size == 0 or mep_size % num_experts == 0: return EPMLPExperts(num_experts, d_model, d_ff, activation, drop_rate) elif d_ff % mep_size == 0: diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/moe/layers.py similarity index 75% rename from colossalai/nn/layer/moe/layers.py rename to colossalai/moe/layers.py index 1ea357fa2749..ace81b543273 100644 --- a/colossalai/nn/layer/moe/layers.py +++ b/colossalai/moe/layers.py @@ -5,19 +5,11 @@ import torch.nn as nn import torch.nn.functional as F -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.legacy.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator -from colossalai.nn.layer.moe._operation import ( - COL_MOE_KERNEL_FLAG, - AllGather, - AllToAll, - MoeCombine, - MoeDispatch, - ReduceScatter, -) -from colossalai.nn.layer.moe.experts import BaseMLPExperts, get_expert_class -from colossalai.nn.layer.moe.routers import MoeRouter, get_router_cls -from colossalai.nn.layer.moe.utils import get_noise_generator +from colossalai.moe._operation import COL_MOE_KERNEL_FLAG, AllGather, AllToAll, MoeCombine, MoeDispatch, ReduceScatter +from colossalai.moe.experts import BaseMLPExperts, get_expert_class +from colossalai.moe.manager import MOE_MANAGER +from colossalai.moe.routers import MoeRouter, get_router_cls +from colossalai.moe.utils import get_noise_generator from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size @@ -66,7 +58,7 @@ def __init__(self, super().__init__() self.hidden_size = hidden_size self.num_experts = num_experts - self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False + self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_MANAGER.use_kernel_optim else False self.expert_parallel = expert_parallel assert expert_parallel in ["EP", "TP", None], f"Unsupported expert parallel type {expert_parallel}" @@ -157,45 +149,3 @@ def _tp_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: expert_out = self.experts(expert_in) expert_out = ReduceScatter.apply(expert_out, self.ep_group) return expert_out - - -class MoeModule(SparseMLP): - """ - For other dependency - """ - - def __init__(self, - num_experts: int, - top_k: int = 1, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_policy: Optional[str] = None, - drop_tks: bool = True, - expert_parallel: str = "EP", - hidden_size: int = 2048, - intermediate_size: int = 2048, - activation: str = None): - super().__init__(num_experts, top_k, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_policy, - drop_tks, expert_parallel, hidden_size, intermediate_size, activation) - - -class MoeLayer(SparseMLP): - """ - For other dependency - """ - - def __init__(self, - num_experts: int, - top_k: int = 1, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_policy: Optional[str] = None, - drop_tks: bool = True, - expert_parallel: str = "EP", - hidden_size: int = 2048, - intermediate_size: int = 2048, - activation: str = None): - super().__init__(num_experts, top_k, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_policy, - drop_tks, expert_parallel, hidden_size, intermediate_size, activation) diff --git a/colossalai/nn/loss/loss_moe.py b/colossalai/moe/loss.py similarity index 92% rename from colossalai/nn/loss/loss_moe.py rename to colossalai/moe/loss.py index 40cea788c3c3..75624510b452 100644 --- a/colossalai/nn/loss/loss_moe.py +++ b/colossalai/moe/loss.py @@ -1,11 +1,9 @@ import torch.nn as nn from torch.nn.modules.loss import _Loss -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.legacy.registry import LOSSES +from colossalai.moe.manager import MOE_MANAGER -@LOSSES.register_module class MoeCrossEntropyLoss(_Loss): r"""torch.nn.CrossEntropyLoss added with auxiliary loss. @@ -45,11 +43,10 @@ def forward(self, *args): `Cross_entropy `_. """ main_loss = self.loss(*args) - aux_loss = MOE_CONTEXT.get_loss() + aux_loss = MOE_MANAGER.get_loss() return main_loss + self.aux_weight * aux_loss -@LOSSES.register_module class MoeLoss(_Loss): """A wrapper class for any loss module to add with auxiliary loss. @@ -77,5 +74,5 @@ def forward(self, *args, **kwargs): The ``args`` and ``kwargs`` may include different parameters varying with different loss function. """ main_loss = self.loss_fn(*args, **kwargs) - aux_loss = MOE_CONTEXT.get_loss() + aux_loss = MOE_MANAGER.get_loss() return main_loss + self.aux_weight * aux_loss diff --git a/colossalai/moe/manager.py b/colossalai/moe/manager.py new file mode 100644 index 000000000000..3dc27c6cb0f0 --- /dev/null +++ b/colossalai/moe/manager.py @@ -0,0 +1,115 @@ +from typing import Tuple + +import torch +import torch.distributed as dist + +from colossalai.context.singleton_meta import SingletonMeta +from colossalai.tensor.moe_tensor.api import get_moe_info +from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo + + +class MoeManager(metaclass=SingletonMeta): + """MoE manager. This class manages different + parallel groups in MoE context and MoE loss in training. + """ + + def __init__(self): + self.world_size = None + # Users may want to set maximum expert parallel size smaller than the world size + # since very low bandwidth across nodes may constrain the performance of MoE + # When we have a maximum expert parallel size, we have a minimum data parallel size naturally + self.max_ep_size = None + self.min_dp_size = None + self.router_aux_loss = [] + self.router_z_loss = [] + self.parallel = None + self.seed = None + self.use_kernel_optim = True + + self.has_setup = False + self._parallel_info_dict = dict() + + @property + def parallel_info_dict(self): + return self._parallel_info_dict + + @property + def is_initialized(self): + return self.has_setup + + def setup(self, seed: int, use_kernel_optim: bool = True, max_ep_size: int = 8, parallel: bool = None): + assert not self.is_initialized, "MoE distributed context shouldn't be set up again" + assert torch.cuda.is_available(), "MoE requires to enable CUDA first" + + self.world_size = dist.get_world_size() + self.seed = seed + dist.get_rank() + self.max_ep_size = min(max_ep_size, dist.get_world_size()) + self.min_dp_size = self.world_size // self.max_ep_size + self.parallel = parallel + + # Enabling kernel optimization may raise error in some cases + # Users can close kernel optimization manually + self.use_kernel_optim = use_kernel_optim + + self.has_setup = True + + def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoeParallelInfo]: + """Calculate the Data Parallel Group and Expert Parallel Group. + + Parameters + ---------- + num_experts : int + The number experts + + Returns + ------- + int, MoeParallelInfo + number of local experts, the MoeParallelInfo of the current ep_size + """ + + gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater + lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less + + assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number" \ + " is not a multiple of ep size or vice versa." + + # If the number of experts is greater than maximum expert parallel size. a.k.a ep_size, + # there are multiple experts in each GPU and each GPU has different experts + # So it's data parallel size is 1 + # Otherwise, there is only one expert in each GPU + # The data parallel size should be calculated + dp_size = 1 if gt_flag else self.max_ep_size // num_experts + ep_size = self.max_ep_size // dp_size + + # Calculate the number of experts for each GPU + if use_tp: + num_local_experts = num_experts + else: + num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size + + # Don't forget to multiply minimum data parallel size + dp_size *= self.min_dp_size + if not (ep_size in self.parallel_info_dict): + self.parallel_info_dict[ep_size] = get_moe_info(ep_size, dp_size) + + return num_local_experts, self.parallel_info_dict[ep_size] + + def set_kernel_not_use(self): + self.use_kernel_optim = False + + def reset_loss(self): + self.router_aux_loss, self.router_z_loss = [], [] + + def add_loss(self, aux_loss: float = 0., z_loss: float = 0.): + self.router_aux_loss.append(aux_loss) + self.router_z_loss.append(z_loss) + + def get_loss(self): + cur_loss = self.router_aux_loss, self.router_z_loss + return cur_loss + + def get_parallel(self): + return self.parallel + + +MOE_MANAGER = MoeManager() diff --git a/colossalai/nn/layer/moe/routers.py b/colossalai/moe/routers.py similarity index 98% rename from colossalai/nn/layer/moe/routers.py rename to colossalai/moe/routers.py index 9332302a096a..dd9243421667 100644 --- a/colossalai/nn/layer/moe/routers.py +++ b/colossalai/moe/routers.py @@ -8,8 +8,8 @@ import torch.nn.functional as F from torch.distributed import ProcessGroup -from colossalai.context import MOE_CONTEXT -from colossalai.nn.layer.moe._operation import moe_cumsum +from colossalai.moe._operation import moe_cumsum +from colossalai.moe.manager import MOE_MANAGER from colossalai.utils import get_current_device @@ -66,7 +66,7 @@ def set_z_loss(self, router_logits: torch.Tensor): def pop_router_loss(self) -> torch.Tensor: assert self._aux_loss is not None - MOE_CONTEXT.add_loss(self._aux_loss, self._z_loss) + MOE_MANAGER.add_loss(self._aux_loss, self._z_loss) self._aux_loss = None self._z_loss = None diff --git a/colossalai/nn/layer/moe/utils.py b/colossalai/moe/utils.py similarity index 66% rename from colossalai/nn/layer/moe/utils.py rename to colossalai/moe/utils.py index 5b3542c80595..58c1665a4d63 100644 --- a/colossalai/nn/layer/moe/utils.py +++ b/colossalai/moe/utils.py @@ -1,10 +1,13 @@ import contextlib -from typing import Callable +from typing import Callable, Dict, List import torch +import torch.distributed as dist +import torch.nn as nn import torch.nn.functional as F -from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.moe.manager import MOE_MANAGER +from colossalai.tensor.moe_tensor.api import get_dp_group, get_dp_group_ranks, get_ep_size, is_moe_tensor from colossalai.utils import get_current_device @@ -119,3 +122,45 @@ def _skip_init(x, *args, **kwargs): for fn, fn_saved in zip(init_fn_list, fn_saved): fn = fn_saved return + + +def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]: + """Returns a parameter dictionary, the key of which is the expert parallel + size of every parameter. Since the parameters in data parallelism is replicated + in each GPU, we set their ep_size to 1. + + Args: + model (:class:`torch.nn.Module`): A pyTorch `nn.Module` from which we get dict. + """ + epsize_param_dict = dict() + for param in model.parameters(): + if not is_moe_tensor(param): + ep_size = 1 # set ep_size to 1 for dp parameters + else: + ep_size = get_ep_size(param) + if ep_size not in epsize_param_dict: + epsize_param_dict[ep_size] = [] + epsize_param_dict[ep_size].append(param) + + return epsize_param_dict + + +def sync_moe_model_param(model: nn.Module): + """Make sure model parameters are consistent in MoE parallel context. + + Args: + model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. + """ + param_dict = get_moe_epsize_param_dict(model) + + # synchronize the parameters whose dp_group is the whole world + if 1 in param_dict: + for param in param_dict[1]: + dist.broadcast(param, src=0) + + for ep_size in param_dict: + # When ep_size = world_size, communication is not needed + if ep_size != 1 and ep_size != MOE_MANAGER.world_size: + for param in param_dict[ep_size]: + src_rank = get_dp_group_ranks(param)[0] + dist.broadcast(param, src=src_rank, group=get_dp_group(param)) diff --git a/colossalai/nn/layer/moe/__init__.py b/colossalai/nn/layer/moe/__init__.py index 52f529814eba..5280acf8dee7 100644 --- a/colossalai/nn/layer/moe/__init__.py +++ b/colossalai/nn/layer/moe/__init__.py @@ -1,10 +1,12 @@ -from .checkpoint import MoeCheckpintIO -from .experts import EPMLPExperts, TPMLPExperts, build_ffn_experts -from .layers import MoeLayer, MoeModule, SparseMLP -from .routers import MoeRouter, Top1Router, Top2Router -from .utils import NormalNoiseGenerator, UniformNoiseGenerator - -__all__ = [ - 'EPMLPExperts', 'TPMLPExperts', 'Top1Router', 'Top2Router', 'MoeModule', 'MoeLayer', 'NormalNoiseGenerator', - 'UniformNoiseGenerator', 'SparseMLP', 'MoeRouter', 'MoeCheckpintIO', 'build_ffn_experts' -] +MoeModule = None +MoeLayer = None +build_ffn_experts = None +EPMLPExperts = None +TPMLPExperts = None +Top1Router = None +Top2Router = None +NormalNoiseGenerator = None +UniformNoiseGenerator = None +SparseMLP = None +MoeRouter = None +MoeCheckpintIO = None diff --git a/colossalai/nn/loss/__init__.py b/colossalai/nn/loss/__init__.py index 7c6fb099d272..e69de29bb2d1 100644 --- a/colossalai/nn/loss/__init__.py +++ b/colossalai/nn/loss/__init__.py @@ -1 +0,0 @@ -# from .loss_moe import MoeCrossEntropyLoss, MoeLoss diff --git a/colossalai/tensor/moe_tensor/api.py b/colossalai/tensor/moe_tensor/api.py index b9b6d338438e..442b3c0f4958 100644 --- a/colossalai/tensor/moe_tensor/api.py +++ b/colossalai/tensor/moe_tensor/api.py @@ -107,3 +107,29 @@ def get_dp_rank(tensor: torch.Tensor) -> int: int: The data parallel rank of the given tensor. """ return dist.get_rank(get_dp_group(tensor)) + + +def get_ep_group_ranks(tensor: torch.Tensor) -> int: + """ + Get the expert parallel group ranks of the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + int: The expert parallel group ranks of the given tensor. + """ + return tensor.moe_info.ep_group_ranks + + +def get_dp_group_ranks(tensor: torch.Tensor) -> int: + """ + Get the data parallel group ranks of the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + int: The data parallel group ranks of the given tensor. + """ + return tensor.moe_info.dp_group_ranks diff --git a/colossalai/tensor/moe_tensor/moe_info.py b/colossalai/tensor/moe_tensor/moe_info.py index 89f79f162b5b..ca7f163b9c24 100644 --- a/colossalai/tensor/moe_tensor/moe_info.py +++ b/colossalai/tensor/moe_tensor/moe_info.py @@ -12,4 +12,6 @@ def __init__(self, ep_size: int, dp_size: int): self.ep_size = ep_size self.pg = ProcessGroupMesh(self.dp_size, self.ep_size) self.ep_group = self.pg.get_group_along_axis(self.ep_axis) + self.ep_group_ranks = self.pg.get_ranks_in_group(self.ep_group) self.dp_group = self.pg.get_group_along_axis(self.dp_axis) + self.dp_group_ranks = self.pg.get_ranks_in_group(self.dp_group) diff --git a/colossalai/utils/moe.py b/colossalai/utils/moe.py deleted file mode 100644 index 1b75448bdd3c..000000000000 --- a/colossalai/utils/moe.py +++ /dev/null @@ -1,53 +0,0 @@ -from typing import Dict, List - -import torch.distributed as dist -import torch.nn as nn - -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.legacy.context import ParallelMode -from colossalai.legacy.core import global_context as gpc -from colossalai.legacy.utils import is_using_ddp - - -def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]: - """Returns a parameter dictionary, the key of which is the expert parallel - size of every parameter. Since the parameters in data parallelism is replicated - in each GPU, we set their ep_size to 1. - - Args: - model (:class:`torch.nn.Module`): A pyTorch `nn.Module` from which we get dict. - """ - epsize_param_dict = dict() - for param in model.parameters(): - if not hasattr(param, "moe_info"): - ep_size = 1 # set ep_size to 1 for dp parameters - else: - ep_size = param.moe_info.ep_size - if ep_size not in epsize_param_dict: - epsize_param_dict[ep_size] = [] - epsize_param_dict[ep_size].append(param) - - return epsize_param_dict - - -def sync_moe_model_param(model: nn.Module): - """Make sure model parameters are consistent in MoE parallel context. - - Args: - model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. - """ - if is_using_ddp(): - param_dict = get_moe_epsize_param_dict(model) - - # synchronize the parameters whose dp_group is the whole world - if 1 in param_dict: - src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0] - for param in param_dict[1]: - dist.broadcast(param, src=src_rank, group=gpc.get_group(ParallelMode.DATA)) - - for ep_size in param_dict: - # When ep_size = world_size, communication is not needed - if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: - src_rank = dist.get_rank(MOE_CONTEXT.parallel_info_dict[ep_size].ep_group) - for param in param_dict[ep_size]: - dist.broadcast(param, src=src_rank, group=param.moe_info.dp_group) diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index 1ea9d48523c3..ec7e1e8941f7 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -36,8 +36,8 @@ replace_return_docstrings, ) -from colossalai.context import MOE_CONTEXT -from colossalai.nn.layer.moe.layers import SparseMLP +from colossalai.moe.layers import SparseMLP +from colossalai.moe.manager import MOE_MANAGER logger = logging.get_logger(__name__) @@ -455,7 +455,7 @@ def __init__(self, config: LlamaConfig, moe: bool): min_capacity=config.min_capacity, noisy_policy=config.noisy_policy, drop_tks=config.drop_tks, - expert_parallel=MOE_CONTEXT.get_parallel() if MOE_CONTEXT.is_initialized else config.expert_parallel, + expert_parallel=MOE_MANAGER.get_parallel() if MOE_MANAGER.is_initialized else config.expert_parallel, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, activation=config.hidden_act, @@ -891,7 +891,7 @@ def forward( "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" # reset moe loss - MOE_CONTEXT.reset_loss() + MOE_MANAGER.reset_loss() output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = (output_hidden_states @@ -1014,7 +1014,7 @@ def _reorder_cache(past_key_values, beam_idx): return reordered_past def _calculate_router_loss(self): - aux_loss, z_loss = MOE_CONTEXT.get_loss() + aux_loss, z_loss = MOE_MANAGER.get_loss() assert len(aux_loss) == len(z_loss) == self.config.num_hidden_layers // self.config.moe_layer_interval aux_loss = self.config.router_aux_loss_factor * sum(aux_loss) / len(aux_loss) z_loss = self.config.router_z_loss_factor * sum(z_loss) / len(z_loss) diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index 67dd387a3950..132f17a9ba0f 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -15,10 +15,10 @@ from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.cluster import DistCoordinator -from colossalai.context import MOE_CONTEXT from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.nn.layer.moe import MoeCheckpintIO -from colossalai.nn.layer.moe.utils import skip_init +from colossalai.moe import MoeCheckpintIO +from colossalai.moe.manager import MOE_MANAGER +from colossalai.moe.utils import skip_init from colossalai.utils import get_current_device @@ -95,7 +95,7 @@ def main(): coordinator = DistCoordinator() # Set up moe - MOE_CONTEXT.setup(seed=42, parallel="EP") + MOE_MANAGER.setup(seed=42, parallel="EP") # Manage loggers disable_existing_loggers() diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py index 7f9bfe376632..3371c35fd295 100644 --- a/tests/test_moe/moe_utils.py +++ b/tests/test_moe/moe_utils.py @@ -2,17 +2,14 @@ import torch.distributed as dist import torch.nn as nn -from colossalai.context import MOE_CONTEXT -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc from colossalai.engine.gradient_handler._base_gradient_handler import BaseGradientHandler from colossalai.engine.gradient_handler.utils import bucket_allreduce +from colossalai.moe import SparseMLP +from colossalai.moe.manager import MOE_MANAGER +from colossalai.moe.utils import get_moe_epsize_param_dict from colossalai.nn import CheckpointModule -from colossalai.nn.layer import SparseMLP from colossalai.registry import GRADIENT_HANDLER from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor -from colossalai.utils.moe import get_moe_epsize_param_dict class MoeModel(nn.Module): @@ -39,7 +36,7 @@ def _forward(self, x): self.test_transform = TestSubModule() def forward(self, x): - MOE_CONTEXT.reset_loss() + MOE_MANAGER.reset_loss() x = self.test_embed(x) x = self.test_transform(x) @@ -68,21 +65,19 @@ def handle_gradient(self): Then running an all-reduce operation for all parameters in experts across moe model parallel group """ - global_data = gpc.data_parallel_size - - if global_data > 1: + if dist.get_world_size() > 1: epsize_param_dict = get_moe_epsize_param_dict(self._model) # epsize is 1, indicating the params are replicated among processes in data parallelism # use the ParallelMode.DATA to get data parallel group # reduce gradients for all parameters in data parallelism if 1 in epsize_param_dict: - bucket_allreduce(param_list=epsize_param_dict[1], group=gpc.get_group(ParallelMode.DATA)) + bucket_allreduce(param_list=epsize_param_dict[1]) for ep_size in epsize_param_dict: - if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: + if ep_size != 1 and ep_size != MOE_MANAGER.world_size: bucket_allreduce(param_list=epsize_param_dict[ep_size], - group=MOE_CONTEXT.parallel_info_dict[ep_size].dp_group) + group=MOE_MANAGER.parallel_info_dict[ep_size].dp_group) def sync_tp_from_ep(tp_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None: @@ -160,3 +155,17 @@ def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_ assert torch.allclose(local_param.grad, all_grad) else: local_param.data.copy_(all_param.data) + + +def assert_not_equal_in_group(tensor, process_group=None): + # all gather tensors from different ranks + world_size = dist.get_world_size(process_group) + tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] + dist.all_gather(tensor_list, tensor, group=process_group) + + # check if they are equal one by one + for i in range(world_size - 1): + a = tensor_list[i] + b = tensor_list[i + 1] + assert not torch.allclose( + a, b), f'expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}' diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index e36555fa2eb9..cbfbcae6ce33 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -4,21 +4,21 @@ import torch.nn as nn import colossalai -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.nn.layer.moe import SparseMLP +from colossalai.moe import SparseMLP +from colossalai.moe.manager import MOE_MANAGER +from colossalai.moe.utils import sync_moe_model_param from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device -from colossalai.utils.moe import sync_moe_model_param -from tests.test_moe.moe_utils import MoeGradientHandler +from tests.test_moe.moe_utils import MoeGradientHandler, assert_not_equal_in_group BATCH_SIZE = 4 -DIM = 16 +DIM = 4 def run_test(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_CONTEXT.setup(42) # MOE initialization + MOE_MANAGER.setup(42) # MOE initialization num_experts_list = [1, 2, 4] layer_list = [] for num_experts in num_experts_list: @@ -32,13 +32,22 @@ def run_test(rank, world_size, port): model = nn.ModuleList(layer_list) model = model.to(get_current_device()) + dist_dict = MOE_MANAGER.parallel_info_dict + assert_not_equal_in_group(layer_list[0].experts.wi.data, dist_dict[1].dp_group) + assert_not_equal_in_group(layer_list[0].experts.wo.data, dist_dict[1].dp_group) + assert_not_equal_in_group(layer_list[1].experts.wi.data, dist_dict[2].dp_group) + assert_not_equal_in_group(layer_list[1].experts.wo.data, dist_dict[2].dp_group) + assert_not_equal_in_group(layer_list[2].experts.wi.data, dist_dict[4].dp_group) + assert_not_equal_in_group(layer_list[2].experts.wo.data, dist_dict[4].dp_group) + sync_moe_model_param(model) - dist_dict = MOE_CONTEXT.parallel_info_dict assert_equal_in_group(layer_list[0].experts.wi.data, dist_dict[1].dp_group) assert_equal_in_group(layer_list[0].experts.wo.data, dist_dict[1].dp_group) assert_equal_in_group(layer_list[1].experts.wi.data, dist_dict[2].dp_group) assert_equal_in_group(layer_list[1].experts.wo.data, dist_dict[2].dp_group) + assert_equal_in_group(layer_list[2].experts.wi.data, dist_dict[4].dp_group) + assert_equal_in_group(layer_list[2].experts.wo.data, dist_dict[4].dp_group) # MoE model synchronization passed grad_handler = MoeGradientHandler(model, 0) @@ -48,7 +57,7 @@ def run_test(rank, world_size, port): data = torch.randn(BATCH_SIZE, DIM, device=get_current_device()) grad = torch.randn_like(data) - MOE_CONTEXT.reset_loss() + MOE_MANAGER.reset_loss() for layer in layer_list: data = layer(data) data.backward(grad) @@ -58,6 +67,8 @@ def run_test(rank, world_size, port): assert_equal_in_group(layer_list[0].experts.wo.grad, dist_dict[1].dp_group) assert_equal_in_group(layer_list[1].experts.wi.grad, dist_dict[2].dp_group) assert_equal_in_group(layer_list[1].experts.wo.grad, dist_dict[2].dp_group) + assert_equal_in_group(layer_list[2].experts.wi.grad, dist_dict[4].dp_group) + assert_equal_in_group(layer_list[2].experts.wo.grad, dist_dict[4].dp_group) # MoE grad handler test passed diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index ce663136600e..1927c9553683 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -1,14 +1,14 @@ import pytest import torch +import torch.distributed as dist import colossalai -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.core import global_context as gpc -from colossalai.nn.layer.moe import SparseMLP +from colossalai.moe import SparseMLP +from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device -BATCH_SIZE = 16 +BATCH_SIZE = 4 NUM_EXPERTS = 4 @@ -21,11 +21,11 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f torch.backends.cuda.matmul.allow_tf32 = False colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - local_rank = gpc.get_local_rank(ParallelMode.GLOBAL) + local_rank = dist.get_rank() - MOE_CONTEXT.setup(42) # MOE environment initialization - MOE_CONTEXT.reset_loss() - torch.manual_seed(rs + local_rank) # set each process has different random seed + MOE_MANAGER.setup(42) # MOE environment initialization + MOE_MANAGER.reset_loss() + torch.manual_seed(rs + local_rank) # set each process has different random seed # get randomized data tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True) diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index 402346527530..1c70c5d43dbd 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -5,8 +5,8 @@ import torch.distributed as dist import colossalai -from colossalai.context import MOE_CONTEXT -from colossalai.nn.layer.moe import MoeCheckpintIO +from colossalai.moe import MoeCheckpintIO +from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device from tests.test_moe.moe_utils import MoeModel @@ -32,7 +32,7 @@ def exam_moe_checkpoint(): def _run_dist(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_CONTEXT.setup(seed=42) + MOE_MANAGER.setup(seed=42) exam_moe_checkpoint() diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py index cb261912e0f6..253fe6a7c094 100644 --- a/tests/test_moe/test_moe_ep_tp.py +++ b/tests/test_moe/test_moe_ep_tp.py @@ -3,11 +3,11 @@ import torch.distributed as dist import colossalai -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.nn.layer.moe import SparseMLP +from colossalai.moe import SparseMLP +from colossalai.moe.manager import MOE_MANAGER +from colossalai.moe.utils import sync_moe_model_param from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device -from colossalai.utils.moe import sync_moe_model_param from tests.test_moe.moe_utils import MoeGradientHandler, sync_tp_from_ep BATCH_SIZE = 4 @@ -16,7 +16,7 @@ def run_test(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_CONTEXT.setup(42) # MOE initialization + MOE_MANAGER.setup(42) # MOE initialization ep_model = SparseMLP(num_experts=4, expert_parallel="EP", hidden_size=DIM, intermediate_size=DIM) tp_model = SparseMLP(num_experts=4, expert_parallel="TP", hidden_size=DIM, intermediate_size=DIM) @@ -25,7 +25,7 @@ def run_test(rank, world_size, port): # sync ep param sync_moe_model_param(ep_model) - dist_dict = MOE_CONTEXT.parallel_info_dict + dist_dict = MOE_MANAGER.parallel_info_dict assert_equal_in_group(ep_model.experts.wi.data, dist_dict[2].dp_group) assert_equal_in_group(ep_model.experts.wo.data, dist_dict[2].dp_group) grad_handler = MoeGradientHandler(ep_model) @@ -38,9 +38,9 @@ def run_test(rank, world_size, port): ep_data = tp_data.detach()[2 * rank:2 * (rank + 1)] out_tp = tp_model(tp_data) - MOE_CONTEXT.reset_loss() + MOE_MANAGER.reset_loss() out_ep = ep_model(ep_data) - MOE_CONTEXT.reset_loss() + MOE_MANAGER.reset_loss() assert torch.allclose(out_ep, out_tp[2 * rank:2 * (rank + 1)]) out_tp.mean().backward() diff --git a/tests/test_moe/test_moe_group.py b/tests/test_moe/test_moe_group.py index fd87a9a3135d..f5d54ba290aa 100644 --- a/tests/test_moe/test_moe_group.py +++ b/tests/test_moe/test_moe_group.py @@ -3,11 +3,11 @@ import torch.nn as nn import colossalai -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.nn.layer.moe import EPMLPExperts, TPMLPExperts +from colossalai.moe import EPMLPExperts, TPMLPExperts +from colossalai.moe.manager import MOE_MANAGER +from colossalai.moe.utils import sync_moe_model_param from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device -from colossalai.utils.moe import sync_moe_model_param HIDDEN_SIZE = 4 INTERMEDIATE_SIZE = 8 @@ -31,7 +31,7 @@ def run_moe_init(expert_cls): assert exp2.num_local_experts == 4 assert exp3.num_local_experts == 8 - parallel_info_dict = MOE_CONTEXT.parallel_info_dict + parallel_info_dict = MOE_MANAGER.parallel_info_dict rank = dist.get_rank() # group creation assert @@ -59,7 +59,7 @@ def run_moe_init(expert_cls): def _run_test(rank, world_size, port, expert_cls): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_CONTEXT.setup(seed=42) + MOE_MANAGER.setup(seed=42) run_moe_init(expert_cls) diff --git a/tests/test_moe/test_moe_local.py b/tests/test_moe/test_moe_local.py index e41a0d821a10..872b65c2d1f1 100644 --- a/tests/test_moe/test_moe_local.py +++ b/tests/test_moe/test_moe_local.py @@ -3,11 +3,11 @@ import torch.distributed as dist import colossalai -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.nn.layer.moe import SparseMLP +from colossalai.moe import SparseMLP +from colossalai.moe.manager import MOE_MANAGER +from colossalai.moe.utils import sync_moe_model_param from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device -from colossalai.utils.moe import sync_moe_model_param from tests.test_moe.moe_utils import MoeGradientHandler, sync_local_from_ep BATCH_SIZE = 4 @@ -16,7 +16,7 @@ def run_test(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_CONTEXT.setup(42) # MOE initialization + MOE_MANAGER.setup(42) # MOE initialization ep_model = SparseMLP(num_experts=4, expert_parallel="EP", hidden_size=DIM, intermediate_size=DIM) local_model = SparseMLP(num_experts=4, expert_parallel=None, hidden_size=DIM, intermediate_size=DIM) @@ -25,7 +25,7 @@ def run_test(rank, world_size, port): # sync ep param sync_moe_model_param(ep_model) - dist_dict = MOE_CONTEXT.parallel_info_dict + dist_dict = MOE_MANAGER.parallel_info_dict assert_equal_in_group(ep_model.experts.wi.data, dist_dict[2].dp_group) assert_equal_in_group(ep_model.experts.wo.data, dist_dict[2].dp_group) grad_handler = MoeGradientHandler(ep_model) @@ -38,9 +38,9 @@ def run_test(rank, world_size, port): ep_data = tp_data.detach()[2 * rank:2 * (rank + 1)] out_tp = local_model(tp_data) - MOE_CONTEXT.reset_loss() + MOE_MANAGER.reset_loss() out_ep = ep_model(ep_data) - MOE_CONTEXT.reset_loss() + MOE_MANAGER.reset_loss() assert torch.allclose(out_ep, out_tp[2 * rank:2 * (rank + 1)]) out_tp.mean().backward() @@ -54,10 +54,11 @@ def run_test(rank, world_size, port): @pytest.mark.dist +@pytest.mark.parametrize("world_size", [2]) @rerun_if_address_is_in_use() -def test_moe_ep_tp(): - spawn(run_test, 2) +def test_moe_local(world_size): + spawn(run_test, world_size) if __name__ == '__main__': - test_moe_ep_tp() + test_moe_local() diff --git a/tests/test_moe/test_moe_zero_fwd_bwd.py b/tests/test_moe/test_moe_zero_fwd_bwd.py index f1f888203746..2b2afa4623b5 100644 --- a/tests/test_moe/test_moe_zero_fwd_bwd.py +++ b/tests/test_moe/test_moe_zero_fwd_bwd.py @@ -5,8 +5,7 @@ from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel -from colossalai.context import MOE_CONTEXT -from colossalai.nn import MoeLoss +from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel @@ -87,7 +86,7 @@ def run_zero_test(local_rank, world_size, stage=1): def run_dist(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_CONTEXT.setup(seed=42) + MOE_MANAGER.setup(seed=42) seed_all(42 + rank) run_zero_test(rank, world_size, stage=1) run_zero_test(rank, world_size, stage=2) diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index 229ee528b4fc..38a5cfbfd66e 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -5,8 +5,7 @@ from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel -from colossalai.context import MOE_CONTEXT -from colossalai.nn import MoeLoss +from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import rerun_if_address_is_in_use, spawn from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel @@ -76,7 +75,7 @@ def run_zero_optim_test(local_rank, world_size, stage=1): def run_dist(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_CONTEXT.setup(seed=42) + MOE_MANAGER.setup(seed=42) run_zero_optim_test(rank, world_size, stage=1) run_zero_optim_test(rank, world_size, stage=2) From d1d0de8988114b0a884b46982f2b1284be254dde Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Tue, 12 Sep 2023 09:53:55 +0800 Subject: [PATCH 09/46] [moe]: add top k router (#4597) * docs: add shape spec * docs: add doc * feat: add top_k router * feat: update init * test: add moe router tests * fix: reorder return values --- colossalai/moe/__init__.py | 8 +- colossalai/moe/experts.py | 26 ++++- colossalai/moe/layers.py | 20 +++- colossalai/moe/routers.py | 185 ++++++++++++++++++++++++++---- tests/test_moe/test_moe_router.py | 48 ++++++++ 5 files changed, 250 insertions(+), 37 deletions(-) create mode 100644 tests/test_moe/test_moe_router.py diff --git a/colossalai/moe/__init__.py b/colossalai/moe/__init__.py index 492cdaf13d1d..1614987538c1 100644 --- a/colossalai/moe/__init__.py +++ b/colossalai/moe/__init__.py @@ -1,10 +1,12 @@ from .checkpoint import MoeCheckpintIO from .experts import EPMLPExperts, TPMLPExperts, build_ffn_experts from .layers import SparseMLP -from .routers import MoeRouter, Top1Router, Top2Router +from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter from .utils import NormalNoiseGenerator, UniformNoiseGenerator __all__ = [ - 'EPMLPExperts', 'TPMLPExperts', 'Top1Router', 'Top2Router', 'NormalNoiseGenerator', 'UniformNoiseGenerator', - 'SparseMLP', 'MoeRouter', 'MoeCheckpintIO', 'build_ffn_experts' + 'EPMLPExperts', 'TPMLPExperts', 'build_ffn_experts', + 'MoeRouter', 'Top1Router', 'Top2Router', 'TopKRouter', + 'NormalNoiseGenerator', 'UniformNoiseGenerator', + 'SparseMLP', 'MoeCheckpintIO' ] diff --git a/colossalai/moe/experts.py b/colossalai/moe/experts.py index da4fe58977e8..9715f4dc37b3 100644 --- a/colossalai/moe/experts.py +++ b/colossalai/moe/experts.py @@ -1,9 +1,9 @@ import math from contextlib import nullcontext +from typing import Callable, Optional import torch import torch.nn as nn - from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import get_activation @@ -14,6 +14,15 @@ class BaseMLPExperts(nn.Module): """ SparseMLP is a multi-layer perceptron with sparse expert parallel layers. + + Args: + num_experts (int): The number of experts + forward: hidden_size --> intermediate_size --> hidden_size + hidden_size (int): The hidden size of MLP + intermediate_size (int): The intermediate size of MLP + expert_parallel (str, optional): The parallelism of experts. Now we have 'EP' and 'TP'. + activation (optional): The activation function of MLP + drop_rate (float, optional): The drop rate of MLP """ def __init__( @@ -21,8 +30,8 @@ def __init__( num_experts: int, hidden_size: int, intermediate_size: int, - expert_parallel: str = None, - activation: str = None, + expert_parallel: Optional[str] = None, + activation: Optional[Callable] = None, drop_rate: float = 0, gated: bool = False, ): @@ -76,7 +85,14 @@ def __init__( for param in self.parameters(): set_moe_tensor_info(param, self.moe_info) - def forward(self, x: torch.Tensor) -> torch.Tensor: # inputs [g, e, c, h] + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): The input tensor of shape (num_groups, num_experts, capacity, hidden_size) + + Returns: + torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size) + """ x = MoeInGradScaler.apply(x, self.ep_size) e = x.size(1) @@ -97,7 +113,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # inputs [g, e, c, h] x = x.reshape(inshape) x = x.transpose(0, 1).contiguous() x = MoeOutGradScaler.apply(x, self.ep_size) - return x # outputs [g, e, c, h] + return x class EPMLPExperts(BaseMLPExperts): diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index ace81b543273..a3f68cf7a6f1 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -88,7 +88,16 @@ def __init__(self, self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size)) nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / self.hidden_size)) - def forward(self, inputs: torch.Tensor) -> Tuple: + def forward(self, + inputs: torch.Tensor) \ + -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + inputs (torch.Tensor): The input tensor of shape (batch_size, seq_len, hidden_size) + + Returns: + torch.Tensor: The output tensor of shape (batch_size, seq_len, hidden_size) + """ # reshape the input tokens tokens = inputs.reshape(-1, self.hidden_size) @@ -100,6 +109,7 @@ def forward(self, inputs: torch.Tensor) -> Tuple: # the result from the router route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group) + # dispatch_data: (num_experts, capacity, hidden_size) if self.use_kernel: dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:]) dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.hidden_size) @@ -107,7 +117,7 @@ def forward(self, inputs: torch.Tensor) -> Tuple: sec_mask_f = route_result_list[1].type_as(inputs) dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) - # dispatch_data [e, c, h] + # expert_output: (num_groups, num_experts, capacity, hidden_size) if self.expert_parallel == "EP": expert_output = self._ep_process(dispatch_data) elif self.expert_parallel == "TP": @@ -115,9 +125,9 @@ def forward(self, inputs: torch.Tensor) -> Tuple: elif self.expert_parallel is None: expert_output = self._local_process(dispatch_data) else: - raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts " - "build function.") - # expert_output [e, c, h] + raise NotImplementedError("This kind of communication has not been implemented yet.\n" + "Please use Experts build function.") + if self.use_kernel: expert_output = expert_output.reshape(-1, self.hidden_size) ans = MoeCombine.apply(expert_output, *route_result_list) diff --git a/colossalai/moe/routers.py b/colossalai/moe/routers.py index dd9243421667..688471530758 100644 --- a/colossalai/moe/routers.py +++ b/colossalai/moe/routers.py @@ -1,6 +1,6 @@ import math from abc import ABC -from typing import Callable, Optional +from typing import Callable, Optional, Tuple import torch import torch.distributed as dist @@ -29,7 +29,7 @@ def __init__(self, capacity_factor_train: float, capacity_factor_eval: float, min_capacity: int, - noisy_func: Callable = None, + noisy_func: Optional[Callable] = None, drop_tks: bool = True): super().__init__() self.k_value = k_value @@ -72,9 +72,10 @@ def pop_router_loss(self) -> torch.Tensor: class Top1Router(MoeRouter): - """Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] - for routing usage. More detailed function can be found in the paper about Switch Transformer - of Google. + """Top1 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity) + and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed + function can be found in the paper about Switch Transformer of Google. + Args: capacity_factor_train (float, optional): Capacity factor in routing of training. capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. @@ -89,7 +90,7 @@ def __init__(self, capacity_factor_eval: float = 2.0, min_capacity: int = 4, select_policy: str = "first", - noisy_func: Callable = None, + noisy_func: Optional[Callable] = None, drop_tks: bool = True): super().__init__(k_value=1, capacity_factor_train=capacity_factor_train, @@ -100,12 +101,27 @@ def __init__(self, self.select_policy = select_policy assert select_policy in {"first", "random"} if select_policy == "random": - self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()), - high=torch.tensor(1.0, - device=get_current_device())).rsample - - def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): - + self.uniform = torch.distributions.uniform.Uniform( + low=torch.tensor(0.0, device=get_current_device()), + high=torch.tensor(1.0, device=get_current_device()) + ).rsample + + def forward(self, + inputs: torch.Tensor, + use_kernel: bool = False, + ep_group: Optional[ProcessGroup] = None + ) -> Tuple: + """ + Args: + inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). + + Returns: + 1. use_kernel is False: + The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity). + The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity). + 2. use_kernel is True: + ... + """ if self.noisy_func is not None and self.training: inputs = self.noisy_func(inputs) @@ -154,8 +170,10 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti class Top2Router(MoeRouter): - """Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] - for routing usage. More detailed function can be found in the paper about ViT-MoE. + """Top2 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity) + and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed + function can be found in the paper about ViT-MoE. + Args: capacity_factor_train (float, optional): Capacity factor in routing of training. capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. @@ -168,7 +186,7 @@ def __init__(self, capacity_factor_train: float = 1.25, capacity_factor_eval: float = 2.0, min_capacity: int = 4, - noisy_func: Callable = None, + noisy_func: Optional[Callable] = None, drop_tks: bool = True): super().__init__(k_value=2, capacity_factor_train=capacity_factor_train, @@ -177,8 +195,22 @@ def __init__(self, noisy_func=noisy_func, drop_tks=drop_tks) - def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): - # inputs: [s, h] + def forward(self, + inputs: torch.Tensor, + use_kernel: bool = False, + ep_group: Optional[ProcessGroup] = None + ) -> Tuple: + """ + Args: + inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). + + Returns: + 1. use_kernel is False: + The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity). + The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity). + 2. use_kernel is True: + ... + """ if self.noisy_func is not None and self.training: inputs = self.noisy_func(inputs) @@ -238,11 +270,116 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti return cb_weight, sec_mask -def get_router_cls(top_k: int) -> MoeRouter: - if top_k == 1: - router_cls = Top1Router - elif top_k == 2: - router_cls = Top2Router +class TopKRouter(MoeRouter): + """Masked matmul router using tokens choose top-k experts assignment. + + NOTE: this is modified from flaxformer. + This router uses the same mechanism as in Switch Transformer + (https://arxiv.org/abs/2101.03961) and V-MoE + (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are + sorted by router_probs and then routed to their choice of expert until the + expert's expert_capacity is reached. There is no guarantee that each token is + processed by an expert, or that each expert receives at least one token. + + Attributes: + num_selected_experts: Maximum number of experts to which each token is + routed. Tokens may be routed to fewer experts if particular experts are + oversubscribed / reach capacity. + """ + + def __init__(self, + num_selected_experts: int, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_func: Optional[Callable] = None, + drop_tks: bool = True): + super().__init__(num_selected_experts, + capacity_factor_train, + capacity_factor_eval, + min_capacity, + noisy_func, + drop_tks) + + def forward(self, + router_probs: torch.Tensor, + expert_capacity: int, + ) -> Tuple: + """Computes masks for the top-k experts per token. + + Args: + router_probs: [num_groups, tokens_per_group, num_experts] + probabilities used to determine the routing of tokens to the experts. + + Returns: + Dispatch and combine arrays for routing with masked matmuls. + """ + num_groups, _, num_experts = router_probs.shape + + # Top-k router probability and corresponding expert indices for each token. + # Shape: [num_groups, tokens_per_group, num_selected_experts]. + expert_gate, expert_index = torch.topk(router_probs, self.k_value) + + # TODO + # auxiliary_loss = _load_balancing_loss(router_probs, expert_index) + + # Make num_selected_experts the leading axis to ensure that top-1 choices + # have priority over top-2 choices, which have priority over top-3 choices, + # etc. + expert_index = torch.transpose(expert_index, 1, 2) + # Shape: [num_groups, num_selected_experts * tokens_per_group] + expert_index = expert_index.reshape(num_groups, -1) + + # Create mask out of indices. + # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. + expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32) + + # Experts have a fixed capacity that we cannot exceed. A token's priority + # within the expert's buffer is given by the masked, cumulative capacity of + # its target expert. + # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. + token_priority = torch.cumsum(expert_mask, dim=1) * expert_mask - 1 + # Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts]. + token_priority = token_priority.reshape((num_groups, self.k_value, -1, num_experts)) + # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. + token_priority = torch.transpose(token_priority, 1, 2) + # For each token, across all selected experts, select the only non-negative + # (unmasked) priority. Now, for group G routing to expert E, token T has + # non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E + # is its targeted expert. + # Shape: [num_groups, tokens_per_group, num_experts]. + token_priority = torch.max(token_priority, dim=2)[0] + + # Token T can only be routed to expert E if its priority is positive and + # less than the expert capacity. One-hot matrix will ignore indices outside + # the range [0, expert_capacity). + # Shape: [num_groups, tokens_per_group, num_experts, expert_capacity]. + valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity) + token_priority = torch.masked_fill(token_priority, ~valid_mask, 0) + dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool) + valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, -1, expert_capacity) + dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0) + + # The combine array will be used for combining expert outputs, scaled by the + # router probabilities. Shape: [num_groups, tokens_per_group, num_experts, + # expert_capacity]. + combine_array = torch.einsum( + '...te,...tec->...tec', + router_probs, + dispatch_mask) + + return combine_array, dispatch_mask + + +def get_router_cls(top_k: int, + grouped: bool = False + ) -> MoeRouter: + if not grouped: + if top_k == 1: + return Top1Router + elif top_k == 2: + return Top2Router + else: + raise NotImplementedError("top_k > 2 is not supported yet") else: - raise NotImplementedError("top_k > 2 is not supported yet") - return router_cls + return TopKRouter diff --git a/tests/test_moe/test_moe_router.py b/tests/test_moe/test_moe_router.py new file mode 100644 index 000000000000..94c263baa5a3 --- /dev/null +++ b/tests/test_moe/test_moe_router.py @@ -0,0 +1,48 @@ +import pytest +import torch +from colossalai.moe.routers import (MoeRouter, Top1Router, Top2Router, + TopKRouter, get_router_cls) + + +@pytest.mark.parametrize(["router", "num_groups"], [ + (Top1Router(), 1), + (Top2Router(), 1), + (TopKRouter(num_selected_experts=3), 4), +]) +@pytest.mark.parametrize( + ["batch_size", "seq_len", "num_experts"], + [ + (4, 5, 8), + (3, 4, 4), + ] +) +def test_router_forward(router: MoeRouter, + batch_size: int, + seq_len: int, + num_experts: int, + num_groups: int): + x = torch.randn((batch_size * seq_len, num_experts)) + if num_groups > 1: + x = x.expand(num_groups, -1, -1) + + router.train() + if isinstance(router, TopKRouter): + combine_array, dispatch_mask = router(x, expert_capacity=2) + else: + combine_array, dispatch_mask = router(x) + assert combine_array.shape[:-1] == x.shape + assert dispatch_mask.shape[:-1] == x.shape + assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value) + + router.eval() + if isinstance(router, TopKRouter): + combine_array, dispatch_mask = router(x, expert_capacity=2) + else: + combine_array, dispatch_mask = router(x) + assert combine_array.shape[:-1] == x.shape + assert dispatch_mask.shape[:-1] == x.shape + assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value) + + +if __name__ == "__main__": + test_router_forward(Top1Router(), 4, 4, 4, 1) From 708bf6f5de938cbab6ba7ebad9ad31ac0a1cd9e4 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Wed, 13 Sep 2023 09:52:04 +0800 Subject: [PATCH 10/46] [moe]: modify router loss, polish code (#4693) * feat: check z_loss and add doc * style: rename misleading variable * feat: modify auxiliary loss * feat: add aux_loss in topk router and modify doc * docs: add fn doc --- colossalai/moe/routers.py | 99 +++++++++++++------ .../openmoe/model/modeling_openmoe.py | 36 ++++++- 2 files changed, 106 insertions(+), 29 deletions(-) diff --git a/colossalai/moe/routers.py b/colossalai/moe/routers.py index 688471530758..6fa89a416203 100644 --- a/colossalai/moe/routers.py +++ b/colossalai/moe/routers.py @@ -49,19 +49,60 @@ def get_capacity(self, logits_shape): assert capacity > 0 return capacity - def set_aux_loss(self, logits: torch.Tensor, cmask: torch.Tensor, num_experts: int) -> None: + def set_aux_loss(self, + router_probs: torch.Tensor, + expert_indices: torch.Tensor, + num_experts: int + ) -> None: + """Computes auxiliary load balancing loss as in Switch Transformer. + + See Switch Transformer (https://arxiv.org/abs/2101.03961). This function + implements the loss function presented in equations (4) - (6). It aims to + penalize those cases where the routing between experts is unbalanced. + + Args: + router_probs: Probability assigned to each expert per token. Shape: + [num_groups, tokens_per_group, num_experts]. + expert_indices: [num_groups, tokens_per_group, num_selected_experts] + indices identifying the top num_selected_experts for a given token. + """ assert self._aux_loss is None - me = torch.mean(logits, dim=0) - ce = torch.mean(cmask.float(), dim=0) - aux_loss = num_experts * torch.sum(me * ce) + if router_probs.dim() == expert_indices.dim() == 2: + router_probs = router_probs.unsqueeze(0) + expert_indices = expert_indices.unsqueeze(0) + assert router_probs.dim() == expert_indices.dim() == 3, \ + "router_probs must be 3D tensor and expert_indices must be 4D tensor" + + # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. + expert_mask = F.one_hot(expert_indices, num_experts) + # For a given token, determine if it was routed to a given expert. + # Shape: [num_groups, tokens_per_group, num_experts] + expert_mask = expert_mask.max(dim=-2)[0] + + tokens_per_group_and_expert = torch.mean(expert_mask.float(), dim=-2) + router_prob_per_group_and_expert = torch.mean(router_probs.float(), dim=-2) + aux_loss = num_experts**2 * torch.mean( + tokens_per_group_and_expert * router_prob_per_group_and_expert) self._aux_loss = aux_loss def set_z_loss(self, router_logits: torch.Tensor): + """Compute router z-loss. + + The router z-loss was introduced in Designing Effective Sparse Expert Models + (https://arxiv.org/abs/2202.08906). It encourages router logits to remain + small in an effort to improve stability. + + Args: + router_logits: [num_groups, tokens_per_group, num_experts] router logits. + """ assert self._z_loss is None - n, _ = router_logits.shape - log_z = torch.logsumexp(router_logits, axis=-1) - z_loss = log_z**2 - z_loss = torch.sum(z_loss, dtype=torch.float32) / n + if router_logits.dim() == 2: + router_logits = router_logits.unsqueeze(0) + assert router_logits.dim() == 3, "router_logits must be 3D tensor" + num_groups, tokens_per_group, _ = router_logits.shape + log_z = torch.logsumexp(router_logits, dim=-1) + z_loss = torch.sum(log_z**2, dtype=torch.float32 + ) / (num_groups * tokens_per_group) self._z_loss = z_loss def pop_router_loss(self) -> torch.Tensor: @@ -126,15 +167,15 @@ def forward(self, inputs = self.noisy_func(inputs) assert inputs.dtype == torch.float - logits = F.softmax(inputs, dim=-1) - num_experts = logits.size(-1) - capacity = self.get_capacity(logits.shape) + probs = F.softmax(inputs, dim=-1) + num_experts = probs.size(-1) + capacity = self.get_capacity(inputs.shape) top1_idx = torch.argmax(inputs, dim=-1) mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) # caculate router loss - self.set_aux_loss(logits, mask, num_experts) + self.set_aux_loss(probs, top1_idx.unsqueeze(-1), num_experts) self.set_z_loss(inputs) self.pop_router_loss() @@ -160,10 +201,10 @@ def forward(self, mask = torch.sum(mask, dim=-1) mask = torch.stack([mask], dim=0).to(torch.int32) dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32) - return logits, mask, dest_idx, num_experts * capacity + return probs, mask, dest_idx, num_experts * capacity else: ranks = F.one_hot(ranks, num_classes=capacity) - weight = mask * logits.type_as(inputs) + weight = mask * probs.type_as(inputs) combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1) sec_mask = combine_weights.bool() return combine_weights, sec_mask @@ -215,13 +256,13 @@ def forward(self, inputs = self.noisy_func(inputs) assert inputs.dtype == torch.float - logits = F.softmax(inputs, dim=-1) # logits: [s, e] - num_experts = logits.size(-1) - capacity = self.get_capacity(logits.shape) + probs = F.softmax(inputs, dim=-1) + num_experts = probs.size(-1) + capacity = self.get_capacity(inputs.shape) - top1_idx = torch.argmax(logits, dim=-1) + top1_idx = torch.argmax(probs, dim=-1) mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) - logits_except1 = logits.masked_fill(mask1.bool(), float("-inf")) + logits_except1 = probs.masked_fill(mask1.bool(), float("-inf")) top2_idx = torch.argmax(logits_except1, dim=-1) mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32) @@ -229,7 +270,8 @@ def forward(self, cmask = cmask.float() / 2.0 # div 2 to normalize it to 1 # caculate loss - self.set_aux_loss(logits, cmask, num_experts) + expert_indices = torch.stack([top1_idx, top2_idx], dim=-1) + self.set_aux_loss(probs, expert_indices, num_experts) self.set_z_loss(inputs) self.pop_router_loss() @@ -255,10 +297,10 @@ def forward(self, mask = torch.stack([mask1, mask2], dim=0).to(torch.int32) dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32) - return logits, mask, dest_idx, num_experts * capacity + return probs, mask, dest_idx, num_experts * capacity else: - weight1 = mask1 * logits.type_as(inputs) - weight2 = mask2 * logits.type_as(inputs) + weight1 = mask1 * probs.type_as(inputs) + weight2 = mask2 * probs.type_as(inputs) rank1_sc = F.one_hot(rank1, num_classes=capacity) rank2_sc = F.one_hot(rank2, num_classes=capacity) @@ -282,9 +324,9 @@ class TopKRouter(MoeRouter): processed by an expert, or that each expert receives at least one token. Attributes: - num_selected_experts: Maximum number of experts to which each token is - routed. Tokens may be routed to fewer experts if particular experts are - oversubscribed / reach capacity. + num_selected_experts: Maximum number of experts to which each token is + routed. Tokens may be routed to fewer experts if particular experts are + oversubscribed / reach capacity. """ def __init__(self, @@ -314,14 +356,15 @@ def forward(self, Returns: Dispatch and combine arrays for routing with masked matmuls. """ + # TODO: add parallel group num_groups, _, num_experts = router_probs.shape # Top-k router probability and corresponding expert indices for each token. # Shape: [num_groups, tokens_per_group, num_selected_experts]. expert_gate, expert_index = torch.topk(router_probs, self.k_value) - # TODO - # auxiliary_loss = _load_balancing_loss(router_probs, expert_index) + self.set_aux_loss(router_probs, expert_index, num_experts) + self.pop_router_loss() # Make num_selected_experts the leading axis to ensure that top-1 choices # have priority over top-2 choices, which have priority over top-3 choices, diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index ec7e1e8941f7..cf9c5013cc29 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -1020,7 +1020,19 @@ def _calculate_router_loss(self): z_loss = self.config.router_z_loss_factor * sum(z_loss) / len(z_loss) return aux_loss, z_loss - def _calculate_loss(self, logits, targets): + def _calculate_loss(self, + logits: torch.Tensor, + targets: torch.Tensor + ) -> torch.Tensor: + """Compute cross entropy and entropy for log probs and targets. + + Args: + logits: [batch, length, num_classes] float array. + targets: categorical targets [batch, length] int array. + + Returns: + Tuple of scalar loss. + """ if len(logits.shape) != len(targets.shape) + 1: raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' % (str(logits.shape), str(targets.shape))) @@ -1045,6 +1057,28 @@ def _calculate_loss(self, logits, targets): class ZLossCrossEntropy(torch.autograd.Function): + """Computes cross entropy loss with stable custom gradient. + + Computes a stabilized-gradient version of: + -jnp.sum(targets * nn.log_softmax(logits), axis=-1) + + If z_loss > 0, then an auxiliary loss equal to z_loss*log(z)^2 + will be added to the cross entropy loss (z = softmax normalization constant). + The two uses of z_loss are: + 1. To keep the logits from drifting too far from zero, which can cause + unacceptable roundoff errors in bfloat16. + 2. To encourage the logits to be normalized log-probabilities. + + Args: + logits: [batch, length, num_classes] float array. + targets: categorical one-hot targets [batch, length, num_classes] float + array. + z_loss: coefficient for auxilliary z-loss loss term. + + Returns: + tuple with the total loss and the z_loss, both + float arrays with shape [batch, length]. + """ @staticmethod def forward(ctx, logits, targets, z_loss): From fde57bf17fd4c1fc5c1d54df18c1659d6937d97d Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> Date: Thu, 14 Sep 2023 11:58:36 +0800 Subject: [PATCH 11/46] [moe] speed up embed and mlp (#4701) * update triton * update kernel * add init * add version check * update precision * update precision * update kernel in experts * update test arg * update settings --- .../kernel/triton/llama_act_combine_kernel.py | 185 ++++++++++++++++++ colossalai/moe/_operation.py | 32 +-- colossalai/moe/experts.py | 11 +- colossalai/moe/layers.py | 4 +- examples/language/openmoe/infer.py | 6 +- .../openmoe/model/modeling_openmoe.py | 33 ++-- .../triton/test_llama_act_combine.py | 56 ++++++ 7 files changed, 281 insertions(+), 46 deletions(-) create mode 100644 colossalai/kernel/triton/llama_act_combine_kernel.py create mode 100644 tests/test_infer_ops/triton/test_llama_act_combine.py diff --git a/colossalai/kernel/triton/llama_act_combine_kernel.py b/colossalai/kernel/triton/llama_act_combine_kernel.py new file mode 100644 index 000000000000..45996c0dca53 --- /dev/null +++ b/colossalai/kernel/triton/llama_act_combine_kernel.py @@ -0,0 +1,185 @@ +from functools import reduce +from typing import Any, Tuple + +import torch +from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd + +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +if HAS_TRITON: + PRECISION_MAP = { + "fp32": (0, torch.float32), + "fp16": (1, torch.float16), + "bf16": (2, torch.bfloat16), + } + + @triton.jit + def _llama_act_combine_forward( + X_GATE1, + X_GATE2, + X_UP, + Y, + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + BLOCK_SIZE: tl.constexpr, + ): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X_GATE1 += row * stride + X_GATE2 += row * stride + X_UP += row * stride + Y += row * stride + + # do activation and combine, and store in y + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.) + x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.) + x_up = tl.load(X_UP + cols, mask=mask, other=0.) + x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype) + y = x_gate1 * x_gate2 * x_gate2_sigmoid * x_up + # Write output + tl.store(Y + cols, y, mask=mask) + + @triton.jit + def _llama_act_combine_backward( + X_GATE1, + X_GATE2, + X_UP, + X_GATE1_GRAD, + X_GATE2_GRAD, + X_UP_GRAD, + Y_GRAD, + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + BLOCK_SIZE: tl.constexpr, + ): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X_GATE1 += row * stride + X_GATE2 += row * stride + X_UP += row * stride + X_GATE1_GRAD += row * stride + X_GATE2_GRAD += row * stride + X_UP_GRAD += row * stride + Y_GRAD += row * stride + + # do activation and combine, and store in y + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.) + x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.) + x_up = tl.load(X_UP + cols, mask=mask, other=0.) + y_grad = tl.load(Y_GRAD + cols, mask=mask, other=0.) + + # forward: y = x_gate1 * x_gate2 * tl.sigmoid(x_gate2) * x_up + x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype) + x_gate2_act = y_grad * x_gate2 * x_gate2_sigmoid + x_up_grad = x_gate2_act * x_gate1 + x_gate1_grad = x_gate2_act * x_up + # grad(x*sigmoid(x)) = sigmoid(x) + x * sigmoid(x) * [1 − sigmoid(x)] + # = sigmoid(x) * {1 + x * [(1 − sigmoid(x)]} + x_gate2_grad = (y_grad * x_gate1 * x_up) * x_gate2_sigmoid * (1 + x_gate2 * (1 - x_gate2_sigmoid)) + + # Write output + tl.store(X_GATE1_GRAD + cols, x_gate1_grad, mask=mask) + tl.store(X_GATE2_GRAD + cols, x_gate2_grad, mask=mask) + tl.store(X_UP_GRAD + cols, x_up_grad, mask=mask) + + class LlamaActCombine(torch.autograd.Function): + """ + act(x_gate) * x_up + + Args: + x_gate (torch.Tensor): (b, l, 2d) x_gate + x_up (torch.Tensor): (b, l, d) x_up + activation (str): only support swiglu + precision (str): fp32, fp16, bf16 + """ + + @staticmethod + @custom_fwd + def forward(ctx: Any, x_gate: torch.Tensor, x_up: torch.Tensor, activation: str = "swiglu") -> torch.Tensor: + """ + act(x_gate) * x_up + + Args: + x_gate (torch.Tensor): (b, l, 2d) x gate + x_up (torch.Tensor): (b, l, d) x up + activation (str): only support swiglu + """ + assert activation == "swiglu", "Only swiglu is supported" + + # split x gate + assert x_gate.shape[-1] % 2 == 0, "axis size must be divisible by 2" + x_gate1, x_gate2 = torch.split(x_gate, x_gate.shape[-1] // 2, -1) + x_gate1 = x_gate1.contiguous() + x_gate2 = x_gate2.contiguous() + if not x_up.is_contiguous(): + x_up = x_up.contiguous() + # assert shape + assert x_gate1.shape == x_gate2.shape == x_up.shape + + # add ctx for backward + if x_gate.requires_grad: + ctx.save_for_backward(x_gate1, x_gate2, x_up) + + # allocate output + y = torch.empty_like(x_up) + M, N = reduce(lambda x, y: x * y, x_up.shape[:-1]), x_up.shape[-1] + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x_gate.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_SIZE: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + # restore setting + ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps = M, N, BLOCK_SIZE, num_warps + # enqueue kernel + _llama_act_combine_forward[(M,)](x_gate1, + x_gate2, + x_up, + y, + x_up.stride(-2), + N, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps) + return y + + @staticmethod + @custom_bwd + def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, Tensor, None, None]: + # restore from ctx + (x_gate1, x_gate2, x_up) = ctx.saved_tensors + M, N, BLOCK_SIZE, num_warps = ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps + + # init grad + y_grad = grad_outputs[0] + x_gate1_grad, x_gate2_grad, x_up_grad = torch.empty_like(x_gate1), torch.empty_like( + x_gate2), torch.empty_like(x_up) + + # enqueue kernel + _llama_act_combine_backward[(M,)](x_gate1, + x_gate2, + x_up, + x_gate1_grad, + x_gate2_grad, + x_up_grad, + y_grad, + x_up.stride(-2), + N, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps) + x_gate_grad = torch.cat([x_gate1_grad, x_gate2_grad], dim=-1) + return x_gate_grad, x_up_grad, None, None diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index 01530bb55c20..a67feaefbfb8 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -5,32 +5,16 @@ from torch import Tensor from torch.distributed import ProcessGroup -COL_MOE_KERNEL_FLAG = False - try: from colossalai._C import moe except: - moe = None - - -def build_moe_if_not_prebuilt(): - # load moe kernel during runtime if not pre-built - global moe - if moe is None: - from colossalai.kernel.op_builder import MOEBuilder - - moe = MOEBuilder().load() + from colossalai.kernel.op_builder import MOEBuilder + moe = MOEBuilder().load() class AllGather(torch.autograd.Function): @staticmethod def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: - global moe - - if moe is None: - from colossalai.kernel.op_builder import MOEBuilder - - moe = MOEBuilder().load() if ctx is not None: ctx.comm_grp = group @@ -102,9 +86,6 @@ def forward(ctx, tokens, mask, dest_idx, ec): s = tokens.size(0) h = tokens.size(1) - # load moe kernel during runtime if not pre-built - build_moe_if_not_prebuilt() - expert_input = moe.dispatch_forward(s, ec, h, tokens, mask, dest_idx) ctx.save_for_backward(mask, dest_idx) @@ -131,10 +112,7 @@ def forward(ctx, expert_tokens, logits, mask, dest_idx, ec): c = ec // e h = expert_tokens.size(-1) - # load moe kernel during runtime if not pre-built - build_moe_if_not_prebuilt() - - fp16_flag = expert_tokens.dtype == torch.float16 + fp16_flag = (expert_tokens.dtype == torch.float16) cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens ctokens = moe.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx) output = ctokens.to(torch.float16) if fp16_flag else ctokens @@ -163,9 +141,7 @@ def backward(ctx, tokens_grad): def moe_cumsum(inputs: Tensor): dim0 = inputs.size(0) flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0) - if flag and COL_MOE_KERNEL_FLAG: - # load moe kernel during runtime if not pre-built - build_moe_if_not_prebuilt() + if flag: return moe.cumsum_sub_one(inputs) else: return torch.cumsum(inputs, dim=0) - 1 diff --git a/colossalai/moe/experts.py b/colossalai/moe/experts.py index 9715f4dc37b3..4535d8ab9a85 100644 --- a/colossalai/moe/experts.py +++ b/colossalai/moe/experts.py @@ -4,12 +4,17 @@ import torch import torch.nn as nn + +from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import get_activation from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.moe_tensor.api import get_ep_size, set_moe_tensor_info +if HAS_TRITON: + from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine + class BaseMLPExperts(nn.Module): """ @@ -78,6 +83,7 @@ def __init__( nn.init.trunc_normal_(self.wi, std=math.sqrt(0.1 / hidden_size)) nn.init.trunc_normal_(self.wo, std=math.sqrt(0.1 / intermediate_size)) + self.act_name = activation self.act = get_activation(activation) self.drop = nn.Dropout(p=drop_rate) @@ -103,7 +109,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.reshape(e, -1, h) if self.gated: - x = self.act(torch.bmm(x, self.wi_gate)) * torch.bmm(x, self.wi_up) + if HAS_TRITON and self.act_name == "swiglu": + x = LlamaActCombine.apply(torch.bmm(x, self.wi_gate), torch.bmm(x, self.wi_up)) + else: + x = self.act(torch.bmm(x, self.wi_gate)) * torch.bmm(x, self.wi_up) else: x = torch.bmm(x, self.wi) x = self.act(x) diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index a3f68cf7a6f1..1255a4816041 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -5,7 +5,7 @@ import torch.nn as nn import torch.nn.functional as F -from colossalai.moe._operation import COL_MOE_KERNEL_FLAG, AllGather, AllToAll, MoeCombine, MoeDispatch, ReduceScatter +from colossalai.moe._operation import AllGather, AllToAll, MoeCombine, MoeDispatch, ReduceScatter from colossalai.moe.experts import BaseMLPExperts, get_expert_class from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.routers import MoeRouter, get_router_cls @@ -58,7 +58,7 @@ def __init__(self, super().__init__() self.hidden_size = hidden_size self.num_experts = num_experts - self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_MANAGER.use_kernel_optim else False + self.use_kernel = True if MOE_MANAGER.use_kernel_optim else False self.expert_parallel = expert_parallel assert expert_parallel in ["EP", "TP", None], f"Unsupported expert parallel type {expert_parallel}" diff --git a/examples/language/openmoe/infer.py b/examples/language/openmoe/infer.py index b41fa2f2e4f1..f59772189827 100644 --- a/examples/language/openmoe/infer.py +++ b/examples/language/openmoe/infer.py @@ -20,7 +20,7 @@ def inference(args): model = OpenMoeForCausalLM(config) else: model = OpenMoeForCausalLM.from_pretrained(f"hpcaitech/openmoe-{args.model}") - model = model.eval().bfloat16() + model = model.eval().half() model = model.to(torch.cuda.current_device()) input_str = """``` @@ -37,9 +37,9 @@ def inference(args): What is the value of sum immediately after the 10th time line 3 is executed?""" # print("model config: ", model.config) - input_ids = tokenizer("" + input_str, return_tensors="pt", add_special_tokens=True) + input_ids = tokenizer("" + input_str, return_tensors="pt", add_special_tokens=False) input_ids = input_ids.input_ids.to(torch.cuda.current_device()) - generation_output = model.generate(input_ids, use_cache=True, do_sample=True, max_new_tokens=128) + generation_output = model.generate(input_ids, use_cache=True, do_sample=True, max_new_tokens=16) out = tokenizer.decode(generation_output[0], skip_special_tokens=False) print(f"output: \n{out}\n") diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index cf9c5013cc29..6ccbf64a60e4 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -36,9 +36,13 @@ replace_return_docstrings, ) +from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON from colossalai.moe.layers import SparseMLP from colossalai.moe.manager import MOE_MANAGER +if HAS_TRITON: + from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "LlamaConfig" @@ -95,7 +99,7 @@ def generate_fixed_pos_embedding(features, length, min_timescale=1.0, max_timesc timescale = min_timescale * (max_timescale / min_timescale)**fraction rotational_frequency = 1. / timescale - sinusoid_inp = torch.einsum('i,j->ij', torch.arange(length, dtype=torch.float64).cuda(), rotational_frequency) + sinusoid_inp = torch.einsum('i,j->ij', torch.arange(length, dtype=torch.float32).cuda(), rotational_frequency) sinusoid_inp = torch.cat([sinusoid_inp, sinusoid_inp], dim=-1) @@ -121,16 +125,16 @@ def apply_rotary_embedding(q, k, cos, sin, decode=False, rotary_index=None): if decode and qlen == 1 and rotary_index is not None: qcos = cos[rotary_index + 1, :] qsin = sin[rotary_index + 1, :] - qcos = qcos.unsqueeze(2).expand(batch, qlen, qheads, d) - qsin = qsin.unsqueeze(2).expand(batch, qlen, qheads, d) + qcos = qcos.unsqueeze(2) + qsin = qsin.unsqueeze(2) + kcos, ksin = cos[:klen, :], sin[:klen, :] + kcos = kcos.unsqueeze(0).unsqueeze(2) + ksin = ksin.unsqueeze(0).unsqueeze(2) else: qcos, qsin = cos[:qlen, :], sin[:qlen, :] - qcos = qcos.unsqueeze(0).unsqueeze(2).expand(batch, qlen, qheads, d) - qsin = qsin.unsqueeze(0).unsqueeze(2).expand(batch, qlen, qheads, d) - - kcos, ksin = cos[:klen, :], sin[:klen, :] - kcos = kcos.unsqueeze(0).unsqueeze(2).expand(batch, klen, kheads, d) - ksin = ksin.unsqueeze(0).unsqueeze(2).expand(batch, klen, kheads, d) + qcos = qcos.unsqueeze(0).unsqueeze(2) + qsin = qsin.unsqueeze(0).unsqueeze(2) + kcos, ksin = qcos, qsin out_q = (q * qcos) + (rotate_half(q) * qsin) out_k = (k * kcos) + (rotate_half(k) * ksin) @@ -278,7 +282,10 @@ def forward(self, x): down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)] down_proj = sum(down_proj) else: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + if HAS_TRITON: + down_proj = self.down_proj(LlamaActCombine.apply(self.gate_proj(x), self.up_proj(x))) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj @@ -313,6 +320,7 @@ def __init__(self, config: LlamaConfig): self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.sin, self.cos = generate_fixed_pos_embedding(self.head_dim, self.max_position_embeddings, 1e4) self._init_rope() def _init_rope(self): @@ -382,9 +390,10 @@ def forward( query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) - dim = query_states.shape[-1] max_length = max(query_states.shape[1], key_states.shape[1]) - sin, cos = generate_fixed_pos_embedding(dim, max_length, max_timescale=1e4) + assert max_length <= self.sin.shape[0] + sin, cos = self.sin[:max_length], self.cos[:max_length] + # TODO: for inference, we can add emb kv into cache to avoid computation query_states, key_states = apply_rotary_embedding(query_states, key_states, cos, diff --git a/tests/test_infer_ops/triton/test_llama_act_combine.py b/tests/test_infer_ops/triton/test_llama_act_combine.py new file mode 100644 index 000000000000..5341aa35ab90 --- /dev/null +++ b/tests/test_infer_ops/triton/test_llama_act_combine.py @@ -0,0 +1,56 @@ +import pytest +import torch +from packaging import version +from torch import nn + +from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine + +try: + import triton + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') + +BATCH_SIZE = 4 +SEQ_LEN = 16 +HIDDEN_SIZE = 32 + + +def SwiGLU(x): + """Gated linear unit activation function. + Args: + x : input array + axis: the axis along which the split should be computed (default: -1) + """ + size = x.shape[-1] + assert size % 2 == 0, "axis size must be divisible by 2" + x1, x2 = torch.split(x, size // 2, -1) + return x1 * (x2 * torch.sigmoid(x2.to(torch.float32)).to(x.dtype)) + + +@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton") +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +def test_llama_act_combine(dtype: str): + x_gate = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE * 2, dtype=dtype).cuda() + x_gate_torch = nn.Parameter(x_gate.detach().clone()) + x_gate_kernel = nn.Parameter(x_gate.detach().clone()) + x_up = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, dtype=dtype).cuda() + x_up_torch = nn.Parameter(x_up.detach().clone()) + x_up_kernel = nn.Parameter(x_up.detach().clone()) + + torch_out = SwiGLU(x_gate_torch) * x_up_torch + kernel_out = LlamaActCombine.apply(x_gate_kernel, x_up_kernel) + atol = 1e-5 if dtype == torch.float32 else 5e-2 + assert torch.allclose(torch_out, kernel_out, atol=atol) + + torch_out.mean().backward() + kernel_out.mean().backward() + assert all(grad is not None for grad in [x_gate_torch.grad, x_up_torch.grad, x_gate_kernel.grad, x_up_kernel.grad]) + assert torch.allclose(x_gate_torch.grad, x_gate_kernel.grad, atol=atol) + assert torch.allclose(x_up_torch.grad, x_up_kernel.grad, atol=atol) + + +if __name__ == '__main__': + test_llama_act_combine(torch.float16) From adb8ebee28eed5191ebae2e90341bac6250bf0fc Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 15 Sep 2023 16:53:51 +0800 Subject: [PATCH 12/46] [moe] adapt to main modifications --- .../gradient_handler/_moe_gradient_handler.py | 2 +- tests/test_moe/moe_utils.py | 8 +-- tests/test_moe/test_grad_handler.py | 2 +- tests/test_moe/test_moe_ep_tp.py | 2 +- tests/test_moe/test_moe_local.py | 2 +- tests/test_moe/test_moe_router.py | 23 +++--- tests/test_moe/test_moe_zero_model.py | 70 ------------------- 7 files changed, 16 insertions(+), 93 deletions(-) delete mode 100644 tests/test_moe/test_moe_zero_model.py diff --git a/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py index 6a7224cff7bd..2c999ca77be7 100644 --- a/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py +++ b/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py @@ -2,7 +2,7 @@ from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.registry import GRADIENT_HANDLER -from colossalai.utils.moe import get_moe_epsize_param_dict +from colossalai.moe.utils import get_moe_epsize_param_dict from ._base_gradient_handler import BaseGradientHandler from .utils import bucket_allreduce diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py index 3371c35fd295..53266beb1877 100644 --- a/tests/test_moe/moe_utils.py +++ b/tests/test_moe/moe_utils.py @@ -2,13 +2,13 @@ import torch.distributed as dist import torch.nn as nn -from colossalai.engine.gradient_handler._base_gradient_handler import BaseGradientHandler -from colossalai.engine.gradient_handler.utils import bucket_allreduce +from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler +from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce +from colossalai.legacy.nn import CheckpointModule +from colossalai.legacy.registry import GRADIENT_HANDLER from colossalai.moe import SparseMLP from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import get_moe_epsize_param_dict -from colossalai.nn import CheckpointModule -from colossalai.registry import GRADIENT_HANDLER from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index cbfbcae6ce33..e3de8f101a74 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -12,7 +12,7 @@ from tests.test_moe.moe_utils import MoeGradientHandler, assert_not_equal_in_group BATCH_SIZE = 4 -DIM = 4 +DIM = 16 def run_test(rank, world_size, port): diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py index 253fe6a7c094..72b639c8b43a 100644 --- a/tests/test_moe/test_moe_ep_tp.py +++ b/tests/test_moe/test_moe_ep_tp.py @@ -11,7 +11,7 @@ from tests.test_moe.moe_utils import MoeGradientHandler, sync_tp_from_ep BATCH_SIZE = 4 -DIM = 4 +DIM = 16 def run_test(rank, world_size, port): diff --git a/tests/test_moe/test_moe_local.py b/tests/test_moe/test_moe_local.py index 872b65c2d1f1..09cc0cc6a4ef 100644 --- a/tests/test_moe/test_moe_local.py +++ b/tests/test_moe/test_moe_local.py @@ -11,7 +11,7 @@ from tests.test_moe.moe_utils import MoeGradientHandler, sync_local_from_ep BATCH_SIZE = 4 -DIM = 4 +DIM = 16 def run_test(rank, world_size, port): diff --git a/tests/test_moe/test_moe_router.py b/tests/test_moe/test_moe_router.py index 94c263baa5a3..fce0d1064950 100644 --- a/tests/test_moe/test_moe_router.py +++ b/tests/test_moe/test_moe_router.py @@ -1,7 +1,7 @@ import pytest import torch -from colossalai.moe.routers import (MoeRouter, Top1Router, Top2Router, - TopKRouter, get_router_cls) + +from colossalai.moe.routers import MoeRouter, Top1Router, Top2Router, TopKRouter @pytest.mark.parametrize(["router", "num_groups"], [ @@ -9,19 +9,12 @@ (Top2Router(), 1), (TopKRouter(num_selected_experts=3), 4), ]) -@pytest.mark.parametrize( - ["batch_size", "seq_len", "num_experts"], - [ - (4, 5, 8), - (3, 4, 4), - ] -) -def test_router_forward(router: MoeRouter, - batch_size: int, - seq_len: int, - num_experts: int, - num_groups: int): - x = torch.randn((batch_size * seq_len, num_experts)) +@pytest.mark.parametrize(["batch_size", "seq_len", "num_experts"], [ + (4, 5, 8), + (3, 4, 4), +]) +def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_experts: int, num_groups: int): + x = torch.randn((batch_size * seq_len, num_experts)).cuda() if num_groups > 1: x = x.expand(num_groups, -1, -1) diff --git a/tests/test_moe/test_moe_zero_model.py b/tests/test_moe/test_moe_zero_model.py deleted file mode 100644 index 724d70d77bc6..000000000000 --- a/tests/test_moe/test_moe_zero_model.py +++ /dev/null @@ -1,70 +0,0 @@ -import pytest -import torch - -import colossalai -from colossalai.context import MOE_CONTEXT -from colossalai.legacy.engine.gradient_handler import MoeGradientHandler -from colossalai.nn import MoeLoss -from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use, spawn -from colossalai.zero.legacy.init_ctx import ZeroInitContext -from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy -from colossalai.zero.legacy.sharded_model import ShardedModelV2 -from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16 -from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy -from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_moe.test_moe_zero_init import MoeModel -from tests.test_zero.test_legacy.common import CONFIG, check_grads_padding, run_fwd_bwd - - -@parameterize("enable_autocast", [False]) -@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) -def run_model_test(enable_autocast, shard_strategy_class): - shard_strategy = shard_strategy_class() - - get_components_func = non_distributed_component_funcs.get_callable("hanging_param_model") - _, train_dataloader, _, optimizer_class, _ = get_components_func() - criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss) - - with ZeroInitContext( - target_device=torch.device("cuda", torch.cuda.current_device()), shard_strategy=shard_strategy, shard_param=True - ): - zero_model = MoeModel(checkpoint=True) - zero_model = ShardedModelV2(zero_model, shard_strategy) - - # check whether parameters are identical in ddp - for name, p in zero_model.named_parameters(): - if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated: - assert_equal_in_group(p.colo_attr.data_payload) - - model = MoeModel(checkpoint=True).half() - col_model_deepcopy(zero_model, model) - model = model.cuda() - grad_handler = MoeGradientHandler(model) - - for i, (data, label) in enumerate(train_dataloader): - if i > 5: - break - - data, label = cast_tensor_to_fp16(data).cuda(), label.cuda() - run_fwd_bwd(model, data, label, criterion, enable_autocast) - run_fwd_bwd(zero_model, data, label, criterion, enable_autocast) - grad_handler.handle_gradient() - - check_grads_padding(model, zero_model, loose=True) - - -def run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - MOE_CONTEXT.setup(seed=42) - run_model_test() - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [2]) -@rerun_if_address_is_in_use() -def test_moe_zero_model(world_size): - spawn(run_dist, world_size) - - -if __name__ == "__main__": - test_moe_zero_model(world_size=2) From 3f02e57dfa7fa9c7f191a7446d929fb7134a2b5a Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Mon, 18 Sep 2023 10:39:43 +0800 Subject: [PATCH 13/46] [moe]: add flash attention & optimize top2 router (#4712) * feat: add benchmark train * perf: use flash_attn * fix: modify benchmark config * fix: check flash attn installation * fix: update config with args * perf: optimize top2 router --- colossalai/moe/routers.py | 26 ++- .../openmoe/benchmark/benchmark_train.py | 196 ++++++++++++++++++ .../openmoe/benchmark/benchmark_train.sh | 34 +++ examples/language/openmoe/benchmark/utils.py | 61 ++++++ .../openmoe/model/modeling_openmoe.py | 72 ++++--- 5 files changed, 352 insertions(+), 37 deletions(-) create mode 100644 examples/language/openmoe/benchmark/benchmark_train.py create mode 100755 examples/language/openmoe/benchmark/benchmark_train.sh create mode 100644 examples/language/openmoe/benchmark/utils.py diff --git a/colossalai/moe/routers.py b/colossalai/moe/routers.py index 6fa89a416203..1ac66f7bb78f 100644 --- a/colossalai/moe/routers.py +++ b/colossalai/moe/routers.py @@ -47,7 +47,7 @@ def get_capacity(self, logits_shape): capacity += capacity % 2 capacity = max(capacity, self.min_capacity) assert capacity > 0 - return capacity + return int(capacity) def set_aux_loss(self, router_probs: torch.Tensor, @@ -299,15 +299,27 @@ def forward(self, return probs, mask, dest_idx, num_experts * capacity else: + # >>> original code + # weight1 = mask1 * probs.type_as(inputs) + # weight2 = mask2 * probs.type_as(inputs) + # rank1_sc = F.one_hot(rank1, num_classes=capacity) + # rank2_sc = F.one_hot(rank2, num_classes=capacity) + + # cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1) + # cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1) + # cb_weight = cb_weight1 + cb_weight2 + # sec_mask = cb_weight.bool() + weight1 = mask1 * probs.type_as(inputs) weight2 = mask2 * probs.type_as(inputs) - rank1_sc = F.one_hot(rank1, num_classes=capacity) - rank2_sc = F.one_hot(rank2, num_classes=capacity) - cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1) - cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1) - cb_weight = cb_weight1 + cb_weight2 - sec_mask = cb_weight.bool() + cb_weight = torch.zeros(inputs.shape + (capacity, ), device=inputs.device) + sec_mask = torch.zeros_like(cb_weight, dtype=torch.bool) + indices = torch.arange(0, inputs.shape[0], device=inputs.device) + cb_weight[indices, top1_idx[indices], rank1[indices]] += weight1[indices, top1_idx[indices]] + cb_weight[indices, top2_idx[indices], rank2[indices]] += weight2[indices, top2_idx[indices]] + sec_mask[indices, top1_idx[indices], rank1[indices]] |= mask1.bool()[indices, top1_idx[indices]] + sec_mask[indices, top2_idx[indices], rank2[indices]] |= mask2.bool()[indices, top2_idx[indices]] return cb_weight, sec_mask diff --git a/examples/language/openmoe/benchmark/benchmark_train.py b/examples/language/openmoe/benchmark/benchmark_train.py new file mode 100644 index 000000000000..373516c56f84 --- /dev/null +++ b/examples/language/openmoe/benchmark/benchmark_train.py @@ -0,0 +1,196 @@ +import colossalai +import datasets +import torch +import transformers +from colossalai import get_default_parser +from colossalai.booster import Booster +from colossalai.booster.plugin import LowLevelZeroPlugin +from colossalai.cluster import DistCoordinator +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.moe.manager import MOE_MANAGER +from colossalai.moe.utils import skip_init +from colossalai.utils import get_current_device +from model.modeling_openmoe import OpenMoeForCausalLM +from torch.utils.data import Dataset +from tqdm import tqdm +from transformers import Adafactor +from transformers.models.llama import LlamaConfig +from utils import SimpleTimer, print_model_numel + + +class RandomDataset(Dataset): + + def __init__(self, + num_samples: int = 1000, + max_length: int = 2048, + vocab_size: int = 32000): + self.num_samples = num_samples + self.max_length = max_length + self.input_ids = torch.randint(0, vocab_size, + (num_samples, max_length), + device=get_current_device()) + self.attention_mask = torch.ones_like(self.input_ids, + device=get_current_device()) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + return { + 'input_ids': self.input_ids[idx], + 'attention_mask': self.attention_mask[idx], + 'labels': self.input_ids[idx] + } + + +def parse_args(): + parser = get_default_parser() + # TODO: add model_name + # parser.add_argument("--model_name", type=str, default="base", choices=["base", "8b"], + # help="Path to pretrained model or model identifier from huggingface.co/models.") + parser.add_argument("--num_epoch", type=int, default=1, help="Number of epochs.") + parser.add_argument("--batch_size", type=int, default=4, help="Batch size (per dp group) for the training dataloader.") + parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + parser.add_argument("--num_samples", type=int, default=1000, help="Number of samples in the dataset.") + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + MDOEL_CONFIG = { + "architectures": [ + "OpenMoeForCausalLM" + ], + "capacity_factor_eval": 2.0, + "capacity_factor_train": 1.25, + "drop_tks": True, + "dropout_rate": 0.0, + "expert_parallel": None, + "gated": True, + "head_dim": 64, + "hidden_act": "swiglu", + "hidden_size": 768, + "intermediate_size": 2048, + "label_smoothing": 0.0, + "layer_norm_epsilon": 1e-06, + "min_capacity": 4, + "moe_layer_interval": 4, + "noisy_policy": None, + "num_attention_heads": 12, + "num_experts": 16, + "num_hidden_layers": 12, + "num_key_value_heads": 12, + "pretraining_tp": 1, + "rope_scaling": None, + "router_aux_loss_factor": 0.01, + "router_z_loss_factor": 0.0001, + "topk": 2, + "torch_dtype": "float32", + "vocab_size": 256384, + "z_loss_factor": 0.0001 + } + OPTIM_CONFIG = { + "decay_rate": -0.8, + "weight_decay": 0.01, + } + + # update config from args + for k in MDOEL_CONFIG: + if hasattr(args, k): + MDOEL_CONFIG[k] = getattr(args, k) + + # Launch ColossalAI + colossalai.launch_from_torch(config={}, seed=args.seed) + coordinator = DistCoordinator() + + # Set up moe + MOE_MANAGER.setup(seed=42, parallel="EP") + + # Manage loggers + disable_existing_loggers() + logger = get_dist_logger() + if coordinator.is_master(): + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # Build OpenMoe model + config = LlamaConfig() + for k, v in MDOEL_CONFIG.items(): + setattr(config, k, v) + + with skip_init(): + model = OpenMoeForCausalLM(config) + + logger.info(f"Finish init model with config:\n{config}", ranks=[0]) + model_param = sum(p.numel() for p in model.parameters() if p.requires_grad) + logger.info(f"Model param count: {model_param/1e6:.2f}M", ranks=[0]) + + # Enable gradient checkpointing + model.gradient_checkpointing_enable() + + # Set plugin + plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2) + logger.info(f"Set plugin as {plugin}", ranks=[0]) + + # Prepare tokenizer and dataloader + dataset = RandomDataset(num_samples=args.num_samples) + dataloader = plugin.prepare_dataloader(dataset, + batch_size=args.batch_size, + shuffle=True, + drop_last=True) + + # Set optimizer + optimizer = Adafactor(model.parameters(), + decay_rate=OPTIM_CONFIG["decay_rate"], + weight_decay=OPTIM_CONFIG["weight_decay"]) + + # Set booster + booster = Booster(plugin=plugin) + model, optimizer, _, dataloader, _ = booster.boost(model=model, + optimizer=optimizer, + dataloader=dataloader) + + # Start benchmark + model.train() + logger.info(f"Start benchmark", ranks=[0]) + + timer = SimpleTimer() + for epoch in range(args.num_epoch): + for batch in tqdm(dataloader, + desc=f'Epoch [{epoch + 1}]', + disable=not coordinator.is_master()): + timer.start("train_step") + + # Forward + timer.start("forward") + outputs = model(use_cache=False, chunk_head=True, **batch) + loss = outputs['loss'] + torch.cuda.synchronize() + timer.stop("forward") + + # Backward + timer.start("backward") + booster.backward(loss, optimizer) + torch.cuda.synchronize() + timer.stop("backward") + + # Optimizer step + timer.start("optimizer_step") + optimizer.step() + optimizer.zero_grad() + torch.cuda.synchronize() + timer.stop("optimizer_step") + + timer.stop("train_step") + + logger.info(f"Benchmark result:\n{repr(timer)}", ranks=[0]) + + +if __name__ == "__main__": + main() diff --git a/examples/language/openmoe/benchmark/benchmark_train.sh b/examples/language/openmoe/benchmark/benchmark_train.sh new file mode 100755 index 000000000000..0496a31a7479 --- /dev/null +++ b/examples/language/openmoe/benchmark/benchmark_train.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +set -xue + +BENCHMARK_DIR=benchmark +NUM_GPU=2 + +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv | + tail -n +2 | + nl -v 0 | + tee /dev/tty | + sort -g -k 2 | + awk '{print $1}' | + head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} + +set_n_least_used_CUDA_VISIBLE_DEVICES $NUM_GPU + +# HACK: make model importable +example_dir=$(dirname $(realpath $(dirname $0))) +if [ -z ${PYTHONPATH+x} ]; then + export PYTHONPATH=$example_dir +else + export PYTHONPATH=$example_dir:$PYTHONPATH +fi + +torchrun --standalone --nproc_per_node $NUM_GPU \ + $example_dir/$BENCHMARK_DIR/benchmark_train.py diff --git a/examples/language/openmoe/benchmark/utils.py b/examples/language/openmoe/benchmark/utils.py new file mode 100644 index 000000000000..d2edee64451c --- /dev/null +++ b/examples/language/openmoe/benchmark/utils.py @@ -0,0 +1,61 @@ +import dataclasses +import time +from typing import Dict + +import torch.distributed as dist +import torch.nn as nn +from colossalai.logging import DistributedLogger + + +def print_model_numel(logger: DistributedLogger, + model: nn.Module) -> None: + B = 1024**3 + M = 1024**2 + K = 1024 + outputs = "Model param count: " + model_param = sum(p.numel() for p in model.parameters() if p.requires_grad) + if model_param >= B: + outputs += f'{model_param / B:.2f} B\n' + elif model_param >= M: + outputs += f'{model_param / M:.2f} M\n' + elif model_param >= K: + outputs += f'{model_param / K:.2f} K\n' + else: + outputs += f'{model_param}\n' + logger.info(outputs, ranks=[0]) + + +@dataclasses.dataclass +class TimingItem(): + last_time: float = 0.0 + total_time: float = 0.0 + count: float = 0 + + def __str__(self) -> str: + return f"average time: {self.total_time/self.count * 1000:.2f} ms" + + +class SimpleTimer(): + def __init__(self, warmup: int = 20) -> None: + self.timing_items: Dict[str, TimingItem] = {} + self.warmup = warmup + + def start(self, name: str): + if name not in self.timing_items: + self.timing_items[name] = TimingItem() + self.timing_items[name].last_time = time.time() + + def stop(self, name: str): + assert name in self.timing_items + timing_item = self.timing_items[name] + timing_item.total_time += time.time() - timing_item.last_time + timing_item.count += 1 + if timing_item.count > self.warmup: + timing_item.count = 0 + timing_item.total_time = 0.0 + + def __repr__(self) -> str: + result = "[Timer]:\n" + for name, timing_item in self.timing_items.items(): + result += f" {name}: {timing_item}\n" + return result diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index 6ccbf64a60e4..4775a3ebea0d 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -24,24 +24,23 @@ import torch import torch.nn.functional as F import torch.utils.checkpoint +from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN +from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON +from colossalai.moe.layers import SparseMLP +from colossalai.moe.manager import MOE_MANAGER from torch import nn -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast) from transformers.modeling_utils import PreTrainedModel from transformers.models.llama import LlamaConfig from transformers.models.t5.modeling_t5 import T5LayerNorm -from transformers.utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) - -from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON -from colossalai.moe.layers import SparseMLP -from colossalai.moe.manager import MOE_MANAGER +from transformers.utils import (add_start_docstrings, + add_start_docstrings_to_model_forward, logging, + replace_return_docstrings) if HAS_TRITON: - from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine + from colossalai.kernel.triton.llama_act_combine_kernel import \ + LlamaActCombine logger = logging.get_logger(__name__) @@ -349,6 +348,7 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, + use_kernel: bool = True, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -407,24 +407,36 @@ def forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError(f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}") - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}") - if self.training: - attention_mask = attention_mask.clone().detach() - attention_mask[:, :, :, 0] = 0 - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) + if HAS_FLASH_ATTN and use_kernel: + from flash_attn import flash_attn_func + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + attn_output = flash_attn_func(query_states, + key_states, + value_states, + softmax_scale=1.0, + causal=True) + attn_output = attn_output.transpose(1, 2).contiguous() + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError(f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}") + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}") + if self.training: + attention_mask = attention_mask.clone().detach() + attention_mask[:, :, :, 0] = 0 + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" From d12bbe7847edcf33025c34230ea3fe20baea520a Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> Date: Thu, 21 Sep 2023 11:20:15 +0800 Subject: [PATCH 14/46] [moe] support hybrid parallel (#4748) * init policy * renam,e * update pp * finish pp * update script * update plugin * finish pp * update setup for different plugin * update ci * update ci * update ci * support ep inside or dp inside * update arg for kernel * disable ci * update train script * update plugin --- .../plugin/moe_hybrid_parallel_plugin.py | 173 ++++++ colossalai/moe/manager.py | 91 ++- colossalai/tensor/moe_tensor/api.py | 11 +- colossalai/tensor/moe_tensor/moe_info.py | 27 +- examples/language/openmoe/model/__init__.py | 0 .../openmoe/model/modeling_openmoe.py | 145 +---- .../language/openmoe/model/openmoe_policy.py | 545 ++++++++++++++++++ examples/language/openmoe/test_ci.sh | 5 - examples/language/openmoe/train.py | 156 +++-- examples/language/openmoe/train.sh | 8 +- 10 files changed, 962 insertions(+), 199 deletions(-) create mode 100644 colossalai/booster/plugin/moe_hybrid_parallel_plugin.py create mode 100644 examples/language/openmoe/model/__init__.py create mode 100644 examples/language/openmoe/model/openmoe_policy.py diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py new file mode 100644 index 000000000000..fab6c2f0cb7b --- /dev/null +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -0,0 +1,173 @@ +from typing import Optional + +import torch +import torch.distributed as dist + +from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelPlugin +from colossalai.cluster import ProcessGroupMesh +from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig +from colossalai.shardformer.policies.base_policy import Policy + +PP_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2 + + +class MoeHybridParallelPlugin(HybridParallelPlugin): + """ + Plugin for Moe Hybrid Parallel Training. + Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin. + The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size). + + Example: + >>> from colossalai.booster import Booster + >>> from colossalai.booster.plugin import HybridParallelPlugin + + >>> model, train_dataset, optimizer, criterion = ... + >>> plugin = HybridParallelPlugin(tp_size=2, pp_size=2) + + >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) + >>> booster = Booster(plugin=plugin) + >>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader) + + Args: + tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1. + pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1. + precision (str, optional): Specifies the precision of parameters during training. + Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'. + Defaults to 'fp16'. + zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2]. + When set to 0, ZeRO will not be used. Defaults to 0. + enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer. + Currently all the optimization methods include fused normalization, flash attention and JIT. + Defaults to False. + enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False. + enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False. + enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. + enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. + enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False. + num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None. + microbatch_size (int, optional): Microbatch size when using pipeline parallelism. + Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline. + If ``num_microbatches`` is provided, this will be ignored. Defaults to None. + initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16. + min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1. + growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2. + backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5. + growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000. + hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2. + max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32. + max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0. + broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True. + ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25. + find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False. + check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False. + gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False. + static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False. + zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12. + cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False. + communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None. + overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True. + """ + + def __init__(self, + tp_size: int, + pp_size: int, + precision: str = 'fp16', + zero_stage: int = 0, + enable_all_optimization: bool = False, + enable_fused_normalization: bool = False, + enable_flash_attention: bool = False, + enable_jit_fused: bool = False, + enable_sequence_parallelism: bool = False, + enable_sequence_overlap: bool = False, + num_microbatches: Optional[int] = None, + microbatch_size: Optional[int] = None, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0, + broadcast_buffers: bool = True, + ddp_bucket_cap_mb: int = 25, + find_unused_parameters: bool = False, + check_reduction: bool = False, + gradient_as_bucket_view: bool = False, + static_graph: bool = False, + zero_bucket_size_in_m: int = 12, + cpu_offload: bool = False, + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True, + custom_policy: Policy = None) -> None: + + super().__init__() + assert dist.get_world_size() % ( + tp_size * pp_size + ) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}' + + if enable_sequence_parallelism: + assert tp_size > 1, 'Sequence parallelism must be enabled when using tensor parallelism' + + self.tp_size = tp_size + self.pp_size = pp_size + self.dp_size = dist.get_world_size() // (tp_size * pp_size) + self.precision = precision + self.zero_stage = zero_stage + self.cpu_offload = cpu_offload + self.enable_all_optimization = enable_all_optimization + self.enable_fused_normalization = enable_fused_normalization + self.enable_flash_attention = enable_flash_attention + self.enable_jit_fused = enable_jit_fused + self.enable_sequence_parallelism = enable_sequence_parallelism + # we change pg mesh to (pp, dp, tp) for better moe performance + self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size) + self.stage_manager = None + self.schedule = None + self.custom_policy = custom_policy + assert zero_stage in (0, 1, 2) + if self.pp_size > 1: + assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism' + assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism' + self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS) + self.schedule = OneForwardOneBackwardSchedule(self.stage_manager, + num_microbatches=num_microbatches, + microbatch_size=microbatch_size) + self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) + self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) + self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) + self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group, + pipeline_stage_manager=self.stage_manager, + enable_tensor_parallelism=self.tp_size > 1, + enable_all_optimization=self.enable_all_optimization, + enable_fused_normalization=self.enable_fused_normalization, + enable_flash_attention=self.enable_flash_attention, + enable_jit_fused=self.enable_jit_fused, + enable_sequence_parallelism=enable_sequence_parallelism, + enable_sequence_overlap=enable_sequence_overlap) + self.amp_config = dict( + initial_scale=initial_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + min_scale=min_scale, + max_scale=max_scale, + ) + + self.ddp_config = dict(broadcast_buffers=broadcast_buffers, + bucket_cap_mb=ddp_bucket_cap_mb, + find_unused_parameters=find_unused_parameters, + check_reduction=check_reduction, + gradient_as_bucket_view=gradient_as_bucket_view, + static_graph=static_graph) + + self.zero_config = dict(reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + cpu_offload=cpu_offload, + partition_grad=(self.zero_stage == 2)) + + self.max_norm = max_norm diff --git a/colossalai/moe/manager.py b/colossalai/moe/manager.py index 3dc27c6cb0f0..e61fb0bf9582 100644 --- a/colossalai/moe/manager.py +++ b/colossalai/moe/manager.py @@ -24,7 +24,9 @@ def __init__(self): self.router_z_loss = [] self.parallel = None self.seed = None - self.use_kernel_optim = True + self.mode = None + self.use_kernel_optim = False + self.use_ep_inside = None self.has_setup = False self._parallel_info_dict = dict() @@ -37,15 +39,53 @@ def parallel_info_dict(self): def is_initialized(self): return self.has_setup - def setup(self, seed: int, use_kernel_optim: bool = True, max_ep_size: int = 8, parallel: bool = None): + def setup(self, + seed: int, + use_kernel_optim: bool = True, + parallel: bool = None, + mode: str = "dynamic", + max_ep_size: int = 8, + fixed_dp_size: int = 0, + fixed_ep_size: int = 0, + fixed_pp_size: int = 0, + use_ep_inside: bool = True) -> None: + """ + Setup MoE distributed context. + + Args: + seed (int): Random seed. Defaults to 42. + use_kernel_optim (bool, optional): Use cuda kernel. Defaults to True. + parallel (bool, optional): Parallel mode, should be EP, TP or None. Defaults to None. + mode (str, optional): Should be "fixed" or "dynamic". Defaults to "dynamic". + In fixed mode, the ep size and dp size is fixed. + In dynamic mode, the ep size and dp size will be changed according to num experts. + max_ep_size (int, optional): Max ep size in dynamic mode. Defaults to 8. + fixed_dp_size (int, optional): Fixed dp size in fixed mode. Defaults to 0. + fixed_ep_size (int, optional): Fixed ep size in fixed mode. Defaults to 0. + fixed_pp_size (int, optional): Fixed pp size in fixed mode. Defaults to 0. + use_ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True. + """ assert not self.is_initialized, "MoE distributed context shouldn't be set up again" assert torch.cuda.is_available(), "MoE requires to enable CUDA first" self.world_size = dist.get_world_size() self.seed = seed + dist.get_rank() - self.max_ep_size = min(max_ep_size, dist.get_world_size()) - self.min_dp_size = self.world_size // self.max_ep_size self.parallel = parallel + self.use_ep_inside = use_ep_inside + + # init by mode + self.mode = mode + assert self.mode in ["fixed", "dynamic"], "mode should be fixed or dynamic" + if self.mode == "dynamic": + self.max_ep_size = min(max_ep_size, dist.get_world_size()) + self.min_dp_size = self.world_size // self.max_ep_size + else: + assert fixed_dp_size > 0 and fixed_ep_size > 0 and fixed_pp_size > 0, "dp_size, ep_size and pp_size should be greater than 0" + assert isinstance(fixed_dp_size, int) and isinstance(fixed_ep_size, int) and isinstance( + fixed_pp_size, int), "dp_size, ep_size and pp_size should be int" + self.ep_size = fixed_ep_size + self.dp_size = fixed_dp_size + self.pp_size = fixed_pp_size # Enabling kernel optimization may raise error in some cases # Users can close kernel optimization manually @@ -67,30 +107,39 @@ def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoePara number of local experts, the MoeParallelInfo of the current ep_size """ - gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater - lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less - - assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number" \ - " is not a multiple of ep size or vice versa." - - # If the number of experts is greater than maximum expert parallel size. a.k.a ep_size, - # there are multiple experts in each GPU and each GPU has different experts - # So it's data parallel size is 1 - # Otherwise, there is only one expert in each GPU - # The data parallel size should be calculated - dp_size = 1 if gt_flag else self.max_ep_size // num_experts - ep_size = self.max_ep_size // dp_size + if self.mode == "dynamic": + gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater + lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less + + assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number" \ + " is not a multiple of ep size or vice versa." + + # If the number of experts is greater than maximum expert parallel size. a.k.a ep_size, + # there are multiple experts in each GPU and each GPU has different experts + # So it's data parallel size is 1 + # Otherwise, there is only one expert in each GPU + # The data parallel size should be calculated + dp_size = 1 if gt_flag else self.max_ep_size // num_experts + ep_size = self.max_ep_size // dp_size + # Don't forget to multiply minimum data parallel size + dp_size *= self.min_dp_size + pp_size = 1 + else: + dp_size = self.dp_size + ep_size = self.ep_size + pp_size = self.pp_size # Calculate the number of experts for each GPU if use_tp: num_local_experts = num_experts else: - num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size + if self.mode == "dynamic": + num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size + else: + num_local_experts = num_experts // ep_size - # Don't forget to multiply minimum data parallel size - dp_size *= self.min_dp_size if not (ep_size in self.parallel_info_dict): - self.parallel_info_dict[ep_size] = get_moe_info(ep_size, dp_size) + self.parallel_info_dict[ep_size] = get_moe_info(ep_size, dp_size, pp_size, ep_inside=self.use_ep_inside) return num_local_experts, self.parallel_info_dict[ep_size] diff --git a/colossalai/tensor/moe_tensor/api.py b/colossalai/tensor/moe_tensor/api.py index 442b3c0f4958..9120a40b8533 100644 --- a/colossalai/tensor/moe_tensor/api.py +++ b/colossalai/tensor/moe_tensor/api.py @@ -28,20 +28,23 @@ def set_moe_tensor_info(tensor: torch.Tensor, moe_info: MoeParallelInfo) -> None moe_info (dict): The moe info to be set. """ - tensor.__setattr__('moe_info', moe_info) + tensor.__setattr__("moe_info", moe_info) -def get_moe_info(ep_size: int, dp_size: int) -> MoeParallelInfo: +def get_moe_info(ep_size: int, dp_size: int, pp_size: int, ep_inside: bool) -> MoeParallelInfo: """ Get moe info for the given tensor. Args: - tensor (torch.Tensor): The tensor to be checked. + ep_size (int): The expert parallel size. + dp_size (int): The data parallel size. + pp_size (int): The pipeline parallel size. + ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Returns: dict: The moe info of the given tensor. """ - return MoeParallelInfo(ep_size, dp_size) + return MoeParallelInfo(ep_inside, ep_size, dp_size, pp_size) def get_ep_group(tensor: torch.Tensor) -> ProcessGroup: diff --git a/colossalai/tensor/moe_tensor/moe_info.py b/colossalai/tensor/moe_tensor/moe_info.py index ca7f163b9c24..5097ac1044e7 100644 --- a/colossalai/tensor/moe_tensor/moe_info.py +++ b/colossalai/tensor/moe_tensor/moe_info.py @@ -2,15 +2,26 @@ class MoeParallelInfo: - """Moe parallelism information, storing parallel sizes and groups. - """ + """Moe parallelism information, storing parallel sizes and groups.""" + + def __init__(self, ep_inside: bool, ep_size: int, dp_size: int, pp_size: int = 1): + """ + init MoeParallelInfo with ep_size, dp_size and pp_size + + Args: + ep_size (int): expert parallel size + dp_size (int): data parallel (zero) size + pp_size (int, optional): pipeline parallel size. Defaults to 1. + ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True. + """ + self.pp_size, self.dp_size, self.ep_size = pp_size, dp_size, ep_size + if ep_inside: + self.pp_axis, self.dp_axis, self.ep_axis = 0, 1, 2 + self.pg = ProcessGroupMesh(self.pp_size, self.dp_size, self.ep_size) + else: + self.pp_axis, self.ep_axis, self.dp_axis = 0, 1, 2 + self.pg = ProcessGroupMesh(self.pp_size, self.ep_size, self.dp_size) - def __init__(self, ep_size: int, dp_size: int): - self.dp_axis = 0 - self.dp_size = dp_size - self.ep_axis = 1 - self.ep_size = ep_size - self.pg = ProcessGroupMesh(self.dp_size, self.ep_size) self.ep_group = self.pg.get_group_along_axis(self.ep_axis) self.ep_group_ranks = self.pg.get_ranks_in_group(self.ep_group) self.dp_group = self.pg.get_group_along_axis(self.dp_axis) diff --git a/examples/language/openmoe/model/__init__.py b/examples/language/openmoe/model/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index 4775a3ebea0d..a774d4e9fd55 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -144,87 +144,6 @@ def apply_rotary_embedding(q, k, cos, sin, decode=False, rotary_index=None): return out_q, out_k -class LlamaRotaryEmbedding(torch.nn.Module): - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.inv_freq = inv_freq - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache(seq_len=max_position_embeddings, - device=self.inv_freq.device, - dtype=torch.get_default_dtype()) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - ) - - -class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - t = t / self.scaling_factor - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - -class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - - if seq_len > self.max_position_embeddings: - base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) - - (self.scaling_factor - 1))**(self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq) - - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., :x.shape[-1] // 2] @@ -232,17 +151,6 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - def SwiGLU(x): """Gated linear unit activation function. Args: @@ -255,7 +163,7 @@ def SwiGLU(x): return x1 * (x2 * torch.sigmoid(x2)) -class LlamaMLP(nn.Module): +class OpenMoeMLP(nn.Module): def __init__(self, config): super().__init__() @@ -266,6 +174,7 @@ def __init__(self, config): self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = SwiGLU + self.use_kernel = True if MOE_MANAGER.use_kernel_optim else False def forward(self, x): if self.pretraining_tp > 1: @@ -281,7 +190,7 @@ def forward(self, x): down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)] down_proj = sum(down_proj) else: - if HAS_TRITON: + if HAS_TRITON and self.use_kernel: down_proj = self.down_proj(LlamaActCombine.apply(self.gate_proj(x), self.up_proj(x))) else: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) @@ -301,7 +210,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -class LlamaAttention(nn.Module): +class OpenMoeAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: LlamaConfig): @@ -320,22 +229,6 @@ def __init__(self, config: LlamaConfig): self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.sin, self.cos = generate_fixed_pos_embedding(self.head_dim, self.max_position_embeddings, 1e4) - self._init_rope() - - def _init_rope(self): - if self.config.rope_scaling is None: - self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) - else: - scaling_type = self.config.rope_scaling["type"] - scaling_factor = self.config.rope_scaling["factor"] - if scaling_type == "linear": - self.rotary_emb = LlamaLinearScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor) - elif scaling_type == "dynamic": - self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -458,13 +351,13 @@ def forward( return attn_output, attn_weights, past_key_value -class LlamaDecoderLayer(nn.Module): +class OpenMoeDecoderLayer(nn.Module): def __init__(self, config: LlamaConfig, moe: bool): super().__init__() self.hidden_size = config.hidden_size self.moe = moe - self.self_attn = LlamaAttention(config=config) + self.self_attn = OpenMoeAttention(config=config) self.input_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) if self.moe: @@ -482,9 +375,9 @@ def __init__(self, config: LlamaConfig, moe: bool): activation=config.hidden_act, gated=config.gated) self.pre_extra_mlp_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) - self.extra_mlp = LlamaMLP(config) + self.extra_mlp = OpenMoeMLP(config) else: - self.mlp = LlamaMLP(config) + self.mlp = OpenMoeMLP(config) def forward( self, @@ -568,7 +461,7 @@ def forward( "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", LLAMA_START_DOCSTRING, ) -class LlamaPreTrainedModel(PreTrainedModel): +class OpenMoePreTrainedModel(PreTrainedModel): config_class = LlamaConfig base_model_prefix = "model" supports_gradient_checkpointing = True @@ -587,7 +480,7 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, LlamaModel): + if isinstance(module, OpenMoeModel): module.gradient_checkpointing = value @@ -659,7 +552,7 @@ def _set_gradient_checkpointing(self, module, value=False): "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", LLAMA_START_DOCSTRING, ) -class LlamaModel(LlamaPreTrainedModel): +class OpenMoeModel(OpenMoePreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] @@ -674,7 +567,7 @@ def __init__(self, config: LlamaConfig): self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, moe=True if (i + 1) % config.moe_layer_interval == 0 else False) + OpenMoeDecoderLayer(config, moe=True if (i + 1) % config.moe_layer_interval == 0 else False) for i in range(config.num_hidden_layers) ]) self.norm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -839,12 +732,12 @@ def custom_forward(*inputs): ) -class OpenMoeForCausalLM(LlamaPreTrainedModel): +class OpenMoeForCausalLM(OpenMoePreTrainedModel): # _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) - self.model = LlamaModel(config) + self.model = OpenMoeModel(config) self.pretraining_tp = config.pretraining_tp self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) @@ -1034,17 +927,15 @@ def _reorder_cache(past_key_values, beam_idx): past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),) return reordered_past - def _calculate_router_loss(self): - aux_loss, z_loss = MOE_MANAGER.get_loss() + def _calculate_router_loss(self, aux_loss: list = None, z_loss: list = None): + if aux_loss is None or z_loss is None: + aux_loss, z_loss = MOE_MANAGER.get_loss() assert len(aux_loss) == len(z_loss) == self.config.num_hidden_layers // self.config.moe_layer_interval aux_loss = self.config.router_aux_loss_factor * sum(aux_loss) / len(aux_loss) z_loss = self.config.router_z_loss_factor * sum(z_loss) / len(z_loss) return aux_loss, z_loss - def _calculate_loss(self, - logits: torch.Tensor, - targets: torch.Tensor - ) -> torch.Tensor: + def _calculate_loss(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """Compute cross entropy and entropy for log probs and targets. Args: diff --git a/examples/language/openmoe/model/openmoe_policy.py b/examples/language/openmoe/model/openmoe_policy.py new file mode 100644 index 000000000000..cc82683cd319 --- /dev/null +++ b/examples/language/openmoe/model/openmoe_policy.py @@ -0,0 +1,545 @@ +import warnings +from functools import partial +from typing import Callable, Dict, List, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.nn import Module +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.utils import logging + +from colossalai.moe.manager import MOE_MANAGER +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +from .modeling_openmoe import OpenMoeDecoderLayer, OpenMoeForCausalLM, OpenMoeModel + +__all__ = ["OpenMoePolicy", "OpenMoeForCausalLMPolicy"] + + +class OpenMoePolicy(Policy): + + def config_sanity_check(self): + pass + + def preprocess(self): + if self.shard_config.enable_tensor_parallelism: + # Resize embedding + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + policy = {} + + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + raise NotImplementedError( + "openmoe dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + + if self.shard_config.enable_tensor_parallelism: + raise NotImplementedError("Tensor parallelism is not supported for openmoe model now.") + + # optimization configuration + if self.shard_config.enable_fused_normalization: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=FusedRMSNorm, + ), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", + target_module=FusedRMSNorm, + ), + SubModuleReplacementDescription( + suffix="pre_extra_mlp_layernorm", + target_module=FusedRMSNorm, + ignore_if_not_exist=True, + ), + ], + policy=policy, + target_key=OpenMoeDecoderLayer, + ) + + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="norm", + target_module=FusedRMSNorm, + ), + policy=policy, + target_key=OpenMoeModel, + ) + + if self.shard_config.enable_flash_attention: + raise NotImplementedError("Flash attention has already been replaced in openmoe.") + + return policy + + def postprocess(self): + return self.model + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "OpenMoeModel": + module = self.model + else: + module = self.model.model + + layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=model_cls) + + return + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == "LlamaModel": + module = self.model + else: + module = self.model.model + stage_manager = self.pipeline_stage_manager + + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.norm) + + return held_layers + + +class OpenMoeModelPolicy(OpenMoePolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=OpenMoeModel, + new_forward=OpenMoePipelineForwards.openmoe_model_forward, + policy=policy, + ) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + held_layers = super().get_held_layers() + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in llama model""" + return [] + + +class OpenMoeForCausalLMPolicy(OpenMoePolicy): + + def module_policy(self): + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + # add a new item for casual lm + new_item = { + OpenMoeForCausalLM: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True), + ) + ]) + } + policy.update(new_item) + + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=OpenMoeForCausalLM, + new_forward=OpenMoePipelineForwards.llama_for_causal_lm_forward, + policy=policy, + ) + + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + llama_model = self.model.model + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: + if (id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight) + and self.pipeline_stage_manager.num_stages > 1): + # tie weights + return [{ + 0: llama_model.embed_tokens.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, + }] + return [] + + +class OpenMoePipelineForwards: + """ + This class serves as a micro library for forward function substitution of Llama models + under pipeline setting. + """ + + @staticmethod + def openmoe_model_forward( + self: OpenMoeModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + past_router_aux_loss: Optional[torch.FloatTensor] = None, + past_router_z_loss: Optional[torch.FloatTensor] = None, + ): + # reset moe loss for different data + MOE_MANAGER.reset_loss() + + logger = logging.get_logger(__name__) + + output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions) + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + + # retrieve input_ids and inputs_embeds + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + device = input_ids.device if input_ids is not None else inputs_embeds.device + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + seq_length_with_past = seq_length + past_key_values_length = 0 + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + if use_cache: + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") + use_cache = False + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # embed positions, for the first stage, hidden_states is the input embeddings, + # for the other stages, hidden_states is the output of the previous stage + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), + dtype=torch.bool, + device=hidden_states.device, + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, + (batch_size, seq_length), + hidden_states, + past_key_values_length, + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + start_idx, end_idx = stage_index[0], stage_index[1] + for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = (past_key_values[idx] if past_key_values is not None else None) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if stage_manager.is_last_stage(): + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + + # concat past losses with current ones + router_aux_loss, router_z_loss = MOE_MANAGER.get_loss() + if past_router_aux_loss is not None and past_router_z_loss is not None: + router_aux_loss = past_router_aux_loss + router_aux_loss + router_z_loss = past_router_z_loss + router_z_loss + + if stage_manager.is_last_stage(): + return tuple([ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + router_aux_loss, + router_z_loss, + ]) + # always return dict for imediate stage + return { + "hidden_states": hidden_states, + "router_aux_loss": router_aux_loss, + "router_z_loss": router_z_loss, + } + + @staticmethod + def llama_for_causal_lm_forward( + self: OpenMoeForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + chunk_head: Optional[bool] = None, + past_router_aux_loss: Optional[torch.FloatTensor] = None, + past_router_z_loss: Optional[torch.FloatTensor] = None, + ): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" + logger = logging.get_logger(__name__) + output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions) + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = OpenMoePipelineForwards.openmoe_model_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + past_router_aux_loss=past_router_aux_loss, + past_router_z_loss=past_router_z_loss, + ) + + if stage_manager.is_last_stage(): + ( + hidden_states, + past_key_values, + all_hidden_states, + attentions, + router_aux_loss, + router_z_loss, + ) = outputs + + if self.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + + loss = None + # if no training, just do forward + if labels is None: + logits = self.lm_head(hidden_states) + logits = logits.float() + # the vocab size for openmoe is 30w+ + # which causes great activation memory in training, up to 20G for one sequence + # so we use chunk and checkpoint to reduce memory + else: + if chunk_head == True: + + def create_custom_forward(module): + + def custom_forward(*inputs): + logits = module(inputs[0]) + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous().float() + shift_labels = inputs[1][..., 1:].contiguous() + # Flatten the tokens + loss = self._calculate_loss(shift_logits, shift_labels) + return loss + + return custom_forward + + aux_loss, z_loss = self._calculate_router_loss(router_aux_loss, router_z_loss) + loss = aux_loss + z_loss + for batch_idx in range(hidden_states.shape[0]): + loss = loss + torch.utils.checkpoint.checkpoint( + create_custom_forward(self.lm_head), + hidden_states[batch_idx:batch_idx + 1, :], + labels[batch_idx:batch_idx + 1, :], + ) + logits = None + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + aux_loss, z_loss = self._calculate_router_loss(router_aux_loss, router_z_loss) + loss = aux_loss + z_loss + loss = loss + self._calculate_loss(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=attentions, + ) + else: + hidden_states = outputs["hidden_states"] + router_aux_loss = outputs["router_aux_loss"] + router_z_loss = outputs["router_z_loss"] + return { + "hidden_states": hidden_states, + "past_router_aux_loss": router_aux_loss, + "past_router_z_loss": router_z_loss, + } diff --git a/examples/language/openmoe/test_ci.sh b/examples/language/openmoe/test_ci.sh index 75eee902c747..e69de29bb2d1 100644 --- a/examples/language/openmoe/test_ci.sh +++ b/examples/language/openmoe/test_ci.sh @@ -1,5 +0,0 @@ -set -xe -pip install -r requirements.txt - -python infer.py --model "test" -torchrun --standalone --nproc_per_node 2 train.py --model_name "test" --batch_size 1 --num_epoch 20 diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index 132f17a9ba0f..2099bbde91f5 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -5,6 +5,7 @@ import transformers from huggingface_hub import snapshot_download from model.modeling_openmoe import OpenMoeForCausalLM +from model.openmoe_policy import OpenMoeForCausalLMPolicy from torch.utils.data import Dataset from tqdm import tqdm from transformers import Adafactor, T5Tokenizer @@ -14,6 +15,7 @@ from colossalai import get_default_parser from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.moe import MoeCheckpintIO @@ -52,31 +54,67 @@ def __len__(self): def __getitem__(self, idx): return { - 'input_ids': self.input_ids[idx], - 'attention_mask': self.attention_mask[idx], - 'labels': self.input_ids[idx] + "input_ids": self.input_ids[idx], + "attention_mask": self.attention_mask[idx], + "labels": self.input_ids[idx], } def parse_args(): + # basic settings parser = get_default_parser() - parser.add_argument("--model_name", - type=str, - default="base", - help="Path to pretrained model or model identifier from huggingface.co/models.") - parser.add_argument("--output_path", - type=str, - default="./output_model.bin", - help="The path of your saved model after finetuning.") + parser.add_argument( + "--model_name", + type=str, + default="base", + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--output_path", + type=str, + default="./output_model.bin", + help="The path of your saved model after finetuning.", + ) parser.add_argument("--num_epoch", type=int, default=10, help="Number of epochs.") - parser.add_argument("--batch_size", - type=int, - default=4, - help="Batch size (per dp group) for the training dataloader.") + parser.add_argument( + "--batch_size", + type=int, + default=4, + help="Batch size (per dp group) for the training dataloader.", + ) parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + parser.add_argument( + "--plugin", + type=str, + default="hybrid", + help="parallel plugin", + choices=["zero1", "zero2", "hybrid"], + ) + # hybrid plugin + parser.add_argument("--pp_size", type=int, default=2, help="pp size") + parser.add_argument("--dp_size", type=int, default=1, help="dp size") + parser.add_argument("--ep_size", type=int, default=2, help="ep size") + parser.add_argument("--zero_stage", type=int, default=1, help="zero stage in hybrid plugin") + parser.add_argument("--microbatch_size", type=int, default=1, help="microbatch size") + # kernel + parser.add_argument( + "--use_kernel", + action="store_true", + help="Use kernel optim. Need to install flash attention, apex, triton to enable all kernel optimizations.", + ) # loss - parser.add_argument("--router_aux_loss_factor", type=float, default=0.01, help="router_aux_loss_factor.") - parser.add_argument("--router_z_loss_factor", type=float, default=0.0001, help="router_z_loss_factor.") + parser.add_argument( + "--router_aux_loss_factor", + type=float, + default=0.01, + help="router_aux_loss_factor.", + ) + parser.add_argument( + "--router_z_loss_factor", + type=float, + default=0.0001, + help="router_z_loss_factor.", + ) parser.add_argument("--label_smoothing", type=float, default=0.0, help="label_smoothing.") parser.add_argument("--z_loss_factor", type=float, default=0.0001, help="z_loss_factor.") # optim @@ -95,7 +133,24 @@ def main(): coordinator = DistCoordinator() # Set up moe - MOE_MANAGER.setup(seed=42, parallel="EP") + if args.plugin in ["zero1", "zero2"]: + MOE_MANAGER.setup( + seed=42, + parallel="EP", + use_kernel_optim=False if args.model_name == "test" else args.use_kernel, + ) + elif args.plugin == "hybrid": + assert (args.dp_size * args.ep_size * + args.pp_size == coordinator.world_size), "dp_size * ep_size * pp_size must equal to world_size" + MOE_MANAGER.setup( + seed=42, + parallel="EP", + mode="fixed", + fixed_dp_size=args.dp_size, + fixed_ep_size=args.ep_size, + fixed_pp_size=args.pp_size, + use_kernel_optim=False if args.model_name == "test" else args.use_kernel, + ) # Manage loggers disable_existing_loggers() @@ -129,12 +184,27 @@ def main(): # Set plugin booster_kwargs = {} - plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2) + if args.plugin == "zero1": + plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=1) + elif args.plugin == "zero2": + plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2) + elif args.plugin == "hybrid": + plugin = MoeHybridParallelPlugin( + tp_size=1, + pp_size=args.pp_size, + zero_stage=args.zero_stage, + microbatch_size=args.microbatch_size, + custom_policy=OpenMoeForCausalLMPolicy(), + enable_fused_normalization=args.use_kernel, + enable_jit_fused=args.use_kernel, + ) + else: + raise ValueError(f"Invalid plugin {args.plugin}") logger.info(f"Set plugin as {plugin}", ranks=[0]) # Prepare tokenizer and dataloader tokenizer = T5Tokenizer.from_pretrained("google/umt5-small") - dataset = RandomDataset(num_samples=1000 if args.model_name != "test" else 1) + dataset = RandomDataset(num_samples=1000 if args.model_name != "test" else 50) dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) # Set optimizer @@ -143,27 +213,47 @@ def main(): # Set booster booster = Booster(plugin=plugin, **booster_kwargs) model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader) + use_pipeline = (isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1) + is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() logger.info(f"Finish init booster", ranks=[0]) # Start finetuning logger.info(f"Start finetuning", ranks=[0]) for epoch in range(args.num_epoch): model.train() - with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar: - for batch in pbar: - # Forward - optimizer.zero_grad() - batch = move_to_cuda(batch, torch.cuda.current_device()) - - outputs = model(use_cache=False, chunk_head=True, **batch) - loss = outputs['loss'] + train_dataloader_iter = iter(dataloader) + total_len = len(train_dataloader_iter) + with tqdm( + range(total_len), + desc=f"Epoch [{epoch + 1}/{args.num_epoch}]", + disable=not coordinator.is_master(), + ) as pbar: + # Forward pass + for _ in pbar: + if use_pipeline: + outputs = booster.execute_pipeline( + train_dataloader_iter, + model, + lambda x, y: x.loss, + optimizer, + return_loss=True, + return_outputs=True, + ) + # Backward and optimize + if is_pp_last_stage: + loss = outputs["loss"] + pbar.set_postfix({"loss": loss.item()}) + else: + data = next(train_dataloader_iter) + data = move_to_cuda(data, torch.cuda.current_device()) + outputs = model(**data) + loss = outputs["loss"] + # Backward + booster.backward(loss, optimizer) + pbar.set_postfix({"loss": loss.item()}) - # Backward - booster.backward(loss, optimizer) optimizer.step() - - # Print batch loss - pbar.set_postfix({'loss': loss.item()}) + optimizer.zero_grad() # Finish training and evaluate logger.info(f"Finish finetuning", ranks=[0]) diff --git a/examples/language/openmoe/train.sh b/examples/language/openmoe/train.sh index 9a55779ca5ef..6712aa10a88b 100644 --- a/examples/language/openmoe/train.sh +++ b/examples/language/openmoe/train.sh @@ -1,3 +1,9 @@ -torchrun --standalone --nproc_per_node 2 train.py \ +torchrun --standalone --nproc_per_node 4 train.py \ --model_name "base" \ + --plugin "hybrid" \ + --pp_size 2 \ + --dp_size 1 \ + --ep_size 2 \ + --use_kernel \ + --zero_stage 1 \ --batch_size 4 From b72fa37b359c6618698fdd1822ab4d74eec5b560 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> Date: Thu, 21 Sep 2023 17:12:48 +0800 Subject: [PATCH 15/46] [moe] update benchmark (#4770) * init policy * renam,e * update pp * finish pp * update script * update plugin * finish pp * update setup for different plugin * update ci * update ci * update ci * support ep inside or dp inside * update arg for kernel * disable ci * update train script * fsdp * update train * update train * fsdp benchmark * rename * update fsdp bench * fix plugin * update benchmark --- .../plugin/moe_hybrid_parallel_plugin.py | 5 +- .../openmoe/benchmark/benchmark_cai.py | 232 ++++++++++++++++++ .../openmoe/benchmark/benchmark_cai.sh | 56 +++++ .../openmoe/benchmark/benchmark_fsdp.py | 124 ++++++++++ .../openmoe/benchmark/benchmark_fsdp.sh | 25 ++ .../openmoe/benchmark/benchmark_train.py | 196 --------------- .../openmoe/benchmark/benchmark_train.sh | 34 --- examples/language/openmoe/benchmark/utils.py | 151 ++++++++---- .../openmoe/model/modeling_openmoe.py | 45 ++-- examples/language/openmoe/train.py | 80 +++--- 10 files changed, 609 insertions(+), 339 deletions(-) create mode 100644 examples/language/openmoe/benchmark/benchmark_cai.py create mode 100755 examples/language/openmoe/benchmark/benchmark_cai.sh create mode 100644 examples/language/openmoe/benchmark/benchmark_fsdp.py create mode 100755 examples/language/openmoe/benchmark/benchmark_fsdp.sh delete mode 100644 examples/language/openmoe/benchmark/benchmark_train.py delete mode 100755 examples/language/openmoe/benchmark/benchmark_train.sh diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index fab6c2f0cb7b..1f3bb294a7ca 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -103,7 +103,10 @@ def __init__(self, overlap_communication: bool = True, custom_policy: Policy = None) -> None: - super().__init__() + super().__init__(tp_size=tp_size, + pp_size=pp_size, + num_microbatches=num_microbatches, + microbatch_size=microbatch_size) assert dist.get_world_size() % ( tp_size * pp_size ) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}' diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py new file mode 100644 index 000000000000..7f36f8a88925 --- /dev/null +++ b/examples/language/openmoe/benchmark/benchmark_cai.py @@ -0,0 +1,232 @@ +import datasets +import torch +import torch.distributed as dist +import transformers +from model.modeling_openmoe import OpenMoeForCausalLM +from model.openmoe_policy import OpenMoeForCausalLMPolicy +from torch.utils.data import Dataset +from tqdm import tqdm +from transformers import Adafactor +from transformers.models.llama import LlamaConfig +from utils import PerformanceEvaluator, get_model_numel + +import colossalai +from colossalai import get_default_parser +from colossalai.booster import Booster +from colossalai.booster.plugin import LowLevelZeroPlugin +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.cluster import DistCoordinator +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.moe.manager import MOE_MANAGER +from colossalai.utils import get_current_device + + +def move_to_cuda(batch, device): + return {k: v.to(device) for k, v in batch.items()} + + +class RandomDataset(Dataset): + + def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000): + self.num_samples = num_samples + self.max_length = max_length + self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) + self.attention_mask = torch.ones_like(self.input_ids) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + return { + "input_ids": self.input_ids[idx], + "attention_mask": self.attention_mask[idx], + "labels": self.input_ids[idx], + } + + +def parse_args(): + # basic settings + parser = get_default_parser() + parser.add_argument( + "--model_name", + type=str, + default="base", + choices=["base", "8b"], + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--batch_size", + type=int, + default=4, + help="Batch size (per dp group) for the training dataloader.", + ) + parser.add_argument( + "--seq_length", + type=int, + default=2048, + help="sequence length for the training dataloader.", + ) + parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + parser.add_argument( + "--plugin", + type=str, + default="hybrid", + help="parallel plugin", + choices=["zero1", "zero2", "hybrid"], + ) + # hybrid plugin + parser.add_argument("--pp_size", type=int, default=2, help="pp size") + parser.add_argument("--dp_size", type=int, default=1, help="dp size") + parser.add_argument("--ep_size", type=int, default=2, help="ep size") + parser.add_argument("--zero_stage", type=int, default=1, help="zero stage in hybrid plugin") + parser.add_argument("--microbatch_size", type=int, default=1, help="microbatch size") + # kernel + parser.add_argument( + "--use_kernel", + action="store_true", + help="Use kernel optim. Need to install flash attention, apex, triton to enable all kernel optimizations.", + ) + # bench + parser.add_argument("--warmup", type=int, default=20) + parser.add_argument("--active", type=int, default=20) + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + # Launch ColossalAI + colossalai.launch_from_torch(config={}, seed=args.seed) + coordinator = DistCoordinator() + + # Manage loggers + disable_existing_loggers() + logger = get_dist_logger() + if coordinator.is_master(): + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # Set plugin + booster_kwargs = {} + if args.plugin == "zero1": + dp_size = dist.get_world_size() + plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=1) + MOE_MANAGER.setup( + seed=42, + parallel="EP", + use_kernel_optim=args.use_kernel, + ) + elif args.plugin == "zero2": + dp_size = dist.get_world_size() + plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2) + MOE_MANAGER.setup( + seed=42, + parallel="EP", + use_kernel_optim=args.use_kernel, + ) + elif args.plugin == "hybrid": + dp_size = dist.get_world_size() // args.pp_size + plugin = MoeHybridParallelPlugin( + tp_size=1, + pp_size=args.pp_size, + zero_stage=args.zero_stage, + microbatch_size=args.microbatch_size, + custom_policy=OpenMoeForCausalLMPolicy(), + enable_fused_normalization=args.use_kernel, + enable_jit_fused=args.use_kernel, + ) + MOE_MANAGER.setup( + seed=42, + parallel="EP", + mode="fixed", + fixed_dp_size=args.dp_size, + fixed_ep_size=args.ep_size, + fixed_pp_size=args.pp_size, + use_kernel_optim=args.use_kernel, + ) + else: + raise ValueError(f"Invalid plugin {args.plugin}") + logger.info(f"Set plugin as {plugin}", ranks=[0]) + + # Build OpenMoe model + repo_name = "hpcaitech/openmoe-" + args.model_name + config = LlamaConfig.from_pretrained(repo_name) + setattr(config, "router_aux_loss_factor", 0.1) + setattr(config, "router_z_loss_factor", 0.1) + setattr(config, "label_smoothing", 0.1) + setattr(config, "z_loss_factor", 0.1) + model = OpenMoeForCausalLM(config) + logger.info(f"Finish init model with config:\n{config}", ranks=[0]) + + # Enable gradient checkpointing + model.gradient_checkpointing_enable() + + # Prepare tokenizer and dataloader + dataset = RandomDataset( + num_samples=args.batch_size * (args.warmup + args.active + 1) * dp_size, + max_length=args.seq_length, + ) + dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size) + + # Set optimizer + optimizer = Adafactor(model.parameters(), weight_decay=0.01) + + model_numel = get_model_numel(model) + performance_evaluator = PerformanceEvaluator( + model_numel, + enable_grad_checkpoint=True, + ignore_steps=args.warmup, + dp_world_size=dp_size, + ) + + # Set booster + booster = Booster(plugin=plugin, **booster_kwargs) + model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader) + use_pipeline = (isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1) + is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() + logger.info(f"Finish init booster", ranks=[0]) + + # Start finetuning + logger.info(f"Start finetuning", ranks=[0]) + model.train() + train_dataloader_iter = iter(dataloader) + total_len = len(train_dataloader_iter) - 1 + exmaple_data = next(train_dataloader_iter) + with tqdm(range(total_len), disable=not coordinator.is_master()) as pbar: + for step in pbar: + performance_evaluator.on_step_start(step) + if use_pipeline: + # Forward pass + outputs = booster.execute_pipeline( + train_dataloader_iter, + model, + lambda x, y: x.loss, + optimizer, + return_loss=True, + return_outputs=True, + ) + # Backward and optimize + if is_pp_last_stage: + loss = outputs["loss"] + pbar.set_postfix({"loss": loss.item()}) + else: + # Forward pass + data = move_to_cuda(data, torch.cuda.current_device()) + outputs = model(**data) + loss = outputs["loss"] + # Backward + booster.backward(loss, optimizer) + pbar.set_postfix({"loss": loss.item()}) + + optimizer.step() + optimizer.zero_grad() + performance_evaluator.on_step_end(exmaple_data["input_ids"]) + performance_evaluator.on_fit_end() + + +if __name__ == "__main__": + main() diff --git a/examples/language/openmoe/benchmark/benchmark_cai.sh b/examples/language/openmoe/benchmark/benchmark_cai.sh new file mode 100755 index 000000000000..24d0c1b23ab2 --- /dev/null +++ b/examples/language/openmoe/benchmark/benchmark_cai.sh @@ -0,0 +1,56 @@ +#!/bin/bash + +set -xue + +NUM_GPU=4 +MODEL="base" +BATCH_SIZE=1 +SEQ_LENGTH=2048 +WARMUP=10 +ACTIVE=10 + +# HACK: make model importable +example_dir=$(dirname $(realpath $(dirname $0))) +if [ -z ${PYTHONPATH+x} ]; then + export PYTHONPATH=$example_dir +else + export PYTHONPATH=$example_dir:$PYTHONPATH +fi + +# hybrid +torchrun --standalone --nproc_per_node $NUM_GPU \ + $example_dir/benchmark/benchmark_cai.py \ + --model_name $MODEL \ + --batch_size $BATCH_SIZE \ + --seq_length $SEQ_LENGTH \ + --warmup $WARMUP \ + --active $ACTIVE \ + --use_kernel \ + --plugin hybrid \ + --pp_size 2 \ + --dp_size 1 \ + --ep_size 2 \ + --zero_stage 1 \ + --microbatch_size 1 + +# zero1 +torchrun --standalone --nproc_per_node $NUM_GPU \ + $example_dir/benchmark/benchmark_cai.py \ + --model_name $MODEL \ + --batch_size $BATCH_SIZE \ + --seq_length $SEQ_LENGTH \ + --warmup $WARMUP \ + --active $ACTIVE \ + --plugin zero1 \ + --use_kernel + +# zero2 +torchrun --standalone --nproc_per_node $NUM_GPU \ + $example_dir/benchmark/benchmark_cai.py \ + --model_name $MODEL \ + --batch_size $BATCH_SIZE \ + --seq_length $SEQ_LENGTH \ + --warmup $WARMUP \ + --active $ACTIVE \ + --plugin zero2 \ + --use_kernel diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.py b/examples/language/openmoe/benchmark/benchmark_fsdp.py new file mode 100644 index 000000000000..cb231687ef39 --- /dev/null +++ b/examples/language/openmoe/benchmark/benchmark_fsdp.py @@ -0,0 +1,124 @@ +import argparse +import os + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import tqdm +from model.modeling_openmoe import LlamaConfig, OpenMoeForCausalLM +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision +from torch.utils.data import Dataset +from torch.utils.data.distributed import DistributedSampler +from transformers import Adafactor +from transformers.models.llama import LlamaConfig +from utils import PerformanceEvaluator, get_model_numel + +from colossalai.moe.manager import MOE_MANAGER + + +class RandomDataset(Dataset): + + def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000): + self.num_samples = num_samples + self.max_length = max_length + self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length)) + self.attention_mask = torch.ones_like(self.input_ids) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + return { + "input_ids": self.input_ids[idx], + "attention_mask": self.attention_mask[idx], + "labels": self.input_ids[idx], + } + + +def fsdp_main(rank, world_size, args): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "14501" + # initialize the process group + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + MOE_MANAGER.setup(seed=42, parallel=None, use_kernel_optim=False) + + dp_size = dist.get_world_size() + dataset = RandomDataset(max_length=args.seq_length, + num_samples=args.batch_size * (args.warmup + args.active) * dp_size) + sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size, shuffle=False) + train_kwargs = {"batch_size": args.batch_size, "sampler": sampler} + train_loader = torch.utils.data.DataLoader(dataset, **train_kwargs) + torch.cuda.set_device(rank) + + config = LlamaConfig.from_pretrained("hpcaitech/openmoe-%s" % args.model_name) + setattr(config, "router_aux_loss_factor", 0.1) + setattr(config, "router_z_loss_factor", 0.1) + setattr(config, "label_smoothing", 0.1) + setattr(config, "z_loss_factor", 0.1) + model = OpenMoeForCausalLM(config).to(rank) + # 使用FSDP将model warp起来 + model = FSDP( + model, + mixed_precision=MixedPrecision( + param_dtype=torch.float16, + reduce_dtype=torch.float16, + buffer_dtype=torch.float16, + ), + ) + optimizer = Adafactor(model.parameters()) + model.train() + + model_numel = get_model_numel(model) + performance_evaluator = PerformanceEvaluator( + model_numel, + enable_grad_checkpoint=True, + ignore_steps=args.warmup, + dp_world_size=dist.get_world_size(), + ) + + for step, data in tqdm.tqdm(enumerate(train_loader), total=len(train_loader)): + performance_evaluator.on_step_start(step) + input_ids, attention_mask, labels = ( + data["input_ids"].cuda(), + data["attention_mask"].cuda(), + data["labels"].cuda(), + ) + + optimizer.zero_grad() + output = model( + input_ids=input_ids, + labels=labels, + attention_mask=attention_mask, + chunk_head=False, + ) + loss = output["loss"] + loss.backward() + optimizer.step() + performance_evaluator.on_step_end(input_ids) + + performance_evaluator.on_fit_end() + if dist.get_rank() == 0: + print(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name", + type=str, + default="base", + choices=["base", "8b"], + help="base or 8b", + ) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--seq_length", type=int, default=2048) + parser.add_argument("--warmup", type=int, default=20) + parser.add_argument("--active", type=int, default=20) + args = parser.parse_args() + + torch.manual_seed(42) + + WORLD_SIZE = torch.cuda.device_count() + mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True) diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.sh b/examples/language/openmoe/benchmark/benchmark_fsdp.sh new file mode 100755 index 000000000000..a4cb32019431 --- /dev/null +++ b/examples/language/openmoe/benchmark/benchmark_fsdp.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +set -xue + +NUM_GPU=4 +MODEL="base" +BATCH_SIZE=1 +SEQ_LENGTH=2048 +WARMUP=10 +ACTIVE=10 + +# HACK: make model importable +example_dir=$(dirname $(realpath $(dirname $0))) +if [ -z ${PYTHONPATH+x} ]; then + export PYTHONPATH=$example_dir +else + export PYTHONPATH=$example_dir:$PYTHONPATH +fi + +python $example_dir/benchmark/benchmark_fsdp.py \ + --model_name $MODEL \ + --batch_size $BATCH_SIZE \ + --seq_length $SEQ_LENGTH \ + --warmup $WARMUP \ + --active $ACTIVE diff --git a/examples/language/openmoe/benchmark/benchmark_train.py b/examples/language/openmoe/benchmark/benchmark_train.py deleted file mode 100644 index 373516c56f84..000000000000 --- a/examples/language/openmoe/benchmark/benchmark_train.py +++ /dev/null @@ -1,196 +0,0 @@ -import colossalai -import datasets -import torch -import transformers -from colossalai import get_default_parser -from colossalai.booster import Booster -from colossalai.booster.plugin import LowLevelZeroPlugin -from colossalai.cluster import DistCoordinator -from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.moe.manager import MOE_MANAGER -from colossalai.moe.utils import skip_init -from colossalai.utils import get_current_device -from model.modeling_openmoe import OpenMoeForCausalLM -from torch.utils.data import Dataset -from tqdm import tqdm -from transformers import Adafactor -from transformers.models.llama import LlamaConfig -from utils import SimpleTimer, print_model_numel - - -class RandomDataset(Dataset): - - def __init__(self, - num_samples: int = 1000, - max_length: int = 2048, - vocab_size: int = 32000): - self.num_samples = num_samples - self.max_length = max_length - self.input_ids = torch.randint(0, vocab_size, - (num_samples, max_length), - device=get_current_device()) - self.attention_mask = torch.ones_like(self.input_ids, - device=get_current_device()) - - def __len__(self): - return self.num_samples - - def __getitem__(self, idx): - return { - 'input_ids': self.input_ids[idx], - 'attention_mask': self.attention_mask[idx], - 'labels': self.input_ids[idx] - } - - -def parse_args(): - parser = get_default_parser() - # TODO: add model_name - # parser.add_argument("--model_name", type=str, default="base", choices=["base", "8b"], - # help="Path to pretrained model or model identifier from huggingface.co/models.") - parser.add_argument("--num_epoch", type=int, default=1, help="Number of epochs.") - parser.add_argument("--batch_size", type=int, default=4, help="Batch size (per dp group) for the training dataloader.") - parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") - parser.add_argument("--num_samples", type=int, default=1000, help="Number of samples in the dataset.") - - args = parser.parse_args() - return args - - -def main(): - args = parse_args() - - MDOEL_CONFIG = { - "architectures": [ - "OpenMoeForCausalLM" - ], - "capacity_factor_eval": 2.0, - "capacity_factor_train": 1.25, - "drop_tks": True, - "dropout_rate": 0.0, - "expert_parallel": None, - "gated": True, - "head_dim": 64, - "hidden_act": "swiglu", - "hidden_size": 768, - "intermediate_size": 2048, - "label_smoothing": 0.0, - "layer_norm_epsilon": 1e-06, - "min_capacity": 4, - "moe_layer_interval": 4, - "noisy_policy": None, - "num_attention_heads": 12, - "num_experts": 16, - "num_hidden_layers": 12, - "num_key_value_heads": 12, - "pretraining_tp": 1, - "rope_scaling": None, - "router_aux_loss_factor": 0.01, - "router_z_loss_factor": 0.0001, - "topk": 2, - "torch_dtype": "float32", - "vocab_size": 256384, - "z_loss_factor": 0.0001 - } - OPTIM_CONFIG = { - "decay_rate": -0.8, - "weight_decay": 0.01, - } - - # update config from args - for k in MDOEL_CONFIG: - if hasattr(args, k): - MDOEL_CONFIG[k] = getattr(args, k) - - # Launch ColossalAI - colossalai.launch_from_torch(config={}, seed=args.seed) - coordinator = DistCoordinator() - - # Set up moe - MOE_MANAGER.setup(seed=42, parallel="EP") - - # Manage loggers - disable_existing_loggers() - logger = get_dist_logger() - if coordinator.is_master(): - datasets.utils.logging.set_verbosity_warning() - transformers.utils.logging.set_verbosity_info() - else: - datasets.utils.logging.set_verbosity_error() - transformers.utils.logging.set_verbosity_error() - - # Build OpenMoe model - config = LlamaConfig() - for k, v in MDOEL_CONFIG.items(): - setattr(config, k, v) - - with skip_init(): - model = OpenMoeForCausalLM(config) - - logger.info(f"Finish init model with config:\n{config}", ranks=[0]) - model_param = sum(p.numel() for p in model.parameters() if p.requires_grad) - logger.info(f"Model param count: {model_param/1e6:.2f}M", ranks=[0]) - - # Enable gradient checkpointing - model.gradient_checkpointing_enable() - - # Set plugin - plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2) - logger.info(f"Set plugin as {plugin}", ranks=[0]) - - # Prepare tokenizer and dataloader - dataset = RandomDataset(num_samples=args.num_samples) - dataloader = plugin.prepare_dataloader(dataset, - batch_size=args.batch_size, - shuffle=True, - drop_last=True) - - # Set optimizer - optimizer = Adafactor(model.parameters(), - decay_rate=OPTIM_CONFIG["decay_rate"], - weight_decay=OPTIM_CONFIG["weight_decay"]) - - # Set booster - booster = Booster(plugin=plugin) - model, optimizer, _, dataloader, _ = booster.boost(model=model, - optimizer=optimizer, - dataloader=dataloader) - - # Start benchmark - model.train() - logger.info(f"Start benchmark", ranks=[0]) - - timer = SimpleTimer() - for epoch in range(args.num_epoch): - for batch in tqdm(dataloader, - desc=f'Epoch [{epoch + 1}]', - disable=not coordinator.is_master()): - timer.start("train_step") - - # Forward - timer.start("forward") - outputs = model(use_cache=False, chunk_head=True, **batch) - loss = outputs['loss'] - torch.cuda.synchronize() - timer.stop("forward") - - # Backward - timer.start("backward") - booster.backward(loss, optimizer) - torch.cuda.synchronize() - timer.stop("backward") - - # Optimizer step - timer.start("optimizer_step") - optimizer.step() - optimizer.zero_grad() - torch.cuda.synchronize() - timer.stop("optimizer_step") - - timer.stop("train_step") - - logger.info(f"Benchmark result:\n{repr(timer)}", ranks=[0]) - - -if __name__ == "__main__": - main() diff --git a/examples/language/openmoe/benchmark/benchmark_train.sh b/examples/language/openmoe/benchmark/benchmark_train.sh deleted file mode 100755 index 0496a31a7479..000000000000 --- a/examples/language/openmoe/benchmark/benchmark_train.sh +++ /dev/null @@ -1,34 +0,0 @@ -#!/bin/bash - -set -xue - -BENCHMARK_DIR=benchmark -NUM_GPU=2 - -set_n_least_used_CUDA_VISIBLE_DEVICES() { - local n=${1:-"9999"} - echo "GPU Memory Usage:" - local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv | - tail -n +2 | - nl -v 0 | - tee /dev/tty | - sort -g -k 2 | - awk '{print $1}' | - head -n $n) - export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') - echo "Now CUDA_VISIBLE_DEVICES is set to:" - echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" -} - -set_n_least_used_CUDA_VISIBLE_DEVICES $NUM_GPU - -# HACK: make model importable -example_dir=$(dirname $(realpath $(dirname $0))) -if [ -z ${PYTHONPATH+x} ]; then - export PYTHONPATH=$example_dir -else - export PYTHONPATH=$example_dir:$PYTHONPATH -fi - -torchrun --standalone --nproc_per_node $NUM_GPU \ - $example_dir/$BENCHMARK_DIR/benchmark_train.py diff --git a/examples/language/openmoe/benchmark/utils.py b/examples/language/openmoe/benchmark/utils.py index d2edee64451c..7a0955bb028a 100644 --- a/examples/language/openmoe/benchmark/utils.py +++ b/examples/language/openmoe/benchmark/utils.py @@ -1,61 +1,126 @@ -import dataclasses -import time -from typing import Dict +from time import time +from typing import Optional +import torch import torch.distributed as dist import torch.nn as nn +from torch import Tensor + from colossalai.logging import DistributedLogger -def print_model_numel(logger: DistributedLogger, - model: nn.Module) -> None: +def print_model_numel(logger: DistributedLogger, model: nn.Module) -> None: B = 1024**3 M = 1024**2 K = 1024 outputs = "Model param count: " model_param = sum(p.numel() for p in model.parameters() if p.requires_grad) if model_param >= B: - outputs += f'{model_param / B:.2f} B\n' + outputs += f"{model_param / B:.2f} B\n" elif model_param >= M: - outputs += f'{model_param / M:.2f} M\n' + outputs += f"{model_param / M:.2f} M\n" elif model_param >= K: - outputs += f'{model_param / K:.2f} K\n' + outputs += f"{model_param / K:.2f} K\n" else: - outputs += f'{model_param}\n' + outputs += f"{model_param}\n" logger.info(outputs, ranks=[0]) -@dataclasses.dataclass -class TimingItem(): - last_time: float = 0.0 - total_time: float = 0.0 - count: float = 0 - - def __str__(self) -> str: - return f"average time: {self.total_time/self.count * 1000:.2f} ms" - - -class SimpleTimer(): - def __init__(self, warmup: int = 20) -> None: - self.timing_items: Dict[str, TimingItem] = {} - self.warmup = warmup - - def start(self, name: str): - if name not in self.timing_items: - self.timing_items[name] = TimingItem() - self.timing_items[name].last_time = time.time() - - def stop(self, name: str): - assert name in self.timing_items - timing_item = self.timing_items[name] - timing_item.total_time += time.time() - timing_item.last_time - timing_item.count += 1 - if timing_item.count > self.warmup: - timing_item.count = 0 - timing_item.total_time = 0.0 - - def __repr__(self) -> str: - result = "[Timer]:\n" - for name, timing_item in self.timing_items.items(): - result += f" {name}: {timing_item}\n" - return result +def get_model_numel(model: nn.Module) -> None: + model_param = sum(p.numel() for p in model.parameters() if p.requires_grad) + return model_param + + +def divide(x: float, y: float) -> float: + if y == 0: + return float("inf") + elif y == float("inf"): + return float("nan") + return x / y + + +@torch.no_grad() +def all_reduce_mean(x: float, world_size: int) -> float: + if world_size == 1: + return x + tensor = torch.tensor([x], device=torch.cuda.current_device()) + dist.all_reduce(tensor) + tensor = tensor / world_size + return tensor.item() + + +class Timer: + + def __init__(self) -> None: + self.start_time: Optional[float] = None + self.duration: float = 0.0 + + def start(self) -> None: + self.start_time = time() + + def end(self) -> None: + assert self.start_time is not None + self.duration += time() - self.start_time + self.start_time = None + + def reset(self) -> None: + self.duration = 0.0 + + +class PerformanceEvaluator: + """ + Callback for valuate the performance of the model. + Args: + actor_num_params: The number of parameters of the actor model. + critic_num_params: The number of parameters of the critic model. + initial_model_num_params: The number of parameters of the initial model. + reward_model_num_params: The number of parameters of the reward model. + enable_grad_checkpoint: Whether to enable gradient checkpointing. + ignore_episodes: The number of episodes to ignore when calculating the performance. + """ + + def __init__( + self, + model_numel: int, + enable_grad_checkpoint: bool = False, + ignore_steps: int = 0, + dp_world_size: Optional[int] = None, + ) -> None: + self.model_numel = model_numel + self.enable_grad_checkpoint = enable_grad_checkpoint + self.ignore_steps = ignore_steps + self.dp_world_size = dp_world_size + self.world_size = dist.get_world_size() + self.disable: bool = False + self.timer = Timer() + self.num_samples: int = 0 + self.flop: int = 0 + + def on_step_start(self, step: int) -> None: + self.disable = self.ignore_steps > 0 and step < self.ignore_steps + if self.disable: + return + torch.cuda.synchronize() + self.timer.start() + + def on_step_end(self, input_ids: Tensor, **kwargs) -> None: + if self.disable: + return + torch.cuda.synchronize() + self.timer.end() + + batch_size, seq_len = input_ids.shape + + self.num_samples += batch_size + self.flop += (batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint))) + + def on_fit_end(self) -> None: + avg_duration = all_reduce_mean(self.timer.duration, self.world_size) + avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12) + mp_world_size = self.world_size // self.dp_world_size + avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size + if dist.get_rank() == 0: + print( + f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, " + f"avg_throughput: {avg_throughput}") + print(f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}") diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index a774d4e9fd55..4d5ff19936b6 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -24,23 +24,24 @@ import torch import torch.nn.functional as F import torch.utils.checkpoint +from torch import nn +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.models.llama.modeling_llama import LlamaConfig, LlamaRMSNorm +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) + from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON from colossalai.moe.layers import SparseMLP from colossalai.moe.manager import MOE_MANAGER -from torch import nn -from transformers.modeling_outputs import (BaseModelOutputWithPast, - CausalLMOutputWithPast) -from transformers.modeling_utils import PreTrainedModel -from transformers.models.llama import LlamaConfig -from transformers.models.t5.modeling_t5 import T5LayerNorm -from transformers.utils import (add_start_docstrings, - add_start_docstrings_to_model_forward, logging, - replace_return_docstrings) if HAS_TRITON: - from colossalai.kernel.triton.llama_act_combine_kernel import \ - LlamaActCombine + from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine logger = logging.get_logger(__name__) @@ -305,23 +306,21 @@ def forward( query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - attn_output = flash_attn_func(query_states, - key_states, - value_states, - softmax_scale=1.0, - causal=True) + attn_output = flash_attn_func(query_states, key_states, value_states, softmax_scale=1.0, causal=True) attn_output = attn_output.transpose(1, 2).contiguous() else: attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError(f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}") + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}") if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}") + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) if self.training: attention_mask = attention_mask.clone().detach() attention_mask[:, :, :, 0] = 0 @@ -358,8 +357,8 @@ def __init__(self, config: LlamaConfig, moe: bool): self.hidden_size = config.hidden_size self.moe = moe self.self_attn = OpenMoeAttention(config=config) - self.input_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) if self.moe: self.mlp = SparseMLP( num_experts=config.num_experts, @@ -374,7 +373,7 @@ def __init__(self, config: LlamaConfig, moe: bool): intermediate_size=config.intermediate_size, activation=config.hidden_act, gated=config.gated) - self.pre_extra_mlp_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_extra_mlp_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.extra_mlp = OpenMoeMLP(config) else: self.mlp = OpenMoeMLP(config) @@ -570,7 +569,7 @@ def __init__(self, config: LlamaConfig): OpenMoeDecoderLayer(config, moe=True if (i + 1) % config.moe_layer_interval == 0 else False) for i in range(config.num_hidden_layers) ]) - self.norm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False # Initialize weights and apply final processing diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index 2099bbde91f5..e276759043a9 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -67,6 +67,7 @@ def parse_args(): "--model_name", type=str, default="base", + choices=["base", "8b"], help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( @@ -132,26 +133,6 @@ def main(): colossalai.launch_from_torch(config={}, seed=args.seed) coordinator = DistCoordinator() - # Set up moe - if args.plugin in ["zero1", "zero2"]: - MOE_MANAGER.setup( - seed=42, - parallel="EP", - use_kernel_optim=False if args.model_name == "test" else args.use_kernel, - ) - elif args.plugin == "hybrid": - assert (args.dp_size * args.ep_size * - args.pp_size == coordinator.world_size), "dp_size * ep_size * pp_size must equal to world_size" - MOE_MANAGER.setup( - seed=42, - parallel="EP", - mode="fixed", - fixed_dp_size=args.dp_size, - fixed_ep_size=args.ep_size, - fixed_pp_size=args.pp_size, - use_kernel_optim=False if args.model_name == "test" else args.use_kernel, - ) - # Manage loggers disable_existing_loggers() logger = get_dist_logger() @@ -162,32 +143,22 @@ def main(): datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() - # Build OpenMoe model - repo_name = "hpcaitech/openmoe-" + args.model_name - if args.model_name == "test": - config = LlamaConfig.from_pretrained("hpcaitech/openmoe-base") - config.vocab_size = 32000 - else: - config = LlamaConfig.from_pretrained(repo_name) - setattr(config, "router_aux_loss_factor", args.router_aux_loss_factor) - setattr(config, "router_z_loss_factor", args.router_z_loss_factor) - setattr(config, "label_smoothing", args.label_smoothing) - setattr(config, "z_loss_factor", args.z_loss_factor) - with skip_init(): - model = OpenMoeForCausalLM(config) - if args.model_name != "test": - load_ckpt(repo_name, model) - logger.info(f"Finish init model with config:\n{config}", ranks=[0]) - - # Enable gradient checkpointing - model.gradient_checkpointing_enable() - # Set plugin booster_kwargs = {} if args.plugin == "zero1": plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=1) + MOE_MANAGER.setup( + seed=42, + parallel="EP", + use_kernel_optim=args.use_kernel, + ) elif args.plugin == "zero2": plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2) + MOE_MANAGER.setup( + seed=42, + parallel="EP", + use_kernel_optim=args.use_kernel, + ) elif args.plugin == "hybrid": plugin = MoeHybridParallelPlugin( tp_size=1, @@ -198,13 +169,37 @@ def main(): enable_fused_normalization=args.use_kernel, enable_jit_fused=args.use_kernel, ) + MOE_MANAGER.setup( + seed=42, + parallel="EP", + mode="fixed", + fixed_dp_size=args.dp_size, + fixed_ep_size=args.ep_size, + fixed_pp_size=args.pp_size, + use_kernel_optim=args.use_kernel, + ) else: raise ValueError(f"Invalid plugin {args.plugin}") logger.info(f"Set plugin as {plugin}", ranks=[0]) + # Build OpenMoe model + repo_name = "hpcaitech/openmoe-" + args.model_name + config = LlamaConfig.from_pretrained(repo_name) + setattr(config, "router_aux_loss_factor", args.router_aux_loss_factor) + setattr(config, "router_z_loss_factor", args.router_z_loss_factor) + setattr(config, "label_smoothing", args.label_smoothing) + setattr(config, "z_loss_factor", args.z_loss_factor) + with skip_init(): + model = OpenMoeForCausalLM(config) + load_ckpt(repo_name, model) + logger.info(f"Finish init model with config:\n{config}", ranks=[0]) + + # Enable gradient checkpointing + model.gradient_checkpointing_enable() + # Prepare tokenizer and dataloader tokenizer = T5Tokenizer.from_pretrained("google/umt5-small") - dataset = RandomDataset(num_samples=1000 if args.model_name != "test" else 50) + dataset = RandomDataset(num_samples=1000) dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) # Set optimizer @@ -228,9 +223,9 @@ def main(): desc=f"Epoch [{epoch + 1}/{args.num_epoch}]", disable=not coordinator.is_master(), ) as pbar: - # Forward pass for _ in pbar: if use_pipeline: + # Forward pass outputs = booster.execute_pipeline( train_dataloader_iter, model, @@ -244,6 +239,7 @@ def main(): loss = outputs["loss"] pbar.set_postfix({"loss": loss.item()}) else: + # Forward pass data = next(train_dataloader_iter) data = move_to_cuda(data, torch.cuda.current_device()) outputs = model(**data) From 5c97a96ca30cdc232390b7b97f009f26923eae66 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> Date: Fri, 22 Sep 2023 10:42:49 +0800 Subject: [PATCH 16/46] [moe] fix ci (#4772) * init policy * renam,e * update pp * finish pp * update script * update plugin * finish pp * update setup for different plugin * update ci * update ci * update ci * support ep inside or dp inside * update arg for kernel * disable ci * update train script * fsdp * update train * update train * fsdp benchmark * rename * update fsdp bench * fix plugin * update benchmark * fix ci * fix ci * rename * update ci * update test * update vocab * update chunk head --- colossalai/moe/_operation.py | 38 ++++++++++++------- colossalai/moe/layers.py | 2 +- .../openmoe/benchmark/benchmark_fsdp.py | 1 - .../openmoe/model/modeling_openmoe.py | 2 +- examples/language/openmoe/test_ci.sh | 11 ++++++ examples/language/openmoe/train.py | 26 ++++++++----- 6 files changed, 55 insertions(+), 25 deletions(-) diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index a67feaefbfb8..c47d8df296b4 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -5,17 +5,21 @@ from torch import Tensor from torch.distributed import ProcessGroup -try: - from colossalai._C import moe -except: +from colossalai.moe.manager import MOE_MANAGER + +MOE_KERNEL = None + + +def load_moe(): + global MOE_KERNEL from colossalai.kernel.op_builder import MOEBuilder - moe = MOEBuilder().load() + + MOE_KERNEL = MOEBuilder().load() class AllGather(torch.autograd.Function): @staticmethod def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: - if ctx is not None: ctx.comm_grp = group @@ -86,7 +90,10 @@ def forward(ctx, tokens, mask, dest_idx, ec): s = tokens.size(0) h = tokens.size(1) - expert_input = moe.dispatch_forward(s, ec, h, tokens, mask, dest_idx) + if MOE_KERNEL is None: + load_moe() + + expert_input = MOE_KERNEL.dispatch_forward(s, ec, h, tokens, mask, dest_idx) ctx.save_for_backward(mask, dest_idx) ctx.s = s @@ -98,7 +105,7 @@ def forward(ctx, tokens, mask, dest_idx, ec): @staticmethod def backward(ctx, output_grad): mask, dest_idx = ctx.saved_tensors - d_tokens = moe.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx) + d_tokens = MOE_KERNEL.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx) return d_tokens, None, None, None @@ -112,9 +119,11 @@ def forward(ctx, expert_tokens, logits, mask, dest_idx, ec): c = ec // e h = expert_tokens.size(-1) - fp16_flag = (expert_tokens.dtype == torch.float16) + fp16_flag = expert_tokens.dtype == torch.float16 cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens - ctokens = moe.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx) + if MOE_KERNEL is None: + load_moe() + ctokens = MOE_KERNEL.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx) output = ctokens.to(torch.float16) if fp16_flag else ctokens ctx.save_for_backward(expert_tokens, logits, mask, dest_idx) @@ -130,9 +139,10 @@ def forward(ctx, expert_tokens, logits, mask, dest_idx, ec): def backward(ctx, tokens_grad): expert_tokens, logits, mask, dest_idx = ctx.saved_tensors - cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 else tokens_grad + cb_grad = (tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 else tokens_grad) cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens - d_expert, d_logits = moe.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits, mask, dest_idx) + d_expert, d_logits = MOE_KERNEL.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits, mask, + dest_idx) d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert return d_expert, d_logits, None, None, None @@ -141,8 +151,10 @@ def backward(ctx, tokens_grad): def moe_cumsum(inputs: Tensor): dim0 = inputs.size(0) flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0) - if flag: - return moe.cumsum_sub_one(inputs) + if flag and MOE_MANAGER.use_kernel_optim: + if MOE_KERNEL is None: + load_moe() + return MOE_KERNEL.cumsum_sub_one(inputs) else: return torch.cumsum(inputs, dim=0) - 1 diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index 1255a4816041..a78bfe0a3d74 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -58,7 +58,7 @@ def __init__(self, super().__init__() self.hidden_size = hidden_size self.num_experts = num_experts - self.use_kernel = True if MOE_MANAGER.use_kernel_optim else False + self.use_kernel = MOE_MANAGER.use_kernel_optim self.expert_parallel = expert_parallel assert expert_parallel in ["EP", "TP", None], f"Unsupported expert parallel type {expert_parallel}" diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.py b/examples/language/openmoe/benchmark/benchmark_fsdp.py index cb231687ef39..c7357c06e5c7 100644 --- a/examples/language/openmoe/benchmark/benchmark_fsdp.py +++ b/examples/language/openmoe/benchmark/benchmark_fsdp.py @@ -58,7 +58,6 @@ def fsdp_main(rank, world_size, args): setattr(config, "label_smoothing", 0.1) setattr(config, "z_loss_factor", 0.1) model = OpenMoeForCausalLM(config).to(rank) - # 使用FSDP将model warp起来 model = FSDP( model, mixed_precision=MixedPrecision( diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index 4d5ff19936b6..6933f108a09e 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -776,7 +776,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - chunk_head: Optional[bool] = None, + chunk_head: Optional[bool] = True, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: diff --git a/examples/language/openmoe/test_ci.sh b/examples/language/openmoe/test_ci.sh index e69de29bb2d1..86742e088f71 100644 --- a/examples/language/openmoe/test_ci.sh +++ b/examples/language/openmoe/test_ci.sh @@ -0,0 +1,11 @@ +pip install -r requirements.txt + +# inference +python infer.py --model "test" + +# train +torchrun --standalone --nproc_per_node 4 train.py \ + --num_epoch 1 \ + --model_name "test" \ + --plugin zero2 \ + --batch_size 1 diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index e276759043a9..a7f46f2f693b 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -67,7 +67,7 @@ def parse_args(): "--model_name", type=str, default="base", - choices=["base", "8b"], + choices=["base", "8b", "test"], help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( @@ -132,6 +132,7 @@ def main(): # Launch ColossalAI colossalai.launch_from_torch(config={}, seed=args.seed) coordinator = DistCoordinator() + test_mode = args.model_name == "test" # Manage loggers disable_existing_loggers() @@ -150,14 +151,14 @@ def main(): MOE_MANAGER.setup( seed=42, parallel="EP", - use_kernel_optim=args.use_kernel, + use_kernel_optim=args.use_kernel if not test_mode else False, ) elif args.plugin == "zero2": plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2) MOE_MANAGER.setup( seed=42, parallel="EP", - use_kernel_optim=args.use_kernel, + use_kernel_optim=args.use_kernel if not test_mode else False, ) elif args.plugin == "hybrid": plugin = MoeHybridParallelPlugin( @@ -166,8 +167,8 @@ def main(): zero_stage=args.zero_stage, microbatch_size=args.microbatch_size, custom_policy=OpenMoeForCausalLMPolicy(), - enable_fused_normalization=args.use_kernel, - enable_jit_fused=args.use_kernel, + enable_fused_normalization=args.use_kernel if not test_mode else False, + enable_jit_fused=args.use_kernel if not test_mode else False, ) MOE_MANAGER.setup( seed=42, @@ -183,15 +184,22 @@ def main(): logger.info(f"Set plugin as {plugin}", ranks=[0]) # Build OpenMoe model - repo_name = "hpcaitech/openmoe-" + args.model_name - config = LlamaConfig.from_pretrained(repo_name) + if test_mode: + config = LlamaConfig.from_pretrained("hpcaitech/openmoe-base") + config.hidden_size = 64 + config.intermediate_size = 128 + config.vocab_size = 32000 + else: + repo_name = "hpcaitech/openmoe-" + args.model_name + config = LlamaConfig.from_pretrained(repo_name) setattr(config, "router_aux_loss_factor", args.router_aux_loss_factor) setattr(config, "router_z_loss_factor", args.router_z_loss_factor) setattr(config, "label_smoothing", args.label_smoothing) setattr(config, "z_loss_factor", args.z_loss_factor) with skip_init(): model = OpenMoeForCausalLM(config) - load_ckpt(repo_name, model) + if not test_mode: + load_ckpt(repo_name, model) logger.info(f"Finish init model with config:\n{config}", ranks=[0]) # Enable gradient checkpointing @@ -199,7 +207,7 @@ def main(): # Prepare tokenizer and dataloader tokenizer = T5Tokenizer.from_pretrained("google/umt5-small") - dataset = RandomDataset(num_samples=1000) + dataset = RandomDataset(num_samples=1000 if not test_mode else 20) dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) # Set optimizer From c68303b80c49638dfb1af5eb025ed2a64b9fa3eb Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> Date: Fri, 29 Sep 2023 10:23:02 +0800 Subject: [PATCH 17/46] [moe] update benchmark scripts and ckpt io (#4804) * update benchmark script * update pp strategy * update plugin * update bench script * optimize * update pp layers * update zero ep * ep * update ckpt * update test --- .../plugin/moe_hybrid_parallel_plugin.py | 61 +++++ colossalai/moe/checkpoint.py | 234 ++++++++++++++++-- .../openmoe/benchmark/benchmark_cai.py | 20 +- .../openmoe/benchmark/benchmark_cai.sh | 27 +- .../openmoe/benchmark/benchmark_fsdp.py | 33 ++- .../openmoe/benchmark/benchmark_fsdp.sh | 8 +- .../openmoe/model/modeling_openmoe.py | 5 +- .../language/openmoe/model/openmoe_policy.py | 23 +- examples/language/openmoe/test_ci.sh | 12 +- examples/language/openmoe/train.py | 33 ++- tests/test_moe/test_moe_checkpoint.py | 149 +++++++++-- 11 files changed, 516 insertions(+), 89 deletions(-) diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 1f3bb294a7ca..784204528d65 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -1,10 +1,15 @@ +import random from typing import Optional +import numpy as np import torch import torch.distributed as dist +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelPlugin from colossalai.cluster import ProcessGroupMesh +from colossalai.moe import MoeCheckpintIO from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig @@ -174,3 +179,59 @@ def __init__(self, partition_grad=(self.zero_stage == 2)) self.max_norm = max_norm + + def prepare_dataloader(self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + **kwargs): + r""" + Prepare a dataloader for distributed training. The dataloader will be wrapped by + `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. + + + Args: + dataset (`torch.utils.data.Dataset`): The dataset to be loaded. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. + seed (int, optional): Random worker seed for sampling, defaults to 1024. + add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. + num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. + kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in + `DataLoader `_. + + Returns: + :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. + """ + _kwargs = kwargs.copy() + sampler = DistributedSampler(dataset, + num_replicas=self.pg_mesh.size(DP_AXIS), + rank=self.pg_mesh.coordinate(DP_AXIS), + shuffle=shuffle) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader(dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs) + + def get_checkpoint_io(self) -> MoeCheckpintIO: + self.checkpoint_io = MoeCheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) + return self.checkpoint_io diff --git a/colossalai/moe/checkpoint.py b/colossalai/moe/checkpoint.py index 3cda5a7f044c..99e0ae811bbd 100644 --- a/colossalai/moe/checkpoint.py +++ b/colossalai/moe/checkpoint.py @@ -1,25 +1,53 @@ +import logging +import os from copy import deepcopy from pathlib import Path -from typing import Optional +from typing import Iterator, Optional, OrderedDict, Tuple import torch import torch.distributed as dist import torch.nn as nn +from torch.distributed import ProcessGroup from torch.optim import Optimizer -from colossalai.checkpoint_io import GeneralCheckpointIO +from colossalai.checkpoint_io import CheckpointIndexFile, HybridParallelCheckpointIO +from colossalai.checkpoint_io.utils import ( + StateDictSharder, + gather_distributed_param, + get_model_base_filenames, + is_safetensors_available, + load_shard_state_dict, + load_state_dict_into_model, + save_config_file, + save_state_dict_shards, +) +from colossalai.moe.manager import MOE_MANAGER from colossalai.tensor.moe_tensor.api import get_dp_rank, get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor -class MoeCheckpintIO(GeneralCheckpointIO): +class MoeCheckpintIO(HybridParallelCheckpointIO): - def __init__(self) -> None: - super().__init__() + def __init__( + self, + dp_group: ProcessGroup, + pp_group: ProcessGroup, + tp_group: ProcessGroup, + zero_stage: int, + ) -> None: + assert zero_stage in [ + 0, + 1, + 2, + ], f"zero_stage should be 0 or 1 or 2, got {zero_stage}" + super().__init__(dp_group, pp_group, tp_group, zero_stage) + self.parallel = MOE_MANAGER.parallel - def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): - state_dict = torch.load(checkpoint) + def pre_load_model(self, model: nn.Module, state_dict: dict) -> dict: + """ + Preprocess state_dict before loading and slice the state_dict of MOE tensors. + """ for name, param in state_dict.items(): - if '.experts.' in name: + if ".experts." in name: model_param = dict(model.named_parameters())[name] if is_moe_tensor(model_param): ep_rank = get_ep_rank(model_param) @@ -28,13 +56,99 @@ def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): assert param.shape[0] % ep_size == 0 param = param[ep_rank * expert_num:(ep_rank + 1) * expert_num] state_dict[name] = param + dist.barrier() + return state_dict + + def _model_sharder( + self, + state_dict: nn.Module, + prefix: str = "", + keep_vars: bool = False, + size_per_shard: int = 1024, + ) -> Iterator[Tuple[OrderedDict, int]]: + # An internel method that breaks state_dict of model into shards within limited size. + state_dict_sharder = StateDictSharder(size_per_shard) + + for name, param in state_dict.items(): + if param is None: + continue + # Gather tensor pieces when using tensor parallel. + param_ = gather_distributed_param(param, keep_vars=False) + block, block_size = state_dict_sharder.append_param(prefix + name, param_) + if block is not None: + yield block, block_size + + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size + + def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool) -> None: + state_dict = torch.load(checkpoint) + state_dict = self.pre_load_model(model, state_dict) + model.load_state_dict(state_dict, strict=strict if self.pp_size == 1 else False) + + def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False): + """ + Load sharded model with the given path to index file of checkpoint folder. + + Args: + model (nn.Module): The model to be loaded. + checkpoint_index_file (str): Path to the index file of checkpointing folder. + strict (bool, optional): For name matching during loading state_dict. Defaults to False. + This argument should be manually set to False since params on same device might be stored in different files. + """ + + # Check whether the checkpoint uses safetensors. + use_safetensors = False + if "safetensors" in checkpoint_index_file.name: + use_safetensors = True + + if use_safetensors and not is_safetensors_available(): + raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.") + + # Read checkpoint index file. + ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) + ckpt_root_path = ckpt_index_file.root_path + weight_map = ckpt_index_file.weight_map + strict = False + + # Load params & buffers to model. + # Keep a record of loaded files so that file will not be repeatedly loaded. + loaded_file = set() + + def _load(name: str): + if name not in weight_map: + raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!") + filename = weight_map[name] + + # If this param/buffer has been loaded before, directly return. + if filename in loaded_file: + return - model.load_state_dict(state_dict, strict=strict) + file_path = os.path.join(ckpt_root_path, filename) + state_dict = load_shard_state_dict(Path(file_path), use_safetensors) + state_dict = self.pre_load_model(model, state_dict) + missing_keys = [] - def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + load_state_dict_into_model( + model, + state_dict, + missing_keys=missing_keys, + strict=strict, + load_sub_module=True, + ) + loaded_file.add(filename) + + # Load parameters. + for name, _ in model.named_parameters(): + _load(name) + + if self.verbose: + logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") + + def pre_save_model(self, model: nn.Module) -> dict: state_dict = model.state_dict() for name, param in model.named_parameters(): - if '.experts.' in name and is_moe_tensor(param): + if ".experts." in name and is_moe_tensor(param): ep_group = get_ep_group(param) ep_rank = get_ep_rank(param) ep_size = get_ep_size(param) @@ -45,19 +159,95 @@ def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor # gather param from every ep rank dist.all_gather(all_param, param, group=ep_group) if ep_rank == 0: - assert dist.get_rank() == 0 all_param = torch.cat(all_param, dim=0) state_dict[name] = all_param.cpu() + if self.pp_size > 1: + if self.dp_rank == 0: + out = [None for _ in range(self.pp_size)] + dist.all_gather_object(out, state_dict, group=self.pp_group) + if self.pp_rank == 0: + new_state_dict = {} + for o in out: + new_state_dict.update(o) + state_dict = new_state_dict + dist.barrier() + return state_dict + + def save_unsharded_model( + self, + model: nn.Module, + checkpoint: str, + gather_dtensor: bool, + use_safetensors: bool, + ): + state_dict = self.pre_save_model(model) if dist.get_rank() == 0: torch.save(state_dict, checkpoint) dist.barrier() - def load_sharded_model(self, model: nn.Module, index_file_path: str, strict: bool): - raise NotImplementedError() + def save_sharded_model( + self, + model: nn.Module, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + use_safetensors: bool = False, + ) -> None: + """ + Save sharded model checkpoint under the given checkpointing path. + The following files will be created under the path: + - An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names. + - Multiple files that store state tensors of models. + The filenames are in the form of "pytorch_model.-000XX.bin" - def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str], - size_per_shard: int, use_safetensors: bool): - raise NotImplementedError() + Args: + model (nn.Module): Model on local device to be saved. + checkpoint (str): Checkpointing path which should be a directory path. + gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True. + prefix (str, optional): Perfix of file to save. Defaults to None. + size_per_shard (int, optional): Size per shard in MB. Defaults to 1024. + use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. + """ + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + # Then collect the sharded parameters & buffers along tp_group. + # Only devices with tp_rank == 0 are responsible for model saving. + state_dict = self.pre_save_model(model) + + if dist.get_rank() == 0: + state_dict_shard = self._model_sharder(state_dict, size_per_shard=size_per_shard) + + # Devices along the same dp_group share the same copies of model. + # So only let the device with dp_rank == 0 save the model. + if self.dp_rank != 0: + return + + weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) + index_file = CheckpointIndexFile(checkpoint) + control_saving = self.tp_rank == 0 + + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors, + ) + if control_saving: + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + save_config_file(model, checkpoint) + if self.verbose: + logging.info(f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}.") + dist.barrier() # ======================================================== # Abstract methods for optimizer loading/saving implementation @@ -69,8 +259,14 @@ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, pre def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): raise NotImplementedError() - def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, - size_per_shard: int): + def save_sharded_optimizer( + self, + optimizer: Optimizer, + checkpoint: Path, + gather_dtensor: bool, + prefix: str, + size_per_shard: int, + ): raise NotImplementedError() def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool): diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py index 7f36f8a88925..d7dbd58ed0ca 100644 --- a/examples/language/openmoe/benchmark/benchmark_cai.py +++ b/examples/language/openmoe/benchmark/benchmark_cai.py @@ -72,7 +72,7 @@ def parse_args(): type=str, default="hybrid", help="parallel plugin", - choices=["zero1", "zero2", "hybrid"], + choices=["zero2", "zero2_ep", "hybrid"], ) # hybrid plugin parser.add_argument("--pp_size", type=int, default=2, help="pp size") @@ -112,17 +112,24 @@ def main(): # Set plugin booster_kwargs = {} - if args.plugin == "zero1": + if args.plugin == "zero2": dp_size = dist.get_world_size() - plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=1) + plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2) MOE_MANAGER.setup( seed=42, - parallel="EP", + parallel=None, use_kernel_optim=args.use_kernel, ) - elif args.plugin == "zero2": + elif args.plugin == "zero2_ep": dp_size = dist.get_world_size() - plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2) + plugin = MoeHybridParallelPlugin( + tp_size=1, + pp_size=1, + zero_stage=2, + custom_policy=OpenMoeForCausalLMPolicy(), + enable_fused_normalization=args.use_kernel, + enable_jit_fused=args.use_kernel, + ) MOE_MANAGER.setup( seed=42, parallel="EP", @@ -215,6 +222,7 @@ def main(): pbar.set_postfix({"loss": loss.item()}) else: # Forward pass + data = next(train_dataloader_iter) data = move_to_cuda(data, torch.cuda.current_device()) outputs = model(**data) loss = outputs["loss"] diff --git a/examples/language/openmoe/benchmark/benchmark_cai.sh b/examples/language/openmoe/benchmark/benchmark_cai.sh index 24d0c1b23ab2..620bd4901ccd 100755 --- a/examples/language/openmoe/benchmark/benchmark_cai.sh +++ b/examples/language/openmoe/benchmark/benchmark_cai.sh @@ -2,12 +2,11 @@ set -xue -NUM_GPU=4 -MODEL="base" -BATCH_SIZE=1 +NUM_GPU=8 +MODEL="8b" SEQ_LENGTH=2048 -WARMUP=10 -ACTIVE=10 +WARMUP=5 +ACTIVE=5 # HACK: make model importable example_dir=$(dirname $(realpath $(dirname $0))) @@ -21,7 +20,7 @@ fi torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ --model_name $MODEL \ - --batch_size $BATCH_SIZE \ + --batch_size 512 \ --seq_length $SEQ_LENGTH \ --warmup $WARMUP \ --active $ACTIVE \ @@ -29,28 +28,28 @@ torchrun --standalone --nproc_per_node $NUM_GPU \ --plugin hybrid \ --pp_size 2 \ --dp_size 1 \ - --ep_size 2 \ + --ep_size 4 \ --zero_stage 1 \ - --microbatch_size 1 + --microbatch_size 32 -# zero1 +# zero2 torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ --model_name $MODEL \ - --batch_size $BATCH_SIZE \ + --batch_size 8 \ --seq_length $SEQ_LENGTH \ --warmup $WARMUP \ --active $ACTIVE \ - --plugin zero1 \ + --plugin zero2 \ --use_kernel -# zero2 +# zero2_ep torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ --model_name $MODEL \ - --batch_size $BATCH_SIZE \ + --batch_size 16 \ --seq_length $SEQ_LENGTH \ --warmup $WARMUP \ --active $ACTIVE \ - --plugin zero2 \ + --plugin zero2_ep \ --use_kernel diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.py b/examples/language/openmoe/benchmark/benchmark_fsdp.py index c7357c06e5c7..1b69c8d4abeb 100644 --- a/examples/language/openmoe/benchmark/benchmark_fsdp.py +++ b/examples/language/openmoe/benchmark/benchmark_fsdp.py @@ -1,13 +1,15 @@ import argparse +import functools import os import torch import torch.distributed as dist import torch.multiprocessing as mp import tqdm -from model.modeling_openmoe import LlamaConfig, OpenMoeForCausalLM +from model.modeling_openmoe import LlamaConfig, OpenMoeDecoderLayer, OpenMoeForCausalLM from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.utils.data import Dataset from torch.utils.data.distributed import DistributedSampler from transformers import Adafactor @@ -18,8 +20,9 @@ class RandomDataset(Dataset): - - def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000): + def __init__( + self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000 + ): self.num_samples = num_samples self.max_length = max_length self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length)) @@ -45,9 +48,13 @@ def fsdp_main(rank, world_size, args): MOE_MANAGER.setup(seed=42, parallel=None, use_kernel_optim=False) dp_size = dist.get_world_size() - dataset = RandomDataset(max_length=args.seq_length, - num_samples=args.batch_size * (args.warmup + args.active) * dp_size) - sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size, shuffle=False) + dataset = RandomDataset( + max_length=args.seq_length, + num_samples=args.batch_size * (args.warmup + args.active) * dp_size, + ) + sampler = DistributedSampler( + dataset, rank=rank, num_replicas=world_size, shuffle=False + ) train_kwargs = {"batch_size": args.batch_size, "sampler": sampler} train_loader = torch.utils.data.DataLoader(dataset, **train_kwargs) torch.cuda.set_device(rank) @@ -57,7 +64,13 @@ def fsdp_main(rank, world_size, args): setattr(config, "router_z_loss_factor", 0.1) setattr(config, "label_smoothing", 0.1) setattr(config, "z_loss_factor", 0.1) - model = OpenMoeForCausalLM(config).to(rank) + model = OpenMoeForCausalLM(config) + auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls={ + OpenMoeDecoderLayer, + }, + ) model = FSDP( model, mixed_precision=MixedPrecision( @@ -65,6 +78,8 @@ def fsdp_main(rank, world_size, args): reduce_dtype=torch.float16, buffer_dtype=torch.float16, ), + auto_wrap_policy=auto_wrap_policy, + device_id=torch.cuda.current_device(), ) optimizer = Adafactor(model.parameters()) model.train() @@ -99,7 +114,9 @@ def fsdp_main(rank, world_size, args): performance_evaluator.on_fit_end() if dist.get_rank() == 0: - print(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") + print( + f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB" + ) if __name__ == "__main__": diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.sh b/examples/language/openmoe/benchmark/benchmark_fsdp.sh index a4cb32019431..41ffcd882a3b 100755 --- a/examples/language/openmoe/benchmark/benchmark_fsdp.sh +++ b/examples/language/openmoe/benchmark/benchmark_fsdp.sh @@ -2,12 +2,12 @@ set -xue -NUM_GPU=4 -MODEL="base" +NUM_GPU=8 +MODEL="8b" BATCH_SIZE=1 SEQ_LENGTH=2048 -WARMUP=10 -ACTIVE=10 +WARMUP=5 +ACTIVE=5 # HACK: make model importable example_dir=$(dirname $(realpath $(dirname $0))) diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index 6933f108a09e..f8c79320fa57 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -997,10 +997,11 @@ def forward(ctx, logits, targets, z_loss): shifted = logits - max_logit exp_shifted = torch.exp(shifted) sum_exp = torch.sum(exp_shifted, axis=-1, keepdims=True) - log_softmax = shifted - torch.log(sum_exp) + sum_exp_log = torch.log(sum_exp) + log_softmax = shifted - sum_exp_log loss = -torch.sum(targets * log_softmax, axis=-1) # Add auxilliary z-loss term. - log_z = torch.squeeze(torch.log(sum_exp) + max_logit, axis=-1) + log_z = torch.squeeze(sum_exp_log + max_logit, axis=-1) total_z_loss = z_loss * torch.square(log_z) loss += total_z_loss ctx.z_loss = z_loss diff --git a/examples/language/openmoe/model/openmoe_policy.py b/examples/language/openmoe/model/openmoe_policy.py index cc82683cd319..f354bbea990e 100644 --- a/examples/language/openmoe/model/openmoe_policy.py +++ b/examples/language/openmoe/model/openmoe_policy.py @@ -97,7 +97,7 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli else: module = self.model.model - layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages) + layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} self.append_or_create_method_replacement(description=method_replacement, @@ -110,7 +110,7 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" assert self.pipeline_stage_manager is not None - if self.model.__class__.__name__ == "LlamaModel": + if self.model.__class__.__name__ == "OpenMoeModel": module = self.model else: module = self.model.model @@ -126,6 +126,23 @@ def get_held_layers(self) -> List[Module]: held_layers.append(module.norm) return held_layers + + @staticmethod + def distribute_layers(num_layers: int, num_stages: int) -> List[int]: + """Divide layers into stages + + """ + if num_layers == 24 and num_stages == 4: + return [7, 7, 7, 3] + elif num_layers == 24 and num_stages == 2: + return [15, 9] + elif num_layers == 12 and num_stages == 4: + return [5, 5, 5, 1] + elif num_layers == 12 and num_stages == 2: + return [8, 4] + else: + print(f"num_layers: {num_layers}, num_stages: {num_stages} not optimized, use origin pp policy") + return Policy.distribute_layers(num_layers, num_stages) class OpenMoeModelPolicy(OpenMoePolicy): @@ -401,7 +418,7 @@ def llama_for_causal_lm_forward( stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, - chunk_head: Optional[bool] = None, + chunk_head: Optional[bool] = True, past_router_aux_loss: Optional[torch.FloatTensor] = None, past_router_z_loss: Optional[torch.FloatTensor] = None, ): diff --git a/examples/language/openmoe/test_ci.sh b/examples/language/openmoe/test_ci.sh index 86742e088f71..0f68db4275f7 100644 --- a/examples/language/openmoe/test_ci.sh +++ b/examples/language/openmoe/test_ci.sh @@ -7,5 +7,15 @@ python infer.py --model "test" torchrun --standalone --nproc_per_node 4 train.py \ --num_epoch 1 \ --model_name "test" \ - --plugin zero2 \ + --plugin zero2_ep \ + --batch_size 1 + +torchrun --standalone --nproc_per_node 4 train.py \ + --model_name "test" \ + --plugin "hybrid" \ + --num_epoch 1 \ + --pp_size 2 \ + --dp_size 1 \ + --ep_size 2 \ + --zero_stage 1 \ --batch_size 1 diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index a7f46f2f693b..6f239104328c 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -14,7 +14,6 @@ import colossalai from colossalai import get_default_parser from colossalai.booster import Booster -from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator from colossalai.logging import disable_existing_loggers, get_dist_logger @@ -28,7 +27,7 @@ def move_to_cuda(batch, device): return {k: v.to(device) for k, v in batch.items()} -def load_ckpt(repo_name: str, model: OpenMoeForCausalLM): +def load_ckpt(repo_name: str, model: OpenMoeForCausalLM, booster: Booster): ckpt_path = snapshot_download(repo_name) # single ckpt if os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin")): @@ -38,7 +37,7 @@ def load_ckpt(repo_name: str, model: OpenMoeForCausalLM): ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json") else: raise ValueError(f"Invalid checkpoint path: {ckpt_path}") - MoeCheckpintIO().load_model(model, ckpt_path) + booster.load_model(model, ckpt_path) class RandomDataset(Dataset): @@ -89,7 +88,7 @@ def parse_args(): type=str, default="hybrid", help="parallel plugin", - choices=["zero1", "zero2", "hybrid"], + choices=["zero1_ep", "zero2_ep", "hybrid"], ) # hybrid plugin parser.add_argument("--pp_size", type=int, default=2, help="pp size") @@ -146,15 +145,29 @@ def main(): # Set plugin booster_kwargs = {} - if args.plugin == "zero1": - plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=1) + if args.plugin == "zero1_ep": + plugin = MoeHybridParallelPlugin( + tp_size=1, + pp_size=1, + zero_stage=1, + custom_policy=OpenMoeForCausalLMPolicy(), + enable_fused_normalization=args.use_kernel, + enable_jit_fused=args.use_kernel, + ) MOE_MANAGER.setup( seed=42, parallel="EP", use_kernel_optim=args.use_kernel if not test_mode else False, ) - elif args.plugin == "zero2": - plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2) + elif args.plugin == "zero2_ep": + plugin = MoeHybridParallelPlugin( + tp_size=1, + pp_size=1, + zero_stage=2, + custom_policy=OpenMoeForCausalLMPolicy(), + enable_fused_normalization=args.use_kernel, + enable_jit_fused=args.use_kernel, + ) MOE_MANAGER.setup( seed=42, parallel="EP", @@ -198,8 +211,6 @@ def main(): setattr(config, "z_loss_factor", args.z_loss_factor) with skip_init(): model = OpenMoeForCausalLM(config) - if not test_mode: - load_ckpt(repo_name, model) logger.info(f"Finish init model with config:\n{config}", ranks=[0]) # Enable gradient checkpointing @@ -216,6 +227,8 @@ def main(): # Set booster booster = Booster(plugin=plugin, **booster_kwargs) model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader) + if not test_mode: + load_ckpt(repo_name, model, booster) use_pipeline = (isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1) is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() logger.info(f"Finish init booster", ranks=[0]) diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index 1c70c5d43dbd..489f5ebdacfc 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -1,47 +1,152 @@ import os +import shutil import pytest import torch import torch.distributed as dist +from transformers.models.llama import LlamaConfig import colossalai -from colossalai.moe import MoeCheckpintIO +from colossalai.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device -from tests.test_moe.moe_utils import MoeModel +from examples.language.openmoe.model.modeling_openmoe import OpenMoeForCausalLM +from examples.language.openmoe.model.openmoe_policy import OpenMoeForCausalLMPolicy -def exam_moe_checkpoint(): - ckpt = MoeCheckpintIO() - model = MoeModel(checkpoint=True).to(get_current_device()) - ckpt.save_model(model, 'temp_path.pth') +def get_config(): + config = LlamaConfig( + vocab_size=300, + hidden_size=32, + intermediate_size=64, + num_hidden_layers=2, + num_attention_heads=2, + ) + settings = { + "vocab_size": 300, + "intermediate_size": 32, + "hidden_size": 16, + "num_hidden_layers": 2, + "head_dim": 4, + "num_attention_heads": 4, + "dropout_rate": 0.0, + "layer_norm_epsilon": 1e-06, + "hidden_act": "swiglu", + "num_experts": 16, + "topk": 2, + "capacity_factor_train": 1.25, + "capacity_factor_eval": 2.0, + "min_capacity": 4, + "noisy_policy": None, + "drop_tks": True, + "expert_parallel": None, + "gated": True, + "moe_layer_interval": 4, + "router_aux_loss_factor": 0.1, + "router_z_loss_factor": 0.1, + "label_smoothing": 0.1, + "z_loss_factor": 0.1, + } + for key, value in settings.items(): + setattr(config, key, value) + return config - other_model = MoeModel(checkpoint=True).to(get_current_device()) - ckpt.load_model(other_model, 'temp_path.pth') - state_0 = model.state_dict() - state_1 = other_model.state_dict() - for k, v in state_0.items(): - u = state_1.get(k) +def get_model(parallel): + config = get_config() + model = OpenMoeForCausalLM(config) + + if parallel == None: + plugin = MoeHybridParallelPlugin( + tp_size=1, + pp_size=1, + zero_stage=0, + custom_policy=OpenMoeForCausalLMPolicy(), + ) + elif parallel == "zero_ep": + plugin = MoeHybridParallelPlugin( + tp_size=1, + pp_size=1, + zero_stage=2, + custom_policy=OpenMoeForCausalLMPolicy(), + ) + elif parallel == "hybrid": + plugin = MoeHybridParallelPlugin( + tp_size=1, + pp_size=2, + zero_stage=1, + microbatch_size=1, + custom_policy=OpenMoeForCausalLMPolicy(), + ) + booster = Booster(plugin=plugin) + model, _, _, _, _ = booster.boost(model=model) + return model, booster + + +def _test_moe_checkpoint(parallel, shard): + if parallel == None: + MOE_MANAGER.setup( + seed=42, + parallel=None, + ) + elif parallel == "zero2_ep": + MOE_MANAGER.setup( + seed=42, + parallel="EP", + ) + elif parallel == "hybrid": + MOE_MANAGER.setup( + seed=42, + parallel="EP", + mode="fixed", + fixed_dp_size=1, + fixed_ep_size=2, + fixed_pp_size=2, + ) + model1, booster1 = get_model(parallel) + model2, booster2 = get_model(parallel) + + if shard: + booster1.save_model(model1, "./tmp_ckpt", shard=True, size_per_shard=1) + booster2.load_model(model2, "./tmp_ckpt") + else: + booster1.save_model(model1, "tmp_ckpt.pth") + booster2.load_model(model2, "tmp_ckpt.pth") + + state1 = model1.state_dict() + state2 = model2.state_dict() + for k, v in state1.items(): + u = state2.get(k) assert torch.equal(u.data, v.data) if dist.get_rank() == 0: - os.remove("temp_path.pth") + if shard: + shutil.rmtree("./tmp_ckpt") + else: + os.remove("tmp_ckpt.pth") -def _run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_MANAGER.setup(seed=42) - exam_moe_checkpoint() +def _run_dist(rank, world_size, port, parallel, shard): + colossalai.launch( + config=dict(), + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) + _test_moe_checkpoint(parallel, shard) @pytest.mark.dist -@pytest.mark.parametrize("world_size", [2, 4]) +@pytest.mark.parametrize("world_size", [4]) +@pytest.mark.parametrize("parallel", [None, "zero_ep", "hybrid"]) +@pytest.mark.parametrize("shard", [True, False]) @rerun_if_address_is_in_use() -def test_moe_checkpoint(world_size): - spawn(_run_dist, world_size) +def test_moe_checkpoint(world_size, parallel, shard): + spawn(_run_dist, world_size, parallel=parallel, shard=shard) if __name__ == "__main__": - test_moe_checkpoint(world_size=4) + test_moe_checkpoint(world_size=4, parallel="hybrid", shard=True) From 4d74f83a913127170e6b3f4edecee22215d0cb12 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> Date: Wed, 4 Oct 2023 14:22:48 +0800 Subject: [PATCH 18/46] [moe] support overlap for expert tp (#4851) * overlap comm * fix typo * update bench script * add option * update script * update bench --- colossalai/moe/_operation.py | 64 ++++++-- colossalai/moe/experts.py | 80 ++++++---- colossalai/moe/layers.py | 146 ++++++++++++++---- colossalai/moe/utils.py | 56 ++++--- .../openmoe/benchmark/benchmark_cai.py | 59 ++++--- .../openmoe/benchmark/benchmark_cai.sh | 41 +++-- 6 files changed, 318 insertions(+), 128 deletions(-) diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index c47d8df296b4..740d17b5698f 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -8,6 +8,8 @@ from colossalai.moe.manager import MOE_MANAGER MOE_KERNEL = None +WORLD_HANDLE_ALLGATHER = None +WORLD_HANDLE_REDUCESCATTER = None def load_moe(): @@ -19,9 +21,15 @@ def load_moe(): class AllGather(torch.autograd.Function): @staticmethod - def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: + def forward( + ctx: Any, + inputs: Tensor, + group: Optional[ProcessGroup] = None, + overlap: bool = False, + ) -> Tensor: if ctx is not None: ctx.comm_grp = group + ctx.overlap = overlap comm_size = dist.get_world_size(group) if comm_size == 1: @@ -30,19 +38,40 @@ def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> T buffer_shape = (comm_size,) + inputs.shape outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device) buffer_list = list(torch.chunk(outputs, comm_size, dim=0)) - dist.all_gather(buffer_list, inputs, group=group) - return outputs + if not overlap: + dist.all_gather(buffer_list, inputs, group=group) + return outputs, None + else: + handle = dist.all_gather(buffer_list, inputs, group=group, async_op=True) + if ctx is None and overlap: + global WORLD_HANDLE_ALLGATHER + WORLD_HANDLE_ALLGATHER = handle + return outputs, handle @staticmethod - def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]: - return ReduceScatter.forward(None, grad_outputs, ctx.comm_grp), None + def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None]: + global WORLD_HANDLE_REDUCESCATTER + if WORLD_HANDLE_REDUCESCATTER is not None: + WORLD_HANDLE_REDUCESCATTER.wait() + WORLD_HANDLE_REDUCESCATTER = None + return ( + ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, ctx.overlap)[0], + None, + None, + ) class ReduceScatter(torch.autograd.Function): @staticmethod - def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: + def forward( + ctx: Any, + inputs: Tensor, + group: Optional[ProcessGroup] = None, + overlap: bool = False, + ) -> Tensor: if ctx is not None: ctx.comm_grp = group + ctx.overlap = overlap comm_size = dist.get_world_size(group) if comm_size == 1: @@ -54,12 +83,27 @@ def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> T output_shape = inputs.shape[1:] outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device) buffer_list = list(torch.chunk(inputs, comm_size, dim=0)) - dist.reduce_scatter(outputs, buffer_list, group=group) - return outputs + if not overlap: + dist.reduce_scatter(outputs, buffer_list, group=group) + return outputs, None + else: + handle = dist.reduce_scatter(outputs, buffer_list, group=group, async_op=True) + if ctx is None and overlap: + global WORLD_HANDLE_REDUCESCATTER + WORLD_HANDLE_REDUCESCATTER = handle + return outputs, handle @staticmethod - def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]: - return AllGather.forward(None, grad_outputs, ctx.comm_grp), None + def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None]: + global WORLD_HANDLE_ALLGATHER + if WORLD_HANDLE_ALLGATHER is not None: + WORLD_HANDLE_ALLGATHER.wait() + WORLD_HANDLE_ALLGATHER = None + return ( + AllGather.forward(None, grad_outputs[0], ctx.comm_grp, ctx.overlap)[0], + None, + None, + ) class AllToAll(torch.autograd.Function): diff --git a/colossalai/moe/experts.py b/colossalai/moe/experts.py index 4535d8ab9a85..e05ea59b3d28 100644 --- a/colossalai/moe/experts.py +++ b/colossalai/moe/experts.py @@ -1,6 +1,6 @@ import math from contextlib import nullcontext -from typing import Callable, Optional +from typing import Callable, Optional, Tuple import torch import torch.nn as nn @@ -52,8 +52,9 @@ def __init__( num_experts, use_tp=True if expert_parallel == "TP" else False) # get settings for different parallel if expert_parallel == "TP": - assert intermediate_size % MOE_MANAGER.max_ep_size == 0, \ - "intermediate_size should be divide by maximum expert parallel size" + assert ( + intermediate_size % + MOE_MANAGER.max_ep_size == 0), "intermediate_size should be divide by maximum expert parallel size" intermediate_size = intermediate_size // MOE_MANAGER.max_ep_size num_experts = self.num_total_experts else: @@ -77,11 +78,11 @@ def __init__( seed_ctx = nullcontext() with seed_ctx: if gated: - nn.init.trunc_normal_(self.wi_gate, std=math.sqrt(0.1 / hidden_size)) - nn.init.trunc_normal_(self.wi_up, std=math.sqrt(0.1 / hidden_size)) + torch.nn.init.trunc_normal_(self.wi_gate, std=math.sqrt(0.1 / hidden_size)) + torch.nn.init.trunc_normal_(self.wi_up, std=math.sqrt(0.1 / hidden_size)) else: - nn.init.trunc_normal_(self.wi, std=math.sqrt(0.1 / hidden_size)) - nn.init.trunc_normal_(self.wo, std=math.sqrt(0.1 / intermediate_size)) + torch.nn.init.trunc_normal_(self.wi, std=math.sqrt(0.1 / hidden_size)) + torch.nn.init.trunc_normal_(self.wo, std=math.sqrt(0.1 / intermediate_size)) self.act_name = activation self.act = get_activation(activation) @@ -91,7 +92,7 @@ def __init__( for param in self.parameters(): set_moe_tensor_info(param, self.moe_info) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, param_slice: Tuple[slice] = (slice(None),)) -> torch.Tensor: """ Args: x (torch.Tensor): The input tensor of shape (num_groups, num_experts, capacity, hidden_size) @@ -110,14 +111,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.gated: if HAS_TRITON and self.act_name == "swiglu": - x = LlamaActCombine.apply(torch.bmm(x, self.wi_gate), torch.bmm(x, self.wi_up)) + x = LlamaActCombine.apply( + torch.bmm(x, self.wi_gate[param_slice]), + torch.bmm(x, self.wi_up[param_slice]), + ) else: - x = self.act(torch.bmm(x, self.wi_gate)) * torch.bmm(x, self.wi_up) + x = self.act(torch.bmm(x, self.wi_gate[param_slice])) * torch.bmm(x, self.wi_up[param_slice]) else: - x = torch.bmm(x, self.wi) + x = torch.bmm(x, self.wi[param_slice]) x = self.act(x) x = self.drop(x) - x = torch.bmm(x, self.wo) + x = torch.bmm(x, self.wo[param_slice]) x = x.reshape(inshape) x = x.transpose(0, 1).contiguous() @@ -130,14 +134,24 @@ class EPMLPExperts(BaseMLPExperts): Use expert parallelism to split each expert evenly, which can deploy experts in """ - def __init__(self, - num_experts: int, - hidden_size: int, - intermediate_size: int, - activation=None, - drop_rate: float = 0, - gated: bool = False): - super().__init__(num_experts, hidden_size, intermediate_size, "EP", activation, drop_rate, gated) + def __init__( + self, + num_experts: int, + hidden_size: int, + intermediate_size: int, + activation=None, + drop_rate: float = 0, + gated: bool = False, + ): + super().__init__( + num_experts, + hidden_size, + intermediate_size, + "EP", + activation, + drop_rate, + gated, + ) class TPMLPExperts(BaseMLPExperts): @@ -146,14 +160,24 @@ class TPMLPExperts(BaseMLPExperts): maximum expert parallel size can't be divide by the number of experts. """ - def __init__(self, - num_experts: int, - hidden_size: int, - intermediate_size: int, - activation: str = None, - drop_rate: float = 0, - gated: bool = False): - super().__init__(num_experts, hidden_size, intermediate_size, "TP", activation, drop_rate, gated) + def __init__( + self, + num_experts: int, + hidden_size: int, + intermediate_size: int, + activation: str = None, + drop_rate: float = 0, + gated: bool = False, + ): + super().__init__( + num_experts, + hidden_size, + intermediate_size, + "TP", + activation, + drop_rate, + gated, + ) def get_expert_class(name: str) -> BaseMLPExperts: diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index a78bfe0a3d74..c2cf627aceae 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -42,42 +42,52 @@ class SparseMLP(nn.Module): https://arxiv.org/abs/2201.05596 """ - def __init__(self, - num_experts: int, - top_k: int = 1, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_policy: Optional[str] = None, - drop_tks: bool = True, - expert_parallel: str = "EP", - hidden_size: int = 2048, - intermediate_size: int = 2048, - activation: str = None, - gated: bool = False): + def __init__( + self, + num_experts: int, + top_k: int = 1, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_policy: Optional[str] = None, + drop_tks: bool = True, + expert_parallel: str = "EP", + hidden_size: int = 2048, + intermediate_size: int = 2048, + activation: str = None, + gated: bool = False, + ): super().__init__() self.hidden_size = hidden_size self.num_experts = num_experts self.use_kernel = MOE_MANAGER.use_kernel_optim self.expert_parallel = expert_parallel - assert expert_parallel in ["EP", "TP", None], f"Unsupported expert parallel type {expert_parallel}" + assert expert_parallel in [ + "EP", + "TP", + None, + ], f"Unsupported expert parallel type {expert_parallel}" # moe router noisy_func = get_noise_generator(noisy_policy, num_experts) router_cls = get_router_cls(top_k) - self.router: MoeRouter = router_cls(capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks) + self.router: MoeRouter = router_cls( + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks, + ) # moe experts expert_cls = get_expert_class(expert_parallel) - self.experts: BaseMLPExperts = expert_cls(num_experts=num_experts, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - activation=activation, - gated=gated) + self.experts: BaseMLPExperts = expert_cls( + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + activation=activation, + gated=gated, + ) if expert_parallel is not None: self.ep_group = get_ep_group(self.experts) self.ep_size = get_ep_size(self.experts) @@ -88,9 +98,7 @@ def __init__(self, self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size)) nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / self.hidden_size)) - def forward(self, - inputs: torch.Tensor) \ - -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: inputs (torch.Tensor): The input tensor of shape (batch_size, seq_len, hidden_size) @@ -146,6 +154,15 @@ def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor: return expert_out def _ep_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: + """ + Expert Parallel + + Args: + dispatch_data (torch.Tensor): (num_experts, capacity, hidden_size) + + Returns: + torch.Tensor: (num_experts, capacity, hidden_size) + """ expert_input = AllToAll.apply(dispatch_data, self.ep_group) input_shape = expert_input.shape expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) @@ -155,7 +172,74 @@ def _ep_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: return expert_output def _tp_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: - expert_in = AllGather.apply(dispatch_data, self.ep_group) - expert_out = self.experts(expert_in) - expert_out = ReduceScatter.apply(expert_out, self.ep_group) - return expert_out + """ + TP with overlap. + + origin: + | C | + | A | | R | + + overlap: + | C1 || C2 || C3 || C4 | + | A1 || A2 | | R1 | A3 || R2 | A4 || R3 | | R4 | + + C is computation, A is all gather, R is reduce scatter. + + Args: + dispatch_data (torch.Tensor): (num_experts, capacity, hidden_size) + + Returns: + torch.Tensor: (num_experts, capacity, hidden_size) + """ + chunk_num = 4 + chunk_size = dispatch_data.shape[0] // chunk_num + out = torch.empty_like(dispatch_data) + in_data = None + in_handle = None + out_data = None + out_handle = None + + # backward compatibility for async op + torch.cuda.synchronize() + + def get_chunk_slice(idx: int, gap: int) -> Tuple[slice]: + return (slice(idx * gap, (idx + 1) * gap),) + + for i in range(chunk_num): + cur_chunk_slice = get_chunk_slice(i, chunk_size) + + # if first, all gather + if i == 0: + d = dispatch_data[cur_chunk_slice].contiguous() + expert_in, _ = AllGather.apply(d, self.ep_group) + else: + expert_in = in_data + + # async communication while compute + if i != 0: + # reduce scatter last out + out_data, out_handle = ReduceScatter.apply(out_data, self.ep_group, True) + if i != chunk_num - 1: + # all gather next in + next_d = dispatch_data[get_chunk_slice(i + 1, chunk_size)].contiguous() + in_data, in_handle = AllGather.apply(next_d, self.ep_group, True) + + # compute + expert_out = self.experts(expert_in, cur_chunk_slice) + + # sync handle + if i != 0: + out_handle.wait() + out[get_chunk_slice(i - 1, chunk_size)] = out_data + if i != chunk_num - 1: + in_handle.wait() + out_data = expert_out + + # store out for last loop + if i == chunk_num - 1: + out_data, _ = ReduceScatter.apply(out_data, self.ep_group) + out[cur_chunk_slice] = out_data + + # sync for async op + torch.cuda.synchronize() + return out diff --git a/colossalai/moe/utils.py b/colossalai/moe/utils.py index 58c1665a4d63..e3bc6d3cac9a 100644 --- a/colossalai/moe/utils.py +++ b/colossalai/moe/utils.py @@ -28,9 +28,10 @@ class NormalNoiseGenerator: """ def __init__(self, num_experts: int): - self.normal = torch.distributions.normal.Normal(loc=torch.tensor(0.0, device=get_current_device()), - scale=torch.tensor(1.0 / num_experts**2, - device=get_current_device())).rsample + self.normal = torch.distributions.normal.Normal( + loc=torch.tensor(0.0, device=get_current_device()), + scale=torch.tensor(1.0 / num_experts**2, device=get_current_device()), + ).rsample def __call__(self, inputs: torch.Tensor): noisy = self.normal(inputs.shape) @@ -49,9 +50,10 @@ class UniformNoiseGenerator: """ def __init__(self, eps: float = 1e-2): - self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - eps, device=get_current_device()), - high=torch.tensor(1.0 + eps, - device=get_current_device())).rsample + self.uniform = torch.distributions.uniform.Uniform( + low=torch.tensor(1.0 - eps, device=get_current_device()), + high=torch.tensor(1.0 + eps, device=get_current_device()), + ).rsample def __call__(self, inputs: torch.Tensor): noisy = self.uniform(inputs.shape) @@ -65,9 +67,9 @@ def autocast_softmax(logit: torch.Tensor, dim: int): def get_noise_generator(noise_type: str, num_experts: int) -> Callable: if noise_type is None: return None - elif noise_type == 'Jitter': + elif noise_type == "Jitter": noisy_func = UniformNoiseGenerator() - elif noise_type == 'Gaussian': + elif noise_type == "Gaussian": noisy_func = NormalNoiseGenerator(num_experts) else: raise NotImplementedError("Unsupported input noisy policy") @@ -75,11 +77,11 @@ def get_noise_generator(noise_type: str, num_experts: int) -> Callable: def get_activation(act: str) -> Callable: - if act is None or act == 'relu': + if act is None or act == "relu": return torch.nn.ReLU() - elif act == 'gelu': + elif act == "gelu": return torch.nn.GELU() - elif act == 'swiglu': + elif act == "swiglu": return SwiGLU else: raise NotImplementedError("Unsupported activation function") @@ -103,24 +105,28 @@ def skip_init(): skip param random init """ - def _skip_init(x, *args, **kwargs): - return x + def _skip_init(*args, **kwargs): + pass - # __enter__ - fn_saved = [] - init_fn_list = [ - torch.nn.init.constant_, torch.nn.init.uniform_, torch.nn.init.normal_, torch.nn.init.xavier_uniform_, - torch.nn.init.xavier_normal_, torch.nn.init.kaiming_uniform_, torch.nn.init.kaiming_normal_ - ] - for fn in init_fn_list: - fn_saved.append(fn) - fn = _skip_init + init_func = { + "constant_": torch.nn.init.constant_, + "uniform_": torch.nn.init.uniform_, + "normal_": torch.nn.init.normal_, + "kaiming_uniform_": torch.nn.init.kaiming_uniform_, + "kaiming_normal_": torch.nn.init.kaiming_normal_, + "xavier_normal_": torch.nn.init.xavier_normal_, + "xavier_uniform_": torch.nn.init.xavier_uniform_, + "trunc_normal_": torch.nn.init.trunc_normal_, + } + + for method_name, original_init in init_func.items(): + setattr(torch.nn.init, method_name, _skip_init) yield - # __exit__ - for fn, fn_saved in zip(init_fn_list, fn_saved): - fn = fn_saved + for method_name, original_init in init_func.items(): + setattr(torch.nn.init, method_name, original_init) + return diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py index d7dbd58ed0ca..5ff0843caaea 100644 --- a/examples/language/openmoe/benchmark/benchmark_cai.py +++ b/examples/language/openmoe/benchmark/benchmark_cai.py @@ -1,7 +1,8 @@ -import datasets +import os + import torch import torch.distributed as dist -import transformers +from huggingface_hub import snapshot_download from model.modeling_openmoe import OpenMoeForCausalLM from model.openmoe_policy import OpenMoeForCausalLMPolicy from torch.utils.data import Dataset @@ -16,8 +17,8 @@ from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator -from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.moe.manager import MOE_MANAGER +from colossalai.moe.utils import skip_init from colossalai.utils import get_current_device @@ -25,6 +26,19 @@ def move_to_cuda(batch, device): return {k: v.to(device) for k, v in batch.items()} +def load_ckpt(repo_name: str, model: OpenMoeForCausalLM, booster: Booster): + ckpt_path = snapshot_download(repo_name) + # single ckpt + if os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin")): + ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin") + # shard ckpt + elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")): + ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json") + else: + raise ValueError(f"Invalid checkpoint path: {ckpt_path}") + booster.load_model(model, ckpt_path) + + class RandomDataset(Dataset): def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000): @@ -72,7 +86,7 @@ def parse_args(): type=str, default="hybrid", help="parallel plugin", - choices=["zero2", "zero2_ep", "hybrid"], + choices=["zero2", "zero2_ep", "hybrid", "zero2_tp"], ) # hybrid plugin parser.add_argument("--pp_size", type=int, default=2, help="pp size") @@ -100,16 +114,6 @@ def main(): colossalai.launch_from_torch(config={}, seed=args.seed) coordinator = DistCoordinator() - # Manage loggers - disable_existing_loggers() - logger = get_dist_logger() - if coordinator.is_master(): - datasets.utils.logging.set_verbosity_warning() - transformers.utils.logging.set_verbosity_info() - else: - datasets.utils.logging.set_verbosity_error() - transformers.utils.logging.set_verbosity_error() - # Set plugin booster_kwargs = {} if args.plugin == "zero2": @@ -135,6 +139,21 @@ def main(): parallel="EP", use_kernel_optim=args.use_kernel, ) + elif args.plugin == "zero2_tp": + dp_size = dist.get_world_size() + plugin = MoeHybridParallelPlugin( + tp_size=1, + pp_size=1, + zero_stage=2, + custom_policy=OpenMoeForCausalLMPolicy(), + enable_fused_normalization=args.use_kernel, + enable_jit_fused=args.use_kernel, + ) + MOE_MANAGER.setup( + seed=42, + parallel="TP", + use_kernel_optim=args.use_kernel, + ) elif args.plugin == "hybrid": dp_size = dist.get_world_size() // args.pp_size plugin = MoeHybridParallelPlugin( @@ -157,7 +176,7 @@ def main(): ) else: raise ValueError(f"Invalid plugin {args.plugin}") - logger.info(f"Set plugin as {plugin}", ranks=[0]) + coordinator.print_on_master(f"Set plugin as {plugin}") # Build OpenMoe model repo_name = "hpcaitech/openmoe-" + args.model_name @@ -166,8 +185,9 @@ def main(): setattr(config, "router_z_loss_factor", 0.1) setattr(config, "label_smoothing", 0.1) setattr(config, "z_loss_factor", 0.1) - model = OpenMoeForCausalLM(config) - logger.info(f"Finish init model with config:\n{config}", ranks=[0]) + with skip_init(): + model = OpenMoeForCausalLM(config) + coordinator.print_on_master(f"Finish init model with config:\n{config}") # Enable gradient checkpointing model.gradient_checkpointing_enable() @@ -193,12 +213,13 @@ def main(): # Set booster booster = Booster(plugin=plugin, **booster_kwargs) model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader) + load_ckpt(repo_name, model, booster) use_pipeline = (isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1) is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() - logger.info(f"Finish init booster", ranks=[0]) + coordinator.print_on_master(f"Finish init booster") # Start finetuning - logger.info(f"Start finetuning", ranks=[0]) + coordinator.print_on_master(f"Start finetuning") model.train() train_dataloader_iter = iter(dataloader) total_len = len(train_dataloader_iter) - 1 diff --git a/examples/language/openmoe/benchmark/benchmark_cai.sh b/examples/language/openmoe/benchmark/benchmark_cai.sh index 620bd4901ccd..5db65a216461 100755 --- a/examples/language/openmoe/benchmark/benchmark_cai.sh +++ b/examples/language/openmoe/benchmark/benchmark_cai.sh @@ -5,8 +5,8 @@ set -xue NUM_GPU=8 MODEL="8b" SEQ_LENGTH=2048 -WARMUP=5 -ACTIVE=5 +WARMUP=8 +ACTIVE=4 # HACK: make model importable example_dir=$(dirname $(realpath $(dirname $0))) @@ -16,40 +16,51 @@ else export PYTHONPATH=$example_dir:$PYTHONPATH fi -# hybrid +# zero2 torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ --model_name $MODEL \ - --batch_size 512 \ + --batch_size 4 \ --seq_length $SEQ_LENGTH \ --warmup $WARMUP \ --active $ACTIVE \ - --use_kernel \ - --plugin hybrid \ - --pp_size 2 \ - --dp_size 1 \ - --ep_size 4 \ - --zero_stage 1 \ - --microbatch_size 32 + --plugin zero2 \ + --use_kernel -# zero2 +# zero2_tp torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ --model_name $MODEL \ - --batch_size 8 \ + --batch_size 12 \ --seq_length $SEQ_LENGTH \ --warmup $WARMUP \ --active $ACTIVE \ - --plugin zero2 \ + --plugin zero2_tp \ --use_kernel # zero2_ep torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ --model_name $MODEL \ - --batch_size 16 \ + --batch_size 12 \ --seq_length $SEQ_LENGTH \ --warmup $WARMUP \ --active $ACTIVE \ --plugin zero2_ep \ --use_kernel + +# hybrid +torchrun --standalone --nproc_per_node $NUM_GPU \ + $example_dir/benchmark/benchmark_cai.py \ + --model_name $MODEL \ + --batch_size 512 \ + --seq_length $SEQ_LENGTH \ + --warmup $WARMUP \ + --active $ACTIVE \ + --use_kernel \ + --plugin hybrid \ + --pp_size 2 \ + --dp_size 1 \ + --ep_size 4 \ + --zero_stage 1 \ + --microbatch_size 32 From 2481b8372699f1c125bb2b33d5e03750937a0d3f Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> Date: Wed, 11 Oct 2023 19:30:54 +0800 Subject: [PATCH 19/46] [moe] support hybrid zero strategy. (#4877) * overlap comm * fix typo * update bench script * add option * update script * update bench * param init * support dp zero * fix zero dp * fxi bug * update pg bug * update experts * fix optim bug * update config * kaishen niubi * fix bug * embed * Merge branch 'feature/MoE' of https://github.com/hpcaitech/ColossalAI into bench * update bench * update optim * update doc * update sync * fix test * fix arg * update ckpt * update test * fix * remove print * polish code * update hybrid zero optim * update print --- .../plugin/moe_hybrid_parallel_plugin.py | 120 ++++++++- colossalai/moe/experts.py | 50 ++-- colossalai/moe/layers.py | 44 ++-- colossalai/moe/manager.py | 9 +- colossalai/zero/low_level/low_level_optim.py | 236 +++++++++++++----- .../openmoe/benchmark/benchmark_cai.py | 66 ++--- .../openmoe/benchmark/benchmark_cai.sh | 31 ++- .../openmoe/benchmark/benchmark_fsdp.py | 15 +- .../openmoe/benchmark/benchmark_fsdp.sh | 4 +- .../openmoe/model/modeling_openmoe.py | 24 +- pytest.ini | 2 +- tests/test_moe/moe_utils.py | 9 +- tests/test_moe/test_grad_handler.py | 3 +- tests/test_moe/test_kernel.py | 3 +- tests/test_moe/test_moe_checkpoint.py | 10 +- tests/test_moe/test_moe_ep_tp.py | 9 +- tests/test_moe/test_moe_group.py | 2 +- tests/test_moe/test_moe_hybrid_zero.py | 89 +++++++ tests/test_moe/test_moe_local.py | 9 +- tests/test_moe/test_moe_zero_fwd_bwd.py | 2 +- tests/test_moe/test_moe_zero_optim.py | 2 +- 21 files changed, 556 insertions(+), 183 deletions(-) create mode 100644 tests/test_moe/test_moe_hybrid_zero.py diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 784204528d65..5171780da347 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -1,23 +1,71 @@ import random -from typing import Optional +from typing import Callable, Optional, OrderedDict, Tuple import numpy as np import torch import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.nn import Module +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelPlugin +from colossalai.booster.plugin.hybrid_parallel_plugin import ( + HybridParallelAMPOptimizer, + HybridParallelModule, + HybridParallelNaiveOptimizer, + HybridParallelPlugin, + get_param_info, + init_pipeline_optimizer, +) from colossalai.cluster import ProcessGroupMesh +from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.moe import MoeCheckpintIO from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig from colossalai.shardformer.policies.base_policy import Policy +from colossalai.zero.low_level import LowLevelZeroOptimizer PP_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2 +class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): + + def __init__( + self, + optimizer: Optimizer, + model: Module, + use_pipeline: bool, + param_info: OrderedDict, + initial_scale: int = 2**16, # grad scaler config + min_scale: int = 1, + growth_factor: float = 2., + backoff_factor: float = .5, + growth_interval: int = 2000, + hysteresis: int = 2, + max_scale: int = 2**24, + clip_grad_norm: float = 0.0, # grad clipping + verbose: bool = False, + reduce_bucket_size: int = 1024 * 1024, # communication + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True, + partition_grad: bool = False, # stage 2 flag + cpu_offload: bool = False, # cpu offload + dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm + tp_process_group: Optional[ProcessGroup] = None, # if using tp + forced_dtype: Optional[torch.dtype] = None, + extra_dp_process_group: Optional[ProcessGroup] = None): + self.param_info = param_info + if use_pipeline: + init_pipeline_optimizer(optimizer, model) + super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, + hysteresis, max_scale, clip_grad_norm, verbose, reduce_bucket_size, communication_dtype, + overlap_communication, partition_grad, cpu_offload, dp_process_group, tp_process_group, + forced_dtype, extra_dp_process_group) + + class MoeHybridParallelPlugin(HybridParallelPlugin): """ Plugin for Moe Hybrid Parallel Training. @@ -78,6 +126,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): def __init__(self, tp_size: int, pp_size: int, + extra_dp_size: int = 1, precision: str = 'fp16', zero_stage: int = 0, enable_all_optimization: bool = False, @@ -106,6 +155,7 @@ def __init__(self, cpu_offload: bool = False, communication_dtype: Optional[torch.dtype] = None, overlap_communication: bool = True, + use_ep_inside: bool = True, custom_policy: Policy = None) -> None: super().__init__(tp_size=tp_size, @@ -132,6 +182,23 @@ def __init__(self, self.enable_sequence_parallelism = enable_sequence_parallelism # we change pg mesh to (pp, dp, tp) for better moe performance self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size) + + # sync moe in outer dp group, and sync other param in global dp group + if extra_dp_size > 1: + ep_size = self.dp_size // extra_dp_size + if use_ep_inside: + self.pg_mesh_moe = ProcessGroupMesh(self.pp_size, extra_dp_size, ep_size) + self.extra_dp_group = self.pg_mesh_moe.get_group_along_axis(1) + if dist.get_rank() == 0: + print(f"Zero Parallel: pp {self.pp_size}, outer_dp {extra_dp_size}, inner_dp {ep_size}") + else: + self.pg_mesh_moe = ProcessGroupMesh(self.pp_size, ep_size, extra_dp_size) + self.extra_dp_group = self.pg_mesh_moe.get_group_along_axis(2) + if dist.get_rank() == 0: + print(f"Zero Parallel: pp {self.pp_size}, outer_dp {ep_size}, inner_dp {extra_dp_size}") + else: + self.extra_dp_group = None + self.stage_manager = None self.schedule = None self.custom_policy = custom_policy @@ -235,3 +302,52 @@ def seed_worker(worker_id): def get_checkpoint_io(self) -> MoeCheckpintIO: self.checkpoint_io = MoeCheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) return self.checkpoint_io + + def configure( + self, + model: Module, + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, + ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: + param_info = get_param_info(optimizer) + if not isinstance(model, ModelWrapper): + use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 + model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp, + self.ddp_config, self.custom_policy) + if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): + if self.zero_stage == 0: + if self.precision in ['fp16', 'bf16']: + optimizer = HybridParallelAMPOptimizer(optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + precision=self.precision, + max_norm=self.max_norm, + **self.amp_config) + self.checkpoint_io.link_master_and_working_param(optimizer.working_to_master_map, + optimizer.master_to_working_map) + else: + optimizer = HybridParallelNaiveOptimizer(optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info) + else: + assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1." + assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO." + optimizer = HybridParallelZeroOptimizer(optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + dp_process_group=self.dp_group, + tp_process_group=self.tp_group, + extra_dp_process_group=self.extra_dp_group, + verbose=True, + clip_grad_norm=self.max_norm, + **self.zero_config, + **self.amp_config) + self.checkpoint_io.link_master_and_working_param(optimizer._param_store.working_to_master_param, + optimizer._param_store.master_to_working_param) + + return model, optimizer, criterion, dataloader, lr_scheduler diff --git a/colossalai/moe/experts.py b/colossalai/moe/experts.py index e05ea59b3d28..81a7b21544e4 100644 --- a/colossalai/moe/experts.py +++ b/colossalai/moe/experts.py @@ -39,27 +39,28 @@ def __init__( activation: Optional[Callable] = None, drop_rate: float = 0, gated: bool = False, + use_kernel: bool = False, ): super().__init__() assert expert_parallel in ["EP", "TP", None] self.expert_parallel = expert_parallel self.num_total_experts = num_experts self.gated = gated + self.use_kernel = use_kernel + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size # get expert parallel info if expert_parallel is not None: self.num_local_experts, self.moe_info = MOE_MANAGER.get_info( num_experts, use_tp=True if expert_parallel == "TP" else False) # get settings for different parallel + self.ep_size = get_ep_size(self) if expert_parallel == "TP": - assert ( - intermediate_size % - MOE_MANAGER.max_ep_size == 0), "intermediate_size should be divide by maximum expert parallel size" - intermediate_size = intermediate_size // MOE_MANAGER.max_ep_size + intermediate_size = intermediate_size // self.ep_size num_experts = self.num_total_experts else: num_experts = self.num_local_experts - self.ep_size = get_ep_size(self) else: self.num_local_experts = self.num_total_experts self.ep_size = 1 @@ -71,19 +72,6 @@ def __init__( self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) self.wo = nn.Parameter(torch.empty(num_experts, intermediate_size, hidden_size)) - # expert param should be different - if expert_parallel is not None: - seed_ctx = Randomizer(MOE_MANAGER.seed).fork_rng(enable_cpu=True) - else: - seed_ctx = nullcontext() - with seed_ctx: - if gated: - torch.nn.init.trunc_normal_(self.wi_gate, std=math.sqrt(0.1 / hidden_size)) - torch.nn.init.trunc_normal_(self.wi_up, std=math.sqrt(0.1 / hidden_size)) - else: - torch.nn.init.trunc_normal_(self.wi, std=math.sqrt(0.1 / hidden_size)) - torch.nn.init.trunc_normal_(self.wo, std=math.sqrt(0.1 / intermediate_size)) - self.act_name = activation self.act = get_activation(activation) self.drop = nn.Dropout(p=drop_rate) @@ -92,6 +80,24 @@ def __init__( for param in self.parameters(): set_moe_tensor_info(param, self.moe_info) + # init param + self.reset_parameters() + + @torch.no_grad() + def reset_parameters(self): + # expert param should be different + if self.expert_parallel is not None: + seed_ctx = Randomizer(MOE_MANAGER.seed).fork_rng(enable_cpu=True) + else: + seed_ctx = Randomizer(42).fork_rng(enable_cpu=True) + with seed_ctx: + if self.gated: + torch.nn.init.normal_(self.wi_gate, std=math.sqrt(0.1 / self.hidden_size)) + torch.nn.init.normal_(self.wi_up, std=math.sqrt(0.1 / self.hidden_size)) + else: + torch.nn.init.normal_(self.wi, std=math.sqrt(0.1 / self.hidden_size)) + torch.nn.init.normal_(self.wo, std=math.sqrt(0.1 / self.intermediate_size)) + def forward(self, x: torch.Tensor, param_slice: Tuple[slice] = (slice(None),)) -> torch.Tensor: """ Args: @@ -110,7 +116,7 @@ def forward(self, x: torch.Tensor, param_slice: Tuple[slice] = (slice(None),)) - x = x.reshape(e, -1, h) if self.gated: - if HAS_TRITON and self.act_name == "swiglu": + if self.use_kernel and HAS_TRITON and self.act_name == "swiglu": x = LlamaActCombine.apply( torch.bmm(x, self.wi_gate[param_slice]), torch.bmm(x, self.wi_up[param_slice]), @@ -142,7 +148,9 @@ def __init__( activation=None, drop_rate: float = 0, gated: bool = False, + use_kernel: bool = False, ): + # TODO: This class can be aborted super().__init__( num_experts, hidden_size, @@ -151,6 +159,7 @@ def __init__( activation, drop_rate, gated, + use_kernel, ) @@ -168,7 +177,9 @@ def __init__( activation: str = None, drop_rate: float = 0, gated: bool = False, + use_kernel: bool = False, ): + # TODO: This class can be aborted super().__init__( num_experts, hidden_size, @@ -177,6 +188,7 @@ def __init__( activation, drop_rate, gated, + use_kernel, ) diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index c2cf627aceae..036bd32ae7c0 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -51,7 +51,6 @@ def __init__( min_capacity: int = 4, noisy_policy: Optional[str] = None, drop_tks: bool = True, - expert_parallel: str = "EP", hidden_size: int = 2048, intermediate_size: int = 2048, activation: str = None, @@ -59,14 +58,16 @@ def __init__( ): super().__init__() self.hidden_size = hidden_size + self.intermediate_size = intermediate_size self.num_experts = num_experts self.use_kernel = MOE_MANAGER.use_kernel_optim - self.expert_parallel = expert_parallel - assert expert_parallel in [ + self.expert_parallel = MOE_MANAGER.get_parallel() + self.gated = gated + assert self.expert_parallel in [ "EP", "TP", None, - ], f"Unsupported expert parallel type {expert_parallel}" + ], f"Unsupported expert parallel type {self.expert_parallel}" # moe router noisy_func = get_noise_generator(noisy_policy, num_experts) @@ -80,23 +81,29 @@ def __init__( ) # moe experts - expert_cls = get_expert_class(expert_parallel) - self.experts: BaseMLPExperts = expert_cls( - num_experts=num_experts, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - activation=activation, - gated=gated, - ) - if expert_parallel is not None: + expert_cls = get_expert_class(self.expert_parallel) + self.experts: BaseMLPExperts = expert_cls(num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + activation=activation, + gated=gated, + use_kernel=self.use_kernel) + if self.expert_parallel is not None: self.ep_group = get_ep_group(self.experts) self.ep_size = get_ep_size(self.experts) else: self.ep_group = None self.num_local_experts = self.experts.num_local_experts + # gate self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size)) - nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / self.hidden_size)) + + # init param + self.reset_parameters() + + @torch.no_grad() + def reset_parameters(self): + torch.nn.init.normal_(self.gate_weight, std=math.sqrt(0.1 / self.hidden_size)) def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -171,7 +178,7 @@ def _ep_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: expert_output = AllToAll.apply(expert_output, self.ep_group) return expert_output - def _tp_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: + def _tp_process(self, dispatch_data: torch.Tensor, use_overlap: bool = False) -> torch.Tensor: """ TP with overlap. @@ -191,6 +198,13 @@ def _tp_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: (num_experts, capacity, hidden_size) """ + if use_overlap == False: + expert_in, _ = AllGather.apply(dispatch_data, self.ep_group) + expert_out = self.experts(expert_in) + expert_out, _ = ReduceScatter.apply(expert_out, self.ep_group) + return expert_out + + # TODO: there is accuracy problem in overlap chunk_num = 4 chunk_size = dispatch_data.shape[0] // chunk_num out = torch.empty_like(dispatch_data) diff --git a/colossalai/moe/manager.py b/colossalai/moe/manager.py index e61fb0bf9582..1e949bb9a6dd 100644 --- a/colossalai/moe/manager.py +++ b/colossalai/moe/manager.py @@ -41,8 +41,8 @@ def is_initialized(self): def setup(self, seed: int, - use_kernel_optim: bool = True, - parallel: bool = None, + use_kernel_optim: bool = False, + parallel: str = None, mode: str = "dynamic", max_ep_size: int = 8, fixed_dp_size: int = 0, @@ -140,6 +140,11 @@ def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoePara if not (ep_size in self.parallel_info_dict): self.parallel_info_dict[ep_size] = get_moe_info(ep_size, dp_size, pp_size, ep_inside=self.use_ep_inside) + if dist.get_rank() == 0: + if self.use_ep_inside: + print(f"MoE Parallel: pp {pp_size}, dp {dp_size}, ep {ep_size}") + else: + print(f"MoE Parallel: pp {pp_size}, ep {ep_size}, dp {dp_size}") return num_local_experts, self.parallel_info_dict[ep_size] diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index b037274f922b..01776a8352fc 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -8,6 +8,7 @@ import torch.distributed as dist import torch.nn as nn from torch import Tensor, inf +from torch._utils import _flatten_dense_tensors from torch.distributed import ProcessGroup from torch.optim import Optimizer @@ -23,6 +24,15 @@ from colossalai.utils.cuda import get_current_device from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor +from ._utils import ( + calculate_global_norm_from_list, + compute_norm, + flatten, + has_inf_or_nan, + release_param_grad, + sync_tensor, + unflatten, +) from .bookkeeping import BucketStore, GradientStore, ParameterStore @@ -74,7 +84,9 @@ def __init__( partition_grad: bool = False, # stage 2 flag cpu_offload: bool = False, # cpu offload dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm + tp_process_group: Optional[ProcessGroup] = None, # if using tp forced_dtype: Optional[torch.dtype] = None, + extra_dp_process_group: Optional[ProcessGroup] = None, master_weights: bool = True, # master weights ): super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) @@ -95,6 +107,18 @@ def __init__( self._local_rank = dist.get_rank(group=self.dp_pg) self._world_size = dist.get_world_size(group=self.dp_pg) + # extra dp + # This group is used to sync moe param, dp_world_size = moe_duplicates * extra_dp_size. + # Non moe param will be sync by global dp pg, moe param will be sync by extra dp pg. + # Moe param grad is be split as non moe param by global dp pg, and grad will be merged in step. + # And moe working and master param are split by extra dp pg. + self.extra_dp_pg = extra_dp_process_group + if self.extra_dp_pg is not None: + self.extra_dp_pg_size = dist.get_world_size(group=self.extra_dp_pg) + self.extra_dp_pg_rank = dist.get_rank(group=self.extra_dp_pg) + + self.tp_pg = tp_process_group + # working and master params for mixed precision training self._working_param_groups = dict() self._master_param_groups_of_current_rank = dict() @@ -139,10 +163,11 @@ def __init__( group_params = list() for param in param_group['params']: if param.requires_grad: - # skip moe param - if is_moe_tensor(param): - moe_params.append(param) - continue + if self.extra_dp_pg is None: + # skip moe param + if is_moe_tensor(param): + moe_params.append(param) + continue group_params.append(param) # add the working params to working_param_groups for bookkeeping @@ -227,13 +252,18 @@ def _create_master_param_current_rank(self, param_list): param.data = padding_param[: param.numel()].view(param.shape) else: padding_param = param.data.view(-1) - splited_params = padding_param.split(padding_param.numel() // self._world_size) + + if self.extra_dp_pg is not None and is_moe_tensor(param): + splited_params = padding_param.split(padding_param.numel() // self.extra_dp_pg_size) + splited_params = splited_params[self.extra_dp_pg_rank] + else: + splited_params = padding_param.split(padding_param.numel() // self._world_size) + splited_params = splited_params[self._local_rank] # use fp32 when master_weights is True if self._master_weights is True: - splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device) - else: - splited_param_current_rank = splited_params[self._local_rank] + splited_param_current_rank = splited_params.detach().float().to(device) + params_current_rank.append(splited_param_current_rank) self._param_store.link_master_and_working_param(splited_param_current_rank, param) @@ -266,8 +296,9 @@ def _run_reduction(self): if self._bucket_store.num_elements_in_bucket() > 0: self._bucket_store.build_grad_in_bucket() - flat_grads = self._bucket_store.get_flatten_grad() - flat_grads /= self._world_size + if self.extra_dp_pg is None: + flat_grads = self._bucket_store.get_flatten_grad() + flat_grads /= self._world_size # ready to add other tensors to bucket self._bucket_store.reset_num_elements_in_bucket() @@ -275,7 +306,8 @@ def _run_reduction(self): if self._overlap_communication: stream = self._comm_stream # in case of the memory being reused in the default stream - flat_grads.record_stream(stream) + if self.extra_dp_pg is None: + flat_grads.record_stream(stream) # waiting for ops in the default stream finishing stream.wait_stream(torch.cuda.current_stream()) else: @@ -284,29 +316,73 @@ def _run_reduction(self): with torch.cuda.stream(stream): group_id = self._bucket_store.current_group_id - grad_dtype = flat_grads.dtype - if self._communication_dtype is not None: - flat_grads = flat_grads.to(self._communication_dtype) + if self.extra_dp_pg is None: + grad_dtype = flat_grads.dtype + if self._communication_dtype is not None: + flat_grads = flat_grads.to(self._communication_dtype) if not self._partition_grads: - dist.all_reduce(flat_grads, group=self.dp_pg) - if flat_grads.dtype != grad_dtype: - flat_grads = flat_grads.to(grad_dtype) - - flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size) - grad_in_bucket = self._bucket_store.get_grad() - - for rank, grad_list in grad_in_bucket.items(): - sync_tensor(flat_grads_per_rank[rank], grad_list) - for grad in grad_list: - param_id = self._bucket_store.get_param_id_of_grad(grad) - if ( - len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) - < self._world_size - ): - self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) - else: - self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id) + if self.extra_dp_pg is None: + dist.all_reduce(flat_grads, group=self.dp_pg) + if flat_grads.dtype != grad_dtype: + flat_grads = flat_grads.to(grad_dtype) + + flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size) + grad_in_bucket = self._bucket_store.get_grad() + + for rank, grad_list in grad_in_bucket.items(): + sync_tensor(flat_grads_per_rank[rank], grad_list) + for grad in grad_list: + param_id = self._bucket_store.get_param_id_of_grad(grad) + if len(self._grad_store.get_partitioned_gradients_by_param_id( + group_id, param_id)) < self._world_size: + self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) + else: + self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id) + + # sync extra zero group + else: + # record moe and non moe param + moe_list = [] + for param in self._bucket_store._param_list: + moe_list.append(is_moe_tensor(param)) + + # divide them into different groups + moe_grad_list = [] + non_moe_grad_list = [] + for grad_list in self._bucket_store._grad_in_bucket.values(): + non_moe_cur_grad = [] + moe_cur_grad = [] + for i in range(len(grad_list)): + if moe_list[i] == True: + moe_cur_grad.append(grad_list[i]) + else: + non_moe_cur_grad.append(grad_list[i]) + if len(moe_cur_grad) > 0: + moe_grad_list.append(moe_cur_grad) + if len(non_moe_cur_grad) > 0: + non_moe_grad_list.append(non_moe_cur_grad) + + # sync non moe param in global dp group + if len(non_moe_grad_list) > 0: + flat_grads = [] + for grad_list in non_moe_grad_list: + flat_grads.append(_flatten_dense_tensors(grad_list)) + flat_grads = _flatten_dense_tensors(flat_grads) + flat_grads /= self._world_size + dist.all_reduce(flat_grads, group=self.dp_pg) + flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size) + self._sync_unpartitioned_grad(non_moe_grad_list, flat_grads_per_rank, group_id) + + # sync moe param only in zero group + if len(moe_grad_list) > 0: + flat_grads = [] + for grad_list in moe_grad_list: + flat_grads.append(_flatten_dense_tensors(grad_list)) + flat_grads = _flatten_dense_tensors(flat_grads) + dist.all_reduce(flat_grads, group=self.extra_dp_pg) + flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size) + self._sync_unpartitioned_grad(moe_grad_list, flat_grads_per_rank, group_id) else: flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size)) @@ -327,6 +403,16 @@ def _run_reduction(self): self._bucket_store.reset() + def _sync_unpartitioned_grad(self, origin_grad_list, flat_grad_list, group_id): + for rank, grad_list in enumerate(origin_grad_list): + sync_tensor(flat_grad_list[rank], grad_list) + for grad in grad_list: + param_id = self._bucket_store.get_param_id_of_grad(grad) + if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < self._world_size: + self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) + else: + self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id) + def _add_to_bucket(self, param, group_id): param_size = param.numel() @@ -443,13 +529,18 @@ def step(self, closure=None): # else the splited grad should be attached to the splited param grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param)) if len(grads) > 0: - real_working_params[group_id].append(working_param) + # moe hybrid zero + if self.extra_dp_pg is not None and is_moe_tensor(working_param): + real_working_params[group_id].append(working_param) + param_slice = self._world_size // self.extra_dp_pg_size + grad = grads[self.extra_dp_pg_rank * param_slice:(self.extra_dp_pg_rank + 1) * param_slice] + grad = flatten(grad) + else: + real_working_params[group_id].append(working_param) + grad = grads[grad_index] # no need to copy fp32 grad if master_weights is False - grad = ( - grads[grad_index].to(splited_param.dtype).to(splited_param.device) - if self._master_weights - else grads[grad_index] - ) + if self._master_weights: + grad = grad.to(splited_param.dtype).to(splited_param.device) splited_param.grad = grad grad_partition_groups.append(grad) real_master_params[group_id].append(splited_param) @@ -487,11 +578,18 @@ def step(self, closure=None): master_working_param = self.optim.param_groups[group_id]["params"] for idx, splited_param in enumerate(master_working_param): working_param = real_working_params[group_id][idx] - all_splited_param = [ - torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype) for _ in range(self._world_size) - ] - dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.dp_pg) - working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) + if self.extra_dp_pg is not None and is_moe_tensor(working_param): + all_splited_param = [ + torch.zeros(splited_param.shape, device="cuda", dtype=dtype) + for _ in range(self.extra_dp_pg_size) + ] + dist.all_gather(all_splited_param, splited_param.cuda().to(dtype), group=self.extra_dp_pg) + else: + all_splited_param = [ + torch.zeros(splited_param.shape, device="cuda", dtype=dtype) for _ in range(self._world_size) + ] + dist.all_gather(all_splited_param, splited_param.cuda().to(dtype), group=self.dp_pg) + working_param.data.copy_(flatten(all_splited_param)[:working_param.numel()].reshape_as(working_param)) self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float: @@ -512,7 +610,6 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo norm_type = float(norm_type) if norm_type == inf: total_norm = max(grad.data.abs().max() for grad in gradients) - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg) total_norm = total_norm_cuda.item() @@ -620,13 +717,18 @@ def state_dict(self) -> Dict: for k, v in state.items(): if isinstance(v, torch.Tensor) and k != "step": working_param = self._param_store.master_to_working_param[id(param)] - gather_tensor = [ - torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size) - ] - dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg) - param_state = ( - torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() - ) + if self.extra_dp_pg is not None and is_moe_tensor(v): + gather_tensor = [ + torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self.extra_dp_pg_size) + ] + dist.all_gather(gather_tensor, v.cuda(), group=self.extra_dp_pg) + else: + gather_tensor = [ + torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size) + ] + dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg) + param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as( + working_param).cpu() zero_state[param][k] = param_state states_dict = self._pack_state(zero_state) @@ -648,8 +750,12 @@ def load_state_dict(self, state_dict: Dict): v = v.flatten() if padding_size > 0: v = torch.nn.functional.pad(v, [0, padding_size]) - v_list = v.split(v.numel() // self._world_size) - zero_state_dict["state"][param_idx][k] = v_list[self._local_rank].detach().clone() + if self.extra_dp_pg is not None and is_moe_tensor(v): + v_list = v.split(v.numel() // self.extra_dp_pg_size) + zero_state_dict['state'][param_idx][k] = v_list[self.extra_dp_pg_rank].detach().clone() + else: + v_list = v.split(v.numel() // self._world_size) + zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach().clone() self.optim.load_state_dict(zero_state_dict) @@ -679,12 +785,19 @@ def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, i working_param = self._param_store.master_to_working_param[id(master_param)] for k, v in states.items(): - if isinstance(v, torch.Tensor) and k != "step": - state_tensor = [torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size)] - dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg) - state_tensor = ( - torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() - ) + if isinstance(v, torch.Tensor) and k != 'step': + if self.extra_dp_pg is not None and is_moe_tensor(v): + state_tensor = [ + torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self.extra_dp_pg_size) + ] + dist.all_gather(state_tensor, v.cuda(), group=self.extra_dp_pg) + else: + state_tensor = [ + torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size) + ] + dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg) + state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as( + working_param).cpu() current_block_size += state_tensor.numel() current_block[k] = state_tensor @@ -712,7 +825,10 @@ def update_master_params(self, model: nn.Module) -> None: working_param = p.data.view(-1) if padding_size > 0: working_param = torch.nn.functional.pad(working_param, [0, padding_size]) - master_param.copy_(working_param.chunk(self._world_size)[self._local_rank]) + if self.extra_dp_pg is not None and is_moe_tensor(p): + master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank]) + else: + master_param.copy_(working_param.chunk(self._world_size)[self._local_rank]) def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: return self._param_store.working_to_master_param diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py index 5ff0843caaea..830ff9df0ec6 100644 --- a/examples/language/openmoe/benchmark/benchmark_cai.py +++ b/examples/language/openmoe/benchmark/benchmark_cai.py @@ -41,7 +41,7 @@ def load_ckpt(repo_name: str, model: OpenMoeForCausalLM, booster: Booster): class RandomDataset(Dataset): - def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000): + def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 256384): self.num_samples = num_samples self.max_length = max_length self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) @@ -86,7 +86,6 @@ def parse_args(): type=str, default="hybrid", help="parallel plugin", - choices=["zero2", "zero2_ep", "hybrid", "zero2_tp"], ) # hybrid plugin parser.add_argument("--pp_size", type=int, default=2, help="pp size") @@ -94,6 +93,7 @@ def parse_args(): parser.add_argument("--ep_size", type=int, default=2, help="ep size") parser.add_argument("--zero_stage", type=int, default=1, help="zero stage in hybrid plugin") parser.add_argument("--microbatch_size", type=int, default=1, help="microbatch size") + parser.add_argument("--extra_dp_size", type=int, default=1) # kernel parser.add_argument( "--use_kernel", @@ -116,63 +116,73 @@ def main(): # Set plugin booster_kwargs = {} - if args.plugin == "zero2": + hybrid_dict = {"tp_size": 1, "custom_policy": OpenMoeForCausalLMPolicy(), "enable_fused_normalization": args.use_kernel, "enable_jit_fused": args.use_kernel} + mgr_dict = {"seed": 42, "use_kernel_optim": args.use_kernel} + if args.plugin == "zero": dp_size = dist.get_world_size() plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2) MOE_MANAGER.setup( - seed=42, parallel=None, - use_kernel_optim=args.use_kernel, + **mgr_dict, ) - elif args.plugin == "zero2_ep": + elif args.plugin == "ep": dp_size = dist.get_world_size() plugin = MoeHybridParallelPlugin( - tp_size=1, pp_size=1, zero_stage=2, - custom_policy=OpenMoeForCausalLMPolicy(), - enable_fused_normalization=args.use_kernel, - enable_jit_fused=args.use_kernel, + **hybrid_dict, ) MOE_MANAGER.setup( - seed=42, parallel="EP", - use_kernel_optim=args.use_kernel, + **mgr_dict, ) - elif args.plugin == "zero2_tp": + elif args.plugin == "ep_zero": dp_size = dist.get_world_size() + use_ep_inside = False plugin = MoeHybridParallelPlugin( - tp_size=1, pp_size=1, - zero_stage=2, - custom_policy=OpenMoeForCausalLMPolicy(), - enable_fused_normalization=args.use_kernel, - enable_jit_fused=args.use_kernel, + zero_stage=1, + extra_dp_size=args.extra_dp_size, + use_ep_inside=use_ep_inside, + **hybrid_dict, ) MOE_MANAGER.setup( - seed=42, - parallel="TP", - use_kernel_optim=args.use_kernel, + parallel="EP", + max_ep_size=dp_size // args.extra_dp_size, + use_ep_inside=use_ep_inside, + **mgr_dict, + ) + elif args.plugin == "zero_ep": + dp_size = dist.get_world_size() + use_ep_inside = True + plugin = MoeHybridParallelPlugin( + pp_size=1, + zero_stage=1, + extra_dp_size=args.extra_dp_size, + use_ep_inside=use_ep_inside, + **hybrid_dict, + ) + MOE_MANAGER.setup( + parallel="EP", + max_ep_size=dp_size // args.extra_dp_size, + use_ep_inside=use_ep_inside, + **mgr_dict, ) elif args.plugin == "hybrid": dp_size = dist.get_world_size() // args.pp_size plugin = MoeHybridParallelPlugin( - tp_size=1, pp_size=args.pp_size, zero_stage=args.zero_stage, microbatch_size=args.microbatch_size, - custom_policy=OpenMoeForCausalLMPolicy(), - enable_fused_normalization=args.use_kernel, - enable_jit_fused=args.use_kernel, + **hybrid_dict, ) MOE_MANAGER.setup( - seed=42, parallel="EP", mode="fixed", fixed_dp_size=args.dp_size, fixed_ep_size=args.ep_size, fixed_pp_size=args.pp_size, - use_kernel_optim=args.use_kernel, + **mgr_dict, ) else: raise ValueError(f"Invalid plugin {args.plugin}") @@ -219,7 +229,7 @@ def main(): coordinator.print_on_master(f"Finish init booster") # Start finetuning - coordinator.print_on_master(f"Start finetuning") + coordinator.print_on_master(f"Start training") model.train() train_dataloader_iter = iter(dataloader) total_len = len(train_dataloader_iter) - 1 diff --git a/examples/language/openmoe/benchmark/benchmark_cai.sh b/examples/language/openmoe/benchmark/benchmark_cai.sh index 5db65a216461..ec4490faa55d 100755 --- a/examples/language/openmoe/benchmark/benchmark_cai.sh +++ b/examples/language/openmoe/benchmark/benchmark_cai.sh @@ -2,7 +2,7 @@ set -xue -NUM_GPU=8 +NUM_GPU=4 MODEL="8b" SEQ_LENGTH=2048 WARMUP=8 @@ -16,7 +16,7 @@ else export PYTHONPATH=$example_dir:$PYTHONPATH fi -# zero2 +# zero torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ --model_name $MODEL \ @@ -24,10 +24,10 @@ torchrun --standalone --nproc_per_node $NUM_GPU \ --seq_length $SEQ_LENGTH \ --warmup $WARMUP \ --active $ACTIVE \ - --plugin zero2 \ + --plugin zero \ --use_kernel -# zero2_tp +# ep torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ --model_name $MODEL \ @@ -35,10 +35,10 @@ torchrun --standalone --nproc_per_node $NUM_GPU \ --seq_length $SEQ_LENGTH \ --warmup $WARMUP \ --active $ACTIVE \ - --plugin zero2_tp \ + --plugin ep \ --use_kernel -# zero2_ep +# ep_zero torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ --model_name $MODEL \ @@ -46,14 +46,27 @@ torchrun --standalone --nproc_per_node $NUM_GPU \ --seq_length $SEQ_LENGTH \ --warmup $WARMUP \ --active $ACTIVE \ - --plugin zero2_ep \ - --use_kernel + --plugin ep_zero \ + --use_kernel \ + --extra_dp_size 2 + +# zero_ep +torchrun --standalone --nproc_per_node $NUM_GPU \ + $example_dir/benchmark/benchmark_cai.py \ + --model_name $MODEL \ + --batch_size 12 \ + --seq_length $SEQ_LENGTH \ + --warmup $WARMUP \ + --active $ACTIVE \ + --plugin zero_ep \ + --use_kernel \ + --extra_dp_size 2 # hybrid torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ --model_name $MODEL \ - --batch_size 512 \ + --batch_size 128 \ --seq_length $SEQ_LENGTH \ --warmup $WARMUP \ --active $ACTIVE \ diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.py b/examples/language/openmoe/benchmark/benchmark_fsdp.py index 1b69c8d4abeb..0edf102d640c 100644 --- a/examples/language/openmoe/benchmark/benchmark_fsdp.py +++ b/examples/language/openmoe/benchmark/benchmark_fsdp.py @@ -20,9 +20,8 @@ class RandomDataset(Dataset): - def __init__( - self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000 - ): + + def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000): self.num_samples = num_samples self.max_length = max_length self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length)) @@ -52,9 +51,7 @@ def fsdp_main(rank, world_size, args): max_length=args.seq_length, num_samples=args.batch_size * (args.warmup + args.active) * dp_size, ) - sampler = DistributedSampler( - dataset, rank=rank, num_replicas=world_size, shuffle=False - ) + sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size, shuffle=False) train_kwargs = {"batch_size": args.batch_size, "sampler": sampler} train_loader = torch.utils.data.DataLoader(dataset, **train_kwargs) torch.cuda.set_device(rank) @@ -64,7 +61,9 @@ def fsdp_main(rank, world_size, args): setattr(config, "router_z_loss_factor", 0.1) setattr(config, "label_smoothing", 0.1) setattr(config, "z_loss_factor", 0.1) + torch.set_default_dtype(torch.float16) model = OpenMoeForCausalLM(config) + torch.set_default_dtype(torch.float32) auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls={ @@ -114,9 +113,7 @@ def fsdp_main(rank, world_size, args): performance_evaluator.on_fit_end() if dist.get_rank() == 0: - print( - f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB" - ) + print(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") if __name__ == "__main__": diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.sh b/examples/language/openmoe/benchmark/benchmark_fsdp.sh index 41ffcd882a3b..e1eb2a9c6053 100755 --- a/examples/language/openmoe/benchmark/benchmark_fsdp.sh +++ b/examples/language/openmoe/benchmark/benchmark_fsdp.sh @@ -6,8 +6,8 @@ NUM_GPU=8 MODEL="8b" BATCH_SIZE=1 SEQ_LENGTH=2048 -WARMUP=5 -ACTIVE=5 +WARMUP=6 +ACTIVE=3 # HACK: make model importable example_dir=$(dirname $(realpath $(dirname $0))) diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index f8c79320fa57..357c0f22a783 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -360,19 +360,17 @@ def __init__(self, config: LlamaConfig, moe: bool): self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) if self.moe: - self.mlp = SparseMLP( - num_experts=config.num_experts, - top_k=config.topk, - capacity_factor_train=config.capacity_factor_train, - capacity_factor_eval=config.capacity_factor_eval, - min_capacity=config.min_capacity, - noisy_policy=config.noisy_policy, - drop_tks=config.drop_tks, - expert_parallel=MOE_MANAGER.get_parallel() if MOE_MANAGER.is_initialized else config.expert_parallel, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - activation=config.hidden_act, - gated=config.gated) + self.mlp = SparseMLP(num_experts=config.num_experts, + top_k=config.topk, + capacity_factor_train=config.capacity_factor_train, + capacity_factor_eval=config.capacity_factor_eval, + min_capacity=config.min_capacity, + noisy_policy=config.noisy_policy, + drop_tks=config.drop_tks, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + activation=config.hidden_act, + gated=config.gated) self.pre_extra_mlp_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.extra_mlp = OpenMoeMLP(config) else: diff --git a/pytest.ini b/pytest.ini index 38ad7d76de50..598e0a74e71c 100644 --- a/pytest.ini +++ b/pytest.ini @@ -2,4 +2,4 @@ markers = dist: tests which are run in a multi-GPU or multi-machine environment (at least 4 GPUs) largedist: tests which are run in a multi-GPU or multi-machine environment (at least 8 GPUs) -addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk --ignore=tests/test_moe --ignore=tests/test_fx --ignore=tests/test_legacy +addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk --ignore=tests/test_fx --ignore=tests/test_legacy diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py index 53266beb1877..934061ae4417 100644 --- a/tests/test_moe/moe_utils.py +++ b/tests/test_moe/moe_utils.py @@ -14,16 +14,13 @@ class MoeModel(nn.Module): - def __init__(self, checkpoint: bool = False, expert_parallel: str = "EP"): + def __init__(self, checkpoint: bool = False): class TestSubModule(CheckpointModule): def __init__(self): super().__init__(checkpoint) - self.moe = SparseMLP(num_experts=8, - expert_parallel=expert_parallel, - hidden_size=16, - intermediate_size=32) + self.moe = SparseMLP(num_experts=8, hidden_size=16, intermediate_size=32) self.proj = nn.Linear(16, 4) def _forward(self, x): @@ -127,7 +124,7 @@ def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_ """Sync the parameters of tp model from ep model Args: - tp_model (MoeModule) + local_model (MoeModule) ep_model (MoeModule) """ for (local_name, local_param), (ep_name, ep_param) in zip(local_model.named_parameters(), diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index e3de8f101a74..d935be2a9628 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -18,7 +18,7 @@ def run_test(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_MANAGER.setup(42) # MOE initialization + MOE_MANAGER.setup(42, parallel="EP") # MOE initialization num_experts_list = [1, 2, 4] layer_list = [] for num_experts in num_experts_list: @@ -26,7 +26,6 @@ def run_test(rank, world_size, port): intermediate_size=DIM * 4, num_experts=num_experts, top_k=1, - expert_parallel="EP", noisy_policy="Jitter") layer_list.append(moe_layer) diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index 1927c9553683..ef5177289aad 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -23,7 +23,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') local_rank = dist.get_rank() - MOE_MANAGER.setup(42) # MOE environment initialization + MOE_MANAGER.setup(42, parallel="EP") # MOE environment initialization MOE_MANAGER.reset_loss() torch.manual_seed(rs + local_rank) # set each process has different random seed @@ -34,7 +34,6 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f intermediate_size=hidden_size * 2, num_experts=NUM_EXPERTS, top_k=topk, - expert_parallel="EP", capacity_factor_train=1.0) layer = layer.to(get_current_device()) if data_type == torch.float16: diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index 489f5ebdacfc..09af499185db 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -1,5 +1,7 @@ +import importlib import os import shutil +import sys import pytest import torch @@ -11,8 +13,12 @@ from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import rerun_if_address_is_in_use, spawn -from examples.language.openmoe.model.modeling_openmoe import OpenMoeForCausalLM -from examples.language.openmoe.model.openmoe_policy import OpenMoeForCausalLMPolicy + +sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "examples/language/openmoe")) + +# TODO: better way to import them +OpenMoeForCausalLM = importlib.import_module("model.modeling_openmoe").OpenMoeForCausalLM +OpenMoeForCausalLMPolicy = importlib.import_module("model.openmoe_policy").OpenMoeForCausalLMPolicy def get_config(): diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py index 72b639c8b43a..2bbf739ebbd4 100644 --- a/tests/test_moe/test_moe_ep_tp.py +++ b/tests/test_moe/test_moe_ep_tp.py @@ -16,10 +16,11 @@ def run_test(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_MANAGER.setup(42) # MOE initialization - - ep_model = SparseMLP(num_experts=4, expert_parallel="EP", hidden_size=DIM, intermediate_size=DIM) - tp_model = SparseMLP(num_experts=4, expert_parallel="TP", hidden_size=DIM, intermediate_size=DIM) + MOE_MANAGER.setup(42, parallel="EP") # MOE initialization + ep_model = SparseMLP(num_experts=4, hidden_size=DIM, intermediate_size=DIM) + MOE_MANAGER.__init__() + MOE_MANAGER.setup(42, parallel="TP") + tp_model = SparseMLP(num_experts=4, hidden_size=DIM, intermediate_size=DIM) ep_model = ep_model.to(get_current_device()) tp_model = tp_model.to(get_current_device()) diff --git a/tests/test_moe/test_moe_group.py b/tests/test_moe/test_moe_group.py index f5d54ba290aa..e111ea6bb18d 100644 --- a/tests/test_moe/test_moe_group.py +++ b/tests/test_moe/test_moe_group.py @@ -59,7 +59,7 @@ def run_moe_init(expert_cls): def _run_test(rank, world_size, port, expert_cls): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_MANAGER.setup(seed=42) + MOE_MANAGER.setup(seed=42, parallel="EP") run_moe_init(expert_cls) diff --git a/tests/test_moe/test_moe_hybrid_zero.py b/tests/test_moe/test_moe_hybrid_zero.py new file mode 100644 index 000000000000..a2b8efb0e2dc --- /dev/null +++ b/tests/test_moe/test_moe_hybrid_zero.py @@ -0,0 +1,89 @@ +import pytest +import torch +import torch.distributed as dist + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import LowLevelZeroPlugin +from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel +from colossalai.moe.manager import MOE_MANAGER +from colossalai.tensor.moe_tensor.api import is_moe_tensor +from colossalai.testing import rerun_if_address_is_in_use, spawn +from tests.test_moe.moe_utils import MoeModel + + +def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False): + model.train() + with torch.cuda.amp.autocast(enabled=enable_autocast): + if criterion: + y = model(data) + loss = criterion(y, label) + else: + loss = model(data, label) + loss = loss.float() + + if isinstance(model, LowLevelZeroModel): + optimizer.backward(loss / 2) + else: + loss.backward() + return y + + +def run_zero_optim_test(local_rank, world_size, stage=1): + criterion = torch.nn.CrossEntropyLoss() + data = torch.randn(16, 4).cuda() + label = torch.randint(0, 4, (16,)).cuda() + + MOE_MANAGER.setup(seed=42, parallel=None) + torch_model = MoeModel(checkpoint=True) + torch_optimizer = torch.optim.Adam(torch_model.parameters()) + torch_model = torch_model.cuda() + + MOE_MANAGER.__init__() + MOE_MANAGER.setup(seed=42, max_ep_size=2, use_ep_inside=False, parallel="EP") + zero_model = MoeModel(checkpoint=True) + extra_dp_group = MOE_MANAGER.parallel_info_dict[2].dp_group + ep_rank = dist.get_rank(MOE_MANAGER.parallel_info_dict[2].ep_group) + ep_size = MOE_MANAGER.parallel_info_dict[2].ep_size + for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): + if is_moe_tensor(zero_param): + num_expert = torch_param.data.shape[0] + zero_param.data.copy_(torch_param.data[ep_rank * (num_expert // ep_size):(ep_rank + 1) * + (num_expert // ep_size)].detach().clone()) + else: + zero_param.data.copy_(torch_param.data.detach().clone()) + zero_optimizer = torch.optim.Adam(zero_model.parameters()) + plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") + plugin.zero_optim_kwargs["extra_dp_process_group"] = extra_dp_group + booster = Booster(plugin=plugin) + zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) + + run_fwd_bwd(torch_model, data, label, criterion, None) + torch_optimizer.step() + run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + zero_optimizer.step() + + for (torch_name, torch_param), (zero_name, zero_param) in zip(torch_model.named_parameters(), + zero_model.named_parameters()): + if is_moe_tensor(zero_param): + num_expert = torch_param.data.shape[0] + torch_param.data = torch_param.data[ep_rank * (num_expert // ep_size):(ep_rank + 1) * + (num_expert // ep_size)] + assert torch.allclose(torch_param.data, zero_param.data, + atol=1e-4), f"{torch_name}\ntorch_param {torch_param.data}\nzero_param {zero_param.data}" + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_zero_optim_test(rank, world_size, stage=1) + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [4]) +@rerun_if_address_is_in_use() +def test_moe_zero_optim(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_moe_zero_optim(world_size=4) diff --git a/tests/test_moe/test_moe_local.py b/tests/test_moe/test_moe_local.py index 09cc0cc6a4ef..1211a0d2d7f1 100644 --- a/tests/test_moe/test_moe_local.py +++ b/tests/test_moe/test_moe_local.py @@ -16,10 +16,11 @@ def run_test(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_MANAGER.setup(42) # MOE initialization - - ep_model = SparseMLP(num_experts=4, expert_parallel="EP", hidden_size=DIM, intermediate_size=DIM) - local_model = SparseMLP(num_experts=4, expert_parallel=None, hidden_size=DIM, intermediate_size=DIM) + MOE_MANAGER.setup(42, parallel=None) + local_model = SparseMLP(num_experts=4, hidden_size=DIM, intermediate_size=DIM) + MOE_MANAGER.__init__() + MOE_MANAGER.setup(42, parallel="EP") # MOE initialization + ep_model = SparseMLP(num_experts=4, hidden_size=DIM, intermediate_size=DIM) ep_model = ep_model.to(get_current_device()) local_model = local_model.to(get_current_device()) diff --git a/tests/test_moe/test_moe_zero_fwd_bwd.py b/tests/test_moe/test_moe_zero_fwd_bwd.py index 2b2afa4623b5..499d65f0072a 100644 --- a/tests/test_moe/test_moe_zero_fwd_bwd.py +++ b/tests/test_moe/test_moe_zero_fwd_bwd.py @@ -86,7 +86,7 @@ def run_zero_test(local_rank, world_size, stage=1): def run_dist(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_MANAGER.setup(seed=42) + MOE_MANAGER.setup(seed=42, parallel="EP") seed_all(42 + rank) run_zero_test(rank, world_size, stage=1) run_zero_test(rank, world_size, stage=2) diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index 38a5cfbfd66e..8f4d89f17330 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -75,7 +75,7 @@ def run_zero_optim_test(local_rank, world_size, stage=1): def run_dist(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_MANAGER.setup(seed=42) + MOE_MANAGER.setup(seed=42, parallel="EP") run_zero_optim_test(rank, world_size, stage=1) run_zero_optim_test(rank, world_size, stage=2) From 7441a1fbc627f4f662076decc684398326e2b678 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> Date: Thu, 12 Oct 2023 17:03:18 +0800 Subject: [PATCH 20/46] update mm (#4893) --- colossalai/moe/experts.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/colossalai/moe/experts.py b/colossalai/moe/experts.py index 81a7b21544e4..47dceeae8edb 100644 --- a/colossalai/moe/experts.py +++ b/colossalai/moe/experts.py @@ -116,19 +116,19 @@ def forward(self, x: torch.Tensor, param_slice: Tuple[slice] = (slice(None),)) - x = x.reshape(e, -1, h) if self.gated: + x_gate = [torch.mm(x[i], self.wi_gate[param_slice][i]) for i in range(e)] + x_up = [torch.mm(x[i], self.wi_up[param_slice][i]) for i in range(e)] if self.use_kernel and HAS_TRITON and self.act_name == "swiglu": - x = LlamaActCombine.apply( - torch.bmm(x, self.wi_gate[param_slice]), - torch.bmm(x, self.wi_up[param_slice]), - ) + x = [LlamaActCombine.apply(x_gate[i], x_up[i]) for i in range(e)] else: - x = self.act(torch.bmm(x, self.wi_gate[param_slice])) * torch.bmm(x, self.wi_up[param_slice]) + x = [self.act(x_gate[i]) * x_up[i] for i in range(e)] else: - x = torch.bmm(x, self.wi[param_slice]) - x = self.act(x) - x = self.drop(x) - x = torch.bmm(x, self.wo[param_slice]) + x = [torch.mm(x[i], self.wi[param_slice][i]) for i in range(e)] + x = [self.act(x[i]) for i in range(e)] + x = [self.drop(x[i]) for i in range(e)] + x = [torch.mm(x[i], self.wo[param_slice][i]) for i in range(e)] + x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0) x = x.reshape(inshape) x = x.transpose(0, 1).contiguous() x = MoeOutGradScaler.apply(x, self.ep_size) From 5844f347251577ce1bd64ed0c28f11f851324d00 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> Date: Mon, 16 Oct 2023 18:13:15 +0800 Subject: [PATCH 21/46] [moe] support load balance (#4914) * add load balance * update test * update param exchange * pass test * update test * update test * update test * update test * fix ranks * update --- colossalai/moe/experts.py | 1 - colossalai/moe/layers.py | 47 ++- colossalai/moe/load_balance.py | 429 ++++++++++++++++++++++++ colossalai/moe/manager.py | 39 ++- tests/test_moe/test_moe_load_balance.py | 193 +++++++++++ 5 files changed, 696 insertions(+), 13 deletions(-) create mode 100644 colossalai/moe/load_balance.py create mode 100644 tests/test_moe/test_moe_load_balance.py diff --git a/colossalai/moe/experts.py b/colossalai/moe/experts.py index 47dceeae8edb..076f160adb79 100644 --- a/colossalai/moe/experts.py +++ b/colossalai/moe/experts.py @@ -1,5 +1,4 @@ import math -from contextlib import nullcontext from typing import Callable, Optional, Tuple import torch diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index 036bd32ae7c0..3f82a0fa23fd 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -1,16 +1,18 @@ import math -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import torch +import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from colossalai.moe._operation import AllGather, AllToAll, MoeCombine, MoeDispatch, ReduceScatter from colossalai.moe.experts import BaseMLPExperts, get_expert_class +from colossalai.moe.load_balance import LoadBalancer from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.routers import MoeRouter, get_router_cls from colossalai.moe.utils import get_noise_generator -from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size +from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_size class SparseMLP(nn.Module): @@ -72,6 +74,7 @@ def __init__( # moe router noisy_func = get_noise_generator(noisy_policy, num_experts) router_cls = get_router_cls(top_k) + self.topk = top_k self.router: MoeRouter = router_cls( capacity_factor_train=capacity_factor_train, capacity_factor_eval=capacity_factor_eval, @@ -91,13 +94,30 @@ def __init__( if self.expert_parallel is not None: self.ep_group = get_ep_group(self.experts) self.ep_size = get_ep_size(self.experts) + self.dp_group = get_dp_group(self.experts) else: self.ep_group = None + self.dp_group = None self.num_local_experts = self.experts.num_local_experts # gate self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size)) + # load balance + self.enable_load_balance = MOE_MANAGER.load_balance + if self.enable_load_balance == True: + self.load_balancer = LoadBalancer( + experts=self.experts, + gate=self.gate_weight, + local_expert_num=self.num_local_experts, + expert_num=self.num_experts, + ep_group=self.ep_group, + dp_group=self.dp_group, + tolerance=MOE_MANAGER.tolerance, + beam_width=MOE_MANAGER.beam_width, + group_swap_factor=MOE_MANAGER.group_swap_factor, + ) + # init param self.reset_parameters() @@ -121,6 +141,14 @@ def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: fp32_weight = self.gate_weight.to(torch.float) gate_output = F.linear(fp32_input, fp32_weight) + # update expert load + if self.enable_load_balance == True: + with torch.no_grad(): + # TODO: optimize computation + expert_load = torch.topk(gate_output, k=self.topk, dim=-1)[1] + expert_load = torch.bincount(expert_load.view(-1)) + self.load_balancer.update_load(expert_load) + # the result from the router route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group) @@ -257,3 +285,18 @@ def get_chunk_slice(idx: int, gap: int) -> Tuple[slice]: # sync for async op torch.cuda.synchronize() return out + + +def apply_load_balance(model: nn.Module, optim: Any) -> None: + """ + apply load balance to every experts in the model + """ + + def _apply_recursive(module: nn.Module): + for _, sub_module in module.named_children(): + if isinstance(sub_module, SparseMLP): + if sub_module.enable_load_balance == True: + sub_module.load_balancer.balance_load(optim) + _apply_recursive(sub_module) + + _apply_recursive(model) diff --git a/colossalai/moe/load_balance.py b/colossalai/moe/load_balance.py new file mode 100644 index 000000000000..b2fb672329c2 --- /dev/null +++ b/colossalai/moe/load_balance.py @@ -0,0 +1,429 @@ +from copy import deepcopy +from typing import List, Optional, Tuple + +import torch +import torch.distributed as dist +from torch import Tensor, nn +from torch.distributed import ProcessGroup + +from colossalai.cluster import ProcessGroupMesh +from colossalai.moe.experts import BaseMLPExperts +from colossalai.moe.manager import MOE_MANAGER +from colossalai.zero.low_level import LowLevelZeroOptimizer + + +class LoadBalancer: + + def __init__( + self, + experts: BaseMLPExperts, + gate: nn.Parameter, + local_expert_num: int, + expert_num: int, + ep_group: ProcessGroup, + dp_group: ProcessGroup, + tolerance: Optional[float] = 0.1, + beam_width: Optional[int] = 8, + group_swap_factor: Optional[float] = 0.4, + ) -> None: + self.experts: BaseMLPExperts = experts + self.gate: nn.Parameter = gate + self.moe_ep_group: ProcessGroup = ep_group + self.moe_ep_ranks = MOE_MANAGER.parallel_info_dict[dist.get_world_size(self.moe_ep_group)].ep_group_ranks + self.moe_dp_group: ProcessGroup = dp_group + self.tolerance = tolerance + self.beam_width = beam_width + self.group_swap_factor = group_swap_factor + self.local_expert_num = local_expert_num + self.expert_num = expert_num + self.local_load = None + # TODO: use a global process group mesh + pp_size = 1 if MOE_MANAGER.pp_size is None else MOE_MANAGER.pp_size + global_dp_group = ProcessGroupMesh(pp_size, dist.get_world_size() // pp_size) + self.global_dp_group = global_dp_group.get_group_along_axis(1) + + def _clear_load(self) -> None: + self.local_load = None + + def _sync_load(self) -> Tensor: + new_load = self.local_load.clone().detach() + # all reduce load between ep group + dist.all_reduce(new_load, group=self.moe_ep_group) + # all reduce load between dp group + dist.all_reduce(new_load, group=self.moe_dp_group) + return new_load + + @staticmethod + def _get_diff_from_avg(data: List, group: int, avg: float) -> float: + return abs(sum(data[group]) / len(data[group]) - avg) + + @staticmethod + def _swap_data(data: List, group_i: int, index_i: int, group_j: int, index_j: int) -> None: + data[group_i][index_i], data[group_j][index_j] = ( + data[group_j][index_j], + data[group_i][index_i], + ) + + @staticmethod + def _normalize_data(data: List) -> List: + max_value = max(max(sublist) for sublist in data) + data = [[i / max_value for i in sublist] for sublist in data] + return data + + @staticmethod + def _get_swap_loss( + group_swap_factor: float, + swap_list: List, + group_i: int, + index_i: int, + group_j: int, + index_j: int, + ) -> float: + """ + Get swap loss. The swap loss is used to avoid the situation that + the same index is swapped twice and the same group is swapped for multiple times. + """ + swap_loss = 0 + for swap in swap_list: + for group_id, index_id in zip([group_i, group_j], [index_i, index_j]): + # the group has been swapped + if group_id in [swap[0], swap[2]]: + # the index has been swapped + # we want to avoid the situation that the same index is swapped twice + if index_id in [swap[1], swap[3]]: + swap_loss += 1e5 + # the index has not been swapped + # this is acceptable but as less as possible + else: + swap_loss += group_swap_factor + return swap_loss + + @staticmethod + def _check_convergence(data: List, avg: float, tolerance: float): + """ + Check whether the data is converged after swap. + """ + for sublist in data: + if abs(sum(sublist) / len(sublist) - avg) > tolerance * avg: + return False + return True + + def _beam_search( + self, + inputs: Tuple[List, float, List], + beam_width: int, + avg: float, + group_swap_factor: float, + ) -> List: + """ + Beam search for the best swap combination. + Specifically, we swap two elements from two groups and calculate the score. + The score is the difference between the origin group sum and the new group sum. + The larger the score, the better the swap combination. + + Args: + inputs (Tuple): (data, origin_score, swap_list) + beam_width (int): beam width for beam search + avg (float): average value of the data + group_swap_factor (float): group loss for group swap loss + + Returns: + List: results list + """ + data, origin_score, swap_list = inputs + results = [] + group_num = len(data) + group_size = len(data[0]) + origin_diff_list = [self._get_diff_from_avg(data, i, avg) for i in range(group_num)] + + for group_num_i in range(group_num): + for group_size_i in range(group_size): + for group_num_j in range(group_num_i + 1, group_num): + for group_size_j in range(group_size): + new_data = deepcopy(data) + # calculate origin group sum + origin_diff = (origin_diff_list[group_num_i] + origin_diff_list[group_num_j]) + # swap data + self._swap_data( + new_data, + group_num_i, + group_size_i, + group_num_j, + group_size_j, + ) + # calculate new group sum + new_diff = self._get_diff_from_avg(new_data, group_num_i, avg) + self._get_diff_from_avg( + new_data, group_num_j, avg) + # caculate score + new_score = origin_diff - new_diff + if new_score > 0: + new_score = origin_score + new_score + # get swap loss + swap_loss = self._get_swap_loss( + group_swap_factor, + swap_list, + group_num_i, + group_size_i, + group_num_j, + group_size_j, + ) + new_score = new_score - swap_loss + # update swap list + new_swap_list = swap_list + [(group_num_i, group_size_i, group_num_j, group_size_j)] + results.append((new_data, new_score, new_swap_list)) + # sort results + results.sort(key=lambda x: x[1], reverse=True) + # select top k results + results = results[:beam_width] + return results + + def _load_to_list(self, load: Tensor) -> List: + load_len = len(load) + assert load_len % self.local_expert_num == 0 + load_list = [] + tmp_list = [] + for i in range(len(load)): + tmp_list.append(float(load[i])) + if (i + 1) % self.local_expert_num == 0: + load_list.append(tmp_list) + tmp_list = [] + return load_list + + def _search_balance( + self, + data: List, + tolerance: Optional[float] = 0.1, + beam_width: Optional[int] = 8, + group_swap_factor: Optional[float] = 0.4, + return_swapped_data: Optional[bool] = False, + ) -> Tuple[List, List]: + """ + Search for the best swap combination to balance the data within the specified tolerance. + And return the balanced data and the swap list. The swap list is used to record the swap. + The swap list is a list of tuples. Each tuple is a swap operation. + + Args: + data (List): expert load list. + E.g. [[9.2, 8.3], [2.3, 10.0], [6.1, 7.2], [5.3, 3.2]] + This means there are 4 devices and each devices has 2 experts. + The value is the load of the expert. + tolerance (float): tolerance for balance. + beam_width (int): beam width for beam search. + group_swap_factor (float): group swap factor for group swap loss. + The bigger it is, the less times a group will be swapped. + return_swapped_data (bool): whether to return the swapped data. + + Returns: + Tuple: (balanced data, swap list). + The swap list is a list of tuples. Each tuple is a swap operation. + E.g. [(0, 0, 1, 0), (...), (...)]. The first tuple means + the first expert of the first device is swapped with the first expert + of the second device. + """ + norm_data = self._normalize_data(data) + avg = sum(sum(sublist) / len(sublist) for sublist in norm_data) / len(norm_data) + results = [(norm_data, 0, [])] + stop_flag = False + + while stop_flag == False: + new_results = [] + best_score = results[0][1] + for i in range(len(results)): + new_results.extend(self._beam_search(results[i], beam_width, avg, group_swap_factor)) + if len(new_results) == 0: + stop_flag = True + break + new_results.sort(key=lambda x: x[1], reverse=True) + new_best_score = new_results[0][1] + if new_best_score == best_score: + stop_flag = True + break + new_results = new_results[:beam_width] + results = new_results + for i in results: + if self._check_convergence(results[0][0], avg, tolerance): + stop_flag = True + break + + swap_list = results[0][2] + if return_swapped_data: + out = deepcopy(data) + for swap in swap_list: + self._swap_data(out, *swap) + return out, swap_list + else: + return swap_list + + @staticmethod + def _swap_expert_single_tensor( + weight: nn.Parameter, + expert_idx: int, + comm_group: ProcessGroup, + send_first: bool, + comm_rank: int, + ): + # exchange weight + local_weight = weight.data[expert_idx] + new_weight = torch.empty_like(local_weight) + if send_first: + dist.send(local_weight, dst=comm_rank, group=comm_group) + dist.recv(new_weight, src=comm_rank, group=comm_group) + else: + dist.recv(new_weight, src=comm_rank, group=comm_group) + dist.send(local_weight, dst=comm_rank, group=comm_group) + weight.data[expert_idx] = new_weight + + def _swap_expert_param_and_optim( + self, + weight: nn.Parameter, + expert_idx: int, + comm_group: ProcessGroup, + send_first: bool, + comm_rank: int, + optim: LowLevelZeroOptimizer, + ): + # need to update master and working param if master param exists + # else just update working param + if weight in optim.optim.state: + master_weight_ptr = None + working_weight_ptr = weight + exp_avg_ptr = optim.optim.state[working_weight_ptr]["exp_avg"] + exp_avg_sq_ptr = optim.optim.state[working_weight_ptr]["exp_avg_sq"] + else: + master_weight_ptr = optim._param_store.working_to_master_param[id(weight)] + working_weight_ptr = weight + exp_avg_ptr = optim.optim.state[master_weight_ptr]["exp_avg"] + exp_avg_sq_ptr = optim.optim.state[master_weight_ptr]["exp_avg_sq"] + + # exchange weight + self._swap_expert_single_tensor( + working_weight_ptr, + expert_idx, + comm_group, + send_first, + comm_rank, + ) + if master_weight_ptr is not None: + # TODO: exchange master weight, skip for now + # master weight is shared by dp group + tmp = working_weight_ptr.view(-1).split( + working_weight_ptr.numel() // dist.get_world_size(self.moe_dp_group))[dist.get_rank(self.moe_dp_group)] + master_weight_ptr.data.copy_(tmp.clone().detach().to(master_weight_ptr.device).to(master_weight_ptr.dtype)) + # exchange optim + self._swap_expert_single_tensor(exp_avg_ptr, expert_idx, comm_group, send_first, comm_rank) + self._swap_expert_single_tensor(exp_avg_sq_ptr, expert_idx, comm_group, send_first, comm_rank) + + def _gather_global_dp_group(self, data: Tensor) -> Tensor: + data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size(self.global_dp_group))] + dist.all_gather(data_list, data, group=self.global_dp_group) + data_list = torch.cat(data_list, dim=0) + return data_list + + def _swap_moe_param(self, swap_list: List, optim: LowLevelZeroOptimizer) -> None: + """ + Swap moe param and optim. + We use different strategies to swap expert and gate. + For expert, we exchange the param and optim of the expert by p2p. + For gate, we all gather the gate choose the part we want. + + Args: + swap_list (List) + optim (LowLevelZeroOptimizer) + """ + # get all experts weights + local_rank = dist.get_rank(self.moe_ep_group) + if self.experts.gated: + weight_list = [self.experts.wi_up, self.experts.wi_gate] + else: + weight_list = [self.experts.wi] + weight_list.append(self.experts.wo) + + # gate optim should be obtained first + gate_shape = self.gate.shape + # get master weight and optim + master_gate_weight = optim._param_store.working_to_master_param[id(self.gate)] + gate_exp_avg = optim.optim.state[master_gate_weight]["exp_avg"] + gate_exp_avg_sq = optim.optim.state[master_gate_weight]["exp_avg_sq"] + # gather + global_master_gate_weight = self._gather_global_dp_group(master_gate_weight).view(gate_shape) + global_gate_exp_avg = self._gather_global_dp_group(gate_exp_avg).view(gate_shape) + global_gate_exp_avg_sq = self._gather_global_dp_group(gate_exp_avg_sq).view(gate_shape) + assert (self.gate.shape == global_master_gate_weight.shape == global_gate_exp_avg.shape == + global_gate_exp_avg_sq.shape) + + for swap in swap_list: + source_group, source_idx, target_group, target_idx = swap + source_rank = self.moe_ep_ranks[source_group] + target_rank = self.moe_ep_ranks[target_group] + # exchange expert + if local_rank in [source_group, target_group]: + for weight in weight_list: + if local_rank == source_group: + self._swap_expert_param_and_optim( + weight, + source_idx, + self.moe_ep_group, + True, + target_rank, + optim, + ) + elif local_rank == target_group: + self._swap_expert_param_and_optim( + weight, + target_idx, + self.moe_ep_group, + False, + source_rank, + optim, + ) + # exchange gate + source_expert_pos = source_group * self.local_expert_num + source_idx + target_expert_pos = target_group * self.local_expert_num + target_idx + for gate in [ + self.gate, + global_master_gate_weight, + global_gate_exp_avg, + global_gate_exp_avg_sq, + ]: + origin_source = gate.data[source_expert_pos].clone().detach() + origin_target = gate.data[target_expert_pos].clone().detach() + gate.data[source_expert_pos], gate.data[target_expert_pos] = ( + origin_target, + origin_source, + ) + + # update gate + dp_group_rank = dist.get_rank(self.global_dp_group) + dp_group_size = dist.get_world_size(self.global_dp_group) + global_master_gate_weight = global_master_gate_weight.view(-1).split(global_master_gate_weight.numel() // + dp_group_size)[dp_group_rank] + master_gate_weight.data.copy_(global_master_gate_weight) + global_gate_exp_avg = global_gate_exp_avg.view(-1).split(global_gate_exp_avg.numel() // + dp_group_size)[dp_group_rank] + gate_exp_avg.data.copy_(global_gate_exp_avg) + global_gate_exp_avg_sq = global_gate_exp_avg_sq.view(-1).split(global_gate_exp_avg_sq.numel() // + dp_group_size)[dp_group_rank] + gate_exp_avg_sq.data.copy_(global_gate_exp_avg_sq) + + @torch.no_grad() + def update_load(self, load: Tensor) -> None: + if len(load) != self.expert_num: + padding_size = self.expert_num - len(load) + padding = torch.zeros(padding_size, dtype=load.dtype, device=load.device) + load = torch.cat((load, padding), dim=0) + if self.local_load is None: + self.local_load = load + else: + self.local_load += load + + @torch.no_grad() + def balance_load(self, optim: LowLevelZeroOptimizer) -> None: + # prepare load + load = self._sync_load() + load = self._load_to_list(load) + # search balance + swap_list = self._search_balance(load) + # swap expert and gate + self._swap_moe_param(swap_list, optim) + # clear load + self._clear_load() diff --git a/colossalai/moe/manager.py b/colossalai/moe/manager.py index 1e949bb9a6dd..e3659ef43fbd 100644 --- a/colossalai/moe/manager.py +++ b/colossalai/moe/manager.py @@ -27,6 +27,13 @@ def __init__(self): self.mode = None self.use_kernel_optim = False self.use_ep_inside = None + self.pp_size = None + + # load balance param + self.load_balance = None + self.tolerance = None + self.beam_width = None + self.group_swap_factor = None self.has_setup = False self._parallel_info_dict = dict() @@ -39,16 +46,22 @@ def parallel_info_dict(self): def is_initialized(self): return self.has_setup - def setup(self, - seed: int, - use_kernel_optim: bool = False, - parallel: str = None, - mode: str = "dynamic", - max_ep_size: int = 8, - fixed_dp_size: int = 0, - fixed_ep_size: int = 0, - fixed_pp_size: int = 0, - use_ep_inside: bool = True) -> None: + def setup( + self, + seed: int, + use_kernel_optim: bool = False, + parallel: str = None, + mode: str = "dynamic", + max_ep_size: int = 8, + fixed_dp_size: int = 0, + fixed_ep_size: int = 0, + fixed_pp_size: int = 0, + use_ep_inside: bool = True, + enable_load_balance: bool = False, + tolerance: float = 0.1, + beam_width: int = 8, + group_swap_factor: float = 0.4, + ) -> None: """ Setup MoE distributed context. @@ -91,6 +104,12 @@ def setup(self, # Users can close kernel optimization manually self.use_kernel_optim = use_kernel_optim + # update load balance + self.load_balance = enable_load_balance + self.tolerance = tolerance + self.beam_width = beam_width + self.group_swap_factor = group_swap_factor + self.has_setup = True def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoeParallelInfo]: diff --git a/tests/test_moe/test_moe_load_balance.py b/tests/test_moe/test_moe_load_balance.py new file mode 100644 index 000000000000..b4eea04bc85a --- /dev/null +++ b/tests/test_moe/test_moe_load_balance.py @@ -0,0 +1,193 @@ +import pytest +import torch +import torch.distributed as dist + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import LowLevelZeroPlugin +from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel +from colossalai.moe.layers import apply_load_balance +from colossalai.moe.manager import MOE_MANAGER +from colossalai.tensor.moe_tensor.api import is_moe_tensor +from colossalai.testing import rerun_if_address_is_in_use, spawn +from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel + + +def split_ddp_grad(grad, world_size): + with torch.no_grad(): + grad = grad.clone().detach().flatten() + padding_size = (world_size - grad.numel() % world_size) % world_size + if padding_size > 0: + grad = torch.nn.functional.pad(grad, [0, padding_size]) + splited_grad = grad.split(grad.numel() // world_size) + return splited_grad + + +def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False): + model.train() + with torch.cuda.amp.autocast(enabled=enable_autocast): + if criterion: + y = model(data) + loss = criterion(y, label) + else: + loss = model(data, label) + loss = loss.float() + + if isinstance(model, LowLevelZeroModel): + optimizer.backward(loss) + else: + loss.backward() + return y + + +def run_zero_optim_test(local_rank, world_size, stage=1): + criterion = torch.nn.CrossEntropyLoss() + + MOE_MANAGER.__init__() + MOE_MANAGER.setup( + seed=42, + parallel="EP", + enable_load_balance=True, + tolerance=0.1, + beam_width=8, + group_swap_factor=0.4, + ) + zero_model = MoeModel(checkpoint=True) + zero_optimizer = torch.optim.Adam(zero_model.parameters()) + plugin = LowLevelZeroPlugin(stage=stage, precision="bf16", verbose=True) + booster = Booster(plugin=plugin) + zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) + + MOE_MANAGER.__init__() + MOE_MANAGER.setup(seed=42, parallel="EP") + torch_model = MoeModel(checkpoint=True) + for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): + torch_param.data.copy_(zero_param.data) + torch_optimizer = torch.optim.Adam(torch_model.parameters()) + torch_model = torch_model.cuda().bfloat16() + grad_handler = MoeGradientHandler(torch_model) + + # run to update expert load + data = torch.randn(16, 4).cuda().bfloat16() / 1000 / (local_rank + 1) + label = torch.randint(0, 4, (16,)).cuda() + + # run torch model twice + run_fwd_bwd(torch_model, data, label, criterion, None) + grad_handler.handle_gradient() + torch_optimizer.step() + torch_optimizer.zero_grad() + run_fwd_bwd(torch_model, data, label, criterion, None) + grad_handler.handle_gradient() + + # get optim and load status in zero model + run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + zero_optimizer.step() + zero_optimizer.zero_grad() + with torch.no_grad(): + origin_out = zero_model(data) + + # load balance + apply_load_balance(zero_model, zero_optimizer) + + # run again to test + zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + torch.allclose(origin_out, zero_out) + + # assert optim + torch_optimizer.step() + torch_out = run_fwd_bwd(torch_model, data, label, criterion, None) + zero_optimizer.step() + zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + assert torch.allclose(zero_out, torch_out), f"zero_out:{zero_out}\ntorch_out{torch_out}" + + +def run_hybrid_zero_optim_test(local_rank, world_size, stage=1): + criterion = torch.nn.CrossEntropyLoss() + data = torch.randn(16, 4).cuda() + label = torch.randint(0, 4, (16,)).cuda() + + MOE_MANAGER.__init__() + MOE_MANAGER.setup(seed=42, parallel=None) + torch_model = MoeModel(checkpoint=True) + torch_optimizer = torch.optim.Adam(torch_model.parameters()) + torch_model = torch_model.cuda() + + MOE_MANAGER.__init__() + MOE_MANAGER.setup( + seed=42, + max_ep_size=2, + use_ep_inside=False, + parallel="EP", + enable_load_balance=True, + tolerance=0.1, + beam_width=8, + group_swap_factor=0.4, + ) + zero_model = MoeModel(checkpoint=True) + extra_dp_group = MOE_MANAGER.parallel_info_dict[2].dp_group + ep_rank = dist.get_rank(MOE_MANAGER.parallel_info_dict[2].ep_group) + ep_size = MOE_MANAGER.parallel_info_dict[2].ep_size + for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): + if is_moe_tensor(zero_param): + num_expert = torch_param.data.shape[0] + zero_param.data.copy_(torch_param.data[ep_rank * (num_expert // ep_size):(ep_rank + 1) * + (num_expert // ep_size)].detach().clone()) + else: + zero_param.data.copy_(torch_param.data.detach().clone()) + zero_optimizer = torch.optim.Adam(zero_model.parameters()) + plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") + plugin.zero_optim_kwargs["extra_dp_process_group"] = extra_dp_group + booster = Booster(plugin=plugin) + zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) + + # run torch for twice + run_fwd_bwd(torch_model, data, label, criterion, None) + torch_optimizer.step() + torch_optimizer.zero_grad() + run_fwd_bwd(torch_model, data, label, criterion, None) + torch_optimizer.step() + + # run zero + run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + zero_optimizer.step() + zero_optimizer.zero_grad() + with torch.no_grad(): + origin_out = zero_model(data) + + # load balance + apply_load_balance(zero_model, zero_optimizer) + + # assert out + zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + torch.allclose(origin_out, zero_out) + + # assert optim + zero_optimizer.step() + zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + torch_out = run_fwd_bwd(torch_model, data, label, criterion, None) + assert torch.allclose(zero_out, torch_out, atol=8e-4), f"zero_out:{zero_out}\ntorch_out{torch_out}" + + +def run_dist(rank, world_size, port): + colossalai.launch( + config=dict(), + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) + run_zero_optim_test(rank, world_size, stage=1) + run_zero_optim_test(rank, world_size, stage=2) + run_hybrid_zero_optim_test(rank, world_size, stage=1) + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [4]) +@rerun_if_address_is_in_use() +def test_moe_load_balance(world_size): + spawn(run_dist, world_size) + + +if __name__ == "__main__": + test_moe_load_balance(world_size=4) From 5f20878ad6623d52527ab2bbc1051a32ce2b8ef4 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> Date: Tue, 17 Oct 2023 10:44:01 +0800 Subject: [PATCH 22/46] update bench (#4923) --- colossalai/moe/_operation.py | 40 +++++++++----- .../openmoe/benchmark/benchmark_cai.py | 54 ++++++++++++++++--- 2 files changed, 75 insertions(+), 19 deletions(-) diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index 740d17b5698f..fb3885707014 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -3,6 +3,7 @@ import torch import torch.distributed as dist from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd from torch.distributed import ProcessGroup from colossalai.moe.manager import MOE_MANAGER @@ -130,31 +131,42 @@ def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: class MoeDispatch(torch.autograd.Function): @staticmethod + @custom_fwd def forward(ctx, tokens, mask, dest_idx, ec): s = tokens.size(0) h = tokens.size(1) + dtype = tokens.dtype if MOE_KERNEL is None: load_moe() - + if tokens.dtype != torch.float32: + tokens = tokens.to(torch.float32) expert_input = MOE_KERNEL.dispatch_forward(s, ec, h, tokens, mask, dest_idx) - + if expert_input.dtype != dtype: + expert_input = expert_input.to(dtype) ctx.save_for_backward(mask, dest_idx) ctx.s = s ctx.h = h ctx.ec = ec + ctx.dtype = dtype return expert_input @staticmethod + @custom_bwd def backward(ctx, output_grad): mask, dest_idx = ctx.saved_tensors + if output_grad.dtype != torch.float32: + output_grad = output_grad.to(torch.float32) d_tokens = MOE_KERNEL.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx) + if d_tokens.dtype != ctx.dtype: + d_tokens = d_tokens.to(ctx.dtype) return d_tokens, None, None, None class MoeCombine(torch.autograd.Function): @staticmethod + @custom_fwd def forward(ctx, expert_tokens, logits, mask, dest_idx, ec): assert logits.dtype == torch.float32 @@ -162,32 +174,36 @@ def forward(ctx, expert_tokens, logits, mask, dest_idx, ec): e = logits.size(1) c = ec // e h = expert_tokens.size(-1) + dtype = expert_tokens.dtype - fp16_flag = expert_tokens.dtype == torch.float16 - cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens + if expert_tokens.dtype != torch.float32: + expert_tokens = expert_tokens.to(torch.float32) if MOE_KERNEL is None: load_moe() - ctokens = MOE_KERNEL.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx) - output = ctokens.to(torch.float16) if fp16_flag else ctokens + output = MOE_KERNEL.combine_forward(s, e, c, h, expert_tokens, logits, mask, dest_idx) + if output.dtype != dtype: + output = output.to(dtype) ctx.save_for_backward(expert_tokens, logits, mask, dest_idx) ctx.s = s ctx.e = e ctx.c = c ctx.h = h - ctx.fp16_flag = fp16_flag + ctx.dtype = dtype return output @staticmethod + @custom_bwd def backward(ctx, tokens_grad): expert_tokens, logits, mask, dest_idx = ctx.saved_tensors + if tokens_grad.dtype != torch.float32: + tokens_grad = tokens_grad.to(torch.float32) - cb_grad = (tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 else tokens_grad) - cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens - d_expert, d_logits = MOE_KERNEL.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits, mask, - dest_idx) - d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert + d_expert, d_logits = MOE_KERNEL.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, tokens_grad, expert_tokens, logits, + mask, dest_idx) + if d_expert.dtype != ctx.dtype: + d_expert = d_expert.to(ctx.dtype) return d_expert, d_logits, None, None, None diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py index 830ff9df0ec6..e1acba5c88b0 100644 --- a/examples/language/openmoe/benchmark/benchmark_cai.py +++ b/examples/language/openmoe/benchmark/benchmark_cai.py @@ -1,3 +1,4 @@ +import json import os import torch @@ -7,7 +8,7 @@ from model.openmoe_policy import OpenMoeForCausalLMPolicy from torch.utils.data import Dataset from tqdm import tqdm -from transformers import Adafactor +from transformers import T5Tokenizer from transformers.models.llama import LlamaConfig from utils import PerformanceEvaluator, get_model_numel @@ -17,6 +18,7 @@ from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator +from colossalai.moe.layers import apply_load_balance from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import skip_init from colossalai.utils import get_current_device @@ -41,11 +43,36 @@ def load_ckpt(repo_name: str, model: OpenMoeForCausalLM, booster: Booster): class RandomDataset(Dataset): - def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 256384): + def __init__(self, + num_samples: int = 1000, + max_length: int = 2048, + vocab_size: int = 256384, + tokenizer: T5Tokenizer = None): self.num_samples = num_samples self.max_length = max_length - self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) - self.attention_mask = torch.ones_like(self.input_ids) + if os.path.exists("./mock_data.json"): + self.input_ids = [] + self.attention_mask = [] + with open("./mock_data.json", 'r') as f: + data = json.load(f) + for v in data.values(): + d = v["text"] + encode = tokenizer("" + d, + return_tensors="pt", + add_special_tokens=False, + max_length=max_length, + truncation=True, + padding="max_length") + self.input_ids.append(encode["input_ids"]) + self.attention_mask.append(encode["attention_mask"]) + self.input_ids = torch.cat(self.input_ids, dim=0).to(get_current_device()) + self.attention_mask = torch.cat(self.attention_mask, dim=0).to(get_current_device()) + repeat_times = num_samples // self.input_ids.shape[0] + 1 + self.input_ids = self.input_ids.repeat(repeat_times, 1)[:num_samples] + self.attention_mask = self.attention_mask.repeat(repeat_times, 1)[:num_samples] + else: + self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) + self.attention_mask = torch.ones_like(self.input_ids) def __len__(self): return self.num_samples @@ -103,6 +130,8 @@ def parse_args(): # bench parser.add_argument("--warmup", type=int, default=20) parser.add_argument("--active", type=int, default=20) + # load balance + parser.add_argument("--load_balance", action="store_true") args = parser.parse_args() return args @@ -116,8 +145,14 @@ def main(): # Set plugin booster_kwargs = {} - hybrid_dict = {"tp_size": 1, "custom_policy": OpenMoeForCausalLMPolicy(), "enable_fused_normalization": args.use_kernel, "enable_jit_fused": args.use_kernel} - mgr_dict = {"seed": 42, "use_kernel_optim": args.use_kernel} + hybrid_dict = { + "tp_size": 1, + "custom_policy": OpenMoeForCausalLMPolicy(), + "enable_fused_normalization": args.use_kernel, + "enable_jit_fused": args.use_kernel, + "precision": "bf16" + } + mgr_dict = {"seed": 42, "use_kernel_optim": args.use_kernel, "enable_load_balance": args.load_balance} if args.plugin == "zero": dp_size = dist.get_world_size() plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2) @@ -203,14 +238,16 @@ def main(): model.gradient_checkpointing_enable() # Prepare tokenizer and dataloader + tokenizer = T5Tokenizer.from_pretrained("google/umt5-small") dataset = RandomDataset( num_samples=args.batch_size * (args.warmup + args.active + 1) * dp_size, max_length=args.seq_length, + tokenizer=tokenizer, ) dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size) # Set optimizer - optimizer = Adafactor(model.parameters(), weight_decay=0.01) + optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.01, lr=1e-5) model_numel = get_model_numel(model) performance_evaluator = PerformanceEvaluator( @@ -264,6 +301,9 @@ def main(): optimizer.step() optimizer.zero_grad() performance_evaluator.on_step_end(exmaple_data["input_ids"]) + if (step == args.warmup // 2) and args.load_balance: + apply_load_balance(model, optimizer) + coordinator.print_on_master(f"Apply load balance") performance_evaluator.on_fit_end() From b0e277b4e05daba70895a7a127d6351cad0b34fb Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Wed, 18 Oct 2023 10:03:50 +0800 Subject: [PATCH 23/46] [moe]: add overlap ep, and fix overlap tp (#4925) * test: add more ep/tp test case * to: add TPOverlap fn * fix: fix tp overlap * fix: remove useless variables * feat: add async all to all * feat: add overlap ep * fix: fix import error * fix: fix ep/tp tests * perf: optimize overlap * fix: add world_size check --- colossalai/moe/_operation.py | 79 +++++++----- colossalai/moe/layers.py | 205 ++++++++++++++++++++----------- tests/test_moe/test_moe_ep_tp.py | 70 ++++++++--- tests/test_moe/test_moe_local.py | 65 ---------- 4 files changed, 231 insertions(+), 188 deletions(-) delete mode 100644 tests/test_moe/test_moe_local.py diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index fb3885707014..9f5f400922ea 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -9,8 +9,6 @@ from colossalai.moe.manager import MOE_MANAGER MOE_KERNEL = None -WORLD_HANDLE_ALLGATHER = None -WORLD_HANDLE_REDUCESCATTER = None def load_moe(): @@ -27,14 +25,20 @@ def forward( inputs: Tensor, group: Optional[ProcessGroup] = None, overlap: bool = False, - ) -> Tensor: + ) -> Tuple[Tensor, Any]: + """ + Returns: + outputs: Tensor + handle: Optional[Work], if overlap is True + """ + assert ctx is not None or not overlap + if ctx is not None: ctx.comm_grp = group - ctx.overlap = overlap comm_size = dist.get_world_size(group) if comm_size == 1: - return inputs.unsqueeze(0) + return inputs.unsqueeze(0), None buffer_shape = (comm_size,) + inputs.shape outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device) @@ -44,19 +48,12 @@ def forward( return outputs, None else: handle = dist.all_gather(buffer_list, inputs, group=group, async_op=True) - if ctx is None and overlap: - global WORLD_HANDLE_ALLGATHER - WORLD_HANDLE_ALLGATHER = handle return outputs, handle @staticmethod - def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None]: - global WORLD_HANDLE_REDUCESCATTER - if WORLD_HANDLE_REDUCESCATTER is not None: - WORLD_HANDLE_REDUCESCATTER.wait() - WORLD_HANDLE_REDUCESCATTER = None + def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: return ( - ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, ctx.overlap)[0], + ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, False)[0], None, None, ) @@ -69,14 +66,20 @@ def forward( inputs: Tensor, group: Optional[ProcessGroup] = None, overlap: bool = False, - ) -> Tensor: + ) -> Tuple[Tensor, Any]: + """ + Returns: + outputs: Tensor + handle: Optional[Work], if overlap is True + """ + assert ctx is not None or not overlap + if ctx is not None: ctx.comm_grp = group - ctx.overlap = overlap comm_size = dist.get_world_size(group) if comm_size == 1: - return inputs.squeeze(0) + return inputs.squeeze(0), None if not inputs.is_contiguous(): inputs = inputs.contiguous() @@ -89,19 +92,13 @@ def forward( return outputs, None else: handle = dist.reduce_scatter(outputs, buffer_list, group=group, async_op=True) - if ctx is None and overlap: - global WORLD_HANDLE_REDUCESCATTER - WORLD_HANDLE_REDUCESCATTER = handle return outputs, handle @staticmethod - def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None]: - global WORLD_HANDLE_ALLGATHER - if WORLD_HANDLE_ALLGATHER is not None: - WORLD_HANDLE_ALLGATHER.wait() - WORLD_HANDLE_ALLGATHER = None + def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: + # TODO: support async backward return ( - AllGather.forward(None, grad_outputs[0], ctx.comm_grp, ctx.overlap)[0], + AllGather.forward(None, grad_outputs[0], ctx.comm_grp, False)[0], None, None, ) @@ -113,20 +110,38 @@ class AllToAll(torch.autograd.Function): """ @staticmethod - def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: + def forward( + ctx: Any, + inputs: Tensor, + group: Optional[ProcessGroup] = None, + overlap: bool = False, + ) -> Tuple[Tensor, Any]: + """ + Returns: + outputs: Tensor + handle: Optional[Work], if overlap is True + """ if ctx is not None: ctx.comm_grp = group if not inputs.is_contiguous(): inputs = inputs.contiguous() if dist.get_world_size(group) == 1: - return inputs + return inputs, None output = torch.empty_like(inputs) - dist.all_to_all_single(output, inputs, group=group) - return output + if not overlap: + dist.all_to_all_single(output, inputs, group=group) + return output, None + else: + handle = dist.all_to_all_single(output, inputs, group=group, async_op=True) + return output, handle @staticmethod - def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: - return AllToAll.forward(None, *grad_outputs, ctx.comm_grp), None + def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]: + return ( + AllToAll.forward(None, grad_outputs[0], ctx.comm_grp)[0], + None, + None, + ) class MoeDispatch(torch.autograd.Function): diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index 3f82a0fa23fd..9846cd432b53 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -1,3 +1,4 @@ +import dataclasses import math from typing import Any, Optional, Tuple @@ -188,7 +189,10 @@ def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor: expert_out = self.experts(expert_in) return expert_out - def _ep_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: + def _ep_process(self, + dispatch_data: torch.Tensor, + overlap: bool = True + ) -> torch.Tensor: """ Expert Parallel @@ -198,27 +202,82 @@ def _ep_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: (num_experts, capacity, hidden_size) """ - expert_input = AllToAll.apply(dispatch_data, self.ep_group) - input_shape = expert_input.shape - expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) - expert_output = self.experts(expert_input) - expert_output = expert_output.reshape(input_shape) - expert_output = AllToAll.apply(expert_output, self.ep_group) - return expert_output - - def _tp_process(self, dispatch_data: torch.Tensor, use_overlap: bool = False) -> torch.Tensor: - """ - TP with overlap. + if not overlap or dist.get_world_size(self.ep_group) == 1: + expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0] + expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) + expert_output = self.experts(expert_input) + expert_output = AllToAll.apply(expert_output, self.ep_group, False)[0] + return expert_output - origin: + else: + @dataclasses.dataclass + class Capsule(): + data: torch.Tensor + handle: Any = None + + NUM_CHUNK = 2 + NUM_STAGES = 4 + + assert dispatch_data.shape[1] % NUM_CHUNK == 0, \ + "arbitrary chunk num is not supported yet" + chunk_size = dispatch_data.shape[1] // NUM_CHUNK + input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size) + dispatch_data = dispatch_data.reshape(*input_shape) + chunk_data = torch.split(dispatch_data, chunk_size, dim=2) + output = torch.empty_like(dispatch_data) + + offset = 0 + _expert_in, expert_in, _expert_out, expert_out = None, None, None, None + + for i in range(NUM_CHUNK + NUM_STAGES - 1): + if expert_out is not None: + expert_out.handle.wait() + output[:, :, offset:offset + chunk_size, :] = expert_out.data + offset += chunk_size + expert_out = None + + # all2all last output + if _expert_out is not None: + expert_out = Capsule( + *AllToAll.apply(_expert_out.data, self.ep_group, True), + ) + _expert_out = None + + # all2all next input + if 0 <= i < NUM_CHUNK: + _expert_in = Capsule( + *AllToAll.apply(chunk_data[i].contiguous(), self.ep_group, True) + ) + + # compute + if expert_in is not None: + expert_in.handle.wait() + _expert_out = Capsule( + data=self.experts(expert_in.data), + handle=None + ) + expert_in = None + + if _expert_in is not None: + expert_in = _expert_in + _expert_in = None + + return output + + def _tp_process(self, + dispatch_data: torch.Tensor, + overlap: bool = True + ) -> torch.Tensor: + """ + without overlap: | C | | A | | R | - overlap: + with overlap: | C1 || C2 || C3 || C4 | | A1 || A2 | | R1 | A3 || R2 | A4 || R3 | | R4 | - C is computation, A is all gather, R is reduce scatter. + where C is computation, A is all gather, R is reduce scatter. Args: dispatch_data (torch.Tensor): (num_experts, capacity, hidden_size) @@ -226,65 +285,67 @@ def _tp_process(self, dispatch_data: torch.Tensor, use_overlap: bool = False) -> Returns: torch.Tensor: (num_experts, capacity, hidden_size) """ - if use_overlap == False: - expert_in, _ = AllGather.apply(dispatch_data, self.ep_group) + if not overlap or dist.get_world_size(self.ep_group) == 1: + expert_in = AllGather.apply(dispatch_data, self.ep_group, False)[0] expert_out = self.experts(expert_in) - expert_out, _ = ReduceScatter.apply(expert_out, self.ep_group) + expert_out = ReduceScatter.apply(expert_out, self.ep_group, False)[0] return expert_out - - # TODO: there is accuracy problem in overlap - chunk_num = 4 - chunk_size = dispatch_data.shape[0] // chunk_num - out = torch.empty_like(dispatch_data) - in_data = None - in_handle = None - out_data = None - out_handle = None - - # backward compatibility for async op - torch.cuda.synchronize() - - def get_chunk_slice(idx: int, gap: int) -> Tuple[slice]: - return (slice(idx * gap, (idx + 1) * gap),) - - for i in range(chunk_num): - cur_chunk_slice = get_chunk_slice(i, chunk_size) - - # if first, all gather - if i == 0: - d = dispatch_data[cur_chunk_slice].contiguous() - expert_in, _ = AllGather.apply(d, self.ep_group) - else: - expert_in = in_data - - # async communication while compute - if i != 0: - # reduce scatter last out - out_data, out_handle = ReduceScatter.apply(out_data, self.ep_group, True) - if i != chunk_num - 1: - # all gather next in - next_d = dispatch_data[get_chunk_slice(i + 1, chunk_size)].contiguous() - in_data, in_handle = AllGather.apply(next_d, self.ep_group, True) - - # compute - expert_out = self.experts(expert_in, cur_chunk_slice) - - # sync handle - if i != 0: - out_handle.wait() - out[get_chunk_slice(i - 1, chunk_size)] = out_data - if i != chunk_num - 1: - in_handle.wait() - out_data = expert_out - - # store out for last loop - if i == chunk_num - 1: - out_data, _ = ReduceScatter.apply(out_data, self.ep_group) - out[cur_chunk_slice] = out_data - - # sync for async op - torch.cuda.synchronize() - return out + else: + @dataclasses.dataclass + class Capsule(): + data: torch.Tensor + handle: Any + indices: Tuple + + NUM_CHUNK = 2 + NUM_STAGES = 4 + + assert dispatch_data.shape[0] % NUM_CHUNK == 0, \ + "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts" + chunk_size = dispatch_data.shape[0] // NUM_CHUNK + chunk_data = torch.split(dispatch_data, chunk_size, dim=0) + output = torch.empty_like(dispatch_data) + + def get_chunk_slice(idx: int, chunk_size: int) -> Tuple[slice]: + return (slice(idx * chunk_size, (idx + 1) * chunk_size), ) + + _expert_in, expert_in, _expert_out, expert_out = None, None, None, None + + for i in range(NUM_CHUNK + NUM_STAGES - 1): + if expert_out is not None: + expert_out.handle.wait() + output[expert_out.indices] = expert_out.data + expert_out = None + + # reduce scatter last output + if _expert_out is not None: + expert_out = Capsule( + *ReduceScatter.apply(_expert_out.data, self.ep_group, True), + indices=_expert_out.indices + ) + _expert_out = None + + # all gather next input + if 0 <= i < NUM_CHUNK: + _expert_in = Capsule( + *AllGather.apply(chunk_data[i].contiguous(), self.ep_group, True), + indices=get_chunk_slice(i, chunk_size) + ) + + # compute + if expert_in is not None: + expert_in.handle.wait() + _expert_out = Capsule( + self.experts(expert_in.data, expert_in.indices), + handle=None, indices=expert_in.indices + ) + expert_in = None + + if _expert_in is not None: + expert_in = _expert_in + _expert_in = None + + return output def apply_load_balance(model: nn.Module, optim: Any) -> None: diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py index 2bbf739ebbd4..51fd135483b6 100644 --- a/tests/test_moe/test_moe_ep_tp.py +++ b/tests/test_moe/test_moe_ep_tp.py @@ -8,57 +8,89 @@ from colossalai.moe.utils import sync_moe_model_param from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device -from tests.test_moe.moe_utils import MoeGradientHandler, sync_tp_from_ep +from tests.test_moe.moe_utils import MoeGradientHandler, sync_local_from_ep, sync_tp_from_ep -BATCH_SIZE = 4 -DIM = 16 +def run_test(rank: int, + world_size: int, + port: int, + num_experts: int, + batch_size: int, + dim: int, + seed: int): + assert batch_size % world_size == 0 -def run_test(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_MANAGER.setup(42, parallel="EP") # MOE initialization - ep_model = SparseMLP(num_experts=4, hidden_size=DIM, intermediate_size=DIM) + + MOE_MANAGER.__init__() + MOE_MANAGER.setup(seed, parallel=None) + local_model = SparseMLP(num_experts=num_experts, + hidden_size=dim, + intermediate_size=dim * 2) MOE_MANAGER.__init__() - MOE_MANAGER.setup(42, parallel="TP") - tp_model = SparseMLP(num_experts=4, hidden_size=DIM, intermediate_size=DIM) + MOE_MANAGER.setup(seed, parallel="EP") + ep_model = SparseMLP(num_experts=num_experts, + hidden_size=dim, + intermediate_size=dim * 2) + MOE_MANAGER.__init__() + MOE_MANAGER.setup(seed, parallel="TP") + tp_model = SparseMLP(num_experts=num_experts, + hidden_size=dim, + intermediate_size=dim * 2) ep_model = ep_model.to(get_current_device()) tp_model = tp_model.to(get_current_device()) + local_model = local_model.to(get_current_device()) # sync ep param sync_moe_model_param(ep_model) dist_dict = MOE_MANAGER.parallel_info_dict - assert_equal_in_group(ep_model.experts.wi.data, dist_dict[2].dp_group) - assert_equal_in_group(ep_model.experts.wo.data, dist_dict[2].dp_group) + assert_equal_in_group(ep_model.experts.wi.data, dist_dict[world_size].dp_group) + assert_equal_in_group(ep_model.experts.wo.data, dist_dict[world_size].dp_group) grad_handler = MoeGradientHandler(ep_model) # sync tp param sync_tp_from_ep(tp_model, ep_model) + # sync local param + sync_local_from_ep(local_model, ep_model) rank = dist.get_rank() - torch.cuda.manual_seed(78) - tp_data = torch.randn(BATCH_SIZE, DIM, device=get_current_device()) - ep_data = tp_data.detach()[2 * rank:2 * (rank + 1)] + torch.cuda.manual_seed(seed) + tp_data = torch.randn(batch_size, dim, device=get_current_device()) + micro_batch_size = batch_size // world_size + ep_data = tp_data.detach()[micro_batch_size * rank:micro_batch_size * (rank + 1)] + out_local = local_model(tp_data) + MOE_MANAGER.reset_loss() out_tp = tp_model(tp_data) MOE_MANAGER.reset_loss() out_ep = ep_model(ep_data) MOE_MANAGER.reset_loss() - assert torch.allclose(out_ep, out_tp[2 * rank:2 * (rank + 1)]) + assert torch.allclose(out_ep, out_tp[micro_batch_size * rank:micro_batch_size * (rank + 1)]) + assert torch.allclose(out_ep, out_local[micro_batch_size * rank:micro_batch_size * (rank + 1)]) + out_local.mean().backward() out_tp.mean().backward() out_ep.mean().backward() grad_handler.handle_gradient() - assert_equal_in_group(ep_model.experts.wi.grad, dist_dict[2].dp_group) - assert_equal_in_group(ep_model.experts.wo.grad, dist_dict[2].dp_group) + assert_equal_in_group(ep_model.experts.wi.grad, dist_dict[world_size].dp_group) + assert_equal_in_group(ep_model.experts.wo.grad, dist_dict[world_size].dp_group) + sync_local_from_ep(local_model, ep_model, assert_grad_flag=True) sync_tp_from_ep(tp_model, ep_model, assert_grad_flag=True) @pytest.mark.dist +@pytest.mark.parametrize("num_experts", [4, 8]) +@pytest.mark.parametrize("batch_size", [4, 8]) +@pytest.mark.parametrize("dim", [16, 256]) +@pytest.mark.parametrize("seed", [42, 78]) @rerun_if_address_is_in_use() -def test_moe_ep_tp(): - spawn(run_test, 2) +def test_moe_ep_tp(num_experts: int, + batch_size: int, + dim: int, + seed: int): + spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, seed=seed) if __name__ == '__main__': - test_moe_ep_tp() + test_moe_ep_tp(num_experts=8, batch_size=8, dim=256, seed=42) diff --git a/tests/test_moe/test_moe_local.py b/tests/test_moe/test_moe_local.py deleted file mode 100644 index 1211a0d2d7f1..000000000000 --- a/tests/test_moe/test_moe_local.py +++ /dev/null @@ -1,65 +0,0 @@ -import pytest -import torch -import torch.distributed as dist - -import colossalai -from colossalai.moe import SparseMLP -from colossalai.moe.manager import MOE_MANAGER -from colossalai.moe.utils import sync_moe_model_param -from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device -from tests.test_moe.moe_utils import MoeGradientHandler, sync_local_from_ep - -BATCH_SIZE = 4 -DIM = 16 - - -def run_test(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_MANAGER.setup(42, parallel=None) - local_model = SparseMLP(num_experts=4, hidden_size=DIM, intermediate_size=DIM) - MOE_MANAGER.__init__() - MOE_MANAGER.setup(42, parallel="EP") # MOE initialization - ep_model = SparseMLP(num_experts=4, hidden_size=DIM, intermediate_size=DIM) - ep_model = ep_model.to(get_current_device()) - local_model = local_model.to(get_current_device()) - - # sync ep param - sync_moe_model_param(ep_model) - dist_dict = MOE_MANAGER.parallel_info_dict - assert_equal_in_group(ep_model.experts.wi.data, dist_dict[2].dp_group) - assert_equal_in_group(ep_model.experts.wo.data, dist_dict[2].dp_group) - grad_handler = MoeGradientHandler(ep_model) - # sync tp param - sync_local_from_ep(local_model, ep_model) - - rank = dist.get_rank() - torch.cuda.manual_seed(78) - tp_data = torch.randn(BATCH_SIZE, DIM, device=get_current_device()) - ep_data = tp_data.detach()[2 * rank:2 * (rank + 1)] - - out_tp = local_model(tp_data) - MOE_MANAGER.reset_loss() - out_ep = ep_model(ep_data) - MOE_MANAGER.reset_loss() - assert torch.allclose(out_ep, out_tp[2 * rank:2 * (rank + 1)]) - - out_tp.mean().backward() - out_ep.mean().backward() - grad_handler.handle_gradient() - - assert_equal_in_group(ep_model.experts.wi.grad, dist_dict[2].dp_group) - assert_equal_in_group(ep_model.experts.wo.grad, dist_dict[2].dp_group) - - sync_local_from_ep(local_model, ep_model, assert_grad_flag=True) - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [2]) -@rerun_if_address_is_in_use() -def test_moe_local(world_size): - spawn(run_test, world_size) - - -if __name__ == '__main__': - test_moe_local() From 4a7bf291198a61f2ec3a2f946d0ef72bb1ff8c2a Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> Date: Fri, 20 Oct 2023 16:14:54 +0800 Subject: [PATCH 24/46] [moe] polish code (#4952) * doc * update script * update experts * update optim in fsdp * update kernel in sparse * empty cache * update script * update bench * update script * remove epzero2 * fix * update print * update test script * update script * update manager * update host * update script --- colossalai/moe/__init__.py | 15 +- colossalai/moe/_operation.py | 4 +- colossalai/moe/experts.py | 118 ++++--------- colossalai/moe/layers.py | 146 ++++++++-------- colossalai/moe/load_balance.py | 11 +- colossalai/moe/manager.py | 86 ++++------ colossalai/moe/routers.py | 83 ++++----- colossalai/moe/utils.py | 7 +- colossalai/zero/low_level/low_level_optim.py | 162 ++++++++++++------ .../openmoe/benchmark/benchmark_cai.py | 72 ++++---- .../openmoe/benchmark/benchmark_cai.sh | 39 ++--- .../openmoe/benchmark/benchmark_fsdp.py | 3 +- .../openmoe/benchmark/benchmark_fsdp.sh | 4 +- .../language/openmoe/benchmark/hostfile.txt | 2 + examples/language/openmoe/infer.py | 49 +++++- .../openmoe/model/modeling_openmoe.py | 32 ++-- examples/language/openmoe/train.py | 31 +++- tests/test_moe/moe_utils.py | 7 +- tests/test_moe/test_grad_handler.py | 21 ++- tests/test_moe/test_kernel.py | 8 +- tests/test_moe/test_moe_checkpoint.py | 14 +- tests/test_moe/test_moe_ep_tp.py | 31 +--- tests/test_moe/test_moe_group.py | 57 +++--- tests/test_moe/test_moe_load_balance.py | 12 +- 24 files changed, 530 insertions(+), 484 deletions(-) create mode 100644 examples/language/openmoe/benchmark/hostfile.txt diff --git a/colossalai/moe/__init__.py b/colossalai/moe/__init__.py index 1614987538c1..f32e89dfad3f 100644 --- a/colossalai/moe/__init__.py +++ b/colossalai/moe/__init__.py @@ -1,12 +1,17 @@ from .checkpoint import MoeCheckpintIO -from .experts import EPMLPExperts, TPMLPExperts, build_ffn_experts +from .experts import MLPExperts from .layers import SparseMLP from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter from .utils import NormalNoiseGenerator, UniformNoiseGenerator __all__ = [ - 'EPMLPExperts', 'TPMLPExperts', 'build_ffn_experts', - 'MoeRouter', 'Top1Router', 'Top2Router', 'TopKRouter', - 'NormalNoiseGenerator', 'UniformNoiseGenerator', - 'SparseMLP', 'MoeCheckpintIO' + "MLPExperts", + "MoeRouter", + "Top1Router", + "Top2Router", + "TopKRouter", + "NormalNoiseGenerator", + "UniformNoiseGenerator", + "SparseMLP", + "MoeCheckpintIO", ] diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index 9f5f400922ea..542c6372790f 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -223,10 +223,10 @@ def backward(ctx, tokens_grad): return d_expert, d_logits, None, None, None -def moe_cumsum(inputs: Tensor): +def moe_cumsum(inputs: Tensor, use_kernel: bool = False): dim0 = inputs.size(0) flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0) - if flag and MOE_MANAGER.use_kernel_optim: + if flag and use_kernel: if MOE_KERNEL is None: load_moe() return MOE_KERNEL.cumsum_sub_one(inputs) diff --git a/colossalai/moe/experts.py b/colossalai/moe/experts.py index 076f160adb79..3471b2876e9b 100644 --- a/colossalai/moe/experts.py +++ b/colossalai/moe/experts.py @@ -15,18 +15,19 @@ from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine -class BaseMLPExperts(nn.Module): +class MLPExperts(nn.Module): """ SparseMLP is a multi-layer perceptron with sparse expert parallel layers. Args: num_experts (int): The number of experts - forward: hidden_size --> intermediate_size --> hidden_size - hidden_size (int): The hidden size of MLP - intermediate_size (int): The intermediate size of MLP - expert_parallel (str, optional): The parallelism of experts. Now we have 'EP' and 'TP'. + hidden_size (int): The hidden size of MLP + intermediate_size (int): The intermediate size of MLP + expert_parallel (str, optional): The parallelism of experts. Now we have None, EP and TP. activation (optional): The activation function of MLP drop_rate (float, optional): The drop rate of MLP + gated (bool, optional): Whether to use gated MLP + use_kernel (bool, optional): Whether to use kernel optimization """ def __init__( @@ -36,9 +37,9 @@ def __init__( intermediate_size: int, expert_parallel: Optional[str] = None, activation: Optional[Callable] = None, - drop_rate: float = 0, - gated: bool = False, - use_kernel: bool = False, + drop_rate: Optional[float] = 0, + gated: Optional[bool] = False, + use_kernel: Optional[bool] = False, ): super().__init__() assert expert_parallel in ["EP", "TP", None] @@ -97,8 +98,15 @@ def reset_parameters(self): torch.nn.init.normal_(self.wi, std=math.sqrt(0.1 / self.hidden_size)) torch.nn.init.normal_(self.wo, std=math.sqrt(0.1 / self.intermediate_size)) - def forward(self, x: torch.Tensor, param_slice: Tuple[slice] = (slice(None),)) -> torch.Tensor: + def forward( + self, + x: torch.Tensor, + param_slice: Tuple[slice] = (slice(None),), + use_sparse: bool = True, + ) -> torch.Tensor: """ + forward: hidden_size --> intermediate_size --> hidden_size + Args: x (torch.Tensor): The input tensor of shape (num_groups, num_experts, capacity, hidden_size) @@ -114,6 +122,16 @@ def forward(self, x: torch.Tensor, param_slice: Tuple[slice] = (slice(None),)) - inshape = x.shape x = x.reshape(e, -1, h) + if self.use_kernel and use_sparse: + seq_len = x.shape[1] + with torch.no_grad(): + mask = x[:, :, 0] != 0.0 + mask = torch.sum(mask, dim=-1) + x_list = [] + for i in range(e): + x_list.append(x[i, :mask[i]]) + x = x_list + if self.gated: x_gate = [torch.mm(x[i], self.wi_gate[param_slice][i]) for i in range(e)] x_up = [torch.mm(x[i], self.wi_up[param_slice][i]) for i in range(e)] @@ -127,86 +145,12 @@ def forward(self, x: torch.Tensor, param_slice: Tuple[slice] = (slice(None),)) - x = [self.drop(x[i]) for i in range(e)] x = [torch.mm(x[i], self.wo[param_slice][i]) for i in range(e)] + if self.use_kernel and use_sparse: + for i in range(e): + x[i] = torch.nn.functional.pad(x[i], (0, 0, 0, seq_len - x[i].shape[0]), mode="constant", value=0) + x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0) x = x.reshape(inshape) x = x.transpose(0, 1).contiguous() x = MoeOutGradScaler.apply(x, self.ep_size) return x - - -class EPMLPExperts(BaseMLPExperts): - """ - Use expert parallelism to split each expert evenly, which can deploy experts in - """ - - def __init__( - self, - num_experts: int, - hidden_size: int, - intermediate_size: int, - activation=None, - drop_rate: float = 0, - gated: bool = False, - use_kernel: bool = False, - ): - # TODO: This class can be aborted - super().__init__( - num_experts, - hidden_size, - intermediate_size, - "EP", - activation, - drop_rate, - gated, - use_kernel, - ) - - -class TPMLPExperts(BaseMLPExperts): - """Use tensor parallelism to split each expert evenly, which can deploy experts in - case that the number of experts can't be divide by maximum expert parallel size or - maximum expert parallel size can't be divide by the number of experts. - """ - - def __init__( - self, - num_experts: int, - hidden_size: int, - intermediate_size: int, - activation: str = None, - drop_rate: float = 0, - gated: bool = False, - use_kernel: bool = False, - ): - # TODO: This class can be aborted - super().__init__( - num_experts, - hidden_size, - intermediate_size, - "TP", - activation, - drop_rate, - gated, - use_kernel, - ) - - -def get_expert_class(name: str) -> BaseMLPExperts: - if name == "TP": - return TPMLPExperts - elif name == "EP": - return EPMLPExperts - elif name is None: - return BaseMLPExperts - else: - raise ValueError(f"Unknown expert class name: {name}") - - -def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): - mep_size = MOE_MANAGER.max_ep_size - if num_experts % mep_size == 0 or mep_size % num_experts == 0: - return EPMLPExperts(num_experts, d_model, d_ff, activation, drop_rate) - elif d_ff % mep_size == 0: - return TPMLPExperts(num_experts, d_model, d_ff, activation, drop_rate) - else: - raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.") diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index 9846cd432b53..bd2cefbe9ab8 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from colossalai.moe._operation import AllGather, AllToAll, MoeCombine, MoeDispatch, ReduceScatter -from colossalai.moe.experts import BaseMLPExperts, get_expert_class +from colossalai.moe.experts import MLPExperts from colossalai.moe.load_balance import LoadBalancer from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.routers import MoeRouter, get_router_cls @@ -48,50 +48,59 @@ class SparseMLP(nn.Module): def __init__( self, num_experts: int, - top_k: int = 1, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_policy: Optional[str] = None, - drop_tks: bool = True, - hidden_size: int = 2048, - intermediate_size: int = 2048, - activation: str = None, - gated: bool = False, + hidden_size: int, + intermediate_size: int, + router_top_k: int = 1, + router_capacity_factor_train: Optional[float] = 1.25, + router_capacity_factor_eval: Optional[float] = 2.0, + router_min_capacity: Optional[int] = 4, + router_noisy_policy: Optional[str] = None, + router_drop_tks: Optional[bool] = True, + mlp_activation: Optional[str] = None, + mlp_gated: Optional[bool] = False, + enable_load_balance: Optional[bool] = False, + load_balance_tolerance: Optional[float] = 0.1, + load_balance_beam_width: Optional[int] = 8, + load_balance_group_swap_factor: Optional[float] = 0.4, + enable_kernel: Optional[bool] = False, + enable_comm_overlap: Optional[bool] = False, ): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_experts = num_experts - self.use_kernel = MOE_MANAGER.use_kernel_optim + self.gated = mlp_gated + self.enable_kernel = enable_kernel + self.enable_comm_overlap = enable_comm_overlap self.expert_parallel = MOE_MANAGER.get_parallel() - self.gated = gated - assert self.expert_parallel in [ - "EP", - "TP", - None, - ], f"Unsupported expert parallel type {self.expert_parallel}" # moe router - noisy_func = get_noise_generator(noisy_policy, num_experts) - router_cls = get_router_cls(top_k) - self.topk = top_k + noisy_func = get_noise_generator(router_noisy_policy, num_experts) + router_cls = get_router_cls(router_top_k) + self.topk = router_top_k self.router: MoeRouter = router_cls( - capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, + capacity_factor_train=router_capacity_factor_train, + capacity_factor_eval=router_capacity_factor_eval, + min_capacity=router_min_capacity, noisy_func=noisy_func, - drop_tks=drop_tks, + drop_tks=router_drop_tks, ) + # gate + self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size)) + # moe experts - expert_cls = get_expert_class(self.expert_parallel) - self.experts: BaseMLPExperts = expert_cls(num_experts=num_experts, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - activation=activation, - gated=gated, - use_kernel=self.use_kernel) + self.experts = MLPExperts( + num_experts=self.num_experts, + expert_parallel=self.expert_parallel, + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + activation=mlp_activation, + gated=mlp_gated, + use_kernel=self.enable_kernel, + ) + + # get parallel settings if self.expert_parallel is not None: self.ep_group = get_ep_group(self.experts) self.ep_size = get_ep_size(self.experts) @@ -101,11 +110,8 @@ def __init__( self.dp_group = None self.num_local_experts = self.experts.num_local_experts - # gate - self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size)) - # load balance - self.enable_load_balance = MOE_MANAGER.load_balance + self.enable_load_balance = enable_load_balance if self.enable_load_balance == True: self.load_balancer = LoadBalancer( experts=self.experts, @@ -114,9 +120,9 @@ def __init__( expert_num=self.num_experts, ep_group=self.ep_group, dp_group=self.dp_group, - tolerance=MOE_MANAGER.tolerance, - beam_width=MOE_MANAGER.beam_width, - group_swap_factor=MOE_MANAGER.group_swap_factor, + tolerance=load_balance_tolerance, + beam_width=load_balance_beam_width, + group_swap_factor=load_balance_group_swap_factor, ) # init param @@ -147,14 +153,15 @@ def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: with torch.no_grad(): # TODO: optimize computation expert_load = torch.topk(gate_output, k=self.topk, dim=-1)[1] + # TODO: bincount introduces synchronize, fix it expert_load = torch.bincount(expert_load.view(-1)) self.load_balancer.update_load(expert_load) # the result from the router - route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group) + route_result_list = self.router(inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group) # dispatch_data: (num_experts, capacity, hidden_size) - if self.use_kernel: + if self.enable_kernel: dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:]) dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.hidden_size) else: @@ -163,16 +170,16 @@ def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # expert_output: (num_groups, num_experts, capacity, hidden_size) if self.expert_parallel == "EP": - expert_output = self._ep_process(dispatch_data) + expert_output = self._ep_process(dispatch_data, overlap=self.enable_comm_overlap) elif self.expert_parallel == "TP": - expert_output = self._tp_process(dispatch_data) + expert_output = self._tp_process(dispatch_data, overlap=self.enable_comm_overlap) elif self.expert_parallel is None: expert_output = self._local_process(dispatch_data) else: raise NotImplementedError("This kind of communication has not been implemented yet.\n" "Please use Experts build function.") - if self.use_kernel: + if self.enable_kernel: expert_output = expert_output.reshape(-1, self.hidden_size) ans = MoeCombine.apply(expert_output, *route_result_list) else: @@ -189,10 +196,7 @@ def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor: expert_out = self.experts(expert_in) return expert_out - def _ep_process(self, - dispatch_data: torch.Tensor, - overlap: bool = True - ) -> torch.Tensor: + def _ep_process(self, dispatch_data: torch.Tensor, overlap: bool = False) -> torch.Tensor: """ Expert Parallel @@ -210,16 +214,16 @@ def _ep_process(self, return expert_output else: + @dataclasses.dataclass - class Capsule(): + class Capsule: data: torch.Tensor handle: Any = None - NUM_CHUNK = 2 + NUM_CHUNK = 4 NUM_STAGES = 4 - assert dispatch_data.shape[1] % NUM_CHUNK == 0, \ - "arbitrary chunk num is not supported yet" + assert (dispatch_data.shape[1] % NUM_CHUNK == 0), "arbitrary chunk num is not supported yet" chunk_size = dispatch_data.shape[1] // NUM_CHUNK input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size) dispatch_data = dispatch_data.reshape(*input_shape) @@ -238,24 +242,17 @@ class Capsule(): # all2all last output if _expert_out is not None: - expert_out = Capsule( - *AllToAll.apply(_expert_out.data, self.ep_group, True), - ) + expert_out = Capsule(*AllToAll.apply(_expert_out.data, self.ep_group, True),) _expert_out = None # all2all next input if 0 <= i < NUM_CHUNK: - _expert_in = Capsule( - *AllToAll.apply(chunk_data[i].contiguous(), self.ep_group, True) - ) + _expert_in = Capsule(*AllToAll.apply(chunk_data[i].contiguous(), self.ep_group, True)) # compute if expert_in is not None: expert_in.handle.wait() - _expert_out = Capsule( - data=self.experts(expert_in.data), - handle=None - ) + _expert_out = Capsule(data=self.experts(expert_in.data), handle=None) expert_in = None if _expert_in is not None: @@ -264,10 +261,7 @@ class Capsule(): return output - def _tp_process(self, - dispatch_data: torch.Tensor, - overlap: bool = True - ) -> torch.Tensor: + def _tp_process(self, dispatch_data: torch.Tensor, overlap: bool = False) -> torch.Tensor: """ without overlap: | C | @@ -291,23 +285,24 @@ def _tp_process(self, expert_out = ReduceScatter.apply(expert_out, self.ep_group, False)[0] return expert_out else: + @dataclasses.dataclass - class Capsule(): + class Capsule: data: torch.Tensor handle: Any indices: Tuple - NUM_CHUNK = 2 + NUM_CHUNK = 4 NUM_STAGES = 4 - assert dispatch_data.shape[0] % NUM_CHUNK == 0, \ - "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts" + assert (dispatch_data.shape[0] % NUM_CHUNK == 0 + ), "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts" chunk_size = dispatch_data.shape[0] // NUM_CHUNK chunk_data = torch.split(dispatch_data, chunk_size, dim=0) output = torch.empty_like(dispatch_data) def get_chunk_slice(idx: int, chunk_size: int) -> Tuple[slice]: - return (slice(idx * chunk_size, (idx + 1) * chunk_size), ) + return (slice(idx * chunk_size, (idx + 1) * chunk_size),) _expert_in, expert_in, _expert_out, expert_out = None, None, None, None @@ -321,7 +316,7 @@ def get_chunk_slice(idx: int, chunk_size: int) -> Tuple[slice]: if _expert_out is not None: expert_out = Capsule( *ReduceScatter.apply(_expert_out.data, self.ep_group, True), - indices=_expert_out.indices + indices=_expert_out.indices, ) _expert_out = None @@ -329,7 +324,7 @@ def get_chunk_slice(idx: int, chunk_size: int) -> Tuple[slice]: if 0 <= i < NUM_CHUNK: _expert_in = Capsule( *AllGather.apply(chunk_data[i].contiguous(), self.ep_group, True), - indices=get_chunk_slice(i, chunk_size) + indices=get_chunk_slice(i, chunk_size), ) # compute @@ -337,7 +332,8 @@ def get_chunk_slice(idx: int, chunk_size: int) -> Tuple[slice]: expert_in.handle.wait() _expert_out = Capsule( self.experts(expert_in.data, expert_in.indices), - handle=None, indices=expert_in.indices + handle=None, + indices=expert_in.indices, ) expert_in = None @@ -360,4 +356,6 @@ def _apply_recursive(module: nn.Module): sub_module.load_balancer.balance_load(optim) _apply_recursive(sub_module) + torch.cuda.empty_cache() _apply_recursive(model) + torch.cuda.empty_cache() diff --git a/colossalai/moe/load_balance.py b/colossalai/moe/load_balance.py index b2fb672329c2..4a3d0fe4d096 100644 --- a/colossalai/moe/load_balance.py +++ b/colossalai/moe/load_balance.py @@ -7,7 +7,7 @@ from torch.distributed import ProcessGroup from colossalai.cluster import ProcessGroupMesh -from colossalai.moe.experts import BaseMLPExperts +from colossalai.moe.experts import MLPExperts from colossalai.moe.manager import MOE_MANAGER from colossalai.zero.low_level import LowLevelZeroOptimizer @@ -16,7 +16,7 @@ class LoadBalancer: def __init__( self, - experts: BaseMLPExperts, + experts: MLPExperts, gate: nn.Parameter, local_expert_num: int, expert_num: int, @@ -26,7 +26,7 @@ def __init__( beam_width: Optional[int] = 8, group_swap_factor: Optional[float] = 0.4, ) -> None: - self.experts: BaseMLPExperts = experts + self.experts: MLPExperts = experts self.gate: nn.Parameter = gate self.moe_ep_group: ProcessGroup = ep_group self.moe_ep_ranks = MOE_MANAGER.parallel_info_dict[dist.get_world_size(self.moe_ep_group)].ep_group_ranks @@ -423,6 +423,11 @@ def balance_load(self, optim: LowLevelZeroOptimizer) -> None: load = self._load_to_list(load) # search balance swap_list = self._search_balance(load) + if dist.get_rank() == 0: + if len(swap_list) > 0: + print(f"[Load Balance] Applying expert swap...") + else: + print(f"[Load Balance] Invalid swap, skip...") # swap expert and gate self._swap_moe_param(swap_list, optim) # clear load diff --git a/colossalai/moe/manager.py b/colossalai/moe/manager.py index e3659ef43fbd..f237ea134638 100644 --- a/colossalai/moe/manager.py +++ b/colossalai/moe/manager.py @@ -14,29 +14,29 @@ class MoeManager(metaclass=SingletonMeta): """ def __init__(self): - self.world_size = None - # Users may want to set maximum expert parallel size smaller than the world size - # since very low bandwidth across nodes may constrain the performance of MoE - # When we have a maximum expert parallel size, we have a minimum data parallel size naturally - self.max_ep_size = None - self.min_dp_size = None - self.router_aux_loss = [] - self.router_z_loss = [] self.parallel = None self.seed = None self.mode = None - self.use_kernel_optim = False self.use_ep_inside = None + self.world_size = None + self._parallel_info_dict = dict() + + # router + self.router_aux_loss = [] + self.router_z_loss = [] + + # fixed mode self.pp_size = None + self.dp_size = None + self.ep_size = None - # load balance param - self.load_balance = None - self.tolerance = None - self.beam_width = None - self.group_swap_factor = None + # dynamic mode + # Users may want to set maximum expert parallel size smaller than the world size + # since very low bandwidth across nodes may constrain the performance of MoE + # When we have a maximum expert parallel size, we have a minimum data parallel size naturally + self.max_ep_size = None self.has_setup = False - self._parallel_info_dict = dict() @property def parallel_info_dict(self): @@ -49,7 +49,6 @@ def is_initialized(self): def setup( self, seed: int, - use_kernel_optim: bool = False, parallel: str = None, mode: str = "dynamic", max_ep_size: int = 8, @@ -57,10 +56,6 @@ def setup( fixed_ep_size: int = 0, fixed_pp_size: int = 0, use_ep_inside: bool = True, - enable_load_balance: bool = False, - tolerance: float = 0.1, - beam_width: int = 8, - group_swap_factor: float = 0.4, ) -> None: """ Setup MoE distributed context. @@ -78,38 +73,28 @@ def setup( fixed_pp_size (int, optional): Fixed pp size in fixed mode. Defaults to 0. use_ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True. """ - assert not self.is_initialized, "MoE distributed context shouldn't be set up again" + assert (not self.is_initialized), "MoE distributed context shouldn't be set up again" assert torch.cuda.is_available(), "MoE requires to enable CUDA first" - self.world_size = dist.get_world_size() self.seed = seed + dist.get_rank() self.parallel = parallel self.use_ep_inside = use_ep_inside + self.world_size = dist.get_world_size() # init by mode self.mode = mode assert self.mode in ["fixed", "dynamic"], "mode should be fixed or dynamic" if self.mode == "dynamic": - self.max_ep_size = min(max_ep_size, dist.get_world_size()) - self.min_dp_size = self.world_size // self.max_ep_size + self.max_ep_size = min(max_ep_size, self.world_size) else: - assert fixed_dp_size > 0 and fixed_ep_size > 0 and fixed_pp_size > 0, "dp_size, ep_size and pp_size should be greater than 0" - assert isinstance(fixed_dp_size, int) and isinstance(fixed_ep_size, int) and isinstance( - fixed_pp_size, int), "dp_size, ep_size and pp_size should be int" + assert (fixed_dp_size > 0 and fixed_ep_size > 0 + and fixed_pp_size > 0), "dp_size, ep_size and pp_size should be greater than 0" + assert (isinstance(fixed_dp_size, int) and isinstance(fixed_ep_size, int) + and isinstance(fixed_pp_size, int)), "dp_size, ep_size and pp_size should be int" self.ep_size = fixed_ep_size self.dp_size = fixed_dp_size self.pp_size = fixed_pp_size - # Enabling kernel optimization may raise error in some cases - # Users can close kernel optimization manually - self.use_kernel_optim = use_kernel_optim - - # update load balance - self.load_balance = enable_load_balance - self.tolerance = tolerance - self.beam_width = beam_width - self.group_swap_factor = group_swap_factor - self.has_setup = True def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoeParallelInfo]: @@ -127,21 +112,13 @@ def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoePara """ if self.mode == "dynamic": - gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater - lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less - - assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number" \ - " is not a multiple of ep size or vice versa." - - # If the number of experts is greater than maximum expert parallel size. a.k.a ep_size, - # there are multiple experts in each GPU and each GPU has different experts - # So it's data parallel size is 1 - # Otherwise, there is only one expert in each GPU - # The data parallel size should be calculated - dp_size = 1 if gt_flag else self.max_ep_size // num_experts - ep_size = self.max_ep_size // dp_size - # Don't forget to multiply minimum data parallel size - dp_size *= self.min_dp_size + gt_flag = (num_experts % self.max_ep_size == 0) # check whether num_experts is greater + lt_flag = (self.max_ep_size % num_experts == 0) # check whether num_experts is less + assert gt_flag or lt_flag, ("Automatic experts placement dose not not support expert number" + " is not a multiple of ep size or vice versa.") + dp_size = 1 if gt_flag else self.world_size // num_experts + ep_size = min(self.world_size // dp_size, self.max_ep_size) + dp_size = self.world_size // ep_size pp_size = 1 else: dp_size = self.dp_size @@ -167,13 +144,10 @@ def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoePara return num_local_experts, self.parallel_info_dict[ep_size] - def set_kernel_not_use(self): - self.use_kernel_optim = False - def reset_loss(self): self.router_aux_loss, self.router_z_loss = [], [] - def add_loss(self, aux_loss: float = 0., z_loss: float = 0.): + def add_loss(self, aux_loss: float = 0.0, z_loss: float = 0.0): self.router_aux_loss.append(aux_loss) self.router_z_loss.append(z_loss) diff --git a/colossalai/moe/routers.py b/colossalai/moe/routers.py index 1ac66f7bb78f..7960a74d4539 100644 --- a/colossalai/moe/routers.py +++ b/colossalai/moe/routers.py @@ -30,7 +30,8 @@ def __init__(self, capacity_factor_eval: float, min_capacity: int, noisy_func: Optional[Callable] = None, - drop_tks: bool = True): + drop_tks: bool = True, + use_kernel: bool = False): super().__init__() self.k_value = k_value self.capacity_factor_train = capacity_factor_train @@ -40,6 +41,7 @@ def __init__(self, self.drop_tks = drop_tks self._aux_loss = None self._z_loss = None + self.use_kernel = use_kernel def get_capacity(self, logits_shape): capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval @@ -49,11 +51,7 @@ def get_capacity(self, logits_shape): assert capacity > 0 return int(capacity) - def set_aux_loss(self, - router_probs: torch.Tensor, - expert_indices: torch.Tensor, - num_experts: int - ) -> None: + def set_aux_loss(self, router_probs: torch.Tensor, expert_indices: torch.Tensor, num_experts: int) -> None: """Computes auxiliary load balancing loss as in Switch Transformer. See Switch Transformer (https://arxiv.org/abs/2101.03961). This function @@ -81,8 +79,7 @@ def set_aux_loss(self, tokens_per_group_and_expert = torch.mean(expert_mask.float(), dim=-2) router_prob_per_group_and_expert = torch.mean(router_probs.float(), dim=-2) - aux_loss = num_experts**2 * torch.mean( - tokens_per_group_and_expert * router_prob_per_group_and_expert) + aux_loss = num_experts**2 * torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) self._aux_loss = aux_loss def set_z_loss(self, router_logits: torch.Tensor): @@ -101,8 +98,7 @@ def set_z_loss(self, router_logits: torch.Tensor): assert router_logits.dim() == 3, "router_logits must be 3D tensor" num_groups, tokens_per_group, _ = router_logits.shape log_z = torch.logsumexp(router_logits, dim=-1) - z_loss = torch.sum(log_z**2, dtype=torch.float32 - ) / (num_groups * tokens_per_group) + z_loss = torch.sum(log_z**2, dtype=torch.float32) / (num_groups * tokens_per_group) self._z_loss = z_loss def pop_router_loss(self) -> torch.Tensor: @@ -113,8 +109,8 @@ def pop_router_loss(self) -> torch.Tensor: class Top1Router(MoeRouter): - """Top1 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity) - and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed + """Top1 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity) + and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed function can be found in the paper about Switch Transformer of Google. Args: @@ -142,22 +138,17 @@ def __init__(self, self.select_policy = select_policy assert select_policy in {"first", "random"} if select_policy == "random": - self.uniform = torch.distributions.uniform.Uniform( - low=torch.tensor(0.0, device=get_current_device()), - high=torch.tensor(1.0, device=get_current_device()) - ).rsample - - def forward(self, - inputs: torch.Tensor, - use_kernel: bool = False, - ep_group: Optional[ProcessGroup] = None - ) -> Tuple: + self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()), + high=torch.tensor(1.0, + device=get_current_device())).rsample + + def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: """ Args: inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). Returns: - 1. use_kernel is False: + 1. use_kernel is False: The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity). The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity). 2. use_kernel is True: @@ -188,9 +179,9 @@ def forward(self, rand_mask = mask * self.uniform(mask.shape) _, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0) mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1) - ranks = moe_cumsum(mask) + ranks = moe_cumsum(mask, use_kernel=self.use_kernel) elif self.select_policy == "first": - ranks = moe_cumsum(mask) + ranks = moe_cumsum(mask, use_kernel=self.use_kernel) mask = mask * torch.lt(ranks, capacity) else: raise NotImplementedError("Not support such select policy yet.") @@ -211,8 +202,8 @@ def forward(self, class Top2Router(MoeRouter): - """Top2 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity) - and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed + """Top2 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity) + and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed function can be found in the paper about ViT-MoE. Args: @@ -236,17 +227,13 @@ def __init__(self, noisy_func=noisy_func, drop_tks=drop_tks) - def forward(self, - inputs: torch.Tensor, - use_kernel: bool = False, - ep_group: Optional[ProcessGroup] = None - ) -> Tuple: + def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: """ Args: inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). Returns: - 1. use_kernel is False: + 1. use_kernel is False: The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity). The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity). 2. use_kernel is True: @@ -280,8 +267,8 @@ def forward(self, dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) capacity = max_num.item() - rank1 = moe_cumsum(mask1) # rank1: [s, e] - rank2 = moe_cumsum(mask2) + rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e] + rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel) rank2 += torch.sum(mask1, dim=-2, keepdim=True) mask1 *= torch.lt(rank1, capacity) @@ -313,7 +300,7 @@ def forward(self, weight1 = mask1 * probs.type_as(inputs) weight2 = mask2 * probs.type_as(inputs) - cb_weight = torch.zeros(inputs.shape + (capacity, ), device=inputs.device) + cb_weight = torch.zeros(inputs.shape + (capacity,), device=inputs.device) sec_mask = torch.zeros_like(cb_weight, dtype=torch.bool) indices = torch.arange(0, inputs.shape[0], device=inputs.device) cb_weight[indices, top1_idx[indices], rank1[indices]] += weight1[indices, top1_idx[indices]] @@ -348,17 +335,14 @@ def __init__(self, min_capacity: int = 4, noisy_func: Optional[Callable] = None, drop_tks: bool = True): - super().__init__(num_selected_experts, - capacity_factor_train, - capacity_factor_eval, - min_capacity, - noisy_func, + super().__init__(num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, drop_tks) - def forward(self, - router_probs: torch.Tensor, - expert_capacity: int, - ) -> Tuple: + def forward( + self, + router_probs: torch.Tensor, + expert_capacity: int, + ) -> Tuple: """Computes masks for the top-k experts per token. Args: @@ -418,17 +402,12 @@ def forward(self, # The combine array will be used for combining expert outputs, scaled by the # router probabilities. Shape: [num_groups, tokens_per_group, num_experts, # expert_capacity]. - combine_array = torch.einsum( - '...te,...tec->...tec', - router_probs, - dispatch_mask) + combine_array = torch.einsum('...te,...tec->...tec', router_probs, dispatch_mask) return combine_array, dispatch_mask -def get_router_cls(top_k: int, - grouped: bool = False - ) -> MoeRouter: +def get_router_cls(top_k: int, grouped: bool = False) -> MoeRouter: if not grouped: if top_k == 1: return Top1Router diff --git a/colossalai/moe/utils.py b/colossalai/moe/utils.py index e3bc6d3cac9a..0938e4206fda 100644 --- a/colossalai/moe/utils.py +++ b/colossalai/moe/utils.py @@ -1,5 +1,5 @@ import contextlib -from typing import Callable, Dict, List +from typing import Any, Callable, Dict, List import torch import torch.distributed as dist @@ -170,3 +170,8 @@ def sync_moe_model_param(model: nn.Module): for param in param_dict[ep_size]: src_rank = get_dp_group_ranks(param)[0] dist.broadcast(param, src=src_rank, group=get_dp_group(param)) + + +def set_moe_args(config: Any, args: dict): + for k, v in args.items(): + setattr(config, k, v) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 01776a8352fc..b256da222a93 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -8,7 +8,7 @@ import torch.distributed as dist import torch.nn as nn from torch import Tensor, inf -from torch._utils import _flatten_dense_tensors +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch.distributed import ProcessGroup from torch.optim import Optimizer @@ -299,6 +299,40 @@ def _run_reduction(self): if self.extra_dp_pg is None: flat_grads = self._bucket_store.get_flatten_grad() flat_grads /= self._world_size + else: + # record moe and non moe param + moe_list = [] + for param in self._bucket_store._param_list: + moe_list.append(is_moe_tensor(param)) + + # divide them into different groups + moe_grad_list = [] + non_moe_grad_list = [] + for grad_list in self._bucket_store._grad_in_bucket.values(): + non_moe_cur_grad = [] + moe_cur_grad = [] + for i in range(len(grad_list)): + if moe_list[i] == True: + moe_cur_grad.append(grad_list[i]) + else: + non_moe_cur_grad.append(grad_list[i]) + if len(moe_cur_grad) > 0: + moe_grad_list.append(moe_cur_grad) + if len(non_moe_cur_grad) > 0: + non_moe_grad_list.append(non_moe_cur_grad) + + if len(non_moe_grad_list) > 0: + non_moe_flat_grads = [] + for grad_list in non_moe_grad_list: + non_moe_flat_grads.append(_flatten_dense_tensors(grad_list)) + non_moe_flat_grads = _flatten_dense_tensors(non_moe_flat_grads) + non_moe_flat_grads /= self._world_size + + if len(moe_grad_list) > 0: + moe_flat_grads = [] + for grad_list in moe_grad_list: + moe_flat_grads.append(_flatten_dense_tensors(grad_list)) + moe_flat_grads = _flatten_dense_tensors(moe_flat_grads) # ready to add other tensors to bucket self._bucket_store.reset_num_elements_in_bucket() @@ -308,6 +342,11 @@ def _run_reduction(self): # in case of the memory being reused in the default stream if self.extra_dp_pg is None: flat_grads.record_stream(stream) + else: + if len(non_moe_grad_list) > 0: + non_moe_flat_grads.record_stream(stream) + if len(moe_grad_list) > 0: + moe_flat_grads.record_stream(stream) # waiting for ops in the default stream finishing stream.wait_stream(torch.cuda.current_stream()) else: @@ -342,64 +381,70 @@ def _run_reduction(self): # sync extra zero group else: - # record moe and non moe param - moe_list = [] - for param in self._bucket_store._param_list: - moe_list.append(is_moe_tensor(param)) - - # divide them into different groups - moe_grad_list = [] - non_moe_grad_list = [] - for grad_list in self._bucket_store._grad_in_bucket.values(): - non_moe_cur_grad = [] - moe_cur_grad = [] - for i in range(len(grad_list)): - if moe_list[i] == True: - moe_cur_grad.append(grad_list[i]) - else: - non_moe_cur_grad.append(grad_list[i]) - if len(moe_cur_grad) > 0: - moe_grad_list.append(moe_cur_grad) - if len(non_moe_cur_grad) > 0: - non_moe_grad_list.append(non_moe_cur_grad) - # sync non moe param in global dp group if len(non_moe_grad_list) > 0: - flat_grads = [] - for grad_list in non_moe_grad_list: - flat_grads.append(_flatten_dense_tensors(grad_list)) - flat_grads = _flatten_dense_tensors(flat_grads) - flat_grads /= self._world_size - dist.all_reduce(flat_grads, group=self.dp_pg) - flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size) + dist.all_reduce(non_moe_flat_grads, group=self.dp_pg) + flat_grads_per_rank = non_moe_flat_grads.split(non_moe_flat_grads.numel() // + self._world_size) self._sync_unpartitioned_grad(non_moe_grad_list, flat_grads_per_rank, group_id) # sync moe param only in zero group if len(moe_grad_list) > 0: - flat_grads = [] - for grad_list in moe_grad_list: - flat_grads.append(_flatten_dense_tensors(grad_list)) - flat_grads = _flatten_dense_tensors(flat_grads) - dist.all_reduce(flat_grads, group=self.extra_dp_pg) - flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size) + dist.all_reduce(moe_flat_grads, group=self.extra_dp_pg) + flat_grads_per_rank = moe_flat_grads.split(moe_flat_grads.numel() // self._world_size) self._sync_unpartitioned_grad(moe_grad_list, flat_grads_per_rank, group_id) else: - flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size)) - recieved_grad = torch.zeros_like(flat_grads_list[0]) - dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) - - if recieved_grad.dtype != grad_dtype: - recieved_grad = recieved_grad.to(grad_dtype) - - grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] - sync_tensor(recieved_grad, grad_in_bucket_current_rank) - for grad in grad_in_bucket_current_rank: - param_id = self._bucket_store.get_param_id_of_grad(grad) - if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < 1: - self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) - else: - self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id) + if self.extra_dp_pg is None: + flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size)) + recieved_grad = torch.zeros_like(flat_grads_list[0]) + dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) + + if recieved_grad.dtype != grad_dtype: + recieved_grad = recieved_grad.to(grad_dtype) + + grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] + sync_tensor(recieved_grad, grad_in_bucket_current_rank) + for grad in grad_in_bucket_current_rank: + param_id = self._bucket_store.get_param_id_of_grad(grad) + if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < 1: + self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) + else: + self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id) + else: + if len(non_moe_grad_list) > 0: + flat_grads_list = list(non_moe_flat_grads.split( + len(non_moe_flat_grads) // self._world_size)) + recieved_grad = torch.zeros_like(flat_grads_list[0]) + dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) + + grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] + sync_tensor(recieved_grad, grad_in_bucket_current_rank) + for grad in grad_in_bucket_current_rank: + param_id = self._bucket_store.get_param_id_of_grad(grad) + if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < 1: + self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) + else: + self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id) + + if len(moe_grad_list) > 0: + flat_grads_list = list(moe_flat_grads.split(len(moe_flat_grads) // self.extra_dp_pg_size)) + recieved_grad = torch.zeros_like(flat_grads_list[0]) + dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.extra_dp_pg) + + param_slice = self._world_size // self.extra_dp_pg_size + recieved_grad = list(recieved_grad.split(len(recieved_grad) // param_slice)) + grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] + for split_recieved_grad in recieved_grad: + split_recieved_grad = _unflatten_dense_tensors(split_recieved_grad, + grad_in_bucket_current_rank) + for grad in grad_in_bucket_current_rank: + param_id = self._bucket_store.get_param_id_of_grad(grad) + if len(self._grad_store.get_partitioned_gradients_by_param_id( + group_id, param_id)) < param_slice: + self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) + else: + self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id) self._bucket_store.reset() @@ -532,8 +577,11 @@ def step(self, closure=None): # moe hybrid zero if self.extra_dp_pg is not None and is_moe_tensor(working_param): real_working_params[group_id].append(working_param) - param_slice = self._world_size // self.extra_dp_pg_size - grad = grads[self.extra_dp_pg_rank * param_slice:(self.extra_dp_pg_rank + 1) * param_slice] + if self._partition_grads: + grad = grads + else: + param_slice = self._world_size // self.extra_dp_pg_size + grad = grads[self.extra_dp_pg_rank * param_slice:(self.extra_dp_pg_rank + 1) * param_slice] grad = flatten(grad) else: real_working_params[group_id].append(working_param) @@ -559,13 +607,21 @@ def step(self, closure=None): global_norm = calculate_global_norm_from_list(norm_list=norm_groups) self._unscale_and_clip_grads(grad_partition_groups, global_norm) + # TODO: we should store master param for ep + if len(self.param_groups) > len(self._working_param_groups): + for param in self.param_groups[-1]['params']: + param.data = param.data.to(torch.float32) + param.grad = param.grad.to(torch.float32) + # update the parameters self.optim.step() - # release the moe grad + # TODO: release the moe grad. we should store master param if len(self.param_groups) > len(self._working_param_groups): + dtype = real_working_params[0][0].dtype for param in self.param_groups[-1]['params']: param.grad = None + param.data = param.data.to(dtype) # release the grad grad_partition_groups = [] diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py index e1acba5c88b0..2f6bfa0f89a2 100644 --- a/examples/language/openmoe/benchmark/benchmark_cai.py +++ b/examples/language/openmoe/benchmark/benchmark_cai.py @@ -15,12 +15,12 @@ import colossalai from colossalai import get_default_parser from colossalai.booster import Booster -from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator from colossalai.moe.layers import apply_load_balance from colossalai.moe.manager import MOE_MANAGER -from colossalai.moe.utils import skip_init +from colossalai.moe.utils import set_moe_args, skip_init +from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device @@ -118,7 +118,7 @@ def parse_args(): parser.add_argument("--pp_size", type=int, default=2, help="pp size") parser.add_argument("--dp_size", type=int, default=1, help="dp size") parser.add_argument("--ep_size", type=int, default=2, help="ep size") - parser.add_argument("--zero_stage", type=int, default=1, help="zero stage in hybrid plugin") + parser.add_argument("--zero_stage", type=int, default=2, help="zero stage in hybrid plugin") parser.add_argument("--microbatch_size", type=int, default=1, help="microbatch size") parser.add_argument("--extra_dp_size", type=int, default=1) # kernel @@ -132,6 +132,9 @@ def parse_args(): parser.add_argument("--active", type=int, default=20) # load balance parser.add_argument("--load_balance", action="store_true") + + # overlap + parser.add_argument("--overlap_alltoall", action="store_true") args = parser.parse_args() return args @@ -150,49 +153,38 @@ def main(): "custom_policy": OpenMoeForCausalLMPolicy(), "enable_fused_normalization": args.use_kernel, "enable_jit_fused": args.use_kernel, - "precision": "bf16" + "precision": "bf16", + "zero_stage": args.zero_stage, + } + mgr_dict = { + "seed": 42, } - mgr_dict = {"seed": 42, "use_kernel_optim": args.use_kernel, "enable_load_balance": args.load_balance} if args.plugin == "zero": - dp_size = dist.get_world_size() - plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2) - MOE_MANAGER.setup( - parallel=None, - **mgr_dict, - ) - elif args.plugin == "ep": dp_size = dist.get_world_size() plugin = MoeHybridParallelPlugin( pp_size=1, - zero_stage=2, **hybrid_dict, ) MOE_MANAGER.setup( - parallel="EP", + parallel=None, **mgr_dict, ) - elif args.plugin == "ep_zero": + elif args.plugin == "ep": dp_size = dist.get_world_size() - use_ep_inside = False plugin = MoeHybridParallelPlugin( pp_size=1, - zero_stage=1, - extra_dp_size=args.extra_dp_size, - use_ep_inside=use_ep_inside, **hybrid_dict, ) MOE_MANAGER.setup( parallel="EP", - max_ep_size=dp_size // args.extra_dp_size, - use_ep_inside=use_ep_inside, + max_ep_size=dp_size, **mgr_dict, ) - elif args.plugin == "zero_ep": + elif args.plugin == "ep_zero": dp_size = dist.get_world_size() - use_ep_inside = True + use_ep_inside = False plugin = MoeHybridParallelPlugin( pp_size=1, - zero_stage=1, extra_dp_size=args.extra_dp_size, use_ep_inside=use_ep_inside, **hybrid_dict, @@ -226,10 +218,28 @@ def main(): # Build OpenMoe model repo_name = "hpcaitech/openmoe-" + args.model_name config = LlamaConfig.from_pretrained(repo_name) - setattr(config, "router_aux_loss_factor", 0.1) - setattr(config, "router_z_loss_factor", 0.1) - setattr(config, "label_smoothing", 0.1) - setattr(config, "z_loss_factor", 0.1) + moe_args = { + "num_experts": config.num_experts, + "moe_layer_interval": config.moe_layer_interval, + "router_topk": 2, + "router_capacity_factor_train": 1.25, + "router_capacity_factor_eval": 2.0, + "router_min_capacity": 4, + "router_noisy_policy": None, + "router_drop_tks": True, + "router_aux_loss_factor": 0.01, + "router_z_loss_factor": 0.01, + "mlp_gated": True, + "label_smoothing": 0.001, + "z_loss_factor": 0.01, + "enable_load_balance": args.load_balance, + "load_balance_tolerance": 0.1, + "load_balance_beam_width": 8, + "load_balance_group_swap_factor": 0.4, + "enable_kernel": args.use_kernel, + "enable_comm_overlap": args.overlap_alltoall, + } + set_moe_args(config, moe_args) with skip_init(): model = OpenMoeForCausalLM(config) coordinator.print_on_master(f"Finish init model with config:\n{config}") @@ -247,7 +257,7 @@ def main(): dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size) # Set optimizer - optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.01, lr=1e-5) + optimizer = HybridAdam(model.parameters(), weight_decay=0.01, lr=1e-5) model_numel = get_model_numel(model) performance_evaluator = PerformanceEvaluator( @@ -259,8 +269,8 @@ def main(): # Set booster booster = Booster(plugin=plugin, **booster_kwargs) - model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader) load_ckpt(repo_name, model, booster) + model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader) use_pipeline = (isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1) is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() coordinator.print_on_master(f"Finish init booster") @@ -302,8 +312,8 @@ def main(): optimizer.zero_grad() performance_evaluator.on_step_end(exmaple_data["input_ids"]) if (step == args.warmup // 2) and args.load_balance: - apply_load_balance(model, optimizer) coordinator.print_on_master(f"Apply load balance") + apply_load_balance(model, optimizer) performance_evaluator.on_fit_end() diff --git a/examples/language/openmoe/benchmark/benchmark_cai.sh b/examples/language/openmoe/benchmark/benchmark_cai.sh index ec4490faa55d..f269e260d8db 100755 --- a/examples/language/openmoe/benchmark/benchmark_cai.sh +++ b/examples/language/openmoe/benchmark/benchmark_cai.sh @@ -2,10 +2,10 @@ set -xue -NUM_GPU=4 +NUM_GPU=8 MODEL="8b" SEQ_LENGTH=2048 -WARMUP=8 +WARMUP=20 ACTIVE=4 # HACK: make model importable @@ -16,51 +16,50 @@ else export PYTHONPATH=$example_dir:$PYTHONPATH fi -# zero -torchrun --standalone --nproc_per_node $NUM_GPU \ - $example_dir/benchmark/benchmark_cai.py \ - --model_name $MODEL \ - --batch_size 4 \ - --seq_length $SEQ_LENGTH \ - --warmup $WARMUP \ - --active $ACTIVE \ - --plugin zero \ - --use_kernel # ep +echo -e "\n\n Naive EP \n\n" torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ --model_name $MODEL \ - --batch_size 12 \ + --batch_size 8 \ --seq_length $SEQ_LENGTH \ --warmup $WARMUP \ --active $ACTIVE \ --plugin ep \ - --use_kernel + --zero_stage 2 + # ep_zero +echo -e "\n\n EP-ZERO \n\n" torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ --model_name $MODEL \ - --batch_size 12 \ + --batch_size 16 \ --seq_length $SEQ_LENGTH \ --warmup $WARMUP \ --active $ACTIVE \ --plugin ep_zero \ --use_kernel \ - --extra_dp_size 2 + --extra_dp_size 2 \ + --zero_stage 1 \ + --load_balance -# zero_ep +echo -e "\n\n EP-ZERO + Overlap \n\n" torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ --model_name $MODEL \ - --batch_size 12 \ + --batch_size 16 \ --seq_length $SEQ_LENGTH \ --warmup $WARMUP \ --active $ACTIVE \ - --plugin zero_ep \ + --plugin ep_zero \ --use_kernel \ - --extra_dp_size 2 + --extra_dp_size 2 \ + --zero_stage 1 \ + --load_balance \ + --overlap_alltoall + # hybrid torchrun --standalone --nproc_per_node $NUM_GPU \ diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.py b/examples/language/openmoe/benchmark/benchmark_fsdp.py index 0edf102d640c..531e18313798 100644 --- a/examples/language/openmoe/benchmark/benchmark_fsdp.py +++ b/examples/language/openmoe/benchmark/benchmark_fsdp.py @@ -12,7 +12,6 @@ from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.utils.data import Dataset from torch.utils.data.distributed import DistributedSampler -from transformers import Adafactor from transformers.models.llama import LlamaConfig from utils import PerformanceEvaluator, get_model_numel @@ -80,7 +79,7 @@ def fsdp_main(rank, world_size, args): auto_wrap_policy=auto_wrap_policy, device_id=torch.cuda.current_device(), ) - optimizer = Adafactor(model.parameters()) + optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.01, lr=1e-5) model.train() model_numel = get_model_numel(model) diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.sh b/examples/language/openmoe/benchmark/benchmark_fsdp.sh index e1eb2a9c6053..0380ee1ade20 100755 --- a/examples/language/openmoe/benchmark/benchmark_fsdp.sh +++ b/examples/language/openmoe/benchmark/benchmark_fsdp.sh @@ -6,8 +6,8 @@ NUM_GPU=8 MODEL="8b" BATCH_SIZE=1 SEQ_LENGTH=2048 -WARMUP=6 -ACTIVE=3 +WARMUP=8 +ACTIVE=4 # HACK: make model importable example_dir=$(dirname $(realpath $(dirname $0))) diff --git a/examples/language/openmoe/benchmark/hostfile.txt b/examples/language/openmoe/benchmark/hostfile.txt new file mode 100644 index 000000000000..994b3e2cfc4f --- /dev/null +++ b/examples/language/openmoe/benchmark/hostfile.txt @@ -0,0 +1,2 @@ +host1 +host2 diff --git a/examples/language/openmoe/infer.py b/examples/language/openmoe/infer.py index f59772189827..1ad1456b9c56 100644 --- a/examples/language/openmoe/infer.py +++ b/examples/language/openmoe/infer.py @@ -5,6 +5,8 @@ from transformers import T5Tokenizer from transformers.models.llama import LlamaConfig +from colossalai.moe.utils import set_moe_args + def parse_args(): parser = ArgumentParser() @@ -17,9 +19,54 @@ def inference(args): tokenizer = T5Tokenizer.from_pretrained("google/umt5-small") if args.model == "test": config = LlamaConfig.from_pretrained("hpcaitech/openmoe-base") + moe_args = { + "num_experts": config.num_experts, + "moe_layer_interval": config.moe_layer_interval, + "router_topk": 2, + "router_capacity_factor_train": 1.25, + "router_capacity_factor_eval": 2.0, + "router_min_capacity": 4, + "router_noisy_policy": None, + "router_drop_tks": True, + "router_aux_loss_factor": 0.01, + "router_z_loss_factor": 0.01, + "mlp_gated": True, + "label_smoothing": 0.001, + "z_loss_factor": 0.01, + "enable_load_balance": False, + "load_balance_tolerance": 0.1, + "load_balance_beam_width": 8, + "load_balance_group_swap_factor": 0.4, + "enable_kernel": False, + "enable_comm_overlap": False, + } + set_moe_args(config, moe_args) model = OpenMoeForCausalLM(config) else: - model = OpenMoeForCausalLM.from_pretrained(f"hpcaitech/openmoe-{args.model}") + config = LlamaConfig.from_pretrained(f"hpcaitech/openmoe-{args.model}") + moe_args = { + "num_experts": config.num_experts, + "moe_layer_interval": config.moe_layer_interval, + "router_topk": 2, + "router_capacity_factor_train": 1.25, + "router_capacity_factor_eval": 2.0, + "router_min_capacity": 4, + "router_noisy_policy": None, + "router_drop_tks": True, + "router_aux_loss_factor": 0.01, + "router_z_loss_factor": 0.01, + "mlp_gated": True, + "label_smoothing": 0.001, + "z_loss_factor": 0.01, + "enable_load_balance": False, + "load_balance_tolerance": 0.1, + "load_balance_beam_width": 8, + "load_balance_group_swap_factor": 0.4, + "enable_kernel": False, + "enable_comm_overlap": False, + } + set_moe_args(config, moe_args) + model = OpenMoeForCausalLM.from_pretrained(f"hpcaitech/openmoe-{args.model}", config=config) model = model.eval().half() model = model.to(torch.cuda.current_device()) diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index 357c0f22a783..6f9b668e4597 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -39,6 +39,7 @@ from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON from colossalai.moe.layers import SparseMLP from colossalai.moe.manager import MOE_MANAGER +from colossalai.moe.utils import get_activation if HAS_TRITON: from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine @@ -166,7 +167,7 @@ def SwiGLU(x): class OpenMoeMLP(nn.Module): - def __init__(self, config): + def __init__(self, config: LlamaConfig): super().__init__() self.pretraining_tp = config.pretraining_tp self.hidden_size = config.hidden_size @@ -174,8 +175,9 @@ def __init__(self, config): self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = SwiGLU - self.use_kernel = True if MOE_MANAGER.use_kernel_optim else False + self.hidden_act = config.hidden_act + self.act_fn = get_activation(self.hidden_act) + self.use_kernel = config.enable_kernel def forward(self, x): if self.pretraining_tp > 1: @@ -191,7 +193,7 @@ def forward(self, x): down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)] down_proj = sum(down_proj) else: - if HAS_TRITON and self.use_kernel: + if HAS_TRITON and self.use_kernel and self.hidden_act == "swiglu": down_proj = self.down_proj(LlamaActCombine.apply(self.gate_proj(x), self.up_proj(x))) else: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) @@ -361,16 +363,22 @@ def __init__(self, config: LlamaConfig, moe: bool): self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) if self.moe: self.mlp = SparseMLP(num_experts=config.num_experts, - top_k=config.topk, - capacity_factor_train=config.capacity_factor_train, - capacity_factor_eval=config.capacity_factor_eval, - min_capacity=config.min_capacity, - noisy_policy=config.noisy_policy, - drop_tks=config.drop_tks, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, - activation=config.hidden_act, - gated=config.gated) + router_top_k=config.router_topk, + router_capacity_factor_train=config.router_capacity_factor_train, + router_capacity_factor_eval=config.router_capacity_factor_eval, + router_min_capacity=config.router_min_capacity, + router_noisy_policy=config.router_noisy_policy, + router_drop_tks=config.router_drop_tks, + mlp_activation=config.hidden_act, + mlp_gated=config.mlp_gated, + enable_load_balance=config.enable_load_balance, + load_balance_tolerance=config.load_balance_tolerance, + load_balance_beam_width=config.load_balance_beam_width, + load_balance_group_swap_factor=config.load_balance_group_swap_factor, + enable_kernel=config.enable_kernel, + enable_comm_overlap=config.enable_comm_overlap) self.pre_extra_mlp_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.extra_mlp = OpenMoeMLP(config) else: diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index 6f239104328c..ec9ec21b55dc 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -19,7 +19,7 @@ from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.moe import MoeCheckpintIO from colossalai.moe.manager import MOE_MANAGER -from colossalai.moe.utils import skip_init +from colossalai.moe.utils import set_moe_args, skip_init from colossalai.utils import get_current_device @@ -157,7 +157,6 @@ def main(): MOE_MANAGER.setup( seed=42, parallel="EP", - use_kernel_optim=args.use_kernel if not test_mode else False, ) elif args.plugin == "zero2_ep": plugin = MoeHybridParallelPlugin( @@ -171,7 +170,6 @@ def main(): MOE_MANAGER.setup( seed=42, parallel="EP", - use_kernel_optim=args.use_kernel if not test_mode else False, ) elif args.plugin == "hybrid": plugin = MoeHybridParallelPlugin( @@ -190,7 +188,6 @@ def main(): fixed_dp_size=args.dp_size, fixed_ep_size=args.ep_size, fixed_pp_size=args.pp_size, - use_kernel_optim=args.use_kernel, ) else: raise ValueError(f"Invalid plugin {args.plugin}") @@ -205,10 +202,28 @@ def main(): else: repo_name = "hpcaitech/openmoe-" + args.model_name config = LlamaConfig.from_pretrained(repo_name) - setattr(config, "router_aux_loss_factor", args.router_aux_loss_factor) - setattr(config, "router_z_loss_factor", args.router_z_loss_factor) - setattr(config, "label_smoothing", args.label_smoothing) - setattr(config, "z_loss_factor", args.z_loss_factor) + moe_args = { + "num_experts": config.num_experts, + "moe_layer_interval": config.moe_layer_interval, + "router_topk": 2, + "router_capacity_factor_train": 1.25, + "router_capacity_factor_eval": 2.0, + "router_min_capacity": 4, + "router_noisy_policy": None, + "router_drop_tks": True, + "router_aux_loss_factor": 0.01, + "router_z_loss_factor": 0.01, + "mlp_gated": True, + "label_smoothing": 0.001, + "z_loss_factor": 0.01, + "enable_load_balance": False, + "load_balance_tolerance": 0.1, + "load_balance_beam_width": 8, + "load_balance_group_swap_factor": 0.4, + "enable_kernel": False, + "enable_comm_overlap": False, + } + set_moe_args(config, moe_args) with skip_init(): model = OpenMoeForCausalLM(config) logger.info(f"Finish init model with config:\n{config}", ranks=[0]) diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py index 934061ae4417..2e116de2db7d 100644 --- a/tests/test_moe/moe_utils.py +++ b/tests/test_moe/moe_utils.py @@ -14,13 +14,16 @@ class MoeModel(nn.Module): - def __init__(self, checkpoint: bool = False): + def __init__(self, checkpoint: bool = False, enable_load_balance: bool = False): class TestSubModule(CheckpointModule): def __init__(self): super().__init__(checkpoint) - self.moe = SparseMLP(num_experts=8, hidden_size=16, intermediate_size=32) + self.moe = SparseMLP(num_experts=8, + hidden_size=16, + intermediate_size=32, + enable_load_balance=enable_load_balance) self.proj = nn.Linear(16, 4) def _forward(self, x): diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index d935be2a9628..28ee618e1ba7 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -16,17 +16,26 @@ def run_test(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch( + config=dict(), + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) MOE_MANAGER.setup(42, parallel="EP") # MOE initialization num_experts_list = [1, 2, 4] layer_list = [] for num_experts in num_experts_list: - moe_layer = SparseMLP(hidden_size=DIM, - intermediate_size=DIM * 4, - num_experts=num_experts, - top_k=1, - noisy_policy="Jitter") + moe_layer = SparseMLP( + hidden_size=DIM, + intermediate_size=DIM * 4, + num_experts=num_experts, + router_top_k=1, + router_noisy_policy="Jitter", + ) layer_list.append(moe_layer) model = nn.ModuleList(layer_list) diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index ef5177289aad..c710c7bf713d 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -33,14 +33,14 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f layer = SparseMLP(hidden_size=hidden_size, intermediate_size=hidden_size * 2, num_experts=NUM_EXPERTS, - top_k=topk, - capacity_factor_train=1.0) + router_top_k=topk, + router_capacity_factor_train=1.0) layer = layer.to(get_current_device()) if data_type == torch.float16: layer = layer.half() # use matrix multiplication instead of COL_MOE_KERNEL in MOE dispatch and combine - layer.use_kernel = False + layer.enable_kernel = False old_out = layer(tokens) ech = old_out.shape grad = torch.randn(ech, device=get_current_device()) @@ -54,7 +54,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f tokens.grad.zero_() layer.gate_weight.grad.zero_() - layer.use_kernel = True + layer.enable_kernel = True new_out = layer(tokens) # get outputs through colossal kernel if data_type == torch.float32: diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index 09af499185db..40aae12f016a 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -16,7 +16,6 @@ sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "examples/language/openmoe")) -# TODO: better way to import them OpenMoeForCausalLM = importlib.import_module("model.modeling_openmoe").OpenMoeForCausalLM OpenMoeForCausalLMPolicy = importlib.import_module("model.openmoe_policy").OpenMoeForCausalLMPolicy @@ -37,22 +36,27 @@ def get_config(): "head_dim": 4, "num_attention_heads": 4, "dropout_rate": 0.0, - "layer_norm_epsilon": 1e-06, "hidden_act": "swiglu", "num_experts": 16, - "topk": 2, "capacity_factor_train": 1.25, "capacity_factor_eval": 2.0, "min_capacity": 4, "noisy_policy": None, "drop_tks": True, - "expert_parallel": None, - "gated": True, "moe_layer_interval": 4, "router_aux_loss_factor": 0.1, "router_z_loss_factor": 0.1, "label_smoothing": 0.1, "z_loss_factor": 0.1, + "mlp_gated": True, + "label_smoothing": 0.001, + "z_loss_factor": 0.01, + "enable_load_balance": False, + "load_balance_tolerance": 0.1, + "load_balance_beam_width": 8, + "load_balance_group_swap_factor": 0.4, + "enable_kernel": False, + "enable_comm_overlap": False, } for key, value in settings.items(): setattr(config, key, value) diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py index 51fd135483b6..11d0664fd580 100644 --- a/tests/test_moe/test_moe_ep_tp.py +++ b/tests/test_moe/test_moe_ep_tp.py @@ -11,32 +11,20 @@ from tests.test_moe.moe_utils import MoeGradientHandler, sync_local_from_ep, sync_tp_from_ep -def run_test(rank: int, - world_size: int, - port: int, - num_experts: int, - batch_size: int, - dim: int, - seed: int): +def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, seed: int): assert batch_size % world_size == 0 colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') MOE_MANAGER.__init__() MOE_MANAGER.setup(seed, parallel=None) - local_model = SparseMLP(num_experts=num_experts, - hidden_size=dim, - intermediate_size=dim * 2) + local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) MOE_MANAGER.__init__() MOE_MANAGER.setup(seed, parallel="EP") - ep_model = SparseMLP(num_experts=num_experts, - hidden_size=dim, - intermediate_size=dim * 2) + ep_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) MOE_MANAGER.__init__() MOE_MANAGER.setup(seed, parallel="TP") - tp_model = SparseMLP(num_experts=num_experts, - hidden_size=dim, - intermediate_size=dim * 2) + tp_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) ep_model = ep_model.to(get_current_device()) tp_model = tp_model.to(get_current_device()) local_model = local_model.to(get_current_device()) @@ -81,14 +69,11 @@ def run_test(rank: int, @pytest.mark.dist @pytest.mark.parametrize("num_experts", [4, 8]) -@pytest.mark.parametrize("batch_size", [4, 8]) -@pytest.mark.parametrize("dim", [16, 256]) -@pytest.mark.parametrize("seed", [42, 78]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("dim", [32]) +@pytest.mark.parametrize("seed", [42]) @rerun_if_address_is_in_use() -def test_moe_ep_tp(num_experts: int, - batch_size: int, - dim: int, - seed: int): +def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, seed: int): spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, seed=seed) diff --git a/tests/test_moe/test_moe_group.py b/tests/test_moe/test_moe_group.py index e111ea6bb18d..3cd5acc0d953 100644 --- a/tests/test_moe/test_moe_group.py +++ b/tests/test_moe/test_moe_group.py @@ -3,7 +3,7 @@ import torch.nn as nn import colossalai -from colossalai.moe import EPMLPExperts, TPMLPExperts +from colossalai.moe.experts import MLPExperts from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import sync_moe_model_param from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn @@ -13,38 +13,39 @@ INTERMEDIATE_SIZE = 8 -def run_moe_init(expert_cls): - expert_args = dict(hidden_size=HIDDEN_SIZE, intermediate_size=INTERMEDIATE_SIZE) - exp0 = expert_cls(1, **expert_args) - exp1 = expert_cls(2, **expert_args) - exp2 = expert_cls(4, **expert_args) - exp3 = expert_cls(8, **expert_args) +def run_moe_init(expert_parallel): + MOE_MANAGER.__init__() + MOE_MANAGER.setup(seed=42, parallel=expert_parallel) + expert_args = dict( + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + expert_parallel=expert_parallel, + ) + exp0 = MLPExperts(1, **expert_args) + exp1 = MLPExperts(2, **expert_args) + exp2 = MLPExperts(4, **expert_args) - if expert_cls is EPMLPExperts: + if expert_parallel == "EP": assert exp0.num_local_experts == 1 assert exp1.num_local_experts == 1 - assert exp2.num_local_experts == 1 - assert exp3.num_local_experts == 2 + assert exp2.num_local_experts == 2 else: assert exp0.num_local_experts == 1 assert exp1.num_local_experts == 2 assert exp2.num_local_experts == 4 - assert exp3.num_local_experts == 8 parallel_info_dict = MOE_MANAGER.parallel_info_dict rank = dist.get_rank() # group creation assert - assert len(parallel_info_dict) == 3 - assert dist.get_rank(parallel_info_dict[4].ep_group) == rank + assert len(parallel_info_dict) == 2 assert dist.get_rank(parallel_info_dict[2].ep_group) == rank % 2 assert dist.get_rank(parallel_info_dict[1].ep_group) == 0 - assert dist.get_rank(parallel_info_dict[4].dp_group) == 0 assert dist.get_rank(parallel_info_dict[2].dp_group) == rank // 2 assert dist.get_rank(parallel_info_dict[1].dp_group) == rank - model = nn.ModuleList([exp0, exp1, exp2, exp3]) + model = nn.ModuleList([exp0, exp1, exp2]) model = model.to(get_current_device()) sync_moe_model_param(model) @@ -57,19 +58,25 @@ def run_moe_init(expert_cls): assert_equal_in_group(exp1.wo.data, parallel_info_dict[2].dp_group) -def _run_test(rank, world_size, port, expert_cls): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_MANAGER.setup(seed=42, parallel="EP") - run_moe_init(expert_cls) +def _run_test(rank, world_size, port, expert_parallel): + colossalai.launch( + config=dict(), + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) + run_moe_init(expert_parallel) @pytest.mark.dist -@pytest.mark.parametrize("expert_cls", [EPMLPExperts, TPMLPExperts]) +@pytest.mark.parametrize("expert_parallel", ["EP", "TP"]) @rerun_if_address_is_in_use() -def test_moe_initialization(expert_cls): - spawn(_run_test, 4, expert_cls=expert_cls) +def test_moe_initialization(expert_parallel): + spawn(_run_test, 2, expert_parallel=expert_parallel) -if __name__ == '__main__': - test_moe_initialization(EPMLPExperts) - test_moe_initialization(TPMLPExperts) +if __name__ == "__main__": + test_moe_initialization("EP") + test_moe_initialization("TP") diff --git a/tests/test_moe/test_moe_load_balance.py b/tests/test_moe/test_moe_load_balance.py index b4eea04bc85a..5126c61ae92f 100644 --- a/tests/test_moe/test_moe_load_balance.py +++ b/tests/test_moe/test_moe_load_balance.py @@ -47,12 +47,8 @@ def run_zero_optim_test(local_rank, world_size, stage=1): MOE_MANAGER.setup( seed=42, parallel="EP", - enable_load_balance=True, - tolerance=0.1, - beam_width=8, - group_swap_factor=0.4, ) - zero_model = MoeModel(checkpoint=True) + zero_model = MoeModel(checkpoint=True, enable_load_balance=True) zero_optimizer = torch.optim.Adam(zero_model.parameters()) plugin = LowLevelZeroPlugin(stage=stage, precision="bf16", verbose=True) booster = Booster(plugin=plugin) @@ -118,12 +114,8 @@ def run_hybrid_zero_optim_test(local_rank, world_size, stage=1): max_ep_size=2, use_ep_inside=False, parallel="EP", - enable_load_balance=True, - tolerance=0.1, - beam_width=8, - group_swap_factor=0.4, ) - zero_model = MoeModel(checkpoint=True) + zero_model = MoeModel(checkpoint=True, enable_load_balance=True) extra_dp_group = MOE_MANAGER.parallel_info_dict[2].dp_group ep_rank = dist.get_rank(MOE_MANAGER.parallel_info_dict[2].ep_group) ep_size = MOE_MANAGER.parallel_info_dict[2].ep_size From c644b474145c3353ac34468def705de664df8298 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> Date: Thu, 26 Oct 2023 10:50:07 +0800 Subject: [PATCH 25/46] [moe] update train script (#4959) * update * update ckpt * update train * update train --- colossalai/moe/checkpoint.py | 17 +- .../openmoe/benchmark/benchmark_cai.py | 32 +--- examples/language/openmoe/infer.py | 61 ++----- .../openmoe/model/modeling_openmoe.py | 78 ++++++++- examples/language/openmoe/test_ci.sh | 8 +- examples/language/openmoe/train.py | 157 ++++++++++++------ tests/test_moe/test_moe_checkpoint.py | 48 ++---- 7 files changed, 225 insertions(+), 176 deletions(-) diff --git a/colossalai/moe/checkpoint.py b/colossalai/moe/checkpoint.py index 99e0ae811bbd..386fc2010805 100644 --- a/colossalai/moe/checkpoint.py +++ b/colossalai/moe/checkpoint.py @@ -48,14 +48,15 @@ def pre_load_model(self, model: nn.Module, state_dict: dict) -> dict: """ for name, param in state_dict.items(): if ".experts." in name: - model_param = dict(model.named_parameters())[name] - if is_moe_tensor(model_param): - ep_rank = get_ep_rank(model_param) - ep_size = get_ep_size(model_param) - expert_num = param.shape[0] // ep_size - assert param.shape[0] % ep_size == 0 - param = param[ep_rank * expert_num:(ep_rank + 1) * expert_num] - state_dict[name] = param + if name in dict(model.named_parameters()): + model_param = dict(model.named_parameters())[name] + if is_moe_tensor(model_param): + ep_rank = get_ep_rank(model_param) + ep_size = get_ep_size(model_param) + expert_num = param.shape[0] // ep_size + assert param.shape[0] % ep_size == 0 + param = param[ep_rank * expert_num:(ep_rank + 1) * expert_num] + state_dict[name] = param dist.barrier() return state_dict diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py index 2f6bfa0f89a2..1a158eabc151 100644 --- a/examples/language/openmoe/benchmark/benchmark_cai.py +++ b/examples/language/openmoe/benchmark/benchmark_cai.py @@ -4,7 +4,7 @@ import torch import torch.distributed as dist from huggingface_hub import snapshot_download -from model.modeling_openmoe import OpenMoeForCausalLM +from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args from model.openmoe_policy import OpenMoeForCausalLMPolicy from torch.utils.data import Dataset from tqdm import tqdm @@ -19,7 +19,7 @@ from colossalai.cluster import DistCoordinator from colossalai.moe.layers import apply_load_balance from colossalai.moe.manager import MOE_MANAGER -from colossalai.moe.utils import set_moe_args, skip_init +from colossalai.moe.utils import skip_init from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device @@ -218,28 +218,12 @@ def main(): # Build OpenMoe model repo_name = "hpcaitech/openmoe-" + args.model_name config = LlamaConfig.from_pretrained(repo_name) - moe_args = { - "num_experts": config.num_experts, - "moe_layer_interval": config.moe_layer_interval, - "router_topk": 2, - "router_capacity_factor_train": 1.25, - "router_capacity_factor_eval": 2.0, - "router_min_capacity": 4, - "router_noisy_policy": None, - "router_drop_tks": True, - "router_aux_loss_factor": 0.01, - "router_z_loss_factor": 0.01, - "mlp_gated": True, - "label_smoothing": 0.001, - "z_loss_factor": 0.01, - "enable_load_balance": args.load_balance, - "load_balance_tolerance": 0.1, - "load_balance_beam_width": 8, - "load_balance_group_swap_factor": 0.4, - "enable_kernel": args.use_kernel, - "enable_comm_overlap": args.overlap_alltoall, - } - set_moe_args(config, moe_args) + set_openmoe_args(config, + num_experts=config.num_experts, + moe_layer_interval=config.moe_layer_interval, + enable_load_balance=args.load_balance, + enable_kernel=args.use_kernel, + enable_comm_overlap=args.overlap_alltoall) with skip_init(): model = OpenMoeForCausalLM(config) coordinator.print_on_master(f"Finish init model with config:\n{config}") diff --git a/examples/language/openmoe/infer.py b/examples/language/openmoe/infer.py index 1ad1456b9c56..db90c6e34507 100644 --- a/examples/language/openmoe/infer.py +++ b/examples/language/openmoe/infer.py @@ -1,12 +1,10 @@ from argparse import ArgumentParser import torch -from model.modeling_openmoe import OpenMoeForCausalLM +from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args from transformers import T5Tokenizer from transformers.models.llama import LlamaConfig -from colossalai.moe.utils import set_moe_args - def parse_args(): parser = ArgumentParser() @@ -15,59 +13,22 @@ def parse_args(): def inference(args): - tokenizer = T5Tokenizer.from_pretrained("google/umt5-small") if args.model == "test": config = LlamaConfig.from_pretrained("hpcaitech/openmoe-base") - moe_args = { - "num_experts": config.num_experts, - "moe_layer_interval": config.moe_layer_interval, - "router_topk": 2, - "router_capacity_factor_train": 1.25, - "router_capacity_factor_eval": 2.0, - "router_min_capacity": 4, - "router_noisy_policy": None, - "router_drop_tks": True, - "router_aux_loss_factor": 0.01, - "router_z_loss_factor": 0.01, - "mlp_gated": True, - "label_smoothing": 0.001, - "z_loss_factor": 0.01, - "enable_load_balance": False, - "load_balance_tolerance": 0.1, - "load_balance_beam_width": 8, - "load_balance_group_swap_factor": 0.4, - "enable_kernel": False, - "enable_comm_overlap": False, - } - set_moe_args(config, moe_args) + set_openmoe_args(config, + num_experts=config.num_experts, + moe_layer_interval=config.moe_layer_interval, + enable_kernel=True) model = OpenMoeForCausalLM(config) else: config = LlamaConfig.from_pretrained(f"hpcaitech/openmoe-{args.model}") - moe_args = { - "num_experts": config.num_experts, - "moe_layer_interval": config.moe_layer_interval, - "router_topk": 2, - "router_capacity_factor_train": 1.25, - "router_capacity_factor_eval": 2.0, - "router_min_capacity": 4, - "router_noisy_policy": None, - "router_drop_tks": True, - "router_aux_loss_factor": 0.01, - "router_z_loss_factor": 0.01, - "mlp_gated": True, - "label_smoothing": 0.001, - "z_loss_factor": 0.01, - "enable_load_balance": False, - "load_balance_tolerance": 0.1, - "load_balance_beam_width": 8, - "load_balance_group_swap_factor": 0.4, - "enable_kernel": False, - "enable_comm_overlap": False, - } - set_moe_args(config, moe_args) + set_openmoe_args(config, + num_experts=config.num_experts, + moe_layer_interval=config.moe_layer_interval, + enable_kernel=False) model = OpenMoeForCausalLM.from_pretrained(f"hpcaitech/openmoe-{args.model}", config=config) - model = model.eval().half() + model = model.eval().bfloat16() model = model.to(torch.cuda.current_device()) input_str = """``` @@ -86,7 +47,7 @@ def inference(args): # print("model config: ", model.config) input_ids = tokenizer("" + input_str, return_tensors="pt", add_special_tokens=False) input_ids = input_ids.input_ids.to(torch.cuda.current_device()) - generation_output = model.generate(input_ids, use_cache=True, do_sample=True, max_new_tokens=16) + generation_output = model.generate(input_ids, use_cache=True, do_sample=True, max_new_tokens=64) out = tokenizer.decode(generation_output[0], skip_special_tokens=False) print(f"output: \n{out}\n") diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index 6f9b668e4597..7d28de731407 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -39,7 +39,7 @@ from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON from colossalai.moe.layers import SparseMLP from colossalai.moe.manager import MOE_MANAGER -from colossalai.moe.utils import get_activation +from colossalai.moe.utils import get_activation, set_moe_args if HAS_TRITON: from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine @@ -49,6 +49,78 @@ _CONFIG_FOR_DOC = "LlamaConfig" +def set_openmoe_args( + config: LlamaConfig, + num_experts: int, + moe_layer_interval: int, + router_topk: int = 2, + router_capacity_factor_train: float = 1.25, + router_capacity_factor_eval: float = 2.0, + router_min_capacity: int = 4, + router_noisy_policy: str = None, + router_drop_tks: bool = True, + router_aux_loss_factor: float = 0.01, + router_z_loss_factor: float = 0.01, + mlp_gated: bool = True, + label_smoothing: float = 0.001, + z_loss_factor: float = 0.01, + enable_load_balance: bool = False, + load_balance_tolerance: float = 0.1, + load_balance_beam_width: int = 8, + load_balance_group_swap_factor: float = 0.4, + enable_kernel: bool = False, + enable_comm_overlap: bool = False, +) -> None: + """ + MoE related arguments. + It inserts the MoE arguments into the Llama config. + + Args: + config (LlamaConfig): Transformers Llama config. + num_experts (int, optional): Number of experts. + moe_layer_interval (int, optional): The interval moe layer. + router_topk (int, optional): Moe router top k. Defaults to 2. + router_capacity_factor_train (float, optional): Moe router max capacity for train. Defaults to 1.25. + router_capacity_factor_eval (float, optional): Moe router max capacity for eval. Defaults to 2.0. + router_min_capacity (int, optional): Moe router min capacity. Defaults to 4. + router_noisy_policy (str, optional): Moe router noisy policy. You can choose [Jitter, Gaussian, None]. Defaults to None. + router_drop_tks (bool, optional): Whether moe router drop tokens which exceed max capacity. Defaults to True. + router_aux_loss_factor (float, optional): Moe router aux loss. You can refer to STMoE for details. Defaults to 0.01. + router_z_loss_factor (float, optional): Moe router z loss. You can refer to STMoE for details. Defaults to 0.01. + mlp_gated (bool, optional): Use gate in mlp. Defaults to True. + label_smoothing (float, optional): Label smoothing. Defaults to 0.001. + z_loss_factor (float, optional): The final outputs' classification z loss factor. Defaults to 0.01. + enable_load_balance (bool, optional): Expert load balance. Defaults to False. + load_balance_tolerance (float, optional): Expert load balance search's difference tolerance. Defaults to 0.1. + load_balance_beam_width (int, optional): Expert load balance search's beam width. Defaults to 8. + load_balance_group_swap_factor (float, optional): Expert load balance group swap factor. Longer value encourages less swap. Defaults to 0.4. + enable_kernel (bool, optional): Use kernel optimization. Defaults to False. + enable_comm_overlap (bool, optional): Use communication overlap for MoE. Recommended to enable for muiti-node training. Defaults to False. + """ + moe_args = dict( + num_experts=num_experts, + moe_layer_interval=moe_layer_interval, + router_topk=router_topk, + router_capacity_factor_train=router_capacity_factor_train, + router_capacity_factor_eval=router_capacity_factor_eval, + router_min_capacity=router_min_capacity, + router_noisy_policy=router_noisy_policy, + router_drop_tks=router_drop_tks, + router_aux_loss_factor=router_aux_loss_factor, + router_z_loss_factor=router_z_loss_factor, + mlp_gated=mlp_gated, + label_smoothing=label_smoothing, + z_loss_factor=z_loss_factor, + enable_load_balance=enable_load_balance, + load_balance_tolerance=load_balance_tolerance, + load_balance_beam_width=load_balance_beam_width, + load_balance_group_swap_factor=load_balance_group_swap_factor, + enable_kernel=enable_kernel, + enable_comm_overlap=enable_comm_overlap, + ) + set_moe_args(config, moe_args) + + # Copied from transformers.models.bart.modeling_bart._make_causal_mask def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, @@ -96,7 +168,7 @@ def generate_fixed_pos_embedding(features, length, min_timescale=1.0, max_timesc output_sin: a float32 Tensor with shape [length, features] output_cos: a float32 Tensor with shape [length, features] """ - fraction = torch.arange(0, features, 2, dtype=torch.float64).cuda() / features + fraction = torch.arange(0, features, 2, dtype=torch.float32).cuda() / features timescale = min_timescale * (max_timescale / min_timescale)**fraction rotational_frequency = 1. / timescale @@ -231,7 +303,7 @@ def __init__(self, config: LlamaConfig): self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self.sin, self.cos = generate_fixed_pos_embedding(self.head_dim, self.max_position_embeddings, 1e4) + self.sin, self.cos = generate_fixed_pos_embedding(self.head_dim, self.max_position_embeddings, 1.0, 1e4) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() diff --git a/examples/language/openmoe/test_ci.sh b/examples/language/openmoe/test_ci.sh index 0f68db4275f7..71198d8756d0 100644 --- a/examples/language/openmoe/test_ci.sh +++ b/examples/language/openmoe/test_ci.sh @@ -7,7 +7,13 @@ python infer.py --model "test" torchrun --standalone --nproc_per_node 4 train.py \ --num_epoch 1 \ --model_name "test" \ - --plugin zero2_ep \ + --plugin "ep" \ + --batch_size 1 + +torchrun --standalone --nproc_per_node 4 train.py \ + --num_epoch 1 \ + --model_name "test" \ + --plugin "ep_zero" \ --batch_size 1 torchrun --standalone --nproc_per_node 4 train.py \ diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index ec9ec21b55dc..19bc70e1c4f5 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -2,9 +2,10 @@ import datasets import torch +import torch.distributed as dist import transformers from huggingface_hub import snapshot_download -from model.modeling_openmoe import OpenMoeForCausalLM +from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args from model.openmoe_policy import OpenMoeForCausalLMPolicy from torch.utils.data import Dataset from tqdm import tqdm @@ -17,9 +18,9 @@ from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.moe import MoeCheckpintIO +from colossalai.moe.layers import apply_load_balance from colossalai.moe.manager import MOE_MANAGER -from colossalai.moe.utils import set_moe_args, skip_init +from colossalai.moe.utils import skip_init from colossalai.utils import get_current_device @@ -42,7 +43,31 @@ def load_ckpt(repo_name: str, model: OpenMoeForCausalLM, booster: Booster): class RandomDataset(Dataset): - def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000): + def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000, tokenizer=None): + """ + A random dataset + + You can use tokenizer to process your own data + Example: + self.input_ids = [] + self.attention_mask = [] + data = your_data() + data = shuffle(data) + for text in data: + # text is a str + encode = tokenizer( + "" + text, + return_tensors="pt", + add_special_tokens=False, + max_length=max_length, + truncation=True, + padding="max_length") + self.input_ids.append(encode["input_ids"]) + self.attention_mask.append(encode["attention_mask"]) + self.input_ids = torch.cat(self.input_ids, dim=0).to(get_current_device()) + self.attention_mask = torch.cat(self.attention_mask, dim=0).to(get_current_device()) + """ + # TODO: use distributed sampler self.num_samples = num_samples self.max_length = max_length self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) @@ -88,20 +113,38 @@ def parse_args(): type=str, default="hybrid", help="parallel plugin", - choices=["zero1_ep", "zero2_ep", "hybrid"], + choices=["ep", "ep_zero", "hybrid"], ) + + # optim + parser.add_argument("--decay_rate", type=float, default=-0.8, help="adafactor optim decay rate.") + parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to use.") + parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.") + + # zero stage for all plugins + parser.add_argument("--zero_stage", type=int, default=2, help="zero stage in hybrid plugin") + + # ep zero plugin + parser.add_argument("--extra_dp_size", type=int, default=1, help="ep zero's moe dp size") + # hybrid plugin parser.add_argument("--pp_size", type=int, default=2, help="pp size") parser.add_argument("--dp_size", type=int, default=1, help="dp size") parser.add_argument("--ep_size", type=int, default=2, help="ep size") - parser.add_argument("--zero_stage", type=int, default=1, help="zero stage in hybrid plugin") parser.add_argument("--microbatch_size", type=int, default=1, help="microbatch size") + # kernel parser.add_argument( "--use_kernel", action="store_true", - help="Use kernel optim. Need to install flash attention, apex, triton to enable all kernel optimizations.", + help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations.", + ) + parser.add_argument( + "--use_layernorm_kernel", + action="store_true", + help="Use layernorm kernel. Need to install apex.", ) + # loss parser.add_argument( "--router_aux_loss_factor", @@ -117,9 +160,13 @@ def parse_args(): ) parser.add_argument("--label_smoothing", type=float, default=0.0, help="label_smoothing.") parser.add_argument("--z_loss_factor", type=float, default=0.0001, help="z_loss_factor.") - # optim - parser.add_argument("--decay_rate", type=float, default=-0.8, help="adafactor optim decay rate.") - parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to use.") + + # load balance + parser.add_argument("--load_balance", action="store_true", help="moe load balance") + parser.add_argument("--load_balance_interval", type=int, default=1000, help="moe load balance interval") + + # overlap + parser.add_argument("--comm_overlap", action="store_true", help="moe comm overlap") args = parser.parse_args() return args @@ -145,49 +192,57 @@ def main(): # Set plugin booster_kwargs = {} - if args.plugin == "zero1_ep": + hybrid_dict = { + "tp_size": 1, + "custom_policy": OpenMoeForCausalLMPolicy(), + "enable_fused_normalization": args.use_layernorm_kernel, + "enable_jit_fused": args.use_kernel, + "precision": "bf16", + "zero_stage": args.zero_stage, + } + mgr_dict = { + "seed": 42, + } + if args.plugin == "ep": + dp_size = dist.get_world_size() plugin = MoeHybridParallelPlugin( - tp_size=1, pp_size=1, - zero_stage=1, - custom_policy=OpenMoeForCausalLMPolicy(), - enable_fused_normalization=args.use_kernel, - enable_jit_fused=args.use_kernel, + **hybrid_dict, ) MOE_MANAGER.setup( - seed=42, parallel="EP", + max_ep_size=dp_size, + **mgr_dict, ) - elif args.plugin == "zero2_ep": + elif args.plugin == "ep_zero": + dp_size = dist.get_world_size() + use_ep_inside = False plugin = MoeHybridParallelPlugin( - tp_size=1, pp_size=1, - zero_stage=2, - custom_policy=OpenMoeForCausalLMPolicy(), - enable_fused_normalization=args.use_kernel, - enable_jit_fused=args.use_kernel, + extra_dp_size=args.extra_dp_size, + use_ep_inside=use_ep_inside, + **hybrid_dict, ) MOE_MANAGER.setup( - seed=42, parallel="EP", + max_ep_size=dp_size // args.extra_dp_size, + use_ep_inside=use_ep_inside, + **mgr_dict, ) elif args.plugin == "hybrid": + dp_size = dist.get_world_size() // args.pp_size plugin = MoeHybridParallelPlugin( - tp_size=1, pp_size=args.pp_size, - zero_stage=args.zero_stage, microbatch_size=args.microbatch_size, - custom_policy=OpenMoeForCausalLMPolicy(), - enable_fused_normalization=args.use_kernel if not test_mode else False, - enable_jit_fused=args.use_kernel if not test_mode else False, + **hybrid_dict, ) MOE_MANAGER.setup( - seed=42, parallel="EP", mode="fixed", fixed_dp_size=args.dp_size, fixed_ep_size=args.ep_size, fixed_pp_size=args.pp_size, + **mgr_dict, ) else: raise ValueError(f"Invalid plugin {args.plugin}") @@ -202,28 +257,17 @@ def main(): else: repo_name = "hpcaitech/openmoe-" + args.model_name config = LlamaConfig.from_pretrained(repo_name) - moe_args = { - "num_experts": config.num_experts, - "moe_layer_interval": config.moe_layer_interval, - "router_topk": 2, - "router_capacity_factor_train": 1.25, - "router_capacity_factor_eval": 2.0, - "router_min_capacity": 4, - "router_noisy_policy": None, - "router_drop_tks": True, - "router_aux_loss_factor": 0.01, - "router_z_loss_factor": 0.01, - "mlp_gated": True, - "label_smoothing": 0.001, - "z_loss_factor": 0.01, - "enable_load_balance": False, - "load_balance_tolerance": 0.1, - "load_balance_beam_width": 8, - "load_balance_group_swap_factor": 0.4, - "enable_kernel": False, - "enable_comm_overlap": False, - } - set_moe_args(config, moe_args) + set_openmoe_args( + config, + num_experts=config.num_experts, + moe_layer_interval=config.moe_layer_interval, + router_aux_loss_factor=args.router_aux_loss_factor, + router_z_loss_factor=args.router_z_loss_factor, + z_loss_factor=args.z_loss_factor, + enable_load_balance=args.load_balance, + enable_comm_overlap=args.comm_overlap, + enable_kernel=args.use_kernel, + ) with skip_init(): model = OpenMoeForCausalLM(config) logger.info(f"Finish init model with config:\n{config}", ranks=[0]) @@ -233,7 +277,7 @@ def main(): # Prepare tokenizer and dataloader tokenizer = T5Tokenizer.from_pretrained("google/umt5-small") - dataset = RandomDataset(num_samples=1000 if not test_mode else 20) + dataset = RandomDataset(num_samples=1000 if not test_mode else 20, tokenizer=tokenizer) dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) # Set optimizer @@ -259,7 +303,7 @@ def main(): desc=f"Epoch [{epoch + 1}/{args.num_epoch}]", disable=not coordinator.is_master(), ) as pbar: - for _ in pbar: + for step in pbar: if use_pipeline: # Forward pass outputs = booster.execute_pipeline( @@ -287,6 +331,11 @@ def main(): optimizer.step() optimizer.zero_grad() + # Apply load balance + if args.load_balance and args.load_balance_interval > 0 and step % args.load_balance_interval == 0: + coordinator.print_on_master(f"Apply load balance") + apply_load_balance(model, optimizer) + # Finish training and evaluate logger.info(f"Finish finetuning", ranks=[0]) booster.save_model(model, args.output_path) diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index 40aae12f016a..b68eaec50fea 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -14,52 +14,28 @@ from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import rerun_if_address_is_in_use, spawn -sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "examples/language/openmoe")) +sys.path.append(os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))), + "examples/language/openmoe", +)) OpenMoeForCausalLM = importlib.import_module("model.modeling_openmoe").OpenMoeForCausalLM +set_openmoe_args = importlib.import_module("model.modeling_openmoe").set_openmoe_args OpenMoeForCausalLMPolicy = importlib.import_module("model.openmoe_policy").OpenMoeForCausalLMPolicy def get_config(): config = LlamaConfig( vocab_size=300, - hidden_size=32, - intermediate_size=64, - num_hidden_layers=2, + hidden_size=16, + intermediate_size=32, + num_hidden_layers=4, num_attention_heads=2, + head_dim=4, + dropout_rate=0.0, + hidden_act="swiglu", ) - settings = { - "vocab_size": 300, - "intermediate_size": 32, - "hidden_size": 16, - "num_hidden_layers": 2, - "head_dim": 4, - "num_attention_heads": 4, - "dropout_rate": 0.0, - "hidden_act": "swiglu", - "num_experts": 16, - "capacity_factor_train": 1.25, - "capacity_factor_eval": 2.0, - "min_capacity": 4, - "noisy_policy": None, - "drop_tks": True, - "moe_layer_interval": 4, - "router_aux_loss_factor": 0.1, - "router_z_loss_factor": 0.1, - "label_smoothing": 0.1, - "z_loss_factor": 0.1, - "mlp_gated": True, - "label_smoothing": 0.001, - "z_loss_factor": 0.01, - "enable_load_balance": False, - "load_balance_tolerance": 0.1, - "load_balance_beam_width": 8, - "load_balance_group_swap_factor": 0.4, - "enable_kernel": False, - "enable_comm_overlap": False, - } - for key, value in settings.items(): - setattr(config, key, value) + set_openmoe_args(config, num_experts=16, moe_layer_interval=1) return config From 5cc3ad0007114f2742fed2f1f5581e589f89fa9f Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Thu, 26 Oct 2023 13:47:29 +0800 Subject: [PATCH 26/46] update --- colossalai/legacy/initialize.py | 12 --- colossalai/tensor/moe_tensor/api.py | 3 +- colossalai/zero/low_level/low_level_optim.py | 92 ++++++++++---------- tests/test_moe/moe_utils.py | 32 ++++--- tests/test_moe/test_moe_hybrid_zero.py | 30 ++++--- tests/test_moe/test_moe_load_balance.py | 15 ++-- tests/test_moe/test_moe_zero_fwd_bwd.py | 20 +++-- tests/test_moe/test_moe_zero_optim.py | 17 ++-- 8 files changed, 110 insertions(+), 111 deletions(-) diff --git a/colossalai/legacy/initialize.py b/colossalai/legacy/initialize.py index ce9c626553bf..4035bd6b54ef 100644 --- a/colossalai/legacy/initialize.py +++ b/colossalai/legacy/initialize.py @@ -16,7 +16,6 @@ from torch.utils.data import DataLoader from colossalai.context import Config, ConfigException -from colossalai.context.moe_context import MOE_CONTEXT from colossalai.interface import OptimizerWrapper from colossalai.legacy.amp import AMP_TYPE, convert_to_amp from colossalai.legacy.amp.naive_amp import NaiveAMPModel @@ -36,7 +35,6 @@ from colossalai.legacy.zero.gemini.ophooks import BaseOpHook from colossalai.logging import get_dist_logger from colossalai.utils import get_current_device -from colossalai.utils.moe import sync_moe_model_param def get_default_parser(): @@ -323,8 +321,6 @@ def initialize( if not use_zero: if is_using_sequence(): sync_model_param(model, ParallelMode.SEQUENCE_DP) - elif MOE_CONTEXT.is_initialized: - sync_moe_model_param(model) elif is_using_ddp(): sync_model_param(model, ParallelMode.DATA) else: @@ -377,14 +373,6 @@ def initialize( "added even though not specified in the configuration", ranks=[0], ) - elif is_using_ddp() and MOE_CONTEXT.is_initialized: - gradient_handler_cfg = [dict(type="MoeGradientHandler")] - if verbose: - logger.info( - "Data parallel training is detected with moe parallel, MoeGradientHandler is automatically " - "added even though not specified in the configuration", - ranks=[0], - ) elif is_using_sequence(): model = DDP( model, diff --git a/colossalai/tensor/moe_tensor/api.py b/colossalai/tensor/moe_tensor/api.py index 9120a40b8533..c9efec63feb3 100644 --- a/colossalai/tensor/moe_tensor/api.py +++ b/colossalai/tensor/moe_tensor/api.py @@ -1,7 +1,6 @@ import torch import torch.distributed as dist - -from colossalai.tensor import ProcessGroup +from torch.distributed import ProcessGroup from .moe_info import MoeParallelInfo diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index b256da222a93..2dfe92d517c5 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -20,19 +20,11 @@ from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger from colossalai.tensor.moe_tensor.api import is_moe_tensor + # from colossalai.tensor import ColoParameter, ProcessGroup from colossalai.utils.cuda import get_current_device from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor -from ._utils import ( - calculate_global_norm_from_list, - compute_norm, - flatten, - has_inf_or_nan, - release_param_grad, - sync_tensor, - unflatten, -) from .bookkeeping import BucketStore, GradientStore, ParameterStore @@ -84,7 +76,7 @@ def __init__( partition_grad: bool = False, # stage 2 flag cpu_offload: bool = False, # cpu offload dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm - tp_process_group: Optional[ProcessGroup] = None, # if using tp + tp_process_group: Optional[ProcessGroup] = None, # if using tp forced_dtype: Optional[torch.dtype] = None, extra_dp_process_group: Optional[ProcessGroup] = None, master_weights: bool = True, # master weights @@ -161,7 +153,7 @@ def __init__( # and add buffers to parameter store for future access for group_id, param_group in enumerate(self.optim.param_groups): group_params = list() - for param in param_group['params']: + for param in param_group["params"]: if param.requires_grad: if self.extra_dp_pg is None: # skip moe param @@ -185,9 +177,9 @@ def __init__( if len(moe_params) > 0: param_group = dict() for key, value in self.optim.param_groups[0].items(): - if key != 'params': + if key != "params": param_group[key] = value - param_group['params'] = moe_params + param_group["params"] = moe_params self.optim.param_groups.append(param_group) # intialize communication stream for @@ -373,8 +365,10 @@ def _run_reduction(self): sync_tensor(flat_grads_per_rank[rank], grad_list) for grad in grad_list: param_id = self._bucket_store.get_param_id_of_grad(grad) - if len(self._grad_store.get_partitioned_gradients_by_param_id( - group_id, param_id)) < self._world_size: + if ( + len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) + < self._world_size + ): self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) else: self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id) @@ -384,8 +378,9 @@ def _run_reduction(self): # sync non moe param in global dp group if len(non_moe_grad_list) > 0: dist.all_reduce(non_moe_flat_grads, group=self.dp_pg) - flat_grads_per_rank = non_moe_flat_grads.split(non_moe_flat_grads.numel() // - self._world_size) + flat_grads_per_rank = non_moe_flat_grads.split( + non_moe_flat_grads.numel() // self._world_size + ) self._sync_unpartitioned_grad(non_moe_grad_list, flat_grads_per_rank, group_id) # sync moe param only in zero group @@ -413,8 +408,9 @@ def _run_reduction(self): self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id) else: if len(non_moe_grad_list) > 0: - flat_grads_list = list(non_moe_flat_grads.split( - len(non_moe_flat_grads) // self._world_size)) + flat_grads_list = list( + non_moe_flat_grads.split(len(non_moe_flat_grads) // self._world_size) + ) recieved_grad = torch.zeros_like(flat_grads_list[0]) dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) @@ -436,12 +432,15 @@ def _run_reduction(self): recieved_grad = list(recieved_grad.split(len(recieved_grad) // param_slice)) grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] for split_recieved_grad in recieved_grad: - split_recieved_grad = _unflatten_dense_tensors(split_recieved_grad, - grad_in_bucket_current_rank) + split_recieved_grad = _unflatten_dense_tensors( + split_recieved_grad, grad_in_bucket_current_rank + ) for grad in grad_in_bucket_current_rank: param_id = self._bucket_store.get_param_id_of_grad(grad) - if len(self._grad_store.get_partitioned_gradients_by_param_id( - group_id, param_id)) < param_slice: + if ( + len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) + < param_slice + ): self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) else: self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id) @@ -581,7 +580,9 @@ def step(self, closure=None): grad = grads else: param_slice = self._world_size // self.extra_dp_pg_size - grad = grads[self.extra_dp_pg_rank * param_slice:(self.extra_dp_pg_rank + 1) * param_slice] + grad = grads[ + self.extra_dp_pg_rank * param_slice : (self.extra_dp_pg_rank + 1) * param_slice + ] grad = flatten(grad) else: real_working_params[group_id].append(working_param) @@ -609,7 +610,7 @@ def step(self, closure=None): # TODO: we should store master param for ep if len(self.param_groups) > len(self._working_param_groups): - for param in self.param_groups[-1]['params']: + for param in self.param_groups[-1]["params"]: param.data = param.data.to(torch.float32) param.grad = param.grad.to(torch.float32) @@ -618,10 +619,9 @@ def step(self, closure=None): # TODO: release the moe grad. we should store master param if len(self.param_groups) > len(self._working_param_groups): - dtype = real_working_params[0][0].dtype - for param in self.param_groups[-1]['params']: + for param in self.param_groups[-1]["params"]: param.grad = None - param.data = param.data.to(dtype) + param.data = param.data.to(self._dtype) # release the grad grad_partition_groups = [] @@ -629,23 +629,23 @@ def step(self, closure=None): release_param_grad(self._master_param_groups_of_current_rank[group_id]) # update working partition updated by the current rank - # dtype = real_working_params[0][0].dtype for group_id in range(self.num_param_groups): master_working_param = self.optim.param_groups[group_id]["params"] for idx, splited_param in enumerate(master_working_param): working_param = real_working_params[group_id][idx] if self.extra_dp_pg is not None and is_moe_tensor(working_param): all_splited_param = [ - torch.zeros(splited_param.shape, device="cuda", dtype=dtype) + torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype) for _ in range(self.extra_dp_pg_size) ] - dist.all_gather(all_splited_param, splited_param.cuda().to(dtype), group=self.extra_dp_pg) + dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.extra_dp_pg) else: all_splited_param = [ - torch.zeros(splited_param.shape, device="cuda", dtype=dtype) for _ in range(self._world_size) + torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype) + for _ in range(self._world_size) ] - dist.all_gather(all_splited_param, splited_param.cuda().to(dtype), group=self.dp_pg) - working_param.data.copy_(flatten(all_splited_param)[:working_param.numel()].reshape_as(working_param)) + dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.dp_pg) + working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float: @@ -775,16 +775,17 @@ def state_dict(self) -> Dict: working_param = self._param_store.master_to_working_param[id(param)] if self.extra_dp_pg is not None and is_moe_tensor(v): gather_tensor = [ - torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self.extra_dp_pg_size) + torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self.extra_dp_pg_size) ] dist.all_gather(gather_tensor, v.cuda(), group=self.extra_dp_pg) else: gather_tensor = [ - torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size) + torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size) ] dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg) - param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as( - working_param).cpu() + param_state = ( + torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() + ) zero_state[param][k] = param_state states_dict = self._pack_state(zero_state) @@ -808,10 +809,10 @@ def load_state_dict(self, state_dict: Dict): v = torch.nn.functional.pad(v, [0, padding_size]) if self.extra_dp_pg is not None and is_moe_tensor(v): v_list = v.split(v.numel() // self.extra_dp_pg_size) - zero_state_dict['state'][param_idx][k] = v_list[self.extra_dp_pg_rank].detach().clone() + zero_state_dict["state"][param_idx][k] = v_list[self.extra_dp_pg_rank].detach().clone() else: v_list = v.split(v.numel() // self._world_size) - zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach().clone() + zero_state_dict["state"][param_idx][k] = v_list[self._local_rank].detach().clone() self.optim.load_state_dict(zero_state_dict) @@ -841,19 +842,20 @@ def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, i working_param = self._param_store.master_to_working_param[id(master_param)] for k, v in states.items(): - if isinstance(v, torch.Tensor) and k != 'step': + if isinstance(v, torch.Tensor) and k != "step": if self.extra_dp_pg is not None and is_moe_tensor(v): state_tensor = [ - torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self.extra_dp_pg_size) + torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self.extra_dp_pg_size) ] dist.all_gather(state_tensor, v.cuda(), group=self.extra_dp_pg) else: state_tensor = [ - torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size) + torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size) ] dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg) - state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as( - working_param).cpu() + state_tensor = ( + torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() + ) current_block_size += state_tensor.numel() current_block[k] = state_tensor diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py index 2e116de2db7d..40adeab717de 100644 --- a/tests/test_moe/moe_utils.py +++ b/tests/test_moe/moe_utils.py @@ -4,7 +4,6 @@ from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce -from colossalai.legacy.nn import CheckpointModule from colossalai.legacy.registry import GRADIENT_HANDLER from colossalai.moe import SparseMLP from colossalai.moe.manager import MOE_MANAGER @@ -13,20 +12,16 @@ class MoeModel(nn.Module): - - def __init__(self, checkpoint: bool = False, enable_load_balance: bool = False): - - class TestSubModule(CheckpointModule): - + def __init__(self, enable_load_balance: bool = False): + class TestSubModule(nn.Module): def __init__(self): - super().__init__(checkpoint) - self.moe = SparseMLP(num_experts=8, - hidden_size=16, - intermediate_size=32, - enable_load_balance=enable_load_balance) + super().__init__() + self.moe = SparseMLP( + num_experts=8, hidden_size=16, intermediate_size=32, enable_load_balance=enable_load_balance + ) self.proj = nn.Linear(16, 4) - def _forward(self, x): + def forward(self, x): x = self.moe(x) x = self.proj(x) return x @@ -76,8 +71,9 @@ def handle_gradient(self): for ep_size in epsize_param_dict: if ep_size != 1 and ep_size != MOE_MANAGER.world_size: - bucket_allreduce(param_list=epsize_param_dict[ep_size], - group=MOE_MANAGER.parallel_info_dict[ep_size].dp_group) + bucket_allreduce( + param_list=epsize_param_dict[ep_size], group=MOE_MANAGER.parallel_info_dict[ep_size].dp_group + ) def sync_tp_from_ep(tp_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None: @@ -130,8 +126,9 @@ def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_ local_model (MoeModule) ep_model (MoeModule) """ - for (local_name, local_param), (ep_name, ep_param) in zip(local_model.named_parameters(), - ep_model.named_parameters()): + for (local_name, local_param), (ep_name, ep_param) in zip( + local_model.named_parameters(), ep_model.named_parameters() + ): assert local_name == ep_name if "experts" not in local_name: if assert_grad_flag: @@ -168,4 +165,5 @@ def assert_not_equal_in_group(tensor, process_group=None): a = tensor_list[i] b = tensor_list[i + 1] assert not torch.allclose( - a, b), f'expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}' + a, b + ), f"expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}" diff --git a/tests/test_moe/test_moe_hybrid_zero.py b/tests/test_moe/test_moe_hybrid_zero.py index a2b8efb0e2dc..e3f093f7461e 100644 --- a/tests/test_moe/test_moe_hybrid_zero.py +++ b/tests/test_moe/test_moe_hybrid_zero.py @@ -35,21 +35,24 @@ def run_zero_optim_test(local_rank, world_size, stage=1): label = torch.randint(0, 4, (16,)).cuda() MOE_MANAGER.setup(seed=42, parallel=None) - torch_model = MoeModel(checkpoint=True) + torch_model = MoeModel() torch_optimizer = torch.optim.Adam(torch_model.parameters()) torch_model = torch_model.cuda() MOE_MANAGER.__init__() MOE_MANAGER.setup(seed=42, max_ep_size=2, use_ep_inside=False, parallel="EP") - zero_model = MoeModel(checkpoint=True) + zero_model = MoeModel() extra_dp_group = MOE_MANAGER.parallel_info_dict[2].dp_group ep_rank = dist.get_rank(MOE_MANAGER.parallel_info_dict[2].ep_group) ep_size = MOE_MANAGER.parallel_info_dict[2].ep_size for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): if is_moe_tensor(zero_param): num_expert = torch_param.data.shape[0] - zero_param.data.copy_(torch_param.data[ep_rank * (num_expert // ep_size):(ep_rank + 1) * - (num_expert // ep_size)].detach().clone()) + zero_param.data.copy_( + torch_param.data[ep_rank * (num_expert // ep_size) : (ep_rank + 1) * (num_expert // ep_size)] + .detach() + .clone() + ) else: zero_param.data.copy_(torch_param.data.detach().clone()) zero_optimizer = torch.optim.Adam(zero_model.parameters()) @@ -63,18 +66,21 @@ def run_zero_optim_test(local_rank, world_size, stage=1): run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) zero_optimizer.step() - for (torch_name, torch_param), (zero_name, zero_param) in zip(torch_model.named_parameters(), - zero_model.named_parameters()): + for (torch_name, torch_param), (zero_name, zero_param) in zip( + torch_model.named_parameters(), zero_model.named_parameters() + ): if is_moe_tensor(zero_param): num_expert = torch_param.data.shape[0] - torch_param.data = torch_param.data[ep_rank * (num_expert // ep_size):(ep_rank + 1) * - (num_expert // ep_size)] - assert torch.allclose(torch_param.data, zero_param.data, - atol=1e-4), f"{torch_name}\ntorch_param {torch_param.data}\nzero_param {zero_param.data}" + torch_param.data = torch_param.data[ + ep_rank * (num_expert // ep_size) : (ep_rank + 1) * (num_expert // ep_size) + ] + assert torch.allclose( + torch_param.data, zero_param.data, atol=1e-4 + ), f"{torch_name}\ntorch_param {torch_param.data}\nzero_param {zero_param.data}" def run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_zero_optim_test(rank, world_size, stage=1) @@ -85,5 +91,5 @@ def test_moe_zero_optim(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_moe_zero_optim(world_size=4) diff --git a/tests/test_moe/test_moe_load_balance.py b/tests/test_moe/test_moe_load_balance.py index 5126c61ae92f..4daad7949a87 100644 --- a/tests/test_moe/test_moe_load_balance.py +++ b/tests/test_moe/test_moe_load_balance.py @@ -48,7 +48,7 @@ def run_zero_optim_test(local_rank, world_size, stage=1): seed=42, parallel="EP", ) - zero_model = MoeModel(checkpoint=True, enable_load_balance=True) + zero_model = MoeModel(enable_load_balance=True) zero_optimizer = torch.optim.Adam(zero_model.parameters()) plugin = LowLevelZeroPlugin(stage=stage, precision="bf16", verbose=True) booster = Booster(plugin=plugin) @@ -56,7 +56,7 @@ def run_zero_optim_test(local_rank, world_size, stage=1): MOE_MANAGER.__init__() MOE_MANAGER.setup(seed=42, parallel="EP") - torch_model = MoeModel(checkpoint=True) + torch_model = MoeModel() for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): torch_param.data.copy_(zero_param.data) torch_optimizer = torch.optim.Adam(torch_model.parameters()) @@ -104,7 +104,7 @@ def run_hybrid_zero_optim_test(local_rank, world_size, stage=1): MOE_MANAGER.__init__() MOE_MANAGER.setup(seed=42, parallel=None) - torch_model = MoeModel(checkpoint=True) + torch_model = MoeModel() torch_optimizer = torch.optim.Adam(torch_model.parameters()) torch_model = torch_model.cuda() @@ -115,15 +115,18 @@ def run_hybrid_zero_optim_test(local_rank, world_size, stage=1): use_ep_inside=False, parallel="EP", ) - zero_model = MoeModel(checkpoint=True, enable_load_balance=True) + zero_model = MoeModel(enable_load_balance=True) extra_dp_group = MOE_MANAGER.parallel_info_dict[2].dp_group ep_rank = dist.get_rank(MOE_MANAGER.parallel_info_dict[2].ep_group) ep_size = MOE_MANAGER.parallel_info_dict[2].ep_size for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): if is_moe_tensor(zero_param): num_expert = torch_param.data.shape[0] - zero_param.data.copy_(torch_param.data[ep_rank * (num_expert // ep_size):(ep_rank + 1) * - (num_expert // ep_size)].detach().clone()) + zero_param.data.copy_( + torch_param.data[ep_rank * (num_expert // ep_size) : (ep_rank + 1) * (num_expert // ep_size)] + .detach() + .clone() + ) else: zero_param.data.copy_(torch_param.data.detach().clone()) zero_optimizer = torch.optim.Adam(zero_model.parameters()) diff --git a/tests/test_moe/test_moe_zero_fwd_bwd.py b/tests/test_moe/test_moe_zero_fwd_bwd.py index 499d65f0072a..8f046ab00d59 100644 --- a/tests/test_moe/test_moe_zero_fwd_bwd.py +++ b/tests/test_moe/test_moe_zero_fwd_bwd.py @@ -41,21 +41,22 @@ def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False) def run_zero_test(local_rank, world_size, stage=1): criterion = torch.nn.CrossEntropyLoss() - zero_model = MoeModel(checkpoint=True) + zero_model = MoeModel() optimizer = torch.optim.Adam(zero_model.parameters()) plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") booster = Booster(plugin=plugin) zero_model, optimizer, _, _, _ = booster.boost(zero_model, optimizer) - torch_model = MoeModel(checkpoint=True) + torch_model = MoeModel() for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): torch_param.data.copy_(zero_param.data) torch_model = torch_model.cuda() grad_handler = MoeGradientHandler(torch_model) # assert zero model - for (torch_name, torch_param), (zero_name, zero_param) in zip(torch_model.named_parameters(), - zero_model.module.named_parameters()): + for (torch_name, torch_param), (zero_name, zero_param) in zip( + torch_model.named_parameters(), zero_model.module.named_parameters() + ): assert zero_name == torch_name assert torch.allclose(zero_param.data, torch_param.data) @@ -67,8 +68,9 @@ def run_zero_test(local_rank, world_size, stage=1): assert torch.allclose(torch_out, zero_out) grad_handler.handle_gradient() - for (zero_name, zero_param), (torch_name, torch_param) in zip(zero_model.module.named_parameters(), - torch_model.named_parameters()): + for (zero_name, zero_param), (torch_name, torch_param) in zip( + zero_model.module.named_parameters(), torch_model.named_parameters() + ): assert zero_name == torch_name zero_grad_list = optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param)) if hasattr(zero_param, "moe_info"): @@ -78,14 +80,14 @@ def run_zero_test(local_rank, world_size, stage=1): assert len(zero_grad_list) > 0 torch_grad_list = split_ddp_grad(torch_param.grad, world_size) if stage == 2: - torch_grad_list = torch_grad_list[local_rank:local_rank + 1] + torch_grad_list = torch_grad_list[local_rank : local_rank + 1] assert len(zero_grad_list) == len(torch_grad_list) for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list): assert torch.allclose(zero_grad, torch_grad) def run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") MOE_MANAGER.setup(seed=42, parallel="EP") seed_all(42 + rank) run_zero_test(rank, world_size, stage=1) @@ -99,5 +101,5 @@ def test_moe_zero_model(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_moe_zero_model(world_size=2) diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index 8f4d89f17330..ebea7509f6dc 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -40,13 +40,13 @@ def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False) def run_zero_optim_test(local_rank, world_size, stage=1): criterion = torch.nn.CrossEntropyLoss() - zero_model = MoeModel(checkpoint=True) + zero_model = MoeModel() zero_optimizer = torch.optim.Adam(zero_model.parameters()) plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") booster = Booster(plugin=plugin) zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) - torch_model = MoeModel(checkpoint=True) + torch_model = MoeModel() for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): torch_param.data.copy_(zero_param.data) torch_optimizer = torch.optim.Adam(torch_model.parameters()) @@ -63,18 +63,19 @@ def run_zero_optim_test(local_rank, world_size, stage=1): torch_optimizer.step() zero_optimizer.step() - for (torch_name, torch_param), (zero_name, zero_param) in zip(torch_model.named_parameters(), - zero_model.named_parameters()): + for (torch_name, torch_param), (zero_name, zero_param) in zip( + torch_model.named_parameters(), zero_model.named_parameters() + ): assert torch.allclose( - torch_param.data, - zero_param.data), f"{torch_name}\ntorch_param {torch_param.data}\nzero_param {zero_param.data}" + torch_param.data, zero_param.data + ), f"{torch_name}\ntorch_param {torch_param.data}\nzero_param {zero_param.data}" torch_optimizer.zero_grad() zero_optimizer.zero_grad() def run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") MOE_MANAGER.setup(seed=42, parallel="EP") run_zero_optim_test(rank, world_size, stage=1) run_zero_optim_test(rank, world_size, stage=2) @@ -87,5 +88,5 @@ def test_moe_zero_optim(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_moe_zero_optim(world_size=2) From 713446bd8b47e919dead0df74ef6733c04b7d3fd Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Thu, 26 Oct 2023 13:55:40 +0800 Subject: [PATCH 27/46] delete context --- colossalai/context/moe_context.py | 132 ------------------------------ 1 file changed, 132 deletions(-) delete mode 100644 colossalai/context/moe_context.py diff --git a/colossalai/context/moe_context.py b/colossalai/context/moe_context.py deleted file mode 100644 index 510b05278c56..000000000000 --- a/colossalai/context/moe_context.py +++ /dev/null @@ -1,132 +0,0 @@ -from typing import Tuple - -import torch -import torch.distributed as dist - -from colossalai.context.parallel_mode import ParallelMode -from colossalai.context.singleton_meta import SingletonMeta -from colossalai.tensor import ProcessGroup - - -def _check_sanity(): - from colossalai.core import global_context as gpc - if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1: - raise NotImplementedError("Moe is not compatible with tensor or " - "pipeline parallel at present.") - - -class MoeParallelInfo: - """Moe parallelism information, storing parallel sizes and groups. - """ - - def __init__(self, ep_size: int, dp_size: int): - _check_sanity() - self.ep_size = ep_size - self.dp_size = dp_size - self.pg = ProcessGroup(tp_degree=ep_size, dp_degree=dp_size) - self.ep_group = self.pg.tp_process_group() - self.dp_group = self.pg.dp_process_group() - - -class MoeContext(metaclass=SingletonMeta): - """MoE parallel context manager. This class manages different - parallel groups in MoE context and MoE loss in training. - """ - - def __init__(self): - self.world_size = 1 - # Users may want to set maximum expert parallel size smaller than the world size - # since very low bandwidth across nodes may constrain the performance of MoE - # When we have a maximum expert parallel size, we have a minimum data parallel size naturally - self.max_ep_size = 1 - self.min_dp_size = 1 - self.aux_loss = None - self.use_kernel_optim = True - - self.has_setup = False - self._parallel_info_dict = dict() - - @property - def parallel_info_dict(self): - return self._parallel_info_dict - - @property - def is_initialized(self): - return self.has_setup - - def setup(self, seed: int, use_kernel_optim: bool = True): - assert not self.is_initialized, "MoE distributed context shouldn't be set up again" - _check_sanity() - assert torch.cuda.is_available(), "MoE requires to enable CUDA first" - - self.world_size = dist.get_world_size() - - from colossalai.core import global_context as gpc - self.max_ep_size = gpc.config.get('max_ep_size', self.world_size) - assert self.world_size % self.max_ep_size == 0, \ - "Maximum expert parallel size must be a factor of the number of GPUs" - self.min_dp_size = self.world_size // self.max_ep_size - - # Enabling kernel optimization may raise error in some cases - # Users can close kernel optimization manually - self.use_kernel_optim = use_kernel_optim - - from .random import moe_set_seed - - moe_set_seed(seed) - self.has_setup = True - - def get_info(self, num_experts: int) -> Tuple[int, MoeParallelInfo]: - """Calculate the Data Parallel Group and Expert Parallel Group. - - Parameters - ---------- - num_experts : int - The number experts - - Returns - ------- - int, MoeParallelInfo - number of local experts, the MoeParallelInfo of the current ep_size - """ - - gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater - lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less - - assert gt_flag or lt_flag, ( - "Automatic experts placement dose not not support expert number" - " is not a multiple of ep size or vice versa." - ) - - # If the number of experts is greater than maximum expert parallel size. a.k.a ep_size, - # there are multiple experts in each GPU and each GPU has different experts - # So it's data parallel size is 1 - # Otherwise, there is only one expert in each GPU - # The data parallel size should be calculated - dp_size = 1 if gt_flag else self.max_ep_size // num_experts - ep_size = self.max_ep_size // dp_size - - # Calculate the number of experts for each GPU - num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size - - # Don't forget to multiply minimum data parallel size - dp_size *= self.min_dp_size - if not (ep_size in self.parallel_info_dict): - self.parallel_info_dict[ep_size] = MoeParallelInfo(ep_size, dp_size) - - return num_local_experts, self.parallel_info_dict[ep_size] - - def set_kernel_not_use(self): - self.use_kernel_optim = False - - def reset_loss(self): - self.aux_loss = 0 - - def add_loss(self, loss): - self.aux_loss += loss - - def get_loss(self): - return self.aux_loss - - -MOE_CONTEXT = MoeContext() From 1b19a5fb7d90b2fa6cf6c056c985c9a7b85e4eee Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Thu, 26 Oct 2023 13:57:41 +0800 Subject: [PATCH 28/46] remove moe --- colossalai/nn/layer/__init__.py | 1 - colossalai/nn/layer/moe/__init__.py | 12 ------------ 2 files changed, 13 deletions(-) delete mode 100644 colossalai/nn/layer/moe/__init__.py diff --git a/colossalai/nn/layer/__init__.py b/colossalai/nn/layer/__init__.py index 9aeab9f44a6d..16281fe0b66d 100644 --- a/colossalai/nn/layer/__init__.py +++ b/colossalai/nn/layer/__init__.py @@ -1,2 +1 @@ -# from .moe import * from .utils import * diff --git a/colossalai/nn/layer/moe/__init__.py b/colossalai/nn/layer/moe/__init__.py deleted file mode 100644 index 5280acf8dee7..000000000000 --- a/colossalai/nn/layer/moe/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -MoeModule = None -MoeLayer = None -build_ffn_experts = None -EPMLPExperts = None -TPMLPExperts = None -Top1Router = None -Top2Router = None -NormalNoiseGenerator = None -UniformNoiseGenerator = None -SparseMLP = None -MoeRouter = None -MoeCheckpintIO = None From ca42bf421d99556e21814cd4f0d2a800bbb66b69 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Thu, 26 Oct 2023 14:15:22 +0800 Subject: [PATCH 29/46] fix bugs --- .../plugin/moe_hybrid_parallel_plugin.py | 320 ++++++++++-------- examples/language/openmoe/train.py | 13 +- 2 files changed, 179 insertions(+), 154 deletions(-) diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 5171780da347..5c3aa12c0c2c 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -1,4 +1,5 @@ import random +from types import MethodType from typing import Callable, Optional, OrderedDict, Tuple import numpy as np @@ -32,38 +33,55 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): - def __init__( - self, - optimizer: Optimizer, - model: Module, - use_pipeline: bool, - param_info: OrderedDict, - initial_scale: int = 2**16, # grad scaler config - min_scale: int = 1, - growth_factor: float = 2., - backoff_factor: float = .5, - growth_interval: int = 2000, - hysteresis: int = 2, - max_scale: int = 2**24, - clip_grad_norm: float = 0.0, # grad clipping - verbose: bool = False, - reduce_bucket_size: int = 1024 * 1024, # communication - communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = True, - partition_grad: bool = False, # stage 2 flag - cpu_offload: bool = False, # cpu offload - dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm - tp_process_group: Optional[ProcessGroup] = None, # if using tp - forced_dtype: Optional[torch.dtype] = None, - extra_dp_process_group: Optional[ProcessGroup] = None): + self, + optimizer: Optimizer, + model: Module, + use_pipeline: bool, + param_info: OrderedDict, + initial_scale: int = 2**16, # grad scaler config + min_scale: int = 1, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000, + hysteresis: int = 2, + max_scale: int = 2**24, + clip_grad_norm: float = 0.0, # grad clipping + verbose: bool = False, + reduce_bucket_size: int = 1024 * 1024, # communication + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True, + partition_grad: bool = False, # stage 2 flag + cpu_offload: bool = False, # cpu offload + dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm + tp_process_group: Optional[ProcessGroup] = None, # if using tp + forced_dtype: Optional[torch.dtype] = None, + extra_dp_process_group: Optional[ProcessGroup] = None, + ): self.param_info = param_info if use_pipeline: init_pipeline_optimizer(optimizer, model) - super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, - hysteresis, max_scale, clip_grad_norm, verbose, reduce_bucket_size, communication_dtype, - overlap_communication, partition_grad, cpu_offload, dp_process_group, tp_process_group, - forced_dtype, extra_dp_process_group) + super().__init__( + optimizer, + initial_scale, + min_scale, + growth_factor, + backoff_factor, + growth_interval, + hysteresis, + max_scale, + clip_grad_norm, + verbose, + reduce_bucket_size, + communication_dtype, + overlap_communication, + partition_grad, + cpu_offload, + dp_process_group, + tp_process_group, + forced_dtype, + extra_dp_process_group, + ) class MoeHybridParallelPlugin(HybridParallelPlugin): @@ -123,51 +141,51 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True. """ - def __init__(self, - tp_size: int, - pp_size: int, - extra_dp_size: int = 1, - precision: str = 'fp16', - zero_stage: int = 0, - enable_all_optimization: bool = False, - enable_fused_normalization: bool = False, - enable_flash_attention: bool = False, - enable_jit_fused: bool = False, - enable_sequence_parallelism: bool = False, - enable_sequence_overlap: bool = False, - num_microbatches: Optional[int] = None, - microbatch_size: Optional[int] = None, - initial_scale: float = 2**16, - min_scale: float = 1, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - max_scale: float = 2**32, - max_norm: float = 0, - broadcast_buffers: bool = True, - ddp_bucket_cap_mb: int = 25, - find_unused_parameters: bool = False, - check_reduction: bool = False, - gradient_as_bucket_view: bool = False, - static_graph: bool = False, - zero_bucket_size_in_m: int = 12, - cpu_offload: bool = False, - communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = True, - use_ep_inside: bool = True, - custom_policy: Policy = None) -> None: - - super().__init__(tp_size=tp_size, - pp_size=pp_size, - num_microbatches=num_microbatches, - microbatch_size=microbatch_size) - assert dist.get_world_size() % ( - tp_size * pp_size - ) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}' + def __init__( + self, + tp_size: int, + pp_size: int, + extra_dp_size: int = 1, + precision: str = "fp16", + zero_stage: int = 0, + enable_all_optimization: bool = False, + enable_fused_normalization: bool = False, + enable_flash_attention: bool = False, + enable_jit_fused: bool = False, + enable_sequence_parallelism: bool = False, + enable_sequence_overlap: bool = False, + num_microbatches: Optional[int] = None, + microbatch_size: Optional[int] = None, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0, + broadcast_buffers: bool = True, + ddp_bucket_cap_mb: int = 25, + find_unused_parameters: bool = False, + check_reduction: bool = False, + gradient_as_bucket_view: bool = False, + static_graph: bool = False, + zero_bucket_size_in_m: int = 12, + cpu_offload: bool = False, + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True, + use_ep_inside: bool = True, + custom_policy: Policy = None, + ) -> None: + super().__init__( + tp_size=tp_size, pp_size=pp_size, num_microbatches=num_microbatches, microbatch_size=microbatch_size + ) + assert ( + dist.get_world_size() % (tp_size * pp_size) == 0 + ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" if enable_sequence_parallelism: - assert tp_size > 1, 'Sequence parallelism must be enabled when using tensor parallelism' + assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism" self.tp_size = tp_size self.pp_size = pp_size @@ -204,24 +222,28 @@ def __init__(self, self.custom_policy = custom_policy assert zero_stage in (0, 1, 2) if self.pp_size > 1: - assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism' - assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism' + assert ( + num_microbatches is not None or microbatch_size is not None + ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" + assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism" self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS) - self.schedule = OneForwardOneBackwardSchedule(self.stage_manager, - num_microbatches=num_microbatches, - microbatch_size=microbatch_size) + self.schedule = OneForwardOneBackwardSchedule( + self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size + ) self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) - self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group, - pipeline_stage_manager=self.stage_manager, - enable_tensor_parallelism=self.tp_size > 1, - enable_all_optimization=self.enable_all_optimization, - enable_fused_normalization=self.enable_fused_normalization, - enable_flash_attention=self.enable_flash_attention, - enable_jit_fused=self.enable_jit_fused, - enable_sequence_parallelism=enable_sequence_parallelism, - enable_sequence_overlap=enable_sequence_overlap) + self.shard_config = ShardConfig( + tensor_parallel_process_group=self.tp_group, + pipeline_stage_manager=self.stage_manager, + enable_tensor_parallelism=self.tp_size > 1, + enable_all_optimization=self.enable_all_optimization, + enable_fused_normalization=self.enable_fused_normalization, + enable_flash_attention=self.enable_flash_attention, + enable_jit_fused=self.enable_jit_fused, + enable_sequence_parallelism=enable_sequence_parallelism, + enable_sequence_overlap=enable_sequence_overlap, + ) self.amp_config = dict( initial_scale=initial_scale, growth_factor=growth_factor, @@ -232,30 +254,28 @@ def __init__(self, max_scale=max_scale, ) - self.ddp_config = dict(broadcast_buffers=broadcast_buffers, - bucket_cap_mb=ddp_bucket_cap_mb, - find_unused_parameters=find_unused_parameters, - check_reduction=check_reduction, - gradient_as_bucket_view=gradient_as_bucket_view, - static_graph=static_graph) + self.ddp_config = dict( + broadcast_buffers=broadcast_buffers, + bucket_cap_mb=ddp_bucket_cap_mb, + find_unused_parameters=find_unused_parameters, + check_reduction=check_reduction, + gradient_as_bucket_view=gradient_as_bucket_view, + static_graph=static_graph, + ) - self.zero_config = dict(reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024, - communication_dtype=communication_dtype, - overlap_communication=overlap_communication, - cpu_offload=cpu_offload, - partition_grad=(self.zero_stage == 2)) + self.zero_config = dict( + reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + cpu_offload=cpu_offload, + partition_grad=(self.zero_stage == 2), + ) self.max_norm = max_norm - def prepare_dataloader(self, - dataset, - batch_size, - shuffle=False, - seed=1024, - drop_last=False, - pin_memory=False, - num_workers=0, - **kwargs): + def prepare_dataloader( + self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs + ): r""" Prepare a dataloader for distributed training. The dataloader will be wrapped by `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. @@ -278,10 +298,9 @@ def prepare_dataloader(self, :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. """ _kwargs = kwargs.copy() - sampler = DistributedSampler(dataset, - num_replicas=self.pg_mesh.size(DP_AXIS), - rank=self.pg_mesh.coordinate(DP_AXIS), - shuffle=shuffle) + sampler = DistributedSampler( + dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle + ) # Deterministic dataloader def seed_worker(worker_id): @@ -290,14 +309,16 @@ def seed_worker(worker_id): torch.manual_seed(worker_seed) random.seed(worker_seed) - return DataLoader(dataset, - batch_size=batch_size, - sampler=sampler, - worker_init_fn=seed_worker, - drop_last=drop_last, - pin_memory=pin_memory, - num_workers=num_workers, - **_kwargs) + return DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) def get_checkpoint_io(self) -> MoeCheckpintIO: self.checkpoint_io = MoeCheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) @@ -314,40 +335,45 @@ def configure( param_info = get_param_info(optimizer) if not isinstance(model, ModelWrapper): use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 - model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp, - self.ddp_config, self.custom_policy) + model = HybridParallelModule( + model, self.precision, self.shard_config, self.dp_group, use_ddp, self.ddp_config, self.custom_policy + ) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if self.zero_stage == 0: - if self.precision in ['fp16', 'bf16']: - optimizer = HybridParallelAMPOptimizer(optimizer, - model, - use_pipeline=self.enable_pipeline_parallelism, - param_info=param_info, - precision=self.precision, - max_norm=self.max_norm, - **self.amp_config) - self.checkpoint_io.link_master_and_working_param(optimizer.working_to_master_map, - optimizer.master_to_working_map) + if self.precision in ["fp16", "bf16"]: + optimizer = HybridParallelAMPOptimizer( + optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + precision=self.precision, + max_norm=self.max_norm, + **self.amp_config, + ) + self.checkpoint_io.link_master_and_working_param( + optimizer.working_to_master_map, optimizer.master_to_working_map + ) else: - optimizer = HybridParallelNaiveOptimizer(optimizer, - model, - use_pipeline=self.enable_pipeline_parallelism, - param_info=param_info) + optimizer = HybridParallelNaiveOptimizer( + optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info + ) else: assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1." - assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO." - optimizer = HybridParallelZeroOptimizer(optimizer, - model, - use_pipeline=self.enable_pipeline_parallelism, - param_info=param_info, - dp_process_group=self.dp_group, - tp_process_group=self.tp_group, - extra_dp_process_group=self.extra_dp_group, - verbose=True, - clip_grad_norm=self.max_norm, - **self.zero_config, - **self.amp_config) - self.checkpoint_io.link_master_and_working_param(optimizer._param_store.working_to_master_param, - optimizer._param_store.master_to_working_param) + assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." + optimizer = HybridParallelZeroOptimizer( + optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + dp_process_group=self.dp_group, + tp_process_group=self.tp_group, + extra_dp_process_group=self.extra_dp_group, + verbose=True, + clip_grad_norm=self.max_norm, + **self.zero_config, + **self.amp_config, + ) + # inject update_master_params + model.update_master_params = MethodType(optimizer.update_master_params, model) return model, optimizer, criterion, dataloader, lr_scheduler diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index 19bc70e1c4f5..8b2e0e833b58 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -1,3 +1,4 @@ +import argparse import os import datasets @@ -13,7 +14,6 @@ from transformers.models.llama import LlamaConfig import colossalai -from colossalai import get_default_parser from colossalai.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator @@ -42,7 +42,6 @@ def load_ckpt(repo_name: str, model: OpenMoeForCausalLM, booster: Booster): class RandomDataset(Dataset): - def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000, tokenizer=None): """ A random dataset @@ -86,7 +85,7 @@ def __getitem__(self, idx): def parse_args(): # basic settings - parser = get_default_parser() + parser = argparse.ArgumentParser() parser.add_argument( "--model_name", type=str, @@ -288,7 +287,7 @@ def main(): model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader) if not test_mode: load_ckpt(repo_name, model, booster) - use_pipeline = (isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1) + use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1 is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() logger.info(f"Finish init booster", ranks=[0]) @@ -299,9 +298,9 @@ def main(): train_dataloader_iter = iter(dataloader) total_len = len(train_dataloader_iter) with tqdm( - range(total_len), - desc=f"Epoch [{epoch + 1}/{args.num_epoch}]", - disable=not coordinator.is_master(), + range(total_len), + desc=f"Epoch [{epoch + 1}/{args.num_epoch}]", + disable=not coordinator.is_master(), ) as pbar: for step in pbar: if use_pipeline: From c381e4c481df5491a3b8c8e291b66cd870765a88 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Thu, 26 Oct 2023 15:55:32 +0800 Subject: [PATCH 30/46] update timeout temporarily --- .github/workflows/build_on_pr.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index e2114d43bcd0..08a351fc3984 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -142,7 +142,7 @@ jobs: container: image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny - timeout-minutes: 60 + timeout-minutes: 120 defaults: run: shell: bash From b19fb917b6f0718cc8845da09a93cf8071089b41 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Thu, 26 Oct 2023 17:50:13 +0800 Subject: [PATCH 31/46] resume time --- .github/workflows/build_on_pr.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index 08a351fc3984..e2114d43bcd0 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -142,7 +142,7 @@ jobs: container: image: hpcaitech/pytorch-cuda:1.12.0-11.3.0 options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny - timeout-minutes: 120 + timeout-minutes: 60 defaults: run: shell: bash From 61df786e358673b38d2496734affd519b98c5df2 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Sun, 29 Oct 2023 00:33:19 +0800 Subject: [PATCH 32/46] fix bug --- colossalai/zero/low_level/low_level_optim.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 2dfe92d517c5..a7b715128d49 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -255,6 +255,8 @@ def _create_master_param_current_rank(self, param_list): # use fp32 when master_weights is True if self._master_weights is True: splited_param_current_rank = splited_params.detach().float().to(device) + else: + splited_param_current_rank = splited_params params_current_rank.append(splited_param_current_rank) self._param_store.link_master_and_working_param(splited_param_current_rank, param) From 685c80a002dde5aeafcdc60b4da66fc687bfc4b0 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Sun, 29 Oct 2023 01:08:44 +0800 Subject: [PATCH 33/46] remove tp --- colossalai/zero/low_level/low_level_optim.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index a7b715128d49..f27bd88fa8de 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -76,7 +76,6 @@ def __init__( partition_grad: bool = False, # stage 2 flag cpu_offload: bool = False, # cpu offload dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm - tp_process_group: Optional[ProcessGroup] = None, # if using tp forced_dtype: Optional[torch.dtype] = None, extra_dp_process_group: Optional[ProcessGroup] = None, master_weights: bool = True, # master weights @@ -109,8 +108,6 @@ def __init__( self.extra_dp_pg_size = dist.get_world_size(group=self.extra_dp_pg) self.extra_dp_pg_rank = dist.get_rank(group=self.extra_dp_pg) - self.tp_pg = tp_process_group - # working and master params for mixed precision training self._working_param_groups = dict() self._master_param_groups_of_current_rank = dict() From 9586f6196507a7f1e0331f452dd15179f9068ecd Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Sun, 29 Oct 2023 01:30:55 +0800 Subject: [PATCH 34/46] use kwargs --- .../plugin/moe_hybrid_parallel_plugin.py | 37 +++++++++---------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 5c3aa12c0c2c..110456f98167 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -62,25 +62,24 @@ def __init__( if use_pipeline: init_pipeline_optimizer(optimizer, model) super().__init__( - optimizer, - initial_scale, - min_scale, - growth_factor, - backoff_factor, - growth_interval, - hysteresis, - max_scale, - clip_grad_norm, - verbose, - reduce_bucket_size, - communication_dtype, - overlap_communication, - partition_grad, - cpu_offload, - dp_process_group, - tp_process_group, - forced_dtype, - extra_dp_process_group, + optimizer=optimizer, + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + clip_grad_norm=clip_grad_norm, + verbose=verbose, + reduce_bucket_size=reduce_bucket_size, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + partition_grad=partition_grad, + cpu_offload=cpu_offload, + dp_process_group=dp_process_group, + forced_dtype=forced_dtype, + extra_dp_process_group=extra_dp_process_group, ) From 6c0094ccb1b4c32349236ca7b69fcf78703aff86 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Mon, 30 Oct 2023 10:41:28 +0800 Subject: [PATCH 35/46] polish and align with main --- .../plugin/moe_hybrid_parallel_plugin.py | 22 +++--- colossalai/zero/low_level/low_level_optim.py | 72 ++++++++++--------- tests/test_moe/test_moe_hybrid_zero.py | 2 +- 3 files changed, 51 insertions(+), 45 deletions(-) diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 110456f98167..b67642b0d2e3 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -55,10 +55,16 @@ def __init__( cpu_offload: bool = False, # cpu offload dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm tp_process_group: Optional[ProcessGroup] = None, # if using tp + pp_process_group: Optional[ProcessGroup] = None, forced_dtype: Optional[torch.dtype] = None, - extra_dp_process_group: Optional[ProcessGroup] = None, + moe_extra_dp_process_group: Optional[ProcessGroup] = None, ): self.param_info = param_info + self.stage_manager = model.stage_manager + self.shared_params = model.shared_params + self.dp_pg = dp_process_group + self.tp_pg = tp_process_group + self.pp_pg = pp_process_group if use_pipeline: init_pipeline_optimizer(optimizer, model) super().__init__( @@ -79,7 +85,7 @@ def __init__( cpu_offload=cpu_offload, dp_process_group=dp_process_group, forced_dtype=forced_dtype, - extra_dp_process_group=extra_dp_process_group, + moe_extra_dp_process_group=moe_extra_dp_process_group, ) @@ -176,9 +182,6 @@ def __init__( use_ep_inside: bool = True, custom_policy: Policy = None, ) -> None: - super().__init__( - tp_size=tp_size, pp_size=pp_size, num_microbatches=num_microbatches, microbatch_size=microbatch_size - ) assert ( dist.get_world_size() % (tp_size * pp_size) == 0 ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" @@ -205,16 +208,16 @@ def __init__( ep_size = self.dp_size // extra_dp_size if use_ep_inside: self.pg_mesh_moe = ProcessGroupMesh(self.pp_size, extra_dp_size, ep_size) - self.extra_dp_group = self.pg_mesh_moe.get_group_along_axis(1) + self.moe_extra_dp_group = self.pg_mesh_moe.get_group_along_axis(1) if dist.get_rank() == 0: print(f"Zero Parallel: pp {self.pp_size}, outer_dp {extra_dp_size}, inner_dp {ep_size}") else: self.pg_mesh_moe = ProcessGroupMesh(self.pp_size, ep_size, extra_dp_size) - self.extra_dp_group = self.pg_mesh_moe.get_group_along_axis(2) + self.moe_extra_dp_group = self.pg_mesh_moe.get_group_along_axis(2) if dist.get_rank() == 0: print(f"Zero Parallel: pp {self.pp_size}, outer_dp {ep_size}, inner_dp {extra_dp_size}") else: - self.extra_dp_group = None + self.moe_extra_dp_group = None self.stage_manager = None self.schedule = None @@ -366,7 +369,8 @@ def configure( param_info=param_info, dp_process_group=self.dp_group, tp_process_group=self.tp_group, - extra_dp_process_group=self.extra_dp_group, + pp_process_group=self.pp_group, + moe_extra_dp_process_group=self.moe_extra_dp_group, verbose=True, clip_grad_norm=self.max_norm, **self.zero_config, diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index f27bd88fa8de..5e06dbfd947a 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -77,7 +77,7 @@ def __init__( cpu_offload: bool = False, # cpu offload dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm forced_dtype: Optional[torch.dtype] = None, - extra_dp_process_group: Optional[ProcessGroup] = None, + moe_extra_dp_process_group: Optional[ProcessGroup] = None, master_weights: bool = True, # master weights ): super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) @@ -103,10 +103,10 @@ def __init__( # Non moe param will be sync by global dp pg, moe param will be sync by extra dp pg. # Moe param grad is be split as non moe param by global dp pg, and grad will be merged in step. # And moe working and master param are split by extra dp pg. - self.extra_dp_pg = extra_dp_process_group - if self.extra_dp_pg is not None: - self.extra_dp_pg_size = dist.get_world_size(group=self.extra_dp_pg) - self.extra_dp_pg_rank = dist.get_rank(group=self.extra_dp_pg) + self.moe_extra_dp_pg = moe_extra_dp_process_group + if self.moe_extra_dp_pg is not None: + self.moe_extra_dp_pg_size = dist.get_world_size(group=self.moe_extra_dp_pg) + self.moe_extra_dp_pg_rank = dist.get_rank(group=self.moe_extra_dp_pg) # working and master params for mixed precision training self._working_param_groups = dict() @@ -152,7 +152,7 @@ def __init__( group_params = list() for param in param_group["params"]: if param.requires_grad: - if self.extra_dp_pg is None: + if self.moe_extra_dp_pg is None: # skip moe param if is_moe_tensor(param): moe_params.append(param) @@ -242,9 +242,9 @@ def _create_master_param_current_rank(self, param_list): else: padding_param = param.data.view(-1) - if self.extra_dp_pg is not None and is_moe_tensor(param): - splited_params = padding_param.split(padding_param.numel() // self.extra_dp_pg_size) - splited_params = splited_params[self.extra_dp_pg_rank] + if self.moe_extra_dp_pg is not None and is_moe_tensor(param): + splited_params = padding_param.split(padding_param.numel() // self.moe_extra_dp_pg_size) + splited_params = splited_params[self.moe_extra_dp_pg_rank] else: splited_params = padding_param.split(padding_param.numel() // self._world_size) splited_params = splited_params[self._local_rank] @@ -287,7 +287,7 @@ def _run_reduction(self): if self._bucket_store.num_elements_in_bucket() > 0: self._bucket_store.build_grad_in_bucket() - if self.extra_dp_pg is None: + if self.moe_extra_dp_pg is None: flat_grads = self._bucket_store.get_flatten_grad() flat_grads /= self._world_size else: @@ -331,7 +331,7 @@ def _run_reduction(self): if self._overlap_communication: stream = self._comm_stream # in case of the memory being reused in the default stream - if self.extra_dp_pg is None: + if self.moe_extra_dp_pg is None: flat_grads.record_stream(stream) else: if len(non_moe_grad_list) > 0: @@ -346,13 +346,13 @@ def _run_reduction(self): with torch.cuda.stream(stream): group_id = self._bucket_store.current_group_id - if self.extra_dp_pg is None: + if self.moe_extra_dp_pg is None: grad_dtype = flat_grads.dtype if self._communication_dtype is not None: flat_grads = flat_grads.to(self._communication_dtype) if not self._partition_grads: - if self.extra_dp_pg is None: + if self.moe_extra_dp_pg is None: dist.all_reduce(flat_grads, group=self.dp_pg) if flat_grads.dtype != grad_dtype: flat_grads = flat_grads.to(grad_dtype) @@ -384,12 +384,12 @@ def _run_reduction(self): # sync moe param only in zero group if len(moe_grad_list) > 0: - dist.all_reduce(moe_flat_grads, group=self.extra_dp_pg) + dist.all_reduce(moe_flat_grads, group=self.moe_extra_dp_pg) flat_grads_per_rank = moe_flat_grads.split(moe_flat_grads.numel() // self._world_size) self._sync_unpartitioned_grad(moe_grad_list, flat_grads_per_rank, group_id) else: - if self.extra_dp_pg is None: + if self.moe_extra_dp_pg is None: flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size)) recieved_grad = torch.zeros_like(flat_grads_list[0]) dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) @@ -423,11 +423,13 @@ def _run_reduction(self): self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id) if len(moe_grad_list) > 0: - flat_grads_list = list(moe_flat_grads.split(len(moe_flat_grads) // self.extra_dp_pg_size)) + flat_grads_list = list( + moe_flat_grads.split(len(moe_flat_grads) // self.moe_extra_dp_pg_size) + ) recieved_grad = torch.zeros_like(flat_grads_list[0]) - dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.extra_dp_pg) + dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.moe_extra_dp_pg) - param_slice = self._world_size // self.extra_dp_pg_size + param_slice = self._world_size // self.moe_extra_dp_pg_size recieved_grad = list(recieved_grad.split(len(recieved_grad) // param_slice)) grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] for split_recieved_grad in recieved_grad: @@ -573,14 +575,14 @@ def step(self, closure=None): grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param)) if len(grads) > 0: # moe hybrid zero - if self.extra_dp_pg is not None and is_moe_tensor(working_param): + if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param): real_working_params[group_id].append(working_param) if self._partition_grads: grad = grads else: - param_slice = self._world_size // self.extra_dp_pg_size + param_slice = self._world_size // self.moe_extra_dp_pg_size grad = grads[ - self.extra_dp_pg_rank * param_slice : (self.extra_dp_pg_rank + 1) * param_slice + self.moe_extra_dp_pg_rank * param_slice : (self.moe_extra_dp_pg_rank + 1) * param_slice ] grad = flatten(grad) else: @@ -616,7 +618,7 @@ def step(self, closure=None): # update the parameters self.optim.step() - # TODO: release the moe grad. we should store master param + # release the moe gradm if len(self.param_groups) > len(self._working_param_groups): for param in self.param_groups[-1]["params"]: param.grad = None @@ -632,12 +634,12 @@ def step(self, closure=None): master_working_param = self.optim.param_groups[group_id]["params"] for idx, splited_param in enumerate(master_working_param): working_param = real_working_params[group_id][idx] - if self.extra_dp_pg is not None and is_moe_tensor(working_param): + if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param): all_splited_param = [ torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype) - for _ in range(self.extra_dp_pg_size) + for _ in range(self.moe_extra_dp_pg_size) ] - dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.extra_dp_pg) + dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.moe_extra_dp_pg) else: all_splited_param = [ torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype) @@ -772,11 +774,11 @@ def state_dict(self) -> Dict: for k, v in state.items(): if isinstance(v, torch.Tensor) and k != "step": working_param = self._param_store.master_to_working_param[id(param)] - if self.extra_dp_pg is not None and is_moe_tensor(v): + if self.moe_extra_dp_pg is not None and is_moe_tensor(v): gather_tensor = [ - torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self.extra_dp_pg_size) + torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size) ] - dist.all_gather(gather_tensor, v.cuda(), group=self.extra_dp_pg) + dist.all_gather(gather_tensor, v.cuda(), group=self.moe_extra_dp_pg) else: gather_tensor = [ torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size) @@ -806,9 +808,9 @@ def load_state_dict(self, state_dict: Dict): v = v.flatten() if padding_size > 0: v = torch.nn.functional.pad(v, [0, padding_size]) - if self.extra_dp_pg is not None and is_moe_tensor(v): - v_list = v.split(v.numel() // self.extra_dp_pg_size) - zero_state_dict["state"][param_idx][k] = v_list[self.extra_dp_pg_rank].detach().clone() + if self.moe_extra_dp_pg is not None and is_moe_tensor(v): + v_list = v.split(v.numel() // self.moe_extra_dp_pg_size) + zero_state_dict["state"][param_idx][k] = v_list[self.moe_extra_dp_pg_rank].detach().clone() else: v_list = v.split(v.numel() // self._world_size) zero_state_dict["state"][param_idx][k] = v_list[self._local_rank].detach().clone() @@ -842,11 +844,11 @@ def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, i for k, v in states.items(): if isinstance(v, torch.Tensor) and k != "step": - if self.extra_dp_pg is not None and is_moe_tensor(v): + if self.moe_extra_dp_pg is not None and is_moe_tensor(v): state_tensor = [ - torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self.extra_dp_pg_size) + torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size) ] - dist.all_gather(state_tensor, v.cuda(), group=self.extra_dp_pg) + dist.all_gather(state_tensor, v.cuda(), group=self.moe_extra_dp_pg) else: state_tensor = [ torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size) @@ -882,7 +884,7 @@ def update_master_params(self, model: nn.Module) -> None: working_param = p.data.view(-1) if padding_size > 0: working_param = torch.nn.functional.pad(working_param, [0, padding_size]) - if self.extra_dp_pg is not None and is_moe_tensor(p): + if self.moe_extra_dp_pg is not None and is_moe_tensor(p): master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank]) else: master_param.copy_(working_param.chunk(self._world_size)[self._local_rank]) diff --git a/tests/test_moe/test_moe_hybrid_zero.py b/tests/test_moe/test_moe_hybrid_zero.py index e3f093f7461e..142af5de98a9 100644 --- a/tests/test_moe/test_moe_hybrid_zero.py +++ b/tests/test_moe/test_moe_hybrid_zero.py @@ -57,7 +57,7 @@ def run_zero_optim_test(local_rank, world_size, stage=1): zero_param.data.copy_(torch_param.data.detach().clone()) zero_optimizer = torch.optim.Adam(zero_model.parameters()) plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") - plugin.zero_optim_kwargs["extra_dp_process_group"] = extra_dp_group + plugin.zero_optim_kwargs["moe_extra_dp_process_group"] = extra_dp_group booster = Booster(plugin=plugin) zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) From b732ab03b56d4e5e696b7dce56cb79b5872426d2 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Mon, 30 Oct 2023 14:03:56 +0800 Subject: [PATCH 36/46] fix test --- tests/test_moe/test_moe_load_balance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_moe/test_moe_load_balance.py b/tests/test_moe/test_moe_load_balance.py index 4daad7949a87..048e85311f8b 100644 --- a/tests/test_moe/test_moe_load_balance.py +++ b/tests/test_moe/test_moe_load_balance.py @@ -131,7 +131,7 @@ def run_hybrid_zero_optim_test(local_rank, world_size, stage=1): zero_param.data.copy_(torch_param.data.detach().clone()) zero_optimizer = torch.optim.Adam(zero_model.parameters()) plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") - plugin.zero_optim_kwargs["extra_dp_process_group"] = extra_dp_group + plugin.zero_optim_kwargs["moe_extra_dp_process_group"] = extra_dp_group booster = Booster(plugin=plugin) zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) From e85122b466261593986d1b45aa477c0520129cf8 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Tue, 31 Oct 2023 14:47:33 +0800 Subject: [PATCH 37/46] update doc --- colossalai/zero/low_level/low_level_optim.py | 20 +- examples/language/openmoe/README.md | 130 +++++++++++- .../openmoe/benchmark/benchmark_cai.py | 57 +++-- .../openmoe/model/modeling_openmoe.py | 187 +++++++++-------- examples/language/openmoe/test_ci.sh | 14 +- examples/language/openmoe/train.py | 198 ++++++++++-------- examples/language/openmoe/train.sh | 47 ++++- 7 files changed, 423 insertions(+), 230 deletions(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 5e06dbfd947a..a002d2087257 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -406,16 +406,23 @@ def _run_reduction(self): else: self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id) else: + grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] + moe_grad_in_bucket_current_rank = [] + non_moe_grad_in_bucket_current_rank = [] + for idx, grad in enumerate(grad_in_bucket_current_rank): + if moe_list[idx] == True: + moe_grad_in_bucket_current_rank.append(grad) + else: + non_moe_grad_in_bucket_current_rank.append(grad) + if len(non_moe_grad_list) > 0: flat_grads_list = list( non_moe_flat_grads.split(len(non_moe_flat_grads) // self._world_size) ) recieved_grad = torch.zeros_like(flat_grads_list[0]) dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) - - grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] - sync_tensor(recieved_grad, grad_in_bucket_current_rank) - for grad in grad_in_bucket_current_rank: + sync_tensor(recieved_grad, non_moe_grad_in_bucket_current_rank) + for grad in non_moe_grad_in_bucket_current_rank: param_id = self._bucket_store.get_param_id_of_grad(grad) if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < 1: self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) @@ -431,12 +438,11 @@ def _run_reduction(self): param_slice = self._world_size // self.moe_extra_dp_pg_size recieved_grad = list(recieved_grad.split(len(recieved_grad) // param_slice)) - grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] for split_recieved_grad in recieved_grad: split_recieved_grad = _unflatten_dense_tensors( - split_recieved_grad, grad_in_bucket_current_rank + split_recieved_grad, moe_grad_in_bucket_current_rank ) - for grad in grad_in_bucket_current_rank: + for grad in moe_grad_in_bucket_current_rank: param_id = self._bucket_store.get_param_id_of_grad(grad) if ( len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) diff --git a/examples/language/openmoe/README.md b/examples/language/openmoe/README.md index 26b5ee73b054..3873232c5952 100644 --- a/examples/language/openmoe/README.md +++ b/examples/language/openmoe/README.md @@ -1,17 +1,131 @@ ## OpenMoE -[OpenMoE](https://github.com/XueFuzhao/OpenMoE) is a project aimed at Igniting the Open-Source MoE Community! +[OpenMoE](https://github.com/XueFuzhao/OpenMoE) is the open-source community's first decoder-only MoE transformer. OpenMoE is implemented in Jax, and [Colossal-AI](https://github.com/hpcaitech/ColossalAI) has pioneered an efficient open-source support for this model in PyTorch, enabling a broader range of users to participate in and use this model. The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates finetune and inference methods. -The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates finetune and inference methods. +## Usage +### 1. Installation -## Our Modifications +Please install the latest ColossalAI from source. -We reimplement OpenMoE with PyTorch + GPU. +```bash +CUDA_EXT=1 pip install -U git+https://github.com/hpcaitech/ColossalAI +``` + +Then install dependencies. + +```bash +cd ColossalAI +pip install -r requirements.txt +cd examples/language/openmoe +pip install -r requirements.txt +``` + +Additionally, we recommend you to use torch 1.13.1. We've tested our code on torch 1.13.1 and found it's compatible with our code and flash attention. + +### 2. Install kernels (Optional) + +We have utilized `Triton`, `FlashAttention` and `Apex` kernel for better performance. They are not necessary but we recommend you to install them to fully utilize your hardware. +``` +# install triton via pip +pip install triton + +# install flash attention via pip +pip install flash-attn + +# install apex from source +git clone https://github.com/NVIDIA/apex.git +cd apex +git checkout 741bdf50825a97664db08574981962d66436d16a +pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ --global-option="--cuda_ext" +``` + +### 3. Train +Yon can use colossalai run to launch single-node training: +```bash +colossalai run --standalone --nproc_per_node YOUR_GPU_PER_NODE train.py --OTHER_CONFIGURATIONS +``` +Yon can also use colossalai run to launch multi-nodes training: +```bash +colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE train.py --OTHER_CONFIGURATIONS +``` -## Run Inference +Here is a sample hostfile: + +```text +hostname1 +hostname2 +hostname3 +hostname4 +``` + +The hostname refers to the ip address of your nodes. Make sure master node can access all nodes (including itself) by ssh without password. + +Here is details about CLI arguments: + +- Model configuration: `--model_name`. `base` and `8b` are supported for OpenMoE. +- Booster plugin: `--plugin`. `ep`, `ep_zero` and `hybrid` are supported. `ep_zero` is recommended for general cases. `ep` can provides least memory consumption and `hybrid` suits large scale training. +- Output path: `--output_path`. The path to save your model. The default value is `./outputs`. +- Number of epochs: `--num_epochs`. The default value is 1. +- Local batch size: `--batch_size`. Batch size per GPU. The default value is 1. +- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000. +- Mixed precision: `--precision`. The default value is "bf16". "fp16", "bf16" and "fp32" are supported. +- Max length: `--max_length`. Max sequence length. Default to 2048. +- Dataset: `-d`, `--dataset`. The default dataset is `yizhongw/self_instruct`. It support any dataset from `datasets` with the same data format as it. +- Task Name: `--task_name`. Task of corresponding dataset. Default to `super_natural_instructions`. +- Learning rate: `--lr`. The default value is 1e-5. +- Weight decay: `--weight_decay`. The default value is 0. +- Zero stage: `--zero_stage`. Zero stage. Recommend 2 for ep and 1 for ep zero. +- Extra dp size: `--extra_dp_size`. Extra moe param dp size for ep_zero plugin. Recommended to be 2 or 4. +- Use kernel: `--use_kernel`. Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed. +- Use layernorm kernel: `--use_layernorm_kernel`. Use layernorm kernel. Need to install apex. Raise error if not installed. +- Router aux loss factor: `--router_aux_loss_factor`. Moe router z loss factor. You can refer to STMoE for details. +- Router z loss factor: `--router_z_loss_factor`. Moe router aux loss factor. You can refer to STMoE for details. +- Label smoothing: `--label_smoothing`. Label smoothing. +- Z loss factor: `--z_loss_factor`. The final outputs' classification z loss factor. +Load balance: `--load_balance`. Expert load balance. Defaults to False. Recommend enabling. +- Load balance interval: `--load_balance_interval`. Expert load balance interval. +- Communication overlap: `--comm_overlap`. Use communication overlap for MoE. Recommended to enable for multi-node training. + +### 4. Shell Script Examples + +For your convenience, we provide some shell scripts to train with various configurations. Here we will show an example of how to run training +OpenMoE. + +#### a. Running environment +This experiment was performed on a single computing nodes with 8 A800 80GB GPUs in total for OpenMoE-8B. The GPUs are fully connected with NVLink. + +#### b. Running command +We demonstrate how to run three plugins in `train.sh`. You can choose anyone and use your own args. -By running the following script: ```bash -bash infer.sh +bash train.sh +``` + +#### c. Multi-Nodes Training + +To run on multi-nodes, you can modify the script as: +```bash +colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \ +train.py --OTHER_CONFIGURATIONS +``` + +## Reference +``` +@article{bian2021colossal, + title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training}, + author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang}, + journal={arXiv preprint arXiv:2110.14883}, + year={2021} +} +``` + +```bibtex +@misc{openmoe2023, + author = {Fuzhao Xue, Zian Zheng, Yao Fu, Jinjie Ni, Zangwei Zheng, Wangchunshu Zhou and Yang You}, + title = {OpenMoE: Open Mixture-of-Experts Language Models}, + year = {2023}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/XueFuzhao/OpenMoE}}, +} ``` -You will infer a [OpenMoE-8B/32E](https://github.com/XueFuzhao/OpenMoE) model. diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py index 1a158eabc151..112a12cb6b17 100644 --- a/examples/language/openmoe/benchmark/benchmark_cai.py +++ b/examples/language/openmoe/benchmark/benchmark_cai.py @@ -1,3 +1,4 @@ +import argparse import json import os @@ -13,7 +14,6 @@ from utils import PerformanceEvaluator, get_model_numel import colossalai -from colossalai import get_default_parser from colossalai.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator @@ -42,27 +42,26 @@ def load_ckpt(repo_name: str, model: OpenMoeForCausalLM, booster: Booster): class RandomDataset(Dataset): - - def __init__(self, - num_samples: int = 1000, - max_length: int = 2048, - vocab_size: int = 256384, - tokenizer: T5Tokenizer = None): + def __init__( + self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 256384, tokenizer: T5Tokenizer = None + ): self.num_samples = num_samples self.max_length = max_length if os.path.exists("./mock_data.json"): self.input_ids = [] self.attention_mask = [] - with open("./mock_data.json", 'r') as f: + with open("./mock_data.json", "r") as f: data = json.load(f) for v in data.values(): d = v["text"] - encode = tokenizer("" + d, - return_tensors="pt", - add_special_tokens=False, - max_length=max_length, - truncation=True, - padding="max_length") + encode = tokenizer( + "" + d, + return_tensors="pt", + add_special_tokens=False, + max_length=max_length, + truncation=True, + padding="max_length", + ) self.input_ids.append(encode["input_ids"]) self.attention_mask.append(encode["attention_mask"]) self.input_ids = torch.cat(self.input_ids, dim=0).to(get_current_device()) @@ -87,7 +86,7 @@ def __getitem__(self, idx): def parse_args(): # basic settings - parser = get_default_parser() + parser = argparse.ArgumentParser() parser.add_argument( "--model_name", type=str, @@ -159,17 +158,7 @@ def main(): mgr_dict = { "seed": 42, } - if args.plugin == "zero": - dp_size = dist.get_world_size() - plugin = MoeHybridParallelPlugin( - pp_size=1, - **hybrid_dict, - ) - MOE_MANAGER.setup( - parallel=None, - **mgr_dict, - ) - elif args.plugin == "ep": + if args.plugin == "ep": dp_size = dist.get_world_size() plugin = MoeHybridParallelPlugin( pp_size=1, @@ -218,12 +207,14 @@ def main(): # Build OpenMoe model repo_name = "hpcaitech/openmoe-" + args.model_name config = LlamaConfig.from_pretrained(repo_name) - set_openmoe_args(config, - num_experts=config.num_experts, - moe_layer_interval=config.moe_layer_interval, - enable_load_balance=args.load_balance, - enable_kernel=args.use_kernel, - enable_comm_overlap=args.overlap_alltoall) + set_openmoe_args( + config, + num_experts=config.num_experts, + moe_layer_interval=config.moe_layer_interval, + enable_load_balance=args.load_balance, + enable_kernel=args.use_kernel, + enable_comm_overlap=args.overlap_alltoall, + ) with skip_init(): model = OpenMoeForCausalLM(config) coordinator.print_on_master(f"Finish init model with config:\n{config}") @@ -255,7 +246,7 @@ def main(): booster = Booster(plugin=plugin, **booster_kwargs) load_ckpt(repo_name, model, booster) model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader) - use_pipeline = (isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1) + use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1 is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() coordinator.print_on_master(f"Finish init booster") diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index 7d28de731407..7e3e6b3ed364 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -60,7 +60,7 @@ def set_openmoe_args( router_noisy_policy: str = None, router_drop_tks: bool = True, router_aux_loss_factor: float = 0.01, - router_z_loss_factor: float = 0.01, + router_z_loss_factor: float = 0.0001, mlp_gated: bool = True, label_smoothing: float = 0.001, z_loss_factor: float = 0.01, @@ -122,10 +122,9 @@ def set_openmoe_args( # Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: torch.Size, - dtype: torch.dtype, - device: torch.device, - past_key_values_length: int = 0): +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): """ Make causal mask used for bi-directional self-attention. """ @@ -169,10 +168,10 @@ def generate_fixed_pos_embedding(features, length, min_timescale=1.0, max_timesc output_cos: a float32 Tensor with shape [length, features] """ fraction = torch.arange(0, features, 2, dtype=torch.float32).cuda() / features - timescale = min_timescale * (max_timescale / min_timescale)**fraction - rotational_frequency = 1. / timescale + timescale = min_timescale * (max_timescale / min_timescale) ** fraction + rotational_frequency = 1.0 / timescale - sinusoid_inp = torch.einsum('i,j->ij', torch.arange(length, dtype=torch.float32).cuda(), rotational_frequency) + sinusoid_inp = torch.einsum("i,j->ij", torch.arange(length, dtype=torch.float32).cuda(), rotational_frequency) sinusoid_inp = torch.cat([sinusoid_inp, sinusoid_inp], dim=-1) @@ -193,8 +192,8 @@ def apply_rotary_embedding(q, k, cos, sin, decode=False, rotary_index=None): batch, qlen, qheads, d = q.shape kbatch, klen, kheads, kd = k.shape - assert batch == kbatch, f'{batch} != {kbatch}' - assert d == kd, f'{d} != {kd}' + assert batch == kbatch, f"{batch} != {kbatch}" + assert d == kd, f"{d} != {kd}" if decode and qlen == 1 and rotary_index is not None: qcos = cos[rotary_index + 1, :] qsin = sin[rotary_index + 1, :] @@ -220,8 +219,8 @@ def apply_rotary_embedding(q, k, cos, sin, decode=False, rotary_index=None): def rotate_half(x): """Rotates half the hidden dims of the input.""" - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) @@ -238,7 +237,6 @@ def SwiGLU(x): class OpenMoeMLP(nn.Module): - def __init__(self, config: LlamaConfig): super().__init__() self.pretraining_tp = config.pretraining_tp @@ -362,12 +360,9 @@ def forward( assert max_length <= self.sin.shape[0] sin, cos = self.sin[:max_length], self.cos[:max_length] # TODO: for inference, we can add emb kv into cache to avoid computation - query_states, key_states = apply_rotary_embedding(query_states, - key_states, - cos, - sin, - decode=True if q_len == 1 else False, - rotary_index=position_ids) + query_states, key_states = apply_rotary_embedding( + query_states, key_states, cos, sin, decode=True if q_len == 1 else False, rotary_index=position_ids + ) query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) @@ -377,6 +372,7 @@ def forward( if HAS_FLASH_ATTN and use_kernel: from flash_attn import flash_attn_func + query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) @@ -388,7 +384,8 @@ def forward( if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}") + f" {attn_weights.size()}" + ) if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): @@ -405,8 +402,10 @@ def forward( attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}") + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) @@ -425,7 +424,6 @@ def forward( class OpenMoeDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig, moe: bool): super().__init__() self.hidden_size = config.hidden_size @@ -434,23 +432,25 @@ def __init__(self, config: LlamaConfig, moe: bool): self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) if self.moe: - self.mlp = SparseMLP(num_experts=config.num_experts, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - router_top_k=config.router_topk, - router_capacity_factor_train=config.router_capacity_factor_train, - router_capacity_factor_eval=config.router_capacity_factor_eval, - router_min_capacity=config.router_min_capacity, - router_noisy_policy=config.router_noisy_policy, - router_drop_tks=config.router_drop_tks, - mlp_activation=config.hidden_act, - mlp_gated=config.mlp_gated, - enable_load_balance=config.enable_load_balance, - load_balance_tolerance=config.load_balance_tolerance, - load_balance_beam_width=config.load_balance_beam_width, - load_balance_group_swap_factor=config.load_balance_group_swap_factor, - enable_kernel=config.enable_kernel, - enable_comm_overlap=config.enable_comm_overlap) + self.mlp = SparseMLP( + num_experts=config.num_experts, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + router_top_k=config.router_topk, + router_capacity_factor_train=config.router_capacity_factor_train, + router_capacity_factor_eval=config.router_capacity_factor_eval, + router_min_capacity=config.router_min_capacity, + router_noisy_policy=config.router_noisy_policy, + router_drop_tks=config.router_drop_tks, + mlp_activation=config.hidden_act, + mlp_gated=config.mlp_gated, + enable_load_balance=config.enable_load_balance, + load_balance_tolerance=config.load_balance_tolerance, + load_balance_beam_width=config.load_balance_beam_width, + load_balance_group_swap_factor=config.load_balance_group_swap_factor, + enable_kernel=config.enable_kernel, + enable_comm_overlap=config.enable_comm_overlap, + ) self.pre_extra_mlp_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.extra_mlp = OpenMoeMLP(config) else: @@ -643,10 +643,12 @@ def __init__(self, config: LlamaConfig): self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList([ - OpenMoeDecoderLayer(config, moe=True if (i + 1) % config.moe_layer_interval == 0 else False) - for i in range(config.num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + OpenMoeDecoderLayer(config, moe=True if (i + 1) % config.moe_layer_interval == 0 else False) + for i in range(config.num_hidden_layers) + ] + ) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -674,10 +676,12 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, - tgt_len=input_shape[-1]).to(inputs_embeds.device) - combined_attention_mask = (expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + - combined_attention_mask) + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) return combined_attention_mask @@ -695,8 +699,9 @@ def forward( return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -720,10 +725,9 @@ def forward( if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange(past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device) + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() @@ -732,18 +736,20 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) # embed positions if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length_with_past), - dtype=torch.bool, - device=inputs_embeds.device) - attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds, - past_key_values_length) + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) hidden_states = inputs_embeds if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False # decoder layers @@ -760,7 +766,6 @@ def forward( if self.gradient_checkpointing and self.training: def create_custom_forward(module): - def custom_forward(*inputs): # None for past_key_value return module(*inputs, output_attentions, None) @@ -885,8 +890,9 @@ def forward( MOE_MANAGER.reset_loss() output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) @@ -920,7 +926,6 @@ def forward( if chunk_head == True: def create_custom_forward(module): - def custom_forward(*inputs): logits = module(inputs[0]) logits = logits.float() @@ -938,8 +943,8 @@ def custom_forward(*inputs): for batch_idx in range(hidden_states.shape[0]): loss = loss + torch.utils.checkpoint.checkpoint( create_custom_forward(self.lm_head), - hidden_states[batch_idx:batch_idx + 1, :], - labels[batch_idx:batch_idx + 1, :], + hidden_states[batch_idx : batch_idx + 1, :], + labels[batch_idx : batch_idx + 1, :], ) logits = None else: @@ -965,12 +970,9 @@ def custom_forward(*inputs): attentions=outputs.attentions, ) - def prepare_inputs_for_generation(self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - **kwargs): + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): if past_key_values: input_ids = input_ids[:, -1:] @@ -988,20 +990,23 @@ def prepare_inputs_for_generation(self, else: model_inputs = {"input_ids": input_ids} - model_inputs.update({ - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - }) + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) return model_inputs @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple( - past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),) + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) return reordered_past def _calculate_router_loss(self, aux_loss: list = None, z_loss: list = None): @@ -1023,19 +1028,23 @@ def _calculate_loss(self, logits: torch.Tensor, targets: torch.Tensor) -> torch. Tuple of scalar loss. """ if len(logits.shape) != len(targets.shape) + 1: - raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' % - (str(logits.shape), str(targets.shape))) + raise ValueError( + "Incorrect shapes. Got shape %s logits and %s targets" % (str(logits.shape), str(targets.shape)) + ) vocab_size = logits.shape[-1] confidence = 1.0 - self.config.label_smoothing low_confidence = (1.0 - confidence) / (vocab_size - 1) - normalizing_constant = -(confidence * math.log(confidence) + - (vocab_size - 1) * low_confidence * math.log(low_confidence + 1e-20)) + normalizing_constant = -( + confidence * math.log(confidence) + (vocab_size - 1) * low_confidence * math.log(low_confidence + 1e-20) + ) # one hot - soft_targets = targets[..., None] == \ - torch.arange(vocab_size, device=targets.device).reshape((1,) * len(targets.shape) + (-1,)) - soft_targets = torch.where(soft_targets, torch.full_like(soft_targets, confidence), - torch.full_like(soft_targets, low_confidence)) + soft_targets = targets[..., None] == torch.arange(vocab_size, device=targets.device).reshape( + (1,) * len(targets.shape) + (-1,) + ) + soft_targets = torch.where( + soft_targets, torch.full_like(soft_targets, confidence), torch.full_like(soft_targets, low_confidence) + ) soft_targets = soft_targets.to(torch.float32) # cross entropy @@ -1093,7 +1102,7 @@ def backward(ctx, *grad_outputs): z_loss = ctx.z_loss logits, targets, exp_shifted, sum_exp, log_softmax, log_z = ctx.saved_tensors # z-loss term adds the (2 * z_loss * log_z) factor. - deriv = ((1 + 2 * z_loss * log_z).unsqueeze(-1) * exp_shifted / sum_exp - targets) + deriv = (1 + 2 * z_loss * log_z).unsqueeze(-1) * exp_shifted / sum_exp - targets g_logits = g.unsqueeze(-1) * deriv g_targets = -g.unsqueeze(-1) * log_softmax diff --git a/examples/language/openmoe/test_ci.sh b/examples/language/openmoe/test_ci.sh index 71198d8756d0..960c83adb489 100644 --- a/examples/language/openmoe/test_ci.sh +++ b/examples/language/openmoe/test_ci.sh @@ -14,12 +14,22 @@ torchrun --standalone --nproc_per_node 4 train.py \ --num_epoch 1 \ --model_name "test" \ --plugin "ep_zero" \ - --batch_size 1 + --batch_size 1 \ + --zero_stage 1 \ + --extra_dp_size 2 \ + +torchrun --standalone --nproc_per_node 4 train.py \ + --num_epoch 1 \ + --model_name "test" \ + --plugin "ep_zero" \ + --batch_size 1 \ + --zero_stage 2 \ + --extra_dp_size 2 \ torchrun --standalone --nproc_per_node 4 train.py \ --model_name "test" \ --plugin "hybrid" \ - --num_epoch 1 \ + --num_epoch 1 \ --pp_size 2 \ --dp_size 1 \ --ep_size 2 \ diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index 8b2e0e833b58..e8c2f6aaa447 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -1,26 +1,27 @@ import argparse import os +from functools import partial +from typing import Dict -import datasets import torch import torch.distributed as dist -import transformers +from datasets import load_dataset from huggingface_hub import snapshot_download from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args from model.openmoe_policy import OpenMoeForCausalLMPolicy from torch.utils.data import Dataset from tqdm import tqdm -from transformers import Adafactor, T5Tokenizer +from transformers import T5Tokenizer from transformers.models.llama import LlamaConfig import colossalai from colossalai.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator -from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.moe.layers import apply_load_balance from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import skip_init +from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device @@ -41,32 +42,23 @@ def load_ckpt(repo_name: str, model: OpenMoeForCausalLM, booster: Booster): booster.load_model(model, ckpt_path) +def tokenize_data(batch, tokenizer: T5Tokenizer, max_length: int) -> Dict: + texts = ["" + sample["prompt"] + sample["completion"] for sample in batch] + data = tokenizer( + texts, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_length, + add_special_tokens=False, + ) + data = {k: v.cuda() for k, v in data.items()} + data["labels"] = data["input_ids"].clone() + return data + + class RandomDataset(Dataset): def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000, tokenizer=None): - """ - A random dataset - - You can use tokenizer to process your own data - Example: - self.input_ids = [] - self.attention_mask = [] - data = your_data() - data = shuffle(data) - for text in data: - # text is a str - encode = tokenizer( - "" + text, - return_tensors="pt", - add_special_tokens=False, - max_length=max_length, - truncation=True, - padding="max_length") - self.input_ids.append(encode["input_ids"]) - self.attention_mask.append(encode["attention_mask"]) - self.input_ids = torch.cat(self.input_ids, dim=0).to(get_current_device()) - self.attention_mask = torch.cat(self.attention_mask, dim=0).to(get_current_device()) - """ - # TODO: use distributed sampler self.num_samples = num_samples self.max_length = max_length self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) @@ -93,55 +85,80 @@ def parse_args(): choices=["base", "8b", "test"], help="Path to pretrained model or model identifier from huggingface.co/models.", ) + parser.add_argument( + "--plugin", + type=str, + default="hybrid", + choices=["ep", "ep_zero", "hybrid"], + help="Parallel methos. ep_zero is recommended for general cases. ep can provides least memory consumption and hybrid suits large scale training.", + ) parser.add_argument( "--output_path", type=str, - default="./output_model.bin", + default="./outputs", help="The path of your saved model after finetuning.", ) - parser.add_argument("--num_epoch", type=int, default=10, help="Number of epochs.") + parser.add_argument("--num_epoch", type=int, default=1, help="Number of epochs.") parser.add_argument( "--batch_size", type=int, - default=4, + default=1, help="Batch size (per dp group) for the training dataloader.", ) + parser.add_argument( + "--save_interval", + type=int, + default=1000, + help=" The interval (steps) of saving checkpoints.", + ) + parser.add_argument( + "--precision", + type=str, + default="bf16", + choices=["fp32", "bf16", "fp16"], + help="The mixed precision training.", + ) + parser.add_argument("--max_length", type=int, default=2048, help="Max sequence length.") parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") parser.add_argument( - "--plugin", + "--dataset", type=str, - default="hybrid", - help="parallel plugin", - choices=["ep", "ep_zero", "hybrid"], + default="yizhongw/self_instruct", + help="dataset name from `datasets` repo.", + ) + parser.add_argument( + "--task_name", + type=str, + default="super_natural_instructions", + help="task of corresponding dataset.", ) # optim - parser.add_argument("--decay_rate", type=float, default=-0.8, help="adafactor optim decay rate.") - parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to use.") parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.") + parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") # zero stage for all plugins - parser.add_argument("--zero_stage", type=int, default=2, help="zero stage in hybrid plugin") - - # ep zero plugin - parser.add_argument("--extra_dp_size", type=int, default=1, help="ep zero's moe dp size") - + parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.") + # ep_zero plugin + parser.add_argument( + "--extra_dp_size", type=int, default=1, help="ep_zero plugin's moe dp size. Recommended to be 2 or 4." + ) # hybrid plugin - parser.add_argument("--pp_size", type=int, default=2, help="pp size") - parser.add_argument("--dp_size", type=int, default=1, help="dp size") - parser.add_argument("--ep_size", type=int, default=2, help="ep size") - parser.add_argument("--microbatch_size", type=int, default=1, help="microbatch size") + parser.add_argument("--pp_size", type=int, default=2, help="pp size for hybrid plugin") + parser.add_argument("--dp_size", type=int, default=1, help="dp size for hybrid plugin") + parser.add_argument("--ep_size", type=int, default=2, help="ep size for hybrid plugin") + parser.add_argument("--microbatch_size", type=int, default=1, help="Microbatch size in pipeline for hybrid plugin") # kernel parser.add_argument( "--use_kernel", action="store_true", - help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations.", + help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.", ) parser.add_argument( "--use_layernorm_kernel", action="store_true", - help="Use layernorm kernel. Need to install apex.", + help="Use layernorm kernel. Need to install apex. Raise error if not installed.", ) # loss @@ -149,23 +166,30 @@ def parse_args(): "--router_aux_loss_factor", type=float, default=0.01, - help="router_aux_loss_factor.", + help="Moe router z loss. You can refer to STMoE for details.", ) parser.add_argument( "--router_z_loss_factor", type=float, default=0.0001, - help="router_z_loss_factor.", + help="Moe router aux loss. You can refer to STMoE for details.", + ) + parser.add_argument("--label_smoothing", type=float, default=0.0, help="Label smoothing.") + parser.add_argument( + "--z_loss_factor", type=float, default=0.0001, help="The final outputs' classification z loss factor." ) - parser.add_argument("--label_smoothing", type=float, default=0.0, help="label_smoothing.") - parser.add_argument("--z_loss_factor", type=float, default=0.0001, help="z_loss_factor.") # load balance - parser.add_argument("--load_balance", action="store_true", help="moe load balance") - parser.add_argument("--load_balance_interval", type=int, default=1000, help="moe load balance interval") - - # overlap - parser.add_argument("--comm_overlap", action="store_true", help="moe comm overlap") + parser.add_argument( + "--load_balance", action="store_true", help="Expert load balance. Defaults to False. Recommend to enable." + ) + parser.add_argument("--load_balance_interval", type=int, default=1000, help="Expert load balance interval.") + # communicate overlap + parser.add_argument( + "--comm_overlap", + action="store_true", + help="Use communication overlap for MoE. Recommended to enable for muiti-node training.", + ) args = parser.parse_args() return args @@ -179,16 +203,6 @@ def main(): coordinator = DistCoordinator() test_mode = args.model_name == "test" - # Manage loggers - disable_existing_loggers() - logger = get_dist_logger() - if coordinator.is_master(): - datasets.utils.logging.set_verbosity_warning() - transformers.utils.logging.set_verbosity_info() - else: - datasets.utils.logging.set_verbosity_error() - transformers.utils.logging.set_verbosity_error() - # Set plugin booster_kwargs = {} hybrid_dict = { @@ -196,7 +210,7 @@ def main(): "custom_policy": OpenMoeForCausalLMPolicy(), "enable_fused_normalization": args.use_layernorm_kernel, "enable_jit_fused": args.use_kernel, - "precision": "bf16", + "precision": args.precision, "zero_stage": args.zero_stage, } mgr_dict = { @@ -245,13 +259,13 @@ def main(): ) else: raise ValueError(f"Invalid plugin {args.plugin}") - logger.info(f"Set plugin as {plugin}", ranks=[0]) + coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}") # Build OpenMoe model if test_mode: config = LlamaConfig.from_pretrained("hpcaitech/openmoe-base") - config.hidden_size = 64 - config.intermediate_size = 128 + config.hidden_size = 128 + config.intermediate_size = 256 config.vocab_size = 32000 else: repo_name = "hpcaitech/openmoe-" + args.model_name @@ -269,30 +283,38 @@ def main(): ) with skip_init(): model = OpenMoeForCausalLM(config) - logger.info(f"Finish init model with config:\n{config}", ranks=[0]) + coordinator.print_on_master(f"Finish init model with config:\n{config}") # Enable gradient checkpointing model.gradient_checkpointing_enable() # Prepare tokenizer and dataloader tokenizer = T5Tokenizer.from_pretrained("google/umt5-small") - dataset = RandomDataset(num_samples=1000 if not test_mode else 20, tokenizer=tokenizer) - dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) + if test_mode: + dataset = RandomDataset(num_samples=20, tokenizer=tokenizer) + collate_fn = None + else: + dataset = load_dataset(args.dataset, args.task_name) + dataset = dataset["train"] + collate_fn = partial(tokenize_data, tokenizer=tokenizer, max_length=args.max_length) + dataloader = plugin.prepare_dataloader( + dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn + ) # Set optimizer - optimizer = Adafactor(model.parameters(), decay_rate=args.decay_rate, weight_decay=args.weight_decay) + optimizer = HybridAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) # Set booster booster = Booster(plugin=plugin, **booster_kwargs) - model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader) if not test_mode: load_ckpt(repo_name, model, booster) + model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader) use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1 is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() - logger.info(f"Finish init booster", ranks=[0]) + coordinator.print_on_master(f"Finish init booster") # Start finetuning - logger.info(f"Start finetuning", ranks=[0]) + coordinator.print_on_master(f"Start finetuning") for epoch in range(args.num_epoch): model.train() train_dataloader_iter = iter(dataloader) @@ -331,14 +353,24 @@ def main(): optimizer.zero_grad() # Apply load balance - if args.load_balance and args.load_balance_interval > 0 and step % args.load_balance_interval == 0: + if ( + args.load_balance + and args.load_balance_interval > 0 + and (step + 1) % args.load_balance_interval == 0 + ): coordinator.print_on_master(f"Apply load balance") apply_load_balance(model, optimizer) + # save ckeckpoint + if (step + 1) % args.save_interval == 0: + coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}") + booster.save_model(model, args.output_path, shard=True) + + # save checkpoint at the end of each epochs + booster.save_model(model, args.output_path, shard=True) + coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}") - # Finish training and evaluate - logger.info(f"Finish finetuning", ranks=[0]) - booster.save_model(model, args.output_path) - logger.info(f"Saving model checkpoint to {args.output_path}", ranks=[0]) + # Finish training + coordinator.print_on_master(f"Finish training") if __name__ == "__main__": diff --git a/examples/language/openmoe/train.sh b/examples/language/openmoe/train.sh index 6712aa10a88b..91cd3db8d7ba 100644 --- a/examples/language/openmoe/train.sh +++ b/examples/language/openmoe/train.sh @@ -1,9 +1,40 @@ -torchrun --standalone --nproc_per_node 4 train.py \ - --model_name "base" \ - --plugin "hybrid" \ - --pp_size 2 \ - --dp_size 1 \ - --ep_size 2 \ - --use_kernel \ +#!/bin/bash + +set -xue + +NUM_GPU=8 +MODEL="8b" +SEQ_LENGTH=2048 +BATCH_SIZE=1 +LR=0.00001 + +# ep zero +torchrun --standalone --nproc_per_node $NUM_GPU train.py \ + --num_epoch 1 \ + --model_name $MODEL \ + --plugin "ep_zero" \ + --batch_size $BATCH_SIZE \ + --lr $LR \ --zero_stage 1 \ - --batch_size 4 + --extra_dp_size 2 + +# ep +# torchrun --standalone --nproc_per_node $NUM_GPU train.py \ +# --num_epoch 1 \ +# --model_name $MODEL \ +# --plugin "ep_zero" \ +# --batch_size $BATCH_SIZE \ +# --lr $LR \ +# --zero_stage 1 + +# hybrid +# torchrun --standalone --nproc_per_node $NUM_GPU train.py \ +# --num_epoch 1 \ +# --model_name $MODEL \ +# --plugin "hybrid" \ +# --batch_size $BATCH_SIZE \ +# --lr $LR \ +# --zero_stage 1 \ +# --pp_size 2 \ +# --dp_size 1 \ +# --ep_size 2 \ From 25c329f22f16e882a1c0d0a537e403db35433219 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> Date: Tue, 31 Oct 2023 14:53:16 +0800 Subject: [PATCH 38/46] Dist (#7) * dist bench * update fsdp --- .../openmoe/benchmark/benchmark_cai_dist.sh | 57 +++++++++++++++++++ .../openmoe/benchmark/benchmark_fsdp.py | 35 ++++++------ .../openmoe/benchmark/benchmark_fsdp.sh | 23 +++++++- 3 files changed, 97 insertions(+), 18 deletions(-) create mode 100755 examples/language/openmoe/benchmark/benchmark_cai_dist.sh diff --git a/examples/language/openmoe/benchmark/benchmark_cai_dist.sh b/examples/language/openmoe/benchmark/benchmark_cai_dist.sh new file mode 100755 index 000000000000..469d17e2934a --- /dev/null +++ b/examples/language/openmoe/benchmark/benchmark_cai_dist.sh @@ -0,0 +1,57 @@ +#!/bin/bash + +set -xue + +export NCCL_IB_HCA=mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1 +export NCCL_IB_DISABLE=0 +export NCCL_SOCKET_IFNAME=eth0 +export NCCL_IB_GID_INDEX=3 +export NCCL_IB_TIMEOUT=23 +export NCCL_IB_RETRY_CNT=7 +export TORCH_DISTRIBUTED_DEBUG=INFO +export TORCH_DISTRIBUTED_DETAIL=DEBUG +export GLOO_SOCKET_IFNAME=eth0 + +NUM_GPU=8 +MODEL="8b" +SEQ_LENGTH=2048 +WARMUP=20 +ACTIVE=4 + +# HACK: make model importable +example_dir=$(dirname $(realpath $(dirname $0))) +if [ -z ${PYTHONPATH+x} ]; then + export PYTHONPATH=$example_dir +else + export PYTHONPATH=$example_dir:$PYTHONPATH +fi + + +# ep +echo -e "\n\n Naive EP \n\n" +colossalai run --nproc_per_node $NUM_GPU --hostfile "hostfile.txt" \ + $example_dir/benchmark/benchmark_cai.py \ + --model_name $MODEL \ + --batch_size 12 \ + --seq_length $SEQ_LENGTH \ + --warmup $WARMUP \ + --active $ACTIVE \ + --plugin ep \ + --zero_stage 2 + + +# ep_zero +echo -e "\n\n EP-ZERO \n\n" +colossalai run --nproc_per_node $NUM_GPU --hostfile "hostfile.txt" \ + $example_dir/benchmark/benchmark_cai.py \ + --model_name $MODEL \ + --batch_size 20 \ + --seq_length $SEQ_LENGTH \ + --warmup $WARMUP \ + --active $ACTIVE \ + --plugin ep_zero \ + --use_kernel \ + --extra_dp_size 2 \ + --zero_stage 1 \ + --load_balance \ + --overlap_alltoall diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.py b/examples/language/openmoe/benchmark/benchmark_fsdp.py index 531e18313798..45a11ad636b2 100644 --- a/examples/language/openmoe/benchmark/benchmark_fsdp.py +++ b/examples/language/openmoe/benchmark/benchmark_fsdp.py @@ -4,9 +4,8 @@ import torch import torch.distributed as dist -import torch.multiprocessing as mp import tqdm -from model.modeling_openmoe import LlamaConfig, OpenMoeDecoderLayer, OpenMoeForCausalLM +from model.modeling_openmoe import LlamaConfig, OpenMoeDecoderLayer, OpenMoeForCausalLM, set_openmoe_args from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy @@ -19,7 +18,6 @@ class RandomDataset(Dataset): - def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000): self.num_samples = num_samples self.max_length = max_length @@ -38,12 +36,12 @@ def __getitem__(self, idx): def fsdp_main(rank, world_size, args): - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "14501" # initialize the process group - dist.init_process_group("nccl", rank=rank, world_size=world_size) - MOE_MANAGER.setup(seed=42, parallel=None, use_kernel_optim=False) + # initialize the process group + dist.init_process_group("nccl") + + MOE_MANAGER.setup(seed=42, parallel=None) dp_size = dist.get_world_size() dataset = RandomDataset( @@ -56,10 +54,14 @@ def fsdp_main(rank, world_size, args): torch.cuda.set_device(rank) config = LlamaConfig.from_pretrained("hpcaitech/openmoe-%s" % args.model_name) - setattr(config, "router_aux_loss_factor", 0.1) - setattr(config, "router_z_loss_factor", 0.1) - setattr(config, "label_smoothing", 0.1) - setattr(config, "z_loss_factor", 0.1) + set_openmoe_args( + config, + num_experts=config.num_experts, + moe_layer_interval=config.moe_layer_interval, + enable_load_balance=False, + enable_kernel=False, + enable_comm_overlap=False, + ) torch.set_default_dtype(torch.float16) model = OpenMoeForCausalLM(config) torch.set_default_dtype(torch.float32) @@ -72,9 +74,9 @@ def fsdp_main(rank, world_size, args): model = FSDP( model, mixed_precision=MixedPrecision( - param_dtype=torch.float16, - reduce_dtype=torch.float16, - buffer_dtype=torch.float16, + param_dtype=torch.bfloat16, + reduce_dtype=torch.bfloat16, + buffer_dtype=torch.bfloat16, ), auto_wrap_policy=auto_wrap_policy, device_id=torch.cuda.current_device(), @@ -132,5 +134,6 @@ def fsdp_main(rank, world_size, args): torch.manual_seed(42) - WORLD_SIZE = torch.cuda.device_count() - mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ["LOCAL_RANK"]) + fsdp_main(local_rank, world_size, args) diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.sh b/examples/language/openmoe/benchmark/benchmark_fsdp.sh index 0380ee1ade20..18b182dd832d 100755 --- a/examples/language/openmoe/benchmark/benchmark_fsdp.sh +++ b/examples/language/openmoe/benchmark/benchmark_fsdp.sh @@ -2,7 +2,16 @@ set -xue -NUM_GPU=8 +export NCCL_IB_HCA=mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1 +export NCCL_IB_DISABLE=0 +export NCCL_SOCKET_IFNAME=eth0 +export NCCL_IB_GID_INDEX=3 +export NCCL_IB_TIMEOUT=23 +export NCCL_IB_RETRY_CNT=7 +export TORCH_DISTRIBUTED_DEBUG=INFO +export TORCH_DISTRIBUTED_DETAIL=DEBUG +export GLOO_SOCKET_IFNAME=eth0 + MODEL="8b" BATCH_SIZE=1 SEQ_LENGTH=2048 @@ -17,7 +26,17 @@ else export PYTHONPATH=$example_dir:$PYTHONPATH fi -python $example_dir/benchmark/benchmark_fsdp.py \ +# single node +torchrun --standalone $example_dir/benchmark/benchmark_fsdp.py \ + --model_name $MODEL \ + --batch_size $BATCH_SIZE \ + --seq_length $SEQ_LENGTH \ + --warmup $WARMUP \ + --active $ACTIVE + +# multi node +torchrun --nproc_per_node=8 --nnodes=2 --node_rank=node_rank --master_addr=master_addr --master_port=master_port \ + $example_dir/benchmark/benchmark_fsdp.py \ --model_name $MODEL \ --batch_size $BATCH_SIZE \ --seq_length $SEQ_LENGTH \ From 6b03bd448df972d1cf8e050585372244a57e801b Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Tue, 31 Oct 2023 18:03:59 +0800 Subject: [PATCH 39/46] update dist script --- .../language/openmoe/benchmark/benchmark_cai_dist.sh | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/examples/language/openmoe/benchmark/benchmark_cai_dist.sh b/examples/language/openmoe/benchmark/benchmark_cai_dist.sh index 469d17e2934a..06d57e4f0574 100755 --- a/examples/language/openmoe/benchmark/benchmark_cai_dist.sh +++ b/examples/language/openmoe/benchmark/benchmark_cai_dist.sh @@ -2,16 +2,6 @@ set -xue -export NCCL_IB_HCA=mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1 -export NCCL_IB_DISABLE=0 -export NCCL_SOCKET_IFNAME=eth0 -export NCCL_IB_GID_INDEX=3 -export NCCL_IB_TIMEOUT=23 -export NCCL_IB_RETRY_CNT=7 -export TORCH_DISTRIBUTED_DEBUG=INFO -export TORCH_DISTRIBUTED_DETAIL=DEBUG -export GLOO_SOCKET_IFNAME=eth0 - NUM_GPU=8 MODEL="8b" SEQ_LENGTH=2048 From 659c9b17383e4dd64000f0096d96dc2263c8822f Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Tue, 31 Oct 2023 18:04:37 +0800 Subject: [PATCH 40/46] update cai version --- examples/language/openmoe/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/language/openmoe/requirements.txt b/examples/language/openmoe/requirements.txt index 935a3f1e4ce0..ccf02ba1d0d6 100644 --- a/examples/language/openmoe/requirements.txt +++ b/examples/language/openmoe/requirements.txt @@ -1,4 +1,4 @@ -colossalai >= 0.1.12 +colossalai >= 0.3.3 torch >= 1.8.1 transformers >= 4.20.0 sentencepiece From caece568d7d5871a927fa0501f4a2adb2c77ce8d Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Tue, 31 Oct 2023 18:22:49 +0800 Subject: [PATCH 41/46] update fsdp --- examples/language/openmoe/benchmark/benchmark_fsdp.sh | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.sh b/examples/language/openmoe/benchmark/benchmark_fsdp.sh index 18b182dd832d..c6f5624dd746 100755 --- a/examples/language/openmoe/benchmark/benchmark_fsdp.sh +++ b/examples/language/openmoe/benchmark/benchmark_fsdp.sh @@ -2,16 +2,6 @@ set -xue -export NCCL_IB_HCA=mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1 -export NCCL_IB_DISABLE=0 -export NCCL_SOCKET_IFNAME=eth0 -export NCCL_IB_GID_INDEX=3 -export NCCL_IB_TIMEOUT=23 -export NCCL_IB_RETRY_CNT=7 -export TORCH_DISTRIBUTED_DEBUG=INFO -export TORCH_DISTRIBUTED_DETAIL=DEBUG -export GLOO_SOCKET_IFNAME=eth0 - MODEL="8b" BATCH_SIZE=1 SEQ_LENGTH=2048 From 9fe76802712cacdb1be656787be31423516526d8 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Tue, 31 Oct 2023 21:04:04 +0800 Subject: [PATCH 42/46] update zero --- colossalai/zero/low_level/low_level_optim.py | 68 ++++++++------------ tests/test_moe/test_moe_hybrid_zero.py | 2 + tests/test_moe/test_moe_load_balance.py | 1 + 3 files changed, 29 insertions(+), 42 deletions(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index a002d2087257..8cc05704e65a 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -359,18 +359,7 @@ def _run_reduction(self): flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size) grad_in_bucket = self._bucket_store.get_grad() - - for rank, grad_list in grad_in_bucket.items(): - sync_tensor(flat_grads_per_rank[rank], grad_list) - for grad in grad_list: - param_id = self._bucket_store.get_param_id_of_grad(grad) - if ( - len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) - < self._world_size - ): - self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) - else: - self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id) + self._update_unpartitoned_grad(grad_in_bucket.values(), flat_grads_per_rank, group_id) # sync extra zero group else: @@ -380,13 +369,13 @@ def _run_reduction(self): flat_grads_per_rank = non_moe_flat_grads.split( non_moe_flat_grads.numel() // self._world_size ) - self._sync_unpartitioned_grad(non_moe_grad_list, flat_grads_per_rank, group_id) + self._update_unpartitoned_grad(non_moe_grad_list, flat_grads_per_rank, group_id) # sync moe param only in zero group if len(moe_grad_list) > 0: dist.all_reduce(moe_flat_grads, group=self.moe_extra_dp_pg) flat_grads_per_rank = moe_flat_grads.split(moe_flat_grads.numel() // self._world_size) - self._sync_unpartitioned_grad(moe_grad_list, flat_grads_per_rank, group_id) + self._update_unpartitoned_grad(moe_grad_list, flat_grads_per_rank, group_id) else: if self.moe_extra_dp_pg is None: @@ -398,14 +387,9 @@ def _run_reduction(self): recieved_grad = recieved_grad.to(grad_dtype) grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] - sync_tensor(recieved_grad, grad_in_bucket_current_rank) - for grad in grad_in_bucket_current_rank: - param_id = self._bucket_store.get_param_id_of_grad(grad) - if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < 1: - self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) - else: - self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id) + self._update_partitoned_grad(grad_in_bucket_current_rank, recieved_grad, group_id, 1) else: + # categorize moe and non moe param grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] moe_grad_in_bucket_current_rank = [] non_moe_grad_in_bucket_current_rank = [] @@ -421,13 +405,9 @@ def _run_reduction(self): ) recieved_grad = torch.zeros_like(flat_grads_list[0]) dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) - sync_tensor(recieved_grad, non_moe_grad_in_bucket_current_rank) - for grad in non_moe_grad_in_bucket_current_rank: - param_id = self._bucket_store.get_param_id_of_grad(grad) - if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < 1: - self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) - else: - self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id) + self._update_partitoned_grad( + non_moe_grad_in_bucket_current_rank, recieved_grad, group_id, 1 + ) if len(moe_grad_list) > 0: flat_grads_list = list( @@ -435,34 +415,38 @@ def _run_reduction(self): ) recieved_grad = torch.zeros_like(flat_grads_list[0]) dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.moe_extra_dp_pg) - param_slice = self._world_size // self.moe_extra_dp_pg_size recieved_grad = list(recieved_grad.split(len(recieved_grad) // param_slice)) for split_recieved_grad in recieved_grad: split_recieved_grad = _unflatten_dense_tensors( split_recieved_grad, moe_grad_in_bucket_current_rank ) - for grad in moe_grad_in_bucket_current_rank: + for real_grad, grad in zip(split_recieved_grad, moe_grad_in_bucket_current_rank): param_id = self._bucket_store.get_param_id_of_grad(grad) - if ( - len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) - < param_slice - ): - self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) - else: - self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id) + self._add_grad(real_grad, param_slice, group_id, param_id) self._bucket_store.reset() - def _sync_unpartitioned_grad(self, origin_grad_list, flat_grad_list, group_id): + def _update_unpartitoned_grad(self, origin_grad_list: List, flat_grad_list: List, group_id: int) -> None: for rank, grad_list in enumerate(origin_grad_list): sync_tensor(flat_grad_list[rank], grad_list) for grad in grad_list: param_id = self._bucket_store.get_param_id_of_grad(grad) - if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < self._world_size: - self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) - else: - self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id) + self._add_grad(grad, self._world_size, group_id, param_id) + + def _update_partitoned_grad( + self, origin_grad_list: List, flat_grad: torch.Tensor, group_id: int, partition_num: int + ) -> None: + sync_tensor(flat_grad, origin_grad_list) + for grad in origin_grad_list: + param_id = self._bucket_store.get_param_id_of_grad(grad) + self._add_grad(grad, partition_num, group_id, param_id) + + def _add_grad(self, grad: torch.Tensor, partition_num: int, group_id: int, param_id: int) -> None: + if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num: + self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) + else: + self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id) def _add_to_bucket(self, param, group_id): param_size = param.numel() diff --git a/tests/test_moe/test_moe_hybrid_zero.py b/tests/test_moe/test_moe_hybrid_zero.py index 142af5de98a9..e9f71d5ca635 100644 --- a/tests/test_moe/test_moe_hybrid_zero.py +++ b/tests/test_moe/test_moe_hybrid_zero.py @@ -34,6 +34,7 @@ def run_zero_optim_test(local_rank, world_size, stage=1): data = torch.randn(16, 4).cuda() label = torch.randint(0, 4, (16,)).cuda() + MOE_MANAGER.__init__() MOE_MANAGER.setup(seed=42, parallel=None) torch_model = MoeModel() torch_optimizer = torch.optim.Adam(torch_model.parameters()) @@ -82,6 +83,7 @@ def run_zero_optim_test(local_rank, world_size, stage=1): def run_dist(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_zero_optim_test(rank, world_size, stage=1) + run_zero_optim_test(rank, world_size, stage=2) @pytest.mark.dist diff --git a/tests/test_moe/test_moe_load_balance.py b/tests/test_moe/test_moe_load_balance.py index 048e85311f8b..52c21da5b4a5 100644 --- a/tests/test_moe/test_moe_load_balance.py +++ b/tests/test_moe/test_moe_load_balance.py @@ -175,6 +175,7 @@ def run_dist(rank, world_size, port): run_zero_optim_test(rank, world_size, stage=1) run_zero_optim_test(rank, world_size, stage=2) run_hybrid_zero_optim_test(rank, world_size, stage=1) + run_hybrid_zero_optim_test(rank, world_size, stage=2) @pytest.mark.dist From 0eb56231ba1bed88089e4efa2e6e49667ed25415 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Wed, 1 Nov 2023 01:03:16 +0800 Subject: [PATCH 43/46] fix bug --- colossalai/moe/load_balance.py | 46 ++++++++++++-------- colossalai/zero/low_level/low_level_optim.py | 6 +-- tests/test_moe/test_moe_load_balance.py | 1 + 3 files changed, 31 insertions(+), 22 deletions(-) diff --git a/colossalai/moe/load_balance.py b/colossalai/moe/load_balance.py index 4a3d0fe4d096..85c12d73fa52 100644 --- a/colossalai/moe/load_balance.py +++ b/colossalai/moe/load_balance.py @@ -13,7 +13,6 @@ class LoadBalancer: - def __init__( self, experts: MLPExperts, @@ -41,6 +40,8 @@ def __init__( pp_size = 1 if MOE_MANAGER.pp_size is None else MOE_MANAGER.pp_size global_dp_group = ProcessGroupMesh(pp_size, dist.get_world_size() // pp_size) self.global_dp_group = global_dp_group.get_group_along_axis(1) + self.global_dp_rank = dist.get_rank(self.global_dp_group) + self.global_dp_size = dist.get_world_size(self.global_dp_group) def _clear_load(self) -> None: self.local_load = None @@ -142,7 +143,7 @@ def _beam_search( for group_size_j in range(group_size): new_data = deepcopy(data) # calculate origin group sum - origin_diff = (origin_diff_list[group_num_i] + origin_diff_list[group_num_j]) + origin_diff = origin_diff_list[group_num_i] + origin_diff_list[group_num_j] # swap data self._swap_data( new_data, @@ -153,7 +154,8 @@ def _beam_search( ) # calculate new group sum new_diff = self._get_diff_from_avg(new_data, group_num_i, avg) + self._get_diff_from_avg( - new_data, group_num_j, avg) + new_data, group_num_j, avg + ) # caculate score new_score = origin_diff - new_diff if new_score > 0: @@ -307,14 +309,15 @@ def _swap_expert_param_and_optim( # TODO: exchange master weight, skip for now # master weight is shared by dp group tmp = working_weight_ptr.view(-1).split( - working_weight_ptr.numel() // dist.get_world_size(self.moe_dp_group))[dist.get_rank(self.moe_dp_group)] + working_weight_ptr.numel() // dist.get_world_size(self.moe_dp_group) + )[dist.get_rank(self.moe_dp_group)] master_weight_ptr.data.copy_(tmp.clone().detach().to(master_weight_ptr.device).to(master_weight_ptr.dtype)) # exchange optim self._swap_expert_single_tensor(exp_avg_ptr, expert_idx, comm_group, send_first, comm_rank) self._swap_expert_single_tensor(exp_avg_sq_ptr, expert_idx, comm_group, send_first, comm_rank) def _gather_global_dp_group(self, data: Tensor) -> Tensor: - data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size(self.global_dp_group))] + data_list = [torch.zeros_like(data) for _ in range(self.global_dp_size)] dist.all_gather(data_list, data, group=self.global_dp_group) data_list = torch.cat(data_list, dim=0) return data_list @@ -348,8 +351,12 @@ def _swap_moe_param(self, swap_list: List, optim: LowLevelZeroOptimizer) -> None global_master_gate_weight = self._gather_global_dp_group(master_gate_weight).view(gate_shape) global_gate_exp_avg = self._gather_global_dp_group(gate_exp_avg).view(gate_shape) global_gate_exp_avg_sq = self._gather_global_dp_group(gate_exp_avg_sq).view(gate_shape) - assert (self.gate.shape == global_master_gate_weight.shape == global_gate_exp_avg.shape == - global_gate_exp_avg_sq.shape) + assert ( + self.gate.shape + == global_master_gate_weight.shape + == global_gate_exp_avg.shape + == global_gate_exp_avg_sq.shape + ) for swap in swap_list: source_group, source_idx, target_group, target_idx = swap @@ -380,10 +387,10 @@ def _swap_moe_param(self, swap_list: List, optim: LowLevelZeroOptimizer) -> None source_expert_pos = source_group * self.local_expert_num + source_idx target_expert_pos = target_group * self.local_expert_num + target_idx for gate in [ - self.gate, - global_master_gate_weight, - global_gate_exp_avg, - global_gate_exp_avg_sq, + self.gate, + global_master_gate_weight, + global_gate_exp_avg, + global_gate_exp_avg_sq, ]: origin_source = gate.data[source_expert_pos].clone().detach() origin_target = gate.data[target_expert_pos].clone().detach() @@ -393,16 +400,17 @@ def _swap_moe_param(self, swap_list: List, optim: LowLevelZeroOptimizer) -> None ) # update gate - dp_group_rank = dist.get_rank(self.global_dp_group) - dp_group_size = dist.get_world_size(self.global_dp_group) - global_master_gate_weight = global_master_gate_weight.view(-1).split(global_master_gate_weight.numel() // - dp_group_size)[dp_group_rank] + global_master_gate_weight = global_master_gate_weight.view(-1).split( + global_master_gate_weight.numel() // self.global_dp_size + )[self.global_dp_rank] master_gate_weight.data.copy_(global_master_gate_weight) - global_gate_exp_avg = global_gate_exp_avg.view(-1).split(global_gate_exp_avg.numel() // - dp_group_size)[dp_group_rank] + global_gate_exp_avg = global_gate_exp_avg.view(-1).split(global_gate_exp_avg.numel() // self.global_dp_size)[ + self.global_dp_rank + ] gate_exp_avg.data.copy_(global_gate_exp_avg) - global_gate_exp_avg_sq = global_gate_exp_avg_sq.view(-1).split(global_gate_exp_avg_sq.numel() // - dp_group_size)[dp_group_rank] + global_gate_exp_avg_sq = global_gate_exp_avg_sq.view(-1).split( + global_gate_exp_avg_sq.numel() // self.global_dp_size + )[self.global_dp_rank] gate_exp_avg_sq.data.copy_(global_gate_exp_avg_sq) @torch.no_grad() diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 8cc05704e65a..932053dd1294 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -432,7 +432,7 @@ def _update_unpartitoned_grad(self, origin_grad_list: List, flat_grad_list: List sync_tensor(flat_grad_list[rank], grad_list) for grad in grad_list: param_id = self._bucket_store.get_param_id_of_grad(grad) - self._add_grad(grad, self._world_size, group_id, param_id) + self._add_grad(grad, self._world_size, group_id, param_id, rank) def _update_partitoned_grad( self, origin_grad_list: List, flat_grad: torch.Tensor, group_id: int, partition_num: int @@ -442,11 +442,11 @@ def _update_partitoned_grad( param_id = self._bucket_store.get_param_id_of_grad(grad) self._add_grad(grad, partition_num, group_id, param_id) - def _add_grad(self, grad: torch.Tensor, partition_num: int, group_id: int, param_id: int) -> None: + def _add_grad(self, grad: torch.Tensor, partition_num: int, group_id: int, param_id: int, rank: int = 0) -> None: if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num: self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) else: - self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id) + self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id) def _add_to_bucket(self, param, group_id): param_size = param.numel() diff --git a/tests/test_moe/test_moe_load_balance.py b/tests/test_moe/test_moe_load_balance.py index 52c21da5b4a5..173a7a356555 100644 --- a/tests/test_moe/test_moe_load_balance.py +++ b/tests/test_moe/test_moe_load_balance.py @@ -160,6 +160,7 @@ def run_hybrid_zero_optim_test(local_rank, world_size, stage=1): zero_optimizer.step() zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) torch_out = run_fwd_bwd(torch_model, data, label, criterion, None) + # TODO: high atol, check if bug exists assert torch.allclose(zero_out, torch_out, atol=8e-4), f"zero_out:{zero_out}\ntorch_out{torch_out}" From 4be194a5b28179522b068de71b2797a1b21f919d Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Wed, 1 Nov 2023 10:06:21 +0800 Subject: [PATCH 44/46] reverse legacy --- colossalai/legacy/context/random/__init__.py | 14 ++++++++++++-- colossalai/legacy/context/random/_helper.py | 10 ++++++++++ .../legacy/engine/gradient_handler/__init__.py | 9 +++++++-- .../gradient_handler/_moe_gradient_handler.py | 2 +- 4 files changed, 30 insertions(+), 5 deletions(-) diff --git a/colossalai/legacy/context/random/__init__.py b/colossalai/legacy/context/random/__init__.py index e2314f859d3f..5e8d82922ddc 100644 --- a/colossalai/legacy/context/random/__init__.py +++ b/colossalai/legacy/context/random/__init__.py @@ -3,6 +3,7 @@ get_current_mode, get_seeds, get_states, + moe_set_seed, reset_seeds, seed, set_mode, @@ -12,6 +13,15 @@ ) __all__ = [ - 'seed', 'set_mode', 'with_seed', 'add_seed', 'get_seeds', 'get_states', 'get_current_mode', 'set_seed_states', - 'sync_states', 'reset_seeds' + "seed", + "set_mode", + "with_seed", + "add_seed", + "get_seeds", + "get_states", + "get_current_mode", + "set_seed_states", + "sync_states", + "moe_set_seed", + "reset_seeds", ] diff --git a/colossalai/legacy/context/random/_helper.py b/colossalai/legacy/context/random/_helper.py index 7d27b3f85db9..be1d951d1229 100644 --- a/colossalai/legacy/context/random/_helper.py +++ b/colossalai/legacy/context/random/_helper.py @@ -159,5 +159,15 @@ def wrapper(*args, **kwargs): return wrapper +def moe_set_seed(seed): + if torch.cuda.is_available(): + from colossalai.legacy.core import global_context as gpc + + global_rank = gpc.get_global_rank() + diff_seed = seed + global_rank + add_seed(ParallelMode.TENSOR, diff_seed, True) + print(f"moe seed condition: {global_rank} with tensor seed {diff_seed}", flush=True) + + def reset_seeds(): _SEED_MANAGER.reset() diff --git a/colossalai/legacy/engine/gradient_handler/__init__.py b/colossalai/legacy/engine/gradient_handler/__init__.py index 633e9f885918..78928b138842 100644 --- a/colossalai/legacy/engine/gradient_handler/__init__.py +++ b/colossalai/legacy/engine/gradient_handler/__init__.py @@ -1,10 +1,15 @@ from ._base_gradient_handler import BaseGradientHandler from ._data_parallel_gradient_handler import DataParallelGradientHandler +from ._moe_gradient_handler import MoeGradientHandler from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler from ._zero_gradient_handler import ZeROGradientHandler __all__ = [ - 'BaseGradientHandler', 'DataParallelGradientHandler', 'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler', - 'SequenceParallelGradientHandler' + "BaseGradientHandler", + "DataParallelGradientHandler", + "ZeROGradientHandler", + "PipelineSharedModuleGradientHandler", + "MoeGradientHandler", + "SequenceParallelGradientHandler", ] diff --git a/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py index 2c999ca77be7..6a7224cff7bd 100644 --- a/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py +++ b/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py @@ -2,7 +2,7 @@ from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.registry import GRADIENT_HANDLER -from colossalai.moe.utils import get_moe_epsize_param_dict +from colossalai.utils.moe import get_moe_epsize_param_dict from ._base_gradient_handler import BaseGradientHandler from .utils import bucket_allreduce From 7e92e7b3b5e1bb96d69211367e04090853b90dcb Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Wed, 1 Nov 2023 10:08:11 +0800 Subject: [PATCH 45/46] update --- colossalai/legacy/engine/gradient_handler/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/colossalai/legacy/engine/gradient_handler/__init__.py b/colossalai/legacy/engine/gradient_handler/__init__.py index 78928b138842..713df5a64783 100644 --- a/colossalai/legacy/engine/gradient_handler/__init__.py +++ b/colossalai/legacy/engine/gradient_handler/__init__.py @@ -1,6 +1,5 @@ from ._base_gradient_handler import BaseGradientHandler from ._data_parallel_gradient_handler import DataParallelGradientHandler -from ._moe_gradient_handler import MoeGradientHandler from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler from ._zero_gradient_handler import ZeROGradientHandler @@ -10,6 +9,5 @@ "DataParallelGradientHandler", "ZeROGradientHandler", "PipelineSharedModuleGradientHandler", - "MoeGradientHandler", "SequenceParallelGradientHandler", ] From da6392f6988f294be0cc22259bddd5ff1cfe807f Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Wed, 1 Nov 2023 13:54:58 +0800 Subject: [PATCH 46/46] update readme --- examples/language/openmoe/README.md | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/language/openmoe/README.md b/examples/language/openmoe/README.md index 3873232c5952..a0821a5330a4 100644 --- a/examples/language/openmoe/README.md +++ b/examples/language/openmoe/README.md @@ -14,9 +14,7 @@ CUDA_EXT=1 pip install -U git+https://github.com/hpcaitech/ColossalAI Then install dependencies. ```bash -cd ColossalAI -pip install -r requirements.txt -cd examples/language/openmoe +cd ColossalAI/examples/language/openmoe pip install -r requirements.txt ``` @@ -30,7 +28,7 @@ We have utilized `Triton`, `FlashAttention` and `Apex` kernel for better perform pip install triton # install flash attention via pip -pip install flash-attn +pip install flash-attn==2.0.5 # install apex from source git clone https://github.com/NVIDIA/apex.git