Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
57d4fab
add pipeline policy and bert forward to be done
CjhHa1 Jun 30, 2023
8300f45
add bertmodel pipeline forward and make tests
CjhHa1 Jul 3, 2023
246b6d3
add Bert_Policy and test for policy
CjhHa1 Jul 3, 2023
db0a1f1
update formatting
CjhHa1 Jul 3, 2023
9f57067
update formatting
CjhHa1 Jul 3, 2023
dac6427
update the code
CjhHa1 Jul 3, 2023
8b30a02
fix bugs
CjhHa1 Jul 4, 2023
585eb9d
fix name confilt
CjhHa1 Jul 4, 2023
27fb804
add bloom model and policy ,revise the base class of policy
CjhHa1 Jul 4, 2023
86f3d90
Merge branch 'feature/pipeline' into feature/pipeline
CjhHa1 Jul 4, 2023
3ea0ba4
revise
CjhHa1 Jul 4, 2023
cd99e13
Merge branch 'feature/pipeline' of github.com:CjhHa1/ColossalAI into …
CjhHa1 Jul 4, 2023
edb02b2
revision
CjhHa1 Jul 4, 2023
369df2c
add bert_for_pretraining
CjhHa1 Jul 4, 2023
ac114f3
Merge branch 'hpcaitech:feature/pipeline' into feature/pipeline
CjhHa1 Jul 5, 2023
0319c8b
add bert_for_pretraining forward and policy
CjhHa1 Jul 5, 2023
29ef380
fix typos
CjhHa1 Jul 6, 2023
5cd2478
cancel warning
CjhHa1 Jul 6, 2023
ef528e6
change the imediate output to default dict
CjhHa1 Jul 6, 2023
e3e6c3b
change the default output of get_shared_params
CjhHa1 Jul 6, 2023
9eb6047
update
CjhHa1 Aug 1, 2023
6793bc3
add chatglm
CjhHa1 Aug 1, 2023
d9289dc
add
CjhHa1 Aug 1, 2023
68dd101
chatglm
CjhHa1 Aug 1, 2023
593d347
chatglm
CjhHa1 Aug 1, 2023
c14dc16
finish chatglm
CjhHa1 Aug 1, 2023
d451a17
deletes
CjhHa1 Aug 1, 2023
c6bbb90
fix rmsnorm
CjhHa1 Aug 3, 2023
098990a
Merge branch 'feature/pipeline' into chatglm
CjhHa1 Aug 3, 2023
cf0b554
chatglm
CjhHa1 Aug 3, 2023
907c22c
Merge branch 'chatglm' of github.com:CjhHa1/ColossalAI into chatglm
CjhHa1 Aug 3, 2023
c88d500
Merge branch 'feature/pipeline' into chatglm
CjhHa1 Aug 3, 2023
0beae70
fix chatglm shard
CjhHa1 Aug 4, 2023
1bfee44
Merge branch 'chatglm' of github.com:CjhHa1/ColossalAI into chatglm
CjhHa1 Aug 4, 2023
b868b6a
init
CjhHa1 Aug 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 189 additions & 0 deletions colossalai/shardformer/modeling/chatglm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
""" PyTorch ChatGLM model. """
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss, LayerNorm
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.utils import logging

from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
ChatGLMForConditionalGeneration,
ChatGLMModel,
GLMBlock,
)


class ChatGLMPipelineForwards:
'''
This class serves as a micro library for ChatGLM model forwards under pipeline parallelism.
'''

@staticmethod
def chatglm_model_forward(
self: ChatGLMModel,
input_ids,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
full_attention_mask: Optional[torch.BoolTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
):
logger = logging.get_logger(__name__)
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
if past_key_values:
logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
past_key_values = None
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if use_cache:
logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
use_cache = False
if stage_manager.is_first_stage():
batch_size, seq_length = input_ids.shape
if inputs_embeds is None:
inputs_embeds = self.embedding(input_ids)
hidden_states = inputs_embeds
else:
seq_length, batch_size = hidden_states.shape[:2]
if self.pre_seq_len is not None:
if past_key_values is None:
past_key_values = self.get_prompt(batch_size=batch_size,
device=input_ids.device,
dtype=inputs_embeds.dtype)
if attention_mask is not None:
attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask],
dim=-1)
if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
# Rotary positional embeddings
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
if position_ids is not None:
rotary_pos_emb = rotary_pos_emb[position_ids]
else:
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
if not past_key_values:
past_key_values = [None for _ in range(self.num_layers)]
presents = () if use_cache else None
if self.encoder.gradient_checkpointing and self.encoder.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
use_cache = False
all_self_attentions = None
all_hidden_states = () if output_hidden_states else None
start_idx, end_idx = stage_index[0], stage_index[1]
for idx in range(start_idx, end_idx):
layer = self.encoder._get_layer(idx)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.encoder.gradient_checkpointing and self.encoder.training:
layer_ret = torch.utils.checkpoint.checkpoint(layer, hidden_states, attention_mask, rotary_pos_emb,
past_key_values[idx], use_cache)
else:
layer_ret = layer(hidden_states,
full_attention_mask,
rotary_pos_emb,
kv_cache=past_key_values[idx],
use_cache=use_cache)
hidden_states, kv_cache = layer_ret
if use_cache:
presents = presents + (kv_cache,)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if stage_manager.is_last_stage():
# final layer_norm
if self.encoder.post_layer_norm:
hidden_states = self.encoder.final_layernorm(hidden_states)
if not return_dict:
return tuple(
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
else:
return {'hidden_states': hidden_states}

@staticmethod
def chatglm_for_conditional_generation_forward(
self: ChatGLMForConditionalGeneration,
input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
return_last_logit: Optional[bool] = False,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
):
logger = logging.get_logger(__name__)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
transformer_outputs = ChatGLMPipelineForwards.chatglm_model_forward(
self.transformer,
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
)
if stage_manager.is_last_stage():
hidden_states = transformer_outputs[0]
if return_last_logit:
hidden_states = hidden_states[-1:]
lm_logits = self.transformer.output_layer(hidden_states)
lm_logits = lm_logits.transpose(0, 1).contiguous()
loss = None
if labels is not None:
lm_logits = lm_logits.to(torch.float32)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
lm_logits = lm_logits.to(hidden_states.dtype)
loss = loss.to(hidden_states.dtype)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
else:
return transformer_outputs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from transformers import PretrainedConfig


class ChatGLMConfig(PretrainedConfig):
model_type = "chatglm"

def __init__(self,
num_layers=28,
padded_vocab_size=65024,
hidden_size=4096,
ffn_hidden_size=13696,
kv_channels=128,
num_attention_heads=32,
seq_length=2048,
hidden_dropout=0.0,
attention_dropout=0.0,
layernorm_epsilon=1e-5,
rmsnorm=True,
apply_residual_connection_post_layernorm=False,
post_layer_norm=True,
add_bias_linear=False,
add_qkv_bias=False,
bias_dropout_fusion=True,
multi_query_attention=False,
multi_query_group_num=1,
apply_query_key_layer_scaling=True,
attention_softmax_in_fp32=True,
fp32_residual_connection=False,
quantization_bit=0,
pre_seq_len=None,
prefix_projection=False,
**kwargs):
self.num_layers = num_layers
self.vocab_size = padded_vocab_size
self.padded_vocab_size = padded_vocab_size
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.kv_channels = kv_channels
self.num_attention_heads = num_attention_heads
self.seq_length = seq_length
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.layernorm_epsilon = layernorm_epsilon
self.rmsnorm = rmsnorm
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
self.post_layer_norm = post_layer_norm
self.add_bias_linear = add_bias_linear
self.add_qkv_bias = add_qkv_bias
self.bias_dropout_fusion = bias_dropout_fusion
self.multi_query_attention = multi_query_attention
self.multi_query_group_num = multi_query_group_num
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
self.fp32_residual_connection = fp32_residual_connection
self.quantization_bit = quantization_bit
self.pre_seq_len = pre_seq_len
self.prefix_projection = prefix_projection
super().__init__(**kwargs)
Loading