diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index d705be6c4deb..7d914ce6bff9 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -160,12 +160,19 @@ class DataType { */ static DataType UInt(int bits, int lanes = 1) { return DataType(kDLUInt, bits, lanes); } /*! - * \brief Construct an uint type. + * \brief Construct an float type. * \param bits The number of bits in the type. * \param lanes The number of lanes * \return The constructed data type. */ static DataType Float(int bits, int lanes = 1) { return DataType(kDLFloat, bits, lanes); } + /*! + * \brief Construct an bfloat type. + * \param bits The number of bits in the type. + * \param lanes The number of lanes + * \return The constructed data type. + */ + static DataType BFloat(int bits, int lanes = 1) { return DataType(kDLBfloat, bits, lanes); } /*! * \brief Construct a bool type. * \param lanes The number of lanes diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index 2a97b0b31d1e..f33603b923a5 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -302,8 +302,22 @@ def have_tensorcore(compute_version=None, target=None): major, minor = compute_version.split("_")[1] compute_version = major + "." + minor major, _ = parse_compute_version(compute_version) + if major >= 7: + return True + + return False + + +def have_bf16(compute_version): + """Either bf16 support is provided in the compute capability or not - if major == 7: + Parameters + ---------- + compute_version : str + compute capability of a GPU (e.g. "8.0") + """ + major, _ = parse_compute_version(compute_version) + if major >= 8: return True return False diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index 75da3d4a5c17..5c60515e3448 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -148,7 +148,9 @@ def copyfrom(self, source_array): source_array.shape, shape ) ) - source_array = np.ascontiguousarray(source_array, dtype=dtype) + source_array = np.ascontiguousarray( + source_array, dtype="uint16" if dtype == "bfloat16" else dtype + ) assert source_array.flags["C_CONTIGUOUS"] data = source_array.ctypes.data_as(ctypes.c_void_p) nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 2e9babacc441..e54acd2221d1 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -61,6 +61,18 @@ std::string CodeGenCUDA::Finish() { decl_stream << _cuda_half_util; } + if (enable_bf16_) { + decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)\n"; + decl_stream << "#include \n"; + decl_stream << "__device__ nv_bfloat16 max" + << "(nv_bfloat16 a, nv_bfloat16 b)\n" + << "{\n return __hgt(a, b) ? a : b;\n}\n"; + decl_stream << "__device__ nv_bfloat16 min(nv_bfloat16 a, nv_bfloat16 b)\n" + << "{\n return __hlt(a, b) ? a : b;\n}\n"; + decl_stream << "#endif\n\n"; + decl_stream << _cuda_bfloat16_util; + } + if (enable_warp_shuffle_) { decl_stream << _cuda_warp_intrinsic_util; } @@ -170,6 +182,17 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) os << lanes; return; } + } else if (t.is_bfloat16()) { + enable_bf16_ = true; + if (t.is_scalar()) { + os << "nv_bfloat16"; + } else if (lanes <= 8) { + ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; + os << "uint" << lanes / 2; + } else { + fail = true; + } + if (!fail) return; } else if (t == DataType::Bool()) { os << "bool"; return; @@ -382,6 +405,8 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i, } } else if (t.is_float16()) { os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; + } else if (t.is_bfloat16()) { + os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; } else if (t.lanes() > 4 && t.lanes() <= 8) { std::string type_name; if (t.bits() == 16) { @@ -427,6 +452,9 @@ void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i, } else if (t.is_float16()) { stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = " << value << ";\n"; + } else if (t.is_bfloat16()) { + stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] + << " = " << value << ";\n"; } else if (t.lanes() > 4 && t.lanes() <= 8) { std::string type_name; if (t.bits() == 16) { @@ -687,7 +715,8 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) || op->dtype == DataType::UInt(8) || op->dtype == DataType::Int(4) || - op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1)) + op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1) || + op->dtype == DataType::BFloat(16)) << "Matrix_a and matrix_b only support half or char or unsigned char " << "or uint4 or int4 or int1 type for now"; } else { @@ -767,6 +796,19 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO return; } + if (op->dtype.is_bfloat16()) { + std::string v = PrintExpr(op->value); + os << "make_"; + PrintType(op->dtype, os); + os << '('; + for (int i = 0; i < op->lanes / 2; ++i) { + if (i != 0) os << ", "; + os << "__pack_nv_bfloat162(" << v << ", " << v << ")"; + } + os << ')'; + return; + } + std::string v = PrintExpr(op->value); os << "make_"; PrintType(op->dtype, os); @@ -836,6 +878,13 @@ void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream& os) { } inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*) + // Type code is kBFloat + if (op->dtype.is_bfloat16()) { + os << "__float2bfloat16_rn"; + os << '(' << std::scientific << op->value << 'f' << ')'; + return; + } + // Type code is kFloat switch (op->dtype.bits()) { case 64: case 32: { @@ -938,7 +987,7 @@ void CodeGenCUDA::HandleVolatileLoads(const std::string& value, const LoadNode* // Cast away volatile qualifier for fp16 types. That is, only loads and // stores are volatile. The loaded objects are not marked as volatile. // - if (op->dtype.is_float16() && IsVolatile(op->buffer_var.get())) { + if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) && IsVolatile(op->buffer_var.get())) { os << "("; PrintType(op->dtype, os); os << ")(" << value << ")"; @@ -979,6 +1028,25 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val return; } + if (t.is_bfloat16()) { + if (i == 0) { + os << "make_"; + PrintType(t, os); + os << '('; + } + if (i % 2 == 0) { + os << "__pack_bfloat162(" << value; + } else { + os << "," << value << ")"; + if (i != t.lanes() - 1) { + os << ","; + } else { + os << ")"; + } + } + return; + } + if (i == 0) { os << "make_"; PrintType(t, os); diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index 3cde8e379eb4..2098b8ac8344 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -42,7 +42,7 @@ class CodeGenCUDA final : public CodeGenC { void Init(bool output_ssa); std::string Finish(); bool need_include_path() { - return (enable_fp16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_); + return (enable_fp16_ || enable_bf16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_); } // override behavior void PrintFuncPrefix() final; @@ -88,6 +88,8 @@ class CodeGenCUDA final : public CodeGenC { std::string vid_global_barrier_expect_; // whether enable fp16 bool enable_fp16_{false}; + // whether enable bf16 + bool enable_bf16_{false}; // whether enable int8 bool enable_int8_{false}; // whether enable warp shuffle intrinsics diff --git a/src/target/source/intrin_rule_cuda.cc b/src/target/source/intrin_rule_cuda.cc index 5c562f7b1643..965b86c24d9e 100644 --- a/src/target/source/intrin_rule_cuda.cc +++ b/src/target/source/intrin_rule_cuda.cc @@ -43,6 +43,8 @@ struct CUDAMath { default: return ""; } + } else if (t.is_bfloat16()) { + return 'h' + name; } return ""; } diff --git a/src/target/source/literal/cuda_half_t.h b/src/target/source/literal/cuda_half_t.h index f8e92d508d88..3888f3a4fb07 100644 --- a/src/target/source/literal/cuda_half_t.h +++ b/src/target/source/literal/cuda_half_t.h @@ -311,6 +311,30 @@ static inline __device__ __host__ half htanh(half x) { #endif )"; +static constexpr const char* _cuda_bfloat16_util = R"( +// Pack two bfloat16 values. +static inline __device__ __host__ unsigned +__pack_nv_bfloat162(const nv_bfloat16 x, const nv_bfloat16 y) { + unsigned v0 = *((unsigned short *)&x); + unsigned v1 = *((unsigned short *)&y); + return (v1 << 16) | v0; +} + +// fix undefined fp16 match function +static inline __device__ __host__ nv_bfloat16 hpow(nv_bfloat16 x, nv_bfloat16 y) { + float tmp_x = __bfloat162float(x); + float tmp_y = __bfloat162float(y); + float result = powf(tmp_x, tmp_y); + return __float2bfloat16(result); +} + +static inline __device__ __host__ nv_bfloat16 htanh(nv_bfloat16 x) { + float tmp_x = __bfloat162float(x); + float result = tanhf(tmp_x); + return __float2bfloat16(result); +} +)"; + static constexpr const char* _cuda_warp_intrinsic_util = R"( #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700) #define __shfl_sync(mask, var, lane, width) \ diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index a228a640f108..06d7cb4bb7bb 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -19,7 +19,7 @@ import numpy as np from tvm import topi import unittest -from tvm.contrib.nvcc import have_fp16, have_int8 +from tvm.contrib.nvcc import have_fp16, have_int8, have_bf16 from tvm.contrib import nvcc import tvm.testing @@ -67,6 +67,53 @@ def check_cuda(dtype, n, lanes): check_cuda("float16", 64, 8) +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda +def test_cuda_bf16_vectorize_add(): + if not have_bf16(tvm.gpu(0).compute_version): + print("skip because gpu does not support bf16") + return + num_thread = 8 + + def np_float2np_bf16(arr): + """Convert a numpy array of float to a numpy array + of bf16 in uint16""" + orig = arr.view("