Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion QEfficient/cloud/finetune_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from pathlib import Path
from typing import Any, Dict, List, Tuple

from QEfficient.finetune.experimental.core.callbacks import replace_progress_callback
from peft import get_peft_model

from QEfficient.finetune.experimental.core.callbacks import TrainingLogger, replace_progress_callback
from QEfficient.finetune.experimental.core.component_registry import ComponentFactory
from QEfficient.finetune.experimental.core.config_manager import (
ConfigManager,
Expand All @@ -38,6 +40,8 @@
f"Unable to import 'torch_qaic' package due to exception: {e}. Moving ahead without the torch_qaic extension.",
level=logging.WARNING,
)
# Separate logger instance for training logs, ensures rank 0 logging and file writing to avoid conflicts in distributed settings
train_logger = TrainingLogger(rank=0)


class FineTuningPipeline:
Expand Down Expand Up @@ -107,6 +111,7 @@ def get_trainer(self):

def _setup_environment(self) -> None:
"""Set up environment variables for output directories."""
self.rank = int(os.environ.get("RANK", "0"))
os.environ["OUTPUT_DIR"] = str(self.output_dir)
os.environ["TRACKIO_DIR"] = str(self.output_dir / "trackio_logs")
os.environ["TENSORBOARD_LOGGING_DIR"] = str(self.output_dir)
Expand Down Expand Up @@ -254,10 +259,21 @@ def _create_trainer(
dependencies = {}
if peft_config is not None:
dependencies["peft_config"] = peft_config
if self.rank == 0:
model_configuration = get_peft_model(model, peft_config)
trainable_params, all_param = model_configuration.get_nb_trainable_parameters()
pct = (trainable_params / all_param) * 100
model_configuration.unload() # Removing the peft adapters
train_logger.write(f"TRAINING INFO: Model has {all_param / 1e6:.4f} Million params.")
train_logger.write(
f"TRAINING INFO: Trainable params: {trainable_params} || "
f"all params: {all_param} || trainable%: {pct:.4f}"
)
trainer_cls, args_cls, additional_kwargs = ComponentFactory.create_trainer_config(trainer_type, **dependencies)

# Clean up training config: remove fields that shouldn't be passed to TrainingArguments
training_config.pop("device", None)
training_config.pop("log_file_name", None)
# Note: torch_dtype was already converted to fp16/bf16 flag in prepare_training_config
training_config.pop("deepspeed_config", None)
training_config.pop("torch_dtype", None)
Expand All @@ -280,6 +296,10 @@ def _create_trainer(
subset_eval_indices = list(range(0, int(num_samples - num_samples * split_ratio)))
eval_dataset = eval_dataset.select(subset_eval_indices)
train_dataset = train_dataset.select(subset_train_indices)
# Logging the number of training and evaluation samples
if self.rank == 0:
train_logger.write(f"TRAINING INFO: Length of Training Dataset is {len(train_dataset)}")
train_logger.write(f"TRAINING INFO: Length of Evaluation Dataset is {len(eval_dataset)}")
trainer = trainer_cls(
model=model,
processing_class=tokenizer,
Expand Down
134 changes: 134 additions & 0 deletions QEfficient/finetune/experimental/core/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
# -----------------------------------------------------------------------------

import json
import logging
import math
import os
import time
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Optional

from transformers import (
Expand All @@ -20,6 +25,8 @@
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState

from QEfficient.finetune.experimental.core.component_registry import ComponentFactory, registry
from QEfficient.finetune.experimental.core.config_manager import ConfigManager
from QEfficient.finetune.experimental.core.logger import Logger
from QEfficient.finetune.experimental.core.utils.profiler_utils import (
get_op_verifier_ctx,
init_qaic_profiling,
Expand All @@ -32,6 +39,122 @@
registry.callback("tensorboard")(TensorBoardCallback)


logger = Logger(__name__)

# Setting the path for dumping the log file
config = ConfigManager().config.training

output_dir = Path(config["output_dir"])
log_file_name = config.get("log_file_name")

log_file_name = (
output_dir / log_file_name if log_file_name else output_dir / f"training_logs_{datetime.now():%Y%m%d_%H%M%S}.txt"
)

log_file_name.parent.mkdir(parents=True, exist_ok=True)


@registry.callback("train_logger")
class TrainingLogger(TrainerCallback):
"""
A [`TrainerCallback`] that logs per epoch time, training metric (perplexity),training loss, evaluation metrics and loss etc.
These are only logged for rank = 0.
"""

def __init__(self, rank=0, log_file: str | None = log_file_name):
self.rank = rank # rank-safe logging (only rank 0)
# Log file setup
self.log_file = log_file
# Ensure directory exists
os.makedirs(os.path.dirname(self.log_file), exist_ok=True)
self.epoch_start_time = None
self.best_eval_loss = float("inf")

# ----------------------------------------------------
# Safe write to log (only rank 0)
# ----------------------------------------------------
def write(self, text):
if self.rank != 0:
return
logger.log_rank_zero(text)
try:
with open(self.log_file, "a") as f:
f.write(text + "\n")
f.flush()
os.fsync(f.fileno())

except OSError:
logging.exception("Failed to write to log file: %s", self.log_file)

# ----------------------------------------------------
# EPOCH BEGIN
# ----------------------------------------------------
def on_epoch_begin(self, args, state, control, **kwargs):
if self.rank != 0:
return

epoch = int(state.epoch) + 1
self.epoch_start_time = time.time()
if state.is_world_process_zero:
self.write(f"TRAINING INFO: Starting epoch {epoch}/{int(args.num_train_epochs)}")

# ----------------------------------------------------
# EVALUATION
# ----------------------------------------------------
def on_evaluate(self, args, state, control, metrics, **kwargs):
if self.rank != 0:
return

epoch = int(state.epoch)
eval_loss = None
eval_metric = None

for entry in reversed(state.log_history):
if "eval_loss" in entry:
eval_loss = entry["eval_loss"]
break
if eval_loss is not None:
eval_metric = math.exp(eval_loss)
# Track best eval loss
if eval_loss is not None and eval_loss < self.best_eval_loss:
self.best_eval_loss = eval_loss
if state.is_world_process_zero:
self.write(f"EVALUATION INFO: Best eval loss on epoch {epoch} is {eval_loss:.4f}")
if state.is_world_process_zero:
self.write(f"EVALUATION INFO: Epoch {epoch}: Eval Loss: {eval_loss:.4f} || Eval metric: {eval_metric:.4f}")

# ----------------------------------------------------
# EPOCH END — TRAIN LOSS + METRIC + TIME
# ----------------------------------------------------
def on_epoch_end(self, args, state, control, **kwargs):
if self.rank != 0:
return

epoch = int(state.epoch)
epoch_time = time.time() - self.epoch_start_time

# Extract the last recorded train loss
train_loss = None
for entry in reversed(state.log_history):
if "loss" in entry:
train_loss = entry["loss"]
break

# Compute perplexity safely
train_metric = None
if train_loss is not None:
train_metric = math.exp(train_loss)
if state.is_world_process_zero:
self.write(
f"TRAINING INFO: Epoch {epoch}: "
f" Train epoch loss: {train_loss:.4f} || "
f" Train metric: {train_metric} || "
f" Epoch time {epoch_time:.2f} sec"
)
state.log_history.append({"train/epoch_time_sec": epoch_time, "epoch": state.epoch})
control.should_log = True


@registry.callback("enhanced_progressbar")
class EnhancedProgressCallback(ProgressCallback):
"""
Expand Down Expand Up @@ -233,3 +356,14 @@ def replace_progress_callback(trainer: Any, callbacks: list[Any], logger: Any =
import warnings

warnings.warn(f"Could not add enhanced progress callback: {e}")
try:
# Add Train Logger
train_logger = ComponentFactory.create_callback("train_logger")
trainer.add_callback(train_logger)
except Exception as e:
if logger:
logger.log_rank_zero(f"Warning: Could not add train logger callback: {e}", level="warning")
else:
import warnings

warnings.warn(f"Could not add train warning callback: {e}")
4 changes: 4 additions & 0 deletions QEfficient/finetune/experimental/core/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,10 @@ class TrainingConfig:
default="./training_results",
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
)
log_file_name: str = field(
default=None,
metadata={"help": "The log_file output name."},
)
overwrite_output_dir: bool = field(
default=False,
metadata={"help": "Whether to overwrite the output directory."},
Expand Down
Loading