Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions examples/language/gpt/hybridparallelism/finetune.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
from contextlib import nullcontext
from typing import Callable, List, Union

import evaluate
Expand All @@ -17,6 +18,7 @@
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam

# ==============================
Expand Down Expand Up @@ -186,7 +188,6 @@ def main():
help="only gpt2 now",
)
parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached")
parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context")
args = parser.parse_args()

if args.model_type == "gpt2":
Expand Down Expand Up @@ -250,10 +251,16 @@ def main():
pad_token_id=data_builder.tokenizer.pad_token_id,
)

if model_name == "gpt2":
model = GPT2ForSequenceClassification.from_pretrained(model_name, config=cfg).cuda()
else:
raise RuntimeError
init_ctx = (
LazyInitContext(default_device=get_accelerator().get_current_device())
if isinstance(plugin, (GeminiPlugin))
else nullcontext()
)
with init_ctx:
if model_name == "gpt2":
model = GPT2ForSequenceClassification.from_pretrained(model_name, config=cfg).cuda()
else:
raise RuntimeError

# optimizer
no_decay = ["bias", "LayerNorm.weight"]
Expand Down
24 changes: 16 additions & 8 deletions examples/language/opt/opt_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time
from contextlib import nullcontext

import torch
import tqdm
Expand All @@ -8,9 +9,11 @@
from transformers.utils.versions import require_version

import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam

Expand Down Expand Up @@ -62,14 +65,6 @@ def main():
if args.mem_cap > 0:
colo_memory_cap(args.mem_cap)

# Build OPT model
config = AutoConfig.from_pretrained(args.model_name_or_path)
model = OPTForCausalLM(config=config)
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])

# Enable gradient checkpointing
model.gradient_checkpointing_enable()

# Set plugin
booster_kwargs = {}
if args.plugin == "torch_ddp_fp16":
Expand All @@ -82,6 +77,19 @@ def main():
plugin = LowLevelZeroPlugin(initial_scale=2**5)
logger.info(f"Set plugin as {args.plugin}", ranks=[0])

# Build OPT model
init_ctx = (
LazyInitContext(default_device=get_accelerator().get_current_device())
if isinstance(plugin, (GeminiPlugin))
else nullcontext()
)
config = AutoConfig.from_pretrained(args.model_name_or_path)
with init_ctx:
model = OPTForCausalLM(config=config)
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])

# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Set optimizer
optimizer = HybridAdam(model.parameters(), lr=args.learning_rate)

Expand Down
27 changes: 19 additions & 8 deletions examples/language/opt/opt_train_demo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from contextlib import nullcontext

import datasets
import torch
import transformers
Expand All @@ -8,9 +10,11 @@
from transformers.utils.versions import require_version

import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam

Expand Down Expand Up @@ -78,14 +82,6 @@ def main():
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()

# Build OPT model
config = AutoConfig.from_pretrained(args.model_name_or_path)
model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config)
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])

# Enable gradient checkpointing
model.gradient_checkpointing_enable()

# Set plugin
booster_kwargs = {}
if args.plugin == "torch_ddp_fp16":
Expand All @@ -110,6 +106,21 @@ def main():

logger.info(f"Set plugin as {args.plugin}", ranks=[0])

# Build OPT model
config = AutoConfig.from_pretrained(args.model_name_or_path)
# Build OPT model
init_ctx = (
LazyInitContext(default_device=get_accelerator().get_current_device())
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
else nullcontext()
)
with init_ctx:
model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config)
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])

# Enable gradient checkpointing
model.gradient_checkpointing_enable()

# Prepare tokenizer and dataloader
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
dataset = NetflixDataset(tokenizer)
Expand Down