Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,10 @@ def __init__(
precision: str = 'fp16',
zero_stage: int = 0,
cpu_offload: bool = False,
enable_all_optimization: bool = False,
enable_fused_normalization: bool = False,
enable_flash_attention: bool = False,
enable_jit_fused: bool = False,
num_microbatches: Optional[int] = None,
initial_scale: float = 2**16,
min_scale: float = 1,
Expand All @@ -171,7 +174,10 @@ def __init__(
self.precision = precision
self.zero_stage = zero_stage
self.cpu_offload = cpu_offload
self.enable_all_optimization = enable_all_optimization
self.enable_fused_normalization = enable_fused_normalization
self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size)
self.stage_manager = None
self.schedule = None
Expand All @@ -186,7 +192,10 @@ def __init__(
self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group,
pipeline_stage_manager=self.stage_manager,
enable_tensor_parallelism=self.tp_size > 1,
enable_fused_normalization=self.enable_fused_normalization)
enable_all_optimization=self.enable_all_optimization,
enable_fused_normalization=self.enable_fused_normalization,
enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused)
self.amp_config = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggi
SentencePiece
datasets
ninja
flash-attn
flash-attn>=2.0
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ contexttimer
ninja
torch>=1.11
safetensors
flash-attn
flash-attn>=2.0
16 changes: 9 additions & 7 deletions tests/test_shardformer/test_model/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import copy
from contextlib import nullcontext
from typing import Optional
from typing import Any, Callable, Dict, List, Optional

import torch
Expand All @@ -16,8 +15,8 @@
from colossalai.lazy import LazyInitContext
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.auto_policy import Policy
from colossalai.shardformer._utils import getattr_
from colossalai.shardformer.policies.auto_policy import Policy
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor


Expand Down Expand Up @@ -156,10 +155,12 @@ def _criterion(outputs, inputs):
else:
data = {k: v.cuda() for k, v in data.items()}
sharded_output = sharded_model(**data)

sharded_loss = criterion(sharded_output)
sharded_loss.backward()
sharded_optimizer.backward(sharded_loss)

org_model.train()
data = {k: v.cuda() for k, v in data.items()}
org_output = org_model(**data)
org_loss = criterion(org_output)
org_loss.backward()
Expand All @@ -181,12 +182,12 @@ def check_output_hidden_state(org_output: Tensor,
if stage_manager and stage_manager.is_last_stage():
sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=0)

assert torch.allclose(org_hidden_state, sharded_hidden_state, atol=atol, rtol=rtol), \
assert torch.allclose(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol), \
f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"


def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3):
assert torch.allclose(org_loss, sharded_loss, atol=atol, rtol=rtol), \
assert torch.allclose(org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol), \
f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}"


Expand All @@ -213,7 +214,7 @@ def check_weight(org_model: Module,
if verbose and dist.get_rank() == 0:
print(f"'{suffix}' weight: {org_weight}, {sharded_weight}")

assert torch.allclose(org_weight, sharded_weight, atol=atol, rtol=rtol), \
assert torch.allclose(org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol), \
f"shard model weight is not equal to origin model weight\n{org_weight}\n{sharded_weight}"


Expand Down Expand Up @@ -244,6 +245,7 @@ def check_grad(org_model: Module,

if verbose and dist.get_rank() == 0:
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")

assert torch.allclose(
org_grad, shard_grad, rtol=rtol, atol=atol
org_grad.float(), shard_grad.float(), rtol=rtol, atol=atol
), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}"
59 changes: 39 additions & 20 deletions tests/test_shardformer/test_model/test_shard_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torch import distributed as dist

import colossalai
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
Expand Down Expand Up @@ -38,33 +39,49 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,

# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3

if org_model.__class__.__name__ == 'GPT2Model':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=1e-5, rtol=1e-3)
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)

check_loss(org_loss, sharded_loss, atol=1e-5, rtol=1e-3)
# check loss
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)

def unwrap(module):
if isinstance(module, HybridParallelModule):
module = module.unwrap()
if module.__class__.__name__ == 'GPT2Model':
return module
return module.transformer

# unwrap model
if org_model.__class__.__name__ == 'GPT2Model':
gpt2 = org_model
sharded_gpt2 = sharded_model.unwrap()
else:
gpt2 = org_model.transformer
sharded_gpt2 = sharded_model.unwrap().transformer
gpt2 = unwrap(org_model)
sharded_gpt2 = unwrap(sharded_model)

col_layer_for_check = ['h[0].mlp.c_fc']
row_layer_for_check = ['wte', 'h[0].mlp.c_proj']

# check grad
if test_config['precision'] == 'fp32':
atol, rtol = 1e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
check_grad(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=1, verbose=False)
check_grad(gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=1e-4, rtol=1e-3, dim=0, verbose=False)
check_grad(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
check_grad(gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)

# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
if test_config['precision'] == 'fp32':
atol, rtol = 5e-3, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=5e-3, rtol=1e-3, dim=1, verbose=False)
check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)

torch.cuda.empty_cache()

Expand All @@ -73,29 +90,31 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_fused_normalization': True,
'use_lazy_init': True
'enable_all_optimization': True,
'use_lazy_init': True,
'precision': 'fp32',
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'use_lazy_init': False
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp16',
'initial_scale': 1,
}, {
'tp_size': 4,
'pp_size': 1,
'enable_fused_normalization': True,
'use_lazy_init': False
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
}])
@clear_cache_before_run()
def run_gpt2_test(test_config):

# TODO: add test_config for TP+DP after supporting & debugging it
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}

# TODO: add test_config for flash attention & jit operator after supporting
# TODO: check and debug TP+AMP

sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
test_config['precision'] = 'float' # Do not use fp16/bf16 in testing

for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
Expand Down