Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/mcore_bridge/config/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import os
import re
import torch.nn.functional as F
from dataclasses import dataclass
from dataclasses import dataclass, field
from megatron.core import mpu
from megatron.core.transformer import TransformerConfig
from transformers import PretrainedConfig
from transformers.utils import is_torch_npu_available
from transformers.utils.versions import require_version
from typing import List, Literal, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union

from mcore_bridge.utils import get_logger, json_parse_to_dict

Expand Down Expand Up @@ -229,6 +229,7 @@ class ModelConfig(TransformerConfig):
task_type: Literal['causal_lm', 'seq_cls', 'embedding', 'generative_reranker'] = 'causal_lm'
num_labels: Optional[int] = None
mlp_padding_free: bool = False
model_kwargs: Dict[str, Any] = field(default_factory=dict)

_mindspeed_defaults_cache = None

Expand Down
1 change: 1 addition & 0 deletions src/mcore_bridge/model/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class MLLMModelType:
glm4v_moe = 'glm4v_moe'
kimi_vl = 'kimi_vl'
llama4 = 'llama4'
gemma4 = 'gemma4'

kimi_k25 = 'kimi_k25'

Expand Down
65 changes: 37 additions & 28 deletions src/mcore_bridge/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,7 @@ def __init__(
for i in range(len(self.decoder.layers)):
if hasattr(self.decoder.layers[i].self_attention, 'rotary_pos_emb'):
del self.decoder.layers[i].self_attention.rotary_pos_emb
self.attention_scaling = 1.
new_inv_freq, self.attention_scaling = get_rope_inv_freq(config)
self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device)
self._set_inv_freq()
if self.config.task_type == 'seq_cls' and self.post_process:
self.output_layer = OutputLayerLinear(
config.hidden_size,
Expand Down Expand Up @@ -222,7 +220,36 @@ def _preprocess(
if decoder_input is not None and self.training and torch.is_grad_enabled() and not decoder_input.requires_grad:
# fix LoRA incompatibility with gradient checkpointing
decoder_input = decoder_input.requires_grad_(True)
rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin = self._get_rotary_pos_emb(
decoder_input, position_ids, packed_seq_params=packed_seq_params)
Comment on lines +223 to +224
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The inference_context is not passed to the _get_rotary_pos_emb method. This will cause the method to skip critical inference-specific logic, such as utilizing the RoPE cache or correctly calculating the rotary sequence length for flash decoding, which can lead to performance degradation or incorrect results during inference.

Suggested change
rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin = self._get_rotary_pos_emb(
decoder_input, position_ids, packed_seq_params=packed_seq_params)
rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin = self._get_rotary_pos_emb(
decoder_input, position_ids, packed_seq_params=packed_seq_params, inference_context=inference_context)


if (in_inference_mode and ((self.config.enable_cuda_graph and self.config.cuda_graph_scope != 'full_iteration')
or self.config.flash_decode) and rotary_pos_cos is not None
and inference_context.is_static_batching()):
current_batch_size = input_ids.shape[0]
sequence_len_offset = torch.tensor(
[inference_context.sequence_len_offset] * current_batch_size,
dtype=torch.int32,
device=rotary_pos_cos.device, # Co-locate this with the rotary tensors
)
else:
sequence_len_offset = None

# Wrap decoder_input to allow the decoder (TransformerBlock) to delete the
# reference held by this caller function, enabling early garbage collection for
# inference. Skip wrapping if decoder_input is logged after decoder completion.
if in_inference_mode and not has_config_logger_enabled(self.config):
decoder_input = WrappedTensor(decoder_input)

return (decoder_input, rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin,
sequence_len_offset)

def _set_inv_freq(self):
self.attention_scaling = 1.
new_inv_freq, self.attention_scaling = get_rope_inv_freq(self.config)
self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device)

def _get_rotary_pos_emb(self, decoder_input, position_ids, packed_seq_params, inference_context=None):
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb = None
rotary_pos_cos = None
Expand Down Expand Up @@ -257,26 +284,13 @@ def _preprocess(
rotary_seq_len,
packed_seq=packed_seq,
)
decoder_rotary_pos_emb = rotary_pos_emb
packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd'
if self.position_embedding_type == 'rope' and packed_seq and not self.config.apply_rope_fusion:
assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}'
decoder_rotary_pos_emb = rotary_pos_emb[position_ids[0]]

if (in_inference_mode and ((self.config.enable_cuda_graph and self.config.cuda_graph_scope != 'full_iteration')
or self.config.flash_decode) and rotary_pos_cos is not None
and inference_context.is_static_batching()):
current_batch_size = input_ids.shape[0]
sequence_len_offset = torch.tensor(
[inference_context.sequence_len_offset] * current_batch_size,
dtype=torch.int32,
device=rotary_pos_cos.device, # Co-locate this with the rotary tensors
)
else:
sequence_len_offset = None

# Wrap decoder_input to allow the decoder (TransformerBlock) to delete the
# reference held by this caller function, enabling early garbage collection for
# inference. Skip wrapping if decoder_input is logged after decoder completion.
if in_inference_mode and not has_config_logger_enabled(self.config):
decoder_input = WrappedTensor(decoder_input)

return decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset
return rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin

# Code borrowed from NVIDIA/Megatron-LM
def forward(
Expand Down Expand Up @@ -308,19 +322,14 @@ def forward(

inference_context = deprecate_inference_params(inference_context, inference_params)

decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset = (
decoder_input, rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset = (
self._preprocess(
input_ids=input_ids,
position_ids=position_ids,
decoder_input=decoder_input,
inference_context=inference_context,
packed_seq_params=packed_seq_params,
))
decoder_rotary_pos_emb = rotary_pos_emb
packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd'
if self.position_embedding_type == 'rope' and packed_seq and not self.config.apply_rope_fusion:
assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}'
decoder_rotary_pos_emb = rotary_pos_emb[position_ids[0]]

mtp_decoder_input = decoder_input
if self.config.is_multimodal and self.config.mtp_num_layers and decoder_input is None:
Expand Down
4 changes: 3 additions & 1 deletion src/mcore_bridge/model/mm_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@


class MultimodalGPTModel(MegatronModule):
language_model_cls = GPTModel

def __init__(self,
config: ModelConfig,
Expand All @@ -29,7 +30,8 @@ def __init__(self,
super().__init__(config)
self.pre_process = pre_process
self.post_process = post_process
self.language_model = GPTModel(config, transformer_layer_spec, pre_process, post_process, *_args, **kwargs)
self.language_model = self.language_model_cls(config, transformer_layer_spec, pre_process, post_process, *_args,
**kwargs)
self.vp_stage = self.language_model.vp_stage
self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights
self.model_meta = config.model_meta
Expand Down
2 changes: 1 addition & 1 deletion src/mcore_bridge/model/mm_gpts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from . import glm, internvl, kimi_vl, llama4, llava, qwen, qwen3_5, qwen3_5_gdn, qwen3_omni, qwen3_vl
from . import gemma4, glm, internvl, kimi_vl, llama4, llava, qwen, qwen3_5, qwen3_5_gdn, qwen3_omni, qwen3_vl
109 changes: 109 additions & 0 deletions src/mcore_bridge/model/mm_gpts/gemma4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import copy
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from transformers import AutoModel, PretrainedConfig
from typing import Optional

from mcore_bridge.bridge import MultimodalGPTBridge
from mcore_bridge.config import ModelConfig

from ..constant import ModelType
from ..gpt_model import GPTModel
from ..mm_gpt_model import MultimodalGPTModel
from ..register import ModelLoader, ModelMeta, register_model
from ..rope import get_rope_inv_freq
from .utils import HuggingFaceVit


class Gemma4Vit(HuggingFaceVit):
module_mapping = {
'model.vision_tower': 'vision_tower',
'model.embed_vision': 'embed_vision',
'model.audio_tower': 'audio_tower',
'model.embed_audio': 'embed_audio',
}
_vision_tower = ['vision_tower', 'audio_tower']
_aligner = ['embed_vision', 'embed_audio']
support_multimodal = False

def prepare_model(self, hf_config: PretrainedConfig):
from transformers.models.gemma4.modeling_gemma4 import Gemma4MultimodalEmbedder
self.vision_tower = AutoModel.from_config(hf_config.vision_config)
self.audio_tower = AutoModel.from_config(hf_config.audio_config) if hf_config.audio_config is not None else None
self.embed_vision = Gemma4MultimodalEmbedder(hf_config.vision_config, hf_config.text_config)
self.embed_audio = (
Gemma4MultimodalEmbedder(hf_config.audio_config, hf_config.text_config)
if hf_config.audio_config is not None else None)

def get_inputs_embeds(self, inputs_embeds, **kwargs):
return inputs_embeds


class Gemma4SelfAttention(SelfAttention):

def __init__(
self,
config: ModelConfig,
submodules: SelfAttentionSubmodules,
layer_number: int,
*args,
**kwargs,
):
text_config = config.hf_config.text_config
super().__init__(config, submodules, layer_number, *args, **kwargs)


class Gemma4Bridge(MultimodalGPTBridge):
pass


class Gemma4TextGPTModel(GPTModel):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
print()

def _set_inv_freq(self):
rope_scaling = self.config.rope_scaling
self.config.rope_scaling = rope_scaling['sliding_attention']
new_inv_freq, attention_scaling = get_rope_inv_freq(self.config)
assert attention_scaling == 1, 'not support'
self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device)
# full
self.full_rotary_pos_emb = copy.copy(self.rotary_pos_emb)
self.config.rope_scaling = rope_scaling['full_attention']
kwargs = {}
if self.config.rope_scaling['rope_type'] == 'proportional':
kwargs['head_dim_key'] = 'global_head_dim'
new_inv_freq, attention_scaling = get_rope_inv_freq(self.config, **kwargs)
assert attention_scaling == 1, 'not support'
self.full_rotary_pos_emb.inv_freq = new_inv_freq
self.attention_scaling = attention_scaling

self.config.rope_scaling = rope_scaling
Comment on lines +67 to +84
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The implementation of _set_inv_freq for Gemma4TextGPTModel has several issues:

  1. Potential Runtime Crash: Restoring self.config.rope_scaling to the original nested dictionary at line 62 will cause a KeyError in _get_rope_type (called via dynamic_rope_update during every forward pass) because that function expects a dictionary with a rope_type key at the top level, which the Gemma4 configuration lacks (it uses sliding_attention and full_attention as top-level keys).
  2. Dead Code: self.full_rotary_pos_emb is initialized but never utilized by the base GPTModel forward pass or RoPE application logic.
  3. Poor Error Messages: The assertion messages 'not support' at lines 49 and 58 are not descriptive. They should clearly state that attention scaling other than 1.0 is not supported for this model.



class Gemma4GPTModel(MultimodalGPTModel):
language_model_cls = Gemma4TextGPTModel


class Gemma4Loader(ModelLoader):
model_cls = Gemma4GPTModel

def get_transformer_layer_spec(self, vp_stage: Optional[int] = None):
layer_specs = get_gpt_decoder_block_spec(
self.config, use_transformer_engine=True, normalization=self.config.normalization, vp_stage=vp_stage)
for layer_spec in layer_specs.layer_specs:
layer_spec.submodules.self_attention.module = Gemma4SelfAttention
return layer_specs


register_model(
ModelMeta(
ModelType.gemma4,
['gemma4'],
bridge_cls=Gemma4Bridge,
visual_cls=Gemma4Vit,
loader=Gemma4Loader,
))
4 changes: 2 additions & 2 deletions src/mcore_bridge/model/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,12 @@ def _get_rope_type(rope_scaling: Optional[Dict[str, Any]]):
return rope_type


def get_rope_inv_freq(config, seq_len=None):
def get_rope_inv_freq(config, seq_len=None, **kwargs):
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
ROPE_INIT_FUNCTIONS.update(EXTENDED_ROPE_INIT_FUNCTIONS)
dummy_config = _get_dummy_config(config)
rope_init_fn = ROPE_INIT_FUNCTIONS[_get_rope_type(config.rope_scaling)]
inv_freq, attention_scaling = rope_init_fn(dummy_config, 'cpu', seq_len=seq_len)
inv_freq, attention_scaling = rope_init_fn(dummy_config, 'cpu', seq_len=seq_len, **kwargs)
if attention_scaling is None:
attention_scaling = 1.
return inv_freq, attention_scaling
Expand Down
39 changes: 3 additions & 36 deletions src/mcore_bridge/tuners/patcher.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import copy
from contextlib import contextmanager
from megatron.core.extensions.transformer_engine import TEGroupedLinear, TELayerNormColumnParallelLinear, TELinear
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.moe.router import TopKRouter
Expand All @@ -11,6 +9,8 @@
from torch import nn
from typing import Optional

from mcore_bridge.utils import patch_deepcopy

from .lora import LoraParallelLinear


Expand All @@ -37,47 +37,14 @@ def dispatch_megatron(
model.dispatch_megatron = dispatch_megatron


@contextmanager
def _patch_deepcopy():
_origin_deepcopy = copy.deepcopy
copy_keys = ('tp_group', '_tp_group', 'config')

def new_deepcopy(x, *args, **kwargs):
if not isinstance(x, nn.Module):
return _origin_deepcopy(x, *args, **kwargs)

saved_values = {}
for key in copy_keys:
val = getattr(x, key, None)
if val is not None:
saved_values[key] = val
setattr(x, key, None)

try:
res = _origin_deepcopy(x, *args, **kwargs)
finally:
for key, value in saved_values.items():
setattr(x, key, value)

for key, value in saved_values.items():
setattr(res, key, value)
return res

copy.deepcopy = new_deepcopy
try:
yield
finally:
copy.deepcopy = _origin_deepcopy


def _patch_lora_model():
if hasattr(LoraModel, '_mcore_patched'):
return

__origin_init__ = LoraModel.__init__

def __new_init__(self, *args, **kwargs):
with _patch_deepcopy():
with patch_deepcopy():
__origin_init__(self, *args, **kwargs)
if not isinstance(self.model, MegatronModule):
return
Expand Down
2 changes: 1 addition & 1 deletion src/mcore_bridge/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
from .megatron_utils import get_local_layer_specs, set_random_seed, split_cp_inputs, unwrap_model
from .safetensors import SafetensorLazyLoader, StreamingSafetensorSaver
from .torch_utils import gc_collect, get_current_device, safe_ddp_context, to_device
from .utils import deep_getattr, get_env_args, json_parse_to_dict
from .utils import deep_getattr, get_env_args, json_parse_to_dict, patch_deepcopy
36 changes: 36 additions & 0 deletions src/mcore_bridge/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import copy
import json
import os
from contextlib import contextmanager
from torch import nn
from transformers.utils import strtobool
from typing import Callable, Dict, Optional, TypeVar, Union

Expand Down Expand Up @@ -58,3 +61,36 @@ def deep_getattr(obj, attr: str, default=None):
else:
obj = getattr(obj, a, default)
return obj


@contextmanager
def patch_deepcopy():
_origin_deepcopy = copy.deepcopy
copy_keys = ('tp_group', '_tp_group', 'config')

def new_deepcopy(x, *args, **kwargs):
if not isinstance(x, nn.Module):
return _origin_deepcopy(x, *args, **kwargs)

saved_values = {}
for key in copy_keys:
val = getattr(x, key, None)
if val is not None:
saved_values[key] = val
setattr(x, key, None)

try:
res = _origin_deepcopy(x, *args, **kwargs)
finally:
for key, value in saved_values.items():
setattr(x, key, value)

for key, value in saved_values.items():
setattr(res, key, value)
return res

copy.deepcopy = new_deepcopy
try:
yield
finally:
copy.deepcopy = _origin_deepcopy
Loading