diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 6e35a836db16..863242a695c6 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -106,6 +106,7 @@ is_hadamard_available, is_hqq_available, is_huggingface_hub_greater_or_equal, + is_ipython_available, is_jinja_available, is_jmespath_available, is_jumanpp_available, @@ -1179,6 +1180,11 @@ def require_faiss(test_case): return unittest.skipUnless(is_faiss_available(), "test requires `faiss`")(test_case) +def require_ipython(test_case): + """Decorator marking a test that requires IPython. These tests are skipped when IPython isn't installed.""" + return unittest.skipUnless(is_ipython_available(), "test requires `IPython`")(test_case) + + def require_optuna(test_case): """ Decorator marking a test that requires optuna. diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 3f5c7cac386b..d12e0b277c1b 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -150,6 +150,7 @@ is_hqq_available, is_huggingface_hub_greater_or_equal, is_in_notebook, + is_ipython_available, is_jinja_available, is_jmespath_available, is_jumanpp_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 1e1ac2545f05..de11d23cbecf 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -1540,6 +1540,11 @@ def msg_callable(): torch._check_with(error_type, cond, msg_callable) +@lru_cache +def is_ipython_available() -> bool: + return importlib.util.find_spec("IPython") is not None + + @lru_cache def is_in_notebook() -> bool: try: diff --git a/src/transformers/utils/notebook.py b/src/transformers/utils/notebook.py index ecbe8271fe13..1c7fb7a77bea 100644 --- a/src/transformers/utils/notebook.py +++ b/src/transformers/utils/notebook.py @@ -351,7 +351,9 @@ def on_log(self, args, state, control, logs=None, **kwargs): tt.write_line(values) def on_evaluate(self, args, state, control, metrics=None, **kwargs): - tt = _require(self.training_tracker, "on_train_begin must be called before on_evaluate") + # Recompute first_column here since on_evaluate can be called before on_train_begin, + # where it is normally initialized. + self.first_column = "Epoch" if args.eval_strategy == IntervalStrategy.EPOCH else "Step" values = {"Training Loss": "No log", "Validation Loss": "No log"} for log in reversed(state.log_history): @@ -374,6 +376,8 @@ def on_evaluate(self, args, state, control, metrics=None, **kwargs): _ = metrics.pop(f"{metric_key_prefix}_runtime", None) _ = metrics.pop(f"{metric_key_prefix}_samples_per_second", None) _ = metrics.pop(f"{metric_key_prefix}_steps_per_second", None) + _ = metrics.pop(f"{metric_key_prefix}_model_preparation_time", None) + for k, v in metrics.items(): splits = k.split("_") name = " ".join([part.capitalize() for part in splits[1:]]) @@ -381,11 +385,18 @@ def on_evaluate(self, args, state, control, metrics=None, **kwargs): # Single dataset name = "Validation Loss" values[name] = v - tt.write_line(values) - tt.remove_child() + + if self.training_tracker is not None: + tt = self.training_tracker + tt.write_line(values) + tt.remove_child() + # Evaluation takes a long time so we should force the next update. + self._force_next_update = True + else: + # No training tracker, but still show the metrics + disp.display(disp.HTML(text_to_html_table([list(values.keys()), list(values.values())]))) + self.prediction_bar = None - # Evaluation takes a long time so we should force the next update. - self._force_next_update = True def on_train_end(self, args, state, control, **kwargs): tt = _require(self.training_tracker, "on_train_begin must be called before on_train_end") diff --git a/tests/trainer/test_trainer_callback.py b/tests/trainer/test_trainer_callback.py index 0d132a9051f5..db0ccd56b1a1 100644 --- a/tests/trainer/test_trainer_callback.py +++ b/tests/trainer/test_trainer_callback.py @@ -43,7 +43,7 @@ is_torch_available, ) from transformers.integrations.integration_utils import KubeflowCallback, SwanLabCallback -from transformers.testing_utils import require_torch +from transformers.testing_utils import require_ipython, require_torch from transformers.trainer_callback import CallbackHandler, ExportableState, TrainerControl @@ -1269,3 +1269,75 @@ def state(self): self.assertEqual(instance.name, "test") self.assertEqual(instance.counter, 5) + + +@require_torch +@require_ipython +class NotebookProgressCallbackTest(unittest.TestCase): + """Tests for NotebookProgressCallback behavior in notebook environments.""" + + def setUp(self): + self.output_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.output_dir) + + def _create_trainer(self): + train_dataset = RegressionDataset(length=16) + eval_dataset = RegressionDataset(length=16) + config = RegressionModelConfig(a=0, b=0) + model = RegressionPreTrainedModel(config) + + args = TrainingArguments( + self.output_dir, + per_device_train_batch_size=2, + per_device_eval_batch_size=2, + num_train_epochs=1, + logging_strategy="no", + report_to=[], + eval_strategy="epoch", + disable_tqdm=True, + ) + + from transformers.utils.notebook import NotebookProgressCallback + + trainer = Trainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + callbacks=[NotebookProgressCallback()], # force it + ) + return trainer + + def test_evaluate_before_training(self): + """Calling evaluate() before training does not crash and returns metrics.""" + trainer = self._create_trainer() + metrics = trainer.evaluate() + self.assertIn("eval_loss", metrics) + # Check that the notebook callback exists in callback handler + from transformers.utils.notebook import NotebookProgressCallback + + cb = next( + (c for c in trainer.callback_handler.callbacks if isinstance(c, NotebookProgressCallback)), + None, + ) + self.assertIsNotNone(cb) + + def test_evaluate_after_training(self): + """Calling evaluate() after training does not crash and returns metrics.""" + trainer = self._create_trainer() + trainer.train() + metrics = trainer.evaluate() + self.assertIn("eval_loss", metrics) + + def test_multiple_evaluate_calls(self): + """Calling evaluate() multiple times in a row works in notebook environment.""" + trainer = self._create_trainer() + metrics1 = trainer.evaluate() + trainer.train() + metrics2 = trainer.evaluate() + metrics3 = trainer.evaluate() + self.assertIn("eval_loss", metrics1) + self.assertIn("eval_loss", metrics2) + self.assertIn("eval_loss", metrics3)