From 3bd06df4dcd82198487f2561d95953a0c298df74 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 22 Dec 2025 12:39:48 +0000 Subject: [PATCH 1/5] add sampling aot --- aiter/aot/sampling.py | 89 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 aiter/aot/sampling.py diff --git a/aiter/aot/sampling.py b/aiter/aot/sampling.py new file mode 100644 index 0000000000..0d758d5d83 --- /dev/null +++ b/aiter/aot/sampling.py @@ -0,0 +1,89 @@ +from collections import namedtuple +import os +import concurrent.futures +from csrc.cpp_itfs.sampling.top_k_renorm_probs import ( + compile as top_k_renorm_probs_compile, +) +from csrc.cpp_itfs.sampling.top_p_sampling_from_probs import ( + compile as top_p_sampling_from_probs_compile, +) +from csrc.cpp_itfs.sampling.top_k_top_p_sampling_from_probs import ( + compile as top_k_top_p_sampling_from_probs_compile, +) + +TopKRenormConfig = namedtuple( + "TopKRenormConfig", + ["vec_size", "func_name"], +) + +TopPSamplingConfig = namedtuple( + "TopPSamplingConfig", + ["vec_size", "deterministic", "func_name"], +) + +TopKTopPSamplingConfig = namedtuple( + "TopKTopPSamplingConfig", + ["vec_size", "deterministic", "func_name"], +) + + +def process_top_k_renorm_config(config): + return top_k_renorm_probs_compile(config.vec_size) + + +def process_top_p_sampling_config(config): + return top_p_sampling_from_probs_compile(config.vec_size, config.deterministic) + + +def process_top_k_top_p_sampling_config(config): + return top_k_top_p_sampling_from_probs_compile( + config.vec_size, config.deterministic + ) + + +def main(): + # Generate configs for top_k_renorm_probs + top_k_renorm_configs = [] + for vec_size in range(1, 5): + top_k_renorm_configs.append( + TopKRenormConfig( + vec_size=vec_size, + func_name="top_k_renorm_probs", + ) + ) + + # Generate configs for top_p_sampling_from_probs + top_p_sampling_configs = [] + for vec_size in range(1, 5): + for deterministic in [False, True]: + top_p_sampling_configs.append( + TopPSamplingConfig( + vec_size=vec_size, + deterministic=deterministic, + func_name="top_p_sampling_from_probs", + ) + ) + + # Generate configs for top_k_top_p_sampling_from_probs + top_k_top_p_sampling_configs = [] + for vec_size in range(1, 5): + for deterministic in [False, True]: + top_k_top_p_sampling_configs.append( + TopKTopPSamplingConfig( + vec_size=vec_size, + deterministic=deterministic, + func_name="top_k_top_p_sampling_from_probs", + ) + ) + + max_jobs = int(os.environ.get("MAX_JOBS", os.cpu_count() or 16)) + + # Process all configs in parallel + with concurrent.futures.ProcessPoolExecutor(max_workers=max_jobs) as executor: + executor.map(process_top_k_renorm_config, top_k_renorm_configs) + executor.map(process_top_p_sampling_config, top_p_sampling_configs) + executor.map(process_top_k_top_p_sampling_config, top_k_top_p_sampling_configs) + + +if __name__ == "__main__": + main() From 4aa11150bd0a72d18cbfc8eafc5ccc39bc0cffc0 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 22 Dec 2025 12:42:10 +0000 Subject: [PATCH 2/5] simple compile --- csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja | 4 +--- csrc/cpp_itfs/sampling/top_k_renorm_probs.py | 10 +++++----- .../sampling/top_k_top_p_sampling_from_probs.cpp.jinja | 4 +--- .../sampling/top_k_top_p_sampling_from_probs.py | 9 +++++---- .../sampling/top_p_sampling_from_probs.cpp.jinja | 3 +-- csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py | 9 +++++---- 6 files changed, 18 insertions(+), 21 deletions(-) diff --git a/csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja b/csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja index f7d0261f9c..ba6aea1f13 100644 --- a/csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja +++ b/csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja @@ -32,12 +32,10 @@ FUNCTION_DEFINE; FUNCTION_DEFINE { - constexpr uint32_t vec_size = std::gcd(16 / sizeof(float), {{d}}); - const uint32_t smem_size = sizeof(aiter::sampling::RenormTempStorage); dim3 nblks(batch_size); dim3 nthrs(aiter::sampling::BLOCK_THREADS); - auto kernel = aiter::sampling::TopKRenormProbKernel; + auto kernel = aiter::sampling::TopKRenormProbKernel; hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); kernel<<(stream)>>>(reinterpret_cast(probs_ptr), reinterpret_cast(renormed_probs_ptr), reinterpret_cast(top_k_arr_ptr), top_k_val, {{d}}); } \ No newline at end of file diff --git a/csrc/cpp_itfs/sampling/top_k_renorm_probs.py b/csrc/cpp_itfs/sampling/top_k_renorm_probs.py index cfc816798f..2e6467e082 100644 --- a/csrc/cpp_itfs/sampling/top_k_renorm_probs.py +++ b/csrc/cpp_itfs/sampling/top_k_renorm_probs.py @@ -4,7 +4,7 @@ from jinja2 import Template from csrc.cpp_itfs.utils import compile_template_op, AITER_CORE_DIR - +import math MD_NAME = "top_k_renorm_probs" @@ -16,7 +16,7 @@ def compile( - d: int, + vec_size: int, folder: str = None, ): return compile_template_op( @@ -27,7 +27,7 @@ def compile( f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/sampling.cuh", f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/vec_dtypes.cuh", ], - d=d, + vec_size=vec_size, folder=folder, ) @@ -46,10 +46,10 @@ def top_k_renorm_probs( batch_size = probs.size(0) vocab_size = probs.size(1) - + vec_size = math.gcd(16 // probs.element_size(), vocab_size) renorm_probs = torch.empty_like(probs) - func = compile(vocab_size) + func = compile(vec_size) ( probs_ptr, renorm_probs_ptr, diff --git a/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.cpp.jinja b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.cpp.jinja index 301b5c9790..14a259082c 100644 --- a/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.cpp.jinja +++ b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.cpp.jinja @@ -37,13 +37,11 @@ FUNCTION_DEFINE; FUNCTION_DEFINE { - constexpr uint32_t vec_size = std::gcd(16 / sizeof(float), {{d}}); - const uint32_t smem_size = sizeof(aiter::sampling::SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(aiter::sampling::BLOCK_THREADS); auto kernel = aiter::sampling::TopKTopPSamplingFromProbKernel; + {{vec_size}}, {{"true" if deterministic else "false"}}, float, int>; hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); kernel<<(stream)>>>(reinterpret_cast(probs_ptr), reinterpret_cast(top_k_arr_ptr), reinterpret_cast(top_p_arr_ptr), reinterpret_cast(output_ptr), reinterpret_cast(indices_ptr), top_k_val, top_p_val, {{d}}, static_cast(philox_seed), static_cast(philox_offset)); } \ No newline at end of file diff --git a/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py index 48fbe6e6f3..cb16129fc5 100644 --- a/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py +++ b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py @@ -4,6 +4,7 @@ from jinja2 import Template from csrc.cpp_itfs.utils import compile_template_op, AITER_CORE_DIR, str_to_bool +import math MD_NAME = "top_k_top_p_sampling_from_probs" @@ -16,7 +17,7 @@ def compile( - d: int, + vec_size: int, deterministic: bool, folder: str = None, ): @@ -28,7 +29,7 @@ def compile( f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/sampling.cuh", f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/vec_dtypes.cuh", ], - d=d, + vec_size=vec_size, deterministic=deterministic, folder=folder, ) @@ -61,8 +62,8 @@ def top_k_top_p_sampling_from_probs( philox_seed = generator.seed() output = torch.empty(batch_size, dtype=torch.int32, device=probs.device) - - func = compile(vocab_size, deterministic) + vec_size = math.gcd(16 // probs.element_size(), vocab_size) + func = compile(vec_size, deterministic) ( probs_ptr, output_ptr, diff --git a/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.cpp.jinja b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.cpp.jinja index 99c23b44e7..20ebf6c862 100644 --- a/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.cpp.jinja +++ b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.cpp.jinja @@ -34,13 +34,12 @@ FUNCTION_DEFINE; FUNCTION_DEFINE { - constexpr uint32_t vec_size = std::gcd(16 / sizeof(float), {{d}}); const uint32_t smem_size = sizeof(aiter::sampling::SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(aiter::sampling::BLOCK_THREADS); auto kernel = aiter::sampling::TopPSamplingFromProbKernel; + {{vec_size}}, {{"true" if deterministic else "false"}}, float, int>; hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); kernel<<(stream)>>>(reinterpret_cast(probs_ptr), reinterpret_cast(output_ptr), reinterpret_cast(indices_ptr), reinterpret_cast(top_p_arr_ptr), top_p_val, {{d}}, static_cast(philox_seed), static_cast(philox_offset)); } \ No newline at end of file diff --git a/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py index 7e1500b231..6a9e26d269 100644 --- a/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py +++ b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py @@ -4,6 +4,7 @@ from jinja2 import Template from csrc.cpp_itfs.utils import compile_template_op, AITER_CORE_DIR, str_to_bool +import math MD_NAME = "top_p_sampling_from_probs" @@ -16,7 +17,7 @@ def compile( - d: int, + vec_size: int, deterministic: bool, folder: str = None, ): @@ -28,7 +29,7 @@ def compile( f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/sampling.cuh", f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/vec_dtypes.cuh", ], - d=d, + vec_size=vec_size, deterministic=deterministic, folder=folder, ) @@ -56,9 +57,9 @@ def top_p_sampling_from_probs( batch_size = probs.size(0) vocab_size = probs.size(1) - + vec_size = math.gcd(16 // probs.element_size(), vocab_size) samples = torch.empty(batch_size, dtype=torch.int32, device=probs.device) - func = compile(vocab_size, deterministic) + func = compile(vec_size, deterministic) ( probs_ptr, samples_ptr, From 9c663bd9f64adba581070b9bf85f9349977cc6fb Mon Sep 17 00:00:00 2001 From: root Date: Mon, 22 Dec 2025 12:56:32 +0000 Subject: [PATCH 3/5] fix compile bugs --- .../sampling/top_k_renorm_probs.cpp.jinja | 3 +- csrc/cpp_itfs/sampling/top_k_renorm_probs.py | 30 ++++------- .../top_k_top_p_sampling_from_probs.cpp.jinja | 3 +- .../top_k_top_p_sampling_from_probs.py | 50 ++++++------------- .../top_p_sampling_from_probs.cpp.jinja | 3 +- .../sampling/top_p_sampling_from_probs.py | 42 +++++----------- 6 files changed, 41 insertions(+), 90 deletions(-) diff --git a/csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja b/csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja index ba6aea1f13..1c176062db 100644 --- a/csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja +++ b/csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja @@ -24,6 +24,7 @@ void* top_k_arr_ptr, \ int batch_size, \ int top_k_val, \ + int vocab_size, \ void* stream) extern "C" { @@ -37,5 +38,5 @@ FUNCTION_DEFINE dim3 nthrs(aiter::sampling::BLOCK_THREADS); auto kernel = aiter::sampling::TopKRenormProbKernel; hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); - kernel<<(stream)>>>(reinterpret_cast(probs_ptr), reinterpret_cast(renormed_probs_ptr), reinterpret_cast(top_k_arr_ptr), top_k_val, {{d}}); + kernel<<(stream)>>>(reinterpret_cast(probs_ptr), reinterpret_cast(renormed_probs_ptr), reinterpret_cast(top_k_arr_ptr), top_k_val, vocab_size); } \ No newline at end of file diff --git a/csrc/cpp_itfs/sampling/top_k_renorm_probs.py b/csrc/cpp_itfs/sampling/top_k_renorm_probs.py index 2e6467e082..3cd5b05cab 100644 --- a/csrc/cpp_itfs/sampling/top_k_renorm_probs.py +++ b/csrc/cpp_itfs/sampling/top_k_renorm_probs.py @@ -50,28 +50,16 @@ def top_k_renorm_probs( renorm_probs = torch.empty_like(probs) func = compile(vec_size) - ( - probs_ptr, - renorm_probs_ptr, - top_k_arr_ptr, - top_k_val, - batch_size, - stream, - ) = torch_to_c_types( - probs, - renorm_probs, - maybe_top_k_arr, - top_k_val, - batch_size, - torch.cuda.current_stream(), - ) func( - probs_ptr, - renorm_probs_ptr, - top_k_arr_ptr, - batch_size, - top_k_val, - stream, + *torch_to_c_types( + probs, + renorm_probs, + maybe_top_k_arr, + top_k_val, + batch_size, + vocab_size, + torch.cuda.current_stream(), + ) ) return renorm_probs diff --git a/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.cpp.jinja b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.cpp.jinja index 14a259082c..6408c41d96 100644 --- a/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.cpp.jinja +++ b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.cpp.jinja @@ -29,6 +29,7 @@ float top_p_val, \ int philox_seed, \ int philox_offset, \ + int vocab_size, \ void* stream) extern "C" { @@ -43,5 +44,5 @@ FUNCTION_DEFINE auto kernel = aiter::sampling::TopKTopPSamplingFromProbKernel; hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); - kernel<<(stream)>>>(reinterpret_cast(probs_ptr), reinterpret_cast(top_k_arr_ptr), reinterpret_cast(top_p_arr_ptr), reinterpret_cast(output_ptr), reinterpret_cast(indices_ptr), top_k_val, top_p_val, {{d}}, static_cast(philox_seed), static_cast(philox_offset)); + kernel<<(stream)>>>(reinterpret_cast(probs_ptr), reinterpret_cast(top_k_arr_ptr), reinterpret_cast(top_p_arr_ptr), reinterpret_cast(output_ptr), reinterpret_cast(indices_ptr), top_k_val, top_p_val, vocab_size, static_cast(philox_seed), static_cast(philox_offset)); } \ No newline at end of file diff --git a/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py index cb16129fc5..b3471a726b 100644 --- a/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py +++ b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py @@ -64,43 +64,21 @@ def top_k_top_p_sampling_from_probs( output = torch.empty(batch_size, dtype=torch.int32, device=probs.device) vec_size = math.gcd(16 // probs.element_size(), vocab_size) func = compile(vec_size, deterministic) - ( - probs_ptr, - output_ptr, - indices_ptr, - top_k_arr_ptr, - top_p_arr_ptr, - top_k_val, - top_p_val, - batch_size, - philox_seed, - philox_offset, - stream, - ) = torch_to_c_types( - probs, - output, - indices, - maybe_top_k_arr, - maybe_top_p_arr, - top_k_val, - top_p_val, - batch_size, - philox_seed, - philox_offset, - torch.cuda.current_stream(), - ) func( - probs_ptr, - output_ptr, - indices_ptr, - top_k_arr_ptr, - top_p_arr_ptr, - batch_size, - top_k_val, - top_p_val, - philox_seed, - philox_offset, - stream, + *torch_to_c_types( + probs, + output, + indices, + maybe_top_k_arr, + maybe_top_p_arr, + top_k_val, + top_p_val, + batch_size, + philox_seed, + philox_offset, + vocab_size, + torch.cuda.current_stream(), + ) ) return output diff --git a/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.cpp.jinja b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.cpp.jinja index 20ebf6c862..020b494ffe 100644 --- a/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.cpp.jinja +++ b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.cpp.jinja @@ -26,6 +26,7 @@ float top_p_val, \ int philox_seed, \ int philox_offset, \ + int vocab_size, \ void* stream) extern "C" { @@ -41,5 +42,5 @@ FUNCTION_DEFINE auto kernel = aiter::sampling::TopPSamplingFromProbKernel; hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); - kernel<<(stream)>>>(reinterpret_cast(probs_ptr), reinterpret_cast(output_ptr), reinterpret_cast(indices_ptr), reinterpret_cast(top_p_arr_ptr), top_p_val, {{d}}, static_cast(philox_seed), static_cast(philox_offset)); + kernel<<(stream)>>>(reinterpret_cast(probs_ptr), reinterpret_cast(output_ptr), reinterpret_cast(indices_ptr), reinterpret_cast(top_p_arr_ptr), top_p_val, vocab_size, static_cast(philox_seed), static_cast(philox_offset)); } \ No newline at end of file diff --git a/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py index 6a9e26d269..c9f01532b7 100644 --- a/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py +++ b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py @@ -60,37 +60,19 @@ def top_p_sampling_from_probs( vec_size = math.gcd(16 // probs.element_size(), vocab_size) samples = torch.empty(batch_size, dtype=torch.int32, device=probs.device) func = compile(vec_size, deterministic) - ( - probs_ptr, - samples_ptr, - indices_ptr, - top_p_arr_ptr, - top_p_val, - batch_size, - philox_seed, - philox_offset, - stream, - ) = torch_to_c_types( - probs, - samples, - indices, - maybe_top_p_arr, - top_p_val, - batch_size, - philox_seed, - philox_offset, - torch.cuda.current_stream(), - ) func( - probs_ptr, - samples_ptr, - indices_ptr, - top_p_arr_ptr, - batch_size, - top_p_val, - philox_seed, - philox_offset, - stream, + torch_to_c_types( + probs, + samples, + indices, + maybe_top_p_arr, + top_p_val, + batch_size, + philox_seed, + philox_offset, + vocab_size, + torch.cuda.current_stream(), + ) ) return samples From 7bea886dbeb3a402338a1b79b6f599ab51f0a176 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 22 Dec 2025 12:57:53 +0000 Subject: [PATCH 4/5] fix a bug --- csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py index c9f01532b7..ebfe374264 100644 --- a/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py +++ b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py @@ -61,7 +61,7 @@ def top_p_sampling_from_probs( samples = torch.empty(batch_size, dtype=torch.int32, device=probs.device) func = compile(vec_size, deterministic) func( - torch_to_c_types( + *torch_to_c_types( probs, samples, indices, From 8edda5fe249eb736ea6be147cfc20aefe9e5259f Mon Sep 17 00:00:00 2001 From: root Date: Mon, 22 Dec 2025 14:18:55 +0000 Subject: [PATCH 5/5] revert changes --- csrc/cpp_itfs/sampling/top_k_renorm_probs.py | 33 ++++++++---- .../top_k_top_p_sampling_from_probs.py | 53 ++++++++++++++----- .../sampling/top_p_sampling_from_probs.py | 45 +++++++++++----- 3 files changed, 96 insertions(+), 35 deletions(-) diff --git a/csrc/cpp_itfs/sampling/top_k_renorm_probs.py b/csrc/cpp_itfs/sampling/top_k_renorm_probs.py index 3cd5b05cab..524285c006 100644 --- a/csrc/cpp_itfs/sampling/top_k_renorm_probs.py +++ b/csrc/cpp_itfs/sampling/top_k_renorm_probs.py @@ -50,16 +50,31 @@ def top_k_renorm_probs( renorm_probs = torch.empty_like(probs) func = compile(vec_size) + ( + probs_ptr, + renorm_probs_ptr, + top_k_arr_ptr, + top_k_val, + batch_size, + vocab_size, + stream, + ) = torch_to_c_types( + probs, + renorm_probs, + maybe_top_k_arr, + top_k_val, + batch_size, + vocab_size, + torch.cuda.current_stream(), + ) func( - *torch_to_c_types( - probs, - renorm_probs, - maybe_top_k_arr, - top_k_val, - batch_size, - vocab_size, - torch.cuda.current_stream(), - ) + probs_ptr, + renorm_probs_ptr, + top_k_arr_ptr, + batch_size, + top_k_val, + vocab_size, + stream, ) return renorm_probs diff --git a/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py index b3471a726b..0ac5520dc8 100644 --- a/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py +++ b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py @@ -64,21 +64,46 @@ def top_k_top_p_sampling_from_probs( output = torch.empty(batch_size, dtype=torch.int32, device=probs.device) vec_size = math.gcd(16 // probs.element_size(), vocab_size) func = compile(vec_size, deterministic) + ( + probs_ptr, + output_ptr, + indices_ptr, + top_k_arr_ptr, + top_p_arr_ptr, + top_k_val, + top_p_val, + batch_size, + philox_seed, + philox_offset, + vocab_size, + stream, + ) = torch_to_c_types( + probs, + output, + indices, + maybe_top_k_arr, + maybe_top_p_arr, + top_k_val, + top_p_val, + batch_size, + philox_seed, + philox_offset, + vocab_size, + torch.cuda.current_stream(), + ) func( - *torch_to_c_types( - probs, - output, - indices, - maybe_top_k_arr, - maybe_top_p_arr, - top_k_val, - top_p_val, - batch_size, - philox_seed, - philox_offset, - vocab_size, - torch.cuda.current_stream(), - ) + probs_ptr, + output_ptr, + indices_ptr, + top_k_arr_ptr, + top_p_arr_ptr, + batch_size, + top_k_val, + top_p_val, + philox_seed, + philox_offset, + vocab_size, + stream, ) return output diff --git a/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py index ebfe374264..3c9f5b9af6 100644 --- a/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py +++ b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py @@ -60,19 +60,40 @@ def top_p_sampling_from_probs( vec_size = math.gcd(16 // probs.element_size(), vocab_size) samples = torch.empty(batch_size, dtype=torch.int32, device=probs.device) func = compile(vec_size, deterministic) + ( + probs_ptr, + samples_ptr, + indices_ptr, + top_p_arr_ptr, + top_p_val, + batch_size, + philox_seed, + philox_offset, + vocab_size, + stream, + ) = torch_to_c_types( + probs, + samples, + indices, + maybe_top_p_arr, + top_p_val, + batch_size, + philox_seed, + philox_offset, + vocab_size, + torch.cuda.current_stream(), + ) func( - *torch_to_c_types( - probs, - samples, - indices, - maybe_top_p_arr, - top_p_val, - batch_size, - philox_seed, - philox_offset, - vocab_size, - torch.cuda.current_stream(), - ) + probs_ptr, + samples_ptr, + indices_ptr, + top_p_arr_ptr, + batch_size, + top_p_val, + philox_seed, + philox_offset, + vocab_size, + stream, ) return samples