diff --git a/colossalai/inference/pipeline/benchmark/benchmark.py b/colossalai/inference/pipeline/benchmark/benchmark.py new file mode 100644 index 000000000000..97dfc6336bea --- /dev/null +++ b/colossalai/inference/pipeline/benchmark/benchmark.py @@ -0,0 +1,112 @@ +import torch +import torch.distributed as dist +import transformers + +import colossalai +import time +from colossalai.inference import PPInferEngine +from colossalai.inference.pipeline.policy.llama_ppinfer import LlamaForCausalLMPipelinePolicy +import argparse +GIGABYTE = 1024 ** 3 +MEGABYTE = 1024 * 1024 + +colossalai.launch_from_torch(config={}) + +def data_gen(batch_size: int=4, seq_len: int=512): + input_ids = torch.randint(10, 30000, (1, seq_len), dtype=torch.int32) + attention_mask = torch.ones((1, seq_len), dtype=torch.int32) + data = dict(input_ids=input_ids, attention_mask=attention_mask) + for k, v in data.items(): + if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: + new_shape = [1] * v.dim() + new_shape[0] = batch_size + data[k] = v.to('cuda').repeat(*new_shape) + return data + +def print_details_info(timestamps, model_config, args, whole_end2end): + if dist.get_rank() == 0: + prefill = [] + encoder = [] + end2end = [] + for timestamp in timestamps: + prefill.append(timestamp[1] - timestamp[0]) + encoder.append( + sum(timestamp[i + 1] - timestamp[i] for i in range(1,len(timestamp) - 1)) / (len(timestamp) - 2)) + end2end.append(timestamp[-1] - timestamp[0]) + print(whole_end2end) + with open(f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","w+") as f: + mb_avg_end2end = sum(end2end)/len(end2end) + mb_avg_latency = mb_avg_end2end/(args.new_length * args.mb_size) + whole_avg_latency = whole_end2end/(args.new_length * args.batch_size) + num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers) + num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size + if args.dtype in ['fp16','bf16']: + num_bytes = 2 + else: + num_bytes = 4 + + f.write(f"llama-{args.model}{args.dtype}_pp{args.pp_size}, input_len:{args.seq_len}, output_len:{args.new_length}, bsz:{args.batch_size}, mbsz:{args.mb_size}\n") + f.write("Average prefill time: {0:8.2f} ms\n".format(sum(prefill)/len(prefill)*1000)) + f.write("Average encode time: {0:8.2f} ms\n".format(sum(encoder)/len(encoder)*1000)) + f.write("Average micro batch end2end time: {0:8.2f} ms\n".format(mb_avg_end2end*1000)) + f.write("Average micro batch Per Token Latency: {0:8.2f} ms\n".format(mb_avg_latency * 1000)) + f.write("Whole batch end2end time: {0:8.2f} ms\n".format(whole_end2end*1000)) + f.write("Whole batch Per Token Latency: {0:8.2f} ms\n".format(whole_avg_latency * 1000)) + f.write("Throughput: {} tokens/s\n".format((1000/(whole_avg_latency * 1000)))) + f.write("flops: {0:8.2f} TFlops/s\n".format(1/whole_avg_latency * num_parameters * num_bytes / 1e12)) + f.write("----------------------------------------------------------\n") + + + if torch.cuda.is_available(): + current_device = torch.cuda.current_device() + + # free memory and the total available memory in bytes + global_free_memory, total_GPU_memory_occupied = torch.cuda.mem_get_info() + memory_allocated = torch.cuda.memory_allocated() + max_memory_allocated = torch.cuda.max_memory_allocated() + memory_reserved = torch.cuda.memory_reserved() + max_memory_reserved = torch.cuda.max_memory_reserved() + with open(f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","a") as f: + f.write( + f"\nCurrently using GPU: {current_device}\n" + f"free memory : {global_free_memory / GIGABYTE:.4f} GB,\n" + f"total memory: {total_GPU_memory_occupied / GIGABYTE:.4f} GB,\n" + f"memory allocated: {memory_allocated / GIGABYTE:.4f} GB,\n" + f"Max CUDA memory allocated: {max_memory_allocated / GIGABYTE:.4f} GB,\n" + f"memory reserved/cached: {memory_reserved / GIGABYTE:.4f} GB,\n" + f"Max CUDA memory reserved/cached: {max_memory_reserved / GIGABYTE:.4f} GB,\n" + ) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--model', default='toy', help='the size of model') + parser.add_argument('-b', '--batch_size', type=int, default=8, help='batch size') + parser.add_argument('-s', '--seq_len', type=int, default=8, help='sequence length') + parser.add_argument('--new_length', type=int, default=4, help='new tokens length') + parser.add_argument('--mb_size', type=int, default=1, help='micro_batch_size') + parser.add_argument('--pp_size', type=int, default=2, help='pipeline size') + parser.add_argument('--log_path', type=str, default='./log' ,help='where to store the benchmark log') + parser.add_argument('--dtype', type=str, default='fp16', help='data type') + args = parser.parse_args() + + if args.model == 'toy': + model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=8)) + elif args.model == '7b': + model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained('decapoda-research/llama-7b-hf')) + elif args.model == '13b': + model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained('decapoda-research/llama-13b-hf')) + else: + raise NotImplementedError + + + engine = PPInferEngine(pp_size=args.pp_size, dtype=args.dtype, micro_batch_size=args.mb_size, new_length=args.new_length, model=model, model_policy=LlamaForCausalLMPipelinePolicy(),verbose=True) + data = data_gen(args.batch_size, args.seq_len) + + torch.cuda.synchronize() + whole_end2end = time.time() + output, timestamps = engine.inference([data]) + torch.cuda.synchronize() + whole_end2end = time.time() - whole_end2end + + print_details_info(timestamps, model.config, args, whole_end2end) + diff --git a/colossalai/inference/pipeline/benchmark/run.sh b/colossalai/inference/pipeline/benchmark/run.sh new file mode 100644 index 000000000000..7d8da858692f --- /dev/null +++ b/colossalai/inference/pipeline/benchmark/run.sh @@ -0,0 +1,50 @@ +script_dir=$(cd "$(dirname "$0")" && pwd) +cd "${script_dir}" + +# 7b, fp32, 2 gpu, 1024, 128 +for BATCH_SIZE in 2 4 8 16; do + CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ + --model="7b" \ + --dtype="fp16" \ + --batch_size=${BATCH_SIZE} \ + --seq_len=1024 \ + --new_length=128 \ + --mb_size=$((${BATCH_SIZE}/2)) \ + --pp_size=2 +done + +# 7b, fp32, 2 gpu, 512, 512 +for BATCH_SIZE in 2 4 8 16 32; do + CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ + --model="7b" \ + --dtype="fp16" \ + --batch_size=${BATCH_SIZE} \ + --seq_len=512 \ + --new_length=512 \ + --mb_size=$((${BATCH_SIZE}/2)) \ + --pp_size=2 +done + +# 7b, fp32, 2 gpu, 1024, 128 +for BATCH_SIZE in 2 4 8; do + CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ + --model="13b" \ + --dtype="fp16" \ + --batch_size=${BATCH_SIZE} \ + --seq_len=1024 \ + --new_length=128 \ + --mb_size=$((${BATCH_SIZE}/2)) \ + --pp_size=2 +done + +# 13b, fp16, 2 gpu, 512, 512 +for BATCH_SIZE in 2 4 8 16; do + CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ + --model="13b" \ + --dtype="fp16" \ + --batch_size=${BATCH_SIZE} \ + --seq_len=512 \ + --new_length=512 \ + --mb_size=$((${BATCH_SIZE}/2)) \ + --pp_size=2 +done diff --git a/colossalai/inference/pipeline/engine.py b/colossalai/inference/pipeline/engine.py index 39366d2d69da..048ead2bccda 100644 --- a/colossalai/inference/pipeline/engine.py +++ b/colossalai/inference/pipeline/engine.py @@ -1,23 +1,15 @@ -import re -from functools import partial -from types import MethodType -from typing import Callable, List, Optional, Set +from typing import Callable, List, Optional, Set, Union -import numpy as np import torch -import torch.distributed as dist import torch.nn as nn from colossalai.cluster import ProcessGroupMesh from colossalai.pipeline.schedule.generate import GenerateSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.shardformer._utils import getattr_ from colossalai.shardformer.policies.base_policy import Policy from .microbatch_manager import MicroBatchManager -from .policy.gpt2_ppinfer import GPT2LMHeadModelPipelinePolicy -from .utils import get_suffix_name, set_tensors_to_none class PPInferEngine: @@ -54,6 +46,7 @@ class PPInferEngine: def __init__( self, pp_size: int, + dtype: str = 'fp16', pp_model: nn.Module = None, model: nn.Module = None, model_policy: Policy = None, @@ -72,12 +65,22 @@ def __init__( self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True) self.mb_manager = MicroBatchManager(self.stage_manager.stage, new_length, micro_batch_size, micro_batch_buffer_size or pp_size) + self.verbose = verbose self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose) + + assert dtype in ['fp16', 'fp32', 'bf16'], "dtype should be one of 'fp16', 'fp32', 'bf16'" + if dtype == 'fp16': + model.half() + elif dtype == 'bf16': + model.to(torch.bfloat16) self.model = pp_model or self._shardformer(model, model_policy) def inference(self, input_list): - out = self.schedule.generate_step(self.model, iter(input_list)) - return out + out, timestamp = self.schedule.generate_step(self.model, iter(input_list)) + if self.verbose: + return out, timestamp + else: + return out def _shardformer(self, model, model_policy): shardconfig = ShardConfig( diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 227ad2daca0e..c8195d422c03 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -159,12 +159,27 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any: return object_list[0] -def _p2p_comm_shape( +def _p2p_comm( tensor_send_next: torch.Tensor, recv_prev: bool, peer: int, group: ProcessGroup, + comm_dtype: torch.dtype = torch.float16, ): + """ + Send and recv tensor using P2P communication, used when pipeline size is 2 to solve the race communication. + + Agrs: + tensor_send_next (torch.Tensor): tensor to be sent to next stage + recv_prev (bool): whether to receive tensor from previous stage + peer (int): rank of the peer + group (ProcessGroup): process group + comm_dtype (torch.dtype): dtype of the tensor to be sent + + Returns: + torch.Tensor: tensor received from previous stage + """ + # send and recv shape send_next_shape = None recv_prev_shape = None @@ -194,20 +209,10 @@ def _p2p_comm_shape( if recv_prev_shape is not None: recv_prev_shape = recv_prev_shape.tolist() - return recv_prev_shape - - -def _p2p_comm( - tensor_send_next: torch.Tensor, - recv_pre: bool, - peer: int, - group: ProcessGroup, - comm_type: torch.dtype = torch.float32, -): + # send and recv data tensor_recv_prev = None - recv_prev_shape = _p2p_comm_shape(tensor_send_next, recv_pre, peer, group) - if recv_pre: - tensor_recv_prev = torch.empty(recv_prev_shape, device=torch.cuda.current_device(), dtype=comm_type) + if recv_prev: + tensor_recv_prev = torch.empty(recv_prev_shape, device=torch.cuda.current_device(), dtype=comm_dtype) ops = [] if tensor_send_next is not None: @@ -296,7 +301,7 @@ def send_backward(self, input_object: Any, prev_rank: int = None) -> None: cur_rank = self.stage_manager.get_rank() _send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank)) - def p2p_communicate(self, output_object: Any, recv_pre: bool, peer: int = None) -> None: + def p2p_communicate(self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16) -> None: """ Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch. @@ -307,5 +312,5 @@ def p2p_communicate(self, output_object: Any, recv_pre: bool, peer: int = None) if peer is None: peer = self.stage_manager.get_next_rank() cur_rank = self.stage_manager.get_rank() - recv_tensor = _p2p_comm(output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer)) + recv_tensor = _p2p_comm(output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer), comm_dtype) return recv_tensor diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index 85feebea6e6e..1a961d3036b8 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -17,6 +17,10 @@ class ActionIntervalBuffer(): + """ + The buffer to save the interval hidden states and new token for stage to use. + + """ def __int__(self): self.hidden_states = None @@ -28,7 +32,7 @@ def clear(self): class GenerateSchedule(PipelineSchedule): - ''' + """ GenerateSchedule is a class that handles the pipeline parallel inference. In our schedule, we place tie weight layer, embedding and lm_head in the same device to save space, so in this schedule, the out for each encoding progress is on rank0. @@ -37,7 +41,7 @@ class GenerateSchedule(PipelineSchedule): stage_manager (`PipelineStageManager`): Pipeline stage manager. mb_manager (`MicroBatchManager`): Micro batch manager. verbose (bool): Whether to verbose the information of the pipeline. - ''' + """ def __init__(self, stage_manager: PipelineStageManager, mb_manager: MicroBatchManager, verbose: bool) -> None: super().__init__(stage_manager) @@ -48,8 +52,10 @@ def __init__(self, stage_manager: PipelineStageManager, mb_manager: MicroBatchMa self.batch_size: Optional[int] = None self.microbatch_offset: Optional[int] = None self.num_microbatches: Optional[int] = None - self.verbose = verbose self.action_interval_buffer = ActionIntervalBuffer() + self.verbose = verbose + self.timestamps = None + self.comm_dtype = None def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -126,6 +132,9 @@ def LoadStageAction(self, model: Module) -> None: In this action, 1.load micro_batch 2.do the forward 3.step to update """ inputs_dict = self.load_micro_batch() + if self.verbose and self.stage_manager.is_first_stage(): + torch.cuda.synchronize() + self.timestamps[self.mb_manager.idx].append(time.time()) output_dict = model_forward(model, inputs_dict, None) self.mb_manager.step(inputs_dict, output_dict, None) @@ -139,6 +148,9 @@ def GenTokenAction(self, model: Module): assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None" hidden_states = {'hidden_states': hidden_states} logits = model_forward(model, None, hidden_states) + if self.verbose and self.stage_manager.is_first_stage(): + torch.cuda.synchronize() + self.timestamps[self.mb_manager.idx].append(time.time()) assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" new_token = self._get_token_id(logits['logits']) @@ -173,7 +185,7 @@ def CommAction(self, recv_pre: bool) -> torch.Tensor: In this action, 1.receive the hidden_states from previous stage 2.send the hidden_states to next stage """ hidden_states = self.action_interval_buffer.hidden_states - ret = self.comm.p2p_communicate(hidden_states, recv_pre) + ret = self.comm.p2p_communicate(hidden_states, recv_pre, comm_dtype=self.comm_dtype) self.action_interval_buffer.hidden_states = ret @@ -208,20 +220,6 @@ def genAction(self, model: Module): return actions - def verbose_info(self, timestamps: List): - prefill = [] - encoder = [] - end2end = [] - for timestamp in timestamps: - prefill.append(timestamp[1] - timestamp[0]) - encoder.append( - sum(timestamp[i + 1] - timestamp[i] for i in range(1, - len(timestamp) - 1)) / (len(timestamp) - 2)) - end2end.append(timestamp[-1] - timestamp[0]) - print(f"Average prefill time: {sum(prefill)/len(prefill)}") - print(f"Average encode time: {sum(encoder)/len(encoder)}") - print(f"Average end2end time: {sum(end2end)/len(end2end)}") - def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: if self.stage_manager.num_stages == 2: return self.generate_step_p2p(model, data_iter) @@ -244,11 +242,14 @@ def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.T output_sequence = [] self.load_batch(data_iter) model.eval() + self.comm_dtype = model.dtype whole_timestamp = [] #run by round for _ in range(self.round): + self.timestamps = [[] for _ in range(self.stage_manager.num_stages) + ] if self.verbose and self.stage_manager.is_first_stage() else None self.action_interval_buffer.clear() while self.mb_manager.is_micro_batch_done() is False: actions = self.genAction(model) @@ -261,8 +262,10 @@ def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.T else: self.CommAction(False) self.mb_manager.clear() + if self.verbose and self.stage_manager.is_first_stage(): + whole_timestamp.extend(self.timestamps) - return output_sequence + return output_sequence, whole_timestamp @torch.no_grad() def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]: @@ -283,8 +286,8 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t whole_timestamp = [] # run by round for _ in range(self.round): - timestampes = [[] for _ in range(self.stage_manager.num_stages) - ] if self.verbose and self.stage_manager.is_first_stage() else None + self.timestamps = [[] for _ in range(self.stage_manager.num_stages) + ] if self.verbose and self.stage_manager.is_first_stage() else None while self.mb_manager.is_micro_batch_done() is False: inputs_dict = None new_token = None @@ -294,7 +297,8 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t if self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.PREFILL: inputs_dict = self.load_micro_batch() if self.verbose and self.stage_manager.is_first_stage(): - timestampes[self.mb_manager.idx].append(time.time()) + torch.cuda.synchronize() + self.timestamps[self.mb_manager.idx].append(time.time()) output_dict = model_forward(model, inputs_dict, None) self.mb_manager.step(inputs_dict, output_dict, None) # In GENERATE phase @@ -306,12 +310,13 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None" logits = model_forward(model, None, hidden_states) if self.verbose and self.stage_manager.is_first_stage(): - timestampes[self.mb_manager.idx].append(time.time()) + torch.cuda.synchronize() + self.timestamps[self.mb_manager.idx].append(time.time()) assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" new_token = self._get_token_id(logits['logits']) self.mb_manager.step(None, None, new_token) # If the current micro batch is not DONE, go through blocks - if self.mb_manager.cur_state is Status.GENERATE: + if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN): inputs_dict = self._prepare_inputs_for_new_token(new_token) output_dict = model_forward(model, inputs_dict, None) self.mb_manager.step(inputs_dict, output_dict, None) @@ -322,7 +327,8 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t self.mb_manager.step(inputs_dict, output_dict, None) # Current microbatch is not DONE, send hidden_state to next stage - if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state is Status.GENERATE: + if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state in (Status.GENERATE, + Status.COOLDOWN): self.comm.send_forward({'hidden_states': output_dict['hidden_states']}) self.mb_manager.next() @@ -332,9 +338,6 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t output_sequence.extend(self.mb_manager.export_new_tokens()) self.mb_manager.clear() if self.verbose and self.stage_manager.is_first_stage(): - whole_timestamp.extend(timestampes) - - if self.verbose and self.stage_manager.is_first_stage(): - self.verbose_info(whole_timestamp) + whole_timestamp.extend(self.timestamps) - return output_sequence + return output_sequence, whole_timestamp