Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion QEfficient/cloud/finetune_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,8 @@ def _create_trainer(
if num_samples > 0:
# Truncating datasets to a smaller number of samples.
# If you want to use all data, set dataset_num_samples to -1 or remove it from config.
logger.warning("Using fewer samples may impact finetuning quality.")
if (num_samples * split_ratio) / len(train_dataset) <= 0.05:
logger.log_rank_zero("Using fewer samples may impact finetuning quality.", logging.WARNING)
subset_train_indices = list(range(0, int(num_samples * split_ratio)))
subset_eval_indices = list(range(0, int(num_samples - num_samples * split_ratio)))
eval_dataset = eval_dataset.select(subset_eval_indices)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,3 @@ callbacks:
early_stopping:
early_stopping_patience: 3 # Number of epochs to wait before stopping training
early_stopping_threshold: 0.001 # Minimum change in metric to qualify as improvement
tensorboard:
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,3 @@ callbacks:
early_stopping:
early_stopping_patience: 3 # Number of epochs to wait before stopping training
early_stopping_threshold: 0.001 # Minimum change in metric to qualify as improvement
tensorboard:
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,3 @@ callbacks:
early_stopping:
early_stopping_patience: 3 # Number of epochs to wait before stopping training
early_stopping_threshold: 0.001 # Minimum change in metric to qualify as improvement
tensorboard:
1 change: 0 additions & 1 deletion QEfficient/finetune/experimental/core/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from QEfficient.utils.device_utils import is_nsp_free

logger = Logger(__name__)
logger.logger.propagate = False


@dataclass
Expand Down
7 changes: 4 additions & 3 deletions QEfficient/finetune/experimental/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
)

logger = Logger(__name__)
logger.logger.propagate = False


class BaseDataset(Dataset, ABC):
Expand Down Expand Up @@ -102,9 +101,11 @@ def __init__(
if not os.path.isfile(self.json_file_path):
raise FileNotFoundError(f"JSON file not found or invalid: '{self.json_file_path}'")
if self.prompt_template and self.prompt_func_path:
logger.info("Both prompt_template and prompt_func are provided. Using prompt_template for preprocessing.")
logger.log_rank_zero(
"Both prompt_template and prompt_func are provided. Using prompt_template for preprocessing."
)
if self.completion_template and self.completion_func_path:
logger.info(
logger.log_rank_zero(
"Both completion_template and completion_func are provided. Using completion_template for preprocessing."
)
if self.prompt_template is None and self.prompt_func_path is None:
Expand Down
41 changes: 35 additions & 6 deletions QEfficient/finetune/experimental/core/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@


import logging
import sys
from pathlib import Path
from typing import Optional

from transformers.utils.logging import get_logger as hf_get_logger

from QEfficient.finetune.experimental.core.utils.dist_utils import get_local_rank
from QEfficient.finetune.experimental.core.utils.dist_utils import is_global_rank_zero


# -----------------------------------------------------------------------------
# Logger usage:
Expand All @@ -27,6 +27,34 @@
# Attach file handler later if needed:
# logger.prepare_for_logs(output_dir="logs", log_level="DEBUG")
# -----------------------------------------------------------------------------
class QEffFormatter(logging.Formatter):
"""
Formatter class used to set colors for printing different logging levels of messages on console.
"""

cyan: str = "\x1b[38;5;14m"
yellow: str = "\x1b[33;20m"
red: str = "\x1b[31;20m"
bold_red: str = "\x1b[31;1m"
reset: str = "\x1b[0m"
common_format: str = "%(levelname)s - %(name)s - %(message)s" # type: ignore
format_with_line_info = "%(levelname)s - %(name)s - %(message)s (%(filename)s:%(lineno)d)" # type: ignore

FORMATS = {
logging.DEBUG: cyan + format_with_line_info + reset,
logging.INFO: cyan + common_format + reset,
logging.WARNING: yellow + common_format + reset,
logging.ERROR: red + format_with_line_info + reset,
logging.CRITICAL: bold_red + format_with_line_info + reset,
}

def format(self, record):
"""
Overriding the base class method to Choose format based on log level.
"""
log_fmt = self.FORMATS.get(record.levelno)
formatter = logging.Formatter(log_fmt)
return formatter.format(record)


class Logger:
Expand All @@ -48,17 +76,17 @@ def __init__(
"""
self.logger = hf_get_logger(name)
self.logger.setLevel(level)

self.logger.propagate = False
# Clear any existing handlers
self.logger.handlers.clear()

# Create formatter
self.formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")

# Console handler
console_handler = logging.StreamHandler(sys.stdout)
console_handler = logging.StreamHandler()
console_handler.setLevel(level)
console_handler.setFormatter(self.formatter)
console_handler.setFormatter(QEffFormatter())
self.logger.addHandler(console_handler)

# File handler (if log_file is provided)
Expand Down Expand Up @@ -100,7 +128,7 @@ def log_rank_zero(self, message: str, level: int = logging.INFO) -> None:
message: Message to log
level: Logging level
"""
if get_local_rank() == 0:
if is_global_rank_zero():
self.logger.log(level, message)

def log_exception(self, message: str, exception: Exception, raise_exception: bool = True) -> None:
Expand Down Expand Up @@ -130,6 +158,7 @@ def prepare_for_logs(self, output_dir: Optional[str] = None, log_level: str = "I
# Convert string log level to logging constant
level = getattr(logging, log_level.upper(), logging.INFO)
self.logger.setLevel(level)
self.logger.propagate = False

# Update existing handlers' levels
for handler in self.logger.handlers:
Expand Down
1 change: 0 additions & 1 deletion QEfficient/finetune/experimental/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from QEfficient.finetune.experimental.core.utils.dataset_utils import insert_pad_token

logger = Logger(__name__)
logger.logger.propagate = False


class BaseModel(nn.Module, ABC):
Expand Down
17 changes: 17 additions & 0 deletions QEfficient/finetune/experimental/core/utils/dist_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
import os

import torch.distributed as dist

Expand Down Expand Up @@ -37,3 +38,19 @@ def get_world_size() -> int:
def is_main_process() -> bool:
"""Check if the current process is the main process (rank 0)."""
return get_rank() == 0


def get_global_rank() -> int:
"""Return global rank if available (torchrun/deepspeed), else fall back to local rank."""
r = os.environ.get("RANK")
if r is not None:
try:
return int(r)
except ValueError:
return 0
# Fallback to local rank
return int(get_local_rank())


def is_global_rank_zero() -> bool:
return get_global_rank() == 0
24 changes: 15 additions & 9 deletions QEfficient/finetune/experimental/tests/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def test_init_with_file(self, tmp_path):
def test_log_levels(self, caplog):
"""Test all log levels work correctly"""
logger = Logger("level_test_logger", level=logging.DEBUG)
logger.logger.propagate = True

with caplog.at_level(logging.DEBUG):
Comment thread
tchawada marked this conversation as resolved.
logger.debug("Debug message")
Expand All @@ -63,22 +64,24 @@ def test_log_levels(self, caplog):
assert "Error message" in caplog.text
assert "Critical message" in caplog.text

@patch("QEfficient.finetune.experimental.core.logger.get_local_rank")
def test_log_rank_zero_positive_case(self, mock_get_local_rank, caplog):
@patch("QEfficient.finetune.experimental.core.logger.is_global_rank_zero")
def test_log_rank_zero_positive_case(self, mock_get_global_rank, caplog):
"""Test rank zero logging functionality"""
mock_get_local_rank.return_value = 0
mock_get_global_rank.return_value = True
logger = Logger("rank_test_logger")
logger.logger.propagate = True

with caplog.at_level(logging.INFO):
logger.log_rank_zero("Rank zero message")

assert "Rank zero message" in caplog.text

@patch("QEfficient.finetune.experimental.core.logger.get_local_rank")
def test_log_rank_zero_negative_case(self, mock_get_local_rank, caplog):
@patch("QEfficient.finetune.experimental.core.logger.is_global_rank_zero")
def test_log_rank_zero_negative_case(self, mock_get_global_rank, caplog):
"""Test to verify that only rank‑zero messages are logged"""
mock_get_local_rank.return_value = 1
mock_get_global_rank.return_value = False
logger = Logger("rank_test_logger")
logger.logger.propagate = True

with caplog.at_level(logging.INFO):
logger.log_rank_zero("Should not appear")
Expand All @@ -88,6 +91,7 @@ def test_log_rank_zero_negative_case(self, mock_get_local_rank, caplog):
def test_log_exception_raise(self, caplog):
"""Test exception logging with raising"""
logger = Logger("exception_test_logger")
logger.logger.propagate = True

with pytest.raises(ValueError), caplog.at_level(logging.ERROR):
logger.log_exception("Custom error", ValueError("Test exception"), raise_exception=True)
Expand All @@ -99,6 +103,7 @@ def test_log_exception_raise(self, caplog):
def test_log_exception_no_raise(self, caplog):
"""Test exception logging without raising"""
logger = Logger("exception_test_logger")
logger.logger.propagate = True

with caplog.at_level(logging.ERROR):
logger.log_exception("Custom error", ValueError("Test exception"), raise_exception=False)
Expand Down Expand Up @@ -168,7 +173,7 @@ def test_get_logger_with_file(self, tmp_path):

# Check that we have 2 handlers (console + file)
assert len(logger.logger.handlers) == 2 # Console + file
assert isinstance(logger.logger.handlers[1], logging.FileHandler)
any(isinstance(h, logging.FileHandler) for h in logger.logger.handlers)

# Check file exists
assert log_file.exists()
Expand All @@ -188,6 +193,7 @@ def test_complete_workflow(self, tmp_path, caplog):
# Setup
log_file = tmp_path / "workflow.log"
logger = Logger("workflow_test", str(log_file), logging.DEBUG)
logger.logger.propagate = True

# Test all methods
logger.debug("Debug test")
Expand All @@ -203,8 +209,8 @@ def test_complete_workflow(self, tmp_path, caplog):
logger.log_exception("Caught exception", e, raise_exception=False)

# Test rank zero logging
with patch("QEfficient.finetune.experimental.core.logger.get_local_rank") as mock_rank:
mock_rank.return_value = 0
with patch("QEfficient.finetune.experimental.core.logger.is_global_rank_zero") as mock_rank:
mock_rank.return_value = True
logger.log_rank_zero("Rank zero test")

# Verify all messages were logged
Expand Down
2 changes: 0 additions & 2 deletions QEfficient/utils/device_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ def is_nsp_free():
# Check if NSP free is eqaul to total nsp
if nsp_free != nsp_total:
raise RuntimeError(f"QAIC device {qid_idx} does not have {nsp_total} NSP free")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WHy are we changing in the device_utils of AOT? Don't we have a separte device_utils for us, in finetune/experimental/

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meet told to put it here, first I placed it in utils.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay

else:
logger.info(f"QAIC device {qid_idx} has {nsp_free} NSP free")
else:
logger.warning("Failed to parse NSP free information from qaic-util output")

Expand Down
30 changes: 25 additions & 5 deletions docs/source/hf_finetune.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,31 +50,51 @@ export QAIC_DEVICE_LOG_LEVEL=0 # Device-level logs
export QAIC_DEBUG=1 # Show CPU fallback ops, etc.

# Set temp directory
export TMPDIR = $HOME/tmp
export TMPDIR=$HOME/tmp
```

### Step-by-Step Guide to run a fine-tuning job

### For QAIC Training
For Docker-based environments, use the provided `torch-qaic-env` environment.

```bash
source /opt/torch-qaic-env/bin/activate
python -m venv finetune_env
source finetune_env/bin/activate
git clone https://github.com/quic/efficient-transformers.git
git checkout ft_experimental
cd efficient-transformers
git checkout ft_experimental
pip install -e .
pip install --index-url https://download.pytorch.org/whl/cpu --extra-index-url https://devpi.qualcomm.com/qcom/dev/+simple --trusted-host devpi.qualcomm.com "torch==2.9.1+cpu" "torchvision==0.24.1+cpu" "torchaudio==2.9.1+cpu"
pip install trl==0.22.0
git clone https://github.com/quic-swatia/transformers.git
cd .. && git clone https://github.com/quic-swatia/transformers.git
cd transformers
git checkout version-4.55.0 && pip install -e .
cd .. && QAIC_VISIBLE_DEVICES=0 python QEfficient/cloud/finetune_experimental.py QEfficient/finetune/experimental/configs/sft_single_device_gsm8k_config.yaml
cd .. && cd efficient-transformers
QAIC_VISIBLE_DEVICES=0 python QEfficient/cloud/finetune_experimental.py QEfficient/finetune/experimental/configs/sft_single_device_gsm8k_config.yaml

```

> **Note**
> If you’re using the `torch-qaic-env` Docker environment, `torch_qaic` and `accelerate` may already be installed.

### For CUDA Training

```bash
python -m venv finetune_env
source finetune_env/bin/activate
git clone https://github.com/quic/efficient-transformers.git
cd efficient-transformers
git checkout ft_experimental
pip install -e .
pip install torch==2.9.1 torchvision==0.24.1 torchaudio==2.9.1 --index-url https://download.pytorch.org/whl/cu130
pip install trl==0.22.0
cd .. && git clone https://github.com/quic-swatia/transformers.git
cd transformers
git checkout version-4.55.0 && pip install -e .
cd .. && cd efficient-transformers
CUDA_VISIBLE_DEVICES=0 torchrun --nproc-per-node 1 -m QEfficient.cloud.finetune_experimental --device cuda --num_epochs 1 --model_name meta-llama/Llama-3.2-3B --dataset_name yahma/alpaca-cleaned --train_batch_size 1 --gradient_accumulation_steps 768 --prompt_func QEfficient.finetune.experimental.preprocessing.alpaca_func:create_alpaca_prompt --completion_template {output}
```
***
## Finetuning

Expand Down
Loading