From 5f8598e941b9564752091e367c9b077bcd24ab40 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 1 May 2025 23:25:48 -0400 Subject: [PATCH] [CUTLASS] Add GeMM kernels for Blackwell GPUs This PR introduces CUTLASS gemm kernels, groupwise-scaled gemm kernels and group gemm kernels for Blackwell GPUs. Files are reorganized a bit so that the exposed global functions are now architecture agnostic. Prior to this PR, our global function names for CUTLASS kernels usually end with `"_sm90"`, which brings extra complexity when the frontend compiler decides to dispatch kernels when there are multiple supported architectures, such as Hopper and Blackwell. Therefore, this PR renames those global function so that the function names are arch agnostic. During the build time, only the kernels that the specific architecture supports will be built. --- 3rdparty/cutlass | 2 +- cmake/modules/contrib/CUTLASS.cmake | 16 +- .../contrib/cutlass/fp16_group_gemm.cuh | 72 ++++++ .../cutlass/fp16_group_gemm_runner_sm100.cuh | 221 ++++++++++++++++++ ...er.cuh => fp16_group_gemm_runner_sm90.cuh} | 10 +- .../contrib/cutlass/fp16_group_gemm_sm100.cu | 54 +++++ ..._group_gemm.cu => fp16_group_gemm_sm90.cu} | 53 +++-- .../cutlass/fp8_blockwise_scaled_gemm.cu | 164 ------------- ...8_group_gemm.cu => fp8_group_gemm_sm90.cu} | 2 +- .../cutlass/fp8_groupwise_scaled_gemm.cuh | 172 ++++++++++++++ ...fp8_groupwise_scaled_gemm_runner_sm100.cuh | 155 ++++++++++++ ...fp8_groupwise_scaled_gemm_runner_sm90.cuh} | 53 +---- .../fp8_groupwise_scaled_gemm_sm100.cu | 77 ++++++ .../cutlass/fp8_groupwise_scaled_gemm_sm90.cu | 77 ++++++ ...oupwise_scaled_group_gemm_runner_sm100.cuh | 220 +++++++++++++++++ .../fp8_groupwise_scaled_group_gemm_sm100.cu | 93 ++++++++ src/target/tag.cc | 2 + tests/python/contrib/test_cutlass_gemm.py | 32 ++- 18 files changed, 1220 insertions(+), 255 deletions(-) create mode 100644 src/runtime/contrib/cutlass/fp16_group_gemm.cuh create mode 100644 src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm100.cuh rename src/runtime/contrib/cutlass/{group_gemm_runner.cuh => fp16_group_gemm_runner_sm90.cuh} (96%) create mode 100644 src/runtime/contrib/cutlass/fp16_group_gemm_sm100.cu rename src/runtime/contrib/cutlass/{fp16_group_gemm.cu => fp16_group_gemm_sm90.cu} (60%) delete mode 100644 src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu rename src/runtime/contrib/cutlass/{fp8_group_gemm.cu => fp8_group_gemm_sm90.cu} (98%) create mode 100644 src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh create mode 100644 src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm100.cuh rename src/runtime/contrib/cutlass/{blockwise_scaled_gemm_runner.cuh => fp8_groupwise_scaled_gemm_runner_sm90.cuh} (75%) create mode 100644 src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm100.cu create mode 100644 src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm90.cu create mode 100644 src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_runner_sm100.cuh create mode 100644 src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu diff --git a/3rdparty/cutlass b/3rdparty/cutlass index afa177220367..ad7b2f5e84fc 160000 --- a/3rdparty/cutlass +++ b/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit afa1772203677c5118fcd82537a9c8fefbcc7008 +Subproject commit ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e diff --git a/cmake/modules/contrib/CUTLASS.cmake b/cmake/modules/contrib/CUTLASS.cmake index d11777e8514a..b74ce4c8dfe0 100644 --- a/cmake/modules/contrib/CUTLASS.cmake +++ b/cmake/modules/contrib/CUTLASS.cmake @@ -58,19 +58,27 @@ if(USE_CUDA AND USE_CUTLASS) set(TVM_CUTLASS_RUNTIME_SRCS "") if (CMAKE_CUDA_ARCHITECTURES MATCHES "90a") - list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp16_group_gemm.cu) - list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_group_gemm.cu) + list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp16_group_gemm_sm90.cu) + list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu) list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_gemm.cu) - list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu) + list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm90.cu) + endif() + if (CMAKE_CUDA_ARCHITECTURES MATCHES "100a") + list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp16_group_gemm_sm100.cu) + list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm100.cu) + list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu) endif() if(TVM_CUTLASS_RUNTIME_SRCS) add_library(tvm_cutlass_objs OBJECT ${TVM_CUTLASS_RUNTIME_SRCS}) - target_compile_options(tvm_cutlass_objs PRIVATE $<$:--expt-relaxed-constexpr>) + target_compile_options(tvm_cutlass_objs PRIVATE $<$:-lineinfo --expt-relaxed-constexpr>) target_include_directories(tvm_cutlass_objs PRIVATE ${CUTLASS_DIR}/include ${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm/cutlass_extensions/include ) + target_link_libraries(tvm_cutlass_objs PRIVATE tvm_ffi_header) target_compile_definitions(tvm_cutlass_objs PRIVATE DMLC_USE_LOGGING_LIBRARY=) + # Note: enable this to get more detailed logs for cutlass kernels + # target_compile_definitions(tvm_cutlass_objs PRIVATE CUTLASS_DEBUG_TRACE_LEVEL=2) list(APPEND CUTLASS_RUNTIME_OBJS "$<${CUTLASS_GEN_COND}:$>") endif() diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm.cuh b/src/runtime/contrib/cutlass/fp16_group_gemm.cuh new file mode 100644 index 000000000000..ebb8f58a6b18 --- /dev/null +++ b/src/runtime/contrib/cutlass/fp16_group_gemm.cuh @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +#include "cutlass/bfloat16.h" +#include "cutlass/half.h" + +namespace tvm { +namespace runtime { + +template +struct CutlassGroupGemm; + +template +void tvm_cutlass_group_gemm_impl(NDArray x, NDArray weight, NDArray indptr, NDArray workspace, + NDArray out) { + // Workspace is used for storing device-side group gemm arguments and cutlass internal workspace. + // Recommened size is 4MB. + static auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream"); + CHECK_EQ(x->ndim, 2); + CHECK_EQ(weight->ndim, 3); + CHECK_EQ(indptr->ndim, 1); + CHECK_EQ(workspace->ndim, 1); + CHECK_EQ(out->ndim, 2); + int num_groups = weight->shape[0]; + int n = weight->shape[1]; + int k = weight->shape[2]; + float alpha = 1.0f; + float beta = 0.0f; + cudaStream_t stream = static_cast(func().cast()); + + if (DataType(x->dtype) == DataType::Float(16)) { + CHECK(DataType(weight->dtype) == DataType::Float(16)); + CHECK(DataType(out->dtype) == DataType::Float(16)); + using Dtype = cutlass::half_t; + CutlassGroupGemm::run( + static_cast(x->data), static_cast(weight->data), + static_cast(indptr->data), static_cast(workspace->data), + workspace->shape[0], n, k, num_groups, alpha, beta, static_cast(out->data), stream); + } else if (DataType(x->dtype) == DataType::BFloat(16)) { + CHECK(DataType(weight->dtype) == DataType::BFloat(16)); + CHECK(DataType(out->dtype) == DataType::BFloat(16)); + using Dtype = cutlass::bfloat16_t; + CutlassGroupGemm::run( + static_cast(x->data), static_cast(weight->data), + static_cast(indptr->data), static_cast(workspace->data), + workspace->shape[0], n, k, num_groups, alpha, beta, static_cast(out->data), stream); + } +} + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm100.cuh b/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm100.cuh new file mode 100644 index 000000000000..f38664915d35 --- /dev/null +++ b/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm100.cuh @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include "../../cuda/cuda_common.h" + +// clang-format off +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +// clang-format on + +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + CHECK(error == cutlass::Status::kSuccess) \ + << "Got cutlass error: " << cutlassGetStatusString(error); \ + } + +using namespace cute; +using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group + +inline size_t aligned(size_t value, size_t alignment = 16) { + return (value + alignment - 1) / alignment * alignment; +} + +template +struct MMA1SMConfig { + using MmaTileShape = Shape<_128, _256, Int<128 / sizeof(ElementA)>>; + using ClusterShape = Shape<_2, _2, _1>; + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch +}; + +template +struct MMA2SMConfig { + using MmaTileShape = Shape<_256, _256, Int<128 / sizeof(ElementA)>>; + using ClusterShape = Shape<_2, _2, _1>; + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch +}; + +template +struct CutlassGroupGemmRunner { + static constexpr int AlignmentA = + 128 / cutlass::sizeof_bits::value; // Alignment of A matrix in units of elements + // (up to 16 bytes) + + static constexpr int AlignmentB = + 128 / cutlass::sizeof_bits::value; // Alignment of B matrix in units of elements + // (up to 16 bytes) + + static constexpr int AlignmentC = + 128 / cutlass::sizeof_bits::value; // Alignment of C matrix in units of elements + // (up to 16 bytes) + + // Core kernel configurations + using ElementAccumulator = float; // Element type for internal accumulation + using ScaleType = std::variant; + using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + using StageCountType = + cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size + + // Different configs for 1SM and 2SM MMA kernel + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, OperatorClass, typename ScheduleConfig::MmaTileShape, + typename ScheduleConfig::ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementC, LayoutC*, + AlignmentC, typename ScheduleConfig::EpilogueSchedule, + cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB, LayoutB*, + AlignmentB, ElementAccumulator, typename ScheduleConfig::MmaTileShape, + typename ScheduleConfig::ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + typename ScheduleConfig::KernelSchedule>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; + + void run_group_gemm(const ElementA** ptr_A, const ElementB** ptr_B, const ElementC** ptr_C, + ElementC** ptr_D, + typename ProblemShape::UnderlyingProblemShape* problem_sizes, + typename ProblemShape::UnderlyingProblemShape* problem_sizes_host, + StrideA* stride_A, StrideB* stride_B, StrideC* stride_C, StrideD* stride_D, + uint8_t* workspace, int64_t workspace_size, int num_groups, ScaleType alpha, + ScaleType beta, cudaStream_t stream) { + typename Gemm::Arguments arguments; + decltype(arguments.epilogue.thread) fusion_args; + [&]() { + ICHECK(alpha.index() == beta.index()) << "alpha and beta must have the same type"; + if (std::holds_alternative(alpha)) { + fusion_args.alpha = std::get(alpha); + fusion_args.beta = std::get(beta); + } else if (std::holds_alternative(alpha)) { + fusion_args.alpha_ptr = std::get(alpha); + fusion_args.beta_ptr = std::get(beta); + } else { + LOG(FATAL) << "Unsupported alpha and beta type"; + throw; + } + }(); + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + arguments = typename Gemm::Arguments{cutlass::gemm::GemmUniversalMode::kGrouped, + {num_groups, problem_sizes, problem_sizes_host}, + {ptr_A, stride_A, ptr_B, stride_B}, + {fusion_args, ptr_C, stride_C, ptr_D, stride_D}, + hw_info}; + Gemm gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(arguments)); + CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments)); + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream)); + CUTLASS_CHECK(gemm_op.run(stream)); + } +}; + +template +__global__ void prepare_group_gemm_arguments( + const ElementA** ptr_A, const ElementB** ptr_B, ElementC** ptr_D, + typename ProblemShape::UnderlyingProblemShape* problem_sizes, StrideA* stride_A, + StrideB* stride_B, StrideC* stride_D, const ElementA* x, const ElementB* weight, ElementC* out, + int64_t* indptr, int64_t n, int64_t k, int64_t num_groups) { + int group_id = threadIdx.x; + if (group_id >= num_groups) return; + int prev_rows = group_id == 0 ? 0 : indptr[group_id - 1]; + ptr_A[group_id] = x + prev_rows * k; + ptr_B[group_id] = weight + group_id * k * n; + ptr_D[group_id] = out + prev_rows * n; + problem_sizes[group_id] = {static_cast(indptr[group_id] - prev_rows), static_cast(n), + static_cast(k)}; + stride_A[group_id] = cute::make_stride(k, Int<1>{}, Int<0>{}); + stride_B[group_id] = cute::make_stride(k, Int<1>{}, Int<0>{}); + stride_D[group_id] = cute::make_stride(n, Int<1>{}, Int<0>{}); +} + +template +void cutlass_group_gemm_sm100(ElementA* x, ElementB* weight, int64_t* indptr, uint8_t* workspace, + int64_t workspace_size, int64_t n, int64_t k, int64_t num_groups, + std::variant alpha, + std::variant beta, ElementC* out, + cudaStream_t stream) { + // Note: We use MMA2SMConfig for now. It can be changed to MMA1SMConfig if needed. + using Runner = CutlassGroupGemmRunner, ElementA, ElementB, ElementC>; + using StrideA = typename Runner::StrideA; + using StrideB = typename Runner::StrideB; + using StrideC = typename Runner::StrideC; + + Runner runner; + std::ptrdiff_t offset = 0; + const ElementA** ptr_A = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(ElementA*) * num_groups); + const ElementB** ptr_B = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(ElementB*) * num_groups); + ElementC** ptr_D = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(ElementC*) * num_groups); + typename ProblemShape::UnderlyingProblemShape* problem_sizes = + reinterpret_cast(workspace + offset); + offset += aligned(sizeof(typename ProblemShape::UnderlyingProblemShape) * num_groups); + StrideA* stride_A = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(StrideA) * num_groups); + StrideB* stride_B = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(StrideB) * num_groups); + StrideC* stride_D = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(StrideC) * num_groups); + prepare_group_gemm_arguments<<<1, num_groups, 0, stream>>>(ptr_A, ptr_B, ptr_D, problem_sizes, + stride_A, stride_B, stride_D, x, + weight, out, indptr, n, k, num_groups); + offset = aligned(offset, 256); + runner.run_group_gemm(ptr_A, ptr_B, const_cast(ptr_D), ptr_D, problem_sizes, + nullptr, stride_A, stride_B, stride_D, stride_D, workspace + offset, + workspace_size - offset, num_groups, alpha, beta, stream); +} diff --git a/src/runtime/contrib/cutlass/group_gemm_runner.cuh b/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh similarity index 96% rename from src/runtime/contrib/cutlass/group_gemm_runner.cuh rename to src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh index a3c52e27a9d5..38e1beb2b8f4 100644 --- a/src/runtime/contrib/cutlass/group_gemm_runner.cuh +++ b/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh @@ -169,11 +169,11 @@ __global__ void prepare_group_gemm_arguments( } template -void cutlass_group_gemm(ElementA* x, ElementB* weight, int64_t* indptr, uint8_t* workspace, - int64_t workspace_size, int64_t n, int64_t k, int64_t num_groups, - std::variant alpha, - std::variant beta, ElementC* out, - cudaStream_t stream) { +void cutlass_group_gemm_sm90(ElementA* x, ElementB* weight, int64_t* indptr, uint8_t* workspace, + int64_t workspace_size, int64_t n, int64_t k, int64_t num_groups, + std::variant alpha, + std::variant beta, ElementC* out, + cudaStream_t stream) { using Runner = CutlassGroupGemmRunner; using StrideA = typename Runner::StrideA; using StrideB = typename Runner::StrideB; diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm_sm100.cu b/src/runtime/contrib/cutlass/fp16_group_gemm_sm100.cu new file mode 100644 index 000000000000..29efcbe088ae --- /dev/null +++ b/src/runtime/contrib/cutlass/fp16_group_gemm_sm100.cu @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include "fp16_group_gemm.cuh" +#include "fp16_group_gemm_runner_sm100.cuh" + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +namespace tvm { +namespace runtime { + +template +struct CutlassGroupGemm<100, ElementA, ElementB, ElementC> { + static void run(ElementA* A, ElementB* B, int64_t* indptr, uint8_t* workspace, int workspace_size, + int N, int K, int num_groups, float alpha, float beta, ElementC* C, + cudaStream_t stream) { + cutlass_group_gemm_sm100( + A, B, indptr, workspace, workspace_size, N, K, num_groups, alpha, beta, C, stream); + } +}; + +void tvm_cutlass_group_gemm_sm100(NDArray x, NDArray weight, NDArray indptr, NDArray workspace, + NDArray out) { + tvm_cutlass_group_gemm_impl<100>(x, weight, indptr, workspace, out); +} + +TVM_FFI_REGISTER_GLOBAL("cutlass.group_gemm").set_body_typed(tvm_cutlass_group_gemm_sm100); + +} // namespace runtime +} // namespace tvm + +#endif // CUTLASS_ARCH_MMA_SM100_SUPPORTED diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm.cu b/src/runtime/contrib/cutlass/fp16_group_gemm_sm90.cu similarity index 60% rename from src/runtime/contrib/cutlass/fp16_group_gemm.cu rename to src/runtime/contrib/cutlass/fp16_group_gemm_sm90.cu index dffe7dc4ffed..93a03a0675b2 100644 --- a/src/runtime/contrib/cutlass/fp16_group_gemm.cu +++ b/src/runtime/contrib/cutlass/fp16_group_gemm_sm90.cu @@ -19,14 +19,28 @@ #include #include -#include #include +#include #include -#include "group_gemm_runner.cuh" +#include "fp16_group_gemm.cuh" +#include "fp16_group_gemm_runner_sm90.cuh" + +namespace tvm { +namespace runtime { #if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) +template +struct CutlassGroupGemm<90, ElementA, ElementB, ElementC> { + static void run(ElementA* A, ElementB* B, int64_t* indptr, uint8_t* workspace, int workspace_size, + int N, int K, int num_groups, float alpha, float beta, ElementC* C, + cudaStream_t stream) { + cutlass_group_gemm_sm90(A, B, indptr, workspace, workspace_size, + N, K, num_groups, alpha, beta, C, stream); + } +}; + template <> struct KernelTraits { using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; @@ -34,36 +48,21 @@ struct KernelTraits { using ClusterShape = Shape<_2, _2, _1>; // Shape of the threadblocks in a cluster }; -namespace tvm { -namespace runtime { +template <> +struct KernelTraits { + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; + using TileShape = Shape<_128, _256, _64>; // Threadblock-level tile size + using ClusterShape = Shape<_2, _2, _1>; // Shape of the threadblocks in a cluster +}; -template void tvm_cutlass_group_gemm_sm90(NDArray x, NDArray weight, NDArray indptr, NDArray workspace, NDArray out) { - // Workspace is used for storing device-side group gemm arguments and cutlass internal workspace. - // Recommened size is 4MB. - static auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream"); - cudaStream_t stream = static_cast(func().cast()); - CHECK_EQ(x->ndim, 2); - CHECK_EQ(weight->ndim, 3); - CHECK_EQ(indptr->ndim, 1); - CHECK_EQ(workspace->ndim, 1); - CHECK_EQ(out->ndim, 2); - int num_groups = weight->shape[0]; - int n = weight->shape[1]; - int k = weight->shape[2]; - float alpha = 1.0f; - float beta = 0.0f; - cutlass_group_gemm(static_cast(x->data), static_cast(weight->data), - static_cast(indptr->data), static_cast(workspace->data), - workspace->shape[0], n, k, num_groups, alpha, beta, - static_cast(out->data), stream); + tvm_cutlass_group_gemm_impl<90>(x, weight, indptr, workspace, out); } -TVM_FFI_REGISTER_GLOBAL("cutlass.group_gemm_fp16_sm90") - .set_body_typed(tvm_cutlass_group_gemm_sm90); +TVM_FFI_REGISTER_GLOBAL("cutlass.group_gemm").set_body_typed(tvm_cutlass_group_gemm_sm90); + +#endif // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED } // namespace runtime } // namespace tvm - -#endif // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED diff --git a/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu b/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu deleted file mode 100644 index 5164958afeb5..000000000000 --- a/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu +++ /dev/null @@ -1,164 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include -#include -#include -#include - -#include "../cublas/cublas_utils.h" -#include "blockwise_scaled_gemm_runner.cuh" - -#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) - -namespace tvm { -namespace runtime { - -void tvm_cutlass_fp8_blockwise_scaled_gemm(NDArray a, NDArray b, NDArray scales_a, NDArray scales_b, - NDArray workspace, int64_t block_size_0, - int64_t block_size_1, NDArray out) { - using TileShape = Shape<_128, _128, _128>; - using ClusterShape = Shape<_1, _1, _1>; - - // Workspace is used for storing device-side gemm arguments and cutlass internal workspace. - // Recommened size is 4MB. - const auto get_stream_func = tvm::ffi::Function::GetGlobal("runtime.get_cuda_stream"); - ICHECK(get_stream_func.has_value()); - cudaStream_t stream = static_cast((*get_stream_func)().cast()); - - CHECK_GE(a->ndim, 2); - CHECK_EQ(scales_a->ndim, a->ndim); - CHECK_EQ(b->ndim, 2); - CHECK_EQ(scales_b->ndim, 2); - CHECK_EQ(workspace->ndim, 1); - CHECK_EQ(out->ndim, a->ndim); - int64_t m = 1; - for (int64_t i = 0; i < a->ndim - 1; ++i) { - m *= a->shape[i]; - } - int64_t n = b->shape[0]; - CHECK_EQ(a->shape[a->ndim - 1], b->shape[1]) << "Only col-major B is supported now."; - int64_t k = a->shape[a->ndim - 1]; - - // scales_a is col-major of (*a_shape[:-1], k / block_size) - CHECK_EQ(scales_a->shape[0] * block_size_1, k); - for (int64_t i = 1; i < scales_a->ndim; ++i) { - CHECK_EQ(scales_a->shape[i], a->shape[i - 1]); - } - // scales_b is col-major of (k / block_size, n / block_size) - CHECK_EQ(scales_b->shape[0] * block_size_0, n); - CHECK_EQ(scales_b->shape[1] * block_size_1, k); - - using tvm::runtime::DataType; - CHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN()); - CHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN()); - CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32)); - CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32)); - CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8)); - - if (DataType(out->dtype) == DataType::Float(16)) { - cutlass_fp8_blockwise_scaled_gemm( - static_cast(a->data), static_cast(b->data), - static_cast(scales_a->data), static_cast(scales_b->data), - static_cast(out->data), static_cast(workspace->data), - workspace->shape[0] * DataType(workspace->dtype).bytes(), m, n, k, stream); - } else if (DataType(out->dtype) == DataType::BFloat(16)) { - cutlass_fp8_blockwise_scaled_gemm( - static_cast(a->data), static_cast(b->data), - static_cast(scales_a->data), static_cast(scales_b->data), - static_cast(out->data), static_cast(workspace->data), - workspace->shape[0] * DataType(workspace->dtype).bytes(), m, n, k, stream); - } else { - LOG(FATAL) << "Unsupported output dtype: " << DataType(out->dtype); - } -} - -void tvm_cutlass_fp8_blockwise_scaled_bmm(NDArray a, NDArray b, NDArray scales_a, NDArray scales_b, - NDArray workspace, int64_t block_size_0, - int64_t block_size_1, NDArray out) { - using TileShape = Shape<_128, _128, _128>; - using ClusterShape = Shape<_1, _1, _1>; - - // Workspace is used for storing device-side gemm arguments and cutlass internal workspace. - // Recommened size is 4MB. - const auto get_stream_func = tvm::ffi::Function::GetGlobal("runtime.get_cuda_stream"); - ICHECK(get_stream_func.has_value()); - cudaStream_t stream = static_cast((*get_stream_func)().cast()); - - CHECK_EQ(a->ndim, 3); - CHECK_EQ(scales_a->ndim, 3); - CHECK_EQ(b->ndim, 3); - CHECK_EQ(scales_b->ndim, 3); - CHECK_EQ(workspace->ndim, 1); - CHECK_EQ(out->ndim, 3); - int64_t batch_size = a->shape[0]; - int64_t m = a->shape[1]; - int64_t n = b->shape[1]; - CHECK_EQ(a->shape[2], b->shape[2]) << "Only col-major B is supported now."; - int64_t k = a->shape[2]; - CHECK_EQ(b->shape[0], batch_size); - CHECK_EQ(scales_a->shape[0], batch_size); - CHECK_EQ(scales_b->shape[0], batch_size); - CHECK_EQ(out->shape[0], batch_size); - - // scales_a is col-major of (batch_size, m, k / block_size) - CHECK_EQ(scales_a->shape[1] * block_size_1, k); - CHECK_EQ(scales_a->shape[2], m); - // scales_b is col-major of (k / block_size, n / block_size) - CHECK_EQ(scales_b->shape[1] * block_size_0, n); - CHECK_EQ(scales_b->shape[2] * block_size_1, k); - - using tvm::runtime::DataType; - CHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN()); - CHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN()); - CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32)); - CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32)); - CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8)); - - if (DataType(out->dtype) == DataType::Float(16)) { - cutlass_fp8_blockwise_scaled_bmm( - static_cast(a->data), static_cast(b->data), - static_cast(scales_a->data), static_cast(scales_b->data), - static_cast(out->data), static_cast(workspace->data), - workspace->shape[0] * DataType(workspace->dtype).bytes(), m, n, k, batch_size, stream); - } else if (DataType(out->dtype) == DataType::BFloat(16)) { - cutlass_fp8_blockwise_scaled_bmm( - static_cast(a->data), static_cast(b->data), - static_cast(scales_a->data), static_cast(scales_b->data), - static_cast(out->data), static_cast(workspace->data), - workspace->shape[0] * DataType(workspace->dtype).bytes(), m, n, k, batch_size, stream); - } else { - LOG(FATAL) << "Unsupported output dtype: " << DataType(out->dtype); - } -} - -TVM_FFI_REGISTER_GLOBAL("cutlass.blockwise_scaled_gemm_e4m3fn_e4m3fn") - .set_body_typed(tvm_cutlass_fp8_blockwise_scaled_gemm); -TVM_FFI_REGISTER_GLOBAL("cutlass.blockwise_scaled_bmm_e4m3fn_e4m3fn") - .set_body_typed(tvm_cutlass_fp8_blockwise_scaled_bmm); - -} // namespace runtime -} // namespace tvm - -#endif // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED diff --git a/src/runtime/contrib/cutlass/fp8_group_gemm.cu b/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu similarity index 98% rename from src/runtime/contrib/cutlass/fp8_group_gemm.cu rename to src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu index 62a91dec1809..686a6ebcffeb 100644 --- a/src/runtime/contrib/cutlass/fp8_group_gemm.cu +++ b/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu @@ -23,7 +23,7 @@ #include #include -#include "group_gemm_runner.cuh" +#include "fp16_group_gemm_runner_sm90.cuh" #if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh new file mode 100644 index 000000000000..4ecca5f1d8a9 --- /dev/null +++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +#include "cutlass/bfloat16.h" +#include "cutlass/half.h" + +namespace tvm { +namespace runtime { + +template +struct CutlassFP8GroupwiseGemm; + +template +void tvm_cutlass_fp8_groupwise_scaled_gemm_impl(NDArray a, NDArray b, NDArray scales_a, + NDArray scales_b, NDArray workspace, + int64_t block_size_0, int64_t block_size_1, + NDArray out) { + // Workspace is used for storing device-side gemm arguments and cutlass internal workspace. + // Recommened size is 4MB. + static tvm::ffi::Function get_stream_func = + tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream"); + cudaStream_t stream = static_cast(get_stream_func().cast()); + + CHECK_GE(a->ndim, 2); + CHECK_EQ(scales_a->ndim, a->ndim); + CHECK_EQ(b->ndim, 2); + CHECK_EQ(scales_b->ndim, 2); + CHECK_EQ(workspace->ndim, 1); + CHECK_EQ(out->ndim, a->ndim); + int64_t m = 1; + for (int64_t i = 0; i < a->ndim - 1; ++i) { + m *= a->shape[i]; + } + int64_t n = b->shape[0]; + CHECK_EQ(a->shape[a->ndim - 1], b->shape[1]) << "Only col-major B is supported now."; + int64_t k = a->shape[a->ndim - 1]; + + // scales_a is col-major of (*a_shape[:-1], k / block_size) + CHECK_EQ(scales_a->shape[0] * block_size_1, k); + for (int64_t i = 1; i < scales_a->ndim; ++i) { + CHECK_EQ(scales_a->shape[i], a->shape[i - 1]); + } + // scales_b is col-major of (k / block_size, n / block_size) + CHECK_EQ((n + block_size_0 - 1) / block_size_0, scales_b->shape[0]); + CHECK_EQ(scales_b->shape[1] * block_size_1, k); + + using tvm::runtime::DataType; + CHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN()); + CHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN()); + CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32)); + CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32)); + CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8)); + + if (DataType(out->dtype) == DataType::Float(16)) { + CutlassFP8GroupwiseGemm::run(static_cast(a->data), + static_cast(b->data), + static_cast(scales_a->data), + static_cast(scales_b->data), + static_cast(out->data), + static_cast(workspace->data), + workspace->shape[0] * DataType(workspace->dtype).bytes(), m, + n, k, 1, stream); + } else if (DataType(out->dtype) == DataType::BFloat(16)) { + CutlassFP8GroupwiseGemm::run(static_cast(a->data), + static_cast(b->data), + static_cast(scales_a->data), + static_cast(scales_b->data), + static_cast(out->data), + static_cast(workspace->data), + workspace->shape[0] * DataType(workspace->dtype).bytes(), m, + n, k, 1, stream); + } else { + LOG(FATAL) << "Unsupported output dtype: " << DataType(out->dtype); + } +} + +template +void tvm_cutlass_fp8_groupwise_scaled_bmm_impl(NDArray a, NDArray b, NDArray scales_a, + NDArray scales_b, NDArray workspace, + int64_t block_size_0, int64_t block_size_1, + NDArray out) { + // Workspace is used for storing device-side gemm arguments and cutlass internal workspace. + // Recommened size is 4MB. + static tvm::ffi::Function get_stream_func = + tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream"); + cudaStream_t stream = static_cast(get_stream_func().cast()); + + CHECK_EQ(a->ndim, 3); + CHECK_EQ(scales_a->ndim, 3); + CHECK_EQ(b->ndim, 3); + CHECK_EQ(scales_b->ndim, 3); + CHECK_EQ(workspace->ndim, 1); + CHECK_EQ(out->ndim, 3); + int64_t batch_size = a->shape[0]; + int64_t m = a->shape[1]; + int64_t n = b->shape[1]; + CHECK_EQ(a->shape[2], b->shape[2]) << "Only col-major B is supported now."; + int64_t k = a->shape[2]; + CHECK_EQ(b->shape[0], batch_size); + CHECK_EQ(scales_a->shape[0], batch_size); + CHECK_EQ(scales_b->shape[0], batch_size); + CHECK_EQ(out->shape[0], batch_size); + + // scales_a is col-major of (batch_size, m, k / block_size) + CHECK_EQ(scales_a->shape[1] * block_size_1, k); + CHECK_EQ(scales_a->shape[2], m); + // scales_b is col-major of (k / block_size, n / block_size) + CHECK_EQ(scales_b->shape[1] * block_size_0, n); + CHECK_EQ(scales_b->shape[2] * block_size_1, k); + + using tvm::runtime::DataType; + CHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN()); + CHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN()); + CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32)); + CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32)); + CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8)); + + if (DataType(out->dtype) == DataType::Float(16)) { + CutlassFP8GroupwiseGemm::run(static_cast(a->data), + static_cast(b->data), + static_cast(scales_a->data), + static_cast(scales_b->data), + static_cast(out->data), + static_cast(workspace->data), + workspace->shape[0] * DataType(workspace->dtype).bytes(), m, + n, k, batch_size, stream); + } else if (DataType(out->dtype) == DataType::BFloat(16)) { + CutlassFP8GroupwiseGemm::run(static_cast(a->data), + static_cast(b->data), + static_cast(scales_a->data), + static_cast(scales_b->data), + static_cast(out->data), + static_cast(workspace->data), + workspace->shape[0] * DataType(workspace->dtype).bytes(), m, + n, k, batch_size, stream); + } else { + LOG(FATAL) << "Unsupported output dtype: " << DataType(out->dtype); + } +} + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm100.cuh b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm100.cuh new file mode 100644 index 000000000000..95fc578fd43f --- /dev/null +++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm100.cuh @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include "../../cuda/cuda_common.h" + +// clang-format off +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +// clang-format on + +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + CHECK(error == cutlass::Status::kSuccess) \ + << "Got cutlass error: " << cutlassGetStatusString(error); \ + } + +using namespace cute; +using tvm::runtime::NDArray; + +template +struct CutlassFP8ScaledGroupwiseGemmRunnerSM100 { + using ElementA = cutlass::float_e4m3_t; + using LayoutA = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + using ElementB = cutlass::float_e4m3_t; + using LayoutB = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + using ElementC = void; + using LayoutC = cutlass::layout::RowMajor; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using LayoutD = LayoutC; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + // MMA type + using ElementAccumulator = float; // Element Accumulator will also be our scale factor type + using ElementCompute = float; + using ElementBlockScale = float; + + static constexpr int ScaleGranularityM = 1; + static constexpr int ScaleGranularityN = 128; + static constexpr int ScaleGranularityK = 128; + using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig< + ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, UMMA::Major::MN, UMMA::Major::K>; + + using LayoutSFA = + decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand + using LayoutSFB = + decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementCompute, ElementC, + LayoutC, AlignmentC, ElementD, LayoutC, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementA, + cute::tuple, AlignmentA, ElementB, cute::tuple, + AlignmentB, ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelScheduleSm100Blockwise>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, CollectiveMainloop, CollectiveEpilogue, + void>; // Default to ClusterLaunchControl (CLC) based tile scheduler + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideD = typename Gemm::GemmKernel::StrideD; + + void run_gemm(const ElementA* a_ptr, const ElementB* b_ptr, const ElementBlockScale* scales_a_ptr, + const ElementBlockScale* scales_b_ptr, ElementD* o_ptr, int m, int n, int k, int l, + uint8_t* workspace, int64_t workspace_size, cudaStream_t stream) { + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + StrideA stride_a = + cute::make_stride(static_cast(k), Int<1>{}, static_cast(m * k)); + StrideB stride_b = + cute::make_stride(static_cast(k), Int<1>{}, static_cast(n * k)); + StrideD stride_d = + cute::make_stride(static_cast(n), Int<1>{}, static_cast(m * n)); + auto layout_scales_a = ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, l)); + auto layout_scales_b = ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, l)); + + typename Gemm::Arguments arguments = {cutlass::gemm::GemmUniversalMode::kGemm, + {m, n, k, l}, + {a_ptr, stride_a, b_ptr, stride_b, scales_a_ptr, + layout_scales_a, scales_b_ptr, layout_scales_b}, + {{}, o_ptr, stride_d, o_ptr, stride_d}, + hw_info}; + + Gemm gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(arguments)); + CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments)); + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream)); + CUTLASS_CHECK(gemm_op.run(stream)); + } +}; + +template +void cutlass_fp8_groupwise_scaled_mm_sm100(ElementA* a, ElementB* b, ElementBlockScale* scales_a, + ElementBlockScale* scales_b, ElementD* out, + uint8_t* workspace, int64_t workspace_size, int64_t m, + int64_t n, int64_t k, int64_t l, cudaStream_t stream) { + using Runner = CutlassFP8ScaledGroupwiseGemmRunnerSM100; + Runner runner; + runner.run_gemm(a, b, scales_a, scales_b, out, m, n, k, l, workspace, workspace_size, stream); +} diff --git a/src/runtime/contrib/cutlass/blockwise_scaled_gemm_runner.cuh b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm90.cuh similarity index 75% rename from src/runtime/contrib/cutlass/blockwise_scaled_gemm_runner.cuh rename to src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm90.cuh index f520bf815a94..5ec9ed083916 100644 --- a/src/runtime/contrib/cutlass/blockwise_scaled_gemm_runner.cuh +++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm90.cuh @@ -58,7 +58,7 @@ using tvm::runtime::NDArray; template -struct CutlassFP8ScaledBlockwiseGemmRunner { +struct CutlassFP8GroupwiseScaledGemmRunner { using ElementAccumulator = float; using ElementCompute = float; using ElementBlockScale = float; @@ -149,53 +149,14 @@ struct CutlassFP8ScaledBlockwiseGemmRunner { template -void cutlass_fp8_blockwise_scaled_gemm(ElementA* a, ElementB* b, ElementBlockScale* scales_a, - ElementBlockScale* scales_b, ElementD* out, - uint8_t* workspace, int64_t workspace_size, int64_t m, - int64_t n, int64_t k, cudaStream_t stream) { +void cutlass_fp8_groupwise_scaled_mm_sm90(ElementA* a, ElementB* b, ElementBlockScale* scales_a, + ElementBlockScale* scales_b, ElementD* out, + uint8_t* workspace, int64_t workspace_size, int64_t m, + int64_t n, int64_t k, int64_t l, cudaStream_t stream) { if (k > 3 * n) { using SchedulerType = cutlass::gemm::StreamKScheduler; using Runner = - CutlassFP8ScaledBlockwiseGemmRunner; - using StrideA = typename Runner::StrideA; - using StrideB = typename Runner::StrideB; - using StrideD = typename Runner::StrideD; - - Runner runner; - StrideA stride_a = cute::make_stride(k, Int<1>{}, m * k); - StrideB stride_b = cute::make_stride(k, Int<1>{}, n * k); - StrideD stride_d = cute::make_stride(n, Int<1>{}, m * n); - ProblemShape problem_size{static_cast(m), static_cast(n), static_cast(k), 1}; - runner.run_gemm(a, b, scales_a, scales_b, out, &problem_size, &stride_a, &stride_b, &stride_d, - workspace, workspace_size, stream); - } else { - using SchedulerType = cutlass::gemm::PersistentScheduler; - using Runner = - CutlassFP8ScaledBlockwiseGemmRunner; - using StrideA = typename Runner::StrideA; - using StrideB = typename Runner::StrideB; - using StrideD = typename Runner::StrideD; - - Runner runner; - StrideA stride_a = cute::make_stride(k, Int<1>{}, m * k); - StrideB stride_b = cute::make_stride(k, Int<1>{}, n * k); - StrideD stride_d = cute::make_stride(n, Int<1>{}, m * n); - ProblemShape problem_size{static_cast(m), static_cast(n), static_cast(k), 1}; - runner.run_gemm(a, b, scales_a, scales_b, out, &problem_size, &stride_a, &stride_b, &stride_d, - workspace, workspace_size, stream); - } -} - -template -void cutlass_fp8_blockwise_scaled_bmm(ElementA* a, ElementB* b, ElementBlockScale* scales_a, - ElementBlockScale* scales_b, ElementD* out, - uint8_t* workspace, int64_t workspace_size, int64_t m, - int64_t n, int64_t k, int64_t l, cudaStream_t stream) { - if (k > 3 * n) { - using SchedulerType = cutlass::gemm::StreamKScheduler; - using Runner = - CutlassFP8ScaledBlockwiseGemmRunner; + CutlassFP8GroupwiseScaledGemmRunner; using StrideA = typename Runner::StrideA; using StrideB = typename Runner::StrideB; using StrideD = typename Runner::StrideD; @@ -211,7 +172,7 @@ void cutlass_fp8_blockwise_scaled_bmm(ElementA* a, ElementB* b, ElementBlockScal } else { using SchedulerType = cutlass::gemm::PersistentScheduler; using Runner = - CutlassFP8ScaledBlockwiseGemmRunner; + CutlassFP8GroupwiseScaledGemmRunner; using StrideA = typename Runner::StrideA; using StrideB = typename Runner::StrideB; using StrideD = typename Runner::StrideD; diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm100.cu b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm100.cu new file mode 100644 index 000000000000..ffa3ae6653e6 --- /dev/null +++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm100.cu @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include "../cublas/cublas_utils.h" +#include "fp8_groupwise_scaled_gemm.cuh" +#include "fp8_groupwise_scaled_gemm_runner_sm100.cuh" + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +namespace tvm { +namespace runtime { + +template +struct CutlassFP8GroupwiseGemm<100, TileShape, ClusterShape, ElementA, ElementB, ElementC, + ElementBlockScale> { + static void run(ElementA* a, ElementB* b, ElementBlockScale* scales_a, + ElementBlockScale* scales_b, ElementC* out, uint8_t* workspace, + int64_t workspace_size, int64_t m, int64_t n, int64_t k, int64_t l, + cudaStream_t stream) { + cutlass_fp8_groupwise_scaled_mm_sm100( + a, b, scales_a, scales_b, out, workspace, workspace_size, m, n, k, l, stream); + } +}; + +void tvm_cutlass_fp8_groupwise_scaled_gemm_sm100(NDArray a, NDArray b, NDArray scales_a, + NDArray scales_b, NDArray workspace, + int64_t block_size_0, int64_t block_size_1, + NDArray out) { + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + tvm_cutlass_fp8_groupwise_scaled_gemm_impl<100, TileShape, ClusterShape>( + a, b, scales_a, scales_b, workspace, block_size_0, block_size_1, out); +} + +void tvm_cutlass_fp8_groupwise_scaled_bmm_sm100(NDArray a, NDArray b, NDArray scales_a, + NDArray scales_b, NDArray workspace, + int64_t block_size_0, int64_t block_size_1, + NDArray out) { + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + tvm_cutlass_fp8_groupwise_scaled_bmm_impl<100, TileShape, ClusterShape>( + a, b, scales_a, scales_b, workspace, block_size_0, block_size_1, out); +} + +TVM_FFI_REGISTER_GLOBAL("cutlass.groupwise_scaled_gemm_e4m3fn_e4m3fn") + .set_body_typed(tvm_cutlass_fp8_groupwise_scaled_gemm_sm100); +TVM_FFI_REGISTER_GLOBAL("cutlass.groupwise_scaled_bmm_e4m3fn_e4m3fn") + .set_body_typed(tvm_cutlass_fp8_groupwise_scaled_bmm_sm100); + +} // namespace runtime +} // namespace tvm + +#endif // CUTLASS_ARCH_MMA_SM100_SUPPORTED diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm90.cu b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm90.cu new file mode 100644 index 000000000000..e445e97da364 --- /dev/null +++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm90.cu @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include "../cublas/cublas_utils.h" +#include "fp8_groupwise_scaled_gemm.cuh" +#include "fp8_groupwise_scaled_gemm_runner_sm90.cuh" + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +namespace tvm { +namespace runtime { + +template +struct CutlassFP8GroupwiseGemm<90, TileShape, ClusterShape, ElementA, ElementB, ElementC, + ElementBlockScale> { + static void run(ElementA* a, ElementB* b, ElementBlockScale* scales_a, + ElementBlockScale* scales_b, ElementC* out, uint8_t* workspace, + int64_t workspace_size, int64_t m, int64_t n, int64_t k, int64_t l, + cudaStream_t stream) { + cutlass_fp8_groupwise_scaled_mm_sm90( + a, b, scales_a, scales_b, out, workspace, workspace_size, m, n, k, l, stream); + } +}; + +void tvm_cutlass_fp8_groupwise_scaled_gemm_sm90(NDArray a, NDArray b, NDArray scales_a, + NDArray scales_b, NDArray workspace, + int64_t block_size_0, int64_t block_size_1, + NDArray out) { + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + tvm_cutlass_fp8_groupwise_scaled_gemm_impl<90, TileShape, ClusterShape>( + a, b, scales_a, scales_b, workspace, block_size_0, block_size_1, out); +} + +void tvm_cutlass_fp8_groupwise_scaled_bmm_sm90(NDArray a, NDArray b, NDArray scales_a, + NDArray scales_b, NDArray workspace, + int64_t block_size_0, int64_t block_size_1, + NDArray out) { + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + tvm_cutlass_fp8_groupwise_scaled_bmm_impl<90, TileShape, ClusterShape>( + a, b, scales_a, scales_b, workspace, block_size_0, block_size_1, out); +} + +TVM_FFI_REGISTER_GLOBAL("cutlass.groupwise_scaled_gemm_e4m3fn_e4m3fn") + .set_body_typed(tvm_cutlass_fp8_groupwise_scaled_gemm_sm90); +TVM_FFI_REGISTER_GLOBAL("cutlass.groupwise_scaled_bmm_e4m3fn_e4m3fn") + .set_body_typed(tvm_cutlass_fp8_groupwise_scaled_bmm_sm90); + +} // namespace runtime +} // namespace tvm + +#endif // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_runner_sm100.cuh b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_runner_sm100.cuh new file mode 100644 index 000000000000..19c6b699aa95 --- /dev/null +++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_runner_sm100.cuh @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include "../../cuda/cuda_common.h" + +// clang-format off +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +// clang-format on + +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + CHECK(error == cutlass::Status::kSuccess) \ + << "Got cutlass error: " << cutlassGetStatusString(error); \ + } + +using namespace cute; +using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group + +inline size_t aligned(size_t value, size_t alignment = 16) { + return (value + alignment - 1) / alignment * alignment; +} + +template +struct CutlassFP8ScaledGroupwiseGroupGemmRunnerSM100 { + using LayoutA = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + using LayoutB = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + using LayoutC = cutlass::layout::RowMajor; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using ElementAccumulator = float; + using ElementCompute = float; + + static constexpr int ScaleGranularityM = 1; + static constexpr int ScaleGranularityN = 128; + static constexpr int ScaleGranularityK = 128; + using ScaleConfig = + cutlass::detail::Sm100BlockwiseScaleConfig; + + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise2SmSm100; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementCompute, ElementC, + LayoutC*, AlignmentC, ElementC, LayoutC*, AlignmentC, EpilogueSchedule>::CollectiveOp; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementA, + cute::tuple, AlignmentA, ElementB, cute::tuple, + AlignmentB, ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + using GemmKernel = cutlass::gemm::kernel::GemmUniversal; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; + + void run_group_gemm(const ElementA** ptr_A, const ElementB** ptr_B, + const ElementBlockScale** ptr_scales_a, + const ElementBlockScale** ptr_scales_b, const ElementC** ptr_C, + ElementC** ptr_D, + typename ProblemShape::UnderlyingProblemShape* problem_sizes, + typename ProblemShape::UnderlyingProblemShape* problem_sizes_host, + StrideA* stride_A, StrideB* stride_B, LayoutSFA* layout_scales_a, + LayoutSFB* layout_scales_b, StrideC* stride_C, StrideD* stride_D, + uint8_t* workspace, int64_t workspace_size, int num_groups, + cudaStream_t stream) { + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Gemm::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGrouped, + {num_groups, problem_sizes, problem_sizes_host}, + {ptr_A, stride_A, ptr_B, stride_B, ptr_scales_a, + layout_scales_a, ptr_scales_b, layout_scales_b}, + {{}, ptr_C, stride_C, ptr_D, stride_D}, + hw_info}; + auto& fusion_args = arguments.epilogue.thread; + fusion_args.alpha = 1.0f; + fusion_args.beta = 0.0f; + + Gemm gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(arguments)); + CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments)); + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream)); + CUTLASS_CHECK(gemm_op.run(stream)); + } +}; + +template +__global__ void prepare_group_gemm_arguments( + const ElementA** ptr_A, const ElementB** ptr_B, const ElementBlockScale** ptr_scales_a, + const ElementBlockScale** ptr_scales_b, ElementC** ptr_D, + typename ProblemShape::UnderlyingProblemShape* problem_sizes, StrideA* stride_A, + StrideB* stride_B, LayoutSFA* layout_scales_a, LayoutSFB* layout_scales_b, StrideC* stride_D, + const ElementA* a, const ElementB* b, const ElementBlockScale* scales_a, + const ElementBlockScale* scales_b, ElementC* out, int64_t* indptr, int64_t n, int64_t k, + int num_groups) { + int group_id = threadIdx.x; + if (group_id >= num_groups) return; + int prev_rows = group_id == 0 ? 0 : indptr[group_id - 1]; + ptr_A[group_id] = a + prev_rows * k; + ptr_B[group_id] = b + group_id * k * n; + ptr_D[group_id] = out + prev_rows * n; + ptr_scales_a[group_id] = scales_a + prev_rows * ((k + 127) / 128); + ptr_scales_b[group_id] = scales_b + group_id * ((k + 127) / 128) * ((n + 127) / 128); + int64_t m = indptr[group_id] - prev_rows; + problem_sizes[group_id] = {static_cast(m), static_cast(n), static_cast(k)}; + stride_A[group_id] = cute::make_stride(k, Int<1>{}, Int<0>{}); + stride_B[group_id] = cute::make_stride(k, Int<1>{}, Int<0>{}); + stride_D[group_id] = cute::make_stride(n, Int<1>{}, Int<0>{}); + layout_scales_a[group_id] = ScaleConfig::tile_atom_to_shape_SFA( + make_shape(static_cast(m), static_cast(n), static_cast(k), 1)); + layout_scales_b[group_id] = ScaleConfig::tile_atom_to_shape_SFB( + make_shape(static_cast(m), static_cast(n), static_cast(k), 1)); +} + +template +void cutlass_fp8_groupwise_scaled_group_gemm_sm100( + ElementA* a, ElementB* b, const ElementBlockScale* scales_a, const ElementBlockScale* scales_b, + int64_t* indptr, uint8_t* workspace, int64_t workspace_size, int64_t n, int64_t k, + int64_t num_groups, ElementC* out, cudaStream_t stream) { + using TileShape = Shape<_256, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using Runner = + CutlassFP8ScaledGroupwiseGroupGemmRunnerSM100; + using ScaleConfig = typename Runner::ScaleConfig; + using StrideA = typename Runner::StrideA; + using StrideB = typename Runner::StrideB; + using StrideC = typename Runner::StrideC; + using LayoutSFA = typename Runner::LayoutSFA; + using LayoutSFB = typename Runner::LayoutSFB; + + Runner runner; + std::ptrdiff_t offset = 0; + const ElementA** ptr_A = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(ElementA*) * num_groups); + const ElementB** ptr_B = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(ElementB*) * num_groups); + const ElementBlockScale** ptr_scales_a = + reinterpret_cast(workspace + offset); + offset += aligned(sizeof(ElementBlockScale*) * num_groups); + const ElementBlockScale** ptr_scales_b = + reinterpret_cast(workspace + offset); + offset += aligned(sizeof(ElementBlockScale*) * num_groups); + ElementC** ptr_D = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(ElementC*) * num_groups); + typename ProblemShape::UnderlyingProblemShape* problem_sizes = + reinterpret_cast(workspace + offset); + offset += aligned(sizeof(typename ProblemShape::UnderlyingProblemShape) * num_groups); + StrideA* stride_A = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(StrideA) * num_groups); + StrideB* stride_B = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(StrideB) * num_groups); + StrideC* stride_D = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(StrideC) * num_groups); + LayoutSFA* layout_scales_a = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(LayoutSFA) * num_groups); + LayoutSFB* layout_scales_b = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(LayoutSFB) * num_groups); + prepare_group_gemm_arguments + <<<1, num_groups, 0, stream>>>(ptr_A, ptr_B, ptr_scales_a, ptr_scales_b, ptr_D, problem_sizes, + stride_A, stride_B, layout_scales_a, layout_scales_b, stride_D, + a, b, scales_a, scales_b, out, indptr, n, k, num_groups); + offset = aligned(offset, 256); + runner.run_group_gemm(ptr_A, ptr_B, ptr_scales_a, ptr_scales_b, + const_cast(ptr_D), ptr_D, problem_sizes, nullptr, + stride_A, stride_B, layout_scales_a, layout_scales_b, stride_D, stride_D, + workspace + offset, workspace_size - offset, num_groups, stream); +} diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu new file mode 100644 index 000000000000..d13481e9dd3f --- /dev/null +++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include "fp8_groupwise_scaled_group_gemm_runner_sm100.cuh" + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +namespace tvm { +namespace runtime { + +void tvm_fp8_groupwise_scaled_group_gemm_sm100(NDArray a, NDArray b, NDArray scales_a, + NDArray scales_b, NDArray indptr, NDArray workspace, + int64_t block_size_0, int64_t block_size_1, + NDArray out) { + // Workspace is used for storing device-side group gemm arguments and cutlass internal workspace. + // Recommended size is 4MB. + static auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream"); + cudaStream_t stream = static_cast(func().cast()); + CHECK_EQ(a->ndim, 2); + CHECK_EQ(b->ndim, 3); + CHECK_EQ(indptr->ndim, 1); + CHECK_EQ(workspace->ndim, 1); + CHECK_EQ(out->ndim, 2); + int num_groups = b->shape[0]; + int n = b->shape[1]; + int k = b->shape[2]; + + CHECK_EQ(scales_a->ndim, a->ndim); + CHECK_EQ(scales_b->ndim, b->ndim); + // scales_a is row-major of (m, k / block_size) + CHECK_EQ((k + block_size_1 - 1) / block_size_1, scales_a->shape[1]); + CHECK_EQ(scales_a->shape[0], a->shape[0]); + // scales_b is col-major of (k / block_size, n / block_size) + CHECK_EQ(scales_b->shape[0], num_groups); + CHECK_EQ((n + block_size_0 - 1) / block_size_0, scales_b->shape[1]); + CHECK_EQ((k + block_size_1 - 1) / block_size_1, scales_b->shape[2]); + + using tvm::runtime::DataType; + CHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN()); + CHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN()); + CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32)); + CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32)); + CHECK_EQ(DataType(indptr->dtype), DataType::Int(64)); + CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8)); + + if (DataType(out->dtype) == DataType::Float(16)) { + using Dtype = cutlass::half_t; + cutlass_fp8_groupwise_scaled_group_gemm_sm100( + static_cast(a->data), static_cast(b->data), + static_cast(scales_a->data), static_cast(scales_b->data), + static_cast(indptr->data), static_cast(workspace->data), + workspace->shape[0], n, k, num_groups, static_cast(out->data), stream); + } else if (DataType(out->dtype) == DataType::BFloat(16)) { + using Dtype = cutlass::bfloat16_t; + cutlass_fp8_groupwise_scaled_group_gemm_sm100( + static_cast(a->data), static_cast(b->data), + static_cast(scales_a->data), static_cast(scales_b->data), + static_cast(indptr->data), static_cast(workspace->data), + workspace->shape[0], n, k, num_groups, static_cast(out->data), stream); + } +} + +TVM_FFI_REGISTER_GLOBAL("cutlass.groupwise_scaled_group_gemm_e4m3fn_e4m3fn") + .set_body_typed(tvm_fp8_groupwise_scaled_group_gemm_sm100); + +} // namespace runtime +} // namespace tvm + +#endif // CUTLASS_ARCH_MMA_SM100_SUPPORTED diff --git a/src/target/tag.cc b/src/target/tag.cc index f6e2307b75e1..0df0d8d2c7af 100644 --- a/src/target/tag.cc +++ b/src/target/tag.cc @@ -161,6 +161,8 @@ TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a100", "sm_80", 49152, 65536) .with_config("l2_cache_size_bytes", 41943040); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-h100", "sm_90a", 49152, 65536) .with_config("l2_cache_size_bytes", 52428800); +TVM_REGISTER_CUDA_TAG("nvidia/nvidia-b100", "sm_100a", 49152, 65536) + .with_config("l2_cache_size_bytes", 52428800); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a40", "sm_86", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a30", "sm_80", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a10", "sm_86", 49152, 65536); diff --git a/tests/python/contrib/test_cutlass_gemm.py b/tests/python/contrib/test_cutlass_gemm.py index 7c259e6f7d6d..33f7ef1160a1 100644 --- a/tests/python/contrib/test_cutlass_gemm.py +++ b/tests/python/contrib/test_cutlass_gemm.py @@ -44,8 +44,8 @@ def verify_group_gemm( def get_ref_data(): assert M % num_groups == 0 M_per_group = M // num_groups - a_np = get_random_ndarray((M, K), "float16") - b_np = get_random_ndarray((num_groups, N, K), "float16") + a_np = get_random_ndarray((M, K), x_dtype) + b_np = get_random_ndarray((num_groups, N, K), weight_dtype) indptr_np = np.arange(1, num_groups + 1).astype("int64") * M_per_group c_np = np.concatenate( [a_np[i * M_per_group : (i + 1) * M_per_group] @ b_np[i].T for i in range(num_groups)], @@ -76,7 +76,7 @@ def to_numpy_dtype(dtype): @tvm.testing.requires_cuda_compute_version(9) def test_group_gemm_sm90(): verify_group_gemm( - "cutlass.group_gemm_fp16_sm90", + "cutlass.group_gemm", 8, 128, 128, @@ -116,6 +116,24 @@ def test_group_gemm_sm90(): ) +@tvm.testing.requires_cutlass +@tvm.testing.requires_cuda_compute_version(10) +def test_group_gemm_sm100(): + verify_group_gemm( + "cutlass.group_gemm", + 8, + 128, + 128, + 4, + "bfloat16", + "bfloat16", + "bfloat16", + False, + rtol=1e-2, + atol=1e-3, + ) + + def rowwise_quant_fp8_e4m3(shape: Tuple[int, int], block_size: Tuple[int, int], dtype: str): x_full_np = (np.random.rand(*shape) * 2 - 1).astype(dtype) x_scale_shape = ( @@ -283,14 +301,14 @@ def blockwise_bmm( @tvm.testing.requires_cutlass @tvm.testing.requires_cuda_compute_version(9) -def test_fp8_e4m3_blockwise_scaled_gemm(): +def test_fp8_e4m3_groupwise_scaled_gemm(): M = 16 N = 4608 K = 896 block_size = (128, 128) assert N % 128 == 0 and K % 128 == 0 # Only support N/K are multiple of 128 - func_name = "cutlass.blockwise_scaled_gemm_e4m3fn_e4m3fn" + func_name = "cutlass.groupwise_scaled_gemm_e4m3fn_e4m3fn" gemm_func = tvm.get_global_func(func_name, allow_missing=True) if gemm_func is None: print(f"Skipped as {func_name} is not available") @@ -316,7 +334,7 @@ def test_fp8_e4m3_blockwise_scaled_gemm(): @tvm.testing.requires_cutlass @tvm.testing.requires_cuda_compute_version(9) -def test_fp8_e4m3_blockwise_scaled_bmm(): +def test_fp8_e4m3_groupwise_scaled_bmm(): B = 16 M = 40 N = 512 @@ -324,7 +342,7 @@ def test_fp8_e4m3_blockwise_scaled_bmm(): block_size = (128, 128) assert N % 128 == 0 and K % 128 == 0 # Only support N/K are multiple of 128 - func_name = "cutlass.blockwise_scaled_bmm_e4m3fn_e4m3fn" + func_name = "cutlass.groupwise_scaled_bmm_e4m3fn_e4m3fn" gemm_func = tvm.get_global_func(func_name, allow_missing=True) if gemm_func is None: print(f"Skipped as {func_name} is not available")