diff --git a/docs/langref/hybrid_script.rst b/docs/langref/hybrid_script.rst index 122bcd95e690..7043281fcafb 100644 --- a/docs/langref/hybrid_script.rst +++ b/docs/langref/hybrid_script.rst @@ -68,17 +68,23 @@ to LLVM module. Tuning ~~~~~~ -**Under construction, not supported yet.** - Follow up the example above, you can use some tvm like interfaces to tune the code: .. code-block:: python + i, j = c.op.axis sch = tvm.create_schedule(op) jo, ji = sch.split(j, 4) sch.vectorize(ji) -``split``, ``reorder``, and loop_annotation will be supported! +For now, you can use loop annotations (``unroll``, ``parallel``, ``vectorize``, and ``bind``), +loop manipulation (``split`` and ``fuse``), and ``reorder``. + +.. note:: + + This is a preliminary function, so users should be in charge of the correctness + of the functionality after tuning. Specifically, users should be careful when + fusing and reorderding imperfect loops. Loops ~~~~~ diff --git a/include/tvm/operation.h b/include/tvm/operation.h index 02cd0d016f39..3509b133cfc3 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -459,6 +459,8 @@ class HybridOpNode : public OperationNode { Array inputs; /*! \brief Symbolic placeholder representation of outputs */ Array outputs; + /*! \brief The axis of iterations */ + Array axis; /*! \brief the statement that generates the computation. This is * slightly different from the body in ExternOpNode. All the output * tensors keep its own name specified by users in the script. @@ -500,6 +502,7 @@ class HybridOpNode : public OperationNode { v->Visit("attrs", &attrs); v->Visit("inputs", &inputs); v->Visit("outputs", &outputs); + v->Visit("axis", &axis); v->Visit("body", &body); } EXPORT static Operation make(std::string name, diff --git a/python/tvm/tensor.py b/python/tvm/tensor.py index 9a98e9a6e769..e1345ad373bf 100644 --- a/python/tvm/tensor.py +++ b/python/tvm/tensor.py @@ -152,7 +152,7 @@ class ComputeOp(Operation): """Compute operation.""" @property def axis(self): - """Represent axis of IterVar, only defined when it is a ComputeOp""" + """Represent axis of IterVar, defined when it is a ComputeOp""" return self.__getattr__("axis") @property @@ -184,4 +184,7 @@ class ExternOp(Operation): @register_node class HybridOp(Operation): """Hybrid operation.""" - pass + @property + def axis(self): + """Represent axis of IterVar, also defined when it is a HybridOp""" + return self.__getattr__("axis") diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index d4cb2b4c632b..baf42f9367b4 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -212,6 +212,7 @@ void ComputeOpNode::GatherBound( const Operation& self, const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const { + CHECK_EQ(self.operator->(), this); const TensorDom& tdom = tensor_dom.at(self.output(0)); for (size_t i = 0; i < this->axis.size(); ++i) { Range r = arith::Union(tdom.data.at(i)).cover_range(this->axis[i]->dom); diff --git a/src/op/hybrid_op.cc b/src/op/hybrid_op.cc index 4dbb2c0b964f..acd7b5737c5f 100644 --- a/src/op/hybrid_op.cc +++ b/src/op/hybrid_op.cc @@ -1,5 +1,5 @@ /*! - * Copyright (c) 2018 by Contributors + * Copyright (c) 2019 by Contributors * \brief Hybrid computation rule. * \file hybrid_op.cc */ @@ -7,8 +7,13 @@ #include #include #include +#include +#include +#include #include +#include #include "op_util.h" +#include "hybrid_op.h" namespace tvm { using namespace ir; @@ -25,7 +30,7 @@ int HybridOpNode::num_outputs() const { } Array HybridOpNode::root_iter_vars() const { - return {}; + return this->axis; } Type HybridOpNode::output_dtype(size_t i) const { @@ -52,6 +57,7 @@ Operation HybridOpNode::make(std::string name, n->attrs = std::move(attrs); n->inputs = std::move(inputs); n->outputs = std::move(outputs); + n->axis = op::GatherLoopVars(body); n->body = std::move(body); Operation res = Operation(n); return res; @@ -62,8 +68,8 @@ Array HybridOpNode::InputTensors() const { } Operation HybridOpNode::ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const { + const Operation &self, + const std::unordered_map &rmap) const { CHECK_EQ(self.operator->(), this); auto n = make_node(*this); n->body = op::ReplaceTensor(this->body, rmap); @@ -83,13 +89,13 @@ Operation HybridOpNode::ReplaceInputs( } void HybridOpNode::PropBoundToInputs( - const Operation& self, - const std::unordered_map& dom_map, + const Operation &self, + const std::unordered_map &dom_map, std::unordered_map* out_dom_map) const { for (Tensor t : this->inputs) { auto it = out_dom_map->find(t); if (it == out_dom_map->end()) continue; - TensorDom& dom = it->second; + TensorDom &dom = it->second; for (size_t i = 0; i < t->shape.size(); ++i) { dom.data[i].emplace_back(IntSet::range( Range::make_by_min_extent( @@ -99,15 +105,20 @@ void HybridOpNode::PropBoundToInputs( } void HybridOpNode::GatherBound( - const Operation& self, - const std::unordered_map& tensor_dom, + const Operation &self, + const std::unordered_map &tensor_dom, std::unordered_map* out_dom_map) const { + for (auto iter_var : axis) { + CHECK(!out_dom_map->count(iter_var)); + out_dom_map->operator[](iter_var) = iter_var->dom; + } } Stmt HybridOpNode::BuildRealize( - const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const { + const Stage &stage, + const std::unordered_map &realize_map, + const Stmt &body) const { + // TODO(@were): Add attribute inject here and remove it from hybrid parser. CHECK_EQ(stage->op.get(), this); Stmt realize_body = body; for (int k = 0; k < num_outputs(); ++k) { @@ -126,8 +137,8 @@ Stmt HybridOpNode::BuildRealize( } Stmt HybridOpNode::BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, + const Stage &stage, + const std::unordered_map &dom_map, bool debug_keep_trivial_loop) const { CHECK_EQ(stage->op.operator->(), this); Stmt ret = AttrStmt::make(make_zero(Int(32)), attr::extern_scope, 0, this->body); @@ -184,6 +195,302 @@ Stmt HybridOpNode::BuildProvide( * */ ret = op::ReplaceTensor(ret, rmap); ret = op::ReplaceProvideTensor(ret, rmap); + + ret = op::ApplySchedule(stage, dom_map, ret); return ret; } + +namespace op { + + +Stmt ApplyLoopShapes(const Stage &stage, + const std::unordered_map &dom_map, Stmt stmt) { + class LoopSpliter : public IRMutator { + Expr factor; + const Variable *parent; + IterVar inner, outer; + + public: + bool splitted; + LoopSpliter(const SplitNode *split, + const std::unordered_map &dom_map) : + factor(split->factor), splitted(false) { + parent = split->parent->var.get(); + + auto &inner_ = split->inner; + CHECK(dom_map.count(inner_)); + auto &inner_dom = dom_map.find(inner_)->second; + CHECK(is_const_int(inner_dom->min, 0)); + + auto &outer_ = split->outer; + CHECK(dom_map.count(outer_)); + auto &outer_dom = dom_map.find(outer_)->second; + CHECK(is_const_int(outer_dom->min, 0)); + + inner = IterVarNode::make(inner_dom, inner_->var, inner_->iter_type); + outer = IterVarNode::make(outer_dom, outer_->var, outer_->iter_type); + } + + Stmt Mutate_(const For *op, const Stmt &stmt) { + if (op->loop_var.get() == parent) { + std::unordered_map rmap; + rmap[op->loop_var.get()] = inner + outer * factor; + Stmt ret = ir::Substitute(op->body, rmap); + Expr cond = likely(outer * factor < (op->extent - inner)); + ret = IfThenElse::make(cond, ret); + ret = For::make(inner->var, Expr(0), inner->dom->extent, + IterVarTypeToForType(inner->iter_type), op->device_api, ret); + ret = For::make(outer->var, Expr(0), outer->dom->extent, + IterVarTypeToForType(outer->iter_type), op->device_api, ret); + splitted = true; + return ret; + } + return IRMutator::Mutate_(op, stmt); + } + }; + + class LoopFuser : public IRMutator { + const IterVar &parent; + const Variable *inner; + const Variable *outer; + bool under_outer; + Expr extent; + + public: + bool fused; + explicit LoopFuser(const FuseNode *fuse_) + : parent(fuse_->fused), inner(fuse_->inner->var.get()), + outer(fuse_->outer->var.get()), under_outer(false), + extent(0), fused(false) {} + + // TODO(@were): Handle imperfect loops + + Stmt Mutate_(const For *op, const Stmt &stmt) { + if (op->loop_var.get() == inner) { + CHECK(under_outer); + std::unordered_map rmap; + rmap[op->loop_var.get()] = parent % op->extent; + extent = op->extent; + fused = true; + return ir::Substitute(op->body, rmap); + } else if (op->loop_var.get() == outer) { + under_outer = true; + Stmt body = IRMutator::Mutate(op->body); + std::unordered_map rmap; + rmap[op->loop_var.get()] = parent / extent; + body = ir::Substitute(body, rmap); + under_outer = false; + return For::make(parent->var, Expr(0), extent * op->extent, + op->for_type, op->device_api, body); + } else if (under_outer) { + Stmt body = IRMutator::Mutate(op->body); + std::unordered_map rmap; + rmap[op->loop_var.get()] = parent / extent % op->extent; + body = ir::Substitute(body, rmap); + extent = extent * op->extent; + return body; + } + return IRMutator::Mutate(stmt); + } + }; + + for (auto &rel : stage->relations) { + if (const SplitNode *split = rel.as()) { + LoopSpliter Spliter(split, dom_map); + stmt = Spliter.Mutate(stmt); + CHECK(Spliter.splitted); + } else if (const FuseNode *fuse = rel.as()) { + LoopFuser Fuser(fuse); + stmt = Fuser.Mutate(stmt); + CHECK(Fuser.fused); + } + } + + return stmt; +} + +Stmt ApplyLoopAnnotations(const Stage &stage, + const std::unordered_map &rebased, Stmt stmt) { + class LoopAnnotator : public IRMutator { + const Variable *var; + const IterVarAttr &attr; + + public: + LoopAnnotator(const Variable *var_, const IterVarAttr &attr_) : var(var_), attr(attr_) {} + + Stmt Mutate_(const For *op, const Stmt &stmt) { + if (op->loop_var.get() == var) { + if (attr->bind_thread.defined()) { + const auto &iter_var = attr->bind_thread; + if (iter_var->dom.defined()) { + CHECK(is_const_int(iter_var->dom->min, 0)); + CHECK(Equal(iter_var->dom->extent, op->extent)) + << "Thread extent and loop extent mismatch!\n"; + } + std::unordered_map rmap; + rmap[op->loop_var.get()] = iter_var; + Stmt body = ir::Substitute(op->body, rmap); + return AttrStmt::make(iter_var, "thread_extent", op->extent, body); + } else { + return For::make(op->loop_var, op->min, op->extent, + IterVarTypeToForType(attr->iter_type), op->device_api, op->body); + } + } + return IRMutator::Mutate_(op, stmt); + } + }; + + for (auto &iter_var : stage->leaf_iter_vars) { + bool need_change = false; + int found = 0; + + const IterVar &actual = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var; + const Variable *var = actual->var.get(); + ForType expected = IterVarTypeToForType(iter_var->iter_type); + IterVarAttr attr; + if (stage->iter_var_attrs.count(iter_var)) { + attr = stage->iter_var_attrs[iter_var]; + expected = IterVarTypeToForType(attr->iter_type); + } + + PostOrderVisit(stmt, [&found, &var, &attr, &expected, &need_change](const NodeRef &node) { + if (const For *op = node.as()) { + if (op->loop_var.get() == var) { + ++found; + need_change = expected != op->for_type || (attr.defined() && attr->bind_thread.defined()); + } + } + }); + + CHECK_EQ(found, 1) << " iter var should be found exactly once!"; + if (need_change) { + stmt = LoopAnnotator(var, attr).Mutate(stmt); + } + } + return stmt; +} + +Stmt ApplyLoopOrder(const Stage &stage, + const std::unordered_map &dom_map, + const std::unordered_map &rebased, Stmt stmt) { + std::vector current_order; + PostOrderVisit(stmt, [¤t_order](const NodeRef &node) { + if (const For *op = node.as()) + current_order.push_back(op->loop_var.get()); + }); + std::reverse(current_order.begin(), current_order.end()); + auto &required_ord = stage->leaf_iter_vars; + CHECK_EQ(current_order.size(), required_ord.size()) << "Cannot reorder the loops!"; + std::unordered_map reorder; + bool need_reorder = false; + for (size_t i = 0; i < current_order.size(); ++i) { + auto ¤t = current_order[i]; + const IterVar &iter_var = required_ord[i]; + const IterVar &required = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var; + CHECK(required->dom.defined() || dom_map.count(required)) << required << "\n"; + reorder[current] = required; + if (current != required->var.get()) { + need_reorder = true; + } + } + + class LoopReorder : public IRMutator { + const Stage &stage; + const std::unordered_map &dom_map; + const std::unordered_map &reorder; + + public: + LoopReorder(const Stage &stage, + const std::unordered_map &dom_map, + const std::unordered_map &reorder) + : stage(stage), dom_map(dom_map), reorder(reorder) {} + + Stmt Mutate_(const For *op, const Stmt &stmt) { + // Reorder from in to out + Stmt body_ = IRMutator::Mutate(op->body); + CHECK(reorder.count(op->loop_var.get())); + auto target = reorder.find(op->loop_var.get())->second; + if (body_.same_as(op->body) && op->loop_var.get() == target->var.get()) + return stmt; + const Stmt &body = op->body.same_as(body_) ? op->body : body_; + ForType for_type = IterVarTypeToForType(target->iter_type); + if (stage->iter_var_attrs.count(target)) { + for_type = IterVarTypeToForType(stage->iter_var_attrs[target]->iter_type); + } + const Range &range = target->dom.defined() ? target->dom : dom_map.find(target)->second; + return For::make(target->var, range->min, range->extent, + for_type, HalideIR::DeviceAPI::None, body); + } + }; + + if (need_reorder) + return LoopReorder(stage, dom_map, reorder).Mutate(stmt); + + return stmt; +} + +Stmt ApplySchedule(const Stage &stage, + const std::unordered_map &dom_map, Stmt stmt) { + // TODO(@were): Eliminate loop rebase in script parser and move the burden here + // Gather rebased variables + std::unordered_map rebased; + for (auto rel : stage->relations) { + if (auto rebase = rel.as()) { + rebased[rebase->rebased] = rebase->parent; + CHECK(rebase->parent->dom.defined()); + CHECK(dom_map.count(rebase->rebased)); + } + } + stmt = ApplyLoopShapes(stage, dom_map, stmt); + stmt = ApplyLoopOrder(stage, dom_map, rebased, stmt); + stmt = ApplyLoopAnnotations(stage, rebased, stmt); + return stmt; +} + +std::vector GatherLoopVars(Stmt stmt) { + // TODO(@were): Write a comprehensive pass to analyze iter var types + std::vector res_; + PostOrderVisit(stmt, [&res_](const NodeRef &node) { + if (const For *op = node.as()) { + Var loop_var(op->loop_var); + Range dom = Range::make_by_min_extent(op->min, op->extent); + res_.push_back(IterVarNode::make(dom, loop_var, ForTypeToIterVarType(op->for_type))); + } + }); + std::reverse(res_.begin(), res_.end()); + return res_; +} + +// replacer to replace tensors' usage in Provide +class ProviderReplacer : public ir::IRMutator { + public: + explicit ProviderReplacer(const std::unordered_map &vmap) + : vmap_(vmap) {} + + Stmt Mutate_(const ir::Provide* op, const Stmt &s) { + Tensor t = Operation(op->func.node_).output(op->value_index); + auto it = vmap_.find(t); + if (it != vmap_.end()) { + Stmt ret = ir::Provide::make( + it->second->op, it->second->value_index, op->value, op->args); + found = true; + return IRMutator::Mutate_(ret.as(), ret); + } + return IRMutator::Mutate_(op, s); + } + + // whether it is found. + bool found{false}; + + private: + const std::unordered_map &vmap_; +}; + +Stmt ReplaceProvideTensor(Stmt stmt, + const std::unordered_map &replace) { + ProviderReplacer repl(replace); + Stmt ret = repl.Mutate(stmt); + return repl.found ? ret : stmt; +} +} // namespace op } // namespace tvm diff --git a/src/op/hybrid_op.h b/src/op/hybrid_op.h new file mode 100644 index 000000000000..892e420137d6 --- /dev/null +++ b/src/op/hybrid_op.h @@ -0,0 +1,80 @@ +/*! + * Copyright (c) 2019 by Contributors + * \brief Helper utilities to implement hybrid_op. + * \file hybrid_op.h + */ +#ifndef TVM_OP_HYBRID_OP_H_ +#define TVM_OP_HYBRID_OP_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "../pass/ir_util.h" +#include "../pass/arg_binder.h" +#include "../schedule/message_passing.h" + + +namespace tvm { +namespace op { + +/*! + * \brief Find all the iteration variables in the given statement body. + * \param stmt The body to be inspected. + */ +std::vector GatherLoopVars(Stmt stmt); + +/*! + * \brief Replace the tensor reference (especially in Provide's) in stmt by the replace map. + * \param stmt The statement to be processed. + * \param replace The replacement rule. + */ +Stmt ReplaceProvideTensor(Stmt stmt, + const std::unordered_map& replace); + +/*! + * \brief Apply the schedule manipulation on the function body. + * \param stmt The statement to be processed. + * \param dom_map The extents of the iterative variables may be used. + * \param stage The schedule information to be applied. + */ +Stmt ApplySchedule(const Stage& stage, + const std::unordered_map& dom_map, Stmt stmt); + +/*! + * \brief Apply loop splits and fuses in the schedule on the function body. + * \param stage The schedule information to be applied. + * \param dom_map The extents of the iterative variables may be used. + * \param stmt The statement to be processed. + */ +Stmt ApplyLoopShapes(const Stage &stage, + const std::unordered_map& dom_map, Stmt stmt); + + +/*! + * \brief Apply loop annotation in the schedule on the function body. + * \param stage The schedule information to be applied. + * \param rebased The map specifies the rebase, a.k.a rename, relationship of these variables. + * \param stmt The statement to be processed. + */ +Stmt ApplyLoopAnnotations(const Stage &stage, + const std::unordered_map& rebased, Stmt stmt); + +/*! + * \brief Apply loop order in the schedule on the function body. + * \param stage The schedule information to be applied. + * \param dom_map The extents of the iterative variables may be used. + * \param rebased The map specifies the rebase, a.k.a rename, relationship of these variables. + * \param stmt The statement to be processed. + */ +Stmt ApplyLoopOrder(const Stage &stage, + const std::unordered_map &dom_map, + const std::unordered_map &rebased, Stmt stmt); + +} // namespace op +} // namespace tvm + +#endif // TVM_OP_HYBRID_OP_H_ diff --git a/src/op/op_util.cc b/src/op/op_util.cc index 886f7c912303..b18552d5c562 100644 --- a/src/op/op_util.cc +++ b/src/op/op_util.cc @@ -164,38 +164,6 @@ std::vector MakeIfNest(const std::vector& predicates) { return nest; } -// replacer to replace tensors' usage in Provide -class ProviderReplacer : public ir::IRMutator { - public: - explicit ProviderReplacer(const std::unordered_map& vmap) - : vmap_(vmap) {} - - Stmt Mutate_(const ir::Provide* op, const Stmt& s) { - Tensor t = Operation(op->func.node_).output(op->value_index); - auto it = vmap_.find(t); - if (it != vmap_.end()) { - Stmt ret = ir::Provide::make( - it->second->op, it->second->value_index, op->value, op->args); - found = true; - return IRMutator::Mutate_(ret.as(), ret); - } - return IRMutator::Mutate_(op, s); - } - - // whether it is found. - bool found{false}; - - private: - const std::unordered_map& vmap_; -}; - -Stmt ReplaceProvideTensor(Stmt stmt, - const std::unordered_map& replace) { - ProviderReplacer repl(replace); - Stmt ret = repl.Mutate(stmt); - return repl.found ? ret : stmt; -} - // replacer to replace tensors class TensorReplacer : public ir::IRMutator { public: @@ -247,5 +215,35 @@ Stmt Substitute(Stmt s, return ir::Substitute(s, init); } +IterVarType ForTypeToIterVarType(ir::ForType for_type) { + switch (for_type) { + case ForType::Serial: + return kDataPar; + case ForType::Parallel: + return kParallelized; + case ForType::Vectorized: + return kVectorized; + case ForType::Unrolled: + return kUnrolled; + default: + return kDataPar; + } +} + +ir::ForType IterVarTypeToForType(IterVarType iter_type) { + switch (iter_type) { + case kDataPar: + return ForType::Serial; + case kParallelized: + return ForType::Parallel; + case kVectorized: + return ForType::Vectorized; + case kUnrolled: + return ForType::Unrolled; + default: + return ForType::Serial; + } +} + } // namespace op } // namespace tvm diff --git a/src/op/op_util.h b/src/op/op_util.h index 6971f14eef73..de2e44c2ed59 100644 --- a/src/op/op_util.h +++ b/src/op/op_util.h @@ -48,14 +48,6 @@ MakeLoopNest(const Stage& stage, */ std::vector MakeIfNest(const std::vector& predicates); -/*! - * \brief Replace the tensor reference (especially in Provide's) in stmt by the replace map. - * \param stmt The statement to be processed. - * \param replace The replacement rule. - */ -Stmt ReplaceProvideTensor(Stmt stmt, - const std::unordered_map& replace); - /*! * \brief Replace the tensor reference (especially in Call's) in stmt by the replace map. * \param stmt The statement to be processed. @@ -80,6 +72,18 @@ Expr ReplaceTensor(Expr expr, Stmt Substitute(Stmt stmt, const std::unordered_map& value_map); +/*! + * \brief Converts Halide ForType to its corresponding IterVarType + * \param for_type The ForType to be converted + */ +IterVarType ForTypeToIterVarType(ir::ForType for_type); + +/*! + * \brief Converts IterVarType to its corresponding Halide ForType + * \param iter_type The IterVarType to be converted + */ +ir::ForType IterVarTypeToForType(IterVarType iter_type); + } // namespace op } // namespace tvm #endif // TVM_OP_OP_UTIL_H_ diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index 668b1598446b..a54fec3a7bf7 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -3,7 +3,7 @@ from tvm.hybrid.intrin import HYBRID_GLOBALS @nose.tools.nottest -def run_and_check(func, args, var_dict={}, target='llvm'): +def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None): def tvm_val_2_py_val(val): val = tvm.ir_pass.Substitute(val, var_dict) val = tvm.ir_pass.Simplify(val) @@ -13,8 +13,14 @@ def tvm_val_2_py_val(val): ctx = tvm.context(target, 0) op = None - outs = func(*tuple(tvm.convert(i) if isinstance(i, list) else i for i in args)) - op = outs[0].op if isinstance(outs, list) else outs.op + if sch is None: + outs = func(*tuple(tvm.convert(i) if isinstance(i, list) else i for i in args)) + op = outs[0].op if isinstance(outs, list) else outs.op + sch = tvm.create_schedule(op) + else: + assert outs is not None + assert isinstance(outs, list) + op = outs[0].op emu_args = [] nd_args = [] @@ -30,13 +36,13 @@ def tvm_val_2_py_val(val): assert isinstance(i, list) emu_args.append(numpy.array(i)) - sch = tvm.create_schedule(op) + compile_args = [i for i in args if isinstance(i, (tvm.tensor.Tensor, tvm.expr.Var))] + \ + (outs if isinstance(outs, list) else [outs]) module = tvm.build(sch, - [i for i in args if isinstance(i, (tvm.tensor.Tensor, tvm.expr.Var))] + \ - (outs if isinstance(outs, list) else [outs]), + compile_args, target=target) assert module - + out_tensors = [] for i in range(op.num_outputs): output = op.output(i) @@ -47,7 +53,7 @@ def tvm_val_2_py_val(val): ref_data = func(*emu_args) if isinstance(ref_data, numpy.ndarray): ref_data = [ref_data] - + module(*nd_args) for nd, np in zip(out_tensors, ref_data): @@ -282,9 +288,38 @@ def vec_add(a, b): a = tvm.placeholder((1000, ), dtype='float32', name='a') b = tvm.placeholder((1000, ), dtype='float32', name='b') - run_and_check(vec_add, [a, b], target='cuda') + @script + def raw(a, b): + c = output_tensor((1000, ), 'float32') + for i in range(1000): + c[i] = a[i] + b[i] + return c + + c = raw(a, b) + sch = tvm.create_schedule(c.op) + x = tvm.thread_axis('threadIdx.x') + sch[c].bind(c.op.axis[0], x) + run_and_check(raw, [a, b], sch=sch, outs=[c], target='cuda') + + # Test loop binds + @tvm.hybrid.script + def goo(a, b): + c = output_tensor(a.shape, a.dtype) + len_b = len(b) + for i in const_range(len_b * 2): + if i < len_b: + c[i] = a[i] + b[i] + else: + c[i - len_b] = a[i - len_b] + b[i - len_b] + return c + a = tvm.placeholder((5, ), name='a', dtype='int32') + b = [1, 2, 3, 4, 5] + c = goo(a, tvm.convert(b)) + sch = tvm.create_schedule(c.op) + run_and_check(goo, [a, b], sch=sch, outs=[c]) + def test_math_intrin(): @script def intrin_real(a): @@ -593,6 +628,68 @@ def hoo(a, b): b = [1, 2, 3, 4, 5] run_and_check(hoo, [a, b]) +def test_schedule(): + @script + def outer_product(a, b): + c = output_tensor((64, 64), a.dtype) + for i in range(64): + for j in range(64): + c[i, j] = a[i] * b[j] + return c + a = tvm.placeholder((64,), name='a', dtype='float32') + b = tvm.placeholder((64,), name='b', dtype='float32') + c = outer_product(a, b) + + # Test perfect loop split + # Test loop reorder + # Test loop annotation + sch = tvm.create_schedule(c.op) + i, j = c.op.axis + io, ii = sch[c].split(i, 4) + sch[c].parallel(ii) + jo, ji = sch[c].split(j, 4) + joo, joi = sch[c].split(jo, 4) + sch[c].vectorize(ji) + sch[c].reorder(ii, io, joo, joi, ji) + ir = tvm.lower(sch, [a, b, c], simple_mode=True) + assert isinstance(ir, tvm.stmt.ProducerConsumer) + ir = ir.body + assert isinstance(ir, tvm.stmt.AttrStmt) + ir = ir.body + assert isinstance(ir, tvm.stmt.For) + assert ir.loop_var.name == 'i.inner' + ir = ir.body + assert isinstance(ir, tvm.stmt.For) + assert ir.loop_var.name == 'i.outer' + ir = ir.body + assert isinstance(ir, tvm.stmt.For) + assert ir.loop_var.name == 'j.outer.outer' + ir = ir.body + assert isinstance(ir, tvm.stmt.For) + assert ir.loop_var.name == 'j.outer.inner' + ir = ir.body + run_and_check(outer_product, [a, b], sch=sch, outs=[c]) + + # Test fuse + sch = tvm.create_schedule(c.op) + sch[c].fuse(c.op.axis[0], c.op.axis[1]) + ir = tvm.lower(sch, [a, b, c], simple_mode=True) + assert isinstance(ir, tvm.stmt.ProducerConsumer) + ir = ir.body + assert isinstance(ir, tvm.stmt.AttrStmt) + ir = ir.body + assert isinstance(ir, tvm.stmt.For) + assert ir.loop_var.name == 'i.j.fused' + run_and_check(outer_product, [a, b], sch=sch, outs=[c]) + + # Test imperfect loop split + sch = tvm.create_schedule(c.op) + sch[c].split(c.op.axis[0], 3) + ir = tvm.lower(sch, [a, b, c], simple_mode=True) + run_and_check(outer_product, [a, b], sch=sch, outs=[c]) + + # Test loop binds + if __name__ == "__main__": test_outer_product() @@ -610,5 +707,6 @@ def hoo(a, b): test_func_call() test_bool() test_const_range() + test_schedule() # TODO: # test_inplace()