-
Notifications
You must be signed in to change notification settings - Fork 4.5k
[Infer] Bug fix rotary embedding in llama #4608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
tiandiao123
merged 8 commits into
hpcaitech:feature/colossal-inference
from
Xu-Kai:inference
Sep 5, 2023
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
c7d6d8d
fix rotary embedding
Xu-Kai 0eadde2
merge colossal inference
Xu-Kai 64c0782
delete print
Xu-Kai 64d8d55
fix init seq len bug
Xu-Kai a255925
rename pytest
Xu-Kai c525e78
add benchmark for llama
Xu-Kai 117cdf1
refactor codes
Xu-Kai 2af8002
delete useless code
Xu-Kai File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,93 @@ | ||
| # Adapted from ModelTC https://github.com/ModelTC/lightllm | ||
| import torch | ||
| import triton | ||
| import triton.language as tl | ||
|
|
||
|
|
||
| @triton.jit | ||
| def _rotary_kernel( | ||
| q, | ||
| Cos, | ||
| Sin, | ||
| q_bs_stride, | ||
| q_h_stride, | ||
| q_d_stride, | ||
| cos_bs_stride, | ||
| cos_d_stride, | ||
| total_len, | ||
| HEAD_NUM: tl.constexpr, | ||
| BLOCK_HEAD: tl.constexpr, | ||
| BLOCK_SEQ: tl.constexpr, | ||
| HEAD_DIM: tl.constexpr, | ||
| ): | ||
| current_head_index = tl.program_id(0) | ||
| current_seq_index = tl.program_id(1) | ||
|
|
||
| current_head_range = current_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) | ||
| current_seq_range = current_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) | ||
|
|
||
| dim_range0 = tl.arange(0, HEAD_DIM // 2) | ||
| dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) | ||
|
|
||
| off_q0 = current_seq_range[:, None, None] * q_bs_stride + current_head_range[ | ||
| None, :, None] * q_h_stride + dim_range0[None, None, :] * q_d_stride | ||
| off_q1 = current_seq_range[:, None, None] * q_bs_stride + current_head_range[ | ||
| None, :, None] * q_h_stride + dim_range1[None, None, :] * q_d_stride | ||
|
|
||
| off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride | ||
|
|
||
| q0 = tl.load(q + off_q0, | ||
| mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), | ||
| other=0.0) | ||
| q1 = tl.load(q + off_q1, | ||
| mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), | ||
| other=0.0) | ||
|
|
||
| cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) | ||
| sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) | ||
|
|
||
| out0 = q0 * cos - q1 * sin | ||
| out1 = q0 * sin + q1 * cos | ||
|
|
||
| tl.store(q + off_q0, | ||
| out0, | ||
| mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM)) | ||
| tl.store(q + off_q1, | ||
| out1, | ||
| mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM)) | ||
|
|
||
| return | ||
|
|
||
|
|
||
| @torch.no_grad() | ||
| def rotary_embedding_fwd(q, cos, sin): | ||
| total_len = q.shape[0] | ||
| head_num = q.shape[1] | ||
| head_dim = q.shape[2] | ||
| assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" | ||
| BLOCK_HEAD = 4 | ||
| BLOCK_SEQ = 32 | ||
| grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) | ||
| if head_dim >= 128: | ||
| num_warps = 8 | ||
| else: | ||
| num_warps = 4 | ||
|
|
||
| _rotary_kernel[grid]( | ||
| q, | ||
| cos, | ||
| sin, | ||
| q.stride(0), | ||
| q.stride(1), | ||
| q.stride(2), | ||
| cos.stride(0), | ||
| cos.stride(1), | ||
| total_len, | ||
| HEAD_NUM=head_num, | ||
| BLOCK_HEAD=BLOCK_HEAD, | ||
| BLOCK_SEQ=BLOCK_SEQ, | ||
| HEAD_DIM=head_dim, | ||
| num_warps=num_warps, | ||
| num_stages=1, | ||
| ) | ||
| return |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.