From a50e35e20118bc16260363b61101eeeb1cc179f8 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 13 Feb 2026 11:09:00 -0500 Subject: [PATCH 01/19] refactor _inner_training_loop to smaller methods --- src/transformers/trainer.py | 486 +++++++++++++++++------------- src/transformers/trainer_utils.py | 1 + 2 files changed, 281 insertions(+), 206 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 731c922db81f..68edf26cf0bd 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1462,6 +1462,44 @@ def _inner_training_loop( max_steps, ) = self.set_initial_training_values(args, train_dataloader, total_train_batch_size) + model, train_dataloader = self._setup_training(args, max_steps, resume_from_checkpoint, train_dataloader) + + epochs_trained, steps_trained_in_current_epoch, start_time = self._init_loop_state( + args=args, + model=model, + num_update_steps_per_epoch=num_update_steps_per_epoch, + num_train_epochs=num_train_epochs, + max_steps=max_steps, + total_train_batch_size=total_train_batch_size, + num_examples=num_examples, + len_dataloader=len_dataloader, + train_dataloader=train_dataloader, + resume_from_checkpoint=resume_from_checkpoint, + trial=trial, + ignore_keys_for_eval=ignore_keys_for_eval, + ) + + for epoch in range(epochs_trained, num_train_epochs): + self._run_epoch( + model=model, + epoch=epoch, + train_dataloader=train_dataloader, + len_dataloader=len_dataloader, + args=args, + trial=trial, + ignore_keys_for_eval=ignore_keys_for_eval, + start_time=start_time, + resume_from_checkpoint=resume_from_checkpoint, + epochs_trained=epochs_trained, + steps_trained_in_current_epoch=steps_trained_in_current_epoch, + ) + if self.control.should_training_stop: + break + + return self._finalize_training(model, trial, num_train_samples, start_time) + + def _setup_training(self, args, max_steps, resume_from_checkpoint, train_dataloader): + """Create optimizer, wrap model, load checkpoint. Returns (wrapped_model, train_dataloader).""" if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: if self.args.n_gpu > 1: # nn.DataParallel(model) replicates the model, creating new variables and module @@ -1495,7 +1533,6 @@ def _inner_training_loop( cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) ] ) - self.state.is_hyper_param_search = trial is not None self.state.train_batch_size = self._train_batch_size # Compute absolute values for logging, eval, and save if given as ratio @@ -1583,6 +1620,26 @@ def _inner_training_loop( # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc. + return model, train_dataloader + + def _init_loop_state( + self, + args, + model, + num_update_steps_per_epoch, + num_train_epochs, + max_steps, + total_train_batch_size, + num_examples, + len_dataloader, + train_dataloader, + resume_from_checkpoint, + trial, + ignore_keys_for_eval, + ): + """Initialize training loop state. Returns (epochs_trained, steps_trained_in_current_epoch, start_time).""" + self.state.is_hyper_param_search = trial is not None + # Train! logger.info("***** Running training *****") logger.info(f" Num examples = {num_examples:,}") @@ -1632,240 +1689,257 @@ def _inner_training_loop( self.state.init_training_references(self, max_steps, num_train_epochs, trial) # tr_loss is a tensor to avoid synchronization of TPUs through .item() - tr_loss = torch.tensor(0.0, device=args.device) + self._tr_loss = torch.tensor(0.0, device=args.device) # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses self._total_loss_scalar = 0.0 self._globalstep_last_logged = self.state.global_step model.zero_grad() - grad_norm: float | None = None - learning_rate = None + self._grad_norm: float | None = None + self._learning_rate = None self.control = self.callback_handler.on_train_begin(args, self.state, self.control) if args.eval_on_start: self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True) - for epoch in range(epochs_trained, num_train_epochs): - epoch_dataloader = train_dataloader + return epochs_trained, steps_trained_in_current_epoch, start_time - steps_in_epoch = ( - len(epoch_dataloader) - if len_dataloader is not None - else args.max_steps * args.gradient_accumulation_steps - ) - self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) - - step = -1 - rng_to_sync = False - - # Handle resumption from checkpoint - if epoch == epochs_trained and resume_from_checkpoint is not None: - if steps_trained_in_current_epoch > 0 and not args.ignore_data_skip: - epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch) - step = steps_trained_in_current_epoch - 1 - rng_to_sync = True - elif steps_trained_in_current_epoch == 0: - self._load_rng_state(resume_from_checkpoint) + def _run_epoch( + self, + model, + epoch, + train_dataloader, + len_dataloader, + args, + trial, + ignore_keys_for_eval, + start_time, + resume_from_checkpoint, + epochs_trained, + steps_trained_in_current_epoch, + ): + """Run one full pass over the dataloader.""" + epoch_dataloader = train_dataloader - if hasattr(epoch_dataloader, "set_epoch"): - epoch_dataloader.set_epoch(epoch) - - epoch_iterator = iter(epoch_dataloader) - # We chunkify the epoch iterator into gradient accumulation steps `n` batches - remainder = steps_in_epoch % args.gradient_accumulation_steps - if remainder == 0: - remainder = args.gradient_accumulation_steps - update_step = -1 - total_updates = steps_in_epoch // args.gradient_accumulation_steps + int( - remainder < args.gradient_accumulation_steps - ) - for _ in range(total_updates): - update_step += 1 - num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder - batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches, args.device) - # Store the number of batches for current gradient accumulation - # This is used to correctly scale the loss when the last accumulation step has fewer batches - self.current_gradient_accumulation_steps = len(batch_samples) - for i, inputs in enumerate(batch_samples): - step += 1 - do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch - # Since we perform prefetching, we need to manually set sync_gradients - self.accelerator.gradient_state._set_sync_gradients(do_sync_step) - - if self.args.include_num_input_tokens_seen != "no": - main_input_name = getattr(self.model, "main_input_name", "input_ids") - if main_input_name not in inputs: - logger.warning( - "Tried to track the number of tokens seen, however the current model is " - "not configured properly to know what item is the input. To fix this, add " - "a `main_input_name` attribute to the model class you are using." - ) - else: - if self.args.include_num_input_tokens_seen == "non_padding": - if "attention_mask" in inputs: - input_tokens = inputs["attention_mask"].sum() - elif ( - self.processing_class is not None - and hasattr(self.processing_class, "pad_token_id") - and self.processing_class.pad_token_id is not None - ): - input_tokens = ( - inputs[main_input_name] != self.processing_class.pad_token_id - ).sum() - else: - logger.warning( - "Could not determine method to count non-padding tokens, falling back to counting all tokens." - ) - input_tokens = inputs[main_input_name].numel() + steps_in_epoch = ( + len(epoch_dataloader) + if len_dataloader is not None + else args.max_steps * args.gradient_accumulation_steps + ) + self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) + + step = -1 + rng_to_sync = False + + # Handle resumption from checkpoint + if epoch == epochs_trained and resume_from_checkpoint is not None: + if steps_trained_in_current_epoch > 0 and not args.ignore_data_skip: + epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch) + step = steps_trained_in_current_epoch - 1 + rng_to_sync = True + elif steps_trained_in_current_epoch == 0: + self._load_rng_state(resume_from_checkpoint) + + if hasattr(epoch_dataloader, "set_epoch"): + epoch_dataloader.set_epoch(epoch) + + epoch_iterator = iter(epoch_dataloader) + # We chunkify the epoch iterator into gradient accumulation steps `n` batches + remainder = steps_in_epoch % args.gradient_accumulation_steps + if remainder == 0: + remainder = args.gradient_accumulation_steps + update_step = -1 + total_updates = steps_in_epoch // args.gradient_accumulation_steps + int( + remainder < args.gradient_accumulation_steps + ) + for _ in range(total_updates): + update_step += 1 + num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder + batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches, args.device) + # Store the number of batches for current gradient accumulation + # This is used to correctly scale the loss when the last accumulation step has fewer batches + self.current_gradient_accumulation_steps = len(batch_samples) + for i, inputs in enumerate(batch_samples): + step += 1 + do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch + # Since we perform prefetching, we need to manually set sync_gradients + self.accelerator.gradient_state._set_sync_gradients(do_sync_step) + + if self.args.include_num_input_tokens_seen != "no": + main_input_name = getattr(self.model, "main_input_name", "input_ids") + if main_input_name not in inputs: + logger.warning( + "Tried to track the number of tokens seen, however the current model is " + "not configured properly to know what item is the input. To fix this, add " + "a `main_input_name` attribute to the model class you are using." + ) + else: + if self.args.include_num_input_tokens_seen == "non_padding": + if "attention_mask" in inputs: + input_tokens = inputs["attention_mask"].sum() + elif ( + self.processing_class is not None + and hasattr(self.processing_class, "pad_token_id") + and self.processing_class.pad_token_id is not None + ): + input_tokens = ( + inputs[main_input_name] != self.processing_class.pad_token_id + ).sum() else: + logger.warning( + "Could not determine method to count non-padding tokens, falling back to counting all tokens." + ) input_tokens = inputs[main_input_name].numel() + else: + input_tokens = inputs[main_input_name].numel() - input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64) - self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).sum().item() - - if rng_to_sync: - self._load_rng_state(resume_from_checkpoint) - rng_to_sync = False - - if step % args.gradient_accumulation_steps == 0: - self.control = self.callback_handler.on_step_begin(args, self.state, self.control) + input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64) + self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).sum().item() - # We sync the gradients in the following cases: 1. sync_each_batch set to True 2. Using deepspeed 3. when we are at the last batch sample - if ( - self.accelerator.gradient_state.plugin_kwargs.get("sync_each_batch", False) - or self.accelerator.distributed_type == DistributedType.DEEPSPEED - or i == len(batch_samples) - 1 - ): - sync_context = contextlib.nullcontext - else: - sync_context = functools.partial(self.accelerator.no_sync, model=model) - with sync_context(): - tr_loss_step = self.training_step(model, inputs, num_items_in_batch) - - if ( - args.logging_nan_inf_filter - and not is_torch_xla_available() - and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) - ): - # if loss is nan or inf simply add the average of previous logged losses - tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) - else: - if tr_loss.device != tr_loss_step.device: - raise ValueError( - f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}" - ) - tr_loss = tr_loss + tr_loss_step + if rng_to_sync: + self._load_rng_state(resume_from_checkpoint) + rng_to_sync = False + + if step % args.gradient_accumulation_steps == 0: + self.control = self.callback_handler.on_step_begin(args, self.state, self.control) + + # We sync the gradients in the following cases: 1. sync_each_batch set to True 2. Using deepspeed 3. when we are at the last batch sample + if ( + self.accelerator.gradient_state.plugin_kwargs.get("sync_each_batch", False) + or self.accelerator.distributed_type == DistributedType.DEEPSPEED + or i == len(batch_samples) - 1 + ): + sync_context = contextlib.nullcontext + else: + sync_context = functools.partial(self.accelerator.no_sync, model=model) + with sync_context(): + tr_loss_step = self.training_step(model, inputs, num_items_in_batch) + + if ( + args.logging_nan_inf_filter + and not is_torch_xla_available() + and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) + ): + # if loss is nan or inf simply add the average of previous logged losses + self._tr_loss += self._tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) + else: + if self._tr_loss.device != tr_loss_step.device: + raise ValueError( + f"Calculated loss must be on the original device: {self._tr_loss.device} but device in use is {tr_loss_step.device}" + ) + self._tr_loss += tr_loss_step - self.current_flos += float(self.floating_point_ops(inputs)) + self.current_flos += float(self.floating_point_ops(inputs)) - if do_sync_step: - # Since we perform prefetching, we need to manually set sync_gradients to True - self.accelerator.gradient_state._set_sync_gradients(True) + if do_sync_step: + # Since we perform prefetching, we need to manually set sync_gradients to True + self.accelerator.gradient_state._set_sync_gradients(True) - # Gradient clipping - if args.max_grad_norm is not None and args.max_grad_norm > 0: - if is_sagemaker_mp_enabled() and args.fp16: - _grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm) - else: - grad_norm_context = contextlib.nullcontext - if self.is_tp_enabled: - from torch.distributed._tensor.experimental import implicit_replication - - grad_norm_context = implicit_replication - with grad_norm_context(): - _grad_norm = self.accelerator.clip_grad_norm_( - model.parameters(), - args.max_grad_norm, - ) + # Gradient clipping + if args.max_grad_norm is not None and args.max_grad_norm > 0: + if is_sagemaker_mp_enabled() and args.fp16: + _grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm) + else: + grad_norm_context = contextlib.nullcontext + if self.is_tp_enabled: + from torch.distributed._tensor.experimental import implicit_replication + + grad_norm_context = implicit_replication + with grad_norm_context(): + _grad_norm = self.accelerator.clip_grad_norm_( + model.parameters(), + args.max_grad_norm, + ) + + if self.accelerator.distributed_type == DistributedType.DEEPSPEED: + self._grad_norm = model.get_global_grad_norm() + # In some cases the grad norm may not return a float + if hasattr(self._grad_norm, "item"): + self._grad_norm = self._grad_norm.item() + else: + self._grad_norm = _grad_norm + + self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control) + + context = contextlib.nullcontext + if self.is_tp_enabled: + from torch.distributed._tensor.experimental import implicit_replication + + context = implicit_replication + + with context(): + self.optimizer.step() + + self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control) + + # get learning rate before update + self._learning_rate = self._get_learning_rate() + + if not self.accelerator.optimizer_step_was_skipped: + # Delay optimizer scheduling until metrics are generated + if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + self.lr_scheduler.step() + + model.zero_grad() + self.state.global_step += 1 + self.state.epoch = epoch + (step + 1) / steps_in_epoch + self.control = self.callback_handler.on_step_end(args, self.state, self.control) + self._maybe_log_save_evaluate( + self._tr_loss, + self._grad_norm, + model, + trial, + epoch, + ignore_keys_for_eval, + start_time, + learning_rate=self._learning_rate, + ) + else: + self.control = self.callback_handler.on_substep_end(args, self.state, self.control) - if self.accelerator.distributed_type == DistributedType.DEEPSPEED: - grad_norm = model.get_global_grad_norm() - # In some cases the grad norm may not return a float - if hasattr(grad_norm, "item"): - grad_norm = grad_norm.item() - else: - grad_norm = _grad_norm - - self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control) - - context = contextlib.nullcontext - if self.is_tp_enabled: - from torch.distributed._tensor.experimental import implicit_replication - - context = implicit_replication - - with context(): - self.optimizer.step() - - self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control) - - # get leaning rate before update - learning_rate = self._get_learning_rate() - - if not self.accelerator.optimizer_step_was_skipped: - # Delay optimizer scheduling until metrics are generated - if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): - self.lr_scheduler.step() - - model.zero_grad() - self.state.global_step += 1 - self.state.epoch = epoch + (step + 1) / steps_in_epoch - self.control = self.callback_handler.on_step_end(args, self.state, self.control) - self._maybe_log_save_evaluate( - tr_loss, - grad_norm, - model, - trial, - epoch, - ignore_keys_for_eval, - start_time, - learning_rate=learning_rate, - ) - else: - self.control = self.callback_handler.on_substep_end(args, self.state, self.control) - - # PyTorch/XLA relies on the data loader to insert the mark_step for - # each step. Since we are breaking the loop early, we need to manually - # insert the mark_step here. - if self.control.should_epoch_stop or self.control.should_training_stop: - if is_torch_xla_available(): - xm.mark_step() - break - # We also need to break out of the nested loop + # PyTorch/XLA relies on the data loader to insert the mark_step for + # each step. Since we are breaking the loop early, we need to manually + # insert the mark_step here. if self.control.should_epoch_stop or self.control.should_training_stop: if is_torch_xla_available(): xm.mark_step() break - if step < 0: - logger.warning( - "There seems not to be a single sample in your epoch_iterator, stopping training at step" - f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" - f" num_steps ({max_steps}) higher than the number of available samples." - ) - self.control.should_training_stop = True - - self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) - self._maybe_log_save_evaluate( - tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=learning_rate - ) - - if DebugOption.TPU_METRICS_DEBUG in self.args.debug: + # We also need to break out of the nested loop + if self.control.should_epoch_stop or self.control.should_training_stop: if is_torch_xla_available(): - # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) - xm.master_print(met.metrics_report()) - else: - logger.warning( - "You enabled PyTorch/XLA debug metrics but you don't have a TPU " - "configured. Check your training configuration if this is unexpected." - ) - if self.control.should_training_stop: + xm.mark_step() break + if step < 0: + logger.warning( + "There seems not to be a single sample in your epoch_iterator, stopping training at step" + f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" + f" num_steps ({self.state.max_steps}) higher than the number of available samples." + ) + self.control.should_training_stop = True + + self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) + self._maybe_log_save_evaluate( + self._tr_loss, self._grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, + learning_rate=self._learning_rate, + ) + + if DebugOption.TPU_METRICS_DEBUG in self.args.debug: + if is_torch_xla_available(): + # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) + xm.master_print(met.metrics_report()) + else: + logger.warning( + "You enabled PyTorch/XLA debug metrics but you don't have a TPU " + "configured. Check your training configuration if this is unexpected." + ) + def _finalize_training(self, model, trial, num_train_samples, start_time): + """Finalize training: metrics, best-model loading, cleanup. Returns TrainOutput.""" logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") - if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: + if self.args.load_best_model_at_end and self.state.best_model_checkpoint is not None: self._load_best_model() # add remaining tr_loss - self._total_loss_scalar += tr_loss.item() + self._total_loss_scalar += self._tr_loss.item() effective_global_step = max(self.state.global_step, 0.001) # Avoid ZeroDivisionError train_loss = self._total_loss_scalar / effective_global_step @@ -1897,7 +1971,7 @@ def _inner_training_loop( logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") shutil.rmtree(checkpoint, ignore_errors=True) - self.control = self.callback_handler.on_train_end(args, self.state, self.control) + self.control = self.callback_handler.on_train_end(self.args, self.state, self.control) # Wait for the checkpoint to be uploaded. self._finish_current_push() diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 46582e4069c8..aa8717dc8e90 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -612,6 +612,7 @@ class TrainerMemoryTracker: "__init__": "init", "train": "train", "_inner_training_loop": "train", + "_finalize_training": "train", "evaluate": "eval", "predict": "test", } From 4bbbac199d67a36acb44f99eacea6307c4f26947 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 16 Feb 2026 20:25:35 -0500 Subject: [PATCH 02/19] address PR comments and remove unused args --- src/transformers/trainer.py | 60 +++++++++++++++++++------------------ 1 file changed, 31 insertions(+), 29 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 68edf26cf0bd..5bc652b36760 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1416,16 +1416,7 @@ def train( ignore_keys_for_eval=ignore_keys_for_eval, ) - def _inner_training_loop( - self, - batch_size: int | None = None, - args: TrainingArguments | None = None, - resume_from_checkpoint: str | None = None, - trial: "optuna.Trial | dict[str, Any] | None" = None, - ignore_keys_for_eval: list[str] | None = None, - ) -> TrainOutput: - """Run the actual training loop: forward, backward, optimizer step, logging, and checkpointing.""" - self.accelerator.free_memory() + def _init_train_batch_size(self, batch_size): self._train_batch_size = batch_size if self.args.auto_find_batch_size: if self.state.train_batch_size != self._train_batch_size: @@ -1441,6 +1432,18 @@ def _inner_training_loop( self.args.per_device_train_batch_size = original_bs self.state.train_batch_size = self._train_batch_size logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") + + def _inner_training_loop( + self, + batch_size: int | None = None, + args: TrainingArguments | None = None, + resume_from_checkpoint: str | None = None, + trial: "optuna.Trial | dict[str, Any] | None" = None, + ignore_keys_for_eval: list[str] | None = None, + ) -> TrainOutput: + """Run the actual training loop: forward, backward, optimizer step, logging, and checkpointing.""" + self.accelerator.free_memory() + self._init_train_batch_size(batch_size) # Data loader and number of training steps train_dataloader = self.get_train_dataloader() if self.is_fsdp_xla_v2_enabled: @@ -1450,8 +1453,6 @@ def _inner_training_loop( # number of training epochs: num_train_epochs # number of training steps per epoch: num_update_steps_per_epoch # total number of training steps to execute: max_steps - total_train_batch_size = self.get_total_train_batch_size(args) - ( num_train_epochs, num_update_steps_per_epoch, @@ -1460,9 +1461,11 @@ def _inner_training_loop( epoch_based, len_dataloader, max_steps, - ) = self.set_initial_training_values(args, train_dataloader, total_train_batch_size) + total_train_batch_size, + ) = self.set_initial_training_values(args, train_dataloader) - model, train_dataloader = self._setup_training(args, max_steps, resume_from_checkpoint, train_dataloader) + self._setup_debug_model() + model, train_dataloader = self._setup_training(args, max_steps, resume_from_checkpoint, train_dataloader, trial) epochs_trained, steps_trained_in_current_epoch, start_time = self._init_loop_state( args=args, @@ -1472,13 +1475,14 @@ def _inner_training_loop( max_steps=max_steps, total_train_batch_size=total_train_batch_size, num_examples=num_examples, - len_dataloader=len_dataloader, train_dataloader=train_dataloader, resume_from_checkpoint=resume_from_checkpoint, trial=trial, - ignore_keys_for_eval=ignore_keys_for_eval, ) + if args.eval_on_start: + self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True) + for epoch in range(epochs_trained, num_train_epochs): self._run_epoch( model=model, @@ -1496,10 +1500,9 @@ def _inner_training_loop( if self.control.should_training_stop: break - return self._finalize_training(model, trial, num_train_samples, start_time) + return self._finalize_training(trial, num_train_samples, start_time) - def _setup_training(self, args, max_steps, resume_from_checkpoint, train_dataloader): - """Create optimizer, wrap model, load checkpoint. Returns (wrapped_model, train_dataloader).""" + def _setup_debug_model(self): if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: if self.args.n_gpu > 1: # nn.DataParallel(model) replicates the model, creating new variables and module @@ -1510,6 +1513,8 @@ def _setup_training(self, args, max_steps, resume_from_checkpoint, train_dataloa else: DebugUnderflowOverflow(self.model) + def _setup_training(self, args, max_steps, resume_from_checkpoint, train_dataloader, trial): + """Create optimizer, lr_scheduler, wrap model, load checkpoint. Returns (wrapped_model, train_dataloader).""" delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled # Can't delay optimizer creation when using FSDP2: https://github.com/huggingface/accelerate/blob/3f636d626063ffcf9a337c7d3624d61b7d187d59/src/accelerate/accelerator.py#L1404 @@ -1533,6 +1538,7 @@ def _setup_training(self, args, max_steps, resume_from_checkpoint, train_dataloa cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) ] ) + self.state.is_hyper_param_search = trial is not None self.state.train_batch_size = self._train_batch_size # Compute absolute values for logging, eval, and save if given as ratio @@ -1631,15 +1637,11 @@ def _init_loop_state( max_steps, total_train_batch_size, num_examples, - len_dataloader, train_dataloader, resume_from_checkpoint, trial, - ignore_keys_for_eval, ): """Initialize training loop state. Returns (epochs_trained, steps_trained_in_current_epoch, start_time).""" - self.state.is_hyper_param_search = trial is not None - # Train! logger.info("***** Running training *****") logger.info(f" Num examples = {num_examples:,}") @@ -1698,9 +1700,6 @@ def _init_loop_state( self._learning_rate = None self.control = self.callback_handler.on_train_begin(args, self.state, self.control) - if args.eval_on_start: - self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True) - return epochs_trained, steps_trained_in_current_epoch, start_time def _run_epoch( @@ -1932,7 +1931,7 @@ def _run_epoch( "configured. Check your training configuration if this is unexpected." ) - def _finalize_training(self, model, trial, num_train_samples, start_time): + def _finalize_training(self, trial, num_train_samples, start_time): """Finalize training: metrics, best-model loading, cleanup. Returns TrainOutput.""" logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") if self.args.load_best_model_at_end and self.state.best_model_checkpoint is not None: @@ -2403,8 +2402,8 @@ def _prepare_context_parallel_inputs( return contextlib.nullcontext, inputs def set_initial_training_values( - self, args: TrainingArguments, dataloader: DataLoader, total_train_batch_size: int - ) -> tuple[int, int, int, int, bool, int | None, int]: + self, args: TrainingArguments, dataloader: DataLoader + ) -> tuple[int, int, int, int, bool, int | None, int, int]: """ Calculates and returns the following values: - `num_train_epochs` @@ -2414,12 +2413,14 @@ def set_initial_training_values( - `epoch_based` - `len_dataloader` - `max_steps` + - `total_train_batch_size` """ # Case 1: we rely on `args.max_steps` first max_steps = args.max_steps # If max_steps is negative, we use the number of epochs to determine the number of total steps later epoch_based = max_steps < 0 len_dataloader = len(dataloader) if has_length(dataloader) else None + total_train_batch_size = self.get_total_train_batch_size(args) # Account for Sequence Parallelism (SP) dataloader adapter's effect sp_size = self.get_sp_size() @@ -2469,6 +2470,7 @@ def set_initial_training_values( epoch_based, len_dataloader, max_steps, + total_train_batch_size, ) def get_total_train_batch_size(self, args: TrainingArguments) -> int: From 48391e9ede8f5c68de028c9403f9362559db4173 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Wed, 18 Feb 2026 13:20:05 +0000 Subject: [PATCH 03/19] move _train_batch_size --- src/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 5bc652b36760..8894a1dea487 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1358,7 +1358,6 @@ def train( # This might change the seed so needs to run first. self._hp_search_setup(trial) - self._train_batch_size = self.args.train_batch_size # Model re-init model_reloaded = False @@ -4363,6 +4362,7 @@ def _hp_search_setup(self, trial: "optuna.Trial | dict[str, Any] | None") -> Non # Simply calling `_reset_state` is enough and doesn't need a version pin. AcceleratorState()._reset_state() + self._train_batch_size = self.args.train_batch_size self.create_accelerator_and_postprocess() def _report_to_hp_search( From cda1b847e9a139f88d1f9fa649101e65e6e71749 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 19 Feb 2026 15:23:14 +0000 Subject: [PATCH 04/19] big update --- src/transformers/testing_utils.py | 6 +- src/transformers/trainer.py | 574 ++++++++++++--------------- src/transformers/trainer_callback.py | 2 +- src/transformers/trainer_utils.py | 13 + 4 files changed, 273 insertions(+), 322 deletions(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 57c09c46b6cc..273edeb88d5d 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -1468,11 +1468,7 @@ def get_steps_per_epoch(trainer: Trainer) -> int: training_args = trainer.args train_dataloader = trainer.get_train_dataloader() - initial_training_values = trainer.set_initial_training_values( - args=training_args, - dataloader=train_dataloader, - total_train_batch_size=training_args.per_device_train_batch_size, - ) + initial_training_values = trainer.set_initial_training_values(args=training_args, dataloader=train_dataloader) steps_per_epoch = initial_training_values[1] return steps_per_epoch diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 8894a1dea487..4322d3a0f0d1 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -42,7 +42,6 @@ # ruff: isort: on -import huggingface_hub.utils as hf_hub_utils import numpy as np import safetensors.torch import torch @@ -142,6 +141,7 @@ set_seed, sort_checkpoints, speed_metrics, + suppress_progress_bars, unwrap_peft_model, validate_quantization_for_training, ) @@ -170,7 +170,6 @@ is_torch_hpu_available, is_torch_mlu_available, is_torch_musa_available, - is_torch_neuroncore_available, is_torch_npu_available, is_torch_xla_available, logging, @@ -218,6 +217,7 @@ DataLoaderConfiguration, DistributedDataParallelKwargs, DistributedType, + DummyScheduler, GradientAccumulationPlugin, load_fsdp_model, load_fsdp_optimizer, @@ -702,6 +702,23 @@ def _build_accelerator_args(self, **kwargs) -> dict[str, Any]: } args.update(kwargs) + if self.args.ddp_find_unused_parameters is not None: + find_unused = self.args.ddp_find_unused_parameters + elif isinstance(self.model, PreTrainedModel): + # find_unused_parameters breaks checkpointing as per + # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 + find_unused = not (self.model.is_gradient_checkpointing or self.args.gradient_checkpointing) + else: + find_unused = True + + ddp_kwargs = {"find_unused_parameters": find_unused} + if self.args.ddp_bucket_cap_mb is not None: + ddp_kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb + if self.args.ddp_broadcast_buffers is not None: + ddp_kwargs["broadcast_buffers"] = self.args.ddp_broadcast_buffers + + args["ddp_handler"] = DistributedDataParallelKwargs(**ddp_kwargs) + # We defer compatibility checks to accelerator if self.args.parallelism_config is not None: min_accelerate_version = "1.12.0" @@ -1129,7 +1146,7 @@ def create_optimizer_and_scheduler(self, num_training_steps: int) -> None: self.create_optimizer() self.create_scheduler(num_training_steps=num_training_steps) - def create_optimizer(self) -> torch.optim.Optimizer: + def create_optimizer(self, model=None) -> torch.optim.Optimizer: """ Setup the optimizer. @@ -1139,7 +1156,7 @@ def create_optimizer(self) -> torch.optim.Optimizer: Returns: `torch.optim.Optimizer`: The optimizer instance. """ - opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model + opt_model = self.model if model is None else model if self.optimizer is None: decay_parameters = self.get_decay_parameter_names(opt_model) @@ -1341,6 +1358,26 @@ def train( self.is_in_train = True + # Model re-init + if self.model_init is not None: + # Seed must be set before instantiating the model when using model_init. + enable_full_determinism(args.seed) if args.full_determinism else set_seed(args.seed) + self.model = self.call_model_init(trial) + # Reinitializes optimizer and scheduler + self.optimizer, self.lr_scheduler = None, None + if self.place_model_on_device: + self._move_model_to_device(self.model, args.device) + self.model_wrapped = self.model + + # When fp16/bf16 full eval is enabled, __init__ skips device placement so that + # evaluation_loop can cast dtype and move in one step. Move the model now for training. + if (args.fp16_full_eval or args.bf16_full_eval) and not self.is_model_parallel and self.model_init is None: + self._move_model_to_device(self.model, args.device) + + # Activate gradient checkpointing if needed + if args.gradient_checkpointing: + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs) + # If the model uses a tokenizer, it may have a new tokens for fine-tuning purposes. if isinstance(self.processing_class, (PreTrainedTokenizerBase, ProcessorMixin)) and hasattr( self.model, "config" @@ -1351,23 +1388,18 @@ def train( if self.neftune_noise_alpha is not None: self.neftune_hook_handle = activate_neftune(self.model, self.neftune_noise_alpha, self.accelerator) - # When fp16/bf16 full eval is enabled, __init__ skips device placement so that - # evaluation_loop can cast dtype and move in one step. Move the model now for training. - if (args.fp16_full_eval or args.bf16_full_eval) and not self.is_model_parallel and self.model_init is None: - self._move_model_to_device(self.model, args.device) - # This might change the seed so needs to run first. self._hp_search_setup(trial) - # Model re-init - model_reloaded = False - if self.model_init is not None: - # Seed must be set before instantiating the model when using model_init. - enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) - self.model = self.call_model_init(trial) - model_reloaded = True - # Reinitializes optimizer and scheduler - self.optimizer, self.lr_scheduler = None, None + if DebugOption.UNDERFLOW_OVERFLOW in args.debug: + if args.n_gpu > 1: + # nn.DataParallel(model) replicates the model, creating new variables and module + # references registered here no longer work on other gpus, breaking the module + raise ValueError( + "Currently --debug underflow_overflow is not supported under DP. Please use DDP with torchrun" + ) + else: + DebugUnderflowOverflow(self.model) # Load potential model checkpoint if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: @@ -1376,38 +1408,17 @@ def train( raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") if resume_from_checkpoint is not None: - if not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled and not self.is_fsdp_enabled: - self._load_from_checkpoint(resume_from_checkpoint) - # In case of repeating the find_executable_batch_size, set `self._train_batch_size` properly state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) - # Only restore the checkpoint's train_batch_size when using auto_find_batch_size, - # as that feature needs to resume with the automatically-found batch size. - # Otherwise, use the current args batch size to allow users to change batch configuration. if state.train_batch_size is not None and args.auto_find_batch_size: + # Only restore the checkpoint's train_batch_size when using auto_find_batch_size, self._train_batch_size = state.train_batch_size - # If model was re-initialized, put it on the right device and update self.model_wrapped - if model_reloaded: - if self.place_model_on_device: - self._move_model_to_device(self.model, args.device) - self.model_wrapped = self.model - inner_training_loop = find_executable_batch_size( self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size ) - if args.push_to_hub: - try: - # Disable progress bars when uploading models during checkpoints to avoid polluting stdout - hf_hub_utils.disable_progress_bars() - return inner_training_loop( - args=args, - resume_from_checkpoint=resume_from_checkpoint, - trial=trial, - ignore_keys_for_eval=ignore_keys_for_eval, - ) - finally: - hf_hub_utils.enable_progress_bars() - else: + # Disable progress bars when uploading models during checkpoints to avoid polluting stdout + ctx = suppress_progress_bars() if args.push_to_hub else contextlib.nullcontext() + with ctx: return inner_training_loop( args=args, resume_from_checkpoint=resume_from_checkpoint, @@ -1415,23 +1426,6 @@ def train( ignore_keys_for_eval=ignore_keys_for_eval, ) - def _init_train_batch_size(self, batch_size): - self._train_batch_size = batch_size - if self.args.auto_find_batch_size: - if self.state.train_batch_size != self._train_batch_size: - release_memory(self.model_wrapped) - self.model_wrapped = self.model - - # Check for DeepSpeed *after* the initial pass and modify the config - if self.is_deepspeed_enabled: - # Temporarily unset `self.args.train_batch_size` - original_bs = self.args.per_device_train_batch_size - self.args.per_device_train_batch_size = self._train_batch_size // max(1, self.args.n_gpu) - propagate_args_to_deepspeed(self.accelerator, self.args, auto_find_batch_size=True) - self.args.per_device_train_batch_size = original_bs - self.state.train_batch_size = self._train_batch_size - logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") - def _inner_training_loop( self, batch_size: int | None = None, @@ -1441,44 +1435,65 @@ def _inner_training_loop( ignore_keys_for_eval: list[str] | None = None, ) -> TrainOutput: """Run the actual training loop: forward, backward, optimizer step, logging, and checkpointing.""" - self.accelerator.free_memory() - self._init_train_batch_size(batch_size) + if args.auto_find_batch_size: + self._update_auto_batch_size(batch_size) # Data loader and number of training steps train_dataloader = self.get_train_dataloader() if self.is_fsdp_xla_v2_enabled: train_dataloader = tpu_spmd_dataloader(train_dataloader) # Setting up training control variables: - # number of training epochs: num_train_epochs - # number of training steps per epoch: num_update_steps_per_epoch - # total number of training steps to execute: max_steps ( num_train_epochs, num_update_steps_per_epoch, num_examples, num_train_samples, - epoch_based, - len_dataloader, - max_steps, total_train_batch_size, + max_steps, ) = self.set_initial_training_values(args, train_dataloader) - self._setup_debug_model() - model, train_dataloader = self._setup_training(args, max_steps, resume_from_checkpoint, train_dataloader, trial) - - epochs_trained, steps_trained_in_current_epoch, start_time = self._init_loop_state( - args=args, - model=model, - num_update_steps_per_epoch=num_update_steps_per_epoch, - num_train_epochs=num_train_epochs, - max_steps=max_steps, - total_train_batch_size=total_train_batch_size, - num_examples=num_examples, - train_dataloader=train_dataloader, - resume_from_checkpoint=resume_from_checkpoint, - trial=trial, + epochs_trained, steps_trained_in_current_epoch = self._init_training_state( + max_steps, num_update_steps_per_epoch, num_train_epochs, resume_from_checkpoint, trial + ) + model, train_dataloader = self._prepare_model_and_optimizer( + max_steps, train_dataloader, resume_from_checkpoint ) + # Train! + logger.info("***** Running training *****") + logger.info(f" Num examples = {num_examples:,}") + logger.info(f" Num Epochs = {num_train_epochs:,}") + logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}") + if self.args.per_device_train_batch_size != self._train_batch_size: + logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {max_steps:,}") + logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}") + + if resume_from_checkpoint is not None: + logger.info( + f" Resuming training from checkpoint with epoch {epochs_trained} and global step {self.state.global_step}" + ) + if not self.args.ignore_data_skip: + logger.info( + f" Fast-forwarding the dataloader past {epochs_trained} epochs and" + f" {steps_trained_in_current_epoch} batches to resume from the exact training state." + ) + + start_time = time.time() + # needed to calculate tokens/s + self.initial_num_input_tokens_seen_for_session = self.state.num_input_tokens_seen + # Logging state: _tr_loss accumulates on-device between logging steps (avoiding costly .item() syncs + # on TPUs), then gets drained into _total_loss_scalar at each logging step. + self._tr_loss = torch.tensor(0.0, device=args.device) + self._total_loss_scalar = 0.0 + self._globalstep_last_logged = self.state.global_step + + model.zero_grad() + + self.control = self.callback_handler.on_train_begin(args, self.state, self.control) + if args.eval_on_start: self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True) @@ -1487,8 +1502,6 @@ def _inner_training_loop( model=model, epoch=epoch, train_dataloader=train_dataloader, - len_dataloader=len_dataloader, - args=args, trial=trial, ignore_keys_for_eval=ignore_keys_for_eval, start_time=start_time, @@ -1501,19 +1514,39 @@ def _inner_training_loop( return self._finalize_training(trial, num_train_samples, start_time) - def _setup_debug_model(self): - if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: - if self.args.n_gpu > 1: - # nn.DataParallel(model) replicates the model, creating new variables and module - # references registered here no longer work on other gpus, breaking the module - raise ValueError( - "Currently --debug underflow_overflow is not supported under DP. Please use DDP with torchrun" - ) - else: - DebugUnderflowOverflow(self.model) + def _init_training_state( + self, max_steps, num_update_steps_per_epoch, num_train_epochs, resume_from_checkpoint, trial + ) -> tuple[int, int]: + """Initialize TrainerState, optionally restoring from checkpoint. Returns (epochs_trained, steps_trained_in_current_epoch).""" + self.state = TrainerState( + stateful_callbacks=[ + cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) + ] + ) + self.state.is_hyper_param_search = trial is not None + self.state.train_batch_size = self._train_batch_size + self.state.compute_steps(self.args, max_steps) + + epochs_trained = 0 + steps_trained_in_current_epoch = 0 + + if resume_from_checkpoint is not None and os.path.isfile( + os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) + ): + self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) + compare_trainer_and_checkpoint_args(self.args, self.state) + self._load_callback_state() + epochs_trained = int(self.state.global_step // num_update_steps_per_epoch) + if not self.args.ignore_data_skip: + steps_trained_in_current_epoch = self.state.global_step % num_update_steps_per_epoch + steps_trained_in_current_epoch *= self.args.gradient_accumulation_steps + + self.state.init_training_references(self, max_steps, num_train_epochs, trial) - def _setup_training(self, args, max_steps, resume_from_checkpoint, train_dataloader, trial): - """Create optimizer, lr_scheduler, wrap model, load checkpoint. Returns (wrapped_model, train_dataloader).""" + return epochs_trained, steps_trained_in_current_epoch + + def _prepare_model_and_optimizer(self, max_steps, train_dataloader, resume_from_checkpoint): + """Wrap model, create optimizer and scheduler, and run accelerator.prepare. Returns (model, train_dataloader).""" delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled # Can't delay optimizer creation when using FSDP2: https://github.com/huggingface/accelerate/blob/3f636d626063ffcf9a337c7d3624d61b7d187d59/src/accelerate/accelerator.py#L1404 @@ -1532,54 +1565,27 @@ def _setup_training(self, args, max_steps, resume_from_checkpoint, train_dataloa if not delay_optimizer_creation: self.create_optimizer() - self.state = TrainerState( - stateful_callbacks=[ - cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) - ] - ) - self.state.is_hyper_param_search = trial is not None - self.state.train_batch_size = self._train_batch_size - - # Compute absolute values for logging, eval, and save if given as ratio - self.state.compute_steps(args, max_steps) - - # Activate gradient checkpointing if needed - if args.gradient_checkpointing: - self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs) + model = self._wrap_model(self.model) - model = self._wrap_model(self.model_wrapped) - - # as the model is wrapped, don't use `accelerator.prepare` - # this is for unhandled cases such as - # FSDP-XLA, SageMaker MP/DP, DataParallel + # If the model is wrapped, don't use `accelerator.prepare` + # this is for unhandled cases in accelerate such as FSDP-XLA, SageMaker MP/DP, DataParallel use_accelerator_prepare = model is self.model - if use_accelerator_prepare and self.is_fsdp_enabled: - # In case of auto_find_batch_size=True - # Remove FSDP wrapping from sub-models. - self.model = unwrap_model(self.model, recursive=True) - - if delay_optimizer_creation: - if use_accelerator_prepare: - # configure fsdp plugin for qlora if any - if self.is_fsdp_enabled and _is_peft_model(model): - update_fsdp_plugin_peft(self.model, self.accelerator) - if self.accelerator.mixed_precision != "fp8": - self.model = self.accelerator.prepare(self.model) - self.create_optimizer() - # prepare using `accelerator` prepare if use_accelerator_prepare: - self.model.train() - if self.is_deepspeed_enabled: - from accelerate.utils import DummyScheduler - - if isinstance(self.lr_scheduler, DummyScheduler): - model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( - self.model, self.optimizer, self.lr_scheduler - ) - else: - model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) + if delay_optimizer_creation: + # TODO: check if we can move this somewhere else + if self.is_fsdp_enabled and _is_peft_model(self.model): + update_fsdp_plugin_peft(self.model, self.accelerator) + # we only prepare the model as we don't have an optimizer + model = self.accelerator.prepare(self.model) + # using the model we prepared to create the optimizer + self.create_optimizer(model) + self.optimizer = self.accelerator.prepare(self.optimizer) + elif self.is_deepspeed_enabled and isinstance(self.lr_scheduler, DummyScheduler): + model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( + self.model, self.optimizer, self.lr_scheduler + ) else: model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) else: @@ -1588,26 +1594,35 @@ def _setup_training(self, args, max_steps, resume_from_checkpoint, train_dataloa # Create scheduler now that the optimizer won't change anymore self.create_scheduler(num_training_steps=max_steps) - # since DataLoader was Accelerate prepared w/o a model arg in the same call, we now have to complete the DL wrapping for ALST/UlyssesSP, after model has been prepared - pc = getattr(self.accelerator, "parallelism_config", None) - if pc is not None and pc.sp_backend == "deepspeed" and pc.sp_enabled: - train_dataloader = self.accelerator.deepspeed_ulysses_dl_adapter(train_dataloader, model) + # updating self.model_wrapped + self.model_wrapped = model - if self.is_fsdp_enabled: + if self.is_fsdp_enabled or self.is_fsdp_xla_enabled: + # breaking convention for FSDP model + # TODO: check if this is really needed self.model = self.model_wrapped = model - # Fix `got mixed torch.Tensor and DTensor` error in model.generate() for FSDP2 with LoRA - if hasattr(self.model, "generate"): - dist.fsdp.register_fsdp_forward_method(self.model, "generate") - - # for the rest of this function `model` is the outside model, whether it was wrapped or not - if model is not self.model: - self.model_wrapped = model # backward compatibility + # TODO: check if we really need this if self.is_deepspeed_enabled: self.deepspeed = self.model_wrapped - # ckpt loading + # Important: at this point: + # self.model is the Transformers Model except when we are using FSDP + # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), + # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc. + + if self.is_fsdp_enabled: + # Fix `got mixed torch.Tensor and DTensor` error in model.generate() for FSDP2 with LoRA + if hasattr(self.model, "generate"): + dist.fsdp.register_fsdp_forward_method(self.model, "generate") + + # since DataLoader was Accelerate prepared w/o a model arg in the same call, we now have to complete the DL wrapping for ALST/UlyssesSP, after model has been prepared + pc = getattr(self.accelerator, "parallelism_config", None) + if pc is not None and pc.sp_backend == "deepspeed" and pc.sp_enabled: + train_dataloader = self.accelerator.deepspeed_ulysses_dl_adapter(train_dataloader, model) + + # load checkpoint if resume_from_checkpoint is not None: if self.is_deepspeed_enabled: deepspeed_load_checkpoint( @@ -1616,98 +1631,21 @@ def _setup_training(self, args, max_steps, resume_from_checkpoint, train_dataloa elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled: self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped) - # Check if saved optimizer or scheduler states exist - self._load_optimizer_and_scheduler(resume_from_checkpoint) - self._load_scaler(resume_from_checkpoint) - - # important: at this point: - # self.model is the Transformers Model - # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), - # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc. - - return model, train_dataloader - - def _init_loop_state( - self, - args, - model, - num_update_steps_per_epoch, - num_train_epochs, - max_steps, - total_train_batch_size, - num_examples, - train_dataloader, - resume_from_checkpoint, - trial, - ): - """Initialize training loop state. Returns (epochs_trained, steps_trained_in_current_epoch, start_time).""" - # Train! - logger.info("***** Running training *****") - logger.info(f" Num examples = {num_examples:,}") - logger.info(f" Num Epochs = {num_train_epochs:,}") - logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}") - if self.args.per_device_train_batch_size != self._train_batch_size: - logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}") - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {max_steps:,}") - logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}") - - self.state.epoch = 0 - start_time = time.time() - self.initial_num_input_tokens_seen_for_session = self.state.num_input_tokens_seen - epochs_trained = 0 - steps_trained_in_current_epoch = 0 - - # Check if continuing training from a checkpoint - if resume_from_checkpoint is not None and os.path.isfile( - os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) - ): - self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) - compare_trainer_and_checkpoint_args(self.args, self.state) - self._load_callback_state() - epochs_trained = int(self.state.global_step // num_update_steps_per_epoch) - if not args.ignore_data_skip: - steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) - steps_trained_in_current_epoch *= args.gradient_accumulation_steps - else: - steps_trained_in_current_epoch = 0 - - logger.info(" Continuing training from checkpoint, will skip to saved global_step") - logger.info(f" Continuing training from epoch {epochs_trained}") - logger.info(f" Continuing training from global step {self.state.global_step}") - if not args.ignore_data_skip: - logger.info( - f" Will skip the first {epochs_trained} epochs then the first" - f" {steps_trained_in_current_epoch} batches in the first epoch." - ) + self._load_optimizer_and_scheduler(resume_from_checkpoint) + self._load_scaler(resume_from_checkpoint) - # Update the references + # Update the references for the callback_handler for attr in ("model", "optimizer", "lr_scheduler"): setattr(self.callback_handler, attr, getattr(self, attr)) self.callback_handler.train_dataloader = train_dataloader - self.state.init_training_references(self, max_steps, num_train_epochs, trial) - - # tr_loss is a tensor to avoid synchronization of TPUs through .item() - self._tr_loss = torch.tensor(0.0, device=args.device) - # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses - self._total_loss_scalar = 0.0 - self._globalstep_last_logged = self.state.global_step - model.zero_grad() - self._grad_norm: float | None = None - self._learning_rate = None - self.control = self.callback_handler.on_train_begin(args, self.state, self.control) - - return epochs_trained, steps_trained_in_current_epoch, start_time + return model, train_dataloader def _run_epoch( self, model, epoch, train_dataloader, - len_dataloader, - args, trial, ignore_keys_for_eval, start_time, @@ -1716,49 +1654,56 @@ def _run_epoch( steps_trained_in_current_epoch, ): """Run one full pass over the dataloader.""" - epoch_dataloader = train_dataloader - steps_in_epoch = ( - len(epoch_dataloader) - if len_dataloader is not None - else args.max_steps * args.gradient_accumulation_steps + len(train_dataloader) + if has_length(train_dataloader) is not None + else self.args.max_steps * self.args.gradient_accumulation_steps ) - self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) step = -1 + grad_norm: float | None = None + learning_rate = None rng_to_sync = False - # Handle resumption from checkpoint + self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control) + + # Handle resumption from checkpoint: skip already-trained batches in the resumed epoch if epoch == epochs_trained and resume_from_checkpoint is not None: - if steps_trained_in_current_epoch > 0 and not args.ignore_data_skip: - epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch) + if steps_trained_in_current_epoch > 0 and not self.args.ignore_data_skip: + train_dataloader = skip_first_batches(train_dataloader, steps_trained_in_current_epoch) step = steps_trained_in_current_epoch - 1 rng_to_sync = True elif steps_trained_in_current_epoch == 0: self._load_rng_state(resume_from_checkpoint) - if hasattr(epoch_dataloader, "set_epoch"): - epoch_dataloader.set_epoch(epoch) + if hasattr(train_dataloader, "set_epoch"): + train_dataloader.set_epoch(epoch) - epoch_iterator = iter(epoch_dataloader) + epoch_iterator = iter(train_dataloader) # We chunkify the epoch iterator into gradient accumulation steps `n` batches - remainder = steps_in_epoch % args.gradient_accumulation_steps + remainder = steps_in_epoch % self.args.gradient_accumulation_steps if remainder == 0: - remainder = args.gradient_accumulation_steps + remainder = self.args.gradient_accumulation_steps update_step = -1 - total_updates = steps_in_epoch // args.gradient_accumulation_steps + int( - remainder < args.gradient_accumulation_steps + total_updates = steps_in_epoch // self.args.gradient_accumulation_steps + int( + remainder < self.args.gradient_accumulation_steps ) for _ in range(total_updates): update_step += 1 - num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder - batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches, args.device) + num_batches = self.args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder + batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches, self.args.device) # Store the number of batches for current gradient accumulation # This is used to correctly scale the loss when the last accumulation step has fewer batches self.current_gradient_accumulation_steps = len(batch_samples) + + # need to sync after we skipped the batched in `get_batch_samples` + if rng_to_sync: + self._load_rng_state(resume_from_checkpoint) + rng_to_sync = False + for i, inputs in enumerate(batch_samples): step += 1 - do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch + do_sync_step = (step + 1) % self.args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch # Since we perform prefetching, we need to manually set sync_gradients self.accelerator.gradient_state._set_sync_gradients(do_sync_step) @@ -1779,9 +1724,7 @@ def _run_epoch( and hasattr(self.processing_class, "pad_token_id") and self.processing_class.pad_token_id is not None ): - input_tokens = ( - inputs[main_input_name] != self.processing_class.pad_token_id - ).sum() + input_tokens = (inputs[main_input_name] != self.processing_class.pad_token_id).sum() else: logger.warning( "Could not determine method to count non-padding tokens, falling back to counting all tokens." @@ -1793,12 +1736,8 @@ def _run_epoch( input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64) self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).sum().item() - if rng_to_sync: - self._load_rng_state(resume_from_checkpoint) - rng_to_sync = False - - if step % args.gradient_accumulation_steps == 0: - self.control = self.callback_handler.on_step_begin(args, self.state, self.control) + if step % self.args.gradient_accumulation_steps == 0: + self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control) # We sync the gradients in the following cases: 1. sync_each_batch set to True 2. Using deepspeed 3. when we are at the last batch sample if ( @@ -1813,7 +1752,7 @@ def _run_epoch( tr_loss_step = self.training_step(model, inputs, num_items_in_batch) if ( - args.logging_nan_inf_filter + self.args.logging_nan_inf_filter and not is_torch_xla_available() and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) ): @@ -1833,9 +1772,9 @@ def _run_epoch( self.accelerator.gradient_state._set_sync_gradients(True) # Gradient clipping - if args.max_grad_norm is not None and args.max_grad_norm > 0: - if is_sagemaker_mp_enabled() and args.fp16: - _grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm) + if self.args.max_grad_norm is not None and self.args.max_grad_norm > 0: + if is_sagemaker_mp_enabled() and self.args.fp16: + _grad_norm = self.optimizer.clip_master_grads(self.args.max_grad_norm) else: grad_norm_context = contextlib.nullcontext if self.is_tp_enabled: @@ -1845,18 +1784,18 @@ def _run_epoch( with grad_norm_context(): _grad_norm = self.accelerator.clip_grad_norm_( model.parameters(), - args.max_grad_norm, + self.args.max_grad_norm, ) if self.accelerator.distributed_type == DistributedType.DEEPSPEED: - self._grad_norm = model.get_global_grad_norm() + grad_norm = model.get_global_grad_norm() # In some cases the grad norm may not return a float - if hasattr(self._grad_norm, "item"): - self._grad_norm = self._grad_norm.item() + if hasattr(grad_norm, "item"): + grad_norm = grad_norm.item() else: - self._grad_norm = _grad_norm + grad_norm = _grad_norm - self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control) + self.control = self.callback_handler.on_pre_optimizer_step(self.args, self.state, self.control) context = contextlib.nullcontext if self.is_tp_enabled: @@ -1867,10 +1806,9 @@ def _run_epoch( with context(): self.optimizer.step() - self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control) + self.control = self.callback_handler.on_optimizer_step(self.args, self.state, self.control) - # get learning rate before update - self._learning_rate = self._get_learning_rate() + learning_rate = self._get_learning_rate() if not self.accelerator.optimizer_step_was_skipped: # Delay optimizer scheduling until metrics are generated @@ -1880,19 +1818,19 @@ def _run_epoch( model.zero_grad() self.state.global_step += 1 self.state.epoch = epoch + (step + 1) / steps_in_epoch - self.control = self.callback_handler.on_step_end(args, self.state, self.control) + self.control = self.callback_handler.on_step_end(self.args, self.state, self.control) self._maybe_log_save_evaluate( self._tr_loss, - self._grad_norm, + grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, - learning_rate=self._learning_rate, + learning_rate=learning_rate, ) else: - self.control = self.callback_handler.on_substep_end(args, self.state, self.control) + self.control = self.callback_handler.on_substep_end(self.args, self.state, self.control) # PyTorch/XLA relies on the data loader to insert the mark_step for # each step. Since we are breaking the loop early, we need to manually @@ -1914,10 +1852,16 @@ def _run_epoch( ) self.control.should_training_stop = True - self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) + self.control = self.callback_handler.on_epoch_end(self.args, self.state, self.control) self._maybe_log_save_evaluate( - self._tr_loss, self._grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, - learning_rate=self._learning_rate, + self._tr_loss, + grad_norm, + model, + trial, + epoch, + ignore_keys_for_eval, + start_time, + learning_rate=learning_rate, ) if DebugOption.TPU_METRICS_DEBUG in self.args.debug: @@ -2402,17 +2346,15 @@ def _prepare_context_parallel_inputs( def set_initial_training_values( self, args: TrainingArguments, dataloader: DataLoader - ) -> tuple[int, int, int, int, bool, int | None, int, int]: + ) -> tuple[int, int, int, int, int, int | None, int]: """ Calculates and returns the following values: - `num_train_epochs` - `num_update_steps_per_epoch` - `num_examples` - `num_train_samples` - - `epoch_based` - - `len_dataloader` - - `max_steps` - `total_train_batch_size` + - `max_steps` """ # Case 1: we rely on `args.max_steps` first max_steps = args.max_steps @@ -2466,10 +2408,8 @@ def set_initial_training_values( num_update_steps_per_epoch, num_examples, num_train_samples, - epoch_based, - len_dataloader, - max_steps, total_train_batch_size, + max_steps, ) def get_total_train_batch_size(self, args: TrainingArguments) -> int: @@ -2522,16 +2462,16 @@ def get_tp_size(self) -> int: def _wrap_model(self, model: nn.Module, training: bool = True, dataloader: DataLoader | None = None) -> nn.Module: """Wrap `model` for distributed training if needed (DDP, FSDP, SageMaker, etc.).""" - if is_sagemaker_mp_enabled(): - # Wrapping the base model twice in a DistributedModel will raise an error. - if isinstance(self.model_wrapped, smp.model.DistributedModel): - return self.model_wrapped - return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps) - # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again if self.accelerator.unwrap_model(model, keep_torch_compile=False) is not model: return model + if is_sagemaker_mp_enabled(): + # Wrapping the base model twice in a DistributedModel will raise an error. + if isinstance(model, smp.model.DistributedModel): + return model + return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps) + # Multi-gpu training, 8bit models does not support DP if self.args.n_gpu > 1 and not getattr(model, "is_loaded_in_8bit", False): model = nn.DataParallel(model) @@ -2543,33 +2483,35 @@ def _wrap_model(self, model: nn.Module, training: bool = True, dataloader: DataL # Distributed training using PyTorch FSDP if self.is_fsdp_xla_enabled: - self.model = model = wrap_model_xla_fsdp(model, self.args, self.is_fsdp_xla_v2_enabled) + model = wrap_model_xla_fsdp(model, self.args, self.is_fsdp_xla_v2_enabled) elif is_sagemaker_dp_enabled(): model = nn.parallel.DistributedDataParallel( model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))] ) - elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: - if is_torch_neuroncore_available(): - return model - kwargs = {} - if self.args.ddp_find_unused_parameters is not None: - kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters - elif isinstance(model, PreTrainedModel): - # find_unused_parameters breaks checkpointing as per - # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 - kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing - else: - kwargs["find_unused_parameters"] = True + return model - if self.args.ddp_bucket_cap_mb is not None: - kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb + def _update_auto_batch_size(self, batch_size): + """Free memory, reset model wrapping, and update DeepSpeed config for the new batch size when using `auto_find_batch_size`""" + # reset everything + self.accelerator.free_memory() + # `_train_batch_size` value might have changed to `auto_find_batch_size` + self._train_batch_size = batch_size + # frees the wrapped model and resets it back to the unwrapped base model + release_memory(self.model_wrapped) - if self.args.ddp_broadcast_buffers is not None: - kwargs["broadcast_buffers"] = self.args.ddp_broadcast_buffers + if self.is_fsdp_enabled: + # Remove FSDP wrapping from sub-models because self.model points to the wrapped model in FSDP case + self.model = unwrap_model(self.model, recursive=True) - self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs) + self.model_wrapped = self.model - return model + # Check for DeepSpeed *after* the initial pass and modify the config + if self.is_deepspeed_enabled: + # Temporarily unset `self.args.train_batch_size` + original_bs = self.args.per_device_train_batch_size + self.args.per_device_train_batch_size = self._train_batch_size // max(1, self.args.n_gpu) + propagate_args_to_deepspeed(self.accelerator, self.args, auto_find_batch_size=True) + self.args.per_device_train_batch_size = original_bs # ---- Evaluation & Prediction ---- @@ -2695,14 +2637,13 @@ def evaluation_loop( if self.is_deepspeed_enabled and self.deepspeed is None: _, _ = deepspeed_init(self, num_training_steps=0, inference=True) - model = self._wrap_model(self.model, training=False, dataloader=dataloader) + model = self._wrap_model(self.model, training=False) if len(self.accelerator._models) == 0 and model is self.model: start_time = time.time() model = ( self.accelerator.prepare(model) - if self.is_deepspeed_enabled - or (self.is_fsdp_enabled and self.accelerator.mixed_precision != "fp8" and not self.args.torch_compile) + if self.is_deepspeed_enabled or (self.is_fsdp_enabled and not self.args.torch_compile) else self.accelerator.prepare_model(model, evaluation_mode=True) ) self.model_preparation_time = round(time.time() - start_time, 4) @@ -4362,6 +4303,7 @@ def _hp_search_setup(self, trial: "optuna.Trial | dict[str, Any] | None") -> Non # Simply calling `_reset_state` is enough and doesn't need a version pin. AcceleratorState()._reset_state() + # `train_batch_size` might change when using HPO https://github.com/huggingface/transformers/pull/18918 self._train_batch_size = self.args.train_batch_size self.create_accelerator_and_postprocess() diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index 92d61eba1ca2..ac9c5b164c9b 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -92,7 +92,7 @@ class TrainerState: Relevant callbacks should implement a `state` and `from_state` function. """ - epoch: float | None = None + epoch: float = 0 global_step: int = 0 max_steps: int = 0 logging_steps: int = 500 diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index aa8717dc8e90..74a94b32e8a9 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -15,6 +15,7 @@ PyTorch-independent utilities for the Trainer class. """ +import contextlib import copy import functools import gc @@ -1237,3 +1238,15 @@ def align_special_tokens(model, processing_class): "The model config and generation config were aligned accordingly, being updated with the tokenizer's " f"values. Updated tokens: {updated_tokens}." ) + + +@contextlib.contextmanager +def suppress_progress_bars(): + """Context manager that suppresses huggingface_hub progress bars.""" + import huggingface_hub.utils as hf_hub_utils + + hf_hub_utils.disable_progress_bars() + try: + yield + finally: + hf_hub_utils.enable_progress_bars() From c9f361036bb035f97684a41bc42cd6d30d2b7044 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 19 Feb 2026 15:35:12 +0000 Subject: [PATCH 05/19] move DummyScheduler --- src/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 4322d3a0f0d1..fc2a50a8f732 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -217,7 +217,6 @@ DataLoaderConfiguration, DistributedDataParallelKwargs, DistributedType, - DummyScheduler, GradientAccumulationPlugin, load_fsdp_model, load_fsdp_optimizer, @@ -1573,6 +1572,7 @@ def _prepare_model_and_optimizer(self, max_steps, train_dataloader, resume_from_ # prepare using `accelerator` prepare if use_accelerator_prepare: + from accelerate.utils import DummyScheduler if delay_optimizer_creation: # TODO: check if we can move this somewhere else if self.is_fsdp_enabled and _is_peft_model(self.model): From f4cad42b55d779138fc0a5a7e4635f4b83ebc489 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 19 Feb 2026 15:36:31 +0000 Subject: [PATCH 06/19] style --- src/transformers/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index fc2a50a8f732..36395cdbe0bc 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1573,6 +1573,7 @@ def _prepare_model_and_optimizer(self, max_steps, train_dataloader, resume_from_ # prepare using `accelerator` prepare if use_accelerator_prepare: from accelerate.utils import DummyScheduler + if delay_optimizer_creation: # TODO: check if we can move this somewhere else if self.is_fsdp_enabled and _is_peft_model(self.model): From 46e20417f7a8fe7f44f98102303f41daaa7f7f74 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 19 Feb 2026 15:41:15 +0000 Subject: [PATCH 07/19] make it private --- src/transformers/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 36395cdbe0bc..277d7dded0ed 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1482,7 +1482,7 @@ def _inner_training_loop( start_time = time.time() # needed to calculate tokens/s - self.initial_num_input_tokens_seen_for_session = self.state.num_input_tokens_seen + self._initial_num_input_tokens_seen = self.state.num_input_tokens_seen # Logging state: _tr_loss accumulates on-device between logging steps (avoiding costly .item() syncs # on TPUs), then gets drained into _total_loss_scalar at each logging step. self._tr_loss = torch.tensor(0.0, device=args.device) @@ -3865,7 +3865,7 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: logs["num_input_tokens_seen"] = self.state.num_input_tokens_seen if start_time is not None: current_session_num_tokens = ( - self.state.num_input_tokens_seen - self.initial_num_input_tokens_seen_for_session + self.state.num_input_tokens_seen - self._initial_num_input_tokens_seen ) logs.update(speed_metrics("train", start_time, num_tokens=current_session_num_tokens)) From da07a7179f28b845f30ec33c69cc1378497ac443 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 19 Feb 2026 15:53:28 +0000 Subject: [PATCH 08/19] switch to kwargs_handlers --- src/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 277d7dded0ed..9d094c7f736a 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -716,7 +716,7 @@ def _build_accelerator_args(self, **kwargs) -> dict[str, Any]: if self.args.ddp_broadcast_buffers is not None: ddp_kwargs["broadcast_buffers"] = self.args.ddp_broadcast_buffers - args["ddp_handler"] = DistributedDataParallelKwargs(**ddp_kwargs) + args["kwargs_handlers"] = [DistributedDataParallelKwargs(**ddp_kwargs)] # We defer compatibility checks to accelerator if self.args.parallelism_config is not None: From c6bee6a7a9c37bee6b4e421c58b8906de0abe084 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 19 Feb 2026 15:54:45 +0000 Subject: [PATCH 09/19] style --- src/transformers/trainer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 9d094c7f736a..d106cb61d0e4 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3864,9 +3864,7 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: if self.args.include_num_input_tokens_seen != "no": logs["num_input_tokens_seen"] = self.state.num_input_tokens_seen if start_time is not None: - current_session_num_tokens = ( - self.state.num_input_tokens_seen - self._initial_num_input_tokens_seen - ) + current_session_num_tokens = self.state.num_input_tokens_seen - self._initial_num_input_tokens_seen logs.update(speed_metrics("train", start_time, num_tokens=current_session_num_tokens)) output = {**logs, "step": self.state.global_step} From daec1155e2d9ef3685057917f742a04462e9e316 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 19 Feb 2026 16:57:40 +0000 Subject: [PATCH 10/19] fix resuming --- src/transformers/trainer.py | 65 ++++++++++++++++++------------------- 1 file changed, 31 insertions(+), 34 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index d106cb61d0e4..83a805ee88a7 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1448,6 +1448,7 @@ def _inner_training_loop( num_examples, num_train_samples, total_train_batch_size, + steps_in_epoch, max_steps, ) = self.set_initial_training_values(args, train_dataloader) @@ -1462,6 +1463,7 @@ def _inner_training_loop( logger.info("***** Running training *****") logger.info(f" Num examples = {num_examples:,}") logger.info(f" Num Epochs = {num_train_epochs:,}") + logger.info(f" Num update steps per epoch = {num_update_steps_per_epoch:,}") logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}") if self.args.per_device_train_batch_size != self._train_batch_size: logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}") @@ -1497,10 +1499,14 @@ def _inner_training_loop( self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True) for epoch in range(epochs_trained, num_train_epochs): + + self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control) self._run_epoch( model=model, epoch=epoch, train_dataloader=train_dataloader, + steps_in_epoch=steps_in_epoch, + num_update_steps_per_epoch=num_update_steps_per_epoch, trial=trial, ignore_keys_for_eval=ignore_keys_for_eval, start_time=start_time, @@ -1572,8 +1578,6 @@ def _prepare_model_and_optimizer(self, max_steps, train_dataloader, resume_from_ # prepare using `accelerator` prepare if use_accelerator_prepare: - from accelerate.utils import DummyScheduler - if delay_optimizer_creation: # TODO: check if we can move this somewhere else if self.is_fsdp_enabled and _is_peft_model(self.model): @@ -1583,7 +1587,7 @@ def _prepare_model_and_optimizer(self, max_steps, train_dataloader, resume_from_ # using the model we prepared to create the optimizer self.create_optimizer(model) self.optimizer = self.accelerator.prepare(self.optimizer) - elif self.is_deepspeed_enabled and isinstance(self.lr_scheduler, DummyScheduler): + elif self.is_deepspeed_enabled and type(self.lr_scheduler).__name__ == "DummyScheduler": model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( self.model, self.optimizer, self.lr_scheduler ) @@ -1647,6 +1651,8 @@ def _run_epoch( model, epoch, train_dataloader, + steps_in_epoch, + num_update_steps_per_epoch, trial, ignore_keys_for_eval, start_time, @@ -1655,53 +1661,51 @@ def _run_epoch( steps_trained_in_current_epoch, ): """Run one full pass over the dataloader.""" - steps_in_epoch = ( - len(train_dataloader) - if has_length(train_dataloader) is not None - else self.args.max_steps * self.args.gradient_accumulation_steps - ) step = -1 - grad_norm: float | None = None + grad_norm = None learning_rate = None rng_to_sync = False - self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control) - # Handle resumption from checkpoint: skip already-trained batches in the resumed epoch + num_update_steps_trained = 0 if epoch == epochs_trained and resume_from_checkpoint is not None: if steps_trained_in_current_epoch > 0 and not self.args.ignore_data_skip: train_dataloader = skip_first_batches(train_dataloader, steps_trained_in_current_epoch) step = steps_trained_in_current_epoch - 1 + num_update_steps_trained = steps_trained_in_current_epoch // self.args.gradient_accumulation_steps rng_to_sync = True elif steps_trained_in_current_epoch == 0: self._load_rng_state(resume_from_checkpoint) if hasattr(train_dataloader, "set_epoch"): train_dataloader.set_epoch(epoch) - epoch_iterator = iter(train_dataloader) + # We chunkify the epoch iterator into gradient accumulation steps `n` batches remainder = steps_in_epoch % self.args.gradient_accumulation_steps if remainder == 0: remainder = self.args.gradient_accumulation_steps - update_step = -1 - total_updates = steps_in_epoch // self.args.gradient_accumulation_steps + int( - remainder < self.args.gradient_accumulation_steps - ) - for _ in range(total_updates): - update_step += 1 - num_batches = self.args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder + + # Outer loop: one iteration per optimizer step. Each iteration prefetches + # `gradient_accumulation_steps` batches (fewer for the last step if the epoch + # doesn't divide evenly). + for update_step in range(num_update_steps_trained, num_update_steps_per_epoch): + num_batches = self.args.gradient_accumulation_steps if update_step != (num_update_steps_per_epoch - 1) else remainder batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches, self.args.device) - # Store the number of batches for current gradient accumulation - # This is used to correctly scale the loss when the last accumulation step has fewer batches + + # This is used to correctly scale the loss when the last accumulation step has fewer batches. + # Not used if `num_items_in_batch` is not None. self.current_gradient_accumulation_steps = len(batch_samples) - # need to sync after we skipped the batched in `get_batch_samples` + # need to sync after if we skipped the batches in `get_batch_samples` for shuffle order reason if rng_to_sync: self._load_rng_state(resume_from_checkpoint) rng_to_sync = False + # Inner loop: forward + backward for each micro-batch. Gradients are + # accumulated without syncing until the last micro-batch, then we clip, + # step the optimizer, and log/save/evaluate. for i, inputs in enumerate(batch_samples): step += 1 do_sync_step = (step + 1) % self.args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch @@ -1864,16 +1868,8 @@ def _run_epoch( start_time, learning_rate=learning_rate, ) - if DebugOption.TPU_METRICS_DEBUG in self.args.debug: - if is_torch_xla_available(): - # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) - xm.master_print(met.metrics_report()) - else: - logger.warning( - "You enabled PyTorch/XLA debug metrics but you don't have a TPU " - "configured. Check your training configuration if this is unexpected." - ) + xm.master_print(met.metrics_report()) def _finalize_training(self, trial, num_train_samples, start_time): """Finalize training: metrics, best-model loading, cleanup. Returns TrainOutput.""" @@ -2355,6 +2351,7 @@ def set_initial_training_values( - `num_examples` - `num_train_samples` - `total_train_batch_size` + - `steps_in_epoch` (total batches per epoch) - `max_steps` """ # Case 1: we rely on `args.max_steps` first @@ -2378,8 +2375,7 @@ def set_initial_training_values( ) # Case 3: We have a length but are using epochs, we can extrapolate the number of steps if epoch_based: - max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) - + max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) # Now we figure out `num_examples`, `num_train_epochs`, and `train_samples` if len_dataloader: num_examples = self.num_examples(dataloader) @@ -2404,12 +2400,14 @@ def set_initial_training_values( "args.max_steps must be set to a positive value if dataloader does not have a length, was" f" {args.max_steps}" ) + steps_in_epoch = len_dataloader if len_dataloader is not None else max_steps * args.gradient_accumulation_steps return ( num_train_epochs, num_update_steps_per_epoch, num_examples, num_train_samples, total_train_batch_size, + steps_in_epoch, max_steps, ) @@ -2608,7 +2606,6 @@ def evaluate( self.log(output.metrics) if DebugOption.TPU_METRICS_DEBUG in self.args.debug: - # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) From a434736229445a7d6fba4f7dc506659b535d1f8f Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 19 Feb 2026 16:58:30 +0000 Subject: [PATCH 11/19] style --- src/transformers/trainer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 83a805ee88a7..2ffec505d6cf 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1499,7 +1499,6 @@ def _inner_training_loop( self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True) for epoch in range(epochs_trained, num_train_epochs): - self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control) self._run_epoch( model=model, @@ -1691,7 +1690,9 @@ def _run_epoch( # `gradient_accumulation_steps` batches (fewer for the last step if the epoch # doesn't divide evenly). for update_step in range(num_update_steps_trained, num_update_steps_per_epoch): - num_batches = self.args.gradient_accumulation_steps if update_step != (num_update_steps_per_epoch - 1) else remainder + num_batches = ( + self.args.gradient_accumulation_steps if update_step != (num_update_steps_per_epoch - 1) else remainder + ) batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches, self.args.device) # This is used to correctly scale the loss when the last accumulation step has fewer batches. @@ -2375,7 +2376,7 @@ def set_initial_training_values( ) # Case 3: We have a length but are using epochs, we can extrapolate the number of steps if epoch_based: - max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) + max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) # Now we figure out `num_examples`, `num_train_epochs`, and `train_samples` if len_dataloader: num_examples = self.num_examples(dataloader) From fa551548ab1da22b5a14dd35a66e2bda811a3553 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 19 Feb 2026 18:24:33 +0000 Subject: [PATCH 12/19] revert loading after --- src/transformers/trainer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 2ffec505d6cf..1bacf23e8cae 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1407,6 +1407,11 @@ def train( raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") if resume_from_checkpoint is not None: + # Load model checkpoint before accelerator.prepare() for regular models, + # so that buffers and parameters are on the right device after prepare. + # Deepspeed/FSDP models are loaded after prepare in _prepare_model_and_optimizer. + if not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled and not self.is_fsdp_enabled: + self._load_from_checkpoint(resume_from_checkpoint) state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) if state.train_batch_size is not None and args.auto_find_batch_size: # Only restore the checkpoint's train_batch_size when using auto_find_batch_size, From c18c16772cbfbb4d8cd79ca293ce509549b39412 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 19 Feb 2026 19:23:25 +0000 Subject: [PATCH 13/19] way better now --- src/transformers/trainer.py | 135 +++++++++++++++++------------------- 1 file changed, 62 insertions(+), 73 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 1bacf23e8cae..2fee830286a5 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1650,6 +1650,56 @@ def _prepare_model_and_optimizer(self, max_steps, train_dataloader, resume_from_ return model, train_dataloader + def _track_num_input_tokens(self, inputs): + """Count input tokens seen (all or non-padding) and update state.""" + if self.args.include_num_input_tokens_seen != "no": + return + main_input_name = getattr(self.model, "main_input_name", "input_ids") + if main_input_name not in inputs: + logger.warning( + "Tried to track the number of tokens seen, however the current model is " + "not configured properly to know what item is the input. To fix this, add " + "a `main_input_name` attribute to the model class you are using." + ) + return + + if self.args.include_num_input_tokens_seen == "non_padding": + if "attention_mask" in inputs: + input_tokens = inputs["attention_mask"].sum() + elif ( + self.processing_class is not None + and hasattr(self.processing_class, "pad_token_id") + and self.processing_class.pad_token_id is not None + ): + input_tokens = (inputs[main_input_name] != self.processing_class.pad_token_id).sum() + else: + logger.warning( + "Could not determine method to count non-padding tokens, falling back to counting all tokens." + ) + input_tokens = inputs[main_input_name].numel() + else: + input_tokens = inputs[main_input_name].numel() + + input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64) + self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).sum().item() + + def _clip_grad_norm(self, model): + """Clip gradients to max_grad_norm. Returns the pre-clip gradient norm.""" + if is_sagemaker_mp_enabled() and self.args.fp16: + return self.optimizer.clip_master_grads(self.args.max_grad_norm) + return self.accelerator.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) + + def _get_grad_norm(self, model, grad_norm=None): + """Return the gradient norm as a Python float.""" + if grad_norm is None: + # Compute norm without clipping (inf means no actual clipping happens) + grad_norm = self.accelerator.clip_grad_norm_(model.parameters(), float("inf")) + + if self.accelerator.distributed_type == DistributedType.DEEPSPEED: + if hasattr(grad_norm, "item"): + grad_norm = grad_norm.item() + return grad_norm + def _run_epoch( self, model, @@ -1718,35 +1768,6 @@ def _run_epoch( # Since we perform prefetching, we need to manually set sync_gradients self.accelerator.gradient_state._set_sync_gradients(do_sync_step) - if self.args.include_num_input_tokens_seen != "no": - main_input_name = getattr(self.model, "main_input_name", "input_ids") - if main_input_name not in inputs: - logger.warning( - "Tried to track the number of tokens seen, however the current model is " - "not configured properly to know what item is the input. To fix this, add " - "a `main_input_name` attribute to the model class you are using." - ) - else: - if self.args.include_num_input_tokens_seen == "non_padding": - if "attention_mask" in inputs: - input_tokens = inputs["attention_mask"].sum() - elif ( - self.processing_class is not None - and hasattr(self.processing_class, "pad_token_id") - and self.processing_class.pad_token_id is not None - ): - input_tokens = (inputs[main_input_name] != self.processing_class.pad_token_id).sum() - else: - logger.warning( - "Could not determine method to count non-padding tokens, falling back to counting all tokens." - ) - input_tokens = inputs[main_input_name].numel() - else: - input_tokens = inputs[main_input_name].numel() - - input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64) - self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).sum().item() - if step % self.args.gradient_accumulation_steps == 0: self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control) @@ -1777,48 +1798,18 @@ def _run_epoch( self._tr_loss += tr_loss_step self.current_flos += float(self.floating_point_ops(inputs)) + self._track_num_input_tokens(inputs) if do_sync_step: - # Since we perform prefetching, we need to manually set sync_gradients to True - self.accelerator.gradient_state._set_sync_gradients(True) - - # Gradient clipping - if self.args.max_grad_norm is not None and self.args.max_grad_norm > 0: - if is_sagemaker_mp_enabled() and self.args.fp16: - _grad_norm = self.optimizer.clip_master_grads(self.args.max_grad_norm) - else: - grad_norm_context = contextlib.nullcontext - if self.is_tp_enabled: - from torch.distributed._tensor.experimental import implicit_replication - - grad_norm_context = implicit_replication - with grad_norm_context(): - _grad_norm = self.accelerator.clip_grad_norm_( - model.parameters(), - self.args.max_grad_norm, - ) - - if self.accelerator.distributed_type == DistributedType.DEEPSPEED: - grad_norm = model.get_global_grad_norm() - # In some cases the grad norm may not return a float - if hasattr(grad_norm, "item"): - grad_norm = grad_norm.item() - else: - grad_norm = _grad_norm + if self.args.max_grad_norm > 0: + grad_norm = self._clip_grad_norm(model) + grad_norm = self._get_grad_norm(model, grad_norm=grad_norm) self.control = self.callback_handler.on_pre_optimizer_step(self.args, self.state, self.control) - - context = contextlib.nullcontext - if self.is_tp_enabled: - from torch.distributed._tensor.experimental import implicit_replication - - context = implicit_replication - - with context(): - self.optimizer.step() - + self.optimizer.step() self.control = self.callback_handler.on_optimizer_step(self.args, self.state, self.control) + # get leaning rate before update learning_rate = self._get_learning_rate() if not self.accelerator.optimizer_step_was_skipped: @@ -1843,18 +1834,16 @@ def _run_epoch( else: self.control = self.callback_handler.on_substep_end(self.args, self.state, self.control) - # PyTorch/XLA relies on the data loader to insert the mark_step for - # each step. Since we are breaking the loop early, we need to manually - # insert the mark_step here. if self.control.should_epoch_stop or self.control.should_training_stop: - if is_torch_xla_available(): - xm.mark_step() break - # We also need to break out of the nested loop if self.control.should_epoch_stop or self.control.should_training_stop: - if is_torch_xla_available(): - xm.mark_step() break + + # PyTorch/XLA relies on the dataloader to insert mark_step each iteration. + # When we break out of the loop early, we flush the pending graph manually. + if is_torch_xla_available(): + xm.mark_step() + if step < 0: logger.warning( "There seems not to be a single sample in your epoch_iterator, stopping training at step" From 2ced96a5f8d26262197b91041dfeeb8c2e1064b7 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 19 Feb 2026 19:25:08 +0000 Subject: [PATCH 14/19] move method --- src/transformers/trainer.py | 100 ++++++++++++++++++------------------ 1 file changed, 50 insertions(+), 50 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 2fee830286a5..9fe3340a3f2e 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1650,56 +1650,6 @@ def _prepare_model_and_optimizer(self, max_steps, train_dataloader, resume_from_ return model, train_dataloader - def _track_num_input_tokens(self, inputs): - """Count input tokens seen (all or non-padding) and update state.""" - if self.args.include_num_input_tokens_seen != "no": - return - main_input_name = getattr(self.model, "main_input_name", "input_ids") - if main_input_name not in inputs: - logger.warning( - "Tried to track the number of tokens seen, however the current model is " - "not configured properly to know what item is the input. To fix this, add " - "a `main_input_name` attribute to the model class you are using." - ) - return - - if self.args.include_num_input_tokens_seen == "non_padding": - if "attention_mask" in inputs: - input_tokens = inputs["attention_mask"].sum() - elif ( - self.processing_class is not None - and hasattr(self.processing_class, "pad_token_id") - and self.processing_class.pad_token_id is not None - ): - input_tokens = (inputs[main_input_name] != self.processing_class.pad_token_id).sum() - else: - logger.warning( - "Could not determine method to count non-padding tokens, falling back to counting all tokens." - ) - input_tokens = inputs[main_input_name].numel() - else: - input_tokens = inputs[main_input_name].numel() - - input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64) - self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).sum().item() - - def _clip_grad_norm(self, model): - """Clip gradients to max_grad_norm. Returns the pre-clip gradient norm.""" - if is_sagemaker_mp_enabled() and self.args.fp16: - return self.optimizer.clip_master_grads(self.args.max_grad_norm) - return self.accelerator.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) - - def _get_grad_norm(self, model, grad_norm=None): - """Return the gradient norm as a Python float.""" - if grad_norm is None: - # Compute norm without clipping (inf means no actual clipping happens) - grad_norm = self.accelerator.clip_grad_norm_(model.parameters(), float("inf")) - - if self.accelerator.distributed_type == DistributedType.DEEPSPEED: - if hasattr(grad_norm, "item"): - grad_norm = grad_norm.item() - return grad_norm - def _run_epoch( self, model, @@ -2507,6 +2457,56 @@ def _update_auto_batch_size(self, batch_size): propagate_args_to_deepspeed(self.accelerator, self.args, auto_find_batch_size=True) self.args.per_device_train_batch_size = original_bs + def _track_num_input_tokens(self, inputs): + """Count input tokens seen (all or non-padding) and update state.""" + if self.args.include_num_input_tokens_seen != "no": + return + main_input_name = getattr(self.model, "main_input_name", "input_ids") + if main_input_name not in inputs: + logger.warning( + "Tried to track the number of tokens seen, however the current model is " + "not configured properly to know what item is the input. To fix this, add " + "a `main_input_name` attribute to the model class you are using." + ) + return + + if self.args.include_num_input_tokens_seen == "non_padding": + if "attention_mask" in inputs: + input_tokens = inputs["attention_mask"].sum() + elif ( + self.processing_class is not None + and hasattr(self.processing_class, "pad_token_id") + and self.processing_class.pad_token_id is not None + ): + input_tokens = (inputs[main_input_name] != self.processing_class.pad_token_id).sum() + else: + logger.warning( + "Could not determine method to count non-padding tokens, falling back to counting all tokens." + ) + input_tokens = inputs[main_input_name].numel() + else: + input_tokens = inputs[main_input_name].numel() + + input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64) + self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).sum().item() + + def _clip_grad_norm(self, model): + """Clip gradients to max_grad_norm. Returns the pre-clip gradient norm.""" + if is_sagemaker_mp_enabled() and self.args.fp16: + return self.optimizer.clip_master_grads(self.args.max_grad_norm) + return self.accelerator.clip_grad_norm_(model.parameters(), self.args.max_grad_norm) + + def _get_grad_norm(self, model, grad_norm=None): + """Return the gradient norm as a Python float.""" + if grad_norm is None: + # Compute norm without clipping (inf means no actual clipping happens) + grad_norm = self.accelerator.clip_grad_norm_(model.parameters(), float("inf")) + + if self.accelerator.distributed_type == DistributedType.DEEPSPEED: + if hasattr(grad_norm, "item"): + grad_norm = grad_norm.item() + return grad_norm + # ---- Evaluation & Prediction ---- def evaluate( From 7feecabefc0202d3df9e7248ec1b6621ec3dddfb Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 19 Feb 2026 19:27:46 +0000 Subject: [PATCH 15/19] small mistake --- src/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 9fe3340a3f2e..a0f402eda3b4 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2459,7 +2459,7 @@ def _update_auto_batch_size(self, batch_size): def _track_num_input_tokens(self, inputs): """Count input tokens seen (all or non-padding) and update state.""" - if self.args.include_num_input_tokens_seen != "no": + if self.args.include_num_input_tokens_seen == "no": return main_input_name = getattr(self.model, "main_input_name", "input_ids") if main_input_name not in inputs: From b37736d9763ea4bc788216ec77419efd7fe5627a Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 19 Feb 2026 20:15:22 +0000 Subject: [PATCH 16/19] update --- src/transformers/testing_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 273edeb88d5d..3ba75a680c44 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -1469,7 +1469,7 @@ def get_steps_per_epoch(trainer: Trainer) -> int: train_dataloader = trainer.get_train_dataloader() initial_training_values = trainer.set_initial_training_values(args=training_args, dataloader=train_dataloader) - steps_per_epoch = initial_training_values[1] + steps_per_epoch = initial_training_values[5] return steps_per_epoch From a697854fcd13b62559c21e5f092ff2dca8c91efd Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 20 Feb 2026 15:52:36 +0000 Subject: [PATCH 17/19] remove tp_size --- src/transformers/trainer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index a0f402eda3b4..3b3d2f7bb845 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -727,9 +727,7 @@ def _build_accelerator_args(self, **kwargs) -> dict[str, Any]: ) args["parallelism_config"] = self.args.parallelism_config - self.is_tp_enabled = False if getattr(self.model, "tp_size", None) is not None and self.model.tp_size > 1: - self.is_tp_enabled = True if self.args.parallelism_config is None: if is_accelerate_available("1.12.0"): if self.args.parallelism_config is None: From cc24432005ec5cd1ce05c90328a9d4aa410dcf3e Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 20 Feb 2026 19:14:55 +0000 Subject: [PATCH 18/19] account for grpo --- src/transformers/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 3b3d2f7bb845..437d4f1c1b16 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1437,6 +1437,7 @@ def _inner_training_loop( ignore_keys_for_eval: list[str] | None = None, ) -> TrainOutput: """Run the actual training loop: forward, backward, optimizer step, logging, and checkpointing.""" + self.accelerator.free_memory() if args.auto_find_batch_size: self._update_auto_batch_size(batch_size) # Data loader and number of training steps From 2e97639a499964f96e8f486c51904738082ced5b Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Mon, 23 Feb 2026 15:54:56 +0000 Subject: [PATCH 19/19] final touch --- src/transformers/trainer.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 437d4f1c1b16..6794c2cca07d 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1407,7 +1407,7 @@ def train( if resume_from_checkpoint is not None: # Load model checkpoint before accelerator.prepare() for regular models, # so that buffers and parameters are on the right device after prepare. - # Deepspeed/FSDP models are loaded after prepare in _prepare_model_and_optimizer. + # Deepspeed/FSDP models are loaded after prepare in _prepare_for_training. if not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled and not self.is_fsdp_enabled: self._load_from_checkpoint(resume_from_checkpoint) state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) @@ -1437,6 +1437,7 @@ def _inner_training_loop( ignore_keys_for_eval: list[str] | None = None, ) -> TrainOutput: """Run the actual training loop: forward, backward, optimizer step, logging, and checkpointing.""" + # reset everything self.accelerator.free_memory() if args.auto_find_batch_size: self._update_auto_batch_size(batch_size) @@ -1459,9 +1460,7 @@ def _inner_training_loop( epochs_trained, steps_trained_in_current_epoch = self._init_training_state( max_steps, num_update_steps_per_epoch, num_train_epochs, resume_from_checkpoint, trial ) - model, train_dataloader = self._prepare_model_and_optimizer( - max_steps, train_dataloader, resume_from_checkpoint - ) + model, train_dataloader = self._prepare_for_training(max_steps, train_dataloader, resume_from_checkpoint) # Train! logger.info("***** Running training *****") @@ -1553,7 +1552,7 @@ def _init_training_state( return epochs_trained, steps_trained_in_current_epoch - def _prepare_model_and_optimizer(self, max_steps, train_dataloader, resume_from_checkpoint): + def _prepare_for_training(self, max_steps, train_dataloader, resume_from_checkpoint): """Wrap model, create optimizer and scheduler, and run accelerator.prepare. Returns (model, train_dataloader).""" delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled @@ -1750,6 +1749,7 @@ def _run_epoch( self._track_num_input_tokens(inputs) if do_sync_step: + grad_norm = None if self.args.max_grad_norm > 0: grad_norm = self._clip_grad_norm(model) grad_norm = self._get_grad_norm(model, grad_norm=grad_norm) @@ -1812,14 +1812,10 @@ def _run_epoch( start_time, learning_rate=learning_rate, ) - if DebugOption.TPU_METRICS_DEBUG in self.args.debug: - xm.master_print(met.metrics_report()) def _finalize_training(self, trial, num_train_samples, start_time): """Finalize training: metrics, best-model loading, cleanup. Returns TrainOutput.""" logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") - if self.args.load_best_model_at_end and self.state.best_model_checkpoint is not None: - self._load_best_model() # add remaining tr_loss self._total_loss_scalar += self._tr_loss.item() @@ -1836,15 +1832,14 @@ def _finalize_training(self, trial, num_train_samples, start_time): metrics["total_flos"] = self.state.total_flos metrics["train_loss"] = train_loss - self.is_in_train = False - self._memory_tracker.stop_and_update_metrics(metrics) - self.log(metrics) - run_dir = self._get_output_dir(trial) + if self.args.load_best_model_at_end and self.state.best_model_checkpoint is not None: + self._load_best_model() + checkpoints_sorted = sort_checkpoints( - output_dir=run_dir, best_model_checkpoint=self.state.best_model_checkpoint + output_dir=self._get_output_dir(trial), best_model_checkpoint=self.state.best_model_checkpoint ) # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save. @@ -1863,6 +1858,7 @@ def _finalize_training(self, trial, num_train_samples, start_time): # for the embedding layer by removing the forward post hook. if self.neftune_noise_alpha is not None: deactivate_neftune(self.model, self.neftune_hook_handle, self.accelerator) + self.is_in_train = False return TrainOutput(self.state.global_step, train_loss, metrics) @@ -2435,8 +2431,6 @@ def _wrap_model(self, model: nn.Module, training: bool = True, dataloader: DataL def _update_auto_batch_size(self, batch_size): """Free memory, reset model wrapping, and update DeepSpeed config for the new batch size when using `auto_find_batch_size`""" - # reset everything - self.accelerator.free_memory() # `_train_batch_size` value might have changed to `auto_find_batch_size` self._train_batch_size = batch_size # frees the wrapped model and resets it back to the unwrapped base model