Skip to content
Merged
34 changes: 34 additions & 0 deletions colossalai/inference/dynamic_batching/get_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from transformers import AutoTokenizer

_FAST_LLAMA_TOKENIZER = "/home/lccd/share/llama-tokenizer"


def get_tokenizer(
tokenizer=None,
tokenizer_name: str = "",
trust_remote_code: bool = False,
use_fast: bool = True,
):
if tokenizer is not None:
tokenizer = tokenizer
else:
if "llama" in tokenizer_name.lower() and use_fast == True:
print(
"For some LLaMA-based models, initializing the fast tokenizer may "
"take a long time. To eliminate the initialization time, consider "
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
"tokenizer. This is done automatically in Colossalai."
)

tokenizer_name = _FAST_LLAMA_TOKENIZER

try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code
)
except TypeError:
use_fast = False
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code
)
return tokenizer
15 changes: 12 additions & 3 deletions colossalai/inference/dynamic_batching/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class Req:
def __init__(self, request_id, prompt_ids, sample_params: SamplingParams):
def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, prompts: str):
self.request_id = request_id
self.prompt_ids = prompt_ids
self.input_len = len(prompt_ids)
Expand All @@ -14,6 +14,7 @@ def __init__(self, request_id, prompt_ids, sample_params: SamplingParams):
self.output_metadata_list = []
self.has_generate_finished = False
self.aborted = False
self.prompts = prompts

def to_rpc_obj(self):
return {
Expand All @@ -36,7 +37,11 @@ def stop_sequences_matched(self):
if self.sample_params.stop_sequences is not None:
for stop_token_ids in self.sample_params.stop_sequences:
stop_len = len(stop_token_ids)
if stop_len > 0 and len(self.output_ids) >= stop_len and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)):
if (
stop_len > 0
and len(self.output_ids) >= stop_len
and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len))
):
return True
return False

Expand Down Expand Up @@ -102,17 +107,21 @@ def mark_finished_req(self, eos_id):
has_new_finish = True
return has_new_finish

def filter_finished(self):
def filter_finished(self) -> List[Req]:
"""
Filter finished requests from the batch, the finished ones will be removed from 'reqs'.
"""
# TODO: the logic of return should be defined here.
unfinished_req = []
finished_req = []
for req in self.reqs:
if not req.has_generate_finished:
unfinished_req.append(req)
else:
finished_req.append(req)
self.reqs = unfinished_req
self.id_to_reqs = {req.request_id: req for req in self.reqs}
return finished_req

def is_clear(self):
return len(self.reqs) == 0
Expand Down
164 changes: 164 additions & 0 deletions colossalai/inference/dynamic_batching/ray_dist_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import asyncio
import logging
import os
from typing import List

import ray
import ray.util.collective as collective
import torch
from transformers import AutoModelForCausalLM

import colossalai
from colossalai.inference.dynamic_batching.get_tokenizer import get_tokenizer
from colossalai.inference.dynamic_batching.ray_init_config import EngineArgsClass, RooterArgsClass
from colossalai.inference.dynamic_batching.sampling_params import SamplingParams
from colossalai.inference.manager import start_dynamic_batching
from colossalai.inference.tensor_parallel.engine import TPInferEngine
from colossalai.shardformer import ShardConfig
from colossalai.testing import free_port

ray_serve_logger = logging.getLogger("ray.serve")


def log_cuda_info(scope_name: str):
ray_serve_logger.info(f" {scope_name}: ray.get_gpu_ids(): {ray.get_gpu_ids()}")
ray_serve_logger.info(
f" {scope_name}: CUDA_VISIBLE_DEVICES: {os.getenv('CUDA_VISIBLE_DEVICES', 'NO DEVICES FOUND!')}"
)
if torch.cuda.is_available():
ray_serve_logger.info(
f" {scope_name}: cuda current_device: {torch.cuda.current_device()}, cuda device count: {torch.cuda.device_count()}"
)
else:
ray_serve_logger.info(f" {scope_name}: cuda is not available!")


@ray.remote(num_gpus=1)
class Worker:
def __init__(
self,
model_path: str,
tensor_parallel_size: int,
max_batch_size: int,
max_input_len: int,
max_output_len: int,
router_config: RooterArgsClass,
):
log_cuda_info("Worker.init")
self.tensor_parallel_size = tensor_parallel_size
self.model_path = model_path
self.max_batch_size = max_batch_size
self.max_input_len = max_input_len
self.max_output_len = max_output_len
self.router_config = router_config

def setup(self, world_size, rank, port):
# initialize a ray collective group, otherwise colossalai distributed env won't be built successfully
collective.init_collective_group(world_size, rank, "nccl", "default")
# initialize and set distributed environment
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
ray_serve_logger.info(f"Worker with rank {rank} (world size {world_size}) setting up..")
log_cuda_info("Worker.setup")

# Load model
self.tokenizer = get_tokenizer(tokenizer_name=self.model_path)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16
)

shard_config = ShardConfig(enable_tensor_parallelism=True if world_size > 1 else False, inference_only=True)
self.infer_engine = TPInferEngine(
self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len
)
self.start_dynamic_batching = start_dynamic_batching(self.router_config, self.infer_engine, [])

return True

def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams) -> str:
ray_serve_logger.info(f"text: {prompt}")

results_generator = self.start_dynamic_batching.generate(prompt, sampling_params, request_id)

final_output = None
for request_output in results_generator:
final_output = request_output

assert final_output is not None
ray_serve_logger.info(f"Generated text: {final_output}")
return final_output

def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams):
self.start_dynamic_batching.add_input(request_id, sampling_params, prompt)

def abort(self, request_id: str):
self.start_dynamic_batching.abort(request_id)

def step(self):
self.start_dynamic_batching._step()

def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str):
self.start_dynamic_batching.add_req(prompt_ids, sampling_params, request_id, prompt)


class Driver:
def __init__(self, router_config: RooterArgsClass, engine_config: EngineArgsClass):
log_cuda_info("Driver:init")
model_path = engine_config.model
tensor_parallel_size = engine_config.tensor_parallel_size

self.num_workers = tensor_parallel_size
self.workers = []
init_rets = []

# Just grab a free port on localhost
# NOTE workers in this communication group listen to the same port
available_port = free_port()

for i in range(self.num_workers):
worker_name = "worker_idx_{}".format(i)
w = Worker.options(name=worker_name).remote(
model_path,
self.num_workers,
engine_config.max_batch_size,
engine_config.max_input_len,
engine_config.max_output_len,
router_config,
)
self.workers.append(w)
init_rets.append(w.setup.remote(self.num_workers, i, available_port))
_options = {
"group_name": "default_driver",
"world_size": self.num_workers,
"ranks": [i for i in range(self.num_workers)],
"backend": "nccl",
}
collective.create_collective_group(self.workers, **_options)
_ = ray.get(init_rets)

# set batch wait delay in seconds and maximum number of sequences in a batch
def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams):
results = ray.get([w.generate.remote(request_id, prompt, sampling_params) for w in self.workers])
text_res = results[0] # get any one of the copies
return text_res

async def async_generate(self, request_id: str, prompt: str, sampling_params: SamplingParams):
all_outputs = []
for worker in self.workers:
all_outputs.append(worker.generate.remote(request_id, prompt, sampling_params))
all_outputs = await asyncio.gather(*all_outputs)
text_res = all_outputs[0] # get any one of the copies
return text_res

def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams):
ray.get([w.add_input.remote(request_id, sampling_params, prompt) for w in self.workers])

def abort(self, request_id: str):
ray.get([w.abort.remote(request_id) for w in self.workers])

def step(self):
ray.get([w._step.remote() for w in self.workers])

def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str):
ray.get([w.add_req.remote(prompt_ids, sampling_params, request_id, prompt) for w in self.workers])
58 changes: 58 additions & 0 deletions colossalai/inference/dynamic_batching/ray_init_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import logging

import yaml
from pydantic import BaseModel

logger = logging.getLogger(__name__)


class EngineArgsClass(BaseModel):
"""Config for Engine"""

model: str
tensor_parallel_size: int = 2
max_batch_size: int = 4
max_input_len: int = 128
max_output_len: int = 32


class RooterArgsClass(BaseModel):
"""Config for Rooter"""

max_total_token_num: int = 42
batch_max_tokens: int = 42
eos_id: int = 0
disable_log_stats: bool = False
log_stats_interval: int = 10
model: str


class RayInitConfig(BaseModel):
"""All-together configs without app router config"""

engine_config_data: EngineArgsClass
router_config_data: RooterArgsClass

@classmethod
def from_yaml_path(cls, path: str):
try:
with open(path, "r") as yaml_file:
try:
config = yaml.safe_load(yaml_file)
# serve deployment config
engine_config = config.get("engine_config", {})
router_config = config.get("router_config", {})

return cls(
engine_config_data=engine_config,
router_config_data=router_config,
)
except yaml.YAMLError as e:
logger.error(f"An Error occurred when parsing yaml: {e}")
raise
except FileNotFoundError:
logger.error(f"The file '{path}' does not exist!")
raise
except OSError as e:
logger.error(f"An Error occurred: {e}")
raise
Loading