From a941951fb64d69825553bf84828efb4628d736d3 Mon Sep 17 00:00:00 2001 From: Zihao Date: Fri, 2 Jun 2023 17:21:21 -0700 Subject: [PATCH 01/43] upd --- src/tir/transforms/unsupported_dtype_legalize.cc | 1 - tests/python/unittest/test_target_codegen_cuda.py | 2 ++ 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index be8876b81550..0040951961dc 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -716,7 +716,6 @@ TVM_REGISTER_GLOBAL("tir.transform.FP8ComputeLegalize").set_body_typed(FP8Comput Pass FP8StorageLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - LOG(INFO) << f; // TODO(tvm-team): skip if the target supports fp8 return FP8StorageLegalizer().Legalize(f); }; diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index e60138a9c8d6..7685d64368ef 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -101,6 +101,8 @@ def check_cuda(n, lanes): disabled_pass=["tir.BF16Promote", "tir.BF16CastElimination", "tir.BF16TypeLowering"] ): fun = tvm.build(s, [A, B], "cuda") + print(fun.imported_modules[0].get_source()) + assert False dev = tvm.cuda(0) np_a = np.random.uniform(size=(n, lanes)).astype("float32") np_a = np_bf162np_float(np_float2np_bf16(np_a)) From 32841512dea1e6aa1f73f601041a8af90756e00a Mon Sep 17 00:00:00 2001 From: Zihao Date: Fri, 2 Jun 2023 22:32:23 -0700 Subject: [PATCH 02/43] wip --- python/tvm/contrib/nvcc.py | 3 +- src/driver/driver_api.cc | 3 +- .../transforms/unsupported_dtype_legalize.cc | 34 ++++++++++++++----- .../unittest/test_target_codegen_cuda.py | 12 +++---- 4 files changed, 36 insertions(+), 16 deletions(-) diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index 5eb348009914..2ef3cf9df4dd 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -100,7 +100,8 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target # Because it is hard to do runtime compiler detection, we require nvcc is configured # correctly by default. # if cxx_compiler_path != "": - # cmd += ["-ccbin", cxx_compiler_path] + # cmd += ["-ccbin", cxx_compiler_path] + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index cfc7fa80c7a9..fecacff5b0d0 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -55,6 +55,8 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer); TVM_REGISTER_PASS_CONFIG_OPTION("tir.ptx_ldg32", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.hardware_support_bf16", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.hardware_support_fp8", Bool); // WARNING: May cause coherency issues resulting data miscompares // Experimental feature that, when enabled by the runtime, bypasses the cache when using DMA. When @@ -151,7 +153,6 @@ Array CreatePassList(bool disable_loop_partition) { bool disable_cse_tir = pass_ctx->GetConfig("tir.disable_cse_tir", Bool(false)).value(); bool enable_equiv_terms_in_cse_tir = pass_ctx->GetConfig("tir.enable_equiv_terms_in_cse_tir", Bool(false)).value(); - bool ptx_ldg32 = pass_ctx->GetConfig("tir.ptx_ldg32", Bool(false)).value(); // Get any user-added passes diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index 0040951961dc..74d881ab9e9a 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -686,8 +686,13 @@ namespace transform { Pass BF16ComputeLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - // TODO(tvm-team): skip if the target supports bf16 - return BF16ComputeLegalizer().Legalize(f); + bool target_support_bf16 = + ctx->GetConfig("tir.hardware_support_bf16", Bool(false)).value(); + if (target_support_bf16) { + return f; + } else { + return BF16ComputeLegalizer().Legalize(f); + } }; return CreatePrimFuncPass(pass_func, 0, "tir.BF16ComputeLegalize", {}); } @@ -696,8 +701,13 @@ TVM_REGISTER_GLOBAL("tir.transform.BF16ComputeLegalize").set_body_typed(BF16Comp Pass BF16StorageLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - // TODO(tvm-team): skip if the target supports bf16 - return BF16StorageLegalizer().Legalize(f); + bool target_support_bf16 = + ctx->GetConfig("tir.hardware_support_bf16", Bool(false)).value(); + if (target_support_bf16) { + return f; + } else { + return BF16StorageLegalizer().Legalize(f); + } }; return CreatePrimFuncPass(pass_func, 0, "tir.BF16StorageLegalize", {}); } @@ -706,8 +716,12 @@ TVM_REGISTER_GLOBAL("tir.transform.BF16StorageLegalize").set_body_typed(BF16Stor Pass FP8ComputeLegalize(String promote_dtype_str) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { - // TODO(tvm-team): skip if the target supports fp8 - return FP8ComputeLegalizer(DataType(String2DLDataType(promote_dtype_str))).Legalize(f); + bool target_support_fp8 = ctx->GetConfig("tir.hardware_support_fp8", Bool(false)).value(); + if (target_support_fp8) { + return f; + } else { + return FP8ComputeLegalizer(DataType(String2DLDataType(promote_dtype_str))).Legalize(f); + } }; return CreatePrimFuncPass(pass_func, 0, "tir.FP8ComputeLegalize", {}); } @@ -716,8 +730,12 @@ TVM_REGISTER_GLOBAL("tir.transform.FP8ComputeLegalize").set_body_typed(FP8Comput Pass FP8StorageLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - // TODO(tvm-team): skip if the target supports fp8 - return FP8StorageLegalizer().Legalize(f); + bool target_support_fp8 = ctx->GetConfig("tir.hardware_support_fp8", Bool(false)).value(); + if (target_support_fp8) { + return f; + } else { + return FP8StorageLegalizer().Legalize(f); + } }; return CreatePrimFuncPass(pass_func, 0, "tir.FP8StorageLegalize", {}); } diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index 7685d64368ef..fa9474e73d88 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -97,12 +97,12 @@ def check_cuda(n, lanes): xo, xi = s[B].split(B.op.axis[0], factor=num_thread) s[B].bind(xo, bx) s[B].bind(xi, tx) - with tvm.transform.PassContext( - disabled_pass=["tir.BF16Promote", "tir.BF16CastElimination", "tir.BF16TypeLowering"] - ): - fun = tvm.build(s, [A, B], "cuda") - print(fun.imported_modules[0].get_source()) - assert False + # with tvm.transform.PassContext( + # disabled_pass=["tir.BF16ComputeLegalize", "tir.BF16StorageLegalize"] + # ): + fun = tvm.build(s, [A, B], "cuda") + print(fun.imported_modules[0].get_source()) + assert False dev = tvm.cuda(0) np_a = np.random.uniform(size=(n, lanes)).astype("float32") np_a = np_bf162np_float(np_float2np_bf16(np_a)) From e753640b6d405575e13b93dbef19d444fa79ddb8 Mon Sep 17 00:00:00 2001 From: Zihao Date: Fri, 9 Jun 2023 17:50:44 -0700 Subject: [PATCH 03/43] upd --- src/target/source/cuda_vector_intrin.cc | 42 +++++++++++++++++++++++++ src/target/source/cuda_vector_intrin.h | 42 +++++++++++++++++++++++++ 2 files changed, 84 insertions(+) create mode 100644 src/target/source/cuda_vector_intrin.cc create mode 100644 src/target/source/cuda_vector_intrin.h diff --git a/src/target/source/cuda_vector_intrin.cc b/src/target/source/cuda_vector_intrin.cc new file mode 100644 index 000000000000..a3d957e42fcb --- /dev/null +++ b/src/target/source/cuda_vector_intrin.cc @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file cuda_vector_intrin.cc + * \brief Code generation with vector intrinsics in CUDA. + */ +#include "cuda_vector_intrin.h" + +namespace tvm { +namespace codegen { +namespace cuda { + +std::string PrintHalf2BinaryOp(const std::string& op, const std::string lhs, + const std::string rhs) { + return ""; +} + +std::string PrintNVBFloat162BinaryOp(const std::string& op, const std::string lhs, + const std::string rhs) { + return ""; +} + +} // namespace cuda +} // namespace codegen +} // namespace tvm diff --git a/src/target/source/cuda_vector_intrin.h b/src/target/source/cuda_vector_intrin.h new file mode 100644 index 000000000000..9f636681c0b4 --- /dev/null +++ b/src/target/source/cuda_vector_intrin.h @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file cuda_vector_intrin.h + * \brief Code generation with vector intrinsics in CUDA. + */ +#ifndef TVM_TARGET_SOURCE_CUDA_VECTOR_INTRIN_H_ +#define TVM_TARGET_SOURCE_CUDA_VECTOR_INTRIN_H_ + +#include + +namespace tvm { +namespace codegen { +namespace cuda { + +std::string PrintHalf2BinaryOp(const std::string& op, const std::string lhs, const std::string rhs); + +std::string PrintNVBFloat162BinaryOp(const std::string& op, const std::string lhs, + const std::string rhs); + +} // namespace cuda +} // namespace codegen +} // namespace tvm + +#endif // TVM_TARGET_SOURCE_CUDA_VECTOR_INTRIN_H_ \ No newline at end of file From 7f316368d67bc9011e46b9f46f2183691176394d Mon Sep 17 00:00:00 2001 From: Zihao Date: Thu, 29 Jun 2023 05:01:16 -0700 Subject: [PATCH 04/43] upd --- include/tvm/runtime/data_type.h | 6 +++- include/tvm/tir/op.h | 2 +- python/tvm/contrib/nvcc.py | 5 ++-- src/tir/op/op.cc | 40 +++++++++++++++++++------- src/tir/transforms/dtype_conversion.cc | 5 +--- try.py | 23 +++++++++++++++ 6 files changed, 63 insertions(+), 18 deletions(-) create mode 100644 try.py diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index 9fb113f56b2c..f2c7e09df369 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -95,7 +95,7 @@ class DataType { bool is_scalar() const { return lanes() == 1; } /*! \return whether type is a scalar type. */ bool is_bool() const { return code() == DataType::kUInt && bits() == 1; } - /*! \return whether type is a float type. */ + /*! \return whether type is a IEEE 754 standard float type. */ bool is_float() const { return code() == DataType::kFloat; } /*! \return whether type is a float8 type. */ bool is_float8() const { @@ -107,6 +107,10 @@ class DataType { bool is_float16() const { return is_float() && bits() == 16; } /*! \return whether type is a bfloat16 type. */ bool is_bfloat16() const { return code() == DataType::kBFloat && bits() == 16; } + /*! \return whether type is a general floating point data type. */ + bool is_floating_point() const { + return is_float() || is_float8() || is_bfloat16(); + } /*! \return whether type is an int type. */ bool is_int() const { return code() == DataType::kInt; } /*! \return whether type is an uint type. */ diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 3d5e589ab4a4..93738e0d773c 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -939,7 +939,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span()) return LargeUIntImm(t, static_cast(low), static_cast(high), span); } } - if (t.is_float() || t.is_bfloat16() || t.is_float8()) + if (t.is_floating_point()) return FloatImm(t, static_cast(value), span); // For now, we store const scalar values of custom datatypes within doubles; later, during the // datatypes lowering pass, we will lower the value to its true representation in the format diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index 2ef3cf9df4dd..bd36a325c813 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -99,8 +99,9 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target # However, we cannot use TVM_CXX_COMPILER_PATH because the runtime env. # Because it is hard to do runtime compiler detection, we require nvcc is configured # correctly by default. - # if cxx_compiler_path != "": - # cmd += ["-ccbin", cxx_compiler_path] + cxx_compiler_path = os.environ.get("CUDAHOSTCXX", "") + if cxx_compiler_path != "": + cmd += ["-ccbin", cxx_compiler_path] proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 39214c4546dc..da3392184eaf 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -33,6 +33,8 @@ // Centralized header for constant folders. #include "../../arith/const_fold.h" #include "../../target/datatype/registry.h" +#include "../../support/scalars.h" + namespace tvm { @@ -206,10 +208,16 @@ PrimExpr max_value(const DataType& dtype, Span span) { } else if (dtype.bits() == 32) { return FloatImm(dtype, std::numeric_limits::max(), span); } else if (dtype.bits() == 16) { - return FloatImm(dtype, 65504.0, span); + return FloatImm(dtype, support::kMaxFloat16, span); } } else if (dtype.is_bfloat16()) { - return FloatImm(dtype, std::numeric_limits::max(), span); + return FloatImm(dtype, support::kMaxBFloat16, span); + } else if (dtype.is_float8()) { + if (dtype.code() == DataType::kE4M3Float) { + return FloatImm(dtype, support::kMaxE4M3, span); + } else { // E5M2 + return FloatImm(dtype, support::kMaxE5M2, span); + } } LOG(FATAL) << "Cannot decide max_value for type" << dtype; } @@ -240,10 +248,16 @@ PrimExpr min_value(const DataType& dtype, Span span) { } else if (dtype.bits() == 32) { return FloatImm(dtype, std::numeric_limits::lowest(), span); } else if (dtype.bits() == 16) { - return FloatImm(dtype, -65504.0, span); - } + return FloatImm(dtype, -support::kMaxFloat16, span); + } } else if (dtype.is_bfloat16()) { - return FloatImm(dtype, std::numeric_limits::lowest(), span); + return FloatImm(dtype, -support::kMaxBFloat16, span); + } else if (dtype.is_float8()) { + if (dtype.code() == DataType::kE4M3Float) { + return FloatImm(dtype, -support::kMaxE4M3, span); + } else { // E5M2 + return FloatImm(dtype, -support::kMaxE5M2, span); + } } LOG(FATAL) << "Cannot decide min_value for type" << dtype; } @@ -258,6 +272,12 @@ PrimExpr infinity(const DataType& dtype, Span span) { } else if (dtype.bits() == 32 || dtype.bits() == 16) { return FloatImm(dtype, std::numeric_limits::infinity(), span); } + } else if (dtype.is_bfloat16()) { + return FloatImm(dtype, std::numeric_limits::infinity(), span); + } else if (dtype.is_float8()) { + if (dtype.code() == DataType::kE5M2Float) { + return FloatImm(dtype, std::numeric_limits::infinity(), span); + } } LOG(FATAL) << "Cannot decide infinity for type " << dtype; } @@ -661,7 +681,7 @@ TVM_REGISTER_GLOBAL("tir.bitwise_not").set_body_typed([](PrimExpr a, Span span) // pow PrimExpr pow(PrimExpr x, PrimExpr y, Span span) { BinaryOpMatchTypes(x, y, span); - ICHECK(x.dtype().is_float()) << "power only applies to float"; + ICHECK(x.dtype().is_floating_point()) << "power only applies to float"; static auto op = Op::Get("tir.pow"); return tir::Call(x.dtype(), op, {x, y}, span); } @@ -677,7 +697,7 @@ PrimExpr abs(PrimExpr x, Span span) { return IntImm(x.dtype(), std::abs(px->value), px->span); } return tir::Select(x >= make_zero(x.dtype()), x, -x, span); - } else if (x.dtype().is_float()) { + } else if (x.dtype().is_floating_point()) { using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) { @@ -701,7 +721,7 @@ PrimExpr isnan(PrimExpr x, Span span) { DataType t = DataType::Bool(x.dtype().lanes()); if (x.dtype().is_int() || x.dtype().is_uint()) { return make_const(t, false); - } else if (x.dtype().is_float()) { + } else if (x.dtype().is_floating_point()) { using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) { @@ -723,7 +743,7 @@ PrimExpr isinf(PrimExpr x, Span span) { DataType t = DataType::Bool(x.dtype().lanes()); if (x.dtype().is_int() || x.dtype().is_uint()) { return make_const(t, false, span); - } else if (x.dtype().is_float()) { + } else if (x.dtype().is_floating_point()) { PrimExpr infX = infinity(x.dtype(), span); return abs(x, span) == infX && !isnan(x, span); } else { @@ -787,7 +807,7 @@ PrimExpr prod(PrimExpr source, Array rdom, Array init, Span s // fmod PrimExpr fmod(PrimExpr x, PrimExpr y, Span span) { BinaryOpMatchTypes(x, y, span); - ICHECK(x.dtype().is_float()) << "fmod only applies to float"; + ICHECK(x.dtype().is_floating_point()) << "fmod only applies to float"; static auto op = Op::Get("tir.fmod"); return tir::Call(x.dtype(), op, {x, y}, span); } diff --git a/src/tir/transforms/dtype_conversion.cc b/src/tir/transforms/dtype_conversion.cc index de94cf647387..de9d57bed13d 100644 --- a/src/tir/transforms/dtype_conversion.cc +++ b/src/tir/transforms/dtype_conversion.cc @@ -38,11 +38,8 @@ PrimExpr DTypeConversion(PrimExpr src_value, DataType tgt_dtype, RoundingMode ro // The lanes of src dtype and target dtype must match. CHECK_EQ(src_dtype.lanes(), tgt_dtype.lanes()) << "The lanes for data type for source value must matches the target datatype."; - auto is_floating_point = [](DataType dtype) { - return dtype.is_float() || dtype.is_float8() || dtype.is_bfloat16(); - }; // Both source dtype and target dtype should be floating point. - CHECK(is_floating_point(src_dtype) && is_floating_point(tgt_dtype)); + CHECK(src_dtype.is_floating_point() && tgt_dtype.is_floating_point()); FloatConfig src_fp = FloatConfig::FromDataType(src_value.dtype()), tgt_fp = FloatConfig::FromDataType(tgt_dtype); int exponent_delta = tgt_fp.exponent - src_fp.exponent; diff --git a/try.py b/try.py new file mode 100644 index 000000000000..dd1be2cfebcf --- /dev/null +++ b/try.py @@ -0,0 +1,23 @@ +import tvm +from tvm.script import tir as T + +@T.prim_func +def vectorize_add_fp16(A: T.Buffer([128], "bfloat16"), B: T.Buffer([128], "bfloat16")) -> None: + + for i in range(128): + with T.block("blk"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + T.abs(A[vi]) + + +sch = tvm.tir.Schedule(vectorize_add_fp16, debug_mask="all") +blk = sch.get_block("blk") +i, = sch.get_loops(blk) +io, ii, v = sch.split(i, [None, 32, 2]) +sch.vectorize(v) +sch.bind(ii, "threadIdx.x") +sch.bind(io, "blockIdx.x") + +print(sch.mod["main"]) +f = tvm.build(sch.mod["main"], target="cuda") +print(f.imported_modules[0].get_source()) \ No newline at end of file From 0f6b8589e9dfd151854c4e711e8a997cf1e3f578 Mon Sep 17 00:00:00 2001 From: Zihao Date: Thu, 29 Jun 2023 05:48:16 -0700 Subject: [PATCH 05/43] upd --- src/driver/driver_api.cc | 24 +++++++++++---- src/target/llvm/codegen_llvm.cc | 6 ++++ src/target/source/codegen_cuda.cc | 15 ++++++++-- src/target/source/codegen_cuda.h | 2 +- src/target/source/intrin_rule_cuda.cc | 8 ++++- src/target/source/literal/cuda_half_t.h | 1 + .../transforms/unsupported_dtype_legalize.cc | 30 +++---------------- try.py | 23 -------------- 8 files changed, 50 insertions(+), 59 deletions(-) delete mode 100644 try.py diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index fecacff5b0d0..5526ef4920b5 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -55,8 +55,8 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer); TVM_REGISTER_PASS_CONFIG_OPTION("tir.ptx_ldg32", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.hardware_support_bf16", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.hardware_support_fp8", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.target_support_bf16", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.target_support_fp8", Bool); // WARNING: May cause coherency issues resulting data miscompares // Experimental feature that, when enabled by the runtime, bypasses the cache when using DMA. When @@ -154,6 +154,8 @@ Array CreatePassList(bool disable_loop_partition) { bool enable_equiv_terms_in_cse_tir = pass_ctx->GetConfig("tir.enable_equiv_terms_in_cse_tir", Bool(false)).value(); bool ptx_ldg32 = pass_ctx->GetConfig("tir.ptx_ldg32", Bool(false)).value(); + bool target_support_bf16 = pass_ctx->GetConfig("tir.target_support_bf16", Bool(false)).value(); + bool target_support_fp8 = pass_ctx->GetConfig("tir.target_support_fp8", Bool(false)).value(); // Get any user-added passes Array> add_lower_pass = @@ -211,8 +213,12 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::InjectSoftwarePipeline()); pass_list.push_back(tir::transform::LowerOpaqueBlock()); pass_list.push_back(tir::transform::FlattenBuffer()); - pass_list.push_back(tir::transform::FP8ComputeLegalize()); - pass_list.push_back(tir::transform::BF16ComputeLegalize()); + if (!target_support_fp8) { + pass_list.push_back(tir::transform::FP8ComputeLegalize()); + } + if (!target_support_bf16) { + pass_list.push_back(tir::transform::BF16ComputeLegalize()); + } pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); @@ -541,6 +547,8 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg, transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) { transform::PassContext pass_ctx = transform::PassContext::Current(); + bool target_support_bf16 = pass_ctx->GetConfig("tir.target_support_bf16", Bool(false)).value(); + bool target_support_fp8 = pass_ctx->GetConfig("tir.target_support_fp8", Bool(false)).value(); Array mixed_pass_list; @@ -588,8 +596,12 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) } else { mixed_pass_list.push_back(tir::transform::MakePackedAPI()); } - mixed_pass_list.push_back(tir::transform::FP8StorageLegalize()); - mixed_pass_list.push_back(tir::transform::BF16StorageLegalize()); + if (!target_support_fp8) { + mixed_pass_list.push_back(tir::transform::FP8StorageLegalize()); + } + if (!target_support_bf16) { + mixed_pass_list.push_back(tir::transform::BF16StorageLegalize()); + } mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions()); mixed_pass_list.push_back(tir::transform::SplitHostDevice()); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index ada51677ac16..d3b00a6a82aa 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -568,6 +568,12 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { default: LOG(FATAL) << "do not support " << dtype; } + } else if (dtype.is_bfloat16()) { +#if TVM_LLVM_VERSION >= 110 + etype = llvm::Type::getBFloatTy(*ctx); +#else + LOG(FATAL) << "bfloat16 is not supported for your LLVM version " << TVM_LLVM_VERSION << "."; +#endif } if (dtype.lanes() != 1) { #if TVM_LLVM_VERSION >= 110 diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index ec8695a2a038..fd3e06b1c5c1 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -36,6 +36,7 @@ #include "../../tir/transforms/ir_utils.h" #include "literal/cuda_half_t.h" #include "ptx.h" +#include "cuda_vector_intrin.h" namespace tvm { namespace codegen { @@ -210,7 +211,11 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) // h4.w is emitted as *(half2*)(&(u2.y)).y // ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; - os << "uint" << lanes / 2; + if (lanes == 2) { + os << "half2"; + } else { + os << "uint" << lanes / 2; + } } else { fail = true; } @@ -251,7 +256,11 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) os << "nv_bfloat16"; } else if (lanes <= 8) { ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; - os << "uint" << lanes / 2; + if (lanes == 2) { + os << "nv_bfloat162"; + } else { + os << "uint" << lanes / 2; + } } else { fail = true; } @@ -425,6 +434,8 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) LOG(FATAL) << "Cannot convert type " << t << " to CUDA type"; } + + void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, std::ostream& os) { // NOLINT(*) // Delcare the result. diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index c6cf96d460d4..46ecaba16538 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -52,7 +52,7 @@ class CodeGenCUDA final : public CodeGenC { void PrintStorageSync(const CallNode* op) final; void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, - std::ostream& os) final; // NOLINT(*) + std::ostream& os) final; // NOLINT(*) void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) void PrintVecElemLoad(const std::string& vec, DataType t, int i, std::ostream& os) final; // NOLINT(*) diff --git a/src/target/source/intrin_rule_cuda.cc b/src/target/source/intrin_rule_cuda.cc index 95fbf7f1a513..47222e75003d 100644 --- a/src/target/source/intrin_rule_cuda.cc +++ b/src/target/source/intrin_rule_cuda.cc @@ -53,7 +53,13 @@ struct CUDAMath { return ""; } } else if (t.is_bfloat16()) { - return 'h' + name; + if (name == "fabs") { + return "habs"; + } else if (name == "round") { + return "hrint"; + } else { + return 'h' + name; + } } return ""; } diff --git a/src/target/source/literal/cuda_half_t.h b/src/target/source/literal/cuda_half_t.h index c3a1c66e5874..bc6b627f38b8 100644 --- a/src/target/source/literal/cuda_half_t.h +++ b/src/target/source/literal/cuda_half_t.h @@ -357,6 +357,7 @@ CUDA_UNSUPPORTED_HALF_MATH_UNARY(htanh, tanhf) CUDA_UNSUPPORTED_HALF_MATH_UNARY(htan, tanf) CUDA_UNSUPPORTED_HALF_MATH_UNARY(hatan, atanf) CUDA_UNSUPPORTED_HALF_MATH_UNARY(herf, erf) +CUDA_UNSUPPORTED_HALF_MATH_UNARY(habs, abs) #undef CUDA_UNSUPPORTED_HALF_MATH_BINARY #undef CUDA_UNSUPPORTED_HALF_MATH_UNARY diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index 74d881ab9e9a..a92ce020c94d 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -686,13 +686,7 @@ namespace transform { Pass BF16ComputeLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - bool target_support_bf16 = - ctx->GetConfig("tir.hardware_support_bf16", Bool(false)).value(); - if (target_support_bf16) { - return f; - } else { - return BF16ComputeLegalizer().Legalize(f); - } + return BF16ComputeLegalizer().Legalize(f); }; return CreatePrimFuncPass(pass_func, 0, "tir.BF16ComputeLegalize", {}); } @@ -701,13 +695,7 @@ TVM_REGISTER_GLOBAL("tir.transform.BF16ComputeLegalize").set_body_typed(BF16Comp Pass BF16StorageLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - bool target_support_bf16 = - ctx->GetConfig("tir.hardware_support_bf16", Bool(false)).value(); - if (target_support_bf16) { - return f; - } else { - return BF16StorageLegalizer().Legalize(f); - } + return BF16StorageLegalizer().Legalize(f); }; return CreatePrimFuncPass(pass_func, 0, "tir.BF16StorageLegalize", {}); } @@ -716,12 +704,7 @@ TVM_REGISTER_GLOBAL("tir.transform.BF16StorageLegalize").set_body_typed(BF16Stor Pass FP8ComputeLegalize(String promote_dtype_str) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { - bool target_support_fp8 = ctx->GetConfig("tir.hardware_support_fp8", Bool(false)).value(); - if (target_support_fp8) { - return f; - } else { - return FP8ComputeLegalizer(DataType(String2DLDataType(promote_dtype_str))).Legalize(f); - } + return FP8ComputeLegalizer(DataType(String2DLDataType(promote_dtype_str))).Legalize(f); }; return CreatePrimFuncPass(pass_func, 0, "tir.FP8ComputeLegalize", {}); } @@ -730,12 +713,7 @@ TVM_REGISTER_GLOBAL("tir.transform.FP8ComputeLegalize").set_body_typed(FP8Comput Pass FP8StorageLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - bool target_support_fp8 = ctx->GetConfig("tir.hardware_support_fp8", Bool(false)).value(); - if (target_support_fp8) { - return f; - } else { - return FP8StorageLegalizer().Legalize(f); - } + return FP8StorageLegalizer().Legalize(f); }; return CreatePrimFuncPass(pass_func, 0, "tir.FP8StorageLegalize", {}); } diff --git a/try.py b/try.py deleted file mode 100644 index dd1be2cfebcf..000000000000 --- a/try.py +++ /dev/null @@ -1,23 +0,0 @@ -import tvm -from tvm.script import tir as T - -@T.prim_func -def vectorize_add_fp16(A: T.Buffer([128], "bfloat16"), B: T.Buffer([128], "bfloat16")) -> None: - - for i in range(128): - with T.block("blk"): - vi = T.axis.remap("S", [i]) - B[vi] = A[vi] + T.abs(A[vi]) - - -sch = tvm.tir.Schedule(vectorize_add_fp16, debug_mask="all") -blk = sch.get_block("blk") -i, = sch.get_loops(blk) -io, ii, v = sch.split(i, [None, 32, 2]) -sch.vectorize(v) -sch.bind(ii, "threadIdx.x") -sch.bind(io, "blockIdx.x") - -print(sch.mod["main"]) -f = tvm.build(sch.mod["main"], target="cuda") -print(f.imported_modules[0].get_source()) \ No newline at end of file From edc938afffe1dcc908893297f8258438656703d6 Mon Sep 17 00:00:00 2001 From: Zihao Date: Thu, 29 Jun 2023 05:48:59 -0700 Subject: [PATCH 06/43] rm redundant files --- src/target/source/codegen_cuda.cc | 1 - src/target/source/cuda_vector_intrin.cc | 42 ------------------------- src/target/source/cuda_vector_intrin.h | 42 ------------------------- 3 files changed, 85 deletions(-) delete mode 100644 src/target/source/cuda_vector_intrin.cc delete mode 100644 src/target/source/cuda_vector_intrin.h diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index fd3e06b1c5c1..72b62557751b 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -36,7 +36,6 @@ #include "../../tir/transforms/ir_utils.h" #include "literal/cuda_half_t.h" #include "ptx.h" -#include "cuda_vector_intrin.h" namespace tvm { namespace codegen { diff --git a/src/target/source/cuda_vector_intrin.cc b/src/target/source/cuda_vector_intrin.cc deleted file mode 100644 index a3d957e42fcb..000000000000 --- a/src/target/source/cuda_vector_intrin.cc +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file cuda_vector_intrin.cc - * \brief Code generation with vector intrinsics in CUDA. - */ -#include "cuda_vector_intrin.h" - -namespace tvm { -namespace codegen { -namespace cuda { - -std::string PrintHalf2BinaryOp(const std::string& op, const std::string lhs, - const std::string rhs) { - return ""; -} - -std::string PrintNVBFloat162BinaryOp(const std::string& op, const std::string lhs, - const std::string rhs) { - return ""; -} - -} // namespace cuda -} // namespace codegen -} // namespace tvm diff --git a/src/target/source/cuda_vector_intrin.h b/src/target/source/cuda_vector_intrin.h deleted file mode 100644 index 9f636681c0b4..000000000000 --- a/src/target/source/cuda_vector_intrin.h +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file cuda_vector_intrin.h - * \brief Code generation with vector intrinsics in CUDA. - */ -#ifndef TVM_TARGET_SOURCE_CUDA_VECTOR_INTRIN_H_ -#define TVM_TARGET_SOURCE_CUDA_VECTOR_INTRIN_H_ - -#include - -namespace tvm { -namespace codegen { -namespace cuda { - -std::string PrintHalf2BinaryOp(const std::string& op, const std::string lhs, const std::string rhs); - -std::string PrintNVBFloat162BinaryOp(const std::string& op, const std::string lhs, - const std::string rhs); - -} // namespace cuda -} // namespace codegen -} // namespace tvm - -#endif // TVM_TARGET_SOURCE_CUDA_VECTOR_INTRIN_H_ \ No newline at end of file From 4ac20f172600d22219165f8db2b50922ca40f8e2 Mon Sep 17 00:00:00 2001 From: Zihao Date: Thu, 29 Jun 2023 06:23:26 -0700 Subject: [PATCH 07/43] upd --- python/tvm/_ffi/runtime_ctypes.py | 1 + python/tvm/runtime/ndarray.py | 20 +++++++++++++------- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index adcc3a8e972c..2554098d2c43 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -108,6 +108,7 @@ class DataType(ctypes.Structure): "uint64": {"type_code": DataTypeCode.UINT, "bits": 64, "lanes": 1}, "e4m3_float8": {"type_code": DataTypeCode.E4M3Float, "bits": 8, "lanes": 1}, "e5m2_float8": {"type_code": DataTypeCode.E5M2Float, "bits": 8, "lanes": 1}, + "bfloat16": {"type_code": DataTypeCode.BFLOAT, "bits": 16, "lanes": 1}, "float16": {"type_code": DataTypeCode.FLOAT, "bits": 16, "lanes": 1}, "float32": {"type_code": DataTypeCode.FLOAT, "bits": 32, "lanes": 1}, "float64": {"type_code": DataTypeCode.FLOAT, "bits": 64, "lanes": 1}, diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index a78c68ee67c4..6857c5e2d958 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -167,18 +167,19 @@ def copyfrom(self, source_array): raise ValueError( f"array shape do not match the shape of NDArray {source_array.shape} vs {shape}" ) + numpy_str_map = DataType.NUMPY2STR np_dtype_str = ( numpy_str_map[source_array.dtype] if source_array.dtype in numpy_str_map else str(source_array.dtype) ) - if (not source_array.flags["C_CONTIGUOUS"]) or ( - dtype == "bfloat16" or dtype != np_dtype_str - ): - source_array = np.ascontiguousarray( - source_array, dtype="uint16" if dtype == "bfloat16" else dtype - ) + if (not source_array.flags["C_CONTIGUOUS"]) or dtype != np_dtype_str: + if dtype == "e4m3_float8": + dtype = "float8_e4m3fn" + elif dtype == "e5m2_float8": + dtype = "float8_e5m2" + source_array = np.ascontiguousarray(source_array, dtype) assert source_array.flags["C_CONTIGUOUS"] data = source_array.ctypes.data_as(ctypes.c_void_p) nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize) @@ -221,7 +222,12 @@ def numpy(self): if dtype == "int4": dtype = "int8" if dtype == "bfloat16": - dtype = "uint16" + if ml_dtypes is not None: + dtype = ml_dtypes.bfloat16 + else: + raise RuntimeError( + "ml_dtypes is not installed, cannot convert bfloat16 array to numpy." + ) if dtype == "e4m3_float8": if ml_dtypes is not None: dtype = ml_dtypes.float8_e4m3fn From de6c5b297b57dea7ae4302454fbf7bf415438719 Mon Sep 17 00:00:00 2001 From: Zihao Date: Thu, 29 Jun 2023 07:33:07 -0700 Subject: [PATCH 08/43] upd --- python/tvm/testing/utils.py | 5 ++ src/target/source/codegen_cuda.cc | 58 ++++++++++--------- src/target/source/codegen_cuda.h | 2 +- src/tir/op/op.cc | 9 ++- .../unittest/test_target_codegen_cuda.py | 41 +++++-------- 5 files changed, 56 insertions(+), 59 deletions(-) diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index f29b7c4394c2..92923a2ed290 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -114,6 +114,11 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7): compares the `abs(actual-desired)` with `atol+rtol*abs(desired)`. Since we often allow `desired` to be close to zero, we generally want non-zero `atol`. """ + # The ml_dtypes v0.2 is not compatible with allclose function, convert to float32. + if actual.dtype == "bfloat16": + actual = actual.astype("float32") + if desired.dtype == "bfloat16": + desired = desired.astype("float32") actual = np.asanyarray(actual) desired = np.asanyarray(desired) np.testing.assert_allclose(actual.shape, desired.shape) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 448e2d4a051d..bff30e65e0aa 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -199,6 +199,8 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) enable_fp16_ = true; if (t.is_scalar()) { os << "half"; + } else if (lanes == 2) { + os << "half2"; } else if (lanes <= 8) { // Emit CUDA code to access fp16 vector elements. // @@ -210,11 +212,7 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) // h4.w is emitted as *(half2*)(&(u2.y)).y // ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; - if (lanes == 2) { - os << "half2"; - } else { - os << "uint" << lanes / 2; - } + os << "uint" << lanes / 2; } else { fail = true; } @@ -253,13 +251,11 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) enable_bf16_ = true; if (t.is_scalar()) { os << "nv_bfloat16"; + } else if (lanes == 2) { + os << "nv_bfloat162"; } else if (lanes <= 8) { ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; - if (lanes == 2) { - os << "nv_bfloat162"; - } else { - os << "uint" << lanes / 2; - } + os << "uint" << lanes / 2; } else { fail = true; } @@ -433,8 +429,6 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) LOG(FATAL) << "Cannot convert type " << t << " to CUDA type"; } - - void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, std::ostream& os) { // NOLINT(*) // Delcare the result. @@ -1122,27 +1116,35 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO if (op->dtype.is_float16()) { std::string v = PrintExpr(op->value); - os << "make_"; - PrintType(op->dtype, os); - os << '('; - for (int i = 0; i < op->lanes / 2; ++i) { - if (i != 0) os << ", "; - os << "__pack_half2(" << v << ", " << v << ")"; + if (op->lanes == 2) { + os << "make_half2(" << v << ", " << v << ")"; + } else { + os << "make_"; + PrintType(op->dtype, os); + os << '('; + for (int i = 0; i < op->lanes / 2; ++i) { + if (i != 0) os << ", "; + os << "__pack_half2(" << v << ", " << v << ")"; + } + os << ')'; } - os << ')'; return; } if (op->dtype.is_bfloat16()) { std::string v = PrintExpr(op->value); - os << "make_"; - PrintType(op->dtype, os); - os << '('; - for (int i = 0; i < op->lanes / 2; ++i) { - if (i != 0) os << ", "; - os << "__pack_nv_bfloat162(" << v << ", " << v << ")"; + if (op->lanes == 2) { + os << "__halves2bfloat162(" << v << ", " << v << ")"; + } else { + os << "make_"; + PrintType(op->dtype, os); + os << "("; + for (int i = 0; i < op->lanes / 2; ++i) { + if (i != 0) os << ", "; + os << "__pack_nv_bfloat162(" << v << ", " << v << ")"; + } + os << ')'; } - os << ')'; return; } @@ -1402,7 +1404,7 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val os << '('; } if (i % 2 == 0) { - os << "__pack_half2(" << value; + os << "make_half2(" << value; } else { os << "," << value << ")"; if (i != t.lanes() - 1) { @@ -1421,7 +1423,7 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val os << '('; } if (i % 2 == 0) { - os << "__pack_bfloat162(" << value; + os << " __halves2bfloat162(" << value; } else { os << "," << value << ")"; if (i != t.lanes() - 1) { diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index 46ecaba16538..c6cf96d460d4 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -52,7 +52,7 @@ class CodeGenCUDA final : public CodeGenC { void PrintStorageSync(const CallNode* op) final; void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, - std::ostream& os) final; // NOLINT(*) + std::ostream& os) final; // NOLINT(*) void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) void PrintVecElemLoad(const std::string& vec, DataType t, int i, std::ostream& os) final; // NOLINT(*) diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index da3392184eaf..d9c1d6ff8f72 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -32,9 +32,8 @@ #include // Centralized header for constant folders. #include "../../arith/const_fold.h" -#include "../../target/datatype/registry.h" #include "../../support/scalars.h" - +#include "../../target/datatype/registry.h" namespace tvm { @@ -215,7 +214,7 @@ PrimExpr max_value(const DataType& dtype, Span span) { } else if (dtype.is_float8()) { if (dtype.code() == DataType::kE4M3Float) { return FloatImm(dtype, support::kMaxE4M3, span); - } else { // E5M2 + } else { // E5M2 return FloatImm(dtype, support::kMaxE5M2, span); } } @@ -249,13 +248,13 @@ PrimExpr min_value(const DataType& dtype, Span span) { return FloatImm(dtype, std::numeric_limits::lowest(), span); } else if (dtype.bits() == 16) { return FloatImm(dtype, -support::kMaxFloat16, span); - } + } } else if (dtype.is_bfloat16()) { return FloatImm(dtype, -support::kMaxBFloat16, span); } else if (dtype.is_float8()) { if (dtype.code() == DataType::kE4M3Float) { return FloatImm(dtype, -support::kMaxE4M3, span); - } else { // E5M2 + } else { // E5M2 return FloatImm(dtype, -support::kMaxE5M2, span); } } diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index fa9474e73d88..ea145daf1a3f 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -24,6 +24,10 @@ from tvm.contrib.nvcc import have_fp16, have_int8, have_bf16 import tvm.testing import pytest +try: + import ml_dtypes +except ImportError as e: + ml_dtypes = None tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") @@ -75,21 +79,11 @@ def test_cuda_bf16_vectorize_add(): if not have_bf16(tvm.cuda(0).compute_version): print("skip because gpu does not support bf16") return + if ml_dtypes is None: + print("skip because ml_dtypes not installed") + return num_thread = 8 - def np_float2np_bf16(arr): - """Convert a numpy array of float to a numpy array - of bf16 in uint16""" - orig = arr.view(" Date: Thu, 29 Jun 2023 07:35:09 -0700 Subject: [PATCH 09/43] do not change nvcc --- python/tvm/contrib/nvcc.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index bd36a325c813..2ef3cf9df4dd 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -99,9 +99,8 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target # However, we cannot use TVM_CXX_COMPILER_PATH because the runtime env. # Because it is hard to do runtime compiler detection, we require nvcc is configured # correctly by default. - cxx_compiler_path = os.environ.get("CUDAHOSTCXX", "") - if cxx_compiler_path != "": - cmd += ["-ccbin", cxx_compiler_path] + # if cxx_compiler_path != "": + # cmd += ["-ccbin", cxx_compiler_path] proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) From fd0ff79157f629e54f752f0e28d07a06e16e5bb9 Mon Sep 17 00:00:00 2001 From: Zihao Date: Thu, 29 Jun 2023 07:37:04 -0700 Subject: [PATCH 10/43] fix --- tests/python/unittest/test_target_codegen_cuda.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index ea145daf1a3f..633d2335f51c 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -103,7 +103,7 @@ def check_cuda(n, lanes): c = tvm.nd.empty((n, lanes), "bfloat16", dev).copyfrom(c) tvm.testing.assert_allclose(c.numpy(), np_a + 1) - # check_cuda(64, 2) + check_cuda(64, 2) check_cuda(64, 4) check_cuda(64, 6) check_cuda(64, 8) From a0525b9d95ee2ca89c0592c668a768e9cc01bc25 Mon Sep 17 00:00:00 2001 From: Zihao Date: Thu, 29 Jun 2023 07:38:45 -0700 Subject: [PATCH 11/43] remove empty line --- python/tvm/contrib/nvcc.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index 2ef3cf9df4dd..92bd4041dc66 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -101,7 +101,6 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target # correctly by default. # if cxx_compiler_path != "": # cmd += ["-ccbin", cxx_compiler_path] - proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) From 9c6d63938e93f5618b7950cd9372f9ead551124a Mon Sep 17 00:00:00 2001 From: Zihao Date: Thu, 29 Jun 2023 07:47:40 -0700 Subject: [PATCH 12/43] fix --- python/tvm/contrib/nvcc.py | 2 +- src/target/source/codegen_cuda.cc | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index 92bd4041dc66..5eb348009914 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -100,7 +100,7 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target # Because it is hard to do runtime compiler detection, we require nvcc is configured # correctly by default. # if cxx_compiler_path != "": - # cmd += ["-ccbin", cxx_compiler_path] + # cmd += ["-ccbin", cxx_compiler_path] proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index bff30e65e0aa..87a8c29ccd27 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -1404,7 +1404,7 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val os << '('; } if (i % 2 == 0) { - os << "make_half2(" << value; + os << "__pack_half2(" << value; } else { os << "," << value << ")"; if (i != t.lanes() - 1) { @@ -1423,7 +1423,7 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val os << '('; } if (i % 2 == 0) { - os << " __halves2bfloat162(" << value; + os << "__pack_nv_bfloat162(" << value; } else { os << "," << value << ")"; if (i != t.lanes() - 1) { From bdc3382107b44e3d50afe1ecd5d0b153de507d96 Mon Sep 17 00:00:00 2001 From: Zihao Date: Thu, 29 Jun 2023 09:46:20 -0700 Subject: [PATCH 13/43] lint --- tests/python/unittest/test_target_codegen_cuda.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index 633d2335f51c..90bdb0d75131 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -24,6 +24,7 @@ from tvm.contrib.nvcc import have_fp16, have_int8, have_bf16 import tvm.testing import pytest + try: import ml_dtypes except ImportError as e: @@ -91,9 +92,7 @@ def check_cuda(n, lanes): xo, xi = s[B].split(B.op.axis[0], factor=num_thread) s[B].bind(xo, bx) s[B].bind(xi, tx) - with tvm.transform.PassContext( - config={"tir.target_support_bf16": True} - ): + with tvm.transform.PassContext(config={"tir.target_support_bf16": True}): fun = tvm.build(s, [A, B], "cuda") dev = tvm.cuda(0) np_a = np.random.uniform(size=(n, lanes)).astype("bfloat16") From a15b600dfa053c34c889d53dc45793c892fd24cd Mon Sep 17 00:00:00 2001 From: Zihao Date: Thu, 29 Jun 2023 17:12:33 -0700 Subject: [PATCH 14/43] c++ lint --- src/driver/driver_api.cc | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index c8ecd3915cb1..41a6dfa22e9e 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -154,8 +154,10 @@ Array CreatePassList(bool disable_loop_partition) { bool enable_equiv_terms_in_cse_tir = pass_ctx->GetConfig("tir.enable_equiv_terms_in_cse_tir", Bool(false)).value(); bool ptx_ldg32 = pass_ctx->GetConfig("tir.ptx_ldg32", Bool(false)).value(); - bool target_support_bf16 = pass_ctx->GetConfig("tir.target_support_bf16", Bool(false)).value(); - bool target_support_fp8 = pass_ctx->GetConfig("tir.target_support_fp8", Bool(false)).value(); + bool target_support_bf16 = + pass_ctx->GetConfig("tir.target_support_bf16", Bool(false)).value(); + bool target_support_fp8 = + pass_ctx->GetConfig("tir.target_support_fp8", Bool(false)).value(); // Get any user-added passes Array> add_lower_pass = @@ -551,8 +553,10 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg, transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) { transform::PassContext pass_ctx = transform::PassContext::Current(); - bool target_support_bf16 = pass_ctx->GetConfig("tir.target_support_bf16", Bool(false)).value(); - bool target_support_fp8 = pass_ctx->GetConfig("tir.target_support_fp8", Bool(false)).value(); + bool target_support_bf16 = + pass_ctx->GetConfig("tir.target_support_bf16", Bool(false)).value(); + bool target_support_fp8 = + pass_ctx->GetConfig("tir.target_support_fp8", Bool(false)).value(); Array mixed_pass_list; From c33423ac397dd4acd19d6d37b3ef15c4c6d76270 Mon Sep 17 00:00:00 2001 From: Zihao Date: Thu, 29 Jun 2023 20:18:52 -0700 Subject: [PATCH 15/43] use ml_dtypes for llvm codegen test --- .../unittest/test_target_codegen_llvm.py | 41 +++---------------- 1 file changed, 6 insertions(+), 35 deletions(-) diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 6a2f5573b274..aa20c6078f97 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -695,33 +695,6 @@ def _transform(f, *_): tvm.testing.assert_allclose(c_.numpy(), (a_.numpy() * 2).astype("int32")) -def np_float2np_bf16(arr): - """Convert a numpy array of float to a numpy array - of bf16 in uint16""" - orig = arr.view(" Date: Fri, 30 Jun 2023 20:41:33 -0700 Subject: [PATCH 16/43] add ml_dtypes to ci-constraints --- docker/python/ci-constraints.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/docker/python/ci-constraints.txt b/docker/python/ci-constraints.txt index 003c13170411..b8f94c112a9c 100644 --- a/docker/python/ci-constraints.txt +++ b/docker/python/ci-constraints.txt @@ -37,3 +37,4 @@ tflite = "==2.4.0" torch = "==1.11.0" torchvision = "==0.12.0+cpu" #xgboost = "==1.4.2" +ml_dtypes = "==0.1.0" From 40f088ccf7559ce834bb0cfac22b5aec7b5cf3b6 Mon Sep 17 00:00:00 2001 From: Zihao Date: Fri, 30 Jun 2023 20:42:24 -0700 Subject: [PATCH 17/43] alphabetical --- docker/python/ci-constraints.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/python/ci-constraints.txt b/docker/python/ci-constraints.txt index b8f94c112a9c..7336c0821654 100644 --- a/docker/python/ci-constraints.txt +++ b/docker/python/ci-constraints.txt @@ -17,6 +17,7 @@ flowvision = "==0.1.0" #h5py = "==3.1.0" keras = "==2.7" jinja2 = "==3.0.3" +ml_dtypes = "==0.1.0" mxnet = "==1.6.0" mypy = "==0.902" oneflow = "==0.7.0" @@ -37,4 +38,3 @@ tflite = "==2.4.0" torch = "==1.11.0" torchvision = "==0.12.0+cpu" #xgboost = "==1.4.2" -ml_dtypes = "==0.1.0" From 259f27813b2e741c266fbd61f9f73a6fe89fb9eb Mon Sep 17 00:00:00 2001 From: Zihao Date: Fri, 30 Jun 2023 21:03:24 -0700 Subject: [PATCH 18/43] pylint --- include/tvm/runtime/data_type.h | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index f2c7e09df369..53465708354f 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -108,9 +108,7 @@ class DataType { /*! \return whether type is a bfloat16 type. */ bool is_bfloat16() const { return code() == DataType::kBFloat && bits() == 16; } /*! \return whether type is a general floating point data type. */ - bool is_floating_point() const { - return is_float() || is_float8() || is_bfloat16(); - } + bool is_floating_point() const { return is_float() || is_float8() || is_bfloat16(); } /*! \return whether type is an int type. */ bool is_int() const { return code() == DataType::kInt; } /*! \return whether type is an uint type. */ From aff8f610ff618667f1d06c24e30d7e5fd637220a Mon Sep 17 00:00:00 2001 From: Zihao Date: Sat, 1 Jul 2023 00:08:37 -0700 Subject: [PATCH 19/43] lint --- include/tvm/tir/op.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 93738e0d773c..fd1db1c7ba49 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -939,8 +939,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span()) return LargeUIntImm(t, static_cast(low), static_cast(high), span); } } - if (t.is_floating_point()) - return FloatImm(t, static_cast(value), span); + if (t.is_floating_point()) return FloatImm(t, static_cast(value), span); // For now, we store const scalar values of custom datatypes within doubles; later, during the // datatypes lowering pass, we will lower the value to its true representation in the format // specified by the datatype. From d883e336a1ad1dda3371cdfa2971ba87e6e82834 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sat, 1 Jul 2023 02:04:59 -0700 Subject: [PATCH 20/43] upd --- src/target/source/codegen_cuda.cc | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 87a8c29ccd27..2ce48f1f1224 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -482,9 +482,21 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i, os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))"; } } else if (t.is_float16()) { - os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; + if (t.lanes() == 2) { + // 2 * float16 is stored as half2, use v.x directly + os << vec << "." << access[i]; + } else { + // 4/8 * float16 is stored as uint2/4, use ((half2*)(&(v.x))).x + os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; + } } else if (t.is_bfloat16()) { - os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; + if (t.lanes() == 2) { + // 2 * bfloat16 is stored as nv_bfloat162, use v.x directly + os << vec << "." << access[i]; + } else { + // 4/8 * bfloat16 is stored as uint2/4, use ((nv_bfloat162*)(&(v.x))).x + os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; + } } else if (t.lanes() > 4 && t.lanes() <= 8) { std::string type_name; if (t.bits() == 16) { From 487c15f40282abb5fe2ccb7b932ceb96166cf479 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sat, 1 Jul 2023 02:17:13 -0700 Subject: [PATCH 21/43] improve comments --- src/target/source/codegen_cuda.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 2ce48f1f1224..96adb3992ef7 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -483,18 +483,18 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i, } } else if (t.is_float16()) { if (t.lanes() == 2) { - // 2 * float16 is stored as half2, use v.x directly + // 2 * float16 is stored as half2, return v.x/y directly os << vec << "." << access[i]; } else { - // 4/8 * float16 is stored as uint2/4, use ((half2*)(&(v.x))).x + // 4/8 * float16 is stored as uint2/4, return ((half2*)(&(v.x/y/z/w)))->x/y os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; } } else if (t.is_bfloat16()) { if (t.lanes() == 2) { - // 2 * bfloat16 is stored as nv_bfloat162, use v.x directly + // 2 * bfloat16 is stored as nv_bfloat162, return v.x/y directly os << vec << "." << access[i]; } else { - // 4/8 * bfloat16 is stored as uint2/4, use ((nv_bfloat162*)(&(v.x))).x + // 4/8 * bfloat16 is stored as uint2/4, return ((nv_bfloat162*)(&(v.x/y/z/w)))->x/y os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; } } else if (t.lanes() > 4 && t.lanes() <= 8) { From 73d636141a8cdb8a2325655a5e665553548bf654 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sat, 1 Jul 2023 02:33:29 -0700 Subject: [PATCH 22/43] improved code comment --- src/target/source/codegen_cuda.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 96adb3992ef7..4022a7aaf993 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -483,18 +483,18 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i, } } else if (t.is_float16()) { if (t.lanes() == 2) { - // 2 * float16 is stored as half2, return v.x/y directly + // vec (2 * float16) is stored as half2, return vec.x/y directly os << vec << "." << access[i]; } else { - // 4/8 * float16 is stored as uint2/4, return ((half2*)(&(v.x/y/z/w)))->x/y + // vec (4/8 * float16) is stored as uint2/4, return ((half2*)(&(vec.x/y/z/w)))->x/y os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; } } else if (t.is_bfloat16()) { if (t.lanes() == 2) { - // 2 * bfloat16 is stored as nv_bfloat162, return v.x/y directly + // vec (2 * bfloat16) is stored as nv_bfloat162, return vec.x/y directly os << vec << "." << access[i]; } else { - // 4/8 * bfloat16 is stored as uint2/4, return ((nv_bfloat162*)(&(v.x/y/z/w)))->x/y + // vec (4/8 * bfloat16) is stored as uint2/4, return ((nv_bfloat162*)(&(vec.x/y/z/w)))->x/y os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; } } else if (t.lanes() > 4 && t.lanes() <= 8) { From 0d7d8ba52807e1f41cf93105f0109ac4f644a4d2 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sat, 1 Jul 2023 12:09:21 -0700 Subject: [PATCH 23/43] upd --- python/tvm/testing/utils.py | 20 +++-- src/driver/driver_api.cc | 29 ++----- .../multi_level_tiling_tensor_core.cc | 2 + src/target/source/codegen_cuda.cc | 82 ++++++++++++------- src/target/tag.cc | 8 +- src/tir/transforms/dtype_conversion.h | 2 +- .../transforms/unsupported_dtype_legalize.cc | 63 ++++++++++++++ .../unittest/test_target_codegen_cuda.py | 82 ++++++++++--------- 8 files changed, 189 insertions(+), 99 deletions(-) diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 92923a2ed290..f50790d5f1dc 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -106,6 +106,17 @@ def test_something(): ) +def promote_bf16_to_fp32(x): + r"""Promote the data type of an array-like structure from bfloat16 to float32.""" + if isinstance(x, list): + return [promote_bf16_to_fp32(y) for y in x] + else: + if isinstance(x, np.ndarray) and x.dtype == "bfloat16": + return x.astype("float32") + else: + return x + + def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7): """Version of np.testing.assert_allclose with `atol` and `rtol` fields set in reasonable defaults. @@ -114,11 +125,10 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7): compares the `abs(actual-desired)` with `atol+rtol*abs(desired)`. Since we often allow `desired` to be close to zero, we generally want non-zero `atol`. """ - # The ml_dtypes v0.2 is not compatible with allclose function, convert to float32. - if actual.dtype == "bfloat16": - actual = actual.astype("float32") - if desired.dtype == "bfloat16": - desired = desired.astype("float32") + # The ml_dtypes v0.2 is not compatible with numpy's asanyarray function, promote to float32 first. + actual = promote_bf16_to_fp32(actual) + desired = promote_bf16_to_fp32(desired) + actual = np.asanyarray(actual) desired = np.asanyarray(desired) np.testing.assert_allclose(actual.shape, desired.shape) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 41a6dfa22e9e..f5eeda7243d7 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -55,8 +55,6 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer); TVM_REGISTER_PASS_CONFIG_OPTION("tir.ptx_ldg32", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.target_support_bf16", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.target_support_fp8", Bool); // WARNING: May cause coherency issues resulting data miscompares // Experimental feature that, when enabled by the runtime, bypasses the cache when using DMA. When @@ -154,10 +152,6 @@ Array CreatePassList(bool disable_loop_partition) { bool enable_equiv_terms_in_cse_tir = pass_ctx->GetConfig("tir.enable_equiv_terms_in_cse_tir", Bool(false)).value(); bool ptx_ldg32 = pass_ctx->GetConfig("tir.ptx_ldg32", Bool(false)).value(); - bool target_support_bf16 = - pass_ctx->GetConfig("tir.target_support_bf16", Bool(false)).value(); - bool target_support_fp8 = - pass_ctx->GetConfig("tir.target_support_fp8", Bool(false)).value(); // Get any user-added passes Array> add_lower_pass = @@ -219,12 +213,6 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::TransformMmaBufferLayout()); pass_list.push_back(tir::transform::LowerOpaqueBlock()); pass_list.push_back(tir::transform::FlattenBuffer()); - if (!target_support_fp8) { - pass_list.push_back(tir::transform::FP8ComputeLegalize()); - } - if (!target_support_bf16) { - pass_list.push_back(tir::transform::BF16ComputeLegalize()); - } pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); @@ -553,10 +541,6 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg, transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) { transform::PassContext pass_ctx = transform::PassContext::Current(); - bool target_support_bf16 = - pass_ctx->GetConfig("tir.target_support_bf16", Bool(false)).value(); - bool target_support_fp8 = - pass_ctx->GetConfig("tir.target_support_fp8", Bool(false)).value(); Array mixed_pass_list; @@ -567,6 +551,10 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) mixed_pass_list.push_back(tir::transform::BindTarget(target)); + // FP8/BF16 ComputeLegalize passes (after bind target) + mixed_pass_list.push_back(tir::transform::FP8ComputeLegalize()); + mixed_pass_list.push_back(tir::transform::BF16ComputeLegalize()); + mixed_pass_list.push_back(tir::transform::VerifyMemory()); mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc()); @@ -607,13 +595,8 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) } else { mixed_pass_list.push_back(tir::transform::MakePackedAPI()); } - if (!target_support_fp8) { - mixed_pass_list.push_back(tir::transform::FP8StorageLegalize()); - } - if (!target_support_bf16) { - mixed_pass_list.push_back(tir::transform::BF16StorageLegalize()); - } - + mixed_pass_list.push_back(tir::transform::FP8StorageLegalize()); + mixed_pass_list.push_back(tir::transform::BF16StorageLegalize()); mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch()); return transform::Sequential(mixed_pass_list); diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index 18bd58510d85..5f53e8097eb4 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -605,6 +605,8 @@ std::vector MultiLevelTilingTensorCoreNode::AddReadReuseTensorCore( const DataType& dtype = cache_read_buffer->dtype; if (dtype.is_float16()) { sch->StorageAlign(cache_read, 0, -2, 32, 8); + } else if (dtype.is_bfloat16()) { + sch->StorageAlign(cache_read, 0, -2, 32, 8); } else if (dtype.is_int() && dtype.bits() == 8) { sch->StorageAlign(cache_read, 0, -2, 32, 16); } else { diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 4022a7aaf993..fa3406f0f915 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -626,12 +626,18 @@ std::string CodeGenCUDA::CastFromTo(std::string value, DataType from, DataType t os << "(("; this->PrintType(target, os); os << ")"; - if (from.is_float16() && (target.is_int() || target.is_uint()) && target.bits() == 8) { + if ((from.is_float16() || from.is_bfloat16()) && (target.is_int() || target.is_uint()) && + target.bits() == 8) { + // use int/uint as intermediate data type os << "("; if (target.is_uint()) { os << "u"; } os << "int)"; + } else if ((from.is_bfloat16() && target.is_float16()) || + (from.is_float16() && target.is_bfloat16())) { + // use float as intermediate data type + os << "(float)"; } os << value << ")"; return os.str(); @@ -655,12 +661,9 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { std::string src = SSAGetID(PrintExpr(op->value), from_ty); for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) { std::ostringstream val; - val << "("; - PrintType(target_ty.element_of(), val); - val << ")("; PrintVecElemLoad(src, from_ty, i, val); - val << ")"; - PrintVecElemStore(sret, target_ty, i, val.str()); + std::string casted_val = CastFromTo(val.str(), from_ty.element_of(), target_ty.element_of()); + PrintVecElemStore(sret, target_ty, i, casted_val); } } os << sret; @@ -1146,7 +1149,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO if (op->dtype.is_bfloat16()) { std::string v = PrintExpr(op->value); if (op->lanes == 2) { - os << "__halves2bfloat162(" << v << ", " << v << ")"; + os << "make_bfloat162" << v << ", " << v << ")"; } else { os << "make_"; PrintType(op->dtype, os); @@ -1387,6 +1390,7 @@ void CodeGenCUDA::HandleVolatileLoads(const std::string& value, const BufferLoad // Cast away volatile qualifier for fp16 types. That is, only loads and // stores are volatile. The loaded objects are not marked as volatile. // + // TODO(Zihao): figure out what it is if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) && IsVolatile(op->buffer->data.get())) { os << "("; PrintType(op->dtype, os); @@ -1410,38 +1414,58 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val } if (t.is_float16()) { - if (i == 0) { - os << "make_"; - PrintType(t, os); - os << '('; - } - if (i % 2 == 0) { - os << "__pack_half2(" << value; + if (t.lanes() == 2) { + // result data type is half2 + if (i == 0) { + os << "make_half2(" << value; + } else { + os << ", " << value << ")"; + } } else { - os << "," << value << ")"; - if (i != t.lanes() - 1) { - os << ","; + // result data type is uint2/4 + if (i == 0) { + os << "make_"; + PrintType(t, os); + os << '('; + } + if (i % 2 == 0) { + os << "__pack_half2(" << value; } else { - os << ")"; + os << "," << value << ")"; + if (i != t.lanes() - 1) { + os << ","; + } else { + os << ")"; + } } } return; } if (t.is_bfloat16()) { - if (i == 0) { - os << "make_"; - PrintType(t, os); - os << '('; - } - if (i % 2 == 0) { - os << "__pack_nv_bfloat162(" << value; + if (t.lanes() == 2) { + // result data type is nv_bfloat162 + if (i == 0) { + os << "make_bfloat162" << value; + } else { + os << ", " << value << ")"; + } } else { - os << "," << value << ")"; - if (i != t.lanes() - 1) { - os << ","; + // result data type is uint2/4 + if (i == 0) { + os << "make_"; + PrintType(t, os); + os << '('; + } + if (i % 2 == 0) { + os << "__pack_nv_bfloat162(" << value; } else { - os << ")"; + os << "," << value << ")"; + if (i != t.lanes() - 1) { + os << ","; + } else { + os << ")"; + } } } return; diff --git a/src/target/tag.cc b/src/target/tag.cc index 037d2e5937ca..3188c0169a3e 100644 --- a/src/target/tag.cc +++ b/src/target/tag.cc @@ -108,8 +108,8 @@ TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-xavier") // Naming convention for CUDA tags see https://developer.nvidia.com/cuda-gpus // Parameters see Table 15. Technical Specifications per Compute Capability // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html -// Check `Maximum y- or z-dimension of a grid of thread blocks` for max threads per block // Check `Maximum amount of shared memory per thread block` for max shared memory per block +// Check `Maximum number of 32-bit registers per thread block` for max registers per block // Note that above 48 KB requires dynamic shared memory TVM_REGISTER_CUDA_TAG("nvidia/tesla-k80", "sm_37", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/tesla-k40", "sm_35", 49152, 65536); @@ -219,6 +219,12 @@ TVM_REGISTER_CUDA_TAG("nvidia/nvidia-nvs-310", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/nvs-5400m", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/nvs-5200m", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/nvs-4200m", "sm_21", 49152, 32768); +TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-4090", "sm_89", 49152, 65536); +TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-4080", "sm_89", 49152, 65536); +TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-4070-ti", "sm_89", 49152, 65536); +TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-4070", "sm_89", 49152, 65536); +TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-4060-ti", "sm_89", 49152, 65536); +TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-4060", "sm_89", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3090-ti", "sm_86", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3090", "sm_86", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3080-ti", "sm_86", 49152, 65536); diff --git a/src/tir/transforms/dtype_conversion.h b/src/tir/transforms/dtype_conversion.h index b509abb9cd27..80a65b900520 100644 --- a/src/tir/transforms/dtype_conversion.h +++ b/src/tir/transforms/dtype_conversion.h @@ -99,7 +99,7 @@ class FloatConfig { * \return The FloatConfig class containing internal floating point representation. */ static FloatConfig FromDataType(DataType dtype) { - CHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8()) + CHECK(dtype.is_floating_point()) << "FloatConfig is only applicable to floating point data types, got " << dtype << " instead."; if (dtype.is_float()) { diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index a92ce020c94d..70f15d31c415 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -30,11 +30,54 @@ #include #include +#include "../../support/utils.h" #include "dtype_conversion.h" namespace tvm { namespace tir { +namespace { + +/*! + * \brief Return whether target has native BF16 support. + */ +bool TargetSupportBF16(const Target& target) { + if (target->kind->name == "cuda") { + auto cuda_arch = target->GetAttr("arch"); + if (cuda_arch.defined()) { + auto cuda_arch_str = cuda_arch.value(); + CHECK(support::StartsWith(cuda_arch_str, "sm_")) + << "Expect cuda arch to start with \"sm_\", got " << cuda_arch_str << " instead."; + if (std::stoi(std::string(cuda_arch_str).substr(3)) >= 80) { + // sm_80 or later + return true; + } + } + } + return false; +} + +/*! + * \brief Return whether target has native FP8 support. + */ +bool TargetSupportFP8(const Target& target) { + if (target->kind->name == "cuda") { + auto cuda_arch = target->GetAttr("arch"); + if (cuda_arch.defined()) { + auto cuda_arch_str = cuda_arch.value(); + CHECK(support::StartsWith(cuda_arch_str, "sm_")) + << "Expect cuda arch to start with \"sm_\", got " << cuda_arch_str << " instead."; + if (std::stoi(std::string(cuda_arch_str).substr(3)) >= 89) { + // sm_89 or later + return true; + } + } + } + return false; +} + +} // namespace + // NOTE: do not touch buffer on function boundary // remap internal fp8/bf16 buffer to f32 if they meet the following condition // - constant allocation size @@ -686,6 +729,11 @@ namespace transform { Pass BF16ComputeLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto target = f->GetAttr(tvm::attr::kTarget); + if (target.defined() && TargetSupportBF16(target.value())) { + // Do not legalize bf16 compute if target has native bf16 support. + return f; + } return BF16ComputeLegalizer().Legalize(f); }; return CreatePrimFuncPass(pass_func, 0, "tir.BF16ComputeLegalize", {}); @@ -695,6 +743,11 @@ TVM_REGISTER_GLOBAL("tir.transform.BF16ComputeLegalize").set_body_typed(BF16Comp Pass BF16StorageLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto target = f->GetAttr(tvm::attr::kTarget); + if (target.defined() && TargetSupportBF16(target.value())) { + // Do not legalize bf16 storage if target has native bf16 support. + return f; + } return BF16StorageLegalizer().Legalize(f); }; return CreatePrimFuncPass(pass_func, 0, "tir.BF16StorageLegalize", {}); @@ -704,6 +757,11 @@ TVM_REGISTER_GLOBAL("tir.transform.BF16StorageLegalize").set_body_typed(BF16Stor Pass FP8ComputeLegalize(String promote_dtype_str) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto target = f->GetAttr(tvm::attr::kTarget); + if (target.defined() && TargetSupportFP8(target.value())) { + // Do not legalize fp8 compute if target has native fp8 support. + return f; + } return FP8ComputeLegalizer(DataType(String2DLDataType(promote_dtype_str))).Legalize(f); }; return CreatePrimFuncPass(pass_func, 0, "tir.FP8ComputeLegalize", {}); @@ -713,6 +771,11 @@ TVM_REGISTER_GLOBAL("tir.transform.FP8ComputeLegalize").set_body_typed(FP8Comput Pass FP8StorageLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto target = f->GetAttr(tvm::attr::kTarget); + if (target.defined() && TargetSupportFP8(target.value())) { + // Do not legalize fp8 storage if target has native fp8 support. + return f; + } return FP8StorageLegalizer().Legalize(f); }; return CreatePrimFuncPass(pass_func, 0, "tir.FP8StorageLegalize", {}); diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index 90bdb0d75131..4b34d20250e1 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -24,11 +24,7 @@ from tvm.contrib.nvcc import have_fp16, have_int8, have_bf16 import tvm.testing import pytest - -try: - import ml_dtypes -except ImportError as e: - ml_dtypes = None +import ml_dtypes tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") @@ -43,6 +39,9 @@ def check_cuda(dtype, n, lanes): if dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version): print("Skip because gpu does not have fp16 support") return + if dtype == "bfloat16" and not have_bf16(tvm.cuda(0).compute_version): + print("skip because gpu does not support bf16") + return if dtype == "int8" and not have_int8(tvm.cuda(0).compute_version): print("skip because gpu does not support int8") return @@ -72,40 +71,10 @@ def check_cuda(dtype, n, lanes): check_cuda("float16", 64, 4) check_cuda("float16", 64, 6) check_cuda("float16", 64, 8) - - -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda -def test_cuda_bf16_vectorize_add(): - if not have_bf16(tvm.cuda(0).compute_version): - print("skip because gpu does not support bf16") - return - if ml_dtypes is None: - print("skip because ml_dtypes not installed") - return - num_thread = 8 - - def check_cuda(n, lanes): - A = te.placeholder((n,), name="A", dtype="bfloat16x%d" % lanes) - B = te.compute((n,), lambda i: A[i] + tvm.tir.const(1, A.dtype), name="B") - s = te.create_schedule(B.op) - xo, xi = s[B].split(B.op.axis[0], factor=num_thread) - s[B].bind(xo, bx) - s[B].bind(xi, tx) - with tvm.transform.PassContext(config={"tir.target_support_bf16": True}): - fun = tvm.build(s, [A, B], "cuda") - dev = tvm.cuda(0) - np_a = np.random.uniform(size=(n, lanes)).astype("bfloat16") - a = tvm.nd.empty((n,), A.dtype, dev).copyfrom(np_a) - c = tvm.nd.empty((n,), B.dtype, dev) - fun(a, c) - c = tvm.nd.empty((n, lanes), "bfloat16", dev).copyfrom(c) - tvm.testing.assert_allclose(c.numpy(), np_a + 1) - - check_cuda(64, 2) - check_cuda(64, 4) - check_cuda(64, 6) - check_cuda(64, 8) + check_cuda("bfloat16", 64, 2) + check_cuda("bfloat16", 64, 4) + check_cuda("bfloat16", 64, 6) + check_cuda("bfloat16", 64, 8) @tvm.testing.requires_gpu @@ -468,11 +437,13 @@ def check(device, dtype, m=32, n=32): b_nd = tvm.nd.array(b_np, dev) g_nd = tvm.nd.array(np.zeros(g_np.shape, dtype=g_np.dtype), dev) func(a_nd, b_nd, g_nd) + print(g_nd, g_np) tvm.testing.assert_allclose(g_nd.numpy(), g_np, rtol=1e-3) check("cuda", "float32") check("rocm", "float32") check("cuda", "float16") + # check("cuda", "bfloat16") ignored because of rounding errors @tvm.testing.requires_gpu @@ -486,6 +457,9 @@ def check(device, dtype, m=32, n=32): if dtype == "float16" and not have_fp16(dev.compute_version): print("Skip because gpu does not have fp16 support") return + if dtype == "bfloat16" and not have_bf16(dev.compute_version): + print("skip because gpu does not support bf16") + return a = tvm.te.placeholder((m, n), name="a", dtype=dtype) b = topi.sum(a) @@ -504,6 +478,7 @@ def check(device, dtype, m=32, n=32): check("cuda", "float32") check("rocm", "float32") check("cuda", "float16") + # check("cuda", "bfloat16") ignored because of rounding errors @tvm.testing.requires_gpu @@ -565,6 +540,9 @@ def check(t0, t1, factor): if (t0 == "float16" or t1 == "float16") and not have_fp16(tvm.cuda(0).compute_version): print("Skip because gpu does not have fp16 support") return + if (t0 == "bfloat16" or t1 == "bfloat16") and not have_bf16(tvm.cuda(0).compute_version): + print("skip because gpu does not support bf16") + return # compute n = 128 @@ -577,6 +555,7 @@ def check(t0, t1, factor): ob, ib = s[C].split(s[C].op.axis[0], factor=factor) s[C].vectorize(ib) s[C].bind(ob, tx) + func = tvm.build(s, [A, B, C], "cuda") # correctness @@ -603,6 +582,7 @@ def skip(t0, t1): types_4 = [ "float16", "float32", + "bfloat16", "int8", "uint8", "int16", @@ -613,7 +593,17 @@ def skip(t0, t1): "int64", "uint64", ] - types_8 = ["float16", "float32", "int8", "uint8", "int16", "uint16", "int32", "uint32"] + types_8 = [ + "float16", + "float32", + "bfloat16", + "int8", + "uint8", + "int16", + "uint16", + "int32", + "uint32", + ] for t0, t1 in [(x, y) for x in types_4 for y in types_4 if not skip(x, y)]: check(t0, t1, 4) for t0, t1 in [(x, y) for x in types_8 for y in types_8 if not skip(x, y)]: @@ -662,6 +652,9 @@ def run_test(tvm_intrin, np_func, dtype): if dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version): print("Skip because gpu does not have fp16 support") return + if dtype == "bfloat16" and not have_bf16(tvm.cuda(0).compute_version): + print("skip because gpu does not support bf16") + return # set of intrinsics does not support fp16 yet. skip_set = { tvm.tir.abs, @@ -675,6 +668,9 @@ def run_test(tvm_intrin, np_func, dtype): if dtype == "float16" and tvm_intrin in skip_set: print("Skip because '{0}' does not support fp16 yet".format(tvm_intrin.__name__)) return + if dtype == "bfloat16" and tvm_intrin in skip_set: + print("Skip because '{0}' does not support bf16 yet".format(tvm_intrin.__name__)) + return n = 128 A = te.placeholder((n,), dtype=dtype, name="A") @@ -684,12 +680,14 @@ def run_test(tvm_intrin, np_func, dtype): dev = tvm.cuda(0) a = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(A.dtype), dev) b = tvm.nd.array(np.zeros(shape=(n,)).astype(A.dtype), dev) + f(a, b) tvm.testing.assert_allclose(b.numpy(), np_func(a.numpy()), atol=1e-3, rtol=1e-3) for func in test_funcs: run_test(*func, "float32") run_test(*func, "float16") + run_test(*func, "bfloat16") @tvm.testing.requires_gpu @@ -751,6 +749,9 @@ def check_cuda(dtype, n, l, padding, lanes): if dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version): print("Skip because gpu does not have fp16 support") return + if dtype == "bfloat16" and not have_bf16(tvm.cuda(0).compute_version): + print("skip because gpu does not support bf16") + return dev = tvm.cuda(0) A = tvm.te.placeholder((n, l), name="A", dtype=dtype) @@ -786,6 +787,7 @@ def check_cuda(dtype, n, l, padding, lanes): check_cuda("int32", 64, 16, 3, 4) check_cuda("float16", 64, 16, 3, 4) check_cuda("float32", 64, 16, 3, 4) + check_cuda("bfloat16", 64, 16, 3, 4) def vcf_check_common(s, args): From 4f4e8a6b5d82afa69e505bf029ac2774e3082da7 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sat, 1 Jul 2023 12:11:08 -0700 Subject: [PATCH 24/43] bugfix --- src/target/source/codegen_cuda.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index fa3406f0f915..63b3572e3939 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -1149,7 +1149,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO if (op->dtype.is_bfloat16()) { std::string v = PrintExpr(op->value); if (op->lanes == 2) { - os << "make_bfloat162" << v << ", " << v << ")"; + os << "make_bfloat162(" << v << ", " << v << ")"; } else { os << "make_"; PrintType(op->dtype, os); @@ -1446,7 +1446,7 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val if (t.lanes() == 2) { // result data type is nv_bfloat162 if (i == 0) { - os << "make_bfloat162" << value; + os << "make_bfloat162(" << value; } else { os << ", " << value << ")"; } From 2162ebaa1b8d2519d6fec81269c450dcf078dffb Mon Sep 17 00:00:00 2001 From: Zihao Date: Sat, 1 Jul 2023 19:18:29 -0700 Subject: [PATCH 25/43] bugfix --- src/tir/transforms/unsupported_dtype_legalize.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index 70f15d31c415..cd942c5ba68e 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -366,7 +366,8 @@ class ComputeLegalizer : public StmtExprMutator { DataType legalized_dtype = new_buf->dtype.with_lanes(index_lanes * buffer_lanes); value = CastTargetToDType(value, legalized_dtype); } - if (value.dtype() != new_buf->dtype) { + + if (value.dtype().element_of() != new_buf->dtype.element_of()) { // this happens when buffer get rewritten to f32 // but values remain as fp8/bf16 ICHECK(MatchDType(value->dtype)); From 31355fa7144d702c7a2b400ae75a568a6f414cd1 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sat, 1 Jul 2023 19:20:51 -0700 Subject: [PATCH 26/43] lint --- python/tvm/testing/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index f50790d5f1dc..b4a5ada1f0ff 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -125,7 +125,8 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7): compares the `abs(actual-desired)` with `atol+rtol*abs(desired)`. Since we often allow `desired` to be close to zero, we generally want non-zero `atol`. """ - # The ml_dtypes v0.2 is not compatible with numpy's asanyarray function, promote to float32 first. + # The ml_dtypes v0.2 is not compatible with numpy's asanyarray function, promote to + # float32 first. actual = promote_bf16_to_fp32(actual) desired = promote_bf16_to_fp32(desired) From daaae7107cefb7ca35581561c348c6954c37b9c6 Mon Sep 17 00:00:00 2001 From: Zihao Date: Mon, 3 Jul 2023 01:56:55 -0700 Subject: [PATCH 27/43] refactor buildprocess --- include/tvm/driver/driver_api.h | 11 +++++- python/tvm/driver/build_module.py | 64 +++++++++++++++++++++++++------ src/driver/driver_api.cc | 44 ++++++++++++++------- 3 files changed, 92 insertions(+), 27 deletions(-) diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index fffcab49667c..ce498af4ce3a 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -50,11 +50,14 @@ using tvm::transform::Pass; /*! * \brief Configures and returns the composite Pass for the fused module (pre split) that contains * device and host code. + * * \param mixed_mod The original mixed module. * \param target The device Target. + * \param apply_lower_passes Whether to apply lowering passes. * \return The composite Pass for the fused module. // */ -TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target); +TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target, + bool apply_lower_passes); /*! * \brief Configures and returns the composite Pass for the device Target after device/host from @@ -140,6 +143,12 @@ TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array& args, IRModule ScheduleToModule(te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds, GlobalVarSupply global_var_supply); + +TVM_DLL runtime::Module IRModuleToRuntimeModule(const Map& inputs_arg, + const Target& target_host_arg, bool apply_lower_passes) { + +} + /*! * \brief Build a device and host module for a specific target from an IRModule. * \param funcs The functions to be built. diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 9389e7fbee60..4d7dc609147d 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -71,8 +71,8 @@ def schedule_to_module( """According to the given schedule, form a function. This is a low-level function intended for testing purposes, and - does not apply any optimization passes. In general, `tvm.lower` - and `tvm.build` should be used instead. + does not apply any optimization passes. In general, `tvm.build` + should be used instead. Parameters ---------- @@ -91,6 +91,47 @@ def schedule_to_module( return ffi.schedule_to_module(sch, args, name, binds) +def as_ir_module( + inp: Union[te.Schedule, PrimFunc, IRModule], + args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, + name: str = "main", + binds: Optional[Mapping[tensor.Tensor, Buffer]] = None, +) -> IRModule: + """Convert input to IRModule. + + Parameters + ---------- + inp : Union[tvm.te.schedule.Schedule, tvm.tir.PrimFunc, IRModule] + The input TE schedule or TensorIR PrimFunc/IRModule. + + args : Optional[List[Union[tvm.tir.Buffer, tensor.Tensor, Var]]] + The argument lists to the function for TE schedule. + It should be None if :attr:`inp` is a TensorIR PrimFunc/IRModule. + + name : str + The name of the result function. + + binds : Optional[Mapping[tensor.Tensor, tvm.tir.Buffer]] + Dictionary that maps the Tensor to Buffer which specified the data layout + requirement of the function. By default, a new compact buffer is created + for each tensor in the argument. + + Returns + ------- + m : IRModule + The result IRModule. + """ + if isinstance(inp, IRModule): + return inp + if isinstance(inp, PrimFunc): + return IRModule({name: inp.with_attr("global_symbol", name)}) + if isinstance(inp, te.Schedule): + return schedule_to_module(inp, args, name, binds) + raise ValueError( + f"Expected input to be an IRModule, PrimFunc or te.Schedule, but got {type(inp)}" + ) + + def lower( inp: Union[te.Schedule, PrimFunc, IRModule], args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, @@ -100,6 +141,8 @@ def lower( ) -> IRModule: """Lowering step before build into target. + Warning(legacy): This function is maintained for backward compatibility, please use :func:`build` directly. + Parameters ---------- inp : Union[tvm.te.schedule.Schedule, tvm.tir.PrimFunc, IRModule] @@ -199,8 +242,7 @@ def build( B = te.placeholder((n,), name='B') C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') s = tvm.te.create_schedule(C.op) - m = tvm.lower(s, [A, B, C], name="test_add") - rt_mod = tvm.build(m, target="llvm") + rt_mod = tvm.build(s, target="llvm") 2. it is a dict of compilation target to IRModule. @@ -213,9 +255,7 @@ def build( s1 = tvm.te.create_schedule(C.op) with tvm.target.cuda() as cuda_tgt: s2 = topi.cuda.schedule_injective(cuda_tgt, [C]) - m1 = tvm.lower(s1, [A, B, C], name="test_add1") - m2 = tvm.lower(s2, [A, B, C], name="test_add2") - rt_mod = tvm.build({"llvm": m1, "cuda": m2}) + rt_mod = tvm.build({"llvm": s1, "cuda": s2}) Note ---- @@ -224,16 +264,16 @@ def build( if isinstance(inputs, te.Schedule): if args is None: raise ValueError("args must be given for build from schedule") - input_mod = lower(inputs, args, name=name, binds=binds) + input_mod = as_ir_module(inputs, args, name=name, binds=binds) elif isinstance(inputs, (list, tuple, container.Array)): merged_mod = tvm.IRModule({}) for x in inputs: - merged_mod.update(lower(x)) + merged_mod.update(as_ir_module(x)) input_mod = merged_mod elif isinstance(inputs, PrimFunc): - input_mod = lower(inputs, name=name) + input_mod = as_ir_module(inputs, name=name) elif isinstance(inputs, tvm.IRModule): - input_mod = lower(inputs) + input_mod = as_ir_module(inputs) elif not isinstance(inputs, (dict, container.Map)): raise ValueError( f"Inputs must be te.Schedule, IRModule, PrimFunc, " @@ -278,7 +318,7 @@ def build( annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host) - rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host) + rt_mod_host = _driver_ffi.ir_module_to_runtime_module(annotated_mods, target_host, True) annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index f5eeda7243d7..9b932fd12ca9 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -213,6 +213,8 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::TransformMmaBufferLayout()); pass_list.push_back(tir::transform::LowerOpaqueBlock()); pass_list.push_back(tir::transform::FlattenBuffer()); + pass_list.push_back(tir::transform::FP8ComputeLegalize()); + pass_list.push_back(tir::transform::BF16ComputeLegalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); @@ -402,13 +404,14 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule") * device and host. Then it also applies transformations on the new splitted modules. */ std::pair SplitMixedModule(IRModule mod_mixed, const Target& target_arg, - const Target& target_host_arg) { + const Target& target_host_arg, + bool apply_lower_passes) { Target target = target_arg, target_host = target_host_arg; CheckAndUpdateHostConsistency(&target, &target_host); ICHECK(mod_mixed.defined()) << "This module must be defined"; - mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, target)); + mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, target, apply_lower_passes)); IRModule host_mod = ApplyPasses(mod_mixed, HostModulePassManager(mod_mixed, target_host)); @@ -427,8 +430,8 @@ std::pair SplitMixedModule(IRModule mod_mixed, const Target& return {host_mod, device_mod}; } -runtime::Module TIRToRuntime(const Map& inputs_arg, - const Target& target_host_arg) { +runtime::Module IRModuleToRuntimeModule(const Map& inputs_arg, + const Target& target_host_arg, bool apply_lower_passes) { std::vector device_modules; Map inputs = inputs_arg; Target target_host = target_host_arg; @@ -464,7 +467,7 @@ runtime::Module TIRToRuntime(const Map& inputs_arg, if (it.second.defined()) { const Target& target = it.first; const IRModule& ir_module = it.second; - auto pair = SplitMixedModule(ir_module, target, target_host); + auto pair = SplitMixedModule(ir_module, target, target_host, apply_lower_passes); auto& host_mod = pair.first; auto& device_mod = pair.second; @@ -501,6 +504,17 @@ runtime::Module TIRToRuntime(const Map& inputs_arg, return mhost; } +TVM_REGISTER_GLOBAL("driver.ir_module_to_runtime_module") + .set_body_typed([](const Map& inputs_arg, Target host_target, + bool apply_lower_passes) { + return IRModuleToRuntimeModule(inputs_arg, host_target, apply_lower_passes); + }); + +runtime::Module TIRToRuntime(const Map& inputs_arg, + const Target& target_host_arg) { + return IRModuleToRuntimeModule(inputs_arg, target_host_arg, false); +} + TVM_REGISTER_GLOBAL("driver.tir_to_runtime") .set_body_typed([](const Map& inputs_arg, Target host_target) { return TIRToRuntime(inputs_arg, host_target); @@ -539,22 +553,24 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg, return TIRToRuntime(inputs, target_host); } -transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) { +transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target, + bool apply_te_passes = false) { transform::PassContext pass_ctx = transform::PassContext::Current(); Array mixed_pass_list; + mixed_pass_list.push_back(tir::transform::BindTarget(target)); + if (apply_te_passes) { + for (auto&& pass : CreatePassList(false)) { + mixed_pass_list.push_back(pass); + } + } + // VerifyVTCMLimit must occur before LowerVtcmAlloc mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(target)); // LowerVtcmAlloc must occur after any transformations that modify memory allocation locations mixed_pass_list.push_back(tir::transform::LowerVtcmAlloc()); - mixed_pass_list.push_back(tir::transform::BindTarget(target)); - - // FP8/BF16 ComputeLegalize passes (after bind target) - mixed_pass_list.push_back(tir::transform::FP8ComputeLegalize()); - mixed_pass_list.push_back(tir::transform::BF16ComputeLegalize()); - mixed_pass_list.push_back(tir::transform::VerifyMemory()); mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc()); @@ -603,8 +619,8 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) } TVM_REGISTER_GLOBAL("driver.mixed_mod_passes") - .set_body_typed([](IRModule mixed_mod, Target target) { - return MixedModulePassManager(mixed_mod, target); + .set_body_typed([](IRModule mixed_mod, Target target, bool apply_lower_passes) { + return MixedModulePassManager(mixed_mod, target, apply_lower_passes); }); transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host) { From af72e1e76730046fb7c035f2ba66163f21e94e4f Mon Sep 17 00:00:00 2001 From: Zihao Date: Mon, 3 Jul 2023 02:14:02 -0700 Subject: [PATCH 28/43] remove unused functions --- include/tvm/driver/driver_api.h | 5 ----- 1 file changed, 5 deletions(-) diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index ce498af4ce3a..9c582e26badd 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -144,11 +144,6 @@ IRModule ScheduleToModule(te::Schedule sch, const Array& args, const const std::unordered_map& binds, GlobalVarSupply global_var_supply); -TVM_DLL runtime::Module IRModuleToRuntimeModule(const Map& inputs_arg, - const Target& target_host_arg, bool apply_lower_passes) { - -} - /*! * \brief Build a device and host module for a specific target from an IRModule. * \param funcs The functions to be built. From 98cb6e4829628d6f199ca1ec952da7eec2b8843d Mon Sep 17 00:00:00 2001 From: Zihao Date: Mon, 3 Jul 2023 04:48:51 -0700 Subject: [PATCH 29/43] pylint --- python/tvm/driver/build_module.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 4d7dc609147d..b92c6f15d497 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -141,7 +141,8 @@ def lower( ) -> IRModule: """Lowering step before build into target. - Warning(legacy): This function is maintained for backward compatibility, please use :func:`build` directly. + (Warning) This function is obsolete and maintained for backward compatibility with + legacy TE Schedule, please use :func:`build` directly. Parameters ---------- From aae82088b6d06fd409687f4c22caeb0a50543f5f Mon Sep 17 00:00:00 2001 From: Zihao Date: Mon, 3 Jul 2023 06:39:59 -0700 Subject: [PATCH 30/43] import error --- python/tvm/runtime/ndarray.py | 5 +---- tests/python/unittest/test_target_codegen_llvm.py | 1 + 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index 6857c5e2d958..84b2026f2a6a 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -19,11 +19,8 @@ import ctypes import warnings import numpy as np +import ml_dtypes -try: - import ml_dtypes -except ImportError: - ml_dtypes = None import tvm._ffi from tvm._ffi.base import _LIB, check_call, c_array, string_types, _FFI_MODE diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index aa20c6078f97..c1a78759faa5 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -22,6 +22,7 @@ import pytest import re import sys +import ml_dtypes import tvm import tvm.testing From 4fc069e1687aa0c83b9552591309bfad255b8076 Mon Sep 17 00:00:00 2001 From: Zihao Date: Mon, 3 Jul 2023 09:14:34 -0700 Subject: [PATCH 31/43] add ml_dtypes to build-environment.yaml --- conda/build-environment.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/conda/build-environment.yaml b/conda/build-environment.yaml index a1b43eb6ef0c..78128e07406e 100644 --- a/conda/build-environment.yaml +++ b/conda/build-environment.yaml @@ -37,3 +37,4 @@ dependencies: - make - scipy - pillow + - ml_dtypes From 3eae70dd9f64853b60fdadbf8fcb15e7803f810d Mon Sep 17 00:00:00 2001 From: Zihao Date: Mon, 3 Jul 2023 09:19:51 -0700 Subject: [PATCH 32/43] update docker scripts --- docker/install/ubuntu2004_install_python_package.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docker/install/ubuntu2004_install_python_package.sh b/docker/install/ubuntu2004_install_python_package.sh index 10c9b680c680..ff262c9cb05c 100644 --- a/docker/install/ubuntu2004_install_python_package.sh +++ b/docker/install/ubuntu2004_install_python_package.sh @@ -43,4 +43,5 @@ pip3 install --upgrade \ junitparser==2.4.2 \ six \ tornado \ - pytest-lazy-fixture + pytest-lazy-fixture \ + ml_dtypes From f1b24f9b04020af0440b5ce5bde3044051a4ef7c Mon Sep 17 00:00:00 2001 From: Zihao Date: Mon, 3 Jul 2023 19:55:13 -0700 Subject: [PATCH 33/43] bugfix --- conda/build-environment.yaml | 3 +++ python/tvm/relay/frontend/common.py | 1 + python/tvm/relay/frontend/pytorch.py | 2 ++ src/support/scalars.h | 2 +- tests/python/frontend/onnx/test_forward.py | 12 ++------- tests/python/frontend/pytorch/test_forward.py | 9 ++++--- tests/python/relay/test_op_level2.py | 26 ------------------- 7 files changed, 15 insertions(+), 40 deletions(-) diff --git a/conda/build-environment.yaml b/conda/build-environment.yaml index 78128e07406e..507105ab92ce 100644 --- a/conda/build-environment.yaml +++ b/conda/build-environment.yaml @@ -38,3 +38,6 @@ dependencies: - scipy - pillow - ml_dtypes + - pip + - pip: + - ml_dtypes diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 0433d3b52ebf..b8d31cb8c201 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -21,6 +21,7 @@ import numpy as np import tvm +import ml_dtypes from tvm.ir import IRModule from tvm.topi.utils import get_const_tuple diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 5e4d75599613..c32d8ee263b8 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -4375,6 +4375,8 @@ def _convert_data_type(input_type, default_dtype=None): return "float32" elif input_type in ["half", "float16", "torch.float16"]: return "float16" + elif input_type in ["bfloat16", "torch.bfloat16"]: + return "bfloat16" elif input_type in ["long", "int64", "torch.int64"]: return "int64" elif input_type in ["int", "int32", "torch.int32"]: diff --git a/src/support/scalars.h b/src/support/scalars.h index 2b34914565ed..79ca53f60444 100644 --- a/src/support/scalars.h +++ b/src/support/scalars.h @@ -67,7 +67,7 @@ constexpr double kMaxFloat16 = 65504.0; // 2^127 * (1 + 127/128) // See https://en.wikipedia.org/wiki/Bfloat16_floating-point_format -constexpr double kMaxBFloat16 = 3.895313892515354759047080037148786688e38; +constexpr double kMaxBFloat16 = 3.3895313892515354759047080037148786688e38; // 2^8 * (1 + 6/8) // See https://arxiv.org/pdf/2209.05433.pdf diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 72265e49818c..7d4cfca757c4 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -29,6 +29,7 @@ import pytest import scipy import numpy as np +import ml_dtypes import tvm import tvm.testing @@ -99,13 +100,6 @@ def get_tvm_output_with_vm( freeze_params=freeze_params, convert_config=convert_config, ) - # handle the bfloat16 so we explicitly allocate - # bfloat16 arrays as input - for i, param in enumerate(mod["main"].params): - if param.type_annotation.dtype == "bfloat16": - input_data[i] = tvm.nd.empty(input_data[i].shape, "bfloat16").copyfrom( - input_data[i] - ) if validate_structural_equal: with tvm.testing.enable_span_filling(): @@ -5594,9 +5588,7 @@ def test_onnx_nodes(target, dev, onnx_test): atol = 1e-4 if "to_BFLOAT16" in test_dir: - # the tolerance here is for the comparison in uint16 space, but is not as significant - # of a delta in bfloat16 space because it's representing the mantissa being off by 1 - atol = 1 + atol = 1e-2 if "_sce_" in test_dir: # complicated loss functions like SoftmaxCrossEntropy can have minor variations diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 83930d1ea80b..f525005a9d90 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -24,6 +24,7 @@ import pytest import numpy as np +import ml_dtypes import torch from torch.nn import Module @@ -2694,7 +2695,10 @@ def verify_model_vm(input_model, ishapes, idtype=None, idata=None, targets=None) # Inference for name, inp in zip(input_names, input_data): - params[name] = inp.numpy() + if inp.dtype == torch.bfloat16: + params[name] = inp.float().numpy().astype("bfloat16") + else: + params[name] = inp.numpy() vm_res = evaluator(**params) # Baseline result @@ -3992,8 +3996,7 @@ def forward(self, arg): verify_script_model(IsFloatingPoint(), [(1, 1)], targets, idtype=torch.float64) verify_script_model(IsFloatingPoint(), [(1, 1)], targets, idtype=torch.float32) verify_script_model(IsFloatingPoint(), [(1, 1)], targets, idtype=torch.float16) - # todo(dvisnty): Run the test for bfloat16 when full bfloat16 support is implemented - # verify_script_model(IsFloatingPoint(), [(1,1)], targets, idtype=torch.bfloat16) + verify_script_model(IsFloatingPoint(), [(1, 1)], targets, idtype=torch.bfloat16) verify_script_model(IsFloatingPoint(), [(1, 1)], targets, idtype=torch.int64) verify_script_model(IsFloatingPoint(), [(1, 1)], targets, idtype=torch.int32) verify_script_model(IsFloatingPoint(), [(1, 1)], targets, idtype=torch.int16) diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 0a0ae561ab73..15cd92d0450d 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -2022,22 +2022,6 @@ def test_conv2d_rocm_sdot4(): np.testing.assert_equal(out, ref) -def np_float2tvm_bf16(arr): - """Convert a numpy array of float to a TVM array - of bf16""" - orig = arr.view(" Date: Tue, 4 Jul 2023 02:26:08 -0700 Subject: [PATCH 34/43] import ml_dtypes for all --- python/tvm/__init__.py | 1 + python/tvm/_ffi/runtime_ctypes.py | 12 +++------- python/tvm/relay/frontend/common.py | 1 - python/tvm/runtime/ndarray.py | 22 ------------------- tests/python/frontend/onnx/test_forward.py | 1 - tests/python/frontend/pytorch/test_forward.py | 1 - .../unittest/test_target_codegen_cuda.py | 2 +- .../unittest/test_target_codegen_llvm.py | 1 - 8 files changed, 5 insertions(+), 36 deletions(-) diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 5b6fbe7b2546..af8aab9779a6 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -20,6 +20,7 @@ import sys import os import traceback +import ml_dtypes # top-level alias # tvm._ffi diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 80be3a56fc58..3c1cb4c4d0b9 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -21,10 +21,6 @@ import numpy as np -try: - import ml_dtypes -except ImportError: - ml_dtypes = None from .base import _LIB, check_call tvm_shape_index_t = ctypes.c_int64 @@ -96,6 +92,9 @@ class DataType(ctypes.Structure): np.dtype(np.float32): "float32", np.dtype(np.float64): "float64", np.dtype(np.float_): "float64", + np.dtype("bfloat16"): "bfloat16", + np.dtype("float8_e4m3fn"): "e4m3_float8", + np.dtype("float8_e5m2"): "e5m2_float8", } STR2DTYPE = { "void": {"type_code": DataTypeCode.HANDLE, "bits": 0, "lanes": 0}, @@ -204,11 +203,6 @@ def __ne__(self, other): return not self.__eq__(other) -if ml_dtypes is not None: - DataType.NUMPY2STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16" - DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "e4m3_float8" - DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e5m2)] = "e5m2_float8" - RPC_SESS_MASK = 128 diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index b8d31cb8c201..0433d3b52ebf 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -21,7 +21,6 @@ import numpy as np import tvm -import ml_dtypes from tvm.ir import IRModule from tvm.topi.utils import get_const_tuple diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index 84b2026f2a6a..aa273b2b667f 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -19,7 +19,6 @@ import ctypes import warnings import numpy as np -import ml_dtypes import tvm._ffi @@ -218,27 +217,6 @@ def numpy(self): dtype = str(t) if dtype == "int4": dtype = "int8" - if dtype == "bfloat16": - if ml_dtypes is not None: - dtype = ml_dtypes.bfloat16 - else: - raise RuntimeError( - "ml_dtypes is not installed, cannot convert bfloat16 array to numpy." - ) - if dtype == "e4m3_float8": - if ml_dtypes is not None: - dtype = ml_dtypes.float8_e4m3fn - else: - raise RuntimeError( - "ml_dtypes is not installed, cannot convert e4m3_float8 array to numpy." - ) - if dtype == "e5m2_float8": - if ml_dtypes is not None: - dtype = ml_dtypes.float8_e5m2 - else: - raise RuntimeError( - "ml_dtypes is not installed, cannot convert e5m2_float8 array to numpy." - ) np_arr = np.empty(shape, dtype=dtype) assert np_arr.flags["C_CONTIGUOUS"] data = np_arr.ctypes.data_as(ctypes.c_void_p) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 7d4cfca757c4..79e773073dbc 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -29,7 +29,6 @@ import pytest import scipy import numpy as np -import ml_dtypes import tvm import tvm.testing diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index f525005a9d90..96e1d8194d67 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -24,7 +24,6 @@ import pytest import numpy as np -import ml_dtypes import torch from torch.nn import Module diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index 4b34d20250e1..b4690e553cb2 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -24,7 +24,7 @@ from tvm.contrib.nvcc import have_fp16, have_int8, have_bf16 import tvm.testing import pytest -import ml_dtypes + tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index c1a78759faa5..aa20c6078f97 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -22,7 +22,6 @@ import pytest import re import sys -import ml_dtypes import tvm import tvm.testing From 6e1c7cb4c651630a454976754ccc94943144ca93 Mon Sep 17 00:00:00 2001 From: Zihao Date: Tue, 4 Jul 2023 02:41:17 -0700 Subject: [PATCH 35/43] fix bug --- conda/build-environment.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/conda/build-environment.yaml b/conda/build-environment.yaml index 507105ab92ce..220c084a5374 100644 --- a/conda/build-environment.yaml +++ b/conda/build-environment.yaml @@ -37,7 +37,6 @@ dependencies: - make - scipy - pillow - - ml_dtypes - pip - pip: - ml_dtypes From c3415cea1ba8d331df14d27040e8ae3edb8feacf Mon Sep 17 00:00:00 2001 From: Zihao Date: Sat, 8 Jul 2023 01:40:14 -0700 Subject: [PATCH 36/43] fix --- docker/install/ubuntu2004_install_python_package.sh | 4 ++-- tests/python/relay/test_op_level2.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docker/install/ubuntu2004_install_python_package.sh b/docker/install/ubuntu2004_install_python_package.sh index ff262c9cb05c..46e243e0cd64 100644 --- a/docker/install/ubuntu2004_install_python_package.sh +++ b/docker/install/ubuntu2004_install_python_package.sh @@ -43,5 +43,5 @@ pip3 install --upgrade \ junitparser==2.4.2 \ six \ tornado \ - pytest-lazy-fixture \ - ml_dtypes + pytest-lazy-fixture + diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 15cd92d0450d..87f5148af0d0 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -2053,8 +2053,8 @@ def get_subgraph(dtype): for t in ["float32", "bfloat16"]: mod = tvm.IRModule.from_expr(get_subgraph(t)) - data_np = np.random.uniform(1, 10, d_shape).astype("float32") - weight_np = np.random.uniform(1, 10, size=w_shape).astype("float32") + data_np = np.random.uniform(1, 10, d_shape).astype(t) + weight_np = np.random.uniform(1, 10, size=w_shape).astype(t) ref = tvm.topi.testing.conv2d_nchw_python(data_np, weight_np, strides, padding) target = "llvm -mcpu=skylake-avx512 -libs=dnnl" From b79c9d7e9501809b11bef75636d2b90aca7a7c9e Mon Sep 17 00:00:00 2001 From: Zihao Date: Sun, 16 Jul 2023 15:20:42 -0700 Subject: [PATCH 37/43] resolve issues --- python/tvm/relay/frontend/onnx.py | 2 +- python/tvm/topi/arm_cpu/injective.py | 2 +- python/tvm/topi/nn/winograd_util.py | 1 - src/arith/rewrite_simplify.cc | 2 +- src/auto_scheduler/feature.cc | 20 ++++++++++---------- src/autotvm/touch_extractor.h | 10 +++++----- src/ir/expr.cc | 3 +-- src/relay/op/nn/nn.cc | 3 +-- src/relay/transforms/to_mixed_precision.cc | 11 +++++------ src/target/source/codegen_cuda.cc | 1 - tests/python/frontend/onnx/test_forward.py | 8 ++++++++ tests/python/relay/test_op_level2.py | 13 ++++++------- tests/python/unittest/test_tir_imm_values.py | 9 ++++++++- 13 files changed, 47 insertions(+), 38 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 65e308a257e4..cfe25485fd17 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -115,7 +115,7 @@ def get_type(elem_type): raise ImportError(f"Unable to import TensorProto from onnx {e}") # Onnx mapping converts bfloat16 to float16 because - # numpy does not have a bfloat16 data type. However, + # onnx does not have a bfloat16 data type. However, # tvm has one, so we force the return type to be bfloat16 if elem_type == int(TensorProto.BFLOAT16): return "bfloat16" diff --git a/python/tvm/topi/arm_cpu/injective.py b/python/tvm/topi/arm_cpu/injective.py index 5c63e5a513db..5b4c61e1367f 100644 --- a/python/tvm/topi/arm_cpu/injective.py +++ b/python/tvm/topi/arm_cpu/injective.py @@ -68,7 +68,7 @@ def schedule_injective(outs): if list(s[x].op.axis): # do not vectorize for broadcast - dtype = "uint16" if x.dtype == "bfloat16" else x.dtype + dtype = x.dtype (io, ii) = s[x].split(list(s[x].op.axis)[-1], 16 // np.dtype(dtype).itemsize) s[x].vectorize(ii) tvm.te.schedule.AutoInlineInjective(s) diff --git a/python/tvm/topi/nn/winograd_util.py b/python/tvm/topi/nn/winograd_util.py index 8c2f50d7f8a6..77b0dbca824f 100644 --- a/python/tvm/topi/nn/winograd_util.py +++ b/python/tvm/topi/nn/winograd_util.py @@ -169,7 +169,6 @@ def winograd_transform_matrices(tile_size, kernel_size, out_dtype): intp_pts = _interpolation_points(degree) A_data, B_data, G_data = _cook_toom_convolution(intp_pts, tile_size, kernel_size) - out_dtype = "uint16" if out_dtype == "bfloat16" else out_dtype return ( const_matrix(A_data.astype(out_dtype), "A"), const_matrix(B_data.astype(out_dtype), "B"), diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 40088fd963d7..572517f8eaaf 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -621,7 +621,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { // x / 2.0 = x * 0.5 if (const FloatImmNode* ptr = op->b.as()) { - ICHECK(op->dtype.is_float() || op->dtype.is_bfloat16() || + ICHECK(op->dtype.is_floating_point() || datatype::Registry::Global()->GetTypeRegistered(op->dtype.code())); return op->a * make_const(op->b.dtype(), 1.0 / ptr->value); } diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index 65cc13eb61fc..c81667b54549 100644 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -248,14 +248,14 @@ int64_t GetLoopExtent(const ForNode* node, const Analyzer& ana) { // Count math ops in an expr class MathOpCounter : public StmtExprVisitor { public: -#define VisitBinary(Type, float_ct, int_ct) \ - void VisitExpr_(const Type* op) final { \ - if (op->a.dtype().is_float() || op->a.dtype().is_bfloat16()) { \ - float_ct += op->a.dtype().lanes(); \ - } else { \ - int_ct += op->a.dtype().lanes(); \ - } \ - StmtExprVisitor::VisitExpr_(op); \ +#define VisitBinary(Type, float_ct, int_ct) \ + void VisitExpr_(const Type* op) final { \ + if (op->a.dtype().is_floating_point()) { \ + float_ct += op->a.dtype().lanes(); \ + } else { \ + int_ct += op->a.dtype().lanes(); \ + } \ + StmtExprVisitor::VisitExpr_(op); \ } VisitBinary(AddNode, float_addsub, int_addsub); @@ -301,13 +301,13 @@ class MathOpCounter : public StmtExprVisitor { effect_kind == CallEffectKind::kPure || effect_kind == CallEffectKind::kExprAnnotation; if (is_pure) { - if (op->dtype.is_float() || op->dtype.is_bfloat16()) { + if (op->dtype.is_floating_point()) { float_math_func++; } else { int_math_func++; } } else { - if (op->dtype.is_float() || op->dtype.is_bfloat16()) { + if (op->dtype.is_floating_point()) { float_other_func++; } else { int_other_func++; diff --git a/src/autotvm/touch_extractor.h b/src/autotvm/touch_extractor.h index 83260e1e0633..d6e8f7cf6053 100644 --- a/src/autotvm/touch_extractor.h +++ b/src/autotvm/touch_extractor.h @@ -87,35 +87,35 @@ class TouchExtractor : public FeatureVisitor { // arithmetic stats void VisitExpr_(const AddNode* op) final { - if (op->dtype.is_float() || op->dtype.is_bfloat16()) { + if (op->dtype.is_floating_point()) { itervar_map[itervar_stack_.back()].add_ct++; } FeatureVisitor::VisitExpr_(op); } void VisitExpr_(const SubNode* op) final { - if (op->dtype.is_float() || op->dtype.is_bfloat16()) { + if (op->dtype.is_floating_point()) { itervar_map[itervar_stack_.back()].add_ct++; } FeatureVisitor::VisitExpr_(op); } void VisitExpr_(const MulNode* op) final { - if (op->dtype.is_float() || op->dtype.is_bfloat16()) { + if (op->dtype.is_floating_point()) { itervar_map[itervar_stack_.back()].mul_ct++; } FeatureVisitor::VisitExpr_(op); } void VisitExpr_(const DivNode* op) final { - if (op->dtype.is_float() || op->dtype.is_bfloat16()) { + if (op->dtype.is_floating_point()) { itervar_map[itervar_stack_.back()].div_ct++; } FeatureVisitor::VisitExpr_(op); } void VisitExpr_(const ModNode* op) final { - if (op->dtype.is_float() || op->dtype.is_bfloat16()) { + if (op->dtype.is_floating_point()) { itervar_map[itervar_stack_.back()].div_ct++; } FeatureVisitor::VisitExpr_(op); diff --git a/src/ir/expr.cc b/src/ir/expr.cc index fdd8c2cd8bc5..ed4051813ed7 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -104,8 +104,7 @@ TVM_REGISTER_NODE_TYPE(IntImmNode); FloatImm::FloatImm(DataType dtype, double value, Span span) { ICHECK_EQ(dtype.lanes(), 1) << "ValueError: FloatImm can only take scalar."; - ICHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() || - dtype.code() >= DataType::kCustomBegin) + ICHECK(dtype.is_floating_point() || dtype.code() >= DataType::kCustomBegin) << "ValueError: FloatImm supports only float, but " << dtype << " was supplied."; // check range for float32 and float16 since they have specified range. diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 9e2fe63b006a..832aac17b2b5 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -1286,8 +1286,7 @@ bool NLLLossRel(const Array& types, int num_inputs, const Attrs& attrs, << ", weights shape = " << weights->shape); return false; } - if (!(predictions->dtype == weights->dtype && - (predictions->dtype.is_float() || predictions->dtype.is_bfloat16()))) { + if (!(predictions->dtype == weights->dtype && (predictions->dtype.is_floating_point()))) { reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) << "NLLLossRel: predictions and weights should" << " be of the same floating type."); diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index 820bc6e58e4d..0b86c3703481 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -161,7 +161,7 @@ class MixedPrecisionPass : public MixedModeMutator { */ DataType cur_type = (attrs->out_dtype); ObjectPtr new_attrs = make_object(*attrs); - if (cur_type.is_float() || cur_type.is_bfloat16() || cur_type.is_void()) { + if (cur_type.is_floating_point() || cur_type.is_void()) { new_attrs->out_dtype = accumulation_dtype; } return Attrs(new_attrs); @@ -177,7 +177,7 @@ class MixedPrecisionPass : public MixedModeMutator { */ DataType cur_type = (attrs->dtype); ObjectPtr new_attrs = make_object(*attrs); - if (cur_type.is_float() || cur_type.is_bfloat16() || cur_type.is_void()) { + if (cur_type.is_floating_point() || cur_type.is_void()) { new_attrs->dtype = accumulation_dtype; } return Attrs(new_attrs); @@ -202,8 +202,7 @@ class MixedPrecisionPass : public MixedModeMutator { If ignore_non_float, then ignore non-floating types. */ if (const TensorTypeNode* tensor_type = t.as()) { - bool is_supported_floating_point_type = - (tensor_type->dtype).is_float() || (tensor_type->dtype).is_bfloat16(); + bool is_supported_floating_point_type = tensor_type->dtype.is_floating_point(); return (ignore_non_float && !is_supported_floating_point_type) || tensor_type->dtype == mixed_precision_type_; } else if (const TupleTypeNode* tuple_type = t.as()) { @@ -220,7 +219,7 @@ class MixedPrecisionPass : public MixedModeMutator { /* Cast tensor to the wanted datatype, returning a cached version if it's already been done. */ // If this is not a floating point type, do not cast. E.g. it might be an integer - if (!(expr_dtype.is_float() || expr_dtype.is_bfloat16())) { + if (!(expr_dtype.is_floating_point())) { return expr; } @@ -302,7 +301,7 @@ class MixedPrecisionPass : public MixedModeMutator { original_dtype_.push_back((root_->checked_type_).as()->dtype); } } - if (!(mixed_precision_type_.is_float() || mixed_precision_type_.is_bfloat16())) { + if (!(mixed_precision_type_.is_floating_point())) { LOG(FATAL) << "Only support IEEE floating point mixed precision types and bfloat16, but got " << mixed_precision_type_; } diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 63b3572e3939..9a78c98994eb 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -1390,7 +1390,6 @@ void CodeGenCUDA::HandleVolatileLoads(const std::string& value, const BufferLoad // Cast away volatile qualifier for fp16 types. That is, only loads and // stores are volatile. The loaded objects are not marked as volatile. // - // TODO(Zihao): figure out what it is if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) && IsVolatile(op->buffer->data.get())) { os << "("; PrintType(op->dtype, os); diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 79e773073dbc..5a47706b0000 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -99,6 +99,14 @@ def get_tvm_output_with_vm( freeze_params=freeze_params, convert_config=convert_config, ) + # handle the bfloat16 so we explicitly allocate + # bfloat16 arrays as input + for i, param in enumerate(mod["main"].params): + if param.type_annotation.dtype == "bfloat16": + # cast uint16 to bloat16 + input_data[i] = tvm.nd.empty(input_data[i].shape, "bfloat16").copyfrom( + input_data[i].view("bfloat16") + ) if validate_structural_equal: with tvm.testing.enable_span_filling(): diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 87f5148af0d0..042104fd0dd1 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -2083,8 +2083,8 @@ def test_conv2d_nhwc_dnnl(): built with dnnl=ON" ) return - d_shape = (1, 56, 56, 64) - w_shape = (3, 3, 64, 64) + d_shape = (1, 56, 56, 32) + w_shape = (3, 3, 32, 64) padding = (1, 1) strides = (1, 1) @@ -2108,8 +2108,8 @@ def get_subgraph(dtype): for t in ["float32", "bfloat16"]: mod = tvm.IRModule.from_expr(get_subgraph(t)) - data_np = np.random.uniform(1, 10, d_shape).astype("float32") - weight_np = np.random.uniform(1, 10, size=w_shape).astype("float32") + data_np = np.random.uniform(0, 10, size=d_shape).astype(t) / 10 + weight_np = np.random.uniform(0, 10, size=w_shape).astype(t) / 10 ref = tvm.topi.testing.conv2d_nhwc_python(data_np, weight_np, strides, padding) target = "llvm -mcpu=skylake-avx512 -libs=dnnl" @@ -2121,13 +2121,12 @@ def get_subgraph(dtype): runtime.set_input("data", data_np) runtime.run() - out = runtime.get_output(0).numpy() if t == "bfloat16": - np.testing.assert_allclose(out, ref, rtol=1e-2) + np.testing.assert_allclose(out, ref, rtol=3e-1) else: - np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3) def _test_conv2d_int8_alter_dtype(data_dtype, target, dot_product_instrs): diff --git a/tests/python/unittest/test_tir_imm_values.py b/tests/python/unittest/test_tir_imm_values.py index 416943c85da6..90a65a281583 100644 --- a/tests/python/unittest/test_tir_imm_values.py +++ b/tests/python/unittest/test_tir_imm_values.py @@ -107,7 +107,14 @@ def compare_float_value(value, expect, msg): "dtype, literals", [ ["float16", [-65504.0, 3.14, 65504.0, np.inf, np.nan]], - ["bfloat16", [-3.38953139e38, 3.38953139e38, 3.14]], + [ + "bfloat16", + [ + -3.3895313892515354759047080037148786688e38, + 3.3895313892515354759047080037148786688e38, + 3.14, + ], + ], ["float32", [np.finfo("float32").min, 3.14, np.finfo("float32").max, np.inf, np.nan]], ["float64", [np.finfo("float64").min, 3.14, np.finfo("float64").max, np.inf, np.nan]], ], From e079c28cf184be70d211c1043cf0f049216d1483 Mon Sep 17 00:00:00 2001 From: Zihao Date: Mon, 17 Jul 2023 02:52:09 -0700 Subject: [PATCH 38/43] fix tests --- python/tvm/contrib/nvcc.py | 2 +- python/tvm/runtime/ndarray.py | 2 +- tests/python/frontend/onnx/test_forward.py | 8 +++++--- tests/python/relay/test_op_level2.py | 14 +++++++------- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index 5eb348009914..6dc0503acc94 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -100,7 +100,7 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target # Because it is hard to do runtime compiler detection, we require nvcc is configured # correctly by default. # if cxx_compiler_path != "": - # cmd += ["-ccbin", cxx_compiler_path] + cmd += ["-ccbin", "/home/expye/g++11/bin"] proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index aa273b2b667f..3ad80393f143 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -175,7 +175,7 @@ def copyfrom(self, source_array): dtype = "float8_e4m3fn" elif dtype == "e5m2_float8": dtype = "float8_e5m2" - source_array = np.ascontiguousarray(source_array, dtype) + source_array = np.ascontiguousarray(source_array).view(dtype) assert source_array.flags["C_CONTIGUOUS"] data = source_array.ctypes.data_as(ctypes.c_void_p) nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 5a47706b0000..083a89b6b3f9 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -103,9 +103,8 @@ def get_tvm_output_with_vm( # bfloat16 arrays as input for i, param in enumerate(mod["main"].params): if param.type_annotation.dtype == "bfloat16": - # cast uint16 to bloat16 input_data[i] = tvm.nd.empty(input_data[i].shape, "bfloat16").copyfrom( - input_data[i].view("bfloat16") + input_data[i] ) if validate_structural_equal: @@ -5558,7 +5557,10 @@ def _load_proto(proto_filename, target_list, model_type_proto): elif model_type_proto.HasField("tensor_type"): tensor = onnx.TensorProto() tensor.ParseFromString(protobuf_content) - target_list.append(numpy_helper.to_array(tensor)) + np_tensor = numpy_helper.to_array(tensor) + if model_type_proto.tensor_type.elem_type == TensorProto.BFLOAT16: + np_tensor = np_tensor.view("bfloat16") + target_list.append(np_tensor) elif model_type_proto.HasField("optional_type"): optional = onnx.OptionalProto() optional.ParseFromString(protobuf_content) diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 042104fd0dd1..1973926f2031 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -1540,7 +1540,7 @@ def test_batch_flatten(): ref_res = batch_flatten(data) for target, dev in tvm.testing.enabled_targets(): op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)(data) - np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=0.01) + tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=0.01) def _test_upsampling(layout, method, align_corners=False): @@ -2019,7 +2019,7 @@ def test_conv2d_rocm_sdot4(): data_np.astype("int32"), weight_np.astype("int32"), strides, padding ) - np.testing.assert_equal(out, ref) + tvm.testing.assert_equal(out, ref) @tvm.testing.requires_x86 @@ -2070,9 +2070,9 @@ def get_subgraph(dtype): out = runtime.get_output(0).numpy() if t == "bfloat16": - np.testing.assert_allclose(out, ref, rtol=1e-2) + tvm.testing.assert_allclose(out, ref, rtol=3e-1) else: - np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5) + tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5) @tvm.testing.requires_x86 @@ -2124,9 +2124,9 @@ def get_subgraph(dtype): out = runtime.get_output(0).numpy() if t == "bfloat16": - np.testing.assert_allclose(out, ref, rtol=3e-1) + tvm.testing.assert_allclose(out, ref, rtol=3e-1) else: - np.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3) + tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3) def _test_conv2d_int8_alter_dtype(data_dtype, target, dot_product_instrs): @@ -2197,7 +2197,7 @@ def get_conv2d_nchw( out = rt_mod.get_output(0).numpy() - np.testing.assert_equal(out, ref) + tvm.testing.assert_equal(out, ref) @tvm.testing.requires_arm_dot From 81ee32eb3ce7a1182afdfded2be705801fecd5c9 Mon Sep 17 00:00:00 2001 From: Zihao Date: Mon, 17 Jul 2023 02:54:14 -0700 Subject: [PATCH 39/43] revert changes in nvcc.py --- python/tvm/contrib/nvcc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index 6dc0503acc94..5eb348009914 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -100,7 +100,7 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target # Because it is hard to do runtime compiler detection, we require nvcc is configured # correctly by default. # if cxx_compiler_path != "": - cmd += ["-ccbin", "/home/expye/g++11/bin"] + # cmd += ["-ccbin", cxx_compiler_path] proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) From 524191f8165b69b3fdb99bc7ed3c8f8790efe448 Mon Sep 17 00:00:00 2001 From: Zihao Date: Mon, 17 Jul 2023 02:57:26 -0700 Subject: [PATCH 40/43] fix lint --- docker/install/ubuntu2004_install_python_package.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/docker/install/ubuntu2004_install_python_package.sh b/docker/install/ubuntu2004_install_python_package.sh index 46e243e0cd64..10c9b680c680 100644 --- a/docker/install/ubuntu2004_install_python_package.sh +++ b/docker/install/ubuntu2004_install_python_package.sh @@ -44,4 +44,3 @@ pip3 install --upgrade \ six \ tornado \ pytest-lazy-fixture - From abbe2ef300e367b07aa8351543dd78e9e2ecfa9b Mon Sep 17 00:00:00 2001 From: Zihao Date: Mon, 17 Jul 2023 11:41:41 -0700 Subject: [PATCH 41/43] fix copyfrom semantics --- python/tvm/runtime/ndarray.py | 2 +- tests/python/frontend/onnx/test_forward.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index 3ad80393f143..aa273b2b667f 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -175,7 +175,7 @@ def copyfrom(self, source_array): dtype = "float8_e4m3fn" elif dtype == "e5m2_float8": dtype = "float8_e5m2" - source_array = np.ascontiguousarray(source_array).view(dtype) + source_array = np.ascontiguousarray(source_array, dtype) assert source_array.flags["C_CONTIGUOUS"] data = source_array.ctypes.data_as(ctypes.c_void_p) nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 083a89b6b3f9..2583a0b4fc77 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -104,7 +104,7 @@ def get_tvm_output_with_vm( for i, param in enumerate(mod["main"].params): if param.type_annotation.dtype == "bfloat16": input_data[i] = tvm.nd.empty(input_data[i].shape, "bfloat16").copyfrom( - input_data[i] + input_data[i].view("bfloat16") ) if validate_structural_equal: From 1f7c54a55733fa6e2dd0768cc3cf7661d1ef487b Mon Sep 17 00:00:00 2001 From: Zihao Date: Mon, 17 Jul 2023 17:26:25 -0700 Subject: [PATCH 42/43] use numpy's impl for assert equal --- tests/python/relay/test_op_level2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 1973926f2031..01f8c9439283 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -2019,7 +2019,7 @@ def test_conv2d_rocm_sdot4(): data_np.astype("int32"), weight_np.astype("int32"), strides, padding ) - tvm.testing.assert_equal(out, ref) + np.testing.assert_equal(out, ref) @tvm.testing.requires_x86 @@ -2197,7 +2197,7 @@ def get_conv2d_nchw( out = rt_mod.get_output(0).numpy() - tvm.testing.assert_equal(out, ref) + np.testing.assert_equal(out, ref) @tvm.testing.requires_arm_dot From 32298ea8a665d3b07b54221f25f111b111d7f445 Mon Sep 17 00:00:00 2001 From: Zihao Date: Mon, 24 Jul 2023 04:55:59 -0700 Subject: [PATCH 43/43] test on tlcpack-staging docker images --- ci/jenkins/docker-images.ini | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ci/jenkins/docker-images.ini b/ci/jenkins/docker-images.ini index a5c95533e971..3ce43131357d 100644 --- a/ci/jenkins/docker-images.ini +++ b/ci/jenkins/docker-images.ini @@ -21,9 +21,9 @@ ci_arm: tlcpack/ci-arm:20230615-060132-62a5e7acf ci_cortexm: tlcpack/ci-cortexm:20230613-060122-21361a63a ci_cpu: tlcpack/ci-cpu:20230604-060130-0af9ff90e ci_gpu: tlcpack/ci-gpu:20230504-142417-4d37a0a0 -ci_hexagon: tlcpack/ci-hexagon:20230504-142417-4d37a0a0 -ci_i386: tlcpack/ci-i386:20230504-142417-4d37a0a0 +ci_hexagon: tlcpackstaging/ci_hexagon:20230724-060135-684689e92 +ci_i386: tlcpackstaging/ci_i386:20230724-060135-684689e92 ci_lint: tlcpack/ci-lint:20230504-142417-4d37a0a0 -ci_minimal: tlcpack/ci-minimal:20230504-142417-4d37a0a0 -ci_riscv: tlcpack/ci-riscv:20230504-142417-4d37a0a0 -ci_wasm: tlcpack/ci-wasm:20230504-142417-4d37a0a0 +ci_minimal: tlcpackstaging/ci_minimal:20230724-060135-684689e92 +ci_riscv: tlcpackstaging/ci_riscv:20230724-060135-684689e92 +ci_wasm: tlcpackstaging/ci_wasm:20230724-060135-684689e92