diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index 65fd0c98fdb7..40664f0c40a1 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -124,6 +124,8 @@ class DataType { bool is_bool() const { return code() == DataType::kUInt && bits() == 1; } /*! \return whether type is a float type. */ bool is_float() const { return code() == DataType::kFloat; } + /*! \return whether type is a bfloat type. */ + bool is_bfloat() const { return code() == DataType::kBFloat; } /*! \return whether type is a float8 type. */ bool is_float8() const { return (code() == DataType::kFloat || code() == DataType::kFloat8_e4m3fn || diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index e60a3859acf5..b01cb8422274 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -473,6 +473,7 @@ inline Var Handle(runtime::DataType dtype = runtime::DataType::Void(), TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##32, FDType(32)); \ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##64, FDType(64)); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(BFloat, DataType::BFloat); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Float, DataType::Float); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(UInt, DataType::UInt); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Int, DataType::Int); @@ -490,6 +491,7 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Int, DataType::Int); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##32, FDType, 32); \ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##64, FDType, 64); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(BFloat, DataType::BFloat); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Float, DataType::Float); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(UInt, DataType::UInt); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int); diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 317bd6bead7c..8a9c231617a2 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -118,6 +118,7 @@ class DataType(ctypes.Structure): "float16": {"type_code": DataTypeCode.FLOAT, "bits": 16, "lanes": 1}, "float32": {"type_code": DataTypeCode.FLOAT, "bits": 32, "lanes": 1}, "float64": {"type_code": DataTypeCode.FLOAT, "bits": 64, "lanes": 1}, + "bfloat16": {"type_code": DataTypeCode.BFLOAT, "bits": 16, "lanes": 1}, } def __init__(self, type_str): diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 2fce022da365..3e835e8d9dbe 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1464,7 +1464,7 @@ def func( float4_e2m1fnx32 = func_gen(("Float4E2M1FNx32")) float4_e2m1fnx64 = func_gen(("Float4E2M1FNx64")) - +bfloat16 = func_gen(("BFloat16")) # pylint: enable=invalid-name @@ -1961,7 +1961,6 @@ def wrapped(*args, **kwargs): # pylint: enable=invalid-name - __all__ = [ "int8", "int16", @@ -2048,6 +2047,7 @@ def wrapped(*args, **kwargs): "float16x64", "float32x64", "float64x64", + "bfloat16", "buffer", "buffer_decl", "prim_func", diff --git a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc index d654e467f1e7..a6a71202ae15 100644 --- a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc +++ b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc @@ -119,7 +119,11 @@ namespace meta_schedule { class DisallowAsyncStridedMemCopyNode : public PostprocNode { public: // Inherited from PostprocNode - void InitializeWithTuneContext(const TuneContext& context) final {} + void InitializeWithTuneContext(const TuneContext& context) final { + /* Null check */ + ICHECK(context->target) << "Context must contain a target"; + this->target = context->target.value(); + } // Inherited from PostprocNode bool Apply(const tir::Schedule& sch) final { IRModule mod = sch->mod(); @@ -130,6 +134,7 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode { IRModule lowered{nullptr}; try { auto pass_list = Array(); + pass_list.push_back(tir::transform::BindTarget(this->target)); pass_list.push_back(tir::transform::LowerInitBlock()); pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); @@ -168,6 +173,9 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode { static constexpr const char* _type_key = "meta_schedule.DisallowAsyncStridedMemCopy"; TVM_DECLARE_FINAL_OBJECT_INFO(DisallowAsyncStridedMemCopyNode, PostprocNode); + + private: + tvm::Target target; }; Postproc Postproc::DisallowAsyncStridedMemCopy() { diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 526b816d0945..b4668d65d399 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -191,7 +191,8 @@ bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx, axes_non_neg = NormalizeAxes(call, ctx, data_sinfo->ndim, axes); } int n_axis = axes.size(); - if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) { + if (!data_sinfo->IsUnknownDtype() && + (!data_sinfo->dtype.is_float() && !data_sinfo->dtype.is_bfloat())) { ctx->ReportFatal( Diagnostic::Error(call) << op << " requires the input data to have float dtype. However, the given data dtype is " diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index 6e2ef6bd2bef..eea6db22fdda 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -199,7 +199,8 @@ template inline StructInfo InferStructInfoUnary(const Call& call, const BlockBuilder& ctx, FType f_compute_out_dtype) { TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); - if (require_float_dtype && !input_sinfo->IsUnknownDtype() && !input_sinfo->dtype.is_float()) { + if (require_float_dtype && !input_sinfo->IsUnknownDtype() && + (!input_sinfo->dtype.is_float() && !input_sinfo->dtype.is_bfloat())) { ctx->ReportFatal( Diagnostic::Error(call) << call->op diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index a75a35781001..83e32f5af898 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -752,8 +752,10 @@ TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Float", Float); TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.UInt", UInt); TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Int", Int); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.BFloat16").set_body_typed(BFloat16); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float8E4M3FN").set_body_typed(Float8E4M3FN); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float8E5M2").set_body_typed(Float8E5M2); +TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.BFloat16", BFloat16); TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3FN", Float8E4M3FN); TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E5M2", Float8E5M2); diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 35973776c818..34023e0bb7d7 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -1661,7 +1661,7 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val os << '('; } if (i % 2 == 0) { - os << "__pack_bfloat162(" << value; + os << "__pack_nv_bfloat162(" << value; } else { os << "," << value << ")"; if (i != t.lanes() - 1) { diff --git a/src/target/source/intrin_rule_cuda.cc b/src/target/source/intrin_rule_cuda.cc index 79ea7a458ff0..e762bde69f4d 100644 --- a/src/target/source/intrin_rule_cuda.cc +++ b/src/target/source/intrin_rule_cuda.cc @@ -53,7 +53,13 @@ struct CUDAMath { return ""; } } else if (t.is_bfloat16()) { - return 'h' + name; + if (name == "fabs") { + return "__habs"; + } else if (name == "round") { + return "hrint"; + } else { + return "h" + name; + } } else if (t.is_int() || t.is_uint()) { switch (t.bits()) { case 32: diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 63c82d1d6c11..46c15cb3dfc3 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -801,7 +801,7 @@ PrimExpr abs(PrimExpr x, Span span) { return IntImm(x.dtype(), std::abs(px->value), px->span); } return tir::Select(x >= make_zero(x.dtype()), x, -x, span); - } else if (x.dtype().is_float()) { + } else if (x.dtype().is_float() || x.dtype().is_bfloat()) { using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) { diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index c75ecf77e708..e20ffcff0bfc 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -339,7 +339,6 @@ class ComputeLegalizer : public StmtExprMutator { Stmt VisitStmt_(const AttrStmtNode* op) final { Stmt ret = StmtExprMutator::VisitStmt_(op); op = ret.as(); - if (auto buffer = op->node.as()) { auto it = buffer_remap_.find(buffer.value()); if (it != buffer_remap_.end()) { @@ -350,6 +349,43 @@ class ComputeLegalizer : public StmtExprMutator { if (it != var_remap_.end()) { return AttrStmt(it->second, op->attr_key, op->value, op->body); } + } else if (auto reducer = op->node.as()) { + auto legalized_identity_elements = + reducer->identity_element.Map([this](PrimExpr expr) { return this->VisitExpr(expr); }); + + // Remap input variables + for (size_t i = 0; i < legalized_identity_elements.size(); i++) { + Var lhs_var = reducer->lhs[i]; + if (lhs_var.dtype() != legalized_identity_elements[i].dtype()) { + var_remap_[lhs_var] = lhs_var.copy_with_dtype(legalized_identity_elements[i].dtype()); + } + Var rhs_var = reducer->rhs[i]; + if (rhs_var.dtype() != legalized_identity_elements[i].dtype()) { + var_remap_[rhs_var] = rhs_var.copy_with_dtype(legalized_identity_elements[i].dtype()); + } + } + + auto legalized_results = + reducer->result.Map([this](PrimExpr expr) { return this->VisitExpr(expr); }); + + auto legalized_lhs = reducer->lhs.Map([this](Var var) { + auto it = var_remap_.find(var); + if (it != var_remap_.end()) { + return it->second; + } + return var; + }); + + auto legalized_rhs = reducer->rhs.Map([this](Var var) { + auto it = var_remap_.find(var); + if (it != var_remap_.end()) { + return it->second; + } + return var; + }); + return AttrStmt(CommReducer(legalized_lhs, legalized_rhs, legalized_results, + legalized_identity_elements, reducer->span), + op->attr_key, op->value, op->body); } return ret; } @@ -714,7 +750,10 @@ bool CheckDataTypeSupport(const Target& target, const std::string& support_func_ Pass BF16ComputeLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - // TODO(tvm-team): skip if the target supports bf16 + auto target = f->GetAttr(tvm::attr::kTarget).value(); + if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_bf16")) { + return f; + } return BF16ComputeLegalizer().Legalize(f); }; return CreatePrimFuncPass(pass_func, 0, "tir.BF16ComputeLegalize", {}); @@ -724,7 +763,10 @@ TVM_REGISTER_GLOBAL("tir.transform.BF16ComputeLegalize").set_body_typed(BF16Comp Pass BF16StorageLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - // TODO(tvm-team): skip if the target supports bf16 + auto target = f->GetAttr(tvm::attr::kTarget).value(); + if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_bf16")) { + return f; + } return BF16StorageLegalizer().Legalize(f); }; return CreatePrimFuncPass(pass_func, 0, "tir.BF16StorageLegalize", {}); diff --git a/tests/python/tir-transform/test_tir_transform_bf16_legalize.py b/tests/python/tir-transform/test_tir_transform_bf16_legalize.py index e2752e8bbb3c..fa1aa558b6d0 100644 --- a/tests/python/tir-transform/test_tir_transform_bf16_legalize.py +++ b/tests/python/tir-transform/test_tir_transform_bf16_legalize.py @@ -16,26 +16,10 @@ # under the License. import tvm import tvm.script +from tvm.target import Target from tvm.script import tir as T - - -def get_before(): - @tvm.script.ir_module - class Before: - @T.prim_func - def main( - Aptr: T.handle("bfloat16"), Bptr: T.handle("bfloat16"), Dptr: T.handle("bfloat16") - ): - T.func_attr({"global_symbol": "main"}) - A = T.decl_buffer((100,), "bfloat16", data=Aptr) - B = T.decl_buffer((100,), "bfloat16", data=Bptr) - D = T.decl_buffer((100,), "bfloat16", data=Dptr) - C = T.decl_buffer((100,), "bfloat16") - for i in T.grid(100): - C[i] = A[i] + B[i] - D[i] = T.exp(C[i]) - - return Before +from tvm.target import Target +from tvm.tir.transform.transform import BindTarget def u16tof32(v): @@ -60,61 +44,7 @@ def f32tobf16(v): return T.reinterpret("bfloat16", f32tou16(v)) -def get_after_compute_legalize(): - @tvm.script.ir_module - class After: - @T.prim_func - def main( - Aptr: T.handle("bfloat16"), Bptr: T.handle("bfloat16"), Dptr: T.handle("bfloat16") - ): - T.func_attr({"global_symbol": "main"}) - A = T.decl_buffer((100,), "bfloat16", data=Aptr) - B = T.decl_buffer((100,), "bfloat16", data=Bptr) - D = T.decl_buffer((100,), "bfloat16", data=Dptr) - C = T.decl_buffer((100,), "float32") - for i in T.grid(100): - C[i] = bf16tof32(A[i]) + bf16tof32(B[i]) - D[i] = f32tobf16(T.exp(C[i])) - - return After - - -def get_after_storage_legalize(): - @tvm.script.ir_module - class After: - @T.prim_func - def main(Aptr: T.handle("uint16"), Bptr: T.handle("uint16"), Dptr: T.handle("uint16")): - T.func_attr({"global_symbol": "main"}) - A = T.decl_buffer((100,), "uint16", data=Aptr) - B = T.decl_buffer((100,), "uint16", data=Bptr) - D = T.decl_buffer((100,), "uint16", data=Dptr) - C = T.decl_buffer((100,), "float32") - for i in T.grid(100): - C[i] = u16tof32(A[i]) + u16tof32(B[i]) - D[i] = f32tou16(T.exp(C[i])) - - return After - - -def test_bf16_compute_legalize(): - before = get_before() - expected = get_after_compute_legalize() - # run the transform twice to ensure we can afford to deal - # with this repeative optimizations - after = tvm.tir.transform.BF16ComputeLegalize()(before) - after = tvm.tir.transform.BF16ComputeLegalize()(after) - - tvm.ir.assert_structural_equal(after, expected) - - -def test_bf16_storage_legalize(): - before = get_after_compute_legalize() - after = tvm.tir.transform.BF16StorageLegalize()(before) - expected = get_after_storage_legalize() - tvm.ir.assert_structural_equal(after, expected) - - -def test_bf16_storage_scope(): +def test_bf16_storage_compute_scope_will_legalize(): def get_before(): @tvm.script.ir_module class Before: @@ -175,13 +105,289 @@ def main( return After - before = get_before() + target = Target("nvidia/geforce-rtx-2080-ti") + before = BindTarget(target)(get_before()) + after_compute = tvm.tir.transform.BF16ComputeLegalize()(before) + after_storage = tvm.tir.transform.BF16StorageLegalize()(after_compute) + tvm.ir.assert_structural_equal(after_compute, BindTarget(target)(after_compute_legalize())) + tvm.ir.assert_structural_equal(after_storage, BindTarget(target)(after_storage_legalize())) + + +def test_bf16_storage_compute_scope_wont_legalize(): + def get_before(): + @tvm.script.ir_module + class Before: + @T.prim_func + def main( + Aptr: T.handle("bfloat16", storage_scope="shared"), + Bptr: T.handle("bfloat16", storage_scope="local"), + Dptr: T.handle("bfloat16"), + ): + T.func_attr({"global_symbol": "main"}) + A = T.decl_buffer((100,), "bfloat16", data=Aptr) + B = T.decl_buffer((100,), "bfloat16", data=Bptr) + D = T.decl_buffer((100,), "bfloat16", data=Dptr) + C = T.decl_buffer((100,), "bfloat16") + for i in T.grid(100): + C[i] = A[i] + B[i] + D[i] = T.exp(C[i]) + + return Before + + def after_compute_legalize(): + @tvm.script.ir_module + class After: + @T.prim_func + def main( + Aptr: T.handle("bfloat16", storage_scope="shared"), + Bptr: T.handle("bfloat16", storage_scope="local"), + Dptr: T.handle("bfloat16"), + ): + T.func_attr({"global_symbol": "main"}) + A = T.decl_buffer((100,), "bfloat16", data=Aptr) + B = T.decl_buffer((100,), "bfloat16", data=Bptr) + D = T.decl_buffer((100,), "bfloat16", data=Dptr) + C = T.decl_buffer((100,), "bfloat16") + for i in T.grid(100): + C[i] = A[i] + B[i] + D[i] = T.exp(C[i]) + + return After + + def after_storage_legalize(): + @tvm.script.ir_module + class After: + @T.prim_func + def main( + Aptr: T.handle("bfloat16", storage_scope="shared"), + Bptr: T.handle("bfloat16", storage_scope="local"), + Dptr: T.handle("bfloat16"), + ): + T.func_attr({"global_symbol": "main"}) + A = T.decl_buffer((100,), "bfloat16", data=Aptr) + B = T.decl_buffer((100,), "bfloat16", data=Bptr) + D = T.decl_buffer((100,), "bfloat16", data=Dptr) + C = T.decl_buffer((100,), "bfloat16") + for i in T.grid(100): + C[i] = A[i] + B[i] + D[i] = T.exp(C[i]) + + return After + + target = Target("nvidia/geforce-rtx-3090-ti") + before = BindTarget(target)(get_before()) + after_compute = tvm.tir.transform.BF16ComputeLegalize()(before) + after_storage = tvm.tir.transform.BF16StorageLegalize()(after_compute) + tvm.ir.assert_structural_equal(after_compute, BindTarget(target)(after_compute_legalize())) + tvm.ir.assert_structural_equal(after_storage, BindTarget(target)(after_storage_legalize())) + + +def test_bf16_reduce_will_legalize(): + def get_before(): + @tvm.script.ir_module + class Before: + @T.prim_func(private=True) + def main( + Aptr: T.handle("bfloat16", storage_scope="shared"), + ): + A_flat = T.decl_buffer(4096, "bfloat16", data=Aptr) + + for i in range(128): + threadIdx_x = T.launch_thread("threadIdx.x", 32) + + reduce = T.decl_buffer(1, dtype="bfloat16", scope="local") + + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.bfloat16(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + T.tvm_thread_allreduce( + T.uint32(1), + A_flat[0], + T.bool(True), + reduce[0], + threadIdx_x, + ) + + return Before + + def after_compute_legalize(): + @tvm.script.ir_module + class After: + @T.prim_func(private=True) + def main( + Aptr: T.handle("bfloat16", storage_scope="shared"), + ): + A_flat_1 = T.decl_buffer(4096, "bfloat16", data=Aptr) + + for i in range(128): + threadIdx_x = T.launch_thread("threadIdx.x", 32) + + reduce = T.decl_buffer(1, dtype="float32", scope="local") + + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + T.tvm_thread_allreduce( + T.uint32(1), + T.reinterpret( + "float32", + T.shift_left( + T.Cast("uint32", T.reinterpret("uint16", A_flat_1[0])), + T.uint32(16), + ), + ), + T.bool(True), + reduce[0], + threadIdx_x, + ) + + return After + + def after_storage_legalize(): + @tvm.script.ir_module + class After: + @T.prim_func(private=True) + def main( + Aptr: T.handle("uint16", storage_scope="shared"), + ): + A_flat_1 = T.decl_buffer(4096, "uint16", data=Aptr) + + for i in range(128): + threadIdx_x = T.launch_thread("threadIdx.x", 32) + + reduce = T.decl_buffer(1, dtype="float32", scope="local") + + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + T.tvm_thread_allreduce( + T.uint32(1), + T.reinterpret( + "float32", + T.shift_left( + T.Cast("uint32", T.reinterpret("uint16", A_flat_1[0])), + T.uint32(16), + ), + ), + T.bool(True), + reduce[0], + threadIdx_x, + ) + + return After + + target = Target("nvidia/geforce-rtx-2080-ti") + before = BindTarget(target)(get_before()) + after_compute = tvm.tir.transform.BF16ComputeLegalize()(before) + after_storage = tvm.tir.transform.BF16StorageLegalize()(after_compute) + tvm.ir.assert_structural_equal(after_compute, BindTarget(target)(after_compute_legalize())) + tvm.ir.assert_structural_equal(after_storage, BindTarget(target)(after_storage_legalize())) + + +def test_bf16_reduce_wont_legalize(): + def get_before(): + @tvm.script.ir_module + class Before: + @T.prim_func(private=True) + def main( + Aptr: T.handle("bfloat16", storage_scope="shared"), + ): + A_flat = T.decl_buffer(4096, "bfloat16", data=Aptr) + + for i in range(128): + threadIdx_x = T.launch_thread("threadIdx.x", 32) + + reduce = T.decl_buffer(1, dtype="bfloat16", scope="local") + + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.bfloat16(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + T.tvm_thread_allreduce( + T.uint32(1), + A_flat[0], + T.bool(True), + reduce[0], + threadIdx_x, + ) + + return Before + + def after_compute_legalize(): + @tvm.script.ir_module + class After: + @T.prim_func(private=True) + def main( + Aptr: T.handle("bfloat16", storage_scope="shared"), + ): + A_flat = T.decl_buffer(4096, "bfloat16", data=Aptr) + + for i in range(128): + threadIdx_x = T.launch_thread("threadIdx.x", 32) + + reduce = T.decl_buffer(1, dtype="bfloat16", scope="local") + + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.bfloat16(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + T.tvm_thread_allreduce( + T.uint32(1), + A_flat[0], + T.bool(True), + reduce[0], + threadIdx_x, + ) + + return After + + def after_storage_legalize(): + @tvm.script.ir_module + class After: + @T.prim_func(private=True) + def main( + Aptr: T.handle("bfloat16", storage_scope="shared"), + ): + A_flat = T.decl_buffer(4096, "bfloat16", data=Aptr) + + for i in range(128): + threadIdx_x = T.launch_thread("threadIdx.x", 32) + + reduce = T.decl_buffer(1, dtype="bfloat16", scope="local") + + with T.attr( + T.comm_reducer(lambda x, y: x + y, [T.bfloat16(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + T.tvm_thread_allreduce( + T.uint32(1), + A_flat[0], + T.bool(True), + reduce[0], + threadIdx_x, + ) + + return After + + target = Target("nvidia/geforce-rtx-3090-ti") + before = BindTarget(target)(get_before()) after_compute = tvm.tir.transform.BF16ComputeLegalize()(before) after_storage = tvm.tir.transform.BF16StorageLegalize()(after_compute) - tvm.ir.assert_structural_equal(after_compute, after_compute_legalize()) - tvm.ir.assert_structural_equal(after_storage, after_storage_legalize()) + tvm.ir.assert_structural_equal(after_compute, BindTarget(target)(after_compute_legalize())) + tvm.ir.assert_structural_equal(after_storage, BindTarget(target)(after_storage_legalize())) if __name__ == "__main__": - test_bf16_storage_legalize() - test_bf16_storage_scope() + test_bf16_storage_compute_scope_will_legalize() + test_bf16_storage_compute_scope_wont_legalize() + test_bf16_reduce_will_legalize() + test_bf16_reduce_wont_legalize()