Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
is_hadamard_available,
is_hqq_available,
is_huggingface_hub_greater_or_equal,
is_ipython_available,
is_jinja_available,
is_jmespath_available,
is_jumanpp_available,
Expand Down Expand Up @@ -1179,6 +1180,11 @@ def require_faiss(test_case):
return unittest.skipUnless(is_faiss_available(), "test requires `faiss`")(test_case)


def require_ipython(test_case):
"""Decorator marking a test that requires IPython. These tests are skipped when IPython isn't installed."""
return unittest.skipUnless(is_ipython_available(), "test requires `IPython`")(test_case)


def require_optuna(test_case):
"""
Decorator marking a test that requires optuna.
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@
is_hqq_available,
is_huggingface_hub_greater_or_equal,
is_in_notebook,
is_ipython_available,
is_jinja_available,
is_jmespath_available,
is_jumanpp_available,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1540,6 +1540,11 @@ def msg_callable():
torch._check_with(error_type, cond, msg_callable)


@lru_cache
def is_ipython_available() -> bool:
return importlib.util.find_spec("IPython") is not None


@lru_cache
def is_in_notebook() -> bool:
try:
Expand Down
21 changes: 16 additions & 5 deletions src/transformers/utils/notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,9 @@ def on_log(self, args, state, control, logs=None, **kwargs):
tt.write_line(values)

def on_evaluate(self, args, state, control, metrics=None, **kwargs):
tt = _require(self.training_tracker, "on_train_begin must be called before on_evaluate")
# Recompute first_column here since on_evaluate can be called before on_train_begin,
# where it is normally initialized.
self.first_column = "Epoch" if args.eval_strategy == IntervalStrategy.EPOCH else "Step"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we need to overwrite that ? we shouldn't have to

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because on_evaluate can also be called before on_train_begin (which is the bug this PR fixes), but self.first_column wouldn't exist yet since it's only initialized in on_train_begin. Another option would be to move the initialization to __init__ with a default of "Step" instead, so on_evaluate doesn't need to overwrite it. Defaulting to "Step" could make sense here since if training hasn't started, there are no epochs to reference. So I can do that if you prefer it.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

understood. let's keep what you did, maybe just add a comment about that above

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!


values = {"Training Loss": "No log", "Validation Loss": "No log"}
for log in reversed(state.log_history):
Expand All @@ -374,18 +376,27 @@ def on_evaluate(self, args, state, control, metrics=None, **kwargs):
_ = metrics.pop(f"{metric_key_prefix}_runtime", None)
_ = metrics.pop(f"{metric_key_prefix}_samples_per_second", None)
_ = metrics.pop(f"{metric_key_prefix}_steps_per_second", None)
_ = metrics.pop(f"{metric_key_prefix}_model_preparation_time", None)

for k, v in metrics.items():
splits = k.split("_")
name = " ".join([part.capitalize() for part in splits[1:]])
if name == "Loss":
# Single dataset
name = "Validation Loss"
values[name] = v
tt.write_line(values)
tt.remove_child()

if self.training_tracker is not None:
tt = self.training_tracker
tt.write_line(values)
tt.remove_child()
# Evaluation takes a long time so we should force the next update.
self._force_next_update = True
else:
# No training tracker, but still show the metrics
disp.display(disp.HTML(text_to_html_table([list(values.keys()), list(values.values())])))

self.prediction_bar = None
# Evaluation takes a long time so we should force the next update.
self._force_next_update = True

def on_train_end(self, args, state, control, **kwargs):
tt = _require(self.training_tracker, "on_train_begin must be called before on_train_end")
Expand Down
74 changes: 73 additions & 1 deletion tests/trainer/test_trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
is_torch_available,
)
from transformers.integrations.integration_utils import KubeflowCallback, SwanLabCallback
from transformers.testing_utils import require_torch
from transformers.testing_utils import require_ipython, require_torch
from transformers.trainer_callback import CallbackHandler, ExportableState, TrainerControl


Expand Down Expand Up @@ -1269,3 +1269,75 @@ def state(self):

self.assertEqual(instance.name, "test")
self.assertEqual(instance.counter, 5)


@require_torch
@require_ipython
class NotebookProgressCallbackTest(unittest.TestCase):
"""Tests for NotebookProgressCallback behavior in notebook environments."""

def setUp(self):
self.output_dir = tempfile.mkdtemp()

def tearDown(self):
shutil.rmtree(self.output_dir)

def _create_trainer(self):
train_dataset = RegressionDataset(length=16)
eval_dataset = RegressionDataset(length=16)
config = RegressionModelConfig(a=0, b=0)
model = RegressionPreTrainedModel(config)

args = TrainingArguments(
self.output_dir,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
num_train_epochs=1,
logging_strategy="no",
report_to=[],
eval_strategy="epoch",
disable_tqdm=True,
)

from transformers.utils.notebook import NotebookProgressCallback

trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
callbacks=[NotebookProgressCallback()], # force it
)
return trainer

def test_evaluate_before_training(self):
"""Calling evaluate() before training does not crash and returns metrics."""
trainer = self._create_trainer()
metrics = trainer.evaluate()
self.assertIn("eval_loss", metrics)
# Check that the notebook callback exists in callback handler
from transformers.utils.notebook import NotebookProgressCallback

cb = next(
(c for c in trainer.callback_handler.callbacks if isinstance(c, NotebookProgressCallback)),
None,
)
self.assertIsNotNone(cb)

def test_evaluate_after_training(self):
"""Calling evaluate() after training does not crash and returns metrics."""
trainer = self._create_trainer()
trainer.train()
metrics = trainer.evaluate()
self.assertIn("eval_loss", metrics)

def test_multiple_evaluate_calls(self):
"""Calling evaluate() multiple times in a row works in notebook environment."""
trainer = self._create_trainer()
metrics1 = trainer.evaluate()
trainer.train()
metrics2 = trainer.evaluate()
metrics3 = trainer.evaluate()
self.assertIn("eval_loss", metrics1)
self.assertIn("eval_loss", metrics2)
self.assertIn("eval_loss", metrics3)
Loading