Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions colossalai/inference/tensor_parallel/modeling/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def bloom_model_forward(
all_hidden_states = all_hidden_states + (hidden_states,)

if self.gradient_checkpointing and self.training:
# FIXME: currently our KV cache manager does not handle this condition
# NOTE: currently our KV cache manager does not handle this condition
def create_custom_forward(module):

def custom_forward(*inputs):
Expand Down Expand Up @@ -240,7 +240,8 @@ def custom_forward(*inputs):
all_hidden_states = all_hidden_states + (hidden_states,)

# update indices of kv cache block
# TODO: might want to remove this part, instead, better to pass the BatchInferState from model forward,
# NOT READY FOR PRIME TIME
# might want to remove this part, instead, better to pass the BatchInferState from model forward,
# and update these information in engine.generate after model foward called
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
infer_state.seq_len += 1
Expand Down
5 changes: 3 additions & 2 deletions colossalai/inference/tensor_parallel/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,13 @@ def llama_model_forward(
past_key_values_length = 0

if past_key_values is not None:
# TODO dummy but work, revise it
# NOT READY FOR PRIME TIME
# dummy but work, revise it
past_key_values_length = infer_state.cache_manager.past_key_values_length
# past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length

# FIXME: differentiate with prefill stage
# NOTE: differentiate with prefill stage
# block_loc require different value-assigning method for two different stage
if use_cache and seq_length != 1:
# NOTE assuem prefill stage
Expand Down
8 changes: 4 additions & 4 deletions colossalai/inference/tensor_parallel/policies/llama.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from functools import partial

from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm

from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy

from ..modeling.llama import LlamaInferenceForwards
from ..modeling.llama import get_llama_vllm_rmsnorm_forward
from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward


class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
Expand Down Expand Up @@ -37,8 +37,8 @@ def module_policy(self):
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=LlamaAttention)
# TODO: adding rms_norm caused precision issue, fix @tiandiao123

# NOTE: adding rms_norm caused precision issue, fix @tiandiao123
# infer_forward = get_llama_vllm_rmsnorm_forward()
# if infer_forward is not None:
# method_replacement = {'forward': partial(infer_forward)}
Expand Down
69 changes: 40 additions & 29 deletions colossalai/kernel/triton/self_attention_nofusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
from .qkv_matmul_kernel import qkv_gemm_4d_kernel
from .softmax import softmax_kernel

def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float):
r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels
def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
input_mask: torch.Tensor, scale: float):
r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels
Args:
q (torch.Tensor): Q embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
k (torch.Tensor): K embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size)
Expand All @@ -36,39 +37,49 @@ def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: t
# head_size * num_of_head
d_model = q.shape[-1] * q.shape[-2]

score_output = torch.empty(
(batches, H, M, N), device=q.device, dtype=q.dtype)
score_output = torch.empty((batches, H, M, N), device=q.device, dtype=q.dtype)

grid = lambda meta: (
batches,
H,
triton.cdiv(M, meta["BLOCK_SIZE_M"]) *
triton.cdiv(N, meta["BLOCK_SIZE_N"]),
triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]),
)

qkv_gemm_4d_kernel[grid](
q, k, score_output,
M, N, K,
q.stride(0), q.stride(2), q.stride(1), q.stride(3),
k.stride(0), k.stride(2), k.stride(3), k.stride(1),
score_output.stride(0), score_output.stride(1), score_output.stride(2), score_output.stride(3),
q,
k,
score_output,
M,
N,
K,
q.stride(0),
q.stride(2),
q.stride(1),
q.stride(3),
k.stride(0),
k.stride(2),
k.stride(3),
k.stride(1),
score_output.stride(0),
score_output.stride(1),
score_output.stride(2),
score_output.stride(3),
scale=scale,
# currently manually setting, later on we can use auto-tune config to match best setting
# currently manually setting, later on we can use auto-tune config to match best setting
BLOCK_SIZE_M=64,
BLOCK_SIZE_N=32,
BLOCK_SIZE_K=32,
GROUP_SIZE_M=8,
)

softmax_output = torch.empty(
score_output.shape, device=score_output.device, dtype=score_output.dtype)

softmax_output = torch.empty(score_output.shape, device=score_output.device, dtype=score_output.dtype)
score_output_shape = score_output.shape

score_output = score_output.view(-1, score_output.shape[-1])
n_rows, n_cols = score_output.shape

if n_rows <= 350000:

block_size = max(triton.next_power_of_2(n_cols), 2)
num_warps = 4
if block_size >= 4096:
Expand All @@ -78,37 +89,39 @@ def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: t
else:
num_warps = 4

softmax_kernel[(n_rows, )](
softmax_kernel[(n_rows,)](
softmax_output,
score_output,
score_output.stride(0),
n_cols,
mask_ptr = input_mask,
mask_ptr=input_mask,
num_warps=num_warps,
BLOCK_SIZE=block_size,
)

else:
#TODO: change softmax kernel functions to make it suitable for large size dimension
# NOTE: change softmax kernel functions to make it suitable for large size dimension
softmax_output = torch.nn.functional.softmax(score_output, dim=-1)
softmax_output = softmax_output.view(*score_output_shape)

batches, H, M, K = softmax_output.shape
N = v.shape[-1]

output = torch.empty(
(batches, M, H, N), device=softmax_output.device, dtype=softmax_output.dtype)
output = torch.empty((batches, M, H, N), device=softmax_output.device, dtype=softmax_output.dtype)

grid = lambda meta: (
batches,
H,
triton.cdiv(M, meta["BLOCK_SIZE_M"]) *
triton.cdiv(N, meta["BLOCK_SIZE_N"]),
triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]),
)

qkv_gemm_4d_kernel[grid](
softmax_output, v, output,
M, N, K,
softmax_output,
v,
output,
M,
N,
K,
softmax_output.stride(0),
softmax_output.stride(1),
softmax_output.stride(2),
Expand All @@ -129,7 +142,6 @@ def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: t
)
return output.view(batches, -1, d_model)


def self_attention_compute_using_triton(qkv,
input_mask,
layer_past,
Expand All @@ -152,7 +164,6 @@ def self_attention_compute_using_triton(qkv,
k = k.view(batches, -1, num_of_heads, head_size)
v = v.view(batches, -1, num_of_heads, head_size)

data_output_triton = self_attention_forward_without_fusion(
q, k, v, input_mask, scale)
data_output_triton = self_attention_forward_without_fusion(q, k, v, input_mask, scale)

return data_output_triton
return data_output_triton
31 changes: 17 additions & 14 deletions tests/test_infer/test_llama_infer.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,63 @@
import os

import numpy as np
import pytest
import torch
import numpy as np
import torch.distributed as dist
from transformers import LlamaForCausalLM, LlamaTokenizer

import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from transformers import LlamaForCausalLM, LlamaTokenizer
from colossalai.cluster import ProcessGroupMesh
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.inference.tensor_parallel.engine import TPInferEngine
import torch.distributed as dist
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn

os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
TPSIZE = 2
BATCH_SIZE = 8
MAX_INPUT_LEN = 12
MAX_OUTPUT_LEN = 100


def init_to_get_rotary(self, base=10000):
self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
if not hasattr(self.config, "rope_scaling"):
rope_scaling_factor = 1.0
else:
rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0
if hasattr(self.config,"max_sequence_length"):
if hasattr(self.config, "max_sequence_length"):
max_seq_len = self.config.max_sequence_length
elif hasattr(self.config,"max_position_embeddings"):
elif hasattr(self.config, "max_position_embeddings"):
max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
else:
max_seq_len = 2048 * rope_scaling_factor
max_seq_len = 2048 * rope_scaling_factor
base = float(base)
inv_freq = 1.0 / (base ** (torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / self.config.head_dim_))
inv_freq = 1.0 / (base**(torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) /
self.config.head_dim_))
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
freqs = torch.outer(t, inv_freq)

self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
return


@parameterize('test_config', [{
'tp_size': TPSIZE,
}])
def run_llama_test(test_config):

llama_model_path = "/data/scratch/llama-7b-hf"
tokenizer = LlamaTokenizer.from_pretrained(llama_model_path)
tokenizer.pad_token_id = tokenizer.unk_token_id
model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id)
init_to_get_rotary(model.model, base=10000)
model = model.half()

text = "how is weather today?"
input_ids = tokenizer.encode(text, return_tensors='pt', device='cuda')

infer_engine = TPInferEngine(model.half(), BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True)
shardformer = ShardFormer(shard_config=shard_config)
Expand All @@ -65,7 +68,7 @@ def run_llama_test(test_config):
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
outputs = infer_engine.generate(input_ids, generate_kwargs)
print("outputs.shape: ", outputs.shape)

print("outputs: ", outputs)

output_text = tokenizer.decode(outputs[0])
Expand Down