From 6b8608e7c553cfd56a3c8dac64ff04b015e0e17d Mon Sep 17 00:00:00 2001 From: Joshua Hong Date: Fri, 21 Feb 2025 17:00:45 -0500 Subject: [PATCH 01/14] Add bf support --- include/tvm/runtime/data_type.h | 2 ++ python/tvm/relax/expr.py | 5 +++++ src/relax/op/nn/nn.cc | 2 +- src/relax/op/op_common.h | 2 +- src/tir/transforms/unsupported_dtype_legalize.cc | 14 ++++++++++++++ 5 files changed, 23 insertions(+), 2 deletions(-) diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index c49fde1746bc..bd4a189b3f3e 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -120,6 +120,8 @@ class DataType { bool is_bool() const { return code() == DataType::kUInt && bits() == 1; } /*! \return whether type is a float type. */ bool is_float() const { return code() == DataType::kFloat; } + /*! \return whether type is a bfloat type. */ + bool is_bfloat() const { return code() == DataType::kBFloat; } /*! \return whether type is a float8 type. */ bool is_float8() const { return (code() == DataType::kFloat || code() == DataType::kE4M3Float || diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 190df4286056..fd99d6e8437a 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -21,6 +21,11 @@ import numpy as _np # type: ignore +try: + import ml_dtypes +except ImportError: + ml_dtypes = None + import tvm import tvm._ffi import tvm.ir diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 7eccf47e4b06..03cdb11a8187 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -191,7 +191,7 @@ bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx, axes_non_neg = NormalizeAxes(call, ctx, data_sinfo->ndim, axes); } int n_axis = axes.size(); - if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) { + if (!data_sinfo->IsUnknownDtype() && (!data_sinfo->dtype.is_float() && !data_sinfo->dtype.is_bfloat())) { ctx->ReportFatal( Diagnostic::Error(call) << op << " requires the input data to have float dtype. However, the given data dtype is " diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index eb9caae4b9e1..4c3c7475928a 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -201,7 +201,7 @@ template inline StructInfo InferStructInfoUnary(const Call& call, const BlockBuilder& ctx, FType f_compute_out_dtype) { TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); - if (require_float_dtype && !input_sinfo->IsUnknownDtype() && !input_sinfo->dtype.is_float()) { + if (require_float_dtype && !input_sinfo->IsUnknownDtype() && (!input_sinfo->dtype.is_float() && !input_sinfo->dtype.is_bfloat())) { ctx->ReportFatal( Diagnostic::Error(call) << call->op diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index c75ecf77e708..9d67a962c29d 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -350,6 +350,9 @@ class ComputeLegalizer : public StmtExprMutator { if (it != var_remap_.end()) { return AttrStmt(it->second, op->attr_key, op->value, op->body); } + } else if (auto reducer = op->node.as()) { + auto identity_elements = reducer->identity_element.Map([this](PrimExpr expr) { return this->VisitExpr(expr); }); + return AttrStmt(CommReducer(reducer->lhs, reducer->rhs, reducer->result, identity_elements, reducer->span), op->attr_key, op->value, op->body); } return ret; } @@ -586,6 +589,9 @@ class StorageLegalizer : public StmtExprMutator { if (it != var_remap_.end()) { return AttrStmt(it->second, op->attr_key, op->value, op->body); } + } else if (auto reducer = op->node.as()) { + auto identity_elements = reducer->identity_element.Map([this](PrimExpr expr) { return this->VisitExpr(expr); }); + return AttrStmt(CommReducer(reducer->lhs, reducer->rhs, reducer->result, identity_elements, reducer->span), op->attr_key, op->value, op->body); } return ret; } @@ -714,6 +720,10 @@ bool CheckDataTypeSupport(const Target& target, const std::string& support_func_ Pass BF16ComputeLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto target = f->GetAttr(tvm::attr::kTarget).value(); + if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_bf16")) { + return f; + } // TODO(tvm-team): skip if the target supports bf16 return BF16ComputeLegalizer().Legalize(f); }; @@ -724,6 +734,10 @@ TVM_REGISTER_GLOBAL("tir.transform.BF16ComputeLegalize").set_body_typed(BF16Comp Pass BF16StorageLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto target = f->GetAttr(tvm::attr::kTarget).value(); + if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_bf16")) { + return f; + } // TODO(tvm-team): skip if the target supports bf16 return BF16StorageLegalizer().Legalize(f); }; From 3db78562fbdb2d25bc3cb8a74a8fc4c24a70f277 Mon Sep 17 00:00:00 2001 From: Joshua Hong Date: Sat, 22 Feb 2025 01:28:39 -0500 Subject: [PATCH 02/14] Add negative/positive case for base bf legalization --- .../test_tir_transform_bf16_legalize.py | 158 +++++++++--------- 1 file changed, 77 insertions(+), 81 deletions(-) diff --git a/tests/python/tir-transform/test_tir_transform_bf16_legalize.py b/tests/python/tir-transform/test_tir_transform_bf16_legalize.py index e2752e8bbb3c..31f03672b49e 100644 --- a/tests/python/tir-transform/test_tir_transform_bf16_legalize.py +++ b/tests/python/tir-transform/test_tir_transform_bf16_legalize.py @@ -16,27 +16,9 @@ # under the License. import tvm import tvm.script +from tvm.target import Target from tvm.script import tir as T - - -def get_before(): - @tvm.script.ir_module - class Before: - @T.prim_func - def main( - Aptr: T.handle("bfloat16"), Bptr: T.handle("bfloat16"), Dptr: T.handle("bfloat16") - ): - T.func_attr({"global_symbol": "main"}) - A = T.decl_buffer((100,), "bfloat16", data=Aptr) - B = T.decl_buffer((100,), "bfloat16", data=Bptr) - D = T.decl_buffer((100,), "bfloat16", data=Dptr) - C = T.decl_buffer((100,), "bfloat16") - for i in T.grid(100): - C[i] = A[i] + B[i] - D[i] = T.exp(C[i]) - - return Before - +from tvm.tir.transform.transform import BindTarget def u16tof32(v): uint32_v = v.astype("uint32") @@ -59,62 +41,7 @@ def f32tou16(v): def f32tobf16(v): return T.reinterpret("bfloat16", f32tou16(v)) - -def get_after_compute_legalize(): - @tvm.script.ir_module - class After: - @T.prim_func - def main( - Aptr: T.handle("bfloat16"), Bptr: T.handle("bfloat16"), Dptr: T.handle("bfloat16") - ): - T.func_attr({"global_symbol": "main"}) - A = T.decl_buffer((100,), "bfloat16", data=Aptr) - B = T.decl_buffer((100,), "bfloat16", data=Bptr) - D = T.decl_buffer((100,), "bfloat16", data=Dptr) - C = T.decl_buffer((100,), "float32") - for i in T.grid(100): - C[i] = bf16tof32(A[i]) + bf16tof32(B[i]) - D[i] = f32tobf16(T.exp(C[i])) - - return After - - -def get_after_storage_legalize(): - @tvm.script.ir_module - class After: - @T.prim_func - def main(Aptr: T.handle("uint16"), Bptr: T.handle("uint16"), Dptr: T.handle("uint16")): - T.func_attr({"global_symbol": "main"}) - A = T.decl_buffer((100,), "uint16", data=Aptr) - B = T.decl_buffer((100,), "uint16", data=Bptr) - D = T.decl_buffer((100,), "uint16", data=Dptr) - C = T.decl_buffer((100,), "float32") - for i in T.grid(100): - C[i] = u16tof32(A[i]) + u16tof32(B[i]) - D[i] = f32tou16(T.exp(C[i])) - - return After - - -def test_bf16_compute_legalize(): - before = get_before() - expected = get_after_compute_legalize() - # run the transform twice to ensure we can afford to deal - # with this repeative optimizations - after = tvm.tir.transform.BF16ComputeLegalize()(before) - after = tvm.tir.transform.BF16ComputeLegalize()(after) - - tvm.ir.assert_structural_equal(after, expected) - - -def test_bf16_storage_legalize(): - before = get_after_compute_legalize() - after = tvm.tir.transform.BF16StorageLegalize()(before) - expected = get_after_storage_legalize() - tvm.ir.assert_structural_equal(after, expected) - - -def test_bf16_storage_scope(): +def test_bf16_storage_compute_scope_will_legalize(): def get_before(): @tvm.script.ir_module class Before: @@ -175,13 +102,82 @@ def main( return After - before = get_before() + target = Target("nvidia/geforce-rtx-2080-ti") + before = BindTarget(target)(get_before()) + after_compute = tvm.tir.transform.BF16ComputeLegalize()(before) + after_storage = tvm.tir.transform.BF16StorageLegalize()(after_compute) + tvm.ir.assert_structural_equal(after_compute, BindTarget(target)(after_compute_legalize())) + tvm.ir.assert_structural_equal(after_storage, BindTarget(target)(after_storage_legalize())) + +def test_bf16_storage_compute_scope_wont_legalize(): + def get_before(): + @tvm.script.ir_module + class Before: + @T.prim_func + def main( + Aptr: T.handle("bfloat16", storage_scope="shared"), + Bptr: T.handle("bfloat16", storage_scope="local"), + Dptr: T.handle("bfloat16"), + ): + T.func_attr({"global_symbol": "main"}) + A = T.decl_buffer((100,), "bfloat16", data=Aptr) + B = T.decl_buffer((100,), "bfloat16", data=Bptr) + D = T.decl_buffer((100,), "bfloat16", data=Dptr) + C = T.decl_buffer((100,), "bfloat16") + for i in T.grid(100): + C[i] = A[i] + B[i] + D[i] = T.exp(C[i]) + + return Before + + def after_compute_legalize(): + @tvm.script.ir_module + class After: + @T.prim_func + def main( + Aptr: T.handle("bfloat16", storage_scope="shared"), + Bptr: T.handle("bfloat16", storage_scope="local"), + Dptr: T.handle("bfloat16"), + ): + T.func_attr({"global_symbol": "main"}) + A = T.decl_buffer((100,), "bfloat16", data=Aptr) + B = T.decl_buffer((100,), "bfloat16", data=Bptr) + D = T.decl_buffer((100,), "bfloat16", data=Dptr) + C = T.decl_buffer((100,), "bfloat16") + for i in T.grid(100): + C[i] = A[i] + B[i] + D[i] = T.exp(C[i]) + + return After + + def after_storage_legalize(): + @tvm.script.ir_module + class After: + @T.prim_func + def main( + Aptr: T.handle("bfloat16", storage_scope="shared"), + Bptr: T.handle("bfloat16", storage_scope="local"), + Dptr: T.handle("bfloat16"), + ): + T.func_attr({"global_symbol": "main"}) + A = T.decl_buffer((100,), "bfloat16", data=Aptr) + B = T.decl_buffer((100,), "bfloat16", data=Bptr) + D = T.decl_buffer((100,), "bfloat16", data=Dptr) + C = T.decl_buffer((100,), "bfloat16") + for i in T.grid(100): + C[i] = A[i] + B[i] + D[i] = T.exp(C[i]) + + return After + + target = Target("nvidia/geforce-rtx-3090-ti") + before = BindTarget(target)(get_before()) after_compute = tvm.tir.transform.BF16ComputeLegalize()(before) after_storage = tvm.tir.transform.BF16StorageLegalize()(after_compute) - tvm.ir.assert_structural_equal(after_compute, after_compute_legalize()) - tvm.ir.assert_structural_equal(after_storage, after_storage_legalize()) + tvm.ir.assert_structural_equal(after_compute, BindTarget(target)(after_compute_legalize())) + tvm.ir.assert_structural_equal(after_storage, BindTarget(target)(after_storage_legalize())) if __name__ == "__main__": - test_bf16_storage_legalize() - test_bf16_storage_scope() + test_bf16_storage_compute_scope_will_legalize() + test_bf16_storage_compute_scope_wont_legalize() \ No newline at end of file From 735c5ffde2128e66a89cff3da757981284129967 Mon Sep 17 00:00:00 2001 From: Joshua Hong Date: Sat, 22 Feb 2025 22:16:01 -0500 Subject: [PATCH 03/14] Add tests and fix reduce --- include/tvm/script/ir_builder/tir/ir.h | 2 + python/tvm/_ffi/runtime_ctypes.py | 2 + python/tvm/script/ir_builder/tir/ir.py | 3 +- src/script/ir_builder/tir/ir.cc | 2 + .../transforms/unsupported_dtype_legalize.cc | 39 +++- .../test_tir_transform_bf16_legalize.py | 193 +++++++++++++++++- 6 files changed, 232 insertions(+), 9 deletions(-) diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index 380c2fcce25d..9058e35352dc 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -473,6 +473,7 @@ inline Var Handle(runtime::DataType dtype = runtime::DataType::Void(), TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##32, FDType(32)); \ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##64, FDType(64)); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(BFloat, DataType::BFloat); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Float, DataType::Float); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(UInt, DataType::UInt); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Int, DataType::Int); @@ -490,6 +491,7 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Int, DataType::Int); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##32, FDType, 32); \ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##64, FDType, 64); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(BFloat, DataType::BFloat); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Float, DataType::Float); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(UInt, DataType::UInt); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int); diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index f79df1644e28..c6262d4b727a 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -233,6 +233,8 @@ def itemsize(self): DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "e4m3_float8" DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e5m2)] = "e5m2_float8" + DataType.STR2DTYPE["bfloat16"] = {"type_code": DataTypeCode.BFLOAT, "bits": 16, "lanes": 1} + RPC_SESS_MASK = 128 diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index da0e2954e83b..5596611029d1 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1457,6 +1457,7 @@ def func( e5m2_float8x32 = func_gen(("E5M2Float8x32")) e5m2_float8x64 = func_gen(("E5M2Float8x64")) +bfloat16 = func_gen(("BFloat16")) # pylint: enable=invalid-name @@ -1953,7 +1954,6 @@ def wrapped(*args, **kwargs): # pylint: enable=invalid-name - __all__ = [ "int8", "int16", @@ -2033,6 +2033,7 @@ def wrapped(*args, **kwargs): "float16x64", "float32x64", "float64x64", + "bfloat16", "buffer", "buffer_decl", "prim_func", diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 17353561ee54..441126bd8ad1 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -752,8 +752,10 @@ TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Float", Float); TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.UInt", UInt); TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Int", Int); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.BFloat16").set_body_typed(BFloat16); TVM_REGISTER_GLOBAL("script.ir_builder.tir.E4M3Float8").set_body_typed(E4M3Float8); TVM_REGISTER_GLOBAL("script.ir_builder.tir.E5M2Float8").set_body_typed(E5M2Float8); +TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.BFloat16", BFloat16); TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E4M3Float8", E4M3Float8); TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E5M2Float8", E5M2Float8); diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index 9d67a962c29d..11a40c577c59 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -351,8 +351,38 @@ class ComputeLegalizer : public StmtExprMutator { return AttrStmt(it->second, op->attr_key, op->value, op->body); } } else if (auto reducer = op->node.as()) { - auto identity_elements = reducer->identity_element.Map([this](PrimExpr expr) { return this->VisitExpr(expr); }); - return AttrStmt(CommReducer(reducer->lhs, reducer->rhs, reducer->result, identity_elements, reducer->span), op->attr_key, op->value, op->body); + auto legalized_identity_elements = reducer->identity_element.Map([this](PrimExpr expr) { return this->VisitExpr(expr); }); + + // Remap input variables + for (size_t i = 0; i < legalized_identity_elements.size(); i++) { + Var lhs_var = reducer->lhs[i]; + if (lhs_var.dtype() != legalized_identity_elements[i].dtype()) { + var_remap_[lhs_var] = lhs_var.copy_with_dtype(legalized_identity_elements[i].dtype()); + } + Var rhs_var = reducer->rhs[i]; + if (rhs_var.dtype() != legalized_identity_elements[i].dtype()) { + var_remap_[rhs_var] = rhs_var.copy_with_dtype(legalized_identity_elements[i].dtype()); + } + } + + auto legalized_results = reducer->result.Map([this](PrimExpr expr) { return this->VisitExpr(expr); }); + + auto legalized_lhs = reducer->lhs.Map([this](Var var) { + auto it = var_remap_.find(var); + if (it != var_remap_.end()) { + return it->second; + } + return var; + }); + + auto legalized_rhs = reducer->rhs.Map([this](Var var) { + auto it = var_remap_.find(var); + if (it != var_remap_.end()) { + return it->second; + } + return var; + }); + return AttrStmt(CommReducer(legalized_lhs, legalized_rhs, legalized_results, legalized_identity_elements, reducer->span), op->attr_key, op->value, op->body); } return ret; } @@ -589,9 +619,6 @@ class StorageLegalizer : public StmtExprMutator { if (it != var_remap_.end()) { return AttrStmt(it->second, op->attr_key, op->value, op->body); } - } else if (auto reducer = op->node.as()) { - auto identity_elements = reducer->identity_element.Map([this](PrimExpr expr) { return this->VisitExpr(expr); }); - return AttrStmt(CommReducer(reducer->lhs, reducer->rhs, reducer->result, identity_elements, reducer->span), op->attr_key, op->value, op->body); } return ret; } @@ -724,7 +751,6 @@ Pass BF16ComputeLegalize() { if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_bf16")) { return f; } - // TODO(tvm-team): skip if the target supports bf16 return BF16ComputeLegalizer().Legalize(f); }; return CreatePrimFuncPass(pass_func, 0, "tir.BF16ComputeLegalize", {}); @@ -738,7 +764,6 @@ Pass BF16StorageLegalize() { if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_bf16")) { return f; } - // TODO(tvm-team): skip if the target supports bf16 return BF16StorageLegalizer().Legalize(f); }; return CreatePrimFuncPass(pass_func, 0, "tir.BF16StorageLegalize", {}); diff --git a/tests/python/tir-transform/test_tir_transform_bf16_legalize.py b/tests/python/tir-transform/test_tir_transform_bf16_legalize.py index 31f03672b49e..1d6232fec745 100644 --- a/tests/python/tir-transform/test_tir_transform_bf16_legalize.py +++ b/tests/python/tir-transform/test_tir_transform_bf16_legalize.py @@ -177,7 +177,198 @@ def main( tvm.ir.assert_structural_equal(after_compute, BindTarget(target)(after_compute_legalize())) tvm.ir.assert_structural_equal(after_storage, BindTarget(target)(after_storage_legalize())) +def test_bf16_reduce_will_legalize(): + def get_before(): + @tvm.script.ir_module + class Before: + @T.prim_func(private=True) + def main( + Aptr: T.handle("bfloat16", storage_scope="shared"), + ): + A_flat = T.decl_buffer(4096, "bfloat16", data=Aptr) + + for i in range(128): + threadIdx_x = T.launch_thread("threadIdx.x", 32) + + reduce = T.decl_buffer(1, dtype="bfloat16", scope="local") + + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.bfloat16(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + T.tvm_thread_allreduce( + T.uint32(1), + A_flat[0], + T.bool(True), + reduce[0], + threadIdx_x, + ) + + return Before + + def after_compute_legalize(): + @tvm.script.ir_module + class After: + @T.prim_func(private=True) + def main( + Aptr: T.handle("bfloat16", storage_scope="shared"), + ): + A_flat_1 = T.decl_buffer(4096, "bfloat16", data=Aptr) + + for i in range(128): + threadIdx_x = T.launch_thread("threadIdx.x", 32) + + reduce = T.decl_buffer(1, dtype="float32", scope="local") + + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + T.tvm_thread_allreduce( + T.uint32(1), + T.reinterpret("float32", T.shift_left(T.Cast("uint32", T.reinterpret("uint16", A_flat_1[0])), T.uint32(16))), + T.bool(True), + reduce[0], + threadIdx_x, + ) + + return After + + def after_storage_legalize(): + @tvm.script.ir_module + class After: + @T.prim_func(private=True) + def main( + Aptr: T.handle("uint16", storage_scope="shared"), + ): + A_flat_1 = T.decl_buffer(4096, "uint16", data=Aptr) + + for i in range(128): + threadIdx_x = T.launch_thread("threadIdx.x", 32) + + reduce = T.decl_buffer(1, dtype="float32", scope="local") + + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + T.tvm_thread_allreduce( + T.uint32(1), + T.reinterpret("float32", T.shift_left(T.Cast("uint32", T.reinterpret("uint16", A_flat_1[0])), T.uint32(16))), + T.bool(True), + reduce[0], + threadIdx_x, + ) + + return After + + target = Target("nvidia/geforce-rtx-2080-ti") + before = BindTarget(target)(get_before()) + after_compute = tvm.tir.transform.BF16ComputeLegalize()(before) + after_storage = tvm.tir.transform.BF16StorageLegalize()(after_compute) + tvm.ir.assert_structural_equal(after_compute, BindTarget(target)(after_compute_legalize())) + tvm.ir.assert_structural_equal(after_storage, BindTarget(target)(after_storage_legalize())) + +def test_bf16_reduce_wont_legalize(): + def get_before(): + @tvm.script.ir_module + class Before: + @T.prim_func(private=True) + def main( + Aptr: T.handle("bfloat16", storage_scope="shared"), + ): + A_flat = T.decl_buffer(4096, "bfloat16", data=Aptr) + + for i in range(128): + threadIdx_x = T.launch_thread("threadIdx.x", 32) + + reduce = T.decl_buffer(1, dtype="bfloat16", scope="local") + + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.bfloat16(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + T.tvm_thread_allreduce( + T.uint32(1), + A_flat[0], + T.bool(True), + reduce[0], + threadIdx_x, + ) + + return Before + + def after_compute_legalize(): + @tvm.script.ir_module + class After: + @T.prim_func(private=True) + def main( + Aptr: T.handle("bfloat16", storage_scope="shared"), + ): + A_flat = T.decl_buffer(4096, "bfloat16", data=Aptr) + + for i in range(128): + threadIdx_x = T.launch_thread("threadIdx.x", 32) + + reduce = T.decl_buffer(1, dtype="bfloat16", scope="local") + + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.bfloat16(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + T.tvm_thread_allreduce( + T.uint32(1), + A_flat[0], + T.bool(True), + reduce[0], + threadIdx_x, + ) + + return After + + def after_storage_legalize(): + @tvm.script.ir_module + class After: + @T.prim_func(private=True) + def main( + Aptr: T.handle("bfloat16", storage_scope="shared"), + ): + A_flat = T.decl_buffer(4096, "bfloat16", data=Aptr) + + for i in range(128): + threadIdx_x = T.launch_thread("threadIdx.x", 32) + + reduce = T.decl_buffer(1, dtype="bfloat16", scope="local") + + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.bfloat16(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + T.tvm_thread_allreduce( + T.uint32(1), + A_flat[0], + T.bool(True), + reduce[0], + threadIdx_x, + ) + + return After + + target = Target("nvidia/geforce-rtx-3090-ti") + before = BindTarget(target)(get_before()) + after_compute = tvm.tir.transform.BF16ComputeLegalize()(before) + after_storage = tvm.tir.transform.BF16StorageLegalize()(after_compute) + tvm.ir.assert_structural_equal(after_compute, BindTarget(target)(after_compute_legalize())) + tvm.ir.assert_structural_equal(after_storage, BindTarget(target)(after_storage_legalize())) if __name__ == "__main__": test_bf16_storage_compute_scope_will_legalize() - test_bf16_storage_compute_scope_wont_legalize() \ No newline at end of file + test_bf16_storage_compute_scope_wont_legalize() + test_bf16_reduce_will_legalize() + test_bf16_reduce_wont_legalize() \ No newline at end of file From e2ee869ae027929d3a7fd66418a9b0b0f5de3677 Mon Sep 17 00:00:00 2001 From: Joshua Hong Date: Sun, 23 Feb 2025 11:28:04 -0500 Subject: [PATCH 04/14] Codegen fix --- src/target/source/codegen_cuda.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 040051825119..05d9d5e17927 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -1542,7 +1542,7 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val os << '('; } if (i % 2 == 0) { - os << "__pack_bfloat162(" << value; + os << "__pack_nv_bfloat162(" << value; } else { os << "," << value << ")"; if (i != t.lanes() - 1) { From cc34c94dc71da3dae9bd1af3817154954ca7661b Mon Sep 17 00:00:00 2001 From: Joshua Hong Date: Fri, 21 Feb 2025 17:00:45 -0500 Subject: [PATCH 05/14] Add bf support --- include/tvm/runtime/data_type.h | 2 ++ python/tvm/relax/expr.py | 5 +++++ src/relax/op/nn/nn.cc | 2 +- src/relax/op/op_common.h | 2 +- src/tir/transforms/unsupported_dtype_legalize.cc | 14 ++++++++++++++ 5 files changed, 23 insertions(+), 2 deletions(-) diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index c49fde1746bc..bd4a189b3f3e 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -120,6 +120,8 @@ class DataType { bool is_bool() const { return code() == DataType::kUInt && bits() == 1; } /*! \return whether type is a float type. */ bool is_float() const { return code() == DataType::kFloat; } + /*! \return whether type is a bfloat type. */ + bool is_bfloat() const { return code() == DataType::kBFloat; } /*! \return whether type is a float8 type. */ bool is_float8() const { return (code() == DataType::kFloat || code() == DataType::kE4M3Float || diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 58845ce48986..050b91b0f0ef 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -21,6 +21,11 @@ import numpy as _np # type: ignore +try: + import ml_dtypes +except ImportError: + ml_dtypes = None + import tvm import tvm._ffi import tvm.ir diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 526b816d0945..136a2ae17b99 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -191,7 +191,7 @@ bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx, axes_non_neg = NormalizeAxes(call, ctx, data_sinfo->ndim, axes); } int n_axis = axes.size(); - if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) { + if (!data_sinfo->IsUnknownDtype() && (!data_sinfo->dtype.is_float() && !data_sinfo->dtype.is_bfloat())) { ctx->ReportFatal( Diagnostic::Error(call) << op << " requires the input data to have float dtype. However, the given data dtype is " diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index 6e2ef6bd2bef..bd58b64b9d3c 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -199,7 +199,7 @@ template inline StructInfo InferStructInfoUnary(const Call& call, const BlockBuilder& ctx, FType f_compute_out_dtype) { TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); - if (require_float_dtype && !input_sinfo->IsUnknownDtype() && !input_sinfo->dtype.is_float()) { + if (require_float_dtype && !input_sinfo->IsUnknownDtype() && (!input_sinfo->dtype.is_float() && !input_sinfo->dtype.is_bfloat())) { ctx->ReportFatal( Diagnostic::Error(call) << call->op diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index c75ecf77e708..9d67a962c29d 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -350,6 +350,9 @@ class ComputeLegalizer : public StmtExprMutator { if (it != var_remap_.end()) { return AttrStmt(it->second, op->attr_key, op->value, op->body); } + } else if (auto reducer = op->node.as()) { + auto identity_elements = reducer->identity_element.Map([this](PrimExpr expr) { return this->VisitExpr(expr); }); + return AttrStmt(CommReducer(reducer->lhs, reducer->rhs, reducer->result, identity_elements, reducer->span), op->attr_key, op->value, op->body); } return ret; } @@ -586,6 +589,9 @@ class StorageLegalizer : public StmtExprMutator { if (it != var_remap_.end()) { return AttrStmt(it->second, op->attr_key, op->value, op->body); } + } else if (auto reducer = op->node.as()) { + auto identity_elements = reducer->identity_element.Map([this](PrimExpr expr) { return this->VisitExpr(expr); }); + return AttrStmt(CommReducer(reducer->lhs, reducer->rhs, reducer->result, identity_elements, reducer->span), op->attr_key, op->value, op->body); } return ret; } @@ -714,6 +720,10 @@ bool CheckDataTypeSupport(const Target& target, const std::string& support_func_ Pass BF16ComputeLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto target = f->GetAttr(tvm::attr::kTarget).value(); + if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_bf16")) { + return f; + } // TODO(tvm-team): skip if the target supports bf16 return BF16ComputeLegalizer().Legalize(f); }; @@ -724,6 +734,10 @@ TVM_REGISTER_GLOBAL("tir.transform.BF16ComputeLegalize").set_body_typed(BF16Comp Pass BF16StorageLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto target = f->GetAttr(tvm::attr::kTarget).value(); + if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_bf16")) { + return f; + } // TODO(tvm-team): skip if the target supports bf16 return BF16StorageLegalizer().Legalize(f); }; From e0007e1d65c5c0541b00429cb26ccd2c3619c14b Mon Sep 17 00:00:00 2001 From: Joshua Hong Date: Sat, 22 Feb 2025 01:28:39 -0500 Subject: [PATCH 06/14] Add negative/positive case for base bf legalization --- .../test_tir_transform_bf16_legalize.py | 158 +++++++++--------- 1 file changed, 77 insertions(+), 81 deletions(-) diff --git a/tests/python/tir-transform/test_tir_transform_bf16_legalize.py b/tests/python/tir-transform/test_tir_transform_bf16_legalize.py index e2752e8bbb3c..31f03672b49e 100644 --- a/tests/python/tir-transform/test_tir_transform_bf16_legalize.py +++ b/tests/python/tir-transform/test_tir_transform_bf16_legalize.py @@ -16,27 +16,9 @@ # under the License. import tvm import tvm.script +from tvm.target import Target from tvm.script import tir as T - - -def get_before(): - @tvm.script.ir_module - class Before: - @T.prim_func - def main( - Aptr: T.handle("bfloat16"), Bptr: T.handle("bfloat16"), Dptr: T.handle("bfloat16") - ): - T.func_attr({"global_symbol": "main"}) - A = T.decl_buffer((100,), "bfloat16", data=Aptr) - B = T.decl_buffer((100,), "bfloat16", data=Bptr) - D = T.decl_buffer((100,), "bfloat16", data=Dptr) - C = T.decl_buffer((100,), "bfloat16") - for i in T.grid(100): - C[i] = A[i] + B[i] - D[i] = T.exp(C[i]) - - return Before - +from tvm.tir.transform.transform import BindTarget def u16tof32(v): uint32_v = v.astype("uint32") @@ -59,62 +41,7 @@ def f32tou16(v): def f32tobf16(v): return T.reinterpret("bfloat16", f32tou16(v)) - -def get_after_compute_legalize(): - @tvm.script.ir_module - class After: - @T.prim_func - def main( - Aptr: T.handle("bfloat16"), Bptr: T.handle("bfloat16"), Dptr: T.handle("bfloat16") - ): - T.func_attr({"global_symbol": "main"}) - A = T.decl_buffer((100,), "bfloat16", data=Aptr) - B = T.decl_buffer((100,), "bfloat16", data=Bptr) - D = T.decl_buffer((100,), "bfloat16", data=Dptr) - C = T.decl_buffer((100,), "float32") - for i in T.grid(100): - C[i] = bf16tof32(A[i]) + bf16tof32(B[i]) - D[i] = f32tobf16(T.exp(C[i])) - - return After - - -def get_after_storage_legalize(): - @tvm.script.ir_module - class After: - @T.prim_func - def main(Aptr: T.handle("uint16"), Bptr: T.handle("uint16"), Dptr: T.handle("uint16")): - T.func_attr({"global_symbol": "main"}) - A = T.decl_buffer((100,), "uint16", data=Aptr) - B = T.decl_buffer((100,), "uint16", data=Bptr) - D = T.decl_buffer((100,), "uint16", data=Dptr) - C = T.decl_buffer((100,), "float32") - for i in T.grid(100): - C[i] = u16tof32(A[i]) + u16tof32(B[i]) - D[i] = f32tou16(T.exp(C[i])) - - return After - - -def test_bf16_compute_legalize(): - before = get_before() - expected = get_after_compute_legalize() - # run the transform twice to ensure we can afford to deal - # with this repeative optimizations - after = tvm.tir.transform.BF16ComputeLegalize()(before) - after = tvm.tir.transform.BF16ComputeLegalize()(after) - - tvm.ir.assert_structural_equal(after, expected) - - -def test_bf16_storage_legalize(): - before = get_after_compute_legalize() - after = tvm.tir.transform.BF16StorageLegalize()(before) - expected = get_after_storage_legalize() - tvm.ir.assert_structural_equal(after, expected) - - -def test_bf16_storage_scope(): +def test_bf16_storage_compute_scope_will_legalize(): def get_before(): @tvm.script.ir_module class Before: @@ -175,13 +102,82 @@ def main( return After - before = get_before() + target = Target("nvidia/geforce-rtx-2080-ti") + before = BindTarget(target)(get_before()) + after_compute = tvm.tir.transform.BF16ComputeLegalize()(before) + after_storage = tvm.tir.transform.BF16StorageLegalize()(after_compute) + tvm.ir.assert_structural_equal(after_compute, BindTarget(target)(after_compute_legalize())) + tvm.ir.assert_structural_equal(after_storage, BindTarget(target)(after_storage_legalize())) + +def test_bf16_storage_compute_scope_wont_legalize(): + def get_before(): + @tvm.script.ir_module + class Before: + @T.prim_func + def main( + Aptr: T.handle("bfloat16", storage_scope="shared"), + Bptr: T.handle("bfloat16", storage_scope="local"), + Dptr: T.handle("bfloat16"), + ): + T.func_attr({"global_symbol": "main"}) + A = T.decl_buffer((100,), "bfloat16", data=Aptr) + B = T.decl_buffer((100,), "bfloat16", data=Bptr) + D = T.decl_buffer((100,), "bfloat16", data=Dptr) + C = T.decl_buffer((100,), "bfloat16") + for i in T.grid(100): + C[i] = A[i] + B[i] + D[i] = T.exp(C[i]) + + return Before + + def after_compute_legalize(): + @tvm.script.ir_module + class After: + @T.prim_func + def main( + Aptr: T.handle("bfloat16", storage_scope="shared"), + Bptr: T.handle("bfloat16", storage_scope="local"), + Dptr: T.handle("bfloat16"), + ): + T.func_attr({"global_symbol": "main"}) + A = T.decl_buffer((100,), "bfloat16", data=Aptr) + B = T.decl_buffer((100,), "bfloat16", data=Bptr) + D = T.decl_buffer((100,), "bfloat16", data=Dptr) + C = T.decl_buffer((100,), "bfloat16") + for i in T.grid(100): + C[i] = A[i] + B[i] + D[i] = T.exp(C[i]) + + return After + + def after_storage_legalize(): + @tvm.script.ir_module + class After: + @T.prim_func + def main( + Aptr: T.handle("bfloat16", storage_scope="shared"), + Bptr: T.handle("bfloat16", storage_scope="local"), + Dptr: T.handle("bfloat16"), + ): + T.func_attr({"global_symbol": "main"}) + A = T.decl_buffer((100,), "bfloat16", data=Aptr) + B = T.decl_buffer((100,), "bfloat16", data=Bptr) + D = T.decl_buffer((100,), "bfloat16", data=Dptr) + C = T.decl_buffer((100,), "bfloat16") + for i in T.grid(100): + C[i] = A[i] + B[i] + D[i] = T.exp(C[i]) + + return After + + target = Target("nvidia/geforce-rtx-3090-ti") + before = BindTarget(target)(get_before()) after_compute = tvm.tir.transform.BF16ComputeLegalize()(before) after_storage = tvm.tir.transform.BF16StorageLegalize()(after_compute) - tvm.ir.assert_structural_equal(after_compute, after_compute_legalize()) - tvm.ir.assert_structural_equal(after_storage, after_storage_legalize()) + tvm.ir.assert_structural_equal(after_compute, BindTarget(target)(after_compute_legalize())) + tvm.ir.assert_structural_equal(after_storage, BindTarget(target)(after_storage_legalize())) if __name__ == "__main__": - test_bf16_storage_legalize() - test_bf16_storage_scope() + test_bf16_storage_compute_scope_will_legalize() + test_bf16_storage_compute_scope_wont_legalize() \ No newline at end of file From 8e46f8e927e62911f453aa86cd4583ac8b7fe360 Mon Sep 17 00:00:00 2001 From: Joshua Hong Date: Sat, 22 Feb 2025 22:16:01 -0500 Subject: [PATCH 07/14] Add tests and fix reduce --- include/tvm/script/ir_builder/tir/ir.h | 2 + python/tvm/_ffi/runtime_ctypes.py | 2 + python/tvm/script/ir_builder/tir/ir.py | 3 +- src/script/ir_builder/tir/ir.cc | 2 + .../transforms/unsupported_dtype_legalize.cc | 39 +++- .../test_tir_transform_bf16_legalize.py | 193 +++++++++++++++++- 6 files changed, 232 insertions(+), 9 deletions(-) diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index 380c2fcce25d..9058e35352dc 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -473,6 +473,7 @@ inline Var Handle(runtime::DataType dtype = runtime::DataType::Void(), TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##32, FDType(32)); \ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##64, FDType(64)); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(BFloat, DataType::BFloat); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Float, DataType::Float); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(UInt, DataType::UInt); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Int, DataType::Int); @@ -490,6 +491,7 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Int, DataType::Int); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##32, FDType, 32); \ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##64, FDType, 64); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(BFloat, DataType::BFloat); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Float, DataType::Float); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(UInt, DataType::UInt); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int); diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index f79df1644e28..c6262d4b727a 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -233,6 +233,8 @@ def itemsize(self): DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "e4m3_float8" DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e5m2)] = "e5m2_float8" + DataType.STR2DTYPE["bfloat16"] = {"type_code": DataTypeCode.BFLOAT, "bits": 16, "lanes": 1} + RPC_SESS_MASK = 128 diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index da0e2954e83b..5596611029d1 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1457,6 +1457,7 @@ def func( e5m2_float8x32 = func_gen(("E5M2Float8x32")) e5m2_float8x64 = func_gen(("E5M2Float8x64")) +bfloat16 = func_gen(("BFloat16")) # pylint: enable=invalid-name @@ -1953,7 +1954,6 @@ def wrapped(*args, **kwargs): # pylint: enable=invalid-name - __all__ = [ "int8", "int16", @@ -2033,6 +2033,7 @@ def wrapped(*args, **kwargs): "float16x64", "float32x64", "float64x64", + "bfloat16", "buffer", "buffer_decl", "prim_func", diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 17353561ee54..441126bd8ad1 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -752,8 +752,10 @@ TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Float", Float); TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.UInt", UInt); TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Int", Int); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.BFloat16").set_body_typed(BFloat16); TVM_REGISTER_GLOBAL("script.ir_builder.tir.E4M3Float8").set_body_typed(E4M3Float8); TVM_REGISTER_GLOBAL("script.ir_builder.tir.E5M2Float8").set_body_typed(E5M2Float8); +TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.BFloat16", BFloat16); TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E4M3Float8", E4M3Float8); TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.E5M2Float8", E5M2Float8); diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index 9d67a962c29d..11a40c577c59 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -351,8 +351,38 @@ class ComputeLegalizer : public StmtExprMutator { return AttrStmt(it->second, op->attr_key, op->value, op->body); } } else if (auto reducer = op->node.as()) { - auto identity_elements = reducer->identity_element.Map([this](PrimExpr expr) { return this->VisitExpr(expr); }); - return AttrStmt(CommReducer(reducer->lhs, reducer->rhs, reducer->result, identity_elements, reducer->span), op->attr_key, op->value, op->body); + auto legalized_identity_elements = reducer->identity_element.Map([this](PrimExpr expr) { return this->VisitExpr(expr); }); + + // Remap input variables + for (size_t i = 0; i < legalized_identity_elements.size(); i++) { + Var lhs_var = reducer->lhs[i]; + if (lhs_var.dtype() != legalized_identity_elements[i].dtype()) { + var_remap_[lhs_var] = lhs_var.copy_with_dtype(legalized_identity_elements[i].dtype()); + } + Var rhs_var = reducer->rhs[i]; + if (rhs_var.dtype() != legalized_identity_elements[i].dtype()) { + var_remap_[rhs_var] = rhs_var.copy_with_dtype(legalized_identity_elements[i].dtype()); + } + } + + auto legalized_results = reducer->result.Map([this](PrimExpr expr) { return this->VisitExpr(expr); }); + + auto legalized_lhs = reducer->lhs.Map([this](Var var) { + auto it = var_remap_.find(var); + if (it != var_remap_.end()) { + return it->second; + } + return var; + }); + + auto legalized_rhs = reducer->rhs.Map([this](Var var) { + auto it = var_remap_.find(var); + if (it != var_remap_.end()) { + return it->second; + } + return var; + }); + return AttrStmt(CommReducer(legalized_lhs, legalized_rhs, legalized_results, legalized_identity_elements, reducer->span), op->attr_key, op->value, op->body); } return ret; } @@ -589,9 +619,6 @@ class StorageLegalizer : public StmtExprMutator { if (it != var_remap_.end()) { return AttrStmt(it->second, op->attr_key, op->value, op->body); } - } else if (auto reducer = op->node.as()) { - auto identity_elements = reducer->identity_element.Map([this](PrimExpr expr) { return this->VisitExpr(expr); }); - return AttrStmt(CommReducer(reducer->lhs, reducer->rhs, reducer->result, identity_elements, reducer->span), op->attr_key, op->value, op->body); } return ret; } @@ -724,7 +751,6 @@ Pass BF16ComputeLegalize() { if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_bf16")) { return f; } - // TODO(tvm-team): skip if the target supports bf16 return BF16ComputeLegalizer().Legalize(f); }; return CreatePrimFuncPass(pass_func, 0, "tir.BF16ComputeLegalize", {}); @@ -738,7 +764,6 @@ Pass BF16StorageLegalize() { if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_bf16")) { return f; } - // TODO(tvm-team): skip if the target supports bf16 return BF16StorageLegalizer().Legalize(f); }; return CreatePrimFuncPass(pass_func, 0, "tir.BF16StorageLegalize", {}); diff --git a/tests/python/tir-transform/test_tir_transform_bf16_legalize.py b/tests/python/tir-transform/test_tir_transform_bf16_legalize.py index 31f03672b49e..1d6232fec745 100644 --- a/tests/python/tir-transform/test_tir_transform_bf16_legalize.py +++ b/tests/python/tir-transform/test_tir_transform_bf16_legalize.py @@ -177,7 +177,198 @@ def main( tvm.ir.assert_structural_equal(after_compute, BindTarget(target)(after_compute_legalize())) tvm.ir.assert_structural_equal(after_storage, BindTarget(target)(after_storage_legalize())) +def test_bf16_reduce_will_legalize(): + def get_before(): + @tvm.script.ir_module + class Before: + @T.prim_func(private=True) + def main( + Aptr: T.handle("bfloat16", storage_scope="shared"), + ): + A_flat = T.decl_buffer(4096, "bfloat16", data=Aptr) + + for i in range(128): + threadIdx_x = T.launch_thread("threadIdx.x", 32) + + reduce = T.decl_buffer(1, dtype="bfloat16", scope="local") + + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.bfloat16(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + T.tvm_thread_allreduce( + T.uint32(1), + A_flat[0], + T.bool(True), + reduce[0], + threadIdx_x, + ) + + return Before + + def after_compute_legalize(): + @tvm.script.ir_module + class After: + @T.prim_func(private=True) + def main( + Aptr: T.handle("bfloat16", storage_scope="shared"), + ): + A_flat_1 = T.decl_buffer(4096, "bfloat16", data=Aptr) + + for i in range(128): + threadIdx_x = T.launch_thread("threadIdx.x", 32) + + reduce = T.decl_buffer(1, dtype="float32", scope="local") + + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + T.tvm_thread_allreduce( + T.uint32(1), + T.reinterpret("float32", T.shift_left(T.Cast("uint32", T.reinterpret("uint16", A_flat_1[0])), T.uint32(16))), + T.bool(True), + reduce[0], + threadIdx_x, + ) + + return After + + def after_storage_legalize(): + @tvm.script.ir_module + class After: + @T.prim_func(private=True) + def main( + Aptr: T.handle("uint16", storage_scope="shared"), + ): + A_flat_1 = T.decl_buffer(4096, "uint16", data=Aptr) + + for i in range(128): + threadIdx_x = T.launch_thread("threadIdx.x", 32) + + reduce = T.decl_buffer(1, dtype="float32", scope="local") + + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + T.tvm_thread_allreduce( + T.uint32(1), + T.reinterpret("float32", T.shift_left(T.Cast("uint32", T.reinterpret("uint16", A_flat_1[0])), T.uint32(16))), + T.bool(True), + reduce[0], + threadIdx_x, + ) + + return After + + target = Target("nvidia/geforce-rtx-2080-ti") + before = BindTarget(target)(get_before()) + after_compute = tvm.tir.transform.BF16ComputeLegalize()(before) + after_storage = tvm.tir.transform.BF16StorageLegalize()(after_compute) + tvm.ir.assert_structural_equal(after_compute, BindTarget(target)(after_compute_legalize())) + tvm.ir.assert_structural_equal(after_storage, BindTarget(target)(after_storage_legalize())) + +def test_bf16_reduce_wont_legalize(): + def get_before(): + @tvm.script.ir_module + class Before: + @T.prim_func(private=True) + def main( + Aptr: T.handle("bfloat16", storage_scope="shared"), + ): + A_flat = T.decl_buffer(4096, "bfloat16", data=Aptr) + + for i in range(128): + threadIdx_x = T.launch_thread("threadIdx.x", 32) + + reduce = T.decl_buffer(1, dtype="bfloat16", scope="local") + + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.bfloat16(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + T.tvm_thread_allreduce( + T.uint32(1), + A_flat[0], + T.bool(True), + reduce[0], + threadIdx_x, + ) + + return Before + + def after_compute_legalize(): + @tvm.script.ir_module + class After: + @T.prim_func(private=True) + def main( + Aptr: T.handle("bfloat16", storage_scope="shared"), + ): + A_flat = T.decl_buffer(4096, "bfloat16", data=Aptr) + + for i in range(128): + threadIdx_x = T.launch_thread("threadIdx.x", 32) + + reduce = T.decl_buffer(1, dtype="bfloat16", scope="local") + + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.bfloat16(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + T.tvm_thread_allreduce( + T.uint32(1), + A_flat[0], + T.bool(True), + reduce[0], + threadIdx_x, + ) + + return After + + def after_storage_legalize(): + @tvm.script.ir_module + class After: + @T.prim_func(private=True) + def main( + Aptr: T.handle("bfloat16", storage_scope="shared"), + ): + A_flat = T.decl_buffer(4096, "bfloat16", data=Aptr) + + for i in range(128): + threadIdx_x = T.launch_thread("threadIdx.x", 32) + + reduce = T.decl_buffer(1, dtype="bfloat16", scope="local") + + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.bfloat16(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + T.tvm_thread_allreduce( + T.uint32(1), + A_flat[0], + T.bool(True), + reduce[0], + threadIdx_x, + ) + + return After + + target = Target("nvidia/geforce-rtx-3090-ti") + before = BindTarget(target)(get_before()) + after_compute = tvm.tir.transform.BF16ComputeLegalize()(before) + after_storage = tvm.tir.transform.BF16StorageLegalize()(after_compute) + tvm.ir.assert_structural_equal(after_compute, BindTarget(target)(after_compute_legalize())) + tvm.ir.assert_structural_equal(after_storage, BindTarget(target)(after_storage_legalize())) if __name__ == "__main__": test_bf16_storage_compute_scope_will_legalize() - test_bf16_storage_compute_scope_wont_legalize() \ No newline at end of file + test_bf16_storage_compute_scope_wont_legalize() + test_bf16_reduce_will_legalize() + test_bf16_reduce_wont_legalize() \ No newline at end of file From 321070b19baa0c61d362bd333ca3c5d3fda72d3a Mon Sep 17 00:00:00 2001 From: Joshua Hong Date: Sun, 23 Feb 2025 11:28:04 -0500 Subject: [PATCH 08/14] Codegen fix --- src/target/source/codegen_cuda.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 040051825119..05d9d5e17927 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -1542,7 +1542,7 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val os << '('; } if (i % 2 == 0) { - os << "__pack_bfloat162(" << value; + os << "__pack_nv_bfloat162(" << value; } else { os << "," << value << ")"; if (i != t.lanes() - 1) { From 5491f0c2741e3a042fd0e5462eb636a7a81f21b7 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 26 Feb 2025 10:29:34 -0500 Subject: [PATCH 09/14] Fix lint --- src/relax/op/nn/nn.cc | 3 ++- src/relax/op/op_common.h | 3 ++- src/target/source/codegen_cuda.cc | 4 +-- .../transforms/unsupported_dtype_legalize.cc | 26 +++++++++++-------- .../test_tir_transform_bf16_legalize.py | 26 ++++++++++++++++--- 5 files changed, 43 insertions(+), 19 deletions(-) diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 136a2ae17b99..b4668d65d399 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -191,7 +191,8 @@ bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx, axes_non_neg = NormalizeAxes(call, ctx, data_sinfo->ndim, axes); } int n_axis = axes.size(); - if (!data_sinfo->IsUnknownDtype() && (!data_sinfo->dtype.is_float() && !data_sinfo->dtype.is_bfloat())) { + if (!data_sinfo->IsUnknownDtype() && + (!data_sinfo->dtype.is_float() && !data_sinfo->dtype.is_bfloat())) { ctx->ReportFatal( Diagnostic::Error(call) << op << " requires the input data to have float dtype. However, the given data dtype is " diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index bd58b64b9d3c..eea6db22fdda 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -199,7 +199,8 @@ template inline StructInfo InferStructInfoUnary(const Call& call, const BlockBuilder& ctx, FType f_compute_out_dtype) { TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); - if (require_float_dtype && !input_sinfo->IsUnknownDtype() && (!input_sinfo->dtype.is_float() && !input_sinfo->dtype.is_bfloat())) { + if (require_float_dtype && !input_sinfo->IsUnknownDtype() && + (!input_sinfo->dtype.is_float() && !input_sinfo->dtype.is_bfloat())) { ctx->ReportFatal( Diagnostic::Error(call) << call->op diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 05d9d5e17927..7877fe814ef4 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -994,8 +994,8 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { var_idmap_[inverse_index_map->initial_indices[1].get()] = "local_id"; os << "for (int local_id = 0; local_id < 8; ++local_id) {\n"; - os << dst << "[" + this->PrintExpr(dst_ind) + "]" - << " = " << src << "[" << src_offset << " + local_id];\n"; + os << dst << "[" + this->PrintExpr(dst_ind) + "]" << " = " << src << "[" << src_offset + << " + local_id];\n"; os << "}\n"; } else if (op->op.same_as(builtin::mma_fill())) { diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index 11a40c577c59..386e00018ea9 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -351,21 +351,23 @@ class ComputeLegalizer : public StmtExprMutator { return AttrStmt(it->second, op->attr_key, op->value, op->body); } } else if (auto reducer = op->node.as()) { - auto legalized_identity_elements = reducer->identity_element.Map([this](PrimExpr expr) { return this->VisitExpr(expr); }); + auto legalized_identity_elements = + reducer->identity_element.Map([this](PrimExpr expr) { return this->VisitExpr(expr); }); // Remap input variables for (size_t i = 0; i < legalized_identity_elements.size(); i++) { - Var lhs_var = reducer->lhs[i]; - if (lhs_var.dtype() != legalized_identity_elements[i].dtype()) { - var_remap_[lhs_var] = lhs_var.copy_with_dtype(legalized_identity_elements[i].dtype()); - } - Var rhs_var = reducer->rhs[i]; - if (rhs_var.dtype() != legalized_identity_elements[i].dtype()) { - var_remap_[rhs_var] = rhs_var.copy_with_dtype(legalized_identity_elements[i].dtype()); - } + Var lhs_var = reducer->lhs[i]; + if (lhs_var.dtype() != legalized_identity_elements[i].dtype()) { + var_remap_[lhs_var] = lhs_var.copy_with_dtype(legalized_identity_elements[i].dtype()); + } + Var rhs_var = reducer->rhs[i]; + if (rhs_var.dtype() != legalized_identity_elements[i].dtype()) { + var_remap_[rhs_var] = rhs_var.copy_with_dtype(legalized_identity_elements[i].dtype()); + } } - auto legalized_results = reducer->result.Map([this](PrimExpr expr) { return this->VisitExpr(expr); }); + auto legalized_results = + reducer->result.Map([this](PrimExpr expr) { return this->VisitExpr(expr); }); auto legalized_lhs = reducer->lhs.Map([this](Var var) { auto it = var_remap_.find(var); @@ -382,7 +384,9 @@ class ComputeLegalizer : public StmtExprMutator { } return var; }); - return AttrStmt(CommReducer(legalized_lhs, legalized_rhs, legalized_results, legalized_identity_elements, reducer->span), op->attr_key, op->value, op->body); + return AttrStmt(CommReducer(legalized_lhs, legalized_rhs, legalized_results, + legalized_identity_elements, reducer->span), + op->attr_key, op->value, op->body); } return ret; } diff --git a/tests/python/tir-transform/test_tir_transform_bf16_legalize.py b/tests/python/tir-transform/test_tir_transform_bf16_legalize.py index 1d6232fec745..9ef72db35c0b 100644 --- a/tests/python/tir-transform/test_tir_transform_bf16_legalize.py +++ b/tests/python/tir-transform/test_tir_transform_bf16_legalize.py @@ -16,10 +16,11 @@ # under the License. import tvm import tvm.script -from tvm.target import Target from tvm.script import tir as T +from tvm.target import Target from tvm.tir.transform.transform import BindTarget + def u16tof32(v): uint32_v = v.astype("uint32") uint32_v = uint32_v << tvm.tir.const(16, "uint32") @@ -41,6 +42,7 @@ def f32tou16(v): def f32tobf16(v): return T.reinterpret("bfloat16", f32tou16(v)) + def test_bf16_storage_compute_scope_will_legalize(): def get_before(): @tvm.script.ir_module @@ -109,6 +111,7 @@ def main( tvm.ir.assert_structural_equal(after_compute, BindTarget(target)(after_compute_legalize())) tvm.ir.assert_structural_equal(after_storage, BindTarget(target)(after_storage_legalize())) + def test_bf16_storage_compute_scope_wont_legalize(): def get_before(): @tvm.script.ir_module @@ -177,6 +180,7 @@ def main( tvm.ir.assert_structural_equal(after_compute, BindTarget(target)(after_compute_legalize())) tvm.ir.assert_structural_equal(after_storage, BindTarget(target)(after_storage_legalize())) + def test_bf16_reduce_will_legalize(): def get_before(): @tvm.script.ir_module @@ -228,7 +232,13 @@ def main( ): T.tvm_thread_allreduce( T.uint32(1), - T.reinterpret("float32", T.shift_left(T.Cast("uint32", T.reinterpret("uint16", A_flat_1[0])), T.uint32(16))), + T.reinterpret( + "float32", + T.shift_left( + T.Cast("uint32", T.reinterpret("uint16", A_flat_1[0])), + T.uint32(16), + ), + ), T.bool(True), reduce[0], threadIdx_x, @@ -257,7 +267,13 @@ def main( ): T.tvm_thread_allreduce( T.uint32(1), - T.reinterpret("float32", T.shift_left(T.Cast("uint32", T.reinterpret("uint16", A_flat_1[0])), T.uint32(16))), + T.reinterpret( + "float32", + T.shift_left( + T.Cast("uint32", T.reinterpret("uint16", A_flat_1[0])), + T.uint32(16), + ), + ), T.bool(True), reduce[0], threadIdx_x, @@ -272,6 +288,7 @@ def main( tvm.ir.assert_structural_equal(after_compute, BindTarget(target)(after_compute_legalize())) tvm.ir.assert_structural_equal(after_storage, BindTarget(target)(after_storage_legalize())) + def test_bf16_reduce_wont_legalize(): def get_before(): @tvm.script.ir_module @@ -367,8 +384,9 @@ def main( tvm.ir.assert_structural_equal(after_compute, BindTarget(target)(after_compute_legalize())) tvm.ir.assert_structural_equal(after_storage, BindTarget(target)(after_storage_legalize())) + if __name__ == "__main__": test_bf16_storage_compute_scope_will_legalize() test_bf16_storage_compute_scope_wont_legalize() test_bf16_reduce_will_legalize() - test_bf16_reduce_wont_legalize() \ No newline at end of file + test_bf16_reduce_wont_legalize() From 9ddcac3721a9c46e755673cdfe0edb68da65e568 Mon Sep 17 00:00:00 2001 From: Joshua Hong Date: Fri, 7 Mar 2025 12:36:37 -0500 Subject: [PATCH 10/14] Add target cnull heck tand target binding o disallow async strided mem copy --- .../postproc/disallow_async_strided_mem_copy.cc | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc index d654e467f1e7..e006950a34e9 100644 --- a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc +++ b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc @@ -119,7 +119,11 @@ namespace meta_schedule { class DisallowAsyncStridedMemCopyNode : public PostprocNode { public: // Inherited from PostprocNode - void InitializeWithTuneContext(const TuneContext& context) final {} + void InitializeWithTuneContext(const TuneContext& context) final { + /* Null check */ + ICHECK(context->target) << "Context must contain a target"; + this->target = context->target.value(); + } // Inherited from PostprocNode bool Apply(const tir::Schedule& sch) final { IRModule mod = sch->mod(); @@ -130,6 +134,7 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode { IRModule lowered{nullptr}; try { auto pass_list = Array(); + pass_list.push_back(tir::transform::BindTarget(this->target)); pass_list.push_back(tir::transform::LowerInitBlock()); pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); @@ -168,6 +173,8 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode { static constexpr const char* _type_key = "meta_schedule.DisallowAsyncStridedMemCopy"; TVM_DECLARE_FINAL_OBJECT_INFO(DisallowAsyncStridedMemCopyNode, PostprocNode); + private: + tvm::Target target; }; Postproc Postproc::DisallowAsyncStridedMemCopy() { From 00c7b446859f8fd2e195b8ec8ef9bca180494fee Mon Sep 17 00:00:00 2001 From: Joshua Hong Date: Fri, 7 Mar 2025 16:13:24 -0500 Subject: [PATCH 11/14] Formatting fixes --- python/tvm/_ffi/runtime_ctypes.py | 3 +-- python/tvm/script/ir_builder/tir/ir.py | 1 - src/tir/transforms/unsupported_dtype_legalize.cc | 1 - tests/python/tir-transform/test_tir_transform_bf16_legalize.py | 2 ++ 4 files changed, 3 insertions(+), 4 deletions(-) diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index e72a6099192a..8a9c231617a2 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -118,6 +118,7 @@ class DataType(ctypes.Structure): "float16": {"type_code": DataTypeCode.FLOAT, "bits": 16, "lanes": 1}, "float32": {"type_code": DataTypeCode.FLOAT, "bits": 32, "lanes": 1}, "float64": {"type_code": DataTypeCode.FLOAT, "bits": 64, "lanes": 1}, + "bfloat16": {"type_code": DataTypeCode.BFLOAT, "bits": 16, "lanes": 1}, } def __init__(self, type_str): @@ -251,8 +252,6 @@ def itemsize(self): DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e5m2)] = "float8_e5m2" DataType.NUMPY2STR[np.dtype(ml_dtypes.float4_e2m1fn)] = "float4_e2m1fn" - DataType.STR2DTYPE["bfloat16"] = {"type_code": DataTypeCode.BFLOAT, "bits": 16, "lanes": 1} - RPC_SESS_MASK = 128 diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 8d822ae3333e..3e835e8d9dbe 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1464,7 +1464,6 @@ def func( float4_e2m1fnx32 = func_gen(("Float4E2M1FNx32")) float4_e2m1fnx64 = func_gen(("Float4E2M1FNx64")) - bfloat16 = func_gen(("BFloat16")) # pylint: enable=invalid-name diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index 386e00018ea9..e20ffcff0bfc 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -339,7 +339,6 @@ class ComputeLegalizer : public StmtExprMutator { Stmt VisitStmt_(const AttrStmtNode* op) final { Stmt ret = StmtExprMutator::VisitStmt_(op); op = ret.as(); - if (auto buffer = op->node.as()) { auto it = buffer_remap_.find(buffer.value()); if (it != buffer_remap_.end()) { diff --git a/tests/python/tir-transform/test_tir_transform_bf16_legalize.py b/tests/python/tir-transform/test_tir_transform_bf16_legalize.py index 64fd94691ee6..fa1aa558b6d0 100644 --- a/tests/python/tir-transform/test_tir_transform_bf16_legalize.py +++ b/tests/python/tir-transform/test_tir_transform_bf16_legalize.py @@ -43,6 +43,7 @@ def f32tou16(v): def f32tobf16(v): return T.reinterpret("bfloat16", f32tou16(v)) + def test_bf16_storage_compute_scope_will_legalize(): def get_before(): @tvm.script.ir_module @@ -384,6 +385,7 @@ def main( tvm.ir.assert_structural_equal(after_compute, BindTarget(target)(after_compute_legalize())) tvm.ir.assert_structural_equal(after_storage, BindTarget(target)(after_storage_legalize())) + if __name__ == "__main__": test_bf16_storage_compute_scope_will_legalize() test_bf16_storage_compute_scope_wont_legalize() From 9cad8ca2ece2a7b54b0edf7c1272f49eb09bc762 Mon Sep 17 00:00:00 2001 From: Joshua Hong Date: Fri, 7 Mar 2025 16:29:01 -0500 Subject: [PATCH 12/14] Remove unused import --- python/tvm/relax/expr.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 050b91b0f0ef..58845ce48986 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -21,11 +21,6 @@ import numpy as _np # type: ignore -try: - import ml_dtypes -except ImportError: - ml_dtypes = None - import tvm import tvm._ffi import tvm.ir From 21f1e15bbdb10f0b2c60f67035c3070ad0b3b2c6 Mon Sep 17 00:00:00 2001 From: Joshua Hong Date: Fri, 7 Mar 2025 16:48:54 -0500 Subject: [PATCH 13/14] Lint fix --- src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc index e006950a34e9..a6a71202ae15 100644 --- a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc +++ b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc @@ -173,6 +173,7 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode { static constexpr const char* _type_key = "meta_schedule.DisallowAsyncStridedMemCopy"; TVM_DECLARE_FINAL_OBJECT_INFO(DisallowAsyncStridedMemCopyNode, PostprocNode); + private: tvm::Target target; }; From 52e0d05770da5e6ed34e350f9714c90aba15dfa2 Mon Sep 17 00:00:00 2001 From: Joshua Hong Date: Sat, 8 Mar 2025 18:20:53 -0500 Subject: [PATCH 14/14] Fix for bf16 related codegen --- src/target/source/intrin_rule_cuda.cc | 8 +++++++- src/tir/op/op.cc | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/target/source/intrin_rule_cuda.cc b/src/target/source/intrin_rule_cuda.cc index 79ea7a458ff0..e762bde69f4d 100644 --- a/src/target/source/intrin_rule_cuda.cc +++ b/src/target/source/intrin_rule_cuda.cc @@ -53,7 +53,13 @@ struct CUDAMath { return ""; } } else if (t.is_bfloat16()) { - return 'h' + name; + if (name == "fabs") { + return "__habs"; + } else if (name == "round") { + return "hrint"; + } else { + return "h" + name; + } } else if (t.is_int() || t.is_uint()) { switch (t.bits()) { case 32: diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 63c82d1d6c11..46c15cb3dfc3 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -801,7 +801,7 @@ PrimExpr abs(PrimExpr x, Span span) { return IntImm(x.dtype(), std::abs(px->value), px->span); } return tir::Select(x >= make_zero(x.dtype()), x, -x, span); - } else if (x.dtype().is_float()) { + } else if (x.dtype().is_float() || x.dtype().is_bfloat()) { using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) {