From bbae495ef37f7d7d5aa128c34fedff2d9ee504cd Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 4 Sep 2023 13:25:04 +0800 Subject: [PATCH 01/14] init --- .../language/openmoe/convert_openmoe_ckpt.py | 174 ++++ .../language/openmoe/convert_openmoe_ckpt.sh | 1 + examples/language/openmoe/modeling_openmoe.py | 960 ++++++++++++++++++ examples/language/openmoe/openmoe_config.json | 12 + 4 files changed, 1147 insertions(+) create mode 100644 examples/language/openmoe/convert_openmoe_ckpt.py create mode 100644 examples/language/openmoe/convert_openmoe_ckpt.sh create mode 100644 examples/language/openmoe/modeling_openmoe.py create mode 100644 examples/language/openmoe/openmoe_config.json diff --git a/examples/language/openmoe/convert_openmoe_ckpt.py b/examples/language/openmoe/convert_openmoe_ckpt.py new file mode 100644 index 000000000000..c84a960c2d3a --- /dev/null +++ b/examples/language/openmoe/convert_openmoe_ckpt.py @@ -0,0 +1,174 @@ +# coding=utf-8 +# Copyright 2022 Google LLC and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Convert T5X checkpoint to PyTorch + +Steps: +- Install gsutil according to https://cloud.google.com/storage/docs/gsutil_install +- Get a T5X checkpoint at https://github.com/google-research/t5x/blob/main/docs/models.md#t5-11-checkpoints Example: + `gsutil -m cp -r gs://t5-data/pretrained_models/t5x/t5_1_1_small $HOME/` +- Create or download a corresponding config for the downloaded model. E.g. for T5 v1.1 small, you can use + https://huggingface.co/google/t5-v1_1-small/blob/main/config.json +- Convert: + ``` + python3 convert_t5x_checkpoint_to_pytorch.py --t5x_checkpoint_path=$HOME/t5_1_1_small --config_file=config.json\ + --pytorch_dump_path=$HOME/t5_1_1_small_pt + ``` +""" + +import argparse +import collections + +import torch +from flax import traverse_util +from inference.moe.modeling_openllama import OpenLlamaForCausalLM +from t5x import checkpoints +from transformers import LlamaConfig +from transformers.utils import logging + +logging.set_verbosity_info() + + +def t5x_attention_lookup(params, i, prefix, layer_name="attention"): + """Returns the KOQV parameters of (self-)attention. Does not transpose.""" + k = params[f"{prefix}/layers_{i}/{layer_name}/key/kernel"] + o = params[f"{prefix}/layers_{i}/{layer_name}/out/kernel"] + q = params[f"{prefix}/layers_{i}/{layer_name}/query/kernel"] + v = params[f"{prefix}/layers_{i}/{layer_name}/value/kernel"] + return k, o, q, v + + +def t5x_mlp_lookup(params, i, prefix, split_mlp_wi=False): + """Returns the MLP parameters of a layer. Does not transpose.""" + if split_mlp_wi: + wi_0 = params[f"{prefix}/layers_{i}/mlp/wi_0/kernel"] + wi_1 = params[f"{prefix}/layers_{i}/mlp/wi_1/kernel"] + wi = (wi_0, wi_1) + else: + wi = params[f"{prefix}/layers_{i}/mlp/wi/kernel"] + + wo = params[f"{prefix}/layers_{i}/mlp/wo/kernel"] + return wi, wo + + +def t5x_layer_norm_lookup(params, i, prefix, layer_name): + """Returns the layer norm param of a layer.""" + return params[f"{prefix}/layers_{i}/{layer_name}/scale"] + + +def convert_t5x_to_pytorch(variables: dict, *, num_layers: int): + """Converts the parameters from T5X-Flax to Transformers-PyTorch.""" + old = traverse_util.flatten_dict(variables["target"]) + old = {"/".join(k): v for k, v in old.items()} + + # v1.1 models have a gated GeLU with wi_0 and wi_1 instead of wi + split_mlp_wi = True + print("Split MLP:", split_mlp_wi) + + new = collections.OrderedDict() + print(old.keys()) + for key, value in old.items(): + print(f"{key}: {value.shape}") + + # Shared embeddings. + new["model.embed_tokens.weight"] = old["token_embedder/embedding"] + + # Decoder. + for i in range(num_layers): + # Block i, layer 0 (Self Attention). + layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_self_attention_layer_norm") + k, o, q, v = t5x_attention_lookup(old, i, "decoder", "self_attention") + new[f"model.layers.{i}.input_layernorm.weight"] = layer_norm + new[f"model.layers.{i}.self_attn.k_proj.weight"] = k.T + new[f"model.layers.{i}.self_attn.o_proj.weight"] = o.T + new[f"model.layers.{i}.self_attn.q_proj.weight"] = q.T + new[f"model.layers.{i}.self_attn.v_proj.weight"] = v.T + + # Block i, layer 2 (MLP). + layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_mlp_layer_norm") + wi, wo = t5x_mlp_lookup(old, i, "decoder", split_mlp_wi) + new[f"model.layers.{i}.post_attention_layernorm.weight"] = layer_norm + new[f"model.layers.{i}.mlp.gate_proj.weight"] = wi[0].T + new[f"model.layers.{i}.mlp.up_proj.weight"] = wi[1].T + new[f"model.layers.{i}.mlp.down_proj.weight"] = wo.T + + new["model.norm.weight"] = old["decoder/decoder_norm/scale"] + + # LM Head (only in v1.1 checkpoints, in v1.0 embeddings are used instead) + if "decoder/logits_dense/kernel" in old: + new["lm_head.weight"] = old["decoder/logits_dense/kernel"].T + + return new + + +def make_state_dict(converted_params): + """Prepares a state dict for the PyTorch model.""" + # Make a state dict with torch tensors. + state_dict = collections.OrderedDict([(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()]) + + return state_dict + + +def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path): + """Replaces the params in model witht the T5X converted params.""" + variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path) + converted = convert_t5x_to_pytorch(variables, num_layers=config.num_hidden_layers) + state_dict = make_state_dict(converted) + model.load_state_dict(state_dict, strict=True) + + +def convert_t5x_checkpoint_to_pytorch(t5x_checkpoint_path, config_file, pytorch_dump_path): + """Loads the config and model, converts the T5X checkpoint, and saves a PyTorch checkpoint.""" + # Initialise PyTorch model + config = LlamaConfig.from_json_file(config_file) + print(f"Building PyTorch model from configuration: {config}") + # Non-v1.1 checkpoints could also use T5Model, but this works for all. + # The v1.0 checkpoints will simply have an LM head that is the word embeddings. + model = OpenLlamaForCausalLM(config) + + # Load weights from tf checkpoint + load_t5x_weights_in_t5(model, config, t5x_checkpoint_path) + + # Save pytorch-model + print(f"Save PyTorch model to {pytorch_dump_path}") + model.save_pretrained(pytorch_dump_path) + + # Verify that we can load the checkpoint. + model.from_pretrained(pytorch_dump_path) + print("Done") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Converts a native T5X checkpoint into a PyTorch checkpoint.") + # Required parameters + parser.add_argument("--t5x_checkpoint_path", + default=None, + type=str, + required=True, + help="Path to the T5X checkpoint.") + parser.add_argument( + "--config_file", + default=None, + type=str, + required=True, + help="The config json file corresponding to the pre-trained T5 model.\nThis specifies the model architecture.", + ) + parser.add_argument("--pytorch_dump_path", + default=None, + type=str, + required=True, + help="Path to the output PyTorch model.") + args = parser.parse_args() + convert_t5x_checkpoint_to_pytorch(args.t5x_checkpoint_path, args.config_file, args.pytorch_dump_path) diff --git a/examples/language/openmoe/convert_openmoe_ckpt.sh b/examples/language/openmoe/convert_openmoe_ckpt.sh new file mode 100644 index 000000000000..1805ca8172ae --- /dev/null +++ b/examples/language/openmoe/convert_openmoe_ckpt.sh @@ -0,0 +1 @@ +python convert_openmoe_ckpt.py --t5x_checkpoint_path /data3/users/lczxl/OpenMoE/openmoe_base --config_file ./openmoe_config.json --pytorch_dump_path /data3/users/lczxl/OpenMoE/openmoe_base_pytorch diff --git a/examples/language/openmoe/modeling_openmoe.py b/examples/language/openmoe/modeling_openmoe.py new file mode 100644 index 000000000000..2fb8bdeb30bd --- /dev/null +++ b/examples/language/openmoe/modeling_openmoe.py @@ -0,0 +1,960 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch LLaMA model.""" +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.models.llama import LlamaConfig +from transformers.models.t5.modeling_t5 import T5LayerNorm +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlamaConfig" + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask(input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +def generate_fixed_pos_embedding(features, length, min_timescale=1.0, max_timescale=10000.0): + """Generate Sin/Cos for Rotary Embeddings. + + Args: + features: an integer + length: an integer + min_timescale: an optional float + max_timescale: an optional float + + Returns: + output_sin: a float32 Tensor with shape [length, features] + output_cos: a float32 Tensor with shape [length, features] + """ + fraction = torch.arange(0, features, 2, dtype=torch.float32).cuda() / features + timescale = min_timescale * (max_timescale / min_timescale)**fraction + rotational_frequency = 1. / timescale + + sinusoid_inp = torch.einsum('i,j->ij', torch.arange(length, dtype=torch.float32).cuda(), rotational_frequency) + + sinusoid_inp = torch.cat([sinusoid_inp, sinusoid_inp], dim=-1) + + return torch.sin(sinusoid_inp).to(torch.bfloat16), torch.cos(sinusoid_inp).to(torch.bfloat16) + + +def apply_rotary_embedding(q, k, cos, sin, decode=False, rotary_index=None): + """Helper function to apply Rotary Embeddings.""" + if len(k.shape) == 3: + # for multi query attention + k = k.unsqueeze(2) + multiquery = True + else: + multiquery = False + + batch, qlen, qheads, d = q.shape + kbatch, klen, kheads, kd = k.shape + assert batch == kbatch, f'{batch} != {kbatch}' + assert d == kd, f'{d} != {kd}' + if decode and qlen == 1 and rotary_index is not None: + qcos = cos[rotary_index, :] + qsin = sin[rotary_index, :] + qcos = qcos.unsqueeze(1).unsqueeze(2).expand(batch, qlen, qheads, d) + qsin = qsin.unsqueeze(1).unsqueeze(2).expand(batch, qlen, qheads, d) + else: + qcos, qsin = cos[:qlen, :], sin[:qlen, :] + qcos = qcos.unsqueeze(0).unsqueeze(2).expand(batch, qlen, qheads, d) + qsin = qsin.unsqueeze(0).unsqueeze(2).expand(batch, qlen, qheads, d) + + kcos, ksin = cos[:klen, :], sin[:klen, :] + kcos = kcos.unsqueeze(0).unsqueeze(2).expand(batch, klen, kheads, d) + ksin = ksin.unsqueeze(0).unsqueeze(2).expand(batch, klen, kheads, d) + + out_q = (q * qcos) + (rotate_half(q) * qsin) + out_k = (k * kcos) + (rotate_half(k) * ksin) + + if multiquery: + out_k = out_k.squeeze(2) + + return out_q, out_k + + +class LlamaRotaryEmbedding(torch.nn.Module): + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.inv_freq = inv_freq + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache(seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype()) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + ) + + +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) - + (self.scaling_factor - 1))**(self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def SwiGLU(x): + """Gated linear unit activation function. + Args: + x : input array + axis: the axis along which the split should be computed (default: -1) + """ + size = x.shape[-1] + assert size % 2 == 0, "axis size must be divisible by 2" + x1, x2 = torch.split(x, size // 2, -1) + return x1 * (x2 * torch.sigmoid(x2)) + + +class LlamaMLP(nn.Module): + + def __init__(self, config): + super().__init__() + self.pretraining_tp = config.pretraining_tp + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = SwiGLU + + def forward(self, x): + if self.pretraining_tp > 1: + slice = self.intermediate_size // self.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)] + down_proj = sum(down_proj) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.pretraining_tp = config.pretraining_tp + self.max_position_embeddings = config.max_position_embeddings + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads}).") + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + # import pdb; pdb.set_trace() + + if self.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp + query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + dim = query_states.shape[-1] + max_length = max(query_states.shape[1], key_states.shape[1]) + sin, cos = generate_fixed_pos_embedding(dim, max_length, max_timescale=1e4) + query_states, key_states = apply_rotary_embedding(query_states, + key_states, + cos, + sin, + decode=True if q_len == 1 else False, + rotary_index=position_ids) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError(f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}") + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}") + attention_mask[:, :, :, 0] = 0 + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaDecoderLayer(nn.Module): + + def __init__(self, config: LlamaConfig, moe: bool): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = LlamaAttention(config=config) + self.mlp = LlamaMLP(config) + self.input_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, LlamaModel): + module.gradient_checkpointing = value + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([ + LlamaDecoderLayer(config, moe=True if (i + 1) % 4 == 0 else False) for i in range(config.num_hidden_layers) + ]) + self.norm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, + tgt_len=input_shape[-1]).to(inputs_embeds.device) + combined_attention_mask = (expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + + combined_attention_mask) + + return combined_attention_mask + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange(past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # import pdb; pdb.set_trace() + # embed positions + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), + dtype=torch.bool, + device=inputs_embeds.device) + attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds, + past_key_values_length) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class OpenLlamaForCausalLM(LlamaPreTrainedModel): + # _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.pretraining_tp = config.pretraining_tp + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation(self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs): + if past_key_values: + input_ids = input_ids[:, -1:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update({ + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + }) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += (tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),) + return reordered_past diff --git a/examples/language/openmoe/openmoe_config.json b/examples/language/openmoe/openmoe_config.json new file mode 100644 index 000000000000..71d2b37563aa --- /dev/null +++ b/examples/language/openmoe/openmoe_config.json @@ -0,0 +1,12 @@ +{ + "architectures": [ + "OpenLlamaForCausalLM" + ], + "intermediate_size": 2048, + "hidden_size": 768, + "num_hidden_layers": 12, + "num_attention_heads": 12, + "dropout_rate": 0.0, + "layer_norm_epsilon": 1e-06, + "vocab_size": 256384 +} From 255cdae03c0dda5f78a1ab7681f32da507c11eb9 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 4 Sep 2023 14:30:12 +0800 Subject: [PATCH 02/14] update moe ckpt --- .../language/openmoe/convert_openmoe_ckpt.py | 58 +++++++++++++++++-- 1 file changed, 53 insertions(+), 5 deletions(-) diff --git a/examples/language/openmoe/convert_openmoe_ckpt.py b/examples/language/openmoe/convert_openmoe_ckpt.py index c84a960c2d3a..42a3f83054d4 100644 --- a/examples/language/openmoe/convert_openmoe_ckpt.py +++ b/examples/language/openmoe/convert_openmoe_ckpt.py @@ -33,7 +33,7 @@ import torch from flax import traverse_util -from inference.moe.modeling_openllama import OpenLlamaForCausalLM +from modeling_openmoe import OpenLlamaForCausalLM from t5x import checkpoints from transformers import LlamaConfig from transformers.utils import logging @@ -63,6 +63,37 @@ def t5x_mlp_lookup(params, i, prefix, split_mlp_wi=False): return wi, wo +def t5x_extra_mlp_lookup(params, i, prefix, split_mlp_wi=False): + """Returns the MLP parameters of a layer. Does not transpose.""" + if split_mlp_wi: + wi_0 = params[f"{prefix}/layers_{i}/extra_mlp/wi_0/kernel"] + wi_1 = params[f"{prefix}/layers_{i}/extra_mlp/wi_1/kernel"] + wi = (wi_0, wi_1) + else: + wi = params[f"{prefix}/layers_{i}/extra_mlp/wi/kernel"] + + wo = params[f"{prefix}/layers_{i}/extra_mlp/wo/kernel"] + return wi, wo + + +def t5x_experts_lookup(params, i, prefix, split_mlp_wi=False): + """Returns the MLP parameters of a layer. Does not transpose.""" + if split_mlp_wi: + wi_0 = params[f"{prefix}/layers_{i}/mlp/expert/wi_0/kernel"] + wi_1 = params[f"{prefix}/layers_{i}/mlp/expert/wi_1/kernel"] + wi = (wi_0, wi_1) + else: + wi = params[f"{prefix}/layers_{i}/mlp/expert/wi/kernel"] + + wo = params[f"{prefix}/layers_{i}/mlp/expert/wo/kernel"] + return wi, wo + + +def t5x_gate_lookup(params, i, prefix, split_mlp_wi=False): + """Returns the MLP parameters of a layer. Does not transpose.""" + return params[f"{prefix}/layers_{i}/mlp/router/router_weights/w/kernel"] + + def t5x_layer_norm_lookup(params, i, prefix, layer_name): """Returns the layer norm param of a layer.""" return params[f"{prefix}/layers_{i}/{layer_name}/scale"] @@ -98,11 +129,28 @@ def convert_t5x_to_pytorch(variables: dict, *, num_layers: int): # Block i, layer 2 (MLP). layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_mlp_layer_norm") - wi, wo = t5x_mlp_lookup(old, i, "decoder", split_mlp_wi) new[f"model.layers.{i}.post_attention_layernorm.weight"] = layer_norm - new[f"model.layers.{i}.mlp.gate_proj.weight"] = wi[0].T - new[f"model.layers.{i}.mlp.up_proj.weight"] = wi[1].T - new[f"model.layers.{i}.mlp.down_proj.weight"] = wo.T + + if (i + 1) % 4 == 0: + # moe + gate = t5x_gate_lookup(old, i, "decoder", split_mlp_wi) + new[f"model.layers.{i}.mlp.gate_weight"] = gate.T + wi, wo = t5x_experts_lookup(old, i, "decoder", split_mlp_wi) + new[f"model.layers.{i}.mlp.experts.wi_gate"] = wi[0] + new[f"model.layers.{i}.mlp.experts.wi_up"] = wi[1] + new[f"model.layers.{i}.mlp.experts.wo"] = wo + # extra + layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_extra_mlp_layer_norm") + new[f"model.layers.{i}.pre_extra_mlp_layernorm.weight"] = layer_norm + wi, wo = t5x_extra_mlp_lookup(old, i, "decoder", split_mlp_wi) + new[f"model.layers.{i}.extra_mlp.gate_proj.weight"] = wi[0].T + new[f"model.layers.{i}.extra_mlp.up_proj.weight"] = wi[1].T + new[f"model.layers.{i}.extra_mlp.down_proj.weight"] = wo.T + else: + wi, wo = t5x_mlp_lookup(old, i, "decoder", split_mlp_wi) + new[f"model.layers.{i}.mlp.gate_proj.weight"] = wi[0].T + new[f"model.layers.{i}.mlp.up_proj.weight"] = wi[1].T + new[f"model.layers.{i}.mlp.down_proj.weight"] = wo.T new["model.norm.weight"] = old["decoder/decoder_norm/scale"] From dbf634d6eab54e48acf7cbf98f4d02f635d94741 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 4 Sep 2023 14:30:51 +0800 Subject: [PATCH 03/14] update config --- examples/language/openmoe/openmoe_config.json | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/language/openmoe/openmoe_config.json b/examples/language/openmoe/openmoe_config.json index 71d2b37563aa..6401ebcb7aea 100644 --- a/examples/language/openmoe/openmoe_config.json +++ b/examples/language/openmoe/openmoe_config.json @@ -8,5 +8,6 @@ "num_attention_heads": 12, "dropout_rate": 0.0, "layer_norm_epsilon": 1e-06, - "vocab_size": 256384 + "vocab_size": 256384, + "hidden_act": "swiglu" } From a4d3c0ae5bf5f4a1e2e30b1df53847cf0c18e8e6 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 4 Sep 2023 17:32:51 +0800 Subject: [PATCH 04/14] support openmoe infernece --- colossalai/nn/layer/moe/__init__.py | 4 +- colossalai/nn/layer/moe/experts.py | 41 +++++++++----- colossalai/nn/layer/moe/layers.py | 6 ++- colossalai/nn/layer/moe/utils.py | 35 +++++++----- examples/language/openmoe/README.md | 17 ++++++ .../language/openmoe/convert_openmoe_ckpt.sh | 1 - examples/language/openmoe/infer.py | 44 +++++++++++++++ examples/language/openmoe/infer.sh | 1 + .../{ => model}/convert_openmoe_ckpt.py | 4 +- .../openmoe/model/convert_openmoe_ckpt.sh | 1 + .../openmoe/{ => model}/modeling_openmoe.py | 54 +++++++++++++------ .../openmoe_base_config.json} | 2 +- 12 files changed, 161 insertions(+), 49 deletions(-) create mode 100644 examples/language/openmoe/README.md delete mode 100644 examples/language/openmoe/convert_openmoe_ckpt.sh create mode 100644 examples/language/openmoe/infer.py create mode 100644 examples/language/openmoe/infer.sh rename examples/language/openmoe/{ => model}/convert_openmoe_ckpt.py (99%) create mode 100644 examples/language/openmoe/model/convert_openmoe_ckpt.sh rename examples/language/openmoe/{ => model}/modeling_openmoe.py (95%) rename examples/language/openmoe/{openmoe_config.json => model/openmoe_base_config.json} (89%) diff --git a/colossalai/nn/layer/moe/__init__.py b/colossalai/nn/layer/moe/__init__.py index c20d16181909..f99353d0e0dd 100644 --- a/colossalai/nn/layer/moe/__init__.py +++ b/colossalai/nn/layer/moe/__init__.py @@ -2,9 +2,9 @@ from .experts import EPMLPExperts, TPMLPExperts from .layers import MoeLayer, MoeModule, SparseMLP from .routers import MoeRouter, Top1Router, Top2Router -from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts +from .utils import NormalNoiseGenerator, UniformNoiseGenerator __all__ = [ 'EPMLPExperts', 'TPMLPExperts', 'Top1Router', 'Top2Router', 'MoeModule', 'MoeLayer', 'NormalNoiseGenerator', - 'UniformNoiseGenerator', 'build_ffn_experts', 'SparseMLP', 'MoeRouter', 'MoeCheckpintIO' + 'UniformNoiseGenerator', 'SparseMLP', 'MoeRouter', 'MoeCheckpintIO' ] diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/nn/layer/moe/experts.py index 608eca05435e..a64fcf68fc66 100644 --- a/colossalai/nn/layer/moe/experts.py +++ b/colossalai/nn/layer/moe/experts.py @@ -8,6 +8,7 @@ from colossalai.context import ParallelMode, seed from colossalai.context.moe_context import MOE_CONTEXT from colossalai.nn.layer.moe._operation import MoeInGradScaler, MoeOutGradScaler +from colossalai.nn.layer.moe.utils import get_activation from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_size, set_moe_tensor_info @@ -24,11 +25,13 @@ def __init__( expert_parallel: str = None, activation: str = None, drop_rate: float = 0, + gated: bool = False, ): super().__init__() assert expert_parallel in ["EP", "TP", None] self.expert_parallel = expert_parallel self.num_total_experts = num_experts + self.gated = gated # get expert parallel info if expert_parallel is not None: @@ -47,14 +50,19 @@ def __init__( self.num_local_experts = self.num_total_experts self.ep_size = 1 - self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) + if gated: + self.wi_gate = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size * 2)) + self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) + else: + self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) self.wo = nn.Parameter(torch.empty(num_experts, intermediate_size, hidden_size)) - with seed(ParallelMode.TENSOR): - nn.init.trunc_normal_(self.wi, std=math.sqrt(0.1 / hidden_size)) - nn.init.trunc_normal_(self.wo, std=math.sqrt(0.1 / intermediate_size)) + if expert_parallel is not None: + with seed(ParallelMode.TENSOR): + nn.init.trunc_normal_(self.wi, std=math.sqrt(0.1 / hidden_size)) + nn.init.trunc_normal_(self.wo, std=math.sqrt(0.1 / intermediate_size)) - self.act = nn.GELU() if activation is None else activation + self.act = get_activation(activation) self.drop = nn.Dropout(p=drop_rate) if expert_parallel is not None: @@ -71,10 +79,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # inputs [g, e, c, h] inshape = x.shape x = x.reshape(e, -1, h) - x = torch.bmm(x, self.wi) - x = self.act(x) - with seed(ParallelMode.TENSOR): - x = self.drop(x) + if self.gated: + x = self.act(torch.bmm(x, self.wi_gate)) * torch.bmm(x, self.wi_up) + else: + x = torch.bmm(x, self.wi) + x = self.act(x) + + if self.expert_parallel is not None: + with seed(ParallelMode.TENSOR): + x = self.drop(x) x = torch.bmm(x, self.wo) x = x.reshape(inshape) @@ -93,8 +106,9 @@ def __init__(self, hidden_size: int, intermediate_size: int, activation=None, - drop_rate: float = 0): - super().__init__(num_experts, hidden_size, intermediate_size, "EP", activation, drop_rate) + drop_rate: float = 0, + gated: bool = False): + super().__init__(num_experts, hidden_size, intermediate_size, "EP", activation, drop_rate, gated) def state_dict(self, destination=None, prefix='', keep_vars=False): dp_rank = dist.get_rank(get_dp_group(self)) @@ -134,8 +148,9 @@ def __init__(self, hidden_size: int, intermediate_size: int, activation: str = None, - drop_rate: float = 0): - super().__init__(num_experts, hidden_size, intermediate_size, "TP", activation, drop_rate) + drop_rate: float = 0, + gated: bool = False): + super().__init__(num_experts, hidden_size, intermediate_size, "TP", activation, drop_rate, gated) def get_expert_class(name: str) -> BaseMLPExperts: diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py index f39eab40d28b..6b7be9eb57c0 100644 --- a/colossalai/nn/layer/moe/layers.py +++ b/colossalai/nn/layer/moe/layers.py @@ -60,7 +60,8 @@ def __init__(self, expert_parallel: str = "EP", hidden_size: int = 2048, intermediate_size: int = 2048, - activation: str = None): + activation: str = None, + gated: bool = False): super().__init__() self.hidden_size = hidden_size self.num_experts = num_experts @@ -82,7 +83,8 @@ def __init__(self, self.experts: BaseMLPExperts = expert_cls(num_experts=num_experts, hidden_size=hidden_size, intermediate_size=intermediate_size, - activation=activation) + activation=activation, + gated=gated) if expert_parallel is not None: self.ep_group = get_ep_group(self.experts) self.ep_size = get_ep_size(self.experts) diff --git a/colossalai/nn/layer/moe/utils.py b/colossalai/nn/layer/moe/utils.py index eb3bef70998d..369f6c0752ac 100644 --- a/colossalai/nn/layer/moe/utils.py +++ b/colossalai/nn/layer/moe/utils.py @@ -6,8 +6,6 @@ from colossalai.context.moe_context import MOE_CONTEXT from colossalai.utils import get_current_device -from .experts import EPMLPExperts, TPMLPExperts - class ForceFP32Parameter(torch.nn.Parameter): @@ -60,16 +58,6 @@ def autocast_softmax(logit: torch.Tensor, dim: int): return F.softmax(logit, dim=dim, detype=torch.float32) -def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): - mep_size = MOE_CONTEXT.max_ep_size - if num_experts % mep_size == 0 or mep_size % num_experts == 0: - return EPMLPExperts(num_experts, d_model, d_ff, activation, drop_rate) - elif d_ff % mep_size == 0: - return TPMLPExperts(num_experts, d_model, d_ff, activation, drop_rate) - else: - raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.") - - def get_noise_generator(noise_type: str, num_experts: int) -> Callable: if noise_type is None: return None @@ -80,3 +68,26 @@ def get_noise_generator(noise_type: str, num_experts: int) -> Callable: else: raise NotImplementedError("Unsupported input noisy policy") return noisy_func + + +def get_activation(act: str) -> Callable: + if act is None or act == 'relu': + return torch.nn.ReLU() + elif act == 'gelu': + return torch.nn.GELU() + elif act == 'swiglu': + return SwiGLU + else: + raise NotImplementedError("Unsupported activation function") + + +def SwiGLU(x): + """Gated linear unit activation function. + Args: + x : input array + axis: the axis along which the split should be computed (default: -1) + """ + size = x.shape[-1] + assert size % 2 == 0, "axis size must be divisible by 2" + x1, x2 = torch.split(x, size // 2, -1) + return x1 * (x2 * torch.sigmoid(x2)) diff --git a/examples/language/openmoe/README.md b/examples/language/openmoe/README.md new file mode 100644 index 000000000000..26b5ee73b054 --- /dev/null +++ b/examples/language/openmoe/README.md @@ -0,0 +1,17 @@ +## OpenMoE +[OpenMoE](https://github.com/XueFuzhao/OpenMoE) is a project aimed at Igniting the Open-Source MoE Community! + +The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates finetune and inference methods. + + +## Our Modifications + +We reimplement OpenMoE with PyTorch + GPU. + +## Run Inference + +By running the following script: +```bash +bash infer.sh +``` +You will infer a [OpenMoE-8B/32E](https://github.com/XueFuzhao/OpenMoE) model. diff --git a/examples/language/openmoe/convert_openmoe_ckpt.sh b/examples/language/openmoe/convert_openmoe_ckpt.sh deleted file mode 100644 index 1805ca8172ae..000000000000 --- a/examples/language/openmoe/convert_openmoe_ckpt.sh +++ /dev/null @@ -1 +0,0 @@ -python convert_openmoe_ckpt.py --t5x_checkpoint_path /data3/users/lczxl/OpenMoE/openmoe_base --config_file ./openmoe_config.json --pytorch_dump_path /data3/users/lczxl/OpenMoE/openmoe_base_pytorch diff --git a/examples/language/openmoe/infer.py b/examples/language/openmoe/infer.py new file mode 100644 index 000000000000..8285af3f730a --- /dev/null +++ b/examples/language/openmoe/infer.py @@ -0,0 +1,44 @@ +from argparse import ArgumentParser + +import torch +from model.modeling_openmoe import OpenMoeForCausalLM +from transformers import T5Tokenizer + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument("--path", default="/path/to/openmoe", type=str, help="model path") + return parser.parse_args() + + +def inference(args): + + tokenizer = T5Tokenizer.from_pretrained("google/umt5-small") + model = OpenMoeForCausalLM.from_pretrained(args.path) + model = model.eval().bfloat16() + model = model.to(torch.cuda.current_device()) + + input_str = """``` +y = list(map(int, ['1', 'hello', '2'])) +``` +What error does this program produce? +ValueError: invalid literal for int() with base 10: 'hello' + +``` +sum = 0 +for i in range(100): + sum += i +``` +What is the value of sum immediately after the 10th time line 3 is executed?""" + + # print("model config: ", model.config) + input_ids = tokenizer("" + input_str, return_tensors="pt", add_special_tokens=True) + input_ids = input_ids.input_ids.to(torch.cuda.current_device()) + generation_output = model.generate(input_ids, use_cache=True, do_sample=True, max_new_tokens=128) + out = tokenizer.decode(generation_output[0], skip_special_tokens=False) + print(f"output: \n{out}\n") + + +if __name__ == "__main__": + args = parse_args() + inference(args) diff --git a/examples/language/openmoe/infer.sh b/examples/language/openmoe/infer.sh new file mode 100644 index 000000000000..78787f48fbb8 --- /dev/null +++ b/examples/language/openmoe/infer.sh @@ -0,0 +1 @@ +python infer.py --path /path/to/openmoe diff --git a/examples/language/openmoe/convert_openmoe_ckpt.py b/examples/language/openmoe/model/convert_openmoe_ckpt.py similarity index 99% rename from examples/language/openmoe/convert_openmoe_ckpt.py rename to examples/language/openmoe/model/convert_openmoe_ckpt.py index 42a3f83054d4..d78729f44182 100644 --- a/examples/language/openmoe/convert_openmoe_ckpt.py +++ b/examples/language/openmoe/model/convert_openmoe_ckpt.py @@ -33,7 +33,7 @@ import torch from flax import traverse_util -from modeling_openmoe import OpenLlamaForCausalLM +from modeling_openmoe import OpenMoeForCausalLM from t5x import checkpoints from transformers import LlamaConfig from transformers.utils import logging @@ -184,7 +184,7 @@ def convert_t5x_checkpoint_to_pytorch(t5x_checkpoint_path, config_file, pytorch_ print(f"Building PyTorch model from configuration: {config}") # Non-v1.1 checkpoints could also use T5Model, but this works for all. # The v1.0 checkpoints will simply have an LM head that is the word embeddings. - model = OpenLlamaForCausalLM(config) + model = OpenMoeForCausalLM(config) # Load weights from tf checkpoint load_t5x_weights_in_t5(model, config, t5x_checkpoint_path) diff --git a/examples/language/openmoe/model/convert_openmoe_ckpt.sh b/examples/language/openmoe/model/convert_openmoe_ckpt.sh new file mode 100644 index 000000000000..c0d53f562e40 --- /dev/null +++ b/examples/language/openmoe/model/convert_openmoe_ckpt.sh @@ -0,0 +1 @@ +python convert_openmoe_ckpt.py --t5x_checkpoint_path /path/to/t5x --config_file /path/to/config --pytorch_dump_path /path/to/save diff --git a/examples/language/openmoe/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py similarity index 95% rename from examples/language/openmoe/modeling_openmoe.py rename to examples/language/openmoe/model/modeling_openmoe.py index 2fb8bdeb30bd..ff9d51403005 100644 --- a/examples/language/openmoe/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -17,21 +17,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch LLaMA model.""" -import math +""" PyTorch OpenMoE model.""" from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F import torch.utils.checkpoint from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from torch.nn import CrossEntropyLoss from transformers.activations import ACT2FN -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - SequenceClassifierOutputWithPast, -) +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_utils import PreTrainedModel from transformers.models.llama import LlamaConfig from transformers.models.t5.modeling_t5 import T5LayerNorm @@ -42,6 +37,8 @@ replace_return_docstrings, ) +from colossalai.nn.layer.moe.layers import SparseMLP + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "LlamaConfig" @@ -94,11 +91,11 @@ def generate_fixed_pos_embedding(features, length, min_timescale=1.0, max_timesc output_sin: a float32 Tensor with shape [length, features] output_cos: a float32 Tensor with shape [length, features] """ - fraction = torch.arange(0, features, 2, dtype=torch.float32).cuda() / features + fraction = torch.arange(0, features, 2, dtype=torch.float64).cuda() / features timescale = min_timescale * (max_timescale / min_timescale)**fraction rotational_frequency = 1. / timescale - sinusoid_inp = torch.einsum('i,j->ij', torch.arange(length, dtype=torch.float32).cuda(), rotational_frequency) + sinusoid_inp = torch.einsum('i,j->ij', torch.arange(length, dtype=torch.float64).cuda(), rotational_frequency) sinusoid_inp = torch.cat([sinusoid_inp, sinusoid_inp], dim=-1) @@ -119,10 +116,10 @@ def apply_rotary_embedding(q, k, cos, sin, decode=False, rotary_index=None): assert batch == kbatch, f'{batch} != {kbatch}' assert d == kd, f'{d} != {kd}' if decode and qlen == 1 and rotary_index is not None: - qcos = cos[rotary_index, :] - qsin = sin[rotary_index, :] - qcos = qcos.unsqueeze(1).unsqueeze(2).expand(batch, qlen, qheads, d) - qsin = qsin.unsqueeze(1).unsqueeze(2).expand(batch, qlen, qheads, d) + qcos = cos[rotary_index + 1, :] + qsin = sin[rotary_index + 1, :] + qcos = qcos.unsqueeze(2).expand(batch, qlen, qheads, d) + qsin = qsin.unsqueeze(2).expand(batch, qlen, qheads, d) else: qcos, qsin = cos[:qlen, :], sin[:qlen, :] qcos = qcos.unsqueeze(0).unsqueeze(2).expand(batch, qlen, qheads, d) @@ -444,10 +441,27 @@ class LlamaDecoderLayer(nn.Module): def __init__(self, config: LlamaConfig, moe: bool): super().__init__() self.hidden_size = config.hidden_size + self.moe = moe self.self_attn = LlamaAttention(config=config) - self.mlp = LlamaMLP(config) self.input_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) + if self.moe: + self.mlp = SparseMLP(num_experts=16, + top_k=2, + capacity_factor_train=1.25, + capacity_factor_eval=2., + min_capacity=4, + noisy_policy=None, + drop_tks=True, + expert_parallel=None, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + activation=config.hidden_act, + gated=True) + self.pre_extra_mlp_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) + self.extra_mlp = LlamaMLP(config) + else: + self.mlp = LlamaMLP(config) def forward( self, @@ -491,8 +505,16 @@ def forward( residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) + if self.moe: + hidden_states = hidden_states[0] hidden_states = residual + hidden_states + if self.moe: + residual = hidden_states + hidden_states = self.pre_extra_mlp_layernorm(hidden_states) + hidden_states = self.extra_mlp(hidden_states) + hidden_states = residual + hidden_states + outputs = (hidden_states,) if output_attentions: @@ -796,7 +818,7 @@ def custom_forward(*inputs): ) -class OpenLlamaForCausalLM(LlamaPreTrainedModel): +class OpenMoeForCausalLM(LlamaPreTrainedModel): # _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/examples/language/openmoe/openmoe_config.json b/examples/language/openmoe/model/openmoe_base_config.json similarity index 89% rename from examples/language/openmoe/openmoe_config.json rename to examples/language/openmoe/model/openmoe_base_config.json index 6401ebcb7aea..48f8d197cf31 100644 --- a/examples/language/openmoe/openmoe_config.json +++ b/examples/language/openmoe/model/openmoe_base_config.json @@ -1,6 +1,6 @@ { "architectures": [ - "OpenLlamaForCausalLM" + "OpenMoeForCausalLM" ], "intermediate_size": 2048, "hidden_size": 768, From bb6eda01639775d3caca773c2e3cff8327d0b9b3 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 5 Sep 2023 10:43:54 +0800 Subject: [PATCH 05/14] update config --- .../openmoe/model/convert_openmoe_ckpt.py | 8 ++++-- .../openmoe/model/modeling_openmoe.py | 28 +++++++++---------- .../openmoe/model/openmoe_8b_config.json | 24 ++++++++++++++++ .../openmoe/model/openmoe_base_config.json | 13 ++++++++- 4 files changed, 54 insertions(+), 19 deletions(-) create mode 100644 examples/language/openmoe/model/openmoe_8b_config.json diff --git a/examples/language/openmoe/model/convert_openmoe_ckpt.py b/examples/language/openmoe/model/convert_openmoe_ckpt.py index d78729f44182..20b1e780d8b3 100644 --- a/examples/language/openmoe/model/convert_openmoe_ckpt.py +++ b/examples/language/openmoe/model/convert_openmoe_ckpt.py @@ -99,7 +99,7 @@ def t5x_layer_norm_lookup(params, i, prefix, layer_name): return params[f"{prefix}/layers_{i}/{layer_name}/scale"] -def convert_t5x_to_pytorch(variables: dict, *, num_layers: int): +def convert_t5x_to_pytorch(variables: dict, *, num_layers: int, moe_interval: int): """Converts the parameters from T5X-Flax to Transformers-PyTorch.""" old = traverse_util.flatten_dict(variables["target"]) old = {"/".join(k): v for k, v in old.items()} @@ -131,7 +131,7 @@ def convert_t5x_to_pytorch(variables: dict, *, num_layers: int): layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_mlp_layer_norm") new[f"model.layers.{i}.post_attention_layernorm.weight"] = layer_norm - if (i + 1) % 4 == 0: + if (i + 1) % moe_interval == 0: # moe gate = t5x_gate_lookup(old, i, "decoder", split_mlp_wi) new[f"model.layers.{i}.mlp.gate_weight"] = gate.T @@ -172,7 +172,9 @@ def make_state_dict(converted_params): def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path): """Replaces the params in model witht the T5X converted params.""" variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path) - converted = convert_t5x_to_pytorch(variables, num_layers=config.num_hidden_layers) + converted = convert_t5x_to_pytorch(variables, + num_layers=config.num_hidden_layers, + moe_interval=config.moe_layer_interval) state_dict = make_state_dict(converted) model.load_state_dict(state_dict, strict=True) diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index ff9d51403005..cd006b03ab8c 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -300,15 +300,12 @@ def __init__(self, config: LlamaConfig): self.config = config self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads + self.head_dim = config.head_dim self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.pretraining_tp = config.pretraining_tp self.max_position_embeddings = config.max_position_embeddings - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads}).") self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) @@ -421,7 +418,7 @@ def forward( f" {attn_output.size()}") attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) if self.pretraining_tp > 1: attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) @@ -446,18 +443,18 @@ def __init__(self, config: LlamaConfig, moe: bool): self.input_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) if self.moe: - self.mlp = SparseMLP(num_experts=16, - top_k=2, - capacity_factor_train=1.25, - capacity_factor_eval=2., - min_capacity=4, - noisy_policy=None, - drop_tks=True, - expert_parallel=None, + self.mlp = SparseMLP(num_experts=config.num_experts, + top_k=config.topk, + capacity_factor_train=config.capacity_factor_train, + capacity_factor_eval=config.capacity_factor_eval, + min_capacity=config.min_capacity, + noisy_policy=config.noisy_policy, + drop_tks=config.drop_tks, + expert_parallel=config.expert_parallel, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, activation=config.hidden_act, - gated=True) + gated=config.gated) self.pre_extra_mlp_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) self.extra_mlp = LlamaMLP(config) else: @@ -653,7 +650,8 @@ def __init__(self, config: LlamaConfig): self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, moe=True if (i + 1) % 4 == 0 else False) for i in range(config.num_hidden_layers) + LlamaDecoderLayer(config, moe=True if (i + 1) % config.moe_layer_interval == 0 else False) + for i in range(config.num_hidden_layers) ]) self.norm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/examples/language/openmoe/model/openmoe_8b_config.json b/examples/language/openmoe/model/openmoe_8b_config.json new file mode 100644 index 000000000000..248697c37d3c --- /dev/null +++ b/examples/language/openmoe/model/openmoe_8b_config.json @@ -0,0 +1,24 @@ +{ + "architectures": [ + "OpenMoeForCausalLM" + ], + "intermediate_size": 8192, + "hidden_size": 2048, + "num_hidden_layers": 24, + "head_dim": 128, + "num_attention_heads": 24, + "dropout_rate": 0.0, + "layer_norm_epsilon": 1e-06, + "vocab_size": 256384, + "hidden_act": "swiglu", + "num_experts": 32, + "topk": 2, + "capacity_factor_train": 1.25, + "capacity_factor_eval": 2.0, + "min_capacity": 4, + "noisy_policy": null, + "drop_tks": true, + "expert_parallel": null, + "gated": true, + "moe_layer_interval": 6 +} diff --git a/examples/language/openmoe/model/openmoe_base_config.json b/examples/language/openmoe/model/openmoe_base_config.json index 48f8d197cf31..5a7c97bd1916 100644 --- a/examples/language/openmoe/model/openmoe_base_config.json +++ b/examples/language/openmoe/model/openmoe_base_config.json @@ -5,9 +5,20 @@ "intermediate_size": 2048, "hidden_size": 768, "num_hidden_layers": 12, + "head_dim": 64, "num_attention_heads": 12, "dropout_rate": 0.0, "layer_norm_epsilon": 1e-06, "vocab_size": 256384, - "hidden_act": "swiglu" + "hidden_act": "swiglu", + "num_experts": 16, + "topk": 2, + "capacity_factor_train": 1.25, + "capacity_factor_eval": 2.0, + "min_capacity": 4, + "noisy_policy": null, + "drop_tks": true, + "expert_parallel": null, + "gated": true, + "moe_layer_interval": 4 } From 1b51447d5291ae9df09072228edacfd6b52edc4f Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 5 Sep 2023 15:43:25 +0800 Subject: [PATCH 06/14] remove pdb --- examples/language/openmoe/model/modeling_openmoe.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index cd006b03ab8c..7fdd4cc32c23 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -340,7 +340,6 @@ def forward( use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - # import pdb; pdb.set_trace() if self.pretraining_tp > 1: key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp From cd9acea67b91389c5c97b668d8545c78b725764f Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 6 Sep 2023 09:58:20 +0800 Subject: [PATCH 07/14] update ci --- examples/language/openmoe/infer.py | 4 ++-- examples/language/openmoe/infer.sh | 2 +- examples/language/openmoe/test_ci.sh | 1 + 3 files changed, 4 insertions(+), 3 deletions(-) create mode 100644 examples/language/openmoe/test_ci.sh diff --git a/examples/language/openmoe/infer.py b/examples/language/openmoe/infer.py index 8285af3f730a..0ac7fb72508c 100644 --- a/examples/language/openmoe/infer.py +++ b/examples/language/openmoe/infer.py @@ -7,14 +7,14 @@ def parse_args(): parser = ArgumentParser() - parser.add_argument("--path", default="/path/to/openmoe", type=str, help="model path") + parser.add_argument("--model", default="base", type=str, help="model path", choices=["base", "8b"]) return parser.parse_args() def inference(args): tokenizer = T5Tokenizer.from_pretrained("google/umt5-small") - model = OpenMoeForCausalLM.from_pretrained(args.path) + model = OpenMoeForCausalLM.from_pretrained(f"hpcaitech/openmoe-{args.model}") model = model.eval().bfloat16() model = model.to(torch.cuda.current_device()) diff --git a/examples/language/openmoe/infer.sh b/examples/language/openmoe/infer.sh index 78787f48fbb8..a578203eba84 100644 --- a/examples/language/openmoe/infer.sh +++ b/examples/language/openmoe/infer.sh @@ -1 +1 @@ -python infer.py --path /path/to/openmoe +python infer.py --model "base" diff --git a/examples/language/openmoe/test_ci.sh b/examples/language/openmoe/test_ci.sh new file mode 100644 index 000000000000..1c0314261b82 --- /dev/null +++ b/examples/language/openmoe/test_ci.sh @@ -0,0 +1 @@ +bash infer.sh From 8e0b2c7ec1e4f06db7db5aa5b36c22b92de42dcc Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 6 Sep 2023 10:00:53 +0800 Subject: [PATCH 08/14] update requirement --- examples/language/openmoe/requirements.txt | 3 +++ examples/language/openmoe/test_ci.sh | 3 +++ 2 files changed, 6 insertions(+) create mode 100644 examples/language/openmoe/requirements.txt diff --git a/examples/language/openmoe/requirements.txt b/examples/language/openmoe/requirements.txt new file mode 100644 index 000000000000..2b869f7ccf3e --- /dev/null +++ b/examples/language/openmoe/requirements.txt @@ -0,0 +1,3 @@ +colossalai >= 0.1.12 +torch >= 1.8.1 +transformers >= 4.20.0 diff --git a/examples/language/openmoe/test_ci.sh b/examples/language/openmoe/test_ci.sh index 1c0314261b82..7b841e1a4877 100644 --- a/examples/language/openmoe/test_ci.sh +++ b/examples/language/openmoe/test_ci.sh @@ -1 +1,4 @@ +set -xe +pip install -r requirements.txt + bash infer.sh From bb4a742228ab81c2c51285e1c7c3b8db0e2b2eed Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 6 Sep 2023 10:21:13 +0800 Subject: [PATCH 09/14] add build ffn experts --- colossalai/nn/layer/moe/__init__.py | 4 ++-- colossalai/nn/layer/moe/experts.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/colossalai/nn/layer/moe/__init__.py b/colossalai/nn/layer/moe/__init__.py index f99353d0e0dd..52f529814eba 100644 --- a/colossalai/nn/layer/moe/__init__.py +++ b/colossalai/nn/layer/moe/__init__.py @@ -1,10 +1,10 @@ from .checkpoint import MoeCheckpintIO -from .experts import EPMLPExperts, TPMLPExperts +from .experts import EPMLPExperts, TPMLPExperts, build_ffn_experts from .layers import MoeLayer, MoeModule, SparseMLP from .routers import MoeRouter, Top1Router, Top2Router from .utils import NormalNoiseGenerator, UniformNoiseGenerator __all__ = [ 'EPMLPExperts', 'TPMLPExperts', 'Top1Router', 'Top2Router', 'MoeModule', 'MoeLayer', 'NormalNoiseGenerator', - 'UniformNoiseGenerator', 'SparseMLP', 'MoeRouter', 'MoeCheckpintIO' + 'UniformNoiseGenerator', 'SparseMLP', 'MoeRouter', 'MoeCheckpintIO', 'build_ffn_experts' ] diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/nn/layer/moe/experts.py index a64fcf68fc66..9a51ec2a5c7e 100644 --- a/colossalai/nn/layer/moe/experts.py +++ b/colossalai/nn/layer/moe/experts.py @@ -162,3 +162,13 @@ def get_expert_class(name: str) -> BaseMLPExperts: return BaseMLPExperts else: raise ValueError(f"Unknown expert class name: {name}") + + +def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): + mep_size = MOE_CONTEXT.max_ep_size + if num_experts % mep_size == 0 or mep_size % num_experts == 0: + return EPMLPExperts(num_experts, d_model, d_ff, activation, drop_rate) + elif d_ff % mep_size == 0: + return TPMLPExperts(num_experts, d_model, d_ff, activation, drop_rate) + else: + raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.") From 716a103a1225d1ef56a43fa7f04ccf02dede9643 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 6 Sep 2023 10:46:13 +0800 Subject: [PATCH 10/14] update requirement --- examples/language/openmoe/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/language/openmoe/requirements.txt b/examples/language/openmoe/requirements.txt index 2b869f7ccf3e..2e7175c59baf 100644 --- a/examples/language/openmoe/requirements.txt +++ b/examples/language/openmoe/requirements.txt @@ -1,3 +1,3 @@ -colossalai >= 0.1.12 +colossalai @ git+https://github.com/hpcaitech/ColossalAI@feature/moe torch >= 1.8.1 transformers >= 4.20.0 From d503058245ca6f8af26c1f0b8c00b070ed60b901 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 6 Sep 2023 11:05:10 +0800 Subject: [PATCH 11/14] update ci --- examples/language/openmoe/test_ci.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/language/openmoe/test_ci.sh b/examples/language/openmoe/test_ci.sh index 7b841e1a4877..66d880743aaf 100644 --- a/examples/language/openmoe/test_ci.sh +++ b/examples/language/openmoe/test_ci.sh @@ -1,4 +1,5 @@ set -xe +pip uninstall colossalai pip install -r requirements.txt bash infer.sh From ef08b35831ad11baa93ec0e23a8c64ee89d0a9b5 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 6 Sep 2023 11:08:20 +0800 Subject: [PATCH 12/14] update ci --- colossalai/tensor/moe_tensor/__init__.py | 0 examples/language/openmoe/requirements.txt | 2 +- examples/language/openmoe/test_ci.sh | 1 - 3 files changed, 1 insertion(+), 2 deletions(-) create mode 100644 colossalai/tensor/moe_tensor/__init__.py diff --git a/colossalai/tensor/moe_tensor/__init__.py b/colossalai/tensor/moe_tensor/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/examples/language/openmoe/requirements.txt b/examples/language/openmoe/requirements.txt index 2e7175c59baf..2b869f7ccf3e 100644 --- a/examples/language/openmoe/requirements.txt +++ b/examples/language/openmoe/requirements.txt @@ -1,3 +1,3 @@ -colossalai @ git+https://github.com/hpcaitech/ColossalAI@feature/moe +colossalai >= 0.1.12 torch >= 1.8.1 transformers >= 4.20.0 diff --git a/examples/language/openmoe/test_ci.sh b/examples/language/openmoe/test_ci.sh index 66d880743aaf..7b841e1a4877 100644 --- a/examples/language/openmoe/test_ci.sh +++ b/examples/language/openmoe/test_ci.sh @@ -1,5 +1,4 @@ set -xe -pip uninstall colossalai pip install -r requirements.txt bash infer.sh From d53e538859e75d64440e76546e43c3db6912be41 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 6 Sep 2023 11:16:50 +0800 Subject: [PATCH 13/14] update require --- examples/language/openmoe/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/language/openmoe/requirements.txt b/examples/language/openmoe/requirements.txt index 2b869f7ccf3e..2fb95d9c71d3 100644 --- a/examples/language/openmoe/requirements.txt +++ b/examples/language/openmoe/requirements.txt @@ -1,3 +1,4 @@ colossalai >= 0.1.12 torch >= 1.8.1 transformers >= 4.20.0 +sentencepiece From c4f30a50c21be4d13b0a5fc1ccbed25a6b04c8b9 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 6 Sep 2023 12:57:08 +0800 Subject: [PATCH 14/14] update ci --- examples/language/openmoe/infer.py | 9 +++++++-- examples/language/openmoe/test_ci.sh | 2 +- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/examples/language/openmoe/infer.py b/examples/language/openmoe/infer.py index 0ac7fb72508c..b41fa2f2e4f1 100644 --- a/examples/language/openmoe/infer.py +++ b/examples/language/openmoe/infer.py @@ -3,18 +3,23 @@ import torch from model.modeling_openmoe import OpenMoeForCausalLM from transformers import T5Tokenizer +from transformers.models.llama import LlamaConfig def parse_args(): parser = ArgumentParser() - parser.add_argument("--model", default="base", type=str, help="model path", choices=["base", "8b"]) + parser.add_argument("--model", default="base", type=str, help="model path", choices=["base", "8b", "test"]) return parser.parse_args() def inference(args): tokenizer = T5Tokenizer.from_pretrained("google/umt5-small") - model = OpenMoeForCausalLM.from_pretrained(f"hpcaitech/openmoe-{args.model}") + if args.model == "test": + config = LlamaConfig.from_pretrained("hpcaitech/openmoe-base") + model = OpenMoeForCausalLM(config) + else: + model = OpenMoeForCausalLM.from_pretrained(f"hpcaitech/openmoe-{args.model}") model = model.eval().bfloat16() model = model.to(torch.cuda.current_device()) diff --git a/examples/language/openmoe/test_ci.sh b/examples/language/openmoe/test_ci.sh index 7b841e1a4877..349b2eaccd79 100644 --- a/examples/language/openmoe/test_ci.sh +++ b/examples/language/openmoe/test_ci.sh @@ -1,4 +1,4 @@ set -xe pip install -r requirements.txt -bash infer.sh +python infer.py --model "test"