Skip to content
2 changes: 1 addition & 1 deletion tests/kit/model_zoo/transformers/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def data_gen_for_qa():
output_transform_fn = lambda x: x

# define loss funciton
loss_fn_for_bert_model = lambda x: x.pooler_output.mean()
loss_fn_for_bert_model = lambda x: x.pooler_output.sum()
loss_fn = lambda x: x.loss

config = transformers.BertConfig(hidden_size=128,
Expand Down
14 changes: 10 additions & 4 deletions tests/kit/model_zoo/transformers/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,23 @@ def data_gen_for_question_answering():
input_ids = torch.tensor(
[[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)
start_positions = torch.tensor([1], dtype=torch.int64)
end_positions = torch.tensor([10], dtype=torch.int64)
return dict(input_ids=input_ids,
attention_mask=attention_mask,
start_positions=start_positions,
end_positions=end_positions)


# define output transform function
output_transform_fn = lambda x: x

# define loss function
loss_fn_for_bloom_model = lambda x: x.last_hidden_state.mean()
loss_fn_for_bloom_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state,
torch.ones_like(x.last_hidden_state))
loss_fn_for_causal_lm = lambda x: x.loss
loss_fn_for_classification = lambda x: x.logits.mean()
loss_fn_for_question_answering = lambda x: x.end_logits.mean()
loss_fn_for_classification = lambda x: x.loss
loss_fn_for_question_answering = lambda x: x.loss

config = transformers.BloomConfig(n_layer=1,
n_head=4,
Expand Down
20 changes: 12 additions & 8 deletions tests/kit/model_zoo/transformers/gpt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import copy

import torch
import transformers

Expand Down Expand Up @@ -44,22 +46,23 @@ def data_gen_for_token_classification():
# token classification data gen
# `labels` is the type not the token id for token classification, 0 or 1
data = data_gen()
data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0]], dtype=torch.int64)
data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 1]], dtype=torch.int64)
return data


def data_gen_for_sequence_classification():
# sequence classification data gen
data = data_gen()
data['labels'] = torch.tensor([0], dtype=torch.int64)
data['labels'] = torch.tensor([1], dtype=torch.int64)
return data


# define output transform function
output_transform_fn = lambda x: x

# define loss function
loss_fn_for_gpt2_model = lambda x: x.last_hidden_state.mean()
loss_fn_for_gpt2_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state
))
loss_fn = lambda x: x.loss

config = transformers.GPT2Config(n_layer=2,
Expand All @@ -69,9 +72,10 @@ def data_gen_for_sequence_classification():
embd_pdrop=0,
resid_pdrop=0,
summary_first_dropout=0,
hidden_dropout=0,
problem_type="single_label_classification",
pad_token_id=50256)
hidden_dropout=0)

config_for_token_classification = copy.deepcopy(config)
config_for_token_classification.num_labels = 2

# register the following models
model_zoo.register(name='transformers_gpt',
Expand Down Expand Up @@ -99,13 +103,13 @@ def data_gen_for_sequence_classification():
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_gpt_for_token_classification',
model_fn=lambda: transformers.GPT2ForTokenClassification(config),
model_fn=lambda: transformers.GPT2ForTokenClassification(config_for_token_classification),
data_gen_fn=data_gen_for_token_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_gpt_for_sequence_classification',
model_fn=lambda: transformers.GPT2ForSequenceClassification(config),
model_fn=lambda: transformers.GPT2ForSequenceClassification(config_for_token_classification),
data_gen_fn=data_gen_for_sequence_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
Expand Down
3 changes: 2 additions & 1 deletion tests/kit/model_zoo/transformers/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def data_gen_for_question_answering():


output_transform_fn = lambda x: x
loss_fn_for_opt_model = lambda x: x.last_hidden_state.mean()
loss_fn_for_opt_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state)
)
loss_fn_for_lm = lambda x: x.loss
config = transformers.OPTConfig(
hidden_size=128,
Expand Down
4 changes: 2 additions & 2 deletions tests/kit/model_zoo/transformers/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def data_gen():
# input_features = inputs.input_features
# decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id

input_features = torch.randn(1, 80, 3000)
input_features = torch.rand(1, 80, 3000)
decoder_input_ids = torch.tensor([[1, 1]]) * 50258
return dict(input_features=input_features, decoder_input_ids=decoder_input_ids)

Expand Down Expand Up @@ -53,7 +53,7 @@ def data_gen_for_audio_classification():
output_transform_fn = lambda x: x

# define loss funciton
loss_fn = lambda x: x.last_hidden_state.mean()
loss_fn = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state))
loss_fn_attr = lambda x: x.loss

config = transformers.WhisperConfig(
Expand Down
22 changes: 22 additions & 0 deletions tests/test_shardformer/test_model/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
from contextlib import nullcontext

import torch
import torch.distributed as dist
from torch.nn import Module

from colossalai.lazy import LazyInitContext
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer._utils import getattr_
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor


def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True, use_lazy_init: bool = False):
Expand Down Expand Up @@ -74,3 +77,22 @@ def check_state_dict(org_model: Module, sharded_model: Module, name: str = ''):
assert v.shape == shard_v.shape, f'{name} {k} shape mismatch, {v.shape} vs {shard_v.shape}'
assert v.dtype == shard_v.dtype, f'{name} {k} dtype mismatch, {v.dtype} vs {shard_v.dtype}'
assert torch.equal(v, shard_v), f'{name} {k} value mismatch'


def check_grad(original_model, sharded_model, layer_suffix, atol=1e-5, rtol=1e-5, dim=0, verbose=False):
for suffix in layer_suffix:
org_grad = getattr_(original_model, suffix).weight.grad
shard_grad = getattr_(sharded_model, suffix).weight.grad
shard_weight = getattr_(sharded_model, suffix).weight

if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(dist.get_world_size())]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=dim)
else:
all_shard_grad = shard_grad
if verbose and dist.get_rank() == 0:
print(f"'{suffix}' grad: {org_grad}, {all_shard_grad}")
assert torch.allclose(
org_grad, all_shard_grad, rtol=rtol, atol=atol
), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{all_shard_grad}"
50 changes: 13 additions & 37 deletions tests/test_shardformer/test_model/test_shard_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,18 @@
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward
from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward


def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# unwarp model
if org_model.__class__.__name__ == 'BertModel':
bert = org_model
sharded_bert = sharded_model
else:
bert = org_model.bert
sharded_bert = sharded_model.bert

# check forward
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
Expand All @@ -32,42 +40,10 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"

# check grad

if org_model.__class__.__name__ == 'BertModel':
bert = org_model
sharded_bert = sharded_model
else:
bert = org_model.bert
sharded_bert = sharded_model.bert

# compare self attention grad
org_grad = bert.encoder.layer[0].attention.self.query.weight.grad
shard_grad = sharded_bert.encoder.layer[0].attention.self.query.weight.grad
shard_weight = sharded_bert.encoder.layer[0].attention.self.query.weight

if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
else:
all_shard_grad = shard_grad
assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"

# compare embedding grad
org_grad = bert.embeddings.word_embeddings.weight.grad
shard_grad = sharded_bert.embeddings.word_embeddings.weight.grad
shard_weight = sharded_bert.embeddings.word_embeddings.weight

if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
else:
all_shard_grad = shard_grad

assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
col_layer_for_check = ['encoder.layer[0].attention.self.query', 'embeddings.word_embeddings']
row_layer_for_check = ['encoder.layer[0].attention.output.dense']
check_grad(bert, sharded_bert, col_layer_for_check, atol=1e-7, rtol=1e-3, dim=0, verbose=False)
check_grad(bert, sharded_bert, row_layer_for_check, atol=1e-7, rtol=1e-3, dim=1, verbose=False)


@parameterize('enable_fused_normalization', [False, True])
Expand Down
58 changes: 12 additions & 46 deletions tests/test_shardformer/test_model/test_shard_blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
Expand All @@ -12,7 +11,7 @@
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, run_forward
from tests.test_shardformer.test_model._utils import build_model, check_grad, run_forward


def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
Expand All @@ -33,50 +32,17 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
blip2 = org_model
sharded_blip2 = sharded_model

# compare vision_model grad

org_grad = blip2.vision_model.encoder.layers[0].self_attn.qkv.weight.grad
shard_grad = sharded_blip2.vision_model.encoder.layers[0].self_attn.qkv.weight.grad
shard_weight = sharded_blip2.vision_model.encoder.layers[0].self_attn.qkv.weight

if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
else:
all_shard_grad = shard_grad
assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"

# compare qformer grad
org_grad = blip2.qformer.encoder.layer[0].attention.attention.query.weight.grad
shard_grad = sharded_blip2.qformer.encoder.layer[0].attention.attention.query.weight.grad
shard_weight = sharded_blip2.qformer.encoder.layer[0].attention.attention.query.weight

if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
else:
all_shard_grad = shard_grad

assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"

# compare language_model grad
org_grad = blip2.language_model.model.decoder.layers[0].self_attn.k_proj.weight.grad
shard_grad = sharded_blip2.language_model.model.decoder.layers[0].self_attn.k_proj.weight.grad
shard_weight = sharded_blip2.language_model.model.decoder.layers[0].self_attn.k_proj.weight

if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
else:
all_shard_grad = shard_grad

assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
# check grad
col_layer_for_check = [
'vision_model.encoder.layers[0].self_attn.qkv', 'qformer.encoder.layer[0].attention.attention.query',
'language_model.model.decoder.layers[0].self_attn.k_proj'
]
row_layer_for_check = [
'vision_model.encoder.layers[0].self_attn.projection', 'qformer.encoder.layer[0].attention.output.dense',
'language_model.model.decoder.layers[0].self_attn.out_proj'
]
check_grad(blip2, sharded_blip2, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False)
check_grad(blip2, sharded_blip2, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False)


@parameterize('enable_fused_normalization', [True, False])
Expand Down
39 changes: 7 additions & 32 deletions tests/test_shardformer/test_model/test_shard_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
Expand All @@ -12,7 +11,7 @@
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward
from tests.test_shardformer.test_model._utils import build_model, check_grad, check_state_dict, run_forward


def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
Expand All @@ -26,7 +25,7 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
shard_loss.backward()

assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
atol=1e-6), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"

# unwrap model
if org_model.__class__.__name__ == 'BloomModel':
Expand All @@ -36,35 +35,11 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
bloom = org_model.transformer
sharded_bloom = sharded_model.transformer

# check attention grad
org_grad = bloom.h[0].self_attention.query_key_value.weight.grad
shard_grad = sharded_bloom.h[0].self_attention.query_key_value.weight.grad
shard_weight = sharded_bloom.h[0].self_attention.query_key_value.weight

if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
else:
all_shard_grad = shard_grad

assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"

# check embedding weights
org_grad = bloom.word_embeddings.weight.grad
shard_grad = sharded_bloom.word_embeddings.weight.grad
shard_weight = sharded_bloom.word_embeddings.weight

if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
else:
all_shard_grad = shard_grad

assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
# check grad
col_layer_for_check = ['h[0].self_attention.query_key_value']
row_layer_for_check = ['h[0].self_attention.dense']
check_grad(bloom, sharded_bloom, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False)
check_grad(bloom, sharded_bloom, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False)


@parameterize('enable_fused_normalization', [True, False])
Expand Down
Loading