From 0dfb03af98cada5be05f479f7ae93e8cca712f58 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 14 Sep 2023 13:43:13 +0800 Subject: [PATCH 01/16] init policy --- .../language/openmoe/model/llama_policy.py | 468 ++++++++++++++++++ 1 file changed, 468 insertions(+) create mode 100644 examples/language/openmoe/model/llama_policy.py diff --git a/examples/language/openmoe/model/llama_policy.py b/examples/language/openmoe/model/llama_policy.py new file mode 100644 index 000000000000..c4421de6a36c --- /dev/null +++ b/examples/language/openmoe/model/llama_policy.py @@ -0,0 +1,468 @@ +import warnings +from functools import partial +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from modeling_openmoe import ( + OpenMoeAttention, + OpenMoeDecoderLayer, + OpenMoeForCausalLM, + OpenMoeMLP, + OpenMoeModel, + OpenMoePreTrainedModel, +) +from torch import Tensor +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, Module, MSELoss +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = ['OpenMoePolicy', 'OpenMoeForCausalLMPolicy'] + + +class OpenMoePolicy(Policy): + + def config_sanity_check(self): + pass + + def preprocess(self): + if self.shard_config.enable_tensor_parallelism: + # Resize embedding + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + policy = {} + + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + raise NotImplementedError( + "openmoe dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + + if self.shard_config.enable_tensor_parallelism: + raise NotImplementedError("Tensor parallelism is not supported for openmoe model now.") + + # optimization configuration + if self.shard_config.enable_fused_normalization: + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=FusedRMSNorm, + ), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", + target_module=FusedRMSNorm, + ), + SubModuleReplacementDescription( + suffix="pre_extra_mlp_layernorm", + target_module=FusedRMSNorm, + ) + ], + policy=policy, + target_key=OpenMoeDecoderLayer) + + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="norm", + target_module=FusedRMSNorm, + ), + policy=policy, + target_key=OpenMoeModel) + + if self.shard_config.enable_flash_attention: + raise NotImplementedError("Flash attention has already been replaced in openmoe.") + + return policy + + def postprocess(self): + return self.model + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "LlamaModel": + module = self.model + else: + module = self.model.model + + layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=model_cls) + + return + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == 'LlamaModel': + module = self.model + else: + module = self.model.model + stage_manager = self.pipeline_stage_manager + + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.norm) + + return held_layers + + +class OpenMoeModelPolicy(OpenMoePolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + from transformers.models.llama.modeling_llama import LlamaModel + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward(model_cls=LlamaModel, + new_forward=OpenMoePipelineForwards.llama_model_forward, + policy=policy) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + held_layers = super().get_held_layers() + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in llama model""" + return [] + + +class OpenMoeForCausalLMPolicy(OpenMoePolicy): + + def module_policy(self): + + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + # add a new item for casual lm + new_item = { + OpenMoeForCausalLM: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) + ]) + } + policy.update(new_item) + + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward(model_cls=OpenMoeForCausalLM, + new_forward=OpenMoePipelineForwards.llama_for_causal_lm_forward, + policy=policy) + + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + llama_model = self.model.model + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: + if id(llama_model.embed_tokens.weight) == id( + self.model.lm_head.weight) and self.pipeline_stage_manager.num_stages > 1: + # tie weights + return [{ + 0: llama_model.embed_tokens.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight + }] + return [] + + +class OpenMoePipelineForwards: + ''' + This class serves as a micro library for forward function substitution of Llama models + under pipeline setting. + ''' + + @staticmethod + def llama_model_forward( + self: OpenMoeModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ): + logger = logging.get_logger(__name__) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + device = input_ids.device if input_ids is not None else inputs_embeds.device + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + seq_length_with_past = seq_length + past_key_values_length = 0 + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + position_ids = torch.arange(past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # embed positions, for the first stage, hidden_states is the input embeddings, + # for the other stages, hidden_states is the output of the previous stage + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), + dtype=torch.bool, + device=hidden_states.device) + attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), hidden_states, + past_key_values_length) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + start_idx, end_idx = stage_index[0], stage_index[1] + for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if stage_manager.is_last_stage(): + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if stage_manager.is_last_stage(): + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + # always return dict for imediate stage + return {'hidden_states': hidden_states} + + @staticmethod + def llama_for_causal_lm_forward( + self: OpenMoeForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" + logger = logging.get_logger(__name__) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = OpenMoePipelineForwards.llama_model_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) + past_key_values = None + all_hidden_states = None + all_self_attentions = None + all_cross_attentions = None + + if stage_manager.is_last_stage(): + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + hidden_states = outputs.get('hidden_states') + return {'hidden_states': hidden_states} From d8e4b8064f894605e15d31d10fd9854f4a444406 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 14 Sep 2023 13:52:49 +0800 Subject: [PATCH 02/16] renam,e --- examples/language/openmoe/model/__init__.py | 0 .../{llama_policy.py => openmoe_policy.py} | 83 ++++++++++++++----- 2 files changed, 63 insertions(+), 20 deletions(-) create mode 100644 examples/language/openmoe/model/__init__.py rename examples/language/openmoe/model/{llama_policy.py => openmoe_policy.py} (88%) diff --git a/examples/language/openmoe/model/__init__.py b/examples/language/openmoe/model/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/examples/language/openmoe/model/llama_policy.py b/examples/language/openmoe/model/openmoe_policy.py similarity index 88% rename from examples/language/openmoe/model/llama_policy.py rename to examples/language/openmoe/model/openmoe_policy.py index c4421de6a36c..53d4675f14c0 100644 --- a/examples/language/openmoe/model/llama_policy.py +++ b/examples/language/openmoe/model/openmoe_policy.py @@ -4,14 +4,7 @@ import torch import torch.nn as nn -from modeling_openmoe import ( - OpenMoeAttention, - OpenMoeDecoderLayer, - OpenMoeForCausalLM, - OpenMoeMLP, - OpenMoeModel, - OpenMoePreTrainedModel, -) +import torch.nn.functional as F from torch import Tensor from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, Module, MSELoss from transformers.modeling_outputs import ( @@ -21,10 +14,20 @@ ) from transformers.utils import logging +from colossalai.moe.manager import MOE_MANAGER from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from .modeling_openmoe import ( + OpenMoeAttention, + OpenMoeDecoderLayer, + OpenMoeForCausalLM, + OpenMoeMLP, + OpenMoeModel, + OpenMoePreTrainedModel, +) + __all__ = ['OpenMoePolicy', 'OpenMoeForCausalLMPolicy'] @@ -375,6 +378,7 @@ def llama_for_causal_lm_forward( stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, + chunk_head: Optional[bool] = None, ): r""" Args: @@ -401,6 +405,9 @@ def llama_for_causal_lm_forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." ```""" + # reset moe loss + MOE_MANAGER.reset_loss() + logger = logging.get_logger(__name__) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = (output_hidden_states @@ -438,19 +445,55 @@ def llama_for_causal_lm_forward( if stage_manager.is_last_stage(): hidden_states = outputs[0] - logits = self.lm_head(hidden_states) + if self.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) + # if no training, just do forward + if labels is None: + logits = self.lm_head(hidden_states) + logits = logits.float() + # the vocab size for openmoe is 30w+ + # which causes great activation memory in training, up to 20G for one sequence + # so we use chunk and checkpoint to reduce memory + else: + if chunk_head == True: + + def create_custom_forward(module): + + def custom_forward(*inputs): + logits = module(inputs[0]) + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous().float() + shift_labels = inputs[1][..., 1:].contiguous() + # Flatten the tokens + loss = self._calculate_loss(shift_logits, shift_labels) + return loss + + return custom_forward + + aux_loss, z_loss = self._calculate_router_loss() + loss = aux_loss + z_loss + for batch_idx in range(hidden_states.shape[0]): + loss = loss + torch.utils.checkpoint.checkpoint( + create_custom_forward(self.lm_head), + hidden_states[batch_idx:batch_idx + 1, :], + labels[batch_idx:batch_idx + 1, :], + ) + logits = None + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + aux_loss, z_loss = self._calculate_router_loss() + loss = aux_loss + z_loss + loss = loss + self._calculate_loss(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] From 7f68f632c252dede6c5fd72be80e966bd8aeebcd Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 14 Sep 2023 17:15:15 +0800 Subject: [PATCH 03/16] update pp --- .../booster/plugin/hybrid_parallel_plugin.py | 13 +- .../openmoe/model/modeling_openmoe.py | 137 ++---------------- .../language/openmoe/model/openmoe_policy.py | 21 +-- examples/language/openmoe/train.py | 73 +++++++--- 4 files changed, 83 insertions(+), 161 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 3fbeebcc4110..d65bd437962e 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -22,6 +22,7 @@ from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer.policies.base_policy import Policy from colossalai.zero.low_level import LowLevelZeroOptimizer from .pp_plugin_base import PipelinePluginBase @@ -38,13 +39,15 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16): class HybridParallelModule(ModelWrapper): def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool, - ddp_config: dict) -> None: + ddp_config: dict, custom_policy: Policy) -> None: self.stage_manager = shard_config.pipeline_stage_manager self.dp_group = dp_group shardformer = ShardFormer(shard_config) - module, self.shared_params = shardformer.optimize(module) + if custom_policy is not None: + assert isinstance(custom_policy, object) + module, self.shared_params = shardformer.optimize(module, policy=custom_policy) # setting process groups for shared parameters self.shared_param_process_groups = [] @@ -302,7 +305,8 @@ def __init__(self, zero_bucket_size_in_m: int = 12, cpu_offload: bool = False, communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = True) -> None: + overlap_communication: bool = True, + custom_policy: Policy = None) -> None: super().__init__() assert dist.get_world_size() % ( @@ -326,6 +330,7 @@ def __init__(self, self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) self.stage_manager = None self.schedule = None + self.custom_policy = custom_policy assert zero_stage in (0, 1, 2) if self.pp_size > 1: assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism' @@ -405,7 +410,7 @@ def configure( if not isinstance(model, ModelWrapper): use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp, - self.ddp_config) + self.ddp_config, self.custom_policy) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if self.zero_stage == 0: if self.precision in ['fp16', 'bf16']: diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index 6ccbf64a60e4..7d95fedce26e 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -145,87 +145,6 @@ def apply_rotary_embedding(q, k, cos, sin, decode=False, rotary_index=None): return out_q, out_k -class LlamaRotaryEmbedding(torch.nn.Module): - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.inv_freq = inv_freq - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache(seq_len=max_position_embeddings, - device=self.inv_freq.device, - dtype=torch.get_default_dtype()) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - ) - - -class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - t = t / self.scaling_factor - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - -class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - - if seq_len > self.max_position_embeddings: - base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) - - (self.scaling_factor - 1))**(self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq) - - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., :x.shape[-1] // 2] @@ -233,17 +152,6 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - def SwiGLU(x): """Gated linear unit activation function. Args: @@ -256,7 +164,7 @@ def SwiGLU(x): return x1 * (x2 * torch.sigmoid(x2)) -class LlamaMLP(nn.Module): +class OpenMoeMLP(nn.Module): def __init__(self, config): super().__init__() @@ -302,7 +210,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -class LlamaAttention(nn.Module): +class OpenMoeAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: LlamaConfig): @@ -321,22 +229,6 @@ def __init__(self, config: LlamaConfig): self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.sin, self.cos = generate_fixed_pos_embedding(self.head_dim, self.max_position_embeddings, 1e4) - self._init_rope() - - def _init_rope(self): - if self.config.rope_scaling is None: - self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) - else: - scaling_type = self.config.rope_scaling["type"] - scaling_factor = self.config.rope_scaling["factor"] - if scaling_type == "linear": - self.rotary_emb = LlamaLinearScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor) - elif scaling_type == "dynamic": - self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -446,13 +338,13 @@ def forward( return attn_output, attn_weights, past_key_value -class LlamaDecoderLayer(nn.Module): +class OpenMoeDecoderLayer(nn.Module): def __init__(self, config: LlamaConfig, moe: bool): super().__init__() self.hidden_size = config.hidden_size self.moe = moe - self.self_attn = LlamaAttention(config=config) + self.self_attn = OpenMoeAttention(config=config) self.input_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) if self.moe: @@ -470,9 +362,9 @@ def __init__(self, config: LlamaConfig, moe: bool): activation=config.hidden_act, gated=config.gated) self.pre_extra_mlp_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) - self.extra_mlp = LlamaMLP(config) + self.extra_mlp = OpenMoeMLP(config) else: - self.mlp = LlamaMLP(config) + self.mlp = OpenMoeMLP(config) def forward( self, @@ -556,7 +448,7 @@ def forward( "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", LLAMA_START_DOCSTRING, ) -class LlamaPreTrainedModel(PreTrainedModel): +class OpenMoePreTrainedModel(PreTrainedModel): config_class = LlamaConfig base_model_prefix = "model" supports_gradient_checkpointing = True @@ -575,7 +467,7 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, LlamaModel): + if isinstance(module, OpenMoeModel): module.gradient_checkpointing = value @@ -647,7 +539,7 @@ def _set_gradient_checkpointing(self, module, value=False): "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", LLAMA_START_DOCSTRING, ) -class LlamaModel(LlamaPreTrainedModel): +class OpenMoeModel(OpenMoePreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] @@ -662,7 +554,7 @@ def __init__(self, config: LlamaConfig): self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, moe=True if (i + 1) % config.moe_layer_interval == 0 else False) + OpenMoeDecoderLayer(config, moe=True if (i + 1) % config.moe_layer_interval == 0 else False) for i in range(config.num_hidden_layers) ]) self.norm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -827,12 +719,12 @@ def custom_forward(*inputs): ) -class OpenMoeForCausalLM(LlamaPreTrainedModel): +class OpenMoeForCausalLM(OpenMoePreTrainedModel): # _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) - self.model = LlamaModel(config) + self.model = OpenMoeModel(config) self.pretraining_tp = config.pretraining_tp self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) @@ -1029,10 +921,7 @@ def _calculate_router_loss(self): z_loss = self.config.router_z_loss_factor * sum(z_loss) / len(z_loss) return aux_loss, z_loss - def _calculate_loss(self, - logits: torch.Tensor, - targets: torch.Tensor - ) -> torch.Tensor: + def _calculate_loss(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """Compute cross entropy and entropy for log probs and targets. Args: diff --git a/examples/language/openmoe/model/openmoe_policy.py b/examples/language/openmoe/model/openmoe_policy.py index 53d4675f14c0..df82e6deb721 100644 --- a/examples/language/openmoe/model/openmoe_policy.py +++ b/examples/language/openmoe/model/openmoe_policy.py @@ -1,32 +1,21 @@ import warnings from functools import partial -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Union import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, Module, MSELoss -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - SequenceClassifierOutputWithPast, -) +from torch.nn import Module +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.utils import logging from colossalai.moe.manager import MOE_MANAGER from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D +from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -from .modeling_openmoe import ( - OpenMoeAttention, - OpenMoeDecoderLayer, - OpenMoeForCausalLM, - OpenMoeMLP, - OpenMoeModel, - OpenMoePreTrainedModel, -) +from .modeling_openmoe import OpenMoeDecoderLayer, OpenMoeForCausalLM, OpenMoeModel __all__ = ['OpenMoePolicy', 'OpenMoeForCausalLMPolicy'] diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index 132f17a9ba0f..3ce97841730a 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -5,6 +5,7 @@ import transformers from huggingface_hub import snapshot_download from model.modeling_openmoe import OpenMoeForCausalLM +from model.openmoe_policy import OpenMoeForCausalLMPolicy from torch.utils.data import Dataset from tqdm import tqdm from transformers import Adafactor, T5Tokenizer @@ -13,7 +14,7 @@ import colossalai from colossalai import get_default_parser from colossalai.booster import Booster -from colossalai.booster.plugin import LowLevelZeroPlugin +from colossalai.booster.plugin import HybridParallelPlugin, LowLevelZeroPlugin from colossalai.cluster import DistCoordinator from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.moe import MoeCheckpintIO @@ -59,6 +60,7 @@ def __getitem__(self, idx): def parse_args(): + # basic settings parser = get_default_parser() parser.add_argument("--model_name", type=str, @@ -74,6 +76,16 @@ def parse_args(): default=4, help="Batch size (per dp group) for the training dataloader.") parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + parser.add_argument("--plugin", + type=str, + default="zero2", + help="parallel plugin", + choices=["zero1", "zero2", "hybrid"]) + # hybrid plugin + parser.add_argument("--tp_size", type=int, default=1, help="tp size") + parser.add_argument("--pp_size", type=int, default=2, help="pp size") + parser.add_argument("--zero_stage", type=int, default=1, help="zero stage in hybrid plugin") + parser.add_argument("--microbatch_size", type=int, default=1, help="microbatch size") # loss parser.add_argument("--router_aux_loss_factor", type=float, default=0.01, help="router_aux_loss_factor.") parser.add_argument("--router_z_loss_factor", type=float, default=0.0001, help="router_z_loss_factor.") @@ -95,7 +107,7 @@ def main(): coordinator = DistCoordinator() # Set up moe - MOE_MANAGER.setup(seed=42, parallel="EP") + MOE_MANAGER.setup(seed=42, parallel=None) # Manage loggers disable_existing_loggers() @@ -129,12 +141,23 @@ def main(): # Set plugin booster_kwargs = {} - plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2) + if args.plugin == "zero1": + plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=1) + elif args.plugin == "zero2": + plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2) + elif args.plugin == "hybrid": + plugin = HybridParallelPlugin(tp_size=args.tp_size, + pp_size=args.pp_size, + zero_stage=args.zero_stage, + microbatch_size=args.microbatch_size, + custom_policy=OpenMoeForCausalLMPolicy()) + else: + raise ValueError(f"Invalid plugin {args.plugin}") logger.info(f"Set plugin as {plugin}", ranks=[0]) # Prepare tokenizer and dataloader tokenizer = T5Tokenizer.from_pretrained("google/umt5-small") - dataset = RandomDataset(num_samples=1000 if args.model_name != "test" else 1) + dataset = RandomDataset(num_samples=1000 if args.model_name != "test" else 10) dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) # Set optimizer @@ -143,27 +166,43 @@ def main(): # Set booster booster = Booster(plugin=plugin, **booster_kwargs) model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader) + use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 + is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() logger.info(f"Finish init booster", ranks=[0]) # Start finetuning logger.info(f"Start finetuning", ranks=[0]) for epoch in range(args.num_epoch): model.train() - with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar: - for batch in pbar: - # Forward - optimizer.zero_grad() - batch = move_to_cuda(batch, torch.cuda.current_device()) + train_dataloader_iter = iter(dataloader) + total_len = len(train_dataloader_iter) + with tqdm(range(total_len), + desc=f'Epoch [{epoch + 1}/{args.num_epoch}]', + disable=not (coordinator.is_master() or is_pp_last_stage)) as pbar: + # Forward pass + for _ in pbar: + if use_pipeline: + outputs = booster.execute_pipeline(train_dataloader_iter, + model, + lambda x: x, + optimizer, + return_loss=True, + return_outputs=True) + # Backward and optimize + if is_pp_last_stage: + loss = outputs['loss'] + pbar.set_postfix({'loss': loss.item()}) + else: + data = next(train_dataloader_iter) + data = move_to_cuda(data, torch.cuda.current_device()) + outputs = model(**data) + loss = outputs['loss'] + # Backward + booster.backward(loss, optimizer) + pbar.set_postfix({'loss': loss.item()}) - outputs = model(use_cache=False, chunk_head=True, **batch) - loss = outputs['loss'] - - # Backward - booster.backward(loss, optimizer) optimizer.step() - - # Print batch loss - pbar.set_postfix({'loss': loss.item()}) + optimizer.zero_grad() # Finish training and evaluate logger.info(f"Finish finetuning", ranks=[0]) From af224af2731cdd1a6db48a532916e8e3ded2bbdb Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 15 Sep 2023 13:11:22 +0800 Subject: [PATCH 04/16] finish pp --- .../openmoe/model/modeling_openmoe.py | 5 +- .../language/openmoe/model/openmoe_policy.py | 66 +++++++++++-------- examples/language/openmoe/train.py | 11 ++-- 3 files changed, 45 insertions(+), 37 deletions(-) diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index 7d95fedce26e..d8289b791dd5 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -914,8 +914,9 @@ def _reorder_cache(past_key_values, beam_idx): past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),) return reordered_past - def _calculate_router_loss(self): - aux_loss, z_loss = MOE_MANAGER.get_loss() + def _calculate_router_loss(self, aux_loss: list = None, z_loss: list = None): + if aux_loss is None or z_loss is None: + aux_loss, z_loss = MOE_MANAGER.get_loss() assert len(aux_loss) == len(z_loss) == self.config.num_hidden_layers // self.config.moe_layer_interval aux_loss = self.config.router_aux_loss_factor * sum(aux_loss) / len(aux_loss) z_loss = self.config.router_z_loss_factor * sum(z_loss) / len(z_loss) diff --git a/examples/language/openmoe/model/openmoe_policy.py b/examples/language/openmoe/model/openmoe_policy.py index df82e6deb721..21e25bcb73a0 100644 --- a/examples/language/openmoe/model/openmoe_policy.py +++ b/examples/language/openmoe/model/openmoe_policy.py @@ -130,11 +130,10 @@ def __init__(self) -> None: def module_policy(self): policy = super().module_policy() - from transformers.models.llama.modeling_llama import LlamaModel if self.pipeline_stage_manager: # set None as default - self.set_pipeline_forward(model_cls=LlamaModel, - new_forward=OpenMoePipelineForwards.llama_model_forward, + self.set_pipeline_forward(model_cls=OpenMoeModel, + new_forward=OpenMoePipelineForwards.openmoe_model_forward, policy=policy) return policy @@ -201,7 +200,7 @@ class OpenMoePipelineForwards: ''' @staticmethod - def llama_model_forward( + def openmoe_model_forward( self: OpenMoeModel, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, @@ -215,7 +214,12 @@ def llama_model_forward( stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, + past_router_aux_loss: Optional[torch.FloatTensor] = None, + past_router_z_loss: Optional[torch.FloatTensor] = None, ): + # reset moe loss for different data + MOE_MANAGER.reset_loss() + logger = logging.get_logger(__name__) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -339,17 +343,17 @@ def custom_forward(*inputs): if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + + # concat past losses with current ones + router_aux_loss, router_z_loss = MOE_MANAGER.get_loss() + if past_router_aux_loss is not None and past_router_z_loss is not None: + router_aux_loss = past_router_aux_loss + router_aux_loss + router_z_loss = past_router_z_loss + router_z_loss + if stage_manager.is_last_stage(): - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) + return tuple([hidden_states, next_cache, all_hidden_states, all_self_attns, router_aux_loss, router_z_loss]) # always return dict for imediate stage - return {'hidden_states': hidden_states} + return {'hidden_states': hidden_states, 'router_aux_loss': router_aux_loss, 'router_z_loss': router_z_loss} @staticmethod def llama_for_causal_lm_forward( @@ -368,6 +372,8 @@ def llama_for_causal_lm_forward( hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, chunk_head: Optional[bool] = None, + past_router_aux_loss: Optional[torch.FloatTensor] = None, + past_router_z_loss: Optional[torch.FloatTensor] = None, ): r""" Args: @@ -394,9 +400,6 @@ def llama_for_causal_lm_forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." ```""" - # reset moe loss - MOE_MANAGER.reset_loss() - logger = logging.get_logger(__name__) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = (output_hidden_states @@ -412,7 +415,7 @@ def llama_for_causal_lm_forward( output_hidden_states = False # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = OpenMoePipelineForwards.llama_model_forward( + outputs = OpenMoePipelineForwards.openmoe_model_forward( self.model, input_ids=input_ids, attention_mask=attention_mask, @@ -426,14 +429,13 @@ def llama_for_causal_lm_forward( stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, + past_router_aux_loss=past_router_aux_loss, + past_router_z_loss=past_router_z_loss, ) - past_key_values = None - all_hidden_states = None - all_self_attentions = None - all_cross_attentions = None if stage_manager.is_last_stage(): - hidden_states = outputs[0] + hidden_states, past_key_values, all_hidden_states, attentions, router_aux_loss, router_z_loss = outputs + if self.pretraining_tp > 1: lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0) logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)] @@ -464,7 +466,7 @@ def custom_forward(*inputs): return custom_forward - aux_loss, z_loss = self._calculate_router_loss() + aux_loss, z_loss = self._calculate_router_loss(router_aux_loss, router_z_loss) loss = aux_loss + z_loss for batch_idx in range(hidden_states.shape[0]): loss = loss + torch.utils.checkpoint.checkpoint( @@ -480,7 +482,7 @@ def custom_forward(*inputs): shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens - aux_loss, z_loss = self._calculate_router_loss() + aux_loss, z_loss = self._calculate_router_loss(router_aux_loss, router_z_loss) loss = aux_loss + z_loss loss = loss + self._calculate_loss(shift_logits, shift_labels) @@ -491,10 +493,16 @@ def custom_forward(*inputs): return CausalLMOutputWithPast( loss=loss, logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=attentions, ) else: - hidden_states = outputs.get('hidden_states') - return {'hidden_states': hidden_states} + hidden_states = outputs['hidden_states'] + router_aux_loss = outputs['router_aux_loss'] + router_z_loss = outputs['router_z_loss'] + return { + 'hidden_states': hidden_states, + 'past_router_aux_loss': router_aux_loss, + 'past_router_z_loss': router_z_loss + } diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index 3ce97841730a..6351d26ca0a1 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -78,7 +78,7 @@ def parse_args(): parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") parser.add_argument("--plugin", type=str, - default="zero2", + default="hybrid", help="parallel plugin", choices=["zero1", "zero2", "hybrid"]) # hybrid plugin @@ -157,7 +157,7 @@ def main(): # Prepare tokenizer and dataloader tokenizer = T5Tokenizer.from_pretrained("google/umt5-small") - dataset = RandomDataset(num_samples=1000 if args.model_name != "test" else 10) + dataset = RandomDataset(num_samples=1000 if args.model_name != "test" else 50) dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) # Set optimizer @@ -176,15 +176,14 @@ def main(): model.train() train_dataloader_iter = iter(dataloader) total_len = len(train_dataloader_iter) - with tqdm(range(total_len), - desc=f'Epoch [{epoch + 1}/{args.num_epoch}]', - disable=not (coordinator.is_master() or is_pp_last_stage)) as pbar: + with tqdm(range(total_len), desc=f'Epoch [{epoch + 1}/{args.num_epoch}]', + disable=not coordinator.is_master()) as pbar: # Forward pass for _ in pbar: if use_pipeline: outputs = booster.execute_pipeline(train_dataloader_iter, model, - lambda x: x, + lambda x, y: x.loss, optimizer, return_loss=True, return_outputs=True) From e275b0967e6259cec6bdd5f0fb762020893d38d2 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 18 Sep 2023 09:57:32 +0800 Subject: [PATCH 05/16] update script --- examples/language/openmoe/train.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/language/openmoe/train.sh b/examples/language/openmoe/train.sh index 9a55779ca5ef..a2fe425c5805 100644 --- a/examples/language/openmoe/train.sh +++ b/examples/language/openmoe/train.sh @@ -1,3 +1,3 @@ -torchrun --standalone --nproc_per_node 2 train.py \ +torchrun --standalone --nproc_per_node 4 train.py \ --model_name "base" \ --batch_size 4 From dd6da186a62df3aaaae005bab24a372f3d5c88ad Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 18 Sep 2023 10:54:34 +0800 Subject: [PATCH 06/16] update plugin --- .../plugin/moe_hybrid_parallel_plugin.py | 527 ++++++++++++++++++ 1 file changed, 527 insertions(+) create mode 100644 colossalai/booster/plugin/moe_hybrid_parallel_plugin.py diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py new file mode 100644 index 000000000000..d65bd437962e --- /dev/null +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -0,0 +1,527 @@ +import random +from contextlib import nullcontext +from functools import partial +from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union + +import numpy as np +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.nn import Module, SyncBatchNorm +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils._pytree import tree_map +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer +from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO +from colossalai.cluster import ProcessGroupMesh +from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer.policies.base_policy import Policy +from colossalai.zero.low_level import LowLevelZeroOptimizer + +from .pp_plugin_base import PipelinePluginBase + +DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 + + +def _convert_floating_point(x, dtype: torch.dtype = torch.float16): + if isinstance(x, torch.Tensor) and torch.is_floating_point(x): + return x.to(dtype) + return x + + +class HybridParallelModule(ModelWrapper): + + def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool, + ddp_config: dict, custom_policy: Policy) -> None: + + self.stage_manager = shard_config.pipeline_stage_manager + self.dp_group = dp_group + + shardformer = ShardFormer(shard_config) + if custom_policy is not None: + assert isinstance(custom_policy, object) + module, self.shared_params = shardformer.optimize(module, policy=custom_policy) + + # setting process groups for shared parameters + self.shared_param_process_groups = [] + for shared_param in self.shared_params: + if len(shared_param) > 0: + self.shared_param_process_groups.append( + self.stage_manager.init_process_group_by_stages(list(shared_param.keys()))) + + # setting mixed_precision + self.mixed_precision = None + if precision == 'fp16': + self.mixed_precision = torch.float16 + elif precision == 'bf16': + self.mixed_precision = torch.bfloat16 + if self.mixed_precision is not None: + module = module.to(self.mixed_precision) + module = module.cuda() + + # setting input type cast when using mixed precision + self.convert_fn = None + if self.mixed_precision is not None: + self.convert_fn = partial(_convert_floating_point, dtype=self.mixed_precision) + + # setting ddp configs + if use_ddp: + # convert model to sync bn + module = SyncBatchNorm.convert_sync_batchnorm(module, dp_group) + # wrap the model with PyTorch DDP + module = DDP(module, process_group=dp_group, **ddp_config) + + super().__init__(module) + + def sync_shared_params(self): + for shared_param, group in zip(self.shared_params, self.shared_param_process_groups): + if self.stage_manager.stage in shared_param: + param = shared_param[self.stage_manager.stage] + dist.all_reduce(param.grad, group=group) + dist.barrier() + + def no_sync(self) -> Iterator[None]: + # no sync grads across data parallel + return nullcontext() + + def sync_grads(self): + # sync grad across data parallel + if self.dp_group.size() == 1: + return + for p in self.module.parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, group=self.dp_group) + p.grad.div_(self.dp_group.size()) + + def forward(self, *args, **kwargs): + if self.convert_fn is not None: + args = tree_map(self.convert_fn, args) + kwargs = tree_map(self.convert_fn, kwargs) + return super().forward(*args, **kwargs) + + def unwrap(self): + module = super().unwrap() + if isinstance(module, DDP): + module = module.module + return module + + +def get_param_info(optim: Optimizer): + # Get a backup of necessary information of parameters for future use, which includes: + # 1. A complete param_group, with params in the form of param_id + # 2. A mapping from param address (obtained using id(param)) to integer param_id + # 3. A mapping from integer param_id to param address. + # 4. A mapping from param_address (obtained using id(param)) to the original shape of parameter before sharding. + # When Zero is used, the params here are fp16/bf16 model params rather than fp32 master params in optimizer. + + if optim is None: + return {} + param_info = {'param_groups': [], 'param2id': {}, 'id2param': {}, 'param2shape': {}} + start_index = 0 + for group in optim.param_groups: + + packed_group = {k: v for k, v in group.items() if k != 'params'} + packed_group['params'] = [] + + for param_id, param in enumerate(group['params'], start_index): + original_shape = param.shape if isinstance(param, torch.Tensor) else None + packed_group['params'].append(param_id) + param_info['param2id'][id(param)] = param_id + param_info['id2param'][param_id] = id(param) + param_info['param2shape'][id(param)] = original_shape + + param_info['param_groups'].append(packed_group) + start_index += len(group['params']) + + return param_info + + +def init_pipeline_optimizer(optim: Optimizer, model: Module): + model_params = set(model.parameters()) + new_param_groups = [] + for group in optim.param_groups: + params = [p for p in group['params'] if p in model_params] + new_param_groups.append({**group, 'params': params}) + optim.__setstate__({'param_groups': new_param_groups}) + + +class HybridParallelNaiveOptimizer(OptimizerWrapper): + + def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict): + self.param_info = param_info + if use_pipeline: + init_pipeline_optimizer(optim, model) + super().__init__(optim) + + +class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): + + def __init__(self, + optim: Optimizer, + model: Module, + use_pipeline: bool, + param_info: OrderedDict, + precision: str = 'fp16', + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0): + self.param_info = param_info + if use_pipeline: + init_pipeline_optimizer(optim, model) + super().__init__(optim, precision, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, + hysteresis, max_scale, max_norm) + + +class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): + + def __init__( + self, + optimizer: Optimizer, + model: Module, + use_pipeline: bool, + param_info: OrderedDict, + initial_scale: int = 2**16, # grad scaler config + min_scale: int = 1, + growth_factor: float = 2., + backoff_factor: float = .5, + growth_interval: int = 2000, + hysteresis: int = 2, + max_scale: int = 2**24, + clip_grad_norm: float = 0.0, # grad clipping + verbose: bool = False, + reduce_bucket_size: int = 1024 * 1024, # communication + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True, + partition_grad: bool = False, # stage 2 flag + cpu_offload: bool = False, # cpu offload + dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm + tp_process_group: Optional[ProcessGroup] = None, # if using tp + forced_dtype: Optional[torch.dtype] = None): + self.param_info = param_info + if use_pipeline: + init_pipeline_optimizer(optimizer, model) + super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, + hysteresis, max_scale, clip_grad_norm, verbose, reduce_bucket_size, communication_dtype, + overlap_communication, partition_grad, cpu_offload, dp_process_group, tp_process_group, + forced_dtype) + + +class HybridParallelPlugin(PipelinePluginBase): + """ + Plugin for Hybrid Parallel Training. + Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin. + The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size). + + Example: + >>> from colossalai.booster import Booster + >>> from colossalai.booster.plugin import HybridParallelPlugin + + >>> model, train_dataset, optimizer, criterion = ... + >>> plugin = HybridParallelPlugin(tp_size=2, pp_size=2) + + >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) + >>> booster = Booster(plugin=plugin) + >>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader) + + Args: + tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1. + pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1. + precision (str, optional): Specifies the precision of parameters during training. + Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'. + Defaults to 'fp16'. + zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2]. + When set to 0, ZeRO will not be used. Defaults to 0. + enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer. + Currently all the optimization methods include fused normalization, flash attention and JIT. + Defaults to False. + enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False. + enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False. + enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. + enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. + enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False. + num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None. + microbatch_size (int, optional): Microbatch size when using pipeline parallelism. + Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline. + If ``num_microbatches`` is provided, this will be ignored. Defaults to None. + initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16. + min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1. + growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2. + backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5. + growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000. + hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2. + max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32. + max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0. + broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True. + ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25. + find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False. + check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False. + gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False. + static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False. + zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12. + cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False. + communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None. + overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True. + """ + + def __init__(self, + tp_size: int, + pp_size: int, + precision: str = 'fp16', + zero_stage: int = 0, + enable_all_optimization: bool = False, + enable_fused_normalization: bool = False, + enable_flash_attention: bool = False, + enable_jit_fused: bool = False, + enable_sequence_parallelism: bool = False, + enable_sequence_overlap: bool = False, + num_microbatches: Optional[int] = None, + microbatch_size: Optional[int] = None, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0, + broadcast_buffers: bool = True, + ddp_bucket_cap_mb: int = 25, + find_unused_parameters: bool = False, + check_reduction: bool = False, + gradient_as_bucket_view: bool = False, + static_graph: bool = False, + zero_bucket_size_in_m: int = 12, + cpu_offload: bool = False, + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True, + custom_policy: Policy = None) -> None: + + super().__init__() + assert dist.get_world_size() % ( + tp_size * pp_size + ) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}' + + if enable_sequence_parallelism: + assert tp_size > 1, 'Sequence parallelism must be enabled when using tensor parallelism' + + self.tp_size = tp_size + self.pp_size = pp_size + self.dp_size = dist.get_world_size() // (tp_size * pp_size) + self.precision = precision + self.zero_stage = zero_stage + self.cpu_offload = cpu_offload + self.enable_all_optimization = enable_all_optimization + self.enable_fused_normalization = enable_fused_normalization + self.enable_flash_attention = enable_flash_attention + self.enable_jit_fused = enable_jit_fused + self.enable_sequence_parallelism = enable_sequence_parallelism + self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) + self.stage_manager = None + self.schedule = None + self.custom_policy = custom_policy + assert zero_stage in (0, 1, 2) + if self.pp_size > 1: + assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism' + assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism' + self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS) + self.schedule = OneForwardOneBackwardSchedule(self.stage_manager, + num_microbatches=num_microbatches, + microbatch_size=microbatch_size) + self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) + self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) + self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) + self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group, + pipeline_stage_manager=self.stage_manager, + enable_tensor_parallelism=self.tp_size > 1, + enable_all_optimization=self.enable_all_optimization, + enable_fused_normalization=self.enable_fused_normalization, + enable_flash_attention=self.enable_flash_attention, + enable_jit_fused=self.enable_jit_fused, + enable_sequence_parallelism=enable_sequence_parallelism, + enable_sequence_overlap=enable_sequence_overlap) + self.amp_config = dict( + initial_scale=initial_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + min_scale=min_scale, + max_scale=max_scale, + ) + + self.ddp_config = dict(broadcast_buffers=broadcast_buffers, + bucket_cap_mb=ddp_bucket_cap_mb, + find_unused_parameters=find_unused_parameters, + check_reduction=check_reduction, + gradient_as_bucket_view=gradient_as_bucket_view, + static_graph=static_graph) + + self.zero_config = dict(reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + cpu_offload=cpu_offload, + partition_grad=(self.zero_stage == 2)) + + self.max_norm = max_norm + + @property + def enable_pipeline_parallelism(self) -> bool: + return self.pp_size > 1 + + def supported_devices(self) -> List[str]: + return ['cuda'] + + def supported_precisions(self) -> List[str]: + return ['fp16', 'bf16', 'fp32'] + + def control_device(self) -> bool: + return True + + def control_precision(self) -> bool: + return True + + def support_no_sync(self) -> bool: + return False + + def control_checkpoint_io(self) -> bool: + return True + + def configure( + self, + model: Module, + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, + ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: + param_info = get_param_info(optimizer) + if not isinstance(model, ModelWrapper): + use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 + model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp, + self.ddp_config, self.custom_policy) + if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): + if self.zero_stage == 0: + if self.precision in ['fp16', 'bf16']: + optimizer = HybridParallelAMPOptimizer(optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + precision=self.precision, + max_norm=self.max_norm, + **self.amp_config) + self.checkpoint_io.link_master_and_working_param(optimizer.working_to_master_map, + optimizer.master_to_working_map) + else: + optimizer = HybridParallelNaiveOptimizer(optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info) + else: + assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1." + assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO." + optimizer = HybridParallelZeroOptimizer(optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + dp_process_group=self.dp_group, + tp_process_group=self.tp_group, + verbose=True, + clip_grad_norm=self.max_norm, + **self.zero_config, + **self.amp_config) + self.checkpoint_io.link_master_and_working_param(optimizer._param_store.working_to_master_param, + optimizer._param_store.master_to_working_param) + + return model, optimizer, criterion, dataloader, lr_scheduler + + def execute_pipeline(self, + data_iter: Iterator, + model: HybridParallelModule, + criterion: Callable[[Any, Any], torch.Tensor], + optimizer: Optional[Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer, + HybridParallelZeroOptimizer]] = None, + return_loss: bool = True, + return_outputs: bool = False) -> dict: + assert self.enable_pipeline_parallelism, 'pipeline parallelism is not enabled' + # return loss or outputs if needed + ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() + with ctx: + outputs = self.schedule.forward_backward_step(model, data_iter, criterion, optimizer, return_loss, + return_outputs) + model.sync_shared_params() + if isinstance(optimizer, HybridParallelZeroOptimizer): + optimizer.sync_grad() + else: + model.sync_grads() + return outputs + + def prepare_dataloader(self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + **kwargs): + r""" + Prepare a dataloader for distributed training. The dataloader will be wrapped by + `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. + + + Args: + dataset (`torch.utils.data.Dataset`): The dataset to be loaded. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. + seed (int, optional): Random worker seed for sampling, defaults to 1024. + add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. + num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. + kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in + `DataLoader `_. + + Returns: + :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. + """ + _kwargs = kwargs.copy() + sampler = DistributedSampler(dataset, + num_replicas=self.pg_mesh.size(DP_AXIS), + rank=self.pg_mesh.coordinate(DP_AXIS), + shuffle=shuffle) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader(dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs) + + def get_checkpoint_io(self) -> CheckpointIO: + self.checkpoint_io = HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) + return self.checkpoint_io + + def no_sync(self, model: Module) -> Iterator[None]: + raise NotImplementedError From 0c2b3ef2b1b548934c3aebf61e6f01dc77258eee Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 18 Sep 2023 11:21:41 +0800 Subject: [PATCH 07/16] finish pp --- .../plugin/moe_hybrid_parallel_plugin.py | 218 ++---------------- colossalai/moe/manager.py | 84 +++++-- colossalai/tensor/moe_tensor/api.py | 10 +- colossalai/tensor/moe_tensor/moe_info.py | 15 +- examples/language/openmoe/train.py | 27 ++- 5 files changed, 110 insertions(+), 244 deletions(-) diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index d65bd437962e..04ec0ce57cef 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -1,226 +1,38 @@ import random -from contextlib import nullcontext -from functools import partial -from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union +from typing import Any, Callable, Iterator, List, Optional, Tuple, Union import numpy as np import torch import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.nn import Module, SyncBatchNorm -from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn import Module from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler -from torch.utils._pytree import tree_map from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer +from colossalai.booster.plugin.hybrid_parallel_plugin import ( + HybridParallelAMPOptimizer, + HybridParallelModule, + HybridParallelNaiveOptimizer, + HybridParallelZeroOptimizer, + get_param_info, +) from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.cluster import ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer import ShardConfig from colossalai.shardformer.policies.base_policy import Policy -from colossalai.zero.low_level import LowLevelZeroOptimizer from .pp_plugin_base import PipelinePluginBase -DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 - - -def _convert_floating_point(x, dtype: torch.dtype = torch.float16): - if isinstance(x, torch.Tensor) and torch.is_floating_point(x): - return x.to(dtype) - return x - - -class HybridParallelModule(ModelWrapper): - - def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool, - ddp_config: dict, custom_policy: Policy) -> None: - - self.stage_manager = shard_config.pipeline_stage_manager - self.dp_group = dp_group - - shardformer = ShardFormer(shard_config) - if custom_policy is not None: - assert isinstance(custom_policy, object) - module, self.shared_params = shardformer.optimize(module, policy=custom_policy) - - # setting process groups for shared parameters - self.shared_param_process_groups = [] - for shared_param in self.shared_params: - if len(shared_param) > 0: - self.shared_param_process_groups.append( - self.stage_manager.init_process_group_by_stages(list(shared_param.keys()))) - - # setting mixed_precision - self.mixed_precision = None - if precision == 'fp16': - self.mixed_precision = torch.float16 - elif precision == 'bf16': - self.mixed_precision = torch.bfloat16 - if self.mixed_precision is not None: - module = module.to(self.mixed_precision) - module = module.cuda() - - # setting input type cast when using mixed precision - self.convert_fn = None - if self.mixed_precision is not None: - self.convert_fn = partial(_convert_floating_point, dtype=self.mixed_precision) - - # setting ddp configs - if use_ddp: - # convert model to sync bn - module = SyncBatchNorm.convert_sync_batchnorm(module, dp_group) - # wrap the model with PyTorch DDP - module = DDP(module, process_group=dp_group, **ddp_config) - - super().__init__(module) - - def sync_shared_params(self): - for shared_param, group in zip(self.shared_params, self.shared_param_process_groups): - if self.stage_manager.stage in shared_param: - param = shared_param[self.stage_manager.stage] - dist.all_reduce(param.grad, group=group) - dist.barrier() - - def no_sync(self) -> Iterator[None]: - # no sync grads across data parallel - return nullcontext() - - def sync_grads(self): - # sync grad across data parallel - if self.dp_group.size() == 1: - return - for p in self.module.parameters(): - if p.grad is not None: - dist.all_reduce(p.grad, group=self.dp_group) - p.grad.div_(self.dp_group.size()) - - def forward(self, *args, **kwargs): - if self.convert_fn is not None: - args = tree_map(self.convert_fn, args) - kwargs = tree_map(self.convert_fn, kwargs) - return super().forward(*args, **kwargs) - - def unwrap(self): - module = super().unwrap() - if isinstance(module, DDP): - module = module.module - return module - - -def get_param_info(optim: Optimizer): - # Get a backup of necessary information of parameters for future use, which includes: - # 1. A complete param_group, with params in the form of param_id - # 2. A mapping from param address (obtained using id(param)) to integer param_id - # 3. A mapping from integer param_id to param address. - # 4. A mapping from param_address (obtained using id(param)) to the original shape of parameter before sharding. - # When Zero is used, the params here are fp16/bf16 model params rather than fp32 master params in optimizer. - - if optim is None: - return {} - param_info = {'param_groups': [], 'param2id': {}, 'id2param': {}, 'param2shape': {}} - start_index = 0 - for group in optim.param_groups: - - packed_group = {k: v for k, v in group.items() if k != 'params'} - packed_group['params'] = [] - - for param_id, param in enumerate(group['params'], start_index): - original_shape = param.shape if isinstance(param, torch.Tensor) else None - packed_group['params'].append(param_id) - param_info['param2id'][id(param)] = param_id - param_info['id2param'][param_id] = id(param) - param_info['param2shape'][id(param)] = original_shape - - param_info['param_groups'].append(packed_group) - start_index += len(group['params']) - - return param_info - - -def init_pipeline_optimizer(optim: Optimizer, model: Module): - model_params = set(model.parameters()) - new_param_groups = [] - for group in optim.param_groups: - params = [p for p in group['params'] if p in model_params] - new_param_groups.append({**group, 'params': params}) - optim.__setstate__({'param_groups': new_param_groups}) - - -class HybridParallelNaiveOptimizer(OptimizerWrapper): - - def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict): - self.param_info = param_info - if use_pipeline: - init_pipeline_optimizer(optim, model) - super().__init__(optim) - - -class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): +PP_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2 - def __init__(self, - optim: Optimizer, - model: Module, - use_pipeline: bool, - param_info: OrderedDict, - precision: str = 'fp16', - initial_scale: float = 2**16, - min_scale: float = 1, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - max_scale: float = 2**32, - max_norm: float = 0): - self.param_info = param_info - if use_pipeline: - init_pipeline_optimizer(optim, model) - super().__init__(optim, precision, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, - hysteresis, max_scale, max_norm) - - -class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): - - def __init__( - self, - optimizer: Optimizer, - model: Module, - use_pipeline: bool, - param_info: OrderedDict, - initial_scale: int = 2**16, # grad scaler config - min_scale: int = 1, - growth_factor: float = 2., - backoff_factor: float = .5, - growth_interval: int = 2000, - hysteresis: int = 2, - max_scale: int = 2**24, - clip_grad_norm: float = 0.0, # grad clipping - verbose: bool = False, - reduce_bucket_size: int = 1024 * 1024, # communication - communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = True, - partition_grad: bool = False, # stage 2 flag - cpu_offload: bool = False, # cpu offload - dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm - tp_process_group: Optional[ProcessGroup] = None, # if using tp - forced_dtype: Optional[torch.dtype] = None): - self.param_info = param_info - if use_pipeline: - init_pipeline_optimizer(optimizer, model) - super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, - hysteresis, max_scale, clip_grad_norm, verbose, reduce_bucket_size, communication_dtype, - overlap_communication, partition_grad, cpu_offload, dp_process_group, tp_process_group, - forced_dtype) - - -class HybridParallelPlugin(PipelinePluginBase): + +class MoeHybridParallelPlugin(PipelinePluginBase): """ - Plugin for Hybrid Parallel Training. + Plugin for Moe Hybrid Parallel Training. Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin. The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size). @@ -327,7 +139,7 @@ def __init__(self, self.enable_flash_attention = enable_flash_attention self.enable_jit_fused = enable_jit_fused self.enable_sequence_parallelism = enable_sequence_parallelism - self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) + self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size) self.stage_manager = None self.schedule = None self.custom_policy = custom_policy diff --git a/colossalai/moe/manager.py b/colossalai/moe/manager.py index 3dc27c6cb0f0..30f191a1de91 100644 --- a/colossalai/moe/manager.py +++ b/colossalai/moe/manager.py @@ -24,6 +24,7 @@ def __init__(self): self.router_z_loss = [] self.parallel = None self.seed = None + self.mode = None self.use_kernel_optim = True self.has_setup = False @@ -37,16 +38,50 @@ def parallel_info_dict(self): def is_initialized(self): return self.has_setup - def setup(self, seed: int, use_kernel_optim: bool = True, max_ep_size: int = 8, parallel: bool = None): + def setup(self, + seed: int, + use_kernel_optim: bool = True, + parallel: bool = None, + mode: str = "dynamic", + max_ep_size: int = 8, + fixed_dp_size: int = 0, + fixed_ep_size: int = 0, + fixed_pp_size: int = 0) -> None: + """ + Setup MoE distributed context. + + Args: + seed (int): Random seed. Defaults to 42. + use_kernel_optim (bool, optional): Use cuda kernel. Defaults to True. + parallel (bool, optional): Parallel mode, should be EP, TP or None. Defaults to None. + mode (str, optional): Should be "fixed" or "dynamic". Defaults to "dynamic". + In fixed mode, the ep size and dp size is fixed. + In dynamic mode, the ep size and dp size will be changed according to num experts. + max_ep_size (int, optional): Max ep size in dynamic mode. Defaults to 8. + fixed_dp_size (int, optional): Fixed dp size in fixed mode. Defaults to 0. + fixed_ep_size (int, optional): Fixed ep size in fixed mode. Defaults to 0. + fixed_pp_size (int, optional): Fixed pp size in fixed mode. Defaults to 0. + """ assert not self.is_initialized, "MoE distributed context shouldn't be set up again" assert torch.cuda.is_available(), "MoE requires to enable CUDA first" self.world_size = dist.get_world_size() self.seed = seed + dist.get_rank() - self.max_ep_size = min(max_ep_size, dist.get_world_size()) - self.min_dp_size = self.world_size // self.max_ep_size self.parallel = parallel + # init by mode + assert mode in ["fixed", "dynamic"], "mode should be fixed or dynamic" + if mode == "dynamic": + self.max_ep_size = min(max_ep_size, dist.get_world_size()) + self.min_dp_size = self.world_size // self.max_ep_size + else: + assert fixed_dp_size > 0 and fixed_ep_size > 0 and fixed_pp_size > 0, "dp_size, ep_size and pp_size should be greater than 0" + assert isinstance(fixed_dp_size, int) and isinstance(fixed_ep_size, int) and isinstance( + fixed_pp_size, int), "dp_size, ep_size and pp_size should be int" + self.ep_size = fixed_ep_size + self.dp_size = fixed_dp_size + self.pp_size = fixed_pp_size + # Enabling kernel optimization may raise error in some cases # Users can close kernel optimization manually self.use_kernel_optim = use_kernel_optim @@ -67,30 +102,39 @@ def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoePara number of local experts, the MoeParallelInfo of the current ep_size """ - gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater - lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less - - assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number" \ - " is not a multiple of ep size or vice versa." - - # If the number of experts is greater than maximum expert parallel size. a.k.a ep_size, - # there are multiple experts in each GPU and each GPU has different experts - # So it's data parallel size is 1 - # Otherwise, there is only one expert in each GPU - # The data parallel size should be calculated - dp_size = 1 if gt_flag else self.max_ep_size // num_experts - ep_size = self.max_ep_size // dp_size + if self.mode == "dynamic": + gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater + lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less + + assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number" \ + " is not a multiple of ep size or vice versa." + + # If the number of experts is greater than maximum expert parallel size. a.k.a ep_size, + # there are multiple experts in each GPU and each GPU has different experts + # So it's data parallel size is 1 + # Otherwise, there is only one expert in each GPU + # The data parallel size should be calculated + dp_size = 1 if gt_flag else self.max_ep_size // num_experts + ep_size = self.max_ep_size // dp_size + # Don't forget to multiply minimum data parallel size + dp_size *= self.min_dp_size + pp_size = None + else: + dp_size = self.dp_size + ep_size = self.ep_size + pp_size = self.pp_size # Calculate the number of experts for each GPU if use_tp: num_local_experts = num_experts else: - num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size + if self.mode == "dynamic": + num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size + else: + num_local_experts = num_experts // ep_size - # Don't forget to multiply minimum data parallel size - dp_size *= self.min_dp_size if not (ep_size in self.parallel_info_dict): - self.parallel_info_dict[ep_size] = get_moe_info(ep_size, dp_size) + self.parallel_info_dict[ep_size] = get_moe_info(ep_size, dp_size, pp_size) return num_local_experts, self.parallel_info_dict[ep_size] diff --git a/colossalai/tensor/moe_tensor/api.py b/colossalai/tensor/moe_tensor/api.py index 442b3c0f4958..fc4ed14e0ef7 100644 --- a/colossalai/tensor/moe_tensor/api.py +++ b/colossalai/tensor/moe_tensor/api.py @@ -28,20 +28,22 @@ def set_moe_tensor_info(tensor: torch.Tensor, moe_info: MoeParallelInfo) -> None moe_info (dict): The moe info to be set. """ - tensor.__setattr__('moe_info', moe_info) + tensor.__setattr__("moe_info", moe_info) -def get_moe_info(ep_size: int, dp_size: int) -> MoeParallelInfo: +def get_moe_info(ep_size: int, dp_size: int, pp_size: int) -> MoeParallelInfo: """ Get moe info for the given tensor. Args: - tensor (torch.Tensor): The tensor to be checked. + ep_size (int): The expert parallel size. + dp_size (int): The data parallel size. + pp_size (int): The pipeline parallel size. Returns: dict: The moe info of the given tensor. """ - return MoeParallelInfo(ep_size, dp_size) + return MoeParallelInfo(ep_size, dp_size, pp_size) def get_ep_group(tensor: torch.Tensor) -> ProcessGroup: diff --git a/colossalai/tensor/moe_tensor/moe_info.py b/colossalai/tensor/moe_tensor/moe_info.py index ca7f163b9c24..2d3c2efbfb31 100644 --- a/colossalai/tensor/moe_tensor/moe_info.py +++ b/colossalai/tensor/moe_tensor/moe_info.py @@ -2,15 +2,14 @@ class MoeParallelInfo: - """Moe parallelism information, storing parallel sizes and groups. - """ + """Moe parallelism information, storing parallel sizes and groups.""" + + def __init__(self, ep_size: int, dp_size: int, pp_size: int = 1): + self.pp_axis, self.dp_axis, self.ep_axis = 0, 1, 2 + self.pp_size, self.dp_size, self.ep_size = pp_size, dp_size, ep_size + + self.pg = ProcessGroupMesh(self.pp_size, self.dp_size, self.ep_size) - def __init__(self, ep_size: int, dp_size: int): - self.dp_axis = 0 - self.dp_size = dp_size - self.ep_axis = 1 - self.ep_size = ep_size - self.pg = ProcessGroupMesh(self.dp_size, self.ep_size) self.ep_group = self.pg.get_group_along_axis(self.ep_axis) self.ep_group_ranks = self.pg.get_ranks_in_group(self.ep_group) self.dp_group = self.pg.get_group_along_axis(self.dp_axis) diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index 6351d26ca0a1..1bc19d3d726b 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -14,7 +14,8 @@ import colossalai from colossalai import get_default_parser from colossalai.booster import Booster -from colossalai.booster.plugin import HybridParallelPlugin, LowLevelZeroPlugin +from colossalai.booster.plugin import LowLevelZeroPlugin +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.moe import MoeCheckpintIO @@ -82,8 +83,9 @@ def parse_args(): help="parallel plugin", choices=["zero1", "zero2", "hybrid"]) # hybrid plugin - parser.add_argument("--tp_size", type=int, default=1, help="tp size") parser.add_argument("--pp_size", type=int, default=2, help="pp size") + parser.add_argument("--dp_size", type=int, default=1, help="dp size") + parser.add_argument("--ep_size", type=int, default=2, help="ep size") parser.add_argument("--zero_stage", type=int, default=1, help="zero stage in hybrid plugin") parser.add_argument("--microbatch_size", type=int, default=1, help="microbatch size") # loss @@ -107,7 +109,14 @@ def main(): coordinator = DistCoordinator() # Set up moe - MOE_MANAGER.setup(seed=42, parallel=None) + assert args.dp_size * args.ep_size * args.pp_size == coordinator.world_size, "dp_size * ep_size * pp_size must equal to world_size" + # MOE_MANAGER.setup(seed=42, parallel=None) + MOE_MANAGER.setup(seed=42, + parallel="EP", + mode="fixed", + fixed_dp_size=args.dp_size, + fixed_ep_size=args.ep_size, + fixed_pp_size=args.pp_size) # Manage loggers disable_existing_loggers() @@ -146,11 +155,11 @@ def main(): elif args.plugin == "zero2": plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2) elif args.plugin == "hybrid": - plugin = HybridParallelPlugin(tp_size=args.tp_size, - pp_size=args.pp_size, - zero_stage=args.zero_stage, - microbatch_size=args.microbatch_size, - custom_policy=OpenMoeForCausalLMPolicy()) + plugin = MoeHybridParallelPlugin(tp_size=1, + pp_size=args.pp_size, + zero_stage=args.zero_stage, + microbatch_size=args.microbatch_size, + custom_policy=OpenMoeForCausalLMPolicy()) else: raise ValueError(f"Invalid plugin {args.plugin}") logger.info(f"Set plugin as {plugin}", ranks=[0]) @@ -166,7 +175,7 @@ def main(): # Set booster booster = Booster(plugin=plugin, **booster_kwargs) model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader) - use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 + use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1 is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() logger.info(f"Finish init booster", ranks=[0]) From 58d99f08e9882031f94cd0088a68d36a048537b6 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 18 Sep 2023 11:31:27 +0800 Subject: [PATCH 08/16] update setup for different plugin --- examples/language/openmoe/train.py | 126 ++++++++++++++++++----------- 1 file changed, 78 insertions(+), 48 deletions(-) diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index 1bc19d3d726b..ab72659eff27 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -54,34 +54,42 @@ def __len__(self): def __getitem__(self, idx): return { - 'input_ids': self.input_ids[idx], - 'attention_mask': self.attention_mask[idx], - 'labels': self.input_ids[idx] + "input_ids": self.input_ids[idx], + "attention_mask": self.attention_mask[idx], + "labels": self.input_ids[idx], } def parse_args(): # basic settings parser = get_default_parser() - parser.add_argument("--model_name", - type=str, - default="base", - help="Path to pretrained model or model identifier from huggingface.co/models.") - parser.add_argument("--output_path", - type=str, - default="./output_model.bin", - help="The path of your saved model after finetuning.") + parser.add_argument( + "--model_name", + type=str, + default="base", + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--output_path", + type=str, + default="./output_model.bin", + help="The path of your saved model after finetuning.", + ) parser.add_argument("--num_epoch", type=int, default=10, help="Number of epochs.") - parser.add_argument("--batch_size", - type=int, - default=4, - help="Batch size (per dp group) for the training dataloader.") + parser.add_argument( + "--batch_size", + type=int, + default=4, + help="Batch size (per dp group) for the training dataloader.", + ) parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") - parser.add_argument("--plugin", - type=str, - default="hybrid", - help="parallel plugin", - choices=["zero1", "zero2", "hybrid"]) + parser.add_argument( + "--plugin", + type=str, + default="hybrid", + help="parallel plugin", + choices=["zero1", "zero2", "hybrid"], + ) # hybrid plugin parser.add_argument("--pp_size", type=int, default=2, help="pp size") parser.add_argument("--dp_size", type=int, default=1, help="dp size") @@ -89,8 +97,18 @@ def parse_args(): parser.add_argument("--zero_stage", type=int, default=1, help="zero stage in hybrid plugin") parser.add_argument("--microbatch_size", type=int, default=1, help="microbatch size") # loss - parser.add_argument("--router_aux_loss_factor", type=float, default=0.01, help="router_aux_loss_factor.") - parser.add_argument("--router_z_loss_factor", type=float, default=0.0001, help="router_z_loss_factor.") + parser.add_argument( + "--router_aux_loss_factor", + type=float, + default=0.01, + help="router_aux_loss_factor.", + ) + parser.add_argument( + "--router_z_loss_factor", + type=float, + default=0.0001, + help="router_z_loss_factor.", + ) parser.add_argument("--label_smoothing", type=float, default=0.0, help="label_smoothing.") parser.add_argument("--z_loss_factor", type=float, default=0.0001, help="z_loss_factor.") # optim @@ -109,14 +127,19 @@ def main(): coordinator = DistCoordinator() # Set up moe - assert args.dp_size * args.ep_size * args.pp_size == coordinator.world_size, "dp_size * ep_size * pp_size must equal to world_size" - # MOE_MANAGER.setup(seed=42, parallel=None) - MOE_MANAGER.setup(seed=42, - parallel="EP", - mode="fixed", - fixed_dp_size=args.dp_size, - fixed_ep_size=args.ep_size, - fixed_pp_size=args.pp_size) + assert (args.dp_size * args.ep_size * + args.pp_size == coordinator.world_size), "dp_size * ep_size * pp_size must equal to world_size" + if args.plugin in ["zero1", "zero2"]: + MOE_MANAGER.setup(seed=42, parallel="EP") + elif args.plugin == "hybrid": + MOE_MANAGER.setup( + seed=42, + parallel="EP", + mode="fixed", + fixed_dp_size=args.dp_size, + fixed_ep_size=args.ep_size, + fixed_pp_size=args.pp_size, + ) # Manage loggers disable_existing_loggers() @@ -155,11 +178,13 @@ def main(): elif args.plugin == "zero2": plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2) elif args.plugin == "hybrid": - plugin = MoeHybridParallelPlugin(tp_size=1, - pp_size=args.pp_size, - zero_stage=args.zero_stage, - microbatch_size=args.microbatch_size, - custom_policy=OpenMoeForCausalLMPolicy()) + plugin = MoeHybridParallelPlugin( + tp_size=1, + pp_size=args.pp_size, + zero_stage=args.zero_stage, + microbatch_size=args.microbatch_size, + custom_policy=OpenMoeForCausalLMPolicy(), + ) else: raise ValueError(f"Invalid plugin {args.plugin}") logger.info(f"Set plugin as {plugin}", ranks=[0]) @@ -175,7 +200,7 @@ def main(): # Set booster booster = Booster(plugin=plugin, **booster_kwargs) model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader) - use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1 + use_pipeline = (isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1) is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() logger.info(f"Finish init booster", ranks=[0]) @@ -185,29 +210,34 @@ def main(): model.train() train_dataloader_iter = iter(dataloader) total_len = len(train_dataloader_iter) - with tqdm(range(total_len), desc=f'Epoch [{epoch + 1}/{args.num_epoch}]', - disable=not coordinator.is_master()) as pbar: + with tqdm( + range(total_len), + desc=f"Epoch [{epoch + 1}/{args.num_epoch}]", + disable=not coordinator.is_master(), + ) as pbar: # Forward pass for _ in pbar: if use_pipeline: - outputs = booster.execute_pipeline(train_dataloader_iter, - model, - lambda x, y: x.loss, - optimizer, - return_loss=True, - return_outputs=True) + outputs = booster.execute_pipeline( + train_dataloader_iter, + model, + lambda x, y: x.loss, + optimizer, + return_loss=True, + return_outputs=True, + ) # Backward and optimize if is_pp_last_stage: - loss = outputs['loss'] - pbar.set_postfix({'loss': loss.item()}) + loss = outputs["loss"] + pbar.set_postfix({"loss": loss.item()}) else: data = next(train_dataloader_iter) data = move_to_cuda(data, torch.cuda.current_device()) outputs = model(**data) - loss = outputs['loss'] + loss = outputs["loss"] # Backward booster.backward(loss, optimizer) - pbar.set_postfix({'loss': loss.item()}) + pbar.set_postfix({"loss": loss.item()}) optimizer.step() optimizer.zero_grad() From e67b3aec08b2ba87824f3f500ccdbadba44adf64 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 18 Sep 2023 12:33:33 +0800 Subject: [PATCH 09/16] update ci --- colossalai/moe/manager.py | 7 ++++--- examples/language/openmoe/test_ci.sh | 2 +- examples/language/openmoe/train.py | 4 ++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/colossalai/moe/manager.py b/colossalai/moe/manager.py index 30f191a1de91..0f2964cb1076 100644 --- a/colossalai/moe/manager.py +++ b/colossalai/moe/manager.py @@ -70,8 +70,9 @@ def setup(self, self.parallel = parallel # init by mode - assert mode in ["fixed", "dynamic"], "mode should be fixed or dynamic" - if mode == "dynamic": + self.mode = mode + assert self.mode in ["fixed", "dynamic"], "mode should be fixed or dynamic" + if self.mode == "dynamic": self.max_ep_size = min(max_ep_size, dist.get_world_size()) self.min_dp_size = self.world_size // self.max_ep_size else: @@ -118,7 +119,7 @@ def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoePara ep_size = self.max_ep_size // dp_size # Don't forget to multiply minimum data parallel size dp_size *= self.min_dp_size - pp_size = None + pp_size = 1 else: dp_size = self.dp_size ep_size = self.ep_size diff --git a/examples/language/openmoe/test_ci.sh b/examples/language/openmoe/test_ci.sh index 75eee902c747..8361b66c50d1 100644 --- a/examples/language/openmoe/test_ci.sh +++ b/examples/language/openmoe/test_ci.sh @@ -2,4 +2,4 @@ set -xe pip install -r requirements.txt python infer.py --model "test" -torchrun --standalone --nproc_per_node 2 train.py --model_name "test" --batch_size 1 --num_epoch 20 +torchrun --standalone --nproc_per_node 2 train.py --model_name "test" --batch_size 1 --num_epoch 1 --plugin zero2 diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index ab72659eff27..ad20920eb24b 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -127,11 +127,11 @@ def main(): coordinator = DistCoordinator() # Set up moe - assert (args.dp_size * args.ep_size * - args.pp_size == coordinator.world_size), "dp_size * ep_size * pp_size must equal to world_size" if args.plugin in ["zero1", "zero2"]: MOE_MANAGER.setup(seed=42, parallel="EP") elif args.plugin == "hybrid": + assert (args.dp_size * args.ep_size * + args.pp_size == coordinator.world_size), "dp_size * ep_size * pp_size must equal to world_size" MOE_MANAGER.setup( seed=42, parallel="EP", From 89e8f99b1acb3a2c48421aaf13956016740113fc Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 18 Sep 2023 13:58:50 +0800 Subject: [PATCH 10/16] update ci --- colossalai/moe/manager.py | 2 +- examples/language/openmoe/test_ci.sh | 1 + examples/language/openmoe/train.py | 7 ++++++- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/colossalai/moe/manager.py b/colossalai/moe/manager.py index 0f2964cb1076..d05ee4c25450 100644 --- a/colossalai/moe/manager.py +++ b/colossalai/moe/manager.py @@ -25,7 +25,7 @@ def __init__(self): self.parallel = None self.seed = None self.mode = None - self.use_kernel_optim = True + self.use_kernel_optim = False self.has_setup = False self._parallel_info_dict = dict() diff --git a/examples/language/openmoe/test_ci.sh b/examples/language/openmoe/test_ci.sh index 8361b66c50d1..c555c3e5b116 100644 --- a/examples/language/openmoe/test_ci.sh +++ b/examples/language/openmoe/test_ci.sh @@ -3,3 +3,4 @@ pip install -r requirements.txt python infer.py --model "test" torchrun --standalone --nproc_per_node 2 train.py --model_name "test" --batch_size 1 --num_epoch 1 --plugin zero2 +torchrun --standalone --nproc_per_node 4 train.py --model_name "test" --batch_size 1 --num_epoch 1 --plugin hybrid --pp_size 2 --dp_size 1 --ep_size 2 diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index ad20920eb24b..efb61d1bb69c 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -128,7 +128,11 @@ def main(): # Set up moe if args.plugin in ["zero1", "zero2"]: - MOE_MANAGER.setup(seed=42, parallel="EP") + MOE_MANAGER.setup( + seed=42, + parallel="EP", + use_kernel_optim=True if args.model_name != "test" else False, + ) elif args.plugin == "hybrid": assert (args.dp_size * args.ep_size * args.pp_size == coordinator.world_size), "dp_size * ep_size * pp_size must equal to world_size" @@ -139,6 +143,7 @@ def main(): fixed_dp_size=args.dp_size, fixed_ep_size=args.ep_size, fixed_pp_size=args.pp_size, + use_kernel_optim=True if args.model_name != "test" else False, ) # Manage loggers From fae3c50363e27aace033535bc781e3b357cca712 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 18 Sep 2023 14:00:29 +0800 Subject: [PATCH 11/16] update ci --- examples/language/openmoe/test_ci.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/language/openmoe/test_ci.sh b/examples/language/openmoe/test_ci.sh index c555c3e5b116..8361b66c50d1 100644 --- a/examples/language/openmoe/test_ci.sh +++ b/examples/language/openmoe/test_ci.sh @@ -3,4 +3,3 @@ pip install -r requirements.txt python infer.py --model "test" torchrun --standalone --nproc_per_node 2 train.py --model_name "test" --batch_size 1 --num_epoch 1 --plugin zero2 -torchrun --standalone --nproc_per_node 4 train.py --model_name "test" --batch_size 1 --num_epoch 1 --plugin hybrid --pp_size 2 --dp_size 1 --ep_size 2 From 4abb220f5829b1db0c9948c5de794f324de967bd Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 18 Sep 2023 14:26:43 +0800 Subject: [PATCH 12/16] support ep inside or dp inside --- colossalai/moe/manager.py | 8 ++++++-- colossalai/tensor/moe_tensor/api.py | 5 +++-- colossalai/tensor/moe_tensor/moe_info.py | 20 ++++++++++++++++---- 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/colossalai/moe/manager.py b/colossalai/moe/manager.py index d05ee4c25450..e61fb0bf9582 100644 --- a/colossalai/moe/manager.py +++ b/colossalai/moe/manager.py @@ -26,6 +26,7 @@ def __init__(self): self.seed = None self.mode = None self.use_kernel_optim = False + self.use_ep_inside = None self.has_setup = False self._parallel_info_dict = dict() @@ -46,7 +47,8 @@ def setup(self, max_ep_size: int = 8, fixed_dp_size: int = 0, fixed_ep_size: int = 0, - fixed_pp_size: int = 0) -> None: + fixed_pp_size: int = 0, + use_ep_inside: bool = True) -> None: """ Setup MoE distributed context. @@ -61,6 +63,7 @@ def setup(self, fixed_dp_size (int, optional): Fixed dp size in fixed mode. Defaults to 0. fixed_ep_size (int, optional): Fixed ep size in fixed mode. Defaults to 0. fixed_pp_size (int, optional): Fixed pp size in fixed mode. Defaults to 0. + use_ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True. """ assert not self.is_initialized, "MoE distributed context shouldn't be set up again" assert torch.cuda.is_available(), "MoE requires to enable CUDA first" @@ -68,6 +71,7 @@ def setup(self, self.world_size = dist.get_world_size() self.seed = seed + dist.get_rank() self.parallel = parallel + self.use_ep_inside = use_ep_inside # init by mode self.mode = mode @@ -135,7 +139,7 @@ def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoePara num_local_experts = num_experts // ep_size if not (ep_size in self.parallel_info_dict): - self.parallel_info_dict[ep_size] = get_moe_info(ep_size, dp_size, pp_size) + self.parallel_info_dict[ep_size] = get_moe_info(ep_size, dp_size, pp_size, ep_inside=self.use_ep_inside) return num_local_experts, self.parallel_info_dict[ep_size] diff --git a/colossalai/tensor/moe_tensor/api.py b/colossalai/tensor/moe_tensor/api.py index fc4ed14e0ef7..9120a40b8533 100644 --- a/colossalai/tensor/moe_tensor/api.py +++ b/colossalai/tensor/moe_tensor/api.py @@ -31,7 +31,7 @@ def set_moe_tensor_info(tensor: torch.Tensor, moe_info: MoeParallelInfo) -> None tensor.__setattr__("moe_info", moe_info) -def get_moe_info(ep_size: int, dp_size: int, pp_size: int) -> MoeParallelInfo: +def get_moe_info(ep_size: int, dp_size: int, pp_size: int, ep_inside: bool) -> MoeParallelInfo: """ Get moe info for the given tensor. @@ -39,11 +39,12 @@ def get_moe_info(ep_size: int, dp_size: int, pp_size: int) -> MoeParallelInfo: ep_size (int): The expert parallel size. dp_size (int): The data parallel size. pp_size (int): The pipeline parallel size. + ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Returns: dict: The moe info of the given tensor. """ - return MoeParallelInfo(ep_size, dp_size, pp_size) + return MoeParallelInfo(ep_inside, ep_size, dp_size, pp_size) def get_ep_group(tensor: torch.Tensor) -> ProcessGroup: diff --git a/colossalai/tensor/moe_tensor/moe_info.py b/colossalai/tensor/moe_tensor/moe_info.py index 2d3c2efbfb31..5097ac1044e7 100644 --- a/colossalai/tensor/moe_tensor/moe_info.py +++ b/colossalai/tensor/moe_tensor/moe_info.py @@ -4,11 +4,23 @@ class MoeParallelInfo: """Moe parallelism information, storing parallel sizes and groups.""" - def __init__(self, ep_size: int, dp_size: int, pp_size: int = 1): - self.pp_axis, self.dp_axis, self.ep_axis = 0, 1, 2 - self.pp_size, self.dp_size, self.ep_size = pp_size, dp_size, ep_size + def __init__(self, ep_inside: bool, ep_size: int, dp_size: int, pp_size: int = 1): + """ + init MoeParallelInfo with ep_size, dp_size and pp_size - self.pg = ProcessGroupMesh(self.pp_size, self.dp_size, self.ep_size) + Args: + ep_size (int): expert parallel size + dp_size (int): data parallel (zero) size + pp_size (int, optional): pipeline parallel size. Defaults to 1. + ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True. + """ + self.pp_size, self.dp_size, self.ep_size = pp_size, dp_size, ep_size + if ep_inside: + self.pp_axis, self.dp_axis, self.ep_axis = 0, 1, 2 + self.pg = ProcessGroupMesh(self.pp_size, self.dp_size, self.ep_size) + else: + self.pp_axis, self.ep_axis, self.dp_axis = 0, 1, 2 + self.pg = ProcessGroupMesh(self.pp_size, self.ep_size, self.dp_size) self.ep_group = self.pg.get_group_along_axis(self.ep_axis) self.ep_group_ranks = self.pg.get_ranks_in_group(self.ep_group) From ac98ee64ffa6850d320f9389b14a3b5b87f1069d Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 18 Sep 2023 15:50:32 +0800 Subject: [PATCH 13/16] update arg for kernel --- .../openmoe/model/modeling_openmoe.py | 3 +- .../language/openmoe/model/openmoe_policy.py | 209 +++++++++++------- examples/language/openmoe/train.py | 12 +- 3 files changed, 135 insertions(+), 89 deletions(-) diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index d8289b791dd5..90d3e0022ce4 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -175,6 +175,7 @@ def __init__(self, config): self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = SwiGLU + self.use_kernel = True if MOE_MANAGER.use_kernel_optim else False def forward(self, x): if self.pretraining_tp > 1: @@ -190,7 +191,7 @@ def forward(self, x): down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)] down_proj = sum(down_proj) else: - if HAS_TRITON: + if HAS_TRITON and self.use_kernel: down_proj = self.down_proj(LlamaActCombine.apply(self.gate_proj(x), self.up_proj(x))) else: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) diff --git a/examples/language/openmoe/model/openmoe_policy.py b/examples/language/openmoe/model/openmoe_policy.py index 21e25bcb73a0..cc82683cd319 100644 --- a/examples/language/openmoe/model/openmoe_policy.py +++ b/examples/language/openmoe/model/openmoe_policy.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from torch import Tensor from torch.nn import Module -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.utils import logging from colossalai.moe.manager import MOE_MANAGER @@ -17,7 +17,7 @@ from .modeling_openmoe import OpenMoeDecoderLayer, OpenMoeForCausalLM, OpenMoeModel -__all__ = ['OpenMoePolicy', 'OpenMoeForCausalLMPolicy'] +__all__ = ["OpenMoePolicy", "OpenMoeForCausalLMPolicy"] class OpenMoePolicy(Policy): @@ -50,29 +50,34 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: # optimization configuration if self.shard_config.enable_fused_normalization: - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="input_layernorm", - target_module=FusedRMSNorm, - ), - SubModuleReplacementDescription( - suffix="post_attention_layernorm", + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=FusedRMSNorm, + ), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", + target_module=FusedRMSNorm, + ), + SubModuleReplacementDescription( + suffix="pre_extra_mlp_layernorm", + target_module=FusedRMSNorm, + ignore_if_not_exist=True, + ), + ], + policy=policy, + target_key=OpenMoeDecoderLayer, + ) + + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="norm", target_module=FusedRMSNorm, ), - SubModuleReplacementDescription( - suffix="pre_extra_mlp_layernorm", - target_module=FusedRMSNorm, - ) - ], - policy=policy, - target_key=OpenMoeDecoderLayer) - - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="norm", - target_module=FusedRMSNorm, - ), - policy=policy, - target_key=OpenMoeModel) + policy=policy, + target_key=OpenMoeModel, + ) if self.shard_config.enable_flash_attention: raise NotImplementedError("Flash attention has already been replaced in openmoe.") @@ -84,17 +89,17 @@ def postprocess(self): def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: """If under pipeline parallel setting, replacing the original forward method of huggingface - to customized forward method, and add this changing to policy.""" + to customized forward method, and add this changing to policy.""" if self.pipeline_stage_manager: stage_manager = self.pipeline_stage_manager - if self.model.__class__.__name__ == "LlamaModel": + if self.model.__class__.__name__ == "OpenMoeModel": module = self.model else: module = self.model.model layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) @@ -105,7 +110,7 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" assert self.pipeline_stage_manager is not None - if self.model.__class__.__name__ == 'LlamaModel': + if self.model.__class__.__name__ == "LlamaModel": module = self.model else: module = self.model.model @@ -132,9 +137,11 @@ def module_policy(self): policy = super().module_policy() if self.pipeline_stage_manager: # set None as default - self.set_pipeline_forward(model_cls=OpenMoeModel, - new_forward=OpenMoePipelineForwards.openmoe_model_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=OpenMoeModel, + new_forward=OpenMoePipelineForwards.openmoe_model_forward, + policy=policy, + ) return policy def get_held_layers(self) -> List[Module]: @@ -150,7 +157,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: class OpenMoeForCausalLMPolicy(OpenMoePolicy): def module_policy(self): - policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: @@ -159,16 +165,21 @@ def module_policy(self): OpenMoeForCausalLM: ModulePolicyDescription(sub_module_replacement=[ SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) + suffix="lm_head", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True), + ) ]) } policy.update(new_item) if self.pipeline_stage_manager: # set None as default - self.set_pipeline_forward(model_cls=OpenMoeForCausalLM, - new_forward=OpenMoePipelineForwards.llama_for_causal_lm_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=OpenMoeForCausalLM, + new_forward=OpenMoePipelineForwards.llama_for_causal_lm_forward, + policy=policy, + ) return policy @@ -183,21 +194,21 @@ def get_held_layers(self) -> List[Module]: def get_shared_params(self) -> List[Dict[int, Tensor]]: llama_model = self.model.model if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: - if id(llama_model.embed_tokens.weight) == id( - self.model.lm_head.weight) and self.pipeline_stage_manager.num_stages > 1: + if (id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight) + and self.pipeline_stage_manager.num_stages > 1): # tie weights return [{ 0: llama_model.embed_tokens.weight, - self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight + self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, }] return [] class OpenMoePipelineForwards: - ''' + """ This class serves as a micro library for forward function substitution of Llama models under pipeline setting. - ''' + """ @staticmethod def openmoe_model_forward( @@ -222,12 +233,12 @@ def openmoe_model_forward( logger = logging.get_logger(__name__) - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions) output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) # retrieve input_ids and inputs_embeds if stage_manager.is_first_stage(): @@ -253,13 +264,13 @@ def openmoe_model_forward( # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False if past_key_values is not None: @@ -267,10 +278,12 @@ def openmoe_model_forward( seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: - position_ids = torch.arange(past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device) + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() @@ -278,11 +291,17 @@ def openmoe_model_forward( # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length_with_past), - dtype=torch.bool, - device=hidden_states.device) - attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), hidden_states, - past_key_values_length) + attention_mask = torch.ones( + (batch_size, seq_length_with_past), + dtype=torch.bool, + device=hidden_states.device, + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, + (batch_size, seq_length), + hidden_states, + past_key_values_length, + ) if self.gradient_checkpointing and self.training: if use_cache: @@ -300,7 +319,7 @@ def openmoe_model_forward( if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = past_key_values[idx] if past_key_values is not None else None + past_key_value = (past_key_values[idx] if past_key_values is not None else None) if self.gradient_checkpointing and self.training: @@ -351,9 +370,20 @@ def custom_forward(*inputs): router_z_loss = past_router_z_loss + router_z_loss if stage_manager.is_last_stage(): - return tuple([hidden_states, next_cache, all_hidden_states, all_self_attns, router_aux_loss, router_z_loss]) + return tuple([ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + router_aux_loss, + router_z_loss, + ]) # always return dict for imediate stage - return {'hidden_states': hidden_states, 'router_aux_loss': router_aux_loss, 'router_z_loss': router_z_loss} + return { + "hidden_states": hidden_states, + "router_aux_loss": router_aux_loss, + "router_z_loss": router_z_loss, + } @staticmethod def llama_for_causal_lm_forward( @@ -376,42 +406,42 @@ def llama_for_causal_lm_forward( past_router_z_loss: Optional[torch.FloatTensor] = None, ): r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - Returns: + Returns: - Example: + Example: - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - >>> prompt = "Hey, are you consciours? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." - ```""" + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" logger = logging.get_logger(__name__) - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions) output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) @@ -434,7 +464,14 @@ def llama_for_causal_lm_forward( ) if stage_manager.is_last_stage(): - hidden_states, past_key_values, all_hidden_states, attentions, router_aux_loss, router_z_loss = outputs + ( + hidden_states, + past_key_values, + all_hidden_states, + attentions, + router_aux_loss, + router_z_loss, + ) = outputs if self.pretraining_tp > 1: lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0) @@ -498,11 +535,11 @@ def custom_forward(*inputs): attentions=attentions, ) else: - hidden_states = outputs['hidden_states'] - router_aux_loss = outputs['router_aux_loss'] - router_z_loss = outputs['router_z_loss'] + hidden_states = outputs["hidden_states"] + router_aux_loss = outputs["router_aux_loss"] + router_z_loss = outputs["router_z_loss"] return { - 'hidden_states': hidden_states, - 'past_router_aux_loss': router_aux_loss, - 'past_router_z_loss': router_z_loss + "hidden_states": hidden_states, + "past_router_aux_loss": router_aux_loss, + "past_router_z_loss": router_z_loss, } diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index efb61d1bb69c..2099bbde91f5 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -96,6 +96,12 @@ def parse_args(): parser.add_argument("--ep_size", type=int, default=2, help="ep size") parser.add_argument("--zero_stage", type=int, default=1, help="zero stage in hybrid plugin") parser.add_argument("--microbatch_size", type=int, default=1, help="microbatch size") + # kernel + parser.add_argument( + "--use_kernel", + action="store_true", + help="Use kernel optim. Need to install flash attention, apex, triton to enable all kernel optimizations.", + ) # loss parser.add_argument( "--router_aux_loss_factor", @@ -131,7 +137,7 @@ def main(): MOE_MANAGER.setup( seed=42, parallel="EP", - use_kernel_optim=True if args.model_name != "test" else False, + use_kernel_optim=False if args.model_name == "test" else args.use_kernel, ) elif args.plugin == "hybrid": assert (args.dp_size * args.ep_size * @@ -143,7 +149,7 @@ def main(): fixed_dp_size=args.dp_size, fixed_ep_size=args.ep_size, fixed_pp_size=args.pp_size, - use_kernel_optim=True if args.model_name != "test" else False, + use_kernel_optim=False if args.model_name == "test" else args.use_kernel, ) # Manage loggers @@ -189,6 +195,8 @@ def main(): zero_stage=args.zero_stage, microbatch_size=args.microbatch_size, custom_policy=OpenMoeForCausalLMPolicy(), + enable_fused_normalization=args.use_kernel, + enable_jit_fused=args.use_kernel, ) else: raise ValueError(f"Invalid plugin {args.plugin}") From a97a201e37c73c8bd8c3c7841971a7f023cfeb8d Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 18 Sep 2023 17:40:29 +0800 Subject: [PATCH 14/16] disable ci --- examples/language/openmoe/test_ci.sh | 5 ----- 1 file changed, 5 deletions(-) diff --git a/examples/language/openmoe/test_ci.sh b/examples/language/openmoe/test_ci.sh index 8361b66c50d1..e69de29bb2d1 100644 --- a/examples/language/openmoe/test_ci.sh +++ b/examples/language/openmoe/test_ci.sh @@ -1,5 +0,0 @@ -set -xe -pip install -r requirements.txt - -python infer.py --model "test" -torchrun --standalone --nproc_per_node 2 train.py --model_name "test" --batch_size 1 --num_epoch 1 --plugin zero2 From 2817bc237f7360cec8efb43a5debf00d59e20d6b Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 18 Sep 2023 17:40:47 +0800 Subject: [PATCH 15/16] update train script --- examples/language/openmoe/train.sh | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/language/openmoe/train.sh b/examples/language/openmoe/train.sh index a2fe425c5805..6712aa10a88b 100644 --- a/examples/language/openmoe/train.sh +++ b/examples/language/openmoe/train.sh @@ -1,3 +1,9 @@ torchrun --standalone --nproc_per_node 4 train.py \ --model_name "base" \ + --plugin "hybrid" \ + --pp_size 2 \ + --dp_size 1 \ + --ep_size 2 \ + --use_kernel \ + --zero_stage 1 \ --batch_size 4 From 057067f0b798b56ca3b79fca61339bd08a3f056b Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 21 Sep 2023 09:53:02 +0800 Subject: [PATCH 16/16] update plugin --- .../plugin/moe_hybrid_parallel_plugin.py | 174 +----------------- 1 file changed, 4 insertions(+), 170 deletions(-) diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 04ec0ce57cef..fab6c2f0cb7b 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -1,36 +1,19 @@ -import random -from typing import Any, Callable, Iterator, List, Optional, Tuple, Union +from typing import Optional -import numpy as np import torch import torch.distributed as dist -from torch.nn import Module -from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler as LRScheduler -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler -from colossalai.booster.plugin.hybrid_parallel_plugin import ( - HybridParallelAMPOptimizer, - HybridParallelModule, - HybridParallelNaiveOptimizer, - HybridParallelZeroOptimizer, - get_param_info, -) -from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO +from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelPlugin from colossalai.cluster import ProcessGroupMesh -from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig from colossalai.shardformer.policies.base_policy import Policy -from .pp_plugin_base import PipelinePluginBase - PP_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2 -class MoeHybridParallelPlugin(PipelinePluginBase): +class MoeHybridParallelPlugin(HybridParallelPlugin): """ Plugin for Moe Hybrid Parallel Training. Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin. @@ -139,6 +122,7 @@ def __init__(self, self.enable_flash_attention = enable_flash_attention self.enable_jit_fused = enable_jit_fused self.enable_sequence_parallelism = enable_sequence_parallelism + # we change pg mesh to (pp, dp, tp) for better moe performance self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size) self.stage_manager = None self.schedule = None @@ -187,153 +171,3 @@ def __init__(self, partition_grad=(self.zero_stage == 2)) self.max_norm = max_norm - - @property - def enable_pipeline_parallelism(self) -> bool: - return self.pp_size > 1 - - def supported_devices(self) -> List[str]: - return ['cuda'] - - def supported_precisions(self) -> List[str]: - return ['fp16', 'bf16', 'fp32'] - - def control_device(self) -> bool: - return True - - def control_precision(self) -> bool: - return True - - def support_no_sync(self) -> bool: - return False - - def control_checkpoint_io(self) -> bool: - return True - - def configure( - self, - model: Module, - optimizer: Optional[Optimizer] = None, - criterion: Optional[Callable] = None, - dataloader: Optional[DataLoader] = None, - lr_scheduler: Optional[LRScheduler] = None, - ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: - param_info = get_param_info(optimizer) - if not isinstance(model, ModelWrapper): - use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 - model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp, - self.ddp_config, self.custom_policy) - if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): - if self.zero_stage == 0: - if self.precision in ['fp16', 'bf16']: - optimizer = HybridParallelAMPOptimizer(optimizer, - model, - use_pipeline=self.enable_pipeline_parallelism, - param_info=param_info, - precision=self.precision, - max_norm=self.max_norm, - **self.amp_config) - self.checkpoint_io.link_master_and_working_param(optimizer.working_to_master_map, - optimizer.master_to_working_map) - else: - optimizer = HybridParallelNaiveOptimizer(optimizer, - model, - use_pipeline=self.enable_pipeline_parallelism, - param_info=param_info) - else: - assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1." - assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO." - optimizer = HybridParallelZeroOptimizer(optimizer, - model, - use_pipeline=self.enable_pipeline_parallelism, - param_info=param_info, - dp_process_group=self.dp_group, - tp_process_group=self.tp_group, - verbose=True, - clip_grad_norm=self.max_norm, - **self.zero_config, - **self.amp_config) - self.checkpoint_io.link_master_and_working_param(optimizer._param_store.working_to_master_param, - optimizer._param_store.master_to_working_param) - - return model, optimizer, criterion, dataloader, lr_scheduler - - def execute_pipeline(self, - data_iter: Iterator, - model: HybridParallelModule, - criterion: Callable[[Any, Any], torch.Tensor], - optimizer: Optional[Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer, - HybridParallelZeroOptimizer]] = None, - return_loss: bool = True, - return_outputs: bool = False) -> dict: - assert self.enable_pipeline_parallelism, 'pipeline parallelism is not enabled' - # return loss or outputs if needed - ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() - with ctx: - outputs = self.schedule.forward_backward_step(model, data_iter, criterion, optimizer, return_loss, - return_outputs) - model.sync_shared_params() - if isinstance(optimizer, HybridParallelZeroOptimizer): - optimizer.sync_grad() - else: - model.sync_grads() - return outputs - - def prepare_dataloader(self, - dataset, - batch_size, - shuffle=False, - seed=1024, - drop_last=False, - pin_memory=False, - num_workers=0, - **kwargs): - r""" - Prepare a dataloader for distributed training. The dataloader will be wrapped by - `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. - - - Args: - dataset (`torch.utils.data.Dataset`): The dataset to be loaded. - shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. - seed (int, optional): Random worker seed for sampling, defaults to 1024. - add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. - drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size - is not divisible by the batch size. If False and the size of dataset is not divisible by - the batch size, then the last batch will be smaller, defaults to False. - pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. - num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. - kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in - `DataLoader `_. - - Returns: - :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. - """ - _kwargs = kwargs.copy() - sampler = DistributedSampler(dataset, - num_replicas=self.pg_mesh.size(DP_AXIS), - rank=self.pg_mesh.coordinate(DP_AXIS), - shuffle=shuffle) - - # Deterministic dataloader - def seed_worker(worker_id): - worker_seed = seed - np.random.seed(worker_seed) - torch.manual_seed(worker_seed) - random.seed(worker_seed) - - return DataLoader(dataset, - batch_size=batch_size, - sampler=sampler, - worker_init_fn=seed_worker, - drop_last=drop_last, - pin_memory=pin_memory, - num_workers=num_workers, - **_kwargs) - - def get_checkpoint_io(self) -> CheckpointIO: - self.checkpoint_io = HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) - return self.checkpoint_io - - def no_sync(self, model: Module) -> Iterator[None]: - raise NotImplementedError