Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion colossalai/inference/pipeline/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def __init__(
self.pp_size = pp_size
self.pg_mesh = ProcessGroupMesh(pp_size)
self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True)
self.mb_manager = MicroBatchManager(new_length, micro_batch_size, micro_batch_buffer_size or pp_size)
self.mb_manager = MicroBatchManager(self.stage_manager.stage, new_length, micro_batch_size,
micro_batch_buffer_size or pp_size)
self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager)
self.model = pp_model or self._shardformer(model, model_policy)

Expand Down
197 changes: 139 additions & 58 deletions colossalai/inference/pipeline/microbatch_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Dict
from typing import Dict, Tuple

import torch

Expand All @@ -13,121 +13,202 @@ class Status(Enum):


class MicroBatchDescription():
"""
This is the class to record the infomation of each microbatch, and also do some update operation.
This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more
details, please refer to the doc of these two classes blow.

Args:
inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
"""

def __init__(
self,
mb_inputs: Dict[str, torch.Tensor],
interval_inputs: Dict[str, torch.Tensor],
inputs_dict: Dict[str, torch.Tensor],
output_dict: Dict[str, torch.Tensor],
new_length: int,
) -> None:
if mb_inputs is not None:
assert mb_inputs.get('input_ids') is not None and mb_inputs.get('attention_mask') is not None
self.mb_length = mb_inputs['input_ids'].shape[-1]
self.attn_mask = mb_inputs['attention_mask']
self.input_ids = mb_inputs['input_ids']

elif interval_inputs is not None:
assert interval_inputs.get('hidden_states') is not None
self.mb_length = interval_inputs['hidden_states'].shape[-2]
else:
raise ValueError('mb_inputs and interval_inputs can not be None at the same time')

assert output_dict.get('hidden_states') is not None
self.mb_length = output_dict['hidden_states'].shape[-2]
self.target_length = self.mb_length + new_length
self.kv_cache = ()

def update(self, kv_cache):
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
if output_dict is not None:
self._update_kvcache(output_dict['past_key_values'])

def _update_kvcache(self, kv_cache: Tuple):
assert type(kv_cache) == tuple
self.kv_cache = kv_cache

@property
def state(self):
"""
Return the state of current micro batch, when current length is equal to target length,
the state is DONE, otherwise GENERATE

"""
# TODO: add the condition for early stopping
if self.cur_length == self.target_length:
Comment thread
ver217 marked this conversation as resolved.
return Status.DONE
else:
return Status.GENERATE

@property
def cur_length(self):
"""
Return the current sequnence length of micro batch, when there is no kv_cache, the length is mb_length,
otherwise the sequence length is `kv_cache[0][0].shape[-2]` plus 1
Return the current sequnence length of micro batch

"""
if len(self.kv_cache) == 0:
pass


class HeadMicroBatchDescription(MicroBatchDescription):
"""
This class is used to record the infomation of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask`
and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schdule of pipeline, the operation to update the
information and the condition to determine the state is different from other stages.

Args:
inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
new_length (int): the new length of the input sequence.

"""

Comment thread
ver217 marked this conversation as resolved.
def __init__(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor],
new_length: int) -> None:
super().__init__(inputs_dict, output_dict, new_length)
assert inputs_dict is not None
assert inputs_dict.get('input_ids') is not None and inputs_dict.get('attention_mask') is not None
self.input_ids = inputs_dict['input_ids']
self.attn_mask = inputs_dict['attention_mask']
self.new_tokens = None

def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
super().update(output_dict, new_token)
if new_token is not None:
self._update_newtokens(new_token)
if self.state is not Status.DONE and new_token is not None:
self._update_attnmask()

def _update_newtokens(self, new_token: torch.Tensor):
if self.new_tokens is None:
self.new_tokens = new_token
else:
self.new_tokens = torch.cat([self.new_tokens, new_token], dim=-1)

def _update_attnmask(self):
self.attn_mask = torch.cat(
(self.attn_mask, torch.ones((self.attn_mask.shape[0], 1), dtype=torch.int64, device='cuda')), dim=-1)

@property
def cur_length(self):
"""
When there is no new_token, the length is mb_length, otherwise the sequence length is `mb_length` plus the length of new_token

"""
if self.new_tokens is None:
return self.mb_length
return self.kv_cache[0][0].shape[-2] + 1
else:
return self.mb_length + len(self.new_tokens[0])


class BodyMicroBatchDescription(MicroBatchDescription):
"""
This class is used to record the infomation of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`,

Args:
inputs_dict (Dict[str, torch.Tensor]): will always be `None`. Other stages only receive hiddenstates from previous stage.
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
"""

def __init__(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor],
new_length: int) -> None:
super().__init__(inputs_dict, output_dict, new_length)

def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
super().update(output_dict, new_token)

@property
def state(self):
def cur_length(self):
"""
Return the state of current micro batch, when current length is equal to target length,
the state is DONE, otherwise GENERATE
When there is no kv_cache, the length is mb_length, otherwise the sequence length is `kv_cache[0][0].shape[-2]` plus 1

"""
if self.cur_length == self.target_length:
return Status.DONE
if len(self.kv_cache) == 0:
return self.mb_length
else:
return Status.GENERATE
return self.kv_cache[0][0].shape[-2] + 1


class MicroBatchManager():
'''
MicroBatchManager is a class that manages the micro batch.

Args:
stage (int): stage id of current stage.
new_length (int): the new length of the input sequence.
micro_batch_size (int): the micro batch size.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.

'''

def __init__(self, new_length: int, micro_batch_size: int, micro_batch_buffer_size: int):
def __init__(self, stage: int, new_length: int, micro_batch_size: int, micro_batch_buffer_size: int):
self.stage = stage
self.new_length = new_length
self.micro_batch_size = micro_batch_size
self.buffer_size = micro_batch_buffer_size
self.mb_descrption_buffer = {}
self.new_tokens_buffer = {}
self.idx = 0

def _add_descrption(self, mb_inputs: Dict[str, torch.Tensor], inter_inputs: Dict[str, torch.Tensor]):
self.mb_descrption_buffer[self.idx] = MicroBatchDescription(mb_inputs, inter_inputs, self.new_length)

def _update_descrption(self, present_kv):
self.mb_descrption_buffer[self.idx].update(present_kv)

def _remove_descrption(self):
self.mb_descrption_buffer.pop(self.idx)

def step(self, mb_inputs=None, inter_inputs=None, present_kv=None):
def step(self, inputs_dict=None, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
"""
Update the state if microbatch manager
Update the state if microbatch manager, 2 conditions.
1. For first stage in PREFILL, receive inputs and outputs, `_add_descrption` will save its inputs.
2. For other conditon, only receive the output of previous stage, and update the descrption.

Args:
mb_inputs (int, optional): The input of first stage when in prefill, should be a dict like {'input_ids': torch.Tensor, 'attention_mask': torch.Tensor}.
inter_inputs ([type], optional): The input of intermediate stage (the output of previous stage), should be a dict like {'hidden_state': torch.Tensor}.
present_kv ([type], optional): The kvcache of current microbatch in current stage.
inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
new_token (torch.Tensor): the new token generated by current stage.
"""
# Add descrption first if the descrption is None
if inputs_dict is None and output_dict is None and new_token is None:
return Status.PREFILL
if self.mb_descrption_buffer.get(self.idx) is None:
self._add_descrption(mb_inputs, inter_inputs)
self._update_descrption(present_kv)
state = self.cur_state
self.next()
return state
self._add_descrption(inputs_dict, output_dict)
self.cur_descrption.update(output_dict, new_token)
return self.cur_state

def next(self):
self.idx = (self.idx + 1) % self.buffer_size
def export_new_tokens(self):
new_tokens_list = [i.new_tokens[0].tolist() for i in self.mb_descrption_buffer.values()]
return new_tokens_list

def is_micro_batch_done(self):
if len(self.mb_descrption_buffer) == 0:
return False
for mb in self.mb_descrption_buffer.values():
if mb.state != Status.DONE:
return False
self.mb_descrption_buffer.clear()
return True

def add_new_tokens(self, new_token):
if self.idx not in self.new_tokens_buffer:
self.new_tokens_buffer[self.idx] = new_token
def clear(self):
self.mb_descrption_buffer.clear()

def next(self):
self.idx = (self.idx + 1) % self.buffer_size

def _add_descrption(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor]):
if self.stage == 0:
self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription(inputs_dict, output_dict, self.new_length)
else:
self.new_tokens_buffer[self.idx] = torch.cat([self.new_tokens_buffer[self.idx], new_token], dim=-1)
self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription(inputs_dict, output_dict, self.new_length)

def export_new_tokens(self):
list = [item.tolist() for item in self.new_tokens_buffer.values()]
flat_list = [item for sublist in list for item in sublist]
self.new_tokens_buffer.clear()
return flat_list
def _remove_descrption(self):
self.mb_descrption_buffer.pop(self.idx)

@property
def cur_descrption(self) -> MicroBatchDescription:
Expand Down
28 changes: 7 additions & 21 deletions colossalai/inference/pipeline/modeling/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,12 @@ def gpt2_lmhead_model_forward(
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# If is first stage and after warmup, go throught lm_head first
if stage_manager.is_first_stage() and hidden_states is not None:
lm_logits = self.lm_head(hidden_states)
return {'logits': lm_logits}

# Not first stage or before warmup, go through gpt2 model
outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
input_ids,
past_key_values=past_key_values,
Expand All @@ -269,24 +275,4 @@ def gpt2_lmhead_model_forward(
hidden_states=hidden_states,
stage_index=stage_index)

# If not at the last stage, return hidden_states as in GPT2Model
if not stage_manager.is_last_stage():
return outputs

hidden_states = outputs['hidden_states']
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output

return {'hidden_states': lm_logits, 'past_key_values': outputs['past_key_values']}
return outputs
4 changes: 3 additions & 1 deletion colossalai/inference/pipeline/policy/gpt2_ppinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def module_policy(self):

def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
# make the tie weight lm_head and embedding in the same device to save memory
# if self.pipeline_stage_manager.is_first_stage():
if self.pipeline_stage_manager.is_first_stage():
held_layers.append(self.model.lm_head)
return held_layers

Expand Down
Loading