Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ def __init__(
)
self.extra_dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) if self.extra_dp_size > 1 else None
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) if self.tp_size > 1 else None
self.dp_size = self.zero_size * self.extra_dp_size

self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
Expand Down
28 changes: 18 additions & 10 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@

from .pp_plugin_base import PipelinePluginBase

DP_AXIS, PP_AXIS, TP_AXIS, SP_AXIS = 0, 1, 2, 3
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]

PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16}
Expand Down Expand Up @@ -987,6 +986,7 @@ def __init__(
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
enable_metadata_cache: bool = True,
make_vocab_size_divisible_by: int = 64,
dp_outside: bool = True,
) -> None:
super().__init__()
assert (
Expand Down Expand Up @@ -1034,7 +1034,12 @@ def __init__(
self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
if dp_outside:
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
else:
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
self.stage_manager = None
self.schedule = None
self.custom_policy = custom_policy
Expand All @@ -1048,7 +1053,7 @@ def __init__(
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
self.stage_manager = PipelineStageManager(
self.pg_mesh,
pipeline_axis=PP_AXIS,
pipeline_axis=self.pp_axis,
enable_interleave=pp_style == "interleaved",
num_model_chunks=num_model_chunks,
)
Expand All @@ -1072,13 +1077,13 @@ def __init__(
else:
raise NotImplementedError()

self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis)
if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]:
self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
else:
self.sp_group = self.pg_mesh.get_group_along_axis(SP_AXIS)
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)

self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
Expand Down Expand Up @@ -1169,7 +1174,7 @@ def configure(
and self.sequence_parallelism_mode == "all_to_all"
)
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
dp_group = self.pg_mesh.create_group_along_axis([DP_AXIS, SP_AXIS])
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
else:
dp_group = self.dp_group
model = HybridParallelModule(
Expand Down Expand Up @@ -1317,7 +1322,10 @@ def prepare_dataloader(
_kwargs = kwargs.copy()
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
sampler = distributed_sampler_cls(
dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
dataset,
num_replicas=self.pg_mesh.size(self.dp_axis),
rank=self.pg_mesh.coordinate(self.dp_axis),
shuffle=shuffle,
)

# Deterministic dataloader
Expand Down
117 changes: 3 additions & 114 deletions examples/language/llama2/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Pretraining LLaMA-1/2: best practices for building LLaMA-1/2-like base models
# Pretraining LLaMA-1/2/3: best practices for building LLaMA-1/2/3-like base models

### LLaMA2
<p align="center">
Expand All @@ -16,38 +16,10 @@
- 65-billion-parameter large model pretraining accelerated by 38%
[[blog]](https://www.hpc-ai.tech/blog/large-model-pretraining)

## Dataset

Different from the original LLaMA, we use [RedPajama](https://www.together.xyz/blog/redpajama) dataset, which is a reproduction of the LLaMA training dataset containing over 1.2 trillion tokens. The full dataset is ~5TB unzipped on disk and ~3TB to download compressed.

A smaller, more consumable random sample can be downloaded through [Hugging Face](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T). If you just want to try out the pretraining script, you can use a 1B-token sample subset of RedPajama, which is available at [Hugging Face](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample).

RedPajama-Data-1T consists of seven data slices:

| | RedPajama | LLaMA |
|---------------|--------------|---------------|
| CommonCrawl | 878 billion | 852 billion |
| C4 | 175 billion | 190 billion |
| Github | 59 billion | 100 billion |
| Books | 26 billion | 25 billion |
| ArXiv | 28 billion | 33 billion |
| Wikipedia | 24 billion | 25 billion |
| StackExchange | 20 billion | 27 billion |
| Total | 1.2 trillion | 1.25 trillion |

## Training

We follow the hyperparameter settings from the original LLaMA paper. We use AdamW with $beta1=0.9$ and $beta2=0.95$. We use a cosine learning rate schedule, such that the final learning rate is equal to 10% of the maximal learning rate. We use a weight decay of 0.1 and gradient clipping of 1.0. We use 2,000 warmup steps.

| params | learning rate | batch size |
|--------|---------------|------------|
| 6.7B | 3.0e-4 | 4M |
| 13.0B | 3.0e-4 | 4M |
| 32.5B | 1.5e-4 | 4M |
| 65.2B | 1.5e-4 | 4M |

## Usage

> ⚠ This example only has benchmarking script. For training/finetuning, please refer to the [applications/Colossal-LLaMA](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Colossal-LLaMA).

### 1. Installation

Please install the latest ColossalAI from source.
Expand All @@ -62,52 +34,6 @@ Then install other dependencies.
pip install -r requirements.txt
```

Additionally, we recommend you to use torch 1.13.1. We've tested our code on torch 1.13.1 and found it's compatible with our code and flash attention.

### 2. Download the dataset

The dataset can be automatically downloaded by using `huggingface/datasets`. You can specify the dataset path by `-d` or `--dataset`. The default dataset is `togethercomputer/RedPajama-Data-1T-Sample`.

### 3. Command line arguments

Yon can use colossalai run to launch multi-nodes training:
```bash
colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
pretrain.py --OTHER_CONFIGURATIONS
```

Here is a sample hostfile:

```text
hostname1
hostname2
hostname3
hostname4
```

Make sure master node can access all nodes (including itself) by ssh without password.

Here is details about CLI arguments:

- Model configuration: `-c`, `--config`. `7b`, `13b`, `30b` and `65b` are supported for LLaMA-1, `7b`, `13b`, and `70b` are supported for LLaMA-2.
- Booster plugin: `-p`, `--plugin`. `gemini`, `gemini_auto`, `zero2`, `hybrid_parallel` and `zero2_cpu` are supported. For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins).
- Dataset path: `-d`, `--dataset`. The default dataset is `togethercomputer/RedPajama-Data-1T-Sample`. It support any dataset from `datasets` with the same data format as RedPajama.
- Number of epochs: `-e`, `--num_epochs`. The default value is 1.
- Local batch size: `-b`, `--batch_size`. Batch size per GPU. The default value is 2.
- Learning rate: `--lr`. The default value is 3e-4.
- Weight decay: `-w`, `--weight_decay`. The default value is 0.1.
- Warmup steps: `-s`, `--warmup_steps`. The default value is 2000.
- Gradient checkpointing: `-g`, `--gradient_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size.
- Max length: `-l`, `--max_length`. The default value is 4096.
- Mixed precision: `-x`, `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported.
- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000.
- Checkpoint directory: `-o`, `--save_dir`. The directory path to save checkpoints. The default value is `checkpoint`.
- Checkpoint to load: `-f`, `--load`. The checkpoint path to load. The default value is `None`.
- Gradient clipping: `--gradient_clipping`. The default value is 1.0.
- Tensorboard log directory: `-t`, `--tensorboard_dir`. The directory path to save tensorboard logs. The default value is `tb_logs`.
- Flash attention: `-a`, `--flash_attention`. If you want to use flash attention, you must install `flash-attn`. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention.


### 4. Shell Script Examples

For your convenience, we provide some shell scripts to run benchmark with various configurations.
Expand Down Expand Up @@ -193,40 +119,3 @@ If you run the above command successfully, you will get the following results:
year={2023}
}
```


# Fine-tune Llama2

We also provide a example to fine-tune llama2 in `finetune.py`,

Make sure master node can access all nodes (including itself) by ssh without password.

Here is details about CLI arguments:

- Pretrained checkpoint path: `--model_path`, the path of your model checkpoint, it can be your local directory or a Hugging Face tag.
- Booster plugin: `-p`, `--plugin`. `gemini`, `gemini_auto`, `zero2`, `hybrid_parallel` and `zero2_cpu` are supported. For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins).
- Dataset path: `-d`, `--dataset`. The default dataset is `yizhongw/self_instruct`. It support any dataset from `datasets` with the same data format as `yizhongw/self_instruct`.
- task name: `--task_name`, the task to fine-tune, it's also related to the target of loading dataset, The default value is `super_natural_instructions`.
- Number of epochs: `-e`, `--num_epochs`. The default value is 1.
- Local batch size: `-b`, `--batch_size`. Batch size per GPU. The default value is 2.
- Learning rate: `--lr`. The default value is 3e-4.
- Weight decay: `-w`, `--weight_decay`. The default value is 0.1.
- Gradient checkpointing: `-g`, `--gradient_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size.
- Max length: `-l`, `--max_length`. The default value is 4096.
- Mixed precision: `-x`, `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported.
- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000.
- Checkpoint directory: `-o`, `--save_dir`. The directory path to save checkpoints. The default value is `checkpoint`.
- Checkpoint to load: `-f`, `--load`. The checkpoint path to load. The default value is `None`.
- Gradient clipping: `--gradient_clipping`. The default value is 1.0.
- Tensorboard log directory: `-t`, `--tensorboard_dir`. The directory path to save tensorboard logs. The default value is `tb_logs`.
- Flash attention: `-a`, `--flash_attention`. If you want to use flash attention, you must install `flash-attn`. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention.


```shell
torchrun --standalone --nproc_per_node 8 finetune.py \
--plugin "hybrid_parallel" \
--dataset "yizhongw/self_instruct" \
--model_path "/path/llama" \
--task_name "super_natural_instructions" \
--save_dir "/path/output"
```
1 change: 0 additions & 1 deletion examples/language/llama2/attn.py

This file was deleted.

Loading