From 2842800bad13a16c9da05c2dd12a0f5f9170ffd4 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Wed, 25 Sep 2019 17:30:11 -0700 Subject: [PATCH 01/12] add tensor core support --- include/tvm/ir.h | 8 + src/codegen/codegen_cuda.cc | 113 ++++++++++ src/codegen/codegen_cuda.h | 5 + src/pass/storage_access.cc | 2 +- src/runtime/thread_storage_scope.h | 20 +- .../unittest/test_schedule_tensor_core.py | 197 ++++++++++++++++++ 6 files changed, 343 insertions(+), 2 deletions(-) create mode 100644 tests/python/unittest/test_schedule_tensor_core.py diff --git a/include/tvm/ir.h b/include/tvm/ir.h index b90804983cfb..b4e28f486926 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -1552,6 +1552,14 @@ constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit"; * } */ constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce"; +/*! + * \brief tvm intrinsic for tensor core opeartors. + */ +constexpr const char* tvm_load_matrix_sync = "tvm_load_matrix_sync"; +constexpr const char* tvm_mma_sync = "tvm_mma_sync"; +constexpr const char* tvm_fill_fragment = "tvm_fill_fragment"; +constexpr const char* tvm_store_matrix_sync = "tvm_store_matrix_sync"; +constexpr const char* tvm_access_fragement = "tvm_access_fragement"; } // namespace intrinsic diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index 241310fd00d4..245ba52fd2fa 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -74,6 +74,10 @@ std::string CodeGenCUDA::Finish() { decl_stream << "#include \n"; } + if (need_mma_h_) { + decl_stream << "#include \n"; + } + return CodeGenC::Finish(); } @@ -290,6 +294,100 @@ void CodeGenCUDA::PrintStorageScope( } } +void CodeGenCUDA::VisitExpr_(const Call *op, std::ostream& os) { + if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) { + need_mma_h_ = true; + CHECK_EQ(op->args.size(), 3U); + os << "nvcuda::wmma::fill_fragment("; + this->PrintExpr(op->args[0], os); + os << "["; + this->PrintExpr(op->args[1], os); + os << "], "; + this->PrintExpr(op->args[2], os); + os << ")"; + } else if (op->is_intrinsic(intrinsic::tvm_load_matrix_sync)) { + need_mma_h_ = true; + CHECK_EQ(op->args.size(), 4U); + os << "nvcuda::wmma::load_matrix_sync("; + this->PrintExpr(op->args[0], os); + os << "["; + this->PrintExpr(op->args[1], os); + os << "], "; + this->PrintExpr(op->args[2], os); + os << ", "; + this->PrintExpr(op->args[3], os); + os << ")"; + } else if (op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) { + need_mma_h_ = true; + CHECK_EQ(op->args.size(), 5U); + os << "nvcuda::wmma::store_matrix_sync("; + this->PrintExpr(op->args[2], os); + os << ", "; + this->PrintExpr(op->args[0], os); + os << "["; + this->PrintExpr(op->args[1], os); + os << "], "; + this->PrintExpr(op->args[3], os); + if (const StringImm *str = op->args[4].as()) { + os << ", nvcuda::wmma::" << str->value; + } else { + LOG(FATAL) << "Invalid parameters"; + } + os << ")"; + } else if (op->is_intrinsic(intrinsic::tvm_mma_sync)) { + need_mma_h_ = true; + CHECK_EQ(op->args.size(), 8U); + os << "nvcuda::wmma::mma_sync("; + for (int i = 0; i < 4; ++i) { + this->PrintExpr(op->args[i * 2], os); + os << "["; + this->PrintExpr(op->args[i * 2 + 1], os); + os << "]" << ((i < 3) ? ", ": ")"); + } + } else { + CodeGenC::VisitExpr_(op, os); + } +} + +void CodeGenCUDA::VisitStmt_(const Allocate* op) { + CHECK(!is_zero(op->condition)); + std::string vid = AllocVarID(op->buffer_var.get()); + if (op->new_expr.defined()) { + // Prefer global static allocation for the program + CHECK_EQ(op->free_function, "nop"); + std::string new_data = PrintExpr(op->new_expr); + this->PrintIndent(); + PrintType(op->type, stream); + stream << "* "<< vid << '=' << new_data << ";\n"; + } else { + this->PrintIndent(); + int32_t constant_size = op->constant_allocation_size(); + CHECK_GT(constant_size, 0) + << "Can only handle constant size stack allocation for now"; + const Variable* buffer = op->buffer_var.as(); + std::string scope = alloc_storage_scope_.at(buffer); + if (scope.find("wmma.") == 0) { + if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { + CHECK(op->type.is_float() && op->type.bits() == 16) + << "Matrix_a and matrix_b only support half type for now"; + } else { + CHECK(op->type.is_float() && (op->type.bits() == 16 || op->type.bits() == 32)) + << "Accumulator only support half and float type for now"; + } + constant_size /= 256; + PrintWmmaScope(scope, op->type, stream); + } else { + PrintStorageScope(scope, stream); + stream << ' '; + PrintType(op->type, stream); + } + stream << ' '<< vid << '[' + << constant_size << "];\n"; + } + RegisterHandleType(op->buffer_var.get(), op->type); + this->PrintStmt(op->body); +} + void CodeGenCUDA::VisitStmt_(const Evaluate *op) { if (is_const(op->value)) return; const Call* call = op->value.as(); @@ -392,5 +490,20 @@ void CodeGenCUDA::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(* PrintConst(op, os, this); } +void CodeGenCUDA::PrintWmmaScope(const std::string &scope, Type t, std::ostream &os) { + std::stringstream type; + PrintType(t, type); + if (scope == "wmma.matrix_a") { + need_mma_h_ = true; + os << "nvcuda::wmma::fragment"; + } else if (scope == "wmma.matrix_b") { + need_mma_h_ = true; + os << "nvcuda::wmma::fragment"; + } else if (scope == "wmma.accumulator") { + need_mma_h_ = true; + os << "nvcuda::wmma::fragment"; + } +} + } // namespace codegen } // namespace tvm diff --git a/src/codegen/codegen_cuda.h b/src/codegen/codegen_cuda.h index 61c6fa3a5170..621747d615c4 100644 --- a/src/codegen/codegen_cuda.h +++ b/src/codegen/codegen_cuda.h @@ -60,7 +60,9 @@ class CodeGenCUDA final : public CodeGenC { void VisitExpr_(const Shuffle* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const FloatImm *op, std::ostream& os) final; + void VisitExpr_(const Call *op, std::ostream& os) final; void VisitStmt_(const Evaluate *op) final; + void VisitStmt_(const Allocate *op) final; private: // Whether global barrier is needed. @@ -75,7 +77,10 @@ class CodeGenCUDA final : public CodeGenC { bool enable_int8_{false}; // whether need math_constants.h bool need_math_constants_h_{false}; + // whether need mma.h + bool need_mma_h_{false}; friend void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p); + void PrintWmmaScope(const std::string& scope, Type t, std::ostream& os); }; } // namespace codegen diff --git a/src/pass/storage_access.cc b/src/pass/storage_access.cc index f7deb25560d6..7d056366dc77 100644 --- a/src/pass/storage_access.cc +++ b/src/pass/storage_access.cc @@ -184,7 +184,7 @@ void StorageAccessVisitor::Visit_(const IfThenElse* op) { void StorageAccessVisitor::Visit_(const Call* op) { if (op->is_intrinsic(intrinsic::tvm_address_of)) { const Load *l = op->args[0].as(); - IRVisitor::Visit_(l); + Visit_(l); } else if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { CHECK_EQ(op->args.size(), 5U); Type dtype = op->args[0].type(); diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index 0934e46d4e21..8e75fab87849 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -50,7 +50,13 @@ enum class StorageRank { */ kWarp = 2, /*! \brief thread local memory */ - kLocal = 3 + kLocal = 3, + /*! \brief wmma scope memory of matrix_a */ + kWMMAMatrixA = 4, + /*! \brief wmma scope memory of matrix_b */ + kWMMAMatrixB = 5, + /*! \brief wmma scope memory of accumulator */ + kWMMAAccumulator = 6, }; /*! @@ -89,6 +95,9 @@ struct StorageScope { case StorageRank::kShared: return "shared" + tag; case StorageRank::kWarp: return "warp" + tag; case StorageRank::kLocal: return "local" + tag; + case StorageRank::kWMMAMatrixA: return "wmma.matrix_a" + tag; + case StorageRank::kWMMAMatrixB: return "wmma.matrix_b" + tag; + case StorageRank::kWMMAAccumulator: return "wmma.accumulator" + tag; default: LOG(FATAL) << "unknown storage scope"; return ""; } } @@ -111,6 +120,15 @@ struct StorageScope { } else if (s.compare(0, 5, "local") == 0) { r.rank = StorageRank::kLocal; r.tag = s.substr(5, std::string::npos); + } else if (s.compare(0, 13, "wmma.matrix_a") == 0) { + r.rank = StorageRank::kWMMAMatrixA; + r.tag = s.substr(13, std::string::npos); + } else if (s.compare(0, 13, "wmma.matrix_b") == 0) { + r.rank = StorageRank::kWMMAMatrixB; + r.tag = s.substr(13, std::string::npos); + } else if (s.compare(0, 16, "wmma.accumulator") == 0) { + r.rank = StorageRank::kWMMAAccumulator; + r.tag = s.substr(16, std::string::npos); } else { LOG(FATAL) << "unknown storage scope " << s; } diff --git a/tests/python/unittest/test_schedule_tensor_core.py b/tests/python/unittest/test_schedule_tensor_core.py new file mode 100644 index 000000000000..d62d688498c0 --- /dev/null +++ b/tests/python/unittest/test_schedule_tensor_core.py @@ -0,0 +1,197 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# 'License'); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import numpy as np + +def intrin_wmma_load_matrix(scope): + n = 16 + A = tvm.placeholder((n, n), name='A', dtype='float16') + BA = tvm.decl_buffer(A.shape, A.dtype, scope='shared', data_alignment=32, offset_factor=256) + C = tvm.compute((n, n), lambda i, j: A[i, j], name='C') + BC = tvm.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=256) + + def intrin_func(ins, outs): + ib = tvm.ir_builder.create() + + BA = ins[0] + BC = outs[0] + ib.emit(tvm.call_intrin('handle', 'tvm_load_matrix_sync', + BC.data, BC.elem_offset / 256, + BA.access_ptr('r'), n)) + return ib.get() + + return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) + +def intrin_wmma_gemm(): + n = 16 + A = tvm.placeholder((n, n), name='A', dtype='float16') + B = tvm.placeholder((n, n), name='B', dtype='float16') + k = tvm.reduce_axis((0, n), name="k") + C = tvm.compute((n, n), + lambda ii, jj: + tvm.sum(A[ii, k].astype('float') * B[k, jj].astype('float'), axis=k), + name='C') + BA = tvm.decl_buffer(A.shape, A.dtype, name='BA', scope='wmma.matrix_a', data_alignment=32, offset_factor=256) + BB = tvm.decl_buffer(B.shape, B.dtype, name='BB', scope='wmma.matrix_b', data_alignment=32, offset_factor=256) + BC = tvm.decl_buffer(C.shape, C.dtype, name='BC', scope='wmma.accumulator', data_alignment=32, offset_factor=256) + + def intrin_func(ins, outs): + BA, BB = ins + BC, = outs + def init(): + ib = tvm.ir_builder.create() + ib.emit(tvm.call_intrin('handle', 'tvm_fill_fragment', BC.data, BC.elem_offset / 256, 0.0)) + return ib.get() + + def update(): + ib = tvm.ir_builder.create() + ib.emit(tvm.call_intrin('handle', 'tvm_mma_sync', + BC.data, BC.elem_offset / 256, + BA.data, BA.elem_offset / 256, + BB.data, BB.elem_offset / 256, + BC.data, BC.elem_offset / 256)) + return ib.get() + return update(), init(), update() + return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC}) + +def intrin_wmma_store_matrix(): + n = 16 + A = tvm.placeholder((n, n), name='A', dtype='float32') + BA = tvm.decl_buffer(A.shape, A.dtype, scope='wmma.accumulator', data_alignment=32, offset_factor=256) + C = tvm.compute((n, n), lambda i, j: A[i, j], name='C') + BC = tvm.decl_buffer(C.shape, C.dtype, scope='shared', data_alignment=32, offset_factor=256) + + def intrin_func(ins, outs): + ib = tvm.ir_builder.create() + + BA = ins[0] + BC = outs[0] + ib.emit(tvm.call_intrin('handle', 'tvm_store_matrix_sync', + BA.data, BA.elem_offset / 256, + BC.access_ptr('w'), n, 'mem_row_major')) + return ib.get() + + return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) + +def test_tensor_core_float(): + n = 1024 + m, l = n, n + nn, mm, ll = n // 16, m // 16, l // 16 + A = tvm.placeholder((n, l), name='A', dtype='float16') + B = tvm.placeholder((l, m), name='B', dtype='float16') + AS = tvm.compute((nn, ll, 16, 16), lambda i, j, ii, jj: A[i * 16 + ii, j * 16 + jj], name='A.shared') + BS = tvm.compute((ll, nn, 16, 16), lambda i, j, ii, jj: B[i * 16 + ii, j * 16 + jj], name='B.shared') + k1 = tvm.reduce_axis((0, ll), name='k1') + k2 = tvm.reduce_axis((0, 16), name='k2') + CF = tvm.compute((nn, mm, 16, 16), + lambda i, j, ii, jj: + tvm.sum(AS[i, k1, ii, k2].astype('float') * BS[k1, j, k2, jj].astype('float'), axis=[k1, k2]), + name='Fragment_C') + C = tvm.compute((n, m), lambda ii, jj: CF[ii // 16, jj // 16, ii % 16, jj % 16], name='C') + s = tvm.create_schedule(C.op) + + warp_size = 32 + kernel_size = 16 + block_row_warps = 2 + block_col_warps = 4 + warp_row_tiles = 2 + warp_col_tiles = 1 + chunk = 4 + + block_x = tvm.thread_axis('blockIdx.x') + block_y = tvm.thread_axis('blockIdx.y') + thread_x = tvm.thread_axis('threadIdx.x') + thread_y = tvm.thread_axis('threadIdx.y') + thread_z = tvm.thread_axis('threadIdx.z') + + AF = s.cache_read(AS, 'wmma.matrix_a', [CF]) + BF = s.cache_read(BS, 'wmma.matrix_b', [CF]) + CS = s.cache_read(CF, 'shared', [C]) + s[AS].set_scope('shared') + s[BS].set_scope('shared') + s[CF].set_scope('wmma.accumulator') + + i, j = s[C].op.axis + i, ii = s[C].split(i, factor=kernel_size * warp_row_tiles) + block_i, i = s[C].split(i, factor=block_row_warps) + j, jj = s[C].split(j, factor=kernel_size * warp_col_tiles) + block_j, j = s[C].split(j, factor=block_col_warps) + s[C].reorder(block_i, block_j, i, j, ii, jj) + s[C].bind(block_i, block_x) + s[C].bind(block_j, block_y) + s[C].bind(i, thread_y) + s[C].bind(j, thread_z) + thread_i, _ = s[C].split(ii, nparts=warp_size) + s[C].bind(thread_i, thread_x) + + s[CS].compute_at(s[C], j) + s[CF].compute_at(s[C], j) + xo, yo, xi, yi = CS.op.axis + tx, xo = s[CS].split(xo, nparts=block_row_warps) + ty, yo = s[CS].split(yo, nparts=block_col_warps) + s[CS].bind(tx, thread_y) + s[CS].bind(ty, thread_z) + + warp_i, warp_j, _i, _j = s[CF].op.axis + k, _k = CF.op.reduce_axis + ko, ki = s[CF].split(k, factor=chunk) + s[CF].reorder(ko, ki, warp_i, warp_j, _i, _j, _k) + + s[AF].compute_at(s[CF], ki) + s[BF].compute_at(s[CF], ki) + + s[AS].compute_at(s[CF], ko) + xo, yo, xi, yi = AS.op.axis + tx, xo = s[AS].split(xo, nparts=block_row_warps) + ty, yo = s[AS].split(yo, nparts=block_col_warps) + t = s[AS].fuse(xi, yi) + to, ti = s[AS].split(t, nparts=warp_size) + s[AS].bind(tx, thread_y) + s[AS].bind(ty, thread_z) + s[AS].bind(to, thread_x) + + s[BS].compute_at(s[CF], ko) + xo, yo, xi, yi = BS.op.axis + tx, xo = s[BS].split(xo, nparts=block_row_warps) + ty, yo = s[BS].split(yo, nparts=block_col_warps) + t = s[BS].fuse(xi, yi) + to, ti = s[BS].split(t, nparts=warp_size) + s[BS].bind(tx, thread_y) + s[BS].bind(ty, thread_z) + s[BS].bind(to, thread_x) + + s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_a')) + s[BF].tensorize(BF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_b')) + s[CS].tensorize(CS.op.axis[-2], intrin_wmma_store_matrix()) + s[CF].tensorize(_i, intrin_wmma_gemm()) + func = tvm.build(s, [A, B, C], 'cuda') + + ctx = tvm.gpu(0) + a_np = np.random.uniform(size=(n, n)).astype(A.dtype) + b_np = np.random.uniform(size=(n, n)).astype(B.dtype) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.array(np.zeros((n, n), dtype=C.dtype), ctx) + func(a, b, c) + c_np = c.asnumpy() + np.testing.assert_allclose(c_np, np.dot(a_np, b_np).astype(C.dtype), rtol=0.001, atol=0.001) + evaluator = func.time_evaluator(func.entry_name, ctx, number=5) + print('gemm with tensor core: %f ms' % (evaluator(a, b, c).mean * 1e3)) + + +if __name__ == '__main__': + test_tensor_core_float() From 052d243e9a06406865ecaf01c7254e7c34b09fa2 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Wed, 25 Sep 2019 21:10:40 -0700 Subject: [PATCH 02/12] avoid memory bank conflict --- .../unittest/test_schedule_tensor_core.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/tests/python/unittest/test_schedule_tensor_core.py b/tests/python/unittest/test_schedule_tensor_core.py index d62d688498c0..9e624116ae1c 100644 --- a/tests/python/unittest/test_schedule_tensor_core.py +++ b/tests/python/unittest/test_schedule_tensor_core.py @@ -126,17 +126,20 @@ def test_tensor_core_float(): s[CF].set_scope('wmma.accumulator') i, j = s[C].op.axis - i, ii = s[C].split(i, factor=kernel_size * warp_row_tiles) + i, ti = s[C].split(i, factor=kernel_size) + i, ii = s[C].split(i, factor=warp_row_tiles) block_i, i = s[C].split(i, factor=block_row_warps) - j, jj = s[C].split(j, factor=kernel_size * warp_col_tiles) + j, tj = s[C].split(j, factor=kernel_size) + j, jj = s[C].split(j, factor=warp_col_tiles) block_j, j = s[C].split(j, factor=block_col_warps) - s[C].reorder(block_i, block_j, i, j, ii, jj) + s[C].reorder(block_i, block_j, i, j, ii, jj, ti, tj) s[C].bind(block_i, block_x) s[C].bind(block_j, block_y) s[C].bind(i, thread_y) s[C].bind(j, thread_z) - thread_i, _ = s[C].split(ii, nparts=warp_size) - s[C].bind(thread_i, thread_x) + t = s[C].fuse(ti, tj) + to, ti = s[C].split(t, factor=warp_size) + s[C].bind(ti, thread_x) s[CS].compute_at(s[C], j) s[CF].compute_at(s[C], j) @@ -159,20 +162,21 @@ def test_tensor_core_float(): tx, xo = s[AS].split(xo, nparts=block_row_warps) ty, yo = s[AS].split(yo, nparts=block_col_warps) t = s[AS].fuse(xi, yi) - to, ti = s[AS].split(t, nparts=warp_size) + to, ti = s[AS].split(t, factor=warp_size) s[AS].bind(tx, thread_y) s[AS].bind(ty, thread_z) - s[AS].bind(to, thread_x) + s[AS].bind(ti, thread_x) s[BS].compute_at(s[CF], ko) xo, yo, xi, yi = BS.op.axis tx, xo = s[BS].split(xo, nparts=block_row_warps) ty, yo = s[BS].split(yo, nparts=block_col_warps) t = s[BS].fuse(xi, yi) - to, ti = s[BS].split(t, nparts=warp_size) + to, ti = s[BS].split(t, factor=warp_size) s[BS].bind(tx, thread_y) s[BS].bind(ty, thread_z) - s[BS].bind(to, thread_x) + s[BS].bind(ti, thread_x) + s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_a')) s[BF].tensorize(BF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_b')) From 78133e3a07fcec7ec4b48d9cbdf2cf22485980de Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Thu, 26 Sep 2019 19:10:37 -0700 Subject: [PATCH 03/12] fix thread sync & better performance --- include/tvm/ir_pass.h | 18 +++++++-------- python/tvm/build_module.py | 3 ++- src/codegen/build_module.cc | 3 ++- src/pass/storage_access.cc | 8 ++++++- src/pass/storage_sync.cc | 22 +++++++++++++++++++ .../unittest/test_schedule_tensor_core.py | 12 +++++----- 6 files changed, 48 insertions(+), 18 deletions(-) diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 03078b8be41f..db5e60e479d8 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -359,15 +359,6 @@ Stmt LiftAttrScope(Stmt stmt, std::string attr_key); */ Stmt RewriteUnsafeSelect(Stmt stmt); -/*! - * \brief Lower attached storage access information. - * Do this pass after all storage access analysis finish. - * - * \param stmt The stmt to be transformed - * \return Transformed stmt. - */ -Stmt LowerStorageAccessInfo(Stmt stmt); - /*! * \brief Decorate the stmt with a device scope, this is helpful for * hardware accelerator without thread blocks. @@ -513,6 +504,15 @@ LoweredFunc CombineContextCall(LoweredFunc f); */ LoweredFunc PointerValueTypeRewrite(LoweredFunc f); +/*! + * \brief Lower attached storage access information. + * Do this pass after all storage access analysis finish. + * + * \param stmt The stmt to be transformed + * \return Transformed stmt. + */ +LoweredFunc LowerStorageAccessInfo(LoweredFunc func); + /*! * \brief Lower intrinsic function calls. * \param f The device function to be lowered. diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index 4cb09931616e..0e5c5e0dc2d7 100644 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -413,7 +413,6 @@ def lower(sch, # Phase 3 stmt = ir_pass.Simplify(stmt) - stmt = ir_pass.LowerStorageAccessInfo(stmt) stmt = ir_pass.RemoveNoOp(stmt) if not cfg.disable_select_rewriting: stmt = ir_pass.RewriteUnsafeSelect(stmt) @@ -494,6 +493,8 @@ def _build_for_device(flist, target, target_host): assert not fdevice target_host = _target.create(target_host) + fdevice = [ir_pass.LowerStorageAccessInfo(x) for x in fdevice] + fhost = [ir_pass.LowerStorageAccessInfo(x) for x in fhost] fdevice = [ir_pass.LowerIntrin(x, target.target_name) for x in fdevice] fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost] fhost = [ir_pass.CombineContextCall(x) for x in fhost] diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 66340e9c9021..77b72e2b66d6 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -422,7 +422,6 @@ Stmt BuildStmt(Schedule sch, // Phase 2 stmt = ir::Simplify(stmt); - stmt = ir::LowerStorageAccessInfo(stmt); stmt = ir::RemoveNoOp(stmt); if (!(config->disable_select_rewriting)) @@ -517,6 +516,7 @@ Array > split_dev_host_funcs(const Array& funcs, for (size_t i = 0; i < fhost.size(); ++i) { auto func = fhost[i]; func = ir::BindDeviceType(func, target->device_type); + func = ir::LowerStorageAccessInfo(func); func = ir::LowerTVMBuiltin(func); fhost.Set(i, func); } @@ -524,6 +524,7 @@ Array > split_dev_host_funcs(const Array& funcs, for (size_t i = 0; i < fhost.size(); ++i) { auto func = fhost[i]; func = ir::LowerIntrin(func, target_host->target_name); + func = ir::LowerStorageAccessInfo(func); func = ir::CombineContextCall(func); fhost.Set(i, func); } diff --git a/src/pass/storage_access.cc b/src/pass/storage_access.cc index 7d056366dc77..0e185685b555 100644 --- a/src/pass/storage_access.cc +++ b/src/pass/storage_access.cc @@ -184,7 +184,7 @@ void StorageAccessVisitor::Visit_(const IfThenElse* op) { void StorageAccessVisitor::Visit_(const Call* op) { if (op->is_intrinsic(intrinsic::tvm_address_of)) { const Load *l = op->args[0].as(); - Visit_(l); + IRVisitor::Visit_(l); } else if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { CHECK_EQ(op->args.size(), 5U); Type dtype = op->args[0].type(); @@ -341,5 +341,11 @@ Stmt LowerStorageAccessInfo(Stmt stmt) { return StorageAccessInfoLower().Mutate(stmt); } +LoweredFunc LowerStorageAccessInfo(LoweredFunc f) { + auto n = make_node(*f.operator->()); + n->body = LowerStorageAccessInfo(f->body); + return LoweredFunc(n); +} + } // namespace ir } // namespace tvm diff --git a/src/pass/storage_sync.cc b/src/pass/storage_sync.cc index 7c2f3211c532..019ed357274e 100644 --- a/src/pass/storage_sync.cc +++ b/src/pass/storage_sync.cc @@ -263,6 +263,28 @@ class ThreadSyncInserter : public IRMutator { } } + Expr Mutate_(const Call* op, const Expr& e) final { + if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { + Expr expr = IRMutator::Mutate_(op, e); + op = expr.as(); + CHECK_EQ(op->args.size(), 5U); + const Variable* buffer_var = op->args[1].as(); + Var var(buffer_var->GetNodePtr()); + const IntImm* flag = op->args[4].as(); + if ((flag->value & 1) && sync_scope_.rank == StorageRank::kGlobal && + GetScope(buffer_var).rank == StorageRank::kGlobal) { + ++rw_stats_[var].read_count; + } + if (flag->value & 2 && sync_scope_.rank == StorageRank::kGlobal && + GetScope(buffer_var).rank == StorageRank::kGlobal) { + ++rw_stats_[var].write_count; + } + return expr; + } else { + return IRMutator::Mutate_(op, e); + } + } + private: // RW statistics about data struct Entry { diff --git a/tests/python/unittest/test_schedule_tensor_core.py b/tests/python/unittest/test_schedule_tensor_core.py index 9e624116ae1c..cfad92066d0e 100644 --- a/tests/python/unittest/test_schedule_tensor_core.py +++ b/tests/python/unittest/test_schedule_tensor_core.py @@ -87,8 +87,8 @@ def intrin_func(ins, outs): return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) -def test_tensor_core_float(): - n = 1024 +def test_tensor_core_gemm(): + n = 4096 m, l = n, n nn, mm, ll = n // 16, m // 16, l // 16 A = tvm.placeholder((n, l), name='A', dtype='float16') @@ -109,8 +109,8 @@ def test_tensor_core_float(): block_row_warps = 2 block_col_warps = 4 warp_row_tiles = 2 - warp_col_tiles = 1 - chunk = 4 + warp_col_tiles = 2 + chunk = 8 block_x = tvm.thread_axis('blockIdx.x') block_y = tvm.thread_axis('blockIdx.y') @@ -192,10 +192,10 @@ def test_tensor_core_float(): c = tvm.nd.array(np.zeros((n, n), dtype=C.dtype), ctx) func(a, b, c) c_np = c.asnumpy() - np.testing.assert_allclose(c_np, np.dot(a_np, b_np).astype(C.dtype), rtol=0.001, atol=0.001) + np.testing.assert_allclose(c_np, np.dot(a_np.astype(C.dtype), b_np.astype(C.dtype)), rtol=1e-4, atol=1e-4) evaluator = func.time_evaluator(func.entry_name, ctx, number=5) print('gemm with tensor core: %f ms' % (evaluator(a, b, c).mean * 1e3)) if __name__ == '__main__': - test_tensor_core_float() + test_tensor_core_gemm() From e290b1e03bb0b04bf2528f051ff1c7512927a9e0 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Wed, 2 Oct 2019 21:07:12 -0700 Subject: [PATCH 04/12] better performance --- src/codegen/codegen_cuda.cc | 12 ++- .../unittest/test_schedule_tensor_core.py | 97 +++++++++---------- 2 files changed, 55 insertions(+), 54 deletions(-) diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index 245ba52fd2fa..3fc0c872ecf1 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -106,14 +106,22 @@ void CodeGenCUDA::PrintType(Type t, std::ostream& os) { // NOLINT(*) bool fail = false; if (t.is_float()) { switch (t.bits()) { - case 16: os << "half"; + case 16: enable_fp16_ = true; + if (lanes == 1) { + os << "half"; + } else if (lanes <= 8) { + CHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; + os << "float" << lanes / 2; + } else { + fail = true; + } break; case 32: os << "float"; break; case 64: os << "double"; break; default: fail = true; break; } - if (!fail && lanes == 1) return; + if (!fail && (lanes == 1 || t.bits() == 16)) return; if (!fail && (lanes >= 2 && lanes <= 4)) { os << lanes; return; } diff --git a/tests/python/unittest/test_schedule_tensor_core.py b/tests/python/unittest/test_schedule_tensor_core.py index cfad92066d0e..87d2801cdaa0 100644 --- a/tests/python/unittest/test_schedule_tensor_core.py +++ b/tests/python/unittest/test_schedule_tensor_core.py @@ -30,7 +30,7 @@ def intrin_func(ins, outs): BA = ins[0] BC = outs[0] ib.emit(tvm.call_intrin('handle', 'tvm_load_matrix_sync', - BC.data, BC.elem_offset / 256, + BC.data, BC.elem_offset // 256, BA.access_ptr('r'), n)) return ib.get() @@ -42,9 +42,9 @@ def intrin_wmma_gemm(): B = tvm.placeholder((n, n), name='B', dtype='float16') k = tvm.reduce_axis((0, n), name="k") C = tvm.compute((n, n), - lambda ii, jj: - tvm.sum(A[ii, k].astype('float') * B[k, jj].astype('float'), axis=k), - name='C') + lambda ii, jj: + tvm.sum(A[ii, k].astype('float') * B[k, jj].astype('float'), axis=k), + name='C') BA = tvm.decl_buffer(A.shape, A.dtype, name='BA', scope='wmma.matrix_a', data_alignment=32, offset_factor=256) BB = tvm.decl_buffer(B.shape, B.dtype, name='BB', scope='wmma.matrix_b', data_alignment=32, offset_factor=256) BC = tvm.decl_buffer(C.shape, C.dtype, name='BC', scope='wmma.accumulator', data_alignment=32, offset_factor=256) @@ -52,20 +52,23 @@ def intrin_wmma_gemm(): def intrin_func(ins, outs): BA, BB = ins BC, = outs + def init(): ib = tvm.ir_builder.create() - ib.emit(tvm.call_intrin('handle', 'tvm_fill_fragment', BC.data, BC.elem_offset / 256, 0.0)) + ib.emit(tvm.call_intrin('handle', 'tvm_fill_fragment', BC.data, BC.elem_offset // 256, 0.0)) return ib.get() def update(): ib = tvm.ir_builder.create() ib.emit(tvm.call_intrin('handle', 'tvm_mma_sync', - BC.data, BC.elem_offset / 256, - BA.data, BA.elem_offset / 256, - BB.data, BB.elem_offset / 256, - BC.data, BC.elem_offset / 256)) + BC.data, BC.elem_offset // 256, + BA.data, BA.elem_offset // 256, + BB.data, BB.elem_offset // 256, + BC.data, BC.elem_offset // 256)) return ib.get() + return update(), init(), update() + return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC}) def intrin_wmma_store_matrix(): @@ -73,7 +76,7 @@ def intrin_wmma_store_matrix(): A = tvm.placeholder((n, n), name='A', dtype='float32') BA = tvm.decl_buffer(A.shape, A.dtype, scope='wmma.accumulator', data_alignment=32, offset_factor=256) C = tvm.compute((n, n), lambda i, j: A[i, j], name='C') - BC = tvm.decl_buffer(C.shape, C.dtype, scope='shared', data_alignment=32, offset_factor=256) + BC = tvm.decl_buffer(C.shape, C.dtype, scope='global', data_alignment=32, offset_factor=256) def intrin_func(ins, outs): ib = tvm.ir_builder.create() @@ -81,7 +84,7 @@ def intrin_func(ins, outs): BA = ins[0] BC = outs[0] ib.emit(tvm.call_intrin('handle', 'tvm_store_matrix_sync', - BA.data, BA.elem_offset / 256, + BA.data, BA.elem_offset // 256, BC.access_ptr('w'), n, 'mem_row_major')) return ib.get() @@ -90,27 +93,27 @@ def intrin_func(ins, outs): def test_tensor_core_gemm(): n = 4096 m, l = n, n + assert(n % 16 == 0) + assert(m % 16 == 0) + assert(l % 16 == 0) nn, mm, ll = n // 16, m // 16, l // 16 - A = tvm.placeholder((n, l), name='A', dtype='float16') - B = tvm.placeholder((l, m), name='B', dtype='float16') - AS = tvm.compute((nn, ll, 16, 16), lambda i, j, ii, jj: A[i * 16 + ii, j * 16 + jj], name='A.shared') - BS = tvm.compute((ll, nn, 16, 16), lambda i, j, ii, jj: B[i * 16 + ii, j * 16 + jj], name='B.shared') + A = tvm.placeholder((nn, ll, 16, 16), name='A', dtype='float16') + B = tvm.placeholder((ll, mm, 16, 16), name='B', dtype='float16') k1 = tvm.reduce_axis((0, ll), name='k1') k2 = tvm.reduce_axis((0, 16), name='k2') - CF = tvm.compute((nn, mm, 16, 16), - lambda i, j, ii, jj: - tvm.sum(AS[i, k1, ii, k2].astype('float') * BS[k1, j, k2, jj].astype('float'), axis=[k1, k2]), - name='Fragment_C') - C = tvm.compute((n, m), lambda ii, jj: CF[ii // 16, jj // 16, ii % 16, jj % 16], name='C') + C = tvm.compute((nn, mm, 16, 16), + lambda i, j, ii, jj: + tvm.sum(A[i, k1, ii, k2].astype('float') * B[k1, j, k2, jj].astype('float'), axis=[k1, k2]), + name='Fragment_C') s = tvm.create_schedule(C.op) warp_size = 32 kernel_size = 16 block_row_warps = 2 block_col_warps = 4 - warp_row_tiles = 2 + warp_row_tiles = 4 warp_col_tiles = 2 - chunk = 8 + chunk = 4 block_x = tvm.thread_axis('blockIdx.x') block_y = tvm.thread_axis('blockIdx.y') @@ -118,37 +121,24 @@ def test_tensor_core_gemm(): thread_y = tvm.thread_axis('threadIdx.y') thread_z = tvm.thread_axis('threadIdx.z') - AF = s.cache_read(AS, 'wmma.matrix_a', [CF]) - BF = s.cache_read(BS, 'wmma.matrix_b', [CF]) - CS = s.cache_read(CF, 'shared', [C]) - s[AS].set_scope('shared') - s[BS].set_scope('shared') - s[CF].set_scope('wmma.accumulator') + AS = s.cache_read(A, 'shared', [C]) + BS = s.cache_read(B, 'shared', [C]) + AF = s.cache_read(AS, 'wmma.matrix_a', [C]) + BF = s.cache_read(BS, 'wmma.matrix_b', [C]) + CF = s.cache_write(C, 'wmma.accumulator') - i, j = s[C].op.axis - i, ti = s[C].split(i, factor=kernel_size) + i, j, kernel_i, kernel_j = s[C].op.axis i, ii = s[C].split(i, factor=warp_row_tiles) block_i, i = s[C].split(i, factor=block_row_warps) - j, tj = s[C].split(j, factor=kernel_size) j, jj = s[C].split(j, factor=warp_col_tiles) block_j, j = s[C].split(j, factor=block_col_warps) - s[C].reorder(block_i, block_j, i, j, ii, jj, ti, tj) + s[C].reorder(block_i, block_j, i, j, ii, jj, kernel_i, kernel_j) s[C].bind(block_i, block_x) s[C].bind(block_j, block_y) s[C].bind(i, thread_y) s[C].bind(j, thread_z) - t = s[C].fuse(ti, tj) - to, ti = s[C].split(t, factor=warp_size) - s[C].bind(ti, thread_x) - s[CS].compute_at(s[C], j) s[CF].compute_at(s[C], j) - xo, yo, xi, yi = CS.op.axis - tx, xo = s[CS].split(xo, nparts=block_row_warps) - ty, yo = s[CS].split(yo, nparts=block_col_warps) - s[CS].bind(tx, thread_y) - s[CS].bind(ty, thread_z) - warp_i, warp_j, _i, _j = s[CF].op.axis k, _k = CF.op.reduce_axis ko, ki = s[CF].split(k, factor=chunk) @@ -162,36 +152,39 @@ def test_tensor_core_gemm(): tx, xo = s[AS].split(xo, nparts=block_row_warps) ty, yo = s[AS].split(yo, nparts=block_col_warps) t = s[AS].fuse(xi, yi) - to, ti = s[AS].split(t, factor=warp_size) + to, ti = s[AS].split(t, nparts=warp_size) s[AS].bind(tx, thread_y) s[AS].bind(ty, thread_z) - s[AS].bind(ti, thread_x) + s[AS].bind(to, thread_x) + s[AS].vectorize(ti) s[BS].compute_at(s[CF], ko) xo, yo, xi, yi = BS.op.axis tx, xo = s[BS].split(xo, nparts=block_row_warps) ty, yo = s[BS].split(yo, nparts=block_col_warps) t = s[BS].fuse(xi, yi) - to, ti = s[BS].split(t, factor=warp_size) + to, ti = s[BS].split(t, nparts=warp_size) s[BS].bind(tx, thread_y) s[BS].bind(ty, thread_z) - s[BS].bind(ti, thread_x) - + s[BS].bind(to, thread_x) + s[BS].vectorize(ti) s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_a')) s[BF].tensorize(BF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_b')) - s[CS].tensorize(CS.op.axis[-2], intrin_wmma_store_matrix()) + s[C].tensorize(kernel_i, intrin_wmma_store_matrix()) s[CF].tensorize(_i, intrin_wmma_gemm()) func = tvm.build(s, [A, B, C], 'cuda') ctx = tvm.gpu(0) - a_np = np.random.uniform(size=(n, n)).astype(A.dtype) - b_np = np.random.uniform(size=(n, n)).astype(B.dtype) + a_np = np.random.uniform(size=(nn, nn, 16, 16)).astype(A.dtype) + b_np = np.random.uniform(size=(nn, nn, 16, 16)).astype(B.dtype) a = tvm.nd.array(a_np, ctx) b = tvm.nd.array(b_np, ctx) - c = tvm.nd.array(np.zeros((n, n), dtype=C.dtype), ctx) + c = tvm.nd.array(np.zeros((nn, nn, 16, 16), dtype=C.dtype), ctx) func(a, b, c) - c_np = c.asnumpy() + a_np = a_np.transpose(0, 2, 1, 3).reshape(n, n) + b_np = b_np.transpose(0, 2, 1, 3).reshape(n, n) + c_np = c.asnumpy().transpose(0, 2, 1, 3).reshape(n, n) np.testing.assert_allclose(c_np, np.dot(a_np.astype(C.dtype), b_np.astype(C.dtype)), rtol=1e-4, atol=1e-4) evaluator = func.time_evaluator(func.entry_name, ctx, number=5) print('gemm with tensor core: %f ms' % (evaluator(a, b, c).mean * 1e3)) From 5caec7d84f054cd9971fb9ffe9648f3d108d6c9e Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Fri, 4 Oct 2019 19:59:16 -0700 Subject: [PATCH 05/12] add schedule test for conv2d --- .../unittest/test_schedule_tensor_core.py | 200 ++++++++++++++++-- .../python/topi/testing/conv2d_nhwc_python.py | 2 +- 2 files changed, 188 insertions(+), 14 deletions(-) diff --git a/tests/python/unittest/test_schedule_tensor_core.py b/tests/python/unittest/test_schedule_tensor_core.py index 87d2801cdaa0..bc9e65ff7621 100644 --- a/tests/python/unittest/test_schedule_tensor_core.py +++ b/tests/python/unittest/test_schedule_tensor_core.py @@ -16,6 +16,11 @@ # under the License. import tvm import numpy as np +from topi.testing import conv2d_nhwc_python +from tvm.contrib import nvcc + +VERIFY = False + def intrin_wmma_load_matrix(scope): n = 16 @@ -36,6 +41,7 @@ def intrin_func(ins, outs): return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) + def intrin_wmma_gemm(): n = 16 A = tvm.placeholder((n, n), name='A', dtype='float16') @@ -71,6 +77,7 @@ def update(): return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC}) + def intrin_wmma_store_matrix(): n = 16 A = tvm.placeholder((n, n), name='A', dtype='float32') @@ -90,21 +97,22 @@ def intrin_func(ins, outs): return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) + def test_tensor_core_gemm(): n = 4096 m, l = n, n - assert(n % 16 == 0) - assert(m % 16 == 0) - assert(l % 16 == 0) + assert (n % 16 == 0) + assert (m % 16 == 0) + assert (l % 16 == 0) nn, mm, ll = n // 16, m // 16, l // 16 A = tvm.placeholder((nn, ll, 16, 16), name='A', dtype='float16') B = tvm.placeholder((ll, mm, 16, 16), name='B', dtype='float16') k1 = tvm.reduce_axis((0, ll), name='k1') k2 = tvm.reduce_axis((0, 16), name='k2') C = tvm.compute((nn, mm, 16, 16), - lambda i, j, ii, jj: - tvm.sum(A[i, k1, ii, k2].astype('float') * B[k1, j, k2, jj].astype('float'), axis=[k1, k2]), - name='Fragment_C') + lambda i, j, ii, jj: + tvm.sum(A[i, k1, ii, k2].astype('float') * B[k1, j, k2, jj].astype('float'), axis=[k1, k2]), + name='Fragment_C') s = tvm.create_schedule(C.op) warp_size = 32 @@ -181,14 +189,180 @@ def test_tensor_core_gemm(): a = tvm.nd.array(a_np, ctx) b = tvm.nd.array(b_np, ctx) c = tvm.nd.array(np.zeros((nn, nn, 16, 16), dtype=C.dtype), ctx) - func(a, b, c) - a_np = a_np.transpose(0, 2, 1, 3).reshape(n, n) - b_np = b_np.transpose(0, 2, 1, 3).reshape(n, n) - c_np = c.asnumpy().transpose(0, 2, 1, 3).reshape(n, n) - np.testing.assert_allclose(c_np, np.dot(a_np.astype(C.dtype), b_np.astype(C.dtype)), rtol=1e-4, atol=1e-4) - evaluator = func.time_evaluator(func.entry_name, ctx, number=5) + evaluator = func.time_evaluator(func.entry_name, ctx, number=3) print('gemm with tensor core: %f ms' % (evaluator(a, b, c).mean * 1e3)) + if VERIFY: + func(a, b, c) + a_np = a_np.transpose(0, 2, 1, 3).reshape(n, n) + b_np = b_np.transpose(0, 2, 1, 3).reshape(n, n) + c_np = c.asnumpy().transpose(0, 2, 1, 3).reshape(n, n) + np.testing.assert_allclose(c_np, np.dot(a_np.astype(C.dtype), b_np.astype(C.dtype)), rtol=1e-4, atol=1e-4) + + +def test_tensor_core_conv(): + # The sizes of inputs and filters + batch_size = 256 + height = 14 + width = 14 + in_channels = 256 + out_channels = 512 + kernel_h = 3 + kernel_w = 3 + pad_h = 1 + pad_w = 1 + stride_h = 1 + stride_w = 1 + block_size = 16 + + block_row_warps = 2 + block_col_warps = 4 + warp_row_tiles = 4 + warp_col_tiles = 2 + warp_size = 32 + chunk = 2 + + # Input feature map: (N, H, W, IC, n, ic) + data_shape = (batch_size // block_size, + height, + width, + in_channels // block_size, + block_size, + block_size) + # Kernel: (H, W, IC, OC, ic, oc) + kernel_shape = (kernel_h, + kernel_w, + in_channels // block_size, + out_channels // block_size, + block_size, + block_size) + + # Output feature map: (N, H, W, OC, n, oc) + output_shape = (batch_size // block_size, + height, + width, + out_channels // block_size, + block_size, + block_size) + + assert (batch_size % block_size == 0) + assert (in_channels % block_size == 0) + assert (out_channels % block_size == 0) + + kh = tvm.reduce_axis((0, kernel_h), name='kh') + kw = tvm.reduce_axis((0, kernel_w), name='kw') + ic = tvm.reduce_axis((0, in_channels // block_size), name='ic') + ii = tvm.reduce_axis((0, block_size), name='ii') + + # Algorithm + A = tvm.placeholder(data_shape, name='A', dtype="float16") + W = tvm.placeholder(kernel_shape, name='W', dtype="float16") + Apad = tvm.compute( + (batch_size // block_size, height + 2 * pad_h, width + 2 * pad_w, in_channels // block_size, block_size, + block_size), + lambda n, h, w, i, nn, ii: tvm.if_then_else( + tvm.all(h >= pad_h, h - pad_h < height, + w >= pad_w, w - pad_w < width), + A[n, h - pad_h, w - pad_w, i, nn, ii], tvm.const(0., "float16")), + name='Apad') + Conv = tvm.compute(output_shape, + lambda n, h, w, o, nn, oo: tvm.sum( + Apad[n, h * stride_h + kh, w * stride_w + kw, ic, nn, ii].astype("float32") * + W[kh, kw, ic, o, ii, oo].astype("float32"), + axis=[ic, kh, kw, ii]), + name="Conv") + + s = tvm.create_schedule(Conv.op) + s[Apad].compute_inline() + + AS = s.cache_read(Apad, 'shared', [Conv]) + WS = s.cache_read(W, 'shared', [Conv]) + AF = s.cache_read(AS, 'wmma.matrix_a', [Conv]) + WF = s.cache_read(WS, 'wmma.matrix_b', [Conv]) + ConvF = s.cache_write(Conv, 'wmma.accumulator') + + block_x = tvm.thread_axis('blockIdx.x') + block_y = tvm.thread_axis('blockIdx.y') + block_z = tvm.thread_axis('blockIdx.z') + thread_x = tvm.thread_axis('threadIdx.x') + thread_y = tvm.thread_axis('threadIdx.y') + thread_z = tvm.thread_axis('threadIdx.z') + + nc, hc, wc, oc, nnc, ooc = Conv.op.axis + block_k = s[Conv].fuse(hc, wc) + s[Conv].bind(block_k, block_z) + nc, nci = s[Conv].split(nc, factor=warp_row_tiles) + block_i, nc = s[Conv].split(nc, factor=block_row_warps) + oc, oci = s[Conv].split(oc, factor=warp_col_tiles) + block_j, oc = s[Conv].split(oc, factor=block_col_warps) + s[Conv].reorder(block_k, block_i, block_j, nc, oc, nci, oci, nnc, ooc) + s[Conv].bind(block_i, block_x) + s[Conv].bind(block_j, block_y) + s[Conv].bind(nc, thread_y) + s[Conv].bind(oc, thread_z) + + s[ConvF].compute_at(s[Conv], oc) + n, h, w, o, nnf, oof = ConvF.op.axis + ko, ki = s[ConvF].split(ic, factor=chunk) + s[ConvF].reorder(ko, kh, ki, kw, n, o, nnf, oof, ii) + + s[AF].compute_at(s[ConvF], kw) + s[WF].compute_at(s[ConvF], kw) + + s[WS].compute_at(s[ConvF], kh) + s[AS].compute_at(s[ConvF], kh) + + n, h, w, i, nn, ii = AS.op.axis + tx, xo = s[AS].split(n, nparts=block_row_warps) + ty, yo = s[AS].split(xo, nparts=block_col_warps) + t = s[AS].fuse(nn, ii) + to, ti = s[AS].split(t, factor=warp_size) + s[AS].bind(tx, thread_y) + s[AS].bind(ty, thread_z) + s[AS].bind(ti, thread_x) + + kh, kw, ic, o, ii, oo = WS.op.axis + tx, xo = s[WS].split(o, nparts=block_row_warps) + ty, yo = s[WS].split(xo, nparts=block_col_warps) + t = s[WS].fuse(ii, oo) + to, ti = s[WS].split(t, nparts=warp_size) + s[WS].bind(tx, thread_y) + s[WS].bind(ty, thread_z) + s[WS].bind(to, thread_x) + s[WS].vectorize(ti) + + s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_a')) + s[WF].tensorize(WF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_b')) + s[Conv].tensorize(nnc, intrin_wmma_store_matrix()) + s[ConvF].tensorize(nnf, intrin_wmma_gemm()) + + func = tvm.build(s, [A, W, Conv], 'cuda') + + ctx = tvm.gpu(0) + a_np = np.random.uniform(size=data_shape).astype(A.dtype) + w_np = np.random.uniform(size=kernel_shape).astype(W.dtype) + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + c = tvm.nd.array(np.zeros(output_shape, dtype=Conv.dtype), ctx) + evaluator = func.time_evaluator(func.entry_name, ctx, number=3) + print('conv2d with tensor core: %f ms' % (evaluator(a, w, c).mean * 1e3)) + + if VERIFY: + func(a, w, c) + a_np = a_np.transpose(0, 4, 1, 2, 3, 5).reshape(batch_size, height, width, in_channels) + w_np = w_np.transpose(0, 1, 2, 4, 3, 5).reshape(kernel_h, kernel_w, in_channels, out_channels) + c_np = c.asnumpy().transpose((0, 4, 1, 2, 3, 5)).reshape(batch_size, height, width, out_channels) + c_std = conv2d_nhwc_python(a_np.astype(Conv.dtype), + w_np.astype(Conv.dtype), + (stride_h, stride_w), + (pad_h, pad_w)).astype(Conv.dtype) + np.testing.assert_allclose(c_np, c_std, rtol=1e-4, atol=1e-4) + if __name__ == '__main__': - test_tensor_core_gemm() + ctx = tvm.gpu(0) + if not nvcc.have_tensorcore(ctx.compute_version): + print("skip because gpu does not support tensor core") + else: + test_tensor_core_gemm() + test_tensor_core_conv() diff --git a/topi/python/topi/testing/conv2d_nhwc_python.py b/topi/python/topi/testing/conv2d_nhwc_python.py index d2ef40c64d21..8a6a467a80c4 100644 --- a/topi/python/topi/testing/conv2d_nhwc_python.py +++ b/topi/python/topi/testing/conv2d_nhwc_python.py @@ -40,7 +40,7 @@ def conv2d_nhwc_python(a_np, w_np, stride, padding): Returns ------- b_np : np.ndarray - 4-D with shape [out_height, out_width, out_channel, batch] + 4-D with shape [batch, out_height, out_width, out_channel] """ batch, in_height, in_width, in_channel = a_np.shape kernel_h, kernel_w, _, num_filter = w_np.shape From edea7dff0a6b2365863ea9ff5608c3f6f04479bd Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Mon, 7 Oct 2019 15:52:04 -0700 Subject: [PATCH 06/12] extend into BatchMatMul --- .../unittest/test_schedule_tensor_core.py | 45 ++++++++++--------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/tests/python/unittest/test_schedule_tensor_core.py b/tests/python/unittest/test_schedule_tensor_core.py index bc9e65ff7621..8dd29ee5d590 100644 --- a/tests/python/unittest/test_schedule_tensor_core.py +++ b/tests/python/unittest/test_schedule_tensor_core.py @@ -98,20 +98,21 @@ def intrin_func(ins, outs): return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) -def test_tensor_core_gemm(): - n = 4096 +def test_tensor_core_batch_matmal(): + batch_size = 20 + n = 2048 m, l = n, n assert (n % 16 == 0) assert (m % 16 == 0) assert (l % 16 == 0) nn, mm, ll = n // 16, m // 16, l // 16 - A = tvm.placeholder((nn, ll, 16, 16), name='A', dtype='float16') - B = tvm.placeholder((ll, mm, 16, 16), name='B', dtype='float16') + A = tvm.placeholder((batch_size, nn, ll, 16, 16), name='A', dtype='float16') + B = tvm.placeholder((batch_size, ll, mm, 16, 16), name='B', dtype='float16') k1 = tvm.reduce_axis((0, ll), name='k1') k2 = tvm.reduce_axis((0, 16), name='k2') - C = tvm.compute((nn, mm, 16, 16), - lambda i, j, ii, jj: - tvm.sum(A[i, k1, ii, k2].astype('float') * B[k1, j, k2, jj].astype('float'), axis=[k1, k2]), + C = tvm.compute((batch_size, nn, mm, 16, 16), + lambda b, i, j, ii, jj: + tvm.sum(A[b, i, k1, ii, k2].astype('float') * B[b, k1, j, k2, jj].astype('float'), axis=[k1, k2]), name='Fragment_C') s = tvm.create_schedule(C.op) @@ -125,6 +126,7 @@ def test_tensor_core_gemm(): block_x = tvm.thread_axis('blockIdx.x') block_y = tvm.thread_axis('blockIdx.y') + block_z = tvm.thread_axis('blockIdx.z') thread_x = tvm.thread_axis('threadIdx.x') thread_y = tvm.thread_axis('threadIdx.y') thread_z = tvm.thread_axis('threadIdx.z') @@ -135,19 +137,20 @@ def test_tensor_core_gemm(): BF = s.cache_read(BS, 'wmma.matrix_b', [C]) CF = s.cache_write(C, 'wmma.accumulator') - i, j, kernel_i, kernel_j = s[C].op.axis + b, i, j, kernel_i, kernel_j = s[C].op.axis i, ii = s[C].split(i, factor=warp_row_tiles) block_i, i = s[C].split(i, factor=block_row_warps) j, jj = s[C].split(j, factor=warp_col_tiles) block_j, j = s[C].split(j, factor=block_col_warps) s[C].reorder(block_i, block_j, i, j, ii, jj, kernel_i, kernel_j) + s[C].bind(b, block_z) s[C].bind(block_i, block_x) s[C].bind(block_j, block_y) s[C].bind(i, thread_y) s[C].bind(j, thread_z) s[CF].compute_at(s[C], j) - warp_i, warp_j, _i, _j = s[CF].op.axis + b, warp_i, warp_j, _i, _j = s[CF].op.axis k, _k = CF.op.reduce_axis ko, ki = s[CF].split(k, factor=chunk) s[CF].reorder(ko, ki, warp_i, warp_j, _i, _j, _k) @@ -156,7 +159,7 @@ def test_tensor_core_gemm(): s[BF].compute_at(s[CF], ki) s[AS].compute_at(s[CF], ko) - xo, yo, xi, yi = AS.op.axis + b, xo, yo, xi, yi = AS.op.axis tx, xo = s[AS].split(xo, nparts=block_row_warps) ty, yo = s[AS].split(yo, nparts=block_col_warps) t = s[AS].fuse(xi, yi) @@ -167,7 +170,7 @@ def test_tensor_core_gemm(): s[AS].vectorize(ti) s[BS].compute_at(s[CF], ko) - xo, yo, xi, yi = BS.op.axis + b, xo, yo, xi, yi = BS.op.axis tx, xo = s[BS].split(xo, nparts=block_row_warps) ty, yo = s[BS].split(yo, nparts=block_col_warps) t = s[BS].fuse(xi, yi) @@ -184,23 +187,23 @@ def test_tensor_core_gemm(): func = tvm.build(s, [A, B, C], 'cuda') ctx = tvm.gpu(0) - a_np = np.random.uniform(size=(nn, nn, 16, 16)).astype(A.dtype) - b_np = np.random.uniform(size=(nn, nn, 16, 16)).astype(B.dtype) + a_np = np.random.uniform(size=(batch_size, nn, nn, 16, 16)).astype(A.dtype) + b_np = np.random.uniform(size=(batch_size, nn, nn, 16, 16)).astype(B.dtype) a = tvm.nd.array(a_np, ctx) b = tvm.nd.array(b_np, ctx) - c = tvm.nd.array(np.zeros((nn, nn, 16, 16), dtype=C.dtype), ctx) + c = tvm.nd.array(np.zeros((batch_size, nn, nn, 16, 16), dtype=C.dtype), ctx) evaluator = func.time_evaluator(func.entry_name, ctx, number=3) print('gemm with tensor core: %f ms' % (evaluator(a, b, c).mean * 1e3)) if VERIFY: func(a, b, c) - a_np = a_np.transpose(0, 2, 1, 3).reshape(n, n) - b_np = b_np.transpose(0, 2, 1, 3).reshape(n, n) - c_np = c.asnumpy().transpose(0, 2, 1, 3).reshape(n, n) - np.testing.assert_allclose(c_np, np.dot(a_np.astype(C.dtype), b_np.astype(C.dtype)), rtol=1e-4, atol=1e-4) + a_np = a_np.transpose((0, 1, 3, 2, 4)).reshape(batch_size, n, n) + b_np = b_np.transpose((0, 1, 3, 2, 4)).reshape(batch_size, n, n) + c_np = c.asnumpy().transpose((0, 1, 3, 2, 4)).reshape(batch_size, n, n) + np.testing.assert_allclose(c_np, np.matmul(a_np.astype(C.dtype), b_np.astype(C.dtype)), rtol=1e-4, atol=1e-4) -def test_tensor_core_conv(): +def test_tensor_core_batch_conv(): # The sizes of inputs and filters batch_size = 256 height = 14 @@ -364,5 +367,5 @@ def test_tensor_core_conv(): if not nvcc.have_tensorcore(ctx.compute_version): print("skip because gpu does not support tensor core") else: - test_tensor_core_gemm() - test_tensor_core_conv() + test_tensor_core_batch_matmal() + test_tensor_core_batch_conv() From cfff0951a88df6013106be72ffdd570ec9f73ffc Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Tue, 15 Oct 2019 01:33:55 -0700 Subject: [PATCH 07/12] support config fragment shape and layout using intrinsic --- include/tvm/ir.h | 12 +- include/tvm/ir_pass.h | 8 + python/tvm/build_module.py | 1 + src/api/api_pass.cc | 1 + src/codegen/codegen_cuda.cc | 54 +++-- src/codegen/codegen_cuda.h | 6 +- src/pass/infer_fragment.cc | 197 ++++++++++++++++++ .../unittest/test_schedule_tensor_core.py | 10 +- 8 files changed, 264 insertions(+), 25 deletions(-) create mode 100644 src/pass/infer_fragment.cc diff --git a/include/tvm/ir.h b/include/tvm/ir.h index b4e28f486926..4d95c8ab7d5a 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -1310,6 +1310,16 @@ constexpr const char* opengl_stage_scope = "opengl_stage_scope"; */ constexpr const char* device_scope = "device_scope"; +/*! + * \brief Mark that the shape of TensorCore fragment + */ +constexpr const char* fragment_shape = "fragment_shape"; + +/*! + * \brief Mark that the layout of TensorCore fragment + */ +constexpr const char* fragment_layout = "fragment_layout"; + /*! * \brief Check if attr_key is a pragma key extension * \param attr_key The attr key to be compared @@ -1319,6 +1329,7 @@ inline bool IsPragmaKey(const std::string& attr_key) { return attr_key.compare(0, 7, "pragma_") == 0; } + } // namespace attr /*! \brief namespace of TVM Intrinsic functions */ @@ -1559,7 +1570,6 @@ constexpr const char* tvm_load_matrix_sync = "tvm_load_matrix_sync"; constexpr const char* tvm_mma_sync = "tvm_mma_sync"; constexpr const char* tvm_fill_fragment = "tvm_fill_fragment"; constexpr const char* tvm_store_matrix_sync = "tvm_store_matrix_sync"; -constexpr const char* tvm_access_fragement = "tvm_access_fragement"; } // namespace intrinsic diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index db5e60e479d8..330eb54e040d 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -532,6 +532,14 @@ LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target); */ LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target); +/*! + * \brief Infer the TensorCore fragment infomation using tensor intrinsics + * + * \param stmt The stmt to be transformed + * \return Transformed stmt. + */ +LoweredFunc InferFragment(LoweredFunc f); + /*! * \brief Verify if memory accesses are legal for a specific target device type. * diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index 0e5c5e0dc2d7..a6155339cc62 100644 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -464,6 +464,7 @@ def _build_for_device(flist, target, target_host): func = ir_pass.ThreadSync(func, "global") func = ir_pass.ThreadSync(func, "shared") func = ir_pass.ThreadSync(func, "warp") + func = ir_pass.InferFragment(func) warp_size = target.thread_warp_size func = ir_pass.LowerThreadAllreduce(func, warp_size) fsplits = [s for s in ir_pass.SplitHostDevice(func)] diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index dd0415afd9eb..1cf4c64e4e2c 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -161,5 +161,6 @@ REGISTER_PASS(DecorateDeviceScope); REGISTER_PASS(InstrumentBoundCheckers); REGISTER_PASS(VerifyCompactBuffer); REGISTER_PASS(HoistIfThenElse); +REGISTER_PASS(InferFragment) } // namespace ir } // namespace tvm diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index 3fc0c872ecf1..fe249824467f 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -305,39 +305,39 @@ void CodeGenCUDA::PrintStorageScope( void CodeGenCUDA::VisitExpr_(const Call *op, std::ostream& os) { if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) { need_mma_h_ = true; - CHECK_EQ(op->args.size(), 3U); + CHECK_EQ(op->args.size(), 6U); os << "nvcuda::wmma::fill_fragment("; this->PrintExpr(op->args[0], os); os << "["; - this->PrintExpr(op->args[1], os); + this->PrintExpr(op->args[4], os); os << "], "; - this->PrintExpr(op->args[2], os); + this->PrintExpr(op->args[5], os); os << ")"; } else if (op->is_intrinsic(intrinsic::tvm_load_matrix_sync)) { need_mma_h_ = true; - CHECK_EQ(op->args.size(), 4U); + CHECK_EQ(op->args.size(), 8U); os << "nvcuda::wmma::load_matrix_sync("; this->PrintExpr(op->args[0], os); os << "["; - this->PrintExpr(op->args[1], os); + this->PrintExpr(op->args[4], os); os << "], "; - this->PrintExpr(op->args[2], os); + this->PrintExpr(op->args[5], os); os << ", "; - this->PrintExpr(op->args[3], os); + this->PrintExpr(op->args[6], os); os << ")"; } else if (op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) { need_mma_h_ = true; - CHECK_EQ(op->args.size(), 5U); + CHECK_EQ(op->args.size(), 8U); os << "nvcuda::wmma::store_matrix_sync("; - this->PrintExpr(op->args[2], os); + this->PrintExpr(op->args[5], os); os << ", "; this->PrintExpr(op->args[0], os); os << "["; - this->PrintExpr(op->args[1], os); + this->PrintExpr(op->args[4], os); os << "], "; - this->PrintExpr(op->args[3], os); - if (const StringImm *str = op->args[4].as()) { - os << ", nvcuda::wmma::" << str->value; + this->PrintExpr(op->args[6], os); + if (const StringImm *str = op->args[7].as()) { + os << ", nvcuda::wmma::mem_" << str->value; } else { LOG(FATAL) << "Invalid parameters"; } @@ -357,6 +357,19 @@ void CodeGenCUDA::VisitExpr_(const Call *op, std::ostream& os) { } } +void CodeGenCUDA::VisitStmt_(const AttrStmt* op) { + if (op->attr_key == attr::fragment_shape) { + const Variable* buffer = op->node.as(); + const StringImm* shape_str = op->value.as(); + fragment_shapes[buffer] = shape_str->value; + } else if (op->attr_key == attr::fragment_layout) { + const Variable* buffer = op->node.as(); + const StringImm* layout_str = op->value.as(); + fragment_layouts[buffer] = layout_str->value; + } + CodeGenC::VisitStmt_(op); +} + void CodeGenCUDA::VisitStmt_(const Allocate* op) { CHECK(!is_zero(op->condition)); std::string vid = AllocVarID(op->buffer_var.get()); @@ -383,7 +396,7 @@ void CodeGenCUDA::VisitStmt_(const Allocate* op) { << "Accumulator only support half and float type for now"; } constant_size /= 256; - PrintWmmaScope(scope, op->type, stream); + PrintWmmaScope(scope, op->type, buffer, stream); } else { PrintStorageScope(scope, stream); stream << ' '; @@ -498,18 +511,23 @@ void CodeGenCUDA::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(* PrintConst(op, os, this); } -void CodeGenCUDA::PrintWmmaScope(const std::string &scope, Type t, std::ostream &os) { +void CodeGenCUDA::PrintWmmaScope(const std::string &scope, Type t, const Variable* variable, std::ostream &os) { std::stringstream type; PrintType(t, type); + std::string shape_str = fragment_shapes[variable]; if (scope == "wmma.matrix_a") { need_mma_h_ = true; - os << "nvcuda::wmma::fragment"; + std::string layout_str = fragment_layouts[variable]; + os << "nvcuda::wmma::fragment"; } else if (scope == "wmma.matrix_b") { need_mma_h_ = true; - os << "nvcuda::wmma::fragment"; + std::string layout_str = fragment_layouts[variable]; + os << "nvcuda::wmma::fragment"; } else if (scope == "wmma.accumulator") { need_mma_h_ = true; - os << "nvcuda::wmma::fragment"; + os << "nvcuda::wmma::fragment"; } } diff --git a/src/codegen/codegen_cuda.h b/src/codegen/codegen_cuda.h index 621747d615c4..9671ece63d09 100644 --- a/src/codegen/codegen_cuda.h +++ b/src/codegen/codegen_cuda.h @@ -63,6 +63,7 @@ class CodeGenCUDA final : public CodeGenC { void VisitExpr_(const Call *op, std::ostream& os) final; void VisitStmt_(const Evaluate *op) final; void VisitStmt_(const Allocate *op) final; + void VisitStmt_(const AttrStmt *op) final; private: // Whether global barrier is needed. @@ -79,8 +80,11 @@ class CodeGenCUDA final : public CodeGenC { bool need_math_constants_h_{false}; // whether need mma.h bool need_mma_h_{false}; + + std::unordered_map fragment_shapes; + std::unordered_map fragment_layouts; friend void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p); - void PrintWmmaScope(const std::string& scope, Type t, std::ostream& os); + void PrintWmmaScope(const std::string& scope, Type t, const Variable* variable, std::ostream& os); }; } // namespace codegen diff --git a/src/pass/infer_fragment.cc b/src/pass/infer_fragment.cc new file mode 100644 index 000000000000..9abdaecfe06e --- /dev/null +++ b/src/pass/infer_fragment.cc @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file tensorcore_fragment.cc + */ +#include +#include +#include +#include +#include +#include +#include "ir_util.h" +#include "storage_access.h" +#include "../runtime/thread_storage_scope.h" + +namespace tvm { +namespace ir { + +class FragmentGetter : public IRVisitor { + public: + struct FragmentInfo { + int m, n, k; + std::string layout; + FragmentInfo() = default; + FragmentInfo(int _m, int _n, int _k, const std::string& _layout) + : m(_m), n(_n), k(_k), layout(_layout) {} + }; + + void Visit_(const Call* op) final { + IRVisitor::Visit_(op); + + if (op->is_intrinsic(intrinsic::tvm_load_matrix_sync) || + op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) { + CHECK_EQ(op->args.size(), 8U); + const Variable* buffer_var = op->args[0].as(); + CHECK(buffer_var); + const IntImm* m = op->args[1].as(); + const IntImm* n = op->args[2].as(); + const IntImm* k = op->args[3].as(); + const StringImm* layout = op->args[7].as(); + CHECK(m); + CHECK(n); + CHECK(k); + CHECK(layout); + + std::string scope = scopes[buffer_var]; + if (fragments.count(buffer_var)) { + FragmentInfo info = fragments[buffer_var]; + CHECK_EQ(m->value, info.m); + CHECK_EQ(n->value, info.n); + CHECK_EQ(k->value, info.k); + if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { + CHECK_EQ(layout->value, info.layout); + } + } else { + FragmentInfo info; + if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { + info = FragmentInfo(m->value, n->value, k->value, layout->value); + } else if (scope == "wmma.accumulator") { + info = FragmentInfo(m->value, n->value, k->value, ""); + } + fragments[buffer_var] = info; + } + } else if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) { + CHECK_EQ(op->args.size(), 6U); + const Variable* buffer_var = op->args[0].as(); + CHECK(buffer_var); + const IntImm* m = op->args[1].as(); + const IntImm* n = op->args[2].as(); + const IntImm* k = op->args[3].as(); + CHECK(m); + CHECK(n); + CHECK(k); + + std::string scope = scopes[buffer_var]; + CHECK_EQ(scope, "wmma.accumulator"); + if (fragments.count(buffer_var)) { + FragmentInfo info = fragments[buffer_var]; + CHECK_EQ(m->value, info.m); + CHECK_EQ(n->value, info.n); + CHECK_EQ(k->value, info.k); + } else { + FragmentInfo info(m->value, n->value, k->value, ""); + fragments[buffer_var] = info; + } + } + } + + void Visit_(const AttrStmt* op) final { + if (op->attr_key == attr::storage_scope) { + const Variable* buffer = op->node.as(); + CHECK(buffer); + scopes[buffer] = op->value.as()->value; + } + IRVisitor::Visit_(op); + } + + std::unordered_map scopes; + std::unordered_map fragments; +}; + +class FragmentChecker : public IRVisitor { + public: + FragmentChecker(const FragmentGetter &getter) : fragment_getter(getter) {} + + void Visit_(const Call* op) final { + if (op->is_intrinsic(intrinsic::tvm_mma_sync)) { + CHECK_EQ(op->args.size(), 8U); + const Variable* buffer_var_d = op->args[0].as(); + const Variable* buffer_var_a = op->args[2].as(); + const Variable* buffer_var_b = op->args[4].as(); + const Variable* buffer_var_c = op->args[6].as(); + CHECK(buffer_var_d); + CHECK(buffer_var_a); + CHECK(buffer_var_b); + CHECK(buffer_var_c); + CHECK(CheckShape(buffer_var_d, buffer_var_a)); + CHECK(CheckShape(buffer_var_d, buffer_var_b)); + CHECK(CheckShape(buffer_var_d, buffer_var_c)); + } + } + private: + bool CheckShape(const Variable* buffer1, const Variable* buffer2) { + CHECK(fragment_getter.fragments.count(buffer1)); + CHECK(fragment_getter.fragments.count(buffer2)); + FragmentGetter::FragmentInfo info1 = fragment_getter.fragments.at(buffer1); + FragmentGetter::FragmentInfo info2 = fragment_getter.fragments.at(buffer2); + return info1.m == info2.m && info1.n == info2.n && info1.k == info2.k; + + } + const FragmentGetter &fragment_getter; + +}; + +class InferFragmenter : public IRMutator { + public: + InferFragmenter(const FragmentGetter &getter) : fragment_getter(getter) {} + + Stmt Mutate_(const Allocate* op, const Stmt& s) final { + Stmt stmt = IRMutator::Mutate_(op, s); + const Variable* buffer = op->buffer_var.get(); + if (fragment_getter.fragments.count(buffer)) { + FragmentGetter::FragmentInfo info = fragment_getter.fragments.at(buffer); + std::string shape = std::to_string(info.n) + ", " + + std::to_string(info.m) + ", " + + std::to_string(info.k); + Expr shape_expr = StringImm::make(shape); + Stmt shape_attr = AttrStmt::make(op->buffer_var, attr::fragment_shape, shape_expr, stmt); + if (info.layout != "") { + Stmt layout_attr = AttrStmt::make(op->buffer_var, attr::fragment_layout, + StringImm::make(info.layout), shape_attr); + return layout_attr; + } else { + return shape_attr; + } + } + return stmt; + } + private: + const FragmentGetter &fragment_getter; +}; + +Stmt InferFragment(Stmt stmt) { + FragmentGetter getter; + getter.Visit(stmt); + FragmentChecker(getter).Visit(stmt); + stmt = InferFragmenter(getter).Mutate(stmt); + return stmt; +} + +LoweredFunc InferFragment(LoweredFunc f) { + CHECK_NE(f->func_type, kHostFunc); + auto n = make_node(*f.operator->()); + n->body = InferFragment(f->body); + return LoweredFunc(n); +} + +} // namespace ir +} // namespace tvm diff --git a/tests/python/unittest/test_schedule_tensor_core.py b/tests/python/unittest/test_schedule_tensor_core.py index 8dd29ee5d590..3892f681ae98 100644 --- a/tests/python/unittest/test_schedule_tensor_core.py +++ b/tests/python/unittest/test_schedule_tensor_core.py @@ -35,8 +35,8 @@ def intrin_func(ins, outs): BA = ins[0] BC = outs[0] ib.emit(tvm.call_intrin('handle', 'tvm_load_matrix_sync', - BC.data, BC.elem_offset // 256, - BA.access_ptr('r'), n)) + BC.data, n, n, n, BC.elem_offset // 256, + BA.access_ptr('r'), n, 'row_major')) return ib.get() return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) @@ -61,7 +61,7 @@ def intrin_func(ins, outs): def init(): ib = tvm.ir_builder.create() - ib.emit(tvm.call_intrin('handle', 'tvm_fill_fragment', BC.data, BC.elem_offset // 256, 0.0)) + ib.emit(tvm.call_intrin('handle', 'tvm_fill_fragment', BC.data, n, n, n, BC.elem_offset // 256, 0.0)) return ib.get() def update(): @@ -91,8 +91,8 @@ def intrin_func(ins, outs): BA = ins[0] BC = outs[0] ib.emit(tvm.call_intrin('handle', 'tvm_store_matrix_sync', - BA.data, BA.elem_offset // 256, - BC.access_ptr('w'), n, 'mem_row_major')) + BA.data, n, n, n, BA.elem_offset // 256, + BC.access_ptr('w'), n, 'row_major')) return ib.get() return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) From 190d936af053c6c7eb58677c842bb2d40f92ebca Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Tue, 15 Oct 2019 19:50:23 -0700 Subject: [PATCH 08/12] add TensorCore tutorial --- .../unittest/test_schedule_tensor_core.py | 12 +- tutorials/optimize/opt_conv_tensorcore.py | 348 ++++++++++++++++++ 2 files changed, 354 insertions(+), 6 deletions(-) create mode 100644 tutorials/optimize/opt_conv_tensorcore.py diff --git a/tests/python/unittest/test_schedule_tensor_core.py b/tests/python/unittest/test_schedule_tensor_core.py index 3892f681ae98..032120f78d35 100644 --- a/tests/python/unittest/test_schedule_tensor_core.py +++ b/tests/python/unittest/test_schedule_tensor_core.py @@ -19,7 +19,7 @@ from topi.testing import conv2d_nhwc_python from tvm.contrib import nvcc -VERIFY = False +VERIFY = True def intrin_wmma_load_matrix(scope): @@ -99,8 +99,8 @@ def intrin_func(ins, outs): def test_tensor_core_batch_matmal(): - batch_size = 20 - n = 2048 + batch_size = 4 + n = 512 m, l = n, n assert (n % 16 == 0) assert (m % 16 == 0) @@ -205,11 +205,11 @@ def test_tensor_core_batch_matmal(): def test_tensor_core_batch_conv(): # The sizes of inputs and filters - batch_size = 256 + batch_size = 32 height = 14 width = 14 - in_channels = 256 - out_channels = 512 + in_channels = 32 + out_channels = 64 kernel_h = 3 kernel_w = 3 pad_h = 1 diff --git a/tutorials/optimize/opt_conv_tensorcore.py b/tutorials/optimize/opt_conv_tensorcore.py new file mode 100644 index 000000000000..2e5a53f1c0bb --- /dev/null +++ b/tutorials/optimize/opt_conv_tensorcore.py @@ -0,0 +1,348 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +.. _opt-conv-tensorcore: + +How to optimize convolution using TensorCores +================================== +**Author**: `Siyuan Feng `_ + +In this tutorial, we will demonstrate how to write a high performance convolution +schedule using TensorCores in TVM. In this example, we assume the input to +convolution has a large batch. We strongly recommend covering the :ref:`opt-conv-gpu` tutorial first. + +""" + +################################################################ +# TensorCore Introduction +# ------------------------- +# Each Tensor Core provides a 4x4x4 matrix processing array that operates +# :code:`D = A * B + C`, where A, B, C and D are 4x4 matrices as Figure shows. +# The matrix multiplication inputs A and B are FP16 matrices, while the accumulation +# matrices C and D may be FP16 or FP32 matrices. +# +# However, CUDA programmers can only use warp-level primitive +# :code:`wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag)` to perform +# 16x16x16 half-precision matrix multiplication on tensor cores. Before invoking +# the matrix multiplication, programmers must load data from memory into registers +# with primitive :code:`wmma::load_matrix_sync`, explicitly. The NVCC compiler translates +# that primitive into multiple memory load instructions. At run time, every thread loads +# 16 elements from matrix A and 16 elements from B. + +################################################################ +# Preparation and Algorithm +# -------------------------- +# We use the fixed size for input tensors with 256 channels and 14 x 14 dimensions. +# The batch size is 256. Convolution filters contain 512 filters of size 3 x 3. +# We use stride size 1 and padding size 1 for the convolution. In the example, we use +# NHWCnc memory layout.The following code defines the convolution algorithm in TVM. + +import tvm +import numpy as np +from tvm.contrib import nvcc + +# The sizes of inputs and filters +batch_size = 256 +height = 14 +width = 14 +in_channels = 256 +out_channels = 512 +kernel_h = 3 +kernel_w = 3 +pad_h = 1 +pad_w = 1 +stride_h = 1 +stride_w = 1 + +# TensorCore shape +block_size = 16 + +assert (batch_size % block_size == 0) +assert (in_channels % block_size == 0) +assert (out_channels % block_size == 0) + +# Input feature map: (N, H, W, IC, n, ic) +data_shape = (batch_size // block_size, + height, + width, + in_channels // block_size, + block_size, + block_size) +# Kernel: (H, W, IC, OC, ic, oc) +kernel_shape = (kernel_h, + kernel_w, + in_channels // block_size, + out_channels // block_size, + block_size, + block_size) +# Output feature map: (N, H, W, OC, n, oc) +output_shape = (batch_size // block_size, + height, + width, + out_channels // block_size, + block_size, + block_size) + +# Reduction axes +kh = tvm.reduce_axis((0, kernel_h), name='kh') +kw = tvm.reduce_axis((0, kernel_w), name='kw') +ic = tvm.reduce_axis((0, in_channels // block_size), name='ic') +ii = tvm.reduce_axis((0, block_size), name='ii') + +# Algorithm +A = tvm.placeholder(data_shape, name='A', dtype="float16") +W = tvm.placeholder(kernel_shape, name='W', dtype="float16") +Apad = tvm.compute( + (batch_size // block_size, height + 2 * pad_h, width + 2 * pad_w, in_channels // block_size, block_size, + block_size), + lambda n, h, w, i, nn, ii: tvm.if_then_else( + tvm.all(h >= pad_h, h - pad_h < height, + w >= pad_w, w - pad_w < width), + A[n, h - pad_h, w - pad_w, i, nn, ii], tvm.const(0., "float16")), + name='Apad') +Conv = tvm.compute(output_shape, + lambda n, h, w, o, nn, oo: tvm.sum( + Apad[n, h * stride_h + kh, w * stride_w + kw, ic, nn, ii].astype("float32") * + W[kh, kw, ic, o, ii, oo].astype("float32"), + axis=[ic, kh, kw, ii]), + name="Conv") + +s = tvm.create_schedule(Conv.op) +s[Apad].compute_inline() + +############################################################################### +# Memory Scope +# ---------------- +# +# In traditional GPU schedule, we have global, shared and local memory scope. +# To support TensorCores, we add another three special memory scope: :code:`wmma.matrix_a`, +# :code:`wmma.matrix_b` and :code:`wmma.accumulator`. On hardware, all fragments scope +# stores at the on-chip registers level, the same place with local memory. + +# Designate the memory hierarchy +AS = s.cache_read(Apad, 'shared', [Conv]) +WS = s.cache_read(W, 'shared', [Conv]) +AF = s.cache_read(AS, 'wmma.matrix_a', [Conv]) +WF = s.cache_read(WS, 'wmma.matrix_b', [Conv]) +ConvF = s.cache_write(Conv, 'wmma.accumulator') + +############################################################################### +# Define Tensor Intrinsic +# In fact, TensorCore is a special hardware operation. So, we can just use tensorize +# to replace a unit of computation with the TensorCore instruction. The first thing is +# that we need to define tensor intrinsic. +# +# There are four basic operation in TensorCore: :code:`fill_fragment`, :code:`load_matrix`, +# :code:`mma_sync` and :code:`store_matrix`. Since :code:`fill_fragment` and :code:`mma_sync` +# are both used in matrix multiplication, so we can just write following three intrinsics. + +def intrin_wmma_load_matrix(scope): + n = 16 + A = tvm.placeholder((n, n), name='A', dtype='float16') + BA = tvm.decl_buffer(A.shape, A.dtype, scope='shared', data_alignment=32, offset_factor=256) + C = tvm.compute((n, n), lambda i, j: A[i, j], name='C') + BC = tvm.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=256) + + def intrin_func(ins, outs): + ib = tvm.ir_builder.create() + + BA = ins[0] + BC = outs[0] + ib.emit(tvm.call_intrin('handle', 'tvm_load_matrix_sync', + BC.data, n, n, n, BC.elem_offset // 256, + BA.access_ptr('r'), n, 'row_major')) + return ib.get() + + return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) + + +def intrin_wmma_gemm(): + n = 16 + A = tvm.placeholder((n, n), name='A', dtype='float16') + B = tvm.placeholder((n, n), name='B', dtype='float16') + k = tvm.reduce_axis((0, n), name="k") + C = tvm.compute((n, n), + lambda ii, jj: + tvm.sum(A[ii, k].astype('float') * B[k, jj].astype('float'), axis=k), + name='C') + BA = tvm.decl_buffer(A.shape, A.dtype, name='BA', scope='wmma.matrix_a', data_alignment=32, offset_factor=256) + BB = tvm.decl_buffer(B.shape, B.dtype, name='BB', scope='wmma.matrix_b', data_alignment=32, offset_factor=256) + BC = tvm.decl_buffer(C.shape, C.dtype, name='BC', scope='wmma.accumulator', data_alignment=32, offset_factor=256) + + def intrin_func(ins, outs): + BA, BB = ins + BC, = outs + + def init(): + ib = tvm.ir_builder.create() + ib.emit(tvm.call_intrin('handle', 'tvm_fill_fragment', BC.data, n, n, n, BC.elem_offset // 256, 0.0)) + return ib.get() + + def update(): + ib = tvm.ir_builder.create() + ib.emit(tvm.call_intrin('handle', 'tvm_mma_sync', + BC.data, BC.elem_offset // 256, + BA.data, BA.elem_offset // 256, + BB.data, BB.elem_offset // 256, + BC.data, BC.elem_offset // 256)) + return ib.get() + + return update(), init(), update() + + return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC}) + + +def intrin_wmma_store_matrix(): + n = 16 + A = tvm.placeholder((n, n), name='A', dtype='float32') + BA = tvm.decl_buffer(A.shape, A.dtype, scope='wmma.accumulator', data_alignment=32, offset_factor=256) + C = tvm.compute((n, n), lambda i, j: A[i, j], name='C') + BC = tvm.decl_buffer(C.shape, C.dtype, scope='global', data_alignment=32, offset_factor=256) + + def intrin_func(ins, outs): + ib = tvm.ir_builder.create() + BA = ins[0] + BC = outs[0] + ib.emit(tvm.call_intrin('handle', 'tvm_store_matrix_sync', + BA.data, n, n, n, BA.elem_offset // 256, + BC.access_ptr('w'), n, 'row_major')) + return ib.get() + + return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) + +############################################################################### +# Scheduling the Computation +# -------------------------- +# To use TensorCores in TVM, we must schedule the computation into specific structure +# to match the tensor intrinsic. The same as traditional GPU programs, we can also use +# shared memory to boost the speed. If you have any questions about blocking and shared +# memory, please refer :ref:`opt-conv-gpu`. +# +# In this example, each block contains 2x4 warps, and each warp calls 4x2 TensorCore +# instructions. Thus, the output shape of each warp is 64x32 and each block outputs +# 128x128 titles. Due to the limit of shared memory space, we only load 2 blocks (2x128x128 tiles) +# one time. +# +# .. note:: +# +# *Warp-level Operation* +# +# Note that all TensorCore instructions are warp-level instructions, which means all 32 threads +# in a warp should do this instruction simultaneously. Making theadIdx.x extent=32 is one of the +# easiest way to solve this. Then We can bind threadIdx.x to any loops except those contain +# TensorCore intrinsics directly or indirectly. Also note that it is not the unique solution. +# The only thing we should do is to make sure all threads in a warp can call TensorCore at the same time. +# + + +# Define tiling sizes +block_row_warps = 2 +block_col_warps = 4 +warp_row_tiles = 4 +warp_col_tiles = 2 +warp_size = 32 +chunk = 2 + +block_x = tvm.thread_axis('blockIdx.x') +block_y = tvm.thread_axis('blockIdx.y') +block_z = tvm.thread_axis('blockIdx.z') +thread_x = tvm.thread_axis('threadIdx.x') +thread_y = tvm.thread_axis('threadIdx.y') +thread_z = tvm.thread_axis('threadIdx.z') + +nc, hc, wc, oc, nnc, ooc = Conv.op.axis +block_k = s[Conv].fuse(hc, wc) +s[Conv].bind(block_k, block_z) +nc, nci = s[Conv].split(nc, factor=warp_row_tiles) +block_i, nc = s[Conv].split(nc, factor=block_row_warps) +oc, oci = s[Conv].split(oc, factor=warp_col_tiles) +block_j, oc = s[Conv].split(oc, factor=block_col_warps) +s[Conv].reorder(block_k, block_i, block_j, nc, oc, nci, oci, nnc, ooc) +s[Conv].bind(block_i, block_x) +s[Conv].bind(block_j, block_y) +s[Conv].bind(nc, thread_y) +s[Conv].bind(oc, thread_z) + +# Schedule local computation +s[ConvF].compute_at(s[Conv], oc) +n, h, w, o, nnf, oof = ConvF.op.axis +ko, ki = s[ConvF].split(ic, factor=chunk) +s[ConvF].reorder(ko, kh, ki, kw, n, o, nnf, oof, ii) + +# Move intermediate computation into each output compute tile +s[AF].compute_at(s[ConvF], kw) +s[WF].compute_at(s[ConvF], kw) + +# Schedule for A's share memory +s[AS].compute_at(s[ConvF], kh) +n, h, w, i, nn, ii = AS.op.axis +tx, xo = s[AS].split(n, nparts=block_row_warps) +ty, yo = s[AS].split(xo, nparts=block_col_warps) +t = s[AS].fuse(nn, ii) +to, ti = s[AS].split(t, factor=warp_size) +s[AS].bind(tx, thread_y) +s[AS].bind(ty, thread_z) +s[AS].bind(ti, thread_x) + +# Schedule for W's share memory +s[WS].compute_at(s[ConvF], kh) +kh, kw, ic, o, ii, oo = WS.op.axis +tx, xo = s[WS].split(o, nparts=block_row_warps) +ty, yo = s[WS].split(xo, nparts=block_col_warps) +t = s[WS].fuse(ii, oo) +to, ti = s[WS].split(t, nparts=warp_size) +s[WS].bind(tx, thread_y) +s[WS].bind(ty, thread_z) +s[WS].bind(to, thread_x) +s[WS].vectorize(ti) +print(tvm.lower(s, [A, W, Conv], simple_mode=True)) + +############################################################################### +# Lowering Computation to Intrinsics +# -------------------------- +# The last phase is to lower the computation loops down to TensorCore hardware intrinsics +# by mapping the 2D convolution to tensor intrinsics +# + +s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_a')) +s[WF].tensorize(WF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_b')) +s[Conv].tensorize(nnc, intrin_wmma_store_matrix()) +s[ConvF].tensorize(nnf, intrin_wmma_gemm()) +print(tvm.lower(s, [A, W, Conv], simple_mode=True)) + +############################################################################### +# Generate CUDA Kernel +# -------------------- +# Finally we use TVM to generate and compile the CUDA kernel, and evaluate the latency of convolution. +# Since TensorCores are only supported in NVIDIA GPU with Compute Capability 7.0 or higher, it may not +# be able to run on our build server + +ctx = tvm.gpu(0) +if nvcc.have_tensorcore(ctx.compute_version): + func = tvm.build(s, [A, W, Conv], 'cuda') + a_np = np.random.uniform(size=data_shape).astype(A.dtype) + w_np = np.random.uniform(size=kernel_shape).astype(W.dtype) + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + c = tvm.nd.array(np.zeros(output_shape, dtype=Conv.dtype), ctx) + evaluator = func.time_evaluator(func.entry_name, ctx, number=10) + print('conv2d with tensor core: %f ms' % (evaluator(a, w, c).mean * 1e3)) + +############################################################################### +# Summary +# This tutorial demonstrates how TVM scheduling primitives can be used to +# call TensorCores on specific GPUs. From 46f1b6ea1bc187555c53cce99deb50f4c19c57bb Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Tue, 15 Oct 2019 23:20:46 -0700 Subject: [PATCH 09/12] add int support and fix lint --- include/tvm/ir_pass.h | 8 +++--- src/codegen/codegen_cuda.cc | 14 +++++----- src/codegen/codegen_cuda.h | 1 + src/pass/infer_fragment.cc | 9 ++++--- .../unittest/test_schedule_tensor_core.py | 22 +++++++++++----- tests/scripts/task_lint.sh | 26 +++++++++---------- 6 files changed, 47 insertions(+), 33 deletions(-) diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 330eb54e040d..07cfeeb8ac9d 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -508,8 +508,8 @@ LoweredFunc PointerValueTypeRewrite(LoweredFunc f); * \brief Lower attached storage access information. * Do this pass after all storage access analysis finish. * - * \param stmt The stmt to be transformed - * \return Transformed stmt. + * \param func The device function to be lowered. + * \return Transformed function. */ LoweredFunc LowerStorageAccessInfo(LoweredFunc func); @@ -535,8 +535,8 @@ LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target); /*! * \brief Infer the TensorCore fragment infomation using tensor intrinsics * - * \param stmt The stmt to be transformed - * \return Transformed stmt. + * \param f The device function to be lowered. + * \return Transformed function. */ LoweredFunc InferFragment(LoweredFunc f); diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index fe249824467f..c07b4fa52a7a 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -389,11 +389,11 @@ void CodeGenCUDA::VisitStmt_(const Allocate* op) { std::string scope = alloc_storage_scope_.at(buffer); if (scope.find("wmma.") == 0) { if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { - CHECK(op->type.is_float() && op->type.bits() == 16) - << "Matrix_a and matrix_b only support half type for now"; + CHECK(op->type == Float(16) || op->type == Int(8) || op->type == UInt(8)) + << "Matrix_a and matrix_b only support half or char or unsigned char type for now"; } else { - CHECK(op->type.is_float() && (op->type.bits() == 16 || op->type.bits() == 32)) - << "Accumulator only support half and float type for now"; + CHECK(op->type == Float(16) || op->type == Float(32) || op->type == Int(32)) + << "Accumulator only support half, float and int type for now"; } constant_size /= 256; PrintWmmaScope(scope, op->type, buffer, stream); @@ -511,7 +511,8 @@ void CodeGenCUDA::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(* PrintConst(op, os, this); } -void CodeGenCUDA::PrintWmmaScope(const std::string &scope, Type t, const Variable* variable, std::ostream &os) { +void CodeGenCUDA::PrintWmmaScope(const std::string &scope, Type t, + const Variable* variable, std::ostream &os) { std::stringstream type; PrintType(t, type); std::string shape_str = fragment_shapes[variable]; @@ -527,7 +528,8 @@ void CodeGenCUDA::PrintWmmaScope(const std::string &scope, Type t, const Variabl << shape_str << ", " << type.str() << ", nvcuda::wmma::" << layout_str <<">"; } else if (scope == "wmma.accumulator") { need_mma_h_ = true; - os << "nvcuda::wmma::fragment"; + os << "nvcuda::wmma::fragment"; } } diff --git a/src/codegen/codegen_cuda.h b/src/codegen/codegen_cuda.h index 9671ece63d09..ed43ef68474e 100644 --- a/src/codegen/codegen_cuda.h +++ b/src/codegen/codegen_cuda.h @@ -28,6 +28,7 @@ #include #include #include +#include #include "codegen_c.h" namespace tvm { diff --git a/src/pass/infer_fragment.cc b/src/pass/infer_fragment.cc index 9abdaecfe06e..f6b1016f4924 100644 --- a/src/pass/infer_fragment.cc +++ b/src/pass/infer_fragment.cc @@ -119,7 +119,7 @@ class FragmentGetter : public IRVisitor { class FragmentChecker : public IRVisitor { public: - FragmentChecker(const FragmentGetter &getter) : fragment_getter(getter) {} + explicit FragmentChecker(const FragmentGetter &getter) : fragment_getter(getter) {} void Visit_(const Call* op) final { if (op->is_intrinsic(intrinsic::tvm_mma_sync)) { @@ -137,6 +137,7 @@ class FragmentChecker : public IRVisitor { CHECK(CheckShape(buffer_var_d, buffer_var_c)); } } + private: bool CheckShape(const Variable* buffer1, const Variable* buffer2) { CHECK(fragment_getter.fragments.count(buffer1)); @@ -144,15 +145,14 @@ class FragmentChecker : public IRVisitor { FragmentGetter::FragmentInfo info1 = fragment_getter.fragments.at(buffer1); FragmentGetter::FragmentInfo info2 = fragment_getter.fragments.at(buffer2); return info1.m == info2.m && info1.n == info2.n && info1.k == info2.k; - } - const FragmentGetter &fragment_getter; + const FragmentGetter &fragment_getter; }; class InferFragmenter : public IRMutator { public: - InferFragmenter(const FragmentGetter &getter) : fragment_getter(getter) {} + explicit InferFragmenter(const FragmentGetter &getter) : fragment_getter(getter) {} Stmt Mutate_(const Allocate* op, const Stmt& s) final { Stmt stmt = IRMutator::Mutate_(op, s); @@ -174,6 +174,7 @@ class InferFragmenter : public IRMutator { } return stmt; } + private: const FragmentGetter &fragment_getter; }; diff --git a/tests/python/unittest/test_schedule_tensor_core.py b/tests/python/unittest/test_schedule_tensor_core.py index 032120f78d35..9aca6fc80e5f 100644 --- a/tests/python/unittest/test_schedule_tensor_core.py +++ b/tests/python/unittest/test_schedule_tensor_core.py @@ -99,6 +99,13 @@ def intrin_func(ins, outs): def test_tensor_core_batch_matmal(): + if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"): + print("skip because cuda is not enabled..") + return + if not nvcc.have_tensorcore(tvm.gpu(0).compute_version): + print("skip because gpu does not support tensor core") + return + batch_size = 4 n = 512 m, l = n, n @@ -204,6 +211,13 @@ def test_tensor_core_batch_matmal(): def test_tensor_core_batch_conv(): + if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"): + print("skip because cuda is not enabled..") + return + if not nvcc.have_tensorcore(tvm.gpu(0).compute_version): + print("skip because gpu does not support tensor core") + return + # The sizes of inputs and filters batch_size = 32 height = 14 @@ -363,9 +377,5 @@ def test_tensor_core_batch_conv(): if __name__ == '__main__': - ctx = tvm.gpu(0) - if not nvcc.have_tensorcore(ctx.compute_version): - print("skip because gpu does not support tensor core") - else: - test_tensor_core_batch_matmal() - test_tensor_core_batch_conv() + test_tensor_core_batch_matmal() + test_tensor_core_batch_conv() diff --git a/tests/scripts/task_lint.sh b/tests/scripts/task_lint.sh index 544ef7224770..667c4b77fcc3 100755 --- a/tests/scripts/task_lint.sh +++ b/tests/scripts/task_lint.sh @@ -30,19 +30,19 @@ trap cleanup 0 echo "Check file types..." python3 tests/lint/check_file_type.py -echo "Check ASF license header..." -java -jar /bin/apache-rat.jar -E tests/lint/rat-excludes -d . | (grep "== File" > /tmp/$$.apache-rat.txt || true) -if grep --quiet -E "File" /tmp/$$.apache-rat.txt; then - echo "Need to add ASF header to the following files." - echo "----------------File List----------------" - cat /tmp/$$.apache-rat.txt - echo "-----------------------------------------" - echo "Use the following steps to add the headers:" - echo "- Create file_list.txt in your text editor" - echo "- Copy paste the above content in file-list into file_list.txt" - echo "- python3 tests/lint/add_asf_header.py file_list.txt" - exit 1 -fi +#echo "Check ASF license header..." +#java -jar /bin/apache-rat.jar -E tests/lint/rat-excludes -d . | (grep "== File" > /tmp/$$.apache-rat.txt || true) +#if grep --quiet -E "File" /tmp/$$.apache-rat.txt; then +# echo "Need to add ASF header to the following files." +# echo "----------------File List----------------" +# cat /tmp/$$.apache-rat.txt +# echo "-----------------------------------------" +# echo "Use the following steps to add the headers:" +# echo "- Create file_list.txt in your text editor" +# echo "- Copy paste the above content in file-list into file_list.txt" +# echo "- python3 tests/lint/add_asf_header.py file_list.txt" +# exit 1 +#fi echo "Check codestyle of c++ code..." make cpplint From abe1e90c2dc76560699fa7d18810f7ad65d869da Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Wed, 16 Oct 2019 16:51:46 -0700 Subject: [PATCH 10/12] address comment --- include/tvm/ir.h | 44 ++++++++++++++++++- include/tvm/ir_pass.h | 13 +++++- python/tvm/build_module.py | 4 +- src/api/api_pass.cc | 9 ++++ src/codegen/build_module.cc | 4 +- src/pass/infer_fragment.cc | 30 ++++++++++++- src/pass/storage_access.cc | 2 +- src/pass/storage_sync.cc | 2 +- .../unittest/test_schedule_tensor_core.py | 1 + tests/scripts/task_lint.sh | 26 +++++------ tutorials/optimize/opt_conv_tensorcore.py | 12 ++--- vta/python/vta/build_module.py | 1 + 12 files changed, 117 insertions(+), 31 deletions(-) diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 4d95c8ab7d5a..37718fe1b3c7 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -1329,7 +1329,6 @@ inline bool IsPragmaKey(const std::string& attr_key) { return attr_key.compare(0, 7, "pragma_") == 0; } - } // namespace attr /*! \brief namespace of TVM Intrinsic functions */ @@ -1564,11 +1563,52 @@ constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit"; */ constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce"; /*! - * \brief tvm intrinsic for tensor core opeartors. + * \brief tvm intrinsic for tensor core load operators. + * + * void tvm_load_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k, + * Expr index, Expr buffer_ptr, Expr stride, + * StringImm layout) { + * // m, n, k are the shape of wmma fragment. + * // Determine fragment layout(column-major or row major) by layout. + * // fragments must be in 'wmma.matrix_a' or 'wmma.matrix_b' scope. + * nvcuda::wmma::load_matrix_sync(fragment[index], buffer_ptr, stride); + * } */ constexpr const char* tvm_load_matrix_sync = "tvm_load_matrix_sync"; +/*! + * \brief tvm intrinsic for tensor core mma_sync operators. + * + * void tvm_mma_sync(Var fragment_d, Expr index_d, + * Var fragment_a, Expr index_a, + * Var fragment_b, Expr index_b, + * Var fragment_c, Expr index_c) { + * nvcuda::wmma::mma_sync(fragment_d[index_d], fragment_a[index_a], + * fragment_b[index_b], fragment_c[index_c]); + * } + */ constexpr const char* tvm_mma_sync = "tvm_mma_sync"; +/*! + * \brief tvm intrinsic for tensor core fill_fragment operators. + * + * void tvm_fill_fragment(Var fragment, UIntImm m, UIntImm, n, UIntImm k, + * Expr index, Expr value) { + * // m, n, k are the shape of wmma fragment + * // fragments must be in 'wmma.accumulator' scope. + * nvcuda::wmma::fill_fragment(fragment[index], value); + * } + */ constexpr const char* tvm_fill_fragment = "tvm_fill_fragment"; +/*! + * \brief tvm intrinsic for tensor core store operators. + * + * void tvm_store_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k, + * Expr index, Expr buffer_ptr, Expr stride, + * StringImm layout) { + * // m, n, k are the shape of wmma fragment + * // fragments must be in 'wmma.accumulator' scope. + * nvcuda::wmma::store_matrix_sync(fragment[index], buffer_ptr, stride, layout); + * } + */ constexpr const char* tvm_store_matrix_sync = "tvm_store_matrix_sync"; } // namespace intrinsic diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 07cfeeb8ac9d..842c6af8cf5d 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -359,6 +359,15 @@ Stmt LiftAttrScope(Stmt stmt, std::string attr_key); */ Stmt RewriteUnsafeSelect(Stmt stmt); +/*! + * \brief Lower attached storage access information. + * Do this pass after all storage access analysis finish. + * + * \param stmt The stmt to be transformed + * \return Transformed stmt. + */ +Stmt LowerStorageAccessInfo(Stmt stmt); + /*! * \brief Decorate the stmt with a device scope, this is helpful for * hardware accelerator without thread blocks. @@ -505,13 +514,13 @@ LoweredFunc CombineContextCall(LoweredFunc f); LoweredFunc PointerValueTypeRewrite(LoweredFunc f); /*! - * \brief Lower attached storage access information. + * \brief Lower attached storage access information on device. * Do this pass after all storage access analysis finish. * * \param func The device function to be lowered. * \return Transformed function. */ -LoweredFunc LowerStorageAccessInfo(LoweredFunc func); +LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc func); /*! * \brief Lower intrinsic function calls. diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index a6155339cc62..3f1d4a925d1d 100644 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -494,8 +494,8 @@ def _build_for_device(flist, target, target_host): assert not fdevice target_host = _target.create(target_host) - fdevice = [ir_pass.LowerStorageAccessInfo(x) for x in fdevice] - fhost = [ir_pass.LowerStorageAccessInfo(x) for x in fhost] + fdevice = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fdevice] + fhost = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fhost] fdevice = [ir_pass.LowerIntrin(x, target.target_name) for x in fdevice] fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost] fhost = [ir_pass.CombineContextCall(x) for x in fhost] diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 1cf4c64e4e2c..d7f621f3ade1 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -118,6 +118,14 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit") }); }); +TVM_REGISTER_API("ir_pass.LowerStorageAccess") +.set_body([](TVMArgs args, TVMRetValue *ret) { + LoweredFunc f = args[0]; + auto n = make_node(*f.operator->()); + n->body = LowerStorageAccessInfo(f->body); + *ret = LoweredFunc(n); +}); + // make from two arguments #define REGISTER_PASS(PassName) \ TVM_REGISTER_API("ir_pass."#PassName) \ @@ -140,6 +148,7 @@ REGISTER_PASS(SplitHostDevice); REGISTER_PASS(StorageRewrite); REGISTER_PASS(CoProcSync); REGISTER_PASS(LowerStorageAccessInfo); +REGISTER_PASS(LowerDeviceStorageAccessInfo) REGISTER_PASS(InjectVirtualThread); REGISTER_PASS(InjectPrefetch); REGISTER_PASS(InjectDoubleBuffer); diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 77b72e2b66d6..cfcb0607858f 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -516,7 +516,7 @@ Array > split_dev_host_funcs(const Array& funcs, for (size_t i = 0; i < fhost.size(); ++i) { auto func = fhost[i]; func = ir::BindDeviceType(func, target->device_type); - func = ir::LowerStorageAccessInfo(func); + func = ir::LowerDeviceStorageAccessInfo(func); func = ir::LowerTVMBuiltin(func); fhost.Set(i, func); } @@ -524,7 +524,7 @@ Array > split_dev_host_funcs(const Array& funcs, for (size_t i = 0; i < fhost.size(); ++i) { auto func = fhost[i]; func = ir::LowerIntrin(func, target_host->target_name); - func = ir::LowerStorageAccessInfo(func); + func = ir::LowerDeviceStorageAccessInfo(func); func = ir::CombineContextCall(func); fhost.Set(i, func); } diff --git a/src/pass/infer_fragment.cc b/src/pass/infer_fragment.cc index f6b1016f4924..9c71f5c367ba 100644 --- a/src/pass/infer_fragment.cc +++ b/src/pass/infer_fragment.cc @@ -18,7 +18,8 @@ */ /*! - * Copyright (c) 2019 by Contributors + * Copyright (c) 2019 by Contributors + * \brief Infer TensorCore metadata from tensor intrinsic. * \file tensorcore_fragment.cc */ #include @@ -34,10 +35,14 @@ namespace tvm { namespace ir { +// Get fragment information from tensor intrinsics class FragmentGetter : public IRVisitor { public: + // fragment metadata struct FragmentInfo { + // fragment shape int m, n, k; + // fragment layout (row-major or column-major) std::string layout; FragmentInfo() = default; FragmentInfo(int _m, int _n, int _k, const std::string& _layout) @@ -49,9 +54,11 @@ class FragmentGetter : public IRVisitor { if (op->is_intrinsic(intrinsic::tvm_load_matrix_sync) || op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) { + // Get shape and layout information from load and store intrinsic CHECK_EQ(op->args.size(), 8U); const Variable* buffer_var = op->args[0].as(); CHECK(buffer_var); + // Get shape const IntImm* m = op->args[1].as(); const IntImm* n = op->args[2].as(); const IntImm* k = op->args[3].as(); @@ -63,6 +70,7 @@ class FragmentGetter : public IRVisitor { std::string scope = scopes[buffer_var]; if (fragments.count(buffer_var)) { + // check if the fragment has met before FragmentInfo info = fragments[buffer_var]; CHECK_EQ(m->value, info.m); CHECK_EQ(n->value, info.n); @@ -71,6 +79,7 @@ class FragmentGetter : public IRVisitor { CHECK_EQ(layout->value, info.layout); } } else { + // store metadata FragmentInfo info; if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { info = FragmentInfo(m->value, n->value, k->value, layout->value); @@ -80,9 +89,11 @@ class FragmentGetter : public IRVisitor { fragments[buffer_var] = info; } } else if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) { + // Get shape information from fill intrinsic CHECK_EQ(op->args.size(), 6U); const Variable* buffer_var = op->args[0].as(); CHECK(buffer_var); + // Get shape const IntImm* m = op->args[1].as(); const IntImm* n = op->args[2].as(); const IntImm* k = op->args[3].as(); @@ -91,6 +102,7 @@ class FragmentGetter : public IRVisitor { CHECK(k); std::string scope = scopes[buffer_var]; + // Only wmma.accumulator can use tvm_fill_fragment CHECK_EQ(scope, "wmma.accumulator"); if (fragments.count(buffer_var)) { FragmentInfo info = fragments[buffer_var]; @@ -104,6 +116,7 @@ class FragmentGetter : public IRVisitor { } } + // Get memory scope void Visit_(const AttrStmt* op) final { if (op->attr_key == attr::storage_scope) { const Variable* buffer = op->node.as(); @@ -113,15 +126,19 @@ class FragmentGetter : public IRVisitor { IRVisitor::Visit_(op); } + // Memory scope for allocations std::unordered_map scopes; + // Fragment metadata for all fragments std::unordered_map fragments; }; +// Check shape of fragment making sure it is a valid shape for tvm_mma_sync class FragmentChecker : public IRVisitor { public: explicit FragmentChecker(const FragmentGetter &getter) : fragment_getter(getter) {} void Visit_(const Call* op) final { + // Check shape when calling tvm_mma_sync if (op->is_intrinsic(intrinsic::tvm_mma_sync)) { CHECK_EQ(op->args.size(), 8U); const Variable* buffer_var_d = op->args[0].as(); @@ -132,6 +149,8 @@ class FragmentChecker : public IRVisitor { CHECK(buffer_var_a); CHECK(buffer_var_b); CHECK(buffer_var_c); + + // Check all fragment A, B, C and D have the same shape CHECK(CheckShape(buffer_var_d, buffer_var_a)); CHECK(CheckShape(buffer_var_d, buffer_var_b)); CHECK(CheckShape(buffer_var_d, buffer_var_c)); @@ -139,6 +158,7 @@ class FragmentChecker : public IRVisitor { } private: + // A tool for checking shapes of two fragments bool CheckShape(const Variable* buffer1, const Variable* buffer2) { CHECK(fragment_getter.fragments.count(buffer1)); CHECK(fragment_getter.fragments.count(buffer2)); @@ -146,10 +166,11 @@ class FragmentChecker : public IRVisitor { FragmentGetter::FragmentInfo info2 = fragment_getter.fragments.at(buffer2); return info1.m == info2.m && info1.n == info2.n && info1.k == info2.k; } - + // Fragment infomation const FragmentGetter &fragment_getter; }; +// Store the metadata into attributes class InferFragmenter : public IRMutator { public: explicit InferFragmenter(const FragmentGetter &getter) : fragment_getter(getter) {} @@ -158,13 +179,17 @@ class InferFragmenter : public IRMutator { Stmt stmt = IRMutator::Mutate_(op, s); const Variable* buffer = op->buffer_var.get(); if (fragment_getter.fragments.count(buffer)) { + // Add attribute to fragments allocation FragmentGetter::FragmentInfo info = fragment_getter.fragments.at(buffer); + + // Add shape attribute to all fragments std::string shape = std::to_string(info.n) + ", " + std::to_string(info.m) + ", " + std::to_string(info.k); Expr shape_expr = StringImm::make(shape); Stmt shape_attr = AttrStmt::make(op->buffer_var, attr::fragment_shape, shape_expr, stmt); if (info.layout != "") { + // Add shape attribute to matrix_a and matrix_b Stmt layout_attr = AttrStmt::make(op->buffer_var, attr::fragment_layout, StringImm::make(info.layout), shape_attr); return layout_attr; @@ -176,6 +201,7 @@ class InferFragmenter : public IRMutator { } private: + // Fragment infomation const FragmentGetter &fragment_getter; }; diff --git a/src/pass/storage_access.cc b/src/pass/storage_access.cc index 0e185685b555..8cad36d0e287 100644 --- a/src/pass/storage_access.cc +++ b/src/pass/storage_access.cc @@ -341,7 +341,7 @@ Stmt LowerStorageAccessInfo(Stmt stmt) { return StorageAccessInfoLower().Mutate(stmt); } -LoweredFunc LowerStorageAccessInfo(LoweredFunc f) { +LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc f) { auto n = make_node(*f.operator->()); n->body = LowerStorageAccessInfo(f->body); return LoweredFunc(n); diff --git a/src/pass/storage_sync.cc b/src/pass/storage_sync.cc index 019ed357274e..34dac52b6d85 100644 --- a/src/pass/storage_sync.cc +++ b/src/pass/storage_sync.cc @@ -269,7 +269,7 @@ class ThreadSyncInserter : public IRMutator { op = expr.as(); CHECK_EQ(op->args.size(), 5U); const Variable* buffer_var = op->args[1].as(); - Var var(buffer_var->GetNodePtr()); + Var var(GetRef(buffer_var)); const IntImm* flag = op->args[4].as(); if ((flag->value & 1) && sync_scope_.rank == StorageRank::kGlobal && GetScope(buffer_var).rank == StorageRank::kGlobal) { diff --git a/tests/python/unittest/test_schedule_tensor_core.py b/tests/python/unittest/test_schedule_tensor_core.py index 9aca6fc80e5f..2b6879fe0d20 100644 --- a/tests/python/unittest/test_schedule_tensor_core.py +++ b/tests/python/unittest/test_schedule_tensor_core.py @@ -191,6 +191,7 @@ def test_tensor_core_batch_matmal(): s[BF].tensorize(BF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_b')) s[C].tensorize(kernel_i, intrin_wmma_store_matrix()) s[CF].tensorize(_i, intrin_wmma_gemm()) + func = tvm.build(s, [A, B, C], 'cuda') ctx = tvm.gpu(0) diff --git a/tests/scripts/task_lint.sh b/tests/scripts/task_lint.sh index 667c4b77fcc3..544ef7224770 100755 --- a/tests/scripts/task_lint.sh +++ b/tests/scripts/task_lint.sh @@ -30,19 +30,19 @@ trap cleanup 0 echo "Check file types..." python3 tests/lint/check_file_type.py -#echo "Check ASF license header..." -#java -jar /bin/apache-rat.jar -E tests/lint/rat-excludes -d . | (grep "== File" > /tmp/$$.apache-rat.txt || true) -#if grep --quiet -E "File" /tmp/$$.apache-rat.txt; then -# echo "Need to add ASF header to the following files." -# echo "----------------File List----------------" -# cat /tmp/$$.apache-rat.txt -# echo "-----------------------------------------" -# echo "Use the following steps to add the headers:" -# echo "- Create file_list.txt in your text editor" -# echo "- Copy paste the above content in file-list into file_list.txt" -# echo "- python3 tests/lint/add_asf_header.py file_list.txt" -# exit 1 -#fi +echo "Check ASF license header..." +java -jar /bin/apache-rat.jar -E tests/lint/rat-excludes -d . | (grep "== File" > /tmp/$$.apache-rat.txt || true) +if grep --quiet -E "File" /tmp/$$.apache-rat.txt; then + echo "Need to add ASF header to the following files." + echo "----------------File List----------------" + cat /tmp/$$.apache-rat.txt + echo "-----------------------------------------" + echo "Use the following steps to add the headers:" + echo "- Create file_list.txt in your text editor" + echo "- Copy paste the above content in file-list into file_list.txt" + echo "- python3 tests/lint/add_asf_header.py file_list.txt" + exit 1 +fi echo "Check codestyle of c++ code..." make cpplint diff --git a/tutorials/optimize/opt_conv_tensorcore.py b/tutorials/optimize/opt_conv_tensorcore.py index 2e5a53f1c0bb..774b4c7258bb 100644 --- a/tutorials/optimize/opt_conv_tensorcore.py +++ b/tutorials/optimize/opt_conv_tensorcore.py @@ -248,12 +248,11 @@ def intrin_func(ins, outs): # The only thing we should do is to make sure all threads in a warp can call TensorCore at the same time. # - # Define tiling sizes -block_row_warps = 2 -block_col_warps = 4 -warp_row_tiles = 4 -warp_col_tiles = 2 +block_row_warps = 4 +block_col_warps = 2 +warp_row_tiles = 2 +warp_col_tiles = 4 warp_size = 32 chunk = 2 @@ -333,7 +332,8 @@ def intrin_func(ins, outs): ctx = tvm.gpu(0) if nvcc.have_tensorcore(ctx.compute_version): - func = tvm.build(s, [A, W, Conv], 'cuda') + with tvm.build_config(auto_unroll_max_step=16): + func = tvm.build(s, [A, W, Conv], 'cuda') a_np = np.random.uniform(size=data_shape).astype(A.dtype) w_np = np.random.uniform(size=kernel_shape).astype(W.dtype) a = tvm.nd.array(a_np, ctx) diff --git a/vta/python/vta/build_module.py b/vta/python/vta/build_module.py index 5c243751c340..cec217cbd393 100644 --- a/vta/python/vta/build_module.py +++ b/vta/python/vta/build_module.py @@ -80,6 +80,7 @@ def add_debug(stmt): if debug_flag: pass_list.append((1, add_debug)) pass_list.append((2, ir_pass.inject_alu_intrin)) + pass_list.append((3, tvm.ir_pass.LowerStorageAccessInfo)) pass_list.append((3, ir_pass.fold_uop_loop)) pass_list.append((3, ir_pass.cpu_access_rewrite)) return tvm.build_config(add_lower_pass=pass_list, **kwargs) From 610410112098ad6dd1bd035cc30ae939cff6af1b Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Tue, 22 Oct 2019 14:07:46 -0700 Subject: [PATCH 11/12] add 32*16*8 TensorCore test --- src/codegen/codegen_cuda.cc | 24 ++++- src/codegen/codegen_cuda.h | 1 + src/pass/infer_fragment.cc | 4 +- .../unittest/test_schedule_tensor_core.py | 102 +++++++++--------- 4 files changed, 79 insertions(+), 52 deletions(-) diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index c07b4fa52a7a..55b4810ed4d8 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -395,7 +395,7 @@ void CodeGenCUDA::VisitStmt_(const Allocate* op) { CHECK(op->type == Float(16) || op->type == Float(32) || op->type == Int(32)) << "Accumulator only support half, float and int type for now"; } - constant_size /= 256; + constant_size = GetWmmaFragmentSize(scope, buffer, constant_size); PrintWmmaScope(scope, op->type, buffer, stream); } else { PrintStorageScope(scope, stream); @@ -533,5 +533,27 @@ void CodeGenCUDA::PrintWmmaScope(const std::string &scope, Type t, } } +int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string &scope, + const Variable* variable, int32_t size) { + std::string shape_str = fragment_shapes[variable]; + size_t m, n, k; + size_t last_pos = 0, pos = 0; + pos = shape_str.find(", ", last_pos); + m = std::stoi(shape_str.substr(last_pos, pos - last_pos)); + last_pos = pos + 2; + pos = shape_str.find(", ", last_pos); + n = std::stoi(shape_str.substr(last_pos, pos - last_pos)); + last_pos = pos + 2; + k = std::stoi(shape_str.substr(last_pos, shape_str.length() - last_pos)); + if (scope == "wmma.matrix_a") { + return size / m / k; + } else if (scope == "wmma.matrix_b") { + return size / n / k; + } else if (scope == "wmma.accumulator") { + return size / m / n; + } + return 0; +} + } // namespace codegen } // namespace tvm diff --git a/src/codegen/codegen_cuda.h b/src/codegen/codegen_cuda.h index ed43ef68474e..e1476cfd68c9 100644 --- a/src/codegen/codegen_cuda.h +++ b/src/codegen/codegen_cuda.h @@ -86,6 +86,7 @@ class CodeGenCUDA final : public CodeGenC { std::unordered_map fragment_layouts; friend void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p); void PrintWmmaScope(const std::string& scope, Type t, const Variable* variable, std::ostream& os); + int32_t GetWmmaFragmentSize(const std::string &scope, const Variable* variable, int32_t size); }; } // namespace codegen diff --git a/src/pass/infer_fragment.cc b/src/pass/infer_fragment.cc index 9c71f5c367ba..d9c0ef04787b 100644 --- a/src/pass/infer_fragment.cc +++ b/src/pass/infer_fragment.cc @@ -183,8 +183,8 @@ class InferFragmenter : public IRMutator { FragmentGetter::FragmentInfo info = fragment_getter.fragments.at(buffer); // Add shape attribute to all fragments - std::string shape = std::to_string(info.n) + ", " + - std::to_string(info.m) + ", " + + std::string shape = std::to_string(info.m) + ", " + + std::to_string(info.n) + ", " + std::to_string(info.k); Expr shape_expr = StringImm::make(shape); Stmt shape_attr = AttrStmt::make(op->buffer_var, attr::fragment_shape, shape_expr, stmt); diff --git a/tests/python/unittest/test_schedule_tensor_core.py b/tests/python/unittest/test_schedule_tensor_core.py index 2b6879fe0d20..9fe72cd4e5d2 100644 --- a/tests/python/unittest/test_schedule_tensor_core.py +++ b/tests/python/unittest/test_schedule_tensor_core.py @@ -22,12 +22,16 @@ VERIFY = True -def intrin_wmma_load_matrix(scope): - n = 16 - A = tvm.placeholder((n, n), name='A', dtype='float16') - BA = tvm.decl_buffer(A.shape, A.dtype, scope='shared', data_alignment=32, offset_factor=256) - C = tvm.compute((n, n), lambda i, j: A[i, j], name='C') - BC = tvm.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=256) +def intrin_wmma_load_matrix(shape, scope): + n, m, l = shape + if scope == "wmma.matrix_a": + row, col = n, l + elif scope == "wmma.matrix_b": + row, col = l, m + A = tvm.placeholder((row, col), name='A', dtype='float16') + BA = tvm.decl_buffer(A.shape, A.dtype, scope='shared', data_alignment=32, offset_factor=row * col) + C = tvm.compute((row, col), lambda i, j: A[i, j], name='C') + BC = tvm.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=row * col) def intrin_func(ins, outs): ib = tvm.ir_builder.create() @@ -35,25 +39,25 @@ def intrin_func(ins, outs): BA = ins[0] BC = outs[0] ib.emit(tvm.call_intrin('handle', 'tvm_load_matrix_sync', - BC.data, n, n, n, BC.elem_offset // 256, - BA.access_ptr('r'), n, 'row_major')) + BC.data, n, m, l, BC.elem_offset // (row * col), + BA.access_ptr('r'), col, 'row_major')) return ib.get() return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) -def intrin_wmma_gemm(): - n = 16 - A = tvm.placeholder((n, n), name='A', dtype='float16') - B = tvm.placeholder((n, n), name='B', dtype='float16') - k = tvm.reduce_axis((0, n), name="k") - C = tvm.compute((n, n), +def intrin_wmma_gemm(shape): + n, m, l = shape + A = tvm.placeholder((n, l), name='A', dtype='float16') + B = tvm.placeholder((l, m), name='B', dtype='float16') + k = tvm.reduce_axis((0, l), name="k") + C = tvm.compute((n, m), lambda ii, jj: tvm.sum(A[ii, k].astype('float') * B[k, jj].astype('float'), axis=k), name='C') - BA = tvm.decl_buffer(A.shape, A.dtype, name='BA', scope='wmma.matrix_a', data_alignment=32, offset_factor=256) - BB = tvm.decl_buffer(B.shape, B.dtype, name='BB', scope='wmma.matrix_b', data_alignment=32, offset_factor=256) - BC = tvm.decl_buffer(C.shape, C.dtype, name='BC', scope='wmma.accumulator', data_alignment=32, offset_factor=256) + BA = tvm.decl_buffer(A.shape, A.dtype, name='BA', scope='wmma.matrix_a', data_alignment=32, offset_factor=n * l) + BB = tvm.decl_buffer(B.shape, B.dtype, name='BB', scope='wmma.matrix_b', data_alignment=32, offset_factor=l * m) + BC = tvm.decl_buffer(C.shape, C.dtype, name='BC', scope='wmma.accumulator', data_alignment=32, offset_factor=n * m) def intrin_func(ins, outs): BA, BB = ins @@ -61,16 +65,16 @@ def intrin_func(ins, outs): def init(): ib = tvm.ir_builder.create() - ib.emit(tvm.call_intrin('handle', 'tvm_fill_fragment', BC.data, n, n, n, BC.elem_offset // 256, 0.0)) + ib.emit(tvm.call_intrin('handle', 'tvm_fill_fragment', BC.data, n, m, l, BC.elem_offset // (n * m), 0.0)) return ib.get() def update(): ib = tvm.ir_builder.create() ib.emit(tvm.call_intrin('handle', 'tvm_mma_sync', - BC.data, BC.elem_offset // 256, - BA.data, BA.elem_offset // 256, - BB.data, BB.elem_offset // 256, - BC.data, BC.elem_offset // 256)) + BC.data, BC.elem_offset // (n * m), + BA.data, BA.elem_offset // (n * l), + BB.data, BB.elem_offset // (l * m), + BC.data, BC.elem_offset // (n * m))) return ib.get() return update(), init(), update() @@ -78,12 +82,12 @@ def update(): return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC}) -def intrin_wmma_store_matrix(): - n = 16 - A = tvm.placeholder((n, n), name='A', dtype='float32') - BA = tvm.decl_buffer(A.shape, A.dtype, scope='wmma.accumulator', data_alignment=32, offset_factor=256) - C = tvm.compute((n, n), lambda i, j: A[i, j], name='C') - BC = tvm.decl_buffer(C.shape, C.dtype, scope='global', data_alignment=32, offset_factor=256) +def intrin_wmma_store_matrix(shape): + n, m, l = shape + A = tvm.placeholder((n, m), name='A', dtype='float32') + BA = tvm.decl_buffer(A.shape, A.dtype, scope='wmma.accumulator', data_alignment=32, offset_factor=n * m) + C = tvm.compute((n, m), lambda i, j: A[i, j], name='C') + BC = tvm.decl_buffer(C.shape, C.dtype, scope='global', data_alignment=32, offset_factor=n * m) def intrin_func(ins, outs): ib = tvm.ir_builder.create() @@ -91,8 +95,8 @@ def intrin_func(ins, outs): BA = ins[0] BC = outs[0] ib.emit(tvm.call_intrin('handle', 'tvm_store_matrix_sync', - BA.data, n, n, n, BA.elem_offset // 256, - BC.access_ptr('w'), n, 'row_major')) + BA.data, n, m, l, BA.elem_offset // (n * m), + BC.access_ptr('w'), m, 'row_major')) return ib.get() return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) @@ -109,15 +113,15 @@ def test_tensor_core_batch_matmal(): batch_size = 4 n = 512 m, l = n, n - assert (n % 16 == 0) - assert (m % 16 == 0) + assert (n % 32 == 0) + assert (m % 8 == 0) assert (l % 16 == 0) - nn, mm, ll = n // 16, m // 16, l // 16 - A = tvm.placeholder((batch_size, nn, ll, 16, 16), name='A', dtype='float16') - B = tvm.placeholder((batch_size, ll, mm, 16, 16), name='B', dtype='float16') + nn, mm, ll = n // 32, m // 8, l // 16 + A = tvm.placeholder((batch_size, nn, ll, 32, 16), name='A', dtype='float16') + B = tvm.placeholder((batch_size, ll, mm, 16, 8), name='B', dtype='float16') k1 = tvm.reduce_axis((0, ll), name='k1') k2 = tvm.reduce_axis((0, 16), name='k2') - C = tvm.compute((batch_size, nn, mm, 16, 16), + C = tvm.compute((batch_size, nn, mm, 32, 8), lambda b, i, j, ii, jj: tvm.sum(A[b, i, k1, ii, k2].astype('float') * B[b, k1, j, k2, jj].astype('float'), axis=[k1, k2]), name='Fragment_C') @@ -174,7 +178,6 @@ def test_tensor_core_batch_matmal(): s[AS].bind(tx, thread_y) s[AS].bind(ty, thread_z) s[AS].bind(to, thread_x) - s[AS].vectorize(ti) s[BS].compute_at(s[CF], ko) b, xo, yo, xi, yi = BS.op.axis @@ -185,21 +188,21 @@ def test_tensor_core_batch_matmal(): s[BS].bind(tx, thread_y) s[BS].bind(ty, thread_z) s[BS].bind(to, thread_x) - s[BS].vectorize(ti) - s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_a')) - s[BF].tensorize(BF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_b')) - s[C].tensorize(kernel_i, intrin_wmma_store_matrix()) - s[CF].tensorize(_i, intrin_wmma_gemm()) + s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix((32, 8, 16), 'wmma.matrix_a')) + s[BF].tensorize(BF.op.axis[-2], intrin_wmma_load_matrix((32, 8, 16), 'wmma.matrix_b')) + s[C].tensorize(kernel_i, intrin_wmma_store_matrix((32, 8, 16))) + s[CF].tensorize(_i, intrin_wmma_gemm((32, 8, 16))) func = tvm.build(s, [A, B, C], 'cuda') ctx = tvm.gpu(0) - a_np = np.random.uniform(size=(batch_size, nn, nn, 16, 16)).astype(A.dtype) - b_np = np.random.uniform(size=(batch_size, nn, nn, 16, 16)).astype(B.dtype) + a_np = np.random.uniform(size=(batch_size, nn, ll, 32, 16)).astype(A.dtype) + b_np = np.random.uniform(size=(batch_size, ll, mm, 16, 8)).astype(B.dtype) a = tvm.nd.array(a_np, ctx) b = tvm.nd.array(b_np, ctx) - c = tvm.nd.array(np.zeros((batch_size, nn, nn, 16, 16), dtype=C.dtype), ctx) + c = tvm.nd.array(np.zeros((batch_size, nn, mm, 32, 8), dtype=C.dtype), ctx) + func(a, b, c) evaluator = func.time_evaluator(func.entry_name, ctx, number=3) print('gemm with tensor core: %f ms' % (evaluator(a, b, c).mean * 1e3)) @@ -211,6 +214,7 @@ def test_tensor_core_batch_matmal(): np.testing.assert_allclose(c_np, np.matmul(a_np.astype(C.dtype), b_np.astype(C.dtype)), rtol=1e-4, atol=1e-4) + def test_tensor_core_batch_conv(): if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"): print("skip because cuda is not enabled..") @@ -349,10 +353,10 @@ def test_tensor_core_batch_conv(): s[WS].bind(to, thread_x) s[WS].vectorize(ti) - s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_a')) - s[WF].tensorize(WF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_b')) - s[Conv].tensorize(nnc, intrin_wmma_store_matrix()) - s[ConvF].tensorize(nnf, intrin_wmma_gemm()) + s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix((16, 16, 16), 'wmma.matrix_a')) + s[WF].tensorize(WF.op.axis[-2], intrin_wmma_load_matrix((16, 16, 16), 'wmma.matrix_b')) + s[Conv].tensorize(nnc, intrin_wmma_store_matrix((16, 16, 16))) + s[ConvF].tensorize(nnf, intrin_wmma_gemm((16, 16, 16))) func = tvm.build(s, [A, W, Conv], 'cuda') From f15daf0dae2ccceb54ab14f65df7ecebbb479dbe Mon Sep 17 00:00:00 2001 From: Aleksander Fedorov Date: Wed, 23 Oct 2019 00:09:20 +0300 Subject: [PATCH 12/12] fix wmma include logic --- src/codegen/codegen_cuda.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/codegen/codegen_cuda.h b/src/codegen/codegen_cuda.h index e1476cfd68c9..53e7db45efc6 100644 --- a/src/codegen/codegen_cuda.h +++ b/src/codegen/codegen_cuda.h @@ -41,7 +41,7 @@ class CodeGenCUDA final : public CodeGenC { void AddFunction(LoweredFunc f); std::string Finish(); bool need_include_path() { - return (enable_fp16_ || enable_int8_ || need_math_constants_h_); + return (enable_fp16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_); } // override behavior void VisitStmt_(const ir::For* op) final;