From e2364522f8ed3504b0118736fb8387099aa5d69b Mon Sep 17 00:00:00 2001 From: xiaolil1 Date: Mon, 14 Jul 2025 08:35:02 +0000 Subject: [PATCH] Add draft gemm_4bit_cutlass kernel --- csrc/xpu_cutlass.cpp | 440 +++++++++++++++++++++++++++++++++++++++++++ csrc/xpu_cutlass.h | 57 ++++++ 2 files changed, 497 insertions(+) create mode 100644 csrc/xpu_cutlass.cpp create mode 100644 csrc/xpu_cutlass.h diff --git a/csrc/xpu_cutlass.cpp b/csrc/xpu_cutlass.cpp new file mode 100644 index 000000000..e9f74e18b --- /dev/null +++ b/csrc/xpu_cutlass.cpp @@ -0,0 +1,440 @@ +#include "xpu_cutlass.h" +#include +#include +#include +#include +#include +#include + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/layout.hpp" +#include "cutlass/detail/mma.hpp" +#include "cutlass/cuda_host_adapter.hpp" +#include +#include + +#include "cutlass/kernel_launch.h" + +// 2.x +#include "cutlass/gemm/device/gemm_universal_base.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h" + +// 3.x +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/sycl_event_manager.hpp" + +using namespace cute; +using namespace cutlass; +using namespace cutlass::gemm; + +// Define Basic information +// Weight-only-quant (B) +using MmaType = cutlass::bfloat16_t; +using QuantType = cutlass::uint4_t; //NF4,FP4 + +using ElementA = MmaType; +using ElementB = QuantType; + +using ElementMMA = ElementA; +using ElementQuant = QuantType; +using ElementScale = float; + +using ElementAccumulator = float; +using ElementComputeEpilogue = float; +using ElementOutput = float; + +using ProblemShape = Shape; + +using TileShape = Shape<_32, _128, _64>; +using TiledMma = + typename TiledMMAHelper, Layout, + Layout, Stride<_8, _1, _0>>>::TiledMMA; +constexpr int PipelineStages = 2; +constexpr int LUT_NUM = 4; + +using MmaAtomShape = typename TiledMma::AtomShape_MNK; +using WorkgroupTileShape = TileShape; +static constexpr auto BLK_M = get<0>(WorkgroupTileShape{}); +static constexpr auto BLK_N = get<1>(WorkgroupTileShape{}); +static constexpr auto BLK_K = get<2>(WorkgroupTileShape{}); + +//Threads number +static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); +static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); +static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); + +static_assert(BLK_M % TiledMma{}.template tile_size_mnk<0>() == 0, "TiledMma permutation size must match block size."); +static_assert(BLK_N % TiledMma{}.template tile_size_mnk<1>() == 0, "TiledMma permutation size must match block size."); +static_assert(BLK_K % TiledMma{}.template tile_size_mnk<2>() == 0, "TiledMma permutation size must match block size."); + +//sub-tile shape +static constexpr auto SG_M = ceil_div(BLK_M, ATOM_M); +static constexpr auto SG_N = ceil_div(BLK_N, ATOM_N); +static constexpr auto SG_K = ceil_div(BLK_K, ATOM_K); +using SubgroupTileShape = Shape; + +static constexpr auto Num_SGs = ATOM_N * ATOM_M * ATOM_K; +static constexpr auto SG_QNT_WIDTH = Int{}; +static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}); + +using DispatchPolicy = cutlass::gemm::MainloopIntelPVCMixedPrecision; +static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; + +// Design Scheduler +using TileScheduler_ = PersistentScheduler; +static_assert(cute::is_void_v or cute::is_same_v, "Intel PVC does not support specializing the tile scheduler."); +using ArchTag = typename DispatchPolicy::ArchTag; +using TileScheduler = typename cutlass::gemm::kernel::detail::TileSchedulerSelector, cute::Int<1>, cute::Int<1>>>::Scheduler; +using TileSchedulerArguments = typename TileScheduler::Arguments; +using TileSchedulerParams = typename TileScheduler::Params; + +using ClusterShape = typename DispatchPolicy::ClusterShape; + +// Define Copy +using CopyThreadShape = Shape<_1, Int>; +using CopyThreadShapeRev = decltype(cute::reverse(CopyThreadShape{})); + +using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; +using StrideA = cutlass::gemm::TagToStrideA_t; +using traits_load_A = Copy_Traits; +using atom_load_A = Copy_Atom; +using val_layout_load_A = decltype(make_layout(shape_div(typename traits_load_A::BlockShape{}, CopyThreadShape{}))); +using Copy_A = decltype(make_tiled_copy(atom_load_A{}, Layout{}, val_layout_load_A{})); + +using GmemTiledCopyB = XE_2D_U4x32x16_LD_T; +using StrideB = cutlass::gemm::TagToStrideB_t; +using traits_load_B = Copy_Traits; +using atom_load_B = Copy_Atom; +using val_layout_load_B = decltype(make_layout(shape_div(typename traits_load_B::BlockShape{}, CopyThreadShape{}))); +using Copy_B = decltype(make_tiled_copy(atom_load_B{}, Layout{}, val_layout_load_B{})); + +using GmemTiledCopyScale = XE_2D_U32x1x16_LD_N; +using StrideScale = cute::Stride<_1, int64_t, int64_t>; +using traits_load_scale = Copy_Traits; +using atom_load_scale = Copy_Atom; +using val_layout_load_scale = decltype(make_layout(shape_div(typename traits_load_scale::BlockShape{}, CopyThreadShapeRev{}))); +using Copy_Scale = decltype(make_tiled_copy(atom_load_scale{}, Layout{}, val_layout_load_scale{})); + +using GmemTiledCopyD = XE_2D_U32x8x16_ST_N; +using StrideD = cutlass::gemm::TagToStrideC_t; +using Trait_D = Copy_Traits; +using val_layout_store_D = decltype(make_layout(shape_div(typename Trait_D::BlockShape{}, CopyThreadShape{}))); +using Copy_D = decltype(make_tiled_copy(Copy_Atom{}, Layout{}, val_layout_store_D{})); + +template +class gemm_4bit_cutlass_kernel { +public: + struct Params { + int m, n, k, l; + T* A; + uint8_t* B; + float* out; + float *datatype; // LUT + int group_size; + float* absmax; + + ProblemShape problem_shape{}; + + Copy_A tiled_copy_a; + Copy_B tiled_copy_b; + Copy_Scale tiled_copy_scale; + Copy_D tiled_store_d; + KernelHardwareInfo hw_info{}; + TileSchedulerParams scheduler{}; + }; + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + static dim3 + get_grid_shape(Params const& params) { + dim3 grid = TileScheduler::get_tiled_cta_shape_mnl(params.problem_shape, TileShape{}, ClusterShape{}); + if(params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN) { + return {grid.y, grid.x, grid.z}; + } else { + return {grid.x, grid.y, grid.z}; + } + } + + CUTLASS_DEVICE + void operator()(Params const& params, char* smem_buf) { + int thread_idx = int(ThreadIdxX()); + + // Load Dequatize LUT and save to SLM, 16 for 4bits + alignas(128) float (*quant_map)[16] = reinterpret_cast(smem_buf); + if (thread_idx < 16 * LUT_NUM) { + quant_map[thread_idx / 16][thread_idx % 16] = params.datatype[thread_idx % 16]; + } + barrier_arrive(3); + + const int m_coord = (params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN) + ? BlockIdxY() : BlockIdxX(); + const int n_coord = (params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN) + ? BlockIdxX() : BlockIdxY(); + const int l_coord = BlockIdxZ(); + + Tensor mA_mkl = cute::get_pvc_tensor(make_shape(params.m, params.k, params.l)); + Tensor mB_nkl = cute::get_pvc_tensor(make_shape(params.n, params.k, 1)); + + Tensor gA = local_tile(mA_mkl, select<0, 2>(TileShape{}), make_coord(m_coord, _, l_coord)); + Tensor gB = local_tile(mB_nkl, select<1, 2>(TileShape{}), make_coord(n_coord, _, 0)); + + TiledMma tiled_mma; + Tensor accumulators = partition_fragment_C(tiled_mma, take<0, 2>(TileShape{})); + clear(accumulators); + + auto k_tile_iter = cute::make_coord_iterator(idx2crd(0, make_shape(params.k)), make_shape(params.k)); + int k_tile_count = ceil_div(params.k, get<2>(TileShape{})); + + auto thr_copy_A = params.tiled_copy_a.get_slice(thread_idx); + auto thr_copy_B = params.tiled_copy_b.get_slice(thread_idx); + auto thr_copy_scale = params.tiled_copy_scale.get_slice(thread_idx); + + auto first_thread_in_sg_idx = syclcompat::get_nd_item<1>().get_sub_group().get_group_linear_id() * DispatchPolicy::SubgroupSize; + auto thr_mma = tiled_mma.get_slice(first_thread_in_sg_idx); + + Tensor tCgA = thr_mma.partition_A(gA); + Tensor tCgB = thr_mma.partition_B(gB); + + Tensor mma_A = make_tensor(make_fragment_layout(params.tiled_copy_a, tCgA(_, _, _, 0).shape())); + Tensor mma_B = make_tensor(make_fragment_layout(params.tiled_copy_b, tCgB(_, _, _, 0).shape())); + + Tensor dequant_frag = make_tensor(mma_B.layout()); + + static constexpr auto scale_traits_size = decltype(size(typename GmemTiledCopyScale::BlockShape{}))::value / DispatchPolicy::SubgroupSize; + static constexpr auto scale_traits_num = SG_QNT_WIDTH / decltype(size<1>(typename GmemTiledCopyScale::BlockShape{}))::value; + using FragScaleLayout = Layout, Int, _1>>; + Tensor fragment_scale = make_tensor(FragScaleLayout{}); + + static_assert(std::is_same_v); + static_assert(std::is_same_v); + static_assert(std::is_same_v); + + Tensor frag_copy_A = thr_copy_A.retile_D(mma_A); + Tensor frag_copy_B = thr_copy_B.retile_D(dequant_frag); + Tensor frag_copy_Scale = thr_copy_scale.retile_D(fragment_scale); + + Tensor tAgA = thr_copy_A.retile_S(tCgA); + Tensor tBgB = thr_copy_B.retile_S(tCgB); + + auto tiled_prefetch_a = cute::prefetch_selector,Int>, Num_SGs>(params.tiled_copy_a);; + auto tiled_prefetch_b = cute::prefetch_selector,Int>, Num_SGs>(params.tiled_copy_b);; + auto thr_prefetch_A = tiled_prefetch_a.get_slice(thread_idx); + auto thr_prefetch_B = tiled_prefetch_b.get_slice(thread_idx); + + auto pAgA = thr_prefetch_A.partition_S(gA); + auto pBgB = thr_prefetch_B.partition_S(gB); + + const int k_reload_factor = ceil_div(params.group_size, BLK_K); + + auto tSgS = [&](){ + return make_tensor(make_inttuple_iter(make_coord(n_coord * BLK_N + get<2>(thr_mma.thr_vmnk_)*SG_QNT_WIDTH, 0, 0)), + make_layout(make_shape(Int{}, Int{}, _1{}, k_tile_count/k_reload_factor), + make_stride(E<0>{}*_16{}, E<0>{}*_16{}, _0{}, E<1>{}*_1{}))); + + }(); + + auto dequant = [&] (int start_lut_id) { + constexpr int N = decltype(cute::size<1>(mma_B))::value; + constexpr int K = decltype(cute::size(mma_B))::value / N; + + using src_compress_type = uint64_t; + using dst_compress_type = uint64_t; + + constexpr int src_compress_size = cute::sizeof_bits_v / cute::sizeof_bits_v; + constexpr int dst_compress_size = cute::sizeof_bits_v / cute::sizeof_bits_v; + constexpr int src_vec_size = (K / src_compress_size) >= 16 ? 16 : K / src_compress_size; //16 -> max vec_size of sycl::vec + constexpr int dst_vec_size = (K / dst_compress_size) >= 16 ? 16 : K / dst_compress_size; //16 -> max vec_size of sycl::vec + constexpr int src_loop_num = K / src_vec_size / src_compress_size; + constexpr int dst_loop_num = K / dst_vec_size / dst_compress_size; + + src_compress_type src[src_vec_size]; + ElementMMA dst[dst_compress_size * dst_vec_size]; + + int lut_id = start_lut_id; + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < N; n++) { + float scale_value = fragment_scale(n); + + CUTLASS_PRAGMA_UNROLL + for (int l = 0; l < src_loop_num; l++) { + reinterpret_cast*>(src)[0] = + reinterpret_cast*>( + cute::raw_pointer_cast(dequant_frag.data()))[n*src_loop_num + l]; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < src_vec_size; v++) { + src_compress_type src_value = src[v]; + int dst_idx = v * src_compress_size; + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < src_compress_size; c++) { + uint8_t bit_value = (src_value >> (4 * (((c + 1) & 1) + (c >> 1) * 2))) & 0xF; + dst[dst_idx + c] = static_cast(quant_map[lut_id][bit_value] * scale_value); + lut_id = (lut_id + 1) % LUT_NUM; + } + } + } + + CUTLASS_PRAGMA_UNROLL + for (int l = 0; l < dst_loop_num; l++) { + reinterpret_cast*>( + cute::raw_pointer_cast(mma_B.data()))[n*dst_loop_num + l] = + reinterpret_cast*>(dst)[0]; + } + } + }; + + const int k_start_idx = crd2idx((*k_tile_iter), make_shape(params.k)); + int prefetch_k = k_start_idx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < DispatchPolicy::Stages; i++, prefetch_k++) { + prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k)); + prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k)); + } + + auto sg_idx = syclcompat::get_nd_item<1>().get_sub_group().get_group_linear_id(); + int start_lut_id = sg_idx % LUT_NUM; + + for (int k_tile = k_start_idx, k_s = 0; k_tile < k_tile_count + k_start_idx; k_tile++, prefetch_k++, k_s++) { + copy(params.tiled_copy_b, tBgB(_, _, _,k_tile), frag_copy_B); + + copy(params.tiled_copy_scale, tSgS(_, _, _, (k_start_idx + k_s) / k_reload_factor), frag_copy_Scale); + + copy(params.tiled_copy_a, tAgA(_, _, _,k_tile), frag_copy_A); + + dequant(start_lut_id); + + if(prefetch_k < k_tile_count) { + prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k)); + prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k)); + } + + cute::gemm(tiled_mma, mma_A, mma_B, accumulators); + barrier_wait(3); + } + + static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); // atom numbers per thread; A frags per sub_group + static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // atom numbers per thread; B frags per sub_group + + auto m_sg = get_sub_group_id() / ATOM_N; + auto n_sg = get_sub_group_id() % ATOM_N; + + Tensor mD_mnl = cute::get_pvc_tensor(make_shape(params.m, params.n, params.l)); // Logical full output tensor + + // Tile the output tensor per WG and select the tile for current WG + Tensor g_wg_D = local_tile(mD_mnl, take<0,2>(TileShape{}), make_coord(m_coord,n_coord,l_coord)); + + // Tile the output tensor per SG and select tile for the current SG + Tensor gD = local_tile(g_wg_D, take<0,2>(SubgroupTileShape{}), make_coord(m_sg,n_sg)); + + auto thread_xe_store_d = params.tiled_store_d.get_thread_slice(thread_idx); //partial copy_atom for current thread + Tensor tCgD = thread_xe_store_d.partition_D(gD); //values for current thread. + + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < FragsN; ++epi_n) { + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < FragsM; ++epi_m) { + copy(params.tiled_store_d, accumulators(_, epi_m, epi_n), tCgD(_, epi_m, epi_n)); + } + } + } +}; + +template +void gemm_4bit_cutlass(int m, int n, int k, int l, T *A, unsigned char *B, + float *absmax_, float *datatype, float *out, int lda, + int ldb, int ldc, int blocksize, sycl::queue *stream) { + sycl::queue q = *stream; + + using GemmKernel = gemm_4bit_cutlass_kernel; + + static constexpr int smem_size= 16 * sizeof(float) * LUT_NUM; + + auto problem_size = ProblemShape{m, n, k, l}; + + using Params = GemmKernel::Params; + Params params; + params.m = m; + params.n = n; + params.k = k; + params.l = l; + params.A = A; + params.B = B; + params.out = out; + params.datatype = datatype; + params.group_size = blocksize; + params.absmax = absmax_; + + StrideA stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, l)); + auto mA_mkl = make_tensor(make_gmem_ptr(A), make_layout(make_shape(m, k, l), stride_A)); + Copy_A tiled_copy_a{Copy_A{}.with(mA_mkl)}; + + StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, l)); + auto mB_nkl = make_tensor(cute::subbyte_iterator(B), make_layout(make_shape(n, k, l), stride_B)); + Copy_B tiled_copy_b{Copy_B{}.with(mB_nkl)}; + + params.tiled_copy_a = tiled_copy_a; + params.tiled_copy_b = tiled_copy_b; + + const int scale_k = cute::ceil_div(k, blocksize); + StrideScale stride_S = cutlass::make_cute_packed_stride(StrideScale{}, cute::make_shape(n, scale_k, 1)); + auto mScale = make_tensor(make_gmem_ptr(absmax_), make_layout(make_shape(n, scale_k, 1), stride_S)); + Copy_Scale tiled_copy_scale{Copy_Scale{}.with(mScale)}; + + params.tiled_copy_scale = tiled_copy_scale; + + cutlass::KernelHardwareInfo hw_info; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + auto problem_shape_MNKL = problem_size; + + StrideD stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, l)); + auto mD = make_tensor(make_gmem_ptr(out), make_layout(make_shape(m, n, l), stride_D)); + Copy_D tiled_store_d = {Copy_D{}.with(mD)}; + params.tiled_store_d = tiled_store_d; + + params.hw_info = hw_info; + + TileSchedulerArguments scheduler{}; + params.scheduler = TileScheduler::to_underlying_arguments( + problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, scheduler, nullptr); + + params.problem_shape = problem_size; + + dim3 const block = GemmKernel::get_block_shape(); + dim3 const grid = GemmKernel::get_grid_shape(params); + + const syclcompat::dim3 sycl_block(block.x, block.y, block.z); + const syclcompat::dim3 sycl_grid(grid.x, grid.y, grid.z); + + auto kernel_props = [] { + return syclcompat::experimental::kernel_properties{ + sycl::ext::oneapi::experimental::sub_group_size + }; + }(); + syclcompat::experimental::launch_properties launch_props { + sycl::ext::oneapi::experimental::work_group_scratch_size(smem_size), + }; + syclcompat::experimental::launch_policy policy{ + sycl_grid, sycl_block, launch_props, kernel_props + }; + + syclcompat::experimental::launch>(policy, q, params); +} + +template void gemm_4bit_cutlass( + int m, int n, int k, int l, sycl::ext::oneapi::bfloat16 *A, unsigned char *B, + float *absmax, float *datatype, float *out, int lda, + int ldb, int ldc, int blocksize, sycl::queue *stream); + diff --git a/csrc/xpu_cutlass.h b/csrc/xpu_cutlass.h new file mode 100644 index 000000000..2bcb5c968 --- /dev/null +++ b/csrc/xpu_cutlass.h @@ -0,0 +1,57 @@ +#pragma once + +#include + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/xe_epilogue.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/util/GPU_Clock.hpp" +#include + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "helper.h" +#include "sycl_common.hpp" +#endif + +// cute API +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/kernel_hardware_info.hpp" + +#include "cute/tensor.hpp" + +#include +#include +#include +#include +#include + +#include + +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/tensor_predicate.hpp" + +#include "cutlass/device_kernel.h" + +template +void gemm_4bit_cutlass(int m, int n, int k, int l, T *A, unsigned char *B, + T *absmax, float *datatype, float *out, int lda, + int ldb, int ldc, int blocksize, + sycl::queue *stream);