Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion swift/llm/model/model/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,6 @@ def get_model_tokenizer_gemma3_vision(model_dir: str,
TemplateType.gemma3_vision,
get_model_tokenizer_gemma3_vision,
architectures=['Gemma3ForConditionalGeneration'],
model_arch=ModelArch.gemma3_vision,
model_arch=ModelArch.llava_hf,
requires=['transformers>=4.49'],
))
4 changes: 2 additions & 2 deletions swift/llm/model/model/stepfun.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_model_tokenizer_got_ocr2(*args, **kwargs):

def get_model_tokenizer_got_ocr2_hf(model_dir, *args, **kwargs):
from transformers.models.got_ocr2 import GotOcr2ForConditionalGeneration
GotOcr2ForConditionalGeneration._no_split_modules.append('GotOcr2VisionLayer')
GotOcr2ForConditionalGeneration._no_split_modules = ['GotOcr2VisionLayer']
model, processor = get_model_tokenizer_multimodal(model_dir, *args, **kwargs)
return model, processor

Expand All @@ -49,7 +49,7 @@ def get_model_tokenizer_got_ocr2_hf(model_dir, *args, **kwargs):
],
TemplateType.got_ocr2_hf,
get_model_tokenizer_got_ocr2_hf,
model_arch=ModelArch.got_ocr2_hf,
model_arch=ModelArch.llava_hf,
architectures=['GOTQwenForCausalLM'],
tags=['vision']))

Expand Down
106 changes: 68 additions & 38 deletions swift/llm/model/model_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
from dataclasses import dataclass, field
from typing import List, Optional, Union

import transformers
from packaging import version

transformers_ge_4_52 = version.parse(transformers.__version__) >= version.parse('4.52')


class LLMModelArch:
qwen = 'qwen'
Expand Down Expand Up @@ -33,7 +38,6 @@ class MLLMModelArch:

llama3_1_omni = 'llama3_1_omni'
llama3_2_vision = 'llama3_2_vision'
llama4 = 'llama4'

llava_hf = 'llava_hf'
llava_next_video_hf = 'llava_next_video_hf'
Expand All @@ -59,14 +63,12 @@ class MLLMModelArch:
idefics3 = 'idefics3'

got_ocr2 = 'got_ocr2'
got_ocr2_hf = 'got_ocr2_hf'

ovis1_6 = 'ovis1_6'
molmo = 'molmo'
emu3_chat = 'emu3_chat'
megrez_omni = 'megrez_omni'
valley = 'valley'
gemma3_vision = 'gemma3_vision'
mistral_2503 = 'mistral_2503'


Expand Down Expand Up @@ -308,13 +310,22 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
lm_head='lm_head',
))

register_model_arch(
MultiModelKeys(
MLLMModelArch.llava_hf,
language_model='language_model',
aligner='multi_modal_projector',
vision_tower='vision_tower',
))
if transformers_ge_4_52:
register_model_arch(
MultiModelKeys(
MLLMModelArch.llava_hf,
language_model='model.language_model',
aligner='model.multi_modal_projector',
vision_tower='model.vision_tower',
))
else:
register_model_arch(
MultiModelKeys(
MLLMModelArch.llava_hf,
language_model='language_model',
aligner='multi_modal_projector',
vision_tower='vision_tower',
))

register_model_arch(
MultiModelKeys(
Expand All @@ -324,12 +335,20 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
vision_tower='model.vision_tower',
))

register_model_arch(
MultiModelKeys(
MLLMModelArch.llava_next_video_hf,
language_model='language_model',
aligner=['multi_modal_projector'],
vision_tower='vision_tower'))
if transformers_ge_4_52:
register_model_arch(
MultiModelKeys(
MLLMModelArch.llava_next_video_hf,
language_model='model.language_model',
aligner=['model.multi_modal_projector'],
vision_tower='model.vision_tower'))
else:
register_model_arch(
MultiModelKeys(
MLLMModelArch.llava_next_video_hf,
language_model='language_model',
aligner=['multi_modal_projector'],
vision_tower='vision_tower'))

register_model_arch(
MultiModelKeys(
Expand Down Expand Up @@ -459,13 +478,23 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
vision_tower='audio_tower',
))

register_model_arch(
MultiModelKeys(
MLLMModelArch.qwen2_vl,
language_model='model',
aligner='visual.merger',
vision_tower='visual',
))
if transformers_ge_4_52:
register_model_arch(
MultiModelKeys(
MLLMModelArch.qwen2_vl,
language_model='model.language_model',
aligner='model.visual.merger',
vision_tower='model.visual',
))
else:
register_model_arch(
MultiModelKeys(
MLLMModelArch.qwen2_vl,
language_model='model',
aligner='visual.merger',
vision_tower='visual',
))

register_model_arch(
MultiModelKeys(
MLLMModelArch.qwen2_5_omni,
Expand Down Expand Up @@ -507,13 +536,22 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
vision_tower='model.vision_tower_high',
))

register_model_arch(
MultiModelKeys(
MLLMModelArch.llama3_2_vision,
language_model='language_model',
aligner='multi_modal_projector',
vision_tower='vision_model',
))
if transformers_ge_4_52:
register_model_arch(
MultiModelKeys(
MLLMModelArch.llama3_2_vision,
language_model='model.language_model',
aligner='model.multi_modal_projector',
vision_tower='model.vision_model',
))
else:
register_model_arch(
MultiModelKeys(
MLLMModelArch.llama3_2_vision,
language_model='language_model',
aligner='multi_modal_projector',
vision_tower='vision_model',
))

register_model_arch(MultiModelKeys(
MLLMModelArch.ovis1_6,
Expand Down Expand Up @@ -547,14 +585,6 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
vision_tower=['model.vision_tower', 'model.qwen2vl_vision_tower'],
))

register_model_arch(
MultiModelKeys(
MLLMModelArch.gemma3_vision,
language_model='language_model',
aligner='multi_modal_projector',
vision_tower='vision_tower',
))


def get_model_arch(arch_name: Optional[str]) -> Optional[MultiModelKeys]:
return MODEL_ARCH_MAPPING.get(arch_name)
4 changes: 2 additions & 2 deletions swift/llm/model/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from transformers.modeling_outputs import SequenceClassifierOutputWithPast

from swift.llm import deep_getattr, to_device, to_float_dtype
from swift.utils import get_dist_setting, get_logger, is_mp_ddp, safe_ddp_context, use_torchacc
from swift.utils import get_dist_setting, get_logger, is_mp, is_mp_ddp, safe_ddp_context, use_torchacc
from swift.utils.torch_utils import _get_max_memory, _sync_max_memory, get_device_count
from .utils import HfConfigFactory, get_llm_model

Expand Down Expand Up @@ -349,7 +349,7 @@ def new_get_cached_module_file(pretrained_model_name_or_path, *args, **kwargs):

@contextmanager
def patch_tp_plan(load_model: bool):
if not load_model or not is_mp_ddp() or version.parse(
if not load_model or not is_mp() or version.parse(
transformers.__version__) < version.parse('4.50') or 'WORLD_SIZE' not in os.environ:
yield
return
Expand Down
Loading