Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions colossalai/shardformer/modeling/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def gpt2_model_forward(
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
input_shape = input_ids.size()
input_ids = input_ids.view(-1, seq_length)
input_ids = input_ids.view(-1, input_shape[-1])
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
Expand All @@ -89,13 +89,14 @@ def gpt2_model_forward(

device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, seq_length)
token_type_ids = token_type_ids.view(-1, input_shape[-1])
else:
if hidden_states is None:
raise ValueError("hidden_states shouldn't be None for stages other than the first stage.")
input_shape = hidden_states.size()[:-1]
batch_size, seq_length = input_shape[0], input_shape[1]
batch_size = input_shape[0]
device = hidden_states.device
hidden_states = hidden_states.view((-1,) + hidden_states.shape[-2:])
Comment thread
Fridge003 marked this conversation as resolved.

# GPT2Attention mask.
if attention_mask is not None:
Expand Down Expand Up @@ -136,9 +137,9 @@ def gpt2_model_forward(

if stage_manager.is_first_stage():
if position_ids is not None:
position_ids = position_ids.view(-1, seq_length)
position_ids = position_ids.view(-1, input_shape[-1])
else:
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)
position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])

if inputs_embeds is None:
Expand Down Expand Up @@ -721,7 +722,6 @@ def forward(
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
_, tgt_len, _ = hidden_states.size()

if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
Expand Down
38 changes: 27 additions & 11 deletions tests/kit/model_zoo/transformers/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,27 @@ def data_gen_for_sequence_classification():


def date_gen_for_double_heads():
data = data_gen_for_lm()
data['mc_labels'] = torch.zeros(data['input_ids'].shape[0], dtype=torch.int64)
return data
num_choices = 2
batch_size = 2
input_ids = torch.tensor(
[[15496, 11, 616, 3290, 318, 13779, 318, 13779], [15496, 11, 616, 3290, 318, 13779, 318, 13779]],
dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
mc_labels = torch.zeros(input_ids.shape[0], dtype=torch.int64)

mc_token_ids = torch.arange(0, num_choices, dtype=torch.int64)
mc_token_ids = mc_token_ids.expand((batch_size, num_choices))
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, num_choices, -1).contiguous()
multiple_choice_input_mask = attention_mask.unsqueeze(1).expand(-1, num_choices, -1).contiguous()

inputs = {
"input_ids": multiple_choice_inputs_ids,
"mc_token_ids": mc_token_ids,
"attention_mask": multiple_choice_input_mask,
"labels": multiple_choice_inputs_ids,
"mc_labels": mc_labels,
}
return inputs


# define output transform function
Expand Down Expand Up @@ -98,14 +116,12 @@ def date_gen_for_double_heads():
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))

# TODO The model training is failing, there is a bug in GPT2DoubleHeadsModel in transformers.
# model_zoo.register(name='transformers_gpt_double_heads',
# model_fn=lambda: transformers.GPT2DoubleHeadsModel(config),
# data_gen_fn=date_gen_for_double_heads,
# output_transform_fn=lambda x: dict(loss=x.loss + x.mc_loss),
# loss_fn=loss_fn,
# model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_gpt_double_heads',
model_fn=lambda: transformers.GPT2DoubleHeadsModel(config),
data_gen_fn=date_gen_for_double_heads,
output_transform_fn=output_transform_fn,
loss_fn=lambda x: x.loss + x.mc_loss,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_gpt_for_question_answering',
model_fn=lambda: transformers.GPT2ForQuestionAnswering(config),
data_gen_fn=data_gen_for_question_answering,
Expand Down
3 changes: 2 additions & 1 deletion tests/test_booster/test_plugin/test_gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def check_gemini_plugin(subset: str, init_method: str = 'none', early_stop: bool
'transformers_t5_encoder_model', # does not support apex rmsnorm
'transformers_chatglm',
'transformers_sam',
'transformers_vit'
'transformers_vit',
'transformers_gpt_double_heads', # TODO check why does the model fail to run using Gemini
]:
continue

Expand Down
4 changes: 2 additions & 2 deletions tests/test_shardformer/test_model/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,13 @@ def _criterion(outputs, inputs):
data = data_gen_fn()

if booster.plugin.enable_sequence_parallelism and booster.plugin.tp_size != 0:
seq_len = data['input_ids'].shape[1]
seq_len = data['input_ids'].shape[-1]
lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len)
times = lcm // seq_len
input_shape = data['input_ids'].shape
for k, v in data.items():
if v.shape == input_shape:
data[k] = v.repeat(1, times)
data[k] = v.repeat(input_shape[:-1] + (input_shape[-1] * times,))

sharded_model.train()
if booster.plugin.stage_manager is not None:
Expand Down
8 changes: 0 additions & 8 deletions tests/test_shardformer/test_model/test_shard_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'num_microbatches': 4,
'enable_all_optimization': True,
'use_lazy_init': True,
'enable_sequence_parallelism': True,
Comment thread
Fridge003 marked this conversation as resolved.
'precision': 'fp32',
}, {
'tp_size': 4,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': True,
'enable_sequence_parallelism': True,
'precision': 'fp32',
}, {
'tp_size': 2,
Expand Down