Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ class DataType {
bool is_bool() const { return code() == DataType::kUInt && bits() == 1; }
/*! \return whether type is a float type. */
bool is_float() const { return code() == DataType::kFloat; }
/*! \return whether type is a bfloat type. */
bool is_bfloat() const { return code() == DataType::kBFloat; }
/*! \return whether type is a float8 type. */
bool is_float8() const {
return (code() == DataType::kFloat || code() == DataType::kFloat8_e4m3fn ||
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ inline Var Handle(runtime::DataType dtype = runtime::DataType::Void(),
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##32, FDType(32)); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##64, FDType(64));

TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(BFloat, DataType::BFloat);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Float, DataType::Float);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(UInt, DataType::UInt);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Int, DataType::Int);
Expand All @@ -490,6 +491,7 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Int, DataType::Int);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##32, FDType, 32); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##64, FDType, 64);

TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(BFloat, DataType::BFloat);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Float, DataType::Float);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(UInt, DataType::UInt);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int);
Expand Down
1 change: 1 addition & 0 deletions python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class DataType(ctypes.Structure):
"float16": {"type_code": DataTypeCode.FLOAT, "bits": 16, "lanes": 1},
"float32": {"type_code": DataTypeCode.FLOAT, "bits": 32, "lanes": 1},
"float64": {"type_code": DataTypeCode.FLOAT, "bits": 64, "lanes": 1},
"bfloat16": {"type_code": DataTypeCode.BFLOAT, "bits": 16, "lanes": 1},
}

def __init__(self, type_str):
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1464,7 +1464,7 @@ def func(
float4_e2m1fnx32 = func_gen(("Float4E2M1FNx32"))
float4_e2m1fnx64 = func_gen(("Float4E2M1FNx64"))


bfloat16 = func_gen(("BFloat16"))
# pylint: enable=invalid-name


Expand Down Expand Up @@ -1961,7 +1961,6 @@ def wrapped(*args, **kwargs):

# pylint: enable=invalid-name


__all__ = [
"int8",
"int16",
Expand Down Expand Up @@ -2048,6 +2047,7 @@ def wrapped(*args, **kwargs):
"float16x64",
"float32x64",
"float64x64",
"bfloat16",
"buffer",
"buffer_decl",
"prim_func",
Expand Down
10 changes: 9 additions & 1 deletion src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,11 @@ namespace meta_schedule {
class DisallowAsyncStridedMemCopyNode : public PostprocNode {
public:
// Inherited from PostprocNode
void InitializeWithTuneContext(const TuneContext& context) final {}
void InitializeWithTuneContext(const TuneContext& context) final {
/* Null check */
ICHECK(context->target) << "Context must contain a target";
this->target = context->target.value();
}
// Inherited from PostprocNode
bool Apply(const tir::Schedule& sch) final {
IRModule mod = sch->mod();
Expand All @@ -130,6 +134,7 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode {
IRModule lowered{nullptr};
try {
auto pass_list = Array<tvm::transform::Pass>();
pass_list.push_back(tir::transform::BindTarget(this->target));
pass_list.push_back(tir::transform::LowerInitBlock());
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
Expand Down Expand Up @@ -168,6 +173,9 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode {

static constexpr const char* _type_key = "meta_schedule.DisallowAsyncStridedMemCopy";
TVM_DECLARE_FINAL_OBJECT_INFO(DisallowAsyncStridedMemCopyNode, PostprocNode);

private:
tvm::Target target;
};

Postproc Postproc::DisallowAsyncStridedMemCopy() {
Expand Down
3 changes: 2 additions & 1 deletion src/relax/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx,
axes_non_neg = NormalizeAxes(call, ctx, data_sinfo->ndim, axes);
}
int n_axis = axes.size();
if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) {
if (!data_sinfo->IsUnknownDtype() &&
(!data_sinfo->dtype.is_float() && !data_sinfo->dtype.is_bfloat())) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< op << " requires the input data to have float dtype. However, the given data dtype is "
Expand Down
3 changes: 2 additions & 1 deletion src/relax/op/op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ template <bool require_float_dtype, typename FType>
inline StructInfo InferStructInfoUnary(const Call& call, const BlockBuilder& ctx,
FType f_compute_out_dtype) {
TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
if (require_float_dtype && !input_sinfo->IsUnknownDtype() && !input_sinfo->dtype.is_float()) {
if (require_float_dtype && !input_sinfo->IsUnknownDtype() &&
(!input_sinfo->dtype.is_float() && !input_sinfo->dtype.is_bfloat())) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< call->op
Expand Down
2 changes: 2 additions & 0 deletions src/script/ir_builder/tir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -752,8 +752,10 @@ TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Float", Float);
TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.UInt", UInt);
TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Int", Int);

TVM_REGISTER_GLOBAL("script.ir_builder.tir.BFloat16").set_body_typed(BFloat16);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float8E4M3FN").set_body_typed(Float8E4M3FN);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float8E5M2").set_body_typed(Float8E5M2);
TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.BFloat16", BFloat16);
TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3FN", Float8E4M3FN);
TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E5M2", Float8E5M2);

Expand Down
2 changes: 1 addition & 1 deletion src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1661,7 +1661,7 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val
os << '(';
}
if (i % 2 == 0) {
os << "__pack_bfloat162(" << value;
os << "__pack_nv_bfloat162(" << value;
} else {
os << "," << value << ")";
if (i != t.lanes() - 1) {
Expand Down
8 changes: 7 additions & 1 deletion src/target/source/intrin_rule_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,13 @@ struct CUDAMath {
return "";
}
} else if (t.is_bfloat16()) {
return 'h' + name;
if (name == "fabs") {
return "__habs";
} else if (name == "round") {
return "hrint";
} else {
return "h" + name;
}
} else if (t.is_int() || t.is_uint()) {
switch (t.bits()) {
case 32:
Expand Down
2 changes: 1 addition & 1 deletion src/tir/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,7 @@ PrimExpr abs(PrimExpr x, Span span) {
return IntImm(x.dtype(), std::abs(px->value), px->span);
}
return tir::Select(x >= make_zero(x.dtype()), x, -x, span);
} else if (x.dtype().is_float()) {
} else if (x.dtype().is_float() || x.dtype().is_bfloat()) {
using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) {
Expand Down
48 changes: 45 additions & 3 deletions src/tir/transforms/unsupported_dtype_legalize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,6 @@ class ComputeLegalizer : public StmtExprMutator {
Stmt VisitStmt_(const AttrStmtNode* op) final {
Stmt ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<AttrStmtNode>();

if (auto buffer = op->node.as<Buffer>()) {
auto it = buffer_remap_.find(buffer.value());
if (it != buffer_remap_.end()) {
Expand All @@ -350,6 +349,43 @@ class ComputeLegalizer : public StmtExprMutator {
if (it != var_remap_.end()) {
return AttrStmt(it->second, op->attr_key, op->value, op->body);
}
} else if (auto reducer = op->node.as<CommReducerNode>()) {
auto legalized_identity_elements =
reducer->identity_element.Map([this](PrimExpr expr) { return this->VisitExpr(expr); });

// Remap input variables
for (size_t i = 0; i < legalized_identity_elements.size(); i++) {
Var lhs_var = reducer->lhs[i];
if (lhs_var.dtype() != legalized_identity_elements[i].dtype()) {
var_remap_[lhs_var] = lhs_var.copy_with_dtype(legalized_identity_elements[i].dtype());
}
Var rhs_var = reducer->rhs[i];
if (rhs_var.dtype() != legalized_identity_elements[i].dtype()) {
var_remap_[rhs_var] = rhs_var.copy_with_dtype(legalized_identity_elements[i].dtype());
}
}

auto legalized_results =
reducer->result.Map([this](PrimExpr expr) { return this->VisitExpr(expr); });

auto legalized_lhs = reducer->lhs.Map([this](Var var) {
auto it = var_remap_.find(var);
if (it != var_remap_.end()) {
return it->second;
}
return var;
});

auto legalized_rhs = reducer->rhs.Map([this](Var var) {
auto it = var_remap_.find(var);
if (it != var_remap_.end()) {
return it->second;
}
return var;
});
return AttrStmt(CommReducer(legalized_lhs, legalized_rhs, legalized_results,
legalized_identity_elements, reducer->span),
op->attr_key, op->value, op->body);
}
return ret;
}
Expand Down Expand Up @@ -714,7 +750,10 @@ bool CheckDataTypeSupport(const Target& target, const std::string& support_func_

Pass BF16ComputeLegalize() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
// TODO(tvm-team): skip if the target supports bf16
auto target = f->GetAttr<Target>(tvm::attr::kTarget).value();
if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_bf16")) {
return f;
}
return BF16ComputeLegalizer().Legalize(f);
};
return CreatePrimFuncPass(pass_func, 0, "tir.BF16ComputeLegalize", {});
Expand All @@ -724,7 +763,10 @@ TVM_REGISTER_GLOBAL("tir.transform.BF16ComputeLegalize").set_body_typed(BF16Comp

Pass BF16StorageLegalize() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
// TODO(tvm-team): skip if the target supports bf16
auto target = f->GetAttr<Target>(tvm::attr::kTarget).value();
if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_bf16")) {
return f;
}
return BF16StorageLegalizer().Legalize(f);
};
return CreatePrimFuncPass(pass_func, 0, "tir.BF16StorageLegalize", {});
Expand Down
Loading
Loading