From c36063ecdc21856a0b6d868398f8b962505d6c68 Mon Sep 17 00:00:00 2001 From: "Tatsuya.Nishiyama" Date: Fri, 20 Jul 2018 21:14:53 +0900 Subject: [PATCH 1/4] [TVM][CUDA] Add int8 support for cuda --- src/codegen/codegen_cuda.cc | 7 +++- src/codegen/codegen_cuda.h | 4 +- src/codegen/opt/build_cuda_on.cc | 3 -- tests/python/unittest/test_codegen_cuda.py | 45 +++++++++++++++++++++- 4 files changed, 53 insertions(+), 6 deletions(-) diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index 38c559ba2ba5..71e3aa4de94c 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -34,6 +34,10 @@ std::string CodeGenCUDA::Finish() { decl_stream << "#include \n"; } + if (enable_int8_) { + decl_stream << "#include \n"; + } + return CodeGenC::Finish(); } @@ -83,7 +87,8 @@ void CodeGenCUDA::PrintType(Type t, std::ostream& os) { // NOLINT(*) } if (t.bits() == 8 && t.lanes() == 4) { // directly 4 8 bit int in integer. - os << "int"; return; + enable_int8_ = true; + os << "char4"; return; } switch (t.bits()) { case 8: { diff --git a/src/codegen/codegen_cuda.h b/src/codegen/codegen_cuda.h index dd559f5232f9..f5d9861ec6b2 100644 --- a/src/codegen/codegen_cuda.h +++ b/src/codegen/codegen_cuda.h @@ -20,7 +20,7 @@ class CodeGenCUDA final : public CodeGenC { void Init(bool output_ssa); void AddFunction(LoweredFunc f); std::string Finish(); - bool need_include_path() { return enable_fp16_; } + bool need_include_path() { return (enable_fp16_ || enable_int8_); } // override behavior void VisitStmt_(const ir::For* op) final; void PrintStorageSync(const Call* op) final; @@ -49,6 +49,8 @@ class CodeGenCUDA final : public CodeGenC { std::string vid_global_barrier_expect_; // whether enable fp16 bool enable_fp16_{false}; + // whether enable int8 + bool enable_int8_{false}; }; } // namespace codegen diff --git a/src/codegen/opt/build_cuda_on.cc b/src/codegen/opt/build_cuda_on.cc index 72fd38b925e5..2e5766f53b76 100644 --- a/src/codegen/opt/build_cuda_on.cc +++ b/src/codegen/opt/build_cuda_on.cc @@ -64,7 +64,6 @@ std::string FindCUDAIncludePath() { std::string NVRTCCompile(const std::string& code, bool include_path = false) { std::vector compile_params; std::vector param_cstrings{}; - int num_options = 0; nvrtcProgram prog; cudaDeviceProp device_prop; std::string cc = "30"; @@ -78,13 +77,11 @@ std::string NVRTCCompile(const std::string& code, bool include_path = false) { } compile_params.push_back("-arch=compute_" + cc); - num_options++; if (include_path) { std::string include_option = "--include-path=" + FindCUDAIncludePath(); compile_params.push_back(include_option); - num_options++; } for (const auto& string : compile_params) { diff --git a/tests/python/unittest/test_codegen_cuda.py b/tests/python/unittest/test_codegen_cuda.py index 044c40696e59..8a79134b1d73 100644 --- a/tests/python/unittest/test_codegen_cuda.py +++ b/tests/python/unittest/test_codegen_cuda.py @@ -1,6 +1,12 @@ import tvm import numpy as np -from tvm.contrib.nvcc import have_fp16 +from tvm.contrib.nvcc import have_fp16, have_int8 +from tvm.contrib import nvcc + +@tvm.register_func +def tvm_callback_cuda_compile(code): + ptx = nvcc.compile_cuda(code, target="ptx") + return ptx def test_cuda_vectorize_add(): num_thread = 8 @@ -11,6 +17,9 @@ def check_cuda(dtype, n, lanes): if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version): print("skip because gpu does not support fp16") return + if dtype == "int8" and not have_int8(tvm.gpu(0).compute_version): + print("skip because gpu does not support int8") + return A = tvm.placeholder((n,), name='A', dtype="%sx%d" % (dtype, lanes)) B = tvm.compute((n,), lambda i: A[i]+tvm.const(1, A.dtype), name='B') s = tvm.create_schedule(B.op) @@ -27,6 +36,40 @@ def check_cuda(dtype, n, lanes): check_cuda("float32", 64, 2) check_cuda("float16", 64, 2) + check_cuda("int8", 64, 4) + +def test_cuda_multiply_add(): + num_thread = 8 + def check_cuda(dtype, n, lanes): + if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"): + print("skip because cuda is not enabled..") + return + if dtype == "int8" and not have_int8(tvm.gpu(0).compute_version): + print("skip because gpu does not support int8") + return + A = tvm.placeholder((n,), name='A', dtype="%sx%d" % (dtype, lanes)) + B = tvm.placeholder((n,), name='B', dtype="%sx%d" % (dtype, lanes)) + C = tvm.placeholder((n,), name='C', dtype="int32") + D = tvm.compute((n,), + lambda i: tvm.call_pure_extern("int32", "__dp4a", A[i], B[i], C[i]), name='D') + s = tvm.create_schedule(D.op) + xo, xi = s[D].split(D.op.axis[0], factor=num_thread) + s[D].bind(xo, tvm.thread_axis("blockIdx.x")) + s[D].bind(xi, tvm.thread_axis("threadIdx.x")) + fun = tvm.build(s, [A, B, C, D], "cuda") + np_a = np.random.randint(low=-128, high=127, size=(n,lanes)) + np_b = np.random.randint(low=-128, high=127, size=(n,lanes)) + np_c = np.random.randint(low=0, high=127, size=(n,)) + np_d = [sum(x * y) + z for x, y, z in zip(np_a, np_b, np_c)] + ctx = tvm.gpu(0) + a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(np_a) + b = tvm.nd.empty((n,), B.dtype, ctx).copyfrom(np_b) + c = tvm.nd.empty((n,), C.dtype, ctx).copyfrom(np_c) + d = tvm.nd.empty((n,), D.dtype, ctx) + fun(a, b, c, d) + np.testing.assert_allclose(d.asnumpy(), np_d) + check_cuda("int8", 64, 4) if __name__ == "__main__": test_cuda_vectorize_add() + test_cuda_multiply_add() From 924d7695d2d746b992f607a60a425fd43c8b8830 Mon Sep 17 00:00:00 2001 From: "Tatsuya.Nishiyama" Date: Sun, 29 Jul 2018 01:56:14 +0900 Subject: [PATCH 2/4] Remove registering tvm_callback_cuda_compile --- tests/python/unittest/test_codegen_cuda.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/python/unittest/test_codegen_cuda.py b/tests/python/unittest/test_codegen_cuda.py index 8a79134b1d73..36d28c32c6c8 100644 --- a/tests/python/unittest/test_codegen_cuda.py +++ b/tests/python/unittest/test_codegen_cuda.py @@ -3,11 +3,6 @@ from tvm.contrib.nvcc import have_fp16, have_int8 from tvm.contrib import nvcc -@tvm.register_func -def tvm_callback_cuda_compile(code): - ptx = nvcc.compile_cuda(code, target="ptx") - return ptx - def test_cuda_vectorize_add(): num_thread = 8 def check_cuda(dtype, n, lanes): @@ -38,6 +33,7 @@ def check_cuda(dtype, n, lanes): check_cuda("float16", 64, 2) check_cuda("int8", 64, 4) + def test_cuda_multiply_add(): num_thread = 8 def check_cuda(dtype, n, lanes): From 78634a9149f266af727ec6bb4ec42f3d42634d70 Mon Sep 17 00:00:00 2001 From: "Tatsuya.Nishiyama" Date: Tue, 31 Jul 2018 00:49:02 +0900 Subject: [PATCH 3/4] support lanes = 8 and lanes = 16 for int8 --- src/codegen/codegen_cuda.cc | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index 71e3aa4de94c..70ab807c9509 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -85,14 +85,19 @@ void CodeGenCUDA::PrintType(Type t, std::ostream& os) { // NOLINT(*) os << "unsigned "; } } - if (t.bits() == 8 && t.lanes() == 4) { - // directly 4 8 bit int in integer. - enable_int8_ = true; - os << "char4"; return; - } switch (t.bits()) { case 8: { - if (!t.is_uint() && t.lanes() == 1) { + if (t.lanes() == 4) { + // directly 4 8 bit int in integer. + enable_int8_ = true; + os << "char4"; return; + } else if (t.lanes() == 8) { + enable_int8_ = true; + os << "int2"; return; + } else if (t.lanes() == 16) { + enable_int8_ = true; + os << "int4"; return; + } else if (!t.is_uint() && t.lanes() == 1) { os << "signed char"; break; } else { os << "char"; break; From 7c037eeeaed2ec3bef4bf95b27d872a30e1cfdbd Mon Sep 17 00:00:00 2001 From: "Tatsuya.Nishiyama" Date: Tue, 31 Jul 2018 14:53:46 +0900 Subject: [PATCH 4/4] Add vectorize_load_test --- tests/python/unittest/test_codegen_cuda.py | 23 ++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/python/unittest/test_codegen_cuda.py b/tests/python/unittest/test_codegen_cuda.py index 36d28c32c6c8..664883528efc 100644 --- a/tests/python/unittest/test_codegen_cuda.py +++ b/tests/python/unittest/test_codegen_cuda.py @@ -66,6 +66,29 @@ def check_cuda(dtype, n, lanes): np.testing.assert_allclose(d.asnumpy(), np_d) check_cuda("int8", 64, 4) +def test_cuda_vectorize_load(): + num_thread = 8 + def check_cuda(dtype, n, lanes): + if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"): + print("skip because cuda is not enabled..") + return + ctx = tvm.gpu(0) + A = tvm.placeholder((n,), name='A', dtype="%sx%d" % (dtype, lanes)) + B = tvm.compute((n,), lambda i: A[i], name='B') + s = tvm.create_schedule(B.op) + bx, tx = s[B].split(B.op.axis[0], factor=num_thread) + s[B].bind(bx, tvm.thread_axis("blockIdx.x")) + s[B].bind(tx, tvm.thread_axis("threadIdx.x")) + fun = tvm.build(s, [A, B], "cuda", name="vector_load") + np_a = np.random.randint(low=-128, high=127, size=(n,lanes)) + a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(np_a) + b = tvm.nd.empty((n,), B.dtype, ctx) + fun(a,b) + np.testing.assert_allclose(a.asnumpy(), b.asnumpy()) + check_cuda("int8", 64, 8) + check_cuda("int8", 64, 16) + if __name__ == "__main__": test_cuda_vectorize_add() test_cuda_multiply_add() + test_cuda_load_store()