Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions colossalai/inference/tensor_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ def __init__(self,
self.max_output_len = max_output_len
self.max_total_token_num = self.max_batch_size * (self.max_input_len + self.max_output_len)

# Constraints relatable with specs of devices and model
# This may change into an optional arg in the future
# Constraints relatable with specs of devices
assert self.max_batch_size <= 64, "Max batch size exceeds the constraint"
assert self.max_input_len + self.max_output_len <= 4096, "Max length exceeds the constraint"
assert self.max_input_len + self.max_output_len <= 2048, "Max length exceeds the constraint"

torch.device(device=device)
self.dtype = dtype

self.head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads
Expand Down Expand Up @@ -94,20 +94,20 @@ def shard_model_by(self, shardformer: ShardFormer) -> None:
def _supported_models() -> List[str]:
return _supported_models

def generate(self, input_tokens: Union[BatchEncoding, dict, list, torch.Tensor], **generate_kwargs) -> torch.Tensor:
def generate(self, input_tokens, generate_kwargs) -> torch.Tensor:
if isinstance(input_tokens, torch.Tensor):
input_tokens = dict(input_ids=input_tokens, attention_mask=torch.ones_like(input_tokens, dtype=torch.bool))
for t in input_tokens:
if torch.is_tensor(input_tokens[t]):
input_tokens[t] = input_tokens[t].cuda()

if self.sharded_model is not None:
return self.generate_by_set_infer_state(input_tokens, **generate_kwargs)
return self.generate_by_set_infer_state(input_tokens, generate_kwargs)

return self.model.generate(**input_tokens, **generate_kwargs)

@torch.no_grad()
def generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch.Tensor:
def generate_by_set_infer_state(self, input_tokens, generate_kwargs) -> torch.Tensor:
"""
Generate output tokens by setting BatchInferState as an attribute to the model and calling model.generate

Expand Down Expand Up @@ -191,9 +191,8 @@ def prepare_batch_state(self, inputs) -> BatchInferState:
start_index += curr_seq_len
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
else:
length = max(len(input_id) for input_id in input_ids_list)
for i, input_ids in enumerate(input_ids_list):
curr_seq_len = length
curr_seq_len = len(input_ids)
seq_lengths[i] = curr_seq_len
seq_start_indexes[i] = start_index
start_index += curr_seq_len
Expand Down
5 changes: 3 additions & 2 deletions colossalai/inference/tensor_parallel/kvcache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
# https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py

import torch
from transformers.utils import logging

from colossalai.logging import get_dist_logger


class MemoryManager:
Expand All @@ -26,7 +27,7 @@ def __init__(self,
head_dim: int,
layer_num: int,
device: torch.device = torch.device('cuda')):
self.logger = logging.get_logger(__name__)
self.logger = get_dist_logger(__name__)
self.available_size = size
self.past_key_values_length = 0
self._init_mem_states(size, device)
Expand Down
13 changes: 8 additions & 5 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,16 @@
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaModel,
apply_rotary_pos_emb,
)
from transformers.utils import logging

from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
from colossalai.pipeline.stage_manager import PipelineStageManager


Expand Down Expand Up @@ -395,10 +402,6 @@ def get_llama_flash_attention_forward():

from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention

from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb

from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention

def forward(
self: LlamaAttention,
hidden_states: torch.Tensor,
Expand Down
4 changes: 0 additions & 4 deletions colossalai/shardformer/shard/shard_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ class ShardConfig:
enable_flash_attention: bool = False
enable_jit_fused: bool = False
inference_only: bool = False
enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False

# pipeline_parallel_size: int
# data_parallel_size: int
Expand Down Expand Up @@ -60,8 +58,6 @@ def _turn_on_all_optimization(self):
self.enable_fused_normalization = True
self.enable_flash_attention = True
self.enable_jit_fused = True
self.enable_sequence_parallelism = True
self.enable_sequence_overlap = True

def _infer(self):
"""
Expand Down
21 changes: 9 additions & 12 deletions tests/test_infer/test_infer_engine.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pytest
from itertools import accumulate
from packaging import version

import pytest
import torch
import torch.nn as nn
from packaging import version
from transformers import BloomConfig, BloomForCausalLM, LlamaConfig, LlamaForCausalLM
from transformers.tokenization_utils_base import BatchEncoding

Expand All @@ -21,7 +21,6 @@

CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5')


@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
def test_prepare_data():
# dummy module used for testing
Expand Down Expand Up @@ -55,24 +54,22 @@ def __init__(self):
attention_mask[i][max_seq_len - len(li):] = [1 for _ in range(len(li))]
data = dict(input_ids=input_ids_list, attention_mask=attention_mask)
inputs_batch_encoding = BatchEncoding(data=data)

seq_lengths = [len(li) for li in input_ids_list]
start_loc = list(accumulate([0] + seq_lengths[:-1]))
seq_lengths = torch.tensor(seq_lengths, dtype=torch.int32)
start_loc = torch.tensor(start_loc, dtype=torch.int32)

# input token id list as inputs
batch_state_out1 = infer_engine.prepare_batch_state(inputs_batch_encoding)
# BatchEncoding as inputs
batch_state_out2 = infer_engine.prepare_batch_state(input_ids_list)

assert batch_state_out1.batch_size == batch_state_out2.batch_size == batch_size
assert torch.equal(batch_state_out1.seq_len, batch_state_out2.seq_len)

# The following tests are discarded for now, and will be reused after all features are added
# assert torch.equal(batch_state_out1.seq_len.to(seq_lengths.device), seq_lengths)
# assert torch.equal(batch_state_out2.seq_len.to(seq_lengths.device), seq_lengths)
# assert torch.equal(batch_state_out1.start_loc.to(start_loc.device), start_loc)
# assert torch.equal(batch_state_out2.start_loc.to(start_loc.device), start_loc)

assert torch.equal(batch_state_out1.seq_len.to(seq_lengths.device), seq_lengths)
assert torch.equal(batch_state_out2.seq_len.to(seq_lengths.device), seq_lengths)
assert torch.equal(batch_state_out1.start_loc.to(start_loc.device), start_loc)
assert torch.equal(batch_state_out2.start_loc.to(start_loc.device), start_loc)

@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
def test_orig_generate():
Expand All @@ -91,7 +88,7 @@ def test_orig_generate():

# original model generate
generate_kwargs = dict(do_sample=False)
infer_engine.generate(input_ids, **generate_kwargs)
infer_engine.generate(input_ids, generate_kwargs)

torch.cuda.empty_cache()

Expand Down