Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions aiter/aot/sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from collections import namedtuple
import os
import concurrent.futures
from csrc.cpp_itfs.sampling.top_k_renorm_probs import (
compile as top_k_renorm_probs_compile,
)
from csrc.cpp_itfs.sampling.top_p_sampling_from_probs import (
compile as top_p_sampling_from_probs_compile,
)
from csrc.cpp_itfs.sampling.top_k_top_p_sampling_from_probs import (
compile as top_k_top_p_sampling_from_probs_compile,
)

TopKRenormConfig = namedtuple(
"TopKRenormConfig",
["vec_size", "func_name"],
)

TopPSamplingConfig = namedtuple(
"TopPSamplingConfig",
["vec_size", "deterministic", "func_name"],
)

TopKTopPSamplingConfig = namedtuple(
"TopKTopPSamplingConfig",
["vec_size", "deterministic", "func_name"],
)


def process_top_k_renorm_config(config):
return top_k_renorm_probs_compile(config.vec_size)


def process_top_p_sampling_config(config):
return top_p_sampling_from_probs_compile(config.vec_size, config.deterministic)


def process_top_k_top_p_sampling_config(config):
return top_k_top_p_sampling_from_probs_compile(
config.vec_size, config.deterministic
)


def main():
# Generate configs for top_k_renorm_probs
top_k_renorm_configs = []
for vec_size in range(1, 5):
top_k_renorm_configs.append(
TopKRenormConfig(
vec_size=vec_size,
func_name="top_k_renorm_probs",
)
)

# Generate configs for top_p_sampling_from_probs
top_p_sampling_configs = []
for vec_size in range(1, 5):
for deterministic in [False, True]:
top_p_sampling_configs.append(
TopPSamplingConfig(
vec_size=vec_size,
deterministic=deterministic,
func_name="top_p_sampling_from_probs",
)
)

# Generate configs for top_k_top_p_sampling_from_probs
top_k_top_p_sampling_configs = []
for vec_size in range(1, 5):
for deterministic in [False, True]:
top_k_top_p_sampling_configs.append(
TopKTopPSamplingConfig(
vec_size=vec_size,
deterministic=deterministic,
func_name="top_k_top_p_sampling_from_probs",
)
)

max_jobs = int(os.environ.get("MAX_JOBS", os.cpu_count() or 16))

# Process all configs in parallel
with concurrent.futures.ProcessPoolExecutor(max_workers=max_jobs) as executor:
executor.map(process_top_k_renorm_config, top_k_renorm_configs)
executor.map(process_top_p_sampling_config, top_p_sampling_configs)
executor.map(process_top_k_top_p_sampling_config, top_k_top_p_sampling_configs)


if __name__ == "__main__":
main()
7 changes: 3 additions & 4 deletions csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
void* top_k_arr_ptr, \
int batch_size, \
int top_k_val, \
int vocab_size, \
void* stream)

extern "C" {
Expand All @@ -32,12 +33,10 @@ FUNCTION_DEFINE;

FUNCTION_DEFINE
{
constexpr uint32_t vec_size = std::gcd(16 / sizeof(float), {{d}});

const uint32_t smem_size = sizeof(aiter::sampling::RenormTempStorage<aiter::sampling::BLOCK_THREADS, aiter::sampling::REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(aiter::sampling::BLOCK_THREADS);
auto kernel = aiter::sampling::TopKRenormProbKernel<aiter::sampling::BLOCK_THREADS, aiter::sampling::REDUCE_ALGO, vec_size, float, int>;
auto kernel = aiter::sampling::TopKRenormProbKernel<aiter::sampling::BLOCK_THREADS, aiter::sampling::REDUCE_ALGO, {{vec_size}}, float, int>;
hipFuncSetAttribute(reinterpret_cast<const void*>(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size);
kernel<<<nblks, nthrs, smem_size, reinterpret_cast<hipStream_t>(stream)>>>(reinterpret_cast<float*>(probs_ptr), reinterpret_cast<float*>(renormed_probs_ptr), reinterpret_cast<int*>(top_k_arr_ptr), top_k_val, {{d}});
kernel<<<nblks, nthrs, smem_size, reinterpret_cast<hipStream_t>(stream)>>>(reinterpret_cast<float*>(probs_ptr), reinterpret_cast<float*>(renormed_probs_ptr), reinterpret_cast<int*>(top_k_arr_ptr), top_k_val, vocab_size);
}
13 changes: 8 additions & 5 deletions csrc/cpp_itfs/sampling/top_k_renorm_probs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from jinja2 import Template
from csrc.cpp_itfs.utils import compile_template_op, AITER_CORE_DIR

import math

MD_NAME = "top_k_renorm_probs"

Expand All @@ -16,7 +16,7 @@


def compile(
d: int,
vec_size: int,
folder: str = None,
):
return compile_template_op(
Expand All @@ -27,7 +27,7 @@ def compile(
f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/sampling.cuh",
f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/vec_dtypes.cuh",
],
d=d,
vec_size=vec_size,
folder=folder,
)

Expand All @@ -46,23 +46,25 @@ def top_k_renorm_probs(

batch_size = probs.size(0)
vocab_size = probs.size(1)

vec_size = math.gcd(16 // probs.element_size(), vocab_size)
renorm_probs = torch.empty_like(probs)

func = compile(vocab_size)
func = compile(vec_size)
(
probs_ptr,
renorm_probs_ptr,
top_k_arr_ptr,
top_k_val,
batch_size,
vocab_size,
stream,
) = torch_to_c_types(
probs,
renorm_probs,
maybe_top_k_arr,
top_k_val,
batch_size,
vocab_size,
torch.cuda.current_stream(),
)
func(
Expand All @@ -71,6 +73,7 @@ def top_k_renorm_probs(
top_k_arr_ptr,
batch_size,
top_k_val,
vocab_size,
stream,
)
return renorm_probs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
float top_p_val, \
int philox_seed, \
int philox_offset, \
int vocab_size, \
void* stream)

extern "C" {
Expand All @@ -37,13 +38,11 @@ FUNCTION_DEFINE;

FUNCTION_DEFINE
{
constexpr uint32_t vec_size = std::gcd(16 / sizeof(float), {{d}});

const uint32_t smem_size = sizeof(aiter::sampling::SamplingTempStorage<aiter::sampling::BLOCK_THREADS, aiter::sampling::SCAN_ALGO, aiter::sampling::REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(aiter::sampling::BLOCK_THREADS);
auto kernel = aiter::sampling::TopKTopPSamplingFromProbKernel<aiter::sampling::BLOCK_THREADS, aiter::sampling::SCAN_ALGO, aiter::sampling::REDUCE_ALGO,
vec_size, {{"true" if deterministic else "false"}}, float, int>;
{{vec_size}}, {{"true" if deterministic else "false"}}, float, int>;
hipFuncSetAttribute(reinterpret_cast<const void*>(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size);
kernel<<<nblks, nthrs, smem_size, reinterpret_cast<hipStream_t>(stream)>>>(reinterpret_cast<float*>(probs_ptr), reinterpret_cast<int*>(top_k_arr_ptr), reinterpret_cast<float*>(top_p_arr_ptr), reinterpret_cast<int*>(output_ptr), reinterpret_cast<int*>(indices_ptr), top_k_val, top_p_val, {{d}}, static_cast<uint64_t>(philox_seed), static_cast<uint64_t>(philox_offset));
kernel<<<nblks, nthrs, smem_size, reinterpret_cast<hipStream_t>(stream)>>>(reinterpret_cast<float*>(probs_ptr), reinterpret_cast<int*>(top_k_arr_ptr), reinterpret_cast<float*>(top_p_arr_ptr), reinterpret_cast<int*>(output_ptr), reinterpret_cast<int*>(indices_ptr), top_k_val, top_p_val, vocab_size, static_cast<uint64_t>(philox_seed), static_cast<uint64_t>(philox_offset));
}
12 changes: 8 additions & 4 deletions csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from jinja2 import Template
from csrc.cpp_itfs.utils import compile_template_op, AITER_CORE_DIR, str_to_bool
import math


MD_NAME = "top_k_top_p_sampling_from_probs"
Expand All @@ -16,7 +17,7 @@


def compile(
d: int,
vec_size: int,
deterministic: bool,
folder: str = None,
):
Expand All @@ -28,7 +29,7 @@ def compile(
f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/sampling.cuh",
f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/vec_dtypes.cuh",
],
d=d,
vec_size=vec_size,
deterministic=deterministic,
folder=folder,
)
Expand Down Expand Up @@ -61,8 +62,8 @@ def top_k_top_p_sampling_from_probs(
philox_seed = generator.seed()

output = torch.empty(batch_size, dtype=torch.int32, device=probs.device)

func = compile(vocab_size, deterministic)
vec_size = math.gcd(16 // probs.element_size(), vocab_size)
func = compile(vec_size, deterministic)
(
probs_ptr,
output_ptr,
Expand All @@ -74,6 +75,7 @@ def top_k_top_p_sampling_from_probs(
batch_size,
philox_seed,
philox_offset,
vocab_size,
stream,
) = torch_to_c_types(
probs,
Expand All @@ -86,6 +88,7 @@ def top_k_top_p_sampling_from_probs(
batch_size,
philox_seed,
philox_offset,
vocab_size,
torch.cuda.current_stream(),
)
func(
Expand All @@ -99,6 +102,7 @@ def top_k_top_p_sampling_from_probs(
top_p_val,
philox_seed,
philox_offset,
vocab_size,
stream,
)
return output
Expand Down
6 changes: 3 additions & 3 deletions csrc/cpp_itfs/sampling/top_p_sampling_from_probs.cpp.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
float top_p_val, \
int philox_seed, \
int philox_offset, \
int vocab_size, \
void* stream)

extern "C" {
Expand All @@ -34,13 +35,12 @@ FUNCTION_DEFINE;

FUNCTION_DEFINE
{
constexpr uint32_t vec_size = std::gcd(16 / sizeof(float), {{d}});

const uint32_t smem_size = sizeof(aiter::sampling::SamplingTempStorage<aiter::sampling::BLOCK_THREADS, aiter::sampling::SCAN_ALGO, aiter::sampling::REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(aiter::sampling::BLOCK_THREADS);
auto kernel = aiter::sampling::TopPSamplingFromProbKernel<aiter::sampling::BLOCK_THREADS, aiter::sampling::SCAN_ALGO, aiter::sampling::REDUCE_ALGO,
vec_size, {{"true" if deterministic else "false"}}, float, int>;
{{vec_size}}, {{"true" if deterministic else "false"}}, float, int>;
hipFuncSetAttribute(reinterpret_cast<const void*>(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size);
kernel<<<nblks, nthrs, smem_size, reinterpret_cast<hipStream_t>(stream)>>>(reinterpret_cast<float*>(probs_ptr), reinterpret_cast<int*>(output_ptr), reinterpret_cast<int*>(indices_ptr), reinterpret_cast<float*>(top_p_arr_ptr), top_p_val, {{d}}, static_cast<uint64_t>(philox_seed), static_cast<uint64_t>(philox_offset));
kernel<<<nblks, nthrs, smem_size, reinterpret_cast<hipStream_t>(stream)>>>(reinterpret_cast<float*>(probs_ptr), reinterpret_cast<int*>(output_ptr), reinterpret_cast<int*>(indices_ptr), reinterpret_cast<float*>(top_p_arr_ptr), top_p_val, vocab_size, static_cast<uint64_t>(philox_seed), static_cast<uint64_t>(philox_offset));
}
12 changes: 8 additions & 4 deletions csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from jinja2 import Template
from csrc.cpp_itfs.utils import compile_template_op, AITER_CORE_DIR, str_to_bool
import math


MD_NAME = "top_p_sampling_from_probs"
Expand All @@ -16,7 +17,7 @@


def compile(
d: int,
vec_size: int,
deterministic: bool,
folder: str = None,
):
Expand All @@ -28,7 +29,7 @@ def compile(
f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/sampling.cuh",
f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/vec_dtypes.cuh",
],
d=d,
vec_size=vec_size,
deterministic=deterministic,
folder=folder,
)
Expand Down Expand Up @@ -56,9 +57,9 @@ def top_p_sampling_from_probs(

batch_size = probs.size(0)
vocab_size = probs.size(1)

vec_size = math.gcd(16 // probs.element_size(), vocab_size)
samples = torch.empty(batch_size, dtype=torch.int32, device=probs.device)
func = compile(vocab_size, deterministic)
func = compile(vec_size, deterministic)
(
probs_ptr,
samples_ptr,
Expand All @@ -68,6 +69,7 @@ def top_p_sampling_from_probs(
batch_size,
philox_seed,
philox_offset,
vocab_size,
stream,
) = torch_to_c_types(
probs,
Expand All @@ -78,6 +80,7 @@ def top_p_sampling_from_probs(
batch_size,
philox_seed,
philox_offset,
vocab_size,
torch.cuda.current_stream(),
)
func(
Expand All @@ -89,6 +92,7 @@ def top_p_sampling_from_probs(
top_p_val,
philox_seed,
philox_offset,
vocab_size,
stream,
)
return samples
Expand Down