From 212df3d501c5f50ea070de6ad71f2875d01fd0a4 Mon Sep 17 00:00:00 2001 From: "huangjintao.hjt" Date: Fri, 8 Nov 2024 11:30:00 +0800 Subject: [PATCH 1/2] fix qwen_vl dpo --- swift/llm/utils/model.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/swift/llm/utils/model.py b/swift/llm/utils/model.py index 52d1f7d281..7c1c560cdd 100644 --- a/swift/llm/utils/model.py +++ b/swift/llm/utils/model.py @@ -5987,8 +5987,9 @@ def _get_cast_dtype(self) -> torch.dtype: tokenizer_cls: Type[PreTrainedTokenizerBase] = get_class_from_dynamic_module(class_ref, model_dir) tokenizer_cls._auto_class = 'AutoTokenizer' tokenizer_cls.IMAGE_ST = () # fix no attr `self.IMAGE_ST` bug - tokenizer_cls._old_decode = tokenizer_cls._decode - tokenizer_cls._decode = _qwen_vl_audio_decode + if not hasattr(tokenizer_cls, '_old_decode'): + tokenizer_cls._old_decode = tokenizer_cls._decode + tokenizer_cls._decode = _qwen_vl_audio_decode # fix device_map is 4 n_gpu = torch.cuda.device_count() local_world_size = get_dist_setting()[3] @@ -5999,8 +6000,8 @@ def _get_cast_dtype(self) -> torch.dtype: kwargs['tokenizer'] = tokenizer_cls.from_pretrained(model_dir, trust_remote_code=True) model, tokenizer = get_qwen_function(model_dir, torch_dtype, model_kwargs, load_model, **kwargs) - device_type = next(model.parameters()).device.type if model is not None: + device_type = next(model.parameters()).device.type fix_qwen_inplace_bug(model) # fix device_map is 4 if n_gpu // local_world_size >= 4: @@ -6040,8 +6041,9 @@ def get_model_tokenizer_qwen_audio(model_dir: str, tokenizer_cls: Type[PreTrainedTokenizerBase] = get_class_from_dynamic_module(class_ref, model_dir) tokenizer_cls._auto_class = 'AutoTokenizer' tokenizer_cls.AUDIO_ST = () # fix no attr `self.AUDIO_ST` bug - tokenizer_cls._old_decode = tokenizer_cls._decode - tokenizer_cls._decode = _qwen_vl_audio_decode + if not hasattr(tokenizer_cls, '_old_decode'): + tokenizer_cls._old_decode = tokenizer_cls._decode + tokenizer_cls._decode = _qwen_vl_audio_decode kwargs['tokenizer'] = tokenizer_cls.from_pretrained(model_dir, trust_remote_code=True) model, tokenizer = get_qwen_function(model_dir, torch_dtype, model_kwargs, load_model, **kwargs) if model is not None: From 7fe723b7c434578b1514406d79e8704065d9945d Mon Sep 17 00:00:00 2001 From: "huangjintao.hjt" Date: Fri, 8 Nov 2024 13:19:25 +0800 Subject: [PATCH 2/2] compat transformers==4.46.2 loss --- swift/llm/sft.py | 1 + 1 file changed, 1 insertion(+) diff --git a/swift/llm/sft.py b/swift/llm/sft.py index 603971e5fe..725997b82b 100644 --- a/swift/llm/sft.py +++ b/swift/llm/sft.py @@ -491,6 +491,7 @@ def trainer_train( json.dump(check_json_format(args_obj.__dict__), f, ensure_ascii=False, indent=2) logging_path = os.path.join(args.output_dir, 'logging.jsonl') logger.info(f'The logging file will be saved in: {logging_path}') + trainer.model_accepts_loss_kwargs = True # fix transformers>=4.46.2 with template.training_context(): trainer.train(training_args.resume_from_checkpoint) last_model_checkpoint = getattr(trainer.state, 'last_model_checkpoint', None)