Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Callable, List, Optional, Tuple

import torch
Expand Down Expand Up @@ -392,6 +393,13 @@ def get_llama_flash_attention_forward():

from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb

llama_version = 2
try:
from transformers.models.llama.modeling_llama import repeat_kv
except:
warnings.warn("using llamav1, llamav1 hasn't repeat_kv function")
llama_version = 1

from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention

def forward(
Expand Down Expand Up @@ -424,6 +432,11 @@ def forward(

past_key_value = (key_states, value_states) if use_cache else None

# repeat k/v heads if n_kv_heads < n_heads
if llama_version == 2:
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

me_input_shape = (bsz, q_len, self.num_heads, self.head_dim)
query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape)
key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape)
Expand Down
1 change: 0 additions & 1 deletion colossalai/shardformer/modeling/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,6 @@ def forward(
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."

attention_input_shape = (bsz, -1, self.num_heads, self.head_dim)
# get query proj
Expand Down
6 changes: 2 additions & 4 deletions colossalai/shardformer/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,8 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:

if self.shard_config.enable_tensor_parallelism:
decoder_attribute_replacement = {
"self_attn.hidden_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
}
if getattr(self.model.config, "num_key_value_heads", False):
decoder_attribute_replacement["self_attn.num_key_value_heads"] = \
Expand Down
55 changes: 26 additions & 29 deletions examples/language/bert/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,18 @@ def evaluate_model(
model.eval()

def evaluate_subset(dataloader: DataLoader):
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()

accum_loss = torch.zeros(1, device=get_current_device())
for batch in dataloader:
batch = move_to_cuda(batch)
labels = batch["labels"]
batch_size = batch["input_ids"].shape[0]
if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None:
if use_pipeline:
pg_mesh = booster.plugin.pg_mesh
pp_group = booster.plugin.pp_group
current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group)
current_rank = dist.get_rank()
#TODO pass dataloader to execute_pipeline directly
batch = iter([batch])
outputs = booster.execute_pipeline(batch,
model,
Expand All @@ -78,31 +79,25 @@ def evaluate_subset(dataloader: DataLoader):
return_loss=True,
return_outputs=True)

if booster.plugin.stage_manager.is_last_stage():
val_loss = outputs["loss"]

if is_pp_last_stage:
logits = outputs["outputs"]["logits"]

val_loss = outputs["loss"]
accum_loss.add_(val_loss)

if num_labels > 1:
preds = torch.argmax(logits, axis=1)
elif num_labels == 1:
preds = logits.squeeze()

dist.broadcast(preds, src=current_rank, group=pp_group)
dist.broadcast(val_loss, src=current_rank, group=pp_group)
dist.broadcast_object_list([preds, val_loss], src=current_pp_group_ranks[-1], group=pp_group)

metric.add_batch(predictions=preds, references=labels)
elif current_rank in current_pp_group_ranks:
val_loss = torch.empty((1,), device=get_current_device())
preds = torch.empty((batch_size,), dtype=torch.int64, device=get_current_device())
object_list = [None, None]
dist.broadcast_object_list(object_list, src=current_pp_group_ranks[-1], group=pp_group)

dist.broadcast(preds, src=current_pp_group_ranks[-1], group=pp_group)
dist.broadcast(val_loss, src=current_pp_group_ranks[-1], group=pp_group)

accum_loss.add_(val_loss)
metric.add_batch(predictions=preds, references=labels)
metric.add_batch(predictions=object_list[0].to(get_current_device()), references=labels)
accum_loss.add_(object_list[1].to(get_current_device()))

else:
batch = move_to_cuda(batch)
Expand Down Expand Up @@ -138,31 +133,33 @@ def evaluate_subset(dataloader: DataLoader):
def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: Callable, lr_scheduler: LRScheduler,
train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator):

use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
total_step = len(train_dataloader)

model.train()
is_pp_last_stage = hasattr(
booster.plugin,
"stage_manager") and booster.plugin.stage_manager is not None and booster.plugin.stage_manager.is_last_stage()
with tqdm(train_dataloader,
optimizer.zero_grad()
train_dataloader_iter = iter(train_dataloader)
with tqdm(range(total_step),
desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]',
disable=not (coordinator.is_master() or is_pp_last_stage)) as pbar:
for batch in pbar:
# Forward pass
batch = move_to_cuda(batch)
if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None:
#TODO pass train_dataloader to execute_pipeline directly
batch = iter([batch])
outputs = booster.execute_pipeline(batch,
# Forward pass
for _ in pbar:
if use_pipeline:
outputs = booster.execute_pipeline(train_dataloader_iter,
model,
_criterion,
optimizer,
return_loss=True,
return_outputs=True)
# Backward and optimize
if booster.plugin.stage_manager.is_last_stage():
if is_pp_last_stage:
loss = outputs['loss']
pbar.set_postfix({'loss': loss.item()})
else:
outputs = model(**batch)
data = next(train_dataloader_iter)
data = move_to_cuda(data)
outputs = model(**data)
loss = _criterion(outputs, None)
# Backward
booster.backward(loss, optimizer)
Expand Down
140 changes: 44 additions & 96 deletions examples/language/opt/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,117 +4,65 @@
def parse_demo_args():

parser = get_default_parser()
parser.add_argument(
"--model_name_or_path",
type=str,
default="facebook/opt-350m",
help="Path to pretrained model or model identifier from huggingface.co/models."
)
parser.add_argument(
"--output_path",
type=str,
default="./output_model.bin",
help="The path of your saved model after finetuning."
)
parser.add_argument("--model_name_or_path",
type=str,
default="facebook/opt-350m",
help="Path to pretrained model or model identifier from huggingface.co/models.")
parser.add_argument("--output_path",
type=str,
default="./output_model.bin",
help="The path of your saved model after finetuning.")
parser.add_argument(
"--plugin",
type=str,
default="gemini",
help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'."
)
parser.add_argument(
"--num_epoch",
type=int,
default=10,
help="Number of epochs."
)
parser.add_argument(
"--batch_size",
type=int,
default=32,
help="Batch size (per dp group) for the training dataloader."
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-5,
help="Initial learning rate (after the potential warmup period) to use."
)
parser.add_argument(
"--warmup_ratio",
type=float,
default=0.1,
help="Ratio of warmup steps against total training steps."
)
parser.add_argument(
"--weight_decay",
type=float,
default=0.01,
help="Weight decay to use."
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="A seed for reproducible training."
)
help=
"Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero', 'hybrid_parallel'."
)
parser.add_argument("--num_epoch", type=int, default=10, help="Number of epochs.")
parser.add_argument("--batch_size",
type=int,
default=32,
help="Batch size (per dp group) for the training dataloader.")
parser.add_argument("--learning_rate",
type=float,
default=5e-5,
help="Initial learning rate (after the potential warmup period) to use.")
parser.add_argument("--warmup_ratio",
type=float,
default=0.1,
help="Ratio of warmup steps against total training steps.")
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to use.")
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")

args = parser.parse_args()
return args



def parse_benchmark_args():

parser = get_default_parser()
parser.add_argument(
"--model_name_or_path",
type=str,
default="facebook/opt-125m",
help="Path to pretrained model or model identifier from huggingface.co/models."
)
parser.add_argument("--model_name_or_path",
type=str,
default="facebook/opt-125m",
help="Path to pretrained model or model identifier from huggingface.co/models.")
parser.add_argument(
"--plugin",
type=str,
default="gemini",
help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'."
)
parser.add_argument(
"--batch_size",
type=int,
default=32,
help="Batch size (per dp group) for the training dataloader."
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-5,
help="Initial learning rate (after the potential warmup period) to use."
)
parser.add_argument(
"--weight_decay",
type=float,
default=0.0,
help="Weight decay to use."
)
parser.add_argument(
"--max_train_steps",
type=int,
default=20,
help="Total number of training steps to perform."
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="A seed for reproducible training."
)
parser.add_argument(
"--mem_cap",
type=int,
default=0,
help="Limit on the usage of space for each GPU (in GB)."
)
help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'.")
parser.add_argument("--batch_size",
type=int,
default=32,
help="Batch size (per dp group) for the training dataloader.")
parser.add_argument("--learning_rate",
type=float,
default=5e-5,
help="Initial learning rate (after the potential warmup period) to use.")
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
parser.add_argument("--max_train_steps", type=int, default=20, help="Total number of training steps to perform.")
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
parser.add_argument("--mem_cap", type=int, default=0, help="Limit on the usage of space for each GPU (in GB).")
args = parser.parse_args()

return args
return args
Loading