From a941951fb64d69825553bf84828efb4628d736d3 Mon Sep 17 00:00:00 2001 From: Zihao Date: Fri, 2 Jun 2023 17:21:21 -0700 Subject: [PATCH 01/37] upd --- src/tir/transforms/unsupported_dtype_legalize.cc | 1 - tests/python/unittest/test_target_codegen_cuda.py | 2 ++ 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index be8876b81550..0040951961dc 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -716,7 +716,6 @@ TVM_REGISTER_GLOBAL("tir.transform.FP8ComputeLegalize").set_body_typed(FP8Comput Pass FP8StorageLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - LOG(INFO) << f; // TODO(tvm-team): skip if the target supports fp8 return FP8StorageLegalizer().Legalize(f); }; diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index e60138a9c8d6..7685d64368ef 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -101,6 +101,8 @@ def check_cuda(n, lanes): disabled_pass=["tir.BF16Promote", "tir.BF16CastElimination", "tir.BF16TypeLowering"] ): fun = tvm.build(s, [A, B], "cuda") + print(fun.imported_modules[0].get_source()) + assert False dev = tvm.cuda(0) np_a = np.random.uniform(size=(n, lanes)).astype("float32") np_a = np_bf162np_float(np_float2np_bf16(np_a)) From 32841512dea1e6aa1f73f601041a8af90756e00a Mon Sep 17 00:00:00 2001 From: Zihao Date: Fri, 2 Jun 2023 22:32:23 -0700 Subject: [PATCH 02/37] wip --- python/tvm/contrib/nvcc.py | 3 +- src/driver/driver_api.cc | 3 +- .../transforms/unsupported_dtype_legalize.cc | 34 ++++++++++++++----- .../unittest/test_target_codegen_cuda.py | 12 +++---- 4 files changed, 36 insertions(+), 16 deletions(-) diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index 5eb348009914..2ef3cf9df4dd 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -100,7 +100,8 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target # Because it is hard to do runtime compiler detection, we require nvcc is configured # correctly by default. # if cxx_compiler_path != "": - # cmd += ["-ccbin", cxx_compiler_path] + # cmd += ["-ccbin", cxx_compiler_path] + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index cfc7fa80c7a9..fecacff5b0d0 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -55,6 +55,8 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer); TVM_REGISTER_PASS_CONFIG_OPTION("tir.ptx_ldg32", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.hardware_support_bf16", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.hardware_support_fp8", Bool); // WARNING: May cause coherency issues resulting data miscompares // Experimental feature that, when enabled by the runtime, bypasses the cache when using DMA. When @@ -151,7 +153,6 @@ Array CreatePassList(bool disable_loop_partition) { bool disable_cse_tir = pass_ctx->GetConfig("tir.disable_cse_tir", Bool(false)).value(); bool enable_equiv_terms_in_cse_tir = pass_ctx->GetConfig("tir.enable_equiv_terms_in_cse_tir", Bool(false)).value(); - bool ptx_ldg32 = pass_ctx->GetConfig("tir.ptx_ldg32", Bool(false)).value(); // Get any user-added passes diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index 0040951961dc..74d881ab9e9a 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -686,8 +686,13 @@ namespace transform { Pass BF16ComputeLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - // TODO(tvm-team): skip if the target supports bf16 - return BF16ComputeLegalizer().Legalize(f); + bool target_support_bf16 = + ctx->GetConfig("tir.hardware_support_bf16", Bool(false)).value(); + if (target_support_bf16) { + return f; + } else { + return BF16ComputeLegalizer().Legalize(f); + } }; return CreatePrimFuncPass(pass_func, 0, "tir.BF16ComputeLegalize", {}); } @@ -696,8 +701,13 @@ TVM_REGISTER_GLOBAL("tir.transform.BF16ComputeLegalize").set_body_typed(BF16Comp Pass BF16StorageLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - // TODO(tvm-team): skip if the target supports bf16 - return BF16StorageLegalizer().Legalize(f); + bool target_support_bf16 = + ctx->GetConfig("tir.hardware_support_bf16", Bool(false)).value(); + if (target_support_bf16) { + return f; + } else { + return BF16StorageLegalizer().Legalize(f); + } }; return CreatePrimFuncPass(pass_func, 0, "tir.BF16StorageLegalize", {}); } @@ -706,8 +716,12 @@ TVM_REGISTER_GLOBAL("tir.transform.BF16StorageLegalize").set_body_typed(BF16Stor Pass FP8ComputeLegalize(String promote_dtype_str) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { - // TODO(tvm-team): skip if the target supports fp8 - return FP8ComputeLegalizer(DataType(String2DLDataType(promote_dtype_str))).Legalize(f); + bool target_support_fp8 = ctx->GetConfig("tir.hardware_support_fp8", Bool(false)).value(); + if (target_support_fp8) { + return f; + } else { + return FP8ComputeLegalizer(DataType(String2DLDataType(promote_dtype_str))).Legalize(f); + } }; return CreatePrimFuncPass(pass_func, 0, "tir.FP8ComputeLegalize", {}); } @@ -716,8 +730,12 @@ TVM_REGISTER_GLOBAL("tir.transform.FP8ComputeLegalize").set_body_typed(FP8Comput Pass FP8StorageLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - // TODO(tvm-team): skip if the target supports fp8 - return FP8StorageLegalizer().Legalize(f); + bool target_support_fp8 = ctx->GetConfig("tir.hardware_support_fp8", Bool(false)).value(); + if (target_support_fp8) { + return f; + } else { + return FP8StorageLegalizer().Legalize(f); + } }; return CreatePrimFuncPass(pass_func, 0, "tir.FP8StorageLegalize", {}); } diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index 7685d64368ef..fa9474e73d88 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -97,12 +97,12 @@ def check_cuda(n, lanes): xo, xi = s[B].split(B.op.axis[0], factor=num_thread) s[B].bind(xo, bx) s[B].bind(xi, tx) - with tvm.transform.PassContext( - disabled_pass=["tir.BF16Promote", "tir.BF16CastElimination", "tir.BF16TypeLowering"] - ): - fun = tvm.build(s, [A, B], "cuda") - print(fun.imported_modules[0].get_source()) - assert False + # with tvm.transform.PassContext( + # disabled_pass=["tir.BF16ComputeLegalize", "tir.BF16StorageLegalize"] + # ): + fun = tvm.build(s, [A, B], "cuda") + print(fun.imported_modules[0].get_source()) + assert False dev = tvm.cuda(0) np_a = np.random.uniform(size=(n, lanes)).astype("float32") np_a = np_bf162np_float(np_float2np_bf16(np_a)) From e753640b6d405575e13b93dbef19d444fa79ddb8 Mon Sep 17 00:00:00 2001 From: Zihao Date: Fri, 9 Jun 2023 17:50:44 -0700 Subject: [PATCH 03/37] upd --- src/target/source/cuda_vector_intrin.cc | 42 +++++++++++++++++++++++++ src/target/source/cuda_vector_intrin.h | 42 +++++++++++++++++++++++++ 2 files changed, 84 insertions(+) create mode 100644 src/target/source/cuda_vector_intrin.cc create mode 100644 src/target/source/cuda_vector_intrin.h diff --git a/src/target/source/cuda_vector_intrin.cc b/src/target/source/cuda_vector_intrin.cc new file mode 100644 index 000000000000..a3d957e42fcb --- /dev/null +++ b/src/target/source/cuda_vector_intrin.cc @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file cuda_vector_intrin.cc + * \brief Code generation with vector intrinsics in CUDA. + */ +#include "cuda_vector_intrin.h" + +namespace tvm { +namespace codegen { +namespace cuda { + +std::string PrintHalf2BinaryOp(const std::string& op, const std::string lhs, + const std::string rhs) { + return ""; +} + +std::string PrintNVBFloat162BinaryOp(const std::string& op, const std::string lhs, + const std::string rhs) { + return ""; +} + +} // namespace cuda +} // namespace codegen +} // namespace tvm diff --git a/src/target/source/cuda_vector_intrin.h b/src/target/source/cuda_vector_intrin.h new file mode 100644 index 000000000000..9f636681c0b4 --- /dev/null +++ b/src/target/source/cuda_vector_intrin.h @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file cuda_vector_intrin.h + * \brief Code generation with vector intrinsics in CUDA. + */ +#ifndef TVM_TARGET_SOURCE_CUDA_VECTOR_INTRIN_H_ +#define TVM_TARGET_SOURCE_CUDA_VECTOR_INTRIN_H_ + +#include + +namespace tvm { +namespace codegen { +namespace cuda { + +std::string PrintHalf2BinaryOp(const std::string& op, const std::string lhs, const std::string rhs); + +std::string PrintNVBFloat162BinaryOp(const std::string& op, const std::string lhs, + const std::string rhs); + +} // namespace cuda +} // namespace codegen +} // namespace tvm + +#endif // TVM_TARGET_SOURCE_CUDA_VECTOR_INTRIN_H_ \ No newline at end of file From 7f316368d67bc9011e46b9f46f2183691176394d Mon Sep 17 00:00:00 2001 From: Zihao Date: Thu, 29 Jun 2023 05:01:16 -0700 Subject: [PATCH 04/37] upd --- include/tvm/runtime/data_type.h | 6 +++- include/tvm/tir/op.h | 2 +- python/tvm/contrib/nvcc.py | 5 ++-- src/tir/op/op.cc | 40 +++++++++++++++++++------- src/tir/transforms/dtype_conversion.cc | 5 +--- try.py | 23 +++++++++++++++ 6 files changed, 63 insertions(+), 18 deletions(-) create mode 100644 try.py diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index 9fb113f56b2c..f2c7e09df369 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -95,7 +95,7 @@ class DataType { bool is_scalar() const { return lanes() == 1; } /*! \return whether type is a scalar type. */ bool is_bool() const { return code() == DataType::kUInt && bits() == 1; } - /*! \return whether type is a float type. */ + /*! \return whether type is a IEEE 754 standard float type. */ bool is_float() const { return code() == DataType::kFloat; } /*! \return whether type is a float8 type. */ bool is_float8() const { @@ -107,6 +107,10 @@ class DataType { bool is_float16() const { return is_float() && bits() == 16; } /*! \return whether type is a bfloat16 type. */ bool is_bfloat16() const { return code() == DataType::kBFloat && bits() == 16; } + /*! \return whether type is a general floating point data type. */ + bool is_floating_point() const { + return is_float() || is_float8() || is_bfloat16(); + } /*! \return whether type is an int type. */ bool is_int() const { return code() == DataType::kInt; } /*! \return whether type is an uint type. */ diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 3d5e589ab4a4..93738e0d773c 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -939,7 +939,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span()) return LargeUIntImm(t, static_cast(low), static_cast(high), span); } } - if (t.is_float() || t.is_bfloat16() || t.is_float8()) + if (t.is_floating_point()) return FloatImm(t, static_cast(value), span); // For now, we store const scalar values of custom datatypes within doubles; later, during the // datatypes lowering pass, we will lower the value to its true representation in the format diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index 2ef3cf9df4dd..bd36a325c813 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -99,8 +99,9 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target # However, we cannot use TVM_CXX_COMPILER_PATH because the runtime env. # Because it is hard to do runtime compiler detection, we require nvcc is configured # correctly by default. - # if cxx_compiler_path != "": - # cmd += ["-ccbin", cxx_compiler_path] + cxx_compiler_path = os.environ.get("CUDAHOSTCXX", "") + if cxx_compiler_path != "": + cmd += ["-ccbin", cxx_compiler_path] proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 39214c4546dc..da3392184eaf 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -33,6 +33,8 @@ // Centralized header for constant folders. #include "../../arith/const_fold.h" #include "../../target/datatype/registry.h" +#include "../../support/scalars.h" + namespace tvm { @@ -206,10 +208,16 @@ PrimExpr max_value(const DataType& dtype, Span span) { } else if (dtype.bits() == 32) { return FloatImm(dtype, std::numeric_limits::max(), span); } else if (dtype.bits() == 16) { - return FloatImm(dtype, 65504.0, span); + return FloatImm(dtype, support::kMaxFloat16, span); } } else if (dtype.is_bfloat16()) { - return FloatImm(dtype, std::numeric_limits::max(), span); + return FloatImm(dtype, support::kMaxBFloat16, span); + } else if (dtype.is_float8()) { + if (dtype.code() == DataType::kE4M3Float) { + return FloatImm(dtype, support::kMaxE4M3, span); + } else { // E5M2 + return FloatImm(dtype, support::kMaxE5M2, span); + } } LOG(FATAL) << "Cannot decide max_value for type" << dtype; } @@ -240,10 +248,16 @@ PrimExpr min_value(const DataType& dtype, Span span) { } else if (dtype.bits() == 32) { return FloatImm(dtype, std::numeric_limits::lowest(), span); } else if (dtype.bits() == 16) { - return FloatImm(dtype, -65504.0, span); - } + return FloatImm(dtype, -support::kMaxFloat16, span); + } } else if (dtype.is_bfloat16()) { - return FloatImm(dtype, std::numeric_limits::lowest(), span); + return FloatImm(dtype, -support::kMaxBFloat16, span); + } else if (dtype.is_float8()) { + if (dtype.code() == DataType::kE4M3Float) { + return FloatImm(dtype, -support::kMaxE4M3, span); + } else { // E5M2 + return FloatImm(dtype, -support::kMaxE5M2, span); + } } LOG(FATAL) << "Cannot decide min_value for type" << dtype; } @@ -258,6 +272,12 @@ PrimExpr infinity(const DataType& dtype, Span span) { } else if (dtype.bits() == 32 || dtype.bits() == 16) { return FloatImm(dtype, std::numeric_limits::infinity(), span); } + } else if (dtype.is_bfloat16()) { + return FloatImm(dtype, std::numeric_limits::infinity(), span); + } else if (dtype.is_float8()) { + if (dtype.code() == DataType::kE5M2Float) { + return FloatImm(dtype, std::numeric_limits::infinity(), span); + } } LOG(FATAL) << "Cannot decide infinity for type " << dtype; } @@ -661,7 +681,7 @@ TVM_REGISTER_GLOBAL("tir.bitwise_not").set_body_typed([](PrimExpr a, Span span) // pow PrimExpr pow(PrimExpr x, PrimExpr y, Span span) { BinaryOpMatchTypes(x, y, span); - ICHECK(x.dtype().is_float()) << "power only applies to float"; + ICHECK(x.dtype().is_floating_point()) << "power only applies to float"; static auto op = Op::Get("tir.pow"); return tir::Call(x.dtype(), op, {x, y}, span); } @@ -677,7 +697,7 @@ PrimExpr abs(PrimExpr x, Span span) { return IntImm(x.dtype(), std::abs(px->value), px->span); } return tir::Select(x >= make_zero(x.dtype()), x, -x, span); - } else if (x.dtype().is_float()) { + } else if (x.dtype().is_floating_point()) { using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) { @@ -701,7 +721,7 @@ PrimExpr isnan(PrimExpr x, Span span) { DataType t = DataType::Bool(x.dtype().lanes()); if (x.dtype().is_int() || x.dtype().is_uint()) { return make_const(t, false); - } else if (x.dtype().is_float()) { + } else if (x.dtype().is_floating_point()) { using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) { @@ -723,7 +743,7 @@ PrimExpr isinf(PrimExpr x, Span span) { DataType t = DataType::Bool(x.dtype().lanes()); if (x.dtype().is_int() || x.dtype().is_uint()) { return make_const(t, false, span); - } else if (x.dtype().is_float()) { + } else if (x.dtype().is_floating_point()) { PrimExpr infX = infinity(x.dtype(), span); return abs(x, span) == infX && !isnan(x, span); } else { @@ -787,7 +807,7 @@ PrimExpr prod(PrimExpr source, Array rdom, Array init, Span s // fmod PrimExpr fmod(PrimExpr x, PrimExpr y, Span span) { BinaryOpMatchTypes(x, y, span); - ICHECK(x.dtype().is_float()) << "fmod only applies to float"; + ICHECK(x.dtype().is_floating_point()) << "fmod only applies to float"; static auto op = Op::Get("tir.fmod"); return tir::Call(x.dtype(), op, {x, y}, span); } diff --git a/src/tir/transforms/dtype_conversion.cc b/src/tir/transforms/dtype_conversion.cc index de94cf647387..de9d57bed13d 100644 --- a/src/tir/transforms/dtype_conversion.cc +++ b/src/tir/transforms/dtype_conversion.cc @@ -38,11 +38,8 @@ PrimExpr DTypeConversion(PrimExpr src_value, DataType tgt_dtype, RoundingMode ro // The lanes of src dtype and target dtype must match. CHECK_EQ(src_dtype.lanes(), tgt_dtype.lanes()) << "The lanes for data type for source value must matches the target datatype."; - auto is_floating_point = [](DataType dtype) { - return dtype.is_float() || dtype.is_float8() || dtype.is_bfloat16(); - }; // Both source dtype and target dtype should be floating point. - CHECK(is_floating_point(src_dtype) && is_floating_point(tgt_dtype)); + CHECK(src_dtype.is_floating_point() && tgt_dtype.is_floating_point()); FloatConfig src_fp = FloatConfig::FromDataType(src_value.dtype()), tgt_fp = FloatConfig::FromDataType(tgt_dtype); int exponent_delta = tgt_fp.exponent - src_fp.exponent; diff --git a/try.py b/try.py new file mode 100644 index 000000000000..dd1be2cfebcf --- /dev/null +++ b/try.py @@ -0,0 +1,23 @@ +import tvm +from tvm.script import tir as T + +@T.prim_func +def vectorize_add_fp16(A: T.Buffer([128], "bfloat16"), B: T.Buffer([128], "bfloat16")) -> None: + + for i in range(128): + with T.block("blk"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + T.abs(A[vi]) + + +sch = tvm.tir.Schedule(vectorize_add_fp16, debug_mask="all") +blk = sch.get_block("blk") +i, = sch.get_loops(blk) +io, ii, v = sch.split(i, [None, 32, 2]) +sch.vectorize(v) +sch.bind(ii, "threadIdx.x") +sch.bind(io, "blockIdx.x") + +print(sch.mod["main"]) +f = tvm.build(sch.mod["main"], target="cuda") +print(f.imported_modules[0].get_source()) \ No newline at end of file From 0f6b8589e9dfd151854c4e711e8a997cf1e3f578 Mon Sep 17 00:00:00 2001 From: Zihao Date: Thu, 29 Jun 2023 05:48:16 -0700 Subject: [PATCH 05/37] upd --- src/driver/driver_api.cc | 24 +++++++++++---- src/target/llvm/codegen_llvm.cc | 6 ++++ src/target/source/codegen_cuda.cc | 15 ++++++++-- src/target/source/codegen_cuda.h | 2 +- src/target/source/intrin_rule_cuda.cc | 8 ++++- src/target/source/literal/cuda_half_t.h | 1 + .../transforms/unsupported_dtype_legalize.cc | 30 +++---------------- try.py | 23 -------------- 8 files changed, 50 insertions(+), 59 deletions(-) delete mode 100644 try.py diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index fecacff5b0d0..5526ef4920b5 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -55,8 +55,8 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer); TVM_REGISTER_PASS_CONFIG_OPTION("tir.ptx_ldg32", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.hardware_support_bf16", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.hardware_support_fp8", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.target_support_bf16", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.target_support_fp8", Bool); // WARNING: May cause coherency issues resulting data miscompares // Experimental feature that, when enabled by the runtime, bypasses the cache when using DMA. When @@ -154,6 +154,8 @@ Array CreatePassList(bool disable_loop_partition) { bool enable_equiv_terms_in_cse_tir = pass_ctx->GetConfig("tir.enable_equiv_terms_in_cse_tir", Bool(false)).value(); bool ptx_ldg32 = pass_ctx->GetConfig("tir.ptx_ldg32", Bool(false)).value(); + bool target_support_bf16 = pass_ctx->GetConfig("tir.target_support_bf16", Bool(false)).value(); + bool target_support_fp8 = pass_ctx->GetConfig("tir.target_support_fp8", Bool(false)).value(); // Get any user-added passes Array> add_lower_pass = @@ -211,8 +213,12 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::InjectSoftwarePipeline()); pass_list.push_back(tir::transform::LowerOpaqueBlock()); pass_list.push_back(tir::transform::FlattenBuffer()); - pass_list.push_back(tir::transform::FP8ComputeLegalize()); - pass_list.push_back(tir::transform::BF16ComputeLegalize()); + if (!target_support_fp8) { + pass_list.push_back(tir::transform::FP8ComputeLegalize()); + } + if (!target_support_bf16) { + pass_list.push_back(tir::transform::BF16ComputeLegalize()); + } pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); @@ -541,6 +547,8 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg, transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) { transform::PassContext pass_ctx = transform::PassContext::Current(); + bool target_support_bf16 = pass_ctx->GetConfig("tir.target_support_bf16", Bool(false)).value(); + bool target_support_fp8 = pass_ctx->GetConfig("tir.target_support_fp8", Bool(false)).value(); Array mixed_pass_list; @@ -588,8 +596,12 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) } else { mixed_pass_list.push_back(tir::transform::MakePackedAPI()); } - mixed_pass_list.push_back(tir::transform::FP8StorageLegalize()); - mixed_pass_list.push_back(tir::transform::BF16StorageLegalize()); + if (!target_support_fp8) { + mixed_pass_list.push_back(tir::transform::FP8StorageLegalize()); + } + if (!target_support_bf16) { + mixed_pass_list.push_back(tir::transform::BF16StorageLegalize()); + } mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions()); mixed_pass_list.push_back(tir::transform::SplitHostDevice()); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index ada51677ac16..d3b00a6a82aa 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -568,6 +568,12 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { default: LOG(FATAL) << "do not support " << dtype; } + } else if (dtype.is_bfloat16()) { +#if TVM_LLVM_VERSION >= 110 + etype = llvm::Type::getBFloatTy(*ctx); +#else + LOG(FATAL) << "bfloat16 is not supported for your LLVM version " << TVM_LLVM_VERSION << "."; +#endif } if (dtype.lanes() != 1) { #if TVM_LLVM_VERSION >= 110 diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index ec8695a2a038..fd3e06b1c5c1 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -36,6 +36,7 @@ #include "../../tir/transforms/ir_utils.h" #include "literal/cuda_half_t.h" #include "ptx.h" +#include "cuda_vector_intrin.h" namespace tvm { namespace codegen { @@ -210,7 +211,11 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) // h4.w is emitted as *(half2*)(&(u2.y)).y // ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; - os << "uint" << lanes / 2; + if (lanes == 2) { + os << "half2"; + } else { + os << "uint" << lanes / 2; + } } else { fail = true; } @@ -251,7 +256,11 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) os << "nv_bfloat16"; } else if (lanes <= 8) { ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; - os << "uint" << lanes / 2; + if (lanes == 2) { + os << "nv_bfloat162"; + } else { + os << "uint" << lanes / 2; + } } else { fail = true; } @@ -425,6 +434,8 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) LOG(FATAL) << "Cannot convert type " << t << " to CUDA type"; } + + void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, std::ostream& os) { // NOLINT(*) // Delcare the result. diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index c6cf96d460d4..46ecaba16538 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -52,7 +52,7 @@ class CodeGenCUDA final : public CodeGenC { void PrintStorageSync(const CallNode* op) final; void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, - std::ostream& os) final; // NOLINT(*) + std::ostream& os) final; // NOLINT(*) void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) void PrintVecElemLoad(const std::string& vec, DataType t, int i, std::ostream& os) final; // NOLINT(*) diff --git a/src/target/source/intrin_rule_cuda.cc b/src/target/source/intrin_rule_cuda.cc index 95fbf7f1a513..47222e75003d 100644 --- a/src/target/source/intrin_rule_cuda.cc +++ b/src/target/source/intrin_rule_cuda.cc @@ -53,7 +53,13 @@ struct CUDAMath { return ""; } } else if (t.is_bfloat16()) { - return 'h' + name; + if (name == "fabs") { + return "habs"; + } else if (name == "round") { + return "hrint"; + } else { + return 'h' + name; + } } return ""; } diff --git a/src/target/source/literal/cuda_half_t.h b/src/target/source/literal/cuda_half_t.h index c3a1c66e5874..bc6b627f38b8 100644 --- a/src/target/source/literal/cuda_half_t.h +++ b/src/target/source/literal/cuda_half_t.h @@ -357,6 +357,7 @@ CUDA_UNSUPPORTED_HALF_MATH_UNARY(htanh, tanhf) CUDA_UNSUPPORTED_HALF_MATH_UNARY(htan, tanf) CUDA_UNSUPPORTED_HALF_MATH_UNARY(hatan, atanf) CUDA_UNSUPPORTED_HALF_MATH_UNARY(herf, erf) +CUDA_UNSUPPORTED_HALF_MATH_UNARY(habs, abs) #undef CUDA_UNSUPPORTED_HALF_MATH_BINARY #undef CUDA_UNSUPPORTED_HALF_MATH_UNARY diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index 74d881ab9e9a..a92ce020c94d 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -686,13 +686,7 @@ namespace transform { Pass BF16ComputeLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - bool target_support_bf16 = - ctx->GetConfig("tir.hardware_support_bf16", Bool(false)).value(); - if (target_support_bf16) { - return f; - } else { - return BF16ComputeLegalizer().Legalize(f); - } + return BF16ComputeLegalizer().Legalize(f); }; return CreatePrimFuncPass(pass_func, 0, "tir.BF16ComputeLegalize", {}); } @@ -701,13 +695,7 @@ TVM_REGISTER_GLOBAL("tir.transform.BF16ComputeLegalize").set_body_typed(BF16Comp Pass BF16StorageLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - bool target_support_bf16 = - ctx->GetConfig("tir.hardware_support_bf16", Bool(false)).value(); - if (target_support_bf16) { - return f; - } else { - return BF16StorageLegalizer().Legalize(f); - } + return BF16StorageLegalizer().Legalize(f); }; return CreatePrimFuncPass(pass_func, 0, "tir.BF16StorageLegalize", {}); } @@ -716,12 +704,7 @@ TVM_REGISTER_GLOBAL("tir.transform.BF16StorageLegalize").set_body_typed(BF16Stor Pass FP8ComputeLegalize(String promote_dtype_str) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { - bool target_support_fp8 = ctx->GetConfig("tir.hardware_support_fp8", Bool(false)).value(); - if (target_support_fp8) { - return f; - } else { - return FP8ComputeLegalizer(DataType(String2DLDataType(promote_dtype_str))).Legalize(f); - } + return FP8ComputeLegalizer(DataType(String2DLDataType(promote_dtype_str))).Legalize(f); }; return CreatePrimFuncPass(pass_func, 0, "tir.FP8ComputeLegalize", {}); } @@ -730,12 +713,7 @@ TVM_REGISTER_GLOBAL("tir.transform.FP8ComputeLegalize").set_body_typed(FP8Comput Pass FP8StorageLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - bool target_support_fp8 = ctx->GetConfig("tir.hardware_support_fp8", Bool(false)).value(); - if (target_support_fp8) { - return f; - } else { - return FP8StorageLegalizer().Legalize(f); - } + return FP8StorageLegalizer().Legalize(f); }; return CreatePrimFuncPass(pass_func, 0, "tir.FP8StorageLegalize", {}); } diff --git a/try.py b/try.py deleted file mode 100644 index dd1be2cfebcf..000000000000 --- a/try.py +++ /dev/null @@ -1,23 +0,0 @@ -import tvm -from tvm.script import tir as T - -@T.prim_func -def vectorize_add_fp16(A: T.Buffer([128], "bfloat16"), B: T.Buffer([128], "bfloat16")) -> None: - - for i in range(128): - with T.block("blk"): - vi = T.axis.remap("S", [i]) - B[vi] = A[vi] + T.abs(A[vi]) - - -sch = tvm.tir.Schedule(vectorize_add_fp16, debug_mask="all") -blk = sch.get_block("blk") -i, = sch.get_loops(blk) -io, ii, v = sch.split(i, [None, 32, 2]) -sch.vectorize(v) -sch.bind(ii, "threadIdx.x") -sch.bind(io, "blockIdx.x") - -print(sch.mod["main"]) -f = tvm.build(sch.mod["main"], target="cuda") -print(f.imported_modules[0].get_source()) \ No newline at end of file From edc938afffe1dcc908893297f8258438656703d6 Mon Sep 17 00:00:00 2001 From: Zihao Date: Thu, 29 Jun 2023 05:48:59 -0700 Subject: [PATCH 06/37] rm redundant files --- src/target/source/codegen_cuda.cc | 1 - src/target/source/cuda_vector_intrin.cc | 42 ------------------------- src/target/source/cuda_vector_intrin.h | 42 ------------------------- 3 files changed, 85 deletions(-) delete mode 100644 src/target/source/cuda_vector_intrin.cc delete mode 100644 src/target/source/cuda_vector_intrin.h diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index fd3e06b1c5c1..72b62557751b 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -36,7 +36,6 @@ #include "../../tir/transforms/ir_utils.h" #include "literal/cuda_half_t.h" #include "ptx.h" -#include "cuda_vector_intrin.h" namespace tvm { namespace codegen { diff --git a/src/target/source/cuda_vector_intrin.cc b/src/target/source/cuda_vector_intrin.cc deleted file mode 100644 index a3d957e42fcb..000000000000 --- a/src/target/source/cuda_vector_intrin.cc +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file cuda_vector_intrin.cc - * \brief Code generation with vector intrinsics in CUDA. - */ -#include "cuda_vector_intrin.h" - -namespace tvm { -namespace codegen { -namespace cuda { - -std::string PrintHalf2BinaryOp(const std::string& op, const std::string lhs, - const std::string rhs) { - return ""; -} - -std::string PrintNVBFloat162BinaryOp(const std::string& op, const std::string lhs, - const std::string rhs) { - return ""; -} - -} // namespace cuda -} // namespace codegen -} // namespace tvm diff --git a/src/target/source/cuda_vector_intrin.h b/src/target/source/cuda_vector_intrin.h deleted file mode 100644 index 9f636681c0b4..000000000000 --- a/src/target/source/cuda_vector_intrin.h +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file cuda_vector_intrin.h - * \brief Code generation with vector intrinsics in CUDA. - */ -#ifndef TVM_TARGET_SOURCE_CUDA_VECTOR_INTRIN_H_ -#define TVM_TARGET_SOURCE_CUDA_VECTOR_INTRIN_H_ - -#include - -namespace tvm { -namespace codegen { -namespace cuda { - -std::string PrintHalf2BinaryOp(const std::string& op, const std::string lhs, const std::string rhs); - -std::string PrintNVBFloat162BinaryOp(const std::string& op, const std::string lhs, - const std::string rhs); - -} // namespace cuda -} // namespace codegen -} // namespace tvm - -#endif // TVM_TARGET_SOURCE_CUDA_VECTOR_INTRIN_H_ \ No newline at end of file From 4ac20f172600d22219165f8db2b50922ca40f8e2 Mon Sep 17 00:00:00 2001 From: Zihao Date: Thu, 29 Jun 2023 06:23:26 -0700 Subject: [PATCH 07/37] upd --- python/tvm/_ffi/runtime_ctypes.py | 1 + python/tvm/runtime/ndarray.py | 20 +++++++++++++------- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index adcc3a8e972c..2554098d2c43 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -108,6 +108,7 @@ class DataType(ctypes.Structure): "uint64": {"type_code": DataTypeCode.UINT, "bits": 64, "lanes": 1}, "e4m3_float8": {"type_code": DataTypeCode.E4M3Float, "bits": 8, "lanes": 1}, "e5m2_float8": {"type_code": DataTypeCode.E5M2Float, "bits": 8, "lanes": 1}, + "bfloat16": {"type_code": DataTypeCode.BFLOAT, "bits": 16, "lanes": 1}, "float16": {"type_code": DataTypeCode.FLOAT, "bits": 16, "lanes": 1}, "float32": {"type_code": DataTypeCode.FLOAT, "bits": 32, "lanes": 1}, "float64": {"type_code": DataTypeCode.FLOAT, "bits": 64, "lanes": 1}, diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index a78c68ee67c4..6857c5e2d958 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -167,18 +167,19 @@ def copyfrom(self, source_array): raise ValueError( f"array shape do not match the shape of NDArray {source_array.shape} vs {shape}" ) + numpy_str_map = DataType.NUMPY2STR np_dtype_str = ( numpy_str_map[source_array.dtype] if source_array.dtype in numpy_str_map else str(source_array.dtype) ) - if (not source_array.flags["C_CONTIGUOUS"]) or ( - dtype == "bfloat16" or dtype != np_dtype_str - ): - source_array = np.ascontiguousarray( - source_array, dtype="uint16" if dtype == "bfloat16" else dtype - ) + if (not source_array.flags["C_CONTIGUOUS"]) or dtype != np_dtype_str: + if dtype == "e4m3_float8": + dtype = "float8_e4m3fn" + elif dtype == "e5m2_float8": + dtype = "float8_e5m2" + source_array = np.ascontiguousarray(source_array, dtype) assert source_array.flags["C_CONTIGUOUS"] data = source_array.ctypes.data_as(ctypes.c_void_p) nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize) @@ -221,7 +222,12 @@ def numpy(self): if dtype == "int4": dtype = "int8" if dtype == "bfloat16": - dtype = "uint16" + if ml_dtypes is not None: + dtype = ml_dtypes.bfloat16 + else: + raise RuntimeError( + "ml_dtypes is not installed, cannot convert bfloat16 array to numpy." + ) if dtype == "e4m3_float8": if ml_dtypes is not None: dtype = ml_dtypes.float8_e4m3fn From de6c5b297b57dea7ae4302454fbf7bf415438719 Mon Sep 17 00:00:00 2001 From: Zihao Date: Thu, 29 Jun 2023 07:33:07 -0700 Subject: [PATCH 08/37] upd --- python/tvm/testing/utils.py | 5 ++ src/target/source/codegen_cuda.cc | 58 ++++++++++--------- src/target/source/codegen_cuda.h | 2 +- src/tir/op/op.cc | 9 ++- .../unittest/test_target_codegen_cuda.py | 41 +++++-------- 5 files changed, 56 insertions(+), 59 deletions(-) diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index f29b7c4394c2..92923a2ed290 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -114,6 +114,11 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7): compares the `abs(actual-desired)` with `atol+rtol*abs(desired)`. Since we often allow `desired` to be close to zero, we generally want non-zero `atol`. """ + # The ml_dtypes v0.2 is not compatible with allclose function, convert to float32. + if actual.dtype == "bfloat16": + actual = actual.astype("float32") + if desired.dtype == "bfloat16": + desired = desired.astype("float32") actual = np.asanyarray(actual) desired = np.asanyarray(desired) np.testing.assert_allclose(actual.shape, desired.shape) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 448e2d4a051d..bff30e65e0aa 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -199,6 +199,8 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) enable_fp16_ = true; if (t.is_scalar()) { os << "half"; + } else if (lanes == 2) { + os << "half2"; } else if (lanes <= 8) { // Emit CUDA code to access fp16 vector elements. // @@ -210,11 +212,7 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) // h4.w is emitted as *(half2*)(&(u2.y)).y // ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; - if (lanes == 2) { - os << "half2"; - } else { - os << "uint" << lanes / 2; - } + os << "uint" << lanes / 2; } else { fail = true; } @@ -253,13 +251,11 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) enable_bf16_ = true; if (t.is_scalar()) { os << "nv_bfloat16"; + } else if (lanes == 2) { + os << "nv_bfloat162"; } else if (lanes <= 8) { ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; - if (lanes == 2) { - os << "nv_bfloat162"; - } else { - os << "uint" << lanes / 2; - } + os << "uint" << lanes / 2; } else { fail = true; } @@ -433,8 +429,6 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) LOG(FATAL) << "Cannot convert type " << t << " to CUDA type"; } - - void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, std::ostream& os) { // NOLINT(*) // Delcare the result. @@ -1122,27 +1116,35 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO if (op->dtype.is_float16()) { std::string v = PrintExpr(op->value); - os << "make_"; - PrintType(op->dtype, os); - os << '('; - for (int i = 0; i < op->lanes / 2; ++i) { - if (i != 0) os << ", "; - os << "__pack_half2(" << v << ", " << v << ")"; + if (op->lanes == 2) { + os << "make_half2(" << v << ", " << v << ")"; + } else { + os << "make_"; + PrintType(op->dtype, os); + os << '('; + for (int i = 0; i < op->lanes / 2; ++i) { + if (i != 0) os << ", "; + os << "__pack_half2(" << v << ", " << v << ")"; + } + os << ')'; } - os << ')'; return; } if (op->dtype.is_bfloat16()) { std::string v = PrintExpr(op->value); - os << "make_"; - PrintType(op->dtype, os); - os << '('; - for (int i = 0; i < op->lanes / 2; ++i) { - if (i != 0) os << ", "; - os << "__pack_nv_bfloat162(" << v << ", " << v << ")"; + if (op->lanes == 2) { + os << "__halves2bfloat162(" << v << ", " << v << ")"; + } else { + os << "make_"; + PrintType(op->dtype, os); + os << "("; + for (int i = 0; i < op->lanes / 2; ++i) { + if (i != 0) os << ", "; + os << "__pack_nv_bfloat162(" << v << ", " << v << ")"; + } + os << ')'; } - os << ')'; return; } @@ -1402,7 +1404,7 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val os << '('; } if (i % 2 == 0) { - os << "__pack_half2(" << value; + os << "make_half2(" << value; } else { os << "," << value << ")"; if (i != t.lanes() - 1) { @@ -1421,7 +1423,7 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val os << '('; } if (i % 2 == 0) { - os << "__pack_bfloat162(" << value; + os << " __halves2bfloat162(" << value; } else { os << "," << value << ")"; if (i != t.lanes() - 1) { diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index 46ecaba16538..c6cf96d460d4 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -52,7 +52,7 @@ class CodeGenCUDA final : public CodeGenC { void PrintStorageSync(const CallNode* op) final; void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, - std::ostream& os) final; // NOLINT(*) + std::ostream& os) final; // NOLINT(*) void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) void PrintVecElemLoad(const std::string& vec, DataType t, int i, std::ostream& os) final; // NOLINT(*) diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index da3392184eaf..d9c1d6ff8f72 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -32,9 +32,8 @@ #include // Centralized header for constant folders. #include "../../arith/const_fold.h" -#include "../../target/datatype/registry.h" #include "../../support/scalars.h" - +#include "../../target/datatype/registry.h" namespace tvm { @@ -215,7 +214,7 @@ PrimExpr max_value(const DataType& dtype, Span span) { } else if (dtype.is_float8()) { if (dtype.code() == DataType::kE4M3Float) { return FloatImm(dtype, support::kMaxE4M3, span); - } else { // E5M2 + } else { // E5M2 return FloatImm(dtype, support::kMaxE5M2, span); } } @@ -249,13 +248,13 @@ PrimExpr min_value(const DataType& dtype, Span span) { return FloatImm(dtype, std::numeric_limits::lowest(), span); } else if (dtype.bits() == 16) { return FloatImm(dtype, -support::kMaxFloat16, span); - } + } } else if (dtype.is_bfloat16()) { return FloatImm(dtype, -support::kMaxBFloat16, span); } else if (dtype.is_float8()) { if (dtype.code() == DataType::kE4M3Float) { return FloatImm(dtype, -support::kMaxE4M3, span); - } else { // E5M2 + } else { // E5M2 return FloatImm(dtype, -support::kMaxE5M2, span); } } diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index fa9474e73d88..ea145daf1a3f 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -24,6 +24,10 @@ from tvm.contrib.nvcc import have_fp16, have_int8, have_bf16 import tvm.testing import pytest +try: + import ml_dtypes +except ImportError as e: + ml_dtypes = None tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") @@ -75,21 +79,11 @@ def test_cuda_bf16_vectorize_add(): if not have_bf16(tvm.cuda(0).compute_version): print("skip because gpu does not support bf16") return + if ml_dtypes is None: + print("skip because ml_dtypes not installed") + return num_thread = 8 - def np_float2np_bf16(arr): - """Convert a numpy array of float to a numpy array - of bf16 in uint16""" - orig = arr.view(" Date: Thu, 29 Jun 2023 07:35:09 -0700 Subject: [PATCH 09/37] do not change nvcc --- python/tvm/contrib/nvcc.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index bd36a325c813..2ef3cf9df4dd 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -99,9 +99,8 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target # However, we cannot use TVM_CXX_COMPILER_PATH because the runtime env. # Because it is hard to do runtime compiler detection, we require nvcc is configured # correctly by default. - cxx_compiler_path = os.environ.get("CUDAHOSTCXX", "") - if cxx_compiler_path != "": - cmd += ["-ccbin", cxx_compiler_path] + # if cxx_compiler_path != "": + # cmd += ["-ccbin", cxx_compiler_path] proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) From fd0ff79157f629e54f752f0e28d07a06e16e5bb9 Mon Sep 17 00:00:00 2001 From: Zihao Date: Thu, 29 Jun 2023 07:37:04 -0700 Subject: [PATCH 10/37] fix --- tests/python/unittest/test_target_codegen_cuda.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index ea145daf1a3f..633d2335f51c 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -103,7 +103,7 @@ def check_cuda(n, lanes): c = tvm.nd.empty((n, lanes), "bfloat16", dev).copyfrom(c) tvm.testing.assert_allclose(c.numpy(), np_a + 1) - # check_cuda(64, 2) + check_cuda(64, 2) check_cuda(64, 4) check_cuda(64, 6) check_cuda(64, 8) From a0525b9d95ee2ca89c0592c668a768e9cc01bc25 Mon Sep 17 00:00:00 2001 From: Zihao Date: Thu, 29 Jun 2023 07:38:45 -0700 Subject: [PATCH 11/37] remove empty line --- python/tvm/contrib/nvcc.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index 2ef3cf9df4dd..92bd4041dc66 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -101,7 +101,6 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target # correctly by default. # if cxx_compiler_path != "": # cmd += ["-ccbin", cxx_compiler_path] - proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) From 9c6d63938e93f5618b7950cd9372f9ead551124a Mon Sep 17 00:00:00 2001 From: Zihao Date: Thu, 29 Jun 2023 07:47:40 -0700 Subject: [PATCH 12/37] fix --- python/tvm/contrib/nvcc.py | 2 +- src/target/source/codegen_cuda.cc | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index 92bd4041dc66..5eb348009914 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -100,7 +100,7 @@ def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target # Because it is hard to do runtime compiler detection, we require nvcc is configured # correctly by default. # if cxx_compiler_path != "": - # cmd += ["-ccbin", cxx_compiler_path] + # cmd += ["-ccbin", cxx_compiler_path] proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index bff30e65e0aa..87a8c29ccd27 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -1404,7 +1404,7 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val os << '('; } if (i % 2 == 0) { - os << "make_half2(" << value; + os << "__pack_half2(" << value; } else { os << "," << value << ")"; if (i != t.lanes() - 1) { @@ -1423,7 +1423,7 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val os << '('; } if (i % 2 == 0) { - os << " __halves2bfloat162(" << value; + os << "__pack_nv_bfloat162(" << value; } else { os << "," << value << ")"; if (i != t.lanes() - 1) { From bdc3382107b44e3d50afe1ecd5d0b153de507d96 Mon Sep 17 00:00:00 2001 From: Zihao Date: Thu, 29 Jun 2023 09:46:20 -0700 Subject: [PATCH 13/37] lint --- tests/python/unittest/test_target_codegen_cuda.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index 633d2335f51c..90bdb0d75131 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -24,6 +24,7 @@ from tvm.contrib.nvcc import have_fp16, have_int8, have_bf16 import tvm.testing import pytest + try: import ml_dtypes except ImportError as e: @@ -91,9 +92,7 @@ def check_cuda(n, lanes): xo, xi = s[B].split(B.op.axis[0], factor=num_thread) s[B].bind(xo, bx) s[B].bind(xi, tx) - with tvm.transform.PassContext( - config={"tir.target_support_bf16": True} - ): + with tvm.transform.PassContext(config={"tir.target_support_bf16": True}): fun = tvm.build(s, [A, B], "cuda") dev = tvm.cuda(0) np_a = np.random.uniform(size=(n, lanes)).astype("bfloat16") From a15b600dfa053c34c889d53dc45793c892fd24cd Mon Sep 17 00:00:00 2001 From: Zihao Date: Thu, 29 Jun 2023 17:12:33 -0700 Subject: [PATCH 14/37] c++ lint --- src/driver/driver_api.cc | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index c8ecd3915cb1..41a6dfa22e9e 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -154,8 +154,10 @@ Array CreatePassList(bool disable_loop_partition) { bool enable_equiv_terms_in_cse_tir = pass_ctx->GetConfig("tir.enable_equiv_terms_in_cse_tir", Bool(false)).value(); bool ptx_ldg32 = pass_ctx->GetConfig("tir.ptx_ldg32", Bool(false)).value(); - bool target_support_bf16 = pass_ctx->GetConfig("tir.target_support_bf16", Bool(false)).value(); - bool target_support_fp8 = pass_ctx->GetConfig("tir.target_support_fp8", Bool(false)).value(); + bool target_support_bf16 = + pass_ctx->GetConfig("tir.target_support_bf16", Bool(false)).value(); + bool target_support_fp8 = + pass_ctx->GetConfig("tir.target_support_fp8", Bool(false)).value(); // Get any user-added passes Array> add_lower_pass = @@ -551,8 +553,10 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg, transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) { transform::PassContext pass_ctx = transform::PassContext::Current(); - bool target_support_bf16 = pass_ctx->GetConfig("tir.target_support_bf16", Bool(false)).value(); - bool target_support_fp8 = pass_ctx->GetConfig("tir.target_support_fp8", Bool(false)).value(); + bool target_support_bf16 = + pass_ctx->GetConfig("tir.target_support_bf16", Bool(false)).value(); + bool target_support_fp8 = + pass_ctx->GetConfig("tir.target_support_fp8", Bool(false)).value(); Array mixed_pass_list; From c33423ac397dd4acd19d6d37b3ef15c4c6d76270 Mon Sep 17 00:00:00 2001 From: Zihao Date: Thu, 29 Jun 2023 20:18:52 -0700 Subject: [PATCH 15/37] use ml_dtypes for llvm codegen test --- .../unittest/test_target_codegen_llvm.py | 41 +++---------------- 1 file changed, 6 insertions(+), 35 deletions(-) diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 6a2f5573b274..aa20c6078f97 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -695,33 +695,6 @@ def _transform(f, *_): tvm.testing.assert_allclose(c_.numpy(), (a_.numpy() * 2).astype("int32")) -def np_float2np_bf16(arr): - """Convert a numpy array of float to a numpy array - of bf16 in uint16""" - orig = arr.view(" Date: Fri, 30 Jun 2023 20:41:33 -0700 Subject: [PATCH 16/37] add ml_dtypes to ci-constraints --- docker/python/ci-constraints.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/docker/python/ci-constraints.txt b/docker/python/ci-constraints.txt index 003c13170411..b8f94c112a9c 100644 --- a/docker/python/ci-constraints.txt +++ b/docker/python/ci-constraints.txt @@ -37,3 +37,4 @@ tflite = "==2.4.0" torch = "==1.11.0" torchvision = "==0.12.0+cpu" #xgboost = "==1.4.2" +ml_dtypes = "==0.1.0" From 40f088ccf7559ce834bb0cfac22b5aec7b5cf3b6 Mon Sep 17 00:00:00 2001 From: Zihao Date: Fri, 30 Jun 2023 20:42:24 -0700 Subject: [PATCH 17/37] alphabetical --- docker/python/ci-constraints.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/python/ci-constraints.txt b/docker/python/ci-constraints.txt index b8f94c112a9c..7336c0821654 100644 --- a/docker/python/ci-constraints.txt +++ b/docker/python/ci-constraints.txt @@ -17,6 +17,7 @@ flowvision = "==0.1.0" #h5py = "==3.1.0" keras = "==2.7" jinja2 = "==3.0.3" +ml_dtypes = "==0.1.0" mxnet = "==1.6.0" mypy = "==0.902" oneflow = "==0.7.0" @@ -37,4 +38,3 @@ tflite = "==2.4.0" torch = "==1.11.0" torchvision = "==0.12.0+cpu" #xgboost = "==1.4.2" -ml_dtypes = "==0.1.0" From 259f27813b2e741c266fbd61f9f73a6fe89fb9eb Mon Sep 17 00:00:00 2001 From: Zihao Date: Fri, 30 Jun 2023 21:03:24 -0700 Subject: [PATCH 18/37] pylint --- include/tvm/runtime/data_type.h | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index f2c7e09df369..53465708354f 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -108,9 +108,7 @@ class DataType { /*! \return whether type is a bfloat16 type. */ bool is_bfloat16() const { return code() == DataType::kBFloat && bits() == 16; } /*! \return whether type is a general floating point data type. */ - bool is_floating_point() const { - return is_float() || is_float8() || is_bfloat16(); - } + bool is_floating_point() const { return is_float() || is_float8() || is_bfloat16(); } /*! \return whether type is an int type. */ bool is_int() const { return code() == DataType::kInt; } /*! \return whether type is an uint type. */ From aff8f610ff618667f1d06c24e30d7e5fd637220a Mon Sep 17 00:00:00 2001 From: Zihao Date: Sat, 1 Jul 2023 00:08:37 -0700 Subject: [PATCH 19/37] lint --- include/tvm/tir/op.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 93738e0d773c..fd1db1c7ba49 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -939,8 +939,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span()) return LargeUIntImm(t, static_cast(low), static_cast(high), span); } } - if (t.is_floating_point()) - return FloatImm(t, static_cast(value), span); + if (t.is_floating_point()) return FloatImm(t, static_cast(value), span); // For now, we store const scalar values of custom datatypes within doubles; later, during the // datatypes lowering pass, we will lower the value to its true representation in the format // specified by the datatype. From d883e336a1ad1dda3371cdfa2971ba87e6e82834 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sat, 1 Jul 2023 02:04:59 -0700 Subject: [PATCH 20/37] upd --- src/target/source/codegen_cuda.cc | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 87a8c29ccd27..2ce48f1f1224 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -482,9 +482,21 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i, os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))"; } } else if (t.is_float16()) { - os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; + if (t.lanes() == 2) { + // 2 * float16 is stored as half2, use v.x directly + os << vec << "." << access[i]; + } else { + // 4/8 * float16 is stored as uint2/4, use ((half2*)(&(v.x))).x + os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; + } } else if (t.is_bfloat16()) { - os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; + if (t.lanes() == 2) { + // 2 * bfloat16 is stored as nv_bfloat162, use v.x directly + os << vec << "." << access[i]; + } else { + // 4/8 * bfloat16 is stored as uint2/4, use ((nv_bfloat162*)(&(v.x))).x + os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; + } } else if (t.lanes() > 4 && t.lanes() <= 8) { std::string type_name; if (t.bits() == 16) { From 487c15f40282abb5fe2ccb7b932ceb96166cf479 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sat, 1 Jul 2023 02:17:13 -0700 Subject: [PATCH 21/37] improve comments --- src/target/source/codegen_cuda.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 2ce48f1f1224..96adb3992ef7 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -483,18 +483,18 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i, } } else if (t.is_float16()) { if (t.lanes() == 2) { - // 2 * float16 is stored as half2, use v.x directly + // 2 * float16 is stored as half2, return v.x/y directly os << vec << "." << access[i]; } else { - // 4/8 * float16 is stored as uint2/4, use ((half2*)(&(v.x))).x + // 4/8 * float16 is stored as uint2/4, return ((half2*)(&(v.x/y/z/w)))->x/y os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; } } else if (t.is_bfloat16()) { if (t.lanes() == 2) { - // 2 * bfloat16 is stored as nv_bfloat162, use v.x directly + // 2 * bfloat16 is stored as nv_bfloat162, return v.x/y directly os << vec << "." << access[i]; } else { - // 4/8 * bfloat16 is stored as uint2/4, use ((nv_bfloat162*)(&(v.x))).x + // 4/8 * bfloat16 is stored as uint2/4, return ((nv_bfloat162*)(&(v.x/y/z/w)))->x/y os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; } } else if (t.lanes() > 4 && t.lanes() <= 8) { From 73d636141a8cdb8a2325655a5e665553548bf654 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sat, 1 Jul 2023 02:33:29 -0700 Subject: [PATCH 22/37] improved code comment --- src/target/source/codegen_cuda.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 96adb3992ef7..4022a7aaf993 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -483,18 +483,18 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i, } } else if (t.is_float16()) { if (t.lanes() == 2) { - // 2 * float16 is stored as half2, return v.x/y directly + // vec (2 * float16) is stored as half2, return vec.x/y directly os << vec << "." << access[i]; } else { - // 4/8 * float16 is stored as uint2/4, return ((half2*)(&(v.x/y/z/w)))->x/y + // vec (4/8 * float16) is stored as uint2/4, return ((half2*)(&(vec.x/y/z/w)))->x/y os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; } } else if (t.is_bfloat16()) { if (t.lanes() == 2) { - // 2 * bfloat16 is stored as nv_bfloat162, return v.x/y directly + // vec (2 * bfloat16) is stored as nv_bfloat162, return vec.x/y directly os << vec << "." << access[i]; } else { - // 4/8 * bfloat16 is stored as uint2/4, return ((nv_bfloat162*)(&(v.x/y/z/w)))->x/y + // vec (4/8 * bfloat16) is stored as uint2/4, return ((nv_bfloat162*)(&(vec.x/y/z/w)))->x/y os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; } } else if (t.lanes() > 4 && t.lanes() <= 8) { From 0d7d8ba52807e1f41cf93105f0109ac4f644a4d2 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sat, 1 Jul 2023 12:09:21 -0700 Subject: [PATCH 23/37] upd --- python/tvm/testing/utils.py | 20 +++-- src/driver/driver_api.cc | 29 ++----- .../multi_level_tiling_tensor_core.cc | 2 + src/target/source/codegen_cuda.cc | 82 ++++++++++++------- src/target/tag.cc | 8 +- src/tir/transforms/dtype_conversion.h | 2 +- .../transforms/unsupported_dtype_legalize.cc | 63 ++++++++++++++ .../unittest/test_target_codegen_cuda.py | 82 ++++++++++--------- 8 files changed, 189 insertions(+), 99 deletions(-) diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 92923a2ed290..f50790d5f1dc 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -106,6 +106,17 @@ def test_something(): ) +def promote_bf16_to_fp32(x): + r"""Promote the data type of an array-like structure from bfloat16 to float32.""" + if isinstance(x, list): + return [promote_bf16_to_fp32(y) for y in x] + else: + if isinstance(x, np.ndarray) and x.dtype == "bfloat16": + return x.astype("float32") + else: + return x + + def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7): """Version of np.testing.assert_allclose with `atol` and `rtol` fields set in reasonable defaults. @@ -114,11 +125,10 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7): compares the `abs(actual-desired)` with `atol+rtol*abs(desired)`. Since we often allow `desired` to be close to zero, we generally want non-zero `atol`. """ - # The ml_dtypes v0.2 is not compatible with allclose function, convert to float32. - if actual.dtype == "bfloat16": - actual = actual.astype("float32") - if desired.dtype == "bfloat16": - desired = desired.astype("float32") + # The ml_dtypes v0.2 is not compatible with numpy's asanyarray function, promote to float32 first. + actual = promote_bf16_to_fp32(actual) + desired = promote_bf16_to_fp32(desired) + actual = np.asanyarray(actual) desired = np.asanyarray(desired) np.testing.assert_allclose(actual.shape, desired.shape) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 41a6dfa22e9e..f5eeda7243d7 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -55,8 +55,6 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer); TVM_REGISTER_PASS_CONFIG_OPTION("tir.ptx_ldg32", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.target_support_bf16", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.target_support_fp8", Bool); // WARNING: May cause coherency issues resulting data miscompares // Experimental feature that, when enabled by the runtime, bypasses the cache when using DMA. When @@ -154,10 +152,6 @@ Array CreatePassList(bool disable_loop_partition) { bool enable_equiv_terms_in_cse_tir = pass_ctx->GetConfig("tir.enable_equiv_terms_in_cse_tir", Bool(false)).value(); bool ptx_ldg32 = pass_ctx->GetConfig("tir.ptx_ldg32", Bool(false)).value(); - bool target_support_bf16 = - pass_ctx->GetConfig("tir.target_support_bf16", Bool(false)).value(); - bool target_support_fp8 = - pass_ctx->GetConfig("tir.target_support_fp8", Bool(false)).value(); // Get any user-added passes Array> add_lower_pass = @@ -219,12 +213,6 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::TransformMmaBufferLayout()); pass_list.push_back(tir::transform::LowerOpaqueBlock()); pass_list.push_back(tir::transform::FlattenBuffer()); - if (!target_support_fp8) { - pass_list.push_back(tir::transform::FP8ComputeLegalize()); - } - if (!target_support_bf16) { - pass_list.push_back(tir::transform::BF16ComputeLegalize()); - } pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); @@ -553,10 +541,6 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg, transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) { transform::PassContext pass_ctx = transform::PassContext::Current(); - bool target_support_bf16 = - pass_ctx->GetConfig("tir.target_support_bf16", Bool(false)).value(); - bool target_support_fp8 = - pass_ctx->GetConfig("tir.target_support_fp8", Bool(false)).value(); Array mixed_pass_list; @@ -567,6 +551,10 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) mixed_pass_list.push_back(tir::transform::BindTarget(target)); + // FP8/BF16 ComputeLegalize passes (after bind target) + mixed_pass_list.push_back(tir::transform::FP8ComputeLegalize()); + mixed_pass_list.push_back(tir::transform::BF16ComputeLegalize()); + mixed_pass_list.push_back(tir::transform::VerifyMemory()); mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc()); @@ -607,13 +595,8 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) } else { mixed_pass_list.push_back(tir::transform::MakePackedAPI()); } - if (!target_support_fp8) { - mixed_pass_list.push_back(tir::transform::FP8StorageLegalize()); - } - if (!target_support_bf16) { - mixed_pass_list.push_back(tir::transform::BF16StorageLegalize()); - } - + mixed_pass_list.push_back(tir::transform::FP8StorageLegalize()); + mixed_pass_list.push_back(tir::transform::BF16StorageLegalize()); mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch()); return transform::Sequential(mixed_pass_list); diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index 18bd58510d85..5f53e8097eb4 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -605,6 +605,8 @@ std::vector MultiLevelTilingTensorCoreNode::AddReadReuseTensorCore( const DataType& dtype = cache_read_buffer->dtype; if (dtype.is_float16()) { sch->StorageAlign(cache_read, 0, -2, 32, 8); + } else if (dtype.is_bfloat16()) { + sch->StorageAlign(cache_read, 0, -2, 32, 8); } else if (dtype.is_int() && dtype.bits() == 8) { sch->StorageAlign(cache_read, 0, -2, 32, 16); } else { diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 4022a7aaf993..fa3406f0f915 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -626,12 +626,18 @@ std::string CodeGenCUDA::CastFromTo(std::string value, DataType from, DataType t os << "(("; this->PrintType(target, os); os << ")"; - if (from.is_float16() && (target.is_int() || target.is_uint()) && target.bits() == 8) { + if ((from.is_float16() || from.is_bfloat16()) && (target.is_int() || target.is_uint()) && + target.bits() == 8) { + // use int/uint as intermediate data type os << "("; if (target.is_uint()) { os << "u"; } os << "int)"; + } else if ((from.is_bfloat16() && target.is_float16()) || + (from.is_float16() && target.is_bfloat16())) { + // use float as intermediate data type + os << "(float)"; } os << value << ")"; return os.str(); @@ -655,12 +661,9 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { std::string src = SSAGetID(PrintExpr(op->value), from_ty); for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) { std::ostringstream val; - val << "("; - PrintType(target_ty.element_of(), val); - val << ")("; PrintVecElemLoad(src, from_ty, i, val); - val << ")"; - PrintVecElemStore(sret, target_ty, i, val.str()); + std::string casted_val = CastFromTo(val.str(), from_ty.element_of(), target_ty.element_of()); + PrintVecElemStore(sret, target_ty, i, casted_val); } } os << sret; @@ -1146,7 +1149,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO if (op->dtype.is_bfloat16()) { std::string v = PrintExpr(op->value); if (op->lanes == 2) { - os << "__halves2bfloat162(" << v << ", " << v << ")"; + os << "make_bfloat162" << v << ", " << v << ")"; } else { os << "make_"; PrintType(op->dtype, os); @@ -1387,6 +1390,7 @@ void CodeGenCUDA::HandleVolatileLoads(const std::string& value, const BufferLoad // Cast away volatile qualifier for fp16 types. That is, only loads and // stores are volatile. The loaded objects are not marked as volatile. // + // TODO(Zihao): figure out what it is if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) && IsVolatile(op->buffer->data.get())) { os << "("; PrintType(op->dtype, os); @@ -1410,38 +1414,58 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val } if (t.is_float16()) { - if (i == 0) { - os << "make_"; - PrintType(t, os); - os << '('; - } - if (i % 2 == 0) { - os << "__pack_half2(" << value; + if (t.lanes() == 2) { + // result data type is half2 + if (i == 0) { + os << "make_half2(" << value; + } else { + os << ", " << value << ")"; + } } else { - os << "," << value << ")"; - if (i != t.lanes() - 1) { - os << ","; + // result data type is uint2/4 + if (i == 0) { + os << "make_"; + PrintType(t, os); + os << '('; + } + if (i % 2 == 0) { + os << "__pack_half2(" << value; } else { - os << ")"; + os << "," << value << ")"; + if (i != t.lanes() - 1) { + os << ","; + } else { + os << ")"; + } } } return; } if (t.is_bfloat16()) { - if (i == 0) { - os << "make_"; - PrintType(t, os); - os << '('; - } - if (i % 2 == 0) { - os << "__pack_nv_bfloat162(" << value; + if (t.lanes() == 2) { + // result data type is nv_bfloat162 + if (i == 0) { + os << "make_bfloat162" << value; + } else { + os << ", " << value << ")"; + } } else { - os << "," << value << ")"; - if (i != t.lanes() - 1) { - os << ","; + // result data type is uint2/4 + if (i == 0) { + os << "make_"; + PrintType(t, os); + os << '('; + } + if (i % 2 == 0) { + os << "__pack_nv_bfloat162(" << value; } else { - os << ")"; + os << "," << value << ")"; + if (i != t.lanes() - 1) { + os << ","; + } else { + os << ")"; + } } } return; diff --git a/src/target/tag.cc b/src/target/tag.cc index 037d2e5937ca..3188c0169a3e 100644 --- a/src/target/tag.cc +++ b/src/target/tag.cc @@ -108,8 +108,8 @@ TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-xavier") // Naming convention for CUDA tags see https://developer.nvidia.com/cuda-gpus // Parameters see Table 15. Technical Specifications per Compute Capability // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html -// Check `Maximum y- or z-dimension of a grid of thread blocks` for max threads per block // Check `Maximum amount of shared memory per thread block` for max shared memory per block +// Check `Maximum number of 32-bit registers per thread block` for max registers per block // Note that above 48 KB requires dynamic shared memory TVM_REGISTER_CUDA_TAG("nvidia/tesla-k80", "sm_37", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/tesla-k40", "sm_35", 49152, 65536); @@ -219,6 +219,12 @@ TVM_REGISTER_CUDA_TAG("nvidia/nvidia-nvs-310", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/nvs-5400m", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/nvs-5200m", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/nvs-4200m", "sm_21", 49152, 32768); +TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-4090", "sm_89", 49152, 65536); +TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-4080", "sm_89", 49152, 65536); +TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-4070-ti", "sm_89", 49152, 65536); +TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-4070", "sm_89", 49152, 65536); +TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-4060-ti", "sm_89", 49152, 65536); +TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-4060", "sm_89", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3090-ti", "sm_86", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3090", "sm_86", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3080-ti", "sm_86", 49152, 65536); diff --git a/src/tir/transforms/dtype_conversion.h b/src/tir/transforms/dtype_conversion.h index b509abb9cd27..80a65b900520 100644 --- a/src/tir/transforms/dtype_conversion.h +++ b/src/tir/transforms/dtype_conversion.h @@ -99,7 +99,7 @@ class FloatConfig { * \return The FloatConfig class containing internal floating point representation. */ static FloatConfig FromDataType(DataType dtype) { - CHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8()) + CHECK(dtype.is_floating_point()) << "FloatConfig is only applicable to floating point data types, got " << dtype << " instead."; if (dtype.is_float()) { diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index a92ce020c94d..70f15d31c415 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -30,11 +30,54 @@ #include #include +#include "../../support/utils.h" #include "dtype_conversion.h" namespace tvm { namespace tir { +namespace { + +/*! + * \brief Return whether target has native BF16 support. + */ +bool TargetSupportBF16(const Target& target) { + if (target->kind->name == "cuda") { + auto cuda_arch = target->GetAttr("arch"); + if (cuda_arch.defined()) { + auto cuda_arch_str = cuda_arch.value(); + CHECK(support::StartsWith(cuda_arch_str, "sm_")) + << "Expect cuda arch to start with \"sm_\", got " << cuda_arch_str << " instead."; + if (std::stoi(std::string(cuda_arch_str).substr(3)) >= 80) { + // sm_80 or later + return true; + } + } + } + return false; +} + +/*! + * \brief Return whether target has native FP8 support. + */ +bool TargetSupportFP8(const Target& target) { + if (target->kind->name == "cuda") { + auto cuda_arch = target->GetAttr("arch"); + if (cuda_arch.defined()) { + auto cuda_arch_str = cuda_arch.value(); + CHECK(support::StartsWith(cuda_arch_str, "sm_")) + << "Expect cuda arch to start with \"sm_\", got " << cuda_arch_str << " instead."; + if (std::stoi(std::string(cuda_arch_str).substr(3)) >= 89) { + // sm_89 or later + return true; + } + } + } + return false; +} + +} // namespace + // NOTE: do not touch buffer on function boundary // remap internal fp8/bf16 buffer to f32 if they meet the following condition // - constant allocation size @@ -686,6 +729,11 @@ namespace transform { Pass BF16ComputeLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto target = f->GetAttr(tvm::attr::kTarget); + if (target.defined() && TargetSupportBF16(target.value())) { + // Do not legalize bf16 compute if target has native bf16 support. + return f; + } return BF16ComputeLegalizer().Legalize(f); }; return CreatePrimFuncPass(pass_func, 0, "tir.BF16ComputeLegalize", {}); @@ -695,6 +743,11 @@ TVM_REGISTER_GLOBAL("tir.transform.BF16ComputeLegalize").set_body_typed(BF16Comp Pass BF16StorageLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto target = f->GetAttr(tvm::attr::kTarget); + if (target.defined() && TargetSupportBF16(target.value())) { + // Do not legalize bf16 storage if target has native bf16 support. + return f; + } return BF16StorageLegalizer().Legalize(f); }; return CreatePrimFuncPass(pass_func, 0, "tir.BF16StorageLegalize", {}); @@ -704,6 +757,11 @@ TVM_REGISTER_GLOBAL("tir.transform.BF16StorageLegalize").set_body_typed(BF16Stor Pass FP8ComputeLegalize(String promote_dtype_str) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto target = f->GetAttr(tvm::attr::kTarget); + if (target.defined() && TargetSupportFP8(target.value())) { + // Do not legalize fp8 compute if target has native fp8 support. + return f; + } return FP8ComputeLegalizer(DataType(String2DLDataType(promote_dtype_str))).Legalize(f); }; return CreatePrimFuncPass(pass_func, 0, "tir.FP8ComputeLegalize", {}); @@ -713,6 +771,11 @@ TVM_REGISTER_GLOBAL("tir.transform.FP8ComputeLegalize").set_body_typed(FP8Comput Pass FP8StorageLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto target = f->GetAttr(tvm::attr::kTarget); + if (target.defined() && TargetSupportFP8(target.value())) { + // Do not legalize fp8 storage if target has native fp8 support. + return f; + } return FP8StorageLegalizer().Legalize(f); }; return CreatePrimFuncPass(pass_func, 0, "tir.FP8StorageLegalize", {}); diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index 90bdb0d75131..4b34d20250e1 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -24,11 +24,7 @@ from tvm.contrib.nvcc import have_fp16, have_int8, have_bf16 import tvm.testing import pytest - -try: - import ml_dtypes -except ImportError as e: - ml_dtypes = None +import ml_dtypes tx = te.thread_axis("threadIdx.x") bx = te.thread_axis("blockIdx.x") @@ -43,6 +39,9 @@ def check_cuda(dtype, n, lanes): if dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version): print("Skip because gpu does not have fp16 support") return + if dtype == "bfloat16" and not have_bf16(tvm.cuda(0).compute_version): + print("skip because gpu does not support bf16") + return if dtype == "int8" and not have_int8(tvm.cuda(0).compute_version): print("skip because gpu does not support int8") return @@ -72,40 +71,10 @@ def check_cuda(dtype, n, lanes): check_cuda("float16", 64, 4) check_cuda("float16", 64, 6) check_cuda("float16", 64, 8) - - -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda -def test_cuda_bf16_vectorize_add(): - if not have_bf16(tvm.cuda(0).compute_version): - print("skip because gpu does not support bf16") - return - if ml_dtypes is None: - print("skip because ml_dtypes not installed") - return - num_thread = 8 - - def check_cuda(n, lanes): - A = te.placeholder((n,), name="A", dtype="bfloat16x%d" % lanes) - B = te.compute((n,), lambda i: A[i] + tvm.tir.const(1, A.dtype), name="B") - s = te.create_schedule(B.op) - xo, xi = s[B].split(B.op.axis[0], factor=num_thread) - s[B].bind(xo, bx) - s[B].bind(xi, tx) - with tvm.transform.PassContext(config={"tir.target_support_bf16": True}): - fun = tvm.build(s, [A, B], "cuda") - dev = tvm.cuda(0) - np_a = np.random.uniform(size=(n, lanes)).astype("bfloat16") - a = tvm.nd.empty((n,), A.dtype, dev).copyfrom(np_a) - c = tvm.nd.empty((n,), B.dtype, dev) - fun(a, c) - c = tvm.nd.empty((n, lanes), "bfloat16", dev).copyfrom(c) - tvm.testing.assert_allclose(c.numpy(), np_a + 1) - - check_cuda(64, 2) - check_cuda(64, 4) - check_cuda(64, 6) - check_cuda(64, 8) + check_cuda("bfloat16", 64, 2) + check_cuda("bfloat16", 64, 4) + check_cuda("bfloat16", 64, 6) + check_cuda("bfloat16", 64, 8) @tvm.testing.requires_gpu @@ -468,11 +437,13 @@ def check(device, dtype, m=32, n=32): b_nd = tvm.nd.array(b_np, dev) g_nd = tvm.nd.array(np.zeros(g_np.shape, dtype=g_np.dtype), dev) func(a_nd, b_nd, g_nd) + print(g_nd, g_np) tvm.testing.assert_allclose(g_nd.numpy(), g_np, rtol=1e-3) check("cuda", "float32") check("rocm", "float32") check("cuda", "float16") + # check("cuda", "bfloat16") ignored because of rounding errors @tvm.testing.requires_gpu @@ -486,6 +457,9 @@ def check(device, dtype, m=32, n=32): if dtype == "float16" and not have_fp16(dev.compute_version): print("Skip because gpu does not have fp16 support") return + if dtype == "bfloat16" and not have_bf16(dev.compute_version): + print("skip because gpu does not support bf16") + return a = tvm.te.placeholder((m, n), name="a", dtype=dtype) b = topi.sum(a) @@ -504,6 +478,7 @@ def check(device, dtype, m=32, n=32): check("cuda", "float32") check("rocm", "float32") check("cuda", "float16") + # check("cuda", "bfloat16") ignored because of rounding errors @tvm.testing.requires_gpu @@ -565,6 +540,9 @@ def check(t0, t1, factor): if (t0 == "float16" or t1 == "float16") and not have_fp16(tvm.cuda(0).compute_version): print("Skip because gpu does not have fp16 support") return + if (t0 == "bfloat16" or t1 == "bfloat16") and not have_bf16(tvm.cuda(0).compute_version): + print("skip because gpu does not support bf16") + return # compute n = 128 @@ -577,6 +555,7 @@ def check(t0, t1, factor): ob, ib = s[C].split(s[C].op.axis[0], factor=factor) s[C].vectorize(ib) s[C].bind(ob, tx) + func = tvm.build(s, [A, B, C], "cuda") # correctness @@ -603,6 +582,7 @@ def skip(t0, t1): types_4 = [ "float16", "float32", + "bfloat16", "int8", "uint8", "int16", @@ -613,7 +593,17 @@ def skip(t0, t1): "int64", "uint64", ] - types_8 = ["float16", "float32", "int8", "uint8", "int16", "uint16", "int32", "uint32"] + types_8 = [ + "float16", + "float32", + "bfloat16", + "int8", + "uint8", + "int16", + "uint16", + "int32", + "uint32", + ] for t0, t1 in [(x, y) for x in types_4 for y in types_4 if not skip(x, y)]: check(t0, t1, 4) for t0, t1 in [(x, y) for x in types_8 for y in types_8 if not skip(x, y)]: @@ -662,6 +652,9 @@ def run_test(tvm_intrin, np_func, dtype): if dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version): print("Skip because gpu does not have fp16 support") return + if dtype == "bfloat16" and not have_bf16(tvm.cuda(0).compute_version): + print("skip because gpu does not support bf16") + return # set of intrinsics does not support fp16 yet. skip_set = { tvm.tir.abs, @@ -675,6 +668,9 @@ def run_test(tvm_intrin, np_func, dtype): if dtype == "float16" and tvm_intrin in skip_set: print("Skip because '{0}' does not support fp16 yet".format(tvm_intrin.__name__)) return + if dtype == "bfloat16" and tvm_intrin in skip_set: + print("Skip because '{0}' does not support bf16 yet".format(tvm_intrin.__name__)) + return n = 128 A = te.placeholder((n,), dtype=dtype, name="A") @@ -684,12 +680,14 @@ def run_test(tvm_intrin, np_func, dtype): dev = tvm.cuda(0) a = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(A.dtype), dev) b = tvm.nd.array(np.zeros(shape=(n,)).astype(A.dtype), dev) + f(a, b) tvm.testing.assert_allclose(b.numpy(), np_func(a.numpy()), atol=1e-3, rtol=1e-3) for func in test_funcs: run_test(*func, "float32") run_test(*func, "float16") + run_test(*func, "bfloat16") @tvm.testing.requires_gpu @@ -751,6 +749,9 @@ def check_cuda(dtype, n, l, padding, lanes): if dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version): print("Skip because gpu does not have fp16 support") return + if dtype == "bfloat16" and not have_bf16(tvm.cuda(0).compute_version): + print("skip because gpu does not support bf16") + return dev = tvm.cuda(0) A = tvm.te.placeholder((n, l), name="A", dtype=dtype) @@ -786,6 +787,7 @@ def check_cuda(dtype, n, l, padding, lanes): check_cuda("int32", 64, 16, 3, 4) check_cuda("float16", 64, 16, 3, 4) check_cuda("float32", 64, 16, 3, 4) + check_cuda("bfloat16", 64, 16, 3, 4) def vcf_check_common(s, args): From 4f4e8a6b5d82afa69e505bf029ac2774e3082da7 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sat, 1 Jul 2023 12:11:08 -0700 Subject: [PATCH 24/37] bugfix --- src/target/source/codegen_cuda.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index fa3406f0f915..63b3572e3939 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -1149,7 +1149,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO if (op->dtype.is_bfloat16()) { std::string v = PrintExpr(op->value); if (op->lanes == 2) { - os << "make_bfloat162" << v << ", " << v << ")"; + os << "make_bfloat162(" << v << ", " << v << ")"; } else { os << "make_"; PrintType(op->dtype, os); @@ -1446,7 +1446,7 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val if (t.lanes() == 2) { // result data type is nv_bfloat162 if (i == 0) { - os << "make_bfloat162" << value; + os << "make_bfloat162(" << value; } else { os << ", " << value << ")"; } From 2162ebaa1b8d2519d6fec81269c450dcf078dffb Mon Sep 17 00:00:00 2001 From: Zihao Date: Sat, 1 Jul 2023 19:18:29 -0700 Subject: [PATCH 25/37] bugfix --- src/tir/transforms/unsupported_dtype_legalize.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index 70f15d31c415..cd942c5ba68e 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -366,7 +366,8 @@ class ComputeLegalizer : public StmtExprMutator { DataType legalized_dtype = new_buf->dtype.with_lanes(index_lanes * buffer_lanes); value = CastTargetToDType(value, legalized_dtype); } - if (value.dtype() != new_buf->dtype) { + + if (value.dtype().element_of() != new_buf->dtype.element_of()) { // this happens when buffer get rewritten to f32 // but values remain as fp8/bf16 ICHECK(MatchDType(value->dtype)); From 31355fa7144d702c7a2b400ae75a568a6f414cd1 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sat, 1 Jul 2023 19:20:51 -0700 Subject: [PATCH 26/37] lint --- python/tvm/testing/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index f50790d5f1dc..b4a5ada1f0ff 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -125,7 +125,8 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7): compares the `abs(actual-desired)` with `atol+rtol*abs(desired)`. Since we often allow `desired` to be close to zero, we generally want non-zero `atol`. """ - # The ml_dtypes v0.2 is not compatible with numpy's asanyarray function, promote to float32 first. + # The ml_dtypes v0.2 is not compatible with numpy's asanyarray function, promote to + # float32 first. actual = promote_bf16_to_fp32(actual) desired = promote_bf16_to_fp32(desired) From daaae7107cefb7ca35581561c348c6954c37b9c6 Mon Sep 17 00:00:00 2001 From: Zihao Date: Mon, 3 Jul 2023 01:56:55 -0700 Subject: [PATCH 27/37] refactor buildprocess --- include/tvm/driver/driver_api.h | 11 +++++- python/tvm/driver/build_module.py | 64 +++++++++++++++++++++++++------ src/driver/driver_api.cc | 44 ++++++++++++++------- 3 files changed, 92 insertions(+), 27 deletions(-) diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index fffcab49667c..ce498af4ce3a 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -50,11 +50,14 @@ using tvm::transform::Pass; /*! * \brief Configures and returns the composite Pass for the fused module (pre split) that contains * device and host code. + * * \param mixed_mod The original mixed module. * \param target The device Target. + * \param apply_lower_passes Whether to apply lowering passes. * \return The composite Pass for the fused module. // */ -TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target); +TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target, + bool apply_lower_passes); /*! * \brief Configures and returns the composite Pass for the device Target after device/host from @@ -140,6 +143,12 @@ TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array& args, IRModule ScheduleToModule(te::Schedule sch, const Array& args, const std::string& name, const std::unordered_map& binds, GlobalVarSupply global_var_supply); + +TVM_DLL runtime::Module IRModuleToRuntimeModule(const Map& inputs_arg, + const Target& target_host_arg, bool apply_lower_passes) { + +} + /*! * \brief Build a device and host module for a specific target from an IRModule. * \param funcs The functions to be built. diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 9389e7fbee60..4d7dc609147d 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -71,8 +71,8 @@ def schedule_to_module( """According to the given schedule, form a function. This is a low-level function intended for testing purposes, and - does not apply any optimization passes. In general, `tvm.lower` - and `tvm.build` should be used instead. + does not apply any optimization passes. In general, `tvm.build` + should be used instead. Parameters ---------- @@ -91,6 +91,47 @@ def schedule_to_module( return ffi.schedule_to_module(sch, args, name, binds) +def as_ir_module( + inp: Union[te.Schedule, PrimFunc, IRModule], + args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, + name: str = "main", + binds: Optional[Mapping[tensor.Tensor, Buffer]] = None, +) -> IRModule: + """Convert input to IRModule. + + Parameters + ---------- + inp : Union[tvm.te.schedule.Schedule, tvm.tir.PrimFunc, IRModule] + The input TE schedule or TensorIR PrimFunc/IRModule. + + args : Optional[List[Union[tvm.tir.Buffer, tensor.Tensor, Var]]] + The argument lists to the function for TE schedule. + It should be None if :attr:`inp` is a TensorIR PrimFunc/IRModule. + + name : str + The name of the result function. + + binds : Optional[Mapping[tensor.Tensor, tvm.tir.Buffer]] + Dictionary that maps the Tensor to Buffer which specified the data layout + requirement of the function. By default, a new compact buffer is created + for each tensor in the argument. + + Returns + ------- + m : IRModule + The result IRModule. + """ + if isinstance(inp, IRModule): + return inp + if isinstance(inp, PrimFunc): + return IRModule({name: inp.with_attr("global_symbol", name)}) + if isinstance(inp, te.Schedule): + return schedule_to_module(inp, args, name, binds) + raise ValueError( + f"Expected input to be an IRModule, PrimFunc or te.Schedule, but got {type(inp)}" + ) + + def lower( inp: Union[te.Schedule, PrimFunc, IRModule], args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, @@ -100,6 +141,8 @@ def lower( ) -> IRModule: """Lowering step before build into target. + Warning(legacy): This function is maintained for backward compatibility, please use :func:`build` directly. + Parameters ---------- inp : Union[tvm.te.schedule.Schedule, tvm.tir.PrimFunc, IRModule] @@ -199,8 +242,7 @@ def build( B = te.placeholder((n,), name='B') C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') s = tvm.te.create_schedule(C.op) - m = tvm.lower(s, [A, B, C], name="test_add") - rt_mod = tvm.build(m, target="llvm") + rt_mod = tvm.build(s, target="llvm") 2. it is a dict of compilation target to IRModule. @@ -213,9 +255,7 @@ def build( s1 = tvm.te.create_schedule(C.op) with tvm.target.cuda() as cuda_tgt: s2 = topi.cuda.schedule_injective(cuda_tgt, [C]) - m1 = tvm.lower(s1, [A, B, C], name="test_add1") - m2 = tvm.lower(s2, [A, B, C], name="test_add2") - rt_mod = tvm.build({"llvm": m1, "cuda": m2}) + rt_mod = tvm.build({"llvm": s1, "cuda": s2}) Note ---- @@ -224,16 +264,16 @@ def build( if isinstance(inputs, te.Schedule): if args is None: raise ValueError("args must be given for build from schedule") - input_mod = lower(inputs, args, name=name, binds=binds) + input_mod = as_ir_module(inputs, args, name=name, binds=binds) elif isinstance(inputs, (list, tuple, container.Array)): merged_mod = tvm.IRModule({}) for x in inputs: - merged_mod.update(lower(x)) + merged_mod.update(as_ir_module(x)) input_mod = merged_mod elif isinstance(inputs, PrimFunc): - input_mod = lower(inputs, name=name) + input_mod = as_ir_module(inputs, name=name) elif isinstance(inputs, tvm.IRModule): - input_mod = lower(inputs) + input_mod = as_ir_module(inputs) elif not isinstance(inputs, (dict, container.Map)): raise ValueError( f"Inputs must be te.Schedule, IRModule, PrimFunc, " @@ -278,7 +318,7 @@ def build( annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host) - rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host) + rt_mod_host = _driver_ffi.ir_module_to_runtime_module(annotated_mods, target_host, True) annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index f5eeda7243d7..9b932fd12ca9 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -213,6 +213,8 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::TransformMmaBufferLayout()); pass_list.push_back(tir::transform::LowerOpaqueBlock()); pass_list.push_back(tir::transform::FlattenBuffer()); + pass_list.push_back(tir::transform::FP8ComputeLegalize()); + pass_list.push_back(tir::transform::BF16ComputeLegalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); @@ -402,13 +404,14 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule") * device and host. Then it also applies transformations on the new splitted modules. */ std::pair SplitMixedModule(IRModule mod_mixed, const Target& target_arg, - const Target& target_host_arg) { + const Target& target_host_arg, + bool apply_lower_passes) { Target target = target_arg, target_host = target_host_arg; CheckAndUpdateHostConsistency(&target, &target_host); ICHECK(mod_mixed.defined()) << "This module must be defined"; - mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, target)); + mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, target, apply_lower_passes)); IRModule host_mod = ApplyPasses(mod_mixed, HostModulePassManager(mod_mixed, target_host)); @@ -427,8 +430,8 @@ std::pair SplitMixedModule(IRModule mod_mixed, const Target& return {host_mod, device_mod}; } -runtime::Module TIRToRuntime(const Map& inputs_arg, - const Target& target_host_arg) { +runtime::Module IRModuleToRuntimeModule(const Map& inputs_arg, + const Target& target_host_arg, bool apply_lower_passes) { std::vector device_modules; Map inputs = inputs_arg; Target target_host = target_host_arg; @@ -464,7 +467,7 @@ runtime::Module TIRToRuntime(const Map& inputs_arg, if (it.second.defined()) { const Target& target = it.first; const IRModule& ir_module = it.second; - auto pair = SplitMixedModule(ir_module, target, target_host); + auto pair = SplitMixedModule(ir_module, target, target_host, apply_lower_passes); auto& host_mod = pair.first; auto& device_mod = pair.second; @@ -501,6 +504,17 @@ runtime::Module TIRToRuntime(const Map& inputs_arg, return mhost; } +TVM_REGISTER_GLOBAL("driver.ir_module_to_runtime_module") + .set_body_typed([](const Map& inputs_arg, Target host_target, + bool apply_lower_passes) { + return IRModuleToRuntimeModule(inputs_arg, host_target, apply_lower_passes); + }); + +runtime::Module TIRToRuntime(const Map& inputs_arg, + const Target& target_host_arg) { + return IRModuleToRuntimeModule(inputs_arg, target_host_arg, false); +} + TVM_REGISTER_GLOBAL("driver.tir_to_runtime") .set_body_typed([](const Map& inputs_arg, Target host_target) { return TIRToRuntime(inputs_arg, host_target); @@ -539,22 +553,24 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg, return TIRToRuntime(inputs, target_host); } -transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) { +transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target, + bool apply_te_passes = false) { transform::PassContext pass_ctx = transform::PassContext::Current(); Array mixed_pass_list; + mixed_pass_list.push_back(tir::transform::BindTarget(target)); + if (apply_te_passes) { + for (auto&& pass : CreatePassList(false)) { + mixed_pass_list.push_back(pass); + } + } + // VerifyVTCMLimit must occur before LowerVtcmAlloc mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(target)); // LowerVtcmAlloc must occur after any transformations that modify memory allocation locations mixed_pass_list.push_back(tir::transform::LowerVtcmAlloc()); - mixed_pass_list.push_back(tir::transform::BindTarget(target)); - - // FP8/BF16 ComputeLegalize passes (after bind target) - mixed_pass_list.push_back(tir::transform::FP8ComputeLegalize()); - mixed_pass_list.push_back(tir::transform::BF16ComputeLegalize()); - mixed_pass_list.push_back(tir::transform::VerifyMemory()); mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc()); @@ -603,8 +619,8 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) } TVM_REGISTER_GLOBAL("driver.mixed_mod_passes") - .set_body_typed([](IRModule mixed_mod, Target target) { - return MixedModulePassManager(mixed_mod, target); + .set_body_typed([](IRModule mixed_mod, Target target, bool apply_lower_passes) { + return MixedModulePassManager(mixed_mod, target, apply_lower_passes); }); transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host) { From af72e1e76730046fb7c035f2ba66163f21e94e4f Mon Sep 17 00:00:00 2001 From: Zihao Date: Mon, 3 Jul 2023 02:14:02 -0700 Subject: [PATCH 28/37] remove unused functions --- include/tvm/driver/driver_api.h | 5 ----- 1 file changed, 5 deletions(-) diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index ce498af4ce3a..9c582e26badd 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -144,11 +144,6 @@ IRModule ScheduleToModule(te::Schedule sch, const Array& args, const const std::unordered_map& binds, GlobalVarSupply global_var_supply); -TVM_DLL runtime::Module IRModuleToRuntimeModule(const Map& inputs_arg, - const Target& target_host_arg, bool apply_lower_passes) { - -} - /*! * \brief Build a device and host module for a specific target from an IRModule. * \param funcs The functions to be built. From 98cb6e4829628d6f199ca1ec952da7eec2b8843d Mon Sep 17 00:00:00 2001 From: Zihao Date: Mon, 3 Jul 2023 04:48:51 -0700 Subject: [PATCH 29/37] pylint --- python/tvm/driver/build_module.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 4d7dc609147d..b92c6f15d497 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -141,7 +141,8 @@ def lower( ) -> IRModule: """Lowering step before build into target. - Warning(legacy): This function is maintained for backward compatibility, please use :func:`build` directly. + (Warning) This function is obsolete and maintained for backward compatibility with + legacy TE Schedule, please use :func:`build` directly. Parameters ---------- From aae82088b6d06fd409687f4c22caeb0a50543f5f Mon Sep 17 00:00:00 2001 From: Zihao Date: Mon, 3 Jul 2023 06:39:59 -0700 Subject: [PATCH 30/37] import error --- python/tvm/runtime/ndarray.py | 5 +---- tests/python/unittest/test_target_codegen_llvm.py | 1 + 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index 6857c5e2d958..84b2026f2a6a 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -19,11 +19,8 @@ import ctypes import warnings import numpy as np +import ml_dtypes -try: - import ml_dtypes -except ImportError: - ml_dtypes = None import tvm._ffi from tvm._ffi.base import _LIB, check_call, c_array, string_types, _FFI_MODE diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index aa20c6078f97..c1a78759faa5 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -22,6 +22,7 @@ import pytest import re import sys +import ml_dtypes import tvm import tvm.testing From 0a4a67af6a5825dd3302e7a2f39216ef697f4feb Mon Sep 17 00:00:00 2001 From: Zihao Date: Fri, 30 Jun 2023 06:30:23 -0700 Subject: [PATCH 31/37] upd --- src/target/source/codegen_c.cc | 7 +++ src/target/source/codegen_c.h | 2 + src/target/source/codegen_cuda.cc | 69 ++++++++++++++---------- src/target/source/codegen_cuda.h | 2 + src/target/source/intrin_rule_cuda.cc | 16 ++++++ src/target/source/literal/cuda_half_t.h | 71 ++++++++++++++++++++++++- 6 files changed, 139 insertions(+), 28 deletions(-) diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index a7cc320562cb..0b5ffa966857 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -621,6 +621,13 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) } } +void CodeGenC::PrintVecUnaryOp(const std::string& op, DataType t, PrimExpr operand, + std::ostream& os) { // NOLINT(*) + os << op << "("; + this->PrintExpr(operand, os); + os << ")"; +} + void CodeGenC::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, std::ostream& os) { // NOLINT(*) if (isalpha(op[0])) { diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 93f9ea519c23..c6b8dca7f74e 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -171,6 +171,8 @@ class CodeGenC : public ExprFunctor, virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*) virtual void PrintStorageSync(const CallNode* op); // NOLINT(*) // Binary vector op. + virtual void PrintVecUnaryOp(const std::string& op, DataType t, PrimExpr operand, + std::ostream& os); // NOLINT(*) virtual void PrintVecBinaryOp(const std::string& op, DataType op_type, PrimExpr lhs, PrimExpr rhs, std::ostream& os); // NOLINT(*) // print vector load diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 63b3572e3939..bc96145508c6 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -103,6 +103,7 @@ std::string CodeGenCUDA::Finish() { decl_stream << _cuda_half_t_def; decl_stream << "#endif\n\n"; decl_stream << _cuda_half_util; + decl_stream << _cuda_half2_util; } if (enable_bf16_) { @@ -115,6 +116,7 @@ std::string CodeGenCUDA::Finish() { << "{\n return __hlt(a, b) ? a : b;\n}\n"; decl_stream << "#endif\n\n"; decl_stream << _cuda_bfloat16_util; + decl_stream << _cuda_bfloat162_util; } if (enable_fp8_) { @@ -429,39 +431,52 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) LOG(FATAL) << "Cannot convert type " << t << " to CUDA type"; } +void CodeGenCUDA::PrintVecUnaryOp(const std::string& op, DataType t, PrimExpr operand, std::ostream& os) { // NOLINT(*) + // TODO(Zihao) +} + void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, std::ostream& os) { // NOLINT(*) // Delcare the result. - std::string sret = name_supply_->FreshName("_"); - this->PrintIndent(); - this->PrintType(t, stream); - stream << ' ' << sret << ";\n"; - int ssa_scope = BeginScope(); - { - // Unpack into individual ops. - std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype()); - std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype()); - - for (int i = 0, lanes = t.lanes(); i < lanes; ++i) { - std::ostringstream value_temp; - if (isalpha(op[0])) { - value_temp << op << "("; - PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp); - value_temp << ", "; - PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp); - value_temp << ")"; - } else { - value_temp << "("; - PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp); - value_temp << op; - PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp); - value_temp << ")"; + if (t.bits() == 16 && t.is_floating_point() && t.lanes() == 2) { + // native half2 and nv_bfloat162 support. + if (isalpha(op[0])) { + os << op << "(" << lhs << ", " << rhs << ")"; + } else { + os << "(" << lhs << " " << op << " " << rhs << ")"; + } + } else { + std::string sret = name_supply_->FreshName("_"); + this->PrintIndent(); + this->PrintType(t, stream); + stream << ' ' << sret << ";\n"; + int ssa_scope = BeginScope(); + { + // Unpack into individual ops. + std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype()); + std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype()); + + for (int i = 0, lanes = t.lanes(); i < lanes; ++i) { + std::ostringstream value_temp; + if (isalpha(op[0])) { + value_temp << op << "("; + PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp); + value_temp << ", "; + PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp); + value_temp << ")"; + } else { + value_temp << "("; + PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp); + value_temp << op; + PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp); + value_temp << ")"; + } + PrintVecElemStore(sret, t, i, value_temp.str()); } - PrintVecElemStore(sret, t, i, value_temp.str()); } + EndScope(ssa_scope); + os << sret; } - EndScope(ssa_scope); - os << sret; } void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i, diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index c6cf96d460d4..662ca38b7f6d 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -51,6 +51,8 @@ class CodeGenCUDA final : public CodeGenC { void VisitStmt_(const ForNode* op) final; void PrintStorageSync(const CallNode* op) final; void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) + void PrintVecUnaryOp(const std::string& op, DataType t, PrimExpr operand, + std::ostream& os) final; // NOLINT(*) void PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, std::ostream& os) final; // NOLINT(*) void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) diff --git a/src/target/source/intrin_rule_cuda.cc b/src/target/source/intrin_rule_cuda.cc index 47222e75003d..f598d448f618 100644 --- a/src/target/source/intrin_rule_cuda.cc +++ b/src/target/source/intrin_rule_cuda.cc @@ -125,6 +125,22 @@ struct CUDAWarpIntrinsic { } }; +struct CUDAVectorIntrinsic { + std::string operator()(DataType t, std::string name) const { + if (t.bits() == 16 && t.is_floating_point()) { + // half2 and nv_bfloat16 arithmetics + // https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH____HALF2__ARITHMETIC.html#group__CUDA__MATH____HALF2__ARITHMETIC + // https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH____BFLOAT162__ARITHMETIC.html#group__CUDA__MATH____BFLOAT162__ARITHMETIC + if (name == "div") { + return "__h2div"; + } else { + return "__h" + name + "2"; + } + } + return ""; + } +}; + static PrimExpr DispatchCUDAWarpActiveMask(const PrimExpr& e) { const CallNode* call = e.as(); return Call(call->dtype, Op::Get("tir.cuda.__activemask"), call->args); diff --git a/src/target/source/literal/cuda_half_t.h b/src/target/source/literal/cuda_half_t.h index bc6b627f38b8..91a09003d887 100644 --- a/src/target/source/literal/cuda_half_t.h +++ b/src/target/source/literal/cuda_half_t.h @@ -325,6 +325,41 @@ CUDA_UNSUPPORTED_HALF_MATH_UNARY(herf, erf) #endif )"; +static constexpr const char* _cuda_half2_util = R"( +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + +static inline __device__ __host__ unsigned +__half2_as_uint(const half2 v) { + return *((unsigned *)(&v)); +} + +static inline __device__ __host__ half2 +__uint_as_half2(const unsigned int v) { + return *((half2 *)(&v)); +} + +#define CUDA_UNSUPPORTED_HALF2_MATH_BINARY(HALF2_MATH_NAME, HALF_MATH_NAME) \ +static inline __device__ __host__ half2 HALF2_MATH_NAME(half2 x, half2 y) { \ + return __halves2half2(HALF_MATH_NAME(x.x, y.x), HALF_MATH_NAME(x.y, y.y)); \ +} + +#define CUDA_UNSUPPORTED_HALF2_MATH_UNARY(HALF2_MATH_NAME, HALF_MATH_NAME) \ +static inline __device__ __host__ half2 HALF2_MATH_NAME(half2 x) { \ + return __halves2half2(HALF_MATH_NAME(x.x), HALF_MATH_NAME(x.y)); \ +} + +CUDA_UNSUPPORTED_HALF2_MATH_BINARY(hpow2, hpow) +CUDA_UNSUPPORTED_HALF2_MATH_BINARY(htanh2, htanh) +CUDA_UNSUPPORTED_HALF2_MATH_BINARY(htan2, htan) +CUDA_UNSUPPORTED_HALF2_MATH_BINARY(hatan2, hatan) +CUDA_UNSUPPORTED_HALF2_MATH_BINARY(herf2, herf) + +#undef CUDA_UNSUPPORTED_HALF2_MATH_BINARY +#undef CUDA_UNSUPPORTED_HALF2_MATH_UNARY + +#endif +)"; + static constexpr const char* _cuda_bfloat16_util = R"( // Pack two bfloat16 values. static inline __device__ __host__ unsigned @@ -357,12 +392,46 @@ CUDA_UNSUPPORTED_HALF_MATH_UNARY(htanh, tanhf) CUDA_UNSUPPORTED_HALF_MATH_UNARY(htan, tanf) CUDA_UNSUPPORTED_HALF_MATH_UNARY(hatan, atanf) CUDA_UNSUPPORTED_HALF_MATH_UNARY(herf, erf) -CUDA_UNSUPPORTED_HALF_MATH_UNARY(habs, abs) #undef CUDA_UNSUPPORTED_HALF_MATH_BINARY #undef CUDA_UNSUPPORTED_HALF_MATH_UNARY )"; +static constexpr const char* _cuda_bfloat162_util = R"( +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + +static inline __device__ __host__ unsigned +__nv_bfloat162_as_uint(const nv_bfloat162 v) { + return *((unsigned *)(&v)); +} + +static inline __device__ __host__ nv_bfloat162 +__uint_as_nv_bfloat162(const unsigned int v) { + return *((nv_bfloat162 *)(&v)); +} + +#define CUDA_UNSUPPORTED_BFLOAT2_MATH_BINARY(BFLOAT2_MATH_NAME, BFLOAT_MATH_NAME) \ +static inline __device__ __host__ nv_bfloat162 BFLOAT2_MATH_NAME(nv_bfloat162 x, nv_bfloat162 y) { \ + return __halves2half2(BFLOAT_MATH_NAME(x.x, y.x), BFLOAT_MATH_NAME(x.y, y.y)); \ +} + +#define CUDA_UNSUPPORTED_BFLOAT2_MATH_UNARY(BFLOAT2_MATH_NAME, BFLOAT_MATH_NAME) \ +static inline __device__ __host__ nv_bfloat162 BFLOAT2_MATH_NAME(nv_bfloat162 x) { \ + return __halves2half2(BFLOAT_MATH_NAME(x.x), BFLOAT_MATH_NAME(x.y)); \ +} + +CUDA_UNSUPPORTED_BFLOAT2_MATH_BINARY(hpow2, hpow) +CUDA_UNSUPPORTED_BFLOAT2_MATH_BINARY(htanh2, htanh) +CUDA_UNSUPPORTED_BFLOAT2_MATH_BINARY(htan2, htan) +CUDA_UNSUPPORTED_BFLOAT2_MATH_BINARY(hatan2, hatan) +CUDA_UNSUPPORTED_BFLOAT2_MATH_BINARY(herf2, herf) + +#undef CUDA_UNSUPPORTED_BFLOAT2_MATH_BINARY +#undef CUDA_UNSUPPORTED_BFLOAT2_MATH_UNARY + +#endif +)"; + static constexpr const char* _cuda_warp_intrinsic_util = R"( #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700) #define __shfl_sync(mask, var, lane, width) \ From 61ed5cc868ecfbc4ace35943fd0f558299d368cd Mon Sep 17 00:00:00 2001 From: Zihao Date: Fri, 30 Jun 2023 06:38:57 -0700 Subject: [PATCH 32/37] fix --- src/target/source/codegen_cuda.cc | 4 ++-- src/target/source/literal/cuda_half_t.h | 16 ++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index bc96145508c6..970bffcb5038 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -441,9 +441,9 @@ void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr l if (t.bits() == 16 && t.is_floating_point() && t.lanes() == 2) { // native half2 and nv_bfloat162 support. if (isalpha(op[0])) { - os << op << "(" << lhs << ", " << rhs << ")"; + os << op << "(" << PrintExpr(lhs) << ", " << PrintExpr(rhs) << ")"; } else { - os << "(" << lhs << " " << op << " " << rhs << ")"; + os << "(" << PrintExpr(lhs) << " " << op << " " << PrintExpr(rhs) << ")"; } } else { std::string sret = name_supply_->FreshName("_"); diff --git a/src/target/source/literal/cuda_half_t.h b/src/target/source/literal/cuda_half_t.h index 91a09003d887..d0787437329d 100644 --- a/src/target/source/literal/cuda_half_t.h +++ b/src/target/source/literal/cuda_half_t.h @@ -349,10 +349,10 @@ static inline __device__ __host__ half2 HALF2_MATH_NAME(half2 x) { \ } CUDA_UNSUPPORTED_HALF2_MATH_BINARY(hpow2, hpow) -CUDA_UNSUPPORTED_HALF2_MATH_BINARY(htanh2, htanh) -CUDA_UNSUPPORTED_HALF2_MATH_BINARY(htan2, htan) -CUDA_UNSUPPORTED_HALF2_MATH_BINARY(hatan2, hatan) -CUDA_UNSUPPORTED_HALF2_MATH_BINARY(herf2, herf) +CUDA_UNSUPPORTED_HALF2_MATH_UNARY(htanh2, htanh) +CUDA_UNSUPPORTED_HALF2_MATH_UNARY(htan2, htan) +CUDA_UNSUPPORTED_HALF2_MATH_UNARY(hatan2, hatan) +CUDA_UNSUPPORTED_HALF2_MATH_UNARY(herf2, herf) #undef CUDA_UNSUPPORTED_HALF2_MATH_BINARY #undef CUDA_UNSUPPORTED_HALF2_MATH_UNARY @@ -421,10 +421,10 @@ static inline __device__ __host__ nv_bfloat162 BFLOAT2_MATH_NAME(nv_bfloat162 x) } CUDA_UNSUPPORTED_BFLOAT2_MATH_BINARY(hpow2, hpow) -CUDA_UNSUPPORTED_BFLOAT2_MATH_BINARY(htanh2, htanh) -CUDA_UNSUPPORTED_BFLOAT2_MATH_BINARY(htan2, htan) -CUDA_UNSUPPORTED_BFLOAT2_MATH_BINARY(hatan2, hatan) -CUDA_UNSUPPORTED_BFLOAT2_MATH_BINARY(herf2, herf) +CUDA_UNSUPPORTED_BFLOAT2_MATH_UNARY(htanh2, htanh) +CUDA_UNSUPPORTED_BFLOAT2_MATH_UNARY(htan2, htan) +CUDA_UNSUPPORTED_BFLOAT2_MATH_UNARY(hatan2, hatan) +CUDA_UNSUPPORTED_BFLOAT2_MATH_UNARY(herf2, herf) #undef CUDA_UNSUPPORTED_BFLOAT2_MATH_BINARY #undef CUDA_UNSUPPORTED_BFLOAT2_MATH_UNARY From 1cbd2f5e184690c320f006d02d51ada9bb060b4e Mon Sep 17 00:00:00 2001 From: Zihao Date: Sat, 1 Jul 2023 02:00:30 -0700 Subject: [PATCH 33/37] wip --- src/target/source/codegen_cuda.cc | 51 ++++++++++++++++++++----------- 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 970bffcb5038..49a5a88eb55f 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -432,14 +432,19 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) } void CodeGenCUDA::PrintVecUnaryOp(const std::string& op, DataType t, PrimExpr operand, std::ostream& os) { // NOLINT(*) - // TODO(Zihao) + if (t.bits() == 16 && t.is_floating_point() && t.lanes() == 2) { + // use native half2 and nv_bfloat162 intrinsics. + os << op << "(" << PrintExpr(operand) << ")"; + } else { + // TODO(Zihao) + } } void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, std::ostream& os) { // NOLINT(*) // Delcare the result. if (t.bits() == 16 && t.is_floating_point() && t.lanes() == 2) { - // native half2 and nv_bfloat162 support. + // native half2 and nv_bfloat162 intrinsics. if (isalpha(op[0])) { os << op << "(" << PrintExpr(lhs) << ", " << PrintExpr(rhs) << ")"; } else { @@ -456,22 +461,34 @@ void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr l std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype()); std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype()); - for (int i = 0, lanes = t.lanes(); i < lanes; ++i) { - std::ostringstream value_temp; - if (isalpha(op[0])) { - value_temp << op << "("; - PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp); - value_temp << ", "; - PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp); - value_temp << ")"; - } else { - value_temp << "("; - PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp); - value_temp << op; - PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp); - value_temp << ")"; + if (t.bits() == 16 && t.is_floating_point() && t.lanes() % 2 == 0) { + std::string ptr_str = t.is_float16() ? "(half2*)" : "(nv_bfloat162*)"; + for (int i = 0, lanes = t.lanes(); i < lanes / 2; ++i) { + std::ostringstream value_temp; + if (isalpha(op[0])) { + value_temp << op << "(" << ")"; + } else { + value_temp << op << "(" << ")"; + } + } + } else { + for (int i = 0, lanes = t.lanes(); i < lanes; ++i) { + std::ostringstream value_temp; + if (isalpha(op[0])) { + value_temp << op << "("; + PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp); + value_temp << ", "; + PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp); + value_temp << ")"; + } else { + value_temp << "("; + PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp); + value_temp << op; + PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp); + value_temp << ")"; + } + PrintVecElemStore(sret, t, i, value_temp.str()); } - PrintVecElemStore(sret, t, i, value_temp.str()); } } EndScope(ssa_scope); From 617496ac6550ad122eaa31ba3deee352a4312b0d Mon Sep 17 00:00:00 2001 From: Zihao Date: Sat, 1 Jul 2023 02:17:46 -0700 Subject: [PATCH 34/37] wip --- src/target/source/codegen_cuda.cc | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 49a5a88eb55f..f5c59edf7fad 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -40,6 +40,33 @@ namespace tvm { namespace codegen { + +void PrintVec2xFloat16ElemLoad(const std::string& vec, DataType t, int i, + std::ostream& os) { // NOLINT(*) + ICHECK(!t.is_scalar()) << "Cannot load half2/nv_bfloat162 from scalar."; + ICHECK(t.is_floating_point() && t.bits() == 16) << "Data type not much, PrintVec2xFloat16ElemLoad only supports floating point type with 16 bits, got " << t << " instead."; + + static const char access[] = {'x', 'y', 'z', 'w'}; + if (t.is_float16()) { + if (t.lanes() == 2) { + // 2 * float16 is stored as half2, return itself + os << vec; + } else { + // 4/8 * float16 is stored as uint2/4, use (*(half2*)(&(v.x))) + os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; + } + } else { + ICHECK(t.is_bfloat16()); + if (t.lanes() == 2) { + // 2 * bfloat16 is stored as nv_bfloat162, return itself + os << vec; + } else { + // 4/8 * bfloat16 is stored as uint2/4, use (*(nv_bfloat162*)(&(v.x))) + os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; + } + } +} + CodeGenCUDA::CodeGenCUDA() { restrict_keyword_ = "__restrict__"; } void CodeGenCUDA::Init(bool output_ssa) { From e96203203a7587e6450c84b84aa6edd67289c3dd Mon Sep 17 00:00:00 2001 From: Zihao Date: Sat, 1 Jul 2023 02:31:33 -0700 Subject: [PATCH 35/37] upd --- src/target/source/codegen_cuda.cc | 58 ++++++++++++++++++++++++------- 1 file changed, 46 insertions(+), 12 deletions(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index f5c59edf7fad..5df3f7805034 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -40,29 +40,60 @@ namespace tvm { namespace codegen { - void PrintVec2xFloat16ElemLoad(const std::string& vec, DataType t, int i, - std::ostream& os) { // NOLINT(*) + std::ostream& os) { // NOLINT(*) ICHECK(!t.is_scalar()) << "Cannot load half2/nv_bfloat162 from scalar."; - ICHECK(t.is_floating_point() && t.bits() == 16) << "Data type not much, PrintVec2xFloat16ElemLoad only supports floating point type with 16 bits, got " << t << " instead."; + ICHECK(t.is_floating_point() && t.bits() == 16) + << "Data type not much, PrintVec2xFloat16ElemLoad only supports floating point type with 16 " + "bits, got " + << t << " instead."; static const char access[] = {'x', 'y', 'z', 'w'}; if (t.is_float16()) { if (t.lanes() == 2) { - // 2 * float16 is stored as half2, return itself + // vec (2 * float16) is stored as half2, return itself os << vec; } else { - // 4/8 * float16 is stored as uint2/4, use (*(half2*)(&(v.x))) - os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; + // vec (4/8 * float16) is stored as uint2/4, return (*((half2*)(&(vec.x/y/z/w)))) + os << "(*((half2*)(&(" << vec << "." << access[i] << "))))"; } } else { ICHECK(t.is_bfloat16()); if (t.lanes() == 2) { - // 2 * bfloat16 is stored as nv_bfloat162, return itself + // vec (2 * bfloat16) is stored as nv_bfloat162, return itself os << vec; } else { - // 4/8 * bfloat16 is stored as uint2/4, use (*(nv_bfloat162*)(&(v.x))) - os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; + // vec (4/8 * bfloat16) is stored as uint2/4, return (*((nv_bfloat162*)(&(vec.x)))) + os << "(*((nv_bfloat162*)(&(" << vec << "." << access[i] << "))))"; + } + } +} + +void PrintVec2xFloat16ElemStore(const std::string& vec, DataType t, int i, const std::string& value, + std::ostream& os) { + ICHECK(!t.is_scalar()) << "Cannot store half2/nv_bfloat162 to scalar."; + ICHECK(t.is_floating_point() && t.bits() == 16) + << "Data type not much, PrintVec2xFloat16ElemStore only supports floating point type with 16 " + "bits, got " + << t << " instead."; + + static const char access[] = {'x', 'y', 'z', 'w'}; + if (t.is_float16()) { + if (t.lanes() == 2) { + // vec (2 * float16) has type half2, return vec = value; + os << vec << " = " << value << ";\n"; + } else { + // vec (4/8 * float16) is stored as uint2/4, return ((half2*)(&(vec.x/y/z/w))) = value + os << "((half2*)(&(" << vec << "." << access[i] << "))) = " << value << ";\n"; + } + } else { + ICHECK(t.is_bfloat16()); + if (t.lanes() == 2) { + // vec (2 * bfloat16) has type nv_bfloat162, return vec = value; + os << vec << " = " << value << ";\n"; + } else { + // vec (4/8 * bfloat16) is stored as uint2/4, return ((nv_bfloat162*)(&(vec.x/y/z/w))) = value + os << "((nv_bfloat162*)(&(" << vec << "." << access[i] << "))) = " << value << ";\n"; } } } @@ -458,7 +489,8 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) LOG(FATAL) << "Cannot convert type " << t << " to CUDA type"; } -void CodeGenCUDA::PrintVecUnaryOp(const std::string& op, DataType t, PrimExpr operand, std::ostream& os) { // NOLINT(*) +void CodeGenCUDA::PrintVecUnaryOp(const std::string& op, DataType t, PrimExpr operand, + std::ostream& os) { // NOLINT(*) if (t.bits() == 16 && t.is_floating_point() && t.lanes() == 2) { // use native half2 and nv_bfloat162 intrinsics. os << op << "(" << PrintExpr(operand) << ")"; @@ -493,9 +525,11 @@ void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr l for (int i = 0, lanes = t.lanes(); i < lanes / 2; ++i) { std::ostringstream value_temp; if (isalpha(op[0])) { - value_temp << op << "(" << ")"; + value_temp << op << "(" + << ")"; } else { - value_temp << op << "(" << ")"; + value_temp << op << "(" + << ")"; } } } else { From c2bd9eedfff6458bb99098d0cd04843c9907e0b8 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sat, 1 Jul 2023 02:53:51 -0700 Subject: [PATCH 36/37] complete unary --- src/target/source/codegen_cuda.cc | 50 ++++++++++++++++++++++++++----- 1 file changed, 42 insertions(+), 8 deletions(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 5df3f7805034..7979ab72f24f 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -495,7 +495,34 @@ void CodeGenCUDA::PrintVecUnaryOp(const std::string& op, DataType t, PrimExpr op // use native half2 and nv_bfloat162 intrinsics. os << op << "(" << PrintExpr(operand) << ")"; } else { - // TODO(Zihao) + std::string sret = name_supply_->FreshName("_"); + this->PrintIndent(); + this->PrintType(t, stream); + stream << ' ' << sret << ";\n"; + int ssa_scope = BeginScope(); + { // begin of ssa_scope + std::string voperand = SSAGetID(PrintExpr(operand), operand.dtype()); + if (t.bits() == 16 && t.is_floating_point() && t.lanes() % 2 == 0) { + // load & store at the granularity of 2 elements. + for (int i = 0, lanes = t.lanes(); i < lanes / 2; ++i) { + std::ostringstream value_temp; + value_temp << op << "("; + PrintVec2xFloat16ElemLoad(voperand, operand.dtype(), i, value_temp); + value_temp << ")"; + PrintVec2xFloat16ElemStore(sret, t, i, value_temp.str(), stream); + } + } else { + for (int i = 0, lanes = t.lanes(); i < lanes; ++i) { + std::ostringstream value_temp; + value_temp << op << "("; + PrintVecElemLoad(voperand, operand.dtype(), i, value_temp); + value_temp << ")"; + PrintVecElemStore(sret, t, i, value_temp.str()); + } + } + } // end of ssa_scope + EndScope(ssa_scope); + os << sret; } } @@ -515,22 +542,29 @@ void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr l this->PrintType(t, stream); stream << ' ' << sret << ";\n"; int ssa_scope = BeginScope(); - { + { // begin of ssa_scope // Unpack into individual ops. std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype()); std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype()); if (t.bits() == 16 && t.is_floating_point() && t.lanes() % 2 == 0) { - std::string ptr_str = t.is_float16() ? "(half2*)" : "(nv_bfloat162*)"; + // load & store at the granularity of 2 elements. for (int i = 0, lanes = t.lanes(); i < lanes / 2; ++i) { std::ostringstream value_temp; if (isalpha(op[0])) { - value_temp << op << "(" - << ")"; + value_temp << op << "("; + PrintVec2xFloat16ElemLoad(vlhs, lhs.dtype(), i, value_temp); + value_temp << ", "; + PrintVec2xFloat16ElemLoad(vrhs, rhs.dtype(), i, value_temp); + value_temp << ")"; } else { - value_temp << op << "(" - << ")"; + value_temp << "("; + PrintVec2xFloat16ElemLoad(vlhs, lhs.dtype(), i, value_temp); + value_temp << op; + PrintVec2xFloat16ElemLoad(vrhs, rhs.dtype(), i, value_temp); + value_temp << ")"; } + PrintVec2xFloat16ElemStore(sret, t, i, value_temp.str(), stream); } } else { for (int i = 0, lanes = t.lanes(); i < lanes; ++i) { @@ -551,7 +585,7 @@ void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr l PrintVecElemStore(sret, t, i, value_temp.str()); } } - } + } // end of ssa_scope EndScope(ssa_scope); os << sret; } From 3887ff684362dddec0ae72eb674e2a07fde8685f Mon Sep 17 00:00:00 2001 From: Zihao Date: Sat, 1 Jul 2023 03:44:25 -0700 Subject: [PATCH 37/37] bugfix --- src/target/source/codegen_cuda.cc | 8 ++++---- src/target/source/literal/cuda_half_t.h | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 7979ab72f24f..a2e05c6d55fc 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -83,8 +83,8 @@ void PrintVec2xFloat16ElemStore(const std::string& vec, DataType t, int i, const // vec (2 * float16) has type half2, return vec = value; os << vec << " = " << value << ";\n"; } else { - // vec (4/8 * float16) is stored as uint2/4, return ((half2*)(&(vec.x/y/z/w))) = value - os << "((half2*)(&(" << vec << "." << access[i] << "))) = " << value << ";\n"; + // vec (4/8 * float16) is stored as uint2/4, return *((half2*)(&(vec.x/y/z/w))) = value + os << "*((half2*)(&(" << vec << "." << access[i] << "))) = " << value << ";\n"; } } else { ICHECK(t.is_bfloat16()); @@ -92,8 +92,8 @@ void PrintVec2xFloat16ElemStore(const std::string& vec, DataType t, int i, const // vec (2 * bfloat16) has type nv_bfloat162, return vec = value; os << vec << " = " << value << ";\n"; } else { - // vec (4/8 * bfloat16) is stored as uint2/4, return ((nv_bfloat162*)(&(vec.x/y/z/w))) = value - os << "((nv_bfloat162*)(&(" << vec << "." << access[i] << "))) = " << value << ";\n"; + // vec (4/8 * bfloat16) is stored as uint2/4, return *((nv_bfloat162*)(&(vec.x/y/z/w))) = value + os << "*((nv_bfloat162*)(&(" << vec << "." << access[i] << "))) = " << value << ";\n"; } } } diff --git a/src/target/source/literal/cuda_half_t.h b/src/target/source/literal/cuda_half_t.h index d0787437329d..430bbd54485c 100644 --- a/src/target/source/literal/cuda_half_t.h +++ b/src/target/source/literal/cuda_half_t.h @@ -412,12 +412,12 @@ __uint_as_nv_bfloat162(const unsigned int v) { #define CUDA_UNSUPPORTED_BFLOAT2_MATH_BINARY(BFLOAT2_MATH_NAME, BFLOAT_MATH_NAME) \ static inline __device__ __host__ nv_bfloat162 BFLOAT2_MATH_NAME(nv_bfloat162 x, nv_bfloat162 y) { \ - return __halves2half2(BFLOAT_MATH_NAME(x.x, y.x), BFLOAT_MATH_NAME(x.y, y.y)); \ + return __halves2bfloat162(BFLOAT_MATH_NAME(x.x, y.x), BFLOAT_MATH_NAME(x.y, y.y)); \ } #define CUDA_UNSUPPORTED_BFLOAT2_MATH_UNARY(BFLOAT2_MATH_NAME, BFLOAT_MATH_NAME) \ static inline __device__ __host__ nv_bfloat162 BFLOAT2_MATH_NAME(nv_bfloat162 x) { \ - return __halves2half2(BFLOAT_MATH_NAME(x.x), BFLOAT_MATH_NAME(x.y)); \ + return __halves2bfloat162(BFLOAT_MATH_NAME(x.x), BFLOAT_MATH_NAME(x.y)); \ } CUDA_UNSUPPORTED_BFLOAT2_MATH_BINARY(hpow2, hpow)