From ee9875443d1b7e6f72a09e03db1bf9696bdd3515 Mon Sep 17 00:00:00 2001 From: "huangjintao.hjt" Date: Thu, 19 Sep 2024 19:54:23 +0800 Subject: [PATCH 1/6] dynamic gradient_checkpointing --- swift/llm/rlhf.py | 4 ++-- swift/llm/sft.py | 10 ++++---- swift/llm/utils/__init__.py | 11 +++++---- swift/llm/utils/argument.py | 2 +- swift/llm/utils/utils.py | 47 ++++++++++++++++++++++++++++++++++++- 5 files changed, 61 insertions(+), 13 deletions(-) diff --git a/swift/llm/rlhf.py b/swift/llm/rlhf.py index ccf65520de..d4662f8f09 100644 --- a/swift/llm/rlhf.py +++ b/swift/llm/rlhf.py @@ -3,7 +3,7 @@ from swift.trainers import TrainerFactory from swift.utils import get_logger, get_main, seed_everything -from .sft import prepare_dataset, prepare_train_model_template, trainer_train +from .sft import prepare_dataset, prepare_model_template_train, trainer_train from .utils import TEMPLATE_MAPPING, RLHFArguments logger = get_logger() @@ -18,7 +18,7 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]: logger.warning(f"Please check if args.template_type: '{args.template_type}' is correct.") msg = {} - model, ref_model, template, callbacks = prepare_train_model_template(args) + model, ref_model, template, callbacks = prepare_model_template_train(args) with TrainerFactory.patch_template(args, template): train_dataset, val_dataset = prepare_dataset(args, template, msg) diff --git a/swift/llm/sft.py b/swift/llm/sft.py index d76de2e989..da133c9ae8 100644 --- a/swift/llm/sft.py +++ b/swift/llm/sft.py @@ -22,8 +22,8 @@ from .accelerator import ta_accelerate from .tuner import prepare_model from .utils import (TEMPLATE_MAPPING, LazyLLMDataset, PtArguments, RLHFArguments, SftArguments, Template, dataset_map, - get_dataset, get_model_tokenizer, get_template, get_time_info, print_example, set_generation_config, - sort_by_max_length, stat_dataset) + dynamic_vit_gradient_checkpointing, get_dataset, get_model_tokenizer, get_template, get_time_info, + print_example, set_generation_config, sort_by_max_length, stat_dataset) logger = get_logger() @@ -115,7 +115,7 @@ def llm_sft_megatron(args: SftArguments) -> Dict[str, Any]: return {} -def prepare_train_model_template(args, msg: Optional[Dict[str, Any]] = None): +def prepare_model_template_train(args, msg: Optional[Dict[str, Any]] = None): if args.gpu_memory_fraction is not None: for device_id in range(torch.cuda.device_count()): @@ -239,6 +239,8 @@ def prepare_train_model_template(args, msg: Optional[Dict[str, Any]] = None): model.label_names = label_names model.return_loss = return_loss + if args.is_multimodal: + dynamic_vit_gradient_checkpointing(model, args.model_type) # Preparing LoRA model, callbacks = prepare_model(model, args) @@ -501,7 +503,7 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]: if args.train_backend == 'megatron': return llm_sft_megatron(args) msg = {} - model, template, callbacks = prepare_train_model_template(args, msg) + model, template, callbacks = prepare_model_template_train(args, msg) train_dataset, val_dataset = prepare_dataset(args, template, msg) return trainer_train(args, model, template, train_dataset, val_dataset, callbacks=callbacks, msg=msg) diff --git a/swift/llm/utils/__init__.py b/swift/llm/utils/__init__.py index 6c7a42e307..3c3c66ac40 100644 --- a/swift/llm/utils/__init__.py +++ b/swift/llm/utils/__init__.py @@ -22,11 +22,12 @@ ModelList, UsageInfo, XRequestConfig, random_uuid) from .template import (DEFAULT_SYSTEM, TEMPLATE_MAPPING, History, KTOTemplateMixin, Prompt, RLHFTemplateMixin, StopWords, Template, TemplateType, get_env_args, get_template, register_template) -from .utils import (LazyLLMDataset, LLMDataset, dataset_map, download_dataset, find_all_linears, find_embedding, - find_ln, get_max_model_len, get_time_info, history_to_messages, inference, inference_stream, - is_lmdeploy_available, is_megatron_available, is_quant_model, is_vllm_available, - limit_history_length, messages_join_observation, messages_to_history, print_example, - safe_tokenizer_decode, set_generation_config, sort_by_max_length, stat_dataset, to_device) +from .utils import (LazyLLMDataset, LLMDataset, dataset_map, download_dataset, dynamic_vit_gradient_checkpointing, + find_all_linears, find_embedding, find_ln, get_max_model_len, get_time_info, history_to_messages, + inference, inference_stream, is_lmdeploy_available, is_megatron_available, is_quant_model, + is_vllm_available, limit_history_length, messages_join_observation, messages_to_history, + print_example, safe_tokenizer_decode, set_generation_config, sort_by_max_length, stat_dataset, + to_device) logger = get_logger() diff --git a/swift/llm/utils/argument.py b/swift/llm/utils/argument.py index 47ce25c5fc..2655f54a3f 100644 --- a/swift/llm/utils/argument.py +++ b/swift/llm/utils/argument.py @@ -139,7 +139,7 @@ def handle_generation_config(self: Union['SftArguments', 'InferArguments']) -> N if self.temperature == 0: self.do_sample = False if self.do_sample is False and (isinstance(self, InferArguments) and self.infer_backend == 'pt' - and isinstance(self, SftArguments)): + or isinstance(self, SftArguments)): # fix warning self.temperature = 1. self.top_p = 1. diff --git a/swift/llm/utils/utils.py b/swift/llm/utils/utils.py index db796ac16d..798fe81594 100644 --- a/swift/llm/utils/utils.py +++ b/swift/llm/utils/utils.py @@ -6,10 +6,11 @@ import shutil import time from copy import deepcopy -from functools import partial, wraps +from functools import partial, update_wrapper, wraps from queue import Empty, Queue from tempfile import TemporaryDirectory from threading import Thread +from types import MethodType from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Sequence, Set, Tuple, Union import accelerate @@ -18,6 +19,8 @@ import requests import torch import torch.distributed as dist +import torch.nn as nn +import torch.utils.checkpoint import transformers from datasets import Dataset as HfDataset from datasets import IterableDataset as HfIterableDataset @@ -421,6 +424,48 @@ def find_ln(model: Module) -> List[str]: return list(module_names) +def _find_module_list(vision_tower) -> Optional[nn.ModuleList]: + module_lists = [] + for m in vision_tower.modules(): + if isinstance(m, nn.ModuleList) and len(m) >= 10: + module_lists.append(m) + if module_lists is not None: + return max(module_lists, key=lambda x: len(x)) + + +def _add_gradient_checkpointing(module_list): + + def _new_forward(self, *args, **kwargs): + layer_ret = torch.utils.checkpoint.checkpoint(self.__old_forward, *args, **kwargs, use_reentrant=False) + return layer_ret + + for module in module_list: + if hasattr(module, '_old_forward'): # device_map + __old_forward = module._old_forward + module._old_forward = MethodType(_new_forward, module) + else: + __old_forward = module.forward + module.forward = MethodType(_new_forward, module) + module.__old_forward = __old_forward + + +def dynamic_vit_gradient_checkpointing(model, model_type: str) -> None: + from swift.utils.module_mapping import MODEL_KEYS_MAPPING + from .model import MODEL_MAPPING + model_info = MODEL_MAPPING[model_type] + lora_target_modules = model_info.get('lora_target_modules') + + if not isinstance(lora_target_modules, str): + return + vision_tower_list = MODEL_KEYS_MAPPING[lora_target_modules].vision_tower + for vision_tower_name in vision_tower_list: + vision_tower = getattr(model, vision_tower_name) + module_list = _find_module_list(vision_tower) + if module_list is None: + continue + _add_gradient_checkpointing(module_list) + + def find_embedding(model: Module) -> List[str]: return _find_layers(model, torch.nn.Embedding) From c01cd6f1fd1f8b4d454dcad5ec1f68c13b3b571f Mon Sep 17 00:00:00 2001 From: "huangjintao.hjt" Date: Thu, 19 Sep 2024 21:10:01 +0800 Subject: [PATCH 2/6] update --- swift/llm/utils/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/swift/llm/utils/utils.py b/swift/llm/utils/utils.py index 798fe81594..888c463803 100644 --- a/swift/llm/utils/utils.py +++ b/swift/llm/utils/utils.py @@ -427,6 +427,8 @@ def find_ln(model: Module) -> List[str]: def _find_module_list(vision_tower) -> Optional[nn.ModuleList]: module_lists = [] for m in vision_tower.modules(): + if getattr(m, 'gradient_checkpointing', False): + return if isinstance(m, nn.ModuleList) and len(m) >= 10: module_lists.append(m) if module_lists is not None: From 0a54fc1699bc77d107ea1a9677f07195492d1c8d Mon Sep 17 00:00:00 2001 From: "huangjintao.hjt" Date: Thu, 19 Sep 2024 21:29:05 +0800 Subject: [PATCH 3/6] update --- swift/llm/sft.py | 2 +- swift/llm/utils/utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/swift/llm/sft.py b/swift/llm/sft.py index da133c9ae8..e39c0317b9 100644 --- a/swift/llm/sft.py +++ b/swift/llm/sft.py @@ -239,7 +239,7 @@ def prepare_model_template_train(args, msg: Optional[Dict[str, Any]] = None): model.label_names = label_names model.return_loss = return_loss - if args.is_multimodal: + if args.is_multimodal and args.gradient_checkpointing: dynamic_vit_gradient_checkpointing(model, args.model_type) # Preparing LoRA model, callbacks = prepare_model(model, args) diff --git a/swift/llm/utils/utils.py b/swift/llm/utils/utils.py index 888c463803..5774616805 100644 --- a/swift/llm/utils/utils.py +++ b/swift/llm/utils/utils.py @@ -427,7 +427,7 @@ def find_ln(model: Module) -> List[str]: def _find_module_list(vision_tower) -> Optional[nn.ModuleList]: module_lists = [] for m in vision_tower.modules(): - if getattr(m, 'gradient_checkpointing', False): + if hasattr(m, 'gradient_checkpointing'): return if isinstance(m, nn.ModuleList) and len(m) >= 10: module_lists.append(m) @@ -438,7 +438,7 @@ def _find_module_list(vision_tower) -> Optional[nn.ModuleList]: def _add_gradient_checkpointing(module_list): def _new_forward(self, *args, **kwargs): - layer_ret = torch.utils.checkpoint.checkpoint(self.__old_forward, *args, **kwargs, use_reentrant=False) + layer_ret = torch.utils.checkpoint.checkpoint(self.__old_forward, *args, **kwargs) return layer_ret for module in module_list: From 75245fa07cc99c6af80bf53d14cb020a853e5494 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 19 Sep 2024 21:33:25 +0800 Subject: [PATCH 4/6] update --- swift/llm/utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/llm/utils/utils.py b/swift/llm/utils/utils.py index 5774616805..5f4c464cd1 100644 --- a/swift/llm/utils/utils.py +++ b/swift/llm/utils/utils.py @@ -6,7 +6,7 @@ import shutil import time from copy import deepcopy -from functools import partial, update_wrapper, wraps +from functools import partial, wraps from queue import Empty, Queue from tempfile import TemporaryDirectory from threading import Thread From 23f75f509d005ad6884454b2bed911cf7c536af9 Mon Sep 17 00:00:00 2001 From: "huangjintao.hjt" Date: Thu, 19 Sep 2024 23:05:02 +0800 Subject: [PATCH 5/6] update --- swift/llm/utils/utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/swift/llm/utils/utils.py b/swift/llm/utils/utils.py index 5f4c464cd1..3b7ffd329b 100644 --- a/swift/llm/utils/utils.py +++ b/swift/llm/utils/utils.py @@ -451,6 +451,13 @@ def _new_forward(self, *args, **kwargs): module.__old_forward = __old_forward +def deep_getattr(model, attr: str): + attrs = attr.split('.') + for a in attrs: + model = getattr(model, a) + return model + + def dynamic_vit_gradient_checkpointing(model, model_type: str) -> None: from swift.utils.module_mapping import MODEL_KEYS_MAPPING from .model import MODEL_MAPPING @@ -461,7 +468,7 @@ def dynamic_vit_gradient_checkpointing(model, model_type: str) -> None: return vision_tower_list = MODEL_KEYS_MAPPING[lora_target_modules].vision_tower for vision_tower_name in vision_tower_list: - vision_tower = getattr(model, vision_tower_name) + vision_tower = deep_getattr(model, vision_tower_name) module_list = _find_module_list(vision_tower) if module_list is None: continue From e4d680a3747cb6746199be70bb5869778e41747d Mon Sep 17 00:00:00 2001 From: "huangjintao.hjt" Date: Fri, 20 Sep 2024 10:09:29 +0800 Subject: [PATCH 6/6] update --- swift/llm/sft.py | 2 +- swift/llm/utils/argument.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/swift/llm/sft.py b/swift/llm/sft.py index e39c0317b9..edb529a87c 100644 --- a/swift/llm/sft.py +++ b/swift/llm/sft.py @@ -239,7 +239,7 @@ def prepare_model_template_train(args, msg: Optional[Dict[str, Any]] = None): model.label_names = label_names model.return_loss = return_loss - if args.is_multimodal and args.gradient_checkpointing: + if args.is_multimodal and args.gradient_checkpointing and args.vit_use_gc: dynamic_vit_gradient_checkpointing(model, args.model_type) # Preparing LoRA model, callbacks = prepare_model(model, args) diff --git a/swift/llm/utils/argument.py b/swift/llm/utils/argument.py index 2655f54a3f..90feb3b1cd 100644 --- a/swift/llm/utils/argument.py +++ b/swift/llm/utils/argument.py @@ -780,6 +780,7 @@ class SftArguments(ArgumentsBase): use_liger: bool = False gradient_checkpointing: Optional[bool] = None + vit_use_gc: bool = True # vit use gradient_checkpointing # e.g. 'default-zero3', 'default-zero2', 'ds_config/zero2.json', 'zero2-offload', 'zero3-offload' deepspeed: Optional[str] = None batch_size: int = 1