Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,19 @@ class DataType {
*/
static DataType UInt(int bits, int lanes = 1) { return DataType(kDLUInt, bits, lanes); }
/*!
* \brief Construct an uint type.
* \brief Construct an float type.
* \param bits The number of bits in the type.
* \param lanes The number of lanes
* \return The constructed data type.
*/
static DataType Float(int bits, int lanes = 1) { return DataType(kDLFloat, bits, lanes); }
/*!
* \brief Construct an bfloat type.
* \param bits The number of bits in the type.
* \param lanes The number of lanes
* \return The constructed data type.
*/
static DataType BFloat(int bits, int lanes = 1) { return DataType(kDLBfloat, bits, lanes); }
/*!
* \brief Construct a bool type.
* \param lanes The number of lanes
Expand Down
16 changes: 15 additions & 1 deletion python/tvm/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,22 @@ def have_tensorcore(compute_version=None, target=None):
major, minor = compute_version.split("_")[1]
compute_version = major + "." + minor
major, _ = parse_compute_version(compute_version)
if major >= 7:
return True

return False


def have_bf16(compute_version):
"""Either bf16 support is provided in the compute capability or not

if major == 7:
Parameters
----------
compute_version : str
compute capability of a GPU (e.g. "8.0")
"""
major, _ = parse_compute_version(compute_version)
if major >= 8:
return True

return False
4 changes: 3 additions & 1 deletion python/tvm/runtime/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,9 @@ def copyfrom(self, source_array):
source_array.shape, shape
)
)
source_array = np.ascontiguousarray(source_array, dtype=dtype)
source_array = np.ascontiguousarray(
source_array, dtype="uint16" if dtype == "bfloat16" else dtype
)
assert source_array.flags["C_CONTIGUOUS"]
data = source_array.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize)
Expand Down
72 changes: 70 additions & 2 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ std::string CodeGenCUDA::Finish() {
decl_stream << _cuda_half_util;
}

if (enable_bf16_) {
decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)\n";
decl_stream << "#include <cuda_bf16.h>\n";
decl_stream << "__device__ nv_bfloat16 max"
<< "(nv_bfloat16 a, nv_bfloat16 b)\n"
<< "{\n return __hgt(a, b) ? a : b;\n}\n";
decl_stream << "__device__ nv_bfloat16 min(nv_bfloat16 a, nv_bfloat16 b)\n"
<< "{\n return __hlt(a, b) ? a : b;\n}\n";
decl_stream << "#endif\n\n";
decl_stream << _cuda_bfloat16_util;
}

if (enable_warp_shuffle_) {
decl_stream << _cuda_warp_intrinsic_util;
}
Expand Down Expand Up @@ -170,6 +182,17 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
os << lanes;
return;
}
} else if (t.is_bfloat16()) {
enable_bf16_ = true;
if (t.is_scalar()) {
os << "nv_bfloat16";
} else if (lanes <= 8) {
ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
os << "uint" << lanes / 2;
} else {
fail = true;
}
if (!fail) return;
} else if (t == DataType::Bool()) {
os << "bool";
return;
Expand Down Expand Up @@ -382,6 +405,8 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i,
}
} else if (t.is_float16()) {
os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
} else if (t.is_bfloat16()) {
os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
} else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name;
if (t.bits() == 16) {
Expand Down Expand Up @@ -427,6 +452,9 @@ void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i,
} else if (t.is_float16()) {
stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = "
<< value << ";\n";
} else if (t.is_bfloat16()) {
stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]
<< " = " << value << ";\n";
} else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name;
if (t.bits() == 16) {
Expand Down Expand Up @@ -687,7 +715,8 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) ||
op->dtype == DataType::UInt(8) || op->dtype == DataType::Int(4) ||
op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1))
op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1) ||
op->dtype == DataType::BFloat(16))
<< "Matrix_a and matrix_b only support half or char or unsigned char "
<< "or uint4 or int4 or int1 type for now";
} else {
Expand Down Expand Up @@ -767,6 +796,19 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO
return;
}

if (op->dtype.is_bfloat16()) {
std::string v = PrintExpr(op->value);
os << "make_";
PrintType(op->dtype, os);
os << '(';
for (int i = 0; i < op->lanes / 2; ++i) {
if (i != 0) os << ", ";
os << "__pack_nv_bfloat162(" << v << ", " << v << ")";
}
os << ')';
return;
}

std::string v = PrintExpr(op->value);
os << "make_";
PrintType(op->dtype, os);
Expand Down Expand Up @@ -836,6 +878,13 @@ void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream& os) {
}

inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*)
// Type code is kBFloat
if (op->dtype.is_bfloat16()) {
os << "__float2bfloat16_rn";
os << '(' << std::scientific << op->value << 'f' << ')';
return;
}
// Type code is kFloat
switch (op->dtype.bits()) {
case 64:
case 32: {
Expand Down Expand Up @@ -938,7 +987,7 @@ void CodeGenCUDA::HandleVolatileLoads(const std::string& value, const LoadNode*
// Cast away volatile qualifier for fp16 types. That is, only loads and
// stores are volatile. The loaded objects are not marked as volatile.
//
if (op->dtype.is_float16() && IsVolatile(op->buffer_var.get())) {
if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) && IsVolatile(op->buffer_var.get())) {
os << "(";
PrintType(op->dtype, os);
os << ")(" << value << ")";
Expand Down Expand Up @@ -979,6 +1028,25 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val
return;
}

if (t.is_bfloat16()) {
if (i == 0) {
os << "make_";
PrintType(t, os);
os << '(';
}
if (i % 2 == 0) {
os << "__pack_bfloat162(" << value;
} else {
os << "," << value << ")";
if (i != t.lanes() - 1) {
os << ",";
} else {
os << ")";
}
}
return;
}

if (i == 0) {
os << "make_";
PrintType(t, os);
Expand Down
4 changes: 3 additions & 1 deletion src/target/source/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class CodeGenCUDA final : public CodeGenC {
void Init(bool output_ssa);
std::string Finish();
bool need_include_path() {
return (enable_fp16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_);
return (enable_fp16_ || enable_bf16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_);
}
// override behavior
void PrintFuncPrefix() final;
Expand Down Expand Up @@ -88,6 +88,8 @@ class CodeGenCUDA final : public CodeGenC {
std::string vid_global_barrier_expect_;
// whether enable fp16
bool enable_fp16_{false};
// whether enable bf16
bool enable_bf16_{false};
// whether enable int8
bool enable_int8_{false};
// whether enable warp shuffle intrinsics
Expand Down
2 changes: 2 additions & 0 deletions src/target/source/intrin_rule_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ struct CUDAMath {
default:
return "";
}
} else if (t.is_bfloat16()) {
return 'h' + name;
}
return "";
}
Expand Down
24 changes: 24 additions & 0 deletions src/target/source/literal/cuda_half_t.h
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,30 @@ static inline __device__ __host__ half htanh(half x) {
#endif
)";

static constexpr const char* _cuda_bfloat16_util = R"(
// Pack two bfloat16 values.
static inline __device__ __host__ unsigned
__pack_nv_bfloat162(const nv_bfloat16 x, const nv_bfloat16 y) {
unsigned v0 = *((unsigned short *)&x);
unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0;
}

// fix undefined fp16 match function
static inline __device__ __host__ nv_bfloat16 hpow(nv_bfloat16 x, nv_bfloat16 y) {
float tmp_x = __bfloat162float(x);
float tmp_y = __bfloat162float(y);
float result = powf(tmp_x, tmp_y);
return __float2bfloat16(result);
}

static inline __device__ __host__ nv_bfloat16 htanh(nv_bfloat16 x) {
float tmp_x = __bfloat162float(x);
float result = tanhf(tmp_x);
return __float2bfloat16(result);
}
)";

static constexpr const char* _cuda_warp_intrinsic_util = R"(
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)
#define __shfl_sync(mask, var, lane, width) \
Expand Down
50 changes: 49 additions & 1 deletion tests/python/unittest/test_target_codegen_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np
from tvm import topi
import unittest
from tvm.contrib.nvcc import have_fp16, have_int8
from tvm.contrib.nvcc import have_fp16, have_int8, have_bf16
from tvm.contrib import nvcc
import tvm.testing

Expand Down Expand Up @@ -67,6 +67,53 @@ def check_cuda(dtype, n, lanes):
check_cuda("float16", 64, 8)


@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_cuda_bf16_vectorize_add():
if not have_bf16(tvm.gpu(0).compute_version):
print("skip because gpu does not support bf16")
return
num_thread = 8

def np_float2np_bf16(arr):
"""Convert a numpy array of float to a numpy array
of bf16 in uint16"""
orig = arr.view("<u4")
bias = np.bitwise_and(np.right_shift(orig, 16), 1) + 0x7FFF
return np.right_shift(orig + bias, 16).astype("uint16")

def np_bf162np_float(arr):
"""Convert a numpy array of bf16 (uint16) to a numpy array
of float"""
u32 = np.left_shift(arr.astype("uint32"), 16)
return u32.view("<f4")

def check_cuda(n, lanes):
A = te.placeholder((n,), name="A", dtype="bfloat16x%d" % lanes)
B = te.compute((n,), lambda i: A[i] + tvm.tir.const(1, A.dtype), name="B")
s = te.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
s[B].bind(xo, bx)
s[B].bind(xi, tx)
with tvm.transform.PassContext(
disabled_pass=["tir.BF16Promote", "tir.BF16CastElimination", "tir.BF16TypeLowering"]
):
fun = tvm.build(s, [A, B], "cuda")
ctx = tvm.gpu(0)
np_a = np.random.uniform(size=(n, lanes)).astype("float32")
np_a = np_bf162np_float(np_float2np_bf16(np_a))
a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(np_float2np_bf16(np_a))
c = tvm.nd.empty((n,), B.dtype, ctx)
fun(a, c)
c = tvm.nd.empty((n, lanes), "uint16", ctx).copyfrom(c)
tvm.testing.assert_allclose(c.asnumpy(), np_float2np_bf16(np_a + 1))

check_cuda(64, 2)
check_cuda(64, 4)
check_cuda(64, 6)
check_cuda(64, 8)


@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_cuda_multiply_add():
Expand Down Expand Up @@ -922,6 +969,7 @@ def test_unrolled_vectorization():

if __name__ == "__main__":
test_cuda_vectorize_add()
test_cuda_bf16_vectorize_add()
test_cuda_multiply_add()
test_cuda_vectorize_load()
test_cuda_make_int8()
Expand Down