diff --git a/aiter/aot/sampling.py b/aiter/aot/sampling.py new file mode 100644 index 0000000000..0d758d5d83 --- /dev/null +++ b/aiter/aot/sampling.py @@ -0,0 +1,89 @@ +from collections import namedtuple +import os +import concurrent.futures +from csrc.cpp_itfs.sampling.top_k_renorm_probs import ( + compile as top_k_renorm_probs_compile, +) +from csrc.cpp_itfs.sampling.top_p_sampling_from_probs import ( + compile as top_p_sampling_from_probs_compile, +) +from csrc.cpp_itfs.sampling.top_k_top_p_sampling_from_probs import ( + compile as top_k_top_p_sampling_from_probs_compile, +) + +TopKRenormConfig = namedtuple( + "TopKRenormConfig", + ["vec_size", "func_name"], +) + +TopPSamplingConfig = namedtuple( + "TopPSamplingConfig", + ["vec_size", "deterministic", "func_name"], +) + +TopKTopPSamplingConfig = namedtuple( + "TopKTopPSamplingConfig", + ["vec_size", "deterministic", "func_name"], +) + + +def process_top_k_renorm_config(config): + return top_k_renorm_probs_compile(config.vec_size) + + +def process_top_p_sampling_config(config): + return top_p_sampling_from_probs_compile(config.vec_size, config.deterministic) + + +def process_top_k_top_p_sampling_config(config): + return top_k_top_p_sampling_from_probs_compile( + config.vec_size, config.deterministic + ) + + +def main(): + # Generate configs for top_k_renorm_probs + top_k_renorm_configs = [] + for vec_size in range(1, 5): + top_k_renorm_configs.append( + TopKRenormConfig( + vec_size=vec_size, + func_name="top_k_renorm_probs", + ) + ) + + # Generate configs for top_p_sampling_from_probs + top_p_sampling_configs = [] + for vec_size in range(1, 5): + for deterministic in [False, True]: + top_p_sampling_configs.append( + TopPSamplingConfig( + vec_size=vec_size, + deterministic=deterministic, + func_name="top_p_sampling_from_probs", + ) + ) + + # Generate configs for top_k_top_p_sampling_from_probs + top_k_top_p_sampling_configs = [] + for vec_size in range(1, 5): + for deterministic in [False, True]: + top_k_top_p_sampling_configs.append( + TopKTopPSamplingConfig( + vec_size=vec_size, + deterministic=deterministic, + func_name="top_k_top_p_sampling_from_probs", + ) + ) + + max_jobs = int(os.environ.get("MAX_JOBS", os.cpu_count() or 16)) + + # Process all configs in parallel + with concurrent.futures.ProcessPoolExecutor(max_workers=max_jobs) as executor: + executor.map(process_top_k_renorm_config, top_k_renorm_configs) + executor.map(process_top_p_sampling_config, top_p_sampling_configs) + executor.map(process_top_k_top_p_sampling_config, top_k_top_p_sampling_configs) + + +if __name__ == "__main__": + main() diff --git a/csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja b/csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja index f7d0261f9c..1c176062db 100644 --- a/csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja +++ b/csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja @@ -24,6 +24,7 @@ void* top_k_arr_ptr, \ int batch_size, \ int top_k_val, \ + int vocab_size, \ void* stream) extern "C" { @@ -32,12 +33,10 @@ FUNCTION_DEFINE; FUNCTION_DEFINE { - constexpr uint32_t vec_size = std::gcd(16 / sizeof(float), {{d}}); - const uint32_t smem_size = sizeof(aiter::sampling::RenormTempStorage); dim3 nblks(batch_size); dim3 nthrs(aiter::sampling::BLOCK_THREADS); - auto kernel = aiter::sampling::TopKRenormProbKernel; + auto kernel = aiter::sampling::TopKRenormProbKernel; hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); - kernel<<(stream)>>>(reinterpret_cast(probs_ptr), reinterpret_cast(renormed_probs_ptr), reinterpret_cast(top_k_arr_ptr), top_k_val, {{d}}); + kernel<<(stream)>>>(reinterpret_cast(probs_ptr), reinterpret_cast(renormed_probs_ptr), reinterpret_cast(top_k_arr_ptr), top_k_val, vocab_size); } \ No newline at end of file diff --git a/csrc/cpp_itfs/sampling/top_k_renorm_probs.py b/csrc/cpp_itfs/sampling/top_k_renorm_probs.py index cfc816798f..524285c006 100644 --- a/csrc/cpp_itfs/sampling/top_k_renorm_probs.py +++ b/csrc/cpp_itfs/sampling/top_k_renorm_probs.py @@ -4,7 +4,7 @@ from jinja2 import Template from csrc.cpp_itfs.utils import compile_template_op, AITER_CORE_DIR - +import math MD_NAME = "top_k_renorm_probs" @@ -16,7 +16,7 @@ def compile( - d: int, + vec_size: int, folder: str = None, ): return compile_template_op( @@ -27,7 +27,7 @@ def compile( f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/sampling.cuh", f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/vec_dtypes.cuh", ], - d=d, + vec_size=vec_size, folder=folder, ) @@ -46,16 +46,17 @@ def top_k_renorm_probs( batch_size = probs.size(0) vocab_size = probs.size(1) - + vec_size = math.gcd(16 // probs.element_size(), vocab_size) renorm_probs = torch.empty_like(probs) - func = compile(vocab_size) + func = compile(vec_size) ( probs_ptr, renorm_probs_ptr, top_k_arr_ptr, top_k_val, batch_size, + vocab_size, stream, ) = torch_to_c_types( probs, @@ -63,6 +64,7 @@ def top_k_renorm_probs( maybe_top_k_arr, top_k_val, batch_size, + vocab_size, torch.cuda.current_stream(), ) func( @@ -71,6 +73,7 @@ def top_k_renorm_probs( top_k_arr_ptr, batch_size, top_k_val, + vocab_size, stream, ) return renorm_probs diff --git a/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.cpp.jinja b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.cpp.jinja index 301b5c9790..6408c41d96 100644 --- a/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.cpp.jinja +++ b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.cpp.jinja @@ -29,6 +29,7 @@ float top_p_val, \ int philox_seed, \ int philox_offset, \ + int vocab_size, \ void* stream) extern "C" { @@ -37,13 +38,11 @@ FUNCTION_DEFINE; FUNCTION_DEFINE { - constexpr uint32_t vec_size = std::gcd(16 / sizeof(float), {{d}}); - const uint32_t smem_size = sizeof(aiter::sampling::SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(aiter::sampling::BLOCK_THREADS); auto kernel = aiter::sampling::TopKTopPSamplingFromProbKernel; + {{vec_size}}, {{"true" if deterministic else "false"}}, float, int>; hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); - kernel<<(stream)>>>(reinterpret_cast(probs_ptr), reinterpret_cast(top_k_arr_ptr), reinterpret_cast(top_p_arr_ptr), reinterpret_cast(output_ptr), reinterpret_cast(indices_ptr), top_k_val, top_p_val, {{d}}, static_cast(philox_seed), static_cast(philox_offset)); + kernel<<(stream)>>>(reinterpret_cast(probs_ptr), reinterpret_cast(top_k_arr_ptr), reinterpret_cast(top_p_arr_ptr), reinterpret_cast(output_ptr), reinterpret_cast(indices_ptr), top_k_val, top_p_val, vocab_size, static_cast(philox_seed), static_cast(philox_offset)); } \ No newline at end of file diff --git a/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py index 48fbe6e6f3..0ac5520dc8 100644 --- a/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py +++ b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py @@ -4,6 +4,7 @@ from jinja2 import Template from csrc.cpp_itfs.utils import compile_template_op, AITER_CORE_DIR, str_to_bool +import math MD_NAME = "top_k_top_p_sampling_from_probs" @@ -16,7 +17,7 @@ def compile( - d: int, + vec_size: int, deterministic: bool, folder: str = None, ): @@ -28,7 +29,7 @@ def compile( f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/sampling.cuh", f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/vec_dtypes.cuh", ], - d=d, + vec_size=vec_size, deterministic=deterministic, folder=folder, ) @@ -61,8 +62,8 @@ def top_k_top_p_sampling_from_probs( philox_seed = generator.seed() output = torch.empty(batch_size, dtype=torch.int32, device=probs.device) - - func = compile(vocab_size, deterministic) + vec_size = math.gcd(16 // probs.element_size(), vocab_size) + func = compile(vec_size, deterministic) ( probs_ptr, output_ptr, @@ -74,6 +75,7 @@ def top_k_top_p_sampling_from_probs( batch_size, philox_seed, philox_offset, + vocab_size, stream, ) = torch_to_c_types( probs, @@ -86,6 +88,7 @@ def top_k_top_p_sampling_from_probs( batch_size, philox_seed, philox_offset, + vocab_size, torch.cuda.current_stream(), ) func( @@ -99,6 +102,7 @@ def top_k_top_p_sampling_from_probs( top_p_val, philox_seed, philox_offset, + vocab_size, stream, ) return output diff --git a/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.cpp.jinja b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.cpp.jinja index 99c23b44e7..020b494ffe 100644 --- a/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.cpp.jinja +++ b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.cpp.jinja @@ -26,6 +26,7 @@ float top_p_val, \ int philox_seed, \ int philox_offset, \ + int vocab_size, \ void* stream) extern "C" { @@ -34,13 +35,12 @@ FUNCTION_DEFINE; FUNCTION_DEFINE { - constexpr uint32_t vec_size = std::gcd(16 / sizeof(float), {{d}}); const uint32_t smem_size = sizeof(aiter::sampling::SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(aiter::sampling::BLOCK_THREADS); auto kernel = aiter::sampling::TopPSamplingFromProbKernel; + {{vec_size}}, {{"true" if deterministic else "false"}}, float, int>; hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); - kernel<<(stream)>>>(reinterpret_cast(probs_ptr), reinterpret_cast(output_ptr), reinterpret_cast(indices_ptr), reinterpret_cast(top_p_arr_ptr), top_p_val, {{d}}, static_cast(philox_seed), static_cast(philox_offset)); + kernel<<(stream)>>>(reinterpret_cast(probs_ptr), reinterpret_cast(output_ptr), reinterpret_cast(indices_ptr), reinterpret_cast(top_p_arr_ptr), top_p_val, vocab_size, static_cast(philox_seed), static_cast(philox_offset)); } \ No newline at end of file diff --git a/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py index 7e1500b231..3c9f5b9af6 100644 --- a/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py +++ b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py @@ -4,6 +4,7 @@ from jinja2 import Template from csrc.cpp_itfs.utils import compile_template_op, AITER_CORE_DIR, str_to_bool +import math MD_NAME = "top_p_sampling_from_probs" @@ -16,7 +17,7 @@ def compile( - d: int, + vec_size: int, deterministic: bool, folder: str = None, ): @@ -28,7 +29,7 @@ def compile( f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/sampling.cuh", f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/vec_dtypes.cuh", ], - d=d, + vec_size=vec_size, deterministic=deterministic, folder=folder, ) @@ -56,9 +57,9 @@ def top_p_sampling_from_probs( batch_size = probs.size(0) vocab_size = probs.size(1) - + vec_size = math.gcd(16 // probs.element_size(), vocab_size) samples = torch.empty(batch_size, dtype=torch.int32, device=probs.device) - func = compile(vocab_size, deterministic) + func = compile(vec_size, deterministic) ( probs_ptr, samples_ptr, @@ -68,6 +69,7 @@ def top_p_sampling_from_probs( batch_size, philox_seed, philox_offset, + vocab_size, stream, ) = torch_to_c_types( probs, @@ -78,6 +80,7 @@ def top_p_sampling_from_probs( batch_size, philox_seed, philox_offset, + vocab_size, torch.cuda.current_stream(), ) func( @@ -89,6 +92,7 @@ def top_p_sampling_from_probs( top_p_val, philox_seed, philox_offset, + vocab_size, stream, ) return samples