diff --git a/3rdparty/cutlass b/3rdparty/cutlass index bbe579a9e3be..afa177220367 160000 --- a/3rdparty/cutlass +++ b/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit bbe579a9e3beb6ea6626d9227ec32d0dae119a49 +Subproject commit afa1772203677c5118fcd82537a9c8fefbcc7008 diff --git a/3rdparty/cutlass_fpA_intB_gemm b/3rdparty/cutlass_fpA_intB_gemm index f09824e950ed..fdef2307917e 160000 --- a/3rdparty/cutlass_fpA_intB_gemm +++ b/3rdparty/cutlass_fpA_intB_gemm @@ -1 +1 @@ -Subproject commit f09824e950ed6678670004bd23578757b3473f21 +Subproject commit fdef2307917ec2c7cc5becc29fb95d77498484bd diff --git a/cmake/modules/contrib/CUTLASS.cmake b/cmake/modules/contrib/CUTLASS.cmake index b302622cbce8..b9097a02e93f 100644 --- a/cmake/modules/contrib/CUTLASS.cmake +++ b/cmake/modules/contrib/CUTLASS.cmake @@ -58,11 +58,15 @@ if(USE_CUDA AND USE_CUTLASS) list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp16_group_gemm.cu) list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_group_gemm.cu) list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_gemm.cu) + list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu) endif() if(TVM_CUTLASS_RUNTIME_SRCS) add_library(tvm_cutlass_objs OBJECT ${TVM_CUTLASS_RUNTIME_SRCS}) target_compile_options(tvm_cutlass_objs PRIVATE $<$:--expt-relaxed-constexpr>) - target_include_directories(tvm_cutlass_objs PRIVATE ${CUTLASS_DIR}/include) + target_include_directories(tvm_cutlass_objs PRIVATE + ${CUTLASS_DIR}/include + ${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm/cutlass_extensions/include + ) target_compile_definitions(tvm_cutlass_objs PRIVATE DMLC_USE_LOGGING_LIBRARY=) list(APPEND CUTLASS_RUNTIME_OBJS "$<${CUTLASS_GEN_COND}:$>") endif() diff --git a/src/runtime/contrib/cutlass/blockwise_scaled_gemm_runner.cuh b/src/runtime/contrib/cutlass/blockwise_scaled_gemm_runner.cuh new file mode 100644 index 000000000000..f520bf815a94 --- /dev/null +++ b/src/runtime/contrib/cutlass/blockwise_scaled_gemm_runner.cuh @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include "../../cuda/cuda_common.h" + +// clang-format off +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/float8.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass_extensions/gemm/collective/collective_builder.hpp" +#include "cutlass_extensions/gemm/dispatch_policy.hpp" +// clang-format on + +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + CHECK(error == cutlass::Status::kSuccess) \ + << "Got cutlass error: " << cutlassGetStatusString(error); \ + } + +using namespace cute; +using ProblemShape = Shape; +using tvm::runtime::NDArray; + +template +struct CutlassFP8ScaledBlockwiseGemmRunner { + using ElementAccumulator = float; + using ElementCompute = float; + using ElementBlockScale = float; + + using ElementA = cutlass::float_e4m3_t; + using LayoutA = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + using ElementB = cutlass::float_e4m3_t; + using LayoutB = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + using ElementC = void; + using LayoutC = cutlass::layout::RowMajor; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using LayoutD = cutlass::layout::RowMajor; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + using ArchTag = cutlass::arch::Sm90; + using OperatorClass = cutlass::arch::OpClassTensorOp; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + using StoreEpilogueCompute = + typename cutlass::epilogue::fusion::Sm90EVT; + + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum< + ScaleGranularityM>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, ElementAccumulator, + ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, + EpilogueSchedule, StoreEpilogueCompute>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, + ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal, // Indicates ProblemShape + CollectiveMainloop, CollectiveEpilogue, SchedulerType>; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideD = typename Gemm::GemmKernel::StrideD; + + void run_gemm(const ElementA* a_ptr, const ElementB* b_ptr, const ElementBlockScale* scales_a_ptr, + const ElementBlockScale* scales_b_ptr, ElementD* o_ptr, ProblemShape* problem_size, + StrideA* stride_a, StrideB* stride_b, StrideD* stride_d, uint8_t* workspace, + int64_t workspace_size, cudaStream_t stream) { + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + typename Gemm::GemmKernel::TileSchedulerArguments scheduler; + static constexpr bool UsesStreamKScheduler = + cute::is_same_v; + if constexpr (UsesStreamKScheduler) { + using DecompositionMode = typename cutlass::gemm::kernel::detail:: + PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using ReductionMode = typename cutlass::gemm::kernel::detail:: + PersistentTileSchedulerSm90StreamKParams::ReductionMode; + scheduler.decomposition_mode = DecompositionMode::StreamK; + scheduler.reduction_mode = ReductionMode::Nondeterministic; + } + + typename Gemm::Arguments arguments = { + cutlass::gemm::GemmUniversalMode::kGemm, + *problem_size, + {a_ptr, *stride_a, b_ptr, *stride_b, scales_a_ptr, scales_b_ptr}, + {{}, nullptr, *stride_d, o_ptr, *stride_d}, + hw_info, + scheduler}; + + Gemm gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(arguments)); + CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments)); + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream)); + CUTLASS_CHECK(gemm_op.run(stream)); + } +}; + +template +void cutlass_fp8_blockwise_scaled_gemm(ElementA* a, ElementB* b, ElementBlockScale* scales_a, + ElementBlockScale* scales_b, ElementD* out, + uint8_t* workspace, int64_t workspace_size, int64_t m, + int64_t n, int64_t k, cudaStream_t stream) { + if (k > 3 * n) { + using SchedulerType = cutlass::gemm::StreamKScheduler; + using Runner = + CutlassFP8ScaledBlockwiseGemmRunner; + using StrideA = typename Runner::StrideA; + using StrideB = typename Runner::StrideB; + using StrideD = typename Runner::StrideD; + + Runner runner; + StrideA stride_a = cute::make_stride(k, Int<1>{}, m * k); + StrideB stride_b = cute::make_stride(k, Int<1>{}, n * k); + StrideD stride_d = cute::make_stride(n, Int<1>{}, m * n); + ProblemShape problem_size{static_cast(m), static_cast(n), static_cast(k), 1}; + runner.run_gemm(a, b, scales_a, scales_b, out, &problem_size, &stride_a, &stride_b, &stride_d, + workspace, workspace_size, stream); + } else { + using SchedulerType = cutlass::gemm::PersistentScheduler; + using Runner = + CutlassFP8ScaledBlockwiseGemmRunner; + using StrideA = typename Runner::StrideA; + using StrideB = typename Runner::StrideB; + using StrideD = typename Runner::StrideD; + + Runner runner; + StrideA stride_a = cute::make_stride(k, Int<1>{}, m * k); + StrideB stride_b = cute::make_stride(k, Int<1>{}, n * k); + StrideD stride_d = cute::make_stride(n, Int<1>{}, m * n); + ProblemShape problem_size{static_cast(m), static_cast(n), static_cast(k), 1}; + runner.run_gemm(a, b, scales_a, scales_b, out, &problem_size, &stride_a, &stride_b, &stride_d, + workspace, workspace_size, stream); + } +} + +template +void cutlass_fp8_blockwise_scaled_bmm(ElementA* a, ElementB* b, ElementBlockScale* scales_a, + ElementBlockScale* scales_b, ElementD* out, + uint8_t* workspace, int64_t workspace_size, int64_t m, + int64_t n, int64_t k, int64_t l, cudaStream_t stream) { + if (k > 3 * n) { + using SchedulerType = cutlass::gemm::StreamKScheduler; + using Runner = + CutlassFP8ScaledBlockwiseGemmRunner; + using StrideA = typename Runner::StrideA; + using StrideB = typename Runner::StrideB; + using StrideD = typename Runner::StrideD; + + Runner runner; + StrideA stride_a = cute::make_stride(k, Int<1>{}, m * k); + StrideB stride_b = cute::make_stride(k, Int<1>{}, n * k); + StrideD stride_d = cute::make_stride(n, Int<1>{}, m * n); + ProblemShape problem_size{static_cast(m), static_cast(n), static_cast(k), + static_cast(l)}; + runner.run_gemm(a, b, scales_a, scales_b, out, &problem_size, &stride_a, &stride_b, &stride_d, + workspace, workspace_size, stream); + } else { + using SchedulerType = cutlass::gemm::PersistentScheduler; + using Runner = + CutlassFP8ScaledBlockwiseGemmRunner; + using StrideA = typename Runner::StrideA; + using StrideB = typename Runner::StrideB; + using StrideD = typename Runner::StrideD; + + Runner runner; + StrideA stride_a = cute::make_stride(k, Int<1>{}, m * k); + StrideB stride_b = cute::make_stride(k, Int<1>{}, n * k); + StrideD stride_d = cute::make_stride(n, Int<1>{}, m * n); + ProblemShape problem_size{static_cast(m), static_cast(n), static_cast(k), + static_cast(l)}; + runner.run_gemm(a, b, scales_a, scales_b, out, &problem_size, &stride_a, &stride_b, &stride_d, + workspace, workspace_size, stream); + } +} diff --git a/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu b/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu new file mode 100644 index 000000000000..4ac5a621a006 --- /dev/null +++ b/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include "../cublas/cublas_utils.h" +#include "blockwise_scaled_gemm_runner.cuh" + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +namespace tvm { +namespace runtime { + +void tvm_cutlass_fp8_blockwise_scaled_gemm(NDArray a, NDArray b, NDArray scales_a, NDArray scales_b, + NDArray workspace, int64_t block_size_0, + int64_t block_size_1, NDArray out) { + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + + // Workspace is used for storing device-side gemm arguments and cutlass internal workspace. + // Recommened size is 4MB. + auto get_stream_func = tvm::runtime::Registry::Get("runtime.get_cuda_stream"); + ICHECK(get_stream_func != nullptr); + cudaStream_t stream = static_cast((*get_stream_func)().operator void*()); + + CHECK_GE(a->ndim, 2); + CHECK_EQ(scales_a->ndim, a->ndim); + CHECK_EQ(b->ndim, 2); + CHECK_EQ(scales_b->ndim, 2); + CHECK_EQ(workspace->ndim, 1); + CHECK_EQ(out->ndim, a->ndim); + int64_t m = 1; + for (int64_t i = 0; i < a->ndim - 1; ++i) { + m *= a->shape[i]; + } + int64_t n = b->shape[0]; + CHECK_EQ(a->shape[a->ndim - 1], b->shape[1]) << "Only col-major B is supported now."; + int64_t k = a->shape[a->ndim - 1]; + + // scales_a is col-major of (*a_shape[:-1], k / block_size) + CHECK_EQ(scales_a->shape[0] * block_size_1, k); + for (int64_t i = 1; i < scales_a->ndim; ++i) { + CHECK_EQ(scales_a->shape[i], a->shape[i - 1]); + } + // scales_b is col-major of (k / block_size, n / block_size) + CHECK_EQ(scales_b->shape[0] * block_size_0, n); + CHECK_EQ(scales_b->shape[1] * block_size_1, k); + + using tvm::runtime::DataType; + CHECK_EQ(DataType(a->dtype), DataType::NVFloat8E4M3()); + CHECK_EQ(DataType(b->dtype), DataType::NVFloat8E4M3()); + CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32)); + CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32)); + CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8)); + + if (DataType(out->dtype) == DataType::Float(16)) { + cutlass_fp8_blockwise_scaled_gemm( + static_cast(a->data), static_cast(b->data), + static_cast(scales_a->data), static_cast(scales_b->data), + static_cast(out->data), static_cast(workspace->data), + workspace->shape[0] * DataType(workspace->dtype).bytes(), m, n, k, stream); + } else if (DataType(out->dtype) == DataType::BFloat(16)) { + cutlass_fp8_blockwise_scaled_gemm( + static_cast(a->data), static_cast(b->data), + static_cast(scales_a->data), static_cast(scales_b->data), + static_cast(out->data), static_cast(workspace->data), + workspace->shape[0] * DataType(workspace->dtype).bytes(), m, n, k, stream); + } else { + LOG(FATAL) << "Unsupported output dtype: " << DataType(out->dtype); + } +} + +void tvm_cutlass_fp8_blockwise_scaled_bmm(NDArray a, NDArray b, NDArray scales_a, NDArray scales_b, + NDArray workspace, int64_t block_size_0, + int64_t block_size_1, NDArray out) { + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + + // Workspace is used for storing device-side gemm arguments and cutlass internal workspace. + // Recommened size is 4MB. + auto get_stream_func = tvm::runtime::Registry::Get("runtime.get_cuda_stream"); + ICHECK(get_stream_func != nullptr); + cudaStream_t stream = static_cast((*get_stream_func)().operator void*()); + + CHECK_EQ(a->ndim, 3); + CHECK_EQ(scales_a->ndim, 3); + CHECK_EQ(b->ndim, 3); + CHECK_EQ(scales_b->ndim, 3); + CHECK_EQ(workspace->ndim, 1); + CHECK_EQ(out->ndim, 3); + int64_t batch_size = a->shape[0]; + int64_t m = a->shape[1]; + int64_t n = b->shape[1]; + CHECK_EQ(a->shape[2], b->shape[2]) << "Only col-major B is supported now."; + int64_t k = a->shape[2]; + CHECK_EQ(b->shape[0], batch_size); + CHECK_EQ(scales_a->shape[0], batch_size); + CHECK_EQ(scales_b->shape[0], batch_size); + CHECK_EQ(out->shape[0], batch_size); + + // scales_a is col-major of (batch_size, m, k / block_size) + CHECK_EQ(scales_a->shape[1] * block_size_1, k); + CHECK_EQ(scales_a->shape[2], m); + // scales_b is col-major of (k / block_size, n / block_size) + CHECK_EQ(scales_b->shape[1] * block_size_0, n); + CHECK_EQ(scales_b->shape[2] * block_size_1, k); + + using tvm::runtime::DataType; + CHECK_EQ(DataType(a->dtype), DataType::NVFloat8E4M3()); + CHECK_EQ(DataType(b->dtype), DataType::NVFloat8E4M3()); + CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32)); + CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32)); + CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8)); + + if (DataType(out->dtype) == DataType::Float(16)) { + cutlass_fp8_blockwise_scaled_bmm( + static_cast(a->data), static_cast(b->data), + static_cast(scales_a->data), static_cast(scales_b->data), + static_cast(out->data), static_cast(workspace->data), + workspace->shape[0] * DataType(workspace->dtype).bytes(), m, n, k, batch_size, stream); + } else if (DataType(out->dtype) == DataType::BFloat(16)) { + cutlass_fp8_blockwise_scaled_bmm( + static_cast(a->data), static_cast(b->data), + static_cast(scales_a->data), static_cast(scales_b->data), + static_cast(out->data), static_cast(workspace->data), + workspace->shape[0] * DataType(workspace->dtype).bytes(), m, n, k, batch_size, stream); + } else { + LOG(FATAL) << "Unsupported output dtype: " << DataType(out->dtype); + } +} + +TVM_REGISTER_GLOBAL("cutlass.blockwise_scaled_gemm_e4m3fn_e4m3fn") + .set_body_typed(tvm_cutlass_fp8_blockwise_scaled_gemm); +TVM_REGISTER_GLOBAL("cutlass.blockwise_scaled_bmm_e4m3fn_e4m3fn") + .set_body_typed(tvm_cutlass_fp8_blockwise_scaled_bmm); + +} // namespace runtime +} // namespace tvm + +#endif // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED diff --git a/src/runtime/contrib/cutlass/group_gemm_runner.cuh b/src/runtime/contrib/cutlass/group_gemm_runner.cuh index 71979672b93a..a3c52e27a9d5 100644 --- a/src/runtime/contrib/cutlass/group_gemm_runner.cuh +++ b/src/runtime/contrib/cutlass/group_gemm_runner.cuh @@ -105,10 +105,10 @@ struct CutlassGroupGemmRunner { using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - using StrideA = typename Gemm::GemmKernel::UnderlyingStrideA; - using StrideB = typename Gemm::GemmKernel::UnderlyingStrideB; - using StrideC = typename Gemm::GemmKernel::UnderlyingStrideC; - using StrideD = typename Gemm::GemmKernel::UnderlyingStrideD; + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; void run_group_gemm(const ElementA** ptr_A, const ElementB** ptr_B, const ElementC** ptr_C, ElementC** ptr_D, @@ -163,9 +163,9 @@ __global__ void prepare_group_gemm_arguments( ptr_D[group_id] = out + prev_rows * n; problem_sizes[group_id] = {static_cast(indptr[group_id] - prev_rows), static_cast(n), static_cast(k)}; - stride_A[group_id] = cute::make_stride(k, Int<1>{}, int64_t{0}); - stride_B[group_id] = cute::make_stride(k, Int<1>{}, int64_t{0}); - stride_D[group_id] = cute::make_stride(n, Int<1>{}, int64_t{0}); + stride_A[group_id] = cute::make_stride(k, Int<1>{}, Int<0>{}); + stride_B[group_id] = cute::make_stride(k, Int<1>{}, Int<0>{}); + stride_D[group_id] = cute::make_stride(n, Int<1>{}, Int<0>{}); } template diff --git a/tests/python/contrib/test_cutlass_gemm.py b/tests/python/contrib/test_cutlass_gemm.py new file mode 100644 index 000000000000..7c259e6f7d6d --- /dev/null +++ b/tests/python/contrib/test_cutlass_gemm.py @@ -0,0 +1,352 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Tuple + +import ml_dtypes +import numpy as np + +import tvm +import tvm.testing +from tvm.contrib.pickle_memoize import memoize + + +def get_random_ndarray(shape, dtype): + if dtype == "int8": + return np.random.randint(-128, 128, shape).astype(dtype) + elif dtype == "uint8": + return np.random.randint(0, 256, shape).astype(dtype) + return np.random.uniform(-1, 1, shape).astype(dtype) + + +def verify_group_gemm( + func_name, M, N, K, num_groups, x_dtype, weight_dtype, out_dtype, use_scale, rtol, atol +): + group_gemm_func = tvm.get_global_func(func_name, allow_missing=True) + if group_gemm_func is None: + print(f"Skipped as {func_name} is not available") + return + + @memoize("tvm.contrib.cutlass.test_group_gemm_sm90") + def get_ref_data(): + assert M % num_groups == 0 + M_per_group = M // num_groups + a_np = get_random_ndarray((M, K), "float16") + b_np = get_random_ndarray((num_groups, N, K), "float16") + indptr_np = np.arange(1, num_groups + 1).astype("int64") * M_per_group + c_np = np.concatenate( + [a_np[i * M_per_group : (i + 1) * M_per_group] @ b_np[i].T for i in range(num_groups)], + axis=0, + ) + return a_np, b_np, indptr_np, c_np + + def to_numpy_dtype(dtype): + mapping = {"float8_e5m2": ml_dtypes.float8_e5m2, "float8_e4m3fn": ml_dtypes.float8_e4m3fn} + return mapping.get(dtype, dtype) + + a_np, b_np, indptr_np, c_np = get_ref_data() + dev = tvm.cuda(0) + a_nd = tvm.nd.array(a_np.astype(to_numpy_dtype(x_dtype)), device=dev) + b_nd = tvm.nd.array(b_np.astype(to_numpy_dtype(weight_dtype)), device=dev) + c_nd = tvm.nd.empty(c_np.shape, dtype=out_dtype, device=dev) + indptr_nd = tvm.nd.array(indptr_np, device=dev) + workspace = tvm.nd.empty((4096 * 1024,), dtype="uint8", device=dev) + if use_scale: + scale = tvm.nd.array(np.array([1.0], dtype="float32"), device=dev) + group_gemm_func(a_nd, b_nd, indptr_nd, workspace, scale, c_nd) + else: + group_gemm_func(a_nd, b_nd, indptr_nd, workspace, c_nd) + tvm.testing.assert_allclose(c_nd.numpy(), c_np, rtol=rtol, atol=atol) + + +@tvm.testing.requires_cutlass +@tvm.testing.requires_cuda_compute_version(9) +def test_group_gemm_sm90(): + verify_group_gemm( + "cutlass.group_gemm_fp16_sm90", + 8, + 128, + 128, + 4, + "float16", + "float16", + "float16", + False, + rtol=1e-3, + atol=1e-3, + ) + verify_group_gemm( + "cutlass.group_gemm_e5m2_e5m2_fp16", + 8, + 16, + 16, + 4, + "float8_e5m2", + "float8_e5m2", + "float16", + True, + rtol=1e-1, + atol=1, + ) + verify_group_gemm( + "cutlass.group_gemm_e4m3_e4m3_fp16", + 8, + 16, + 16, + 4, + "float8_e4m3fn", + "float8_e4m3fn", + "float16", + True, + rtol=1e-1, + atol=1, + ) + + +def rowwise_quant_fp8_e4m3(shape: Tuple[int, int], block_size: Tuple[int, int], dtype: str): + x_full_np = (np.random.rand(*shape) * 2 - 1).astype(dtype) + x_scale_shape = ( + *shape[:-1], + (shape[-1] + block_size[1] - 1) // block_size[1], + ) + # For each (block_size[1]) block, compute the max abs value of `w_full_np` + x_max_abs_np = np.zeros(x_scale_shape, dtype="float32") + for i in range(x_scale_shape[-1]): + x_max_abs_np[..., i] = np.max( + np.abs(x_full_np[..., i * block_size[1] : min((i + 1) * block_size[1], shape[-1])]), + axis=-1, + )[0] + # Scale is the `x_max_abs_np` divided by the max value of quant_dtype in ml_dtypes + fp8_max = float(ml_dtypes.finfo("float8_e4m3fn").max) + x_scale_np = x_max_abs_np / fp8_max + # `x_np` is the `x_full_np` divided by the `x_scale_np` (with block awareness), + # clamped to (-fp8_max, fp8_max), and cast to `quant_dtype` + x_np = np.zeros_like(x_full_np, dtype="float8_e4m3fn") + for i in range(x_scale_shape[-1]): + x_np[..., i * block_size[1] : min((i + 1) * block_size[1], shape[-1])] = np.clip( + x_full_np[..., i * block_size[1] : min((i + 1) * block_size[1], shape[-1])] + / x_scale_np[..., i : i + 1], + -fp8_max, + fp8_max, + ) + + x_scale_np = np.random.rand(*x_scale_np.shape).astype("float32") / fp8_max + for i in range(x_scale_shape[-1]): + x_full_np[..., i * block_size[1] : min((i + 1) * block_size[1], shape[-1])] = ( + x_np[..., i * block_size[1] : min((i + 1) * block_size[1], shape[-1])].astype( + x_scale_np.dtype + ) + * x_scale_np[..., i : i + 1] + ) + return x_np, x_scale_np + + +def blockwise_quant_fp8_e4m3(shape: Tuple[int, int], block_size: Tuple[int, int], dtype: str): + w_full_np = (np.random.rand(*shape) * 2 - 1).astype(dtype) + w_scale_shape = ( + *shape[:-2], + (shape[-2] + block_size[0] - 1) // block_size[0], + (shape[-1] + block_size[1] - 1) // block_size[1], + ) + # For each (block_size[0], block_size[1]) block, compute the max abs value of `w_full_np` + w_max_abs_np = np.zeros(w_scale_shape, dtype="float32") + for i in range(w_scale_shape[-2]): + for j in range(w_scale_shape[-1]): + block_shape = ( + *shape[:-2], + min(block_size[0], shape[-2] - i * block_size[0]), + min(block_size[1], shape[-1] - j * block_size[1]), + ) + w_max_abs_np[..., i, j] = np.max( + np.abs( + w_full_np[ + ..., + i * block_size[0] : min((i + 1) * block_size[0], shape[-2]), + j * block_size[1] : min((j + 1) * block_size[1], shape[-1]), + ] + ).reshape(*shape[:-2], block_shape[-2] * block_shape[-1]), + axis=-1, + ) + # Scale is the `w_max_abs_np` divided by the max value of quant_dtype in ml_dtypes + fp8_max = float(ml_dtypes.finfo("float8_e4m3fn").max) + w_scale_np = w_max_abs_np / fp8_max + # `w_np` is the `w_full_np` divided by the `w_scale_np` (with block awareness), + # clamped to (-fp8_max, fp8_max), and cast to `quant_dtype` + w_np = np.zeros_like(w_full_np, dtype="float8_e4m3fn") + if len(w_scale_shape) == 2: + for i in range(w_scale_shape[-2]): + for j in range(w_scale_shape[-1]): + w_np[ + i * block_size[0] : min((i + 1) * block_size[0], shape[-2]), + j * block_size[1] : min((j + 1) * block_size[1], shape[-1]), + ] = np.clip( + w_full_np[ + i * block_size[0] : min((i + 1) * block_size[0], shape[-2]), + j * block_size[1] : min((j + 1) * block_size[1], shape[-1]), + ] + / w_scale_np[..., i, j], + -fp8_max, + fp8_max, + ) + else: + for e in range(w_scale_shape[0]): + for i in range(w_scale_shape[-2]): + for j in range(w_scale_shape[-1]): + w_np[ + e, + i * block_size[0] : min((i + 1) * block_size[0], shape[-2]), + j * block_size[1] : min((j + 1) * block_size[1], shape[-1]), + ] = np.clip( + w_full_np[ + e, + i * block_size[0] : min((i + 1) * block_size[0], shape[-2]), + j * block_size[1] : min((j + 1) * block_size[1], shape[-1]), + ] + / w_scale_np[e, i, j], + -fp8_max, + fp8_max, + ) + + w_scale_np = np.random.rand(*w_scale_np.shape).astype("float32") / fp8_max + return w_np, w_scale_np + + +def blockwise_matmul( + x_fp8_np: np.ndarray, + x_scale_np: np.ndarray, + w_np: np.ndarray, + w_scale_np: np.ndarray, + block_size: Tuple[int, int], + dtype: str, +): + o_np = np.zeros((x_fp8_np.shape[0], w_np.shape[0]), dtype=dtype) + for j in range(w_scale_np.shape[0]): + for k in range(w_scale_np.shape[1]): + o_np[:, j * block_size[0] : min((j + 1) * block_size[0], w_np.shape[0])] += ( + np.matmul( + x_fp8_np[ + :, k * block_size[1] : min((k + 1) * block_size[1], x_fp8_np.shape[1]) + ].astype(dtype), + w_np[ + j * block_size[0] : min((j + 1) * block_size[0], w_np.shape[0]), + k * block_size[1] : min((k + 1) * block_size[1], w_np.shape[1]), + ].T.astype(dtype), + ) + * x_scale_np[:, k : k + 1] + * w_scale_np[j, k] + ) + return o_np + + +def blockwise_bmm( + x_fp8_np: np.ndarray, + x_scale_np: np.ndarray, + w_np: np.ndarray, + w_scale_np: np.ndarray, + block_size: Tuple[int, int], + dtype: str, +): + o_np = np.zeros((x_fp8_np.shape[0], x_fp8_np.shape[1], w_np.shape[1]), dtype=dtype) + for j in range(w_scale_np.shape[1]): + for k in range(w_scale_np.shape[2]): + o_np[..., j * block_size[0] : min((j + 1) * block_size[0], w_np.shape[1])] += ( + np.matmul( + x_fp8_np[ + ..., k * block_size[1] : min((k + 1) * block_size[1], x_fp8_np.shape[2]) + ].astype(dtype), + w_np[ + ..., + j * block_size[0] : min((j + 1) * block_size[0], w_np.shape[1]), + k * block_size[1] : min((k + 1) * block_size[1], w_np.shape[2]), + ] + .transpose(0, 2, 1) + .astype(dtype), + ) + * x_scale_np[..., k : k + 1] + * w_scale_np[..., j : j + 1, k : k + 1] + ) + return o_np + + +@tvm.testing.requires_cutlass +@tvm.testing.requires_cuda_compute_version(9) +def test_fp8_e4m3_blockwise_scaled_gemm(): + M = 16 + N = 4608 + K = 896 + block_size = (128, 128) + assert N % 128 == 0 and K % 128 == 0 # Only support N/K are multiple of 128 + + func_name = "cutlass.blockwise_scaled_gemm_e4m3fn_e4m3fn" + gemm_func = tvm.get_global_func(func_name, allow_missing=True) + if gemm_func is None: + print(f"Skipped as {func_name} is not available") + return + + device = tvm.cuda(0) + dtype = "bfloat16" + x_np, x_scale_np = rowwise_quant_fp8_e4m3((M, K), block_size, dtype) + w_np, w_scale_np = blockwise_quant_fp8_e4m3((N, K), block_size, dtype) + o_np = blockwise_matmul(x_np, x_scale_np, w_np, w_scale_np, block_size, dtype) + x_tvm = tvm.nd.array(x_np, device=device) + x_scale_tvm = tvm.nd.array(x_scale_np.T, device=device) + w_tvm = tvm.nd.array(w_np, device=device) + w_scale_tvm = tvm.nd.array(w_scale_np, device=device) + workspace = tvm.nd.empty((4096 * 1024,), dtype="uint8", device=device) + o_tvm = tvm.nd.empty((M, N), dtype=dtype, device=device) + gemm_func( + x_tvm, w_tvm, x_scale_tvm, w_scale_tvm, workspace, block_size[0], block_size[1], o_tvm + ) + o_tvm = o_tvm.numpy() + tvm.testing.assert_allclose(o_tvm, o_np, rtol=1e-4, atol=0.5) + + +@tvm.testing.requires_cutlass +@tvm.testing.requires_cuda_compute_version(9) +def test_fp8_e4m3_blockwise_scaled_bmm(): + B = 16 + M = 40 + N = 512 + K = 128 + block_size = (128, 128) + assert N % 128 == 0 and K % 128 == 0 # Only support N/K are multiple of 128 + + func_name = "cutlass.blockwise_scaled_bmm_e4m3fn_e4m3fn" + gemm_func = tvm.get_global_func(func_name, allow_missing=True) + if gemm_func is None: + print(f"Skipped as {func_name} is not available") + return + + device = tvm.cuda(0) + dtype = "bfloat16" + x_np, x_scale_np = rowwise_quant_fp8_e4m3((B, M, K), block_size, dtype) + w_np, w_scale_np = blockwise_quant_fp8_e4m3((B, N, K), block_size, dtype) + o_np = blockwise_bmm(x_np, x_scale_np, w_np, w_scale_np, block_size, dtype) + x_tvm = tvm.nd.array(x_np, device=device) + x_scale_tvm = tvm.nd.array(x_scale_np.transpose(0, 2, 1), device=device) + w_tvm = tvm.nd.array(w_np, device=device) + w_scale_tvm = tvm.nd.array(w_scale_np, device=device) + workspace = tvm.nd.empty((4096 * 1024,), dtype="uint8", device=device) + o_tvm = tvm.nd.empty((B, M, N), dtype=dtype, device=device) + gemm_func( + x_tvm, w_tvm, x_scale_tvm, w_scale_tvm, workspace, block_size[0], block_size[1], o_tvm + ) + o_tvm = o_tvm.numpy() + tvm.testing.assert_allclose(o_tvm, o_np, rtol=1e-4, atol=0.5) + + +if __name__ == "__main__": + tvm.testing.main()