From 0e044336c3667c4ec1b35d615aa99a1feb5d21d9 Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Thu, 24 Jul 2025 08:14:08 +0000 Subject: [PATCH 01/19] upload sampling.cuh --- csrc/cpp_itfs/sampling/sampling.cuh | 1273 +++++++++++++++++++++++++++ 1 file changed, 1273 insertions(+) create mode 100644 csrc/cpp_itfs/sampling/sampling.cuh diff --git a/csrc/cpp_itfs/sampling/sampling.cuh b/csrc/cpp_itfs/sampling/sampling.cuh new file mode 100644 index 0000000000..df0a117824 --- /dev/null +++ b/csrc/cpp_itfs/sampling/sampling.cuh @@ -0,0 +1,1273 @@ +#include "hip/hip_runtime.h" +/* + * Copyright (C) 2024-2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #ifndef FLASHINFER_SAMPLING_CUH_ + #define FLASHINFER_SAMPLING_CUH_ + + #include + #include + #include + + #include + #include + #include + #include + #include + #include + + #include "math.cuh" + #include "utils.cuh" + #include "vec_dtypes.cuh" + + namespace aiter { + + namespace sampling { + + using namespace hipcub; + + #define DISPATCH_ALIGNED_VEC_SIZE(aligned_vec_size, ALIGNED_VEC_SIZE, ...) \ + switch (aligned_vec_size) { \ + case 16: { \ + constexpr size_t ALIGNED_VEC_SIZE = 16; \ + __VA_ARGS__ \ + break; \ + } \ + case 8: { \ + constexpr size_t ALIGNED_VEC_SIZE = 8; \ + __VA_ARGS__ \ + break; \ + } \ + case 4: { \ + constexpr size_t ALIGNED_VEC_SIZE = 4; \ + __VA_ARGS__ \ + break; \ + } \ + case 2: { \ + constexpr size_t ALIGNED_VEC_SIZE = 2; \ + __VA_ARGS__ \ + break; \ + } \ + case 1: { \ + constexpr size_t ALIGNED_VEC_SIZE = 1; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size; \ + throw std::runtime_error(err_msg.str()); \ + } \ + } + + + #define DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, ...) \ + if (deterministic) { \ + constexpr bool DETERMINISTIC = true; \ + __VA_ARGS__ \ + } else { \ + constexpr bool DETERMINISTIC = false; \ + __VA_ARGS__ \ + } + + #define DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, ...) \ + if (compute_capacity.first >= 8) { \ + constexpr uint32_t BLOCK_THREADS = 1024; \ + __VA_ARGS__ \ + } else { \ + constexpr uint32_t BLOCK_THREADS = 512; \ + __VA_ARGS__ \ + } + + constexpr BlockScanAlgorithm SCAN_ALGO = BLOCK_SCAN_WARP_SCANS; + constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS; + + #if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120100) + #define FLASHINFER_CUB_SUBTRACTLEFT_DEFINED + #endif + + template + struct ValueCount { + T value; + int count; + + __device__ ValueCount operator+(const ValueCount& other) const { + return {value + other.value, count + other.count}; + } + __device__ ValueCount& operator+=(const ValueCount& other) { + value += other.value; + count += other.count; + return *this; + } + }; + + struct BoolDiffOp { + __device__ __forceinline__ bool operator()(const bool& lhs, const bool& rhs) const { + return lhs != rhs; + } + }; + + template + struct SamplingTempStorage { + union { + float deterministic_scan[BLOCK_THREADS / 32]; + typename BlockScan::TempStorage scan; + typename BlockReduce::TempStorage reduce; + typename BlockReduce::TempStorage reduce_int; + typename BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage + reduce_value_count; + typename BlockAdjacentDifference::TempStorage adj_diff; + } block_prim; + struct { + int32_t sampled_id; + int32_t last_valid_id; + float max_val; + union { + float value; + ValueCount pair; + } block_aggregate; + }; + }; + + template + __device__ __forceinline__ T infinity() { + return __builtin_huge_valf(); + } + + /*! + * \brief Deterministic inclusive scan implementation, use Belloch scan algorithm. + * \note This implementation is slower than the hipcub::BlockScan, but it is deterministic. + */ + template + __device__ __forceinline__ void DeterministicInclusiveSum( + const float* in_data, float* out_data, + SamplingTempStorage* temp_storage) { + float* smem_prefix_sum = temp_storage->block_prim.deterministic_scan; + float thread_data[VEC_SIZE]; + float thread_sum = 0; + #pragma unroll + for (uint32_t i = 0; i < VEC_SIZE; ++i) { + thread_sum += in_data[i]; + thread_data[i] = thread_sum; + } + + float thread_exclusive_prefix_sum = thread_sum; + + #pragma unroll + for (uint32_t offset = 1; offset < 32; offset *= 2) { + float tmp = __shfl_up_sync(0xffffffff, thread_exclusive_prefix_sum, offset); + if ((threadIdx.x + 1) % (offset * 2) == 0) { + thread_exclusive_prefix_sum += tmp; + } + } + + float warp_sum = __shfl_sync(0xffffffff, thread_exclusive_prefix_sum, threadIdx.x | 0xffffffff); + if (threadIdx.x % 32 == 31) { + thread_exclusive_prefix_sum = 0; + } + + #pragma unroll + for (uint32_t offset = 16; offset >= 1; offset /= 2) { + float tmp = __shfl_xor_sync(0xffffffff, thread_exclusive_prefix_sum, offset); + if ((threadIdx.x + 1) % (offset * 2) == 0) { + thread_exclusive_prefix_sum = tmp + thread_exclusive_prefix_sum; + } + if ((threadIdx.x + 1) % (offset * 2) == offset) { + thread_exclusive_prefix_sum = tmp; + } + } + + smem_prefix_sum[threadIdx.x / 32] = warp_sum; + __syncthreads(); + + if (threadIdx.x < 32) { + float warp_exclusive_prefix_sum = + (threadIdx.x < BLOCK_THREADS / 32) ? smem_prefix_sum[threadIdx.x] : 0; + + #pragma unroll + for (uint32_t offset = 1; offset < 32; offset *= 2) { + float tmp = __shfl_up_sync(0xffffffff, warp_exclusive_prefix_sum, offset); + if ((threadIdx.x + 1) % (offset * 2) == 0) { + warp_exclusive_prefix_sum += tmp; + } + } + + if (threadIdx.x % 32 == 31) { + warp_exclusive_prefix_sum = 0; + } + + #pragma unroll + for (uint32_t offset = 16; offset >= 1; offset /= 2) { + float tmp = __shfl_xor_sync(0xffffffff, warp_exclusive_prefix_sum, offset); + if ((threadIdx.x + 1) % (offset * 2) == 0) { + warp_exclusive_prefix_sum = tmp + warp_exclusive_prefix_sum; + } + if ((threadIdx.x + 1) % (offset * 2) == offset) { + warp_exclusive_prefix_sum = tmp; + } + } + if (threadIdx.x < BLOCK_THREADS / 32) { + smem_prefix_sum[threadIdx.x] = warp_exclusive_prefix_sum; + } + } + __syncthreads(); + + #pragma unroll + for (uint32_t i = 0; i < VEC_SIZE; ++i) { + out_data[i] = smem_prefix_sum[threadIdx.x / 32] + thread_exclusive_prefix_sum + thread_data[i]; + } + } + + template + __device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, uint32_t d, + TempStorage& temp_storage) { + const uint32_t tx = threadIdx.x; + vec_t in_data_vec; + + float max_val = 0; + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + in_data_vec.fill(0); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + in_data_vec.cast_load(in_data + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + } + float in_data_[VEC_SIZE]; + #pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + in_data_[j] = in_data_vec[j]; + } + max_val = max( + max_val, BlockReduce(temp_storage.block_prim.reduce) + .Reduce(in_data_, hipcub::Max())); + __syncthreads(); + } + if (tx == 0) { + temp_storage.max_val = max_val; + } + __syncthreads(); + return temp_storage.max_val; + } + + template + __device__ __forceinline__ void DeviceSamplingFromProb( + uint32_t i, uint32_t d, Predicate pred, float u, vec_t prob_vec, + float& aggregate, + SamplingTempStorage* temp_storage) { + const uint32_t tx = threadIdx.x; + float prob_greater_than_threshold[VEC_SIZE]; + float inclusive_cdf[VEC_SIZE]; + bool greater_than_u[VEC_SIZE], valid[VEC_SIZE]; + #pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + prob_greater_than_threshold[j] = pred(prob_vec[j]) ? prob_vec[j] : 0; + valid[j] = pred(prob_vec[j]) && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d; + } + float aggregate_local = + BlockReduce(temp_storage->block_prim.reduce) + .Sum(prob_greater_than_threshold); + if (tx == 0) { + temp_storage->block_aggregate.value = aggregate_local; + } + __syncthreads(); + aggregate_local = temp_storage->block_aggregate.value; + + if (aggregate + aggregate_local > u) { + if constexpr (DETERMINISTIC) { + DeterministicInclusiveSum( + prob_greater_than_threshold, inclusive_cdf, temp_storage); + } else { + BlockScan(temp_storage->block_prim.scan) + .InclusiveSum(prob_greater_than_threshold, inclusive_cdf); + + __syncthreads(); + } + + #pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + greater_than_u[j] = (inclusive_cdf[j] + aggregate > u) && valid[j]; + } + + bool greater_than_u_diff[VEC_SIZE]; + #ifdef FLASHINFER_CUB_SUBTRACTLEFT_DEFINED + BlockAdjacentDifference(temp_storage->block_prim.adj_diff) + .SubtractLeft(greater_than_u, greater_than_u_diff, BoolDiffOp()); + #else + BlockAdjacentDifference(temp_storage->block_prim.adj_diff) + .FlagHeads(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0); + #endif + __syncthreads(); + + #pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + if (greater_than_u_diff[j]) { + atomicMin(&(temp_storage->sampled_id), (i * BLOCK_THREADS + tx) * VEC_SIZE + j); + } + } + __syncthreads(); + } + + // update the last valid index + int valid_index[VEC_SIZE]; + #pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + if (valid[j]) { + valid_index[j] = (i * BLOCK_THREADS + tx) * VEC_SIZE + j; + } else { + valid_index[j] = -1; + } + } + int max_valid_index = + BlockReduce(temp_storage->block_prim.reduce_int) + .Reduce(valid_index, hipcub::Max()); + if (tx == 0 && max_valid_index != -1) { + temp_storage->last_valid_id = max_valid_index; + } + __syncthreads(); + aggregate += aggregate_local; + } + + template + struct DataAndIndex { + DType data; + IdType index; + + __device__ DataAndIndex operator+(const DataAndIndex& other) const { + if (data > other.data) { + return {data, index}; + } else { + return {other.data, other.index}; + } + } + __device__ DataAndIndex& operator+=(const DataAndIndex& other) { + if (data > other.data) { + return *this; + } else { + data = other.data; + index = other.index; + return *this; + } + } + }; + + template + __device__ __forceinline__ vec_t GenerateGumbelNoise(uint64_t philox_seed, + uint64_t philox_offset, + uint64_t subsequence) { + hiprandStatePhilox4_32_10_t state; + vec_t noise; + constexpr float kEPSILON = 1e-20f; + constexpr float kLOG2 = 0.6931471806f; + auto uniform2gumbel = [](float x) { return -kLOG2 * log2f(-log2f(x + kEPSILON) + kEPSILON); }; + // TODO: compare the speed of log2 and log + #pragma unroll + for (uint32_t i = 0; i + 4 <= VEC_SIZE; i += 4) { + hiprand_init(philox_seed, subsequence + i, philox_offset, &state); + float4 noise_vec = hiprand_uniform4(&state); + noise[i] = uniform2gumbel(noise_vec.x); + noise[i + 1] = uniform2gumbel(noise_vec.y); + noise[i + 2] = uniform2gumbel(noise_vec.z); + noise[i + 3] = uniform2gumbel(noise_vec.w); + } + if constexpr (VEC_SIZE % 4 != 0) { + hiprand_init(philox_seed, subsequence + VEC_SIZE / 4 * 4, philox_offset, &state); + float4 noise_vec = hiprand_uniform4(&state); + if constexpr (VEC_SIZE % 4 == 1) { + noise[VEC_SIZE - 1] = uniform2gumbel(noise_vec.x); + } else if constexpr (VEC_SIZE % 4 == 2) { + noise[VEC_SIZE - 2] = uniform2gumbel(noise_vec.x); + noise[VEC_SIZE - 1] = uniform2gumbel(noise_vec.y); + } else if constexpr (VEC_SIZE % 4 == 3) { + noise[VEC_SIZE - 3] = uniform2gumbel(noise_vec.x); + noise[VEC_SIZE - 2] = uniform2gumbel(noise_vec.y); + noise[VEC_SIZE - 1] = uniform2gumbel(noise_vec.z); + } + } + + if constexpr (std::is_same_v) { + return noise; + } else { + vec_t ret; + #pragma unroll + for (uint32_t i = 0; i < VEC_SIZE; ++i) { + ret[i] = static_cast(noise[i]); + } + return ret; + } + } + + template + __global__ void SamplingFromLogitsKernel(DType* logits, IdType* output, IdType* indices, uint32_t d, + uint64_t philox_seed, uint64_t philox_offset) { + const uint32_t bx = blockIdx.x, tx = threadIdx.x; + const uint32_t row_idx = indices == nullptr ? bx : indices[bx]; + using SharedMem = typename BlockReduce, BLOCK_THREADS, + REDUCE_ALGORITHM>::TempStorage; + extern __shared__ __align__(alignof(SharedMem)) uint8_t smem_sampling[]; + auto& temp_storage = reinterpret_cast(smem_sampling); + + vec_t logits_vec; + DataAndIndex max_data = {-infinity(), 0}; + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + logits_vec.fill(-infinity()); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + logits_vec.cast_load(logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + } + + vec_t gumbel_noise = GenerateGumbelNoise( + philox_seed, philox_offset, + static_cast(bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE)); + DataAndIndex cur_data[VEC_SIZE]; + #pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + cur_data[j].data = (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d + ? logits_vec[j] + gumbel_noise[j] + : -infinity(); + cur_data[j].index = (i * BLOCK_THREADS + tx) * VEC_SIZE + j; + } + + max_data += + BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage) + .Sum(cur_data); + } + if (tx == 0) { + output[bx] = max_data.index; + } + } + + template + __global__ void SamplingFromProbKernel(DType* probs, IdType* output, IdType* indices, uint32_t d, + uint64_t philox_seed, uint64_t philox_offset) { + hiprandStatePhilox4_32_10_t state; + const uint32_t bx = blockIdx.x, tx = threadIdx.x; + hiprand_init(philox_seed, bx, philox_offset, &state); + const uint32_t row_idx = indices == nullptr ? bx : indices[bx]; + + extern __shared__ __align__( + alignof(SamplingTempStorage)) + uint8_t smem_sampling[]; + auto& temp_storage = + reinterpret_cast&>( + smem_sampling); + temp_storage.sampled_id = d; + __syncthreads(); + + vec_t probs_vec; + float aggregate(0); + float u = hiprand_uniform(&state); + + #pragma unroll 2 + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(0); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + } + + DeviceSamplingFromProb( + i, d, [](float x) { return x > 0; }, u, probs_vec, aggregate, &temp_storage); + if (float(aggregate) > u) { + break; + } + } + int sampled_id = temp_storage.sampled_id; + if (sampled_id == d) { + // NOTE(Zihao): this would happen when u is very close to 1 + // and the sum of probabilities is smaller than u + // In this case, we use the last valid index as the sampled id + sampled_id = temp_storage.last_valid_id; + } + output[bx] = sampled_id; + } + + template + __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType* indices, + IdType* top_k_arr, uint32_t top_k_val, uint32_t d, + uint64_t philox_seed, uint64_t philox_offset) { + const uint32_t batch_size = gridDim.x; + const uint32_t bx = blockIdx.x, tx = threadIdx.x; + hiprandStatePhilox4_32_10_t state; + hiprand_init(philox_seed, bx, philox_offset, &state); + const uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx]; + const uint32_t row_idx = indices == nullptr ? bx : indices[bx]; + + extern __shared__ __align__( + alignof(SamplingTempStorage)) + uint8_t smem_sampling[]; + auto& temp_storage = + reinterpret_cast&>( + smem_sampling); + + vec_t probs_vec; + float aggregate; + float q = 1; + double low = 0, high = 1.f; + int sampled_id; + int round = 0; + do { + round += 1; + temp_storage.sampled_id = d; + __syncthreads(); + float u = hiprand_uniform(&state) * q; + aggregate = 0; + #pragma unroll 2 + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(0); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + } + + DeviceSamplingFromProb( + i, d, [&](float x) { return x > low; }, u, probs_vec, aggregate, &temp_storage); + if (aggregate > u) { + break; + } + } + __syncthreads(); + sampled_id = temp_storage.sampled_id; + if (sampled_id == d) { + // NOTE(Zihao): this would happen when u is very close to 1 + // and the sum of probabilities is smaller than u + // In this case, we use the last valid index as the sampled id + sampled_id = temp_storage.last_valid_id; + } + double pivot_0 = probs[row_idx * d + sampled_id]; + double pivot_1 = (pivot_0 + high) / 2; + + ValueCount aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0}; + #pragma unroll 2 + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(0); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + } + + ValueCount probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE]; + #pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + probs_gt_pivot_0[j] = { + (probs_vec[j] > pivot_0) ? probs_vec[j] : 0, + (probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + probs_gt_pivot_1[j] = { + (probs_vec[j] > pivot_1) ? probs_vec[j] : 0, + (probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + } + + aggregate_gt_pivot_0 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_value_count) + .Sum(probs_gt_pivot_0); + if (tx == 0) { + temp_storage.block_aggregate.pair = aggregate_gt_pivot_0; + } + __syncthreads(); + aggregate_gt_pivot_0 = temp_storage.block_aggregate.pair; + + aggregate_gt_pivot_1 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_value_count) + .Sum(probs_gt_pivot_1); + if (tx == 0) { + temp_storage.block_aggregate.pair = aggregate_gt_pivot_1; + } + __syncthreads(); + aggregate_gt_pivot_1 = temp_storage.block_aggregate.pair; + } + if (aggregate_gt_pivot_0.count < k) { + // case 1: pivot_0 accepted + break; + } + if (aggregate_gt_pivot_1.count < k) { + // case 2: pivot_0 rejected, pivot_1 accepted + low = pivot_0; + high = pivot_1; + q = aggregate_gt_pivot_0.value; + } else { + // case 3: pivot_0 rejected, pivot_1 rejected + low = pivot_1; + q = aggregate_gt_pivot_1.value; + } + } while (low < high); + __syncthreads(); + if (tx == 0) { + output[bx] = sampled_id; + } + } + + template + __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType* indices, + float* top_p_arr, float top_p_val, uint32_t d, + uint64_t philox_seed, uint64_t philox_offset) { + const uint32_t batch_size = gridDim.x; + const uint32_t bx = blockIdx.x, tx = threadIdx.x; + hiprandStatePhilox4_32_10_t state; + hiprand_init(philox_seed, bx, philox_offset, &state); + const uint32_t row_idx = indices == nullptr ? bx : indices[bx]; + float top_p = (top_p_arr == nullptr) ? top_p_val : top_p_arr[row_idx]; + + extern __shared__ __align__( + alignof(SamplingTempStorage)) + uint8_t smem_sampling[]; + auto& temp_storage = + reinterpret_cast&>( + smem_sampling); + + vec_t probs_vec; + float aggregate; + float q = 1; + double low = 0, high = 1.f; + int sampled_id; + do { + temp_storage.sampled_id = d; + __syncthreads(); + float u = hiprand_uniform(&state) * q; + aggregate = 0; + #pragma unroll 2 + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(0); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + } + + DeviceSamplingFromProb( + i, d, [&](float x) { return x > low; }, u, probs_vec, aggregate, &temp_storage); + if (aggregate > u) { + break; + } + } + __syncthreads(); + sampled_id = temp_storage.sampled_id; + if (sampled_id == d) { + // NOTE(Zihao): this would happen when u is very close to 1 + // and the sum of probabilities is smaller than u + // In this case, we use the last valid index as the sampled id + sampled_id = temp_storage.last_valid_id; + } + double pivot_0 = probs[row_idx * d + sampled_id]; + double pivot_1 = (pivot_0 + high) / 2; + + float aggregate_gt_pivot_0 = 0, aggregate_gt_pivot_1 = 0; + #pragma unroll 2 + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(0); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + } + + float probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE]; + #pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + probs_gt_pivot_0[j] = (probs_vec[j] > pivot_0) ? probs_vec[j] : 0; + probs_gt_pivot_1[j] = (probs_vec[j] > pivot_1) ? probs_vec[j] : 0; + } + + aggregate_gt_pivot_0 += BlockReduce(temp_storage.block_prim.reduce) + .Sum(probs_gt_pivot_0); + if (tx == 0) { + temp_storage.block_aggregate.value = aggregate_gt_pivot_0; + } + __syncthreads(); + aggregate_gt_pivot_0 = temp_storage.block_aggregate.value; + + aggregate_gt_pivot_1 += BlockReduce(temp_storage.block_prim.reduce) + .Sum(probs_gt_pivot_1); + if (tx == 0) { + temp_storage.block_aggregate.value = aggregate_gt_pivot_1; + } + __syncthreads(); + aggregate_gt_pivot_1 = temp_storage.block_aggregate.value; + } + if (aggregate_gt_pivot_0 < top_p) { + // case 1: pivot_0 accepted + break; + } + if (aggregate_gt_pivot_1 < top_p) { + // case 2: pivot_0 rejected, pivot_1 accepted + low = pivot_0; + high = pivot_1; + q = aggregate_gt_pivot_0; + } else { + // case 3: pivot_0 rejected, pivot_1 rejected + low = pivot_1; + q = aggregate_gt_pivot_1; + } + } while (low < high); + __syncthreads(); + if (tx == 0) { + output[bx] = sampled_id; + } + } + + template + __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* top_k_arr, float* top_p_arr, + IdType* output, IdType* indices, IdType top_k_val, + float top_p_val, uint32_t d, uint64_t philox_seed, + uint64_t philox_offset) { + const uint32_t batch_size = gridDim.x; + const uint32_t bx = blockIdx.x, tx = threadIdx.x; + hiprandStatePhilox4_32_10_t state; + hiprand_init(philox_seed, bx, philox_offset, &state); + const uint32_t row_idx = indices == nullptr ? bx : indices[bx]; + const uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[row_idx]; + const float p = top_p_arr == nullptr ? top_p_val : top_p_arr[row_idx]; + + extern __shared__ __align__( + alignof(SamplingTempStorage)) + uint8_t smem_sampling[]; + auto& temp_storage = + reinterpret_cast&>( + smem_sampling); + + vec_t probs_vec; + float aggregate; + float q = 1; + double low = 0, high = 1.f; + int sampled_id; + do { + temp_storage.sampled_id = d; + __syncthreads(); + float u = hiprand_uniform(&state) * q; + aggregate = 0; + #pragma unroll 2 + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(0); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + } + + DeviceSamplingFromProb( + i, d, [&](float x) { return x > low; }, u, probs_vec, aggregate, &temp_storage); + if (aggregate > u) { + break; + } + } + __syncthreads(); + sampled_id = temp_storage.sampled_id; + if (sampled_id == d) { + // NOTE(Zihao): this would happen when u is very close to 1 + // and the sum of probabilities is smaller than u + // In this case, we use the last valid index as the sampled id + sampled_id = temp_storage.last_valid_id; + } + double pivot_0 = probs[row_idx * d + sampled_id]; + double pivot_1 = (pivot_0 + high) / 2; + + ValueCount aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0}; + #pragma unroll 2 + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(0); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + } + + ValueCount probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE]; + #pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + probs_gt_pivot_0[j] = { + (probs_vec[j] > pivot_0) ? probs_vec[j] : 0, + (probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + probs_gt_pivot_1[j] = { + (probs_vec[j] > pivot_1) ? probs_vec[j] : 0, + (probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + } + + aggregate_gt_pivot_0 += + BlockReduce, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count) + .Sum(probs_gt_pivot_0); + if (tx == 0) { + temp_storage.block_aggregate.pair = aggregate_gt_pivot_0; + } + __syncthreads(); + aggregate_gt_pivot_0 = temp_storage.block_aggregate.pair; + + aggregate_gt_pivot_1 += + BlockReduce, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count) + .Sum(probs_gt_pivot_1); + if (tx == 0) { + temp_storage.block_aggregate.pair = aggregate_gt_pivot_1; + } + __syncthreads(); + aggregate_gt_pivot_1 = temp_storage.block_aggregate.pair; + } + if (aggregate_gt_pivot_0.count < k && aggregate_gt_pivot_0.value < p) { + // case 1: pivot_0 accepted + break; + } + if (aggregate_gt_pivot_1.count < k && aggregate_gt_pivot_1.value < p) { + // case 2: pivot_0 rejected, pivot_1 accepted + low = pivot_0; + high = pivot_1; + q = aggregate_gt_pivot_0.value; + } else { + // case 3: pivot_0 rejected, pivot_1 rejected + low = pivot_1; + q = aggregate_gt_pivot_1.value; + } + } while (low < high); + __syncthreads(); + if (tx == 0) { + output[bx] = sampled_id; + } + } + + template + hipError_t SamplingFromLogits(T* logits, IdType* output, IdType* indices, uint32_t batch_size, + uint32_t d, bool deterministic, uint64_t philox_seed, + uint64_t philox_offset, hipStream_t stream = 0) { + constexpr uint32_t BLOCK_THREADS = 1024; + const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&logits, &output, &indices, &d, &philox_seed, &philox_offset}; + const uint32_t smem_size = sizeof( + typename BlockReduce, BLOCK_THREADS, REDUCE_ALGO>::TempStorage); + + DISPATCH_ALIGNED_VEC_SIZE( + vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { + auto kernel = SamplingFromLogitsKernel; + hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); + })}); + return hipSuccess; + } + + template + hipError_t SamplingFromProb(T* probs, IdType* output, IdType* indices, uint32_t batch_size, + uint32_t d, bool deterministic, uint64_t philox_seed, + uint64_t philox_offset, hipStream_t stream = 0) { + constexpr uint32_t BLOCK_THREADS = 1024; + const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&probs, &output, &indices, &d, &philox_seed, &philox_offset, &d}; + const uint32_t smem_size = sizeof(SamplingTempStorage); + + DISPATCH_ALIGNED_VEC_SIZE( + vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { + auto kernel = SamplingFromProbKernel; + + hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); + })}); + return hipSuccess; + } + + template + hipError_t TopKSamplingFromProb(T* probs, IdType* output, IdType* indices, T* top_k_arr, + uint32_t batch_size, uint32_t top_k_val, uint32_t d, + bool deterministic, uint64_t philox_seed, uint64_t philox_offset, + hipStream_t stream = 0) { + const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + + auto compute_capacity = GetCudaComputeCapability(); + DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, { + const uint32_t smem_size = sizeof(SamplingTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&probs, &output, &indices, &top_k_arr, + &top_k_val, &d, &philox_seed, &philox_offset}; + + DISPATCH_ALIGNED_VEC_SIZE( + vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { + auto kernel = TopKSamplingFromProbKernel; + + hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); + })}); + return hipSuccess; + }); + } + + template + hipError_t TopPSamplingFromProb(T* probs, IdType* output, IdType* indices, T* top_p_arr, + uint32_t batch_size, T top_p_val, uint32_t d, bool deterministic, + uint64_t philox_seed, uint64_t philox_offset, + hipStream_t stream = 0) { + constexpr uint32_t BLOCK_THREADS = 1024; + const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + + const uint32_t smem_size = sizeof(SamplingTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&probs, &output, &indices, &top_p_arr, + &top_p_val, &d, &philox_seed, &philox_offset}; + + DISPATCH_ALIGNED_VEC_SIZE( + vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { + auto kernel = TopPSamplingFromProbKernel; + + hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); + })}); + return hipSuccess; + } + + template + hipError_t TopKTopPSamplingFromProb(T* probs, IdType* top_k_arr, T* top_p_arr, IdType* output, + IdType* indices, uint32_t batch_size, IdType top_k_val, + T top_p_val, uint32_t d, bool deterministic, + uint64_t philox_seed, uint64_t philox_offset, + hipStream_t stream = 0) { + const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + + auto compute_capacity = GetCudaComputeCapability(); + DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, { + const uint32_t smem_size = sizeof(SamplingTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&probs, &top_k_arr, &top_p_arr, &output, &indices, + &top_k_val, &top_p_val, &d, &philox_seed, &philox_offset}; + + DISPATCH_ALIGNED_VEC_SIZE( + vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { + auto kernel = TopKTopPSamplingFromProbKernel; + + hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); + })}); + return hipSuccess; + }); + } + + template + struct RenormTempStorage { + union { + typename BlockReduce::TempStorage reduce; + typename BlockReduce::TempStorage reduce_int; + typename BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage + reduce_value_count; + } block_prim; + struct { + float max_val; + float min_val; + union { + struct { + float values[2]; + }; + struct { + int counts[2]; + }; + struct { + ValueCount pairs[2]; + }; + } block_aggregate; + }; + }; + + template + __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float* top_p_arr, + float top_p_val, uint32_t d) { + const uint32_t bx = blockIdx.x, tx = threadIdx.x; + const uint32_t row_idx = bx; + float p = top_p_arr == nullptr ? top_p_val : top_p_arr[bx]; + + extern __shared__ __align__(alignof(RenormTempStorage)) + uint8_t smem_renorm[]; + auto& temp_storage = + reinterpret_cast&>(smem_renorm); + temp_storage.max_val = 0; + vec_t probs_vec; + + float max_val = GetMaxValue>(probs, row_idx, d, + temp_storage); + + double low = 0, high = max_val; + float min_gt_low, max_le_high; + float sum_low = 1; + // f(x) = sum(probs[probs > x]), f(x) is non-increasing + // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high} + // loop invariant: + // - f(low) >= p, f(high) < p + // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high) + // stopping condition + // - f(low) >= p, f(min_gt_low) == f(max_le_high) == f(high) < p + do { + double pivot_0 = (high + 2 * low) / 3; + double pivot_1 = (2 * high + low) / 3; + + float aggregate_gt_pivot_0 = 0, aggregate_gt_pivot_1 = 0; + min_gt_low = high; + max_le_high = low; + #pragma unroll 2 + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(0); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + } + + float probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE]; + #pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + probs_gt_pivot_0[j] = (probs_vec[j] > pivot_0) ? probs_vec[j] : 0; + probs_gt_pivot_1[j] = (probs_vec[j] > pivot_1) ? probs_vec[j] : 0; + + if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { + min_gt_low = min(min_gt_low, probs_vec[j]); + } + if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { + max_le_high = max(max_le_high, probs_vec[j]); + } + } + + aggregate_gt_pivot_0 += + BlockReduce(temp_storage.block_prim.reduce) + .Sum(probs_gt_pivot_0); + __syncthreads(); + + aggregate_gt_pivot_1 += + BlockReduce(temp_storage.block_prim.reduce) + .Sum(probs_gt_pivot_1); + __syncthreads(); + } + min_gt_low = BlockReduce(temp_storage.block_prim.reduce) + .Reduce(min_gt_low, hipcub::Min()); + __syncthreads(); + max_le_high = + BlockReduce(temp_storage.block_prim.reduce) + .Reduce(max_le_high, hipcub::Max()); + if (tx == 0) { + temp_storage.block_aggregate.values[0] = aggregate_gt_pivot_0; + temp_storage.block_aggregate.values[1] = aggregate_gt_pivot_1; + temp_storage.min_val = min_gt_low; + temp_storage.max_val = max_le_high; + } + __syncthreads(); + aggregate_gt_pivot_0 = temp_storage.block_aggregate.values[0]; + aggregate_gt_pivot_1 = temp_storage.block_aggregate.values[1]; + min_gt_low = temp_storage.min_val; + max_le_high = temp_storage.max_val; + + if (aggregate_gt_pivot_1 >= p) { + low = pivot_1; + sum_low = aggregate_gt_pivot_1; + } else if (aggregate_gt_pivot_0 >= p) { + low = pivot_0; + high = min(pivot_1, max_le_high); + sum_low = aggregate_gt_pivot_0; + } else { + high = min(pivot_0, max_le_high); + } + } while (min_gt_low != max_le_high); + + float normalizer = math::ptx_rcp(max(sum_low, 1e-8)); + + // normalize + #pragma unroll 2 + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(0); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + } + #pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + probs_vec[j] = (probs_vec[j] > low) ? probs_vec[j] * normalizer : 0; + } + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.cast_store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + + tx * VEC_SIZE); + } + } + } + + template + __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* top_k_arr, + uint32_t top_k_val, uint32_t d) { + const uint32_t bx = blockIdx.x, tx = threadIdx.x; + const uint32_t row_idx = bx; + uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx]; + double pivot = -infinity(), normalizer = 1; + vec_t probs_vec; + if (k < d) { + extern __shared__ __align__(alignof(RenormTempStorage)) + uint8_t smem_renorm[]; + auto& temp_storage = + reinterpret_cast&>(smem_renorm); + temp_storage.max_val = 0; + + float max_val = GetMaxValue>( + probs, row_idx, d, temp_storage); + + double low = 0, high = max_val; + float min_gt_low, max_le_high; + float sum_low = 1; + // f(x) = len(nonzero(probs > x)), f(x) is non-increasing + // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high} + // loop invariant: + // - f(low) >= k, f(high) < k + // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high) + // stopping condition: min_gt_low == max_le_high + // - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k + do { + double pivot_0 = (high + 2 * low) / 3; + double pivot_1 = (2 * high + low) / 3; + + ValueCount aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0}; + min_gt_low = high; + max_le_high = low; + #pragma unroll 2 + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(0); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + } + ValueCount probs_gt_pivot_0_pair[VEC_SIZE], probs_gt_pivot_1_pair[VEC_SIZE]; + #pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + probs_gt_pivot_0_pair[j] = { + (probs_vec[j] > pivot_0) ? probs_vec[j] : 0, + (probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + probs_gt_pivot_1_pair[j] = { + (probs_vec[j] > pivot_1) ? probs_vec[j] : 0, + (probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + + if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { + min_gt_low = min(min_gt_low, probs_vec[j]); + } + if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { + max_le_high = max(max_le_high, probs_vec[j]); + } + } + + aggregate_gt_pivot_0 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_value_count) + .Sum(probs_gt_pivot_0_pair); + __syncthreads(); + + aggregate_gt_pivot_1 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_value_count) + .Sum(probs_gt_pivot_1_pair); + __syncthreads(); + } + min_gt_low = + BlockReduce(temp_storage.block_prim.reduce) + .Reduce(min_gt_low, hipcub::Min()); + __syncthreads(); + max_le_high = + BlockReduce(temp_storage.block_prim.reduce) + .Reduce(max_le_high, hipcub::Max()); + if (tx == 0) { + temp_storage.block_aggregate.pairs[0] = aggregate_gt_pivot_0; + temp_storage.block_aggregate.pairs[1] = aggregate_gt_pivot_1; + temp_storage.min_val = min_gt_low; + temp_storage.max_val = max_le_high; + } + __syncthreads(); + aggregate_gt_pivot_0 = temp_storage.block_aggregate.pairs[0]; + aggregate_gt_pivot_1 = temp_storage.block_aggregate.pairs[1]; + min_gt_low = temp_storage.min_val; + max_le_high = temp_storage.max_val; + + if (aggregate_gt_pivot_1.count >= k) { + low = pivot_1; + sum_low = float(aggregate_gt_pivot_1.value); + } else if (aggregate_gt_pivot_0.count >= k) { + low = pivot_0; + high = min(pivot_1, max_le_high); + sum_low = float(aggregate_gt_pivot_0.value); + } else { + high = min(pivot_0, max_le_high); + } + } while (min_gt_low != max_le_high); + + normalizer = math::ptx_rcp(max(sum_low, 1e-8)); + pivot = low; + } + + // normalize + #pragma unroll 2 + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(0); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + } + #pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + probs_vec[j] = (probs_vec[j] > pivot) ? probs_vec[j] * normalizer : 0; + } + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + } + } + } + + template + hipError_t TopPRenormProb(DType* probs, DType* renormed_prob, float* top_p_arr, + uint32_t batch_size, float top_p_val, uint32_t d, + hipStream_t stream = 0) { + constexpr uint32_t BLOCK_THREADS = 1024; + const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); + + const uint32_t smem_size = sizeof(RenormTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&probs, &renormed_prob, &top_p_arr, &top_p_val, &d}; + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { + auto kernel = TopPRenormProbKernel; + hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); + hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); + }); + return hipSuccess; + } + + template + hipError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr, + uint32_t batch_size, uint32_t top_k_val, uint32_t d, + hipStream_t stream = 0) { + const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); + + auto compute_capacity = GetCudaComputeCapability(); + DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, { + const uint32_t smem_size = sizeof(RenormTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&probs, &renormed_prob, &top_k_arr, &top_k_val, &d}; + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { + auto kernel = TopKRenormProbKernel; + + hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); + hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); + }); + return hipSuccess; + }); + } + + } // namespace sampling + + } // namespace flashinfer + + #endif // FLASHINFER_SAMPLING_CUH_ \ No newline at end of file From 807ae758296d0245bbef22cc8ccb2376a1342dde Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Thu, 24 Jul 2025 08:31:38 +0000 Subject: [PATCH 02/19] update --- csrc/cpp_itfs/sampling/sampling.cuh | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/csrc/cpp_itfs/sampling/sampling.cuh b/csrc/cpp_itfs/sampling/sampling.cuh index df0a117824..8e5f6e01ca 100644 --- a/csrc/cpp_itfs/sampling/sampling.cuh +++ b/csrc/cpp_itfs/sampling/sampling.cuh @@ -27,9 +27,7 @@ #include #include #include - - #include "math.cuh" - #include "utils.cuh" + #include "vec_dtypes.cuh" namespace aiter { @@ -94,9 +92,6 @@ constexpr BlockScanAlgorithm SCAN_ALGO = BLOCK_SCAN_WARP_SCANS; constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS; - #if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120100) - #define FLASHINFER_CUB_SUBTRACTLEFT_DEFINED - #endif template struct ValueCount { @@ -303,13 +298,13 @@ } bool greater_than_u_diff[VEC_SIZE]; - #ifdef FLASHINFER_CUB_SUBTRACTLEFT_DEFINED + BlockAdjacentDifference(temp_storage->block_prim.adj_diff) .SubtractLeft(greater_than_u, greater_than_u_diff, BoolDiffOp()); - #else - BlockAdjacentDifference(temp_storage->block_prim.adj_diff) - .FlagHeads(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0); - #endif + + // BlockAdjacentDifference(temp_storage->block_prim.adj_diff) + // .FlagHeads(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0); + __syncthreads(); #pragma unroll @@ -1082,7 +1077,7 @@ } } while (min_gt_low != max_le_high); - float normalizer = math::ptx_rcp(max(sum_low, 1e-8)); + float normalizer = __frcp_rn(max(sum_low, 1e-8)); // normalize #pragma unroll 2 @@ -1204,7 +1199,7 @@ } } while (min_gt_low != max_le_high); - normalizer = math::ptx_rcp(max(sum_low, 1e-8)); + normalizer = __frcp_rn(max(sum_low, 1e-8)); pivot = low; } From 74044c2163536b9ae17d3da4f35f338e3bdd21dc Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Thu, 24 Jul 2025 09:12:50 +0000 Subject: [PATCH 03/19] upload vec_dtypes.cuh --- csrc/cpp_itfs/sampling/sampling.cuh | 115 +- csrc/cpp_itfs/sampling/vec_dtypes.cuh | 1589 +++++++++++++++++++++++++ 2 files changed, 1638 insertions(+), 66 deletions(-) create mode 100644 csrc/cpp_itfs/sampling/vec_dtypes.cuh diff --git a/csrc/cpp_itfs/sampling/sampling.cuh b/csrc/cpp_itfs/sampling/sampling.cuh index 8e5f6e01ca..cf8aa1b7a3 100644 --- a/csrc/cpp_itfs/sampling/sampling.cuh +++ b/csrc/cpp_itfs/sampling/sampling.cuh @@ -79,16 +79,9 @@ constexpr bool DETERMINISTIC = false; \ __VA_ARGS__ \ } - - #define DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, ...) \ - if (compute_capacity.first >= 8) { \ - constexpr uint32_t BLOCK_THREADS = 1024; \ - __VA_ARGS__ \ - } else { \ - constexpr uint32_t BLOCK_THREADS = 512; \ - __VA_ARGS__ \ - } - + + #define BLOCK_THREADS 1024 + constexpr BlockScanAlgorithm SCAN_ALGO = BLOCK_SCAN_WARP_SCANS; constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS; @@ -877,27 +870,23 @@ uint32_t batch_size, uint32_t top_k_val, uint32_t d, bool deterministic, uint64_t philox_seed, uint64_t philox_offset, hipStream_t stream = 0) { - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - - auto compute_capacity = GetCudaComputeCapability(); - DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, { - const uint32_t smem_size = sizeof(SamplingTempStorage); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &output, &indices, &top_k_arr, - &top_k_val, &d, &philox_seed, &philox_offset}; - - DISPATCH_ALIGNED_VEC_SIZE( - vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { - auto kernel = TopKSamplingFromProbKernel; - - hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); - - hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); - })}); - return hipSuccess; - }); + const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + const uint32_t smem_size = sizeof(SamplingTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&probs, &output, &indices, &top_k_arr, + &top_k_val, &d, &philox_seed, &philox_offset}; + + DISPATCH_ALIGNED_VEC_SIZE( + vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { + auto kernel = TopKSamplingFromProbKernel; + + hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); + })}); + return hipSuccess; } template @@ -932,27 +921,24 @@ T top_p_val, uint32_t d, bool deterministic, uint64_t philox_seed, uint64_t philox_offset, hipStream_t stream = 0) { - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - auto compute_capacity = GetCudaComputeCapability(); - DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, { - const uint32_t smem_size = sizeof(SamplingTempStorage); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &top_k_arr, &top_p_arr, &output, &indices, - &top_k_val, &top_p_val, &d, &philox_seed, &philox_offset}; - - DISPATCH_ALIGNED_VEC_SIZE( - vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { - auto kernel = TopKTopPSamplingFromProbKernel; + const uint32_t smem_size = sizeof(SamplingTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&probs, &top_k_arr, &top_p_arr, &output, &indices, + &top_k_val, &top_p_val, &d, &philox_seed, &philox_offset}; - hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); - - hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); - })}); - return hipSuccess; - }); + DISPATCH_ALIGNED_VEC_SIZE( + vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { + auto kernel = TopKTopPSamplingFromProbKernel; + + hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); + })}); + return hipSuccess; } template @@ -1243,22 +1229,19 @@ hipError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr, uint32_t batch_size, uint32_t top_k_val, uint32_t d, hipStream_t stream = 0) { - const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); - - auto compute_capacity = GetCudaComputeCapability(); - DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, { - const uint32_t smem_size = sizeof(RenormTempStorage); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &renormed_prob, &top_k_arr, &top_k_val, &d}; - DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = TopKRenormProbKernel; - - hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); - hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); - }); - return hipSuccess; - }); + const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); + + const uint32_t smem_size = sizeof(RenormTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&probs, &renormed_prob, &top_k_arr, &top_k_val, &d}; + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { + auto kernel = TopKRenormProbKernel; + + hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); + hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); + }); + return hipSuccess; } } // namespace sampling diff --git a/csrc/cpp_itfs/sampling/vec_dtypes.cuh b/csrc/cpp_itfs/sampling/vec_dtypes.cuh new file mode 100644 index 0000000000..e63d146e4b --- /dev/null +++ b/csrc/cpp_itfs/sampling/vec_dtypes.cuh @@ -0,0 +1,1589 @@ +/* + * Copyright (C) 2023-2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #ifndef VEC_DTYPES_CUH_ + #define VEC_DTYPES_CUH_ + + + #include + #include + #include + #include + #include + #include + + #include + + #if defined(__HIPCC__) || (defined(__clang__) && defined(__HIP__)) || defined(__HIPCC_RTC__) + /* + Hacky workaround for the error below: + + /home/git_repos/glen-amd/flashinfer/include/flashinfer/attention/../vec_dtypes_hip.cuh:200:38: error: use of undeclared identifier '__float2bfloat162_rn'; did you mean '__float22bfloat162_rn'? + 200 | const __hip_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + | ^~~~~~~~~~~~~~~~~~~~ + | __float22bfloat162_rn + /opt/rocm-6.3.1/lib/llvm/bin/../../../include/hip/amd_detail/amd_hip_bf16.h:574:45: note: '__float22bfloat162_rn' declared here + 574 | __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __float22bfloat162_rn(const float2 a) { + */ + __HOST_DEVICE__ inline __hip_bfloat162 __float2bfloat162_rn(const float a) { + return __hip_bfloat162{__float2bfloat16(a), __float2bfloat16(a)}; + } + + inline __attribute__((always_inline)) __device__ __hip_bfloat162 make_bfloat162(const __hip_bfloat16 x, const __hip_bfloat16 y) { + __hip_bfloat162 t; + t.x = x; + t.y = y; + return t; + } + #endif + + namespace aiter { + + #define FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + + #define inline __attribute__((always_inline)) __device__ inline __attribute__((always_inline)) __device__ + + + /******************* vec_t type cast *******************/ + + template + struct vec_cast { + template + inline __attribute__((always_inline)) __device__ static void cast(dst_t* dst, const src_t* src) { + #pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + dst[i] = (dst_t)src[i]; + } + } + }; + + template <> + struct vec_cast { + template + inline __attribute__((always_inline)) __device__ static void cast(float* dst, const half* src) { + if constexpr (vec_size == 1) { + // dst[0] = (float)src[0]; + dst[0] = __half2float(src[0]); + } else { + #pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2*)dst)[i] = __half22float2(((half2*)src)[i]); + } + } + } + }; + + template <> + struct vec_cast { + template + inline __attribute__((always_inline)) __device__ static void cast(half* dst, const float* src) { + if constexpr (vec_size == 1) { + dst[0] = __float2half(src[0]); + } else { + #pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2*)dst)[i] = __float22half2_rn(((float2*)src)[i]); + } + } + } + }; + + template + constexpr inline __attribute__((always_inline)) __device__ int get_exponent_bits() { + if constexpr (std::is_same_v) { + return 4; + } else if constexpr (std::is_same_v) { + return 5; + } else if constexpr (std::is_same_v) { + return 5; + } else if constexpr (std::is_same_v) { + return 8; + } + } + + template + constexpr inline __attribute__((always_inline)) __device__ int get_mantissa_bits() { + if constexpr (std::is_same_v) { + return 3; + } else if constexpr (std::is_same_v) { + return 2; + } else if constexpr (std::is_same_v) { + return 11; + } else if constexpr (std::is_same_v) { + return 7; + } + } + + /*! + * \brief Fallback to software fast dequant implementation if hardware dequantization is not + * available. + * \note Inspired by Marlin's fast dequantization, but here we don't have to permute + * weights order. + * \ref + * https://github.com/vllm-project/vllm/blob/6dffa4b0a6120159ef2fe44d695a46817aff65bc/csrc/quantization/fp8/fp8_marlin.cu#L120 + */ + template + __device__ void fast_dequant_f8f16x4(uint32_t* input, uint2* output) { + uint32_t q = *input; + if constexpr (std::is_same_v && std::is_same_v) { + output->x = __byte_perm(0U, q, 0x5140); + output->y = __byte_perm(0U, q, 0x7362); + } else { + constexpr int FP8_EXPONENT = get_exponent_bits(); + constexpr int FP8_MANTISSA = get_mantissa_bits(); + constexpr int FP16_EXPONENT = get_exponent_bits(); + + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; + // Calculate MASK for extracting mantissa and exponent + // XXX: duplicate defs of `MASK1` and `MASK2`, + // in the HIP file "include/hip/amd_detail/amd_device_functions.h". + constexpr int MASK1_orig = 0x80000000; + constexpr int MASK2_orig = MASK1_orig >> (FP8_EXPONENT + FP8_MANTISSA); + constexpr int MASK3 = MASK2_orig & 0x7fffffff; + constexpr int MASK = MASK3 | (MASK3 >> 16); + q = __byte_perm(q, q, 0x1302); + + // Extract and shift FP8 values to FP16 format + uint32_t Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + uint32_t Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); + + constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + // Construct and apply exponent bias + if constexpr (std::is_same_v) { + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + + // Convert to half2 and apply bias + *(half2*)&(output->x) = __hmul2(*reinterpret_cast(&Out1), bias_reg); + *(half2*)&(output->y) = __hmul2(*reinterpret_cast(&Out2), bias_reg); + } else { + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const __hip_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + // Convert to bfloat162 and apply bias + *(__hip_bfloat162*)&(output->x) = + __hmul2(*reinterpret_cast(&Out1), bias_reg); + *(__hip_bfloat162*)&(output->y) = + __hmul2(*reinterpret_cast(&Out2), bias_reg); + } + } + } + + template <> + struct vec_cast<__hip_bfloat16, __hip_fp8_e4m3_fnuz> { + template + inline __attribute__((always_inline)) __device__ static void cast(__hip_bfloat16* dst, const __hip_fp8_e4m3_fnuz* src) { + if constexpr (vec_size == 1) { + dst[0] = __hip_bfloat16(src[0]); + } else if constexpr (vec_size == 2) { + dst[0] = __hip_bfloat16(src[0]); + dst[1] = __hip_bfloat16(src[1]); + } else { + static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); + #pragma unroll + for (uint32_t i = 0; i < vec_size / 4; ++i) { + fast_dequant_f8f16x4<__hip_fp8_e4m3_fnuz, __hip_bfloat16>((uint32_t*)&src[i * 4], + (uint2*)&dst[i * 4]); + } + } + } + }; + + template <> + struct vec_cast<__hip_bfloat16, __hip_fp8_e5m2_fnuz> { + template + inline __attribute__((always_inline)) __device__ static void cast(__hip_bfloat16* dst, const __hip_fp8_e5m2_fnuz* src) { + if constexpr (vec_size == 1) { + dst[0] = __hip_bfloat16(src[0]); + } else if constexpr (vec_size == 2) { + dst[0] = __hip_bfloat16(src[0]); + dst[1] = __hip_bfloat16(src[1]); + } else { + static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); + #pragma unroll + for (uint32_t i = 0; i < vec_size / 4; ++i) { + fast_dequant_f8f16x4<__hip_fp8_e5m2_fnuz, __hip_bfloat16>((uint32_t*)&src[i * 4], + (uint2*)&dst[i * 4]); + } + } + } + }; + + #if defined(__HIPCC__) || (defined(__clang__) && defined(__HIP__)) || defined(__HIPCC_RTC__) + // Function to convert half-precision to e4m3 + __device__ uint8_t convert_f32_to_e4m3(float val) { + // Define the range of e4m3 + // 1. Minimum representable value for e4m3 + // 2. Binary 1000.000 in e4m3 + // 3. FLT_MIN is not suitable for e4m3 because e4m3 has a much smaller dynamic range. + float min_e4m3 = -8.0f; + // 1. Maximum representable value for e4m3 + // 2. Binary 0111.111 in e4m3 + // FLT_MAX far exceeds the maximum value representable in e4m3. + float max_e4m3 = 7.875f; + + // Saturate the value to the e4m3 range + val = fminf(fmaxf(val, min_e4m3), max_e4m3); + + // Perform conversion + // Decompose into mantissa and exponent + int exp; + float mantissa = frexpf(val, &exp); + + // Encode sign bit + uint8_t sign = (mantissa < 0) ? 0x80 : 0x00; + + // Normalize mantissa and encode exponent + mantissa = fabsf(mantissa) * 16.0f; // Scale mantissa for e4m3's 3-bit precision + uint8_t exponent = static_cast(exp + 7); // Bias of 7 for e4m3 + + // Quantize mantissa + // Apply round-to-nearest-even to the mantissa + uint8_t quant_mantissa = static_cast(roundf(mantissa)) & 0x07; + + // Combine into 8 bits: [sign][exponent][mantissa] + return sign | (exponent << 3) | quant_mantissa; + } + + __device__ __half2 convert_uint32_to_half2(uint32_t input) { + // Extract the low and high 16 bits + uint16_t low_val = input & 0xFFFF; + uint16_t high_val = (input >> 16) & 0xFFFF; + // Convert to __half + __half low_half = __float2half(static_cast(low_val)); + __half high_half = __float2half(static_cast(high_val)); + // Pack into __half2 + return __halves2half2(low_half, high_half); + } + + + // Convert f16x2 (__half2) to e4m3x2 (packed 16-bit) + __device__ uint16_t convert_f16x2_to_e4m3x2(__half2 x) { + float f32_0 = __half2float(__low2half(x)); + float f32_1 = __half2float(__high2half(x)); + uint8_t e4m3_0 = convert_f32_to_e4m3(f32_0); + uint8_t e4m3_1 = convert_f32_to_e4m3(f32_1); + return (static_cast(e4m3_1) << 8) | e4m3_0; + } + #endif + + template <> + struct vec_cast<__hip_fp8_e4m3_fnuz, half> { + template + inline __attribute__((always_inline)) __device__ static void cast(__hip_fp8_e4m3_fnuz* dst, const half* src) { + #ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + if constexpr (vec_size == 1) { + dst[0] = __hip_fp8_e4m3_fnuz(src[0]); + } else { + #pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + uint16_t y; + uint32_t x = *(uint32_t*)&src[i * 2]; + #if defined(__HIPCC__) || (defined(__clang__) && defined(__HIP__)) || defined(__HIPCC_RTC__) + __half2 x_h2 = convert_uint32_to_half2(x); + y = convert_f16x2_to_e4m3x2(x_h2); + #else + asm volatile("cvt.rn.satfinite.e4m3x2.f16x2 %0, %1;" : "=h"(y) : "r"(x)); + #endif + *(uint16_t*)&dst[i * 2] = y; + } + } + #else + #pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + dst[i] = __hip_fp8_e4m3_fnuz(src[i]); + } + #endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + } + }; + + #if defined(__HIPCC__) || (defined(__clang__) && defined(__HIP__)) || defined(__HIPCC_RTC__) + __device__ uint16_t convert_f16x2_to_e5m2x2(uint32_t x) { + // Unpack the two 16-bit half-precision floats from the input + // Extract lower 16 bits + __half h1 = __ushort_as_half(x & 0xFFFF); + // Extract upper 16 bits + __half h2 = __ushort_as_half((x >> 16) & 0xFFFF); + + #if 0 + // Alternative with `__uint2half_rn` + uint16_t val1 = x & 0xFFFF; // Lower 16 bits + uint16_t val2 = (x >> 16) & 0xFFFF; // Upper 16 bits + __half h1 = __uint2half_rn(val1); + __half h2 = __uint2half_rn(val2); + #endif + + // Define the range of e5m2 + // Minimum representable value for e5m2 + const float min_e5m2 = -8.0f; + // Maximum representable value for e5m2 + const float max_e5m2 = 7.75f; + + // Helper lambda for conversion + auto f32_to_e5m2 = [min_e5m2, max_e5m2](float val) -> uint8_t { + // Saturate the val + val= fminf(fmaxf(val, min_e5m2), max_e5m2); + + // Decompose into mantissa and exponent + int exp; + float mantissa = frexpf(val, &exp); + + // Encode sign bit + uint8_t sign = (mantissa < 0) ? 0x10 : 0x00; // Sign in bit 4 + mantissa = fabsf(mantissa); + + // Normalize mantissa and encode exponent + mantissa *= 4.0f; // Scale for 2-bit mantissa + uint8_t exponent = static_cast(exp + 7); // Apply bias for e5m2 + + // Apply round-to-nearest-even + uint8_t quant_mantissa = static_cast(roundf(mantissa)) & 0x03; + + // Combine into 5 bits: [sign][exponent][mantissa] + return sign | (exponent << 2) | quant_mantissa; + }; + + // Convert the two __half values to e5m2 + uint8_t e5m2_1 = f32_to_e5m2(__half2float(h1)); + uint8_t e5m2_2 = f32_to_e5m2(__half2float(h2)); + + // Pack the two e5m2 values into a single 16-bit output + return (e5m2_2 << 8) | e5m2_1; + } + #endif + + template <> + struct vec_cast<__hip_fp8_e5m2_fnuz, half> { + template + inline __attribute__((always_inline)) __device__ static void cast(__hip_fp8_e5m2_fnuz* dst, const half* src) { + #ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + if constexpr (vec_size == 1) { + dst[0] = __hip_fp8_e5m2_fnuz(src[0]); + } else { + #pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + uint16_t y; + uint32_t x = *(uint32_t*)&src[i * 2]; + #if defined(__HIPCC__) || (defined(__clang__) && defined(__HIP__)) || defined(__HIPCC_RTC__) + y = convert_f16x2_to_e5m2x2(x); + #else + asm volatile("cvt.rn.satfinite.e5m2x2.f16x2 %0, %1;" : "=h"(y) : "r"(x)); + #endif + *(uint16_t*)&dst[i * 2] = y; + } + } + #else + #pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + dst[i] = __hip_fp8_e5m2_fnuz(src[i]); + } + #endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + } + }; + + #if defined(__HIPCC__) || (defined(__clang__) && defined(__HIP__)) || defined(__HIPCC_RTC__) + __device__ uint32_t convert_e4m3x2_to_f16x2(uint16_t x) { + // Extract two e4m3 values from the 16-bit input + uint8_t e4m3_1 = x & 0xFF; // Lower 8 bits + uint8_t e4m3_2 = (x >> 8) & 0xFF; // Upper 8 bits + + // Decode e4m3 to float + auto e4m3_to_f32 = [](uint8_t e4m3) -> float { + // Extract sign, exponent, and mantissa + int sign = (e4m3 & 0x80) ? -1 : 1; + int exponent = ((e4m3 >> 3) & 0x0F) - 7; // 4-bit exponent with bias 7 + int mantissa = e4m3 & 0x07; // 3-bit mantissa + + // Handle special case: zero + if (exponent == -7 && mantissa == 0) { + return 0.0f; + } + + // Convert to float + float f32_val = sign * ldexpf(1.0f + mantissa / 8.0f, exponent); + return f32_val; + }; + + float f1 = e4m3_to_f32(e4m3_1); + float f2 = e4m3_to_f32(e4m3_2); + + // Convert float to IEEE f16 + __half h1 = __float2half_rn(f1); + __half h2 = __float2half_rn(f2); + + // Pack the two f16 values into a single uint32_t + uint32_t f16x2 = (__half_as_ushort(h2) << 16) | __half_as_ushort(h1); + return f16x2; + } + #endif + + template <> + struct vec_cast { + template + inline __attribute__((always_inline)) __device__ static void cast(half* dst, const __hip_fp8_e4m3_fnuz* src) { + #ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + if constexpr (vec_size == 1) { + dst[0] = half(src[0]); + } else { + #pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + uint32_t y; + uint16_t x = *(uint16_t*)&src[i * 2]; + #if defined(__HIPCC__) || (defined(__clang__) && defined(__HIP__)) || defined(__HIPCC_RTC__) + y = convert_e4m3x2_to_f16x2(x); + #else + asm volatile("cvt.rn.f16x2.e4m3x2 %0, %1;" : "=r"(y) : "h"(x)); + #endif + *(uint32_t*)&dst[i * 2] = y; + } + } + #else + if constexpr (vec_size == 1) { + dst[0] = half(src[0]); + } else if constexpr (vec_size == 2) { + dst[0] = half(src[0]); + dst[1] = half(src[1]); + } else { + static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); + #pragma unroll + for (uint32_t i = 0; i < vec_size / 4; ++i) { + fast_dequant_f8f16x4<__hip_fp8_e4m3_fnuz, half>((uint32_t*)&src[i * 4], (uint2*)&dst[i * 4]); + } + } + #endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + } + }; + + #if defined(__HIPCC__) || (defined(__clang__) && defined(__HIP__)) || defined(__HIPCC_RTC__) + __device__ uint32_t convert_e5m2x2_to_f16x2(uint16_t x) { + // Extract two e5m2 values from the 16-bit input + uint8_t e5m2_1 = x & 0xFF; // Lower 8 bits + uint8_t e5m2_2 = (x >> 8) & 0xFF; // Upper 8 bits + + // Decode e5m2 to float + auto e5m2_to_f32 = [](uint8_t e5m2) -> float { + // Extract sign, exponent, and mantissa + int sign = (e5m2 & 0x80) ? -1 : 1; // Sign bit + int exponent = ((e5m2 >> 2) & 0x1F) - 15; // 5-bit exponent with bias 15 + int mantissa = e5m2 & 0x03; // 2-bit mantissa + + // Handle special case: zero + if (exponent == -15 && mantissa == 0) { + return 0.0f; + } + + // Convert to float + float value = sign * ldexpf(1.0f + mantissa / 4.0f, exponent); + return value; + }; + + float f1 = e5m2_to_f32(e5m2_1); + float f2 = e5m2_to_f32(e5m2_2); + + // Convert float to IEEE f16 + __half h1 = __float2half_rn(f1); + __half h2 = __float2half_rn(f2); + + // Pack the two f16 values into a single uint32_t + uint32_t f16x2 = (__half_as_ushort(h2) << 16) | __half_as_ushort(h1); + return f16x2; + } + #endif + + template <> + struct vec_cast { + template + inline __attribute__((always_inline)) __device__ static void cast(half* dst, const __hip_fp8_e5m2_fnuz* src) { + #ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + if constexpr (vec_size == 1) { + dst[0] = half(src[0]); + } else { + #pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + uint32_t y; + uint16_t x = *(uint16_t*)&src[i * 2]; + #if defined(__HIPCC__) || (defined(__clang__) && defined(__HIP__)) || defined(__HIPCC_RTC__) + y = convert_e5m2x2_to_f16x2(x); + #else + asm volatile("cvt.rn.f16x2.e5m2x2 %0, %1;" : "=r"(y) : "h"(x)); + #endif + *(uint32_t*)&dst[i * 2] = y; + } + } + #else + if constexpr (vec_size == 1) { + dst[0] = half(src[0]); + } else if constexpr (vec_size == 2) { + dst[0] = half(src[0]); + dst[1] = half(src[1]); + } else { + static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); + #pragma unroll + for (uint32_t i = 0; i < vec_size / 4; ++i) { + fast_dequant_f8f16x4<__hip_fp8_e5m2_fnuz, half>((uint32_t*)&src[i * 4], (uint2*)&dst[i * 4]); + } + } + #endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + } + }; + + template <> + struct vec_cast { + template + inline __attribute__((always_inline)) __device__ static void cast(float* dst, const __hip_bfloat16* src) { + if constexpr (vec_size == 1) { + dst[0] = (float)src[0]; + } else { + #pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2*)dst)[i] = __bfloat1622float2(((__hip_bfloat162*)src)[i]); + } + } + } + }; + + template <> + struct vec_cast<__hip_bfloat16, float> { + template + inline __attribute__((always_inline)) __device__ static void cast(__hip_bfloat16* dst, const float* src) { + /*if constexpr (vec_size == 1) { + dst[0] = __hip_bfloat16(src[0]); + } else { + #pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((__hip_bfloat162*)dst)[i] = __float22bfloat162_rn(((float2*)src)[i]); + } + }*/ + //fast but unsafe bfloat conversion... + union f2bf { float f; __hip_bfloat16 bf[2]; } _f2bf; + #pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + _f2bf.f = src[i]; + dst[i] = _f2bf.bf[1]; + } + } + }; + + template + struct vec_t { + inline __attribute__((always_inline)) __device__ float_t& operator[](size_t i); + inline __attribute__((always_inline)) __device__ const float_t& operator[](size_t i) const; + inline __attribute__((always_inline)) __device__ void fill(float_t val); + inline __attribute__((always_inline)) __device__ void load(const float_t* ptr); + inline __attribute__((always_inline)) __device__ void store(float_t* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src); + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr); + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const; + inline __attribute__((always_inline)) __device__ static void memcpy(float_t* dst, const float_t* src); + inline __attribute__((always_inline)) __device__ float_t* ptr(); + }; + + template + inline __attribute__((always_inline)) __device__ void cast_from_impl(vec_t& dst, + const vec_t& src) { + #if defined(__HIPCC__) || (defined(__clang__) && defined(__HIP__)) || defined(__HIPCC_RTC__) + vec_cast::template cast( + #else + vec_cast::cast( + #endif + dst.ptr(), const_cast*>(&src)->ptr()); + } + + template + inline __attribute__((always_inline)) __device__ void cast_load_impl(vec_t& dst, + const src_float_t* src_ptr) { + if constexpr (std::is_same_v) { + dst.load(src_ptr); + } else { + vec_t tmp; + tmp.load(src_ptr); + dst.cast_from(tmp); + } + } + + template + inline __attribute__((always_inline)) __device__ void cast_store_impl(tgt_float_t* dst_ptr, + const vec_t& src) { + if constexpr (std::is_same_v) { + src.store(dst_ptr); + } else { + vec_t tmp; + tmp.cast_from(src); + tmp.store(dst_ptr); + } + } + + /******************* vec_t<__hip_fp8_e4m3_fnuz> *******************/ + + // __hip_fp8_e4m3_fnuz x 1 + template <> + struct vec_t<__hip_fp8_e4m3_fnuz, 1> { + __hip_fp8_e4m3_fnuz data; + + inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz& operator[](size_t i) { return ((__hip_fp8_e4m3_fnuz*)(&data))[i]; } + inline __attribute__((always_inline)) __device__ const __hip_fp8_e4m3_fnuz& operator[](size_t i) const { + return ((const __hip_fp8_e4m3_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz* ptr() { return reinterpret_cast<__hip_fp8_e4m3_fnuz*>(&data); } + inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e4m3_fnuz val); + inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e4m3_fnuz* ptr); + inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e4m3_fnuz* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + inline __attribute__((always_inline)) __device__ static void memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src); + }; + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e4m3_fnuz, 1>::fill(__hip_fp8_e4m3_fnuz val) { data = val; } + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e4m3_fnuz, 1>::load(const __hip_fp8_e4m3_fnuz* ptr) { data = *ptr; } + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e4m3_fnuz, 1>::store(__hip_fp8_e4m3_fnuz* ptr) const { *ptr = data; } + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e4m3_fnuz, 1>::memcpy(__hip_fp8_e4m3_fnuz* dst, + const __hip_fp8_e4m3_fnuz* src) { + *dst = *src; + } + + // __hip_fp8_e4m3_fnuz x 2 + template <> + struct vec_t<__hip_fp8_e4m3_fnuz, 2> { + __hip_fp8x2_e4m3_fnuz data; + + inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz& operator[](size_t i) { return ((__hip_fp8_e4m3_fnuz*)(&data))[i]; } + inline __attribute__((always_inline)) __device__ const __hip_fp8_e4m3_fnuz& operator[](size_t i) const { + return ((const __hip_fp8_e4m3_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz* ptr() { return reinterpret_cast<__hip_fp8_e4m3_fnuz*>(&data); } + inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e4m3_fnuz val); + inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e4m3_fnuz* ptr); + inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e4m3_fnuz* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src); + }; + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e4m3_fnuz, 2>::fill(__hip_fp8_e4m3_fnuz val) { + data.__x = (__hip_fp8x2_storage_t(val.__x) << 8) | __hip_fp8x2_storage_t(val.__x); + } + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e4m3_fnuz, 2>::load(const __hip_fp8_e4m3_fnuz* ptr) { + data = *((__hip_fp8x2_e4m3_fnuz*)ptr); + } + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e4m3_fnuz, 2>::store(__hip_fp8_e4m3_fnuz* ptr) const { + *((__hip_fp8x2_e4m3_fnuz*)ptr) = data; + } + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e4m3_fnuz, 2>::memcpy(__hip_fp8_e4m3_fnuz* dst, + const __hip_fp8_e4m3_fnuz* src) { + *((__hip_fp8x2_e4m3_fnuz*)dst) = *((__hip_fp8x2_e4m3_fnuz*)src); + } + + // __hip_fp8_e4m3_fnuz x 4 + + template <> + struct vec_t<__hip_fp8_e4m3_fnuz, 4> { + __hip_fp8x4_e4m3_fnuz data; + + inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz& operator[](size_t i) { return ((__hip_fp8_e4m3_fnuz*)(&data))[i]; } + inline __attribute__((always_inline)) __device__ const __hip_fp8_e4m3_fnuz& operator[](size_t i) const { + return ((const __hip_fp8_e4m3_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz* ptr() { return reinterpret_cast<__hip_fp8_e4m3_fnuz*>(&data); } + inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e4m3_fnuz val); + inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e4m3_fnuz* ptr); + inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e4m3_fnuz* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + inline __attribute__((always_inline)) __device__ static void memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src); + }; + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e4m3_fnuz, 4>::fill(__hip_fp8_e4m3_fnuz val) { + data.__x = (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + } + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e4m3_fnuz, 4>::load(const __hip_fp8_e4m3_fnuz* ptr) { + data = *((__hip_fp8x4_e4m3_fnuz*)ptr); + } + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e4m3_fnuz, 4>::store(__hip_fp8_e4m3_fnuz* ptr) const { + *((__hip_fp8x4_e4m3_fnuz*)ptr) = data; + } + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e4m3_fnuz, 4>::memcpy(__hip_fp8_e4m3_fnuz* dst, + const __hip_fp8_e4m3_fnuz* src) { + *((__hip_fp8x4_e4m3_fnuz*)dst) = *((__hip_fp8x4_e4m3_fnuz*)src); + } + + // __hip_fp8_e4m3_fnuz x 8 + + template <> + struct vec_t<__hip_fp8_e4m3_fnuz, 8> { + uint2 data; + + inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz& operator[](size_t i) { return ((__hip_fp8_e4m3_fnuz*)(&data))[i]; } + inline __attribute__((always_inline)) __device__ const __hip_fp8_e4m3_fnuz& operator[](size_t i) const { + return ((const __hip_fp8_e4m3_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz* ptr() { return reinterpret_cast<__hip_fp8_e4m3_fnuz*>(&data); } + inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e4m3_fnuz val); + inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e4m3_fnuz* ptr); + inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e4m3_fnuz* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + inline __attribute__((always_inline)) __device__ static void memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src); + }; + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e4m3_fnuz, 8>::fill(__hip_fp8_e4m3_fnuz val) { + ((__hip_fp8x4_e4m3_fnuz*)(&data.x))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e4m3_fnuz*)(&data.y))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + } + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e4m3_fnuz, 8>::load(const __hip_fp8_e4m3_fnuz* ptr) { + data = *((uint2*)ptr); + } + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e4m3_fnuz, 8>::store(__hip_fp8_e4m3_fnuz* ptr) const { + *((uint2*)ptr) = data; + } + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e4m3_fnuz, 8>::memcpy(__hip_fp8_e4m3_fnuz* dst, + const __hip_fp8_e4m3_fnuz* src) { + *((uint2*)dst) = *((uint2*)src); + } + + // __hip_fp8_e4m3_fnuz x 16 or more + template + struct vec_t<__hip_fp8_e4m3_fnuz, vec_size> { + uint4 data[vec_size / 16]; + + inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz& operator[](size_t i) { return ((__hip_fp8_e4m3_fnuz*)data)[i]; } + inline __attribute__((always_inline)) __device__ const __hip_fp8_e4m3_fnuz& operator[](size_t i) const { + return ((const __hip_fp8_e4m3_fnuz*)data)[i]; + } + inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz* ptr() { return reinterpret_cast<__hip_fp8_e4m3_fnuz*>(&data); } + inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e4m3_fnuz val) { + #pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((__hip_fp8x4_e4m3_fnuz*)(&(data[i].x)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e4m3_fnuz*)(&(data[i].y)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e4m3_fnuz*)(&(data[i].z)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e4m3_fnuz*)(&(data[i].w)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + } + } + inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e4m3_fnuz* ptr) { + #pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + data[i] = ((uint4*)ptr)[i]; + } + } + inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e4m3_fnuz* ptr) const { + #pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4*)ptr)[i] = data[i]; + } + } + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + inline __attribute__((always_inline)) __device__ static void memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src) { + #pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4*)dst)[i] = ((uint4*)src)[i]; + } + } + }; + + /******************* vec_t<__hip_fp8_e5m2_fnuz> *******************/ + + // __hip_fp8_e5m2_fnuz x 1 + template <> + struct vec_t<__hip_fp8_e5m2_fnuz, 1> { + __hip_fp8_e5m2_fnuz data; + + inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz& operator[](size_t i) { return ((__hip_fp8_e5m2_fnuz*)(&data))[i]; } + inline __attribute__((always_inline)) __device__ const __hip_fp8_e5m2_fnuz& operator[](size_t i) const { + return ((const __hip_fp8_e5m2_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz* ptr() { return reinterpret_cast<__hip_fp8_e5m2_fnuz*>(&data); } + inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e5m2_fnuz val); + inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e5m2_fnuz* ptr); + inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e5m2_fnuz* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + inline __attribute__((always_inline)) __device__ static void memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src); + }; + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e5m2_fnuz, 1>::fill(__hip_fp8_e5m2_fnuz val) { data = val; } + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e5m2_fnuz, 1>::load(const __hip_fp8_e5m2_fnuz* ptr) { data = *ptr; } + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e5m2_fnuz, 1>::store(__hip_fp8_e5m2_fnuz* ptr) const { *ptr = data; } + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e5m2_fnuz, 1>::memcpy(__hip_fp8_e5m2_fnuz* dst, + const __hip_fp8_e5m2_fnuz* src) { + *dst = *src; + } + + // __hip_fp8_e5m2_fnuz x 2 + template <> + struct vec_t<__hip_fp8_e5m2_fnuz, 2> { + __hip_fp8x2_e5m2_fnuz data; + + inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz& operator[](size_t i) { return ((__hip_fp8_e5m2_fnuz*)(&data))[i]; } + inline __attribute__((always_inline)) __device__ const __hip_fp8_e5m2_fnuz& operator[](size_t i) const { + return ((const __hip_fp8_e5m2_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz* ptr() { return reinterpret_cast<__hip_fp8_e5m2_fnuz*>(&data); } + inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e5m2_fnuz val); + inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e5m2_fnuz* ptr); + inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e5m2_fnuz* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + inline __attribute__((always_inline)) __device__ static void memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src); + }; + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e5m2_fnuz, 2>::fill(__hip_fp8_e5m2_fnuz val) { + data.__x = (__hip_fp8x2_storage_t(val.__x) << 8) | __hip_fp8x2_storage_t(val.__x); + } + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e5m2_fnuz, 2>::load(const __hip_fp8_e5m2_fnuz* ptr) { + data = *((__hip_fp8x2_e5m2_fnuz*)ptr); + } + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e5m2_fnuz, 2>::store(__hip_fp8_e5m2_fnuz* ptr) const { + *((__hip_fp8x2_e5m2_fnuz*)ptr) = data; + } + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e5m2_fnuz, 2>::memcpy(__hip_fp8_e5m2_fnuz* dst, + const __hip_fp8_e5m2_fnuz* src) { + *((__hip_fp8x2_e5m2_fnuz*)dst) = *((__hip_fp8x2_e5m2_fnuz*)src); + } + + // __hip_fp8_e5m2_fnuz x 4 + + template <> + struct vec_t<__hip_fp8_e5m2_fnuz, 4> { + __hip_fp8x4_e5m2_fnuz data; + + inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz& operator[](size_t i) { return ((__hip_fp8_e5m2_fnuz*)(&data))[i]; } + inline __attribute__((always_inline)) __device__ const __hip_fp8_e5m2_fnuz& operator[](size_t i) const { + return ((const __hip_fp8_e5m2_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz* ptr() { return reinterpret_cast<__hip_fp8_e5m2_fnuz*>(&data); } + inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e5m2_fnuz val); + inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e5m2_fnuz* ptr); + inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e5m2_fnuz* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + inline __attribute__((always_inline)) __device__ static void memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src); + }; + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e5m2_fnuz, 4>::fill(__hip_fp8_e5m2_fnuz val) { + data.__x = (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + } + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e5m2_fnuz, 4>::load(const __hip_fp8_e5m2_fnuz* ptr) { + data = *((__hip_fp8x4_e5m2_fnuz*)ptr); + } + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e5m2_fnuz, 4>::store(__hip_fp8_e5m2_fnuz* ptr) const { + *((__hip_fp8x4_e5m2_fnuz*)ptr) = data; + } + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e5m2_fnuz, 4>::memcpy(__hip_fp8_e5m2_fnuz* dst, + const __hip_fp8_e5m2_fnuz* src) { + *((__hip_fp8x4_e5m2_fnuz*)dst) = *((__hip_fp8x4_e5m2_fnuz*)src); + } + + // __hip_fp8_e5m2_fnuz x 8 + + template <> + struct vec_t<__hip_fp8_e5m2_fnuz, 8> { + uint2 data; + + inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz& operator[](size_t i) { return ((__hip_fp8_e5m2_fnuz*)(&data))[i]; } + inline __attribute__((always_inline)) __device__ const __hip_fp8_e5m2_fnuz& operator[](size_t i) const { + return ((const __hip_fp8_e5m2_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz* ptr() { return reinterpret_cast<__hip_fp8_e5m2_fnuz*>(&data); } + inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e5m2_fnuz val); + inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e5m2_fnuz* ptr); + inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e5m2_fnuz* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src); + }; + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e5m2_fnuz, 8>::fill(__hip_fp8_e5m2_fnuz val) { + ((__hip_fp8x4_e5m2_fnuz*)(&data.x))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e5m2_fnuz*)(&data.y))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + } + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e5m2_fnuz, 8>::load(const __hip_fp8_e5m2_fnuz* ptr) { + data = *((uint2*)ptr); + } + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e5m2_fnuz, 8>::store(__hip_fp8_e5m2_fnuz* ptr) const { + *((uint2*)ptr) = data; + } + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e5m2_fnuz, 8>::memcpy(__hip_fp8_e5m2_fnuz* dst, + const __hip_fp8_e5m2_fnuz* src) { + *((uint2*)dst) = *((uint2*)src); + } + + // __hip_fp8_e5m2_fnuz x 16 or more + + template + struct vec_t<__hip_fp8_e5m2_fnuz, vec_size> { + uint4 data[vec_size / 16]; + + inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz& operator[](size_t i) { return ((__hip_fp8_e5m2_fnuz*)data)[i]; } + inline __attribute__((always_inline)) __device__ const __hip_fp8_e5m2_fnuz& operator[](size_t i) const { + return ((const __hip_fp8_e5m2_fnuz*)data)[i]; + } + inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz* ptr() { return reinterpret_cast<__hip_fp8_e5m2_fnuz*>(&data); } + inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e5m2_fnuz val) { + #pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((__hip_fp8x4_e5m2_fnuz*)(&(data[i].x)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e5m2_fnuz*)(&(data[i].y)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e5m2_fnuz*)(&(data[i].z)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e5m2_fnuz*)(&(data[i].w)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + } + } + inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e5m2_fnuz* ptr) { + #pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + data[i] = ((uint4*)ptr)[i]; + } + } + inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e5m2_fnuz* ptr) const { + #pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4*)ptr)[i] = data[i]; + } + } + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src) { + #pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4*)dst)[i] = ((uint4*)src)[i]; + } + } + }; + + /******************* vec_t *******************/ + + // half x 1 + template <> + struct vec_t { + half data; + + inline __attribute__((always_inline)) __device__ half& operator[](size_t i) { return ((half*)(&data))[i]; } + inline __attribute__((always_inline)) __device__ const half& operator[](size_t i) const { return ((const half*)(&data))[i]; } + inline __attribute__((always_inline)) __device__ half* ptr() { return reinterpret_cast(&data); } + inline __attribute__((always_inline)) __device__ void fill(half val); + inline __attribute__((always_inline)) __device__ void load(const half* ptr); + inline __attribute__((always_inline)) __device__ void store(half* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + inline __attribute__((always_inline)) __device__ static void memcpy(half* dst, const half* src); + }; + + inline __attribute__((always_inline)) __device__ void vec_t::fill(half val) { data = val; } + + inline __attribute__((always_inline)) __device__ void vec_t::load(const half* ptr) { data = *ptr; } + + inline __attribute__((always_inline)) __device__ void vec_t::store(half* ptr) const { *ptr = data; } + + inline __attribute__((always_inline)) __device__ void vec_t::memcpy(half* dst, const half* src) { *dst = *src; } + + // half x 2 + template <> + struct vec_t { + half2 data; + + inline __attribute__((always_inline)) __device__ half& operator[](size_t i) { return ((half*)(&data))[i]; } + inline __attribute__((always_inline)) __device__ const half& operator[](size_t i) const { return ((const half*)(&data))[i]; } + inline __attribute__((always_inline)) __device__ half* ptr() { return reinterpret_cast(&data); } + inline __attribute__((always_inline)) __device__ void fill(half val); + inline __attribute__((always_inline)) __device__ void load(const half* ptr); + inline __attribute__((always_inline)) __device__ void store(half* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + inline __attribute__((always_inline)) __device__ static void memcpy(half* dst, const half* src); + }; + + inline __attribute__((always_inline)) __device__ void vec_t::fill(half val) { data = make_half2(val, val); } + + inline __attribute__((always_inline)) __device__ void vec_t::load(const half* ptr) { data = *((half2*)ptr); } + + inline __attribute__((always_inline)) __device__ void vec_t::store(half* ptr) const { *((half2*)ptr) = data; } + + inline __attribute__((always_inline)) __device__ void vec_t::memcpy(half* dst, const half* src) { + *((half2*)dst) = *((half2*)src); + } + + // half x 4 + + template <> + struct vec_t { + uint2 data; + + inline __attribute__((always_inline)) __device__ half& operator[](size_t i) { return ((half*)(&data))[i]; } + inline __attribute__((always_inline)) __device__ const half& operator[](size_t i) const { return ((const half*)(&data))[i]; } + inline __attribute__((always_inline)) __device__ half* ptr() { return reinterpret_cast(&data); } + inline __attribute__((always_inline)) __device__ void fill(half val); + inline __attribute__((always_inline)) __device__ void load(const half* ptr); + inline __attribute__((always_inline)) __device__ void store(half* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void memcpy(half* dst, const half* src); + }; + + inline __attribute__((always_inline)) __device__ void vec_t::fill(half val) { + *(half2*)(&data.x) = make_half2(val, val); + *(half2*)(&data.y) = make_half2(val, val); + } + + inline __attribute__((always_inline)) __device__ void vec_t::load(const half* ptr) { data = *((uint2*)ptr); } + + inline __attribute__((always_inline)) __device__ void vec_t::store(half* ptr) const { *((uint2*)ptr) = data; } + + inline __attribute__((always_inline)) __device__ void vec_t::memcpy(half* dst, const half* src) { + *((uint2*)dst) = *((uint2*)src); + } + + // half x 8 or more + + template + struct vec_t { + uint4 data[vec_size / 8]; + inline __attribute__((always_inline)) __device__ half& operator[](size_t i) { return ((half*)data)[i]; } + inline __attribute__((always_inline)) __device__ const half& operator[](size_t i) const { return ((const half*)data)[i]; } + inline __attribute__((always_inline)) __device__ half* ptr() { return reinterpret_cast(&data); } + inline __attribute__((always_inline)) __device__ void fill(half val) { + #pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + *(half2*)(&(data[i].x)) = make_half2(val, val); + *(half2*)(&(data[i].y)) = make_half2(val, val); + *(half2*)(&(data[i].z)) = make_half2(val, val); + *(half2*)(&(data[i].w)) = make_half2(val, val); + } + } + inline __attribute__((always_inline)) __device__ void load(const half* ptr) { + #pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + data[i] = ((uint4*)ptr)[i]; + } + } + inline __attribute__((always_inline)) __device__ void store(half* ptr) const { + #pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4*)ptr)[i] = data[i]; + } + } + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void memcpy(half* dst, const half* src) { + #pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4*)dst)[i] = ((uint4*)src)[i]; + } + } + }; + + /******************* vec_t<__hip_bfloat16> *******************/ + + // __hip_bfloat16 x 1 + template <> + struct vec_t<__hip_bfloat16, 1> { + __hip_bfloat16 data; + inline __attribute__((always_inline)) __device__ __hip_bfloat16& operator[](size_t i) { return ((__hip_bfloat16*)(&data))[i]; } + inline __attribute__((always_inline)) __device__ const __hip_bfloat16& operator[](size_t i) const { + return ((const __hip_bfloat16*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ __hip_bfloat16* ptr() { return reinterpret_cast<__hip_bfloat16*>(&data); } + inline __attribute__((always_inline)) __device__ void fill(__hip_bfloat16 val); + inline __attribute__((always_inline)) __device__ void load(const __hip_bfloat16* ptr); + inline __attribute__((always_inline)) __device__ void store(__hip_bfloat16* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void memcpy(__hip_bfloat16* dst, const __hip_bfloat16* src); + }; + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_bfloat16, 1>::fill(__hip_bfloat16 val) { data = val; } + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_bfloat16, 1>::load(const __hip_bfloat16* ptr) { data = *ptr; } + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_bfloat16, 1>::store(__hip_bfloat16* ptr) const { *ptr = data; } + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_bfloat16, 1>::memcpy(__hip_bfloat16* dst, const __hip_bfloat16* src) { + *dst = *src; + } + + // __hip_bfloat16 x 2 + template <> + struct vec_t<__hip_bfloat16, 2> { + __hip_bfloat162 data; + + inline __attribute__((always_inline)) __device__ __hip_bfloat16& operator[](size_t i) { return ((__hip_bfloat16*)(&data))[i]; } + inline __attribute__((always_inline)) __device__ const __hip_bfloat16& operator[](size_t i) const { + return ((const __hip_bfloat16*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ __hip_bfloat16* ptr() { return reinterpret_cast<__hip_bfloat16*>(&data); } + inline __attribute__((always_inline)) __device__ void fill(__hip_bfloat16 val); + inline __attribute__((always_inline)) __device__ void load(const __hip_bfloat16* ptr); + inline __attribute__((always_inline)) __device__ void store(__hip_bfloat16* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void memcpy(__hip_bfloat16* dst, const __hip_bfloat16* src); + }; + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_bfloat16, 2>::fill(__hip_bfloat16 val) { + data = make_bfloat162(val, val); + } + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_bfloat16, 2>::load(const __hip_bfloat16* ptr) { + data = *((__hip_bfloat162*)ptr); + } + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_bfloat16, 2>::store(__hip_bfloat16* ptr) const { + *((__hip_bfloat162*)ptr) = data; + } + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_bfloat16, 2>::memcpy(__hip_bfloat16* dst, const __hip_bfloat16* src) { + *((__hip_bfloat162*)dst) = *((__hip_bfloat162*)src); + } + + // __hip_bfloat16 x 4 + + template <> + struct vec_t<__hip_bfloat16, 4> { + uint2 data; + + inline __attribute__((always_inline)) __device__ __hip_bfloat16& operator[](size_t i) { return ((__hip_bfloat16*)(&data))[i]; } + inline __attribute__((always_inline)) __device__ const __hip_bfloat16& operator[](size_t i) const { + return ((const __hip_bfloat16*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ __hip_bfloat16* ptr() { return reinterpret_cast<__hip_bfloat16*>(&data); } + inline __attribute__((always_inline)) __device__ void fill(__hip_bfloat16 val); + inline __attribute__((always_inline)) __device__ void load(const __hip_bfloat16* ptr); + inline __attribute__((always_inline)) __device__ void store(__hip_bfloat16* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void memcpy(__hip_bfloat16* dst, const __hip_bfloat16* src); + }; + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_bfloat16, 4>::fill(__hip_bfloat16 val) { + *(__hip_bfloat162*)(&data.x) = make_bfloat162(val, val); + *(__hip_bfloat162*)(&data.y) = make_bfloat162(val, val); + } + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_bfloat16, 4>::load(const __hip_bfloat16* ptr) { + data = *((uint2*)ptr); + } + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_bfloat16, 4>::store(__hip_bfloat16* ptr) const { + *((uint2*)ptr) = data; + } + + inline __attribute__((always_inline)) __device__ void vec_t<__hip_bfloat16, 4>::memcpy(__hip_bfloat16* dst, const __hip_bfloat16* src) { + *((uint2*)dst) = *((uint2*)src); + } + + // __hip_bfloat16 x 8 or more + + template + struct vec_t<__hip_bfloat16, vec_size> { + uint4 data[vec_size / 8]; + + inline __attribute__((always_inline)) __device__ __hip_bfloat16& operator[](size_t i) { return ((__hip_bfloat16*)data)[i]; } + inline __attribute__((always_inline)) __device__ const __hip_bfloat16& operator[](size_t i) const { + return ((const __hip_bfloat16*)data)[i]; + } + inline __attribute__((always_inline)) __device__ __hip_bfloat16* ptr() { return reinterpret_cast<__hip_bfloat16*>(&data); } + inline __attribute__((always_inline)) __device__ void fill(__hip_bfloat16 val) { + #pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + *(__hip_bfloat162*)(&(data[i].x)) = make_bfloat162(val, val); + *(__hip_bfloat162*)(&(data[i].y)) = make_bfloat162(val, val); + *(__hip_bfloat162*)(&(data[i].z)) = make_bfloat162(val, val); + *(__hip_bfloat162*)(&(data[i].w)) = make_bfloat162(val, val); + } + } + inline __attribute__((always_inline)) __device__ void load(const __hip_bfloat16* ptr) { + #pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + data[i] = ((uint4*)ptr)[i]; + } + } + inline __attribute__((always_inline)) __device__ void store(__hip_bfloat16* ptr) const { + #pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4*)ptr)[i] = data[i]; + } + } + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void memcpy(__hip_bfloat16* dst, const __hip_bfloat16* src) { + #pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4*)dst)[i] = ((uint4*)src)[i]; + } + } + }; + + /******************* vec_t *******************/ + + // float x 1 + + template <> + struct vec_t { + float data; + + inline __attribute__((always_inline)) __device__ float& operator[](size_t i) { return ((float*)(&data))[i]; } + inline __attribute__((always_inline)) __device__ const float& operator[](size_t i) const { return ((const float*)(&data))[i]; } + inline __attribute__((always_inline)) __device__ float* ptr() { return reinterpret_cast(&data); } + inline __attribute__((always_inline)) __device__ void fill(float val); + inline __attribute__((always_inline)) __device__ void load(const float* ptr); + inline __attribute__((always_inline)) __device__ void store(float* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void memcpy(float* dst, const float* src); + }; + + inline __attribute__((always_inline)) __device__ void vec_t::fill(float val) { data = val; } + + inline __attribute__((always_inline)) __device__ void vec_t::load(const float* ptr) { data = *ptr; } + + inline __attribute__((always_inline)) __device__ void vec_t::store(float* ptr) const { *ptr = data; } + + inline __attribute__((always_inline)) __device__ void vec_t::memcpy(float* dst, const float* src) { *dst = *src; } + + // float x 2 + + template <> + struct vec_t { + float2 data; + + inline __attribute__((always_inline)) __device__ float& operator[](size_t i) { return ((float*)(&data))[i]; } + inline __attribute__((always_inline)) __device__ const float& operator[](size_t i) const { return ((const float*)(&data))[i]; } + inline __attribute__((always_inline)) __device__ float* ptr() { return reinterpret_cast(&data); } + inline __attribute__((always_inline)) __device__ void fill(float val); + inline __attribute__((always_inline)) __device__ void load(const float* ptr); + inline __attribute__((always_inline)) __device__ void store(float* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void memcpy(float* dst, const float* src); + }; + + inline __attribute__((always_inline)) __device__ void vec_t::fill(float val) { data = make_float2(val, val); } + + inline __attribute__((always_inline)) __device__ void vec_t::load(const float* ptr) { data = *((float2*)ptr); } + + inline __attribute__((always_inline)) __device__ void vec_t::store(float* ptr) const { *((float2*)ptr) = data; } + + inline __attribute__((always_inline)) __device__ void vec_t::memcpy(float* dst, const float* src) { + *((float2*)dst) = *((float2*)src); + } + + // float x 4 or more + template + struct vec_t { + float4 data[vec_size / 4]; + + inline __attribute__((always_inline)) __device__ float& operator[](size_t i) { return ((float*)(data))[i]; } + inline __attribute__((always_inline)) __device__ const float& operator[](size_t i) const { return ((const float*)(data))[i]; } + inline __attribute__((always_inline)) __device__ float* ptr() { return reinterpret_cast(&data); } + inline __attribute__((always_inline)) __device__ void fill(float val) { + #pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + data[i] = make_float4(val, val, val, val); + } + } + inline __attribute__((always_inline)) __device__ void load(const float* ptr) { + #pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + data[i] = ((float4*)ptr)[i]; + } + } + inline __attribute__((always_inline)) __device__ void store(float* ptr) const { + #pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4*)ptr)[i] = data[i]; + } + } + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void memcpy(float* dst, const float* src) { + #pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4*)dst)[i] = ((float4*)src)[i]; + } + } + }; + + } // namespace flashinfer + + #endif // VEC_DTYPES_CUH_ \ No newline at end of file From 5b414acd0d71d5d9c1766a207657fc23903bc8fb Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Fri, 25 Jul 2025 03:15:01 +0000 Subject: [PATCH 04/19] add sampling kernels --- csrc/cpp_itfs/file_baton.py | 57 +++++++ csrc/cpp_itfs/sampling/sampling.cuh | 140 +++++++++--------- csrc/cpp_itfs/sampling/sampling_test.py | 78 ++++++++++ .../top_k_top_p_sampling_from_probs.cpp.jinja | 32 ++++ .../top_k_top_p_sampling_from_probs.py | 101 +++++++++++++ csrc/cpp_itfs/torch_utils.py | 29 +++- csrc/cpp_itfs/utils.py | 2 +- 7 files changed, 366 insertions(+), 73 deletions(-) create mode 100644 csrc/cpp_itfs/file_baton.py create mode 100644 csrc/cpp_itfs/sampling/sampling_test.py create mode 100644 csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.cpp.jinja create mode 100644 csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py diff --git a/csrc/cpp_itfs/file_baton.py b/csrc/cpp_itfs/file_baton.py new file mode 100644 index 0000000000..40ed604c97 --- /dev/null +++ b/csrc/cpp_itfs/file_baton.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +# mypy: allow-untyped-defs +import os +import time +import logging + +logger = logging.getLogger("aiter") + + +class FileBaton: + """A primitive, file-based synchronization utility.""" + + def __init__(self, lock_file_path, wait_seconds=0.2): + """ + Create a new :class:`FileBaton`. + + Args: + lock_file_path: The path to the file used for locking. + wait_seconds: The seconds to periodically sleep (spin) when + calling ``wait()``. + """ + self.lock_file_path = lock_file_path + self.wait_seconds = wait_seconds + self.fd = None + + def try_acquire(self): + """ + Try to atomically create a file under exclusive access. + + Returns: + True if the file could be created, else False. + """ + try: + self.fd = os.open(self.lock_file_path, os.O_CREAT | os.O_EXCL) + return True + except FileExistsError: + return False + + def wait(self): + """ + Periodically sleeps for a certain amount until the baton is released. + + The amount of time slept depends on the ``wait_seconds`` parameter + passed to the constructor. + """ + logger.info(f"waiting for baton release at {self.lock_file_path}") + while os.path.exists(self.lock_file_path): + time.sleep(self.wait_seconds) + + def release(self): + """Release the baton and removes its file.""" + if self.fd is not None: + os.close(self.fd) + + os.remove(self.lock_file_path) diff --git a/csrc/cpp_itfs/sampling/sampling.cuh b/csrc/cpp_itfs/sampling/sampling.cuh index cf8aa1b7a3..bea46681f8 100644 --- a/csrc/cpp_itfs/sampling/sampling.cuh +++ b/csrc/cpp_itfs/sampling/sampling.cuh @@ -1,3 +1,5 @@ +#pragma once + #include "hip/hip_runtime.h" /* * Copyright (C) 2024-2025 by FlashInfer team. @@ -14,16 +16,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #ifndef FLASHINFER_SAMPLING_CUH_ - #define FLASHINFER_SAMPLING_CUH_ - #include - #include + #include #include - #include - #include - #include + #include + #include + #include #include #include #include @@ -36,6 +35,12 @@ using namespace hipcub; + constexpr uint32_t BLOCK_THREADS = 1024; + + constexpr BlockScanAlgorithm SCAN_ALGO = BLOCK_SCAN_WARP_SCANS; + constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS; + + #define DISPATCH_ALIGNED_VEC_SIZE(aligned_vec_size, ALIGNED_VEC_SIZE, ...) \ switch (aligned_vec_size) { \ case 16: { \ @@ -78,13 +83,7 @@ } else { \ constexpr bool DETERMINISTIC = false; \ __VA_ARGS__ \ - } - - #define BLOCK_THREADS 1024 - - constexpr BlockScanAlgorithm SCAN_ALGO = BLOCK_SCAN_WARP_SCANS; - constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS; - + } template struct ValueCount { @@ -107,6 +106,11 @@ } }; + template + __forceinline__ __device__ __host__ T1 ceil_div(const T1 x, const T2 y) { + return (x + y - 1) / y; + } + template struct SamplingTempStorage { @@ -157,20 +161,20 @@ #pragma unroll for (uint32_t offset = 1; offset < 32; offset *= 2) { - float tmp = __shfl_up_sync(0xffffffff, thread_exclusive_prefix_sum, offset); + float tmp = __shfl_up(thread_exclusive_prefix_sum, offset); if ((threadIdx.x + 1) % (offset * 2) == 0) { thread_exclusive_prefix_sum += tmp; } } - float warp_sum = __shfl_sync(0xffffffff, thread_exclusive_prefix_sum, threadIdx.x | 0xffffffff); + float warp_sum = __shfl(thread_exclusive_prefix_sum, threadIdx.x | 0xffffffff); if (threadIdx.x % 32 == 31) { thread_exclusive_prefix_sum = 0; } #pragma unroll for (uint32_t offset = 16; offset >= 1; offset /= 2) { - float tmp = __shfl_xor_sync(0xffffffff, thread_exclusive_prefix_sum, offset); + float tmp = __shfl_xor(thread_exclusive_prefix_sum, offset); if ((threadIdx.x + 1) % (offset * 2) == 0) { thread_exclusive_prefix_sum = tmp + thread_exclusive_prefix_sum; } @@ -188,7 +192,7 @@ #pragma unroll for (uint32_t offset = 1; offset < 32; offset *= 2) { - float tmp = __shfl_up_sync(0xffffffff, warp_exclusive_prefix_sum, offset); + float tmp = __shfl_up(warp_exclusive_prefix_sum, offset); if ((threadIdx.x + 1) % (offset * 2) == 0) { warp_exclusive_prefix_sum += tmp; } @@ -200,7 +204,7 @@ #pragma unroll for (uint32_t offset = 16; offset >= 1; offset /= 2) { - float tmp = __shfl_xor_sync(0xffffffff, warp_exclusive_prefix_sum, offset); + float tmp = __shfl_xor(warp_exclusive_prefix_sum, offset); if ((threadIdx.x + 1) % (offset * 2) == 0) { warp_exclusive_prefix_sum = tmp + warp_exclusive_prefix_sum; } @@ -240,7 +244,7 @@ } max_val = max( max_val, BlockReduce(temp_storage.block_prim.reduce) - .Reduce(in_data_, hipcub::Max())); + .Reduce(in_data_, hipcub::Max())); __syncthreads(); } if (tx == 0) { @@ -267,7 +271,7 @@ } float aggregate_local = BlockReduce(temp_storage->block_prim.reduce) - .Sum(prob_greater_than_threshold); + .Sum(prob_greater_than_threshold); if (tx == 0) { temp_storage->block_aggregate.value = aggregate_local; } @@ -293,10 +297,10 @@ bool greater_than_u_diff[VEC_SIZE]; BlockAdjacentDifference(temp_storage->block_prim.adj_diff) - .SubtractLeft(greater_than_u, greater_than_u_diff, BoolDiffOp()); + .SubtractLeft(greater_than_u, greater_than_u_diff, BoolDiffOp{}); - // BlockAdjacentDifference(temp_storage->block_prim.adj_diff) - // .FlagHeads(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0); + // BlockAdjacentDifference(temp_storage->block_prim.adj_diff) + // .FlagHeads(greater_than_u_diff, greater_than_u, BoolDiffOp{}, 0); __syncthreads(); @@ -432,7 +436,7 @@ max_data += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage) - .Sum(cur_data); + .Sum(cur_data); } if (tx == 0) { output[bx] = max_data.index; @@ -564,7 +568,7 @@ aggregate_gt_pivot_0 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( temp_storage.block_prim.reduce_value_count) - .Sum(probs_gt_pivot_0); + .Sum(probs_gt_pivot_0); if (tx == 0) { temp_storage.block_aggregate.pair = aggregate_gt_pivot_0; } @@ -573,7 +577,7 @@ aggregate_gt_pivot_1 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( temp_storage.block_prim.reduce_value_count) - .Sum(probs_gt_pivot_1); + .Sum(probs_gt_pivot_1); if (tx == 0) { temp_storage.block_aggregate.pair = aggregate_gt_pivot_1; } @@ -672,7 +676,7 @@ } aggregate_gt_pivot_0 += BlockReduce(temp_storage.block_prim.reduce) - .Sum(probs_gt_pivot_0); + .Sum(probs_gt_pivot_0); if (tx == 0) { temp_storage.block_aggregate.value = aggregate_gt_pivot_0; } @@ -680,7 +684,7 @@ aggregate_gt_pivot_0 = temp_storage.block_aggregate.value; aggregate_gt_pivot_1 += BlockReduce(temp_storage.block_prim.reduce) - .Sum(probs_gt_pivot_1); + .Sum(probs_gt_pivot_1); if (tx == 0) { temp_storage.block_aggregate.value = aggregate_gt_pivot_1; } @@ -722,14 +726,14 @@ const uint32_t row_idx = indices == nullptr ? bx : indices[bx]; const uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[row_idx]; const float p = top_p_arr == nullptr ? top_p_val : top_p_arr[row_idx]; - + extern __shared__ __align__( alignof(SamplingTempStorage)) uint8_t smem_sampling[]; auto& temp_storage = reinterpret_cast&>( smem_sampling); - + vec_t probs_vec; float aggregate; float q = 1; @@ -764,7 +768,7 @@ } double pivot_0 = probs[row_idx * d + sampled_id]; double pivot_1 = (pivot_0 + high) / 2; - + ValueCount aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0}; #pragma unroll 2 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { @@ -786,7 +790,7 @@ aggregate_gt_pivot_0 += BlockReduce, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count) - .Sum(probs_gt_pivot_0); + .Sum(probs_gt_pivot_0); if (tx == 0) { temp_storage.block_aggregate.pair = aggregate_gt_pivot_0; } @@ -795,7 +799,7 @@ aggregate_gt_pivot_1 += BlockReduce, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count) - .Sum(probs_gt_pivot_1); + .Sum(probs_gt_pivot_1); if (tx == 0) { temp_storage.block_aggregate.pair = aggregate_gt_pivot_1; } @@ -827,7 +831,7 @@ hipError_t SamplingFromLogits(T* logits, IdType* output, IdType* indices, uint32_t batch_size, uint32_t d, bool deterministic, uint64_t philox_seed, uint64_t philox_offset, hipStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = 1024; + const uint32_t vec_size = std::gcd(16 / sizeof(T), d); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); @@ -848,7 +852,6 @@ hipError_t SamplingFromProb(T* probs, IdType* output, IdType* indices, uint32_t batch_size, uint32_t d, bool deterministic, uint64_t philox_seed, uint64_t philox_offset, hipStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(T), d); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); @@ -894,7 +897,7 @@ uint32_t batch_size, T top_p_val, uint32_t d, bool deterministic, uint64_t philox_seed, uint64_t philox_offset, hipStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = 1024; + const uint32_t vec_size = std::gcd(16 / sizeof(T), d); const uint32_t smem_size = sizeof(SamplingTempStorage); @@ -915,31 +918,31 @@ return hipSuccess; } - template - hipError_t TopKTopPSamplingFromProb(T* probs, IdType* top_k_arr, T* top_p_arr, IdType* output, - IdType* indices, uint32_t batch_size, IdType top_k_val, - T top_p_val, uint32_t d, bool deterministic, - uint64_t philox_seed, uint64_t philox_offset, - hipStream_t stream = 0) { - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - - const uint32_t smem_size = sizeof(SamplingTempStorage); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &top_k_arr, &top_p_arr, &output, &indices, - &top_k_val, &top_p_val, &d, &philox_seed, &philox_offset}; +// template +// hipError_t TopKTopPSamplingFromProb(T* probs, IdType* top_k_arr, T* top_p_arr, IdType* output, +// IdType* indices, uint32_t batch_size, IdType top_k_val, +// T top_p_val, uint32_t d, bool deterministic, +// uint64_t philox_seed, uint64_t philox_offset, +// hipStream_t stream = 0) { +// const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + +// const uint32_t smem_size = sizeof(SamplingTempStorage); +// dim3 nblks(batch_size); +// dim3 nthrs(BLOCK_THREADS); +// void* args[] = {&probs, &top_k_arr, &top_p_arr, &output, &indices, +// &top_k_val, &top_p_val, &d, &philox_seed, &philox_offset}; - DISPATCH_ALIGNED_VEC_SIZE( - vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { - auto kernel = TopKTopPSamplingFromProbKernel; +// DISPATCH_ALIGNED_VEC_SIZE( +// vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { +// auto kernel = TopKTopPSamplingFromProbKernel; - hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); +// hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); - hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); - })}); - return hipSuccess; - } +// hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); +// })}); +// return hipSuccess; +// } template struct RenormTempStorage { @@ -1025,12 +1028,12 @@ aggregate_gt_pivot_0 += BlockReduce(temp_storage.block_prim.reduce) - .Sum(probs_gt_pivot_0); + .Sum(probs_gt_pivot_0); __syncthreads(); aggregate_gt_pivot_1 += BlockReduce(temp_storage.block_prim.reduce) - .Sum(probs_gt_pivot_1); + .Sum(probs_gt_pivot_1); __syncthreads(); } min_gt_low = BlockReduce(temp_storage.block_prim.reduce) @@ -1146,12 +1149,12 @@ aggregate_gt_pivot_0 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( temp_storage.block_prim.reduce_value_count) - .Sum(probs_gt_pivot_0_pair); + .Sum(probs_gt_pivot_0_pair); __syncthreads(); aggregate_gt_pivot_1 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( temp_storage.block_prim.reduce_value_count) - .Sum(probs_gt_pivot_1_pair); + .Sum(probs_gt_pivot_1_pair); __syncthreads(); } min_gt_low = @@ -1210,7 +1213,7 @@ hipError_t TopPRenormProb(DType* probs, DType* renormed_prob, float* top_p_arr, uint32_t batch_size, float top_p_val, uint32_t d, hipStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = 1024; + const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); const uint32_t smem_size = sizeof(RenormTempStorage); @@ -1236,16 +1239,13 @@ dim3 nthrs(BLOCK_THREADS); void* args[] = {&probs, &renormed_prob, &top_k_arr, &top_k_val, &d}; DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = TopKRenormProbKernel; - + auto kernel = TopKRenormProbKernel; hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); - hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); + hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); }); return hipSuccess; } } // namespace sampling - } // namespace flashinfer - - #endif // FLASHINFER_SAMPLING_CUH_ \ No newline at end of file + } // namespace aiter \ No newline at end of file diff --git a/csrc/cpp_itfs/sampling/sampling_test.py b/csrc/cpp_itfs/sampling/sampling_test.py new file mode 100644 index 0000000000..beffa28d9f --- /dev/null +++ b/csrc/cpp_itfs/sampling/sampling_test.py @@ -0,0 +1,78 @@ +""" +Copyright (C) 2024-2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import pytest +import torch + +from csrc.cpp_itfs.sampling.top_k_top_p_sampling_from_probs import ( + top_k_top_p_sampling_from_probs, +) + +torch.set_default_device("cuda") + + +def _to_tensor_scalar_tuple(x): + if isinstance(x, torch.Tensor): + return (x, 0) + else: + return (None, x) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("p", [0.1, 0.5]) +def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p): + torch.manual_seed(42) + if p == 0.1: + k = int(vocab_size * 0.5) + elif p == 0.5: + k = int(vocab_size * 0.1) + else: + raise ValueError("p not recognized") + eps = 1e-4 + pre_norm_prob = torch.rand(batch_size, vocab_size) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + # top-p mask + sorted_prob, indices = torch.sort(normalized_prob, descending=False) + cdf = torch.cumsum(sorted_prob, dim=-1) + mask_top_p = torch.zeros(batch_size, vocab_size, dtype=torch.int32) + mask_top_p.scatter_add_(1, indices, (cdf > (1 - p) - eps).int()) + # top-k mask + sorted_prob, _ = torch.sort(normalized_prob, descending=True) + pivot = sorted_prob[:, k - 1] + mask_top_k = (normalized_prob >= pivot.unsqueeze(-1)).int() + # overall mask + mask = torch.minimum(mask_top_p, mask_top_k) + top_p_tensor = torch.full((batch_size,), p) + top_k_tensor = torch.full((batch_size,), k) + + num_trails = 1000 + for _ in range(num_trails): + samples = top_k_top_p_sampling_from_probs( + normalized_prob, + None, + *_to_tensor_scalar_tuple(top_k_tensor), + *_to_tensor_scalar_tuple(top_p_tensor), + deterministic=True, + ) + assert torch.all(samples < vocab_size) and torch.all(samples >= 0) + assert torch.all(mask[torch.arange(batch_size), samples] == 1), normalized_prob[ + torch.arange(batch_size), samples + ] + + +if __name__ == "__main__": + test_top_k_top_p_joint_sampling_from_probs(1, 111, 0.1) diff --git a/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.cpp.jinja b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.cpp.jinja new file mode 100644 index 0000000000..82d07c224f --- /dev/null +++ b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.cpp.jinja @@ -0,0 +1,32 @@ +#include "sampling.cuh" + + +#define FUNCTION_DEFINE \ + void {{func_name}}(void* probs_ptr, \ + void* output_ptr, \ + void* indices_ptr, \ + void* top_k_arr_ptr, \ + void* top_p_arr_ptr, \ + int batch_size, \ + int top_k_val, \ + float top_p_val, \ + int philox_seed, \ + int philox_offset, \ + void* stream) + +extern "C" { +FUNCTION_DEFINE; +} + +FUNCTION_DEFINE +{ + constexpr uint32_t vec_size = std::gcd(16 / sizeof(float), {{d}}); + + const uint32_t smem_size = sizeof(aiter::sampling::SamplingTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(aiter::sampling::BLOCK_THREADS); + auto kernel = aiter::sampling::TopKTopPSamplingFromProbKernel; + hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); + kernel<<(stream)>>>(reinterpret_cast(probs_ptr), reinterpret_cast(top_k_arr_ptr), reinterpret_cast(top_p_arr_ptr), reinterpret_cast(output_ptr), reinterpret_cast(indices_ptr), top_k_val, top_p_val, {{d}}, static_cast(philox_seed), static_cast(philox_offset)); +} \ No newline at end of file diff --git a/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py new file mode 100644 index 0000000000..44bb0db7bc --- /dev/null +++ b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py @@ -0,0 +1,101 @@ +from jinja2 import Template +from csrc.cpp_itfs.utils import compile_template_op, AITER_CORE_DIR + + +MD_NAME = "top_k_top_p_sampling_from_probs" + +with open( + f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.cpp.jinja", + "r", +) as f: + src_template = Template(f.read()) + + +def compile( + d: int, + deterministic: bool, + folder: str = None, +): + return compile_template_op( + src_template, + MD_NAME, + [ + f"{AITER_CORE_DIR}/csrc/cpp_itfs/utils.h", + f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/sampling.cuh", + f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/vec_dtypes.cuh", + ], + d=d, + deterministic=deterministic, + folder=folder, + ) + + +def top_k_top_p_sampling_from_probs( + probs, + indices, + maybe_top_k_arr, + top_k_val, + maybe_top_p_arr, + top_p_val, + deterministic=False, + gen=None, +): + import torch + from csrc.cpp_itfs.torch_utils import torch_to_c_types + + if gen is None: + gen = torch.cuda.default_generators[probs.device.index] + probs = probs.float() + top_p_val = float(top_p_val) + maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None + maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None + + batch_size = indices.size(0) if indices is not None else probs.size(0) + vocab_size = probs.size(1) + philox_offset = gen.get_offset() + philox_seed = gen.seed() + + output = torch.empty(batch_size, dtype=torch.int32, device=probs.device) + + func = compile(vocab_size, deterministic) + ( + probs_ptr, + output_ptr, + indices_ptr, + top_k_arr_ptr, + top_p_arr_ptr, + top_k_val, + top_p_val, + vocab_size, + batch_size, + philox_seed, + philox_offset, + stream, + ) = torch_to_c_types( + probs, + output, + indices, + maybe_top_k_arr, + maybe_top_p_arr, + top_k_val, + top_p_val, + vocab_size, + batch_size, + philox_seed, + philox_offset, + torch.cuda.current_stream(), + ) + func( + probs_ptr, + output_ptr, + indices_ptr, + top_k_arr_ptr, + top_p_arr_ptr, + batch_size, + top_k_val, + top_p_val, + philox_seed, + philox_offset, + stream, + ) + return output diff --git a/csrc/cpp_itfs/torch_utils.py b/csrc/cpp_itfs/torch_utils.py index 09afec4898..cbc7ed75cb 100644 --- a/csrc/cpp_itfs/torch_utils.py +++ b/csrc/cpp_itfs/torch_utils.py @@ -2,8 +2,33 @@ import ctypes from torch.library import Library from typing import Callable, Optional, Tuple -from csrc.cpp_itfs.utils import AITER_LOG_MORE -from aiter.test_common import log_args +from csrc.cpp_itfs.utils import AITER_LOG_MORE, logger + + +def log_args(func, *args, **kwargs): + import inspect + + callargs = inspect.getcallargs(func, *args, **kwargs) + + prefix = f"calling {func.__name__}(" + blanks = " " * (len(prefix)) + + def getTensorInfo(el): + if isinstance(el, torch.Tensor): + return f"{el.shape} {el.dtype} {el.device} {hex(el.data_ptr())}" + elif isinstance(el, tuple): + viewNum = 5 + if len(el) > viewNum: + el = list(el[:viewNum]) + ["..."] + return f'\n{" "*(len(prefix)+31)}'.join( + ["("] + [f" {getTensorInfo(e)}" for e in el] + [")"] + ) + return el + + info = [f"{el:<28} = {getTensorInfo(callargs[el])}" for el in callargs] + info = f",\n{blanks}".join(info) + logger.info(f"\n{prefix}{info})") + return callargs ctypes_map = { diff --git a/csrc/cpp_itfs/utils.py b/csrc/cpp_itfs/utils.py index 4702887701..fc348af2f2 100644 --- a/csrc/cpp_itfs/utils.py +++ b/csrc/cpp_itfs/utils.py @@ -8,7 +8,7 @@ from functools import lru_cache, partial import binascii import hashlib -from aiter.jit.utils.file_baton import FileBaton +from csrc.cpp_itfs.file_baton import FileBaton import logging import time From e537b70f276b970fd89156878f6e3fcb5ce92634 Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Fri, 25 Jul 2025 05:57:02 +0000 Subject: [PATCH 05/19] add topk renorm probs kernel --- csrc/cpp_itfs/sampling/sampling.cuh | 47 ------------ csrc/cpp_itfs/sampling/sampling_test.py | 35 ++++++++- .../sampling/top_k_renorm_probs.cpp.jinja | 25 +++++++ csrc/cpp_itfs/sampling/top_k_renorm_probs.py | 74 +++++++++++++++++++ .../top_k_top_p_sampling_from_probs.py | 1 + 5 files changed, 134 insertions(+), 48 deletions(-) create mode 100644 csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja create mode 100644 csrc/cpp_itfs/sampling/top_k_renorm_probs.py diff --git a/csrc/cpp_itfs/sampling/sampling.cuh b/csrc/cpp_itfs/sampling/sampling.cuh index bea46681f8..80858b5c5c 100644 --- a/csrc/cpp_itfs/sampling/sampling.cuh +++ b/csrc/cpp_itfs/sampling/sampling.cuh @@ -299,9 +299,6 @@ BlockAdjacentDifference(temp_storage->block_prim.adj_diff) .SubtractLeft(greater_than_u, greater_than_u_diff, BoolDiffOp{}); - // BlockAdjacentDifference(temp_storage->block_prim.adj_diff) - // .FlagHeads(greater_than_u_diff, greater_than_u, BoolDiffOp{}, 0); - __syncthreads(); #pragma unroll @@ -917,32 +914,6 @@ })}); return hipSuccess; } - -// template -// hipError_t TopKTopPSamplingFromProb(T* probs, IdType* top_k_arr, T* top_p_arr, IdType* output, -// IdType* indices, uint32_t batch_size, IdType top_k_val, -// T top_p_val, uint32_t d, bool deterministic, -// uint64_t philox_seed, uint64_t philox_offset, -// hipStream_t stream = 0) { -// const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - -// const uint32_t smem_size = sizeof(SamplingTempStorage); -// dim3 nblks(batch_size); -// dim3 nthrs(BLOCK_THREADS); -// void* args[] = {&probs, &top_k_arr, &top_p_arr, &output, &indices, -// &top_k_val, &top_p_val, &d, &philox_seed, &philox_offset}; - -// DISPATCH_ALIGNED_VEC_SIZE( -// vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { -// auto kernel = TopKTopPSamplingFromProbKernel; - -// hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); - -// hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); -// })}); -// return hipSuccess; -// } template struct RenormTempStorage { @@ -1228,24 +1199,6 @@ return hipSuccess; } - template - hipError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr, - uint32_t batch_size, uint32_t top_k_val, uint32_t d, - hipStream_t stream = 0) { - const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); - - const uint32_t smem_size = sizeof(RenormTempStorage); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &renormed_prob, &top_k_arr, &top_k_val, &d}; - DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = TopKRenormProbKernel; - hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); - hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); - }); - return hipSuccess; - } - } // namespace sampling } // namespace aiter \ No newline at end of file diff --git a/csrc/cpp_itfs/sampling/sampling_test.py b/csrc/cpp_itfs/sampling/sampling_test.py index beffa28d9f..5c60f14dae 100644 --- a/csrc/cpp_itfs/sampling/sampling_test.py +++ b/csrc/cpp_itfs/sampling/sampling_test.py @@ -21,6 +21,10 @@ top_k_top_p_sampling_from_probs, ) +from csrc.cpp_itfs.sampling.top_k_renorm_probs import ( + top_k_renorm_probs, +) + torch.set_default_device("cuda") @@ -31,6 +35,34 @@ def _to_tensor_scalar_tuple(x): return (None, x) +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("k", [10, 100, 500]) +def test_top_k_renorm_probs(batch_size, vocab_size, k): + if k > vocab_size: + pytest.skip("k should be less than vocab_size") + torch.manual_seed(42) + pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + sorted_prob, _ = torch.sort(normalized_prob, descending=True) + pivot = sorted_prob[:, k - 1] + mask = (normalized_prob >= pivot.unsqueeze(-1)).int() + renorm_prob_ground_truth = normalized_prob.clone() + renorm_prob_ground_truth[mask == 0] = 0 + renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum( + dim=-1, keepdim=True + ) + + renorm_prob = top_k_renorm_probs(normalized_prob, *_to_tensor_scalar_tuple(k)) + for i in range(batch_size): + torch.testing.assert_close( + renorm_prob_ground_truth[i], + renorm_prob[i], + rtol=1e-3, + atol=1e-3, + ) + + @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) @pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) @pytest.mark.parametrize("p", [0.1, 0.5]) @@ -75,4 +107,5 @@ def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p): if __name__ == "__main__": - test_top_k_top_p_joint_sampling_from_probs(1, 111, 0.1) + # test_top_k_top_p_joint_sampling_from_probs(1, 111, 0.1) + test_top_k_renorm_probs(1, 111, 10) diff --git a/csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja b/csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja new file mode 100644 index 0000000000..a13241dd61 --- /dev/null +++ b/csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja @@ -0,0 +1,25 @@ +#include "sampling.cuh" + + +#define FUNCTION_DEFINE \ + void {{func_name}}(void* probs_ptr, \ + void* renormed_probs_ptr, \ + void* top_k_arr_ptr, \ + int batch_size, \ + int top_k_val, \ + void* stream) + +extern "C" { +FUNCTION_DEFINE; +} + +FUNCTION_DEFINE +{ + constexpr uint32_t vec_size = std::gcd(16 / sizeof(float), {{d}}); + + const uint32_t smem_size = sizeof(aiter::sampling::RenormTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(aiter::sampling::BLOCK_THREADS); + auto kernel = aiter::sampling::TopKRenormProbKernel; + kernel<<(stream)>>>(reinterpret_cast(probs_ptr), reinterpret_cast(renormed_probs_ptr), reinterpret_cast(top_k_arr_ptr), top_k_val, {{d}}); +} \ No newline at end of file diff --git a/csrc/cpp_itfs/sampling/top_k_renorm_probs.py b/csrc/cpp_itfs/sampling/top_k_renorm_probs.py new file mode 100644 index 0000000000..2c631ab521 --- /dev/null +++ b/csrc/cpp_itfs/sampling/top_k_renorm_probs.py @@ -0,0 +1,74 @@ +from jinja2 import Template +from csrc.cpp_itfs.utils import compile_template_op, AITER_CORE_DIR + + +MD_NAME = "top_k_renorm_probs" + +with open( + f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja", + "r", +) as f: + src_template = Template(f.read()) + + +def compile( + d: int, + folder: str = None, +): + return compile_template_op( + src_template, + MD_NAME, + [ + f"{AITER_CORE_DIR}/csrc/cpp_itfs/utils.h", + f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/sampling.cuh", + f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/vec_dtypes.cuh", + ], + d=d, + folder=folder, + ) + + +def top_k_renorm_probs( + probs, + maybe_top_k_arr, + top_k_val, +): + import torch + from csrc.cpp_itfs.torch_utils import torch_to_c_types + + probs = probs.float() + maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None + top_k_val = int(top_k_val) + + batch_size = probs.size(0) + vocab_size = probs.size(1) + + renorm_probs = torch.empty_like(probs) + + func = compile(vocab_size) + ( + probs_ptr, + renorm_probs_ptr, + top_k_arr_ptr, + top_k_val, + vocab_size, + batch_size, + stream, + ) = torch_to_c_types( + probs, + renorm_probs, + maybe_top_k_arr, + top_k_val, + vocab_size, + batch_size, + torch.cuda.current_stream(), + ) + func( + probs_ptr, + renorm_probs_ptr, + top_k_arr_ptr, + batch_size, + top_k_val, + stream, + ) + return renorm_probs diff --git a/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py index 44bb0db7bc..8c3c4fd8cd 100644 --- a/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py +++ b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py @@ -47,6 +47,7 @@ def top_k_top_p_sampling_from_probs( gen = torch.cuda.default_generators[probs.device.index] probs = probs.float() top_p_val = float(top_p_val) + top_k_val = int(top_k_val) maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None From c752fd8190f8dd466373483bae23889d2fad91d9 Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Fri, 25 Jul 2025 06:33:50 +0000 Subject: [PATCH 06/19] add top p sampling --- csrc/cpp_itfs/sampling/sampling.cuh | 44 ++++----- csrc/cpp_itfs/sampling/sampling_test.py | 30 ++++++- csrc/cpp_itfs/sampling/top_k_renorm_probs.py | 2 - .../top_k_top_p_sampling_from_probs.py | 12 ++- .../top_p_sampling_from_probs.cpp.jinja | 30 +++++++ .../sampling/top_p_sampling_from_probs.py | 90 +++++++++++++++++++ 6 files changed, 176 insertions(+), 32 deletions(-) create mode 100644 csrc/cpp_itfs/sampling/top_p_sampling_from_probs.cpp.jinja create mode 100644 csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py diff --git a/csrc/cpp_itfs/sampling/sampling.cuh b/csrc/cpp_itfs/sampling/sampling.cuh index 80858b5c5c..f928054a88 100644 --- a/csrc/cpp_itfs/sampling/sampling.cuh +++ b/csrc/cpp_itfs/sampling/sampling.cuh @@ -889,31 +889,31 @@ return hipSuccess; } - template - hipError_t TopPSamplingFromProb(T* probs, IdType* output, IdType* indices, T* top_p_arr, - uint32_t batch_size, T top_p_val, uint32_t d, bool deterministic, - uint64_t philox_seed, uint64_t philox_offset, - hipStream_t stream = 0) { +// template +// hipError_t TopPSamplingFromProb(T* probs, IdType* output, IdType* indices, T* top_p_arr, +// uint32_t batch_size, T top_p_val, uint32_t d, bool deterministic, +// uint64_t philox_seed, uint64_t philox_offset, +// hipStream_t stream = 0) { - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - - const uint32_t smem_size = sizeof(SamplingTempStorage); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &output, &indices, &top_p_arr, - &top_p_val, &d, &philox_seed, &philox_offset}; - - DISPATCH_ALIGNED_VEC_SIZE( - vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { - auto kernel = TopPSamplingFromProbKernel; +// const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + +// const uint32_t smem_size = sizeof(SamplingTempStorage); +// dim3 nblks(batch_size); +// dim3 nthrs(BLOCK_THREADS); +// void* args[] = {&probs, &output, &indices, &top_p_arr, +// &top_p_val, &d, &philox_seed, &philox_offset}; + +// DISPATCH_ALIGNED_VEC_SIZE( +// vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { +// auto kernel = TopPSamplingFromProbKernel; - hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); +// hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); - hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); - })}); - return hipSuccess; - } +// hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); +// })}); +// return hipSuccess; +// } template struct RenormTempStorage { diff --git a/csrc/cpp_itfs/sampling/sampling_test.py b/csrc/cpp_itfs/sampling/sampling_test.py index 5c60f14dae..7ad1fb4003 100644 --- a/csrc/cpp_itfs/sampling/sampling_test.py +++ b/csrc/cpp_itfs/sampling/sampling_test.py @@ -25,6 +25,11 @@ top_k_renorm_probs, ) +from csrc.cpp_itfs.sampling.top_p_sampling_from_probs import ( + top_p_sampling_from_probs, +) + + torch.set_default_device("cuda") @@ -35,6 +40,28 @@ def _to_tensor_scalar_tuple(x): return (None, x) +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("p", [0.1, 0.5, 0.9]) +def test_top_p_sampling(batch_size, vocab_size, p): + torch.manual_seed(42) + eps = 1e-4 + pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + sorted_prob, indices = torch.sort(normalized_prob, descending=False) + cdf = torch.cumsum(sorted_prob, dim=-1) + mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0) + mask.scatter_add_(1, indices, (cdf > (1 - p) - eps).int()) + + num_trails = 1000 + for _ in range(num_trails): + samples = top_p_sampling_from_probs( + normalized_prob, None, *_to_tensor_scalar_tuple(p), deterministic=True + ) + assert torch.all(samples < vocab_size) and torch.all(samples >= 0) + assert torch.all(mask[torch.arange(batch_size), samples] == 1) + + @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) @pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) @pytest.mark.parametrize("k", [10, 100, 500]) @@ -107,5 +134,6 @@ def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p): if __name__ == "__main__": - # test_top_k_top_p_joint_sampling_from_probs(1, 111, 0.1) + test_top_k_top_p_joint_sampling_from_probs(1, 111, 0.1) test_top_k_renorm_probs(1, 111, 10) + test_top_p_sampling(1, 111, 0.1) diff --git a/csrc/cpp_itfs/sampling/top_k_renorm_probs.py b/csrc/cpp_itfs/sampling/top_k_renorm_probs.py index 2c631ab521..1bd797b8e1 100644 --- a/csrc/cpp_itfs/sampling/top_k_renorm_probs.py +++ b/csrc/cpp_itfs/sampling/top_k_renorm_probs.py @@ -51,7 +51,6 @@ def top_k_renorm_probs( renorm_probs_ptr, top_k_arr_ptr, top_k_val, - vocab_size, batch_size, stream, ) = torch_to_c_types( @@ -59,7 +58,6 @@ def top_k_renorm_probs( renorm_probs, maybe_top_k_arr, top_k_val, - vocab_size, batch_size, torch.cuda.current_stream(), ) diff --git a/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py index 8c3c4fd8cd..d1dc6f8a49 100644 --- a/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py +++ b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py @@ -38,13 +38,13 @@ def top_k_top_p_sampling_from_probs( maybe_top_p_arr, top_p_val, deterministic=False, - gen=None, + generator=None, ): import torch from csrc.cpp_itfs.torch_utils import torch_to_c_types - if gen is None: - gen = torch.cuda.default_generators[probs.device.index] + if generator is None: + generator = torch.cuda.default_generators[probs.device.index] probs = probs.float() top_p_val = float(top_p_val) top_k_val = int(top_k_val) @@ -53,8 +53,8 @@ def top_k_top_p_sampling_from_probs( batch_size = indices.size(0) if indices is not None else probs.size(0) vocab_size = probs.size(1) - philox_offset = gen.get_offset() - philox_seed = gen.seed() + philox_offset = generator.get_offset() + philox_seed = generator.seed() output = torch.empty(batch_size, dtype=torch.int32, device=probs.device) @@ -67,7 +67,6 @@ def top_k_top_p_sampling_from_probs( top_p_arr_ptr, top_k_val, top_p_val, - vocab_size, batch_size, philox_seed, philox_offset, @@ -80,7 +79,6 @@ def top_k_top_p_sampling_from_probs( maybe_top_p_arr, top_k_val, top_p_val, - vocab_size, batch_size, philox_seed, philox_offset, diff --git a/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.cpp.jinja b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.cpp.jinja new file mode 100644 index 0000000000..fb1d7dad5e --- /dev/null +++ b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.cpp.jinja @@ -0,0 +1,30 @@ +#include "sampling.cuh" + + +#define FUNCTION_DEFINE \ + void {{func_name}}(void* probs_ptr, \ + void* output_ptr, \ + void* indices_ptr, \ + void* top_p_arr_ptr, \ + int batch_size, \ + float top_p_val, \ + int philox_seed, \ + int philox_offset, \ + void* stream) + +extern "C" { +FUNCTION_DEFINE; +} + +FUNCTION_DEFINE +{ + constexpr uint32_t vec_size = std::gcd(16 / sizeof(float), {{d}}); + + const uint32_t smem_size = sizeof(aiter::sampling::SamplingTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(aiter::sampling::BLOCK_THREADS); + auto kernel = aiter::sampling::TopPSamplingFromProbKernel; + hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); + kernel<<(stream)>>>(reinterpret_cast(probs_ptr), reinterpret_cast(output_ptr), reinterpret_cast(indices_ptr), reinterpret_cast(top_p_arr_ptr), top_p_val, {{d}}, static_cast(philox_seed), static_cast(philox_offset)); +} \ No newline at end of file diff --git a/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py new file mode 100644 index 0000000000..382e9e8ed6 --- /dev/null +++ b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py @@ -0,0 +1,90 @@ +from jinja2 import Template +from csrc.cpp_itfs.utils import compile_template_op, AITER_CORE_DIR + + +MD_NAME = "top_p_sampling_from_probs" + +with open( + f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.cpp.jinja", + "r", +) as f: + src_template = Template(f.read()) + + +def compile( + d: int, + deterministic: bool, + folder: str = None, +): + return compile_template_op( + src_template, + MD_NAME, + [ + f"{AITER_CORE_DIR}/csrc/cpp_itfs/utils.h", + f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/sampling.cuh", + f"{AITER_CORE_DIR}/csrc/cpp_itfs/sampling/vec_dtypes.cuh", + ], + d=d, + deterministic=deterministic, + folder=folder, + ) + + +def top_p_sampling_from_probs( + probs, + indices, + maybe_top_p_arr, + top_p_val, + deterministic: bool = False, + generator=None, +): + import torch + from csrc.cpp_itfs.torch_utils import torch_to_c_types + + if generator is None: + generator = torch.cuda.default_generators[probs.device.index] + philox_offset = generator.get_offset() + philox_seed = generator.seed() + + probs = probs.float() + maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None + top_p_val = float(top_p_val) + + batch_size = probs.size(0) + vocab_size = probs.size(1) + + samples = torch.empty(batch_size, dtype=torch.int32, device=probs.device) + func = compile(vocab_size, deterministic) + ( + probs_ptr, + samples_ptr, + indices_ptr, + top_p_arr_ptr, + top_p_val, + batch_size, + philox_seed, + philox_offset, + stream, + ) = torch_to_c_types( + probs, + samples, + indices, + maybe_top_p_arr, + top_p_val, + batch_size, + philox_seed, + philox_offset, + torch.cuda.current_stream(), + ) + func( + probs_ptr, + samples_ptr, + indices_ptr, + top_p_arr_ptr, + batch_size, + top_p_val, + philox_seed, + philox_offset, + stream, + ) + return samples From 2da7397c73be8d4d91962b09c80eb5899c1cdba9 Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Fri, 25 Jul 2025 06:57:09 +0000 Subject: [PATCH 07/19] register sampling ops --- aiter/ops/sampling.py | 82 +++++++++++++++++++ .../test_sampling.py | 21 ++--- 2 files changed, 88 insertions(+), 15 deletions(-) create mode 100644 aiter/ops/sampling.py rename csrc/cpp_itfs/sampling/sampling_test.py => op_tests/test_sampling.py (90%) diff --git a/aiter/ops/sampling.py b/aiter/ops/sampling.py new file mode 100644 index 0000000000..434f8de8b4 --- /dev/null +++ b/aiter/ops/sampling.py @@ -0,0 +1,82 @@ +import torch +from typing import Optional + +from csrc.cpp_itfs.sampling.top_k_renorm_probs import ( + top_k_renorm_probs as top_k_renorm_probs_core, +) +from csrc.cpp_itfs.sampling.top_p_sampling_from_probs import ( + top_p_sampling_from_probs as top_p_sampling_from_probs_core, +) +from csrc.cpp_itfs.sampling.top_k_top_p_sampling_from_probs import ( + top_k_top_p_sampling_from_probs as top_k_top_p_sampling_from_probs_core, +) +from csrc.cpp_itfs.torch_utils import direct_register_custom_op + + +def top_k_renorm_probs( + probs: torch.Tensor, + maybe_top_k_arr: Optional[torch.Tensor], + top_k_val: int, +) -> torch.Tensor: + return top_k_renorm_probs_core( + probs, + maybe_top_k_arr, + top_k_val, + ) + + +direct_register_custom_op( + "top_k_renorm_probs", + top_k_renorm_probs, + [], +) + + +def top_p_sampling_from_probs( + probs: torch.Tensor, + indices: torch.Tensor, + maybe_top_p_arr: Optional[torch.Tensor], + top_p_val: float, + deterministic: bool = False, +) -> torch.Tensor: + return top_p_sampling_from_probs_core( + probs, + indices, + maybe_top_p_arr, + top_p_val, + deterministic, + ) + + +direct_register_custom_op( + "top_p_sampling_from_probs", + top_p_sampling_from_probs, + [], +) + + +def top_k_top_p_sampling_from_probs( + probs: torch.Tensor, + indices: torch.Tensor, + maybe_top_k_arr: Optional[torch.Tensor], + top_k_val: int, + maybe_top_p_arr: Optional[torch.Tensor], + top_p_val: float, + deterministic: bool = False, +) -> torch.Tensor: + return top_k_top_p_sampling_from_probs_core( + probs, + indices, + maybe_top_k_arr, + top_k_val, + maybe_top_p_arr, + top_p_val, + deterministic, + ) + + +direct_register_custom_op( + "top_k_top_p_sampling_from_probs", + top_k_top_p_sampling_from_probs, + [], +) diff --git a/csrc/cpp_itfs/sampling/sampling_test.py b/op_tests/test_sampling.py similarity index 90% rename from csrc/cpp_itfs/sampling/sampling_test.py rename to op_tests/test_sampling.py index 7ad1fb4003..ec6e9aaa05 100644 --- a/csrc/cpp_itfs/sampling/sampling_test.py +++ b/op_tests/test_sampling.py @@ -17,18 +17,7 @@ import pytest import torch -from csrc.cpp_itfs.sampling.top_k_top_p_sampling_from_probs import ( - top_k_top_p_sampling_from_probs, -) - -from csrc.cpp_itfs.sampling.top_k_renorm_probs import ( - top_k_renorm_probs, -) - -from csrc.cpp_itfs.sampling.top_p_sampling_from_probs import ( - top_p_sampling_from_probs, -) - +from aiter.ops import sampling # noqa: F401 torch.set_default_device("cuda") @@ -55,7 +44,7 @@ def test_top_p_sampling(batch_size, vocab_size, p): num_trails = 1000 for _ in range(num_trails): - samples = top_p_sampling_from_probs( + samples = torch.ops.aiter.top_p_sampling_from_probs( normalized_prob, None, *_to_tensor_scalar_tuple(p), deterministic=True ) assert torch.all(samples < vocab_size) and torch.all(samples >= 0) @@ -80,7 +69,9 @@ def test_top_k_renorm_probs(batch_size, vocab_size, k): dim=-1, keepdim=True ) - renorm_prob = top_k_renorm_probs(normalized_prob, *_to_tensor_scalar_tuple(k)) + renorm_prob = torch.ops.aiter.top_k_renorm_probs( + normalized_prob, *_to_tensor_scalar_tuple(k) + ) for i in range(batch_size): torch.testing.assert_close( renorm_prob_ground_truth[i], @@ -120,7 +111,7 @@ def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p): num_trails = 1000 for _ in range(num_trails): - samples = top_k_top_p_sampling_from_probs( + samples = torch.ops.aiter.top_k_top_p_sampling_from_probs( normalized_prob, None, *_to_tensor_scalar_tuple(top_k_tensor), From 4c01c916672bebea4b65bc66b2b497c3b4dca22e Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Fri, 25 Jul 2025 07:00:00 +0000 Subject: [PATCH 08/19] set shared memory size --- csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja b/csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja index a13241dd61..cb06d1f1e0 100644 --- a/csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja +++ b/csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja @@ -21,5 +21,6 @@ FUNCTION_DEFINE dim3 nblks(batch_size); dim3 nthrs(aiter::sampling::BLOCK_THREADS); auto kernel = aiter::sampling::TopKRenormProbKernel; + hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); kernel<<(stream)>>>(reinterpret_cast(probs_ptr), reinterpret_cast(renormed_probs_ptr), reinterpret_cast(top_k_arr_ptr), top_k_val, {{d}}); } \ No newline at end of file From 4dadfb9664f45ba2840fd021f88042ab719d555a Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Fri, 25 Jul 2025 07:01:54 +0000 Subject: [PATCH 09/19] remove useless code --- csrc/cpp_itfs/sampling/sampling.cuh | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/csrc/cpp_itfs/sampling/sampling.cuh b/csrc/cpp_itfs/sampling/sampling.cuh index f928054a88..cbb76d591b 100644 --- a/csrc/cpp_itfs/sampling/sampling.cuh +++ b/csrc/cpp_itfs/sampling/sampling.cuh @@ -888,33 +888,7 @@ })}); return hipSuccess; } - -// template -// hipError_t TopPSamplingFromProb(T* probs, IdType* output, IdType* indices, T* top_p_arr, -// uint32_t batch_size, T top_p_val, uint32_t d, bool deterministic, -// uint64_t philox_seed, uint64_t philox_offset, -// hipStream_t stream = 0) { -// const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - -// const uint32_t smem_size = sizeof(SamplingTempStorage); -// dim3 nblks(batch_size); -// dim3 nthrs(BLOCK_THREADS); -// void* args[] = {&probs, &output, &indices, &top_p_arr, -// &top_p_val, &d, &philox_seed, &philox_offset}; - -// DISPATCH_ALIGNED_VEC_SIZE( -// vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { -// auto kernel = TopPSamplingFromProbKernel; - -// hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); - -// hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); -// })}); -// return hipSuccess; -// } - template struct RenormTempStorage { union { From 52b9c890b7e3ca387f90126cd3c7c20f749c94d5 Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Fri, 25 Jul 2025 07:12:43 +0000 Subject: [PATCH 10/19] remove cuda code --- csrc/cpp_itfs/sampling/vec_dtypes.cuh | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/csrc/cpp_itfs/sampling/vec_dtypes.cuh b/csrc/cpp_itfs/sampling/vec_dtypes.cuh index e63d146e4b..44fdf5fb15 100644 --- a/csrc/cpp_itfs/sampling/vec_dtypes.cuh +++ b/csrc/cpp_itfs/sampling/vec_dtypes.cuh @@ -289,12 +289,8 @@ for (size_t i = 0; i < vec_size / 2; ++i) { uint16_t y; uint32_t x = *(uint32_t*)&src[i * 2]; - #if defined(__HIPCC__) || (defined(__clang__) && defined(__HIP__)) || defined(__HIPCC_RTC__) __half2 x_h2 = convert_uint32_to_half2(x); y = convert_f16x2_to_e4m3x2(x_h2); - #else - asm volatile("cvt.rn.satfinite.e4m3x2.f16x2 %0, %1;" : "=h"(y) : "r"(x)); - #endif *(uint16_t*)&dst[i * 2] = y; } } @@ -315,13 +311,6 @@ // Extract upper 16 bits __half h2 = __ushort_as_half((x >> 16) & 0xFFFF); - #if 0 - // Alternative with `__uint2half_rn` - uint16_t val1 = x & 0xFFFF; // Lower 16 bits - uint16_t val2 = (x >> 16) & 0xFFFF; // Upper 16 bits - __half h1 = __uint2half_rn(val1); - __half h2 = __uint2half_rn(val2); - #endif // Define the range of e5m2 // Minimum representable value for e5m2 @@ -374,11 +363,7 @@ for (size_t i = 0; i < vec_size / 2; ++i) { uint16_t y; uint32_t x = *(uint32_t*)&src[i * 2]; - #if defined(__HIPCC__) || (defined(__clang__) && defined(__HIP__)) || defined(__HIPCC_RTC__) y = convert_f16x2_to_e5m2x2(x); - #else - asm volatile("cvt.rn.satfinite.e5m2x2.f16x2 %0, %1;" : "=h"(y) : "r"(x)); - #endif *(uint16_t*)&dst[i * 2] = y; } } @@ -439,11 +424,7 @@ for (size_t i = 0; i < vec_size / 2; ++i) { uint32_t y; uint16_t x = *(uint16_t*)&src[i * 2]; - #if defined(__HIPCC__) || (defined(__clang__) && defined(__HIP__)) || defined(__HIPCC_RTC__) y = convert_e4m3x2_to_f16x2(x); - #else - asm volatile("cvt.rn.f16x2.e4m3x2 %0, %1;" : "=r"(y) : "h"(x)); - #endif *(uint32_t*)&dst[i * 2] = y; } } @@ -512,11 +493,7 @@ for (size_t i = 0; i < vec_size / 2; ++i) { uint32_t y; uint16_t x = *(uint16_t*)&src[i * 2]; - #if defined(__HIPCC__) || (defined(__clang__) && defined(__HIP__)) || defined(__HIPCC_RTC__) y = convert_e5m2x2_to_f16x2(x); - #else - asm volatile("cvt.rn.f16x2.e5m2x2 %0, %1;" : "=r"(y) : "h"(x)); - #endif *(uint32_t*)&dst[i * 2] = y; } } @@ -594,11 +571,7 @@ template inline __attribute__((always_inline)) __device__ void cast_from_impl(vec_t& dst, const vec_t& src) { - #if defined(__HIPCC__) || (defined(__clang__) && defined(__HIP__)) || defined(__HIPCC_RTC__) vec_cast::template cast( - #else - vec_cast::cast( - #endif dst.ptr(), const_cast*>(&src)->ptr()); } From ab9eb4a9f8a7d06cfdd25b561e1bf16b5363bcd1 Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Fri, 25 Jul 2025 07:52:35 +0000 Subject: [PATCH 11/19] fix marco --- csrc/cpp_itfs/sampling/vec_dtypes.cuh | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/csrc/cpp_itfs/sampling/vec_dtypes.cuh b/csrc/cpp_itfs/sampling/vec_dtypes.cuh index 44fdf5fb15..794b1517a0 100644 --- a/csrc/cpp_itfs/sampling/vec_dtypes.cuh +++ b/csrc/cpp_itfs/sampling/vec_dtypes.cuh @@ -53,9 +53,6 @@ #define FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED - #define inline __attribute__((always_inline)) __device__ inline __attribute__((always_inline)) __device__ - - /******************* vec_t type cast *******************/ template @@ -1397,7 +1394,7 @@ } inline __attribute__((always_inline)) __device__ __hip_bfloat16* ptr() { return reinterpret_cast<__hip_bfloat16*>(&data); } inline __attribute__((always_inline)) __device__ void fill(__hip_bfloat16 val) { - #pragma unoll + #pragma unroll for (size_t i = 0; i < vec_size / 8; ++i) { *(__hip_bfloat162*)(&(data[i].x)) = make_bfloat162(val, val); *(__hip_bfloat162*)(&(data[i].y)) = make_bfloat162(val, val); @@ -1406,13 +1403,13 @@ } } inline __attribute__((always_inline)) __device__ void load(const __hip_bfloat16* ptr) { - #pragma unoll + #pragma unroll for (size_t i = 0; i < vec_size / 8; ++i) { data[i] = ((uint4*)ptr)[i]; } } inline __attribute__((always_inline)) __device__ void store(__hip_bfloat16* ptr) const { - #pragma unoll + #pragma unroll for (size_t i = 0; i < vec_size / 8; ++i) { ((uint4*)ptr)[i] = data[i]; } @@ -1430,7 +1427,7 @@ cast_store_impl(ptr, *this); } inline __attribute__((always_inline)) __device__ static void memcpy(__hip_bfloat16* dst, const __hip_bfloat16* src) { - #pragma unoll + #pragma unroll for (size_t i = 0; i < vec_size / 8; ++i) { ((uint4*)dst)[i] = ((uint4*)src)[i]; } From ac6106ef514c4aec9b088a78395a649ec3f620a9 Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Fri, 25 Jul 2025 07:54:03 +0000 Subject: [PATCH 12/19] fix macro --- op_tests/test_sampling.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/op_tests/test_sampling.py b/op_tests/test_sampling.py index ec6e9aaa05..0fc9435f1c 100644 --- a/op_tests/test_sampling.py +++ b/op_tests/test_sampling.py @@ -42,8 +42,8 @@ def test_top_p_sampling(batch_size, vocab_size, p): mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0) mask.scatter_add_(1, indices, (cdf > (1 - p) - eps).int()) - num_trails = 1000 - for _ in range(num_trails): + num_trials = 1000 + for _ in range(num_trials): samples = torch.ops.aiter.top_p_sampling_from_probs( normalized_prob, None, *_to_tensor_scalar_tuple(p), deterministic=True ) @@ -84,14 +84,15 @@ def test_top_k_renorm_probs(batch_size, vocab_size, k): @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) @pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) @pytest.mark.parametrize("p", [0.1, 0.5]) -def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p): +@pytest.mark.parametrize("k", [10, 50]) +def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p, k): torch.manual_seed(42) - if p == 0.1: - k = int(vocab_size * 0.5) - elif p == 0.5: - k = int(vocab_size * 0.1) - else: - raise ValueError("p not recognized") + # if p == 0.1: + # k = int(vocab_size * 0.5) + # elif p == 0.5: + # k = int(vocab_size * 0.1) + # else: + # raise ValueError("p not recognized") eps = 1e-4 pre_norm_prob = torch.rand(batch_size, vocab_size) normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) @@ -109,8 +110,8 @@ def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p): top_p_tensor = torch.full((batch_size,), p) top_k_tensor = torch.full((batch_size,), k) - num_trails = 1000 - for _ in range(num_trails): + num_trials = 1000 + for _ in range(num_trials): samples = torch.ops.aiter.top_k_top_p_sampling_from_probs( normalized_prob, None, @@ -125,6 +126,6 @@ def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p): if __name__ == "__main__": - test_top_k_top_p_joint_sampling_from_probs(1, 111, 0.1) - test_top_k_renorm_probs(1, 111, 10) - test_top_p_sampling(1, 111, 0.1) + test_top_k_top_p_joint_sampling_from_probs(40, 129280, 0.6, 20) + # test_top_k_renorm_probs(1, 129280, 10) + # test_top_p_sampling(1, 129280, 0.1) From 701fcfc500ec09cf645867a382f6543dc5f3c0b4 Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Mon, 28 Jul 2025 06:40:45 +0000 Subject: [PATCH 13/19] add copyright --- csrc/cpp_itfs/sampling/sampling.cuh | 8 +++++--- .../sampling/top_k_renorm_probs.cpp.jinja | 17 +++++++++++++++++ csrc/cpp_itfs/sampling/top_k_renorm_probs.py | 4 ++++ .../top_k_top_p_sampling_from_probs.cpp.jinja | 17 +++++++++++++++++ .../sampling/top_k_top_p_sampling_from_probs.py | 4 ++++ .../top_p_sampling_from_probs.cpp.jinja | 16 ++++++++++++++++ .../sampling/top_p_sampling_from_probs.py | 4 ++++ csrc/cpp_itfs/sampling/vec_dtypes.cuh | 6 +----- csrc/cpp_itfs/torch_utils.py | 4 ++++ csrc/cpp_itfs/utils.py | 4 ++++ 10 files changed, 76 insertions(+), 8 deletions(-) diff --git a/csrc/cpp_itfs/sampling/sampling.cuh b/csrc/cpp_itfs/sampling/sampling.cuh index cbb76d591b..d21f427619 100644 --- a/csrc/cpp_itfs/sampling/sampling.cuh +++ b/csrc/cpp_itfs/sampling/sampling.cuh @@ -1,6 +1,3 @@ -#pragma once - -#include "hip/hip_runtime.h" /* * Copyright (C) 2024-2025 by FlashInfer team. * @@ -16,6 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + +#pragma once + +#include "hip/hip_runtime.h" + #include #include diff --git a/csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja b/csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja index cb06d1f1e0..f7d0261f9c 100644 --- a/csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja +++ b/csrc/cpp_itfs/sampling/top_k_renorm_probs.cpp.jinja @@ -1,3 +1,20 @@ +/* + * Copyright (C) 2024-2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + #include "sampling.cuh" diff --git a/csrc/cpp_itfs/sampling/top_k_renorm_probs.py b/csrc/cpp_itfs/sampling/top_k_renorm_probs.py index 1bd797b8e1..cc4722811c 100644 --- a/csrc/cpp_itfs/sampling/top_k_renorm_probs.py +++ b/csrc/cpp_itfs/sampling/top_k_renorm_probs.py @@ -1,3 +1,7 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + + from jinja2 import Template from csrc.cpp_itfs.utils import compile_template_op, AITER_CORE_DIR diff --git a/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.cpp.jinja b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.cpp.jinja index 82d07c224f..301b5c9790 100644 --- a/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.cpp.jinja +++ b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.cpp.jinja @@ -1,3 +1,20 @@ +/* + * Copyright (C) 2024-2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + #include "sampling.cuh" diff --git a/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py index d1dc6f8a49..1af0db5f43 100644 --- a/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py +++ b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py @@ -1,3 +1,7 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + + from jinja2 import Template from csrc.cpp_itfs.utils import compile_template_op, AITER_CORE_DIR diff --git a/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.cpp.jinja b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.cpp.jinja index fb1d7dad5e..99c23b44e7 100644 --- a/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.cpp.jinja +++ b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.cpp.jinja @@ -1,3 +1,19 @@ +/* + * Copyright (C) 2024-2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #include "sampling.cuh" diff --git a/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py index 382e9e8ed6..1d107869d0 100644 --- a/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py +++ b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py @@ -1,3 +1,7 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + + from jinja2 import Template from csrc.cpp_itfs.utils import compile_template_op, AITER_CORE_DIR diff --git a/csrc/cpp_itfs/sampling/vec_dtypes.cuh b/csrc/cpp_itfs/sampling/vec_dtypes.cuh index 794b1517a0..7cb336dd1c 100644 --- a/csrc/cpp_itfs/sampling/vec_dtypes.cuh +++ b/csrc/cpp_itfs/sampling/vec_dtypes.cuh @@ -13,9 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #ifndef VEC_DTYPES_CUH_ - #define VEC_DTYPES_CUH_ - +#pragma once #include #include @@ -1555,5 +1553,3 @@ }; } // namespace flashinfer - - #endif // VEC_DTYPES_CUH_ \ No newline at end of file diff --git a/csrc/cpp_itfs/torch_utils.py b/csrc/cpp_itfs/torch_utils.py index cbc7ed75cb..f4ac4ab71c 100644 --- a/csrc/cpp_itfs/torch_utils.py +++ b/csrc/cpp_itfs/torch_utils.py @@ -1,3 +1,7 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + + import torch import ctypes from torch.library import Library diff --git a/csrc/cpp_itfs/utils.py b/csrc/cpp_itfs/utils.py index fc348af2f2..ecf40ca185 100644 --- a/csrc/cpp_itfs/utils.py +++ b/csrc/cpp_itfs/utils.py @@ -1,3 +1,7 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + + import shutil import os import subprocess From 49fb7969260d0fc0672a2dbca4abe0c821468e73 Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Mon, 28 Jul 2025 06:41:57 +0000 Subject: [PATCH 14/19] format code --- csrc/cpp_itfs/sampling/sampling.cuh | 2607 ++++++++++-------- csrc/cpp_itfs/sampling/vec_dtypes.cuh | 3679 ++++++++++++++----------- 2 files changed, 3599 insertions(+), 2687 deletions(-) diff --git a/csrc/cpp_itfs/sampling/sampling.cuh b/csrc/cpp_itfs/sampling/sampling.cuh index d21f427619..4a6afb0fea 100644 --- a/csrc/cpp_itfs/sampling/sampling.cuh +++ b/csrc/cpp_itfs/sampling/sampling.cuh @@ -18,1163 +18,1470 @@ #include "hip/hip_runtime.h" - - #include - #include - - #include - #include - #include - #include - #include - #include - - #include "vec_dtypes.cuh" - - namespace aiter { - - namespace sampling { - - using namespace hipcub; - - constexpr uint32_t BLOCK_THREADS = 1024; - - constexpr BlockScanAlgorithm SCAN_ALGO = BLOCK_SCAN_WARP_SCANS; - constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS; - - - #define DISPATCH_ALIGNED_VEC_SIZE(aligned_vec_size, ALIGNED_VEC_SIZE, ...) \ - switch (aligned_vec_size) { \ - case 16: { \ - constexpr size_t ALIGNED_VEC_SIZE = 16; \ - __VA_ARGS__ \ - break; \ - } \ - case 8: { \ - constexpr size_t ALIGNED_VEC_SIZE = 8; \ - __VA_ARGS__ \ - break; \ - } \ - case 4: { \ - constexpr size_t ALIGNED_VEC_SIZE = 4; \ - __VA_ARGS__ \ - break; \ - } \ - case 2: { \ - constexpr size_t ALIGNED_VEC_SIZE = 2; \ - __VA_ARGS__ \ - break; \ - } \ - case 1: { \ - constexpr size_t ALIGNED_VEC_SIZE = 1; \ - __VA_ARGS__ \ - break; \ - } \ - default: { \ - std::ostringstream err_msg; \ - err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size; \ - throw std::runtime_error(err_msg.str()); \ - } \ - } - - - #define DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, ...) \ - if (deterministic) { \ - constexpr bool DETERMINISTIC = true; \ - __VA_ARGS__ \ - } else { \ - constexpr bool DETERMINISTIC = false; \ - __VA_ARGS__ \ - } - - template - struct ValueCount { - T value; - int count; - - __device__ ValueCount operator+(const ValueCount& other) const { - return {value + other.value, count + other.count}; - } - __device__ ValueCount& operator+=(const ValueCount& other) { - value += other.value; - count += other.count; - return *this; - } - }; - - struct BoolDiffOp { - __device__ __forceinline__ bool operator()(const bool& lhs, const bool& rhs) const { - return lhs != rhs; - } - }; - - template - __forceinline__ __device__ __host__ T1 ceil_div(const T1 x, const T2 y) { - return (x + y - 1) / y; - } - - template - struct SamplingTempStorage { - union { - float deterministic_scan[BLOCK_THREADS / 32]; - typename BlockScan::TempStorage scan; - typename BlockReduce::TempStorage reduce; - typename BlockReduce::TempStorage reduce_int; - typename BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage - reduce_value_count; - typename BlockAdjacentDifference::TempStorage adj_diff; - } block_prim; - struct { - int32_t sampled_id; - int32_t last_valid_id; - float max_val; - union { - float value; - ValueCount pair; - } block_aggregate; - }; - }; - - template - __device__ __forceinline__ T infinity() { - return __builtin_huge_valf(); - } - - /*! - * \brief Deterministic inclusive scan implementation, use Belloch scan algorithm. - * \note This implementation is slower than the hipcub::BlockScan, but it is deterministic. - */ - template - __device__ __forceinline__ void DeterministicInclusiveSum( - const float* in_data, float* out_data, - SamplingTempStorage* temp_storage) { - float* smem_prefix_sum = temp_storage->block_prim.deterministic_scan; - float thread_data[VEC_SIZE]; - float thread_sum = 0; - #pragma unroll - for (uint32_t i = 0; i < VEC_SIZE; ++i) { - thread_sum += in_data[i]; - thread_data[i] = thread_sum; - } - - float thread_exclusive_prefix_sum = thread_sum; - - #pragma unroll - for (uint32_t offset = 1; offset < 32; offset *= 2) { - float tmp = __shfl_up(thread_exclusive_prefix_sum, offset); - if ((threadIdx.x + 1) % (offset * 2) == 0) { - thread_exclusive_prefix_sum += tmp; - } - } - - float warp_sum = __shfl(thread_exclusive_prefix_sum, threadIdx.x | 0xffffffff); - if (threadIdx.x % 32 == 31) { - thread_exclusive_prefix_sum = 0; - } - - #pragma unroll - for (uint32_t offset = 16; offset >= 1; offset /= 2) { - float tmp = __shfl_xor(thread_exclusive_prefix_sum, offset); - if ((threadIdx.x + 1) % (offset * 2) == 0) { - thread_exclusive_prefix_sum = tmp + thread_exclusive_prefix_sum; - } - if ((threadIdx.x + 1) % (offset * 2) == offset) { - thread_exclusive_prefix_sum = tmp; - } - } - - smem_prefix_sum[threadIdx.x / 32] = warp_sum; - __syncthreads(); - - if (threadIdx.x < 32) { - float warp_exclusive_prefix_sum = - (threadIdx.x < BLOCK_THREADS / 32) ? smem_prefix_sum[threadIdx.x] : 0; - - #pragma unroll - for (uint32_t offset = 1; offset < 32; offset *= 2) { - float tmp = __shfl_up(warp_exclusive_prefix_sum, offset); - if ((threadIdx.x + 1) % (offset * 2) == 0) { - warp_exclusive_prefix_sum += tmp; - } - } - - if (threadIdx.x % 32 == 31) { - warp_exclusive_prefix_sum = 0; - } - - #pragma unroll - for (uint32_t offset = 16; offset >= 1; offset /= 2) { - float tmp = __shfl_xor(warp_exclusive_prefix_sum, offset); - if ((threadIdx.x + 1) % (offset * 2) == 0) { - warp_exclusive_prefix_sum = tmp + warp_exclusive_prefix_sum; - } - if ((threadIdx.x + 1) % (offset * 2) == offset) { - warp_exclusive_prefix_sum = tmp; - } - } - if (threadIdx.x < BLOCK_THREADS / 32) { - smem_prefix_sum[threadIdx.x] = warp_exclusive_prefix_sum; - } - } - __syncthreads(); - - #pragma unroll - for (uint32_t i = 0; i < VEC_SIZE; ++i) { - out_data[i] = smem_prefix_sum[threadIdx.x / 32] + thread_exclusive_prefix_sum + thread_data[i]; - } - } - - template - __device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, uint32_t d, - TempStorage& temp_storage) { - const uint32_t tx = threadIdx.x; - vec_t in_data_vec; - - float max_val = 0; - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - in_data_vec.fill(0); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - in_data_vec.cast_load(in_data + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); - } - float in_data_[VEC_SIZE]; - #pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - in_data_[j] = in_data_vec[j]; - } - max_val = max( - max_val, BlockReduce(temp_storage.block_prim.reduce) - .Reduce(in_data_, hipcub::Max())); - __syncthreads(); - } - if (tx == 0) { - temp_storage.max_val = max_val; - } - __syncthreads(); - return temp_storage.max_val; - } - - template - __device__ __forceinline__ void DeviceSamplingFromProb( - uint32_t i, uint32_t d, Predicate pred, float u, vec_t prob_vec, - float& aggregate, - SamplingTempStorage* temp_storage) { - const uint32_t tx = threadIdx.x; - float prob_greater_than_threshold[VEC_SIZE]; - float inclusive_cdf[VEC_SIZE]; - bool greater_than_u[VEC_SIZE], valid[VEC_SIZE]; - #pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - prob_greater_than_threshold[j] = pred(prob_vec[j]) ? prob_vec[j] : 0; - valid[j] = pred(prob_vec[j]) && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d; - } - float aggregate_local = - BlockReduce(temp_storage->block_prim.reduce) - .Sum(prob_greater_than_threshold); - if (tx == 0) { - temp_storage->block_aggregate.value = aggregate_local; - } - __syncthreads(); - aggregate_local = temp_storage->block_aggregate.value; - - if (aggregate + aggregate_local > u) { - if constexpr (DETERMINISTIC) { - DeterministicInclusiveSum( - prob_greater_than_threshold, inclusive_cdf, temp_storage); - } else { - BlockScan(temp_storage->block_prim.scan) - .InclusiveSum(prob_greater_than_threshold, inclusive_cdf); - - __syncthreads(); - } - - #pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - greater_than_u[j] = (inclusive_cdf[j] + aggregate > u) && valid[j]; - } - - bool greater_than_u_diff[VEC_SIZE]; - - BlockAdjacentDifference(temp_storage->block_prim.adj_diff) - .SubtractLeft(greater_than_u, greater_than_u_diff, BoolDiffOp{}); - - __syncthreads(); - - #pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - if (greater_than_u_diff[j]) { - atomicMin(&(temp_storage->sampled_id), (i * BLOCK_THREADS + tx) * VEC_SIZE + j); - } - } - __syncthreads(); - } - - // update the last valid index - int valid_index[VEC_SIZE]; - #pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - if (valid[j]) { - valid_index[j] = (i * BLOCK_THREADS + tx) * VEC_SIZE + j; - } else { - valid_index[j] = -1; - } - } - int max_valid_index = - BlockReduce(temp_storage->block_prim.reduce_int) - .Reduce(valid_index, hipcub::Max()); - if (tx == 0 && max_valid_index != -1) { - temp_storage->last_valid_id = max_valid_index; - } - __syncthreads(); - aggregate += aggregate_local; - } - - template - struct DataAndIndex { - DType data; - IdType index; - - __device__ DataAndIndex operator+(const DataAndIndex& other) const { - if (data > other.data) { - return {data, index}; - } else { - return {other.data, other.index}; - } - } - __device__ DataAndIndex& operator+=(const DataAndIndex& other) { - if (data > other.data) { - return *this; - } else { - data = other.data; - index = other.index; - return *this; - } - } - }; - - template - __device__ __forceinline__ vec_t GenerateGumbelNoise(uint64_t philox_seed, - uint64_t philox_offset, - uint64_t subsequence) { - hiprandStatePhilox4_32_10_t state; - vec_t noise; - constexpr float kEPSILON = 1e-20f; - constexpr float kLOG2 = 0.6931471806f; - auto uniform2gumbel = [](float x) { return -kLOG2 * log2f(-log2f(x + kEPSILON) + kEPSILON); }; - // TODO: compare the speed of log2 and log - #pragma unroll - for (uint32_t i = 0; i + 4 <= VEC_SIZE; i += 4) { - hiprand_init(philox_seed, subsequence + i, philox_offset, &state); - float4 noise_vec = hiprand_uniform4(&state); - noise[i] = uniform2gumbel(noise_vec.x); - noise[i + 1] = uniform2gumbel(noise_vec.y); - noise[i + 2] = uniform2gumbel(noise_vec.z); - noise[i + 3] = uniform2gumbel(noise_vec.w); - } - if constexpr (VEC_SIZE % 4 != 0) { - hiprand_init(philox_seed, subsequence + VEC_SIZE / 4 * 4, philox_offset, &state); - float4 noise_vec = hiprand_uniform4(&state); - if constexpr (VEC_SIZE % 4 == 1) { - noise[VEC_SIZE - 1] = uniform2gumbel(noise_vec.x); - } else if constexpr (VEC_SIZE % 4 == 2) { - noise[VEC_SIZE - 2] = uniform2gumbel(noise_vec.x); - noise[VEC_SIZE - 1] = uniform2gumbel(noise_vec.y); - } else if constexpr (VEC_SIZE % 4 == 3) { - noise[VEC_SIZE - 3] = uniform2gumbel(noise_vec.x); - noise[VEC_SIZE - 2] = uniform2gumbel(noise_vec.y); - noise[VEC_SIZE - 1] = uniform2gumbel(noise_vec.z); - } - } - - if constexpr (std::is_same_v) { - return noise; - } else { - vec_t ret; - #pragma unroll - for (uint32_t i = 0; i < VEC_SIZE; ++i) { - ret[i] = static_cast(noise[i]); - } - return ret; - } - } - - template - __global__ void SamplingFromLogitsKernel(DType* logits, IdType* output, IdType* indices, uint32_t d, - uint64_t philox_seed, uint64_t philox_offset) { - const uint32_t bx = blockIdx.x, tx = threadIdx.x; - const uint32_t row_idx = indices == nullptr ? bx : indices[bx]; - using SharedMem = typename BlockReduce, BLOCK_THREADS, - REDUCE_ALGORITHM>::TempStorage; - extern __shared__ __align__(alignof(SharedMem)) uint8_t smem_sampling[]; - auto& temp_storage = reinterpret_cast(smem_sampling); - - vec_t logits_vec; - DataAndIndex max_data = {-infinity(), 0}; - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - logits_vec.fill(-infinity()); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - logits_vec.cast_load(logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); - } - - vec_t gumbel_noise = GenerateGumbelNoise( - philox_seed, philox_offset, - static_cast(bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE)); - DataAndIndex cur_data[VEC_SIZE]; - #pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - cur_data[j].data = (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d - ? logits_vec[j] + gumbel_noise[j] - : -infinity(); - cur_data[j].index = (i * BLOCK_THREADS + tx) * VEC_SIZE + j; - } - - max_data += - BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage) - .Sum(cur_data); - } - if (tx == 0) { - output[bx] = max_data.index; - } - } - - template - __global__ void SamplingFromProbKernel(DType* probs, IdType* output, IdType* indices, uint32_t d, - uint64_t philox_seed, uint64_t philox_offset) { - hiprandStatePhilox4_32_10_t state; - const uint32_t bx = blockIdx.x, tx = threadIdx.x; - hiprand_init(philox_seed, bx, philox_offset, &state); - const uint32_t row_idx = indices == nullptr ? bx : indices[bx]; - - extern __shared__ __align__( - alignof(SamplingTempStorage)) - uint8_t smem_sampling[]; - auto& temp_storage = - reinterpret_cast&>( - smem_sampling); - temp_storage.sampled_id = d; - __syncthreads(); - - vec_t probs_vec; - float aggregate(0); - float u = hiprand_uniform(&state); - - #pragma unroll 2 - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(0); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); - } - - DeviceSamplingFromProb( - i, d, [](float x) { return x > 0; }, u, probs_vec, aggregate, &temp_storage); - if (float(aggregate) > u) { - break; - } - } - int sampled_id = temp_storage.sampled_id; - if (sampled_id == d) { - // NOTE(Zihao): this would happen when u is very close to 1 - // and the sum of probabilities is smaller than u - // In this case, we use the last valid index as the sampled id - sampled_id = temp_storage.last_valid_id; - } - output[bx] = sampled_id; - } - - template - __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType* indices, - IdType* top_k_arr, uint32_t top_k_val, uint32_t d, - uint64_t philox_seed, uint64_t philox_offset) { - const uint32_t batch_size = gridDim.x; - const uint32_t bx = blockIdx.x, tx = threadIdx.x; - hiprandStatePhilox4_32_10_t state; - hiprand_init(philox_seed, bx, philox_offset, &state); - const uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx]; - const uint32_t row_idx = indices == nullptr ? bx : indices[bx]; - - extern __shared__ __align__( - alignof(SamplingTempStorage)) - uint8_t smem_sampling[]; - auto& temp_storage = - reinterpret_cast&>( - smem_sampling); - - vec_t probs_vec; - float aggregate; - float q = 1; - double low = 0, high = 1.f; - int sampled_id; - int round = 0; - do { - round += 1; - temp_storage.sampled_id = d; - __syncthreads(); - float u = hiprand_uniform(&state) * q; - aggregate = 0; - #pragma unroll 2 - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(0); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); - } - - DeviceSamplingFromProb( - i, d, [&](float x) { return x > low; }, u, probs_vec, aggregate, &temp_storage); - if (aggregate > u) { - break; - } - } - __syncthreads(); - sampled_id = temp_storage.sampled_id; - if (sampled_id == d) { - // NOTE(Zihao): this would happen when u is very close to 1 - // and the sum of probabilities is smaller than u - // In this case, we use the last valid index as the sampled id - sampled_id = temp_storage.last_valid_id; - } - double pivot_0 = probs[row_idx * d + sampled_id]; - double pivot_1 = (pivot_0 + high) / 2; - - ValueCount aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0}; - #pragma unroll 2 - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(0); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); - } - - ValueCount probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE]; - #pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - probs_gt_pivot_0[j] = { - (probs_vec[j] > pivot_0) ? probs_vec[j] : 0, - (probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; - probs_gt_pivot_1[j] = { - (probs_vec[j] > pivot_1) ? probs_vec[j] : 0, - (probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; - } - - aggregate_gt_pivot_0 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( - temp_storage.block_prim.reduce_value_count) - .Sum(probs_gt_pivot_0); - if (tx == 0) { - temp_storage.block_aggregate.pair = aggregate_gt_pivot_0; - } - __syncthreads(); - aggregate_gt_pivot_0 = temp_storage.block_aggregate.pair; - - aggregate_gt_pivot_1 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( - temp_storage.block_prim.reduce_value_count) - .Sum(probs_gt_pivot_1); - if (tx == 0) { - temp_storage.block_aggregate.pair = aggregate_gt_pivot_1; - } - __syncthreads(); - aggregate_gt_pivot_1 = temp_storage.block_aggregate.pair; - } - if (aggregate_gt_pivot_0.count < k) { - // case 1: pivot_0 accepted - break; - } - if (aggregate_gt_pivot_1.count < k) { - // case 2: pivot_0 rejected, pivot_1 accepted - low = pivot_0; - high = pivot_1; - q = aggregate_gt_pivot_0.value; - } else { - // case 3: pivot_0 rejected, pivot_1 rejected - low = pivot_1; - q = aggregate_gt_pivot_1.value; - } - } while (low < high); - __syncthreads(); - if (tx == 0) { - output[bx] = sampled_id; - } - } - - template - __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType* indices, - float* top_p_arr, float top_p_val, uint32_t d, - uint64_t philox_seed, uint64_t philox_offset) { - const uint32_t batch_size = gridDim.x; - const uint32_t bx = blockIdx.x, tx = threadIdx.x; - hiprandStatePhilox4_32_10_t state; - hiprand_init(philox_seed, bx, philox_offset, &state); - const uint32_t row_idx = indices == nullptr ? bx : indices[bx]; - float top_p = (top_p_arr == nullptr) ? top_p_val : top_p_arr[row_idx]; - - extern __shared__ __align__( - alignof(SamplingTempStorage)) - uint8_t smem_sampling[]; - auto& temp_storage = - reinterpret_cast&>( - smem_sampling); - - vec_t probs_vec; - float aggregate; - float q = 1; - double low = 0, high = 1.f; - int sampled_id; - do { - temp_storage.sampled_id = d; - __syncthreads(); - float u = hiprand_uniform(&state) * q; - aggregate = 0; - #pragma unroll 2 - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(0); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); - } - - DeviceSamplingFromProb( - i, d, [&](float x) { return x > low; }, u, probs_vec, aggregate, &temp_storage); - if (aggregate > u) { - break; - } - } - __syncthreads(); - sampled_id = temp_storage.sampled_id; - if (sampled_id == d) { - // NOTE(Zihao): this would happen when u is very close to 1 - // and the sum of probabilities is smaller than u - // In this case, we use the last valid index as the sampled id - sampled_id = temp_storage.last_valid_id; - } - double pivot_0 = probs[row_idx * d + sampled_id]; - double pivot_1 = (pivot_0 + high) / 2; - - float aggregate_gt_pivot_0 = 0, aggregate_gt_pivot_1 = 0; - #pragma unroll 2 - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(0); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); - } - - float probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE]; - #pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - probs_gt_pivot_0[j] = (probs_vec[j] > pivot_0) ? probs_vec[j] : 0; - probs_gt_pivot_1[j] = (probs_vec[j] > pivot_1) ? probs_vec[j] : 0; - } - - aggregate_gt_pivot_0 += BlockReduce(temp_storage.block_prim.reduce) - .Sum(probs_gt_pivot_0); - if (tx == 0) { - temp_storage.block_aggregate.value = aggregate_gt_pivot_0; - } - __syncthreads(); - aggregate_gt_pivot_0 = temp_storage.block_aggregate.value; - - aggregate_gt_pivot_1 += BlockReduce(temp_storage.block_prim.reduce) - .Sum(probs_gt_pivot_1); - if (tx == 0) { - temp_storage.block_aggregate.value = aggregate_gt_pivot_1; - } - __syncthreads(); - aggregate_gt_pivot_1 = temp_storage.block_aggregate.value; - } - if (aggregate_gt_pivot_0 < top_p) { - // case 1: pivot_0 accepted - break; - } - if (aggregate_gt_pivot_1 < top_p) { - // case 2: pivot_0 rejected, pivot_1 accepted - low = pivot_0; - high = pivot_1; - q = aggregate_gt_pivot_0; - } else { - // case 3: pivot_0 rejected, pivot_1 rejected - low = pivot_1; - q = aggregate_gt_pivot_1; - } - } while (low < high); - __syncthreads(); - if (tx == 0) { - output[bx] = sampled_id; - } - } - - template - __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* top_k_arr, float* top_p_arr, - IdType* output, IdType* indices, IdType top_k_val, - float top_p_val, uint32_t d, uint64_t philox_seed, - uint64_t philox_offset) { - const uint32_t batch_size = gridDim.x; - const uint32_t bx = blockIdx.x, tx = threadIdx.x; - hiprandStatePhilox4_32_10_t state; - hiprand_init(philox_seed, bx, philox_offset, &state); - const uint32_t row_idx = indices == nullptr ? bx : indices[bx]; - const uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[row_idx]; - const float p = top_p_arr == nullptr ? top_p_val : top_p_arr[row_idx]; - - extern __shared__ __align__( - alignof(SamplingTempStorage)) - uint8_t smem_sampling[]; - auto& temp_storage = - reinterpret_cast&>( - smem_sampling); - - vec_t probs_vec; - float aggregate; - float q = 1; - double low = 0, high = 1.f; - int sampled_id; - do { - temp_storage.sampled_id = d; - __syncthreads(); - float u = hiprand_uniform(&state) * q; - aggregate = 0; - #pragma unroll 2 - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(0); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); - } - - DeviceSamplingFromProb( - i, d, [&](float x) { return x > low; }, u, probs_vec, aggregate, &temp_storage); - if (aggregate > u) { - break; - } - } - __syncthreads(); - sampled_id = temp_storage.sampled_id; - if (sampled_id == d) { - // NOTE(Zihao): this would happen when u is very close to 1 - // and the sum of probabilities is smaller than u - // In this case, we use the last valid index as the sampled id - sampled_id = temp_storage.last_valid_id; - } - double pivot_0 = probs[row_idx * d + sampled_id]; - double pivot_1 = (pivot_0 + high) / 2; - - ValueCount aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0}; - #pragma unroll 2 - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(0); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); - } - - ValueCount probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE]; - #pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - probs_gt_pivot_0[j] = { - (probs_vec[j] > pivot_0) ? probs_vec[j] : 0, - (probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; - probs_gt_pivot_1[j] = { - (probs_vec[j] > pivot_1) ? probs_vec[j] : 0, - (probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; - } - - aggregate_gt_pivot_0 += - BlockReduce, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count) - .Sum(probs_gt_pivot_0); - if (tx == 0) { - temp_storage.block_aggregate.pair = aggregate_gt_pivot_0; - } - __syncthreads(); - aggregate_gt_pivot_0 = temp_storage.block_aggregate.pair; - - aggregate_gt_pivot_1 += - BlockReduce, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count) - .Sum(probs_gt_pivot_1); - if (tx == 0) { - temp_storage.block_aggregate.pair = aggregate_gt_pivot_1; - } - __syncthreads(); - aggregate_gt_pivot_1 = temp_storage.block_aggregate.pair; - } - if (aggregate_gt_pivot_0.count < k && aggregate_gt_pivot_0.value < p) { - // case 1: pivot_0 accepted - break; - } - if (aggregate_gt_pivot_1.count < k && aggregate_gt_pivot_1.value < p) { - // case 2: pivot_0 rejected, pivot_1 accepted - low = pivot_0; - high = pivot_1; - q = aggregate_gt_pivot_0.value; - } else { - // case 3: pivot_0 rejected, pivot_1 rejected - low = pivot_1; - q = aggregate_gt_pivot_1.value; - } - } while (low < high); - __syncthreads(); - if (tx == 0) { - output[bx] = sampled_id; - } - } - - template - hipError_t SamplingFromLogits(T* logits, IdType* output, IdType* indices, uint32_t batch_size, - uint32_t d, bool deterministic, uint64_t philox_seed, - uint64_t philox_offset, hipStream_t stream = 0) { - - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&logits, &output, &indices, &d, &philox_seed, &philox_offset}; - const uint32_t smem_size = sizeof( - typename BlockReduce, BLOCK_THREADS, REDUCE_ALGO>::TempStorage); - - DISPATCH_ALIGNED_VEC_SIZE( - vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { - auto kernel = SamplingFromLogitsKernel; - hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); - })}); - return hipSuccess; - } - - template - hipError_t SamplingFromProb(T* probs, IdType* output, IdType* indices, uint32_t batch_size, - uint32_t d, bool deterministic, uint64_t philox_seed, - uint64_t philox_offset, hipStream_t stream = 0) { - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &output, &indices, &d, &philox_seed, &philox_offset, &d}; - const uint32_t smem_size = sizeof(SamplingTempStorage); - - DISPATCH_ALIGNED_VEC_SIZE( - vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { - auto kernel = SamplingFromProbKernel; - - hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); - })}); - return hipSuccess; - } - - template - hipError_t TopKSamplingFromProb(T* probs, IdType* output, IdType* indices, T* top_k_arr, - uint32_t batch_size, uint32_t top_k_val, uint32_t d, - bool deterministic, uint64_t philox_seed, uint64_t philox_offset, - hipStream_t stream = 0) { +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "vec_dtypes.cuh" + +namespace aiter { + +namespace sampling { + +using namespace hipcub; + +constexpr uint32_t BLOCK_THREADS = 1024; + +constexpr BlockScanAlgorithm SCAN_ALGO = BLOCK_SCAN_WARP_SCANS; +constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS; + +#define DISPATCH_ALIGNED_VEC_SIZE(aligned_vec_size, ALIGNED_VEC_SIZE, ...) \ + switch(aligned_vec_size) \ + { \ + case 16: { \ + constexpr size_t ALIGNED_VEC_SIZE = 16; \ + __VA_ARGS__ \ + break; \ + } \ + case 8: { \ + constexpr size_t ALIGNED_VEC_SIZE = 8; \ + __VA_ARGS__ \ + break; \ + } \ + case 4: { \ + constexpr size_t ALIGNED_VEC_SIZE = 4; \ + __VA_ARGS__ \ + break; \ + } \ + case 2: { \ + constexpr size_t ALIGNED_VEC_SIZE = 2; \ + __VA_ARGS__ \ + break; \ + } \ + case 1: { \ + constexpr size_t ALIGNED_VEC_SIZE = 1; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size; \ + throw std::runtime_error(err_msg.str()); \ + } \ + } + +#define DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, ...) \ + if(deterministic) \ + { \ + constexpr bool DETERMINISTIC = true; \ + __VA_ARGS__ \ + } \ + else \ + { \ + constexpr bool DETERMINISTIC = false; \ + __VA_ARGS__ \ + } + +template +struct ValueCount +{ + T value; + int count; + + __device__ ValueCount operator+(const ValueCount& other) const + { + return {value + other.value, count + other.count}; + } + __device__ ValueCount& operator+=(const ValueCount& other) + { + value += other.value; + count += other.count; + return *this; + } +}; + +struct BoolDiffOp +{ + __device__ __forceinline__ bool operator()(const bool& lhs, const bool& rhs) const + { + return lhs != rhs; + } +}; + +template +__forceinline__ __device__ __host__ T1 ceil_div(const T1 x, const T2 y) +{ + return (x + y - 1) / y; +} + +template +struct SamplingTempStorage +{ + union + { + float deterministic_scan[BLOCK_THREADS / 32]; + typename BlockScan::TempStorage scan; + typename BlockReduce::TempStorage reduce; + typename BlockReduce::TempStorage reduce_int; + typename BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage + reduce_value_count; + typename BlockAdjacentDifference::TempStorage adj_diff; + } block_prim; + struct + { + int32_t sampled_id; + int32_t last_valid_id; + float max_val; + union + { + float value; + ValueCount pair; + } block_aggregate; + }; +}; + +template +__device__ __forceinline__ T infinity() +{ + return __builtin_huge_valf(); +} + +/*! + * \brief Deterministic inclusive scan implementation, use Belloch scan algorithm. + * \note This implementation is slower than the hipcub::BlockScan, but it is deterministic. + */ +template +__device__ __forceinline__ void DeterministicInclusiveSum( + const float* in_data, + float* out_data, + SamplingTempStorage* temp_storage) +{ + float* smem_prefix_sum = temp_storage->block_prim.deterministic_scan; + float thread_data[VEC_SIZE]; + float thread_sum = 0; +#pragma unroll + for(uint32_t i = 0; i < VEC_SIZE; ++i) + { + thread_sum += in_data[i]; + thread_data[i] = thread_sum; + } + + float thread_exclusive_prefix_sum = thread_sum; + +#pragma unroll + for(uint32_t offset = 1; offset < 32; offset *= 2) + { + float tmp = __shfl_up(thread_exclusive_prefix_sum, offset); + if((threadIdx.x + 1) % (offset * 2) == 0) + { + thread_exclusive_prefix_sum += tmp; + } + } + + float warp_sum = __shfl(thread_exclusive_prefix_sum, threadIdx.x | 0xffffffff); + if(threadIdx.x % 32 == 31) + { + thread_exclusive_prefix_sum = 0; + } + +#pragma unroll + for(uint32_t offset = 16; offset >= 1; offset /= 2) + { + float tmp = __shfl_xor(thread_exclusive_prefix_sum, offset); + if((threadIdx.x + 1) % (offset * 2) == 0) + { + thread_exclusive_prefix_sum = tmp + thread_exclusive_prefix_sum; + } + if((threadIdx.x + 1) % (offset * 2) == offset) + { + thread_exclusive_prefix_sum = tmp; + } + } + + smem_prefix_sum[threadIdx.x / 32] = warp_sum; + __syncthreads(); + + if(threadIdx.x < 32) + { + float warp_exclusive_prefix_sum = + (threadIdx.x < BLOCK_THREADS / 32) ? smem_prefix_sum[threadIdx.x] : 0; + +#pragma unroll + for(uint32_t offset = 1; offset < 32; offset *= 2) + { + float tmp = __shfl_up(warp_exclusive_prefix_sum, offset); + if((threadIdx.x + 1) % (offset * 2) == 0) + { + warp_exclusive_prefix_sum += tmp; + } + } + + if(threadIdx.x % 32 == 31) + { + warp_exclusive_prefix_sum = 0; + } + +#pragma unroll + for(uint32_t offset = 16; offset >= 1; offset /= 2) + { + float tmp = __shfl_xor(warp_exclusive_prefix_sum, offset); + if((threadIdx.x + 1) % (offset * 2) == 0) + { + warp_exclusive_prefix_sum = tmp + warp_exclusive_prefix_sum; + } + if((threadIdx.x + 1) % (offset * 2) == offset) + { + warp_exclusive_prefix_sum = tmp; + } + } + if(threadIdx.x < BLOCK_THREADS / 32) + { + smem_prefix_sum[threadIdx.x] = warp_exclusive_prefix_sum; + } + } + __syncthreads(); + +#pragma unroll + for(uint32_t i = 0; i < VEC_SIZE; ++i) + { + out_data[i] = + smem_prefix_sum[threadIdx.x / 32] + thread_exclusive_prefix_sum + thread_data[i]; + } +} + +template +__device__ __forceinline__ float +GetMaxValue(float* in_data, uint32_t row_idx, uint32_t d, TempStorage& temp_storage) +{ + const uint32_t tx = threadIdx.x; + vec_t in_data_vec; + + float max_val = 0; + for(uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) + { + in_data_vec.fill(0); + if((i * BLOCK_THREADS + tx) * VEC_SIZE < d) + { + in_data_vec.cast_load(in_data + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + } + float in_data_[VEC_SIZE]; +#pragma unroll + for(uint32_t j = 0; j < VEC_SIZE; ++j) + { + in_data_[j] = in_data_vec[j]; + } + max_val = + max(max_val, + BlockReduce(temp_storage.block_prim.reduce) + .Reduce(in_data_, hipcub::Max())); + __syncthreads(); + } + if(tx == 0) + { + temp_storage.max_val = max_val; + } + __syncthreads(); + return temp_storage.max_val; +} + +template +__device__ __forceinline__ void DeviceSamplingFromProb( + uint32_t i, + uint32_t d, + Predicate pred, + float u, + vec_t prob_vec, + float& aggregate, + SamplingTempStorage* temp_storage) +{ + const uint32_t tx = threadIdx.x; + float prob_greater_than_threshold[VEC_SIZE]; + float inclusive_cdf[VEC_SIZE]; + bool greater_than_u[VEC_SIZE], valid[VEC_SIZE]; +#pragma unroll + for(uint32_t j = 0; j < VEC_SIZE; ++j) + { + prob_greater_than_threshold[j] = pred(prob_vec[j]) ? prob_vec[j] : 0; + valid[j] = pred(prob_vec[j]) && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d; + } + float aggregate_local = + BlockReduce(temp_storage->block_prim.reduce) + .Sum(prob_greater_than_threshold); + if(tx == 0) + { + temp_storage->block_aggregate.value = aggregate_local; + } + __syncthreads(); + aggregate_local = temp_storage->block_aggregate.value; + + if(aggregate + aggregate_local > u) + { + if constexpr(DETERMINISTIC) + { + DeterministicInclusiveSum( + prob_greater_than_threshold, inclusive_cdf, temp_storage); + } + else + { + BlockScan(temp_storage->block_prim.scan) + .InclusiveSum(prob_greater_than_threshold, inclusive_cdf); + + __syncthreads(); + } + +#pragma unroll + for(uint32_t j = 0; j < VEC_SIZE; ++j) + { + greater_than_u[j] = (inclusive_cdf[j] + aggregate > u) && valid[j]; + } + + bool greater_than_u_diff[VEC_SIZE]; + + BlockAdjacentDifference(temp_storage->block_prim.adj_diff) + .SubtractLeft(greater_than_u, greater_than_u_diff, BoolDiffOp{}); + + __syncthreads(); + +#pragma unroll + for(uint32_t j = 0; j < VEC_SIZE; ++j) + { + if(greater_than_u_diff[j]) + { + atomicMin(&(temp_storage->sampled_id), (i * BLOCK_THREADS + tx) * VEC_SIZE + j); + } + } + __syncthreads(); + } + + // update the last valid index + int valid_index[VEC_SIZE]; +#pragma unroll + for(uint32_t j = 0; j < VEC_SIZE; ++j) + { + if(valid[j]) + { + valid_index[j] = (i * BLOCK_THREADS + tx) * VEC_SIZE + j; + } + else + { + valid_index[j] = -1; + } + } + int max_valid_index = + BlockReduce(temp_storage->block_prim.reduce_int) + .Reduce(valid_index, hipcub::Max()); + if(tx == 0 && max_valid_index != -1) + { + temp_storage->last_valid_id = max_valid_index; + } + __syncthreads(); + aggregate += aggregate_local; +} + +template +struct DataAndIndex +{ + DType data; + IdType index; + + __device__ DataAndIndex operator+(const DataAndIndex& other) const + { + if(data > other.data) + { + return {data, index}; + } + else + { + return {other.data, other.index}; + } + } + __device__ DataAndIndex& operator+=(const DataAndIndex& other) + { + if(data > other.data) + { + return *this; + } + else + { + data = other.data; + index = other.index; + return *this; + } + } +}; + +template +__device__ __forceinline__ vec_t +GenerateGumbelNoise(uint64_t philox_seed, uint64_t philox_offset, uint64_t subsequence) +{ + hiprandStatePhilox4_32_10_t state; + vec_t noise; + constexpr float kEPSILON = 1e-20f; + constexpr float kLOG2 = 0.6931471806f; + auto uniform2gumbel = [](float x) { return -kLOG2 * log2f(-log2f(x + kEPSILON) + kEPSILON); }; +// TODO: compare the speed of log2 and log +#pragma unroll + for(uint32_t i = 0; i + 4 <= VEC_SIZE; i += 4) + { + hiprand_init(philox_seed, subsequence + i, philox_offset, &state); + float4 noise_vec = hiprand_uniform4(&state); + noise[i] = uniform2gumbel(noise_vec.x); + noise[i + 1] = uniform2gumbel(noise_vec.y); + noise[i + 2] = uniform2gumbel(noise_vec.z); + noise[i + 3] = uniform2gumbel(noise_vec.w); + } + if constexpr(VEC_SIZE % 4 != 0) + { + hiprand_init(philox_seed, subsequence + VEC_SIZE / 4 * 4, philox_offset, &state); + float4 noise_vec = hiprand_uniform4(&state); + if constexpr(VEC_SIZE % 4 == 1) + { + noise[VEC_SIZE - 1] = uniform2gumbel(noise_vec.x); + } + else if constexpr(VEC_SIZE % 4 == 2) + { + noise[VEC_SIZE - 2] = uniform2gumbel(noise_vec.x); + noise[VEC_SIZE - 1] = uniform2gumbel(noise_vec.y); + } + else if constexpr(VEC_SIZE % 4 == 3) + { + noise[VEC_SIZE - 3] = uniform2gumbel(noise_vec.x); + noise[VEC_SIZE - 2] = uniform2gumbel(noise_vec.y); + noise[VEC_SIZE - 1] = uniform2gumbel(noise_vec.z); + } + } + + if constexpr(std::is_same_v) + { + return noise; + } + else + { + vec_t ret; +#pragma unroll + for(uint32_t i = 0; i < VEC_SIZE; ++i) + { + ret[i] = static_cast(noise[i]); + } + return ret; + } +} + +template +__global__ void SamplingFromLogitsKernel(DType* logits, + IdType* output, + IdType* indices, + uint32_t d, + uint64_t philox_seed, + uint64_t philox_offset) +{ + const uint32_t bx = blockIdx.x, tx = threadIdx.x; + const uint32_t row_idx = indices == nullptr ? bx : indices[bx]; + using SharedMem = typename BlockReduce, + BLOCK_THREADS, + REDUCE_ALGORITHM>::TempStorage; + extern __shared__ __align__(alignof(SharedMem)) uint8_t smem_sampling[]; + auto& temp_storage = reinterpret_cast(smem_sampling); + + vec_t logits_vec; + DataAndIndex max_data = {-infinity(), 0}; + for(uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) + { + logits_vec.fill(-infinity()); + if((i * BLOCK_THREADS + tx) * VEC_SIZE < d) + { + logits_vec.cast_load(logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + + tx * VEC_SIZE); + } + + vec_t gumbel_noise = GenerateGumbelNoise( + philox_seed, + philox_offset, + static_cast(bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE)); + DataAndIndex cur_data[VEC_SIZE]; +#pragma unroll + for(uint32_t j = 0; j < VEC_SIZE; ++j) + { + cur_data[j].data = (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d + ? logits_vec[j] + gumbel_noise[j] + : -infinity(); + cur_data[j].index = (i * BLOCK_THREADS + tx) * VEC_SIZE + j; + } + + max_data += + BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage) + .Sum(cur_data); + } + if(tx == 0) + { + output[bx] = max_data.index; + } +} + +template +__global__ void SamplingFromProbKernel(DType* probs, + IdType* output, + IdType* indices, + uint32_t d, + uint64_t philox_seed, + uint64_t philox_offset) +{ + hiprandStatePhilox4_32_10_t state; + const uint32_t bx = blockIdx.x, tx = threadIdx.x; + hiprand_init(philox_seed, bx, philox_offset, &state); + const uint32_t row_idx = indices == nullptr ? bx : indices[bx]; + + extern __shared__ __align__( + alignof(SamplingTempStorage)) + uint8_t smem_sampling[]; + auto& temp_storage = + reinterpret_cast&>( + smem_sampling); + temp_storage.sampled_id = d; + __syncthreads(); + + vec_t probs_vec; + float aggregate(0); + float u = hiprand_uniform(&state); + +#pragma unroll 2 + for(uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) + { + probs_vec.fill(0); + if((i * BLOCK_THREADS + tx) * VEC_SIZE < d) + { + probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + } + + DeviceSamplingFromProb( + i, d, [](float x) { return x > 0; }, u, probs_vec, aggregate, &temp_storage); + if(float(aggregate) > u) + { + break; + } + } + int sampled_id = temp_storage.sampled_id; + if(sampled_id == d) + { + // NOTE(Zihao): this would happen when u is very close to 1 + // and the sum of probabilities is smaller than u + // In this case, we use the last valid index as the sampled id + sampled_id = temp_storage.last_valid_id; + } + output[bx] = sampled_id; +} + +template +__global__ void TopKSamplingFromProbKernel(DType* probs, + IdType* output, + IdType* indices, + IdType* top_k_arr, + uint32_t top_k_val, + uint32_t d, + uint64_t philox_seed, + uint64_t philox_offset) +{ + const uint32_t batch_size = gridDim.x; + const uint32_t bx = blockIdx.x, tx = threadIdx.x; + hiprandStatePhilox4_32_10_t state; + hiprand_init(philox_seed, bx, philox_offset, &state); + const uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx]; + const uint32_t row_idx = indices == nullptr ? bx : indices[bx]; + + extern __shared__ __align__( + alignof(SamplingTempStorage)) + uint8_t smem_sampling[]; + auto& temp_storage = + reinterpret_cast&>( + smem_sampling); + + vec_t probs_vec; + float aggregate; + float q = 1; + double low = 0, high = 1.f; + int sampled_id; + int round = 0; + do + { + round += 1; + temp_storage.sampled_id = d; + __syncthreads(); + float u = hiprand_uniform(&state) * q; + aggregate = 0; +#pragma unroll 2 + for(uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) + { + probs_vec.fill(0); + if((i * BLOCK_THREADS + tx) * VEC_SIZE < d) + { + probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + } + + DeviceSamplingFromProb( + i, d, [&](float x) { return x > low; }, u, probs_vec, aggregate, &temp_storage); + if(aggregate > u) + { + break; + } + } + __syncthreads(); + sampled_id = temp_storage.sampled_id; + if(sampled_id == d) + { + // NOTE(Zihao): this would happen when u is very close to 1 + // and the sum of probabilities is smaller than u + // In this case, we use the last valid index as the sampled id + sampled_id = temp_storage.last_valid_id; + } + double pivot_0 = probs[row_idx * d + sampled_id]; + double pivot_1 = (pivot_0 + high) / 2; + + ValueCount aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0}; +#pragma unroll 2 + for(uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) + { + probs_vec.fill(0); + if((i * BLOCK_THREADS + tx) * VEC_SIZE < d) + { + probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + } + + ValueCount probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE]; +#pragma unroll + for(uint32_t j = 0; j < VEC_SIZE; ++j) + { + probs_gt_pivot_0[j] = { + (probs_vec[j] > pivot_0) ? probs_vec[j] : 0, + (probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + probs_gt_pivot_1[j] = { + (probs_vec[j] > pivot_1) ? probs_vec[j] : 0, + (probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + } + + aggregate_gt_pivot_0 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_value_count) + .Sum(probs_gt_pivot_0); + if(tx == 0) + { + temp_storage.block_aggregate.pair = aggregate_gt_pivot_0; + } + __syncthreads(); + aggregate_gt_pivot_0 = temp_storage.block_aggregate.pair; + + aggregate_gt_pivot_1 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_value_count) + .Sum(probs_gt_pivot_1); + if(tx == 0) + { + temp_storage.block_aggregate.pair = aggregate_gt_pivot_1; + } + __syncthreads(); + aggregate_gt_pivot_1 = temp_storage.block_aggregate.pair; + } + if(aggregate_gt_pivot_0.count < k) + { + // case 1: pivot_0 accepted + break; + } + if(aggregate_gt_pivot_1.count < k) + { + // case 2: pivot_0 rejected, pivot_1 accepted + low = pivot_0; + high = pivot_1; + q = aggregate_gt_pivot_0.value; + } + else + { + // case 3: pivot_0 rejected, pivot_1 rejected + low = pivot_1; + q = aggregate_gt_pivot_1.value; + } + } while(low < high); + __syncthreads(); + if(tx == 0) + { + output[bx] = sampled_id; + } +} + +template +__global__ void TopPSamplingFromProbKernel(DType* probs, + IdType* output, + IdType* indices, + float* top_p_arr, + float top_p_val, + uint32_t d, + uint64_t philox_seed, + uint64_t philox_offset) +{ + const uint32_t batch_size = gridDim.x; + const uint32_t bx = blockIdx.x, tx = threadIdx.x; + hiprandStatePhilox4_32_10_t state; + hiprand_init(philox_seed, bx, philox_offset, &state); + const uint32_t row_idx = indices == nullptr ? bx : indices[bx]; + float top_p = (top_p_arr == nullptr) ? top_p_val : top_p_arr[row_idx]; + + extern __shared__ __align__( + alignof(SamplingTempStorage)) + uint8_t smem_sampling[]; + auto& temp_storage = + reinterpret_cast&>( + smem_sampling); + + vec_t probs_vec; + float aggregate; + float q = 1; + double low = 0, high = 1.f; + int sampled_id; + do + { + temp_storage.sampled_id = d; + __syncthreads(); + float u = hiprand_uniform(&state) * q; + aggregate = 0; +#pragma unroll 2 + for(uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) + { + probs_vec.fill(0); + if((i * BLOCK_THREADS + tx) * VEC_SIZE < d) + { + probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + } + + DeviceSamplingFromProb( + i, d, [&](float x) { return x > low; }, u, probs_vec, aggregate, &temp_storage); + if(aggregate > u) + { + break; + } + } + __syncthreads(); + sampled_id = temp_storage.sampled_id; + if(sampled_id == d) + { + // NOTE(Zihao): this would happen when u is very close to 1 + // and the sum of probabilities is smaller than u + // In this case, we use the last valid index as the sampled id + sampled_id = temp_storage.last_valid_id; + } + double pivot_0 = probs[row_idx * d + sampled_id]; + double pivot_1 = (pivot_0 + high) / 2; + + float aggregate_gt_pivot_0 = 0, aggregate_gt_pivot_1 = 0; +#pragma unroll 2 + for(uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) + { + probs_vec.fill(0); + if((i * BLOCK_THREADS + tx) * VEC_SIZE < d) + { + probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + } + + float probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE]; +#pragma unroll + for(uint32_t j = 0; j < VEC_SIZE; ++j) + { + probs_gt_pivot_0[j] = (probs_vec[j] > pivot_0) ? probs_vec[j] : 0; + probs_gt_pivot_1[j] = (probs_vec[j] > pivot_1) ? probs_vec[j] : 0; + } + + aggregate_gt_pivot_0 += + BlockReduce(temp_storage.block_prim.reduce) + .Sum(probs_gt_pivot_0); + if(tx == 0) + { + temp_storage.block_aggregate.value = aggregate_gt_pivot_0; + } + __syncthreads(); + aggregate_gt_pivot_0 = temp_storage.block_aggregate.value; + + aggregate_gt_pivot_1 += + BlockReduce(temp_storage.block_prim.reduce) + .Sum(probs_gt_pivot_1); + if(tx == 0) + { + temp_storage.block_aggregate.value = aggregate_gt_pivot_1; + } + __syncthreads(); + aggregate_gt_pivot_1 = temp_storage.block_aggregate.value; + } + if(aggregate_gt_pivot_0 < top_p) + { + // case 1: pivot_0 accepted + break; + } + if(aggregate_gt_pivot_1 < top_p) + { + // case 2: pivot_0 rejected, pivot_1 accepted + low = pivot_0; + high = pivot_1; + q = aggregate_gt_pivot_0; + } + else + { + // case 3: pivot_0 rejected, pivot_1 rejected + low = pivot_1; + q = aggregate_gt_pivot_1; + } + } while(low < high); + __syncthreads(); + if(tx == 0) + { + output[bx] = sampled_id; + } +} + +template +__global__ void TopKTopPSamplingFromProbKernel(DType* probs, + IdType* top_k_arr, + float* top_p_arr, + IdType* output, + IdType* indices, + IdType top_k_val, + float top_p_val, + uint32_t d, + uint64_t philox_seed, + uint64_t philox_offset) +{ + const uint32_t batch_size = gridDim.x; + const uint32_t bx = blockIdx.x, tx = threadIdx.x; + hiprandStatePhilox4_32_10_t state; + hiprand_init(philox_seed, bx, philox_offset, &state); + const uint32_t row_idx = indices == nullptr ? bx : indices[bx]; + const uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[row_idx]; + const float p = top_p_arr == nullptr ? top_p_val : top_p_arr[row_idx]; + + extern __shared__ __align__( + alignof(SamplingTempStorage)) + uint8_t smem_sampling[]; + auto& temp_storage = + reinterpret_cast&>( + smem_sampling); + + vec_t probs_vec; + float aggregate; + float q = 1; + double low = 0, high = 1.f; + int sampled_id; + do + { + temp_storage.sampled_id = d; + __syncthreads(); + float u = hiprand_uniform(&state) * q; + aggregate = 0; +#pragma unroll 2 + for(uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) + { + probs_vec.fill(0); + if((i * BLOCK_THREADS + tx) * VEC_SIZE < d) + { + probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + } + + DeviceSamplingFromProb( + i, d, [&](float x) { return x > low; }, u, probs_vec, aggregate, &temp_storage); + if(aggregate > u) + { + break; + } + } + __syncthreads(); + sampled_id = temp_storage.sampled_id; + if(sampled_id == d) + { + // NOTE(Zihao): this would happen when u is very close to 1 + // and the sum of probabilities is smaller than u + // In this case, we use the last valid index as the sampled id + sampled_id = temp_storage.last_valid_id; + } + double pivot_0 = probs[row_idx * d + sampled_id]; + double pivot_1 = (pivot_0 + high) / 2; + + ValueCount aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0}; +#pragma unroll 2 + for(uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) + { + probs_vec.fill(0); + if((i * BLOCK_THREADS + tx) * VEC_SIZE < d) + { + probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + } + + ValueCount probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE]; +#pragma unroll + for(uint32_t j = 0; j < VEC_SIZE; ++j) + { + probs_gt_pivot_0[j] = { + (probs_vec[j] > pivot_0) ? probs_vec[j] : 0, + (probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + probs_gt_pivot_1[j] = { + (probs_vec[j] > pivot_1) ? probs_vec[j] : 0, + (probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + } + + aggregate_gt_pivot_0 += BlockReduce, BLOCK_THREADS>( + temp_storage.block_prim.reduce_value_count) + .Sum(probs_gt_pivot_0); + if(tx == 0) + { + temp_storage.block_aggregate.pair = aggregate_gt_pivot_0; + } + __syncthreads(); + aggregate_gt_pivot_0 = temp_storage.block_aggregate.pair; + + aggregate_gt_pivot_1 += BlockReduce, BLOCK_THREADS>( + temp_storage.block_prim.reduce_value_count) + .Sum(probs_gt_pivot_1); + if(tx == 0) + { + temp_storage.block_aggregate.pair = aggregate_gt_pivot_1; + } + __syncthreads(); + aggregate_gt_pivot_1 = temp_storage.block_aggregate.pair; + } + if(aggregate_gt_pivot_0.count < k && aggregate_gt_pivot_0.value < p) + { + // case 1: pivot_0 accepted + break; + } + if(aggregate_gt_pivot_1.count < k && aggregate_gt_pivot_1.value < p) + { + // case 2: pivot_0 rejected, pivot_1 accepted + low = pivot_0; + high = pivot_1; + q = aggregate_gt_pivot_0.value; + } + else + { + // case 3: pivot_0 rejected, pivot_1 rejected + low = pivot_1; + q = aggregate_gt_pivot_1.value; + } + } while(low < high); + __syncthreads(); + if(tx == 0) + { + output[bx] = sampled_id; + } +} + +template +hipError_t SamplingFromLogits(T* logits, + IdType* output, + IdType* indices, + uint32_t batch_size, + uint32_t d, + bool deterministic, + uint64_t philox_seed, + uint64_t philox_offset, + hipStream_t stream = 0) +{ + const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&logits, &output, &indices, &d, &philox_seed, &philox_offset}; + const uint32_t smem_size = sizeof( + typename BlockReduce, BLOCK_THREADS, REDUCE_ALGO>::TempStorage); + + DISPATCH_ALIGNED_VEC_SIZE( + vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { + auto kernel = SamplingFromLogitsKernel; + hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); + })}); + return hipSuccess; +} + +template +hipError_t SamplingFromProb(T* probs, + IdType* output, + IdType* indices, + uint32_t batch_size, + uint32_t d, + bool deterministic, + uint64_t philox_seed, + uint64_t philox_offset, + hipStream_t stream = 0) +{ + const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&probs, &output, &indices, &d, &philox_seed, &philox_offset, &d}; + const uint32_t smem_size = sizeof(SamplingTempStorage); + + DISPATCH_ALIGNED_VEC_SIZE( + vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { + auto kernel = SamplingFromProbKernel; + + hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); + })}); + return hipSuccess; +} + +template +hipError_t TopKSamplingFromProb(T* probs, + IdType* output, + IdType* indices, + T* top_k_arr, + uint32_t batch_size, + uint32_t top_k_val, + uint32_t d, + bool deterministic, + uint64_t philox_seed, + uint64_t philox_offset, + hipStream_t stream = 0) +{ + const uint32_t vec_size = std::gcd(16 / sizeof(T), d); const uint32_t smem_size = sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &output, &indices, &top_k_arr, - &top_k_val, &d, &philox_seed, &philox_offset}; + void* args[] = { + &probs, &output, &indices, &top_k_arr, &top_k_val, &d, &philox_seed, &philox_offset}; DISPATCH_ALIGNED_VEC_SIZE( vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { - auto kernel = TopKSamplingFromProbKernel; - - hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); - - hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); + auto kernel = TopKSamplingFromProbKernel; + + hipFuncSetAttribute(reinterpret_cast(kernel), + hipFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); })}); return hipSuccess; - } - - template - struct RenormTempStorage { - union { - typename BlockReduce::TempStorage reduce; - typename BlockReduce::TempStorage reduce_int; - typename BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage - reduce_value_count; - } block_prim; - struct { - float max_val; - float min_val; - union { - struct { - float values[2]; - }; - struct { - int counts[2]; - }; - struct { - ValueCount pairs[2]; - }; - } block_aggregate; - }; - }; - - template - __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float* top_p_arr, - float top_p_val, uint32_t d) { - const uint32_t bx = blockIdx.x, tx = threadIdx.x; - const uint32_t row_idx = bx; - float p = top_p_arr == nullptr ? top_p_val : top_p_arr[bx]; - - extern __shared__ __align__(alignof(RenormTempStorage)) - uint8_t smem_renorm[]; - auto& temp_storage = - reinterpret_cast&>(smem_renorm); - temp_storage.max_val = 0; - vec_t probs_vec; - - float max_val = GetMaxValue>(probs, row_idx, d, - temp_storage); - - double low = 0, high = max_val; - float min_gt_low, max_le_high; - float sum_low = 1; - // f(x) = sum(probs[probs > x]), f(x) is non-increasing - // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high} - // loop invariant: - // - f(low) >= p, f(high) < p - // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high) - // stopping condition - // - f(low) >= p, f(min_gt_low) == f(max_le_high) == f(high) < p - do { - double pivot_0 = (high + 2 * low) / 3; - double pivot_1 = (2 * high + low) / 3; - - float aggregate_gt_pivot_0 = 0, aggregate_gt_pivot_1 = 0; - min_gt_low = high; - max_le_high = low; - #pragma unroll 2 - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(0); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); - } - - float probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE]; - #pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - probs_gt_pivot_0[j] = (probs_vec[j] > pivot_0) ? probs_vec[j] : 0; - probs_gt_pivot_1[j] = (probs_vec[j] > pivot_1) ? probs_vec[j] : 0; - - if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { - min_gt_low = min(min_gt_low, probs_vec[j]); - } - if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { - max_le_high = max(max_le_high, probs_vec[j]); - } - } - - aggregate_gt_pivot_0 += - BlockReduce(temp_storage.block_prim.reduce) - .Sum(probs_gt_pivot_0); - __syncthreads(); - - aggregate_gt_pivot_1 += - BlockReduce(temp_storage.block_prim.reduce) - .Sum(probs_gt_pivot_1); - __syncthreads(); - } - min_gt_low = BlockReduce(temp_storage.block_prim.reduce) - .Reduce(min_gt_low, hipcub::Min()); - __syncthreads(); - max_le_high = - BlockReduce(temp_storage.block_prim.reduce) - .Reduce(max_le_high, hipcub::Max()); - if (tx == 0) { - temp_storage.block_aggregate.values[0] = aggregate_gt_pivot_0; - temp_storage.block_aggregate.values[1] = aggregate_gt_pivot_1; - temp_storage.min_val = min_gt_low; - temp_storage.max_val = max_le_high; - } - __syncthreads(); - aggregate_gt_pivot_0 = temp_storage.block_aggregate.values[0]; - aggregate_gt_pivot_1 = temp_storage.block_aggregate.values[1]; - min_gt_low = temp_storage.min_val; - max_le_high = temp_storage.max_val; - - if (aggregate_gt_pivot_1 >= p) { - low = pivot_1; - sum_low = aggregate_gt_pivot_1; - } else if (aggregate_gt_pivot_0 >= p) { - low = pivot_0; - high = min(pivot_1, max_le_high); - sum_low = aggregate_gt_pivot_0; - } else { - high = min(pivot_0, max_le_high); - } - } while (min_gt_low != max_le_high); - - float normalizer = __frcp_rn(max(sum_low, 1e-8)); - - // normalize - #pragma unroll 2 - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(0); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); - } - #pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - probs_vec[j] = (probs_vec[j] > low) ? probs_vec[j] * normalizer : 0; - } - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.cast_store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + +} + +template +struct RenormTempStorage +{ + union + { + typename BlockReduce::TempStorage reduce; + typename BlockReduce::TempStorage reduce_int; + typename BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage + reduce_value_count; + } block_prim; + struct + { + float max_val; + float min_val; + union + { + struct + { + float values[2]; + }; + struct + { + int counts[2]; + }; + struct + { + ValueCount pairs[2]; + }; + } block_aggregate; + }; +}; + +template +__global__ void TopPRenormProbKernel( + DType* probs, DType* renormed_prob, float* top_p_arr, float top_p_val, uint32_t d) +{ + const uint32_t bx = blockIdx.x, tx = threadIdx.x; + const uint32_t row_idx = bx; + float p = top_p_arr == nullptr ? top_p_val : top_p_arr[bx]; + + extern __shared__ __align__(alignof(RenormTempStorage)) + uint8_t smem_renorm[]; + auto& temp_storage = + reinterpret_cast&>(smem_renorm); + temp_storage.max_val = 0; + vec_t probs_vec; + + float max_val = GetMaxValue>( + probs, row_idx, d, temp_storage); + + double low = 0, high = max_val; + float min_gt_low, max_le_high; + float sum_low = 1; + // f(x) = sum(probs[probs > x]), f(x) is non-increasing + // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high} + // loop invariant: + // - f(low) >= p, f(high) < p + // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high) + // stopping condition + // - f(low) >= p, f(min_gt_low) == f(max_le_high) == f(high) < p + do + { + double pivot_0 = (high + 2 * low) / 3; + double pivot_1 = (2 * high + low) / 3; + + float aggregate_gt_pivot_0 = 0, aggregate_gt_pivot_1 = 0; + min_gt_low = high; + max_le_high = low; +#pragma unroll 2 + for(uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) + { + probs_vec.fill(0); + if((i * BLOCK_THREADS + tx) * VEC_SIZE < d) + { + probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + + tx * VEC_SIZE); + } + + float probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE]; +#pragma unroll + for(uint32_t j = 0; j < VEC_SIZE; ++j) + { + probs_gt_pivot_0[j] = (probs_vec[j] > pivot_0) ? probs_vec[j] : 0; + probs_gt_pivot_1[j] = (probs_vec[j] > pivot_1) ? probs_vec[j] : 0; + + if(probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) + { + min_gt_low = min(min_gt_low, probs_vec[j]); + } + if(probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) + { + max_le_high = max(max_le_high, probs_vec[j]); + } + } + + aggregate_gt_pivot_0 += + BlockReduce(temp_storage.block_prim.reduce) + .Sum(probs_gt_pivot_0); + __syncthreads(); + + aggregate_gt_pivot_1 += + BlockReduce(temp_storage.block_prim.reduce) + .Sum(probs_gt_pivot_1); + __syncthreads(); + } + min_gt_low = + BlockReduce(temp_storage.block_prim.reduce) + .Reduce(min_gt_low, hipcub::Min()); + __syncthreads(); + max_le_high = + BlockReduce(temp_storage.block_prim.reduce) + .Reduce(max_le_high, hipcub::Max()); + if(tx == 0) + { + temp_storage.block_aggregate.values[0] = aggregate_gt_pivot_0; + temp_storage.block_aggregate.values[1] = aggregate_gt_pivot_1; + temp_storage.min_val = min_gt_low; + temp_storage.max_val = max_le_high; + } + __syncthreads(); + aggregate_gt_pivot_0 = temp_storage.block_aggregate.values[0]; + aggregate_gt_pivot_1 = temp_storage.block_aggregate.values[1]; + min_gt_low = temp_storage.min_val; + max_le_high = temp_storage.max_val; + + if(aggregate_gt_pivot_1 >= p) + { + low = pivot_1; + sum_low = aggregate_gt_pivot_1; + } + else if(aggregate_gt_pivot_0 >= p) + { + low = pivot_0; + high = min(pivot_1, max_le_high); + sum_low = aggregate_gt_pivot_0; + } + else + { + high = min(pivot_0, max_le_high); + } + } while(min_gt_low != max_le_high); + + float normalizer = __frcp_rn(max(sum_low, 1e-8)); + + // normalize +#pragma unroll 2 + for(uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) + { + probs_vec.fill(0); + if((i * BLOCK_THREADS + tx) * VEC_SIZE < d) + { + probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + } +#pragma unroll + for(uint32_t j = 0; j < VEC_SIZE; ++j) + { + probs_vec[j] = (probs_vec[j] > low) ? probs_vec[j] * normalizer : 0; + } + if((i * BLOCK_THREADS + tx) * VEC_SIZE < d) + { + probs_vec.cast_store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + + tx * VEC_SIZE); + } + } +} + +template +__global__ void TopKRenormProbKernel( + DType* probs, DType* renormed_prob, IdType* top_k_arr, uint32_t top_k_val, uint32_t d) +{ + const uint32_t bx = blockIdx.x, tx = threadIdx.x; + const uint32_t row_idx = bx; + uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx]; + double pivot = -infinity(), normalizer = 1; + vec_t probs_vec; + if(k < d) + { + extern __shared__ __align__(alignof(RenormTempStorage)) + uint8_t smem_renorm[]; + auto& temp_storage = + reinterpret_cast&>(smem_renorm); + temp_storage.max_val = 0; + + float max_val = GetMaxValue>( + probs, row_idx, d, temp_storage); + + double low = 0, high = max_val; + float min_gt_low, max_le_high; + float sum_low = 1; + // f(x) = len(nonzero(probs > x)), f(x) is non-increasing + // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high} + // loop invariant: + // - f(low) >= k, f(high) < k + // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high) + // stopping condition: min_gt_low == max_le_high + // - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k + do + { + double pivot_0 = (high + 2 * low) / 3; + double pivot_1 = (2 * high + low) / 3; + + ValueCount aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0}; + min_gt_low = high; + max_le_high = low; +#pragma unroll 2 + for(uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) + { + probs_vec.fill(0); + if((i * BLOCK_THREADS + tx) * VEC_SIZE < d) + { + probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + + tx * VEC_SIZE); + } + ValueCount probs_gt_pivot_0_pair[VEC_SIZE], probs_gt_pivot_1_pair[VEC_SIZE]; +#pragma unroll + for(uint32_t j = 0; j < VEC_SIZE; ++j) + { + probs_gt_pivot_0_pair[j] = { + (probs_vec[j] > pivot_0) ? probs_vec[j] : 0, + (probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + probs_gt_pivot_1_pair[j] = { + (probs_vec[j] > pivot_1) ? probs_vec[j] : 0, + (probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + + if(probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) + { + min_gt_low = min(min_gt_low, probs_vec[j]); + } + if(probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) + { + max_le_high = max(max_le_high, probs_vec[j]); + } + } + + aggregate_gt_pivot_0 += + BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_value_count) + .Sum(probs_gt_pivot_0_pair); + __syncthreads(); + + aggregate_gt_pivot_1 += + BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_value_count) + .Sum(probs_gt_pivot_1_pair); + __syncthreads(); + } + min_gt_low = + BlockReduce(temp_storage.block_prim.reduce) + .Reduce(min_gt_low, hipcub::Min()); + __syncthreads(); + max_le_high = + BlockReduce(temp_storage.block_prim.reduce) + .Reduce(max_le_high, hipcub::Max()); + if(tx == 0) + { + temp_storage.block_aggregate.pairs[0] = aggregate_gt_pivot_0; + temp_storage.block_aggregate.pairs[1] = aggregate_gt_pivot_1; + temp_storage.min_val = min_gt_low; + temp_storage.max_val = max_le_high; + } + __syncthreads(); + aggregate_gt_pivot_0 = temp_storage.block_aggregate.pairs[0]; + aggregate_gt_pivot_1 = temp_storage.block_aggregate.pairs[1]; + min_gt_low = temp_storage.min_val; + max_le_high = temp_storage.max_val; + + if(aggregate_gt_pivot_1.count >= k) + { + low = pivot_1; + sum_low = float(aggregate_gt_pivot_1.value); + } + else if(aggregate_gt_pivot_0.count >= k) + { + low = pivot_0; + high = min(pivot_1, max_le_high); + sum_low = float(aggregate_gt_pivot_0.value); + } + else + { + high = min(pivot_0, max_le_high); + } + } while(min_gt_low != max_le_high); + + normalizer = __frcp_rn(max(sum_low, 1e-8)); + pivot = low; + } + + // normalize +#pragma unroll 2 + for(uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) + { + probs_vec.fill(0); + if((i * BLOCK_THREADS + tx) * VEC_SIZE < d) + { + probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + } +#pragma unroll + for(uint32_t j = 0; j < VEC_SIZE; ++j) + { + probs_vec[j] = (probs_vec[j] > pivot) ? probs_vec[j] * normalizer : 0; + } + if((i * BLOCK_THREADS + tx) * VEC_SIZE < d) + { + probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); - } - } - } - - template - __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* top_k_arr, - uint32_t top_k_val, uint32_t d) { - const uint32_t bx = blockIdx.x, tx = threadIdx.x; - const uint32_t row_idx = bx; - uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx]; - double pivot = -infinity(), normalizer = 1; - vec_t probs_vec; - if (k < d) { - extern __shared__ __align__(alignof(RenormTempStorage)) - uint8_t smem_renorm[]; - auto& temp_storage = - reinterpret_cast&>(smem_renorm); - temp_storage.max_val = 0; - - float max_val = GetMaxValue>( - probs, row_idx, d, temp_storage); - - double low = 0, high = max_val; - float min_gt_low, max_le_high; - float sum_low = 1; - // f(x) = len(nonzero(probs > x)), f(x) is non-increasing - // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high} - // loop invariant: - // - f(low) >= k, f(high) < k - // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high) - // stopping condition: min_gt_low == max_le_high - // - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k - do { - double pivot_0 = (high + 2 * low) / 3; - double pivot_1 = (2 * high + low) / 3; - - ValueCount aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0}; - min_gt_low = high; - max_le_high = low; - #pragma unroll 2 - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(0); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); - } - ValueCount probs_gt_pivot_0_pair[VEC_SIZE], probs_gt_pivot_1_pair[VEC_SIZE]; - #pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - probs_gt_pivot_0_pair[j] = { - (probs_vec[j] > pivot_0) ? probs_vec[j] : 0, - (probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; - probs_gt_pivot_1_pair[j] = { - (probs_vec[j] > pivot_1) ? probs_vec[j] : 0, - (probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; - - if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { - min_gt_low = min(min_gt_low, probs_vec[j]); - } - if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { - max_le_high = max(max_le_high, probs_vec[j]); - } - } - - aggregate_gt_pivot_0 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( - temp_storage.block_prim.reduce_value_count) - .Sum(probs_gt_pivot_0_pair); - __syncthreads(); - - aggregate_gt_pivot_1 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( - temp_storage.block_prim.reduce_value_count) - .Sum(probs_gt_pivot_1_pair); - __syncthreads(); - } - min_gt_low = - BlockReduce(temp_storage.block_prim.reduce) - .Reduce(min_gt_low, hipcub::Min()); - __syncthreads(); - max_le_high = - BlockReduce(temp_storage.block_prim.reduce) - .Reduce(max_le_high, hipcub::Max()); - if (tx == 0) { - temp_storage.block_aggregate.pairs[0] = aggregate_gt_pivot_0; - temp_storage.block_aggregate.pairs[1] = aggregate_gt_pivot_1; - temp_storage.min_val = min_gt_low; - temp_storage.max_val = max_le_high; - } - __syncthreads(); - aggregate_gt_pivot_0 = temp_storage.block_aggregate.pairs[0]; - aggregate_gt_pivot_1 = temp_storage.block_aggregate.pairs[1]; - min_gt_low = temp_storage.min_val; - max_le_high = temp_storage.max_val; - - if (aggregate_gt_pivot_1.count >= k) { - low = pivot_1; - sum_low = float(aggregate_gt_pivot_1.value); - } else if (aggregate_gt_pivot_0.count >= k) { - low = pivot_0; - high = min(pivot_1, max_le_high); - sum_low = float(aggregate_gt_pivot_0.value); - } else { - high = min(pivot_0, max_le_high); - } - } while (min_gt_low != max_le_high); - - normalizer = __frcp_rn(max(sum_low, 1e-8)); - pivot = low; - } - - // normalize - #pragma unroll 2 - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(0); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); - } - #pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; ++j) { - probs_vec[j] = (probs_vec[j] > pivot) ? probs_vec[j] * normalizer : 0; - } - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); - } - } - } - - template - hipError_t TopPRenormProb(DType* probs, DType* renormed_prob, float* top_p_arr, - uint32_t batch_size, float top_p_val, uint32_t d, - hipStream_t stream = 0) { - - const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); - - const uint32_t smem_size = sizeof(RenormTempStorage); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &renormed_prob, &top_p_arr, &top_p_val, &d}; - DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = TopPRenormProbKernel; - hipFuncSetAttribute(reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); - hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); - }); - return hipSuccess; - } - - } // namespace sampling - - } // namespace aiter \ No newline at end of file + } + } +} + +template +hipError_t TopPRenormProb(DType* probs, + DType* renormed_prob, + float* top_p_arr, + uint32_t batch_size, + float top_p_val, + uint32_t d, + hipStream_t stream = 0) +{ + + const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); + + const uint32_t smem_size = sizeof(RenormTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&probs, &renormed_prob, &top_p_arr, &top_p_val, &d}; + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { + auto kernel = TopPRenormProbKernel; + hipFuncSetAttribute(reinterpret_cast(kernel), + hipFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); + }); + return hipSuccess; +} + +} // namespace sampling + +} // namespace aiter \ No newline at end of file diff --git a/csrc/cpp_itfs/sampling/vec_dtypes.cuh b/csrc/cpp_itfs/sampling/vec_dtypes.cuh index 7cb336dd1c..468d7ae8bb 100644 --- a/csrc/cpp_itfs/sampling/vec_dtypes.cuh +++ b/csrc/cpp_itfs/sampling/vec_dtypes.cuh @@ -15,1541 +15,2146 @@ */ #pragma once - #include - #include - #include - #include - #include - #include - - #include - - #if defined(__HIPCC__) || (defined(__clang__) && defined(__HIP__)) || defined(__HIPCC_RTC__) - /* - Hacky workaround for the error below: - - /home/git_repos/glen-amd/flashinfer/include/flashinfer/attention/../vec_dtypes_hip.cuh:200:38: error: use of undeclared identifier '__float2bfloat162_rn'; did you mean '__float22bfloat162_rn'? - 200 | const __hip_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); - | ^~~~~~~~~~~~~~~~~~~~ - | __float22bfloat162_rn - /opt/rocm-6.3.1/lib/llvm/bin/../../../include/hip/amd_detail/amd_hip_bf16.h:574:45: note: '__float22bfloat162_rn' declared here - 574 | __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 __float22bfloat162_rn(const float2 a) { +#include +#include +#include +#include +#include +#include + +#include + +#if defined(__HIPCC__) || (defined(__clang__) && defined(__HIP__)) || defined(__HIPCC_RTC__) +/* +Hacky workaround for the error below: + + /home/git_repos/glen-amd/flashinfer/include/flashinfer/attention/../vec_dtypes_hip.cuh:200:38: +error: use of undeclared identifier '__float2bfloat162_rn'; did you mean '__float22bfloat162_rn'? + 200 | const __hip_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); | ^~~~~~~~~~~~~~~~~~~~ | __float22bfloat162_rn + /opt/rocm-6.3.1/lib/llvm/bin/../../../include/hip/amd_detail/amd_hip_bf16.h:574:45: note: +'__float22bfloat162_rn' declared here 574 | __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 +__float22bfloat162_rn(const float2 a) { +*/ +__HOST_DEVICE__ inline __hip_bfloat162 __float2bfloat162_rn(const float a) +{ + return __hip_bfloat162{__float2bfloat16(a), __float2bfloat16(a)}; +} + +inline __attribute__((always_inline)) __device__ __hip_bfloat162 +make_bfloat162(const __hip_bfloat16 x, const __hip_bfloat16 y) +{ + __hip_bfloat162 t; + t.x = x; + t.y = y; + return t; +} +#endif + +namespace aiter { + +#define FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + +/******************* vec_t type cast *******************/ + +template +struct vec_cast +{ + template + inline __attribute__((always_inline)) __device__ static void cast(dst_t* dst, const src_t* src) + { +#pragma unroll + for(size_t i = 0; i < vec_size; ++i) + { + dst[i] = (dst_t)src[i]; + } + } +}; + +template <> +struct vec_cast +{ + template + inline __attribute__((always_inline)) __device__ static void cast(float* dst, const half* src) + { + if constexpr(vec_size == 1) + { + // dst[0] = (float)src[0]; + dst[0] = __half2float(src[0]); + } + else + { +#pragma unroll + for(size_t i = 0; i < vec_size / 2; ++i) + { + ((float2*)dst)[i] = __half22float2(((half2*)src)[i]); + } + } + } +}; + +template <> +struct vec_cast +{ + template + inline __attribute__((always_inline)) __device__ static void cast(half* dst, const float* src) + { + if constexpr(vec_size == 1) + { + dst[0] = __float2half(src[0]); + } + else + { +#pragma unroll + for(size_t i = 0; i < vec_size / 2; ++i) + { + ((half2*)dst)[i] = __float22half2_rn(((float2*)src)[i]); + } + } + } +}; + +template +constexpr inline __attribute__((always_inline)) __device__ int get_exponent_bits() +{ + if constexpr(std::is_same_v) + { + return 4; + } + else if constexpr(std::is_same_v) + { + return 5; + } + else if constexpr(std::is_same_v) + { + return 5; + } + else if constexpr(std::is_same_v) + { + return 8; + } +} + +template +constexpr inline __attribute__((always_inline)) __device__ int get_mantissa_bits() +{ + if constexpr(std::is_same_v) + { + return 3; + } + else if constexpr(std::is_same_v) + { + return 2; + } + else if constexpr(std::is_same_v) + { + return 11; + } + else if constexpr(std::is_same_v) + { + return 7; + } +} + +/*! + * \brief Fallback to software fast dequant implementation if hardware dequantization is not + * available. + * \note Inspired by Marlin's fast dequantization, but here we don't have to permute + * weights order. + * \ref + * https://github.com/vllm-project/vllm/blob/6dffa4b0a6120159ef2fe44d695a46817aff65bc/csrc/quantization/fp8/fp8_marlin.cu#L120 */ - __HOST_DEVICE__ inline __hip_bfloat162 __float2bfloat162_rn(const float a) { - return __hip_bfloat162{__float2bfloat16(a), __float2bfloat16(a)}; - } - - inline __attribute__((always_inline)) __device__ __hip_bfloat162 make_bfloat162(const __hip_bfloat16 x, const __hip_bfloat16 y) { - __hip_bfloat162 t; - t.x = x; - t.y = y; - return t; - } - #endif - - namespace aiter { - - #define FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED - - /******************* vec_t type cast *******************/ - - template - struct vec_cast { - template - inline __attribute__((always_inline)) __device__ static void cast(dst_t* dst, const src_t* src) { - #pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - dst[i] = (dst_t)src[i]; - } - } - }; - - template <> - struct vec_cast { - template - inline __attribute__((always_inline)) __device__ static void cast(float* dst, const half* src) { - if constexpr (vec_size == 1) { - // dst[0] = (float)src[0]; - dst[0] = __half2float(src[0]); - } else { - #pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((float2*)dst)[i] = __half22float2(((half2*)src)[i]); - } - } - } - }; - - template <> - struct vec_cast { - template - inline __attribute__((always_inline)) __device__ static void cast(half* dst, const float* src) { - if constexpr (vec_size == 1) { - dst[0] = __float2half(src[0]); - } else { - #pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((half2*)dst)[i] = __float22half2_rn(((float2*)src)[i]); - } - } - } - }; - - template - constexpr inline __attribute__((always_inline)) __device__ int get_exponent_bits() { - if constexpr (std::is_same_v) { - return 4; - } else if constexpr (std::is_same_v) { - return 5; - } else if constexpr (std::is_same_v) { - return 5; - } else if constexpr (std::is_same_v) { - return 8; - } - } - - template - constexpr inline __attribute__((always_inline)) __device__ int get_mantissa_bits() { - if constexpr (std::is_same_v) { - return 3; - } else if constexpr (std::is_same_v) { - return 2; - } else if constexpr (std::is_same_v) { - return 11; - } else if constexpr (std::is_same_v) { - return 7; - } - } - - /*! - * \brief Fallback to software fast dequant implementation if hardware dequantization is not - * available. - * \note Inspired by Marlin's fast dequantization, but here we don't have to permute - * weights order. - * \ref - * https://github.com/vllm-project/vllm/blob/6dffa4b0a6120159ef2fe44d695a46817aff65bc/csrc/quantization/fp8/fp8_marlin.cu#L120 - */ - template - __device__ void fast_dequant_f8f16x4(uint32_t* input, uint2* output) { - uint32_t q = *input; - if constexpr (std::is_same_v && std::is_same_v) { - output->x = __byte_perm(0U, q, 0x5140); - output->y = __byte_perm(0U, q, 0x7362); - } else { - constexpr int FP8_EXPONENT = get_exponent_bits(); - constexpr int FP8_MANTISSA = get_mantissa_bits(); - constexpr int FP16_EXPONENT = get_exponent_bits(); - - constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; - // Calculate MASK for extracting mantissa and exponent - // XXX: duplicate defs of `MASK1` and `MASK2`, - // in the HIP file "include/hip/amd_detail/amd_device_functions.h". - constexpr int MASK1_orig = 0x80000000; - constexpr int MASK2_orig = MASK1_orig >> (FP8_EXPONENT + FP8_MANTISSA); - constexpr int MASK3 = MASK2_orig & 0x7fffffff; - constexpr int MASK = MASK3 | (MASK3 >> 16); - q = __byte_perm(q, q, 0x1302); - - // Extract and shift FP8 values to FP16 format - uint32_t Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - uint32_t Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); - - constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); - // Construct and apply exponent bias - if constexpr (std::is_same_v) { - const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); - - // Convert to half2 and apply bias - *(half2*)&(output->x) = __hmul2(*reinterpret_cast(&Out1), bias_reg); - *(half2*)&(output->y) = __hmul2(*reinterpret_cast(&Out2), bias_reg); - } else { - constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; - const __hip_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); - // Convert to bfloat162 and apply bias - *(__hip_bfloat162*)&(output->x) = - __hmul2(*reinterpret_cast(&Out1), bias_reg); - *(__hip_bfloat162*)&(output->y) = - __hmul2(*reinterpret_cast(&Out2), bias_reg); - } - } - } - - template <> - struct vec_cast<__hip_bfloat16, __hip_fp8_e4m3_fnuz> { - template - inline __attribute__((always_inline)) __device__ static void cast(__hip_bfloat16* dst, const __hip_fp8_e4m3_fnuz* src) { - if constexpr (vec_size == 1) { - dst[0] = __hip_bfloat16(src[0]); - } else if constexpr (vec_size == 2) { - dst[0] = __hip_bfloat16(src[0]); - dst[1] = __hip_bfloat16(src[1]); - } else { - static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); - #pragma unroll - for (uint32_t i = 0; i < vec_size / 4; ++i) { - fast_dequant_f8f16x4<__hip_fp8_e4m3_fnuz, __hip_bfloat16>((uint32_t*)&src[i * 4], - (uint2*)&dst[i * 4]); - } - } - } - }; - - template <> - struct vec_cast<__hip_bfloat16, __hip_fp8_e5m2_fnuz> { - template - inline __attribute__((always_inline)) __device__ static void cast(__hip_bfloat16* dst, const __hip_fp8_e5m2_fnuz* src) { - if constexpr (vec_size == 1) { - dst[0] = __hip_bfloat16(src[0]); - } else if constexpr (vec_size == 2) { - dst[0] = __hip_bfloat16(src[0]); - dst[1] = __hip_bfloat16(src[1]); - } else { - static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); - #pragma unroll - for (uint32_t i = 0; i < vec_size / 4; ++i) { - fast_dequant_f8f16x4<__hip_fp8_e5m2_fnuz, __hip_bfloat16>((uint32_t*)&src[i * 4], - (uint2*)&dst[i * 4]); - } - } - } - }; - - #if defined(__HIPCC__) || (defined(__clang__) && defined(__HIP__)) || defined(__HIPCC_RTC__) - // Function to convert half-precision to e4m3 - __device__ uint8_t convert_f32_to_e4m3(float val) { - // Define the range of e4m3 - // 1. Minimum representable value for e4m3 - // 2. Binary 1000.000 in e4m3 - // 3. FLT_MIN is not suitable for e4m3 because e4m3 has a much smaller dynamic range. - float min_e4m3 = -8.0f; - // 1. Maximum representable value for e4m3 - // 2. Binary 0111.111 in e4m3 - // FLT_MAX far exceeds the maximum value representable in e4m3. - float max_e4m3 = 7.875f; - - // Saturate the value to the e4m3 range - val = fminf(fmaxf(val, min_e4m3), max_e4m3); - - // Perform conversion - // Decompose into mantissa and exponent - int exp; - float mantissa = frexpf(val, &exp); - - // Encode sign bit - uint8_t sign = (mantissa < 0) ? 0x80 : 0x00; - - // Normalize mantissa and encode exponent - mantissa = fabsf(mantissa) * 16.0f; // Scale mantissa for e4m3's 3-bit precision - uint8_t exponent = static_cast(exp + 7); // Bias of 7 for e4m3 - - // Quantize mantissa - // Apply round-to-nearest-even to the mantissa - uint8_t quant_mantissa = static_cast(roundf(mantissa)) & 0x07; - - // Combine into 8 bits: [sign][exponent][mantissa] - return sign | (exponent << 3) | quant_mantissa; - } - - __device__ __half2 convert_uint32_to_half2(uint32_t input) { - // Extract the low and high 16 bits - uint16_t low_val = input & 0xFFFF; - uint16_t high_val = (input >> 16) & 0xFFFF; - // Convert to __half - __half low_half = __float2half(static_cast(low_val)); - __half high_half = __float2half(static_cast(high_val)); - // Pack into __half2 - return __halves2half2(low_half, high_half); - } - - - // Convert f16x2 (__half2) to e4m3x2 (packed 16-bit) - __device__ uint16_t convert_f16x2_to_e4m3x2(__half2 x) { - float f32_0 = __half2float(__low2half(x)); - float f32_1 = __half2float(__high2half(x)); - uint8_t e4m3_0 = convert_f32_to_e4m3(f32_0); - uint8_t e4m3_1 = convert_f32_to_e4m3(f32_1); - return (static_cast(e4m3_1) << 8) | e4m3_0; - } - #endif - - template <> - struct vec_cast<__hip_fp8_e4m3_fnuz, half> { - template - inline __attribute__((always_inline)) __device__ static void cast(__hip_fp8_e4m3_fnuz* dst, const half* src) { - #ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED - if constexpr (vec_size == 1) { - dst[0] = __hip_fp8_e4m3_fnuz(src[0]); - } else { - #pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - uint16_t y; - uint32_t x = *(uint32_t*)&src[i * 2]; - __half2 x_h2 = convert_uint32_to_half2(x); - y = convert_f16x2_to_e4m3x2(x_h2); - *(uint16_t*)&dst[i * 2] = y; - } - } - #else - #pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - dst[i] = __hip_fp8_e4m3_fnuz(src[i]); - } - #endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED - } - }; - - #if defined(__HIPCC__) || (defined(__clang__) && defined(__HIP__)) || defined(__HIPCC_RTC__) - __device__ uint16_t convert_f16x2_to_e5m2x2(uint32_t x) { - // Unpack the two 16-bit half-precision floats from the input - // Extract lower 16 bits - __half h1 = __ushort_as_half(x & 0xFFFF); - // Extract upper 16 bits - __half h2 = __ushort_as_half((x >> 16) & 0xFFFF); - - - // Define the range of e5m2 - // Minimum representable value for e5m2 - const float min_e5m2 = -8.0f; - // Maximum representable value for e5m2 - const float max_e5m2 = 7.75f; - - // Helper lambda for conversion - auto f32_to_e5m2 = [min_e5m2, max_e5m2](float val) -> uint8_t { - // Saturate the val - val= fminf(fmaxf(val, min_e5m2), max_e5m2); - - // Decompose into mantissa and exponent - int exp; - float mantissa = frexpf(val, &exp); - - // Encode sign bit - uint8_t sign = (mantissa < 0) ? 0x10 : 0x00; // Sign in bit 4 - mantissa = fabsf(mantissa); - - // Normalize mantissa and encode exponent - mantissa *= 4.0f; // Scale for 2-bit mantissa - uint8_t exponent = static_cast(exp + 7); // Apply bias for e5m2 - - // Apply round-to-nearest-even - uint8_t quant_mantissa = static_cast(roundf(mantissa)) & 0x03; - - // Combine into 5 bits: [sign][exponent][mantissa] - return sign | (exponent << 2) | quant_mantissa; - }; - - // Convert the two __half values to e5m2 - uint8_t e5m2_1 = f32_to_e5m2(__half2float(h1)); - uint8_t e5m2_2 = f32_to_e5m2(__half2float(h2)); - - // Pack the two e5m2 values into a single 16-bit output - return (e5m2_2 << 8) | e5m2_1; - } - #endif - - template <> - struct vec_cast<__hip_fp8_e5m2_fnuz, half> { - template - inline __attribute__((always_inline)) __device__ static void cast(__hip_fp8_e5m2_fnuz* dst, const half* src) { - #ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED - if constexpr (vec_size == 1) { - dst[0] = __hip_fp8_e5m2_fnuz(src[0]); - } else { - #pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - uint16_t y; - uint32_t x = *(uint32_t*)&src[i * 2]; - y = convert_f16x2_to_e5m2x2(x); - *(uint16_t*)&dst[i * 2] = y; - } - } - #else - #pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - dst[i] = __hip_fp8_e5m2_fnuz(src[i]); - } - #endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED - } - }; - - #if defined(__HIPCC__) || (defined(__clang__) && defined(__HIP__)) || defined(__HIPCC_RTC__) - __device__ uint32_t convert_e4m3x2_to_f16x2(uint16_t x) { - // Extract two e4m3 values from the 16-bit input - uint8_t e4m3_1 = x & 0xFF; // Lower 8 bits - uint8_t e4m3_2 = (x >> 8) & 0xFF; // Upper 8 bits - - // Decode e4m3 to float - auto e4m3_to_f32 = [](uint8_t e4m3) -> float { - // Extract sign, exponent, and mantissa - int sign = (e4m3 & 0x80) ? -1 : 1; - int exponent = ((e4m3 >> 3) & 0x0F) - 7; // 4-bit exponent with bias 7 - int mantissa = e4m3 & 0x07; // 3-bit mantissa - - // Handle special case: zero - if (exponent == -7 && mantissa == 0) { - return 0.0f; - } - - // Convert to float - float f32_val = sign * ldexpf(1.0f + mantissa / 8.0f, exponent); - return f32_val; - }; - - float f1 = e4m3_to_f32(e4m3_1); - float f2 = e4m3_to_f32(e4m3_2); - - // Convert float to IEEE f16 - __half h1 = __float2half_rn(f1); - __half h2 = __float2half_rn(f2); - - // Pack the two f16 values into a single uint32_t - uint32_t f16x2 = (__half_as_ushort(h2) << 16) | __half_as_ushort(h1); - return f16x2; - } - #endif - - template <> - struct vec_cast { - template - inline __attribute__((always_inline)) __device__ static void cast(half* dst, const __hip_fp8_e4m3_fnuz* src) { - #ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED - if constexpr (vec_size == 1) { - dst[0] = half(src[0]); - } else { - #pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - uint32_t y; - uint16_t x = *(uint16_t*)&src[i * 2]; - y = convert_e4m3x2_to_f16x2(x); - *(uint32_t*)&dst[i * 2] = y; - } - } - #else - if constexpr (vec_size == 1) { - dst[0] = half(src[0]); - } else if constexpr (vec_size == 2) { - dst[0] = half(src[0]); - dst[1] = half(src[1]); - } else { - static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); - #pragma unroll - for (uint32_t i = 0; i < vec_size / 4; ++i) { - fast_dequant_f8f16x4<__hip_fp8_e4m3_fnuz, half>((uint32_t*)&src[i * 4], (uint2*)&dst[i * 4]); - } - } - #endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED - } - }; - - #if defined(__HIPCC__) || (defined(__clang__) && defined(__HIP__)) || defined(__HIPCC_RTC__) - __device__ uint32_t convert_e5m2x2_to_f16x2(uint16_t x) { - // Extract two e5m2 values from the 16-bit input - uint8_t e5m2_1 = x & 0xFF; // Lower 8 bits - uint8_t e5m2_2 = (x >> 8) & 0xFF; // Upper 8 bits - - // Decode e5m2 to float - auto e5m2_to_f32 = [](uint8_t e5m2) -> float { - // Extract sign, exponent, and mantissa - int sign = (e5m2 & 0x80) ? -1 : 1; // Sign bit - int exponent = ((e5m2 >> 2) & 0x1F) - 15; // 5-bit exponent with bias 15 - int mantissa = e5m2 & 0x03; // 2-bit mantissa - - // Handle special case: zero - if (exponent == -15 && mantissa == 0) { - return 0.0f; - } - - // Convert to float - float value = sign * ldexpf(1.0f + mantissa / 4.0f, exponent); - return value; - }; - - float f1 = e5m2_to_f32(e5m2_1); - float f2 = e5m2_to_f32(e5m2_2); - - // Convert float to IEEE f16 - __half h1 = __float2half_rn(f1); - __half h2 = __float2half_rn(f2); - - // Pack the two f16 values into a single uint32_t - uint32_t f16x2 = (__half_as_ushort(h2) << 16) | __half_as_ushort(h1); - return f16x2; - } - #endif - - template <> - struct vec_cast { - template - inline __attribute__((always_inline)) __device__ static void cast(half* dst, const __hip_fp8_e5m2_fnuz* src) { - #ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED - if constexpr (vec_size == 1) { - dst[0] = half(src[0]); - } else { - #pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - uint32_t y; - uint16_t x = *(uint16_t*)&src[i * 2]; - y = convert_e5m2x2_to_f16x2(x); - *(uint32_t*)&dst[i * 2] = y; - } - } - #else - if constexpr (vec_size == 1) { - dst[0] = half(src[0]); - } else if constexpr (vec_size == 2) { - dst[0] = half(src[0]); - dst[1] = half(src[1]); - } else { - static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); - #pragma unroll - for (uint32_t i = 0; i < vec_size / 4; ++i) { - fast_dequant_f8f16x4<__hip_fp8_e5m2_fnuz, half>((uint32_t*)&src[i * 4], (uint2*)&dst[i * 4]); - } - } - #endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED - } - }; - - template <> - struct vec_cast { - template - inline __attribute__((always_inline)) __device__ static void cast(float* dst, const __hip_bfloat16* src) { - if constexpr (vec_size == 1) { - dst[0] = (float)src[0]; - } else { - #pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((float2*)dst)[i] = __bfloat1622float2(((__hip_bfloat162*)src)[i]); - } - } - } - }; - - template <> - struct vec_cast<__hip_bfloat16, float> { - template - inline __attribute__((always_inline)) __device__ static void cast(__hip_bfloat16* dst, const float* src) { - /*if constexpr (vec_size == 1) { - dst[0] = __hip_bfloat16(src[0]); - } else { - #pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((__hip_bfloat162*)dst)[i] = __float22bfloat162_rn(((float2*)src)[i]); - } - }*/ - //fast but unsafe bfloat conversion... - union f2bf { float f; __hip_bfloat16 bf[2]; } _f2bf; - #pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - _f2bf.f = src[i]; - dst[i] = _f2bf.bf[1]; - } - } - }; - - template - struct vec_t { - inline __attribute__((always_inline)) __device__ float_t& operator[](size_t i); - inline __attribute__((always_inline)) __device__ const float_t& operator[](size_t i) const; - inline __attribute__((always_inline)) __device__ void fill(float_t val); - inline __attribute__((always_inline)) __device__ void load(const float_t* ptr); - inline __attribute__((always_inline)) __device__ void store(float_t* ptr) const; - template - inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src); - template - inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr); - template - inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const; - inline __attribute__((always_inline)) __device__ static void memcpy(float_t* dst, const float_t* src); - inline __attribute__((always_inline)) __device__ float_t* ptr(); - }; - - template - inline __attribute__((always_inline)) __device__ void cast_from_impl(vec_t& dst, - const vec_t& src) { - vec_cast::template cast( - dst.ptr(), const_cast*>(&src)->ptr()); - } - - template - inline __attribute__((always_inline)) __device__ void cast_load_impl(vec_t& dst, - const src_float_t* src_ptr) { - if constexpr (std::is_same_v) { - dst.load(src_ptr); - } else { - vec_t tmp; - tmp.load(src_ptr); - dst.cast_from(tmp); - } - } - - template - inline __attribute__((always_inline)) __device__ void cast_store_impl(tgt_float_t* dst_ptr, - const vec_t& src) { - if constexpr (std::is_same_v) { - src.store(dst_ptr); - } else { - vec_t tmp; - tmp.cast_from(src); - tmp.store(dst_ptr); - } - } - - /******************* vec_t<__hip_fp8_e4m3_fnuz> *******************/ - - // __hip_fp8_e4m3_fnuz x 1 - template <> - struct vec_t<__hip_fp8_e4m3_fnuz, 1> { - __hip_fp8_e4m3_fnuz data; - - inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz& operator[](size_t i) { return ((__hip_fp8_e4m3_fnuz*)(&data))[i]; } - inline __attribute__((always_inline)) __device__ const __hip_fp8_e4m3_fnuz& operator[](size_t i) const { - return ((const __hip_fp8_e4m3_fnuz*)(&data))[i]; - } - inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz* ptr() { return reinterpret_cast<__hip_fp8_e4m3_fnuz*>(&data); } - inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e4m3_fnuz val); - inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e4m3_fnuz* ptr); - inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e4m3_fnuz* ptr) const; - template - inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - - inline __attribute__((always_inline)) __device__ static void memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src); - }; - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e4m3_fnuz, 1>::fill(__hip_fp8_e4m3_fnuz val) { data = val; } - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e4m3_fnuz, 1>::load(const __hip_fp8_e4m3_fnuz* ptr) { data = *ptr; } - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e4m3_fnuz, 1>::store(__hip_fp8_e4m3_fnuz* ptr) const { *ptr = data; } - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e4m3_fnuz, 1>::memcpy(__hip_fp8_e4m3_fnuz* dst, - const __hip_fp8_e4m3_fnuz* src) { - *dst = *src; - } - - // __hip_fp8_e4m3_fnuz x 2 - template <> - struct vec_t<__hip_fp8_e4m3_fnuz, 2> { - __hip_fp8x2_e4m3_fnuz data; - - inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz& operator[](size_t i) { return ((__hip_fp8_e4m3_fnuz*)(&data))[i]; } - inline __attribute__((always_inline)) __device__ const __hip_fp8_e4m3_fnuz& operator[](size_t i) const { - return ((const __hip_fp8_e4m3_fnuz*)(&data))[i]; - } - inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz* ptr() { return reinterpret_cast<__hip_fp8_e4m3_fnuz*>(&data); } - inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e4m3_fnuz val); - inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e4m3_fnuz* ptr); - inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e4m3_fnuz* ptr) const; - template - inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - inline __attribute__((always_inline)) __device__ static void memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src); - }; - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e4m3_fnuz, 2>::fill(__hip_fp8_e4m3_fnuz val) { - data.__x = (__hip_fp8x2_storage_t(val.__x) << 8) | __hip_fp8x2_storage_t(val.__x); - } - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e4m3_fnuz, 2>::load(const __hip_fp8_e4m3_fnuz* ptr) { - data = *((__hip_fp8x2_e4m3_fnuz*)ptr); - } - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e4m3_fnuz, 2>::store(__hip_fp8_e4m3_fnuz* ptr) const { - *((__hip_fp8x2_e4m3_fnuz*)ptr) = data; - } - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e4m3_fnuz, 2>::memcpy(__hip_fp8_e4m3_fnuz* dst, - const __hip_fp8_e4m3_fnuz* src) { - *((__hip_fp8x2_e4m3_fnuz*)dst) = *((__hip_fp8x2_e4m3_fnuz*)src); - } - - // __hip_fp8_e4m3_fnuz x 4 - - template <> - struct vec_t<__hip_fp8_e4m3_fnuz, 4> { - __hip_fp8x4_e4m3_fnuz data; - - inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz& operator[](size_t i) { return ((__hip_fp8_e4m3_fnuz*)(&data))[i]; } - inline __attribute__((always_inline)) __device__ const __hip_fp8_e4m3_fnuz& operator[](size_t i) const { - return ((const __hip_fp8_e4m3_fnuz*)(&data))[i]; - } - inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz* ptr() { return reinterpret_cast<__hip_fp8_e4m3_fnuz*>(&data); } - inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e4m3_fnuz val); - inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e4m3_fnuz* ptr); - inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e4m3_fnuz* ptr) const; - template - inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - - inline __attribute__((always_inline)) __device__ static void memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src); - }; - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e4m3_fnuz, 4>::fill(__hip_fp8_e4m3_fnuz val) { - data.__x = (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); - } - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e4m3_fnuz, 4>::load(const __hip_fp8_e4m3_fnuz* ptr) { - data = *((__hip_fp8x4_e4m3_fnuz*)ptr); - } - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e4m3_fnuz, 4>::store(__hip_fp8_e4m3_fnuz* ptr) const { - *((__hip_fp8x4_e4m3_fnuz*)ptr) = data; - } - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e4m3_fnuz, 4>::memcpy(__hip_fp8_e4m3_fnuz* dst, - const __hip_fp8_e4m3_fnuz* src) { - *((__hip_fp8x4_e4m3_fnuz*)dst) = *((__hip_fp8x4_e4m3_fnuz*)src); - } - - // __hip_fp8_e4m3_fnuz x 8 - - template <> - struct vec_t<__hip_fp8_e4m3_fnuz, 8> { - uint2 data; - - inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz& operator[](size_t i) { return ((__hip_fp8_e4m3_fnuz*)(&data))[i]; } - inline __attribute__((always_inline)) __device__ const __hip_fp8_e4m3_fnuz& operator[](size_t i) const { - return ((const __hip_fp8_e4m3_fnuz*)(&data))[i]; - } - inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz* ptr() { return reinterpret_cast<__hip_fp8_e4m3_fnuz*>(&data); } - inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e4m3_fnuz val); - inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e4m3_fnuz* ptr); - inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e4m3_fnuz* ptr) const; - template - inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - - inline __attribute__((always_inline)) __device__ static void memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src); - }; - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e4m3_fnuz, 8>::fill(__hip_fp8_e4m3_fnuz val) { - ((__hip_fp8x4_e4m3_fnuz*)(&data.x))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); - ((__hip_fp8x4_e4m3_fnuz*)(&data.y))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); - } - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e4m3_fnuz, 8>::load(const __hip_fp8_e4m3_fnuz* ptr) { - data = *((uint2*)ptr); - } - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e4m3_fnuz, 8>::store(__hip_fp8_e4m3_fnuz* ptr) const { - *((uint2*)ptr) = data; - } - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e4m3_fnuz, 8>::memcpy(__hip_fp8_e4m3_fnuz* dst, - const __hip_fp8_e4m3_fnuz* src) { - *((uint2*)dst) = *((uint2*)src); - } - - // __hip_fp8_e4m3_fnuz x 16 or more - template - struct vec_t<__hip_fp8_e4m3_fnuz, vec_size> { - uint4 data[vec_size / 16]; - - inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz& operator[](size_t i) { return ((__hip_fp8_e4m3_fnuz*)data)[i]; } - inline __attribute__((always_inline)) __device__ const __hip_fp8_e4m3_fnuz& operator[](size_t i) const { - return ((const __hip_fp8_e4m3_fnuz*)data)[i]; - } - inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz* ptr() { return reinterpret_cast<__hip_fp8_e4m3_fnuz*>(&data); } - inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e4m3_fnuz val) { - #pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((__hip_fp8x4_e4m3_fnuz*)(&(data[i].x)))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); - ((__hip_fp8x4_e4m3_fnuz*)(&(data[i].y)))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); - ((__hip_fp8x4_e4m3_fnuz*)(&(data[i].z)))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); - ((__hip_fp8x4_e4m3_fnuz*)(&(data[i].w)))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); - } - } - inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e4m3_fnuz* ptr) { - #pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - data[i] = ((uint4*)ptr)[i]; - } - } - inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e4m3_fnuz* ptr) const { - #pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((uint4*)ptr)[i] = data[i]; - } - } - template - inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - - inline __attribute__((always_inline)) __device__ static void memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src) { - #pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((uint4*)dst)[i] = ((uint4*)src)[i]; - } - } - }; - - /******************* vec_t<__hip_fp8_e5m2_fnuz> *******************/ - - // __hip_fp8_e5m2_fnuz x 1 - template <> - struct vec_t<__hip_fp8_e5m2_fnuz, 1> { - __hip_fp8_e5m2_fnuz data; - - inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz& operator[](size_t i) { return ((__hip_fp8_e5m2_fnuz*)(&data))[i]; } - inline __attribute__((always_inline)) __device__ const __hip_fp8_e5m2_fnuz& operator[](size_t i) const { - return ((const __hip_fp8_e5m2_fnuz*)(&data))[i]; - } - inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz* ptr() { return reinterpret_cast<__hip_fp8_e5m2_fnuz*>(&data); } - inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e5m2_fnuz val); - inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e5m2_fnuz* ptr); - inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e5m2_fnuz* ptr) const; - template - inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - - inline __attribute__((always_inline)) __device__ static void memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src); - }; - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e5m2_fnuz, 1>::fill(__hip_fp8_e5m2_fnuz val) { data = val; } - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e5m2_fnuz, 1>::load(const __hip_fp8_e5m2_fnuz* ptr) { data = *ptr; } - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e5m2_fnuz, 1>::store(__hip_fp8_e5m2_fnuz* ptr) const { *ptr = data; } - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e5m2_fnuz, 1>::memcpy(__hip_fp8_e5m2_fnuz* dst, - const __hip_fp8_e5m2_fnuz* src) { - *dst = *src; - } - - // __hip_fp8_e5m2_fnuz x 2 - template <> - struct vec_t<__hip_fp8_e5m2_fnuz, 2> { - __hip_fp8x2_e5m2_fnuz data; - - inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz& operator[](size_t i) { return ((__hip_fp8_e5m2_fnuz*)(&data))[i]; } - inline __attribute__((always_inline)) __device__ const __hip_fp8_e5m2_fnuz& operator[](size_t i) const { - return ((const __hip_fp8_e5m2_fnuz*)(&data))[i]; - } - inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz* ptr() { return reinterpret_cast<__hip_fp8_e5m2_fnuz*>(&data); } - inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e5m2_fnuz val); - inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e5m2_fnuz* ptr); - inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e5m2_fnuz* ptr) const; - template - inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - - inline __attribute__((always_inline)) __device__ static void memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src); - }; - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e5m2_fnuz, 2>::fill(__hip_fp8_e5m2_fnuz val) { - data.__x = (__hip_fp8x2_storage_t(val.__x) << 8) | __hip_fp8x2_storage_t(val.__x); - } - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e5m2_fnuz, 2>::load(const __hip_fp8_e5m2_fnuz* ptr) { - data = *((__hip_fp8x2_e5m2_fnuz*)ptr); - } - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e5m2_fnuz, 2>::store(__hip_fp8_e5m2_fnuz* ptr) const { - *((__hip_fp8x2_e5m2_fnuz*)ptr) = data; - } - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e5m2_fnuz, 2>::memcpy(__hip_fp8_e5m2_fnuz* dst, - const __hip_fp8_e5m2_fnuz* src) { - *((__hip_fp8x2_e5m2_fnuz*)dst) = *((__hip_fp8x2_e5m2_fnuz*)src); - } - - // __hip_fp8_e5m2_fnuz x 4 - - template <> - struct vec_t<__hip_fp8_e5m2_fnuz, 4> { - __hip_fp8x4_e5m2_fnuz data; - - inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz& operator[](size_t i) { return ((__hip_fp8_e5m2_fnuz*)(&data))[i]; } - inline __attribute__((always_inline)) __device__ const __hip_fp8_e5m2_fnuz& operator[](size_t i) const { - return ((const __hip_fp8_e5m2_fnuz*)(&data))[i]; - } - inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz* ptr() { return reinterpret_cast<__hip_fp8_e5m2_fnuz*>(&data); } - inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e5m2_fnuz val); - inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e5m2_fnuz* ptr); - inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e5m2_fnuz* ptr) const; - template - inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - - inline __attribute__((always_inline)) __device__ static void memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src); - }; - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e5m2_fnuz, 4>::fill(__hip_fp8_e5m2_fnuz val) { - data.__x = (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); - } - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e5m2_fnuz, 4>::load(const __hip_fp8_e5m2_fnuz* ptr) { - data = *((__hip_fp8x4_e5m2_fnuz*)ptr); - } - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e5m2_fnuz, 4>::store(__hip_fp8_e5m2_fnuz* ptr) const { - *((__hip_fp8x4_e5m2_fnuz*)ptr) = data; - } - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e5m2_fnuz, 4>::memcpy(__hip_fp8_e5m2_fnuz* dst, - const __hip_fp8_e5m2_fnuz* src) { - *((__hip_fp8x4_e5m2_fnuz*)dst) = *((__hip_fp8x4_e5m2_fnuz*)src); - } - - // __hip_fp8_e5m2_fnuz x 8 - - template <> - struct vec_t<__hip_fp8_e5m2_fnuz, 8> { - uint2 data; - - inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz& operator[](size_t i) { return ((__hip_fp8_e5m2_fnuz*)(&data))[i]; } - inline __attribute__((always_inline)) __device__ const __hip_fp8_e5m2_fnuz& operator[](size_t i) const { - return ((const __hip_fp8_e5m2_fnuz*)(&data))[i]; - } - inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz* ptr() { return reinterpret_cast<__hip_fp8_e5m2_fnuz*>(&data); } - inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e5m2_fnuz val); - inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e5m2_fnuz* ptr); - inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e5m2_fnuz* ptr) const; - template - inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - inline __attribute__((always_inline)) __device__ static void memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src); - }; - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e5m2_fnuz, 8>::fill(__hip_fp8_e5m2_fnuz val) { - ((__hip_fp8x4_e5m2_fnuz*)(&data.x))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); - ((__hip_fp8x4_e5m2_fnuz*)(&data.y))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); - } - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e5m2_fnuz, 8>::load(const __hip_fp8_e5m2_fnuz* ptr) { - data = *((uint2*)ptr); - } - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e5m2_fnuz, 8>::store(__hip_fp8_e5m2_fnuz* ptr) const { - *((uint2*)ptr) = data; - } - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_fp8_e5m2_fnuz, 8>::memcpy(__hip_fp8_e5m2_fnuz* dst, - const __hip_fp8_e5m2_fnuz* src) { - *((uint2*)dst) = *((uint2*)src); - } - - // __hip_fp8_e5m2_fnuz x 16 or more - - template - struct vec_t<__hip_fp8_e5m2_fnuz, vec_size> { - uint4 data[vec_size / 16]; - - inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz& operator[](size_t i) { return ((__hip_fp8_e5m2_fnuz*)data)[i]; } - inline __attribute__((always_inline)) __device__ const __hip_fp8_e5m2_fnuz& operator[](size_t i) const { - return ((const __hip_fp8_e5m2_fnuz*)data)[i]; - } - inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz* ptr() { return reinterpret_cast<__hip_fp8_e5m2_fnuz*>(&data); } - inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e5m2_fnuz val) { - #pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((__hip_fp8x4_e5m2_fnuz*)(&(data[i].x)))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); - ((__hip_fp8x4_e5m2_fnuz*)(&(data[i].y)))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); - ((__hip_fp8x4_e5m2_fnuz*)(&(data[i].z)))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); - ((__hip_fp8x4_e5m2_fnuz*)(&(data[i].w)))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); - } - } - inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e5m2_fnuz* ptr) { - #pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - data[i] = ((uint4*)ptr)[i]; - } - } - inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e5m2_fnuz* ptr) const { - #pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((uint4*)ptr)[i] = data[i]; - } - } - template - inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - inline __attribute__((always_inline)) __device__ static void memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src) { - #pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((uint4*)dst)[i] = ((uint4*)src)[i]; - } - } - }; - - /******************* vec_t *******************/ - - // half x 1 - template <> - struct vec_t { - half data; - - inline __attribute__((always_inline)) __device__ half& operator[](size_t i) { return ((half*)(&data))[i]; } - inline __attribute__((always_inline)) __device__ const half& operator[](size_t i) const { return ((const half*)(&data))[i]; } - inline __attribute__((always_inline)) __device__ half* ptr() { return reinterpret_cast(&data); } - inline __attribute__((always_inline)) __device__ void fill(half val); - inline __attribute__((always_inline)) __device__ void load(const half* ptr); - inline __attribute__((always_inline)) __device__ void store(half* ptr) const; - template - inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - - inline __attribute__((always_inline)) __device__ static void memcpy(half* dst, const half* src); - }; - - inline __attribute__((always_inline)) __device__ void vec_t::fill(half val) { data = val; } - - inline __attribute__((always_inline)) __device__ void vec_t::load(const half* ptr) { data = *ptr; } - - inline __attribute__((always_inline)) __device__ void vec_t::store(half* ptr) const { *ptr = data; } - - inline __attribute__((always_inline)) __device__ void vec_t::memcpy(half* dst, const half* src) { *dst = *src; } - - // half x 2 - template <> - struct vec_t { - half2 data; - - inline __attribute__((always_inline)) __device__ half& operator[](size_t i) { return ((half*)(&data))[i]; } - inline __attribute__((always_inline)) __device__ const half& operator[](size_t i) const { return ((const half*)(&data))[i]; } - inline __attribute__((always_inline)) __device__ half* ptr() { return reinterpret_cast(&data); } - inline __attribute__((always_inline)) __device__ void fill(half val); - inline __attribute__((always_inline)) __device__ void load(const half* ptr); - inline __attribute__((always_inline)) __device__ void store(half* ptr) const; - template - inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - - inline __attribute__((always_inline)) __device__ static void memcpy(half* dst, const half* src); - }; - - inline __attribute__((always_inline)) __device__ void vec_t::fill(half val) { data = make_half2(val, val); } - - inline __attribute__((always_inline)) __device__ void vec_t::load(const half* ptr) { data = *((half2*)ptr); } - - inline __attribute__((always_inline)) __device__ void vec_t::store(half* ptr) const { *((half2*)ptr) = data; } - - inline __attribute__((always_inline)) __device__ void vec_t::memcpy(half* dst, const half* src) { - *((half2*)dst) = *((half2*)src); - } - - // half x 4 - - template <> - struct vec_t { - uint2 data; - - inline __attribute__((always_inline)) __device__ half& operator[](size_t i) { return ((half*)(&data))[i]; } - inline __attribute__((always_inline)) __device__ const half& operator[](size_t i) const { return ((const half*)(&data))[i]; } - inline __attribute__((always_inline)) __device__ half* ptr() { return reinterpret_cast(&data); } - inline __attribute__((always_inline)) __device__ void fill(half val); - inline __attribute__((always_inline)) __device__ void load(const half* ptr); - inline __attribute__((always_inline)) __device__ void store(half* ptr) const; - template - inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - inline __attribute__((always_inline)) __device__ static void memcpy(half* dst, const half* src); - }; - - inline __attribute__((always_inline)) __device__ void vec_t::fill(half val) { - *(half2*)(&data.x) = make_half2(val, val); - *(half2*)(&data.y) = make_half2(val, val); - } - - inline __attribute__((always_inline)) __device__ void vec_t::load(const half* ptr) { data = *((uint2*)ptr); } - - inline __attribute__((always_inline)) __device__ void vec_t::store(half* ptr) const { *((uint2*)ptr) = data; } - - inline __attribute__((always_inline)) __device__ void vec_t::memcpy(half* dst, const half* src) { - *((uint2*)dst) = *((uint2*)src); - } - - // half x 8 or more - - template - struct vec_t { - uint4 data[vec_size / 8]; - inline __attribute__((always_inline)) __device__ half& operator[](size_t i) { return ((half*)data)[i]; } - inline __attribute__((always_inline)) __device__ const half& operator[](size_t i) const { return ((const half*)data)[i]; } - inline __attribute__((always_inline)) __device__ half* ptr() { return reinterpret_cast(&data); } - inline __attribute__((always_inline)) __device__ void fill(half val) { - #pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - *(half2*)(&(data[i].x)) = make_half2(val, val); - *(half2*)(&(data[i].y)) = make_half2(val, val); - *(half2*)(&(data[i].z)) = make_half2(val, val); - *(half2*)(&(data[i].w)) = make_half2(val, val); - } - } - inline __attribute__((always_inline)) __device__ void load(const half* ptr) { - #pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - data[i] = ((uint4*)ptr)[i]; - } - } - inline __attribute__((always_inline)) __device__ void store(half* ptr) const { - #pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - ((uint4*)ptr)[i] = data[i]; - } - } - template - inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - inline __attribute__((always_inline)) __device__ static void memcpy(half* dst, const half* src) { - #pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - ((uint4*)dst)[i] = ((uint4*)src)[i]; - } - } - }; - - /******************* vec_t<__hip_bfloat16> *******************/ - - // __hip_bfloat16 x 1 - template <> - struct vec_t<__hip_bfloat16, 1> { - __hip_bfloat16 data; - inline __attribute__((always_inline)) __device__ __hip_bfloat16& operator[](size_t i) { return ((__hip_bfloat16*)(&data))[i]; } - inline __attribute__((always_inline)) __device__ const __hip_bfloat16& operator[](size_t i) const { - return ((const __hip_bfloat16*)(&data))[i]; - } - inline __attribute__((always_inline)) __device__ __hip_bfloat16* ptr() { return reinterpret_cast<__hip_bfloat16*>(&data); } - inline __attribute__((always_inline)) __device__ void fill(__hip_bfloat16 val); - inline __attribute__((always_inline)) __device__ void load(const __hip_bfloat16* ptr); - inline __attribute__((always_inline)) __device__ void store(__hip_bfloat16* ptr) const; - template - inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - inline __attribute__((always_inline)) __device__ static void memcpy(__hip_bfloat16* dst, const __hip_bfloat16* src); - }; - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_bfloat16, 1>::fill(__hip_bfloat16 val) { data = val; } - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_bfloat16, 1>::load(const __hip_bfloat16* ptr) { data = *ptr; } - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_bfloat16, 1>::store(__hip_bfloat16* ptr) const { *ptr = data; } - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_bfloat16, 1>::memcpy(__hip_bfloat16* dst, const __hip_bfloat16* src) { - *dst = *src; - } - - // __hip_bfloat16 x 2 - template <> - struct vec_t<__hip_bfloat16, 2> { - __hip_bfloat162 data; - - inline __attribute__((always_inline)) __device__ __hip_bfloat16& operator[](size_t i) { return ((__hip_bfloat16*)(&data))[i]; } - inline __attribute__((always_inline)) __device__ const __hip_bfloat16& operator[](size_t i) const { - return ((const __hip_bfloat16*)(&data))[i]; - } - inline __attribute__((always_inline)) __device__ __hip_bfloat16* ptr() { return reinterpret_cast<__hip_bfloat16*>(&data); } - inline __attribute__((always_inline)) __device__ void fill(__hip_bfloat16 val); - inline __attribute__((always_inline)) __device__ void load(const __hip_bfloat16* ptr); - inline __attribute__((always_inline)) __device__ void store(__hip_bfloat16* ptr) const; - template - inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - inline __attribute__((always_inline)) __device__ static void memcpy(__hip_bfloat16* dst, const __hip_bfloat16* src); - }; - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_bfloat16, 2>::fill(__hip_bfloat16 val) { - data = make_bfloat162(val, val); - } - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_bfloat16, 2>::load(const __hip_bfloat16* ptr) { - data = *((__hip_bfloat162*)ptr); - } - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_bfloat16, 2>::store(__hip_bfloat16* ptr) const { - *((__hip_bfloat162*)ptr) = data; - } - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_bfloat16, 2>::memcpy(__hip_bfloat16* dst, const __hip_bfloat16* src) { - *((__hip_bfloat162*)dst) = *((__hip_bfloat162*)src); - } - - // __hip_bfloat16 x 4 - - template <> - struct vec_t<__hip_bfloat16, 4> { - uint2 data; - - inline __attribute__((always_inline)) __device__ __hip_bfloat16& operator[](size_t i) { return ((__hip_bfloat16*)(&data))[i]; } - inline __attribute__((always_inline)) __device__ const __hip_bfloat16& operator[](size_t i) const { - return ((const __hip_bfloat16*)(&data))[i]; - } - inline __attribute__((always_inline)) __device__ __hip_bfloat16* ptr() { return reinterpret_cast<__hip_bfloat16*>(&data); } - inline __attribute__((always_inline)) __device__ void fill(__hip_bfloat16 val); - inline __attribute__((always_inline)) __device__ void load(const __hip_bfloat16* ptr); - inline __attribute__((always_inline)) __device__ void store(__hip_bfloat16* ptr) const; - template - inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - inline __attribute__((always_inline)) __device__ static void memcpy(__hip_bfloat16* dst, const __hip_bfloat16* src); - }; - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_bfloat16, 4>::fill(__hip_bfloat16 val) { - *(__hip_bfloat162*)(&data.x) = make_bfloat162(val, val); - *(__hip_bfloat162*)(&data.y) = make_bfloat162(val, val); - } - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_bfloat16, 4>::load(const __hip_bfloat16* ptr) { - data = *((uint2*)ptr); - } - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_bfloat16, 4>::store(__hip_bfloat16* ptr) const { - *((uint2*)ptr) = data; - } - - inline __attribute__((always_inline)) __device__ void vec_t<__hip_bfloat16, 4>::memcpy(__hip_bfloat16* dst, const __hip_bfloat16* src) { - *((uint2*)dst) = *((uint2*)src); - } - - // __hip_bfloat16 x 8 or more - - template - struct vec_t<__hip_bfloat16, vec_size> { - uint4 data[vec_size / 8]; - - inline __attribute__((always_inline)) __device__ __hip_bfloat16& operator[](size_t i) { return ((__hip_bfloat16*)data)[i]; } - inline __attribute__((always_inline)) __device__ const __hip_bfloat16& operator[](size_t i) const { - return ((const __hip_bfloat16*)data)[i]; - } - inline __attribute__((always_inline)) __device__ __hip_bfloat16* ptr() { return reinterpret_cast<__hip_bfloat16*>(&data); } - inline __attribute__((always_inline)) __device__ void fill(__hip_bfloat16 val) { - #pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - *(__hip_bfloat162*)(&(data[i].x)) = make_bfloat162(val, val); - *(__hip_bfloat162*)(&(data[i].y)) = make_bfloat162(val, val); - *(__hip_bfloat162*)(&(data[i].z)) = make_bfloat162(val, val); - *(__hip_bfloat162*)(&(data[i].w)) = make_bfloat162(val, val); - } - } - inline __attribute__((always_inline)) __device__ void load(const __hip_bfloat16* ptr) { - #pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - data[i] = ((uint4*)ptr)[i]; - } - } - inline __attribute__((always_inline)) __device__ void store(__hip_bfloat16* ptr) const { - #pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - ((uint4*)ptr)[i] = data[i]; - } - } - template - inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - inline __attribute__((always_inline)) __device__ static void memcpy(__hip_bfloat16* dst, const __hip_bfloat16* src) { - #pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - ((uint4*)dst)[i] = ((uint4*)src)[i]; - } - } - }; - - /******************* vec_t *******************/ - - // float x 1 - - template <> - struct vec_t { - float data; - - inline __attribute__((always_inline)) __device__ float& operator[](size_t i) { return ((float*)(&data))[i]; } - inline __attribute__((always_inline)) __device__ const float& operator[](size_t i) const { return ((const float*)(&data))[i]; } - inline __attribute__((always_inline)) __device__ float* ptr() { return reinterpret_cast(&data); } - inline __attribute__((always_inline)) __device__ void fill(float val); - inline __attribute__((always_inline)) __device__ void load(const float* ptr); - inline __attribute__((always_inline)) __device__ void store(float* ptr) const; - template - inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - inline __attribute__((always_inline)) __device__ static void memcpy(float* dst, const float* src); - }; - - inline __attribute__((always_inline)) __device__ void vec_t::fill(float val) { data = val; } - - inline __attribute__((always_inline)) __device__ void vec_t::load(const float* ptr) { data = *ptr; } - - inline __attribute__((always_inline)) __device__ void vec_t::store(float* ptr) const { *ptr = data; } - - inline __attribute__((always_inline)) __device__ void vec_t::memcpy(float* dst, const float* src) { *dst = *src; } - - // float x 2 - - template <> - struct vec_t { - float2 data; - - inline __attribute__((always_inline)) __device__ float& operator[](size_t i) { return ((float*)(&data))[i]; } - inline __attribute__((always_inline)) __device__ const float& operator[](size_t i) const { return ((const float*)(&data))[i]; } - inline __attribute__((always_inline)) __device__ float* ptr() { return reinterpret_cast(&data); } - inline __attribute__((always_inline)) __device__ void fill(float val); - inline __attribute__((always_inline)) __device__ void load(const float* ptr); - inline __attribute__((always_inline)) __device__ void store(float* ptr) const; - template - inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - inline __attribute__((always_inline)) __device__ static void memcpy(float* dst, const float* src); - }; - - inline __attribute__((always_inline)) __device__ void vec_t::fill(float val) { data = make_float2(val, val); } - - inline __attribute__((always_inline)) __device__ void vec_t::load(const float* ptr) { data = *((float2*)ptr); } - - inline __attribute__((always_inline)) __device__ void vec_t::store(float* ptr) const { *((float2*)ptr) = data; } - - inline __attribute__((always_inline)) __device__ void vec_t::memcpy(float* dst, const float* src) { - *((float2*)dst) = *((float2*)src); - } - - // float x 4 or more - template - struct vec_t { - float4 data[vec_size / 4]; - - inline __attribute__((always_inline)) __device__ float& operator[](size_t i) { return ((float*)(data))[i]; } - inline __attribute__((always_inline)) __device__ const float& operator[](size_t i) const { return ((const float*)(data))[i]; } - inline __attribute__((always_inline)) __device__ float* ptr() { return reinterpret_cast(&data); } - inline __attribute__((always_inline)) __device__ void fill(float val) { - #pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - data[i] = make_float4(val, val, val, val); - } - } - inline __attribute__((always_inline)) __device__ void load(const float* ptr) { - #pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - data[i] = ((float4*)ptr)[i]; - } - } - inline __attribute__((always_inline)) __device__ void store(float* ptr) const { - #pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((float4*)ptr)[i] = data[i]; - } - } - template - inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) { - cast_from_impl(*this, src); - } - template - inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) { - cast_load_impl(*this, ptr); - } - template - inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const { - cast_store_impl(ptr, *this); - } - inline __attribute__((always_inline)) __device__ static void memcpy(float* dst, const float* src) { - #pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((float4*)dst)[i] = ((float4*)src)[i]; - } - } - }; - - } // namespace flashinfer +template +__device__ void fast_dequant_f8f16x4(uint32_t* input, uint2* output) +{ + uint32_t q = *input; + if constexpr(std::is_same_v && std::is_same_v) + { + output->x = __byte_perm(0U, q, 0x5140); + output->y = __byte_perm(0U, q, 0x7362); + } + else + { + constexpr int FP8_EXPONENT = get_exponent_bits(); + constexpr int FP8_MANTISSA = get_mantissa_bits(); + constexpr int FP16_EXPONENT = get_exponent_bits(); + + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; + // Calculate MASK for extracting mantissa and exponent + // XXX: duplicate defs of `MASK1` and `MASK2`, + // in the HIP file "include/hip/amd_detail/amd_device_functions.h". + constexpr int MASK1_orig = 0x80000000; + constexpr int MASK2_orig = MASK1_orig >> (FP8_EXPONENT + FP8_MANTISSA); + constexpr int MASK3 = MASK2_orig & 0x7fffffff; + constexpr int MASK = MASK3 | (MASK3 >> 16); + q = __byte_perm(q, q, 0x1302); + + // Extract and shift FP8 values to FP16 format + uint32_t Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + uint32_t Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); + + constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + // Construct and apply exponent bias + if constexpr(std::is_same_v) + { + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + + // Convert to half2 and apply bias + *(half2*)&(output->x) = __hmul2(*reinterpret_cast(&Out1), bias_reg); + *(half2*)&(output->y) = __hmul2(*reinterpret_cast(&Out2), bias_reg); + } + else + { + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const __hip_bfloat162 bias_reg = + __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + // Convert to bfloat162 and apply bias + *(__hip_bfloat162*)&(output->x) = + __hmul2(*reinterpret_cast(&Out1), bias_reg); + *(__hip_bfloat162*)&(output->y) = + __hmul2(*reinterpret_cast(&Out2), bias_reg); + } + } +} + +template <> +struct vec_cast<__hip_bfloat16, __hip_fp8_e4m3_fnuz> +{ + template + inline __attribute__((always_inline)) __device__ static void + cast(__hip_bfloat16* dst, const __hip_fp8_e4m3_fnuz* src) + { + if constexpr(vec_size == 1) + { + dst[0] = __hip_bfloat16(src[0]); + } + else if constexpr(vec_size == 2) + { + dst[0] = __hip_bfloat16(src[0]); + dst[1] = __hip_bfloat16(src[1]); + } + else + { + static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); +#pragma unroll + for(uint32_t i = 0; i < vec_size / 4; ++i) + { + fast_dequant_f8f16x4<__hip_fp8_e4m3_fnuz, __hip_bfloat16>((uint32_t*)&src[i * 4], + (uint2*)&dst[i * 4]); + } + } + } +}; + +template <> +struct vec_cast<__hip_bfloat16, __hip_fp8_e5m2_fnuz> +{ + template + inline __attribute__((always_inline)) __device__ static void + cast(__hip_bfloat16* dst, const __hip_fp8_e5m2_fnuz* src) + { + if constexpr(vec_size == 1) + { + dst[0] = __hip_bfloat16(src[0]); + } + else if constexpr(vec_size == 2) + { + dst[0] = __hip_bfloat16(src[0]); + dst[1] = __hip_bfloat16(src[1]); + } + else + { + static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); +#pragma unroll + for(uint32_t i = 0; i < vec_size / 4; ++i) + { + fast_dequant_f8f16x4<__hip_fp8_e5m2_fnuz, __hip_bfloat16>((uint32_t*)&src[i * 4], + (uint2*)&dst[i * 4]); + } + } + } +}; + +#if defined(__HIPCC__) || (defined(__clang__) && defined(__HIP__)) || defined(__HIPCC_RTC__) +// Function to convert half-precision to e4m3 +__device__ uint8_t convert_f32_to_e4m3(float val) +{ + // Define the range of e4m3 + // 1. Minimum representable value for e4m3 + // 2. Binary 1000.000 in e4m3 + // 3. FLT_MIN is not suitable for e4m3 because e4m3 has a much smaller dynamic range. + float min_e4m3 = -8.0f; + // 1. Maximum representable value for e4m3 + // 2. Binary 0111.111 in e4m3 + // FLT_MAX far exceeds the maximum value representable in e4m3. + float max_e4m3 = 7.875f; + + // Saturate the value to the e4m3 range + val = fminf(fmaxf(val, min_e4m3), max_e4m3); + + // Perform conversion + // Decompose into mantissa and exponent + int exp; + float mantissa = frexpf(val, &exp); + + // Encode sign bit + uint8_t sign = (mantissa < 0) ? 0x80 : 0x00; + + // Normalize mantissa and encode exponent + mantissa = fabsf(mantissa) * 16.0f; // Scale mantissa for e4m3's 3-bit precision + uint8_t exponent = static_cast(exp + 7); // Bias of 7 for e4m3 + + // Quantize mantissa + // Apply round-to-nearest-even to the mantissa + uint8_t quant_mantissa = static_cast(roundf(mantissa)) & 0x07; + + // Combine into 8 bits: [sign][exponent][mantissa] + return sign | (exponent << 3) | quant_mantissa; +} + +__device__ __half2 convert_uint32_to_half2(uint32_t input) +{ + // Extract the low and high 16 bits + uint16_t low_val = input & 0xFFFF; + uint16_t high_val = (input >> 16) & 0xFFFF; + // Convert to __half + __half low_half = __float2half(static_cast(low_val)); + __half high_half = __float2half(static_cast(high_val)); + // Pack into __half2 + return __halves2half2(low_half, high_half); +} + +// Convert f16x2 (__half2) to e4m3x2 (packed 16-bit) +__device__ uint16_t convert_f16x2_to_e4m3x2(__half2 x) +{ + float f32_0 = __half2float(__low2half(x)); + float f32_1 = __half2float(__high2half(x)); + uint8_t e4m3_0 = convert_f32_to_e4m3(f32_0); + uint8_t e4m3_1 = convert_f32_to_e4m3(f32_1); + return (static_cast(e4m3_1) << 8) | e4m3_0; +} +#endif + +template <> +struct vec_cast<__hip_fp8_e4m3_fnuz, half> +{ + template + inline __attribute__((always_inline)) __device__ static void cast(__hip_fp8_e4m3_fnuz* dst, + const half* src) + { +#ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + if constexpr(vec_size == 1) + { + dst[0] = __hip_fp8_e4m3_fnuz(src[0]); + } + else + { +#pragma unroll + for(size_t i = 0; i < vec_size / 2; ++i) + { + uint16_t y; + uint32_t x = *(uint32_t*)&src[i * 2]; + __half2 x_h2 = convert_uint32_to_half2(x); + y = convert_f16x2_to_e4m3x2(x_h2); + *(uint16_t*)&dst[i * 2] = y; + } + } +#else +#pragma unroll + for(size_t i = 0; i < vec_size; ++i) + { + dst[i] = __hip_fp8_e4m3_fnuz(src[i]); + } +#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + } +}; + +#if defined(__HIPCC__) || (defined(__clang__) && defined(__HIP__)) || defined(__HIPCC_RTC__) +__device__ uint16_t convert_f16x2_to_e5m2x2(uint32_t x) +{ + // Unpack the two 16-bit half-precision floats from the input + // Extract lower 16 bits + __half h1 = __ushort_as_half(x & 0xFFFF); + // Extract upper 16 bits + __half h2 = __ushort_as_half((x >> 16) & 0xFFFF); + + // Define the range of e5m2 + // Minimum representable value for e5m2 + const float min_e5m2 = -8.0f; + // Maximum representable value for e5m2 + const float max_e5m2 = 7.75f; + + // Helper lambda for conversion + auto f32_to_e5m2 = [min_e5m2, max_e5m2](float val) -> uint8_t { + // Saturate the val + val = fminf(fmaxf(val, min_e5m2), max_e5m2); + + // Decompose into mantissa and exponent + int exp; + float mantissa = frexpf(val, &exp); + + // Encode sign bit + uint8_t sign = (mantissa < 0) ? 0x10 : 0x00; // Sign in bit 4 + mantissa = fabsf(mantissa); + + // Normalize mantissa and encode exponent + mantissa *= 4.0f; // Scale for 2-bit mantissa + uint8_t exponent = static_cast(exp + 7); // Apply bias for e5m2 + + // Apply round-to-nearest-even + uint8_t quant_mantissa = static_cast(roundf(mantissa)) & 0x03; + + // Combine into 5 bits: [sign][exponent][mantissa] + return sign | (exponent << 2) | quant_mantissa; + }; + + // Convert the two __half values to e5m2 + uint8_t e5m2_1 = f32_to_e5m2(__half2float(h1)); + uint8_t e5m2_2 = f32_to_e5m2(__half2float(h2)); + + // Pack the two e5m2 values into a single 16-bit output + return (e5m2_2 << 8) | e5m2_1; +} +#endif + +template <> +struct vec_cast<__hip_fp8_e5m2_fnuz, half> +{ + template + inline __attribute__((always_inline)) __device__ static void cast(__hip_fp8_e5m2_fnuz* dst, + const half* src) + { +#ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + if constexpr(vec_size == 1) + { + dst[0] = __hip_fp8_e5m2_fnuz(src[0]); + } + else + { +#pragma unroll + for(size_t i = 0; i < vec_size / 2; ++i) + { + uint16_t y; + uint32_t x = *(uint32_t*)&src[i * 2]; + y = convert_f16x2_to_e5m2x2(x); + *(uint16_t*)&dst[i * 2] = y; + } + } +#else +#pragma unroll + for(size_t i = 0; i < vec_size; ++i) + { + dst[i] = __hip_fp8_e5m2_fnuz(src[i]); + } +#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + } +}; + +#if defined(__HIPCC__) || (defined(__clang__) && defined(__HIP__)) || defined(__HIPCC_RTC__) +__device__ uint32_t convert_e4m3x2_to_f16x2(uint16_t x) +{ + // Extract two e4m3 values from the 16-bit input + uint8_t e4m3_1 = x & 0xFF; // Lower 8 bits + uint8_t e4m3_2 = (x >> 8) & 0xFF; // Upper 8 bits + + // Decode e4m3 to float + auto e4m3_to_f32 = [](uint8_t e4m3) -> float { + // Extract sign, exponent, and mantissa + int sign = (e4m3 & 0x80) ? -1 : 1; + int exponent = ((e4m3 >> 3) & 0x0F) - 7; // 4-bit exponent with bias 7 + int mantissa = e4m3 & 0x07; // 3-bit mantissa + + // Handle special case: zero + if(exponent == -7 && mantissa == 0) + { + return 0.0f; + } + + // Convert to float + float f32_val = sign * ldexpf(1.0f + mantissa / 8.0f, exponent); + return f32_val; + }; + + float f1 = e4m3_to_f32(e4m3_1); + float f2 = e4m3_to_f32(e4m3_2); + + // Convert float to IEEE f16 + __half h1 = __float2half_rn(f1); + __half h2 = __float2half_rn(f2); + + // Pack the two f16 values into a single uint32_t + uint32_t f16x2 = (__half_as_ushort(h2) << 16) | __half_as_ushort(h1); + return f16x2; +} +#endif + +template <> +struct vec_cast +{ + template + inline __attribute__((always_inline)) __device__ static void + cast(half* dst, const __hip_fp8_e4m3_fnuz* src) + { +#ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + if constexpr(vec_size == 1) + { + dst[0] = half(src[0]); + } + else + { +#pragma unroll + for(size_t i = 0; i < vec_size / 2; ++i) + { + uint32_t y; + uint16_t x = *(uint16_t*)&src[i * 2]; + y = convert_e4m3x2_to_f16x2(x); + *(uint32_t*)&dst[i * 2] = y; + } + } +#else + if constexpr(vec_size == 1) + { + dst[0] = half(src[0]); + } + else if constexpr(vec_size == 2) + { + dst[0] = half(src[0]); + dst[1] = half(src[1]); + } + else + { + static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); +#pragma unroll + for(uint32_t i = 0; i < vec_size / 4; ++i) + { + fast_dequant_f8f16x4<__hip_fp8_e4m3_fnuz, half>((uint32_t*)&src[i * 4], + (uint2*)&dst[i * 4]); + } + } +#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + } +}; + +#if defined(__HIPCC__) || (defined(__clang__) && defined(__HIP__)) || defined(__HIPCC_RTC__) +__device__ uint32_t convert_e5m2x2_to_f16x2(uint16_t x) +{ + // Extract two e5m2 values from the 16-bit input + uint8_t e5m2_1 = x & 0xFF; // Lower 8 bits + uint8_t e5m2_2 = (x >> 8) & 0xFF; // Upper 8 bits + + // Decode e5m2 to float + auto e5m2_to_f32 = [](uint8_t e5m2) -> float { + // Extract sign, exponent, and mantissa + int sign = (e5m2 & 0x80) ? -1 : 1; // Sign bit + int exponent = ((e5m2 >> 2) & 0x1F) - 15; // 5-bit exponent with bias 15 + int mantissa = e5m2 & 0x03; // 2-bit mantissa + + // Handle special case: zero + if(exponent == -15 && mantissa == 0) + { + return 0.0f; + } + + // Convert to float + float value = sign * ldexpf(1.0f + mantissa / 4.0f, exponent); + return value; + }; + + float f1 = e5m2_to_f32(e5m2_1); + float f2 = e5m2_to_f32(e5m2_2); + + // Convert float to IEEE f16 + __half h1 = __float2half_rn(f1); + __half h2 = __float2half_rn(f2); + + // Pack the two f16 values into a single uint32_t + uint32_t f16x2 = (__half_as_ushort(h2) << 16) | __half_as_ushort(h1); + return f16x2; +} +#endif + +template <> +struct vec_cast +{ + template + inline __attribute__((always_inline)) __device__ static void + cast(half* dst, const __hip_fp8_e5m2_fnuz* src) + { +#ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + if constexpr(vec_size == 1) + { + dst[0] = half(src[0]); + } + else + { +#pragma unroll + for(size_t i = 0; i < vec_size / 2; ++i) + { + uint32_t y; + uint16_t x = *(uint16_t*)&src[i * 2]; + y = convert_e5m2x2_to_f16x2(x); + *(uint32_t*)&dst[i * 2] = y; + } + } +#else + if constexpr(vec_size == 1) + { + dst[0] = half(src[0]); + } + else if constexpr(vec_size == 2) + { + dst[0] = half(src[0]); + dst[1] = half(src[1]); + } + else + { + static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); +#pragma unroll + for(uint32_t i = 0; i < vec_size / 4; ++i) + { + fast_dequant_f8f16x4<__hip_fp8_e5m2_fnuz, half>((uint32_t*)&src[i * 4], + (uint2*)&dst[i * 4]); + } + } +#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + } +}; + +template <> +struct vec_cast +{ + template + inline __attribute__((always_inline)) __device__ static void cast(float* dst, + const __hip_bfloat16* src) + { + if constexpr(vec_size == 1) + { + dst[0] = (float)src[0]; + } + else + { +#pragma unroll + for(size_t i = 0; i < vec_size / 2; ++i) + { + ((float2*)dst)[i] = __bfloat1622float2(((__hip_bfloat162*)src)[i]); + } + } + } +}; + +template <> +struct vec_cast<__hip_bfloat16, float> +{ + template + inline __attribute__((always_inline)) __device__ static void cast(__hip_bfloat16* dst, + const float* src) + { + /*if constexpr (vec_size == 1) { + dst[0] = __hip_bfloat16(src[0]); + } else { + #pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((__hip_bfloat162*)dst)[i] = __float22bfloat162_rn(((float2*)src)[i]); + } + }*/ + // fast but unsafe bfloat conversion... + union f2bf + { + float f; + __hip_bfloat16 bf[2]; + } _f2bf; +#pragma unroll + for(size_t i = 0; i < vec_size; ++i) + { + _f2bf.f = src[i]; + dst[i] = _f2bf.bf[1]; + } + } +}; + +template +struct vec_t +{ + inline __attribute__((always_inline)) __device__ float_t& operator[](size_t i); + inline __attribute__((always_inline)) __device__ const float_t& operator[](size_t i) const; + inline __attribute__((always_inline)) __device__ void fill(float_t val); + inline __attribute__((always_inline)) __device__ void load(const float_t* ptr); + inline __attribute__((always_inline)) __device__ void store(float_t* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src); + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr); + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const; + inline __attribute__((always_inline)) __device__ static void memcpy(float_t* dst, + const float_t* src); + inline __attribute__((always_inline)) __device__ float_t* ptr(); +}; + +template +inline __attribute__((always_inline)) __device__ void +cast_from_impl(vec_t& dst, const vec_t& src) +{ + vec_cast::template cast( + dst.ptr(), const_cast*>(&src)->ptr()); +} + +template +inline __attribute__((always_inline)) __device__ void +cast_load_impl(vec_t& dst, const src_float_t* src_ptr) +{ + if constexpr(std::is_same_v) + { + dst.load(src_ptr); + } + else + { + vec_t tmp; + tmp.load(src_ptr); + dst.cast_from(tmp); + } +} + +template +inline __attribute__((always_inline)) __device__ void +cast_store_impl(tgt_float_t* dst_ptr, const vec_t& src) +{ + if constexpr(std::is_same_v) + { + src.store(dst_ptr); + } + else + { + vec_t tmp; + tmp.cast_from(src); + tmp.store(dst_ptr); + } +} + +/******************* vec_t<__hip_fp8_e4m3_fnuz> *******************/ + +// __hip_fp8_e4m3_fnuz x 1 +template <> +struct vec_t<__hip_fp8_e4m3_fnuz, 1> +{ + __hip_fp8_e4m3_fnuz data; + + inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz& operator[](size_t i) + { + return ((__hip_fp8_e4m3_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ const __hip_fp8_e4m3_fnuz& + operator[](size_t i) const + { + return ((const __hip_fp8_e4m3_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz* ptr() + { + return reinterpret_cast<__hip_fp8_e4m3_fnuz*>(&data); + } + inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e4m3_fnuz val); + inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e4m3_fnuz* ptr); + inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e4m3_fnuz* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + + inline __attribute__((always_inline)) __device__ static void + memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src); +}; + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e4m3_fnuz, 1>::fill(__hip_fp8_e4m3_fnuz val) +{ + data = val; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e4m3_fnuz, 1>::load(const __hip_fp8_e4m3_fnuz* ptr) +{ + data = *ptr; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e4m3_fnuz, 1>::store(__hip_fp8_e4m3_fnuz* ptr) const +{ + *ptr = data; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e4m3_fnuz, 1>::memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src) +{ + *dst = *src; +} + +// __hip_fp8_e4m3_fnuz x 2 +template <> +struct vec_t<__hip_fp8_e4m3_fnuz, 2> +{ + __hip_fp8x2_e4m3_fnuz data; + + inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz& operator[](size_t i) + { + return ((__hip_fp8_e4m3_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ const __hip_fp8_e4m3_fnuz& + operator[](size_t i) const + { + return ((const __hip_fp8_e4m3_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz* ptr() + { + return reinterpret_cast<__hip_fp8_e4m3_fnuz*>(&data); + } + inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e4m3_fnuz val); + inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e4m3_fnuz* ptr); + inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e4m3_fnuz* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void + memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src); +}; + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e4m3_fnuz, 2>::fill(__hip_fp8_e4m3_fnuz val) +{ + data.__x = (__hip_fp8x2_storage_t(val.__x) << 8) | __hip_fp8x2_storage_t(val.__x); +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e4m3_fnuz, 2>::load(const __hip_fp8_e4m3_fnuz* ptr) +{ + data = *((__hip_fp8x2_e4m3_fnuz*)ptr); +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e4m3_fnuz, 2>::store(__hip_fp8_e4m3_fnuz* ptr) const +{ + *((__hip_fp8x2_e4m3_fnuz*)ptr) = data; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e4m3_fnuz, 2>::memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src) +{ + *((__hip_fp8x2_e4m3_fnuz*)dst) = *((__hip_fp8x2_e4m3_fnuz*)src); +} + +// __hip_fp8_e4m3_fnuz x 4 + +template <> +struct vec_t<__hip_fp8_e4m3_fnuz, 4> +{ + __hip_fp8x4_e4m3_fnuz data; + + inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz& operator[](size_t i) + { + return ((__hip_fp8_e4m3_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ const __hip_fp8_e4m3_fnuz& + operator[](size_t i) const + { + return ((const __hip_fp8_e4m3_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz* ptr() + { + return reinterpret_cast<__hip_fp8_e4m3_fnuz*>(&data); + } + inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e4m3_fnuz val); + inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e4m3_fnuz* ptr); + inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e4m3_fnuz* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + + inline __attribute__((always_inline)) __device__ static void + memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src); +}; + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e4m3_fnuz, 4>::fill(__hip_fp8_e4m3_fnuz val) +{ + data.__x = (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e4m3_fnuz, 4>::load(const __hip_fp8_e4m3_fnuz* ptr) +{ + data = *((__hip_fp8x4_e4m3_fnuz*)ptr); +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e4m3_fnuz, 4>::store(__hip_fp8_e4m3_fnuz* ptr) const +{ + *((__hip_fp8x4_e4m3_fnuz*)ptr) = data; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e4m3_fnuz, 4>::memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src) +{ + *((__hip_fp8x4_e4m3_fnuz*)dst) = *((__hip_fp8x4_e4m3_fnuz*)src); +} + +// __hip_fp8_e4m3_fnuz x 8 + +template <> +struct vec_t<__hip_fp8_e4m3_fnuz, 8> +{ + uint2 data; + + inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz& operator[](size_t i) + { + return ((__hip_fp8_e4m3_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ const __hip_fp8_e4m3_fnuz& + operator[](size_t i) const + { + return ((const __hip_fp8_e4m3_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz* ptr() + { + return reinterpret_cast<__hip_fp8_e4m3_fnuz*>(&data); + } + inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e4m3_fnuz val); + inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e4m3_fnuz* ptr); + inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e4m3_fnuz* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + + inline __attribute__((always_inline)) __device__ static void + memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src); +}; + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e4m3_fnuz, 8>::fill(__hip_fp8_e4m3_fnuz val) +{ + ((__hip_fp8x4_e4m3_fnuz*)(&data.x))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e4m3_fnuz*)(&data.y))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e4m3_fnuz, 8>::load(const __hip_fp8_e4m3_fnuz* ptr) +{ + data = *((uint2*)ptr); +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e4m3_fnuz, 8>::store(__hip_fp8_e4m3_fnuz* ptr) const +{ + *((uint2*)ptr) = data; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e4m3_fnuz, 8>::memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src) +{ + *((uint2*)dst) = *((uint2*)src); +} + +// __hip_fp8_e4m3_fnuz x 16 or more +template +struct vec_t<__hip_fp8_e4m3_fnuz, vec_size> +{ + uint4 data[vec_size / 16]; + + inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz& operator[](size_t i) + { + return ((__hip_fp8_e4m3_fnuz*)data)[i]; + } + inline __attribute__((always_inline)) __device__ const __hip_fp8_e4m3_fnuz& + operator[](size_t i) const + { + return ((const __hip_fp8_e4m3_fnuz*)data)[i]; + } + inline __attribute__((always_inline)) __device__ __hip_fp8_e4m3_fnuz* ptr() + { + return reinterpret_cast<__hip_fp8_e4m3_fnuz*>(&data); + } + inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e4m3_fnuz val) + { +#pragma unroll + for(size_t i = 0; i < vec_size / 16; ++i) + { + ((__hip_fp8x4_e4m3_fnuz*)(&(data[i].x)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e4m3_fnuz*)(&(data[i].y)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e4m3_fnuz*)(&(data[i].z)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e4m3_fnuz*)(&(data[i].w)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + } + } + inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e4m3_fnuz* ptr) + { +#pragma unroll + for(size_t i = 0; i < vec_size / 16; ++i) + { + data[i] = ((uint4*)ptr)[i]; + } + } + inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e4m3_fnuz* ptr) const + { +#pragma unroll + for(size_t i = 0; i < vec_size / 16; ++i) + { + ((uint4*)ptr)[i] = data[i]; + } + } + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + + inline __attribute__((always_inline)) __device__ static void + memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src) + { +#pragma unroll + for(size_t i = 0; i < vec_size / 16; ++i) + { + ((uint4*)dst)[i] = ((uint4*)src)[i]; + } + } +}; + +/******************* vec_t<__hip_fp8_e5m2_fnuz> *******************/ + +// __hip_fp8_e5m2_fnuz x 1 +template <> +struct vec_t<__hip_fp8_e5m2_fnuz, 1> +{ + __hip_fp8_e5m2_fnuz data; + + inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz& operator[](size_t i) + { + return ((__hip_fp8_e5m2_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ const __hip_fp8_e5m2_fnuz& + operator[](size_t i) const + { + return ((const __hip_fp8_e5m2_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz* ptr() + { + return reinterpret_cast<__hip_fp8_e5m2_fnuz*>(&data); + } + inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e5m2_fnuz val); + inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e5m2_fnuz* ptr); + inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e5m2_fnuz* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + + inline __attribute__((always_inline)) __device__ static void + memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src); +}; + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e5m2_fnuz, 1>::fill(__hip_fp8_e5m2_fnuz val) +{ + data = val; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e5m2_fnuz, 1>::load(const __hip_fp8_e5m2_fnuz* ptr) +{ + data = *ptr; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e5m2_fnuz, 1>::store(__hip_fp8_e5m2_fnuz* ptr) const +{ + *ptr = data; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e5m2_fnuz, 1>::memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src) +{ + *dst = *src; +} + +// __hip_fp8_e5m2_fnuz x 2 +template <> +struct vec_t<__hip_fp8_e5m2_fnuz, 2> +{ + __hip_fp8x2_e5m2_fnuz data; + + inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz& operator[](size_t i) + { + return ((__hip_fp8_e5m2_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ const __hip_fp8_e5m2_fnuz& + operator[](size_t i) const + { + return ((const __hip_fp8_e5m2_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz* ptr() + { + return reinterpret_cast<__hip_fp8_e5m2_fnuz*>(&data); + } + inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e5m2_fnuz val); + inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e5m2_fnuz* ptr); + inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e5m2_fnuz* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + + inline __attribute__((always_inline)) __device__ static void + memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src); +}; + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e5m2_fnuz, 2>::fill(__hip_fp8_e5m2_fnuz val) +{ + data.__x = (__hip_fp8x2_storage_t(val.__x) << 8) | __hip_fp8x2_storage_t(val.__x); +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e5m2_fnuz, 2>::load(const __hip_fp8_e5m2_fnuz* ptr) +{ + data = *((__hip_fp8x2_e5m2_fnuz*)ptr); +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e5m2_fnuz, 2>::store(__hip_fp8_e5m2_fnuz* ptr) const +{ + *((__hip_fp8x2_e5m2_fnuz*)ptr) = data; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e5m2_fnuz, 2>::memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src) +{ + *((__hip_fp8x2_e5m2_fnuz*)dst) = *((__hip_fp8x2_e5m2_fnuz*)src); +} + +// __hip_fp8_e5m2_fnuz x 4 + +template <> +struct vec_t<__hip_fp8_e5m2_fnuz, 4> +{ + __hip_fp8x4_e5m2_fnuz data; + + inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz& operator[](size_t i) + { + return ((__hip_fp8_e5m2_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ const __hip_fp8_e5m2_fnuz& + operator[](size_t i) const + { + return ((const __hip_fp8_e5m2_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz* ptr() + { + return reinterpret_cast<__hip_fp8_e5m2_fnuz*>(&data); + } + inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e5m2_fnuz val); + inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e5m2_fnuz* ptr); + inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e5m2_fnuz* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + + inline __attribute__((always_inline)) __device__ static void + memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src); +}; + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e5m2_fnuz, 4>::fill(__hip_fp8_e5m2_fnuz val) +{ + data.__x = (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e5m2_fnuz, 4>::load(const __hip_fp8_e5m2_fnuz* ptr) +{ + data = *((__hip_fp8x4_e5m2_fnuz*)ptr); +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e5m2_fnuz, 4>::store(__hip_fp8_e5m2_fnuz* ptr) const +{ + *((__hip_fp8x4_e5m2_fnuz*)ptr) = data; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e5m2_fnuz, 4>::memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src) +{ + *((__hip_fp8x4_e5m2_fnuz*)dst) = *((__hip_fp8x4_e5m2_fnuz*)src); +} + +// __hip_fp8_e5m2_fnuz x 8 + +template <> +struct vec_t<__hip_fp8_e5m2_fnuz, 8> +{ + uint2 data; + + inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz& operator[](size_t i) + { + return ((__hip_fp8_e5m2_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ const __hip_fp8_e5m2_fnuz& + operator[](size_t i) const + { + return ((const __hip_fp8_e5m2_fnuz*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz* ptr() + { + return reinterpret_cast<__hip_fp8_e5m2_fnuz*>(&data); + } + inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e5m2_fnuz val); + inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e5m2_fnuz* ptr); + inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e5m2_fnuz* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void + memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src); +}; + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e5m2_fnuz, 8>::fill(__hip_fp8_e5m2_fnuz val) +{ + ((__hip_fp8x4_e5m2_fnuz*)(&data.x))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e5m2_fnuz*)(&data.y))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e5m2_fnuz, 8>::load(const __hip_fp8_e5m2_fnuz* ptr) +{ + data = *((uint2*)ptr); +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e5m2_fnuz, 8>::store(__hip_fp8_e5m2_fnuz* ptr) const +{ + *((uint2*)ptr) = data; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_fp8_e5m2_fnuz, 8>::memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src) +{ + *((uint2*)dst) = *((uint2*)src); +} + +// __hip_fp8_e5m2_fnuz x 16 or more + +template +struct vec_t<__hip_fp8_e5m2_fnuz, vec_size> +{ + uint4 data[vec_size / 16]; + + inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz& operator[](size_t i) + { + return ((__hip_fp8_e5m2_fnuz*)data)[i]; + } + inline __attribute__((always_inline)) __device__ const __hip_fp8_e5m2_fnuz& + operator[](size_t i) const + { + return ((const __hip_fp8_e5m2_fnuz*)data)[i]; + } + inline __attribute__((always_inline)) __device__ __hip_fp8_e5m2_fnuz* ptr() + { + return reinterpret_cast<__hip_fp8_e5m2_fnuz*>(&data); + } + inline __attribute__((always_inline)) __device__ void fill(__hip_fp8_e5m2_fnuz val) + { +#pragma unroll + for(size_t i = 0; i < vec_size / 16; ++i) + { + ((__hip_fp8x4_e5m2_fnuz*)(&(data[i].x)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e5m2_fnuz*)(&(data[i].y)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e5m2_fnuz*)(&(data[i].z)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e5m2_fnuz*)(&(data[i].w)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + } + } + inline __attribute__((always_inline)) __device__ void load(const __hip_fp8_e5m2_fnuz* ptr) + { +#pragma unroll + for(size_t i = 0; i < vec_size / 16; ++i) + { + data[i] = ((uint4*)ptr)[i]; + } + } + inline __attribute__((always_inline)) __device__ void store(__hip_fp8_e5m2_fnuz* ptr) const + { +#pragma unroll + for(size_t i = 0; i < vec_size / 16; ++i) + { + ((uint4*)ptr)[i] = data[i]; + } + } + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void + memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src) + { +#pragma unroll + for(size_t i = 0; i < vec_size / 16; ++i) + { + ((uint4*)dst)[i] = ((uint4*)src)[i]; + } + } +}; + +/******************* vec_t *******************/ + +// half x 1 +template <> +struct vec_t +{ + half data; + + inline __attribute__((always_inline)) __device__ half& operator[](size_t i) + { + return ((half*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ const half& operator[](size_t i) const + { + return ((const half*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ half* ptr() + { + return reinterpret_cast(&data); + } + inline __attribute__((always_inline)) __device__ void fill(half val); + inline __attribute__((always_inline)) __device__ void load(const half* ptr); + inline __attribute__((always_inline)) __device__ void store(half* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + + inline __attribute__((always_inline)) __device__ static void memcpy(half* dst, const half* src); +}; + +inline __attribute__((always_inline)) __device__ void vec_t::fill(half val) { data = val; } + +inline __attribute__((always_inline)) __device__ void vec_t::load(const half* ptr) +{ + data = *ptr; +} + +inline __attribute__((always_inline)) __device__ void vec_t::store(half* ptr) const +{ + *ptr = data; +} + +inline __attribute__((always_inline)) __device__ void vec_t::memcpy(half* dst, + const half* src) +{ + *dst = *src; +} + +// half x 2 +template <> +struct vec_t +{ + half2 data; + + inline __attribute__((always_inline)) __device__ half& operator[](size_t i) + { + return ((half*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ const half& operator[](size_t i) const + { + return ((const half*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ half* ptr() + { + return reinterpret_cast(&data); + } + inline __attribute__((always_inline)) __device__ void fill(half val); + inline __attribute__((always_inline)) __device__ void load(const half* ptr); + inline __attribute__((always_inline)) __device__ void store(half* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + + inline __attribute__((always_inline)) __device__ static void memcpy(half* dst, const half* src); +}; + +inline __attribute__((always_inline)) __device__ void vec_t::fill(half val) +{ + data = make_half2(val, val); +} + +inline __attribute__((always_inline)) __device__ void vec_t::load(const half* ptr) +{ + data = *((half2*)ptr); +} + +inline __attribute__((always_inline)) __device__ void vec_t::store(half* ptr) const +{ + *((half2*)ptr) = data; +} + +inline __attribute__((always_inline)) __device__ void vec_t::memcpy(half* dst, + const half* src) +{ + *((half2*)dst) = *((half2*)src); +} + +// half x 4 + +template <> +struct vec_t +{ + uint2 data; + + inline __attribute__((always_inline)) __device__ half& operator[](size_t i) + { + return ((half*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ const half& operator[](size_t i) const + { + return ((const half*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ half* ptr() + { + return reinterpret_cast(&data); + } + inline __attribute__((always_inline)) __device__ void fill(half val); + inline __attribute__((always_inline)) __device__ void load(const half* ptr); + inline __attribute__((always_inline)) __device__ void store(half* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void memcpy(half* dst, const half* src); +}; + +inline __attribute__((always_inline)) __device__ void vec_t::fill(half val) +{ + *(half2*)(&data.x) = make_half2(val, val); + *(half2*)(&data.y) = make_half2(val, val); +} + +inline __attribute__((always_inline)) __device__ void vec_t::load(const half* ptr) +{ + data = *((uint2*)ptr); +} + +inline __attribute__((always_inline)) __device__ void vec_t::store(half* ptr) const +{ + *((uint2*)ptr) = data; +} + +inline __attribute__((always_inline)) __device__ void vec_t::memcpy(half* dst, + const half* src) +{ + *((uint2*)dst) = *((uint2*)src); +} + +// half x 8 or more + +template +struct vec_t +{ + uint4 data[vec_size / 8]; + inline __attribute__((always_inline)) __device__ half& operator[](size_t i) + { + return ((half*)data)[i]; + } + inline __attribute__((always_inline)) __device__ const half& operator[](size_t i) const + { + return ((const half*)data)[i]; + } + inline __attribute__((always_inline)) __device__ half* ptr() + { + return reinterpret_cast(&data); + } + inline __attribute__((always_inline)) __device__ void fill(half val) + { +#pragma unroll + for(size_t i = 0; i < vec_size / 8; ++i) + { + *(half2*)(&(data[i].x)) = make_half2(val, val); + *(half2*)(&(data[i].y)) = make_half2(val, val); + *(half2*)(&(data[i].z)) = make_half2(val, val); + *(half2*)(&(data[i].w)) = make_half2(val, val); + } + } + inline __attribute__((always_inline)) __device__ void load(const half* ptr) + { +#pragma unroll + for(size_t i = 0; i < vec_size / 8; ++i) + { + data[i] = ((uint4*)ptr)[i]; + } + } + inline __attribute__((always_inline)) __device__ void store(half* ptr) const + { +#pragma unroll + for(size_t i = 0; i < vec_size / 8; ++i) + { + ((uint4*)ptr)[i] = data[i]; + } + } + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void memcpy(half* dst, const half* src) + { +#pragma unroll + for(size_t i = 0; i < vec_size / 8; ++i) + { + ((uint4*)dst)[i] = ((uint4*)src)[i]; + } + } +}; + +/******************* vec_t<__hip_bfloat16> *******************/ + +// __hip_bfloat16 x 1 +template <> +struct vec_t<__hip_bfloat16, 1> +{ + __hip_bfloat16 data; + inline __attribute__((always_inline)) __device__ __hip_bfloat16& operator[](size_t i) + { + return ((__hip_bfloat16*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ const __hip_bfloat16& + operator[](size_t i) const + { + return ((const __hip_bfloat16*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ __hip_bfloat16* ptr() + { + return reinterpret_cast<__hip_bfloat16*>(&data); + } + inline __attribute__((always_inline)) __device__ void fill(__hip_bfloat16 val); + inline __attribute__((always_inline)) __device__ void load(const __hip_bfloat16* ptr); + inline __attribute__((always_inline)) __device__ void store(__hip_bfloat16* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void memcpy(__hip_bfloat16* dst, + const __hip_bfloat16* src); +}; + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_bfloat16, 1>::fill(__hip_bfloat16 val) +{ + data = val; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_bfloat16, 1>::load(const __hip_bfloat16* ptr) +{ + data = *ptr; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_bfloat16, 1>::store(__hip_bfloat16* ptr) const +{ + *ptr = data; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_bfloat16, 1>::memcpy(__hip_bfloat16* dst, const __hip_bfloat16* src) +{ + *dst = *src; +} + +// __hip_bfloat16 x 2 +template <> +struct vec_t<__hip_bfloat16, 2> +{ + __hip_bfloat162 data; + + inline __attribute__((always_inline)) __device__ __hip_bfloat16& operator[](size_t i) + { + return ((__hip_bfloat16*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ const __hip_bfloat16& + operator[](size_t i) const + { + return ((const __hip_bfloat16*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ __hip_bfloat16* ptr() + { + return reinterpret_cast<__hip_bfloat16*>(&data); + } + inline __attribute__((always_inline)) __device__ void fill(__hip_bfloat16 val); + inline __attribute__((always_inline)) __device__ void load(const __hip_bfloat16* ptr); + inline __attribute__((always_inline)) __device__ void store(__hip_bfloat16* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void memcpy(__hip_bfloat16* dst, + const __hip_bfloat16* src); +}; + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_bfloat16, 2>::fill(__hip_bfloat16 val) +{ + data = make_bfloat162(val, val); +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_bfloat16, 2>::load(const __hip_bfloat16* ptr) +{ + data = *((__hip_bfloat162*)ptr); +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_bfloat16, 2>::store(__hip_bfloat16* ptr) const +{ + *((__hip_bfloat162*)ptr) = data; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_bfloat16, 2>::memcpy(__hip_bfloat16* dst, const __hip_bfloat16* src) +{ + *((__hip_bfloat162*)dst) = *((__hip_bfloat162*)src); +} + +// __hip_bfloat16 x 4 + +template <> +struct vec_t<__hip_bfloat16, 4> +{ + uint2 data; + + inline __attribute__((always_inline)) __device__ __hip_bfloat16& operator[](size_t i) + { + return ((__hip_bfloat16*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ const __hip_bfloat16& + operator[](size_t i) const + { + return ((const __hip_bfloat16*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ __hip_bfloat16* ptr() + { + return reinterpret_cast<__hip_bfloat16*>(&data); + } + inline __attribute__((always_inline)) __device__ void fill(__hip_bfloat16 val); + inline __attribute__((always_inline)) __device__ void load(const __hip_bfloat16* ptr); + inline __attribute__((always_inline)) __device__ void store(__hip_bfloat16* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void memcpy(__hip_bfloat16* dst, + const __hip_bfloat16* src); +}; + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_bfloat16, 4>::fill(__hip_bfloat16 val) +{ + *(__hip_bfloat162*)(&data.x) = make_bfloat162(val, val); + *(__hip_bfloat162*)(&data.y) = make_bfloat162(val, val); +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_bfloat16, 4>::load(const __hip_bfloat16* ptr) +{ + data = *((uint2*)ptr); +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_bfloat16, 4>::store(__hip_bfloat16* ptr) const +{ + *((uint2*)ptr) = data; +} + +inline __attribute__((always_inline)) __device__ void +vec_t<__hip_bfloat16, 4>::memcpy(__hip_bfloat16* dst, const __hip_bfloat16* src) +{ + *((uint2*)dst) = *((uint2*)src); +} + +// __hip_bfloat16 x 8 or more + +template +struct vec_t<__hip_bfloat16, vec_size> +{ + uint4 data[vec_size / 8]; + + inline __attribute__((always_inline)) __device__ __hip_bfloat16& operator[](size_t i) + { + return ((__hip_bfloat16*)data)[i]; + } + inline __attribute__((always_inline)) __device__ const __hip_bfloat16& + operator[](size_t i) const + { + return ((const __hip_bfloat16*)data)[i]; + } + inline __attribute__((always_inline)) __device__ __hip_bfloat16* ptr() + { + return reinterpret_cast<__hip_bfloat16*>(&data); + } + inline __attribute__((always_inline)) __device__ void fill(__hip_bfloat16 val) + { +#pragma unroll + for(size_t i = 0; i < vec_size / 8; ++i) + { + *(__hip_bfloat162*)(&(data[i].x)) = make_bfloat162(val, val); + *(__hip_bfloat162*)(&(data[i].y)) = make_bfloat162(val, val); + *(__hip_bfloat162*)(&(data[i].z)) = make_bfloat162(val, val); + *(__hip_bfloat162*)(&(data[i].w)) = make_bfloat162(val, val); + } + } + inline __attribute__((always_inline)) __device__ void load(const __hip_bfloat16* ptr) + { +#pragma unroll + for(size_t i = 0; i < vec_size / 8; ++i) + { + data[i] = ((uint4*)ptr)[i]; + } + } + inline __attribute__((always_inline)) __device__ void store(__hip_bfloat16* ptr) const + { +#pragma unroll + for(size_t i = 0; i < vec_size / 8; ++i) + { + ((uint4*)ptr)[i] = data[i]; + } + } + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void memcpy(__hip_bfloat16* dst, + const __hip_bfloat16* src) + { +#pragma unroll + for(size_t i = 0; i < vec_size / 8; ++i) + { + ((uint4*)dst)[i] = ((uint4*)src)[i]; + } + } +}; + +/******************* vec_t *******************/ + +// float x 1 + +template <> +struct vec_t +{ + float data; + + inline __attribute__((always_inline)) __device__ float& operator[](size_t i) + { + return ((float*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ const float& operator[](size_t i) const + { + return ((const float*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ float* ptr() + { + return reinterpret_cast(&data); + } + inline __attribute__((always_inline)) __device__ void fill(float val); + inline __attribute__((always_inline)) __device__ void load(const float* ptr); + inline __attribute__((always_inline)) __device__ void store(float* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void memcpy(float* dst, + const float* src); +}; + +inline __attribute__((always_inline)) __device__ void vec_t::fill(float val) +{ + data = val; +} + +inline __attribute__((always_inline)) __device__ void vec_t::load(const float* ptr) +{ + data = *ptr; +} + +inline __attribute__((always_inline)) __device__ void vec_t::store(float* ptr) const +{ + *ptr = data; +} + +inline __attribute__((always_inline)) __device__ void vec_t::memcpy(float* dst, + const float* src) +{ + *dst = *src; +} + +// float x 2 + +template <> +struct vec_t +{ + float2 data; + + inline __attribute__((always_inline)) __device__ float& operator[](size_t i) + { + return ((float*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ const float& operator[](size_t i) const + { + return ((const float*)(&data))[i]; + } + inline __attribute__((always_inline)) __device__ float* ptr() + { + return reinterpret_cast(&data); + } + inline __attribute__((always_inline)) __device__ void fill(float val); + inline __attribute__((always_inline)) __device__ void load(const float* ptr); + inline __attribute__((always_inline)) __device__ void store(float* ptr) const; + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void memcpy(float* dst, + const float* src); +}; + +inline __attribute__((always_inline)) __device__ void vec_t::fill(float val) +{ + data = make_float2(val, val); +} + +inline __attribute__((always_inline)) __device__ void vec_t::load(const float* ptr) +{ + data = *((float2*)ptr); +} + +inline __attribute__((always_inline)) __device__ void vec_t::store(float* ptr) const +{ + *((float2*)ptr) = data; +} + +inline __attribute__((always_inline)) __device__ void vec_t::memcpy(float* dst, + const float* src) +{ + *((float2*)dst) = *((float2*)src); +} + +// float x 4 or more +template +struct vec_t +{ + float4 data[vec_size / 4]; + + inline __attribute__((always_inline)) __device__ float& operator[](size_t i) + { + return ((float*)(data))[i]; + } + inline __attribute__((always_inline)) __device__ const float& operator[](size_t i) const + { + return ((const float*)(data))[i]; + } + inline __attribute__((always_inline)) __device__ float* ptr() + { + return reinterpret_cast(&data); + } + inline __attribute__((always_inline)) __device__ void fill(float val) + { +#pragma unroll + for(size_t i = 0; i < vec_size / 4; ++i) + { + data[i] = make_float4(val, val, val, val); + } + } + inline __attribute__((always_inline)) __device__ void load(const float* ptr) + { +#pragma unroll + for(size_t i = 0; i < vec_size / 4; ++i) + { + data[i] = ((float4*)ptr)[i]; + } + } + inline __attribute__((always_inline)) __device__ void store(float* ptr) const + { +#pragma unroll + for(size_t i = 0; i < vec_size / 4; ++i) + { + ((float4*)ptr)[i] = data[i]; + } + } + template + inline __attribute__((always_inline)) __device__ void cast_from(const vec_t& src) + { + cast_from_impl(*this, src); + } + template + inline __attribute__((always_inline)) __device__ void cast_load(const T* ptr) + { + cast_load_impl(*this, ptr); + } + template + inline __attribute__((always_inline)) __device__ void cast_store(T* ptr) const + { + cast_store_impl(ptr, *this); + } + inline __attribute__((always_inline)) __device__ static void memcpy(float* dst, + const float* src) + { +#pragma unroll + for(size_t i = 0; i < vec_size / 4; ++i) + { + ((float4*)dst)[i] = ((float4*)src)[i]; + } + } +}; + +} // namespace aiter From 812410accf18bb443f1049210efcba7d8b2ca590 Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Mon, 28 Jul 2025 07:08:24 +0000 Subject: [PATCH 15/19] fix deps --- csrc/cpp_itfs/file_baton.py | 57 ------------------------------------- csrc/cpp_itfs/utils.py | 2 +- 2 files changed, 1 insertion(+), 58 deletions(-) delete mode 100644 csrc/cpp_itfs/file_baton.py diff --git a/csrc/cpp_itfs/file_baton.py b/csrc/cpp_itfs/file_baton.py deleted file mode 100644 index 40ed604c97..0000000000 --- a/csrc/cpp_itfs/file_baton.py +++ /dev/null @@ -1,57 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. - -# mypy: allow-untyped-defs -import os -import time -import logging - -logger = logging.getLogger("aiter") - - -class FileBaton: - """A primitive, file-based synchronization utility.""" - - def __init__(self, lock_file_path, wait_seconds=0.2): - """ - Create a new :class:`FileBaton`. - - Args: - lock_file_path: The path to the file used for locking. - wait_seconds: The seconds to periodically sleep (spin) when - calling ``wait()``. - """ - self.lock_file_path = lock_file_path - self.wait_seconds = wait_seconds - self.fd = None - - def try_acquire(self): - """ - Try to atomically create a file under exclusive access. - - Returns: - True if the file could be created, else False. - """ - try: - self.fd = os.open(self.lock_file_path, os.O_CREAT | os.O_EXCL) - return True - except FileExistsError: - return False - - def wait(self): - """ - Periodically sleeps for a certain amount until the baton is released. - - The amount of time slept depends on the ``wait_seconds`` parameter - passed to the constructor. - """ - logger.info(f"waiting for baton release at {self.lock_file_path}") - while os.path.exists(self.lock_file_path): - time.sleep(self.wait_seconds) - - def release(self): - """Release the baton and removes its file.""" - if self.fd is not None: - os.close(self.fd) - - os.remove(self.lock_file_path) diff --git a/csrc/cpp_itfs/utils.py b/csrc/cpp_itfs/utils.py index ecf40ca185..9744497193 100644 --- a/csrc/cpp_itfs/utils.py +++ b/csrc/cpp_itfs/utils.py @@ -12,7 +12,7 @@ from functools import lru_cache, partial import binascii import hashlib -from csrc.cpp_itfs.file_baton import FileBaton +from aiter.jit.utils.file_baton import FileBaton import logging import time From 2d28063d9cc62c8fdee9da9290fed4a13c094c69 Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Mon, 28 Jul 2025 07:12:05 +0000 Subject: [PATCH 16/19] fix copyright --- aiter/ops/sampling.py | 3 +++ op_tests/test_sampling.py | 17 ++--------------- 2 files changed, 5 insertions(+), 15 deletions(-) diff --git a/aiter/ops/sampling.py b/aiter/ops/sampling.py index 434f8de8b4..0df541569b 100644 --- a/aiter/ops/sampling.py +++ b/aiter/ops/sampling.py @@ -1,3 +1,6 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + import torch from typing import Optional diff --git a/op_tests/test_sampling.py b/op_tests/test_sampling.py index 0fc9435f1c..3fac3ae3bb 100644 --- a/op_tests/test_sampling.py +++ b/op_tests/test_sampling.py @@ -1,18 +1,5 @@ -""" -Copyright (C) 2024-2025 by FlashInfer team. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" +# SPDX-License-Identifier: MIT +# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. import pytest import torch From 646b10278e5a449feb00e333457bb170ca8d4600 Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Mon, 28 Jul 2025 14:08:41 +0000 Subject: [PATCH 17/19] remove useless code --- csrc/cpp_itfs/sampling/sampling.cuh | 669 ---------------------------- 1 file changed, 669 deletions(-) diff --git a/csrc/cpp_itfs/sampling/sampling.cuh b/csrc/cpp_itfs/sampling/sampling.cuh index 4a6afb0fea..b973c634b5 100644 --- a/csrc/cpp_itfs/sampling/sampling.cuh +++ b/csrc/cpp_itfs/sampling/sampling.cuh @@ -41,53 +41,6 @@ constexpr uint32_t BLOCK_THREADS = 1024; constexpr BlockScanAlgorithm SCAN_ALGO = BLOCK_SCAN_WARP_SCANS; constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS; -#define DISPATCH_ALIGNED_VEC_SIZE(aligned_vec_size, ALIGNED_VEC_SIZE, ...) \ - switch(aligned_vec_size) \ - { \ - case 16: { \ - constexpr size_t ALIGNED_VEC_SIZE = 16; \ - __VA_ARGS__ \ - break; \ - } \ - case 8: { \ - constexpr size_t ALIGNED_VEC_SIZE = 8; \ - __VA_ARGS__ \ - break; \ - } \ - case 4: { \ - constexpr size_t ALIGNED_VEC_SIZE = 4; \ - __VA_ARGS__ \ - break; \ - } \ - case 2: { \ - constexpr size_t ALIGNED_VEC_SIZE = 2; \ - __VA_ARGS__ \ - break; \ - } \ - case 1: { \ - constexpr size_t ALIGNED_VEC_SIZE = 1; \ - __VA_ARGS__ \ - break; \ - } \ - default: { \ - std::ostringstream err_msg; \ - err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size; \ - throw std::runtime_error(err_msg.str()); \ - } \ - } - -#define DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, ...) \ - if(deterministic) \ - { \ - constexpr bool DETERMINISTIC = true; \ - __VA_ARGS__ \ - } \ - else \ - { \ - constexpr bool DETERMINISTIC = false; \ - __VA_ARGS__ \ - } - template struct ValueCount { @@ -397,358 +350,6 @@ __device__ __forceinline__ void DeviceSamplingFromProb( aggregate += aggregate_local; } -template -struct DataAndIndex -{ - DType data; - IdType index; - - __device__ DataAndIndex operator+(const DataAndIndex& other) const - { - if(data > other.data) - { - return {data, index}; - } - else - { - return {other.data, other.index}; - } - } - __device__ DataAndIndex& operator+=(const DataAndIndex& other) - { - if(data > other.data) - { - return *this; - } - else - { - data = other.data; - index = other.index; - return *this; - } - } -}; - -template -__device__ __forceinline__ vec_t -GenerateGumbelNoise(uint64_t philox_seed, uint64_t philox_offset, uint64_t subsequence) -{ - hiprandStatePhilox4_32_10_t state; - vec_t noise; - constexpr float kEPSILON = 1e-20f; - constexpr float kLOG2 = 0.6931471806f; - auto uniform2gumbel = [](float x) { return -kLOG2 * log2f(-log2f(x + kEPSILON) + kEPSILON); }; -// TODO: compare the speed of log2 and log -#pragma unroll - for(uint32_t i = 0; i + 4 <= VEC_SIZE; i += 4) - { - hiprand_init(philox_seed, subsequence + i, philox_offset, &state); - float4 noise_vec = hiprand_uniform4(&state); - noise[i] = uniform2gumbel(noise_vec.x); - noise[i + 1] = uniform2gumbel(noise_vec.y); - noise[i + 2] = uniform2gumbel(noise_vec.z); - noise[i + 3] = uniform2gumbel(noise_vec.w); - } - if constexpr(VEC_SIZE % 4 != 0) - { - hiprand_init(philox_seed, subsequence + VEC_SIZE / 4 * 4, philox_offset, &state); - float4 noise_vec = hiprand_uniform4(&state); - if constexpr(VEC_SIZE % 4 == 1) - { - noise[VEC_SIZE - 1] = uniform2gumbel(noise_vec.x); - } - else if constexpr(VEC_SIZE % 4 == 2) - { - noise[VEC_SIZE - 2] = uniform2gumbel(noise_vec.x); - noise[VEC_SIZE - 1] = uniform2gumbel(noise_vec.y); - } - else if constexpr(VEC_SIZE % 4 == 3) - { - noise[VEC_SIZE - 3] = uniform2gumbel(noise_vec.x); - noise[VEC_SIZE - 2] = uniform2gumbel(noise_vec.y); - noise[VEC_SIZE - 1] = uniform2gumbel(noise_vec.z); - } - } - - if constexpr(std::is_same_v) - { - return noise; - } - else - { - vec_t ret; -#pragma unroll - for(uint32_t i = 0; i < VEC_SIZE; ++i) - { - ret[i] = static_cast(noise[i]); - } - return ret; - } -} - -template -__global__ void SamplingFromLogitsKernel(DType* logits, - IdType* output, - IdType* indices, - uint32_t d, - uint64_t philox_seed, - uint64_t philox_offset) -{ - const uint32_t bx = blockIdx.x, tx = threadIdx.x; - const uint32_t row_idx = indices == nullptr ? bx : indices[bx]; - using SharedMem = typename BlockReduce, - BLOCK_THREADS, - REDUCE_ALGORITHM>::TempStorage; - extern __shared__ __align__(alignof(SharedMem)) uint8_t smem_sampling[]; - auto& temp_storage = reinterpret_cast(smem_sampling); - - vec_t logits_vec; - DataAndIndex max_data = {-infinity(), 0}; - for(uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) - { - logits_vec.fill(-infinity()); - if((i * BLOCK_THREADS + tx) * VEC_SIZE < d) - { - logits_vec.cast_load(logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + - tx * VEC_SIZE); - } - - vec_t gumbel_noise = GenerateGumbelNoise( - philox_seed, - philox_offset, - static_cast(bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE)); - DataAndIndex cur_data[VEC_SIZE]; -#pragma unroll - for(uint32_t j = 0; j < VEC_SIZE; ++j) - { - cur_data[j].data = (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d - ? logits_vec[j] + gumbel_noise[j] - : -infinity(); - cur_data[j].index = (i * BLOCK_THREADS + tx) * VEC_SIZE + j; - } - - max_data += - BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage) - .Sum(cur_data); - } - if(tx == 0) - { - output[bx] = max_data.index; - } -} - -template -__global__ void SamplingFromProbKernel(DType* probs, - IdType* output, - IdType* indices, - uint32_t d, - uint64_t philox_seed, - uint64_t philox_offset) -{ - hiprandStatePhilox4_32_10_t state; - const uint32_t bx = blockIdx.x, tx = threadIdx.x; - hiprand_init(philox_seed, bx, philox_offset, &state); - const uint32_t row_idx = indices == nullptr ? bx : indices[bx]; - - extern __shared__ __align__( - alignof(SamplingTempStorage)) - uint8_t smem_sampling[]; - auto& temp_storage = - reinterpret_cast&>( - smem_sampling); - temp_storage.sampled_id = d; - __syncthreads(); - - vec_t probs_vec; - float aggregate(0); - float u = hiprand_uniform(&state); - -#pragma unroll 2 - for(uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) - { - probs_vec.fill(0); - if((i * BLOCK_THREADS + tx) * VEC_SIZE < d) - { - probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); - } - - DeviceSamplingFromProb( - i, d, [](float x) { return x > 0; }, u, probs_vec, aggregate, &temp_storage); - if(float(aggregate) > u) - { - break; - } - } - int sampled_id = temp_storage.sampled_id; - if(sampled_id == d) - { - // NOTE(Zihao): this would happen when u is very close to 1 - // and the sum of probabilities is smaller than u - // In this case, we use the last valid index as the sampled id - sampled_id = temp_storage.last_valid_id; - } - output[bx] = sampled_id; -} - -template -__global__ void TopKSamplingFromProbKernel(DType* probs, - IdType* output, - IdType* indices, - IdType* top_k_arr, - uint32_t top_k_val, - uint32_t d, - uint64_t philox_seed, - uint64_t philox_offset) -{ - const uint32_t batch_size = gridDim.x; - const uint32_t bx = blockIdx.x, tx = threadIdx.x; - hiprandStatePhilox4_32_10_t state; - hiprand_init(philox_seed, bx, philox_offset, &state); - const uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx]; - const uint32_t row_idx = indices == nullptr ? bx : indices[bx]; - - extern __shared__ __align__( - alignof(SamplingTempStorage)) - uint8_t smem_sampling[]; - auto& temp_storage = - reinterpret_cast&>( - smem_sampling); - - vec_t probs_vec; - float aggregate; - float q = 1; - double low = 0, high = 1.f; - int sampled_id; - int round = 0; - do - { - round += 1; - temp_storage.sampled_id = d; - __syncthreads(); - float u = hiprand_uniform(&state) * q; - aggregate = 0; -#pragma unroll 2 - for(uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) - { - probs_vec.fill(0); - if((i * BLOCK_THREADS + tx) * VEC_SIZE < d) - { - probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); - } - - DeviceSamplingFromProb( - i, d, [&](float x) { return x > low; }, u, probs_vec, aggregate, &temp_storage); - if(aggregate > u) - { - break; - } - } - __syncthreads(); - sampled_id = temp_storage.sampled_id; - if(sampled_id == d) - { - // NOTE(Zihao): this would happen when u is very close to 1 - // and the sum of probabilities is smaller than u - // In this case, we use the last valid index as the sampled id - sampled_id = temp_storage.last_valid_id; - } - double pivot_0 = probs[row_idx * d + sampled_id]; - double pivot_1 = (pivot_0 + high) / 2; - - ValueCount aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0}; -#pragma unroll 2 - for(uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) - { - probs_vec.fill(0); - if((i * BLOCK_THREADS + tx) * VEC_SIZE < d) - { - probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); - } - - ValueCount probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE]; -#pragma unroll - for(uint32_t j = 0; j < VEC_SIZE; ++j) - { - probs_gt_pivot_0[j] = { - (probs_vec[j] > pivot_0) ? probs_vec[j] : 0, - (probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; - probs_gt_pivot_1[j] = { - (probs_vec[j] > pivot_1) ? probs_vec[j] : 0, - (probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; - } - - aggregate_gt_pivot_0 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( - temp_storage.block_prim.reduce_value_count) - .Sum(probs_gt_pivot_0); - if(tx == 0) - { - temp_storage.block_aggregate.pair = aggregate_gt_pivot_0; - } - __syncthreads(); - aggregate_gt_pivot_0 = temp_storage.block_aggregate.pair; - - aggregate_gt_pivot_1 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( - temp_storage.block_prim.reduce_value_count) - .Sum(probs_gt_pivot_1); - if(tx == 0) - { - temp_storage.block_aggregate.pair = aggregate_gt_pivot_1; - } - __syncthreads(); - aggregate_gt_pivot_1 = temp_storage.block_aggregate.pair; - } - if(aggregate_gt_pivot_0.count < k) - { - // case 1: pivot_0 accepted - break; - } - if(aggregate_gt_pivot_1.count < k) - { - // case 2: pivot_0 rejected, pivot_1 accepted - low = pivot_0; - high = pivot_1; - q = aggregate_gt_pivot_0.value; - } - else - { - // case 3: pivot_0 rejected, pivot_1 rejected - low = pivot_1; - q = aggregate_gt_pivot_1.value; - } - } while(low < high); - __syncthreads(); - if(tx == 0) - { - output[bx] = sampled_id; - } -} - template -hipError_t SamplingFromLogits(T* logits, - IdType* output, - IdType* indices, - uint32_t batch_size, - uint32_t d, - bool deterministic, - uint64_t philox_seed, - uint64_t philox_offset, - hipStream_t stream = 0) -{ - - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&logits, &output, &indices, &d, &philox_seed, &philox_offset}; - const uint32_t smem_size = sizeof( - typename BlockReduce, BLOCK_THREADS, REDUCE_ALGO>::TempStorage); - - DISPATCH_ALIGNED_VEC_SIZE( - vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { - auto kernel = SamplingFromLogitsKernel; - hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); - })}); - return hipSuccess; -} - -template -hipError_t SamplingFromProb(T* probs, - IdType* output, - IdType* indices, - uint32_t batch_size, - uint32_t d, - bool deterministic, - uint64_t philox_seed, - uint64_t philox_offset, - hipStream_t stream = 0) -{ - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &output, &indices, &d, &philox_seed, &philox_offset, &d}; - const uint32_t smem_size = sizeof(SamplingTempStorage); - - DISPATCH_ALIGNED_VEC_SIZE( - vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { - auto kernel = SamplingFromProbKernel; - - hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); - })}); - return hipSuccess; -} - -template -hipError_t TopKSamplingFromProb(T* probs, - IdType* output, - IdType* indices, - T* top_k_arr, - uint32_t batch_size, - uint32_t top_k_val, - uint32_t d, - bool deterministic, - uint64_t philox_seed, - uint64_t philox_offset, - hipStream_t stream = 0) -{ - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - const uint32_t smem_size = sizeof(SamplingTempStorage); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = { - &probs, &output, &indices, &top_k_arr, &top_k_val, &d, &philox_seed, &philox_offset}; - - DISPATCH_ALIGNED_VEC_SIZE( - vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { - auto kernel = TopKSamplingFromProbKernel; - - hipFuncSetAttribute(reinterpret_cast(kernel), - hipFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - - hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); - })}); - return hipSuccess; -} - template struct RenormTempStorage { @@ -1166,146 +663,6 @@ struct RenormTempStorage }; }; -template -__global__ void TopPRenormProbKernel( - DType* probs, DType* renormed_prob, float* top_p_arr, float top_p_val, uint32_t d) -{ - const uint32_t bx = blockIdx.x, tx = threadIdx.x; - const uint32_t row_idx = bx; - float p = top_p_arr == nullptr ? top_p_val : top_p_arr[bx]; - - extern __shared__ __align__(alignof(RenormTempStorage)) - uint8_t smem_renorm[]; - auto& temp_storage = - reinterpret_cast&>(smem_renorm); - temp_storage.max_val = 0; - vec_t probs_vec; - - float max_val = GetMaxValue>( - probs, row_idx, d, temp_storage); - - double low = 0, high = max_val; - float min_gt_low, max_le_high; - float sum_low = 1; - // f(x) = sum(probs[probs > x]), f(x) is non-increasing - // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high} - // loop invariant: - // - f(low) >= p, f(high) < p - // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high) - // stopping condition - // - f(low) >= p, f(min_gt_low) == f(max_le_high) == f(high) < p - do - { - double pivot_0 = (high + 2 * low) / 3; - double pivot_1 = (2 * high + low) / 3; - - float aggregate_gt_pivot_0 = 0, aggregate_gt_pivot_1 = 0; - min_gt_low = high; - max_le_high = low; -#pragma unroll 2 - for(uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) - { - probs_vec.fill(0); - if((i * BLOCK_THREADS + tx) * VEC_SIZE < d) - { - probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + - tx * VEC_SIZE); - } - - float probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE]; -#pragma unroll - for(uint32_t j = 0; j < VEC_SIZE; ++j) - { - probs_gt_pivot_0[j] = (probs_vec[j] > pivot_0) ? probs_vec[j] : 0; - probs_gt_pivot_1[j] = (probs_vec[j] > pivot_1) ? probs_vec[j] : 0; - - if(probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) - { - min_gt_low = min(min_gt_low, probs_vec[j]); - } - if(probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) - { - max_le_high = max(max_le_high, probs_vec[j]); - } - } - - aggregate_gt_pivot_0 += - BlockReduce(temp_storage.block_prim.reduce) - .Sum(probs_gt_pivot_0); - __syncthreads(); - - aggregate_gt_pivot_1 += - BlockReduce(temp_storage.block_prim.reduce) - .Sum(probs_gt_pivot_1); - __syncthreads(); - } - min_gt_low = - BlockReduce(temp_storage.block_prim.reduce) - .Reduce(min_gt_low, hipcub::Min()); - __syncthreads(); - max_le_high = - BlockReduce(temp_storage.block_prim.reduce) - .Reduce(max_le_high, hipcub::Max()); - if(tx == 0) - { - temp_storage.block_aggregate.values[0] = aggregate_gt_pivot_0; - temp_storage.block_aggregate.values[1] = aggregate_gt_pivot_1; - temp_storage.min_val = min_gt_low; - temp_storage.max_val = max_le_high; - } - __syncthreads(); - aggregate_gt_pivot_0 = temp_storage.block_aggregate.values[0]; - aggregate_gt_pivot_1 = temp_storage.block_aggregate.values[1]; - min_gt_low = temp_storage.min_val; - max_le_high = temp_storage.max_val; - - if(aggregate_gt_pivot_1 >= p) - { - low = pivot_1; - sum_low = aggregate_gt_pivot_1; - } - else if(aggregate_gt_pivot_0 >= p) - { - low = pivot_0; - high = min(pivot_1, max_le_high); - sum_low = aggregate_gt_pivot_0; - } - else - { - high = min(pivot_0, max_le_high); - } - } while(min_gt_low != max_le_high); - - float normalizer = __frcp_rn(max(sum_low, 1e-8)); - - // normalize -#pragma unroll 2 - for(uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) - { - probs_vec.fill(0); - if((i * BLOCK_THREADS + tx) * VEC_SIZE < d) - { - probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); - } -#pragma unroll - for(uint32_t j = 0; j < VEC_SIZE; ++j) - { - probs_vec[j] = (probs_vec[j] > low) ? probs_vec[j] * normalizer : 0; - } - if((i * BLOCK_THREADS + tx) * VEC_SIZE < d) - { - probs_vec.cast_store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + - tx * VEC_SIZE); - } - } -} - template -hipError_t TopPRenormProb(DType* probs, - DType* renormed_prob, - float* top_p_arr, - uint32_t batch_size, - float top_p_val, - uint32_t d, - hipStream_t stream = 0) -{ - - const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); - - const uint32_t smem_size = sizeof(RenormTempStorage); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &renormed_prob, &top_p_arr, &top_p_val, &d}; - DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = TopPRenormProbKernel; - hipFuncSetAttribute(reinterpret_cast(kernel), - hipFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream); - }); - return hipSuccess; -} - } // namespace sampling } // namespace aiter \ No newline at end of file From 8c0f47de0779bdb594f57911cbbd5b2e97d8f5f5 Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Tue, 5 Aug 2025 07:27:13 +0000 Subject: [PATCH 18/19] add python cli --- csrc/cpp_itfs/sampling/top_k_renorm_probs.py | 10 ++++++++++ .../sampling/top_k_top_p_sampling_from_probs.py | 13 ++++++++++++- csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py | 13 ++++++++++++- 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/csrc/cpp_itfs/sampling/top_k_renorm_probs.py b/csrc/cpp_itfs/sampling/top_k_renorm_probs.py index cc4722811c..cfc816798f 100644 --- a/csrc/cpp_itfs/sampling/top_k_renorm_probs.py +++ b/csrc/cpp_itfs/sampling/top_k_renorm_probs.py @@ -74,3 +74,13 @@ def top_k_renorm_probs( stream, ) return renorm_probs + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--d", type=int, required=True) + parser.add_argument("--folder", type=str, default=None) + args = parser.parse_args() + compile(**vars(args)) diff --git a/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py index 1af0db5f43..48fbe6e6f3 100644 --- a/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py +++ b/csrc/cpp_itfs/sampling/top_k_top_p_sampling_from_probs.py @@ -3,7 +3,7 @@ from jinja2 import Template -from csrc.cpp_itfs.utils import compile_template_op, AITER_CORE_DIR +from csrc.cpp_itfs.utils import compile_template_op, AITER_CORE_DIR, str_to_bool MD_NAME = "top_k_top_p_sampling_from_probs" @@ -102,3 +102,14 @@ def top_k_top_p_sampling_from_probs( stream, ) return output + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--d", type=int, required=True) + parser.add_argument("--deterministic", type=str_to_bool, required=True) + parser.add_argument("--folder", type=str, default=None) + args = parser.parse_args() + compile(**vars(args)) diff --git a/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py index 1d107869d0..7e1500b231 100644 --- a/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py +++ b/csrc/cpp_itfs/sampling/top_p_sampling_from_probs.py @@ -3,7 +3,7 @@ from jinja2 import Template -from csrc.cpp_itfs.utils import compile_template_op, AITER_CORE_DIR +from csrc.cpp_itfs.utils import compile_template_op, AITER_CORE_DIR, str_to_bool MD_NAME = "top_p_sampling_from_probs" @@ -92,3 +92,14 @@ def top_p_sampling_from_probs( stream, ) return samples + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--d", type=int, required=True) + parser.add_argument("--deterministic", type=str_to_bool, required=True) + parser.add_argument("--folder", type=str, default=None) + args = parser.parse_args() + compile(**vars(args)) From b9a024e3f4142283bfd6dda432542e463756e505 Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Tue, 5 Aug 2025 08:29:25 +0000 Subject: [PATCH 19/19] remove useless files --- aiter/aot/triton/compile.cpp | 72 -------- aiter/aot/triton/compile.h | 14 -- aiter/aot/triton/compile.py | 307 ----------------------------------- 3 files changed, 393 deletions(-) delete mode 100644 aiter/aot/triton/compile.cpp delete mode 100644 aiter/aot/triton/compile.h delete mode 100644 aiter/aot/triton/compile.py diff --git a/aiter/aot/triton/compile.cpp b/aiter/aot/triton/compile.cpp deleted file mode 100644 index 124a305ced..0000000000 --- a/aiter/aot/triton/compile.cpp +++ /dev/null @@ -1,72 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -/* clang-format off */ -#include -#include -#include -#include -#include - -// helpers to check for hip errors -#define HIP_CHECK(ans) {{\ - gpuAssert((ans), __FILE__, __LINE__);\ - }}\ - -static inline void gpuAssert(hipError_t code, const char *file, int line) {{ - if (code != hipSuccess) {{ - const char *prefix = "Triton Error [HIP]: "; - const char *str; - hipDrvGetErrorString(code, &str); - char err[1024] = {{0}}; - strcat(err, prefix); - strcat(err, str); - printf("%s\\n", err); - exit(code); - }} -}} - -// globals -#define HSACO_NAME {kernel_name}_hsaco -hipModule_t {kernel_name}_mod = nullptr; -hipFunction_t {kernel_name}_func = nullptr; -unsigned char HSACO_NAME[{bin_size}] = {{ {bin_data} }}; - - -void unload_{kernel_name}(void) {{ - HIP_CHECK(hipModuleUnload({kernel_name}_mod)); -}} - - -void load_{kernel_name}() {{ - int dev = 0; - void *bin = (void *)&HSACO_NAME; - int shared = {shared}; - HIP_CHECK(hipModuleLoadData(&{kernel_name}_mod, bin)); - HIP_CHECK(hipModuleGetFunction(&{kernel_name}_func, {kernel_name}_mod, "{triton_kernel_name}")); - // set dynamic shared memory if necessary - int shared_optin; - HIP_CHECK(hipDeviceGetAttribute(&shared_optin, hipDeviceAttributeSharedMemPerBlockOptin, dev)); - if (shared > 49152 && shared_optin > 49152) {{ - HIP_CHECK(hipFuncSetCacheConfig({kernel_name}_func, hipFuncCachePreferShared)); - HIP_CHECK(hipFuncSetAttribute(reinterpret_cast({kernel_name}_func), hipFuncAttributeMaxDynamicSharedMemorySize, shared_optin)) - }} -}} - -/* -{kernel_docstring} -*/ -hipError_t {kernel_name}(hipStream_t stream, {signature}) {{ - if ({kernel_name}_func == nullptr) - load_{kernel_name}(); - unsigned int gX = {gridX}; - unsigned int gY = {gridY}; - unsigned int gZ = {gridZ}; - hipDeviceptr_t global_scratch = 0; - void *args[{num_args}] = {{ {arg_pointers} }}; - // TODO: shared memory - if(gX * gY * gZ > 0) - return hipModuleLaunchKernel({kernel_name}_func, gX, gY, gZ, {num_warps} * warpSize, 1, 1, {shared}, stream, args, nullptr); - else - return hipErrorInvalidValue; -}} diff --git a/aiter/aot/triton/compile.h b/aiter/aot/triton/compile.h deleted file mode 100644 index c6d856ddfe..0000000000 --- a/aiter/aot/triton/compile.h +++ /dev/null @@ -1,14 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include -#include -#include - -void unload_{kernel_name}(void); -void load_{kernel_name}(void); -// tt-linker: {kernel_name}:{full_signature}:{algo_info} -hipError_t{_placeholder} {kernel_name}(hipStream_t stream, {signature}); \ No newline at end of file diff --git a/aiter/aot/triton/compile.py b/aiter/aot/triton/compile.py deleted file mode 100644 index acb92ceae3..0000000000 --- a/aiter/aot/triton/compile.py +++ /dev/null @@ -1,307 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. - -import binascii -import hashlib -import importlib.util -import sys -from argparse import ArgumentParser -from pathlib import Path -from typing import List - -import triton - -try: - old_compiler = True - from triton.compiler.code_generator import kernel_suffix -except ImportError: - old_compiler = False - -from triton.backends.amd.driver import ty_to_cpp - -desc = """ -Triton ahead-of-time compiler: - -This program compiles the kernel with name `kernel-name` in the file at the -provided `path` into self-contained C source-code that embeds the `cubin` -data along with utilities to load, unload and launch the kernel. - -signature is provided as a list of (optionally divisibility-hinted) types -or constexpr values, e.g. - -`compile.py --kernel-name kernel --signature "*fp32:16, i32:16, 1024, i32" --out-name kernel /path/to/kernel.py` - -will compile triton.JITFunction of name `kernel` inside the file `/path/to/kernel.py`. -Said kernel will be specialized such that argument 0, 1 are assumed to be multiple of 16, -and argument 2 is assumed to be a compile-time constant of value 1024, i.e. it won't be part of the generated prototype. - -The resulting entry point will have signature - -CUresult kernel_{specialization_suffix}(CUstream stream, unsigned gX, unsigned gY, unsigned gZ, float* arg0, int32_t arg1, int32_t arg2) - -Different such specialized entry points can be combined using the `linker.py` script. - -NOTE: when resolving the scope of /path/to/kernel.py, the file will be executed from within its parent directory with the python interpreter -used to run this `compile.py` script -""" - - -def compile_kernel( - path, - kernel_name: str, - signature: str, - grid: str, - num_warps: int = 1, - num_stages: int = 3, - out_name: str = None, - out_path: Path = None, - waves_per_eu=0, - kpack=2, - matrix_instr_nonkdim=16, -): - out_name = out_name if out_name else kernel_name - out_path = out_path if out_path else Path(out_name) - - arg_path = Path(path) - sys.path.insert(0, str(arg_path.parent)) - spec = importlib.util.spec_from_file_location(arg_path.stem, arg_path) - mod = importlib.util.module_from_spec(spec) - spec.loader.exec_module(mod) - kernel = getattr(mod, kernel_name) - grid = grid.split(",") - assert len(grid) == 3 - - # validate and parse signature - signature = list(map(lambda s: s.strip(" "), signature.split(","))) - - def hash_signature(signature: List[str]): - m = hashlib.sha256() - m.update(" ".join(signature).encode()) - return m.hexdigest()[:8] - - meta_sig = f"warps{num_warps}xstages{num_stages}" - sig_hash = hash_signature(signature + [meta_sig]) - - def constexpr(s): - try: - ret = int(s) - return ret - except ValueError: - pass - try: - ret = float(s) - return ret - except ValueError: - pass - return None - - if old_compiler: - hints = { - i: constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s - } - hints = {k: v for k, v in hints.items() if v is not None} - constants = {kernel.arg_names[i]: constexpr(s) for i, s in enumerate(signature)} - constants = {k: v for k, v in constants.items() if v is not None} - signature = { - kernel.arg_names[i]: s.split(":")[0] - for i, s in enumerate(signature) - if kernel.arg_names[i] not in constants - } - const_sig = "x".join([str(v) for v in constants.values()]) - doc_string = [f"{k}={v}" for k, v in constants.items()] - doc_string += [f"num_warps={num_warps}", f"num_stages={num_stages}"] - - # compile ast into cubin - for h in hints.values(): - assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}" - attrs = triton.backends.compiler.AttrsDescriptor.from_hints(hints) - for p, v in attrs.get_constants().items(): - constants.update({kernel.arg_names[p]: v}) - - src = triton.compiler.ASTSource( - fn=kernel, constants=constants, signature=signature, attrs=attrs - ) - opts = { - "num_warps": num_warps, - "num_stages": num_stages, - "waves_per_eu": waves_per_eu, - "kpack": kpack, - "matrix_instr_nonkdim": matrix_instr_nonkdim, - } - ccinfo = triton.compile(src, options=opts) - arg_names = [] - arg_types = [] - arg_names_not_1 = [] - arg_types_not_1 = [] - for i, arg_name in enumerate(kernel.arg_names): - if arg_name not in constants: - arg_names.append(arg_name) - arg_types.append(signature[arg_name]) - arg_names_not_1.append(arg_name) - arg_types_not_1.append(signature[arg_name]) - elif i in attrs.equal_to_1: - arg_names.append(arg_name) - arg_types.append(signature[arg_name]) - - # dump C stub code - suffix = kernel_suffix(signature.values(), attrs) - else: - hints = { - (i,): constexpr(s.split(":")[1]) - for i, s in enumerate(signature) - if ":" in s - } - hints = {k: v for k, v in hints.items() if v is not None} - constants = {kernel.arg_names[i]: constexpr(s) for i, s in enumerate(signature)} - constants = {k: v for k, v in constants.items() if v is not None} - for key, value in hints.items(): - if value == 1: - constants[kernel.arg_names[key[0]]] = value - signature = { - kernel.arg_names[i]: s.split(":")[0] for i, s in enumerate(signature) - } - for key in constants: - signature[key] = "constexpr" - const_sig = "x".join([str(v) for v in constants.values()]) - doc_string = [f"{k}={v}" for k, v in constants.items()] - doc_string += [ - f"num_warps={num_warps}", - f"num_stages={num_stages}", - f"waves_per_eu={waves_per_eu}", - f"kpack={kpack}", - f"matrix_instr_nonkdim={matrix_instr_nonkdim}", - ] - # compile ast into cubin - for h in hints.values(): - assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}" - attrs = {k: [["tt.divisibility", 16]] for k, v in hints.items() if v == 16} - src = triton.compiler.ASTSource( - fn=kernel, constexprs=constants, signature=signature, attrs=attrs - ) - opts = { - "num_warps": num_warps, - "num_stages": num_stages, - "waves_per_eu": waves_per_eu, - "kpack": kpack, - "matrix_instr_nonkdim": matrix_instr_nonkdim, - } - ccinfo = triton.compile(src, options=opts) - if ccinfo.metadata.global_scratch_size > 0: - raise RuntimeError( - "AOT compiling kernels with global scratch requirements is not yet implemented" - ) - - arg_names = [] - arg_types = [] - arg_names_not_1 = [] - arg_types_not_1 = [] - for i, arg_name in enumerate(kernel.arg_names): - if arg_name not in constants: - arg_names.append(arg_name) - arg_types.append(signature[arg_name]) - arg_names_not_1.append(arg_name) - arg_types_not_1.append(signature[arg_name]) - elif hints.get((i,), None) == 1: - arg_names.append(arg_name) - arg_types.append("i32") - - # dump C stub code - suffix = "" - for i, ty in enumerate(signature.values()): - suffix += str(i) - if hints.get((i,), None) == 1: - suffix += "c" - if hints.get((i,), None) == 16: - suffix += "d" - - func_name = "_".join([out_name, sig_hash, suffix]) - hex_ = binascii.hexlify(ccinfo.asm["hsaco"]).decode("utf-8") - - params = { - "kernel_name": func_name, - "triton_kernel_name": kernel_name, - "bin_size": len(hex_), - "bin_data": ", ".join([f"0x{x}{y}" for x, y in zip(hex_[::2], hex_[1::2])]), - "signature": ", ".join( - [ - f"{ty_to_cpp(ty)} {name}" - for name, ty in zip(arg_names_not_1, arg_types_not_1) - ] - ), - "full_signature": ", ".join( - [f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)] - ), - "arg_pointers": ", ".join( - [f"&{arg}" for arg in arg_names_not_1] + ["&global_scratch"] - ), - "num_args": len(arg_names_not_1) + 1, - "kernel_docstring": doc_string, - "shared": ccinfo.metadata.shared, - "num_warps": num_warps, - "algo_info": "_".join([const_sig, meta_sig]), - "gridX": grid[0], - "gridY": grid[1], - "gridZ": grid[2], - "_placeholder": "", - } - output_files = [] - for ext in ["h", "cpp"]: - template_path = Path(__file__).parent / f"compile.{ext}" - output_file = out_path.with_suffix(f".{sig_hash}_{suffix}.{ext}") - output_files.append(output_file) - with output_file.open("w") as fp: - fp.write(Path(template_path).read_text().format(**params)) - return func_name, *output_files - - -if __name__ == "__main__": - - # command-line arguments - parser = ArgumentParser(description=desc) - parser.add_argument( - "path", - help="Path to Python source containing desired kernel in its scope. File will be executed.", - ) - parser.add_argument( - "--kernel-name", - "-n", - type=str, - default="", - help="Name of the kernel to compile", - required=True, - ) - parser.add_argument( - "--num-warps", - "-w", - type=int, - default=1, - help="Number of warps to launch the kernel", - ) - parser.add_argument("--waves-per-eu", type=int, default=1) - parser.add_argument("--matrix-instr-nonkdim", type=int, default=0) - parser.add_argument("--kpack", type=int, default=1) - parser.add_argument( - "--num-stages", - "-ns", - type=int, - default=3, - help="Number of stages (meta-parameter of the kernel)", - ) - parser.add_argument( - "--out-name", - "-on", - type=str, - default=None, - help="Out name for the compiled kernel", - ) - parser.add_argument( - "--out-path", "-o", type=Path, default=None, help="Out filename" - ) - parser.add_argument( - "--signature", "-s", type=str, help="Signature of the kernel", required=True - ) - parser.add_argument( - "--grid", "-g", type=str, help="Launch grid of the kernel", required=True - ) - args = parser.parse_args() - compile_kernel(**vars(args))