From 51c50103bd53eee04676c4a62c66536246bda67c Mon Sep 17 00:00:00 2001 From: Meet Patel Date: Fri, 28 Nov 2025 17:30:12 +0530 Subject: [PATCH 1/9] [QEff. Finetune]: Added component registry and factory functionality. (#645) - Added functionality to register dataset, model, optimizer, trainer objects in a registry and fetch the class of given object based on configuration provided. - Also, added simple test cases to verify the functionality. --------- Signed-off-by: Meet Patel --- .../experimental/core/component_registry.py | 194 ++++++++++++++++++ .../experimental/tests/test_registry.py | 167 +++++++++++++++ 2 files changed, 361 insertions(+) create mode 100644 QEfficient/finetune/experimental/tests/test_registry.py diff --git a/QEfficient/finetune/experimental/core/component_registry.py b/QEfficient/finetune/experimental/core/component_registry.py index d647b73a65..7744d71e6a 100644 --- a/QEfficient/finetune/experimental/core/component_registry.py +++ b/QEfficient/finetune/experimental/core/component_registry.py @@ -4,3 +4,197 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- + + +import logging +from typing import Callable, Dict, Optional, Type + +# from QEfficient.finetune.experimental.core.logger import get_logger + +# logger = get_logger() +logger = logging.getLogger(__name__) + + +def get_object(obj_dict: Dict, name: str, object_type: str, list_fn: Callable) -> Optional[Type]: + """Utility to get object from a dictionary with error handling.""" + obj = obj_dict.get(name) + if obj is None: + raise ValueError(f"Unknown {object_type}: {name}. Available: {list_fn()}") + return obj + + +class ComponentRegistry: + """Registry for managing different training components.""" + + def __init__(self): + self._optimizers: Dict[str, Type] = {} + self._schedulers: Dict[str, Type] = {} + self._datasets: Dict[str, Type] = {} + self._models: Dict[str, Type] = {} + self._data_collators: Dict[str, Type] = {} + self._metrics: Dict[str, Type] = {} + self._loss_functions: Dict[str, Type] = {} + self._callbacks: Dict[str, Type] = {} + self._hooks: Dict[str, Type] = {} + self._trainer_modules: Dict[str, Type] = {} + + def trainer_module(self, name: str, args_cls=None, required_kwargs=None): + """ + Decorator to register a trainer module with its configuration. + Each trainer module has to be binded to its args class and required kwargs. + + Args: + name: Name of the trainer type + args_cls: The arguments class for this trainer + required_kwargs: Dictionary of required keyword arguments and their default values + """ + required_kwargs = required_kwargs or {} + + def decorator(trainer_cls): + self._trainer_modules[name] = { + "trainer_cls": trainer_cls, + "args_cls": args_cls, + "required_kwargs": required_kwargs, + } + logger.info(f"Registered trainer module: {name}") + return self._trainer_modules[name] + + return decorator + + def optimizer(self, name: str): + """Decorator to register an optimizer class.""" + + def decorator(cls: Type): + self._optimizers[name] = cls + logger.info(f"Registered optimizer: {name}") + return cls + + return decorator + + def scheduler(self, name: str): + """Decorator to register a scheduler class.""" + + def decorator(cls: Type): + self._schedulers[name] = cls + logger.info(f"Registered scheduler: {name}") + return cls + + return decorator + + def dataset(self, name: str): + """Decorator to register a dataset class.""" + + def decorator(cls: Type): + self._datasets[name] = cls + logger.info(f"Registered dataset: {name}") + return cls + + return decorator + + def model(self, name: str): + """Decorator to register a model class.""" + + def decorator(cls: Type): + self._models[name] = cls + logger.info(f"Registered model: {name}") + return cls + + return decorator + + def data_collator(self, name: str): + """Decorator to register a data collator class.""" + + def decorator(fn_pointer: Type): + self._data_collators[name] = fn_pointer + logger.info(f"Registered data collator: {name}") + return fn_pointer + + return decorator + + def loss_function(self, name: str): + """Decorator to register a loss function class.""" + + def decorator(cls: Type): + self._loss_functions[name] = cls + logger.info(f"Registered loss function: {name}") + return cls + + return decorator + + def callback(self, name: str): + """Decorator to register a callback class.""" + + def decorator(cls: Type): + self._callbacks[name] = cls + logger.info(f"Registered callback: {name}") + return cls + + return decorator + + def get_trainer_module(self, name: str) -> Optional[Type]: + """Get trainer module class by name.""" + return get_object(self._trainer_modules, name, "trainer module", self.list_trainer_modules) + + def get_optimizer(self, name: str) -> Optional[Type]: + """Get optimizer class by name.""" + return get_object(self._optimizers, name, "optimizer", self.list_optimizers) + + def get_scheduler(self, name: str) -> Optional[Type]: + """Get scheduler class by name.""" + return get_object(self._schedulers, name, "scheduler", self.list_schedulers) + + def get_dataset(self, name: str) -> Optional[Type]: + """Get dataset class by name.""" + return get_object(self._datasets, name, "dataset", self.list_datasets) + + def get_model(self, name: str) -> Optional[Type]: + """Get model class by name.""" + return get_object(self._models, name, "model", self.list_models) + + def get_data_collator(self, name: str) -> Optional[Type]: + """Get data collator class by name.""" + return get_object(self._data_collators, name, "data collator", self.list_data_collators) + + def get_loss_function(self, name: str) -> Optional[Type]: + """Get loss function class by name.""" + return get_object(self._loss_functions, name, "loss function", self.list_loss_functions) + + def get_callback(self, name: str) -> Optional[Type]: + """Get callback class by name.""" + return get_object(self._callbacks, name, "callback", self.list_callbacks) + + def list_trainer_modules(self) -> list[str]: + """List all registered trainer modules.""" + return list(self._trainer_modules.keys()) + + def list_optimizers(self) -> list[str]: + """List all registered optimizers.""" + return list(self._optimizers.keys()) + + def list_schedulers(self) -> list[str]: + """List all registered schedulers.""" + return list(self._schedulers.keys()) + + def list_datasets(self) -> list[str]: + """List all registered datasets.""" + return list(self._datasets.keys()) + + def list_models(self) -> list[str]: + """List all registered models.""" + return list(self._models.keys()) + + def list_data_collators(self) -> list[str]: + """List all registered data collators.""" + return list(self._data_collators.keys()) + + def list_loss_functions(self) -> list[str]: + """List all registered loss functions.""" + return list(self._loss_functions.keys()) + + def list_callbacks(self) -> list[str]: + """List all registered callbacks.""" + return list(self._callbacks.keys()) + + +# Global registry instance +registry = ComponentRegistry() diff --git a/QEfficient/finetune/experimental/tests/test_registry.py b/QEfficient/finetune/experimental/tests/test_registry.py new file mode 100644 index 0000000000..3e10aa8208 --- /dev/null +++ b/QEfficient/finetune/experimental/tests/test_registry.py @@ -0,0 +1,167 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import pytest + +from QEfficient.finetune.experimental.core.component_registry import ComponentRegistry, get_object, registry + + +class TestComponentRegistry: + @pytest.fixture(autouse=True) + def setUp(self): + """Set up test fixtures before each test method.""" + self.registry = ComponentRegistry() + + @pytest.mark.parametrize( + "register_method, get_method, object_name", + [ + ("trainer_module", "get_trainer_module", "test_trainer"), + ("optimizer", "get_optimizer", "test_optimizer"), + ("scheduler", "get_scheduler", "test_scheduler"), + ("dataset", "get_dataset", "test_dataset"), + ("model", "get_model", "test_model"), + ("data_collator", "get_data_collator", "test_collator"), + ("loss_function", "get_loss_function", "test_loss"), + ("callback", "get_callback", "test_callback"), + ], + ) + def test_object_success(self, register_method: str, get_method: str, object_name: str): + """Test object registration decorator.""" + + class MockObject: + pass + + # Register with decorator + getattr(self.registry, register_method)(object_name)(MockObject) + + # Verify registration + retrieved = getattr(self.registry, get_method)(object_name) + if register_method == "trainer_module": + retrieved = retrieved["trainer_cls"] + assert retrieved == MockObject + + @pytest.mark.parametrize( + "object_type, get_method", + [ + ("trainer module", "get_trainer_module"), + ("optimizer", "get_optimizer"), + ("scheduler", "get_scheduler"), + ("dataset", "get_dataset"), + ("model", "get_model"), + ("data collator", "get_data_collator"), + ("loss function", "get_loss_function"), + ("callback", "get_callback"), + ], + ) + def test_object_failure(self, object_type: str, get_method: str, object_name: str = "non_existent"): + """Test failure when retrieving non-existent object.""" + with pytest.raises(ValueError) as exc_info: + getattr(self.registry, get_method)(object_name) + + assert f"Unknown {object_type}" in str(exc_info.value) + + def test_init_empty_registries(self): + """Test that all registries are initialized as empty dictionaries.""" + assert len(self.registry._optimizers) == 0 + assert len(self.registry._schedulers) == 0 + assert len(self.registry._datasets) == 0 + assert len(self.registry._models) == 0 + assert len(self.registry._data_collators) == 0 + assert len(self.registry._metrics) == 0 + assert len(self.registry._loss_functions) == 0 + assert len(self.registry._callbacks) == 0 + assert len(self.registry._hooks) == 0 + assert len(self.registry._trainer_modules) == 0 + + def test_trainer_module_with_args_and_kwargs(self): + """Test trainer module registration with args class and required kwargs.""" + + class MockArgs: + pass + + class MockTrainer: + pass + + # Register with decorator including args class and required kwargs + self.registry.trainer_module( + "test_trainer_with_args", args_cls=MockArgs, required_kwargs={"param1": "default1", "param2": "default2"} + )(MockTrainer) + + # Verify registration details + module_info = self.registry.get_trainer_module("test_trainer_with_args") + assert module_info["trainer_cls"] == MockTrainer + assert module_info["args_cls"] == MockArgs + assert module_info["required_kwargs"] == {"param1": "default1", "param2": "default2"} + + def test_list_methods(self): + """Test all list methods return correct keys.""" + + # Register some dummy items + class DummyClass: + pass + + self.registry.optimizer("opt1")(DummyClass) + self.registry.scheduler("sched1")(DummyClass) + self.registry.dataset("ds1")(DummyClass) + self.registry.model("model1")(DummyClass) + self.registry.data_collator("coll1")(lambda x: x) + self.registry.loss_function("loss1")(DummyClass) + self.registry.callback("cb1")(DummyClass) + self.registry.trainer_module("tm1")(DummyClass) + + # Test lists + assert self.registry.list_optimizers() == ["opt1"] + assert self.registry.list_schedulers() == ["sched1"] + assert self.registry.list_datasets() == ["ds1"] + assert self.registry.list_models() == ["model1"] + assert self.registry.list_data_collators() == ["coll1"] + assert self.registry.list_loss_functions() == ["loss1"] + assert self.registry.list_callbacks() == ["cb1"] + assert self.registry.list_trainer_modules() == ["tm1"] + + def test_logging_on_registration(self, mocker): + """Test that registration logs messages.""" + mock_logger = mocker.patch("QEfficient.finetune.experimental.core.component_registry.logger") + + class MockClass: + pass + + # Test optimizer registration logging + self.registry.optimizer("test_opt")(MockClass) + mock_logger.info.assert_called_with("Registered optimizer: test_opt") + + # Reset mock + mock_logger.reset_mock() + + # Test trainer module registration logging + self.registry.trainer_module("test_tm")(MockClass) + mock_logger.info.assert_called_with("Registered trainer module: test_tm") + + +class TestGetObjectFunction: + def test_get_object_success(self): + """Test get_object function success case.""" + test_dict = {"key1": "value1", "key2": "value2"} + + result = get_object(test_dict, "key1", "test_type", lambda: ["key1", "key2"]) + assert result == "value1" + + def test_get_object_failure(self): + """Test get_object function failure case.""" + test_dict = {"key1": "value1"} + + with pytest.raises(ValueError) as exc_info: + get_object(test_dict, "nonexistent", "test_type", lambda: ["key1", "key2"]) + + assert "Unknown test_type: nonexistent" in str(exc_info.value) + assert "Available: ['key1', 'key2']" in str(exc_info.value) + + +class TestGlobalRegistry: + def test_global_registry_instance(self): + """Test that global registry instance exists and is of correct type.""" + assert isinstance(registry, ComponentRegistry) From 3ab4dd70e44e7b7a751cc246c30ce116d995f795 Mon Sep 17 00:00:00 2001 From: Tanisha Chawada Date: Fri, 5 Dec 2025 15:07:40 +0530 Subject: [PATCH 2/9] [QEff. Finetune]: Adding optimizer registry and its test cases (#649) Adding a Script for Registering and Retrieving Optimizer Classes The script includes: get_optimizer() Returns the optimizer class and kwargs. Additionally, there is a test_optimizer.py script that validates the functionality of the optimizer registration and retrieval process. --------- Signed-off-by: Tanisha Chawada --- .../finetune/experimental/core/optimizer.py | 25 +++++ .../experimental/tests/test_optimizer.py | 96 +++++++++++++++++++ 2 files changed, 121 insertions(+) create mode 100644 QEfficient/finetune/experimental/tests/test_optimizer.py diff --git a/QEfficient/finetune/experimental/core/optimizer.py b/QEfficient/finetune/experimental/core/optimizer.py index d647b73a65..d4f82cbebb 100644 --- a/QEfficient/finetune/experimental/core/optimizer.py +++ b/QEfficient/finetune/experimental/core/optimizer.py @@ -4,3 +4,28 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- + +""" +Optimizer components for the training system. +""" + +import torch.optim as optim + +from QEfficient.finetune.experimental.core.component_registry import registry + +registry.optimizer("Adam")(optim.Adam) +registry.optimizer("AdamW")(optim.AdamW) +registry.optimizer("SGD")(optim.SGD) + + +def prepare_optimizer(opt_config): + """ + Create optimizer from config. + Args: opt_config: Dictionary containing optimizer configuration. + Returns: Tuple of optimizer class and its arguments. + """ + opt_name = opt_config.pop("optimizer_name") + opt_cls = registry.get_optimizer(opt_name) + opt_config["lr"] = float(opt_config["lr"]) + optimizer_cls_and_kwargs = (opt_cls, opt_config) + return optimizer_cls_and_kwargs diff --git a/QEfficient/finetune/experimental/tests/test_optimizer.py b/QEfficient/finetune/experimental/tests/test_optimizer.py new file mode 100644 index 0000000000..e105d5ddf9 --- /dev/null +++ b/QEfficient/finetune/experimental/tests/test_optimizer.py @@ -0,0 +1,96 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import copy + +import pytest +import torch.nn as nn +import torch.optim as optim + +from QEfficient.finetune.experimental.core.component_registry import registry +from QEfficient.finetune.experimental.core.optimizer import prepare_optimizer + +OPTIMIZER_CONFIGS = { + "Adam": { + "optimizer_name": "Adam", + "opt_cls": optim.Adam, + "lr": 1e-4, + "weight_decay": 0.01, + "betas": (0.9, 0.999), + "eps": 1e-8, + "amsgrad": False, + }, + "AdamW": { + "optimizer_name": "AdamW", + "opt_cls": optim.AdamW, + "lr": 1e-4, + "weight_decay": 0.01, + "betas": (0.9, 0.999), + "eps": 1e-8, + "amsgrad": False, + }, + "SGD": { + "optimizer_name": "SGD", + "opt_cls": optim.SGD, + "lr": 1e-4, + "momentum": 0.9, + "weight_decay": 0.01, + "dampening": 0.0, + "nesterov": False, + }, + "RMSprop": { + "optimizer_name": "RMSprop", + "opt_cls": optim.RMSprop, + }, +} + +REGISTRY_CONFIG = { + "RMSprop": { + "optimizer_name": "RMSprop", + "opt_cls": optim.RMSprop, + }, +} + + +@pytest.fixture +def dummy_model(): + return nn.Sequential( + nn.Linear(10, 5), + nn.ReLU(), + nn.Linear(5, 1), + ) + + +@pytest.mark.parametrize("opt_name", OPTIMIZER_CONFIGS.keys()) +def test_optimizers(opt_name, dummy_model): + """Test that all registered optimizers can be created with their configs.""" + config = copy.deepcopy(OPTIMIZER_CONFIGS[opt_name]) + + config.pop("opt_cls") + try: + optimizer_class_and_kwargs = prepare_optimizer(config) + assert optimizer_class_and_kwargs is not None + except ValueError as e: + assert "Unknown optimizer" in str(e) + return + optimizer_class = optimizer_class_and_kwargs[0] + opt_inst = optimizer_class(dummy_model.parameters(), **optimizer_class_and_kwargs[1]) + assert isinstance(opt_inst, optim.Optimizer) + assert len(list(opt_inst.param_groups)) == 1 + + for key in ["lr", "weight_decay", "betas", "eps", "momentum", "dampening", "nesterov", "amsgrad"]: + if key in config: + assert opt_inst.param_groups[0][key] == config[key], f"{key} mismatch" + + +@pytest.mark.parametrize("opt_name, opt_cls", REGISTRY_CONFIG.items()) +def test_registered_optimizer(opt_name, opt_cls): + """Test that the optimizer registerd correctly.""" + registry.optimizer(opt_name)(opt_cls) + optimizer_class = registry.get_optimizer(opt_name) + assert optimizer_class is not None + assert optimizer_class == opt_cls From dc86ad2196dcea8de90694f5462de00f48dda51c Mon Sep 17 00:00:00 2001 From: Dhiraj Kumar Sah Date: Fri, 5 Dec 2025 17:39:32 +0530 Subject: [PATCH 3/9] [QEff. Finetune]: Added Base dataset class and SFT dataset classes along with its test cases. (#647) Edited the SFTDataset class to enable custom dataset loading. Updated the dataset.py file to only enable support for SFTDataset type. Created test file to check the functionalities. --------- Signed-off-by: Dhiraj Kumar Sah --- .../finetune/experimental/core/dataset.py | 251 +++++++++ .../experimental/core/utils/dataset_utils.py | 25 + .../experimental/tests/test_dataset.py | 528 ++++++++++++++++++ 3 files changed, 804 insertions(+) create mode 100644 QEfficient/finetune/experimental/tests/test_dataset.py diff --git a/QEfficient/finetune/experimental/core/dataset.py b/QEfficient/finetune/experimental/core/dataset.py index d647b73a65..4a243c40b2 100644 --- a/QEfficient/finetune/experimental/core/dataset.py +++ b/QEfficient/finetune/experimental/core/dataset.py @@ -4,3 +4,254 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- + +""" +Dataset components for the training system. +""" + +import importlib +import os +import re +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict + +from datasets import load_dataset, load_dataset_builder +from torch.utils.data import Dataset + +from QEfficient.finetune.experimental.core.component_registry import registry +from QEfficient.finetune.experimental.core.utils.dataset_utils import ( + apply_train_test_split, +) + + +class BaseDataset(Dataset, ABC): + """Base class for all datasets to ensure consistent interface.""" + + def __init__(self, dataset_name: str, split: str, seed: int = 42, **kwargs): + self.dataset_name = dataset_name + self.split = split + self.seed = seed + self.kwargs = kwargs + self._initialize_dataset() + + @abstractmethod + def _initialize_dataset(self): + """Subclasses should implement this to load and prepare the dataset.""" + pass + + @abstractmethod + def __len__(self): + """Return the number of samples in the dataset.""" + pass + + @abstractmethod + def __getitem__(self, idx): + """Should return a dictionary with 'input_ids', 'attention_mask', and 'labels'.""" + pass + + +@registry.dataset("sft_dataset") +class SFTDataset(BaseDataset): + """ + A Supervised Fine-Tuning (SFT) dataset class for text data. + + This class handles loading data from Hugging Face datasets or custom JSON files, + filtering out invalid samples, and applying a prompt/completion templating for SFT tasks. + + Args: + dataset_name (str): The name of the dataset to load from Hugging Face datasets. + Ignored if json_file_path is provided. + split (str): The dataset split to use (e.g., "train", "validation", "test"). + split_ratio (float): Ratio for train/test split when only one split is available. + seed (int): Random seed for reproducibility. + json_file_path (str, optional): Path to a custom JSON file containing the dataset. + If provided, this takes precedence over dataset_name. + prompt_template (str): A string template for constructing the prompt. Variables in the + template should be enclosed in curly braces, e.g., "Answer the question: {question}". + completion_template (str): A string template for constructing the completion (target). + Variables should be enclosed in curly braces, e.g., "{answer}". + + Raises: + RuntimeError: If any variables specified in `prompt_template` or `completion_template` + are not found as columns in the loaded dataset. + """ + + def __init__( + self, + dataset_name: str, + split: str, + split_ratio: float = 0.8, + seed: int = 42, + **kwargs, + ): + self.split_ratio = split_ratio + self.json_file_path = kwargs.get("json_file_path", None) + self.prompt_template = kwargs.get("prompt_template", None) + self.completion_template = kwargs.get("completion_template", None) + self.prompt_func_path = kwargs.get("prompt_func", None) + self.completion_func_path = kwargs.get("completion_func", None) + self.remove_samples_with_empty_columns = kwargs.get("remove_samples_with_empty_columns", True) + + if self.json_file_path not in (None, ""): + if not os.path.isfile(self.json_file_path): + raise FileNotFoundError(f"JSON file not found or invalid: '{self.json_file_path}'") + if (self.prompt_template is None and self.prompt_func_path is None) or ( + self.prompt_template is not None and self.prompt_func_path is not None + ): + raise RuntimeError("Either provide prompt_template or prompt_func in the config.") + if (self.completion_template is None and self.completion_func_path is None) or ( + self.completion_template is not None and self.completion_func_path is not None + ): + raise RuntimeError("Either provide completion_template or completion_func in the config.") + + # Call parent class __init__ which will call _initialize_dataset + super().__init__(dataset_name, split, seed, **kwargs) + + def _initialize_dataset(self): + """ + Initialize the dataset from either HuggingFace or a custom JSON file. + + This method loads the dataset, applies splitting if necessary, and prepares + it for preprocessing with prompt/completion templates. + """ + if self.json_file_path: + # Load dataset from JSON file + self.dataset = load_dataset("json", data_files=self.json_file_path, split="train") + + # Apply train/test split if needed + if self.split in ["train", "test"]: + self.dataset = apply_train_test_split(self.dataset, self.split_ratio, self.split, self.seed) + else: + # Load dataset from HuggingFace + db = load_dataset_builder(self.dataset_name) + available_splits = [] + if db.info.splits is not None: + available_splits = list(db.info.splits.keys()) + + if self.split not in available_splits: + raise ValueError(f"Split {self.split} is not available for dataset {self.dataset_name}.") + + # FIXME: Add streaming support for larger datasets. + self.dataset = load_dataset(self.dataset_name, split=self.split) + + if len(available_splits) == 1: + self.dataset = apply_train_test_split(self.dataset, self.split_ratio, self.split, self.seed) + + self.dataset = self._setup_templates(self.dataset, self.dataset.column_names) + + def _setup_templates(self, dataset, dataset_columns): + """ + Set up prompt/completion templates or functions and apply preprocessing. + """ + if self.prompt_template: + self.prompt_func = None + # Extract variables from templates and check if they exist in dataset columns + prompt_variables = re.findall(r"\{(.*?)\}", self.prompt_template) + for var in prompt_variables: + if var not in dataset_columns: + raise RuntimeError( + f"Prompt template variable '{var}' not found in dataset columns: {dataset_columns}." + ) + else: + prompt_variables = dataset_columns + self.prompt_func = self.import_func(self.prompt_func_path) + + if self.completion_template: + self.completion_func = None + # Extract variables from templates and check if they exist in dataset columns + completion_variables = re.findall(r"\{(.*?)\}", self.completion_template) + for var in completion_variables: + if var not in dataset_columns: + raise RuntimeError( + f"Completion template variable '{var}' not found in dataset columns: {dataset_columns}." + ) + else: + completion_variables = dataset_columns + self.completion_func = self.import_func(self.completion_func_path) + + # Filter out samples with None or empty strings in relevant columns + relevant_columns = list(set(prompt_variables + completion_variables)) + if self.remove_samples_with_empty_columns: + dataset = dataset.filter(lambda example: self._filter_empty_or_none_samples(example, relevant_columns)) + return dataset + + def import_func(self, func_path: str) -> Callable: + if ":" not in func_path: + raise ValueError("func_path must be in the format 'module_file_path:function_name'.") + module_file_path, function_name = func_path.split(":") + + try: + module = importlib.import_module(module_file_path) + except Exception: + raise RuntimeError(f"Unable to import module : {module_file_path}.") + if not hasattr(module, function_name): + raise ValueError(f"Function {function_name} not found in module {module_file_path}.") + return getattr(module, function_name) + + def _filter_empty_or_none_samples(self, example: Dict[str, Any], relevant_columns: list) -> bool: + """ + Filters out samples where any of the relevant columns are None or contain only whitespace. + + Args: + example (Dict[str, Any]): A single sample from the dataset. + relevant_columns (list): List of column names to check for empty or None values. + + Returns: + bool: True if the sample should be kept, False otherwise. + """ + for column in relevant_columns: + value = example.get(column) + if value is None or (isinstance(value, str) and not value.strip()): + return False + return True + + def _preprocess_sample(self, example: Dict[str, Any]) -> Dict[str, str]: + """ + Applies the prompt and completion templates to a single example. + + Args: + example (Dict[str, Any]): A single sample from the dataset. + + Returns: + Dict[str, str]: A dictionary containing the 'prompt' and 'completion' strings. + """ + prompt_text = ( + self.prompt_func(example) if self.prompt_func is not None else self.prompt_template.format(**example) + ) + completion_text = ( + self.completion_func(example) + if self.completion_func is not None + else self.completion_template.format(**example) + ) + return { + "prompt": prompt_text, + "completion": completion_text, + } + + def __len__(self) -> int: + """ + Returns the number of samples in the dataset. + + Returns: + int: The total number of samples. + """ + return self.dataset.num_rows + + def __getitem__(self, idx: int) -> Dict[str, str]: + """ + Retrieves a processed sample from the dataset at the given index. + This method doesn't tokenize the input items, it is expected that the SFTTrainer will handle tokenization. + + Args: + idx (int): The index of the sample to retrieve. + + Returns: + Dict[str, str]: A dictionary containing the processed 'prompt' and 'completion' for the sample. + """ + # Get the raw example using .select and access the first element + example = self.dataset.select(indices=[int(idx)])[0] + + # Apply preprocessing (templating) on the fly + processed_example = self._preprocess_sample(example) + + return processed_example diff --git a/QEfficient/finetune/experimental/core/utils/dataset_utils.py b/QEfficient/finetune/experimental/core/utils/dataset_utils.py index d647b73a65..11e2fecfc3 100644 --- a/QEfficient/finetune/experimental/core/utils/dataset_utils.py +++ b/QEfficient/finetune/experimental/core/utils/dataset_utils.py @@ -4,3 +4,28 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- +def insert_pad_token(tokenizer): + # Add pad token if it doesn't exist + if tokenizer.pad_token is None: + # Try to use existing special token as pad token + if tokenizer.eos_token is not None: + tokenizer.pad_token = tokenizer.eos_token + elif tokenizer.bos_token is not None: + tokenizer.pad_token = tokenizer.bos_token + elif tokenizer.sep_token is not None: + tokenizer.pad_token = tokenizer.sep_token + else: + # Add a new pad token + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + +def apply_train_test_split(dataset, split_ratio, split, seed): + """ + Apply train/test split to the dataset based on split_ratio. + """ + splitted_dataset = dataset.train_test_split(test_size=(1 - split_ratio), seed=seed) + if split == "test": + dataset = splitted_dataset["test"] + else: + dataset = splitted_dataset["train"] + return dataset diff --git a/QEfficient/finetune/experimental/tests/test_dataset.py b/QEfficient/finetune/experimental/tests/test_dataset.py new file mode 100644 index 0000000000..ca2fc14505 --- /dev/null +++ b/QEfficient/finetune/experimental/tests/test_dataset.py @@ -0,0 +1,528 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Tests for dataset components. +""" + +import json +import os +import tempfile +import unittest +from unittest.mock import MagicMock, patch + +from QEfficient.finetune.experimental.core.dataset import BaseDataset, SFTDataset + +SEED = 42 +SPLIT_RATIO = 0.8 + + +class TestBaseDataset(unittest.TestCase): + """Tests for BaseDataset abstract class.""" + + def test_base_dataset_cannot_be_instantiated(self): + """Test that BaseDataset cannot be instantiated directly.""" + with self.assertRaises(TypeError): + BaseDataset(dataset_name="test", split="train") + + +class TestSFTDataset(unittest.TestCase): + """Tests for SFTDataset class.""" + + def setUp(self): + """Set up test fixtures.""" + # Create a temporary directory for test files + self.test_dir = tempfile.mkdtemp() + self.json_file_path = os.path.join(self.test_dir, "test_dataset.json") + + # Create a dummy JSON dataset + self.dummy_data = [ + {"question": "What is AI?", "answer": "Artificial Intelligence"}, + {"question": "What is ML?", "answer": "Machine Learning"}, + {"question": "What is DL?", "answer": "Deep Learning"}, + {"question": "What is NLP?", "answer": "Natural Language Processing"}, + {"question": "", "answer": "Empty question"}, # Empty question + {"question": "Valid question", "answer": ""}, # Empty answer + {"question": None, "answer": "None question"}, # None question + {"question": "Valid question 2", "answer": None}, # None answer + ] + + with open(self.json_file_path, "w") as f: + json.dump(self.dummy_data, f) + + def tearDown(self): + """Clean up test fixtures.""" + # Remove temporary files and directories + import shutil + + if os.path.exists(self.test_dir): + shutil.rmtree(self.test_dir) + + @patch("QEfficient.finetune.experimental.core.dataset.load_dataset") + @patch("QEfficient.finetune.experimental.core.dataset.load_dataset_builder") + def test_sft_dataset_with_huggingface_dataset_and_templates(self, mock_builder, mock_load): + """Test loading from HuggingFace dataset with templates using mocked data.""" + # Create mock dataset with dummy data + mock_dataset = MagicMock() + mock_dataset.column_names = ["text", "label"] + mock_dataset.num_rows = 3 + + # Mock the select method to return individual samples + def mock_select(indices): + sample_data = [ + {"text": "Sample text 1", "label": "Label 1"}, + {"text": "Sample text 2", "label": "Label 2"}, + {"text": "Sample text 3", "label": "Label 3"}, + ] + return [sample_data[indices[0]]] + + mock_dataset.select = mock_select + mock_dataset.filter = lambda func: mock_dataset # Return self for filtering + + # Mock train_test_split to return a dict with train/test splits + mock_split_result = {"train": mock_dataset, "test": mock_dataset} + mock_dataset.train_test_split = lambda test_size, seed: mock_split_result + + # Mock the dataset builder to indicate multiple splits are available + mock_info = MagicMock() + mock_info.splits = {"train": MagicMock(), "test": MagicMock()} + mock_builder.return_value.info = mock_info + + # Mock load_dataset to return our mock dataset + mock_load.return_value = mock_dataset + + # Create the dataset + dataset = SFTDataset( + dataset_name="dummy_hf_dataset", + split="train", + prompt_template="Text: {text}", + completion_template="Label: {label}", + ) + + self.assertIsNotNone(dataset) + self.assertEqual(len(dataset), 3) + + # Test __getitem__ + sample = dataset[0] + self.assertIn("prompt", sample) + self.assertIn("completion", sample) + self.assertTrue(sample["prompt"].startswith("Text:")) + self.assertTrue(sample["completion"].startswith("Label:")) + + def test_sft_dataset_with_json_file_and_templates(self): + """Test loading from JSON file with templates.""" + dataset = SFTDataset( + dataset_name="dummy", # Ignored when json_file_path is provided + split="train", + json_file_path=self.json_file_path, + prompt_template="Q: {question}", + completion_template="A: {answer}", + ) + + self.assertIsNotNone(dataset) + # After filtering empty/None values and applying train split (default 0.8) + # we get a subset of the 4 valid samples + self.assertGreater(len(dataset), 0) + self.assertLessEqual(len(dataset), 4) + + # Test __getitem__ + sample = dataset[0] + self.assertIn("prompt", sample) + self.assertIn("completion", sample) + self.assertTrue(sample["prompt"].startswith("Q:")) + self.assertTrue(sample["completion"].startswith("A:")) + + def test_sft_dataset_json_file_without_filtering(self): + """Test loading from JSON file without filtering empty samples.""" + dataset = SFTDataset( + dataset_name="dummy", + split="train", + json_file_path=self.json_file_path, + prompt_template="Q: {question}", + completion_template="A: {answer}", + remove_samples_with_empty_columns=False, + ) + + # When filtering is disabled and split="train" is used, it still applies train/test split + # So we get ~80% of 8 samples = ~6 samples + self.assertGreater(len(dataset), 0) + self.assertLessEqual(len(dataset), 8) + + def test_sft_dataset_train_test_split_from_json(self): + """Test train/test split when loading from JSON file.""" + train_dataset = SFTDataset( + dataset_name="dummy", + split="train", + split_ratio=SPLIT_RATIO, + json_file_path=self.json_file_path, + prompt_template="Q: {question}", + completion_template="A: {answer}", + seed=SEED, + ) + + test_dataset = SFTDataset( + dataset_name="dummy", + split="test", + split_ratio=SPLIT_RATIO, + json_file_path=self.json_file_path, + prompt_template="Q: {question}", + completion_template="A: {answer}", + seed=SEED, + ) + + # After filtering, we have 4 valid samples + # With split ratio, train should have ~3 samples, test should have ~1 sample + self.assertGreater(len(train_dataset), 0) + self.assertGreater(len(test_dataset), 0) + # Total should equal the filtered dataset size + self.assertEqual(len(train_dataset) + len(test_dataset), 4) + + def test_sft_dataset_with_custom_prompt_function(self): + """Test loading with custom prompt function.""" + # Create a temporary module file with custom functions + func_file_path = os.path.join(self.test_dir, "custom_funcs.py") + with open(func_file_path, "w") as f: + f.write(""" +def custom_prompt(example): + return f"Custom prompt: {example['question']}" + +def custom_completion(example): + return f"Custom completion: {example['answer']}" +""") + + # Add the test directory to sys.path temporarily + import sys + + sys.path.insert(0, self.test_dir) + + try: + dataset = SFTDataset( + dataset_name="dummy", + split="train", + json_file_path=self.json_file_path, + prompt_func="custom_funcs:custom_prompt", + completion_func="custom_funcs:custom_completion", + ) + + self.assertIsNotNone(dataset) + self.assertGreater(len(dataset), 0) + + # Test that custom functions are applied + sample = dataset[0] + self.assertTrue(sample["prompt"].startswith("Custom prompt:")) + self.assertTrue(sample["completion"].startswith("Custom completion:")) + finally: + # Clean up + sys.path.remove(self.test_dir) + if os.path.exists(func_file_path): + os.remove(func_file_path) + + def test_sft_dataset_missing_template_variable(self): + """Test error when template variable is not in dataset columns.""" + with self.assertRaises(RuntimeError) as context: + SFTDataset( + dataset_name="dummy", + split="train", + json_file_path=self.json_file_path, + prompt_template="Q: {nonexistent_column}", + completion_template="A: {answer}", + ) + + self.assertIn("not found in dataset columns", str(context.exception)) + + def test_sft_dataset_missing_completion_template_variable(self): + """Test error when completion template variable is not in dataset columns.""" + with self.assertRaises(RuntimeError) as context: + SFTDataset( + dataset_name="dummy", + split="train", + json_file_path=self.json_file_path, + prompt_template="Q: {question}", + completion_template="A: {nonexistent_column}", + ) + + self.assertIn("not found in dataset columns", str(context.exception)) + + def test_sft_dataset_no_prompt_template_or_func(self): + """Test error when neither prompt_template nor prompt_func is provided.""" + with self.assertRaises(RuntimeError) as context: + SFTDataset( + dataset_name="dummy", + split="train", + json_file_path=self.json_file_path, + completion_template="A: {answer}", + ) + + self.assertIn("Either provide prompt_template or prompt_func", str(context.exception)) + + def test_sft_dataset_both_prompt_template_and_func(self): + """Test error when both prompt_template and prompt_func are provided.""" + with self.assertRaises(RuntimeError) as context: + SFTDataset( + dataset_name="dummy", + split="train", + json_file_path=self.json_file_path, + prompt_template="Q: {question}", + prompt_func="module:function", + completion_template="A: {answer}", + ) + + self.assertIn("Either provide prompt_template or prompt_func", str(context.exception)) + + def test_sft_dataset_no_completion_template_or_func(self): + """Test error when neither completion_template nor completion_func is provided.""" + with self.assertRaises(RuntimeError) as context: + SFTDataset( + dataset_name="dummy", + split="train", + json_file_path=self.json_file_path, + prompt_template="Q: {question}", + ) + + self.assertIn( + "Either provide completion_template or completion_func", + str(context.exception), + ) + + def test_sft_dataset_both_completion_template_and_func(self): + """Test error when both completion_template and completion_func are provided.""" + with self.assertRaises(RuntimeError) as context: + SFTDataset( + dataset_name="dummy", + split="train", + json_file_path=self.json_file_path, + prompt_template="Q: {question}", + completion_template="A: {answer}", + completion_func="module:function", + ) + + self.assertIn( + "Either provide completion_template or completion_func", + str(context.exception), + ) + + def test_sft_dataset_invalid_func_path_format(self): + """Test error when func_path doesn't contain colon separator.""" + with self.assertRaises(ValueError) as context: + SFTDataset( + dataset_name="dummy", + split="train", + json_file_path=self.json_file_path, + prompt_func="invalid_format", + completion_template="A: {answer}", + ) + + self.assertIn("must be in the format", str(context.exception)) + + def test_sft_dataset_invalid_module_import(self): + """Test error when module cannot be imported.""" + with self.assertRaises(RuntimeError) as context: + SFTDataset( + dataset_name="dummy", + split="train", + json_file_path=self.json_file_path, + prompt_func="nonexistent_module:function", + completion_template="A: {answer}", + ) + + self.assertIn("Unable to import module", str(context.exception)) + + def test_sft_dataset_invalid_function_name(self): + """Test error when function doesn't exist in module.""" + # Create a temporary module file without the expected function + func_file_path = os.path.join(self.test_dir, "test_module.py") + with open(func_file_path, "w") as f: + f.write("def some_other_function():\n pass\n") + + import sys + + sys.path.insert(0, self.test_dir) + + try: + with self.assertRaises(ValueError) as context: + SFTDataset( + dataset_name="dummy", + split="train", + json_file_path=self.json_file_path, + prompt_func="test_module:nonexistent_function", + completion_template="A: {answer}", + ) + + self.assertIn("not found in module", str(context.exception)) + finally: + sys.path.remove(self.test_dir) + if os.path.exists(func_file_path): + os.remove(func_file_path) + + def test_sft_dataset_filter_empty_or_none_samples(self): + """Test filtering of samples with empty or None values.""" + dataset = SFTDataset( + dataset_name="dummy", + split="train", + json_file_path=self.json_file_path, + prompt_template="Q: {question}", + completion_template="A: {answer}", + remove_samples_with_empty_columns=True, + ) + + # Verify that all samples have valid (non-empty) questions and answers + for i in range(len(dataset)): + sample = dataset[i] + # Extract the actual question and answer from the formatted strings + question = sample["prompt"].replace("Q: ", "").strip() + answer = sample["completion"].replace("A: ", "").strip() + # Verify neither is empty + self.assertTrue(len(question) > 0, f"Question should not be empty: {sample['prompt']}") + self.assertTrue(len(answer) > 0, f"Answer should not be empty: {sample['completion']}") + + def test_sft_dataset_getitem_returns_correct_format(self): + """Test that __getitem__ returns the correct format.""" + dataset = SFTDataset( + dataset_name="dummy", + split="train", + json_file_path=self.json_file_path, + prompt_template="Q: {question}", + completion_template="A: {answer}", + ) + + sample = dataset[0] + + # Check that sample is a dictionary + self.assertIsInstance(sample, dict) + + # Check that it has the required keys + self.assertIn("prompt", sample) + self.assertIn("completion", sample) + + # Check that values are strings + self.assertIsInstance(sample["prompt"], str) + self.assertIsInstance(sample["completion"], str) + + def test_sft_dataset_len(self): + """Test __len__ method.""" + dataset = SFTDataset( + dataset_name="dummy", + split="train", + json_file_path=self.json_file_path, + prompt_template="Q: {question}", + completion_template="A: {answer}", + ) + + # Check that len returns an integer + self.assertIsInstance(len(dataset), int) + + # Check that len is positive + self.assertGreater(len(dataset), 0) + + # Check that we can iterate through all samples + for i in range(len(dataset)): + sample = dataset[i] + self.assertIsNotNone(sample) + + def test_sft_dataset_with_multiple_template_variables(self): + """Test templates with multiple variables.""" + # Create a more complex JSON dataset + complex_data = [ + {"context": "The sky", "question": "What color?", "answer": "Blue"}, + {"context": "Math", "question": "What is 2+2?", "answer": "4"}, + ] + + complex_json_path = os.path.join(self.test_dir, "complex_dataset.json") + with open(complex_json_path, "w") as f: + json.dump(complex_data, f) + + try: + dataset = SFTDataset( + dataset_name="dummy", + split="train", + json_file_path=complex_json_path, + prompt_template="Context: {context}\nQuestion: {question}", + completion_template="Answer: {answer}", + ) + + # With split="train", it applies train/test split, so we get ~80% of 2 samples + self.assertGreater(len(dataset), 0) + self.assertLessEqual(len(dataset), 2) + + sample = dataset[0] + self.assertIn("Context:", sample["prompt"]) + self.assertIn("Question:", sample["prompt"]) + self.assertIn("Answer:", sample["completion"]) + finally: + if os.path.exists(complex_json_path): + os.remove(complex_json_path) + + def test_sft_dataset_seed_reproducibility(self): + """Test that using the same seed produces the same split.""" + dataset1 = SFTDataset( + dataset_name="dummy", + split="train", + split_ratio=SPLIT_RATIO, + json_file_path=self.json_file_path, + prompt_template="Q: {question}", + completion_template="A: {answer}", + seed=SEED, + ) + + dataset2 = SFTDataset( + dataset_name="dummy", + split="train", + split_ratio=SPLIT_RATIO, + json_file_path=self.json_file_path, + prompt_template="Q: {question}", + completion_template="A: {answer}", + seed=SEED, + ) + + # Both datasets should have the same length + self.assertEqual(len(dataset1), len(dataset2)) + + # Both datasets should have the same samples + for i in range(len(dataset1)): + sample1 = dataset1[i] + sample2 = dataset2[i] + self.assertEqual(sample1["prompt"], sample2["prompt"]) + self.assertEqual(sample1["completion"], sample2["completion"]) + + @patch("QEfficient.finetune.experimental.core.dataset.load_dataset") + @patch("QEfficient.finetune.experimental.core.dataset.load_dataset_builder") + def test_sft_dataset_invalid_split(self, mock_builder, mock_load): + """Test error when requesting an invalid split.""" + # Mock the dataset builder to return specific splits + mock_info = MagicMock() + mock_info.splits = {"train": MagicMock(), "validation": MagicMock()} + mock_builder.return_value.info = mock_info + + with self.assertRaises(ValueError) as context: + SFTDataset( + dataset_name="dummy_dataset", + split="nonexistent_split", + prompt_template="Q: {question}", + completion_template="A: {answer}", + ) + + self.assertIn("not available", str(context.exception)) + + def test_sft_dataset_invalid_json_path(self): + """Test error when an invalid JSON file path is provided.""" + invalid_path = "/path/to/nonexistent/file.json" + + with self.assertRaises(FileNotFoundError) as context: + SFTDataset( + dataset_name="dummy", + split="train", + json_file_path=invalid_path, + prompt_template="Q: {question}", + completion_template="A: {answer}", + ) + + self.assertIn("JSON file not found or invalid", str(context.exception)) + self.assertIn(invalid_path, str(context.exception)) + + +if __name__ == "__main__": + unittest.main() From af9d99c821e2943b0a093691cb52aa345c5acbe0 Mon Sep 17 00:00:00 2001 From: Tanisha Chawada Date: Mon, 8 Dec 2025 16:44:20 +0530 Subject: [PATCH 4/9] [QEff. Finetune] Adding callback and its test cases. (#652) Adding a Script for Registering and Retrieving Callback Classes It has create_callback() function which creates an instance of callback. Additionally, there is a test_callbacks.py script that validates the functionality and retrieval process. --------- Signed-off-by: Tanisha Chawada --- .../finetune/experimental/core/callbacks.py | 199 ++++++++++++++++++ .../experimental/core/utils/profiler_utils.py | 88 ++++++++ .../experimental/tests/test_callback.py | 63 ++++++ 3 files changed, 350 insertions(+) create mode 100644 QEfficient/finetune/experimental/tests/test_callback.py diff --git a/QEfficient/finetune/experimental/core/callbacks.py b/QEfficient/finetune/experimental/core/callbacks.py index d647b73a65..30659e3bbd 100644 --- a/QEfficient/finetune/experimental/core/callbacks.py +++ b/QEfficient/finetune/experimental/core/callbacks.py @@ -4,3 +4,202 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- + +import json +import os +from typing import Any, Dict, Optional + +from transformers import ( + DefaultFlowCallback, + EarlyStoppingCallback, + PrinterCallback, + ProgressCallback, + TrainingArguments, +) +from transformers.integrations.integration_utils import TensorBoardCallback +from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState + +from QEfficient.finetune.experimental.core.component_registry import registry +from QEfficient.finetune.experimental.core.utils.profiler_utils import ( + get_op_verifier_ctx, + init_qaic_profiling, + stop_qaic_profiling, +) + +registry.callback("early_stopping")(EarlyStoppingCallback) +registry.callback("printer")(PrinterCallback) +registry.callback("default_flow")(DefaultFlowCallback) +registry.callback("tensorboard")(TensorBoardCallback) + + +@registry.callback("enhanced_progressbar") +class EnhancedProgressCallback(ProgressCallback): + """ + A [`TrainerCallback`] that displays the progress of training or evaluation. + You can modify `max_str_len` to control how long strings are truncated when logging. + """ + + def __init__(self, *args, **kwargs): + """ + Initialize the callback with optional max_str_len parameter to control string truncation length. + + Args: + max_str_len (`int`): + Maximum length of strings to display in logs. + Longer strings will be truncated with a message. + """ + super().__init__(*args, **kwargs) + + def on_train_begin(self, args, state, control, **kwargs): + """Set progress bar description at the start of training.""" + super().on_train_begin(args, state, control, **kwargs) + if self.training_bar is not None: + self.training_bar.set_description("Training Progress") + + def on_log(self, args, state, control, logs=None, **kwargs): + """ + Override the default `on_log` behavior during training to display + the current epoch number, loss, and learning rate in the logs. + """ + if state.is_world_process_zero and self.training_bar is not None: + # make a shallow copy of logs so we can mutate the fields copied + # but avoid doing any value pickling. + shallow_logs = {} + for k, v in logs.items(): + if isinstance(v, str) and len(v) > self.max_str_len: + shallow_logs[k] = ( + f"[String too long to display, length: {len(v)} > {self.max_str_len}. " + "Consider increasing `max_str_len` if needed.]" + ) + else: + shallow_logs[k] = v + _ = shallow_logs.pop("total_flos", None) + # round numbers so that it looks better in console + if "epoch" in shallow_logs: + shallow_logs["epoch"] = round(shallow_logs["epoch"], 2) + + updated_dict = {} + if "epoch" in shallow_logs: + updated_dict["epoch"] = shallow_logs["epoch"] + if "loss" in shallow_logs: + updated_dict["loss"] = shallow_logs["loss"] + if "learning_rate" in shallow_logs: + updated_dict["lr"] = shallow_logs["learning_rate"] + self.training_bar.set_postfix(updated_dict) + + +@registry.callback("json_logger") +class JSONLoggerCallback(TrainerCallback): + """ + A [`TrainerCallback`] that logs training and evaluation metrics to a JSON file. + """ + + def __init__(self, log_path=None, *args, **kwargs): + """ + Initialize the callback with the path to the JSON log file. + + Args: + log_path (`str`): + Path to the jsonl file where logs will be saved. + """ + super().__init__(*args, **kwargs) + if log_path is None: + log_path = os.path.join(os.environ.get("OUTPUT_DIR", "./"), "training_logs.jsonl") + self.log_path = log_path + # Ensure the log file is created and empty + with open(self.log_path, "w") as _: + pass + + def on_log( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + logs: Optional[Dict] = None, + **kwargs, + ): + """Append sanitized log metrics (including global_step) to a JSONL file.""" + if logs is None: + return + logs.pop("entropy", None) + logs.pop("mean_token_accuracy", None) + if state.global_step: + logs["global_step"] = state.global_step + if logs is not None: + with open(self.log_path, "a") as f: + json_line = json.dumps(logs, separators=(",", ":")) + f.write(json_line + "\n") + + +@registry.callback("qaic_profiler_callback") +class QAICProfilerCallback(TrainerCallback): + """Callback to profile QAIC devices over a specified training step range.""" + + def __init__(self, *args, **kwargs): + """ + Initialize QAIC profiler settings (start/end steps and target device IDs). + """ + + self.start_step = kwargs.get("start_step", -1) + self.end_step = kwargs.get("end_step", -1) + self.device_ids = kwargs.get("device_ids", [0]) + + def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called at the beginning of a training step. If using gradient accumulation, one training step might take + several inputs. + """ + if state.global_step == self.start_step: + for device_id in self.device_ids: + init_qaic_profiling(True, f"qaic:{device_id}") + elif state.global_step == self.end_step: + for device_id in self.device_ids: + stop_qaic_profiling(True, f"qaic:{device_id}") + + +@registry.callback("qaic_op_by_op_verifier_callback") +class QAICOpByOpVerifierCallback(TrainerCallback): + """Callback to verify QAIC operations step-by-step during a specified training range.""" + + def __init__(self, *args, **kwargs): + """ " + Initialize QAIC Op-by-Op verifier callback with profiling and tolerance settings. + """ + self.start_step = kwargs.get("start_step", -1) + self.end_step = kwargs.get("end_step", -1) + self.trace_dir = kwargs.get("trace_dir", "qaic_op_by_op_traces") + self.atol = kwargs.get("atol", 1e-1) + self.rtol = kwargs.get("rtol", 1e-5) + + def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called at the beginning of a training step. If using gradient accumulation, one training step might take + several inputs. + """ + if self.start_step <= state.global_step < self.end_step: + self.op_verifier_ctx_step = get_op_verifier_ctx( + use_op_by_op_verifier=True, + device_type="qaic", + dump_dir=self.trace_dir, + step=state.global_step, + atol=self.atol, + rtol=self.rtol, + ) + self.op_verifier_ctx_step.__enter__() + + def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called at the end of a training step. If using gradient accumulation, one training step might take + several inputs. + """ + if self.start_step <= state.global_step < self.end_step: + if self.op_verifier_ctx_step is not None: + self.op_verifier_ctx_step.__exit__(None, None, None) + + +def create_callbacks(name: str, **kwargs) -> Any: + """Create a callback instance.""" + callback_class = registry.get_callback(name) + if callback_class is None: + raise ValueError(f"Unknown callback: {name}. Available: {registry.list_callbacks()}") + return callback_class(**kwargs) diff --git a/QEfficient/finetune/experimental/core/utils/profiler_utils.py b/QEfficient/finetune/experimental/core/utils/profiler_utils.py index d647b73a65..e24508e831 100644 --- a/QEfficient/finetune/experimental/core/utils/profiler_utils.py +++ b/QEfficient/finetune/experimental/core/utils/profiler_utils.py @@ -4,3 +4,91 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- + + +from contextlib import nullcontext +from typing import ContextManager + +import torch + + +def get_op_verifier_ctx( + use_op_by_op_verifier: bool, + device_type: str, + dump_dir: str, + step: int, + ref_device: str = "cpu", + ref_dtype: torch.dtype = torch.float32, + atol: float = 1e-1, + rtol: float = 1e-5, + use_ref_output_on_mismatch: bool = True, +) -> ContextManager: + """Get the op-by-op verifier context manager when op-by-op verification is + enabled. It helps in debuging operator related issues by matching the + operator execution on qaic v/s cpu. This is meant only for qaic backend. + + Args: + use_op_by_op_verifier (bool): Boolean flag to enable op-by-op verifier. + device_type (str): Device on which the model is being executed. + dump_dir (str): Directory to dump the op-by-op verification results. + step (int): Step number for which the op-by-op verification is to be performed. + ref_device (str, optional): Device to use as reference for verification. + Defaults to "cpu". + ref_dtype (torch.dtype, optional): Data type to use as reference + datatype for verification. Defaults to torch.float32. + atol (float, optional): Absolute tolerance to match the results. Defaults to 1e-1. + rtol (float, optional): Relative tolerance to match the results. Defaults to 1e-5. + use_ref_output_on_mismatch (bool, optional): If an operator has a + mismatch with respect to the reference device, use the reference + device outputs and continue rest of the verification. Defaults to True. + + Returns: + ContextManager: Instance of context manager used to verify the operators. + """ + if (not use_op_by_op_verifier) or ("qaic" in device_type): + return nullcontext() + + # Lazily imported qaic_debug when it is actually needed. + import torch_qaic.debug as qaic_debug + + filter_config = qaic_debug.DispatchFilterConfig.default(device_type) + dump_dir = dump_dir + "/mismatches/step_" + str(step) + return qaic_debug.OpByOpVerifierMode( + ref_device=ref_device, + ref_dtype=ref_dtype, + atol=atol, + rtol=rtol, + use_ref_output_on_mismatch=use_ref_output_on_mismatch, + filter_config=filter_config, + dump_root_dir=dump_dir, + ) + + +def init_qaic_profiling(use_profiler: bool, device_type: str) -> None: + """Initialize the qaic profiling tool. Note: The profiler is only works + for qaic backend. + + Args: + use_profiler (bool): Boolean flag to enable profiler. + device_type (str): Device on which the model is being executed. + """ + if (use_profiler) and ("qaic" in device_type): + # Lazily imported qaic's qaic_profile when it is actually needed. + import torch_qaic.profile as qaic_profile + + qaic_profile.start_profiling(device_type, 1) + + +def stop_qaic_profiling(use_profiler: bool, device_type: str) -> None: + """Stop the qaic profiling tool. Note: The profiler is only works + for qaic backend. + + Args: + use_profiler (bool): Boolean flag to enable profiler. + device_type (str): Device on which the model is being executed. + """ + if (use_profiler) and ("qaic" in device_type): + # Lazily imported qaic's qaic_profile when it is actually needed. + import torch_qaic.profile as qaic_profile + + qaic_profile.stop_profiling(device_type) diff --git a/QEfficient/finetune/experimental/tests/test_callback.py b/QEfficient/finetune/experimental/tests/test_callback.py new file mode 100644 index 0000000000..59ff4d1173 --- /dev/null +++ b/QEfficient/finetune/experimental/tests/test_callback.py @@ -0,0 +1,63 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import pytest +from transformers import TrainerCallback + +from QEfficient.finetune.experimental.core.callbacks import create_callbacks +from QEfficient.finetune.experimental.core.component_registry import registry + + +class ModelSummaryCallback(TrainerCallback): + def __init__(self): + pass + + +# Setup test data +CALLBACK_CONFIGS = { + "early_stopping": { + "name": "early_stopping", + "early_stopping_patience": 3, + "early_stopping_threshold": 0.001, + }, + "tensorboard": {"name": "tensorboard", "tb_writer": "SummaryWriter"}, + "model_summary": { + "name": "model_summary", + "max_depth": 1, + }, +} + +REGISTRY_CALLBACK_CONFIGS = { + "model_summary": { + "name": "model_summary", + "max_depth": 1, + "callback_class": ModelSummaryCallback, + }, +} + + +@pytest.mark.parametrize("callback_name", CALLBACK_CONFIGS.keys()) +def test_callbacks(callback_name): + """Test that registered callbacks that can be created with their configs.""" + # Create callbacks using the factory + config = CALLBACK_CONFIGS[callback_name] + try: + callback_inst = create_callbacks(**config) + except ValueError as e: + assert "Unknown callback" in str(e) + return + assert callback_inst is not None + assert isinstance(callback_inst, TrainerCallback) + + +@pytest.mark.parametrize("callback_name,callback_class", REGISTRY_CALLBACK_CONFIGS.items()) +def test_callbacks_registery(callback_name, callback_class): + """Test that a callback registered correctly.""" + registry.callback(callback_name)(callback_class) + callback = registry.get_callback(callback_name) + assert callback is not None + assert callback == callback_class From d0918e9d4e1a7f1edb8c856cae7aab0227e28a1b Mon Sep 17 00:00:00 2001 From: Tanisha Chawada Date: Mon, 15 Dec 2025 11:56:54 +0530 Subject: [PATCH 5/9] "[QEff.finetuning] Adding config_manager and its test cases." (#656) Added Config_manager to parse the training, model and dataset related arguments. --------- Signed-off-by: Tanisha Chawada --- .../experimental/core/config_manager.py | 749 ++++++++++++++++++ .../experimental/tests/test_config.yaml | 104 +++ .../experimental/tests/test_config_manager.py | 62 ++ 3 files changed, 915 insertions(+) create mode 100644 QEfficient/finetune/experimental/tests/test_config.yaml create mode 100644 QEfficient/finetune/experimental/tests/test_config_manager.py diff --git a/QEfficient/finetune/experimental/core/config_manager.py b/QEfficient/finetune/experimental/core/config_manager.py index d647b73a65..244967f39c 100644 --- a/QEfficient/finetune/experimental/core/config_manager.py +++ b/QEfficient/finetune/experimental/core/config_manager.py @@ -4,3 +4,752 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- +""" +Configuration manager for handling all training configurations. +Provides centralized configuration loading, validation, and management. +""" + +import json +import os +from dataclasses import asdict, dataclass, field, fields, is_dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import yaml +from transformers.hf_argparser import HfArgumentParser + +from QEfficient.finetune.experimental.core.component_registry import registry + + +@dataclass +class OptimizerConfig: + """Configuration for optimizers.""" + + optimizer_name: str = field( + default="adamw", + metadata={"help": "The name of the optimizer to use."}, + ) + lr: float = field( + default=5e-5, + metadata={"help": "The initial learning rate for the optimizer."}, + ) + weight_decay: float = field( + default=0.01, + metadata={"help": "The weight decay to apply (if any)."}, + ) + + +@dataclass +class SchedulerConfig: + """Configuration for learning rate schedulers.""" + + scheduler_name: str = field( + default="cosine", + metadata={"help": "The name of the scheduler to use (e.g., 'linear', 'cosine')."}, + ) + warmup_steps: int = field( + default=100, + metadata={ + "help": "Number of steps for the warmup phase. If provided " + "value is within [0-1) range then it will be interpreted as " + "ratio of total training steps for the warmup phase." + }, + ) + + +@dataclass +class DatasetConfig: + """Configuration for datasets.""" + + tokenizer_name: str = field( + default="HuggingFaceTB/SmolLM-135M", + metadata={"help": "The name or path of the tokenizer to use."}, + ) + dataset_type: str = field( + default="seq_completion", + metadata={"help": "The type of dataset (e.g., 'seq_completion')."}, + ) + dataset_name: str = field( + default="knkarthick/samsum", + metadata={"help": "The name or path of the dataset."}, + ) + dataset_subset: str = field( + default="default", + metadata={"help": "The subset of the dataset to use, if applicable."}, + ) + train_split: str = field( + default="train", + metadata={"help": "The name of the training split."}, + ) + test_split: str = field( + default="test", + metadata={"help": "The name of the test/validation split."}, + ) + max_seq_length: int = field( + default=512, + metadata={"help": "The maximum sequence length for tokenization."}, + ) + split_ratio: float = field( + default=0.8, + metadata={"help": "Ratio for train/test split, used when only train_split is provided."}, + ) + input_columns: list[str] = field( + default_factory=lambda: ["text"], + metadata={"help": "List of column names containing input text."}, + ) + target_column: Optional[str] = field( + default=None, + metadata={"help": "Name of the column containing target labels (if applicable)."}, + ) + train_batch_size: int = field( + default=1, + metadata={"help": "Batch size per device during training."}, + ) + eval_batch_size: int = field( + default=1, + metadata={"help": "Batch size per device during evaluation."}, + ) + num_workers: int = field( + default=4, + metadata={"help": "Number of workers for dataset processing."}, + ) + collate_fn: str = field( + default="dynamic_padding", + metadata={"help": "The collation function to use (e.g., 'dynamic_padding')."}, + ) + group_by_length: bool = field( + default=True, + metadata={"help": "Whether to group samples by length to minimize padding."}, + ) + length_column_name: str = field( + default="input_ids", + metadata={"help": "The column name containing the length of the input sequences."}, + ) + dataloader_pin_memory: bool = field( + default=True, + metadata={"help": "Whether to pin GPU memory for dataloaders."}, + ) + dataloader_persistent_workers: bool = field( + default=True, + metadata={"help": "Whether to keep dataloader workers alive across epochs."}, + ) + dataloader_prefetch_factor: int = field( + default=1, + metadata={"help": "Number of samples loaded in advance by each worker."}, + ) + dataloader_drop_last: bool = field( + default=False, + metadata={"help": "Whether to drop the last incomplete batch."}, + ) + dataloader_num_workers: int = field( + default=1, + metadata={"help": "Number of workers for the DataLoader."}, + ) + + +@dataclass +class PeftConfig: + """Configuration for PEFT (Parameter-Efficient Fine-Tuning) methods.""" + + lora_r: int = field( + default=8, + metadata={"help": "Lora attention dimension."}, + ) + lora_alpha: int = field( + default=16, + metadata={"help": "Lora alpha."}, + ) + lora_dropout: float = field( + default=0.1, + metadata={"help": "The dropout probability for Lora layers."}, + ) + target_modules: list[str] = field( + default_factory=lambda: ["q_proj", "v_proj"], + metadata={"help": "The modules to apply Lora to."}, + ) + bias: str = field( + default="none", + metadata={"help": "Bias type for Lora ('none', 'all', 'lora_only')."}, + ) + task_type: str = field( + default="CAUSAL_LM", + metadata={"help": "The task type for PEFT (e.g., 'CAUSAL_LM', 'SEQ_2_SEQ_LM')."}, + ) + peft_type: str = field( + default="LORA", + metadata={"help": "The PEFT method to use (e.g., 'LORA', 'IA3')."}, + ) + + +@dataclass +class ModelConfig: + """Configuration for models.""" + + model_name: str = field( + default="HuggingFaceTB/SmolLM-135M", + metadata={"help": "The name or path of the pretrained model."}, + ) + model_type: str = field( + default="hf", + metadata={"help": "The type of model ('hf' for Hugging Face, 'custom' for custom models)."}, + ) + auto_class_name: str = field( + default="AutoModelForCausalLM", + metadata={"help": "The AutoClass name to load the model (e.g., 'AutoModelForCausalLM')."}, + ) + load_in_4bit: bool = field( + default=False, + metadata={"help": "Whether to load the model in 4-bit quantization."}, + ) + use_peft: bool = field( + default=True, + metadata={"help": "Whether to use PEFT (Parameter-Efficient Fine-Tuning)."}, + ) + peft_config: Optional[PeftConfig] = field( + default_factory=PeftConfig, + metadata={"help": "Configuration for PEFT."}, + ) + use_cache: bool = field( + default=False, + metadata={"help": "Whether to use the past key/values in the model for faster decoding."}, + ) + attn_implementation: str = field( + default="sdpa", + metadata={"help": "The attention implementation to use (e.g., 'sdpa', 'eager')."}, + ) + device_map: Optional[str] = field( + default=None, + metadata={"help": "The device map to use for model distribution (e.g., 'auto')."}, + ) + + +@dataclass +class CallbackConfig: + """Configuration for callbacks.""" + + callbacks: Dict[str, Dict[str, Any]] = field( + default_factory=dict, + metadata={"help": "Dictionary of callback configurations, keyed by callback name."}, + ) + + +@dataclass +class GradientCheckpointingKwargs: + """Arguments for gradient checkpointing.""" + + preserve_rng_state: bool = field( + default=True, + metadata={"help": "Whether to preserve the RNG state when checkpointing."}, + ) + use_reenrant: bool = field( + default=False, + metadata={"help": "Whether to use reentrant gradient checkpointing."}, + ) + + +@dataclass +class DdpConfig: + """Arguments for Distributed Data Parallel (DDP) training.""" + + ddp_backend: str = field( + default="qccl", + metadata={"help": "The DDP backend to use (e.g., 'nccl', 'gloo', 'qccl')."}, + ) + ddp_find_unused_parameters: bool = field( + default=False, + metadata={"help": "Whether to find unused parameters in DDP."}, + ) + ddp_bucket_cap_mb: Optional[int] = field( + default=25, + metadata={"help": "The bucket size in MB for DDP communication."}, + ) + ddp_broadcast_buffers: bool = field( + default=True, + metadata={"help": "Whether to broadcast buffers in DDP."}, + ) + ddp_timeout: int = field( + default=1800, + metadata={"help": "Timeout for DDP operations in seconds."}, + ) + + +@dataclass +class TrainingConfig: + """Configuration for training.""" + + type: str = field( + default="sft", + metadata={"help": "The type of training (e.g., 'sft' for Supervised Fine-Tuning)."}, + ) + output_dir: str = field( + default="./training_results", + metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, + ) + overwrite_output_dir: bool = field( + default=False, + metadata={"help": "Whether to overwrite the output directory."}, + ) + seed: int = field( + default=42, + metadata={"help": "Random seed for reproducibility."}, + ) + device: str = field( + default="qaic", + metadata={"help": "The device to use for training ('cuda', 'cpu', etc.)."}, + ) + do_eval: bool = field( + default=True, + metadata={"help": "Whether to run evaluation during training."}, + ) + eval_strategy: str = field( + default="epoch", + metadata={"help": "The evaluation strategy to use ('no', 'steps', 'epoch')."}, + ) + eval_steps: int = field( + default=100, + metadata={"help": "Number of update steps between two evaluations."}, + ) + per_device_train_batch_size: int = field( + default=1, + metadata={"help": "Batch size per device during training."}, + ) + per_device_eval_batch_size: int = field( + default=1, + metadata={"help": "Batch size per device during evaluation."}, + ) + gradient_accumulation_steps: int = field( + default=1, + metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."}, + ) + num_train_epochs: int = field( + default=1, + metadata={"help": "Total number of training epochs to perform."}, + ) + max_steps: int = field( + default=-1, + metadata={"help": "If > 0: set total number of training steps to perform."}, + ) + + log_level: str = field( + default="info", + metadata={"help": "Set the verbosity level of the logs ('debug', 'info', 'warning', 'error')."}, + ) + log_on_each_node: bool = field( + default=True, + metadata={"help": "Whether to log on each node in a distributed setup."}, + ) + logging_strategy: str = field( + default="steps", + metadata={"help": "The logging strategy to use ('no', 'steps', 'epoch')."}, + ) + logging_steps: int = field( + default=10, + metadata={"help": "Number of update steps between two loggings."}, + ) + + save_strategy: str = field( + default="epoch", + metadata={"help": "The checkpoint save strategy to use ('no', 'steps', 'epoch')."}, + ) + save_steps: int = field( + default=100, + metadata={"help": "Number of update steps between two checkpoints (if save_strategy is 'steps')."}, + ) + save_total_limit: int = field( + default=5, + metadata={"help": "Limit the total amount of checkpoints. Deletes older checkpoints to stay within limit."}, + ) + metric_for_best_model: str = field( + default="eval_loss", + metadata={"help": "The metric to use to compare two models ('eval_loss', etc.)."}, + ) + + dtype: str = field( + default="fp16", + metadata={"help": "The data type to use for training (e.g., 'fp16', 'bf16')."}, + ) + + gradient_checkpointing: bool = field( + default=False, + metadata={"help": "Whether to use gradient checkpointing."}, + ) + gradient_checkpointing_kwargs: Optional[GradientCheckpointingKwargs] = field( + default_factory=GradientCheckpointingKwargs, + metadata={"help": "Arguments for gradient checkpointing."}, + ) + + torch_compile: bool = field( + default=True, + metadata={"help": "Whether to compile the model with `torch.compile`."}, + ) + include_num_input_tokens_seen: bool = field( + default=True, + metadata={"help": "Whether to include the number of input tokens seen in logs."}, + ) + average_tokens_across_devices: bool = field( + default=True, + metadata={"help": "Whether to average tokens across devices in distributed training."}, + ) + + disable_tqdm: Optional[bool] = field( + default=None, + metadata={"help": "Whether to disable the tqdm progress bar."}, + ) + fsdp_config: Optional[Dict[str, Any]] = field( + default=None, + metadata={"help": "FSDP configuration dictionary."}, + ) + deepspeed_config: Optional[Dict[str, Any]] = field( + default=None, + metadata={"help": "DeepSpeed configuration dictionary."}, + ) + accelerator_config: Optional[Dict[str, Any]] = field( + default=None, + metadata={"help": "Accelerate configuration dictionary."}, + ) + ddp_config: Optional[DdpConfig] = field( + default_factory=DdpConfig, + metadata={"help": "DDP configuration dictionary."}, + ) + use_cpu: Optional[bool] = field( + default=None, + metadata={"help": "Whether to explicitly run training on CPU."}, + ) + resume_from_checkpoint: Optional[str] = field( + default=None, + metadata={"help": "Path to a checkpoint to resume training from."}, + ) + restore_callback_states_from_checkpoint: Optional[bool] = field( + default=None, + metadata={"help": "Whether to restore callback states from checkpoint."}, + ) + report_to: Optional[List[str]] = field( + default=None, + metadata={"help": "The list of integrations to report the results and logs to."}, + ) + completion_only_loss: Optional[bool] = field( + default=False, + metadata={"help": "Whether to compute loss only on completion tokens."}, + ) + + +@dataclass +class MasterConfig: + """Main training configuration.""" + + model: ModelConfig = field(default_factory=ModelConfig, metadata={"help": "Configuration for the model."}) + + dataset: DatasetConfig = field(default_factory=DatasetConfig, metadata={"help": "Configuration for the dataset."}) + + optimizers: OptimizerConfig = field( + default_factory=OptimizerConfig, metadata={"help": "Configuration for optimizers."} + ) + + scheduler: SchedulerConfig = field( + default_factory=SchedulerConfig, metadata={"help": "Configuration for the learning rate scheduler."} + ) + + callbacks: CallbackConfig = field(default_factory=CallbackConfig, metadata={"help": "Configuration for callbacks."}) + + training: TrainingConfig = field( + default_factory=TrainingConfig, metadata={"help": "Configuration for training parameters."} + ) + + extra_params: Dict[str, Any] = field( + default_factory=dict, metadata={"help": "Additional top-level parameters not explicitly defined."} + ) + + +def parse_arguments(config_path: Optional[str] = None, args: Optional[List[str]] = None) -> MasterConfig: + """Create argument parser for the new finetuning interface.""" + parser = HfArgumentParser(MasterConfig) + + if config_path: + config_path = os.path.abspath(config_path) + if not os.path.exists(config_path): + raise FileNotFoundError(f"Config file not found: {config_path}") + if not (config_path.endswith(".yaml") or config_path.endswith(".yml")): + raise ValueError(f"Expected a .yaml/.yml file, got: {config_path}") + + try: + (master_config,) = parser.parse_yaml_file(yaml_file=config_path) + return master_config + except Exception as e: + raise ValueError(f"Failed to parse YAML config '{config_path}': {e}") + + args = [] if args is None else args + # If a single positional YAML file was passed via args, parse it as YAML + if len(args) == 1 and (args[0].endswith(".yaml") or args[0].endswith(".yml")): + yaml_path = os.path.abspath(args[0]) + (master_config,) = parser.parse_yaml_file(yaml_file=yaml_path) + else: + (master_config,) = parser.parse_args_into_dataclasses(args=args) + master_config = asdict(master_config) + master_config = MasterConfig(**master_config) + + return master_config + + +class ConfigManager: + """Manages configuration loading, validation, and updates.""" + + def __init__(self, config: MasterConfig): + """ + Initialize ConfigManager with either: + - Path to config file (str or Path) + - Configuration dictionary + - None (creates empty config) + """ + self.config = config + + def load_config(self, config_path: Union[str, Path]) -> None: + """Load configuration from file.""" + config_path = Path(config_path) + + if not config_path.exists(): + raise FileNotFoundError(f"Configuration file not found: {config_path}") + + if config_path.suffix.lower() in [".yaml", ".yml"]: + with open(config_path, "r") as f: + config_dict = yaml.safe_load(f) + elif config_path.suffix.lower() == ".json": + with open(config_path, "r") as f: + config_dict = json.load(f) + else: + raise ValueError(f"Unsupported configuration file format: {config_path.suffix}") + + self.update_config(config_dict) + + def _ensure_extra_params(self, obj) -> Dict[str, Any]: + """Ensure obj.extra_params exists and is a dict; return it.""" + ep = getattr(obj, "extra_params", None) + if ep is None: + setattr(obj, "extra_params", {}) + ep = obj.extra_params + if not isinstance(ep, dict): + raise TypeError("extra_params must be a dict.") + return ep + + def _stash_top_level_extra(self, section: str, nested_key: str, value: Any) -> None: + """Store unknown nested values under MasterConfig.extra_params['section.nested_key'].""" + ep = self._ensure_extra_params(self.config) + ep[f"{section}.{nested_key}"] = value + + def update_config(self, config_dict: Dict[str, Any]) -> None: + """Update configuration with dictionary values.""" + + SPECIAL_KEYS = {"callbacks"} + + for key, value in config_dict.items(): + if hasattr(self.config, key): + target = getattr(self.config, key) + + # Special handling for callbacks (dict inside CallbackConfig) + if key in SPECIAL_KEYS and isinstance(value, dict): + if is_dataclass(target) and hasattr(target, "callbacks") and isinstance(target.callbacks, dict): + for component_name, component_cfg in value.items(): + target.callbacks[component_name] = component_cfg + elif isinstance(target, dict): + target.update(value) + else: + self._stash_top_level_extra(key, "__all__", value) + continue + + if isinstance(value, dict) and is_dataclass(target): + known = {f.name for f in fields(target)} + for nested_key, nested_value in value.items(): + if nested_key in known: + setattr(target, nested_key, nested_value) + else: + self._stash_top_level_extra(key, nested_key, nested_value) + continue + + if isinstance(value, dict) and isinstance(target, dict): + target.update(value) + continue + setattr(self.config, key, value) + + else: + ep = self._ensure_extra_params(self.config) + ep[key] = value + + def save_config(self, output_path: Union[str, Path]) -> None: + """Save current configuration to file.""" + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + config_dict = self.config + + if output_path.suffix.lower() in [".yaml", ".yml"]: + with open(output_path, "w") as f: + yaml.dump(config_dict, f, default_flow_style=False, indent=2) + elif output_path.suffix.lower() == ".json": + with open(output_path, "w") as f: + json.dump(config_dict, f, indent=2) + else: + raise ValueError(f"Unsupported output file format: {output_path.suffix}") + + def _push(self, errs: List[str], cond: bool, msg: str) -> None: + """Append msg to errs if cond is True.""" + if cond: + errs.append(msg) + + def validate_config(self) -> None: + """ + Validate configuration parameters for MasterConfig. + """ + errors: List[str] = [] + + cfg = self.config + model = getattr(cfg, "model", {}) + optimizers = getattr(cfg, "optimizers", {}) + dataset = getattr(cfg, "dataset", {}) + training = getattr(cfg, "training", {}) + + # ---------- Model ---------- + self._push(errors, not model.get("model_name"), "model.model_name is required.") + + # PEFT validation + if model.get("use_peft"): + pc = model.get("peft_config", {}) + self._push(errors, not isinstance(pc, dict), "model.peft_config must be a dict when use_peft=True.") + if isinstance(pc, dict): + self._push( + errors, + not isinstance(pc.get("lora_r", 0), int) or pc.get("lora_r", 0) <= 0, + "model.peft_config.lora_r must be a positive integer.", + ) + self._push( + errors, + not isinstance(pc.get("lora_alpha", 0), int) or pc.get("lora_alpha", 0) <= 0, + "model.peft_config.lora_alpha must be a positive integer.", + ) + self._push( + errors, + not (0.0 <= float(pc.get("lora_dropout", 0.0)) < 1.0), + "model.peft_config.lora_dropout must be in [0,1).", + ) + + # ---------- Dataset ---------- + self._push(errors, not dataset.get("dataset_name"), "dataset.dataset_name is required.") + self._push(errors, not dataset.get("tokenizer_name"), "dataset.tokenizer_name is required.") + self._push(errors, dataset.get("max_seq_length", 0) <= 0, "dataset.max_seq_length must be positive.") + + # ---------- Training ---------- + # Batch sizes + self._push( + errors, + training.get("per_device_train_batch_size", 0) <= 0, + "training.per_device_train_batch_size must be positive.", + ) + self._push( + errors, + training.get("per_device_eval_batch_size", 0) <= 0, + "training.per_device_eval_batch_size must be positive.", + ) + + # Epochs / steps + n_epochs = training.get("num_train_epochs", 0) + max_steps = training.get("max_steps", -1) + self._push( + errors, + n_epochs <= 0 and max_steps <= 0, + "Either training.num_train_epochs > 0 or training.max_steps > 0 must be set.", + ) + + # Gradient accumulation + self._push( + errors, + training.get("gradient_accumulation_steps", 0) <= 0, + "training.gradient_accumulation_steps must be positive.", + ) + + # Logging / saving configs + self._push(errors, training.get("logging_steps", 0) < 0, "training.logging_steps must be >= 0.") + self._push(errors, training.get("save_total_limit", 0) < 0, "training.save_total_limit must be >= 0.") + + # Device + valid_devices = ["cpu", "cuda", "qaic"] + training_device = training.get("device", None) + if training_device not in valid_devices: + self._push(errors, training_device not in valid_devices, f"training.device must be one of {valid_devices}.") + + # DDP config + ddp = training.get("ddp_config", {}) + if isinstance(ddp, dict): + backend = ddp.get("ddp_backend") + # Accept qccl for Qualcomm, nccl for CUDA, gloo for CPU + self._push( + errors, + backend not in {"qccl", "nccl", "gloo", None}, + "training.ddp_config.ddp_backend must be one of {'qccl','nccl','gloo'} or omitted.", + ) + # -----------Optimizers---------- + self._push(errors, float(optimizers.get("lr", 0)) <= 0, "optimizer.lr must be positive.") + # ---------- Final ---------- + if errors: + # Join messages with bullet points for readability + raise ValueError("Configuration validation failed:\n- " + "\n- ".join(errors)) + + def get_callback_config(self) -> Dict[str, Any]: + """Get callback configuration as dictionary.""" + return self.config.callbacks + + def get_optimizer_config(self) -> Dict[str, Any]: + """Get optimizer configuration as dictionary.""" + return self.config.optimizers + + def get_training_config(self) -> Dict[str, Any]: + """Get training configuration as dictionary.""" + return self.config.training + + def get_scheduler_config(self) -> Dict[str, Any]: + """Get scheduler configuration as dictionary.""" + return self.config.scheduler + + def get_dataset_config(self) -> Dict[str, Any]: + """Get dataset configuration as dictionary.""" + return self.config.dataset + + def get_model_config(self) -> Dict[str, Any]: + """Get model configuration as dictionary.""" + return self.config.model + + def to_dict(self) -> Dict[str, Any]: + """Convert configuration to dictionary.""" + return asdict(self.config) + + def __getattr__(self, name: str) -> Any: + """Allow direct access to config attributes.""" + if hasattr(self.config, name): + return getattr(self.config, name) + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + + +def create_trainer_config(name: str, **dependencies) -> tuple: + """ + Create trainer configuration based on registered trainer modules. + + Args: + name: Name of the trainer type + **dependencies: Any dependencies needed to configure the trainer + + Returns: + tuple: (trainer_class, args_class, additional_kwargs) + """ + config = registry.get_trainer_module(name) + + # Process required kwargs based on available dependencies + additional_kwargs = {} + for kwarg, default in config["required_kwargs"].items(): + if kwarg in dependencies: + additional_kwargs[kwarg] = dependencies[kwarg] + elif default != "REQUIRED": + additional_kwargs[kwarg] = default + + # Check for missing required arguments + for kwarg, default in config["required_kwargs"].items(): + if kwarg not in additional_kwargs and default == "REQUIRED": + raise ValueError(f"Required argument '{kwarg}' not provided for trainer '{name}'") + + return config["trainer_cls"], config["args_cls"], additional_kwargs diff --git a/QEfficient/finetune/experimental/tests/test_config.yaml b/QEfficient/finetune/experimental/tests/test_config.yaml new file mode 100644 index 0000000000..e97e99d583 --- /dev/null +++ b/QEfficient/finetune/experimental/tests/test_config.yaml @@ -0,0 +1,104 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +# model configuration +model: + model_type: "hf" + auto_class_name: "AutoModelForCausalLM" + model_name: "HuggingFaceTB/SmolLM-135M" # Pretrained model name + load_in_4bit: false + use_peft: true + peft_config: + lora_r: 8 + lora_alpha: 16 + lora_dropout: 0.1 + target_modules: ["q_proj", "v_proj"] + bias: "none" + task_type: "CAUSAL_LM" + peft_type: "LORA" + +# Dataset configuration +dataset: + tokenizer_name: "HuggingFaceTB/SmolLM-135M" + dataset_type: "seq_completion" + # dataset_name: "Arthur-LAGACHERIE/very-smollm-corpus-0.5M" + dataset_name: "knkarthick/samsum" + train_split: "train" + max_seq_length: 512 + split_ratio: 0.8 # Ratio for train/test split, used when only train_split is provided + test_split: "test" + group_by_length: True + num_workers: 4 + dataloader_pin_memory: True + dataloader_persistent_workers: True + dataloader_prefetch_factor: 1 + dataloader_drop_last: False + +# Training configuration +training: + type: "sft" + output_dir: "./training_results" + overwrite_output_dir: False + seed: 42 + device: "qaic" + do_eval: True + eval_strategy: "epoch" + eval_steps: 100 + + per_device_train_batch_size: 1 + per_device_eval_batch_size: 1 + gradient_accumulation_steps: 1 + num_train_epochs: 1 + max_steps: -1 + + log_level: "info" + log_on_each_node: True + logging_strategy: "steps" + logging_steps: 10 + + save_strategy: "epoch" + save_total_limit: 5 + metric_for_best_model: "eval_loss" + + dtype: "fp16" + completion_only_loss: True + report_to: "trackio" + + ddp_config: + ddp_backend: "qccl" + ddp_find_unused_parameters: False + ddp_bucket_cap_mb: 25 + ddp_broadcast_buffers: null + ddp_timeout: 1800 + + use_cpu: False + + gradient_checkpointing: False + gradient_checkpointing_kwargs: + preserve_rng_state : True + use_reenrant: False + + torch_compile: True + include_num_input_tokens_seen: True + average_tokens_across_devices: True + +# Optimizer configuration +optimizers: + optimizer_name: "adamw" + lr: 5e-5 + weight_decay: 0.01 + +scheduler: + scheduler_name: "cosine" + warmup_steps: 100 # warmup_steps or warmup_ratio + +callbacks: + early_stopping: + early_stopping_patience: 3 + early_stopping_threshold: 0.001 + tensorboard: + diff --git a/QEfficient/finetune/experimental/tests/test_config_manager.py b/QEfficient/finetune/experimental/tests/test_config_manager.py new file mode 100644 index 0000000000..fd2abfd482 --- /dev/null +++ b/QEfficient/finetune/experimental/tests/test_config_manager.py @@ -0,0 +1,62 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + + +from pathlib import Path + +import pytest + +from QEfficient.finetune.experimental.core.config_manager import ConfigManager, parse_arguments + + +@pytest.fixture +def config_path() -> Path: + here = Path(__file__).resolve().parent + return (here / "test_config.yaml").resolve() + + +def test_config(config_path): + master_config = parse_arguments(args=[]) + config_manager = ConfigManager(master_config) + assert isinstance(config_manager, ConfigManager) + config_manager.load_config(config_path) + try: + config_manager.validate_config() + except Exception as e: + pytest.fail(f"Config validation failed with error: {e}") + + # Test that all required fields are present + missing = [ + a + for a in ("model", "dataset", "optimizers", "scheduler", "callbacks", "training") + if not hasattr(config_manager, a) + ] + assert not missing, f"Missing attributes: {missing}" + trainer_config = config_manager.get_training_config() + assert trainer_config is not None + assert isinstance(trainer_config, dict) + assert (hasattr(trainer_config, attr) for attr in ("output_dir", "train_batch_size", "num_epochs", "ddp_config")) + dataset_config = config_manager.get_dataset_config() + assert dataset_config is not None + assert isinstance(dataset_config, dict) + assert (hasattr(dataset_config, attr) for attr in ("dataset_type", "dataset_name", "tokenizer_name")) + model_config = config_manager.get_model_config() + assert model_config is not None + assert isinstance(model_config, dict) + assert (hasattr(model_config, attr) for attr in ("model_type", "model_name", "use_peft", "peft_config")) + scheduler_config = config_manager.get_scheduler_config() + assert scheduler_config is not None + assert isinstance(scheduler_config, dict) + assert (hasattr(scheduler_config, attr) for attr in ("scheduler_name")) + callback_config = config_manager.get_callback_config() + assert callback_config is not None + assert isinstance(callback_config, dict) + assert (hasattr(callback_config, attr) for attr in ("earlystopping")) + optimizer_config = config_manager.get_optimizer_config() + assert optimizer_config is not None + assert isinstance(optimizer_config, dict) + assert (hasattr(optimizer_config, attr) for attr in ("optimizer_name", "lr")) From cc9705d88976e1e4a941a1151b6319ae6dfdf4d8 Mon Sep 17 00:00:00 2001 From: Ann Kuruvilla Date: Mon, 15 Dec 2025 12:01:32 +0530 Subject: [PATCH 6/9] Revert " "[QEff.finetuning] Adding config_manager and its test cases."" (#666) Reverts quic/efficient-transformers#656 Signed-off-by: Sharvari Medhe Signed-off-by: Ann Kuruvilla --- .../experimental/core/config_manager.py | 749 ------------------ .../experimental/tests/test_config.yaml | 104 --- .../experimental/tests/test_config_manager.py | 62 -- 3 files changed, 915 deletions(-) delete mode 100644 QEfficient/finetune/experimental/tests/test_config.yaml delete mode 100644 QEfficient/finetune/experimental/tests/test_config_manager.py diff --git a/QEfficient/finetune/experimental/core/config_manager.py b/QEfficient/finetune/experimental/core/config_manager.py index 244967f39c..d647b73a65 100644 --- a/QEfficient/finetune/experimental/core/config_manager.py +++ b/QEfficient/finetune/experimental/core/config_manager.py @@ -4,752 +4,3 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- -""" -Configuration manager for handling all training configurations. -Provides centralized configuration loading, validation, and management. -""" - -import json -import os -from dataclasses import asdict, dataclass, field, fields, is_dataclass -from pathlib import Path -from typing import Any, Dict, List, Optional, Union - -import yaml -from transformers.hf_argparser import HfArgumentParser - -from QEfficient.finetune.experimental.core.component_registry import registry - - -@dataclass -class OptimizerConfig: - """Configuration for optimizers.""" - - optimizer_name: str = field( - default="adamw", - metadata={"help": "The name of the optimizer to use."}, - ) - lr: float = field( - default=5e-5, - metadata={"help": "The initial learning rate for the optimizer."}, - ) - weight_decay: float = field( - default=0.01, - metadata={"help": "The weight decay to apply (if any)."}, - ) - - -@dataclass -class SchedulerConfig: - """Configuration for learning rate schedulers.""" - - scheduler_name: str = field( - default="cosine", - metadata={"help": "The name of the scheduler to use (e.g., 'linear', 'cosine')."}, - ) - warmup_steps: int = field( - default=100, - metadata={ - "help": "Number of steps for the warmup phase. If provided " - "value is within [0-1) range then it will be interpreted as " - "ratio of total training steps for the warmup phase." - }, - ) - - -@dataclass -class DatasetConfig: - """Configuration for datasets.""" - - tokenizer_name: str = field( - default="HuggingFaceTB/SmolLM-135M", - metadata={"help": "The name or path of the tokenizer to use."}, - ) - dataset_type: str = field( - default="seq_completion", - metadata={"help": "The type of dataset (e.g., 'seq_completion')."}, - ) - dataset_name: str = field( - default="knkarthick/samsum", - metadata={"help": "The name or path of the dataset."}, - ) - dataset_subset: str = field( - default="default", - metadata={"help": "The subset of the dataset to use, if applicable."}, - ) - train_split: str = field( - default="train", - metadata={"help": "The name of the training split."}, - ) - test_split: str = field( - default="test", - metadata={"help": "The name of the test/validation split."}, - ) - max_seq_length: int = field( - default=512, - metadata={"help": "The maximum sequence length for tokenization."}, - ) - split_ratio: float = field( - default=0.8, - metadata={"help": "Ratio for train/test split, used when only train_split is provided."}, - ) - input_columns: list[str] = field( - default_factory=lambda: ["text"], - metadata={"help": "List of column names containing input text."}, - ) - target_column: Optional[str] = field( - default=None, - metadata={"help": "Name of the column containing target labels (if applicable)."}, - ) - train_batch_size: int = field( - default=1, - metadata={"help": "Batch size per device during training."}, - ) - eval_batch_size: int = field( - default=1, - metadata={"help": "Batch size per device during evaluation."}, - ) - num_workers: int = field( - default=4, - metadata={"help": "Number of workers for dataset processing."}, - ) - collate_fn: str = field( - default="dynamic_padding", - metadata={"help": "The collation function to use (e.g., 'dynamic_padding')."}, - ) - group_by_length: bool = field( - default=True, - metadata={"help": "Whether to group samples by length to minimize padding."}, - ) - length_column_name: str = field( - default="input_ids", - metadata={"help": "The column name containing the length of the input sequences."}, - ) - dataloader_pin_memory: bool = field( - default=True, - metadata={"help": "Whether to pin GPU memory for dataloaders."}, - ) - dataloader_persistent_workers: bool = field( - default=True, - metadata={"help": "Whether to keep dataloader workers alive across epochs."}, - ) - dataloader_prefetch_factor: int = field( - default=1, - metadata={"help": "Number of samples loaded in advance by each worker."}, - ) - dataloader_drop_last: bool = field( - default=False, - metadata={"help": "Whether to drop the last incomplete batch."}, - ) - dataloader_num_workers: int = field( - default=1, - metadata={"help": "Number of workers for the DataLoader."}, - ) - - -@dataclass -class PeftConfig: - """Configuration for PEFT (Parameter-Efficient Fine-Tuning) methods.""" - - lora_r: int = field( - default=8, - metadata={"help": "Lora attention dimension."}, - ) - lora_alpha: int = field( - default=16, - metadata={"help": "Lora alpha."}, - ) - lora_dropout: float = field( - default=0.1, - metadata={"help": "The dropout probability for Lora layers."}, - ) - target_modules: list[str] = field( - default_factory=lambda: ["q_proj", "v_proj"], - metadata={"help": "The modules to apply Lora to."}, - ) - bias: str = field( - default="none", - metadata={"help": "Bias type for Lora ('none', 'all', 'lora_only')."}, - ) - task_type: str = field( - default="CAUSAL_LM", - metadata={"help": "The task type for PEFT (e.g., 'CAUSAL_LM', 'SEQ_2_SEQ_LM')."}, - ) - peft_type: str = field( - default="LORA", - metadata={"help": "The PEFT method to use (e.g., 'LORA', 'IA3')."}, - ) - - -@dataclass -class ModelConfig: - """Configuration for models.""" - - model_name: str = field( - default="HuggingFaceTB/SmolLM-135M", - metadata={"help": "The name or path of the pretrained model."}, - ) - model_type: str = field( - default="hf", - metadata={"help": "The type of model ('hf' for Hugging Face, 'custom' for custom models)."}, - ) - auto_class_name: str = field( - default="AutoModelForCausalLM", - metadata={"help": "The AutoClass name to load the model (e.g., 'AutoModelForCausalLM')."}, - ) - load_in_4bit: bool = field( - default=False, - metadata={"help": "Whether to load the model in 4-bit quantization."}, - ) - use_peft: bool = field( - default=True, - metadata={"help": "Whether to use PEFT (Parameter-Efficient Fine-Tuning)."}, - ) - peft_config: Optional[PeftConfig] = field( - default_factory=PeftConfig, - metadata={"help": "Configuration for PEFT."}, - ) - use_cache: bool = field( - default=False, - metadata={"help": "Whether to use the past key/values in the model for faster decoding."}, - ) - attn_implementation: str = field( - default="sdpa", - metadata={"help": "The attention implementation to use (e.g., 'sdpa', 'eager')."}, - ) - device_map: Optional[str] = field( - default=None, - metadata={"help": "The device map to use for model distribution (e.g., 'auto')."}, - ) - - -@dataclass -class CallbackConfig: - """Configuration for callbacks.""" - - callbacks: Dict[str, Dict[str, Any]] = field( - default_factory=dict, - metadata={"help": "Dictionary of callback configurations, keyed by callback name."}, - ) - - -@dataclass -class GradientCheckpointingKwargs: - """Arguments for gradient checkpointing.""" - - preserve_rng_state: bool = field( - default=True, - metadata={"help": "Whether to preserve the RNG state when checkpointing."}, - ) - use_reenrant: bool = field( - default=False, - metadata={"help": "Whether to use reentrant gradient checkpointing."}, - ) - - -@dataclass -class DdpConfig: - """Arguments for Distributed Data Parallel (DDP) training.""" - - ddp_backend: str = field( - default="qccl", - metadata={"help": "The DDP backend to use (e.g., 'nccl', 'gloo', 'qccl')."}, - ) - ddp_find_unused_parameters: bool = field( - default=False, - metadata={"help": "Whether to find unused parameters in DDP."}, - ) - ddp_bucket_cap_mb: Optional[int] = field( - default=25, - metadata={"help": "The bucket size in MB for DDP communication."}, - ) - ddp_broadcast_buffers: bool = field( - default=True, - metadata={"help": "Whether to broadcast buffers in DDP."}, - ) - ddp_timeout: int = field( - default=1800, - metadata={"help": "Timeout for DDP operations in seconds."}, - ) - - -@dataclass -class TrainingConfig: - """Configuration for training.""" - - type: str = field( - default="sft", - metadata={"help": "The type of training (e.g., 'sft' for Supervised Fine-Tuning)."}, - ) - output_dir: str = field( - default="./training_results", - metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, - ) - overwrite_output_dir: bool = field( - default=False, - metadata={"help": "Whether to overwrite the output directory."}, - ) - seed: int = field( - default=42, - metadata={"help": "Random seed for reproducibility."}, - ) - device: str = field( - default="qaic", - metadata={"help": "The device to use for training ('cuda', 'cpu', etc.)."}, - ) - do_eval: bool = field( - default=True, - metadata={"help": "Whether to run evaluation during training."}, - ) - eval_strategy: str = field( - default="epoch", - metadata={"help": "The evaluation strategy to use ('no', 'steps', 'epoch')."}, - ) - eval_steps: int = field( - default=100, - metadata={"help": "Number of update steps between two evaluations."}, - ) - per_device_train_batch_size: int = field( - default=1, - metadata={"help": "Batch size per device during training."}, - ) - per_device_eval_batch_size: int = field( - default=1, - metadata={"help": "Batch size per device during evaluation."}, - ) - gradient_accumulation_steps: int = field( - default=1, - metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."}, - ) - num_train_epochs: int = field( - default=1, - metadata={"help": "Total number of training epochs to perform."}, - ) - max_steps: int = field( - default=-1, - metadata={"help": "If > 0: set total number of training steps to perform."}, - ) - - log_level: str = field( - default="info", - metadata={"help": "Set the verbosity level of the logs ('debug', 'info', 'warning', 'error')."}, - ) - log_on_each_node: bool = field( - default=True, - metadata={"help": "Whether to log on each node in a distributed setup."}, - ) - logging_strategy: str = field( - default="steps", - metadata={"help": "The logging strategy to use ('no', 'steps', 'epoch')."}, - ) - logging_steps: int = field( - default=10, - metadata={"help": "Number of update steps between two loggings."}, - ) - - save_strategy: str = field( - default="epoch", - metadata={"help": "The checkpoint save strategy to use ('no', 'steps', 'epoch')."}, - ) - save_steps: int = field( - default=100, - metadata={"help": "Number of update steps between two checkpoints (if save_strategy is 'steps')."}, - ) - save_total_limit: int = field( - default=5, - metadata={"help": "Limit the total amount of checkpoints. Deletes older checkpoints to stay within limit."}, - ) - metric_for_best_model: str = field( - default="eval_loss", - metadata={"help": "The metric to use to compare two models ('eval_loss', etc.)."}, - ) - - dtype: str = field( - default="fp16", - metadata={"help": "The data type to use for training (e.g., 'fp16', 'bf16')."}, - ) - - gradient_checkpointing: bool = field( - default=False, - metadata={"help": "Whether to use gradient checkpointing."}, - ) - gradient_checkpointing_kwargs: Optional[GradientCheckpointingKwargs] = field( - default_factory=GradientCheckpointingKwargs, - metadata={"help": "Arguments for gradient checkpointing."}, - ) - - torch_compile: bool = field( - default=True, - metadata={"help": "Whether to compile the model with `torch.compile`."}, - ) - include_num_input_tokens_seen: bool = field( - default=True, - metadata={"help": "Whether to include the number of input tokens seen in logs."}, - ) - average_tokens_across_devices: bool = field( - default=True, - metadata={"help": "Whether to average tokens across devices in distributed training."}, - ) - - disable_tqdm: Optional[bool] = field( - default=None, - metadata={"help": "Whether to disable the tqdm progress bar."}, - ) - fsdp_config: Optional[Dict[str, Any]] = field( - default=None, - metadata={"help": "FSDP configuration dictionary."}, - ) - deepspeed_config: Optional[Dict[str, Any]] = field( - default=None, - metadata={"help": "DeepSpeed configuration dictionary."}, - ) - accelerator_config: Optional[Dict[str, Any]] = field( - default=None, - metadata={"help": "Accelerate configuration dictionary."}, - ) - ddp_config: Optional[DdpConfig] = field( - default_factory=DdpConfig, - metadata={"help": "DDP configuration dictionary."}, - ) - use_cpu: Optional[bool] = field( - default=None, - metadata={"help": "Whether to explicitly run training on CPU."}, - ) - resume_from_checkpoint: Optional[str] = field( - default=None, - metadata={"help": "Path to a checkpoint to resume training from."}, - ) - restore_callback_states_from_checkpoint: Optional[bool] = field( - default=None, - metadata={"help": "Whether to restore callback states from checkpoint."}, - ) - report_to: Optional[List[str]] = field( - default=None, - metadata={"help": "The list of integrations to report the results and logs to."}, - ) - completion_only_loss: Optional[bool] = field( - default=False, - metadata={"help": "Whether to compute loss only on completion tokens."}, - ) - - -@dataclass -class MasterConfig: - """Main training configuration.""" - - model: ModelConfig = field(default_factory=ModelConfig, metadata={"help": "Configuration for the model."}) - - dataset: DatasetConfig = field(default_factory=DatasetConfig, metadata={"help": "Configuration for the dataset."}) - - optimizers: OptimizerConfig = field( - default_factory=OptimizerConfig, metadata={"help": "Configuration for optimizers."} - ) - - scheduler: SchedulerConfig = field( - default_factory=SchedulerConfig, metadata={"help": "Configuration for the learning rate scheduler."} - ) - - callbacks: CallbackConfig = field(default_factory=CallbackConfig, metadata={"help": "Configuration for callbacks."}) - - training: TrainingConfig = field( - default_factory=TrainingConfig, metadata={"help": "Configuration for training parameters."} - ) - - extra_params: Dict[str, Any] = field( - default_factory=dict, metadata={"help": "Additional top-level parameters not explicitly defined."} - ) - - -def parse_arguments(config_path: Optional[str] = None, args: Optional[List[str]] = None) -> MasterConfig: - """Create argument parser for the new finetuning interface.""" - parser = HfArgumentParser(MasterConfig) - - if config_path: - config_path = os.path.abspath(config_path) - if not os.path.exists(config_path): - raise FileNotFoundError(f"Config file not found: {config_path}") - if not (config_path.endswith(".yaml") or config_path.endswith(".yml")): - raise ValueError(f"Expected a .yaml/.yml file, got: {config_path}") - - try: - (master_config,) = parser.parse_yaml_file(yaml_file=config_path) - return master_config - except Exception as e: - raise ValueError(f"Failed to parse YAML config '{config_path}': {e}") - - args = [] if args is None else args - # If a single positional YAML file was passed via args, parse it as YAML - if len(args) == 1 and (args[0].endswith(".yaml") or args[0].endswith(".yml")): - yaml_path = os.path.abspath(args[0]) - (master_config,) = parser.parse_yaml_file(yaml_file=yaml_path) - else: - (master_config,) = parser.parse_args_into_dataclasses(args=args) - master_config = asdict(master_config) - master_config = MasterConfig(**master_config) - - return master_config - - -class ConfigManager: - """Manages configuration loading, validation, and updates.""" - - def __init__(self, config: MasterConfig): - """ - Initialize ConfigManager with either: - - Path to config file (str or Path) - - Configuration dictionary - - None (creates empty config) - """ - self.config = config - - def load_config(self, config_path: Union[str, Path]) -> None: - """Load configuration from file.""" - config_path = Path(config_path) - - if not config_path.exists(): - raise FileNotFoundError(f"Configuration file not found: {config_path}") - - if config_path.suffix.lower() in [".yaml", ".yml"]: - with open(config_path, "r") as f: - config_dict = yaml.safe_load(f) - elif config_path.suffix.lower() == ".json": - with open(config_path, "r") as f: - config_dict = json.load(f) - else: - raise ValueError(f"Unsupported configuration file format: {config_path.suffix}") - - self.update_config(config_dict) - - def _ensure_extra_params(self, obj) -> Dict[str, Any]: - """Ensure obj.extra_params exists and is a dict; return it.""" - ep = getattr(obj, "extra_params", None) - if ep is None: - setattr(obj, "extra_params", {}) - ep = obj.extra_params - if not isinstance(ep, dict): - raise TypeError("extra_params must be a dict.") - return ep - - def _stash_top_level_extra(self, section: str, nested_key: str, value: Any) -> None: - """Store unknown nested values under MasterConfig.extra_params['section.nested_key'].""" - ep = self._ensure_extra_params(self.config) - ep[f"{section}.{nested_key}"] = value - - def update_config(self, config_dict: Dict[str, Any]) -> None: - """Update configuration with dictionary values.""" - - SPECIAL_KEYS = {"callbacks"} - - for key, value in config_dict.items(): - if hasattr(self.config, key): - target = getattr(self.config, key) - - # Special handling for callbacks (dict inside CallbackConfig) - if key in SPECIAL_KEYS and isinstance(value, dict): - if is_dataclass(target) and hasattr(target, "callbacks") and isinstance(target.callbacks, dict): - for component_name, component_cfg in value.items(): - target.callbacks[component_name] = component_cfg - elif isinstance(target, dict): - target.update(value) - else: - self._stash_top_level_extra(key, "__all__", value) - continue - - if isinstance(value, dict) and is_dataclass(target): - known = {f.name for f in fields(target)} - for nested_key, nested_value in value.items(): - if nested_key in known: - setattr(target, nested_key, nested_value) - else: - self._stash_top_level_extra(key, nested_key, nested_value) - continue - - if isinstance(value, dict) and isinstance(target, dict): - target.update(value) - continue - setattr(self.config, key, value) - - else: - ep = self._ensure_extra_params(self.config) - ep[key] = value - - def save_config(self, output_path: Union[str, Path]) -> None: - """Save current configuration to file.""" - output_path = Path(output_path) - output_path.parent.mkdir(parents=True, exist_ok=True) - - config_dict = self.config - - if output_path.suffix.lower() in [".yaml", ".yml"]: - with open(output_path, "w") as f: - yaml.dump(config_dict, f, default_flow_style=False, indent=2) - elif output_path.suffix.lower() == ".json": - with open(output_path, "w") as f: - json.dump(config_dict, f, indent=2) - else: - raise ValueError(f"Unsupported output file format: {output_path.suffix}") - - def _push(self, errs: List[str], cond: bool, msg: str) -> None: - """Append msg to errs if cond is True.""" - if cond: - errs.append(msg) - - def validate_config(self) -> None: - """ - Validate configuration parameters for MasterConfig. - """ - errors: List[str] = [] - - cfg = self.config - model = getattr(cfg, "model", {}) - optimizers = getattr(cfg, "optimizers", {}) - dataset = getattr(cfg, "dataset", {}) - training = getattr(cfg, "training", {}) - - # ---------- Model ---------- - self._push(errors, not model.get("model_name"), "model.model_name is required.") - - # PEFT validation - if model.get("use_peft"): - pc = model.get("peft_config", {}) - self._push(errors, not isinstance(pc, dict), "model.peft_config must be a dict when use_peft=True.") - if isinstance(pc, dict): - self._push( - errors, - not isinstance(pc.get("lora_r", 0), int) or pc.get("lora_r", 0) <= 0, - "model.peft_config.lora_r must be a positive integer.", - ) - self._push( - errors, - not isinstance(pc.get("lora_alpha", 0), int) or pc.get("lora_alpha", 0) <= 0, - "model.peft_config.lora_alpha must be a positive integer.", - ) - self._push( - errors, - not (0.0 <= float(pc.get("lora_dropout", 0.0)) < 1.0), - "model.peft_config.lora_dropout must be in [0,1).", - ) - - # ---------- Dataset ---------- - self._push(errors, not dataset.get("dataset_name"), "dataset.dataset_name is required.") - self._push(errors, not dataset.get("tokenizer_name"), "dataset.tokenizer_name is required.") - self._push(errors, dataset.get("max_seq_length", 0) <= 0, "dataset.max_seq_length must be positive.") - - # ---------- Training ---------- - # Batch sizes - self._push( - errors, - training.get("per_device_train_batch_size", 0) <= 0, - "training.per_device_train_batch_size must be positive.", - ) - self._push( - errors, - training.get("per_device_eval_batch_size", 0) <= 0, - "training.per_device_eval_batch_size must be positive.", - ) - - # Epochs / steps - n_epochs = training.get("num_train_epochs", 0) - max_steps = training.get("max_steps", -1) - self._push( - errors, - n_epochs <= 0 and max_steps <= 0, - "Either training.num_train_epochs > 0 or training.max_steps > 0 must be set.", - ) - - # Gradient accumulation - self._push( - errors, - training.get("gradient_accumulation_steps", 0) <= 0, - "training.gradient_accumulation_steps must be positive.", - ) - - # Logging / saving configs - self._push(errors, training.get("logging_steps", 0) < 0, "training.logging_steps must be >= 0.") - self._push(errors, training.get("save_total_limit", 0) < 0, "training.save_total_limit must be >= 0.") - - # Device - valid_devices = ["cpu", "cuda", "qaic"] - training_device = training.get("device", None) - if training_device not in valid_devices: - self._push(errors, training_device not in valid_devices, f"training.device must be one of {valid_devices}.") - - # DDP config - ddp = training.get("ddp_config", {}) - if isinstance(ddp, dict): - backend = ddp.get("ddp_backend") - # Accept qccl for Qualcomm, nccl for CUDA, gloo for CPU - self._push( - errors, - backend not in {"qccl", "nccl", "gloo", None}, - "training.ddp_config.ddp_backend must be one of {'qccl','nccl','gloo'} or omitted.", - ) - # -----------Optimizers---------- - self._push(errors, float(optimizers.get("lr", 0)) <= 0, "optimizer.lr must be positive.") - # ---------- Final ---------- - if errors: - # Join messages with bullet points for readability - raise ValueError("Configuration validation failed:\n- " + "\n- ".join(errors)) - - def get_callback_config(self) -> Dict[str, Any]: - """Get callback configuration as dictionary.""" - return self.config.callbacks - - def get_optimizer_config(self) -> Dict[str, Any]: - """Get optimizer configuration as dictionary.""" - return self.config.optimizers - - def get_training_config(self) -> Dict[str, Any]: - """Get training configuration as dictionary.""" - return self.config.training - - def get_scheduler_config(self) -> Dict[str, Any]: - """Get scheduler configuration as dictionary.""" - return self.config.scheduler - - def get_dataset_config(self) -> Dict[str, Any]: - """Get dataset configuration as dictionary.""" - return self.config.dataset - - def get_model_config(self) -> Dict[str, Any]: - """Get model configuration as dictionary.""" - return self.config.model - - def to_dict(self) -> Dict[str, Any]: - """Convert configuration to dictionary.""" - return asdict(self.config) - - def __getattr__(self, name: str) -> Any: - """Allow direct access to config attributes.""" - if hasattr(self.config, name): - return getattr(self.config, name) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") - - -def create_trainer_config(name: str, **dependencies) -> tuple: - """ - Create trainer configuration based on registered trainer modules. - - Args: - name: Name of the trainer type - **dependencies: Any dependencies needed to configure the trainer - - Returns: - tuple: (trainer_class, args_class, additional_kwargs) - """ - config = registry.get_trainer_module(name) - - # Process required kwargs based on available dependencies - additional_kwargs = {} - for kwarg, default in config["required_kwargs"].items(): - if kwarg in dependencies: - additional_kwargs[kwarg] = dependencies[kwarg] - elif default != "REQUIRED": - additional_kwargs[kwarg] = default - - # Check for missing required arguments - for kwarg, default in config["required_kwargs"].items(): - if kwarg not in additional_kwargs and default == "REQUIRED": - raise ValueError(f"Required argument '{kwarg}' not provided for trainer '{name}'") - - return config["trainer_cls"], config["args_cls"], additional_kwargs diff --git a/QEfficient/finetune/experimental/tests/test_config.yaml b/QEfficient/finetune/experimental/tests/test_config.yaml deleted file mode 100644 index e97e99d583..0000000000 --- a/QEfficient/finetune/experimental/tests/test_config.yaml +++ /dev/null @@ -1,104 +0,0 @@ -# ----------------------------------------------------------------------------- -# -# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. -# SPDX-License-Identifier: BSD-3-Clause -# -# ----------------------------------------------------------------------------- - -# model configuration -model: - model_type: "hf" - auto_class_name: "AutoModelForCausalLM" - model_name: "HuggingFaceTB/SmolLM-135M" # Pretrained model name - load_in_4bit: false - use_peft: true - peft_config: - lora_r: 8 - lora_alpha: 16 - lora_dropout: 0.1 - target_modules: ["q_proj", "v_proj"] - bias: "none" - task_type: "CAUSAL_LM" - peft_type: "LORA" - -# Dataset configuration -dataset: - tokenizer_name: "HuggingFaceTB/SmolLM-135M" - dataset_type: "seq_completion" - # dataset_name: "Arthur-LAGACHERIE/very-smollm-corpus-0.5M" - dataset_name: "knkarthick/samsum" - train_split: "train" - max_seq_length: 512 - split_ratio: 0.8 # Ratio for train/test split, used when only train_split is provided - test_split: "test" - group_by_length: True - num_workers: 4 - dataloader_pin_memory: True - dataloader_persistent_workers: True - dataloader_prefetch_factor: 1 - dataloader_drop_last: False - -# Training configuration -training: - type: "sft" - output_dir: "./training_results" - overwrite_output_dir: False - seed: 42 - device: "qaic" - do_eval: True - eval_strategy: "epoch" - eval_steps: 100 - - per_device_train_batch_size: 1 - per_device_eval_batch_size: 1 - gradient_accumulation_steps: 1 - num_train_epochs: 1 - max_steps: -1 - - log_level: "info" - log_on_each_node: True - logging_strategy: "steps" - logging_steps: 10 - - save_strategy: "epoch" - save_total_limit: 5 - metric_for_best_model: "eval_loss" - - dtype: "fp16" - completion_only_loss: True - report_to: "trackio" - - ddp_config: - ddp_backend: "qccl" - ddp_find_unused_parameters: False - ddp_bucket_cap_mb: 25 - ddp_broadcast_buffers: null - ddp_timeout: 1800 - - use_cpu: False - - gradient_checkpointing: False - gradient_checkpointing_kwargs: - preserve_rng_state : True - use_reenrant: False - - torch_compile: True - include_num_input_tokens_seen: True - average_tokens_across_devices: True - -# Optimizer configuration -optimizers: - optimizer_name: "adamw" - lr: 5e-5 - weight_decay: 0.01 - -scheduler: - scheduler_name: "cosine" - warmup_steps: 100 # warmup_steps or warmup_ratio - -callbacks: - early_stopping: - early_stopping_patience: 3 - early_stopping_threshold: 0.001 - tensorboard: - diff --git a/QEfficient/finetune/experimental/tests/test_config_manager.py b/QEfficient/finetune/experimental/tests/test_config_manager.py deleted file mode 100644 index fd2abfd482..0000000000 --- a/QEfficient/finetune/experimental/tests/test_config_manager.py +++ /dev/null @@ -1,62 +0,0 @@ -# ----------------------------------------------------------------------------- -# -# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. -# SPDX-License-Identifier: BSD-3-Clause -# -# ----------------------------------------------------------------------------- - - -from pathlib import Path - -import pytest - -from QEfficient.finetune.experimental.core.config_manager import ConfigManager, parse_arguments - - -@pytest.fixture -def config_path() -> Path: - here = Path(__file__).resolve().parent - return (here / "test_config.yaml").resolve() - - -def test_config(config_path): - master_config = parse_arguments(args=[]) - config_manager = ConfigManager(master_config) - assert isinstance(config_manager, ConfigManager) - config_manager.load_config(config_path) - try: - config_manager.validate_config() - except Exception as e: - pytest.fail(f"Config validation failed with error: {e}") - - # Test that all required fields are present - missing = [ - a - for a in ("model", "dataset", "optimizers", "scheduler", "callbacks", "training") - if not hasattr(config_manager, a) - ] - assert not missing, f"Missing attributes: {missing}" - trainer_config = config_manager.get_training_config() - assert trainer_config is not None - assert isinstance(trainer_config, dict) - assert (hasattr(trainer_config, attr) for attr in ("output_dir", "train_batch_size", "num_epochs", "ddp_config")) - dataset_config = config_manager.get_dataset_config() - assert dataset_config is not None - assert isinstance(dataset_config, dict) - assert (hasattr(dataset_config, attr) for attr in ("dataset_type", "dataset_name", "tokenizer_name")) - model_config = config_manager.get_model_config() - assert model_config is not None - assert isinstance(model_config, dict) - assert (hasattr(model_config, attr) for attr in ("model_type", "model_name", "use_peft", "peft_config")) - scheduler_config = config_manager.get_scheduler_config() - assert scheduler_config is not None - assert isinstance(scheduler_config, dict) - assert (hasattr(scheduler_config, attr) for attr in ("scheduler_name")) - callback_config = config_manager.get_callback_config() - assert callback_config is not None - assert isinstance(callback_config, dict) - assert (hasattr(callback_config, attr) for attr in ("earlystopping")) - optimizer_config = config_manager.get_optimizer_config() - assert optimizer_config is not None - assert isinstance(optimizer_config, dict) - assert (hasattr(optimizer_config, attr) for attr in ("optimizer_name", "lr")) From 64c3bf3ba6657b7d139b637b98a075fad9bc8462 Mon Sep 17 00:00:00 2001 From: Tanisha Chawada Date: Mon, 15 Dec 2025 13:30:42 +0530 Subject: [PATCH 7/9] "[QEff.finetuning} Rebasing: hf_config_mananger." (#667) Signed-off-by: Tanisha Chawada --- .../experimental/core/config_manager.py | 747 ++++++++++++++++++ .../experimental/tests/test_config.yaml | 104 +++ .../experimental/tests/test_config_manager.py | 62 ++ 3 files changed, 913 insertions(+) create mode 100644 QEfficient/finetune/experimental/tests/test_config.yaml create mode 100644 QEfficient/finetune/experimental/tests/test_config_manager.py diff --git a/QEfficient/finetune/experimental/core/config_manager.py b/QEfficient/finetune/experimental/core/config_manager.py index d647b73a65..b28c2e1e33 100644 --- a/QEfficient/finetune/experimental/core/config_manager.py +++ b/QEfficient/finetune/experimental/core/config_manager.py @@ -4,3 +4,750 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- +""" +Configuration manager for handling all training configurations. +Provides centralized configuration loading, validation, and management. +""" + +import json +import os +from dataclasses import asdict, dataclass, field, fields, is_dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import yaml +from transformers.hf_argparser import HfArgumentParser + +from QEfficient.finetune.experimental.core.component_registry import registry + + +@dataclass +class OptimizerConfig: + """Configuration for optimizers.""" + + optimizer_name: str = field( + default="adamw", + metadata={"help": "The name of the optimizer to use."}, + ) + lr: float = field( + default=5e-5, + metadata={"help": "The initial learning rate for the optimizer."}, + ) + weight_decay: float = field( + default=0.01, + metadata={"help": "The weight decay to apply (if any)."}, + ) + + +@dataclass +class SchedulerConfig: + """Configuration for learning rate schedulers.""" + + scheduler_name: str = field( + default="cosine", + metadata={"help": "The name of the scheduler to use (e.g., 'linear', 'cosine')."}, + ) + warmup_steps: int = field( + default=100, + metadata={ + "help": "Number of steps for the warmup phase. If provided " + "value is within [0-1) range then it will be interpreted as " + "ratio of total training steps for the warmup phase." + }, + ) + + +@dataclass +class DatasetConfig: + """Configuration for datasets.""" + + tokenizer_name: str = field( + default="HuggingFaceTB/SmolLM-135M", + metadata={"help": "The name or path of the tokenizer to use."}, + ) + dataset_type: str = field( + default="seq_completion", + metadata={"help": "The type of dataset (e.g., 'seq_completion')."}, + ) + dataset_name: str = field( + default="knkarthick/samsum", + metadata={"help": "The name or path of the dataset."}, + ) + dataset_subset: str = field( + default="default", + metadata={"help": "The subset of the dataset to use, if applicable."}, + ) + train_split: str = field( + default="train", + metadata={"help": "The name of the training split."}, + ) + test_split: str = field( + default="test", + metadata={"help": "The name of the test/validation split."}, + ) + max_seq_length: int = field( + default=512, + metadata={"help": "The maximum sequence length for tokenization."}, + ) + split_ratio: float = field( + default=0.8, + metadata={"help": "Ratio for train/test split, used when only train_split is provided."}, + ) + input_columns: list[str] = field( + default_factory=lambda: ["text"], + metadata={"help": "List of column names containing input text."}, + ) + target_column: Optional[str] = field( + default=None, + metadata={"help": "Name of the column containing target labels (if applicable)."}, + ) + train_batch_size: int = field( + default=1, + metadata={"help": "Batch size per device during training."}, + ) + eval_batch_size: int = field( + default=1, + metadata={"help": "Batch size per device during evaluation."}, + ) + num_workers: int = field( + default=4, + metadata={"help": "Number of workers for dataset processing."}, + ) + collate_fn: str = field( + default="dynamic_padding", + metadata={"help": "The collation function to use (e.g., 'dynamic_padding')."}, + ) + group_by_length: bool = field( + default=True, + metadata={"help": "Whether to group samples by length to minimize padding."}, + ) + length_column_name: str = field( + default="input_ids", + metadata={"help": "The column name containing the length of the input sequences."}, + ) + dataloader_pin_memory: bool = field( + default=True, + metadata={"help": "Whether to pin GPU memory for dataloaders."}, + ) + dataloader_persistent_workers: bool = field( + default=True, + metadata={"help": "Whether to keep dataloader workers alive across epochs."}, + ) + dataloader_prefetch_factor: int = field( + default=1, + metadata={"help": "Number of samples loaded in advance by each worker."}, + ) + dataloader_drop_last: bool = field( + default=False, + metadata={"help": "Whether to drop the last incomplete batch."}, + ) + dataloader_num_workers: int = field( + default=1, + metadata={"help": "Number of workers for the DataLoader."}, + ) + + +@dataclass +class PeftConfig: + """Configuration for PEFT (Parameter-Efficient Fine-Tuning) methods.""" + + lora_r: int = field( + default=8, + metadata={"help": "Lora attention dimension."}, + ) + lora_alpha: int = field( + default=16, + metadata={"help": "Lora alpha."}, + ) + lora_dropout: float = field( + default=0.1, + metadata={"help": "The dropout probability for Lora layers."}, + ) + target_modules: list[str] = field( + default_factory=lambda: ["q_proj", "v_proj"], + metadata={"help": "The modules to apply Lora to."}, + ) + bias: str = field( + default="none", + metadata={"help": "Bias type for Lora ('none', 'all', 'lora_only')."}, + ) + task_type: str = field( + default="CAUSAL_LM", + metadata={"help": "The task type for PEFT (e.g., 'CAUSAL_LM', 'SEQ_2_SEQ_LM')."}, + ) + peft_type: str = field( + default="LORA", + metadata={"help": "The PEFT method to use (e.g., 'LORA', 'IA3')."}, + ) + + +@dataclass +class ModelConfig: + """Configuration for models.""" + + model_name: str = field( + default="HuggingFaceTB/SmolLM-135M", + metadata={"help": "The name or path of the pretrained model."}, + ) + model_type: str = field( + default="hf", + metadata={"help": "The type of model ('hf' for Hugging Face, 'custom' for custom models)."}, + ) + auto_class_name: str = field( + default="AutoModelForCausalLM", + metadata={"help": "The AutoClass name to load the model (e.g., 'AutoModelForCausalLM')."}, + ) + load_in_4bit: bool = field( + default=False, + metadata={"help": "Whether to load the model in 4-bit quantization."}, + ) + use_peft: bool = field( + default=True, + metadata={"help": "Whether to use PEFT (Parameter-Efficient Fine-Tuning)."}, + ) + peft_config: Optional[PeftConfig] = field( + default_factory=PeftConfig, + metadata={"help": "Configuration for PEFT."}, + ) + use_cache: bool = field( + default=False, + metadata={"help": "Whether to use the past key/values in the model for faster decoding."}, + ) + attn_implementation: str = field( + default="sdpa", + metadata={"help": "The attention implementation to use (e.g., 'sdpa', 'eager')."}, + ) + device_map: Optional[str] = field( + default=None, + metadata={"help": "The device map to use for model distribution (e.g., 'auto')."}, + ) + + +@dataclass +class CallbackConfig: + """Configuration for callbacks.""" + + callbacks: Dict[str, Dict[str, Any]] = field( + default_factory=dict, + metadata={"help": "Dictionary of callback configurations, keyed by callback name."}, + ) + + +@dataclass +class GradientCheckpointingKwargs: + """Arguments for gradient checkpointing.""" + + preserve_rng_state: bool = field( + default=True, + metadata={"help": "Whether to preserve the RNG state when checkpointing."}, + ) + use_reenrant: bool = field( + default=False, + metadata={"help": "Whether to use reentrant gradient checkpointing."}, + ) + + +@dataclass +class DdpConfig: + """Arguments for Distributed Data Parallel (DDP) training.""" + + ddp_backend: str = field( + default="qccl", + metadata={"help": "The DDP backend to use (e.g., 'nccl', 'gloo', 'qccl')."}, + ) + ddp_find_unused_parameters: bool = field( + default=False, + metadata={"help": "Whether to find unused parameters in DDP."}, + ) + ddp_bucket_cap_mb: Optional[int] = field( + default=25, + metadata={"help": "The bucket size in MB for DDP communication."}, + ) + ddp_broadcast_buffers: bool = field( + default=True, + metadata={"help": "Whether to broadcast buffers in DDP."}, + ) + ddp_timeout: int = field( + default=1800, + metadata={"help": "Timeout for DDP operations in seconds."}, + ) + + +@dataclass +class TrainingConfig: + """Configuration for training.""" + + type: str = field( + default="sft", + metadata={"help": "The type of training (e.g., 'sft' for Supervised Fine-Tuning)."}, + ) + output_dir: str = field( + default="./training_results", + metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, + ) + overwrite_output_dir: bool = field( + default=False, + metadata={"help": "Whether to overwrite the output directory."}, + ) + seed: int = field( + default=42, + metadata={"help": "Random seed for reproducibility."}, + ) + device: str = field( + default="qaic", + metadata={"help": "The device to use for training ('cuda', 'cpu', etc.)."}, + ) + do_eval: bool = field( + default=True, + metadata={"help": "Whether to run evaluation during training."}, + ) + eval_strategy: str = field( + default="epoch", + metadata={"help": "The evaluation strategy to use ('no', 'steps', 'epoch')."}, + ) + eval_steps: int = field( + default=100, + metadata={"help": "Number of update steps between two evaluations."}, + ) + per_device_train_batch_size: int = field( + default=1, + metadata={"help": "Batch size per device during training."}, + ) + per_device_eval_batch_size: int = field( + default=1, + metadata={"help": "Batch size per device during evaluation."}, + ) + gradient_accumulation_steps: int = field( + default=1, + metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."}, + ) + num_train_epochs: int = field( + default=1, + metadata={"help": "Total number of training epochs to perform."}, + ) + max_steps: int = field( + default=-1, + metadata={"help": "If > 0: set total number of training steps to perform."}, + ) + + log_level: str = field( + default="info", + metadata={"help": "Set the verbosity level of the logs ('debug', 'info', 'warning', 'error')."}, + ) + log_on_each_node: bool = field( + default=True, + metadata={"help": "Whether to log on each node in a distributed setup."}, + ) + logging_strategy: str = field( + default="steps", + metadata={"help": "The logging strategy to use ('no', 'steps', 'epoch')."}, + ) + logging_steps: int = field( + default=10, + metadata={"help": "Number of update steps between two loggings."}, + ) + + save_strategy: str = field( + default="epoch", + metadata={"help": "The checkpoint save strategy to use ('no', 'steps', 'epoch')."}, + ) + save_steps: int = field( + default=100, + metadata={"help": "Number of update steps between two checkpoints (if save_strategy is 'steps')."}, + ) + save_total_limit: int = field( + default=5, + metadata={"help": "Limit the total amount of checkpoints. Deletes older checkpoints to stay within limit."}, + ) + metric_for_best_model: str = field( + default="eval_loss", + metadata={"help": "The metric to use to compare two models ('eval_loss', etc.)."}, + ) + + dtype: str = field( + default="fp16", + metadata={"help": "The data type to use for training (e.g., 'fp16', 'bf16')."}, + ) + + gradient_checkpointing: bool = field( + default=False, + metadata={"help": "Whether to use gradient checkpointing."}, + ) + gradient_checkpointing_kwargs: Optional[GradientCheckpointingKwargs] = field( + default_factory=GradientCheckpointingKwargs, + metadata={"help": "Arguments for gradient checkpointing."}, + ) + + torch_compile: bool = field( + default=True, + metadata={"help": "Whether to compile the model with `torch.compile`."}, + ) + include_num_input_tokens_seen: bool = field( + default=True, + metadata={"help": "Whether to include the number of input tokens seen in logs."}, + ) + average_tokens_across_devices: bool = field( + default=True, + metadata={"help": "Whether to average tokens across devices in distributed training."}, + ) + + disable_tqdm: Optional[bool] = field( + default=None, + metadata={"help": "Whether to disable the tqdm progress bar."}, + ) + fsdp_config: Optional[Dict[str, Any]] = field( + default=None, + metadata={"help": "FSDP configuration dictionary."}, + ) + deepspeed_config: Optional[Dict[str, Any]] = field( + default=None, + metadata={"help": "DeepSpeed configuration dictionary."}, + ) + accelerator_config: Optional[Dict[str, Any]] = field( + default=None, + metadata={"help": "Accelerate configuration dictionary."}, + ) + ddp_config: Optional[DdpConfig] = field( + default_factory=DdpConfig, + metadata={"help": "DDP configuration dictionary."}, + ) + use_cpu: Optional[bool] = field( + default=None, + metadata={"help": "Whether to explicitly run training on CPU."}, + ) + resume_from_checkpoint: Optional[str] = field( + default=None, + metadata={"help": "Path to a checkpoint to resume training from."}, + ) + restore_callback_states_from_checkpoint: Optional[bool] = field( + default=None, + metadata={"help": "Whether to restore callback states from checkpoint."}, + ) + report_to: Optional[List[str]] = field( + default=None, + metadata={"help": "The list of integrations to report the results and logs to."}, + ) + completion_only_loss: Optional[bool] = field( + default=False, + metadata={"help": "Whether to compute loss only on completion tokens."}, + ) + + +@dataclass +class MasterConfig: + """Main training configuration.""" + + model: ModelConfig = field(default_factory=ModelConfig, metadata={"help": "Configuration for the model."}) + + dataset: DatasetConfig = field(default_factory=DatasetConfig, metadata={"help": "Configuration for the dataset."}) + + optimizers: OptimizerConfig = field( + default_factory=OptimizerConfig, metadata={"help": "Configuration for optimizers."} + ) + + scheduler: SchedulerConfig = field( + default_factory=SchedulerConfig, metadata={"help": "Configuration for the learning rate scheduler."} + ) + + callbacks: CallbackConfig = field(default_factory=CallbackConfig, metadata={"help": "Configuration for callbacks."}) + + training: TrainingConfig = field( + default_factory=TrainingConfig, metadata={"help": "Configuration for training parameters."} + ) + + extra_params: Dict[str, Any] = field( + default_factory=dict, metadata={"help": "Additional top-level parameters not explicitly defined."} + ) + + +def parse_arguments(config_path: Optional[str] = None, args: Optional[List[str]] = None) -> MasterConfig: + """Create argument parser for the new finetuning interface.""" + parser = HfArgumentParser(MasterConfig) + + if config_path: + config_path = os.path.abspath(config_path) + if not os.path.exists(config_path): + raise FileNotFoundError(f"Config file not found: {config_path}") + if not (config_path.endswith(".yaml") or config_path.endswith(".yml")): + raise ValueError(f"Expected a .yaml/.yml file, got: {config_path}") + + try: + (master_config,) = parser.parse_yaml_file(yaml_file=config_path) + return master_config + except Exception as e: + raise ValueError(f"Failed to parse YAML config '{config_path}': {e}") + + args = [] if args is None else args + # If a single positional YAML file was passed via args, parse it as YAML + if len(args) == 1 and (args[0].endswith(".yaml") or args[0].endswith(".yml")): + yaml_path = os.path.abspath(args[0]) + (master_config,) = parser.parse_yaml_file(yaml_file=yaml_path) + else: + (master_config,) = parser.parse_args_into_dataclasses(args=args) + master_config = asdict(master_config) + master_config = MasterConfig(**master_config) + + return master_config + + +class ConfigManager: + """Manages configuration loading, validation, and updates.""" + + def __init__(self, config: MasterConfig): + """ + Initialize ConfigManager with either: + - Path to config file (str or Path) + - Configuration dictionary + - None (creates empty config) + """ + self.config = config + + def load_config(self, config_path: Union[str, Path]) -> None: + """Load configuration from file.""" + config_path = Path(config_path) + + if not config_path.exists(): + raise FileNotFoundError(f"Configuration file not found: {config_path}") + + if config_path.suffix.lower() in [".yaml", ".yml"]: + with open(config_path, "r") as f: + config_dict = yaml.safe_load(f) + elif config_path.suffix.lower() == ".json": + with open(config_path, "r") as f: + config_dict = json.load(f) + else: + raise ValueError(f"Unsupported configuration file format: {config_path.suffix}") + + self.update_config(config_dict) + + def _ensure_extra_params(self, obj) -> Dict[str, Any]: + """Ensure obj.extra_params exists and is a dict; return it.""" + ep = getattr(obj, "extra_params", None) + if ep is None: + setattr(obj, "extra_params", {}) + ep = obj.extra_params + if not isinstance(ep, dict): + raise TypeError("extra_params must be a dict.") + return ep + + def _stash_top_level_extra(self, section: str, nested_key: str, value: Any) -> None: + """Store unknown nested values under MasterConfig.extra_params['section.nested_key'].""" + ep = self._ensure_extra_params(self.config) + ep[f"{section}.{nested_key}"] = value + + def update_config(self, config_dict: Dict[str, Any]) -> None: + """Update configuration with dictionary values.""" + + SPECIAL_KEYS = {"callbacks"} + + for key, value in config_dict.items(): + if hasattr(self.config, key): + target = getattr(self.config, key) + + # Special handling for callbacks (dict inside CallbackConfig) + if key in SPECIAL_KEYS and isinstance(value, dict): + if is_dataclass(target) and hasattr(target, "callbacks") and isinstance(target.callbacks, dict): + for component_name, component_cfg in value.items(): + target.callbacks[component_name] = component_cfg + elif isinstance(target, dict): + target.update(value) + else: + self._stash_top_level_extra(key, "__all__", value) + continue + + if isinstance(value, dict) and is_dataclass(target): + known = {f.name for f in fields(target)} + for nested_key, nested_value in value.items(): + if nested_key in known: + setattr(target, nested_key, nested_value) + else: + self._stash_top_level_extra(key, nested_key, nested_value) + continue + + if isinstance(value, dict) and isinstance(target, dict): + target.update(value) + continue + setattr(self.config, key, value) + + else: + ep = self._ensure_extra_params(self.config) + ep[key] = value + + def save_config(self, output_path: Union[str, Path]) -> None: + """Save current configuration to file.""" + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + config_dict = self.config + + if output_path.suffix.lower() in [".yaml", ".yml"]: + with open(output_path, "w") as f: + yaml.dump(config_dict, f, default_flow_style=False, indent=2) + elif output_path.suffix.lower() == ".json": + with open(output_path, "w") as f: + json.dump(config_dict, f, indent=2) + else: + raise ValueError(f"Unsupported output file format: {output_path.suffix}") + + def _push(self, errs: List[str], cond: bool, msg: str) -> None: + """Append msg to errs if cond is True.""" + if cond: + errs.append(msg) + + def validate_config(self) -> None: + """ + Validate configuration parameters for MasterConfig. + """ + errors: List[str] = [] + + cfg = self.config + model = getattr(cfg, "model", {}) + dataset = getattr(cfg, "dataset", {}) + training = getattr(cfg, "training", {}) + + # ---------- Model ---------- + self._push(errors, not model.get("model_name"), "model.model_name is required.") + + # PEFT validation + if model.get("use_peft"): + pc = model.get("peft_config", {}) + self._push(errors, not isinstance(pc, dict), "model.peft_config must be a dict when use_peft=True.") + if isinstance(pc, dict): + self._push( + errors, + not isinstance(pc.get("lora_r", 0), int) or pc.get("lora_r", 0) <= 0, + "model.peft_config.lora_r must be a positive integer.", + ) + self._push( + errors, + not isinstance(pc.get("lora_alpha", 0), int) or pc.get("lora_alpha", 0) <= 0, + "model.peft_config.lora_alpha must be a positive integer.", + ) + self._push( + errors, + not (0.0 <= float(pc.get("lora_dropout", 0.0)) < 1.0), + "model.peft_config.lora_dropout must be in [0,1).", + ) + + # ---------- Dataset ---------- + self._push(errors, not dataset.get("dataset_name"), "dataset.dataset_name is required.") + self._push(errors, not dataset.get("tokenizer_name"), "dataset.tokenizer_name is required.") + self._push(errors, dataset.get("max_seq_length", 0) <= 0, "dataset.max_seq_length must be positive.") + + # ---------- Training ---------- + # Batch sizes + self._push( + errors, + training.get("per_device_train_batch_size", 0) <= 0, + "training.per_device_train_batch_size must be positive.", + ) + self._push( + errors, + training.get("per_device_eval_batch_size", 0) <= 0, + "training.per_device_eval_batch_size must be positive.", + ) + + # Epochs / steps + n_epochs = training.get("num_train_epochs", 0) + max_steps = training.get("max_steps", -1) + self._push( + errors, + n_epochs <= 0 and max_steps <= 0, + "Either training.num_train_epochs > 0 or training.max_steps > 0 must be set.", + ) + + # Gradient accumulation + self._push( + errors, + training.get("gradient_accumulation_steps", 0) <= 0, + "training.gradient_accumulation_steps must be positive.", + ) + + # Logging / saving configs + self._push(errors, training.get("logging_steps", 0) < 0, "training.logging_steps must be >= 0.") + self._push(errors, training.get("save_total_limit", 0) < 0, "training.save_total_limit must be >= 0.") + + # Device + valid_devices = ["cpu", "cuda", "qaic"] + training_device = training.get("device", None) + if training_device not in valid_devices: + self._push(errors, training_device not in valid_devices, f"training.device must be one of {valid_devices}.") + + # DDP config + ddp = training.get("ddp_config", {}) + if isinstance(ddp, dict): + backend = ddp.get("ddp_backend") + # Accept qccl for Qualcomm, nccl for CUDA, gloo for CPU + self._push( + errors, + backend not in {"qccl", "nccl", "gloo", None}, + "training.ddp_config.ddp_backend must be one of {'qccl','nccl','gloo'} or omitted.", + ) + + # ---------- Final ---------- + if errors: + # Join messages with bullet points for readability + raise ValueError("Configuration validation failed:\n- " + "\n- ".join(errors)) + + def get_callback_config(self) -> Dict[str, Any]: + """Get callback configuration as dictionary.""" + return self.config.callbacks + + def get_optimizer_config(self) -> Dict[str, Any]: + """Get optimizer configuration as dictionary.""" + return self.config.optimizers + + def get_training_config(self) -> Dict[str, Any]: + """Get training configuration as dictionary.""" + return self.config.training + + def get_scheduler_config(self) -> Dict[str, Any]: + """Get scheduler configuration as dictionary.""" + return self.config.scheduler + + def get_dataset_config(self) -> Dict[str, Any]: + """Get dataset configuration as dictionary.""" + return self.config.dataset + + def get_model_config(self) -> Dict[str, Any]: + """Get model configuration as dictionary.""" + return self.config.model + + def to_dict(self) -> Dict[str, Any]: + """Convert configuration to dictionary.""" + return asdict(self.config) + + def __getattr__(self, name: str) -> Any: + """Allow direct access to config attributes.""" + if hasattr(self.config, name): + return getattr(self.config, name) + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + + +def create_trainer_config(name: str, **dependencies) -> tuple: + """ + Create trainer configuration based on registered trainer modules. + + Args: + name: Name of the trainer type + **dependencies: Any dependencies needed to configure the trainer + + Returns: + tuple: (trainer_class, args_class, additional_kwargs) + """ + config = registry.get_trainer_module(name) + + # Process required kwargs based on available dependencies + additional_kwargs = {} + for kwarg, default in config["required_kwargs"].items(): + if kwarg in dependencies: + additional_kwargs[kwarg] = dependencies[kwarg] + elif default != "REQUIRED": + additional_kwargs[kwarg] = default + + # Check for missing required arguments + for kwarg, default in config["required_kwargs"].items(): + if kwarg not in additional_kwargs and default == "REQUIRED": + raise ValueError(f"Required argument '{kwarg}' not provided for trainer '{name}'") + + return config["trainer_cls"], config["args_cls"], additional_kwargs diff --git a/QEfficient/finetune/experimental/tests/test_config.yaml b/QEfficient/finetune/experimental/tests/test_config.yaml new file mode 100644 index 0000000000..e97e99d583 --- /dev/null +++ b/QEfficient/finetune/experimental/tests/test_config.yaml @@ -0,0 +1,104 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +# model configuration +model: + model_type: "hf" + auto_class_name: "AutoModelForCausalLM" + model_name: "HuggingFaceTB/SmolLM-135M" # Pretrained model name + load_in_4bit: false + use_peft: true + peft_config: + lora_r: 8 + lora_alpha: 16 + lora_dropout: 0.1 + target_modules: ["q_proj", "v_proj"] + bias: "none" + task_type: "CAUSAL_LM" + peft_type: "LORA" + +# Dataset configuration +dataset: + tokenizer_name: "HuggingFaceTB/SmolLM-135M" + dataset_type: "seq_completion" + # dataset_name: "Arthur-LAGACHERIE/very-smollm-corpus-0.5M" + dataset_name: "knkarthick/samsum" + train_split: "train" + max_seq_length: 512 + split_ratio: 0.8 # Ratio for train/test split, used when only train_split is provided + test_split: "test" + group_by_length: True + num_workers: 4 + dataloader_pin_memory: True + dataloader_persistent_workers: True + dataloader_prefetch_factor: 1 + dataloader_drop_last: False + +# Training configuration +training: + type: "sft" + output_dir: "./training_results" + overwrite_output_dir: False + seed: 42 + device: "qaic" + do_eval: True + eval_strategy: "epoch" + eval_steps: 100 + + per_device_train_batch_size: 1 + per_device_eval_batch_size: 1 + gradient_accumulation_steps: 1 + num_train_epochs: 1 + max_steps: -1 + + log_level: "info" + log_on_each_node: True + logging_strategy: "steps" + logging_steps: 10 + + save_strategy: "epoch" + save_total_limit: 5 + metric_for_best_model: "eval_loss" + + dtype: "fp16" + completion_only_loss: True + report_to: "trackio" + + ddp_config: + ddp_backend: "qccl" + ddp_find_unused_parameters: False + ddp_bucket_cap_mb: 25 + ddp_broadcast_buffers: null + ddp_timeout: 1800 + + use_cpu: False + + gradient_checkpointing: False + gradient_checkpointing_kwargs: + preserve_rng_state : True + use_reenrant: False + + torch_compile: True + include_num_input_tokens_seen: True + average_tokens_across_devices: True + +# Optimizer configuration +optimizers: + optimizer_name: "adamw" + lr: 5e-5 + weight_decay: 0.01 + +scheduler: + scheduler_name: "cosine" + warmup_steps: 100 # warmup_steps or warmup_ratio + +callbacks: + early_stopping: + early_stopping_patience: 3 + early_stopping_threshold: 0.001 + tensorboard: + diff --git a/QEfficient/finetune/experimental/tests/test_config_manager.py b/QEfficient/finetune/experimental/tests/test_config_manager.py new file mode 100644 index 0000000000..fd2abfd482 --- /dev/null +++ b/QEfficient/finetune/experimental/tests/test_config_manager.py @@ -0,0 +1,62 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + + +from pathlib import Path + +import pytest + +from QEfficient.finetune.experimental.core.config_manager import ConfigManager, parse_arguments + + +@pytest.fixture +def config_path() -> Path: + here = Path(__file__).resolve().parent + return (here / "test_config.yaml").resolve() + + +def test_config(config_path): + master_config = parse_arguments(args=[]) + config_manager = ConfigManager(master_config) + assert isinstance(config_manager, ConfigManager) + config_manager.load_config(config_path) + try: + config_manager.validate_config() + except Exception as e: + pytest.fail(f"Config validation failed with error: {e}") + + # Test that all required fields are present + missing = [ + a + for a in ("model", "dataset", "optimizers", "scheduler", "callbacks", "training") + if not hasattr(config_manager, a) + ] + assert not missing, f"Missing attributes: {missing}" + trainer_config = config_manager.get_training_config() + assert trainer_config is not None + assert isinstance(trainer_config, dict) + assert (hasattr(trainer_config, attr) for attr in ("output_dir", "train_batch_size", "num_epochs", "ddp_config")) + dataset_config = config_manager.get_dataset_config() + assert dataset_config is not None + assert isinstance(dataset_config, dict) + assert (hasattr(dataset_config, attr) for attr in ("dataset_type", "dataset_name", "tokenizer_name")) + model_config = config_manager.get_model_config() + assert model_config is not None + assert isinstance(model_config, dict) + assert (hasattr(model_config, attr) for attr in ("model_type", "model_name", "use_peft", "peft_config")) + scheduler_config = config_manager.get_scheduler_config() + assert scheduler_config is not None + assert isinstance(scheduler_config, dict) + assert (hasattr(scheduler_config, attr) for attr in ("scheduler_name")) + callback_config = config_manager.get_callback_config() + assert callback_config is not None + assert isinstance(callback_config, dict) + assert (hasattr(callback_config, attr) for attr in ("earlystopping")) + optimizer_config = config_manager.get_optimizer_config() + assert optimizer_config is not None + assert isinstance(optimizer_config, dict) + assert (hasattr(optimizer_config, attr) for attr in ("optimizer_name", "lr")) From 0768935facd6820ed11aba933eba84975bffc788 Mon Sep 17 00:00:00 2001 From: Swati Allabadi Date: Thu, 25 Dec 2025 06:38:46 +0530 Subject: [PATCH 8/9] [QEff. Finetune]: Adding base class and HF class (#658) - Added Base Model class and HF model class. - Base Model class will support FT for any custom model and will be a common skeleton for any model, including any HF model. - Added unit tests for these. --------- Signed-off-by: Swati Allabadi Co-authored-by: Swati Allabadi Signed-off-by: Sharvari Medhe --- .../experimental/core/component_registry.py | 12 +- .../finetune/experimental/core/model.py | 132 +++++++++++++++++ .../finetune/experimental/tests/test_model.py | 136 ++++++++++++++++++ 3 files changed, 279 insertions(+), 1 deletion(-) create mode 100644 QEfficient/finetune/experimental/tests/test_model.py diff --git a/QEfficient/finetune/experimental/core/component_registry.py b/QEfficient/finetune/experimental/core/component_registry.py index 7744d71e6a..d1f9480311 100644 --- a/QEfficient/finetune/experimental/core/component_registry.py +++ b/QEfficient/finetune/experimental/core/component_registry.py @@ -5,7 +5,6 @@ # # ----------------------------------------------------------------------------- - import logging from typing import Callable, Dict, Optional, Type @@ -198,3 +197,14 @@ def list_callbacks(self) -> list[str]: # Global registry instance registry = ComponentRegistry() + + +class ComponentFactory: + @staticmethod + def create_model(model_type: str, model_name: str, **kwargs) -> any: + """Create a model instance.""" + model_class = registry.get_model(model_type) + if model_class is None: + raise ValueError(f"Unknown model: {model_type}. Available: {registry.list_models()}") + model_instance = model_class.create(model_name, **kwargs) + return model_instance diff --git a/QEfficient/finetune/experimental/core/model.py b/QEfficient/finetune/experimental/core/model.py index d647b73a65..0f087e6653 100644 --- a/QEfficient/finetune/experimental/core/model.py +++ b/QEfficient/finetune/experimental/core/model.py @@ -4,3 +4,135 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- + +import warnings +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional, Type + +import torch.nn as nn +import transformers +from transformers import AutoTokenizer + +from QEfficient.finetune.experimental.core.component_registry import registry +from QEfficient.finetune.experimental.core.logger import Logger +from QEfficient.finetune.experimental.core.utils.dataset_utils import insert_pad_token + +logger = Logger(__name__) + + +class BaseModel(nn.Module, ABC): + """Shared skeleton for every finetunable model in the system.""" + + def __init__(self, model_name: str, **model_kwargs: Any) -> None: + super().__init__() + self.model_name = model_name + self.model_kwargs: Dict[str, Any] = model_kwargs + self._model: Optional[nn.Module] = None + self._tokenizer: Any = None # HF tokenizers are not nn.Modules. + + # Factory constructor: load model after __init__ finishes + @classmethod + def create(cls, model_name: str, **model_kwargs: Any) -> "BaseModel": + obj = cls(model_name, **model_kwargs) + # load model after __init__ finishes + module = obj.load_model() + if not isinstance(module, nn.Module): + raise TypeError(f"load_model() must return nn.Module, got {type(module)}") + obj._model = module + return obj + + @abstractmethod + def load_model(self) -> nn.Module: + """Load and return the underlying torch.nn.Module.""" + pass + + def load_tokenizer(self) -> Any: + """Override if the model exposes a tokenizer.""" + warnings.warn(f"{type(self).__name__} does not provide a tokenizer.", category=UserWarning) + return None + + # Lazy accessors + @property + def model(self) -> nn.Module: + if self._model is None: + raise RuntimeError("Model not loaded; use .create(...) to load.") + return self._model + + @property + def tokenizer(self) -> Any: + if self._tokenizer is None: + self._tokenizer = self.load_tokenizer() + return self._tokenizer + + # nn.Module API surface + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def to(self, *args, **kwargs): + self.model.to(*args, **kwargs) + return self + + def train(self, mode: bool = True): + self.model.train(mode) + return super().train(mode) + + def eval(self): + return self.train(False) + + +@registry.model("hf") +class HFModel(BaseModel): + """HuggingFace-backed model with optional quantization.""" + + def __init__( + self, + model_name: str, + auto_class_name: str = "AutoModelForCausalLM", + *, + tokenizer_name: Optional[str] = None, + **model_kwargs: Any, + ) -> None: + super().__init__(model_name, **model_kwargs) + self.tokenizer_name = tokenizer_name or model_name + self.auto_class: Type = self._resolve_auto_class(auto_class_name) + + @staticmethod + def _resolve_auto_class(auto_class_name: str) -> Type: + if not hasattr(transformers, auto_class_name): + candidates = sorted(name for name in dir(transformers) if name.startswith("AutoModel")) + raise ValueError( + f"Unsupported Auto class '{auto_class_name}'. Available candidates: {', '.join(candidates)}" + ) + return getattr(transformers, auto_class_name) + + # def _build_quant_config(self) -> Optional[BitsAndBytesConfig]: + # if not self.model_kwargs.get("load_in_4bit"): + # return None + # return BitsAndBytesConfig( + # load_in_4bit=True, + # bnb_4bit_quant_type=self.model_kwargs.get("bnb_4bit_quant_type", "nf4"), + # bnb_4bit_compute_dtype=self.model_kwargs.get("bnb_4bit_compute_dtype", torch.float16), + # bnb_4bit_use_double_quant=self.model_kwargs.get("bnb_4bit_use_double_quant", True), + # ) + + def configure_model_kwargs(self) -> Dict[str, Any]: + """Hook for subclasses to tweak HF `.from_pretrained` kwargs.""" + + extra = dict(self.model_kwargs) + # extra["quantization_config"] = self._build_quant_config() + return extra + + def load_model(self) -> nn.Module: + logger.log_rank_zero(f"Loading HuggingFace model '{self.model_name}' via {self.auto_class.__name__}") + + return self.auto_class.from_pretrained( + self.model_name, + **self.configure_model_kwargs(), + ) + + def load_tokenizer(self) -> AutoTokenizer: + """Load Hugging Face tokenizer.""" + logger.log_rank_zero(f"Loading tokenizer '{self.tokenizer_name}'") + tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name) + insert_pad_token(tokenizer) + return tokenizer diff --git a/QEfficient/finetune/experimental/tests/test_model.py b/QEfficient/finetune/experimental/tests/test_model.py new file mode 100644 index 0000000000..e83abf3898 --- /dev/null +++ b/QEfficient/finetune/experimental/tests/test_model.py @@ -0,0 +1,136 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from unittest import mock + +import pytest +import torch +import torch.nn as nn + +from QEfficient.finetune.experimental.core import model +from QEfficient.finetune.experimental.core.component_registry import ComponentFactory, registry +from QEfficient.finetune.experimental.core.model import BaseModel + + +class TestMockModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(2, 2) + + def forward(self, x): + return self.linear(x) + + +@registry.model("testcustom") +class TestCustomModel(BaseModel): + def __init__(self, model_name): + super().__init__(model_name) + print("init of custom class") + + def load_model(self) -> nn.Module: + return TestMockModel() + + def load_tokenizer(self): + return "dummy-tokenizer" + + +# BaseModel tests +def test_model_property_errors_if_not_created(): + m = TestCustomModel("dummy") + with pytest.raises(RuntimeError): + _ = m.model # must call .create() + + +def test_create_builds_and_registers(): + m = ComponentFactory.create_model("testcustom", "dummy") + # inner model exists and registered + assert "_model" in m._modules + assert isinstance(m.model, TestMockModel) + # forward works + out = m(torch.zeros(1, 2)) + assert out.shape == (1, 2) + + +def test_tokenizer_lazy_loading(): + m = ComponentFactory.create_model("testcustom", "dummy") + assert m._tokenizer is None + tok = m.tokenizer + assert tok == "dummy-tokenizer" + assert m._tokenizer == tok + + +def test_to_moves_inner_and_returns_self(): + m = ComponentFactory.create_model("testcustom", "dummy") + with mock.patch.object(TestMockModel, "to", autospec=True) as mocked_to: + ret = m.to("cpu:0") + assert mocked_to.call_args[0][0] is m.model + assert mocked_to.call_args[0][1] == "cpu:0" + assert ret is m + + +def test_train_eval_sync_flags(): + m = ComponentFactory.create_model("testcustom", "dummy") + m.eval() + assert m.training is False + assert m.model.training is False + m.train() + assert m.training is True + assert m.model.training is True + + +def test_state_dict_contains_inner_params(): + m = ComponentFactory.create_model("testcustom", "dummy") + sd = m.state_dict() + # should contain params from TestMockModel.linear + assert any("linear.weight" in k for k in sd) + assert any("linear.bias" in k for k in sd) + + +# HFModel tests +def test_hfmodel_invalid_auto_class_raises(): + with pytest.raises(ValueError): + ComponentFactory.create_model("hf", "hf-name", auto_class_name="AutoDoesNotExist") + + +def test_hfmodel_loads_auto_and_tokenizer(monkeypatch): + # fake HF Auto class + class FakeAuto(nn.Module): + @classmethod + def from_pretrained(cls, name, **kwargs): + inst = cls() + inst.loaded = (name, kwargs) + return inst + + def forward(self, x): + return x + + fake_tok = mock.Mock() + + # Monkeypatch transformer classes used in HFModel + monkeypatch.setattr( + "QEfficient.finetune.experimental.core.model.transformers.AutoModelForCausalLM", + FakeAuto, + raising=False, + ) + monkeypatch.setattr( + model, + "AutoTokenizer", + mock.Mock(from_pretrained=mock.Mock(return_value=fake_tok)), + ) + monkeypatch.setattr( + "QEfficient.finetune.experimental.core.model.insert_pad_token", + mock.Mock(), + raising=False, + ) + m = ComponentFactory.create_model("hf", "hf-name") + assert isinstance(m.model, FakeAuto) + + # load tokenizer + tok = m.load_tokenizer() + + assert hasattr(tok, "pad_token_id") + assert m.model.loaded[0] == "hf-name" From 59e93ef02d3a955be4bcad45e365d5335c7847db Mon Sep 17 00:00:00 2001 From: Dhiraj Kumar Sah Date: Fri, 2 Jan 2026 19:29:15 +0530 Subject: [PATCH 9/9] Added Trainer classes and tests for FT (#697) This PR contains all the changes of PR #660 along with all the comments being addressed. The new PR was created due a rebase issue. Signed-off-by: Dhiraj Kumar Sah --- .../experimental/core/trainer/base_trainer.py | 73 +++ .../experimental/core/trainer/sft_trainer.py | 9 + .../experimental/tests/test_trainer.py | 493 ++++++++++++++++++ 3 files changed, 575 insertions(+) create mode 100644 QEfficient/finetune/experimental/tests/test_trainer.py diff --git a/QEfficient/finetune/experimental/core/trainer/base_trainer.py b/QEfficient/finetune/experimental/core/trainer/base_trainer.py index d647b73a65..0a3c50f7f1 100644 --- a/QEfficient/finetune/experimental/core/trainer/base_trainer.py +++ b/QEfficient/finetune/experimental/core/trainer/base_trainer.py @@ -4,3 +4,76 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- +from typing import Optional + +from peft import get_peft_model +from transformers import Trainer, TrainingArguments + +from QEfficient.finetune.experimental.core.component_registry import registry +from QEfficient.finetune.experimental.core.config_manager import PeftConfig + + +@registry.trainer_module(name="base", args_cls=TrainingArguments, required_kwargs={"peft_config": PeftConfig}) +class BaseTrainer(Trainer): + """ + Extended Trainer class that supports PEFT (Parameter-Efficient Fine-Tuning). + + This trainer extends the standard HuggingFace Trainer to optionally apply + PEFT configurations to the model before training. + """ + + def __init__( + self, + model=None, + args=None, + data_collator=None, + train_dataset=None, + eval_dataset=None, + processing_class=None, + model_init=None, + compute_metrics=None, + callbacks=None, + optimizers=(None, None), + preprocess_logits_for_metrics=None, + peft_config: Optional[PeftConfig] = None, + **kwargs, + ): + """ + Initialize the BaseTrainer with optional PEFT support. + + Args: + model: The model to train + args: Training arguments + data_collator: Data collator for batching + train_dataset: Training dataset + eval_dataset: Evaluation dataset + processing_class: Tokenizer or processor + model_init: Function to initialize model + compute_metrics: Function to compute metrics + callbacks: List of callbacks + optimizers: Tuple of (optimizer, scheduler) + preprocess_logits_for_metrics: Function to preprocess logits + peft_config: Optional PEFT configuration. If provided, the model will be + wrapped with PEFT before training. + **kwargs: Additional keyword arguments + """ + # Apply PEFT to model if peft_config is provided + if peft_config is not None and model is not None: + model = get_peft_model(model, peft_config) + model.print_trainable_parameters() + + # Initialize the parent Trainer class + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + **kwargs, + ) diff --git a/QEfficient/finetune/experimental/core/trainer/sft_trainer.py b/QEfficient/finetune/experimental/core/trainer/sft_trainer.py index d647b73a65..3223c5966b 100644 --- a/QEfficient/finetune/experimental/core/trainer/sft_trainer.py +++ b/QEfficient/finetune/experimental/core/trainer/sft_trainer.py @@ -4,3 +4,12 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- +from trl import SFTConfig, SFTTrainer + +from QEfficient.finetune.experimental.core.component_registry import registry +from QEfficient.finetune.experimental.core.config_manager import PeftConfig + + +@registry.trainer_module(name="sft", args_cls=SFTConfig, required_kwargs={"peft_config": PeftConfig}) +class SFTTrainerModule(SFTTrainer): + pass # Just using the standard SFTTrainer diff --git a/QEfficient/finetune/experimental/tests/test_trainer.py b/QEfficient/finetune/experimental/tests/test_trainer.py new file mode 100644 index 0000000000..20af61e36c --- /dev/null +++ b/QEfficient/finetune/experimental/tests/test_trainer.py @@ -0,0 +1,493 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import os +import shutil + +import pytest +import torch +from datasets import Dataset +from peft import LoraConfig, PeftModel +from transformers import Trainer, TrainingArguments +from trl import SFTConfig, SFTTrainer + +from QEfficient.finetune.experimental.core.component_registry import ComponentFactory, registry +from QEfficient.finetune.experimental.core.model import HFModel # noqa: F401 - needed for registration +from QEfficient.finetune.experimental.core.trainer.base_trainer import BaseTrainer +from QEfficient.finetune.experimental.core.trainer.sft_trainer import ( + SFTTrainerModule, +) + +LORA_R = 8 +LORA_ALPHA = 16 +LORA_DROPOUT = 0.1 +MAX_LENGTH = 128 + + +class TestBaseTrainer: + """Test suite for BaseTrainer class.""" + + def test_base_trainer_registered(self): + """Test that BaseTrainer is registered in the registry.""" + trainer_list = registry.list_trainer_modules() + assert "base" in trainer_list + + def test_base_trainer_info_structure(self): + """Test that BaseTrainer registration has correct structure.""" + trainer_info = registry.get_trainer_module("base") + + assert isinstance(trainer_info, dict) + assert "trainer_cls" in trainer_info + assert "args_cls" in trainer_info + assert "required_kwargs" in trainer_info + + def test_base_trainer_class(self): + """Test that BaseTrainer class is correct.""" + + trainer_info = registry.get_trainer_module("base") + trainer_cls = trainer_info["trainer_cls"] + + # The decorator returns the dict, but BaseTrainer is the original class + assert trainer_cls.__name__ == "BaseTrainer" + assert issubclass(trainer_cls, Trainer) + assert trainer_info["args_cls"] == TrainingArguments + + def test_base_trainer_required_kwargs(self): + """Test that BaseTrainer has peft_config in required_kwargs.""" + trainer_info = registry.get_trainer_module("base") + + assert "peft_config" in trainer_info["required_kwargs"] + assert callable(trainer_info["required_kwargs"]["peft_config"]) + + +class TestSFTTrainerModule: + """Test suite for SFTTrainerModule class.""" + + def test_sft_trainer_registered(self): + """Test that SFTTrainerModule is registered in the registry.""" + trainer_list = registry.list_trainer_modules() + assert "sft" in trainer_list + + def test_sft_trainer_info_structure(self): + """Test that SFTTrainerModule registration has correct structure.""" + trainer_info = registry.get_trainer_module("sft") + + assert isinstance(trainer_info, dict) + assert "trainer_cls" in trainer_info + assert "args_cls" in trainer_info + assert "required_kwargs" in trainer_info + + def test_sft_trainer_class(self): + """Test that SFTTrainerModule class is correct.""" + + trainer_info = registry.get_trainer_module("sft") + trainer_cls = trainer_info["trainer_cls"] + + assert trainer_cls == SFTTrainerModule["trainer_cls"] + assert issubclass(trainer_cls, SFTTrainer) + assert trainer_info["args_cls"] == SFTConfig + + def test_sft_trainer_required_kwargs(self): + """Test that SFTTrainerModule has peft_config in required_kwargs.""" + trainer_info = registry.get_trainer_module("sft") + + assert "peft_config" in trainer_info["required_kwargs"] + assert callable(trainer_info["required_kwargs"]["peft_config"]) + + +class TestTrainerRegistry: + """Test suite for trainer registration in the component registry.""" + + def test_both_trainers_registered(self): + """Test that both base and sft trainers are registered.""" + trainer_list = registry.list_trainer_modules() + + assert "base" in trainer_list + assert "sft" in trainer_list + assert len(trainer_list) >= 2 + + def test_registry_returns_dict(self): + """Test that registry returns dict for trainer modules.""" + base_info = registry.get_trainer_module("base") + sft_info = registry.get_trainer_module("sft") + + assert isinstance(base_info, dict) + assert isinstance(sft_info, dict) + + def test_trainer_classes_correct(self): + """Test that trainer classes are correctly stored.""" + base_info = registry.get_trainer_module("base") + sft_info = registry.get_trainer_module("sft") + assert base_info["trainer_cls"] == BaseTrainer["trainer_cls"] + assert sft_info["trainer_cls"] == SFTTrainerModule["trainer_cls"] + + +class TestBaseTrainerWithModel: + """Test suite for BaseTrainer integration with model loading and PEFT.""" + + @pytest.fixture(autouse=True) + def cleanup_output_dirs(self): + """Fixture to clean up test output directories after each test.""" + # Setup: yield control to the test + yield + + # Teardown: clean up output directories + output_dirs = ["./test_output", "./test_output_peft", "./test_output_base", "./test_output_base_peft"] + for output_dir in output_dirs: + if os.path.exists(output_dir): + try: + shutil.rmtree(output_dir) + print(f"\nCleaned up: {output_dir}") + except Exception as e: + print(f"\nWarning: Failed to clean up {output_dir}: {e}") + + @pytest.fixture + def model_config(self): + """Fixture for basic model configuration.""" + return { + "model_name": "HuggingFaceTB/SmolLM-135M", + "auto_class_name": "AutoModelForCausalLM", + "use_cache": False, + "torch_dtype": "float16", + "attn_implementation": "eager", + "device_map": None, + "num_hidden_layers": 1, + } + + @pytest.fixture + def peft_model_config(self): + """Fixture for PEFT configuration.""" + return { + "r": LORA_R, + "lora_alpha": LORA_ALPHA, + "lora_dropout": LORA_DROPOUT, + "target_modules": ["q_proj", "v_proj"], + "bias": "none", + } + + @pytest.fixture + def dummy_dataset(self): + """Fixture for creating a dummy dataset.""" + data = { + "text": [ + "This is a test sentence for training.", + "Another example text for the model.", + "Third sample to ensure proper batching.", + ] + } + return Dataset.from_dict(data) + + def test_base_trainer_instantiation_with_model(self, model_config, dummy_dataset): + """Test that BaseTrainer can be instantiated with a loaded model.""" + # Load model and tokenizer + model_name = model_config.pop("model_name") + hf_model = ComponentFactory.create_model("hf", model_name, **model_config) + model = hf_model.model + tokenizer = hf_model.tokenizer + + # Create training config + training_args = TrainingArguments( + output_dir="./test_output_base", + per_device_train_batch_size=1, + num_train_epochs=1, + logging_steps=1, + save_strategy="no", + bf16=False, + fp16=True, + ) + + # Get BaseTrainer from registry + trainer_info = registry.get_trainer_module("base") + trainer_cls = trainer_info["trainer_cls"] + + # Instantiate trainer without PEFT + trainer = trainer_cls( + model=model, + args=training_args, + train_dataset=dummy_dataset, + processing_class=tokenizer, + ) + + assert trainer is not None + assert trainer.model is not None + assert trainer.processing_class is not None + + def test_base_trainer_with_peft_model(self, model_config, peft_model_config, dummy_dataset): + """Test that BaseTrainer works with PEFT-enabled models.""" + # Load model and tokenizer + model_name = model_config.pop("model_name") + hf_model = ComponentFactory.create_model("hf", model_name, **model_config) + model = hf_model.model + tokenizer = hf_model.tokenizer + + # Load PEFT Config + peft_config = LoraConfig(**peft_model_config) + + # Create training config + training_args = TrainingArguments( + output_dir="./test_output_base_peft", + per_device_train_batch_size=1, + num_train_epochs=1, + logging_steps=1, + save_strategy="no", + bf16=False, + fp16=True, + ) + + # Get BaseTrainer from registry + trainer_info = registry.get_trainer_module("base") + trainer_cls = trainer_info["trainer_cls"] + + # Instantiate trainer with PEFT config + trainer = trainer_cls( + model=model, + args=training_args, + train_dataset=dummy_dataset, + processing_class=tokenizer, + peft_config=peft_config, + ) + + assert trainer is not None + assert trainer.model is not None + + # Verify that the model is now a PEFT model + assert isinstance(trainer.model, PeftModel), "Model should be wrapped as a PeftModel" + + # Verify that the model has the expected PEFT config + assert hasattr(trainer.model, "peft_config"), "Model should have peft_config attribute" + assert trainer.model.peft_config is not None, "PEFT config should not be None" + + # Verify trainable parameters are reduced (PEFT should make only a subset trainable) + trainable_params = sum(p.numel() for p in trainer.model.parameters() if p.requires_grad) + total_params = sum(p.numel() for p in trainer.model.parameters()) + + assert trainable_params < total_params, "PEFT should reduce the number of trainable parameters" + print(f"\nTrainable params: {trainable_params:,} / Total params: {total_params:,}") + + def test_base_trainer_without_peft_config(self, model_config, dummy_dataset): + """Test that BaseTrainer works without PEFT config (standard training).""" + # Load model and tokenizer + model_name = model_config.pop("model_name") + hf_model = ComponentFactory.create_model("hf", model_name, **model_config) + model = hf_model.model + tokenizer = hf_model.tokenizer + + # Create training config + training_args = TrainingArguments( + output_dir="./test_output_base", + per_device_train_batch_size=1, + num_train_epochs=1, + logging_steps=1, + save_strategy="no", + bf16=False, + fp16=True, + ) + + # Get BaseTrainer from registry + trainer_info = registry.get_trainer_module("base") + trainer_cls = trainer_info["trainer_cls"] + + # Instantiate trainer without PEFT config + trainer = trainer_cls( + model=model, + args=training_args, + train_dataset=dummy_dataset, + processing_class=tokenizer, + peft_config=None, # Explicitly pass None + ) + + assert trainer is not None + assert trainer.model is not None + + # Verify that the model is NOT a PEFT model + assert not isinstance(trainer.model, PeftModel), ( + "Model should not be wrapped as a PeftModel when peft_config is None" + ) + + +class TestSFTTrainerWithModel: + """Test suite for SFTTrainer integration with model loading.""" + + @pytest.fixture(autouse=True) + def cleanup_output_dirs(self): + """Fixture to clean up test output directories after each test.""" + # Setup: yield control to the test + yield + + # Teardown: clean up output directories + output_dirs = ["./test_output", "./test_output_peft"] + for output_dir in output_dirs: + if os.path.exists(output_dir): + try: + shutil.rmtree(output_dir) + print(f"\nCleaned up: {output_dir}") + except Exception as e: + print(f"\nWarning: Failed to clean up {output_dir}: {e}") + + @pytest.fixture + def model_config(self): + """Fixture for basic model configuration.""" + return { + "model_name": "HuggingFaceTB/SmolLM-135M", + "auto_class_name": "AutoModelForCausalLM", + "use_cache": False, + "torch_dtype": "float16", + "attn_implementation": "eager", + "device_map": None, + "num_hidden_layers": 1, + } + + @pytest.fixture + def peft_model_config(self): + """Fixture for PEFT configuration.""" + return { + "lora_r": LORA_R, + "lora_alpha": LORA_ALPHA, + "lora_dropout": LORA_DROPOUT, + "target_modules": ["q_proj", "v_proj"], + "bias": "none", + } + + @pytest.fixture + def dummy_dataset(self): + """Fixture for creating a dummy dataset.""" + + data = { + "text": [ + "This is a test sentence for training.", + "Another example text for the model.", + "Third sample to ensure proper batching.", + ] + } + return Dataset.from_dict(data) + + def test_model_forward_pass(self, model_config): + """Test that the loaded model can perform a forward pass.""" + + model_name = model_config.pop("model_name") + hf_model = ComponentFactory.create_model("hf", model_name, **model_config) + loaded_model = hf_model.model + tokenizer = hf_model.tokenizer + + # Prepare input + text = "This is a test." + inputs = tokenizer(text, return_tensors="pt") + + # Perform forward pass + with torch.no_grad(): + outputs = loaded_model(**inputs) + + assert outputs is not None + assert hasattr(outputs, "logits") + assert outputs.logits.shape[0] == 1 # batch size + + def test_sft_trainer_instantiation_with_model(self, model_config, dummy_dataset): + """Test that SFTTrainer can be instantiated with a loaded model.""" + + # Load model and tokenizer + model_name = model_config.pop("model_name") + hf_model = ComponentFactory.create_model("hf", model_name, **model_config) + model = hf_model.model + tokenizer = hf_model.tokenizer + + # Create SFT config + sft_config = SFTConfig( + output_dir="./test_output", + max_length=MAX_LENGTH, + per_device_train_batch_size=1, + num_train_epochs=1, + logging_steps=1, + save_strategy="no", + bf16=False, + fp16=True, + ) + + # Get SFTTrainer from registry + trainer_info = registry.get_trainer_module("sft") + trainer_cls = trainer_info["trainer_cls"] + + # Instantiate trainer + trainer = trainer_cls( + model=model, + args=sft_config, + train_dataset=dummy_dataset, + processing_class=tokenizer, + ) + + assert trainer is not None + assert trainer.model is not None + assert trainer.tokenizer is not None + + def test_sft_trainer_with_peft_model(self, model_config, peft_model_config, dummy_dataset): + """Test that SFTTrainer works with PEFT-enabled models.""" + + # Load model and tokenizer + model_name = model_config.pop("model_name") + hf_model = ComponentFactory.create_model("hf", model_name, **model_config) + model = hf_model.model + # Load PEFT Config + peft_config = LoraConfig(peft_model_config) + tokenizer = hf_model.tokenizer + + # Create SFT config + sft_config = SFTConfig( + output_dir="./test_output_peft", + max_length=MAX_LENGTH, + per_device_train_batch_size=1, + num_train_epochs=1, + logging_steps=1, + save_strategy="no", + bf16=False, + fp16=True, + ) + + # Get SFTTrainer from registry + trainer_info = registry.get_trainer_module("sft") + trainer_cls = trainer_info["trainer_cls"] + + # Instantiate trainer with PEFT config + trainer = trainer_cls( + model=model, + args=sft_config, + train_dataset=dummy_dataset, + processing_class=tokenizer, + peft_config=peft_config, + ) + + assert trainer is not None + assert trainer.model is not None + + def test_sft_trainer_train_dataset_required(self, model_config): + """Test that SFTTrainer requires a training dataset.""" + + # Load model and tokenizer + model_name = model_config.pop("model_name") + hf_model = ComponentFactory.create_model("hf", model_name, **model_config) + model = hf_model.model + tokenizer = hf_model.tokenizer + + # Create SFT config + sft_config = SFTConfig( + output_dir="./test_output", + max_length=MAX_LENGTH, + per_device_train_batch_size=1, + num_train_epochs=1, + bf16=False, + fp16=True, + ) + + # Get SFTTrainer from registry + trainer_info = registry.get_trainer_module("sft") + trainer_cls = trainer_info["trainer_cls"] + + # Attempt to instantiate without dataset should raise TypeError + with pytest.raises(TypeError, match="'NoneType' object is not iterable"): + trainer_cls( + model=model, + args=sft_config, + processing_class=tokenizer, + )