diff --git a/tests/trainer/test_trainer_callback.py b/tests/trainer/test_trainer_callback.py index 4676c0da06e4..76778b01e4b6 100644 --- a/tests/trainer/test_trainer_callback.py +++ b/tests/trainer/test_trainer_callback.py @@ -27,7 +27,8 @@ import shutil import tempfile import unittest -from unittest.mock import patch +from types import SimpleNamespace +from unittest.mock import Mock, patch from transformers import ( DefaultFlowCallback, @@ -41,6 +42,7 @@ TrainingArguments, is_torch_available, ) +from transformers.integrations.integration_utils import SwanLabCallback from transformers.testing_utils import require_torch from transformers.trainer_callback import CallbackHandler, ExportableState, TrainerControl @@ -732,6 +734,73 @@ class MockArgs: self.assertEqual(state.save_steps, 50) +class SwanLabCallbackTest(unittest.TestCase): + def _create_callback(self, fake_swanlab): + with patch("transformers.integrations.integration_utils.is_swanlab_available", return_value=True): + with patch.dict("sys.modules", {"swanlab": fake_swanlab}): + callback = SwanLabCallback() + return callback + + @staticmethod + def _create_args(): + class SwanLabArgs: + run_name = "swanlab-run" + resume_from_checkpoint = False + + @staticmethod + def to_dict(): + return {} + + return SwanLabArgs() + + @staticmethod + def _create_state(): + return SimpleNamespace(is_world_process_zero=True, trial_name=None) + + @staticmethod + def _create_model(): + class DummyConfig: + @staticmethod + def to_dict(): + return {} + + class DummyModel: + config = DummyConfig() + peft_config = None + + @staticmethod + def num_parameters(): + return 1 + + return DummyModel() + + def test_setup_does_not_forward_id_or_resume_by_default(self): + fake_swanlab = Mock() + fake_swanlab.get_run.return_value = None + fake_swanlab.config = {} + callback = self._create_callback(fake_swanlab) + + with patch.dict(os.environ, {}, clear=True): + callback.setup(self._create_args(), self._create_state(), self._create_model()) + + init_kwargs = fake_swanlab.init.call_args.kwargs + self.assertNotIn("id", init_kwargs) + self.assertNotIn("resume", init_kwargs) + + def test_setup_forwards_id_and_resume_from_env(self): + fake_swanlab = Mock() + fake_swanlab.get_run.return_value = None + fake_swanlab.config = {} + callback = self._create_callback(fake_swanlab) + + with patch.dict(os.environ, {"SWANLAB_RUN_ID": "run-123", "SWANLAB_RESUME": "must"}, clear=True): + callback.setup(self._create_args(), self._create_state(), self._create_model()) + + init_kwargs = fake_swanlab.init.call_args.kwargs + self.assertEqual(init_kwargs["id"], "run-123") + self.assertEqual(init_kwargs["resume"], "must") + + class TrainerControlTest(unittest.TestCase): """Tests for TrainerControl functionality."""