diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index 702e05627cb6..fc3b986aedd1 100755 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -41,7 +41,7 @@ if os.getenv("WANDB_MODE") == "offline": print("⚙️ Running in WANDB offline mode") -from .. import PreTrainedModel, TFPreTrainedModel, TrainingArguments +from .. import PreTrainedModel, TrainingArguments from .. import __version__ as version from ..utils import ( PushToHubMixin, @@ -56,6 +56,9 @@ logger = logging.get_logger(__name__) +if is_tf_available(): + from .. import TFPreTrainedModel + if is_torch_available(): import torch import torch.distributed as dist