Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
fbff5d3
add kv cache memory manager
yuanheng-zhao Aug 23, 2023
2d55ace
add stateinfo during inference
yuanheng-zhao Aug 23, 2023
e55e565
add
CjhHa1 Aug 22, 2023
a971535
add infer example
CjhHa1 Aug 23, 2023
0ae5bb7
finish
CjhHa1 Aug 23, 2023
a8f7386
finish
CjhHa1 Aug 23, 2023
cb45cf8
format
yuanheng-zhao Aug 23, 2023
4f21bc5
format
yuanheng-zhao Aug 23, 2023
389d0d4
rename file
yuanheng-zhao Aug 23, 2023
bdba1b5
add kv cache test
yuanheng-zhao Aug 24, 2023
813e23a
revise on BatchInferState
yuanheng-zhao Aug 24, 2023
5b08d60
Merge commit 'refs/pull/4495/head' of https://github.com/hpcaitech/Co…
isky-cd Aug 24, 2023
469a3c5
add inference test for llama
isky-cd Aug 24, 2023
5993a0f
fix conflict
isky-cd Aug 24, 2023
a98000f
fix conflict
isky-cd Aug 24, 2023
7686c07
fix conflict
isky-cd Aug 24, 2023
ba089d7
feature: add some new features for llama engine
isky-cd Aug 24, 2023
68b5fe8
adapt colossalai triton interface
isky-cd Aug 24, 2023
6021b13
Change the parent class of llama policy
isky-cd Aug 24, 2023
6a1bafa
add nvtx
isky-cd Aug 25, 2023
f79308e
move llama inference code to tensor_parallel
isky-cd Aug 27, 2023
a6cc3dd
Merge branch 'feature/colossal-inference' of https://github.com/hpcai…
isky-cd Aug 28, 2023
2a6a380
fix __init__.py
isky-cd Aug 28, 2023
d10dcf4
rm tensor_parallel
isky-cd Aug 28, 2023
fb2603b
fix: fix bugs in auto_policy.py
isky-cd Aug 28, 2023
92fd955
fix:rm some unused codes
isky-cd Aug 28, 2023
c747249
mv colossalai/tpinference to colossalai/inference/tensor_parallel
isky-cd Aug 28, 2023
8507fc5
Merge branch 'feature/colossal-inference' into llama_test_branch
isky-cd Aug 30, 2023
c27088f
change __init__.py
isky-cd Aug 30, 2023
af16040
save change
isky-cd Aug 30, 2023
bfc55cc
fix conflict
isky-cd Aug 30, 2023
f30f542
fix engine
isky-cd Aug 30, 2023
4b52ebd
Bug fix: Fix hang
isky-cd Aug 30, 2023
6d06421
remove llama_infer_engine.py
isky-cd Aug 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions colossalai/inference/tensor_parallel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .modeling.llama import LlamaInferenceForwards
from .pollcies.llama import LlamaModelInferPolicy
from .engine import TPInferEngine
from .kvcache_manager import MemoryManager

__all__ = ['MemoryManager', 'TPInferEngine']
__all__ = ['LlamaInferenceForwards', 'LlamaModelInferPolicy', 'MemoryManager', 'TPInferEngine']
22 changes: 11 additions & 11 deletions colossalai/inference/tensor_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2

_supported_models = ['LlamaForCausalLM', 'BloomForCausalLM']
_supported_models = ['LlamaForCausalLM', 'LlamaModel', 'BloomForCausalLM']


class TPInferEngine:
Expand All @@ -27,7 +27,7 @@ def __init__(self,
max_input_len: int,
max_output_len: int,
dtype: torch.dtype = torch.float16,
device: torch.device = torch.cuda.current_device()) -> None:
device: str = 'cuda') -> None:
self.model = model
self.sharded_model = None

Expand All @@ -40,7 +40,7 @@ def __init__(self,
assert self.max_batch_size <= 64, "Max batch size exceeds the constraint"
assert self.max_input_len + self.max_output_len <= 2048, "Max length exceeds the constraint"

self.device = device
torch.device(device=device)
self.dtype = dtype

self.head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads
Expand Down Expand Up @@ -88,7 +88,7 @@ def shard_model_by(self, shardformer: ShardFormer) -> None:
assert model_name in self._supported_models(), f"Unsupported model cls {model_name} for TP inference."
policy = get_autopolicy(self.model, inference_only=True)
self.sharded_model, _ = shardformer.optimize(self.model, policy)
self.sharded_model = self.sharded_model.to(self.device)
self.sharded_model = self.sharded_model.cuda()

@staticmethod
def _supported_models() -> List[str]:
Expand Down Expand Up @@ -137,7 +137,7 @@ def generate_by_set_infer_state(self, input_tokens, generate_kwargs) -> torch.Te
input_tokens = dict(input_ids=input_tokens)
for t in input_tokens:
if torch.is_tensor(input_tokens[t]):
input_tokens[t] = input_tokens[t].to(self.device)
input_tokens[t] = input_tokens[t].cuda()

outputs = self.sharded_model.generate(**input_tokens, **generate_kwargs, early_stopping=False)

Expand Down Expand Up @@ -173,8 +173,8 @@ def prepare_batch_state(self, inputs) -> BatchInferState:
else:
batch_size = inputs.shape[0]

seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device=self.device)
seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device=self.device)
seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
start_index = 0

max_len_in_batch = -1
Expand All @@ -197,10 +197,10 @@ def prepare_batch_state(self, inputs) -> BatchInferState:

block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len),
dtype=torch.long,
device=self.device)
device='cuda')
batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
batch_infer_state.seq_len = seq_lengths.to(self.device) # might want to assign specific device
batch_infer_state.start_loc = seq_start_indexes.to(self.device)
batch_infer_state.seq_len = seq_lengths.to('cuda') # might want to assign specific device
batch_infer_state.start_loc = seq_start_indexes.to('cuda')
batch_infer_state.block_loc = block_loc
batch_infer_state.decode_layer_id = 0
batch_infer_state.past_key_values_len = 0
Expand Down Expand Up @@ -251,4 +251,4 @@ def update_batch_state(self, infer_state: Optional[BatchInferState]) -> None:
# => put information already recorded in batchinferstate and pass it to model forward
# => clear records in engine
def add_request():
raise NotImplementedError()
raise NotImplementedError()
3 changes: 3 additions & 0 deletions colossalai/inference/tensor_parallel/modeling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .llama import LlamaInferenceForwards

__all__ = ['LlamaInferenceForwards']
Loading