From 4f980e7de3a46611d3e7265786a7c597de6b6650 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 18 Mar 2018 22:47:13 -0700 Subject: [PATCH 1/2] [SCHEDULE][PASS] Enable Warp memory and lower to shuffle --- include/tvm/ir.h | 8 + include/tvm/ir_pass.h | 9 + python/tvm/build_module.py | 4 + src/api/api_pass.cc | 1 + src/codegen/codegen_c.cc | 14 +- src/codegen/intrin_rule_cuda.cc | 10 + src/pass/ir_util.h | 17 + src/pass/lower_warp_memory.cc | 317 ++++++++++++++++++ src/schedule/bound.cc | 11 +- tests/python/integration/test_ewise.py | 40 +++ .../unittest/test_pass_lower_warp_memory.py | 27 ++ .../unittest/test_schedule_bound_inference.py | 24 ++ 12 files changed, 470 insertions(+), 12 deletions(-) create mode 100644 src/pass/lower_warp_memory.cc create mode 100644 tests/python/unittest/test_pass_lower_warp_memory.py diff --git a/include/tvm/ir.h b/include/tvm/ir.h index f36d914e621f..9e3c8cbc2be1 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -411,6 +411,14 @@ constexpr const char* tvm_call_packed_lowered = "tvm_call_packed_lowered"; * } */ constexpr const char* tvm_storage_sync = "tvm_storage_sync"; +/*! + * \brief See pseudo code + * + * Type tvm_warp_shuffle(Type value, warp_id) { + * return (value passed in by warp indicated by warp_id); + * } + */ +constexpr const char* tvm_warp_shuffle = "tvm_warp_shuffle"; /*! * \brief Initialize the global barrier. * Call this at beginning of kernel that need global barrier. diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 572385d9a895..1ae41032cbb8 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -407,6 +407,15 @@ LoweredFunc ThreadSync(LoweredFunc stmt, std::string storage_scope); */ LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size); +/*! + * \brief Lower warp memory in stmt. + * \param f The device function to be lowered. + * \param warp_size the size of warp where no sync is needed. + * this function will only take in effect if warp_size is bigger than one. + * \return Transformed function. + */ +LoweredFunc LowerWarpMemory(LoweredFunc f, int warp_size); + /*! * \brief Lower packed function call. * \param f The function to be lowered. diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index 03a79860e9ee..0b86cde626af 100755 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -450,6 +450,10 @@ def build(sch, else: raise ValueError("unknown function type %d" % func.func_type) + for i, func in enumerate(fdevice): + warp_size = target.thread_warp_size + fdevice[i] = ir_pass.LowerWarpMemory(func, warp_size) + if "gpu" in target.keys and not fdevice: warnings.warn( "Specified target %s, but cannot find device code, did you do bind?" % target) diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 7ec6ef4009e4..6d59cb3ae505 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -125,6 +125,7 @@ REGISTER_PASS2(SplitPipeline); REGISTER_PASS2(LiftAttrScope); REGISTER_PASS1(NarrowChannelAccess); REGISTER_PASS2(LowerThreadAllreduce); +REGISTER_PASS2(LowerWarpMemory); REGISTER_PASS2(LowerIntrin); REGISTER_PASS1(LowerTVMBuiltin); REGISTER_PASS1(CombineContextCall); diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index e00cd82abe48..9732f0ef65af 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -5,6 +5,7 @@ #include #include #include "./codegen_c.h" +#include "../pass/ir_util.h" #include "../arithmetic/compute_expr.h" namespace tvm { @@ -544,15 +545,6 @@ void CodeGenC::PrintVecBinaryOp( } } -inline bool TryGetRamp1Base(Expr index, int lanes, Expr *base) { - const Ramp* r = index.as(); - if (!r) return false; - if (!is_one(r->stride)) return false; - CHECK_EQ(r->lanes, lanes); - *base = r->base; - return true; -} - void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*) int lanes = op->type.lanes(); // delcare type. @@ -563,7 +555,7 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*) CHECK(is_one(op->predicate)) << "predicated load is not supported"; Expr base; - if (TryGetRamp1Base(op->index, op->type.lanes(), &base)) { + if (GetRamp1Base(op->index, op->type.lanes(), &base)) { std::string ref = GetVecLoad(op->type, op->buffer_var.get(), base); os << ref; } else { @@ -617,7 +609,7 @@ void CodeGenC::VisitStmt_(const Store* op) { CHECK(is_one(op->predicate)) << "Predicated store is not supported"; Expr base; - if (TryGetRamp1Base(op->index, t.lanes(), &base)) { + if (GetRamp1Base(op->index, t.lanes(), &base)) { std::string value = this->PrintExpr(op->value); this->PrintVecStore(op->buffer_var.get(), t, base, value); } else { diff --git a/src/codegen/intrin_rule_cuda.cc b/src/codegen/intrin_rule_cuda.cc index 9abb99d7c7c5..1d199fe5af28 100644 --- a/src/codegen/intrin_rule_cuda.cc +++ b/src/codegen/intrin_rule_cuda.cc @@ -49,6 +49,12 @@ struct CUDAPopcount { } }; +struct CUDAShuffle { + std::string operator()(Type t, std::string name) const { + return "__shfl"; + } +}; + TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp") .set_body(DispatchExtern); @@ -67,6 +73,10 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.pow") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount") .set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle") +.set_body(DispatchExtern); + + } // namespace intrin } // namespace codegen } // namespace tvm diff --git a/src/pass/ir_util.h b/src/pass/ir_util.h index 96a41b120e46..f871133fb74f 100644 --- a/src/pass/ir_util.h +++ b/src/pass/ir_util.h @@ -161,6 +161,23 @@ inline int GetTempAllocaAlignment(Type type, int32_t const_size) { } return align; } + +/*! + * \brief Pattern match index to Ramp with stride=1 + * This is a common pattern in continuous memory load. + * \param index The index formula + * \param lanes number of lanes in the ramp + * \param base The result base. + * \return true if pattern match success and store the base to base. + */ +inline bool GetRamp1Base(Expr index, int lanes, Expr *base) { + const Ramp* r = index.as(); + if (!r) return false; + if (!is_one(r->stride)) return false; + CHECK_EQ(r->lanes, lanes); + *base = r->base; + return true; +} } // namespace ir } // namespace tvm #endif // TVM_PASS_IR_UTIL_H_ diff --git a/src/pass/lower_warp_memory.cc b/src/pass/lower_warp_memory.cc new file mode 100644 index 000000000000..dbd073c0b14c --- /dev/null +++ b/src/pass/lower_warp_memory.cc @@ -0,0 +1,317 @@ +/*! + * Copyright (c) 2018 by Contributors + * + * Lower warp memory to use local memory + * and shuffle intrinsics. + * + * \file lower_warp_memory.cc + */ +// Thanks to Andrew Adams and Vinod Grover for +// explaining the concept of warp shuffle. +#include +#include +#include +#include +#include +#include "./ir_util.h" +#include "../arithmetic/compute_expr.h" +#include "../runtime/thread_storage_scope.h" + +namespace tvm { +namespace ir { + +// Rewrite Rule +// +// There is no special warp memory in most GPUs. +// Instead, we can stripe the data into threads +// and store the data into local memory. +// +// This requires us to do the following rewriting: +// - Rewrite allocation to use local memory. +// - Rewrite store of warp memory to local store. +// - Rewrite load of waro memory to local plus a shuffle. +// +// Define a generic shuffle instrinsic warp_shuffle(data, warp_index). +// We can use the following rewriting rule +// +// Before rewrite, +// +// alloc warp warp_mem[n * warp_size * m] +// store warp_mem[m * warp_index + (warp_size * m) * y + x] +// load warp_mem[m * z + (warp_size * m) * y + x] +// subject to x \in [0, m), y \in [0, n) +// +// After rewrite: +// +// alloc local local_mem[n * m] +// store warp_mem[m * y + x] +// warp_shuffle(load warp_mem[m * y + x], z) +// subject to (m * y + x) is invariant to warp_index + +// Algorithm +// +// To implement this rewrite rule, we can do the follow step: +// For each warp memory alloc +// - Use linear pattern detector on load index to find m +// - Deduce n given warp_size and alloc size +// - Now that we have m, n, warp_size, we can proceed with the rewrite + +// Visitor to find m in pattern +// store warp_mem[m * warp_index + (warp_size * m) * y + x] +class WarpStoreCoeffFinder : private IRVisitor { + public: + WarpStoreCoeffFinder(const Variable* buffer, + Var warp_index) + : buffer_(buffer), warp_index_(warp_index) { + } + // find the warp co-efficient in the statement given the warp size + int Find(const Stmt& stmt) { + this->Visit(stmt); + return warp_coeff_; + } + + private: + /// Visitor implementation + void Visit_(const Store *op) final { + if (op->buffer_var.get() == buffer_) { + if (op->value.type().lanes() == 1) { + UpdatePattern(op->index); + } else { + Expr base; + CHECK(GetRamp1Base(op->index, op->value.type().lanes(), &base)) + << "LowerWarpMemory failed due to store index=" << op->index + << ", can only handle continuous store"; + UpdatePattern(base); + } + } else { + IRVisitor::Visit_(op); + } + } + + void UpdatePattern(const Expr& index) { + Array m = + arith::DetectLinearEquation(index, {warp_index_}); + CHECK_EQ(m.size(), 2U) + << "LowerWarpMemory failed due to store index=" << index; + int coeff; + CHECK(arith::GetConstInt(ir::Simplify(m[0]), &coeff) && coeff > 0) + << "LowerWarpMemory failed due to store index=" << index + << ", require positive constant coefficient on warp index"; + if (warp_coeff_ != 0) { + CHECK_EQ(warp_coeff_, coeff) + << "LowerWarpMemory failed due to two different store coefficient to warp index"; + } else { + warp_coeff_ = coeff; + } + } + + // The buffer variable + const Variable* buffer_; + // the warp index + Var warp_index_; + // the coefficient + int warp_coeff_{0}; +}; + + +// Visitor to find the warp index +class WarpIndexFinder : private IRVisitor { + public: + explicit WarpIndexFinder(int warp_size) + : warp_size_(warp_size) { + } + // find the warp co-efficient in the statement given the warp size + IterVar Find(const Stmt& stmt) { + this->Visit(stmt); + CHECK(warp_index_.defined()) + << "Cannot find warp index(threadIdx.x) within the scope of warp memory"; + return warp_index_; + } + + private: + void Visit(const NodeRef &node) final { + if (warp_index_.defined()) return; + IRVisitor::Visit(node); + } + + /// Visitor implementation + void Visit_(const AttrStmt *op) final { + if (op->attr_key == attr::thread_extent) { + IterVar iv(op->node.node_); + if (iv->thread_tag == "threadIdx.x") { + int value; + CHECK(arith::GetConstInt(op->value, &value) && + value == warp_size_) + << "Expect threadIdx.x 's size to be equal to warp size(" + << warp_size_ << ")" << " to enable warp memory" + << " but get " << op->value << " instead"; + warp_index_ = iv; + } + } + IRVisitor::Visit_(op); + } + // warp size + int warp_size_{0}; + // the warp index + IterVar warp_index_{nullptr}; +}; +// Mutator to change the read pattern +class WarpAccessRewriter : protected IRMutator { + public: + explicit WarpAccessRewriter(int warp_size) + : warp_size_(warp_size) {} + // Rewrite the allocate statement which transforms + // warp memory to local memory. + Stmt Rewrite(const Allocate* op, const Stmt& stmt) { + buffer_ = op->buffer_var.get(); + int alloc_size = op->constant_allocation_size(); + CHECK_GT(alloc_size, 0) + << "warp memory only support constant alloc size"; + alloc_size *= op->type.lanes(); + warp_index_ = WarpIndexFinder(warp_size_).Find(op->body)->var; + warp_coeff_ = WarpStoreCoeffFinder( + buffer_, warp_index_).Find(op->body); + CHECK_EQ(alloc_size % (warp_size_ * warp_coeff_), 0) + << "Warp memory must be multiple of warp size"; + warp_group_ = alloc_size / (warp_size_ * warp_coeff_); + return Allocate::make( + op->buffer_var, + op->type, + {make_const(Int(32), alloc_size / warp_size_)}, + op->condition, + this->Mutate(op->body)); + } + + protected: + Expr Mutate_(const Variable* op, const Expr& expr) { + CHECK(op != buffer_) + << "Cannot access address of warp memory directly"; + return IRMutator::Mutate_(op, expr); + } + + Stmt Mutate_(const Store* op, const Stmt& stmt) { + if (op->buffer_var.get() == buffer_) { + Expr local_index, group; + std::tie(local_index, group) = SplitIndexByGroup(op->index); + return Store::make(op->buffer_var, op->value, local_index, op->predicate); + } else { + return IRMutator::Mutate_(op, stmt); + } + } + + Expr Mutate_(const Load* op, const Expr& expr) { + if (op->buffer_var.get() == buffer_) { + Expr local_index, group; + std::tie(local_index, group) = SplitIndexByGroup(op->index); + // invariance: local index must do not contain warp id + CHECK(!ExprUseVar(local_index, {warp_index_.get()})) + << "LowerWarpMemory failed to rewrite load to shuffle for index " + << op->index << " local_index=" << local_index; + Expr load_value = Load::make( + op->type, op->buffer_var, local_index, op->predicate); + return Call::make(load_value.type(), + intrinsic::tvm_warp_shuffle, + {load_value, group}, + Call::Intrinsic); + } else { + return IRMutator::Mutate_(op, expr); + } + } + // Split the index to the two component + // + // local index is the index in the local + // source index is the corresponding source index + // in this access pattern. + std::pair SplitIndexByGroup(const Expr& index) { + if (index.type().lanes() != 1) { + Expr base, local_index, group; + CHECK(GetRamp1Base(index, index.type().lanes(), &base)); + std::tie(local_index, group) = SplitIndexByGroup(base); + local_index = + Ramp::make(local_index, make_const(local_index.type(), 1), index.type().lanes()); + return std::make_pair(local_index, group); + } + Expr m = make_const(index.type(), warp_coeff_); + Range rng = Range::make_by_min_extent( + make_zero(index.type()), make_const(index.type(), warp_size_)); + Map vrange({{warp_index_, rng}}); + + // simple case, warp index is on the highest. + if (warp_group_ == 1) { + Expr x = Simplify(index % m, vrange); + Expr z = Simplify(index / m, vrange); + return std::make_pair(x, z); + } else { + Expr x = Simplify(index % m, vrange); + Expr y = index / make_const(index.type(), warp_coeff_ * warp_size_); + y = y * m + x; + Expr z = index % make_const(index.type(), warp_coeff_ * warp_size_) / m; + return std::make_pair(Simplify(y, vrange), Simplify(z, vrange)); + } + } + + private: + // the warp size + int warp_size_{0}; + // The buffer variable + const Variable* buffer_; + // Warp index + Var warp_index_; + // the coefficient m + int warp_coeff_{0}; + // the coefficient n + int warp_group_{0}; +}; + +// Mutator to change the read pattern +class WarpMemoryRewriter : private IRMutator { + public: + explicit WarpMemoryRewriter(int warp_size) + : warp_size_(warp_size) { + } + + Stmt Rewrite(Stmt stmt) { + if (warp_size_ == 1) return stmt; + return this->Mutate(stmt); + } + + private: + Stmt Mutate_(const Allocate* op, const Stmt& stmt) { + if (warp_buffer_.count(op->buffer_var.get())) { + WarpAccessRewriter rewriter(warp_size_); + return rewriter.Rewrite(op, stmt); + } else { + return IRMutator::Mutate_(op, stmt); + } + } + + Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) { + using runtime::StorageScope; + if (op->attr_key == attr::storage_scope) { + const Variable* buf = op->node.as(); + StorageScope scope = StorageScope::make(op->value.as()->value); + if (scope.rank == runtime::StorageRank::kWarp) { + warp_buffer_.insert(buf); + Stmt ret = IRMutator::Mutate_(op, stmt); + op = ret.as(); + return AttrStmt::make( + op->node, op->attr_key, StringImm::make("local"), op->body); + } + } + return IRMutator::Mutate_(op, stmt); + } + + int warp_size_{0}; + std::unordered_set warp_buffer_; +}; + +LoweredFunc +LowerWarpMemory(LoweredFunc f, int warp_size) { + CHECK_EQ(f->func_type, kDeviceFunc); + auto n = std::make_shared(*f.operator->()); + n->body = WarpMemoryRewriter(warp_size).Rewrite(n->body); + return LoweredFunc(n); +} + +} // namespace ir +} // namespace tvm diff --git a/src/schedule/bound.cc b/src/schedule/bound.cc index 908b579ec9a4..7929969a8502 100644 --- a/src/schedule/bound.cc +++ b/src/schedule/bound.cc @@ -42,7 +42,16 @@ bool NeedRelax(const IterVar& iv, if (tag.length() == 0 || tag == "pipeline") { return !found_attach; } - return static_cast(scope.rank) <= ThreadScope::make(tag).rank; + ThreadScope ts = ThreadScope::make(tag); + + // When there is warp memory + // threadIdx.x must be set to be warp index. + if (scope.rank == StorageRank::kWarp && + ts.rank == 1 && + ts.dim_index == 0) { + return true; + } + return static_cast(scope.rank) <= ts.rank; } // infer storage scope, if not given diff --git a/tests/python/integration/test_ewise.py b/tests/python/integration/test_ewise.py index 414b1ff008fe..ee880ed1d9fb 100644 --- a/tests/python/integration/test_ewise.py +++ b/tests/python/integration/test_ewise.py @@ -1,4 +1,5 @@ import tvm +from tvm.contrib import nvcc import numpy as np import time @@ -155,7 +156,46 @@ def check_device(device): run("uint64") +def try_warp_memory(): + """skip this in default test because it require higher arch""" + m = 128 + A = tvm.placeholder((m,), name='A') + B = tvm.compute((m,), lambda i: A[i] + 3, name='B') + warp_size = 32 + s = tvm.create_schedule(B.op) + AA = s.cache_read(A, "warp", [B]) + xo, xi = s[B].split(B.op.axis[0], warp_size * 2) + xi0, xi1 = s[B].split(xi, factor=warp_size) + tx = tvm.thread_axis("threadIdx.x") + s[B].bind(xi1, tx) + s[B].bind(xo, tvm.thread_axis("blockIdx.x")) + s[AA].compute_at(s[B], xo) + xo, xi = s[AA].split(s[AA].op.axis[0], warp_size) + s[AA].bind(xi, tx) + + @tvm.register_func + def tvm_callback_cuda_compile(code): + ptx = nvcc.compile_cuda(code, target="ptx") + return ptx + + # one line to build the function. + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("skip because %s is not enabled.." % device) + return + f = tvm.build(s, [A, B], device) + a = tvm.nd.array((np.random.uniform(size=m) * 256).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(m, dtype=B.dtype), ctx) + f(a, b) + np.testing.assert_allclose( + b.asnumpy(), a.asnumpy() + 3, rtol=1e-6) + + check_device("cuda") + + if __name__ == "__main__": + try_warp_memory() test_add() test_log_pow_llvm() test_exp() diff --git a/tests/python/unittest/test_pass_lower_warp_memory.py b/tests/python/unittest/test_pass_lower_warp_memory.py new file mode 100644 index 000000000000..9793b21371bd --- /dev/null +++ b/tests/python/unittest/test_pass_lower_warp_memory.py @@ -0,0 +1,27 @@ +import tvm + +def test_lower_warp_mem(): + m = 128 + A = tvm.placeholder((m,), name='A') + B = tvm.compute((m,), lambda i: A[i] + 3, name='B') + + s = tvm.create_schedule(B.op) + AA = s.cache_read(A, "warp", [B]) + xo, xi = s[B].split(B.op.axis[0], 32) + xi0, xi1 = s[B].split(xi, factor=16) + tx = tvm.thread_axis("threadIdx.x") + s[B].bind(xi1, tx) + s[B].bind(xo, tvm.thread_axis("blockIdx.x")) + s[AA].compute_at(s[B], xo) + xo, xi = s[AA].split(s[AA].op.axis[0], 16) + s[AA].bind(xi, tx) + + f = tvm.lower(s, [A, B]) + fhost, fdevice = tvm.ir_pass.SplitHostDevice(f) + fdevice = tvm.ir_pass.LowerWarpMemory(fdevice, 16) + assert(fdevice.body.body.value.value == "local") + assert(fdevice.body.body.body.extents[0].value == 2) + + +if __name__ == "__main__": + test_lower_warp_mem() diff --git a/tests/python/unittest/test_schedule_bound_inference.py b/tests/python/unittest/test_schedule_bound_inference.py index 3601833de08d..30be3783bbb3 100644 --- a/tests/python/unittest/test_schedule_bound_inference.py +++ b/tests/python/unittest/test_schedule_bound_inference.py @@ -53,6 +53,29 @@ def test_bound3(): assert(bounds[A1.op.axis[0]].extent.value==32) assert(bounds[A1.op.axis[1]].extent.value==16) + +def test_bound_warp(): + m = tvm.var('m') + l = tvm.var('l') + A = tvm.placeholder((m, l), name='A') + A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1') + A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') + + s = tvm.create_schedule(A2.op) + s[A1].set_scope("warp") + xo, xi = s[A2].split(A2.op.axis[0], 32) + xi0, xi1 = s[A2].split(xi, factor=16) + tx = tvm.thread_axis("threadIdx.x") + s[A2].bind(xi1, tx) + s[A2].bind(xi0, tvm.thread_axis("threadIdx.y")) + y = s[A2].op.axis[1] + s[A1].compute_at(s[A2], y) + xo, xi = s[A1].split(s[A1].op.axis[0], factor=16) + s[A1].bind(xi, tx) + bounds = tvm.schedule.InferBound(s) + assert isinstance(bounds, tvm.container.Map) + assert(bounds[A1.op.axis[0]].extent.value==16) + def test_bound_scan(): m = tvm.var("m") n = tvm.var("n") @@ -249,3 +272,4 @@ def test_gemm_bound(): test_bound_conv1d() test_bound2() test_gemm_bound() + test_bound_warp() From e049befcfb450ff4df08d100c4656a9d5cd74ac4 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 25 Mar 2018 19:24:04 -0700 Subject: [PATCH 2/2] OpenCL dispatches for now to intel shuffle --- src/codegen/intrin_rule_opencl.cc | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/codegen/intrin_rule_opencl.cc b/src/codegen/intrin_rule_opencl.cc index 924abcade63f..b8b2412215d1 100644 --- a/src/codegen/intrin_rule_opencl.cc +++ b/src/codegen/intrin_rule_opencl.cc @@ -27,6 +27,17 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.pow") TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount") .set_body(DispatchExtern); +// There is no warp shuffle instruction in standard OpenCL +// When shuffle is used, we assume it is intel's shuffle extension +struct IntelShuffle { + std::string operator()(Type t, std::string name) const { + return "intel_sub_group_shuffle"; + } +}; + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tvm_warp_shuffle") +.set_body(DispatchExtern); + } // namespace intrin } // namespace codegen } // namespace tvm