Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions colossalai/pipeline/schedule/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _recv_pre_stage(self) -> Any:
return self.comm.p2p_recv()
return self.comm.recv_forward()

def LoadStageAction(self, model: Module) -> None:
def _load_stage_action(self, model: Module) -> None:
"""
In this action, 1.load micro_batch 2.do the forward 3.step to update
"""
Expand All @@ -140,7 +140,7 @@ def LoadStageAction(self, model: Module) -> None:
self.mb_manager.step(inputs_dict, output_dict, None)
self.action_interval_buffer.hidden_states = output_dict['hidden_states']

def GenTokenAction(self, model: Module):
def _gen_token_action(self, model: Module):
"""
In this action, 1.do the forward with hidden_states to generate new tokens 2.step to update
"""
Expand All @@ -158,7 +158,7 @@ def GenTokenAction(self, model: Module):
self.action_interval_buffer.new_token = new_token
self.action_interval_buffer.hidden_states = None

def HeadEncodingAction(self, model: Module):
def _head_encoding_action(self, model: Module):
"""
In this action, 1.prepare inputs for encoding for first stage. 2.do the forward to get hidden states 3.step to update
"""
Expand All @@ -170,7 +170,7 @@ def HeadEncodingAction(self, model: Module):
self.mb_manager.step(inputs_dict, output_dict, None)
self.action_interval_buffer.hidden_states = output_dict['hidden_states']

def BodyEncodingAction(self, model: Module):
def _body_encoding_action(self, model: Module):
hidden_states = self.action_interval_buffer.hidden_states
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
inputs_dict = self._prepare_inputs_for_interval_stage()
Expand All @@ -180,7 +180,7 @@ def BodyEncodingAction(self, model: Module):
self.mb_manager.step(inputs_dict, output_dict, None)
self.action_interval_buffer.hidden_states = output_dict['hidden_states']

def CommAction(self, recv_pre: bool) -> torch.Tensor:
def _comm_action(self, recv_pre: bool) -> torch.Tensor:
"""
In this action, 1.receive the hidden_states from previous stage 2.send the hidden_states to next stage
"""
Expand All @@ -189,7 +189,7 @@ def CommAction(self, recv_pre: bool) -> torch.Tensor:

self.action_interval_buffer.hidden_states = ret

def genAction(self, model: Module):
def _gen_action(self, model: Module):
"""
In p2p step method, we use `P2POp` asynchronous communication method, so the communication need to be done
at the begin of each microbatch, it's a more clear way to use an action list to do so. In this function, it will
Expand All @@ -204,19 +204,19 @@ def genAction(self, model: Module):
actions = []
if self.stage_manager.is_first_stage():
if self.mb_manager.cur_state is Status.PREFILL:
actions.append(partial(self.CommAction, False))
actions.append(partial(self.LoadStageAction, model))
actions.append(partial(self._comm_action, False))
actions.append(partial(self._load_stage_action, model))
elif self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.GENERATE:
actions.append(partial(self.CommAction, True))
actions.append(partial(self.GenTokenAction, model))
actions.append(partial(self.HeadEncodingAction, model))
actions.append(partial(self._comm_action, True))
actions.append(partial(self._gen_token_action, model))
actions.append(partial(self._head_encoding_action, model))
elif self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.COOLDOWN:
actions.append(partial(self.CommAction, True))
actions.append(partial(self.GenTokenAction, model))
actions.append(partial(self._comm_action, True))
actions.append(partial(self._gen_token_action, model))
# other stage
else:
actions.append(partial(self.CommAction, True))
actions.append(partial(self.BodyEncodingAction, model))
actions.append(partial(self._comm_action, True))
actions.append(partial(self._body_encoding_action, model))

return actions

Expand Down Expand Up @@ -252,15 +252,15 @@ def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.T
] if self.verbose and self.stage_manager.is_first_stage() else None
self.action_interval_buffer.clear()
while self.mb_manager.is_micro_batch_done() is False:
actions = self.genAction(model)
actions = self._gen_action(model)
for action in actions:
action()
self.mb_manager.next()
# All microbatch in current round is DONE
if self.stage_manager.is_first_stage():
output_sequence.extend(self.mb_manager.export_new_tokens())
else:
self.CommAction(False)
self._comm_action(False)
self.mb_manager.clear()
if self.verbose and self.stage_manager.is_first_stage():
whole_timestamp.extend(self.timestamps)
Expand Down