Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):

self.rotary_pos_emb = RotaryEmbedding(
rotary_dim // 2,
original_impl=config.original_rope,
# original_impl=config.original_rope, # config has no attribute original_rope
device=device,
dtype=config.torch_dtype,
)
Expand Down
1 change: 0 additions & 1 deletion examples/inference/bench_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def run_llama_test(args):
tokenizer.pad_token_id = tokenizer.unk_token_id
model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id)
model = model.half()

model_config = model.config

shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
Expand Down
29 changes: 16 additions & 13 deletions tests/test_infer/test_bloom_infer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import pytest
import torch
from packaging import version
from transformers import BloomForCausalLM
from transformers.models.bloom.configuration_bloom import BloomConfig

import colossalai
from colossalai.inference.tensor_parallel import TPInferEngine
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo

TP_SIZE = 2
MAX_BATCH_SIZE = 4
Expand All @@ -26,21 +27,23 @@
],
)
def run(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_bloom_for_causal_lm")
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
orig_model = model_fn()
orig_model = orig_model.half()
data = data_gen_fn()
bloom_config = BloomConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024)
model = BloomForCausalLM(bloom_config)
model = model.half()

shard_config = ShardConfig(
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
)
infer_engine = TPInferEngine(orig_model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
shard_config = ShardConfig(
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
)
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)

generate_kwargs = dict(do_sample=False)
outputs = infer_engine.generate(data, **generate_kwargs)
input_tokens = {
"input_ids": torch.randint(1, 1000, (MAX_BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
"attention_mask": torch.ones((MAX_BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
}
outputs = infer_engine.generate(input_tokens, **generate_kwargs)

assert outputs is not None
assert outputs is not None


def check_bloom(rank, world_size, port):
Expand Down
41 changes: 21 additions & 20 deletions tests/test_infer/test_chatglm2_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,15 @@

import pytest
import torch
import torch.distributed as dist
from packaging import version
from transformers import AutoTokenizer

import colossalai
from colossalai.inference.tensor_parallel.engine import TPInferEngine
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo.transformers.chatglm2 import infer_config

os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
TPSIZE = 1
Expand All @@ -31,28 +29,31 @@
],
)
def run_chatglm2_test(test_config):
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
# pad_token_id = 0
model_fn = lambda: ChatGLMForConditionalGeneration(infer_config, empty_init=False)
orig_model = model_fn()
orig_model = orig_model.half()
text = ["how is the weather today?"]
input_ids = tokenizer.batch_encode_plus(text, return_tensors="pt", padding=True)
chatglm_config = ChatGLMConfig(
num_layers=2,
vocab_size=1200,
use_cache=True,
multi_query_attention=True,
multi_query_group_num=2,
num_attention_heads=8,
hidden_size=1024,
)
model = ChatGLMForConditionalGeneration(chatglm_config)
model = model.half()

shard_config = ShardConfig(
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
)
infer_engine = TPInferEngine(orig_model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)

infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
outputs = infer_engine.generate(input_ids, **generate_kwargs)
assert outputs is not None

# print("outputs.shape: ", outputs[0].shape)
# print("outputs: ", outputs[0])
if not dist.is_initialized() or dist.get_rank() == 0:
for o in outputs:
output_text = tokenizer.decode(o)
print(output_text)
input_tokens = {
"input_ids": torch.randint(1, 1000, (BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
"attention_mask": torch.ones((BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
}
outputs = infer_engine.generate(input_tokens, **generate_kwargs)

assert outputs is not None


def check_chatglm2(rank, world_size, port):
Expand Down
30 changes: 17 additions & 13 deletions tests/test_infer/test_llama_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import pytest
import torch
from packaging import version
from transformers import LlamaForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig

import colossalai
from colossalai.inference.tensor_parallel.engine import TPInferEngine
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo

os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
TPSIZE = 2
Expand All @@ -29,21 +30,24 @@
],
)
def run_llama_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama_for_casual_lm")
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
orig_model = model_fn()
orig_model = orig_model.half()
data = data_gen_fn()
llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024)
Comment thread
Xu-Kai marked this conversation as resolved.
model = LlamaForCausalLM(llama_config)
model = model.half()

shard_config = ShardConfig(
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
)
infer_engine = TPInferEngine(orig_model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
shard_config = ShardConfig(
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
)
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
init_to_get_rotary(model.model, base=10000)
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)

generate_kwargs = dict(do_sample=False)
outputs = infer_engine.generate(data, **generate_kwargs)
input_tokens = {
"input_ids": torch.randint(1, 1000, (BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
"attention_mask": torch.ones((BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
}
outputs = infer_engine.generate(input_tokens, **generate_kwargs)

assert outputs is not None
assert outputs is not None


def check_llama(rank, world_size, port):
Expand Down
4 changes: 1 addition & 3 deletions tests/test_infer_ops/triton/test_llama2_token_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ def test():
q = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
k = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2)
v = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2)
o = torch.empty_like()
# o = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2)

o = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda")
max_kv_cache_len = seq_len
kv_cache_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda")
kv_cache_loc = torch.zeros((Z, seq_len), dtype=torch.int32, device="cuda")
Expand Down