From f307f7f4b5f459e309b4d35f3c4f846a5a9524dd Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Thu, 7 Sep 2023 16:54:43 +0800 Subject: [PATCH 1/6] add benchmark verbose --- colossalai/inference/pipeline/engine.py | 3 ++- colossalai/pipeline/schedule/generate.py | 34 ++++++++++++++++++++---- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/colossalai/inference/pipeline/engine.py b/colossalai/inference/pipeline/engine.py index 9236ee0a7bff..39366d2d69da 100644 --- a/colossalai/inference/pipeline/engine.py +++ b/colossalai/inference/pipeline/engine.py @@ -60,6 +60,7 @@ def __init__( new_length: int = 32, micro_batch_size: int = 1, micro_batch_buffer_size: int = None, + verbose: bool = False, # TODO: implement early_stopping, and various gerneration options early_stopping: bool = False, do_sample: bool = False, @@ -71,7 +72,7 @@ def __init__( self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True) self.mb_manager = MicroBatchManager(self.stage_manager.stage, new_length, micro_batch_size, micro_batch_buffer_size or pp_size) - self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager) + self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose) self.model = pp_model or self._shardformer(model, model_policy) def inference(self, input_list): diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index a6ca76a83468..2de48fcfa0d2 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -1,5 +1,6 @@ +import time from functools import partial -from typing import Any, Dict, Iterable, Optional, Union +from typing import Any, Iterable, List, Optional, Union import torch import torch.cuda @@ -22,11 +23,12 @@ class GenerateSchedule(PipelineSchedule): this schedule, the out for each encoding progress is on rank0. Args: - stage_manager (PipelineStageManager): Pipeline stage manager. - mb_manager (MicroBatchManager): Micro batch manager. + stage_manager (`PipelineStageManager`): Pipeline stage manager. + mb_manager (`MicroBatchManager`): Micro batch manager. + verbose (bool): Whether to verbose the information of the pipeline. ''' - def __init__(self, stage_manager: PipelineStageManager, mb_manager: MicroBatchManager) -> None: + def __init__(self, stage_manager: PipelineStageManager, mb_manager: MicroBatchManager, verbose: bool) -> None: super().__init__(stage_manager) self.comm = PipelineP2PCommunication(stage_manager) self.mb_manager = mb_manager @@ -35,6 +37,7 @@ def __init__(self, stage_manager: PipelineStageManager, mb_manager: MicroBatchMa self.batch_size: Optional[int] = None self.microbatch_offset: Optional[int] = None self.num_microbatches: Optional[int] = None + self.verbose = verbose def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -95,6 +98,14 @@ def _get_token_id(self, hidden_state: torch.Tensor) -> torch.Tensor: input_ids = torch.argmax(last_hidden_state, dim=-1).unsqueeze(1) return input_ids + def verbose_info(self, timestamps: List): + prefill = [] + encoder = [] + for timestamp in timestamps: + prefill.append(timestamp[1] - timestamp[0]) + encoder.append(sum(timestamp[i + 1] - timestamp[i] for i in range(1, len(timestamp) - 1))) + print(f"Average prefill time: {sum(prefill)/len(prefill)}, Average encode time: {sum(encoder)/len(encoder)}") + @torch.no_grad() def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: """Forward one step of the pipeline @@ -110,8 +121,11 @@ def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tenso self.load_batch(data_iter) model.eval() + whole_timestamp = [] # run by round for _ in range(self.round): + timestampes = [[] for _ in range(self.stage_manager.num_stages) + ] if self.verbose and self.stage_manager.is_first_stage() else None while self.mb_manager.is_micro_batch_done() is False: inputs_dict = None new_token = None @@ -120,6 +134,8 @@ def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tenso # First stage and in PREFILL phase, just load the inputs if self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.PREFILL: inputs_dict = self.load_micro_batch() + if self.verbose and self.stage_manager.is_first_stage(): + timestampes[self.mb_manager.idx].append(time.time()) output_dict = model_forward(model, inputs_dict, None) self.mb_manager.step(inputs_dict, output_dict, None) # In GENERATE phase @@ -130,7 +146,9 @@ def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tenso # First just generate a new token assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None" logits = model_forward(model, None, hidden_states) - assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {output.keys()}" + if self.verbose and self.stage_manager.is_first_stage(): + timestampes[self.mb_manager.idx].append(time.time()) + assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" new_token = self._get_token_id(logits['logits']) self.mb_manager.step(None, None, new_token) # If the current micro batch is not DONE, go through blocks @@ -154,4 +172,10 @@ def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tenso if self.stage_manager.is_first_stage(): output_sequence.extend(self.mb_manager.export_new_tokens()) self.mb_manager.clear() + if self.verbose and self.stage_manager.is_first_stage(): + whole_timestamp.extend(timestampes) + + if self.verbose and self.stage_manager.is_first_stage(): + self.verbose_info(whole_timestamp) + return output_sequence From 5a969e23326449351a0022942242232ff5b6648e Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Tue, 12 Sep 2023 12:01:24 +0800 Subject: [PATCH 2/6] fix export tokens --- colossalai/inference/pipeline/microbatch_manager.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/colossalai/inference/pipeline/microbatch_manager.py b/colossalai/inference/pipeline/microbatch_manager.py index 7f4b14c17748..38ef450d2be6 100644 --- a/colossalai/inference/pipeline/microbatch_manager.py +++ b/colossalai/inference/pipeline/microbatch_manager.py @@ -184,7 +184,9 @@ def step(self, inputs_dict=None, output_dict: Dict[str, torch.Tensor] = None, ne return self.cur_state def export_new_tokens(self): - new_tokens_list = [i.new_tokens[0].tolist() for i in self.mb_descrption_buffer.values()] + new_tokens_list = [] + for i in self.mb_descrption_buffer.values(): + new_tokens_list.extend(i.new_tokens.tolist()) return new_tokens_list def is_micro_batch_done(self): From f56a62595ea552f1d9660d2d39ddf998d98d05b3 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Tue, 12 Sep 2023 12:02:49 +0800 Subject: [PATCH 3/6] fix benchmark verbose --- colossalai/pipeline/schedule/generate.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index 2de48fcfa0d2..2983d776c525 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -101,10 +101,16 @@ def _get_token_id(self, hidden_state: torch.Tensor) -> torch.Tensor: def verbose_info(self, timestamps: List): prefill = [] encoder = [] + end2end = [] for timestamp in timestamps: prefill.append(timestamp[1] - timestamp[0]) - encoder.append(sum(timestamp[i + 1] - timestamp[i] for i in range(1, len(timestamp) - 1))) - print(f"Average prefill time: {sum(prefill)/len(prefill)}, Average encode time: {sum(encoder)/len(encoder)}") + encoder.append( + sum(timestamp[i + 1] - timestamp[i] for i in range(1, + len(timestamp) - 1)) / (len(timestamp) - 2)) + end2end.append(timestamp[-1] - timestamp[0]) + print(f"Average prefill time: {sum(prefill)/len(prefill)}") + print(f"Average encode time: {sum(encoder)/len(encoder)}") + print(f"Average end2end time: {sum(end2end)/len(end2end)}") @torch.no_grad() def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: From 9a22382632fe95e1113a692d35859edb85314749 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Tue, 12 Sep 2023 17:04:30 +0800 Subject: [PATCH 4/6] add P2POp style to do p2p communication --- colossalai/pipeline/p2p.py | 95 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index aed85cf91512..55dd7e86a29f 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -159,6 +159,87 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any: return object_list[0] +def _p2p_comm_shape( + tensor_send_next: torch.Tensor, + recv_prev: bool, + peer: int, + group: ProcessGroup, +): + send_next_shape = None + recv_prev_shape = None + + if tensor_send_next is not None: + send_next_shape = torch.tensor(tensor_send_next.size(), device=torch.cuda.current_device(), dtype=torch.int64) + if recv_prev: + recv_prev_shape = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64) + + ops = [] + # print(f"send shape: {None if send_next_shape is None else send_next_shape.shape}, recv_prev_shape: {None if recv_prev_shape is None else recv_prev_shape.shape}") + if send_next_shape is not None: + send_next_op = dist.P2POp(dist.isend, send_next_shape, peer=peer, group=group) + ops.append(send_next_op) + if recv_prev_shape is not None: + recv_prev_op = dist.P2POp( + dist.irecv, + recv_prev_shape, + peer=peer, + group=group, + ) + ops.append(recv_prev_op) + + print(f"rank: {dist.get_rank()}, {ops}") + if len(ops) > 0: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + # torch.cuda.synchronize() + + if recv_prev_shape is not None: + recv_prev_shape = recv_prev_shape.tolist() + + return recv_prev_shape + + +def _p2p_comm( + tensor_send_next: torch.Tensor, + recv_pre: bool, + peer: int, + group: ProcessGroup, + comm_type: torch.dtype = torch.float32, +): + tensor_recv_prev = None + recv_prev_shape = _p2p_comm_shape(tensor_send_next, recv_pre, peer, group) + print(f"get shape: {recv_prev_shape}, rank: {dist.get_rank()}") + if recv_pre: + tensor_recv_prev = torch.empty(recv_prev_shape, device=torch.cuda.current_device(), dtype=comm_type) + print(tensor_recv_prev.shape) + + ops = [] + if tensor_send_next is not None: + send_next_op = dist.P2POp( + dist.isend, + tensor_send_next, + peer=peer, + group=group, + ) + ops.append(send_next_op) + + if tensor_recv_prev is not None: + recv_prev_op = dist.P2POp( + dist.irecv, + tensor_recv_prev, + peer=peer, + group=group, + ) + ops.append(recv_prev_op) + if len(ops) > 0: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + # torch.cuda.synchronize() + return tensor_recv_prev + + class PipelineP2PCommunication: def __init__(self, stage_manager: PipelineStageManager) -> None: @@ -220,3 +301,17 @@ def send_backward(self, input_object: Any, prev_rank: int = None) -> None: prev_rank = self.stage_manager.get_prev_rank() cur_rank = self.stage_manager.get_rank() _send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank)) + + def p2p_communicate(self, output_object: Any, recv_pre: bool, peer: int = None) -> None: + """ + Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch. + + Args: + output_object (Any): Object to be sent. + next_rank (int, optional): The rank of the recipient of the tensor. + """ + if peer is None: + peer = self.stage_manager.get_next_rank() + cur_rank = self.stage_manager.get_rank() + recv_tensor = _p2p_comm(output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer)) + return recv_tensor From 585bce51000c3c737095546e8193c6e2c83fb426 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Wed, 13 Sep 2023 15:34:44 +0800 Subject: [PATCH 5/6] modify schedule as p2p type when ppsize is 2 --- .../inference/pipeline/microbatch_manager.py | 3 + colossalai/pipeline/p2p.py | 6 - colossalai/pipeline/schedule/generate.py | 127 ++++++++++++++++++ 3 files changed, 130 insertions(+), 6 deletions(-) diff --git a/colossalai/inference/pipeline/microbatch_manager.py b/colossalai/inference/pipeline/microbatch_manager.py index 38ef450d2be6..b6b008442cfd 100644 --- a/colossalai/inference/pipeline/microbatch_manager.py +++ b/colossalai/inference/pipeline/microbatch_manager.py @@ -10,6 +10,7 @@ class Status(Enum): PREFILL = 1 GENERATE = 2 DONE = 3 + COOLDOWN = 4 class MicroBatchDescription(): @@ -52,6 +53,8 @@ def state(self): # TODO: add the condition for early stopping if self.cur_length == self.target_length: return Status.DONE + elif self.cur_length == self.target_length - 1: + return Status.COOLDOWN else: return Status.GENERATE diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 55dd7e86a29f..227ad2daca0e 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -174,7 +174,6 @@ def _p2p_comm_shape( recv_prev_shape = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64) ops = [] - # print(f"send shape: {None if send_next_shape is None else send_next_shape.shape}, recv_prev_shape: {None if recv_prev_shape is None else recv_prev_shape.shape}") if send_next_shape is not None: send_next_op = dist.P2POp(dist.isend, send_next_shape, peer=peer, group=group) ops.append(send_next_op) @@ -187,12 +186,10 @@ def _p2p_comm_shape( ) ops.append(recv_prev_op) - print(f"rank: {dist.get_rank()}, {ops}") if len(ops) > 0: reqs = dist.batch_isend_irecv(ops) for req in reqs: req.wait() - # torch.cuda.synchronize() if recv_prev_shape is not None: recv_prev_shape = recv_prev_shape.tolist() @@ -209,10 +206,8 @@ def _p2p_comm( ): tensor_recv_prev = None recv_prev_shape = _p2p_comm_shape(tensor_send_next, recv_pre, peer, group) - print(f"get shape: {recv_prev_shape}, rank: {dist.get_rank()}") if recv_pre: tensor_recv_prev = torch.empty(recv_prev_shape, device=torch.cuda.current_device(), dtype=comm_type) - print(tensor_recv_prev.shape) ops = [] if tensor_send_next is not None: @@ -236,7 +231,6 @@ def _p2p_comm( reqs = dist.batch_isend_irecv(ops) for req in reqs: req.wait() - # torch.cuda.synchronize() return tensor_recv_prev diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index 2983d776c525..a86bc2e0a662 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -16,6 +16,17 @@ from .base import PipelineSchedule +class ActionIntervalBuffer(): + + def __int__(self): + self.hidden_states = None + self.new_token = None + + def clear(self): + self.hidden_states = None + self.new_token = None + + class GenerateSchedule(PipelineSchedule): ''' GenerateSchedule is a class that handles the pipeline parallel inference. @@ -38,6 +49,7 @@ def __init__(self, stage_manager: PipelineStageManager, mb_manager: MicroBatchMa self.microbatch_offset: Optional[int] = None self.num_microbatches: Optional[int] = None self.verbose = verbose + self.action_interval_buffer = ActionIntervalBuffer() def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -98,6 +110,93 @@ def _get_token_id(self, hidden_state: torch.Tensor) -> torch.Tensor: input_ids = torch.argmax(last_hidden_state, dim=-1).unsqueeze(1) return input_ids + def _recv_pre_stage(self) -> Any: + ''' + Receive the output from previous stage + + Returns: + Any: The output from previous stage + ''' + if self.stage_manager.num_stages == 2: + return self.comm.p2p_recv() + return self.comm.recv_forward() + + def LoadStageAction(self, model: Module) -> None: + """ + In this action, 1.load micro_batch 2.do the forward 3.step to update + """ + inputs_dict = self.load_micro_batch() + output_dict = model_forward(model, inputs_dict, None) + + self.mb_manager.step(inputs_dict, output_dict, None) + self.action_interval_buffer.hidden_states = output_dict['hidden_states'] + + def GenTokenAction(self, model: Module): + """ + In this action, 1.do the forward with hidden_states to generate new tokens 2.step to update + """ + hidden_states = self.action_interval_buffer.hidden_states + assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None" + hidden_states = {'hidden_states': hidden_states} + logits = model_forward(model, None, hidden_states) + assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" + new_token = self._get_token_id(logits['logits']) + + self.mb_manager.step(None, None, new_token) + self.action_interval_buffer.new_token = new_token + self.action_interval_buffer.hidden_states = None + + def HeadEncodingAction(self, model: Module): + """ + In this action, 1.prepare inputs for encoding for first stage. 2.do the forward to get hidden states 3.step to update + """ + new_token = self.action_interval_buffer.new_token + assert new_token is not None, "When first stage in GENERATE phase, the new token should not be None" + inputs_dict = self._prepare_inputs_for_new_token(new_token) + output_dict = model_forward(model, inputs_dict, None) + + self.mb_manager.step(inputs_dict, output_dict, None) + self.action_interval_buffer.hidden_states = output_dict['hidden_states'] + + def BodyEncodingAction(self, model: Module): + hidden_states = self.action_interval_buffer.hidden_states + assert hidden_states is not None, "When not first stage, the hidden states should not be None" + inputs_dict = self._prepare_inputs_for_interval_stage() + hidden_states = {'hidden_states': hidden_states} + output_dict = model_forward(model, inputs_dict, hidden_states) + + self.mb_manager.step(inputs_dict, output_dict, None) + self.action_interval_buffer.hidden_states = output_dict['hidden_states'] + + def CommAction(self, recv_pre: bool) -> torch.Tensor: + """ + In this action, 1.receive the hidden_states from previous stage 2.send the hidden_states to next stage + """ + hidden_states = self.action_interval_buffer.hidden_states + ret = self.comm.p2p_communicate(hidden_states, recv_pre) + + self.action_interval_buffer.hidden_states = ret + + def genAction(self, model: Module): + actions = [] + if self.stage_manager.is_first_stage(): + if self.mb_manager.cur_state is Status.PREFILL: + actions.append(partial(self.CommAction, False)) + actions.append(partial(self.LoadStageAction, model)) + elif self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.GENERATE: + actions.append(partial(self.CommAction, True)) + actions.append(partial(self.GenTokenAction, model)) + actions.append(partial(self.HeadEncodingAction, model)) + elif self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.COOLDOWN: + actions.append(partial(self.CommAction, True)) + actions.append(partial(self.GenTokenAction, model)) + # other stage + else: + actions.append(partial(self.CommAction, True)) + actions.append(partial(self.BodyEncodingAction, model)) + + return actions + def verbose_info(self, timestamps: List): prefill = [] encoder = [] @@ -114,6 +213,34 @@ def verbose_info(self, timestamps: List): @torch.no_grad() def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: + output_sequence = [] + self.load_batch(data_iter) + model.eval() + + whole_timestamp = [] + + #run by round + for _ in range(self.round): + self.action_interval_buffer.clear() + while self.mb_manager.is_micro_batch_done() is False: + actions = self.genAction(model) + for action in actions: + print(f"rank {self.stage_manager.stage}, {action.func.__name__}") + action() + # print(f"rank {self.stage_manager.stage}, {self.action_interval_buffer}") + self.mb_manager.next() + # All microbatch in current round is DONE + if self.stage_manager.is_first_stage(): + output_sequence.extend(self.mb_manager.export_new_tokens()) + else: + self.CommAction(False) + print(f"rank {self.stage_manager.stage}, Final Comm") + self.mb_manager.clear() + + return output_sequence + + @torch.no_grad() + def generate_step1(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: """Forward one step of the pipeline Args: From 272357dd3a2aec4ed76beb7a7ca27cbe53b41649 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Wed, 13 Sep 2023 15:47:10 +0800 Subject: [PATCH 6/6] remove unused code and add docstring --- colossalai/pipeline/schedule/generate.py | 38 ++++++++++++++++++++---- 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index a86bc2e0a662..85feebea6e6e 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -178,6 +178,17 @@ def CommAction(self, recv_pre: bool) -> torch.Tensor: self.action_interval_buffer.hidden_states = ret def genAction(self, model: Module): + """ + In p2p step method, we use `P2POp` asynchronous communication method, so the communication need to be done + at the begin of each microbatch, it's a more clear way to use an action list to do so. In this function, it will + generate a sequence action for current state, and do the action one by one. + + Args: + model (Module): Model to be run. + + Returns: + List[Callable]: A list of action, each action is a callable function, and it will be called in order. + """ actions = [] if self.stage_manager.is_first_stage(): if self.mb_manager.cur_state is Status.PREFILL: @@ -211,8 +222,25 @@ def verbose_info(self, timestamps: List): print(f"Average encode time: {sum(encoder)/len(encoder)}") print(f"Average end2end time: {sum(end2end)/len(end2end)}") - @torch.no_grad() def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: + if self.stage_manager.num_stages == 2: + return self.generate_step_p2p(model, data_iter) + else: + return self.generate_step_broadcast(model, data_iter) + + @torch.no_grad() + def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: + """ + Forward one step of the pipeline, when pipeline size is 2, the schedule is a circle, broadcast communication will be + blocked, so we use `P2POp` asynchronous communication method. + + Args: + model (Module): Model to be run. + data_iter (Iterable): Data iterator. + + Returns: + Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). + """ output_sequence = [] self.load_batch(data_iter) model.eval() @@ -225,23 +253,21 @@ def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tenso while self.mb_manager.is_micro_batch_done() is False: actions = self.genAction(model) for action in actions: - print(f"rank {self.stage_manager.stage}, {action.func.__name__}") action() - # print(f"rank {self.stage_manager.stage}, {self.action_interval_buffer}") self.mb_manager.next() # All microbatch in current round is DONE if self.stage_manager.is_first_stage(): output_sequence.extend(self.mb_manager.export_new_tokens()) else: self.CommAction(False) - print(f"rank {self.stage_manager.stage}, Final Comm") self.mb_manager.clear() return output_sequence @torch.no_grad() - def generate_step1(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: - """Forward one step of the pipeline + def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: + """ + Forward one step of the pipeline Args: model (Module): Model to be run.