Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp
self.shared_param_process_groups = []
for shared_param in self.shared_params:
if len(shared_param) > 0:
self.stage_manager.init_process_group_by_stages(list(shared_param.keys()))
self.shared_param_process_groups.append(
self.stage_manager.init_process_group_by_stages(list(shared_param.keys())))
if precision == 'fp16':
module = module.half().cuda()
elif precision == 'bf16':
Expand Down
4 changes: 3 additions & 1 deletion tests/kit/model_zoo/transformers/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def data_gen_for_sequence_classification():
embd_pdrop=0,
resid_pdrop=0,
summary_first_dropout=0,
hidden_dropout=0)
hidden_dropout=0,
problem_type="single_label_classification",
pad_token_id=50256)

config_for_token_classification = copy.deepcopy(config)
config_for_token_classification.num_labels = 2
Expand Down
159 changes: 149 additions & 10 deletions tests/test_shardformer/test_model/_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
import copy
from contextlib import nullcontext
from typing import Any, Callable, Dict, List, Optional

import torch
import torch.distributed as dist
from torch import Tensor
from torch import distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import Module
from torch.optim import Adam, Optimizer

from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.lazy import LazyInitContext
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer._utils import getattr_
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
Expand Down Expand Up @@ -79,20 +87,151 @@ def check_state_dict(org_model: Module, sharded_model: Module, name: str = ''):
assert torch.equal(v, shard_v), f'{name} {k} value mismatch'


def check_grad(original_model, sharded_model, layer_suffix, atol=1e-5, rtol=1e-5, dim=0, verbose=False):
def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_config: Dict[str, Any]):

use_lazy_init = False
if 'use_lazy_init' in test_config:
use_lazy_init = test_config.pop('use_lazy_init')

if use_lazy_init:
ctx = LazyInitContext()
else:
ctx = nullcontext()

plugin = HybridParallelPlugin(**test_config)
booster = Booster(plugin=plugin)

with ctx:
org_model = model_fn().cuda()
sharded_model = copy.deepcopy(org_model)

if use_lazy_init:
org_model = ctx.materialize(org_model)

org_optimizer = Adam(org_model.parameters(), lr=1e-3)
sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3)
criterion = loss_fn

sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)

return org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster


def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Module, sharded_optimizer: Optimizer,
data_gen_fn: Callable, output_transform_fn: Callable, criterion: Callable,
booster: Booster):

def _criterion(outputs, inputs):
outputs = output_transform_fn(outputs)
loss = criterion(outputs)
return loss

data = data_gen_fn()
sharded_model.train()
if booster.plugin.stage_manager is not None:
data = {
k: v.to('cuda').repeat(4, 1) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
for k, v in data.items()
}
data_iter = iter([data])
sharded_output = booster.execute_pipeline(data_iter,
sharded_model,
_criterion,
sharded_optimizer,
return_loss=True,
return_outputs=True)
sharded_loss = sharded_output['loss']
else:
data = {k: v.cuda() for k, v in data.items()}
sharded_output = sharded_model(**data)
sharded_loss = criterion(sharded_output)
sharded_loss.backward()

org_model.train()
org_output = org_model(**data)
org_loss = criterion(org_output)
org_loss.backward()

return org_loss, org_output, sharded_loss, sharded_output


def check_output_hidden_state(org_output: Tensor,
sharded_output: Tensor,
stage_manager: Optional[PipelineStageManager] = None,
atol: float = 1e-5,
rtol: float = 1e-3):

org_hidden_state = org_output.last_hidden_state

if stage_manager is None:
sharded_hidden_state = sharded_output.last_hidden_state

if stage_manager and stage_manager.is_last_stage():
sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=0)

assert torch.allclose(org_hidden_state, sharded_hidden_state, atol=atol, rtol=rtol), \
f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"


def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3):
assert torch.allclose(org_loss, sharded_loss, atol=atol, rtol=rtol), \
f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}"


def check_weight(org_model: Module,
sharded_model: Module,
layer_suffix: List[str],
tp_group: Optional[ProcessGroup] = None,
dim: int = 0,
atol: float = 1e-5,
rtol: float = 1e-3,
verbose: bool = False):

for suffix in layer_suffix:
org_weight = getattr_(org_model, suffix).weight
sharded_weight = getattr_(sharded_model, suffix).weight

if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight):
sharded_weight_list = [
torch.zeros([*sharded_weight.shape]).to('cuda') for _ in range(dist.get_world_size(tp_group))
]
dist.all_gather(sharded_weight_list, sharded_weight, tp_group)
sharded_weight = torch.cat(sharded_weight_list, dim=dim)

if verbose and dist.get_rank() == 0:
print(f"'{suffix}' weight: {org_weight}, {sharded_weight}")

assert torch.allclose(org_weight, sharded_weight, atol=atol, rtol=rtol), \
f"shard model weight is not equal to origin model weight\n{org_weight}\n{sharded_weight}"


def check_grad(org_model: Module,
sharded_model: Module,
layer_suffix: List[str],
tp_group: ProcessGroup = None,
dim: int = 0,
atol: float = 1e-5,
rtol: float = 1e-3,
verbose: bool = False):

for suffix in layer_suffix:
org_grad = getattr_(original_model, suffix).weight.grad
org_grad = getattr_(org_model, suffix).weight.grad
shard_grad = getattr_(sharded_model, suffix).weight.grad
shard_weight = getattr_(sharded_model, suffix).weight

if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(dist.get_world_size())]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=dim)
else:
all_shard_grad = shard_grad
shard_grad_list = [
torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(dist.get_world_size(tp_group))
]
dist.all_gather(shard_grad_list, shard_grad, tp_group)
shard_grad = torch.cat(shard_grad_list, dim=dim)

# embedding may be resized when using tensor parallel
if shard_grad.shape[0] > org_grad.shape[0]:
shard_grad = shard_grad[:org_grad.shape[0], :]

if verbose and dist.get_rank() == 0:
print(f"'{suffix}' grad: {org_grad}, {all_shard_grad}")
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
assert torch.allclose(
org_grad, all_shard_grad, rtol=rtol, atol=atol
), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{all_shard_grad}"
org_grad, shard_grad, rtol=rtol, atol=atol
), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}"
138 changes: 36 additions & 102 deletions tests/test_shardformer/test_model/test_shard_gpt2.py
Original file line number Diff line number Diff line change
@@ -1,107 +1,48 @@
import copy
from contextlib import nullcontext

import pytest
import torch
from torch import distributed as dist
from torch.optim import Adam

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.lazy.lazy_init import LazyInitContext
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor.api import (
clear_layout_converter,
is_customized_distributed_tensor,
is_distributed_tensor,
)
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_loss,
check_output_hidden_state,
check_weight,
run_forward_backward_with_hybrid_plugin,
)


def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):

use_lazy_init = False
if 'use_lazy_init' in test_config:
use_lazy_init = test_config.pop('use_lazy_init')
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)

if use_lazy_init:
ctx = LazyInitContext()
else:
ctx = nullcontext()

# prepare booster
plugin = HybridParallelPlugin(**test_config)
booster = Booster(plugin=plugin)
stage_manager = plugin.stage_manager

# prepare models and optimizers
with ctx:
org_model = model_fn().cuda()
sharded_model = copy.deepcopy(org_model)

if use_lazy_init:
org_model = ctx.materialize(org_model)

org_optimizer = Adam(org_model.parameters(), lr=1e-3)
sharded_optimizer = Adam(sharded_model.parameters(), lr=1e-3)
criterion = loss_fn

sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)

def _criterion(outputs, inputs):
outputs = output_transform_fn(outputs)
loss = criterion(outputs)
return loss

# do forward and backward
data = data_gen_fn()
sharded_model.train()
if stage_manager:
data = {
k: v.to('cuda').repeat(4, 1) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
for k, v in data.items()
}
data_iter = iter([data])
sharded_output = booster.execute_pipeline(data_iter,
sharded_model,
_criterion,
sharded_optimizer,
return_loss=True,
return_outputs=True)
sharded_loss = sharded_output['loss']
else:
data = {k: v.cuda() for k, v in data.items()}
sharded_output = sharded_model(**data)
sharded_loss = criterion(sharded_output)
sharded_loss.backward()

org_model.train()
org_output = org_model(**data)
org_loss = criterion(org_output)
org_loss.backward()
org_loss, org_output, sharded_loss, sharded_output = \
run_forward_backward_with_hybrid_plugin(
org_model,
sharded_model,
sharded_optimizer,
data_gen_fn,
output_transform_fn,
criterion,
booster)

stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group

# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():

# check last hidden state
if org_model.__class__.__name__ == 'GPT2Model':
org_hidden_state = org_output.last_hidden_state

if stage_manager is None:
sharded_hidden_state = sharded_output.last_hidden_state

if stage_manager and stage_manager.is_last_stage():
sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']],
dim=0)

assert torch.allclose(org_hidden_state, sharded_hidden_state, atol=1e-5, rtol=1e-3), \
f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)

# check loss
assert torch.allclose(org_loss, sharded_loss, atol=1e-5, rtol=1e-3), \
f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}"
check_loss(org_loss, sharded_loss, atol=1e-5, rtol=1e-3)

# unwrap model
if org_model.__class__.__name__ == 'GPT2Model':
Expand All @@ -111,27 +52,19 @@ def _criterion(outputs, inputs):
gpt2 = org_model.transformer
sharded_gpt2 = sharded_model.unwrap().transformer

# check grad
col_layer_for_check = ['h[0].mlp.c_fc']
row_layer_for_check = ['h[0].mlp.c_proj']
check_grad(gpt2, sharded_gpt2, col_layer_for_check, atol=1e-6, rtol=1e-3, dim=1, verbose=False)
check_grad(gpt2, sharded_gpt2, row_layer_for_check, atol=1e-6, rtol=1e-3, dim=0, verbose=False)
row_layer_for_check = ['wte', 'h[0].mlp.c_proj']

# check grad
if stage_manager is None or stage_manager.is_first_stage():
check_grad(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=1, verbose=False)
check_grad(gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=0, verbose=False)

# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
if stage_manager is None or stage_manager.is_first_stage():

org_weight = org_model.h[0].mlp.c_fc.weight
shard_weight = sharded_model.h[0].mlp.c_fc.weight

if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_weight_list = [torch.zeros([*shard_weight.shape]).to('cuda') for _ in range(plugin.tp_size)]
dist.all_gather(shard_weight_list, shard_weight, plugin.tp_group)
shard_weight = torch.cat(shard_weight_list, dim=1)

assert torch.allclose(org_weight, shard_weight, atol=5e-3, rtol=1e-3), \
f"shard model weight is not equal to origin model weight\n{org_weight}\n{shard_weight}"
check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=5e-3, rtol=1e-3, dim=1, verbose=False)

torch.cuda.empty_cache()

Expand All @@ -156,9 +89,11 @@ def _criterion(outputs, inputs):
@clear_cache_before_run()
def run_gpt2_test(test_config):

# TODO: add plugin_config for TP+DP after supporting & debugging it
# TODO: add test_config for TP+DP after supporting & debugging it
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}

# TODO: add test_config for flash attention & jit operator after supporting

sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
test_config['precision'] = 'float' # Do not use fp16/bf16 in testing

Expand All @@ -175,7 +110,6 @@ def check_gpt2(rank, world_size, port):
run_gpt2_test()


@pytest.mark.skip('Have some bug caused by merge')
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
Expand Down