From 9aa509c6176bdc46de895ff566c6719dea9dff2f Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Fri, 15 Sep 2023 10:16:21 +0800 Subject: [PATCH 1/7] add benchmark script --- .../inference/pipeline/benchmark/benchmark.py | 101 ++++++++++++++++++ .../inference/pipeline/benchmark/run.sh | 34 ++++++ colossalai/pipeline/schedule/generate.py | 41 +++---- 3 files changed, 152 insertions(+), 24 deletions(-) create mode 100644 colossalai/inference/pipeline/benchmark/benchmark.py create mode 100644 colossalai/inference/pipeline/benchmark/run.sh diff --git a/colossalai/inference/pipeline/benchmark/benchmark.py b/colossalai/inference/pipeline/benchmark/benchmark.py new file mode 100644 index 000000000000..bda14187480d --- /dev/null +++ b/colossalai/inference/pipeline/benchmark/benchmark.py @@ -0,0 +1,101 @@ +import torch +import torch.distributed as dist +import transformers + +import colossalai +from colossalai.inference import PPInferEngine +from colossalai.inference.pipeline.policy.llama_ppinfer import LlamaForCausalLMPipelinePolicy +import argparse +GIGABYTE = 1024 ** 3 +MEGABYTE = 1024 * 1024 + +colossalai.launch_from_torch(config={}) + +def data_gen(batch_size: int=4, seq_len: int=512): + input_ids = torch.randint(10, 30000, (1, seq_len), dtype=torch.int32) + attention_mask = torch.ones((1, seq_len), dtype=torch.int32) + data = dict(input_ids=input_ids, attention_mask=attention_mask) + for k, v in data.items(): + if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: + new_shape = [1] * v.dim() + new_shape[0] = batch_size + data[k] = v.to('cuda').repeat(*new_shape) + return data + +def print_details_info(timestamps, model_config, args): + if dist.get_rank() == 0: + prefill = [] + encoder = [] + end2end = [] + for timestamp in timestamps: + prefill.append(timestamp[1] - timestamp[0]) + encoder.append( + sum(timestamp[i + 1] - timestamp[i] for i in range(1,len(timestamp) - 1)) / (len(timestamp) - 2)) + end2end.append(timestamp[-1] - timestamp[0]) + with open(f"llama-{args.model}{'fp16' if args.fp16 is True else 'fp32'}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","w+") as f: + avg_latency = sum(end2end)/(args.new_length * args.batch_size) + num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers) + num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size + if args.fp16: + num_bytes = 2 + else: + num_bytes = 4 + + f.write(f"llama-{args.model} {'fp16' if args.fp16 is True else 'fp32'} {args.pp_size}, input_len:{args.seq_len}, output_len:{args.new_length}, bsz:{args.batch_size}, mbsz:{args.mb_size}\n") + f.write("Average prefill time: {0:8.2f} ms\n".format(sum(prefill)/len(prefill)*1000)) + f.write("Average encode time: {0:8.2f} ms\n".format(sum(encoder)/len(encoder)*1000)) + f.write("Average end2end time: {0:8.2f} ms\n".format(sum(end2end)/len(end2end)*1000)) + f.write("Average Per Token Latency: {0:8.2f} ms\n".format(avg_latency * 1000)) + f.write("Avg flops: {0:8.2f} TFlops/s\n".format(1/avg_latency * num_parameters * num_bytes / 1e12)) + f.write("Average Throughput: {} tokens/s\n".format((1000/(avg_latency * 1000)))) + f.write("----------------------------------------------------------\n") + + + if torch.cuda.is_available(): + current_device = torch.cuda.current_device() + + # free memory and the total available memory in bytes + global_free_memory, total_GPU_memory_occupied = torch.cuda.mem_get_info() + memory_allocated = torch.cuda.memory_allocated() + max_memory_allocated = torch.cuda.max_memory_allocated() + memory_reserved = torch.cuda.memory_reserved() + max_memory_reserved = torch.cuda.max_memory_reserved() + with open(f"llama-{args.model}{'fp16' if args.fp16 is True else 'fp32'}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","a") as f: + f.write( + f"\nCurrently using GPU: {current_device}\n" + f"free memory : {global_free_memory / GIGABYTE:.4f} GB,\n" + f"total memory: {total_GPU_memory_occupied / GIGABYTE:.4f} GB,\n" + f"memory allocated: {memory_allocated / GIGABYTE:.4f} GB,\n" + f"Max CUDA memory allocated: {max_memory_allocated / GIGABYTE:.4f} GB,\n" + f"memory reserved/cached: {memory_reserved / GIGABYTE:.4f} GB,\n" + f"Max CUDA memory reserved/cached: {max_memory_reserved / GIGABYTE:.4f} GB,\n" + ) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--model', default='toy', help='the size of model') + parser.add_argument('-b', '--batch_size', type=int, default=8, help='batch size') + parser.add_argument('-s', '--seq_len', type=int, default=8, help='sequence length') + parser.add_argument('--new_length', type=int, default=4, help='new tokens length') + parser.add_argument('--mb_size', type=int, default=1, help='micro_batch_size') + parser.add_argument('--pp_size', type=int, default=2, help='pipeline size') + parser.add_argument('--fp16', type=bool, default=True, help='wheather to use fp16') + args = parser.parse_args() + + if args.model == 'toy': + model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=8)) + elif args.model == '7b': + model = transformers.LlamaForCausalLM.from_pretrained('decapoda-research/llama-7b-hf') + elif args.model == '13b': + model = transformers.LlamaForCausalLM.from_pretrained('decapoda-research/llatma-13b-hf') + else: + raise NotImplementedError + # if args.fp16: + # model = model.half() + engine = PPInferEngine(pp_size=args.pp_size, micro_batch_size=args.mb_size, new_length=args.new_length, model=model, model_policy=LlamaForCausalLMPipelinePolicy(),verbose=True) + data = data_gen(args.batch_size, args.seq_len) + output, timestamps = engine.inference([data]) + if dist.get_rank() == 0: + print(len(output), len(output[0])) + print_details_info(timestamps, model.config, args) + diff --git a/colossalai/inference/pipeline/benchmark/run.sh b/colossalai/inference/pipeline/benchmark/run.sh new file mode 100644 index 000000000000..00e7783406a9 --- /dev/null +++ b/colossalai/inference/pipeline/benchmark/run.sh @@ -0,0 +1,34 @@ +script_dir=$(cd "$(dirname "$0")" && pwd) +cd "${script_dir}" + +CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ + --model="7b" \ + --batch_size=2 \ + --seq_len=1024 \ + --new_length=128 \ + --mb_size=1 \ + --pp_size=2 + +CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ + --model="7b" \ + --batch_size=4 \ + --seq_len=1024 \ + --new_length=128 \ + --mb_size=2 \ + --pp_size=2 + +CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ + --model="7b" \ + --batch_size=8 \ + --seq_len=1024 \ + --new_length=128 \ + --mb_size=4 \ + --pp_size=2 + +CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ + --model="7b" \ + --batch_size=16 \ + --seq_len=1024 \ + --new_length=128 \ + --mb_size=8 \ + --pp_size=2 diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index 85feebea6e6e..869b6d1e1f95 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -48,8 +48,9 @@ def __init__(self, stage_manager: PipelineStageManager, mb_manager: MicroBatchMa self.batch_size: Optional[int] = None self.microbatch_offset: Optional[int] = None self.num_microbatches: Optional[int] = None - self.verbose = verbose self.action_interval_buffer = ActionIntervalBuffer() + self.verbose = verbose + self.timestamps = None def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -126,6 +127,8 @@ def LoadStageAction(self, model: Module) -> None: In this action, 1.load micro_batch 2.do the forward 3.step to update """ inputs_dict = self.load_micro_batch() + if self.verbose and self.stage_manager.is_first_stage(): + self.timestamps[self.mb_manager.idx].append(time.time()) output_dict = model_forward(model, inputs_dict, None) self.mb_manager.step(inputs_dict, output_dict, None) @@ -139,6 +142,8 @@ def GenTokenAction(self, model: Module): assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None" hidden_states = {'hidden_states': hidden_states} logits = model_forward(model, None, hidden_states) + if self.verbose and self.stage_manager.is_first_stage(): + self.timestamps[self.mb_manager.idx].append(time.time()) assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" new_token = self._get_token_id(logits['logits']) @@ -208,21 +213,8 @@ def genAction(self, model: Module): return actions - def verbose_info(self, timestamps: List): - prefill = [] - encoder = [] - end2end = [] - for timestamp in timestamps: - prefill.append(timestamp[1] - timestamp[0]) - encoder.append( - sum(timestamp[i + 1] - timestamp[i] for i in range(1, - len(timestamp) - 1)) / (len(timestamp) - 2)) - end2end.append(timestamp[-1] - timestamp[0]) - print(f"Average prefill time: {sum(prefill)/len(prefill)}") - print(f"Average encode time: {sum(encoder)/len(encoder)}") - print(f"Average end2end time: {sum(end2end)/len(end2end)}") - def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: + model = model.half() if self.stage_manager.num_stages == 2: return self.generate_step_p2p(model, data_iter) else: @@ -249,6 +241,8 @@ def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.T #run by round for _ in range(self.round): + self.timestamps = [[] for _ in range(self.stage_manager.num_stages) + ] if self.verbose and self.stage_manager.is_first_stage() else None self.action_interval_buffer.clear() while self.mb_manager.is_micro_batch_done() is False: actions = self.genAction(model) @@ -261,8 +255,10 @@ def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.T else: self.CommAction(False) self.mb_manager.clear() + if self.verbose and self.stage_manager.is_first_stage(): + whole_timestamp.extend(self.timestamps) - return output_sequence + return output_sequence, whole_timestamp @torch.no_grad() def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: @@ -283,7 +279,7 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t whole_timestamp = [] # run by round for _ in range(self.round): - timestampes = [[] for _ in range(self.stage_manager.num_stages) + self.timestamps = [[] for _ in range(self.stage_manager.num_stages) ] if self.verbose and self.stage_manager.is_first_stage() else None while self.mb_manager.is_micro_batch_done() is False: inputs_dict = None @@ -294,7 +290,7 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t if self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.PREFILL: inputs_dict = self.load_micro_batch() if self.verbose and self.stage_manager.is_first_stage(): - timestampes[self.mb_manager.idx].append(time.time()) + self.timestamps[self.mb_manager.idx].append(time.time()) output_dict = model_forward(model, inputs_dict, None) self.mb_manager.step(inputs_dict, output_dict, None) # In GENERATE phase @@ -306,7 +302,7 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None" logits = model_forward(model, None, hidden_states) if self.verbose and self.stage_manager.is_first_stage(): - timestampes[self.mb_manager.idx].append(time.time()) + self.timestamps[self.mb_manager.idx].append(time.time()) assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" new_token = self._get_token_id(logits['logits']) self.mb_manager.step(None, None, new_token) @@ -332,9 +328,6 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t output_sequence.extend(self.mb_manager.export_new_tokens()) self.mb_manager.clear() if self.verbose and self.stage_manager.is_first_stage(): - whole_timestamp.extend(timestampes) - - if self.verbose and self.stage_manager.is_first_stage(): - self.verbose_info(whole_timestamp) + whole_timestamp.extend(self.timestamps) - return output_sequence + return output_sequence, whole_timestamp From b11059c169f6d4c6202d0133e0a3ec72fe5584f6 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Fri, 15 Sep 2023 10:50:53 +0800 Subject: [PATCH 2/7] update argparse --- colossalai/inference/pipeline/benchmark/benchmark.py | 7 +++---- colossalai/inference/pipeline/benchmark/run.sh | 4 ++++ colossalai/inference/pipeline/engine.py | 3 +++ colossalai/pipeline/p2p.py | 8 ++++---- colossalai/pipeline/schedule/generate.py | 5 +++-- 5 files changed, 17 insertions(+), 10 deletions(-) diff --git a/colossalai/inference/pipeline/benchmark/benchmark.py b/colossalai/inference/pipeline/benchmark/benchmark.py index bda14187480d..2d073840cb61 100644 --- a/colossalai/inference/pipeline/benchmark/benchmark.py +++ b/colossalai/inference/pipeline/benchmark/benchmark.py @@ -79,7 +79,7 @@ def print_details_info(timestamps, model_config, args): parser.add_argument('--new_length', type=int, default=4, help='new tokens length') parser.add_argument('--mb_size', type=int, default=1, help='micro_batch_size') parser.add_argument('--pp_size', type=int, default=2, help='pipeline size') - parser.add_argument('--fp16', type=bool, default=True, help='wheather to use fp16') + parser.add_argument('--fp16', action="store_true", help='wheather to use fp16') args = parser.parse_args() if args.model == 'toy': @@ -90,9 +90,8 @@ def print_details_info(timestamps, model_config, args): model = transformers.LlamaForCausalLM.from_pretrained('decapoda-research/llatma-13b-hf') else: raise NotImplementedError - # if args.fp16: - # model = model.half() - engine = PPInferEngine(pp_size=args.pp_size, micro_batch_size=args.mb_size, new_length=args.new_length, model=model, model_policy=LlamaForCausalLMPipelinePolicy(),verbose=True) + + engine = PPInferEngine(pp_size=args.pp_size, fp16=args.fp16, micro_batch_size=args.mb_size, new_length=args.new_length, model=model, model_policy=LlamaForCausalLMPipelinePolicy(),verbose=True) data = data_gen(args.batch_size, args.seq_len) output, timestamps = engine.inference([data]) if dist.get_rank() == 0: diff --git a/colossalai/inference/pipeline/benchmark/run.sh b/colossalai/inference/pipeline/benchmark/run.sh index 00e7783406a9..6670ba4ac26d 100644 --- a/colossalai/inference/pipeline/benchmark/run.sh +++ b/colossalai/inference/pipeline/benchmark/run.sh @@ -3,6 +3,7 @@ cd "${script_dir}" CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ --model="7b" \ + --fp16 \ --batch_size=2 \ --seq_len=1024 \ --new_length=128 \ @@ -11,6 +12,7 @@ CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 . CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ --model="7b" \ + --fp16 \ --batch_size=4 \ --seq_len=1024 \ --new_length=128 \ @@ -19,6 +21,7 @@ CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 . CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ --model="7b" \ + --fp16 \ --batch_size=8 \ --seq_len=1024 \ --new_length=128 \ @@ -27,6 +30,7 @@ CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 . CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ --model="7b" \ + --fp16 \ --batch_size=16 \ --seq_len=1024 \ --new_length=128 \ diff --git a/colossalai/inference/pipeline/engine.py b/colossalai/inference/pipeline/engine.py index 39366d2d69da..42a88322411d 100644 --- a/colossalai/inference/pipeline/engine.py +++ b/colossalai/inference/pipeline/engine.py @@ -54,6 +54,7 @@ class PPInferEngine: def __init__( self, pp_size: int, + fp16: bool = True, pp_model: nn.Module = None, model: nn.Module = None, model_policy: Policy = None, @@ -74,6 +75,8 @@ def __init__( micro_batch_buffer_size or pp_size) self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose) self.model = pp_model or self._shardformer(model, model_policy) + if fp16: + self.model.half() def inference(self, input_list): out = self.schedule.generate_step(self.model, iter(input_list)) diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 227ad2daca0e..0690538351e7 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -202,12 +202,12 @@ def _p2p_comm( recv_pre: bool, peer: int, group: ProcessGroup, - comm_type: torch.dtype = torch.float32, + comm_dtype: torch.dtype = torch.float16, ): tensor_recv_prev = None recv_prev_shape = _p2p_comm_shape(tensor_send_next, recv_pre, peer, group) if recv_pre: - tensor_recv_prev = torch.empty(recv_prev_shape, device=torch.cuda.current_device(), dtype=comm_type) + tensor_recv_prev = torch.empty(recv_prev_shape, device=torch.cuda.current_device(), dtype=comm_dtype) ops = [] if tensor_send_next is not None: @@ -296,7 +296,7 @@ def send_backward(self, input_object: Any, prev_rank: int = None) -> None: cur_rank = self.stage_manager.get_rank() _send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank)) - def p2p_communicate(self, output_object: Any, recv_pre: bool, peer: int = None) -> None: + def p2p_communicate(self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16) -> None: """ Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch. @@ -307,5 +307,5 @@ def p2p_communicate(self, output_object: Any, recv_pre: bool, peer: int = None) if peer is None: peer = self.stage_manager.get_next_rank() cur_rank = self.stage_manager.get_rank() - recv_tensor = _p2p_comm(output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer)) + recv_tensor = _p2p_comm(output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer), comm_dtype) return recv_tensor diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index 869b6d1e1f95..6c931553e303 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -51,6 +51,7 @@ def __init__(self, stage_manager: PipelineStageManager, mb_manager: MicroBatchMa self.action_interval_buffer = ActionIntervalBuffer() self.verbose = verbose self.timestamps = None + self.comm_dtype = None def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -178,7 +179,7 @@ def CommAction(self, recv_pre: bool) -> torch.Tensor: In this action, 1.receive the hidden_states from previous stage 2.send the hidden_states to next stage """ hidden_states = self.action_interval_buffer.hidden_states - ret = self.comm.p2p_communicate(hidden_states, recv_pre) + ret = self.comm.p2p_communicate(hidden_states, recv_pre, comm_dtype=self.comm_dtype) self.action_interval_buffer.hidden_states = ret @@ -214,7 +215,6 @@ def genAction(self, model: Module): return actions def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: - model = model.half() if self.stage_manager.num_stages == 2: return self.generate_step_p2p(model, data_iter) else: @@ -236,6 +236,7 @@ def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.T output_sequence = [] self.load_batch(data_iter) model.eval() + self.comm_dtype = model.dtype whole_timestamp = [] From 588625d6479c38aeab4d0a7d53532a39c864deec Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Mon, 18 Sep 2023 13:50:34 +0800 Subject: [PATCH 3/7] fix fp16 load --- colossalai/inference/pipeline/engine.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/colossalai/inference/pipeline/engine.py b/colossalai/inference/pipeline/engine.py index 42a88322411d..f7cc2fc25e78 100644 --- a/colossalai/inference/pipeline/engine.py +++ b/colossalai/inference/pipeline/engine.py @@ -74,9 +74,10 @@ def __init__( self.mb_manager = MicroBatchManager(self.stage_manager.stage, new_length, micro_batch_size, micro_batch_buffer_size or pp_size) self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose) - self.model = pp_model or self._shardformer(model, model_policy) if fp16: - self.model.half() + model.half() + self.model = pp_model or self._shardformer(model, model_policy) + def inference(self, input_list): out = self.schedule.generate_step(self.model, iter(input_list)) From ff902e6a414444371f36ddb0cc07a3a2adfc4135 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Mon, 25 Sep 2023 10:27:19 +0800 Subject: [PATCH 4/7] refactor code style --- .../inference/pipeline/benchmark/benchmark.py | 35 ++++++++++++------- colossalai/pipeline/p2p.py | 18 +++------- 2 files changed, 27 insertions(+), 26 deletions(-) diff --git a/colossalai/inference/pipeline/benchmark/benchmark.py b/colossalai/inference/pipeline/benchmark/benchmark.py index 2d073840cb61..94f62cf41460 100644 --- a/colossalai/inference/pipeline/benchmark/benchmark.py +++ b/colossalai/inference/pipeline/benchmark/benchmark.py @@ -3,6 +3,7 @@ import transformers import colossalai +import time from colossalai.inference import PPInferEngine from colossalai.inference.pipeline.policy.llama_ppinfer import LlamaForCausalLMPipelinePolicy import argparse @@ -22,7 +23,7 @@ def data_gen(batch_size: int=4, seq_len: int=512): data[k] = v.to('cuda').repeat(*new_shape) return data -def print_details_info(timestamps, model_config, args): +def print_details_info(timestamps, model_config, args, whole_end2end): if dist.get_rank() == 0: prefill = [] encoder = [] @@ -32,8 +33,11 @@ def print_details_info(timestamps, model_config, args): encoder.append( sum(timestamp[i + 1] - timestamp[i] for i in range(1,len(timestamp) - 1)) / (len(timestamp) - 2)) end2end.append(timestamp[-1] - timestamp[0]) - with open(f"llama-{args.model}{'fp16' if args.fp16 is True else 'fp32'}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","w+") as f: - avg_latency = sum(end2end)/(args.new_length * args.batch_size) + print(whole_end2end) + with open(f"./log/llama-{args.model}{'fp16' if args.fp16 is True else 'fp32'}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","w+") as f: + mb_avg_end2end = sum(end2end)/len(end2end) + mb_avg_latency = mb_avg_end2end/(args.new_length * args.mb_size) + whole_avg_latency = whole_end2end/(args.new_length * args.batch_size) num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers) num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size if args.fp16: @@ -44,10 +48,12 @@ def print_details_info(timestamps, model_config, args): f.write(f"llama-{args.model} {'fp16' if args.fp16 is True else 'fp32'} {args.pp_size}, input_len:{args.seq_len}, output_len:{args.new_length}, bsz:{args.batch_size}, mbsz:{args.mb_size}\n") f.write("Average prefill time: {0:8.2f} ms\n".format(sum(prefill)/len(prefill)*1000)) f.write("Average encode time: {0:8.2f} ms\n".format(sum(encoder)/len(encoder)*1000)) - f.write("Average end2end time: {0:8.2f} ms\n".format(sum(end2end)/len(end2end)*1000)) - f.write("Average Per Token Latency: {0:8.2f} ms\n".format(avg_latency * 1000)) - f.write("Avg flops: {0:8.2f} TFlops/s\n".format(1/avg_latency * num_parameters * num_bytes / 1e12)) - f.write("Average Throughput: {} tokens/s\n".format((1000/(avg_latency * 1000)))) + f.write("Average micro batch end2end time: {0:8.2f} ms\n".format(mb_avg_end2end*1000)) + f.write("Average micro batch Per Token Latency: {0:8.2f} ms\n".format(mb_avg_latency * 1000)) + f.write("Whole batch end2end time: {0:8.2f} ms\n".format(whole_end2end*1000)) + f.write("Whole batch Per Token Latency: {0:8.2f} ms\n".format(whole_avg_latency * 1000)) + f.write("Throughput: {} tokens/s\n".format((1000/(whole_avg_latency * 1000)))) + f.write("flops: {0:8.2f} TFlops/s\n".format(1/whole_avg_latency * num_parameters * num_bytes / 1e12)) f.write("----------------------------------------------------------\n") @@ -60,7 +66,7 @@ def print_details_info(timestamps, model_config, args): max_memory_allocated = torch.cuda.max_memory_allocated() memory_reserved = torch.cuda.memory_reserved() max_memory_reserved = torch.cuda.max_memory_reserved() - with open(f"llama-{args.model}{'fp16' if args.fp16 is True else 'fp32'}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","a") as f: + with open(f"./log/llama-{args.model}{'fp16' if args.fp16 is True else 'fp32'}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","a") as f: f.write( f"\nCurrently using GPU: {current_device}\n" f"free memory : {global_free_memory / GIGABYTE:.4f} GB,\n" @@ -85,16 +91,19 @@ def print_details_info(timestamps, model_config, args): if args.model == 'toy': model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=8)) elif args.model == '7b': - model = transformers.LlamaForCausalLM.from_pretrained('decapoda-research/llama-7b-hf') + model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained('decapoda-research/llama-7b-hf')) elif args.model == '13b': - model = transformers.LlamaForCausalLM.from_pretrained('decapoda-research/llatma-13b-hf') + model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained('decapoda-research/llama-13b-hf')) else: raise NotImplementedError + engine = PPInferEngine(pp_size=args.pp_size, fp16=args.fp16, micro_batch_size=args.mb_size, new_length=args.new_length, model=model, model_policy=LlamaForCausalLMPipelinePolicy(),verbose=True) data = data_gen(args.batch_size, args.seq_len) + + whole_end2end = time.time() output, timestamps = engine.inference([data]) - if dist.get_rank() == 0: - print(len(output), len(output[0])) - print_details_info(timestamps, model.config, args) + whole_end2end = time.time() - whole_end2end + + print_details_info(timestamps, model.config, args, whole_end2end) diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 0690538351e7..d12af3015caf 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -159,12 +159,14 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any: return object_list[0] -def _p2p_comm_shape( +def _p2p_comm( tensor_send_next: torch.Tensor, recv_prev: bool, peer: int, group: ProcessGroup, + comm_dtype: torch.dtype = torch.float16, ): + # send and recv shape send_next_shape = None recv_prev_shape = None @@ -194,19 +196,9 @@ def _p2p_comm_shape( if recv_prev_shape is not None: recv_prev_shape = recv_prev_shape.tolist() - return recv_prev_shape - - -def _p2p_comm( - tensor_send_next: torch.Tensor, - recv_pre: bool, - peer: int, - group: ProcessGroup, - comm_dtype: torch.dtype = torch.float16, -): + # send and recv data tensor_recv_prev = None - recv_prev_shape = _p2p_comm_shape(tensor_send_next, recv_pre, peer, group) - if recv_pre: + if recv_prev: tensor_recv_prev = torch.empty(recv_prev_shape, device=torch.cuda.current_device(), dtype=comm_dtype) ops = [] From ad7504c68312e1981ce5f32b44f3fcf8e6f1662a Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Mon, 25 Sep 2023 10:41:54 +0800 Subject: [PATCH 5/7] add docstring --- colossalai/inference/pipeline/engine.py | 4 ---- colossalai/pipeline/p2p.py | 13 +++++++++++++ colossalai/pipeline/schedule/generate.py | 8 ++++++-- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/colossalai/inference/pipeline/engine.py b/colossalai/inference/pipeline/engine.py index f7cc2fc25e78..772669d119bf 100644 --- a/colossalai/inference/pipeline/engine.py +++ b/colossalai/inference/pipeline/engine.py @@ -12,13 +12,9 @@ from colossalai.pipeline.schedule.generate import GenerateSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.shardformer._utils import getattr_ from colossalai.shardformer.policies.base_policy import Policy from .microbatch_manager import MicroBatchManager -from .policy.gpt2_ppinfer import GPT2LMHeadModelPipelinePolicy -from .utils import get_suffix_name, set_tensors_to_none - class PPInferEngine: ''' diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index d12af3015caf..c8195d422c03 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -166,6 +166,19 @@ def _p2p_comm( group: ProcessGroup, comm_dtype: torch.dtype = torch.float16, ): + """ + Send and recv tensor using P2P communication, used when pipeline size is 2 to solve the race communication. + + Agrs: + tensor_send_next (torch.Tensor): tensor to be sent to next stage + recv_prev (bool): whether to receive tensor from previous stage + peer (int): rank of the peer + group (ProcessGroup): process group + comm_dtype (torch.dtype): dtype of the tensor to be sent + + Returns: + torch.Tensor: tensor received from previous stage + """ # send and recv shape send_next_shape = None recv_prev_shape = None diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index 6c931553e303..ffe5056c2fab 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -17,6 +17,10 @@ class ActionIntervalBuffer(): + """ + The buffer to save the interval hidden states and new token for stage to use. + + """ def __int__(self): self.hidden_states = None @@ -28,7 +32,7 @@ def clear(self): class GenerateSchedule(PipelineSchedule): - ''' + """ GenerateSchedule is a class that handles the pipeline parallel inference. In our schedule, we place tie weight layer, embedding and lm_head in the same device to save space, so in this schedule, the out for each encoding progress is on rank0. @@ -37,7 +41,7 @@ class GenerateSchedule(PipelineSchedule): stage_manager (`PipelineStageManager`): Pipeline stage manager. mb_manager (`MicroBatchManager`): Micro batch manager. verbose (bool): Whether to verbose the information of the pipeline. - ''' + """ def __init__(self, stage_manager: PipelineStageManager, mb_manager: MicroBatchManager, verbose: bool) -> None: super().__init__(stage_manager) From a467d002202a359e2e13d95660beb2c2fe2243f7 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Tue, 26 Sep 2023 13:53:34 +0800 Subject: [PATCH 6/7] polish code --- .../inference/pipeline/benchmark/benchmark.py | 15 ++-- .../inference/pipeline/benchmark/run.sh | 76 +++++++++++-------- colossalai/inference/pipeline/engine.py | 15 ++-- colossalai/pipeline/schedule/generate.py | 4 + 4 files changed, 64 insertions(+), 46 deletions(-) diff --git a/colossalai/inference/pipeline/benchmark/benchmark.py b/colossalai/inference/pipeline/benchmark/benchmark.py index 94f62cf41460..97dfc6336bea 100644 --- a/colossalai/inference/pipeline/benchmark/benchmark.py +++ b/colossalai/inference/pipeline/benchmark/benchmark.py @@ -34,18 +34,18 @@ def print_details_info(timestamps, model_config, args, whole_end2end): sum(timestamp[i + 1] - timestamp[i] for i in range(1,len(timestamp) - 1)) / (len(timestamp) - 2)) end2end.append(timestamp[-1] - timestamp[0]) print(whole_end2end) - with open(f"./log/llama-{args.model}{'fp16' if args.fp16 is True else 'fp32'}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","w+") as f: + with open(f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","w+") as f: mb_avg_end2end = sum(end2end)/len(end2end) mb_avg_latency = mb_avg_end2end/(args.new_length * args.mb_size) whole_avg_latency = whole_end2end/(args.new_length * args.batch_size) num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers) num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size - if args.fp16: + if args.dtype in ['fp16','bf16']: num_bytes = 2 else: num_bytes = 4 - f.write(f"llama-{args.model} {'fp16' if args.fp16 is True else 'fp32'} {args.pp_size}, input_len:{args.seq_len}, output_len:{args.new_length}, bsz:{args.batch_size}, mbsz:{args.mb_size}\n") + f.write(f"llama-{args.model}{args.dtype}_pp{args.pp_size}, input_len:{args.seq_len}, output_len:{args.new_length}, bsz:{args.batch_size}, mbsz:{args.mb_size}\n") f.write("Average prefill time: {0:8.2f} ms\n".format(sum(prefill)/len(prefill)*1000)) f.write("Average encode time: {0:8.2f} ms\n".format(sum(encoder)/len(encoder)*1000)) f.write("Average micro batch end2end time: {0:8.2f} ms\n".format(mb_avg_end2end*1000)) @@ -66,7 +66,7 @@ def print_details_info(timestamps, model_config, args, whole_end2end): max_memory_allocated = torch.cuda.max_memory_allocated() memory_reserved = torch.cuda.memory_reserved() max_memory_reserved = torch.cuda.max_memory_reserved() - with open(f"./log/llama-{args.model}{'fp16' if args.fp16 is True else 'fp32'}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","a") as f: + with open(f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","a") as f: f.write( f"\nCurrently using GPU: {current_device}\n" f"free memory : {global_free_memory / GIGABYTE:.4f} GB,\n" @@ -85,7 +85,8 @@ def print_details_info(timestamps, model_config, args, whole_end2end): parser.add_argument('--new_length', type=int, default=4, help='new tokens length') parser.add_argument('--mb_size', type=int, default=1, help='micro_batch_size') parser.add_argument('--pp_size', type=int, default=2, help='pipeline size') - parser.add_argument('--fp16', action="store_true", help='wheather to use fp16') + parser.add_argument('--log_path', type=str, default='./log' ,help='where to store the benchmark log') + parser.add_argument('--dtype', type=str, default='fp16', help='data type') args = parser.parse_args() if args.model == 'toy': @@ -98,11 +99,13 @@ def print_details_info(timestamps, model_config, args, whole_end2end): raise NotImplementedError - engine = PPInferEngine(pp_size=args.pp_size, fp16=args.fp16, micro_batch_size=args.mb_size, new_length=args.new_length, model=model, model_policy=LlamaForCausalLMPipelinePolicy(),verbose=True) + engine = PPInferEngine(pp_size=args.pp_size, dtype=args.dtype, micro_batch_size=args.mb_size, new_length=args.new_length, model=model, model_policy=LlamaForCausalLMPipelinePolicy(),verbose=True) data = data_gen(args.batch_size, args.seq_len) + torch.cuda.synchronize() whole_end2end = time.time() output, timestamps = engine.inference([data]) + torch.cuda.synchronize() whole_end2end = time.time() - whole_end2end print_details_info(timestamps, model.config, args, whole_end2end) diff --git a/colossalai/inference/pipeline/benchmark/run.sh b/colossalai/inference/pipeline/benchmark/run.sh index 6670ba4ac26d..7d8da858692f 100644 --- a/colossalai/inference/pipeline/benchmark/run.sh +++ b/colossalai/inference/pipeline/benchmark/run.sh @@ -1,38 +1,50 @@ script_dir=$(cd "$(dirname "$0")" && pwd) cd "${script_dir}" -CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ - --model="7b" \ - --fp16 \ - --batch_size=2 \ - --seq_len=1024 \ - --new_length=128 \ - --mb_size=1 \ - --pp_size=2 +# 7b, fp32, 2 gpu, 1024, 128 +for BATCH_SIZE in 2 4 8 16; do + CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ + --model="7b" \ + --dtype="fp16" \ + --batch_size=${BATCH_SIZE} \ + --seq_len=1024 \ + --new_length=128 \ + --mb_size=$((${BATCH_SIZE}/2)) \ + --pp_size=2 +done -CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ - --model="7b" \ - --fp16 \ - --batch_size=4 \ - --seq_len=1024 \ - --new_length=128 \ - --mb_size=2 \ - --pp_size=2 +# 7b, fp32, 2 gpu, 512, 512 +for BATCH_SIZE in 2 4 8 16 32; do + CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ + --model="7b" \ + --dtype="fp16" \ + --batch_size=${BATCH_SIZE} \ + --seq_len=512 \ + --new_length=512 \ + --mb_size=$((${BATCH_SIZE}/2)) \ + --pp_size=2 +done -CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ - --model="7b" \ - --fp16 \ - --batch_size=8 \ - --seq_len=1024 \ - --new_length=128 \ - --mb_size=4 \ - --pp_size=2 +# 7b, fp32, 2 gpu, 1024, 128 +for BATCH_SIZE in 2 4 8; do + CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ + --model="13b" \ + --dtype="fp16" \ + --batch_size=${BATCH_SIZE} \ + --seq_len=1024 \ + --new_length=128 \ + --mb_size=$((${BATCH_SIZE}/2)) \ + --pp_size=2 +done -CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ - --model="7b" \ - --fp16 \ - --batch_size=16 \ - --seq_len=1024 \ - --new_length=128 \ - --mb_size=8 \ - --pp_size=2 +# 13b, fp16, 2 gpu, 512, 512 +for BATCH_SIZE in 2 4 8 16; do + CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ + --model="13b" \ + --dtype="fp16" \ + --batch_size=${BATCH_SIZE} \ + --seq_len=512 \ + --new_length=512 \ + --mb_size=$((${BATCH_SIZE}/2)) \ + --pp_size=2 +done diff --git a/colossalai/inference/pipeline/engine.py b/colossalai/inference/pipeline/engine.py index 772669d119bf..1f4cb97d6e55 100644 --- a/colossalai/inference/pipeline/engine.py +++ b/colossalai/inference/pipeline/engine.py @@ -1,11 +1,6 @@ -import re -from functools import partial -from types import MethodType -from typing import Callable, List, Optional, Set +from typing import Callable, List, Optional, Set, Union -import numpy as np import torch -import torch.distributed as dist import torch.nn as nn from colossalai.cluster import ProcessGroupMesh @@ -50,7 +45,7 @@ class PPInferEngine: def __init__( self, pp_size: int, - fp16: bool = True, + dtype: str = 'fp16', pp_model: nn.Module = None, model: nn.Module = None, model_policy: Policy = None, @@ -70,8 +65,12 @@ def __init__( self.mb_manager = MicroBatchManager(self.stage_manager.stage, new_length, micro_batch_size, micro_batch_buffer_size or pp_size) self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose) - if fp16: + + assert dtype in ['fp16', 'fp32', 'bf16'], "dtype should be one of 'fp16', 'fp32', 'bf16'" + if dtype == 'fp16': model.half() + elif dtype == 'bf16': + model.to(torch.bfloat16) self.model = pp_model or self._shardformer(model, model_policy) diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index ffe5056c2fab..8443ff4d555b 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -133,6 +133,7 @@ def LoadStageAction(self, model: Module) -> None: """ inputs_dict = self.load_micro_batch() if self.verbose and self.stage_manager.is_first_stage(): + torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) output_dict = model_forward(model, inputs_dict, None) @@ -148,6 +149,7 @@ def GenTokenAction(self, model: Module): hidden_states = {'hidden_states': hidden_states} logits = model_forward(model, None, hidden_states) if self.verbose and self.stage_manager.is_first_stage(): + torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" new_token = self._get_token_id(logits['logits']) @@ -295,6 +297,7 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t if self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.PREFILL: inputs_dict = self.load_micro_batch() if self.verbose and self.stage_manager.is_first_stage(): + torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) output_dict = model_forward(model, inputs_dict, None) self.mb_manager.step(inputs_dict, output_dict, None) @@ -307,6 +310,7 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None" logits = model_forward(model, None, hidden_states) if self.verbose and self.stage_manager.is_first_stage(): + torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" new_token = self._get_token_id(logits['logits']) From 69fc4dbd72b8fa532ffdf8fed23cdfc290721b79 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Tue, 26 Sep 2023 14:29:30 +0800 Subject: [PATCH 7/7] fix test bug --- colossalai/inference/pipeline/engine.py | 12 ++++++++---- colossalai/pipeline/schedule/generate.py | 9 +++++---- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/colossalai/inference/pipeline/engine.py b/colossalai/inference/pipeline/engine.py index 1f4cb97d6e55..048ead2bccda 100644 --- a/colossalai/inference/pipeline/engine.py +++ b/colossalai/inference/pipeline/engine.py @@ -11,6 +11,7 @@ from .microbatch_manager import MicroBatchManager + class PPInferEngine: ''' PPInferEngine is a class that handles the pipeline parallel inference. @@ -64,8 +65,9 @@ def __init__( self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True) self.mb_manager = MicroBatchManager(self.stage_manager.stage, new_length, micro_batch_size, micro_batch_buffer_size or pp_size) + self.verbose = verbose self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose) - + assert dtype in ['fp16', 'fp32', 'bf16'], "dtype should be one of 'fp16', 'fp32', 'bf16'" if dtype == 'fp16': model.half() @@ -73,10 +75,12 @@ def __init__( model.to(torch.bfloat16) self.model = pp_model or self._shardformer(model, model_policy) - def inference(self, input_list): - out = self.schedule.generate_step(self.model, iter(input_list)) - return out + out, timestamp = self.schedule.generate_step(self.model, iter(input_list)) + if self.verbose: + return out, timestamp + else: + return out def _shardformer(self, model, model_policy): shardconfig = ShardConfig( diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index 8443ff4d555b..1a961d3036b8 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -249,7 +249,7 @@ def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.T #run by round for _ in range(self.round): self.timestamps = [[] for _ in range(self.stage_manager.num_stages) - ] if self.verbose and self.stage_manager.is_first_stage() else None + ] if self.verbose and self.stage_manager.is_first_stage() else None self.action_interval_buffer.clear() while self.mb_manager.is_micro_batch_done() is False: actions = self.genAction(model) @@ -287,7 +287,7 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t # run by round for _ in range(self.round): self.timestamps = [[] for _ in range(self.stage_manager.num_stages) - ] if self.verbose and self.stage_manager.is_first_stage() else None + ] if self.verbose and self.stage_manager.is_first_stage() else None while self.mb_manager.is_micro_batch_done() is False: inputs_dict = None new_token = None @@ -316,7 +316,7 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t new_token = self._get_token_id(logits['logits']) self.mb_manager.step(None, None, new_token) # If the current micro batch is not DONE, go through blocks - if self.mb_manager.cur_state is Status.GENERATE: + if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN): inputs_dict = self._prepare_inputs_for_new_token(new_token) output_dict = model_forward(model, inputs_dict, None) self.mb_manager.step(inputs_dict, output_dict, None) @@ -327,7 +327,8 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t self.mb_manager.step(inputs_dict, output_dict, None) # Current microbatch is not DONE, send hidden_state to next stage - if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state is Status.GENERATE: + if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state in (Status.GENERATE, + Status.COOLDOWN): self.comm.send_forward({'hidden_states': output_dict['hidden_states']}) self.mb_manager.next()