diff --git a/aiter/aot/triton/compile.cpp b/aiter/aot/triton/compile.cpp deleted file mode 100644 index 124a305ced..0000000000 --- a/aiter/aot/triton/compile.cpp +++ /dev/null @@ -1,72 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -/* clang-format off */ -#include -#include -#include -#include -#include - -// helpers to check for hip errors -#define HIP_CHECK(ans) {{\ - gpuAssert((ans), __FILE__, __LINE__);\ - }}\ - -static inline void gpuAssert(hipError_t code, const char *file, int line) {{ - if (code != hipSuccess) {{ - const char *prefix = "Triton Error [HIP]: "; - const char *str; - hipDrvGetErrorString(code, &str); - char err[1024] = {{0}}; - strcat(err, prefix); - strcat(err, str); - printf("%s\\n", err); - exit(code); - }} -}} - -// globals -#define HSACO_NAME {kernel_name}_hsaco -hipModule_t {kernel_name}_mod = nullptr; -hipFunction_t {kernel_name}_func = nullptr; -unsigned char HSACO_NAME[{bin_size}] = {{ {bin_data} }}; - - -void unload_{kernel_name}(void) {{ - HIP_CHECK(hipModuleUnload({kernel_name}_mod)); -}} - - -void load_{kernel_name}() {{ - int dev = 0; - void *bin = (void *)&HSACO_NAME; - int shared = {shared}; - HIP_CHECK(hipModuleLoadData(&{kernel_name}_mod, bin)); - HIP_CHECK(hipModuleGetFunction(&{kernel_name}_func, {kernel_name}_mod, "{triton_kernel_name}")); - // set dynamic shared memory if necessary - int shared_optin; - HIP_CHECK(hipDeviceGetAttribute(&shared_optin, hipDeviceAttributeSharedMemPerBlockOptin, dev)); - if (shared > 49152 && shared_optin > 49152) {{ - HIP_CHECK(hipFuncSetCacheConfig({kernel_name}_func, hipFuncCachePreferShared)); - HIP_CHECK(hipFuncSetAttribute(reinterpret_cast({kernel_name}_func), hipFuncAttributeMaxDynamicSharedMemorySize, shared_optin)) - }} -}} - -/* -{kernel_docstring} -*/ -hipError_t {kernel_name}(hipStream_t stream, {signature}) {{ - if ({kernel_name}_func == nullptr) - load_{kernel_name}(); - unsigned int gX = {gridX}; - unsigned int gY = {gridY}; - unsigned int gZ = {gridZ}; - hipDeviceptr_t global_scratch = 0; - void *args[{num_args}] = {{ {arg_pointers} }}; - // TODO: shared memory - if(gX * gY * gZ > 0) - return hipModuleLaunchKernel({kernel_name}_func, gX, gY, gZ, {num_warps} * warpSize, 1, 1, {shared}, stream, args, nullptr); - else - return hipErrorInvalidValue; -}} diff --git a/aiter/aot/triton/compile.h b/aiter/aot/triton/compile.h deleted file mode 100644 index c6d856ddfe..0000000000 --- a/aiter/aot/triton/compile.h +++ /dev/null @@ -1,14 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include -#include -#include - -void unload_{kernel_name}(void); -void load_{kernel_name}(void); -// tt-linker: {kernel_name}:{full_signature}:{algo_info} -hipError_t{_placeholder} {kernel_name}(hipStream_t stream, {signature}); \ No newline at end of file diff --git a/aiter/aot/triton/compile.py b/aiter/aot/triton/compile.py deleted file mode 100644 index acb92ceae3..0000000000 --- a/aiter/aot/triton/compile.py +++ /dev/null @@ -1,307 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. - -import binascii -import hashlib -import importlib.util -import sys -from argparse import ArgumentParser -from pathlib import Path -from typing import List - -import triton - -try: - old_compiler = True - from triton.compiler.code_generator import kernel_suffix -except ImportError: - old_compiler = False - -from triton.backends.amd.driver import ty_to_cpp - -desc = """ -Triton ahead-of-time compiler: - -This program compiles the kernel with name `kernel-name` in the file at the -provided `path` into self-contained C source-code that embeds the `cubin` -data along with utilities to load, unload and launch the kernel. - -signature is provided as a list of (optionally divisibility-hinted) types -or constexpr values, e.g. - -`compile.py --kernel-name kernel --signature "*fp32:16, i32:16, 1024, i32" --out-name kernel /path/to/kernel.py` - -will compile triton.JITFunction of name `kernel` inside the file `/path/to/kernel.py`. -Said kernel will be specialized such that argument 0, 1 are assumed to be multiple of 16, -and argument 2 is assumed to be a compile-time constant of value 1024, i.e. it won't be part of the generated prototype. - -The resulting entry point will have signature - -CUresult kernel_{specialization_suffix}(CUstream stream, unsigned gX, unsigned gY, unsigned gZ, float* arg0, int32_t arg1, int32_t arg2) - -Different such specialized entry points can be combined using the `linker.py` script. - -NOTE: when resolving the scope of /path/to/kernel.py, the file will be executed from within its parent directory with the python interpreter -used to run this `compile.py` script -""" - - -def compile_kernel( - path, - kernel_name: str, - signature: str, - grid: str, - num_warps: int = 1, - num_stages: int = 3, - out_name: str = None, - out_path: Path = None, - waves_per_eu=0, - kpack=2, - matrix_instr_nonkdim=16, -): - out_name = out_name if out_name else kernel_name - out_path = out_path if out_path else Path(out_name) - - arg_path = Path(path) - sys.path.insert(0, str(arg_path.parent)) - spec = importlib.util.spec_from_file_location(arg_path.stem, arg_path) - mod = importlib.util.module_from_spec(spec) - spec.loader.exec_module(mod) - kernel = getattr(mod, kernel_name) - grid = grid.split(",") - assert len(grid) == 3 - - # validate and parse signature - signature = list(map(lambda s: s.strip(" "), signature.split(","))) - - def hash_signature(signature: List[str]): - m = hashlib.sha256() - m.update(" ".join(signature).encode()) - return m.hexdigest()[:8] - - meta_sig = f"warps{num_warps}xstages{num_stages}" - sig_hash = hash_signature(signature + [meta_sig]) - - def constexpr(s): - try: - ret = int(s) - return ret - except ValueError: - pass - try: - ret = float(s) - return ret - except ValueError: - pass - return None - - if old_compiler: - hints = { - i: constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s - } - hints = {k: v for k, v in hints.items() if v is not None} - constants = {kernel.arg_names[i]: constexpr(s) for i, s in enumerate(signature)} - constants = {k: v for k, v in constants.items() if v is not None} - signature = { - kernel.arg_names[i]: s.split(":")[0] - for i, s in enumerate(signature) - if kernel.arg_names[i] not in constants - } - const_sig = "x".join([str(v) for v in constants.values()]) - doc_string = [f"{k}={v}" for k, v in constants.items()] - doc_string += [f"num_warps={num_warps}", f"num_stages={num_stages}"] - - # compile ast into cubin - for h in hints.values(): - assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}" - attrs = triton.backends.compiler.AttrsDescriptor.from_hints(hints) - for p, v in attrs.get_constants().items(): - constants.update({kernel.arg_names[p]: v}) - - src = triton.compiler.ASTSource( - fn=kernel, constants=constants, signature=signature, attrs=attrs - ) - opts = { - "num_warps": num_warps, - "num_stages": num_stages, - "waves_per_eu": waves_per_eu, - "kpack": kpack, - "matrix_instr_nonkdim": matrix_instr_nonkdim, - } - ccinfo = triton.compile(src, options=opts) - arg_names = [] - arg_types = [] - arg_names_not_1 = [] - arg_types_not_1 = [] - for i, arg_name in enumerate(kernel.arg_names): - if arg_name not in constants: - arg_names.append(arg_name) - arg_types.append(signature[arg_name]) - arg_names_not_1.append(arg_name) - arg_types_not_1.append(signature[arg_name]) - elif i in attrs.equal_to_1: - arg_names.append(arg_name) - arg_types.append(signature[arg_name]) - - # dump C stub code - suffix = kernel_suffix(signature.values(), attrs) - else: - hints = { - (i,): constexpr(s.split(":")[1]) - for i, s in enumerate(signature) - if ":" in s - } - hints = {k: v for k, v in hints.items() if v is not None} - constants = {kernel.arg_names[i]: constexpr(s) for i, s in enumerate(signature)} - constants = {k: v for k, v in constants.items() if v is not None} - for key, value in hints.items(): - if value == 1: - constants[kernel.arg_names[key[0]]] = value - signature = { - kernel.arg_names[i]: s.split(":")[0] for i, s in enumerate(signature) - } - for key in constants: - signature[key] = "constexpr" - const_sig = "x".join([str(v) for v in constants.values()]) - doc_string = [f"{k}={v}" for k, v in constants.items()] - doc_string += [ - f"num_warps={num_warps}", - f"num_stages={num_stages}", - f"waves_per_eu={waves_per_eu}", - f"kpack={kpack}", - f"matrix_instr_nonkdim={matrix_instr_nonkdim}", - ] - # compile ast into cubin - for h in hints.values(): - assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}" - attrs = {k: [["tt.divisibility", 16]] for k, v in hints.items() if v == 16} - src = triton.compiler.ASTSource( - fn=kernel, constexprs=constants, signature=signature, attrs=attrs - ) - opts = { - "num_warps": num_warps, - "num_stages": num_stages, - "waves_per_eu": waves_per_eu, - "kpack": kpack, - "matrix_instr_nonkdim": matrix_instr_nonkdim, - } - ccinfo = triton.compile(src, options=opts) - if ccinfo.metadata.global_scratch_size > 0: - raise RuntimeError( - "AOT compiling kernels with global scratch requirements is not yet implemented" - ) - - arg_names = [] - arg_types = [] - arg_names_not_1 = [] - arg_types_not_1 = [] - for i, arg_name in enumerate(kernel.arg_names): - if arg_name not in constants: - arg_names.append(arg_name) - arg_types.append(signature[arg_name]) - arg_names_not_1.append(arg_name) - arg_types_not_1.append(signature[arg_name]) - elif hints.get((i,), None) == 1: - arg_names.append(arg_name) - arg_types.append("i32") - - # dump C stub code - suffix = "" - for i, ty in enumerate(signature.values()): - suffix += str(i) - if hints.get((i,), None) == 1: - suffix += "c" - if hints.get((i,), None) == 16: - suffix += "d" - - func_name = "_".join([out_name, sig_hash, suffix]) - hex_ = binascii.hexlify(ccinfo.asm["hsaco"]).decode("utf-8") - - params = { - "kernel_name": func_name, - "triton_kernel_name": kernel_name, - "bin_size": len(hex_), - "bin_data": ", ".join([f"0x{x}{y}" for x, y in zip(hex_[::2], hex_[1::2])]), - "signature": ", ".join( - [ - f"{ty_to_cpp(ty)} {name}" - for name, ty in zip(arg_names_not_1, arg_types_not_1) - ] - ), - "full_signature": ", ".join( - [f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)] - ), - "arg_pointers": ", ".join( - [f"&{arg}" for arg in arg_names_not_1] + ["&global_scratch"] - ), - "num_args": len(arg_names_not_1) + 1, - "kernel_docstring": doc_string, - "shared": ccinfo.metadata.shared, - "num_warps": num_warps, - "algo_info": "_".join([const_sig, meta_sig]), - "gridX": grid[0], - "gridY": grid[1], - "gridZ": grid[2], - "_placeholder": "", - } - output_files = [] - for ext in ["h", "cpp"]: - template_path = Path(__file__).parent / f"compile.{ext}" - output_file = out_path.with_suffix(f".{sig_hash}_{suffix}.{ext}") - output_files.append(output_file) - with output_file.open("w") as fp: - fp.write(Path(template_path).read_text().format(**params)) - return func_name, *output_files - - -if __name__ == "__main__": - - # command-line arguments - parser = ArgumentParser(description=desc) - parser.add_argument( - "path", - help="Path to Python source containing desired kernel in its scope. File will be executed.", - ) - parser.add_argument( - "--kernel-name", - "-n", - type=str, - default="", - help="Name of the kernel to compile", - required=True, - ) - parser.add_argument( - "--num-warps", - "-w", - type=int, - default=1, - help="Number of warps to launch the kernel", - ) - parser.add_argument("--waves-per-eu", type=int, default=1) - parser.add_argument("--matrix-instr-nonkdim", type=int, default=0) - parser.add_argument("--kpack", type=int, default=1) - parser.add_argument( - "--num-stages", - "-ns", - type=int, - default=3, - help="Number of stages (meta-parameter of the kernel)", - ) - parser.add_argument( - "--out-name", - "-on", - type=str, - default=None, - help="Out name for the compiled kernel", - ) - parser.add_argument( - "--out-path", "-o", type=Path, default=None, help="Out filename" - ) - parser.add_argument( - "--signature", "-s", type=str, help="Signature of the kernel", required=True - ) - parser.add_argument( - "--grid", "-g", type=str, help="Launch grid of the kernel", required=True - ) - args = parser.parse_args() - compile_kernel(**vars(args)) diff --git a/aiter/ops/sampling.py b/aiter/ops/sampling.py new file mode 100644 index 0000000000..0df541569b --- /dev/null +++ b/aiter/ops/sampling.py @@ -0,0 +1,85 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +import torch +from typing import Optional + +from csrc.cpp_itfs.sampling.top_k_renorm_probs import ( + top_k_renorm_probs as top_k_renorm_probs_core, +) +from csrc.cpp_itfs.sampling.top_p_sampling_from_probs import ( + top_p_sampling_from_probs as top_p_sampling_from_probs_core, +) +from csrc.cpp_itfs.sampling.top_k_top_p_sampling_from_probs import ( + top_k_top_p_sampling_from_probs as top_k_top_p_sampling_from_probs_core, +) +from csrc.cpp_itfs.torch_utils import direct_register_custom_op + + +def top_k_renorm_probs( + probs: torch.Tensor, + maybe_top_k_arr: Optional[torch.Tensor], + top_k_val: int, +) -> torch.Tensor: + return top_k_renorm_probs_core( + probs, + maybe_top_k_arr, + top_k_val, + ) + + +direct_register_custom_op( + "top_k_renorm_probs", + top_k_renorm_probs, + [], +) + + +def top_p_sampling_from_probs( + probs: torch.Tensor, + indices: torch.Tensor, + maybe_top_p_arr: Optional[torch.Tensor], + top_p_val: float, + deterministic: bool = False, +) -> torch.Tensor: + return top_p_sampling_from_probs_core( + probs, + indices, + maybe_top_p_arr, + top_p_val, + deterministic, + ) + + +direct_register_custom_op( + "top_p_sampling_from_probs", + top_p_sampling_from_probs, + [], +) + + +def top_k_top_p_sampling_from_probs( + probs: torch.Tensor, + indices: torch.Tensor, + maybe_top_k_arr: Optional[torch.Tensor], + top_k_val: int, + maybe_top_p_arr: Optional[torch.Tensor], + top_p_val: float, + deterministic: bool = False, +) -> torch.Tensor: + return top_k_top_p_sampling_from_probs_core( + probs, + indices, + maybe_top_k_arr, + top_k_val, + maybe_top_p_arr, + top_p_val, + deterministic, + ) + + +direct_register_custom_op( + "top_k_top_p_sampling_from_probs", + top_k_top_p_sampling_from_probs, + [], +) diff --git a/csrc/cpp_itfs/sampling/sampling.cuh b/csrc/cpp_itfs/sampling/sampling.cuh new file mode 100644 index 0000000000..b973c634b5 --- /dev/null +++ b/csrc/cpp_itfs/sampling/sampling.cuh @@ -0,0 +1,818 @@ +/* + * Copyright (C) 2024-2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "hip/hip_runtime.h" + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "vec_dtypes.cuh" + +namespace aiter { + +namespace sampling { + +using namespace hipcub; + +constexpr uint32_t BLOCK_THREADS = 1024; + +constexpr BlockScanAlgorithm SCAN_ALGO = BLOCK_SCAN_WARP_SCANS; +constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS; + +template +struct ValueCount +{ + T value; + int count; + + __device__ ValueCount operator+(const ValueCount& other) const + { + return {value + other.value, count + other.count}; + } + __device__ ValueCount& operator+=(const ValueCount& other) + { + value += other.value; + count += other.count; + return *this; + } +}; + +struct BoolDiffOp +{ + __device__ __forceinline__ bool operator()(const bool& lhs, const bool& rhs) const + { + return lhs != rhs; + } +}; + +template +__forceinline__ __device__ __host__ T1 ceil_div(const T1 x, const T2 y) +{ + return (x + y - 1) / y; +} + +template +struct SamplingTempStorage +{ + union + { + float deterministic_scan[BLOCK_THREADS / 32]; + typename BlockScan::TempStorage scan; + typename BlockReduce::TempStorage reduce; + typename BlockReduce::TempStorage reduce_int; + typename BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage + reduce_value_count; + typename BlockAdjacentDifference::TempStorage adj_diff; + } block_prim; + struct + { + int32_t sampled_id; + int32_t last_valid_id; + float max_val; + union + { + float value; + ValueCount pair; + } block_aggregate; + }; +}; + +template +__device__ __forceinline__ T infinity() +{ + return __builtin_huge_valf(); +} + +/*! + * \brief Deterministic inclusive scan implementation, use Belloch scan algorithm. + * \note This implementation is slower than the hipcub::BlockScan, but it is deterministic. + */ +template +__device__ __forceinline__ void DeterministicInclusiveSum( + const float* in_data, + float* out_data, + SamplingTempStorage* temp_storage) +{ + float* smem_prefix_sum = temp_storage->block_prim.deterministic_scan; + float thread_data[VEC_SIZE]; + float thread_sum = 0; +#pragma unroll + for(uint32_t i = 0; i < VEC_SIZE; ++i) + { + thread_sum += in_data[i]; + thread_data[i] = thread_sum; + } + + float thread_exclusive_prefix_sum = thread_sum; + +#pragma unroll + for(uint32_t offset = 1; offset < 32; offset *= 2) + { + float tmp = __shfl_up(thread_exclusive_prefix_sum, offset); + if((threadIdx.x + 1) % (offset * 2) == 0) + { + thread_exclusive_prefix_sum += tmp; + } + } + + float warp_sum = __shfl(thread_exclusive_prefix_sum, threadIdx.x | 0xffffffff); + if(threadIdx.x % 32 == 31) + { + thread_exclusive_prefix_sum = 0; + } + +#pragma unroll + for(uint32_t offset = 16; offset >= 1; offset /= 2) + { + float tmp = __shfl_xor(thread_exclusive_prefix_sum, offset); + if((threadIdx.x + 1) % (offset * 2) == 0) + { + thread_exclusive_prefix_sum = tmp + thread_exclusive_prefix_sum; + } + if((threadIdx.x + 1) % (offset * 2) == offset) + { + thread_exclusive_prefix_sum = tmp; + } + } + + smem_prefix_sum[threadIdx.x / 32] = warp_sum; + __syncthreads(); + + if(threadIdx.x < 32) + { + float warp_exclusive_prefix_sum = + (threadIdx.x < BLOCK_THREADS / 32) ? smem_prefix_sum[threadIdx.x] : 0; + +#pragma unroll + for(uint32_t offset = 1; offset < 32; offset *= 2) + { + float tmp = __shfl_up(warp_exclusive_prefix_sum, offset); + if((threadIdx.x + 1) % (offset * 2) == 0) + { + warp_exclusive_prefix_sum += tmp; + } + } + + if(threadIdx.x % 32 == 31) + { + warp_exclusive_prefix_sum = 0; + } + +#pragma unroll + for(uint32_t offset = 16; offset >= 1; offset /= 2) + { + float tmp = __shfl_xor(warp_exclusive_prefix_sum, offset); + if((threadIdx.x + 1) % (offset * 2) == 0) + { + warp_exclusive_prefix_sum = tmp + warp_exclusive_prefix_sum; + } + if((threadIdx.x + 1) % (offset * 2) == offset) + { + warp_exclusive_prefix_sum = tmp; + } + } + if(threadIdx.x < BLOCK_THREADS / 32) + { + smem_prefix_sum[threadIdx.x] = warp_exclusive_prefix_sum; + } + } + __syncthreads(); + +#pragma unroll + for(uint32_t i = 0; i < VEC_SIZE; ++i) + { + out_data[i] = + smem_prefix_sum[threadIdx.x / 32] + thread_exclusive_prefix_sum + thread_data[i]; + } +} + +template +__device__ __forceinline__ float +GetMaxValue(float* in_data, uint32_t row_idx, uint32_t d, TempStorage& temp_storage) +{ + const uint32_t tx = threadIdx.x; + vec_t in_data_vec; + + float max_val = 0; + for(uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) + { + in_data_vec.fill(0); + if((i * BLOCK_THREADS + tx) * VEC_SIZE < d) + { + in_data_vec.cast_load(in_data + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + } + float in_data_[VEC_SIZE]; +#pragma unroll + for(uint32_t j = 0; j < VEC_SIZE; ++j) + { + in_data_[j] = in_data_vec[j]; + } + max_val = + max(max_val, + BlockReduce(temp_storage.block_prim.reduce) + .Reduce(in_data_, hipcub::Max())); + __syncthreads(); + } + if(tx == 0) + { + temp_storage.max_val = max_val; + } + __syncthreads(); + return temp_storage.max_val; +} + +template +__device__ __forceinline__ void DeviceSamplingFromProb( + uint32_t i, + uint32_t d, + Predicate pred, + float u, + vec_t prob_vec, + float& aggregate, + SamplingTempStorage* temp_storage) +{ + const uint32_t tx = threadIdx.x; + float prob_greater_than_threshold[VEC_SIZE]; + float inclusive_cdf[VEC_SIZE]; + bool greater_than_u[VEC_SIZE], valid[VEC_SIZE]; +#pragma unroll + for(uint32_t j = 0; j < VEC_SIZE; ++j) + { + prob_greater_than_threshold[j] = pred(prob_vec[j]) ? prob_vec[j] : 0; + valid[j] = pred(prob_vec[j]) && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d; + } + float aggregate_local = + BlockReduce(temp_storage->block_prim.reduce) + .Sum(prob_greater_than_threshold); + if(tx == 0) + { + temp_storage->block_aggregate.value = aggregate_local; + } + __syncthreads(); + aggregate_local = temp_storage->block_aggregate.value; + + if(aggregate + aggregate_local > u) + { + if constexpr(DETERMINISTIC) + { + DeterministicInclusiveSum( + prob_greater_than_threshold, inclusive_cdf, temp_storage); + } + else + { + BlockScan(temp_storage->block_prim.scan) + .InclusiveSum(prob_greater_than_threshold, inclusive_cdf); + + __syncthreads(); + } + +#pragma unroll + for(uint32_t j = 0; j < VEC_SIZE; ++j) + { + greater_than_u[j] = (inclusive_cdf[j] + aggregate > u) && valid[j]; + } + + bool greater_than_u_diff[VEC_SIZE]; + + BlockAdjacentDifference(temp_storage->block_prim.adj_diff) + .SubtractLeft(greater_than_u, greater_than_u_diff, BoolDiffOp{}); + + __syncthreads(); + +#pragma unroll + for(uint32_t j = 0; j < VEC_SIZE; ++j) + { + if(greater_than_u_diff[j]) + { + atomicMin(&(temp_storage->sampled_id), (i * BLOCK_THREADS + tx) * VEC_SIZE + j); + } + } + __syncthreads(); + } + + // update the last valid index + int valid_index[VEC_SIZE]; +#pragma unroll + for(uint32_t j = 0; j < VEC_SIZE; ++j) + { + if(valid[j]) + { + valid_index[j] = (i * BLOCK_THREADS + tx) * VEC_SIZE + j; + } + else + { + valid_index[j] = -1; + } + } + int max_valid_index = + BlockReduce(temp_storage->block_prim.reduce_int) + .Reduce(valid_index, hipcub::Max()); + if(tx == 0 && max_valid_index != -1) + { + temp_storage->last_valid_id = max_valid_index; + } + __syncthreads(); + aggregate += aggregate_local; +} + +template +__global__ void TopPSamplingFromProbKernel(DType* probs, + IdType* output, + IdType* indices, + float* top_p_arr, + float top_p_val, + uint32_t d, + uint64_t philox_seed, + uint64_t philox_offset) +{ + const uint32_t batch_size = gridDim.x; + const uint32_t bx = blockIdx.x, tx = threadIdx.x; + hiprandStatePhilox4_32_10_t state; + hiprand_init(philox_seed, bx, philox_offset, &state); + const uint32_t row_idx = indices == nullptr ? bx : indices[bx]; + float top_p = (top_p_arr == nullptr) ? top_p_val : top_p_arr[row_idx]; + + extern __shared__ __align__( + alignof(SamplingTempStorage)) + uint8_t smem_sampling[]; + auto& temp_storage = + reinterpret_cast&>( + smem_sampling); + + vec_t probs_vec; + float aggregate; + float q = 1; + double low = 0, high = 1.f; + int sampled_id; + do + { + temp_storage.sampled_id = d; + __syncthreads(); + float u = hiprand_uniform(&state) * q; + aggregate = 0; +#pragma unroll 2 + for(uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) + { + probs_vec.fill(0); + if((i * BLOCK_THREADS + tx) * VEC_SIZE < d) + { + probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + } + + DeviceSamplingFromProb( + i, d, [&](float x) { return x > low; }, u, probs_vec, aggregate, &temp_storage); + if(aggregate > u) + { + break; + } + } + __syncthreads(); + sampled_id = temp_storage.sampled_id; + if(sampled_id == d) + { + // NOTE(Zihao): this would happen when u is very close to 1 + // and the sum of probabilities is smaller than u + // In this case, we use the last valid index as the sampled id + sampled_id = temp_storage.last_valid_id; + } + double pivot_0 = probs[row_idx * d + sampled_id]; + double pivot_1 = (pivot_0 + high) / 2; + + float aggregate_gt_pivot_0 = 0, aggregate_gt_pivot_1 = 0; +#pragma unroll 2 + for(uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) + { + probs_vec.fill(0); + if((i * BLOCK_THREADS + tx) * VEC_SIZE < d) + { + probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + } + + float probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE]; +#pragma unroll + for(uint32_t j = 0; j < VEC_SIZE; ++j) + { + probs_gt_pivot_0[j] = (probs_vec[j] > pivot_0) ? probs_vec[j] : 0; + probs_gt_pivot_1[j] = (probs_vec[j] > pivot_1) ? probs_vec[j] : 0; + } + + aggregate_gt_pivot_0 += + BlockReduce(temp_storage.block_prim.reduce) + .Sum(probs_gt_pivot_0); + if(tx == 0) + { + temp_storage.block_aggregate.value = aggregate_gt_pivot_0; + } + __syncthreads(); + aggregate_gt_pivot_0 = temp_storage.block_aggregate.value; + + aggregate_gt_pivot_1 += + BlockReduce(temp_storage.block_prim.reduce) + .Sum(probs_gt_pivot_1); + if(tx == 0) + { + temp_storage.block_aggregate.value = aggregate_gt_pivot_1; + } + __syncthreads(); + aggregate_gt_pivot_1 = temp_storage.block_aggregate.value; + } + if(aggregate_gt_pivot_0 < top_p) + { + // case 1: pivot_0 accepted + break; + } + if(aggregate_gt_pivot_1 < top_p) + { + // case 2: pivot_0 rejected, pivot_1 accepted + low = pivot_0; + high = pivot_1; + q = aggregate_gt_pivot_0; + } + else + { + // case 3: pivot_0 rejected, pivot_1 rejected + low = pivot_1; + q = aggregate_gt_pivot_1; + } + } while(low < high); + __syncthreads(); + if(tx == 0) + { + output[bx] = sampled_id; + } +} + +template +__global__ void TopKTopPSamplingFromProbKernel(DType* probs, + IdType* top_k_arr, + float* top_p_arr, + IdType* output, + IdType* indices, + IdType top_k_val, + float top_p_val, + uint32_t d, + uint64_t philox_seed, + uint64_t philox_offset) +{ + const uint32_t batch_size = gridDim.x; + const uint32_t bx = blockIdx.x, tx = threadIdx.x; + hiprandStatePhilox4_32_10_t state; + hiprand_init(philox_seed, bx, philox_offset, &state); + const uint32_t row_idx = indices == nullptr ? bx : indices[bx]; + const uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[row_idx]; + const float p = top_p_arr == nullptr ? top_p_val : top_p_arr[row_idx]; + + extern __shared__ __align__( + alignof(SamplingTempStorage)) + uint8_t smem_sampling[]; + auto& temp_storage = + reinterpret_cast&>( + smem_sampling); + + vec_t probs_vec; + float aggregate; + float q = 1; + double low = 0, high = 1.f; + int sampled_id; + do + { + temp_storage.sampled_id = d; + __syncthreads(); + float u = hiprand_uniform(&state) * q; + aggregate = 0; +#pragma unroll 2 + for(uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) + { + probs_vec.fill(0); + if((i * BLOCK_THREADS + tx) * VEC_SIZE < d) + { + probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + } + + DeviceSamplingFromProb( + i, d, [&](float x) { return x > low; }, u, probs_vec, aggregate, &temp_storage); + if(aggregate > u) + { + break; + } + } + __syncthreads(); + sampled_id = temp_storage.sampled_id; + if(sampled_id == d) + { + // NOTE(Zihao): this would happen when u is very close to 1 + // and the sum of probabilities is smaller than u + // In this case, we use the last valid index as the sampled id + sampled_id = temp_storage.last_valid_id; + } + double pivot_0 = probs[row_idx * d + sampled_id]; + double pivot_1 = (pivot_0 + high) / 2; + + ValueCount aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0}; +#pragma unroll 2 + for(uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) + { + probs_vec.fill(0); + if((i * BLOCK_THREADS + tx) * VEC_SIZE < d) + { + probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + } + + ValueCount probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE]; +#pragma unroll + for(uint32_t j = 0; j < VEC_SIZE; ++j) + { + probs_gt_pivot_0[j] = { + (probs_vec[j] > pivot_0) ? probs_vec[j] : 0, + (probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + probs_gt_pivot_1[j] = { + (probs_vec[j] > pivot_1) ? probs_vec[j] : 0, + (probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + } + + aggregate_gt_pivot_0 += BlockReduce, BLOCK_THREADS>( + temp_storage.block_prim.reduce_value_count) + .Sum(probs_gt_pivot_0); + if(tx == 0) + { + temp_storage.block_aggregate.pair = aggregate_gt_pivot_0; + } + __syncthreads(); + aggregate_gt_pivot_0 = temp_storage.block_aggregate.pair; + + aggregate_gt_pivot_1 += BlockReduce, BLOCK_THREADS>( + temp_storage.block_prim.reduce_value_count) + .Sum(probs_gt_pivot_1); + if(tx == 0) + { + temp_storage.block_aggregate.pair = aggregate_gt_pivot_1; + } + __syncthreads(); + aggregate_gt_pivot_1 = temp_storage.block_aggregate.pair; + } + if(aggregate_gt_pivot_0.count < k && aggregate_gt_pivot_0.value < p) + { + // case 1: pivot_0 accepted + break; + } + if(aggregate_gt_pivot_1.count < k && aggregate_gt_pivot_1.value < p) + { + // case 2: pivot_0 rejected, pivot_1 accepted + low = pivot_0; + high = pivot_1; + q = aggregate_gt_pivot_0.value; + } + else + { + // case 3: pivot_0 rejected, pivot_1 rejected + low = pivot_1; + q = aggregate_gt_pivot_1.value; + } + } while(low < high); + __syncthreads(); + if(tx == 0) + { + output[bx] = sampled_id; + } +} + +template +struct RenormTempStorage +{ + union + { + typename BlockReduce::TempStorage reduce; + typename BlockReduce::TempStorage reduce_int; + typename BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage + reduce_value_count; + } block_prim; + struct + { + float max_val; + float min_val; + union + { + struct + { + float values[2]; + }; + struct + { + int counts[2]; + }; + struct + { + ValueCount pairs[2]; + }; + } block_aggregate; + }; +}; + +template +__global__ void TopKRenormProbKernel( + DType* probs, DType* renormed_prob, IdType* top_k_arr, uint32_t top_k_val, uint32_t d) +{ + const uint32_t bx = blockIdx.x, tx = threadIdx.x; + const uint32_t row_idx = bx; + uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx]; + double pivot = -infinity(), normalizer = 1; + vec_t probs_vec; + if(k < d) + { + extern __shared__ __align__(alignof(RenormTempStorage)) + uint8_t smem_renorm[]; + auto& temp_storage = + reinterpret_cast&>(smem_renorm); + temp_storage.max_val = 0; + + float max_val = GetMaxValue>( + probs, row_idx, d, temp_storage); + + double low = 0, high = max_val; + float min_gt_low, max_le_high; + float sum_low = 1; + // f(x) = len(nonzero(probs > x)), f(x) is non-increasing + // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high} + // loop invariant: + // - f(low) >= k, f(high) < k + // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high) + // stopping condition: min_gt_low == max_le_high + // - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k + do + { + double pivot_0 = (high + 2 * low) / 3; + double pivot_1 = (2 * high + low) / 3; + + ValueCount aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0}; + min_gt_low = high; + max_le_high = low; +#pragma unroll 2 + for(uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) + { + probs_vec.fill(0); + if((i * BLOCK_THREADS + tx) * VEC_SIZE < d) + { + probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + + tx * VEC_SIZE); + } + ValueCount probs_gt_pivot_0_pair[VEC_SIZE], probs_gt_pivot_1_pair[VEC_SIZE]; +#pragma unroll + for(uint32_t j = 0; j < VEC_SIZE; ++j) + { + probs_gt_pivot_0_pair[j] = { + (probs_vec[j] > pivot_0) ? probs_vec[j] : 0, + (probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + probs_gt_pivot_1_pair[j] = { + (probs_vec[j] > pivot_1) ? probs_vec[j] : 0, + (probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + + if(probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) + { + min_gt_low = min(min_gt_low, probs_vec[j]); + } + if(probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) + { + max_le_high = max(max_le_high, probs_vec[j]); + } + } + + aggregate_gt_pivot_0 += + BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_value_count) + .Sum(probs_gt_pivot_0_pair); + __syncthreads(); + + aggregate_gt_pivot_1 += + BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_value_count) + .Sum(probs_gt_pivot_1_pair); + __syncthreads(); + } + min_gt_low = + BlockReduce(temp_storage.block_prim.reduce) + .Reduce(min_gt_low, hipcub::Min()); + __syncthreads(); + max_le_high = + BlockReduce(temp_storage.block_prim.reduce) + .Reduce(max_le_high, hipcub::Max()); + if(tx == 0) + { + temp_storage.block_aggregate.pairs[0] = aggregate_gt_pivot_0; + temp_storage.block_aggregate.pairs[1] = aggregate_gt_pivot_1; + temp_storage.min_val = min_gt_low; + temp_storage.max_val = max_le_high; + } + __syncthreads(); + aggregate_gt_pivot_0 = temp_storage.block_aggregate.pairs[0]; + aggregate_gt_pivot_1 = temp_storage.block_aggregate.pairs[1]; + min_gt_low = temp_storage.min_val; + max_le_high = temp_storage.max_val; + + if(aggregate_gt_pivot_1.count >= k) + { + low = pivot_1; + sum_low = float(aggregate_gt_pivot_1.value); + } + else if(aggregate_gt_pivot_0.count >= k) + { + low = pivot_0; + high = min(pivot_1, max_le_high); + sum_low = float(aggregate_gt_pivot_0.value); + } + else + { + high = min(pivot_0, max_le_high); + } + } while(min_gt_low != max_le_high); + + normalizer = __frcp_rn(max(sum_low, 1e-8)); + pivot = low; + } + + // normalize +#pragma unroll 2 + for(uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) + { + probs_vec.fill(0); + if((i * BLOCK_THREADS + tx) * VEC_SIZE < d) + { + probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + } +#pragma unroll + for(uint32_t j = 0; j < VEC_SIZE; ++j) + { + probs_vec[j] = (probs_vec[j] > pivot) ? probs_vec[j] * normalizer : 0; + } + if((i * BLOCK_THREADS + tx) * VEC_SIZE < d) + { + probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + + tx * VEC_SIZE); + } + } +} + +} // namespace sampling + +} // namespace aiter \ No newline at end of file diff --git a/csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja b/csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja new file mode 100644 index 0000000000..f7d0261f9c --- /dev/null +++ b/csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja @@ -0,0 +1,43 @@ +/* + * Copyright (C) 2024-2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#include "sampling.cuh" + + +#define FUNCTION_DEFINE \ + void {{func_name}}(void* probs_ptr, \ + void* renormed_probs_ptr, \ + void* top_k_arr_ptr, \ + int batch_size, \ + int top_k_val, \ + void* stream) + +extern "C" { +FUNCTION_DEFINE; +} + +FUNCTION_DEFINE +{ + constexpr uint32_t vec_size = std::gcd(16 / sizeof(float), {{d}}); + + const uint32_t smem_size = sizeof(aiter::sampling::RenormTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(aiter::sampling::BLOCK_THREADS); + auto kernel = aiter::sampling::TopKRenormProbKernel; + hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); + kernel<<(stream)>>>(reinterpret_cast(probs_ptr), reinterpret_cast(renormed_probs_ptr), reinterpret_cast(top_k_arr_ptr), top_k_val, {{d}}); +} \ No newline at end of file diff --git a/csrc/cpp_itfs/sampling/top_k_renorm_probs.py b/csrc/cpp_itfs/sampling/top_k_renorm_probs.py new file mode 100644 index 0000000000..cfc816798f --- /dev/null +++ b/csrc/cpp_itfs/sampling/top_k_renorm_probs.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + + +from jinja2 import Template +from csrc.cpp_itfs.utils import compile_template_op, AITER_CORE_DIR + + +MD_NAME = "top_k_renorm_probs" + +with open( + f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja", + "r", +) as f: + src_template = Template(f.read()) + + +def compile( + d: int, + folder: str = None, +): + return compile_template_op( + src_template, + MD_NAME, + [ + f"{AITER_CORE_DIR}/csrc/cpp_itfs/utils.h", + f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/sampling.cuh", + f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/vec_dtypes.cuh", + ], + d=d, + folder=folder, + ) + + +def top_k_renorm_probs( + probs, + maybe_top_k_arr, + top_k_val, +): + import torch + from csrc.cpp_itfs.torch_utils import torch_to_c_types + + probs = probs.float() + maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None + top_k_val = int(top_k_val) + + batch_size = probs.size(0) + vocab_size = probs.size(1) + + renorm_probs = torch.empty_like(probs) + + func = compile(vocab_size) + ( + probs_ptr, + renorm_probs_ptr, + top_k_arr_ptr, + top_k_val, + batch_size, + stream, + ) = torch_to_c_types( + probs, + renorm_probs, + maybe_top_k_arr, + top_k_val, + batch_size, + torch.cuda.current_stream(), + ) + func( + probs_ptr, + renorm_probs_ptr, + top_k_arr_ptr, + batch_size, + top_k_val, + stream, + ) + return renorm_probs + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--d", type=int, required=True) + parser.add_argument("--folder", type=str, default=None) + args = parser.parse_args() + compile(**vars(args)) diff --git a/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.cpp.jinja b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.cpp.jinja new file mode 100644 index 0000000000..301b5c9790 --- /dev/null +++ b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.cpp.jinja @@ -0,0 +1,49 @@ +/* + * Copyright (C) 2024-2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#include "sampling.cuh" + + +#define FUNCTION_DEFINE \ + void {{func_name}}(void* probs_ptr, \ + void* output_ptr, \ + void* indices_ptr, \ + void* top_k_arr_ptr, \ + void* top_p_arr_ptr, \ + int batch_size, \ + int top_k_val, \ + float top_p_val, \ + int philox_seed, \ + int philox_offset, \ + void* stream) + +extern "C" { +FUNCTION_DEFINE; +} + +FUNCTION_DEFINE +{ + constexpr uint32_t vec_size = std::gcd(16 / sizeof(float), {{d}}); + + const uint32_t smem_size = sizeof(aiter::sampling::SamplingTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(aiter::sampling::BLOCK_THREADS); + auto kernel = aiter::sampling::TopKTopPSamplingFromProbKernel; + hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); + kernel<<(stream)>>>(reinterpret_cast(probs_ptr), reinterpret_cast(top_k_arr_ptr), reinterpret_cast(top_p_arr_ptr), reinterpret_cast(output_ptr), reinterpret_cast(indices_ptr), top_k_val, top_p_val, {{d}}, static_cast(philox_seed), static_cast(philox_offset)); +} \ No newline at end of file diff --git a/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py new file mode 100644 index 0000000000..48fbe6e6f3 --- /dev/null +++ b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + + +from jinja2 import Template +from csrc.cpp_itfs.utils import compile_template_op, AITER_CORE_DIR, str_to_bool + + +MD_NAME = "top_k_top_p_sampling_from_probs" + +with open( + f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.cpp.jinja", + "r", +) as f: + src_template = Template(f.read()) + + +def compile( + d: int, + deterministic: bool, + folder: str = None, +): + return compile_template_op( + src_template, + MD_NAME, + [ + f"{AITER_CORE_DIR}/csrc/cpp_itfs/utils.h", + f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/sampling.cuh", + f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/vec_dtypes.cuh", + ], + d=d, + deterministic=deterministic, + folder=folder, + ) + + +def top_k_top_p_sampling_from_probs( + probs, + indices, + maybe_top_k_arr, + top_k_val, + maybe_top_p_arr, + top_p_val, + deterministic=False, + generator=None, +): + import torch + from csrc.cpp_itfs.torch_utils import torch_to_c_types + + if generator is None: + generator = torch.cuda.default_generators[probs.device.index] + probs = probs.float() + top_p_val = float(top_p_val) + top_k_val = int(top_k_val) + maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None + maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None + + batch_size = indices.size(0) if indices is not None else probs.size(0) + vocab_size = probs.size(1) + philox_offset = generator.get_offset() + philox_seed = generator.seed() + + output = torch.empty(batch_size, dtype=torch.int32, device=probs.device) + + func = compile(vocab_size, deterministic) + ( + probs_ptr, + output_ptr, + indices_ptr, + top_k_arr_ptr, + top_p_arr_ptr, + top_k_val, + top_p_val, + batch_size, + philox_seed, + philox_offset, + stream, + ) = torch_to_c_types( + probs, + output, + indices, + maybe_top_k_arr, + maybe_top_p_arr, + top_k_val, + top_p_val, + batch_size, + philox_seed, + philox_offset, + torch.cuda.current_stream(), + ) + func( + probs_ptr, + output_ptr, + indices_ptr, + top_k_arr_ptr, + top_p_arr_ptr, + batch_size, + top_k_val, + top_p_val, + philox_seed, + philox_offset, + stream, + ) + return output + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--d", type=int, required=True) + parser.add_argument("--deterministic", type=str_to_bool, required=True) + parser.add_argument("--folder", type=str, default=None) + args = parser.parse_args() + compile(**vars(args)) diff --git a/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.cpp.jinja b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.cpp.jinja new file mode 100644 index 0000000000..99c23b44e7 --- /dev/null +++ b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.cpp.jinja @@ -0,0 +1,46 @@ +/* + * Copyright (C) 2024-2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "sampling.cuh" + + +#define FUNCTION_DEFINE \ + void {{func_name}}(void* probs_ptr, \ + void* output_ptr, \ + void* indices_ptr, \ + void* top_p_arr_ptr, \ + int batch_size, \ + float top_p_val, \ + int philox_seed, \ + int philox_offset, \ + void* stream) + +extern "C" { +FUNCTION_DEFINE; +} + +FUNCTION_DEFINE +{ + constexpr uint32_t vec_size = std::gcd(16 / sizeof(float), {{d}}); + + const uint32_t smem_size = sizeof(aiter::sampling::SamplingTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(aiter::sampling::BLOCK_THREADS); + auto kernel = aiter::sampling::TopPSamplingFromProbKernel; + hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); + kernel<<(stream)>>>(reinterpret_cast(probs_ptr), reinterpret_cast(output_ptr), reinterpret_cast(indices_ptr), reinterpret_cast(top_p_arr_ptr), top_p_val, {{d}}, static_cast(philox_seed), static_cast(philox_offset)); +} \ No newline at end of file diff --git a/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py new file mode 100644 index 0000000000..7e1500b231 --- /dev/null +++ b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + + +from jinja2 import Template +from csrc.cpp_itfs.utils import compile_template_op, AITER_CORE_DIR, str_to_bool + + +MD_NAME = "top_p_sampling_from_probs" + +with open( + f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.cpp.jinja", + "r", +) as f: + src_template = Template(f.read()) + + +def compile( + d: int, + deterministic: bool, + folder: str = None, +): + return compile_template_op( + src_template, + MD_NAME, + [ + f"{AITER_CORE_DIR}/csrc/cpp_itfs/utils.h", + f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/sampling.cuh", + f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/vec_dtypes.cuh", + ], + d=d, + deterministic=deterministic, + folder=folder, + ) + + +def top_p_sampling_from_probs( + probs, + indices, + maybe_top_p_arr, + top_p_val, + deterministic: bool = False, + generator=None, +): + import torch + from csrc.cpp_itfs.torch_utils import torch_to_c_types + + if generator is None: + generator = torch.cuda.default_generators[probs.device.index] + philox_offset = generator.get_offset() + philox_seed = generator.seed() + + probs = probs.float() + maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None + top_p_val = float(top_p_val) + + batch_size = probs.size(0) + vocab_size = probs.size(1) + + samples = torch.empty(batch_size, dtype=torch.int32, device=probs.device) + func = compile(vocab_size, deterministic) + ( + probs_ptr, + samples_ptr, + indices_ptr, + top_p_arr_ptr, + top_p_val, + batch_size, + philox_seed, + philox_offset, + stream, + ) = torch_to_c_types( + probs, + samples, + indices, + maybe_top_p_arr, + top_p_val, + batch_size, + philox_seed, + philox_offset, + torch.cuda.current_stream(), + ) + func( + probs_ptr, + samples_ptr, + indices_ptr, + top_p_arr_ptr, + batch_size, + top_p_val, + philox_seed, + philox_offset, + stream, + ) + return samples + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--d", type=int, required=True) + parser.add_argument("--deterministic", type=str_to_bool, required=True) + parser.add_argument("--folder", type=str, default=None) + args = parser.parse_args() + compile(**vars(args)) diff --git a/csrc/cpp_itfs/sampling/vec_dtypes.cuh b/csrc/cpp_itfs/sampling/vec_dtypes.cuh new file mode 100644 index 0000000000..468d7ae8bb --- /dev/null +++ b/csrc/cpp_itfs/sampling/vec_dtypes.cuh @@ -0,0 +1,2160 @@ +/* + * Copyright (C) 2023-2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +#if defined(__HIPCC__) || (defined(__clang__) && defined(__HIP__)) || defined(__HIPCC_RTC__) +/* +Hacky workaround for the error below: + + /home/git_repos/glen-amd/flashinfer/include/flashinfer/attention/../vec_dtypes_hip.cuh:200:38: +error: use of undeclared identifier '__float2bfloat162_rn'; did you mean '__float22bfloat162_rn'? + 200 | const __hip_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); | ^~~~~~~~~~~~~~~~~~~~ | __float22bfloat162_rn + /opt/rocm-6.3.1/lib/llvm/bin/../../../include/hip/amd_detail/amd_hip_bf16.h:574:45: note: +'__float22bfloat162_rn' declared here 574 | __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 +__float22bfloat162_rn(const float2 a) { +*/ +__HOST_DEVICE__ inline __hip_bfloat162 __float2bfloat162_rn(const float a) +{ + return __hip_bfloat162{__float2bfloat16(a), __float2bfloat16(a)}; +} + +inline __attribute__((always_inline)) __device__ __hip_bfloat162 +make_bfloat162(const __hip_bfloat16 x, const __hip_bfloat16 y) +{ + __hip_bfloat162 t; + t.x = x; + t.y = y; + return t; +} +#endif + +namespace aiter { + +#define FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + +/******************* vec_t type cast *******************/ + +template +struct vec_cast +{ + template + inline __attribute__((always_inline)) __device__ static void cast(dst_t* dst, const src_t* src) + { +#pragma unroll + for(size_t i = 0; i < vec_size; ++i) + { + dst[i] = (dst_t)src[i]; + } + } +}; + +template <> +struct vec_cast +{ + template + inline __attribute__((always_inline)) __device__ static void cast(float* dst, const half* src) + { + if constexpr(vec_size == 1) + { + // dst[0] = (float)src[0]; + dst[0] = __half2float(src[0]); + } + else + { +#pragma unroll + for(size_t i = 0; i < vec_size / 2; ++i) + { + ((float2*)dst)[i] = __half22float2(((half2*)src)[i]); + } + } + } +}; + +template <> +struct vec_cast +{ + template + inline __attribute__((always_inline)) __device__ static void cast(half* dst, const float* src) + { + if constexpr(vec_size == 1) + { + dst[0] = __float2half(src[0]); + } + else + { +#pragma unroll + for(size_t i = 0; i < vec_size / 2; ++i) + { + ((half2*)dst)[i] = __float22half2_rn(((float2*)src)[i]); + } + } + } +}; + +template +constexpr inline __attribute__((always_inline)) __device__ int get_exponent_bits() +{ + if constexpr(std::is_same_v) + { + return 4; + } + else if constexpr(std::is_same_v) + { + return 5; + } + else if constexpr(std::is_same_v) + { + return 5; + } + else if constexpr(std::is_same_v) + { + return 8; + } +} + +template +constexpr inline __attribute__((always_inline)) __device__ int get_mantissa_bits() +{ + if constexpr(std::is_same_v) + { + return 3; + } + else if constexpr(std::is_same_v) + { + return 2; + } + else if constexpr(std::is_same_v) + { + return 11; + } + else if constexpr(std::is_same_v) + { + return 7; + } +} + +/*! + * \brief Fallback to software fast dequant implementation if hardware dequantization is not + * available. + * \note Inspired by Marlin's fast dequantization, but here we don't have to permute + * weights order. + * \ref + * https://github.com/vllm-project/vllm/blob/6dffa4b0a6120159ef2fe44d695a46817aff65bc/csrc/quantization/fp8/fp8_marlin.cu#L120 + */ +template +__device__ void fast_dequant_f8f16x4(uint32_t* input, uint2* output) +{ + uint32_t q = *input; + if constexpr(std::is_same_v && std::is_same_v) + { + output->x = __byte_perm(0U, q, 0x5140); + output->y = __byte_perm(0U, q, 0x7362); + } + else + { + constexpr int FP8_EXPONENT = get_exponent_bits(); + constexpr int FP8_MANTISSA = get_mantissa_bits(); + constexpr int FP16_EXPONENT = get_exponent_bits(); + + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; + // Calculate MASK for extracting mantissa and exponent + // XXX: duplicate defs of `MASK1` and `MASK2`, + // in the HIP file "include/hip/amd_detail/amd_device_functions.h". + constexpr int MASK1_orig = 0x80000000; + constexpr int MASK2_orig = MASK1_orig >> (FP8_EXPONENT + FP8_MANTISSA); + constexpr int MASK3 = MASK2_orig & 0x7fffffff; + constexpr int MASK = MASK3 | (MASK3 >> 16); + q = __byte_perm(q, q, 0x1302); + + // Extract and shift FP8 values to FP16 format + uint32_t Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + uint32_t Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); + + constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + // Construct and apply exponent bias + if constexpr(std::is_same_v) + { + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + + // Convert to half2 and apply bias + *(half2*)&(output->x) = __hmul2(*reinterpret_cast(&Out1), bias_reg); + *(half2*)&(output->y) = __hmul2(*reinterpret_cast(&Out2), bias_reg); + } + else + { + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const __hip_bfloat162 bias_reg = + __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + // Convert to bfloat162 and apply bias + *(__hip_bfloat162*)&(output->x) = + __hmul2(*reinterpret_cast(&Out1), bias_reg); + *(__hip_bfloat162*)&(output->y) = + __hmul2(*reinterpret_cast(&Out2), bias_reg); + } + } +} + +template <> +struct vec_cast<__hip_bfloat16, __hip_fp8_e4m3_fnuz> +{ + template + inline __attribute__((always_inline)) __device__ static void + cast(__hip_bfloat16* dst, const __hip_fp8_e4m3_fnuz* src) + { + if constexpr(vec_size == 1) + { + dst[0] = __hip_bfloat16(src[0]); + } + else if constexpr(vec_size == 2) + { + dst[0] = __hip_bfloat16(src[0]); + dst[1] = __hip_bfloat16(src[1]); + } + else + { + static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); +#pragma unroll + for(uint32_t i = 0; i < vec_size / 4; ++i) + { + fast_dequant_f8f16x4<__hip_fp8_e4m3_fnuz, __hip_bfloat16>((uint32_t*)&src[i * 4], + (uint2*)&dst[i * 4]); + } + } + } +}; + +template <> +struct vec_cast<__hip_bfloat16, __hip_fp8_e5m2_fnuz> +{ + template + inline __attribute__((always_inline)) __device__ static void + cast(__hip_bfloat16* dst, const __hip_fp8_e5m2_fnuz* src) + { + if constexpr(vec_size == 1) + { + dst[0] = __hip_bfloat16(src[0]); + } + else if constexpr(vec_size == 2) + { + dst[0] = __hip_bfloat16(src[0]); + dst[1] = __hip_bfloat16(src[1]); + } + else + { + static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); +#pragma unroll + for(uint32_t i = 0; i < vec_size / 4; ++i) + { + fast_dequant_f8f16x4<__hip_fp8_e5m2_fnuz, __hip_bfloat16>((uint32_t*)&src[i * 4], + (uint2*)&dst[i * 4]); + } + } + } +}; + +#if defined(__HIPCC__) || (defined(__clang__) && defined(__HIP__)) || defined(__HIPCC_RTC__) +// Function to convert half-precision to e4m3 +__device__ uint8_t convert_f32_to_e4m3(float val) +{ + // Define the range of e4m3 + // 1. Minimum representable value for e4m3 + // 2. Binary 1000.000 in e4m3 + // 3. FLT_MIN is not suitable for e4m3 because e4m3 has a much smaller dynamic range. + float min_e4m3 = -8.0f; + // 1. Maximum representable value for e4m3 + // 2. Binary 0111.111 in e4m3 + // FLT_MAX far exceeds the maximum value representable in e4m3. + float max_e4m3 = 7.875f; + + // Saturate the value to the e4m3 range + val = fminf(fmaxf(val, min_e4m3), max_e4m3); + + // Perform conversion + // Decompose into mantissa and exponent + int exp; + float mantissa = frexpf(val, &exp); + + // Encode sign bit + uint8_t sign = (mantissa < 0) ? 0x80 : 0x00; + + // Normalize mantissa and encode exponent + mantissa = fabsf(mantissa) * 16.0f; // Scale mantissa for e4m3's 3-bit precision + uint8_t exponent = static_cast(exp + 7); // Bias of 7 for e4m3 + + // Quantize mantissa + // Apply round-to-nearest-even to the mantissa + uint8_t quant_mantissa = static_cast(roundf(mantissa)) & 0x07; + + // Combine into 8 bits: [sign][exponent][mantissa] + return sign | (exponent << 3) | quant_mantissa; +} + +__device__ __half2 convert_uint32_to_half2(uint32_t input) +{ + // Extract the low and high 16 bits + uint16_t low_val = input & 0xFFFF; + uint16_t high_val = (input >> 16) & 0xFFFF; + // Convert to __half + __half low_half = __float2half(static_cast(low_val)); + __half high_half = __float2half(static_cast(high_val)); + // Pack into __half2 + return __halves2half2(low_half, high_half); +} + +// Convert f16x2 (__half2) to e4m3x2 (packed 16-bit) +__device__ uint16_t convert_f16x2_to_e4m3x2(__half2 x) +{ + float f32_0 = __half2float(__low2half(x)); + float f32_1 = __half2float(__high2half(x)); + uint8_t e4m3_0 = convert_f32_to_e4m3(f32_0); + uint8_t e4m3_1 = convert_f32_to_e4m3(f32_1); + return (static_cast(e4m3_1) << 8) | e4m3_0; +} +#endif + +template <> +struct vec_cast<__hip_fp8_e4m3_fnuz, half> +{ + template + inline __attribute__((always_inline)) __device__ static void cast(__hip_fp8_e4m3_fnuz* dst, + const half* src) + { +#ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + if constexpr(vec_size == 1) + { + dst[0] = __hip_fp8_e4m3_fnuz(src[0]); + } + else + { +#pragma unroll + for(size_t i = 0; i < vec_size / 2; ++i) + { + uint16_t y; + uint32_t x = *(uint32_t*)&src[i * 2]; + __half2 x_h2 = convert_uint32_to_half2(x); + y = convert_f16x2_to_e4m3x2(x_h2); + *(uint16_t*)&dst[i * 2] = y; + } + } +#else +#pragma unroll + for(size_t i = 0; i < vec_size; ++i) + { + dst[i] = __hip_fp8_e4m3_fnuz(src[i]); + } +#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + } +}; + +#if defined(__HIPCC__) || (defined(__clang__) && defined(__HIP__)) || defined(__HIPCC_RTC__) +__device__ uint16_t convert_f16x2_to_e5m2x2(uint32_t x) +{ + // Unpack the two 16-bit half-precision floats from the input + // Extract lower 16 bits + __half h1 = __ushort_as_half(x & 0xFFFF); + // Extract upper 16 bits + __half h2 = __ushort_as_half((x >> 16) & 0xFFFF); + + // Define the range of e5m2 + // Minimum representable value for e5m2 + const float min_e5m2 = -8.0f; + // Maximum representable value for e5m2 + const float max_e5m2 = 7.75f; + + // Helper lambda for conversion + auto f32_to_e5m2 = [min_e5m2, max_e5m2](float val) -> uint8_t { + // Saturate the val + val = fminf(fmaxf(val, min_e5m2), max_e5m2); + + // Decompose into mantissa and exponent + int exp; + float mantissa = frexpf(val, &exp); + + // Encode sign bit + uint8_t sign = (mantissa < 0) ? 0x10 : 0x00; // Sign in bit 4 + mantissa = fabsf(mantissa); + + // Normalize mantissa and encode exponent + mantissa *= 4.0f; // Scale for 2-bit mantissa + uint8_t exponent = static_cast(exp + 7); // Apply bias for e5m2 + + // Apply round-to-nearest-even + uint8_t quant_mantissa = static_cast(roundf(mantissa)) & 0x03; + + // Combine into 5 bits: [sign][exponent][mantissa] + return sign | (exponent << 2) | quant_mantissa; + }; + + // Convert the two __half values to e5m2 + uint8_t e5m2_1 = f32_to_e5m2(__half2float(h1)); + uint8_t e5m2_2 = f32_to_e5m2(__half2float(h2)); + + // Pack the two e5m2 values into a single 16-bit output + return (e5m2_2 << 8) | e5m2_1; +} +#endif + +template <> +struct vec_cast<__hip_fp8_e5m2_fnuz, half> +{ + template + inline __attribute__((always_inline)) __device__ static void cast(__hip_fp8_e5m2_fnuz* dst, + const half* src) + { +#ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + if constexpr(vec_size == 1) + { + dst[0] = __hip_fp8_e5m2_fnuz(src[0]); + } + else + { +#pragma unroll + for(size_t i = 0; i < vec_size / 2; ++i) + { + uint16_t y; + uint32_t x = *(uint32_t*)&src[i * 2]; + y = convert_f16x2_to_e5m2x2(x); + *(uint16_t*)&dst[i * 2] = y; + } + } +#else +#pragma unroll + for(size_t i = 0; i < vec_size; ++i) + { + dst[i] = __hip_fp8_e5m2_fnuz(src[i]); + } +#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + } +}; + +#if defined(__HIPCC__) || (defined(__clang__) && defined(__HIP__)) || defined(__HIPCC_RTC__) +__device__ uint32_t convert_e4m3x2_to_f16x2(uint16_t x) +{ + // Extract two e4m3 values from the 16-bit input + uint8_t e4m3_1 = x & 0xFF; // Lower 8 bits + uint8_t e4m3_2 = (x >> 8) & 0xFF; // Upper 8 bits + + // Decode e4m3 to float + auto e4m3_to_f32 = [](uint8_t e4m3) -> float { + // Extract sign, exponent, and mantissa + int sign = (e4m3 & 0x80) ? -1 : 1; + int exponent = ((e4m3 >> 3) & 0x0F) - 7; // 4-bit exponent with bias 7 + int mantissa = e4m3 & 0x07; // 3-bit mantissa + + // Handle special case: zero + if(exponent == -7 && mantissa == 0) + { + return 0.0f; + } + + // Convert to float + float f32_val = sign * ldexpf(1.0f + mantissa / 8.0f, exponent); + return f32_val; + }; + + float f1 = e4m3_to_f32(e4m3_1); + float f2 = e4m3_to_f32(e4m3_2); + + // Convert float to IEEE f16 + __half h1 = __float2half_rn(f1); + __half h2 = __float2half_rn(f2); + + // Pack the two f16 values into a single uint32_t + uint32_t f16x2 = (__half_as_ushort(h2) << 16) | __half_as_ushort(h1); + return f16x2; +} +#endif + +template <> +struct vec_cast +{ + template + inline __attribute__((always_inline)) __device__ static void + cast(half* dst, const __hip_fp8_e4m3_fnuz* src) + { +#ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + if constexpr(vec_size == 1) + { + dst[0] = half(src[0]); + } + else + { +#pragma unroll + for(size_t i = 0; i < vec_size / 2; ++i) + { + uint32_t y; + uint16_t x = *(uint16_t*)&src[i * 2]; + y = convert_e4m3x2_to_f16x2(x); + *(uint32_t*)&dst[i * 2] = y; + } + } +#else + if constexpr(vec_size == 1) + { + dst[0] = half(src[0]); + } + else if constexpr(vec_size == 2) + { + dst[0] = half(src[0]); + dst[1] = half(src[1]); + } + else + { + static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); +#pragma unroll + for(uint32_t i = 0; i < vec_size / 4; ++i) + { + fast_dequant_f8f16x4<__hip_fp8_e4m3_fnuz, half>((uint32_t*)&src[i * 4], + (uint2*)&dst[i * 4]); + } + } +#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + } +}; + +#if defined(__HIPCC__) || (defined(__clang__) && defined(__HIP__)) || defined(__HIPCC_RTC__) +__device__ uint32_t convert_e5m2x2_to_f16x2(uint16_t x) +{ + // Extract two e5m2 values from the 16-bit input + uint8_t e5m2_1 = x & 0xFF; // Lower 8 bits + uint8_t e5m2_2 = (x >> 8) & 0xFF; // Upper 8 bits + + // Decode e5m2 to float + auto e5m2_to_f32 = [](uint8_t e5m2) -> float { + // Extract sign, exponent, and mantissa + int sign = (e5m2 & 0x80) ? -1 : 1; // Sign bit + int exponent = ((e5m2 >> 2) & 0x1F) - 15; // 5-bit exponent with bias 15 + int mantissa = e5m2 & 0x03; // 2-bit mantissa + + // Handle special case: zero + if(exponent == -15 && mantissa == 0) + { + return 0.0f; + } + + // Convert to float + float value = sign * ldexpf(1.0f + mantissa / 4.0f, exponent); + return value; + }; + + float f1 = e5m2_to_f32(e5m2_1); + float f2 = e5m2_to_f32(e5m2_2); + + // Convert float to IEEE f16 + __half h1 = __float2half_rn(f1); + __half h2 = __float2half_rn(f2); + + // Pack the two f16 values into a single uint32_t + uint32_t f16x2 = (__half_as_ushort(h2) << 16) | __half_as_ushort(h1); + return f16x2; +} +#endif + +template <> +struct vec_cast +{ + template + inline __attribute__((always_inline)) __device__ static void + cast(half* dst, const __hip_fp8_e5m2_fnuz* src) + { +#ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + if constexpr(vec_size == 1) + { + dst[0] = half(src[0]); + } + else + { +#pragma unroll + for(size_t i = 0; i < vec_size / 2; ++i) + { + uint32_t y; + uint16_t x = *(uint16_t*)&src[i * 2]; + y = convert_e5m2x2_to_f16x2(x); + *(uint32_t*)&dst[i * 2] = y; + } + } +#else + if constexpr(vec_size == 1) + { + dst[0] = half(src[0]); + } + else if constexpr(vec_size == 2) + { + dst[0] = half(src[0]); + dst[1] = half(src[1]); + } + else + { + static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); +#pragma unroll + for(uint32_t i = 0; i < vec_size / 4; ++i) + { + fast_dequant_f8f16x4<__hip_fp8_e5m2_fnuz, half>((uint32_t*)&src[i * 4], + (uint2*)&dst[i * 4]); + } + } +#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + } +}; + +template <> +struct vec_cast +{ + template + inline __attribute__((always_inline)) __device__ static void cast(float* dst, + const __hip_bfloat16* src) + { + if constexpr(vec_size == 1) + { + dst[0] = (float)src[0]; + } + else + { +#pragma unroll + for(size_t i = 0; i < vec_size / 2; ++i) + { + ((float2*)dst)[i] = __bfloat1622float2(((__hip_bfloat162*)src)[i]); + } + } + } +}; + +template <> +struct vec_cast<__hip_bfloat16, float> +{ + template + inline __attribute__((always_inline)) __device__ static void cast(__hip_bfloat16* dst, + const float* src) + { + /*if constexpr (vec_size == 1) { + dst[0] = __hip_bfloat16(src[0]); + } else { + #pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((__hip_bfloat162*)dst)[i] = __float22bfloat162_rn(((float2*)src)[i]); + } + }*/ + // fast but unsafe bfloat conversion... + union f2bf + { + float f; + __hip_bfloat16 bf[2]; + } _f2bf; +#pragma unroll + for(size_t i = 0; i < vec_size; ++i) + { + _f2bf.f = src[i]; + dst[i] = _f2bf.bf[1]; + } + } +}; + +template +struct vec_t +{ + inline __attribute__((always_inline)) __device__ float_t& operator[](size_t i); + inline __attribute__((always_inline)) __device__ const float_t& operator[](size_t i) const; + inline __attribute__((always_inline)) __device__ void fill(float_t val); + inline __attribute__((always_inline)) __device__ void load(const float_t* ptr); + inline __attribute__((always_inline)) __device__ void store(float_t* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src); + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr); + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const; + inline __attribute__((always_inline)) __device__ static void memcpy(float_t* dst, + const float_t* src); + inline __attribute__((always_inline)) __device__ float_t* ptr(); +}; + +template +inline __attribute__((always_inline)) __device__ void +cast_from_impl(vec_t& dst, const vec_t& src) +{ + vec_cast::template cast( + dst.ptr(), const_cast*>(&src)->ptr()); +} + +template +inline __attribute__((always_inline)) __device__ void +cast_load_impl(vec_t& dst, const src_float_t* src_ptr) +{ + if constexpr(std::is_same_v) + { + dst.load(src_ptr); + } + else + { + vec_t tmp; + tmp.load(src_ptr); + dst.cast_from(tmp); + } +} + +template +inline __attribute__((always_inline)) __device__ void +cast_store_impl(tgt_float_t* dst_ptr, const vec_t& src) +{ + if constexpr(std::is_same_v) + { + src.store(dst_ptr); + } + else + { + vec_t tmp; + tmp.cast_from(src); + tmp.store(dst_ptr); + } +} + +/******************* vec_t<__hip_fp8_e4m3_fnuz> *******************/ + +// __hip_fp8_e4m3_fnuz x 1 +template <> +struct vec_t<__hip_fp8_e4m3_fnuz, 1> +{ + __hip_fp8_e4m3_fnuz data; + + inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz& operator[](size_t i) + { + return ((__hip_fp8_e4m3_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ const __hip_fp8_e4m3_fnuz& + operator[](size_t i) const + { + return ((const __hip_fp8_e4m3_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz* ptr() + { + return reinterpret_cast<__hip_fp8_e4m3_fnuz*>(&data); + } + inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e4m3_fnuz val); + inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e4m3_fnuz* ptr); + inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e4m3_fnuz* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + + inline __attribute__((always_inline)) __device__ static void + memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src); +}; + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e4m3_fnuz, 1>::fill(__hip_fp8_e4m3_fnuz val) +{ + data = val; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e4m3_fnuz, 1>::load(const __hip_fp8_e4m3_fnuz* ptr) +{ + data = *ptr; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e4m3_fnuz, 1>::store(__hip_fp8_e4m3_fnuz* ptr) const +{ + *ptr = data; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e4m3_fnuz, 1>::memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src) +{ + *dst = *src; +} + +// __hip_fp8_e4m3_fnuz x 2 +template <> +struct vec_t<__hip_fp8_e4m3_fnuz, 2> +{ + __hip_fp8x2_e4m3_fnuz data; + + inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz& operator[](size_t i) + { + return ((__hip_fp8_e4m3_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ const __hip_fp8_e4m3_fnuz& + operator[](size_t i) const + { + return ((const __hip_fp8_e4m3_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz* ptr() + { + return reinterpret_cast<__hip_fp8_e4m3_fnuz*>(&data); + } + inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e4m3_fnuz val); + inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e4m3_fnuz* ptr); + inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e4m3_fnuz* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void + memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src); +}; + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e4m3_fnuz, 2>::fill(__hip_fp8_e4m3_fnuz val) +{ + data.__x = (__hip_fp8x2_storage_t(val.__x) << 8) | __hip_fp8x2_storage_t(val.__x); +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e4m3_fnuz, 2>::load(const __hip_fp8_e4m3_fnuz* ptr) +{ + data = *((__hip_fp8x2_e4m3_fnuz*)ptr); +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e4m3_fnuz, 2>::store(__hip_fp8_e4m3_fnuz* ptr) const +{ + *((__hip_fp8x2_e4m3_fnuz*)ptr) = data; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e4m3_fnuz, 2>::memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src) +{ + *((__hip_fp8x2_e4m3_fnuz*)dst) = *((__hip_fp8x2_e4m3_fnuz*)src); +} + +// __hip_fp8_e4m3_fnuz x 4 + +template <> +struct vec_t<__hip_fp8_e4m3_fnuz, 4> +{ + __hip_fp8x4_e4m3_fnuz data; + + inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz& operator[](size_t i) + { + return ((__hip_fp8_e4m3_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ const __hip_fp8_e4m3_fnuz& + operator[](size_t i) const + { + return ((const __hip_fp8_e4m3_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz* ptr() + { + return reinterpret_cast<__hip_fp8_e4m3_fnuz*>(&data); + } + inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e4m3_fnuz val); + inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e4m3_fnuz* ptr); + inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e4m3_fnuz* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + + inline __attribute__((always_inline)) __device__ static void + memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src); +}; + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e4m3_fnuz, 4>::fill(__hip_fp8_e4m3_fnuz val) +{ + data.__x = (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e4m3_fnuz, 4>::load(const __hip_fp8_e4m3_fnuz* ptr) +{ + data = *((__hip_fp8x4_e4m3_fnuz*)ptr); +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e4m3_fnuz, 4>::store(__hip_fp8_e4m3_fnuz* ptr) const +{ + *((__hip_fp8x4_e4m3_fnuz*)ptr) = data; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e4m3_fnuz, 4>::memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src) +{ + *((__hip_fp8x4_e4m3_fnuz*)dst) = *((__hip_fp8x4_e4m3_fnuz*)src); +} + +// __hip_fp8_e4m3_fnuz x 8 + +template <> +struct vec_t<__hip_fp8_e4m3_fnuz, 8> +{ + uint2 data; + + inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz& operator[](size_t i) + { + return ((__hip_fp8_e4m3_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ const __hip_fp8_e4m3_fnuz& + operator[](size_t i) const + { + return ((const __hip_fp8_e4m3_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz* ptr() + { + return reinterpret_cast<__hip_fp8_e4m3_fnuz*>(&data); + } + inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e4m3_fnuz val); + inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e4m3_fnuz* ptr); + inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e4m3_fnuz* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + + inline __attribute__((always_inline)) __device__ static void + memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src); +}; + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e4m3_fnuz, 8>::fill(__hip_fp8_e4m3_fnuz val) +{ + ((__hip_fp8x4_e4m3_fnuz*)(&data.x))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e4m3_fnuz*)(&data.y))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e4m3_fnuz, 8>::load(const __hip_fp8_e4m3_fnuz* ptr) +{ + data = *((uint2*)ptr); +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e4m3_fnuz, 8>::store(__hip_fp8_e4m3_fnuz* ptr) const +{ + *((uint2*)ptr) = data; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e4m3_fnuz, 8>::memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src) +{ + *((uint2*)dst) = *((uint2*)src); +} + +// __hip_fp8_e4m3_fnuz x 16 or more +template +struct vec_t<__hip_fp8_e4m3_fnuz, vec_size> +{ + uint4 data[vec_size / 16]; + + inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz& operator[](size_t i) + { + return ((__hip_fp8_e4m3_fnuz*)data)[i]; + } + inline __attribute__((always_inline)) __device__ const __hip_fp8_e4m3_fnuz& + operator[](size_t i) const + { + return ((const __hip_fp8_e4m3_fnuz*)data)[i]; + } + inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz* ptr() + { + return reinterpret_cast<__hip_fp8_e4m3_fnuz*>(&data); + } + inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e4m3_fnuz val) + { +#pragma unroll + for(size_t i = 0; i < vec_size / 16; ++i) + { + ((__hip_fp8x4_e4m3_fnuz*)(&(data[i].x)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e4m3_fnuz*)(&(data[i].y)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e4m3_fnuz*)(&(data[i].z)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e4m3_fnuz*)(&(data[i].w)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + } + } + inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e4m3_fnuz* ptr) + { +#pragma unroll + for(size_t i = 0; i < vec_size / 16; ++i) + { + data[i] = ((uint4*)ptr)[i]; + } + } + inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e4m3_fnuz* ptr) const + { +#pragma unroll + for(size_t i = 0; i < vec_size / 16; ++i) + { + ((uint4*)ptr)[i] = data[i]; + } + } + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + + inline __attribute__((always_inline)) __device__ static void + memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src) + { +#pragma unroll + for(size_t i = 0; i < vec_size / 16; ++i) + { + ((uint4*)dst)[i] = ((uint4*)src)[i]; + } + } +}; + +/******************* vec_t<__hip_fp8_e5m2_fnuz> *******************/ + +// __hip_fp8_e5m2_fnuz x 1 +template <> +struct vec_t<__hip_fp8_e5m2_fnuz, 1> +{ + __hip_fp8_e5m2_fnuz data; + + inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz& operator[](size_t i) + { + return ((__hip_fp8_e5m2_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ const __hip_fp8_e5m2_fnuz& + operator[](size_t i) const + { + return ((const __hip_fp8_e5m2_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz* ptr() + { + return reinterpret_cast<__hip_fp8_e5m2_fnuz*>(&data); + } + inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e5m2_fnuz val); + inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e5m2_fnuz* ptr); + inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e5m2_fnuz* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + + inline __attribute__((always_inline)) __device__ static void + memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src); +}; + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e5m2_fnuz, 1>::fill(__hip_fp8_e5m2_fnuz val) +{ + data = val; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e5m2_fnuz, 1>::load(const __hip_fp8_e5m2_fnuz* ptr) +{ + data = *ptr; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e5m2_fnuz, 1>::store(__hip_fp8_e5m2_fnuz* ptr) const +{ + *ptr = data; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e5m2_fnuz, 1>::memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src) +{ + *dst = *src; +} + +// __hip_fp8_e5m2_fnuz x 2 +template <> +struct vec_t<__hip_fp8_e5m2_fnuz, 2> +{ + __hip_fp8x2_e5m2_fnuz data; + + inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz& operator[](size_t i) + { + return ((__hip_fp8_e5m2_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ const __hip_fp8_e5m2_fnuz& + operator[](size_t i) const + { + return ((const __hip_fp8_e5m2_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz* ptr() + { + return reinterpret_cast<__hip_fp8_e5m2_fnuz*>(&data); + } + inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e5m2_fnuz val); + inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e5m2_fnuz* ptr); + inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e5m2_fnuz* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + + inline __attribute__((always_inline)) __device__ static void + memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src); +}; + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e5m2_fnuz, 2>::fill(__hip_fp8_e5m2_fnuz val) +{ + data.__x = (__hip_fp8x2_storage_t(val.__x) << 8) | __hip_fp8x2_storage_t(val.__x); +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e5m2_fnuz, 2>::load(const __hip_fp8_e5m2_fnuz* ptr) +{ + data = *((__hip_fp8x2_e5m2_fnuz*)ptr); +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e5m2_fnuz, 2>::store(__hip_fp8_e5m2_fnuz* ptr) const +{ + *((__hip_fp8x2_e5m2_fnuz*)ptr) = data; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e5m2_fnuz, 2>::memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src) +{ + *((__hip_fp8x2_e5m2_fnuz*)dst) = *((__hip_fp8x2_e5m2_fnuz*)src); +} + +// __hip_fp8_e5m2_fnuz x 4 + +template <> +struct vec_t<__hip_fp8_e5m2_fnuz, 4> +{ + __hip_fp8x4_e5m2_fnuz data; + + inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz& operator[](size_t i) + { + return ((__hip_fp8_e5m2_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ const __hip_fp8_e5m2_fnuz& + operator[](size_t i) const + { + return ((const __hip_fp8_e5m2_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz* ptr() + { + return reinterpret_cast<__hip_fp8_e5m2_fnuz*>(&data); + } + inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e5m2_fnuz val); + inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e5m2_fnuz* ptr); + inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e5m2_fnuz* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + + inline __attribute__((always_inline)) __device__ static void + memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src); +}; + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e5m2_fnuz, 4>::fill(__hip_fp8_e5m2_fnuz val) +{ + data.__x = (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e5m2_fnuz, 4>::load(const __hip_fp8_e5m2_fnuz* ptr) +{ + data = *((__hip_fp8x4_e5m2_fnuz*)ptr); +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e5m2_fnuz, 4>::store(__hip_fp8_e5m2_fnuz* ptr) const +{ + *((__hip_fp8x4_e5m2_fnuz*)ptr) = data; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e5m2_fnuz, 4>::memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src) +{ + *((__hip_fp8x4_e5m2_fnuz*)dst) = *((__hip_fp8x4_e5m2_fnuz*)src); +} + +// __hip_fp8_e5m2_fnuz x 8 + +template <> +struct vec_t<__hip_fp8_e5m2_fnuz, 8> +{ + uint2 data; + + inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz& operator[](size_t i) + { + return ((__hip_fp8_e5m2_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ const __hip_fp8_e5m2_fnuz& + operator[](size_t i) const + { + return ((const __hip_fp8_e5m2_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz* ptr() + { + return reinterpret_cast<__hip_fp8_e5m2_fnuz*>(&data); + } + inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e5m2_fnuz val); + inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e5m2_fnuz* ptr); + inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e5m2_fnuz* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void + memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src); +}; + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e5m2_fnuz, 8>::fill(__hip_fp8_e5m2_fnuz val) +{ + ((__hip_fp8x4_e5m2_fnuz*)(&data.x))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e5m2_fnuz*)(&data.y))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e5m2_fnuz, 8>::load(const __hip_fp8_e5m2_fnuz* ptr) +{ + data = *((uint2*)ptr); +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e5m2_fnuz, 8>::store(__hip_fp8_e5m2_fnuz* ptr) const +{ + *((uint2*)ptr) = data; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e5m2_fnuz, 8>::memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src) +{ + *((uint2*)dst) = *((uint2*)src); +} + +// __hip_fp8_e5m2_fnuz x 16 or more + +template +struct vec_t<__hip_fp8_e5m2_fnuz, vec_size> +{ + uint4 data[vec_size / 16]; + + inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz& operator[](size_t i) + { + return ((__hip_fp8_e5m2_fnuz*)data)[i]; + } + inline __attribute__((always_inline)) __device__ const __hip_fp8_e5m2_fnuz& + operator[](size_t i) const + { + return ((const __hip_fp8_e5m2_fnuz*)data)[i]; + } + inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz* ptr() + { + return reinterpret_cast<__hip_fp8_e5m2_fnuz*>(&data); + } + inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e5m2_fnuz val) + { +#pragma unroll + for(size_t i = 0; i < vec_size / 16; ++i) + { + ((__hip_fp8x4_e5m2_fnuz*)(&(data[i].x)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e5m2_fnuz*)(&(data[i].y)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e5m2_fnuz*)(&(data[i].z)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e5m2_fnuz*)(&(data[i].w)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + } + } + inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e5m2_fnuz* ptr) + { +#pragma unroll + for(size_t i = 0; i < vec_size / 16; ++i) + { + data[i] = ((uint4*)ptr)[i]; + } + } + inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e5m2_fnuz* ptr) const + { +#pragma unroll + for(size_t i = 0; i < vec_size / 16; ++i) + { + ((uint4*)ptr)[i] = data[i]; + } + } + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void + memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src) + { +#pragma unroll + for(size_t i = 0; i < vec_size / 16; ++i) + { + ((uint4*)dst)[i] = ((uint4*)src)[i]; + } + } +}; + +/******************* vec_t *******************/ + +// half x 1 +template <> +struct vec_t +{ + half data; + + inline __attribute__((always_inline)) __device__ half& operator[](size_t i) + { + return ((half*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ const half& operator[](size_t i) const + { + return ((const half*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ half* ptr() + { + return reinterpret_cast(&data); + } + inline __attribute__((always_inline)) __device__ void fill(half val); + inline __attribute__((always_inline)) __device__ void load(const half* ptr); + inline __attribute__((always_inline)) __device__ void store(half* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + + inline __attribute__((always_inline)) __device__ static void memcpy(half* dst, const half* src); +}; + +inline __attribute__((always_inline)) __device__ void vec_t::fill(half val) { data = val; } + +inline __attribute__((always_inline)) __device__ void vec_t::load(const half* ptr) +{ + data = *ptr; +} + +inline __attribute__((always_inline)) __device__ void vec_t::store(half* ptr) const +{ + *ptr = data; +} + +inline __attribute__((always_inline)) __device__ void vec_t::memcpy(half* dst, + const half* src) +{ + *dst = *src; +} + +// half x 2 +template <> +struct vec_t +{ + half2 data; + + inline __attribute__((always_inline)) __device__ half& operator[](size_t i) + { + return ((half*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ const half& operator[](size_t i) const + { + return ((const half*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ half* ptr() + { + return reinterpret_cast(&data); + } + inline __attribute__((always_inline)) __device__ void fill(half val); + inline __attribute__((always_inline)) __device__ void load(const half* ptr); + inline __attribute__((always_inline)) __device__ void store(half* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + + inline __attribute__((always_inline)) __device__ static void memcpy(half* dst, const half* src); +}; + +inline __attribute__((always_inline)) __device__ void vec_t::fill(half val) +{ + data = make_half2(val, val); +} + +inline __attribute__((always_inline)) __device__ void vec_t::load(const half* ptr) +{ + data = *((half2*)ptr); +} + +inline __attribute__((always_inline)) __device__ void vec_t::store(half* ptr) const +{ + *((half2*)ptr) = data; +} + +inline __attribute__((always_inline)) __device__ void vec_t::memcpy(half* dst, + const half* src) +{ + *((half2*)dst) = *((half2*)src); +} + +// half x 4 + +template <> +struct vec_t +{ + uint2 data; + + inline __attribute__((always_inline)) __device__ half& operator[](size_t i) + { + return ((half*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ const half& operator[](size_t i) const + { + return ((const half*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ half* ptr() + { + return reinterpret_cast(&data); + } + inline __attribute__((always_inline)) __device__ void fill(half val); + inline __attribute__((always_inline)) __device__ void load(const half* ptr); + inline __attribute__((always_inline)) __device__ void store(half* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void memcpy(half* dst, const half* src); +}; + +inline __attribute__((always_inline)) __device__ void vec_t::fill(half val) +{ + *(half2*)(&data.x) = make_half2(val, val); + *(half2*)(&data.y) = make_half2(val, val); +} + +inline __attribute__((always_inline)) __device__ void vec_t::load(const half* ptr) +{ + data = *((uint2*)ptr); +} + +inline __attribute__((always_inline)) __device__ void vec_t::store(half* ptr) const +{ + *((uint2*)ptr) = data; +} + +inline __attribute__((always_inline)) __device__ void vec_t::memcpy(half* dst, + const half* src) +{ + *((uint2*)dst) = *((uint2*)src); +} + +// half x 8 or more + +template +struct vec_t +{ + uint4 data[vec_size / 8]; + inline __attribute__((always_inline)) __device__ half& operator[](size_t i) + { + return ((half*)data)[i]; + } + inline __attribute__((always_inline)) __device__ const half& operator[](size_t i) const + { + return ((const half*)data)[i]; + } + inline __attribute__((always_inline)) __device__ half* ptr() + { + return reinterpret_cast(&data); + } + inline __attribute__((always_inline)) __device__ void fill(half val) + { +#pragma unroll + for(size_t i = 0; i < vec_size / 8; ++i) + { + *(half2*)(&(data[i].x)) = make_half2(val, val); + *(half2*)(&(data[i].y)) = make_half2(val, val); + *(half2*)(&(data[i].z)) = make_half2(val, val); + *(half2*)(&(data[i].w)) = make_half2(val, val); + } + } + inline __attribute__((always_inline)) __device__ void load(const half* ptr) + { +#pragma unroll + for(size_t i = 0; i < vec_size / 8; ++i) + { + data[i] = ((uint4*)ptr)[i]; + } + } + inline __attribute__((always_inline)) __device__ void store(half* ptr) const + { +#pragma unroll + for(size_t i = 0; i < vec_size / 8; ++i) + { + ((uint4*)ptr)[i] = data[i]; + } + } + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void memcpy(half* dst, const half* src) + { +#pragma unroll + for(size_t i = 0; i < vec_size / 8; ++i) + { + ((uint4*)dst)[i] = ((uint4*)src)[i]; + } + } +}; + +/******************* vec_t<__hip_bfloat16> *******************/ + +// __hip_bfloat16 x 1 +template <> +struct vec_t<__hip_bfloat16, 1> +{ + __hip_bfloat16 data; + inline __attribute__((always_inline)) __device__ __hip_bfloat16& operator[](size_t i) + { + return ((__hip_bfloat16*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ const __hip_bfloat16& + operator[](size_t i) const + { + return ((const __hip_bfloat16*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ __hip_bfloat16* ptr() + { + return reinterpret_cast<__hip_bfloat16*>(&data); + } + inline __attribute__((always_inline)) __device__ void fill(__hip_bfloat16 val); + inline __attribute__((always_inline)) __device__ void load(const __hip_bfloat16* ptr); + inline __attribute__((always_inline)) __device__ void store(__hip_bfloat16* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void memcpy(__hip_bfloat16* dst, + const __hip_bfloat16* src); +}; + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_bfloat16, 1>::fill(__hip_bfloat16 val) +{ + data = val; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_bfloat16, 1>::load(const __hip_bfloat16* ptr) +{ + data = *ptr; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_bfloat16, 1>::store(__hip_bfloat16* ptr) const +{ + *ptr = data; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_bfloat16, 1>::memcpy(__hip_bfloat16* dst, const __hip_bfloat16* src) +{ + *dst = *src; +} + +// __hip_bfloat16 x 2 +template <> +struct vec_t<__hip_bfloat16, 2> +{ + __hip_bfloat162 data; + + inline __attribute__((always_inline)) __device__ __hip_bfloat16& operator[](size_t i) + { + return ((__hip_bfloat16*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ const __hip_bfloat16& + operator[](size_t i) const + { + return ((const __hip_bfloat16*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ __hip_bfloat16* ptr() + { + return reinterpret_cast<__hip_bfloat16*>(&data); + } + inline __attribute__((always_inline)) __device__ void fill(__hip_bfloat16 val); + inline __attribute__((always_inline)) __device__ void load(const __hip_bfloat16* ptr); + inline __attribute__((always_inline)) __device__ void store(__hip_bfloat16* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void memcpy(__hip_bfloat16* dst, + const __hip_bfloat16* src); +}; + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_bfloat16, 2>::fill(__hip_bfloat16 val) +{ + data = make_bfloat162(val, val); +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_bfloat16, 2>::load(const __hip_bfloat16* ptr) +{ + data = *((__hip_bfloat162*)ptr); +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_bfloat16, 2>::store(__hip_bfloat16* ptr) const +{ + *((__hip_bfloat162*)ptr) = data; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_bfloat16, 2>::memcpy(__hip_bfloat16* dst, const __hip_bfloat16* src) +{ + *((__hip_bfloat162*)dst) = *((__hip_bfloat162*)src); +} + +// __hip_bfloat16 x 4 + +template <> +struct vec_t<__hip_bfloat16, 4> +{ + uint2 data; + + inline __attribute__((always_inline)) __device__ __hip_bfloat16& operator[](size_t i) + { + return ((__hip_bfloat16*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ const __hip_bfloat16& + operator[](size_t i) const + { + return ((const __hip_bfloat16*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ __hip_bfloat16* ptr() + { + return reinterpret_cast<__hip_bfloat16*>(&data); + } + inline __attribute__((always_inline)) __device__ void fill(__hip_bfloat16 val); + inline __attribute__((always_inline)) __device__ void load(const __hip_bfloat16* ptr); + inline __attribute__((always_inline)) __device__ void store(__hip_bfloat16* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void memcpy(__hip_bfloat16* dst, + const __hip_bfloat16* src); +}; + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_bfloat16, 4>::fill(__hip_bfloat16 val) +{ + *(__hip_bfloat162*)(&data.x) = make_bfloat162(val, val); + *(__hip_bfloat162*)(&data.y) = make_bfloat162(val, val); +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_bfloat16, 4>::load(const __hip_bfloat16* ptr) +{ + data = *((uint2*)ptr); +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_bfloat16, 4>::store(__hip_bfloat16* ptr) const +{ + *((uint2*)ptr) = data; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_bfloat16, 4>::memcpy(__hip_bfloat16* dst, const __hip_bfloat16* src) +{ + *((uint2*)dst) = *((uint2*)src); +} + +// __hip_bfloat16 x 8 or more + +template +struct vec_t<__hip_bfloat16, vec_size> +{ + uint4 data[vec_size / 8]; + + inline __attribute__((always_inline)) __device__ __hip_bfloat16& operator[](size_t i) + { + return ((__hip_bfloat16*)data)[i]; + } + inline __attribute__((always_inline)) __device__ const __hip_bfloat16& + operator[](size_t i) const + { + return ((const __hip_bfloat16*)data)[i]; + } + inline __attribute__((always_inline)) __device__ __hip_bfloat16* ptr() + { + return reinterpret_cast<__hip_bfloat16*>(&data); + } + inline __attribute__((always_inline)) __device__ void fill(__hip_bfloat16 val) + { +#pragma unroll + for(size_t i = 0; i < vec_size / 8; ++i) + { + *(__hip_bfloat162*)(&(data[i].x)) = make_bfloat162(val, val); + *(__hip_bfloat162*)(&(data[i].y)) = make_bfloat162(val, val); + *(__hip_bfloat162*)(&(data[i].z)) = make_bfloat162(val, val); + *(__hip_bfloat162*)(&(data[i].w)) = make_bfloat162(val, val); + } + } + inline __attribute__((always_inline)) __device__ void load(const __hip_bfloat16* ptr) + { +#pragma unroll + for(size_t i = 0; i < vec_size / 8; ++i) + { + data[i] = ((uint4*)ptr)[i]; + } + } + inline __attribute__((always_inline)) __device__ void store(__hip_bfloat16* ptr) const + { +#pragma unroll + for(size_t i = 0; i < vec_size / 8; ++i) + { + ((uint4*)ptr)[i] = data[i]; + } + } + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void memcpy(__hip_bfloat16* dst, + const __hip_bfloat16* src) + { +#pragma unroll + for(size_t i = 0; i < vec_size / 8; ++i) + { + ((uint4*)dst)[i] = ((uint4*)src)[i]; + } + } +}; + +/******************* vec_t *******************/ + +// float x 1 + +template <> +struct vec_t +{ + float data; + + inline __attribute__((always_inline)) __device__ float& operator[](size_t i) + { + return ((float*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ const float& operator[](size_t i) const + { + return ((const float*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ float* ptr() + { + return reinterpret_cast(&data); + } + inline __attribute__((always_inline)) __device__ void fill(float val); + inline __attribute__((always_inline)) __device__ void load(const float* ptr); + inline __attribute__((always_inline)) __device__ void store(float* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void memcpy(float* dst, + const float* src); +}; + +inline __attribute__((always_inline)) __device__ void vec_t::fill(float val) +{ + data = val; +} + +inline __attribute__((always_inline)) __device__ void vec_t::load(const float* ptr) +{ + data = *ptr; +} + +inline __attribute__((always_inline)) __device__ void vec_t::store(float* ptr) const +{ + *ptr = data; +} + +inline __attribute__((always_inline)) __device__ void vec_t::memcpy(float* dst, + const float* src) +{ + *dst = *src; +} + +// float x 2 + +template <> +struct vec_t +{ + float2 data; + + inline __attribute__((always_inline)) __device__ float& operator[](size_t i) + { + return ((float*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ const float& operator[](size_t i) const + { + return ((const float*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ float* ptr() + { + return reinterpret_cast(&data); + } + inline __attribute__((always_inline)) __device__ void fill(float val); + inline __attribute__((always_inline)) __device__ void load(const float* ptr); + inline __attribute__((always_inline)) __device__ void store(float* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void memcpy(float* dst, + const float* src); +}; + +inline __attribute__((always_inline)) __device__ void vec_t::fill(float val) +{ + data = make_float2(val, val); +} + +inline __attribute__((always_inline)) __device__ void vec_t::load(const float* ptr) +{ + data = *((float2*)ptr); +} + +inline __attribute__((always_inline)) __device__ void vec_t::store(float* ptr) const +{ + *((float2*)ptr) = data; +} + +inline __attribute__((always_inline)) __device__ void vec_t::memcpy(float* dst, + const float* src) +{ + *((float2*)dst) = *((float2*)src); +} + +// float x 4 or more +template +struct vec_t +{ + float4 data[vec_size / 4]; + + inline __attribute__((always_inline)) __device__ float& operator[](size_t i) + { + return ((float*)(data))[i]; + } + inline __attribute__((always_inline)) __device__ const float& operator[](size_t i) const + { + return ((const float*)(data))[i]; + } + inline __attribute__((always_inline)) __device__ float* ptr() + { + return reinterpret_cast(&data); + } + inline __attribute__((always_inline)) __device__ void fill(float val) + { +#pragma unroll + for(size_t i = 0; i < vec_size / 4; ++i) + { + data[i] = make_float4(val, val, val, val); + } + } + inline __attribute__((always_inline)) __device__ void load(const float* ptr) + { +#pragma unroll + for(size_t i = 0; i < vec_size / 4; ++i) + { + data[i] = ((float4*)ptr)[i]; + } + } + inline __attribute__((always_inline)) __device__ void store(float* ptr) const + { +#pragma unroll + for(size_t i = 0; i < vec_size / 4; ++i) + { + ((float4*)ptr)[i] = data[i]; + } + } + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void memcpy(float* dst, + const float* src) + { +#pragma unroll + for(size_t i = 0; i < vec_size / 4; ++i) + { + ((float4*)dst)[i] = ((float4*)src)[i]; + } + } +}; + +} // namespace aiter diff --git a/csrc/cpp_itfs/torch_utils.py b/csrc/cpp_itfs/torch_utils.py index 09afec4898..f4ac4ab71c 100644 --- a/csrc/cpp_itfs/torch_utils.py +++ b/csrc/cpp_itfs/torch_utils.py @@ -1,9 +1,38 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + + import torch import ctypes from torch.library import Library from typing import Callable, Optional, Tuple -from csrc.cpp_itfs.utils import AITER_LOG_MORE -from aiter.test_common import log_args +from csrc.cpp_itfs.utils import AITER_LOG_MORE, logger + + +def log_args(func, *args, **kwargs): + import inspect + + callargs = inspect.getcallargs(func, *args, **kwargs) + + prefix = f"calling {func.__name__}(" + blanks = " " * (len(prefix)) + + def getTensorInfo(el): + if isinstance(el, torch.Tensor): + return f"{el.shape} {el.dtype} {el.device} {hex(el.data_ptr())}" + elif isinstance(el, tuple): + viewNum = 5 + if len(el) > viewNum: + el = list(el[:viewNum]) + ["..."] + return f'\n{" "*(len(prefix)+31)}'.join( + ["("] + [f" {getTensorInfo(e)}" for e in el] + [")"] + ) + return el + + info = [f"{el:<28} = {getTensorInfo(callargs[el])}" for el in callargs] + info = f",\n{blanks}".join(info) + logger.info(f"\n{prefix}{info})") + return callargs ctypes_map = { diff --git a/csrc/cpp_itfs/utils.py b/csrc/cpp_itfs/utils.py index 4702887701..9744497193 100644 --- a/csrc/cpp_itfs/utils.py +++ b/csrc/cpp_itfs/utils.py @@ -1,3 +1,7 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + + import shutil import os import subprocess diff --git a/op_tests/test_sampling.py b/op_tests/test_sampling.py new file mode 100644 index 0000000000..3fac3ae3bb --- /dev/null +++ b/op_tests/test_sampling.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. + +import pytest +import torch + +from aiter.ops import sampling # noqa: F401 + +torch.set_default_device("cuda") + + +def _to_tensor_scalar_tuple(x): + if isinstance(x, torch.Tensor): + return (x, 0) + else: + return (None, x) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("p", [0.1, 0.5, 0.9]) +def test_top_p_sampling(batch_size, vocab_size, p): + torch.manual_seed(42) + eps = 1e-4 + pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + sorted_prob, indices = torch.sort(normalized_prob, descending=False) + cdf = torch.cumsum(sorted_prob, dim=-1) + mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0) + mask.scatter_add_(1, indices, (cdf > (1 - p) - eps).int()) + + num_trials = 1000 + for _ in range(num_trials): + samples = torch.ops.aiter.top_p_sampling_from_probs( + normalized_prob, None, *_to_tensor_scalar_tuple(p), deterministic=True + ) + assert torch.all(samples < vocab_size) and torch.all(samples >= 0) + assert torch.all(mask[torch.arange(batch_size), samples] == 1) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("k", [10, 100, 500]) +def test_top_k_renorm_probs(batch_size, vocab_size, k): + if k > vocab_size: + pytest.skip("k should be less than vocab_size") + torch.manual_seed(42) + pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + sorted_prob, _ = torch.sort(normalized_prob, descending=True) + pivot = sorted_prob[:, k - 1] + mask = (normalized_prob >= pivot.unsqueeze(-1)).int() + renorm_prob_ground_truth = normalized_prob.clone() + renorm_prob_ground_truth[mask == 0] = 0 + renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum( + dim=-1, keepdim=True + ) + + renorm_prob = torch.ops.aiter.top_k_renorm_probs( + normalized_prob, *_to_tensor_scalar_tuple(k) + ) + for i in range(batch_size): + torch.testing.assert_close( + renorm_prob_ground_truth[i], + renorm_prob[i], + rtol=1e-3, + atol=1e-3, + ) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("p", [0.1, 0.5]) +@pytest.mark.parametrize("k", [10, 50]) +def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p, k): + torch.manual_seed(42) + # if p == 0.1: + # k = int(vocab_size * 0.5) + # elif p == 0.5: + # k = int(vocab_size * 0.1) + # else: + # raise ValueError("p not recognized") + eps = 1e-4 + pre_norm_prob = torch.rand(batch_size, vocab_size) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + # top-p mask + sorted_prob, indices = torch.sort(normalized_prob, descending=False) + cdf = torch.cumsum(sorted_prob, dim=-1) + mask_top_p = torch.zeros(batch_size, vocab_size, dtype=torch.int32) + mask_top_p.scatter_add_(1, indices, (cdf > (1 - p) - eps).int()) + # top-k mask + sorted_prob, _ = torch.sort(normalized_prob, descending=True) + pivot = sorted_prob[:, k - 1] + mask_top_k = (normalized_prob >= pivot.unsqueeze(-1)).int() + # overall mask + mask = torch.minimum(mask_top_p, mask_top_k) + top_p_tensor = torch.full((batch_size,), p) + top_k_tensor = torch.full((batch_size,), k) + + num_trials = 1000 + for _ in range(num_trials): + samples = torch.ops.aiter.top_k_top_p_sampling_from_probs( + normalized_prob, + None, + *_to_tensor_scalar_tuple(top_k_tensor), + *_to_tensor_scalar_tuple(top_p_tensor), + deterministic=True, + ) + assert torch.all(samples < vocab_size) and torch.all(samples >= 0) + assert torch.all(mask[torch.arange(batch_size), samples] == 1), normalized_prob[ + torch.arange(batch_size), samples + ] + + +if __name__ == "__main__": + test_top_k_top_p_joint_sampling_from_probs(40, 129280, 0.6, 20) + # test_top_k_renorm_probs(1, 129280, 10) + # test_top_p_sampling(1, 129280, 0.1)