Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions nemo/collections/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,25 @@
SquadDataModule,
)
from nemo.collections.llm.gpt.model import (
CodeGemmaConfig2B,
CodeGemmaConfig7B,
CodeLlamaConfig7B,
CodeLlamaConfig13B,
CodeLlamaConfig34B,
CodeLlamaConfig70B,
GemmaConfig,
GemmaConfig2B,
GemmaConfig7B,
GemmaModel,
GPTConfig,
GPTModel,
Llama2Config7B,
Llama2Config13B,
Llama2Config70B,
Llama3Config8B,
Llama3Config70B,
LlamaConfig,
LlamaModel,
MaskedTokenLossReduction,
Mistral7BConfig,
Mistral7BModel,
Expand All @@ -35,6 +52,23 @@
"Mistral7BModel",
"MixtralConfig",
"MixtralModel",
"LlamaConfig",
"Llama2Config7B",
"Llama2Config13B",
"Llama2Config70B",
"Llama3Config8B",
"Llama3Config70B",
"CodeLlamaConfig7B",
"CodeLlamaConfig13B",
"CodeLlamaConfig34B",
"CodeLlamaConfig70B",
"LlamaModel",
"GemmaConfig",
"GemmaConfig2B",
"GemmaConfig7B",
"CodeGemmaConfig2B",
"CodeGemmaConfig7B",
"GemmaModel",
"PreTrainingDataModule",
"FineTuningDataModule",
"SquadDataModule",
Expand Down
19 changes: 19 additions & 0 deletions nemo/collections/llm/gpt/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
gpt_data_step,
gpt_forward_step,
)
from nemo.collections.llm.gpt.model.gemma import *
from nemo.collections.llm.gpt.model.llama import *
from nemo.collections.llm.gpt.model.mistral_7b import Mistral7BConfig, Mistral7BModel
from nemo.collections.llm.gpt.model.mixtral import MixtralConfig, MixtralModel

Expand All @@ -15,6 +17,23 @@
"Mistral7BModel",
"MixtralConfig",
"MixtralModel",
"LlamaConfig",
"Llama2Config7B",
"Llama2Config13B",
"Llama2Config70B",
"Llama3Config8B",
"Llama3Config70B",
"CodeLlamaConfig7B",
"CodeLlamaConfig13B",
"CodeLlamaConfig34B",
"CodeLlamaConfig70B",
"GemmaConfig",
"GemmaConfig2B",
"GemmaConfig7B",
"CodeGemmaConfig2B",
"CodeGemmaConfig7B",
"GemmaModel",
"LlamaModel",
"MaskedTokenLossReduction",
"gpt_data_step",
"gpt_forward_step",
Expand Down
299 changes: 299 additions & 0 deletions nemo/collections/llm/gpt/model/gemma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,299 @@
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Callable, Optional

import torch

from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel
from nemo.collections.llm.utils import Config
from nemo.collections.nlp.modules.common.megatron.utils import openai_gelu
from nemo.lightning import OptimizerModule, io, teardown

if TYPE_CHECKING:
from transformers import GemmaForCausalLM

from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec


# Note: Gemma requires huggingface transformers >= 4.38
# Note: these Gemma configs are copied from the corresponding HF model. You may need to modify the parameter for
# your own needs, in particular: seq_length and rotary_base.
@dataclass
class GemmaConfig(GPTConfig):
# configs that are common across model sizes
normalization: str = "RMSNorm"
activation_func: Callable = openai_gelu
gated_linear_unit: bool = True
position_embedding_type: str = "rope"
add_bias_linear: bool = False
seq_length: int = 8192
kv_channels: int = 256
share_embeddings_and_output_weights: bool = True
# Note: different behavior compared to Legacy NeMo
# Legacy NeMo does not set layernorm_zero_centered_gamma and instead adds 1 in the HF -> NeMo conversion script
# The present implementation is more in line with the official implementation
layernorm_zero_centered_gamma: bool = True


@dataclass
class GemmaConfig2B(GemmaConfig):
num_layers: int = 18
hidden_size: int = 2048
num_attention_heads: int = 8
num_query_groups: int = 1
ffn_hidden_size: int = 16384


@dataclass
class GemmaConfig7B(GemmaConfig):
num_layers: int = 28
hidden_size: int = 3072
num_attention_heads: int = 16
num_query_groups: int = 16
ffn_hidden_size: int = 24576


class CodeGemmaConfig2B(GemmaConfig2B):
pass


class CodeGemmaConfig7B(GemmaConfig7B):
pass


class GemmaModel(GPTModel):
def __init__(
self,
config: Annotated[Optional[GemmaConfig], Config[GemmaConfig]] = None,
optim: Optional[OptimizerModule] = None,
tokenizer: Optional["TokenizerSpec"] = None,
):
super().__init__(config or GemmaConfig(), optim=optim, tokenizer=tokenizer)


@io.model_importer(GemmaModel, "hf")
class HFGemmaImporter(io.ModelConnector["GemmaForCausalLM", GemmaModel]):
def init(self) -> GemmaModel:
return GemmaModel(self.config, tokenizer=self.tokenizer)

def apply(self, output_path: Path) -> Path:
from transformers import GemmaForCausalLM

source = GemmaForCausalLM.from_pretrained(str(self))
target = self.init()
trainer = self.nemo_setup(target)
self.convert_state(source, target)
self.nemo_save(output_path, trainer)

print(f"Converted Gemma model to Nemo, model saved to {output_path}")

teardown(trainer, target)
del trainer, target

return output_path

def convert_state(self, source, target):
mapping = {
"model.embed_tokens.weight": "embedding.word_embeddings.weight",
"model.layers.*.self_attn.o_proj.weight": "decoder.layers.*.self_attention.linear_proj.weight",
"model.layers.*.mlp.down_proj.weight": "decoder.layers.*.mlp.linear_fc2.weight",
"model.layers.*.input_layernorm.weight": "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight",
"model.layers.*.post_attention_layernorm.weight": "decoder.layers.*.mlp.linear_fc1.layer_norm_weight",
"model.norm.weight": "decoder.final_layernorm.weight",
}

return io.apply_transforms(source, target, mapping=mapping, transforms=[_import_qkv, _import_linear_fc1])

@property
def tokenizer(self) -> "AutoTokenizer":
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer

return AutoTokenizer(str(self))

@property
def config(self) -> GemmaConfig:
from transformers import GemmaConfig as HFGemmaConfig

source = HFGemmaConfig.from_pretrained(str(self))

def make_vocab_size_divisible_by(vocab_size):
base = 128
while vocab_size % base != 0:
base //= 2
return base

output = GemmaConfig(
num_layers=source.num_hidden_layers,
hidden_size=source.hidden_size,
ffn_hidden_size=source.intermediate_size,
num_attention_heads=source.num_attention_heads,
init_method_std=source.initializer_range,
layernorm_epsilon=source.rms_norm_eps,
num_query_groups=source.num_key_value_heads,
rotary_base=source.rope_theta,
gated_linear_unit=True,
make_vocab_size_divisible_by=make_vocab_size_divisible_by(source.vocab_size),
share_embeddings_and_output_weights=False,
)

return output


@io.model_exporter(GemmaModel, "hf")
class HFGemmaExporter(io.ModelConnector[GemmaModel, "GemmaForCausalLM"]):
def init(self) -> "GemmaForCausalLM":
from transformers import AutoModelForCausalLM

return AutoModelForCausalLM.from_config(self.config)

def apply(self, output_path: Path) -> Path:
target = self.init()
source, _ = self.nemo_load(str(self))
target = self.convert_state(source, target)

target = target.cpu()
target.save_pretrained(output_path)
self.tokenizer.save_pretrained(output_path)

return output_path

def convert_state(self, source, target):
mapping = {
"embedding.word_embeddings.weight": "model.embed_tokens.weight",
"decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight",
"decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight",
"decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight",
"decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight",
"decoder.final_layernorm.weight": "model.norm.weight",
}

return io.apply_transforms(source, target, mapping=mapping, transforms=[_export_qkv, _export_linear_fc1])

@property
def tokenizer(self):
return io.load_ckpt(str(self)).model.tokenizer.tokenizer

@property
def config(self) -> "GemmaConfig":
source: GemmaConfig = io.load_ckpt(str(self)).model.config

from transformers import GemmaConfig as HFGemmaConfig

return HFGemmaConfig(
num_hidden_layers=source.num_layers,
hidden_size=source.hidden_size,
intermediate_size=source.ffn_hidden_size,
num_attention_heads=source.num_attention_heads,
max_position_embeddings=source.seq_length,
initializer_range=source.init_method_std,
rms_norm_eps=source.layernorm_epsilon,
num_key_value_heads=source.num_query_groups,
vocab_size=self.tokenizer.vocab_size,
)


@io.state_transform(
source_key=(
"model.layers.*.self_attn.q_proj.weight",
"model.layers.*.self_attn.k_proj.weight",
"model.layers.*.self_attn.v_proj.weight",
),
target_key="decoder.layers.*.self_attention.linear_qkv.weight",
)
def _import_qkv(ctx: io.TransformCTX, q, k, v):
megatron_config = ctx.target.config

head_num = megatron_config.num_attention_heads
num_query_groups = megatron_config.num_query_groups
heads_per_group = head_num // num_query_groups
hidden_size = megatron_config.hidden_size
head_num = megatron_config.num_attention_heads
head_size = hidden_size // head_num

old_tensor_shape = q.size()
new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:]
new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:]

q = q.view(*new_q_tensor_shape)
k = k.view(*new_kv_tensor_shape)
v = v.view(*new_kv_tensor_shape)

qkv_weights_l = []
for i in range(num_query_groups):
qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :])
qkv_weights_l.append(k[i : i + 1, :, :])
qkv_weights_l.append(v[i : i + 1, :, :])
qkv_weights = torch.cat(qkv_weights_l)
assert qkv_weights.ndim == 3, qkv_weights.shape
assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape
assert qkv_weights.shape[1] == head_size, qkv_weights.shape
assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape

qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size])

return qkv_weights


@io.state_transform(
source_key="decoder.layers.*.self_attention.linear_qkv.weight",
target_key=(
"model.layers.*.self_attn.q_proj.weight",
"model.layers.*.self_attn.k_proj.weight",
"model.layers.*.self_attn.v_proj.weight",
),
)
def _export_qkv(ctx: io.TransformCTX, linear_qkv):
megatron_config = ctx.source.config

head_num = megatron_config.num_attention_heads
num_query_groups = megatron_config.num_query_groups
heads_per_group = head_num // num_query_groups
hidden_size = megatron_config.hidden_size
head_num = megatron_config.num_attention_heads
head_size = hidden_size // head_num
qkv_total_dim = head_num + 2 * num_query_groups

linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, hidden_size])
q_slice = torch.cat(
[
torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group)
for i in range(num_query_groups)
]
)
k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2))
v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2))

q_proj = linear_qkv[q_slice].reshape(-1, hidden_size).cpu()
k_proj = linear_qkv[k_slice].reshape(-1, hidden_size).cpu()
v_proj = linear_qkv[v_slice].reshape(-1, hidden_size).cpu()

return q_proj, k_proj, v_proj


@io.state_transform(
source_key=("model.layers.*.mlp.gate_proj.weight", "model.layers.*.mlp.up_proj.weight"),
target_key="decoder.layers.*.mlp.linear_fc1.weight",
)
def _import_linear_fc1(down, gate):
return torch.cat((down, gate), axis=0).float()


@io.state_transform(
source_key="decoder.layers.*.mlp.linear_fc1.weight",
target_key=("model.layers.*.mlp.gate_proj.weight", "model.layers.*.mlp.up_proj.weight"),
)
def _export_linear_fc1(linear_fc1):
gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0)

return gate_proj, up_proj


__all__ = [
"GemmaConfig",
"GemmaConfig2B",
"GemmaConfig7B",
"CodeGemmaConfig2B",
"CodeGemmaConfig7B",
"GemmaModel",
]
Loading