From 9e719a26313becfa398471b334409c99d6777051 Mon Sep 17 00:00:00 2001 From: Anusha Bhamidipati Date: Mon, 27 Apr 2026 12:08:42 +0530 Subject: [PATCH] Added the following support for easy visualization of training and validation statistics: 1. train_logger callback function which captures the per epoch time, per epoch loss metric and per epoch perplexity 2. This function also captures number of trainable parameters, number of samples in training and eval dataset 3. All these are logged into a log file which can be given as an input by user by setting the flag --log_file_path in the input config .yaml file. Signed-off-by: abhamidi --- QEfficient/cloud/finetune_experimental.py | 22 ++- .../finetune/experimental/core/callbacks.py | 134 ++++++++++++++++++ .../experimental/core/config_manager.py | 4 + 3 files changed, 159 insertions(+), 1 deletion(-) diff --git a/QEfficient/cloud/finetune_experimental.py b/QEfficient/cloud/finetune_experimental.py index ce024828c8..0edb47f7a8 100644 --- a/QEfficient/cloud/finetune_experimental.py +++ b/QEfficient/cloud/finetune_experimental.py @@ -14,7 +14,9 @@ from pathlib import Path from typing import Any, Dict, List, Tuple -from QEfficient.finetune.experimental.core.callbacks import replace_progress_callback +from peft import get_peft_model + +from QEfficient.finetune.experimental.core.callbacks import TrainingLogger, replace_progress_callback from QEfficient.finetune.experimental.core.component_registry import ComponentFactory from QEfficient.finetune.experimental.core.config_manager import ( ConfigManager, @@ -38,6 +40,8 @@ f"Unable to import 'torch_qaic' package due to exception: {e}. Moving ahead without the torch_qaic extension.", level=logging.WARNING, ) +# Separate logger instance for training logs, ensures rank 0 logging and file writing to avoid conflicts in distributed settings +train_logger = TrainingLogger(rank=0) class FineTuningPipeline: @@ -107,6 +111,7 @@ def get_trainer(self): def _setup_environment(self) -> None: """Set up environment variables for output directories.""" + self.rank = int(os.environ.get("RANK", "0")) os.environ["OUTPUT_DIR"] = str(self.output_dir) os.environ["TRACKIO_DIR"] = str(self.output_dir / "trackio_logs") os.environ["TENSORBOARD_LOGGING_DIR"] = str(self.output_dir) @@ -254,10 +259,21 @@ def _create_trainer( dependencies = {} if peft_config is not None: dependencies["peft_config"] = peft_config + if self.rank == 0: + model_configuration = get_peft_model(model, peft_config) + trainable_params, all_param = model_configuration.get_nb_trainable_parameters() + pct = (trainable_params / all_param) * 100 + model_configuration.unload() # Removing the peft adapters + train_logger.write(f"TRAINING INFO: Model has {all_param / 1e6:.4f} Million params.") + train_logger.write( + f"TRAINING INFO: Trainable params: {trainable_params} || " + f"all params: {all_param} || trainable%: {pct:.4f}" + ) trainer_cls, args_cls, additional_kwargs = ComponentFactory.create_trainer_config(trainer_type, **dependencies) # Clean up training config: remove fields that shouldn't be passed to TrainingArguments training_config.pop("device", None) + training_config.pop("log_file_name", None) # Note: torch_dtype was already converted to fp16/bf16 flag in prepare_training_config training_config.pop("deepspeed_config", None) training_config.pop("torch_dtype", None) @@ -280,6 +296,10 @@ def _create_trainer( subset_eval_indices = list(range(0, int(num_samples - num_samples * split_ratio))) eval_dataset = eval_dataset.select(subset_eval_indices) train_dataset = train_dataset.select(subset_train_indices) + # Logging the number of training and evaluation samples + if self.rank == 0: + train_logger.write(f"TRAINING INFO: Length of Training Dataset is {len(train_dataset)}") + train_logger.write(f"TRAINING INFO: Length of Evaluation Dataset is {len(eval_dataset)}") trainer = trainer_cls( model=model, processing_class=tokenizer, diff --git a/QEfficient/finetune/experimental/core/callbacks.py b/QEfficient/finetune/experimental/core/callbacks.py index bd1ce91c2e..840a72fe46 100644 --- a/QEfficient/finetune/experimental/core/callbacks.py +++ b/QEfficient/finetune/experimental/core/callbacks.py @@ -6,7 +6,12 @@ # ----------------------------------------------------------------------------- import json +import logging +import math import os +import time +from datetime import datetime +from pathlib import Path from typing import Any, Dict, Optional from transformers import ( @@ -20,6 +25,8 @@ from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState from QEfficient.finetune.experimental.core.component_registry import ComponentFactory, registry +from QEfficient.finetune.experimental.core.config_manager import ConfigManager +from QEfficient.finetune.experimental.core.logger import Logger from QEfficient.finetune.experimental.core.utils.profiler_utils import ( get_op_verifier_ctx, init_qaic_profiling, @@ -32,6 +39,122 @@ registry.callback("tensorboard")(TensorBoardCallback) +logger = Logger(__name__) + +# Setting the path for dumping the log file +config = ConfigManager().config.training + +output_dir = Path(config["output_dir"]) +log_file_name = config.get("log_file_name") + +log_file_name = ( + output_dir / log_file_name if log_file_name else output_dir / f"training_logs_{datetime.now():%Y%m%d_%H%M%S}.txt" +) + +log_file_name.parent.mkdir(parents=True, exist_ok=True) + + +@registry.callback("train_logger") +class TrainingLogger(TrainerCallback): + """ + A [`TrainerCallback`] that logs per epoch time, training metric (perplexity),training loss, evaluation metrics and loss etc. + These are only logged for rank = 0. + """ + + def __init__(self, rank=0, log_file: str | None = log_file_name): + self.rank = rank # rank-safe logging (only rank 0) + # Log file setup + self.log_file = log_file + # Ensure directory exists + os.makedirs(os.path.dirname(self.log_file), exist_ok=True) + self.epoch_start_time = None + self.best_eval_loss = float("inf") + + # ---------------------------------------------------- + # Safe write to log (only rank 0) + # ---------------------------------------------------- + def write(self, text): + if self.rank != 0: + return + logger.log_rank_zero(text) + try: + with open(self.log_file, "a") as f: + f.write(text + "\n") + f.flush() + os.fsync(f.fileno()) + + except OSError: + logging.exception("Failed to write to log file: %s", self.log_file) + + # ---------------------------------------------------- + # EPOCH BEGIN + # ---------------------------------------------------- + def on_epoch_begin(self, args, state, control, **kwargs): + if self.rank != 0: + return + + epoch = int(state.epoch) + 1 + self.epoch_start_time = time.time() + if state.is_world_process_zero: + self.write(f"TRAINING INFO: Starting epoch {epoch}/{int(args.num_train_epochs)}") + + # ---------------------------------------------------- + # EVALUATION + # ---------------------------------------------------- + def on_evaluate(self, args, state, control, metrics, **kwargs): + if self.rank != 0: + return + + epoch = int(state.epoch) + eval_loss = None + eval_metric = None + + for entry in reversed(state.log_history): + if "eval_loss" in entry: + eval_loss = entry["eval_loss"] + break + if eval_loss is not None: + eval_metric = math.exp(eval_loss) + # Track best eval loss + if eval_loss is not None and eval_loss < self.best_eval_loss: + self.best_eval_loss = eval_loss + if state.is_world_process_zero: + self.write(f"EVALUATION INFO: Best eval loss on epoch {epoch} is {eval_loss:.4f}") + if state.is_world_process_zero: + self.write(f"EVALUATION INFO: Epoch {epoch}: Eval Loss: {eval_loss:.4f} || Eval metric: {eval_metric:.4f}") + + # ---------------------------------------------------- + # EPOCH END — TRAIN LOSS + METRIC + TIME + # ---------------------------------------------------- + def on_epoch_end(self, args, state, control, **kwargs): + if self.rank != 0: + return + + epoch = int(state.epoch) + epoch_time = time.time() - self.epoch_start_time + + # Extract the last recorded train loss + train_loss = None + for entry in reversed(state.log_history): + if "loss" in entry: + train_loss = entry["loss"] + break + + # Compute perplexity safely + train_metric = None + if train_loss is not None: + train_metric = math.exp(train_loss) + if state.is_world_process_zero: + self.write( + f"TRAINING INFO: Epoch {epoch}: " + f" Train epoch loss: {train_loss:.4f} || " + f" Train metric: {train_metric} || " + f" Epoch time {epoch_time:.2f} sec" + ) + state.log_history.append({"train/epoch_time_sec": epoch_time, "epoch": state.epoch}) + control.should_log = True + + @registry.callback("enhanced_progressbar") class EnhancedProgressCallback(ProgressCallback): """ @@ -233,3 +356,14 @@ def replace_progress_callback(trainer: Any, callbacks: list[Any], logger: Any = import warnings warnings.warn(f"Could not add enhanced progress callback: {e}") + try: + # Add Train Logger + train_logger = ComponentFactory.create_callback("train_logger") + trainer.add_callback(train_logger) + except Exception as e: + if logger: + logger.log_rank_zero(f"Warning: Could not add train logger callback: {e}", level="warning") + else: + import warnings + + warnings.warn(f"Could not add train warning callback: {e}") diff --git a/QEfficient/finetune/experimental/core/config_manager.py b/QEfficient/finetune/experimental/core/config_manager.py index 9846c91944..d0fc9e47c6 100644 --- a/QEfficient/finetune/experimental/core/config_manager.py +++ b/QEfficient/finetune/experimental/core/config_manager.py @@ -330,6 +330,10 @@ class TrainingConfig: default="./training_results", metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, ) + log_file_name: str = field( + default=None, + metadata={"help": "The log_file output name."}, + ) overwrite_output_dir: bool = field( default=False, metadata={"help": "Whether to overwrite the output directory."},