diff --git a/swift/llm/model/model/gemma.py b/swift/llm/model/model/gemma.py index 2c9291d077..9ba1060a3d 100644 --- a/swift/llm/model/model/gemma.py +++ b/swift/llm/model/model/gemma.py @@ -160,6 +160,6 @@ def get_model_tokenizer_gemma3_vision(model_dir: str, TemplateType.gemma3_vision, get_model_tokenizer_gemma3_vision, architectures=['Gemma3ForConditionalGeneration'], - model_arch=ModelArch.gemma3_vision, + model_arch=ModelArch.llava_hf, requires=['transformers>=4.49'], )) diff --git a/swift/llm/model/model/stepfun.py b/swift/llm/model/model/stepfun.py index 34d9db7228..7a7073836d 100644 --- a/swift/llm/model/model/stepfun.py +++ b/swift/llm/model/model/stepfun.py @@ -35,7 +35,7 @@ def get_model_tokenizer_got_ocr2(*args, **kwargs): def get_model_tokenizer_got_ocr2_hf(model_dir, *args, **kwargs): from transformers.models.got_ocr2 import GotOcr2ForConditionalGeneration - GotOcr2ForConditionalGeneration._no_split_modules.append('GotOcr2VisionLayer') + GotOcr2ForConditionalGeneration._no_split_modules = ['GotOcr2VisionLayer'] model, processor = get_model_tokenizer_multimodal(model_dir, *args, **kwargs) return model, processor @@ -49,7 +49,7 @@ def get_model_tokenizer_got_ocr2_hf(model_dir, *args, **kwargs): ], TemplateType.got_ocr2_hf, get_model_tokenizer_got_ocr2_hf, - model_arch=ModelArch.got_ocr2_hf, + model_arch=ModelArch.llava_hf, architectures=['GOTQwenForCausalLM'], tags=['vision'])) diff --git a/swift/llm/model/model_arch.py b/swift/llm/model/model_arch.py index b32b47c0c6..de54a90b57 100644 --- a/swift/llm/model/model_arch.py +++ b/swift/llm/model/model_arch.py @@ -2,6 +2,11 @@ from dataclasses import dataclass, field from typing import List, Optional, Union +import transformers +from packaging import version + +transformers_ge_4_52 = version.parse(transformers.__version__) >= version.parse('4.52') + class LLMModelArch: qwen = 'qwen' @@ -33,7 +38,6 @@ class MLLMModelArch: llama3_1_omni = 'llama3_1_omni' llama3_2_vision = 'llama3_2_vision' - llama4 = 'llama4' llava_hf = 'llava_hf' llava_next_video_hf = 'llava_next_video_hf' @@ -59,14 +63,12 @@ class MLLMModelArch: idefics3 = 'idefics3' got_ocr2 = 'got_ocr2' - got_ocr2_hf = 'got_ocr2_hf' ovis1_6 = 'ovis1_6' molmo = 'molmo' emu3_chat = 'emu3_chat' megrez_omni = 'megrez_omni' valley = 'valley' - gemma3_vision = 'gemma3_vision' mistral_2503 = 'mistral_2503' @@ -308,13 +310,22 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non lm_head='lm_head', )) -register_model_arch( - MultiModelKeys( - MLLMModelArch.llava_hf, - language_model='language_model', - aligner='multi_modal_projector', - vision_tower='vision_tower', - )) +if transformers_ge_4_52: + register_model_arch( + MultiModelKeys( + MLLMModelArch.llava_hf, + language_model='model.language_model', + aligner='model.multi_modal_projector', + vision_tower='model.vision_tower', + )) +else: + register_model_arch( + MultiModelKeys( + MLLMModelArch.llava_hf, + language_model='language_model', + aligner='multi_modal_projector', + vision_tower='vision_tower', + )) register_model_arch( MultiModelKeys( @@ -324,12 +335,20 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non vision_tower='model.vision_tower', )) -register_model_arch( - MultiModelKeys( - MLLMModelArch.llava_next_video_hf, - language_model='language_model', - aligner=['multi_modal_projector'], - vision_tower='vision_tower')) +if transformers_ge_4_52: + register_model_arch( + MultiModelKeys( + MLLMModelArch.llava_next_video_hf, + language_model='model.language_model', + aligner=['model.multi_modal_projector'], + vision_tower='model.vision_tower')) +else: + register_model_arch( + MultiModelKeys( + MLLMModelArch.llava_next_video_hf, + language_model='language_model', + aligner=['multi_modal_projector'], + vision_tower='vision_tower')) register_model_arch( MultiModelKeys( @@ -459,13 +478,23 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non vision_tower='audio_tower', )) -register_model_arch( - MultiModelKeys( - MLLMModelArch.qwen2_vl, - language_model='model', - aligner='visual.merger', - vision_tower='visual', - )) +if transformers_ge_4_52: + register_model_arch( + MultiModelKeys( + MLLMModelArch.qwen2_vl, + language_model='model.language_model', + aligner='model.visual.merger', + vision_tower='model.visual', + )) +else: + register_model_arch( + MultiModelKeys( + MLLMModelArch.qwen2_vl, + language_model='model', + aligner='visual.merger', + vision_tower='visual', + )) + register_model_arch( MultiModelKeys( MLLMModelArch.qwen2_5_omni, @@ -507,13 +536,22 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non vision_tower='model.vision_tower_high', )) -register_model_arch( - MultiModelKeys( - MLLMModelArch.llama3_2_vision, - language_model='language_model', - aligner='multi_modal_projector', - vision_tower='vision_model', - )) +if transformers_ge_4_52: + register_model_arch( + MultiModelKeys( + MLLMModelArch.llama3_2_vision, + language_model='model.language_model', + aligner='model.multi_modal_projector', + vision_tower='model.vision_model', + )) +else: + register_model_arch( + MultiModelKeys( + MLLMModelArch.llama3_2_vision, + language_model='language_model', + aligner='multi_modal_projector', + vision_tower='vision_model', + )) register_model_arch(MultiModelKeys( MLLMModelArch.ovis1_6, @@ -547,14 +585,6 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non vision_tower=['model.vision_tower', 'model.qwen2vl_vision_tower'], )) -register_model_arch( - MultiModelKeys( - MLLMModelArch.gemma3_vision, - language_model='language_model', - aligner='multi_modal_projector', - vision_tower='vision_tower', - )) - def get_model_arch(arch_name: Optional[str]) -> Optional[MultiModelKeys]: return MODEL_ARCH_MAPPING.get(arch_name) diff --git a/swift/llm/model/patcher.py b/swift/llm/model/patcher.py index 9d6c402538..a4bac6c005 100644 --- a/swift/llm/model/patcher.py +++ b/swift/llm/model/patcher.py @@ -17,7 +17,7 @@ from transformers.modeling_outputs import SequenceClassifierOutputWithPast from swift.llm import deep_getattr, to_device, to_float_dtype -from swift.utils import get_dist_setting, get_logger, is_mp_ddp, safe_ddp_context, use_torchacc +from swift.utils import get_dist_setting, get_logger, is_mp, is_mp_ddp, safe_ddp_context, use_torchacc from swift.utils.torch_utils import _get_max_memory, _sync_max_memory, get_device_count from .utils import HfConfigFactory, get_llm_model @@ -349,7 +349,7 @@ def new_get_cached_module_file(pretrained_model_name_or_path, *args, **kwargs): @contextmanager def patch_tp_plan(load_model: bool): - if not load_model or not is_mp_ddp() or version.parse( + if not load_model or not is_mp() or version.parse( transformers.__version__) < version.parse('4.50') or 'WORLD_SIZE' not in os.environ: yield return