diff --git a/swift/llm/sft.py b/swift/llm/sft.py index 8e398f7b6b..edb529a87c 100644 --- a/swift/llm/sft.py +++ b/swift/llm/sft.py @@ -22,8 +22,8 @@ from .accelerator import ta_accelerate from .tuner import prepare_model from .utils import (TEMPLATE_MAPPING, LazyLLMDataset, PtArguments, RLHFArguments, SftArguments, Template, dataset_map, - get_dataset, get_model_tokenizer, get_template, get_time_info, print_example, set_generation_config, - sort_by_max_length, stat_dataset) + dynamic_vit_gradient_checkpointing, get_dataset, get_model_tokenizer, get_template, get_time_info, + print_example, set_generation_config, sort_by_max_length, stat_dataset) logger = get_logger() @@ -239,6 +239,8 @@ def prepare_model_template_train(args, msg: Optional[Dict[str, Any]] = None): model.label_names = label_names model.return_loss = return_loss + if args.is_multimodal and args.gradient_checkpointing and args.vit_use_gc: + dynamic_vit_gradient_checkpointing(model, args.model_type) # Preparing LoRA model, callbacks = prepare_model(model, args) diff --git a/swift/llm/utils/__init__.py b/swift/llm/utils/__init__.py index 6c7a42e307..3c3c66ac40 100644 --- a/swift/llm/utils/__init__.py +++ b/swift/llm/utils/__init__.py @@ -22,11 +22,12 @@ ModelList, UsageInfo, XRequestConfig, random_uuid) from .template import (DEFAULT_SYSTEM, TEMPLATE_MAPPING, History, KTOTemplateMixin, Prompt, RLHFTemplateMixin, StopWords, Template, TemplateType, get_env_args, get_template, register_template) -from .utils import (LazyLLMDataset, LLMDataset, dataset_map, download_dataset, find_all_linears, find_embedding, - find_ln, get_max_model_len, get_time_info, history_to_messages, inference, inference_stream, - is_lmdeploy_available, is_megatron_available, is_quant_model, is_vllm_available, - limit_history_length, messages_join_observation, messages_to_history, print_example, - safe_tokenizer_decode, set_generation_config, sort_by_max_length, stat_dataset, to_device) +from .utils import (LazyLLMDataset, LLMDataset, dataset_map, download_dataset, dynamic_vit_gradient_checkpointing, + find_all_linears, find_embedding, find_ln, get_max_model_len, get_time_info, history_to_messages, + inference, inference_stream, is_lmdeploy_available, is_megatron_available, is_quant_model, + is_vllm_available, limit_history_length, messages_join_observation, messages_to_history, + print_example, safe_tokenizer_decode, set_generation_config, sort_by_max_length, stat_dataset, + to_device) logger = get_logger() diff --git a/swift/llm/utils/argument.py b/swift/llm/utils/argument.py index 2655f54a3f..90feb3b1cd 100644 --- a/swift/llm/utils/argument.py +++ b/swift/llm/utils/argument.py @@ -780,6 +780,7 @@ class SftArguments(ArgumentsBase): use_liger: bool = False gradient_checkpointing: Optional[bool] = None + vit_use_gc: bool = True # vit use gradient_checkpointing # e.g. 'default-zero3', 'default-zero2', 'ds_config/zero2.json', 'zero2-offload', 'zero3-offload' deepspeed: Optional[str] = None batch_size: int = 1 diff --git a/swift/llm/utils/utils.py b/swift/llm/utils/utils.py index db796ac16d..3b7ffd329b 100644 --- a/swift/llm/utils/utils.py +++ b/swift/llm/utils/utils.py @@ -10,6 +10,7 @@ from queue import Empty, Queue from tempfile import TemporaryDirectory from threading import Thread +from types import MethodType from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Sequence, Set, Tuple, Union import accelerate @@ -18,6 +19,8 @@ import requests import torch import torch.distributed as dist +import torch.nn as nn +import torch.utils.checkpoint import transformers from datasets import Dataset as HfDataset from datasets import IterableDataset as HfIterableDataset @@ -421,6 +424,57 @@ def find_ln(model: Module) -> List[str]: return list(module_names) +def _find_module_list(vision_tower) -> Optional[nn.ModuleList]: + module_lists = [] + for m in vision_tower.modules(): + if hasattr(m, 'gradient_checkpointing'): + return + if isinstance(m, nn.ModuleList) and len(m) >= 10: + module_lists.append(m) + if module_lists is not None: + return max(module_lists, key=lambda x: len(x)) + + +def _add_gradient_checkpointing(module_list): + + def _new_forward(self, *args, **kwargs): + layer_ret = torch.utils.checkpoint.checkpoint(self.__old_forward, *args, **kwargs) + return layer_ret + + for module in module_list: + if hasattr(module, '_old_forward'): # device_map + __old_forward = module._old_forward + module._old_forward = MethodType(_new_forward, module) + else: + __old_forward = module.forward + module.forward = MethodType(_new_forward, module) + module.__old_forward = __old_forward + + +def deep_getattr(model, attr: str): + attrs = attr.split('.') + for a in attrs: + model = getattr(model, a) + return model + + +def dynamic_vit_gradient_checkpointing(model, model_type: str) -> None: + from swift.utils.module_mapping import MODEL_KEYS_MAPPING + from .model import MODEL_MAPPING + model_info = MODEL_MAPPING[model_type] + lora_target_modules = model_info.get('lora_target_modules') + + if not isinstance(lora_target_modules, str): + return + vision_tower_list = MODEL_KEYS_MAPPING[lora_target_modules].vision_tower + for vision_tower_name in vision_tower_list: + vision_tower = deep_getattr(model, vision_tower_name) + module_list = _find_module_list(vision_tower) + if module_list is None: + continue + _add_gradient_checkpointing(module_list) + + def find_embedding(model: Module) -> List[str]: return _find_layers(model, torch.nn.Embedding)