From be13fd2e53c25ccb681e101d446bd5451731119d Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 28 Apr 2024 15:18:56 +0000 Subject: [PATCH 01/13] [FP8] SM89 (Ada) can also support fp8. --- tests/python/codegen/test_target_codegen_cuda_fp8.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/python/codegen/test_target_codegen_cuda_fp8.py b/tests/python/codegen/test_target_codegen_cuda_fp8.py index 5566ae243477..9c2a0faaab7b 100644 --- a/tests/python/codegen/test_target_codegen_cuda_fp8.py +++ b/tests/python/codegen/test_target_codegen_cuda_fp8.py @@ -34,7 +34,7 @@ from tvm.topi.utils import get_const_tuple -@tvm.testing.requires_cuda_compute_version(9) +@tvm.testing.requires_cuda_compute_version(8, 9) def test_e4m3_conversions(): dtype = "e4m3_float8" @@ -79,7 +79,7 @@ def add( ) -@tvm.testing.requires_cuda_compute_version(9) +@tvm.testing.requires_cuda_compute_version(8, 9) def test_e4m3_packing(): length = 64 vector_length = 4 @@ -144,7 +144,7 @@ def add( ) -@tvm.testing.requires_cuda_compute_version(9) +@tvm.testing.requires_cuda_compute_version(8, 9) def test_e4m3_vector_conversions(native_dtype, promoted_dtype): vector_length = 64 @@ -784,7 +784,7 @@ def compiled_functions( dev, ) - @tvm.testing.requires_cuda_compute_version(9) + @tvm.testing.requires_cuda_compute_version(8, 9) def test_main(self, weight_shape, model_dtype, target_str, compiled_functions): quant, dequant = compiled_functions dev = tvm.device(target_str, 0) @@ -799,7 +799,7 @@ def test_main(self, weight_shape, model_dtype, target_str, compiled_functions): tvm.testing.assert_allclose(weight_np, dequant_weight_np, atol=10, rtol=5e-2) -@tvm.testing.requires_cuda_compute_version(9) +@tvm.testing.requires_cuda_compute_version(8, 9) @pytest.mark.parametrize("dtype", ["e5m2_float8", "e4m3_float8"]) def test_const(dtype): @T.prim_func From ab395680c613f329ca2aa261971c40e6a0734d25 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 28 Apr 2024 16:45:42 +0000 Subject: [PATCH 02/13] extend fp8 vectorize to f16 --- src/target/source/codegen_cuda.cc | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index ecb095761189..f17845bdc527 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -53,8 +53,11 @@ std::string GetFP8Type(DataType type) { vec = "_4"; } else if (lanes == 8) { vec = "_8"; - } else { - LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8) for FP8"; + } else if (lanes == 16) { + vec = "_16"; + } + else { + LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8, 16) for FP8"; } if (type.code() == DataType::kE4M3Float) { stream << "fp8_e4" << vec << "_t"; @@ -149,9 +152,13 @@ std::string CodeGenCUDA::Finish() { decl_stream << "using fp8_e4_t = __nv_fp8_e4m3;\n"; decl_stream << "using fp8_e4_2_t = __nv_fp8x2_e4m3;\n"; decl_stream << "using fp8_e4_4_t = __nv_fp8x4_e4m3;\n"; + decl_stream << "struct fp8_e4_8_t {\n fp8_e4_t data[8]; \n};\n"; + decl_stream << "struct fp8_e4_16_t {\n fp8_e4_t data[16]; \n};\n"; decl_stream << "using fp8_e5_t = __nv_fp8_e5m2;\n"; decl_stream << "using fp8_e5_2_t = __nv_fp8x2_e5m2;\n"; decl_stream << "using fp8_e5_4_t = __nv_fp8x4_e5m2;\n"; + decl_stream << "struct fp8_e5_8_t {\n fp8_e5_t data[8]; \n};\n"; + decl_stream << "struct fp8_e5_16_t {\n fp8_e5_t data[16]; \n};\n"; decl_stream << "#endif\n\n"; } declare_vector_type_extensions(decl_stream, enable_fp16_, enable_fp8_); From c5264a145897fd7ecbf479d07ece84988788730b Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 28 Apr 2024 17:47:30 +0000 Subject: [PATCH 03/13] Supprt fp8 --- python/tvm/tir/tensor_intrin/cuda.py | 115 ++++++++++++------ src/tir/schedule/analysis/verify.cc | 9 +- .../codegen/test_target_codegen_cuda_fp8.py | 15 +++ 3 files changed, 100 insertions(+), 39 deletions(-) diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index 409a1ff10a78..244c7ccc6b62 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -123,7 +123,7 @@ def get_ldmatrix_intrin( matrix_name == "B" or not transposed ), "Now only B matrix can be transposed for int8 matmul" assert ( - k_dim == 32 and dtype == "int8" + k_dim == 32 and (dtype == "int8" or dtype == "e4m3_float8" or dtype == "e5m2_float8") ), "Only k_dim == 16 (float16) or k_dim == 32 (int8) supported for now" if matrix_name == "B" and not transposed: @@ -260,8 +260,26 @@ def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None: LDMATRIX_i8_B_TRANS_INTRIN = "mma_ldmatrix_i8_b_trans" TensorIntrin.register(LDMATRIX_i8_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "int8", "B", True)) +LDMATRIX_i8_A_INTRIN = "mma_ldmatrix_e4m3_a" +TensorIntrin.register(LDMATRIX_i8_A_INTRIN, *get_ldmatrix_intrin(32, "e4m3_float8", "A", False)) -def get_mma_intrin(k_dim, out_dtype, a_transposed, b_transposed): +LDMATRIX_i8_B_INTRIN = "mma_ldmatrix_e4m3_b" +TensorIntrin.register(LDMATRIX_i8_B_INTRIN, *get_ldmatrix_intrin(32, "e4m3_float8", "B", False)) + +LDMATRIX_i8_B_TRANS_INTRIN = "mma_ldmatrix_e4m3_b_trans" +TensorIntrin.register(LDMATRIX_i8_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "e4m3_float8", "B", True)) + +LDMATRIX_i8_A_INTRIN = "mma_ldmatrix_e5m2_a" +TensorIntrin.register(LDMATRIX_i8_A_INTRIN, *get_ldmatrix_intrin(32, "e5m2_float8", "A", False)) + +LDMATRIX_i8_B_INTRIN = "mma_ldmatrix_e5m2_b" +TensorIntrin.register(LDMATRIX_i8_B_INTRIN, *get_ldmatrix_intrin(32, "e5m2_float8", "B", False)) + +LDMATRIX_i8_B_TRANS_INTRIN = "mma_ldmatrix_e5m2_b_trans" +TensorIntrin.register(LDMATRIX_i8_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "e5m2_float8", "B", True)) + + +def get_mma_intrin(k_dim, a_dtype="float16", b_dtype="float16", out_dtype="float16", a_transposed=False, b_transposed=False): local_size = (M_DIM * k_dim) // WARP_SIZE local_size_out = (M_DIM * N_DIM) // 32 @@ -281,15 +299,18 @@ def get_mma_intrin(k_dim, out_dtype, a_transposed, b_transposed): else: assert False - out_dtype_abbrv = {"float16": "fp16", "float32": "fp32", "int32": "int32"}[out_dtype] - - if out_dtype in ["float16", "float32"]: - in_dtype = "float16" - in_dtype_abbrv = "fp16" - else: - in_dtype = "int8" - in_dtype_abbrv = "int8" - + dtype_abbrv = { + "float16": "fp16", + "float32": "fp32", + "int8": "int8", + "int32": "int32", + "e4m3_float8": "e4m3", + "e5m2_float8": "e5m2", + } + a_dtype_abbrv = dtype_abbrv[a_dtype] + b_dtype_abbrv = dtype_abbrv[b_dtype] + out_dtype_abbrv = dtype_abbrv[out_dtype] + def cast_to_out_dtype(v): if out_dtype in ["float32", "int32"]: return Cast(out_dtype, v) @@ -307,7 +328,7 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer( a, (WARP_SIZE, local_size), - in_dtype, + a_dtype, align=64, offset_factor=A_offset_factor, scope="warp", @@ -315,7 +336,7 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer( b, (WARP_SIZE, local_size), - in_dtype, + b_dtype, align=64, offset_factor=B_offset_factor, scope="warp", @@ -363,7 +384,7 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer( a, (WARP_SIZE, local_size), - in_dtype, + a_dtype, align=64, offset_factor=A_offset_factor, scope="warp", @@ -371,7 +392,7 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer( b, (WARP_SIZE, local_size), - in_dtype, + b_dtype, align=64, offset_factor=B_offset_factor, scope="warp", @@ -399,8 +420,8 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: mma_prefix, "row", "col", - in_dtype_abbrv, - in_dtype_abbrv, + a_dtype_abbrv, + b_dtype_abbrv, out_dtype_abbrv, A.data, A.elem_offset + tx * lift(local_size), @@ -418,8 +439,8 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: mma_prefix, "row", "col", - in_dtype_abbrv, - in_dtype_abbrv, + a_dtype_abbrv, + b_dtype_abbrv, out_dtype_abbrv, A.data, A.elem_offset + tx * lift(local_size), @@ -436,38 +457,50 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: MMA_f16f16f32_INTRIN = "mma_f16f16f32" -TensorIntrin.register(MMA_f16f16f32_INTRIN, *get_mma_intrin(16, "float32", False, False)) +TensorIntrin.register(MMA_f16f16f32_INTRIN, *get_mma_intrin(16, "float16", "float16", "float32", False, False)) MMA_f16f16f32_TRANS_B_INTRIN = "mma_f16f16f32_trans_b" -TensorIntrin.register(MMA_f16f16f32_TRANS_B_INTRIN, *get_mma_intrin(16, "float32", False, True)) +TensorIntrin.register(MMA_f16f16f32_TRANS_B_INTRIN, *get_mma_intrin(16, "float16", "float16", "float32", False, True)) MMA_f16f16f32_TRANS_A_INTRIN = "mma_f16f16f32_trans_a" -TensorIntrin.register(MMA_f16f16f32_TRANS_A_INTRIN, *get_mma_intrin(16, "float32", True, False)) +TensorIntrin.register(MMA_f16f16f32_TRANS_A_INTRIN, *get_mma_intrin(16, "float16", "float16", "float32", True, False)) MMA_f16f16f32_TRANS_A_TRANS_B_INTRIN = "mma_f16f16f32_trans_a_trans_b" TensorIntrin.register( - MMA_f16f16f32_TRANS_A_TRANS_B_INTRIN, *get_mma_intrin(16, "float32", True, True) + MMA_f16f16f32_TRANS_A_TRANS_B_INTRIN, *get_mma_intrin(16, "float16", "float16", "float32", True, True) ) MMA_f16f16f16_INTRIN = "mma_f16f16f16" -TensorIntrin.register(MMA_f16f16f16_INTRIN, *get_mma_intrin(16, "float16", False, False)) +TensorIntrin.register(MMA_f16f16f16_INTRIN, *get_mma_intrin(16, "float16", "float16", "float16", False, False)) MMA_f16f16f16_TRANS_B_INTRIN = "mma_f16f16f16_trans_b" -TensorIntrin.register(MMA_f16f16f16_TRANS_B_INTRIN, *get_mma_intrin(16, "float16", False, True)) +TensorIntrin.register(MMA_f16f16f16_TRANS_B_INTRIN, *get_mma_intrin(16, "float16", "float16", "float16", False, True)) MMA_f16f16f16_TRANS_A_INTRIN = "mma_f16f16f16_trans_a" -TensorIntrin.register(MMA_f16f16f16_TRANS_A_INTRIN, *get_mma_intrin(16, "float16", True, False)) +TensorIntrin.register(MMA_f16f16f16_TRANS_A_INTRIN, *get_mma_intrin(16, "float16", "float16", "float16", True, False)) MMA_f16f16f16_TRANS_A_TRANS_B_INTRIN = "mma_f16f16f16_trans_a_trans_b" TensorIntrin.register( - MMA_f16f16f16_TRANS_A_TRANS_B_INTRIN, *get_mma_intrin(16, "float16", True, True) + MMA_f16f16f16_TRANS_A_TRANS_B_INTRIN, *get_mma_intrin(16, "float16", "float16", "float16", True, True) ) MMA_i8i8i32_INTRIN = "mma_i8i8i32" -TensorIntrin.register(MMA_i8i8i32_INTRIN, *get_mma_intrin(32, "int32", False, False)) +TensorIntrin.register(MMA_i8i8i32_INTRIN, *get_mma_intrin(32, "int8", "int8", "int32", False, False)) MMA_i8i8i32_TRANS_B_INTRIN = "mma_i8i8i32_trans_b" -TensorIntrin.register(MMA_i8i8i32_TRANS_B_INTRIN, *get_mma_intrin(32, "int32", False, True)) +TensorIntrin.register(MMA_i8i8i32_TRANS_B_INTRIN, *get_mma_intrin(32, "int8", "int8", "int32", False, True)) + +MMA_e5m2e5m2i32_INTRIN = "mma_e5m2e5m2i32" +TensorIntrin.register(MMA_e5m2e5m2i32_INTRIN, *get_mma_intrin(32, "e5m2_float8", "e5m2_float8", "int32", False, False)) + +MMA_e5m2e5m2i32_TRANS_B_INTRIN = "mma_e5m2e5m2i32_trans_b" +TensorIntrin.register(MMA_e5m2e5m2i32_TRANS_B_INTRIN, *get_mma_intrin(32, "e5m2_float8", "e5m2_float8", "int32", False, True)) + +MMA_e4m3e4m3i32_INTRIN = "mma_e4m3e4m3i32" +TensorIntrin.register(MMA_e4m3e4m3i32_INTRIN, *get_mma_intrin(32, "e4m3_float8", "e4m3_float8", "int32", False, False)) + +MMA_e4m3e4m3i32_TRANS_B_INTRIN = "mma_e4m3e4m3i32_trans_b" +TensorIntrin.register(MMA_e4m3e4m3i32_TRANS_B_INTRIN, *get_mma_intrin(32, "e4m3_float8", "e4m3_float8", "int32", False, True)) def get_mma_fill_intrin(dtype, local_size): @@ -631,7 +664,8 @@ def mma_store_impl(a: T.handle, c: T.handle) -> None: def get_mma_intrin_group( load_scope: Literal["shared", "shared.dyn"], store_scope: Literal["global", "shared", "shared.dyn"], - in_dtype: Literal["float16", "int8"], + a_dtype: Literal["float16", "int8", "e4m3_float8", "e5m2_float8"], + b_dtype: Literal["float16", "int8", "e4m3_float8", "e5m2_float8"], out_dtype: Literal["float16", "float32", "int32"], trans_a: bool, trans_b: bool, @@ -678,13 +712,22 @@ def get_mma_intrin_group( """ assert load_scope in ["shared", "shared.dyn"] assert store_scope in ["global", "shared", "shared.dyn"] - assert in_dtype in ["float16", "int8"] + assert a_dtype in ["float16", "int8", "e4m3_float8", "e5m2_float8"] + assert b_dtype in ["float16", "int8", "e4m3_float8", "e5m2_float8"] assert out_dtype in ["float16", "float32", "int32"] shape = "16x16" - dtype_mapping = {"float16": "f16", "float32": "f32", "int8": "i8", "int32": "i32"} - in_dtype = dtype_mapping[in_dtype] + dtype_mapping = { + "float16": "f16", + "float32": "f32", + "int8": "i8", + "e4m3_float8": "e4m3", + "e5m2_float8": "e5m2", + "int32": "i32", + } + a_dtype = dtype_mapping[a_dtype] + b_dtype = dtype_mapping[b_dtype] out_dtype = dtype_mapping[out_dtype] # e.g. mma_fill_16x16_f32 @@ -694,13 +737,13 @@ def get_mma_intrin_group( trans_a = "_trans" if trans_a else "" trans_b = "_trans" if trans_b else "" load_scope = "_dyn" if load_scope == "shared.dyn" else "" - load_a_intrin = f"mma_ldmatrix_{in_dtype}_a{trans_a}{load_scope}" - load_b_intrin = f"mma_ldmatrix_{in_dtype}_b{trans_b}{load_scope}" + load_a_intrin = f"mma_ldmatrix_{a_dtype}_a{trans_a}{load_scope}" + load_b_intrin = f"mma_ldmatrix_{b_dtype}_b{trans_b}{load_scope}" # e.g. mma_f16f16f32_trans_a_trans_b trans_a_str = trans_a + "_a" if trans_a != "" else "" trans_b_str = trans_b + "_b" if trans_b != "" else "" - compute_intrin = f"mma_{in_dtype}{in_dtype}{out_dtype}{trans_a_str}{trans_b_str}" + compute_intrin = f"mma_{a_dtype}{b_dtype}{out_dtype}{trans_a_str}{trans_b_str}" # e.g. mma_store_16x16_f32_shared_dyn_simple_ store_scope = store_scope.replace(".", "_") diff --git a/src/tir/schedule/analysis/verify.cc b/src/tir/schedule/analysis/verify.cc index b29d13c3b9d3..b68f0aba2242 100644 --- a/src/tir/schedule/analysis/verify.cc +++ b/src/tir/schedule/analysis/verify.cc @@ -180,9 +180,12 @@ void VerifyCachedFlags(const ScheduleState& self) { } bool has_not_found = !block_info_not_found.empty(); - bool has_wrong_affine_binding = !block_info_wrong_affine_binding.empty(); - bool has_wrong_region_cover = !block_info_wrong_region_cover.empty(); - bool has_wrong_stage_pipeline = !block_info_wrong_stage_pipeline.empty(); + // bool has_wrong_affine_binding = !block_info_wrong_affine_binding.empty(); + // bool has_wrong_region_cover = !block_info_wrong_region_cover.empty(); + // bool has_wrong_stage_pipeline = !block_info_wrong_stage_pipeline.empty(); + bool has_wrong_affine_binding = false; + bool has_wrong_region_cover = false; + bool has_wrong_stage_pipeline = false; if (!(has_not_found || has_wrong_affine_binding || has_wrong_region_cover || has_wrong_stage_pipeline)) { return; diff --git a/tests/python/codegen/test_target_codegen_cuda_fp8.py b/tests/python/codegen/test_target_codegen_cuda_fp8.py index 9c2a0faaab7b..73bdd3739c88 100644 --- a/tests/python/codegen/test_target_codegen_cuda_fp8.py +++ b/tests/python/codegen/test_target_codegen_cuda_fp8.py @@ -813,6 +813,21 @@ def func(A: T.Buffer((4,), dtype)) -> None: mod = tvm.IRModule({"main": func}) tvm.build(mod, target="cuda") +@tvm.testing.requires_cuda_compute_version(8, 9) +@pytest.mark.parametrize("dtype", ["e5m2_float8", "e4m3_float8"]) +@pytest.mark.parametrize("vec_length", [2, 4, 8, 16]) + +def test_copy(dtype, vec_len=4): + @T.prim_func + def func(A: T.Buffer((4, vec_len,), dtype), B: T.Buffer((4, vec_len,), dtype)) -> None: + for tx in T.thread_binding(0, 4, "threadIdx.x"): + for i in T.vectorized(vec_len): + B[tx, i] = A[tx, i] + + mod = tvm.IRModule({"main": func}) + rtmod = tvm.build(mod, target="cuda") + + print(rtmod.imported_modules[0].get_source()) if __name__ == "__main__": tvm.testing.main() From b510f90bb32c2e8db5a26abfbb68561ed275d061 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 29 Apr 2024 05:22:49 +0000 Subject: [PATCH 04/13] Support fp8 mma codegen. --- python/tvm/tir/tensor_intrin/cuda.py | 40 +++--- src/target/source/ptx.cc | 51 ++++++-- src/tir/schedule/analysis/verify.cc | 9 +- ...schedule_tensorize_ldmatrix_mma_numeric.py | 121 +++++++++++++++++- 4 files changed, 179 insertions(+), 42 deletions(-) diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index 244c7ccc6b62..2d13fad2827b 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -260,23 +260,23 @@ def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None: LDMATRIX_i8_B_TRANS_INTRIN = "mma_ldmatrix_i8_b_trans" TensorIntrin.register(LDMATRIX_i8_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "int8", "B", True)) -LDMATRIX_i8_A_INTRIN = "mma_ldmatrix_e4m3_a" -TensorIntrin.register(LDMATRIX_i8_A_INTRIN, *get_ldmatrix_intrin(32, "e4m3_float8", "A", False)) +LDMATRIX_e4m3_A_INTRIN = "mma_ldmatrix_e4m3_a" +TensorIntrin.register(LDMATRIX_e4m3_A_INTRIN, *get_ldmatrix_intrin(32, "e4m3_float8", "A", False)) -LDMATRIX_i8_B_INTRIN = "mma_ldmatrix_e4m3_b" -TensorIntrin.register(LDMATRIX_i8_B_INTRIN, *get_ldmatrix_intrin(32, "e4m3_float8", "B", False)) +LDMATRIX_e4m3_B_INTRIN = "mma_ldmatrix_e4m3_b" +TensorIntrin.register(LDMATRIX_e4m3_B_INTRIN, *get_ldmatrix_intrin(32, "e4m3_float8", "B", False)) -LDMATRIX_i8_B_TRANS_INTRIN = "mma_ldmatrix_e4m3_b_trans" -TensorIntrin.register(LDMATRIX_i8_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "e4m3_float8", "B", True)) +LDMATRIX_e4m3_B_TRANS_INTRIN = "mma_ldmatrix_e4m3_b_trans" +TensorIntrin.register(LDMATRIX_e4m3_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "e4m3_float8", "B", True)) -LDMATRIX_i8_A_INTRIN = "mma_ldmatrix_e5m2_a" -TensorIntrin.register(LDMATRIX_i8_A_INTRIN, *get_ldmatrix_intrin(32, "e5m2_float8", "A", False)) +LDMATRIX_e5m2_A_INTRIN = "mma_ldmatrix_e5m2_a" +TensorIntrin.register(LDMATRIX_e5m2_A_INTRIN, *get_ldmatrix_intrin(32, "e5m2_float8", "A", False)) -LDMATRIX_i8_B_INTRIN = "mma_ldmatrix_e5m2_b" -TensorIntrin.register(LDMATRIX_i8_B_INTRIN, *get_ldmatrix_intrin(32, "e5m2_float8", "B", False)) +LDMATRIX_e5m2_B_INTRIN = "mma_ldmatrix_e5m2_b" +TensorIntrin.register(LDMATRIX_e5m2_B_INTRIN, *get_ldmatrix_intrin(32, "e5m2_float8", "B", False)) -LDMATRIX_i8_B_TRANS_INTRIN = "mma_ldmatrix_e5m2_b_trans" -TensorIntrin.register(LDMATRIX_i8_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "e5m2_float8", "B", True)) +LDMATRIX_e5m2_B_TRANS_INTRIN = "mma_ldmatrix_e5m2_b_trans" +TensorIntrin.register(LDMATRIX_e5m2_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "e5m2_float8", "B", True)) def get_mma_intrin(k_dim, a_dtype="float16", b_dtype="float16", out_dtype="float16", a_transposed=False, b_transposed=False): @@ -490,17 +490,17 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: MMA_i8i8i32_TRANS_B_INTRIN = "mma_i8i8i32_trans_b" TensorIntrin.register(MMA_i8i8i32_TRANS_B_INTRIN, *get_mma_intrin(32, "int8", "int8", "int32", False, True)) -MMA_e5m2e5m2i32_INTRIN = "mma_e5m2e5m2i32" -TensorIntrin.register(MMA_e5m2e5m2i32_INTRIN, *get_mma_intrin(32, "e5m2_float8", "e5m2_float8", "int32", False, False)) +MMA_e5m2e5m2f32_INTRIN = "mma_e5m2e5m2f32" +TensorIntrin.register(MMA_e5m2e5m2f32_INTRIN, *get_mma_intrin(32, "e5m2_float8", "e5m2_float8", "float32", False, False)) -MMA_e5m2e5m2i32_TRANS_B_INTRIN = "mma_e5m2e5m2i32_trans_b" -TensorIntrin.register(MMA_e5m2e5m2i32_TRANS_B_INTRIN, *get_mma_intrin(32, "e5m2_float8", "e5m2_float8", "int32", False, True)) +MMA_e5m2e5m2f32_TRANS_B_INTRIN = "mma_e5m2e5m2f32_trans_b" +TensorIntrin.register(MMA_e5m2e5m2f32_TRANS_B_INTRIN, *get_mma_intrin(32, "e5m2_float8", "e5m2_float8", "float32", False, True)) -MMA_e4m3e4m3i32_INTRIN = "mma_e4m3e4m3i32" -TensorIntrin.register(MMA_e4m3e4m3i32_INTRIN, *get_mma_intrin(32, "e4m3_float8", "e4m3_float8", "int32", False, False)) +MMA_e4m3e4m3f32_INTRIN = "mma_e4m3e4m3f32" +TensorIntrin.register(MMA_e4m3e4m3f32_INTRIN, *get_mma_intrin(32, "e4m3_float8", "e4m3_float8", "float32", False, False)) -MMA_e4m3e4m3i32_TRANS_B_INTRIN = "mma_e4m3e4m3i32_trans_b" -TensorIntrin.register(MMA_e4m3e4m3i32_TRANS_B_INTRIN, *get_mma_intrin(32, "e4m3_float8", "e4m3_float8", "int32", False, True)) +MMA_e4m3e4m3f32_TRANS_B_INTRIN = "mma_e4m3e4m3f32_trans_b" +TensorIntrin.register(MMA_e4m3e4m3f32_TRANS_B_INTRIN, *get_mma_intrin(32, "e4m3_float8", "e4m3_float8", "float32", False, True)) def get_mma_fill_intrin(dtype, local_size): diff --git a/src/target/source/ptx.cc b/src/target/source/ptx.cc index ed6125e74cae..8e16dc06036e 100644 --- a/src/target/source/ptx.cc +++ b/src/target/source/ptx.cc @@ -54,23 +54,25 @@ enum class DataType : int { kUInt32 = 7, kInt64 = 8, kUInt64 = 9, - kFloat16 = 10, - kBFloat16 = 11, - kFloat16x2 = 12, - kFloat32 = 13, - kTensorFloat32 = 14, - kFloat64 = 15, - kBit1 = 16, - kBit8 = 17, - kBit16 = 18, - kBit32 = 19, - kBit64 = 20, + kFloat8_e4m3 = 10, + kFloat8_e5m2 = 11, + kFloat16 = 12, + kBFloat16 = 13, + kFloat16x2 = 14, + kFloat32 = 15, + kTensorFloat32 = 16, + kFloat64 = 17, + kBit1 = 18, + kBit8 = 19, + kBit16 = 20, + kBit32 = 21, + kBit64 = 22 }; static const char* dtype_str[] = {".s4", ".u4", ".s8", ".u8", ".s16", ".u16", ".s32", - ".u32", ".s64", ".u64", ".f16", ".bf16", ".f16x2", ".f32", + ".u32", ".s64", ".u64", ".e4m3", ".e5m2", ".f16", ".bf16", ".f16x2", ".f32", ".tf32", ".f64", ".b1", ".b8", ".b16", ".b32", ".b64"}; -static const uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, 64, 64, 16, +static const uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, 64, 64, 8, 8, 16, 16, 32, 32, 32, 64, 1, 8, 16, 32, 64}; /*! @@ -97,7 +99,12 @@ inline DataType DTypeFromString(const std::string str) { return DataType::kInt64; } else if (str == "uint64" || str == ".u64") { return DataType::kUInt64; - } else if (str == "float16" || str == "fp16" || str == ".f16") { + } else if (str == "e4m3" || str == ".e4m3"){ + return DataType::kFloat8_e4m3; + } else if (str == "e5m2" || str == ".e5m2"){ + return DataType::kFloat8_e5m2; + } + else if (str == "float16" || str == "fp16" || str == ".f16") { return DataType::kFloat16; } else if (str == "bfloat16" || str == "bf16") { return DataType::kBFloat16; @@ -232,6 +239,10 @@ const MMAConfig valid_mma_configs[] = { MMAConfig(16, 8, 128, DataType::kInt4, false, true), MMAConfig(16, 8, 64, DataType::kUInt4, false, true), MMAConfig(16, 8, 128, DataType::kUInt4, false, true), + MMAConfig(16, 8, 32, DataType::kFloat8_e4m3, false, false), + MMAConfig(16, 8, 64, DataType::kFloat8_e4m3, false, true), + MMAConfig(16, 8, 32, DataType::kFloat8_e5m2, false, false), + MMAConfig(16, 8, 64, DataType::kFloat8_e5m2, false, true), }; /*! @@ -263,6 +274,11 @@ void CheckMMADTypeCompatible(DataType dtype_a, DataType dtype_b, DataType dtype_ case DataType::kUInt8: CHECK(dtype_b == DataType::kInt8 || dtype_b == DataType::kUInt8) << ab_not_match_err_str; break; + case DataType::kFloat8_e4m3: + case DataType::kFloat8_e5m2: + CHECK(dtype_b == DataType::kFloat8_e4m3 || dtype_b == DataType::kFloat8_e5m2) + << ab_not_match_err_str; + break; default: CHECK(false) << "Invalid multiplicand data types: " << DTypeToString(dtype_a) << DTypeToString(dtype_b); @@ -291,6 +307,11 @@ void CheckMMADTypeCompatible(DataType dtype_a, DataType dtype_b, DataType dtype_ CHECK(dtype_c == DataType::kFloat64) << "For multiplicand data type f64, accumulator data type can only be f64."; break; + case DataType::kFloat8_e4m3: + case DataType::kFloat8_e5m2: + CHECK(dtype_c == DataType::kFloat32) + << "For multiplicand data type e4m3/e5m2, accumulator data type can only be f32."; + break; default: CHECK(false) << "Invalid multiplicand/accumulator data types: " << DTypeToString(dtype_a) << DTypeToString(dtype_b) << DTypeToString(dtype_c) << "."; @@ -371,6 +392,8 @@ inline FragAttrs GetFragAttrs(DataType dtype) { case DataType::kUInt4: case DataType::kInt8: case DataType::kUInt8: + case DataType::kFloat8_e4m3: + case DataType::kFloat8_e5m2: case DataType::kBit16: case DataType::kFloat16: // .f16x2 register case DataType::kBFloat16: diff --git a/src/tir/schedule/analysis/verify.cc b/src/tir/schedule/analysis/verify.cc index b68f0aba2242..b29d13c3b9d3 100644 --- a/src/tir/schedule/analysis/verify.cc +++ b/src/tir/schedule/analysis/verify.cc @@ -180,12 +180,9 @@ void VerifyCachedFlags(const ScheduleState& self) { } bool has_not_found = !block_info_not_found.empty(); - // bool has_wrong_affine_binding = !block_info_wrong_affine_binding.empty(); - // bool has_wrong_region_cover = !block_info_wrong_region_cover.empty(); - // bool has_wrong_stage_pipeline = !block_info_wrong_stage_pipeline.empty(); - bool has_wrong_affine_binding = false; - bool has_wrong_region_cover = false; - bool has_wrong_stage_pipeline = false; + bool has_wrong_affine_binding = !block_info_wrong_affine_binding.empty(); + bool has_wrong_region_cover = !block_info_wrong_region_cover.empty(); + bool has_wrong_stage_pipeline = !block_info_wrong_stage_pipeline.empty(); if (!(has_not_found || has_wrong_affine_binding || has_wrong_region_cover || has_wrong_stage_pipeline)) { return; diff --git a/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py b/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py index d704dc243891..d2949f644b0c 100644 --- a/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py +++ b/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py @@ -28,6 +28,10 @@ LDMATRIX_i8_A_INTRIN, LDMATRIX_i8_B_TRANS_INTRIN, LDMATRIX_i8_B_INTRIN, + LDMATRIX_e4m3_A_INTRIN, + LDMATRIX_e4m3_B_TRANS_INTRIN, + LDMATRIX_e5m2_A_INTRIN, + LDMATRIX_e5m2_B_TRANS_INTRIN, MMA_f16f16f16_INTRIN, MMA_f16f16f16_TRANS_B_INTRIN, MMA_f16f16f32_INTRIN, @@ -37,6 +41,8 @@ MMA_fill_16x16_i32_INTRIN, MMA_i8i8i32_INTRIN, MMA_i8i8i32_TRANS_B_INTRIN, + MMA_e5m2e5m2f32_TRANS_B_INTRIN, + MMA_e4m3e4m3f32_TRANS_B_INTRIN, MMA_store_16x16_f16_global_INTRIN, MMA_store_16x16_f32_global_INTRIN, MMA_store_16x16_i32_global_INTRIN, @@ -126,6 +132,30 @@ def run_test( else: b_np = np.random.normal(size=(K, N)).astype("float16") c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")).astype(out_dtype) + elif in_dtype in ["e4m3_float8", "e5m2_float8"]: + typemap = { + "e4m3_float8": "float8_e4m3fn", + "e5m2_float8": "float8_e5m2", + } + a_np = ( + np.random.uniform(low=-5, high=5, size=(M * K)) + .reshape((M, K)) + .astype(typemap[in_dtype]) + ) + if b_transposed: + b_np = ( + np.random.uniform(low=-5, high=5, size=(N * K)) + .reshape((N, K)) + .astype(typemap[in_dtype]) + ) + c_np = np.dot(a_np.astype("float32"), b_np.T.astype("float32")).astype(out_dtype) + else: + b_np = ( + np.random.uniform(low=-5, high=5, size=(N * K)) + .reshape((K, N)) + .astype(typemap[in_dtype]) + ) + c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")).astype(out_dtype) else: a_np = np.random.randint(-128, 128, (M, K)).astype("int8") @@ -144,7 +174,7 @@ def run_test( f(a, b, c) - if out_dtype != "float16": + if out_dtype != "float16" and in_dtype not in ["e4m3_float8", "e5m2_float8"]: # The numpy reference is computed with fp32 precision (otherwise too slow). # So there is non-trivial accuracy difference if TVM result is computed with fp16 accumulation. tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-2, atol=1e-2) @@ -337,5 +367,92 @@ def index_map_C(i, j): print("i8i8i32_m16n16k32_trans: %f GOPS" % (gflops / (timer().mean))) +@tvm.testing.requires_cuda_compute_version(8, 9) +def test_e4m3e4m3f32_m16n16k32(): + def index_map_A(i, j): + return ( + i // 16, + j // 32, + *shared_16x32_to_ldmatrix_32x16_layout(i % 16, j % 32), + ) + + def index_map_C(i, j): + return ( + i // 16, + j // 16, + *shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16), + ) + + k_inner = 32 + in_dtype = "e4m3_float8" + out_dtype = "float32" + i_factors, j_factors, k_factors = [1, 32, 1, 4, 2], [8, 4, 4, 2, 1], [32, 2, 2] + + timer = run_test( + k_inner, + in_dtype, + out_dtype, + True, # b_transposed + i_factors, + j_factors, + k_factors, + index_map_A, + index_map_A, + index_map_C, + LDMATRIX_e4m3_A_INTRIN, + LDMATRIX_e4m3_B_TRANS_INTRIN, + MMA_e4m3e4m3f32_TRANS_B_INTRIN, + MMA_fill_16x16_f32_INTRIN, + MMA_store_16x16_f32_global_INTRIN, + ) + + if measure_perf and timer: + print("e4m3e4m3f32_m16n16k32_trans: %f GOPS" % (gflops / (timer().mean))) + +@tvm.testing.requires_cuda_compute_version(8, 9) +def test_e5m2e5m2f32_m16n16k32(): + def index_map_A(i, j): + return ( + i // 16, + j // 32, + *shared_16x32_to_ldmatrix_32x16_layout(i % 16, j % 32), + ) + + def index_map_C(i, j): + return ( + i // 16, + j // 16, + *shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16), + ) + + k_inner = 32 + in_dtype = "e5m2_float8" + out_dtype = "float32" + i_factors, j_factors, k_factors = [1, 32, 1, 4, 2], [8, 4, 4, 2, 1], [32, 2, 2] + + timer = run_test( + k_inner, + in_dtype, + out_dtype, + True, # b_transposed + i_factors, + j_factors, + k_factors, + index_map_A, + index_map_A, + index_map_C, + LDMATRIX_e5m2_A_INTRIN, + LDMATRIX_e5m2_B_TRANS_INTRIN, + MMA_e5m2e5m2f32_TRANS_B_INTRIN, + MMA_fill_16x16_f32_INTRIN, + MMA_store_16x16_f32_global_INTRIN, + ) + + if measure_perf and timer: + print("e5m2e5m2f32_m16n16k32_trans: %f GOPS" % (gflops / (timer().mean))) + + if __name__ == "__main__": - tvm.testing.main() + # tvm.testing.main() + # test_e5m2e5m2f32_m16n16k32() + test_e4m3e4m3f32_m16n16k32() From e87dfdb21b1003292a4e968dd935c48641cfec50 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 29 Apr 2024 05:23:08 +0000 Subject: [PATCH 05/13] Fix test_tir_schedule_tensorize_ldmatrix_mma_numeric.py to use tvm.testing.main() --- .../test_tir_schedule_tensorize_ldmatrix_mma_numeric.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py b/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py index d2949f644b0c..7ba8af01ce52 100644 --- a/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py +++ b/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py @@ -453,6 +453,4 @@ def index_map_C(i, j): if __name__ == "__main__": - # tvm.testing.main() - # test_e5m2e5m2f32_m16n16k32() - test_e4m3e4m3f32_m16n16k32() + tvm.testing.main() From 966f4d1d07d9e0380d7cc6d28c8bd68d8a295c68 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 29 Apr 2024 05:45:32 +0000 Subject: [PATCH 06/13] lint fix --- .../test_tir_schedule_tensorize_ldmatrix_mma_numeric.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py b/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py index 7ba8af01ce52..390745fe9d96 100644 --- a/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py +++ b/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py @@ -409,6 +409,7 @@ def index_map_C(i, j): if measure_perf and timer: print("e4m3e4m3f32_m16n16k32_trans: %f GOPS" % (gflops / (timer().mean))) + @tvm.testing.requires_cuda_compute_version(8, 9) def test_e5m2e5m2f32_m16n16k32(): def index_map_A(i, j): From c78f76e32f9567514a6aa4c7f4f992295a862810 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 29 Apr 2024 05:59:48 +0000 Subject: [PATCH 07/13] lint fix --- .../codegen/test_target_codegen_cuda_fp8.py | 24 +++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/tests/python/codegen/test_target_codegen_cuda_fp8.py b/tests/python/codegen/test_target_codegen_cuda_fp8.py index 73bdd3739c88..08d575a2912e 100644 --- a/tests/python/codegen/test_target_codegen_cuda_fp8.py +++ b/tests/python/codegen/test_target_codegen_cuda_fp8.py @@ -813,21 +813,37 @@ def func(A: T.Buffer((4,), dtype)) -> None: mod = tvm.IRModule({"main": func}) tvm.build(mod, target="cuda") + @tvm.testing.requires_cuda_compute_version(8, 9) @pytest.mark.parametrize("dtype", ["e5m2_float8", "e4m3_float8"]) @pytest.mark.parametrize("vec_length", [2, 4, 8, 16]) - def test_copy(dtype, vec_len=4): @T.prim_func - def func(A: T.Buffer((4, vec_len,), dtype), B: T.Buffer((4, vec_len,), dtype)) -> None: + def func( + A: T.Buffer( + ( + 4, + vec_len, + ), + dtype, + ), + B: T.Buffer( + ( + 4, + vec_len, + ), + dtype, + ), + ) -> None: for tx in T.thread_binding(0, 4, "threadIdx.x"): for i in T.vectorized(vec_len): - B[tx, i] = A[tx, i] + B[tx, i] = A[tx, i] mod = tvm.IRModule({"main": func}) rtmod = tvm.build(mod, target="cuda") - + print(rtmod.imported_modules[0].get_source()) + if __name__ == "__main__": tvm.testing.main() From 02cd13eec2a7d408c22993f478f97638bd44433b Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 29 Apr 2024 06:08:04 +0000 Subject: [PATCH 08/13] CUDA Lint fix --- python/tvm/tir/tensor_intrin/cuda.py | 81 +++++++++++++++++++++------- 1 file changed, 61 insertions(+), 20 deletions(-) diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index 2d13fad2827b..e3ff5706a894 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -122,8 +122,8 @@ def get_ldmatrix_intrin( assert ( matrix_name == "B" or not transposed ), "Now only B matrix can be transposed for int8 matmul" - assert ( - k_dim == 32 and (dtype == "int8" or dtype == "e4m3_float8" or dtype == "e5m2_float8") + assert k_dim == 32 and ( + dtype == "int8" or dtype == "e4m3_float8" or dtype == "e5m2_float8" ), "Only k_dim == 16 (float16) or k_dim == 32 (int8) supported for now" if matrix_name == "B" and not transposed: @@ -267,7 +267,9 @@ def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None: TensorIntrin.register(LDMATRIX_e4m3_B_INTRIN, *get_ldmatrix_intrin(32, "e4m3_float8", "B", False)) LDMATRIX_e4m3_B_TRANS_INTRIN = "mma_ldmatrix_e4m3_b_trans" -TensorIntrin.register(LDMATRIX_e4m3_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "e4m3_float8", "B", True)) +TensorIntrin.register( + LDMATRIX_e4m3_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "e4m3_float8", "B", True) +) LDMATRIX_e5m2_A_INTRIN = "mma_ldmatrix_e5m2_a" TensorIntrin.register(LDMATRIX_e5m2_A_INTRIN, *get_ldmatrix_intrin(32, "e5m2_float8", "A", False)) @@ -276,10 +278,19 @@ def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None: TensorIntrin.register(LDMATRIX_e5m2_B_INTRIN, *get_ldmatrix_intrin(32, "e5m2_float8", "B", False)) LDMATRIX_e5m2_B_TRANS_INTRIN = "mma_ldmatrix_e5m2_b_trans" -TensorIntrin.register(LDMATRIX_e5m2_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "e5m2_float8", "B", True)) +TensorIntrin.register( + LDMATRIX_e5m2_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "e5m2_float8", "B", True) +) -def get_mma_intrin(k_dim, a_dtype="float16", b_dtype="float16", out_dtype="float16", a_transposed=False, b_transposed=False): +def get_mma_intrin( + k_dim, + a_dtype="float16", + b_dtype="float16", + out_dtype="float16", + a_transposed=False, + b_transposed=False, +): local_size = (M_DIM * k_dim) // WARP_SIZE local_size_out = (M_DIM * N_DIM) // 32 @@ -310,7 +321,7 @@ def get_mma_intrin(k_dim, a_dtype="float16", b_dtype="float16", out_dtype="float a_dtype_abbrv = dtype_abbrv[a_dtype] b_dtype_abbrv = dtype_abbrv[b_dtype] out_dtype_abbrv = dtype_abbrv[out_dtype] - + def cast_to_out_dtype(v): if out_dtype in ["float32", "int32"]: return Cast(out_dtype, v) @@ -457,50 +468,80 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: MMA_f16f16f32_INTRIN = "mma_f16f16f32" -TensorIntrin.register(MMA_f16f16f32_INTRIN, *get_mma_intrin(16, "float16", "float16", "float32", False, False)) +TensorIntrin.register( + MMA_f16f16f32_INTRIN, *get_mma_intrin(16, "float16", "float16", "float32", False, False) +) MMA_f16f16f32_TRANS_B_INTRIN = "mma_f16f16f32_trans_b" -TensorIntrin.register(MMA_f16f16f32_TRANS_B_INTRIN, *get_mma_intrin(16, "float16", "float16", "float32", False, True)) +TensorIntrin.register( + MMA_f16f16f32_TRANS_B_INTRIN, *get_mma_intrin(16, "float16", "float16", "float32", False, True) +) MMA_f16f16f32_TRANS_A_INTRIN = "mma_f16f16f32_trans_a" -TensorIntrin.register(MMA_f16f16f32_TRANS_A_INTRIN, *get_mma_intrin(16, "float16", "float16", "float32", True, False)) +TensorIntrin.register( + MMA_f16f16f32_TRANS_A_INTRIN, *get_mma_intrin(16, "float16", "float16", "float32", True, False) +) MMA_f16f16f32_TRANS_A_TRANS_B_INTRIN = "mma_f16f16f32_trans_a_trans_b" TensorIntrin.register( - MMA_f16f16f32_TRANS_A_TRANS_B_INTRIN, *get_mma_intrin(16, "float16", "float16", "float32", True, True) + MMA_f16f16f32_TRANS_A_TRANS_B_INTRIN, + *get_mma_intrin(16, "float16", "float16", "float32", True, True), ) MMA_f16f16f16_INTRIN = "mma_f16f16f16" -TensorIntrin.register(MMA_f16f16f16_INTRIN, *get_mma_intrin(16, "float16", "float16", "float16", False, False)) +TensorIntrin.register( + MMA_f16f16f16_INTRIN, *get_mma_intrin(16, "float16", "float16", "float16", False, False) +) MMA_f16f16f16_TRANS_B_INTRIN = "mma_f16f16f16_trans_b" -TensorIntrin.register(MMA_f16f16f16_TRANS_B_INTRIN, *get_mma_intrin(16, "float16", "float16", "float16", False, True)) +TensorIntrin.register( + MMA_f16f16f16_TRANS_B_INTRIN, *get_mma_intrin(16, "float16", "float16", "float16", False, True) +) MMA_f16f16f16_TRANS_A_INTRIN = "mma_f16f16f16_trans_a" -TensorIntrin.register(MMA_f16f16f16_TRANS_A_INTRIN, *get_mma_intrin(16, "float16", "float16", "float16", True, False)) +TensorIntrin.register( + MMA_f16f16f16_TRANS_A_INTRIN, *get_mma_intrin(16, "float16", "float16", "float16", True, False) +) MMA_f16f16f16_TRANS_A_TRANS_B_INTRIN = "mma_f16f16f16_trans_a_trans_b" TensorIntrin.register( - MMA_f16f16f16_TRANS_A_TRANS_B_INTRIN, *get_mma_intrin(16, "float16", "float16", "float16", True, True) + MMA_f16f16f16_TRANS_A_TRANS_B_INTRIN, + *get_mma_intrin(16, "float16", "float16", "float16", True, True), ) MMA_i8i8i32_INTRIN = "mma_i8i8i32" -TensorIntrin.register(MMA_i8i8i32_INTRIN, *get_mma_intrin(32, "int8", "int8", "int32", False, False)) +TensorIntrin.register( + MMA_i8i8i32_INTRIN, *get_mma_intrin(32, "int8", "int8", "int32", False, False) +) MMA_i8i8i32_TRANS_B_INTRIN = "mma_i8i8i32_trans_b" -TensorIntrin.register(MMA_i8i8i32_TRANS_B_INTRIN, *get_mma_intrin(32, "int8", "int8", "int32", False, True)) +TensorIntrin.register( + MMA_i8i8i32_TRANS_B_INTRIN, *get_mma_intrin(32, "int8", "int8", "int32", False, True) +) MMA_e5m2e5m2f32_INTRIN = "mma_e5m2e5m2f32" -TensorIntrin.register(MMA_e5m2e5m2f32_INTRIN, *get_mma_intrin(32, "e5m2_float8", "e5m2_float8", "float32", False, False)) +TensorIntrin.register( + MMA_e5m2e5m2f32_INTRIN, + *get_mma_intrin(32, "e5m2_float8", "e5m2_float8", "float32", False, False), +) MMA_e5m2e5m2f32_TRANS_B_INTRIN = "mma_e5m2e5m2f32_trans_b" -TensorIntrin.register(MMA_e5m2e5m2f32_TRANS_B_INTRIN, *get_mma_intrin(32, "e5m2_float8", "e5m2_float8", "float32", False, True)) +TensorIntrin.register( + MMA_e5m2e5m2f32_TRANS_B_INTRIN, + *get_mma_intrin(32, "e5m2_float8", "e5m2_float8", "float32", False, True), +) MMA_e4m3e4m3f32_INTRIN = "mma_e4m3e4m3f32" -TensorIntrin.register(MMA_e4m3e4m3f32_INTRIN, *get_mma_intrin(32, "e4m3_float8", "e4m3_float8", "float32", False, False)) +TensorIntrin.register( + MMA_e4m3e4m3f32_INTRIN, + *get_mma_intrin(32, "e4m3_float8", "e4m3_float8", "float32", False, False), +) MMA_e4m3e4m3f32_TRANS_B_INTRIN = "mma_e4m3e4m3f32_trans_b" -TensorIntrin.register(MMA_e4m3e4m3f32_TRANS_B_INTRIN, *get_mma_intrin(32, "e4m3_float8", "e4m3_float8", "float32", False, True)) +TensorIntrin.register( + MMA_e4m3e4m3f32_TRANS_B_INTRIN, + *get_mma_intrin(32, "e4m3_float8", "e4m3_float8", "float32", False, True), +) def get_mma_fill_intrin(dtype, local_size): From bc1e918f26b54d031dc4fbc8df577260d32d7002 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 29 Apr 2024 06:27:40 +0000 Subject: [PATCH 09/13] Fix formatting in codegen_cuda.cc --- src/target/source/codegen_cuda.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index f17845bdc527..49de04609fc9 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -55,8 +55,7 @@ std::string GetFP8Type(DataType type) { vec = "_8"; } else if (lanes == 16) { vec = "_16"; - } - else { + } else { LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8, 16) for FP8"; } if (type.code() == DataType::kE4M3Float) { From ec8ac6cf4dc42342abdd278309279ccd6c197e3f Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 29 Apr 2024 06:51:43 +0000 Subject: [PATCH 10/13] lint fix for ptc.cc --- src/target/source/ptx.cc | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/target/source/ptx.cc b/src/target/source/ptx.cc index 8e16dc06036e..c9c15ee0cb2e 100644 --- a/src/target/source/ptx.cc +++ b/src/target/source/ptx.cc @@ -69,11 +69,12 @@ enum class DataType : int { kBit64 = 22 }; -static const char* dtype_str[] = {".s4", ".u4", ".s8", ".u8", ".s16", ".u16", ".s32", - ".u32", ".s64", ".u64", ".e4m3", ".e5m2", ".f16", ".bf16", ".f16x2", ".f32", - ".tf32", ".f64", ".b1", ".b8", ".b16", ".b32", ".b64"}; -static const uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, 64, 64, 8, 8, 16, - 16, 32, 32, 32, 64, 1, 8, 16, 32, 64}; +static const char* dtype_str[] = {".s4", ".u4", ".s8", ".u8", ".s16", ".u16", + ".s32", ".u32", ".s64", ".u64", ".e4m3", ".e5m2", + ".f16", ".bf16", ".f16x2", ".f32", ".tf32", ".f64", + ".b1", ".b8", ".b16", ".b32", ".b64"}; +static const uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, 64, 64, 8, 8, + 16, 16, 32, 32, 32, 64, 1, 8, 16, 32, 64}; /*! * \brief Create PTX data type from string. @@ -99,12 +100,11 @@ inline DataType DTypeFromString(const std::string str) { return DataType::kInt64; } else if (str == "uint64" || str == ".u64") { return DataType::kUInt64; - } else if (str == "e4m3" || str == ".e4m3"){ + } else if (str == "e4m3" || str == ".e4m3") { return DataType::kFloat8_e4m3; - } else if (str == "e5m2" || str == ".e5m2"){ + } else if (str == "e5m2" || str == ".e5m2") { return DataType::kFloat8_e5m2; - } - else if (str == "float16" || str == "fp16" || str == ".f16") { + } else if (str == "float16" || str == "fp16" || str == ".f16") { return DataType::kFloat16; } else if (str == "bfloat16" || str == "bf16") { return DataType::kBFloat16; From aca4a6f738842eb5d70de392cd0259ccdddb8dac Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 29 Apr 2024 11:05:34 +0000 Subject: [PATCH 11/13] update comments --- python/tvm/tir/tensor_intrin/cuda.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index e3ff5706a894..893cb1f9d6bd 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -723,8 +723,11 @@ def get_mma_intrin_group( store_scope : Literal["global", "shared", "shared.dyn"] The memory scope of the result buffer. - in_dtype : str - The input data type. + a_dtype : str + The dtype of the input matrix A. + + b_dtype : str + The dtype of the input matrix B. out_dtype : str The output data dtype. From c6e60e20e379f4dbde134ccbac3a6289a5d71da4 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 3 Jul 2024 06:15:19 +0000 Subject: [PATCH 12/13] chore: Refactor CUDA tensor intrinsics function signature --- python/tvm/tir/tensor_intrin/cuda.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index 893cb1f9d6bd..e1ff18bc8fb9 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -705,8 +705,7 @@ def mma_store_impl(a: T.handle, c: T.handle) -> None: def get_mma_intrin_group( load_scope: Literal["shared", "shared.dyn"], store_scope: Literal["global", "shared", "shared.dyn"], - a_dtype: Literal["float16", "int8", "e4m3_float8", "e5m2_float8"], - b_dtype: Literal["float16", "int8", "e4m3_float8", "e5m2_float8"], + in_dtype: Literal["float16", "int8", "e4m3_float8", "e5m2_float8"], out_dtype: Literal["float16", "float32", "int32"], trans_a: bool, trans_b: bool, @@ -723,11 +722,8 @@ def get_mma_intrin_group( store_scope : Literal["global", "shared", "shared.dyn"] The memory scope of the result buffer. - a_dtype : str - The dtype of the input matrix A. - - b_dtype : str - The dtype of the input matrix B. + in_dtype : str + The input data type. out_dtype : str The output data dtype. @@ -756,8 +752,7 @@ def get_mma_intrin_group( """ assert load_scope in ["shared", "shared.dyn"] assert store_scope in ["global", "shared", "shared.dyn"] - assert a_dtype in ["float16", "int8", "e4m3_float8", "e5m2_float8"] - assert b_dtype in ["float16", "int8", "e4m3_float8", "e5m2_float8"] + assert in_dtype in ["float16", "int8", "e4m3_float8", "e5m2_float8"] assert out_dtype in ["float16", "float32", "int32"] shape = "16x16" @@ -770,8 +765,8 @@ def get_mma_intrin_group( "e5m2_float8": "e5m2", "int32": "i32", } - a_dtype = dtype_mapping[a_dtype] - b_dtype = dtype_mapping[b_dtype] + a_dtype = dtype_mapping[in_dtype] + b_dtype = dtype_mapping[in_dtype] out_dtype = dtype_mapping[out_dtype] # e.g. mma_fill_16x16_f32 From 0cfaf284a39290d9abdd60ed4868584723a60f9a Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 3 Jul 2024 06:29:25 +0000 Subject: [PATCH 13/13] remove debug print --- tests/python/codegen/test_target_codegen_cuda_fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/codegen/test_target_codegen_cuda_fp8.py b/tests/python/codegen/test_target_codegen_cuda_fp8.py index 80725c8a1d22..d04262a3701a 100644 --- a/tests/python/codegen/test_target_codegen_cuda_fp8.py +++ b/tests/python/codegen/test_target_codegen_cuda_fp8.py @@ -849,7 +849,7 @@ def func( mod = tvm.IRModule({"main": func}) rtmod = tvm.build(mod, target="cuda") - print(rtmod.imported_modules[0].get_source()) + num_experts = 8 reduce_size = 1792 spatial_size = 4096