-
Notifications
You must be signed in to change notification settings - Fork 4.5k
[inference] add int8 rotary embedding kernel for smoothquant #4843
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Xu-Kai
merged 62 commits into
hpcaitech:feature/smoothquant
from
Xu-Kai:feature/smoothquant
Sep 29, 2023
Merged
Changes from all commits
Commits
Show all changes
62 commits
Select commit
Hold shift + click to select a range
c7d6975
[shardformer] fix GPT2DoubleHeadsModel (#4703)
flybird11111 e2c0e7f
[hotfix] Fix import error: colossal.kernel without triton installed (…
yuanheng-zhao 20190b4
[shardformer] to fix whisper test failed due to significant accuracy …
flybird11111 ce97790
[doc] fix llama2 code link (#4726)
binmakeswell f911d5b
[doc] Add user document for Shardformer (#4702)
8c2dda7
[format] applied code formatting on changed files in pull request 472…
github-actions[bot] 50e5602
[doc] add shardformer support matrix/update tensor parallel documents…
e4fc57c
Optimized some syntax errors in the documentation and code under appl…
digger-yu 4616263
[shardformer] update pipeline parallel document (#4725)
flybird11111 cd4e61d
[legacy] remove deterministic data loader test
ppt0011 6a03c93
[shardformer] update seq parallel document (#4730)
FoolPlayer 608cffa
[example] add gpt2 HybridParallelPlugin example (#4653)
FoolPlayer 73eb3e8
Merge pull request #4738 from ppt0011/main
ppt0011 451c346
[doc] polish shardformer doc (#4735)
ac27979
[shardformer] add custom policy in hybrid parallel plugin (#4718)
oahzxl 4c4482f
[example] llama2 add fine-tune example (#4673)
flybird11111 d151dca
[doc] explaination of loading large pretrained models (#4741)
32e7f99
[kernel] update triton init #4740 (#4740)
oahzxl b5f9e37
[legacy] clean up legacy code (#4743)
ver217 3c6b831
[format] applied code formatting on changed files in pull request 474…
github-actions[bot] 079bf3c
[misc] update pre-commit and run all files (#4752)
ver217 10513f2
[doc] explain suitable use case for each plugin
ppt0011 a04337b
[doc] put individual plugin explanation in front
ppt0011 e10d9f0
[doc] add model examples for each plugin
ppt0011 4d7537b
[doc] put native colossalai plugins first in description section
ppt0011 07c2e3d
Merge pull request #4757 from ppt0011/main
ppt0011 7b9b864
[chat]: update rm, add wandb and fix bugs (#4471)
cwher c0a0337
[shardformer] fix master param sync for hybrid plugin/rewrite unwrapp…
df66741
[bug] fix get_default_parser in examples (#4764)
66f3926
[doc] clean up outdated docs (#4765)
ver217 493a5ef
[doc] add shardformer doc to sidebar (#4768)
901ab1e
[chat]: add lora merge weights config (#4766)
cwher 3e05c07
[lazy] support torch 2.0 (#4763)
ver217 1e0e080
[bug] Fix the version check bug in colossalai run when generating the…
littsk 946ab56
[feature] add gptq for inference (#4754)
Xu-Kai ce7ade3
[inference] chatglm2 infer demo (#4724)
CjhHa1 4146f1c
[release] update version (#4775)
ver217 74aa7d9
initial commit: add colossal llama 2 (#4784)
TongLi3701 ce77785
[feature] ColossalEval: Evaluation Pipeline for LLMs (#4786)
chengeharrison d512a4d
[doc] add llama2 domain-specific solution news (#4789)
binmakeswell 26cd6d8
[fix] fix weekly runing example (#4787)
flybird11111 a2db755
[doc] polish shardformer doc (#4779)
64a08b2
[checkpointio] support unsharded checkpointIO for hybrid parallel (#4…
bd01467
update readme
TongLi3701 4965c0d
[lazy] support from_pretrained (#4801)
ver217 8cbce61
update
TongLi3701 62b6af1
Merge pull request #4805 from TongLi3701/docs/fix
Desperado-Jia b6cf0ac
[hotfix] change llama2 Colossal-LLaMA-2 script filename (#4800)
Chandler-Bing a227063
[misc] add last_epoch in CosineAnnealingWarmupLR (#4778)
hova88 da15fdb
[doc] add lazy init docs (#4808)
ver217 54b3ad8
[hotfix] fix norm type error in zero optimizer (#4795)
littsk 11f1e42
[hotfix] Correct several erroneous code comments (#4794)
littsk fb46d05
[format] applied code formatting on changed files in pull request 459…
github-actions[bot] bbbcac2
fix format (#4815)
TongLi3701 be400a0
[chat] fix gemini strategy (#4698)
flybird11111 1fa8c5e
Update Qwen-7B results (#4821)
chengeharrison 822051d
[doc] update slack link (#4823)
binmakeswell c3bef20
add autotune (#4822)
Xu-Kai ed06731
update Colossal (#4832)
TongLi3701 83f85c8
add int8 rotary embedding kernel
Xu-Kai b4b59d4
remove useless code
Xu-Kai 7d20460
Merge branch 'feature/smoothquant' into feature/smoothquant
Xu-Kai File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
119 changes: 119 additions & 0 deletions
119
colossalai/kernel/triton/int8_rotary_embedding_kernel.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,119 @@ | ||
| # Adapted from ModelTC https://github.com/ModelTC/lightllm | ||
| import torch | ||
| import triton | ||
| import triton.language as tl | ||
|
|
||
|
|
||
| @triton.jit | ||
| def _rotary_kernel( | ||
| q, | ||
| input_scale, | ||
| output_scale, | ||
| Cos, | ||
| Sin, | ||
| q_bs_stride, | ||
| q_h_stride, | ||
| q_d_stride, | ||
| cos_bs_stride, | ||
| cos_d_stride, | ||
| total_len, | ||
| HEAD_NUM: tl.constexpr, | ||
| BLOCK_HEAD: tl.constexpr, | ||
| BLOCK_SEQ: tl.constexpr, | ||
| HEAD_DIM: tl.constexpr, | ||
| ): | ||
| current_head_index = tl.program_id(0) | ||
| current_seq_index = tl.program_id(1) | ||
|
|
||
| dim_range0 = tl.arange(0, HEAD_DIM // 2) | ||
| dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) | ||
|
|
||
| current_head_range = current_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) | ||
| current_seq_range = current_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) | ||
|
|
||
| off_q0 = ( | ||
| current_seq_range[:, None, None] * q_bs_stride | ||
| + current_head_range[None, :, None] * q_h_stride | ||
| + dim_range0[None, None, :] * q_d_stride | ||
| ) | ||
| off_q1 = ( | ||
| current_seq_range[:, None, None] * q_bs_stride | ||
| + current_head_range[None, :, None] * q_h_stride | ||
| + dim_range1[None, None, :] * q_d_stride | ||
| ) | ||
|
|
||
| off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride | ||
|
|
||
| q0 = tl.load( | ||
| q + off_q0, | ||
| mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), | ||
| other=0.0, | ||
| ) | ||
| q1 = tl.load( | ||
| q + off_q1, | ||
| mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), | ||
| other=0.0, | ||
| ) | ||
|
|
||
| cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) | ||
| sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) | ||
| in_scale = tl.load(input_scale) | ||
| o_scale = tl.load(output_scale) | ||
|
|
||
| q0 = q0.to(tl.float32) * in_scale | ||
| q1 = q1.to(tl.float32) * in_scale | ||
|
|
||
| out0 = (q0 * cos - q1 * sin) / o_scale | ||
| out1 = (q0 * sin + q1 * cos) / o_scale | ||
|
|
||
| # out0 = out0.to(tl.int8) | ||
| # out1 = out1.to(tl.int8) | ||
|
|
||
| tl.store( | ||
| q + off_q0, | ||
| out0, | ||
| mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), | ||
| ) | ||
| tl.store( | ||
| q + off_q1, | ||
| out1, | ||
| mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), | ||
| ) | ||
|
|
||
| return | ||
|
|
||
|
|
||
| @torch.no_grad() | ||
| def int8_rotary_embedding_fwd(q, cos, sin, input_scale, output_scale): | ||
| total_len = q.shape[0] | ||
| head_num = q.shape[1] | ||
| head_dim = q.shape[2] | ||
| assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" | ||
| BLOCK_HEAD = 4 | ||
| BLOCK_SEQ = 32 | ||
| grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) | ||
| if head_dim >= 128: | ||
| num_warps = 8 | ||
| else: | ||
| num_warps = 4 | ||
|
|
||
| _rotary_kernel[grid]( | ||
| q, | ||
| input_scale, | ||
| output_scale, | ||
| cos, | ||
| sin, | ||
| q.stride(0), | ||
| q.stride(1), | ||
| q.stride(2), | ||
| cos.stride(0), | ||
| cos.stride(1), | ||
| total_len, | ||
| HEAD_NUM=head_num, | ||
| BLOCK_HEAD=BLOCK_HEAD, | ||
| BLOCK_SEQ=BLOCK_SEQ, | ||
| HEAD_DIM=head_dim, | ||
| num_warps=num_warps, | ||
| num_stages=1, | ||
| ) | ||
| return | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| # Adapted from ModelTC https://github.com/ModelTC/lightllm | ||
|
|
||
|
|
||
| import pytest | ||
| import torch | ||
| from packaging import version | ||
|
|
||
| try: | ||
| from colossalai.kernel.triton import int8_rotary_embedding_fwd | ||
|
|
||
| HAS_TRITON = True | ||
| except ImportError: | ||
| HAS_TRITON = False | ||
| print("please install triton from https://github.com/openai/triton") | ||
|
|
||
| TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") | ||
|
|
||
|
|
||
| def torch_rotary_emb(x, cos, sin): | ||
| seq_len, h, dim = x.shape | ||
| x0 = x[:, :, 0 : dim // 2] | ||
| x1 = x[:, :, dim // 2 : dim] | ||
| cos = cos.view((seq_len, 1, dim // 2)) | ||
| sin = sin.view((seq_len, 1, dim // 2)) | ||
| o0 = x0 * cos - x1 * sin | ||
| o1 = x0 * sin + x1 * cos | ||
| return torch.cat((o0, o1), dim=-1) | ||
|
|
||
|
|
||
| @pytest.mark.skipif( | ||
| not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" | ||
| ) | ||
| def test_rotary_emb(): | ||
| SEQ_LEN = 1 | ||
| HEAD_NUM = 32 | ||
| HEAD_DIM = 128 | ||
| dtype = torch.float | ||
| # create data | ||
| x_shape = (SEQ_LEN, HEAD_NUM, HEAD_DIM) | ||
| x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") | ||
| cos_shape = (SEQ_LEN, HEAD_DIM // 2) | ||
| cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") | ||
| sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") | ||
| # forward pass | ||
| y_torch = torch_rotary_emb(x, cos, sin) | ||
|
|
||
| input_scale = torch.max(torch.abs(x)) / 127 | ||
| output_scale = torch.max(torch.abs(y_torch)) / 127 | ||
|
|
||
| x = x / input_scale | ||
| x = x.to(torch.int8) | ||
|
|
||
| int8_rotary_embedding_fwd(x, cos, sin, input_scale, output_scale) | ||
| y_triton = x.to(torch.float) * output_scale | ||
| assert torch.allclose(y_triton, y_torch, atol=2e-1, rtol=1e-2, equal_nan=True) | ||
|
Xu-Kai marked this conversation as resolved.
|
||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| test_rotary_emb() | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.