diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 1f3bb294a7ca..784204528d65 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -1,10 +1,15 @@ +import random from typing import Optional +import numpy as np import torch import torch.distributed as dist +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelPlugin from colossalai.cluster import ProcessGroupMesh +from colossalai.moe import MoeCheckpintIO from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig @@ -174,3 +179,59 @@ def __init__(self, partition_grad=(self.zero_stage == 2)) self.max_norm = max_norm + + def prepare_dataloader(self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + **kwargs): + r""" + Prepare a dataloader for distributed training. The dataloader will be wrapped by + `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. + + + Args: + dataset (`torch.utils.data.Dataset`): The dataset to be loaded. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. + seed (int, optional): Random worker seed for sampling, defaults to 1024. + add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. + num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. + kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in + `DataLoader `_. + + Returns: + :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. + """ + _kwargs = kwargs.copy() + sampler = DistributedSampler(dataset, + num_replicas=self.pg_mesh.size(DP_AXIS), + rank=self.pg_mesh.coordinate(DP_AXIS), + shuffle=shuffle) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader(dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs) + + def get_checkpoint_io(self) -> MoeCheckpintIO: + self.checkpoint_io = MoeCheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) + return self.checkpoint_io diff --git a/colossalai/moe/checkpoint.py b/colossalai/moe/checkpoint.py index 3cda5a7f044c..99e0ae811bbd 100644 --- a/colossalai/moe/checkpoint.py +++ b/colossalai/moe/checkpoint.py @@ -1,25 +1,53 @@ +import logging +import os from copy import deepcopy from pathlib import Path -from typing import Optional +from typing import Iterator, Optional, OrderedDict, Tuple import torch import torch.distributed as dist import torch.nn as nn +from torch.distributed import ProcessGroup from torch.optim import Optimizer -from colossalai.checkpoint_io import GeneralCheckpointIO +from colossalai.checkpoint_io import CheckpointIndexFile, HybridParallelCheckpointIO +from colossalai.checkpoint_io.utils import ( + StateDictSharder, + gather_distributed_param, + get_model_base_filenames, + is_safetensors_available, + load_shard_state_dict, + load_state_dict_into_model, + save_config_file, + save_state_dict_shards, +) +from colossalai.moe.manager import MOE_MANAGER from colossalai.tensor.moe_tensor.api import get_dp_rank, get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor -class MoeCheckpintIO(GeneralCheckpointIO): +class MoeCheckpintIO(HybridParallelCheckpointIO): - def __init__(self) -> None: - super().__init__() + def __init__( + self, + dp_group: ProcessGroup, + pp_group: ProcessGroup, + tp_group: ProcessGroup, + zero_stage: int, + ) -> None: + assert zero_stage in [ + 0, + 1, + 2, + ], f"zero_stage should be 0 or 1 or 2, got {zero_stage}" + super().__init__(dp_group, pp_group, tp_group, zero_stage) + self.parallel = MOE_MANAGER.parallel - def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): - state_dict = torch.load(checkpoint) + def pre_load_model(self, model: nn.Module, state_dict: dict) -> dict: + """ + Preprocess state_dict before loading and slice the state_dict of MOE tensors. + """ for name, param in state_dict.items(): - if '.experts.' in name: + if ".experts." in name: model_param = dict(model.named_parameters())[name] if is_moe_tensor(model_param): ep_rank = get_ep_rank(model_param) @@ -28,13 +56,99 @@ def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): assert param.shape[0] % ep_size == 0 param = param[ep_rank * expert_num:(ep_rank + 1) * expert_num] state_dict[name] = param + dist.barrier() + return state_dict + + def _model_sharder( + self, + state_dict: nn.Module, + prefix: str = "", + keep_vars: bool = False, + size_per_shard: int = 1024, + ) -> Iterator[Tuple[OrderedDict, int]]: + # An internel method that breaks state_dict of model into shards within limited size. + state_dict_sharder = StateDictSharder(size_per_shard) + + for name, param in state_dict.items(): + if param is None: + continue + # Gather tensor pieces when using tensor parallel. + param_ = gather_distributed_param(param, keep_vars=False) + block, block_size = state_dict_sharder.append_param(prefix + name, param_) + if block is not None: + yield block, block_size + + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size + + def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool) -> None: + state_dict = torch.load(checkpoint) + state_dict = self.pre_load_model(model, state_dict) + model.load_state_dict(state_dict, strict=strict if self.pp_size == 1 else False) + + def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False): + """ + Load sharded model with the given path to index file of checkpoint folder. + + Args: + model (nn.Module): The model to be loaded. + checkpoint_index_file (str): Path to the index file of checkpointing folder. + strict (bool, optional): For name matching during loading state_dict. Defaults to False. + This argument should be manually set to False since params on same device might be stored in different files. + """ + + # Check whether the checkpoint uses safetensors. + use_safetensors = False + if "safetensors" in checkpoint_index_file.name: + use_safetensors = True + + if use_safetensors and not is_safetensors_available(): + raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.") + + # Read checkpoint index file. + ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) + ckpt_root_path = ckpt_index_file.root_path + weight_map = ckpt_index_file.weight_map + strict = False + + # Load params & buffers to model. + # Keep a record of loaded files so that file will not be repeatedly loaded. + loaded_file = set() + + def _load(name: str): + if name not in weight_map: + raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!") + filename = weight_map[name] + + # If this param/buffer has been loaded before, directly return. + if filename in loaded_file: + return - model.load_state_dict(state_dict, strict=strict) + file_path = os.path.join(ckpt_root_path, filename) + state_dict = load_shard_state_dict(Path(file_path), use_safetensors) + state_dict = self.pre_load_model(model, state_dict) + missing_keys = [] - def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + load_state_dict_into_model( + model, + state_dict, + missing_keys=missing_keys, + strict=strict, + load_sub_module=True, + ) + loaded_file.add(filename) + + # Load parameters. + for name, _ in model.named_parameters(): + _load(name) + + if self.verbose: + logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") + + def pre_save_model(self, model: nn.Module) -> dict: state_dict = model.state_dict() for name, param in model.named_parameters(): - if '.experts.' in name and is_moe_tensor(param): + if ".experts." in name and is_moe_tensor(param): ep_group = get_ep_group(param) ep_rank = get_ep_rank(param) ep_size = get_ep_size(param) @@ -45,19 +159,95 @@ def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor # gather param from every ep rank dist.all_gather(all_param, param, group=ep_group) if ep_rank == 0: - assert dist.get_rank() == 0 all_param = torch.cat(all_param, dim=0) state_dict[name] = all_param.cpu() + if self.pp_size > 1: + if self.dp_rank == 0: + out = [None for _ in range(self.pp_size)] + dist.all_gather_object(out, state_dict, group=self.pp_group) + if self.pp_rank == 0: + new_state_dict = {} + for o in out: + new_state_dict.update(o) + state_dict = new_state_dict + dist.barrier() + return state_dict + + def save_unsharded_model( + self, + model: nn.Module, + checkpoint: str, + gather_dtensor: bool, + use_safetensors: bool, + ): + state_dict = self.pre_save_model(model) if dist.get_rank() == 0: torch.save(state_dict, checkpoint) dist.barrier() - def load_sharded_model(self, model: nn.Module, index_file_path: str, strict: bool): - raise NotImplementedError() + def save_sharded_model( + self, + model: nn.Module, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + use_safetensors: bool = False, + ) -> None: + """ + Save sharded model checkpoint under the given checkpointing path. + The following files will be created under the path: + - An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names. + - Multiple files that store state tensors of models. + The filenames are in the form of "pytorch_model.-000XX.bin" - def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str], - size_per_shard: int, use_safetensors: bool): - raise NotImplementedError() + Args: + model (nn.Module): Model on local device to be saved. + checkpoint (str): Checkpointing path which should be a directory path. + gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True. + prefix (str, optional): Perfix of file to save. Defaults to None. + size_per_shard (int, optional): Size per shard in MB. Defaults to 1024. + use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. + """ + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + # Then collect the sharded parameters & buffers along tp_group. + # Only devices with tp_rank == 0 are responsible for model saving. + state_dict = self.pre_save_model(model) + + if dist.get_rank() == 0: + state_dict_shard = self._model_sharder(state_dict, size_per_shard=size_per_shard) + + # Devices along the same dp_group share the same copies of model. + # So only let the device with dp_rank == 0 save the model. + if self.dp_rank != 0: + return + + weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) + index_file = CheckpointIndexFile(checkpoint) + control_saving = self.tp_rank == 0 + + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors, + ) + if control_saving: + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + save_config_file(model, checkpoint) + if self.verbose: + logging.info(f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}.") + dist.barrier() # ======================================================== # Abstract methods for optimizer loading/saving implementation @@ -69,8 +259,14 @@ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, pre def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): raise NotImplementedError() - def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, - size_per_shard: int): + def save_sharded_optimizer( + self, + optimizer: Optimizer, + checkpoint: Path, + gather_dtensor: bool, + prefix: str, + size_per_shard: int, + ): raise NotImplementedError() def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool): diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py index 7f36f8a88925..d7dbd58ed0ca 100644 --- a/examples/language/openmoe/benchmark/benchmark_cai.py +++ b/examples/language/openmoe/benchmark/benchmark_cai.py @@ -72,7 +72,7 @@ def parse_args(): type=str, default="hybrid", help="parallel plugin", - choices=["zero1", "zero2", "hybrid"], + choices=["zero2", "zero2_ep", "hybrid"], ) # hybrid plugin parser.add_argument("--pp_size", type=int, default=2, help="pp size") @@ -112,17 +112,24 @@ def main(): # Set plugin booster_kwargs = {} - if args.plugin == "zero1": + if args.plugin == "zero2": dp_size = dist.get_world_size() - plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=1) + plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2) MOE_MANAGER.setup( seed=42, - parallel="EP", + parallel=None, use_kernel_optim=args.use_kernel, ) - elif args.plugin == "zero2": + elif args.plugin == "zero2_ep": dp_size = dist.get_world_size() - plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2) + plugin = MoeHybridParallelPlugin( + tp_size=1, + pp_size=1, + zero_stage=2, + custom_policy=OpenMoeForCausalLMPolicy(), + enable_fused_normalization=args.use_kernel, + enable_jit_fused=args.use_kernel, + ) MOE_MANAGER.setup( seed=42, parallel="EP", @@ -215,6 +222,7 @@ def main(): pbar.set_postfix({"loss": loss.item()}) else: # Forward pass + data = next(train_dataloader_iter) data = move_to_cuda(data, torch.cuda.current_device()) outputs = model(**data) loss = outputs["loss"] diff --git a/examples/language/openmoe/benchmark/benchmark_cai.sh b/examples/language/openmoe/benchmark/benchmark_cai.sh index 24d0c1b23ab2..620bd4901ccd 100755 --- a/examples/language/openmoe/benchmark/benchmark_cai.sh +++ b/examples/language/openmoe/benchmark/benchmark_cai.sh @@ -2,12 +2,11 @@ set -xue -NUM_GPU=4 -MODEL="base" -BATCH_SIZE=1 +NUM_GPU=8 +MODEL="8b" SEQ_LENGTH=2048 -WARMUP=10 -ACTIVE=10 +WARMUP=5 +ACTIVE=5 # HACK: make model importable example_dir=$(dirname $(realpath $(dirname $0))) @@ -21,7 +20,7 @@ fi torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ --model_name $MODEL \ - --batch_size $BATCH_SIZE \ + --batch_size 512 \ --seq_length $SEQ_LENGTH \ --warmup $WARMUP \ --active $ACTIVE \ @@ -29,28 +28,28 @@ torchrun --standalone --nproc_per_node $NUM_GPU \ --plugin hybrid \ --pp_size 2 \ --dp_size 1 \ - --ep_size 2 \ + --ep_size 4 \ --zero_stage 1 \ - --microbatch_size 1 + --microbatch_size 32 -# zero1 +# zero2 torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ --model_name $MODEL \ - --batch_size $BATCH_SIZE \ + --batch_size 8 \ --seq_length $SEQ_LENGTH \ --warmup $WARMUP \ --active $ACTIVE \ - --plugin zero1 \ + --plugin zero2 \ --use_kernel -# zero2 +# zero2_ep torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ --model_name $MODEL \ - --batch_size $BATCH_SIZE \ + --batch_size 16 \ --seq_length $SEQ_LENGTH \ --warmup $WARMUP \ --active $ACTIVE \ - --plugin zero2 \ + --plugin zero2_ep \ --use_kernel diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.py b/examples/language/openmoe/benchmark/benchmark_fsdp.py index c7357c06e5c7..1b69c8d4abeb 100644 --- a/examples/language/openmoe/benchmark/benchmark_fsdp.py +++ b/examples/language/openmoe/benchmark/benchmark_fsdp.py @@ -1,13 +1,15 @@ import argparse +import functools import os import torch import torch.distributed as dist import torch.multiprocessing as mp import tqdm -from model.modeling_openmoe import LlamaConfig, OpenMoeForCausalLM +from model.modeling_openmoe import LlamaConfig, OpenMoeDecoderLayer, OpenMoeForCausalLM from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.utils.data import Dataset from torch.utils.data.distributed import DistributedSampler from transformers import Adafactor @@ -18,8 +20,9 @@ class RandomDataset(Dataset): - - def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000): + def __init__( + self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000 + ): self.num_samples = num_samples self.max_length = max_length self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length)) @@ -45,9 +48,13 @@ def fsdp_main(rank, world_size, args): MOE_MANAGER.setup(seed=42, parallel=None, use_kernel_optim=False) dp_size = dist.get_world_size() - dataset = RandomDataset(max_length=args.seq_length, - num_samples=args.batch_size * (args.warmup + args.active) * dp_size) - sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size, shuffle=False) + dataset = RandomDataset( + max_length=args.seq_length, + num_samples=args.batch_size * (args.warmup + args.active) * dp_size, + ) + sampler = DistributedSampler( + dataset, rank=rank, num_replicas=world_size, shuffle=False + ) train_kwargs = {"batch_size": args.batch_size, "sampler": sampler} train_loader = torch.utils.data.DataLoader(dataset, **train_kwargs) torch.cuda.set_device(rank) @@ -57,7 +64,13 @@ def fsdp_main(rank, world_size, args): setattr(config, "router_z_loss_factor", 0.1) setattr(config, "label_smoothing", 0.1) setattr(config, "z_loss_factor", 0.1) - model = OpenMoeForCausalLM(config).to(rank) + model = OpenMoeForCausalLM(config) + auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls={ + OpenMoeDecoderLayer, + }, + ) model = FSDP( model, mixed_precision=MixedPrecision( @@ -65,6 +78,8 @@ def fsdp_main(rank, world_size, args): reduce_dtype=torch.float16, buffer_dtype=torch.float16, ), + auto_wrap_policy=auto_wrap_policy, + device_id=torch.cuda.current_device(), ) optimizer = Adafactor(model.parameters()) model.train() @@ -99,7 +114,9 @@ def fsdp_main(rank, world_size, args): performance_evaluator.on_fit_end() if dist.get_rank() == 0: - print(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") + print( + f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB" + ) if __name__ == "__main__": diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.sh b/examples/language/openmoe/benchmark/benchmark_fsdp.sh index a4cb32019431..41ffcd882a3b 100755 --- a/examples/language/openmoe/benchmark/benchmark_fsdp.sh +++ b/examples/language/openmoe/benchmark/benchmark_fsdp.sh @@ -2,12 +2,12 @@ set -xue -NUM_GPU=4 -MODEL="base" +NUM_GPU=8 +MODEL="8b" BATCH_SIZE=1 SEQ_LENGTH=2048 -WARMUP=10 -ACTIVE=10 +WARMUP=5 +ACTIVE=5 # HACK: make model importable example_dir=$(dirname $(realpath $(dirname $0))) diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index 6933f108a09e..f8c79320fa57 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -997,10 +997,11 @@ def forward(ctx, logits, targets, z_loss): shifted = logits - max_logit exp_shifted = torch.exp(shifted) sum_exp = torch.sum(exp_shifted, axis=-1, keepdims=True) - log_softmax = shifted - torch.log(sum_exp) + sum_exp_log = torch.log(sum_exp) + log_softmax = shifted - sum_exp_log loss = -torch.sum(targets * log_softmax, axis=-1) # Add auxilliary z-loss term. - log_z = torch.squeeze(torch.log(sum_exp) + max_logit, axis=-1) + log_z = torch.squeeze(sum_exp_log + max_logit, axis=-1) total_z_loss = z_loss * torch.square(log_z) loss += total_z_loss ctx.z_loss = z_loss diff --git a/examples/language/openmoe/model/openmoe_policy.py b/examples/language/openmoe/model/openmoe_policy.py index cc82683cd319..f354bbea990e 100644 --- a/examples/language/openmoe/model/openmoe_policy.py +++ b/examples/language/openmoe/model/openmoe_policy.py @@ -97,7 +97,7 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli else: module = self.model.model - layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages) + layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} self.append_or_create_method_replacement(description=method_replacement, @@ -110,7 +110,7 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" assert self.pipeline_stage_manager is not None - if self.model.__class__.__name__ == "LlamaModel": + if self.model.__class__.__name__ == "OpenMoeModel": module = self.model else: module = self.model.model @@ -126,6 +126,23 @@ def get_held_layers(self) -> List[Module]: held_layers.append(module.norm) return held_layers + + @staticmethod + def distribute_layers(num_layers: int, num_stages: int) -> List[int]: + """Divide layers into stages + + """ + if num_layers == 24 and num_stages == 4: + return [7, 7, 7, 3] + elif num_layers == 24 and num_stages == 2: + return [15, 9] + elif num_layers == 12 and num_stages == 4: + return [5, 5, 5, 1] + elif num_layers == 12 and num_stages == 2: + return [8, 4] + else: + print(f"num_layers: {num_layers}, num_stages: {num_stages} not optimized, use origin pp policy") + return Policy.distribute_layers(num_layers, num_stages) class OpenMoeModelPolicy(OpenMoePolicy): @@ -401,7 +418,7 @@ def llama_for_causal_lm_forward( stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, - chunk_head: Optional[bool] = None, + chunk_head: Optional[bool] = True, past_router_aux_loss: Optional[torch.FloatTensor] = None, past_router_z_loss: Optional[torch.FloatTensor] = None, ): diff --git a/examples/language/openmoe/test_ci.sh b/examples/language/openmoe/test_ci.sh index 86742e088f71..0f68db4275f7 100644 --- a/examples/language/openmoe/test_ci.sh +++ b/examples/language/openmoe/test_ci.sh @@ -7,5 +7,15 @@ python infer.py --model "test" torchrun --standalone --nproc_per_node 4 train.py \ --num_epoch 1 \ --model_name "test" \ - --plugin zero2 \ + --plugin zero2_ep \ + --batch_size 1 + +torchrun --standalone --nproc_per_node 4 train.py \ + --model_name "test" \ + --plugin "hybrid" \ + --num_epoch 1 \ + --pp_size 2 \ + --dp_size 1 \ + --ep_size 2 \ + --zero_stage 1 \ --batch_size 1 diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index a7f46f2f693b..6f239104328c 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -14,7 +14,6 @@ import colossalai from colossalai import get_default_parser from colossalai.booster import Booster -from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator from colossalai.logging import disable_existing_loggers, get_dist_logger @@ -28,7 +27,7 @@ def move_to_cuda(batch, device): return {k: v.to(device) for k, v in batch.items()} -def load_ckpt(repo_name: str, model: OpenMoeForCausalLM): +def load_ckpt(repo_name: str, model: OpenMoeForCausalLM, booster: Booster): ckpt_path = snapshot_download(repo_name) # single ckpt if os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin")): @@ -38,7 +37,7 @@ def load_ckpt(repo_name: str, model: OpenMoeForCausalLM): ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json") else: raise ValueError(f"Invalid checkpoint path: {ckpt_path}") - MoeCheckpintIO().load_model(model, ckpt_path) + booster.load_model(model, ckpt_path) class RandomDataset(Dataset): @@ -89,7 +88,7 @@ def parse_args(): type=str, default="hybrid", help="parallel plugin", - choices=["zero1", "zero2", "hybrid"], + choices=["zero1_ep", "zero2_ep", "hybrid"], ) # hybrid plugin parser.add_argument("--pp_size", type=int, default=2, help="pp size") @@ -146,15 +145,29 @@ def main(): # Set plugin booster_kwargs = {} - if args.plugin == "zero1": - plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=1) + if args.plugin == "zero1_ep": + plugin = MoeHybridParallelPlugin( + tp_size=1, + pp_size=1, + zero_stage=1, + custom_policy=OpenMoeForCausalLMPolicy(), + enable_fused_normalization=args.use_kernel, + enable_jit_fused=args.use_kernel, + ) MOE_MANAGER.setup( seed=42, parallel="EP", use_kernel_optim=args.use_kernel if not test_mode else False, ) - elif args.plugin == "zero2": - plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2) + elif args.plugin == "zero2_ep": + plugin = MoeHybridParallelPlugin( + tp_size=1, + pp_size=1, + zero_stage=2, + custom_policy=OpenMoeForCausalLMPolicy(), + enable_fused_normalization=args.use_kernel, + enable_jit_fused=args.use_kernel, + ) MOE_MANAGER.setup( seed=42, parallel="EP", @@ -198,8 +211,6 @@ def main(): setattr(config, "z_loss_factor", args.z_loss_factor) with skip_init(): model = OpenMoeForCausalLM(config) - if not test_mode: - load_ckpt(repo_name, model) logger.info(f"Finish init model with config:\n{config}", ranks=[0]) # Enable gradient checkpointing @@ -216,6 +227,8 @@ def main(): # Set booster booster = Booster(plugin=plugin, **booster_kwargs) model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader) + if not test_mode: + load_ckpt(repo_name, model, booster) use_pipeline = (isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1) is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() logger.info(f"Finish init booster", ranks=[0]) diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index 20eb0969ca24..489f5ebdacfc 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -1,47 +1,152 @@ import os +import shutil import pytest import torch import torch.distributed as dist +from transformers.models.llama import LlamaConfig import colossalai -from colossalai.moe import MoeCheckpintIO +from colossalai.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device -from tests.test_moe.moe_utils import MoeModel +from examples.language.openmoe.model.modeling_openmoe import OpenMoeForCausalLM +from examples.language.openmoe.model.openmoe_policy import OpenMoeForCausalLMPolicy -def exam_moe_checkpoint(): - ckpt = MoeCheckpintIO() - model = MoeModel(checkpoint=True).to(get_current_device()) - ckpt.save_model(model, 'temp_path.pth') +def get_config(): + config = LlamaConfig( + vocab_size=300, + hidden_size=32, + intermediate_size=64, + num_hidden_layers=2, + num_attention_heads=2, + ) + settings = { + "vocab_size": 300, + "intermediate_size": 32, + "hidden_size": 16, + "num_hidden_layers": 2, + "head_dim": 4, + "num_attention_heads": 4, + "dropout_rate": 0.0, + "layer_norm_epsilon": 1e-06, + "hidden_act": "swiglu", + "num_experts": 16, + "topk": 2, + "capacity_factor_train": 1.25, + "capacity_factor_eval": 2.0, + "min_capacity": 4, + "noisy_policy": None, + "drop_tks": True, + "expert_parallel": None, + "gated": True, + "moe_layer_interval": 4, + "router_aux_loss_factor": 0.1, + "router_z_loss_factor": 0.1, + "label_smoothing": 0.1, + "z_loss_factor": 0.1, + } + for key, value in settings.items(): + setattr(config, key, value) + return config - other_model = MoeModel(checkpoint=True).to(get_current_device()) - ckpt.load_model(other_model, 'temp_path.pth') - state_0 = model.state_dict() - state_1 = other_model.state_dict() - for k, v in state_0.items(): - u = state_1.get(k) +def get_model(parallel): + config = get_config() + model = OpenMoeForCausalLM(config) + + if parallel == None: + plugin = MoeHybridParallelPlugin( + tp_size=1, + pp_size=1, + zero_stage=0, + custom_policy=OpenMoeForCausalLMPolicy(), + ) + elif parallel == "zero_ep": + plugin = MoeHybridParallelPlugin( + tp_size=1, + pp_size=1, + zero_stage=2, + custom_policy=OpenMoeForCausalLMPolicy(), + ) + elif parallel == "hybrid": + plugin = MoeHybridParallelPlugin( + tp_size=1, + pp_size=2, + zero_stage=1, + microbatch_size=1, + custom_policy=OpenMoeForCausalLMPolicy(), + ) + booster = Booster(plugin=plugin) + model, _, _, _, _ = booster.boost(model=model) + return model, booster + + +def _test_moe_checkpoint(parallel, shard): + if parallel == None: + MOE_MANAGER.setup( + seed=42, + parallel=None, + ) + elif parallel == "zero2_ep": + MOE_MANAGER.setup( + seed=42, + parallel="EP", + ) + elif parallel == "hybrid": + MOE_MANAGER.setup( + seed=42, + parallel="EP", + mode="fixed", + fixed_dp_size=1, + fixed_ep_size=2, + fixed_pp_size=2, + ) + model1, booster1 = get_model(parallel) + model2, booster2 = get_model(parallel) + + if shard: + booster1.save_model(model1, "./tmp_ckpt", shard=True, size_per_shard=1) + booster2.load_model(model2, "./tmp_ckpt") + else: + booster1.save_model(model1, "tmp_ckpt.pth") + booster2.load_model(model2, "tmp_ckpt.pth") + + state1 = model1.state_dict() + state2 = model2.state_dict() + for k, v in state1.items(): + u = state2.get(k) assert torch.equal(u.data, v.data) if dist.get_rank() == 0: - os.remove('temp_path.pth') + if shard: + shutil.rmtree("./tmp_ckpt") + else: + os.remove("tmp_ckpt.pth") -def _run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_MANAGER.setup(seed=42) - exam_moe_checkpoint() +def _run_dist(rank, world_size, port, parallel, shard): + colossalai.launch( + config=dict(), + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) + _test_moe_checkpoint(parallel, shard) @pytest.mark.dist -@pytest.mark.parametrize("world_size", [2, 4]) +@pytest.mark.parametrize("world_size", [4]) +@pytest.mark.parametrize("parallel", [None, "zero_ep", "hybrid"]) +@pytest.mark.parametrize("shard", [True, False]) @rerun_if_address_is_in_use() -def test_moe_checkpoint(world_size): - spawn(_run_dist, world_size) +def test_moe_checkpoint(world_size, parallel, shard): + spawn(_run_dist, world_size, parallel=parallel, shard=shard) -if __name__ == '__main__': - test_moe_checkpoint(world_size=4) +if __name__ == "__main__": + test_moe_checkpoint(world_size=4, parallel="hybrid", shard=True)