From 4a7f9295aa8b560241474954e957aba28a33d603 Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Mon, 28 Aug 2023 17:56:41 +0800 Subject: [PATCH 1/9] add engine for TP inference --- colossalai/inference/__init__.py | 0 .../inference/tensor_parallel/__init__.py | 0 .../inference/tensor_parallel/engine.py | 250 ++++++++++++++++++ 3 files changed, 250 insertions(+) create mode 100644 colossalai/inference/__init__.py create mode 100644 colossalai/inference/tensor_parallel/__init__.py create mode 100644 colossalai/inference/tensor_parallel/engine.py diff --git a/colossalai/inference/__init__.py b/colossalai/inference/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/inference/tensor_parallel/__init__.py b/colossalai/inference/tensor_parallel/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py new file mode 100644 index 000000000000..0189f6d61b7c --- /dev/null +++ b/colossalai/inference/tensor_parallel/engine.py @@ -0,0 +1,250 @@ +from functools import partial +from types import MethodType +from typing import Any, Callable, List, Optional, Set, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from transformers.generation import GenerationConfig +from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList +from transformers.tokenization_utils_base import BatchEncoding + +from colossalai.cluster import ProcessGroupMesh +from colossalai.shardformer import ShardConfig, ShardFormer +# from colossalai.shardformer.policies.bloom import BloomModelInferPolicy +from colossalai.shardformer.policies.auto_policy import get_autopolicy + +from .batch_infer_state import BatchInferState +from .kvcache_manager import MemoryManager + +DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 + + +class InferenceEngine: + + def __init__(self, model: nn.Module, max_batch_size, max_input_len, max_output_len, tp_size=1) -> None: + self.model = model + self.sharded_model = None + + self.max_batch_size = max_batch_size + self.max_input_len = max_input_len + self.max_output_len = max_output_len + self.max_total_token_num = self.max_batch_size * (self.max_input_len + self.max_output_len) + assert self.max_batch_size <= 64 + assert self.max_input_len + self.max_output_len <= 2048 + + self.tp_size = tp_size + self.pp_size = 1 # only consider tp for now + self.dp_size = 1 # only consider tp for now + + self.head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads + self.head_num = self.model.config.num_attention_heads // self.tp_size + self.layer_num = self.model.config.num_hidden_layers + self.cache_manager = MemoryManager(self.max_total_token_num, torch.float16, self.head_num, self.head_dim, + self.layer_num) + + # self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) + # self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) + + def shard_model_by(self, shardformer: ShardFormer) -> None: + # TODO Might want to use infer policy only when bs >= 4 + assert self.tp_size == shardformer.shard_config.tensor_parallel_size, "Engine tp size != shardformer tp size" + # shardformer.shard_config.tensor_parallel_process_group = self.tp_group + model_name = self.model.__class__.__name__ + policy = get_autopolicy(self.model, inference_only=True) + if model_name == 'LlamaForCausalLM': + self.sharded_model, _ = shardformer.optimize(self.model, policy) + elif model_name == 'BloomForCausalLM': + self.sharded_model, _ = shardformer.optimize(self.model, policy) + else: + raise ValueError(f'Unsupported model "{model_name}" for inference') + self.sharded_model = self.sharded_model.cuda() + + # NOTE input_tokens is expected to be BatchEncoding, + # instead of only input token ids + @torch.no_grad() + def generate_by_pass_infer_state(self, + input_tokens, + max_out_length: int, + generation_config: Optional[GenerationConfig] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, + **model_kwargs) -> torch.Tensor: + + input_ids = input_tokens['input_ids'] + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + + if batch_size >= 4: + assert self.sharded_model is not None, "sharded model does not exist" + + batch_infer_state = self.prepare_batch_state(input_tokens) + batch_size = batch_infer_state.batch_size + assert batch_infer_state.max_len_in_batch <= self.max_input_len + + # record sequences finish status, add early stopping, etc, + + for _ in range(min(max_out_length, self.max_output_len)): + # ... + self.sharded_model.forward(..., **model_kwargs) + else: + # Use original model + orig_model = self.model + + for _ in range(min(max_out_length, self.max_output_len)): + + if prepare_inputs_fn is None and hasattr(orig_model, 'prepare_inputs_for_generation'): + prepare_inputs_fn = orig_model.prepare_inputs_for_generation + + model_inputs = prepare_inputs_fn(input_ids, ** + model_kwargs) if prepare_inputs_fn is not None else input_tokens + outputs = orig_model(**model_inputs) + + # next_token_logits = outputs['logits'][:, -1, :] + next_token_logits = outputs.logits[:, -1, :] + # pre-process distribution + # next_token_logits = logits_processor(input_ids, next_token_logits) + + # sample + # probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float) + # next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + + # consider greedy only for now + next_tokens = torch.argmax(next_token_logits, dim=-1) + + # finished sentences should have their next token be a padding token + + # if eos_token_id is not None: + # if pad_token_id is None: + # raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + # next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # # update generated ids, model inputs for next step + # input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + # if update_model_kwargs_fn is not None: + # model_kwargs = update_model_kwargs_fn(outputs, model_kwargs) + + # # if eos_token was found in one sentence, set sentence to finished + # if eos_token_id is not None: + # unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) + + # # stop when each sentence is finished if early_stopping=True + # if early_stopping and _is_sequence_finished(unfinished_sequences): + # break + + return input_ids + + @torch.no_grad() + def generate_by_set_infer_state(self, input_tokens, generate_kwargs, early_stopping=False): + + # for testing, always use sharded model + assert self.sharded_model is not None, "sharded model does not exist" + + batch_infer_state = self.prepare_batch_state(input_tokens) + assert batch_infer_state.max_len_in_batch <= self.max_input_len, "max length in batch exceeds limit" + + # set BatchInferState for the current batch as attr to model + # NOTE this is not an expectable way to pass BatchInferState during inference + # we might want to rewrite generate function (e.g. generate_by_pass_infer_state) + # and pass BatchInferState via model forward + if hasattr(self.sharded_model, 'model'): + model = self.sharded_model.model + elif hasattr(self.sharded_model, 'transformer'): + model = self.sharded_model.transformer + setattr(model, 'infer_state', batch_infer_state) + + # add logging + generate_kwargs.update(max_new_tokens=self.max_output_len) + + # convert to dict + if isinstance(input_tokens, torch.Tensor): + input_tokens = dict(input_ids=input_tokens) + for t in input_tokens: + if torch.is_tensor(input_tokens[t]): + input_tokens[t] = input_tokens[t].to(torch.cuda.current_device()) + print(f" input_tokens[{t}].shape: {input_tokens[t].shape}") + + outputs = self.sharded_model.generate(**input_tokens, **generate_kwargs, early_stopping=early_stopping) + + print(f"outputs.shape {outputs.shape}") + return outputs + + # inputs should be one of the following types + # 1. BatchEncoding (e.g. tokenizer batch_encode) + # 2. list of input token ids (e.g. appended result of tokenizer encode) + # 3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt') + # NOTE For torch.Tensor inputs representing a batch of inputs, we are unable to retrieve + # the actual length (e.g. number of tokens) of each input without attention mask + # Hence, for torch.Tensor with shape [bs, l] where bs > 1, we will assume + # all the inputs in the batch has the maximum length l + def prepare_batch_state(self, inputs: [BatchEncoding, torch.Tensor]) -> BatchInferState: + # records length based on attention mask + # Any better method? + if not isinstance(inputs, (BatchEncoding, list, torch.Tensor)): + raise TypeError(f"inputs type {type(inputs)} is not supported in prepare_batch_state") + + if isinstance(inputs, BatchEncoding): + attn_masks = inputs['attention_mask'] + batch_size = attn_masks.shape[0] + max_len_in_batch = attn_masks.shape[1] + elif isinstance(inputs, list): + batch_size = len(inputs) + else: + batch_size = inputs.shape[0] + + # block_loc = torch.empty(batch_size, self.max_input_len + self.max_output_len, dtype=torch.long, device="cuda") + seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + start_index = 0 + + max_len_in_batch = -1 + if isinstance(inputs, BatchEncoding): + for i, attn_mask in enumerate(attn_masks): + curr_seq_len = torch.sum(attn_mask) + seq_lengths[i] = curr_seq_len + seq_start_indexes[i] = start_index + start_index += curr_seq_len + max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch + else: + for i, input_ids in enumerate(inputs): + curr_seq_len = len(input_ids) + seq_lengths[i] = curr_seq_len + seq_start_indexes[i] = start_index + start_index += curr_seq_len + max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch + + block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device="cuda") + batch_infer_state = BatchInferState(batch_size, max_len_in_batch) + batch_infer_state.seq_len = seq_lengths.to('cuda') # might want to assign specific device + batch_infer_state.start_loc = seq_start_indexes.to('cuda') + batch_infer_state.block_loc = block_loc + # NOTE BatchInferState.total_token_num revised (not pushed yet) + # Now we want actual total token num based on seq_len, instead of dummy ones in test + # (Could still use the dummy one for testing usage) + batch_infer_state.set_cache_manager(self.cache_manager) + batch_infer_state.decode_layer_id = 0 + batch_infer_state.past_key_values_len = 0 + batch_infer_state.is_context_stage = True + return batch_infer_state + + # BatchInferState is created and kept during generation + # after each iter of model forward, we should update BatchInferState + # NOTE use in rewritten generate method: use after model.forward + def update_batch_state(self, infer_state: Optional[BatchInferState]) -> None: + # self.b_start_loc = self.b_start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + # self.b_seq_len += 1 + batch_size = infer_state.batch_size + device = infer_state.start_loc.device + infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device=device) + infer_state.seq_len += 1 + + # TODO might want to create a sequence pool + # add a single request/sequence/input text at a time and record its length + # In other words, store the actual length of input tokens representing a single input text + # E.g. "Introduce landmarks in Beijing" + # => add request + # => record token length and other necessary information to be used + # => engine hold all these necessary information until `generate` (or other name) is called, + # => put information already recorded in batchinferstate and pass it to model forward + # => clear records in engine + def add_request(): + pass From e0a38e2d2e433407fc20e5c7a1a9e0e556c98f6c Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Mon, 28 Aug 2023 18:01:12 +0800 Subject: [PATCH 2/9] move file path --- .../inference => inference/tensor_parallel}/batch_infer_state.py | 0 .../inference => inference/tensor_parallel}/kvcache_manager.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename colossalai/{shardformer/inference => inference/tensor_parallel}/batch_infer_state.py (100%) rename colossalai/{shardformer/inference => inference/tensor_parallel}/kvcache_manager.py (100%) diff --git a/colossalai/shardformer/inference/batch_infer_state.py b/colossalai/inference/tensor_parallel/batch_infer_state.py similarity index 100% rename from colossalai/shardformer/inference/batch_infer_state.py rename to colossalai/inference/tensor_parallel/batch_infer_state.py diff --git a/colossalai/shardformer/inference/kvcache_manager.py b/colossalai/inference/tensor_parallel/kvcache_manager.py similarity index 100% rename from colossalai/shardformer/inference/kvcache_manager.py rename to colossalai/inference/tensor_parallel/kvcache_manager.py From eee013f01022d12b707a05ddfb986f242a8a999f Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Mon, 28 Aug 2023 18:18:07 +0800 Subject: [PATCH 3/9] update path --- colossalai/inference/tensor_parallel/__init__.py | 4 ++++ colossalai/inference/tensor_parallel/engine.py | 2 +- tests/test_infer/test_kvcache_manager.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/colossalai/inference/tensor_parallel/__init__.py b/colossalai/inference/tensor_parallel/__init__.py index e69de29bb2d1..e467b4c73e6b 100644 --- a/colossalai/inference/tensor_parallel/__init__.py +++ b/colossalai/inference/tensor_parallel/__init__.py @@ -0,0 +1,4 @@ +from .engine import TPInferEngine +from .kvcache_manager import MemoryManager + +__all__ = ['MemoryManager', 'TPInferEngine'] diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 0189f6d61b7c..95cb62c11513 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -20,7 +20,7 @@ DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 -class InferenceEngine: +class TPInferEngine: def __init__(self, model: nn.Module, max_batch_size, max_input_len, max_output_len, tp_size=1) -> None: self.model = model diff --git a/tests/test_infer/test_kvcache_manager.py b/tests/test_infer/test_kvcache_manager.py index ef48444f73ca..fb04d7800ea2 100644 --- a/tests/test_infer/test_kvcache_manager.py +++ b/tests/test_infer/test_kvcache_manager.py @@ -3,8 +3,8 @@ import pytest import torch +from colossalai.inference.tensor_parallel import MemoryManager from colossalai.logging import disable_existing_loggers -from colossalai.shardformer.inference import MemoryManager from colossalai.testing import rerun_if_address_is_in_use, spawn BATCH_SIZE = 4 From f986e88a9283c788cc4e04d4ab7e6175c6481ce9 Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Mon, 28 Aug 2023 20:11:55 +0800 Subject: [PATCH 4/9] fix TPInferEngine --- .../tensor_parallel/batch_infer_state.py | 5 +- .../inference/tensor_parallel/engine.py | 217 ++++++++---------- 2 files changed, 105 insertions(+), 117 deletions(-) diff --git a/colossalai/inference/tensor_parallel/batch_infer_state.py b/colossalai/inference/tensor_parallel/batch_infer_state.py index fef23a584b8b..2bff9317283e 100644 --- a/colossalai/inference/tensor_parallel/batch_infer_state.py +++ b/colossalai/inference/tensor_parallel/batch_infer_state.py @@ -21,6 +21,7 @@ class BatchInferState: block_loc: torch.Tensor = None start_loc: torch.Tensor = None seq_len: torch.Tensor = None + past_key_values_len: int = None is_context_stage: bool = False context_mem_index: torch.Tensor = None @@ -34,7 +35,9 @@ class BatchInferState: @property def total_token_num(self): - return self.batch_size * self.max_len_in_batch + # return self.batch_size * self.max_len_in_batch + assert self.seq_len is not None and self.seq_len.size(0) > 0 + return int(torch.sum(self.seq_len)) def set_cache_manager(self, manager: MemoryManager): self.cache_manager = manager diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 95cb62c11513..f693a6a21ec6 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -1,17 +1,13 @@ -from functools import partial -from types import MethodType from typing import Any, Callable, List, Optional, Set, Union import torch -import torch.distributed as dist import torch.nn as nn from transformers.generation import GenerationConfig -from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList +from transformers.generation.stopping_criteria import StoppingCriteriaList from transformers.tokenization_utils_base import BatchEncoding from colossalai.cluster import ProcessGroupMesh from colossalai.shardformer import ShardConfig, ShardFormer -# from colossalai.shardformer.policies.bloom import BloomModelInferPolicy from colossalai.shardformer.policies.auto_policy import get_autopolicy from .batch_infer_state import BatchInferState @@ -22,7 +18,13 @@ class TPInferEngine: - def __init__(self, model: nn.Module, max_batch_size, max_input_len, max_output_len, tp_size=1) -> None: + def __init__(self, + model: nn.Module, + max_batch_size, + max_input_len, + max_output_len, + dtype=torch.float16, + tp_size=1) -> None: self.model = model self.sharded_model = None @@ -30,111 +32,68 @@ def __init__(self, model: nn.Module, max_batch_size, max_input_len, max_output_l self.max_input_len = max_input_len self.max_output_len = max_output_len self.max_total_token_num = self.max_batch_size * (self.max_input_len + self.max_output_len) + + # Constraints relatable with specs of devices assert self.max_batch_size <= 64 assert self.max_input_len + self.max_output_len <= 2048 + # NOTE For now, we focus on tensor parallel. + # We might want to merge pp and dp inference in future. self.tp_size = tp_size - self.pp_size = 1 # only consider tp for now - self.dp_size = 1 # only consider tp for now + self.pp_size = 1 + self.dp_size = 1 + self.dtype = dtype self.head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads self.head_num = self.model.config.num_attention_heads // self.tp_size self.layer_num = self.model.config.num_hidden_layers - self.cache_manager = MemoryManager(self.max_total_token_num, torch.float16, self.head_num, self.head_dim, + self.cache_manager = MemoryManager(self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.layer_num) - - # self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) - # self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) + self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) + + def create_shard_config(self) -> ShardConfig: + """ create a ShardConfig consistent with configs and attributes of the engine """ + shard_config = ShardConfig( + tensor_parallel_process_group=None, + pipeline_stage_manager=None, + enable_tensor_parallelism=False, + enable_fused_normalization=False, + enable_all_optimization=False, + enable_flash_attention=False, + enable_jit_fused=False, + inference_only=True, + ) + if self.tp_size > 1: + shard_config.enable_tensor_parallelism = True + tp_process_group = self.pg_mesh.get_group_along_axis(TP_AXIS) + shard_config.tensor_parallel_process_group = tp_process_group + return shard_config def shard_model_by(self, shardformer: ShardFormer) -> None: - # TODO Might want to use infer policy only when bs >= 4 - assert self.tp_size == shardformer.shard_config.tensor_parallel_size, "Engine tp size != shardformer tp size" - # shardformer.shard_config.tensor_parallel_process_group = self.tp_group + """ Shard the model and store the sharded model by given ShardFormer """ + assert self.tp_size == shardformer.shard_config.tensor_parallel_size, \ + "Discrepancy between the tp size of TPInferEngine and the tp size of shard config" model_name = self.model.__class__.__name__ + assert model_name in self._supported_model(), f"Unsupported model cls {model_name} for TP inference." policy = get_autopolicy(self.model, inference_only=True) - if model_name == 'LlamaForCausalLM': - self.sharded_model, _ = shardformer.optimize(self.model, policy) - elif model_name == 'BloomForCausalLM': - self.sharded_model, _ = shardformer.optimize(self.model, policy) - else: - raise ValueError(f'Unsupported model "{model_name}" for inference') + self.sharded_model, _ = shardformer.optimize(self.model, policy) self.sharded_model = self.sharded_model.cuda() - # NOTE input_tokens is expected to be BatchEncoding, - # instead of only input token ids - @torch.no_grad() - def generate_by_pass_infer_state(self, - input_tokens, - max_out_length: int, - generation_config: Optional[GenerationConfig] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, - **model_kwargs) -> torch.Tensor: - - input_ids = input_tokens['input_ids'] - unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) - - if batch_size >= 4: - assert self.sharded_model is not None, "sharded model does not exist" - - batch_infer_state = self.prepare_batch_state(input_tokens) - batch_size = batch_infer_state.batch_size - assert batch_infer_state.max_len_in_batch <= self.max_input_len - - # record sequences finish status, add early stopping, etc, - - for _ in range(min(max_out_length, self.max_output_len)): - # ... - self.sharded_model.forward(..., **model_kwargs) - else: - # Use original model - orig_model = self.model - - for _ in range(min(max_out_length, self.max_output_len)): - - if prepare_inputs_fn is None and hasattr(orig_model, 'prepare_inputs_for_generation'): - prepare_inputs_fn = orig_model.prepare_inputs_for_generation - - model_inputs = prepare_inputs_fn(input_ids, ** - model_kwargs) if prepare_inputs_fn is not None else input_tokens - outputs = orig_model(**model_inputs) - - # next_token_logits = outputs['logits'][:, -1, :] - next_token_logits = outputs.logits[:, -1, :] - # pre-process distribution - # next_token_logits = logits_processor(input_ids, next_token_logits) - - # sample - # probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float) - # next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) - - # consider greedy only for now - next_tokens = torch.argmax(next_token_logits, dim=-1) - - # finished sentences should have their next token be a padding token - - # if eos_token_id is not None: - # if pad_token_id is None: - # raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") - # next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) - - # # update generated ids, model inputs for next step - # input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - # if update_model_kwargs_fn is not None: - # model_kwargs = update_model_kwargs_fn(outputs, model_kwargs) - - # # if eos_token was found in one sentence, set sentence to finished - # if eos_token_id is not None: - # unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) - - # # stop when each sentence is finished if early_stopping=True - # if early_stopping and _is_sequence_finished(unfinished_sequences): - # break - - return input_ids + def _supported_model(self) -> List[str]: + supported_models = ['LlamaForCausalLM', 'LlamaModel', 'BloomForCausalLM'] + return supported_models @torch.no_grad() - def generate_by_set_infer_state(self, input_tokens, generate_kwargs, early_stopping=False): + def generate_by_set_infer_state(self, input_tokens, generate_kwargs, early_stopping=False) -> torch.Tensor: + """ + Generate output tokens by setting BatchInferState as an attribute to the model and calling model.generate + + Args: + inputs: should be one of the following types + 1. BatchEncoding (e.g. tokenizer batch_encode) + 2. list of input token ids (e.g. appended result of tokenizer encode) + 3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt') + """ # for testing, always use sharded model assert self.sharded_model is not None, "sharded model does not exist" @@ -152,33 +111,36 @@ def generate_by_set_infer_state(self, input_tokens, generate_kwargs, early_stopp model = self.sharded_model.transformer setattr(model, 'infer_state', batch_infer_state) - # add logging generate_kwargs.update(max_new_tokens=self.max_output_len) - # convert to dict if isinstance(input_tokens, torch.Tensor): input_tokens = dict(input_ids=input_tokens) for t in input_tokens: if torch.is_tensor(input_tokens[t]): input_tokens[t] = input_tokens[t].to(torch.cuda.current_device()) - print(f" input_tokens[{t}].shape: {input_tokens[t].shape}") outputs = self.sharded_model.generate(**input_tokens, **generate_kwargs, early_stopping=early_stopping) print(f"outputs.shape {outputs.shape}") return outputs - # inputs should be one of the following types - # 1. BatchEncoding (e.g. tokenizer batch_encode) - # 2. list of input token ids (e.g. appended result of tokenizer encode) - # 3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt') - # NOTE For torch.Tensor inputs representing a batch of inputs, we are unable to retrieve - # the actual length (e.g. number of tokens) of each input without attention mask - # Hence, for torch.Tensor with shape [bs, l] where bs > 1, we will assume - # all the inputs in the batch has the maximum length l def prepare_batch_state(self, inputs: [BatchEncoding, torch.Tensor]) -> BatchInferState: - # records length based on attention mask - # Any better method? + """ + Create and prepare BatchInferState used for inference during model forwrad, + by processing each sequence of the given inputs + + Args: + inputs: should be one of the following types + 1. BatchEncoding (e.g. tokenizer batch_encode) + 2. list of input token ids (e.g. appended result of tokenizer encode) + 3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt') + NOTE For torch.Tensor inputs representing a batch of inputs, we are unable to retrieve + the actual length (e.g. number of tokens) of each input without attention mask + Hence, for torch.Tensor with shape [bs, l] where bs > 1, we will assume + all the inputs in the batch has the maximum length l + Returns: + BatchInferState: the states for the current batch during inference + """ if not isinstance(inputs, (BatchEncoding, list, torch.Tensor)): raise TypeError(f"inputs type {type(inputs)} is not supported in prepare_batch_state") @@ -191,7 +153,6 @@ def prepare_batch_state(self, inputs: [BatchEncoding, torch.Tensor]) -> BatchInf else: batch_size = inputs.shape[0] - # block_loc = torch.empty(batch_size, self.max_input_len + self.max_output_len, dtype=torch.long, device="cuda") seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda") seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda") start_index = 0 @@ -217,21 +178,45 @@ def prepare_batch_state(self, inputs: [BatchEncoding, torch.Tensor]) -> BatchInf batch_infer_state.seq_len = seq_lengths.to('cuda') # might want to assign specific device batch_infer_state.start_loc = seq_start_indexes.to('cuda') batch_infer_state.block_loc = block_loc - # NOTE BatchInferState.total_token_num revised (not pushed yet) - # Now we want actual total token num based on seq_len, instead of dummy ones in test - # (Could still use the dummy one for testing usage) - batch_infer_state.set_cache_manager(self.cache_manager) batch_infer_state.decode_layer_id = 0 batch_infer_state.past_key_values_len = 0 batch_infer_state.is_context_stage = True + batch_infer_state.set_cache_manager(self.cache_manager) return batch_infer_state + # TODO might want to implement the func that generates output tokens by passing BatchInferState + # as an arg into model.forward + # requires rewriting model generate and replacing model forward + @torch.no_grad() + def generate_by_pass_infer_state(self, + input_tokens, + max_out_length: int, + generation_config: Optional[GenerationConfig] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, + **model_kwargs) -> torch.Tensor: + + # input_ids = input_tokens['input_ids'] + + # if batch_size >= 4: + # assert self.sharded_model is not None, "sharded model does not exist" + + # batch_infer_state = self.prepare_batch_state(input_tokens) + # batch_size = batch_infer_state.batch_size + # assert batch_infer_state.max_len_in_batch <= self.max_input_len + # # record sequences finish status, add early stopping, etc, + # for _ in range(min(max_out_length, self.max_output_len)): + # # ... + # self.sharded_model.forward(..., **model_kwargs) + # else: + # # Use original model + # orig_model = self.model + raise NotImplementedError("generate by passing BatchInferState is not implemented.") + + # NOTE might want to use in rewritten generate method: use after model.forward # BatchInferState is created and kept during generation # after each iter of model forward, we should update BatchInferState - # NOTE use in rewritten generate method: use after model.forward def update_batch_state(self, infer_state: Optional[BatchInferState]) -> None: - # self.b_start_loc = self.b_start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") - # self.b_seq_len += 1 batch_size = infer_state.batch_size device = infer_state.start_loc.device infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device=device) @@ -247,4 +232,4 @@ def update_batch_state(self, infer_state: Optional[BatchInferState]) -> None: # => put information already recorded in batchinferstate and pass it to model forward # => clear records in engine def add_request(): - pass + raise NotImplementedError() From aebb3f58d43cbb96660c60d3cad9507897d0ed05 Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Mon, 28 Aug 2023 20:21:34 +0800 Subject: [PATCH 5/9] remove unused file --- colossalai/shardformer/inference/__init__.py | 4 ---- 1 file changed, 4 deletions(-) delete mode 100644 colossalai/shardformer/inference/__init__.py diff --git a/colossalai/shardformer/inference/__init__.py b/colossalai/shardformer/inference/__init__.py deleted file mode 100644 index 1bce92653a8e..000000000000 --- a/colossalai/shardformer/inference/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .batch_infer_state import BatchInferState -from .kvcache_manager import MemoryManager - -__all__ = ['BatchInferState', 'MemoryManager'] From 182254766bc5d883f8d714435adf59c33af6501f Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Mon, 28 Aug 2023 20:24:49 +0800 Subject: [PATCH 6/9] add engine test demo --- tests/test_infer/test_infer_engine.py | 35 +++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 tests/test_infer/test_infer_engine.py diff --git a/tests/test_infer/test_infer_engine.py b/tests/test_infer/test_infer_engine.py new file mode 100644 index 000000000000..361d2148d641 --- /dev/null +++ b/tests/test_infer/test_infer_engine.py @@ -0,0 +1,35 @@ +import pytest +from transformers import AutoTokenizer, BloomForCausalLM + +import colossalai +from colossalai.inference.tensor_parallel import TPInferEngine +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig, ShardFormer + +TP_SIZE = 2 + + +def test_tp_infer(): + + model_path = "/data3/data/model_eval_for_commerical_use/phoenix-inst-chat-7b" + tokenizer = AutoTokenizer.from_pretrained(model_path) + tokenizer.pad_token = tokenizer.eos_token + model = BloomForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id) + + text = "Introduce some landmarks in Beijing" + input_ids = tokenizer.encode(text, return_tensors='pt') + + infer_engine = TPInferEngine(model.half(), 4, 12, 8, tp_size=TP_SIZE) + shard_config = infer_engine.create_shard_config() + shardformer = ShardFormer(shard_config=shard_config) + infer_engine.shard_model_by(shardformer) + + generate_kwargs = dict(do_sample=False) + outputs = infer_engine.generate_by_set_infer_state(input_ids, generate_kwargs) + + output_text = tokenizer.decode(outputs) + print(output_text) + + +if __name__ == '__main__': + test_tp_infer() From 278f7160b00a31546513cf576d272443d40c7f95 Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Tue, 29 Aug 2023 10:11:43 +0800 Subject: [PATCH 7/9] revise TPInferEngine --- .../inference/tensor_parallel/engine.py | 73 +++++++++++-------- tests/test_infer/test_infer_engine.py | 9 ++- 2 files changed, 51 insertions(+), 31 deletions(-) diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index f693a6a21ec6..08ff59de50fd 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -1,6 +1,7 @@ from typing import Any, Callable, List, Optional, Set, Union import torch +import torch.distributed as dist import torch.nn as nn from transformers.generation import GenerationConfig from transformers.generation.stopping_criteria import StoppingCriteriaList @@ -15,16 +16,18 @@ DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 +_supported_models = ['LlamaForCausalLM', 'LlamaModel', 'BloomForCausalLM'] + class TPInferEngine: def __init__(self, model: nn.Module, - max_batch_size, - max_input_len, - max_output_len, - dtype=torch.float16, - tp_size=1) -> None: + max_batch_size: int, + max_input_len: int, + max_output_len: int, + dtype: torch.dtype = torch.float16, + device: str = 'cuda') -> None: self.model = model self.sharded_model = None @@ -39,34 +42,47 @@ def __init__(self, # NOTE For now, we focus on tensor parallel. # We might want to merge pp and dp inference in future. - self.tp_size = tp_size - self.pp_size = 1 - self.dp_size = 1 + # self.tp_size = tp_size + # self.pp_size = 1 + # self.dp_size = 1 + torch.cuda.set_device(device=device) self.dtype = dtype self.head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads - self.head_num = self.model.config.num_attention_heads // self.tp_size + self.head_num = self.model.config.num_attention_heads self.layer_num = self.model.config.num_hidden_layers + + self.tp_size = -1 # toe be set with given shard config in self.prepare_shard_config + + def _init_manager(self) -> None: + assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig" + assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}" + self.head_num //= self.tp_size # update sharded number of heads self.cache_manager = MemoryManager(self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.layer_num) - self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) - - def create_shard_config(self) -> ShardConfig: - """ create a ShardConfig consistent with configs and attributes of the engine """ - shard_config = ShardConfig( - tensor_parallel_process_group=None, - pipeline_stage_manager=None, - enable_tensor_parallelism=False, - enable_fused_normalization=False, - enable_all_optimization=False, - enable_flash_attention=False, - enable_jit_fused=False, - inference_only=True, - ) - if self.tp_size > 1: - shard_config.enable_tensor_parallelism = True - tp_process_group = self.pg_mesh.get_group_along_axis(TP_AXIS) - shard_config.tensor_parallel_process_group = tp_process_group + + def prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) -> ShardConfig: + """ Prepare the engine with a given ShardConfig, or create a default one with tp size 1 """ + if shard_config is None: + shard_config = ShardConfig( + tensor_parallel_process_group=None, + pipeline_stage_manager=None, + enable_tensor_parallelism=False, + enable_fused_normalization=False, + enable_all_optimization=False, + enable_flash_attention=False, + enable_jit_fused=False, + inference_only=True, + ) + self.tp_size = 1 + else: + shard_config.inference_only = True + shard_config.pipeline_stage_manager = None + if shard_config.enable_tensor_parallelism: + world_size = dist.get_world_size(shard_config.tensor_parallel_process_group) + self.tp_size = world_size + self._init_manager() + return shard_config def shard_model_by(self, shardformer: ShardFormer) -> None: @@ -80,8 +96,7 @@ def shard_model_by(self, shardformer: ShardFormer) -> None: self.sharded_model = self.sharded_model.cuda() def _supported_model(self) -> List[str]: - supported_models = ['LlamaForCausalLM', 'LlamaModel', 'BloomForCausalLM'] - return supported_models + return _supported_models @torch.no_grad() def generate_by_set_infer_state(self, input_tokens, generate_kwargs, early_stopping=False) -> torch.Tensor: diff --git a/tests/test_infer/test_infer_engine.py b/tests/test_infer/test_infer_engine.py index 361d2148d641..2d3cf37f7b52 100644 --- a/tests/test_infer/test_infer_engine.py +++ b/tests/test_infer/test_infer_engine.py @@ -1,4 +1,5 @@ import pytest +import torch.distributed as dist from transformers import AutoTokenizer, BloomForCausalLM import colossalai @@ -19,9 +20,13 @@ def test_tp_infer(): text = "Introduce some landmarks in Beijing" input_ids = tokenizer.encode(text, return_tensors='pt') - infer_engine = TPInferEngine(model.half(), 4, 12, 8, tp_size=TP_SIZE) - shard_config = infer_engine.create_shard_config() + tp_process_group = dist.new_group([0, 1]) + + infer_engine = TPInferEngine(model.half(), 4, 12, 8) + shard_config = ShardConfig(enable_tensor_parallelism=True, tensor_parallel_process_group=tp_process_group) shardformer = ShardFormer(shard_config=shard_config) + + infer_engine.prepare_with_shard_config(shard_config) infer_engine.shard_model_by(shardformer) generate_kwargs = dict(do_sample=False) From d1b2fd7d4c0c7a64e92a7d1a72b1fe9a24ed8728 Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Tue, 29 Aug 2023 17:46:55 +0800 Subject: [PATCH 8/9] fix TPInferEngine, add test --- .../inference/tensor_parallel/engine.py | 88 ++++++++++--------- .../shardformer/policies/auto_policy.py | 4 +- tests/test_infer/test_infer_engine.py | 66 ++++++++++---- 3 files changed, 97 insertions(+), 61 deletions(-) diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 08ff59de50fd..721447a7fb64 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -1,8 +1,8 @@ -from typing import Any, Callable, List, Optional, Set, Union +from typing import Any, Callable, Dict, List, Optional, Set, Union import torch -import torch.distributed as dist import torch.nn as nn +from transformers import BloomForCausalLM, LlamaForCausalLM from transformers.generation import GenerationConfig from transformers.generation.stopping_criteria import StoppingCriteriaList from transformers.tokenization_utils_base import BatchEncoding @@ -16,7 +16,7 @@ DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 -_supported_models = ['LlamaForCausalLM', 'LlamaModel', 'BloomForCausalLM'] +_supported_models = ['LlamaForCausalLM', 'BloomForCausalLM'] class TPInferEngine: @@ -27,7 +27,7 @@ def __init__(self, max_input_len: int, max_output_len: int, dtype: torch.dtype = torch.float16, - device: str = 'cuda') -> None: + device: torch.device = torch.cuda.current_device()) -> None: self.model = model self.sharded_model = None @@ -37,22 +37,18 @@ def __init__(self, self.max_total_token_num = self.max_batch_size * (self.max_input_len + self.max_output_len) # Constraints relatable with specs of devices - assert self.max_batch_size <= 64 - assert self.max_input_len + self.max_output_len <= 2048 - - # NOTE For now, we focus on tensor parallel. - # We might want to merge pp and dp inference in future. - # self.tp_size = tp_size - # self.pp_size = 1 - # self.dp_size = 1 - torch.cuda.set_device(device=device) + assert self.max_batch_size <= 64, "Max batch size exceeds the constraint" + assert self.max_input_len + self.max_output_len <= 2048, "Max length exceeds the constraint" + + self.device = device self.dtype = dtype self.head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads self.head_num = self.model.config.num_attention_heads self.layer_num = self.model.config.num_hidden_layers - self.tp_size = -1 # toe be set with given shard config in self.prepare_shard_config + self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config + self.cache_manager = None def _init_manager(self) -> None: assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig" @@ -63,6 +59,7 @@ def _init_manager(self) -> None: def prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) -> ShardConfig: """ Prepare the engine with a given ShardConfig, or create a default one with tp size 1 """ + self.tp_size = 1 if shard_config is None: shard_config = ShardConfig( tensor_parallel_process_group=None, @@ -74,13 +71,11 @@ def prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) enable_jit_fused=False, inference_only=True, ) - self.tp_size = 1 else: shard_config.inference_only = True shard_config.pipeline_stage_manager = None if shard_config.enable_tensor_parallelism: - world_size = dist.get_world_size(shard_config.tensor_parallel_process_group) - self.tp_size = world_size + self.tp_size = shard_config.tensor_parallel_size self._init_manager() return shard_config @@ -90,22 +85,31 @@ def shard_model_by(self, shardformer: ShardFormer) -> None: assert self.tp_size == shardformer.shard_config.tensor_parallel_size, \ "Discrepancy between the tp size of TPInferEngine and the tp size of shard config" model_name = self.model.__class__.__name__ - assert model_name in self._supported_model(), f"Unsupported model cls {model_name} for TP inference." + assert model_name in self._supported_models(), f"Unsupported model cls {model_name} for TP inference." policy = get_autopolicy(self.model, inference_only=True) self.sharded_model, _ = shardformer.optimize(self.model, policy) - self.sharded_model = self.sharded_model.cuda() + self.sharded_model = self.sharded_model.to(self.device) - def _supported_model(self) -> List[str]: + @staticmethod + def _supported_models() -> List[str]: return _supported_models + def generate(self, input_tokens, generate_kwargs) -> torch.Tensor: + if isinstance(input_tokens, torch.Tensor): + input_tokens = dict(input_ids=input_tokens, attention_mask=torch.ones_like(input_tokens, dtype=torch.bool)) + if self.sharded_model is not None: + return self.generate_by_set_infer_state(input_tokens, generate_kwargs) + + return self.model.generate(**input_tokens, **generate_kwargs) + @torch.no_grad() - def generate_by_set_infer_state(self, input_tokens, generate_kwargs, early_stopping=False) -> torch.Tensor: + def generate_by_set_infer_state(self, input_tokens, generate_kwargs) -> torch.Tensor: """ Generate output tokens by setting BatchInferState as an attribute to the model and calling model.generate Args: inputs: should be one of the following types - 1. BatchEncoding (e.g. tokenizer batch_encode) + 1. BatchEncoding or dict (e.g. tokenizer batch_encode) 2. list of input token ids (e.g. appended result of tokenizer encode) 3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt') """ @@ -120,9 +124,10 @@ def generate_by_set_infer_state(self, input_tokens, generate_kwargs, early_stopp # NOTE this is not an expectable way to pass BatchInferState during inference # we might want to rewrite generate function (e.g. generate_by_pass_infer_state) # and pass BatchInferState via model forward - if hasattr(self.sharded_model, 'model'): + model = self.sharded_model + if isinstance(model, LlamaForCausalLM): model = self.sharded_model.model - elif hasattr(self.sharded_model, 'transformer'): + elif isinstance(model, BloomForCausalLM): model = self.sharded_model.transformer setattr(model, 'infer_state', batch_infer_state) @@ -132,21 +137,21 @@ def generate_by_set_infer_state(self, input_tokens, generate_kwargs, early_stopp input_tokens = dict(input_ids=input_tokens) for t in input_tokens: if torch.is_tensor(input_tokens[t]): - input_tokens[t] = input_tokens[t].to(torch.cuda.current_device()) + input_tokens[t] = input_tokens[t].to(self.device) - outputs = self.sharded_model.generate(**input_tokens, **generate_kwargs, early_stopping=early_stopping) + outputs = self.sharded_model.generate(**input_tokens, **generate_kwargs, early_stopping=False) print(f"outputs.shape {outputs.shape}") return outputs - def prepare_batch_state(self, inputs: [BatchEncoding, torch.Tensor]) -> BatchInferState: + def prepare_batch_state(self, inputs) -> BatchInferState: """ Create and prepare BatchInferState used for inference during model forwrad, by processing each sequence of the given inputs Args: inputs: should be one of the following types - 1. BatchEncoding (e.g. tokenizer batch_encode) + 1. BatchEncoding or dict (e.g. tokenizer batch_encode) 2. list of input token ids (e.g. appended result of tokenizer encode) 3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt') NOTE For torch.Tensor inputs representing a batch of inputs, we are unable to retrieve @@ -156,10 +161,10 @@ def prepare_batch_state(self, inputs: [BatchEncoding, torch.Tensor]) -> BatchInf Returns: BatchInferState: the states for the current batch during inference """ - if not isinstance(inputs, (BatchEncoding, list, torch.Tensor)): + if not isinstance(inputs, (BatchEncoding, dict, list, torch.Tensor)): raise TypeError(f"inputs type {type(inputs)} is not supported in prepare_batch_state") - if isinstance(inputs, BatchEncoding): + if isinstance(inputs, (BatchEncoding, dict)): attn_masks = inputs['attention_mask'] batch_size = attn_masks.shape[0] max_len_in_batch = attn_masks.shape[1] @@ -168,12 +173,12 @@ def prepare_batch_state(self, inputs: [BatchEncoding, torch.Tensor]) -> BatchInf else: batch_size = inputs.shape[0] - seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device=self.device) + seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device=self.device) start_index = 0 max_len_in_batch = -1 - if isinstance(inputs, BatchEncoding): + if isinstance(inputs, (BatchEncoding, dict)): for i, attn_mask in enumerate(attn_masks): curr_seq_len = torch.sum(attn_mask) seq_lengths[i] = curr_seq_len @@ -188,10 +193,14 @@ def prepare_batch_state(self, inputs: [BatchEncoding, torch.Tensor]) -> BatchInf start_index += curr_seq_len max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch - block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device="cuda") + print(" 666 ", max_len_in_batch) + + block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), + dtype=torch.long, + device=self.device) batch_infer_state = BatchInferState(batch_size, max_len_in_batch) - batch_infer_state.seq_len = seq_lengths.to('cuda') # might want to assign specific device - batch_infer_state.start_loc = seq_start_indexes.to('cuda') + batch_infer_state.seq_len = seq_lengths.to(self.device) # might want to assign specific device + batch_infer_state.start_loc = seq_start_indexes.to(self.device) batch_infer_state.block_loc = block_loc batch_infer_state.decode_layer_id = 0 batch_infer_state.past_key_values_len = 0 @@ -210,12 +219,8 @@ def generate_by_pass_infer_state(self, stopping_criteria: Optional[StoppingCriteriaList] = None, prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, **model_kwargs) -> torch.Tensor: - - # input_ids = input_tokens['input_ids'] - # if batch_size >= 4: # assert self.sharded_model is not None, "sharded model does not exist" - # batch_infer_state = self.prepare_batch_state(input_tokens) # batch_size = batch_infer_state.batch_size # assert batch_infer_state.max_len_in_batch <= self.max_input_len @@ -224,8 +229,7 @@ def generate_by_pass_infer_state(self, # # ... # self.sharded_model.forward(..., **model_kwargs) # else: - # # Use original model - # orig_model = self.model + # Use original model to generate raise NotImplementedError("generate by passing BatchInferState is not implemented.") # NOTE might want to use in rewritten generate method: use after model.forward diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 43ea1c5ab7f6..0ffa7fbeeab1 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -134,7 +134,9 @@ class PolicyLocation: _INFER_POLICY_LIST = { # LlaMa "transformers.models.llama.modeling_llama.LlamaModel": - PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy") + PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy"), + "transformers.models.llama.modeling_llama.LlamaForCausalLM": + PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy"), } diff --git a/tests/test_infer/test_infer_engine.py b/tests/test_infer/test_infer_engine.py index 2d3cf37f7b52..7fcb36554b90 100644 --- a/tests/test_infer/test_infer_engine.py +++ b/tests/test_infer/test_infer_engine.py @@ -1,40 +1,70 @@ import pytest -import torch.distributed as dist -from transformers import AutoTokenizer, BloomForCausalLM +import torch +from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM import colossalai from colossalai.inference.tensor_parallel import TPInferEngine from colossalai.logging import disable_existing_loggers from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn TP_SIZE = 2 +BATCH_SIZE = 4 +MAX_INPUT_LEN = 16 +MAX_OUTPUT_LEN = 8 -def test_tp_infer(): +def test_orig_generate(): + input_ids = torch.randint(low=10, high=1000, size=(BATCH_SIZE, MAX_INPUT_LEN)) - model_path = "/data3/data/model_eval_for_commerical_use/phoenix-inst-chat-7b" - tokenizer = AutoTokenizer.from_pretrained(model_path) - tokenizer.pad_token = tokenizer.eos_token - model = BloomForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id) + model_config = LlamaConfig() + model = LlamaForCausalLM(model_config) + shard_config = ShardConfig(enable_tensor_parallelism=False) - text = "Introduce some landmarks in Beijing" - input_ids = tokenizer.encode(text, return_tensors='pt') + # init TPInferEngine and + infer_engine = TPInferEngine(model, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + infer_engine.prepare_with_shard_config(shard_config) + + # original model generate + generate_kwargs = dict(do_sample=False) + infer_engine.generate(input_ids, generate_kwargs) - tp_process_group = dist.new_group([0, 1]) - infer_engine = TPInferEngine(model.half(), 4, 12, 8) - shard_config = ShardConfig(enable_tensor_parallelism=True, tensor_parallel_process_group=tp_process_group) +def run(): + model_config = LlamaConfig() + model = LlamaForCausalLM(model_config) + shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True) shardformer = ShardFormer(shard_config=shard_config) - infer_engine.prepare_with_shard_config(shard_config) + infer_engine = TPInferEngine(model, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + infer_engine.prepare_with_shard_config(shard_config=shard_config) infer_engine.shard_model_by(shardformer) - generate_kwargs = dict(do_sample=False) - outputs = infer_engine.generate_by_set_infer_state(input_ids, generate_kwargs) + assert infer_engine.cache_manager is not None + assert infer_engine.tp_size == TP_SIZE + assert infer_engine.head_num == model_config.num_attention_heads // TP_SIZE + + # TODO After adding forward replacement for CausalLM, + # uncomment these lines to test sharded model generate + # generate_kwargs = dict(do_sample=False) + # infer_engine.generate(input_ids, generate_kwargs) + + torch.cuda.empty_cache() + + +def check_engine(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run() + - output_text = tokenizer.decode(outputs) - print(output_text) +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_engine_infer(): + spawn(check_engine, TP_SIZE) if __name__ == '__main__': - test_tp_infer() + test_orig_generate() + test_engine_infer() From de3f021b268bb58396e532b33dfe39a537c10dd0 Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Tue, 29 Aug 2023 18:21:37 +0800 Subject: [PATCH 9/9] fix --- colossalai/inference/tensor_parallel/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 721447a7fb64..f643d892aab9 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -180,7 +180,7 @@ def prepare_batch_state(self, inputs) -> BatchInferState: max_len_in_batch = -1 if isinstance(inputs, (BatchEncoding, dict)): for i, attn_mask in enumerate(attn_masks): - curr_seq_len = torch.sum(attn_mask) + curr_seq_len = int(torch.sum(attn_mask)) seq_lengths[i] = curr_seq_len seq_start_indexes[i] = start_index start_index += curr_seq_len