-
Notifications
You must be signed in to change notification settings - Fork 81
[QEff. Finetuning]: Adding PP support in HF trainer stack #813
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
109 changes: 109 additions & 0 deletions
109
QEfficient/finetune/experimental/configs/sample_pp_config.yaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,109 @@ | ||
| # ----------------------------------------------------------------------------- | ||
| # | ||
| # Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. | ||
| # SPDX-License-Identifier: BSD-3-Clause | ||
| # | ||
| # ----------------------------------------------------------------------------- | ||
| # | ||
| # Sample configuration for Pipeline Parallelism (PP) without DDP | ||
| # This config demonstrates how to enable PP support on a single node without distributed training | ||
| # | ||
| # To run with PP only (no DDP): | ||
| # python -m QEfficient.cloud.finetune_experimental configs/sample_pp_config.yaml | ||
| # | ||
|
|
||
| # To Do: Since config is not getting updated properly thorugh yaml, it gets over written (fix for this is added in #795). | ||
| # Once #795 is merged, redudant params (params fow which value matches value in config_manager) can be removed from here. | ||
| # Dataset can also be kept in sync with | ||
|
|
||
| # Model configuration | ||
| model: | ||
| model_type: "hf" # Hugging Face model | ||
| auto_class_name: "AutoModelForCausalLM" | ||
| model_name: "meta-llama/Llama-3.2-1B" # Pretrained model name | ||
| use_cache: False | ||
| attn_implementation: "sdpa" | ||
| use_peft: True | ||
| peft_config: | ||
| lora_r: 8 | ||
| lora_alpha: 16 | ||
| lora_dropout: 0.1 | ||
| target_modules: ["q_proj", "v_proj"] | ||
| task_type: "CAUSAL_LM" | ||
| peft_type: "LORA" | ||
| bias: "none" # Options: "none", "all", "lora_only" | ||
|
|
||
| # Dataset configuration | ||
| dataset: | ||
| tokenizer_name: "meta-llama/Llama-3.2-1B" | ||
| dataset_type: "sft_dataset" | ||
| dataset_name: "openai/gsm8k" | ||
| prompt_template: "Solve the following math problem step by step.\n\n### Question:\n{question}\n\n### Answer:\n" | ||
| config_name: "main" | ||
| train_split: "train" | ||
| test_split: "test" | ||
| max_seq_length: 512 | ||
| completion_template: "{answer}" | ||
| dataloader_num_workers: 1 | ||
| dataloader_pin_memory: True | ||
| dataloader_persistent_workers: False | ||
| group_by_length: True | ||
| # Training configuration | ||
| training: | ||
| type: "sft" | ||
| output_dir: "./training_results_pp" | ||
| overwrite_output_dir: false | ||
| seed: 42 | ||
| device: "qaic" # Use 'cuda' for NVIDIA GPUs, 'qaic' for Qualcomm Cloud AI | ||
| do_eval: True | ||
| torch_dtype: "fp16" | ||
| eval_strategy: "epoch" | ||
| eval_steps: 100 | ||
| per_device_train_batch_size: 1 | ||
| per_device_eval_batch_size: 1 | ||
| gradient_accumulation_steps: 4 | ||
| num_train_epochs: 5 | ||
| max_steps: -1 | ||
| log_level: "info" | ||
| log_on_each_node: True | ||
| logging_strategy: "steps" | ||
| logging_steps: 10 | ||
| save_strategy: "epoch" | ||
| save_steps: 100 | ||
| save_total_limit: 5 | ||
| metric_for_best_model: "eval_loss" | ||
| completion_only_loss: True | ||
|
|
||
| # Pipeline Parallelism Configuration (PP without DDP) | ||
| enable_pp: True | ||
| num_pp_stages: 2 # Split the model into 2 pipeline stages | ||
|
|
||
| # Gradient Checkpointing (optional, saves memory) | ||
| gradient_checkpointing: False | ||
| gradient_checkpointing_kwargs: | ||
| preserve_rng_state: True | ||
| use_reentrant: False | ||
|
|
||
| torch_compile: false | ||
| include_num_input_tokens_seen: True | ||
| average_tokens_across_devices: True | ||
|
|
||
| # Optimizer configuration | ||
| optimizers: | ||
| optimizer_name: "AdamW" | ||
| lr: 5e-5 | ||
| weight_decay: 0.01 | ||
|
|
||
| # Scheduler configuration | ||
| scheduler: | ||
| scheduler_name: "cosine" | ||
| warmup_steps: 100 | ||
|
|
||
| # Callbacks | ||
| callbacks: | ||
| early_stopping: | ||
| early_stopping_patience: 3 | ||
| early_stopping_threshold: 0.001 | ||
| tensorboard: {} | ||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
169 changes: 169 additions & 0 deletions
169
QEfficient/finetune/experimental/core/utils/device_map_utils.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,169 @@ | ||
| # ----------------------------------------------------------------------------- | ||
| # | ||
| # Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. | ||
| # SPDX-License-Identifier: BSD-3-Clause | ||
| # | ||
| # ----------------------------------------------------------------------------- | ||
|
|
||
| """ | ||
| Utility functions for creating device maps for pipeline parallelism. | ||
| """ | ||
|
|
||
| from typing import Dict, Optional | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| from transformers import AutoConfig | ||
|
|
||
| from QEfficient.finetune.experimental.core.utils.dist_utils import get_local_rank | ||
| from QEfficient.utils._utils import get_num_layers_from_config | ||
|
|
||
|
|
||
| def get_device_map( | ||
| model_name: str, | ||
| device: str, | ||
| pp_degree: int = 1, | ||
| ) -> Optional[Dict[str, int]]: | ||
| """ | ||
| Returns device map for the given model based on PP and DDP configuration. | ||
|
|
||
| Args: | ||
| model_name: Name of the model to load configuration from. | ||
| device: Device type (e.g., 'cuda', 'qaic'). | ||
| pp_degree: Pipeline parallelism degree (number of pipeline stages). > 1 enables PP. | ||
| Returns: | ||
| Dict: A dictionary mapping layer names to device IDs, or None if no PP. | ||
| """ | ||
| if pp_degree <= 1: | ||
| return None | ||
|
|
||
| torch_device = torch.device(device) | ||
| num_available_devices = getattr(torch, torch_device.type).device_count() | ||
|
|
||
| if pp_degree > num_available_devices: | ||
| raise ValueError( | ||
| f"pp_degree ({pp_degree}) cannot exceed the number of available {device} devices " | ||
| f"({num_available_devices}). Reduce pp_degree or use a node with more devices." | ||
| ) | ||
| elif pp_degree == num_available_devices: | ||
| device_map = "auto" | ||
| else: # pp_degree < num_available_devices | ||
| device_map = custom_device_map(model_name, device, pp_degree) | ||
|
|
||
| return device_map | ||
|
|
||
|
|
||
| def custom_device_map(model_name: str, device: str, pp_degree: int) -> Dict[str, int]: | ||
| """ | ||
| Returns custom device map for model layers based on number of pipeline stages and process rank. | ||
|
|
||
| Args: | ||
| model_name: Name of the model to load configuration from. | ||
| device: Device type (e.g., 'cuda', 'qaic'). | ||
| pp_degree: Pipeline parallelism degree (number of pipeline stages). | ||
|
|
||
| Returns: | ||
| Dict: A dictionary mapping layer names to device IDs. | ||
|
|
||
| Notes: | ||
| - This device map structure is verified for llama models primarily. | ||
| - For other architectures, you may need to adjust the layer naming conventions. | ||
| - Layers are distributed as evenly as possible: the first (num_layers % pp_degree) | ||
| stages receive one extra layer each. | ||
|
|
||
| Example: | ||
| Example config for PP + DDP is provided below as it works for only PP as well. | ||
| Configuration for meta-llama/Llama-3.2-1B | ||
| Total devices: 4 (2x PP x 2x DDP) | ||
|
|
||
| PP (Pipeline Parallelism): Each copy of the model is split into 2 stages | ||
| DDP (Distributed Data Parallel): 2 model copies run in parallel | ||
|
|
||
| |--------------------------------------------------------------------------| | ||
| | Process Rank | Assigned Device IDs | Model Component | | ||
| |--------------------------------------------------------------------------| | ||
| | Rank 0 | 0 | model.embed_tokens | | ||
| | | | model.lm_head | | ||
| | | | model.layers.0 - model.layers.7 | | ||
| |--------------------------------------------------------------------------| | ||
| | Rank 0 | 1 | model.norm | | ||
| | | | model.rotary_emb | | ||
| | | | model.layers.8 - model.layers.15 | | ||
| |--------------------------------------------------------------------------| | ||
| | Rank 1 | 2 | model.embed_tokens | | ||
| | | | model.lm_head | | ||
| | | | model.layers.0 - model.layers.7 | | ||
| |--------------------------------------------------------------------------| | ||
| | Rank 1 | 3 | model.norm | | ||
| | | | model.rotary_emb | | ||
| | | | model.layers.8 - model.layers.15 | | ||
| |--------------------------------------------------------------------------| | ||
| """ | ||
|
|
||
| model_config = AutoConfig.from_pretrained(model_name) | ||
| num_layers = get_num_layers_from_config(model_config) | ||
| local_rank = get_local_rank() | ||
|
|
||
| if num_layers < pp_degree: | ||
| raise ValueError( | ||
| f"Number of model layers ({num_layers}) must be >= pp_degree ({pp_degree}). " | ||
| f"Cannot split {num_layers} layers across {pp_degree} pipeline stages." | ||
| ) | ||
|
|
||
| first_device = local_rank * pp_degree | ||
| last_device = local_rank * pp_degree + (pp_degree - 1) | ||
|
|
||
| # Handle tied embeddings | ||
| if model_config.tie_word_embeddings: | ||
| lm_head_device = first_device | ||
| else: | ||
| lm_head_device = last_device | ||
|
|
||
| device_map = { | ||
| "model.embed_tokens": first_device, | ||
| "lm_head": lm_head_device, | ||
| "model.norm": last_device, | ||
| "model.rotary_emb": last_device, | ||
| } | ||
|
|
||
| # Distribute layers as evenly as possible across stages. | ||
| # The first (num_layers % pp_degree) stages get one extra layer each. | ||
| base_layers, remainder = divmod(num_layers, pp_degree) | ||
| layers_per_stage = np.array([base_layers + (1 if i < remainder else 0) for i in range(pp_degree)]) | ||
|
|
||
| # Create device assignment per layer | ||
| pp_device_map = np.repeat(np.arange(pp_degree), layers_per_stage) | ||
|
|
||
| # Assign each layer to a device | ||
| for i in range(num_layers): | ||
| device_map[f"model.layers.{i}"] = int(pp_device_map[i] + local_rank * pp_degree) | ||
|
|
||
| return device_map | ||
|
|
||
|
|
||
| def validate_pp_config( | ||
| pp_degree: int, | ||
| device: str, | ||
| local_world_size: int = 1, | ||
| ) -> None: | ||
| """ | ||
| Validate pipeline parallelism configuration. | ||
|
|
||
| Args: | ||
| pp_degree: Pipeline parallelism degree (number of pipeline stages). Must be > 1 to enable PP. | ||
| device: Device type (e.g., 'cuda', 'qaic'). | ||
| local_world_size: Number of processes per node for DDP. | ||
|
|
||
| Raises: | ||
| AssertionError: If configuration is invalid. | ||
| """ | ||
|
quic-swatia marked this conversation as resolved.
|
||
| if pp_degree > 1: | ||
| # Validate device availability | ||
| torch_device = torch.device(device) | ||
| num_available_devices = getattr(torch, torch_device.type).device_count() | ||
|
|
||
| assert local_world_size * pp_degree <= num_available_devices, ( | ||
| f"Number of devices required per node (LOCAL_WORLD_SIZE * pp_degree = " | ||
| f"{local_world_size} * {pp_degree} = {local_world_size * pp_degree}) " | ||
| f"should be <= locally available devices ({num_available_devices})." | ||
| ) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.