From 3a5a3d36330567cb4b447115385933728a1426d0 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Mon, 9 Oct 2023 15:28:34 +0800 Subject: [PATCH] refactor code --- colossalai/pipeline/schedule/generate.py | 34 +++++++++---------- .../test_pipeline_infer.py | 0 2 files changed, 17 insertions(+), 17 deletions(-) rename tests/{test_generate => test_infer}/test_pipeline_infer.py (100%) diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index 1a961d3036b8..8f6acd5fcf4b 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -127,7 +127,7 @@ def _recv_pre_stage(self) -> Any: return self.comm.p2p_recv() return self.comm.recv_forward() - def LoadStageAction(self, model: Module) -> None: + def _load_stage_action(self, model: Module) -> None: """ In this action, 1.load micro_batch 2.do the forward 3.step to update """ @@ -140,7 +140,7 @@ def LoadStageAction(self, model: Module) -> None: self.mb_manager.step(inputs_dict, output_dict, None) self.action_interval_buffer.hidden_states = output_dict['hidden_states'] - def GenTokenAction(self, model: Module): + def _gen_token_action(self, model: Module): """ In this action, 1.do the forward with hidden_states to generate new tokens 2.step to update """ @@ -158,7 +158,7 @@ def GenTokenAction(self, model: Module): self.action_interval_buffer.new_token = new_token self.action_interval_buffer.hidden_states = None - def HeadEncodingAction(self, model: Module): + def _head_encoding_action(self, model: Module): """ In this action, 1.prepare inputs for encoding for first stage. 2.do the forward to get hidden states 3.step to update """ @@ -170,7 +170,7 @@ def HeadEncodingAction(self, model: Module): self.mb_manager.step(inputs_dict, output_dict, None) self.action_interval_buffer.hidden_states = output_dict['hidden_states'] - def BodyEncodingAction(self, model: Module): + def _body_encoding_action(self, model: Module): hidden_states = self.action_interval_buffer.hidden_states assert hidden_states is not None, "When not first stage, the hidden states should not be None" inputs_dict = self._prepare_inputs_for_interval_stage() @@ -180,7 +180,7 @@ def BodyEncodingAction(self, model: Module): self.mb_manager.step(inputs_dict, output_dict, None) self.action_interval_buffer.hidden_states = output_dict['hidden_states'] - def CommAction(self, recv_pre: bool) -> torch.Tensor: + def _comm_action(self, recv_pre: bool) -> torch.Tensor: """ In this action, 1.receive the hidden_states from previous stage 2.send the hidden_states to next stage """ @@ -189,7 +189,7 @@ def CommAction(self, recv_pre: bool) -> torch.Tensor: self.action_interval_buffer.hidden_states = ret - def genAction(self, model: Module): + def _gen_action(self, model: Module): """ In p2p step method, we use `P2POp` asynchronous communication method, so the communication need to be done at the begin of each microbatch, it's a more clear way to use an action list to do so. In this function, it will @@ -204,19 +204,19 @@ def genAction(self, model: Module): actions = [] if self.stage_manager.is_first_stage(): if self.mb_manager.cur_state is Status.PREFILL: - actions.append(partial(self.CommAction, False)) - actions.append(partial(self.LoadStageAction, model)) + actions.append(partial(self._comm_action, False)) + actions.append(partial(self._load_stage_action, model)) elif self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.GENERATE: - actions.append(partial(self.CommAction, True)) - actions.append(partial(self.GenTokenAction, model)) - actions.append(partial(self.HeadEncodingAction, model)) + actions.append(partial(self._comm_action, True)) + actions.append(partial(self._gen_token_action, model)) + actions.append(partial(self._head_encoding_action, model)) elif self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.COOLDOWN: - actions.append(partial(self.CommAction, True)) - actions.append(partial(self.GenTokenAction, model)) + actions.append(partial(self._comm_action, True)) + actions.append(partial(self._gen_token_action, model)) # other stage else: - actions.append(partial(self.CommAction, True)) - actions.append(partial(self.BodyEncodingAction, model)) + actions.append(partial(self._comm_action, True)) + actions.append(partial(self._body_encoding_action, model)) return actions @@ -252,7 +252,7 @@ def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.T ] if self.verbose and self.stage_manager.is_first_stage() else None self.action_interval_buffer.clear() while self.mb_manager.is_micro_batch_done() is False: - actions = self.genAction(model) + actions = self._gen_action(model) for action in actions: action() self.mb_manager.next() @@ -260,7 +260,7 @@ def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.T if self.stage_manager.is_first_stage(): output_sequence.extend(self.mb_manager.export_new_tokens()) else: - self.CommAction(False) + self._comm_action(False) self.mb_manager.clear() if self.verbose and self.stage_manager.is_first_stage(): whole_timestamp.extend(self.timestamps) diff --git a/tests/test_generate/test_pipeline_infer.py b/tests/test_infer/test_pipeline_infer.py similarity index 100% rename from tests/test_generate/test_pipeline_infer.py rename to tests/test_infer/test_pipeline_infer.py