Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions colossalai/shardformer/policies/auto_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ class PolicyLocation:
PolicyLocation(file_name="gpt2", class_name="GPT2LMHeadModelPolicy"),
"transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel":
PolicyLocation(file_name="gpt2", class_name="GPT2DoubleHeadsModelPolicy"),
"transformers.models.gpt2.modeling_gpt2.GPT2ForQuestionAnswering":
PolicyLocation(file_name="gpt2", class_name="GPT2ForQuestionAnsweringPolicy"),
"transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification":
PolicyLocation(file_name="gpt2", class_name="GPT2ForTokenClassificationPolicy"),
"transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification":
Expand Down
136 changes: 127 additions & 9 deletions colossalai/shardformer/policies/gpt2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import logging
from functools import partial
from types import MethodType
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -298,6 +296,33 @@ def postprocess(self):
return self.model


# GPT2ForQuestionAnswering
class GPT2ForQuestionAnsweringPolicy(GPT2Policy):

def __init__(self) -> None:
super().__init__()

def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2ForQuestionAnswering

module_policy = super().module_policy()
self.set_pipeline_forward(model_cls=GPT2ForQuestionAnswering,
new_forward=GPT2PipelineForwards.gpt2_for_question_answering_forward,
policy=module_policy)

return module_policy

def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.qa_outputs)
return held_layers

def get_shared_params(self) -> List[Dict[int, Tensor]]:
'''No shared_params in gpt2 for QA.'''
return []


# GPT2ForTokenClassification
class GPT2ForTokenClassificationPolicy(GPT2Policy):

Expand Down Expand Up @@ -391,6 +416,8 @@ def gpt2_model_forward(
# Please refer to original code of transformers for more details.

from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.utils import logging
logger = logging.get_logger(__name__)

# Preprocess passed in arguments
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Expand All @@ -416,7 +443,8 @@ def gpt2_model_forward(
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, seq_length)
else:
assert hidden_states is not None
if hidden_states is None:
raise ValueError("hidden_states shouldn't be None for stages other than the first stage.")
input_shape = hidden_states.size()[:-1]
batch_size, seq_length = input_shape[0], input_shape[1]
device = hidden_states.device
Expand Down Expand Up @@ -478,21 +506,21 @@ def gpt2_model_forward(

# TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
if past_key_values:
logging.warning('Non-empty past_key_values is not supported for pipeline models at the moment.')
logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
past_key_values = None
if output_attentions:
logging.warning('output_attentions=True is not supported for pipeline models at the moment.')
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
if output_hidden_states:
logging.warning('output_hidden_states=True is not supported for pipeline models at the moment.')
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if use_cache:
logging.warning('use_cache=True is not supported for pipeline models at the moment.')
logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
use_cache = False

if self.gradient_checkpointing and self.training:
if use_cache:
logging.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
use_cache = False
presents = () if use_cache else None
Expand Down Expand Up @@ -751,6 +779,94 @@ def gpt2_double_heads_model_forward(
attentions=outputs.attentions,
)

@staticmethod
def gpt2_for_question_answering_forward(
self: 'GPT2ForQuestionAnswering',
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None) -> Union[Tuple, 'QuestionAnsweringModelOutput']:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.

# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForQuestionAnswering.forward.
# Please refer to original code of transformers for more details.
"""
from transformers.modeling_outputs import QuestionAnsweringModelOutput

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index)

# If not at the last stage, return hidden_states as in GPT2Model
if not stage_manager.is_last_stage():
return {'hidden_states': outputs['hidden_states']}

sequence_output = outputs[0]

logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()

total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1).to(start_logits.device)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1).to(end_logits.device)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)

loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2

if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output

return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

@staticmethod
def gpt2_for_token_classification_forward(
self: 'GPT2ForTokenClassification',
Expand Down Expand Up @@ -852,6 +968,8 @@ def gpt2_for_sequence_classification_forward(
# Please refer to original code of transformers for more details.
"""
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
from transformers.utils import logging
logger = logging.get_logger(__name__)

if input_ids is not None:
batch_size, _ = input_ids.shape[:2]
Expand Down Expand Up @@ -892,7 +1010,7 @@ def gpt2_for_sequence_classification_forward(
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
else:
sequence_lengths = -1
logging.warning(
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`")

Expand Down
2 changes: 1 addition & 1 deletion tests/kit/model_zoo/torchrec/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
#from .torchrec import *
from .torchrec import *
17 changes: 17 additions & 0 deletions tests/kit/model_zoo/transformers/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ def data_gen_for_lm():
return data


def data_gen_for_question_answering():
# question answering data gen
# `labels` is the type not the token id for token classification, 0 or 1
data = data_gen()
start_positions = torch.tensor([0], dtype=torch.int64)
data['start_positions'] = start_positions
end_positions = torch.tensor([1], dtype=torch.int64)
data['end_positions'] = end_positions
return data


def data_gen_for_token_classification():
# token classification data gen
# `labels` is the type not the token id for token classification, 0 or 1
Expand Down Expand Up @@ -82,6 +93,12 @@ def data_gen_for_sequence_classification():
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_gpt_for_question_answering',
model_fn=lambda: transformers.GPT2ForQuestionAnswering(config),
data_gen_fn=data_gen_for_question_answering,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_gpt_for_token_classification',
model_fn=lambda: transformers.GPT2ForTokenClassification(config),
data_gen_fn=data_gen_for_token_classification,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz

sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():

inputs = data_gen_fn()
inputs = {k: v.cuda() for k, v in inputs.items()}
input_ids, _ = inputs['input_ids'], inputs['attention_mask']
Expand Down