diff --git a/nemo/collections/llm/__init__.py b/nemo/collections/llm/__init__.py index cb8db0f5f272..19911b544f43 100644 --- a/nemo/collections/llm/__init__.py +++ b/nemo/collections/llm/__init__.py @@ -13,8 +13,25 @@ SquadDataModule, ) from nemo.collections.llm.gpt.model import ( + CodeGemmaConfig2B, + CodeGemmaConfig7B, + CodeLlamaConfig7B, + CodeLlamaConfig13B, + CodeLlamaConfig34B, + CodeLlamaConfig70B, + GemmaConfig, + GemmaConfig2B, + GemmaConfig7B, + GemmaModel, GPTConfig, GPTModel, + Llama2Config7B, + Llama2Config13B, + Llama2Config70B, + Llama3Config8B, + Llama3Config70B, + LlamaConfig, + LlamaModel, MaskedTokenLossReduction, Mistral7BConfig, Mistral7BModel, @@ -35,6 +52,23 @@ "Mistral7BModel", "MixtralConfig", "MixtralModel", + "LlamaConfig", + "Llama2Config7B", + "Llama2Config13B", + "Llama2Config70B", + "Llama3Config8B", + "Llama3Config70B", + "CodeLlamaConfig7B", + "CodeLlamaConfig13B", + "CodeLlamaConfig34B", + "CodeLlamaConfig70B", + "LlamaModel", + "GemmaConfig", + "GemmaConfig2B", + "GemmaConfig7B", + "CodeGemmaConfig2B", + "CodeGemmaConfig7B", + "GemmaModel", "PreTrainingDataModule", "FineTuningDataModule", "SquadDataModule", diff --git a/nemo/collections/llm/gpt/model/__init__.py b/nemo/collections/llm/gpt/model/__init__.py index 0ddaa61c7a35..2da72539fd15 100644 --- a/nemo/collections/llm/gpt/model/__init__.py +++ b/nemo/collections/llm/gpt/model/__init__.py @@ -5,6 +5,8 @@ gpt_data_step, gpt_forward_step, ) +from nemo.collections.llm.gpt.model.gemma import * +from nemo.collections.llm.gpt.model.llama import * from nemo.collections.llm.gpt.model.mistral_7b import Mistral7BConfig, Mistral7BModel from nemo.collections.llm.gpt.model.mixtral import MixtralConfig, MixtralModel @@ -15,6 +17,23 @@ "Mistral7BModel", "MixtralConfig", "MixtralModel", + "LlamaConfig", + "Llama2Config7B", + "Llama2Config13B", + "Llama2Config70B", + "Llama3Config8B", + "Llama3Config70B", + "CodeLlamaConfig7B", + "CodeLlamaConfig13B", + "CodeLlamaConfig34B", + "CodeLlamaConfig70B", + "GemmaConfig", + "GemmaConfig2B", + "GemmaConfig7B", + "CodeGemmaConfig2B", + "CodeGemmaConfig7B", + "GemmaModel", + "LlamaModel", "MaskedTokenLossReduction", "gpt_data_step", "gpt_forward_step", diff --git a/nemo/collections/llm/gpt/model/gemma.py b/nemo/collections/llm/gpt/model/gemma.py new file mode 100644 index 000000000000..ff9772b1b74c --- /dev/null +++ b/nemo/collections/llm/gpt/model/gemma.py @@ -0,0 +1,299 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Annotated, Callable, Optional + +import torch + +from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel +from nemo.collections.llm.utils import Config +from nemo.collections.nlp.modules.common.megatron.utils import openai_gelu +from nemo.lightning import OptimizerModule, io, teardown + +if TYPE_CHECKING: + from transformers import GemmaForCausalLM + + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec + + +# Note: Gemma requires huggingface transformers >= 4.38 +# Note: these Gemma configs are copied from the corresponding HF model. You may need to modify the parameter for +# your own needs, in particular: seq_length and rotary_base. +@dataclass +class GemmaConfig(GPTConfig): + # configs that are common across model sizes + normalization: str = "RMSNorm" + activation_func: Callable = openai_gelu + gated_linear_unit: bool = True + position_embedding_type: str = "rope" + add_bias_linear: bool = False + seq_length: int = 8192 + kv_channels: int = 256 + share_embeddings_and_output_weights: bool = True + # Note: different behavior compared to Legacy NeMo + # Legacy NeMo does not set layernorm_zero_centered_gamma and instead adds 1 in the HF -> NeMo conversion script + # The present implementation is more in line with the official implementation + layernorm_zero_centered_gamma: bool = True + + +@dataclass +class GemmaConfig2B(GemmaConfig): + num_layers: int = 18 + hidden_size: int = 2048 + num_attention_heads: int = 8 + num_query_groups: int = 1 + ffn_hidden_size: int = 16384 + + +@dataclass +class GemmaConfig7B(GemmaConfig): + num_layers: int = 28 + hidden_size: int = 3072 + num_attention_heads: int = 16 + num_query_groups: int = 16 + ffn_hidden_size: int = 24576 + + +class CodeGemmaConfig2B(GemmaConfig2B): + pass + + +class CodeGemmaConfig7B(GemmaConfig7B): + pass + + +class GemmaModel(GPTModel): + def __init__( + self, + config: Annotated[Optional[GemmaConfig], Config[GemmaConfig]] = None, + optim: Optional[OptimizerModule] = None, + tokenizer: Optional["TokenizerSpec"] = None, + ): + super().__init__(config or GemmaConfig(), optim=optim, tokenizer=tokenizer) + + +@io.model_importer(GemmaModel, "hf") +class HFGemmaImporter(io.ModelConnector["GemmaForCausalLM", GemmaModel]): + def init(self) -> GemmaModel: + return GemmaModel(self.config, tokenizer=self.tokenizer) + + def apply(self, output_path: Path) -> Path: + from transformers import GemmaForCausalLM + + source = GemmaForCausalLM.from_pretrained(str(self)) + target = self.init() + trainer = self.nemo_setup(target) + self.convert_state(source, target) + self.nemo_save(output_path, trainer) + + print(f"Converted Gemma model to Nemo, model saved to {output_path}") + + teardown(trainer, target) + del trainer, target + + return output_path + + def convert_state(self, source, target): + mapping = { + "model.embed_tokens.weight": "embedding.word_embeddings.weight", + "model.layers.*.self_attn.o_proj.weight": "decoder.layers.*.self_attention.linear_proj.weight", + "model.layers.*.mlp.down_proj.weight": "decoder.layers.*.mlp.linear_fc2.weight", + "model.layers.*.input_layernorm.weight": "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight", + "model.layers.*.post_attention_layernorm.weight": "decoder.layers.*.mlp.linear_fc1.layer_norm_weight", + "model.norm.weight": "decoder.final_layernorm.weight", + } + + return io.apply_transforms(source, target, mapping=mapping, transforms=[_import_qkv, _import_linear_fc1]) + + @property + def tokenizer(self) -> "AutoTokenizer": + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + + return AutoTokenizer(str(self)) + + @property + def config(self) -> GemmaConfig: + from transformers import GemmaConfig as HFGemmaConfig + + source = HFGemmaConfig.from_pretrained(str(self)) + + def make_vocab_size_divisible_by(vocab_size): + base = 128 + while vocab_size % base != 0: + base //= 2 + return base + + output = GemmaConfig( + num_layers=source.num_hidden_layers, + hidden_size=source.hidden_size, + ffn_hidden_size=source.intermediate_size, + num_attention_heads=source.num_attention_heads, + init_method_std=source.initializer_range, + layernorm_epsilon=source.rms_norm_eps, + num_query_groups=source.num_key_value_heads, + rotary_base=source.rope_theta, + gated_linear_unit=True, + make_vocab_size_divisible_by=make_vocab_size_divisible_by(source.vocab_size), + share_embeddings_and_output_weights=False, + ) + + return output + + +@io.model_exporter(GemmaModel, "hf") +class HFGemmaExporter(io.ModelConnector[GemmaModel, "GemmaForCausalLM"]): + def init(self) -> "GemmaForCausalLM": + from transformers import AutoModelForCausalLM + + return AutoModelForCausalLM.from_config(self.config) + + def apply(self, output_path: Path) -> Path: + target = self.init() + source, _ = self.nemo_load(str(self)) + target = self.convert_state(source, target) + + target = target.cpu() + target.save_pretrained(output_path) + self.tokenizer.save_pretrained(output_path) + + return output_path + + def convert_state(self, source, target): + mapping = { + "embedding.word_embeddings.weight": "model.embed_tokens.weight", + "decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight", + "decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight", + "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight", + "decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight", + "decoder.final_layernorm.weight": "model.norm.weight", + } + + return io.apply_transforms(source, target, mapping=mapping, transforms=[_export_qkv, _export_linear_fc1]) + + @property + def tokenizer(self): + return io.load_ckpt(str(self)).model.tokenizer.tokenizer + + @property + def config(self) -> "GemmaConfig": + source: GemmaConfig = io.load_ckpt(str(self)).model.config + + from transformers import GemmaConfig as HFGemmaConfig + + return HFGemmaConfig( + num_hidden_layers=source.num_layers, + hidden_size=source.hidden_size, + intermediate_size=source.ffn_hidden_size, + num_attention_heads=source.num_attention_heads, + max_position_embeddings=source.seq_length, + initializer_range=source.init_method_std, + rms_norm_eps=source.layernorm_epsilon, + num_key_value_heads=source.num_query_groups, + vocab_size=self.tokenizer.vocab_size, + ) + + +@io.state_transform( + source_key=( + "model.layers.*.self_attn.q_proj.weight", + "model.layers.*.self_attn.k_proj.weight", + "model.layers.*.self_attn.v_proj.weight", + ), + target_key="decoder.layers.*.self_attention.linear_qkv.weight", +) +def _import_qkv(ctx: io.TransformCTX, q, k, v): + megatron_config = ctx.target.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + hidden_size = megatron_config.hidden_size + head_num = megatron_config.num_attention_heads + head_size = hidden_size // head_num + + old_tensor_shape = q.size() + new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:] + new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:] + + q = q.view(*new_q_tensor_shape) + k = k.view(*new_kv_tensor_shape) + v = v.view(*new_kv_tensor_shape) + + qkv_weights_l = [] + for i in range(num_query_groups): + qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :]) + qkv_weights_l.append(k[i : i + 1, :, :]) + qkv_weights_l.append(v[i : i + 1, :, :]) + qkv_weights = torch.cat(qkv_weights_l) + assert qkv_weights.ndim == 3, qkv_weights.shape + assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape + assert qkv_weights.shape[1] == head_size, qkv_weights.shape + assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape + + qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) + + return qkv_weights + + +@io.state_transform( + source_key="decoder.layers.*.self_attention.linear_qkv.weight", + target_key=( + "model.layers.*.self_attn.q_proj.weight", + "model.layers.*.self_attn.k_proj.weight", + "model.layers.*.self_attn.v_proj.weight", + ), +) +def _export_qkv(ctx: io.TransformCTX, linear_qkv): + megatron_config = ctx.source.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + hidden_size = megatron_config.hidden_size + head_num = megatron_config.num_attention_heads + head_size = hidden_size // head_num + qkv_total_dim = head_num + 2 * num_query_groups + + linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, hidden_size]) + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + q_proj = linear_qkv[q_slice].reshape(-1, hidden_size).cpu() + k_proj = linear_qkv[k_slice].reshape(-1, hidden_size).cpu() + v_proj = linear_qkv[v_slice].reshape(-1, hidden_size).cpu() + + return q_proj, k_proj, v_proj + + +@io.state_transform( + source_key=("model.layers.*.mlp.gate_proj.weight", "model.layers.*.mlp.up_proj.weight"), + target_key="decoder.layers.*.mlp.linear_fc1.weight", +) +def _import_linear_fc1(down, gate): + return torch.cat((down, gate), axis=0).float() + + +@io.state_transform( + source_key="decoder.layers.*.mlp.linear_fc1.weight", + target_key=("model.layers.*.mlp.gate_proj.weight", "model.layers.*.mlp.up_proj.weight"), +) +def _export_linear_fc1(linear_fc1): + gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0) + + return gate_proj, up_proj + + +__all__ = [ + "GemmaConfig", + "GemmaConfig2B", + "GemmaConfig7B", + "CodeGemmaConfig2B", + "CodeGemmaConfig7B", + "GemmaModel", +] diff --git a/nemo/collections/llm/gpt/model/llama.py b/nemo/collections/llm/gpt/model/llama.py new file mode 100644 index 000000000000..aa089b077041 --- /dev/null +++ b/nemo/collections/llm/gpt/model/llama.py @@ -0,0 +1,342 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Annotated, Callable, Optional + +import torch +import torch.nn.functional as F + +from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel +from nemo.collections.llm.utils import Config +from nemo.lightning import OptimizerModule, io, teardown + +if TYPE_CHECKING: + from transformers import LlamaConfig as HFLlamaConfig + from transformers import LlamaForCausalLM + + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec + + +# Note: these Llama configs are copied from the corresponding HF model. You may need to modify the parameter for +# your own needs, in particular: seq_length and rotary_base. +@dataclass +class LlamaConfig(GPTConfig): + # configs that are common across model sizes + normalization: str = "RMSNorm" + activation_func: Callable = F.silu + gated_linear_unit: bool = True + position_embedding_type: str = "rope" + add_bias_linear: bool = False + seq_length: int = 4096 + + +@dataclass +class Llama2Config7B(LlamaConfig): + num_layers: int = 32 + hidden_size: int = 4096 + num_attention_heads: int = 32 + num_query_groups: int = 32 + ffn_hidden_size: int = 11008 + + +@dataclass +class Llama2Config13B(LlamaConfig): + num_layers: int = 40 + hidden_size: int = 5120 + num_attention_heads: int = 40 + num_query_groups: int = 40 + ffn_hidden_size: int = 13824 + + +@dataclass +class Llama2Config70B(LlamaConfig): + num_layers: int = 80 + hidden_size: int = 8192 + num_attention_heads: int = 64 + num_query_groups: int = 8 + ffn_hidden_size: int = 28672 + + +@dataclass +class Llama3Config8B(Llama2Config7B): + seq_length: int = 8192 + num_query_groups: int = 8 + ffn_hidden_size: int = 14336 + + +@dataclass +class Llama3Config70B(Llama2Config70B): + seq_length: int = 8192 + + +@dataclass +class CodeLlamaConfig7B(Llama2Config7B): + rotary_base: int = 1_000_000 + seq_length: int = 16384 + + +@dataclass +class CodeLlamaConfig13B(Llama2Config13B): + rotary_base: int = 1_000_000 + seq_length: int = 16384 + + +@dataclass +class CodeLlamaConfig34B(LlamaConfig): + num_layers: int = 48 + hidden_size: int = 8192 + num_attention_heads: int = 64 + num_query_groups: int = 8 + ffn_hidden_size: int = 22016 + rotary_base: int = 1_000_000 + seq_length: int = 16384 + + +@dataclass +class CodeLlamaConfig70B(Llama2Config70B): + pass + + +class LlamaModel(GPTModel): + def __init__( + self, + config: Annotated[Optional[LlamaConfig], Config[LlamaConfig]] = None, + optim: Optional[OptimizerModule] = None, + tokenizer: Optional["TokenizerSpec"] = None, + ): + super().__init__(config or LlamaConfig(), optim=optim, tokenizer=tokenizer) + + +@io.model_importer(LlamaModel, "hf") +class HFLlamaImporter(io.ModelConnector["LlamaForCausalLM", LlamaModel]): + def init(self) -> LlamaModel: + return LlamaModel(self.config, tokenizer=self.tokenizer) + + def apply(self, output_path: Path) -> Path: + from transformers import LlamaForCausalLM + + source = LlamaForCausalLM.from_pretrained(str(self)) + target = self.init() + trainer = self.nemo_setup(target) + self.convert_state(source, target) + self.nemo_save(output_path, trainer) + + print(f"Converted Llama model to Nemo, model saved to {output_path}") + + teardown(trainer, target) + del trainer, target + + return output_path + + def convert_state(self, source, target): + mapping = { + "model.embed_tokens.weight": "embedding.word_embeddings.weight", + "model.layers.*.self_attn.o_proj.weight": "decoder.layers.*.self_attention.linear_proj.weight", + "model.layers.*.mlp.down_proj.weight": "decoder.layers.*.mlp.linear_fc2.weight", + "model.layers.*.input_layernorm.weight": "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight", + "model.layers.*.post_attention_layernorm.weight": "decoder.layers.*.mlp.linear_fc1.layer_norm_weight", + "model.norm.weight": "decoder.final_layernorm.weight", + "lm_head.weight": "output_layer.weight", + } + + return io.apply_transforms(source, target, mapping=mapping, transforms=[_import_qkv, _import_linear_fc1]) + + @property + def tokenizer(self) -> "AutoTokenizer": + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + + return AutoTokenizer(str(self)) + + @property + def config(self) -> LlamaConfig: + from transformers import LlamaConfig as HFLlamaConfig + + source = HFLlamaConfig.from_pretrained(str(self)) + + def make_vocab_size_divisible_by(vocab_size): + base = 128 + while vocab_size % base != 0: + base //= 2 + return base + + output = LlamaConfig( + num_layers=source.num_hidden_layers, + hidden_size=source.hidden_size, + ffn_hidden_size=source.intermediate_size, + num_attention_heads=source.num_attention_heads, + init_method_std=source.initializer_range, + layernorm_epsilon=source.rms_norm_eps, + num_query_groups=source.num_key_value_heads, + rotary_base=source.rope_theta, + gated_linear_unit=True, + make_vocab_size_divisible_by=make_vocab_size_divisible_by(source.vocab_size), + share_embeddings_and_output_weights=False, + ) + + return output + + +@io.model_exporter(LlamaModel, "hf") +class HFLlamaExporter(io.ModelConnector[LlamaModel, "LlamaForCausalLM"]): + def init(self) -> "LlamaForCausalLM": + from transformers import AutoModelForCausalLM + + return AutoModelForCausalLM.from_config(self.config) + + def apply(self, output_path: Path) -> Path: + target = self.init() + source, _ = self.nemo_load(str(self)) + target = self.convert_state(source, target) + + target = target.cpu() + target.save_pretrained(output_path) + self.tokenizer.save_pretrained(output_path) + + return output_path + + def convert_state(self, source, target): + mapping = { + "embedding.word_embeddings.weight": "model.embed_tokens.weight", + "decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight", + "decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight", + "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight", + "decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight", + "decoder.final_layernorm.weight": "model.norm.weight", + "output_layer.weight": "lm_head.weight", + } + + return io.apply_transforms(source, target, mapping=mapping, transforms=[_export_qkv, _export_linear_fc1]) + + @property + def tokenizer(self): + return io.load_ckpt(str(self)).model.tokenizer.tokenizer + + @property + def config(self) -> "HFLlamaConfig": + source: LlamaConfig = io.load_ckpt(str(self)).model.config + + from transformers import LlamaConfig as HFLlamaConfig + + return HFLlamaConfig( + num_hidden_layers=source.num_layers, + hidden_size=source.hidden_size, + intermediate_size=source.ffn_hidden_size, + num_attention_heads=source.num_attention_heads, + max_position_embeddings=source.seq_length, + initializer_range=source.init_method_std, + rms_norm_eps=source.layernorm_epsilon, + num_key_value_heads=source.num_query_groups, + rope_theta=source.rotary_base, + vocab_size=self.tokenizer.vocab_size, + ) + + +@io.state_transform( + source_key=( + "model.layers.*.self_attn.q_proj.weight", + "model.layers.*.self_attn.k_proj.weight", + "model.layers.*.self_attn.v_proj.weight", + ), + target_key="decoder.layers.*.self_attention.linear_qkv.weight", +) +def _import_qkv(ctx: io.TransformCTX, q, k, v): + megatron_config = ctx.target.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + hidden_size = megatron_config.hidden_size + head_num = megatron_config.num_attention_heads + head_size = hidden_size // head_num + + old_tensor_shape = q.size() + new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:] + new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:] + + q = q.view(*new_q_tensor_shape) + k = k.view(*new_kv_tensor_shape) + v = v.view(*new_kv_tensor_shape) + + qkv_weights_l = [] + for i in range(num_query_groups): + qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :]) + qkv_weights_l.append(k[i : i + 1, :, :]) + qkv_weights_l.append(v[i : i + 1, :, :]) + qkv_weights = torch.cat(qkv_weights_l) + assert qkv_weights.ndim == 3, qkv_weights.shape + assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape + assert qkv_weights.shape[1] == head_size, qkv_weights.shape + assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape + + qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) + + return qkv_weights + + +@io.state_transform( + source_key="decoder.layers.*.self_attention.linear_qkv.weight", + target_key=( + "model.layers.*.self_attn.q_proj.weight", + "model.layers.*.self_attn.k_proj.weight", + "model.layers.*.self_attn.v_proj.weight", + ), +) +def _export_qkv(ctx: io.TransformCTX, linear_qkv): + megatron_config = ctx.source.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + hidden_size = megatron_config.hidden_size + head_num = megatron_config.num_attention_heads + head_size = hidden_size // head_num + qkv_total_dim = head_num + 2 * num_query_groups + + linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, hidden_size]) + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + q_proj = linear_qkv[q_slice].reshape(-1, hidden_size).cpu() + k_proj = linear_qkv[k_slice].reshape(-1, hidden_size).cpu() + v_proj = linear_qkv[v_slice].reshape(-1, hidden_size).cpu() + + return q_proj, k_proj, v_proj + + +@io.state_transform( + source_key=("model.layers.*.mlp.gate_proj.weight", "model.layers.*.mlp.up_proj.weight"), + target_key="decoder.layers.*.mlp.linear_fc1.weight", +) +def _import_linear_fc1(down, gate): + return torch.cat((down, gate), axis=0).float() + + +@io.state_transform( + source_key="decoder.layers.*.mlp.linear_fc1.weight", + target_key=("model.layers.*.mlp.gate_proj.weight", "model.layers.*.mlp.up_proj.weight"), +) +def _export_linear_fc1(linear_fc1): + gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0) + + return gate_proj, up_proj + + +__all__ = [ + "LlamaConfig", + "Llama2Config7B", + "Llama2Config13B", + "Llama2Config70B", + "Llama3Config8B", + "Llama3Config70B", + "CodeLlamaConfig7B", + "CodeLlamaConfig13B", + "CodeLlamaConfig34B", + "CodeLlamaConfig70B", + "LlamaModel", +] diff --git a/nemo/collections/llm/gpt/model/mistral_7b.py b/nemo/collections/llm/gpt/model/mistral_7b.py index ada67c17da25..ff9591581f86 100644 --- a/nemo/collections/llm/gpt/model/mistral_7b.py +++ b/nemo/collections/llm/gpt/model/mistral_7b.py @@ -71,9 +71,6 @@ def apply(self, output_path: Path) -> Path: return output_path - def on_import_ckpt(self, model: pl.LightningModule): - model.tokenizer = self.tokenizer - def convert_state(self, source, target): mapping = { "model.embed_tokens.weight": "embedding.word_embeddings.weight", diff --git a/nemo/lightning/io/connector.py b/nemo/lightning/io/connector.py index a6ab4afd6d1b..41c81582bb63 100644 --- a/nemo/lightning/io/connector.py +++ b/nemo/lightning/io/connector.py @@ -217,4 +217,5 @@ def local_path(self, base_path: Optional[Path] = None) -> Path: return _base / str(self).replace("://", "/") - def on_import_ckpt(self, model: pl.LightningModule): ... + def on_import_ckpt(self, model: pl.LightningModule): + model.tokenizer = self.tokenizer diff --git a/nemo/lightning/io/mixin.py b/nemo/lightning/io/mixin.py index 62b9a165c542..54b6e7195bc9 100644 --- a/nemo/lightning/io/mixin.py +++ b/nemo/lightning/io/mixin.py @@ -198,7 +198,7 @@ def register_importer(cls, ext: str, default_path: Optional[str] = None) -> Call """ def decorator(connector: Type[ConnT]) -> Type[ConnT]: - cls._IMPORTERS[ext] = connector + cls._IMPORTERS[str(cls) + ext] = connector if default_path: connector.default_path = default_path return connector @@ -221,7 +221,7 @@ def register_exporter(cls, ext: str, default_path: Optional[str] = None) -> Call """ def decorator(connector: Type[ConnT]) -> Type[ConnT]: - cls._EXPORTERS[ext] = connector + cls._EXPORTERS[str(cls) + ext] = connector if default_path: connector.default_path = default_path return connector @@ -310,7 +310,7 @@ def _get_connector(cls, ext, path=None, importer=True) -> ModelConnector: else: _path = path - connector = cls._IMPORTERS.get(ext) if importer else cls._EXPORTERS.get(ext) + connector = cls._IMPORTERS.get(str(cls) + ext) if importer else cls._EXPORTERS.get(str(cls) + ext) if not connector: raise ValueError(f"No connector found for extension '{ext}'")