diff --git a/swift/llm/sft.py b/swift/llm/sft.py index 603971e5fe..725997b82b 100644 --- a/swift/llm/sft.py +++ b/swift/llm/sft.py @@ -491,6 +491,7 @@ def trainer_train( json.dump(check_json_format(args_obj.__dict__), f, ensure_ascii=False, indent=2) logging_path = os.path.join(args.output_dir, 'logging.jsonl') logger.info(f'The logging file will be saved in: {logging_path}') + trainer.model_accepts_loss_kwargs = True # fix transformers>=4.46.2 with template.training_context(): trainer.train(training_args.resume_from_checkpoint) last_model_checkpoint = getattr(trainer.state, 'last_model_checkpoint', None)