diff --git a/machine/translation/huggingface/hugging_face_nmt_model_trainer.py b/machine/translation/huggingface/hugging_face_nmt_model_trainer.py index 1192243e..2d8641c7 100644 --- a/machine/translation/huggingface/hugging_face_nmt_model_trainer.py +++ b/machine/translation/huggingface/hugging_face_nmt_model_trainer.py @@ -154,6 +154,9 @@ def train( tgt_lang = self._tgt_lang if tgt_lang is None: tgt_lang = "tgt" + if src_lang == tgt_lang: + src_lang += "_src" + tgt_lang += "_trg" if isinstance(self._corpus, Dataset): train_dataset = self._corpus @@ -206,13 +209,12 @@ def add_tokens(tokenizer: Any, missing_tokens: List[str]) -> Any: use_fast=True, ) # using unofficially supported behavior to set the normalizer + lang_codes = [] tokenizer.backend_tokenizer.normalizer = norm_tok.backend_tokenizer.normalizer # type: ignore - if self._add_unk_src_tokens and self._add_unk_tgt_tokens: - lang_codes = [src_lang, tgt_lang] - elif self._add_unk_src_tokens: - lang_codes = [src_lang] - else: - lang_codes = [tgt_lang] + if self._add_unk_src_tokens: + lang_codes.append(src_lang) + if self._add_unk_tgt_tokens: + lang_codes.append(tgt_lang) missing_tokens = find_missing_characters(tokenizer, train_dataset, lang_codes) if missing_tokens: tokenizer = add_tokens(tokenizer, missing_tokens)