Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions examples/language/opt/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,35 @@ Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/fa

The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost.

We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before
the tokenization). This training script is adapted from the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling).

## Our Modifications
We adapt the OPT training code to ColossalAI by leveraging Gemini and ZeRO DDP.

## Quick Start
You can launch training by using the following bash script
We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before
the tokenization).

We adapt the OPT training code to ColossalAI by leveraging [Boosting API](https://colossalai.org/docs/basics/booster_api) loaded with a chosen plugin, where each plugin corresponds to a specific kind of training strategy. This example supports plugins including TorchDDPPlugin, LowLevelZeroPlugin, and GeminiPlugin.

## Run Demo

By running the following script:
```bash
bash ./run_gemini.sh
bash run_demo.sh
```
You will finetune a [facebook/opt-350m](https://huggingface.co/facebook/opt-350m) model on this [dataset](https://huggingface.co/datasets/hugginglearners/netflix-shows), which contains more than 8000 comments on Netflix shows.

The script can be modified if you want to try another set of hyperparameters or change to another OPT model with different size.

The demo code is adapted from this [blog](https://medium.com/geekculture/fine-tune-eleutherai-gpt-neo-to-generate-netflix-movie-descriptions-in-only-47-lines-of-code-40c9b4c32475) and the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling).



## Run Benchmark

You can run benchmark for OPT model by running the following script:
```bash
bash run_benchmark.sh
```
The script will test performance (throughput & peak memory usage) for each combination of hyperparameters. You can also play with this script to configure your set of hyperparameters for testing.



120 changes: 120 additions & 0 deletions examples/language/opt/args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from colossalai import get_default_parser


def parse_demo_args():

parser = get_default_parser()
parser.add_argument(
"--model_name_or_path",
type=str,
default="facebook/opt-350m",
help="Path to pretrained model or model identifier from huggingface.co/models."
)
parser.add_argument(
"--output_path",
type=str,
default="./output_model.bin",
help="The path of your saved model after finetuning."
)
parser.add_argument(
"--plugin",
type=str,
default="gemini",
help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'."
)
parser.add_argument(
"--num_epoch",
type=int,
default=10,
help="Number of epochs."
)
parser.add_argument(
"--batch_size",
type=int,
default=32,
help="Batch size (per dp group) for the training dataloader."
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-5,
help="Initial learning rate (after the potential warmup period) to use."
)
parser.add_argument(
"--warmup_ratio",
type=float,
default=0.1,
help="Ratio of warmup steps against total training steps."
)
parser.add_argument(
"--weight_decay",
type=float,
default=0.01,
help="Weight decay to use."
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="A seed for reproducible training."
)

args = parser.parse_args()
return args



def parse_benchmark_args():

parser = get_default_parser()
parser.add_argument(
"--model_name_or_path",
type=str,
default="facebook/opt-125m",
help="Path to pretrained model or model identifier from huggingface.co/models."
)
parser.add_argument(
"--plugin",
type=str,
default="gemini",
help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'."
)
parser.add_argument(
"--batch_size",
type=int,
default=32,
help="Batch size (per dp group) for the training dataloader."
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-5,
help="Initial learning rate (after the potential warmup period) to use."
)
parser.add_argument(
"--weight_decay",
type=float,
default=0.0,
help="Weight decay to use."
)
parser.add_argument(
"--max_train_steps",
type=int,
default=20,
help="Total number of training steps to perform."
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="A seed for reproducible training."
)
parser.add_argument(
"--mem_cap",
type=int,
default=0,
help="Limit on the usage of space for each GPU (in GB)."
)
args = parser.parse_args()

return args
21 changes: 0 additions & 21 deletions examples/language/opt/benchmark.sh

This file was deleted.

37 changes: 37 additions & 0 deletions examples/language/opt/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch
from torch.utils.data import Dataset
from datasets import load_dataset


class NetflixDataset(Dataset):

def __init__(self, tokenizer):

super().__init__()

self.tokenizer = tokenizer
self.input_ids = []
self.attn_masks = []
self.labels = []
self.txt_list = netflix_descriptions = load_dataset("hugginglearners/netflix-shows", split="train")['description']
self.max_length = max([len(self.tokenizer.encode(description)) for description in netflix_descriptions])

for txt in self.txt_list:
encodings_dict = self.tokenizer('</s>' + txt + '</s>',
truncation=True,
max_length=self.max_length,
padding="max_length")
self.input_ids.append(torch.tensor(encodings_dict['input_ids']))
self.attn_masks.append(torch.tensor(encodings_dict['attention_mask']))

def __len__(self):
return len(self.input_ids)

def __getitem__(self, idx):
return self.input_ids[idx], self.attn_masks[idx]


def netflix_collator(data):
return {'input_ids': torch.stack([x[0] for x in data]),
'attention_mask': torch.stack([x[1] for x in data]),
'labels': torch.stack([x[0] for x in data])}
146 changes: 146 additions & 0 deletions examples/language/opt/opt_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import time

import torch
import transformers
from transformers import AutoConfig, OPTForCausalLM
from transformers.utils.versions import require_version
import tqdm

import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator

from args import parse_benchmark_args

require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt")


def format_num(num: int, bytes=False):
"""Scale bytes to its proper format, e.g. 1253656 => '1.20MB'"""
factor = 1024 if bytes else 1000
suffix = "B" if bytes else ""
for unit in ["", " K", " M", " G", " T", " P"]:
if num < factor:
return f"{num:.2f}{unit}{suffix}"
num /= factor


def get_data(batch_size, seq_len, vocab_size):
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
attention_mask = torch.ones_like(input_ids)
return input_ids, attention_mask


def colo_memory_cap(size_in_GB):
from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device
cuda_capacity = colo_device_memory_capacity(get_current_device())
if size_in_GB * (1024**3) < cuda_capacity:
colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
print(f"Limiting GPU memory usage to {size_in_GB} GB")


def main():

args = parse_benchmark_args()

# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=args.seed)
coordinator = DistCoordinator()
world_size = coordinator.world_size

# Manage loggers
disable_existing_loggers()
logger = get_dist_logger()
if coordinator.is_master():
transformers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()

# Whether to set limit of memory capacity
if args.mem_cap > 0:
colo_memory_cap(args.mem_cap)

# Build OPT model
# Initialize the model under ColoInitContext if using GeminiPlugin
config = AutoConfig.from_pretrained(args.model_name_or_path)
if args.plugin == 'gemini':
shard_pg = ProcessGroup(tp_degree=world_size)
default_dist_spec = ShardSpec([-1], [world_size])
with ColoInitContext(device='cpu',
default_dist_spec=default_dist_spec,
default_pg=shard_pg):
model = OPTForCausalLM(config)
else:
model = OPTForCausalLM(config)
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])

# Enable gradient checkpointing
model.gradient_checkpointing_enable()

# Set plugin
booster_kwargs = {}
if args.plugin == 'torch_ddp_fp16':
booster_kwargs['mixed_precision'] = 'fp16'
if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin()
elif args.plugin == 'gemini':
plugin = GeminiPlugin(device=get_current_device(),
placement_policy='cpu',
pin_memory=True,
strict_ddp_mode=True,
initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5)
logger.info(f"Set plugin as {args.plugin}", ranks=[0])

# Set optimizer
optimizer = HybridAdam(model.parameters(), lr=args.learning_rate)

# Set booster
booster = Booster(plugin=plugin, **booster_kwargs)
model, optimizer, _, _, _ = booster.boost(model, optimizer)

SEQ_LEN = 1024
VOCAB_SIZE = 50257

# Start training.
logger.info(f"Start testing", ranks=[0])
progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master())

torch.cuda.synchronize()
model.train()
start_time = time.time()

for _ in range(args.max_train_steps):

input_ids, attn_mask = get_data(args.batch_size, SEQ_LEN, VOCAB_SIZE)
optimizer.zero_grad()
outputs = model(input_ids=input_ids, attention_mask=attn_mask, labels=input_ids, use_cache=False)
loss = outputs['loss']
booster.backward(loss, optimizer)
optimizer.step()

torch.cuda.synchronize()
progress_bar.update(1)

# Compute Statistics
end_time = time.time()
throughput = "{:.4f}".format((world_size * args.max_train_steps * args.batch_size) / (end_time - start_time))
max_mem = format_num(torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), bytes=True)

logger.info(f"Testing finished, "
f"batch size per gpu: {args.batch_size}, "
f"plugin: {args.plugin}, "
f"throughput: {throughput}, "
f"maximum memory usage per gpu: {max_mem}.",
ranks=[0])


if __name__ == "__main__":
main()
Loading