diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index f7980a2bce25..25619ca55b3f 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -63,7 +63,6 @@ ) from .utils.chat_parsing_utils import recursive_parse from .utils.chat_template_utils import render_jinja_template -from .utils.import_utils import PROTOBUF_IMPORT_ERROR if TYPE_CHECKING: @@ -76,8 +75,7 @@ def import_protobuf_decode_error(error_message=""): from google.protobuf.message import DecodeError return DecodeError - else: - raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message)) + return () def flatten(arr: list): diff --git a/tests/tokenization/test_tokenization_utils.py b/tests/tokenization/test_tokenization_utils.py index da02adcc484d..1b91903efec7 100644 --- a/tests/tokenization/test_tokenization_utils.py +++ b/tests/tokenization/test_tokenization_utils.py @@ -352,3 +352,24 @@ def test_special_tokens_overwrite(self): new_tokenizer.decode(new_tokenizer.encode(text_with_nonspecial_tokens), skip_special_tokens=True) == text_with_nonspecial_tokens ) + + def test_import_protobuf_decode_error_without_protobuf(self): + from unittest.mock import patch + + from transformers.tokenization_utils_base import import_protobuf_decode_error + + with patch("transformers.tokenization_utils_base.is_protobuf_available", return_value=False): + result = import_protobuf_decode_error() + self.assertEqual(result, ()) + + def test_import_protobuf_decode_error_does_not_mask_exceptions(self): + from unittest.mock import patch + + from transformers.tokenization_utils_base import import_protobuf_decode_error + + with patch("transformers.tokenization_utils_base.is_protobuf_available", return_value=False): + with self.assertRaises(ValueError): + try: + raise ValueError("real error") + except import_protobuf_decode_error(): + pass