Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions colossalai/inference/pipeline/benchmark/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import torch
import torch.distributed as dist
import transformers

import colossalai
import time
from colossalai.inference import PPInferEngine
from colossalai.inference.pipeline.policy.llama_ppinfer import LlamaForCausalLMPipelinePolicy
import argparse
GIGABYTE = 1024 ** 3
MEGABYTE = 1024 * 1024

colossalai.launch_from_torch(config={})

def data_gen(batch_size: int=4, seq_len: int=512):
input_ids = torch.randint(10, 30000, (1, seq_len), dtype=torch.int32)
attention_mask = torch.ones((1, seq_len), dtype=torch.int32)
data = dict(input_ids=input_ids, attention_mask=attention_mask)
for k, v in data.items():
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = batch_size
data[k] = v.to('cuda').repeat(*new_shape)
return data

def print_details_info(timestamps, model_config, args, whole_end2end):
if dist.get_rank() == 0:
prefill = []
encoder = []
end2end = []
for timestamp in timestamps:
prefill.append(timestamp[1] - timestamp[0])
encoder.append(
sum(timestamp[i + 1] - timestamp[i] for i in range(1,len(timestamp) - 1)) / (len(timestamp) - 2))
end2end.append(timestamp[-1] - timestamp[0])
print(whole_end2end)
with open(f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","w+") as f:
mb_avg_end2end = sum(end2end)/len(end2end)
mb_avg_latency = mb_avg_end2end/(args.new_length * args.mb_size)
whole_avg_latency = whole_end2end/(args.new_length * args.batch_size)
num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers)
num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size
if args.dtype in ['fp16','bf16']:
num_bytes = 2
else:
num_bytes = 4

f.write(f"llama-{args.model}{args.dtype}_pp{args.pp_size}, input_len:{args.seq_len}, output_len:{args.new_length}, bsz:{args.batch_size}, mbsz:{args.mb_size}\n")
f.write("Average prefill time: {0:8.2f} ms\n".format(sum(prefill)/len(prefill)*1000))
f.write("Average encode time: {0:8.2f} ms\n".format(sum(encoder)/len(encoder)*1000))
f.write("Average micro batch end2end time: {0:8.2f} ms\n".format(mb_avg_end2end*1000))
f.write("Average micro batch Per Token Latency: {0:8.2f} ms\n".format(mb_avg_latency * 1000))
f.write("Whole batch end2end time: {0:8.2f} ms\n".format(whole_end2end*1000))
f.write("Whole batch Per Token Latency: {0:8.2f} ms\n".format(whole_avg_latency * 1000))
f.write("Throughput: {} tokens/s\n".format((1000/(whole_avg_latency * 1000))))
f.write("flops: {0:8.2f} TFlops/s\n".format(1/whole_avg_latency * num_parameters * num_bytes / 1e12))
f.write("----------------------------------------------------------\n")


if torch.cuda.is_available():
current_device = torch.cuda.current_device()

# free memory and the total available memory in bytes
global_free_memory, total_GPU_memory_occupied = torch.cuda.mem_get_info()
memory_allocated = torch.cuda.memory_allocated()
max_memory_allocated = torch.cuda.max_memory_allocated()
memory_reserved = torch.cuda.memory_reserved()
max_memory_reserved = torch.cuda.max_memory_reserved()
with open(f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","a") as f:
f.write(
f"\nCurrently using GPU: {current_device}\n"
f"free memory : {global_free_memory / GIGABYTE:.4f} GB,\n"
f"total memory: {total_GPU_memory_occupied / GIGABYTE:.4f} GB,\n"
f"memory allocated: {memory_allocated / GIGABYTE:.4f} GB,\n"
f"Max CUDA memory allocated: {max_memory_allocated / GIGABYTE:.4f} GB,\n"
f"memory reserved/cached: {memory_reserved / GIGABYTE:.4f} GB,\n"
f"Max CUDA memory reserved/cached: {max_memory_reserved / GIGABYTE:.4f} GB,\n"
)

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', default='toy', help='the size of model')
parser.add_argument('-b', '--batch_size', type=int, default=8, help='batch size')
parser.add_argument('-s', '--seq_len', type=int, default=8, help='sequence length')
parser.add_argument('--new_length', type=int, default=4, help='new tokens length')
parser.add_argument('--mb_size', type=int, default=1, help='micro_batch_size')
parser.add_argument('--pp_size', type=int, default=2, help='pipeline size')
parser.add_argument('--log_path', type=str, default='./log' ,help='where to store the benchmark log')
parser.add_argument('--dtype', type=str, default='fp16', help='data type')
args = parser.parse_args()

if args.model == 'toy':
model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=8))
elif args.model == '7b':
model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained('decapoda-research/llama-7b-hf'))
elif args.model == '13b':
model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained('decapoda-research/llama-13b-hf'))
else:
raise NotImplementedError


engine = PPInferEngine(pp_size=args.pp_size, dtype=args.dtype, micro_batch_size=args.mb_size, new_length=args.new_length, model=model, model_policy=LlamaForCausalLMPipelinePolicy(),verbose=True)
data = data_gen(args.batch_size, args.seq_len)

torch.cuda.synchronize()
whole_end2end = time.time()
Comment thread
ver217 marked this conversation as resolved.
output, timestamps = engine.inference([data])
torch.cuda.synchronize()
whole_end2end = time.time() - whole_end2end

print_details_info(timestamps, model.config, args, whole_end2end)

50 changes: 50 additions & 0 deletions colossalai/inference/pipeline/benchmark/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
script_dir=$(cd "$(dirname "$0")" && pwd)
cd "${script_dir}"

# 7b, fp32, 2 gpu, 1024, 128
for BATCH_SIZE in 2 4 8 16; do
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
--model="7b" \
--dtype="fp16" \
--batch_size=${BATCH_SIZE} \
--seq_len=1024 \
--new_length=128 \
--mb_size=$((${BATCH_SIZE}/2)) \
--pp_size=2
done

# 7b, fp32, 2 gpu, 512, 512
for BATCH_SIZE in 2 4 8 16 32; do
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
--model="7b" \
--dtype="fp16" \
--batch_size=${BATCH_SIZE} \
--seq_len=512 \
--new_length=512 \
--mb_size=$((${BATCH_SIZE}/2)) \
--pp_size=2
done

# 7b, fp32, 2 gpu, 1024, 128
for BATCH_SIZE in 2 4 8; do
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
--model="13b" \
--dtype="fp16" \
--batch_size=${BATCH_SIZE} \
--seq_len=1024 \
--new_length=128 \
--mb_size=$((${BATCH_SIZE}/2)) \
--pp_size=2
done

# 13b, fp16, 2 gpu, 512, 512
for BATCH_SIZE in 2 4 8 16; do
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
--model="13b" \
--dtype="fp16" \
--batch_size=${BATCH_SIZE} \
--seq_len=512 \
--new_length=512 \
--mb_size=$((${BATCH_SIZE}/2)) \
--pp_size=2
done
25 changes: 14 additions & 11 deletions colossalai/inference/pipeline/engine.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,15 @@
import re
from functools import partial
from types import MethodType
from typing import Callable, List, Optional, Set
from typing import Callable, List, Optional, Set, Union

import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn

from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.schedule.generate import GenerateSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer._utils import getattr_
from colossalai.shardformer.policies.base_policy import Policy

from .microbatch_manager import MicroBatchManager
from .policy.gpt2_ppinfer import GPT2LMHeadModelPipelinePolicy
from .utils import get_suffix_name, set_tensors_to_none


class PPInferEngine:
Expand Down Expand Up @@ -54,6 +46,7 @@ class PPInferEngine:
def __init__(
self,
pp_size: int,
dtype: str = 'fp16',
pp_model: nn.Module = None,
model: nn.Module = None,
model_policy: Policy = None,
Expand All @@ -72,12 +65,22 @@ def __init__(
self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True)
self.mb_manager = MicroBatchManager(self.stage_manager.stage, new_length, micro_batch_size,
micro_batch_buffer_size or pp_size)
self.verbose = verbose
self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose)

assert dtype in ['fp16', 'fp32', 'bf16'], "dtype should be one of 'fp16', 'fp32', 'bf16'"
if dtype == 'fp16':
model.half()
elif dtype == 'bf16':
model.to(torch.bfloat16)
self.model = pp_model or self._shardformer(model, model_policy)

def inference(self, input_list):
out = self.schedule.generate_step(self.model, iter(input_list))
return out
out, timestamp = self.schedule.generate_step(self.model, iter(input_list))
if self.verbose:
return out, timestamp
else:
return out

def _shardformer(self, model, model_policy):
shardconfig = ShardConfig(
Expand Down
37 changes: 21 additions & 16 deletions colossalai/pipeline/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,27 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any:
return object_list[0]


def _p2p_comm_shape(
def _p2p_comm(
tensor_send_next: torch.Tensor,
recv_prev: bool,
peer: int,
group: ProcessGroup,
comm_dtype: torch.dtype = torch.float16,
):
"""
Send and recv tensor using P2P communication, used when pipeline size is 2 to solve the race communication.

Agrs:
tensor_send_next (torch.Tensor): tensor to be sent to next stage
recv_prev (bool): whether to receive tensor from previous stage
peer (int): rank of the peer
group (ProcessGroup): process group
comm_dtype (torch.dtype): dtype of the tensor to be sent

Returns:
torch.Tensor: tensor received from previous stage
"""
# send and recv shape
send_next_shape = None
recv_prev_shape = None

Expand Down Expand Up @@ -194,20 +209,10 @@ def _p2p_comm_shape(
if recv_prev_shape is not None:
recv_prev_shape = recv_prev_shape.tolist()

return recv_prev_shape


def _p2p_comm(
tensor_send_next: torch.Tensor,
recv_pre: bool,
peer: int,
group: ProcessGroup,
comm_type: torch.dtype = torch.float32,
):
# send and recv data
tensor_recv_prev = None
recv_prev_shape = _p2p_comm_shape(tensor_send_next, recv_pre, peer, group)
if recv_pre:
tensor_recv_prev = torch.empty(recv_prev_shape, device=torch.cuda.current_device(), dtype=comm_type)
if recv_prev:
tensor_recv_prev = torch.empty(recv_prev_shape, device=torch.cuda.current_device(), dtype=comm_dtype)

ops = []
if tensor_send_next is not None:
Expand Down Expand Up @@ -296,7 +301,7 @@ def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
cur_rank = self.stage_manager.get_rank()
_send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))

def p2p_communicate(self, output_object: Any, recv_pre: bool, peer: int = None) -> None:
def p2p_communicate(self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16) -> None:
"""
Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch.

Expand All @@ -307,5 +312,5 @@ def p2p_communicate(self, output_object: Any, recv_pre: bool, peer: int = None)
if peer is None:
peer = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
recv_tensor = _p2p_comm(output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer))
recv_tensor = _p2p_comm(output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer), comm_dtype)
return recv_tensor
Loading