diff --git a/examples/pytorch/translation/run_translation.py b/examples/pytorch/translation/run_translation.py index e1d3c4ca387a..74e244994190 100755 --- a/examples/pytorch/translation/run_translation.py +++ b/examples/pytorch/translation/run_translation.py @@ -444,6 +444,9 @@ def main(): ) model.config.forced_bos_token_id = forced_bos_token_id + if hasattr(model, "generation_config") and model.generation_config is not None: + model.generation_config.forced_bos_token_id = forced_bos_token_id + # Get the language codes for input/target. source_lang = data_args.source_lang.split("_")[0] target_lang = data_args.target_lang.split("_")[0]