Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions colossalai/inference/modeling/backends/attention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from dataclasses import dataclass

import torch
from flash_attn import flash_attn_varlen_func

from colossalai.inference.config import ModelShardInferenceConfig
from colossalai.kernel.kernel_loader import InferenceOpsLoader
Expand Down Expand Up @@ -44,14 +43,17 @@ class CudaAttentionBackend(AttentionBackend):
it uses Triton op `context_attention_unpadded` for prefilling and our cuda op `flash_decoding_attention` for decoding.
"""

def __init__(self, use_flash_attn: bool):
def __init__(self, use_flash_attn: bool = False):
super().__init__()
self.inference_ops = InferenceOpsLoader().load()
self.use_flash_attn = use_flash_attn

def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
if self.use_flash_attn:
token_nums = kwargs.get("token_nums", -1)

Comment thread
char-1ee marked this conversation as resolved.
from flash_attn import flash_attn_varlen_func

attn_output = flash_attn_varlen_func(
attn_metadata.query_states,
attn_metadata.key_states,
Expand Down
2 changes: 0 additions & 2 deletions colossalai/inference/modeling/models/nopadding_baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,6 @@ def forward(

self.pre_attention_backend.decode(
attn_metadata,
cos=cos_sin[0],
sin=cos_sin[1],
q_len=q_len,
)
attn_output = self.attention_backend.decode(
Expand Down
4 changes: 2 additions & 2 deletions colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def llama_model_forward(

elif use_cuda_kernel:
if can_use_flash_attn2(inputmetadata.dtype):
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.int32), (1, 0))

hidden_dim = self._cos_cached.size(-1)
total_length = hidden_states.size(0)
Expand Down Expand Up @@ -265,7 +265,7 @@ def __init__(
mlp_dproj: ParallelModule = None,
process_group: ProcessGroup = None,
):
"""A Unified Layer for
"""Replacement of LlamaMLP layer.

Args:
config (LlamaConfig): Holding the Llama model config.
Expand Down
2 changes: 2 additions & 0 deletions colossalai/inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def can_use_flash_attn2(dtype: torch.dtype) -> bool:
return False

try:
from flash_attn import flash_attn_varlen_func # noqa

return True
except ImportError:
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
Expand Down
2 changes: 1 addition & 1 deletion colossalai/shardformer/layer/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def get_pad_info(padding_mask: torch.Tensor) -> Tuple[int, torch.Tensor, torch.T
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return max_seqlen_in_batch, cu_seqlens, indices


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def prepare_data(
num_tokens = torch.sum(context_lengths).item()

max_seq_len_in_batch = context_lengths.max()
cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.int32), (1, 0))

kv_size = (num_tokens, num_kv_heads, HEAD_DIM)
key = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
Expand Down
161 changes: 161 additions & 0 deletions tests/test_infer/test_models/test_custom_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import os
import random

import numpy as np
import pytest
import torch
import torch.distributed as dist
from torch.multiprocessing import Manager
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, LlamaForCausalLM, LlamaTokenizer

import colossalai
import colossalai.inference.modeling.policy as policy
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn

# NOTE: To test a model with the inference engine, you need to provide the path to your
# local pretrained model weights in the MODEL_MAP dictionary
MODEL_MAP = {
"baichuan": {
"model": AutoModelForCausalLM,
"tokenizer": AutoTokenizer,
"policy": policy.NoPaddingBaichuanModelInferPolicy,
"model_name_or_path": "baichuan-inc/Baichuan2-13B-Base", # provide the path to local model weights
},
"llama": {
"model": LlamaForCausalLM,
"tokenizer": LlamaTokenizer,
"policy": policy.NoPaddingLlamaModelInferPolicy,
"model_name_or_path": "meta-llama/Llama-2-70b-hf",
},
}

MODELS_TO_TEST = ["llama", "baichuan"] # Specify the models to test


@parameterize("model", MODELS_TO_TEST)
@parameterize("prompt_template", [None, "model_specific"])
@parameterize("do_sample", [False])
@parameterize("use_cuda_kernel", [True])
@pytest.mark.largedist
@rerun_if_address_is_in_use()
def test_model(model, prompt_template, do_sample, use_cuda_kernel):
model_path = MODEL_MAP[model]["model_name_or_path"]
if not os.path.exists(model_path):
pytest.skip(
f"There is no local model address included for {model}, please replace this address with a valid one."
)

if prompt_template == "model_specific":
prompt_template = model

model_config = MODEL_MAP[model]

kwargs1 = {
"model": model,
"use_engine": True,
"prompt_template": prompt_template,
"do_sample": do_sample,
"policy": model_config["policy"](),
"use_cuda_kernel": use_cuda_kernel,
}

kwargs2 = {
"model": model,
"use_engine": False,
"prompt_template": prompt_template,
"do_sample": do_sample,
"policy": None,
"use_cuda_kernel": use_cuda_kernel,
}

colossal_tp_1_output = run_engine(1, **kwargs1)
colossal_tp_2_output = run_engine(2, **kwargs1)
transformer_tp_1_output = run_engine(1, **kwargs2)

for s1, s2, s3 in zip(colossal_tp_1_output, colossal_tp_2_output, transformer_tp_1_output):
assert s1 == s3, f"\nColossalAI TP=1 Output: {s1}\nTransformers Output: {s3}"
assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}"


def run_engine(world_size, **kwargs):
manager = Manager()
result_list = manager.list([-1] * world_size) # Create a shared list
spawn(run_dist, world_size, func_to_run=_run_engine, ret=result_list, **kwargs)
return result_list[0]


def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")

if ret:
ret[rank] = func_to_run(**kwargs)
else:
func_to_run(**kwargs)


def _run_engine(model, use_engine=False, do_sample=False, use_cuda_kernel=False, prompt_template=None, policy=None):
setup_seed(20)
model_config = MODEL_MAP[model]
model_name_or_path = model_config["model_name_or_path"]
tokenizer = model_config["tokenizer"].from_pretrained(model_name_or_path, use_fast=False, trust_remote_code=True)
model = model_config["model"].from_pretrained(model_name_or_path, trust_remote_code=True).half().cuda()
model = model.eval()

inputs = [
"Introduce some landmarks in Paris:",
]

output_len = 38

if do_sample:
top_p = 0.5
top_k = 50
else:
top_p = None
top_k = None

if use_engine:
inference_config = InferenceConfig(
max_output_len=output_len,
prompt_template=prompt_template,
use_cuda_kernel=use_cuda_kernel,
tp_size=dist.get_world_size(),
)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True, model_policy=policy)
assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting()
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k, max_new_tokens=output_len)
outputs = inference_engine.generate(generation_config=generation_config)
else:
if prompt_template:
# apply prompt template
inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs]
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
inputs = inputs.cuda()
generation_config = GenerationConfig(
do_sample=do_sample,
top_p=top_p,
top_k=top_k,
pad_token_id=tokenizer.pad_token_id,
max_new_tokens=output_len,
)
outputs = model.generate(inputs, generation_config=generation_config)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
return outputs


def setup_seed(seed):
torch.manual_seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)


if __name__ == "__main__":
test_model()