From 7ee0902f86c61c010f3760c53e388efbd38783dd Mon Sep 17 00:00:00 2001 From: Chenfan Date: Tue, 26 May 2020 13:42:13 +0800 Subject: [PATCH 01/45] Code migration Start (#1) * Init commit: Code migration Start * Add loop_state.cc/h * Add ComputeDAG basic test --- CMakeLists.txt | 1 + src/ansor/compute_dag.cc | 1245 +++++++++++++++++++++++++++ src/ansor/compute_dag.h | 161 ++++ src/ansor/expr_hasher.h | 97 +++ src/ansor/loop_state.cc | 1729 ++++++++++++++++++++++++++++++++++++++ src/ansor/loop_state.h | 732 ++++++++++++++++ src/ansor/utils.cc | 102 +++ src/ansor/utils.h | 482 +++++++++++ tests/cpp/ansor_test.cc | 95 +++ 9 files changed, 4644 insertions(+) create mode 100644 src/ansor/compute_dag.cc create mode 100644 src/ansor/compute_dag.h create mode 100644 src/ansor/expr_hasher.h create mode 100644 src/ansor/loop_state.cc create mode 100644 src/ansor/loop_state.h create mode 100644 src/ansor/utils.cc create mode 100644 src/ansor/utils.h create mode 100644 tests/cpp/ansor_test.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index d7faa8a4b666..5550b5f6b3a8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -185,6 +185,7 @@ assign_source_group("Include" ${GROUP_INCLUDE}) # Source file lists file(GLOB_RECURSE COMPILER_SRCS + src/ansor/*.cc src/node/*.cc src/ir/*.cc src/arith/*.cc diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc new file mode 100644 index 000000000000..31136985b330 --- /dev/null +++ b/src/ansor/compute_dag.cc @@ -0,0 +1,1245 @@ +/*! + * Copyright (c) 2020 by Contributors + */ +#include "compute_dag.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +// #include "loop_state.h" +#include "utils.h" +// #include "../relay/pass/kernel_layout_transform.h" + +namespace tvm { +namespace ansor { + +using namespace tvm::tir; + +TVM_REGISTER_NODE_TYPE(ComputeDAGNode); + +template +using OperationMap = AccessAnalyzerNode::OperationMap; + +using OperationSet = std::unordered_set; + +// Topo-sort ops from tensors according to their read-write relations. +// Results are stored in ops +void TopoSortOps(const Array& tensors, std::vector* ops) { + std::unordered_map degree; + std::unordered_map > edge_set; + std::unordered_map priority; + std::unordered_set visited; + + // traverse to build edge_set and count degree + std::vector stack; + stack.reserve(tensors.size()); + for (const auto& x : tensors) { + stack.push_back(x->op.operator->()); + } + + int ct = 0; + while (!stack.empty()) { + const te::OperationNode* op = stack.back(); + stack.pop_back(); + if (visited.count(op)) { + continue; + } + + priority[op] = ct; + ct++; + visited.insert(op); + + if (op->IsInstance()) { + degree[op] = 0; + } else if (auto cop = GetRef(op).as()) { + const Array& input_tensors = cop->InputTensors(); + degree[op] = input_tensors.size(); + for (const auto& ten : input_tensors) { + edge_set[ten->op.operator->()].push_back(op); + stack.push_back(ten->op.operator->()); + } + } else { + LOG(FATAL) << "Unsupported op " << GetRef(op); + } + } + + // topo sort + ops->clear(); + + using Item = std::pair; + auto cmp = [](const Item& left, const Item& right) { + return left.second < right.second; + }; + std::priority_queue, decltype(cmp)> queue(cmp); + for (const auto& iter : degree) { + if (iter.second == 0) { + queue.push(Item(iter.first, priority[iter.first])); + } + } + + ops->reserve(degree.size()); + while (!queue.empty()) { + Item item = queue.top(); + queue.pop(); + ops->push_back(GetRef(item.first)); + for (const auto& dst : edge_set[item.first]) { + degree[dst] -= 1; + if (degree[dst] == 0) { + queue.push(Item(dst, priority[dst])); + } + } + } +} + +// Extract all tensor accesses in an expr +class TensorAccessExtractor : public StmtExprVisitor { + public: + void Extract(PrimExpr expr) { + this->VisitExpr(expr); + } + + void VisitExpr_(const CallNode *op) final { + if (op->call_type == CallNode::CallType::Halide) { + buf_accesses[Downcast(op->func)].emplace_back( + op->args.begin(), op->args.end()); + } + if (op->name == tir::intrinsic::tvm_if_then_else) { + has_branch = true; + } + StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const IfThenElseNode* op) final { + has_branch = true; + StmtExprVisitor::VisitStmt_(op); + } + + void VisitExpr_(const SelectNode* op) final { + has_branch = true; + StmtExprVisitor::VisitExpr_(op); + } + + OperationMap > > buf_accesses; + bool has_branch{false}; +}; + +// Returns whether the expr equals to the var with a const shift +bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) { + if (auto pv = expr.as()) { + return pv == var.get(); + } else if (auto padd = expr.as()) { + return ((padd->a.get() == var.get() && padd->b->IsInstance()) || + (padd->b.get() == var.get() && padd->a->IsInstance())); + } else if (auto psub = expr.as()) { + return ((psub->a.get() == var.get() && psub->b->IsInstance()) || + (psub->b.get() == var.get() && psub->a->IsInstance())); + } else { + return false; + } +} + +// Return whether the access is injective +bool IsInjective(const te::Operation& op, const std::vector& index, + bool* axis_missing, bool* axis_duplicated, bool* same_order) { + auto cop = op.as(); + if (cop == nullptr) { return false; } + + std::vector index_to_var_idx; + std::vector var_idx_ct(cop->axis.size(), 0); + + for (const auto& expr : index) { + if (!is_const(expr)) { + bool found = false; + for (size_t i = 0; i < cop->axis.size(); ++i) { + if (IsConstShiftEqual(cop->axis[i]->var, expr)) { + index_to_var_idx.push_back(i); + var_idx_ct[i]++; + found = true; + break; + } + } + if (!found) { + return false; + } + } + } + + *axis_missing = false; // Some axes are missing + *axis_duplicated = false; // Some axes appear more than once + *same_order = true; // The axis order is the same as op->axis + for (int ct : var_idx_ct) { + if (ct == 0) { + *axis_missing = true; + } else if (ct > 1) { + *axis_duplicated = true; + } + } + for (size_t i = 1; i < index_to_var_idx.size(); ++i) { + if (index_to_var_idx[i] < index_to_var_idx[i - 1]) { + *same_order = false; + break; + } + } + + return true; +} + +// Gather all VarNodes in an expr +static void GatherVars(const PrimExpr& expr, std::unordered_set* vars) { + PostOrderVisit(expr, [&vars](const ObjectRef &node) { + if (const VarNode* op = node.as()) { + vars->insert(op); + } + }); +} + +// Check whether an expr has expensive operations (e.g. exp) +static bool HasExpensiveOp(const PrimExpr& expr) { + bool found = false; + PostOrderVisit(expr, [&found](const ObjectRef &node) { + if (const CallNode* op = node.as()) { + if (op->call_type == CallNode::CallType::PureIntrinsic && op->name == "exp") { + found = true; + } + } + }); + return found; +} + +AccessAnalyzer AccessAnalyzerNode::make(const Array& tensors) { + auto node = make_object(); + OperationMap has_branch; + + // get all ops + TopoSortOps(tensors, &node->ops_topo_order); + + // build read & write access map + for (const auto& op : node->ops_topo_order) { + if (op->IsInstance()) { + node->read_from[op] = OperationMap > >(); + } else if (auto cop = op.as()) { + TensorAccessExtractor extractor; + for (const auto& exp : cop->body) { + extractor.Extract(exp); + } + + for (const auto& iter : extractor.buf_accesses) { + std::vector >& accesses = node->read_by[iter.first][op]; + accesses.insert(accesses.begin(), iter.second.begin(), iter.second.end()); + } + + node->read_from[op] = std::move(extractor.buf_accesses); + has_branch[op] = extractor.has_branch; + } else { + LOG(FATAL) << "Invalid op: " << op; + } + } + + // do some static analysis + for (const auto& op : node->ops_topo_order) { + if (op->IsInstance()) { + node->is_injective[op] = true; + node->needs_multi_level_tiling[op] = false; + node->is_strict_inlineable[op] = false; + node->is_output[op] = false; + } else if (auto pop = op.as()) { + // check whether is element-wise and strict-inlineable (see definition in compute_dag.h) + bool is_injective = true; + bool is_strict_inlineable = true; + + bool axis_missing, axis_duplicated, same_order; + for (const auto& pair : node->read_from[op]) { + const std::vector >& access = pair.second; + for (const auto& index : access) { + if (!IsInjective(op, index, &axis_missing, &axis_duplicated, &same_order)) { + is_injective = false; + is_strict_inlineable = false; + break; + } + if (!same_order || axis_duplicated) { // do not strictly inline transpose + is_strict_inlineable = false; + } + } + if (!is_injective) { break; } + } + if (has_branch[op]) { + is_strict_inlineable = false; + } + + // don't strictly inline expensive op (e.g. exp) + bool has_expensive_op = false; + for (const auto& expr : pop->body) { + has_expensive_op |= HasExpensiveOp(expr); + } + + node->is_injective[op] = is_injective; + node->is_strict_inlineable[op] = is_strict_inlineable && !has_expensive_op; + + // check whether the op needs multi-level tiling (see definition in compute_dag.h) + bool needs_multi_level_tiling = false; + int n_missing = 0; + + for (const auto& pair : node->read_from[op]) { + const std::vector > &access = pair.second; + std::unordered_set vars; + for (const std::vector &indices : access) { + for (const PrimExpr& expr : indices) { + GatherVars(expr, &vars); + } + } + bool missing = false; + for (const auto& axis : pop->axis) { + if (GetIntImm(axis->dom->extent) > 1 && vars.count(axis->var.get()) == 0) { + missing = true; + } + } + if (missing) { + n_missing++; + } + + if (n_missing >= 2 || (n_missing >= 1 && !pop->reduce_axis.empty())) { + needs_multi_level_tiling = true; + break; + } + } + + node->needs_multi_level_tiling[op] = needs_multi_level_tiling; + + // check whether is output + node->is_output[op] = node->read_by[op].empty(); + } else { + LOG(FATAL) << "Invalid op" << op; + } + } + + return AccessAnalyzer(node); +} + +bool AccessAnalyzer::NeedsMultiLevelTiling(const te::Operation &op) const { + return operator->()->needs_multi_level_tiling.at(op); +} + +bool AccessAnalyzer::IsOutput(const te::Operation& op) const { + return operator->()->is_output.at(op); +} + +bool AccessAnalyzer::IsInjective(const te::Operation& op) const { + return operator->()->is_injective.at(op); +} + +bool AccessAnalyzer::IsStrictInlineable(const te::Operation &op) const { + return operator->()->is_strict_inlineable.at(op); +} + +void AccessAnalyzer::GetProducers(const State& state, const te::Operation& op, + OperationSet* producers) const { + producers->clear(); + for (const auto& iter : operator->()->read_from.at(op)) { + producers->insert(iter.first); + } +} + +// void AccessAnalyzer::GetConsumers(const State& state, const te::Operation& op, +// OperationSet* consumers) const { +// OperationSet inlined_ops; + +// for (const auto& stage : state->stages) { +// if (stage->compute_at == kInlined) { +// inlined_ops.insert(stage->op); +// } +// } +// std::function collect; + +// collect = [this, &collect, &inlined_ops, &consumers](const Operation& op) { +// for (const auto& iter : operator->()->read_by.at(op)) { +// if (inlined_ops.count(iter.first)) { +// collect(iter.first); +// } else { +// consumers->insert(iter.first); +// } +// } +// }; + +// consumers->clear(); +// collect(op); +// } + +bool IntArrayEqual(const Array& arr1, const Array& arr2) { + if (arr1.size() != arr2.size()) { + return false; + } + + for (size_t i = 0; i < arr1.size(); ++i) { + auto int1 = arr1[i].as(); + auto int2 = arr2[i].as(); + CHECK(int1 != nullptr); + CHECK(int2 != nullptr); + if (int1->value != int2->value) { + return false; + } + } + return true; +} + +bool AccessAnalyzer::ElementWiseMatch(const te::Operation& op, + const te::Operation& target_op) const { + te::Operation cur_op = op; + while (cur_op != target_op) { + const AccessAnalyzerNode::OperationMap > >& map = + operator->()->read_by.at(cur_op); + + if (map.size() != 1) { + return false; + } + te::Operation next_op = map.begin()->first; + + // Check condition 1: has the same output size + auto p_cur = cur_op.as(); + auto p_next = next_op.as(); + if (p_cur == nullptr || p_next == nullptr) { + return false; + } + + Array output_shape = p_cur->output_shape(0); + for (int i = 1; i < p_cur->num_outputs(); ++i) { + if (!IntArrayEqual(p_cur->output_shape(i), output_shape)) { + return false; + } + } + for (int i = 0; i < p_next->num_outputs(); ++i) { + if (!IntArrayEqual(p_next->output_shape(i), output_shape)) { + return false; + } + } + + // Check condition 2: read is elementwise + const std::vector > reads = map.begin()->second; + bool is_injective, axis_missing, axis_duplicated, same_order; + for (const auto& read : reads) { + is_injective = ::tvm::ansor::IsInjective( + next_op, read, &axis_missing, &axis_duplicated, &same_order); + if (!is_injective || axis_missing || axis_duplicated || !same_order) { + return false; + } + } + + cur_op = std::move(next_op); + } + return true; +} + +// Estimate number of float operations in an expression +class FlopEstimator: public ExprFunctor { + public: + double EstimateFlop(const Array& ops) { + double ret = 0; + for (const auto& op : ops) { + if (auto pop = op.as()) { + double num_element = AxisLengthProd(pop->axis); + if (num_element == -1) { + fail = true; + break; + } + double op_per_element = 0; + for (const auto& x : pop->body) { + op_per_element += VisitExpr(x); + } + ret += num_element * op_per_element; + } else if (op->IsInstance()) { + {} // do nothing + } else { + LOG(FATAL) << "Invalid op type " << op; + } + } + + return fail ? -1 : ret; + } + + double VisitExpr_(const ReduceNode* op) final { + uint64_t num_iter = 1; + for (const auto& x : op->axis) { + if (auto imm = x->dom->extent.as()) { + num_iter *= imm->value; + } else { + fail = true; + num_iter = -1; + } + } + double body_flop = 0; + for (size_t i = 0; i < op->combiner->result.size(); ++i) { + body_flop += VisitExpr(op->combiner->result[i]); + body_flop += VisitExpr(op->source[i]); + } + return num_iter * body_flop; + } + + double VisitExpr_(const FloatImmNode* op) final { return 0.0; } + double VisitExpr_(const IntImmNode* op) final { return 0.0; } +// double VisitExpr_(const UIntImm* op) final { return 0.0; } + + double VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); } + double VisitExpr_(const VarNode* op) final { return 0.0; } + + double VisitExpr_(const SelectNode* op) final { + return VisitExpr(op->condition) + std::max(VisitExpr(op->true_value), + VisitExpr(op->false_value)); + } + +#define VisitBinary(Node) \ + double VisitExpr_(const Node* op) final { \ + return 1.0 + VisitExpr(op->a) + VisitExpr(op->b); \ + } +#define VisitUnary(Node) \ + double VisitExpr_(const Node* op) final { \ + return 1.0 + VisitExpr(op->a); \ + } + + VisitBinary(AddNode); VisitBinary(SubNode); VisitBinary(MulNode) + VisitBinary(DivNode); VisitBinary(ModNode); VisitBinary(FloorDivNode) + VisitBinary(FloorModNode); VisitBinary(MaxNode); VisitBinary(MinNode); + VisitBinary(EQNode); VisitBinary(NENode); VisitBinary(LTNode); + VisitBinary(LENode); VisitBinary(GTNode); VisitBinary(GENode); + VisitBinary(AndNode); VisitBinary(OrNode); VisitUnary(NotNode); + + double VisitExpr_(const CallNode* op) final { + if (op->call_type == CallNode::CallType::Halide) { + // ignore flops in index expressions + return 0.0; + } + + double ret = 0.0; + for (const auto&x : op->args) { + ret += VisitExpr(x); + } + return ret; + } + + double VisitExprDefault_(const Object* op) final { + fail = true; + return -1.0; + } + + bool fail{false}; +}; + +void UpdateStageAxis(const te::Stage& stage, StageToAxesMap *stage_to_axes) { + if (auto pop = stage->op.as()) { + std::vector& axes = (*stage_to_axes)[stage]; + axes.clear(); + for (const auto& axis : pop->axis) { + axes.push_back(axis); + } + for (const auto& axis : pop->reduce_axis) { + axes.push_back(axis); + } + } else if (stage->op->IsInstance()) { + {} // do nothing + } else { + LOG(FATAL) << "Invalid op " << stage->op; + } +} + +// State ComputeDAG::GetInitState() const { +// return Downcast(operator->()->init_state); +// } + +ComputeDAG ComputeDAGNode::make(Array tensors) { + auto node = make_object(); + FlopEstimator estimator; + + node->tensors = std::move(tensors); + node->access_analyzer = AccessAnalyzerNode::make(node->tensors); + node->ops = Array(node->access_analyzer->ops_topo_order); + node->flop_ct = estimator.EstimateFlop(node->ops); +// node->init_state = StateNode::make(node->ops); + + return ComputeDAG(node); +} + +ComputeDAG ComputeDAGNode::make_by_workload_key(const std::string& workload_key) { + Array tens; + // Call python function to decode the workload_key and get the I/O tensors + if (const auto* f = runtime::Registry::Get("ansor.workload_key_to_tensors")) { + tens = (*f)(workload_key); + } else { + LOG(FATAL) << "ansor.workload_key_to_tensors is not registered"; + } + return ComputeDAGNode::make(std::move(tens)); +} + +void ComputeDAGNode::VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("tensors", &tensors); + v->Visit("ops", &ops); + v->Visit("flop_ct", &flop_ct); + v->Visit("access_analyzer", &access_analyzer); +// State s = Downcast(init_state); +// v->Visit("init_state", &s); +} + +// Implemented in multi_stage_policy.cc +// Extract primitive iterators from a nested fused or splitted iterator's name +extern void ExtractOriginalIterators(const std::string& name, std::set* rets); + +// Implemented in loop_state.cc +extern std::string CleanName(const std::string& str); + +std::string BaseName(const std::string& str) { + return str.substr(0, str.rfind("_")); +} + +// class IndexRewriter : public ExprMutator { +// public: +// IndexRewriter(const OperationMap >& placeholder_new_names, +// const OperationMap >& placeholder_new_shapes): +// placeholder_new_names_(placeholder_new_names), +// placeholder_new_shapes_(placeholder_new_shapes) {} + +// Expr Mutate_(const Call* op, const Expr& e) { +// Expr op_ = IRMutator::Mutate_(op, e); + +// const Call* call = op_.as(); + +// if (call->call_type == Call::CallType::Halide) { +// Tensor t = Downcast(call->func).output(call->value_index); +// auto it = placeholder_new_names_.find(t->op); +// if (it != placeholder_new_names_.end()) { +// const std::vector& new_names = it->second; +// const Array& new_shape = placeholder_new_shapes_.at(t->op); +// std::unordered_map name_to_arg; +// for (const auto& arg : call->args) { +// std::string axis_name; +// if (const auto* pimm = arg.as()) { +// CHECK_EQ(pimm->value, 0); +// axis_name = "IntImm"; +// } else { +// axis_name = BaseName(CleanName(Downcast(arg)->name_hint)); +// CHECK_EQ(name_to_arg.count(axis_name), 0); +// name_to_arg[axis_name] = arg; +// } +// } + +// std::unordered_map div_factors; +// std::vector r_new_args; +// for (int i = new_names.size() - 1; i >= 0; --i) { +// auto ori_iter_name = new_names[i]; +// auto name_it = name_to_arg.find(ori_iter_name); +// CHECK(name_it != name_to_arg.end()); +// Expr ori_arg = name_it->second; + +// Expr mod_factor = new_shape[i]; + +// Expr div_factor = 1; +// if (div_factors.count(ori_iter_name)) { +// div_factor = div_factors[ori_iter_name]; +// } +// div_factors[ori_iter_name] = div_factor * new_shape[i]; + +// Expr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor); + +// r_new_args.push_back(new_arg); +// } + +// Array new_args(std::make_move_iterator(r_new_args.rbegin()), +// std::make_move_iterator(r_new_args.rend())); + +// return Call::make(call->type, call->name, new_args, call->call_type, +// call->func, call->value_index); +// } +// } +// return op_; +// } + +// private: +// const OperationMap >& placeholder_new_names_; +// const OperationMap >& placeholder_new_shapes_; +// }; + +// // TODO(minminsun): spill out new functions +// void ComputeDAG::RewriteLayout( +// const std::vector &transform_steps, LayoutRewriteLevel layout_rewrite_level) const { +// ComputeDAGNode* pdag = const_cast(this)->CopyOnWrite(); +// const State& state = ReplayAndInferBound(transform_steps); + +// OperationMap > placeholder_new_names; +// OperationMap > placeholder_new_shapes; +// int stage_id = -1; +// for (const auto& stage : state->stages) { +// stage_id += 1; +// const Operation& op = stage->op; +// if (op->IsInstance()) { +// const Map& attrs = op->attrs; +// if (attrs.count(_layout_free_placeholders_key)) { +// const ObjectRef& attr_value = attrs[_layout_free_placeholders_key]; +// Array placeholders = Downcast>(attr_value); +// for (auto& placeholder : placeholders) { +// const auto placeholder_op = placeholder->op; + +// // Check whether this placeholder has already been handled +// if (placeholder_new_names.count(placeholder_op)) { +// continue; +// } + +// // skip the op that is not direct consumer of this placeholder, +// // mostly due to cache read/write. +// bool direct_consumer = false; +// for (auto& t : op->InputTensors()) { +// if (t->op == placeholder_op) { +// direct_consumer = true; +// break; +// } +// } +// if (!direct_consumer) { +// continue; +// } + +// std::set placeholder_axis_names; +// TensorAccessExtractor extractor; +// for (const auto& exp : op.as()->body) { +// extractor.Extract(exp); +// } +// bool rewrite_placeholder = (layout_rewrite_level == kPlaceholderRewrite || +// layout_rewrite_level == kBothRewrite); +// bool rewrite_body = (layout_rewrite_level == kComputeRewrite || +// layout_rewrite_level == kBothRewrite); +// std::ostringstream os; + +// uint i = 0; +// if (extractor.buf_accesses.count(placeholder_op)) { +// for (const auto& ev : extractor.buf_accesses[placeholder_op]) { +// for (const auto& e : ev) { +// // TODO(minminsun): check whether the extents match the shape of placeholder +// std::string axis_name; +// if (const auto* pimm = e.as()) { +// CHECK_EQ(pimm->value, 0); +// // CHECK_EQ(placeholder->shape[i].as()->value, 1); +// axis_name = "IntImm"; +// } else { +// axis_name = BaseName(CleanName(Downcast(e)->name_hint)); +// } + +// placeholder_axis_names.insert(axis_name); +// if (rewrite_placeholder) { +// os << placeholder->shape[i++] << axis_name; +// } +// } +// } + +// if (rewrite_placeholder) { +// CHECK_EQ(placeholder_axis_names.size(), placeholder->shape.size()); +// std::string ori_layout = os.str(); +// os.str(""); +// ::tvm::relay::KernelLayoutVisitor::global_ori_layouts_queue.push_back(ori_layout); +// } +// } + +// std::vector stage_iters; + +// auto attach_it = state->attach_map->stage_to_attach_iter.find(stage_id); +// int attach_pos = -1; +// size_t iters_before_attach = 0; +// if (attach_it != state->attach_map->stage_to_attach_iter.end()) { +// auto attach = attach_it->second; +// const auto& attach_stage = state->stages[attach.first]; +// attach_pos = attach.second; +// stage_iters.insert(stage_iters.end(), +// attach_stage->iters.begin(), +// attach_stage->iters.begin() + attach_pos + 1); +// } + +// stage_iters.insert(stage_iters.end(), stage->iters.begin(), stage->iters.end()); + +// std::vector iters; +// for (size_t i = 0; i < stage_iters.size(); ++i) { +// const auto& iter = stage_iters[i]; +// if (iter->ori_iters.empty()) { +// iters.push_back(iter); +// } else { +// for (const Iterator& ori_iter : iter->ori_iters) { +// iters.push_back(ori_iter); +// } +// } +// if (static_cast(i) == attach_pos) { +// iters_before_attach = iters.size(); +// } +// } + +// std::vector new_names; +// Array new_shape; +// std::vector new_axis_names; +// for (const Iterator& iter : iters) { +// std::set ori_iter_names; +// ExtractOriginalIterators(iter->name, &ori_iter_names); +// // fused iters have been replaced with iter->ori_iters. +// // So there should be only one ori iter name extracted from iter->name. +// CHECK_EQ(ori_iter_names.size(), 1); +// auto ori_iter_name = BaseName(*ori_iter_names.begin()); +// new_axis_names.push_back(ori_iter_name); +// } +// for (size_t i = 0; i < new_axis_names.size(); ++i) { +// auto iter = iters[i]; +// std::string ori_iter_name; +// if (i < iters_before_attach) { +// ori_iter_name = new_axis_names[i + iters_before_attach]; +// } else { +// ori_iter_name = new_axis_names[i]; +// } +// if (placeholder_axis_names.count(ori_iter_name)) { +// os << iter->range->extent << ori_iter_name; +// new_names.push_back(ori_iter_name); +// new_shape.push_back(iter->range->extent); +// } +// } +// std::string new_layout = os.str(); +// os.str(""); +// ::tvm::relay::KernelLayoutVisitor::global_new_layouts_queue.push_back(new_layout); +// placeholder_new_names[placeholder_op] = new_names; +// placeholder_new_shapes[placeholder_op] = new_shape; + +// Array old_ops = pdag->ops; +// ArrayNode* pops = pdag->ops.CopyOnWrite(); + +// // Create new placeholder +// Operation new_placeholder_op; +// if (rewrite_placeholder) { +// new_placeholder_op = +// te::PlaceholderOpNode::make(placeholder_op->name, +// new_shape, +// placeholder_op.as()->dtype); +// } else { +// new_placeholder_op = placeholder_op; +// } + +// Operation new_compute_op, old_compute_op; +// if (rewrite_body) { +// Array new_body; +// IndexRewriter index_rewriter(placeholder_new_names, +// placeholder_new_shapes); +// for (auto& op : old_ops) { +// if (auto* pop = op.as()) { +// bool need_update = false; +// for (auto& t : op->InputTensors()) { +// if (t->op == placeholder_op) { +// need_update = true; +// break; +// } +// } +// if (need_update) { +// for (auto& body : pop->body) { +// new_body.push_back(index_rewriter.Mutate(body)); +// } +// old_compute_op = op; +// CHECK(!new_compute_op.defined()); +// new_compute_op = ComputeOpNode::make( +// pop->name, pop->tag, pop->attrs, pop->axis, new_body); +// } +// } +// } +// } + +// // construct the map from old_op to new_op +// std::unordered_map updated_ops; +// for (size_t i = 0; i < old_ops.size(); ++i) { +// auto old_op = old_ops[i]; +// if (rewrite_placeholder && old_op == placeholder_op) { +// pops->data[i] = new_placeholder_op; +// updated_ops[placeholder_op] = new_placeholder_op; +// } else if (rewrite_body && old_op == old_compute_op) { +// pops->data[i] = new_compute_op; +// updated_ops[old_compute_op] = new_compute_op; +// } else { +// pops->data[i] = old_op; +// } +// } + +// // Because ops is sorted in topo-order, only do one pass linear scan here. +// for (size_t i = 0; i < pops->data.size(); ++i) { +// auto old_op = Downcast(pops->data[i]); +// if (auto* pop = old_op.as()) { +// auto inputs = pop->InputTensors(); +// std::unordered_map rmap; +// for (auto input : inputs) { +// auto it = updated_ops.find(input->op); +// Operation new_op; +// while (it != updated_ops.end()) { +// new_op = it->second; +// it = updated_ops.find(new_op); +// } +// if (new_op.defined()) { +// int index = input->value_index; +// rmap[input] = new_op.output(index); +// } +// } +// if (!rmap.empty()) { +// Operation new_op = pop->ReplaceInputs(old_op, rmap); +// updated_ops[old_op] = new_op; +// pops->data[i] = new_op; +// } +// } +// } + +// pdag->init_state = StateNode::make(pdag->ops); + +// Array old_tensors = pdag->tensors; +// ArrayNode* ptensors = pdag->tensors.CopyOnWrite(); + +// for (size_t i = 0; i < old_tensors.size(); ++i) { +// const auto& old_tensor = old_tensors[i]; +// auto it = updated_ops.find(old_tensor->op); +// Operation new_op; +// while (it != updated_ops.end()) { +// new_op = it->second; +// it = updated_ops.find(new_op); +// } +// if (new_op.defined()) { +// if (layout_rewrite_level == kBothRewrite) { +// auto index = old_tensor->value_index; +// ptensors->data[i] = new_op.output(index); +// } else if (layout_rewrite_level == kComputeRewrite) { +// TensorNode* old_tensor_node = const_cast(old_tensor.as()); +// old_tensor_node->op = new_op; +// } +// } +// } +// } // end for placeholder +// } +// } +// } // end for stage +// } + +std::pair > ComputeDAG::ApplySteps( + const std::vector& transform_steps, + LayoutRewriteLevel layout_rewrite_level) const { + std::vector stages; + StageToAxesMap stage_to_axes; + if (layout_rewrite_level != kNoRewrite && !transform_steps.empty()) { + ComputeDAG new_dag = *this; + new_dag.RewriteLayout(transform_steps, layout_rewrite_level); + return new_dag.ReplaySteps(transform_steps, &stages, &stage_to_axes); + } else { + return ReplaySteps(transform_steps, &stages, &stage_to_axes); + } +} + +// std::string ComputeDAG::PrintStepsAsPython( +// const std::vector& transform_steps) const { +// std::vector stages; +// StageToAxesMap stage_to_axes; +// Array ops; +// for (const auto& op : operator->()->ops) { +// if (!op->IsInstance()) { +// ops.push_back(op); +// } +// } +// te::Schedule schedule = te::create_schedule({ops.back()}); + +// // init axes +// for (const auto& x : operator->()->ops) { +// const te::Stage& stage = schedule.operator[](x); +// stages.push_back(stage); +// UpdateStageAxis(stage, &stage_to_axes); +// } + +// std::stringstream ss; + +// for (const auto& stage : stages) { +// if (stage->op->IsInstance()) { +// for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { +// ss << stage->leaf_iter_vars[i]->var->name_hint; +// if (i != stage->leaf_iter_vars.size() - 1) { +// ss << ", "; +// } +// } +// ss << " = " << "tuple(" << stage->op->func_name() << ".op.axis)" +// << " + " << "tuple(" << stage->op->func_name() << ".op.reduce_axis)\n"; +// } +// } + +// for (const auto& step : transform_steps) { +// ss << step->PrintAsPythonAPI(&stages, &stage_to_axes, &schedule, transform_steps); +// } + +// return ss.str(); +// } + +// State ComputeDAG::ReplayAndInferBound(const std::vector& transform_steps) const { +// State ret_state = GetInitState(); +// StateNode* pstate = ret_state.CopyOnWrite(); +// pstate->transform_steps = transform_steps; +// ret_state.DoSteps(transform_steps, *this); + +// InferBoundCommon(pstate); + +// return ret_state; +// } + +// State ComputeDAG::InferBound(const State& state) const { +// State ret_state = state; +// StateNode* pstate = ret_state.CopyOnWrite(); + +// InferBoundCommon(pstate); + +// return ret_state; +// } + +// void ComputeDAG::InferBound(std::vector* states) const { +// std::vector out_states(states->size(), State()); + +// auto worker_func = [&states, &out_states, this](int idx) { +// try { +// out_states[idx] = this->InferBound((*states)[idx]); +// } catch (dmlc::Error &e) { +// LOG(WARNING) << "InferBound fails on the state:\n" << (*states)[idx] +// << "\n" << e.what() << std::endl; +// } +// }; + +// // Lower states in parallel +// ThreadPool& pool = ThreadPool::Global(); +// pool.BeginBatch(states->size()); +// for (size_t i = 0; i < states->size(); ++i) { +// pool.Enqueue(worker_func, i); +// } +// pool.WaitBatch(); + +// *states = std::move(out_states); +// } + +void ComputeDAG::ReplayAndGetDAG(const std::vector &transform_steps, + ComputeDAG *task_dag) const { + std::vector stages; + StageToAxesMap stage_to_axes; + te::Schedule sch; + Array old_tensors; + + std::tie(sch, old_tensors) = ReplaySteps(transform_steps, &stages, &stage_to_axes); + + Array new_tensors; + for (auto stage : sch->stages) { + if (stage->op->IsInstance() || + stage->is_output) { + for (auto i = 0; i < stage->op->num_outputs(); ++i) { + new_tensors.push_back(stage->op.output(i)); + } + } + } + + *task_dag = ComputeDAGNode::make(new_tensors); +} + + +// void ComputeDAG::InferBoundCommon(StateNode* pstate) const { +// std::vector stages; +// StageToAxesMap stage_to_axes; +// te::Schedule sch; +// Array tensors; +// Map bounds; + +// std::tie(sch, tensors) = ReplaySteps(pstate->transform_steps, &stages, &stage_to_axes); +// sch = sch.normalize(); +// bounds = schedule::InferBound(sch); + +// for (size_t i = 0; i < pstate->stages.size(); ++i) { +// const Stage& stage = pstate->stages[i]; + +// if (stage->compute_at == kInlined) { +// continue; +// } + +// std::vector new_iters; +// new_iters.reserve(stage->iters.size()); +// for (size_t j = 0; j < stage->iters.size(); ++j) { +// const Iterator& iter = stage->iters[j]; +// const IterVar& axis = stage_to_axes.at(stages[i])[j]; + +// auto find_res = bounds.find(axis); +// if (find_res != bounds.end()) { +// new_iters.push_back(IteratorNode::make(iter->name, (*find_res).second, +// iter->iter_type, iter->annotation, +// &iter->ori_iters)); +// } else { +// LOG(FATAL) << "Infer bound fails"; +// } +// } + +// pstate->stages[i] = StageNode::make(stage->op, stage->op_type, +// std::move(new_iters), stage->compute_at, stage->auto_unroll_max_step, +// stage->storage_offset); +// } +// } + +// std::pair > ComputeDAG::ReplaySteps( +// const std::vector &transform_steps, +// std::vector *stages, +// StageToAxesMap *stage_to_axes) const { +// std::vector ops; +// for (const auto& op : operator->()->ops) { +// if (!op->IsInstance()) { +// ops.push_back(op); +// } +// } + +// te::Schedule schedule = te::create_schedule({ops.back()}); + +// // init axes +// stages->reserve(operator->()->ops.size()); +// for (const auto& x : operator->()->ops) { +// const te::Stage& stage = schedule.operator[](x); +// stages->push_back(stage); +// UpdateStageAxis(stage, stage_to_axes); +// } + +// // todo(lmzheng): should we maintain the attach_map and keep the validity of compute_at +// // an splitted axis? + +// // Use complete rate for the study in the paper +// const char* complete_rate_str = getenv("ANSOR_PROGRAM_COMPLETE_RATE"); +// double complete_rate = -1.0; +// if (complete_rate_str) { +// complete_rate = std::stod(complete_rate_str); +// } +// size_t ct = 0; + +// // replay history +// for (const auto& step : transform_steps) { +// if (complete_rate >= 0 && ct++ > transform_steps.size() * complete_rate) { +// break; +// } + +// if (auto ps = step.as()) { +// ps->ApplyToSchedule(stages, stage_to_axes); +// } else if (auto ps = step.as()) { +// ps->ApplyToSchedule(stages, stage_to_axes); +// } else if (auto ps = step.as()) { +// ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); +// } else if (auto ps = step.as()) { +// ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); +// } else if (auto ps = step.as()) { +// ps->ApplyToSchedule(stages, stage_to_axes); +// } else if (auto ps = step.as()) { +// ps->ApplyToSchedule(stages, stage_to_axes); +// } else if (auto ps = step.as()) { +// ps->ApplyToSchedule(stages, stage_to_axes); +// } else if (auto ps = step.as()) { +// ps->ApplyToSchedule(stages, stage_to_axes); +// } else if (auto ps = step.as()) { +// ps->ApplyToSchedule(stages, stage_to_axes); +// } else if (auto ps = step.as()) { +// ps->ApplyToSchedule(stages, stage_to_axes, &schedule); +// } else if (auto ps = step.as()) { +// ps->ApplyToSchedule(stages, stage_to_axes, &schedule); +// } else if (auto ps = step.as()) { +// ps->ApplyToSchedule(stages, stage_to_axes); +// } else if (auto ps = step.as()) { +// ps->ApplyToSchedule(stages, stage_to_axes, &schedule); +// } else if (auto ps = step.as()) { +// ps->ApplyToSchedule(stages, stage_to_axes); +// } else { +// LOG(FATAL) << "Invalid Step"; +// } +// } + +// return std::make_pair(schedule, operator->()->tensors); +// } + + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { + auto* node = static_cast(ref.get()); + std::stringstream ss; + + for (const auto& op : node->ops) { + if (op->IsInstance()) { + ss << op->func_name() << " = PLACEHOLDER " << op.output(0)->shape << "\n"; + } else if (auto pop = op.as()) { + for (size_t k = 0; k < pop->body.size(); ++k) { + ss << op->func_name() << "("; + for (size_t i = 0; i < pop->axis.size(); i++) { + ss << pop->axis[i]->var->name_hint; + if (i != pop->axis.size() - 1) { + ss << ", "; + } + } + ss << ")"; + if (pop->body.size() > 1) { + ss << ".v" << k; + } + if (auto preduce = pop->body[k].as()) { + CHECK_LT(k, preduce->combiner->result.size()); + PrimExpr combiner = preduce->combiner->result[k]; + if (combiner->IsInstance()) { + ss << " += " << preduce->source[0] << "\n"; + } else if (combiner->IsInstance()) { + ss << " max= " << preduce->source[0] << "\n"; + } else if (combiner->IsInstance()) { + ss << " min= " << preduce->source[0] << "\n"; + } else if (combiner->IsInstance()) { + const auto& select = combiner.as(); + ss << " select(" << select->condition << ", " << select->true_value + << ", " << select->false_value << ")= " + << '(' << preduce->source[0] << ',' << preduce->source[1] << ")\n"; + } else { + LOG(FATAL) << "Unsupported reduction operator" << combiner; + } + } else { + ss << " = " << pop->body[k] << "\n"; + } + } + } else { + LOG(FATAL) << "Invalid op"; + } + } + + p->stream << ss.str(); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { + auto* node = static_cast(ref.get()); + for (const auto& op : node->ops_topo_order) { + p->stream << op << std::endl; + p->stream << "is_injective:\t" << node->is_injective.at(op) << "\t\t"; + p->stream << "needs_multi_level_tiling:\t" + << node->needs_multi_level_tiling.at(op) << std::endl; + p->stream << "is_strict_inlinable:\t" << node->is_strict_inlineable.at(op) << "\t"; + p->stream << "is_output:\t" << node->is_output.at(op) << std::endl; + p->stream << "Read from:\t"; + for (const auto& pair : node->read_from.at(op)) { + for (const auto& index : pair.second) { + p->stream << pair.first->func_name() << Array(index) << ", "; + } + } + p->stream << "\n"; + p->stream << "Read by:\t"; + for (const auto& pair : node->read_by.at(op)) { + for (const auto& index : pair.second) { + p->stream << pair.first->func_name() << Array(index) << ", "; + } + } + p->stream << "\n"; + p->stream << "==================================================\n"; + } + + AccessAnalyzer ana = GetRef(node); + + p->stream << "ElementwiseMatch: \n"; + for (size_t i = 0; i < node->ops_topo_order.size(); ++i) { + for (size_t j = 0; j < node->ops_topo_order.size(); ++j) { + if (i == j) { continue; } + if (ana.ElementWiseMatch(node->ops_topo_order[i], node->ops_topo_order[j])) { + p->stream << node->ops_topo_order[i]->func_name() << " -> " + << node->ops_topo_order[j]->func_name() << "\n"; + } + } + } +}); + +} // namespace ansor +} // namespace tvm diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h new file mode 100644 index 000000000000..c8da44fee828 --- /dev/null +++ b/src/ansor/compute_dag.h @@ -0,0 +1,161 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file ansor/compute_dag.h + * \brief Compute declaration graph and its related analysis tools + */ + +#ifndef TVM_ANSOR_COMPUTE_DAG_H_ +#define TVM_ANSOR_COMPUTE_DAG_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "utils.h" + +namespace tvm { +namespace ansor { + +class ComputeDAG; class AccessAnalyzer; +class StateNode; class State; class Step; + +typedef std::unordered_map, ObjectHash, ObjectEqual> + StageToAxesMap; + +// Update StageToAxes Map during replay +void UpdateStageAxis(const tvm::te::Stage& stage, StageToAxesMap *stage_to_axes); + +/*! \brief Read/Write access static analysis result */ +class AccessAnalyzerNode : public Object { + public: + template + using OperationMap = std::unordered_map; + + OperationMap > > > read_from; + OperationMap > > > read_by; + OperationMap is_injective; + OperationMap is_strict_inlineable; + OperationMap needs_multi_level_tiling; + OperationMap is_output; + std::vector ops_topo_order; + + static AccessAnalyzer make(const Array& tensors); + + static constexpr const char* _type_key = "ansor.AccessAnalyzer"; + TVM_DECLARE_FINAL_OBJECT_INFO(AccessAnalyzerNode, Object); +}; + +/*! \brief Read/Write access static analysis result */ +class AccessAnalyzer : public ObjectRef { + public: + // read/write access analysis + bool NeedsMultiLevelTiling(const te::Operation& op) const; + bool IsInjective(const te::Operation& op) const; + bool IsStrictInlineable(const te::Operation& op) const; + bool IsOutput(const te::Operation& op) const; + + // Get all producers of an op + void GetProducers(const State& state, const te::Operation& op, + std::unordered_set* producers) const; + // Get all consumers of an op. This func deals with inlined op correctly. + void GetConsumers(const State& state, const te::Operation& op, + std::unordered_set* consumers) const; + // Check whether two ops are elementwise matched + // (e.g. conv2d and relu are elementwise matched) + bool ElementWiseMatch(const te::Operation& op, + const te::Operation& target_op) const; + + /*! \Note The current implementation follows these (rough) definitions. + * + * Definition of data-reuse : Exists axis in (op->axis union op->reduce_axis) + * and acc in read accesses, such that axis not in acc. + * (e.g. A[i][j] = B[i] has data reuse, while A[i][j] = B[i][j] does not) + * Definition of NeedsMultiLevelTiling: Exists two acc, both of them make this op have data reuse. + * Definition of injective : For all index expressions, they are single axis variable + * plus an optional const shift. + * (e.g. A[i][j] = B[i][j], A[i][j] = B[i+1][j] are injective, while A[i][j] = B[i*j] is not) + * Definition of strict-inlineable : All read accesses are elementwise, and no branch in the body + * (e.g. A[i][j] = B[i][j] + C[i][j] is strict-inlineable, + * while A[i][j] = tvm_if_then_else(B[i][j] > 0, C[i][j], 0) is not + */ + TVM_DEFINE_OBJECT_REF_METHODS(AccessAnalyzer, ObjectRef, AccessAnalyzerNode); +}; + +/*! \brief Compute declaration graph */ +class ComputeDAGNode : public Object { + public: + Array tensors; // Input and output tensors + Array ops; // All related operations in topo order + double flop_ct; // Number of float operations + AccessAnalyzer access_analyzer; // Read/Write accesss static analyzer + ObjectRef init_state; // initial states + + void VisitAttrs(tvm::AttrVisitor* v); + + static ComputeDAG make(Array tensors); + static ComputeDAG make_by_workload_key(const std::string& workload_key); + + static constexpr const char* _type_key = "ansor.ComputeDAG"; + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeDAGNode, Object); +}; + +enum LayoutRewriteLevel { + kNoRewrite = 0, // No layout rewrite + kPlaceholderRewrite = 1, // Only rewrite layout of placeholder in the compute dag + kComputeRewrite = 2, // Only rewrite compute body for new layout in the compute dag + kBothRewrite = 3, // Rewrite both placeholder and compute body in the compute dag +}; + +/*! \brief Compute declaration graph */ +class ComputeDAG: public ObjectRef { + public: + // Apply transform steps to the init state of this DAG, and get the equivalent tvm::schedule. + // The return values can be used as arguments to tvm.build or tvm.lower + std::pair > ApplySteps( + const std::vector& transform_steps, + LayoutRewriteLevel layout_rewrite_level = kNoRewrite) const; + + // Rewrite the the layout of "layout free" placeholders according to transform steps + void RewriteLayout(const std::vector& transform_steps, + LayoutRewriteLevel layout_rewrite_level = kNoRewrite) const {}; + + // Print transform steps as equivalent python schedule API + std::string PrintStepsAsPython(const std::vector& steps) const; + + // Replay the transform steps and call ir_pass::InferBound to fill correct bound information + State ReplayAndInferBound(const std::vector& transform_steps) const; + + // Fill the correct bound information for a given state + State InferBound(const State& state) const; + + // Fill the correct bound information for a list of given states. + // Return the new states inplace + void InferBound(std::vector* states) const; + + // Replay the transform steps and get the new ops + void ReplayAndGetDAG(const std::vector& steps, ComputeDAG* task_dag) const; + + // Get the init state + State GetInitState() const; + + TVM_DEFINE_OBJECT_REF_METHODS(ComputeDAG, ObjectRef, ComputeDAGNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode); + + private: + // Internal common parts for replaying steps + std::pair > ReplaySteps( + const std::vector& transform_steps, std::vector* stages, + StageToAxesMap* stage_to_axes) const {}; + static constexpr const char* _layout_free_placeholders_key = "layout_free_placeholders"; + + // Internal common parts for inferring bound + void InferBoundCommon(StateNode* pstate) const; +}; + +} // namespace ansor +} // namespace tvm + +#endif // TVM_ANSOR_COMPUTE_DAG_H_ diff --git a/src/ansor/expr_hasher.h b/src/ansor/expr_hasher.h new file mode 100644 index 000000000000..1c743ed9a5c4 --- /dev/null +++ b/src/ansor/expr_hasher.h @@ -0,0 +1,97 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file auto_scheduler/expr_hasher.h + * \brief Hash function for a tvm::Expr + */ + +#ifndef TVM_ANSOR_EXPR_HASHER_H_ +#define TVM_ANSOR_EXPR_HASHER_H_ + +#include +#include +#include +#include + +namespace tvm { + +/*! \brief Assign a hash value for a tvm::Expr */ +class ExprHasher: public tir::ExprFunctor { + public: + size_t VisitExpr_(const tir::AddNode* op) final { + return VisitExpr(op->a) + VisitExpr(op->b); + } + + size_t VisitExpr_(const tir::SubNode* op) final { + return VisitExpr(op->a) - VisitExpr(op->b); + } + + size_t VisitExpr_(const tir::MulNode* op) final { + return VisitExpr(op->a) * VisitExpr(op->b); + } + + size_t VisitExpr_(const tir::DivNode* op) final { + size_t t = VisitExpr(op->b); + if (t != 0) { + return VisitExpr(op->a) / t; + } else { + return dmlc::HashCombine(VisitExpr(op->a), 0x5A); + } + } + + size_t VisitExpr_(const tir::FloorDivNode* op) final { + size_t t = VisitExpr(op->b); + if (t != 0) { + return VisitExpr(op->a) / t; + } else { + return dmlc::HashCombine(VisitExpr(op->a), 0x5B); + } + } + + size_t VisitExpr_(const tir::ModNode* op) final { + size_t t = VisitExpr(op->b); + if (t != 0) { + return VisitExpr(op->a) % t; + } else { + return dmlc::HashCombine(VisitExpr(op->a), 0x5C); + } + } + + size_t VisitExpr_(const tir::FloorModNode* op) final { + size_t t = VisitExpr(op->b); + if (t != 0) { + return VisitExpr(op->a) % t; + } else { + return dmlc::HashCombine(VisitExpr(op->a), 0x5D); + } + } + + size_t VisitExpr_(const tir::CallNode* op) final { + size_t ret = ObjectHash()(op->func); + for (size_t i = 0; i < op->args.size(); ++i) { + ret = dmlc::HashCombine(ret, VisitExpr(op->args[i])); + } + return ret; + } + + size_t VisitExpr_(const tir::VarNode* op) final { + return std::hash()(op); + } + + size_t VisitExpr_(const tir::FloatImmNode* op) final { + return std::hash()(op->value); + } + + size_t VisitExpr_(const tir::IntImmNode* op) final { + return std::hash()(op->value); + } + + size_t VisitExprDefault_(const Object* op) final { + LOG(WARNING) << "Encounter undefined node in ExprHasher: " + << Object::_type_key; + return std::hash()(op); + } +}; + +} // namespace tvm + +#endif // TVM_ANSOR_EXPR_HASHER_H_ diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc new file mode 100644 index 000000000000..92157edc463d --- /dev/null +++ b/src/ansor/loop_state.cc @@ -0,0 +1,1729 @@ +/*! + * Copyright (c) 2020 by Contributors + */ +#include "loop_state.h" +#include +#include "utils.h" + +namespace tvm { +namespace ansor { + +TVM_REGISTER_OBJECT_TYPE(StepNode); +TVM_REGISTER_NODE_TYPE(StateNode); + +inline std::string CleanName(const std::string& str) { + // to make the name valid in python code + std::string ret = str; + StrReplace(&ret, ".", "_"); + StrReplace(&ret, "@", "_"); + StrReplace(&ret, "outer", "o"); + StrReplace(&ret, "inner", "i"); + return ret; +} + +/********** Reorder **********/ +ReorderStep ReorderStepNode::make(int stage_id, const std::vector& after_ids) { + auto node = make_object(); + node->stage_id = stage_id; + node->after_ids = after_ids; + return ReorderStep(node); +} + +void ReorderStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + te::Stage& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + CHECK_EQ(after_ids.size(), axes.size()); + + std::vector new_axes; + new_axes.reserve(axes.size()); + for (auto i : after_ids) { + new_axes.push_back(axes[i]); + } + stage.reorder(new_axes); + (*stage_to_axes)[stage] = std::move(new_axes); +} + +std::string ReorderStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + const te::Stage& stage = (*stages)[stage_id]; + std::stringstream ss; + + ss << "s[" << CleanName(stage->op->func_name()) << "].reorder("; + for (size_t i = 0; i < after_ids.size(); ++i) { + ss << CleanName((*stage_to_axes)[stage][after_ids[i]]->var->name_hint); + if (i != after_ids.size() - 1) { + ss << ", "; + } + } + ss << ")\n"; + + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} + +/********** Split **********/ +std::vector ApplySplitToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, + int stage_id, + int iter_id, + const std::vector& lengths, + bool inner_to_outer) { + te::Stage& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + + std::vector outs; + if (inner_to_outer) { + IterVar outer = axes[iter_id], inner; + for (int i = static_cast(lengths.size()) - 1; i >= 0; i--) { + IterVar to_split = outer; + stage.split(to_split, lengths[i], &outer, &inner); + outs.push_back(inner); + } + outs.push_back(outer); + } else { + IterVar outer, inner = axes[iter_id]; + for (size_t i = 0; i < lengths.size(); i++) { + IterVar to_split = inner; + stage.split_by_nparts(to_split, lengths[i], &outer, &inner); + outs.push_back(outer); + } + outs.push_back(inner); + } + + std::vector new_axes; + new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + iter_id); + if (inner_to_outer) { + new_axes.insert(new_axes.end(), outs.rbegin(), outs.rend()); + } else { + new_axes.insert(new_axes.end(), outs.begin(), outs.end()); + } + new_axes.insert(new_axes.end(), axes.begin() + iter_id + 1, axes.end()); + (*stage_to_axes)[stage] = std::move(new_axes); + + return outs; +} + +std::string PrintSplitAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + int stage_id, + int iter_id, + const std::vector& lengths, + bool inner_to_outer) { + te::Stage& stage = (*stages)[stage_id]; + auto to_split = (*stage_to_axes)[stage][iter_id]; + const auto& func_name = CleanName(stage->op->func_name()); + const auto& outs = ApplySplitToSchedule(stages, stage_to_axes, stage_id, + iter_id, lengths, inner_to_outer); + + std::stringstream ss; + int size = static_cast(lengths.size()); + if (inner_to_outer) { + for (int i = size - 1; i >= 0; i--) { + ss << CleanName(outs[size - i]->var->name_hint) << ", " + << CleanName(outs[size - i - 1]->var->name_hint) + << " = s[" << func_name << "].split(" + << CleanName(to_split->var->name_hint) + << ", factor=" << lengths[i] << ")\n"; + to_split = outs[size - i]; + } + } else { + for (int i = 0; i < size; i++) { + ss << CleanName(outs[i]->var->name_hint) << ", " + << CleanName(outs[i + 1]->var->name_hint) + << " = s[" << func_name << "].split(" + << CleanName(to_split->var->name_hint) + << ", nparts=" << lengths[i] << ")\n"; + to_split = outs[i + 1]; + } + } + + return ss.str(); +} + +SplitStep SplitStepNode::make(int stage_id, int iter_id, + PrimExpr extent, const std::vector& lengths, + bool inner_to_outer) { + auto node = make_object(); + node->stage_id = stage_id; + // Extent can be a unreducible expression in some special cases + if (extent->IsInstance()) { + node->extent = std::move(extent); + } + node->iter_id = iter_id; + node->lengths = lengths; + node->inner_to_outer = inner_to_outer; + return SplitStep(node); +} + +std::vector SplitStepNode::ApplyToSchedule( + std::vector *stages, StageToAxesMap *stage_to_axes) const { + return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, + lengths, inner_to_outer); +} + +std::string SplitStepNode::PrintAsPythonAPI( + std::vector *stages, StageToAxesMap *stage_to_axes, + te::Schedule *schedule, const std::vector& transform_steps) const { + return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, + lengths, inner_to_outer); +} + +/********** Follow Split **********/ +FollowSplitStep FollowSplitStepNode::make(int stage_id, int iter_id, + int src_step_id, int n_split) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->src_step_id = src_step_id; + node->n_split = n_split; + return FollowSplitStep(node); +} + +void FollowSplitStepNode::ExtractSplitLengths(const std::vector& transform_steps, + std::vector* lengths) const { + CHECK_LT(src_step_id, transform_steps.size()); + auto ps = transform_steps[src_step_id].as(); + CHECK(ps != nullptr); + + // get lengths from src step + lengths->reserve(n_split); + int j = 0; + for (; j < n_split - 1; ++j) { + lengths->push_back(ps->lengths[j]); + } + PrimExpr last_factor = 1; + for (; j < static_cast(ps->lengths.size()); ++j) { + if (ps->lengths[j].defined()) { + last_factor *= ps->lengths[j]; + } else { + last_factor = PrimExpr(); + break; + } + } + lengths->push_back(std::move(last_factor)); +} + +std::vector FollowSplitStepNode::ApplyToSchedule( + std::vector *stages, StageToAxesMap *stage_to_axes, + const std::vector& transform_steps) const { + std::vector lengths; + ExtractSplitLengths(transform_steps, &lengths); + return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, + lengths, true); +} + +std::string FollowSplitStepNode::PrintAsPythonAPI( + std::vector *stages, StageToAxesMap *stage_to_axes, + te::Schedule *schedule, const std::vector& transform_steps) const { + std::vector lengths; + ExtractSplitLengths(transform_steps, &lengths); + return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, + lengths, true); +} + +/********** Follow Fused Split **********/ +FollowFusedSplitStep FollowFusedSplitStepNode::make(int stage_id, int iter_id, + const std::vector& src_step_ids, int level, bool factor_or_nparts) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->src_step_ids = src_step_ids;; + node->level = level; + node->factor_or_nparts = factor_or_nparts; + return FollowFusedSplitStep(node); +} + +PrimExpr FollowFusedSplitStepNode::ExtractSplitLength(const std::vector& transform_steps) const { + PrimExpr ret(1); + + for (int src_step_id : src_step_ids) { + CHECK_LT(src_step_id, transform_steps.size()); + auto ps = transform_steps[src_step_id].as(); + CHECK(ps != nullptr); + if (ps->lengths[level].defined() && ret.defined()) { + ret *= ps->lengths[level]; + } else { + return PrimExpr(); + } + } + + return ret; +} + +std::vector FollowFusedSplitStepNode::ApplyToSchedule( + std::vector *stages, StageToAxesMap *stage_to_axes, + const std::vector& transform_steps) const { + const PrimExpr& length = ExtractSplitLength(transform_steps); + return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, + {length}, factor_or_nparts); +} + +std::string FollowFusedSplitStepNode::PrintAsPythonAPI( + std::vector *stages, StageToAxesMap *stage_to_axes, + te::Schedule *schedule, const std::vector& transform_steps) const { + const PrimExpr& length = ExtractSplitLength(transform_steps); + return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, + {length}, factor_or_nparts); +} + + +/********** Fuse **********/ +FuseStep FuseStepNode::make(int stage_id, const std::vector& fused_ids) { + auto node = make_object(); + node->stage_id = stage_id; + node->fused_ids = fused_ids; + return FuseStep(node); +} + +IterVar FuseStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + te::Stage& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + + Array to_fuse; + for (auto i : fused_ids) { + to_fuse.push_back(axes[i]); + } + IterVar fused_axis; + stage.fuse(to_fuse, &fused_axis); + std::vector new_axes; + new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + fused_ids[0]); + new_axes.push_back(fused_axis); + new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back() + 1, + axes.end()); + (*stage_to_axes)[stage] = std::move(new_axes); + + return fused_axis; +} + +std::string FuseStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + const auto& stage = (*stages)[stage_id]; + std::stringstream to_fuse; + + for (size_t i = 0; i < fused_ids.size(); ++i) { + to_fuse << CleanName((*stage_to_axes)[stage][fused_ids[i]]->var->name_hint); + if (i != fused_ids.size() - 1) { + to_fuse << ", "; + } + } + + std::stringstream ss; + const auto& fused = ApplyToSchedule(stages, stage_to_axes); + + ss << CleanName(fused->var->name_hint) << " = s[" + << CleanName(stage->op->func_name()) << "].fuse(" + << to_fuse.str() << ")\n"; + + return ss.str(); +} + +/********** Annotation **********/ +AnnotationStep AnnotationStepNode::make(int stage_id, int iter_id, IteratorAnnotation ann) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->annotation = ann; + return AnnotationStep(node); +} + +void AnnotationStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + te::Stage& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + + switch (annotation) { + case kUnroll: stage.unroll(axes[iter_id]); break; + case kVectorize: stage.vectorize(axes[iter_id]); break; + case kParallel: stage.parallel(axes[iter_id]); break; + case kVThread: stage.bind(axes[iter_id], te::thread_axis(Range(), "vthread")); break; + case kBlockX: stage.bind(axes[iter_id], te::thread_axis(Range(), "blockIdx.x")); break; + case kBlockY: stage.bind(axes[iter_id], te::thread_axis(Range(), "blockIdx.y")); break; + case kThreadX: + if (axes[iter_id]->iter_type == kCommReduce) { + const auto &thread_x = te::thread_axis(Range(), "threadIdx.x"); + stage.bind(axes[iter_id], thread_x); + stage.set_store_predicate(thread_x->var == 0); + } else { + stage.bind(axes[iter_id], te::thread_axis(Range(), "threadIdx.x")); + } + break; + case kThreadY: stage.bind(axes[iter_id], te::thread_axis(Range(), "threadIdx.y")); break; + case kNone: break; + default: LOG(FATAL) << "Invalid Annotation " << annotation; break; + } +} + +std::string AnnotationStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + const auto& iter = (*stage_to_axes)[stage][iter_id]; + + bool bind_reduce_iter = iter->iter_type == kCommReduce && annotation == kThreadX; + if (bind_reduce_iter) { + ss << "thread_x = tvm.thread_axis(\"threadIdx.x\")\n"; + } + + ss << "s[" << CleanName(stage->op->func_name()) << "]."; + switch (annotation) { + case kUnroll: ss << "unroll("; break; + case kVectorize: ss << "vectorize("; break; + case kParallel: ss << "parallel("; break; + case kVThread: + case kBlockX: + case kBlockY: + case kThreadX: + case kThreadY: ss << "bind("; break; + case kNone: break; + default: + LOG(FATAL) << "Invalid annotation " << annotation; break; + } + ss << CleanName(iter->var->name_hint); + switch (annotation) { + case kVThread: ss << ", tvm.thread_axis(\"vthread\")"; break; + case kBlockX: ss << ", tvm.thread_axis(\"blockIdx.x\")"; break; + case kBlockY: ss << ", tvm.thread_axis(\"blockIdy.y\")"; break; + case kThreadX: + if (bind_reduce_iter) { + ss << ", thread_x"; + } else { + ss << ", tvm.thread_axis(\"threadIdx.x\")"; + } + break; + case kThreadY: ss << ", tvm.thread_axis(\"threadIdx.y\")"; break; + default: break; + } + ss << ")\n"; + + if (bind_reduce_iter) { + ss << "s[" << CleanName(stage->op->func_name()) << "]" + << ".set_store_predicate(thread_x.var.equal(0))\n"; + } + + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} + +/********** Compute at **********/ +ComputeAtStep ComputeAtStepNode::make(int stage_id, int target_stage_id, int target_iter_id) { + auto node = make_object(); + node->stage_id = stage_id; + node->target_stage_id = target_stage_id; + node->target_iter_id = target_iter_id; + return ComputeAtStep(node); +} + +void ComputeAtStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + te::Stage& stage = (*stages)[stage_id]; + const IterVar& target_axis = + (*stage_to_axes)[(*stages)[target_stage_id]][target_iter_id]; + stage.compute_at((*stages)[target_stage_id], target_axis); +} + +std::string ComputeAtStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + const auto& target_stage = (*stages)[target_stage_id]; + + ss << "s[" << CleanName(stage->op->func_name()) << "].compute_at(s[" + << CleanName(target_stage->op->func_name()) << "], " + << CleanName((*stage_to_axes)[target_stage][target_iter_id]->var->name_hint); + + ss << ")\n"; + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} + +/********** Compute Root **********/ +ComputeRootStep ComputeRootStepNode::make(int stage_id) { + auto node = make_object(); + node->stage_id = stage_id; + return ComputeRootStep(node); +} + +void ComputeRootStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + (*stages)[stage_id].compute_root(); +} + +std::string ComputeRootStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + + ss << "s[" << CleanName(stage->op->func_name()) << "].compute_root()\n"; + ApplyToSchedule(stages, stage_to_axes); + + return ss.str(); +} + +/********** Compute Inline **********/ +ComputeInlineStep ComputeInlineStepNode::make(int stage_id) { + auto node = make_object(); + node->stage_id = stage_id; + return ComputeInlineStep(node); +} + +void ComputeInlineStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + (*stages)[stage_id].compute_inline(); +} + +std::string ComputeInlineStepNode::PrintAsPythonAPI( + std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + + ss << "s[" << CleanName(stage->op->func_name()) << "].compute_inline()\n"; + ApplyToSchedule(stages, stage_to_axes); + + return ss.str(); +} + +/********** Pack for vec **********/ +PackForVecStep PackForVecStepNode::make(int stage_id, int iter_id, int vec_size) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->vec_size = vec_size; + return PackForVecStep(node); +} + +void PackForVecStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, te::Schedule *schedule) const { + LOG(FATAL) << "Not implemented"; +} + +std::string PackForVecStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + LOG(FATAL) << "Not implemented"; + return ""; +} + +/********** Cache read **********/ +CacheReadStep CacheReadStepNode::make(int stage_id, std::string scope_name, + const std::vector& reader_stage_ids) { + auto node = make_object(); + node->stage_id = stage_id; + node->scope_name = std::move(scope_name); + node->reader_stage_ids = reader_stage_ids; + return CacheReadStep(node); +} + +te::Tensor CacheReadStepNode::ApplyToSchedule(std::vector* stages, + StageToAxesMap *stage_to_axes, te::Schedule *schedule) const { + te::Stage& stage = (*stages)[stage_id]; + + Array readers; + for (const auto& i : reader_stage_ids) { + readers.push_back((*stages)[i]->origin_op); + } + auto out = schedule->cache_read(stage->origin_op.output(0), scope_name, readers); + + const auto& new_stage = (*schedule)[out->op]; + UpdateStageAxis(new_stage, stage_to_axes); + stages->insert(stages->begin() + stage_id + 1, new_stage); + + return out; +} + +std::string CacheReadStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + std::stringstream ss; + // copy stage here, for the original stage will change after apply + auto stage = (*stages)[stage_id]; + std::vector reader_stages; + for (size_t i = 0; i < reader_stage_ids.size(); ++i) { + reader_stages.push_back((*stages)[reader_stage_ids[i]]); + } + + auto out = ApplyToSchedule(stages, stage_to_axes, schedule); + + ss << CleanName(out->op->func_name()) << " = " + << "s.cache_read(" << CleanName(stage->op->func_name()) << ", \"" + << scope_name << "\", [" + << CleanName(reader_stages[0]->op->func_name()); + for (size_t i = 1; i < reader_stage_ids.size(); ++i) { + ss << ", " << CleanName(reader_stages[i]->op->func_name()); + } + ss << "])\n"; + + const auto& iters = out->op->root_iter_vars(); + for (size_t i = 0; i < iters.size(); ++i) { + ss << CleanName(iters[i]->var->name_hint); + if (i != iters.size() - 1) { + ss << ", "; + } + } + ss << " = " << "tuple(" << CleanName(out->op->func_name()) + << ".op.axis)\n"; + + return ss.str(); +} + +/********** Cache write **********/ +CacheWriteStep CacheWriteStepNode::make(int stage_id, std::string scope_name) { + auto node = make_object(); + node->stage_id = stage_id; + node->scope_name = std::move(scope_name); + return CacheWriteStep(node); +} + +Array CacheWriteStepNode::ApplyToSchedule( + std::vector *stages, StageToAxesMap *stage_to_axes, + te::Schedule *schedule) const { + te::Stage& stage = (*stages)[stage_id]; + + Array tensor_array; + // If the target stage has multi outputs, TVM requires to cache_write + // all of them or schedule.cache_write will raise an error + for (auto i = 0; i < stage->op->num_outputs(); ++i) { + tensor_array.push_back(stage->origin_op.output(i)); + } + auto outs = schedule->cache_write(tensor_array, scope_name); + + UpdateStageAxis(stage, stage_to_axes); + // Even if there is multi outputs, TVM schedule only generate one + // new stage + const auto& new_stage = (*schedule)[outs[0]->op]; + UpdateStageAxis(new_stage, stage_to_axes); + stages->insert(stages->begin() + stage_id, new_stage); + + return outs; +} + +std::string CacheWriteStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + std::stringstream ss; + // copy stage here, for the original stage will change after apply + te::Stage stage = (*stages)[stage_id]; + + auto outs = ApplyToSchedule(stages, stage_to_axes, schedule); + + for (size_t i = 0; i < outs.size(); ++i) { + ss << CleanName(outs[i]->op->func_name()) << ", "; + } + ss << "= " << "s.cache_write([" + << CleanName(stage->op.output(0)->op->name); + for (auto i = 1; i < stage->op->num_outputs(); ++i) { + ss << ", " << CleanName(stage->op.output(i)->op->name); + } + ss << "], \"" << scope_name << "\")\n"; + + for (const auto& out : outs) { + const auto& iters = out->op->root_iter_vars(); + for (size_t i = 0; i < iters.size(); ++i) { + ss << CleanName(iters[i]->var->name_hint); + if (i != iters.size() - 1) { + ss << ", "; + } + } + ss << " = " << "tuple(" << CleanName(out->op->func_name()) + << ".op.axis)" + << " + " << "tuple(" << CleanName(out->op->func_name()) + << ".op.reduce_axis)\n"; + } + + return ss.str(); +} + +/********** Pragma **********/ +PragmaStep PragmaStepNode::make(int stage_id, int iter_id, + std::string pragma_type) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->pragma_type = std::move(pragma_type); + return PragmaStep(node); +} + +void PragmaStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + te::Stage& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + if (StrStartsWith(pragma_type, "auto_unroll_max_step")) { + size_t pos = pragma_type.find('$'); + int value = atoi(pragma_type.c_str() + pos + 1); + stage.pragma(axes[iter_id], "auto_unroll_max_step", value); + stage.pragma(axes[iter_id], "unroll_explicit", true); + } else { + stage.pragma(axes[iter_id], pragma_type); + } +} + +std::string PragmaStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + + if (StrStartsWith(pragma_type, "auto_unroll_max_step")) { + size_t pos = pragma_type.find('$'); + int value = atoi(pragma_type.c_str() + pos + 1); + ss << "s[" << CleanName(stage->op->func_name()) << "].pragma(" + << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) + << ", \"auto_unroll_max_step\", " << value << ")\n"; + ss << "s[" << CleanName(stage->op->func_name()) << "].pragma(" + << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) + << ", \"unroll_explicit\", True)\n"; + } else { + ss << "s[" << CleanName(stage->op->func_name()) << "].pragma(" + << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", \"" + << pragma_type << "\")\n"; + } + + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} + +/********** Rfactor **********/ +RfactorStep RfactorStepNode::make(int stage_id, int iter_id, int factor_iter_id) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->factor_iter_id = factor_iter_id; + return RfactorStep(node); +} + +Array RfactorStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, te::Schedule *schedule) const { + const auto& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + + const te::Tensor& tensor = stage->origin_op.output(0); + const IterVar& axis = axes[iter_id]; + auto outs = schedule->rfactor(tensor, axis, factor_iter_id); + + UpdateStageAxis(stage, stage_to_axes); + + const auto& new_stage = (*schedule)[outs[0]->op]; + UpdateStageAxis(new_stage, stage_to_axes); + stages->insert(stages->begin() + stage_id, new_stage); + + return outs; +} + +std::string RfactorStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + + const auto& tensor_name = CleanName(stage->origin_op.output(0)->op->name); + const auto& axis_name = CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint); + + const auto& outs = ApplyToSchedule(stages, stage_to_axes, schedule); + + for (size_t i = 0; i < outs.size(); ++i) { + ss << CleanName(outs[i]->op->func_name()); + if (i != outs.size() - 1) { + ss << ", "; + } + } + ss << " = " << "s.rfactor(" + << tensor_name << ", " + << axis_name << ", " + << factor_iter_id << ")\n"; + + for (const auto& out : outs) { + const auto& iters = out->op->root_iter_vars(); + for (size_t i = 0; i < iters.size(); ++i) { + ss << CleanName(iters[i]->var->name_hint); + if (i != iters.size() - 1) { + ss << ", "; + } + } + ss << " = " << "tuple(" << CleanName(out->op->func_name()) + << ".op.axis)" + << " + " << "tuple(" << CleanName(out->op->func_name()) + << ".op.reduce_axis)\n"; + } + + const auto& output = (*stages)[stage_id + 1]->op.output(0); + const auto& iters = output->op->root_iter_vars(); + for (size_t i = 0; i < iters.size(); ++i) { + ss << CleanName(iters[i]->var->name_hint); + if (i != iters.size() - 1) { + ss << ", "; + } + } + ss << " = " << "tuple(s[" << CleanName(output->op->func_name()) + << "].op.axis)" + << " + " << "tuple(s[" << CleanName(output->op->func_name()) + << "].op.reduce_axis)\n"; + + return ss.str(); +} + +/********** StorageAlign **********/ + +StorageAlignStep StorageAlignStepNode::make(int stage_id, int iter_id, + int factor, int offset) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->factor = factor; + node->offset = offset; + return StorageAlignStep(node); +} + +void StorageAlignStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + te::Stage& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + stage.storage_align(axes[iter_id], factor, offset); +} + +std::string StorageAlignStepNode::PrintAsPythonAPI( + std::vector *stages, StageToAxesMap *stage_to_axes, + te::Schedule *schedule, const std::vector& transform_steps) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + ss << "s[" << CleanName(stage->op->func_name()) << "].storage_align(" + << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", " + << factor << ", " << offset << ")\n"; + + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} + +// Maker for other classes +Iterator IteratorNode::make(std::string name, Range range, + IteratorType iter_type, IteratorAnnotation annotation, + const std::vector* ori_iters) { + auto node = make_object(); + node->name = std::move(name); + node->range = std::move(range); + node->iter_type = iter_type; + node->annotation = annotation; + if (ori_iters != nullptr) { + node->ori_iters = *ori_iters; + } + return Iterator(node); +} + +Stage StageNode::make(te::Operation op) { + auto node = make_object(); + if (op->IsInstance()) { + node->op_type = kCompute; + auto *pop = op.as(); + + for (const auto& axis : pop->axis) { + node->iters.push_back(IteratorNode::make(CleanName(axis->var->name_hint), + axis->dom, kSpace, kNone)); + } + for (const auto& axis : pop->reduce_axis) { + node->iters.push_back(IteratorNode::make(CleanName(axis->var->name_hint), + axis->dom, kReduce, kNone)); + } + } else if (op->IsInstance()) { + node->op_type = kPlaceholder; + } else { + LOG(FATAL) << "Unsupported operator type" << op->_type_key; + } + + node->compute_at = kRoot; + node->op = std::move(op); + node->auto_unroll_max_step = 0; + node->storage_offset = 0; + return Stage(node); +} + +Stage StageNode::make(te::Operation op, StageType op_type, const std::vector& iters, + ComputeAtType compute_at, int16_t auto_unroll_max_step, int storage_offset) { + auto node = make_object(); + node->op = std::move(op); + node->op_type = op_type; + node->iters = iters; + node->compute_at = compute_at; + node->auto_unroll_max_step = auto_unroll_max_step; + node->storage_offset = storage_offset; + return Stage(node); +} + +Stage StageNode::make(te::Operation op, StageType op_type, std::vector&& iters, + ComputeAtType compute_at, int16_t auto_unroll_max_step, int storage_offset) { + auto node = make_object(); + node->op = std::move(op); + node->op_type = op_type; + node->iters = std::move(iters); + node->compute_at = compute_at; + node->auto_unroll_max_step = auto_unroll_max_step; + node->storage_offset = storage_offset; + return Stage(node); +} + +State StateNode::make_empty_state() { + auto node = make_object(); + node->attach_map = AttachMapNode::make(); + node->complete = false; + node->aux_info = ObjectRef(); + return State(node); +} + +State StateNode::make(const Array& ops) { + auto node = make_object(); + for (const auto& op : ops) { + node->stages.push_back(StageNode::make(op)); + } + node->attach_map = AttachMapNode::make(); + node->complete = true; + node->aux_info = ObjectRef(); + return State(node); +} + +State StateNode::make(const std::vector& stages, + const std::vector& transform_steps, + bool complete, ObjectRef aux_info) { + auto node = make_object(); + node->stages = stages; + node->transform_steps = transform_steps; + node->attach_map = AttachMapNode::make(); + node->complete = complete; + node->aux_info = std::move(aux_info); + return State(node); +} + +AttachMap AttachMapNode::make() { + auto node = make_object(); + return AttachMap(node); +} + +// Schedule primitives api +void State::reorder(int stage_id, const std::vector& order) { + const Stage& stage = operator->()->stages[stage_id]; + + CHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators " + "should be specified"; + std::vector after_ids; + GetIndices(stage->iters, order, &after_ids); + ReorderStep step = ReorderStepNode::make(stage_id, after_ids); + CopyOnWrite()->transform_steps.push_back(step); + DoReorderStep(step); +} + +std::vector State::split(int stage_id, + const Iterator& it, const std::vector& lengths, bool inner_to_outer) { + const Stage& stage = operator->()->stages[stage_id]; + + SplitStep step = SplitStepNode::make(stage_id, GetIndex(stage->iters, it), + it->range.defined() ? it->range->extent : PrimExpr(), lengths, + inner_to_outer); + CopyOnWrite()->transform_steps.push_back(step); + return DoSplitStep(step); +} + +std::vector State::follow_split(int stage_id, + const Iterator& it, int src_step_id, int n_split) { + const Stage& stage = operator->()->stages[stage_id]; + + FollowSplitStep step = FollowSplitStepNode::make(stage_id, + GetIndex(stage->iters, it), src_step_id, n_split); + CopyOnWrite()->transform_steps.push_back(step); + return DoFollowSplitStep(step); +} + + +std::vector State::follow_fused_split(int stage_id, const Iterator& it, + const std::vector& src_step_ids, int level, bool factor_or_nparts) { + const Stage& stage = operator->()->stages[stage_id]; + + FollowFusedSplitStep step = FollowFusedSplitStepNode::make(stage_id, + GetIndex(stage->iters, it), src_step_ids, level, factor_or_nparts); + CopyOnWrite()->transform_steps.push_back(step); + return DoFollowFusedSplitStep(step); +} + +Iterator State::fuse(int stage_id, const std::vector& iters) { + const Stage& stage = operator->()->stages[stage_id]; + std::vector indices; + GetIndices(stage->iters, iters, &indices); + FuseStep step = FuseStepNode::make(stage_id, indices); + CopyOnWrite()->transform_steps.push_back(step); + return DoFuseStep(step); +} + +Iterator State::vectorize(int stage_id, const Iterator& it) { + const Stage& stage = operator->()->stages[stage_id]; + AnnotationStep step = AnnotationStepNode::make(stage_id, GetIndex(stage->iters, it), + kVectorize); + CopyOnWrite()->transform_steps.push_back(step); + return DoAnnotationStep(step); +} + +Iterator State::parallel(int stage_id, const Iterator& it) { + const Stage& stage = operator->()->stages[stage_id]; + AnnotationStep step = AnnotationStepNode::make(stage_id, GetIndex(stage->iters, it), + kParallel); + CopyOnWrite()->transform_steps.push_back(step); + return DoAnnotationStep(step); +} + +Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) { + const Stage& stage = operator->()->stages[stage_id]; + AnnotationStep step = AnnotationStepNode::make(stage_id, GetIndex(stage->iters, it), + kUnroll); + + // don't unroll if the extent is larger than max_unroll + if (max_unroll != -1 && it->range.defined()) { + if (auto imm = it->range->extent.as()) { + if (imm->value > max_unroll) { + return it; + } + } + } + + CopyOnWrite()->transform_steps.push_back(step); + return DoAnnotationStep(step); +} + +void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) { + const Stage& target_stage = operator->()->stages[target_stage_id]; + ComputeAtStep step = ComputeAtStepNode::make(stage_id, target_stage_id, + GetIndex(target_stage->iters, target_iter)); + CopyOnWrite()->transform_steps.push_back(step); + return DoComputeAtStep(step); +} + +void State::compute_root(int stage_id) { + ComputeRootStep step = ComputeRootStepNode::make(stage_id); + CopyOnWrite()->transform_steps.push_back(step); + return DoComputeRootStep(step); +} + +void State::compute_inline(int stage_id) { + ComputeInlineStep step = ComputeInlineStepNode::make(stage_id); + CopyOnWrite()->transform_steps.push_back(step); + return DoComputeInlineStep(step); +} + +void State::pack_for_vec(int stage_id, const Iterator& target_iter, int vec_size) { + const Stage& stage = operator->()->stages[stage_id]; + PackForVecStep step = PackForVecStepNode::make(stage_id, + GetIndex(stage->iters, target_iter), vec_size); + CopyOnWrite()->transform_steps.push_back(step); + return DoPackForVecStep(step); +} + +Iterator State::bind_thread(int stage_id, const Iterator& it, + IteratorAnnotation thread_type) { + const Stage& stage = operator->()->stages[stage_id]; + if (thread_type < kVThread || thread_type > kThreadY) { + LOG(FATAL) << "thread_type error, valide: kVThread, kBlockX, kThreadX, " + << "kThreadY"; + } + AnnotationStep step = AnnotationStepNode::make(stage_id, + GetIndex(stage->iters, it), thread_type); + CopyOnWrite()->transform_steps.push_back(step); + return DoAnnotationStep(step); +} + +int State::cache_read(int stage_id, const std::string& scope_name, + const std::vector& reader_stage_ids, const ComputeDAG& task_dag) { + CacheReadStep step = CacheReadStepNode::make(stage_id, scope_name, reader_stage_ids); + CopyOnWrite()->transform_steps.push_back(step); + return DoCacheReadStep(step, task_dag); +} + +int State::cache_write(int stage_id, const std::string& scope_name, + const ComputeDAG& task_dag) { + CacheWriteStep step = CacheWriteStepNode::make(stage_id, scope_name); + CopyOnWrite()->transform_steps.push_back(step); + return DoCacheWriteStep(step, task_dag); +} + +void State::pragma(int stage_id, const Iterator& it, const std::string& pragma_type) { + const Stage& stage = operator->()->stages[stage_id]; + PragmaStep step = PragmaStepNode::make(stage_id, GetIndex(stage->iters, it), + pragma_type); + CopyOnWrite()->transform_steps.push_back(step); + return DoPragmaStep(step); +} + +int State::rfactor(int stage_id, const Iterator& it, int factor_iter_id, + const ComputeDAG& task_dag) { + const Stage& stage = operator->()->stages[stage_id]; + RfactorStep step = RfactorStepNode::make(stage_id, GetIndex(stage->iters, it), factor_iter_id); + CopyOnWrite()->transform_steps.push_back(step); + return DoRfactorStep(step, task_dag); +} + +void State::storage_align(int stage_id, const Iterator& it, int factor, + int offset) { + const Stage& stage = operator->()->stages[stage_id]; + StorageAlignStep step = StorageAlignStepNode::make(stage_id, + GetIndex(stage->iters, it), factor, offset); + CopyOnWrite()->transform_steps.push_back(step); + return DoStorageAlignStep(step); +} + +// Steps' implementations +void State::DoReorderStep(const ReorderStep& step) { + const Stage& stage = operator->()->stages[step->stage_id]; + + std::vector iters; + for (auto x : step->after_ids) { + iters.push_back(stage->iters[x]); + } + + StateNode* pstate = CopyOnWrite(); + pstate->stages[step->stage_id] = StageNode::make(stage->op, stage->op_type, + std::move(iters), stage->compute_at, + stage->auto_unroll_max_step, + stage->storage_offset); +} + +// common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep +std::vector State::DoSplitStepCommon(int stage_id, int iter_id, + const std::vector& lengths, + bool inner_to_outer) { + const Stage& stage = operator->()->stages[stage_id]; + const Iterator& it = stage->iters[iter_id]; + size_t old_iter_size = stage->iters.size(); + + PrimExpr tosplit_min, tosplit_extent; + if (it->range.defined()) { + tosplit_min = it->range->min; + tosplit_extent = it->range->extent; + } else { + tosplit_min = tosplit_extent = PrimExpr(); + } + + std::vector outs; + for (size_t i = 0; i < lengths.size(); ++i) { + PrimExpr l; std::string name; + if (inner_to_outer) { + l = lengths[lengths.size() - i - 1]; + name = it->name + "." + std::to_string(lengths.size() - i); + } else { + l = lengths[i]; + name = it->name + "." + std::to_string(i); + } + Iterator res; + if (l.defined() && tosplit_min.defined() && tosplit_extent.defined()) { + res = IteratorNode::make(name, Range::make_by_min_extent(tosplit_min, l), + it->iter_type, kNone); + tosplit_min = 0; + tosplit_extent = indexdiv(tosplit_extent + l - 1, l); + } else { + res = IteratorNode::make(name, Range(), it->iter_type, kNone); + tosplit_min = tosplit_extent = PrimExpr(); + } + outs.push_back(std::move(res)); + } + + Range range; + if (tosplit_min.defined() && tosplit_extent.defined()) { + range = Range::make_by_min_extent(tosplit_min, tosplit_extent); + } + if (inner_to_outer) { + outs.push_back(IteratorNode::make(it->name + ".0", range, it->iter_type, kNone)); + std::reverse(outs.begin(), outs.end()); + } else { + outs.push_back(IteratorNode::make(it->name + "." + std::to_string(lengths.size()), + range, it->iter_type, kNone)); + } + + std::vector new_iters; + new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + iter_id); + new_iters.insert(new_iters.end(), outs.begin(), outs.end()); + new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id+1, stage->iters.end()); + + StateNode* pstate = CopyOnWrite(); + pstate->stages[stage_id] = StageNode::make(stage->op, stage->op_type, + std::move(new_iters), stage->compute_at, stage->auto_unroll_max_step, + stage->storage_offset); + + // we have to replace the iterators in attach map, these two vectors keep the replacement mapping + std::vector from_iters; + std::vector to_iters; + for (size_t i = iter_id; i < old_iter_size; ++i) { + from_iters.emplace_back(stage_id, i); + to_iters.emplace_back(stage_id, i + lengths.size()); + } + pstate->attach_map.ReplaceIters(from_iters, to_iters); + return outs; +} + +std::vector State::DoSplitStep(const SplitStep& step) { + return DoSplitStepCommon(step->stage_id, step->iter_id, step->lengths, + step->inner_to_outer); +} + +std::vector State::DoFollowSplitStep(const FollowSplitStep& step) { + std::vector lengths; + step->ExtractSplitLengths(operator->()->transform_steps, &lengths); + return DoSplitStepCommon(step->stage_id, step->iter_id, lengths, true); +} + +std::vector State::DoFollowFusedSplitStep(const FollowFusedSplitStep& step) { + const PrimExpr& length = step->ExtractSplitLength(operator->()->transform_steps); + return DoSplitStepCommon(step->stage_id, step->iter_id, {length}, step->factor_or_nparts); +} + +Iterator State::DoFuseStep(const FuseStep& step) { + int stage_id = step->stage_id; + const Stage& stage = operator->()->stages[stage_id]; + int old_iter_size = static_cast(stage->iters.size()); + + std::string new_name; + PrimExpr new_extent = 1; + IteratorType new_iter_type = kSpecial; + + std::vector ori_iters; + for (size_t i = 0; i < step->fused_ids.size(); ++i) { + if (i > 0) { + CHECK_EQ(step->fused_ids[i], step->fused_ids[i-1] + 1); + } + + if (i != step->fused_ids.size() - 1) { + const auto& iter_to_attached_stage = operator->()->attach_map->iter_to_attached_stages; + if (iter_to_attached_stage.find(std::make_pair(stage_id, step->fused_ids[i])) + != iter_to_attached_stage.end()) { + LOG(FATAL) << "Invalid Fuse. Because you want to fuse iterators " + "that have been attached by some stages"; + } + } + + const Iterator& it = stage->iters[step->fused_ids[i]]; + ori_iters.push_back(it); + new_name += it->name + "@"; + + if (it->range.defined() && new_extent.defined()) { + new_extent = new_extent * it->range->extent; + } else { + new_extent = PrimExpr(); + } + + if (i == 0) { + new_iter_type = it->iter_type; + } else { + if (new_iter_type != it->iter_type) { + new_iter_type = kMixed; + } + } + } + + Range range; + if (new_extent.defined()) { + range = Range::make_by_min_extent(0, new_extent); + } + Iterator new_it = IteratorNode::make(new_name, range, new_iter_type, kNone, &ori_iters); + std::vector new_iters; + new_iters.insert(new_iters.end(), stage->iters.begin(), + stage->iters.begin() + step->fused_ids.front()); + new_iters.push_back(new_it); + new_iters.insert(new_iters.end(), stage->iters.begin() + step->fused_ids.back() + 1, + stage->iters.end()); + + StateNode* pstate = CopyOnWrite(); + pstate->stages[stage_id] = StageNode::make(stage->op, stage->op_type, + std::move(new_iters), stage->compute_at, stage->auto_unroll_max_step, + stage->storage_offset); + + // we have to replace the iterators in attach map, these two vectors keep the replacement mapping + std::vector from_iters; + std::vector to_iters; + const int begin_id = step->fused_ids.front(), end_id = step->fused_ids.back(); + for (int i = 0; i < old_iter_size; ++i) { + if (i <= begin_id) { + continue; + } else if (i > end_id) { // move forward + from_iters.emplace_back(stage_id, i); + to_iters.emplace_back(stage_id, i - end_id + begin_id); + } else { // move to the fused id + from_iters.emplace_back(stage_id, i); + to_iters.emplace_back(stage_id, begin_id); + } + } + pstate->attach_map.ReplaceIters(from_iters, to_iters); + return new_it; +} + +Iterator State::DoAnnotationStep(const AnnotationStep& step) { + const Stage& stage = operator->()->stages[step->stage_id]; + Iterator it = stage->iters[step->iter_id]; + + Iterator new_it = IteratorNode::make(it->name, it->range, it->iter_type, + step->annotation, &it->ori_iters); + Stage new_stage = stage; + new_stage.CopyOnWrite()->iters[step->iter_id] = new_it; + StateNode* pstate = CopyOnWrite(); + pstate->stages[step->stage_id] = std::move(new_stage); + return new_it; +} + +void State::DoComputeAtStep(const ComputeAtStep& step) { + const Stage& stage = operator->()->stages[step->stage_id]; + + // after compute_at, we don't know the accurate length information any more + // If we do want to know the accurate lengths, we can call ComputeDAG::ReplayAndInferBound + std::vector new_iters; + for (const Iterator& it : stage->iters) { + size_t s = it->name.size(); + if (s >= 2 && it->name[s-2] == '.' && it->name[s-1] >= '1' && it->name[s-1] <= '4') { + // We use a dangerous heuristic rule here : For multi level splitted iterators, we assume + // their length does not change after compute_at. + // Reason: These iterators are generated in MultiStagePolicy by multi level tiling, they will + // be carefully compute_at their consumers. In this case, their lengths do not change. + // We do this to keep the AnnotateCPU pass to annotate more efficiently. + new_iters.push_back(it); + } else { + new_iters.push_back(IteratorNode::make(it->name, Range(), it->iter_type, + it->annotation, &it->ori_iters)); + } + } + + StateNode* pstate = CopyOnWrite(); + pstate->stages[step->stage_id] = StageNode::make(stage->op, stage->op_type, + std::move(new_iters), kIter, stage->auto_unroll_max_step, + stage->storage_offset); + pstate->attach_map.SetComputeAtIter(step->stage_id, step->target_stage_id, step->target_iter_id); +} + +void State::DoComputeRootStep(const ComputeRootStep& step) { + const Stage& stage = operator->()->stages[step->stage_id]; + + // after compute_root, we don't know the accurate length information any more + // If we do want to know the accurate lengths, we can call ComputeDAG::ReplayAndInferBound + std::vector new_iters; + for (const Iterator& it : stage->iters) { + new_iters.push_back(IteratorNode::make(it->name, Range(), it->iter_type, + it->annotation, &it->ori_iters)); + } + + // update attach map + StateNode* pstate = CopyOnWrite(); + pstate->stages[step->stage_id] = StageNode::make(stage->op, stage->op_type, + std::move(new_iters), kRoot, stage->auto_unroll_max_step, + stage->storage_offset); + pstate->attach_map.DeleteStage(step->stage_id); +} + +void State::DoComputeInlineStep(const ComputeInlineStep& step) { + const Stage& stage = operator->()->stages[step->stage_id]; + + StateNode* pstate = CopyOnWrite(); + + // CHECK the validity of compute_inline + const auto& iter_to_attached_stages = pstate->attach_map->iter_to_attached_stages; + for (size_t i = 0; i < stage->iters.size(); ++i) { + CHECK_EQ(iter_to_attached_stages.count(std::make_pair(step->stage_id, i)), 0) + << "Invalid compute_inline: Because there are some other stages " + "that are attached to the target stage"; + } + + pstate->stages[step->stage_id].CopyOnWrite()->compute_at = kInlined; + pstate->attach_map.DeleteStage(step->stage_id); +} + +void State::DoPackForVecStep(const PackForVecStep& step) { + LOG(FATAL) << "Not implemented"; +} + +// Common part for steps that add new stages (e.g. CacheReadStep, CacheWriteStep, RfactorStep) +void AddStageModificationSteps(size_t step_id, const std::vector& transform_steps, + std::vector* replay_steps) { + const Step& step = transform_steps[step_id]; + if (step->IsInstance() || step->IsInstance()) { + replay_steps->push_back(step); + } else if (step->IsInstance()) { + // add FuseStepNode required by rfactor + if (step_id >= 2 && transform_steps[step_id - 2]->IsInstance()) { + const Step& fuse_step = transform_steps[step_id - 2]; + if (fuse_step->stage_id == step->stage_id) { + replay_steps->push_back(fuse_step); + } + } + // add SplitStepNode required by rfactor + CHECK_GE(step_id, 1); + CHECK(transform_steps[step_id - 1]->IsInstance()); + const Step& split_step = transform_steps[step_id - 1]; + CHECK_EQ(split_step->stage_id, step->stage_id); + replay_steps->push_back(split_step); + // add RfactorStepNode + replay_steps->push_back(step); + } +} + +int State::DoCacheReadStep(const CacheReadStep& step, const ComputeDAG& dag) { + StateNode* pstate = CopyOnWrite(); + std::vector replay_steps; + for (size_t i = 0; i < pstate->transform_steps.size(); ++i) { + AddStageModificationSteps(i, pstate->transform_steps, &replay_steps); + if (pstate->transform_steps[i].same_as(step)) { + break; + } + } + dag.ReplayAndGetDAG(replay_steps, &(pstate->task_dag)); + + // target -> target + target_store + // Should update target's op, insert new stage, update the later stage's op + pstate->stages[step->stage_id].CopyOnWrite()->op = + operator->()->task_dag->ops[step->stage_id]; + pstate->stages.insert(pstate->stages.begin() + step->stage_id + 1, + StageNode::make(operator->()->task_dag->ops[step->stage_id + 1])); + for (size_t i = step->stage_id + 2; i < operator->()->stages.size(); ++i) { + pstate->stages[i].CopyOnWrite()->op = operator->()->task_dag->ops[i]; + } + pstate->attach_map = + operator->()->attach_map.ApplyStageIdOfffset(step->stage_id + 1, 1); + + return step->stage_id + 1; +} + +int State::DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag) { + StateNode* pstate = CopyOnWrite(); + std::vector replay_steps; + for (size_t i = 0; i < pstate->transform_steps.size(); ++i) { + AddStageModificationSteps(i, pstate->transform_steps, &replay_steps); + if (pstate->transform_steps[i].same_as(step)) { + break; + } + } + dag.ReplayAndGetDAG(replay_steps, &(pstate->task_dag)); + + // target -> target_compute + target + // Assume target stage has never been applied any steps before cache_write + // Should insert new stage, update target stage, update the later stage's op + pstate->stages.insert(pstate->stages.begin() + step->stage_id, + StageNode::make(operator->()->task_dag->ops[step->stage_id])); + pstate->stages[step->stage_id + 1] = + StageNode::make(operator->()->task_dag->ops[step->stage_id + 1]); + for (size_t i = step->stage_id + 2; i < operator->()->stages.size(); ++i) { + pstate->stages[i].CopyOnWrite()->op = operator->()->task_dag->ops[i]; + } + pstate->attach_map = + operator->()->attach_map.ApplyStageIdOfffset(step->stage_id, 1); + + return step->stage_id; +} + +void State::DoPragmaStep(const PragmaStep& step) { + if (step->pragma_type == "debug_skip_region") { + StateNode* pstate = CopyOnWrite(); + pstate->attach_map.DeleteStage(step->stage_id); + } else if (StrStartsWith(step->pragma_type, "auto_unroll_max_step")) { + StateNode* pstate = CopyOnWrite(); + StageNode* stage = pstate->stages[step->stage_id].CopyOnWrite(); + size_t pos = step->pragma_type.find('$'); + stage->auto_unroll_max_step = atoi(step->pragma_type.c_str() + pos + 1); + } else if (step->pragma_type == "tensor_core") { + // Nothing needs to be done here + } else { + LOG(FATAL) << "Invalid pragma: " << step->pragma_type; + } +} + +int State::DoRfactorStep(const RfactorStep& step, const ComputeDAG& dag) { + StateNode* pstate = CopyOnWrite(); + const auto compute_at_type = pstate->stages[step->stage_id]->compute_at; + std::vector replay_steps; + for (size_t i = 0; i < pstate->transform_steps.size(); ++i) { + AddStageModificationSteps(i, pstate->transform_steps, &replay_steps); + if (pstate->transform_steps[i].same_as(step)) { + break; + } + } + dag.ReplayAndGetDAG(replay_steps, &(pstate->task_dag)); + + // target -> target_compute + target + // Should insert new stage, update target stage, update the later stage's op + pstate->stages.insert(pstate->stages.begin() + step->stage_id, + StageNode::make(operator->()->task_dag->ops[step->stage_id])); + // maintain the compute_at type of target stage + Stage target_stage = StageNode::make(operator->()->task_dag->ops[step->stage_id + 1]); + target_stage.CopyOnWrite()->compute_at = compute_at_type; + pstate->stages[step->stage_id + 1] = target_stage; + + for (size_t i = step->stage_id + 2; i < operator->()->stages.size(); ++i) { + pstate->stages[i].CopyOnWrite()->op = operator->()->task_dag->ops[i]; + } + pstate->attach_map = + operator->()->attach_map.ApplyStageIdOfffset(step->stage_id, 1); + + return step->stage_id; +} + +void State::DoStorageAlignStep(const StorageAlignStep& step) { + StateNode* pstate = CopyOnWrite(); + StageNode* stage = pstate->stages[step->stage_id].CopyOnWrite(); + stage->storage_offset = step->offset; +} + +void State::DoStep(const Step& step, const ComputeDAG& dag) { + if (auto ps = step.as()) { + DoReorderStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoSplitStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoFollowSplitStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoFollowFusedSplitStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoFuseStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoAnnotationStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoComputeAtStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoComputeRootStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoComputeInlineStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoPackForVecStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoCacheReadStep(GetRef(ps), dag); + } else if (auto ps = step.as()) { + DoCacheWriteStep(GetRef(ps), dag); + } else if (auto ps = step.as()) { + DoPragmaStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoRfactorStep(GetRef(ps), dag); + } else if (auto ps = step.as()) { + DoStorageAlignStep(GetRef(ps)); + } else { + LOG(FATAL) << "Invalid step: " << step; + } +} + +void State::DoSteps(const std::vector& steps, const ComputeDAG& dag) { + // Use complete rate for the study in the paper + const char* complete_rate_str = getenv("ANSOR_PROGRAM_COMPLETE_RATE"); + double complete_rate = -1.0; + if (complete_rate_str) { + complete_rate = std::stod(complete_rate_str); + } + size_t ct = 0; + + for (const auto& step : steps) { + if (complete_rate >= 0 && ct++ > steps.size() * complete_rate) { + break; + } + DoStep(step, dag); + } +} + + +void PrintStage(std::ostream* os, int stage_id, const StateNode* state, size_t base_indent, + bool delete_trivial_loop) { + const Stage& stage = state->stages[stage_id]; + + if (stage->auto_unroll_max_step != 0) { + for (size_t j = 0; j < base_indent; ++j) { + *os << " "; + } + *os << stage->op->func_name() << " auto_unroll: " + << stage->auto_unroll_max_step << "\n"; + } + if (stage->storage_offset != 0) { + for (size_t j = 0; j < base_indent; ++j) { + *os << " "; + } + *os << stage->op->func_name() << " storage_offset: " + << stage->storage_offset << "\n"; + } + + size_t indent = 0; + for (size_t i = 0; i < stage->iters.size(); ++i) { + const Iterator& iter = stage->iters[i]; + + if (!(delete_trivial_loop && iter->range.defined() && is_one(iter->range->extent))) { + for (size_t j = 0; j < base_indent + indent; ++j) { + *os << " "; + } + switch (iter->annotation) { + case kNone: *os << "for "; break; + case kUnroll: *os << "unroll "; break; + case kParallel: *os << "parallel "; break; + case kVectorize: *os << "vectorize "; break; + case kVThread: *os << "vthread "; break; + case kBlockX: *os << "gpu.blockIdx.x "; break; + case kBlockY: *os << "gpu.blockIdx.y "; break; + case kThreadX: *os << "gpu.threadIdx.x "; break; + case kThreadY: *os << "gpu.threadIdx.y "; break; + } + if (iter->range.defined()) { + *os << iter->name << " (" << iter->range->min << "," << iter->range->extent << ")" << "\n"; + } else { + *os << iter->name << " (None)" << "\n"; + } + + indent += 2; + } + + if (state != nullptr) { + AttachMap::IterKey iter_key(stage_id, i); + auto pair = state->attach_map->iter_to_attached_stages.find(iter_key); + if (pair != state->attach_map->iter_to_attached_stages.end()) { + for (const auto& attach_stage_id : pair->second) { + PrintStage(os, attach_stage_id, state, base_indent + indent, delete_trivial_loop); + } + } + } + } + + for (size_t j = 0; j < base_indent + indent; ++j) { + *os << " "; + } + *os << stage->op->func_name() << " = ...\n"; +} + +void PrintState(std::ostream* os, const StateNode* node, bool delete_trivial_loop) { + // Gather placeholders + std::vector placeholders; + for (const auto& stage : node->stages) { + if (stage->op_type == kPlaceholder) { + placeholders.push_back(stage->op->name); + } + } + + *os << "Placeholder: "; + for (size_t i = 0; i < placeholders.size(); ++i) { + *os << placeholders[i]; + if (i != placeholders.size() - 1) { + *os << ", "; + } + } + *os << "\n"; + + // Print all stages + for (size_t i = 0; i < node->stages.size(); ++i) { + const Stage& stage = node->stages[i]; + if (stage->op_type == kPlaceholder) { + continue; + } else if (stage->op_type == kCompute) { + if (stage->compute_at == kRoot) { + PrintStage(os, i, node, 0, delete_trivial_loop); + } + } else { + LOG(FATAL) << "Invalid op type"; + } + } +} + +std::string State::ToStr(bool delete_trivial_loop) const { + std::ostringstream os; + PrintState(&os, operator->(), delete_trivial_loop); + return os.str(); +} + +void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id) { + AttachMapNode* pnode = CopyOnWrite(); + + // delete the current entry of stage + DeleteStageEntry(pnode, stage_id); + + // store the new relation + IterKey iter_key(target_stage_id, target_iter_id); + pnode->stage_to_attach_iter[stage_id] = std::make_pair(target_stage_id, target_iter_id); + pnode->iter_to_attached_stages[iter_key].push_back(stage_id); +} + +void AttachMap::DeleteStage(int stage_id) { + AttachMapNode* pnode = CopyOnWrite(); + + // delete the entry of old stage + DeleteStageEntry(pnode, stage_id); +} + +void AttachMap::ReplaceIters(const std::vector& old_iters, + const std::vector& new_iters) { + AttachMapNode* pnode = CopyOnWrite(); + + CHECK_EQ(old_iters.size(), new_iters.size()); + for (size_t i = 0; i < old_iters.size(); ++i) { + auto entry = pnode->iter_to_attached_stages.find(old_iters[i]); + if (entry == pnode->iter_to_attached_stages.end()) { + continue; + } + + // replace iter in the value of `stage_to_attach_iter` + for (const auto& s : entry->second) { + pnode->stage_to_attach_iter[s] = new_iters[i]; + } + + // replace iter in the key of `iter_to_attached_stages` + std::vector attached_stages = std::move(entry->second); + pnode->iter_to_attached_stages.erase(entry); + pnode->iter_to_attached_stages[new_iters[i]] = std::move(attached_stages); + } +} + +void AttachMap::DeleteStageEntry(AttachMapNode *pnode, int stage_id) { + auto old_entry = pnode->stage_to_attach_iter.find(stage_id); + if (old_entry != pnode->stage_to_attach_iter.end()) { + // delete value in `iter_to_attached_stages` + auto entry2 = pnode->iter_to_attached_stages.find(old_entry->second); + DeleteItem(&entry2->second, stage_id); + if (entry2->second.size() == 0) { + pnode->iter_to_attached_stages.erase(entry2); + } + // delete key in `stage_to_attach_iter` + pnode->stage_to_attach_iter.erase(old_entry); + } +} + +AttachMap AttachMap::ApplyStageIdOfffset(int start_id, int offset) const { + AttachMap map = AttachMapNode::make(); + auto pmap = map.CopyOnWrite(); + for (const auto& x : operator->()->stage_to_attach_iter) { + auto key = x.first; + if (key >= start_id) { + key += offset; + } + auto value = x.second; + if (value.first >= start_id) { + value.first += offset; + } + pmap->stage_to_attach_iter.insert(std::make_pair(key, value)); + } + for (const auto& x : operator->()->iter_to_attached_stages) { + auto key = x.first; + if (key.first >= start_id) { + key.first += offset; + } + auto value = x.second; + for (auto& i : value) { + if (i >= start_id) { + i += offset; + } + } + pmap->iter_to_attached_stages.insert(std::make_pair(key, value)); + } + return map; +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { + auto* node = static_cast(ref.get()); + PrintState(&p->stream, node, true); +}); + +} // namespace ansor +} // namespace tvm diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h new file mode 100644 index 000000000000..3ffe8a7feafb --- /dev/null +++ b/src/ansor/loop_state.h @@ -0,0 +1,732 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file ansor/interfaces.h + * \brief Data structures for loop transformations + + * Basically this is a simplified TVM IR with schedule primitives. + * We don't use the existing TVM IR because + * 1. We want fast incremental change to the loop structures + * 2. We want serializable history for replay and backtracking + * 3. We want simplified IR for easy and clean feature extraction + * 4. We may create some Macro schedule primitives + + * After search is done, we will lower this IR to TVM IR and TVM schedule primitives. + * Because we share a lot common objects during search, the transformation is + * implemented in copy on write style. All objects are immutable, which is + * similar to TVM IR. + */ + +#ifndef TVM_ANSOR_LOOP_STATE_H_ +#define TVM_ANSOR_LOOP_STATE_H_ + +// #include +// #include +// #include +#include +#include +#include +#include +#include +#include +#include "expr_hasher.h" +#include "utils.h" +#include "compute_dag.h" + +namespace tvm { +namespace ansor { + +using namespace tvm::tir; + +enum IteratorType { + kSpace, // spatial iterator + kReduce, // reduction iterator + kMixed, // fused spatial and reduction iterator + kSpecial // special iterator (e.g. virtual root iterator) +}; + +enum IteratorAnnotation { + kNone, kUnroll, kVectorize, kParallel, + kVThread, kBlockX, kThreadX, kBlockY, kThreadY +}; + +enum StageType { + kPlaceholder, kCompute +}; + +enum ComputeAtType { + kRoot, // compute at root + kInlined, // inlined + kIter, // compute at some iterator +}; + +/* Iterator and Stage */ +class Iterator; class Stage; class State; + +/*! + * \brief An for loop iterator + * Similar to tvm::IterVar in `include/expr.h` + */ +class IteratorNode : public Object { + public: + std::string name; + Range range; // domain of for loop range + IteratorType iter_type; + IteratorAnnotation annotation; + std::vector ori_iters; + + static Iterator make(std::string name, Range range, + IteratorType iter_type, IteratorAnnotation annotation, + const std::vector* ori_iters = nullptr); + + static constexpr const char *_type_key = "ansor.Iterator"; + TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(Iterator, ObjectRef, IteratorNode); + +/*! + * \brief A stage in the compute declaration + * Similar to te::Stage in `include/schedule.h` + */ +class StageNode : public Object { + public: + te::Operation op; + StageType op_type; + std::vector iters; + ComputeAtType compute_at; + int16_t auto_unroll_max_step; + int storage_offset; + + static Stage make(te::Operation op); + static Stage make(te::Operation op, StageType op_type, const std::vector& iters, + ComputeAtType compute_at, int16_t auto_unroll_max_step, int storage_offset); + static Stage make(te::Operation op, StageType op_type, std::vector&& iters, + ComputeAtType compute_at, int16_t auto_unroll_max_step, int storage_offset); + + static constexpr const char *_type_key = "ansor.Stage"; + TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(Stage, ObjectRef, StageNode); + + +/*! \brief The base class for a transformation step */ +class StepNode: public Object { + public: + int stage_id; + + // Print step as equivalent python schedule API + virtual std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const = 0; + + static constexpr const char* _type_key = "ansor.Step"; + TVM_DECLARE_BASE_OBJECT_INFO(StepNode, Object); +}; +TVM_DEFINE_MUTABLE_NODE_REF(Step, StepNode); + +/* + * Note on how to add a new transform step + * + * Take fuse for example: + * 1. Define class FuseStepNode, FuseStep in loop_state.h, and implement its make function + * in FuseStepNode::make(...) loop_state.cc + * 2. Implement FuseStepNode::ApplyToSchedule and FuseStepNode::PrintAsPythonAPI. + * - In these two functions you need to lower this step with tvm's schedule API + * 3. Implement State::fuse and State::DoFuseStep. + * - In these two functions you need to incrementally update all data structures in State with + * CopyOnWrite style + * 4. Add you step to ComputeDAG::ReplaySteps and make sure it works. + * 5. Add serialization support in `struct Handler >` + * (in serialization.cc) + * 6. Add hash support in `struct hash<::tvm::ansor::Step>` (search for this function in this file) + */ + +class ReorderStep; class SplitStep; class FollowSplitStep; +class FollowFusedSplitStep; +class FuseStep; class AnnotationStep; +class ComputeAtStep; class ComputeRootStep; class ComputeInlineStep; +class PackForVecStep; class CacheReadStep; class CacheWriteStep; +class PragmaStep; class RfactorStep; class StorageAlignStep; +class AttachMap; + +class ReorderStepNode: public StepNode { + public: + std::vector after_ids; + + static ReorderStep make(int stage_id, const std::vector& after_ids); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.ReorderStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(ReorderStep, Step, ReorderStepNode); + + +class SplitStepNode: public StepNode { + public: + int iter_id; + PrimExpr extent; // the extent of the axis to split + std::vector lengths; // The split factors + bool inner_to_outer; + + static SplitStep make(int stage_id, int iter_id, PrimExpr extent, + const std::vector& lengths, + bool inner_to_outer); + + std::vector ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.SplitStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(SplitStep, Step, SplitStepNode); + +// Similar to SplitStepNode, but use split factor from another step(i.e. Follow another split step) +class FollowSplitStepNode: public StepNode { + public: + int iter_id; + int src_step_id; + int n_split; + + static FollowSplitStep make(int stage_id, int iter_id, + int src_step_id, int n_split); + + void ExtractSplitLengths(const std::vector& transform_steps, + std::vector* lengths) const; + + std::vector ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, + const std::vector& transform_steps) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.FollowSplitStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(FollowSplitStep, Step, FollowSplitStepNode); + + +// Similar to FollowSplitStep, but use split factors from multiple steps +// This can be used for the split in cooperative fetching. +class FollowFusedSplitStepNode: public StepNode { + public: + int iter_id; + std::vector src_step_ids; + int level; // Use the length in this split level + bool factor_or_nparts; // If this is true, use factor. Otherwise, use nparts + + static FollowFusedSplitStep make(int stage_id, int iter_id, + const std::vector& src_step_ids, int level, bool factor_or_nparts); + + PrimExpr ExtractSplitLength(const std::vector& transform_steps) const; + + std::vector ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, + const std::vector& transform_steps) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.FollowFusedSplitStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(FollowFusedSplitStep, Step, FollowFusedSplitStepNode); + + +class FuseStepNode: public StepNode { + public: + std::vector fused_ids; + + static FuseStep make(int stage_id, const std::vector& fused_ids); + + IterVar ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.FuseStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(FuseStep, Step, FuseStepNode); + + +class AnnotationStepNode: public StepNode { + public: + int iter_id; + IteratorAnnotation annotation; + + static AnnotationStep make(int stage_id, int iter_id, IteratorAnnotation ann); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.AnnotationStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(AnnotationStep, Step, AnnotationStepNode); + + +class ComputeAtStepNode: public StepNode { + public: + int target_stage_id; + int target_iter_id; + + static ComputeAtStep make(int stage_id, int target_stage_id, int target_iter_id); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.ComputeAtStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(ComputeAtStep, Step, ComputeAtStepNode); + + +class ComputeRootStepNode: public StepNode { + public: + static ComputeRootStep make(int stage_id); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.ComputeRootStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(ComputeRootStep, Step, ComputeRootStepNode); + + +class ComputeInlineStepNode: public StepNode { + public: + static ComputeInlineStep make(int stage_id); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.ComputeInlineStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(ComputeInlineStep, Step, ComputeInlineStepNode); + +class PackForVecStepNode: public StepNode { + public: + int iter_id; + int vec_size; + + static PackForVecStep make(int stage_id, int iter_id, int vec_size); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, te::Schedule *schedule) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.PackForVecStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(PackForVecStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(PackForVecStep, Step, PackForVecStepNode); + + +/*! \brief Apply cache_read to a stage + * TVM Api: te::Schedule::cache_read(tensor, scope, readers) */ +class CacheReadStepNode: public StepNode { + public: + std::string scope_name; + std::vector reader_stage_ids; + + static CacheReadStep make(int stage_id, std::string scope_name, + const std::vector& reader_stage_id); + + te::Tensor ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, te::Schedule *schedule) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.CacheReadStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(CacheReadStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(CacheReadStep, Step, CacheReadStepNode); + + +/*! \brief Apply cache_write to a stage + * TVM Api: te::Schedule::cache_write(tensor, scope) + * This step will cache_write all output tensors of target stage */ +class CacheWriteStepNode: public StepNode { + public: + std::string scope_name; + + static CacheWriteStep make(int stage_id, std::string scope_name); + + Array ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, te::Schedule *schedule) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.CacheWriteStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(CacheWriteStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(CacheWriteStep, Step, CacheWriteStepNode); + +/*! \brief Add pragma to a specific iterator */ +class PragmaStepNode: public StepNode { + public: + int iter_id; + std::string pragma_type; + + static PragmaStep make(int stage_id, int iter_id, std::string pragma_type); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.PragmaStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(PragmaStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(PragmaStep, Step, PragmaStepNode); + +/*! \brief Factor a reduction axis + * TVM Api: te::Schedule::rfactor(tensor, axis, factor_axis) */ +class RfactorStepNode: public StepNode { + public: + int iter_id; + int factor_iter_id; + + static RfactorStep make(int stage_id, int iter_id, int factor_iter_id); + + Array ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.RfactorStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(RfactorStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(RfactorStep, Step, RfactorStepNode); + +class StorageAlignStepNode: public StepNode { + public: + int iter_id; + int factor; + int offset; + + static StorageAlignStep make(int stage_id, int iter_id, int factor, + int offset); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.StorageAlignStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(StorageAlignStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(StorageAlignStep, Step, StorageAlignStepNode); + +/*! \brief stores the compute_at relation between stages */ +class AttachMapNode: public Object { + public: + using StageKey = int; + using IterKey = std::pair; // stage_id and iter_id + + std::unordered_map stage_to_attach_iter; + std::unordered_map> iter_to_attached_stages; + + static AttachMap make(); + + static constexpr const char* _type_key = "ansor.AttachMap"; + TVM_DECLARE_FINAL_OBJECT_INFO(AttachMapNode, Object); +}; + +/*! \brief stores the compute_at relation between stages + * This stores a bi-directional mapping from stages and iter: + * 1. Stage to its attached iterator 2. Iterator to the stage attached to it + * + * You can use AttachMapNode::stage_to_attach_iter and AttachMapNode::iter_to_attached_stages + * to query the relations */ +class AttachMap : public ObjectRef { + public: + using StageKey = int; + using IterKey = std::pair; // stage_id and iter_id + + void SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id); + void DeleteStage(int stage_id); + void ReplaceIters(const std::vector& old_iters, + const std::vector& new_iters); + AttachMap ApplyStageIdOfffset(int start_id, int offset) const; + + TVM_DEFINE_OBJECT_REF_METHODS(AttachMap, ObjectRef, AttachMapNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(AttachMapNode); + + private: + static void DeleteStageEntry(AttachMapNode* pnode, int stage_id); +}; + +/*! \brief The loop state and corresponding history steps to reach this state */ +class StateNode: public Object { + public: + std::vector stages; // Current stages and loop structures + std::vector transform_steps; // History transformation steps to reach this state + bool complete; // Indicate whether this state has unfilled tile sizes + AttachMap attach_map; // stores the compute_at relation between stages + ObjectRef aux_info; // Used to store any auxiliary info about this state + ComputeDAG task_dag; // The up-to-date ComputeDAG of this state. + // The default value is an empty NodeRef + // (means no modification to the DAG) + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("complete", &complete); + v->Visit("aux_info", &aux_info); + } + + static State make_empty_state(); + static State make(const Array& ops); + static State make(const std::vector& stages, + const std::vector& transform_steps, bool complete, ObjectRef aux_info); + + static constexpr const char* _type_key = "ansor.State"; + TVM_DECLARE_FINAL_OBJECT_INFO(StateNode, Object); +}; + +/*! \brief The loop state and corresponding history steps to reach this state */ +class State : public ObjectRef { + public: + // Schedule primitives + void reorder(int stage_id, const std::vector& order); + std::vector split(int stage_id, const Iterator& it, + const std::vector& lengths, + bool inner_to_outer = true); + std::vector follow_split(int stage_id, const Iterator& it, + int src_step_id, int n_split); + std::vector follow_fused_split(int stage_id, const Iterator& it, + const std::vector& src_step_ids, int level, bool factor_or_nparts); + Iterator fuse(int stage_id, const std::vector& iters); + Iterator vectorize(int stage_id, const Iterator& it); + Iterator parallel(int stage_id, const Iterator& it); + Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1); + // Valide thread_type: kVThread, kBlockX, kThreadX, kThreadY + Iterator bind_thread(int stage_id, const Iterator& it, + IteratorAnnotation thread_type); + void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter); + void compute_root(int stage_id); + void compute_inline(int stage_id); + void pack_for_vec(int stage_id, const Iterator& target_iter, int vec_size); + int cache_read(int stage_id, const std::string& scope_name, + const std::vector& reader_stage_ids, + const ComputeDAG& task_dag); + int cache_write(int stage_id, const std::string& scope_name, + const ComputeDAG& task_dag); + void pragma(int stage_id, const Iterator& it, const std::string& pragma_type); + int rfactor(int stage_id, const Iterator& it, int factor_iter_id, + const ComputeDAG& task_dag); + void storage_align(int stage_id, const Iterator& it, int factor, int offset); + + /* We separate these functions out, so you can call them for replay easily given history steps */ + void DoReorderStep(const ReorderStep& step); + std::vector DoSplitStep(const SplitStep& step); + std::vector DoFollowSplitStep(const FollowSplitStep& step); + std::vector DoFollowFusedSplitStep(const FollowFusedSplitStep& step); + Iterator DoFuseStep(const FuseStep& step); + Iterator DoAnnotationStep(const AnnotationStep& step); + void DoComputeAtStep(const ComputeAtStep& step); + void DoComputeRootStep(const ComputeRootStep& step); + void DoComputeInlineStep(const ComputeInlineStep& step); + void DoPackForVecStep(const PackForVecStep& step); + int DoCacheReadStep(const CacheReadStep& step, const ComputeDAG& dag); + int DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag); + void DoPragmaStep(const PragmaStep& step); + int DoRfactorStep(const RfactorStep& step, const ComputeDAG& dag); + void DoStorageAlignStep(const StorageAlignStep& step); + + /* Do transform steps + * Note: The following function only change loop state. They do not change transform_history. */ + void DoStep(const Step& step, const ComputeDAG& dag); + void DoSteps(const std::vector& step, const ComputeDAG& dag); + + // Print to str + std::string ToStr(bool delete_trivial_loop = true) const; + + TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode); + + private: + // common function for DoSplitStep and DoFollowSplitStep + std::vector DoSplitStepCommon(int stage_id, int iter_id, + const std::vector& lengths, + bool inner_to_outer); +}; + +} // namespace ansor +} // namespace tvm + + +// Hash and equal function for State, Stage, Iterator and Step +namespace std { + +template <> +struct hash<::tvm::ansor::Step> { + std::size_t operator()(const ::tvm::ansor::Step& step) const { + if (auto ps = step.as<::tvm::ansor::ReorderStepNode>()) { + return ::dmlc::HashCombine(1, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ps->after_ids)); + } else if (auto ps = step.as<::tvm::ansor::SplitStepNode>()) { + size_t ret = ::dmlc::HashCombine(2, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + ps->inner_to_outer))); + for (const auto& len : ps->lengths) { + if (len.defined()) { + auto pint = len.as<::tvm::tir::IntImmNode>(); + CHECK(pint != nullptr); + ret = ::dmlc::HashCombine(ret, pint->value); + } else { + ret = ::dmlc::HashCombine(ret, 0x5D); // a magic number + } + return ret; + } + } else if (auto ps = step.as<::tvm::ansor::FollowSplitStepNode>()) { + return ::dmlc::HashCombine(3, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + ::dmlc::HashCombine(std::hash()(ps->src_step_id), + ps->n_split)))); + } else if (auto ps = step.as<::tvm::ansor::FollowFusedSplitStepNode>()) { + return ::dmlc::HashCombine(4, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + ::dmlc::HashCombine(std::hash>()(ps->src_step_ids), + ::dmlc::HashCombine(std::hash()(ps->level), + ps->factor_or_nparts))))); + } else if (auto ps = step.as<::tvm::ansor::FuseStepNode>()) { + return ::dmlc::HashCombine(5, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ps->fused_ids)); + } else if (auto ps = step.as<::tvm::ansor::AnnotationStepNode>()) { + return ::dmlc::HashCombine(6, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + static_cast(ps->annotation)))); + } else if (auto ps = step.as<::tvm::ansor::ComputeAtStepNode>()) { + return ::dmlc::HashCombine(7, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->target_stage_id), + ps->target_iter_id))); + } else if (auto ps = step.as<::tvm::ansor::ComputeRootStepNode>()) { + return ::dmlc::HashCombine(8, + ps->stage_id); + } else if (auto ps = step.as<::tvm::ansor::ComputeInlineStepNode>()) { + return ::dmlc::HashCombine(9, + ps->stage_id); + } else if (auto ps = step.as<::tvm::ansor::PackForVecStepNode>()) { + return ::dmlc::HashCombine(10, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + ps->vec_size))); + } else if (auto ps = step.as<::tvm::ansor::CacheReadStepNode>()) { + return ::dmlc::HashCombine(11, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->scope_name), + ps->reader_stage_ids))); + } else if (auto ps = step.as<::tvm::ansor::CacheWriteStepNode>()) { + return ::dmlc::HashCombine(12, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ps->scope_name)); + } else if (auto ps = step.as<::tvm::ansor::PragmaStepNode>()) { + return ::dmlc::HashCombine(13, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + ps->pragma_type))); + } else if (auto ps = step.as<::tvm::ansor::RfactorStepNode>()) { + return ::dmlc::HashCombine(14, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + ps->factor_iter_id))); + } else if (auto ps = step.as<::tvm::ansor::StorageAlignStepNode>()) { + return ::dmlc::HashCombine(15, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + ::dmlc::HashCombine(std::hash()(ps->factor), + ps->offset)))); + } else { + LOG(FATAL) << "Invalid step"; + } + return 0; + } +}; + +template <> +struct hash<::tvm::ansor::State> { + std::size_t operator()(const ::tvm::ansor::State& state) const { + return std::hash()(state.ToStr()); + } +}; + +template <> +struct equal_to<::tvm::ansor::State> { + bool operator() (const ::tvm::ansor::State& lhs, + const ::tvm::ansor::State& rhs) const { + return lhs.ToStr() == rhs.ToStr(); + } +}; + +} // namespace std + +#endif // TVM_ANSOR_LOOP_STATE_H_ diff --git a/src/ansor/utils.cc b/src/ansor/utils.cc new file mode 100644 index 000000000000..2018cf33d1a2 --- /dev/null +++ b/src/ansor/utils.cc @@ -0,0 +1,102 @@ +/*! + * Copyright (c) 2020 by Contributors + */ + +#include "utils.h" +#include + +namespace tvm { +namespace ansor { + + +NullStream& NullStream::Global() { + static NullStream stream; + return stream; +} + +const std::vector >& SplitFactorizationMemo::GetFactorizationSchemes( + int extent, int n_lengths, int max_innermost_factor) { + QueryKey key = std::make_tuple(extent, n_lengths, max_innermost_factor); + auto it = memory_.find(key); + if (it != memory_.end()) { + return it->second; + } + + tmp_stack_.assign(n_lengths, PrimExpr()); + results_ = &memory_[key]; + n_lengths_ = n_lengths; + + DfsEnumerate(0, extent, max_innermost_factor); + + return *results_; +} + +void SplitFactorizationMemo::DfsEnumerate(int now, int remaining_lenght, int max_innermost_factor) { + if (now == n_lengths_) { + if (tmp_stack_.back().as()->value <= max_innermost_factor) { + results_->push_back(tmp_stack_); + } + } else { + for (const auto& f : GetFactors(remaining_lenght)) { + tmp_stack_[now] = PrimExpr(f); + DfsEnumerate(now + 1, remaining_lenght / f, max_innermost_factor); + } + } +} + +const std::vector& SplitFactorizationMemo::GetFactors(int n) { + auto it = factor_memory_.find(n); + if (it != factor_memory_.end()) { + return it->second; + } + + std::vector& res = factor_memory_[n]; + int step = n % 2 == 0 ? 1 : 2; + for (size_t i = 1; i < static_cast(std::sqrt(n)) + 1; i += step) { + if (n % i == 0) { + res.push_back(i); + if (n / i != i) { + res.push_back(n/i); + } + } + } + std::sort(res.begin(), res.end()); + return res; +} + +ThreadPool& ThreadPool::Global() { + static ThreadPool* pool = new ThreadPool(); + static int ct = 0; + + ct = (ct + 1) % ThreadPool::REFRESH_EVERY; + + if (ct == 0) { + pool->Abort(); + delete pool; + pool = new ThreadPool(); + } + + if (pool->NumWorkers() == 0) { + pool->Launch(std::thread::hardware_concurrency()); + } + + return *pool; +} + +TVM_REGISTER_GLOBAL("ansor.utils.GetFactorizationSchemes") +.set_body([](TVMArgs args, TVMRetValue *ret) { + int extent = args[0]; + int n_lengths = args[1]; + int max_innermost_factor = args[2]; + SplitFactorizationMemo memo; + + Array > result; + for (const auto& lens : memo.GetFactorizationSchemes(extent, n_lengths, max_innermost_factor)) { + result.push_back(lens); + } + + *ret = result; +}); + +} // namespace ansor +} // namespace tvm diff --git a/src/ansor/utils.h b/src/ansor/utils.h new file mode 100644 index 000000000000..4ea7f283ad09 --- /dev/null +++ b/src/ansor/utils.h @@ -0,0 +1,482 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file ansor/utils.h + * \brief Common utilities + */ + +#ifndef TVM_ANSOR_UTILS_H_ +#define TVM_ANSOR_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace std { + +// hash function for std::pair, std::vector and std::tuple +template +struct hash > { + std::size_t operator()(const std::pair& k) const { + return ::dmlc::HashCombine(std::hash()(k.first), std::hash()(k.second)); + } +}; + +template +struct hash > { + std::size_t operator()(const std::tuple& k) const { + return ::dmlc::HashCombine( + ::dmlc::HashCombine(std::hash()(std::get<0>(k)), std::hash()(std::get<1>(k))), + std::hash()(std::get<2>(k))); + } +}; + +template +struct hash > { + std::size_t operator()(const std::vector& vec) const { + if (vec.empty()) { + return 0; + } + std::size_t ret = std::hash()(vec[0]); + for (size_t i = 1; i < vec.size(); ++i) { + ret = ::dmlc::HashCombine(ret, std::hash()(vec[i])); + } + return ret; + } +}; + +} // namespace std + +namespace tvm { +namespace ansor { + +/*! \brief Macro to make it easy to define mutable node ref type given node */ +#define TVM_DEFINE_MUTABLE_NODE_REF(TypeName, NodeName) \ + class TypeName : public ObjectRef { \ + public: \ + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ObjectRef, NodeName); \ + }; \ + +/*! + * \brief Macro to make it easy to define node ref type that + * has a CopyOnWrite member function. + */ +#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \ + class TypeName : public BaseType { \ + public: \ + TVM_DEFINE_OBJECT_REF_METHODS(TypeName, BaseType, NodeName); \ + TVM_DEFINE_OBJECT_REF_COW_METHOD(NodeName); \ + }; + +/********** Utilities for std::vector, std::set **********/ + +/*! \brief Get the first appearance index of elements in a vector */ +template +inline void GetIndices(const std::vector& array, + const std::vector& to_locate, + std::vector* indices) { + for (const auto& v : to_locate) { + auto it = std::find(array.begin(), array.end(), v); + if (it != array.end()) { + indices->push_back(it - array.begin()); + } else { + LOG(FATAL) << "Cannot find the item"; + } + } +} + +/*! \brief Get the first appearance index of an element in a vector */ +template +inline int GetIndex(const std::vector& array, const T& to_locate) { + for (size_t i = 0; i < array.size(); ++i) { + if (array[i] == to_locate) { + return i; + } + } + LOG(FATAL) << "Cannot find the item"; + return -1; +} + +/*! \brief Delete an element in a vector */ +template +inline void DeleteItem(std::vector* array, const T& to_delete) { + auto iter = std::find(array->begin(), array->end(), to_delete); + if (iter != array->end()) { + array->erase(iter); + } +} + +/*! \brief Compute the product of all elements in a vector */ +inline int64_t ElementProduct(const std::vector& array) { + int64_t ret = 1; + for (auto x : array) { + ret *= x; + } + return ret; +} + +/* \brief Get the maximum element in a vector */ +template +T MaximumElement(const std::vector& array) { + CHECK(!array.empty()); + const T* pmax = &array[0]; + for (size_t i = 1; i < array.size(); ++i) { + if (array[i] > *pmax) { + pmax = &array[i]; + } + } + return *pmax; +} + +/*! \brief Move elements from multiple vectors to one vector */ +template +std::vector& ConcatenateMove(std::vector* out, std::vector* in) { + out->insert(out->end(), std::make_move_iterator(in->begin()), + std::make_move_iterator(in->end())); + return *out; +} + +/*! \brief Move elements from multiple vectors to one vector */ +template +std::vector& ConcatenateMove(std::vector* out, std::vector* first, Args... args) { + ConcatenateMove(out, first); + ConcatenateMove(out, args...); + return *out; +} + +/* \brief Get a random permutation of integers [0, n-1] */ +template +void RandomPermutation(int n, std::vector* out, G* gen) { + out->assign(n, 0); + std::iota(out->begin(), out->end(), 0); + std::shuffle(out->begin(), out->end(), *gen); +} + +/* \brief Random sample without replacement */ +template +void RandomSample(std::vector* in_data, size_t out_size, G* gen) { + // Note: This function is inefficient in the cases when out_size << in_data.size() + out_size = std::min(in_data->size(), out_size); + + if (in_data->size() <= out_size) { // return all + return; + } + std::vector indices; + RandomPermutation(in_data->size(), &indices, gen); + + std::vector tmp_data; + tmp_data.reserve(out_size); + for (size_t i = 0; i < out_size; ++i) { + tmp_data.push_back(std::move((*in_data)[indices[i]])); + } + + *in_data = std::move(tmp_data); +} + +/*! \brief Argsort. Order: largest to smallest */ +template +inline void Argsort(const std::vector& scores, std::vector* index) { + index->clear(); index->reserve(scores.size()); + for (size_t i = 0; i < scores.size(); ++i) { + index->push_back(i); + } + auto cmp = [&scores](int l, int r) { + return scores[l] > scores[r]; + }; + std::sort(index->begin(), index->end(), cmp); +} + +// Do x++ for all x in the set such that x >= threshold +inline void SetAddOne(std::set* set, int threshold = 0) { + std::set new_set; + for (int x : *set) { + if (x >= threshold) { + new_set.insert(x + 1); + } else { + new_set.insert(x); + } + } + *set = std::move(new_set); +} + +// Compute Jaccard Similarity of two sets +template +double JaccardSimilarity(std::set s1, std::set s2) { + std::vector intersect; + std::set_intersection(s1.begin(), s1.end(), s2.begin(), s2.end(), + std::back_inserter(intersect)); + return 1.0 * intersect.size() / (s1.size() + s2.size() - intersect.size()); +} + +/********** Utilities for std::string **********/ + +/*! Return whether a string ends with a another substring */ +inline bool StrEndsWith(const std::string& a, const std::string& b) { + if (b.size() > a.size()) return false; + return std::equal(a.begin() + a.size() - b.size(), a.end(), b.begin()); +} + +/*! Return whether a string starts with a another substring */ +inline bool StrStartsWith(const std::string& a, const std::string& b) { + if (b.size() > a.size()) return false; + return std::equal(a.begin(), a.begin() + b.size(), b.begin()); +} + +/*! Replace a sub-string to another sub-string in a string */ +inline void StrReplace(std::string* base, const std::string& from, const std::string& to) { + auto pos = base->find(from); + while (pos != std::string::npos) { + base->replace(pos, from.size(), to); + pos = base->find(from, pos + to.size()); + } +} + +/********** Utilities for TVM Containers / ByteArray **********/ + +/*! \brief Compute mean of a FloatImm array */ +inline double FloatArrayMean(const Array& float_array) { + double sum = 0; + if (float_array.empty()) { + return 0.0; + } + + for (const auto&x : float_array) { + auto floatimm = x.as(); + CHECK(floatimm != nullptr); + sum += floatimm->value; + } + return sum / float_array.size(); +} + +/*! \brief Serialize a 2-dimensional vector to TVMByteArray. + * This is used for sending data to python code */ +template +inline TVMByteArray Serialize2dVector(std::vector >&& in_data, + std::vector* out_data) { + size_t total_bytes = 0; + std::vector size_vector; + + // serialize sizes + total_bytes += (1 + in_data.size()) * sizeof(int); + size_vector.reserve(in_data.size() + 1); + size_vector.push_back(in_data.size()); + for (const auto& x : in_data) { + size_vector.push_back(static_cast(x.size())); + total_bytes += sizeof(T) * x.size(); + } + + out_data->reserve(total_bytes); + char* ptr = out_data->data(); + memmove(ptr, reinterpret_cast(size_vector.data()), (1 + in_data.size()) * sizeof(int)); + ptr += (1 + in_data.size()) * sizeof(int); + + // serialize in_data + for (auto& x : in_data) { + memmove(ptr, x.data(), sizeof(T) * x.size()); + ptr += sizeof(T) * x.size(); + x.clear(); + } + + CHECK_EQ(ptr - out_data->data(), total_bytes); + + return TVMByteArray{out_data->data(), total_bytes}; +} + +/********** Other Utilities **********/ + +// Get an int value from an Expr +inline int64_t GetIntImm(const PrimExpr& expr) { + auto pint = expr.as(); + CHECK(pint != nullptr); + return pint->value; +} + + +// Compute the product of the lengths of axes +inline int64_t AxisLengthProd(const Array& axes) { + int64_t ret = 1.0; + for (const auto& x : axes) { + if (const IntImmNode* imm = x->dom->extent.as()) { + ret *= imm->value; + } else { + return -1.0; + } + } + return ret; +} + + +// An empty output stream +class NullStream : public std::ostream { + public: + NullStream() : std::ostream(nullptr) {} + NullStream(const NullStream &) : std::ostream(nullptr) {} + static NullStream& Global(); +}; + +template +NullStream& operator<<(NullStream& os, const T& value) { + return os; +} + +/*! \brief Get std cout with verbose control */ +inline std::ostream& StdCout(int verbose) { + if (verbose >= 1) { + return std::cout; + } else { + return NullStream::Global(); + } +} + +/*! \brief Print a title */ +inline void PrintTitle(const std::string& title, int verbose) { + if (verbose >= 1) { + std::cout << "------------------------------------------------------------" << "\n"; + std::cout << "----------------------- [ " << title << " ]\n"; + std::cout << "------------------------------------------------------------" << std::endl; + } +} + +/*! \brief A simple thread pool */ +class ThreadPool { + public: + void Launch(size_t n = 1) { + for (std::size_t i = 0; i < n; ++i) { + threads_.emplace_back([this] {WorkerFunc();}); + } + } + + void BeginBatch(int n) { + finish_ct_ = n; + is_finished_ = n <= 0; + } + + template::type> + std::future Enqueue(F&& f, Args&&... args) { + std::packaged_task p(std::bind(f, args...)); + + auto r = p.get_future(); + { + std::unique_lock l(m_); + work_.emplace_back(std::move(p)); + } + work_signal_.notify_one(); + return r; + } + + void WaitBatch() { + std::unique_lock l(finish_mutex_); + if (!is_finished_) { + finish_signal_.wait(l); + } + } + + void Abort() { + CancelPending(); + Join(); + } + + void CancelPending() { + std::unique_lock l(m_); + work_.clear(); + } + + void Join() { + { + std::unique_lock l(m_); + for (size_t i = 0; i < threads_.size(); ++i) { + work_.push_back({}); + } + } + work_signal_.notify_all(); + for (auto& t : threads_) { + t.join(); + } + threads_.clear(); + } + + size_t NumWorkers() { + return threads_.size(); + } + + static const int REFRESH_EVERY = 128; + static ThreadPool& Global(); + + ~ThreadPool() { + Join(); + } + + private: + void WorkerFunc() { + while (true) { + std::packaged_task f; + { + std::unique_lock l(m_); + if (work_.empty()) { + work_signal_.wait(l, [&]{ return !work_.empty(); }); + } + f = std::move(work_.front()); + work_.pop_front(); + } + if (!f.valid()) { return; } + f(); + + finish_ct_--; + if (finish_ct_ == 0) { + std::unique_lock l(finish_mutex_); + + is_finished_ = true; + finish_signal_.notify_one(); + } + } + } + + std::mutex m_; + std::condition_variable work_signal_; + std::deque> work_; + std::vector threads_; + + bool is_finished_; + std::mutex finish_mutex_; + std::atomic finish_ct_; + std::condition_variable finish_signal_; +}; + +/*! + * \brief Enumerate all possible factorization schemes for splitting an axes. + * \note This class will memorize the results for reuse. + */ +class SplitFactorizationMemo { + public: + using QueryKey = std::tuple; + + const std::vector >& GetFactorizationSchemes( + int extent, int n_lengths, int max_innermost_factor); + const std::vector& GetFactors(int n); + + private: + void DfsEnumerate(int now, int remaining_lenght, int max_innermost_factor); + + std::unordered_map > > memory_; + + int n_lengths_; + std::vector tmp_stack_; + std::vector >* results_; + std::unordered_map> factor_memory_; +}; + +} // namespace ansor +} // namespace tvm + +#endif // TVM_ANSOR_UTILS_H_ diff --git a/tests/cpp/ansor_test.cc b/tests/cpp/ansor_test.cc new file mode 100644 index 000000000000..b9a4f25023bf --- /dev/null +++ b/tests/cpp/ansor_test.cc @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +#include +#include +#include "../../src/ansor/compute_dag.h" + +tvm::Array matmul_func(int n, int m, int k) { + using namespace tvm; + using namespace tvm::te; + + Tensor A = placeholder({n, k}, DataType::Float(32), "A"); + Tensor B = placeholder({k, m}, DataType::Float(32), "B"); + IterVar K = IterVarNode::make({0, k}, Var("k"), kCommReduce); + const auto& C = compute( + {n, m}, + [&](Var i, Var j) { return tvm::sum(A[i][K] * B[K][j], {K}); }, + "C"); + + return {A, B, C}; +} + +tvm::Array conv2d_nchw_bn_relu_func(int N, int H, int W, + int CI, int CO, int kernel_size, int strides, int padding, + int dilation = 1) { + using namespace tvm; + using namespace tvm::te; + + Tensor data = placeholder({N, CI, H, W}, DataType::Float(32), "Data"); + Tensor kernel = placeholder({CO, CI, kernel_size, kernel_size}, + DataType::Float(32), "Kernel"); + Tensor bias = placeholder({CO, 1, 1}, DataType::Float(32), "Bias"); + Tensor bn_scale = placeholder({CO, 1, 1}, DataType::Float(32), "Bn_scale"); + Tensor bn_offset = placeholder({CO, 1, 1}, DataType::Float(32), "Bn_offset"); + + int OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1); + int OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1); + + const auto& conv = topi::conv2d_nchw(data, kernel, strides, padding, + dilation); + const auto& bias_add = compute( + {N, CO, OH, OW}, + [&](Var i, Var j, Var k, Var l) { + return conv[i][j][k][l] + bias[j][0][0]; + }, + "Bias_add"); + const auto& bn_mul = compute( + {N, CO, OH, OW}, + [&](Var i, Var j, Var k, Var l) { + return bias_add[i][j][k][l] * bn_scale[j][0][0]; + }, + "Bn_mul"); + const auto& bn_add = compute( + {N, CO, OH, OW}, + [&](Var i, Var j, Var k, Var l) { + return bn_mul[i][j][k][l] + bn_offset[j][0][0]; + }, + "Bn_add"); + const auto& out = topi::relu(bn_add); + + return {data, kernel, bias, bn_scale, bn_offset, out}; +} + +TEST(ComputeDAG, Basic) { + const auto& tensors = conv2d_nchw_bn_relu_func(1, 224, 224, 3, 64, 7, 2, 3); + auto dag = tvm::ansor::ComputeDAGNode::make(tensors); + + LOG(INFO) << "\n" << dag; + LOG(INFO) << "\n" << dag->access_analyzer; +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +} From 9fcbf0bc6e3a0e985ad507dff868458a3124eb6f Mon Sep 17 00:00:00 2001 From: Chenfan Date: Wed, 27 May 2020 18:06:42 +0800 Subject: [PATCH 02/45] Split transform_step out & Update more UTs (#3) * Split transform_step out * Update GetProducers & GetConsumers * Update UTs * Add UT for CacheReadWrite & Some bug fix --- src/ansor/compute_dag.cc | 188 +++---- src/ansor/compute_dag.h | 2 +- src/ansor/loop_state.cc | 1004 +++++------------------------------ src/ansor/loop_state.h | 547 +------------------ src/ansor/transform_step.cc | 820 ++++++++++++++++++++++++++++ src/ansor/transform_step.h | 551 +++++++++++++++++++ tests/cpp/ansor_test.cc | 466 +++++++++++++++- 7 files changed, 2074 insertions(+), 1504 deletions(-) create mode 100644 src/ansor/transform_step.cc create mode 100644 src/ansor/transform_step.h diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index 31136985b330..e1ae3250d1a5 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -14,7 +14,7 @@ #include #include #include -// #include "loop_state.h" +#include "loop_state.h" #include "utils.h" // #include "../relay/pass/kernel_layout_transform.h" @@ -347,30 +347,30 @@ void AccessAnalyzer::GetProducers(const State& state, const te::Operation& op, } } -// void AccessAnalyzer::GetConsumers(const State& state, const te::Operation& op, -// OperationSet* consumers) const { -// OperationSet inlined_ops; +void AccessAnalyzer::GetConsumers(const State& state, const te::Operation& op, + OperationSet* consumers) const { + OperationSet inlined_ops; -// for (const auto& stage : state->stages) { -// if (stage->compute_at == kInlined) { -// inlined_ops.insert(stage->op); -// } -// } -// std::function collect; + for (const auto& stage : state->stages) { + if (stage->compute_at == kInlined) { + inlined_ops.insert(stage->op); + } + } + std::function collect; -// collect = [this, &collect, &inlined_ops, &consumers](const Operation& op) { -// for (const auto& iter : operator->()->read_by.at(op)) { -// if (inlined_ops.count(iter.first)) { -// collect(iter.first); -// } else { -// consumers->insert(iter.first); -// } -// } -// }; + collect = [this, &collect, &inlined_ops, &consumers](const te::Operation& op) { + for (const auto& iter : operator->()->read_by.at(op)) { + if (inlined_ops.count(iter.first)) { + collect(iter.first); + } else { + consumers->insert(iter.first); + } + } + }; -// consumers->clear(); -// collect(op); -// } + consumers->clear(); + collect(op); +} bool IntArrayEqual(const Array& arr1, const Array& arr2) { if (arr1.size() != arr2.size()) { @@ -547,9 +547,9 @@ void UpdateStageAxis(const te::Stage& stage, StageToAxesMap *stage_to_axes) { } } -// State ComputeDAG::GetInitState() const { -// return Downcast(operator->()->init_state); -// } +State ComputeDAG::GetInitState() const { + return Downcast(operator->()->init_state); +} ComputeDAG ComputeDAGNode::make(Array tensors) { auto node = make_object(); @@ -559,7 +559,7 @@ ComputeDAG ComputeDAGNode::make(Array tensors) { node->access_analyzer = AccessAnalyzerNode::make(node->tensors); node->ops = Array(node->access_analyzer->ops_topo_order); node->flop_ct = estimator.EstimateFlop(node->ops); -// node->init_state = StateNode::make(node->ops); + node->init_state = StateNode::make(node->ops); return ComputeDAG(node); } @@ -580,8 +580,8 @@ void ComputeDAGNode::VisitAttrs(tvm::AttrVisitor* v) { v->Visit("ops", &ops); v->Visit("flop_ct", &flop_ct); v->Visit("access_analyzer", &access_analyzer); -// State s = Downcast(init_state); -// v->Visit("init_state", &s); + State s = Downcast(init_state); + v->Visit("init_state", &s); } // Implemented in multi_stage_policy.cc @@ -1075,79 +1075,79 @@ void ComputeDAG::ReplayAndGetDAG(const std::vector &transform_steps, // } // } -// std::pair > ComputeDAG::ReplaySteps( -// const std::vector &transform_steps, -// std::vector *stages, -// StageToAxesMap *stage_to_axes) const { -// std::vector ops; -// for (const auto& op : operator->()->ops) { -// if (!op->IsInstance()) { -// ops.push_back(op); -// } -// } +std::pair > ComputeDAG::ReplaySteps( + const std::vector &transform_steps, + std::vector *stages, + StageToAxesMap *stage_to_axes) const { + std::vector ops; + for (const auto& op : operator->()->ops) { + if (!op->IsInstance()) { + ops.push_back(op); + } + } -// te::Schedule schedule = te::create_schedule({ops.back()}); + te::Schedule schedule = te::create_schedule({ops.back()}); -// // init axes -// stages->reserve(operator->()->ops.size()); -// for (const auto& x : operator->()->ops) { -// const te::Stage& stage = schedule.operator[](x); -// stages->push_back(stage); -// UpdateStageAxis(stage, stage_to_axes); -// } + // init axes + stages->reserve(operator->()->ops.size()); + for (const auto& x : operator->()->ops) { + const te::Stage& stage = schedule.operator[](x); + stages->push_back(stage); + UpdateStageAxis(stage, stage_to_axes); + } -// // todo(lmzheng): should we maintain the attach_map and keep the validity of compute_at -// // an splitted axis? + // todo(lmzheng): should we maintain the attach_map and keep the validity of compute_at + // an splitted axis? -// // Use complete rate for the study in the paper -// const char* complete_rate_str = getenv("ANSOR_PROGRAM_COMPLETE_RATE"); -// double complete_rate = -1.0; -// if (complete_rate_str) { -// complete_rate = std::stod(complete_rate_str); -// } -// size_t ct = 0; + // Use complete rate for the study in the paper + const char* complete_rate_str = getenv("ANSOR_PROGRAM_COMPLETE_RATE"); + double complete_rate = -1.0; + if (complete_rate_str) { + complete_rate = std::stod(complete_rate_str); + } + size_t ct = 0; -// // replay history -// for (const auto& step : transform_steps) { -// if (complete_rate >= 0 && ct++ > transform_steps.size() * complete_rate) { -// break; -// } + // replay history + for (const auto& step : transform_steps) { + if (complete_rate >= 0 && ct++ > transform_steps.size() * complete_rate) { + break; + } -// if (auto ps = step.as()) { -// ps->ApplyToSchedule(stages, stage_to_axes); -// } else if (auto ps = step.as()) { -// ps->ApplyToSchedule(stages, stage_to_axes); -// } else if (auto ps = step.as()) { -// ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); -// } else if (auto ps = step.as()) { -// ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); -// } else if (auto ps = step.as()) { -// ps->ApplyToSchedule(stages, stage_to_axes); -// } else if (auto ps = step.as()) { -// ps->ApplyToSchedule(stages, stage_to_axes); -// } else if (auto ps = step.as()) { -// ps->ApplyToSchedule(stages, stage_to_axes); -// } else if (auto ps = step.as()) { -// ps->ApplyToSchedule(stages, stage_to_axes); -// } else if (auto ps = step.as()) { -// ps->ApplyToSchedule(stages, stage_to_axes); -// } else if (auto ps = step.as()) { -// ps->ApplyToSchedule(stages, stage_to_axes, &schedule); -// } else if (auto ps = step.as()) { -// ps->ApplyToSchedule(stages, stage_to_axes, &schedule); -// } else if (auto ps = step.as()) { -// ps->ApplyToSchedule(stages, stage_to_axes); -// } else if (auto ps = step.as()) { -// ps->ApplyToSchedule(stages, stage_to_axes, &schedule); -// } else if (auto ps = step.as()) { -// ps->ApplyToSchedule(stages, stage_to_axes); -// } else { -// LOG(FATAL) << "Invalid Step"; -// } -// } + if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes, &schedule); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes, &schedule); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes, &schedule); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); + } else { + LOG(FATAL) << "Invalid Step"; + } + } -// return std::make_pair(schedule, operator->()->tensors); -// } + return std::make_pair(schedule, operator->()->tensors); +} TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h index c8da44fee828..9d0708a77f1c 100644 --- a/src/ansor/compute_dag.h +++ b/src/ansor/compute_dag.h @@ -148,7 +148,7 @@ class ComputeDAG: public ObjectRef { // Internal common parts for replaying steps std::pair > ReplaySteps( const std::vector& transform_steps, std::vector* stages, - StageToAxesMap* stage_to_axes) const {}; + StageToAxesMap* stage_to_axes) const; static constexpr const char* _layout_free_placeholders_key = "layout_free_placeholders"; // Internal common parts for inferring bound diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index 92157edc463d..f01899c4c793 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -8,825 +8,8 @@ namespace tvm { namespace ansor { -TVM_REGISTER_OBJECT_TYPE(StepNode); TVM_REGISTER_NODE_TYPE(StateNode); -inline std::string CleanName(const std::string& str) { - // to make the name valid in python code - std::string ret = str; - StrReplace(&ret, ".", "_"); - StrReplace(&ret, "@", "_"); - StrReplace(&ret, "outer", "o"); - StrReplace(&ret, "inner", "i"); - return ret; -} - -/********** Reorder **********/ -ReorderStep ReorderStepNode::make(int stage_id, const std::vector& after_ids) { - auto node = make_object(); - node->stage_id = stage_id; - node->after_ids = after_ids; - return ReorderStep(node); -} - -void ReorderStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const { - te::Stage& stage = (*stages)[stage_id]; - const std::vector& axes = (*stage_to_axes)[stage]; - CHECK_EQ(after_ids.size(), axes.size()); - - std::vector new_axes; - new_axes.reserve(axes.size()); - for (auto i : after_ids) { - new_axes.push_back(axes[i]); - } - stage.reorder(new_axes); - (*stage_to_axes)[stage] = std::move(new_axes); -} - -std::string ReorderStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - const te::Stage& stage = (*stages)[stage_id]; - std::stringstream ss; - - ss << "s[" << CleanName(stage->op->func_name()) << "].reorder("; - for (size_t i = 0; i < after_ids.size(); ++i) { - ss << CleanName((*stage_to_axes)[stage][after_ids[i]]->var->name_hint); - if (i != after_ids.size() - 1) { - ss << ", "; - } - } - ss << ")\n"; - - ApplyToSchedule(stages, stage_to_axes); - return ss.str(); -} - -/********** Split **********/ -std::vector ApplySplitToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, - int stage_id, - int iter_id, - const std::vector& lengths, - bool inner_to_outer) { - te::Stage& stage = (*stages)[stage_id]; - const std::vector& axes = (*stage_to_axes)[stage]; - - std::vector outs; - if (inner_to_outer) { - IterVar outer = axes[iter_id], inner; - for (int i = static_cast(lengths.size()) - 1; i >= 0; i--) { - IterVar to_split = outer; - stage.split(to_split, lengths[i], &outer, &inner); - outs.push_back(inner); - } - outs.push_back(outer); - } else { - IterVar outer, inner = axes[iter_id]; - for (size_t i = 0; i < lengths.size(); i++) { - IterVar to_split = inner; - stage.split_by_nparts(to_split, lengths[i], &outer, &inner); - outs.push_back(outer); - } - outs.push_back(inner); - } - - std::vector new_axes; - new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + iter_id); - if (inner_to_outer) { - new_axes.insert(new_axes.end(), outs.rbegin(), outs.rend()); - } else { - new_axes.insert(new_axes.end(), outs.begin(), outs.end()); - } - new_axes.insert(new_axes.end(), axes.begin() + iter_id + 1, axes.end()); - (*stage_to_axes)[stage] = std::move(new_axes); - - return outs; -} - -std::string PrintSplitAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - int stage_id, - int iter_id, - const std::vector& lengths, - bool inner_to_outer) { - te::Stage& stage = (*stages)[stage_id]; - auto to_split = (*stage_to_axes)[stage][iter_id]; - const auto& func_name = CleanName(stage->op->func_name()); - const auto& outs = ApplySplitToSchedule(stages, stage_to_axes, stage_id, - iter_id, lengths, inner_to_outer); - - std::stringstream ss; - int size = static_cast(lengths.size()); - if (inner_to_outer) { - for (int i = size - 1; i >= 0; i--) { - ss << CleanName(outs[size - i]->var->name_hint) << ", " - << CleanName(outs[size - i - 1]->var->name_hint) - << " = s[" << func_name << "].split(" - << CleanName(to_split->var->name_hint) - << ", factor=" << lengths[i] << ")\n"; - to_split = outs[size - i]; - } - } else { - for (int i = 0; i < size; i++) { - ss << CleanName(outs[i]->var->name_hint) << ", " - << CleanName(outs[i + 1]->var->name_hint) - << " = s[" << func_name << "].split(" - << CleanName(to_split->var->name_hint) - << ", nparts=" << lengths[i] << ")\n"; - to_split = outs[i + 1]; - } - } - - return ss.str(); -} - -SplitStep SplitStepNode::make(int stage_id, int iter_id, - PrimExpr extent, const std::vector& lengths, - bool inner_to_outer) { - auto node = make_object(); - node->stage_id = stage_id; - // Extent can be a unreducible expression in some special cases - if (extent->IsInstance()) { - node->extent = std::move(extent); - } - node->iter_id = iter_id; - node->lengths = lengths; - node->inner_to_outer = inner_to_outer; - return SplitStep(node); -} - -std::vector SplitStepNode::ApplyToSchedule( - std::vector *stages, StageToAxesMap *stage_to_axes) const { - return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, - lengths, inner_to_outer); -} - -std::string SplitStepNode::PrintAsPythonAPI( - std::vector *stages, StageToAxesMap *stage_to_axes, - te::Schedule *schedule, const std::vector& transform_steps) const { - return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, - lengths, inner_to_outer); -} - -/********** Follow Split **********/ -FollowSplitStep FollowSplitStepNode::make(int stage_id, int iter_id, - int src_step_id, int n_split) { - auto node = make_object(); - node->stage_id = stage_id; - node->iter_id = iter_id; - node->src_step_id = src_step_id; - node->n_split = n_split; - return FollowSplitStep(node); -} - -void FollowSplitStepNode::ExtractSplitLengths(const std::vector& transform_steps, - std::vector* lengths) const { - CHECK_LT(src_step_id, transform_steps.size()); - auto ps = transform_steps[src_step_id].as(); - CHECK(ps != nullptr); - - // get lengths from src step - lengths->reserve(n_split); - int j = 0; - for (; j < n_split - 1; ++j) { - lengths->push_back(ps->lengths[j]); - } - PrimExpr last_factor = 1; - for (; j < static_cast(ps->lengths.size()); ++j) { - if (ps->lengths[j].defined()) { - last_factor *= ps->lengths[j]; - } else { - last_factor = PrimExpr(); - break; - } - } - lengths->push_back(std::move(last_factor)); -} - -std::vector FollowSplitStepNode::ApplyToSchedule( - std::vector *stages, StageToAxesMap *stage_to_axes, - const std::vector& transform_steps) const { - std::vector lengths; - ExtractSplitLengths(transform_steps, &lengths); - return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, - lengths, true); -} - -std::string FollowSplitStepNode::PrintAsPythonAPI( - std::vector *stages, StageToAxesMap *stage_to_axes, - te::Schedule *schedule, const std::vector& transform_steps) const { - std::vector lengths; - ExtractSplitLengths(transform_steps, &lengths); - return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, - lengths, true); -} - -/********** Follow Fused Split **********/ -FollowFusedSplitStep FollowFusedSplitStepNode::make(int stage_id, int iter_id, - const std::vector& src_step_ids, int level, bool factor_or_nparts) { - auto node = make_object(); - node->stage_id = stage_id; - node->iter_id = iter_id; - node->src_step_ids = src_step_ids;; - node->level = level; - node->factor_or_nparts = factor_or_nparts; - return FollowFusedSplitStep(node); -} - -PrimExpr FollowFusedSplitStepNode::ExtractSplitLength(const std::vector& transform_steps) const { - PrimExpr ret(1); - - for (int src_step_id : src_step_ids) { - CHECK_LT(src_step_id, transform_steps.size()); - auto ps = transform_steps[src_step_id].as(); - CHECK(ps != nullptr); - if (ps->lengths[level].defined() && ret.defined()) { - ret *= ps->lengths[level]; - } else { - return PrimExpr(); - } - } - - return ret; -} - -std::vector FollowFusedSplitStepNode::ApplyToSchedule( - std::vector *stages, StageToAxesMap *stage_to_axes, - const std::vector& transform_steps) const { - const PrimExpr& length = ExtractSplitLength(transform_steps); - return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, - {length}, factor_or_nparts); -} - -std::string FollowFusedSplitStepNode::PrintAsPythonAPI( - std::vector *stages, StageToAxesMap *stage_to_axes, - te::Schedule *schedule, const std::vector& transform_steps) const { - const PrimExpr& length = ExtractSplitLength(transform_steps); - return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, - {length}, factor_or_nparts); -} - - -/********** Fuse **********/ -FuseStep FuseStepNode::make(int stage_id, const std::vector& fused_ids) { - auto node = make_object(); - node->stage_id = stage_id; - node->fused_ids = fused_ids; - return FuseStep(node); -} - -IterVar FuseStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const { - te::Stage& stage = (*stages)[stage_id]; - const std::vector& axes = (*stage_to_axes)[stage]; - - Array to_fuse; - for (auto i : fused_ids) { - to_fuse.push_back(axes[i]); - } - IterVar fused_axis; - stage.fuse(to_fuse, &fused_axis); - std::vector new_axes; - new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + fused_ids[0]); - new_axes.push_back(fused_axis); - new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back() + 1, - axes.end()); - (*stage_to_axes)[stage] = std::move(new_axes); - - return fused_axis; -} - -std::string FuseStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - const auto& stage = (*stages)[stage_id]; - std::stringstream to_fuse; - - for (size_t i = 0; i < fused_ids.size(); ++i) { - to_fuse << CleanName((*stage_to_axes)[stage][fused_ids[i]]->var->name_hint); - if (i != fused_ids.size() - 1) { - to_fuse << ", "; - } - } - - std::stringstream ss; - const auto& fused = ApplyToSchedule(stages, stage_to_axes); - - ss << CleanName(fused->var->name_hint) << " = s[" - << CleanName(stage->op->func_name()) << "].fuse(" - << to_fuse.str() << ")\n"; - - return ss.str(); -} - -/********** Annotation **********/ -AnnotationStep AnnotationStepNode::make(int stage_id, int iter_id, IteratorAnnotation ann) { - auto node = make_object(); - node->stage_id = stage_id; - node->iter_id = iter_id; - node->annotation = ann; - return AnnotationStep(node); -} - -void AnnotationStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const { - te::Stage& stage = (*stages)[stage_id]; - const std::vector& axes = (*stage_to_axes)[stage]; - - switch (annotation) { - case kUnroll: stage.unroll(axes[iter_id]); break; - case kVectorize: stage.vectorize(axes[iter_id]); break; - case kParallel: stage.parallel(axes[iter_id]); break; - case kVThread: stage.bind(axes[iter_id], te::thread_axis(Range(), "vthread")); break; - case kBlockX: stage.bind(axes[iter_id], te::thread_axis(Range(), "blockIdx.x")); break; - case kBlockY: stage.bind(axes[iter_id], te::thread_axis(Range(), "blockIdx.y")); break; - case kThreadX: - if (axes[iter_id]->iter_type == kCommReduce) { - const auto &thread_x = te::thread_axis(Range(), "threadIdx.x"); - stage.bind(axes[iter_id], thread_x); - stage.set_store_predicate(thread_x->var == 0); - } else { - stage.bind(axes[iter_id], te::thread_axis(Range(), "threadIdx.x")); - } - break; - case kThreadY: stage.bind(axes[iter_id], te::thread_axis(Range(), "threadIdx.y")); break; - case kNone: break; - default: LOG(FATAL) << "Invalid Annotation " << annotation; break; - } -} - -std::string AnnotationStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - std::stringstream ss; - const auto& stage = (*stages)[stage_id]; - const auto& iter = (*stage_to_axes)[stage][iter_id]; - - bool bind_reduce_iter = iter->iter_type == kCommReduce && annotation == kThreadX; - if (bind_reduce_iter) { - ss << "thread_x = tvm.thread_axis(\"threadIdx.x\")\n"; - } - - ss << "s[" << CleanName(stage->op->func_name()) << "]."; - switch (annotation) { - case kUnroll: ss << "unroll("; break; - case kVectorize: ss << "vectorize("; break; - case kParallel: ss << "parallel("; break; - case kVThread: - case kBlockX: - case kBlockY: - case kThreadX: - case kThreadY: ss << "bind("; break; - case kNone: break; - default: - LOG(FATAL) << "Invalid annotation " << annotation; break; - } - ss << CleanName(iter->var->name_hint); - switch (annotation) { - case kVThread: ss << ", tvm.thread_axis(\"vthread\")"; break; - case kBlockX: ss << ", tvm.thread_axis(\"blockIdx.x\")"; break; - case kBlockY: ss << ", tvm.thread_axis(\"blockIdy.y\")"; break; - case kThreadX: - if (bind_reduce_iter) { - ss << ", thread_x"; - } else { - ss << ", tvm.thread_axis(\"threadIdx.x\")"; - } - break; - case kThreadY: ss << ", tvm.thread_axis(\"threadIdx.y\")"; break; - default: break; - } - ss << ")\n"; - - if (bind_reduce_iter) { - ss << "s[" << CleanName(stage->op->func_name()) << "]" - << ".set_store_predicate(thread_x.var.equal(0))\n"; - } - - ApplyToSchedule(stages, stage_to_axes); - return ss.str(); -} - -/********** Compute at **********/ -ComputeAtStep ComputeAtStepNode::make(int stage_id, int target_stage_id, int target_iter_id) { - auto node = make_object(); - node->stage_id = stage_id; - node->target_stage_id = target_stage_id; - node->target_iter_id = target_iter_id; - return ComputeAtStep(node); -} - -void ComputeAtStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const { - te::Stage& stage = (*stages)[stage_id]; - const IterVar& target_axis = - (*stage_to_axes)[(*stages)[target_stage_id]][target_iter_id]; - stage.compute_at((*stages)[target_stage_id], target_axis); -} - -std::string ComputeAtStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - std::stringstream ss; - const auto& stage = (*stages)[stage_id]; - const auto& target_stage = (*stages)[target_stage_id]; - - ss << "s[" << CleanName(stage->op->func_name()) << "].compute_at(s[" - << CleanName(target_stage->op->func_name()) << "], " - << CleanName((*stage_to_axes)[target_stage][target_iter_id]->var->name_hint); - - ss << ")\n"; - ApplyToSchedule(stages, stage_to_axes); - return ss.str(); -} - -/********** Compute Root **********/ -ComputeRootStep ComputeRootStepNode::make(int stage_id) { - auto node = make_object(); - node->stage_id = stage_id; - return ComputeRootStep(node); -} - -void ComputeRootStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const { - (*stages)[stage_id].compute_root(); -} - -std::string ComputeRootStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - std::stringstream ss; - const auto& stage = (*stages)[stage_id]; - - ss << "s[" << CleanName(stage->op->func_name()) << "].compute_root()\n"; - ApplyToSchedule(stages, stage_to_axes); - - return ss.str(); -} - -/********** Compute Inline **********/ -ComputeInlineStep ComputeInlineStepNode::make(int stage_id) { - auto node = make_object(); - node->stage_id = stage_id; - return ComputeInlineStep(node); -} - -void ComputeInlineStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const { - (*stages)[stage_id].compute_inline(); -} - -std::string ComputeInlineStepNode::PrintAsPythonAPI( - std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - std::stringstream ss; - const auto& stage = (*stages)[stage_id]; - - ss << "s[" << CleanName(stage->op->func_name()) << "].compute_inline()\n"; - ApplyToSchedule(stages, stage_to_axes); - - return ss.str(); -} - -/********** Pack for vec **********/ -PackForVecStep PackForVecStepNode::make(int stage_id, int iter_id, int vec_size) { - auto node = make_object(); - node->stage_id = stage_id; - node->iter_id = iter_id; - node->vec_size = vec_size; - return PackForVecStep(node); -} - -void PackForVecStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, te::Schedule *schedule) const { - LOG(FATAL) << "Not implemented"; -} - -std::string PackForVecStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - LOG(FATAL) << "Not implemented"; - return ""; -} - -/********** Cache read **********/ -CacheReadStep CacheReadStepNode::make(int stage_id, std::string scope_name, - const std::vector& reader_stage_ids) { - auto node = make_object(); - node->stage_id = stage_id; - node->scope_name = std::move(scope_name); - node->reader_stage_ids = reader_stage_ids; - return CacheReadStep(node); -} - -te::Tensor CacheReadStepNode::ApplyToSchedule(std::vector* stages, - StageToAxesMap *stage_to_axes, te::Schedule *schedule) const { - te::Stage& stage = (*stages)[stage_id]; - - Array readers; - for (const auto& i : reader_stage_ids) { - readers.push_back((*stages)[i]->origin_op); - } - auto out = schedule->cache_read(stage->origin_op.output(0), scope_name, readers); - - const auto& new_stage = (*schedule)[out->op]; - UpdateStageAxis(new_stage, stage_to_axes); - stages->insert(stages->begin() + stage_id + 1, new_stage); - - return out; -} - -std::string CacheReadStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - std::stringstream ss; - // copy stage here, for the original stage will change after apply - auto stage = (*stages)[stage_id]; - std::vector reader_stages; - for (size_t i = 0; i < reader_stage_ids.size(); ++i) { - reader_stages.push_back((*stages)[reader_stage_ids[i]]); - } - - auto out = ApplyToSchedule(stages, stage_to_axes, schedule); - - ss << CleanName(out->op->func_name()) << " = " - << "s.cache_read(" << CleanName(stage->op->func_name()) << ", \"" - << scope_name << "\", [" - << CleanName(reader_stages[0]->op->func_name()); - for (size_t i = 1; i < reader_stage_ids.size(); ++i) { - ss << ", " << CleanName(reader_stages[i]->op->func_name()); - } - ss << "])\n"; - - const auto& iters = out->op->root_iter_vars(); - for (size_t i = 0; i < iters.size(); ++i) { - ss << CleanName(iters[i]->var->name_hint); - if (i != iters.size() - 1) { - ss << ", "; - } - } - ss << " = " << "tuple(" << CleanName(out->op->func_name()) - << ".op.axis)\n"; - - return ss.str(); -} - -/********** Cache write **********/ -CacheWriteStep CacheWriteStepNode::make(int stage_id, std::string scope_name) { - auto node = make_object(); - node->stage_id = stage_id; - node->scope_name = std::move(scope_name); - return CacheWriteStep(node); -} - -Array CacheWriteStepNode::ApplyToSchedule( - std::vector *stages, StageToAxesMap *stage_to_axes, - te::Schedule *schedule) const { - te::Stage& stage = (*stages)[stage_id]; - - Array tensor_array; - // If the target stage has multi outputs, TVM requires to cache_write - // all of them or schedule.cache_write will raise an error - for (auto i = 0; i < stage->op->num_outputs(); ++i) { - tensor_array.push_back(stage->origin_op.output(i)); - } - auto outs = schedule->cache_write(tensor_array, scope_name); - - UpdateStageAxis(stage, stage_to_axes); - // Even if there is multi outputs, TVM schedule only generate one - // new stage - const auto& new_stage = (*schedule)[outs[0]->op]; - UpdateStageAxis(new_stage, stage_to_axes); - stages->insert(stages->begin() + stage_id, new_stage); - - return outs; -} - -std::string CacheWriteStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - std::stringstream ss; - // copy stage here, for the original stage will change after apply - te::Stage stage = (*stages)[stage_id]; - - auto outs = ApplyToSchedule(stages, stage_to_axes, schedule); - - for (size_t i = 0; i < outs.size(); ++i) { - ss << CleanName(outs[i]->op->func_name()) << ", "; - } - ss << "= " << "s.cache_write([" - << CleanName(stage->op.output(0)->op->name); - for (auto i = 1; i < stage->op->num_outputs(); ++i) { - ss << ", " << CleanName(stage->op.output(i)->op->name); - } - ss << "], \"" << scope_name << "\")\n"; - - for (const auto& out : outs) { - const auto& iters = out->op->root_iter_vars(); - for (size_t i = 0; i < iters.size(); ++i) { - ss << CleanName(iters[i]->var->name_hint); - if (i != iters.size() - 1) { - ss << ", "; - } - } - ss << " = " << "tuple(" << CleanName(out->op->func_name()) - << ".op.axis)" - << " + " << "tuple(" << CleanName(out->op->func_name()) - << ".op.reduce_axis)\n"; - } - - return ss.str(); -} - -/********** Pragma **********/ -PragmaStep PragmaStepNode::make(int stage_id, int iter_id, - std::string pragma_type) { - auto node = make_object(); - node->stage_id = stage_id; - node->iter_id = iter_id; - node->pragma_type = std::move(pragma_type); - return PragmaStep(node); -} - -void PragmaStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const { - te::Stage& stage = (*stages)[stage_id]; - const std::vector& axes = (*stage_to_axes)[stage]; - if (StrStartsWith(pragma_type, "auto_unroll_max_step")) { - size_t pos = pragma_type.find('$'); - int value = atoi(pragma_type.c_str() + pos + 1); - stage.pragma(axes[iter_id], "auto_unroll_max_step", value); - stage.pragma(axes[iter_id], "unroll_explicit", true); - } else { - stage.pragma(axes[iter_id], pragma_type); - } -} - -std::string PragmaStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - std::stringstream ss; - const auto& stage = (*stages)[stage_id]; - - if (StrStartsWith(pragma_type, "auto_unroll_max_step")) { - size_t pos = pragma_type.find('$'); - int value = atoi(pragma_type.c_str() + pos + 1); - ss << "s[" << CleanName(stage->op->func_name()) << "].pragma(" - << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) - << ", \"auto_unroll_max_step\", " << value << ")\n"; - ss << "s[" << CleanName(stage->op->func_name()) << "].pragma(" - << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) - << ", \"unroll_explicit\", True)\n"; - } else { - ss << "s[" << CleanName(stage->op->func_name()) << "].pragma(" - << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", \"" - << pragma_type << "\")\n"; - } - - ApplyToSchedule(stages, stage_to_axes); - return ss.str(); -} - -/********** Rfactor **********/ -RfactorStep RfactorStepNode::make(int stage_id, int iter_id, int factor_iter_id) { - auto node = make_object(); - node->stage_id = stage_id; - node->iter_id = iter_id; - node->factor_iter_id = factor_iter_id; - return RfactorStep(node); -} - -Array RfactorStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, te::Schedule *schedule) const { - const auto& stage = (*stages)[stage_id]; - const std::vector& axes = (*stage_to_axes)[stage]; - - const te::Tensor& tensor = stage->origin_op.output(0); - const IterVar& axis = axes[iter_id]; - auto outs = schedule->rfactor(tensor, axis, factor_iter_id); - - UpdateStageAxis(stage, stage_to_axes); - - const auto& new_stage = (*schedule)[outs[0]->op]; - UpdateStageAxis(new_stage, stage_to_axes); - stages->insert(stages->begin() + stage_id, new_stage); - - return outs; -} - -std::string RfactorStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - std::stringstream ss; - const auto& stage = (*stages)[stage_id]; - - const auto& tensor_name = CleanName(stage->origin_op.output(0)->op->name); - const auto& axis_name = CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint); - - const auto& outs = ApplyToSchedule(stages, stage_to_axes, schedule); - - for (size_t i = 0; i < outs.size(); ++i) { - ss << CleanName(outs[i]->op->func_name()); - if (i != outs.size() - 1) { - ss << ", "; - } - } - ss << " = " << "s.rfactor(" - << tensor_name << ", " - << axis_name << ", " - << factor_iter_id << ")\n"; - - for (const auto& out : outs) { - const auto& iters = out->op->root_iter_vars(); - for (size_t i = 0; i < iters.size(); ++i) { - ss << CleanName(iters[i]->var->name_hint); - if (i != iters.size() - 1) { - ss << ", "; - } - } - ss << " = " << "tuple(" << CleanName(out->op->func_name()) - << ".op.axis)" - << " + " << "tuple(" << CleanName(out->op->func_name()) - << ".op.reduce_axis)\n"; - } - - const auto& output = (*stages)[stage_id + 1]->op.output(0); - const auto& iters = output->op->root_iter_vars(); - for (size_t i = 0; i < iters.size(); ++i) { - ss << CleanName(iters[i]->var->name_hint); - if (i != iters.size() - 1) { - ss << ", "; - } - } - ss << " = " << "tuple(s[" << CleanName(output->op->func_name()) - << "].op.axis)" - << " + " << "tuple(s[" << CleanName(output->op->func_name()) - << "].op.reduce_axis)\n"; - - return ss.str(); -} - -/********** StorageAlign **********/ - -StorageAlignStep StorageAlignStepNode::make(int stage_id, int iter_id, - int factor, int offset) { - auto node = make_object(); - node->stage_id = stage_id; - node->iter_id = iter_id; - node->factor = factor; - node->offset = offset; - return StorageAlignStep(node); -} - -void StorageAlignStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const { - te::Stage& stage = (*stages)[stage_id]; - const std::vector& axes = (*stage_to_axes)[stage]; - stage.storage_align(axes[iter_id], factor, offset); -} - -std::string StorageAlignStepNode::PrintAsPythonAPI( - std::vector *stages, StageToAxesMap *stage_to_axes, - te::Schedule *schedule, const std::vector& transform_steps) const { - std::stringstream ss; - const auto& stage = (*stages)[stage_id]; - ss << "s[" << CleanName(stage->op->func_name()) << "].storage_align(" - << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", " - << factor << ", " << offset << ")\n"; - - ApplyToSchedule(stages, stage_to_axes); - return ss.str(); -} - -// Maker for other classes -Iterator IteratorNode::make(std::string name, Range range, - IteratorType iter_type, IteratorAnnotation annotation, - const std::vector* ori_iters) { - auto node = make_object(); - node->name = std::move(name); - node->range = std::move(range); - node->iter_type = iter_type; - node->annotation = annotation; - if (ori_iters != nullptr) { - node->ori_iters = *ori_iters; - } - return Iterator(node); -} - Stage StageNode::make(te::Operation op) { auto node = make_object(); if (op->IsInstance()) { @@ -854,8 +37,10 @@ Stage StageNode::make(te::Operation op) { return Stage(node); } -Stage StageNode::make(te::Operation op, StageType op_type, const std::vector& iters, - ComputeAtType compute_at, int16_t auto_unroll_max_step, int storage_offset) { +Stage StageNode::make(te::Operation op, StageType op_type, + const std::vector& iters, + ComputeAtType compute_at, int16_t auto_unroll_max_step, + int storage_offset) { auto node = make_object(); node->op = std::move(op); node->op_type = op_type; @@ -866,8 +51,10 @@ Stage StageNode::make(te::Operation op, StageType op_type, const std::vector&& iters, - ComputeAtType compute_at, int16_t auto_unroll_max_step, int storage_offset) { +Stage StageNode::make(te::Operation op, StageType op_type, + std::vector&& iters, + ComputeAtType compute_at, int16_t auto_unroll_max_step, + int storage_offset) { auto node = make_object(); node->op = std::move(op); node->op_type = op_type; @@ -927,8 +114,9 @@ void State::reorder(int stage_id, const std::vector& order) { DoReorderStep(step); } -std::vector State::split(int stage_id, - const Iterator& it, const std::vector& lengths, bool inner_to_outer) { +std::vector State::split(int stage_id, const Iterator& it, + const std::vector& lengths, + bool inner_to_outer) { const Stage& stage = operator->()->stages[stage_id]; SplitStep step = SplitStepNode::make(stage_id, GetIndex(stage->iters, it), @@ -949,8 +137,9 @@ std::vector State::follow_split(int stage_id, } -std::vector State::follow_fused_split(int stage_id, const Iterator& it, - const std::vector& src_step_ids, int level, bool factor_or_nparts) { +std::vector State::follow_fused_split( + int stage_id, const Iterator& it, const std::vector& src_step_ids, + int level, bool factor_or_nparts) { const Stage& stage = operator->()->stages[stage_id]; FollowFusedSplitStep step = FollowFusedSplitStepNode::make(stage_id, @@ -970,24 +159,24 @@ Iterator State::fuse(int stage_id, const std::vector& iters) { Iterator State::vectorize(int stage_id, const Iterator& it) { const Stage& stage = operator->()->stages[stage_id]; - AnnotationStep step = AnnotationStepNode::make(stage_id, GetIndex(stage->iters, it), - kVectorize); + AnnotationStep step = AnnotationStepNode::make( + stage_id, GetIndex(stage->iters, it), kVectorize); CopyOnWrite()->transform_steps.push_back(step); return DoAnnotationStep(step); } Iterator State::parallel(int stage_id, const Iterator& it) { const Stage& stage = operator->()->stages[stage_id]; - AnnotationStep step = AnnotationStepNode::make(stage_id, GetIndex(stage->iters, it), - kParallel); + AnnotationStep step = AnnotationStepNode::make( + stage_id, GetIndex(stage->iters, it), kParallel); CopyOnWrite()->transform_steps.push_back(step); return DoAnnotationStep(step); } Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) { const Stage& stage = operator->()->stages[stage_id]; - AnnotationStep step = AnnotationStepNode::make(stage_id, GetIndex(stage->iters, it), - kUnroll); + AnnotationStep step = AnnotationStepNode::make(stage_id, + GetIndex(stage->iters, it), kUnroll); // don't unroll if the extent is larger than max_unroll if (max_unroll != -1 && it->range.defined()) { @@ -1002,7 +191,8 @@ Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) { return DoAnnotationStep(step); } -void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) { +void State::compute_at(int stage_id, int target_stage_id, + const Iterator& target_iter) { const Stage& target_stage = operator->()->stages[target_stage_id]; ComputeAtStep step = ComputeAtStepNode::make(stage_id, target_stage_id, GetIndex(target_stage->iters, target_iter)); @@ -1022,7 +212,8 @@ void State::compute_inline(int stage_id) { return DoComputeInlineStep(step); } -void State::pack_for_vec(int stage_id, const Iterator& target_iter, int vec_size) { +void State::pack_for_vec(int stage_id, const Iterator& target_iter, + int vec_size) { const Stage& stage = operator->()->stages[stage_id]; PackForVecStep step = PackForVecStepNode::make(stage_id, GetIndex(stage->iters, target_iter), vec_size); @@ -1044,8 +235,10 @@ Iterator State::bind_thread(int stage_id, const Iterator& it, } int State::cache_read(int stage_id, const std::string& scope_name, - const std::vector& reader_stage_ids, const ComputeDAG& task_dag) { - CacheReadStep step = CacheReadStepNode::make(stage_id, scope_name, reader_stage_ids); + const std::vector& reader_stage_ids, + const ComputeDAG& task_dag) { + CacheReadStep step = CacheReadStepNode::make(stage_id, scope_name, + reader_stage_ids); CopyOnWrite()->transform_steps.push_back(step); return DoCacheReadStep(step, task_dag); } @@ -1057,7 +250,8 @@ int State::cache_write(int stage_id, const std::string& scope_name, return DoCacheWriteStep(step, task_dag); } -void State::pragma(int stage_id, const Iterator& it, const std::string& pragma_type) { +void State::pragma(int stage_id, const Iterator& it, + const std::string& pragma_type) { const Stage& stage = operator->()->stages[stage_id]; PragmaStep step = PragmaStepNode::make(stage_id, GetIndex(stage->iters, it), pragma_type); @@ -1068,7 +262,8 @@ void State::pragma(int stage_id, const Iterator& it, const std::string& pragma_t int State::rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& task_dag) { const Stage& stage = operator->()->stages[stage_id]; - RfactorStep step = RfactorStepNode::make(stage_id, GetIndex(stage->iters, it), factor_iter_id); + RfactorStep step = RfactorStepNode::make(stage_id, GetIndex(stage->iters, it), + factor_iter_id); CopyOnWrite()->transform_steps.push_back(step); return DoRfactorStep(step, task_dag); } @@ -1093,15 +288,16 @@ void State::DoReorderStep(const ReorderStep& step) { StateNode* pstate = CopyOnWrite(); pstate->stages[step->stage_id] = StageNode::make(stage->op, stage->op_type, - std::move(iters), stage->compute_at, + std::move(iters), + stage->compute_at, stage->auto_unroll_max_step, stage->storage_offset); } // common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep -std::vector State::DoSplitStepCommon(int stage_id, int iter_id, - const std::vector& lengths, - bool inner_to_outer) { +std::vector State::DoSplitStepCommon( + int stage_id, int iter_id, const std::vector& lengths, + bool inner_to_outer) { const Stage& stage = operator->()->stages[stage_id]; const Iterator& it = stage->iters[iter_id]; size_t old_iter_size = stage->iters.size(); @@ -1142,24 +338,29 @@ std::vector State::DoSplitStepCommon(int stage_id, int iter_id, range = Range::make_by_min_extent(tosplit_min, tosplit_extent); } if (inner_to_outer) { - outs.push_back(IteratorNode::make(it->name + ".0", range, it->iter_type, kNone)); + outs.push_back(IteratorNode::make(it->name + ".0", range, it->iter_type, + kNone)); std::reverse(outs.begin(), outs.end()); } else { - outs.push_back(IteratorNode::make(it->name + "." + std::to_string(lengths.size()), - range, it->iter_type, kNone)); + outs.push_back(IteratorNode::make( + it->name + "." + std::to_string(lengths.size()), range, it->iter_type, + kNone)); } std::vector new_iters; - new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + iter_id); + new_iters.insert(new_iters.end(), stage->iters.begin(), + stage->iters.begin() + iter_id); new_iters.insert(new_iters.end(), outs.begin(), outs.end()); - new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id+1, stage->iters.end()); + new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id+1, + stage->iters.end()); StateNode* pstate = CopyOnWrite(); pstate->stages[stage_id] = StageNode::make(stage->op, stage->op_type, std::move(new_iters), stage->compute_at, stage->auto_unroll_max_step, stage->storage_offset); - // we have to replace the iterators in attach map, these two vectors keep the replacement mapping + // we have to replace the iterators in attach map, + // these two vectors keep the replacement mapping std::vector from_iters; std::vector to_iters; for (size_t i = iter_id; i < old_iter_size; ++i) { @@ -1181,9 +382,12 @@ std::vector State::DoFollowSplitStep(const FollowSplitStep& step) { return DoSplitStepCommon(step->stage_id, step->iter_id, lengths, true); } -std::vector State::DoFollowFusedSplitStep(const FollowFusedSplitStep& step) { - const PrimExpr& length = step->ExtractSplitLength(operator->()->transform_steps); - return DoSplitStepCommon(step->stage_id, step->iter_id, {length}, step->factor_or_nparts); +std::vector State::DoFollowFusedSplitStep( + const FollowFusedSplitStep& step) { + const PrimExpr& length = step->ExtractSplitLength( + operator->()->transform_steps); + return DoSplitStepCommon(step->stage_id, step->iter_id, {length}, + step->factor_or_nparts); } Iterator State::DoFuseStep(const FuseStep& step) { @@ -1202,8 +406,10 @@ Iterator State::DoFuseStep(const FuseStep& step) { } if (i != step->fused_ids.size() - 1) { - const auto& iter_to_attached_stage = operator->()->attach_map->iter_to_attached_stages; - if (iter_to_attached_stage.find(std::make_pair(stage_id, step->fused_ids[i])) + const auto& iter_to_attached_stage = + operator->()->attach_map->iter_to_attached_stages; + if (iter_to_attached_stage.find(std::make_pair(stage_id, + step->fused_ids[i])) != iter_to_attached_stage.end()) { LOG(FATAL) << "Invalid Fuse. Because you want to fuse iterators " "that have been attached by some stages"; @@ -1233,20 +439,23 @@ Iterator State::DoFuseStep(const FuseStep& step) { if (new_extent.defined()) { range = Range::make_by_min_extent(0, new_extent); } - Iterator new_it = IteratorNode::make(new_name, range, new_iter_type, kNone, &ori_iters); + Iterator new_it = IteratorNode::make(new_name, range, new_iter_type, kNone, + &ori_iters); std::vector new_iters; new_iters.insert(new_iters.end(), stage->iters.begin(), - stage->iters.begin() + step->fused_ids.front()); + stage->iters.begin() + step->fused_ids.front()); new_iters.push_back(new_it); - new_iters.insert(new_iters.end(), stage->iters.begin() + step->fused_ids.back() + 1, - stage->iters.end()); + new_iters.insert(new_iters.end(), + stage->iters.begin() + step->fused_ids.back() + 1, + stage->iters.end()); StateNode* pstate = CopyOnWrite(); pstate->stages[stage_id] = StageNode::make(stage->op, stage->op_type, std::move(new_iters), stage->compute_at, stage->auto_unroll_max_step, stage->storage_offset); - // we have to replace the iterators in attach map, these two vectors keep the replacement mapping + // we have to replace the iterators in attach map, + // these two vectors keep the replacement mapping std::vector from_iters; std::vector to_iters; const int begin_id = step->fused_ids.front(), end_id = step->fused_ids.back(); @@ -1282,15 +491,18 @@ void State::DoComputeAtStep(const ComputeAtStep& step) { const Stage& stage = operator->()->stages[step->stage_id]; // after compute_at, we don't know the accurate length information any more - // If we do want to know the accurate lengths, we can call ComputeDAG::ReplayAndInferBound + // If we do want to know the accurate lengths, we can call + // ComputeDAG::ReplayAndInferBound std::vector new_iters; for (const Iterator& it : stage->iters) { size_t s = it->name.size(); - if (s >= 2 && it->name[s-2] == '.' && it->name[s-1] >= '1' && it->name[s-1] <= '4') { - // We use a dangerous heuristic rule here : For multi level splitted iterators, we assume - // their length does not change after compute_at. - // Reason: These iterators are generated in MultiStagePolicy by multi level tiling, they will - // be carefully compute_at their consumers. In this case, their lengths do not change. + if (s >= 2 && it->name[s-2] == '.' && it->name[s-1] >= '1' && + it->name[s-1] <= '4') { + // We use a dangerous heuristic rule here : For multi level splitted + // iterators, we assume their length does not change after compute_at. + // Reason: These iterators are generated in MultiStagePolicy by multi + // level tiling, they will be carefully compute_at their consumers. + // In this case, their lengths do not change. // We do this to keep the AnnotateCPU pass to annotate more efficiently. new_iters.push_back(it); } else { @@ -1303,14 +515,16 @@ void State::DoComputeAtStep(const ComputeAtStep& step) { pstate->stages[step->stage_id] = StageNode::make(stage->op, stage->op_type, std::move(new_iters), kIter, stage->auto_unroll_max_step, stage->storage_offset); - pstate->attach_map.SetComputeAtIter(step->stage_id, step->target_stage_id, step->target_iter_id); + pstate->attach_map.SetComputeAtIter(step->stage_id, step->target_stage_id, + step->target_iter_id); } void State::DoComputeRootStep(const ComputeRootStep& step) { const Stage& stage = operator->()->stages[step->stage_id]; // after compute_root, we don't know the accurate length information any more - // If we do want to know the accurate lengths, we can call ComputeDAG::ReplayAndInferBound + // If we do want to know the accurate lengths, we can call + // ComputeDAG::ReplayAndInferBound std::vector new_iters; for (const Iterator& it : stage->iters) { new_iters.push_back(IteratorNode::make(it->name, Range(), it->iter_type, @@ -1331,7 +545,8 @@ void State::DoComputeInlineStep(const ComputeInlineStep& step) { StateNode* pstate = CopyOnWrite(); // CHECK the validity of compute_inline - const auto& iter_to_attached_stages = pstate->attach_map->iter_to_attached_stages; + const auto& iter_to_attached_stages = + pstate->attach_map->iter_to_attached_stages; for (size_t i = 0; i < stage->iters.size(); ++i) { CHECK_EQ(iter_to_attached_stages.count(std::make_pair(step->stage_id, i)), 0) << "Invalid compute_inline: Because there are some other stages " @@ -1346,15 +561,18 @@ void State::DoPackForVecStep(const PackForVecStep& step) { LOG(FATAL) << "Not implemented"; } -// Common part for steps that add new stages (e.g. CacheReadStep, CacheWriteStep, RfactorStep) -void AddStageModificationSteps(size_t step_id, const std::vector& transform_steps, - std::vector* replay_steps) { +// Common part for steps that add new stages +// (e.g. CacheReadStep, CacheWriteStep, RfactorStep) +void AddStageModificationSteps(size_t step_id, + const std::vector& transform_steps, std::vector* replay_steps) { const Step& step = transform_steps[step_id]; - if (step->IsInstance() || step->IsInstance()) { + if (step->IsInstance() || + step->IsInstance()) { replay_steps->push_back(step); } else if (step->IsInstance()) { // add FuseStepNode required by rfactor - if (step_id >= 2 && transform_steps[step_id - 2]->IsInstance()) { + if (step_id >= 2 && + transform_steps[step_id - 2]->IsInstance()) { const Step& fuse_step = transform_steps[step_id - 2]; if (fuse_step->stage_id == step->stage_id) { replay_steps->push_back(fuse_step); @@ -1406,7 +624,12 @@ int State::DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag) { break; } } + + int last_dag_op_size = pstate->task_dag.defined() ? + pstate->task_dag->ops.size() : dag->ops.size(); dag.ReplayAndGetDAG(replay_steps, &(pstate->task_dag)); + int added_ops = pstate->task_dag->ops.size() - last_dag_op_size; + CHECK_GE(added_ops, 1); // target -> target_compute + target // Assume target stage has never been applied any steps before cache_write @@ -1415,11 +638,24 @@ int State::DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag) { StageNode::make(operator->()->task_dag->ops[step->stage_id])); pstate->stages[step->stage_id + 1] = StageNode::make(operator->()->task_dag->ops[step->stage_id + 1]); - for (size_t i = step->stage_id + 2; i < operator->()->stages.size(); ++i) { + int next_stage_id = step->stage_id + 2; + // Notice: added_ops should actually assert to be 1 + // branch of 2 here is somehow a hack to TVM's cache_write bug with + // multi outputs, see test/cpp/ansor_test.cc: CacheReadWrite test + // for more information + // TODO(jcf94): Fix this + if (added_ops == 2) { + pstate->stages.insert(pstate->stages.begin() + next_stage_id, + StageNode::make(operator->()->task_dag->ops[next_stage_id])); + next_stage_id++; + } else if (added_ops > 2) { + LOG(ERROR) << "Unexpected behavior of CacheWrite."; + } + for (size_t i = next_stage_id; i < operator->()->task_dag->ops.size(); ++i) { pstate->stages[i].CopyOnWrite()->op = operator->()->task_dag->ops[i]; } pstate->attach_map = - operator->()->attach_map.ApplyStageIdOfffset(step->stage_id, 1); + operator->()->attach_map.ApplyStageIdOfffset(step->stage_id, added_ops); return step->stage_id; } @@ -1530,8 +766,8 @@ void State::DoSteps(const std::vector& steps, const ComputeDAG& dag) { } -void PrintStage(std::ostream* os, int stage_id, const StateNode* state, size_t base_indent, - bool delete_trivial_loop) { +void PrintStage(std::ostream* os, int stage_id, const StateNode* state, + size_t base_indent, bool delete_trivial_loop) { const Stage& stage = state->stages[stage_id]; if (stage->auto_unroll_max_step != 0) { @@ -1553,7 +789,8 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state, size_t b for (size_t i = 0; i < stage->iters.size(); ++i) { const Iterator& iter = stage->iters[i]; - if (!(delete_trivial_loop && iter->range.defined() && is_one(iter->range->extent))) { + if (!(delete_trivial_loop && iter->range.defined() && + is_one(iter->range->extent))) { for (size_t j = 0; j < base_indent + indent; ++j) { *os << " "; } @@ -1569,7 +806,8 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state, size_t b case kThreadY: *os << "gpu.threadIdx.y "; break; } if (iter->range.defined()) { - *os << iter->name << " (" << iter->range->min << "," << iter->range->extent << ")" << "\n"; + *os << iter->name << " (" << iter->range->min << "," + << iter->range->extent << ")" << "\n"; } else { *os << iter->name << " (None)" << "\n"; } @@ -1582,7 +820,8 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state, size_t b auto pair = state->attach_map->iter_to_attached_stages.find(iter_key); if (pair != state->attach_map->iter_to_attached_stages.end()) { for (const auto& attach_stage_id : pair->second) { - PrintStage(os, attach_stage_id, state, base_indent + indent, delete_trivial_loop); + PrintStage(os, attach_stage_id, state, base_indent + indent, + delete_trivial_loop); } } } @@ -1594,7 +833,8 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state, size_t b *os << stage->op->func_name() << " = ...\n"; } -void PrintState(std::ostream* os, const StateNode* node, bool delete_trivial_loop) { +void PrintState(std::ostream* os, const StateNode* node, + bool delete_trivial_loop) { // Gather placeholders std::vector placeholders; for (const auto& stage : node->stages) { @@ -1633,7 +873,8 @@ std::string State::ToStr(bool delete_trivial_loop) const { return os.str(); } -void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id) { +void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, + int target_iter_id) { AttachMapNode* pnode = CopyOnWrite(); // delete the current entry of stage @@ -1641,7 +882,8 @@ void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, int target_i // store the new relation IterKey iter_key(target_stage_id, target_iter_id); - pnode->stage_to_attach_iter[stage_id] = std::make_pair(target_stage_id, target_iter_id); + pnode->stage_to_attach_iter[stage_id] = std::make_pair(target_stage_id, + target_iter_id); pnode->iter_to_attached_stages[iter_key].push_back(stage_id); } diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h index 3ffe8a7feafb..dd56e267c0a0 100644 --- a/src/ansor/loop_state.h +++ b/src/ansor/loop_state.h @@ -19,36 +19,18 @@ #ifndef TVM_ANSOR_LOOP_STATE_H_ #define TVM_ANSOR_LOOP_STATE_H_ -// #include -// #include -// #include -#include #include #include #include #include #include -#include "expr_hasher.h" -#include "utils.h" -#include "compute_dag.h" +#include "transform_step.h" namespace tvm { namespace ansor { using namespace tvm::tir; -enum IteratorType { - kSpace, // spatial iterator - kReduce, // reduction iterator - kMixed, // fused spatial and reduction iterator - kSpecial // special iterator (e.g. virtual root iterator) -}; - -enum IteratorAnnotation { - kNone, kUnroll, kVectorize, kParallel, - kVThread, kBlockX, kThreadX, kBlockY, kThreadY -}; - enum StageType { kPlaceholder, kCompute }; @@ -59,29 +41,7 @@ enum ComputeAtType { kIter, // compute at some iterator }; -/* Iterator and Stage */ -class Iterator; class Stage; class State; - -/*! - * \brief An for loop iterator - * Similar to tvm::IterVar in `include/expr.h` - */ -class IteratorNode : public Object { - public: - std::string name; - Range range; // domain of for loop range - IteratorType iter_type; - IteratorAnnotation annotation; - std::vector ori_iters; - - static Iterator make(std::string name, Range range, - IteratorType iter_type, IteratorAnnotation annotation, - const std::vector* ori_iters = nullptr); - - static constexpr const char *_type_key = "ansor.Iterator"; - TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(Iterator, ObjectRef, IteratorNode); +class Stage; class State; /*! * \brief A stage in the compute declaration @@ -97,389 +57,20 @@ class StageNode : public Object { int storage_offset; static Stage make(te::Operation op); - static Stage make(te::Operation op, StageType op_type, const std::vector& iters, - ComputeAtType compute_at, int16_t auto_unroll_max_step, int storage_offset); - static Stage make(te::Operation op, StageType op_type, std::vector&& iters, - ComputeAtType compute_at, int16_t auto_unroll_max_step, int storage_offset); + static Stage make(te::Operation op, StageType op_type, + const std::vector& iters, + ComputeAtType compute_at, int16_t auto_unroll_max_step, + int storage_offset); + static Stage make(te::Operation op, StageType op_type, + std::vector&& iters, + ComputeAtType compute_at, int16_t auto_unroll_max_step, + int storage_offset); static constexpr const char *_type_key = "ansor.Stage"; - TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object); + TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object); }; TVM_DEFINE_COW_NODE_REF(Stage, ObjectRef, StageNode); - -/*! \brief The base class for a transformation step */ -class StepNode: public Object { - public: - int stage_id; - - // Print step as equivalent python schedule API - virtual std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const = 0; - - static constexpr const char* _type_key = "ansor.Step"; - TVM_DECLARE_BASE_OBJECT_INFO(StepNode, Object); -}; -TVM_DEFINE_MUTABLE_NODE_REF(Step, StepNode); - -/* - * Note on how to add a new transform step - * - * Take fuse for example: - * 1. Define class FuseStepNode, FuseStep in loop_state.h, and implement its make function - * in FuseStepNode::make(...) loop_state.cc - * 2. Implement FuseStepNode::ApplyToSchedule and FuseStepNode::PrintAsPythonAPI. - * - In these two functions you need to lower this step with tvm's schedule API - * 3. Implement State::fuse and State::DoFuseStep. - * - In these two functions you need to incrementally update all data structures in State with - * CopyOnWrite style - * 4. Add you step to ComputeDAG::ReplaySteps and make sure it works. - * 5. Add serialization support in `struct Handler >` - * (in serialization.cc) - * 6. Add hash support in `struct hash<::tvm::ansor::Step>` (search for this function in this file) - */ - -class ReorderStep; class SplitStep; class FollowSplitStep; -class FollowFusedSplitStep; -class FuseStep; class AnnotationStep; -class ComputeAtStep; class ComputeRootStep; class ComputeInlineStep; -class PackForVecStep; class CacheReadStep; class CacheWriteStep; -class PragmaStep; class RfactorStep; class StorageAlignStep; -class AttachMap; - -class ReorderStepNode: public StepNode { - public: - std::vector after_ids; - - static ReorderStep make(int stage_id, const std::vector& after_ids); - - void ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.ReorderStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(ReorderStep, Step, ReorderStepNode); - - -class SplitStepNode: public StepNode { - public: - int iter_id; - PrimExpr extent; // the extent of the axis to split - std::vector lengths; // The split factors - bool inner_to_outer; - - static SplitStep make(int stage_id, int iter_id, PrimExpr extent, - const std::vector& lengths, - bool inner_to_outer); - - std::vector ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.SplitStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(SplitStep, Step, SplitStepNode); - -// Similar to SplitStepNode, but use split factor from another step(i.e. Follow another split step) -class FollowSplitStepNode: public StepNode { - public: - int iter_id; - int src_step_id; - int n_split; - - static FollowSplitStep make(int stage_id, int iter_id, - int src_step_id, int n_split); - - void ExtractSplitLengths(const std::vector& transform_steps, - std::vector* lengths) const; - - std::vector ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, - const std::vector& transform_steps) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.FollowSplitStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(FollowSplitStep, Step, FollowSplitStepNode); - - -// Similar to FollowSplitStep, but use split factors from multiple steps -// This can be used for the split in cooperative fetching. -class FollowFusedSplitStepNode: public StepNode { - public: - int iter_id; - std::vector src_step_ids; - int level; // Use the length in this split level - bool factor_or_nparts; // If this is true, use factor. Otherwise, use nparts - - static FollowFusedSplitStep make(int stage_id, int iter_id, - const std::vector& src_step_ids, int level, bool factor_or_nparts); - - PrimExpr ExtractSplitLength(const std::vector& transform_steps) const; - - std::vector ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, - const std::vector& transform_steps) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.FollowFusedSplitStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(FollowFusedSplitStep, Step, FollowFusedSplitStepNode); - - -class FuseStepNode: public StepNode { - public: - std::vector fused_ids; - - static FuseStep make(int stage_id, const std::vector& fused_ids); - - IterVar ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.FuseStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(FuseStep, Step, FuseStepNode); - - -class AnnotationStepNode: public StepNode { - public: - int iter_id; - IteratorAnnotation annotation; - - static AnnotationStep make(int stage_id, int iter_id, IteratorAnnotation ann); - - void ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.AnnotationStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(AnnotationStep, Step, AnnotationStepNode); - - -class ComputeAtStepNode: public StepNode { - public: - int target_stage_id; - int target_iter_id; - - static ComputeAtStep make(int stage_id, int target_stage_id, int target_iter_id); - - void ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.ComputeAtStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(ComputeAtStep, Step, ComputeAtStepNode); - - -class ComputeRootStepNode: public StepNode { - public: - static ComputeRootStep make(int stage_id); - - void ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.ComputeRootStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(ComputeRootStep, Step, ComputeRootStepNode); - - -class ComputeInlineStepNode: public StepNode { - public: - static ComputeInlineStep make(int stage_id); - - void ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.ComputeInlineStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(ComputeInlineStep, Step, ComputeInlineStepNode); - -class PackForVecStepNode: public StepNode { - public: - int iter_id; - int vec_size; - - static PackForVecStep make(int stage_id, int iter_id, int vec_size); - - void ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, te::Schedule *schedule) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.PackForVecStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(PackForVecStepNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(PackForVecStep, Step, PackForVecStepNode); - - -/*! \brief Apply cache_read to a stage - * TVM Api: te::Schedule::cache_read(tensor, scope, readers) */ -class CacheReadStepNode: public StepNode { - public: - std::string scope_name; - std::vector reader_stage_ids; - - static CacheReadStep make(int stage_id, std::string scope_name, - const std::vector& reader_stage_id); - - te::Tensor ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, te::Schedule *schedule) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.CacheReadStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(CacheReadStepNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(CacheReadStep, Step, CacheReadStepNode); - - -/*! \brief Apply cache_write to a stage - * TVM Api: te::Schedule::cache_write(tensor, scope) - * This step will cache_write all output tensors of target stage */ -class CacheWriteStepNode: public StepNode { - public: - std::string scope_name; - - static CacheWriteStep make(int stage_id, std::string scope_name); - - Array ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, te::Schedule *schedule) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.CacheWriteStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(CacheWriteStepNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(CacheWriteStep, Step, CacheWriteStepNode); - -/*! \brief Add pragma to a specific iterator */ -class PragmaStepNode: public StepNode { - public: - int iter_id; - std::string pragma_type; - - static PragmaStep make(int stage_id, int iter_id, std::string pragma_type); - - void ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.PragmaStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(PragmaStepNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(PragmaStep, Step, PragmaStepNode); - -/*! \brief Factor a reduction axis - * TVM Api: te::Schedule::rfactor(tensor, axis, factor_axis) */ -class RfactorStepNode: public StepNode { - public: - int iter_id; - int factor_iter_id; - - static RfactorStep make(int stage_id, int iter_id, int factor_iter_id); - - Array ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.RfactorStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(RfactorStepNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(RfactorStep, Step, RfactorStepNode); - -class StorageAlignStepNode: public StepNode { - public: - int iter_id; - int factor; - int offset; - - static StorageAlignStep make(int stage_id, int iter_id, int factor, - int offset); - - void ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.StorageAlignStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(StorageAlignStepNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(StorageAlignStep, Step, StorageAlignStepNode); - /*! \brief stores the compute_at relation between stages */ class AttachMapNode: public Object { public: @@ -523,13 +114,13 @@ class AttachMap : public ObjectRef { class StateNode: public Object { public: std::vector stages; // Current stages and loop structures - std::vector transform_steps; // History transformation steps to reach this state + std::vector transform_steps; // History transformation steps bool complete; // Indicate whether this state has unfilled tile sizes AttachMap attach_map; // stores the compute_at relation between stages - ObjectRef aux_info; // Used to store any auxiliary info about this state + ObjectRef aux_info; // Used to store any auxiliary info about this state ComputeDAG task_dag; // The up-to-date ComputeDAG of this state. - // The default value is an empty NodeRef - // (means no modification to the DAG) + // The default value is an empty NodeRef + // (means no modification to the DAG) void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("complete", &complete); @@ -539,7 +130,8 @@ class StateNode: public Object { static State make_empty_state(); static State make(const Array& ops); static State make(const std::vector& stages, - const std::vector& transform_steps, bool complete, ObjectRef aux_info); + const std::vector& transform_steps, bool complete, + ObjectRef aux_info); static constexpr const char* _type_key = "ansor.State"; TVM_DECLARE_FINAL_OBJECT_INFO(StateNode, Object); @@ -556,7 +148,8 @@ class State : public ObjectRef { std::vector follow_split(int stage_id, const Iterator& it, int src_step_id, int n_split); std::vector follow_fused_split(int stage_id, const Iterator& it, - const std::vector& src_step_ids, int level, bool factor_or_nparts); + const std::vector& src_step_ids, + int level, bool factor_or_nparts); Iterator fuse(int stage_id, const std::vector& iters); Iterator vectorize(int stage_id, const Iterator& it); Iterator parallel(int stage_id, const Iterator& it); @@ -564,7 +157,8 @@ class State : public ObjectRef { // Valide thread_type: kVThread, kBlockX, kThreadX, kThreadY Iterator bind_thread(int stage_id, const Iterator& it, IteratorAnnotation thread_type); - void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter); + void compute_at(int stage_id, int target_stage_id, + const Iterator& target_iter); void compute_root(int stage_id); void compute_inline(int stage_id); void pack_for_vec(int stage_id, const Iterator& target_iter, int vec_size); @@ -578,7 +172,8 @@ class State : public ObjectRef { const ComputeDAG& task_dag); void storage_align(int stage_id, const Iterator& it, int factor, int offset); - /* We separate these functions out, so you can call them for replay easily given history steps */ + // We separate these functions out, + // so you can call them for replay easily given history steps void DoReorderStep(const ReorderStep& step); std::vector DoSplitStep(const SplitStep& step); std::vector DoFollowSplitStep(const FollowSplitStep& step); @@ -596,7 +191,9 @@ class State : public ObjectRef { void DoStorageAlignStep(const StorageAlignStep& step); /* Do transform steps - * Note: The following function only change loop state. They do not change transform_history. */ + * Note: The following function only change loop state. + * They do not change transform_history. + */ void DoStep(const Step& step, const ComputeDAG& dag); void DoSteps(const std::vector& step, const ComputeDAG& dag); @@ -620,98 +217,6 @@ class State : public ObjectRef { // Hash and equal function for State, Stage, Iterator and Step namespace std { -template <> -struct hash<::tvm::ansor::Step> { - std::size_t operator()(const ::tvm::ansor::Step& step) const { - if (auto ps = step.as<::tvm::ansor::ReorderStepNode>()) { - return ::dmlc::HashCombine(1, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ps->after_ids)); - } else if (auto ps = step.as<::tvm::ansor::SplitStepNode>()) { - size_t ret = ::dmlc::HashCombine(2, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - ps->inner_to_outer))); - for (const auto& len : ps->lengths) { - if (len.defined()) { - auto pint = len.as<::tvm::tir::IntImmNode>(); - CHECK(pint != nullptr); - ret = ::dmlc::HashCombine(ret, pint->value); - } else { - ret = ::dmlc::HashCombine(ret, 0x5D); // a magic number - } - return ret; - } - } else if (auto ps = step.as<::tvm::ansor::FollowSplitStepNode>()) { - return ::dmlc::HashCombine(3, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - ::dmlc::HashCombine(std::hash()(ps->src_step_id), - ps->n_split)))); - } else if (auto ps = step.as<::tvm::ansor::FollowFusedSplitStepNode>()) { - return ::dmlc::HashCombine(4, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - ::dmlc::HashCombine(std::hash>()(ps->src_step_ids), - ::dmlc::HashCombine(std::hash()(ps->level), - ps->factor_or_nparts))))); - } else if (auto ps = step.as<::tvm::ansor::FuseStepNode>()) { - return ::dmlc::HashCombine(5, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ps->fused_ids)); - } else if (auto ps = step.as<::tvm::ansor::AnnotationStepNode>()) { - return ::dmlc::HashCombine(6, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - static_cast(ps->annotation)))); - } else if (auto ps = step.as<::tvm::ansor::ComputeAtStepNode>()) { - return ::dmlc::HashCombine(7, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->target_stage_id), - ps->target_iter_id))); - } else if (auto ps = step.as<::tvm::ansor::ComputeRootStepNode>()) { - return ::dmlc::HashCombine(8, - ps->stage_id); - } else if (auto ps = step.as<::tvm::ansor::ComputeInlineStepNode>()) { - return ::dmlc::HashCombine(9, - ps->stage_id); - } else if (auto ps = step.as<::tvm::ansor::PackForVecStepNode>()) { - return ::dmlc::HashCombine(10, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - ps->vec_size))); - } else if (auto ps = step.as<::tvm::ansor::CacheReadStepNode>()) { - return ::dmlc::HashCombine(11, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->scope_name), - ps->reader_stage_ids))); - } else if (auto ps = step.as<::tvm::ansor::CacheWriteStepNode>()) { - return ::dmlc::HashCombine(12, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ps->scope_name)); - } else if (auto ps = step.as<::tvm::ansor::PragmaStepNode>()) { - return ::dmlc::HashCombine(13, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - ps->pragma_type))); - } else if (auto ps = step.as<::tvm::ansor::RfactorStepNode>()) { - return ::dmlc::HashCombine(14, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - ps->factor_iter_id))); - } else if (auto ps = step.as<::tvm::ansor::StorageAlignStepNode>()) { - return ::dmlc::HashCombine(15, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - ::dmlc::HashCombine(std::hash()(ps->factor), - ps->offset)))); - } else { - LOG(FATAL) << "Invalid step"; - } - return 0; - } -}; - template <> struct hash<::tvm::ansor::State> { std::size_t operator()(const ::tvm::ansor::State& state) const { diff --git a/src/ansor/transform_step.cc b/src/ansor/transform_step.cc new file mode 100644 index 000000000000..8cd8233ae9be --- /dev/null +++ b/src/ansor/transform_step.cc @@ -0,0 +1,820 @@ +/*! + * Copyright (c) 2020 by Contributors + */ +#include "transform_step.h" +#include +#include "utils.h" + +namespace tvm { +namespace ansor { + +TVM_REGISTER_OBJECT_TYPE(StepNode); + +/********** Reorder **********/ +ReorderStep ReorderStepNode::make(int stage_id, const std::vector& after_ids) { + auto node = make_object(); + node->stage_id = stage_id; + node->after_ids = after_ids; + return ReorderStep(node); +} + +void ReorderStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + te::Stage& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + CHECK_EQ(after_ids.size(), axes.size()); + + std::vector new_axes; + new_axes.reserve(axes.size()); + for (auto i : after_ids) { + new_axes.push_back(axes[i]); + } + stage.reorder(new_axes); + (*stage_to_axes)[stage] = std::move(new_axes); +} + +std::string ReorderStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + const te::Stage& stage = (*stages)[stage_id]; + std::stringstream ss; + + ss << "s[" << CleanName(stage->op->func_name()) << "].reorder("; + for (size_t i = 0; i < after_ids.size(); ++i) { + ss << CleanName((*stage_to_axes)[stage][after_ids[i]]->var->name_hint); + if (i != after_ids.size() - 1) { + ss << ", "; + } + } + ss << ")\n"; + + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} + +/********** Split **********/ +std::vector ApplySplitToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, + int stage_id, + int iter_id, + const std::vector& lengths, + bool inner_to_outer) { + te::Stage& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + + std::vector outs; + if (inner_to_outer) { + IterVar outer = axes[iter_id], inner; + for (int i = static_cast(lengths.size()) - 1; i >= 0; i--) { + IterVar to_split = outer; + stage.split(to_split, lengths[i], &outer, &inner); + outs.push_back(inner); + } + outs.push_back(outer); + } else { + IterVar outer, inner = axes[iter_id]; + for (size_t i = 0; i < lengths.size(); i++) { + IterVar to_split = inner; + stage.split_by_nparts(to_split, lengths[i], &outer, &inner); + outs.push_back(outer); + } + outs.push_back(inner); + } + + std::vector new_axes; + new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + iter_id); + if (inner_to_outer) { + new_axes.insert(new_axes.end(), outs.rbegin(), outs.rend()); + } else { + new_axes.insert(new_axes.end(), outs.begin(), outs.end()); + } + new_axes.insert(new_axes.end(), axes.begin() + iter_id + 1, axes.end()); + (*stage_to_axes)[stage] = std::move(new_axes); + + return outs; +} + +std::string PrintSplitAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + int stage_id, + int iter_id, + const std::vector& lengths, + bool inner_to_outer) { + te::Stage& stage = (*stages)[stage_id]; + auto to_split = (*stage_to_axes)[stage][iter_id]; + const auto& func_name = CleanName(stage->op->func_name()); + const auto& outs = ApplySplitToSchedule(stages, stage_to_axes, stage_id, + iter_id, lengths, inner_to_outer); + + std::stringstream ss; + int size = static_cast(lengths.size()); + if (inner_to_outer) { + for (int i = size - 1; i >= 0; i--) { + ss << CleanName(outs[size - i]->var->name_hint) << ", " + << CleanName(outs[size - i - 1]->var->name_hint) + << " = s[" << func_name << "].split(" + << CleanName(to_split->var->name_hint) + << ", factor=" << lengths[i] << ")\n"; + to_split = outs[size - i]; + } + } else { + for (int i = 0; i < size; i++) { + ss << CleanName(outs[i]->var->name_hint) << ", " + << CleanName(outs[i + 1]->var->name_hint) + << " = s[" << func_name << "].split(" + << CleanName(to_split->var->name_hint) + << ", nparts=" << lengths[i] << ")\n"; + to_split = outs[i + 1]; + } + } + + return ss.str(); +} + +SplitStep SplitStepNode::make(int stage_id, int iter_id, + PrimExpr extent, const std::vector& lengths, + bool inner_to_outer) { + auto node = make_object(); + node->stage_id = stage_id; + // Extent can be a unreducible expression in some special cases + if (extent->IsInstance()) { + node->extent = std::move(extent); + } + node->iter_id = iter_id; + node->lengths = lengths; + node->inner_to_outer = inner_to_outer; + return SplitStep(node); +} + +std::vector SplitStepNode::ApplyToSchedule( + std::vector *stages, StageToAxesMap *stage_to_axes) const { + return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, + lengths, inner_to_outer); +} + +std::string SplitStepNode::PrintAsPythonAPI( + std::vector *stages, StageToAxesMap *stage_to_axes, + te::Schedule *schedule, const std::vector& transform_steps) const { + return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, + lengths, inner_to_outer); +} + +/********** Follow Split **********/ +FollowSplitStep FollowSplitStepNode::make(int stage_id, int iter_id, + int src_step_id, int n_split) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->src_step_id = src_step_id; + node->n_split = n_split; + return FollowSplitStep(node); +} + +void FollowSplitStepNode::ExtractSplitLengths(const std::vector& transform_steps, + std::vector* lengths) const { + CHECK_LT(src_step_id, transform_steps.size()); + auto ps = transform_steps[src_step_id].as(); + CHECK(ps != nullptr); + + // get lengths from src step + lengths->reserve(n_split); + int j = 0; + for (; j < n_split - 1; ++j) { + lengths->push_back(ps->lengths[j]); + } + PrimExpr last_factor = 1; + for (; j < static_cast(ps->lengths.size()); ++j) { + if (ps->lengths[j].defined()) { + last_factor *= ps->lengths[j]; + } else { + last_factor = PrimExpr(); + break; + } + } + lengths->push_back(std::move(last_factor)); +} + +std::vector FollowSplitStepNode::ApplyToSchedule( + std::vector *stages, StageToAxesMap *stage_to_axes, + const std::vector& transform_steps) const { + std::vector lengths; + ExtractSplitLengths(transform_steps, &lengths); + return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, + lengths, true); +} + +std::string FollowSplitStepNode::PrintAsPythonAPI( + std::vector *stages, StageToAxesMap *stage_to_axes, + te::Schedule *schedule, const std::vector& transform_steps) const { + std::vector lengths; + ExtractSplitLengths(transform_steps, &lengths); + return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, + lengths, true); +} + +/********** Follow Fused Split **********/ +FollowFusedSplitStep FollowFusedSplitStepNode::make(int stage_id, int iter_id, + const std::vector& src_step_ids, int level, bool factor_or_nparts) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->src_step_ids = src_step_ids;; + node->level = level; + node->factor_or_nparts = factor_or_nparts; + return FollowFusedSplitStep(node); +} + +PrimExpr FollowFusedSplitStepNode::ExtractSplitLength(const std::vector& transform_steps) const { + PrimExpr ret(1); + + for (int src_step_id : src_step_ids) { + CHECK_LT(src_step_id, transform_steps.size()); + auto ps = transform_steps[src_step_id].as(); + CHECK(ps != nullptr); + if (ps->lengths[level].defined() && ret.defined()) { + ret *= ps->lengths[level]; + } else { + return PrimExpr(); + } + } + + return ret; +} + +std::vector FollowFusedSplitStepNode::ApplyToSchedule( + std::vector *stages, StageToAxesMap *stage_to_axes, + const std::vector& transform_steps) const { + const PrimExpr& length = ExtractSplitLength(transform_steps); + return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, + {length}, factor_or_nparts); +} + +std::string FollowFusedSplitStepNode::PrintAsPythonAPI( + std::vector *stages, StageToAxesMap *stage_to_axes, + te::Schedule *schedule, const std::vector& transform_steps) const { + const PrimExpr& length = ExtractSplitLength(transform_steps); + return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, + {length}, factor_or_nparts); +} + + +/********** Fuse **********/ +FuseStep FuseStepNode::make(int stage_id, const std::vector& fused_ids) { + auto node = make_object(); + node->stage_id = stage_id; + node->fused_ids = fused_ids; + return FuseStep(node); +} + +IterVar FuseStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + te::Stage& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + + Array to_fuse; + for (auto i : fused_ids) { + to_fuse.push_back(axes[i]); + } + IterVar fused_axis; + stage.fuse(to_fuse, &fused_axis); + std::vector new_axes; + new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + fused_ids[0]); + new_axes.push_back(fused_axis); + new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back() + 1, + axes.end()); + (*stage_to_axes)[stage] = std::move(new_axes); + + return fused_axis; +} + +std::string FuseStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + const auto& stage = (*stages)[stage_id]; + std::stringstream to_fuse; + + for (size_t i = 0; i < fused_ids.size(); ++i) { + to_fuse << CleanName((*stage_to_axes)[stage][fused_ids[i]]->var->name_hint); + if (i != fused_ids.size() - 1) { + to_fuse << ", "; + } + } + + std::stringstream ss; + const auto& fused = ApplyToSchedule(stages, stage_to_axes); + + ss << CleanName(fused->var->name_hint) << " = s[" + << CleanName(stage->op->func_name()) << "].fuse(" + << to_fuse.str() << ")\n"; + + return ss.str(); +} + +/********** Annotation **********/ +AnnotationStep AnnotationStepNode::make(int stage_id, int iter_id, IteratorAnnotation ann) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->annotation = ann; + return AnnotationStep(node); +} + +void AnnotationStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + te::Stage& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + + switch (annotation) { + case kUnroll: stage.unroll(axes[iter_id]); break; + case kVectorize: stage.vectorize(axes[iter_id]); break; + case kParallel: stage.parallel(axes[iter_id]); break; + case kVThread: stage.bind(axes[iter_id], te::thread_axis(Range(), "vthread")); break; + case kBlockX: stage.bind(axes[iter_id], te::thread_axis(Range(), "blockIdx.x")); break; + case kBlockY: stage.bind(axes[iter_id], te::thread_axis(Range(), "blockIdx.y")); break; + case kThreadX: + if (axes[iter_id]->iter_type == kCommReduce) { + const auto &thread_x = te::thread_axis(Range(), "threadIdx.x"); + stage.bind(axes[iter_id], thread_x); + stage.set_store_predicate(thread_x->var == 0); + } else { + stage.bind(axes[iter_id], te::thread_axis(Range(), "threadIdx.x")); + } + break; + case kThreadY: stage.bind(axes[iter_id], te::thread_axis(Range(), "threadIdx.y")); break; + case kNone: break; + default: LOG(FATAL) << "Invalid Annotation " << annotation; break; + } +} + +std::string AnnotationStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + const auto& iter = (*stage_to_axes)[stage][iter_id]; + + bool bind_reduce_iter = iter->iter_type == kCommReduce && annotation == kThreadX; + if (bind_reduce_iter) { + ss << "thread_x = tvm.thread_axis(\"threadIdx.x\")\n"; + } + + ss << "s[" << CleanName(stage->op->func_name()) << "]."; + switch (annotation) { + case kUnroll: ss << "unroll("; break; + case kVectorize: ss << "vectorize("; break; + case kParallel: ss << "parallel("; break; + case kVThread: + case kBlockX: + case kBlockY: + case kThreadX: + case kThreadY: ss << "bind("; break; + case kNone: break; + default: + LOG(FATAL) << "Invalid annotation " << annotation; break; + } + ss << CleanName(iter->var->name_hint); + switch (annotation) { + case kVThread: ss << ", tvm.thread_axis(\"vthread\")"; break; + case kBlockX: ss << ", tvm.thread_axis(\"blockIdx.x\")"; break; + case kBlockY: ss << ", tvm.thread_axis(\"blockIdy.y\")"; break; + case kThreadX: + if (bind_reduce_iter) { + ss << ", thread_x"; + } else { + ss << ", tvm.thread_axis(\"threadIdx.x\")"; + } + break; + case kThreadY: ss << ", tvm.thread_axis(\"threadIdx.y\")"; break; + default: break; + } + ss << ")\n"; + + if (bind_reduce_iter) { + ss << "s[" << CleanName(stage->op->func_name()) << "]" + << ".set_store_predicate(thread_x.var.equal(0))\n"; + } + + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} + +/********** Compute at **********/ +ComputeAtStep ComputeAtStepNode::make(int stage_id, int target_stage_id, int target_iter_id) { + auto node = make_object(); + node->stage_id = stage_id; + node->target_stage_id = target_stage_id; + node->target_iter_id = target_iter_id; + return ComputeAtStep(node); +} + +void ComputeAtStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + te::Stage& stage = (*stages)[stage_id]; + const IterVar& target_axis = + (*stage_to_axes)[(*stages)[target_stage_id]][target_iter_id]; + stage.compute_at((*stages)[target_stage_id], target_axis); +} + +std::string ComputeAtStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + const auto& target_stage = (*stages)[target_stage_id]; + + ss << "s[" << CleanName(stage->op->func_name()) << "].compute_at(s[" + << CleanName(target_stage->op->func_name()) << "], " + << CleanName((*stage_to_axes)[target_stage][target_iter_id]->var->name_hint); + + ss << ")\n"; + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} + +/********** Compute Root **********/ +ComputeRootStep ComputeRootStepNode::make(int stage_id) { + auto node = make_object(); + node->stage_id = stage_id; + return ComputeRootStep(node); +} + +void ComputeRootStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + (*stages)[stage_id].compute_root(); +} + +std::string ComputeRootStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + + ss << "s[" << CleanName(stage->op->func_name()) << "].compute_root()\n"; + ApplyToSchedule(stages, stage_to_axes); + + return ss.str(); +} + +/********** Compute Inline **********/ +ComputeInlineStep ComputeInlineStepNode::make(int stage_id) { + auto node = make_object(); + node->stage_id = stage_id; + return ComputeInlineStep(node); +} + +void ComputeInlineStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + (*stages)[stage_id].compute_inline(); +} + +std::string ComputeInlineStepNode::PrintAsPythonAPI( + std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + + ss << "s[" << CleanName(stage->op->func_name()) << "].compute_inline()\n"; + ApplyToSchedule(stages, stage_to_axes); + + return ss.str(); +} + +/********** Pack for vec **********/ +PackForVecStep PackForVecStepNode::make(int stage_id, int iter_id, int vec_size) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->vec_size = vec_size; + return PackForVecStep(node); +} + +void PackForVecStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, te::Schedule *schedule) const { + LOG(FATAL) << "Not implemented"; +} + +std::string PackForVecStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + LOG(FATAL) << "Not implemented"; + return ""; +} + +/********** Cache read **********/ +CacheReadStep CacheReadStepNode::make(int stage_id, std::string scope_name, + const std::vector& reader_stage_ids) { + auto node = make_object(); + node->stage_id = stage_id; + node->scope_name = std::move(scope_name); + node->reader_stage_ids = reader_stage_ids; + return CacheReadStep(node); +} + +te::Tensor CacheReadStepNode::ApplyToSchedule(std::vector* stages, + StageToAxesMap *stage_to_axes, te::Schedule *schedule) const { + te::Stage& stage = (*stages)[stage_id]; + + Array readers; + for (const auto& i : reader_stage_ids) { + readers.push_back((*stages)[i]->origin_op); + } + auto out = schedule->cache_read(stage->origin_op.output(0), scope_name, readers); + + const auto& new_stage = (*schedule)[out->op]; + UpdateStageAxis(new_stage, stage_to_axes); + stages->insert(stages->begin() + stage_id + 1, new_stage); + + return out; +} + +std::string CacheReadStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + std::stringstream ss; + // copy stage here, for the original stage will change after apply + auto stage = (*stages)[stage_id]; + std::vector reader_stages; + for (size_t i = 0; i < reader_stage_ids.size(); ++i) { + reader_stages.push_back((*stages)[reader_stage_ids[i]]); + } + + auto out = ApplyToSchedule(stages, stage_to_axes, schedule); + + ss << CleanName(out->op->func_name()) << " = " + << "s.cache_read(" << CleanName(stage->op->func_name()) << ", \"" + << scope_name << "\", [" + << CleanName(reader_stages[0]->op->func_name()); + for (size_t i = 1; i < reader_stage_ids.size(); ++i) { + ss << ", " << CleanName(reader_stages[i]->op->func_name()); + } + ss << "])\n"; + + const auto& iters = out->op->root_iter_vars(); + for (size_t i = 0; i < iters.size(); ++i) { + ss << CleanName(iters[i]->var->name_hint); + if (i != iters.size() - 1) { + ss << ", "; + } + } + ss << " = " << "tuple(" << CleanName(out->op->func_name()) + << ".op.axis)\n"; + + return ss.str(); +} + +/********** Cache write **********/ +CacheWriteStep CacheWriteStepNode::make(int stage_id, std::string scope_name) { + auto node = make_object(); + node->stage_id = stage_id; + node->scope_name = std::move(scope_name); + return CacheWriteStep(node); +} + +Array CacheWriteStepNode::ApplyToSchedule( + std::vector *stages, StageToAxesMap *stage_to_axes, + te::Schedule *schedule) const { + te::Stage& stage = (*stages)[stage_id]; + + Array tensor_array; + // If the target stage has multi outputs, TVM requires to cache_write + // all of them or schedule.cache_write will raise an error + for (auto i = 0; i < stage->op->num_outputs(); ++i) { + tensor_array.push_back(stage->origin_op.output(i)); + } + auto outs = schedule->cache_write(tensor_array, scope_name); + + UpdateStageAxis(stage, stage_to_axes); + // Even if there is multi outputs, TVM schedule only generate one + // new stage + const auto& new_stage = (*schedule)[outs[0]->op]; + UpdateStageAxis(new_stage, stage_to_axes); + stages->insert(stages->begin() + stage_id, new_stage); + + return outs; +} + +std::string CacheWriteStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + std::stringstream ss; + // copy stage here, for the original stage will change after apply + te::Stage stage = (*stages)[stage_id]; + + auto outs = ApplyToSchedule(stages, stage_to_axes, schedule); + + for (size_t i = 0; i < outs.size(); ++i) { + ss << CleanName(outs[i]->op->func_name()) << ", "; + } + ss << "= " << "s.cache_write([" + << CleanName(stage->op.output(0)->op->name); + for (auto i = 1; i < stage->op->num_outputs(); ++i) { + ss << ", " << CleanName(stage->op.output(i)->op->name); + } + ss << "], \"" << scope_name << "\")\n"; + + for (const auto& out : outs) { + const auto& iters = out->op->root_iter_vars(); + for (size_t i = 0; i < iters.size(); ++i) { + ss << CleanName(iters[i]->var->name_hint); + if (i != iters.size() - 1) { + ss << ", "; + } + } + ss << " = " << "tuple(" << CleanName(out->op->func_name()) + << ".op.axis)" + << " + " << "tuple(" << CleanName(out->op->func_name()) + << ".op.reduce_axis)\n"; + } + + return ss.str(); +} + +/********** Pragma **********/ +PragmaStep PragmaStepNode::make(int stage_id, int iter_id, + std::string pragma_type) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->pragma_type = std::move(pragma_type); + return PragmaStep(node); +} + +void PragmaStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + te::Stage& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + if (StrStartsWith(pragma_type, "auto_unroll_max_step")) { + size_t pos = pragma_type.find('$'); + int value = atoi(pragma_type.c_str() + pos + 1); + stage.pragma(axes[iter_id], "auto_unroll_max_step", value); + stage.pragma(axes[iter_id], "unroll_explicit", true); + } else { + stage.pragma(axes[iter_id], pragma_type); + } +} + +std::string PragmaStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + + if (StrStartsWith(pragma_type, "auto_unroll_max_step")) { + size_t pos = pragma_type.find('$'); + int value = atoi(pragma_type.c_str() + pos + 1); + ss << "s[" << CleanName(stage->op->func_name()) << "].pragma(" + << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) + << ", \"auto_unroll_max_step\", " << value << ")\n"; + ss << "s[" << CleanName(stage->op->func_name()) << "].pragma(" + << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) + << ", \"unroll_explicit\", True)\n"; + } else { + ss << "s[" << CleanName(stage->op->func_name()) << "].pragma(" + << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", \"" + << pragma_type << "\")\n"; + } + + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} + +/********** Rfactor **********/ +RfactorStep RfactorStepNode::make(int stage_id, int iter_id, int factor_iter_id) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->factor_iter_id = factor_iter_id; + return RfactorStep(node); +} + +Array RfactorStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, te::Schedule *schedule) const { + const auto& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + + const te::Tensor& tensor = stage->origin_op.output(0); + const IterVar& axis = axes[iter_id]; + auto outs = schedule->rfactor(tensor, axis, factor_iter_id); + + UpdateStageAxis(stage, stage_to_axes); + + const auto& new_stage = (*schedule)[outs[0]->op]; + UpdateStageAxis(new_stage, stage_to_axes); + stages->insert(stages->begin() + stage_id, new_stage); + + return outs; +} + +std::string RfactorStepNode::PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + + const auto& tensor_name = CleanName(stage->origin_op.output(0)->op->name); + const auto& axis_name = CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint); + + const auto& outs = ApplyToSchedule(stages, stage_to_axes, schedule); + + for (size_t i = 0; i < outs.size(); ++i) { + ss << CleanName(outs[i]->op->func_name()); + if (i != outs.size() - 1) { + ss << ", "; + } + } + ss << " = " << "s.rfactor(" + << tensor_name << ", " + << axis_name << ", " + << factor_iter_id << ")\n"; + + for (const auto& out : outs) { + const auto& iters = out->op->root_iter_vars(); + for (size_t i = 0; i < iters.size(); ++i) { + ss << CleanName(iters[i]->var->name_hint); + if (i != iters.size() - 1) { + ss << ", "; + } + } + ss << " = " << "tuple(" << CleanName(out->op->func_name()) + << ".op.axis)" + << " + " << "tuple(" << CleanName(out->op->func_name()) + << ".op.reduce_axis)\n"; + } + + const auto& output = (*stages)[stage_id + 1]->op.output(0); + const auto& iters = output->op->root_iter_vars(); + for (size_t i = 0; i < iters.size(); ++i) { + ss << CleanName(iters[i]->var->name_hint); + if (i != iters.size() - 1) { + ss << ", "; + } + } + ss << " = " << "tuple(s[" << CleanName(output->op->func_name()) + << "].op.axis)" + << " + " << "tuple(s[" << CleanName(output->op->func_name()) + << "].op.reduce_axis)\n"; + + return ss.str(); +} + +/********** StorageAlign **********/ + +StorageAlignStep StorageAlignStepNode::make(int stage_id, int iter_id, + int factor, int offset) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->factor = factor; + node->offset = offset; + return StorageAlignStep(node); +} + +void StorageAlignStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + te::Stage& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + stage.storage_align(axes[iter_id], factor, offset); +} + +std::string StorageAlignStepNode::PrintAsPythonAPI( + std::vector *stages, StageToAxesMap *stage_to_axes, + te::Schedule *schedule, const std::vector& transform_steps) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + ss << "s[" << CleanName(stage->op->func_name()) << "].storage_align(" + << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", " + << factor << ", " << offset << ")\n"; + + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} + +// Maker for other classes +Iterator IteratorNode::make(std::string name, Range range, + IteratorType iter_type, IteratorAnnotation annotation, + const std::vector* ori_iters) { + auto node = make_object(); + node->name = std::move(name); + node->range = std::move(range); + node->iter_type = iter_type; + node->annotation = annotation; + if (ori_iters != nullptr) { + node->ori_iters = *ori_iters; + } + return Iterator(node); +} + +} // namespace ansor +} // namespace tvm diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h new file mode 100644 index 000000000000..9b430be99bd3 --- /dev/null +++ b/src/ansor/transform_step.h @@ -0,0 +1,551 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file ansor/transform_step.h + * \brief Data structures for loop transformations + + * Basically this is a simplified TVM IR with schedule primitives. + * We don't use the existing TVM IR because + * 1. We want fast incremental change to the loop structures + * 2. We want serializable history for replay and backtracking + * 3. We want simplified IR for easy and clean feature extraction + * 4. We may create some Macro schedule primitives + + * After search is done, we will lower this IR to TVM IR and TVM schedule primitives. + * Because we share a lot common objects during search, the transformation is + * implemented in copy on write style. All objects are immutable, which is + * similar to TVM IR. + */ + +#ifndef TVM_ANSOR_TRANSFORM_STEP_H_ +#define TVM_ANSOR_TRANSFORM_STEP_H_ + +#include +#include +#include +#include "compute_dag.h" + +namespace tvm { +namespace ansor { + +using namespace tvm::tir; + +inline std::string CleanName(const std::string& str) { + // to make the name valid in python code + std::string ret = str; + StrReplace(&ret, ".", "_"); + StrReplace(&ret, "@", "_"); + StrReplace(&ret, "outer", "o"); + StrReplace(&ret, "inner", "i"); + return ret; +} + +enum IteratorType { + kSpace, // spatial iterator + kReduce, // reduction iterator + kMixed, // fused spatial and reduction iterator + kSpecial // special iterator (e.g. virtual root iterator) +}; + +enum IteratorAnnotation { + kNone, kUnroll, kVectorize, kParallel, + kVThread, kBlockX, kThreadX, kBlockY, kThreadY +}; + +class Iterator; + +/*! + * \brief An for loop iterator + * Similar to tvm::IterVar in `include/expr.h` + */ +class IteratorNode : public Object { + public: + std::string name; + Range range; // domain of for loop range + IteratorType iter_type; + IteratorAnnotation annotation; + std::vector ori_iters; + + static Iterator make(std::string name, Range range, + IteratorType iter_type, IteratorAnnotation annotation, + const std::vector* ori_iters = nullptr); + + static constexpr const char *_type_key = "ansor.Iterator"; + TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(Iterator, ObjectRef, IteratorNode); + +/*! \brief The base class for a transformation step */ +class StepNode: public Object { + public: + int stage_id; + + // Print step as equivalent python schedule API + virtual std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const = 0; + + static constexpr const char* _type_key = "ansor.Step"; + TVM_DECLARE_BASE_OBJECT_INFO(StepNode, Object); +}; +TVM_DEFINE_MUTABLE_NODE_REF(Step, StepNode); + +/* + * Note on how to add a new transform step + * + * Take fuse for example: + * 1. Define class FuseStepNode, FuseStep in loop_state.h, and implement its make function + * in FuseStepNode::make(...) loop_state.cc + * 2. Implement FuseStepNode::ApplyToSchedule and FuseStepNode::PrintAsPythonAPI. + * - In these two functions you need to lower this step with tvm's schedule API + * 3. Implement State::fuse and State::DoFuseStep. + * - In these two functions you need to incrementally update all data structures in State with + * CopyOnWrite style + * 4. Add you step to ComputeDAG::ReplaySteps and make sure it works. + * 5. Add serialization support in `struct Handler >` + * (in serialization.cc) + * 6. Add hash support in `struct hash<::tvm::ansor::Step>` (search for this function in this file) + */ + +class ReorderStep; class SplitStep; class FollowSplitStep; +class FollowFusedSplitStep; +class FuseStep; class AnnotationStep; +class ComputeAtStep; class ComputeRootStep; class ComputeInlineStep; +class PackForVecStep; class CacheReadStep; class CacheWriteStep; +class PragmaStep; class RfactorStep; class StorageAlignStep; +class AttachMap; + +class ReorderStepNode: public StepNode { + public: + std::vector after_ids; + + static ReorderStep make(int stage_id, const std::vector& after_ids); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.ReorderStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(ReorderStep, Step, ReorderStepNode); + + +class SplitStepNode: public StepNode { + public: + int iter_id; + PrimExpr extent; // the extent of the axis to split + std::vector lengths; // The split factors + bool inner_to_outer; + + static SplitStep make(int stage_id, int iter_id, PrimExpr extent, + const std::vector& lengths, + bool inner_to_outer); + + std::vector ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.SplitStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(SplitStep, Step, SplitStepNode); + +// Similar to SplitStepNode, but use split factor from another step +// (i.e. Follow another split step) +class FollowSplitStepNode: public StepNode { + public: + int iter_id; + int src_step_id; + int n_split; + + static FollowSplitStep make(int stage_id, int iter_id, + int src_step_id, int n_split); + + void ExtractSplitLengths(const std::vector& transform_steps, + std::vector* lengths) const; + + std::vector ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, + const std::vector& transform_steps) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.FollowSplitStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(FollowSplitStep, Step, FollowSplitStepNode); + + +// Similar to FollowSplitStep, but use split factors from multiple steps +// This can be used for the split in cooperative fetching. +class FollowFusedSplitStepNode: public StepNode { + public: + int iter_id; + std::vector src_step_ids; + int level; // Use the length in this split level + bool factor_or_nparts; // If this is true, use factor. Otherwise, use nparts + + static FollowFusedSplitStep make(int stage_id, int iter_id, + const std::vector& src_step_ids, + int level, bool factor_or_nparts); + + PrimExpr ExtractSplitLength(const std::vector& transform_steps) const; + + std::vector ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, + const std::vector& transform_steps) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.FollowFusedSplitStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(FollowFusedSplitStep, Step, FollowFusedSplitStepNode); + + +class FuseStepNode: public StepNode { + public: + std::vector fused_ids; + + static FuseStep make(int stage_id, const std::vector& fused_ids); + + IterVar ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.FuseStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(FuseStep, Step, FuseStepNode); + + +class AnnotationStepNode: public StepNode { + public: + int iter_id; + IteratorAnnotation annotation; + + static AnnotationStep make(int stage_id, int iter_id, IteratorAnnotation ann); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.AnnotationStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(AnnotationStep, Step, AnnotationStepNode); + + +class ComputeAtStepNode: public StepNode { + public: + int target_stage_id; + int target_iter_id; + + static ComputeAtStep make(int stage_id, int target_stage_id, + int target_iter_id); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.ComputeAtStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(ComputeAtStep, Step, ComputeAtStepNode); + + +class ComputeRootStepNode: public StepNode { + public: + static ComputeRootStep make(int stage_id); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.ComputeRootStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(ComputeRootStep, Step, ComputeRootStepNode); + + +class ComputeInlineStepNode: public StepNode { + public: + static ComputeInlineStep make(int stage_id); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.ComputeInlineStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(ComputeInlineStep, Step, ComputeInlineStepNode); + +class PackForVecStepNode: public StepNode { + public: + int iter_id; + int vec_size; + + static PackForVecStep make(int stage_id, int iter_id, int vec_size); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, te::Schedule *schedule) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.PackForVecStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(PackForVecStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(PackForVecStep, Step, PackForVecStepNode); + + +/*! \brief Apply cache_read to a stage + * TVM Api: te::Schedule::cache_read(tensor, scope, readers) */ +class CacheReadStepNode: public StepNode { + public: + std::string scope_name; + std::vector reader_stage_ids; + + static CacheReadStep make(int stage_id, std::string scope_name, + const std::vector& reader_stage_id); + + te::Tensor ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, te::Schedule *schedule) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.CacheReadStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(CacheReadStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(CacheReadStep, Step, CacheReadStepNode); + + +/*! \brief Apply cache_write to a stage + * TVM Api: te::Schedule::cache_write(tensor, scope) + * This step will cache_write all output tensors of target stage */ +class CacheWriteStepNode: public StepNode { + public: + std::string scope_name; + + static CacheWriteStep make(int stage_id, std::string scope_name); + + Array ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, te::Schedule *schedule) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.CacheWriteStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(CacheWriteStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(CacheWriteStep, Step, CacheWriteStepNode); + +/*! \brief Add pragma to a specific iterator */ +class PragmaStepNode: public StepNode { + public: + int iter_id; + std::string pragma_type; + + static PragmaStep make(int stage_id, int iter_id, std::string pragma_type); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.PragmaStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(PragmaStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(PragmaStep, Step, PragmaStepNode); + +/*! \brief Factor a reduction axis + * TVM Api: te::Schedule::rfactor(tensor, axis, factor_axis) */ +class RfactorStepNode: public StepNode { + public: + int iter_id; + int factor_iter_id; + + static RfactorStep make(int stage_id, int iter_id, int factor_iter_id); + + Array ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.RfactorStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(RfactorStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(RfactorStep, Step, RfactorStepNode); + +class StorageAlignStepNode: public StepNode { + public: + int iter_id; + int factor; + int offset; + + static StorageAlignStep make(int stage_id, int iter_id, int factor, + int offset); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.StorageAlignStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(StorageAlignStepNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(StorageAlignStep, Step, StorageAlignStepNode); + +} // namespace ansor +} // namespace tvm + +// Hash and equal function for State, Stage, Iterator and Step +namespace std { + +template <> +struct hash<::tvm::ansor::Step> { + std::size_t operator()(const ::tvm::ansor::Step& step) const { + if (auto ps = step.as<::tvm::ansor::ReorderStepNode>()) { + return ::dmlc::HashCombine(1, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ps->after_ids)); + } else if (auto ps = step.as<::tvm::ansor::SplitStepNode>()) { + size_t ret = ::dmlc::HashCombine(2, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + ps->inner_to_outer))); + for (const auto& len : ps->lengths) { + if (len.defined()) { + auto pint = len.as<::tvm::tir::IntImmNode>(); + CHECK(pint != nullptr); + ret = ::dmlc::HashCombine(ret, pint->value); + } else { + ret = ::dmlc::HashCombine(ret, 0x5D); // a magic number + } + return ret; + } + } else if (auto ps = step.as<::tvm::ansor::FollowSplitStepNode>()) { + return ::dmlc::HashCombine(3, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + ::dmlc::HashCombine(std::hash()(ps->src_step_id), + ps->n_split)))); + } else if (auto ps = step.as<::tvm::ansor::FollowFusedSplitStepNode>()) { + return ::dmlc::HashCombine(4, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + ::dmlc::HashCombine(std::hash>()(ps->src_step_ids), + ::dmlc::HashCombine(std::hash()(ps->level), + ps->factor_or_nparts))))); + } else if (auto ps = step.as<::tvm::ansor::FuseStepNode>()) { + return ::dmlc::HashCombine(5, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ps->fused_ids)); + } else if (auto ps = step.as<::tvm::ansor::AnnotationStepNode>()) { + return ::dmlc::HashCombine(6, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + static_cast(ps->annotation)))); + } else if (auto ps = step.as<::tvm::ansor::ComputeAtStepNode>()) { + return ::dmlc::HashCombine(7, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->target_stage_id), + ps->target_iter_id))); + } else if (auto ps = step.as<::tvm::ansor::ComputeRootStepNode>()) { + return ::dmlc::HashCombine(8, + ps->stage_id); + } else if (auto ps = step.as<::tvm::ansor::ComputeInlineStepNode>()) { + return ::dmlc::HashCombine(9, + ps->stage_id); + } else if (auto ps = step.as<::tvm::ansor::PackForVecStepNode>()) { + return ::dmlc::HashCombine(10, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + ps->vec_size))); + } else if (auto ps = step.as<::tvm::ansor::CacheReadStepNode>()) { + return ::dmlc::HashCombine(11, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->scope_name), + ps->reader_stage_ids))); + } else if (auto ps = step.as<::tvm::ansor::CacheWriteStepNode>()) { + return ::dmlc::HashCombine(12, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ps->scope_name)); + } else if (auto ps = step.as<::tvm::ansor::PragmaStepNode>()) { + return ::dmlc::HashCombine(13, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + ps->pragma_type))); + } else if (auto ps = step.as<::tvm::ansor::RfactorStepNode>()) { + return ::dmlc::HashCombine(14, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + ps->factor_iter_id))); + } else if (auto ps = step.as<::tvm::ansor::StorageAlignStepNode>()) { + return ::dmlc::HashCombine(15, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + ::dmlc::HashCombine(std::hash()(ps->factor), + ps->offset)))); + } else { + LOG(FATAL) << "Invalid step"; + } + return 0; + } +}; +} // namespace std + +#endif // TVM_ANSOR_TRANSFORM_STEP_H_ diff --git a/tests/cpp/ansor_test.cc b/tests/cpp/ansor_test.cc index b9a4f25023bf..87e7ad71a7c0 100644 --- a/tests/cpp/ansor_test.cc +++ b/tests/cpp/ansor_test.cc @@ -19,10 +19,10 @@ #include #include - +#include #include #include -#include "../../src/ansor/compute_dag.h" +#include "../../src/ansor/loop_state.h" tvm::Array matmul_func(int n, int m, int k) { using namespace tvm; @@ -52,11 +52,14 @@ tvm::Array conv2d_nchw_bn_relu_func(int N, int H, int W, Tensor bn_scale = placeholder({CO, 1, 1}, DataType::Float(32), "Bn_scale"); Tensor bn_offset = placeholder({CO, 1, 1}, DataType::Float(32), "Bn_offset"); - int OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1); - int OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1); + int OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1) / strides + 1; + int OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1) / strides + 1; + + const auto& conv = topi::conv2d_nchw(data, kernel, padding, padding, strides, + strides); + CHECK(conv->shape[2].as()->value == OH); + CHECK(conv->shape[3].as()->value == OW); - const auto& conv = topi::conv2d_nchw(data, kernel, strides, padding, - dilation); const auto& bias_add = compute( {N, CO, OH, OW}, [&](Var i, Var j, Var k, Var l) { @@ -82,12 +85,461 @@ tvm::Array conv2d_nchw_bn_relu_func(int N, int H, int W, TEST(ComputeDAG, Basic) { const auto& tensors = conv2d_nchw_bn_relu_func(1, 224, 224, 3, 64, 7, 2, 3); - auto dag = tvm::ansor::ComputeDAGNode::make(tensors); + const auto& dag = tvm::ansor::ComputeDAGNode::make(tensors); + const auto& state = tvm::ansor::StateNode::make(dag->ops); + CHECK(std::equal_to()(state, dag.GetInitState())); + LOG(INFO) << "\n" << state; LOG(INFO) << "\n" << dag; LOG(INFO) << "\n" << dag->access_analyzer; } +TEST(ComputeDAG, GetProducersConsumers) { + using namespace tvm::ansor; + + const auto& tensors = conv2d_nchw_bn_relu_func(1, 224, 224, 3, 64, 7, 2, 3); + const auto& dag = tvm::ansor::ComputeDAGNode::make(tensors); + int data = 0, padding = 1, kernel = 2, conv = 3, bias = 4, bias_add = 5; + int bn_scale = 6, bn_mul = 7, bn_offset = 8, bn_add = 9, relu = 10; + + State s0 = dag.GetInitState(); + std::unordered_set set; + { + std::vector> consumer_list = { + {data, padding}, {padding, conv}, {kernel, conv}, {conv, bias_add}, + {bias, bias_add}, {bias_add, bn_mul}, {bn_scale, bn_mul}, + {bn_mul, bn_add}, {bn_offset, bn_add}, {bn_add, relu} + }; + for (const auto& pair : consumer_list) { + dag->access_analyzer.GetConsumers(s0, s0->stages[pair.first]->op, &set); + CHECK_EQ(set.size(), 1); + CHECK_EQ((*set.begin()), s0->stages[pair.second]->op); + } + std::vector>> producer_list = { + {padding, {data}}, {conv, {padding, kernel}}, {bias_add, {conv, bias}}, + {bn_mul, {bias_add, bn_scale}}, {bn_add, {bn_mul, bn_offset}}, + {relu, {bn_add}} + }; + for (const auto& pair : producer_list) { + dag->access_analyzer.GetProducers(s0, s0->stages[pair.first]->op, &set); + CHECK_EQ(set.size(), pair.second.size()); + for (const auto& target : pair.second) { + CHECK(set.count(s0->stages[target]->op)); + } + } + } + + s0.compute_inline(bn_add); + s0.compute_inline(bn_mul); + s0.compute_inline(bias_add); + s0.compute_inline(padding); + { + std::vector> consumer_list = { + {data, conv}, {kernel, conv}, {conv, relu} + }; + for (const auto& pair : consumer_list) { + dag->access_analyzer.GetConsumers(s0, s0->stages[pair.first]->op, &set); + CHECK_EQ(set.size(), 1); + CHECK_EQ((*set.begin()), s0->stages[pair.second]->op); + } + std::vector>> producer_list = { + {padding, {data}}, {conv, {padding, kernel}}, {bias_add, {conv, bias}}, + {bn_mul, {bias_add, bn_scale}}, {bn_add, {bn_mul, bn_offset}}, + {relu, {bn_add}} + }; + for (const auto& pair : producer_list) { + dag->access_analyzer.GetProducers(s0, s0->stages[pair.first]->op, &set); + CHECK_EQ(set.size(), pair.second.size()); + for (const auto& target : pair.second) { + CHECK(set.count(s0->stages[target]->op)); + } + } + } +} + +TEST(Step, SplitFuseReorder) { + using namespace tvm::ansor; + + const auto& tensors = matmul_func(512, 512, 512); + const auto& dag = ComputeDAGNode::make(tensors); + + State s0 = dag.GetInitState(); + State s1 = s0; + Iterator ti = s0->stages[2]->iters[0]; + Iterator tj = s0->stages[2]->iters[1]; + Iterator tk = s0->stages[2]->iters[2]; + std::vector its; + + CHECK_EQ(s1->stages[2]->iters[0]->range->extent.as()->value, 512); + + its = s0.split(2, ti, {16}); + CHECK_EQ(s0->stages[2]->iters[0]->range->extent.as()->value, 32); + CHECK_EQ(s0->stages[2]->iters[1]->range->extent.as()->value, 16); + + Iterator tio = its[0], tii = its[1]; + its = s0.split(2, tj, {8}); + CHECK_EQ(s0->stages[2]->iters[2]->range->extent.as()->value, 64); + CHECK_EQ(s0->stages[2]->iters[3]->range->extent.as()->value, 8); + + Iterator tjo = its[0], tji = its[1]; + s0.reorder(2, {tio, tjo, tk, tji, tii}); + CHECK_EQ(s0->stages[2]->iters[2]->range->extent.as()->value, 512); + + s0.fuse(2, {tio, tjo}); + CHECK_EQ(s0->stages[2]->iters[0]->range->extent.as()->value, 2048); + + s1.split(2, ti, {8, 2}); + s1.split(2, tj, {32, 8}, false); + CHECK_EQ(s1->stages[2]->iters[0]->range->extent.as()->value, 32); + CHECK_EQ(s1->stages[2]->iters[1]->range->extent.as()->value, 8); + CHECK_EQ(s1->stages[2]->iters[2]->range->extent.as()->value, 2); + CHECK_EQ(s1->stages[2]->iters[3]->range->extent.as()->value, 32); + CHECK_EQ(s1->stages[2]->iters[4]->range->extent.as()->value, 8); + CHECK_EQ(s1->stages[2]->iters[5]->range->extent.as()->value, 2); +} + +TEST(Step, ComputeAtRootInline) { + using namespace tvm::ansor; + + const auto& tensors = conv2d_nchw_bn_relu_func(1, 224, 224, 3, 64, 7, 2, 3); + const auto& dag = tvm::ansor::ComputeDAGNode::make(tensors); + // int data = 0, padding = 1, kernel = 2; + int conv = 3; + // int bias = 4; + int bias_add = 5; + // int bn_scale = 6; + int bn_mul = 7; + // int bn_offset = 8; + int bn_add = 9, relu = 10; + + State s0 = dag.GetInitState(); + s0.compute_inline(bn_add); + s0.compute_inline(bn_mul); + s0.compute_inline(bias_add); + s0.compute_at(conv, relu, s0->stages[relu]->iters[2]); + const auto& conv_stage_attach = s0->attach_map->stage_to_attach_iter.find(conv); + std::pair iterkey(relu, 2); + CHECK(conv_stage_attach->second == iterkey); + const auto& conv_iter_attach = s0->attach_map->iter_to_attached_stages.find(iterkey); + CHECK_EQ(conv_iter_attach->second.size(), 1); + CHECK_EQ(conv_iter_attach->second[0], conv); + std::stringstream ss; + ss << "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + << "for ax1 (0,3)\n" + << " for ax2 (0,230)\n" + << " for ax3 (0,230)\n" + << " T_pad = ...\n" + << "for ax1 (0,64)\n" + << " for ax2 (0,112)\n" + << " for ax0 (None)\n" + << " for ax1 (None)\n" + << " for ax2 (None)\n" + << " for ax3 (None)\n" + << " for i (None)\n" + << " for kh (None)\n" + << " for kw (None)\n" + << " T_conv2d_nchw = ...\n" + << " for ax3 (0,112)\n" + << " T_relu = ...\n"; + CHECK_EQ(s0.ToStr().compare(ss.str()), 0); + + s0.compute_root(conv); + s0.compute_root(bn_mul); + CHECK_EQ(s0->attach_map->stage_to_attach_iter.size(), 0); + CHECK_EQ(s0->attach_map->iter_to_attached_stages.size(), 0); + ss.str(std::string()); + ss << "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + << "for ax1 (0,3)\n" + << " for ax2 (0,230)\n" + << " for ax3 (0,230)\n" + << " T_pad = ...\n" + << "for ax0 (None)\n" + << " for ax1 (None)\n" + << " for ax2 (None)\n" + << " for ax3 (None)\n" + << " for i (None)\n" + << " for kh (None)\n" + << " for kw (None)\n" + << " T_conv2d_nchw = ...\n" + << "for ax0 (None)\n" + << " for ax1 (None)\n" + << " for ax2 (None)\n" + << " for ax3 (None)\n" + << " Bn_mul = ...\n" + << "for ax1 (0,64)\n" + << " for ax2 (0,112)\n" + << " for ax3 (0,112)\n" + << " T_relu = ...\n"; + CHECK_EQ(s0.ToStr().compare(ss.str()), 0); +} + +TEST(Step, CacheReadWrite) { + using namespace tvm; + using namespace tvm::te; + using namespace tvm::ansor; + + const auto& test_func = []() -> Array { + int N = 4, H = 7, W = 7, CO = 512, CI = 512, KH = 3, KW = 3, stride = 1; + int padding = 1; + Tensor data = placeholder({N, CI, H, W}, DataType::Float(32), "Data"); + Tensor kernel_data = placeholder({CO, CI, KH, KW}, DataType::Float(32), + "kernel_data"); + const auto& k_split = compute(kernel_data->shape, + [&](const Array& i) { + return Array({kernel_data[i[0]][i[1]][i[2]][i[3]] + 1, + div(kernel_data[i[0]][i[1]][i[2]][i[3]], 2)}); + }, + "Kernel_split"); + const auto& kernel = compute(kernel_data->shape, + [&](Var i, Var j, Var k, Var l) { + return (k_split[0])[i][j][k][l] + (k_split[1])[i][j][k][l]; + }, + "Kernel"); + const auto& conv = topi::conv2d_nchw(data, kernel, padding, padding, stride, + stride); + const auto& relu = topi::relu(conv); + const auto& out = compute(relu->shape, + [&](Var i, Var j, Var k, Var l) { + return data[i][j][k][l] + relu[i][j][k][l]; + }, + "Add"); + return {data, kernel_data, out}; + }; + const auto& dag = ComputeDAGNode::make(test_func()); + + int data = 0, pad_temp = 1, kernel_data = 2, kernel_split = 3, kernel = 4; + int conv = 5, relu = 6, add = 7; + + // 0: init state + auto s0 = dag.GetInitState(); + std::vector ori_its = s0->stages[add]->iters; + auto its = s0.split(add, s0->stages[add]->iters[0], {2}); + s0.reorder(add, {its[0], ori_its[1], its[1], ori_its[2], ori_its[3]}); + s0.compute_inline(relu); + + // 1: simple cache_write with compute_at + int conv_global = s0.cache_write(conv, "global", dag); + conv++; relu++; add++; + s0.compute_at(conv_global, conv, s0->stages[conv]->iters[3]); + + // 2: simple cache_read with compute_at + int kernel_global = s0.cache_read(kernel, "global", {conv_global}, dag); + conv_global++; conv++; relu++; add++; + s0.compute_at(kernel_global, conv_global, s0->stages[conv_global]->iters[4]); + std::stringstream ss; + ss << "Placeholder: Data, kernel_data\n" + << "for ax0 (0,4)\n" + << " for ax1 (0,512)\n" + << " for ax2 (0,9)\n" + << " for ax3 (0,9)\n" + << " T_pad = ...\n" + << "for ax0 (0,512)\n" + << " for ax1 (0,512)\n" + << " for ax2 (0,3)\n" + << " for ax3 (0,3)\n" + << " Kernel_split = ...\n" + << "for ax0 (0,512)\n" + << " for ax1 (0,512)\n" + << " for ax2 (0,3)\n" + << " for ax3 (0,3)\n" + << " Kernel = ...\n" + << "for ax0 (0,4)\n" + << " for ax1 (0,512)\n" + << " for ax2 (0,7)\n" + << " for ax3 (0,7)\n" + << " for ax0_c (None)\n" + << " for ax1_c (None)\n" + << " for ax2_c (None)\n" + << " for ax3_c (None)\n" + << " for i (None)\n" + << " for ax0 (None)\n" + << " for ax1 (None)\n" + << " for ax2 (None)\n" + << " for ax3 (None)\n" + << " Kernel.global = ...\n" + << " for kh (None)\n" + << " for kw (None)\n" + << " T_conv2d_nchw.global = ...\n" + << " T_conv2d_nchw = ...\n" + << "for ax0.0 (0,2)\n" + << " for ax1 (0,512)\n" + << " for ax0.1 (0,2)\n" + << " for ax2 (0,7)\n" + << " for ax3 (0,7)\n" + << " Add = ...\n"; + CHECK_EQ(s0.ToStr().compare(ss.str()), 0); + + // 3: two level cache_read with compute_at + // preparing for GPU's shared memory & local memory + int pad_temp_global = s0.cache_read(pad_temp, "global", {conv_global}, dag); + kernel_data++; kernel_split++; kernel++; kernel_global++; + conv_global++; conv++; relu++; add++; + int pad_temp_shared = s0.cache_read(pad_temp_global, "shared", {conv_global}, + dag); + kernel_data++; kernel_split++; kernel++; kernel_global++; + conv_global++; conv++; relu++; add++; + s0.compute_at(pad_temp_global, conv_global, + s0->stages[conv_global]->iters[2]); + s0.compute_at(pad_temp_shared, conv_global, + s0->stages[conv_global]->iters[4]); + + // 4: cache_read with multi readers + // This stage cannot be compute at to its consumer + s0.cache_read(data, "global", {pad_temp, add}, dag); + pad_temp++; pad_temp_global++; pad_temp_shared++; + kernel_data++; kernel_split++; kernel++; kernel_global++; + conv_global++; conv++; relu++; add++; + ss.str(std::string()); + ss << "Placeholder: Data, kernel_data\n" + << "for ax0 (0,4)\n" + << " for ax1 (0,512)\n" + << " for ax2 (0,7)\n" + << " for ax3 (0,7)\n" + << " Data.global = ...\n" + << "for ax0 (0,4)\n" + << " for ax1 (0,512)\n" + << " for ax2 (0,9)\n" + << " for ax3 (0,9)\n" + << " T_pad = ...\n" + << "for ax0 (0,512)\n" + << " for ax1 (0,512)\n" + << " for ax2 (0,3)\n" + << " for ax3 (0,3)\n" + << " Kernel_split = ...\n" + << "for ax0 (0,512)\n" + << " for ax1 (0,512)\n" + << " for ax2 (0,3)\n" + << " for ax3 (0,3)\n" + << " Kernel = ...\n" + << "for ax0 (0,4)\n" + << " for ax1 (0,512)\n" + << " for ax2 (0,7)\n" + << " for ax3 (0,7)\n" + << " for ax0_c (None)\n" + << " for ax1_c (None)\n" + << " for ax2_c (None)\n" + << " for ax0 (None)\n" + << " for ax1 (None)\n" + << " for ax2 (None)\n" + << " for ax3 (None)\n" + << " T_pad.global = ...\n" + << " for ax3_c (None)\n" + << " for i (None)\n" + << " for ax0 (None)\n" + << " for ax1 (None)\n" + << " for ax2 (None)\n" + << " for ax3 (None)\n" + << " Kernel.global = ...\n" + << " for ax0 (None)\n" + << " for ax1 (None)\n" + << " for ax2 (None)\n" + << " for ax3 (None)\n" + << " T_pad.global.shared = ...\n" + << " for kh (None)\n" + << " for kw (None)\n" + << " T_conv2d_nchw.global = ...\n" + << " T_conv2d_nchw = ...\n" + << "for ax0.0 (0,2)\n" + << " for ax1 (0,512)\n" + << " for ax0.1 (0,2)\n" + << " for ax2 (0,7)\n" + << " for ax3 (0,7)\n" + << " Add = ...\n"; + CHECK_EQ(s0.ToStr().compare(ss.str()), 0); + + // 5: cache_write with multi outputs + // TVM's cache_write actually has a bug with this case: + + // After schedule.cache_write, TVM generate one new stage: + // From: kernel_data -> kernel_split -> kernel + // To: kernel_data -> kernel_split_global -> kernel_split -> kernel + + // But with topo sort analyse, we get: + // kernel_data -> kernel_split_global -> kernel_split -> kernel + // \ / + // ----------------> kernel_split ----------------> + + // Seems there's bug with the input/output tensor. Such multi outputs case + // should be unusual, so we make some hack on DoCacheWrite + // To be fixed in the future + s0.cache_write(kernel_split, "global", dag); + ss.str(std::string()); + ss << "Placeholder: Data, kernel_data\n" + << "for ax0 (0,4)\n" + << " for ax1 (0,512)\n" + << " for ax2 (0,7)\n" + << " for ax3 (0,7)\n" + << " Data.global = ...\n" + << "for ax0 (0,4)\n" + << " for ax1 (0,512)\n" + << " for ax2 (0,9)\n" + << " for ax3 (0,9)\n" + << " T_pad = ...\n" + << "for ax0_c (0,512)\n" + << " for ax1_c (0,512)\n" + << " for ax2_c (0,3)\n" + << " for ax3_c (0,3)\n" + << " Kernel_split.global = ...\n" + << "for ax0 (0,512)\n" + << " for ax1 (0,512)\n" + << " for ax2 (0,3)\n" + << " for ax3 (0,3)\n" + << " Kernel_split = ...\n" + << "for ax0 (0,512)\n" + << " for ax1 (0,512)\n" + << " for ax2 (0,3)\n" + << " for ax3 (0,3)\n" + << " Kernel_split = ...\n" + << "for ax0 (0,512)\n" + << " for ax1 (0,512)\n" + << " for ax2 (0,3)\n" + << " for ax3 (0,3)\n" + << " Kernel = ...\n" + << "for ax0 (0,4)\n" + << " for ax1 (0,512)\n" + << " for ax2 (0,7)\n" + << " for ax3 (0,7)\n" + << " for ax0_c (None)\n" + << " for ax1_c (None)\n" + << " for ax2_c (None)\n" + << " for ax0 (None)\n" + << " for ax1 (None)\n" + << " for ax2 (None)\n" + << " for ax3 (None)\n" + << " T_pad.global = ...\n" + << " for ax3_c (None)\n" + << " for i (None)\n" + << " for ax0 (None)\n" + << " for ax1 (None)\n" + << " for ax2 (None)\n" + << " for ax3 (None)\n" + << " Kernel.global = ...\n" + << " for ax0 (None)\n" + << " for ax1 (None)\n" + << " for ax2 (None)\n" + << " for ax3 (None)\n" + << " T_pad.global.shared = ...\n" + << " for kh (None)\n" + << " for kw (None)\n" + << " T_conv2d_nchw.global = ...\n" + << " T_conv2d_nchw = ...\n" + << "for ax0.0 (0,2)\n" + << " for ax1 (0,512)\n" + << " for ax0.1 (0,2)\n" + << " for ax2 (0,7)\n" + << " for ax3 (0,7)\n" + << " Add = ...\n"; + CHECK_EQ(s0.ToStr().compare(ss.str()), 0); +} + +TEST(Step, FollowSplitFollowFusedSplit) { + // todo +} + +TEST(Step, Rfactor) { + // todo +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; From f43e82f0ba4353f8fff8fcd830ce08c3bc94c793 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Thu, 28 May 2020 15:56:38 +0800 Subject: [PATCH 03/45] Add search_task, measure and serialization (#4) * Add FollowSplit & FollowFusedSplit tests * Update dag.InferBound & its UT * Add search_task, measure and serialization * Update Serialization UT --- include/tvm/ir/expr.h | 5 + include/tvm/runtime/device_api.h | 3 +- src/ansor/compute_dag.cc | 261 ++++++----- src/ansor/measure.cc | 314 +++++++++++++ src/ansor/measure.h | 262 +++++++++++ src/ansor/search_task.cc | 120 +++++ src/ansor/search_task.h | 92 ++++ src/ansor/serialization.cc | 573 ++++++++++++++++++++++++ src/ansor/serialization.h | 78 ++++ src/ansor/utils.h | 7 + src/ir/expr.cc | 2 + src/runtime/cuda/cuda_device_api.cc | 4 + src/runtime/opencl/opencl_device_api.cc | 3 + tests/cpp/ansor_test.cc | 122 ++++- 14 files changed, 1723 insertions(+), 123 deletions(-) create mode 100644 src/ansor/measure.cc create mode 100644 src/ansor/measure.h create mode 100644 src/ansor/search_task.cc create mode 100644 src/ansor/search_task.h create mode 100644 src/ansor/serialization.cc create mode 100644 src/ansor/serialization.h diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index b2ce50d91f58..b3e527ca6fd9 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -112,6 +112,11 @@ class PrimExpr : public BaseExpr { * \param value The value to be constructed. */ TVM_DLL PrimExpr(float value); // NOLINT(*) + /*! + * \brief construct from double. + * \param value The value to be constructed. + */ + TVM_DLL PrimExpr(double value); // NOLINT(*) /*! \return the data type of this expression. */ DataType dtype() const { return static_cast(get())->dtype; } diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index 421811a52c3b..9b2eb6be2160 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -44,7 +44,8 @@ enum DeviceAttrKind : int { kMaxClockRate = 6, kMultiProcessorCount = 7, kMaxThreadDimensions = 8, - kGcnArch = 9 + kGcnArch = 9, + kMaxRegistersPerBlock = 10 }; /*! \brief Number of bytes each allocation must align to */ diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index e1ae3250d1a5..feaefe9f8e9f 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -3,6 +3,7 @@ */ #include "compute_dag.h" #include +#include #include #include #include @@ -32,7 +33,8 @@ using OperationSet = std::unordered_set; // Topo-sort ops from tensors according to their read-write relations. // Results are stored in ops -void TopoSortOps(const Array& tensors, std::vector* ops) { +void TopoSortOps(const Array& tensors, + std::vector* ops) { std::unordered_map degree; std::unordered_map > edge_set; std::unordered_map priority; @@ -193,7 +195,8 @@ bool IsInjective(const te::Operation& op, const std::vector& index, } // Gather all VarNodes in an expr -static void GatherVars(const PrimExpr& expr, std::unordered_set* vars) { +static void GatherVars(const PrimExpr& expr, + std::unordered_set* vars) { PostOrderVisit(expr, [&vars](const ObjectRef &node) { if (const VarNode* op = node.as()) { vars->insert(op); @@ -206,7 +209,8 @@ static bool HasExpensiveOp(const PrimExpr& expr) { bool found = false; PostOrderVisit(expr, [&found](const ObjectRef &node) { if (const CallNode* op = node.as()) { - if (op->call_type == CallNode::CallType::PureIntrinsic && op->name == "exp") { + if (op->call_type == CallNode::CallType::PureIntrinsic && + op->name == "exp") { found = true; } } @@ -224,7 +228,8 @@ AccessAnalyzer AccessAnalyzerNode::make(const Array& tensors) { // build read & write access map for (const auto& op : node->ops_topo_order) { if (op->IsInstance()) { - node->read_from[op] = OperationMap > >(); + node->read_from[op] = + OperationMap > >(); } else if (auto cop = op.as()) { TensorAccessExtractor extractor; for (const auto& exp : cop->body) { @@ -232,8 +237,10 @@ AccessAnalyzer AccessAnalyzerNode::make(const Array& tensors) { } for (const auto& iter : extractor.buf_accesses) { - std::vector >& accesses = node->read_by[iter.first][op]; - accesses.insert(accesses.begin(), iter.second.begin(), iter.second.end()); + std::vector >& accesses = + node->read_by[iter.first][op]; + accesses.insert(accesses.begin(), iter.second.begin(), + iter.second.end()); } node->read_from[op] = std::move(extractor.buf_accesses); @@ -251,7 +258,8 @@ AccessAnalyzer AccessAnalyzerNode::make(const Array& tensors) { node->is_strict_inlineable[op] = false; node->is_output[op] = false; } else if (auto pop = op.as()) { - // check whether is element-wise and strict-inlineable (see definition in compute_dag.h) + // check whether is element-wise and strict-inlineable + // (see definition in compute_dag.h) bool is_injective = true; bool is_strict_inlineable = true; @@ -259,12 +267,14 @@ AccessAnalyzer AccessAnalyzerNode::make(const Array& tensors) { for (const auto& pair : node->read_from[op]) { const std::vector >& access = pair.second; for (const auto& index : access) { - if (!IsInjective(op, index, &axis_missing, &axis_duplicated, &same_order)) { + if (!IsInjective(op, index, &axis_missing, &axis_duplicated, + &same_order)) { is_injective = false; is_strict_inlineable = false; break; } - if (!same_order || axis_duplicated) { // do not strictly inline transpose + if (!same_order || axis_duplicated) { + // do not strictly inline transpose is_strict_inlineable = false; } } @@ -281,9 +291,11 @@ AccessAnalyzer AccessAnalyzerNode::make(const Array& tensors) { } node->is_injective[op] = is_injective; - node->is_strict_inlineable[op] = is_strict_inlineable && !has_expensive_op; + node->is_strict_inlineable[op] = is_strict_inlineable && + !has_expensive_op; - // check whether the op needs multi-level tiling (see definition in compute_dag.h) + // check whether the op needs multi-level tiling + // (see definition in compute_dag.h) bool needs_multi_level_tiling = false; int n_missing = 0; @@ -297,7 +309,8 @@ AccessAnalyzer AccessAnalyzerNode::make(const Array& tensors) { } bool missing = false; for (const auto& axis : pop->axis) { - if (GetIntImm(axis->dom->extent) > 1 && vars.count(axis->var.get()) == 0) { + if (GetIntImm(axis->dom->extent) > 1 && + vars.count(axis->var.get()) == 0) { missing = true; } } @@ -928,89 +941,90 @@ std::pair > ComputeDAG::ApplySteps( } } -// std::string ComputeDAG::PrintStepsAsPython( -// const std::vector& transform_steps) const { -// std::vector stages; -// StageToAxesMap stage_to_axes; -// Array ops; -// for (const auto& op : operator->()->ops) { -// if (!op->IsInstance()) { -// ops.push_back(op); -// } -// } -// te::Schedule schedule = te::create_schedule({ops.back()}); +std::string ComputeDAG::PrintStepsAsPython(const std::vector& transform_steps) const { + std::vector stages; + StageToAxesMap stage_to_axes; + Array ops; + for (const auto& op : operator->()->ops) { + if (!op->IsInstance()) { + ops.push_back(op); + } + } + te::Schedule schedule = te::create_schedule({ops.back()}); -// // init axes -// for (const auto& x : operator->()->ops) { -// const te::Stage& stage = schedule.operator[](x); -// stages.push_back(stage); -// UpdateStageAxis(stage, &stage_to_axes); -// } + // init axes + for (const auto& x : operator->()->ops) { + const te::Stage& stage = schedule.operator[](x); + stages.push_back(stage); + UpdateStageAxis(stage, &stage_to_axes); + } -// std::stringstream ss; + std::stringstream ss; -// for (const auto& stage : stages) { -// if (stage->op->IsInstance()) { -// for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { -// ss << stage->leaf_iter_vars[i]->var->name_hint; -// if (i != stage->leaf_iter_vars.size() - 1) { -// ss << ", "; -// } -// } -// ss << " = " << "tuple(" << stage->op->func_name() << ".op.axis)" -// << " + " << "tuple(" << stage->op->func_name() << ".op.reduce_axis)\n"; -// } -// } + for (const auto& stage : stages) { + if (stage->op->IsInstance()) { + for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { + ss << stage->leaf_iter_vars[i]->var->name_hint; + if (i != stage->leaf_iter_vars.size() - 1) { + ss << ", "; + } + } + ss << " = " << "tuple(" << stage->op->func_name() << ".op.axis)" + << " + " << "tuple(" << stage->op->func_name() << ".op.reduce_axis)\n"; + } + } -// for (const auto& step : transform_steps) { -// ss << step->PrintAsPythonAPI(&stages, &stage_to_axes, &schedule, transform_steps); -// } + for (const auto& step : transform_steps) { + ss << step->PrintAsPythonAPI(&stages, &stage_to_axes, &schedule, + transform_steps); + } -// return ss.str(); -// } + return ss.str(); +} -// State ComputeDAG::ReplayAndInferBound(const std::vector& transform_steps) const { -// State ret_state = GetInitState(); -// StateNode* pstate = ret_state.CopyOnWrite(); -// pstate->transform_steps = transform_steps; -// ret_state.DoSteps(transform_steps, *this); +State ComputeDAG::ReplayAndInferBound( + const std::vector& transform_steps) const { + State ret_state = GetInitState(); + StateNode* pstate = ret_state.CopyOnWrite(); + pstate->transform_steps = transform_steps; + ret_state.DoSteps(transform_steps, *this); -// InferBoundCommon(pstate); + InferBoundCommon(pstate); -// return ret_state; -// } + return ret_state; +} -// State ComputeDAG::InferBound(const State& state) const { -// State ret_state = state; -// StateNode* pstate = ret_state.CopyOnWrite(); +State ComputeDAG::InferBound(const State& state) const { + State ret_state = state; + StateNode* pstate = ret_state.CopyOnWrite(); -// InferBoundCommon(pstate); + InferBoundCommon(pstate); -// return ret_state; -// } + return ret_state; +} -// void ComputeDAG::InferBound(std::vector* states) const { -// std::vector out_states(states->size(), State()); +void ComputeDAG::InferBound(std::vector* states) const { + std::vector out_states(states->size(), State()); -// auto worker_func = [&states, &out_states, this](int idx) { -// try { -// out_states[idx] = this->InferBound((*states)[idx]); -// } catch (dmlc::Error &e) { -// LOG(WARNING) << "InferBound fails on the state:\n" << (*states)[idx] -// << "\n" << e.what() << std::endl; -// } -// }; + auto worker_func = [&states, &out_states, this](int idx) { + try { + out_states[idx] = this->InferBound((*states)[idx]); + } catch (dmlc::Error &e) { + LOG(WARNING) << "InferBound fails on the state:\n" << (*states)[idx] + << "\n" << e.what() << std::endl; + } + }; -// // Lower states in parallel -// ThreadPool& pool = ThreadPool::Global(); -// pool.BeginBatch(states->size()); -// for (size_t i = 0; i < states->size(); ++i) { -// pool.Enqueue(worker_func, i); -// } -// pool.WaitBatch(); + // Lower states in parallel + ThreadPool& pool = ThreadPool::Global(); + pool.BeginBatch(states->size()); + for (size_t i = 0; i < states->size(); ++i) { + pool.Enqueue(worker_func, i); + } + pool.WaitBatch(); -// *states = std::move(out_states); -// } + *states = std::move(out_states); +} void ComputeDAG::ReplayAndGetDAG(const std::vector &transform_steps, ComputeDAG *task_dag) const { @@ -1019,7 +1033,8 @@ void ComputeDAG::ReplayAndGetDAG(const std::vector &transform_steps, te::Schedule sch; Array old_tensors; - std::tie(sch, old_tensors) = ReplaySteps(transform_steps, &stages, &stage_to_axes); + std::tie(sch, old_tensors) = ReplaySteps(transform_steps, &stages, + &stage_to_axes); Array new_tensors; for (auto stage : sch->stages) { @@ -1035,45 +1050,47 @@ void ComputeDAG::ReplayAndGetDAG(const std::vector &transform_steps, } -// void ComputeDAG::InferBoundCommon(StateNode* pstate) const { -// std::vector stages; -// StageToAxesMap stage_to_axes; -// te::Schedule sch; -// Array tensors; -// Map bounds; +void ComputeDAG::InferBoundCommon(StateNode* pstate) const { + std::vector stages; + StageToAxesMap stage_to_axes; + te::Schedule sch; + Array tensors; + Map bounds; -// std::tie(sch, tensors) = ReplaySteps(pstate->transform_steps, &stages, &stage_to_axes); -// sch = sch.normalize(); -// bounds = schedule::InferBound(sch); + std::tie(sch, tensors) = ReplaySteps(pstate->transform_steps, &stages, + &stage_to_axes); + sch = sch.normalize(); + bounds = te::InferBound(sch); -// for (size_t i = 0; i < pstate->stages.size(); ++i) { -// const Stage& stage = pstate->stages[i]; + for (size_t i = 0; i < pstate->stages.size(); ++i) { + const Stage& stage = pstate->stages[i]; -// if (stage->compute_at == kInlined) { -// continue; -// } + if (stage->compute_at == kInlined) { + continue; + } -// std::vector new_iters; -// new_iters.reserve(stage->iters.size()); -// for (size_t j = 0; j < stage->iters.size(); ++j) { -// const Iterator& iter = stage->iters[j]; -// const IterVar& axis = stage_to_axes.at(stages[i])[j]; - -// auto find_res = bounds.find(axis); -// if (find_res != bounds.end()) { -// new_iters.push_back(IteratorNode::make(iter->name, (*find_res).second, -// iter->iter_type, iter->annotation, -// &iter->ori_iters)); -// } else { -// LOG(FATAL) << "Infer bound fails"; -// } -// } + std::vector new_iters; + new_iters.reserve(stage->iters.size()); + for (size_t j = 0; j < stage->iters.size(); ++j) { + const Iterator& iter = stage->iters[j]; + const IterVar& axis = stage_to_axes.at(stages[i])[j]; + + auto find_res = bounds.find(axis); + if (find_res != bounds.end()) { + new_iters.push_back(IteratorNode::make(iter->name, (*find_res).second, + iter->iter_type, + iter->annotation, + &iter->ori_iters)); + } else { + LOG(FATAL) << "Infer bound fails"; + } + } -// pstate->stages[i] = StageNode::make(stage->op, stage->op_type, -// std::move(new_iters), stage->compute_at, stage->auto_unroll_max_step, -// stage->storage_offset); -// } -// } + pstate->stages[i] = StageNode::make(stage->op, stage->op_type, + std::move(new_iters), stage->compute_at, + stage->auto_unroll_max_step, stage->storage_offset); + } +} std::pair > ComputeDAG::ReplaySteps( const std::vector &transform_steps, @@ -1096,8 +1113,8 @@ std::pair > ComputeDAG::ReplaySteps( UpdateStageAxis(stage, stage_to_axes); } - // todo(lmzheng): should we maintain the attach_map and keep the validity of compute_at - // an splitted axis? + // todo(lmzheng): should we maintain the attach_map and keep the validity of + // compute_at an splitted axis? // Use complete rate for the study in the paper const char* complete_rate_str = getenv("ANSOR_PROGRAM_COMPLETE_RATE"); @@ -1183,8 +1200,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) } else if (combiner->IsInstance()) { const auto& select = combiner.as(); ss << " select(" << select->condition << ", " << select->true_value - << ", " << select->false_value << ")= " - << '(' << preduce->source[0] << ',' << preduce->source[1] << ")\n"; + << ", " << select->false_value << ")= " << '(' + << preduce->source[0] << ',' << preduce->source[1] << ")\n"; } else { LOG(FATAL) << "Unsupported reduction operator" << combiner; } @@ -1208,7 +1225,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "is_injective:\t" << node->is_injective.at(op) << "\t\t"; p->stream << "needs_multi_level_tiling:\t" << node->needs_multi_level_tiling.at(op) << std::endl; - p->stream << "is_strict_inlinable:\t" << node->is_strict_inlineable.at(op) << "\t"; + p->stream << "is_strict_inlinable:\t" << node->is_strict_inlineable.at(op) + << "\t"; p->stream << "is_output:\t" << node->is_output.at(op) << std::endl; p->stream << "Read from:\t"; for (const auto& pair : node->read_from.at(op)) { @@ -1233,7 +1251,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) for (size_t i = 0; i < node->ops_topo_order.size(); ++i) { for (size_t j = 0; j < node->ops_topo_order.size(); ++j) { if (i == j) { continue; } - if (ana.ElementWiseMatch(node->ops_topo_order[i], node->ops_topo_order[j])) { + if (ana.ElementWiseMatch(node->ops_topo_order[i], + node->ops_topo_order[j])) { p->stream << node->ops_topo_order[i]->func_name() << " -> " << node->ops_topo_order[j]->func_name() << "\n"; } diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc new file mode 100644 index 000000000000..1bae02b3f2c5 --- /dev/null +++ b/src/ansor/measure.cc @@ -0,0 +1,314 @@ +/*! + * Copyright (c) 2020 by Contributors + */ +#include "measure.h" +// #include +#include +#include +#include +#include +#include +#include +// #include "search_policy/search_policy.h" + +namespace tvm { +namespace ansor { + +TVM_REGISTER_NODE_TYPE(MeasureInputNode); +TVM_REGISTER_NODE_TYPE(BuildResultNode); +TVM_REGISTER_NODE_TYPE(MeasureResultNode); +TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode); +TVM_REGISTER_OBJECT_TYPE(RunnerNode); +TVM_REGISTER_OBJECT_TYPE(BuilderNode); +TVM_REGISTER_OBJECT_TYPE(LocalBuilderNode); +TVM_REGISTER_OBJECT_TYPE(RPCRunnerNode); +TVM_REGISTER_OBJECT_TYPE(LocalRunnerNode); +TVM_REGISTER_OBJECT_TYPE(ProgramMeasurerNode); + +const char *ErrorNoToStr[] = { + "NoError", + "InstantiationError", + "CompileHostError", + "CompileDeviceError", + "RuntimeDeviceError", + "WrongAnswerError", + "BuildTimeoutError", + "RunTimeoutError", + "UnknownError", +}; + +// Maker +MeasureInput MeasureInputNode::make(SearchTask task, State state) { + auto node = make_object(); + node->task = std::move(task); + node->state = std::move(state); + return MeasureInput(node); +} + +MeasureInput MeasureInputNode::copy() const { + auto node = make_object(); + node->task = task; + node->state = state; + return MeasureInput(node); +} + +BuildResult BuildResultNode::make(std::string filename, Array args, int error_no, + std::string error_msg, double time_cost) { + auto node = make_object(); + node->filename = std::move(filename); + node->args = std::move(args); + node->error_no = error_no; + node->error_msg = std::move(error_msg); + node->time_cost = time_cost; + return BuildResult(node); +} + +MeasureResult MeasureResultNode::make(Array costs, int error_no, + std::string error_msg, double all_cost, double timestamp) { + auto node = make_object(); + node->costs = std::move(costs); + node->error_no = error_no; + node->error_msg = std::move(error_msg); + node->all_cost = all_cost; + node->timestamp = timestamp; + return MeasureResult(node); +} + +MeasureResult MeasureResultNode::copy() const { + auto node = make_object(); + node->costs = costs; + node->error_no = error_no; + node->error_msg = error_msg; + node->all_cost = all_cost; + node->timestamp = timestamp; + return MeasureResult(node); +} + +Builder LocalBuilderNode::make(int timeout, int n_parallel, const std::string& build_func) { + auto node = make_object(); + node->timeout = timeout; + node->n_parallel = n_parallel; + node->build_func = build_func; + return Builder(node); +} + +// LocalBuilder and LocalRunner +Array LocalBuilderNode::Build(const Array &inputs, int verbose) { + if (const auto* f = runtime::Registry::Get("ansor.local_builder.build")) { + Array results = (*f)(inputs, timeout, n_parallel, build_func, verbose); + return results; + } else { + LOG(FATAL) << "ansor.local_builder.build is not registered"; + } + return Array(); +} + +Runner RPCRunnerNode::make(const std::string& key, const std::string& host, int port, + int priority, int timeout, int n_parallel, int number, + int repeat, int min_repeat_ms, double cooldown_interval) { + auto node = make_object(); + node->key = key; + node->host = host; + node->port = port; + node->priority = priority; + node->timeout = timeout; + node->n_parallel = n_parallel; + node->number = number; + node->repeat = repeat; + node->min_repeat_ms = min_repeat_ms; + node->cooldown_interval = cooldown_interval; + return Runner(node); +} + +Array RPCRunnerNode::Run(const Array& inputs, + const Array& build_results, + int verbose) { + if (const auto* f = runtime::Registry::Get("ansor.rpc_runner.run")) { + Array results = (*f)(inputs, build_results, key, host, port, priority, + timeout, n_parallel, number, repeat, + min_repeat_ms, cooldown_interval, verbose); + return results; + } else { + LOG(FATAL) << "ansor.rpc_runner.run is not registered"; + } + return Array(); +} + +Runner LocalRunnerNode::make(int timeout, int number, int repeat, + int min_repeat_ms, double cooldown_interval) { + ObjectPtr node = make_object(); + node->timeout = timeout; + node->number = number; + node->repeat = repeat; + node->min_repeat_ms = min_repeat_ms; + node->cooldown_interval = cooldown_interval; + return Runner(node); +} + +Array LocalRunnerNode::Run(const Array& inputs, + const Array& build_results, + int verbose) { + if (const auto* f = runtime::Registry::Get("ansor.local_runner.run")) { + Array results = (*f)(inputs, build_results, timeout, number, + repeat, min_repeat_ms, cooldown_interval, verbose); + return results; + } else { + LOG(FATAL) << "ansor.local_runner.run is not registered"; + } + return Array(); +} + +ProgramMeasurer ProgramMeasurerNode::make(Builder builder, Runner runner, + Array callbacks, + int verbose, + int max_continous_error) { + auto node = make_object(); + node->builder = std::move(builder); + node->runner = std::move(runner); + node->callbacks = std::move(callbacks); + node->verbose = verbose; + node->max_continous_error = max_continous_error < 0 ? + DEFAULT_MAX_CONTINOUS_ERROR : max_continous_error; + return ProgramMeasurer(node); +} + +void ProgramMeasurerNode::Reset() { + ct = error_ct = 0; + best_flops.clear(); + best_ct.clear(); + best_state.clear(); +} + +void ProgramMeasurerNode::Measure(const SearchTask& task, + const SearchPolicy& policy, + const std::vector& inputs, + std::vector* results, + int batch_size) { + results->clear(); + results->reserve(inputs.size()); + + if (batch_size == -1) { + // set default batch size + batch_size = builder->n_parallel * 2; + } + + StdCout(verbose) << "Get " << inputs.size() << " programs for measure. (This may take a while)" + << std::endl; + + for (size_t i = 0; i < inputs.size(); i += batch_size) { + std::vector input_batch(inputs.begin() + i, + inputs.begin() + std::min(i + batch_size, inputs.size())); + std::vector result_batch; + + // build and run + SilentMeasure(task, input_batch, &result_batch); + + // update current best state according to the new measure result + for (size_t j = 0; j < input_batch.size(); ++j) { + double flops; + if (result_batch[j]->error_no == 0) { + flops = task->compute_dag->flop_ct / FloatArrayMean(result_batch[j]->costs); + error_ct = 0; + } else { + flops = 0.0; + error_ct++; + } + + const std::string& workload_key = input_batch[j]->task->workload_key; + if (flops > best_flops[workload_key]) { + best_flops[workload_key] = flops; + best_state[workload_key] = input_batch[j]->state; + best_ct[workload_key] = ct; + } + + ct++; + if (verbose >= 1) { + std::cout << std::fixed << std::setprecision(2); + std::cout << "===============================================\n"; + std::cout << "No: " << ct + << "\tGFLOPS: " << flops / 1e9 << " / " << best_flops[workload_key] / 1e9 + << "\tresults: " << result_batch[j] << "\n"; + std::cout << "===============================================\n"; + std::cout << input_batch[j]->state << "\n"; + } + } + + // Call callback functions + for (const auto& callback : callbacks) { + callback->callback(policy, input_batch, result_batch); + } + + // Store result batch + for (auto& res : result_batch) { + results->push_back(res); + } + + if (error_ct > max_continous_error) { + LOG(FATAL) << "Too many errors happened during tuning"; + } + } +} + +void ProgramMeasurerNode::SilentMeasure(const SearchTask& task, + const std::vector& inputs, + std::vector* results) { + // Close the thread pool to avoid the conflits with python environment + ThreadPool::Global().Abort(); + + results->clear(); + results->reserve(inputs.size()); + Array input_batch(inputs.begin(), inputs.end()); + + // Call builder and runner + Array build_res_batch = builder->Build(input_batch, verbose); + Array result_batch = runner->Run(input_batch, build_res_batch, verbose); + + // Store result batch + for (auto& res : result_batch) { + results->push_back(res); + } +} + +// Printing functions +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { + p->stream << "MeasureInput()"; +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { + auto* node = static_cast(ref.get()); + if (node->error_no == kNoError) { + p->stream << "MeasureResult(cost:["; + auto old_config = p->stream.precision(4); + for (size_t i = 0; i < node->costs.size(); ++i) { + auto pf = node->costs[i].as(); + CHECK(pf != nullptr); + p->stream << pf->value; + if (i != node->costs.size() - 1) { + p->stream << ","; + } + } + p->stream.precision(old_config); + p->stream << "], "; + p->stream << "error_no:" << 0 << ", " + << "all_cost:" << node->all_cost << ", " + << "Tstamp:" << node->timestamp << ")"; + } else { + p->stream << "MeasureResult(" + << "error_type:" << ErrorNoToStr[node->error_no] << ", " + << "error_msg:" << node->error_msg << ", " + << "all_cost:" << node->all_cost << ", " + << "Tstamp:" << node->timestamp << ")"; + } +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) +.set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { + auto* node = static_cast(ref.get()); + p->stream << "BuildResult(" << node->filename << ", " << node->error_no + << ", " << node->time_cost << ")"; +}); + +} // namespace ansor +} // namespace tvm diff --git a/src/ansor/measure.h b/src/ansor/measure.h new file mode 100644 index 000000000000..4ea1562315ff --- /dev/null +++ b/src/ansor/measure.h @@ -0,0 +1,262 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file ansor/search_task.h + * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs + */ + +#ifndef TVM_ANSOR_MEASURE_H_ +#define TVM_ANSOR_MEASURE_H_ + +// #include +#include +#include +#include +#include +#include "search_task.h" +#include "loop_state.h" + +namespace tvm { +namespace ansor { + +class SearchPolicy; +class MeasureInput; class BuildResult; class MeasureResult; +class Builder; class Runner; class MeasureCallback; class ProgramMeasurer; + +extern const char *ErrorNoToStr[]; + +enum MeasureErrorNO { + kNoError = 0, // No error + kInstantiationError = 1, // Errors happen when apply transform steps from init state + kCompileHostError = 2, // Errors happen when compiling code on host (when build module) + kCompileDeviceError = 3, // Errors happen when compiling code on device (when load module) + kRuntimeDeviceError = 4, // Errors happen when run program on device + kWrongAnswerError = 5, // Answer is wrong when compared to a reference output + kBuildTimeoutError = 6, // Timeout during compilation + kRunTimeoutError = 7, // Timeout during run + kUnknonwError = 8, // Unknown error +}; + +// Inputs and results of one measurement + +/* \brief Store the input of a meansurement */ +class MeasureInputNode: public Object { + public: + SearchTask task; + State state; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("task", &task); + v->Visit("state", &state); + } + + static MeasureInput make(SearchTask task, State state); + MeasureInput copy() const; // Do deep copy + + static constexpr const char* _type_key = "ansor.MeasureInput"; + TVM_DECLARE_FINAL_OBJECT_INFO(MeasureInputNode, Object); +}; +TVM_DEFINE_NODE_REF(MeasureInput, MeasureInputNode); + +/* \brief Store the input of a build */ +class BuildResultNode: public Object { + public: + std::string filename; + Array args; + int error_no; + std::string error_msg; + double time_cost; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("filename", &filename); + v->Visit("args", &args); + v->Visit("error_no", &error_no); + v->Visit("error_msg", &error_msg); + v->Visit("time_cost", &time_cost); + } + + static BuildResult make(std::string filename, Array args, + int error_no, std::string error_msg, double time_cost); + + static constexpr const char* _type_key = "ansor.BuildResult"; + TVM_DECLARE_FINAL_OBJECT_INFO(BuildResultNode, Object); +}; +TVM_DEFINE_NODE_REF(BuildResult, BuildResultNode); + +/* \brief Store the results of a measurement */ +class MeasureResultNode: public Object { + public: + Array costs; + int error_no; + std::string error_msg; + double all_cost; + double timestamp; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("costs", &costs); + v->Visit("error_no", &error_no); + v->Visit("error_msg", &error_msg); + v->Visit("all_cost", &all_cost); + v->Visit("timestamp", ×tamp); + } + + MeasureResult copy() const; // Do deep copy + + static MeasureResult make(Array costs, int error_no, std::string error_msg, + double all_cost, double timestamp); + + static constexpr const char* _type_key = "ansor.MeasureResult"; + TVM_DECLARE_FINAL_OBJECT_INFO(MeasureResultNode, Object); +}; +TVM_DEFINE_NODE_REF(MeasureResult, MeasureResultNode); + + +// Measure callback +class MeasureCallbackNode: public Object { + public: + virtual void callback(const SearchPolicy& policy, + const Array& inputs, + const Array& results) = 0; + static constexpr const char *_type_key = "ansor.MeasureCallback"; + TVM_DECLARE_BASE_OBJECT_INFO(MeasureCallbackNode, Object); +}; +TVM_DEFINE_MUTABLE_NODE_REF(MeasureCallback, MeasureCallbackNode); + + +// Base class for builder and runner + +/* \brief Builder that builds the programs */ +class BuilderNode: public Object { + public: + int n_parallel; + int timeout; + + virtual Array Build(const Array& inputs, int verbose) = 0; + + static constexpr const char* _type_key = "ansor.Builder"; + TVM_DECLARE_BASE_OBJECT_INFO(BuilderNode, Object); +}; +TVM_DEFINE_MUTABLE_NODE_REF(Builder, BuilderNode); + +/* \brief Runner that runs the built programs and measure the time cost */ +class RunnerNode: public Object { + public: + int timeout; + + virtual Array Run(const Array& inputs, + const Array& build_results, + int verbose) = 0; + + static constexpr const char* _type_key = "ansor.Runner"; + TVM_DECLARE_BASE_OBJECT_INFO(RunnerNode, Object); +}; +TVM_DEFINE_MUTABLE_NODE_REF(Runner, RunnerNode); + + +// Implementation of various builders and runners +/* \brief LocalBuilder use local CPU cores to build programs in parallel */ +class LocalBuilderNode: public BuilderNode { + public: + std::string build_func; + + static Builder make(int timeout, int n_parallel, const std::string& build_func); + + Array Build(const Array& inputs, int verbose) final; + + static constexpr const char* _type_key = "ansor.LocalBuilder"; + TVM_DECLARE_FINAL_OBJECT_INFO(LocalBuilderNode, BuilderNode); +}; + +class RPCRunnerNode : public RunnerNode { + public: + std::string key; + std::string host; + int port; + int priority; + int n_parallel; + int number; + int repeat; + int min_repeat_ms; + double cooldown_interval; + + static Runner make(const std::string& key, const std::string& host, int port, + int priority, int timeout, int n_parallel, int number, + int repeat, int min_repeat_ms, double cooldown_interval); + + Array Run(const Array& inputs, + const Array& build_results, + int verbose) final; + + static constexpr const char* _type_key = "ansor.RPCRunner"; + TVM_DECLARE_FINAL_OBJECT_INFO(RPCRunnerNode, RunnerNode); +}; + +/* \brief LocalRunner use local CPU/GPU to runs programs in serial and measure the time cost */ +class LocalRunnerNode: public RunnerNode { + public: + int number; + int repeat; + int min_repeat_ms; + double cooldown_interval; + + static Runner make(int timeout, int number, int repeat, + int min_repeat_ms, double cooldown_interval); + + Array Run(const Array& inputs, + const Array& build_results, + int verbose) final; + + static constexpr const char* _type_key = "ansor.LocalRunner"; + TVM_DECLARE_FINAL_OBJECT_INFO(LocalRunnerNode, RunnerNode); +}; + + +/*! + * \brief Measurer measures the time costs of tvm programs + * This class combines Builder and Runner, and provides a simpler API + */ +class ProgramMeasurerNode: public Object { + public: + static const int DEFAULT_MAX_CONTINOUS_ERROR = 150; + + int ct; + int error_ct; // continuous error counter + std::unordered_map best_flops; + std::unordered_map best_state; + std::unordered_map best_ct; + + Builder builder; + Runner runner; + Array callbacks; + int verbose; + int max_continous_error; + + static ProgramMeasurer make(Builder builder, Runner runner, + Array callbacks, + int verbose, + int max_continous_error = -1); + + /*! \brief Reset book keeping variables */ + void Reset(); + + /*! \biref Do measurement */ + void Measure(const SearchTask& task, + const SearchPolicy& policy, + const std::vector& inputs, + std::vector* results, + int batch_size = -1); + + /*! \biref Do measurement silently */ + void SilentMeasure(const SearchTask& task, + const std::vector& inputs, + std::vector* results); + + static constexpr const char* _type_key = "ansor.ProgramMeasurer"; + TVM_DECLARE_FINAL_OBJECT_INFO(ProgramMeasurerNode, Object); +}; +TVM_DEFINE_MUTABLE_NODE_REF(ProgramMeasurer, ProgramMeasurerNode); + + +} // namespace ansor +} // namespace tvm + +#endif // TVM_ANSOR_MEASURE_H_ diff --git a/src/ansor/search_task.cc b/src/ansor/search_task.cc new file mode 100644 index 000000000000..b9cda9168b9e --- /dev/null +++ b/src/ansor/search_task.cc @@ -0,0 +1,120 @@ +/*! + * Copyright (c) 2020 by Contributors + */ +#include "search_task.h" +#include +#include +#include +#include +#include + +namespace tvm { +namespace ansor { + +TVM_REGISTER_OBJECT_TYPE(HardwareParamsNode); +TVM_REGISTER_OBJECT_TYPE(SearchTaskNode); + +HardwareParams HardwareParamsNode::make(int num_cores, int vector_unit_bytes, + int cache_line_bytes, int max_unroll_vec, + int max_innermost_split_factor) { + auto node = make_object(); + node->num_cores = num_cores; + node->vector_unit_bytes = vector_unit_bytes; + node->cache_line_bytes = cache_line_bytes; + node->max_unroll_vec = max_unroll_vec; + node->max_innermost_split_factor = max_innermost_split_factor; + return HardwareParams(node); +} + +HardwareParams HardwareParamsNode::GetDefaultHardwareParams( + const Target& target, const Target& target_host) { + if (target->target_name == "llvm") { + return HardwareParamsNode::make(tvm::runtime::threading::MaxConcurrency(), + 32, 64, 16, 64); + } else if (target->device_type == kDLGPU) { + // TODO(jcf94): temp implementation, max vectorize size in GPU is related + // to the data type + auto hardware_params = HardwareParamsNode::make(100000, 16, 64, 4, 64); + auto* p_hardware_params = hardware_params.CopyOnWrite(); + + auto ctx = TVMContext{kDLGPU, 0}; + auto func = tvm::runtime::Registry::Get("device_api.gpu"); + CHECK(func != nullptr) << "Cannot find GPU device_api in registry"; + auto device_api = static_cast(((*func)()).operator void*()); + + tvm::runtime::TVMRetValue ret; + device_api->GetAttr(ctx, + tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, + &ret); + p_hardware_params->max_shared_memory_per_block = ret; + + device_api->GetAttr(ctx, + tvm::runtime::DeviceAttrKind::kMaxRegistersPerBlock, + &ret); + p_hardware_params->max_registers_per_block = ret; + + device_api->GetAttr(ctx, + tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, + &ret); + p_hardware_params->max_threads_per_block = ret; + + device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kWarpSize, &ret); + p_hardware_params->warp_size = ret; + + // Manually set now + p_hardware_params->max_vthread_extent = 4; + + return hardware_params; + } else if (target->device_type == kDLOpenCL) { + // TODO(jcf94): temp implementation + auto hardware_params = HardwareParamsNode::make(100000, 16, 64, 4, 64); + auto p_hardware_params = hardware_params.CopyOnWrite(); + + auto ctx = TVMContext{kDLOpenCL, 0}; + auto func = tvm::runtime::Registry::Get("device_api.opencl"); + CHECK(func != nullptr) << "Cannot find GPU device_api in registry"; + auto device_api = static_cast(((*func)()).operator void*()); + + tvm::runtime::TVMRetValue ret; + device_api->GetAttr(ctx, + tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, + &ret); + p_hardware_params->max_shared_memory_per_block = ret; + + device_api->GetAttr(ctx, + tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, + &ret); + p_hardware_params->max_threads_per_block = ret; + + device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kWarpSize, &ret); + p_hardware_params->warp_size = ret; + + // Manually set now + p_hardware_params->max_vthread_extent = 4; + + return hardware_params; + } else { + LOG(FATAL) << "No default hardware parameters for target: " << target; + } + return HardwareParams(); +} + + +SearchTask SearchTaskNode::make(ComputeDAG compute_dag, std::string workload_key, + Target target, Target target_host, HardwareParams hardware_params) { + auto node = make_object(); + node->compute_dag = std::move(compute_dag); + node->workload_key = std::move(workload_key); + node->target = std::move(target); + node->target_host = std::move(target_host); + if (hardware_params.defined()) { + node->hardware_params = std::move(hardware_params); + } else { + node->hardware_params = HardwareParamsNode::GetDefaultHardwareParams( + node->target, node->target_host); + } + return SearchTask(node); +} + +} // namespace ansor +} // namespace tvm diff --git a/src/ansor/search_task.h b/src/ansor/search_task.h new file mode 100644 index 000000000000..7db98a5197a5 --- /dev/null +++ b/src/ansor/search_task.h @@ -0,0 +1,92 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file ansor/search_task.h + * \brief Meta information for a search task + */ + +#ifndef TVM_ANSOR_SEARCH_TASK_H_ +#define TVM_ANSOR_SEARCH_TASK_H_ + +#include +#include +#include "compute_dag.h" + +namespace tvm { +namespace ansor { + +class HardwareParams; class SearchTask; + +/*! \brief Hardware related parameters */ +class HardwareParamsNode : public Object { + public: + int num_cores; + int vector_unit_bytes; + int cache_line_bytes; + // The max length of the axis to be unrolled or vectorized + int max_unroll_vec; + // The max split factor for the innermost tile + int max_innermost_split_factor; + + // Limit params for GPU schedule + int max_shared_memory_per_block{INT32_MAX}; + int max_registers_per_block{INT32_MAX}; + int max_threads_per_block{INT32_MAX}; + int max_vthread_extent{INT32_MAX}; + int warp_size{INT32_MAX}; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("num_cores", &num_cores); + v->Visit("vector_unit_bytes", &vector_unit_bytes); + v->Visit("cache_line_bytes", &cache_line_bytes); + v->Visit("max_unroll_vec", &max_unroll_vec); + v->Visit("max_innermost_split_factor", &max_innermost_split_factor); + + v->Visit("max_shared_memory_per_block", &max_shared_memory_per_block); + v->Visit("max_registers_per_block", &max_registers_per_block); + v->Visit("max_threads_per_block", &max_threads_per_block); + v->Visit("max_vthread_extent", &max_vthread_extent); + v->Visit("warp_size", &warp_size); + } + + static HardwareParams make(int num_cores, int vector_unit_bytes, + int cache_line_bytes, int max_unroll_vec, + int max_innermost_split_factor); + static HardwareParams GetDefaultHardwareParams(const Target& target, + const Target& target_host); + + static constexpr const char *_type_key = "ansor.HardwareParams"; + TVM_DECLARE_FINAL_OBJECT_INFO(HardwareParamsNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(HardwareParams, ObjectRef, HardwareParamsNode); + + +/*! \brief Meta-info for a search task */ +class SearchTaskNode : public Object { + public: + ComputeDAG compute_dag; + std::string workload_key; + Target target; + Target target_host; + HardwareParams hardware_params; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("compute_dag", &compute_dag); + v->Visit("workload_key", &workload_key); + v->Visit("target", &target); + v->Visit("target_host", &target_host); + v->Visit("hardware_params", &hardware_params); + } + + static SearchTask make(ComputeDAG compute_dag, std::string workload_key, + Target target, Target target_host, + HardwareParams hardware_params); + + static constexpr const char *_type_key = "ansor.SearchTask"; + TVM_DECLARE_FINAL_OBJECT_INFO(SearchTaskNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(SearchTask, ObjectRef, SearchTaskNode); + +} // namespace ansor +} // namespace tvm + +#endif // TVM_ANSOR_SEARCH_TASK_H_ diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc new file mode 100644 index 000000000000..0e2b0be42587 --- /dev/null +++ b/src/ansor/serialization.cc @@ -0,0 +1,573 @@ +/*! + * Copyright (c) 2020 by Contributors + */ +#include +// #include +#include +#include +#include +#include +#include +#include +#include "serialization.h" +#include "loop_state.h" +#include "utils.h" + +// Json serialization handler for MeasureInput, MeasureResult +// (and recursively SearchTask, State, Step, ... +namespace dmlc { +namespace json { + +inline std::vector& FloatArrayToVector(std::vector* out, + const ::tvm::Array<::tvm::PrimExpr>& data) { + out->clear(); + for (const auto&x : data) { + auto pf = x.as<::tvm::tir::FloatImmNode>(); + CHECK(pf != nullptr) << "Cost can only contain float values"; + out->push_back(pf->value); + } + return *out; +} + +inline std::vector& IntArrayToVector(std::vector* out, + const ::tvm::Array<::tvm::PrimExpr>& data) { + out->clear(); + for (const auto&x : data) { + auto pi = x.as<::tvm::tir::IntImmNode>(); + CHECK(pi != nullptr) << "Cost can only contain int values"; + out->push_back(pi->value); + } + return *out; +} + +template <> +struct Handler > { + inline static void Write(dmlc::JSONWriter* writer, + const std::vector<::tvm::ansor::Stage> & data) { + // todo(lmzheng): support serialization of Stage + writer->BeginArray(false); + writer->EndArray(); + } + inline static void Read(dmlc::JSONReader* reader, + std::vector<::tvm::ansor::Stage> * data) { + bool s; + reader->BeginArray(); + s = reader->NextArrayItem(); CHECK(!s); + } +}; + +template <> +struct Handler > { + inline static void Write(dmlc::JSONWriter* writer, + const std::vector<::tvm::ansor::Step> & data) { + std::vector tmp; + writer->BeginArray(false); + for (size_t i = 0; i < data.size(); ++i) { + writer->WriteArraySeperator(); + writer->BeginArray(false); + if (auto ps = data[i].as<::tvm::ansor::ReorderStepNode>()) { + writer->WriteArrayItem(std::string("RS")); + writer->WriteArrayItem(ps->stage_id); + + writer->WriteArraySeperator(); + writer->BeginArray(false); + for (int x : ps->after_ids) { + writer->WriteArrayItem(x); + } + writer->EndArray(); + } else if (auto ps = data[i].as<::tvm::ansor::SplitStepNode>()) { + writer->WriteArrayItem(std::string("SS")); + writer->WriteArrayItem(ps->stage_id); + writer->WriteArrayItem(ps->iter_id); + if (ps->extent.defined()) { + writer->WriteArrayItem(::tvm::ansor::GetIntImm(ps->extent)); + } else { + writer->WriteArrayItem(0); + } + writer->WriteArrayItem(IntArrayToVector(&tmp, ps->lengths)); + writer->WriteArrayItem(static_cast(ps->inner_to_outer)); + } else if (auto ps = data[i].as<::tvm::ansor::FollowSplitStepNode>()) { + writer->WriteArrayItem(std::string("FSS")); + writer->WriteArrayItem(ps->stage_id); + writer->WriteArrayItem(ps->iter_id); + writer->WriteArrayItem(ps->src_step_id); + writer->WriteArrayItem(ps->n_split); + } else if (auto ps = data[i].as<::tvm::ansor::FollowFusedSplitStepNode>()) { + writer->WriteArrayItem(std::string("FFSS")); + writer->WriteArrayItem(ps->stage_id); + writer->WriteArrayItem(ps->iter_id); + + writer->WriteArraySeperator(); + writer->BeginArray(false); + for (int x : ps->src_step_ids) { + writer->WriteArrayItem(x); + } + writer->EndArray(); + + writer->WriteArrayItem(ps->level); + writer->WriteArrayItem(static_cast(ps->factor_or_nparts)); + } else if (auto ps = data[i].as<::tvm::ansor::FuseStepNode>()) { + writer->WriteArrayItem(std::string("FS")); + writer->WriteArrayItem(ps->stage_id); + + writer->WriteArraySeperator(); + writer->BeginArray(false); + for (int x : ps->fused_ids) { + writer->WriteArrayItem(x); + } + writer->EndArray(); + } else if (auto ps = data[i].as<::tvm::ansor::AnnotationStepNode>()) { + writer->WriteArrayItem(std::string("AS")); + writer->WriteArrayItem(ps->stage_id); + writer->WriteArrayItem(ps->iter_id); + writer->WriteArrayItem(static_cast(ps->annotation)); + } else if (auto ps = data[i].as<::tvm::ansor::ComputeAtStepNode>()) { + writer->WriteArrayItem(std::string("CA")); + writer->WriteArrayItem(ps->stage_id); + writer->WriteArrayItem(ps->target_stage_id); + writer->WriteArrayItem(ps->target_iter_id); + } else if (auto ps = data[i].as<::tvm::ansor::ComputeRootStepNode>()) { + writer->WriteArrayItem(std::string("CR")); + writer->WriteArrayItem(ps->stage_id); + } else if (auto ps = data[i].as<::tvm::ansor::ComputeInlineStepNode>()) { + writer->WriteArrayItem(std::string("CI")); + writer->WriteArrayItem(ps->stage_id); + } else if (auto ps = data[i].as<::tvm::ansor::CacheReadStepNode>()) { + writer->WriteArrayItem(std::string("CHR")); + writer->WriteArrayItem(ps->stage_id); + writer->WriteArrayItem(ps->scope_name); + writer->WriteArrayItem(ps->reader_stage_ids); + } else if (auto ps = data[i].as<::tvm::ansor::CacheWriteStepNode>()) { + writer->WriteArrayItem(std::string("CHW")); + writer->WriteArrayItem(ps->stage_id); + writer->WriteArrayItem(ps->scope_name); + } else if (auto ps = data[i].as<::tvm::ansor::PragmaStepNode>()) { + writer->WriteArrayItem(std::string("PS")); + writer->WriteArrayItem(ps->stage_id); + writer->WriteArrayItem(ps->iter_id); + writer->WriteArrayItem(ps->pragma_type); + } else if (auto ps = data[i].as<::tvm::ansor::RfactorStepNode>()) { + writer->WriteArrayItem(std::string("RFS")); + writer->WriteArrayItem(ps->stage_id); + writer->WriteArrayItem(ps->iter_id); + writer->WriteArrayItem(ps->factor_iter_id); + } else if (auto ps = data[i].as<::tvm::ansor::StorageAlignStepNode>()) { + writer->WriteArrayItem(std::string("SA")); + writer->WriteArrayItem(ps->stage_id); + writer->WriteArrayItem(ps->iter_id); + writer->WriteArrayItem(ps->factor); + writer->WriteArrayItem(ps->offset); + } else { + LOG(FATAL) << "Invalid step: " << data[i]; + } + writer->EndArray(); + } + writer->EndArray(); + } + inline static void Read(dmlc::JSONReader* reader, + std::vector<::tvm::ansor::Step> * data) { + std::vector int_list; + bool s, inner_to_outer, factor_or_nparts; + std::string name, scope_name, pragma_type; + int stage_id, target_stage_id, iter_id, src_step_id, n_split, ann, extent; + int level, factor_iter_id, factor, offset; + + reader->BeginArray(); + data->clear(); + while (reader->NextArrayItem()) { + reader->BeginArray(); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&name); + if (name == "RS") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&int_list); + data->push_back(::tvm::ansor::ReorderStepNode::make(stage_id, int_list)); + } else if (name == "SS") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&iter_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&extent); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&int_list); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&inner_to_outer); + data->push_back(::tvm::ansor::SplitStepNode::make( + stage_id, iter_id, extent, + std::vector<::tvm::PrimExpr>(int_list.begin(), int_list.end()), + inner_to_outer)); + } else if (name == "FSS") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&iter_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&src_step_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&n_split); + data->push_back(::tvm::ansor::FollowSplitStepNode::make( + stage_id, iter_id, src_step_id, n_split)); + } else if (name == "FFSS") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&iter_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&int_list); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&level); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&factor_or_nparts); + data->push_back(::tvm::ansor::FollowFusedSplitStepNode::make( + stage_id, iter_id, int_list, level, factor_or_nparts)); + } else if (name == "FS") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&int_list); + data->push_back(::tvm::ansor::FuseStepNode::make(stage_id, int_list)); + } else if (name == "AS") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&iter_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&ann); + data->push_back(::tvm::ansor::AnnotationStepNode::make(stage_id, + iter_id, ::tvm::ansor::IteratorAnnotation(ann))); + } else if (name == "CA") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&target_stage_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&iter_id); + data->push_back(::tvm::ansor::ComputeAtStepNode::make( + stage_id, target_stage_id, iter_id)); + } else if (name == "CR") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + data->push_back(::tvm::ansor::ComputeRootStepNode::make(stage_id)); + } else if (name == "CI") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + data->push_back(::tvm::ansor::ComputeInlineStepNode::make(stage_id)); + } else if (name == "CHR") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&scope_name); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&int_list); + data->push_back(::tvm::ansor::CacheReadStepNode::make( + stage_id, scope_name, int_list)); + } else if (name == "CHW") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&scope_name); + data->push_back(::tvm::ansor::CacheWriteStepNode::make( + stage_id, scope_name)); + } else if (name == "PS") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&iter_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&pragma_type); + data->push_back(::tvm::ansor::PragmaStepNode::make( + stage_id, iter_id, pragma_type)); + } else if (name == "RFS") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&iter_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&factor_iter_id); + data->push_back(::tvm::ansor::RfactorStepNode::make( + stage_id, iter_id, factor_iter_id)); + } else if (name == "SA") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&iter_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&factor); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&offset); + data->push_back(::tvm::ansor::StorageAlignStepNode::make( + stage_id, iter_id, factor, offset)); + } else { + LOG(FATAL) << "Invalid step format"; + } + s = reader->NextArrayItem(); CHECK(!s); + } + } +}; + +template <> +struct Handler<::tvm::ansor::StateNode> { + inline static void Write(dmlc::JSONWriter* writer, + const ::tvm::ansor::StateNode& data) { + writer->BeginArray(false); + writer->WriteArrayItem(data.stages); + writer->WriteArrayItem(data.transform_steps); + writer->EndArray(); + } + inline static void Read(dmlc::JSONReader* reader, + ::tvm::ansor::StateNode* data) { + reader->BeginArray(); + bool s; + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&data->stages); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&data->transform_steps); + s = reader->NextArrayItem(); CHECK(!s); + } +}; + +template <> +struct Handler<::tvm::ansor::SearchTaskNode> { + inline static void Write(dmlc::JSONWriter* writer, + const ::tvm::ansor::SearchTaskNode& data) { + writer->BeginArray(false); + writer->WriteArrayItem(data.workload_key); + writer->WriteArrayItem(data.target->str()); + writer->EndArray(); + } + inline static void Read(dmlc::JSONReader* reader, + ::tvm::ansor::SearchTaskNode* data) { + std::string target_str; + bool s; + + reader->BeginArray(); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&data->workload_key); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&target_str); + data->target = ::tvm::Target::Create(target_str); + s = reader->NextArrayItem(); CHECK(!s); + } +}; + +template <> +struct Handler<::tvm::ansor::MeasureInputNode> { + inline static void Write(dmlc::JSONWriter* writer, + const ::tvm::ansor::MeasureInputNode& data) { + writer->BeginArray(false); + writer->WriteArrayItem(*data.task.operator->()); + writer->WriteArrayItem(*data.state.operator->()); + writer->EndArray(); + } + inline static void Read(dmlc::JSONReader* reader, + ::tvm::ansor::MeasureInputNode* data) { + bool s; + auto task_node = ::tvm::make_object<::tvm::ansor::SearchTaskNode>(); + auto state_node = ::tvm::make_object<::tvm::ansor::StateNode>(); + state_node->complete = true; + + reader->BeginArray(); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(task_node.get()); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(state_node.get()); + s = reader->NextArrayItem(); CHECK(!s); + + data->task = ::tvm::ansor::SearchTask(task_node); + data->state = ::tvm::ansor::State(state_node); + } +}; + +template <> +struct Handler<::tvm::ansor::MeasureResultNode> { + inline static void Write(dmlc::JSONWriter* writer, + const ::tvm::ansor::MeasureResultNode& data) { + writer->BeginArray(false); + writer->WriteArraySeperator(); + writer->BeginArray(false); + for (const auto&x : data.costs) { + auto pf = x.as<::tvm::tir::FloatImmNode>(); + CHECK(pf != nullptr) << "Cost can only contain float values"; + writer->WriteArrayItem(pf->value); + } + writer->EndArray(); + writer->WriteArrayItem(data.error_no); + writer->WriteArrayItem(data.all_cost); + writer->WriteArrayItem(static_cast((data.timestamp))); + writer->EndArray(); + } + inline static void Read(dmlc::JSONReader* reader, + ::tvm::ansor::MeasureResultNode* data) { + bool s; + std::vector tmp; + + reader->BeginArray(); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&tmp); + data->costs.clear(); + for (const auto& i : tmp) { + data->costs.push_back(i); + } + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&data->error_no); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&data->all_cost); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&data->timestamp); + s = reader->NextArrayItem(); CHECK(!s); + } +}; + +} // namespace json +} // namespace dmlc + +namespace tvm { +namespace ansor { + +TVM_REGISTER_OBJECT_TYPE(LogToFileNode); +TVM_REGISTER_OBJECT_TYPE(LogReaderNode); + +const std::string ansor_LOG_VERSION = "v0.1"; // NOLINT(*) + +MeasureCallback LogToFileNode::make(std::string filename) { + auto node = make_object(); + node->filename = std::move(filename); + return MeasureCallback(node); +} + +void WriteMeasureRecords(std::ostream* os, + const Array& inputs, + const Array& results) { + dmlc::JSONWriter writer(os); + for (size_t i = 0; i < inputs.size(); ++i) { + writer.BeginObject(false); + writer.WriteObjectKeyValue("i", *inputs[i].operator->()); + writer.WriteObjectKeyValue("r", *results[i].operator->()); + writer.WriteObjectKeyValue("v", ansor_LOG_VERSION); + writer.EndObject(); + *os << "\n"; + } +} + +void ReadMeasureRecords(std::string str, + MeasureInputNode* inp, + MeasureResultNode* res, + std::string* log_version) { + std::istringstream ss(str); + dmlc::JSONReader reader(&ss); + std::string key; + + reader.BeginObject(); + while (reader.NextObjectItem(&key)) { + if (key == "i") { + reader.Read(inp); + } else if (key == "r") { + reader.Read(res); + } else if (key == "v") { + reader.Read(log_version); + } else { + LOG(FATAL) << "Invalid key in json log: " << key; + } + } +} + +TVM_REGISTER_GLOBAL("ansor.write_measure_records_to_file") +.set_body([](TVMArgs args, TVMRetValue *ret) { + std::string filename = args[0]; + Array in = args[1]; + Array res = args[2]; + std::ofstream ofs(filename, std::ofstream::app); + WriteMeasureRecords(&ofs, in, res); +}); + +void LogToFileNode::callback(const SearchPolicy& policy, + const Array& inputs, + const Array& results) { + std::ofstream ofs(filename, std::ofstream::app); + WriteMeasureRecords(&ofs, inputs, results); +} + +LogReader LogReaderNode::make(std::string filename) { + auto node = make_object(); + node->filename = filename; + node->infile.open(filename, std::ifstream::in); + return LogReader(node); +} + +bool LogReaderNode::ReadNext(MeasureInputNode* inp, MeasureResultNode* res) { + std::string log_version; + + while (std::getline(infile, cur_line)) { + if (cur_line[0] == '#' || cur_line[0] == ' ') { + // skip comment lines begin with '#' or ' ' + continue; + } + + try { + ReadMeasureRecords(cur_line, inp, res, &log_version); + } catch (...) { + return false; + } + + return true; + } + + return false; +} + +std::pair, Array > LogReaderNode::ReadLines( + int max_size, int skip_size) { + auto inp = make_object(); + auto res = make_object(); + Array inputs; + Array results; + + while (ReadNext(inp.get(), res.get())) { + if (skip_size > 0) { + skip_size--; + continue; + } + + inputs.push_back(inp->copy()); + results.push_back(res->copy()); + + if (max_size > 0 && static_cast(inputs.size()) >= max_size) { + break; + } + } + + return std::make_pair(inputs, results); +} + +std::pair BestMeasurePairInFile(const std::string& filename, + const std::string& workload_key, + const Target& target) { + std::pair best_pair; + double best_cost = 1e30; + + auto inp = make_object(); + auto res = make_object(); + LogReader reader = LogReaderNode::make(filename); + + while (reader->ReadNext(inp.get(), res.get())) { + if (res->error_no != kNoError || inp->task->workload_key != workload_key + || inp->task->target->target_name != target->target_name) { + continue; + } + + double cost = FloatArrayMean(res->costs); + + if (cost < best_cost) { + best_cost = cost; + best_pair = std::make_pair(inp->copy(), res->copy()); + } + } + + return best_pair; +} + +} // namespace ansor +} // namespace tvm \ No newline at end of file diff --git a/src/ansor/serialization.h b/src/ansor/serialization.h new file mode 100644 index 000000000000..96dfb0ee320b --- /dev/null +++ b/src/ansor/serialization.h @@ -0,0 +1,78 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file ansor/serialization.h + * \brief Json serialization format for dumping and loading tuning records + */ + +#ifndef TVM_ANSOR_SERIALIZATION_H_ +#define TVM_ANSOR_SERIALIZATION_H_ + +#include +#include +#include +#include "measure.h" +// #include "search_policy/search_policy.h" + +namespace tvm { +namespace ansor { + +class LogReader; + +/*! \brief Log the input and results of measurments to file */ +class LogToFileNode: public MeasureCallbackNode { + public: + std::string filename; + + static MeasureCallback make(std::string filename); + + /*! \brief Log measure pairs to file. This is called by the search policy */ + void callback(const SearchPolicy& policy, + const Array& inputs, + const Array& results) final; + + static constexpr const char *_type_key = "ansor.LogToFile"; + TVM_DECLARE_FINAL_OBJECT_INFO(LogToFileNode, MeasureCallbackNode); +}; + +/*! \brief Log reader */ +class LogReaderNode: public Object { + public: + std::string filename; + std::ifstream infile; + + static LogReader make(std::string filename); + + /*! \brief Read next line in the log file + * \return Whether the read is successful */ + bool ReadNext(MeasureInputNode* inp, MeasureResultNode* res); + + /*! \brief Read multiple lines from the log file + * \param max_size The maximum number of lines. -1 means read all lines + * \param skip_size Skip the first n lines */ + std::pair, Array > ReadLines( + int max_size = -1, int skip_size = 0); + + static constexpr const char* _type_key = "ansor.LogReader"; + TVM_DECLARE_FINAL_OBJECT_INFO(LogReaderNode, Object); + private: + std::string cur_line; +}; +TVM_DEFINE_MUTABLE_NODE_REF(LogReader, LogReaderNode); + +void WriteMeasureRecords(std::ostream* os, + const Array& inputs, + const Array& results); + +void ReadMeasureRecords(std::string str, + MeasureInputNode* inp, + MeasureResultNode* res, + std::string* log_version); + +std::pair BestMeasurePairInFile(const std::string& filename, + const std::string& workload_key, + const Target& target); + +} // namespace ansor +} // namespace tvm + +#endif // TVM_ANSOR_SERIALIZATION_H_ diff --git a/src/ansor/utils.h b/src/ansor/utils.h index 4ea7f283ad09..67ebb836c680 100644 --- a/src/ansor/utils.h +++ b/src/ansor/utils.h @@ -61,6 +61,13 @@ struct hash > { namespace tvm { namespace ansor { +/*! \brief Macro to make it easy to define node ref type given node */ +#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \ + class TypeName : public ObjectRef { \ + public: \ + TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ObjectRef, NodeName); \ + }; \ + /*! \brief Macro to make it easy to define mutable node ref type given node */ #define TVM_DEFINE_MUTABLE_NODE_REF(TypeName, NodeName) \ class TypeName : public ObjectRef { \ diff --git a/src/ir/expr.cc b/src/ir/expr.cc index fd380aa33f86..6e898dd5ddb4 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -38,6 +38,8 @@ PrimExpr::PrimExpr(int32_t value) : PrimExpr(IntImm(DataType::Int(32), value)) { PrimExpr::PrimExpr(float value) : PrimExpr(FloatImm(DataType::Float(32), value)) {} +PrimExpr::PrimExpr(double value) : PrimExpr(FloatImm(DataType::Float(64), value)) {} + PrimExpr PrimExpr::FromObject_(ObjectRef ref) { using runtime::ObjectTypeChecker; if (auto* ptr = ref.as()) { diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index a6d4a5499469..4e71383cc1bb 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -94,6 +94,10 @@ class CUDADeviceAPI final : public DeviceAPI { } case kGcnArch: return; + case kMaxRegistersPerBlock: { + CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrMaxRegistersPerBlock, ctx.device_id)); + break; + } } *rv = value; } diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index 6d9835e6231c..71d3232ca4d5 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -109,6 +109,9 @@ void OpenCLWorkspace::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* } case kGcnArch: return; + default: { + LOG(WARNING) << "Attr not implemented."; + } } } diff --git a/tests/cpp/ansor_test.cc b/tests/cpp/ansor_test.cc index 87e7ad71a7c0..c43ec5c0a751 100644 --- a/tests/cpp/ansor_test.cc +++ b/tests/cpp/ansor_test.cc @@ -23,6 +23,7 @@ #include #include #include "../../src/ansor/loop_state.h" +#include "../../src/ansor/serialization.h" tvm::Array matmul_func(int n, int m, int k) { using namespace tvm; @@ -157,6 +158,63 @@ TEST(ComputeDAG, GetProducersConsumers) { } } +TEST(ComputeDAG, InferBoundSerialization) { + using namespace tvm::ansor; + + const auto& tensors = matmul_func(512, 512, 512); + const auto& dag = ComputeDAGNode::make(tensors); + int A = 0, B = 1, C = 2; + + State s0 = dag.GetInitState(); + int C_global = s0.cache_write(C, "global", dag); + C++; + const auto& its0 = s0.split(C, s0->stages[C]->iters[0], {4, 8, 8}); + const auto& its1 = s0.split(C, s0->stages[C]->iters[4], {8, 4, 4}); + s0.reorder(C, {its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], + its0[3], its1[3]}); + s0.compute_at(C_global, C, s0->stages[C]->iters[3]); + s0.split(C_global, s0->stages[C_global]->iters[2], {16}); + int B_global = s0.cache_read(B, "global", {C_global}, dag); + C++; C_global++; + s0.compute_at(B_global, C_global, s0->stages[C_global]->iters[0]); + int A_global = s0.cache_read(A, "global", {C_global}, dag); + B++; B_global++; C++; C_global++; + s0.compute_at(A_global, C_global, s0->stages[C_global]->iters[2]); + + const auto& s1 = dag.InferBound(s0); + std::vector s2 = {s0}; + dag.InferBound(&s2); + const auto& s3 = dag.ReplayAndInferBound(s0->transform_steps); + + CHECK_EQ(s1->stages[B_global]->iters[0]->range->extent.as()->value, + 512); + CHECK_EQ(s1->stages[B_global]->iters[1]->range->extent.as()->value, + 16); + CHECK_EQ(s1->stages[A_global]->iters[0]->range->extent.as()->value, + 1); + CHECK_EQ(s1->stages[A_global]->iters[1]->range->extent.as()->value, + 16); + CHECK_EQ(s1->stages[C_global]->iters[0]->range->extent.as()->value, + 64); + CHECK(std::equal_to()(s1, s2[0])); + CHECK(std::equal_to()(s1, s3)); + + const auto& minp0 = MeasureInputNode::make( + SearchTaskNode::make(dag, "test", tvm::target::llvm(), + tvm::target::llvm(), + HardwareParams()), + s0); + const auto& mres0 = MeasureResultNode::make({0.1}, 0, "", 0.1, 0.1); + std::stringstream ss; + WriteMeasureRecords(&ss, {minp0}, {mres0}); + auto minp1 = tvm::make_object(); + auto mres1 = tvm::make_object(); + std::string log_version; + ReadMeasureRecords(ss.str(), minp1.get(), mres1.get(), &log_version); + const auto& s4 = dag.ReplayAndInferBound(minp1->state->transform_steps); + CHECK(std::equal_to()(s1, s4)); +} + TEST(Step, SplitFuseReorder) { using namespace tvm::ansor; @@ -533,7 +591,69 @@ TEST(Step, CacheReadWrite) { } TEST(Step, FollowSplitFollowFusedSplit) { - // todo + using namespace tvm::ansor; + + const auto& tensors = matmul_func(512, 512, 512); + const auto& dag = ComputeDAGNode::make(tensors); + + State s0 = dag.GetInitState(); + int C = 2; + + int C_global = s0.cache_write(C, "global", dag); + C++; + + // FollowSplitStep currently only support `inner_to_outer = true` + const auto& its0 = s0.split(C, s0->stages[C]->iters[0], {4, 2, 8, 4}, true); + int split_step0 = s0->transform_steps.size() - 1; + // const auto& its1 = s0.split(C, s0->stages[C]->iters[5], {4, 2, 8, 4}, false); + // int split_step1 = s0->transform_steps.size() - 1; + for (int level = 1; level <= 5; level++) { + State tmp = s0; + tmp.follow_split(C_global, s0->stages[C_global]->iters[0], split_step0, + level); + // tmp.follow_split(C_global, s0->stages[C_global]->iters[5], split_step1, + // level); + const auto& stage_C = tmp->stages[C]; + const auto& stage_C_global = tmp->stages[C_global]; + for (int i = 0; i < level; i++) { + CHECK_EQ(stage_C->iters[i]->range->extent.as()->value, + stage_C_global->iters[i]->range->extent.as()->value); + } + // for (int i = 0; i < level; i++) { + // CHECK(stage_C->iters[i+5]->range->extent.as()->value == + // stage_C_global->iters[i+5]->range->extent.as()->value); + // } + } + + const auto& its1 = s0.split(C, s0->stages[C]->iters[5], {2, 2, 4, 8}); + int split_step1 = s0->transform_steps.size() - 1; + std::vector its; + for (int i = 0; i < 5; i++) { + its.push_back(its0[i]); + its.push_back(its1[i]); + } + s0.reorder(C, its); + for (int i = 0; i < 5; i++) { + s0.fuse(C, {s0->stages[C]->iters[i], s0->stages[C]->iters[i+1]}); + } + for (int level = 0; level < 4; level++) { + State tmp = s0; + tmp.follow_fused_split(C_global, tmp->stages[C_global]->iters[0], + {split_step0, split_step1}, level, false); + const auto& stage_C = tmp->stages[C]; + const auto& stage_C_global = tmp->stages[C_global]; + CHECK_EQ(stage_C->iters[level+1]->range->extent.as()->value, + stage_C_global->iters[0]->range->extent.as()->value); + } + for (int level = 0; level < 4; level++) { + State tmp = s0; + tmp.follow_fused_split(C_global, tmp->stages[C_global]->iters[0], + {split_step0, split_step1}, level, true); + const auto& stage_C = tmp->stages[C]; + const auto& stage_C_global = tmp->stages[C_global]; + CHECK_EQ(stage_C->iters[level+1]->range->extent.as()->value, + stage_C_global->iters[1]->range->extent.as()->value); + } } TEST(Step, Rfactor) { From e0a5ed58b1f9e8296f1a6e9fb269a3426037cbf1 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Fri, 29 May 2020 18:46:01 +0800 Subject: [PATCH 04/45] Add MetaTileRewritePolicy (#5) * Add feature * Add cost_model, meta_tile_rewrite_policy * Add MetaTileRewritePolicy basic UT --- src/ansor/cost_model/cost_model.cc | 163 ++ src/ansor/cost_model/cost_model.h | 98 ++ src/ansor/feature.cc | 1386 ++++++++++++++++ src/ansor/feature.h | 63 + .../search_policy/meta_tile_rewrite_policy.cc | 1420 +++++++++++++++++ .../search_policy/meta_tile_rewrite_policy.h | 101 ++ src/ansor/search_policy/search_policy.cc | 14 + src/ansor/search_policy/search_policy.h | 53 + src/ansor/search_policy/utils.cc | 609 +++++++ src/ansor/search_policy/utils.h | 428 +++++ tests/cpp/ansor_test.cc | 99 +- 11 files changed, 4420 insertions(+), 14 deletions(-) create mode 100644 src/ansor/cost_model/cost_model.cc create mode 100644 src/ansor/cost_model/cost_model.h create mode 100644 src/ansor/feature.cc create mode 100644 src/ansor/feature.h create mode 100644 src/ansor/search_policy/meta_tile_rewrite_policy.cc create mode 100644 src/ansor/search_policy/meta_tile_rewrite_policy.h create mode 100644 src/ansor/search_policy/search_policy.cc create mode 100644 src/ansor/search_policy/search_policy.h create mode 100644 src/ansor/search_policy/utils.cc create mode 100644 src/ansor/search_policy/utils.h diff --git a/src/ansor/cost_model/cost_model.cc b/src/ansor/cost_model/cost_model.cc new file mode 100644 index 000000000000..d4304bccb4bf --- /dev/null +++ b/src/ansor/cost_model/cost_model.cc @@ -0,0 +1,163 @@ +/*! + * Copyright (c) 2020 by Contributors + */ +#include "cost_model.h" +#include +#include +#include + +namespace tvm { +namespace ansor { + +using ::tvm::runtime::NDArray; + +TVM_REGISTER_OBJECT_TYPE(CostModelNode); +TVM_REGISTER_OBJECT_TYPE(RandomModelNode); +TVM_REGISTER_OBJECT_TYPE(MeasureModelNode); +TVM_REGISTER_OBJECT_TYPE(PythonBasedCostModelNode); + +void RandomNumber(TVMArgs args, TVMRetValue* rv) { + int n = args[0]; + void* data = args[1]; + float* fdata = reinterpret_cast(data); + for (int i = 0; i < n; i++) { + fdata[i] = static_cast(rand_r(0)) / (static_cast(RAND_MAX)); + } +} + +CostModel RandomModelNode::make() { + ObjectPtr node = make_object(); + node->random_number_func = + runtime::Registry::Get("ansor.cost_model.random_number"); + if (node->random_number_func == nullptr) { + LOG(WARNING) << "ansor.cost_model.random_number is not registered, " + << "use C++ default random_number func instead."; + static PackedFunc cost_model_random_number(RandomNumber); + node->random_number_func = &cost_model_random_number; + } + return CostModel(node); +} + +void RandomModelNode::Update(const Array& inputs, + const Array& results) { +} + +void RandomModelNode::Predict(const SearchTask& task, + const std::vector& states, + std::vector* scores) { + scores->resize(states.size()); + (*random_number_func)(states.size(), static_cast(scores->data())); +} + +CostModel MeasureModelNode::make(Builder builder, Runner runner) { + ObjectPtr node = make_object(); + node->measurer = ProgramMeasurerNode::make(std::move(builder), std::move(runner), + Array(), 0); + return CostModel(node); +} + +void MeasureModelNode::Update(const Array& inputs, + const Array& results) { +} + +void MeasureModelNode::Predict(const SearchTask& task, + const std::vector& states, + std::vector* scores) { + std::vector inputs; + std::vector results; + + inputs.clear(); inputs.reserve(states.size()); + for (const auto& state : states) { + inputs.push_back(MeasureInputNode::make(task, state)); + } + measurer->SilentMeasure(task, inputs, &results); + + scores->clear(); + scores->reserve(results.size()); + for (const auto& res : results) { + scores->push_back(1.0 / FloatArrayMean(res->costs)); + } +} + +CostModel PythonBasedCostModelNode::make(PackedFunc update_func, PackedFunc predict_func, + PackedFunc predict_stage_func) { + auto node = make_object(); + node->update_func = std::move(update_func); + node->predict_func = std::move(predict_func); + node->predict_stage_func = std::move(predict_stage_func); + return CostModel(node); +} + +void PythonBasedCostModelNode::Update(const Array& inputs, + const Array& results) { + update_func(inputs, results); +} + +void PythonBasedCostModelNode::Predict(const SearchTask& task, + const std::vector& states, + std::vector* scores) { + scores->resize(states.size()); + predict_func(task, Array(states.begin(), states.end()), + static_cast(scores->data())); +} + +void PythonBasedCostModelNode::PredictStages(const SearchTask& task, + const std::vector& states, + std::vector* state_scores, + std::vector>* stage_scores) { + int n_states = states.size(); + int n_stages = task->compute_dag.GetInitState()->stages.size(); + std::vector flatten_scores; + flatten_scores.resize(n_states * n_stages * 2); // Allocate sufficient spaces. + predict_stage_func(task, Array(states.begin(), states.end()), + static_cast(flatten_scores.data())); + + // Unpack flatten scores. + state_scores->clear(); + stage_scores->clear(); + + // Score of each states. + for (int i = 0; i < n_states; ++i) { + state_scores->push_back(flatten_scores[i]); + } + + // Score of each stage in each states. + size_t idx = n_states; + for (int i = 0; i < n_states; ++i) { + CHECK_LE(idx, flatten_scores.size()); + + // Number of scored stages of this state. + int s_length = (int)flatten_scores[idx++]; + + if (s_length > 0) { + std::vector scores; + int offset = 0; + + if ((*state_scores)[i] > -INFINITY) { + // If the score is valid. Copy scored stages and assign 0 to placeholder and inlined stages. + // If the score is 0, meaning this state failed to be lowered. Just bypass to update offset. + for (const Stage& stage : states[i]->stages) { + if (stage->op_type == kPlaceholder) { + scores.push_back(0); + continue; + } + if (stage->compute_at == kInlined) { + scores.push_back(0); + continue; + } + scores.push_back(flatten_scores[idx + offset]); + offset++; + } + CHECK_EQ(offset, s_length); + stage_scores->push_back(std::move(scores)); + } + idx += s_length; + } else { + // Cost model does not provide any stage score details. + stage_scores->push_back({}); + } + } +} + +} // namespace ansor +} // namespace tvm diff --git a/src/ansor/cost_model/cost_model.h b/src/ansor/cost_model/cost_model.h new file mode 100644 index 000000000000..36179573c617 --- /dev/null +++ b/src/ansor/cost_model/cost_model.h @@ -0,0 +1,98 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file ansor/cost_model.h + * \brief Base class of cost model + */ + +#ifndef TVM_ANSOR_COST_MODEL_COST_MODEL_H_ +#define TVM_ANSOR_COST_MODEL_COST_MODEL_H_ + +#include +#include +#include +#include +#include "../measure.h" + +namespace tvm { +namespace ansor { + +using runtime::PackedFunc; + +class CostModel; + +/*! \brief The base class for cost model */ +class CostModelNode: public Object { + public: + virtual void Update(const Array& inputs, const Array& results) = 0; + virtual void Predict(const SearchTask& task, const std::vector& states, + std::vector* scores) = 0; + virtual void PredictStages(const SearchTask& task, const std::vector& states, + std::vector* state_scores, + std::vector>* stage_scores) = 0; + + static constexpr const char *_type_key = "ansor.CostModel"; + TVM_DECLARE_BASE_OBJECT_INFO(CostModelNode, Object); +}; +TVM_DEFINE_MUTABLE_NODE_REF(CostModel, CostModelNode); + +/*! \brief The cost model returns random value for all predictions */ +class RandomModelNode: public CostModelNode { + public: + const PackedFunc* random_number_func; + + static CostModel make(); + + void Update(const Array& inputs, const Array& results) final; + void Predict(const SearchTask& task, const std::vector& states, + std::vector* scores) final; + void PredictStages(const SearchTask& task, const std::vector& states, + std::vector* state_scores, + std::vector>* stage_scores) { ; } + + static constexpr const char *_type_key = "ansor.RandomModel"; + TVM_DECLARE_FINAL_OBJECT_INFO(RandomModelNode, CostModelNode); +}; + +class MeasureModelNode : public CostModelNode { + public: + ProgramMeasurer measurer; + + static CostModel make(Builder builder, Runner runner); + + void Update(const Array& inputs, const Array& results) final; + void Predict(const SearchTask& task, const std::vector& states, + std::vector* scores) final; + void PredictStages(const SearchTask& task, const std::vector& states, + std::vector* state_scores, + std::vector>* stage_scores) { ; } + + static constexpr const char* _type_key = "ansor.MeasureModel"; + TVM_DECLARE_FINAL_OBJECT_INFO(MeasureModelNode, CostModelNode); +}; + +/*! \brief A wrapper for cost model defined by python code + * This class will call python's function */ +class PythonBasedCostModelNode: public CostModelNode { + public: + PackedFunc update_func; + PackedFunc predict_func; + PackedFunc predict_stage_func; + + static CostModel make(PackedFunc update_func, PackedFunc predict_func, + PackedFunc predict_stage_func); + + void Update(const Array& inputs, const Array& results) final; + void Predict(const SearchTask& task, const std::vector& states, + std::vector* scores) final; + void PredictStages(const SearchTask& task, const std::vector& states, + std::vector* state_scores, + std::vector>* stage_scores) final; + + static constexpr const char *_type_key = "ansor.PythonBasedCostModel"; + TVM_DECLARE_FINAL_OBJECT_INFO(PythonBasedCostModelNode, CostModelNode); +}; + +} // namespace ansor +} // namespace tvm + +#endif // TVM_ANSOR_COST_MODEL_COST_MODEL_H_ diff --git a/src/ansor/feature.cc b/src/ansor/feature.cc new file mode 100644 index 000000000000..cb865bc3b5ae --- /dev/null +++ b/src/ansor/feature.cc @@ -0,0 +1,1386 @@ +/*! + * Copyright (c) 2020 by Contributors + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "measure.h" +#include "serialization.h" +#include "utils.h" +// #include "../arithmetic/compute_expr.h" + +namespace tvm { +/* Import the function from build_module.cc */ +extern void GetBinds(const Array& args, + bool compact, + const std::unordered_map& binds, + Map* out_binds, + Array* out_arg_list, + const BuildConfig& config); +} // namespace tvm + + +namespace tvm { +namespace ansor { + +using namespace tvm::tir; +using arith::ConstIntBound; +using arith::Analyzer; + +static const int ARITH_INTENSITY_CURVE_SAMPLE_N = 10; + +// Annotation position encoding +enum AnnotationPosType { + kPosNone, kPosInnerSpatial, kPosMiddleSpatial, kPosOuterSpatial, + kPosInnerReduce, kPosMiddleReduce, kPosOuterReduce, kPosMixed +}; + +// Buffer access type +enum BufferAccessType { + kRead, kWrite, kReadWrite, kUnknownRW +}; + +// Accesses to a buffer +struct BufferAccess { + BufferAccessType acc_type{kUnknownRW}; + std::vector > indices; +}; + +// Data reuse type +enum ReuseType { + kLoopMultipleRead, kSerialMultipleReadWrite, kNoReuse +}; + +// Feature for an access of a buffer +struct BufferAccessFeature { + std::string tensor_name; + BufferAccessType acc_type; + float bytes; + float unique_bytes; + float lines; + float unique_lines; + ReuseType reuse_type; + float reuse_dis_iter; // reuse distance in iterator number + float reuse_dis_bytes; // reuse distance in total touched bytes + float reuse_ct; // reuse times + float bytes_d_reuse_ct; + float unique_bytes_d_reuse_ct; + float lines_d_reuse_ct; + float unique_lines_d_reuse_ct; + float stride; +}; + +// Feature set of a statement +struct FeatureSet { + // compute feature + float float_mad; + float float_addsub; + float float_mul; + float float_divmod; + float float_cmp; + float float_math_func; + float float_other_func; + float int_mad; + float int_addsub; + float int_mul; + float int_divmod; + float int_cmp; + float int_math_func; + float int_other_func; + float bool_op; + float select_op; + float vec_num; // The number of vectorized iterators + float vec_prod; // The product of the lengths of vectorized iterators + float vec_len; // The length of the innermost vectorized iterator + AnnotationPosType vec_type; + float unroll_num; // The number of unrolled iterators + float unroll_prod; // The product of the lengths of vectorized iterators + float unroll_len; // The length of the innermost unrolled iterator + AnnotationPosType unroll_type; + float parallel_num; // The number of paralleled iterators + float parallel_prod; // The product of the lengths of paralleled iterators + float parallel_len; // The length of the innermost paralleled iterators + AnnotationPosType parallel_type; + float is_gpu; + float blockIdx_x_len; + float blockIdx_y_len; + float blockIdx_z_len; + float threadIdx_x_len; + float threadIdx_y_len; + float threadIdx_z_len; + float vthread_len; + + float arith_intensity_curve[ARITH_INTENSITY_CURVE_SAMPLE_N]; + + // buffer access feature (per buffer) + std::vector access_feas; + + // allocation feature + float alloc_size; + float alloc_prod; + float alloc_outer_prod; + float alloc_inner_prod; + + // overall feature + float outer_prod; + float num_loops; + float auto_unroll_max_step; +}; + +// Return whether a var is in an expr +bool VarInExpr(const Var& var, const PrimExpr& expr) { + bool find = false; + + PostOrderVisit(expr, [&find, &var](const ObjectRef &node) { + if (find) { + return; + } + + if (const VarNode* op = node.as()) { + if (op == var.get()) { + find = true; + } + } + }); + + return find; +} + +// Get position encoding for annotation +AnnotationPosType GetAnnotationPosEncoding( + const Var& var, const Array& spatial_args, + const Array& axis, const Array& reduce_axis) { + // Try to match spatial args first + size_t find_i = 0; + size_t find_ct = 0; + for (size_t i = 0; i < spatial_args.size(); ++i) { + if (VarInExpr(var, spatial_args[i])) { + find_i = i; + find_ct += 1; + } + } + + if (find_ct == 0) { + // If not find in spatial args, then it is a reduce iteartor. + // Use name to match + for (size_t i = 0; i < reduce_axis.size(); ++i) { + if (var->name_hint.find(reduce_axis[i]->var->name_hint) != std::string::npos) { + find_i = i; + find_ct++; + } + } + if (find_ct >= 1) { + if (find_i == 0) { + return kPosInnerReduce; + } else if (find_i == reduce_axis.size() - 1) { + return kPosOuterReduce; + } else { + return kPosMiddleReduce; + } + } else { + // If the axis is not found in both spatial args and reduce axis, + // then this stage must compute_at somewhere under this aixs and this axis is simplified out + // We assume it is an outer spatial + return kPosOuterSpatial; + } + } else if (find_ct == 1) { + if (find_i == spatial_args.size() - 1) { + return kPosInnerSpatial; + } else if (find_i == 0) { + return kPosOuterSpatial; + } else { + return kPosMiddleSpatial; + } + } else { + return kPosMixed; + } +} + +// Count math ops in an expr +class MathOpCounter : public StmtExprVisitor { + public: +#define VisitBinary(Type, float_ct, int_ct) \ + void VisitExpr_(const Type* op) final { \ + if (op->a.dtype().is_float()) { \ + float_ct++; \ + } else { \ + int_ct++; \ + } \ + StmtExprVisitor::VisitExpr_(op); \ + } \ + + VisitBinary(AddNode, float_addsub, int_addsub); + VisitBinary(SubNode, float_addsub, int_addsub); + VisitBinary(MulNode, float_mul, int_mul); + VisitBinary(DivNode, float_divmod, int_divmod); + VisitBinary(ModNode, float_divmod, int_divmod); + VisitBinary(FloorDivNode, float_divmod, int_divmod); + VisitBinary(FloorModNode, float_divmod, int_divmod); + VisitBinary(MaxNode, float_cmp, int_cmp); + VisitBinary(MinNode, float_cmp, int_cmp); + VisitBinary(EQNode, float_cmp, int_cmp); + VisitBinary(NENode, float_cmp, int_cmp); + VisitBinary(LTNode, float_cmp, int_cmp); + VisitBinary(LENode, float_cmp, int_cmp); + VisitBinary(GTNode, float_cmp, int_cmp); + VisitBinary(GENode, float_cmp, int_cmp); + + void VisitExpr_(const AndNode* op) final { bool_op++; StmtExprVisitor::VisitExpr_(op); } + void VisitExpr_(const OrNode* op) final { bool_op++; StmtExprVisitor::VisitExpr_(op); } + void VisitExpr_(const NotNode* op) final { bool_op++; StmtExprVisitor::VisitExpr_(op); } + void VisitExpr_(const SelectNode* op) final { select_op++; StmtExprVisitor::VisitExpr_(op); } + + // TODO(...): CallNode with type CallNode::Halide has been modified to BufferLoadNode + void VisitExpr_(const CallNode* op) final { + if (op->call_type == CallNode::CallType::PureIntrinsic) { + if (op->dtype.is_float()) { + float_math_func++; + } else { + int_math_func++; + } + } else if (op->call_type != CallNode::CallType::Halide) { + if (op->dtype.is_float()) { + float_other_func++; + } else { + int_other_func++; + } + } + StmtExprVisitor::VisitExpr_(op); + } + + // todo(lmzheng): detect mad + size_t float_mad{0}, float_addsub{0}, float_mul{0}, float_divmod{0}, + float_cmp{0}, float_math_func{0}, float_other_func{0}; + size_t int_mad{0}, int_addsub{0}, int_mul{0}, int_divmod{0}, + int_cmp{0}, int_math_func{0}, int_other_func{0}; + size_t bool_op{0}, select_op{0}; +}; + + +// Extract all buffer accesses in an expr +class BufferAccessExtractor : public StmtExprVisitor { + public: + void ExtractReads(const PrimExpr& expr) { + this->VisitExpr(expr); + } + + void InsertAccess(const te::Tensor& ten, BufferAccessType acc_type, const Array& indices) { + BufferAccess& acc = buf_accesses[ten]; + acc.acc_type = acc_type; + acc.indices.push_back(std::vector(indices.begin(), indices.end())); + } + + // TODO(...): CallNode with type CallNode::Halide has been modified to BufferLoadNode + void VisitExpr_(const CallNode *op) final { + if (op->call_type == CallNode::CallType::Halide) { + te::Tensor ten = Downcast(op->func).output(op->value_index); + BufferAccess& acc = buf_accesses[ten]; + switch (acc.acc_type) { + case kRead: + break; + case kWrite: + acc.acc_type = kReadWrite; break; + case kReadWrite: + break; + case kUnknownRW: + default: + acc.acc_type = kRead; break; + } + + if (acc.acc_type != kReadWrite) { + // If a buffer is both read and written, in the tvm DSL, it must be a update, + // so the indices should be the same. Then we can skip appending indices for it. + // Otherwise we do the following. + buf_accesses[ten].indices.push_back( + std::vector(op->args.begin(), op->args.end())); + } + } + StmtExprVisitor::VisitExpr_(op); + } + + std::unordered_map buf_accesses; +}; + +// Compute coefficient for an loop iterator in an expression +// Note: we use a approximation strategy to find coefficient. +// Hopefully, it is faster than DetectLinearEquation and can handle more cases (non-linear) +class CoefficientExtractor : public StmtExprVisitor { + public: + void VisitExpr_(const MulNode *node) final { + StmtExprVisitor::VisitExpr_(node); + if (visited_var) { + if (!visited_add) { + if (auto a = node->a.as()) { + visited_mul = true; + stride = a->value; + } else if (auto b = node->b.as()) { + visited_mul = true; + stride = b->value; + } + } + } + } + + void VisitExpr_(const AddNode *node) final { + StmtExprVisitor::VisitExpr_(node); + if (visited_var) { + if (!visited_mul) { + visited_add = true; + stride = 1; + } + } + } + + void VisitExpr_(const VarNode *node) final { + if (node == var_) { + visited_var = true; + // This is a magic default stride in case our approximation strategy fails + stride = 2; + } + } + + int ExtractCoefficient(const PrimExpr& expr, const VarNode* var) { + visited_var = visited_mul = visited_add = false; + var_ = var; + + this->VisitExpr(expr); + + if (visited_var && !visited_mul && !visited_add) { + return 1; + } else { + return stride; + } + } + + bool visited_var{false}; + bool visited_mul{false}; + bool visited_add{false}; + int stride{0}; + + private: + const VarNode* var_{nullptr}; +}; + +// Compute stride for the accesses to a buffer +int64_t ComputeStride(const std::vector >& indices, + const std::vector& shape, + const VarNode* stride_var) { + int64_t min_stride = std::numeric_limits::max(); + bool find = false; + CoefficientExtractor extractor; + + for (const auto &index : indices) { + int64_t shape_stride = 1; + for (int i = static_cast(index.size()) - 1; i >= 0; i--) { + int coefficient = extractor.ExtractCoefficient(index[i], stride_var); + if (extractor.visited_var) { + find = true; + min_stride = std::min(min_stride, std::abs(coefficient) * shape_stride); + break; + } + shape_stride *= shape[i]; + } + } + + return find ? min_stride : 0; +} + +// Compute touched bytes and cache lines for accesses to a buffer +void ComputeRegion( + const std::vector > &indices, + arith::Analyzer* ana, + std::vector* region) { + region->clear(); + + if (indices.empty()) { + return; + } + + region->reserve(indices[0].size()); + + if (indices.size() == 1) { + for (const auto& index : indices[0]) { + ConstIntBound bound = ana->const_int_bound(index); + region->push_back(bound->max_value - bound->min_value + 1); + } + } else { + // future(lmzheng): implement a more accurate IntSet? + for (size_t i = 0; i < indices[0].size(); ++i) { + int64_t minimum = ConstIntBound::kPosInf, maximum = ConstIntBound::kNegInf; + for (size_t j = 0; j < indices.size(); ++j) { + ConstIntBound bound = ana->const_int_bound(indices[j][i]); + + minimum = std::min(minimum, bound->min_value); + maximum = std::max(maximum, bound->max_value); + } + region->push_back(maximum - minimum + 1); + } + } +} + +// Compute reuse distance and reuse ratio for accesses to a buffer +// return values: reuse_type, reuse_dis_iter, reuse_dis_bytes, reuse_ct +std::tuple ComputeReuse( + const te::Tensor& t, + const std::vector >& indices, + const std::vector& for_loop_stack, + const std::unordered_map > > >& for_touch_regions) { + float reuse_dis_iter = 1.0f; + float reuse_dis_bytes = -1.0f; + + for (int i = static_cast(for_loop_stack.size()) - 1; i >= 0; --i) { + const ForNode* cur_for = for_loop_stack[i]; + bool find = false; + + for (size_t j = 0; j < indices.size(); j++) { + for (size_t k = 0; k < indices[j].size(); k++) { + if (VarInExpr(cur_for->loop_var, indices[j][k])) { + find = true; + break; + } + } + if (find) { + break; + } + } + + int64_t extent = GetIntImm(for_loop_stack[i]->extent); + if (find) { + // accumulate/update reuse distance + reuse_dis_iter *= extent; + reuse_dis_bytes = 0.0f; + for (const auto& iter : for_touch_regions.at(cur_for)) { + for (const auto& access : iter.second) { + reuse_dis_bytes += std::get<1>(access) * std::get<2>(access); + } + } + } else { + // Have LoopMultipleRead reuse + if (reuse_dis_bytes < 0) { + // For the reuse in the innermost axis, the above code won't be executed. + // So we compute bytes here + reuse_dis_bytes = 0.0f; + for (const auto& iter : for_touch_regions.at(cur_for)) { + for (const auto& access : iter.second) { + reuse_dis_bytes += 1 * std::get<2>(access); + } + } + } + return std::make_tuple(kLoopMultipleRead, reuse_dis_iter, reuse_dis_bytes, extent); + } + + const std::unordered_map > >& + tensor_map = for_touch_regions.at(cur_for); + + int serial_reuse = static_cast(tensor_map.at(t).size()) - 1; + if (serial_reuse > 0) { + int64_t extent = GetIntImm(cur_for->extent); + + // Have SerialMultipleReadWrite reuse + reuse_dis_iter = std::numeric_limits::max(); + for (const auto& acc_info : tensor_map.at(t)) { + reuse_dis_iter = std::min(reuse_dis_iter, static_cast(std::get<1>(acc_info))); + } + + reuse_dis_bytes = 0.0f; + for (const auto& iter : for_touch_regions.at(cur_for)) { + for (const auto& access : iter.second) { + reuse_dis_bytes += std::get<1>(access) * std::get<2>(access); + } + } + + return std::make_tuple(kSerialMultipleReadWrite, + reuse_dis_iter / extent, reuse_dis_bytes / extent, serial_reuse); + } + } + + return std::make_tuple(kNoReuse, 0, 0, 0); +} + +// Extract features for every Provide statement +class PerStmtFeatureExtractor : public StmtExprVisitor { + public: + explicit PerStmtFeatureExtractor(int cache_line_size) : + cache_line_size_(cache_line_size) {} + + void VisitStmt_(const AttrStmtNode* node) final { + if (node->attr_key == tir::attr::thread_extent || + node->attr_key == tir::attr::virtual_thread) { + const Var& var = node->node.as()->var; + int extent = GetIntImm(node->value); + + int* plen = nullptr; + + const std::string& name = var.get()->name_hint; + if (node->attr_key == tir::attr::thread_extent) { + if (name == "blockIdx.x") { + plen = &blockIdx_x_len; + } else if (name == "blockIdx.y") { + plen = &blockIdx_y_len; + } else if (name == "blockIdx.z") { + plen = &blockIdx_z_len; + } else if (name == "threadIdx.x") { + plen = &threadIdx_x_len; + } else if (name == "threadIdx.y") { + plen = &threadIdx_y_len; + } else if (name == "threadIdx.z") { + plen = &threadIdx_z_len; + } else { + LOG(FATAL) << "invalid thread itervar " + name; + } + } else { + plen = &vthread_len; + } + + int extent_before = *plen; + if (node->attr_key == tir::attr::thread_extent) { + *plen = extent; + } else { + *plen *= extent; + } + + is_gpu = true; + + // make a fake for node for blockIdx.x or threadIdx.x + Stmt fake_for_node = ForNode::make(var, 0, extent, ForType::Parallel, + DeviceAPI::None, node->body); + + outer_loop_prod *= extent; + for_loop_stack.push_back(fake_for_node.as()); + StmtExprVisitor::VisitStmt_(node); + for_loop_stack.pop_back(); + outer_loop_prod /= extent; + + *plen = extent_before; + } else if (node->attr_key == "pragma_auto_unroll_max_step") { + int value = GetIntImm(node->value); + + int16_t old_value = cur_auto_unroll_max_step; + cur_auto_unroll_max_step = value; + StmtExprVisitor::VisitStmt_(node); + cur_auto_unroll_max_step = old_value; + } else { + StmtExprVisitor::VisitStmt_(node); + } + } + + void VisitStmt_(const ForNode* node) final { + int64_t loop_extent = GetIntImm(node->extent); + + if (node->for_type == ForType::Vectorized) { + vec_for_stack.push_back(node); + } else if (node->for_type == ForType::Unrolled) { + unroll_for_stack.push_back(node); + } else if (node->for_type == ForType::Parallel) { + parallel_for_stack.push_back(node); + } + + outer_loop_prod *= loop_extent; + for_loop_stack.push_back(node); + StmtExprVisitor::VisitStmt_(node); + for_loop_stack.pop_back(); + outer_loop_prod /= loop_extent; + + if (node->for_type == ForType::Vectorized) { + vec_for_stack.pop_back(); + } else if (node->for_type == ForType::Unrolled) { + unroll_for_stack.pop_back(); + } else if (node->for_type == ForType::Parallel) { + parallel_for_stack.pop_back(); + } + } + + // TODO(...): ProvideNode is deprecated, move to BufferStoreNode + void VisitStmt_(const ProvideNode* node) final { + te::Operation op = Downcast(node->func); + te::Tensor ten = op.output(node->value_index); + const te::ComputeOpNode* pcompute = op.as(); + + FeatureSet &fea = op_features[ten]; + + // compute feature + MathOpCounter mathops; + mathops(node->value); + fea.float_mad = outer_loop_prod * mathops.float_mad; + fea.float_addsub = outer_loop_prod * mathops.float_addsub; + fea.float_mul = outer_loop_prod * mathops.float_mul; + fea.float_divmod = outer_loop_prod * mathops.float_divmod; + fea.float_cmp = outer_loop_prod * mathops.float_cmp; + fea.float_math_func = outer_loop_prod * mathops.float_math_func; + fea.float_other_func = outer_loop_prod * mathops.float_other_func; + fea.int_mad = outer_loop_prod * mathops.int_mad; + fea.int_addsub = outer_loop_prod * mathops.int_addsub; + fea.int_mul = outer_loop_prod * mathops.int_mul; + fea.int_divmod = outer_loop_prod * mathops.int_divmod; + fea.int_math_func = outer_loop_prod * mathops.int_math_func; + fea.int_cmp = outer_loop_prod * mathops.int_cmp; + fea.int_other_func = outer_loop_prod * mathops.int_other_func; + fea.bool_op = outer_loop_prod * mathops.bool_op; + fea.select_op = outer_loop_prod * mathops.select_op; + + fea.outer_prod = outer_loop_prod; + fea.num_loops = for_loop_stack.size(); + fea.auto_unroll_max_step = cur_auto_unroll_max_step; + fea.vec_len = fea.unroll_len = fea.parallel_len = 0.0f; + fea.vec_type = fea.unroll_type = fea.parallel_type = kPosNone; + + fea.vec_num = vec_for_stack.size(); + if (!vec_for_stack.empty()) { + fea.vec_len = GetIntImm(vec_for_stack.back()->extent); + fea.vec_prod = 1.0; + for (const ForNode* pfor : vec_for_stack) { + fea.vec_prod *= GetIntImm(pfor->extent); + } + fea.vec_type = GetAnnotationPosEncoding(vec_for_stack.back()->loop_var, + node->args, pcompute->axis, pcompute->reduce_axis); + } + + fea.unroll_num = unroll_for_stack.size(); + if (!unroll_for_stack.empty()) { + fea.unroll_len = GetIntImm(unroll_for_stack.back()->extent); + fea.unroll_prod = 1.0; + for (const ForNode* pfor : unroll_for_stack) { + fea.unroll_prod *= GetIntImm(pfor->extent); + } + fea.unroll_type = GetAnnotationPosEncoding(unroll_for_stack.back()->loop_var, + node->args, pcompute->axis, pcompute->reduce_axis); + } + + fea.parallel_num = parallel_for_stack.size(); + if (!parallel_for_stack.empty()) { + fea.parallel_len = GetIntImm(parallel_for_stack.back()->extent); + fea.parallel_prod = 1.0; + for (const ForNode* pfor : parallel_for_stack) { + fea.parallel_prod *= GetIntImm(pfor->extent); + } + fea.parallel_type = GetAnnotationPosEncoding(parallel_for_stack.back()->loop_var, + node->args, pcompute->axis, pcompute->reduce_axis); + } + + // GPU threads + fea.is_gpu = is_gpu; + fea.blockIdx_x_len = blockIdx_x_len; + fea.blockIdx_y_len = blockIdx_y_len; + fea.blockIdx_z_len = blockIdx_z_len; + fea.threadIdx_x_len = threadIdx_x_len; + fea.threadIdx_y_len = threadIdx_y_len; + fea.threadIdx_z_len = threadIdx_z_len; + fea.vthread_len = vthread_len; + + // Extract all buffer access + std::vector acc_feas; + BufferAccessExtractor buf_extractor; + buf_extractor.InsertAccess(ten, kWrite, node->args); + buf_extractor.ExtractReads(node->value); + + // Compute touched region for all outer loops + Analyzer ana; + for (auto x : for_loop_stack) { + ana.Bind(x->loop_var, Range::make_by_min_extent(x->min, 1)); + } + + std::vector mem_bytes_list; + std::vector compute_ops_list; + + mem_bytes_list.reserve(for_loop_stack.size()); + compute_ops_list.reserve(for_loop_stack.size()); + + int cur_compute_ops = mathops.float_mad + mathops.float_addsub + mathops.float_mul + + mathops.float_divmod + mathops.float_cmp + + mathops.float_math_func + mathops.float_other_func; + + std::vector tmp_region; + for (int i = static_cast(for_loop_stack.size()) - 1; i >= 0; i--) { + const ForNode* p_for = for_loop_stack[i]; + + ana.Bind(p_for->loop_var, + Range::make_by_min_extent(for_loop_stack[i]->min, for_loop_stack[i]->extent)); + + // Note, here we do overwrite. + // So if there are multiple Provides, the last one will overwrite the first few. + // e.g. The update part in gemm will overwrite the init part. + std::unordered_map > >& + tensor_regions_map = for_touch_regions[p_for]; + + int64_t mem_bytes = 0; + for (const auto &x : buf_extractor.buf_accesses) { + const te::Tensor& t = x.first; + const BufferAccess& acc = x.second; + + ComputeRegion(acc.indices, &ana, &tmp_region); + int64_t touched_size = ElementProduct(tmp_region); + tensor_regions_map[t].push_back(std::make_tuple(acc.acc_type, + touched_size, t->dtype.bytes())); + mem_bytes += touched_size * t->dtype.bytes(); + } + + mem_bytes_list.push_back(std::log2(mem_bytes)); + cur_compute_ops *= GetIntImm(for_loop_stack[i]->extent); + compute_ops_list.push_back(std::log2(cur_compute_ops)); + } + + // Compute arithmetic intensity curve (y axis : arithmetic intensity, x axis : flops). + // We use piecewise linear interpolation to fit this curve. + int pt = 0; + if (cur_compute_ops <= 0 || compute_ops_list.empty()) { + std::fill(fea.arith_intensity_curve, + fea.arith_intensity_curve + ARITH_INTENSITY_CURVE_SAMPLE_N, 0.0); + } else { + for (size_t i = 0; i < ARITH_INTENSITY_CURVE_SAMPLE_N; ++i) { + float cur_compute_ops = compute_ops_list.back() * (i+1) / ARITH_INTENSITY_CURVE_SAMPLE_N; + while (compute_ops_list[pt] < cur_compute_ops - 1e-4) { + pt++; + } + CHECK_LT(pt, compute_ops_list.size()); + + float value; + if (pt == 0) { + value = compute_ops_list[pt] / mem_bytes_list[pt]; + } else { + float base = compute_ops_list[pt-1] / mem_bytes_list[pt-1]; + float slope = (compute_ops_list[pt] / mem_bytes_list[pt] - + compute_ops_list[pt-1] / mem_bytes_list[pt-1]) / + (compute_ops_list[pt] - compute_ops_list[pt-1]); + value = base + slope * (cur_compute_ops - compute_ops_list[pt-1]); + } + fea.arith_intensity_curve[i] = value; + } + } + + // Compute buffer access feature + for (const auto &x : buf_extractor.buf_accesses) { + const te::Tensor& t = x.first; + const BufferAccess& acc = x.second; + + std::vector int_shape; + for (const auto& dim : t->shape) { + int_shape.push_back(GetIntImm(dim)); + } + + size_t ele_bytes = t->dtype.bytes(); + + // calculate bytes + float bytes = outer_loop_prod * ele_bytes; + float unique_bytes; + + // calculate cache lines + int64_t stride; + float lines; + float unique_lines; + + if (for_loop_stack.empty()) { + unique_bytes = ele_bytes; + stride = 0; + lines = 1.0f; + unique_lines = 1.0f; + } else { + unique_bytes = std::get<1>(for_touch_regions[for_loop_stack.front()][t].front()) + * ele_bytes; + + stride = 0; + int64_t reduce_ratio = 1; + + int i; + for (i = static_cast(for_loop_stack.size()) - 1; i >= 0; i--) { + stride = ComputeStride(acc.indices, int_shape, for_loop_stack[i]->loop_var.get()); + if (stride != 0) { + break; + } + reduce_ratio *= GetIntImm(for_loop_stack.back()->extent); + } + + lines = outer_loop_prod / reduce_ratio * + std::min(1.0f, 1.0f * stride * ele_bytes / cache_line_size_); + lines = std::max(lines, 1.0f); + + // convert `stride` back to the stride of the innermost iterator + stride = (i == static_cast(for_loop_stack.size()) - 1 ? stride : 0); + + float n_continuous = ele_bytes; + for (int i = static_cast(tmp_region.size()) - 1; i >= 0; i--) { + if (tmp_region[i] == int_shape[i]) { + n_continuous *= tmp_region[i]; + break; + } + } + unique_lines = unique_bytes / std::min(n_continuous, + static_cast(cache_line_size_)); + unique_lines = std::max(unique_lines, 1.0f); + } + + ReuseType reuse_type; + float reuse_dis_iter, reuse_dis_bytes, reuse_ct; + std::tie(reuse_type, reuse_dis_iter, reuse_dis_bytes, reuse_ct) = + ComputeReuse(t, acc.indices, for_loop_stack, for_touch_regions); + + acc_feas.emplace_back(); + BufferAccessFeature& acc_fea = acc_feas.back(); + + acc_fea.tensor_name = t->op->func_name(); + acc_fea.acc_type = acc.acc_type; + acc_fea.stride = stride; + acc_fea.bytes = bytes; + acc_fea.unique_bytes = unique_bytes; + acc_fea.lines = lines; + acc_fea.unique_lines = unique_lines; + acc_fea.reuse_type = reuse_type; + acc_fea.reuse_dis_iter = reuse_dis_iter; + acc_fea.reuse_dis_bytes = reuse_dis_bytes; + acc_fea.reuse_ct = reuse_ct; + if (acc_fea.reuse_ct > 0.5) { + acc_fea.bytes_d_reuse_ct = bytes / reuse_ct; + acc_fea.unique_bytes_d_reuse_ct = unique_bytes / reuse_ct; + acc_fea.lines_d_reuse_ct = lines / reuse_ct; + acc_fea.unique_lines_d_reuse_ct = unique_lines / reuse_ct; + } else { + // no reuse, multiply by a magic number '2' + acc_fea.bytes_d_reuse_ct = bytes * 2; + acc_fea.unique_bytes_d_reuse_ct = unique_bytes * 2; + acc_fea.lines_d_reuse_ct = lines * 2; + acc_fea.unique_lines_d_reuse_ct = unique_lines* 2; + } + } + + fea.access_feas = acc_feas; + } + + // TODO(...): RealizeNode is deprecated, move to BufferRealizeNode + void VisitStmt_(const RealizeNode *node) final { + StmtExprVisitor::VisitStmt_(node); + + te::Operation op = Downcast(node->func); + te::Tensor ten = op.output(node->value_index); + + FeatureSet& fea = op_features[ten]; + + float allocation_size = 1.0f; + for (const auto& x : node->bounds) { + allocation_size *= GetIntImm(x->extent); + } + // allocation feature + fea.alloc_size = allocation_size * ten->dtype.bytes(); + fea.alloc_prod = allocation_size * outer_loop_prod; + fea.alloc_outer_prod = outer_loop_prod; + fea.alloc_inner_prod = fea.outer_prod / outer_loop_prod; + } + + float outer_loop_prod = 1.0f; + + std::vector for_loop_stack; + std::vector parallel_for_stack; + std::vector vec_for_stack; + std::vector unroll_for_stack; + + bool is_gpu; + int blockIdx_x_len{1}; + int blockIdx_y_len{1}; + int blockIdx_z_len{1}; + int threadIdx_x_len{1}; + int threadIdx_y_len{1}; + int threadIdx_z_len{1}; + int vthread_len{1}; + int16_t cur_auto_unroll_max_step{0}; + + std::unordered_map op_features; + + // for a loop, for all its touched tensors, for all different accesses to the tensors, + // its (access type, number of touched elements, number of bytes of single element) + std::unordered_map > > > for_touch_regions; + + private: + const int cache_line_size_ = 64; +}; + +// shifted log to incorporate the property that slog(0) = 0 +inline float slog(float x) { + return x < 0 ? -std::log2(-x+1) : std::log2(x+1); +} + +// Get features for all ir::Provide statements in a TVM program. +// So we call it `PerStmt` feature +void GetPerStmtFeature(const Stmt& stmt, + int cache_line_size, + int max_n_bufs, + std::vector* ret) { + LOG(WARNING) << "RealizeNode & ProvideNode deprecated, " + << "need to fix the implementation of PerStmtFeatureExtractor."; + PerStmtFeatureExtractor extractor(cache_line_size); + extractor(stmt); + + ret->push_back(extractor.op_features.size()); + + for (const auto& x : extractor.op_features) { + const FeatureSet& fea_set = x.second; + + /***** compute feature *****/ + ret->push_back(slog(fea_set.float_mad)); + ret->push_back(slog(fea_set.float_addsub)); + ret->push_back(slog(fea_set.float_mul)); + ret->push_back(slog(fea_set.float_divmod)); + ret->push_back(slog(fea_set.float_cmp)); + ret->push_back(slog(fea_set.float_math_func)); + ret->push_back(slog(fea_set.float_other_func)); + ret->push_back(slog(fea_set.int_mad)); + ret->push_back(slog(fea_set.int_addsub)); + ret->push_back(slog(fea_set.int_mul)); + ret->push_back(slog(fea_set.int_divmod)); + ret->push_back(slog(fea_set.int_cmp)); + ret->push_back(slog(fea_set.int_math_func)); + ret->push_back(slog(fea_set.int_other_func)); + ret->push_back(slog(fea_set.bool_op)); + ret->push_back(slog(fea_set.select_op)); + + ret->push_back(slog(fea_set.vec_num)); + ret->push_back(slog(fea_set.vec_prod)); + ret->push_back(slog(fea_set.vec_len)); + for (int i = 0; i <= kPosMixed; i++) { + ret->push_back(i == fea_set.vec_type); + } + + ret->push_back(slog(fea_set.unroll_num)); + ret->push_back(slog(fea_set.unroll_prod)); + ret->push_back(slog(fea_set.unroll_len)); + for (int i = 0; i <= kPosMixed; i++) { + ret->push_back(i == fea_set.unroll_type); + } + + ret->push_back(slog(fea_set.parallel_num)); + ret->push_back(slog(fea_set.parallel_prod)); + ret->push_back(slog(fea_set.parallel_len)); + for (int i = 0; i <= kPosMixed; i++) { + ret->push_back(i == fea_set.parallel_type); + } + + ret->push_back(fea_set.is_gpu); + ret->push_back(slog(fea_set.blockIdx_x_len)); + ret->push_back(slog(fea_set.blockIdx_y_len)); + ret->push_back(slog(fea_set.blockIdx_z_len)); + ret->push_back(slog(fea_set.threadIdx_x_len)); + ret->push_back(slog(fea_set.threadIdx_y_len)); + ret->push_back(slog(fea_set.threadIdx_z_len)); + ret->push_back(slog(fea_set.vthread_len)); + + for (size_t i = 0; i < ARITH_INTENSITY_CURVE_SAMPLE_N; ++i) { + ret->push_back(fea_set.arith_intensity_curve[i]); + } + + /***** access feature *****/ + // sort according to pair (lines, bytes) + std::vector > buf_order_key; + for (const auto& acc_fea : fea_set.access_feas) { + buf_order_key.emplace_back(acc_fea.lines, acc_fea.bytes); + } + std::vector buf_order(buf_order_key.size()); + std::iota(buf_order.begin(), buf_order.end(), 0); + + auto cmp = [&buf_order_key](int l, int r) { + return buf_order_key[l].first > buf_order_key[r].first + || (buf_order_key[l].first == buf_order_key[r].first + && buf_order_key[l].second > buf_order_key[r].second); + }; + std::sort(buf_order.begin(), buf_order.end(), cmp); + int n_bufs = std::min(max_n_bufs, static_cast(buf_order.size())); + buf_order.resize(n_bufs); + + for (int idx : buf_order) { + const auto& acc_fea = fea_set.access_feas[idx]; + for (int j = 0; j <= kReadWrite; ++j) { + ret->push_back(j == acc_fea.acc_type); + } + ret->push_back(slog(acc_fea.bytes)); + ret->push_back(slog(acc_fea.unique_bytes)); + ret->push_back(slog(acc_fea.lines)); + ret->push_back(slog(acc_fea.unique_lines)); + for (int j = 0; j <= kNoReuse; ++j) { + ret->push_back(acc_fea.reuse_type == j); + } + ret->push_back(slog(acc_fea.reuse_dis_iter)); + ret->push_back(slog(acc_fea.reuse_dis_bytes)); + ret->push_back(slog(acc_fea.reuse_ct)); + ret->push_back(slog(acc_fea.bytes_d_reuse_ct)); + ret->push_back(slog(acc_fea.unique_bytes_d_reuse_ct)); + ret->push_back(slog(acc_fea.lines_d_reuse_ct)); + ret->push_back(slog(acc_fea.unique_lines_d_reuse_ct)); + ret->push_back(slog(acc_fea.stride)); + } + // - fill padding + for (int i = 0; i < max_n_bufs - n_bufs; ++i) { + for (int j = 0; j <= kReadWrite; ++j) { // 3 + ret->push_back(0.0f); + } + ret->push_back(0.0f); ret->push_back(0.0f); ret->push_back(0.0f); ret->push_back(0.0f); + for (int j = 0; j <= kNoReuse; ++j) { // 3 + ret->push_back(0.0f); + } + ret->push_back(0.0f); ret->push_back(0.0f); ret->push_back(0.0f); ret->push_back(0.0f); + ret->push_back(0.0f); ret->push_back(0.0f); ret->push_back(0.0f); ret->push_back(0.0f); + } + + /***** allocation feature *****/ + ret->push_back(slog(fea_set.alloc_size)); + ret->push_back(slog(fea_set.alloc_prod)); + ret->push_back(slog(fea_set.alloc_outer_prod)); + ret->push_back(slog(fea_set.alloc_inner_prod)); + + /***** overall feature *****/ + ret->push_back(slog(fea_set.outer_prod)); + ret->push_back(slog(fea_set.num_loops)); + ret->push_back(slog(fea_set.auto_unroll_max_step)); + } +} + + +/* \brief Get the name of every element in the feature vector. Use this for debug and inspection */ +void GetPerStmtFeatureName(int max_n_bufs, std::vector *ret) { + /***** compute feature *****/ + ret->push_back(("float_mad")); + ret->push_back(("float_addsub")); + ret->push_back(("float_mul")); + ret->push_back(("float_divmod")); + ret->push_back(("float_cmp")); + ret->push_back(("float_mathfunc")); + ret->push_back(("float_otherfunc")); + ret->push_back(("int_mad")); + ret->push_back(("int_addsub")); + ret->push_back(("int_mul")); + ret->push_back(("int_divmod")); + ret->push_back(("int_cmp")); + ret->push_back(("int_mathfunc")); + ret->push_back(("int_otherfunc")); + ret->push_back(("bool_op")); + ret->push_back(("select_op")); + ret->push_back(("vec_num")); + ret->push_back(("vec_prod")); + ret->push_back(("vec_len")); + ret->push_back(("vec_type.kPosNone")); + ret->push_back(("vec_type.kPosInnerSpatial")); + ret->push_back(("vec_type.kPosMiddleSpatial")); + ret->push_back(("vec_type.kPosOuterSpatial")); + ret->push_back(("vec_type.kPosInnerReduce")); + ret->push_back(("vec_type.kPosMiddleReduce")); + ret->push_back(("vec_type.kPosOuterReduce")); + ret->push_back(("vec_type.kPosMixed")); + ret->push_back(("unroll_num")); + ret->push_back(("unroll_prod")); + ret->push_back(("unroll_len")); + ret->push_back(("unroll_type.kPosNone")); + ret->push_back(("unroll_type.kPosInnerSpatial")); + ret->push_back(("unroll_type.kPosMiddleSpatial")); + ret->push_back(("unroll_type.kPosOuterSpatial")); + ret->push_back(("unroll_type.kPosInnerReduce")); + ret->push_back(("unroll_type.kPosMiddleReduce")); + ret->push_back(("unroll_type.kPosOuterReduce")); + ret->push_back(("unroll_type.kPosMixed")); + ret->push_back(("parallel_num")); + ret->push_back(("parallel_prod")); + ret->push_back(("parallel_len")); + ret->push_back(("parallel_type.kPosNone")); + ret->push_back(("parallel_type.kPosInnerSpatial")); + ret->push_back(("parallel_type.kPosMiddleSpatial")); + ret->push_back(("parallel_type.kPosOuterSpatial")); + ret->push_back(("parallel_type.kPosInnerReduce")); + ret->push_back(("parallel_type.kPosMiddleReduce")); + ret->push_back(("parallel_type.kPosOuterReduce")); + ret->push_back(("parallel_type.kPosMixed")); + ret->push_back(("is_gpu")); + ret->push_back(("blockIdx_x_len")); + ret->push_back(("blockIdx_y_len")); + ret->push_back(("blockIdx_z_len")); + ret->push_back(("threadIdx_x_len")); + ret->push_back(("threadIdx_y_len")); + ret->push_back(("threadIdx_z_len")); + ret->push_back(("vthread_len")); + for (size_t i = 0; i < ARITH_INTENSITY_CURVE_SAMPLE_N; ++i) { + ret->push_back(("arith_intensity_curve_" + std::to_string(i))); + } + // section total: 55 + ARITH_INTENSITY_CURVE_SAMPLE_N = 65 + + /***** access feature *****/ + for (size_t i = 0; i < static_cast(max_n_bufs); ++i) { + std::string prefix = "B" + std::to_string(i) + "."; + ret->push_back((prefix + "acc_type.kRead")); + ret->push_back((prefix + "acc_type.kWrite")); + ret->push_back((prefix + "acc_type.kReadWrite")); + ret->push_back((prefix + "bytes")); + ret->push_back((prefix + "unique_bytes")); + ret->push_back((prefix + "lines")); + ret->push_back((prefix + "unique_lines")); + ret->push_back((prefix + "reuse_type.kLoopMultipleRead")); + ret->push_back((prefix + "reuse_type.kSerialMultipleReadWrite")); + ret->push_back((prefix + "reuse_type.kNoReuse")); + ret->push_back((prefix + "reuse_dis_iter")); + ret->push_back((prefix + "reuse_dis_bytes")); + ret->push_back((prefix + "reuse_ct")); + ret->push_back((prefix + "bytes_d_reuse_ct")); + ret->push_back((prefix + "unique_bytes_d_reuse_ct")); + ret->push_back((prefix + "lines_d_reuse_ct")); + ret->push_back((prefix + "unique_lines_d_reuse_ct")); + ret->push_back((prefix + "stride")); + } + // section total : max_n_bufs * 18 + + /***** allocation feature *****/ + ret->push_back(("alloc_size")); + ret->push_back(("alloc_prod")); + ret->push_back(("alloc_outer_prod")); + ret->push_back(("alloc_inner_prod")); + // section total : 4 + + /***** overall feature *****/ + ret->push_back(("outer_prod")); + ret->push_back(("num_loops")); + ret->push_back(("auto_unroll_max_step")); + // section total : 2 +} + +void GetPerStmtFeaturesWorkerFunc(const SearchTask& task, const State& state, + int max_n_bufs, std::vector* feature, std::atomic* error_ct) { + te::Schedule sch; + Array tensors; + Map bounds; + GlobalVar g("main"); + + std::tie(sch, tensors) = task->compute_dag.ApplySteps(state->transform_steps); + sch = sch.normalize(); + bounds = te::InferBound(sch); + + try { + auto stmt = te::ScheduleOps(sch, bounds, false); + Map out_binds; Array out_arg_list; + bool compact = te::VerifyCompactBuffer(stmt); + GetBinds(tensors, compact, std::unordered_map(), + &out_binds, &out_arg_list, BuildConfig::Create()); + tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, + std::move(stmt), out_binds); + f = WithAttr(std::move(f), "global_symbol", runtime::String("main")); + auto mod = IRModule(Map({{g, f}})); + auto pass_list = Array(); + if (task->target->device_type == kDLGPU) { + pass_list.push_back(tir::transform::InjectPrefetch()); + pass_list.push_back(tir::transform::StorageFlatten(64)); + pass_list.push_back(tir::transform::Simplify()); + pass_list.push_back(tir::transform::VectorizeLoop()); + pass_list.push_back(tir::transform::InjectVirtualThread()); + pass_list.push_back(tir::transform::StorageRewrite()); + pass_list.push_back(tir::transform::Simplify()); + tvm::Map gpu_params { + {"max_shared_memory_per_block", + task->hardware_params->max_shared_memory_per_block}, + {"max_local_memory_per_block", + task->hardware_params->max_registers_per_block}, + {"max_threads_per_block", + task->hardware_params->max_threads_per_block}, + {"max_vector_bytes", + task->hardware_params->vector_unit_bytes} + }; + pass_list.push_back(tir::transform::VerifyGPUCode(gpu_params)); + const auto& optimize = tir::transform::Sequential(pass_list); + optimize(mod); + } + pass_list.clear(); + pass_list.push_back(tir::transform::Simplify()); + const auto& optimize = tir::transform::Sequential(pass_list); + mod = optimize(std::move(mod)); + const auto& it = mod->functions.find(g); + CHECK(it != mod->functions.end()); + const auto& prim_func = (*it).second.as(); + GetPerStmtFeature(prim_func->body, + task->hardware_params->cache_line_bytes, + max_n_bufs, feature); + } catch (dmlc::Error &e) { + (*error_ct)++; + } +} + +void GetPerStmtFeaturesFromStates(const Array& states, + const SearchTask& task, + int max_n_bufs, + int skip_first_n_feature_extraction, + std::vector >* features) { + // extract features + features->assign(states.size(), std::vector()); + + std::atomic error_ct(0); + + ThreadPool& pool = ThreadPool::Global(); + pool.BeginBatch(static_cast(states.size()) - skip_first_n_feature_extraction); + for (size_t i = skip_first_n_feature_extraction; i < states.size(); ++i) { + pool.Enqueue(GetPerStmtFeaturesWorkerFunc, task, states[i], + max_n_bufs, &(*features)[i], &error_ct); + } + pool.WaitBatch(); + + if (error_ct > 0) { + std::cerr << "Encountered " << error_ct + << " errors during feature extraction. Ignored." << std::endl; + } +} + + +void GetPerStmtFeaturesFromStates(const Array& states, + const std::vector& tasks, + int max_n_bufs, + int skip_first_n_feature_extraction, + std::vector >* features) { + // extract features + features->assign(states.size(), std::vector()); + + std::atomic error_ct(0); + + ThreadPool& pool = ThreadPool::Global(); + pool.BeginBatch(static_cast(states.size()) - skip_first_n_feature_extraction); + for (size_t i = skip_first_n_feature_extraction; i < states.size(); ++i) { + pool.Enqueue(GetPerStmtFeaturesWorkerFunc, tasks[i], states[i], + max_n_bufs, &(*features)[i], &error_ct); + } + pool.WaitBatch(); + + if (error_ct > 0) { + std::cerr << "Encountered " << error_ct + << " errors during feature extraction. Ignored." << std::endl; + } +} + +void GetPerStmtFeaturesFromFile(const std::string& filename, + int n_lines, + int max_n_bufs, + std::vector >* features, + std::vector* normalized_throughputs, + std::vector* task_ids) { + Array states; + // ArrayNode* pstates = states.CopyOnWrite(); + std::vector tasks; + + normalized_throughputs->clear(); + task_ids->clear(); + + // (workload_key, target) -> (search_task, task_id) + std::unordered_map, std::pair> task_cache; + // task_id -> min_cost + std::vector min_costs; + + // read from file + LogReader reader = LogReaderNode::make(filename); + auto cur_inp = make_object(); + auto cur_res = make_object(); + while (reader->ReadNext(cur_inp.get(), cur_res.get())) { + float cost = static_cast(FloatArrayMean(cur_res->costs)); + const std::string& workload_key = cur_inp->task->workload_key; + + SearchTask task; + size_t task_id; + std::pair key(workload_key, cur_inp->task->target->str()); + auto find_res = task_cache.find(key); + if (find_res == task_cache.end()) { + // rebuild task + task = SearchTaskNode::make(ComputeDAGNode::make_by_workload_key(workload_key), + workload_key, + cur_inp->task->target, + cur_inp->task->target_host, + cur_inp->task->hardware_params); + task_id = task_cache.size(); + + // compute min cost for each task + task_cache.insert(std::make_pair(key, std::make_pair(task, task_id))); + min_costs.push_back(cost); + } else { + std::tie(task, task_id) = find_res->second; + min_costs[task_id] = std::min(min_costs[task_id], cost); + } + + tasks.push_back(std::move(task)); + task_ids->push_back(task_id); + // pstates->data.push_back(cur_inp->state); + states.push_back(cur_inp->state); + normalized_throughputs->push_back(cost); + + if (n_lines > 0 && static_cast(states.size()) >= n_lines) { + break; + } + } + + for (size_t i = 0; i < normalized_throughputs->size(); ++i) { + (*normalized_throughputs)[i] = min_costs[(*task_ids)[i]] / (*normalized_throughputs)[i]; + } + + GetPerStmtFeaturesFromStates(states, tasks, max_n_bufs, 0, features); +} + +void GetPerStmtFeaturesFromMeasurePairs(const Array& inputs, + const Array& results, + int max_n_bufs, + int skip_first_n_feature_extraction, + std::vector >* features, + std::vector* normalized_throughputs, + std::vector* task_ids) { + Array states; + // ArrayNode* pstates = states.CopyOnWrite(); + std::vector tasks; + + normalized_throughputs->clear(); + task_ids->clear(); + + // (workload_key, target) -> (search_task, task_id) + std::unordered_map, std::pair> task_cache; + // task_id -> min_cost + std::vector min_costs; + + tasks.reserve(inputs.size()); + normalized_throughputs->reserve(inputs.size()); + task_ids->reserve(inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + float cost = static_cast(FloatArrayMean(results[i]->costs)); + const std::string& workload_key = inputs[i]->task->workload_key; + SearchTask task; + + size_t task_id; + std::pair key(workload_key, inputs[i]->task->target->str()); + auto find_res = task_cache.find(key); + if (find_res == task_cache.end()) { + if (inputs[i]->task->compute_dag.defined()) { // the measure input is complete + task = inputs[i]->task; + } else { // the measure input is incomplete + // rebuild task for incomplete measure pairs read from file + task = SearchTaskNode::make(ComputeDAGNode::make_by_workload_key(workload_key), + workload_key, + inputs[i]->task->target, + inputs[i]->task->target_host, + inputs[i]->task->hardware_params); + } + task_id = task_cache.size(); + + // compute min cost for each task + task_cache.insert(std::make_pair(key, std::make_pair(task, task_id))); + min_costs.push_back(cost); + } else { + std::tie(task, task_id) = find_res->second; + min_costs[task_id] = std::min(min_costs[task_id], cost); + } + + tasks.push_back(std::move(task)); + task_ids->push_back(task_id); + // pstates->data.push_back(inputs[i]->state); + states.push_back(inputs[i]->state); + normalized_throughputs->push_back(cost); + } + + for (size_t i = 0; i < normalized_throughputs->size(); ++i) { + (*normalized_throughputs)[i] = min_costs[(*task_ids)[i]] / (*normalized_throughputs)[i]; + } + + GetPerStmtFeaturesFromStates(states, tasks, max_n_bufs, + skip_first_n_feature_extraction, features); +} + +} // namespace ansor +} // namespace tvm diff --git a/src/ansor/feature.h b/src/ansor/feature.h new file mode 100644 index 000000000000..149c59e8cb7d --- /dev/null +++ b/src/ansor/feature.h @@ -0,0 +1,63 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file ansor/search_task.h + * \brief Meta inforamtion for a search task + */ + +#ifndef TVM_ANSOR_FEATURE_H_ +#define TVM_ANSOR_FEATURE_H_ + +// #include +#include +#include +#include "compute_dag.h" +#include "measure.h" + +namespace tvm { +namespace ansor { + +/*! \brief Get PerStmt feature from a tvm IR stmt */ +void GetPerStmtFeature(const Stmt& stmt, + int cache_line_size, + int max_n_bufs, + std::vector* ret); + +/* \brief Get the name of every element in the feature vector. Use this for debug and inspection */ +void GetPerStmtFeatureName(int max_n_bufs, std::vector *ret); + + +/*! \brief Get PerStmt feature from states */ +void GetPerStmtFeaturesFromStates(const Array& states, + const SearchTask& task, + int max_n_bufs, + int skip_first_n_feature_extraction, + std::vector >* features); + +/*! \brief Get PerStmt feature from states */ +void GetPerStmtFeaturesFromStates(const Array& states, + const std::vector& tasks, + int max_n_bufs, + int skip_first_n_feature_extraction, + std::vector >* features); + +/*! \brief Get PerStmt feature from a log file */ +void GetPerStmtFeaturesFromFile(const std::string& filename, + int n_lines, + int max_n_bufs, + std::vector >* features, + std::vector* normalized_throughputs, + std::vector* task_ids); + +/*! \brief Get PerStmt feature from measure pairs */ +void GetPerStmtFeaturesFromMeasurePairs(const Array& inputs, + const Array& results, + int max_n_bufs, + int skip_first_n_feature_extraction, + std::vector >* features, + std::vector* normalized_throughputs, + std::vector* task_ids); + +} // namespace ansor +} // namespace tvm + +#endif // TVM_ANSOR_FEATURE_H_ diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.cc b/src/ansor/search_policy/meta_tile_rewrite_policy.cc new file mode 100644 index 000000000000..b3b93ec9c839 --- /dev/null +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.cc @@ -0,0 +1,1420 @@ +/*! + * Copyright (c) 2020 by Contributors + */ + +#include "meta_tile_rewrite_policy.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "utils.h" + +#define IS_GPU(task) ((task)->target->device_type == kDLGPU || \ + (task)->target->device_type == kDLOpenCL) + +namespace tvm { +namespace ansor { + +TVM_REGISTER_OBJECT_TYPE(MetaTileRewritePolicyNode); + +// All possible candidates for auto_unroll +const std::vector MetaTileRewritePolicyNode::auto_unroll_configs{0, 16, 64, 512, 1024}; + +SearchPolicy MetaTileRewritePolicyNode::make(CostModel program_cost_model, + Map params, + int seed) { + auto node = make_object(); + node->program_cost_model = std::move(program_cost_model); + node->rand_gen_ = std::mt19937(seed); + node->params = std::move(params); + return SearchPolicy(node); +} + +State MetaTileRewritePolicyNode::Search(SearchTask task, int n_trials, + int early_stopping, int num_measure_per_iter, + int verbose, ProgramMeasurer measurer) { + std::vector best_states, random_states; + cur_task_ = task; + verbose_ = verbose; + num_measure_per_iter_ = num_measure_per_iter; + + if (n_trials <= 1) { // no measurement is allowed + SearchOneRound(&best_states, 0, &random_states); + CHECK_GT(best_states.size(), 0); + return best_states[0]; + } else { + std::vector inputs; + std::vector results; + int num_random = static_cast(GetDoubleParam(params, "eps_greedy") * num_measure_per_iter); + + measurer->Reset(); + + early_stopping = early_stopping < 0 ? std::numeric_limits::max() >> 1 : early_stopping; + + int ct = 0; + while (ct < n_trials) { + if (!inputs.empty()) { + // retrain cost models + PrintTitle("Train cost model", verbose_); + program_cost_model->Update(inputs, results); + } + + // Search one round to get promising states + PrintTitle("Search", verbose_); + SearchOneRound(&best_states, num_random, &random_states); + + // Fill correct bound.This is necessary for computing the correct ToStr() for reduncency check + cur_task_->compute_dag.InferBound(&best_states); + cur_task_->compute_dag.InferBound(&random_states); + + // Pick `num_measure_per_iter` states to measure, check hash to remove already measured state + // Also pick some random states to do eps-greedy + PickStatesWithEpsGreedy(&inputs, best_states, random_states, n_trials - ct); + + // Have traversed all of search space + if (inputs.empty()) { + StdCout(verbose) << "All candidates in the search space have been measured." << std::endl; + break; + } + + // Measure candidate states + PrintTitle("Measure", verbose_); + measurer->Measure(cur_task_, GetRef(this), inputs, &results); + ct += inputs.size(); + + if (ct - measurer->best_ct[cur_task_->workload_key] > early_stopping) { + StdCout(verbose) << "Meet the early stopping condition." << std::endl; + break; + } + + // Update measured states. These states will join the LocalMutation in later rounds + for (const auto& res : results) { + measured_states_throughputs_.push_back(1.0 / FloatArrayMean(res->costs)); + } + } + PrintTitle("Done", verbose_); + + return measurer->best_state[cur_task_->workload_key]; + } +} + +std::pair, Array > + MetaTileRewritePolicyNode::ContinueSearchOneRound( + SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) { + if (cur_task_.defined()) { + CHECK_EQ(cur_task_, task); + } else { + cur_task_ = task; + } + verbose_ = verbose; + num_measure_per_iter_ = num_measure; + + std::vector best_states, random_states; + std::vector inputs; + std::vector results; + int num_random = static_cast(GetDoubleParam(params, "eps_greedy") * num_measure); + + // Search one round to get promising states + PrintTitle("Search", verbose); + SearchOneRound(&best_states, num_random * 2, &random_states); + + // Fill correct bound. This is necessary for computing the correct ToStr() for reduncency check + cur_task_->compute_dag.InferBound(&best_states); + cur_task_->compute_dag.InferBound(&random_states); + + // Pick `num_measure` states to measure, check hash to remove already measured state + // Also pick some random states to do eps-greedy + PickStatesWithEpsGreedy(&inputs, best_states, random_states, num_measure); + + // Measure candidate states + PrintTitle("Measure", verbose); + measurer->Measure(cur_task_, GetRef(this), inputs, &results); + + // Update throughputs of measured states. These states will join the LocalMutation in later rounds + for (const auto& res : results) { + measured_states_throughputs_.push_back(1.0 / FloatArrayMean(res->costs)); + } + + // Update the cost model + Array inputs_arr(std::make_move_iterator(inputs.begin()), + std::make_move_iterator(inputs.end())); + Array results_arr(std::make_move_iterator(results.begin()), + std::make_move_iterator(results.end())); + + PrintTitle("Train cost model", verbose); + program_cost_model->Update(inputs_arr, results_arr); + return std::make_pair(std::move(inputs_arr), std::move(results_arr)); +} + +void MetaTileRewritePolicyNode::PickStatesWithEpsGreedy( + std::vector* inputs, + const std::vector& best_states, + const std::vector& random_states, + int remaining_n_trials) { + int num_random = static_cast(GetDoubleParam(params, "eps_greedy") * num_measure_per_iter_); + int num_good = num_measure_per_iter_ - num_random; + + inputs->clear(); + size_t offset_best = 0, offset_random = 0; + + while (static_cast(inputs->size()) < std::min(num_measure_per_iter_, remaining_n_trials)) { + const State* pstate; + + bool has_best = offset_best < best_states.size(); + bool has_random = offset_random < random_states.size(); + + if (static_cast(inputs->size()) < num_good) { + // prefer best states + if (has_best) { + pstate = &best_states[offset_best++]; + } else if (has_random) { + pstate = &random_states[offset_random++]; + } else { + break; + } + } else { + // prefer random states + if (has_random) { + pstate = &random_states[offset_random++]; + } else if (has_best) { + pstate = &best_states[offset_best++]; + } else { + break; + } + } + + // Check if it has already been measured + std::string state_str = pstate->ToStr(); + + if (measured_states_set_.count(state_str)) { continue; } + measured_states_set_.insert(state_str); + + inputs->push_back(MeasureInputNode::make(cur_task_, *pstate)); + measured_states_vector_.push_back(std::move(*pstate)); + } +} + +void MetaTileRewritePolicyNode::SearchOneRound(std::vector* best_states, + int num_random_states, std::vector* random_states) { + best_states->clear(); + random_states->clear(); + + // Get parameters + int population = GetIntParam(params, "evolutionary_search_population"); + int num_use_measured = std::min(static_cast(measured_states_vector_.size()), + static_cast( + GetDoubleParam(params, "evolutionary_search_use_measured_ratio") * population)); + bool have_cost_model = !program_cost_model->IsInstance(); + + if (!have_cost_model) { + num_use_measured = 0; + } + + // Synthesize meta structure + std::vector meta_structures; + SynthesizeMetaStructure(&meta_structures); + + // PrintAllStates(meta_structures); + // exit(0); + + // Sample the init population + std::vector init_population; + SampleInitPopulation(meta_structures, population - num_use_measured, &init_population); + + // PrintAllStates(init_population); + // exit(0); + + if (have_cost_model) { + // Also insert already measured good states to the initial population + std::vector indices; + Argsort(measured_states_throughputs_, &indices); + for (int i = 0; i < num_use_measured; i++) { + init_population.push_back(measured_states_vector_[indices[i]]); + } + + // Perform evolutionary search + EvolutionarySearch(init_population, num_measure_per_iter_ * 2, best_states); + } else { + // If the cost model is useless (i.e. RandomCostModel), skip evolutionary search + RandomSampleStates(init_population, &rand_gen_, num_measure_per_iter_ * 3, best_states); + } + + // Sample some random states for eps-greedy + RandomSampleStates(init_population, &rand_gen_, num_random_states * 10, random_states); +} + +// The baseclass of derivation rules used in meta structure synthesis +class StructureSynthesisRule { + public: + enum ConditionEnum { + kPass, kApply, kApplyAndSkipRest + }; + + virtual ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) = 0; + virtual std::vector > Apply(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) = 0; +}; + +static inline bool ShouldBeCacheRead( + const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) { + const SearchTask& task = policy->cur_task_; + const Stage& stage = state->stages[stage_id]; + + if (HasAttrsFlag(state, stage_id, + SearchPolicyNode::no_cache_read_key)) { + return false; + } + + std::unordered_set consumers; + GetConsumers(task, state, stage->op, &consumers); + if (consumers.size() != 1) { + return false; + } + + int target_stage_id = OperationToStage(*consumers.begin(), state); + if (!NeedsMultilevelTiling(task, state, + state->stages[target_stage_id]->op)) { + return false; + } + + std::unordered_set producers; + GetProducers(task, state, state->stages[target_stage_id]->op, &producers); + // Only those directly mapped stages can do CacheRead + if (producers.find(stage->op) == producers.end()) { + return false; + } + + return true; +} + +static inline bool ShouldAlwaysBeInlined( + const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) { + const SearchTask& task = policy->cur_task_; + const Stage& stage = state->stages[stage_id]; + + if (stage->op->IsInstance()) { + return false; + } + + // Inline limitation of TVM + if (!IsOutputOp(task, state, stage->op) && !HasReduceIter(stage)) { + // Always inline condition: + // 1. Has attrs that this must be inlined + // 2. Analyse shows this is strict inlineable + // 3. A GPU stage can be inlined(If it should be cache read, do it first) + if (HasAttrsFlag(state, stage_id, + SearchPolicyNode::always_compute_inline_key) || + IsStrictInlineable(task, state, stage->op) || + (IS_GPU(policy->cur_task_) && + !ShouldBeCacheRead(policy, state, stage_id))) { + return true; + } + } + + return false; +} + +// The rule that inlines simple elementwise ops +class RuleAlwaysInline : public StructureSynthesisRule { + public: + ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) final { + return ShouldAlwaysBeInlined(policy, state, stage_id) ? + kApplyAndSkipRest : kPass; + } + + std::vector > Apply(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) final { + State tmp_s = state; + tmp_s.compute_inline(stage_id); + return {std::make_pair(std::move(tmp_s), stage_id - 1)}; + } +}; + +// The rule that simply skip the current stage +class RuleSkipStage : public StructureSynthesisRule { + public: + ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) final { + const SearchTask& task = policy->cur_task_; + const Stage& stage = state->stages[stage_id]; + + const auto& attrs = stage->op->attrs; + if ((attrs.count(SearchPolicyNode::no_split_at_inner_key) || + attrs.count(SearchPolicyNode::no_split_at_outer_key)) && + NeedsMultilevelTiling(task, state, stage->op)) { + // for the transform stages in Winograd + return kPass; + } + + return kApply; + } + + std::vector > Apply(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) final { + return {std::make_pair(state, stage_id - 1)}; + } +}; + +// The rule that performs multi-level tiling +class RuleMultiLevelTiling : public StructureSynthesisRule { + public: + ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) final { + const SearchTask& task = policy->cur_task_; + const Stage& stage = state->stages[stage_id]; + + return NeedsMultilevelTiling(task, state, stage->op) ? + (IS_GPU(policy->cur_task_) ? kApplyAndSkipRest : kApply) : kPass; + } + + std::vector > Apply(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) final { + std::string multi_level_tiling_structure = IS_GPU(policy->cur_task_) ? + GetStringParam(policy->params, "gpu_multi_level_tiling_structure") : + GetStringParam(policy->params, "cpu_multi_level_tiling_structure"); + + std::vector spatial_split_step_ids; + State tmp_s = state; + tmp_s = DoMultiLevelTiling(tmp_s, stage_id, multi_level_tiling_structure, + &spatial_split_step_ids); + return {std::make_pair(std::move(tmp_s), stage_id-1)}; + } +}; + +// The rule that performs multi-level tiling and fuses later consumers +class RuleMultiLevelTilingWithFusion : public StructureSynthesisRule { + public: + ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) final { + const SearchTask& task = policy->cur_task_; + const Stage& stage = state->stages[stage_id]; + + int target_stage_id; + + if (IS_GPU(policy->cur_task_)) { + return NeedsMultilevelTiling(task, state, stage->op) && + HasSingleElementwiseMatchedConsumer(task, state, stage, + &target_stage_id) && + (!HasCacheReadStage(state, stage_id) || + HasCacheWriteStage(state, stage_id)) ? + kApplyAndSkipRest : kPass; + } + + return NeedsMultilevelTiling(task, state, stage->op) && + HasSingleElementwiseMatchedConsumer(task, state, stage, + &target_stage_id) ? + kApply : kPass; + } + + std::vector > Apply(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) final { + const SearchTask& task = policy->cur_task_; + const Stage& stage = state->stages[stage_id]; + std::string multi_level_tiling_structure = IS_GPU(policy->cur_task_) ? + GetStringParam(policy->params, "gpu_multi_level_tiling_structure") : + GetStringParam(policy->params, "cpu_multi_level_tiling_structure"); + + std::vector spatial_split_step_ids; + int target_stage_id; + std::unordered_set consumers; + + GetConsumers(task, state, state->stages[stage_id]->op, &consumers); + CHECK(HasSingleElementwiseMatchedConsumer(task, state, stage, &target_stage_id)); + + State base_state = state; + base_state = DoMultiLevelTiling(base_state, stage_id, + multi_level_tiling_structure, &spatial_split_step_ids); + std::vector follow_tiling_levels; + if (IS_GPU(policy->cur_task_)) { + follow_tiling_levels.push_back(3); + } else { + follow_tiling_levels.push_back(1); + follow_tiling_levels.push_back(2); + } + + std::vector > ret; + for (int level : follow_tiling_levels) { + if (tolower(multi_level_tiling_structure[level-1]) != 's') { + continue; + } + State tmp_s = base_state; + tmp_s = FollowTiling(tmp_s, target_stage_id, spatial_split_step_ids, level); + const Iterator &target_iter = tmp_s->stages[target_stage_id]->iters[ + level * spatial_split_step_ids.size() - 1]; + tmp_s.compute_at(stage_id, target_stage_id, target_iter); + + ret.emplace_back(std::move(tmp_s), stage_id - 1); + } + + return ret; + } +}; + +// The rule that adds a cache write stage +class RuleAddCacheWrite : public StructureSynthesisRule { + public: + ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) final { + const SearchTask& task = policy->cur_task_; + const Stage& stage = state->stages[stage_id]; + + int target_stage_id; + + // Add cache write if a stage needs multi-level tiling, + // but does not have a element-wise matched consumer + return NeedsMultilevelTiling(task, state, stage->op) && + !HasAttrsFlag(state, stage_id, SearchPolicyNode::no_cache_write_key) && + (!HasSingleElementwiseMatchedConsumer(task, state, stage, + &target_stage_id) || + (HasCacheReadStage(state, stage_id) && + !HasCacheWriteStage(state, stage_id))) ? + kApply : kPass; + } + + std::vector > Apply(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) final { + const SearchTask& task = policy->cur_task_; + + State tmp_s = state; + tmp_s.cache_write(stage_id, "local", task->compute_dag); + return {std::make_pair(std::move(tmp_s), stage_id)}; + } +}; + +// The rule that adds a cache read stage +// Mainly used for GPU cooperative fetching +// Currently only support 1 to 1 match cache read +class RuleAddCacheRead : public StructureSynthesisRule { + public: + ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) final { + return ShouldBeCacheRead(policy, state, stage_id) ? + kApplyAndSkipRest : kPass; + } + + std::vector > Apply(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) final { + const SearchTask& task = policy->cur_task_; + const Stage& stage = state->stages[stage_id]; + + std::unordered_set consumers; + GetConsumers(task, state, stage->op, &consumers); + CHECK_EQ(consumers.size(), 1); + int target_stage_id = OperationToStage(*consumers.begin(), state); + State tmp_s = state; + int added_stage_id = tmp_s.cache_read(stage_id, "shared", + {target_stage_id}, + task->compute_dag); + target_stage_id++; + const auto& share_read_pos = GetLastReduceIteratorInOutermostReduceTile( + tmp_s->stages[target_stage_id]); + tmp_s.compute_at(added_stage_id, target_stage_id, share_read_pos); + + return {std::make_pair(std::move(tmp_s), stage_id)}; + } +}; + +// The rule that adds rfactor stage +class RuleAddRfactor : public StructureSynthesisRule { + public: + ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) final { + const SearchTask& task = policy->cur_task_; + const Stage& stage = state->stages[stage_id]; + + return NeedsRfactor(task, state, stage->op) && + !HasCacheWriteStage(state, stage_id) ? + kApply : kPass; + } + + std::vector > Apply(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) final { + const SearchTask& task = policy->cur_task_; + const Stage& stage = state->stages[stage_id]; + + std::vector > ret; + + State tmp_s = state; + + // fuse reduce iters + std::vector space_iters, reduce_iters; + for (const auto &iter : stage->iters) { + if (iter->iter_type == kSpace) { + space_iters.push_back(iter); + } else if (iter->iter_type == kReduce) { + reduce_iters.push_back(iter); + } + } + CHECK(!reduce_iters.empty()); + Iterator fused_reduce_iter; + if (reduce_iters.size() > 1) { + fused_reduce_iter = tmp_s.fuse(stage_id, reduce_iters); + } else { + fused_reduce_iter = reduce_iters[0]; + } + + // split reduce iters + const auto &split_res = tmp_s.split(stage_id, fused_reduce_iter, {1}); + int factor_axis_id = static_cast(space_iters.size()); + State base_state = tmp_s; + for (const auto &split_iter : split_res) { + tmp_s = base_state; + tmp_s.rfactor(stage_id, split_iter, factor_axis_id, task->compute_dag); + + // reorder the space iterator to innermost for vectorization + if (split_iter == split_res[1]) { + std::vector new_order; + for (size_t i = 0; i < tmp_s->stages[stage_id]->iters.size(); ++i) { + if (i != space_iters.size()) { + new_order.push_back(tmp_s->stages[stage_id]->iters[i]); + } + } + new_order.push_back(tmp_s->stages[stage_id]->iters[space_iters.size()]); + tmp_s.reorder(stage_id, new_order); + } + ret.emplace_back(std::move(tmp_s), stage_id - 1); + } + + return ret; + } +}; + +void MetaTileRewritePolicyNode::SynthesizeMetaStructure(std::vector* out_states) { + State init_state = cur_task_->compute_dag.GetInitState(); + std::string cpu_multi_level_tiling_structure = + GetStringParam(params, "cpu_multi_level_tiling_structure"); + + // two ping pong buffers to avoid copy + std::vector states_buf1, states_buf2; + std::vector *pnow, *pnext; + pnow = &states_buf1; + pnext = &states_buf2; + pnow->push_back(init_state); + + // A map that maps state to its current working position (stage_id) + std::unordered_map cur_stage_id_map; + cur_stage_id_map[init_state] = static_cast(init_state->stages.size() - 1); + + static RuleSkipStage rule_skip_stage; + static RuleAlwaysInline rule_always_inline; + static RuleMultiLevelTiling rule_multi_level_tiling; + static RuleMultiLevelTilingWithFusion rule_multi_level_tiling_with_fusion; + static RuleAddCacheWrite rule_add_cache_write_stage; + static RuleAddCacheRead rule_add_cache_read_stage; + static RuleAddRfactor rule_add_rfactor; + // We may apply and skip the rest when processing some rules, + // should take care of the rule vector order here + static std::vector all_rules { + &rule_always_inline, &rule_add_cache_write_stage, + &rule_multi_level_tiling_with_fusion, &rule_multi_level_tiling, + &rule_add_rfactor, &rule_skip_stage + }; + if (IS_GPU(cur_task_)) { + // Try cache read first before cache write + all_rules.insert(all_rules.begin() + 1, &rule_add_cache_read_stage); + } + // TODO(xian): Add a new rule to try combination of multi-level tiling + rfactor + + // Derivation rule based synthesizer + while (!pnow->empty()) { + pnext->clear(); + + for (const State& state : *pnow) { + int stage_id = cur_stage_id_map[state]; + + // Reaches to the terminal stage + if (stage_id < 0) { + out_states->push_back(state); + continue; + } + + // Try all derivation rules + for (const auto& rule : all_rules) { + auto rule_check = rule->MeetCondition(this, state, stage_id); + if (rule_check > StructureSynthesisRule::ConditionEnum::kPass) { + for (const auto& pair : rule->Apply(this, state, stage_id)) { + cur_stage_id_map[pair.first] = pair.second; + pnext->push_back(pair.first); + } + // Skip the reset rules + if (rule_check == StructureSynthesisRule::ConditionEnum::kApplyAndSkipRest) { + break; + } + } + } + } + + std::swap(pnow, pnext); + } + + // Hack for rfactor: Replace the split factor for rfactor to the undefined Expr(), + // so later we can sample random value for the split factor. + // Why don't we use Expr() when doing the split for rfactor at the first time? + // Because during ApplySteps, a rfactor with undefined Expr() will crash TVM. + // So rfactor with undefined Expr() will conflict with cache_write, cache_read, rfactor + // in other stages + for (size_t i = 0; i < out_states->size(); ++i) { + auto pstate = (*out_states)[i].CopyOnWrite(); + for (size_t step_id = 0; step_id < pstate->transform_steps.size(); ++step_id) { + if (pstate->transform_steps[step_id]->IsInstance()) { + CHECK_GE(step_id, 1); + int split_step_id = step_id - 1; + auto step = pstate->transform_steps[split_step_id].as(); + CHECK(step != nullptr); + pstate->transform_steps[split_step_id] + = SplitStepNode::make(step->stage_id, step->iter_id, step->extent, {PrimExpr()}, + step->inner_to_outer); + } + } + } + + StdCout(verbose_) << "Synthesize Meta Structure\t\t#s: " << out_states->size() << std::endl; +} + +int InitPopulationFillTileSize(const MetaTileRewritePolicyNode* policy, + State* state, std::mt19937* rand_gen, + SplitFactorizationMemo* split_memo) { + for (size_t step_id = 0; step_id < (*state)->transform_steps.size(); ++step_id) { + if (auto ps = (*state)->transform_steps[step_id].as()) { + bool defined = true; + for (const PrimExpr& len : ps->lengths) { + if (!len.defined()) { + defined = false; + } + } + + if (defined) { + continue; + } + + int extent = GetIntImm(ps->extent); + const std::vector >& candidate_lens = + split_memo->GetFactorizationSchemes( + extent, ps->lengths.size(), + policy->cur_task_->hardware_params->max_innermost_split_factor); + + StateNode* pstate = state->CopyOnWrite(); + pstate->transform_steps[step_id] = SplitStepNode::make( + ps->stage_id, ps->iter_id, ps->extent, + candidate_lens[(*rand_gen)() % candidate_lens.size()], + ps->inner_to_outer); + } + } + + return 0; +} + +int InitPopulationThreadBind(const MetaTileRewritePolicyNode* policy, + State* state) { + for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { + const Stage& stage = (*state)->stages[stage_id]; + auto pop = stage->op.as(); + + if (stage->compute_at != kRoot || stage->op_type == kPlaceholder) { + continue; + } + + std::vector to_fuse; + + // This stage has not been tiled, but in GPU schedule, we must tile it + // to do thread binding + if (!HasSplitStep(*state, stage_id)) { + for (const auto& it : (*state)->stages[stage_id]->iters) { + if (it->iter_type == kReduce) { + break; + } + to_fuse.push_back(it); + } + const auto& fused_it = state->fuse(stage_id, to_fuse); + // Set default vthread=1 & threadIdx.x=default_warp_size + // EvolutionarySearch will try more possiblity + if (GetExtent(fused_it) <= + policy->cur_task_->hardware_params->warp_size) { + state->bind_thread(stage_id, fused_it, kThreadX); + } else { + const auto& split_its = state->split(stage_id, fused_it, + {1, policy->cur_task_->hardware_params->warp_size}); + state->bind_thread(stage_id, split_its[0], kBlockX); + state->bind_thread(stage_id, split_its[1], kVThread); + state->bind_thread(stage_id, split_its[2], kThreadX); + } + + continue; + } + + int total_space_extent = 1; + for (const auto& i : pop->root_iter_vars()) { + CHECK(i->dom.defined()); + const auto& pint = i->dom->extent.as(); + CHECK(pint); + total_space_extent *= pint->value; + } + + // TODO(..): Add ThreadBind support for rfactor + if (total_space_extent <= policy->cur_task_->hardware_params->warp_size) { + for (const auto& it : (*state)->stages[stage_id]->iters) { + if (it->iter_type == kReduce) { + break; + } + to_fuse.push_back(it); + } + const auto& fused_it = state->fuse(stage_id, to_fuse); + state->bind_thread(stage_id, fused_it, kThreadX); + + continue; + } + + // Fuse the outermost space tile as blockIdx + for (size_t i = 0; i < pop->axis.size(); i++) { + const auto& it = (*state)->stages[stage_id]->iters[i]; + if (!StringEndWith(it->name, ".0")) { + break; + } + to_fuse.push_back(it); + } + const auto& blockidx_it = state->fuse(stage_id, to_fuse); + state->bind_thread(stage_id, blockidx_it, kBlockX); + + // Fuse the second outermost space tile as vthread + to_fuse.clear(); + for (size_t i = 1; i < pop->axis.size() + 1; i++) { + const auto& it = (*state)->stages[stage_id]->iters[i]; + if (!StringEndWith(it->name, ".1")) { + break; + } + to_fuse.push_back((*state)->stages[stage_id]->iters[i]); + } + const auto& vthread_it = state->fuse(stage_id, to_fuse); + if (GetExtent(vthread_it) > + policy->cur_task_->hardware_params->max_vthread_extent) { + return -1; + } + state->bind_thread(stage_id, vthread_it, kVThread); + + // Fuse the third outermost space tile as threadIdx + to_fuse.clear(); + for (size_t i = 2; i < pop->axis.size() + 2; i++) { + const auto& it = (*state)->stages[stage_id]->iters[i]; + if (!StringEndWith(it->name, ".2")) { + break; + } + to_fuse.push_back((*state)->stages[stage_id]->iters[i]); + } + const auto& threadidx_it = state->fuse(stage_id, to_fuse); + if (GetExtent(threadidx_it) < + policy->cur_task_->hardware_params->warp_size) { + return -1; + } + state->bind_thread(stage_id, threadidx_it, kThreadX); + } + + return 0; +} + +int InitPopulationCooperativeFetching(const MetaTileRewritePolicyNode* policy, + State* state) { + for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { + // Do cooperative fetching with cache read stage + // For two stages: A -> B + // 1. A -> A_cache_read -> B + // * + // 2. A -> A_cache_write -> A_cache_read -> B + // * + if ((stage_id > 0 && HasCacheReadStage((*state), stage_id - 1) && + !HasCacheWriteStage((*state), stage_id - 1)) || + (stage_id > 1 && HasCacheReadStage((*state), stage_id - 2) && + HasCacheWriteStage((*state), stage_id - 2))) { + // Get spatial_split_step_ids from the root stage + std::unordered_set consumers; + std::vector spatial_split_step_ids; + const Stage& target_stage = (*state)->stages[stage_id]; + GetConsumers(policy->cur_task_, (*state), target_stage->op, &consumers); + CHECK_EQ(consumers.size(), 1); + int target_stage_id = OperationToStage(*consumers.begin(), (*state)); + GetSpaceSplitStepIds((*state), target_stage_id, &spatial_split_step_ids); + + // Fuse all axis to to do cooperative fetching + Iterator fused = state->fuse(stage_id, + (*state)->stages[stage_id]->iters); + // Left a vectorized cooperative fetching split placeholder + const auto& iters0 = state->split(stage_id, fused, {1}); + state->vectorize(stage_id, iters0[1]); + // Follow split to keep a same thread extent with the root stage + const auto& iters1 = state->follow_fused_split(stage_id, iters0[0], + spatial_split_step_ids, + 1, true); + state->bind_thread(stage_id, iters1[1], kThreadX); + } + } + + return 0; +} + +int InitPopulationChangeComputeLocation(const MetaTileRewritePolicyNode* policy, + State* state, std::mt19937* rand_gen) { + if(GetIntParam(policy->params, "disable_change_compute_location")) { + return 0; + } + + for (int stage_id = static_cast((*state)->stages.size()) - 1; stage_id >= 0; stage_id--) { + const Stage& stage = (*state)->stages[stage_id]; + + if (stage->op_type == kPlaceholder) { + continue; + } + + if (IsTiled(stage) || stage->compute_at == kInlined) { + continue; + } + + if (NeedsMultilevelTiling(policy->cur_task_, (*state), stage->op)) { + continue; + } + + std::unordered_set consumers; + + GetConsumers(policy->cur_task_, (*state), stage->op, &consumers); + if (consumers.empty()) { + continue; + } + + int target_stage_id; + if (consumers.size() == 1) { + target_stage_id = OperationToStage(*consumers.begin(), *state); + } else { + // check all consumers share a common root + int common_root_id = -1; + bool mismatch = false; + for (const auto& consumer : consumers) { + int consumer_stage_id = OperationToStage(consumer, *state); + int root_id = -1; + if ((*state)->stages[consumer_stage_id]->compute_at == kRoot) { + root_id = consumer_stage_id; + } else if ((*state)->stages[consumer_stage_id]->compute_at == kIter) { + root_id = (*state)->attach_map->stage_to_attach_iter.at(consumer_stage_id).first; + } else { + LOG(FATAL) << "Invalid case"; + } + + if (common_root_id == -1) { + common_root_id = root_id; + } else { + if (common_root_id != root_id) { + mismatch = true; + break; + } + } + } + + if (mismatch) { + continue; + } + target_stage_id = common_root_id; + } + + const Stage& target_stage = (*state)->stages[target_stage_id]; + std::set to_unroll_name_set; + if (target_stage->op->attrs.count(policy->always_unroll_key)) { + to_unroll_name_set = GetIterNameSetParam(target_stage->op->attrs, + policy->always_unroll_key); + } + + std::vector > candidates; + bool target_compute_at_other = target_stage->compute_at == kIter; + bool target_is_tiled = IsTiled(target_stage); + + bool visited_reduce = false; + // enumerate compute_at location at target_stage + int ct = 0; + for (const auto& target_iter : target_stage->iters) { + if (target_iter->iter_type == kReduce) { + visited_reduce = true; + if (!target_is_tiled) { // do not go into reduce iter + break; + } + } else if (target_iter->iter_type == kSpace) { + if (visited_reduce) { // do not go into inner tile + break; + } + } + + if (to_unroll_name_set.count(target_iter->name)) { + // Do not go into always unroll region + break; + } + + if (GetExtent(target_iter) == 1) { // skip iterators with length of 1 + continue; + } + if (target_compute_at_other && target_iter->iter_type == kSpace && + StrEndsWith(target_iter->name, ".0")) { + // skip the first level iterators if target stage compute_at another stage + // In this case, the lengths of first level iterators are always one + continue; + } + candidates.emplace_back(target_stage_id, target_iter); + + if ((*state)->attach_map->iter_to_attached_stages.count( + std::make_pair(target_stage_id, ct++))) { + break; + } + } + + // if the target_stage is already compute_at another stage X, try also compute_at X + // We call stage X as `target_target_stage` + if (target_compute_at_other) { + int target_target_stage_id; + target_target_stage_id = (*state)->attach_map->stage_to_attach_iter.at( + target_stage_id).first; + const Stage& target_target_stage = (*state)->stages[target_target_stage_id]; + if (target_target_stage->op->attrs.count(policy->always_unroll_key)) { + to_unroll_name_set = GetIterNameSetParam(target_target_stage->op->attrs, + policy->always_unroll_key); + } else { + to_unroll_name_set.clear(); + } + + int ct = 0; + for (const auto& target_target_iter : target_target_stage->iters) { + if (target_target_iter->iter_type == kReduce || + (*state)->attach_map->iter_to_attached_stages.count( + std::make_pair(target_target_stage_id, ct++))) { + break; + } + + if (to_unroll_name_set.count(target_target_iter->name)) { + // Do not go into always unroll region + break; + } + + if (GetExtent(target_target_iter) == 1) { // skip iterators with length of 1 + continue; + } + + candidates.push_back(std::make_pair(target_target_stage_id, target_target_iter)); + } + } + + int choice = (*rand_gen)() % (candidates.size() + 2); + + if (choice == 0) { + if (!HasReduceIter(stage)) { + state->compute_inline(stage_id); + } + } else if (choice == 1) { + state->compute_root(stage_id); + } else { + choice = choice - 2; + state->compute_at(stage_id, candidates[choice].first, candidates[choice].second); + } + } + + return 0; +} + +int InitPopulationParallel(const MetaTileRewritePolicyNode* policy, + State* state) { + std::function annotate_parallel; + + annotate_parallel = [&annotate_parallel]( + const MetaTileRewritePolicyNode* policy, State* state, int stage_id, int iter_offset) { + const Stage& stage = (*state)->stages[stage_id]; + + std::vector to_fuse; + int64_t parallel_degree = 1; + + // strategy: try to fuse and parallel the outermost n iterators + // Stop if we meet reduce iterator or we have enough parallel degree + size_t iter_id = iter_offset; + for (; iter_id < stage->iters.size(); ++iter_id) { + const Iterator& it = stage->iters[iter_id]; + if (it->iter_type == kReduce || it->annotation != kNone) { + break; + } + + to_fuse.push_back(it); + parallel_degree *= GetExtent(it); + + if (parallel_degree > policy->cur_task_->hardware_params->num_cores * 16) { + break; + } + + if ((*state)->attach_map->iter_to_attached_stages.count( + std::make_pair(stage_id, iter_id))) { + break; + } + } + + if (parallel_degree == 1) { + auto res = (*state)->attach_map->iter_to_attached_stages.find(std::make_pair(stage_id, iter_id)); + if (res != (*state)->attach_map->iter_to_attached_stages.end()) { + for (int attached_stage_id : res->second) { + annotate_parallel(policy, state, attached_stage_id, 0); + } + annotate_parallel(policy, state, stage_id, iter_id + 1); + } + } + + if (!to_fuse.empty()) { + if (to_fuse.size() == 1) { + state->parallel(stage_id, to_fuse[0]); + } else { + Iterator fused_iter = state->fuse(stage_id, to_fuse); + state->parallel(stage_id, fused_iter); + } + } + }; + + for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { + const Stage& stage = (*state)->stages[stage_id]; + if (stage->compute_at != kRoot || stage->op_type == kPlaceholder) { + continue; + } + + annotate_parallel(policy, state, stage_id, 0); + } + + return 0; +} + +int InitPopulationVectorization(const MetaTileRewritePolicyNode* policy, + State* state, std::mt19937* rand_gen) { + for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { + const Stage& stage = (*state)->stages[stage_id]; + + if (stage->op_type == kPlaceholder) { + continue; + } + + // Skip cooperative fetching stage + if (IS_GPU(policy->cur_task_) && + HasCacheReadStage((*state), stage_id - 1)) { + continue; + } + + // try to fuse and vectorize the space iterators in the inner most tile + int cum_length_prod = 1; + + std::set to_unroll_name_set; + if (stage->op->attrs.count(policy->always_unroll_key)) { + to_unroll_name_set = GetIterNameSetParam(stage->op->attrs, + policy->always_unroll_key); + } + + int num_fusible = 0; + while (num_fusible < static_cast(stage->iters.size())) { + int iter_id = static_cast(stage->iters.size()) - 1 - num_fusible; + if ((*state)->attach_map->iter_to_attached_stages.count( + std::make_pair(stage_id, iter_id))) { + break; + } + + const Iterator& it = stage->iters[iter_id]; + + // Stop if we meet a reduce iterator + if (it->iter_type == kReduce || it->annotation != kNone || + to_unroll_name_set.count(it->name)) { + break; + } + + // Stop if the memory access is not continuous (vectorizable) + // Note: The check is too hard, so we use heuristic here + if (IsTiled(stage) && num_fusible != 0) { + // If the stage is tiled, then the memory access must not be continuous + // for the innermost two iterators + break; + } + + cum_length_prod *= GetExtent(it); + if (cum_length_prod > policy->cur_task_->hardware_params->max_unroll_vec) { + break; + } + + num_fusible++; + } + + if (num_fusible > 1) { + num_fusible = 1 + (*rand_gen)() % (num_fusible - 1); // Select a random range to fuse + } + + if (num_fusible == 1) { + state->vectorize(stage_id, stage->iters.back()); + } else if (num_fusible > 1) { + std::vector to_fuse(stage->iters.end() - num_fusible, + stage->iters.end()); + state->vectorize(stage_id, state->fuse(stage_id, to_fuse)); + } + } + + return 0; +} + +int InitPopulationUnroll(const MetaTileRewritePolicyNode* policy, + State* state, std::mt19937* rand_gen) { + for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { + const Stage& stage = (*state)->stages[stage_id]; + + if (stage->op_type == kPlaceholder) { + continue; + } + + if (stage->op->attrs.count(policy->always_unroll_inner_key)) { + // Special unroll policy + auto to_unroll_name_set = GetIterNameSetParam(stage->op->attrs, + policy->always_unroll_inner_key); + std::set visited_names; + + // Unroll the space iterators and reduce iterators listed in the attrs + // in the innermost tile + int n = static_cast(stage->iters.size()) - 1; + visited_names.clear(); + while (n >= 0) { + const Iterator& it = stage->iters[n]; + + // If we meet two iterators that come from a same original iterator, + // then we are out of the innermost tile + size_t size_before = visited_names.size(); + ExtractOriginalIterators(it->name, &visited_names); + if (size_before == visited_names.size()) { + break; + } + + std::set name; + ExtractOriginalIterators(it->name, &name); + if (name.size() == 1 && to_unroll_name_set.count(*name.begin())) { + state->unroll(stage_id, it); + } + + n--; + } + } else if (stage->op->attrs.count(policy->always_unroll_key)) { + // Special unroll policy + auto to_unroll_name_set = GetIterNameSetParam(stage->op->attrs, + policy->always_unroll_key); + + // Unroll the space iterators and reduce iterators listed in the attrs + int n = static_cast(stage->iters.size()) - 1; + while (n >= 0) { + const Iterator& it = stage->iters[n]; + if (to_unroll_name_set.count(it->name)) { + state->unroll(stage_id, it); + } + n--; + } + } else if (HasReduceIter(stage)) { + // use auto unroll for multi level tiled stage + int value = policy->auto_unroll_configs[ + (*rand_gen)() % policy->auto_unroll_configs.size()]; + state->pragma(stage_id, (*state)->stages[stage_id]->iters[0], + std::string("auto_unroll_max_step") + "$" + std::to_string(value)); + } + } + + return 0; +} + +void MetaTileRewritePolicyNode::SampleInitPopulation(const std::vector& meta_structures, + int out_size, std::vector* out_states) { + std::uniform_real_distribution<> dis(0.0, 1.0); + int continue_count = 0; + + // TODO(...): Maybe try muti thread here + while (static_cast(out_states->size()) < out_size && + continue_count < out_size * 10) { + State tmp_s = meta_structures[rand_gen_() % meta_structures.size()]; + + InitPopulationFillTileSize(this, &tmp_s, &rand_gen_, &split_memo_); + + if (IS_GPU(cur_task_)) { + tmp_s = cur_task_->compute_dag.InferBound(tmp_s); + + if (InitPopulationThreadBind(this, &tmp_s)) { + continue_count++; + continue; + } + + InitPopulationCooperativeFetching(this, &tmp_s); + } else { + InitPopulationChangeComputeLocation(this, &tmp_s, &rand_gen_); + + tmp_s = cur_task_->compute_dag.InferBound(tmp_s); + + InitPopulationParallel(this, &tmp_s); + } + + InitPopulationVectorization(this, &tmp_s, &rand_gen_); + + InitPopulationUnroll(this, &tmp_s, &rand_gen_); + + out_states->push_back(std::move(tmp_s)); + } + + StdCout(verbose_) << "Sample Initial Population\t\t#s: " + << out_states->size() << std::endl; +} + +void MetaTileRewritePolicyNode::EvolutionarySearch( + const std::vector& init_population, + int num_best_states, std::vector* best_states) { + auto tic_begin = std::chrono::high_resolution_clock::now(); + + // Set parameters for genetic algorithm + int population = GetIntParam(params, "evolutionary_search_population"); + int num_iters = GetIntParam(params, "evolutionary_search_num_iters"); + double mutation_prob = GetDoubleParam(params, "evolutionary_search_mutation_prob"); + int num_cross_over = static_cast(population * 0.0); // NOT IMPLEMENTED currently + int num_cross_over_trial_upper_bound = num_cross_over * 3; + CostModel cost_model = program_cost_model; + + // Two ping pong buffers to avoid copy + std::vector states_buf1, states_buf2; + std::vector *pnow = &states_buf1, *pnext = &states_buf2; + states_buf1.reserve(population); + states_buf2.reserve(population); + states_buf1.insert(states_buf1.begin(), init_population.begin(), init_population.end()); + + // A heap to keep the best states during evolution + using StateItem = std::pair; + auto cmp = [](const StateItem& left, const StateItem& right) { + return left.second > right.second; + }; + std::vector heap; + std::unordered_set in_heap(measured_states_set_); + heap.reserve(num_best_states); + + // auxiliary global variables + std::vector scores; + std::vector prefix_sum_probs; + double max_score = 0.0; + scores.reserve(population); + prefix_sum_probs.reserve(population); + std::uniform_real_distribution<> dis(0.0, 1.0); + int mutation_fail_ct = 0; + + // Genetic Algorithm + for (int k = 0; k < num_iters + 1; ++k) { + // Maintain the heap + cur_task_->compute_dag.InferBound(pnow); + PruneUndefined(pnow); + cost_model->Predict(cur_task_, *pnow, &scores); + + for (size_t i = 0; i < pnow->size(); ++i) { + const State& state = (*pnow)[i]; + std::string state_str = state.ToStr(); + + if (in_heap.count(state_str) == 0) { + if (static_cast(heap.size()) < num_best_states) { + heap.emplace_back((*pnow)[i], scores[i]); + std::push_heap(heap.begin(), heap.end(), cmp); + in_heap.insert(state_str); + } else if (scores[i] > heap.front().second) { + std::string old_state_str = heap.front().first.ToStr(); + in_heap.erase(old_state_str); + in_heap.insert(state_str); + + std::pop_heap(heap.begin(), heap.end(), cmp); + heap.back() = StateItem(state, scores[i]); + std::push_heap(heap.begin(), heap.end(), cmp); + } + if (scores[i] > max_score) { + max_score = scores[i]; + } + } + } + + if (k % 5 == 0 || k == num_iters) { + StdCout(verbose_) << "GA Iter: " << k << std::fixed << std::setprecision(4) + << "\tMax score: " << max_score + << "\tMin score: " << heap.front().second + << "\tPop size: " << pnow->size() << std::endl; + } + + if (k == num_iters) { + break; + } + + // Compute selection probability + double sum = 0.0; + prefix_sum_probs.resize(scores.size()); + for (size_t i = 0; i < scores.size(); ++i) { + sum += std::max(scores[i], 0.0f); + prefix_sum_probs[i] = sum; + } + for (size_t i = 0; i < scores.size(); ++i) { + prefix_sum_probs[i] = prefix_sum_probs[i] / sum; + } + + // Do cross over + int ct = 0; + while (static_cast(pnext->size()) < num_cross_over + && ct < num_cross_over_trial_upper_bound) { + int p1 = RandomChoose(prefix_sum_probs, &rand_gen_); + int p2 = RandomChoose(prefix_sum_probs, &rand_gen_); + + if (p1 == p2) { + pnext->push_back((*pnow)[p1]); + } else { + State tmp_s = CrossOverState((*pnow)[p1], (*pnow)[p2]); + if (tmp_s.defined()) { + pnext->push_back(std::move(tmp_s)); + } + } + ct++; + } + + // Do mutation + mutation_fail_ct = 0; + while (static_cast(pnext->size()) < population) { + int id = RandomChoose(prefix_sum_probs, &rand_gen_); + + if (dis(rand_gen_) < mutation_prob) { + const std::vector rule_prefix_sum_probs{0.9, 0.95, 1.0}; + + int rule_id = RandomChoose(rule_prefix_sum_probs, &rand_gen_); + + State tmp_s; + if (rule_id == 0) { + tmp_s = RandomMutateTileSize((*pnow)[id], &split_memo_, &rand_gen_, + cur_task_->hardware_params->max_innermost_split_factor); + } else if (rule_id == 1) { + tmp_s = RandomMutateMaxUnrollStep((*pnow)[id], &rand_gen_, auto_unroll_configs); + } else if (rule_id == 2) { + tmp_s = MutataParallel((*pnow)[id], &split_memo_, &rand_gen_, cur_task_); + } + + if (tmp_s.defined()) { + pnext->push_back(std::move(tmp_s)); + } else { + mutation_fail_ct++; + } + } else { + pnext->push_back((*pnow)[id]); + } + } + + std::swap(pnext, pnow); pnext->clear(); + } + + // Copy best states in the heap to out_states + std::sort(heap.begin(), heap.end(), cmp); + best_states->clear(); + for (auto& item : heap) { + best_states->push_back(std::move(item.first)); + } + + double duration = std::chrono::duration_cast >( + std::chrono::high_resolution_clock::now()- tic_begin).count(); + StdCout(verbose_) << "EvolutionarySearch\t\t#s: " << best_states->size() + << "\tTime elapsed: " + << std::fixed << std::setprecision(2) << duration << std::endl; +} + +} // namespace ansor +} // namespace tvm diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.h b/src/ansor/search_policy/meta_tile_rewrite_policy.h new file mode 100644 index 000000000000..56a75f8e52fe --- /dev/null +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.h @@ -0,0 +1,101 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file ansor/meta_tile_rewrite_policy.h + * \brief A search policy that search with meta tiling structure and random rewrite + */ +#ifndef TVM_ANSOR_SEARCH_POLICY_META_TILE_REWRITE_POLICY_H_ +#define TVM_ANSOR_SEARCH_POLICY_META_TILE_REWRITE_POLICY_H_ + +#include +#include +#include +#include +#include +#include "search_policy.h" +#include "../cost_model/cost_model.h" +#include "../utils.h" + + +namespace tvm { +namespace ansor { + +/*! Multi stage search policy */ +class MetaTileRewritePolicyNode: public SearchPolicyNode { + public: + CostModel program_cost_model; + + /* this->params is used to store the following arguments + * int evolutionary_search_population // The population size for evolutionary search + * int evolutionary_search_mutation_prob // The probability of mutation for evolutionary search + * int evolutionary_search_num_iters; // The number of iterations for evolutionary search + * double local_mutation_use_measured_ratio; // The maximum percentage of measured states in the initial + * // population for evolutionary search + * double eps_greedy; // Always allocate this percentage of measurements to random sampled states + * str cpu_multi_level_tiling_structure // The structure of multi-level tiling for CPU + * str gpu_multi_level_tiling_structure // The structure of multi-level tiling for GPU + */ + Map params; + + static SearchPolicy make(CostModel program_cost_model, + Map params, + int seed); + + // Search and make n_trails measurements + // Return the best state + State Search(SearchTask task, int n_trials, + int early_stopping, int num_measure_per_iter, + int verbose, ProgramMeasurer measurer) final; + + // Continue search. This is used by JointTuner + std::pair, Array > ContinueSearchOneRound( + SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) final; + + static constexpr const char *_type_key = "ansor.MetaTileRewritePolicy"; + static const std::vector auto_unroll_configs; + + TVM_DECLARE_FINAL_OBJECT_INFO(MetaTileRewritePolicyNode, SearchPolicyNode); + + SearchTask cur_task_; // The current task + + friend class MetaTileRewritePolicyNodeTest; // Hack friend class for UT + protected: + // Pick states from best states and random states with eps-greedy policy + void PickStatesWithEpsGreedy(std::vector* inputs, + const std::vector& best_states, + const std::vector& random_states, int remaining_n_trials); + + private: + // Run one round of the search pipeline + void SearchOneRound(std::vector* best_states, + int num_random_states, std::vector* random_states); + + // Synthesize meta tiling structure without tile size + void SynthesizeMetaStructure(std::vector* out_states); + + // Sample init population + void SampleInitPopulation(const std::vector& meta_structures, + int out_size, std::vector* out_states); + + // Perform evolutionary search + void EvolutionarySearch(const std::vector& init_population, + int num_best_states, std::vector* best_states); + + SplitFactorizationMemo split_memo_; // Memorize split space for Split + std::mt19937 rand_gen_; // Random generator + int verbose_; // Verbose level (0 means silent) + int num_measure_per_iter_; // The number of states to measure per iteration + + // The set of the already measured states. We store the string format for redundancy check + std::unordered_set measured_states_set_; + + // The array of already measured states. + std::vector measured_states_vector_; + + // The throughputs of already measured states + std::vector measured_states_throughputs_; +}; + +} // namespace ansor +} // namespace tvm + +#endif // TVM_ANSOR_SEARCH_POLICY_META_TILE_REWRITE_POLICY_H_ diff --git a/src/ansor/search_policy/search_policy.cc b/src/ansor/search_policy/search_policy.cc new file mode 100644 index 000000000000..89bfeb1a8edd --- /dev/null +++ b/src/ansor/search_policy/search_policy.cc @@ -0,0 +1,14 @@ +/*! + * Copyright (c) 2020 by Contributors + */ + +#include "search_policy.h" + +namespace tvm { +namespace ansor { + +TVM_REGISTER_OBJECT_TYPE(SearchPolicyNode); + +} // namespace ansor +} // namespace tvm + diff --git a/src/ansor/search_policy/search_policy.h b/src/ansor/search_policy/search_policy.h new file mode 100644 index 000000000000..5bd9fb3118b1 --- /dev/null +++ b/src/ansor/search_policy/search_policy.h @@ -0,0 +1,53 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file ansor/search_policy.h + * \brief Base class of search policy + */ +#ifndef TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_ +#define TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_ + +#include +#include +#include +#include +#include "../search_task.h" +#include "../measure.h" + +namespace tvm { +namespace ansor { + +class SearchPolicy; + +/*! \brief The base class for search policy */ +class SearchPolicyNode : public Object { + public: + virtual State Search(SearchTask task, int n_trials, + int early_stopping, int num_measure_per_iter, + int verbose, ProgramMeasurer measurer) = 0; + + virtual std::pair, Array > ContinueSearchOneRound( + SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) = 0; + + // Dict keys + static constexpr const char* always_unroll_inner_key = "ansor_always_unroll_inner"; + static constexpr const char* always_unroll_key = "ansor_always_unroll"; + static constexpr const char* no_split_at_inner_key = "ansor_no_split_at_inner"; + static constexpr const char* no_split_at_outer_key = "ansor_no_split_at_outer"; + static constexpr const char* debug_skip_region_key = "ansor_debug_skip_region"; + static constexpr const char* last_split_is_one_key = "ansor_last_split_is_one"; + + // Flag keys + static constexpr const char* always_compute_inline_key = "ansor_always_compute_inline"; + static constexpr const char* no_cache_write_key = "ansor_no_cache_write"; + static constexpr const char* no_cache_read_key = "ansor_no_cache_read"; + static constexpr const char* tensor_core_support_key = "ansor_tensor_core_support"; + + static constexpr const char *_type_key = "ansor.SearchPolicy"; + TVM_DECLARE_BASE_OBJECT_INFO(SearchPolicyNode, Object); +}; +TVM_DEFINE_MUTABLE_NODE_REF(SearchPolicy, SearchPolicyNode); + +} // namespace ansor +} // namespace tvm + +#endif // TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_ diff --git a/src/ansor/search_policy/utils.cc b/src/ansor/search_policy/utils.cc new file mode 100644 index 000000000000..9c597b4eb811 --- /dev/null +++ b/src/ansor/search_policy/utils.cc @@ -0,0 +1,609 @@ +/*! + * Copyright (c) 2020 by Contributors + */ + +#include "utils.h" +#include "search_policy.h" + +namespace tvm { +namespace ansor { + +void GetSpaceSplitStepIds(const State& s, int stage_id, std::vector* spatial_split_step_ids) { + auto pop = s->stages[stage_id]->op.as(); + CHECK(pop != nullptr); + + auto no_split_name_pair = QueryNoSplitAxis(s->stages[stage_id]); + std::set no_split_at_inner_name_set = no_split_name_pair.first; + std::set no_split_at_outer_name_set = no_split_name_pair.second; + size_t reduce_count = 0; + for (const auto axis : pop->reduce_axis) { + if (!no_split_at_inner_name_set.count(axis->var->name_hint) && + !no_split_at_outer_name_set.count(axis->var->name_hint)) { + reduce_count++; + } + } + + for (int i = static_cast(s->transform_steps.size()) - 1; i >= 0; --i) { + if (s->transform_steps[i]->IsInstance() || + s->transform_steps[i]->IsInstance() || + s->transform_steps[i]->IsInstance()) { + if (stage_id > s->transform_steps[i]->stage_id) { + stage_id--; + } + } else if (auto ps = s->transform_steps[i].as()) { + if (stage_id == ps->stage_id) { + if (reduce_count) { + reduce_count--; + } else { + spatial_split_step_ids->push_back(i); + } + } + } + } +} + +// Query axes that should not be splitted according to the attribute from tvm.compute +std::pair, std::set > QueryNoSplitAxis(const Stage& stage) { + std::pair, std::set > ret; + if (stage->op->attrs.count(SearchPolicyNode::no_split_at_inner_key)) { + ret.first = GetIterNameSetParam(stage->op->attrs, SearchPolicyNode::no_split_at_inner_key); + } + if (stage->op->attrs.count(SearchPolicyNode::no_split_at_outer_key)) { + ret.second = GetIterNameSetParam(stage->op->attrs, SearchPolicyNode::no_split_at_outer_key); + } + return ret; +} + +// Query axes that last split is one +std::set QueryLastSplitIsOneAxis(const Stage& stage) { + std::set ret; + if (stage->op->attrs.count(SearchPolicyNode::last_split_is_one_key)) { + ret = GetIterNameSetParam(stage->op->attrs, SearchPolicyNode::last_split_is_one_key); + } + return ret; +} + +// Apply multi-tiling structure according to a string format +State DoMultiLevelTiling(const State& state, int stage_id, const std::string& format, + std::vector* spatial_split_step_ids) { + std::vector > space_levels; + std::vector > reduce_levels; + std::vector space_outer, space_inner, reduce_outer, reduce_inner; + std::vector split_res; + + for (const auto c : format) { + if (tolower(c) == 's') { + space_levels.emplace_back(); + } else if (tolower(c) == 'r') { + reduce_levels.emplace_back(); + } else { + LOG(FATAL) << "Invalid multi level tiling format: " << format; + } + } + size_t n_space = space_levels.size(); + size_t n_reduce = reduce_levels.size(); + + spatial_split_step_ids->clear(); + + State tmp_s = state; + const Stage& stage = state->stages[stage_id]; + auto no_split_name_pair = QueryNoSplitAxis(stage); // handle special split strategy + auto last_split_is_one_name_set = QueryLastSplitIsOneAxis(stage); + std::set no_split_at_inner_name_set = no_split_name_pair.first; + std::set no_split_at_outer_name_set = no_split_name_pair.second; + + for (const auto& iter : state->stages[stage_id]->iters) { + if (iter->iter_type == kSpace) { + if (!no_split_at_inner_name_set.count(iter->name) && + !no_split_at_outer_name_set.count(iter->name)) { + CHECK_GE(n_space, 1); + int tmp_n_space = n_space; + + if (last_split_is_one_name_set.count(iter->name)) { + tmp_n_space--; + } + + if (tmp_n_space == 1) { + space_levels[0].push_back(iter); + } else { + split_res = tmp_s.split(stage_id, iter, std::vector(tmp_n_space - 1)); + for (int i = 0; i < tmp_n_space; i++) { + space_levels[i].push_back(std::move(split_res[i])); + } + spatial_split_step_ids->push_back(tmp_s->transform_steps.size() - 1); + } + } else { + if (no_split_at_inner_name_set.count(iter->name)) { + space_inner.push_back(iter); + } + if (no_split_at_outer_name_set.count(iter->name)) { + space_outer.push_back(iter); + } + } + } else if (iter->iter_type == kReduce) { + // for reduce iterator, split it into two iterators + if (!no_split_at_inner_name_set.count(iter->name) && + !no_split_at_outer_name_set.count(iter->name)) { + CHECK_GE(n_reduce, 1); + if (n_reduce == 1) { + reduce_levels[0].push_back(iter); + } else { + split_res = tmp_s.split(stage_id, iter, std::vector(n_reduce - 1)); + for (size_t i = 0; i < n_reduce; i++) { + reduce_levels[i].push_back(std::move(split_res[i])); + } + } + } else { + if (no_split_at_inner_name_set.count(iter->name)) { + reduce_inner.push_back(iter); + } + if (no_split_at_outer_name_set.count(iter->name)) { + reduce_outer.push_back(iter); + } + } + } else { + LOG(FATAL) << "Invalid iter type: " << iter->iter_type; + } + } + + if (!space_outer.empty()) { + CHECK(!space_levels.empty()); + space_levels.front().insert(space_levels.front().begin(), + space_outer.begin(), space_outer.end()); + } + if (!space_inner.empty()) { + CHECK(!space_levels.empty()); + space_levels.back().insert(space_levels.back().begin(), + space_inner.begin(), space_inner.end()); + } + + if (!reduce_outer.empty()) { + CHECK(!reduce_levels.empty()); + reduce_levels.front().insert(reduce_levels.front().begin(), + reduce_outer.begin(), reduce_outer.end()); + } + if (!reduce_inner.empty()) { + CHECK(!reduce_levels.empty()); + reduce_levels.back().insert(reduce_levels.back().begin(), + reduce_inner.begin(), reduce_inner.end()); + } + + std::vector order; + int space_ct = 0, reduce_ct = 0; + for (const auto c : format) { + if (tolower(c) == 's') { + order.insert(order.end(), std::make_move_iterator(space_levels[space_ct].begin()), + std::make_move_iterator(space_levels[space_ct].end())); + space_ct++; + } else if (tolower(c) == 'r') { + order.insert(order.end(), std::make_move_iterator(reduce_levels[reduce_ct].begin()), + std::make_move_iterator(reduce_levels[reduce_ct].end())); + reduce_ct++; + } else { + LOG(FATAL) << "Invalid multi level tiling format: " << format; + } + } + + tmp_s.reorder(stage_id, order); + return tmp_s; +} + +// Apply tiling structure: space, space +// But use tile sizes from other SplitStep +State FollowTiling(const State& state, int stage_id, + const std::vector& split_step_ids, int n_split) { + if (n_split < 1 || n_split > 3) { + LOG(FATAL) << "Invalid split parts, currently only support 1, 2 and 3"; + } + // Apply up to three-level tiling structure: space_L0, space_L1, space_L2 + std::vector space_0, space_1, space_2, space_3; + std::vector split_res, tmp_order; + + auto pop = state->stages[stage_id]->op.as(); + CHECK(pop != nullptr); + const Stage& stage = state->stages[stage_id]; + auto no_split_name_pair = QueryNoSplitAxis(stage); // handle special split strategy + const std::set& no_split_at_inner_name_set = no_split_name_pair.first; + const std::set& no_split_at_outer_name_set = no_split_name_pair.second; + int no_split_at_inner_name_in_stage_cnt = 0; + int no_split_at_outer_name_in_stage_cnt = 0; + for (const auto& iter : state->stages[stage_id]->iters) { + no_split_at_inner_name_in_stage_cnt += no_split_at_inner_name_set.count(iter->name); + no_split_at_outer_name_in_stage_cnt += no_split_at_outer_name_set.count(iter->name); + } + + CHECK_EQ(state->stages[stage_id]->iters.size() + - no_split_at_inner_name_in_stage_cnt + - no_split_at_outer_name_in_stage_cnt, + split_step_ids.size()); + + State tmp_s = state; + int ct = 0; + for (const auto& iter : state->stages[stage_id]->iters) { + if (iter->iter_type == kSpace) { + // For spatial iterator, split it into multi iterators + if (!no_split_at_inner_name_set.count(iter->name) && + !no_split_at_outer_name_set.count(iter->name)) { + IteratorAnnotation ann_type = iter->annotation; + split_res = tmp_s.follow_split(stage_id, iter, split_step_ids[ct], + n_split); + // Restore annotation. Move unroll and vectorize to inner, move parallel + // to outer + switch (ann_type) { + case kUnroll: + split_res[n_split] = tmp_s.unroll(stage_id, split_res[n_split]); + break; + case kVectorize: + split_res[n_split] = tmp_s.vectorize(stage_id, split_res[n_split]); + break; + case kParallel: + split_res[0] = tmp_s.parallel(stage_id, split_res[0]); break; + default: + break; + } + + space_0.push_back(std::move(split_res[0])); + space_1.push_back(std::move(split_res[1])); + if (n_split >= 2) { + space_2.push_back(std::move(split_res[2])); + if (n_split == 3) { + space_3.push_back(std::move(split_res[3])); + } + } + ct++; + } else { + if (no_split_at_outer_name_set.count(iter->name)) { + space_0.push_back(iter); + } + if (no_split_at_inner_name_set.count(iter->name)) { + if (n_split == 1) { + space_1.push_back(iter); + } else if (n_split == 2) { + space_2.push_back(iter); + } else { + CHECK_EQ(n_split, 3); + space_3.push_back(iter); + } + } + } + } else { + LOG(FATAL) << "Invalid iter type: " << iter->iter_type; + } + } + if (n_split == 3) { + ConcatenateMove(&tmp_order, &space_0, &space_1, &space_2, &space_3); + } else if (n_split == 2) { + ConcatenateMove(&tmp_order, &space_0, &space_1, &space_2); + } else { + ConcatenateMove(&tmp_order, &space_0, &space_1); + } + tmp_s.reorder(stage_id, tmp_order); + return tmp_s; +} + +// Randomly mutate the tile size of one SplitStep +State RandomMutateTileSize(const State& old_state, SplitFactorizationMemo* split_memo, + std::mt19937* random_gen, int max_innermost_split_factor) { + State tmp_s = old_state; + + // Extract all SplitStep + std::vector split_step_ids; + for (size_t i = 0; i < tmp_s->transform_steps.size(); ++i) { + if (auto ps = tmp_s->transform_steps[i].as()) { + if (ps->extent.defined() && ps->extent->IsInstance() && + GetIntImm(ps->lengths.back()) <= max_innermost_split_factor) { + split_step_ids.push_back(i); + } + } + } + if (split_step_ids.empty()) { + return State(); + } + + // Find a SplitStep with extent != 1 + int retry_ct = 0; + int64_t extent = 1; + int step_id; + const SplitStepNode* ps; + + do { + step_id = split_step_ids[(*random_gen)() % split_step_ids.size()]; + ps = tmp_s->transform_steps[step_id].as(); + CHECK(ps != nullptr); + extent = GetIntImm(ps->extent); + retry_ct += 1; + } while (retry_ct < static_cast(split_step_ids.size()) << 2 && extent == 1); + + if (extent == 1) { + return State(); + } + + // Mutate tile size + std::vector lengths(ps->lengths.size() + 1, 1); + for (int i = 0; i < static_cast(ps->lengths.size()); ++i) { + lengths[i + 1] = GetIntImm(ps->lengths[i]); + } + lengths[0] = extent / ElementProduct(lengths); + + std::vector random_perm; + RandomPermutation(lengths.size(), &random_perm, random_gen); + + for (size_t i = 0; i < random_perm.size(); ++i) { + size_t src_idx = random_perm[i]; + int length = lengths[src_idx]; + + if (length == 1) { + continue; + } + + // Divide one factor from lengths[src_idx] and multiply it to lengths[dst_idx] + size_t dst_idx = random_perm[(i + 1) % random_perm.size()]; + + const std::vector& factors = split_memo->GetFactors(length); + CHECK_GE(factors.size(), 1); + + int divide_factor; + if (dst_idx == lengths.size() - 1) { + // Maintain the restriction of hardware_params.max_innermost_split_factor + int max_factor_index = static_cast(factors.size()) - 1; + for (; max_factor_index >= 1; max_factor_index--) { + if (factors[max_factor_index] * lengths[dst_idx] <= max_innermost_split_factor) { + break; + } + } + if (max_factor_index == 0) { + // failed on this dst_idx, try next one + continue; + } + divide_factor = factors[1 + (*random_gen)() % (max_factor_index)]; + } else { + divide_factor = factors[1 + (*random_gen)() % (factors.size() - 1)]; + } + + std::vector new_lengths; + for (size_t j = 1; j < lengths.size(); ++j) { + if (j == src_idx) { + new_lengths.emplace_back(lengths[j] / divide_factor); + } else if (j == dst_idx) { + new_lengths.emplace_back(lengths[j] * divide_factor); + } else { + new_lengths.emplace_back(lengths[j]); + } + } + + CHECK_LE(GetIntImm(new_lengths.back()), max_innermost_split_factor); + + auto pstate = tmp_s.CopyOnWrite(); + pstate->transform_steps[step_id] = + SplitStepNode::make(ps->stage_id, ps->iter_id, ps->extent, new_lengths, ps->inner_to_outer); + return tmp_s; + } + + return State(); +} + +// Randomly mutate the value of one auto_unroll_max_step PragmaStep +State RandomMutateMaxUnrollStep(const State& old_state, std::mt19937* random_gen, + const std::vector& auto_unroll_configs) { + State tmp_s = old_state; + + // Extract all auto_unroll_max_step pragma steps. + std::vector annotate_steps; + for (size_t i = 0; i < old_state->transform_steps.size(); ++i) { + if (auto ps = tmp_s->transform_steps[i].as()) { + if (ps->pragma_type.find("auto_unroll_max_step") != std::string::npos) { + annotate_steps.push_back(i); + } + } + } + if (annotate_steps.empty()) { + return State(); + } + + // Randomly pick one step. + auto step_id = annotate_steps[(*random_gen)() % annotate_steps.size()]; + auto ps = tmp_s->transform_steps[step_id].as(); + auto val = std::to_string(auto_unroll_configs[(*random_gen)() % auto_unroll_configs.size()]); + + auto pstate = tmp_s.CopyOnWrite(); + pstate->transform_steps[step_id] = PragmaStepNode::make( + ps->stage_id, ps->iter_id, std::string("auto_unroll_max_step") + "$" + val); + return tmp_s; +} + +// Mutate a parallel loop. +State MutataParallel(const State& state, SplitFactorizationMemo* split_memo, + std::mt19937* random_gen, SearchTask& task, int verbose) { + // To make this mutation simple but promising, we only focus on a specific case that + // parallel was added to the outermost loop and the loop is generated by fusing other loops. + // In short, we mutate the step pattern of (fuse -> parallel). + + // Extract all parallel steps. + std::vector parallel_steps; + for (size_t s = 0; s < state->transform_steps.size(); ++s) { + auto ps = state->transform_steps[s].as(); + if (!ps || ps->annotation != kParallel) { + continue; + } + parallel_steps.push_back(s); + } + if (parallel_steps.size() == 0) { + StdCout(verbose) << "Parallel mutation failed: No parallel annotations" << std::endl; + return State(); + } + + // Randomly pick one step. + int retry_ct = 0; + size_t step_id = 0; + size_t stage_id = 0; + do { + step_id = parallel_steps[(*random_gen)() % parallel_steps.size()]; + auto step = state->transform_steps[step_id].as(); + stage_id = step->stage_id; + + // Check assumptions. + auto iter_id = step->iter_id; + if (iter_id == 0 && step_id > 0 && state->transform_steps[step_id - 1].as()) { + break; + } + retry_ct++; + } while (retry_ct <= 3); + + if (retry_ct > 3) { + StdCout(verbose) << "Parallel mutation failed: No valid parallel annotations" << std::endl; + return State(); + } + + // 0: fuse less; 1: fuse more. + std::vector fuse_dir = {0.5, 1.0}; + + // The iter is an attached target so we can only fuse less. + if (state->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, 0)) > 0) { + fuse_dir[0] = 1.0; + } + + // Determine the fuse direction. + auto fuse_step = state->transform_steps[step_id - 1].as(); + std::vector fused_ids = fuse_step->fused_ids; + int iter_offset = 0; + if (RandomChoose(fuse_dir, random_gen) == 0) { + StdCout(verbose) << "Parallel mutation: release iter " << fused_ids.back() << std::endl; + fused_ids.pop_back(); + iter_offset = 1; + } else { + StdCout(verbose) << "Parallel mutation: include iter " << fused_ids.back() + 1 << std::endl; + fused_ids.push_back(fused_ids.back() + 1); + iter_offset = -1; + } + + // Replay a new state. + State tmp_s = task->compute_dag.GetInitState(); + for (size_t s = 0; s < state->transform_steps.size(); ++s) { + auto step = state->transform_steps[s]; + if (s == step_id - 1) { + step = FuseStepNode::make(step->stage_id, fused_ids); + } else if (s > step_id && step->stage_id == static_cast(stage_id)) { + // Since we change the loop structure, iter ID in later steps to the same stage + // has to be adjusted. + auto ps = step.as(); + if (ps) { + CHECK_LE(ps->iter_id + iter_offset, tmp_s->stages[stage_id]->iters.size()); + step = AnnotationStepNode::make(ps->stage_id, ps->iter_id + iter_offset, ps->annotation); + } else { + StdCout(verbose) << "Parallel mutation: Cannot apply " << step << " after fuse" + << std::endl; + return State(); + } + } + tmp_s.CopyOnWrite()->transform_steps.push_back(step); + tmp_s.DoStep(step, task->compute_dag); + } + return state; +} + +// Create all possible tile size states for all SplitStep +void GridMutateTileSize(const State& old_state, std::vector* cands, + SplitFactorizationMemo* split_memo, int max_innermost_split_factor) { + // Extract all SplitStep. + std::vector split_step_ids; + for (size_t i = 0; i < old_state->transform_steps.size(); ++i) { + if (old_state->transform_steps[i]->IsInstance()) { + split_step_ids.push_back(i); + } + } + if (split_step_ids.empty()) { + return; + } + + // Move tile sizes and generate candidates. + for (size_t step_id : split_step_ids) { + const SplitStepNode* ps = old_state->transform_steps[step_id].as(); + CHECK(ps != nullptr); + + int extent = GetIntImm(ps->extent); + if (extent == 1) { + continue; + } + + // Get the current tile sizes. + std::vector lengths(ps->lengths.size(), 1); + for (int i = 0; i < static_cast(ps->lengths.size()); ++i) { + lengths[i] = GetIntImm(ps->lengths[i]); + } + + const std::vector& const_factors = split_memo->GetFactors(extent); + CHECK_GE(const_factors.size(), 1); + + // Move tile size. + for (size_t i = 0; i < ps->lengths.size(); ++i) { + int old_length = lengths[i]; + + for (int factor : const_factors) { + if (i == ps->lengths.size() - 1 && factor > max_innermost_split_factor) { + // Limit the innermost factor. + break; + } + + // Make new length experssions and a new state. + std::vector length_exprs; + lengths[i] = factor; + int outermost = extent / ElementProduct(lengths); + if (outermost == 0) { + break; + } + + // std::cout << "Mutated extent " << extent << ": " << outermost; + for (size_t j = 0; j < lengths.size(); ++j) { + // std::cout << ", " << lengths[j]; + length_exprs.emplace_back(lengths[j]); + } + // std::cout << std::endl; + + State tmp_s = old_state; + const SplitStepNode* new_ps = tmp_s->transform_steps[step_id].as(); + auto pstate = tmp_s.CopyOnWrite(); + pstate->transform_steps[step_id] = + SplitStepNode::make(new_ps->stage_id, new_ps->iter_id, new_ps->extent, length_exprs, + new_ps->inner_to_outer); + if (tmp_s.defined()) { + cands->push_back(std::move(tmp_s)); + } + } + lengths[i] = old_length; + } + } +} + +// Random choose an index according to a prefix sum probability +int RandomChoose(const std::vector& prefix_sum_probs, std::mt19937* random_gen) { + std::uniform_real_distribution<> dis(0.0, 1.0); + double x = dis(*random_gen); + + CHECK(!prefix_sum_probs.empty()); + + return std::lower_bound(prefix_sum_probs.begin(), prefix_sum_probs.end(), x) - + prefix_sum_probs.begin(); +} + +// Prune undefined states. +void PruneUndefined(std::vector* states) { + size_t pt = 0; + for (size_t i = 0; i < states->size(); ++i) { + if (!(*states)[i].defined()) { + continue; + } + (*states)[pt++] = std::move((*states)[i]); + } + + if (pt == 0) { + LOG(FATAL) << "All states are undefined."; + } else { + states->resize(pt); + } +} + +State CrossOverState(const State& p1, const State& p2) { return State(); } + +} // namespace ansor +} // namespace tvm + diff --git a/src/ansor/search_policy/utils.h b/src/ansor/search_policy/utils.h new file mode 100644 index 000000000000..05b50775b52d --- /dev/null +++ b/src/ansor/search_policy/utils.h @@ -0,0 +1,428 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file ansor/search_policy/utils.h + * \brief Common utilities for local mutation in search policy + */ + +#ifndef TVM_ANSOR_SEARCH_POLICY_UTILS_H_ +#define TVM_ANSOR_SEARCH_POLICY_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include "../cost_model/cost_model.h" +#include "../utils.h" +#include "search_policy.h" + +namespace tvm { +namespace ansor { + +inline bool StringEndWith(const std::string& str, const std::string& target) { + int str_len = str.length(); + int target_len = target.length(); + if (str_len <= target_len) { + return false; + } + return str.compare(str_len - target_len, target_len, target) == 0; +} + +// Get an integer from a tvm str Map +inline int GetIntParam(const Map& attr_dict, + const std::string& key) { + CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; + auto pint = attr_dict[key].as(); + CHECK(pint != nullptr); + return pint->value; +} + +// Get a double from a tvm str Map +inline double GetDoubleParam(const Map& attr_dict, + const std::string& key) { + CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; + auto pdouble = attr_dict[key].as(); + CHECK(pdouble != nullptr); + return pdouble->value; +} + +// Get a string from a tvm str Map +inline std::string GetStringParam(const Map& attr_dict, + const std::string& key) { + CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; + auto pstr = attr_dict[key].as(); + CHECK(pstr != nullptr); + return pstr->value; +} + +// Get a iterator name set from a tvm str Map +inline std::set GetIterNameSetParam(const Map& attr_dict, + const std::string& key) { + std::set ret; + CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; + auto names = attr_dict[key].as(); + CHECK(names != nullptr); + for (auto name = names->begin(); name != names->end(); name++) { + ret.insert(name->as()->value); + } + return ret; +} + +// Convert operation to stage id +inline int OperationToStage(const te::Operation& op, const State& state) { + for (size_t i = 0; i < state->stages.size(); ++i) { + if (op == state->stages[i]->op) { + return i; + } + } + LOG(FATAL) << "Cannot find op: " << op; + return -1; +} + +// Return the extent of an iterator +inline int64_t GetExtent(const Iterator& it) { + if (it->range.defined()) { + if (auto pint = it->range->extent.as()) { + return pint->value; + } + } + return -1; +} + +// Return whether an op is strict inlineable +inline bool IsStrictInlineable(const SearchTask& task, const State& state, const te::Operation& op) { + if (state->task_dag.defined()) { + return state->task_dag->access_analyzer.IsStrictInlineable(op); + } else { + return task->compute_dag->access_analyzer.IsStrictInlineable(op); + } +} + +// Return whether an op is an output op +inline bool IsOutputOp(const SearchTask& task, const State& state, const te::Operation& op) { + if (state->task_dag.defined()) { + return state->task_dag->access_analyzer.IsOutput(op); + } else { + return task->compute_dag->access_analyzer.IsOutput(op); + } +} + +// Return whether the stage has an attribute flag +inline bool HasAttrsFlag(const State& state, int stage_id, const char* target) { + if (state->stages[stage_id]->op->attrs.count(target)) { + return GetStringParam(state->stages[stage_id]->op->attrs, target) == "True"; + } + return false; +} + +// Return whether the stage has reduce iterators +inline bool HasReduceIter(const Stage& stage) { + for (const auto& iter : stage->iters) { + if (iter->iter_type != kSpace) { + return true; + } + } + return false; +} + +// Return whether an op needs multi level tiling +inline bool NeedsMultilevelTiling(const SearchTask& task, const State& state, const te::Operation& op) { + if (state->task_dag.defined()) { + return state->task_dag->access_analyzer.NeedsMultiLevelTiling(op); + } else { + return task->compute_dag->access_analyzer.NeedsMultiLevelTiling(op); + } +} + +// Get all consumers for an op. This will take inline into consideration +inline void GetConsumers(const SearchTask& task, const State& state, const te::Operation& op, + std::unordered_set* consumers) { + if (state->task_dag.defined()) { + state->task_dag->access_analyzer.GetConsumers(state, op, consumers); + } else { + task->compute_dag->access_analyzer.GetConsumers(state, op, consumers); + } +} + +inline void GetProducers(const SearchTask& task, const State& state, const te::Operation& op, + std::unordered_set* producers) { + if (state->task_dag.defined()) { + state->task_dag->access_analyzer.GetProducers(state, op, producers); + } else { + task->compute_dag->access_analyzer.GetProducers(state, op, producers); + } +} + +// Return whether two ops are elementwise-matched +inline bool ElementwiseMatch(const SearchTask& task, const State& state, const te::Operation& op, + const te::Operation& target_op) { + if (state->task_dag.defined()) { + return state->task_dag->access_analyzer.ElementWiseMatch(op, target_op); + } else { + return task->compute_dag->access_analyzer.ElementWiseMatch(op, target_op); + } +} + +// Return whether the stage has only one consumer and they are elementwise-matched +inline bool HasSingleElementwiseMatchedConsumer(const SearchTask& task, + const State& state, const Stage& stage, + int* target_stage_id) { + std::unordered_set consumers; + + GetConsumers(task, state, stage->op, &consumers); + if (consumers.size() == 1) { + *target_stage_id = OperationToStage(*consumers.begin(), state); + const Stage& target_stage = state->stages[*target_stage_id]; + if (ElementwiseMatch(task, state, stage->op, target_stage->op) && + (!(HasReduceIter(stage) && HasReduceIter(target_stage)))) { + return true; + } + } + return false; +} + +// Return whether this stage needs rfactor +inline bool NeedsRfactor(const SearchTask& task, const State& state, const te::Operation& op) { + if (op->IsInstance()) { + // Compute the product of lengths of all space iters and all reduce iters + int64_t cum_space_len = 1, cum_reduce_len = 1; + int stage_id = OperationToStage(op, state); + for (const auto& iter : state->stages[stage_id]->iters) { + if (iter->iter_type == kSpace) { + cum_space_len *= GetExtent(iter); + } else if (iter->iter_type == kReduce) { + cum_reduce_len *= GetExtent(iter); + } + } + + if (NeedsMultilevelTiling(task, state, op)) { + // Do not use rfactor if we have enough parallelism on space iters + if (cum_space_len > cum_reduce_len + || cum_space_len > task->hardware_params->num_cores * 16) { + return false; + } else { + return true; + } + } else if (cum_reduce_len > 1) { + // Always try rfactor for reduction ops + return true; + } + } + + return false; +} + +// Return whether the state did cache_write for stage_id +inline bool HasCacheWriteStage(const State& s, int stage_id) { + for (int i = static_cast(s->transform_steps.size()) - 1; i >= 0; --i) { + if (auto ps = s->transform_steps[i].as()) { + if (stage_id > ps->stage_id) { + stage_id--; + } else if (stage_id == ps->stage_id) { + return true; + } + } else if (auto ps = s->transform_steps[i].as()) { + if (stage_id > ps->stage_id) { + stage_id--; + } + } else if (auto ps = s->transform_steps[i].as()) { + if (stage_id > ps->stage_id) { + stage_id--; + } + } + } + return false; +} + +inline bool HasCacheReadStage(const State& s, int stage_id) { + for (int i = static_cast(s->transform_steps.size()) - 1; i >= 0; --i) { + if (auto ps = s->transform_steps[i].as()) { + if (stage_id > ps->stage_id) { + stage_id--; + } + } else if (auto ps = s->transform_steps[i].as()) { + if (stage_id > ps->stage_id) { + stage_id--; + } else if (stage_id == ps->stage_id) { + return true; + } + } else if (auto ps = s->transform_steps[i].as()) { + if (stage_id > ps->stage_id) { + stage_id--; + } + } + } + return false; +} + +void GetSpaceSplitStepIds(const State& s, int stage_id, std::vector* spatial_split_step_ids); + +inline bool HasSplitStep(const State& s, int stage_id) { + for (int i = static_cast(s->transform_steps.size()) - 1; i >= 0; --i) { + if (s->transform_steps[i]->IsInstance() || + s->transform_steps[i]->IsInstance() || + s->transform_steps[i]->IsInstance()) { + if (stage_id > s->transform_steps[i]->stage_id) { + stage_id--; + } + } else if (s->transform_steps[i]->IsInstance() || + s->transform_steps[i]->IsInstance() || + s->transform_steps[i]->IsInstance()) { + if (stage_id == s->transform_steps[i]->stage_id) { + return true; + } + } + } + return false; +} + +// Return whether the stage has been tiled already +inline bool IsTiled(const Stage& stage) { + auto op = stage->op.as(); + CHECK(op != nullptr); + return stage->iters.size() != op->axis.size() + op->reduce_axis.size(); +} + +// Query axes that should not be splitted according to the attribute from tvm.compute +std::pair, std::set > QueryNoSplitAxis(const Stage& stage); +// Query axes that last split is one +std::set QueryLastSplitIsOneAxis(const Stage& stage); + +// Extract primitive iterators from a nested fused or splitted iterator's name +inline void ExtractOriginalIterators(const std::string& name, std::set* rets) { + size_t last_pos = 0; + for (size_t i = 0; i < name.size(); ++i) { + if (name[i] == '@' || name[i] == '.') { // '@' for fuse and '.' for split + if (!isdigit(name[last_pos]) && name[last_pos] != '@' && name[last_pos] != '.') { + rets->insert(name.substr(last_pos, i - last_pos)); + } + last_pos = i + 1; + } + } + + if (last_pos < name.size() && !isdigit(name[last_pos]) && + name[last_pos] != '@' && name[last_pos] != '.') { + rets->insert(name.substr(last_pos, name.size() - last_pos)); + } +} + +// Get the last space iterator in the outer most tile +inline const Iterator& GetLastSpaceIteratorInOutermostTile(const Stage& stage) { + auto pop = stage->op.as(); + CHECK(pop != nullptr); + std::set original_names; + + for (const auto& iter : stage->iters) { + ExtractOriginalIterators(iter->name, &original_names); + if (original_names.size() == pop->axis.size()) { + return iter; + } + } + + LOG(FATAL) << "Cannot find the iterator."; + return stage->iters[0]; +} + +inline const Iterator& GetLastReduceIteratorInOutermostReduceTile(const Stage& stage) { + auto pop = stage->op.as(); + CHECK(pop != nullptr); + std::set original_names; + + auto no_split_name_pair = QueryNoSplitAxis(stage); + std::set no_split_at_inner_name_set = no_split_name_pair.first; + size_t axis_size = 0; + for (const auto axis : pop->axis) { + if (!no_split_at_inner_name_set.count(axis->var->name_hint)) { + axis_size++; + } + } + size_t reduce_axis_size = 0; + for (const auto axis : pop->reduce_axis) { + if (!no_split_at_inner_name_set.count(axis->var->name_hint)) { + reduce_axis_size++; + } + } + + if (reduce_axis_size) { + for (const auto& iter : stage->iters) { + ExtractOriginalIterators(iter->name, &original_names); + if (original_names.size() == axis_size + reduce_axis_size) { + return iter; + } + } + } else { + for (size_t i = 0; i < stage->iters.size(); i++) { + ExtractOriginalIterators(stage->iters[i]->name, &original_names); + if (original_names.size() == axis_size + 1) { + return stage->iters[i-1]; + } + } + } + + LOG(FATAL) << "Cannot find the iterator."; + return stage->iters[0]; +} + +// Random sample states +inline void RandomSampleStates(const std::vector& in_states, std::mt19937* random_gen, + size_t out_size, std::vector* out_states) { + out_states->clear(); + for (size_t i = 0; i < out_size; i++) { + out_states->push_back(in_states[(*random_gen)() % in_states.size()]); + } +} + +// Random choose an index according to a prefix sum probability +int RandomChoose(const std::vector& prefix_sum_probs, std::mt19937* random_gen); + +// Prune undefined states. +void PruneUndefined(std::vector* states); + +// Print all states +inline void PrintAllStates(const std::vector& states) { + for (size_t i = 0; i < states.size(); ++i) { + std::cerr << i << std::endl; + std::cerr << states[i]; + std::cerr << "==============================================" << std::endl; + } +} + +// Apply multi-level tiling structure according to a string format, +// where "S" stands a space level, "R" stands for a reudciton level. +// For example, if the format is "SSRSRS", the we will +// use tiling structure: space_L0, space_L1, reduce_L0, space_L2, reduce_L1, space_L3 +// For example, if apply "SSRSRS" to matrix multiplication, +// we have space iterators i and j, reduce iterator k. +// Then the tiling structure is : i0, j0, i1, j1, k0, i2, j2, k1, i3, j3 +State DoMultiLevelTiling(const State& state, int stage_id, const std::string& format, + std::vector* spatial_split_step_ids); + +// Apply tiling structure: space, space +// But use tile sizes from other SplitStep +State FollowTiling(const State& state, int stage_id, + const std::vector& split_step_ids, int n_split); + +// Randomly mutate the tile size of one SplitStep +State RandomMutateTileSize(const State& old_state, SplitFactorizationMemo* split_memo, + std::mt19937* random_gen, int max_innermost_split_factor); + +// Randomly mutate the value of one auto_unroll_max_step PragmaStep +State RandomMutateMaxUnrollStep(const State& old_state, std::mt19937* random_gen, + const std::vector& auto_unroll_configs); + +// Mutate a parallel loop. +State MutataParallel(const State& old_state, SplitFactorizationMemo* split_memo, + std::mt19937* random_gen, SearchTask& task, int verbose = 0); + +// Create all possible tile size states for all SplitStep +void GridMutateTileSize(const State& old_state, std::vector* cands, + SplitFactorizationMemo* split_memo, int max_innermost_split_factor); + +// GA: Crossover two states +State CrossOverState(const State& p1, const State& p2); + +} // namespace ansor +} // namespace tvm + +#endif // TVM_ANSOR_SEARCH_POLICY_UTILS_H_ diff --git a/tests/cpp/ansor_test.cc b/tests/cpp/ansor_test.cc index c43ec5c0a751..bbcef05f31fc 100644 --- a/tests/cpp/ansor_test.cc +++ b/tests/cpp/ansor_test.cc @@ -24,6 +24,8 @@ #include #include "../../src/ansor/loop_state.h" #include "../../src/ansor/serialization.h" +#include "../../src/ansor/feature.h" +#include "../../src/ansor/search_policy/meta_tile_rewrite_policy.h" tvm::Array matmul_func(int n, int m, int k) { using namespace tvm; @@ -84,11 +86,13 @@ tvm::Array conv2d_nchw_bn_relu_func(int N, int H, int W, return {data, kernel, bias, bn_scale, bn_offset, out}; } +using namespace tvm::ansor; + TEST(ComputeDAG, Basic) { const auto& tensors = conv2d_nchw_bn_relu_func(1, 224, 224, 3, 64, 7, 2, 3); - const auto& dag = tvm::ansor::ComputeDAGNode::make(tensors); - const auto& state = tvm::ansor::StateNode::make(dag->ops); - CHECK(std::equal_to()(state, dag.GetInitState())); + const auto& dag = ComputeDAGNode::make(tensors); + const auto& state = StateNode::make(dag->ops); + CHECK(std::equal_to()(state, dag.GetInitState())); LOG(INFO) << "\n" << state; LOG(INFO) << "\n" << dag; @@ -96,8 +100,6 @@ TEST(ComputeDAG, Basic) { } TEST(ComputeDAG, GetProducersConsumers) { - using namespace tvm::ansor; - const auto& tensors = conv2d_nchw_bn_relu_func(1, 224, 224, 3, 64, 7, 2, 3); const auto& dag = tvm::ansor::ComputeDAGNode::make(tensors); int data = 0, padding = 1, kernel = 2, conv = 3, bias = 4, bias_add = 5; @@ -159,8 +161,6 @@ TEST(ComputeDAG, GetProducersConsumers) { } TEST(ComputeDAG, InferBoundSerialization) { - using namespace tvm::ansor; - const auto& tensors = matmul_func(512, 512, 512); const auto& dag = ComputeDAGNode::make(tensors); int A = 0, B = 1, C = 2; @@ -216,8 +216,6 @@ TEST(ComputeDAG, InferBoundSerialization) { } TEST(Step, SplitFuseReorder) { - using namespace tvm::ansor; - const auto& tensors = matmul_func(512, 512, 512); const auto& dag = ComputeDAGNode::make(tensors); @@ -257,8 +255,6 @@ TEST(Step, SplitFuseReorder) { } TEST(Step, ComputeAtRootInline) { - using namespace tvm::ansor; - const auto& tensors = conv2d_nchw_bn_relu_func(1, 224, 224, 3, 64, 7, 2, 3); const auto& dag = tvm::ansor::ComputeDAGNode::make(tensors); // int data = 0, padding = 1, kernel = 2; @@ -334,7 +330,6 @@ TEST(Step, ComputeAtRootInline) { TEST(Step, CacheReadWrite) { using namespace tvm; using namespace tvm::te; - using namespace tvm::ansor; const auto& test_func = []() -> Array { int N = 4, H = 7, W = 7, CO = 512, CI = 512, KH = 3, KW = 3, stride = 1; @@ -591,8 +586,6 @@ TEST(Step, CacheReadWrite) { } TEST(Step, FollowSplitFollowFusedSplit) { - using namespace tvm::ansor; - const auto& tensors = matmul_func(512, 512, 512); const auto& dag = ComputeDAGNode::make(tensors); @@ -660,6 +653,84 @@ TEST(Step, Rfactor) { // todo } +TEST(Feature, ExtractionMatmul) { + const auto& tensors = matmul_func(512, 512, 512); + const auto& dag = ComputeDAGNode::make(tensors); + State s0 = dag.GetInitState(); + + Iterator ti = s0->stages[2]->iters[0]; + Iterator tj = s0->stages[2]->iters[1]; + Iterator tk = s0->stages[2]->iters[2]; + std::vector its; + its = s0.split(2, ti, {16}); + Iterator tio = its[0], tii = its[1]; + its = s0.split(2, tj, {8}); + Iterator tjo = its[0], tji = its[1]; + s0.reorder(2, {tio, tjo, tk, tji, tii}); + s0.vectorize(2, tji); + s0.parallel(2, tio); + s0.parallel(2, tjo); + s0.unroll(2, tk); + + int max_n_bufs = 5; + std::vector> features; + std::vector feature_names; + GetPerStmtFeatureName(max_n_bufs, &feature_names); + GetPerStmtFeaturesFromStates({s0}, + SearchTaskNode::make(dag, "test", tvm::target::llvm(), + tvm::target::llvm(), + HardwareParams()), + max_n_bufs, 0, &features); + int num_states = 1; + CHECK_EQ(feature_names.size(), (features[0].size() - 1) / num_states); + // TODO(...): Add feature check here +} + +namespace tvm { +namespace ansor { +class MetaTileRewritePolicyNodeTest { + public: + MetaTileRewritePolicyNodeTest(CostModel cost_model, SearchTask task) { + policy = make_object(); + policy->program_cost_model = std::move(cost_model); + policy->rand_gen_ = std::mt19937(0); + policy->params.Set("cpu_multi_level_tiling_structure", + te::StringImmNode::make("SSRSRS")); + policy->params.Set("disable_change_compute_location", + IntImm(DataType::Int(32), 0)); + policy->cur_task_ = task; + } + void SynthesizeMetaStructure(std::vector* meta_structures) { + policy->SynthesizeMetaStructure(meta_structures); + } + void SampleInitPopulation(const std::vector& meta_structures, + int out_size, std::vector* out_states) { + policy->SampleInitPopulation(meta_structures, out_size, out_states); + } + tvm::runtime::ObjectPtr policy; +}; +} // namespace ansor +} // namespace tvm + +TEST(MetaTileRewritePolicy, Basic) { + const auto& tensors = matmul_func(512, 512, 512); + const auto& dag = ComputeDAGNode::make(tensors); + const auto& task = SearchTaskNode::make( + dag, "test", tvm::target::llvm(), tvm::target::llvm(), HardwareParams()); + const auto& cost_model = RandomModelNode::make(); + MetaTileRewritePolicyNodeTest test(cost_model, task); + + std::vector meta_structures, init_population; + test.SynthesizeMetaStructure(&meta_structures); + CHECK_GE(meta_structures.size(), 0); + LOG(INFO) << "SynthesizeMetaStructure get " << meta_structures.size() + << " states."; + test.SampleInitPopulation(meta_structures, 100, &init_population); + CHECK_GE(init_population.size(), 0); + LOG(INFO) << "SampleInitPopulation get " << init_population.size() + << " states."; +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; From 359905a0dd2b161dab662ebf7a7ae911812ee29b Mon Sep 17 00:00:00 2001 From: Chenfan Date: Wed, 3 Jun 2020 15:36:02 +0800 Subject: [PATCH 05/45] Basic Python API for State (#6) * Add Basic Python API for State * Add UTs for State --- python/tvm/ansor/__init__.py | 20 + python/tvm/ansor/_ffi_api.py | 21 + python/tvm/ansor/compute_dag.py | 34 ++ python/tvm/ansor/state.py | 387 +++++++++++++++++ src/ansor/compute_dag.cc | 2 + src/ansor/loop_state.cc | 149 +++++++ src/ansor/transform_step.cc | 1 + src/ansor/transform_step.h | 5 + tests/cpp/ansor_test.cc | 212 +++++---- tests/python/unittest/test_ansor_common.py | 475 +++++++++++++++++++++ 10 files changed, 1222 insertions(+), 84 deletions(-) create mode 100644 python/tvm/ansor/__init__.py create mode 100644 python/tvm/ansor/_ffi_api.py create mode 100644 python/tvm/ansor/compute_dag.py create mode 100644 python/tvm/ansor/state.py create mode 100644 tests/python/unittest/test_ansor_common.py diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py new file mode 100644 index 000000000000..aaa0e9c9174d --- /dev/null +++ b/python/tvm/ansor/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import, redefined-builtin +"""Namespace for Ansor autoSchedule""" + +from .compute_dag import ComputeDAG diff --git a/python/tvm/ansor/_ffi_api.py b/python/tvm/ansor/_ffi_api.py new file mode 100644 index 000000000000..177299e67d21 --- /dev/null +++ b/python/tvm/ansor/_ffi_api.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.ansor""" +import tvm._ffi + + +tvm._ffi._init_api("ansor", __name__) diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py new file mode 100644 index 000000000000..3c46440f75ba --- /dev/null +++ b/python/tvm/ansor/compute_dag.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import +""" ... """ + +import tvm._ffi +from tvm.runtime import Object + +from .state import State + +from . import _ffi_api + + +@tvm._ffi.register_object("ansor.ComputeDAG") +class ComputeDAG(Object): + def __init__(self, tensors): + self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, tensors) + + def get_init_state(self) -> State: + return self.init_state diff --git a/python/tvm/ansor/state.py b/python/tvm/ansor/state.py new file mode 100644 index 000000000000..9a8810190199 --- /dev/null +++ b/python/tvm/ansor/state.py @@ -0,0 +1,387 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import +""" ... """ + +import tvm._ffi +from tvm.runtime import Object + +from . import _ffi_api + + +@tvm._ffi.register_object("ansor.Iterator") +class Iterator(Object): + pass + + +@tvm._ffi.register_object("ansor.Stage") +class Stage(Object): + + def iterator(self, index): + return _ffi_api.StageGetIterator(self, index) + + def iterators(self): + return _ffi_api.StageGetIterators(self) + + +@tvm._ffi.register_object("ansor.State") +class State(Object): + + def stage(self, index): + """ + Parameters + ---------- + index : Int + + Returns + ------- + stage : Stage + """ + return _ffi_api.StateGetStage(self, index) + + def transform_steps_size(self): + """ Return the size of transform_steps + """ + return _ffi_api.StateGetTransformStepsSize(self) + + def reorder(self, stage_id, order): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + order : List[Iterator] + Iterators in expected order + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateReorder(self, stage_id, order) + return state + + def split(self, stage_id, it, lengths, inner_to_outer=True): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + it : Iterator + The target Iterator + lengths: List[Int] + The split factor + inner_to_outer: Bool + True to use `factor` for split from inner to outer, + False to use `nparts` for split from outer to inner + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateSplit(self, stage_id, it, lengths, + inner_to_outer) + return state + + def follow_split(self, stage_id, it, src_step_id, n_split): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + it : Iterator + The target Iterator + src_step_id : Int + The index of target step that this split follows + n_split : Int + Indecate how many level needs to be split out + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateFollowSplit(self, stage_id, it, src_step_id, + n_split) + return state + + def follow_fused_split(self, stage_id, it, src_step_ids, level, + factor_or_nparts): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + it : Iterator + The target Iterator + src_step_ids : List[Int] + The indexes of target step that this split follows + level : Int + factor_or_nparts : Bool + True to use `factor` for split from inner to outer, + False to use `nparts` for split from outer to inner + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateFollowFusedSplit(self, stage_id, it, src_step_ids, + level, factor_or_nparts) + return state + + def fuse(self, stage_id, iters): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + iters : List[Iterator] + The target Iterators to be fused + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateFuse(self, stage_id, iters) + return state + + def vectorize(self, stage_id, it): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + it : Iterator + The target Iterator to be vectorized + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateVectorize(self, stage_id, it) + return state + + def parallel(self, stage_id, it): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + it : Iterator + The target Iterator to be paralleled + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateParallel(self, stage_id, it) + return state + + def unroll(self, stage_id, it, max_unroll=-1): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + it : Iterator + The target Iterator to be unrolled + max_unroll : Int + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateUnroll(self, stage_id, it, max_unroll) + return state + + def bind_thread(self, stage_id, it, thread_type): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + it : Iterator + The target Iterator to be vectorized + thread_type : ... + Supported type: kVThread, kBlockX, kThreadX, kThreadY + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateBindThread(self, stage_id, it, thread_type) + return state + + def compute_at(self, stage_id, target_stage_id, target_iter): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + target_stage_id : Int + The index of compute at target stage + target_iter : Iterator + The target Iterator to be compute at + + Returns + ------- + state : State + The updated state + """ + return _ffi_api.StateComputeAt(self, stage_id, target_stage_id, + target_iter) + + def compute_root(self, stage_id): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + + Returns + ------- + state : State + The updated state + """ + return _ffi_api.StateComputeRoot(self, stage_id) + + def compute_inline(self, stage_id): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + + Returns + ------- + state : State + The updated state + """ + return _ffi_api.StateComputeInline(self, stage_id) + + def pack_for_vec(self, stage_id, target_iter, vec_size): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + target_iter : Iterator + The target Iterator + vec_size : Int + + Returns + ------- + state : State + The updated state + """ + return _ffi_api.StatePackForVec(self, stage_id, target_iter, vec_size) + + def cache_read(self, stage_id, scope_name, reader_stage_ids, task_dag): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + scope_name : Str + reader_stage_ids : List[Int] + task_dag : ComputeDAG + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateCacheRead(self, stage_id, scope_name, + reader_stage_ids, task_dag) + return state + + def cache_write(self, stage_id, scope_name, task_dag): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + scope_name : Str + task_dag : ComputeDAG + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateCacheWrite(self, stage_id, scope_name, task_dag) + return state + + def pragma(self, stage_id, it, pragma_type): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + it : Iterator + The target Iterator + pragma_type : Str + + Returns + ------- + state : State + The updated state + """ + return _ffi_api.StatePragma(self, stage_id, it, pragma_type) + + def rfactor(self, stage_id, it, factor_iter_id, task_dag): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + it : Iterator + factor_iter_id : Int + task_dag : ComputeDAG + + Returns + ------- + state : State + The updated state + """ + state = _ffi_api.StateRfactor(self, stage_id, it, factor_iter_id, + task_dag) + return state + + def storage_align(self, stage_id, it, factor, offset): + """ + Parameters + ---------- + stage_id : Int + The index of target stage + it : Iterator + factor : Int + offset : Int + + Returns + ------- + state : State + The updated state + """ + return _ffi_api.StateStorageAlign(self, stage_id, it, factor, offset) diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index feaefe9f8e9f..1e33068e4965 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -1166,6 +1166,8 @@ std::pair > ComputeDAG::ReplaySteps( return std::make_pair(schedule, operator->()->tensors); } +TVM_REGISTER_GLOBAL("ansor.ComputeDAG") +.set_body_typed([](Array tensors) { return ComputeDAGNode::make(tensors); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index f01899c4c793..ebea5a1e472a 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -3,11 +3,13 @@ */ #include "loop_state.h" #include +#include #include "utils.h" namespace tvm { namespace ansor { +TVM_REGISTER_OBJECT_TYPE(StageNode); TVM_REGISTER_NODE_TYPE(StateNode); Stage StageNode::make(te::Operation op) { @@ -65,6 +67,16 @@ Stage StageNode::make(te::Operation op, StageType op_type, return Stage(node); } +TVM_REGISTER_GLOBAL("ansor.StageGetIterator") + .set_body_typed([](const Stage& stage, int index) { + return stage->iters[index]; + }); + +TVM_REGISTER_GLOBAL("ansor.StageGetIterators") + .set_body_typed([](const Stage& stage) { + return Array(stage->iters); + }); + State StateNode::make_empty_state() { auto node = make_object(); node->attach_map = AttachMapNode::make(); @@ -873,6 +885,143 @@ std::string State::ToStr(bool delete_trivial_loop) const { return os.str(); } +TVM_REGISTER_GLOBAL("ansor.StateGetStage") + .set_body_typed([](const State& state, int index) { + return state->stages[index]; + }); + +TVM_REGISTER_GLOBAL("ansor.StateGetTransformStepsSize") + .set_body_typed([](const State& state) { + return static_cast(state->transform_steps.size()); + }); + +TVM_REGISTER_GLOBAL("ansor.StateReorder") + .set_body_typed([](State state, int stage_id, + const Array& order) { + std::vector ord; + for (const auto& i : order) { + ord.push_back(i); + } + state.reorder(stage_id, ord); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateSplit") + .set_body_typed([](State state, int stage_id, const Iterator& it, + const Array& lengths, + bool inner_to_outer) { + std::vector len; + for (const auto& i : lengths) { + len.push_back(i); + } + state.split(stage_id, it, len, inner_to_outer); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateFollowSplit") + .set_body_typed([](State state, int stage_id, const Iterator& it, + int src_step_id, int n_split) { + state.follow_split(stage_id, it, src_step_id, n_split); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateFollowFusedSplit") + .set_body_typed([](State state, int stage_id, const Iterator& it, + const Array& src_step_ids, int level, + bool factor_or_nparts) { + std::vector array_src_step_ids; + for (const auto& i : src_step_ids) { + array_src_step_ids.push_back(i->value); + } + state.follow_fused_split(stage_id, it, array_src_step_ids, level, + factor_or_nparts); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateFuse") + .set_body_typed([](State state, int stage_id, + const Array& iters) { + std::vector its; + for (const auto& i : iters) { + its.push_back(i); + } + state.fuse(stage_id, its); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateVectorize") + .set_body_typed([](State state, int stage_id, + const Iterator& it) { + state.vectorize(stage_id, it); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateParallel") + .set_body_typed([](State state, int stage_id, + const Iterator& it) { + state.parallel(stage_id, it); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateUnroll") + .set_body_typed([](State state, int stage_id, + const Iterator& it, int max_unroll) { + state.unroll(stage_id, it, max_unroll); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateBindThread") + .set_body_typed([](State state, int stage_id, + const Iterator& it, int thread_type) { + state.bind_thread(stage_id, it, IteratorAnnotation(thread_type)); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateComputeAt") + .set_body_typed([](State state, int stage_id, int target_stage_id, + const Iterator& target_iter) { + state.compute_at(stage_id, target_stage_id, target_iter); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateComputeRoot") + .set_body_typed([](State state, int stage_id) { + state.compute_root(stage_id); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateComputeInline") + .set_body_typed([](State state, int stage_id) { + state.compute_inline(stage_id); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StatePackForVec") + .set_body_typed([](State state, int stage_id, + const Iterator& target_iter, int vec_size) { + state.pack_for_vec(stage_id, target_iter, vec_size); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateCacheRead") + .set_body_typed([](State state, int stage_id, const std::string& scope_name, + const Array& reader_stage_ids, + const ComputeDAG& task_dag) { + std::vector array_reader_stage_ids; + for (const auto& i : reader_stage_ids) { + array_reader_stage_ids.push_back(i->value); + } + state.cache_read(stage_id, scope_name, array_reader_stage_ids, task_dag); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateCacheWrite") + .set_body_typed([](State state, int stage_id, const std::string& scope_name, + const ComputeDAG& task_dag) { + state.cache_write(stage_id, scope_name, task_dag); + return state; + }); + void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id) { AttachMapNode* pnode = CopyOnWrite(); diff --git a/src/ansor/transform_step.cc b/src/ansor/transform_step.cc index 8cd8233ae9be..5f4a6a8dcef9 100644 --- a/src/ansor/transform_step.cc +++ b/src/ansor/transform_step.cc @@ -8,6 +8,7 @@ namespace tvm { namespace ansor { +TVM_REGISTER_NODE_TYPE(IteratorNode); TVM_REGISTER_OBJECT_TYPE(StepNode); /********** Reorder **********/ diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h index 9b430be99bd3..627ce02b60e1 100644 --- a/src/ansor/transform_step.h +++ b/src/ansor/transform_step.h @@ -69,6 +69,11 @@ class IteratorNode : public Object { IteratorType iter_type, IteratorAnnotation annotation, const std::vector* ori_iters = nullptr); + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("range", &range); + } + static constexpr const char *_type_key = "ansor.Iterator"; TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object); }; diff --git a/tests/cpp/ansor_test.cc b/tests/cpp/ansor_test.cc index bbcef05f31fc..75a6cc00b802 100644 --- a/tests/cpp/ansor_test.cc +++ b/tests/cpp/ansor_test.cc @@ -19,13 +19,15 @@ #include #include -#include -#include #include -#include "../../src/ansor/loop_state.h" -#include "../../src/ansor/serialization.h" +#include + +#include + #include "../../src/ansor/feature.h" +#include "../../src/ansor/loop_state.h" #include "../../src/ansor/search_policy/meta_tile_rewrite_policy.h" +#include "../../src/ansor/serialization.h" tvm::Array matmul_func(int n, int m, int k) { using namespace tvm; @@ -35,16 +37,17 @@ tvm::Array matmul_func(int n, int m, int k) { Tensor B = placeholder({k, m}, DataType::Float(32), "B"); IterVar K = IterVarNode::make({0, k}, Var("k"), kCommReduce); const auto& C = compute( - {n, m}, - [&](Var i, Var j) { return tvm::sum(A[i][K] * B[K][j], {K}); }, + {n, m}, [&](Var i, Var j) { return tvm::sum(A[i][K] * B[K][j], {K}); }, "C"); return {A, B, C}; } tvm::Array conv2d_nchw_bn_relu_func(int N, int H, int W, - int CI, int CO, int kernel_size, int strides, int padding, - int dilation = 1) { + int CI, int CO, + int kernel_size, + int strides, int padding, + int dilation = 1) { using namespace tvm; using namespace tvm::te; @@ -58,27 +61,27 @@ tvm::Array conv2d_nchw_bn_relu_func(int N, int H, int W, int OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1) / strides + 1; int OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1) / strides + 1; - const auto& conv = topi::conv2d_nchw(data, kernel, padding, padding, strides, - strides); + const auto& conv = + topi::conv2d_nchw(data, kernel, padding, padding, strides, strides); CHECK(conv->shape[2].as()->value == OH); CHECK(conv->shape[3].as()->value == OW); const auto& bias_add = compute( {N, CO, OH, OW}, [&](Var i, Var j, Var k, Var l) { - return conv[i][j][k][l] + bias[j][0][0]; + return conv[i][j][k][l] + bias[j][0][0]; }, "Bias_add"); const auto& bn_mul = compute( {N, CO, OH, OW}, [&](Var i, Var j, Var k, Var l) { - return bias_add[i][j][k][l] * bn_scale[j][0][0]; + return bias_add[i][j][k][l] * bn_scale[j][0][0]; }, "Bn_mul"); const auto& bn_add = compute( {N, CO, OH, OW}, [&](Var i, Var j, Var k, Var l) { - return bn_mul[i][j][k][l] + bn_offset[j][0][0]; + return bn_mul[i][j][k][l] + bn_offset[j][0][0]; }, "Bn_add"); const auto& out = topi::relu(bn_add); @@ -109,20 +112,22 @@ TEST(ComputeDAG, GetProducersConsumers) { std::unordered_set set; { std::vector> consumer_list = { - {data, padding}, {padding, conv}, {kernel, conv}, {conv, bias_add}, - {bias, bias_add}, {bias_add, bn_mul}, {bn_scale, bn_mul}, - {bn_mul, bn_add}, {bn_offset, bn_add}, {bn_add, relu} - }; + {data, padding}, {padding, conv}, {kernel, conv}, + {conv, bias_add}, {bias, bias_add}, {bias_add, bn_mul}, + {bn_scale, bn_mul}, {bn_mul, bn_add}, {bn_offset, bn_add}, + {bn_add, relu}}; for (const auto& pair : consumer_list) { dag->access_analyzer.GetConsumers(s0, s0->stages[pair.first]->op, &set); CHECK_EQ(set.size(), 1); CHECK_EQ((*set.begin()), s0->stages[pair.second]->op); } std::vector>> producer_list = { - {padding, {data}}, {conv, {padding, kernel}}, {bias_add, {conv, bias}}, - {bn_mul, {bias_add, bn_scale}}, {bn_add, {bn_mul, bn_offset}}, - {relu, {bn_add}} - }; + {padding, {data}}, + {conv, {padding, kernel}}, + {bias_add, {conv, bias}}, + {bn_mul, {bias_add, bn_scale}}, + {bn_add, {bn_mul, bn_offset}}, + {relu, {bn_add}}}; for (const auto& pair : producer_list) { dag->access_analyzer.GetProducers(s0, s0->stages[pair.first]->op, &set); CHECK_EQ(set.size(), pair.second.size()); @@ -138,18 +143,19 @@ TEST(ComputeDAG, GetProducersConsumers) { s0.compute_inline(padding); { std::vector> consumer_list = { - {data, conv}, {kernel, conv}, {conv, relu} - }; + {data, conv}, {kernel, conv}, {conv, relu}}; for (const auto& pair : consumer_list) { dag->access_analyzer.GetConsumers(s0, s0->stages[pair.first]->op, &set); CHECK_EQ(set.size(), 1); CHECK_EQ((*set.begin()), s0->stages[pair.second]->op); } std::vector>> producer_list = { - {padding, {data}}, {conv, {padding, kernel}}, {bias_add, {conv, bias}}, - {bn_mul, {bias_add, bn_scale}}, {bn_add, {bn_mul, bn_offset}}, - {relu, {bn_add}} - }; + {padding, {data}}, + {conv, {padding, kernel}}, + {bias_add, {conv, bias}}, + {bn_mul, {bias_add, bn_scale}}, + {bn_add, {bn_mul, bn_offset}}, + {relu, {bn_add}}}; for (const auto& pair : producer_list) { dag->access_analyzer.GetProducers(s0, s0->stages[pair.first]->op, &set); CHECK_EQ(set.size(), pair.second.size()); @@ -170,15 +176,19 @@ TEST(ComputeDAG, InferBoundSerialization) { C++; const auto& its0 = s0.split(C, s0->stages[C]->iters[0], {4, 8, 8}); const auto& its1 = s0.split(C, s0->stages[C]->iters[4], {8, 4, 4}); - s0.reorder(C, {its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], - its0[3], its1[3]}); + s0.reorder(C, {its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], + its1[3]}); s0.compute_at(C_global, C, s0->stages[C]->iters[3]); s0.split(C_global, s0->stages[C_global]->iters[2], {16}); int B_global = s0.cache_read(B, "global", {C_global}, dag); - C++; C_global++; + C++; + C_global++; s0.compute_at(B_global, C_global, s0->stages[C_global]->iters[0]); int A_global = s0.cache_read(A, "global", {C_global}, dag); - B++; B_global++; C++; C_global++; + B++; + B_global++; + C++; + C_global++; s0.compute_at(A_global, C_global, s0->stages[C_global]->iters[2]); const auto& s1 = dag.InferBound(s0); @@ -186,23 +196,26 @@ TEST(ComputeDAG, InferBoundSerialization) { dag.InferBound(&s2); const auto& s3 = dag.ReplayAndInferBound(s0->transform_steps); - CHECK_EQ(s1->stages[B_global]->iters[0]->range->extent.as()->value, - 512); - CHECK_EQ(s1->stages[B_global]->iters[1]->range->extent.as()->value, - 16); - CHECK_EQ(s1->stages[A_global]->iters[0]->range->extent.as()->value, - 1); - CHECK_EQ(s1->stages[A_global]->iters[1]->range->extent.as()->value, - 16); - CHECK_EQ(s1->stages[C_global]->iters[0]->range->extent.as()->value, - 64); + CHECK_EQ( + s1->stages[B_global]->iters[0]->range->extent.as()->value, + 512); + CHECK_EQ( + s1->stages[B_global]->iters[1]->range->extent.as()->value, + 16); + CHECK_EQ( + s1->stages[A_global]->iters[0]->range->extent.as()->value, 1); + CHECK_EQ( + s1->stages[A_global]->iters[1]->range->extent.as()->value, + 16); + CHECK_EQ( + s1->stages[C_global]->iters[0]->range->extent.as()->value, + 64); CHECK(std::equal_to()(s1, s2[0])); CHECK(std::equal_to()(s1, s3)); const auto& minp0 = MeasureInputNode::make( SearchTaskNode::make(dag, "test", tvm::target::llvm(), - tvm::target::llvm(), - HardwareParams()), + tvm::target::llvm(), HardwareParams()), s0); const auto& mres0 = MeasureResultNode::make({0.1}, 0, "", 0.1, 0.1); std::stringstream ss; @@ -242,7 +255,8 @@ TEST(Step, SplitFuseReorder) { CHECK_EQ(s0->stages[2]->iters[2]->range->extent.as()->value, 512); s0.fuse(2, {tio, tjo}); - CHECK_EQ(s0->stages[2]->iters[0]->range->extent.as()->value, 2048); + CHECK_EQ(s0->stages[2]->iters[0]->range->extent.as()->value, + 2048); s1.split(2, ti, {8, 2}); s1.split(2, tj, {32, 8}, false); @@ -271,10 +285,12 @@ TEST(Step, ComputeAtRootInline) { s0.compute_inline(bn_mul); s0.compute_inline(bias_add); s0.compute_at(conv, relu, s0->stages[relu]->iters[2]); - const auto& conv_stage_attach = s0->attach_map->stage_to_attach_iter.find(conv); + const auto& conv_stage_attach = + s0->attach_map->stage_to_attach_iter.find(conv); std::pair iterkey(relu, 2); CHECK(conv_stage_attach->second == iterkey); - const auto& conv_iter_attach = s0->attach_map->iter_to_attached_stages.find(iterkey); + const auto& conv_iter_attach = + s0->attach_map->iter_to_attached_stages.find(iterkey); CHECK_EQ(conv_iter_attach->second.size(), 1); CHECK_EQ(conv_iter_attach->second[0], conv); std::stringstream ss; @@ -335,25 +351,28 @@ TEST(Step, CacheReadWrite) { int N = 4, H = 7, W = 7, CO = 512, CI = 512, KH = 3, KW = 3, stride = 1; int padding = 1; Tensor data = placeholder({N, CI, H, W}, DataType::Float(32), "Data"); - Tensor kernel_data = placeholder({CO, CI, KH, KW}, DataType::Float(32), - "kernel_data"); - const auto& k_split = compute(kernel_data->shape, + Tensor kernel_data = + placeholder({CO, CI, KH, KW}, DataType::Float(32), "Kernel_data"); + const auto& k_split = compute( + kernel_data->shape, [&](const Array& i) { - return Array({kernel_data[i[0]][i[1]][i[2]][i[3]] + 1, - div(kernel_data[i[0]][i[1]][i[2]][i[3]], 2)}); + return Array({kernel_data[i[0]][i[1]][i[2]][i[3]] + 1, + div(kernel_data[i[0]][i[1]][i[2]][i[3]], 2)}); }, "Kernel_split"); - const auto& kernel = compute(kernel_data->shape, + const auto& kernel = compute( + kernel_data->shape, [&](Var i, Var j, Var k, Var l) { - return (k_split[0])[i][j][k][l] + (k_split[1])[i][j][k][l]; + return (k_split[0])[i][j][k][l] + (k_split[1])[i][j][k][l]; }, "Kernel"); - const auto& conv = topi::conv2d_nchw(data, kernel, padding, padding, stride, - stride); + const auto& conv = + topi::conv2d_nchw(data, kernel, padding, padding, stride, stride); const auto& relu = topi::relu(conv); - const auto& out = compute(relu->shape, + const auto& out = compute( + relu->shape, [&](Var i, Var j, Var k, Var l) { - return data[i][j][k][l] + relu[i][j][k][l]; + return data[i][j][k][l] + relu[i][j][k][l]; }, "Add"); return {data, kernel_data, out}; @@ -372,15 +391,20 @@ TEST(Step, CacheReadWrite) { // 1: simple cache_write with compute_at int conv_global = s0.cache_write(conv, "global", dag); - conv++; relu++; add++; + conv++; + relu++; + add++; s0.compute_at(conv_global, conv, s0->stages[conv]->iters[3]); // 2: simple cache_read with compute_at int kernel_global = s0.cache_read(kernel, "global", {conv_global}, dag); - conv_global++; conv++; relu++; add++; + conv_global++; + conv++; + relu++; + add++; s0.compute_at(kernel_global, conv_global, s0->stages[conv_global]->iters[4]); std::stringstream ss; - ss << "Placeholder: Data, kernel_data\n" + ss << "Placeholder: Data, Kernel_data\n" << "for ax0 (0,4)\n" << " for ax1 (0,512)\n" << " for ax2 (0,9)\n" @@ -425,25 +449,45 @@ TEST(Step, CacheReadWrite) { // 3: two level cache_read with compute_at // preparing for GPU's shared memory & local memory int pad_temp_global = s0.cache_read(pad_temp, "global", {conv_global}, dag); - kernel_data++; kernel_split++; kernel++; kernel_global++; - conv_global++; conv++; relu++; add++; - int pad_temp_shared = s0.cache_read(pad_temp_global, "shared", {conv_global}, - dag); - kernel_data++; kernel_split++; kernel++; kernel_global++; - conv_global++; conv++; relu++; add++; + kernel_data++; + kernel_split++; + kernel++; + kernel_global++; + conv_global++; + conv++; + relu++; + add++; + int pad_temp_shared = + s0.cache_read(pad_temp_global, "shared", {conv_global}, dag); + kernel_data++; + kernel_split++; + kernel++; + kernel_global++; + conv_global++; + conv++; + relu++; + add++; s0.compute_at(pad_temp_global, conv_global, s0->stages[conv_global]->iters[2]); s0.compute_at(pad_temp_shared, conv_global, s0->stages[conv_global]->iters[4]); // 4: cache_read with multi readers - // This stage cannot be compute at to its consumer + // This stage cannot be compute at to its consumer s0.cache_read(data, "global", {pad_temp, add}, dag); - pad_temp++; pad_temp_global++; pad_temp_shared++; - kernel_data++; kernel_split++; kernel++; kernel_global++; - conv_global++; conv++; relu++; add++; + pad_temp++; + pad_temp_global++; + pad_temp_shared++; + kernel_data++; + kernel_split++; + kernel++; + kernel_global++; + conv_global++; + conv++; + relu++; + add++; ss.str(std::string()); - ss << "Placeholder: Data, kernel_data\n" + ss << "Placeholder: Data, Kernel_data\n" << "for ax0 (0,4)\n" << " for ax1 (0,512)\n" << " for ax2 (0,7)\n" @@ -517,7 +561,7 @@ TEST(Step, CacheReadWrite) { // To be fixed in the future s0.cache_write(kernel_split, "global", dag); ss.str(std::string()); - ss << "Placeholder: Data, kernel_data\n" + ss << "Placeholder: Data, Kernel_data\n" << "for ax0 (0,4)\n" << " for ax1 (0,512)\n" << " for ax2 (0,7)\n" @@ -598,8 +642,8 @@ TEST(Step, FollowSplitFollowFusedSplit) { // FollowSplitStep currently only support `inner_to_outer = true` const auto& its0 = s0.split(C, s0->stages[C]->iters[0], {4, 2, 8, 4}, true); int split_step0 = s0->transform_steps.size() - 1; - // const auto& its1 = s0.split(C, s0->stages[C]->iters[5], {4, 2, 8, 4}, false); - // int split_step1 = s0->transform_steps.size() - 1; + // const auto& its1 = s0.split(C, s0->stages[C]->iters[5], {4, 2, 8, 4}, + // false); int split_step1 = s0->transform_steps.size() - 1; for (int level = 1; level <= 5; level++) { State tmp = s0; tmp.follow_split(C_global, s0->stages[C_global]->iters[0], split_step0, @@ -610,7 +654,7 @@ TEST(Step, FollowSplitFollowFusedSplit) { const auto& stage_C_global = tmp->stages[C_global]; for (int i = 0; i < level; i++) { CHECK_EQ(stage_C->iters[i]->range->extent.as()->value, - stage_C_global->iters[i]->range->extent.as()->value); + stage_C_global->iters[i]->range->extent.as()->value); } // for (int i = 0; i < level; i++) { // CHECK(stage_C->iters[i+5]->range->extent.as()->value == @@ -627,7 +671,7 @@ TEST(Step, FollowSplitFollowFusedSplit) { } s0.reorder(C, its); for (int i = 0; i < 5; i++) { - s0.fuse(C, {s0->stages[C]->iters[i], s0->stages[C]->iters[i+1]}); + s0.fuse(C, {s0->stages[C]->iters[i], s0->stages[C]->iters[i + 1]}); } for (int level = 0; level < 4; level++) { State tmp = s0; @@ -635,8 +679,8 @@ TEST(Step, FollowSplitFollowFusedSplit) { {split_step0, split_step1}, level, false); const auto& stage_C = tmp->stages[C]; const auto& stage_C_global = tmp->stages[C_global]; - CHECK_EQ(stage_C->iters[level+1]->range->extent.as()->value, - stage_C_global->iters[0]->range->extent.as()->value); + CHECK_EQ(stage_C->iters[level + 1]->range->extent.as()->value, + stage_C_global->iters[0]->range->extent.as()->value); } for (int level = 0; level < 4; level++) { State tmp = s0; @@ -644,8 +688,8 @@ TEST(Step, FollowSplitFollowFusedSplit) { {split_step0, split_step1}, level, true); const auto& stage_C = tmp->stages[C]; const auto& stage_C_global = tmp->stages[C_global]; - CHECK_EQ(stage_C->iters[level+1]->range->extent.as()->value, - stage_C_global->iters[1]->range->extent.as()->value); + CHECK_EQ(stage_C->iters[level + 1]->range->extent.as()->value, + stage_C_global->iters[1]->range->extent.as()->value); } } @@ -676,10 +720,10 @@ TEST(Feature, ExtractionMatmul) { std::vector> features; std::vector feature_names; GetPerStmtFeatureName(max_n_bufs, &feature_names); - GetPerStmtFeaturesFromStates({s0}, + GetPerStmtFeaturesFromStates( + {s0}, SearchTaskNode::make(dag, "test", tvm::target::llvm(), - tvm::target::llvm(), - HardwareParams()), + tvm::target::llvm(), HardwareParams()), max_n_bufs, 0, &features); int num_states = 1; CHECK_EQ(feature_names.size(), (features[0].size() - 1) / num_states); @@ -704,7 +748,7 @@ class MetaTileRewritePolicyNodeTest { policy->SynthesizeMetaStructure(meta_structures); } void SampleInitPopulation(const std::vector& meta_structures, - int out_size, std::vector* out_states) { + int out_size, std::vector* out_states) { policy->SampleInitPopulation(meta_structures, out_size, out_states); } tvm::runtime::ObjectPtr policy; diff --git a/tests/python/unittest/test_ansor_common.py b/tests/python/unittest/test_ansor_common.py new file mode 100644 index 000000000000..4782f9130cea --- /dev/null +++ b/tests/python/unittest/test_ansor_common.py @@ -0,0 +1,475 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import te +from tvm import ansor +import topi + + +def matmul_nkkm(N, M, K): + A = te.placeholder((N, K), name='A') + B = te.placeholder((K, M), name='B') + k = te.reduce_axis((0, K), name='k') + C = te.compute((N, M), lambda i, j: te.sum( + A[i][k] * B[k][j], axis=[k]), name='C') + + return [A, B, C] + + +def conv2d_nchw_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation=1): + data = te.placeholder((N, CI, H, W), name='Data') + kernel = te.placeholder((CO, CI, kernel_size, kernel_size), name='Kernel') + bias = te.placeholder((CO, 1, 1), name='Bias') + bn_scale = te.placeholder((CO, 1, 1), name='Bn_scale') + bn_offset = te.placeholder((CO, 1, 1), name='Bn_offset') + + OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 + OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 + + conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation) + conv = te.compute((N, CO, OH, OW), + lambda i, j, k, l: conv[i, j, k, l] + bias[j, 0, 0], + name='Bias_add') + conv = te.compute((N, CO, OH, OW), + lambda i, j, k, l: conv[i, j, k, l] * bn_scale[j, 0, 0], + name='Bn_mul') + conv = te.compute((N, CO, OH, OW), + lambda i, j, k, l: conv[i, j, k, l] + bn_offset[j, 0, 0], + name='Bn_add') + out = topi.nn.relu(conv) + + return [data, kernel, bias, bn_offset, bn_scale, out] + + +def test_compute_dag_basic(): + dag = ansor.ComputeDAG(conv2d_nchw_bn_relu(1, 224, 224, 3, 64, 7, 2, 3)) + + print(dag) + print(dag.access_analyzer) + print(dag.get_init_state()) + + +def test_state_split_fuse_reorder(): + dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) + s0 = dag.get_init_state() + s1 = s0 + ti = s0.stage(2).iterator(0) + tj = s0.stage(2).iterator(1) + tk = s0.stage(2).iterator(2) + + assert ti.range.extent == 512 + + s0 = s0.split(2, ti, [16]) + assert s0.stage(2).iterator(0).range.extent == 32 + assert s0.stage(2).iterator(1).range.extent == 16 + tio = s0.stage(2).iterator(0) + tii = s0.stage(2).iterator(1) + + s0 = s0.split(2, tj, [8]) + assert s0.stage(2).iterator(2).range.extent == 64 + assert s0.stage(2).iterator(3).range.extent == 8 + tjo = s0.stage(2).iterator(2) + tji = s0.stage(2).iterator(3) + + s0 = s0.reorder(2, [tio, tjo, tk, tji, tii]) + assert s0.stage(2).iterator(2).range.extent == 512 + + s0 = s0.fuse(2, [tio, tjo]) + assert s0.stage(2).iterator(0).range.extent == 2048 + + s1 = s1.split(2, ti, [8, 2]) + s1 = s1.split(2, tj, [32, 8], False) + assert s1.stage(2).iterator(0).range.extent == 32 + assert s1.stage(2).iterator(1).range.extent == 8 + assert s1.stage(2).iterator(2).range.extent == 2 + assert s1.stage(2).iterator(3).range.extent == 32 + assert s1.stage(2).iterator(4).range.extent == 8 + assert s1.stage(2).iterator(5).range.extent == 2 + + +def test_state_compute_at_root_inline(): + dag = ansor.ComputeDAG(conv2d_nchw_bn_relu(1, 224, 224, 3, 64, 7, 2, 3)) + + # data, padding, kernel = 0, 1, 2 + conv = 3 + # bias = 4 + bias_add = 5 + # bn_scale = 6 + bn_mul = 7 + # bn_offset = 8 + bn_add, relu = 9, 10 + + s0 = dag.get_init_state() + s0 = s0.compute_inline(bn_add) + s0 = s0.compute_inline(bn_mul) + s0 = s0.compute_inline(bias_add) + s0 = s0.compute_at(conv, relu, s0.stage(relu).iterator(2)) + assert str(s0) == \ + "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + \ + "for i1 (0,3)\n" + \ + " for i2 (0,230)\n" + \ + " for i3 (0,230)\n" + \ + " pad_temp = ...\n" + \ + "for i1 (0,64)\n" + \ + " for i2 (0,112)\n" + \ + " for nn (None)\n" + \ + " for ff (None)\n" + \ + " for yy (None)\n" + \ + " for xx (None)\n" + \ + " for rc (None)\n" + \ + " for ry (None)\n" + \ + " for rx (None)\n" + \ + " compute = ...\n" + \ + " for i3 (0,112)\n" + \ + " compute = ...\n" + + s0 = s0.compute_root(conv) + s0 = s0.compute_root(bn_mul) + assert str(s0) == \ + "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + \ + "for i1 (0,3)\n" + \ + " for i2 (0,230)\n" + \ + " for i3 (0,230)\n" + \ + " pad_temp = ...\n" + \ + "for nn (None)\n" + \ + " for ff (None)\n" + \ + " for yy (None)\n" + \ + " for xx (None)\n" + \ + " for rc (None)\n" + \ + " for ry (None)\n" + \ + " for rx (None)\n" + \ + " compute = ...\n" + \ + "for i (None)\n" + \ + " for j (None)\n" + \ + " for k (None)\n" + \ + " for l (None)\n" + \ + " Bn_mul = ...\n" + \ + "for i1 (0,64)\n" + \ + " for i2 (0,112)\n" + \ + " for i3 (0,112)\n" + \ + " compute = ...\n" + + +def test_state_cache_read_write(): + N, H, W, CO, CI, KH, KW, strides, padding = 4, 7, 7, 512, 512, 3, 3, ( + 1, 1), (1, 1) + + data = te.placeholder((N, CI, H, W), name='Data') + kernel_data = te.placeholder((CO, CI, KH, KW), name='Kernel_data') + k0, k1 = te.compute(kernel_data.shape, + lambda *i: (kernel_data(*i)+1, kernel_data(*i)/2), + name='Kernel_split') + kernel = te.compute(kernel_data.shape, + lambda *i: k0(*i) + k1(*i), + name='Kernel') + conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation=1) + relu = topi.nn.relu(conv) + out = topi.add(data, relu) + + dag = ansor.ComputeDAG([data, kernel_data, out]) + data, pad_temp, kernel_data, kernel_split, kernel, conv, relu, add = 0, 1, 2, 3, 4, 5, 6, 7 + + # 0: init state + s0 = dag.get_init_state() + ori_its = s0.stage(add).iterators() + s0 = s0.split(add, s0.stage(add).iterator(0), [2]) + s0 = s0.reorder(add, [s0.stage(add).iterator(0), ori_its[1], + s0.stage(add).iterator(1), ori_its[2], ori_its[3]]) + s0 = s0.compute_inline(relu) + + # 1: simple cache_write with compute_at + s0 = s0.cache_write(conv, "global", dag) + conv_global = conv + conv += 1 + relu += 1 + add += 1 + s0 = s0.compute_at(conv_global, conv, s0.stage(conv).iterator(3)) + + # 2: simple cache_read with compute_at + s0 = s0.cache_read(kernel, "global", [conv_global], dag) + kernel_global = kernel + 1 + conv_global += 1 + conv += 1 + relu += 1 + add += 1 + s0 = s0.compute_at(kernel_global, conv_global, + s0.stage(conv_global).iterator(4)) + assert str(s0) == \ + "Placeholder: Data, Kernel_data\n" + \ + "for i0 (0,4)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,9)\n" + \ + " for i3 (0,9)\n" + \ + " pad_temp = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel_split = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel = ...\n" + \ + "for nn (0,4)\n" + \ + " for ff (0,512)\n" + \ + " for yy (0,7)\n" + \ + " for xx (0,7)\n" + \ + " for nn_c (None)\n" + \ + " for ff_c (None)\n" + \ + " for yy_c (None)\n" + \ + " for xx_c (None)\n" + \ + " for rc (None)\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " Kernel.global = ...\n" + \ + " for ry (None)\n" + \ + " for rx (None)\n" + \ + " compute.global = ...\n" + \ + " compute = ...\n" + \ + "for ax0.0 (0,2)\n" + \ + " for ax1 (0,512)\n" + \ + " for ax0.1 (0,2)\n" + \ + " for ax2 (0,7)\n" + \ + " for ax3 (0,7)\n" + \ + " T_add = ...\n" + + # 3: two level cache_read with compute_at + # preparing for GPU's shared memory & local memory + s0 = s0.cache_read(pad_temp, "global", [conv_global], dag) + pad_temp_global = pad_temp + 1 + kernel_data += 1 + kernel_split += 1 + kernel += 1 + kernel_global += 1 + conv_global += 1 + conv += 1 + relu += 1 + add += 1 + s0 = s0.cache_read(pad_temp_global, "shared", [conv_global], dag) + pad_temp_shared = pad_temp_global + 1 + kernel_data += 1 + kernel_split += 1 + kernel += 1 + kernel_global += 1 + conv_global += 1 + conv += 1 + relu += 1 + add += 1 + s0 = s0.compute_at(pad_temp_global, conv_global, + s0.stage(conv_global).iterator(2)) + s0 = s0.compute_at(pad_temp_shared, conv_global, + s0.stage(conv_global).iterator(4)) + + # 4: cache_read with multi readers + # This stage cannot be compute at to its consumer + s0 = s0.cache_read(data, "global", [pad_temp, add], dag) + pad_temp += 1 + pad_temp_global += 1 + pad_temp_shared += 1 + kernel_data += 1 + kernel_split += 1 + kernel += 1 + kernel_global += 1 + conv_global += 1 + conv += 1 + relu += 1 + add += 1 + assert str(s0) == \ + "Placeholder: Data, Kernel_data\n" + \ + "for ax0 (0,4)\n" + \ + " for ax1 (0,512)\n" + \ + " for ax2 (0,7)\n" + \ + " for ax3 (0,7)\n" + \ + " Data.global = ...\n" + \ + "for i0 (0,4)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,9)\n" + \ + " for i3 (0,9)\n" + \ + " pad_temp = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel_split = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel = ...\n" + \ + "for nn (0,4)\n" + \ + " for ff (0,512)\n" + \ + " for yy (0,7)\n" + \ + " for xx (0,7)\n" + \ + " for nn_c (None)\n" + \ + " for ff_c (None)\n" + \ + " for yy_c (None)\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " pad_temp.global = ...\n" + \ + " for xx_c (None)\n" + \ + " for rc (None)\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " Kernel.global = ...\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " pad_temp.global.shared = ...\n" + \ + " for ry (None)\n" + \ + " for rx (None)\n" + \ + " compute.global = ...\n" + \ + " compute = ...\n" + \ + "for ax0.0 (0,2)\n" + \ + " for ax1 (0,512)\n" + \ + " for ax0.1 (0,2)\n" + \ + " for ax2 (0,7)\n" + \ + " for ax3 (0,7)\n" + \ + " T_add = ...\n" + + # 5: cache_write with multi outputs + # See tests/cpp/ansor_test.cc for more information + s0 = s0.cache_write(kernel_split, "global", dag) + assert str(s0) == \ + "Placeholder: Data, Kernel_data\n" + \ + "for ax0 (0,4)\n" + \ + " for ax1 (0,512)\n" + \ + " for ax2 (0,7)\n" + \ + " for ax3 (0,7)\n" + \ + " Data.global = ...\n" + \ + "for i0 (0,4)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,9)\n" + \ + " for i3 (0,9)\n" + \ + " pad_temp = ...\n" + \ + "for i0_c (0,512)\n" + \ + " for i1_c (0,512)\n" + \ + " for i2_c (0,3)\n" + \ + " for i3_c (0,3)\n" + \ + " Kernel_split.global = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel_split = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel_split = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel = ...\n" + \ + "for nn (0,4)\n" + \ + " for ff (0,512)\n" + \ + " for yy (0,7)\n" + \ + " for xx (0,7)\n" + \ + " for nn_c (None)\n" + \ + " for ff_c (None)\n" + \ + " for yy_c (None)\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " pad_temp.global = ...\n" + \ + " for xx_c (None)\n" + \ + " for rc (None)\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " Kernel.global = ...\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " pad_temp.global.shared = ...\n" + \ + " for ry (None)\n" + \ + " for rx (None)\n" + \ + " compute.global = ...\n" + \ + " compute = ...\n" + \ + "for ax0.0 (0,2)\n" + \ + " for ax1 (0,512)\n" + \ + " for ax0.1 (0,2)\n" + \ + " for ax2 (0,7)\n" + \ + " for ax3 (0,7)\n" + \ + " T_add = ...\n" + + +def test_follow_split_follow_fused_split(): + dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) + s0 = dag.get_init_state() + C = 2 + + s0 = s0.cache_write(C, "global", dag) + C_global = C + C += 1 + + s0 = s0.split(C, s0.stage(C).iterator(0), [4, 2, 8, 4], True) + split_step0 = s0.transform_steps_size() - 1 + for level in range(1, 6): + tmp = s0 + tmp = tmp.follow_split(C_global, tmp.stage( + C_global).iterator(0), split_step0, level) + for i in range(0, level): + assert tmp.stage(C).iterator(i).range.extent == \ + tmp.stage(C_global).iterator(i).range.extent + + s0 = s0.split(C, s0.stage(C).iterator(5), [2, 2, 4, 8]) + split_step1 = s0.transform_steps_size() - 1 + its = s0.stage(C).iterators() + s0 = s0.reorder(C, [its[0], its[5], its[1], its[6], its[2], its[7], + its[3], its[8], its[4], its[9]]) + s0 = s0.fuse(C, [s0.stage(C).iterator(0), s0.stage(C).iterator(1)]) + s0 = s0.fuse(C, [s0.stage(C).iterator(1), s0.stage(C).iterator(2)]) + s0 = s0.fuse(C, [s0.stage(C).iterator(2), s0.stage(C).iterator(3)]) + s0 = s0.fuse(C, [s0.stage(C).iterator(3), s0.stage(C).iterator(4)]) + s0 = s0.fuse(C, [s0.stage(C).iterator(4), s0.stage(C).iterator(5)]) + for level in range(0, 4): + tmp = s0 + tmp = tmp.follow_fused_split(C_global, tmp.stage(C_global).iterator(0), + [split_step0, split_step1], level, False) + assert tmp.stage(C).iterator(level+1).range.extent == \ + tmp.stage(C_global).iterator(0).range.extent + for level in range(0, 4): + tmp = s0 + tmp = tmp.follow_fused_split(C_global, tmp.stage(C_global).iterator(0), + [split_step0, split_step1], level, True) + assert tmp.stage(C).iterator(level+1).range.extent == \ + tmp.stage(C_global).iterator(1).range.extent + + +def test_rfactor(): + pass + + +if __name__ == "__main__": + test_compute_dag_basic() + test_state_split_fuse_reorder() + test_state_compute_at_root_inline() + test_state_cache_read_write() + test_follow_split_follow_fused_split() + test_rfactor() From 2032a64356dc88e341162281b94746e61cdabfe2 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Thu, 4 Jun 2020 16:05:07 +0800 Subject: [PATCH 06/45] Add Python API: Measure & Task (#7) * Update the return value of state operation * Add task * Copy measure.py & utils.py * Fix LocalBuilder * Fix LocalRunner --- python/tvm/ansor/__init__.py | 2 + python/tvm/ansor/compute_dag.py | 50 +- python/tvm/ansor/measure.py | 434 +++++++++++++++++ python/tvm/ansor/state.py | 91 +++- python/tvm/ansor/task.py | 59 +++ python/tvm/ansor/utils.py | 229 +++++++++ src/ansor/compute_dag.cc | 33 +- src/ansor/compute_dag.h | 8 +- src/ansor/loop_state.cc | 522 +++++++++++---------- src/ansor/measure.cc | 202 +++++--- src/ansor/search_task.cc | 66 ++- src/ansor/search_task.h | 10 +- tests/cpp/ansor_test.cc | 4 +- tests/python/unittest/test_ansor_common.py | 120 +++-- 14 files changed, 1401 insertions(+), 429 deletions(-) create mode 100644 python/tvm/ansor/measure.py create mode 100644 python/tvm/ansor/task.py create mode 100644 python/tvm/ansor/utils.py diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index aaa0e9c9174d..cb039cf07d5f 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -18,3 +18,5 @@ """Namespace for Ansor autoSchedule""" from .compute_dag import ComputeDAG +from .task import SearchTask +from .measure import MeasureInput, LocalBuilder, LocalRunner diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index 3c46440f75ba..a66a181f054c 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -25,10 +25,56 @@ from . import _ffi_api +class LayoutRewriteLevel(object): + NO_REWRITE = 0 # No layout rewrite + PLACEHOLDER_REWRITE = 1 # Only rewrite layout of placeholder in the compute dag + COMPUTE_REWRITE = 2 # Only rewrite compute body for new layout in the compute dag + BOTH_REWRITE = 3 # Rewrite both placeholder and compute body in the compute dag + + @tvm._ffi.register_object("ansor.ComputeDAG") class ComputeDAG(Object): + """ + Parameters + ---------- + tensors : List[Tensor] + """ + def __init__(self, tensors): self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, tensors) - def get_init_state(self) -> State: - return self.init_state + def get_init_state(self): + """ Get init state of this ComputeDAG + + Returns + ------- + state : State + """ + return _ffi_api.ComputeDAGGetInitState(self) + + def apply_steps_from_state(self, state, layout_rewrite_level): + """ + Parameters + ---------- + state : State + layout_rewrite_level : LayoutRewriteLevel(***) + + Returns + ------- + sch : Schedule + args : List[Tensor] + """ + sch, args = _ffi_api.ComputeDAGApplyStepsFromState(self, state) + return sch, args + + def print_python_code_from_state(self, state): + """ + Parameters + ---------- + state : State + + Returns + ------- + str : Str + """ + return _ffi_api.ComputeDAGPrintPythonCodeFromState(self, state) diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py new file mode 100644 index 000000000000..72dd3cbfcf92 --- /dev/null +++ b/python/tvm/ansor/measure.py @@ -0,0 +1,434 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import +"""Distributed measurement infrastructure to measure the runtime costs of tensor programs + +These functions are responsible for building the tvm module, uploading it to +remote devices, recording the running time costs, and checking the correctness of the output. + +We implement these in python to utilize python's multiprocessing and error handling +""" +from typing import List +import os +import time +import shutil +import logging +import traceback +import tempfile +import multiprocessing + +import tvm._ffi +from tvm.runtime import Object, module, ndarray +from tvm.driver import build_module +from tvm.target import build_config +from ..contrib import tar, ndk +from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout, request_remote, check_remote +from .compute_dag import LayoutRewriteLevel + +from . import _ffi_api + +logger = logging.getLogger('ansor') + + +@tvm._ffi.register_object("ansor.MeasureInput") +class MeasureInput(Object): + """ + Parameters + ---------- + task : SearchTask + state : State + """ + + def __init__(self, task, state): + self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state) + + +@tvm._ffi.register_object("ansor.BuildResult") +class BuildResult(Object): + """ + Parameters + ---------- + filename : Str + args : List[Tensor] + error_no : Int + error_msg : Str + time_cost : Float + """ + + def __init__(self, filename, args, error_no, error_msg, time_cost): + self.__init_handle_by_constructor__( + _ffi_api.BuildResult, filename, args, error_no, + error_msg if error_msg else "", time_cost) + + +@tvm._ffi.register_object("ansor.MeasureResult") +class MeasureResult(Object): + """ + Parameters + ---------- + costs : List[Float] + error_no : Int + error_msg : Str + all_cost : Float + timestamp : Float + """ + + def __init__(self, costs, error_no, error_msg, all_cost, timestamp): + self.__init_handle_by_constructor__( + _ffi_api.MeasureResult, costs, error_no, + error_msg if error_msg else "", all_cost, timestamp) + + +@tvm._ffi.register_object("ansor.Builder") +class Builder(Object): + def build(self, measure_inputs, verbose=0): + """ + Parameters + ---------- + measure_inputs : List[MeasureInput] + verbost : Int + + Returns + ------- + res : List[BuildResult] + """ + return _ffi_api.BuilderBuild(self, measure_inputs, verbose) + + +@tvm._ffi.register_object("ansor.Runner") +class Runner(Object): + def run(self, measure_inputs, build_results, verbose=0): + """ + Parameters + ---------- + measure_inputs : List[MeasureInput] + build_results : List[BuildResult] + + Returns + ------- + res : List[MeasureResult] + """ + return _ffi_api.RunnerRun(self, measure_inputs, build_results, verbose) + + +@tvm._ffi.register_object("ansor.LocalBuilder") +class LocalBuilder(Builder): + """ + Parameters + ---------- + timeout : Int + n_parallel : Int + build_func : Str + """ + + def __init__(self, + timeout=15, + n_parallel=multiprocessing.cpu_count(), + build_func='default'): + self.__init_handle_by_constructor__( + _ffi_api.LocalBuilder, timeout, n_parallel, build_func) + + +@tvm._ffi.register_object("ansor.LocalRunner") +class LocalRunner(Runner): + """ + Parameters + ---------- + timeout : Int + number : Int + repeat : Int + min_repeat_ms : Int + cooldown_interval : Float + """ + + def __init__(self, + timeout=10, + number=3, + repeat=1, + min_repeat_ms=0, + cooldown_interval=0.0): + self.__init_handle_by_constructor__( + _ffi_api.LocalRunner, timeout, number, repeat, min_repeat_ms, cooldown_interval) + + +MAX_ERROR_MSG_LEN = 512 + + +class MeasureErrorNo(object): + """Error type for MeasureResult""" + NO_ERROR = 0 # No error + INSTANTIATION_ERROR = 1 # Errors happen when apply transform steps from init state + # Errors happen when compiling code on host (e.g. tvm.build) + COMPILE_HOST = 2 + COMPILE_DEVICE = 3 # Errors happen when compiling code on device + # (e.g. OpenCL JIT on the device) + RUNTIME_DEVICE = 4 # Errors happen when run program on device + WRONG_ANSWER = 5 # Answer is wrong when compared to a reference output + BUILD_TIMEOUT = 6 # Timeout during compilation + RUN_TIMEOUT = 7 # Timeout during run + UNKNOWN_ERROR = 8 # Unknown error + + +def make_error_msg(): + error_msg = str(traceback.format_exc()) + if len(error_msg) > MAX_ERROR_MSG_LEN: + error_msg = error_msg[:MAX_ERROR_MSG_LEN//2] + \ + "\n...\n" + error_msg[-MAX_ERROR_MSG_LEN//2:] + return error_msg + + +global global_build_arguments +global global_run_arguments + + +def local_build_worker(index): + # We use fork to copy arguments from a global variable. + # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool + measure_inputs, build_func, timeout, verbose = global_build_arguments + assert isinstance(build_func, str) + if build_func == 'default': + build_func = tar.tar + elif build_func == 'ndk': + build_func = ndk.create_shared + else: + raise ValueError("Invalid build_func" + build_func) + + def timed_func(): + tic = time.time() + inp = measure_inputs[index] + task = inp.task + + error_no = MeasureErrorNo.NO_ERROR + error_msg = None + args = [] + + try: + sch, args = task.compute_dag.apply_steps_from_state( + inp.state, LayoutRewriteLevel.BOTH_REWRITE) + except Exception: + error_no = MeasureErrorNo.INSTANTIATION_ERROR + error_msg = make_error_msg() + + if error_no == 0: + dirname = tempfile.mkdtemp() + filename = os.path.join( + dirname, "tmp_func." + build_func.output_format) + + try: + with build_config(unroll_max_extent=task.hardware_params.max_unroll_vec): + func = build_module.build( + sch, args, target=task.target, target_host=task.target_host) + func.export_library(filename, build_func) + except Exception: + error_no = MeasureErrorNo.COMPILE_HOST + error_msg = make_error_msg() + else: + filename = "" + + if verbose >= 1: + if error_no == MeasureErrorNo.NO_ERROR: + print(".", end="") + else: + print(".E", end="") # Build error + return filename, args, error_no, error_msg, time.time() - tic + + res = call_func_with_timeout(timeout, timed_func) + if isinstance(res, TimeoutError): + if verbose >= 1: + print(".T", end="") # Build timeout + res = None, [], MeasureErrorNo.BUILD_TIMEOUT, None, timeout + + return res + + +@tvm._ffi.register_func("ansor.local_builder.build") +def local_builder_build(inputs: List[MeasureInput], timeout: float, n_parallel: int, build_func: str, verbose: int): + # We use fork to copy arguments from a global variable. + # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool + global global_build_arguments + global_build_arguments = (inputs, build_func, timeout, verbose) + + pool = NoDaemonPool(n_parallel) + tuple_res = pool.map(local_build_worker, range(len(inputs))) + pool.terminate() + pool.join() + del pool + + results = [] + for res in tuple_res: + results.append(BuildResult(*res)) + + return results + + +@tvm._ffi.register_func("ansor.rpc_runner.run") +def rpc_runner_run(inputs: List[MeasureInput], build_results: List[BuildResult], + key: str, host: str, port: int, priority: int, timeout: float, + n_parallel: int, number: int, repeat: int, min_repeat_ms: int, + cooldown_interval: float, verbose: int): + global global_run_arguments + global_run_arguments = (inputs, build_results, key, host, port, priority, timeout, number, + repeat, min_repeat_ms, cooldown_interval, verbose) + + assert len(inputs) == len(build_results), \ + "Measure input size should be equal to build results" + pool = NoDaemonPool(n_parallel) + tuple_res = pool.map(rpc_run_worker, range(len(build_results))) + pool.terminate() + pool.join() + del pool + + results = [] + for res in tuple_res: + results.append(MeasureResult(*res)) + + if verbose >= 1: + print("") + + return results + + +def rpc_run_worker(index): + inputs, build_results, key, host, port, priority, timeout, number, \ + repeat, min_repeat_ms, cooldown_interval, verbose = global_run_arguments + + MAX_FLOAT = 1e10 # We use 1e10 instead of sys.float_info.max for better readability in log + inp = inputs[index] + build_res = build_results[index] + + if build_res.error_no != MeasureErrorNo.NO_ERROR: + return (MAX_FLOAT,), build_res.error_no, build_res.error_msg, build_res.time_cost, time.time() + + def timed_func(): + tic = time.time() + error_no = 0 + error_msg = None + try: + # upload built module + remote = request_remote(key, host, port, priority, timeout) + remote.upload(build_res.filename) + func = remote.load_module(os.path.split(build_res.filename)[1]) + ctx = remote.context(str(inp.task.target), 0) + time_f = func.time_evaluator( + func.entry_name, ctx, number=number, repeat=repeat, min_repeat_ms=min_repeat_ms) + except Exception: + costs = (MAX_FLOAT,) + error_no = MeasureErrorNo.COMPILE_DEVICE + error_msg = make_error_msg() + + if error_no == 0: + try: + args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in + build_res.args] + ctx.sync() + + costs = time_f(*args).results + # clean up remote files + remote.remove(build_res.filename) + remote.remove(os.path.splitext(build_res.filename)[0] + '.so') + remote.remove('') + except Exception: + costs = (MAX_FLOAT,) + error_no = MeasureErrorNo.RUNTIME_DEVICE + error_msg = make_error_msg() + + shutil.rmtree(os.path.dirname(build_res.filename)) + toc = time.time() + + time.sleep(cooldown_interval) + if verbose >= 1: + if error_no == MeasureErrorNo.NO_ERROR: + print("*", end="") + else: + print("*E", end="") # Run error + + return costs, error_no, error_msg, toc - tic + build_res.time_cost, toc + + res = call_func_with_timeout(timeout, timed_func) + + if isinstance(res, TimeoutError): + if verbose >= 1: + print("*T", end="") # Run timeout + res = (MAX_FLOAT,), MeasureErrorNo.RUN_TIMEOUT, None, build_res.time_cost + \ + timeout, time.time() + return res + + +@tvm._ffi.register_func("ansor.local_runner.run") +def local_run(inputs: List[MeasureInput], build_results: List[BuildResult], + timeout: float, number: int, repeat: int, min_repeat_ms: int, + cooldown_interval: float, verbose: int): + MAX_FLOAT = 1e10 # We use 1e10 instead of sys.float_info.max for better readability in log + + def timed_func(inp, build_res): + tic = time.time() + error_no = 0 + error_msg = None + try: + func = module.load_module(build_res.filename) + ctx = ndarray.context(str(inp.task.target), 0) + time_f = func.time_evaluator( + func.entry_name, ctx, number=number, repeat=repeat, min_repeat_ms=min_repeat_ms) + except Exception: + costs = (MAX_FLOAT,) + error_no = MeasureErrorNo.COMPILE_DEVICE + error_msg = make_error_msg() + + if error_no == 0: + try: + args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in + build_res.args] + ctx.sync() + + costs = time_f(*args).results + except Exception: + costs = (MAX_FLOAT,) + error_no = MeasureErrorNo.RUNTIME_DEVICE + error_msg = make_error_msg() + + shutil.rmtree(os.path.dirname(build_res.filename)) + toc = time.time() + time.sleep(cooldown_interval) + + if verbose >= 1: + if error_no == MeasureErrorNo.NO_ERROR: + print("*", end="") + else: + print("*E", end="") # Run error + return costs, error_no, error_msg, toc - tic + build_res.time_cost, toc + + measure_results = [] + assert len(inputs) == len(build_results), \ + "Measure input size should be equal to build results" + for inp, build_res in zip(inputs, build_results): + if build_res.error_no != 0: + res = ( + MAX_FLOAT,), build_res.error_no, build_res.error_msg, build_res.time_cost, time.time() + else: + res = call_func_with_timeout( + timeout, timed_func, args=(inp, build_res)) + if isinstance(res, TimeoutError): + if verbose >= 1: + print("*T", end="") # Run timeout + res = ( + MAX_FLOAT,), MeasureErrorNo.RUN_TIMEOUT, None, build_res.time_cost + timeout, time.time() + measure_results.append(MeasureResult(*res)) + + if verbose >= 1: + print("") + + return measure_results diff --git a/python/tvm/ansor/state.py b/python/tvm/ansor/state.py index 9a8810190199..7de95a8a74af 100644 --- a/python/tvm/ansor/state.py +++ b/python/tvm/ansor/state.py @@ -25,21 +25,41 @@ @tvm._ffi.register_object("ansor.Iterator") class Iterator(Object): + """ ... + """ pass @tvm._ffi.register_object("ansor.Stage") class Stage(Object): + """ ... + """ def iterator(self, index): + """ + Parameters + ---------- + index : Int + + Returns + ------- + iter : Iterator + """ return _ffi_api.StageGetIterator(self, index) def iterators(self): + """ + Returns + ------- + iters : List[Iterator] + """ return _ffi_api.StageGetIterators(self) @tvm._ffi.register_object("ansor.State") class State(Object): + """ ... + """ def stage(self, index): """ @@ -93,10 +113,12 @@ def split(self, stage_id, it, lengths, inner_to_outer=True): ------- state : State The updated state + res_its : List[Iterator] + The splited Iterators result """ - state = _ffi_api.StateSplit(self, stage_id, it, lengths, - inner_to_outer) - return state + state, res_its = _ffi_api.StateSplit(self, stage_id, it, lengths, + inner_to_outer) + return state, res_its def follow_split(self, stage_id, it, src_step_id, n_split): """ @@ -115,10 +137,12 @@ def follow_split(self, stage_id, it, src_step_id, n_split): ------- state : State The updated state + res_its : List[Iterator] + The splited Iterators result """ - state = _ffi_api.StateFollowSplit(self, stage_id, it, src_step_id, - n_split) - return state + state, res_its = _ffi_api.StateFollowSplit(self, stage_id, it, + src_step_id, n_split) + return state, res_its def follow_fused_split(self, stage_id, it, src_step_ids, level, factor_or_nparts): @@ -140,10 +164,13 @@ def follow_fused_split(self, stage_id, it, src_step_ids, level, ------- state : State The updated state + res_its : List[Iterator] + The splited Iterators result """ - state = _ffi_api.StateFollowFusedSplit(self, stage_id, it, src_step_ids, - level, factor_or_nparts) - return state + state, res_its = _ffi_api.StateFollowFusedSplit(self, stage_id, it, + src_step_ids, level, + factor_or_nparts) + return state, res_its def fuse(self, stage_id, iters): """ @@ -158,9 +185,11 @@ def fuse(self, stage_id, iters): ------- state : State The updated state + res_it : Iterator + The fused Iterator """ - state = _ffi_api.StateFuse(self, stage_id, iters) - return state + state, res_it = _ffi_api.StateFuse(self, stage_id, iters) + return state, res_it def vectorize(self, stage_id, it): """ @@ -175,9 +204,11 @@ def vectorize(self, stage_id, it): ------- state : State The updated state + res_it : Iterator + The vectorized Iterator """ - state = _ffi_api.StateVectorize(self, stage_id, it) - return state + state, res_it = _ffi_api.StateVectorize(self, stage_id, it) + return state, res_it def parallel(self, stage_id, it): """ @@ -192,9 +223,11 @@ def parallel(self, stage_id, it): ------- state : State The updated state + res_it : Iterator + The paralleled Iterator """ - state = _ffi_api.StateParallel(self, stage_id, it) - return state + state, res_it = _ffi_api.StateParallel(self, stage_id, it) + return state, res_it def unroll(self, stage_id, it, max_unroll=-1): """ @@ -210,9 +243,11 @@ def unroll(self, stage_id, it, max_unroll=-1): ------- state : State The updated state + res_it : Iterator + The unrolled Iterator """ - state = _ffi_api.StateUnroll(self, stage_id, it, max_unroll) - return state + state, res_it = _ffi_api.StateUnroll(self, stage_id, it, max_unroll) + return state, res_it def bind_thread(self, stage_id, it, thread_type): """ @@ -229,9 +264,12 @@ def bind_thread(self, stage_id, it, thread_type): ------- state : State The updated state + res_it : Iterator + The thread binded Iterator """ - state = _ffi_api.StateBindThread(self, stage_id, it, thread_type) - return state + state, res_it = _ffi_api.StateBindThread(self, stage_id, it, + thread_type) + return state, res_it def compute_at(self, stage_id, target_stage_id, target_iter): """ @@ -311,10 +349,12 @@ def cache_read(self, stage_id, scope_name, reader_stage_ids, task_dag): ------- state : State The updated state + new_stage_id : Int + The added staged id """ - state = _ffi_api.StateCacheRead(self, stage_id, scope_name, - reader_stage_ids, task_dag) - return state + state, new_stage_id = _ffi_api.StateCacheRead(self, stage_id, + scope_name, reader_stage_ids, task_dag) + return state, int(new_stage_id) def cache_write(self, stage_id, scope_name, task_dag): """ @@ -329,9 +369,12 @@ def cache_write(self, stage_id, scope_name, task_dag): ------- state : State The updated state + new_stage_id : Int + The added staged id """ - state = _ffi_api.StateCacheWrite(self, stage_id, scope_name, task_dag) - return state + state, new_stage_id = _ffi_api.StateCacheWrite(self, stage_id, + scope_name, task_dag) + return state, int(new_stage_id) def pragma(self, stage_id, it, pragma_type): """ diff --git a/python/tvm/ansor/task.py b/python/tvm/ansor/task.py new file mode 100644 index 000000000000..245cf4c727ae --- /dev/null +++ b/python/tvm/ansor/task.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import +""" ... """ + +import tvm._ffi +from tvm.runtime import Object + +from . import _ffi_api + +@tvm._ffi.register_object("ansor.HardwareParams") +class HardwareParams(Object): + """ + Parameters + ---------- + num_cores : Int + vector_unit_bytes : Int + cache_line_bytes : Int + max_unroll_vec : Int + max_innermost_split_factor : Int + """ + + def __init__(self, num_cores, vector_unit_bytes, cache_line_bytes, + max_unroll_vec, max_innermost_split_factor): + self.__init_handle_by_constructor__(_ffi_api.HardwareParams, num_cores, + vector_unit_bytes, cache_line_bytes, max_unroll_vec, + max_innermost_split_factor) + + +@tvm._ffi.register_object("ansor.SearchTask") +class SearchTask(Object): + """ + Parameters + ---------- + dag : ComputeDAG + workload_key : Str + target : tvm.target + target_host : tvm.target + hardware_params : HardwareParams + """ + + def __init__(self, dag, workload_key, target, target_host=None, + hardware_params=None): + self.__init_handle_by_constructor__(_ffi_api.SearchTask, dag, + workload_key, target, target_host, hardware_params) diff --git a/python/tvm/ansor/utils.py b/python/tvm/ansor/utils.py new file mode 100644 index 000000000000..0216549c184a --- /dev/null +++ b/python/tvm/ansor/utils.py @@ -0,0 +1,229 @@ +"""Common utilities""" +import multiprocessing +import multiprocessing.pool +import queue +import signal +import threading +import os + +import numpy as np + +try: + import psutil +except ImportError: + psutil = None + +from .. import rpc as _rpc +from tvm.tir import expr +from tvm.tir.transform import Simplify +from tvm.ir.transform import Sequential + + +def get_func_name(func): + """Get name of a function + + Parameters + ---------- + func: Function + The function + Returns + ------- + name: str + The name + """ + + return func.func_name if hasattr(func, 'func_name') else func.__name__ + + +def get_const_int(exp): + """Verifies expr is integer and get the constant value. + + Parameters + ---------- + exp : tvm.Expr or int + The input expression. + + Returns + ------- + out_value : int + The output. + """ + if isinstance(exp, int): + return exp + if not isinstance(exp, (expr.IntImm)): + opt = Sequential([Simplify()]) + exp = opt(exp) + if not isinstance(exp, (expr.IntImm)): + raise ValueError("Expect value to be constant int") + return exp.value + + +def get_const_tuple(in_tuple): + """Verifies input tuple is IntImm, returns tuple of int. + + Parameters + ---------- + in_tuple : tuple of Expr + The input. + + Returns + ------- + out_tuple : tuple of int + The output. + """ + return tuple(get_const_int(x) for x in in_tuple) + + +def to_str_round(x, decimal=6): + """Convert object to str and round float numbers""" + if isinstance(x, str): + return x + if isinstance(x, (list, tuple)) or isinstance(x, np.ndarray): + return "[" + ", ".join([to_str_round(y, decimal=decimal) + for y in x]) + "]" + if isinstance(x, dict): + return str({k: eval(to_str_round(v)) for k, v in x.items()}) + if isinstance(x, int): + return str(x) + if isinstance(x, (np.float32, np.float64, float)): + format_str = "%%.%df" % decimal + return format_str % x + raise ValueError("Invalid value: " + str(x) + "\ttype: " + str(type(x))) + + +def array_mean(arr): + """Mean function for tvm array (Array)""" + return sum(x.value for x in arr) / len(arr) + + +class NoDaemonProcess(multiprocessing.Process): + @property + def daemon(self): + return False + + @daemon.setter + def daemon(self, value): + pass + + +class NoDaemonContext(type(multiprocessing.get_context())): + Process = NoDaemonProcess + + +class NoDaemonPool(multiprocessing.pool.Pool): + """A no daemon pool version of multiprocessing.Pool. + This allows us to start new processings inside the worker function""" + + def __init__(self, *args, **kwargs): + kwargs['context'] = NoDaemonContext() + super().__init__(*args, **kwargs) + + +def kill_child_processes(parent_pid, sig=signal.SIGTERM): + """kill all child processes recursively""" + try: + parent = psutil.Process(parent_pid) + except psutil.NoSuchProcess: + return + children = parent.children(recursive=True) + for process in children: + try: + process.send_signal(sig) + except psutil.NoSuchProcess: + return + + +def call_func_with_timeout(timeout, func, args=(), kwargs=None): + """Call a function with timeout""" + def func_wrapper(que): + if kwargs: + que.put(func(*args, **kwargs)) + else: + que.put(func(*args)) + + que = multiprocessing.Queue(2) + process = multiprocessing.Process(target=func_wrapper, args=(que,)) + process.start() + process.join(timeout) + + try: + res = que.get(block=False) + except queue.Empty: + res = TimeoutError() + + # clean queue and process + kill_child_processes(process.pid) + process.terminate() + process.join() + que.close() + que.join_thread() + del process + del que + + return res + + +def request_remote(device_key, host=None, port=None, priority=1, timeout=60): + """Request a remote session + + Parameters + ---------- + device_key: string + The device key of registered device in tracker + host: host, optional + The host address of rpc tracker. + If is none, will use environment variable "TVM_TRACKER_HOST" + port: int, optional + The port of rpc tracker. + If is none, will use environment variable "TVM_TRACKER_PORT" + priority: int, optional + The priority of this request, larger is more prior + timeout: float, optional + The timeout of this session (units: second) + + Returns + ------ + session: RPCSession + """ + # connect to the tracker + host = host or os.environ['TVM_TRACKER_HOST'] + port = port or int(os.environ['TVM_TRACKER_PORT']) + + tracker = _rpc.connect_tracker(host, port) + remote = tracker.request(device_key, priority=priority, + session_timeout=timeout) + return remote + + +def check_remote(device_key, host=None, port=None, priority=100, timeout=10): + """ + Check the availability of a remote device + + Parameters + ---------- + device_key: string + device key of registered device in tracker + host: host, optional + The host address of rpc tracker. + If is none, will use environment variable "TVM_TRACKER_HOST" + port: int, optional + The port address of rpc tracker. + If is none, will use environment variable "TVM_TRACKER_PORT" + priority: int, optional + The priority of this request, larger is more prior + timeout: float, optional + The timeout of this check (units: seconds). + + Returns + ------- + available: bool + True if can find available device + """ + + def _check(): + remote = request_remote(device_key, host, port, priority) + + t = threading.Thread(target=_check, ) + t.start() + t.join(timeout) + return not t.is_alive() diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index 1e33068e4965..c9415a70c303 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -588,15 +588,6 @@ ComputeDAG ComputeDAGNode::make_by_workload_key(const std::string& workload_key) return ComputeDAGNode::make(std::move(tens)); } -void ComputeDAGNode::VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("tensors", &tensors); - v->Visit("ops", &ops); - v->Visit("flop_ct", &flop_ct); - v->Visit("access_analyzer", &access_analyzer); - State s = Downcast(init_state); - v->Visit("init_state", &s); -} - // Implemented in multi_stage_policy.cc // Extract primitive iterators from a nested fused or splitted iterator's name extern void ExtractOriginalIterators(const std::string& name, std::set* rets); @@ -1166,9 +1157,6 @@ std::pair > ComputeDAG::ReplaySteps( return std::make_pair(schedule, operator->()->tensors); } -TVM_REGISTER_GLOBAL("ansor.ComputeDAG") -.set_body_typed([](Array tensors) { return ComputeDAGNode::make(tensors); }); - TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { auto* node = static_cast(ref.get()); @@ -1262,5 +1250,26 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) } }); +TVM_REGISTER_GLOBAL("ansor.ComputeDAG") +.set_body_typed([](Array tensors) { + return ComputeDAGNode::make(tensors); +}); + +TVM_REGISTER_GLOBAL("ansor.ComputeDAGGetInitState") +.set_body_method(&ComputeDAG::GetInitState); + +TVM_REGISTER_GLOBAL("ansor.ComputeDAGApplyStepsFromState") +.set_body_typed([](const ComputeDAG& dag, const State& state) { + te::Schedule sch; + Array return_tensors; + std::tie(sch, return_tensors) = dag.ApplySteps(state->transform_steps); + return Array{sch, return_tensors}; +}); + +TVM_REGISTER_GLOBAL("ansor.ComputeDAGPrintPythonCodeFromState") +.set_body_typed([](const ComputeDAG& dag, const State& state) { + return dag.PrintStepsAsPython(state->transform_steps); +}); + } // namespace ansor } // namespace tvm diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h index 9d0708a77f1c..3b4c80c50ad8 100644 --- a/src/ansor/compute_dag.h +++ b/src/ansor/compute_dag.h @@ -93,7 +93,13 @@ class ComputeDAGNode : public Object { AccessAnalyzer access_analyzer; // Read/Write accesss static analyzer ObjectRef init_state; // initial states - void VisitAttrs(tvm::AttrVisitor* v); + void VisitAttrs(tvm::AttrVisitor* v) { + LOG(INFO) << "ComputeDAG"; + v->Visit("tensors", &tensors); + v->Visit("ops", &ops); + v->Visit("flop_ct", &flop_ct); + v->Visit("access_analyzer", &access_analyzer); + } static ComputeDAG make(Array tensors); static ComputeDAG make_by_workload_key(const std::string& workload_key); diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index ebea5a1e472a..e18d36e34581 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -2,8 +2,10 @@ * Copyright (c) 2020 by Contributors */ #include "loop_state.h" -#include + #include +#include + #include "utils.h" namespace tvm { @@ -16,15 +18,15 @@ Stage StageNode::make(te::Operation op) { auto node = make_object(); if (op->IsInstance()) { node->op_type = kCompute; - auto *pop = op.as(); + auto* pop = op.as(); for (const auto& axis : pop->axis) { node->iters.push_back(IteratorNode::make(CleanName(axis->var->name_hint), - axis->dom, kSpace, kNone)); + axis->dom, kSpace, kNone)); } for (const auto& axis : pop->reduce_axis) { node->iters.push_back(IteratorNode::make(CleanName(axis->var->name_hint), - axis->dom, kReduce, kNone)); + axis->dom, kReduce, kNone)); } } else if (op->IsInstance()) { node->op_type = kPlaceholder; @@ -54,9 +56,8 @@ Stage StageNode::make(te::Operation op, StageType op_type, } Stage StageNode::make(te::Operation op, StageType op_type, - std::vector&& iters, - ComputeAtType compute_at, int16_t auto_unroll_max_step, - int storage_offset) { + std::vector&& iters, ComputeAtType compute_at, + int16_t auto_unroll_max_step, int storage_offset) { auto node = make_object(); node->op = std::move(op); node->op_type = op_type; @@ -67,16 +68,6 @@ Stage StageNode::make(te::Operation op, StageType op_type, return Stage(node); } -TVM_REGISTER_GLOBAL("ansor.StageGetIterator") - .set_body_typed([](const Stage& stage, int index) { - return stage->iters[index]; - }); - -TVM_REGISTER_GLOBAL("ansor.StageGetIterators") - .set_body_typed([](const Stage& stage) { - return Array(stage->iters); - }); - State StateNode::make_empty_state() { auto node = make_object(); node->attach_map = AttachMapNode::make(); @@ -97,8 +88,8 @@ State StateNode::make(const Array& ops) { } State StateNode::make(const std::vector& stages, - const std::vector& transform_steps, - bool complete, ObjectRef aux_info) { + const std::vector& transform_steps, bool complete, + ObjectRef aux_info) { auto node = make_object(); node->stages = stages; node->transform_steps = transform_steps; @@ -131,31 +122,32 @@ std::vector State::split(int stage_id, const Iterator& it, bool inner_to_outer) { const Stage& stage = operator->()->stages[stage_id]; - SplitStep step = SplitStepNode::make(stage_id, GetIndex(stage->iters, it), - it->range.defined() ? it->range->extent : PrimExpr(), lengths, - inner_to_outer); + SplitStep step = + SplitStepNode::make(stage_id, GetIndex(stage->iters, it), + it->range.defined() ? it->range->extent : PrimExpr(), + lengths, inner_to_outer); CopyOnWrite()->transform_steps.push_back(step); return DoSplitStep(step); } -std::vector State::follow_split(int stage_id, - const Iterator& it, int src_step_id, int n_split) { +std::vector State::follow_split(int stage_id, const Iterator& it, + int src_step_id, int n_split) { const Stage& stage = operator->()->stages[stage_id]; - FollowSplitStep step = FollowSplitStepNode::make(stage_id, - GetIndex(stage->iters, it), src_step_id, n_split); + FollowSplitStep step = FollowSplitStepNode::make( + stage_id, GetIndex(stage->iters, it), src_step_id, n_split); CopyOnWrite()->transform_steps.push_back(step); return DoFollowSplitStep(step); } - std::vector State::follow_fused_split( int stage_id, const Iterator& it, const std::vector& src_step_ids, int level, bool factor_or_nparts) { const Stage& stage = operator->()->stages[stage_id]; - FollowFusedSplitStep step = FollowFusedSplitStepNode::make(stage_id, - GetIndex(stage->iters, it), src_step_ids, level, factor_or_nparts); + FollowFusedSplitStep step = + FollowFusedSplitStepNode::make(stage_id, GetIndex(stage->iters, it), + src_step_ids, level, factor_or_nparts); CopyOnWrite()->transform_steps.push_back(step); return DoFollowFusedSplitStep(step); } @@ -179,16 +171,16 @@ Iterator State::vectorize(int stage_id, const Iterator& it) { Iterator State::parallel(int stage_id, const Iterator& it) { const Stage& stage = operator->()->stages[stage_id]; - AnnotationStep step = AnnotationStepNode::make( - stage_id, GetIndex(stage->iters, it), kParallel); + AnnotationStep step = + AnnotationStepNode::make(stage_id, GetIndex(stage->iters, it), kParallel); CopyOnWrite()->transform_steps.push_back(step); return DoAnnotationStep(step); } Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) { const Stage& stage = operator->()->stages[stage_id]; - AnnotationStep step = AnnotationStepNode::make(stage_id, - GetIndex(stage->iters, it), kUnroll); + AnnotationStep step = + AnnotationStepNode::make(stage_id, GetIndex(stage->iters, it), kUnroll); // don't unroll if the extent is larger than max_unroll if (max_unroll != -1 && it->range.defined()) { @@ -206,8 +198,8 @@ Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) { void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) { const Stage& target_stage = operator->()->stages[target_stage_id]; - ComputeAtStep step = ComputeAtStepNode::make(stage_id, target_stage_id, - GetIndex(target_stage->iters, target_iter)); + ComputeAtStep step = ComputeAtStepNode::make( + stage_id, target_stage_id, GetIndex(target_stage->iters, target_iter)); CopyOnWrite()->transform_steps.push_back(step); return DoComputeAtStep(step); } @@ -227,8 +219,8 @@ void State::compute_inline(int stage_id) { void State::pack_for_vec(int stage_id, const Iterator& target_iter, int vec_size) { const Stage& stage = operator->()->stages[stage_id]; - PackForVecStep step = PackForVecStepNode::make(stage_id, - GetIndex(stage->iters, target_iter), vec_size); + PackForVecStep step = PackForVecStepNode::make( + stage_id, GetIndex(stage->iters, target_iter), vec_size); CopyOnWrite()->transform_steps.push_back(step); return DoPackForVecStep(step); } @@ -240,8 +232,8 @@ Iterator State::bind_thread(int stage_id, const Iterator& it, LOG(FATAL) << "thread_type error, valide: kVThread, kBlockX, kThreadX, " << "kThreadY"; } - AnnotationStep step = AnnotationStepNode::make(stage_id, - GetIndex(stage->iters, it), thread_type); + AnnotationStep step = AnnotationStepNode::make( + stage_id, GetIndex(stage->iters, it), thread_type); CopyOnWrite()->transform_steps.push_back(step); return DoAnnotationStep(step); } @@ -249,14 +241,14 @@ Iterator State::bind_thread(int stage_id, const Iterator& it, int State::cache_read(int stage_id, const std::string& scope_name, const std::vector& reader_stage_ids, const ComputeDAG& task_dag) { - CacheReadStep step = CacheReadStepNode::make(stage_id, scope_name, - reader_stage_ids); + CacheReadStep step = + CacheReadStepNode::make(stage_id, scope_name, reader_stage_ids); CopyOnWrite()->transform_steps.push_back(step); return DoCacheReadStep(step, task_dag); } int State::cache_write(int stage_id, const std::string& scope_name, - const ComputeDAG& task_dag) { + const ComputeDAG& task_dag) { CacheWriteStep step = CacheWriteStepNode::make(stage_id, scope_name); CopyOnWrite()->transform_steps.push_back(step); return DoCacheWriteStep(step, task_dag); @@ -265,14 +257,14 @@ int State::cache_write(int stage_id, const std::string& scope_name, void State::pragma(int stage_id, const Iterator& it, const std::string& pragma_type) { const Stage& stage = operator->()->stages[stage_id]; - PragmaStep step = PragmaStepNode::make(stage_id, GetIndex(stage->iters, it), - pragma_type); + PragmaStep step = + PragmaStepNode::make(stage_id, GetIndex(stage->iters, it), pragma_type); CopyOnWrite()->transform_steps.push_back(step); return DoPragmaStep(step); } int State::rfactor(int stage_id, const Iterator& it, int factor_iter_id, - const ComputeDAG& task_dag) { + const ComputeDAG& task_dag) { const Stage& stage = operator->()->stages[stage_id]; RfactorStep step = RfactorStepNode::make(stage_id, GetIndex(stage->iters, it), factor_iter_id); @@ -283,8 +275,8 @@ int State::rfactor(int stage_id, const Iterator& it, int factor_iter_id, void State::storage_align(int stage_id, const Iterator& it, int factor, int offset) { const Stage& stage = operator->()->stages[stage_id]; - StorageAlignStep step = StorageAlignStepNode::make(stage_id, - GetIndex(stage->iters, it), factor, offset); + StorageAlignStep step = StorageAlignStepNode::make( + stage_id, GetIndex(stage->iters, it), factor, offset); CopyOnWrite()->transform_steps.push_back(step); return DoStorageAlignStep(step); } @@ -299,11 +291,9 @@ void State::DoReorderStep(const ReorderStep& step) { } StateNode* pstate = CopyOnWrite(); - pstate->stages[step->stage_id] = StageNode::make(stage->op, stage->op_type, - std::move(iters), - stage->compute_at, - stage->auto_unroll_max_step, - stage->storage_offset); + pstate->stages[step->stage_id] = StageNode::make( + stage->op, stage->op_type, std::move(iters), stage->compute_at, + stage->auto_unroll_max_step, stage->storage_offset); } // common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep @@ -324,7 +314,8 @@ std::vector State::DoSplitStepCommon( std::vector outs; for (size_t i = 0; i < lengths.size(); ++i) { - PrimExpr l; std::string name; + PrimExpr l; + std::string name; if (inner_to_outer) { l = lengths[lengths.size() - i - 1]; name = it->name + "." + std::to_string(lengths.size() - i); @@ -350,26 +341,26 @@ std::vector State::DoSplitStepCommon( range = Range::make_by_min_extent(tosplit_min, tosplit_extent); } if (inner_to_outer) { - outs.push_back(IteratorNode::make(it->name + ".0", range, it->iter_type, - kNone)); + outs.push_back( + IteratorNode::make(it->name + ".0", range, it->iter_type, kNone)); std::reverse(outs.begin(), outs.end()); } else { - outs.push_back(IteratorNode::make( - it->name + "." + std::to_string(lengths.size()), range, it->iter_type, - kNone)); + outs.push_back( + IteratorNode::make(it->name + "." + std::to_string(lengths.size()), + range, it->iter_type, kNone)); } std::vector new_iters; new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + iter_id); new_iters.insert(new_iters.end(), outs.begin(), outs.end()); - new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id+1, + new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id + 1, stage->iters.end()); StateNode* pstate = CopyOnWrite(); - pstate->stages[stage_id] = StageNode::make(stage->op, stage->op_type, - std::move(new_iters), stage->compute_at, stage->auto_unroll_max_step, - stage->storage_offset); + pstate->stages[stage_id] = StageNode::make( + stage->op, stage->op_type, std::move(new_iters), stage->compute_at, + stage->auto_unroll_max_step, stage->storage_offset); // we have to replace the iterators in attach map, // these two vectors keep the replacement mapping @@ -396,8 +387,8 @@ std::vector State::DoFollowSplitStep(const FollowSplitStep& step) { std::vector State::DoFollowFusedSplitStep( const FollowFusedSplitStep& step) { - const PrimExpr& length = step->ExtractSplitLength( - operator->()->transform_steps); + const PrimExpr& length = + step->ExtractSplitLength(operator->()->transform_steps); return DoSplitStepCommon(step->stage_id, step->iter_id, {length}, step->factor_or_nparts); } @@ -414,15 +405,14 @@ Iterator State::DoFuseStep(const FuseStep& step) { std::vector ori_iters; for (size_t i = 0; i < step->fused_ids.size(); ++i) { if (i > 0) { - CHECK_EQ(step->fused_ids[i], step->fused_ids[i-1] + 1); + CHECK_EQ(step->fused_ids[i], step->fused_ids[i - 1] + 1); } if (i != step->fused_ids.size() - 1) { const auto& iter_to_attached_stage = - operator->()->attach_map->iter_to_attached_stages; - if (iter_to_attached_stage.find(std::make_pair(stage_id, - step->fused_ids[i])) - != iter_to_attached_stage.end()) { + operator->()->attach_map->iter_to_attached_stages; + if (iter_to_attached_stage.find(std::make_pair( + stage_id, step->fused_ids[i])) != iter_to_attached_stage.end()) { LOG(FATAL) << "Invalid Fuse. Because you want to fuse iterators " "that have been attached by some stages"; } @@ -451,8 +441,8 @@ Iterator State::DoFuseStep(const FuseStep& step) { if (new_extent.defined()) { range = Range::make_by_min_extent(0, new_extent); } - Iterator new_it = IteratorNode::make(new_name, range, new_iter_type, kNone, - &ori_iters); + Iterator new_it = + IteratorNode::make(new_name, range, new_iter_type, kNone, &ori_iters); std::vector new_iters; new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + step->fused_ids.front()); @@ -462,9 +452,9 @@ Iterator State::DoFuseStep(const FuseStep& step) { stage->iters.end()); StateNode* pstate = CopyOnWrite(); - pstate->stages[stage_id] = StageNode::make(stage->op, stage->op_type, - std::move(new_iters), stage->compute_at, stage->auto_unroll_max_step, - stage->storage_offset); + pstate->stages[stage_id] = StageNode::make( + stage->op, stage->op_type, std::move(new_iters), stage->compute_at, + stage->auto_unroll_max_step, stage->storage_offset); // we have to replace the iterators in attach map, // these two vectors keep the replacement mapping @@ -477,7 +467,7 @@ Iterator State::DoFuseStep(const FuseStep& step) { } else if (i > end_id) { // move forward from_iters.emplace_back(stage_id, i); to_iters.emplace_back(stage_id, i - end_id + begin_id); - } else { // move to the fused id + } else { // move to the fused id from_iters.emplace_back(stage_id, i); to_iters.emplace_back(stage_id, begin_id); } @@ -491,7 +481,7 @@ Iterator State::DoAnnotationStep(const AnnotationStep& step) { Iterator it = stage->iters[step->iter_id]; Iterator new_it = IteratorNode::make(it->name, it->range, it->iter_type, - step->annotation, &it->ori_iters); + step->annotation, &it->ori_iters); Stage new_stage = stage; new_stage.CopyOnWrite()->iters[step->iter_id] = new_it; StateNode* pstate = CopyOnWrite(); @@ -508,8 +498,8 @@ void State::DoComputeAtStep(const ComputeAtStep& step) { std::vector new_iters; for (const Iterator& it : stage->iters) { size_t s = it->name.size(); - if (s >= 2 && it->name[s-2] == '.' && it->name[s-1] >= '1' && - it->name[s-1] <= '4') { + if (s >= 2 && it->name[s - 2] == '.' && it->name[s - 1] >= '1' && + it->name[s - 1] <= '4') { // We use a dangerous heuristic rule here : For multi level splitted // iterators, we assume their length does not change after compute_at. // Reason: These iterators are generated in MultiStagePolicy by multi @@ -519,14 +509,14 @@ void State::DoComputeAtStep(const ComputeAtStep& step) { new_iters.push_back(it); } else { new_iters.push_back(IteratorNode::make(it->name, Range(), it->iter_type, - it->annotation, &it->ori_iters)); + it->annotation, &it->ori_iters)); } } StateNode* pstate = CopyOnWrite(); - pstate->stages[step->stage_id] = StageNode::make(stage->op, stage->op_type, - std::move(new_iters), kIter, stage->auto_unroll_max_step, - stage->storage_offset); + pstate->stages[step->stage_id] = + StageNode::make(stage->op, stage->op_type, std::move(new_iters), kIter, + stage->auto_unroll_max_step, stage->storage_offset); pstate->attach_map.SetComputeAtIter(step->stage_id, step->target_stage_id, step->target_iter_id); } @@ -540,14 +530,14 @@ void State::DoComputeRootStep(const ComputeRootStep& step) { std::vector new_iters; for (const Iterator& it : stage->iters) { new_iters.push_back(IteratorNode::make(it->name, Range(), it->iter_type, - it->annotation, &it->ori_iters)); + it->annotation, &it->ori_iters)); } // update attach map StateNode* pstate = CopyOnWrite(); - pstate->stages[step->stage_id] = StageNode::make(stage->op, stage->op_type, - std::move(new_iters), kRoot, stage->auto_unroll_max_step, - stage->storage_offset); + pstate->stages[step->stage_id] = + StageNode::make(stage->op, stage->op_type, std::move(new_iters), kRoot, + stage->auto_unroll_max_step, stage->storage_offset); pstate->attach_map.DeleteStage(step->stage_id); } @@ -560,9 +550,10 @@ void State::DoComputeInlineStep(const ComputeInlineStep& step) { const auto& iter_to_attached_stages = pstate->attach_map->iter_to_attached_stages; for (size_t i = 0; i < stage->iters.size(); ++i) { - CHECK_EQ(iter_to_attached_stages.count(std::make_pair(step->stage_id, i)), 0) - << "Invalid compute_inline: Because there are some other stages " - "that are attached to the target stage"; + CHECK_EQ(iter_to_attached_stages.count(std::make_pair(step->stage_id, i)), + 0) + << "Invalid compute_inline: Because there are some other stages " + "that are attached to the target stage"; } pstate->stages[step->stage_id].CopyOnWrite()->compute_at = kInlined; @@ -576,7 +567,8 @@ void State::DoPackForVecStep(const PackForVecStep& step) { // Common part for steps that add new stages // (e.g. CacheReadStep, CacheWriteStep, RfactorStep) void AddStageModificationSteps(size_t step_id, - const std::vector& transform_steps, std::vector* replay_steps) { + const std::vector& transform_steps, + std::vector* replay_steps) { const Step& step = transform_steps[step_id]; if (step->IsInstance() || step->IsInstance()) { @@ -615,14 +607,15 @@ int State::DoCacheReadStep(const CacheReadStep& step, const ComputeDAG& dag) { // target -> target + target_store // Should update target's op, insert new stage, update the later stage's op pstate->stages[step->stage_id].CopyOnWrite()->op = - operator->()->task_dag->ops[step->stage_id]; - pstate->stages.insert(pstate->stages.begin() + step->stage_id + 1, + operator->()->task_dag->ops[step->stage_id]; + pstate->stages.insert( + pstate->stages.begin() + step->stage_id + 1, StageNode::make(operator->()->task_dag->ops[step->stage_id + 1])); for (size_t i = step->stage_id + 2; i < operator->()->stages.size(); ++i) { pstate->stages[i].CopyOnWrite()->op = operator->()->task_dag->ops[i]; } - pstate->attach_map = - operator->()->attach_map.ApplyStageIdOfffset(step->stage_id + 1, 1); + pstate->attach_map = operator->()->attach_map.ApplyStageIdOfffset( + step->stage_id + 1, 1); return step->stage_id + 1; } @@ -637,8 +630,9 @@ int State::DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag) { } } - int last_dag_op_size = pstate->task_dag.defined() ? - pstate->task_dag->ops.size() : dag->ops.size(); + int last_dag_op_size = pstate->task_dag.defined() + ? pstate->task_dag->ops.size() + : dag->ops.size(); dag.ReplayAndGetDAG(replay_steps, &(pstate->task_dag)); int added_ops = pstate->task_dag->ops.size() - last_dag_op_size; CHECK_GE(added_ops, 1); @@ -646,7 +640,8 @@ int State::DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag) { // target -> target_compute + target // Assume target stage has never been applied any steps before cache_write // Should insert new stage, update target stage, update the later stage's op - pstate->stages.insert(pstate->stages.begin() + step->stage_id, + pstate->stages.insert( + pstate->stages.begin() + step->stage_id, StageNode::make(operator->()->task_dag->ops[step->stage_id])); pstate->stages[step->stage_id + 1] = StageNode::make(operator->()->task_dag->ops[step->stage_id + 1]); @@ -657,7 +652,8 @@ int State::DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag) { // for more information // TODO(jcf94): Fix this if (added_ops == 2) { - pstate->stages.insert(pstate->stages.begin() + next_stage_id, + pstate->stages.insert( + pstate->stages.begin() + next_stage_id, StageNode::make(operator->()->task_dag->ops[next_stage_id])); next_stage_id++; } else if (added_ops > 2) { @@ -666,8 +662,8 @@ int State::DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag) { for (size_t i = next_stage_id; i < operator->()->task_dag->ops.size(); ++i) { pstate->stages[i].CopyOnWrite()->op = operator->()->task_dag->ops[i]; } - pstate->attach_map = - operator->()->attach_map.ApplyStageIdOfffset(step->stage_id, added_ops); + pstate->attach_map = operator->()->attach_map.ApplyStageIdOfffset( + step->stage_id, added_ops); return step->stage_id; } @@ -702,18 +698,20 @@ int State::DoRfactorStep(const RfactorStep& step, const ComputeDAG& dag) { // target -> target_compute + target // Should insert new stage, update target stage, update the later stage's op - pstate->stages.insert(pstate->stages.begin() + step->stage_id, + pstate->stages.insert( + pstate->stages.begin() + step->stage_id, StageNode::make(operator->()->task_dag->ops[step->stage_id])); // maintain the compute_at type of target stage - Stage target_stage = StageNode::make(operator->()->task_dag->ops[step->stage_id + 1]); + Stage target_stage = + StageNode::make(operator->()->task_dag->ops[step->stage_id + 1]); target_stage.CopyOnWrite()->compute_at = compute_at_type; pstate->stages[step->stage_id + 1] = target_stage; for (size_t i = step->stage_id + 2; i < operator->()->stages.size(); ++i) { pstate->stages[i].CopyOnWrite()->op = operator->()->task_dag->ops[i]; } - pstate->attach_map = - operator->()->attach_map.ApplyStageIdOfffset(step->stage_id, 1); + pstate->attach_map = operator->()->attach_map.ApplyStageIdOfffset( + step->stage_id, 1); return step->stage_id; } @@ -777,7 +775,6 @@ void State::DoSteps(const std::vector& steps, const ComputeDAG& dag) { } } - void PrintStage(std::ostream* os, int stage_id, const StateNode* state, size_t base_indent, bool delete_trivial_loop) { const Stage& stage = state->stages[stage_id]; @@ -786,15 +783,15 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state, for (size_t j = 0; j < base_indent; ++j) { *os << " "; } - *os << stage->op->func_name() << " auto_unroll: " - << stage->auto_unroll_max_step << "\n"; + *os << stage->op->func_name() + << " auto_unroll: " << stage->auto_unroll_max_step << "\n"; } if (stage->storage_offset != 0) { for (size_t j = 0; j < base_indent; ++j) { *os << " "; } - *os << stage->op->func_name() << " storage_offset: " - << stage->storage_offset << "\n"; + *os << stage->op->func_name() + << " storage_offset: " << stage->storage_offset << "\n"; } size_t indent = 0; @@ -802,26 +799,46 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state, const Iterator& iter = stage->iters[i]; if (!(delete_trivial_loop && iter->range.defined() && - is_one(iter->range->extent))) { + is_one(iter->range->extent))) { for (size_t j = 0; j < base_indent + indent; ++j) { *os << " "; } switch (iter->annotation) { - case kNone: *os << "for "; break; - case kUnroll: *os << "unroll "; break; - case kParallel: *os << "parallel "; break; - case kVectorize: *os << "vectorize "; break; - case kVThread: *os << "vthread "; break; - case kBlockX: *os << "gpu.blockIdx.x "; break; - case kBlockY: *os << "gpu.blockIdx.y "; break; - case kThreadX: *os << "gpu.threadIdx.x "; break; - case kThreadY: *os << "gpu.threadIdx.y "; break; + case kNone: + *os << "for "; + break; + case kUnroll: + *os << "unroll "; + break; + case kParallel: + *os << "parallel "; + break; + case kVectorize: + *os << "vectorize "; + break; + case kVThread: + *os << "vthread "; + break; + case kBlockX: + *os << "gpu.blockIdx.x "; + break; + case kBlockY: + *os << "gpu.blockIdx.y "; + break; + case kThreadX: + *os << "gpu.threadIdx.x "; + break; + case kThreadY: + *os << "gpu.threadIdx.y "; + break; } if (iter->range.defined()) { *os << iter->name << " (" << iter->range->min << "," - << iter->range->extent << ")" << "\n"; + << iter->range->extent << ")" + << "\n"; } else { - *os << iter->name << " (None)" << "\n"; + *os << iter->name << " (None)" + << "\n"; } indent += 2; @@ -885,6 +902,110 @@ std::string State::ToStr(bool delete_trivial_loop) const { return os.str(); } +void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, + int target_iter_id) { + AttachMapNode* pnode = CopyOnWrite(); + + // delete the current entry of stage + DeleteStageEntry(pnode, stage_id); + + // store the new relation + IterKey iter_key(target_stage_id, target_iter_id); + pnode->stage_to_attach_iter[stage_id] = + std::make_pair(target_stage_id, target_iter_id); + pnode->iter_to_attached_stages[iter_key].push_back(stage_id); +} + +void AttachMap::DeleteStage(int stage_id) { + AttachMapNode* pnode = CopyOnWrite(); + + // delete the entry of old stage + DeleteStageEntry(pnode, stage_id); +} + +void AttachMap::ReplaceIters(const std::vector& old_iters, + const std::vector& new_iters) { + AttachMapNode* pnode = CopyOnWrite(); + + CHECK_EQ(old_iters.size(), new_iters.size()); + for (size_t i = 0; i < old_iters.size(); ++i) { + auto entry = pnode->iter_to_attached_stages.find(old_iters[i]); + if (entry == pnode->iter_to_attached_stages.end()) { + continue; + } + + // replace iter in the value of `stage_to_attach_iter` + for (const auto& s : entry->second) { + pnode->stage_to_attach_iter[s] = new_iters[i]; + } + + // replace iter in the key of `iter_to_attached_stages` + std::vector attached_stages = std::move(entry->second); + pnode->iter_to_attached_stages.erase(entry); + pnode->iter_to_attached_stages[new_iters[i]] = std::move(attached_stages); + } +} + +void AttachMap::DeleteStageEntry(AttachMapNode* pnode, int stage_id) { + auto old_entry = pnode->stage_to_attach_iter.find(stage_id); + if (old_entry != pnode->stage_to_attach_iter.end()) { + // delete value in `iter_to_attached_stages` + auto entry2 = pnode->iter_to_attached_stages.find(old_entry->second); + DeleteItem(&entry2->second, stage_id); + if (entry2->second.size() == 0) { + pnode->iter_to_attached_stages.erase(entry2); + } + // delete key in `stage_to_attach_iter` + pnode->stage_to_attach_iter.erase(old_entry); + } +} + +AttachMap AttachMap::ApplyStageIdOfffset(int start_id, int offset) const { + AttachMap map = AttachMapNode::make(); + auto pmap = map.CopyOnWrite(); + for (const auto& x : operator->()->stage_to_attach_iter) { + auto key = x.first; + if (key >= start_id) { + key += offset; + } + auto value = x.second; + if (value.first >= start_id) { + value.first += offset; + } + pmap->stage_to_attach_iter.insert(std::make_pair(key, value)); + } + for (const auto& x : operator->()->iter_to_attached_stages) { + auto key = x.first; + if (key.first >= start_id) { + key.first += offset; + } + auto value = x.second; + for (auto& i : value) { + if (i >= start_id) { + i += offset; + } + } + pmap->iter_to_attached_stages.insert(std::make_pair(key, value)); + } + return map; +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + PrintState(&p->stream, node, true); + }); + +TVM_REGISTER_GLOBAL("ansor.StageGetIterator") + .set_body_typed([](const Stage& stage, int index) { + return stage->iters[index]; + }); + +TVM_REGISTER_GLOBAL("ansor.StageGetIterators") + .set_body_typed([](const Stage& stage) { + return Array(stage->iters); + }); + TVM_REGISTER_GLOBAL("ansor.StateGetStage") .set_body_typed([](const State& state, int index) { return state->stages[index]; @@ -908,21 +1029,20 @@ TVM_REGISTER_GLOBAL("ansor.StateReorder") TVM_REGISTER_GLOBAL("ansor.StateSplit") .set_body_typed([](State state, int stage_id, const Iterator& it, - const Array& lengths, - bool inner_to_outer) { + const Array& lengths, bool inner_to_outer) { std::vector len; for (const auto& i : lengths) { len.push_back(i); } - state.split(stage_id, it, len, inner_to_outer); - return state; + const auto& res = state.split(stage_id, it, len, inner_to_outer); + return Array{state, Array(res)}; }); TVM_REGISTER_GLOBAL("ansor.StateFollowSplit") .set_body_typed([](State state, int stage_id, const Iterator& it, int src_step_id, int n_split) { - state.follow_split(stage_id, it, src_step_id, n_split); - return state; + const auto& res = state.follow_split(stage_id, it, src_step_id, n_split); + return Array{state, Array(res)}; }); TVM_REGISTER_GLOBAL("ansor.StateFollowFusedSplit") @@ -933,9 +1053,9 @@ TVM_REGISTER_GLOBAL("ansor.StateFollowFusedSplit") for (const auto& i : src_step_ids) { array_src_step_ids.push_back(i->value); } - state.follow_fused_split(stage_id, it, array_src_step_ids, level, - factor_or_nparts); - return state; + const auto& res = state.follow_fused_split( + stage_id, it, array_src_step_ids, level, factor_or_nparts); + return Array{state, Array(res)}; }); TVM_REGISTER_GLOBAL("ansor.StateFuse") @@ -945,36 +1065,35 @@ TVM_REGISTER_GLOBAL("ansor.StateFuse") for (const auto& i : iters) { its.push_back(i); } - state.fuse(stage_id, its); - return state; + const auto& res = state.fuse(stage_id, its); + return Array{state, res}; }); TVM_REGISTER_GLOBAL("ansor.StateVectorize") - .set_body_typed([](State state, int stage_id, - const Iterator& it) { - state.vectorize(stage_id, it); - return state; + .set_body_typed([](State state, int stage_id, const Iterator& it) { + const auto& res = state.vectorize(stage_id, it); + return Array{state, res}; }); TVM_REGISTER_GLOBAL("ansor.StateParallel") - .set_body_typed([](State state, int stage_id, - const Iterator& it) { - state.parallel(stage_id, it); - return state; + .set_body_typed([](State state, int stage_id, const Iterator& it) { + const auto& res = state.parallel(stage_id, it); + return Array{state, res}; }); TVM_REGISTER_GLOBAL("ansor.StateUnroll") - .set_body_typed([](State state, int stage_id, - const Iterator& it, int max_unroll) { - state.unroll(stage_id, it, max_unroll); - return state; + .set_body_typed([](State state, int stage_id, const Iterator& it, + int max_unroll) { + const auto& res = state.unroll(stage_id, it, max_unroll); + return Array{state, res}; }); TVM_REGISTER_GLOBAL("ansor.StateBindThread") - .set_body_typed([](State state, int stage_id, - const Iterator& it, int thread_type) { - state.bind_thread(stage_id, it, IteratorAnnotation(thread_type)); - return state; + .set_body_typed([](State state, int stage_id, const Iterator& it, + int thread_type) { + const auto& res = + state.bind_thread(stage_id, it, IteratorAnnotation(thread_type)); + return Array{state, res}; }); TVM_REGISTER_GLOBAL("ansor.StateComputeAt") @@ -997,8 +1116,8 @@ TVM_REGISTER_GLOBAL("ansor.StateComputeInline") }); TVM_REGISTER_GLOBAL("ansor.StatePackForVec") - .set_body_typed([](State state, int stage_id, - const Iterator& target_iter, int vec_size) { + .set_body_typed([](State state, int stage_id, const Iterator& target_iter, + int vec_size) { state.pack_for_vec(stage_id, target_iter, vec_size); return state; }); @@ -1011,110 +1130,17 @@ TVM_REGISTER_GLOBAL("ansor.StateCacheRead") for (const auto& i : reader_stage_ids) { array_reader_stage_ids.push_back(i->value); } - state.cache_read(stage_id, scope_name, array_reader_stage_ids, task_dag); - return state; + int res = state.cache_read(stage_id, scope_name, array_reader_stage_ids, + task_dag); + return Array{state, IntImm(DataType::Int(32), res)}; }); TVM_REGISTER_GLOBAL("ansor.StateCacheWrite") .set_body_typed([](State state, int stage_id, const std::string& scope_name, const ComputeDAG& task_dag) { - state.cache_write(stage_id, scope_name, task_dag); - return state; + int res = state.cache_write(stage_id, scope_name, task_dag); + return Array{state, IntImm(DataType::Int(32), res)}; }); -void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, - int target_iter_id) { - AttachMapNode* pnode = CopyOnWrite(); - - // delete the current entry of stage - DeleteStageEntry(pnode, stage_id); - - // store the new relation - IterKey iter_key(target_stage_id, target_iter_id); - pnode->stage_to_attach_iter[stage_id] = std::make_pair(target_stage_id, - target_iter_id); - pnode->iter_to_attached_stages[iter_key].push_back(stage_id); -} - -void AttachMap::DeleteStage(int stage_id) { - AttachMapNode* pnode = CopyOnWrite(); - - // delete the entry of old stage - DeleteStageEntry(pnode, stage_id); -} - -void AttachMap::ReplaceIters(const std::vector& old_iters, - const std::vector& new_iters) { - AttachMapNode* pnode = CopyOnWrite(); - - CHECK_EQ(old_iters.size(), new_iters.size()); - for (size_t i = 0; i < old_iters.size(); ++i) { - auto entry = pnode->iter_to_attached_stages.find(old_iters[i]); - if (entry == pnode->iter_to_attached_stages.end()) { - continue; - } - - // replace iter in the value of `stage_to_attach_iter` - for (const auto& s : entry->second) { - pnode->stage_to_attach_iter[s] = new_iters[i]; - } - - // replace iter in the key of `iter_to_attached_stages` - std::vector attached_stages = std::move(entry->second); - pnode->iter_to_attached_stages.erase(entry); - pnode->iter_to_attached_stages[new_iters[i]] = std::move(attached_stages); - } -} - -void AttachMap::DeleteStageEntry(AttachMapNode *pnode, int stage_id) { - auto old_entry = pnode->stage_to_attach_iter.find(stage_id); - if (old_entry != pnode->stage_to_attach_iter.end()) { - // delete value in `iter_to_attached_stages` - auto entry2 = pnode->iter_to_attached_stages.find(old_entry->second); - DeleteItem(&entry2->second, stage_id); - if (entry2->second.size() == 0) { - pnode->iter_to_attached_stages.erase(entry2); - } - // delete key in `stage_to_attach_iter` - pnode->stage_to_attach_iter.erase(old_entry); - } -} - -AttachMap AttachMap::ApplyStageIdOfffset(int start_id, int offset) const { - AttachMap map = AttachMapNode::make(); - auto pmap = map.CopyOnWrite(); - for (const auto& x : operator->()->stage_to_attach_iter) { - auto key = x.first; - if (key >= start_id) { - key += offset; - } - auto value = x.second; - if (value.first >= start_id) { - value.first += offset; - } - pmap->stage_to_attach_iter.insert(std::make_pair(key, value)); - } - for (const auto& x : operator->()->iter_to_attached_stages) { - auto key = x.first; - if (key.first >= start_id) { - key.first += offset; - } - auto value = x.second; - for (auto& i : value) { - if (i >= start_id) { - i += offset; - } - } - pmap->iter_to_attached_stages.insert(std::make_pair(key, value)); - } - return map; -} - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { - auto* node = static_cast(ref.get()); - PrintState(&p->stream, node, true); -}); - } // namespace ansor } // namespace tvm diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc index 1bae02b3f2c5..b2cff24973bc 100644 --- a/src/ansor/measure.cc +++ b/src/ansor/measure.cc @@ -3,12 +3,13 @@ */ #include "measure.h" // #include -#include #include +#include + +#include #include #include #include -#include // #include "search_policy/search_policy.h" namespace tvm { @@ -25,16 +26,16 @@ TVM_REGISTER_OBJECT_TYPE(RPCRunnerNode); TVM_REGISTER_OBJECT_TYPE(LocalRunnerNode); TVM_REGISTER_OBJECT_TYPE(ProgramMeasurerNode); -const char *ErrorNoToStr[] = { - "NoError", - "InstantiationError", - "CompileHostError", - "CompileDeviceError", - "RuntimeDeviceError", - "WrongAnswerError", - "BuildTimeoutError", - "RunTimeoutError", - "UnknownError", +const char* ErrorNoToStr[] = { + "NoError", + "InstantiationError", + "CompileHostError", + "CompileDeviceError", + "RuntimeDeviceError", + "WrongAnswerError", + "BuildTimeoutError", + "RunTimeoutError", + "UnknownError", }; // Maker @@ -52,8 +53,9 @@ MeasureInput MeasureInputNode::copy() const { return MeasureInput(node); } -BuildResult BuildResultNode::make(std::string filename, Array args, int error_no, - std::string error_msg, double time_cost) { +BuildResult BuildResultNode::make(std::string filename, Array args, + int error_no, std::string error_msg, + double time_cost) { auto node = make_object(); node->filename = std::move(filename); node->args = std::move(args); @@ -64,7 +66,8 @@ BuildResult BuildResultNode::make(std::string filename, Array args, } MeasureResult MeasureResultNode::make(Array costs, int error_no, - std::string error_msg, double all_cost, double timestamp) { + std::string error_msg, double all_cost, + double timestamp) { auto node = make_object(); node->costs = std::move(costs); node->error_no = error_no; @@ -84,7 +87,8 @@ MeasureResult MeasureResultNode::copy() const { return MeasureResult(node); } -Builder LocalBuilderNode::make(int timeout, int n_parallel, const std::string& build_func) { +Builder LocalBuilderNode::make(int timeout, int n_parallel, + const std::string& build_func) { auto node = make_object(); node->timeout = timeout; node->n_parallel = n_parallel; @@ -93,9 +97,11 @@ Builder LocalBuilderNode::make(int timeout, int n_parallel, const std::string& b } // LocalBuilder and LocalRunner -Array LocalBuilderNode::Build(const Array &inputs, int verbose) { +Array LocalBuilderNode::Build(const Array& inputs, + int verbose) { if (const auto* f = runtime::Registry::Get("ansor.local_builder.build")) { - Array results = (*f)(inputs, timeout, n_parallel, build_func, verbose); + Array results = + (*f)(inputs, timeout, n_parallel, build_func, verbose); return results; } else { LOG(FATAL) << "ansor.local_builder.build is not registered"; @@ -103,9 +109,10 @@ Array LocalBuilderNode::Build(const Array &inputs, in return Array(); } -Runner RPCRunnerNode::make(const std::string& key, const std::string& host, int port, - int priority, int timeout, int n_parallel, int number, - int repeat, int min_repeat_ms, double cooldown_interval) { +Runner RPCRunnerNode::make(const std::string& key, const std::string& host, + int port, int priority, int timeout, int n_parallel, + int number, int repeat, int min_repeat_ms, + double cooldown_interval) { auto node = make_object(); node->key = key; node->host = host; @@ -124,9 +131,9 @@ Array RPCRunnerNode::Run(const Array& inputs, const Array& build_results, int verbose) { if (const auto* f = runtime::Registry::Get("ansor.rpc_runner.run")) { - Array results = (*f)(inputs, build_results, key, host, port, priority, - timeout, n_parallel, number, repeat, - min_repeat_ms, cooldown_interval, verbose); + Array results = (*f)( + inputs, build_results, key, host, port, priority, timeout, n_parallel, + number, repeat, min_repeat_ms, cooldown_interval, verbose); return results; } else { LOG(FATAL) << "ansor.rpc_runner.run is not registered"; @@ -145,12 +152,13 @@ Runner LocalRunnerNode::make(int timeout, int number, int repeat, return Runner(node); } -Array LocalRunnerNode::Run(const Array& inputs, - const Array& build_results, - int verbose) { +Array LocalRunnerNode::Run( + const Array& inputs, const Array& build_results, + int verbose) { if (const auto* f = runtime::Registry::Get("ansor.local_runner.run")) { - Array results = (*f)(inputs, build_results, timeout, number, - repeat, min_repeat_ms, cooldown_interval, verbose); + Array results = + (*f)(inputs, build_results, timeout, number, repeat, min_repeat_ms, + cooldown_interval, verbose); return results; } else { LOG(FATAL) << "ansor.local_runner.run is not registered"; @@ -167,8 +175,9 @@ ProgramMeasurer ProgramMeasurerNode::make(Builder builder, Runner runner, node->runner = std::move(runner); node->callbacks = std::move(callbacks); node->verbose = verbose; - node->max_continous_error = max_continous_error < 0 ? - DEFAULT_MAX_CONTINOUS_ERROR : max_continous_error; + node->max_continous_error = max_continous_error < 0 + ? DEFAULT_MAX_CONTINOUS_ERROR + : max_continous_error; return ProgramMeasurer(node); } @@ -192,12 +201,14 @@ void ProgramMeasurerNode::Measure(const SearchTask& task, batch_size = builder->n_parallel * 2; } - StdCout(verbose) << "Get " << inputs.size() << " programs for measure. (This may take a while)" - << std::endl; + StdCout(verbose) << "Get " << inputs.size() + << " programs for measure. (This may take a while)" + << std::endl; for (size_t i = 0; i < inputs.size(); i += batch_size) { - std::vector input_batch(inputs.begin() + i, - inputs.begin() + std::min(i + batch_size, inputs.size())); + std::vector input_batch( + inputs.begin() + i, + inputs.begin() + std::min(i + batch_size, inputs.size())); std::vector result_batch; // build and run @@ -207,7 +218,8 @@ void ProgramMeasurerNode::Measure(const SearchTask& task, for (size_t j = 0; j < input_batch.size(); ++j) { double flops; if (result_batch[j]->error_no == 0) { - flops = task->compute_dag->flop_ct / FloatArrayMean(result_batch[j]->costs); + flops = + task->compute_dag->flop_ct / FloatArrayMean(result_batch[j]->costs); error_ct = 0; } else { flops = 0.0; @@ -225,8 +237,8 @@ void ProgramMeasurerNode::Measure(const SearchTask& task, if (verbose >= 1) { std::cout << std::fixed << std::setprecision(2); std::cout << "===============================================\n"; - std::cout << "No: " << ct - << "\tGFLOPS: " << flops / 1e9 << " / " << best_flops[workload_key] / 1e9 + std::cout << "No: " << ct << "\tGFLOPS: " << flops / 1e9 << " / " + << best_flops[workload_key] / 1e9 << "\tresults: " << result_batch[j] << "\n"; std::cout << "===============================================\n"; std::cout << input_batch[j]->state << "\n"; @@ -261,7 +273,8 @@ void ProgramMeasurerNode::SilentMeasure(const SearchTask& task, // Call builder and runner Array build_res_batch = builder->Build(input_batch, verbose); - Array result_batch = runner->Run(input_batch, build_res_batch, verbose); + Array result_batch = + runner->Run(input_batch, build_res_batch, verbose); // Store result batch for (auto& res : result_batch) { @@ -271,44 +284,89 @@ void ProgramMeasurerNode::SilentMeasure(const SearchTask& task, // Printing functions TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { - p->stream << "MeasureInput()"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + p->stream << "MeasureInput()"; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { - auto* node = static_cast(ref.get()); - if (node->error_no == kNoError) { - p->stream << "MeasureResult(cost:["; - auto old_config = p->stream.precision(4); - for (size_t i = 0; i < node->costs.size(); ++i) { - auto pf = node->costs[i].as(); - CHECK(pf != nullptr); - p->stream << pf->value; - if (i != node->costs.size() - 1) { - p->stream << ","; + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + if (node->error_no == kNoError) { + p->stream << "MeasureResult(cost:["; + auto old_config = p->stream.precision(4); + for (size_t i = 0; i < node->costs.size(); ++i) { + auto pf = node->costs[i].as(); + CHECK(pf != nullptr); + p->stream << pf->value; + if (i != node->costs.size() - 1) { + p->stream << ","; + } + } + p->stream.precision(old_config); + p->stream << "], "; + p->stream << "error_no:" << 0 << ", " + << "all_cost:" << node->all_cost << ", " + << "Tstamp:" << node->timestamp << ")"; + } else { + p->stream << "MeasureResult(" + << "error_type:" << ErrorNoToStr[node->error_no] << ", " + << "error_msg:" << node->error_msg << ", " + << "all_cost:" << node->all_cost << ", " + << "Tstamp:" << node->timestamp << ")"; } - } - p->stream.precision(old_config); - p->stream << "], "; - p->stream << "error_no:" << 0 << ", " - << "all_cost:" << node->all_cost << ", " - << "Tstamp:" << node->timestamp << ")"; - } else { - p->stream << "MeasureResult(" - << "error_type:" << ErrorNoToStr[node->error_no] << ", " - << "error_msg:" << node->error_msg << ", " - << "all_cost:" << node->all_cost << ", " - << "Tstamp:" << node->timestamp << ")"; - } -}); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { - auto* node = static_cast(ref.get()); - p->stream << "BuildResult(" << node->filename << ", " << node->error_no - << ", " << node->time_cost << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "BuildResult(" << node->filename << ", " << node->error_no + << ", " << node->time_cost << ")"; + }); + +TVM_REGISTER_GLOBAL("ansor.MeasureInput") + .set_body_typed([](SearchTask task, State state) { + return MeasureInputNode::make(task, state); + }); + +TVM_REGISTER_GLOBAL("ansor.BuildResult") + .set_body_typed([](std::string filename, Array args, + int error_no, std::string error_msg, double time_cost) { + return BuildResultNode::make(filename, args, error_no, error_msg, + time_cost); + }); + +TVM_REGISTER_GLOBAL("ansor.MeasureResult") + .set_body_typed([](Array costs, int error_no, + std::string error_msg, double all_cost, + double timestamp) { + return MeasureResultNode::make(costs, error_no, error_msg, all_cost, + timestamp); + }); + +TVM_REGISTER_GLOBAL("ansor.BuilderBuild") + .set_body_typed([](const Builder& builder, + const Array& inputs, int verbose) { + return builder->Build(inputs, verbose); + }); + +TVM_REGISTER_GLOBAL("ansor.RunnerRun") + .set_body_typed([](const Runner& runner, const Array& inputs, + const Array& build_results, int verbose) { + return runner->Run(inputs, build_results, verbose); + }); + +TVM_REGISTER_GLOBAL("ansor.LocalBuilder") + .set_body_typed([](int timeout, int n_parallel, + const std::string& build_func) { + return LocalBuilderNode::make(timeout, n_parallel, build_func); + }); + +TVM_REGISTER_GLOBAL("ansor.LocalRunner") + .set_body_typed([](int timeout, int number, int repeat, int min_repeat_ms, + double cooldown_interval) { + return LocalRunnerNode::make(timeout, number, repeat, min_repeat_ms, + cooldown_interval); + }); } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_task.cc b/src/ansor/search_task.cc index b9cda9168b9e..93f3f60ea768 100644 --- a/src/ansor/search_task.cc +++ b/src/ansor/search_task.cc @@ -2,20 +2,23 @@ * Copyright (c) 2020 by Contributors */ #include "search_task.h" -#include -#include + #include -#include +#include +#include + #include +#include namespace tvm { namespace ansor { -TVM_REGISTER_OBJECT_TYPE(HardwareParamsNode); -TVM_REGISTER_OBJECT_TYPE(SearchTaskNode); +TVM_REGISTER_NODE_TYPE(HardwareParamsNode); +TVM_REGISTER_NODE_TYPE(SearchTaskNode); HardwareParams HardwareParamsNode::make(int num_cores, int vector_unit_bytes, - int cache_line_bytes, int max_unroll_vec, + int cache_line_bytes, + int max_unroll_vec, int max_innermost_split_factor) { auto node = make_object(); node->num_cores = num_cores; @@ -40,21 +43,19 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams( auto ctx = TVMContext{kDLGPU, 0}; auto func = tvm::runtime::Registry::Get("device_api.gpu"); CHECK(func != nullptr) << "Cannot find GPU device_api in registry"; - auto device_api = static_cast(((*func)()).operator void*()); + auto device_api = + static_cast(((*func)()).operator void*()); tvm::runtime::TVMRetValue ret; - device_api->GetAttr(ctx, - tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, - &ret); + device_api->GetAttr( + ctx, tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, &ret); p_hardware_params->max_shared_memory_per_block = ret; - device_api->GetAttr(ctx, - tvm::runtime::DeviceAttrKind::kMaxRegistersPerBlock, - &ret); + device_api->GetAttr( + ctx, tvm::runtime::DeviceAttrKind::kMaxRegistersPerBlock, &ret); p_hardware_params->max_registers_per_block = ret; - device_api->GetAttr(ctx, - tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, + device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, &ret); p_hardware_params->max_threads_per_block = ret; @@ -73,16 +74,15 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams( auto ctx = TVMContext{kDLOpenCL, 0}; auto func = tvm::runtime::Registry::Get("device_api.opencl"); CHECK(func != nullptr) << "Cannot find GPU device_api in registry"; - auto device_api = static_cast(((*func)()).operator void*()); + auto device_api = + static_cast(((*func)()).operator void*()); tvm::runtime::TVMRetValue ret; - device_api->GetAttr(ctx, - tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, - &ret); + device_api->GetAttr( + ctx, tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, &ret); p_hardware_params->max_shared_memory_per_block = ret; - device_api->GetAttr(ctx, - tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, + device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, &ret); p_hardware_params->max_threads_per_block = ret; @@ -99,9 +99,10 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams( return HardwareParams(); } - -SearchTask SearchTaskNode::make(ComputeDAG compute_dag, std::string workload_key, - Target target, Target target_host, HardwareParams hardware_params) { +SearchTask SearchTaskNode::make(ComputeDAG compute_dag, + std::string workload_key, Target target, + Target target_host, + HardwareParams hardware_params) { auto node = make_object(); node->compute_dag = std::move(compute_dag); node->workload_key = std::move(workload_key); @@ -116,5 +117,22 @@ SearchTask SearchTaskNode::make(ComputeDAG compute_dag, std::string workload_key return SearchTask(node); } +TVM_REGISTER_GLOBAL("ansor.HardwareParams") + .set_body_typed([](int num_cores, int vector_unit_bytes, + int cache_line_bytes, int max_unroll_vec, + int max_innermost_split_factor) { + return HardwareParamsNode::make(num_cores, vector_unit_bytes, + cache_line_bytes, max_unroll_vec, + max_innermost_split_factor); + }); + +TVM_REGISTER_GLOBAL("ansor.SearchTask") + .set_body_typed([](ComputeDAG compute_dag, std::string workload_key, + Target target, Target target_host, + HardwareParams hardware_params) { + return SearchTaskNode::make(compute_dag, workload_key, target, + target_host, hardware_params); + }); + } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_task.h b/src/ansor/search_task.h index 7db98a5197a5..9512013848b6 100644 --- a/src/ansor/search_task.h +++ b/src/ansor/search_task.h @@ -8,13 +8,16 @@ #define TVM_ANSOR_SEARCH_TASK_H_ #include + #include + #include "compute_dag.h" namespace tvm { namespace ansor { -class HardwareParams; class SearchTask; +class HardwareParams; +class SearchTask; /*! \brief Hardware related parameters */ class HardwareParamsNode : public Object { @@ -54,12 +57,11 @@ class HardwareParamsNode : public Object { static HardwareParams GetDefaultHardwareParams(const Target& target, const Target& target_host); - static constexpr const char *_type_key = "ansor.HardwareParams"; + static constexpr const char* _type_key = "ansor.HardwareParams"; TVM_DECLARE_FINAL_OBJECT_INFO(HardwareParamsNode, Object); }; TVM_DEFINE_COW_NODE_REF(HardwareParams, ObjectRef, HardwareParamsNode); - /*! \brief Meta-info for a search task */ class SearchTaskNode : public Object { public: @@ -81,7 +83,7 @@ class SearchTaskNode : public Object { Target target, Target target_host, HardwareParams hardware_params); - static constexpr const char *_type_key = "ansor.SearchTask"; + static constexpr const char* _type_key = "ansor.SearchTask"; TVM_DECLARE_FINAL_OBJECT_INFO(SearchTaskNode, Object); }; TVM_DEFINE_COW_NODE_REF(SearchTask, ObjectRef, SearchTaskNode); diff --git a/tests/cpp/ansor_test.cc b/tests/cpp/ansor_test.cc index 75a6cc00b802..e5a2c98c02a9 100644 --- a/tests/cpp/ansor_test.cc +++ b/tests/cpp/ansor_test.cc @@ -242,15 +242,15 @@ TEST(Step, SplitFuseReorder) { CHECK_EQ(s1->stages[2]->iters[0]->range->extent.as()->value, 512); its = s0.split(2, ti, {16}); + Iterator tio = its[0], tii = its[1]; CHECK_EQ(s0->stages[2]->iters[0]->range->extent.as()->value, 32); CHECK_EQ(s0->stages[2]->iters[1]->range->extent.as()->value, 16); - Iterator tio = its[0], tii = its[1]; its = s0.split(2, tj, {8}); + Iterator tjo = its[0], tji = its[1]; CHECK_EQ(s0->stages[2]->iters[2]->range->extent.as()->value, 64); CHECK_EQ(s0->stages[2]->iters[3]->range->extent.as()->value, 8); - Iterator tjo = its[0], tji = its[1]; s0.reorder(2, {tio, tjo, tk, tji, tii}); CHECK_EQ(s0->stages[2]->iters[2]->range->extent.as()->value, 512); diff --git a/tests/python/unittest/test_ansor_common.py b/tests/python/unittest/test_ansor_common.py index 4782f9130cea..da87ea5fe9cf 100644 --- a/tests/python/unittest/test_ansor_common.py +++ b/tests/python/unittest/test_ansor_common.py @@ -73,26 +73,26 @@ def test_state_split_fuse_reorder(): assert ti.range.extent == 512 - s0 = s0.split(2, ti, [16]) + s0, its = s0.split(2, ti, [16]) + tio = its[0] + tii = its[1] assert s0.stage(2).iterator(0).range.extent == 32 assert s0.stage(2).iterator(1).range.extent == 16 - tio = s0.stage(2).iterator(0) - tii = s0.stage(2).iterator(1) - s0 = s0.split(2, tj, [8]) + s0, its = s0.split(2, tj, [8]) + tjo = its[0] + tji = its[1] assert s0.stage(2).iterator(2).range.extent == 64 assert s0.stage(2).iterator(3).range.extent == 8 - tjo = s0.stage(2).iterator(2) - tji = s0.stage(2).iterator(3) s0 = s0.reorder(2, [tio, tjo, tk, tji, tii]) assert s0.stage(2).iterator(2).range.extent == 512 - s0 = s0.fuse(2, [tio, tjo]) - assert s0.stage(2).iterator(0).range.extent == 2048 + s0, res_it = s0.fuse(2, [tio, tjo]) + assert res_it.range.extent == 2048 - s1 = s1.split(2, ti, [8, 2]) - s1 = s1.split(2, tj, [32, 8], False) + s1, _ = s1.split(2, ti, [8, 2]) + s1, _ = s1.split(2, tj, [32, 8], False) assert s1.stage(2).iterator(0).range.extent == 32 assert s1.stage(2).iterator(1).range.extent == 8 assert s1.stage(2).iterator(2).range.extent == 2 @@ -186,22 +186,19 @@ def test_state_cache_read_write(): # 0: init state s0 = dag.get_init_state() ori_its = s0.stage(add).iterators() - s0 = s0.split(add, s0.stage(add).iterator(0), [2]) - s0 = s0.reorder(add, [s0.stage(add).iterator(0), ori_its[1], - s0.stage(add).iterator(1), ori_its[2], ori_its[3]]) + s0, its = s0.split(add, s0.stage(add).iterator(0), [2]) + s0 = s0.reorder(add, [its[0], ori_its[1], its[1], ori_its[2], ori_its[3]]) s0 = s0.compute_inline(relu) # 1: simple cache_write with compute_at - s0 = s0.cache_write(conv, "global", dag) - conv_global = conv + s0, conv_global = s0.cache_write(conv, "global", dag) conv += 1 relu += 1 add += 1 s0 = s0.compute_at(conv_global, conv, s0.stage(conv).iterator(3)) # 2: simple cache_read with compute_at - s0 = s0.cache_read(kernel, "global", [conv_global], dag) - kernel_global = kernel + 1 + s0, kernel_global = s0.cache_read(kernel, "global", [conv_global], dag) conv_global += 1 conv += 1 relu += 1 @@ -252,8 +249,7 @@ def test_state_cache_read_write(): # 3: two level cache_read with compute_at # preparing for GPU's shared memory & local memory - s0 = s0.cache_read(pad_temp, "global", [conv_global], dag) - pad_temp_global = pad_temp + 1 + s0, pad_temp_global = s0.cache_read(pad_temp, "global", [conv_global], dag) kernel_data += 1 kernel_split += 1 kernel += 1 @@ -262,8 +258,8 @@ def test_state_cache_read_write(): conv += 1 relu += 1 add += 1 - s0 = s0.cache_read(pad_temp_global, "shared", [conv_global], dag) - pad_temp_shared = pad_temp_global + 1 + s0, pad_temp_shared = s0.cache_read( + pad_temp_global, "shared", [conv_global], dag) kernel_data += 1 kernel_split += 1 kernel += 1 @@ -279,7 +275,7 @@ def test_state_cache_read_write(): # 4: cache_read with multi readers # This stage cannot be compute at to its consumer - s0 = s0.cache_read(data, "global", [pad_temp, add], dag) + s0, data_global = s0.cache_read(data, "global", [pad_temp, add], dag) pad_temp += 1 pad_temp_global += 1 pad_temp_shared += 1 @@ -350,7 +346,7 @@ def test_state_cache_read_write(): # 5: cache_write with multi outputs # See tests/cpp/ansor_test.cc for more information - s0 = s0.cache_write(kernel_split, "global", dag) + s0, _ = s0.cache_write(kernel_split, "global", dag) assert str(s0) == \ "Placeholder: Data, Kernel_data\n" + \ "for ax0 (0,4)\n" + \ @@ -424,40 +420,39 @@ def test_follow_split_follow_fused_split(): s0 = dag.get_init_state() C = 2 - s0 = s0.cache_write(C, "global", dag) - C_global = C + s0, C_global = s0.cache_write(C, "global", dag) C += 1 - s0 = s0.split(C, s0.stage(C).iterator(0), [4, 2, 8, 4], True) + s0, its0 = s0.split(C, s0.stage(C).iterator(0), [4, 2, 8, 4], True) split_step0 = s0.transform_steps_size() - 1 for level in range(1, 6): tmp = s0 - tmp = tmp.follow_split(C_global, tmp.stage( + tmp, _ = tmp.follow_split(C_global, tmp.stage( C_global).iterator(0), split_step0, level) for i in range(0, level): assert tmp.stage(C).iterator(i).range.extent == \ tmp.stage(C_global).iterator(i).range.extent - s0 = s0.split(C, s0.stage(C).iterator(5), [2, 2, 4, 8]) + s0, its1 = s0.split(C, s0.stage(C).iterator(5), [2, 2, 4, 8]) split_step1 = s0.transform_steps_size() - 1 - its = s0.stage(C).iterators() - s0 = s0.reorder(C, [its[0], its[5], its[1], its[6], its[2], its[7], - its[3], its[8], its[4], its[9]]) - s0 = s0.fuse(C, [s0.stage(C).iterator(0), s0.stage(C).iterator(1)]) - s0 = s0.fuse(C, [s0.stage(C).iterator(1), s0.stage(C).iterator(2)]) - s0 = s0.fuse(C, [s0.stage(C).iterator(2), s0.stage(C).iterator(3)]) - s0 = s0.fuse(C, [s0.stage(C).iterator(3), s0.stage(C).iterator(4)]) - s0 = s0.fuse(C, [s0.stage(C).iterator(4), s0.stage(C).iterator(5)]) + its = [] + for i0, i1 in zip(its0, its1): + its.append(i0) + its.append(i1) + s0 = s0.reorder(C, its) + for i in range(0, 5): + s0, _ = s0.fuse(C, [s0.stage(C).iterator(i), + s0.stage(C).iterator(i+1)]) for level in range(0, 4): tmp = s0 - tmp = tmp.follow_fused_split(C_global, tmp.stage(C_global).iterator(0), - [split_step0, split_step1], level, False) + tmp, _ = tmp.follow_fused_split(C_global, tmp.stage(C_global).iterator(0), + [split_step0, split_step1], level, False) assert tmp.stage(C).iterator(level+1).range.extent == \ tmp.stage(C_global).iterator(0).range.extent for level in range(0, 4): tmp = s0 - tmp = tmp.follow_fused_split(C_global, tmp.stage(C_global).iterator(0), - [split_step0, split_step1], level, True) + tmp, _ = tmp.follow_fused_split(C_global, tmp.stage(C_global).iterator(0), + [split_step0, split_step1], level, True) assert tmp.stage(C).iterator(level+1).range.extent == \ tmp.stage(C_global).iterator(1).range.extent @@ -466,6 +461,49 @@ def test_rfactor(): pass +def test_measure_local_builder_runner(): + dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) + + s0 = dag.get_init_state() + A, B, C = 0, 1, 2 + s0, C_global = s0.cache_write(C, "global", dag) + C += 1 + s0, its0 = s0.split(C, s0.stage(C).iterator(0), [4, 8, 8]) + s0, its1 = s0.split(C, s0.stage(C).iterator(4), [8, 4, 4]) + s0 = s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], + its0[3], its1[3]]) + s0 = s0.compute_at(C_global, C, s0.stage(C).iterator(3)) + s0, _ = s0.split(C_global, s0.stage(C_global).iterator(2), [16]) + s0, B_global = s0.cache_read(B, "global", [C_global], dag) + C += 1 + C_global += 1 + s0 = s0.compute_at(B_global, C_global, s0.stage(C_global).iterator(0)) + s0, A_global = s0.cache_read(A, "global", [C_global], dag) + B += 1 + B_global += 1 + C += 1 + C_global += 1 + s0 = s0.compute_at(A_global, C_global, s0.stage(C_global).iterator(2)) + + tgt = tvm.target.create("llvm") + task = ansor.SearchTask(dag, "test", tgt) + + minp = ansor.MeasureInput(task, s0) + local_builder = ansor.LocalBuilder() + local_runner = ansor.LocalRunner() + + bress = local_builder.build([minp]) + assert bress[0].error_no == 0 + mress = local_runner.run([minp], bress) + assert mress[0].error_no == 0 + + +def test_search_basic(): + dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) + tgt = tvm.target.create("llvm") + task = ansor.SearchTask(dag, "test", tgt) + + if __name__ == "__main__": test_compute_dag_basic() test_state_split_fuse_reorder() @@ -473,3 +511,5 @@ def test_rfactor(): test_state_cache_read_write() test_follow_split_follow_fused_split() test_rfactor() + test_measure_local_builder_runner() + # test_search_basic() From 6b21dc6e7318bb64827382f60ca07871860efa0a Mon Sep 17 00:00:00 2001 From: Chenfan Date: Thu, 4 Jun 2020 21:02:38 +0800 Subject: [PATCH 07/45] Add ansor.auto_schedule() API; First AutoSchedule working version(#8) * Add basic Python support for ansor.auto_schedule * Update AutoSchedule API * Bug fix for get the attach point of a fused iter * Update UT after infer bug fix --- python/tvm/ansor/__init__.py | 4 +- python/tvm/ansor/compute_dag.py | 2 +- python/tvm/ansor/cost_model/__init__.py | 20 +++ python/tvm/ansor/cost_model/cost_model.py | 48 ++++++ python/tvm/ansor/state.py | 6 +- python/tvm/ansor/task.py | 162 +++++++++++++++++- src/ansor/auto_schedule.cc | 85 +++++++++ src/ansor/auto_schedule.h | 61 +++++++ src/ansor/cost_model/cost_model.cc | 42 +++-- src/ansor/loop_state.cc | 21 +++ .../search_policy/meta_tile_rewrite_policy.cc | 11 +- src/ansor/search_policy/utils.h | 11 +- src/te/schedule/schedule_dataflow_rewrite.cc | 66 ++++++- tests/python/unittest/test_ansor_common.py | 40 ++++- 14 files changed, 547 insertions(+), 32 deletions(-) create mode 100644 python/tvm/ansor/cost_model/__init__.py create mode 100644 python/tvm/ansor/cost_model/cost_model.py create mode 100644 src/ansor/auto_schedule.cc create mode 100644 src/ansor/auto_schedule.h diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index cb039cf07d5f..70834ba8936f 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -18,5 +18,7 @@ """Namespace for Ansor autoSchedule""" from .compute_dag import ComputeDAG -from .task import SearchTask +from .task import SearchTask, MetaTileRewritePolicy, TuneOption +from .task import auto_schedule from .measure import MeasureInput, LocalBuilder, LocalRunner +from .cost_model import RandomModel diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index a66a181f054c..f3d27884d622 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -52,7 +52,7 @@ def get_init_state(self): """ return _ffi_api.ComputeDAGGetInitState(self) - def apply_steps_from_state(self, state, layout_rewrite_level): + def apply_steps_from_state(self, state, layout_rewrite_level=None): """ Parameters ---------- diff --git a/python/tvm/ansor/cost_model/__init__.py b/python/tvm/ansor/cost_model/__init__.py new file mode 100644 index 000000000000..aac062e964fd --- /dev/null +++ b/python/tvm/ansor/cost_model/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import, redefined-builtin +""" ... """ + +from .cost_model import RandomModel diff --git a/python/tvm/ansor/cost_model/cost_model.py b/python/tvm/ansor/cost_model/cost_model.py new file mode 100644 index 000000000000..aebc89f465a1 --- /dev/null +++ b/python/tvm/ansor/cost_model/cost_model.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import +""" ... """ +import ctypes +import numpy as np + +import tvm._ffi +from tvm.runtime import Object + +from .. import _ffi_api + + +@tvm._ffi.register_object("ansor.CostModel") +class CostModel(Object): + pass + + +@tvm._ffi.register_object("ansor.RandomModel") +class RandomModel(Object): + """ + """ + + def __init__(self): + self.__init_handle_by_constructor__(_ffi_api.RandomModel) + +# A random number generator func for c++'s RandomModel +@tvm._ffi.register_func("ansor.cost_model.random_number") +def random_number(n, return_ptr): + if n == 0: + return + return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_float)) + array_wrapper = np.ctypeslib.as_array(return_ptr, shape=(n,)) + array_wrapper[:] = np.random.uniform(0, 1, (n,)) diff --git a/python/tvm/ansor/state.py b/python/tvm/ansor/state.py index 7de95a8a74af..aa231ab6f4c6 100644 --- a/python/tvm/ansor/state.py +++ b/python/tvm/ansor/state.py @@ -408,9 +408,9 @@ def rfactor(self, stage_id, it, factor_iter_id, task_dag): state : State The updated state """ - state = _ffi_api.StateRfactor(self, stage_id, it, factor_iter_id, - task_dag) - return state + state, new_stage_id = _ffi_api.StateRfactor(self, stage_id, it, + factor_iter_id, task_dag) + return state, new_stage_id def storage_align(self, stage_id, it, factor, offset): """ diff --git a/python/tvm/ansor/task.py b/python/tvm/ansor/task.py index 245cf4c727ae..5fab57c28f48 100644 --- a/python/tvm/ansor/task.py +++ b/python/tvm/ansor/task.py @@ -16,12 +16,16 @@ # under the License. # pylint: disable=unused-import """ ... """ +import random import tvm._ffi from tvm.runtime import Object +from .measure import LocalBuilder, LocalRunner +from .cost_model import RandomModel from . import _ffi_api + @tvm._ffi.register_object("ansor.HardwareParams") class HardwareParams(Object): """ @@ -37,8 +41,9 @@ class HardwareParams(Object): def __init__(self, num_cores, vector_unit_bytes, cache_line_bytes, max_unroll_vec, max_innermost_split_factor): self.__init_handle_by_constructor__(_ffi_api.HardwareParams, num_cores, - vector_unit_bytes, cache_line_bytes, max_unroll_vec, - max_innermost_split_factor) + vector_unit_bytes, cache_line_bytes, + max_unroll_vec, + max_innermost_split_factor) @tvm._ffi.register_object("ansor.SearchTask") @@ -56,4 +61,155 @@ class SearchTask(Object): def __init__(self, dag, workload_key, target, target_host=None, hardware_params=None): self.__init_handle_by_constructor__(_ffi_api.SearchTask, dag, - workload_key, target, target_host, hardware_params) + workload_key, target, target_host, + hardware_params) + + +@tvm._ffi.register_object("ansor.SearchPolicy") +class SearchPolicy(Object): + pass + + +@tvm._ffi.register_object("ansor.MetaTileRewritePolicy") +class MetaTileRewritePolicy(Object): + """ The search policy that searches with meta tiling and random rewrite + + Parameters + ---------- + program_cost_model: CostModel + Cost model for complete programs + params: int + Parameters of the search policy, go meta_tile_rewrite_policy.h to find the + definitions. See code below to find the default values + seed: int + Random seed + """ + + def __init__(self, + program_cost_model, + params=None, + seed=None): + # set default parameters + default_params = { + "eps_greedy": 0.05, + + 'evolutionary_search_population': 2048, + 'evolutionary_search_num_iters': 15, + "evolutionary_search_mutation_prob": 0.85, + "evolutionary_search_use_measured_ratio": 0.2, + + 'cpu_multi_level_tiling_structure': 'SSRSRS', + 'gpu_multi_level_tiling_structure': 'SSSRRSRS', + + 'disable_change_compute_location': 0, + } + + if params is None: + params = default_params + else: + for key, value in default_params.items(): + if key not in params: + params[key] = value + + self.__init_handle_by_constructor__( + _ffi_api.MetaTileRewritePolicy, program_cost_model, params, + seed or random.randint(1, 1 << 30)) + + +@tvm._ffi.register_object("ansor.TuneOption") +class TuneOption(Object): + """ The options for tuning + + Parameters + ---------- + n_trials: int + Number of total measurement trials + early_stopping: int + Stops early the tuning if no improvement after n measurements + num_measure_per_iter: int + The number of programs to be measured at each iteration + verbose: int + Verbosity level. 0 means silent. + builder: Builder + Builder which builds the program + runner: Runner + Runner which runs the program and measure time costs + callbacks: List[MeasureCallback] + Callback functions + """ + + def __init__(self, n_trials=0, early_stopping=-1, num_measure_per_iter=64, + verbose=1, builder='local', runner='local', callbacks=None): + if isinstance(builder, str): + if builder == 'local': + builder = LocalBuilder() + else: + raise ValueError("Invalid builder: " + builder) + + if isinstance(runner, str): + if runner == 'local': + runner = LocalRunner() + else: + raise ValueError("Invalid builder: " + runner) + + if callbacks is None: + callbacks = [] + + self.__init_handle_by_constructor__( + _ffi_api.TuneOption, n_trials, early_stopping, num_measure_per_iter, + verbose, builder, runner, callbacks) + + +def auto_schedule(workload, search_policy='default', target=None, + target_host=None, hardware_params=None, + tune_option=None): + """ Do auto schedule for a compute declaration. + + The workload paramter can be a `string` as workload_key, or directly + passing a `SearchTask` as input. + + Parameters + ---------- + workload : Str or SearchTask + + target : Target + + task : SearchTask + + target_host : Target = None + + search_policy : Union[SearchPolicy, str] + + hardware_params : HardwareParams + + tune_option : TuneOption + + Returns + ------- + state : State + + sch : tvm.Schedule + + tensors : List[Tensor] + """ + if isinstance(search_policy, str): + if search_policy == 'default': + search_policy = MetaTileRewritePolicy(RandomModel()) + else: + raise ValueError("Invalid search policy: " + search_policy) + + if tune_option is None: + tune_option = TuneOption(n_trials=0) + + if isinstance(workload, str): + sch, tensors = _ffi_api.AutoScheduleByWorkloadKey( + workload, target, target_host, search_policy, hardware_params, + tune_option) + return sch, tensors + elif isinstance(workload, SearchTask): + state = _ffi_api.AutoScheduleBySearchTask(workload, search_policy, + tune_option) + return state + else: + raise ValueError("Invalid workload: " + workload + + ", should be String or SearchTask") diff --git a/src/ansor/auto_schedule.cc b/src/ansor/auto_schedule.cc new file mode 100644 index 000000000000..974e7e5d9f58 --- /dev/null +++ b/src/ansor/auto_schedule.cc @@ -0,0 +1,85 @@ +#include "auto_schedule.h" + +#include + +#include +#include + +#include "search_policy/meta_tile_rewrite_policy.h" + +namespace tvm { +namespace ansor { + +TVM_REGISTER_NODE_TYPE(TuneOptionNode); + +TuneOption TuneOptionNode::make(int n_trials, int early_stopping, + int num_measure_per_iter, int verbose, + Builder builder, Runner runner, + Array callbacks) { + auto node = make_object(); + node->n_trials = n_trials; + node->early_stopping = early_stopping; + node->num_measure_per_iter = num_measure_per_iter; + node->verbose = verbose; + node->builder = std::move(builder); + node->runner = std::move(runner); + node->callbacks = std::move(callbacks); + return TuneOption(node); +} + +State AutoSchedule(SearchTask task, SearchPolicy search_policy, + TuneOption tune_option) { + // Search for the best schedule + ProgramMeasurer measurer = + ProgramMeasurerNode::make(tune_option->builder, tune_option->runner, + tune_option->callbacks, tune_option->verbose); + + return search_policy->Search( + task, tune_option->n_trials, tune_option->early_stopping, + tune_option->num_measure_per_iter, tune_option->verbose, measurer); +} + +std::pair > AutoSchedule( + std::string workload_key, Target target, Target target_host, + SearchPolicy search_policy, HardwareParams hardware_params, + TuneOption tune_option) { + ComputeDAG dag = ComputeDAGNode::make_by_workload_key(workload_key); + SearchTask task = SearchTaskNode::make( + std::move(dag), std::move(workload_key), std::move(target), + std::move(target_host), std::move(hardware_params)); + State state = AutoSchedule(std::move(task), std::move(search_policy), + std::move(tune_option)); + + return task->compute_dag.ApplySteps(state->transform_steps); +} + +TVM_REGISTER_GLOBAL("ansor.TuneOption") + .set_body_typed([](int n_trials, int early_stopping, + int num_measure_per_iter, int verbose, Builder builder, + Runner runner, Array callbacks) { + return TuneOptionNode::make(n_trials, early_stopping, + num_measure_per_iter, verbose, builder, + runner, callbacks); + }); + +TVM_REGISTER_GLOBAL("ansor.AutoScheduleBySearchTask") + .set_body_typed([](SearchTask task, SearchPolicy search_policy, + TuneOption tune_option) { + return AutoSchedule(task, search_policy, tune_option); + }); + +TVM_REGISTER_GLOBAL("ansor.AutoScheduleByWorkloadKey") + .set_body_typed([](std::string workload_key, Target target, + Target target_host, SearchPolicy search_policy, + HardwareParams hardware_params, TuneOption tune_option) { + te::Schedule sch; + Array return_tensors; + std::tie(sch, return_tensors) = + AutoSchedule(workload_key, target, target_host, search_policy, + hardware_params, tune_option); + + return Array{sch, return_tensors}; + }); + +} // namespace ansor +} // namespace tvm diff --git a/src/ansor/auto_schedule.h b/src/ansor/auto_schedule.h new file mode 100644 index 000000000000..c354751390fe --- /dev/null +++ b/src/ansor/auto_schedule.h @@ -0,0 +1,61 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file ansor/search_task.h + * \brief Meta information for a search task + */ + +#ifndef TVM_ANSOR_AUTO_SCHEDULE_H_ +#define TVM_ANSOR_AUTO_SCHEDULE_H_ + +#include "measure.h" + +namespace tvm { +namespace ansor { + +/*! \brief Tuning and measurement options */ +class TuneOption; +class TuneOptionNode : public Object { + public: + int n_trials; // Number of total measurement trials + int early_stopping; // Stops early the tuning if no improvement after n + // measurements + int num_measure_per_iter; // The number of programs to be measured at each + // iteration + int verbose; // Verbosity level. 0 means silent. + Builder builder; // Builder which builds the program + Runner runner; // Runner which runs the program and measure time + // costs + Array callbacks; // Callback functions + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("n_trials", &n_trials); + v->Visit("early_stopping", &early_stopping); + v->Visit("num_measure_per_iter", &num_measure_per_iter); + v->Visit("verbose", &verbose); + v->Visit("builder", &builder); + v->Visit("runner", &runner); + v->Visit("callbacks", &callbacks); + } + + static TuneOption make(int n_trials, int early_stopping, + int num_measure_per_iter, int verbose, Builder builder, + Runner runner, Array callbacks); + + static constexpr const char* _type_key = "ansor.TuneOption"; + TVM_DECLARE_FINAL_OBJECT_INFO(TuneOptionNode, Object); +}; +TVM_DEFINE_COW_NODE_REF(TuneOption, ObjectRef, TuneOptionNode); + +/*! \brief Auto schedule for a compute declaration */ +State AutoSchedule(SearchTask task, SearchPolicy search_policy, + TuneOption tune_option); + +std::pair > AutoSchedule( + std::string workload_key, Target target, Target target_host, + SearchPolicy search_policy, HardwareParams hardware_params, + TuneOption tune_option); + +} // namespace ansor +} // namespace tvm + +#endif // TVM_ANSOR_AUTO_SCHEDULE_H_ \ No newline at end of file diff --git a/src/ansor/cost_model/cost_model.cc b/src/ansor/cost_model/cost_model.cc index d4304bccb4bf..060d2b703287 100644 --- a/src/ansor/cost_model/cost_model.cc +++ b/src/ansor/cost_model/cost_model.cc @@ -2,8 +2,10 @@ * Copyright (c) 2020 by Contributors */ #include "cost_model.h" -#include + #include +#include + #include namespace tvm { @@ -39,8 +41,7 @@ CostModel RandomModelNode::make() { } void RandomModelNode::Update(const Array& inputs, - const Array& results) { -} + const Array& results) {} void RandomModelNode::Predict(const SearchTask& task, const std::vector& states, @@ -51,14 +52,13 @@ void RandomModelNode::Predict(const SearchTask& task, CostModel MeasureModelNode::make(Builder builder, Runner runner) { ObjectPtr node = make_object(); - node->measurer = ProgramMeasurerNode::make(std::move(builder), std::move(runner), - Array(), 0); + node->measurer = ProgramMeasurerNode::make( + std::move(builder), std::move(runner), Array(), 0); return CostModel(node); } void MeasureModelNode::Update(const Array& inputs, - const Array& results) { -} + const Array& results) {} void MeasureModelNode::Predict(const SearchTask& task, const std::vector& states, @@ -66,7 +66,8 @@ void MeasureModelNode::Predict(const SearchTask& task, std::vector inputs; std::vector results; - inputs.clear(); inputs.reserve(states.size()); + inputs.clear(); + inputs.reserve(states.size()); for (const auto& state : states) { inputs.push_back(MeasureInputNode::make(task, state)); } @@ -79,7 +80,8 @@ void MeasureModelNode::Predict(const SearchTask& task, } } -CostModel PythonBasedCostModelNode::make(PackedFunc update_func, PackedFunc predict_func, +CostModel PythonBasedCostModelNode::make(PackedFunc update_func, + PackedFunc predict_func, PackedFunc predict_stage_func) { auto node = make_object(); node->update_func = std::move(update_func); @@ -89,7 +91,7 @@ CostModel PythonBasedCostModelNode::make(PackedFunc update_func, PackedFunc pred } void PythonBasedCostModelNode::Update(const Array& inputs, - const Array& results) { + const Array& results) { update_func(inputs, results); } @@ -101,14 +103,15 @@ void PythonBasedCostModelNode::Predict(const SearchTask& task, static_cast(scores->data())); } -void PythonBasedCostModelNode::PredictStages(const SearchTask& task, - const std::vector& states, - std::vector* state_scores, - std::vector>* stage_scores) { +void PythonBasedCostModelNode::PredictStages( + const SearchTask& task, const std::vector& states, + std::vector* state_scores, + std::vector>* stage_scores) { int n_states = states.size(); int n_stages = task->compute_dag.GetInitState()->stages.size(); std::vector flatten_scores; - flatten_scores.resize(n_states * n_stages * 2); // Allocate sufficient spaces. + // Allocate sufficient spaces. + flatten_scores.resize(n_states * n_stages * 2); predict_stage_func(task, Array(states.begin(), states.end()), static_cast(flatten_scores.data())); @@ -134,8 +137,9 @@ void PythonBasedCostModelNode::PredictStages(const SearchTask& task, int offset = 0; if ((*state_scores)[i] > -INFINITY) { - // If the score is valid. Copy scored stages and assign 0 to placeholder and inlined stages. - // If the score is 0, meaning this state failed to be lowered. Just bypass to update offset. + // If the score is valid. Copy scored stages and assign 0 to placeholder + // and inlined stages. If the score is 0, meaning this state failed to + // be lowered. Just bypass to update offset. for (const Stage& stage : states[i]->stages) { if (stage->op_type == kPlaceholder) { scores.push_back(0); @@ -159,5 +163,9 @@ void PythonBasedCostModelNode::PredictStages(const SearchTask& task, } } +TVM_REGISTER_GLOBAL("ansor.RandomModel").set_body_typed([]() { + return RandomModelNode::make(); +}); + } // namespace ansor } // namespace tvm diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index e18d36e34581..32940da0773a 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -1142,5 +1142,26 @@ TVM_REGISTER_GLOBAL("ansor.StateCacheWrite") return Array{state, IntImm(DataType::Int(32), res)}; }); +TVM_REGISTER_GLOBAL("ansor.StatePragma") + .set_body_typed([](State state, int stage_id, const Iterator& it, + const std::string& pragma_type) { + state.pragma(stage_id, it, pragma_type); + return state; + }); + +TVM_REGISTER_GLOBAL("ansor.StateRfactor") + .set_body_typed([](State state, int stage_id, const Iterator& it, + int factor_iter_id, const ComputeDAG& task_dag) { + int res = state.rfactor(stage_id, it, factor_iter_id, task_dag); + return Array{state, IntImm(DataType::Int(32), res)}; + }); + +TVM_REGISTER_GLOBAL("ansor.StateStorageAlign") + .set_body_typed([](State state, int stage_id, const Iterator& it, + int factor, int offset) { + state.storage_align(stage_id, it, factor, offset); + return state; + }); + } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.cc b/src/ansor/search_policy/meta_tile_rewrite_policy.cc index b3b93ec9c839..b4501804607a 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.cc +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.cc @@ -3,6 +3,7 @@ */ #include "meta_tile_rewrite_policy.h" +#include #include #include #include @@ -586,7 +587,8 @@ class RuleAddRfactor : public StructureSynthesisRule { } }; -void MetaTileRewritePolicyNode::SynthesizeMetaStructure(std::vector* out_states) { +void MetaTileRewritePolicyNode::SynthesizeMetaStructure( + std::vector* out_states) { State init_state = cur_task_->compute_dag.GetInitState(); std::string cpu_multi_level_tiling_structure = GetStringParam(params, "cpu_multi_level_tiling_structure"); @@ -1416,5 +1418,12 @@ void MetaTileRewritePolicyNode::EvolutionarySearch( << std::fixed << std::setprecision(2) << duration << std::endl; } +TVM_REGISTER_GLOBAL("ansor.MetaTileRewritePolicy") +.set_body_typed([](CostModel program_cost_model, + Map params, + int seed){ + return MetaTileRewritePolicyNode::make(program_cost_model, params, seed); +}); + } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_policy/utils.h b/src/ansor/search_policy/utils.h index 05b50775b52d..3337975d7a88 100644 --- a/src/ansor/search_policy/utils.h +++ b/src/ansor/search_policy/utils.h @@ -50,10 +50,15 @@ inline double GetDoubleParam(const Map& attr_dict, // Get a string from a tvm str Map inline std::string GetStringParam(const Map& attr_dict, const std::string& key) { - CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; - auto pstr = attr_dict[key].as(); + CHECK_GT(attr_dict.count(key), 0) + << "Cannot find key: \"" << key << "\" in " << attr_dict; + const auto& target = attr_dict[key]; + if (auto pstr = target.as()) { + return pstr->value; + } + auto pstr = target.as(); CHECK(pstr != nullptr); - return pstr->value; + return pstr->data; } // Get a iterator name set from a tvm str Map diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index af72d3b1a1df..04a3f0b25bee 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -461,7 +461,7 @@ void RebaseNonZeroMinLoop(const Schedule& sch) { for (IterVar iv : root_iter_vars) { size_t idx = FindNodeRef(leaf_vars, iv); auto it = s->iter_var_attrs.find(iv); - // don;t need to rebase path that are binded. + // don't need to rebase path that are binded. if (it != s->iter_var_attrs.end() && (*it).second->bind_thread.defined()) { continue; } @@ -614,10 +614,74 @@ void InjectInline(ScheduleNode* sch) { } } +void LegalizeInvalidAttach(ScheduleNode* sch) { + std::unordered_map replace_map; + + for (Stage stage : sch->stages) { + for (Stage s = stage; s.defined();) { + Stage spec = s.GetAttachSpec(); + if (spec->attach_type != kScope) { + break; + } + bool start_attach = false; + IterVar attach_ivar = spec->attach_ivar; + s = spec->attach_stage; + CHECK(attach_ivar.defined()); + CHECK(s.defined()); + + for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) { + IterVar iv = s->leaf_iter_vars[i - 1]; + if (!start_attach && iv.same_as(attach_ivar)) { + start_attach = true; + } + } + if (!start_attach) { + // If the attach_var is fused into another iter_var, update the + // attach_var to be the fused one + // Do this recursively. + IterVar new_attach_ivar = attach_ivar;; + bool updated = true; + while (updated) { + updated = false; + for (const auto& rel : s->relations) { + if (const FuseNode* r = rel.as()) { + if (new_attach_ivar.same_as(r->inner)) { + new_attach_ivar = r->fused; + updated = true; + } + } else if (const SplitNode* r = rel.as()) { + if (new_attach_ivar.same_as(r->parent)) { + new_attach_ivar = r->inner; + updated = true; + } + } + } + replace_map[attach_ivar] = new_attach_ivar; + } + } + } + } + + // remap the parent relation + for (Stage s : sch->stages) { + if (s->attach_type != kScope) continue; + if (replace_map.count(s->attach_ivar)) { + s->attach_ivar = replace_map.at(s->attach_ivar); + } + } + for (Stage s : sch->groups) { + if (s->attach_type != kScope) continue; + if (replace_map.count(s->attach_ivar)) { + s->attach_ivar = replace_map.at(s->attach_ivar); + } + } +} + Schedule Schedule::normalize() { Schedule sn = copy(); InjectInline(sn.operator->()); RebaseNonZeroMinLoop(sn); + LegalizeInvalidAttach(sn.operator->()); return sn; } diff --git a/tests/python/unittest/test_ansor_common.py b/tests/python/unittest/test_ansor_common.py index da87ea5fe9cf..8f04d003ff94 100644 --- a/tests/python/unittest/test_ansor_common.py +++ b/tests/python/unittest/test_ansor_common.py @@ -14,6 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import random +import numpy as np + import tvm from tvm import te from tvm import ansor @@ -499,10 +502,43 @@ def test_measure_local_builder_runner(): def test_search_basic(): - dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) + print("Test schedule search with default search policy") + + N = 128 + A, B, C = matmul_nkkm(N, N, N) + dag = ansor.ComputeDAG([A, B, C]) tgt = tvm.target.create("llvm") task = ansor.SearchTask(dag, "test", tgt) + cost_model = ansor.RandomModel() + # seed = random.randint(1, 1 << 30) + seed = 944563397 + search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) + state = ansor.auto_schedule(task, search_policy, + tune_option=ansor.TuneOption(n_trials=2)) + sch, args = dag.apply_steps_from_state(state) + + print("==== Get State ====") + print(state) + print("==== Get Python Code ====") + print(dag.print_python_code_from_state(state)) + + try: + print("==== Get Lowered Stmt ====") + print(tvm.lower(sch, args, simple_mode=True)) + mod = tvm.build(sch, args, tgt) + + ctx = tvm.context("llvm", 0) + a = tvm.nd.array(np.random.uniform(size=(N, N)).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=(N, N)).astype(B.dtype), ctx) + c = tvm.nd.array(np.zeros((N, N), dtype=C.dtype), ctx) + mod(a, b, c) + tvm.testing.assert_allclose(c.asnumpy(), np.dot( + a.asnumpy(), b.asnumpy()), rtol=1e-5) + print("==== Verification passed ====") + except Exception: + raise Exception("Error encounterd with seed: %d" % (seed)) + if __name__ == "__main__": test_compute_dag_basic() @@ -512,4 +548,4 @@ def test_search_basic(): test_follow_split_follow_fused_split() test_rfactor() test_measure_local_builder_runner() - # test_search_basic() + test_search_basic() From e52135f37418d86ca547ff090fdbd47cca38e706 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Fri, 5 Jun 2020 17:25:31 +0800 Subject: [PATCH 08/45] Bug fix & Add python serialization API (#10) * Delete C++ UT hack since Python is ready * Add ndarray.non_empty * Update Serialization python API --- include/tvm/runtime/c_runtime_api.h | 23 +++ include/tvm/runtime/ndarray.h | 12 +- python/tvm/ansor/__init__.py | 1 + python/tvm/ansor/compute_dag.py | 12 ++ python/tvm/ansor/measure.py | 8 +- python/tvm/ansor/serialization.py | 98 ++++++++++++ python/tvm/runtime/ndarray.py | 33 ++++ src/ansor/compute_dag.cc | 5 + .../search_policy/meta_tile_rewrite_policy.h | 71 +++++---- src/ansor/serialization.cc | 143 +++++++++--------- src/ansor/serialization.h | 31 ++-- src/runtime/ndarray.cc | 80 +++++++++- tests/cpp/ansor_test.cc | 45 ------ tests/python/unittest/test_ansor_common.py | 18 ++- 14 files changed, 408 insertions(+), 172 deletions(-) create mode 100644 python/tvm/ansor/serialization.py diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 213c7059a5f9..5a32ac7d3d9f 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -384,6 +384,29 @@ TVM_DLL int TVMFuncListGlobalNames(int* out_size, const char*** out_array); TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_bits, int dtype_lanes, int device_type, int device_id, TVMArrayHandle* out); +/*! + * \brief Allocate a nd-array's memory of non-empty values, + * including space of shape, of given spec. + * + * \param shape The shape of the array, the data content will be copied to out + * \param ndim The number of dimension of the array. + * \param dtype_code The type code of the dtype + * \param dtype_bits The number of bits of dtype + * \param dtype_lanes The number of lanes in the dtype. + * \param device_type The device type of context + * \param device_id The device id of context. + * \param out The output handle. + * \return 0 when success, -1 when failure happens + */ +TVM_DLL int TVMArrayAllocNonEmpty(const tvm_index_t* shape, + int ndim, + int dtype_code, + int dtype_bits, + int dtype_lanes, + int device_type, + int device_id, + TVMArrayHandle* out); + /*! * \brief Free the TVM Array. * \param handle The array handle to be freed. diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index e69d802652fd..9cc66a371974 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -138,7 +138,17 @@ class NDArray : public ObjectRef { * \param ctx The context of the Array. * \return The created Array */ - TVM_DLL static NDArray Empty(std::vector shape, DLDataType dtype, DLContext ctx); + TVM_DLL static NDArray Empty(std::vector shape, + DLDataType dtype, DLContext ctx); + /*! + * \brief Create an NDArray with non-empty values. + * \param shape The shape of the new array. + * \param dtype The data type of the new array. + * \param ctx The context of the Array. + * \return The created Array + */ + TVM_DLL static NDArray NonEmpty(std::vector shape, + DLDataType dtype, DLContext ctx); /*! * \brief Create a NDArray backed by a dlpack tensor. * diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 70834ba8936f..af1ca24cd4dc 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -22,3 +22,4 @@ from .task import auto_schedule from .measure import MeasureInput, LocalBuilder, LocalRunner from .cost_model import RandomModel +from .serialization import LogToFile, LogReader, best_measure_pair_in_file diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index f3d27884d622..aa50864548a4 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -78,3 +78,15 @@ def print_python_code_from_state(self, state): str : Str """ return _ffi_api.ComputeDAGPrintPythonCodeFromState(self, state) + + def infer_bound_from_state(self, state): + """ + Parameters + ---------- + state : State + + Returns + ------- + state : State + """ + return _ffi_api.ComputeDAGInferBoundFromState(self, state) diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index 72dd3cbfcf92..d7d0e64eb14b 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -44,6 +44,10 @@ logger = logging.getLogger('ansor') +@tvm._ffi.register_object("ansor.MeasureCallback") +class MeasureCallback(Object): + pass + @tvm._ffi.register_object("ansor.MeasureInput") class MeasureInput(Object): """ @@ -332,7 +336,7 @@ def timed_func(): if error_no == 0: try: - args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in + args = [ndarray.non_empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args] ctx.sync() @@ -390,7 +394,7 @@ def timed_func(inp, build_res): if error_no == 0: try: - args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in + args = [ndarray.non_empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args] ctx.sync() diff --git a/python/tvm/ansor/serialization.py b/python/tvm/ansor/serialization.py new file mode 100644 index 000000000000..172405ce7ddb --- /dev/null +++ b/python/tvm/ansor/serialization.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import +""" ... """ +import numpy as np + +import tvm._ffi +from tvm.runtime import Object + +from .measure import MeasureCallback, MeasureErrorNo + +from . import _ffi_api + + +@tvm._ffi.register_object("ansor.LogToFile") +class LogToFile(MeasureCallback): + """ + Parameters + ---------- + filename : Str + """ + + def __init__(self, filename="ansor_tuning.json"): + self.__init_handle_by_constructor__(_ffi_api.LogToFile, filename) + + +@tvm._ffi.register_object("ansor.LogReader") +class LogReader(Object): + def __init__(self, filename="ansor_tuning.json"): + self.__init_handle_by_constructor__(_ffi_api.LogReader, filename) + + def read_lines(self, max_size=-1, skip_size=0): + inputs, results = _ffi_api.LogReaderReadLines( + self, max_size, skip_size) + return inputs, results + + def __iter__(self): + while True: + ret = _ffi_api.LogReaderReadNext(self) + if ret is None or not len(ret): + break + yield ret[0], ret[1] # (input, result) + + +def best_measure_pair_in_file(filename, workload_key=None, target=None): + """ Return best results form log file + + Parameters + ---------- + filename : Str + + workload_key : Str + + target : Str + + Returns + ------- + inp : MeasureInput + + res : MeasureResult + """ + log_reader = LogReader(filename) + best_cost = 1e30 + best_inp = None + best_res = None + + for inp, res in log_reader: + if res.error_no != MeasureErrorNo.NO_ERROR: + continue + if workload_key and inp.task.workload_key != workload_key: + continue + if target and inp.task.target.target_name != target.target_name: + continue + + costs = [] + for value in res.costs: + costs.append(value.value) + cost = np.mean(costs) + if cost < best_cost: + best_cost = cost + best_inp = inp + best_res = res + + return best_inp, best_res diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index 060673dc19c6..967bfcdd3cde 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -279,6 +279,39 @@ def empty(shape, dtype="float32", ctx=context(1, 0)): return _make_array(handle, False, False) +def non_empty(shape, dtype="float32", ctx=context(1, 0)): + """Create an non-empty array given shape and device + + Parameters + ---------- + shape : tuple of int + The shape of the array + + dtype : type or str + The data type of the array. + + ctx : TVMContext + The context of the array + + Returns + ------- + arr : tvm.nd.NDArray + The array tvm supported. + """ + shape = c_array(tvm_shape_index_t, shape) + ndim = ctypes.c_int(len(shape)) + handle = TVMArrayHandle() + dtype = DataType(dtype) + check_call(_LIB.TVMArrayAllocNonEmpty( + shape, ndim, + ctypes.c_int(dtype.type_code), + ctypes.c_int(dtype.bits), + ctypes.c_int(dtype.lanes), + ctx.device_type, + ctx.device_id, + ctypes.byref(handle))) + return _make_array(handle, False, False) + def from_dlpack(dltensor): """Produce an array from a DLPack tensor without memory copy. Retreives the underlying DLPack tensor's pointer to create an array from the diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index c9415a70c303..7fad0ce5b28a 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -1271,5 +1271,10 @@ TVM_REGISTER_GLOBAL("ansor.ComputeDAGPrintPythonCodeFromState") return dag.PrintStepsAsPython(state->transform_steps); }); +TVM_REGISTER_GLOBAL("ansor.ComputeDAGInferBoundFromState") +.set_body_typed([](const ComputeDAG& dag, const State& state) { + return dag.ReplayAndInferBound(state->transform_steps); +}); + } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.h b/src/ansor/search_policy/meta_tile_rewrite_policy.h index 56a75f8e52fe..ca9033ad866e 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.h +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.h @@ -1,91 +1,100 @@ /*! * Copyright (c) 2020 by Contributors * \file ansor/meta_tile_rewrite_policy.h - * \brief A search policy that search with meta tiling structure and random rewrite + * \brief A search policy that search with meta tiling structure and random + * rewrite */ #ifndef TVM_ANSOR_SEARCH_POLICY_META_TILE_REWRITE_POLICY_H_ #define TVM_ANSOR_SEARCH_POLICY_META_TILE_REWRITE_POLICY_H_ -#include +#include #include -#include #include -#include -#include "search_policy.h" +#include +#include + #include "../cost_model/cost_model.h" #include "../utils.h" - +#include "search_policy.h" namespace tvm { namespace ansor { /*! Multi stage search policy */ -class MetaTileRewritePolicyNode: public SearchPolicyNode { +class MetaTileRewritePolicyNode : public SearchPolicyNode { public: CostModel program_cost_model; /* this->params is used to store the following arguments - * int evolutionary_search_population // The population size for evolutionary search - * int evolutionary_search_mutation_prob // The probability of mutation for evolutionary search - * int evolutionary_search_num_iters; // The number of iterations for evolutionary search - * double local_mutation_use_measured_ratio; // The maximum percentage of measured states in the initial - * // population for evolutionary search - * double eps_greedy; // Always allocate this percentage of measurements to random sampled states - * str cpu_multi_level_tiling_structure // The structure of multi-level tiling for CPU - * str gpu_multi_level_tiling_structure // The structure of multi-level tiling for GPU + * int evolutionary_search_population + * The population size for evolutionary search + * int evolutionary_search_mutation_prob + * The probability of mutation for evolutionary search + * int evolutionary_search_num_iters + * The number of iterations for evolutionary search + * double local_mutation_use_measured_ratio + * The maximum percentage of measured states in the initial population + * for evolutionary search + * double eps_greedy + * Always allocate this percentage of measurements to random sampled states + * str cpu_multi_level_tiling_structure + * The structure of multi-level tiling for CPU + * str gpu_multi_level_tiling_structure + * The structure of multi-level tiling for GPU */ Map params; static SearchPolicy make(CostModel program_cost_model, - Map params, - int seed); + Map params, int seed); // Search and make n_trails measurements // Return the best state - State Search(SearchTask task, int n_trials, - int early_stopping, int num_measure_per_iter, - int verbose, ProgramMeasurer measurer) final; + State Search(SearchTask task, int n_trials, int early_stopping, + int num_measure_per_iter, int verbose, + ProgramMeasurer measurer) final; // Continue search. This is used by JointTuner std::pair, Array > ContinueSearchOneRound( - SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) final; + SearchTask task, int num_measure, int verbose, + ProgramMeasurer measurer) final; - static constexpr const char *_type_key = "ansor.MetaTileRewritePolicy"; + static constexpr const char* _type_key = "ansor.MetaTileRewritePolicy"; static const std::vector auto_unroll_configs; TVM_DECLARE_FINAL_OBJECT_INFO(MetaTileRewritePolicyNode, SearchPolicyNode); - SearchTask cur_task_; // The current task + SearchTask cur_task_; // The current task - friend class MetaTileRewritePolicyNodeTest; // Hack friend class for UT protected: // Pick states from best states and random states with eps-greedy policy void PickStatesWithEpsGreedy(std::vector* inputs, const std::vector& best_states, - const std::vector& random_states, int remaining_n_trials); + const std::vector& random_states, + int remaining_n_trials); private: // Run one round of the search pipeline - void SearchOneRound(std::vector* best_states, - int num_random_states, std::vector* random_states); + void SearchOneRound(std::vector* best_states, int num_random_states, + std::vector* random_states); // Synthesize meta tiling structure without tile size void SynthesizeMetaStructure(std::vector* out_states); // Sample init population void SampleInitPopulation(const std::vector& meta_structures, - int out_size, std::vector* out_states); + int out_size, std::vector* out_states); // Perform evolutionary search void EvolutionarySearch(const std::vector& init_population, - int num_best_states, std::vector* best_states); + int num_best_states, std::vector* best_states); SplitFactorizationMemo split_memo_; // Memorize split space for Split std::mt19937 rand_gen_; // Random generator int verbose_; // Verbose level (0 means silent) - int num_measure_per_iter_; // The number of states to measure per iteration + int num_measure_per_iter_; // The number of states to measure per iteration - // The set of the already measured states. We store the string format for redundancy check + // The set of the already measured states. We store the string format for + // redundancy check std::unordered_set measured_states_set_; // The array of already measured states. diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index 0e2b0be42587..fc4917409cc0 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -1,15 +1,17 @@ /*! * Copyright (c) 2020 by Contributors */ +#include "serialization.h" + #include -// #include #include + #include #include -#include #include #include -#include "serialization.h" +#include + #include "loop_state.h" #include "utils.h" @@ -18,10 +20,10 @@ namespace dmlc { namespace json { -inline std::vector& FloatArrayToVector(std::vector* out, - const ::tvm::Array<::tvm::PrimExpr>& data) { +inline std::vector& FloatArrayToVector( + std::vector* out, const ::tvm::Array<::tvm::PrimExpr>& data) { out->clear(); - for (const auto&x : data) { + for (const auto& x : data) { auto pf = x.as<::tvm::tir::FloatImmNode>(); CHECK(pf != nullptr) << "Cost can only contain float values"; out->push_back(pf->value); @@ -29,10 +31,10 @@ inline std::vector& FloatArrayToVector(std::vector* out, return *out; } -inline std::vector& IntArrayToVector(std::vector* out, - const ::tvm::Array<::tvm::PrimExpr>& data) { +inline std::vector& IntArrayToVector( + std::vector* out, const ::tvm::Array<::tvm::PrimExpr>& data) { out->clear(); - for (const auto&x : data) { + for (const auto& x : data) { auto pi = x.as<::tvm::tir::IntImmNode>(); CHECK(pi != nullptr) << "Cost can only contain int values"; out->push_back(pi->value); @@ -41,15 +43,15 @@ inline std::vector& IntArrayToVector(std::vector* out, } template <> -struct Handler > { +struct Handler> { inline static void Write(dmlc::JSONWriter* writer, - const std::vector<::tvm::ansor::Stage> & data) { + const std::vector<::tvm::ansor::Stage>& data) { // todo(lmzheng): support serialization of Stage writer->BeginArray(false); writer->EndArray(); } inline static void Read(dmlc::JSONReader* reader, - std::vector<::tvm::ansor::Stage> * data) { + std::vector<::tvm::ansor::Stage>* data) { bool s; reader->BeginArray(); s = reader->NextArrayItem(); CHECK(!s); @@ -57,9 +59,9 @@ struct Handler > { }; template <> -struct Handler > { +struct Handler> { inline static void Write(dmlc::JSONWriter* writer, - const std::vector<::tvm::ansor::Step> & data) { + const std::vector<::tvm::ansor::Step>& data) { std::vector tmp; writer->BeginArray(false); for (size_t i = 0; i < data.size(); ++i) { @@ -92,7 +94,8 @@ struct Handler > { writer->WriteArrayItem(ps->iter_id); writer->WriteArrayItem(ps->src_step_id); writer->WriteArrayItem(ps->n_split); - } else if (auto ps = data[i].as<::tvm::ansor::FollowFusedSplitStepNode>()) { + } else if (auto ps = + data[i].as<::tvm::ansor::FollowFusedSplitStepNode>()) { writer->WriteArrayItem(std::string("FFSS")); writer->WriteArrayItem(ps->stage_id); writer->WriteArrayItem(ps->iter_id); @@ -165,7 +168,7 @@ struct Handler > { writer->EndArray(); } inline static void Read(dmlc::JSONReader* reader, - std::vector<::tvm::ansor::Step> * data) { + std::vector<::tvm::ansor::Step>* data) { std::vector int_list; bool s, inner_to_outer, factor_or_nparts; std::string name, scope_name, pragma_type; @@ -183,7 +186,8 @@ struct Handler > { reader->Read(&stage_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&int_list); - data->push_back(::tvm::ansor::ReorderStepNode::make(stage_id, int_list)); + data->push_back( + ::tvm::ansor::ReorderStepNode::make(stage_id, int_list)); } else if (name == "SS") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); @@ -236,8 +240,8 @@ struct Handler > { reader->Read(&iter_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&ann); - data->push_back(::tvm::ansor::AnnotationStepNode::make(stage_id, - iter_id, ::tvm::ansor::IteratorAnnotation(ann))); + data->push_back(::tvm::ansor::AnnotationStepNode::make( + stage_id, iter_id, ::tvm::ansor::IteratorAnnotation(ann))); } else if (name == "CA") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); @@ -269,8 +273,8 @@ struct Handler > { reader->Read(&stage_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&scope_name); - data->push_back(::tvm::ansor::CacheWriteStepNode::make( - stage_id, scope_name)); + data->push_back( + ::tvm::ansor::CacheWriteStepNode::make(stage_id, scope_name)); } else if (name == "PS") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); @@ -278,8 +282,8 @@ struct Handler > { reader->Read(&iter_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&pragma_type); - data->push_back(::tvm::ansor::PragmaStepNode::make( - stage_id, iter_id, pragma_type)); + data->push_back( + ::tvm::ansor::PragmaStepNode::make(stage_id, iter_id, pragma_type)); } else if (name == "RFS") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); @@ -287,8 +291,8 @@ struct Handler > { reader->Read(&iter_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&factor_iter_id); - data->push_back(::tvm::ansor::RfactorStepNode::make( - stage_id, iter_id, factor_iter_id)); + data->push_back(::tvm::ansor::RfactorStepNode::make(stage_id, iter_id, + factor_iter_id)); } else if (name == "SA") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); @@ -388,7 +392,7 @@ struct Handler<::tvm::ansor::MeasureResultNode> { writer->BeginArray(false); writer->WriteArraySeperator(); writer->BeginArray(false); - for (const auto&x : data.costs) { + for (const auto& x : data.costs) { auto pf = x.as<::tvm::tir::FloatImmNode>(); CHECK(pf != nullptr) << "Cost can only contain float values"; writer->WriteArrayItem(pf->value); @@ -430,7 +434,7 @@ namespace ansor { TVM_REGISTER_OBJECT_TYPE(LogToFileNode); TVM_REGISTER_OBJECT_TYPE(LogReaderNode); -const std::string ansor_LOG_VERSION = "v0.1"; // NOLINT(*) +const std::string ansor_LOG_VERSION = "v0.1"; // NOLINT(*) MeasureCallback LogToFileNode::make(std::string filename) { auto node = make_object(); @@ -438,8 +442,7 @@ MeasureCallback LogToFileNode::make(std::string filename) { return MeasureCallback(node); } -void WriteMeasureRecords(std::ostream* os, - const Array& inputs, +void WriteMeasureRecords(std::ostream* os, const Array& inputs, const Array& results) { dmlc::JSONWriter writer(os); for (size_t i = 0; i < inputs.size(); ++i) { @@ -452,10 +455,8 @@ void WriteMeasureRecords(std::ostream* os, } } -void ReadMeasureRecords(std::string str, - MeasureInputNode* inp, - MeasureResultNode* res, - std::string* log_version) { +void ReadMeasureRecords(std::string str, MeasureInputNode* inp, + MeasureResultNode* res, std::string* log_version) { std::istringstream ss(str); dmlc::JSONReader reader(&ss); std::string key; @@ -474,15 +475,6 @@ void ReadMeasureRecords(std::string str, } } -TVM_REGISTER_GLOBAL("ansor.write_measure_records_to_file") -.set_body([](TVMArgs args, TVMRetValue *ret) { - std::string filename = args[0]; - Array in = args[1]; - Array res = args[2]; - std::ofstream ofs(filename, std::ofstream::app); - WriteMeasureRecords(&ofs, in, res); -}); - void LogToFileNode::callback(const SearchPolicy& policy, const Array& inputs, const Array& results) { @@ -518,8 +510,8 @@ bool LogReaderNode::ReadNext(MeasureInputNode* inp, MeasureResultNode* res) { return false; } -std::pair, Array > LogReaderNode::ReadLines( - int max_size, int skip_size) { +std::pair, Array> LogReaderNode::ReadLines( + int max_size, int skip_size) { auto inp = make_object(); auto res = make_object(); Array inputs; @@ -542,32 +534,41 @@ std::pair, Array > LogReaderNode::ReadLines( return std::make_pair(inputs, results); } -std::pair BestMeasurePairInFile(const std::string& filename, - const std::string& workload_key, - const Target& target) { - std::pair best_pair; - double best_cost = 1e30; - - auto inp = make_object(); - auto res = make_object(); - LogReader reader = LogReaderNode::make(filename); - - while (reader->ReadNext(inp.get(), res.get())) { - if (res->error_no != kNoError || inp->task->workload_key != workload_key - || inp->task->target->target_name != target->target_name) { - continue; - } - - double cost = FloatArrayMean(res->costs); - - if (cost < best_cost) { - best_cost = cost; - best_pair = std::make_pair(inp->copy(), res->copy()); - } - } - - return best_pair; -} +TVM_REGISTER_GLOBAL("ansor.write_measure_records_to_file") + .set_body([](TVMArgs args, TVMRetValue* ret) { + std::string filename = args[0]; + Array in = args[1]; + Array res = args[2]; + std::ofstream ofs(filename, std::ofstream::app); + WriteMeasureRecords(&ofs, in, res); + }); + +TVM_REGISTER_GLOBAL("ansor.LogToFile") + .set_body_typed([](const std::string& filename) { + return LogToFileNode::make(filename); + }); + +TVM_REGISTER_GLOBAL("ansor.LogReader") + .set_body_typed([](const std::string& filename) { + return LogReaderNode::make(filename); + }); + +TVM_REGISTER_GLOBAL("ansor.LogReaderReadLines") + .set_body_typed([](LogReader reader, int size, int skip_size) { + const auto& res = reader->ReadLines(size, skip_size); + return Array{res.first, res.second}; + }); + +TVM_REGISTER_GLOBAL("ansor.LogReaderReadNext") + .set_body_typed([](LogReader reader) { + auto inp = make_object(); + auto res = make_object(); + if (reader->ReadNext(inp.get(), res.get())) { + return Array{ObjectRef(inp), ObjectRef(res)}; + } else { + return Array(); + } + }); } // namespace ansor -} // namespace tvm \ No newline at end of file +} // namespace tvm diff --git a/src/ansor/serialization.h b/src/ansor/serialization.h index 96dfb0ee320b..ef4132169652 100644 --- a/src/ansor/serialization.h +++ b/src/ansor/serialization.h @@ -7,11 +7,11 @@ #ifndef TVM_ANSOR_SERIALIZATION_H_ #define TVM_ANSOR_SERIALIZATION_H_ -#include #include +#include #include + #include "measure.h" -// #include "search_policy/search_policy.h" namespace tvm { namespace ansor { @@ -19,23 +19,22 @@ namespace ansor { class LogReader; /*! \brief Log the input and results of measurments to file */ -class LogToFileNode: public MeasureCallbackNode { +class LogToFileNode : public MeasureCallbackNode { public: std::string filename; static MeasureCallback make(std::string filename); /*! \brief Log measure pairs to file. This is called by the search policy */ - void callback(const SearchPolicy& policy, - const Array& inputs, + void callback(const SearchPolicy& policy, const Array& inputs, const Array& results) final; - static constexpr const char *_type_key = "ansor.LogToFile"; + static constexpr const char* _type_key = "ansor.LogToFile"; TVM_DECLARE_FINAL_OBJECT_INFO(LogToFileNode, MeasureCallbackNode); }; /*! \brief Log reader */ -class LogReaderNode: public Object { +class LogReaderNode : public Object { public: std::string filename; std::ifstream infile; @@ -50,27 +49,25 @@ class LogReaderNode: public Object { * \param max_size The maximum number of lines. -1 means read all lines * \param skip_size Skip the first n lines */ std::pair, Array > ReadLines( - int max_size = -1, int skip_size = 0); + int max_size = -1, int skip_size = 0); static constexpr const char* _type_key = "ansor.LogReader"; TVM_DECLARE_FINAL_OBJECT_INFO(LogReaderNode, Object); + private: std::string cur_line; }; TVM_DEFINE_MUTABLE_NODE_REF(LogReader, LogReaderNode); -void WriteMeasureRecords(std::ostream* os, - const Array& inputs, +void WriteMeasureRecords(std::ostream* os, const Array& inputs, const Array& results); -void ReadMeasureRecords(std::string str, - MeasureInputNode* inp, - MeasureResultNode* res, - std::string* log_version); +void ReadMeasureRecords(std::string str, MeasureInputNode* inp, + MeasureResultNode* res, std::string* log_version); -std::pair BestMeasurePairInFile(const std::string& filename, - const std::string& workload_key, - const Target& target); +std::pair BestMeasurePairInFile( + const std::string& filename, const std::string& workload_key, + const Target& target); } // namespace ansor } // namespace tvm diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index 800a9167dadc..714535ecc8a6 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -26,6 +26,9 @@ #include #include +#include +#include + #include "runtime_base.h" extern "C" { @@ -180,7 +183,8 @@ NDArray NDArray::CreateView(std::vector shape, DLDataType dtype) { DLManagedTensor* NDArray::ToDLPack() const { return Internal::ToDLPack(get_mutable()); } -NDArray NDArray::Empty(std::vector shape, DLDataType dtype, DLContext ctx) { +NDArray NDArray::Empty(std::vector shape, DLDataType dtype, + DLContext ctx) { NDArray ret = Internal::Create(shape, dtype, ctx); // setup memory content size_t size = GetDataSize(ret.get_mutable()->dl_tensor); @@ -190,6 +194,59 @@ NDArray NDArray::Empty(std::vector shape, DLDataType dtype, DLContext c return ret; } + +NDArray NDArray::NonEmpty(std::vector shape, DLDataType dtype, + DLContext ctx) { + NDArray ret = Internal::Create(shape, dtype, ctx); + NDArray dummy_cpu_arr = Internal::Create(shape, dtype, {kDLCPU, 0}); + + // setup memory content + size_t size = GetDataSize(ret.get_mutable()->dl_tensor); + size_t alignment = GetDataAlignment(ret.get_mutable()->dl_tensor); + dummy_cpu_arr.get_mutable()->dl_tensor.data = + DeviceAPI::Get(dummy_cpu_arr->ctx)->AllocDataSpace( + {kDLCPU, 0}, size, alignment, dummy_cpu_arr->dtype); + size_t elem_cnt = 1; + for (tvm_index_t i = 0; i < dummy_cpu_arr->ndim; ++i) { + elem_cnt *= static_cast(dummy_cpu_arr->shape[i]); + } + + // TODO(..): maybe we could have better solution for assigning values + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution<> dis(1.0, 10.0); + // Use float representation could make us work well on float / int type too. + for (size_t i = 0; i < elem_cnt; ++i) { + if (dummy_cpu_arr->dtype.bits == 1) { + (reinterpret_cast( + dummy_cpu_arr.get_mutable()->dl_tensor.data))[i] = dis(gen); + } else if (dummy_cpu_arr->dtype.bits == 8) { + (reinterpret_cast( + dummy_cpu_arr.get_mutable()->dl_tensor.data))[i] = dis(gen); + } else if (dummy_cpu_arr->dtype.bits == 16) { + (reinterpret_cast( + dummy_cpu_arr.get_mutable()->dl_tensor.data))[i] = + __truncXfYf2__( + static_cast(dis(gen))); + } else if (dummy_cpu_arr->dtype.bits == 32) { + (reinterpret_cast( + dummy_cpu_arr.get_mutable()->dl_tensor.data))[i] = dis(gen); + } else if (dummy_cpu_arr->dtype.bits == 64) { + (reinterpret_cast( + dummy_cpu_arr.get_mutable()->dl_tensor.data))[i] = dis(gen); + } else { + LOG(FATAL) << "Doesn't support dtype code " << dtype.code + << " dtype bits " << dtype.bits; + } + } + ret.get_mutable()->dl_tensor.data = + DeviceAPI::Get(ret->ctx)->AllocDataSpace( + ret->ctx, size, alignment, ret->dtype); + CopyFromTo(&(dummy_cpu_arr.get_mutable()->dl_tensor), + &(ret.get_mutable()->dl_tensor)); + return ret; +} + NDArray NDArray::FromDLPack(DLManagedTensor* tensor) { NDArray::Container* data = new NDArray::Container(); // construct header @@ -257,8 +314,9 @@ int TVMArrayGetTypeIndex(TVMArrayHandle handle, unsigned* out_tindex) { API_END(); } -int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_bits, - int dtype_lanes, int device_type, int device_id, TVMArrayHandle* out) { +int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, + int dtype_bits, int dtype_lanes, int device_type, + int device_id, TVMArrayHandle* out) { API_BEGIN(); DLDataType dtype; dtype.code = static_cast(dtype_code); @@ -272,6 +330,22 @@ int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_ API_END(); } +int TVMArrayAllocNonEmpty(const tvm_index_t* shape, int ndim, int dtype_code, + int dtype_bits, int dtype_lanes, int device_type, + int device_id, TVMArrayHandle* out) { + API_BEGIN(); + DLDataType dtype; + dtype.code = static_cast(dtype_code); + dtype.bits = static_cast(dtype_bits); + dtype.lanes = static_cast(dtype_lanes); + DLContext ctx; + ctx.device_type = static_cast(device_type); + ctx.device_id = device_id; + *out = NDArray::Internal::MoveToFFIHandle( + NDArray::NonEmpty(std::vector(shape, shape + ndim), dtype, ctx)); + API_END(); +} + int TVMArrayFree(TVMArrayHandle handle) { API_BEGIN(); NDArray::Internal::FFIDecRef(handle); diff --git a/tests/cpp/ansor_test.cc b/tests/cpp/ansor_test.cc index e5a2c98c02a9..00e748204fde 100644 --- a/tests/cpp/ansor_test.cc +++ b/tests/cpp/ansor_test.cc @@ -730,51 +730,6 @@ TEST(Feature, ExtractionMatmul) { // TODO(...): Add feature check here } -namespace tvm { -namespace ansor { -class MetaTileRewritePolicyNodeTest { - public: - MetaTileRewritePolicyNodeTest(CostModel cost_model, SearchTask task) { - policy = make_object(); - policy->program_cost_model = std::move(cost_model); - policy->rand_gen_ = std::mt19937(0); - policy->params.Set("cpu_multi_level_tiling_structure", - te::StringImmNode::make("SSRSRS")); - policy->params.Set("disable_change_compute_location", - IntImm(DataType::Int(32), 0)); - policy->cur_task_ = task; - } - void SynthesizeMetaStructure(std::vector* meta_structures) { - policy->SynthesizeMetaStructure(meta_structures); - } - void SampleInitPopulation(const std::vector& meta_structures, - int out_size, std::vector* out_states) { - policy->SampleInitPopulation(meta_structures, out_size, out_states); - } - tvm::runtime::ObjectPtr policy; -}; -} // namespace ansor -} // namespace tvm - -TEST(MetaTileRewritePolicy, Basic) { - const auto& tensors = matmul_func(512, 512, 512); - const auto& dag = ComputeDAGNode::make(tensors); - const auto& task = SearchTaskNode::make( - dag, "test", tvm::target::llvm(), tvm::target::llvm(), HardwareParams()); - const auto& cost_model = RandomModelNode::make(); - MetaTileRewritePolicyNodeTest test(cost_model, task); - - std::vector meta_structures, init_population; - test.SynthesizeMetaStructure(&meta_structures); - CHECK_GE(meta_structures.size(), 0); - LOG(INFO) << "SynthesizeMetaStructure get " << meta_structures.size() - << " states."; - test.SampleInitPopulation(meta_structures, 100, &init_population); - CHECK_GE(init_population.size(), 0); - LOG(INFO) << "SampleInitPopulation get " << init_population.size() - << " states."; -} - int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; diff --git a/tests/python/unittest/test_ansor_common.py b/tests/python/unittest/test_ansor_common.py index 8f04d003ff94..d701ef5b7bbd 100644 --- a/tests/python/unittest/test_ansor_common.py +++ b/tests/python/unittest/test_ansor_common.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import random +import os import numpy as np import tvm @@ -510,12 +511,17 @@ def test_search_basic(): tgt = tvm.target.create("llvm") task = ansor.SearchTask(dag, "test", tgt) - cost_model = ansor.RandomModel() # seed = random.randint(1, 1 << 30) seed = 944563397 + log_file = "/tmp/_ansor_python_ut_test.json" + + random.seed(seed) + cost_model = ansor.RandomModel() search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) + tune_option = ansor.TuneOption(n_trials=2, + callbacks=[ansor.LogToFile(log_file)]) state = ansor.auto_schedule(task, search_policy, - tune_option=ansor.TuneOption(n_trials=2)) + tune_option=tune_option) sch, args = dag.apply_steps_from_state(state) print("==== Get State ====") @@ -539,6 +545,14 @@ def test_search_basic(): except Exception: raise Exception("Error encounterd with seed: %d" % (seed)) + inp, res = ansor.best_measure_pair_in_file(log_file) + s0 = dag.infer_bound_from_state(state) + s1 = dag.infer_bound_from_state(inp.state) + assert str(s0) == str(s1) + + if os.path.isfile(log_file): + os.system("rm -rf %s" % log_file) + if __name__ == "__main__": test_compute_dag_basic() From 1fe663878d00fa490b2f2b4ce2a200882aec0317 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 7 Jun 2020 08:31:54 -0700 Subject: [PATCH 09/45] Improve code style, python wrapper and test cases (#11) * Update c++ code style and unit test * Update python State wrapper and test cases --- python/tvm/ansor/__init__.py | 8 + python/tvm/ansor/_ffi_api.py | 3 +- python/tvm/ansor/compute_dag.py | 18 +- python/tvm/ansor/cost_model/__init__.py | 2 +- python/tvm/ansor/cost_model/cost_model.py | 8 +- python/tvm/ansor/loop_state.py | 439 +++++++++++++ python/tvm/ansor/measure.py | 3 +- python/tvm/ansor/serialization.py | 24 +- python/tvm/ansor/state.py | 430 ------------- python/tvm/ansor/task.py | 7 +- python/tvm/ansor/utils.py | 20 +- src/ansor/auto_schedule.cc | 69 +- src/ansor/auto_schedule.h | 30 +- src/ansor/compute_dag.cc | 78 ++- src/ansor/compute_dag.h | 53 +- src/ansor/cost_model/cost_model.cc | 27 +- src/ansor/cost_model/cost_model.h | 42 +- src/ansor/expr_hasher.h | 97 --- src/ansor/feature.cc | 3 +- src/ansor/loop_state.cc | 337 +++++----- src/ansor/loop_state.h | 158 ++++- src/ansor/measure.cc | 138 ++-- src/ansor/measure.h | 69 +- .../search_policy/meta_tile_rewrite_policy.cc | 28 +- .../search_policy/meta_tile_rewrite_policy.h | 94 +-- src/ansor/search_policy/search_policy.cc | 23 +- src/ansor/search_policy/search_policy.h | 27 +- src/ansor/search_policy/utils.cc | 56 +- src/ansor/search_policy/utils.h | 93 ++- src/ansor/search_task.cc | 51 +- src/ansor/search_task.h | 39 +- src/ansor/serialization.cc | 219 ++++--- src/ansor/serialization.h | 56 +- src/ansor/transform_step.cc | 78 +-- src/ansor/transform_step.h | 272 +++----- src/ansor/utils.cc | 23 +- src/ansor/utils.h | 133 ++-- tests/cpp/ansor_test.cc | 597 +----------------- tests/python/unittest/test_ansor_common.py | 515 +-------------- .../python/unittest/test_ansor_compute_dag.py | 66 ++ .../python/unittest/test_ansor_loop_state.py | 475 ++++++++++++++ tests/python/unittest/test_ansor_measure.py | 67 ++ .../unittest/test_ansor_search_policy.py | 81 +++ 43 files changed, 2461 insertions(+), 2595 deletions(-) create mode 100644 python/tvm/ansor/loop_state.py delete mode 100644 python/tvm/ansor/state.py delete mode 100644 src/ansor/expr_hasher.h create mode 100644 tests/python/unittest/test_ansor_compute_dag.py create mode 100644 tests/python/unittest/test_ansor_loop_state.py create mode 100644 tests/python/unittest/test_ansor_measure.py create mode 100644 tests/python/unittest/test_ansor_search_policy.py diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index af1ca24cd4dc..1be7ed404c17 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -17,6 +17,14 @@ # pylint: disable=unused-import, redefined-builtin """Namespace for Ansor autoSchedule""" +from . import compute_dag +from . import measure +from . import serialization +from . import loop_state +from . import task +from . import utils + +# Shortcut from .compute_dag import ComputeDAG from .task import SearchTask, MetaTileRewritePolicy, TuneOption from .task import auto_schedule diff --git a/python/tvm/ansor/_ffi_api.py b/python/tvm/ansor/_ffi_api.py index 177299e67d21..e7b8a59eb83b 100644 --- a/python/tvm/ansor/_ffi_api.py +++ b/python/tvm/ansor/_ffi_api.py @@ -14,7 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""FFI APIs for tvm.ansor""" + +"""Register FFI APIs from C++ for the namespace tvm.ansor""" import tvm._ffi diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index aa50864548a4..0b51ebb402cc 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -14,14 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=unused-import -""" ... """ + +""" Computational graph and its analysis tools """ import tvm._ffi from tvm.runtime import Object - -from .state import State - +from .loop_state import State from . import _ffi_api @@ -50,13 +48,13 @@ def get_init_state(self): ------- state : State """ - return _ffi_api.ComputeDAGGetInitState(self) + return State(_ffi_api.ComputeDAGGetInitState(self)) def apply_steps_from_state(self, state, layout_rewrite_level=None): """ Parameters ---------- - state : State + state : StateObject layout_rewrite_level : LayoutRewriteLevel(***) Returns @@ -71,7 +69,7 @@ def print_python_code_from_state(self, state): """ Parameters ---------- - state : State + state : StateObject Returns ------- @@ -83,10 +81,10 @@ def infer_bound_from_state(self, state): """ Parameters ---------- - state : State + state : StateObject Returns ------- - state : State + state : StateObject """ return _ffi_api.ComputeDAGInferBoundFromState(self, state) diff --git a/python/tvm/ansor/cost_model/__init__.py b/python/tvm/ansor/cost_model/__init__.py index aac062e964fd..fc3821cf7998 100644 --- a/python/tvm/ansor/cost_model/__init__.py +++ b/python/tvm/ansor/cost_model/__init__.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=unused-import, redefined-builtin -""" ... """ +""" Cost model that estimates the performance of programs """ from .cost_model import RandomModel diff --git a/python/tvm/ansor/cost_model/cost_model.py b/python/tvm/ansor/cost_model/cost_model.py index aebc89f465a1..a0e586d69cec 100644 --- a/python/tvm/ansor/cost_model/cost_model.py +++ b/python/tvm/ansor/cost_model/cost_model.py @@ -14,14 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=unused-import -""" ... """ + +""" Cost model that estimates the performance of programs """ import ctypes import numpy as np import tvm._ffi from tvm.runtime import Object - from .. import _ffi_api @@ -32,9 +31,6 @@ class CostModel(Object): @tvm._ffi.register_object("ansor.RandomModel") class RandomModel(Object): - """ - """ - def __init__(self): self.__init_handle_by_constructor__(_ffi_api.RandomModel) diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/ansor/loop_state.py new file mode 100644 index 000000000000..557bb9d3102b --- /dev/null +++ b/python/tvm/ansor/loop_state.py @@ -0,0 +1,439 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import + +""" +The definition of the "state" in search. A state consists a current loop structure +and the transform history to reach its current loop structure. +To enable flexible manipulation of the loop structure, we implemented a lightweight +loop structure IR (Intermediate Representation) specifically for search. + +Basically this is a simplified TVM IR with schedule primitives. +We don't use the existing TVM IR because +1. We want fast incremental change to the loop structures +2. We want serializable history for replay and backtracking +3. We may create some Macro schedule primitives + +After search is done, we will lower this IR to TVM IR with TVM schedule primitives. +Because we share a lot common objects during search, the transformation is +implemented in copy on write style. All objects are immutable, which is +similar to TVM IR. +""" + +import tvm._ffi +from tvm.runtime import Object +from . import _ffi_api + + +@tvm._ffi.register_object("ansor.Iterator") +class Iterator(Object): + """A for loop iterator""" + pass + + +@tvm._ffi.register_object("ansor.Stage") +class Stage(Object): + """A stage in the compute declaration. Similar to tvm.te.schedule.Stage""" + + @property + def iters(self): + """ + Returns + ------- + iters : List[Iterator] + """ + if not hasattr(self, "iterators_cache"): + setattr(self, "iterators_cache", _ffi_api.StageGetIterators(self)) + return getattr(self, "iterators_cache") + + def iter(self, index): + """ + Parameters + ---------- + index : Int + + Returns + ------- + iter : Iterator + """ + return _ffi_api.StageGetIterator(self, index) + + +@tvm._ffi.register_object("ansor.State") +class StateObject(Object): + """The internal State object """ + def __eq__(self, other): + return _ffi_api.StateEqual(self, other) + + +class State: + """ + A state in the search process. It consists of the current loop structure + and the history steps to reach this state. + + Notes + ----- + This is a wrapper class of StateObject to deal with copy-on-write property + """ + def __init__(self, state_object): + self.state_object = state_object + + self.stages_cache = None + + def clear_cache(self): + self.stages_cache = None + + def copy(self): + return State(self.state_object) + + @property + def stages(self): + """ + Returns + ------- + stages : List[Stage] + """ + if not self.stages_cache: + self.stages_cache = _ffi_api.StateGetStages(self.state_object) + return self.stages_cache + + def transform_steps_size(self): + """ Return the size of transform_steps + """ + return _ffi_api.StateGetTransformStepsSize(self.state_object) + + def reorder(self, stage_id, order): + """ + Parameters + ---------- + stage_id : Int + The index of the stage to reorder + order : List[Iterator] + Iterators in the expected order + """ + self.state_object = _ffi_api.StateReorder(self.state_object, stage_id, order) + self.clear_cache() + + def split(self, stage_id, it, lengths, inner_to_outer=True): + """ + Parameters + ---------- + stage_id : Int + The index of the stage to split + it : Iterator + The iterator to split + lengths: List[Int] + The split factors + inner_to_outer: Bool + True to use `factor` to split from inner to outer, + False to use `nparts` to split from outer to inner + + Returns + ------- + res_its : List[Iterator] + The splitted new Iterators + """ + self.state_object, res = _ffi_api.StateSplit(self.state_object, stage_id, it, lengths, + inner_to_outer) + self.clear_cache() + return res + + def follow_split(self, stage_id, it, src_step_id, n_split): + """ + Parameters + ---------- + stage_id : Int + The index of the stage to split + it : Iterator + The iterator to split + src_step_id : Int + The index of the split step to follow in the history + n_split : Int + The number of split level + + Returns + ------- + res_its : List[Iterator] + The splitted new Iterators + """ + self.state_object, res = _ffi_api.StateFollowSplit(self.state_object, stage_id, it, + src_step_id, n_split) + self.clear_cache() + return res + + def follow_fused_split(self, stage_id, it, src_step_ids, level, + factor_or_nparts): + """ + Parameters + ---------- + stage_id : Int + The index of the stage to split + it : Iterator + The iterator to split + src_step_ids : List[Int] + The indices of the split steps to follow in the history + level : Int + Use the length in this split level + factor_or_nparts : Bool + True to use `factor` for split from inner to outer, + False to use `nparts` for split from outer to inner + + Returns + ------- + res_its : List[Iterator] + The splitted new Iterators + """ + self.state_object, res = _ffi_api.StateFollowFusedSplit(self.state_object, stage_id, it, + src_step_ids, level, + factor_or_nparts) + self.clear_cache() + return res + + def fuse(self, stage_id, iters): + """ + Parameters + ---------- + stage_id : Int + The index of the stage to fuse + iters : List[Iterator] + The iterators to be fused + + Returns + ------- + res_it : Iterator + The fused Iterator + """ + self.state_object, res = _ffi_api.StateFuse(self.state_object, stage_id, iters) + self.clear_cache() + return res + + def vectorize(self, stage_id, it): + """ + Parameters + ---------- + stage_id : Int + The index of the stage to vectorize + it : Iterator + The iterator to be vectorized + + Returns + ------- + res_it : Iterator + The vectorized Iterator + """ + self.state_object, res = _ffi_api.StateVectorize(self.state_object, stage_id, it) + self.clear_cache() + return res + + def parallel(self, stage_id, it): + """ + Parameters + ---------- + stage_id : Int + The index of the stage to parallel + it : Iterator + The iterator to be parallelized + + Returns + ------- + res_it : Iterator + The parallelized Iterator + """ + self.state_object, res = _ffi_api.StateParallel(self.state_object, stage_id, it) + self.clear_cache() + return res + + def unroll(self, stage_id, it, max_unroll=-1): + """ + Parameters + ---------- + stage_id : Int + The index of the stage to unroll + it : Iterator + The iterator to be unrolled + max_unroll: Int + The maximum length of the iterator that can be unrolled + + Returns + ------- + res_it : Iterator + The unrolled Iterator + """ + self.state_object, res = _ffi_api.StateUnroll(self.state_object, stage_id, it, max_unroll) + self.clear_cache() + return res + + def bind_thread(self, stage_id, it, thread_name): + """ + Parameters + ---------- + stage_id : Int + The index of the stage to bind + it : Iterator + The iterator to be bound + thread_name : str + The name of the thread (e.g. "blockIdx.x", "threadIdx.y", "vthread") + + Returns + ------- + res_it : Iterator + The bound Iterator + """ + trans_table = { + "vthread": 4, + "blockIdx.x": 5, + "threadIdx.x": 6, + "blockIdx.y": 7, + "threadIdx.y": 8, + } + thread_id = trans_table[thread_name] + + self.state_object, res = _ffi_api.StateUnroll(self.state_object, stage_id, it, thread_id) + self.clear_cache() + return res + + def compute_at(self, stage_id, target_stage_id, target_iter): + """ + Parameters + ---------- + stage_id : Int + The index of source stage + target_stage_id : Int + The index of the target stage of compute_at + target_iter : Iterator + The target Iterator of compute_at + """ + self.state_object = _ffi_api.StateComputeAt(self.state_object, stage_id, + target_stage_id, target_iter) + self.clear_cache() + + def compute_root(self, stage_id): + """ + Parameters + ---------- + stage_id : Int + The index of the stage to compute root + """ + self.state_object = _ffi_api.StateComputeRoot(self.state_object, stage_id) + self.clear_cache() + + def compute_inline(self, stage_id): + """ + Parameters + ---------- + stage_id : Int + The index of the stage to compute inline + """ + self.state_object = _ffi_api.StateComputeInline(self.state_object, stage_id) + self.clear_cache() + + def cache_read(self, stage_id, scope_name, reader_stage_ids, task_dag): + """ + Parameters + ---------- + stage_id : Int + The index of the stage to do cache_read + scope_name : Str + reader_stage_ids : List[Int] + task_dag : ComputeDAG + + Returns + ------- + new_stage_id : Int + The added staged id + """ + self.state_object, new_stage_id = _ffi_api.StateCacheRead(self.state_object, stage_id, + scope_name, reader_stage_ids, + task_dag) + self.clear_cache() + return int(new_stage_id) + + def cache_write(self, stage_id, scope_name, task_dag): + """ + Parameters + ---------- + stage_id : Int + The index of the stage to do cache read + scope_name : Str + task_dag : ComputeDAG + + Returns + ------- + new_stage_id : Int + The added staged id + """ + self.state_object, new_stage_id = _ffi_api.StateCacheWrite(self.state_object, stage_id, + scope_name, task_dag) + self.clear_cache() + return int(new_stage_id) + + def pragma(self, stage_id, it, pragma_type): + """ + Parameters + ---------- + stage_id : Int + The index of the stage to add pragma + it : Iterator + The iterator to add pragma + pragma_type : Str + """ + self.state_object = _ffi_api.StatePragma(self.state_object, stage_id, it, pragma_type) + self.clear_cache() + + def rfactor(self, stage_id, it, factor_iter_id, task_dag): + """ + Parameters + ---------- + stage_id : Int + The index of the stage to do reduction factor + it : Iterator + factor_iter_id : Int + task_dag : ComputeDAG + + Returns + ------- + new_stage_id : Int + The added staged id + """ + self.state_object, new_stage_id = _ffi_api.StateRfactor(self.state_object, stage_id, it, + factor_iter_id, task_dag) + self.clear_cache() + return int(new_stage_id) + + def storage_align(self, stage_id, it, factor, offset): + """ + Parameters + ---------- + stage_id : Int + The index of the stage to do storage align + it : Iterator + factor : Int + offset : Int + + Returns + ------- + state : State + The updated state + """ + self.state_object = _ffi_api.StateStorageAlign(self.state_object, stage_id, it, factor, offset) + self.clear_cache() + + def __str__(self): + return str(self.state_object) + + def __eq__(self, other): + return _ffi_api.StateEqual(self.state_object, other.state_object) diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index d7d0e64eb14b..5438edfaa6b2 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=unused-import + """Distributed measurement infrastructure to measure the runtime costs of tensor programs These functions are responsible for building the tvm module, uploading it to @@ -38,7 +38,6 @@ from ..contrib import tar, ndk from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout, request_remote, check_remote from .compute_dag import LayoutRewriteLevel - from . import _ffi_api logger = logging.getLogger('ansor') diff --git a/python/tvm/ansor/serialization.py b/python/tvm/ansor/serialization.py index 172405ce7ddb..bd9a69944057 100644 --- a/python/tvm/ansor/serialization.py +++ b/python/tvm/ansor/serialization.py @@ -14,21 +14,22 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=unused-import -""" ... """ + +"""Tuning log I/O Utilities""" + import numpy as np import tvm._ffi from tvm.runtime import Object - from .measure import MeasureCallback, MeasureErrorNo - from . import _ffi_api @tvm._ffi.register_object("ansor.LogToFile") class LogToFile(MeasureCallback): """ + A measurement callback that writes tuning logs into a file + Parameters ---------- filename : Str @@ -40,6 +41,13 @@ def __init__(self, filename="ansor_tuning.json"): @tvm._ffi.register_object("ansor.LogReader") class LogReader(Object): + """ + Reader of the json log file + + Parameters + ---------- + filename : Str + """ def __init__(self, filename="ansor_tuning.json"): self.__init_handle_by_constructor__(_ffi_api.LogReader, filename) @@ -56,21 +64,23 @@ def __iter__(self): yield ret[0], ret[1] # (input, result) +def write_measure_records_to_file(filename, inputs, results): + """Write(append) measure records to file""" + _ffi_api.WriteMeasureRecordsToFile(filename, inputs, results) + + def best_measure_pair_in_file(filename, workload_key=None, target=None): """ Return best results form log file Parameters ---------- filename : Str - workload_key : Str - target : Str Returns ------- inp : MeasureInput - res : MeasureResult """ log_reader = LogReader(filename) diff --git a/python/tvm/ansor/state.py b/python/tvm/ansor/state.py deleted file mode 100644 index aa231ab6f4c6..000000000000 --- a/python/tvm/ansor/state.py +++ /dev/null @@ -1,430 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-import -""" ... """ - -import tvm._ffi -from tvm.runtime import Object - -from . import _ffi_api - - -@tvm._ffi.register_object("ansor.Iterator") -class Iterator(Object): - """ ... - """ - pass - - -@tvm._ffi.register_object("ansor.Stage") -class Stage(Object): - """ ... - """ - - def iterator(self, index): - """ - Parameters - ---------- - index : Int - - Returns - ------- - iter : Iterator - """ - return _ffi_api.StageGetIterator(self, index) - - def iterators(self): - """ - Returns - ------- - iters : List[Iterator] - """ - return _ffi_api.StageGetIterators(self) - - -@tvm._ffi.register_object("ansor.State") -class State(Object): - """ ... - """ - - def stage(self, index): - """ - Parameters - ---------- - index : Int - - Returns - ------- - stage : Stage - """ - return _ffi_api.StateGetStage(self, index) - - def transform_steps_size(self): - """ Return the size of transform_steps - """ - return _ffi_api.StateGetTransformStepsSize(self) - - def reorder(self, stage_id, order): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - order : List[Iterator] - Iterators in expected order - - Returns - ------- - state : State - The updated state - """ - state = _ffi_api.StateReorder(self, stage_id, order) - return state - - def split(self, stage_id, it, lengths, inner_to_outer=True): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - it : Iterator - The target Iterator - lengths: List[Int] - The split factor - inner_to_outer: Bool - True to use `factor` for split from inner to outer, - False to use `nparts` for split from outer to inner - - Returns - ------- - state : State - The updated state - res_its : List[Iterator] - The splited Iterators result - """ - state, res_its = _ffi_api.StateSplit(self, stage_id, it, lengths, - inner_to_outer) - return state, res_its - - def follow_split(self, stage_id, it, src_step_id, n_split): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - it : Iterator - The target Iterator - src_step_id : Int - The index of target step that this split follows - n_split : Int - Indecate how many level needs to be split out - - Returns - ------- - state : State - The updated state - res_its : List[Iterator] - The splited Iterators result - """ - state, res_its = _ffi_api.StateFollowSplit(self, stage_id, it, - src_step_id, n_split) - return state, res_its - - def follow_fused_split(self, stage_id, it, src_step_ids, level, - factor_or_nparts): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - it : Iterator - The target Iterator - src_step_ids : List[Int] - The indexes of target step that this split follows - level : Int - factor_or_nparts : Bool - True to use `factor` for split from inner to outer, - False to use `nparts` for split from outer to inner - - Returns - ------- - state : State - The updated state - res_its : List[Iterator] - The splited Iterators result - """ - state, res_its = _ffi_api.StateFollowFusedSplit(self, stage_id, it, - src_step_ids, level, - factor_or_nparts) - return state, res_its - - def fuse(self, stage_id, iters): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - iters : List[Iterator] - The target Iterators to be fused - - Returns - ------- - state : State - The updated state - res_it : Iterator - The fused Iterator - """ - state, res_it = _ffi_api.StateFuse(self, stage_id, iters) - return state, res_it - - def vectorize(self, stage_id, it): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - it : Iterator - The target Iterator to be vectorized - - Returns - ------- - state : State - The updated state - res_it : Iterator - The vectorized Iterator - """ - state, res_it = _ffi_api.StateVectorize(self, stage_id, it) - return state, res_it - - def parallel(self, stage_id, it): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - it : Iterator - The target Iterator to be paralleled - - Returns - ------- - state : State - The updated state - res_it : Iterator - The paralleled Iterator - """ - state, res_it = _ffi_api.StateParallel(self, stage_id, it) - return state, res_it - - def unroll(self, stage_id, it, max_unroll=-1): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - it : Iterator - The target Iterator to be unrolled - max_unroll : Int - - Returns - ------- - state : State - The updated state - res_it : Iterator - The unrolled Iterator - """ - state, res_it = _ffi_api.StateUnroll(self, stage_id, it, max_unroll) - return state, res_it - - def bind_thread(self, stage_id, it, thread_type): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - it : Iterator - The target Iterator to be vectorized - thread_type : ... - Supported type: kVThread, kBlockX, kThreadX, kThreadY - - Returns - ------- - state : State - The updated state - res_it : Iterator - The thread binded Iterator - """ - state, res_it = _ffi_api.StateBindThread(self, stage_id, it, - thread_type) - return state, res_it - - def compute_at(self, stage_id, target_stage_id, target_iter): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - target_stage_id : Int - The index of compute at target stage - target_iter : Iterator - The target Iterator to be compute at - - Returns - ------- - state : State - The updated state - """ - return _ffi_api.StateComputeAt(self, stage_id, target_stage_id, - target_iter) - - def compute_root(self, stage_id): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - - Returns - ------- - state : State - The updated state - """ - return _ffi_api.StateComputeRoot(self, stage_id) - - def compute_inline(self, stage_id): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - - Returns - ------- - state : State - The updated state - """ - return _ffi_api.StateComputeInline(self, stage_id) - - def pack_for_vec(self, stage_id, target_iter, vec_size): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - target_iter : Iterator - The target Iterator - vec_size : Int - - Returns - ------- - state : State - The updated state - """ - return _ffi_api.StatePackForVec(self, stage_id, target_iter, vec_size) - - def cache_read(self, stage_id, scope_name, reader_stage_ids, task_dag): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - scope_name : Str - reader_stage_ids : List[Int] - task_dag : ComputeDAG - - Returns - ------- - state : State - The updated state - new_stage_id : Int - The added staged id - """ - state, new_stage_id = _ffi_api.StateCacheRead(self, stage_id, - scope_name, reader_stage_ids, task_dag) - return state, int(new_stage_id) - - def cache_write(self, stage_id, scope_name, task_dag): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - scope_name : Str - task_dag : ComputeDAG - - Returns - ------- - state : State - The updated state - new_stage_id : Int - The added staged id - """ - state, new_stage_id = _ffi_api.StateCacheWrite(self, stage_id, - scope_name, task_dag) - return state, int(new_stage_id) - - def pragma(self, stage_id, it, pragma_type): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - it : Iterator - The target Iterator - pragma_type : Str - - Returns - ------- - state : State - The updated state - """ - return _ffi_api.StatePragma(self, stage_id, it, pragma_type) - - def rfactor(self, stage_id, it, factor_iter_id, task_dag): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - it : Iterator - factor_iter_id : Int - task_dag : ComputeDAG - - Returns - ------- - state : State - The updated state - """ - state, new_stage_id = _ffi_api.StateRfactor(self, stage_id, it, - factor_iter_id, task_dag) - return state, new_stage_id - - def storage_align(self, stage_id, it, factor, offset): - """ - Parameters - ---------- - stage_id : Int - The index of target stage - it : Iterator - factor : Int - offset : Int - - Returns - ------- - state : State - The updated state - """ - return _ffi_api.StateStorageAlign(self, stage_id, it, factor, offset) diff --git a/python/tvm/ansor/task.py b/python/tvm/ansor/task.py index 5fab57c28f48..affcf4a6e195 100644 --- a/python/tvm/ansor/task.py +++ b/python/tvm/ansor/task.py @@ -14,15 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=unused-import -""" ... """ + +"""Meta information for a search task""" + import random import tvm._ffi from tvm.runtime import Object from .measure import LocalBuilder, LocalRunner from .cost_model import RandomModel - from . import _ffi_api @@ -137,7 +137,6 @@ class TuneOption(Object): callbacks: List[MeasureCallback] Callback functions """ - def __init__(self, n_trials=0, early_stopping=-1, num_measure_per_iter=64, verbose=1, builder='local', runner='local', callbacks=None): if isinstance(builder, str): diff --git a/python/tvm/ansor/utils.py b/python/tvm/ansor/utils.py index 0216549c184a..5ed9bd46d355 100644 --- a/python/tvm/ansor/utils.py +++ b/python/tvm/ansor/utils.py @@ -1,4 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + """Common utilities""" + import multiprocessing import multiprocessing.pool import queue @@ -7,7 +25,6 @@ import os import numpy as np - try: import psutil except ImportError: @@ -31,7 +48,6 @@ def get_func_name(func): name: str The name """ - return func.func_name if hasattr(func, 'func_name') else func.__name__ diff --git a/src/ansor/auto_schedule.cc b/src/ansor/auto_schedule.cc index 974e7e5d9f58..a0fa18874a69 100644 --- a/src/ansor/auto_schedule.cc +++ b/src/ansor/auto_schedule.cc @@ -1,10 +1,31 @@ -#include "auto_schedule.h" +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ -#include +/*! + * \file ansor/auto_schedule.cc + * \brief The user interface of the auto-scheduler + */ +#include "auto_schedule.h" +#include #include #include - #include "search_policy/meta_tile_rewrite_policy.h" namespace tvm { @@ -54,32 +75,32 @@ std::pair > AutoSchedule( } TVM_REGISTER_GLOBAL("ansor.TuneOption") - .set_body_typed([](int n_trials, int early_stopping, - int num_measure_per_iter, int verbose, Builder builder, - Runner runner, Array callbacks) { - return TuneOptionNode::make(n_trials, early_stopping, - num_measure_per_iter, verbose, builder, - runner, callbacks); - }); +.set_body_typed([](int n_trials, int early_stopping, + int num_measure_per_iter, int verbose, Builder builder, + Runner runner, Array callbacks) { + return TuneOptionNode::make(n_trials, early_stopping, + num_measure_per_iter, verbose, builder, + runner, callbacks); +}); TVM_REGISTER_GLOBAL("ansor.AutoScheduleBySearchTask") - .set_body_typed([](SearchTask task, SearchPolicy search_policy, - TuneOption tune_option) { - return AutoSchedule(task, search_policy, tune_option); - }); +.set_body_typed([](SearchTask task, SearchPolicy search_policy, + TuneOption tune_option) { + return AutoSchedule(task, search_policy, tune_option); +}); TVM_REGISTER_GLOBAL("ansor.AutoScheduleByWorkloadKey") - .set_body_typed([](std::string workload_key, Target target, - Target target_host, SearchPolicy search_policy, - HardwareParams hardware_params, TuneOption tune_option) { - te::Schedule sch; - Array return_tensors; - std::tie(sch, return_tensors) = - AutoSchedule(workload_key, target, target_host, search_policy, - hardware_params, tune_option); +.set_body_typed([](std::string workload_key, Target target, + Target target_host, SearchPolicy search_policy, + HardwareParams hardware_params, TuneOption tune_option) { + te::Schedule sch; + Array return_tensors; + std::tie(sch, return_tensors) = + AutoSchedule(workload_key, target, target_host, search_policy, + hardware_params, tune_option); - return Array{sch, return_tensors}; - }); + return Array{sch, return_tensors}; +}); } // namespace ansor } // namespace tvm diff --git a/src/ansor/auto_schedule.h b/src/ansor/auto_schedule.h index c354751390fe..f68e844ba776 100644 --- a/src/ansor/auto_schedule.h +++ b/src/ansor/auto_schedule.h @@ -1,12 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors - * \file ansor/search_task.h - * \brief Meta information for a search task + * \file ansor/auto_schedule.h + * \brief The user interface of the auto-scheduler */ #ifndef TVM_ANSOR_AUTO_SCHEDULE_H_ #define TVM_ANSOR_AUTO_SCHEDULE_H_ +#include +#include #include "measure.h" namespace tvm { @@ -44,7 +64,7 @@ class TuneOptionNode : public Object { static constexpr const char* _type_key = "ansor.TuneOption"; TVM_DECLARE_FINAL_OBJECT_INFO(TuneOptionNode, Object); }; -TVM_DEFINE_COW_NODE_REF(TuneOption, ObjectRef, TuneOptionNode); +TVM_DEFINE_COW_OBJECT_REF(TuneOption, ObjectRef, TuneOptionNode); /*! \brief Auto schedule for a compute declaration */ State AutoSchedule(SearchTask task, SearchPolicy search_policy, @@ -58,4 +78,4 @@ std::pair > AutoSchedule( } // namespace ansor } // namespace tvm -#endif // TVM_ANSOR_AUTO_SCHEDULE_H_ \ No newline at end of file +#endif // TVM_ANSOR_AUTO_SCHEDULE_H_ diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index 7fad0ce5b28a..f3979ef0d259 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -1,6 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors + * \file ansor/compute_dag.cc + * \brief Compute declaration graph and its related analysis tools */ + #include "compute_dag.h" #include #include @@ -15,7 +36,7 @@ #include #include #include -#include "loop_state.h" +#include "transform_step.h" #include "utils.h" // #include "../relay/pass/kernel_layout_transform.h" @@ -385,6 +406,7 @@ void AccessAnalyzer::GetConsumers(const State& state, const te::Operation& op, collect(op); } +// Return whether two int arrays are elementwise-equal bool IntArrayEqual(const Array& arr1, const Array& arr2) { if (arr1.size() != arr2.size()) { return false; @@ -543,23 +565,6 @@ class FlopEstimator: public ExprFunctor { bool fail{false}; }; -void UpdateStageAxis(const te::Stage& stage, StageToAxesMap *stage_to_axes) { - if (auto pop = stage->op.as()) { - std::vector& axes = (*stage_to_axes)[stage]; - axes.clear(); - for (const auto& axis : pop->axis) { - axes.push_back(axis); - } - for (const auto& axis : pop->reduce_axis) { - axes.push_back(axis); - } - } else if (stage->op->IsInstance()) { - {} // do nothing - } else { - LOG(FATAL) << "Invalid op " << stage->op; - } -} - State ComputeDAG::GetInitState() const { return Downcast(operator->()->init_state); } @@ -588,13 +593,6 @@ ComputeDAG ComputeDAGNode::make_by_workload_key(const std::string& workload_key) return ComputeDAGNode::make(std::move(tens)); } -// Implemented in multi_stage_policy.cc -// Extract primitive iterators from a nested fused or splitted iterator's name -extern void ExtractOriginalIterators(const std::string& name, std::set* rets); - -// Implemented in loop_state.cc -extern std::string CleanName(const std::string& str); - std::string BaseName(const std::string& str) { return str.substr(0, str.rfind("_")); } @@ -680,8 +678,8 @@ std::string BaseName(const std::string& str) { // const Operation& op = stage->op; // if (op->IsInstance()) { // const Map& attrs = op->attrs; -// if (attrs.count(_layout_free_placeholders_key)) { -// const ObjectRef& attr_value = attrs[_layout_free_placeholders_key]; +// if (attrs.count(layout_free_placeholders_key)) { +// const ObjectRef& attr_value = attrs[layout_free_placeholders_key]; // Array placeholders = Downcast>(attr_value); // for (auto& placeholder : placeholders) { // const auto placeholder_op = placeholder->op; @@ -907,7 +905,8 @@ std::string BaseName(const std::string& str) { // auto index = old_tensor->value_index; // ptensors->data[i] = new_op.output(index); // } else if (layout_rewrite_level == kComputeRewrite) { -// TensorNode* old_tensor_node = const_cast(old_tensor.as()); +// TensorNode* old_tensor_node = +// const_cast(old_tensor.as()); // old_tensor_node->op = new_op; // } // } @@ -918,6 +917,24 @@ std::string BaseName(const std::string& str) { // } // end for stage // } + +void UpdateStageAxis(const te::Stage& stage, StageToAxesMap *stage_to_axes) { + if (auto pop = stage->op.as()) { + std::vector& axes = (*stage_to_axes)[stage]; + axes.clear(); + for (const auto& axis : pop->axis) { + axes.push_back(axis); + } + for (const auto& axis : pop->reduce_axis) { + axes.push_back(axis); + } + } else if (stage->op->IsInstance()) { + {} // do nothing + } else { + LOG(FATAL) << "Invalid op " << stage->op; + } +} + std::pair > ComputeDAG::ApplySteps( const std::vector& transform_steps, LayoutRewriteLevel layout_rewrite_level) const { @@ -1104,9 +1121,6 @@ std::pair > ComputeDAG::ReplaySteps( UpdateStageAxis(stage, stage_to_axes); } - // todo(lmzheng): should we maintain the attach_map and keep the validity of - // compute_at an splitted axis? - // Use complete rate for the study in the paper const char* complete_rate_str = getenv("ANSOR_PROGRAM_COMPLETE_RATE"); double complete_rate = -1.0; diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h index 3b4c80c50ad8..60c1790a0cfb 100644 --- a/src/ansor/compute_dag.h +++ b/src/ansor/compute_dag.h @@ -1,5 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors * \file ansor/compute_dag.h * \brief Compute declaration graph and its related analysis tools */ @@ -22,12 +40,6 @@ namespace ansor { class ComputeDAG; class AccessAnalyzer; class StateNode; class State; class Step; -typedef std::unordered_map, ObjectHash, ObjectEqual> - StageToAxesMap; - -// Update StageToAxes Map during replay -void UpdateStageAxis(const tvm::te::Stage& stage, StageToAxesMap *stage_to_axes); - /*! \brief Read/Write access static analysis result */ class AccessAnalyzerNode : public Object { public: @@ -60,9 +72,11 @@ class AccessAnalyzer : public ObjectRef { // Get all producers of an op void GetProducers(const State& state, const te::Operation& op, std::unordered_set* producers) const; + // Get all consumers of an op. This func deals with inlined op correctly. void GetConsumers(const State& state, const te::Operation& op, std::unordered_set* consumers) const; + // Check whether two ops are elementwise matched // (e.g. conv2d and relu are elementwise matched) bool ElementWiseMatch(const te::Operation& op, @@ -84,17 +98,23 @@ class AccessAnalyzer : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(AccessAnalyzer, ObjectRef, AccessAnalyzerNode); }; +typedef std::unordered_map, ObjectHash, ObjectEqual> + StageToAxesMap; + +// Update StageToAxes Map during replay +void UpdateStageAxis(const tvm::te::Stage& stage, StageToAxesMap *stage_to_axes); + + /*! \brief Compute declaration graph */ class ComputeDAGNode : public Object { public: - Array tensors; // Input and output tensors - Array ops; // All related operations in topo order - double flop_ct; // Number of float operations + Array tensors; // Input and output tensors + Array ops; // All related operations in topo order + double flop_ct; // Number of float operations AccessAnalyzer access_analyzer; // Read/Write accesss static analyzer - ObjectRef init_state; // initial states + ObjectRef init_state; // The initial state void VisitAttrs(tvm::AttrVisitor* v) { - LOG(INFO) << "ComputeDAG"; v->Visit("tensors", &tensors); v->Visit("ops", &ops); v->Visit("flop_ct", &flop_ct); @@ -126,7 +146,7 @@ class ComputeDAG: public ObjectRef { // Rewrite the the layout of "layout free" placeholders according to transform steps void RewriteLayout(const std::vector& transform_steps, - LayoutRewriteLevel layout_rewrite_level = kNoRewrite) const {}; + LayoutRewriteLevel layout_rewrite_level = kNoRewrite) const {} // Print transform steps as equivalent python schedule API std::string PrintStepsAsPython(const std::vector& steps) const; @@ -134,19 +154,21 @@ class ComputeDAG: public ObjectRef { // Replay the transform steps and call ir_pass::InferBound to fill correct bound information State ReplayAndInferBound(const std::vector& transform_steps) const; - // Fill the correct bound information for a given state + // Fill the correct bound information for a given state by calling ir_pass::InferBound State InferBound(const State& state) const; // Fill the correct bound information for a list of given states. // Return the new states inplace void InferBound(std::vector* states) const; - // Replay the transform steps and get the new ops + // Replay the transform steps and get the new DAG void ReplayAndGetDAG(const std::vector& steps, ComputeDAG* task_dag) const; // Get the init state State GetInitState() const; + static constexpr const char* layout_free_placeholders_key = "layout_free_placeholders"; + TVM_DEFINE_OBJECT_REF_METHODS(ComputeDAG, ObjectRef, ComputeDAGNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode); @@ -155,7 +177,6 @@ class ComputeDAG: public ObjectRef { std::pair > ReplaySteps( const std::vector& transform_steps, std::vector* stages, StageToAxesMap* stage_to_axes) const; - static constexpr const char* _layout_free_placeholders_key = "layout_free_placeholders"; // Internal common parts for inferring bound void InferBoundCommon(StateNode* pstate) const; diff --git a/src/ansor/cost_model/cost_model.cc b/src/ansor/cost_model/cost_model.cc index 060d2b703287..8e0936071774 100644 --- a/src/ansor/cost_model/cost_model.cc +++ b/src/ansor/cost_model/cost_model.cc @@ -1,6 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors + * \file ansor/cost_model.h + * \brief Cost model that estimates the performance of programs */ + #include "cost_model.h" #include @@ -23,7 +44,7 @@ void RandomNumber(TVMArgs args, TVMRetValue* rv) { void* data = args[1]; float* fdata = reinterpret_cast(data); for (int i = 0; i < n; i++) { - fdata[i] = static_cast(rand_r(0)) / (static_cast(RAND_MAX)); + fdata[i] = static_cast(rand_r(nullptr)) / (static_cast(RAND_MAX)); } } @@ -130,7 +151,7 @@ void PythonBasedCostModelNode::PredictStages( CHECK_LE(idx, flatten_scores.size()); // Number of scored stages of this state. - int s_length = (int)flatten_scores[idx++]; + int s_length = static_cast(flatten_scores[idx++]); if (s_length > 0) { std::vector scores; diff --git a/src/ansor/cost_model/cost_model.h b/src/ansor/cost_model/cost_model.h index 36179573c617..9daf01197bbf 100644 --- a/src/ansor/cost_model/cost_model.h +++ b/src/ansor/cost_model/cost_model.h @@ -1,8 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors * \file ansor/cost_model.h - * \brief Base class of cost model - */ + * \brief Cost model that estimates the performance of programs +*/ #ifndef TVM_ANSOR_COST_MODEL_COST_MODEL_H_ #define TVM_ANSOR_COST_MODEL_COST_MODEL_H_ @@ -23,17 +41,24 @@ class CostModel; /*! \brief The base class for cost model */ class CostModelNode: public Object { public: + // Update the cost model according to new measurement pairs virtual void Update(const Array& inputs, const Array& results) = 0; + + // Predict the scores of states virtual void Predict(const SearchTask& task, const std::vector& states, std::vector* scores) = 0; + + // Predict the scores of all stages in states virtual void PredictStages(const SearchTask& task, const std::vector& states, std::vector* state_scores, - std::vector>* stage_scores) = 0; + std::vector>* stage_scores) { + LOG(FATAL) << "Not Implemented"; + } static constexpr const char *_type_key = "ansor.CostModel"; TVM_DECLARE_BASE_OBJECT_INFO(CostModelNode, Object); }; -TVM_DEFINE_MUTABLE_NODE_REF(CostModel, CostModelNode); +TVM_DEFINE_MUTABLE_OBJECT_REF(CostModel, CostModelNode); /*! \brief The cost model returns random value for all predictions */ class RandomModelNode: public CostModelNode { @@ -45,14 +70,12 @@ class RandomModelNode: public CostModelNode { void Update(const Array& inputs, const Array& results) final; void Predict(const SearchTask& task, const std::vector& states, std::vector* scores) final; - void PredictStages(const SearchTask& task, const std::vector& states, - std::vector* state_scores, - std::vector>* stage_scores) { ; } static constexpr const char *_type_key = "ansor.RandomModel"; TVM_DECLARE_FINAL_OBJECT_INFO(RandomModelNode, CostModelNode); }; +/*! \brief The cost model returns actual cost by measurement */ class MeasureModelNode : public CostModelNode { public: ProgramMeasurer measurer; @@ -62,9 +85,6 @@ class MeasureModelNode : public CostModelNode { void Update(const Array& inputs, const Array& results) final; void Predict(const SearchTask& task, const std::vector& states, std::vector* scores) final; - void PredictStages(const SearchTask& task, const std::vector& states, - std::vector* state_scores, - std::vector>* stage_scores) { ; } static constexpr const char* _type_key = "ansor.MeasureModel"; TVM_DECLARE_FINAL_OBJECT_INFO(MeasureModelNode, CostModelNode); diff --git a/src/ansor/expr_hasher.h b/src/ansor/expr_hasher.h deleted file mode 100644 index 1c743ed9a5c4..000000000000 --- a/src/ansor/expr_hasher.h +++ /dev/null @@ -1,97 +0,0 @@ -/*! - * Copyright (c) 2020 by Contributors - * \file auto_scheduler/expr_hasher.h - * \brief Hash function for a tvm::Expr - */ - -#ifndef TVM_ANSOR_EXPR_HASHER_H_ -#define TVM_ANSOR_EXPR_HASHER_H_ - -#include -#include -#include -#include - -namespace tvm { - -/*! \brief Assign a hash value for a tvm::Expr */ -class ExprHasher: public tir::ExprFunctor { - public: - size_t VisitExpr_(const tir::AddNode* op) final { - return VisitExpr(op->a) + VisitExpr(op->b); - } - - size_t VisitExpr_(const tir::SubNode* op) final { - return VisitExpr(op->a) - VisitExpr(op->b); - } - - size_t VisitExpr_(const tir::MulNode* op) final { - return VisitExpr(op->a) * VisitExpr(op->b); - } - - size_t VisitExpr_(const tir::DivNode* op) final { - size_t t = VisitExpr(op->b); - if (t != 0) { - return VisitExpr(op->a) / t; - } else { - return dmlc::HashCombine(VisitExpr(op->a), 0x5A); - } - } - - size_t VisitExpr_(const tir::FloorDivNode* op) final { - size_t t = VisitExpr(op->b); - if (t != 0) { - return VisitExpr(op->a) / t; - } else { - return dmlc::HashCombine(VisitExpr(op->a), 0x5B); - } - } - - size_t VisitExpr_(const tir::ModNode* op) final { - size_t t = VisitExpr(op->b); - if (t != 0) { - return VisitExpr(op->a) % t; - } else { - return dmlc::HashCombine(VisitExpr(op->a), 0x5C); - } - } - - size_t VisitExpr_(const tir::FloorModNode* op) final { - size_t t = VisitExpr(op->b); - if (t != 0) { - return VisitExpr(op->a) % t; - } else { - return dmlc::HashCombine(VisitExpr(op->a), 0x5D); - } - } - - size_t VisitExpr_(const tir::CallNode* op) final { - size_t ret = ObjectHash()(op->func); - for (size_t i = 0; i < op->args.size(); ++i) { - ret = dmlc::HashCombine(ret, VisitExpr(op->args[i])); - } - return ret; - } - - size_t VisitExpr_(const tir::VarNode* op) final { - return std::hash()(op); - } - - size_t VisitExpr_(const tir::FloatImmNode* op) final { - return std::hash()(op->value); - } - - size_t VisitExpr_(const tir::IntImmNode* op) final { - return std::hash()(op->value); - } - - size_t VisitExprDefault_(const Object* op) final { - LOG(WARNING) << "Encounter undefined node in ExprHasher: " - << Object::_type_key; - return std::hash()(op); - } -}; - -} // namespace tvm - -#endif // TVM_ANSOR_EXPR_HASHER_H_ diff --git a/src/ansor/feature.cc b/src/ansor/feature.cc index cb865bc3b5ae..31afe931361c 100644 --- a/src/ansor/feature.cc +++ b/src/ansor/feature.cc @@ -272,7 +272,8 @@ class BufferAccessExtractor : public StmtExprVisitor { this->VisitExpr(expr); } - void InsertAccess(const te::Tensor& ten, BufferAccessType acc_type, const Array& indices) { + void InsertAccess(const te::Tensor& ten, BufferAccessType acc_type, + const Array& indices) { BufferAccess& acc = buf_accesses[ten]; acc.acc_type = acc_type; acc.indices.push_back(std::vector(indices.begin(), indices.end())); diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index 32940da0773a..faaac94f3323 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -1,18 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors + * \file ansor/loop_state.h + * \brief An IR (intermediate representation) for loop structures. */ -#include "loop_state.h" +#include "loop_state.h" #include #include - +#include "transform_step.h" #include "utils.h" namespace tvm { namespace ansor { -TVM_REGISTER_OBJECT_TYPE(StageNode); +TVM_REGISTER_OBJECT_TYPE(StepNode); +TVM_REGISTER_NODE_TYPE(StageNode); TVM_REGISTER_NODE_TYPE(StateNode); +TVM_REGISTER_NODE_TYPE(IteratorNode); + +// Maker for other classes +Iterator IteratorNode::make(std::string name, Range range, + IteratorType iter_type, IteratorAnnotation annotation, + const std::vector* ori_iters) { + auto node = make_object(); + node->name = std::move(name); + node->range = std::move(range); + node->iter_type = iter_type; + node->annotation = annotation; + if (ori_iters != nullptr) { + node->ori_iters = *ori_iters; + } + return Iterator(node); +} + Stage StageNode::make(te::Operation op) { auto node = make_object(); @@ -43,7 +81,7 @@ Stage StageNode::make(te::Operation op) { Stage StageNode::make(te::Operation op, StageType op_type, const std::vector& iters, - ComputeAtType compute_at, int16_t auto_unroll_max_step, + ComputeAtType compute_at, int auto_unroll_max_step, int storage_offset) { auto node = make_object(); node->op = std::move(op); @@ -57,7 +95,7 @@ Stage StageNode::make(te::Operation op, StageType op_type, Stage StageNode::make(te::Operation op, StageType op_type, std::vector&& iters, ComputeAtType compute_at, - int16_t auto_unroll_max_step, int storage_offset) { + int auto_unroll_max_step, int storage_offset) { auto node = make_object(); node->op = std::move(op); node->op_type = op_type; @@ -216,15 +254,6 @@ void State::compute_inline(int stage_id) { return DoComputeInlineStep(step); } -void State::pack_for_vec(int stage_id, const Iterator& target_iter, - int vec_size) { - const Stage& stage = operator->()->stages[stage_id]; - PackForVecStep step = PackForVecStepNode::make( - stage_id, GetIndex(stage->iters, target_iter), vec_size); - CopyOnWrite()->transform_steps.push_back(step); - return DoPackForVecStep(step); -} - Iterator State::bind_thread(int stage_id, const Iterator& it, IteratorAnnotation thread_type) { const Stage& stage = operator->()->stages[stage_id]; @@ -560,10 +589,6 @@ void State::DoComputeInlineStep(const ComputeInlineStep& step) { pstate->attach_map.DeleteStage(step->stage_id); } -void State::DoPackForVecStep(const PackForVecStep& step) { - LOG(FATAL) << "Not implemented"; -} - // Common part for steps that add new stages // (e.g. CacheReadStep, CacheWriteStep, RfactorStep) void AddStageModificationSteps(size_t step_id, @@ -741,8 +766,6 @@ void State::DoStep(const Step& step, const ComputeDAG& dag) { DoComputeRootStep(GetRef(ps)); } else if (auto ps = step.as()) { DoComputeInlineStep(GetRef(ps)); - } else if (auto ps = step.as()) { - DoPackForVecStep(GetRef(ps)); } else if (auto ps = step.as()) { DoCacheReadStep(GetRef(ps), dag); } else if (auto ps = step.as()) { @@ -991,177 +1014,175 @@ AttachMap AttachMap::ApplyStageIdOfffset(int start_id, int offset) const { } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - PrintState(&p->stream, node, true); - }); - -TVM_REGISTER_GLOBAL("ansor.StageGetIterator") - .set_body_typed([](const Stage& stage, int index) { - return stage->iters[index]; - }); - -TVM_REGISTER_GLOBAL("ansor.StageGetIterators") - .set_body_typed([](const Stage& stage) { - return Array(stage->iters); - }); - -TVM_REGISTER_GLOBAL("ansor.StateGetStage") - .set_body_typed([](const State& state, int index) { - return state->stages[index]; - }); - -TVM_REGISTER_GLOBAL("ansor.StateGetTransformStepsSize") - .set_body_typed([](const State& state) { - return static_cast(state->transform_steps.size()); - }); +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + PrintState(&p->stream, node, true); +}); + + +TVM_REGISTER_GLOBAL("ansor.StageGetIterator").set_body_typed([](const Stage& stage, int index) { + return stage->iters[index]; +}); + +TVM_REGISTER_GLOBAL("ansor.StageGetIterators").set_body_typed([](const Stage& stage) { + return Array(stage->iters); +}); + +TVM_REGISTER_GLOBAL("ansor.StateGetStages").set_body_typed([](const State& state) { + return Array(state->stages); +}); + +TVM_REGISTER_GLOBAL("ansor.StateGetStage").set_body_typed([](const State& state, int index) { + return state->stages[index]; +}); + +TVM_REGISTER_GLOBAL("ansor.StateGetTransformStepsSize").set_body_typed([](const State& state) { + return static_cast(state->transform_steps.size()); +}); TVM_REGISTER_GLOBAL("ansor.StateReorder") - .set_body_typed([](State state, int stage_id, - const Array& order) { - std::vector ord; - for (const auto& i : order) { - ord.push_back(i); - } - state.reorder(stage_id, ord); - return state; - }); +.set_body_typed([](State state, int stage_id, const Array& order) { + std::vector ord; + for (const auto& i : order) { + ord.push_back(i); + } + state.reorder(stage_id, ord); + return state; +}); TVM_REGISTER_GLOBAL("ansor.StateSplit") - .set_body_typed([](State state, int stage_id, const Iterator& it, - const Array& lengths, bool inner_to_outer) { - std::vector len; - for (const auto& i : lengths) { - len.push_back(i); - } - const auto& res = state.split(stage_id, it, len, inner_to_outer); - return Array{state, Array(res)}; - }); +.set_body_typed([](State state, int stage_id, const Iterator& it, + const Array& lengths, bool inner_to_outer) { + std::vector len; + for (const auto& i : lengths) { + len.push_back(i); + } + const auto& res = state.split(stage_id, it, len, inner_to_outer); + return Array{state, Array(res)}; +}); TVM_REGISTER_GLOBAL("ansor.StateFollowSplit") - .set_body_typed([](State state, int stage_id, const Iterator& it, - int src_step_id, int n_split) { - const auto& res = state.follow_split(stage_id, it, src_step_id, n_split); - return Array{state, Array(res)}; - }); +.set_body_typed([](State state, int stage_id, const Iterator& it, + int src_step_id, int n_split) { + const auto& res = state.follow_split(stage_id, it, src_step_id, n_split); + return Array{state, Array(res)}; +}); TVM_REGISTER_GLOBAL("ansor.StateFollowFusedSplit") - .set_body_typed([](State state, int stage_id, const Iterator& it, - const Array& src_step_ids, int level, - bool factor_or_nparts) { - std::vector array_src_step_ids; - for (const auto& i : src_step_ids) { - array_src_step_ids.push_back(i->value); - } - const auto& res = state.follow_fused_split( - stage_id, it, array_src_step_ids, level, factor_or_nparts); - return Array{state, Array(res)}; - }); +.set_body_typed([](State state, int stage_id, const Iterator& it, + const Array& src_step_ids, int level, + bool factor_or_nparts) { + std::vector array_src_step_ids; + for (const auto& i : src_step_ids) { + array_src_step_ids.push_back(i->value); + } + const auto& res = state.follow_fused_split( + stage_id, it, array_src_step_ids, level, factor_or_nparts); + return Array{state, Array(res)}; +}); TVM_REGISTER_GLOBAL("ansor.StateFuse") - .set_body_typed([](State state, int stage_id, - const Array& iters) { - std::vector its; - for (const auto& i : iters) { - its.push_back(i); - } - const auto& res = state.fuse(stage_id, its); - return Array{state, res}; - }); +.set_body_typed([](State state, int stage_id, + const Array& iters) { + std::vector its; + for (const auto& i : iters) { + its.push_back(i); + } + const auto& res = state.fuse(stage_id, its); + return Array{state, res}; +}); TVM_REGISTER_GLOBAL("ansor.StateVectorize") - .set_body_typed([](State state, int stage_id, const Iterator& it) { - const auto& res = state.vectorize(stage_id, it); - return Array{state, res}; - }); +.set_body_typed([](State state, int stage_id, const Iterator& it) { + const auto& res = state.vectorize(stage_id, it); + return Array{state, res}; +}); TVM_REGISTER_GLOBAL("ansor.StateParallel") - .set_body_typed([](State state, int stage_id, const Iterator& it) { - const auto& res = state.parallel(stage_id, it); - return Array{state, res}; - }); +.set_body_typed([](State state, int stage_id, const Iterator& it) { + const auto& res = state.parallel(stage_id, it); + return Array{state, res}; +}); TVM_REGISTER_GLOBAL("ansor.StateUnroll") - .set_body_typed([](State state, int stage_id, const Iterator& it, - int max_unroll) { - const auto& res = state.unroll(stage_id, it, max_unroll); - return Array{state, res}; - }); +.set_body_typed([](State state, int stage_id, const Iterator& it, + int max_unroll) { + const auto& res = state.unroll(stage_id, it, max_unroll); + return Array{state, res}; +}); TVM_REGISTER_GLOBAL("ansor.StateBindThread") - .set_body_typed([](State state, int stage_id, const Iterator& it, - int thread_type) { - const auto& res = - state.bind_thread(stage_id, it, IteratorAnnotation(thread_type)); - return Array{state, res}; - }); +.set_body_typed([](State state, int stage_id, const Iterator& it, + int thread_type) { + const auto& res = + state.bind_thread(stage_id, it, IteratorAnnotation(thread_type)); + return Array{state, res}; +}); TVM_REGISTER_GLOBAL("ansor.StateComputeAt") - .set_body_typed([](State state, int stage_id, int target_stage_id, - const Iterator& target_iter) { - state.compute_at(stage_id, target_stage_id, target_iter); - return state; - }); +.set_body_typed([](State state, int stage_id, int target_stage_id, + const Iterator& target_iter) { + state.compute_at(stage_id, target_stage_id, target_iter); + return state; +}); TVM_REGISTER_GLOBAL("ansor.StateComputeRoot") - .set_body_typed([](State state, int stage_id) { - state.compute_root(stage_id); - return state; - }); +.set_body_typed([](State state, int stage_id) { + state.compute_root(stage_id); + return state; +}); TVM_REGISTER_GLOBAL("ansor.StateComputeInline") - .set_body_typed([](State state, int stage_id) { - state.compute_inline(stage_id); - return state; - }); - -TVM_REGISTER_GLOBAL("ansor.StatePackForVec") - .set_body_typed([](State state, int stage_id, const Iterator& target_iter, - int vec_size) { - state.pack_for_vec(stage_id, target_iter, vec_size); - return state; - }); +.set_body_typed([](State state, int stage_id) { + state.compute_inline(stage_id); + return state; +}); TVM_REGISTER_GLOBAL("ansor.StateCacheRead") - .set_body_typed([](State state, int stage_id, const std::string& scope_name, - const Array& reader_stage_ids, - const ComputeDAG& task_dag) { - std::vector array_reader_stage_ids; - for (const auto& i : reader_stage_ids) { - array_reader_stage_ids.push_back(i->value); - } - int res = state.cache_read(stage_id, scope_name, array_reader_stage_ids, - task_dag); - return Array{state, IntImm(DataType::Int(32), res)}; - }); +.set_body_typed([](State state, int stage_id, const std::string& scope_name, + const Array& reader_stage_ids, + const ComputeDAG& task_dag) { + std::vector array_reader_stage_ids; + for (const auto& i : reader_stage_ids) { + array_reader_stage_ids.push_back(i->value); + } + int res = state.cache_read(stage_id, scope_name, array_reader_stage_ids, + task_dag); + return Array{state, IntImm(DataType::Int(32), res)}; +}); TVM_REGISTER_GLOBAL("ansor.StateCacheWrite") - .set_body_typed([](State state, int stage_id, const std::string& scope_name, - const ComputeDAG& task_dag) { - int res = state.cache_write(stage_id, scope_name, task_dag); - return Array{state, IntImm(DataType::Int(32), res)}; - }); +.set_body_typed([](State state, int stage_id, const std::string& scope_name, + const ComputeDAG& task_dag) { + int res = state.cache_write(stage_id, scope_name, task_dag); + return Array{state, IntImm(DataType::Int(32), res)}; +}); TVM_REGISTER_GLOBAL("ansor.StatePragma") - .set_body_typed([](State state, int stage_id, const Iterator& it, - const std::string& pragma_type) { - state.pragma(stage_id, it, pragma_type); - return state; - }); +.set_body_typed([](State state, int stage_id, const Iterator& it, + const std::string& pragma_type) { + state.pragma(stage_id, it, pragma_type); + return state; +}); TVM_REGISTER_GLOBAL("ansor.StateRfactor") - .set_body_typed([](State state, int stage_id, const Iterator& it, - int factor_iter_id, const ComputeDAG& task_dag) { - int res = state.rfactor(stage_id, it, factor_iter_id, task_dag); - return Array{state, IntImm(DataType::Int(32), res)}; - }); +.set_body_typed([](State state, int stage_id, const Iterator& it, + int factor_iter_id, const ComputeDAG& task_dag) { + int res = state.rfactor(stage_id, it, factor_iter_id, task_dag); + return Array{state, IntImm(DataType::Int(32), res)}; +}); TVM_REGISTER_GLOBAL("ansor.StateStorageAlign") - .set_body_typed([](State state, int stage_id, const Iterator& it, - int factor, int offset) { - state.storage_align(stage_id, it, factor, offset); - return state; - }); +.set_body_typed([](State state, int stage_id, const Iterator& it, + int factor, int offset) { + state.storage_align(stage_id, it, factor, offset); + return state; +}); + +TVM_REGISTER_GLOBAL("ansor.StateEqual") +.set_body_typed([](State state1, State state2) { + return std::equal_to()(state1, state2); +}); } // namespace ansor } // namespace tvm diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h index dd56e267c0a0..90ba48cd92ac 100644 --- a/src/ansor/loop_state.h +++ b/src/ansor/loop_state.h @@ -1,16 +1,36 @@ -/*! - * Copyright (c) 2020 by Contributors - * \file ansor/interfaces.h - * \brief Data structures for loop transformations +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file ansor/loop_state.h + * \brief The definition of the "state" in search. A state consists a current loop structure + * and the transform history to reach its current loop structure. + * To enable flexible manipulation of the loop structure, we implemented a lightweight + * loop structure IR (Intermediate Representation) specifically for search. + * * Basically this is a simplified TVM IR with schedule primitives. * We don't use the existing TVM IR because * 1. We want fast incremental change to the loop structures * 2. We want serializable history for replay and backtracking - * 3. We want simplified IR for easy and clean feature extraction - * 4. We may create some Macro schedule primitives - - * After search is done, we will lower this IR to TVM IR and TVM schedule primitives. + * 3. We may create some Macro schedule primitives + * + * After search is done, we will lower this IR to TVM IR with TVM schedule primitives. * Because we share a lot common objects during search, the transformation is * implemented in copy on write style. All objects are immutable, which is * similar to TVM IR. @@ -24,24 +44,77 @@ #include #include #include -#include "transform_step.h" +#include "compute_dag.h" namespace tvm { namespace ansor { using namespace tvm::tir; +/*! \brief The type of a stage */ enum StageType { kPlaceholder, kCompute }; +/*! \brief The type of compute location */ enum ComputeAtType { kRoot, // compute at root kInlined, // inlined kIter, // compute at some iterator }; +/*! \brief The type of an iterator */ +enum IteratorType { + kSpace, // spatial iterator + kReduce, // reduction iterator + kMixed, // fused spatial and reduction iterator + kSpecial // special iterator (e.g. virtual root iterator) +}; + +/*! \brief The type of an iterator's annotation */ +enum IteratorAnnotation { + kNone, kUnroll, kVectorize, kParallel, + kVThread, kBlockX, kThreadX, kBlockY, kThreadY +}; + +class Iterator; + +/*! + * \brief A for loop iterator + * Similar to tvm::IterVar in `include/tvm/tir/expr.h` + */ +class IteratorNode : public Object { + public: + std::string name; + Range range; + IteratorType iter_type; + IteratorAnnotation annotation; + std::vector ori_iters; // The original iterators before fusion + + static Iterator make(std::string name, Range range, + IteratorType iter_type, IteratorAnnotation annotation, + const std::vector* ori_iters = nullptr); + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("range", &range); + } + + static constexpr const char *_type_key = "ansor.Iterator"; + TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object); +}; +TVM_DEFINE_COW_OBJECT_REF(Iterator, ObjectRef, IteratorNode); + +// Forward decelerations class Stage; class State; +class AttachMap; + +class ReorderStep; class SplitStep; class FollowSplitStep; +class FollowFusedSplitStep; +class FuseStep; class AnnotationStep; +class ComputeAtStep; class ComputeRootStep; class ComputeInlineStep; +class CacheReadStep; class CacheWriteStep; +class PragmaStep; class RfactorStep; class StorageAlignStep; /*! * \brief A stage in the compute declaration @@ -53,25 +126,32 @@ class StageNode : public Object { StageType op_type; std::vector iters; ComputeAtType compute_at; - int16_t auto_unroll_max_step; + int auto_unroll_max_step; int storage_offset; + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("op", &op); + } + static Stage make(te::Operation op); static Stage make(te::Operation op, StageType op_type, const std::vector& iters, - ComputeAtType compute_at, int16_t auto_unroll_max_step, + ComputeAtType compute_at, int auto_unroll_max_step, int storage_offset); static Stage make(te::Operation op, StageType op_type, std::vector&& iters, - ComputeAtType compute_at, int16_t auto_unroll_max_step, + ComputeAtType compute_at, int auto_unroll_max_step, int storage_offset); static constexpr const char *_type_key = "ansor.Stage"; TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object); }; -TVM_DEFINE_COW_NODE_REF(Stage, ObjectRef, StageNode); +TVM_DEFINE_COW_OBJECT_REF(Stage, ObjectRef, StageNode); -/*! \brief stores the compute_at relation between stages */ +/*! \brief stores the compute_at relation between stages + * This stores a bi-directional mapping from stages and iter: + * 1. Stage to its attached iterator 2. Iterator to the stage attached to it + */ class AttachMapNode: public Object { public: using StageKey = int; @@ -110,6 +190,22 @@ class AttachMap : public ObjectRef { static void DeleteStageEntry(AttachMapNode* pnode, int stage_id); }; +/*! \brief The base class for a transformation step */ +class StepNode: public Object { + public: + int stage_id; + + // Print step as equivalent python schedule API + virtual std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const = 0; + + static constexpr const char* _type_key = "ansor.Step"; + TVM_DECLARE_BASE_OBJECT_INFO(StepNode, Object); +}; +TVM_DEFINE_MUTABLE_OBJECT_REF(Step, StepNode); + /*! \brief The loop state and corresponding history steps to reach this state */ class StateNode: public Object { public: @@ -125,6 +221,7 @@ class StateNode: public Object { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("complete", &complete); v->Visit("aux_info", &aux_info); + v->Visit("task_dag", &task_dag); } static State make_empty_state(); @@ -137,7 +234,8 @@ class StateNode: public Object { TVM_DECLARE_FINAL_OBJECT_INFO(StateNode, Object); }; -/*! \brief The loop state and corresponding history steps to reach this state */ +/*! \brief A state in the search process. + * It consists of the current loop structure and the history steps to reach this state. */ class State : public ObjectRef { public: // Schedule primitives @@ -154,14 +252,12 @@ class State : public ObjectRef { Iterator vectorize(int stage_id, const Iterator& it); Iterator parallel(int stage_id, const Iterator& it); Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1); - // Valide thread_type: kVThread, kBlockX, kThreadX, kThreadY Iterator bind_thread(int stage_id, const Iterator& it, IteratorAnnotation thread_type); void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter); void compute_root(int stage_id); void compute_inline(int stage_id); - void pack_for_vec(int stage_id, const Iterator& target_iter, int vec_size); int cache_read(int stage_id, const std::string& scope_name, const std::vector& reader_stage_ids, const ComputeDAG& task_dag); @@ -172,8 +268,10 @@ class State : public ObjectRef { const ComputeDAG& task_dag); void storage_align(int stage_id, const Iterator& it, int factor, int offset); - // We separate these functions out, - // so you can call them for replay easily given history steps + /* Do transform steps + * Note: The following functions only change loop state but do not change transform_history. + * We separate these functions out, + * so you can call them for replay easily given history steps */ void DoReorderStep(const ReorderStep& step); std::vector DoSplitStep(const SplitStep& step); std::vector DoFollowSplitStep(const FollowSplitStep& step); @@ -183,38 +281,44 @@ class State : public ObjectRef { void DoComputeAtStep(const ComputeAtStep& step); void DoComputeRootStep(const ComputeRootStep& step); void DoComputeInlineStep(const ComputeInlineStep& step); - void DoPackForVecStep(const PackForVecStep& step); int DoCacheReadStep(const CacheReadStep& step, const ComputeDAG& dag); int DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag); void DoPragmaStep(const PragmaStep& step); int DoRfactorStep(const RfactorStep& step, const ComputeDAG& dag); void DoStorageAlignStep(const StorageAlignStep& step); - /* Do transform steps - * Note: The following function only change loop state. - * They do not change transform_history. - */ + // General do step functions with a runtime dynamic dispatcher void DoStep(const Step& step, const ComputeDAG& dag); void DoSteps(const std::vector& step, const ComputeDAG& dag); - // Print to str + // Print the state to a string std::string ToStr(bool delete_trivial_loop = true) const; TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode); private: - // common function for DoSplitStep and DoFollowSplitStep + // Common function for DoSplitStep and DoFollowSplitStep std::vector DoSplitStepCommon(int stage_id, int iter_id, const std::vector& lengths, bool inner_to_outer); }; +/*! \brief Clean the name of an iterator to make it valid in python code */ +inline std::string CleanName(const std::string& str) { + std::string ret = str; + StrReplace(&ret, ".", "_"); + StrReplace(&ret, "@", "_"); + StrReplace(&ret, "outer", "o"); + StrReplace(&ret, "inner", "i"); + return ret; +} + } // namespace ansor } // namespace tvm -// Hash and equal function for State, Stage, Iterator and Step +// Hash and equal function for State namespace std { template <> diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc index b2cff24973bc..43be530f2a35 100644 --- a/src/ansor/measure.cc +++ b/src/ansor/measure.cc @@ -2,15 +2,12 @@ * Copyright (c) 2020 by Contributors */ #include "measure.h" -// #include #include #include - #include #include #include #include -// #include "search_policy/search_policy.h" namespace tvm { namespace ansor { @@ -38,7 +35,7 @@ const char* ErrorNoToStr[] = { "UnknownError", }; -// Maker +// Measure input and result MeasureInput MeasureInputNode::make(SearchTask task, State state) { auto node = make_object(); node->task = std::move(task); @@ -87,6 +84,7 @@ MeasureResult MeasureResultNode::copy() const { return MeasureResult(node); } +// LocalBuilder Builder LocalBuilderNode::make(int timeout, int n_parallel, const std::string& build_func) { auto node = make_object(); @@ -96,7 +94,6 @@ Builder LocalBuilderNode::make(int timeout, int n_parallel, return Builder(node); } -// LocalBuilder and LocalRunner Array LocalBuilderNode::Build(const Array& inputs, int verbose) { if (const auto* f = runtime::Registry::Get("ansor.local_builder.build")) { @@ -109,6 +106,7 @@ Array LocalBuilderNode::Build(const Array& inputs, return Array(); } +// RPC Runner Runner RPCRunnerNode::make(const std::string& key, const std::string& host, int port, int priority, int timeout, int n_parallel, int number, int repeat, int min_repeat_ms, @@ -141,6 +139,7 @@ Array RPCRunnerNode::Run(const Array& inputs, return Array(); } +// Local Runner Runner LocalRunnerNode::make(int timeout, int number, int repeat, int min_repeat_ms, double cooldown_interval) { ObjectPtr node = make_object(); @@ -166,6 +165,7 @@ Array LocalRunnerNode::Run( return Array(); } +// Program Measurer ProgramMeasurer ProgramMeasurerNode::make(Builder builder, Runner runner, Array callbacks, int verbose, @@ -284,89 +284,89 @@ void ProgramMeasurerNode::SilentMeasure(const SearchTask& task, // Printing functions TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - p->stream << "MeasureInput()"; - }); +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + p->stream << "MeasureInput()"; +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - if (node->error_no == kNoError) { - p->stream << "MeasureResult(cost:["; - auto old_config = p->stream.precision(4); - for (size_t i = 0; i < node->costs.size(); ++i) { - auto pf = node->costs[i].as(); - CHECK(pf != nullptr); - p->stream << pf->value; - if (i != node->costs.size() - 1) { - p->stream << ","; - } - } - p->stream.precision(old_config); - p->stream << "], "; - p->stream << "error_no:" << 0 << ", " - << "all_cost:" << node->all_cost << ", " - << "Tstamp:" << node->timestamp << ")"; - } else { - p->stream << "MeasureResult(" - << "error_type:" << ErrorNoToStr[node->error_no] << ", " - << "error_msg:" << node->error_msg << ", " - << "all_cost:" << node->all_cost << ", " - << "Tstamp:" << node->timestamp << ")"; +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + if (node->error_no == kNoError) { + p->stream << "MeasureResult(cost:["; + auto old_config = p->stream.precision(4); + for (size_t i = 0; i < node->costs.size(); ++i) { + auto pf = node->costs[i].as(); + CHECK(pf != nullptr); + p->stream << pf->value; + if (i != node->costs.size() - 1) { + p->stream << ","; } - }); + } + p->stream.precision(old_config); + p->stream << "], "; + p->stream << "error_no:" << 0 << ", " + << "all_cost:" << node->all_cost << ", " + << "Tstamp:" << node->timestamp << ")"; + } else { + p->stream << "MeasureResult(" + << "error_type:" << ErrorNoToStr[node->error_no] << ", " + << "error_msg:" << node->error_msg << ", " + << "all_cost:" << node->all_cost << ", " + << "Tstamp:" << node->timestamp << ")"; + } +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "BuildResult(" << node->filename << ", " << node->error_no - << ", " << node->time_cost << ")"; - }); +.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "BuildResult(" << node->filename << ", " << node->error_no + << ", " << node->time_cost << ")"; +}); TVM_REGISTER_GLOBAL("ansor.MeasureInput") - .set_body_typed([](SearchTask task, State state) { - return MeasureInputNode::make(task, state); - }); +.set_body_typed([](SearchTask task, State state) { + return MeasureInputNode::make(task, state); +}); TVM_REGISTER_GLOBAL("ansor.BuildResult") - .set_body_typed([](std::string filename, Array args, - int error_no, std::string error_msg, double time_cost) { - return BuildResultNode::make(filename, args, error_no, error_msg, - time_cost); - }); +.set_body_typed([](std::string filename, Array args, + int error_no, std::string error_msg, double time_cost) { + return BuildResultNode::make(filename, args, error_no, error_msg, + time_cost); +}); TVM_REGISTER_GLOBAL("ansor.MeasureResult") - .set_body_typed([](Array costs, int error_no, - std::string error_msg, double all_cost, - double timestamp) { - return MeasureResultNode::make(costs, error_no, error_msg, all_cost, - timestamp); - }); +.set_body_typed([](Array costs, int error_no, + std::string error_msg, double all_cost, + double timestamp) { + return MeasureResultNode::make(costs, error_no, error_msg, all_cost, + timestamp); +}); TVM_REGISTER_GLOBAL("ansor.BuilderBuild") - .set_body_typed([](const Builder& builder, - const Array& inputs, int verbose) { - return builder->Build(inputs, verbose); - }); +.set_body_typed([](const Builder& builder, + const Array& inputs, int verbose) { + return builder->Build(inputs, verbose); +}); TVM_REGISTER_GLOBAL("ansor.RunnerRun") - .set_body_typed([](const Runner& runner, const Array& inputs, - const Array& build_results, int verbose) { - return runner->Run(inputs, build_results, verbose); - }); +.set_body_typed([](const Runner& runner, const Array& inputs, + const Array& build_results, int verbose) { + return runner->Run(inputs, build_results, verbose); +}); TVM_REGISTER_GLOBAL("ansor.LocalBuilder") - .set_body_typed([](int timeout, int n_parallel, - const std::string& build_func) { - return LocalBuilderNode::make(timeout, n_parallel, build_func); - }); +.set_body_typed([](int timeout, int n_parallel, + const std::string& build_func) { + return LocalBuilderNode::make(timeout, n_parallel, build_func); +}); TVM_REGISTER_GLOBAL("ansor.LocalRunner") - .set_body_typed([](int timeout, int number, int repeat, int min_repeat_ms, - double cooldown_interval) { - return LocalRunnerNode::make(timeout, number, repeat, min_repeat_ms, - cooldown_interval); - }); +.set_body_typed([](int timeout, int number, int repeat, int min_repeat_ms, + double cooldown_interval) { + return LocalRunnerNode::make(timeout, number, repeat, min_repeat_ms, + cooldown_interval); +}); } // namespace ansor } // namespace tvm diff --git a/src/ansor/measure.h b/src/ansor/measure.h index 4ea1562315ff..780a30514d46 100644 --- a/src/ansor/measure.h +++ b/src/ansor/measure.h @@ -7,7 +7,6 @@ #ifndef TVM_ANSOR_MEASURE_H_ #define TVM_ANSOR_MEASURE_H_ -// #include #include #include #include @@ -22,8 +21,7 @@ class SearchPolicy; class MeasureInput; class BuildResult; class MeasureResult; class Builder; class Runner; class MeasureCallback; class ProgramMeasurer; -extern const char *ErrorNoToStr[]; - +/* \brief The error code of one measurement */ enum MeasureErrorNO { kNoError = 0, // No error kInstantiationError = 1, // Errors happen when apply transform steps from init state @@ -35,14 +33,15 @@ enum MeasureErrorNO { kRunTimeoutError = 7, // Timeout during run kUnknonwError = 8, // Unknown error }; +extern const char *ErrorNoToStr[]; // Inputs and results of one measurement -/* \brief Store the input of a meansurement */ +/* \brief Store the input of a measurement */ class MeasureInputNode: public Object { public: - SearchTask task; - State state; + SearchTask task; // The search task + State state; // The program state to be measured void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("task", &task); @@ -55,16 +54,16 @@ class MeasureInputNode: public Object { static constexpr const char* _type_key = "ansor.MeasureInput"; TVM_DECLARE_FINAL_OBJECT_INFO(MeasureInputNode, Object); }; -TVM_DEFINE_NODE_REF(MeasureInput, MeasureInputNode); +TVM_DEFINE_OBJECT_REF(MeasureInput, MeasureInputNode); /* \brief Store the input of a build */ class BuildResultNode: public Object { public: - std::string filename; - Array args; - int error_no; - std::string error_msg; - double time_cost; + std::string filename; // The filename of built binary file + Array args; // The arguments + int error_no; // The error code (see MeasureErrorNO). 0 means no error. + std::string error_msg; // The error message if there is any error + double time_cost; // The time cost of build void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("filename", &filename); @@ -80,16 +79,16 @@ class BuildResultNode: public Object { static constexpr const char* _type_key = "ansor.BuildResult"; TVM_DECLARE_FINAL_OBJECT_INFO(BuildResultNode, Object); }; -TVM_DEFINE_NODE_REF(BuildResult, BuildResultNode); +TVM_DEFINE_OBJECT_REF(BuildResult, BuildResultNode); /* \brief Store the results of a measurement */ class MeasureResultNode: public Object { public: - Array costs; - int error_no; - std::string error_msg; - double all_cost; - double timestamp; + Array costs; // The time costs of execution + int error_no; // The error code (see MeasureErrorNO). 0 means no error. + std::string error_msg; // The error message if there is any error + double all_cost; // The time cost of build and run + double timestamp; // The time stamps of this measurement void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("costs", &costs); @@ -107,19 +106,21 @@ class MeasureResultNode: public Object { static constexpr const char* _type_key = "ansor.MeasureResult"; TVM_DECLARE_FINAL_OBJECT_INFO(MeasureResultNode, Object); }; -TVM_DEFINE_NODE_REF(MeasureResult, MeasureResultNode); +TVM_DEFINE_OBJECT_REF(MeasureResult, MeasureResultNode); -// Measure callback +/* \brief Bass class of measurement callbacks */ class MeasureCallbackNode: public Object { public: + /*! \biref Callback function that will be called on measurement input/result pairs + * after measurement */ virtual void callback(const SearchPolicy& policy, const Array& inputs, const Array& results) = 0; static constexpr const char *_type_key = "ansor.MeasureCallback"; TVM_DECLARE_BASE_OBJECT_INFO(MeasureCallbackNode, Object); }; -TVM_DEFINE_MUTABLE_NODE_REF(MeasureCallback, MeasureCallbackNode); +TVM_DEFINE_MUTABLE_OBJECT_REF(MeasureCallback, MeasureCallbackNode); // Base class for builder and runner @@ -127,21 +128,23 @@ TVM_DEFINE_MUTABLE_NODE_REF(MeasureCallback, MeasureCallbackNode); /* \brief Builder that builds the programs */ class BuilderNode: public Object { public: - int n_parallel; - int timeout; + int n_parallel; // The number of tasks to run in parallel + int timeout; // Timeout of a build + /*! \biref Build programs and return results */ virtual Array Build(const Array& inputs, int verbose) = 0; static constexpr const char* _type_key = "ansor.Builder"; TVM_DECLARE_BASE_OBJECT_INFO(BuilderNode, Object); }; -TVM_DEFINE_MUTABLE_NODE_REF(Builder, BuilderNode); +TVM_DEFINE_MUTABLE_OBJECT_REF(Builder, BuilderNode); /* \brief Runner that runs the built programs and measure the time cost */ class RunnerNode: public Object { public: - int timeout; + int timeout; // Timeout of a run + /*! \biref Run measurement and return results */ virtual Array Run(const Array& inputs, const Array& build_results, int verbose) = 0; @@ -149,14 +152,14 @@ class RunnerNode: public Object { static constexpr const char* _type_key = "ansor.Runner"; TVM_DECLARE_BASE_OBJECT_INFO(RunnerNode, Object); }; -TVM_DEFINE_MUTABLE_NODE_REF(Runner, RunnerNode); +TVM_DEFINE_MUTABLE_OBJECT_REF(Runner, RunnerNode); // Implementation of various builders and runners /* \brief LocalBuilder use local CPU cores to build programs in parallel */ class LocalBuilderNode: public BuilderNode { public: - std::string build_func; + std::string build_func; // Build function static Builder make(int timeout, int n_parallel, const std::string& build_func); @@ -166,6 +169,7 @@ class LocalBuilderNode: public BuilderNode { TVM_DECLARE_FINAL_OBJECT_INFO(LocalBuilderNode, BuilderNode); }; +/* \brief RPCRunner that uses RPC call to measures the time cost of programs on remote devices */ class RPCRunnerNode : public RunnerNode { public: std::string key; @@ -182,6 +186,7 @@ class RPCRunnerNode : public RunnerNode { int priority, int timeout, int n_parallel, int number, int repeat, int min_repeat_ms, double cooldown_interval); + /*! \biref Run measurement and return results */ Array Run(const Array& inputs, const Array& build_results, int verbose) final; @@ -190,7 +195,7 @@ class RPCRunnerNode : public RunnerNode { TVM_DECLARE_FINAL_OBJECT_INFO(RPCRunnerNode, RunnerNode); }; -/* \brief LocalRunner use local CPU/GPU to runs programs in serial and measure the time cost */ +/* \brief LocalRunner that uses local CPU/GPU to measures the time cost of programs */ class LocalRunnerNode: public RunnerNode { public: int number; @@ -201,6 +206,7 @@ class LocalRunnerNode: public RunnerNode { static Runner make(int timeout, int number, int repeat, int min_repeat_ms, double cooldown_interval); + /*! \biref Run measurement and return results */ Array Run(const Array& inputs, const Array& build_results, int verbose) final; @@ -211,9 +217,8 @@ class LocalRunnerNode: public RunnerNode { /*! - * \brief Measurer measures the time costs of tvm programs - * This class combines Builder and Runner, and provides a simpler API - */ + * \brief Measurer that measures the time costs of tvm programs + * This class combines Builder and Runner, and provides a simpler API */ class ProgramMeasurerNode: public Object { public: static const int DEFAULT_MAX_CONTINOUS_ERROR = 150; @@ -253,7 +258,7 @@ class ProgramMeasurerNode: public Object { static constexpr const char* _type_key = "ansor.ProgramMeasurer"; TVM_DECLARE_FINAL_OBJECT_INFO(ProgramMeasurerNode, Object); }; -TVM_DEFINE_MUTABLE_NODE_REF(ProgramMeasurer, ProgramMeasurerNode); +TVM_DEFINE_MUTABLE_OBJECT_REF(ProgramMeasurer, ProgramMeasurerNode); } // namespace ansor diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.cc b/src/ansor/search_policy/meta_tile_rewrite_policy.cc index b4501804607a..c22d890a8b51 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.cc +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.cc @@ -1,5 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors + * \file ansor/search_policy/meta_tile_rewrite_policy.h + * \brief The search policy that searches by program sampling and evolutionary search */ #include "meta_tile_rewrite_policy.h" @@ -776,7 +796,7 @@ int InitPopulationThreadBind(const MetaTileRewritePolicyNode* policy, // Fuse the outermost space tile as blockIdx for (size_t i = 0; i < pop->axis.size(); i++) { const auto& it = (*state)->stages[stage_id]->iters[i]; - if (!StringEndWith(it->name, ".0")) { + if (!StrEndsWith(it->name, ".0")) { break; } to_fuse.push_back(it); @@ -788,7 +808,7 @@ int InitPopulationThreadBind(const MetaTileRewritePolicyNode* policy, to_fuse.clear(); for (size_t i = 1; i < pop->axis.size() + 1; i++) { const auto& it = (*state)->stages[stage_id]->iters[i]; - if (!StringEndWith(it->name, ".1")) { + if (!StrEndsWith(it->name, ".1")) { break; } to_fuse.push_back((*state)->stages[stage_id]->iters[i]); @@ -804,7 +824,7 @@ int InitPopulationThreadBind(const MetaTileRewritePolicyNode* policy, to_fuse.clear(); for (size_t i = 2; i < pop->axis.size() + 2; i++) { const auto& it = (*state)->stages[stage_id]->iters[i]; - if (!StringEndWith(it->name, ".2")) { + if (!StrEndsWith(it->name, ".2")) { break; } to_fuse.push_back((*state)->stages[stage_id]->iters[i]); diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.h b/src/ansor/search_policy/meta_tile_rewrite_policy.h index ca9033ad866e..0c8c44b9c5ea 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.h +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.h @@ -1,100 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors - * \file ansor/meta_tile_rewrite_policy.h - * \brief A search policy that search with meta tiling structure and random - * rewrite + * \file ansor/search_policy/meta_tile_rewrite_policy.h + * \brief The search policy that searches by program sampling and evolutionary search */ + #ifndef TVM_ANSOR_SEARCH_POLICY_META_TILE_REWRITE_POLICY_H_ #define TVM_ANSOR_SEARCH_POLICY_META_TILE_REWRITE_POLICY_H_ -#include +#include #include -#include #include -#include - +#include +#include +#include "search_policy.h" #include "../cost_model/cost_model.h" #include "../utils.h" -#include "search_policy.h" + namespace tvm { namespace ansor { /*! Multi stage search policy */ -class MetaTileRewritePolicyNode : public SearchPolicyNode { +class MetaTileRewritePolicyNode: public SearchPolicyNode { public: CostModel program_cost_model; /* this->params is used to store the following arguments - * int evolutionary_search_population - * The population size for evolutionary search - * int evolutionary_search_mutation_prob - * The probability of mutation for evolutionary search - * int evolutionary_search_num_iters - * The number of iterations for evolutionary search - * double local_mutation_use_measured_ratio - * The maximum percentage of measured states in the initial population - * for evolutionary search - * double eps_greedy - * Always allocate this percentage of measurements to random sampled states - * str cpu_multi_level_tiling_structure - * The structure of multi-level tiling for CPU - * str gpu_multi_level_tiling_structure - * The structure of multi-level tiling for GPU + * int evolutionary_search_population // The population size for evolutionary search + * int evolutionary_search_mutation_prob // The probability of mutation for evolutionary search + * int evolutionary_search_num_iters; // The number of iterations for evolutionary search + * double local_mutation_use_measured_ratio; // The maximum percentage of measured states in the initial + * // population for evolutionary search + * double eps_greedy; // Always allocate this percentage of measurements to random sampled states + * str cpu_multi_level_tiling_structure // The structure of multi-level tiling for CPU + * str gpu_multi_level_tiling_structure // The structure of multi-level tiling for GPU */ Map params; static SearchPolicy make(CostModel program_cost_model, - Map params, int seed); + Map params, + int seed); // Search and make n_trails measurements // Return the best state - State Search(SearchTask task, int n_trials, int early_stopping, - int num_measure_per_iter, int verbose, - ProgramMeasurer measurer) final; + State Search(SearchTask task, int n_trials, + int early_stopping, int num_measure_per_iter, + int verbose, ProgramMeasurer measurer) final; // Continue search. This is used by JointTuner std::pair, Array > ContinueSearchOneRound( - SearchTask task, int num_measure, int verbose, - ProgramMeasurer measurer) final; + SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) final; - static constexpr const char* _type_key = "ansor.MetaTileRewritePolicy"; + static constexpr const char *_type_key = "ansor.MetaTileRewritePolicy"; static const std::vector auto_unroll_configs; TVM_DECLARE_FINAL_OBJECT_INFO(MetaTileRewritePolicyNode, SearchPolicyNode); - SearchTask cur_task_; // The current task + SearchTask cur_task_; // The current task + friend class MetaTileRewritePolicyNodeTest; // Hack friend class for UT protected: // Pick states from best states and random states with eps-greedy policy void PickStatesWithEpsGreedy(std::vector* inputs, const std::vector& best_states, - const std::vector& random_states, - int remaining_n_trials); + const std::vector& random_states, int remaining_n_trials); private: // Run one round of the search pipeline - void SearchOneRound(std::vector* best_states, int num_random_states, - std::vector* random_states); + void SearchOneRound(std::vector* best_states, + int num_random_states, std::vector* random_states); // Synthesize meta tiling structure without tile size void SynthesizeMetaStructure(std::vector* out_states); // Sample init population void SampleInitPopulation(const std::vector& meta_structures, - int out_size, std::vector* out_states); + int out_size, std::vector* out_states); // Perform evolutionary search void EvolutionarySearch(const std::vector& init_population, - int num_best_states, std::vector* best_states); + int num_best_states, std::vector* best_states); SplitFactorizationMemo split_memo_; // Memorize split space for Split std::mt19937 rand_gen_; // Random generator int verbose_; // Verbose level (0 means silent) - int num_measure_per_iter_; // The number of states to measure per iteration + int num_measure_per_iter_; // The number of states to measure per iteration - // The set of the already measured states. We store the string format for - // redundancy check + // The set of the already measured states. We store the string format for redundancy check std::unordered_set measured_states_set_; // The array of already measured states. diff --git a/src/ansor/search_policy/search_policy.cc b/src/ansor/search_policy/search_policy.cc index 89bfeb1a8edd..866922d0001e 100644 --- a/src/ansor/search_policy/search_policy.cc +++ b/src/ansor/search_policy/search_policy.cc @@ -1,5 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors + * \file ansor/search_policy/search_policy.cc + * \brief The base class for search policy */ #include "search_policy.h" @@ -11,4 +31,3 @@ TVM_REGISTER_OBJECT_TYPE(SearchPolicyNode); } // namespace ansor } // namespace tvm - diff --git a/src/ansor/search_policy/search_policy.h b/src/ansor/search_policy/search_policy.h index 5bd9fb3118b1..f2071deab447 100644 --- a/src/ansor/search_policy/search_policy.h +++ b/src/ansor/search_policy/search_policy.h @@ -1,8 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors - * \file ansor/search_policy.h - * \brief Base class of search policy + * \file ansor/search_policy/search_policy.h + * \brief The base class for search policy */ + #ifndef TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_ #define TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_ @@ -45,7 +64,7 @@ class SearchPolicyNode : public Object { static constexpr const char *_type_key = "ansor.SearchPolicy"; TVM_DECLARE_BASE_OBJECT_INFO(SearchPolicyNode, Object); }; -TVM_DEFINE_MUTABLE_NODE_REF(SearchPolicy, SearchPolicyNode); +TVM_DEFINE_MUTABLE_OBJECT_REF(SearchPolicy, SearchPolicyNode); } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_policy/utils.cc b/src/ansor/search_policy/utils.cc index 9c597b4eb811..608b89da118c 100644 --- a/src/ansor/search_policy/utils.cc +++ b/src/ansor/search_policy/utils.cc @@ -1,5 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors + * \file ansor/search_policy/utils.cc + * \brief Common utilities for search policies */ #include "utils.h" @@ -42,27 +62,6 @@ void GetSpaceSplitStepIds(const State& s, int stage_id, std::vector* spatia } } -// Query axes that should not be splitted according to the attribute from tvm.compute -std::pair, std::set > QueryNoSplitAxis(const Stage& stage) { - std::pair, std::set > ret; - if (stage->op->attrs.count(SearchPolicyNode::no_split_at_inner_key)) { - ret.first = GetIterNameSetParam(stage->op->attrs, SearchPolicyNode::no_split_at_inner_key); - } - if (stage->op->attrs.count(SearchPolicyNode::no_split_at_outer_key)) { - ret.second = GetIterNameSetParam(stage->op->attrs, SearchPolicyNode::no_split_at_outer_key); - } - return ret; -} - -// Query axes that last split is one -std::set QueryLastSplitIsOneAxis(const Stage& stage) { - std::set ret; - if (stage->op->attrs.count(SearchPolicyNode::last_split_is_one_key)) { - ret = GetIterNameSetParam(stage->op->attrs, SearchPolicyNode::last_split_is_one_key); - } - return ret; -} - // Apply multi-tiling structure according to a string format State DoMultiLevelTiling(const State& state, int stage_id, const std::string& format, std::vector* spatial_split_step_ids) { @@ -413,7 +412,7 @@ State RandomMutateMaxUnrollStep(const State& old_state, std::mt19937* random_gen // Mutate a parallel loop. State MutataParallel(const State& state, SplitFactorizationMemo* split_memo, - std::mt19937* random_gen, SearchTask& task, int verbose) { + std::mt19937* random_gen, const SearchTask& task, int verbose) { // To make this mutation simple but promising, we only focus on a specific case that // parallel was added to the outermost loop and the loop is generated by fusing other loops. // In short, we mutate the step pattern of (fuse -> parallel). @@ -574,17 +573,6 @@ void GridMutateTileSize(const State& old_state, std::vector* cands, } } -// Random choose an index according to a prefix sum probability -int RandomChoose(const std::vector& prefix_sum_probs, std::mt19937* random_gen) { - std::uniform_real_distribution<> dis(0.0, 1.0); - double x = dis(*random_gen); - - CHECK(!prefix_sum_probs.empty()); - - return std::lower_bound(prefix_sum_probs.begin(), prefix_sum_probs.end(), x) - - prefix_sum_probs.begin(); -} - // Prune undefined states. void PruneUndefined(std::vector* states) { size_t pt = 0; diff --git a/src/ansor/search_policy/utils.h b/src/ansor/search_policy/utils.h index 3337975d7a88..607a549e1b8a 100644 --- a/src/ansor/search_policy/utils.h +++ b/src/ansor/search_policy/utils.h @@ -1,7 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors - * \file ansor/search_policy/utils.h - * \brief Common utilities for local mutation in search policy + * \file ansor/search_policy/utils.cc + * \brief Common utilities for search policies */ #ifndef TVM_ANSOR_SEARCH_POLICY_UTILS_H_ @@ -15,20 +33,13 @@ #include #include "../cost_model/cost_model.h" #include "../utils.h" +#include "../loop_state.h" +#include "../transform_step.h" #include "search_policy.h" namespace tvm { namespace ansor { -inline bool StringEndWith(const std::string& str, const std::string& target) { - int str_len = str.length(); - int target_len = target.length(); - if (str_len <= target_len) { - return false; - } - return str.compare(str_len - target_len, target_len, target) == 0; -} - // Get an integer from a tvm str Map inline int GetIntParam(const Map& attr_dict, const std::string& key) { @@ -96,7 +107,8 @@ inline int64_t GetExtent(const Iterator& it) { } // Return whether an op is strict inlineable -inline bool IsStrictInlineable(const SearchTask& task, const State& state, const te::Operation& op) { +inline bool IsStrictInlineable(const SearchTask& task, + const State& state, const te::Operation& op) { if (state->task_dag.defined()) { return state->task_dag->access_analyzer.IsStrictInlineable(op); } else { @@ -132,7 +144,8 @@ inline bool HasReduceIter(const Stage& stage) { } // Return whether an op needs multi level tiling -inline bool NeedsMultilevelTiling(const SearchTask& task, const State& state, const te::Operation& op) { +inline bool NeedsMultilevelTiling(const SearchTask& task, + const State& state, const te::Operation& op) { if (state->task_dag.defined()) { return state->task_dag->access_analyzer.NeedsMultiLevelTiling(op); } else { @@ -142,7 +155,7 @@ inline bool NeedsMultilevelTiling(const SearchTask& task, const State& state, co // Get all consumers for an op. This will take inline into consideration inline void GetConsumers(const SearchTask& task, const State& state, const te::Operation& op, - std::unordered_set* consumers) { + std::unordered_set* consumers) { if (state->task_dag.defined()) { state->task_dag->access_analyzer.GetConsumers(state, op, consumers); } else { @@ -161,7 +174,7 @@ inline void GetProducers(const SearchTask& task, const State& state, const te::O // Return whether two ops are elementwise-matched inline bool ElementwiseMatch(const SearchTask& task, const State& state, const te::Operation& op, - const te::Operation& target_op) { + const te::Operation& target_op) { if (state->task_dag.defined()) { return state->task_dag->access_analyzer.ElementWiseMatch(op, target_op); } else { @@ -171,8 +184,7 @@ inline bool ElementwiseMatch(const SearchTask& task, const State& state, const t // Return whether the stage has only one consumer and they are elementwise-matched inline bool HasSingleElementwiseMatchedConsumer(const SearchTask& task, - const State& state, const Stage& stage, - int* target_stage_id) { + const State& state, const Stage& stage, int* target_stage_id) { std::unordered_set consumers; GetConsumers(task, state, stage->op, &consumers); @@ -203,8 +215,8 @@ inline bool NeedsRfactor(const SearchTask& task, const State& state, const te::O if (NeedsMultilevelTiling(task, state, op)) { // Do not use rfactor if we have enough parallelism on space iters - if (cum_space_len > cum_reduce_len - || cum_space_len > task->hardware_params->num_cores * 16) { + if (cum_space_len > cum_reduce_len || + cum_space_len > task->hardware_params->num_cores * 16) { return false; } else { return true; @@ -240,6 +252,7 @@ inline bool HasCacheWriteStage(const State& s, int stage_id) { return false; } +// Return whether the state did cache_read for stage_id inline bool HasCacheReadStage(const State& s, int stage_id) { for (int i = static_cast(s->transform_steps.size()) - 1; i >= 0; --i) { if (auto ps = s->transform_steps[i].as()) { @@ -261,8 +274,10 @@ inline bool HasCacheReadStage(const State& s, int stage_id) { return false; } +// Get all split step on spatial iterators void GetSpaceSplitStepIds(const State& s, int stage_id, std::vector* spatial_split_step_ids); +// Return whether the state did split/follow_split/follow_fused_split in stage_id inline bool HasSplitStep(const State& s, int stage_id) { for (int i = static_cast(s->transform_steps.size()) - 1; i >= 0; --i) { if (s->transform_steps[i]->IsInstance() || @@ -290,9 +305,26 @@ inline bool IsTiled(const Stage& stage) { } // Query axes that should not be splitted according to the attribute from tvm.compute -std::pair, std::set > QueryNoSplitAxis(const Stage& stage); +inline std::pair, std::set > QueryNoSplitAxis( + const Stage& stage) { + std::pair, std::set > ret; + if (stage->op->attrs.count(SearchPolicyNode::no_split_at_inner_key)) { + ret.first = GetIterNameSetParam(stage->op->attrs, SearchPolicyNode::no_split_at_inner_key); + } + if (stage->op->attrs.count(SearchPolicyNode::no_split_at_outer_key)) { + ret.second = GetIterNameSetParam(stage->op->attrs, SearchPolicyNode::no_split_at_outer_key); + } + return ret; +} + // Query axes that last split is one -std::set QueryLastSplitIsOneAxis(const Stage& stage); +inline std::set QueryLastSplitIsOneAxis(const Stage& stage) { + std::set ret; + if (stage->op->attrs.count(SearchPolicyNode::last_split_is_one_key)) { + ret = GetIterNameSetParam(stage->op->attrs, SearchPolicyNode::last_split_is_one_key); + } + return ret; +} // Extract primitive iterators from a nested fused or splitted iterator's name inline void ExtractOriginalIterators(const std::string& name, std::set* rets) { @@ -329,6 +361,7 @@ inline const Iterator& GetLastSpaceIteratorInOutermostTile(const Stage& stage) { return stage->iters[0]; } +// Get the last reduce iterator in the outermost reduce tile inline const Iterator& GetLastReduceIteratorInOutermostReduceTile(const Stage& stage) { auto pop = stage->op.as(); CHECK(pop != nullptr); @@ -379,10 +412,15 @@ inline void RandomSampleStates(const std::vector& in_states, std::mt19937 } // Random choose an index according to a prefix sum probability -int RandomChoose(const std::vector& prefix_sum_probs, std::mt19937* random_gen); +inline int RandomChoose(const std::vector& prefix_sum_probs, std::mt19937* random_gen) { + std::uniform_real_distribution<> dis(0.0, 1.0); + double x = dis(*random_gen); -// Prune undefined states. -void PruneUndefined(std::vector* states); + CHECK(!prefix_sum_probs.empty()); + + return std::lower_bound(prefix_sum_probs.begin(), prefix_sum_probs.end(), x) - + prefix_sum_probs.begin(); +} // Print all states inline void PrintAllStates(const std::vector& states) { @@ -418,7 +456,7 @@ State RandomMutateMaxUnrollStep(const State& old_state, std::mt19937* random_gen // Mutate a parallel loop. State MutataParallel(const State& old_state, SplitFactorizationMemo* split_memo, - std::mt19937* random_gen, SearchTask& task, int verbose = 0); + std::mt19937* random_gen, const SearchTask& task, int verbose = 0); // Create all possible tile size states for all SplitStep void GridMutateTileSize(const State& old_state, std::vector* cands, @@ -427,6 +465,9 @@ void GridMutateTileSize(const State& old_state, std::vector* cands, // GA: Crossover two states State CrossOverState(const State& p1, const State& p2); +// Prune undefined states. +void PruneUndefined(std::vector* states); + } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_task.cc b/src/ansor/search_task.cc index 93f3f60ea768..c65516150f30 100644 --- a/src/ansor/search_task.cc +++ b/src/ansor/search_task.cc @@ -1,12 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors + * \file ansor/search_task.cc + * \brief Meta information and hardware parameters for a search task */ -#include "search_task.h" +#include "search_task.h" #include #include #include - #include #include @@ -118,21 +137,21 @@ SearchTask SearchTaskNode::make(ComputeDAG compute_dag, } TVM_REGISTER_GLOBAL("ansor.HardwareParams") - .set_body_typed([](int num_cores, int vector_unit_bytes, - int cache_line_bytes, int max_unroll_vec, - int max_innermost_split_factor) { - return HardwareParamsNode::make(num_cores, vector_unit_bytes, - cache_line_bytes, max_unroll_vec, - max_innermost_split_factor); - }); +.set_body_typed([](int num_cores, int vector_unit_bytes, + int cache_line_bytes, int max_unroll_vec, + int max_innermost_split_factor) { + return HardwareParamsNode::make(num_cores, vector_unit_bytes, + cache_line_bytes, max_unroll_vec, + max_innermost_split_factor); +}); TVM_REGISTER_GLOBAL("ansor.SearchTask") - .set_body_typed([](ComputeDAG compute_dag, std::string workload_key, - Target target, Target target_host, - HardwareParams hardware_params) { - return SearchTaskNode::make(compute_dag, workload_key, target, - target_host, hardware_params); - }); +.set_body_typed([](ComputeDAG compute_dag, std::string workload_key, + Target target, Target target_host, + HardwareParams hardware_params) { + return SearchTaskNode::make(compute_dag, workload_key, target, + target_host, hardware_params); +}); } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_task.h b/src/ansor/search_task.h index 9512013848b6..cfa5500c39f4 100644 --- a/src/ansor/search_task.h +++ b/src/ansor/search_task.h @@ -1,36 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors * \file ansor/search_task.h - * \brief Meta information for a search task + * \brief Meta information and hardware parameters for a search task */ #ifndef TVM_ANSOR_SEARCH_TASK_H_ #define TVM_ANSOR_SEARCH_TASK_H_ #include - #include - #include "compute_dag.h" namespace tvm { namespace ansor { -class HardwareParams; -class SearchTask; +class HardwareParams; class SearchTask; /*! \brief Hardware related parameters */ class HardwareParamsNode : public Object { public: + // The number of cores int num_cores; + // The width of vector units in bytes int vector_unit_bytes; + // The size of cache line in bytes int cache_line_bytes; - // The max length of the axis to be unrolled or vectorized + // The max length of an axis to be unrolled or vectorized int max_unroll_vec; // The max split factor for the innermost tile int max_innermost_split_factor; - // Limit params for GPU schedule + // Limitation params for GPU int max_shared_memory_per_block{INT32_MAX}; int max_registers_per_block{INT32_MAX}; int max_threads_per_block{INT32_MAX}; @@ -54,13 +72,14 @@ class HardwareParamsNode : public Object { static HardwareParams make(int num_cores, int vector_unit_bytes, int cache_line_bytes, int max_unroll_vec, int max_innermost_split_factor); + static HardwareParams GetDefaultHardwareParams(const Target& target, const Target& target_host); static constexpr const char* _type_key = "ansor.HardwareParams"; TVM_DECLARE_FINAL_OBJECT_INFO(HardwareParamsNode, Object); }; -TVM_DEFINE_COW_NODE_REF(HardwareParams, ObjectRef, HardwareParamsNode); +TVM_DEFINE_COW_OBJECT_REF(HardwareParams, ObjectRef, HardwareParamsNode); /*! \brief Meta-info for a search task */ class SearchTaskNode : public Object { @@ -86,7 +105,7 @@ class SearchTaskNode : public Object { static constexpr const char* _type_key = "ansor.SearchTask"; TVM_DECLARE_FINAL_OBJECT_INFO(SearchTaskNode, Object); }; -TVM_DEFINE_COW_NODE_REF(SearchTask, ObjectRef, SearchTaskNode); +TVM_DEFINE_COW_OBJECT_REF(SearchTask, ObjectRef, SearchTaskNode); } // namespace ansor } // namespace tvm diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index fc4917409cc0..53c75a13f197 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -1,57 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors + * \file ansor/serialization.cc + * \brief Json serialization format for dumping and loading tuning records */ -#include "serialization.h" #include #include - #include #include +#include #include #include -#include - +#include "serialization.h" #include "loop_state.h" +#include "transform_step.h" #include "utils.h" // Json serialization handler for MeasureInput, MeasureResult -// (and recursively SearchTask, State, Step, ... +// (and recursively for SearchTask, State, Step, ...) namespace dmlc { namespace json { -inline std::vector& FloatArrayToVector( - std::vector* out, const ::tvm::Array<::tvm::PrimExpr>& data) { +inline std::vector& IntArrayToVector(std::vector* out, + const ::tvm::Array<::tvm::PrimExpr>& data) { out->clear(); - for (const auto& x : data) { - auto pf = x.as<::tvm::tir::FloatImmNode>(); - CHECK(pf != nullptr) << "Cost can only contain float values"; - out->push_back(pf->value); - } - return *out; -} - -inline std::vector& IntArrayToVector( - std::vector* out, const ::tvm::Array<::tvm::PrimExpr>& data) { - out->clear(); - for (const auto& x : data) { + for (const auto&x : data) { auto pi = x.as<::tvm::tir::IntImmNode>(); - CHECK(pi != nullptr) << "Cost can only contain int values"; + CHECK(pi != nullptr) << "Can only contain int values"; out->push_back(pi->value); } return *out; } template <> -struct Handler> { +struct Handler > { inline static void Write(dmlc::JSONWriter* writer, - const std::vector<::tvm::ansor::Stage>& data) { + const std::vector<::tvm::ansor::Stage> & data) { // todo(lmzheng): support serialization of Stage writer->BeginArray(false); writer->EndArray(); } inline static void Read(dmlc::JSONReader* reader, - std::vector<::tvm::ansor::Stage>* data) { + std::vector<::tvm::ansor::Stage> * data) { bool s; reader->BeginArray(); s = reader->NextArrayItem(); CHECK(!s); @@ -59,16 +67,16 @@ struct Handler> { }; template <> -struct Handler> { +struct Handler > { inline static void Write(dmlc::JSONWriter* writer, - const std::vector<::tvm::ansor::Step>& data) { + const std::vector<::tvm::ansor::Step> & data) { std::vector tmp; writer->BeginArray(false); for (size_t i = 0; i < data.size(); ++i) { writer->WriteArraySeperator(); writer->BeginArray(false); if (auto ps = data[i].as<::tvm::ansor::ReorderStepNode>()) { - writer->WriteArrayItem(std::string("RS")); + writer->WriteArrayItem(std::string("RE")); writer->WriteArrayItem(ps->stage_id); writer->WriteArraySeperator(); @@ -78,7 +86,7 @@ struct Handler> { } writer->EndArray(); } else if (auto ps = data[i].as<::tvm::ansor::SplitStepNode>()) { - writer->WriteArrayItem(std::string("SS")); + writer->WriteArrayItem(std::string("SP")); writer->WriteArrayItem(ps->stage_id); writer->WriteArrayItem(ps->iter_id); if (ps->extent.defined()) { @@ -89,14 +97,13 @@ struct Handler> { writer->WriteArrayItem(IntArrayToVector(&tmp, ps->lengths)); writer->WriteArrayItem(static_cast(ps->inner_to_outer)); } else if (auto ps = data[i].as<::tvm::ansor::FollowSplitStepNode>()) { - writer->WriteArrayItem(std::string("FSS")); + writer->WriteArrayItem(std::string("FSP")); writer->WriteArrayItem(ps->stage_id); writer->WriteArrayItem(ps->iter_id); writer->WriteArrayItem(ps->src_step_id); writer->WriteArrayItem(ps->n_split); - } else if (auto ps = - data[i].as<::tvm::ansor::FollowFusedSplitStepNode>()) { - writer->WriteArrayItem(std::string("FFSS")); + } else if (auto ps = data[i].as<::tvm::ansor::FollowFusedSplitStepNode>()) { + writer->WriteArrayItem(std::string("FFSP")); writer->WriteArrayItem(ps->stage_id); writer->WriteArrayItem(ps->iter_id); @@ -110,7 +117,7 @@ struct Handler> { writer->WriteArrayItem(ps->level); writer->WriteArrayItem(static_cast(ps->factor_or_nparts)); } else if (auto ps = data[i].as<::tvm::ansor::FuseStepNode>()) { - writer->WriteArrayItem(std::string("FS")); + writer->WriteArrayItem(std::string("FU")); writer->WriteArrayItem(ps->stage_id); writer->WriteArraySeperator(); @@ -120,7 +127,7 @@ struct Handler> { } writer->EndArray(); } else if (auto ps = data[i].as<::tvm::ansor::AnnotationStepNode>()) { - writer->WriteArrayItem(std::string("AS")); + writer->WriteArrayItem(std::string("AN")); writer->WriteArrayItem(ps->stage_id); writer->WriteArrayItem(ps->iter_id); writer->WriteArrayItem(static_cast(ps->annotation)); @@ -145,12 +152,12 @@ struct Handler> { writer->WriteArrayItem(ps->stage_id); writer->WriteArrayItem(ps->scope_name); } else if (auto ps = data[i].as<::tvm::ansor::PragmaStepNode>()) { - writer->WriteArrayItem(std::string("PS")); + writer->WriteArrayItem(std::string("PR")); writer->WriteArrayItem(ps->stage_id); writer->WriteArrayItem(ps->iter_id); writer->WriteArrayItem(ps->pragma_type); } else if (auto ps = data[i].as<::tvm::ansor::RfactorStepNode>()) { - writer->WriteArrayItem(std::string("RFS")); + writer->WriteArrayItem(std::string("RF")); writer->WriteArrayItem(ps->stage_id); writer->WriteArrayItem(ps->iter_id); writer->WriteArrayItem(ps->factor_iter_id); @@ -167,8 +174,9 @@ struct Handler> { } writer->EndArray(); } + inline static void Read(dmlc::JSONReader* reader, - std::vector<::tvm::ansor::Step>* data) { + std::vector<::tvm::ansor::Step> * data) { std::vector int_list; bool s, inner_to_outer, factor_or_nparts; std::string name, scope_name, pragma_type; @@ -181,14 +189,13 @@ struct Handler> { reader->BeginArray(); s = reader->NextArrayItem(); CHECK(s); reader->Read(&name); - if (name == "RS") { + if (name == "RE") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&int_list); - data->push_back( - ::tvm::ansor::ReorderStepNode::make(stage_id, int_list)); - } else if (name == "SS") { + data->push_back(::tvm::ansor::ReorderStepNode::make(stage_id, int_list)); + } else if (name == "SP") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); s = reader->NextArrayItem(); CHECK(s); @@ -203,7 +210,7 @@ struct Handler> { stage_id, iter_id, extent, std::vector<::tvm::PrimExpr>(int_list.begin(), int_list.end()), inner_to_outer)); - } else if (name == "FSS") { + } else if (name == "FSP") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); s = reader->NextArrayItem(); CHECK(s); @@ -214,7 +221,7 @@ struct Handler> { reader->Read(&n_split); data->push_back(::tvm::ansor::FollowSplitStepNode::make( stage_id, iter_id, src_step_id, n_split)); - } else if (name == "FFSS") { + } else if (name == "FFSP") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); s = reader->NextArrayItem(); CHECK(s); @@ -227,21 +234,21 @@ struct Handler> { reader->Read(&factor_or_nparts); data->push_back(::tvm::ansor::FollowFusedSplitStepNode::make( stage_id, iter_id, int_list, level, factor_or_nparts)); - } else if (name == "FS") { + } else if (name == "FU") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&int_list); data->push_back(::tvm::ansor::FuseStepNode::make(stage_id, int_list)); - } else if (name == "AS") { + } else if (name == "AN") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&iter_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&ann); - data->push_back(::tvm::ansor::AnnotationStepNode::make( - stage_id, iter_id, ::tvm::ansor::IteratorAnnotation(ann))); + data->push_back(::tvm::ansor::AnnotationStepNode::make(stage_id, + iter_id, ::tvm::ansor::IteratorAnnotation(ann))); } else if (name == "CA") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); @@ -273,26 +280,26 @@ struct Handler> { reader->Read(&stage_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&scope_name); - data->push_back( - ::tvm::ansor::CacheWriteStepNode::make(stage_id, scope_name)); - } else if (name == "PS") { + data->push_back(::tvm::ansor::CacheWriteStepNode::make( + stage_id, scope_name)); + } else if (name == "PR") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&iter_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&pragma_type); - data->push_back( - ::tvm::ansor::PragmaStepNode::make(stage_id, iter_id, pragma_type)); - } else if (name == "RFS") { + data->push_back(::tvm::ansor::PragmaStepNode::make( + stage_id, iter_id, pragma_type)); + } else if (name == "RF") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&iter_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&factor_iter_id); - data->push_back(::tvm::ansor::RfactorStepNode::make(stage_id, iter_id, - factor_iter_id)); + data->push_back(::tvm::ansor::RfactorStepNode::make( + stage_id, iter_id, factor_iter_id)); } else if (name == "SA") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); @@ -392,7 +399,7 @@ struct Handler<::tvm::ansor::MeasureResultNode> { writer->BeginArray(false); writer->WriteArraySeperator(); writer->BeginArray(false); - for (const auto& x : data.costs) { + for (const auto&x : data.costs) { auto pf = x.as<::tvm::tir::FloatImmNode>(); CHECK(pf != nullptr) << "Cost can only contain float values"; writer->WriteArrayItem(pf->value); @@ -434,7 +441,7 @@ namespace ansor { TVM_REGISTER_OBJECT_TYPE(LogToFileNode); TVM_REGISTER_OBJECT_TYPE(LogReaderNode); -const std::string ansor_LOG_VERSION = "v0.1"; // NOLINT(*) +const std::string ANSOR_LOG_VERSION = "v0.1"; // NOLINT(*) MeasureCallback LogToFileNode::make(std::string filename) { auto node = make_object(); @@ -442,21 +449,24 @@ MeasureCallback LogToFileNode::make(std::string filename) { return MeasureCallback(node); } -void WriteMeasureRecords(std::ostream* os, const Array& inputs, +void WriteMeasureRecords(std::ostream* os, + const Array& inputs, const Array& results) { dmlc::JSONWriter writer(os); for (size_t i = 0; i < inputs.size(); ++i) { writer.BeginObject(false); writer.WriteObjectKeyValue("i", *inputs[i].operator->()); writer.WriteObjectKeyValue("r", *results[i].operator->()); - writer.WriteObjectKeyValue("v", ansor_LOG_VERSION); + writer.WriteObjectKeyValue("v", ANSOR_LOG_VERSION); writer.EndObject(); *os << "\n"; } } -void ReadMeasureRecords(std::string str, MeasureInputNode* inp, - MeasureResultNode* res, std::string* log_version) { +void ReadMeasureRecord(const std::string& str, + MeasureInputNode* inp, + MeasureResultNode* res, + std::string* log_version) { std::istringstream ss(str); dmlc::JSONReader reader(&ss); std::string key; @@ -499,7 +509,7 @@ bool LogReaderNode::ReadNext(MeasureInputNode* inp, MeasureResultNode* res) { } try { - ReadMeasureRecords(cur_line, inp, res, &log_version); + ReadMeasureRecord(cur_line, inp, res, &log_version); } catch (...) { return false; } @@ -510,8 +520,8 @@ bool LogReaderNode::ReadNext(MeasureInputNode* inp, MeasureResultNode* res) { return false; } -std::pair, Array> LogReaderNode::ReadLines( - int max_size, int skip_size) { +std::pair, Array > LogReaderNode::ReadLines( + int max_size, int skip_size) { auto inp = make_object(); auto res = make_object(); Array inputs; @@ -534,41 +544,68 @@ std::pair, Array> LogReaderNode::ReadLines( return std::make_pair(inputs, results); } -TVM_REGISTER_GLOBAL("ansor.write_measure_records_to_file") - .set_body([](TVMArgs args, TVMRetValue* ret) { - std::string filename = args[0]; - Array in = args[1]; - Array res = args[2]; - std::ofstream ofs(filename, std::ofstream::app); - WriteMeasureRecords(&ofs, in, res); - }); +std::pair BestMeasurePairInFile(const std::string& filename, + const std::string& workload_key, + const Target& target) { + std::pair best_pair; + double best_cost = 1e30; + + auto inp = make_object(); + auto res = make_object(); + LogReader reader = LogReaderNode::make(filename); + + while (reader->ReadNext(inp.get(), res.get())) { + if (res->error_no != kNoError || inp->task->workload_key != workload_key + || inp->task->target->target_name != target->target_name) { + continue; + } + + double cost = FloatArrayMean(res->costs); + + if (cost < best_cost) { + best_cost = cost; + best_pair = std::make_pair(inp->copy(), res->copy()); + } + } + + return best_pair; +} + +TVM_REGISTER_GLOBAL("ansor.WriteMeasureRecordsToFile") +.set_body([](TVMArgs args, TVMRetValue *ret) { + std::string filename = args[0]; + Array in = args[1]; + Array res = args[2]; + std::ofstream ofs(filename, std::ofstream::app); + WriteMeasureRecords(&ofs, in, res); +}); TVM_REGISTER_GLOBAL("ansor.LogToFile") - .set_body_typed([](const std::string& filename) { - return LogToFileNode::make(filename); - }); +.set_body_typed([](const std::string& filename) { + return LogToFileNode::make(filename); +}); TVM_REGISTER_GLOBAL("ansor.LogReader") - .set_body_typed([](const std::string& filename) { - return LogReaderNode::make(filename); - }); +.set_body_typed([](const std::string& filename) { + return LogReaderNode::make(filename); +}); TVM_REGISTER_GLOBAL("ansor.LogReaderReadLines") - .set_body_typed([](LogReader reader, int size, int skip_size) { - const auto& res = reader->ReadLines(size, skip_size); - return Array{res.first, res.second}; - }); +.set_body_typed([](LogReader reader, int size, int skip_size) { + const auto& res = reader->ReadLines(size, skip_size); + return Array{res.first, res.second}; +}); TVM_REGISTER_GLOBAL("ansor.LogReaderReadNext") - .set_body_typed([](LogReader reader) { - auto inp = make_object(); - auto res = make_object(); - if (reader->ReadNext(inp.get(), res.get())) { - return Array{ObjectRef(inp), ObjectRef(res)}; - } else { - return Array(); - } - }); +.set_body_typed([](LogReader reader) { + auto inp = make_object(); + auto res = make_object(); + if (reader->ReadNext(inp.get(), res.get())) { + return Array{ObjectRef(inp), ObjectRef(res)}; + } else { + return Array(); + } +}); } // namespace ansor } // namespace tvm diff --git a/src/ansor/serialization.h b/src/ansor/serialization.h index ef4132169652..a12760bb3acc 100644 --- a/src/ansor/serialization.h +++ b/src/ansor/serialization.h @@ -1,5 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors * \file ansor/serialization.h * \brief Json serialization format for dumping and loading tuning records */ @@ -7,18 +25,15 @@ #ifndef TVM_ANSOR_SERIALIZATION_H_ #define TVM_ANSOR_SERIALIZATION_H_ -#include #include +#include #include - #include "measure.h" namespace tvm { namespace ansor { -class LogReader; - -/*! \brief Log the input and results of measurments to file */ +/*! \brief Callback for logging the input and results of measurements to file */ class LogToFileNode : public MeasureCallbackNode { public: std::string filename; @@ -26,13 +41,16 @@ class LogToFileNode : public MeasureCallbackNode { static MeasureCallback make(std::string filename); /*! \brief Log measure pairs to file. This is called by the search policy */ - void callback(const SearchPolicy& policy, const Array& inputs, + void callback(const SearchPolicy& policy, + const Array& inputs, const Array& results) final; - static constexpr const char* _type_key = "ansor.LogToFile"; + static constexpr const char *_type_key = "ansor.LogToFile"; TVM_DECLARE_FINAL_OBJECT_INFO(LogToFileNode, MeasureCallbackNode); }; +class LogReader; + /*! \brief Log reader */ class LogReaderNode : public Object { public: @@ -49,7 +67,7 @@ class LogReaderNode : public Object { * \param max_size The maximum number of lines. -1 means read all lines * \param skip_size Skip the first n lines */ std::pair, Array > ReadLines( - int max_size = -1, int skip_size = 0); + int max_size = -1, int skip_size = 0); static constexpr const char* _type_key = "ansor.LogReader"; TVM_DECLARE_FINAL_OBJECT_INFO(LogReaderNode, Object); @@ -57,17 +75,23 @@ class LogReaderNode : public Object { private: std::string cur_line; }; -TVM_DEFINE_MUTABLE_NODE_REF(LogReader, LogReaderNode); +TVM_DEFINE_MUTABLE_OBJECT_REF(LogReader, LogReaderNode); -void WriteMeasureRecords(std::ostream* os, const Array& inputs, +/*! \brief Write measure records to an output stream */ +void WriteMeasureRecords(std::ostream* os, + const Array& inputs, const Array& results); -void ReadMeasureRecords(std::string str, MeasureInputNode* inp, - MeasureResultNode* res, std::string* log_version); +/*! \brief Read one measure record from a string */ +void ReadMeasureRecord(const std::string& str, + MeasureInputNode* inp, + MeasureResultNode* res, + std::string* log_version); -std::pair BestMeasurePairInFile( - const std::string& filename, const std::string& workload_key, - const Target& target); +/*! \brief Return the best measure pair with lowest cost in a file */ +std::pair BestMeasurePairInFile(const std::string& filename, + const std::string& workload_key, + const Target& target); } // namespace ansor } // namespace tvm diff --git a/src/ansor/transform_step.cc b/src/ansor/transform_step.cc index 5f4a6a8dcef9..3f59ff736e9d 100644 --- a/src/ansor/transform_step.cc +++ b/src/ansor/transform_step.cc @@ -1,16 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors + * \file ansor/transform_step.cc + * \brief Transformation steps. For each schedule primitive, there is a corresponding transform step. + * + * See the note in transform_step.h on how to add a new step */ + #include "transform_step.h" #include +#include #include "utils.h" namespace tvm { namespace ansor { -TVM_REGISTER_NODE_TYPE(IteratorNode); -TVM_REGISTER_OBJECT_TYPE(StepNode); - /********** Reorder **********/ ReorderStep ReorderStepNode::make(int stage_id, const std::vector& after_ids) { auto node = make_object(); @@ -226,7 +247,8 @@ FollowFusedSplitStep FollowFusedSplitStepNode::make(int stage_id, int iter_id, return FollowFusedSplitStep(node); } -PrimExpr FollowFusedSplitStepNode::ExtractSplitLength(const std::vector& transform_steps) const { +PrimExpr FollowFusedSplitStepNode::ExtractSplitLength( + const std::vector& transform_steps) const { PrimExpr ret(1); for (int src_step_id : src_step_ids) { @@ -402,7 +424,7 @@ std::string AnnotationStepNode::PrintAsPythonAPI(std::vector *stages, return ss.str(); } -/********** Compute at **********/ +/********** Compute At **********/ ComputeAtStep ComputeAtStepNode::make(int stage_id, int target_stage_id, int target_iter_id) { auto node = make_object(); node->stage_id = stage_id; @@ -487,29 +509,7 @@ std::string ComputeInlineStepNode::PrintAsPythonAPI( return ss.str(); } -/********** Pack for vec **********/ -PackForVecStep PackForVecStepNode::make(int stage_id, int iter_id, int vec_size) { - auto node = make_object(); - node->stage_id = stage_id; - node->iter_id = iter_id; - node->vec_size = vec_size; - return PackForVecStep(node); -} - -void PackForVecStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, te::Schedule *schedule) const { - LOG(FATAL) << "Not implemented"; -} - -std::string PackForVecStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - LOG(FATAL) << "Not implemented"; - return ""; -} - -/********** Cache read **********/ +/********** Cache Read **********/ CacheReadStep CacheReadStepNode::make(int stage_id, std::string scope_name, const std::vector& reader_stage_ids) { auto node = make_object(); @@ -572,7 +572,7 @@ std::string CacheReadStepNode::PrintAsPythonAPI(std::vector *stages, return ss.str(); } -/********** Cache write **********/ +/********** Cache Write **********/ CacheWriteStep CacheWriteStepNode::make(int stage_id, std::string scope_name) { auto node = make_object(); node->stage_id = stage_id; @@ -770,8 +770,7 @@ std::string RfactorStepNode::PrintAsPythonAPI(std::vector *stages, return ss.str(); } -/********** StorageAlign **********/ - +/********** Storage Align **********/ StorageAlignStep StorageAlignStepNode::make(int stage_id, int iter_id, int factor, int offset) { auto node = make_object(); @@ -802,20 +801,5 @@ std::string StorageAlignStepNode::PrintAsPythonAPI( return ss.str(); } -// Maker for other classes -Iterator IteratorNode::make(std::string name, Range range, - IteratorType iter_type, IteratorAnnotation annotation, - const std::vector* ori_iters) { - auto node = make_object(); - node->name = std::move(name); - node->range = std::move(range); - node->iter_type = iter_type; - node->annotation = annotation; - if (ori_iters != nullptr) { - node->ori_iters = *ori_iters; - } - return Iterator(node); -} - } // namespace ansor } // namespace tvm diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h index 627ce02b60e1..8240623ae3b1 100644 --- a/src/ansor/transform_step.h +++ b/src/ansor/transform_step.h @@ -1,106 +1,30 @@ -/*! - * Copyright (c) 2020 by Contributors - * \file ansor/transform_step.h - * \brief Data structures for loop transformations - - * Basically this is a simplified TVM IR with schedule primitives. - * We don't use the existing TVM IR because - * 1. We want fast incremental change to the loop structures - * 2. We want serializable history for replay and backtracking - * 3. We want simplified IR for easy and clean feature extraction - * 4. We may create some Macro schedule primitives - - * After search is done, we will lower this IR to TVM IR and TVM schedule primitives. - * Because we share a lot common objects during search, the transformation is - * implemented in copy on write style. All objects are immutable, which is - * similar to TVM IR. +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. */ -#ifndef TVM_ANSOR_TRANSFORM_STEP_H_ -#define TVM_ANSOR_TRANSFORM_STEP_H_ - -#include -#include -#include -#include "compute_dag.h" - -namespace tvm { -namespace ansor { - -using namespace tvm::tir; - -inline std::string CleanName(const std::string& str) { - // to make the name valid in python code - std::string ret = str; - StrReplace(&ret, ".", "_"); - StrReplace(&ret, "@", "_"); - StrReplace(&ret, "outer", "o"); - StrReplace(&ret, "inner", "i"); - return ret; -} - -enum IteratorType { - kSpace, // spatial iterator - kReduce, // reduction iterator - kMixed, // fused spatial and reduction iterator - kSpecial // special iterator (e.g. virtual root iterator) -}; - -enum IteratorAnnotation { - kNone, kUnroll, kVectorize, kParallel, - kVThread, kBlockX, kThreadX, kBlockY, kThreadY -}; - -class Iterator; - /*! - * \brief An for loop iterator - * Similar to tvm::IterVar in `include/expr.h` - */ -class IteratorNode : public Object { - public: - std::string name; - Range range; // domain of for loop range - IteratorType iter_type; - IteratorAnnotation annotation; - std::vector ori_iters; - - static Iterator make(std::string name, Range range, - IteratorType iter_type, IteratorAnnotation annotation, - const std::vector* ori_iters = nullptr); - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("name", &name); - v->Visit("range", &range); - } - - static constexpr const char *_type_key = "ansor.Iterator"; - TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(Iterator, ObjectRef, IteratorNode); - -/*! \brief The base class for a transformation step */ -class StepNode: public Object { - public: - int stage_id; - - // Print step as equivalent python schedule API - virtual std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const = 0; - - static constexpr const char* _type_key = "ansor.Step"; - TVM_DECLARE_BASE_OBJECT_INFO(StepNode, Object); -}; -TVM_DEFINE_MUTABLE_NODE_REF(Step, StepNode); - -/* - * Note on how to add a new transform step + * \file ansor/transform_step.h + * \brief Transformation steps. For each schedule primitive, there is a corresponding transform step. * + * \Note How to add a new transform step. * Take fuse for example: - * 1. Define class FuseStepNode, FuseStep in loop_state.h, and implement its make function - * in FuseStepNode::make(...) loop_state.cc + * 1. Define class FuseStepNode, FuseStep in transform_steps.h, and implement its make function + * in FuseStepNode::make(...) transform_steps.cc * 2. Implement FuseStepNode::ApplyToSchedule and FuseStepNode::PrintAsPythonAPI. * - In these two functions you need to lower this step with tvm's schedule API * 3. Implement State::fuse and State::DoFuseStep. @@ -112,17 +36,24 @@ TVM_DEFINE_MUTABLE_NODE_REF(Step, StepNode); * 6. Add hash support in `struct hash<::tvm::ansor::Step>` (search for this function in this file) */ -class ReorderStep; class SplitStep; class FollowSplitStep; -class FollowFusedSplitStep; -class FuseStep; class AnnotationStep; -class ComputeAtStep; class ComputeRootStep; class ComputeInlineStep; -class PackForVecStep; class CacheReadStep; class CacheWriteStep; -class PragmaStep; class RfactorStep; class StorageAlignStep; -class AttachMap; +#ifndef TVM_ANSOR_TRANSFORM_STEP_H_ +#define TVM_ANSOR_TRANSFORM_STEP_H_ + +#include +#include +#include +#include "loop_state.h" + +namespace tvm { +namespace ansor { + +using namespace tvm::tir; +/*! \brief Reorder step that corresponds to te::Stage::reorder */ class ReorderStepNode: public StepNode { public: - std::vector after_ids; + std::vector after_ids; // The iterator ids after reorder. + // This array should specify the order of all iterators. static ReorderStep make(int stage_id, const std::vector& after_ids); @@ -137,15 +68,17 @@ class ReorderStepNode: public StepNode { static constexpr const char* _type_key = "ansor.ReorderStep"; TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, Object); }; -TVM_DEFINE_COW_NODE_REF(ReorderStep, Step, ReorderStepNode); - +TVM_DEFINE_COW_OBJECT_REF(ReorderStep, Step, ReorderStepNode); +/*! \brief Split step that corresponds to te::Stage::split with additional + * support of multiple-level of factors */ class SplitStepNode: public StepNode { public: - int iter_id; - PrimExpr extent; // the extent of the axis to split + int iter_id; // The id of the iter to split + PrimExpr extent; // the extent length of the axis to split std::vector lengths; // The split factors - bool inner_to_outer; + bool inner_to_outer; // If true, the `lengths` denote the lengths of + // iterators from inner level to outer level static SplitStep make(int stage_id, int iter_id, PrimExpr extent, const std::vector& lengths, @@ -162,15 +95,15 @@ class SplitStepNode: public StepNode { static constexpr const char* _type_key = "ansor.SplitStep"; TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, Object); }; -TVM_DEFINE_COW_NODE_REF(SplitStep, Step, SplitStepNode); +TVM_DEFINE_COW_OBJECT_REF(SplitStep, Step, SplitStepNode); -// Similar to SplitStepNode, but use split factor from another step -// (i.e. Follow another split step) +/*! \brief Similar to SplitStepNode, but use split factor from another step + * (i.e. Follow another split step) */ class FollowSplitStepNode: public StepNode { public: - int iter_id; - int src_step_id; - int n_split; + int iter_id; // The id of the iter to split + int src_step_id; // The index of the split step to follow in the history + int n_split; // The number of split level static FollowSplitStep make(int stage_id, int iter_id, int src_step_id, int n_split); @@ -190,17 +123,17 @@ class FollowSplitStepNode: public StepNode { static constexpr const char* _type_key = "ansor.FollowSplitStep"; TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object); }; -TVM_DEFINE_COW_NODE_REF(FollowSplitStep, Step, FollowSplitStepNode); - +TVM_DEFINE_COW_OBJECT_REF(FollowSplitStep, Step, FollowSplitStepNode); -// Similar to FollowSplitStep, but use split factors from multiple steps -// This can be used for the split in cooperative fetching. +/*! \brief Similar to FollowSplitStep, but use split factors from multiple steps. + * \Note This can be used for the split in cooperative fetching + */ class FollowFusedSplitStepNode: public StepNode { public: - int iter_id; - std::vector src_step_ids; - int level; // Use the length in this split level - bool factor_or_nparts; // If this is true, use factor. Otherwise, use nparts + int iter_id; // The id of the iter to split + std::vector src_step_ids; // The indices of the split steps to follow in the history + int level; // Use the length in this split level + bool factor_or_nparts; // If this is true, use factor. Otherwise, use nparts static FollowFusedSplitStep make(int stage_id, int iter_id, const std::vector& src_step_ids, @@ -220,12 +153,12 @@ class FollowFusedSplitStepNode: public StepNode { static constexpr const char* _type_key = "ansor.FollowFusedSplitStep"; TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, Object); }; -TVM_DEFINE_COW_NODE_REF(FollowFusedSplitStep, Step, FollowFusedSplitStepNode); - +TVM_DEFINE_COW_OBJECT_REF(FollowFusedSplitStep, Step, FollowFusedSplitStepNode); +/*! \brief Fuse step that corresponds to te::Stage::fuse */ class FuseStepNode: public StepNode { public: - std::vector fused_ids; + std::vector fused_ids; // The ids of iterators to fuse static FuseStep make(int stage_id, const std::vector& fused_ids); @@ -240,9 +173,11 @@ class FuseStepNode: public StepNode { static constexpr const char* _type_key = "ansor.FuseStep"; TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, Object); }; -TVM_DEFINE_COW_NODE_REF(FuseStep, Step, FuseStepNode); - +TVM_DEFINE_COW_OBJECT_REF(FuseStep, Step, FuseStepNode); +/*! \brief Annotation step that corresponds to vectorize, parallel, unroll and thread binding. + * (i.e. te::Stage::vectorize, te::Stage::parallel, te::Stage::vectorize, te::Stage::bind) + */ class AnnotationStepNode: public StepNode { public: int iter_id; @@ -261,9 +196,9 @@ class AnnotationStepNode: public StepNode { static constexpr const char* _type_key = "ansor.AnnotationStep"; TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, Object); }; -TVM_DEFINE_COW_NODE_REF(AnnotationStep, Step, AnnotationStepNode); - +TVM_DEFINE_COW_OBJECT_REF(AnnotationStep, Step, AnnotationStepNode); +/*! \brief Fuse step that corresponds to te::Stage::compute_at */ class ComputeAtStepNode: public StepNode { public: int target_stage_id; @@ -283,9 +218,9 @@ class ComputeAtStepNode: public StepNode { static constexpr const char* _type_key = "ansor.ComputeAtStep"; TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, Object); }; -TVM_DEFINE_COW_NODE_REF(ComputeAtStep, Step, ComputeAtStepNode); - +TVM_DEFINE_COW_OBJECT_REF(ComputeAtStep, Step, ComputeAtStepNode); +/*! \brief Fuse step that corresponds to te::Stage::compute_root */ class ComputeRootStepNode: public StepNode { public: static ComputeRootStep make(int stage_id); @@ -301,9 +236,9 @@ class ComputeRootStepNode: public StepNode { static constexpr const char* _type_key = "ansor.ComputeRootStep"; TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, Object); }; -TVM_DEFINE_COW_NODE_REF(ComputeRootStep, Step, ComputeRootStepNode); - +TVM_DEFINE_COW_OBJECT_REF(ComputeRootStep, Step, ComputeRootStepNode); +/*! \brief Fuse step that corresponds to te::Stage::compute_inline */ class ComputeInlineStepNode: public StepNode { public: static ComputeInlineStep make(int stage_id); @@ -319,31 +254,9 @@ class ComputeInlineStepNode: public StepNode { static constexpr const char* _type_key = "ansor.ComputeInlineStep"; TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, Object); }; -TVM_DEFINE_COW_NODE_REF(ComputeInlineStep, Step, ComputeInlineStepNode); +TVM_DEFINE_COW_OBJECT_REF(ComputeInlineStep, Step, ComputeInlineStepNode); -class PackForVecStepNode: public StepNode { - public: - int iter_id; - int vec_size; - - static PackForVecStep make(int stage_id, int iter_id, int vec_size); - - void ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, te::Schedule *schedule) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.PackForVecStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(PackForVecStepNode, Object); -}; -TVM_DEFINE_COW_NODE_REF(PackForVecStep, Step, PackForVecStepNode); - - -/*! \brief Apply cache_read to a stage - * TVM Api: te::Schedule::cache_read(tensor, scope, readers) */ +/*! \brief Cache read step that corresponds to te::Schedule::cache_read */ class CacheReadStepNode: public StepNode { public: std::string scope_name; @@ -363,12 +276,10 @@ class CacheReadStepNode: public StepNode { static constexpr const char* _type_key = "ansor.CacheReadStep"; TVM_DECLARE_FINAL_OBJECT_INFO(CacheReadStepNode, Object); }; -TVM_DEFINE_COW_NODE_REF(CacheReadStep, Step, CacheReadStepNode); +TVM_DEFINE_COW_OBJECT_REF(CacheReadStep, Step, CacheReadStepNode); - -/*! \brief Apply cache_write to a stage - * TVM Api: te::Schedule::cache_write(tensor, scope) - * This step will cache_write all output tensors of target stage */ +/*! \brief Cache read step that corresponds to te::Schedule::cache_write + * \Note This step will cache_write all output tensors of target stage */ class CacheWriteStepNode: public StepNode { public: std::string scope_name; @@ -386,9 +297,9 @@ class CacheWriteStepNode: public StepNode { static constexpr const char* _type_key = "ansor.CacheWriteStep"; TVM_DECLARE_FINAL_OBJECT_INFO(CacheWriteStepNode, Object); }; -TVM_DEFINE_COW_NODE_REF(CacheWriteStep, Step, CacheWriteStepNode); +TVM_DEFINE_COW_OBJECT_REF(CacheWriteStep, Step, CacheWriteStepNode); -/*! \brief Add pragma to a specific iterator */ +/*! \brief Cache read step that corresponds to te::Schedule::pragma */ class PragmaStepNode: public StepNode { public: int iter_id; @@ -407,10 +318,9 @@ class PragmaStepNode: public StepNode { static constexpr const char* _type_key = "ansor.PragmaStep"; TVM_DECLARE_FINAL_OBJECT_INFO(PragmaStepNode, Object); }; -TVM_DEFINE_COW_NODE_REF(PragmaStep, Step, PragmaStepNode); +TVM_DEFINE_COW_OBJECT_REF(PragmaStep, Step, PragmaStepNode); -/*! \brief Factor a reduction axis - * TVM Api: te::Schedule::rfactor(tensor, axis, factor_axis) */ +/*! \brief Reduction factor step that corresponds to te::Schedule::rfactor */ class RfactorStepNode: public StepNode { public: int iter_id; @@ -430,8 +340,9 @@ class RfactorStepNode: public StepNode { static constexpr const char* _type_key = "ansor.RfactorStep"; TVM_DECLARE_FINAL_OBJECT_INFO(RfactorStepNode, Object); }; -TVM_DEFINE_COW_NODE_REF(RfactorStep, Step, RfactorStepNode); +TVM_DEFINE_COW_OBJECT_REF(RfactorStep, Step, RfactorStepNode); +/*! \brief Storage align step that corresponds to te::Schedule::storage_align */ class StorageAlignStepNode: public StepNode { public: int iter_id; @@ -452,12 +363,12 @@ class StorageAlignStepNode: public StepNode { static constexpr const char* _type_key = "ansor.StorageAlignStep"; TVM_DECLARE_FINAL_OBJECT_INFO(StorageAlignStepNode, Object); }; -TVM_DEFINE_COW_NODE_REF(StorageAlignStep, Step, StorageAlignStepNode); +TVM_DEFINE_COW_OBJECT_REF(StorageAlignStep, Step, StorageAlignStepNode); } // namespace ansor } // namespace tvm -// Hash and equal function for State, Stage, Iterator and Step +// Hash and equal function for Step namespace std { template <> @@ -515,32 +426,27 @@ struct hash<::tvm::ansor::Step> { } else if (auto ps = step.as<::tvm::ansor::ComputeInlineStepNode>()) { return ::dmlc::HashCombine(9, ps->stage_id); - } else if (auto ps = step.as<::tvm::ansor::PackForVecStepNode>()) { - return ::dmlc::HashCombine(10, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - ps->vec_size))); } else if (auto ps = step.as<::tvm::ansor::CacheReadStepNode>()) { - return ::dmlc::HashCombine(11, + return ::dmlc::HashCombine(10, ::dmlc::HashCombine(std::hash()(ps->stage_id), ::dmlc::HashCombine(std::hash()(ps->scope_name), ps->reader_stage_ids))); } else if (auto ps = step.as<::tvm::ansor::CacheWriteStepNode>()) { - return ::dmlc::HashCombine(12, + return ::dmlc::HashCombine(11, ::dmlc::HashCombine(std::hash()(ps->stage_id), ps->scope_name)); } else if (auto ps = step.as<::tvm::ansor::PragmaStepNode>()) { - return ::dmlc::HashCombine(13, + return ::dmlc::HashCombine(12, ::dmlc::HashCombine(std::hash()(ps->stage_id), ::dmlc::HashCombine(std::hash()(ps->iter_id), ps->pragma_type))); } else if (auto ps = step.as<::tvm::ansor::RfactorStepNode>()) { - return ::dmlc::HashCombine(14, + return ::dmlc::HashCombine(13, ::dmlc::HashCombine(std::hash()(ps->stage_id), ::dmlc::HashCombine(std::hash()(ps->iter_id), ps->factor_iter_id))); } else if (auto ps = step.as<::tvm::ansor::StorageAlignStepNode>()) { - return ::dmlc::HashCombine(15, + return ::dmlc::HashCombine(14, ::dmlc::HashCombine(std::hash()(ps->stage_id), ::dmlc::HashCombine(std::hash()(ps->iter_id), ::dmlc::HashCombine(std::hash()(ps->factor), diff --git a/src/ansor/utils.cc b/src/ansor/utils.cc index 2018cf33d1a2..27aac7e8b315 100644 --- a/src/ansor/utils.cc +++ b/src/ansor/utils.cc @@ -1,5 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors + * \file ansor/utils.cc + * \brief Common utilities */ #include "utils.h" @@ -8,7 +28,6 @@ namespace tvm { namespace ansor { - NullStream& NullStream::Global() { static NullStream stream; return stream; diff --git a/src/ansor/utils.h b/src/ansor/utils.h index 67ebb836c680..cb90364b01b5 100644 --- a/src/ansor/utils.h +++ b/src/ansor/utils.h @@ -1,5 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors * \file ansor/utils.h * \brief Common utilities */ @@ -25,7 +43,7 @@ namespace std { -// hash function for std::pair, std::vector and std::tuple +/*! \brief Hash function for std::pair */ template struct hash > { std::size_t operator()(const std::pair& k) const { @@ -33,6 +51,7 @@ struct hash > { } }; +/*! \brief Hash function for std::tuple */ template struct hash > { std::size_t operator()(const std::tuple& k) const { @@ -42,6 +61,7 @@ struct hash > { } }; +/*! \brief Hash function for std::vector */ template struct hash > { std::size_t operator()(const std::vector& vec) const { @@ -61,38 +81,37 @@ struct hash > { namespace tvm { namespace ansor { -/*! \brief Macro to make it easy to define node ref type given node */ -#define TVM_DEFINE_NODE_REF(TypeName, NodeName) \ +/*! \brief Macro to make it easy to define object ref type given node */ +#define TVM_DEFINE_OBJECT_REF(TypeName, ObjectName) \ class TypeName : public ObjectRef { \ public: \ - TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ObjectRef, NodeName); \ + TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ObjectRef, ObjectName); \ }; \ -/*! \brief Macro to make it easy to define mutable node ref type given node */ -#define TVM_DEFINE_MUTABLE_NODE_REF(TypeName, NodeName) \ +/*! \brief Macro to make it easy to define mutable object ref type given node */ +#define TVM_DEFINE_MUTABLE_OBJECT_REF(TypeName, ObjectName) \ class TypeName : public ObjectRef { \ public: \ - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ObjectRef, NodeName); \ + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ObjectRef, ObjectName); \ }; \ /*! * \brief Macro to make it easy to define node ref type that * has a CopyOnWrite member function. */ -#define TVM_DEFINE_COW_NODE_REF(TypeName, BaseType, NodeName) \ - class TypeName : public BaseType { \ - public: \ - TVM_DEFINE_OBJECT_REF_METHODS(TypeName, BaseType, NodeName); \ - TVM_DEFINE_OBJECT_REF_COW_METHOD(NodeName); \ +#define TVM_DEFINE_COW_OBJECT_REF(TypeName, BaseType, ObjectName) \ + class TypeName : public BaseType { \ + public: \ + TVM_DEFINE_OBJECT_REF_METHODS(TypeName, BaseType, ObjectName); \ + TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName); \ }; -/********** Utilities for std::vector, std::set **********/ - +/********** Utilities for std::vector, std::set, std::string **********/ /*! \brief Get the first appearance index of elements in a vector */ template inline void GetIndices(const std::vector& array, - const std::vector& to_locate, - std::vector* indices) { + const std::vector& to_locate, + std::vector* indices) { for (const auto& v : to_locate) { auto it = std::find(array.begin(), array.end(), v); if (it != array.end()) { @@ -133,7 +152,7 @@ inline int64_t ElementProduct(const std::vector& array) { return ret; } -/* \brief Get the maximum element in a vector */ +/*! \brief Get the maximum element in a vector */ template T MaximumElement(const std::vector& array) { CHECK(!array.empty()); @@ -162,7 +181,7 @@ std::vector& ConcatenateMove(std::vector* out, std::vector* first, Args return *out; } -/* \brief Get a random permutation of integers [0, n-1] */ +/*! \brief Get a random permutation of integers [0, n-1] */ template void RandomPermutation(int n, std::vector* out, G* gen) { out->assign(n, 0); @@ -170,7 +189,7 @@ void RandomPermutation(int n, std::vector* out, G* gen) { std::shuffle(out->begin(), out->end(), *gen); } -/* \brief Random sample without replacement */ +/*! \brief Random sample without replacement */ template void RandomSample(std::vector* in_data, size_t out_size, G* gen) { // Note: This function is inefficient in the cases when out_size << in_data.size() @@ -204,43 +223,19 @@ inline void Argsort(const std::vector& scores, std::vector* index) { std::sort(index->begin(), index->end(), cmp); } -// Do x++ for all x in the set such that x >= threshold -inline void SetAddOne(std::set* set, int threshold = 0) { - std::set new_set; - for (int x : *set) { - if (x >= threshold) { - new_set.insert(x + 1); - } else { - new_set.insert(x); - } - } - *set = std::move(new_set); -} - -// Compute Jaccard Similarity of two sets -template -double JaccardSimilarity(std::set s1, std::set s2) { - std::vector intersect; - std::set_intersection(s1.begin(), s1.end(), s2.begin(), s2.end(), - std::back_inserter(intersect)); - return 1.0 * intersect.size() / (s1.size() + s2.size() - intersect.size()); -} - -/********** Utilities for std::string **********/ - -/*! Return whether a string ends with a another substring */ +/*! \brief Return whether a string ends with another substring */ inline bool StrEndsWith(const std::string& a, const std::string& b) { if (b.size() > a.size()) return false; return std::equal(a.begin() + a.size() - b.size(), a.end(), b.begin()); } -/*! Return whether a string starts with a another substring */ +/*! \brief Return whether a string starts with another substring */ inline bool StrStartsWith(const std::string& a, const std::string& b) { if (b.size() > a.size()) return false; return std::equal(a.begin(), a.begin() + b.size(), b.begin()); } -/*! Replace a sub-string to another sub-string in a string */ +/*! \brief Replace a sub-string to another sub-string in a string */ inline void StrReplace(std::string* base, const std::string& from, const std::string& to) { auto pos = base->find(from); while (pos != std::string::npos) { @@ -250,7 +245,6 @@ inline void StrReplace(std::string* base, const std::string& from, const std::st } /********** Utilities for TVM Containers / ByteArray **********/ - /*! \brief Compute mean of a FloatImm array */ inline double FloatArrayMean(const Array& float_array) { double sum = 0; @@ -266,51 +260,15 @@ inline double FloatArrayMean(const Array& float_array) { return sum / float_array.size(); } -/*! \brief Serialize a 2-dimensional vector to TVMByteArray. - * This is used for sending data to python code */ -template -inline TVMByteArray Serialize2dVector(std::vector >&& in_data, - std::vector* out_data) { - size_t total_bytes = 0; - std::vector size_vector; - - // serialize sizes - total_bytes += (1 + in_data.size()) * sizeof(int); - size_vector.reserve(in_data.size() + 1); - size_vector.push_back(in_data.size()); - for (const auto& x : in_data) { - size_vector.push_back(static_cast(x.size())); - total_bytes += sizeof(T) * x.size(); - } - - out_data->reserve(total_bytes); - char* ptr = out_data->data(); - memmove(ptr, reinterpret_cast(size_vector.data()), (1 + in_data.size()) * sizeof(int)); - ptr += (1 + in_data.size()) * sizeof(int); - - // serialize in_data - for (auto& x : in_data) { - memmove(ptr, x.data(), sizeof(T) * x.size()); - ptr += sizeof(T) * x.size(); - x.clear(); - } - - CHECK_EQ(ptr - out_data->data(), total_bytes); - - return TVMByteArray{out_data->data(), total_bytes}; -} - /********** Other Utilities **********/ - -// Get an int value from an Expr +/*! \brief Get an int value from an Expr */ inline int64_t GetIntImm(const PrimExpr& expr) { auto pint = expr.as(); CHECK(pint != nullptr); return pint->value; } - -// Compute the product of the lengths of axes +/*! \brief Compute the product of the lengths of axes */ inline int64_t AxisLengthProd(const Array& axes) { int64_t ret = 1.0; for (const auto& x : axes) { @@ -323,8 +281,7 @@ inline int64_t AxisLengthProd(const Array& axes) { return ret; } - -// An empty output stream +/*! \brief An empty output stream */ class NullStream : public std::ostream { public: NullStream() : std::ostream(nullptr) {} diff --git a/tests/cpp/ansor_test.cc b/tests/cpp/ansor_test.cc index 00e748204fde..5f1dea0f1ea5 100644 --- a/tests/cpp/ansor_test.cc +++ b/tests/cpp/ansor_test.cc @@ -21,28 +21,13 @@ #include #include #include - +#include #include -#include "../../src/ansor/feature.h" +// todo(jcf94): do not use relative path #include "../../src/ansor/loop_state.h" -#include "../../src/ansor/search_policy/meta_tile_rewrite_policy.h" -#include "../../src/ansor/serialization.h" - -tvm::Array matmul_func(int n, int m, int k) { - using namespace tvm; - using namespace tvm::te; - - Tensor A = placeholder({n, k}, DataType::Float(32), "A"); - Tensor B = placeholder({k, m}, DataType::Float(32), "B"); - IterVar K = IterVarNode::make({0, k}, Var("k"), kCommReduce); - const auto& C = compute( - {n, m}, [&](Var i, Var j) { return tvm::sum(A[i][K] * B[K][j], {K}); }, - "C"); - - return {A, B, C}; -} +// Compute declaration for test tvm::Array conv2d_nchw_bn_relu_func(int N, int H, int W, int CI, int CO, int kernel_size, @@ -91,17 +76,7 @@ tvm::Array conv2d_nchw_bn_relu_func(int N, int H, int W, using namespace tvm::ansor; -TEST(ComputeDAG, Basic) { - const auto& tensors = conv2d_nchw_bn_relu_func(1, 224, 224, 3, 64, 7, 2, 3); - const auto& dag = ComputeDAGNode::make(tensors); - const auto& state = StateNode::make(dag->ops); - CHECK(std::equal_to()(state, dag.GetInitState())); - - LOG(INFO) << "\n" << state; - LOG(INFO) << "\n" << dag; - LOG(INFO) << "\n" << dag->access_analyzer; -} - +// Test Access Analyzer TEST(ComputeDAG, GetProducersConsumers) { const auto& tensors = conv2d_nchw_bn_relu_func(1, 224, 224, 3, 64, 7, 2, 3); const auto& dag = tvm::ansor::ComputeDAGNode::make(tensors); @@ -166,570 +141,6 @@ TEST(ComputeDAG, GetProducersConsumers) { } } -TEST(ComputeDAG, InferBoundSerialization) { - const auto& tensors = matmul_func(512, 512, 512); - const auto& dag = ComputeDAGNode::make(tensors); - int A = 0, B = 1, C = 2; - - State s0 = dag.GetInitState(); - int C_global = s0.cache_write(C, "global", dag); - C++; - const auto& its0 = s0.split(C, s0->stages[C]->iters[0], {4, 8, 8}); - const auto& its1 = s0.split(C, s0->stages[C]->iters[4], {8, 4, 4}); - s0.reorder(C, {its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], - its1[3]}); - s0.compute_at(C_global, C, s0->stages[C]->iters[3]); - s0.split(C_global, s0->stages[C_global]->iters[2], {16}); - int B_global = s0.cache_read(B, "global", {C_global}, dag); - C++; - C_global++; - s0.compute_at(B_global, C_global, s0->stages[C_global]->iters[0]); - int A_global = s0.cache_read(A, "global", {C_global}, dag); - B++; - B_global++; - C++; - C_global++; - s0.compute_at(A_global, C_global, s0->stages[C_global]->iters[2]); - - const auto& s1 = dag.InferBound(s0); - std::vector s2 = {s0}; - dag.InferBound(&s2); - const auto& s3 = dag.ReplayAndInferBound(s0->transform_steps); - - CHECK_EQ( - s1->stages[B_global]->iters[0]->range->extent.as()->value, - 512); - CHECK_EQ( - s1->stages[B_global]->iters[1]->range->extent.as()->value, - 16); - CHECK_EQ( - s1->stages[A_global]->iters[0]->range->extent.as()->value, 1); - CHECK_EQ( - s1->stages[A_global]->iters[1]->range->extent.as()->value, - 16); - CHECK_EQ( - s1->stages[C_global]->iters[0]->range->extent.as()->value, - 64); - CHECK(std::equal_to()(s1, s2[0])); - CHECK(std::equal_to()(s1, s3)); - - const auto& minp0 = MeasureInputNode::make( - SearchTaskNode::make(dag, "test", tvm::target::llvm(), - tvm::target::llvm(), HardwareParams()), - s0); - const auto& mres0 = MeasureResultNode::make({0.1}, 0, "", 0.1, 0.1); - std::stringstream ss; - WriteMeasureRecords(&ss, {minp0}, {mres0}); - auto minp1 = tvm::make_object(); - auto mres1 = tvm::make_object(); - std::string log_version; - ReadMeasureRecords(ss.str(), minp1.get(), mres1.get(), &log_version); - const auto& s4 = dag.ReplayAndInferBound(minp1->state->transform_steps); - CHECK(std::equal_to()(s1, s4)); -} - -TEST(Step, SplitFuseReorder) { - const auto& tensors = matmul_func(512, 512, 512); - const auto& dag = ComputeDAGNode::make(tensors); - - State s0 = dag.GetInitState(); - State s1 = s0; - Iterator ti = s0->stages[2]->iters[0]; - Iterator tj = s0->stages[2]->iters[1]; - Iterator tk = s0->stages[2]->iters[2]; - std::vector its; - - CHECK_EQ(s1->stages[2]->iters[0]->range->extent.as()->value, 512); - - its = s0.split(2, ti, {16}); - Iterator tio = its[0], tii = its[1]; - CHECK_EQ(s0->stages[2]->iters[0]->range->extent.as()->value, 32); - CHECK_EQ(s0->stages[2]->iters[1]->range->extent.as()->value, 16); - - its = s0.split(2, tj, {8}); - Iterator tjo = its[0], tji = its[1]; - CHECK_EQ(s0->stages[2]->iters[2]->range->extent.as()->value, 64); - CHECK_EQ(s0->stages[2]->iters[3]->range->extent.as()->value, 8); - - s0.reorder(2, {tio, tjo, tk, tji, tii}); - CHECK_EQ(s0->stages[2]->iters[2]->range->extent.as()->value, 512); - - s0.fuse(2, {tio, tjo}); - CHECK_EQ(s0->stages[2]->iters[0]->range->extent.as()->value, - 2048); - - s1.split(2, ti, {8, 2}); - s1.split(2, tj, {32, 8}, false); - CHECK_EQ(s1->stages[2]->iters[0]->range->extent.as()->value, 32); - CHECK_EQ(s1->stages[2]->iters[1]->range->extent.as()->value, 8); - CHECK_EQ(s1->stages[2]->iters[2]->range->extent.as()->value, 2); - CHECK_EQ(s1->stages[2]->iters[3]->range->extent.as()->value, 32); - CHECK_EQ(s1->stages[2]->iters[4]->range->extent.as()->value, 8); - CHECK_EQ(s1->stages[2]->iters[5]->range->extent.as()->value, 2); -} - -TEST(Step, ComputeAtRootInline) { - const auto& tensors = conv2d_nchw_bn_relu_func(1, 224, 224, 3, 64, 7, 2, 3); - const auto& dag = tvm::ansor::ComputeDAGNode::make(tensors); - // int data = 0, padding = 1, kernel = 2; - int conv = 3; - // int bias = 4; - int bias_add = 5; - // int bn_scale = 6; - int bn_mul = 7; - // int bn_offset = 8; - int bn_add = 9, relu = 10; - - State s0 = dag.GetInitState(); - s0.compute_inline(bn_add); - s0.compute_inline(bn_mul); - s0.compute_inline(bias_add); - s0.compute_at(conv, relu, s0->stages[relu]->iters[2]); - const auto& conv_stage_attach = - s0->attach_map->stage_to_attach_iter.find(conv); - std::pair iterkey(relu, 2); - CHECK(conv_stage_attach->second == iterkey); - const auto& conv_iter_attach = - s0->attach_map->iter_to_attached_stages.find(iterkey); - CHECK_EQ(conv_iter_attach->second.size(), 1); - CHECK_EQ(conv_iter_attach->second[0], conv); - std::stringstream ss; - ss << "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" - << "for ax1 (0,3)\n" - << " for ax2 (0,230)\n" - << " for ax3 (0,230)\n" - << " T_pad = ...\n" - << "for ax1 (0,64)\n" - << " for ax2 (0,112)\n" - << " for ax0 (None)\n" - << " for ax1 (None)\n" - << " for ax2 (None)\n" - << " for ax3 (None)\n" - << " for i (None)\n" - << " for kh (None)\n" - << " for kw (None)\n" - << " T_conv2d_nchw = ...\n" - << " for ax3 (0,112)\n" - << " T_relu = ...\n"; - CHECK_EQ(s0.ToStr().compare(ss.str()), 0); - - s0.compute_root(conv); - s0.compute_root(bn_mul); - CHECK_EQ(s0->attach_map->stage_to_attach_iter.size(), 0); - CHECK_EQ(s0->attach_map->iter_to_attached_stages.size(), 0); - ss.str(std::string()); - ss << "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" - << "for ax1 (0,3)\n" - << " for ax2 (0,230)\n" - << " for ax3 (0,230)\n" - << " T_pad = ...\n" - << "for ax0 (None)\n" - << " for ax1 (None)\n" - << " for ax2 (None)\n" - << " for ax3 (None)\n" - << " for i (None)\n" - << " for kh (None)\n" - << " for kw (None)\n" - << " T_conv2d_nchw = ...\n" - << "for ax0 (None)\n" - << " for ax1 (None)\n" - << " for ax2 (None)\n" - << " for ax3 (None)\n" - << " Bn_mul = ...\n" - << "for ax1 (0,64)\n" - << " for ax2 (0,112)\n" - << " for ax3 (0,112)\n" - << " T_relu = ...\n"; - CHECK_EQ(s0.ToStr().compare(ss.str()), 0); -} - -TEST(Step, CacheReadWrite) { - using namespace tvm; - using namespace tvm::te; - - const auto& test_func = []() -> Array { - int N = 4, H = 7, W = 7, CO = 512, CI = 512, KH = 3, KW = 3, stride = 1; - int padding = 1; - Tensor data = placeholder({N, CI, H, W}, DataType::Float(32), "Data"); - Tensor kernel_data = - placeholder({CO, CI, KH, KW}, DataType::Float(32), "Kernel_data"); - const auto& k_split = compute( - kernel_data->shape, - [&](const Array& i) { - return Array({kernel_data[i[0]][i[1]][i[2]][i[3]] + 1, - div(kernel_data[i[0]][i[1]][i[2]][i[3]], 2)}); - }, - "Kernel_split"); - const auto& kernel = compute( - kernel_data->shape, - [&](Var i, Var j, Var k, Var l) { - return (k_split[0])[i][j][k][l] + (k_split[1])[i][j][k][l]; - }, - "Kernel"); - const auto& conv = - topi::conv2d_nchw(data, kernel, padding, padding, stride, stride); - const auto& relu = topi::relu(conv); - const auto& out = compute( - relu->shape, - [&](Var i, Var j, Var k, Var l) { - return data[i][j][k][l] + relu[i][j][k][l]; - }, - "Add"); - return {data, kernel_data, out}; - }; - const auto& dag = ComputeDAGNode::make(test_func()); - - int data = 0, pad_temp = 1, kernel_data = 2, kernel_split = 3, kernel = 4; - int conv = 5, relu = 6, add = 7; - - // 0: init state - auto s0 = dag.GetInitState(); - std::vector ori_its = s0->stages[add]->iters; - auto its = s0.split(add, s0->stages[add]->iters[0], {2}); - s0.reorder(add, {its[0], ori_its[1], its[1], ori_its[2], ori_its[3]}); - s0.compute_inline(relu); - - // 1: simple cache_write with compute_at - int conv_global = s0.cache_write(conv, "global", dag); - conv++; - relu++; - add++; - s0.compute_at(conv_global, conv, s0->stages[conv]->iters[3]); - - // 2: simple cache_read with compute_at - int kernel_global = s0.cache_read(kernel, "global", {conv_global}, dag); - conv_global++; - conv++; - relu++; - add++; - s0.compute_at(kernel_global, conv_global, s0->stages[conv_global]->iters[4]); - std::stringstream ss; - ss << "Placeholder: Data, Kernel_data\n" - << "for ax0 (0,4)\n" - << " for ax1 (0,512)\n" - << " for ax2 (0,9)\n" - << " for ax3 (0,9)\n" - << " T_pad = ...\n" - << "for ax0 (0,512)\n" - << " for ax1 (0,512)\n" - << " for ax2 (0,3)\n" - << " for ax3 (0,3)\n" - << " Kernel_split = ...\n" - << "for ax0 (0,512)\n" - << " for ax1 (0,512)\n" - << " for ax2 (0,3)\n" - << " for ax3 (0,3)\n" - << " Kernel = ...\n" - << "for ax0 (0,4)\n" - << " for ax1 (0,512)\n" - << " for ax2 (0,7)\n" - << " for ax3 (0,7)\n" - << " for ax0_c (None)\n" - << " for ax1_c (None)\n" - << " for ax2_c (None)\n" - << " for ax3_c (None)\n" - << " for i (None)\n" - << " for ax0 (None)\n" - << " for ax1 (None)\n" - << " for ax2 (None)\n" - << " for ax3 (None)\n" - << " Kernel.global = ...\n" - << " for kh (None)\n" - << " for kw (None)\n" - << " T_conv2d_nchw.global = ...\n" - << " T_conv2d_nchw = ...\n" - << "for ax0.0 (0,2)\n" - << " for ax1 (0,512)\n" - << " for ax0.1 (0,2)\n" - << " for ax2 (0,7)\n" - << " for ax3 (0,7)\n" - << " Add = ...\n"; - CHECK_EQ(s0.ToStr().compare(ss.str()), 0); - - // 3: two level cache_read with compute_at - // preparing for GPU's shared memory & local memory - int pad_temp_global = s0.cache_read(pad_temp, "global", {conv_global}, dag); - kernel_data++; - kernel_split++; - kernel++; - kernel_global++; - conv_global++; - conv++; - relu++; - add++; - int pad_temp_shared = - s0.cache_read(pad_temp_global, "shared", {conv_global}, dag); - kernel_data++; - kernel_split++; - kernel++; - kernel_global++; - conv_global++; - conv++; - relu++; - add++; - s0.compute_at(pad_temp_global, conv_global, - s0->stages[conv_global]->iters[2]); - s0.compute_at(pad_temp_shared, conv_global, - s0->stages[conv_global]->iters[4]); - - // 4: cache_read with multi readers - // This stage cannot be compute at to its consumer - s0.cache_read(data, "global", {pad_temp, add}, dag); - pad_temp++; - pad_temp_global++; - pad_temp_shared++; - kernel_data++; - kernel_split++; - kernel++; - kernel_global++; - conv_global++; - conv++; - relu++; - add++; - ss.str(std::string()); - ss << "Placeholder: Data, Kernel_data\n" - << "for ax0 (0,4)\n" - << " for ax1 (0,512)\n" - << " for ax2 (0,7)\n" - << " for ax3 (0,7)\n" - << " Data.global = ...\n" - << "for ax0 (0,4)\n" - << " for ax1 (0,512)\n" - << " for ax2 (0,9)\n" - << " for ax3 (0,9)\n" - << " T_pad = ...\n" - << "for ax0 (0,512)\n" - << " for ax1 (0,512)\n" - << " for ax2 (0,3)\n" - << " for ax3 (0,3)\n" - << " Kernel_split = ...\n" - << "for ax0 (0,512)\n" - << " for ax1 (0,512)\n" - << " for ax2 (0,3)\n" - << " for ax3 (0,3)\n" - << " Kernel = ...\n" - << "for ax0 (0,4)\n" - << " for ax1 (0,512)\n" - << " for ax2 (0,7)\n" - << " for ax3 (0,7)\n" - << " for ax0_c (None)\n" - << " for ax1_c (None)\n" - << " for ax2_c (None)\n" - << " for ax0 (None)\n" - << " for ax1 (None)\n" - << " for ax2 (None)\n" - << " for ax3 (None)\n" - << " T_pad.global = ...\n" - << " for ax3_c (None)\n" - << " for i (None)\n" - << " for ax0 (None)\n" - << " for ax1 (None)\n" - << " for ax2 (None)\n" - << " for ax3 (None)\n" - << " Kernel.global = ...\n" - << " for ax0 (None)\n" - << " for ax1 (None)\n" - << " for ax2 (None)\n" - << " for ax3 (None)\n" - << " T_pad.global.shared = ...\n" - << " for kh (None)\n" - << " for kw (None)\n" - << " T_conv2d_nchw.global = ...\n" - << " T_conv2d_nchw = ...\n" - << "for ax0.0 (0,2)\n" - << " for ax1 (0,512)\n" - << " for ax0.1 (0,2)\n" - << " for ax2 (0,7)\n" - << " for ax3 (0,7)\n" - << " Add = ...\n"; - CHECK_EQ(s0.ToStr().compare(ss.str()), 0); - - // 5: cache_write with multi outputs - // TVM's cache_write actually has a bug with this case: - - // After schedule.cache_write, TVM generate one new stage: - // From: kernel_data -> kernel_split -> kernel - // To: kernel_data -> kernel_split_global -> kernel_split -> kernel - - // But with topo sort analyse, we get: - // kernel_data -> kernel_split_global -> kernel_split -> kernel - // \ / - // ----------------> kernel_split ----------------> - - // Seems there's bug with the input/output tensor. Such multi outputs case - // should be unusual, so we make some hack on DoCacheWrite - // To be fixed in the future - s0.cache_write(kernel_split, "global", dag); - ss.str(std::string()); - ss << "Placeholder: Data, Kernel_data\n" - << "for ax0 (0,4)\n" - << " for ax1 (0,512)\n" - << " for ax2 (0,7)\n" - << " for ax3 (0,7)\n" - << " Data.global = ...\n" - << "for ax0 (0,4)\n" - << " for ax1 (0,512)\n" - << " for ax2 (0,9)\n" - << " for ax3 (0,9)\n" - << " T_pad = ...\n" - << "for ax0_c (0,512)\n" - << " for ax1_c (0,512)\n" - << " for ax2_c (0,3)\n" - << " for ax3_c (0,3)\n" - << " Kernel_split.global = ...\n" - << "for ax0 (0,512)\n" - << " for ax1 (0,512)\n" - << " for ax2 (0,3)\n" - << " for ax3 (0,3)\n" - << " Kernel_split = ...\n" - << "for ax0 (0,512)\n" - << " for ax1 (0,512)\n" - << " for ax2 (0,3)\n" - << " for ax3 (0,3)\n" - << " Kernel_split = ...\n" - << "for ax0 (0,512)\n" - << " for ax1 (0,512)\n" - << " for ax2 (0,3)\n" - << " for ax3 (0,3)\n" - << " Kernel = ...\n" - << "for ax0 (0,4)\n" - << " for ax1 (0,512)\n" - << " for ax2 (0,7)\n" - << " for ax3 (0,7)\n" - << " for ax0_c (None)\n" - << " for ax1_c (None)\n" - << " for ax2_c (None)\n" - << " for ax0 (None)\n" - << " for ax1 (None)\n" - << " for ax2 (None)\n" - << " for ax3 (None)\n" - << " T_pad.global = ...\n" - << " for ax3_c (None)\n" - << " for i (None)\n" - << " for ax0 (None)\n" - << " for ax1 (None)\n" - << " for ax2 (None)\n" - << " for ax3 (None)\n" - << " Kernel.global = ...\n" - << " for ax0 (None)\n" - << " for ax1 (None)\n" - << " for ax2 (None)\n" - << " for ax3 (None)\n" - << " T_pad.global.shared = ...\n" - << " for kh (None)\n" - << " for kw (None)\n" - << " T_conv2d_nchw.global = ...\n" - << " T_conv2d_nchw = ...\n" - << "for ax0.0 (0,2)\n" - << " for ax1 (0,512)\n" - << " for ax0.1 (0,2)\n" - << " for ax2 (0,7)\n" - << " for ax3 (0,7)\n" - << " Add = ...\n"; - CHECK_EQ(s0.ToStr().compare(ss.str()), 0); -} - -TEST(Step, FollowSplitFollowFusedSplit) { - const auto& tensors = matmul_func(512, 512, 512); - const auto& dag = ComputeDAGNode::make(tensors); - - State s0 = dag.GetInitState(); - int C = 2; - - int C_global = s0.cache_write(C, "global", dag); - C++; - - // FollowSplitStep currently only support `inner_to_outer = true` - const auto& its0 = s0.split(C, s0->stages[C]->iters[0], {4, 2, 8, 4}, true); - int split_step0 = s0->transform_steps.size() - 1; - // const auto& its1 = s0.split(C, s0->stages[C]->iters[5], {4, 2, 8, 4}, - // false); int split_step1 = s0->transform_steps.size() - 1; - for (int level = 1; level <= 5; level++) { - State tmp = s0; - tmp.follow_split(C_global, s0->stages[C_global]->iters[0], split_step0, - level); - // tmp.follow_split(C_global, s0->stages[C_global]->iters[5], split_step1, - // level); - const auto& stage_C = tmp->stages[C]; - const auto& stage_C_global = tmp->stages[C_global]; - for (int i = 0; i < level; i++) { - CHECK_EQ(stage_C->iters[i]->range->extent.as()->value, - stage_C_global->iters[i]->range->extent.as()->value); - } - // for (int i = 0; i < level; i++) { - // CHECK(stage_C->iters[i+5]->range->extent.as()->value == - // stage_C_global->iters[i+5]->range->extent.as()->value); - // } - } - - const auto& its1 = s0.split(C, s0->stages[C]->iters[5], {2, 2, 4, 8}); - int split_step1 = s0->transform_steps.size() - 1; - std::vector its; - for (int i = 0; i < 5; i++) { - its.push_back(its0[i]); - its.push_back(its1[i]); - } - s0.reorder(C, its); - for (int i = 0; i < 5; i++) { - s0.fuse(C, {s0->stages[C]->iters[i], s0->stages[C]->iters[i + 1]}); - } - for (int level = 0; level < 4; level++) { - State tmp = s0; - tmp.follow_fused_split(C_global, tmp->stages[C_global]->iters[0], - {split_step0, split_step1}, level, false); - const auto& stage_C = tmp->stages[C]; - const auto& stage_C_global = tmp->stages[C_global]; - CHECK_EQ(stage_C->iters[level + 1]->range->extent.as()->value, - stage_C_global->iters[0]->range->extent.as()->value); - } - for (int level = 0; level < 4; level++) { - State tmp = s0; - tmp.follow_fused_split(C_global, tmp->stages[C_global]->iters[0], - {split_step0, split_step1}, level, true); - const auto& stage_C = tmp->stages[C]; - const auto& stage_C_global = tmp->stages[C_global]; - CHECK_EQ(stage_C->iters[level + 1]->range->extent.as()->value, - stage_C_global->iters[1]->range->extent.as()->value); - } -} - -TEST(Step, Rfactor) { - // todo -} - -TEST(Feature, ExtractionMatmul) { - const auto& tensors = matmul_func(512, 512, 512); - const auto& dag = ComputeDAGNode::make(tensors); - State s0 = dag.GetInitState(); - - Iterator ti = s0->stages[2]->iters[0]; - Iterator tj = s0->stages[2]->iters[1]; - Iterator tk = s0->stages[2]->iters[2]; - std::vector its; - its = s0.split(2, ti, {16}); - Iterator tio = its[0], tii = its[1]; - its = s0.split(2, tj, {8}); - Iterator tjo = its[0], tji = its[1]; - s0.reorder(2, {tio, tjo, tk, tji, tii}); - s0.vectorize(2, tji); - s0.parallel(2, tio); - s0.parallel(2, tjo); - s0.unroll(2, tk); - - int max_n_bufs = 5; - std::vector> features; - std::vector feature_names; - GetPerStmtFeatureName(max_n_bufs, &feature_names); - GetPerStmtFeaturesFromStates( - {s0}, - SearchTaskNode::make(dag, "test", tvm::target::llvm(), - tvm::target::llvm(), HardwareParams()), - max_n_bufs, 0, &features); - int num_states = 1; - CHECK_EQ(feature_names.size(), (features[0].size() - 1) / num_states); - // TODO(...): Add feature check here -} - int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; diff --git a/tests/python/unittest/test_ansor_common.py b/tests/python/unittest/test_ansor_common.py index d701ef5b7bbd..cd8a1eedb162 100644 --- a/tests/python/unittest/test_ansor_common.py +++ b/tests/python/unittest/test_ansor_common.py @@ -14,13 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import random -import os -import numpy as np -import tvm -from tvm import te -from tvm import ansor +"""Common functions for ansor test cases""" + + +from tvm import te, ansor import topi @@ -59,507 +57,26 @@ def conv2d_nchw_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation return [data, kernel, bias, bn_offset, bn_scale, out] -def test_compute_dag_basic(): - dag = ansor.ComputeDAG(conv2d_nchw_bn_relu(1, 224, 224, 3, 64, 7, 2, 3)) - - print(dag) - print(dag.access_analyzer) - print(dag.get_init_state()) - - -def test_state_split_fuse_reorder(): - dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) - s0 = dag.get_init_state() - s1 = s0 - ti = s0.stage(2).iterator(0) - tj = s0.stage(2).iterator(1) - tk = s0.stage(2).iterator(2) - - assert ti.range.extent == 512 - - s0, its = s0.split(2, ti, [16]) - tio = its[0] - tii = its[1] - assert s0.stage(2).iterator(0).range.extent == 32 - assert s0.stage(2).iterator(1).range.extent == 16 - - s0, its = s0.split(2, tj, [8]) - tjo = its[0] - tji = its[1] - assert s0.stage(2).iterator(2).range.extent == 64 - assert s0.stage(2).iterator(3).range.extent == 8 - - s0 = s0.reorder(2, [tio, tjo, tk, tji, tii]) - assert s0.stage(2).iterator(2).range.extent == 512 - - s0, res_it = s0.fuse(2, [tio, tjo]) - assert res_it.range.extent == 2048 - - s1, _ = s1.split(2, ti, [8, 2]) - s1, _ = s1.split(2, tj, [32, 8], False) - assert s1.stage(2).iterator(0).range.extent == 32 - assert s1.stage(2).iterator(1).range.extent == 8 - assert s1.stage(2).iterator(2).range.extent == 2 - assert s1.stage(2).iterator(3).range.extent == 32 - assert s1.stage(2).iterator(4).range.extent == 8 - assert s1.stage(2).iterator(5).range.extent == 2 - - -def test_state_compute_at_root_inline(): - dag = ansor.ComputeDAG(conv2d_nchw_bn_relu(1, 224, 224, 3, 64, 7, 2, 3)) - - # data, padding, kernel = 0, 1, 2 - conv = 3 - # bias = 4 - bias_add = 5 - # bn_scale = 6 - bn_mul = 7 - # bn_offset = 8 - bn_add, relu = 9, 10 - - s0 = dag.get_init_state() - s0 = s0.compute_inline(bn_add) - s0 = s0.compute_inline(bn_mul) - s0 = s0.compute_inline(bias_add) - s0 = s0.compute_at(conv, relu, s0.stage(relu).iterator(2)) - assert str(s0) == \ - "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + \ - "for i1 (0,3)\n" + \ - " for i2 (0,230)\n" + \ - " for i3 (0,230)\n" + \ - " pad_temp = ...\n" + \ - "for i1 (0,64)\n" + \ - " for i2 (0,112)\n" + \ - " for nn (None)\n" + \ - " for ff (None)\n" + \ - " for yy (None)\n" + \ - " for xx (None)\n" + \ - " for rc (None)\n" + \ - " for ry (None)\n" + \ - " for rx (None)\n" + \ - " compute = ...\n" + \ - " for i3 (0,112)\n" + \ - " compute = ...\n" - - s0 = s0.compute_root(conv) - s0 = s0.compute_root(bn_mul) - assert str(s0) == \ - "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + \ - "for i1 (0,3)\n" + \ - " for i2 (0,230)\n" + \ - " for i3 (0,230)\n" + \ - " pad_temp = ...\n" + \ - "for nn (None)\n" + \ - " for ff (None)\n" + \ - " for yy (None)\n" + \ - " for xx (None)\n" + \ - " for rc (None)\n" + \ - " for ry (None)\n" + \ - " for rx (None)\n" + \ - " compute = ...\n" + \ - "for i (None)\n" + \ - " for j (None)\n" + \ - " for k (None)\n" + \ - " for l (None)\n" + \ - " Bn_mul = ...\n" + \ - "for i1 (0,64)\n" + \ - " for i2 (0,112)\n" + \ - " for i3 (0,112)\n" + \ - " compute = ...\n" - - -def test_state_cache_read_write(): - N, H, W, CO, CI, KH, KW, strides, padding = 4, 7, 7, 512, 512, 3, 3, ( - 1, 1), (1, 1) - - data = te.placeholder((N, CI, H, W), name='Data') - kernel_data = te.placeholder((CO, CI, KH, KW), name='Kernel_data') - k0, k1 = te.compute(kernel_data.shape, - lambda *i: (kernel_data(*i)+1, kernel_data(*i)/2), - name='Kernel_split') - kernel = te.compute(kernel_data.shape, - lambda *i: k0(*i) + k1(*i), - name='Kernel') - conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation=1) - relu = topi.nn.relu(conv) - out = topi.add(data, relu) - - dag = ansor.ComputeDAG([data, kernel_data, out]) - data, pad_temp, kernel_data, kernel_split, kernel, conv, relu, add = 0, 1, 2, 3, 4, 5, 6, 7 - - # 0: init state - s0 = dag.get_init_state() - ori_its = s0.stage(add).iterators() - s0, its = s0.split(add, s0.stage(add).iterator(0), [2]) - s0 = s0.reorder(add, [its[0], ori_its[1], its[1], ori_its[2], ori_its[3]]) - s0 = s0.compute_inline(relu) - - # 1: simple cache_write with compute_at - s0, conv_global = s0.cache_write(conv, "global", dag) - conv += 1 - relu += 1 - add += 1 - s0 = s0.compute_at(conv_global, conv, s0.stage(conv).iterator(3)) - - # 2: simple cache_read with compute_at - s0, kernel_global = s0.cache_read(kernel, "global", [conv_global], dag) - conv_global += 1 - conv += 1 - relu += 1 - add += 1 - s0 = s0.compute_at(kernel_global, conv_global, - s0.stage(conv_global).iterator(4)) - assert str(s0) == \ - "Placeholder: Data, Kernel_data\n" + \ - "for i0 (0,4)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,9)\n" + \ - " for i3 (0,9)\n" + \ - " pad_temp = ...\n" + \ - "for i0 (0,512)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,3)\n" + \ - " for i3 (0,3)\n" + \ - " Kernel_split = ...\n" + \ - "for i0 (0,512)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,3)\n" + \ - " for i3 (0,3)\n" + \ - " Kernel = ...\n" + \ - "for nn (0,4)\n" + \ - " for ff (0,512)\n" + \ - " for yy (0,7)\n" + \ - " for xx (0,7)\n" + \ - " for nn_c (None)\n" + \ - " for ff_c (None)\n" + \ - " for yy_c (None)\n" + \ - " for xx_c (None)\n" + \ - " for rc (None)\n" + \ - " for ax0 (None)\n" + \ - " for ax1 (None)\n" + \ - " for ax2 (None)\n" + \ - " for ax3 (None)\n" + \ - " Kernel.global = ...\n" + \ - " for ry (None)\n" + \ - " for rx (None)\n" + \ - " compute.global = ...\n" + \ - " compute = ...\n" + \ - "for ax0.0 (0,2)\n" + \ - " for ax1 (0,512)\n" + \ - " for ax0.1 (0,2)\n" + \ - " for ax2 (0,7)\n" + \ - " for ax3 (0,7)\n" + \ - " T_add = ...\n" - - # 3: two level cache_read with compute_at - # preparing for GPU's shared memory & local memory - s0, pad_temp_global = s0.cache_read(pad_temp, "global", [conv_global], dag) - kernel_data += 1 - kernel_split += 1 - kernel += 1 - kernel_global += 1 - conv_global += 1 - conv += 1 - relu += 1 - add += 1 - s0, pad_temp_shared = s0.cache_read( - pad_temp_global, "shared", [conv_global], dag) - kernel_data += 1 - kernel_split += 1 - kernel += 1 - kernel_global += 1 - conv_global += 1 - conv += 1 - relu += 1 - add += 1 - s0 = s0.compute_at(pad_temp_global, conv_global, - s0.stage(conv_global).iterator(2)) - s0 = s0.compute_at(pad_temp_shared, conv_global, - s0.stage(conv_global).iterator(4)) - - # 4: cache_read with multi readers - # This stage cannot be compute at to its consumer - s0, data_global = s0.cache_read(data, "global", [pad_temp, add], dag) - pad_temp += 1 - pad_temp_global += 1 - pad_temp_shared += 1 - kernel_data += 1 - kernel_split += 1 - kernel += 1 - kernel_global += 1 - conv_global += 1 - conv += 1 - relu += 1 - add += 1 - assert str(s0) == \ - "Placeholder: Data, Kernel_data\n" + \ - "for ax0 (0,4)\n" + \ - " for ax1 (0,512)\n" + \ - " for ax2 (0,7)\n" + \ - " for ax3 (0,7)\n" + \ - " Data.global = ...\n" + \ - "for i0 (0,4)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,9)\n" + \ - " for i3 (0,9)\n" + \ - " pad_temp = ...\n" + \ - "for i0 (0,512)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,3)\n" + \ - " for i3 (0,3)\n" + \ - " Kernel_split = ...\n" + \ - "for i0 (0,512)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,3)\n" + \ - " for i3 (0,3)\n" + \ - " Kernel = ...\n" + \ - "for nn (0,4)\n" + \ - " for ff (0,512)\n" + \ - " for yy (0,7)\n" + \ - " for xx (0,7)\n" + \ - " for nn_c (None)\n" + \ - " for ff_c (None)\n" + \ - " for yy_c (None)\n" + \ - " for ax0 (None)\n" + \ - " for ax1 (None)\n" + \ - " for ax2 (None)\n" + \ - " for ax3 (None)\n" + \ - " pad_temp.global = ...\n" + \ - " for xx_c (None)\n" + \ - " for rc (None)\n" + \ - " for ax0 (None)\n" + \ - " for ax1 (None)\n" + \ - " for ax2 (None)\n" + \ - " for ax3 (None)\n" + \ - " Kernel.global = ...\n" + \ - " for ax0 (None)\n" + \ - " for ax1 (None)\n" + \ - " for ax2 (None)\n" + \ - " for ax3 (None)\n" + \ - " pad_temp.global.shared = ...\n" + \ - " for ry (None)\n" + \ - " for rx (None)\n" + \ - " compute.global = ...\n" + \ - " compute = ...\n" + \ - "for ax0.0 (0,2)\n" + \ - " for ax1 (0,512)\n" + \ - " for ax0.1 (0,2)\n" + \ - " for ax2 (0,7)\n" + \ - " for ax3 (0,7)\n" + \ - " T_add = ...\n" - - # 5: cache_write with multi outputs - # See tests/cpp/ansor_test.cc for more information - s0, _ = s0.cache_write(kernel_split, "global", dag) - assert str(s0) == \ - "Placeholder: Data, Kernel_data\n" + \ - "for ax0 (0,4)\n" + \ - " for ax1 (0,512)\n" + \ - " for ax2 (0,7)\n" + \ - " for ax3 (0,7)\n" + \ - " Data.global = ...\n" + \ - "for i0 (0,4)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,9)\n" + \ - " for i3 (0,9)\n" + \ - " pad_temp = ...\n" + \ - "for i0_c (0,512)\n" + \ - " for i1_c (0,512)\n" + \ - " for i2_c (0,3)\n" + \ - " for i3_c (0,3)\n" + \ - " Kernel_split.global = ...\n" + \ - "for i0 (0,512)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,3)\n" + \ - " for i3 (0,3)\n" + \ - " Kernel_split = ...\n" + \ - "for i0 (0,512)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,3)\n" + \ - " for i3 (0,3)\n" + \ - " Kernel_split = ...\n" + \ - "for i0 (0,512)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,3)\n" + \ - " for i3 (0,3)\n" + \ - " Kernel = ...\n" + \ - "for nn (0,4)\n" + \ - " for ff (0,512)\n" + \ - " for yy (0,7)\n" + \ - " for xx (0,7)\n" + \ - " for nn_c (None)\n" + \ - " for ff_c (None)\n" + \ - " for yy_c (None)\n" + \ - " for ax0 (None)\n" + \ - " for ax1 (None)\n" + \ - " for ax2 (None)\n" + \ - " for ax3 (None)\n" + \ - " pad_temp.global = ...\n" + \ - " for xx_c (None)\n" + \ - " for rc (None)\n" + \ - " for ax0 (None)\n" + \ - " for ax1 (None)\n" + \ - " for ax2 (None)\n" + \ - " for ax3 (None)\n" + \ - " Kernel.global = ...\n" + \ - " for ax0 (None)\n" + \ - " for ax1 (None)\n" + \ - " for ax2 (None)\n" + \ - " for ax3 (None)\n" + \ - " pad_temp.global.shared = ...\n" + \ - " for ry (None)\n" + \ - " for rx (None)\n" + \ - " compute.global = ...\n" + \ - " compute = ...\n" + \ - "for ax0.0 (0,2)\n" + \ - " for ax1 (0,512)\n" + \ - " for ax0.1 (0,2)\n" + \ - " for ax2 (0,7)\n" + \ - " for ax3 (0,7)\n" + \ - " T_add = ...\n" - - -def test_follow_split_follow_fused_split(): - dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) - s0 = dag.get_init_state() - C = 2 - - s0, C_global = s0.cache_write(C, "global", dag) - C += 1 - - s0, its0 = s0.split(C, s0.stage(C).iterator(0), [4, 2, 8, 4], True) - split_step0 = s0.transform_steps_size() - 1 - for level in range(1, 6): - tmp = s0 - tmp, _ = tmp.follow_split(C_global, tmp.stage( - C_global).iterator(0), split_step0, level) - for i in range(0, level): - assert tmp.stage(C).iterator(i).range.extent == \ - tmp.stage(C_global).iterator(i).range.extent - - s0, its1 = s0.split(C, s0.stage(C).iterator(5), [2, 2, 4, 8]) - split_step1 = s0.transform_steps_size() - 1 - its = [] - for i0, i1 in zip(its0, its1): - its.append(i0) - its.append(i1) - s0 = s0.reorder(C, its) - for i in range(0, 5): - s0, _ = s0.fuse(C, [s0.stage(C).iterator(i), - s0.stage(C).iterator(i+1)]) - for level in range(0, 4): - tmp = s0 - tmp, _ = tmp.follow_fused_split(C_global, tmp.stage(C_global).iterator(0), - [split_step0, split_step1], level, False) - assert tmp.stage(C).iterator(level+1).range.extent == \ - tmp.stage(C_global).iterator(0).range.extent - for level in range(0, 4): - tmp = s0 - tmp, _ = tmp.follow_fused_split(C_global, tmp.stage(C_global).iterator(0), - [split_step0, split_step1], level, True) - assert tmp.stage(C).iterator(level+1).range.extent == \ - tmp.stage(C_global).iterator(1).range.extent - - -def test_rfactor(): - pass - - -def test_measure_local_builder_runner(): +def get_tiled_matmul(): dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) s0 = dag.get_init_state() A, B, C = 0, 1, 2 - s0, C_global = s0.cache_write(C, "global", dag) + C_global = s0.cache_write(C, "global", dag) C += 1 - s0, its0 = s0.split(C, s0.stage(C).iterator(0), [4, 8, 8]) - s0, its1 = s0.split(C, s0.stage(C).iterator(4), [8, 4, 4]) - s0 = s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], - its0[3], its1[3]]) - s0 = s0.compute_at(C_global, C, s0.stage(C).iterator(3)) - s0, _ = s0.split(C_global, s0.stage(C_global).iterator(2), [16]) - s0, B_global = s0.cache_read(B, "global", [C_global], dag) + its0 = s0.split(C, s0.stages[C].iters[0], [4, 8, 8]) + its1 = s0.split(C, s0.stages[C].iters[4], [8, 4, 4]) + s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], its1[3]]) + s0.compute_at(C_global, C, s0.stages[C].iters[3]) + s0.split(C_global, s0.stages[C_global].iters[2], [16]) + B_global = s0.cache_read(B, "global", [C_global], dag) C += 1 C_global += 1 - s0 = s0.compute_at(B_global, C_global, s0.stage(C_global).iterator(0)) - s0, A_global = s0.cache_read(A, "global", [C_global], dag) + s0.compute_at(B_global, C_global, s0.stages[C_global].iters[0]) + A_global = s0.cache_read(A, "global", [C_global], dag) B += 1 B_global += 1 C += 1 C_global += 1 - s0 = s0.compute_at(A_global, C_global, s0.stage(C_global).iterator(2)) - - tgt = tvm.target.create("llvm") - task = ansor.SearchTask(dag, "test", tgt) - - minp = ansor.MeasureInput(task, s0) - local_builder = ansor.LocalBuilder() - local_runner = ansor.LocalRunner() - - bress = local_builder.build([minp]) - assert bress[0].error_no == 0 - mress = local_runner.run([minp], bress) - assert mress[0].error_no == 0 - - -def test_search_basic(): - print("Test schedule search with default search policy") - - N = 128 - A, B, C = matmul_nkkm(N, N, N) - dag = ansor.ComputeDAG([A, B, C]) - tgt = tvm.target.create("llvm") - task = ansor.SearchTask(dag, "test", tgt) - - # seed = random.randint(1, 1 << 30) - seed = 944563397 - log_file = "/tmp/_ansor_python_ut_test.json" - - random.seed(seed) - cost_model = ansor.RandomModel() - search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) - tune_option = ansor.TuneOption(n_trials=2, - callbacks=[ansor.LogToFile(log_file)]) - state = ansor.auto_schedule(task, search_policy, - tune_option=tune_option) - sch, args = dag.apply_steps_from_state(state) - - print("==== Get State ====") - print(state) - print("==== Get Python Code ====") - print(dag.print_python_code_from_state(state)) - - try: - print("==== Get Lowered Stmt ====") - print(tvm.lower(sch, args, simple_mode=True)) - mod = tvm.build(sch, args, tgt) - - ctx = tvm.context("llvm", 0) - a = tvm.nd.array(np.random.uniform(size=(N, N)).astype(A.dtype), ctx) - b = tvm.nd.array(np.random.uniform(size=(N, N)).astype(B.dtype), ctx) - c = tvm.nd.array(np.zeros((N, N), dtype=C.dtype), ctx) - mod(a, b, c) - tvm.testing.assert_allclose(c.asnumpy(), np.dot( - a.asnumpy(), b.asnumpy()), rtol=1e-5) - print("==== Verification passed ====") - except Exception: - raise Exception("Error encounterd with seed: %d" % (seed)) - - inp, res = ansor.best_measure_pair_in_file(log_file) - s0 = dag.infer_bound_from_state(state) - s1 = dag.infer_bound_from_state(inp.state) - assert str(s0) == str(s1) - - if os.path.isfile(log_file): - os.system("rm -rf %s" % log_file) - - -if __name__ == "__main__": - test_compute_dag_basic() - test_state_split_fuse_reorder() - test_state_compute_at_root_inline() - test_state_cache_read_write() - test_follow_split_follow_fused_split() - test_rfactor() - test_measure_local_builder_runner() - test_search_basic() + s0.compute_at(A_global, C_global, s0.stages[C_global].iters[2]) + return dag, s0.state_object diff --git a/tests/python/unittest/test_ansor_compute_dag.py b/tests/python/unittest/test_ansor_compute_dag.py new file mode 100644 index 000000000000..61eb0153a87c --- /dev/null +++ b/tests/python/unittest/test_ansor_compute_dag.py @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test ComputeDAG (replay, infer bound)""" + +import tvm +from tvm import ansor, te + +from test_ansor_common import get_tiled_matmul + + +def test_apply_steps(): + dag, s = get_tiled_matmul() + dag.print_python_code_from_state(s) + sch, tensors = dag.apply_steps_from_state(s) + stmt = tvm.lower(sch, tensors, simple_mode=True) + + +def test_infer_bound(): + dag, s = get_tiled_matmul() + s = dag.infer_bound_from_state(s) + s = ansor.loop_state.State(s) + + A_global, B_global, C_global = 1, 3, 4 + assert s.stages[B_global].iters[0].range.extent == 512 + assert s.stages[B_global].iters[1].range.extent == 16 + assert s.stages[A_global].iters[0].range.extent == 1 + assert s.stages[A_global].iters[1].range.extent == 16 + assert s.stages[C_global].iters[0].range.extent == 64 + + +def test_lower_legalize_invalid_attach(): + N, M = 10, 10 + + A = te.compute((N, M), lambda i, j: 1.0, name='A') + B = te.compute((N, M), lambda i, j: A[i][j], name='B') + + dag = ansor.ComputeDAG([A, B]) + s = dag.get_init_state() + + A, B = 0, 1 + s.compute_at(A, B, s.stages[B].iters[1]) + s.split(B, s.stages[B].iters[1], [2]) + + sch, tensors = dag.apply_steps_from_state(s.state_object) + stmt = tvm.lower(sch, tensors, simple_mode=True) + + +if __name__ == "__main__": + test_apply_steps() + test_infer_bound() + test_lower_legalize_invalid_attach() diff --git a/tests/python/unittest/test_ansor_loop_state.py b/tests/python/unittest/test_ansor_loop_state.py new file mode 100644 index 000000000000..34b720e7e1af --- /dev/null +++ b/tests/python/unittest/test_ansor_loop_state.py @@ -0,0 +1,475 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test loop state and schedule primitives""" + +from tvm import ansor, te +import topi + +from test_ansor_common import matmul_nkkm, conv2d_nchw_bn_relu + + +def test_state_split_fuse_reorder_annotation(): + dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) + s0 = dag.get_init_state() + C = 2 + i, j, k = s0.stages[C].iters + + assert i.range.extent == 512 + + io, ii = s0.split(C, i, [16]) + assert s0.stages[C].iters[0] == io + assert s0.stages[C].iters[1] == ii + assert io.range.extent == 32 + assert ii.range.extent == 16 + + jo, ji = s0.split(C, j, [8]) + assert jo.range.extent == 64 + assert ji.range.extent == 8 + + s0.reorder(C, [io, jo, k, ji, ii]) + assert s0.stages[C].iters[2].range.extent == 512 + + fused_it = s0.fuse(C, [io, jo]) + assert fused_it.range.extent == 2048 + + s1 = dag.get_init_state() + i, j, _ = s1.stages[C].iters + i1, i2, i3 = s1.split(C, i, [8, 2]) + j1, j2, j3 = s1.split(C, j, [32, 8], False) + assert s1.stages[C].iters[0].range.extent == 32 + assert s1.stages[C].iters[1].range.extent == 8 + assert s1.stages[C].iters[2].range.extent == 2 + assert s1.stages[C].iters[3].range.extent == 32 + assert s1.stages[C].iters[4].range.extent == 8 + assert s1.stages[C].iters[5].range.extent == 2 + + s1.parallel(C, j1) + s1.unroll(C, j2) + s1.vectorize(C, j3) + s1.bind_thread(C, i1, "blockIdx.x") + s1.bind_thread(C, i2, "vthread") + s1.bind_thread(C, i3, "threadIdx.y") + + +def test_follow_split_follow_fused_split(): + dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) + s0 = dag.get_init_state() + C = 2 + + C_global = s0.cache_write(C, "global", dag) + C += 1 + + its0 = s0.split(C, s0.stages[C].iters[0], [4, 2, 8, 4], True) + split_step0 = s0.transform_steps_size() - 1 + for level in range(1, 6): + tmp = s0.copy() + tmp.follow_split(C_global, tmp.stages[C_global].iters[0], split_step0, level) + for i in range(0, level): + assert tmp.stages[C].iters[i].range.extent == \ + tmp.stages[C_global].iters[i].range.extent + + its1 = s0.split(C, s0.stages[C].iters[5], [2, 2, 4, 8]) + split_step1 = s0.transform_steps_size() - 1 + its = [] + for i0, i1 in zip(its0, its1): + its.append(i0) + its.append(i1) + s0.reorder(C, its) + for i in range(0, 5): + s0.fuse(C, [s0.stages[C].iters[i], s0.stages[C].iters[i + 1]]) + + for level in range(0, 4): + tmp = s0.copy() + tmp.follow_fused_split(C_global, tmp.stages[C_global].iters[0], + [split_step0, split_step1], level, False) + assert tmp.stages[C].iters[level + 1].range.extent == \ + tmp.stages[C_global].iters[0].range.extent + + for level in range(0, 4): + tmp = s0.copy() + tmp.follow_fused_split(C_global, tmp.stages[C_global].iters[0], + [split_step0, split_step1], level, True) + assert tmp.stages[C].iters[level + 1].range.extent == \ + tmp.stages[C_global].iters[1].range.extent + + +def test_state_compute_at_root_inline(): + dag = ansor.ComputeDAG(conv2d_nchw_bn_relu(1, 224, 224, 3, 64, 7, 2, 3)) + + # data, padding, kernel = 0, 1, 2 + conv = 3 + # bias = 4 + bias_add = 5 + # bn_scale = 6 + bn_mul = 7 + # bn_offset = 8 + bn_add, relu = 9, 10 + + s0 = dag.get_init_state() + s0.compute_inline(bn_add) + s0.compute_inline(bn_mul) + s0.compute_inline(bias_add) + s0.compute_at(conv, relu, s0.stages[relu].iters[2]) + assert str(s0) == \ + "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + \ + "for i1 (0,3)\n" + \ + " for i2 (0,230)\n" + \ + " for i3 (0,230)\n" + \ + " pad_temp = ...\n" + \ + "for i1 (0,64)\n" + \ + " for i2 (0,112)\n" + \ + " for nn (None)\n" + \ + " for ff (None)\n" + \ + " for yy (None)\n" + \ + " for xx (None)\n" + \ + " for rc (None)\n" + \ + " for ry (None)\n" + \ + " for rx (None)\n" + \ + " compute = ...\n" + \ + " for i3 (0,112)\n" + \ + " compute = ...\n" + + s0 = s0.compute_root(conv) + s0 = s0.compute_root(bn_mul) + assert str(s0) == \ + "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + \ + "for i1 (0,3)\n" + \ + " for i2 (0,230)\n" + \ + " for i3 (0,230)\n" + \ + " pad_temp = ...\n" + \ + "for nn (None)\n" + \ + " for ff (None)\n" + \ + " for yy (None)\n" + \ + " for xx (None)\n" + \ + " for rc (None)\n" + \ + " for ry (None)\n" + \ + " for rx (None)\n" + \ + " compute = ...\n" + \ + "for i (None)\n" + \ + " for j (None)\n" + \ + " for k (None)\n" + \ + " for l (None)\n" + \ + " Bn_mul = ...\n" + \ + "for i1 (0,64)\n" + \ + " for i2 (0,112)\n" + \ + " for i3 (0,112)\n" + \ + " compute = ...\n" + + +def test_state_cache_read_write(): + N, H, W, CO, CI, KH, KW, strides, padding = 4, 7, 7, 512, 512, 3, 3, ( + 1, 1), (1, 1) + + data = te.placeholder((N, CI, H, W), name='Data') + kernel_data = te.placeholder((CO, CI, KH, KW), name='Kernel_data') + k0, k1 = te.compute(kernel_data.shape, + lambda *i: (kernel_data(*i)+1, kernel_data(*i)/2), + name='Kernel_split') + kernel = te.compute(kernel_data.shape, + lambda *i: k0(*i) + k1(*i), + name='Kernel') + conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation=1) + relu = topi.nn.relu(conv) + out = topi.add(data, relu) + + dag = ansor.ComputeDAG([data, kernel_data, out]) + data, pad_temp, kernel_data, kernel_split, kernel, conv, relu, add = 0, 1, 2, 3, 4, 5, 6, 7 + + # 0: init state + s0 = dag.get_init_state() + ori_its = s0.stages[add].iters + its = s0.split(add, s0.stages[add].iters[0], [2]) + s0.reorder(add, [its[0], ori_its[1], its[1], ori_its[2], ori_its[3]]) + s0.compute_inline(relu) + + # 1: simple cache_write with compute_at + conv_global = s0.cache_write(conv, "global", dag) + conv += 1 + relu += 1 + add += 1 + s0.compute_at(conv_global, conv, s0.stages[conv].iters[3]) + + # 2: simple cache_read with compute_at + kernel_global = s0.cache_read(kernel, "global", [conv_global], dag) + conv_global += 1 + conv += 1 + relu += 1 + add += 1 + s0.compute_at(kernel_global, conv_global, + s0.stages[conv_global].iters[4]) + assert str(s0) == \ + "Placeholder: Data, Kernel_data\n" + \ + "for i0 (0,4)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,9)\n" + \ + " for i3 (0,9)\n" + \ + " pad_temp = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel_split = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel = ...\n" + \ + "for nn (0,4)\n" + \ + " for ff (0,512)\n" + \ + " for yy (0,7)\n" + \ + " for xx (0,7)\n" + \ + " for nn_c (None)\n" + \ + " for ff_c (None)\n" + \ + " for yy_c (None)\n" + \ + " for xx_c (None)\n" + \ + " for rc (None)\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " Kernel.global = ...\n" + \ + " for ry (None)\n" + \ + " for rx (None)\n" + \ + " compute.global = ...\n" + \ + " compute = ...\n" + \ + "for ax0.0 (0,2)\n" + \ + " for ax1 (0,512)\n" + \ + " for ax0.1 (0,2)\n" + \ + " for ax2 (0,7)\n" + \ + " for ax3 (0,7)\n" + \ + " T_add = ...\n" + + # 3: two level cache_read with compute_at + # preparing for GPU's shared memory & local memory + pad_temp_global = s0.cache_read(pad_temp, "global", [conv_global], dag) + kernel_data += 1 + kernel_split += 1 + kernel += 1 + kernel_global += 1 + conv_global += 1 + conv += 1 + relu += 1 + add += 1 + pad_temp_shared = s0.cache_read(pad_temp_global, "shared", [conv_global], dag) + kernel_data += 1 + kernel_split += 1 + kernel += 1 + kernel_global += 1 + conv_global += 1 + conv += 1 + relu += 1 + add += 1 + s0.compute_at(pad_temp_global, conv_global, s0.stages[conv_global].iters[2]) + s0.compute_at(pad_temp_shared, conv_global, s0.stages[conv_global].iters[4]) + + # 4: cache_read with multi readers + # This stage cannot be compute at to its consumer + data_global = s0.cache_read(data, "global", [pad_temp, add], dag) + pad_temp += 1 + pad_temp_global += 1 + pad_temp_shared += 1 + kernel_data += 1 + kernel_split += 1 + kernel += 1 + kernel_global += 1 + conv_global += 1 + conv += 1 + relu += 1 + add += 1 + assert str(s0) == \ + "Placeholder: Data, Kernel_data\n" + \ + "for ax0 (0,4)\n" + \ + " for ax1 (0,512)\n" + \ + " for ax2 (0,7)\n" + \ + " for ax3 (0,7)\n" + \ + " Data.global = ...\n" + \ + "for i0 (0,4)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,9)\n" + \ + " for i3 (0,9)\n" + \ + " pad_temp = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel_split = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel = ...\n" + \ + "for nn (0,4)\n" + \ + " for ff (0,512)\n" + \ + " for yy (0,7)\n" + \ + " for xx (0,7)\n" + \ + " for nn_c (None)\n" + \ + " for ff_c (None)\n" + \ + " for yy_c (None)\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " pad_temp.global = ...\n" + \ + " for xx_c (None)\n" + \ + " for rc (None)\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " Kernel.global = ...\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " pad_temp.global.shared = ...\n" + \ + " for ry (None)\n" + \ + " for rx (None)\n" + \ + " compute.global = ...\n" + \ + " compute = ...\n" + \ + "for ax0.0 (0,2)\n" + \ + " for ax1 (0,512)\n" + \ + " for ax0.1 (0,2)\n" + \ + " for ax2 (0,7)\n" + \ + " for ax3 (0,7)\n" + \ + " T_add = ...\n" + + # 5: cache_write with multi outputs + # TVM's cache_write actually has a bug with this case: + # + # After schedule.cache_write, TVM generate one new stage: + # From: kernel_data -> kernel_split -> kernel + # To: kernel_data -> kernel_split_global -> kernel_split -> kernel + # + # But with topo sort analyse, we get: + # // kernel_data -> kernel_split_global -> kernel_split -> kernel + # \ / + # ----------------> kernel_split ----------------> + # + # Seems there's bug with the input/output tensor. Such multi outputs case + # should be unusual, so we make some hack on DoCacheWrite + # To be fixed in the future + s0.cache_write(kernel_split, "global", dag) + assert str(s0) == \ + "Placeholder: Data, Kernel_data\n" + \ + "for ax0 (0,4)\n" + \ + " for ax1 (0,512)\n" + \ + " for ax2 (0,7)\n" + \ + " for ax3 (0,7)\n" + \ + " Data.global = ...\n" + \ + "for i0 (0,4)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,9)\n" + \ + " for i3 (0,9)\n" + \ + " pad_temp = ...\n" + \ + "for i0_c (0,512)\n" + \ + " for i1_c (0,512)\n" + \ + " for i2_c (0,3)\n" + \ + " for i3_c (0,3)\n" + \ + " Kernel_split.global = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel_split = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel_split = ...\n" + \ + "for i0 (0,512)\n" + \ + " for i1 (0,512)\n" + \ + " for i2 (0,3)\n" + \ + " for i3 (0,3)\n" + \ + " Kernel = ...\n" + \ + "for nn (0,4)\n" + \ + " for ff (0,512)\n" + \ + " for yy (0,7)\n" + \ + " for xx (0,7)\n" + \ + " for nn_c (None)\n" + \ + " for ff_c (None)\n" + \ + " for yy_c (None)\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " pad_temp.global = ...\n" + \ + " for xx_c (None)\n" + \ + " for rc (None)\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " Kernel.global = ...\n" + \ + " for ax0 (None)\n" + \ + " for ax1 (None)\n" + \ + " for ax2 (None)\n" + \ + " for ax3 (None)\n" + \ + " pad_temp.global.shared = ...\n" + \ + " for ry (None)\n" + \ + " for rx (None)\n" + \ + " compute.global = ...\n" + \ + " compute = ...\n" + \ + "for ax0.0 (0,2)\n" + \ + " for ax1 (0,512)\n" + \ + " for ax0.1 (0,2)\n" + \ + " for ax2 (0,7)\n" + \ + " for ax3 (0,7)\n" + \ + " T_add = ...\n" + + +def test_rfactor(): + dag = ansor.ComputeDAG(matmul_nkkm(8, 8, 512)) + s0 = dag.get_init_state() + C = 2 + + ko, ki = s0.split(C, s0.stages[C].iters[2], [16]) + + s1 = s0.copy() + s1.rfactor(C, ko, 2, dag) + assert str(s1) == \ + "Placeholder: A, B\n" + \ + "for i (0,8)\n" + \ + " for j (0,8)\n" + \ + " for k_o (0,32)\n" + \ + " for k_i (0,16)\n" + \ + " C.rf = ...\n" + \ + "for ax0 (0,8)\n" + \ + " for ax1 (0,8)\n" + \ + " for k_o_v (0,32)\n" + \ + " C.repl = ...\n" + + s2 = s0.copy() + s2.rfactor(C, ki, 2, dag) + assert str(s2) == \ + "Placeholder: A, B\n" + \ + "for i (0,8)\n" + \ + " for j (0,8)\n" + \ + " for k_i (0,16)\n" + \ + " for k_o (0,32)\n" + \ + " C.rf = ...\n" + \ + "for ax0 (0,8)\n" + \ + " for ax1 (0,8)\n" + \ + " for k_i_v (0,16)\n" + \ + " C.repl = ...\n" + + +if __name__ == "__main__": + test_state_split_fuse_reorder_annotation() + test_follow_split_follow_fused_split() + test_state_cache_read_write() + test_rfactor() diff --git a/tests/python/unittest/test_ansor_measure.py b/tests/python/unittest/test_ansor_measure.py new file mode 100644 index 000000000000..baf8a0c4efa2 --- /dev/null +++ b/tests/python/unittest/test_ansor_measure.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test measurement and log serialization""" + +import tvm +from tvm import ansor +import tempfile + +from test_ansor_common import get_tiled_matmul + + +def test_serialization(): + dag, s = get_tiled_matmul() + target = tvm.target.create("llvm") + task = ansor.SearchTask(dag, "test", target) + + inp = ansor.measure.MeasureInput(task, s) + res = ansor.measure.MeasureResult([0.1], 0, "", 0.2, 1) + + with tempfile.NamedTemporaryFile() as fp: + ansor.serialization.write_measure_records_to_file(fp.name, [inp], [res]) + + log_reader = ansor.serialization.LogReader(fp.name) + inputs, results = log_reader.read_lines() + assert len(inputs) == 1 + + s1 = dag.infer_bound_from_state(s) + s2 = dag.infer_bound_from_state(inputs[0].state) + + assert s1 == s2 + assert not (s1 == dag.get_init_state().state_object) + + +def test_measure_local_builder_runner(): + dag, s0 = get_tiled_matmul() + + tgt = tvm.target.create("llvm") + task = ansor.SearchTask(dag, "test", tgt) + + minp = ansor.MeasureInput(task, s0) + local_builder = ansor.LocalBuilder() + local_runner = ansor.LocalRunner() + + bress = local_builder.build([minp]) + assert bress[0].error_no == 0 + mress = local_runner.run([minp], bress) + assert mress[0].error_no == 0 + + +if __name__ == "__main__": + test_serialization() + test_measure_local_builder_runner() diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py new file mode 100644 index 000000000000..eea3f5cfbda3 --- /dev/null +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test search policy""" + +import random +import os +import numpy as np +import tempfile + +import tvm +from tvm import ansor + +from test_ansor_common import matmul_nkkm + +def test_search_basic(): + print("Test schedule search with the default search policy") + + N = 128 + A, B, C = matmul_nkkm(N, N, N) + dag = ansor.ComputeDAG([A, B, C]) + tgt = tvm.target.create("llvm") + task = ansor.SearchTask(dag, "test", tgt) + + seed = 944563397 + random.seed(seed) + + with tempfile.NamedTemporaryFile() as fp: + log_file = fp.name + + cost_model = ansor.RandomModel() + search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) + tune_option = ansor.TuneOption(n_trials=2, + callbacks=[ansor.LogToFile(log_file)]) + state = ansor.auto_schedule(task, search_policy, + tune_option=tune_option) + sch, args = dag.apply_steps_from_state(state) + + print("==== Get State ====") + print(state) + print("==== Get Python Code ====") + print(dag.print_python_code_from_state(state)) + + try: + print("==== Get Lowered Stmt ====") + print(tvm.lower(sch, args, simple_mode=True)) + mod = tvm.build(sch, args, tgt) + + ctx = tvm.context("llvm", 0) + a = tvm.nd.array(np.random.uniform(size=(N, N)).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=(N, N)).astype(B.dtype), ctx) + c = tvm.nd.array(np.zeros((N, N), dtype=C.dtype), ctx) + mod(a, b, c) + tvm.testing.assert_allclose(c.asnumpy(), np.dot( + a.asnumpy(), b.asnumpy()), rtol=1e-5) + print("==== Verification passed ====") + except Exception: + raise Exception("Error encountered with seed: %d" % (seed)) + + inp, res = ansor.best_measure_pair_in_file(log_file) + s0 = dag.infer_bound_from_state(state) + s1 = dag.infer_bound_from_state(inp.state) + assert s0 == s1 + + +if __name__ == "__main__": + test_search_basic() From 43d1530a253dc65aaf9f8da9cc818e9e0c4a1db0 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 7 Jun 2020 23:30:18 -0700 Subject: [PATCH 10/45] fix unit tests --- tests/python/unittest/test_ansor_loop_state.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/python/unittest/test_ansor_loop_state.py b/tests/python/unittest/test_ansor_loop_state.py index 34b720e7e1af..287a1b773395 100644 --- a/tests/python/unittest/test_ansor_loop_state.py +++ b/tests/python/unittest/test_ansor_loop_state.py @@ -23,7 +23,7 @@ from test_ansor_common import matmul_nkkm, conv2d_nchw_bn_relu -def test_state_split_fuse_reorder_annotation(): +def test_split_fuse_reorder_annotation(): dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) s0 = dag.get_init_state() C = 2 @@ -108,7 +108,7 @@ def test_follow_split_follow_fused_split(): tmp.stages[C_global].iters[1].range.extent -def test_state_compute_at_root_inline(): +def test_compute_at_root_inline(): dag = ansor.ComputeDAG(conv2d_nchw_bn_relu(1, 224, 224, 3, 64, 7, 2, 3)) # data, padding, kernel = 0, 1, 2 @@ -144,8 +144,8 @@ def test_state_compute_at_root_inline(): " for i3 (0,112)\n" + \ " compute = ...\n" - s0 = s0.compute_root(conv) - s0 = s0.compute_root(bn_mul) + s0.compute_root(conv) + s0.compute_root(bn_mul) assert str(s0) == \ "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + \ "for i1 (0,3)\n" + \ @@ -171,7 +171,7 @@ def test_state_compute_at_root_inline(): " compute = ...\n" -def test_state_cache_read_write(): +def test_cache_read_write(): N, H, W, CO, CI, KH, KW, strides, padding = 4, 7, 7, 512, 512, 3, 3, ( 1, 1), (1, 1) @@ -469,7 +469,8 @@ def test_rfactor(): if __name__ == "__main__": - test_state_split_fuse_reorder_annotation() + test_split_fuse_reorder_annotation() test_follow_split_follow_fused_split() - test_state_cache_read_write() + test_compute_at_root_inline() + test_cache_read_write() test_rfactor() From f367d1533a10c2d476b7a12e54c5261f71b08cfb Mon Sep 17 00:00:00 2001 From: Chenfan Date: Mon, 8 Jun 2020 14:36:42 +0800 Subject: [PATCH 11/45] Add RPCRunner & OpenCL/CUDA test (#12) * Add RPCRunner & OpenCL search test * Add CUDA search test * Add RPCRunner test --- python/tvm/ansor/__init__.py | 2 +- python/tvm/ansor/measure.py | 22 +++++++ python/tvm/rpc/server.py | 3 +- src/ansor/measure.cc | 8 +++ .../search_policy/meta_tile_rewrite_policy.h | 1 - tests/python/unittest/test_ansor_measure.py | 29 +++++++++ .../unittest/test_ansor_search_policy.py | 61 +++++++++++++++++-- 7 files changed, 117 insertions(+), 9 deletions(-) diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 1be7ed404c17..7552878a3c50 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -28,6 +28,6 @@ from .compute_dag import ComputeDAG from .task import SearchTask, MetaTileRewritePolicy, TuneOption from .task import auto_schedule -from .measure import MeasureInput, LocalBuilder, LocalRunner +from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner from .cost_model import RandomModel from .serialization import LogToFile, LogReader, best_measure_pair_in_file diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index 5438edfaa6b2..b80de7c01633 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -168,6 +168,28 @@ def __init__(self, _ffi_api.LocalRunner, timeout, number, repeat, min_repeat_ms, cooldown_interval) +@tvm._ffi.register_object("ansor.RPCRunner") +class RPCRunner(Runner): + def __init__(self, key, host, port, priority=1, + n_parallel=1, + timeout=10, + number=3, + repeat=1, + min_repeat_ms=0, + cooldown_interval=0.0): + self.__init_handle_by_constructor__( + _ffi_api.RPCRunner, key, host, port, priority, timeout, n_parallel, + number, repeat, min_repeat_ms, cooldown_interval) + + if check_remote(key, host, port, priority, timeout): + logger.info("Get devices for measurement successfully!") + else: + raise RuntimeError("Cannot get remote devices from the tracker. " + "Please check the status of tracker by " + "'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' " + "and make sure you have free devices on the queue status.") + + MAX_ERROR_MSG_LEN = 512 diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index 15a3c7de789d..42bcb00a9117 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -348,7 +348,8 @@ def __init__(self, cmd = [sys.executable, "-m", "tvm.exec.rpc_server", "--host=%s" % host, - "--port=%s" % port] + "--port=%s" % port, + "--port-end=%s" % port_end] if tracker_addr: assert key cmd += ["--tracker=%s:%d" % tracker_addr, diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc index 43be530f2a35..e3593753d3ff 100644 --- a/src/ansor/measure.cc +++ b/src/ansor/measure.cc @@ -368,5 +368,13 @@ TVM_REGISTER_GLOBAL("ansor.LocalRunner") cooldown_interval); }); +TVM_REGISTER_GLOBAL("ansor.RPCRunner") +.set_body_typed([](const std::string& key, const std::string& host, int port, + int priority, int timeout, int n_parallel, int number, + int repeat, int min_repeat_ms, double cooldown_interval) { + return RPCRunnerNode::make(key, host, port, priority, timeout, n_parallel, + number, repeat, min_repeat_ms, cooldown_interval); +}); + } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.h b/src/ansor/search_policy/meta_tile_rewrite_policy.h index 0c8c44b9c5ea..823ef6df4983 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.h +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.h @@ -76,7 +76,6 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode { SearchTask cur_task_; // The current task - friend class MetaTileRewritePolicyNodeTest; // Hack friend class for UT protected: // Pick states from best states and random states with eps-greedy policy void PickStatesWithEpsGreedy(std::vector* inputs, diff --git a/tests/python/unittest/test_ansor_measure.py b/tests/python/unittest/test_ansor_measure.py index baf8a0c4efa2..0385568894fe 100644 --- a/tests/python/unittest/test_ansor_measure.py +++ b/tests/python/unittest/test_ansor_measure.py @@ -19,6 +19,8 @@ import tvm from tvm import ansor +from tvm.rpc.tracker import Tracker +from tvm.rpc.server import Server import tempfile from test_ansor_common import get_tiled_matmul @@ -62,6 +64,33 @@ def test_measure_local_builder_runner(): assert mress[0].error_no == 0 +def test_measure_local_builder_rpc_runner(): + dag, s0 = get_tiled_matmul() + + tgt = tvm.target.create("llvm") + task = ansor.SearchTask(dag, "test", tgt) + + minp = ansor.MeasureInput(task, s0) + local_builder = ansor.LocalBuilder() + host = '0.0.0.0' + tracker = Tracker(host, port=9000, port_end=10000, silent=True) + device_key = '$local$device$%d' % tracker.port + server = Server(host, port=tracker.port, port_end=10000, + key=device_key, + use_popen=True, silent=True, + tracker_addr=(tracker.host, tracker.port)) + rpc_runner = ansor.RPCRunner(device_key, host, tracker.port) + + bress = local_builder.build([minp]) + assert bress[0].error_no == 0 + mress = rpc_runner.run([minp], bress) + assert mress[0].error_no == 0 + + tracker.terminate() + server.terminate() + + if __name__ == "__main__": test_serialization() test_measure_local_builder_runner() + test_measure_local_builder_rpc_runner() diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py index eea3f5cfbda3..9a57691aba22 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -24,19 +24,20 @@ import tvm from tvm import ansor +from tvm.rpc.tracker import Tracker +from tvm.rpc.server import Server from test_ansor_common import matmul_nkkm -def test_search_basic(): - print("Test schedule search with the default search policy") +def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local'): + print("Test %s schedule search with the default search policy" % (target)) N = 128 A, B, C = matmul_nkkm(N, N, N) dag = ansor.ComputeDAG([A, B, C]) - tgt = tvm.target.create("llvm") + tgt = tvm.target.create(target) task = ansor.SearchTask(dag, "test", tgt) - seed = 944563397 random.seed(seed) with tempfile.NamedTemporaryFile() as fp: @@ -44,7 +45,7 @@ def test_search_basic(): cost_model = ansor.RandomModel() search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) - tune_option = ansor.TuneOption(n_trials=2, + tune_option = ansor.TuneOption(n_trials=2, runner=runner, callbacks=[ansor.LogToFile(log_file)]) state = ansor.auto_schedule(task, search_policy, tune_option=tune_option) @@ -60,7 +61,7 @@ def test_search_basic(): print(tvm.lower(sch, args, simple_mode=True)) mod = tvm.build(sch, args, tgt) - ctx = tvm.context("llvm", 0) + ctx = tvm.context(target, 0) a = tvm.nd.array(np.random.uniform(size=(N, N)).astype(A.dtype), ctx) b = tvm.nd.array(np.random.uniform(size=(N, N)).astype(B.dtype), ctx) c = tvm.nd.array(np.zeros((N, N), dtype=C.dtype), ctx) @@ -75,7 +76,55 @@ def test_search_basic(): s0 = dag.infer_bound_from_state(state) s1 = dag.infer_bound_from_state(inp.state) assert s0 == s1 + print() + + +def test_search_basic(): + search_common(seed=944563397) + + +def test_search_opencl(): + if tvm.context("opencl", 0).exist: + host = '0.0.0.0' + tracker = Tracker(host, port=9000, port_end=10000, silent=True) + device_key = '$local$device$%d' % tracker.port + server = Server(host, port=tracker.port, port_end=10000, + key=device_key, + use_popen=True, silent=True, + tracker_addr=(tracker.host, tracker.port)) + rpc_runner = ansor.RPCRunner(device_key, host, tracker.port) + + search_common("opencl", 380344973, rpc_runner) + + tracker.terminate() + server.terminate() + else: + print("OpenCL device not found, skip this test.") + + +def test_search_cuda(): + ctx = tvm.context("cuda", 0) + if ctx.exist: + cuda_arch = "sm_" + "".join(ctx.compute_version.split('.')) + tvm.autotvm.measure.measure_methods.set_cuda_target_arch(cuda_arch) + host = '0.0.0.0' + tracker = Tracker(host, port=9000, port_end=10000, silent=True) + device_key = '$local$device$%d' % tracker.port + server = Server(host, port=tracker.port, port_end=10000, + key=device_key, + use_popen=True, silent=True, + tracker_addr=(tracker.host, tracker.port)) + rpc_runner = ansor.RPCRunner(device_key, host, tracker.port) + + search_common("cuda", 903667810, rpc_runner) + + tracker.terminate() + server.terminate() + else: + print("CUDA device not found, skip this test.") if __name__ == "__main__": test_search_basic() + test_search_opencl() + test_search_cuda() From 2bd6471d6cc3126bea111b373bbfc273dbf8e595 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 8 Jun 2020 01:10:27 -0700 Subject: [PATCH 12/45] rebase to upstream/master --- .gitignore | 1 + python/tvm/ansor/measure.py | 4 ++-- src/ansor/compute_dag.cc | 20 +++++++++---------- src/ansor/{feature.cc => feature.ccc} | 0 .../search_policy/meta_tile_rewrite_policy.cc | 4 ++-- .../search_policy/meta_tile_rewrite_policy.h | 4 ++-- src/ansor/search_policy/utils.h | 8 ++++---- .../python/unittest/test_ansor_compute_dag.py | 8 ++++++++ 8 files changed, 28 insertions(+), 21 deletions(-) rename src/ansor/{feature.cc => feature.ccc} (100%) diff --git a/.gitignore b/.gitignore index b9357018a64c..506e54d93067 100644 --- a/.gitignore +++ b/.gitignore @@ -196,6 +196,7 @@ tvm_t.* .python_history .pytest_cache .local +cmake-build-debug # Visual Studio Code .vscode diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index b80de7c01633..e10da09e4b5a 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -34,7 +34,7 @@ import tvm._ffi from tvm.runtime import Object, module, ndarray from tvm.driver import build_module -from tvm.target import build_config +from tvm.ir import transform from ..contrib import tar, ndk from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout, request_remote, check_remote from .compute_dag import LayoutRewriteLevel @@ -254,7 +254,7 @@ def timed_func(): dirname, "tmp_func." + build_func.output_format) try: - with build_config(unroll_max_extent=task.hardware_params.max_unroll_vec): + with transform.PassContext(): # todo(lmzheng): port the unroll pass func = build_module.build( sch, args, target=task.target, target_host=task.target_host) func.export_library(filename, build_func) diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index f3979ef0d259..de3b98a5106b 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -129,17 +129,19 @@ class TensorAccessExtractor : public StmtExprVisitor { this->VisitExpr(expr); } - void VisitExpr_(const CallNode *op) final { - if (op->call_type == CallNode::CallType::Halide) { - buf_accesses[Downcast(op->func)].emplace_back( - op->args.begin(), op->args.end()); - } + void VisitExpr_(const CallNode* op) final { if (op->name == tir::intrinsic::tvm_if_then_else) { has_branch = true; } StmtExprVisitor::VisitExpr_(op); } + void VisitExpr_(const ProducerLoadNode* op) final { + buf_accesses[Downcast(op->producer)->op].emplace_back( + op->indices.begin(), op->indices.end()); + StmtExprVisitor::VisitExpr_(op); + } + void VisitStmt_(const IfThenElseNode* op) final { has_branch = true; StmtExprVisitor::VisitStmt_(op); @@ -518,7 +520,7 @@ class FlopEstimator: public ExprFunctor { double VisitExpr_(const FloatImmNode* op) final { return 0.0; } double VisitExpr_(const IntImmNode* op) final { return 0.0; } -// double VisitExpr_(const UIntImm* op) final { return 0.0; } + double VisitExpr_(const ProducerLoadNode* op) final { return 0.0; } double VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); } double VisitExpr_(const VarNode* op) final { return 0.0; } @@ -545,11 +547,6 @@ class FlopEstimator: public ExprFunctor { VisitBinary(AndNode); VisitBinary(OrNode); VisitUnary(NotNode); double VisitExpr_(const CallNode* op) final { - if (op->call_type == CallNode::CallType::Halide) { - // ignore flops in index expressions - return 0.0; - } - double ret = 0.0; for (const auto&x : op->args) { ret += VisitExpr(x); @@ -557,6 +554,7 @@ class FlopEstimator: public ExprFunctor { return ret; } + double VisitExprDefault_(const Object* op) final { fail = true; return -1.0; diff --git a/src/ansor/feature.cc b/src/ansor/feature.ccc similarity index 100% rename from src/ansor/feature.cc rename to src/ansor/feature.ccc diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.cc b/src/ansor/search_policy/meta_tile_rewrite_policy.cc index c22d890a8b51..86a7eba1da3a 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.cc +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.cc @@ -47,7 +47,7 @@ TVM_REGISTER_OBJECT_TYPE(MetaTileRewritePolicyNode); const std::vector MetaTileRewritePolicyNode::auto_unroll_configs{0, 16, 64, 512, 1024}; SearchPolicy MetaTileRewritePolicyNode::make(CostModel program_cost_model, - Map params, + Map params, int seed) { auto node = make_object(); node->program_cost_model = std::move(program_cost_model); @@ -1440,7 +1440,7 @@ void MetaTileRewritePolicyNode::EvolutionarySearch( TVM_REGISTER_GLOBAL("ansor.MetaTileRewritePolicy") .set_body_typed([](CostModel program_cost_model, - Map params, + Map params, int seed){ return MetaTileRewritePolicyNode::make(program_cost_model, params, seed); }); diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.h b/src/ansor/search_policy/meta_tile_rewrite_policy.h index 823ef6df4983..f92813b11273 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.h +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.h @@ -53,10 +53,10 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode { * str cpu_multi_level_tiling_structure // The structure of multi-level tiling for CPU * str gpu_multi_level_tiling_structure // The structure of multi-level tiling for GPU */ - Map params; + Map params; static SearchPolicy make(CostModel program_cost_model, - Map params, + Map params, int seed); // Search and make n_trails measurements diff --git a/src/ansor/search_policy/utils.h b/src/ansor/search_policy/utils.h index 607a549e1b8a..3d0611173c94 100644 --- a/src/ansor/search_policy/utils.h +++ b/src/ansor/search_policy/utils.h @@ -41,7 +41,7 @@ namespace tvm { namespace ansor { // Get an integer from a tvm str Map -inline int GetIntParam(const Map& attr_dict, +inline int GetIntParam(const Map& attr_dict, const std::string& key) { CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; auto pint = attr_dict[key].as(); @@ -50,7 +50,7 @@ inline int GetIntParam(const Map& attr_dict, } // Get a double from a tvm str Map -inline double GetDoubleParam(const Map& attr_dict, +inline double GetDoubleParam(const Map& attr_dict, const std::string& key) { CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; auto pdouble = attr_dict[key].as(); @@ -59,7 +59,7 @@ inline double GetDoubleParam(const Map& attr_dict, } // Get a string from a tvm str Map -inline std::string GetStringParam(const Map& attr_dict, +inline std::string GetStringParam(const Map& attr_dict, const std::string& key) { CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; @@ -73,7 +73,7 @@ inline std::string GetStringParam(const Map& attr_dict, } // Get a iterator name set from a tvm str Map -inline std::set GetIterNameSetParam(const Map& attr_dict, +inline std::set GetIterNameSetParam(const Map& attr_dict, const std::string& key) { std::set ret; CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; diff --git a/tests/python/unittest/test_ansor_compute_dag.py b/tests/python/unittest/test_ansor_compute_dag.py index 61eb0153a87c..b60136d4265f 100644 --- a/tests/python/unittest/test_ansor_compute_dag.py +++ b/tests/python/unittest/test_ansor_compute_dag.py @@ -43,6 +43,12 @@ def test_infer_bound(): assert s.stages[C_global].iters[0].range.extent == 64 +def test_estimate_flop(): + dag, s = get_tiled_matmul() + + assert abs(dag.flop_ct - 2 * 512 ** 3) < 0.5 + + def test_lower_legalize_invalid_attach(): N, M = 10, 10 @@ -63,4 +69,6 @@ def test_lower_legalize_invalid_attach(): if __name__ == "__main__": test_apply_steps() test_infer_bound() + test_estimate_flop() test_lower_legalize_invalid_attach() + From c860f2c27f46733798c5deb488e5856f1d63d77c Mon Sep 17 00:00:00 2001 From: Chenfan Date: Mon, 8 Jun 2020 21:04:42 +0800 Subject: [PATCH 13/45] Add Ansor basic tutorial (#13) * Add basic tutorial --- docs/conf.py | 1 + tutorials/ansor/README.txt | 4 + tutorials/ansor/tune_simple_subgraph.py | 204 ++++++++++++++++++++++++ tutorials/autotvm/README.txt | 4 +- 4 files changed, 211 insertions(+), 2 deletions(-) create mode 100644 tutorials/ansor/README.txt create mode 100644 tutorials/ansor/tune_simple_subgraph.py diff --git a/docs/conf.py b/docs/conf.py index 7ece63bd7aa8..5cbaab7f7b6d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -197,6 +197,7 @@ ['../tutorials/frontend', '../tutorials/language', '../tutorials/optimize', + '../tutorials/ansor', '../tutorials/autotvm', '../tutorials/dev', '../tutorials/topi', diff --git a/tutorials/ansor/README.txt b/tutorials/ansor/README.txt new file mode 100644 index 000000000000..85b6ba401dae --- /dev/null +++ b/tutorials/ansor/README.txt @@ -0,0 +1,4 @@ +.. _tutorial-ansor-auto-schedule: + +Ansor: Template Free Auto Scheduling +------------------------------------ diff --git a/tutorials/ansor/tune_simple_subgraph.py b/tutorials/ansor/tune_simple_subgraph.py new file mode 100644 index 000000000000..8555d6163c32 --- /dev/null +++ b/tutorials/ansor/tune_simple_subgraph.py @@ -0,0 +1,204 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Writing compute expression and Using Ansor auto-scheduler +========================================================= +**Author**: `Lianmin Zheng `_, \ + `Chengfan Jia `_, \ + `Minmin Sun `_, \ + `Zhao Wu `_ + +This is an introduction tutorial to the auto-scheduler module in TVM. + +There are two steps in auto-scheduling. +The first step is defining the target task. +The second step is running a search algorithm to auto explore the schedule. +In this tutorial, you can learn how to perform these two steps in TVM. +The whole workflow is illustrated by a matrix multiplication with bias add example. +""" + +###################################################################### +# Install dependencies +# -------------------- +# To use Ansor package in TVM, we need to install some extra dependencies. +# This step (installing xgboost) can be skipped as it doesn't need XGBoost +# (change "3" to "2" if you use python2): +# +# .. code-block:: bash +# +# pip3 install --user psutil xgboost +# +# To make TVM run faster in tuning, it is recommended to use cython +# as FFI of TVM. In the root directory of TVM, execute +# (change "3" to "2" if you use python2): +# +# .. code-block:: bash +# +# pip3 install --user cython +# sudo make cython3 +# +# Now return to python code. Import packages. + +import random +import sys + +import numpy as np +import tvm +from tvm import te + +# the module is called `ansor` +from tvm import ansor + +###################################################################### +# Step 1: Define the target compute subgraph +# ------------------------------------------- +# In this section, we will write a deterministic TVM compute expression code +# to a compute subgraph. +# +# .. note:: Comparing to :ref:`tutorials-autotvm-sec` +# +# In Ansor, we do not need users to provide a schedule template, the only input +# is the compute expression writing by :code:`tvm.te` API or topi op API. +# +# Here is how we implement a matrix multiplication subgraph in TVM. + +# Matmul with bias add +def matmul_add(N, L, M, dtype): + A = te.placeholder((N, L), name='A', dtype=dtype) + B = te.placeholder((L, M), name='B', dtype=dtype) + C = te.placeholder((N, M), name='C', dtype=dtype) + + k = te.reduce_axis((0, L), name='k') + mul = te.compute((N, M), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), + name='Mul') + D = te.compute((N, M), lambda i, j: C[i, j] + mul[i, j], name='D') + + return [A, B, C, D] + +###################################################################### +# Step 2: Search through the schedule space +# ------------------------------------------ +# In step 1, we build the compute subgraph. +# The next step is to pick a cost model as well as a search policy and explore the +# possible schedule. +# +# Auto-scheduler in TVM +# ^^^^^^^^^^^^^^^^^^^^^ +# The job for the Ansor auto-scheduler can be described by following pseudo code +# +# .. code-block:: c +# +# ct = 0 +# while ct < max_number_of_trials: +# auto generate a batch of schedules +# measure this batch of schedules on real hardware and get results +# ct += batch_size +# +# When proposing the next batch of schedules, Ansor can take different cost models to +# guide the schedule generating process. +# +# * :any:`RandomModel`: Generate and take new schedule randomly +# * :any:`XGBModel`: Use XGBoost model to estimate the performance of potential schedules, try to pick schedules with better performance in each step +# +# XGBModel can explore more efficiently and find better schedules. + +################################################################ +# Begin tuning +# ^^^^^^^^^^^^ +# Here we continue our matrix multiplication example. +# +# The :code:`ansor.ComputeDAG` takes the Tensor list as input, and generates +# a dag structure. During which process, :code:`ansor.ComputeDAG` will +# do some analyzes with the target subgraph and the results will be used in +# search policy later. +# +# Then we create the :code:`tvm.target` and a tuning task. + +N, L, M = 64, 64, 64 +A, B, C, D = matmul_add(N, L, M, 'float32') +dag = ansor.ComputeDAG([A, B, C, D]) + +print(dag) +print(dag.access_analyzer) + +tgt = tvm.target.create("llvm") +task = ansor.SearchTask(dag, "test", tgt) + +################################################################ +# Next, we choose random model and create a default search policy: +# :code:`ansor.MetaTileRewritePolicy`. +# +# We only make 5 trials in this tutorial for demonstration. In practice, +# you can do more trials according to your time budget. +# The :code:`ansor.LogToFile` callback will log the tuning results into a +# log file, which can be used to get the best config later. +# +# Then just call :code:`ansor.auto_schedule` and Ansor will try to find a high +# performance schedule for the target subgraph automatically. + +log_file = "matmul_add.json" + +seed = 0 +random.seed(seed) +cost_model = ansor.RandomModel() +search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) + +tune_option = ansor.TuneOption(n_trials=5, + callbacks=[ansor.LogToFile(log_file)]) + +state = ansor.auto_schedule(task, search_policy, + tune_option=tune_option) +print(state) + +######################################################################### +# Finally we apply the history best to be a TVM schedule. +# +# We can call the function :code:`apply_steps_from_state` directly using the returned +# :code:`state` structure. +# :code:`state` can also be used to print out the user friendly Python code on demand. +# +# And since we've record the runing results to file, we can also use the following +# code to reply the best schedule from the log file: +# .. code-block:: c +# +# inp, res = ansor.best_measure_pair_in_file(log_file) +# state = inp.state +# s, arg_bufs = dag.apply_steps_from_state(state) +# +# With the :code:`state` above, we have lowered result and its python code: + +s, arg_bufs = dag.apply_steps_from_state(state) +print("==== Get Lowered Stmt ====") +print(tvm.lower(s, arg_bufs, simple_mode=True)) +print("==== Get Python Code ====") +print(dag.print_python_code_from_state(state)) + +######################################################################### +# Check the correctness to make sure we generate a right schedule. + +func = tvm.build(s, arg_bufs) + +# check correctness +a_np = np.random.uniform(size=(N, L)).astype(np.float32) +b_np = np.random.uniform(size=(L, M)).astype(np.float32) +c_np = np.random.uniform(size=(N, M)).astype(np.float32) +d_np = a_np.dot(b_np) + c_np + +d_tvm = tvm.nd.empty(d_np.shape) +func(tvm.nd.array(a_np), tvm.nd.array(b_np), tvm.nd.array(c_np), d_tvm) + +tvm.testing.assert_allclose(d_np, d_tvm.asnumpy(), rtol=1e-2) diff --git a/tutorials/autotvm/README.txt b/tutorials/autotvm/README.txt index 38e3b3343f4e..4ad36c000e3c 100644 --- a/tutorials/autotvm/README.txt +++ b/tutorials/autotvm/README.txt @@ -1,4 +1,4 @@ .. _tutorials-autotvm-sec: -Auto tuning ------------ +AutoTVM: Template Based Auto Tuning +----------------------------------- From f60d1a60dc96ac408625a85f34eea099c78dd8eb Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 8 Jun 2020 06:28:31 -0700 Subject: [PATCH 14/45] migrate feature extraction (#14) --- python/tvm/ansor/__init__.py | 3 +- python/tvm/ansor/feature.py | 147 +++++++ src/ansor/{feature.ccc => feature.cc} | 401 ++++++++++++++------ src/ansor/feature.h | 35 +- tests/python/unittest/test_ansor_feature.py | 97 +++++ 5 files changed, 566 insertions(+), 117 deletions(-) create mode 100644 python/tvm/ansor/feature.py rename src/ansor/{feature.ccc => feature.cc} (79%) create mode 100644 tests/python/unittest/test_ansor_feature.py diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 7552878a3c50..3e9b76c2f6ad 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=unused-import, redefined-builtin -"""Namespace for Ansor autoSchedule""" +"""Namespace for Ansor auto-scheduler""" from . import compute_dag from . import measure @@ -23,6 +23,7 @@ from . import loop_state from . import task from . import utils +from . import feature # Shortcut from .compute_dag import ComputeDAG diff --git a/python/tvm/ansor/feature.py b/python/tvm/ansor/feature.py new file mode 100644 index 000000000000..fb5fadf16296 --- /dev/null +++ b/python/tvm/ansor/feature.py @@ -0,0 +1,147 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""""Python API for Feature extraction. +The specification of features can be found in `autoscheduler_doc/per_stage_feature.md` +""" + +from typing import List, Tuple +import struct +import numpy as np + +from .loop_state import StateObject +from .task import SearchTask +from .measure import MeasureInput, MeasureResult +from . import _ffi_api + + +DEFAULT_MAX_N_BUFS = 5 + +DEFAULT_FEATURE_VEC_LEN = 164 + + +def unpack_feature(byte_arr: bytearray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Unpack the encoded feature (in byte array format) of from c++""" + size_of_int = 4 + size_of_float = 4 + + """ + The format for n records is: + { + int n; + int[n+2] sizes + + float[sizes[0]] feature for record 1 + float[sizes[1]] feature for record 2 + ... feature for record i... + float[sizes[n-1]] feature for record n + + float[sizes[n]] normalized throughput for n records + int[sizes[n+1]] task id for n records + } + """ + vec_len = DEFAULT_FEATURE_VEC_LEN + + # unpack sizes + offset = 0 + n = struct.unpack_from("1i", byte_arr, offset=offset)[0] + offset += size_of_int + + sizes = struct.unpack_from("%di" % (n+2), byte_arr, offset=offset) + offset += size_of_int * (n+2) + + # unpack features + features = [] + for size in sizes[:-2]: + row = [] + + """ + Now we need to unpack the feature for multiple statements. + The format is: + { + int n_stmts + float[n_stmt][vec_len] feature_vecs + } + where vec_len can be calculated by `(size - 1) / n_stmts` + """ + if size == 0: + # failed during lowering + features.append(np.zeros((1, vec_len))) + else: + n_stmts = struct.unpack_from("f", byte_arr, offset=offset) + offset += size_of_float + + n_stmts = int(n_stmts[0] + 0.5) + tmp_vec_len = (size - 1) // n_stmts + assert tmp_vec_len == vec_len, "The lenght of feature vector is wrong. " \ + "Expected %d but got %d." % (vec_len, tmp_vec_len) + assert (size - 1) % n_stmts == 0 + for _ in range(n_stmts): + x = struct.unpack_from("%df" % vec_len, byte_arr, offset=offset) + offset += vec_len * size_of_float + row.append(x) + + features.append(np.array(row)) + + # unpack normalized_throughputs + m = sizes[-2] + normalized_throughputs = struct.unpack_from("%df" % m, byte_arr, offset=offset) + offset += m * size_of_int + + # unpack task_ids + m = sizes[-1] + task_ids = struct.unpack_from("%di" % m, byte_arr, offset=offset) + offset += m * size_of_int + + assert offset == len(byte_arr), "%d vs %d" % (offset, len(byte_arr)) + return np.array(features), np.array(normalized_throughputs), np.array(task_ids) + + +def get_per_stmt_features_from_file(filename: str, + n_lines: int, + max_n_bufs: int = None) \ + -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Get per_stmt features from a log file""" + byte_arr = _ffi_api.GetPerStmtFeaturesFromFile( + filename, n_lines, max_n_bufs or DEFAULT_MAX_N_BUFS) + return unpack_feature(byte_arr) + + +def get_per_stmt_features_from_measure_pairs(inputs: List[MeasureInput], + results: List[MeasureResult], + skip_first_n_feature_extraction: int = 0, + max_n_bufs: int = None,) \ + -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Get per_stmt features from measurement pairs""" + byte_arr = _ffi_api.GetPerStmtFeaturesFromMeasurePairs( + inputs, results, skip_first_n_feature_extraction, max_n_bufs or DEFAULT_MAX_N_BUFS) + return unpack_feature(byte_arr) + + +def get_per_stmt_features_from_states(states: List[StateObject], + task: SearchTask, + max_n_bufs: int = None) -> List[np.ndarray]: + """Get per_stmt features from states""" + byte_arr = _ffi_api.GetPerStmtFeaturesFromStates( + states, task, max_n_bufs or DEFAULT_MAX_N_BUFS) + return unpack_feature(byte_arr)[0] + + +def get_per_stmt_feature_names(max_n_bufs: int = None) -> List[str]: + """Get names of the elements in the flatten feature vector""" + return [x for x in + _ffi_api.GetPerStmtFeatureNames(max_n_bufs or DEFAULT_MAX_N_BUFS)] diff --git a/src/ansor/feature.ccc b/src/ansor/feature.cc similarity index 79% rename from src/ansor/feature.ccc rename to src/ansor/feature.cc index 31afe931361c..16ddb73ebf47 100644 --- a/src/ansor/feature.ccc +++ b/src/ansor/feature.cc @@ -1,5 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors + * \file ansor/feature.cc + * \brief Feature extraction for the cost model */ #include @@ -7,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -15,16 +36,12 @@ #include "measure.h" #include "serialization.h" #include "utils.h" -// #include "../arithmetic/compute_expr.h" namespace tvm { -/* Import the function from build_module.cc */ -extern void GetBinds(const Array& args, - bool compact, - const std::unordered_map& binds, - Map* out_binds, - Array* out_arg_list, - const BuildConfig& config); +/* Import the function from driver_api.cc */ +extern void GetBinds(const Array& args, bool compact, + const std::unordered_map& binds, + Map* out_binds, Array* out_arg_list); } // namespace tvm @@ -35,6 +52,9 @@ using namespace tvm::tir; using arith::ConstIntBound; using arith::Analyzer; +template +using BufferMap = std::unordered_map; + static const int ARITH_INTENSITY_CURVE_SAMPLE_N = 10; // Annotation position encoding @@ -61,7 +81,7 @@ enum ReuseType { // Feature for an access of a buffer struct BufferAccessFeature { - std::string tensor_name; + std::string buffer_name; BufferAccessType acc_type; float bytes; float unique_bytes; @@ -169,10 +189,11 @@ AnnotationPosType GetAnnotationPosEncoding( } if (find_ct == 0) { - // If not find in spatial args, then it is a reduce iteartor. + // If not find in spacial args, then it is a reduce iterator. // Use name to match + const std::string& var_name = var->name_hint; for (size_t i = 0; i < reduce_axis.size(); ++i) { - if (var->name_hint.find(reduce_axis[i]->var->name_hint) != std::string::npos) { + if (var_name.find(reduce_axis[i]->var->name_hint) != std::string::npos) { find_i = i; find_ct++; } @@ -238,7 +259,6 @@ class MathOpCounter : public StmtExprVisitor { void VisitExpr_(const NotNode* op) final { bool_op++; StmtExprVisitor::VisitExpr_(op); } void VisitExpr_(const SelectNode* op) final { select_op++; StmtExprVisitor::VisitExpr_(op); } - // TODO(...): CallNode with type CallNode::Halide has been modified to BufferLoadNode void VisitExpr_(const CallNode* op) final { if (op->call_type == CallNode::CallType::PureIntrinsic) { if (op->dtype.is_float()) { @@ -246,8 +266,8 @@ class MathOpCounter : public StmtExprVisitor { } else { int_math_func++; } - } else if (op->call_type != CallNode::CallType::Halide) { - if (op->dtype.is_float()) { + } else { + if (op->dtype.is_float()) { float_other_func++; } else { int_other_func++; @@ -272,42 +292,38 @@ class BufferAccessExtractor : public StmtExprVisitor { this->VisitExpr(expr); } - void InsertAccess(const te::Tensor& ten, BufferAccessType acc_type, + void InsertAccess(const Buffer& buf, BufferAccessType acc_type, const Array& indices) { - BufferAccess& acc = buf_accesses[ten]; + BufferAccess& acc = buf_accesses[buf]; acc.acc_type = acc_type; acc.indices.push_back(std::vector(indices.begin(), indices.end())); } - // TODO(...): CallNode with type CallNode::Halide has been modified to BufferLoadNode - void VisitExpr_(const CallNode *op) final { - if (op->call_type == CallNode::CallType::Halide) { - te::Tensor ten = Downcast(op->func).output(op->value_index); - BufferAccess& acc = buf_accesses[ten]; - switch (acc.acc_type) { - case kRead: - break; - case kWrite: - acc.acc_type = kReadWrite; break; - case kReadWrite: - break; - case kUnknownRW: - default: - acc.acc_type = kRead; break; - } + void VisitExpr_(const BufferLoadNode *op) final { + BufferAccess& acc = buf_accesses[op->buffer]; + switch (acc.acc_type) { + case kRead: + break; + case kWrite: + acc.acc_type = kReadWrite; break; + case kReadWrite: + break; + case kUnknownRW: + default: + acc.acc_type = kRead; break; + } - if (acc.acc_type != kReadWrite) { - // If a buffer is both read and written, in the tvm DSL, it must be a update, - // so the indices should be the same. Then we can skip appending indices for it. - // Otherwise we do the following. - buf_accesses[ten].indices.push_back( - std::vector(op->args.begin(), op->args.end())); - } + if (acc.acc_type != kReadWrite) { + // If a buffer is both read and written, in the tvm DSL, it must be a update, + // so the indices should be the same. Then we can skip appending indices for it. + // Otherwise we do the following. + buf_accesses[op->buffer].indices.push_back( + std::vector(op->indices.begin(), op->indices.end())); } StmtExprVisitor::VisitExpr_(op); } - std::unordered_map buf_accesses; + BufferMap buf_accesses; }; // Compute coefficient for an loop iterator in an expression @@ -430,11 +446,11 @@ void ComputeRegion( // Compute reuse distance and reuse ratio for accesses to a buffer // return values: reuse_type, reuse_dis_iter, reuse_dis_bytes, reuse_ct std::tuple ComputeReuse( - const te::Tensor& t, + const Buffer& buf, const std::vector >& indices, const std::vector& for_loop_stack, - const std::unordered_map > > >& for_touch_regions) { + const std::unordered_map > > >& for_touch_regions) { float reuse_dis_iter = 1.0f; float reuse_dis_bytes = -1.0f; @@ -479,16 +495,16 @@ std::tuple ComputeReuse( return std::make_tuple(kLoopMultipleRead, reuse_dis_iter, reuse_dis_bytes, extent); } - const std::unordered_map > >& - tensor_map = for_touch_regions.at(cur_for); + const BufferMap > >& buffer_map + = for_touch_regions.at(cur_for); - int serial_reuse = static_cast(tensor_map.at(t).size()) - 1; + int serial_reuse = static_cast(buffer_map.at(buf).size()) - 1; if (serial_reuse > 0) { int64_t extent = GetIntImm(cur_for->extent); // Have SerialMultipleReadWrite reuse reuse_dis_iter = std::numeric_limits::max(); - for (const auto& acc_info : tensor_map.at(t)) { + for (const auto& acc_info : buffer_map.at(buf)) { reuse_dis_iter = std::min(reuse_dis_iter, static_cast(std::get<1>(acc_info))); } @@ -600,13 +616,8 @@ class PerStmtFeatureExtractor : public StmtExprVisitor { } } - // TODO(...): ProvideNode is deprecated, move to BufferStoreNode - void VisitStmt_(const ProvideNode* node) final { - te::Operation op = Downcast(node->func); - te::Tensor ten = op.output(node->value_index); - const te::ComputeOpNode* pcompute = op.as(); - - FeatureSet &fea = op_features[ten]; + void VisitStmt_(const BufferStoreNode* node) final { + FeatureSet &fea = buffer_features[node->buffer]; // compute feature MathOpCounter mathops; @@ -641,8 +652,10 @@ class PerStmtFeatureExtractor : public StmtExprVisitor { for (const ForNode* pfor : vec_for_stack) { fea.vec_prod *= GetIntImm(pfor->extent); } - fea.vec_type = GetAnnotationPosEncoding(vec_for_stack.back()->loop_var, - node->args, pcompute->axis, pcompute->reduce_axis); + fea.vec_type = kPosMixed; + // todo(lmzheng): this feature requires operation (tvm.compute) information + //GetAnnotationPosEncoding(vec_for_stack.back()->loop_var, + //node->args, pcompute->axis, pcompute->reduce_axis); } fea.unroll_num = unroll_for_stack.size(); @@ -652,8 +665,9 @@ class PerStmtFeatureExtractor : public StmtExprVisitor { for (const ForNode* pfor : unroll_for_stack) { fea.unroll_prod *= GetIntImm(pfor->extent); } - fea.unroll_type = GetAnnotationPosEncoding(unroll_for_stack.back()->loop_var, - node->args, pcompute->axis, pcompute->reduce_axis); + fea.unroll_type = kPosMixed; + //GetAnnotationPosEncoding(unroll_for_stack.back()->loop_var, + //node->args, pcompute->axis, pcompute->reduce_axis); } fea.parallel_num = parallel_for_stack.size(); @@ -663,8 +677,9 @@ class PerStmtFeatureExtractor : public StmtExprVisitor { for (const ForNode* pfor : parallel_for_stack) { fea.parallel_prod *= GetIntImm(pfor->extent); } - fea.parallel_type = GetAnnotationPosEncoding(parallel_for_stack.back()->loop_var, - node->args, pcompute->axis, pcompute->reduce_axis); + fea.parallel_type = kPosMixed; + //GetAnnotationPosEncoding(parallel_for_stack.back()->loop_var, + //node->args, pcompute->axis, pcompute->reduce_axis); } // GPU threads @@ -680,13 +695,13 @@ class PerStmtFeatureExtractor : public StmtExprVisitor { // Extract all buffer access std::vector acc_feas; BufferAccessExtractor buf_extractor; - buf_extractor.InsertAccess(ten, kWrite, node->args); + buf_extractor.InsertAccess(node->buffer, kWrite, node->indices); buf_extractor.ExtractReads(node->value); // Compute touched region for all outer loops Analyzer ana; for (auto x : for_loop_stack) { - ana.Bind(x->loop_var, Range::make_by_min_extent(x->min, 1)); + ana.Bind(x->loop_var, Range::make_by_min_extent(x->min, 1), true); } std::vector mem_bytes_list; @@ -704,22 +719,22 @@ class PerStmtFeatureExtractor : public StmtExprVisitor { const ForNode* p_for = for_loop_stack[i]; ana.Bind(p_for->loop_var, - Range::make_by_min_extent(for_loop_stack[i]->min, for_loop_stack[i]->extent)); + Range::make_by_min_extent(for_loop_stack[i]->min, for_loop_stack[i]->extent), true); // Note, here we do overwrite. // So if there are multiple Provides, the last one will overwrite the first few. // e.g. The update part in gemm will overwrite the init part. - std::unordered_map > >& - tensor_regions_map = for_touch_regions[p_for]; + BufferMap > >& + buffer_regions_map = for_touch_regions[p_for]; int64_t mem_bytes = 0; for (const auto &x : buf_extractor.buf_accesses) { - const te::Tensor& t = x.first; + const Buffer& t = x.first; const BufferAccess& acc = x.second; ComputeRegion(acc.indices, &ana, &tmp_region); int64_t touched_size = ElementProduct(tmp_region); - tensor_regions_map[t].push_back(std::make_tuple(acc.acc_type, + buffer_regions_map[t].push_back(std::make_tuple(acc.acc_type, touched_size, t->dtype.bytes())); mem_bytes += touched_size * t->dtype.bytes(); } @@ -759,7 +774,7 @@ class PerStmtFeatureExtractor : public StmtExprVisitor { // Compute buffer access feature for (const auto &x : buf_extractor.buf_accesses) { - const te::Tensor& t = x.first; + const Buffer& t = x.first; const BufferAccess& acc = x.second; std::vector int_shape; @@ -826,7 +841,7 @@ class PerStmtFeatureExtractor : public StmtExprVisitor { acc_feas.emplace_back(); BufferAccessFeature& acc_fea = acc_feas.back(); - acc_fea.tensor_name = t->op->func_name(); + acc_fea.buffer_name = t->name; acc_fea.acc_type = acc.acc_type; acc_fea.stride = stride; acc_fea.bytes = bytes; @@ -854,21 +869,17 @@ class PerStmtFeatureExtractor : public StmtExprVisitor { fea.access_feas = acc_feas; } - // TODO(...): RealizeNode is deprecated, move to BufferRealizeNode - void VisitStmt_(const RealizeNode *node) final { + void VisitStmt_(const BufferRealizeNode *node) final { StmtExprVisitor::VisitStmt_(node); - te::Operation op = Downcast(node->func); - te::Tensor ten = op.output(node->value_index); - - FeatureSet& fea = op_features[ten]; + FeatureSet& fea = buffer_features[node->buffer]; float allocation_size = 1.0f; for (const auto& x : node->bounds) { allocation_size *= GetIntImm(x->extent); } // allocation feature - fea.alloc_size = allocation_size * ten->dtype.bytes(); + fea.alloc_size = allocation_size * node->buffer->dtype.bytes(); fea.alloc_prod = allocation_size * outer_loop_prod; fea.alloc_outer_prod = outer_loop_prod; fea.alloc_inner_prod = fea.outer_prod / outer_loop_prod; @@ -891,12 +902,12 @@ class PerStmtFeatureExtractor : public StmtExprVisitor { int vthread_len{1}; int16_t cur_auto_unroll_max_step{0}; - std::unordered_map op_features; + BufferMap buffer_features; - // for a loop, for all its touched tensors, for all different accesses to the tensors, + // for a loop, for all its touched buffers, for all different accesses to the buffers, // its (access type, number of touched elements, number of bytes of single element) - std::unordered_map > > > for_touch_regions; + std::unordered_map > > > for_touch_regions; private: const int cache_line_size_ = 64; @@ -913,14 +924,12 @@ void GetPerStmtFeature(const Stmt& stmt, int cache_line_size, int max_n_bufs, std::vector* ret) { - LOG(WARNING) << "RealizeNode & ProvideNode deprecated, " - << "need to fix the implementation of PerStmtFeatureExtractor."; PerStmtFeatureExtractor extractor(cache_line_size); extractor(stmt); - ret->push_back(extractor.op_features.size()); + ret->push_back(extractor.buffer_features.size()); - for (const auto& x : extractor.op_features) { + for (const auto& x : extractor.buffer_features) { const FeatureSet& fea_set = x.second; /***** compute feature *****/ @@ -1148,33 +1157,49 @@ void GetPerStmtFeaturesWorkerFunc(const SearchTask& task, const State& state, int max_n_bufs, std::vector* feature, std::atomic* error_ct) { te::Schedule sch; Array tensors; - Map bounds; - GlobalVar g("main"); std::tie(sch, tensors) = task->compute_dag.ApplySteps(state->transform_steps); sch = sch.normalize(); - bounds = te::InferBound(sch); + auto bounds = te::InferBound(sch); try { auto stmt = te::ScheduleOps(sch, bounds, false); Map out_binds; Array out_arg_list; bool compact = te::VerifyCompactBuffer(stmt); + const std::string& name = "main"; + GlobalVar global_var(name); + + // Copied from driver_api.cc::lower + auto pass_ctx = tvm::transform::PassContext::Current(); GetBinds(tensors, compact, std::unordered_map(), - &out_binds, &out_arg_list, BuildConfig::Create()); - tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, - std::move(stmt), out_binds); - f = WithAttr(std::move(f), "global_symbol", runtime::String("main")); - auto mod = IRModule(Map({{g, f}})); - auto pass_list = Array(); + &out_binds, &out_arg_list); + tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); + f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); + + bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); + bool disable_vectorize = + pass_ctx->GetConfig("tir.disable_vectorize", Bool(false)).value(); + bool instrument_bound_checkers = + pass_ctx->GetConfig("tir.instrument_bound_checkers", Bool(false)).value(); + + if (noalias) { + f = WithAttr(std::move(f), "tir.noalias", Bool(true)); + } + auto mod = IRModule(Map({{global_var, f}})); + if (task->target->device_type == kDLGPU) { + auto pass_list = Array(); + // Phase 0 pass_list.push_back(tir::transform::InjectPrefetch()); - pass_list.push_back(tir::transform::StorageFlatten(64)); + pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); + // Phase 1 + pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); - pass_list.push_back(tir::transform::VectorizeLoop()); + pass_list.push_back(tir::transform::VectorizeLoop(disable_vectorize)); pass_list.push_back(tir::transform::InjectVirtualThread()); pass_list.push_back(tir::transform::StorageRewrite()); pass_list.push_back(tir::transform::Simplify()); - tvm::Map gpu_params { + tvm::Map gpu_params { {"max_shared_memory_per_block", task->hardware_params->max_shared_memory_per_block}, {"max_local_memory_per_block", @@ -1188,11 +1213,9 @@ void GetPerStmtFeaturesWorkerFunc(const SearchTask& task, const State& state, const auto& optimize = tir::transform::Sequential(pass_list); optimize(mod); } - pass_list.clear(); - pass_list.push_back(tir::transform::Simplify()); - const auto& optimize = tir::transform::Sequential(pass_list); + const auto& optimize = tir::transform::Sequential(Array{tir::transform::Simplify()}); mod = optimize(std::move(mod)); - const auto& it = mod->functions.find(g); + const auto& it = mod->functions.find(global_var); CHECK(it != mod->functions.end()); const auto& prim_func = (*it).second.as(); GetPerStmtFeature(prim_func->body, @@ -1205,8 +1228,8 @@ void GetPerStmtFeaturesWorkerFunc(const SearchTask& task, const State& state, void GetPerStmtFeaturesFromStates(const Array& states, const SearchTask& task, - int max_n_bufs, int skip_first_n_feature_extraction, + int max_n_bufs, std::vector >* features) { // extract features features->assign(states.size(), std::vector()); @@ -1230,8 +1253,8 @@ void GetPerStmtFeaturesFromStates(const Array& states, void GetPerStmtFeaturesFromStates(const Array& states, const std::vector& tasks, - int max_n_bufs, int skip_first_n_feature_extraction, + int max_n_bufs, std::vector >* features) { // extract features features->assign(states.size(), std::vector()); @@ -1314,13 +1337,13 @@ void GetPerStmtFeaturesFromFile(const std::string& filename, (*normalized_throughputs)[i] = min_costs[(*task_ids)[i]] / (*normalized_throughputs)[i]; } - GetPerStmtFeaturesFromStates(states, tasks, max_n_bufs, 0, features); + GetPerStmtFeaturesFromStates(states, tasks, 0, max_n_bufs, features); } void GetPerStmtFeaturesFromMeasurePairs(const Array& inputs, const Array& results, - int max_n_bufs, int skip_first_n_feature_extraction, + int max_n_bufs, std::vector >* features, std::vector* normalized_throughputs, std::vector* task_ids) { @@ -1379,9 +1402,173 @@ void GetPerStmtFeaturesFromMeasurePairs(const Array& inputs, (*normalized_throughputs)[i] = min_costs[(*task_ids)[i]] / (*normalized_throughputs)[i]; } - GetPerStmtFeaturesFromStates(states, tasks, max_n_bufs, - skip_first_n_feature_extraction, features); + GetPerStmtFeaturesFromStates(states, tasks, skip_first_n_feature_extraction, + max_n_bufs, features); } +TVMByteArray SerializeFeatures(std::vector >&& features, + std::vector&& normalized_throughputs, + std::vector&& task_ids, + std::vector* out_data) { + size_t total_bytes = 0; + std::vector size_vector; + + int n = features.size(); + + // serialize sizes + size_t size_vector_size = 1 + n + 2; + total_bytes += size_vector_size * sizeof(int); + + size_vector.reserve(size_vector_size); + size_vector.push_back(features.size()); + for (const auto& x : features) { + size_vector.push_back(static_cast(x.size())); + total_bytes += sizeof(float) * x.size(); + } + size_vector.push_back(static_cast(normalized_throughputs.size())); + total_bytes += sizeof(float) * normalized_throughputs.size(); + size_vector.push_back(static_cast(task_ids.size())); + total_bytes += sizeof(int) * task_ids.size(); + + CHECK_EQ(size_vector.size(), size_vector_size); + + // allocate memory + out_data->reserve(total_bytes); + char* ptr = out_data->data(); + + // serialize size_vector + memmove(ptr, reinterpret_cast(size_vector.data()), size_vector.size() * sizeof(int)); + ptr += size_vector.size() * sizeof(int); + + // serialize features + for (auto& x : features) { + memmove(ptr, x.data(), sizeof(float) * x.size()); + ptr += sizeof(float) * x.size(); + x.clear(); + } + + // serialize normalized_throughputs + memmove(ptr, reinterpret_cast(normalized_throughputs.data()), + normalized_throughputs.size() * sizeof(int)); + ptr += normalized_throughputs.size() * sizeof(int); + + // serialize task_ids + memmove(ptr, reinterpret_cast(task_ids.data()), task_ids.size() * sizeof(int)); + ptr += task_ids.size() * sizeof(int); + + CHECK_EQ(ptr - out_data->data(), total_bytes); + + return TVMByteArray{out_data->data(), total_bytes}; +} + + +TVM_REGISTER_GLOBAL("ansor.GetPerStmtFeaturesFromFile") +.set_body([](TVMArgs args, TVMRetValue *ret) { + std::string filename = args[0]; + int n_lines = args[1]; + int max_n_bufs = args[2]; + + std::vector > features; + std::vector normalized_throughputs; + std::vector task_ids; + + GetPerStmtFeaturesFromFile(filename, n_lines, max_n_bufs, + &features, &normalized_throughputs, &task_ids); + + // serialization format for n records: + // + // int n; + // int[n+2] sizes + // + // float[sizes[0]] feature for record 1 + // float[sizes[1]] feature for record 2 + // ... feature for record i... + // float[sizes[n-1]] feature for record n + // + // float[sizes[n]] normalized throughput for n records + // int[sizes[n+1]] task id for n records + + std::vector byte_data; + *ret = SerializeFeatures(std::move(features), std::move(normalized_throughputs), + std::move(task_ids), &byte_data); +}); + +TVM_REGISTER_GLOBAL("ansor.GetPerStmtFeaturesFromMeasurePairs") +.set_body([](TVMArgs args, TVMRetValue *ret) { + Array inputs = args[0]; + Array results = args[1]; + int skip_first_n_feature_extraction = args[2]; + int max_n_bufs = args[3]; + + std::vector > features; + std::vector normalized_throughputs; + std::vector task_ids; + + GetPerStmtFeaturesFromMeasurePairs(inputs, results, skip_first_n_feature_extraction, max_n_bufs, + &features, &normalized_throughputs, &task_ids); + + // serialization format for n records: + // + // int n; + // int[n+2] sizes + // + // float[sizes[0]] feature for record 1 + // float[sizes[1]] feature for record 2 + // ... feature for record i... + // float[sizes[n-1]] feature for record n + // + // float[sizes[n]] normalized throughput for n records + // int[sizes[n+1]] task id for n records + + std::vector byte_data; + *ret = SerializeFeatures(std::move(features), std::move(normalized_throughputs), + std::move(task_ids), &byte_data); +}); + +TVM_REGISTER_GLOBAL("ansor.GetPerStmtFeaturesFromStates") +.set_body([](TVMArgs args, TVMRetValue *ret) { + Array states = args[0]; + SearchTask task = args[1]; + int max_n_bufs = args[2]; + + std::vector > features; + std::vector normalized_throughputs; + std::vector task_ids; + + GetPerStmtFeaturesFromStates(states, task, 0, max_n_bufs, &features); + + // serialization format for n records: + // + // int n; + // int[n+2] sizes + // + // float[sizes[0]] feature for record 1 + // float[sizes[1]] feature for record 2 + // ... feature for record i... + // float[sizes[n-1]] feature for record n + // + // float[sizes[n]] normalized throughput for n records + // int[sizes[n+1]] task id for n records + + std::vector byte_data; + *ret = SerializeFeatures(std::move(features), std::move(normalized_throughputs), + std::move(task_ids), &byte_data); +}); + +TVM_REGISTER_GLOBAL("ansor.GetPerStmtFeatureNames") + .set_body([](TVMArgs args, TVMRetValue *ret) { + int max_n_bufs = args[0]; + std::vector names; + + GetPerStmtFeatureName(max_n_bufs, &names); + + Array arr; + for (const auto& x : names) { + arr.push_back(x); + } + *ret = arr; +}); + + } // namespace ansor } // namespace tvm diff --git a/src/ansor/feature.h b/src/ansor/feature.h index 149c59e8cb7d..e507149643e2 100644 --- a/src/ansor/feature.h +++ b/src/ansor/feature.h @@ -1,13 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors - * \file ansor/search_task.h - * \brief Meta inforamtion for a search task + * \file ansor/feature.h + * \brief Feature extraction for the cost model */ #ifndef TVM_ANSOR_FEATURE_H_ #define TVM_ANSOR_FEATURE_H_ -// #include #include #include #include "compute_dag.h" @@ -26,18 +43,18 @@ void GetPerStmtFeature(const Stmt& stmt, void GetPerStmtFeatureName(int max_n_bufs, std::vector *ret); -/*! \brief Get PerStmt feature from states */ +/*! \brief Get PerStmt feature from states and the same task */ void GetPerStmtFeaturesFromStates(const Array& states, const SearchTask& task, - int max_n_bufs, int skip_first_n_feature_extraction, + int max_n_bufs, std::vector >* features); -/*! \brief Get PerStmt feature from states */ +/*! \brief Get PerStmt feature from states and different tasks */ void GetPerStmtFeaturesFromStates(const Array& states, const std::vector& tasks, - int max_n_bufs, int skip_first_n_feature_extraction, + int max_n_bufs, std::vector >* features); /*! \brief Get PerStmt feature from a log file */ @@ -51,8 +68,8 @@ void GetPerStmtFeaturesFromFile(const std::string& filename, /*! \brief Get PerStmt feature from measure pairs */ void GetPerStmtFeaturesFromMeasurePairs(const Array& inputs, const Array& results, - int max_n_bufs, int skip_first_n_feature_extraction, + int max_n_bufs, std::vector >* features, std::vector* normalized_throughputs, std::vector* task_ids); diff --git a/tests/python/unittest/test_ansor_feature.py b/tests/python/unittest/test_ansor_feature.py new file mode 100644 index 000000000000..abd304a9c2d7 --- /dev/null +++ b/tests/python/unittest/test_ansor_feature.py @@ -0,0 +1,97 @@ +"""Test feature extraction""" + +import math +import tempfile + +import tvm +from tvm import te, ansor + +from test_ansor_common import matmul_nkkm + + +def fequal(a, b): + return math.fabs(a - b) < 1e-6 + + +def test_cpu_matmul(): + dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) + s = dag.get_init_state() + C = 2 + + i, j, k = s.stages[C].iters + io, ii = s.split(C, i, [16]) + jo, ji = s.split(C, j, [8]) + s.reorder(C, [io, jo, k, ji, ii]) + s.vectorize(C, ji) + s.parallel(C, io) + s.parallel(C, jo) + s.unroll(2, k) + + target = tvm.target.create('llvm') + task = ansor.SearchTask(dag, "test", target) + names = ansor.feature.get_per_stmt_feature_names() + fea = ansor.feature.get_per_stmt_features_from_states([s.state_object], task)[0] + + stage_0 = fea[0] + assert len(stage_0) == len(names), "%d vs %d" % (len(stage_0), len(names)) + fea_dict = {} + for name, value in zip(names, stage_0): + fea_dict[name] = value + + for name in ["B0", "B1", "B2"]: + if fequal(fea_dict[name + ".acc_type.kReadWrite"], 1.0): + c_name = name + if fequal(fea_dict[name + ".acc_type.kRead"], 1.0): + if fequal(fea_dict[name + ".stride"], 0.0): + b_name = name + else: + a_name = name + + assert fequal(fea_dict[c_name + ".bytes"], math.log2(512 ** 3 * 4 + 1)) + assert fequal(fea_dict[b_name + ".unique_bytes"], math.log2(512 ** 2 * 4 + 1)) + assert fequal(fea_dict[c_name + ".reuse_dis_iter"], math.log2(8 * 16 + 1)) + assert fequal(fea_dict[c_name + ".reuse_dis_bytes"], math.log2((8 * 16 + 8 + 16) * 4 + 1)) + assert fequal(fea_dict[c_name + ".reuse_ct"], math.log2(512 + 1)) + + assert fequal(fea_dict["unroll_num"], math.log2(1 + 1)) + # assert fequal(fea_dict["unroll_type.kPosInnerReduce"], 1.0) + assert fequal(fea_dict["vec_num"], math.log2(1 + 1)) + assert fequal(fea_dict["parallel_num"], math.log2(2 + 1)) + assert fequal(fea_dict["parallel_prod"], math.log2((512 * 512 / 16 / 8) + 1)) + + +def test_cpu_fusion(): + def fusion_test(N, M): + A = te.placeholder((N, M), name='A') + B = te.compute((N, M), lambda i, j: A[i][j], name='B') + C = te.compute((N, M), lambda i, j: B[i][j], name='C') + return [A, B, C] + + dag = ansor.ComputeDAG(fusion_test(64, 32)) + s = dag.get_init_state() + s.compute_at(1, 2, s.stages[2].iters[1]) + + target = tvm.target.create('llvm') + task = ansor.SearchTask(dag, "test", target) + names = ansor.feature.get_per_stmt_feature_names() + fea = ansor.feature.get_per_stmt_features_from_states([s.state_object], task)[0] + + found = False + for stage_fea in fea: + for i, (name, value) in enumerate(zip(names, stage_fea)): + if 'reuse_type.kSerialMultipleReadWrite' in name and value > 0.5: + assert fequal(stage_fea[i + 2], 1.0) + assert fequal(stage_fea[i + 3], math.log2(16 + 1)) + found = True + assert found + + +def test_gpu_feature(): + # todo(lmzheng) + pass + + +if __name__ == "__main__": + test_cpu_matmul() + test_cpu_fusion() + test_gpu_feature() From b839c0f6b8c45f4dcfdd96a7a60338b40387c5d4 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Tue, 9 Jun 2020 13:55:15 +0800 Subject: [PATCH 15/45] Add XGBModel & RPCRunnerWarpper (#15) * Add XGBModel & RPCRunnerWarpper * Revert "Add Parallel Granularity Mutation" --- python/tvm/ansor/__init__.py | 3 +- python/tvm/ansor/cost_model/cost_model.py | 29 ++ python/tvm/ansor/cost_model/xgb_model.py | 476 ++++++++++++++++++ python/tvm/ansor/measure.py | 48 ++ src/ansor/cost_model/cost_model.cc | 29 +- src/ansor/cost_model/cost_model.h | 6 +- .../search_policy/meta_tile_rewrite_policy.cc | 27 +- .../unittest/test_ansor_search_policy.py | 53 +- 8 files changed, 607 insertions(+), 64 deletions(-) create mode 100644 python/tvm/ansor/cost_model/xgb_model.py diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 3e9b76c2f6ad..2d27995e328e 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -29,6 +29,7 @@ from .compute_dag import ComputeDAG from .task import SearchTask, MetaTileRewritePolicy, TuneOption from .task import auto_schedule -from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner +from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, RPCRunnerWarpper from .cost_model import RandomModel +from .cost_model.xgb_model import XGBModel from .serialization import LogToFile, LogReader, best_measure_pair_in_file diff --git a/python/tvm/ansor/cost_model/cost_model.py b/python/tvm/ansor/cost_model/cost_model.py index a0e586d69cec..fd9b67927185 100644 --- a/python/tvm/ansor/cost_model/cost_model.py +++ b/python/tvm/ansor/cost_model/cost_model.py @@ -42,3 +42,32 @@ def random_number(n, return_ptr): return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_float)) array_wrapper = np.ctypeslib.as_array(return_ptr, shape=(n,)) array_wrapper[:] = np.random.uniform(0, 1, (n,)) + +@tvm._ffi.register_object("ansor.PythonBasedModel") +class PythonBasedModel(CostModel): + def __init__(self): + def update_func(inputs, results): + self.update(inputs, results) + + def predict_func(task, states, return_ptr): + return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_float)) + array_wrapper = np.ctypeslib.as_array(return_ptr, shape=(len(states),)) + array_wrapper[:] = self.predict(task, states) + + def predict_stage_func(task, states, return_ptr): + ret = self.predict_stages(task, states) + return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_float)) + array_wrapper = np.ctypeslib.as_array(return_ptr, shape=ret.shape) + array_wrapper[:] = ret + + self.__init_handle_by_constructor__(_ffi_api.PythonBasedModel, update_func, + predict_func, predict_stage_func) + + def update(self, inputs, results): + raise NotImplementedError + + def predict(self, task, states): + raise NotImplementedError + + def predict_stages(self, task, states): + raise NotImplementedError diff --git a/python/tvm/ansor/cost_model/xgb_model.py b/python/tvm/ansor/cost_model/xgb_model.py new file mode 100644 index 000000000000..e61acfbd168f --- /dev/null +++ b/python/tvm/ansor/cost_model/xgb_model.py @@ -0,0 +1,476 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Cost model based on xgboost""" +from typing import List +import multiprocessing +import logging +import time +from collections import defaultdict + +import numpy as np +import xgboost as xgb + +from ...autotvm.tuner.xgboost_cost_model import get_rank, recall_curve, max_curve +from .cost_model import PythonBasedModel +from ..feature import get_per_stmt_features_from_measure_pairs, get_per_stmt_features_from_states +from ..serialization import LogReader + +logger = logging.getLogger('ansor') + +class XGBDMatrixContext: + """Context to hold additional attributes of xgb.DMatrix""" + def __init__(self): + self.context_dict = defaultdict(dict) + + def get(self, key, matrix, default=None): + return self.context_dict[key].get(matrix.handle.value, default) + + def put(self, key, matrix, value): + self.context_dict[key][matrix.handle.value] = value + +dmatrix_context = XGBDMatrixContext() + +class XGBModel(PythonBasedModel): + """Train a XGBoost model to predict the runtime cost of a program. + The cost of a program = the sum of the costs of all stages in this program. + i.e. Cost(p) = cost_s0 + cost_s1 + ... + cost_sn, where cost_si is the cost of Stage i + + The xgboost model makes prediction per stage, then we sum them up. + The final predction made by this class is normalized throughtput (from 0 to 1, larger is better) + + To support this stage decomposition, we have to implement a custom loss function for + XGBoost, which is the `pack_sum` in the code below. + """ + def __init__(self, verbose_eval=25, num_warmup_sample=100, seed=None): + self.xgb_params = { + 'max_depth': 10, + 'gamma': 0.001, + 'min_child_weight': 0, + 'eta': 0.2, + # todo(lmzheng): automatically decrease learning rate when the loss is too large + + 'n_gpus': 0, + 'n_threads': multiprocessing.cpu_count() / 2, + 'silent': 0, + 'seed': seed or 43, + 'disable_default_eval_metric': 1 + } + self.bst = None + self.plan_size = 32 + self.num_warmup_sample = num_warmup_sample + self.verbose_eval = verbose_eval + + super().__init__() + + # measurement input/result pairs + self.inputs = [] + self.results = [] + self.inputs_feature_cache = [] + + def update(self, inputs, results): + if len(inputs) <= 0: + return + + self.inputs.extend(inputs) + self.results.extend(results) + + # extract feature + n_cached = len(self.inputs_feature_cache) + features, normalized_throughputs, task_ids = \ + get_per_stmt_features_from_measure_pairs(self.inputs, self.results, + skip_first_n_feature_extraction=n_cached) + if n_cached > 0: + features = list(features) + features[:n_cached] = self.inputs_feature_cache + features = np.array(features) + self.inputs_feature_cache = features + dtrain = pack_sum_xgbmatrix(features, normalized_throughputs, task_ids, normalized_throughputs) + + # train xgb model + self.bst = xgb.train(self.xgb_params, dtrain, + num_boost_round=10000, + obj=pack_sum_square_error, + callbacks=[custom_callback( + stopping_rounds=50, + metric='tr-p-rmse', + fevals=[ + pack_sum_rmse, pack_sum_average_peak_score(self.plan_size), + ], + evals=[(dtrain, 'tr')], + maximize=False, + verbose_eval=self.verbose_eval)]) + + def predict(self, task, states): + features = get_per_stmt_features_from_states(states, task) + if self.bst is not None and len(self.inputs) > self.num_warmup_sample: + dtest, pack_ids = pack_sum_xgbmatrix_for_prediction(features) + raw_preds = self.bst.predict(dtest) + ret = pack_sum_predict_throughput(raw_preds, pack_ids) + else: + ret = np.random.uniform(0, 1, (len(states),)) + + # Predict 0 for invalid states that failed to be lowered. + for idx, feature in enumerate(features): + if feature.min() == feature.max() == 0: + ret[idx] = float('-inf') + + return ret + + def predict_stages(self, task, states): + # Format: (s0 score, ..., sN score, s0 n_stage, s0 stage 0, ..., s1 n_stage, s1 stage 0,) + + features = get_per_stmt_features_from_states(states, task) + if self.bst is not None and len(self.inputs) > self.num_warmup_sample: + dtest, pack_ids = pack_sum_xgbmatrix_for_prediction(features) + raw_preds = self.bst.predict(dtest) + breakdown = pack_sum_predict_throughput(raw_preds, pack_ids) + stage_scores = [[] for _ in range(len(states))] + for pred, pack_id in zip(raw_preds, pack_ids): + stage_scores[pack_id].append(pred) + for idx, stage_score in enumerate(stage_scores): + breakdown = np.append(breakdown, len(stage_score)) + breakdown = np.concatenate((breakdown, -np.array(stage_score))) + else: + breakdown = np.concatenate( + (np.random.uniform(0, 1, (len(states), )), np.zeros(len(states), ))) + + # Predict 0 for invalid states that failed to be lowered. + for idx, feature in enumerate(features): + if feature.min() == feature.max() == 0: + breakdown[idx] = float('-inf') + + return breakdown + + def load_log_file(self, file_name, n_lines=-1): + inputs, results = LogReader(file_name).read_lines(n_lines) + logger.info("XGBModel: Loaded %s lines of history log from %s", len(inputs), file_name) + self.update(inputs, results) + + def save(self, file_name: str): + self.bst.save_model(file_name) + + def load(self, file_name: str): + if self.bst is None: + self.bst = xgb.Booster(self.xgb_params) + self.bst.load_model(file_name) + self.num_warmup_sample = -1 + + +def pack_sum_xgbmatrix_for_prediction(xs): + x_flatten = [] + pack_ids = [] + + for ct, x in enumerate(xs): + for row in x: + x_flatten.append(row) + pack_ids.append(ct) + + return xgb.DMatrix(x_flatten), pack_ids + + +def pack_sum_xgbmatrix(xs, ys, gids=None, weights=None): + if gids is not None: + # sort by group + indices = gids.argsort() + xs, ys = xs[indices], ys[indices] + group_sizes = np.bincount(gids) + if weights is not None: + weights = weights[indices] + else: + # assume it has only one group + group_sizes = [len(xs)] + + x_flatten = [] + y_flatten = [] + weights_flatten = [] + pack_ids = [] + + if weights is not None: + for ct, (x, y, w) in enumerate(zip(xs, ys, weights)): + for row in x: + x_flatten.append(row) + y_flatten.append(y) + weights_flatten.append(w) + pack_ids.append(ct) + else: + for ct, (x, y) in enumerate(zip(xs, ys)): + for row in x: + x_flatten.append(row) + y_flatten.append(y) + pack_ids.append(ct) + + ret = xgb.DMatrix(x_flatten, y_flatten) + if weights is not None: + ret.set_weight(weights_flatten) + dmatrix_context.put('pack_ids', ret, np.array(pack_ids)) + dmatrix_context.put('group_sizes', ret, group_sizes) + return ret + +LOSS_TYPE = 3 + +# Type 0 +# The model predicts cost. Use square error of throughput as loss +# loss = 1/2 * (1 / sum(x_i) - y) ^ 2 +# +# Type 1 +# The model predicts cost. Use square error of cost as loss +# loss = 1/2 * (sum(x_i) - 1 / y) ^ 2 +# +# Type 2 +# The model predicts throughput. Use square error of throughput as loss. +# loss = 1/2 * (1 / sum(1 / x_i) - y) ^ 2 +# +# Type 3 +# The model predicts throughput. Use square error of throughput as loss. +# But approximate 1 / (1 / a_1 + 1 / a_2 + ... + 1 / a_n) with -(b_1 + b_2 + b_3) +# loss = 1/2 * (-sum(x_i) - y) ^ 2 +# +# Type 4 +# The model predicts throughput. Use square error of throughput as loss. +# But approximate 1 / (1 / a_1 + 1 / a_2 + ... + 1 / a_n) with -(b_1 + b_2 + b_3) +# Also add a sigmoid to force the prediction to be within the range of (0, 1) +# loss = 1/2 * (sigmoid(-sum(x_i)) - y) ^ 2 +# + +def pack_sum_predict_throughput(raw_preds, pack_ids): + if LOSS_TYPE == 0: + sum_pred = np.bincount(pack_ids, weights=raw_preds) + return 1 / sum_pred + elif LOSS_TYPE == 1: + sum_pred = np.bincount(pack_ids, weights=raw_preds) + return 1 / sum_pred + elif LOSS_TYPE == 2: + sum_inverse_preds = np.bincount(pack_ids, weights=1 / raw_preds) + return 1 / sum_inverse_preds + elif LOSS_TYPE == 3: + sum_pred = np.bincount(pack_ids, weights=raw_preds) + return - sum_pred # pylint: disable=invalid-unary-operand-type + elif LOSS_TYPE == 4: + sum_pred = np.bincount(pack_ids, weights=raw_preds) + return 1 / (1 + np.exp(sum_pred)) + else: + raise ValueError("Invalid loss type: " + LOSS_TYPE) + +def pack_sum_square_error(preds, dtrain): + pack_ids = dmatrix_context.get("pack_ids", dtrain) + weight = dtrain.get_weight() + + if LOSS_TYPE == 0: + sum_pred = np.bincount(pack_ids, weights=preds) + x = sum_pred[pack_ids] + y = dtrain.get_label() + gradient = (x * y - 1) / np.power(x, 3) + hessian = (3 - 2 * x * y) / np.power(x, 4) + elif LOSS_TYPE == 1: + sum_pred = np.bincount(pack_ids, weights=preds) + x = sum_pred[pack_ids] + y = dtrain.get_label() + gradient = x - 1 / np.minimum(y, 1e6) + hessian = np.ones_like(gradient) + elif LOSS_TYPE == 2: + sum_inverse_preds = np.bincount(pack_ids, weights=1 / preds)[pack_ids] + y = dtrain.get_label() + gradient = (1 / sum_inverse_preds - y) / (np.power(preds * sum_inverse_preds, 2)) + hessian = (2 * preds * y * np.power(sum_inverse_preds, 2) - 2 * y * sum_inverse_preds - 2 * preds * sum_inverse_preds + 3) / (np.power(preds * sum_inverse_preds, 4)) + elif LOSS_TYPE == 3: + sum_pred = np.bincount(pack_ids, weights=preds) + x = sum_pred[pack_ids] + y = dtrain.get_label() + gradient = x + y + hessian = np.ones_like(gradient) + elif LOSS_TYPE == 4: + sum_pred = np.bincount(pack_ids, weights=preds) + exp_x = np.exp(sum_pred[pack_ids]) + exp_2x = np.power(exp_x, 2) + y = dtrain.get_label() + gradient = exp_x * (exp_x * y + y - 1) / np.power(exp_x + 1, 3) + hessian = exp_x * (-exp_2x * y + 2 * exp_x + y - 1) / np.power(exp_x + 1, 4) + else: + raise ValueError("Invalid loss type: " + LOSS_TYPE) + + if len(weight) == 0: + return gradient, hessian + else: + return gradient * weight, hessian * weight + +def pack_sum_rmse(raw_preds, dtrain): + pack_ids = dmatrix_context.get("pack_ids", dtrain) + preds = pack_sum_predict_throughput(raw_preds, pack_ids)[pack_ids] + return 'p-rmse', np.sqrt(np.mean(np.square((preds - dtrain.get_label())))) + +def pack_sum_average_peak_score(N): + """Evaluate pack sum average peak score for xgb""" + + def feval(preds, labels): + group_sizes = dmatrix_context.get('group_sizes', labels, [len(preds)]) + pack_ids = dmatrix_context.get("pack_ids", labels) + + preds = pack_sum_predict_throughput(preds, pack_ids) + labels = (np.bincount(pack_ids, weights=labels.get_label()) + / np.unique(pack_ids, return_counts=True)[1]) + + scores = [] + offset = 0 + for size in group_sizes: + preds_group = preds[offset:offset + size] + labels_group = labels[offset:offset + size] + offset += size + + trials = np.argsort(preds_group)[::-1][:N] + trial_scores = labels_group[trials] + curve = max_curve(trial_scores) / np.max(labels_group) + scores.append(np.mean(curve)) + return "a-peak@%d" % N, np.mean(scores) + return feval + +def pack_sum_average_recall_score(N): + """evaluate average recall score for xgb""" + + def feval(preds, labels): + group_sizes = dmatrix_context.get('group_sizes', labels, [len(preds)]) + pack_ids = dmatrix_context.get("pack_ids", labels) + + preds = pack_sum_predict_throughput(preds, pack_ids) + labels = (np.bincount(pack_ids, weights=labels.get_label()) + / np.unique(pack_ids, return_counts=True)[1]) + + scores = [] + offset = 0 + for size in group_sizes: + preds_group = preds[offset:offset + size] + labels_group = labels[offset:offset + size] + offset += size + + trials = np.argsort(preds_group)[::-1] + ranks = get_rank(labels_group[trials])[:N] + curve = recall_curve(ranks) + scores.append(np.mean(curve)) + return "a-recall@%d" % N, np.mean(scores) + return feval + + +def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None, + maximize=False, verbose_eval=True, skip_every=2): + """Callback function for xgboost to support multiple custom evaluation functions""" + from xgboost.core import EarlyStopException + from xgboost.callback import _fmt_metric + from xgboost.training import aggcv + + state = {} + metric_shortname = metric.split("-")[1] + + def init(env): + """internal function""" + bst = env.model + + state['maximize_score'] = maximize + state['best_iteration'] = 0 + if maximize: + state['best_score'] = float('-inf') + else: + state['best_score'] = float('inf') + + if bst is not None: + if bst.attr('best_score') is not None: + state['best_score'] = float(bst.attr('best_score')) + state['best_iteration'] = int(bst.attr('best_iteration')) + state['best_msg'] = bst.attr('best_msg') + else: + bst.set_attr(best_iteration=str(state['best_iteration'])) + bst.set_attr(best_score=str(state['best_score'])) + else: + assert env.cvfolds is not None + + def callback(env): + """internal function""" + if not state: + init(env) + + bst = env.model + i = env.iteration + cvfolds = env.cvfolds + + res_dict = {} + + if i % skip_every == 1: + return + + ##### evaluation ##### + if cvfolds is not None: + for feval in fevals: + tmp = aggcv([f.eval(i, feval) for f in cvfolds]) + for k, mean, std in tmp: + res_dict[k] = [mean, std] + else: + for feval in fevals: + bst_eval = bst.eval_set(evals, i, feval) + res = [x.split(':') for x in bst_eval.split()] + for kv in res[1:]: + res_dict[kv[0]] = [float(kv[1])] + + eval_res = [] + keys = list(res_dict.keys()) + keys.sort(key=lambda x: x if metric_shortname not in x else "a" + x) + for key in keys: + v = res_dict[key] + eval_res.append([key] + v) + + ##### print eval result ##### + if not isinstance(verbose_eval, bool) and verbose_eval and i % verbose_eval == 0: + infos = ["XGB iter: %3d" % i] + for item in eval_res: + if 'null' in item[0]: + continue + infos.append("%s: %.6f" % (item[0], item[1])) + + logger.debug("\t".join(infos)) + if log_file: + with open(log_file, "a") as fout: + fout.write("\t".join(infos) + '\n') + + ##### choose score and do early stopping ##### + score = None + for item in eval_res: + if item[0] == metric: + score = item[1] + break + assert score is not None + + best_score = state['best_score'] + best_iteration = state['best_iteration'] + maximize_score = state['maximize_score'] + if (maximize_score and score > best_score) or \ + (not maximize_score and score < best_score): + msg = '[%d] %s' % ( + env.iteration, + '\t'.join([_fmt_metric(x) for x in eval_res])) + state['best_msg'] = msg + state['best_score'] = score + state['best_iteration'] = env.iteration + # save the property to attributes, so they will occur in checkpoint. + if env.model is not None: + env.model.set_attr(best_score=str(state['best_score']), + best_iteration=str(state['best_iteration']), + best_msg=state['best_msg']) + elif env.iteration - best_iteration >= stopping_rounds: + best_msg = state['best_msg'] + if verbose_eval and env.rank == 0: + logger.debug("XGB stopped. Best iteration: %s ", best_msg) + raise EarlyStopException(best_iteration) + + return callback diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index e10da09e4b5a..e35a73148f3a 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -35,6 +35,8 @@ from tvm.runtime import Object, module, ndarray from tvm.driver import build_module from tvm.ir import transform +from tvm.rpc.tracker import Tracker +from tvm.rpc.server import Server from ..contrib import tar, ndk from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout, request_remote, check_remote from .compute_dag import LayoutRewriteLevel @@ -190,6 +192,52 @@ def __init__(self, key, host, port, priority=1, "and make sure you have free devices on the queue status.") +class RPCRunnerWarpper: + def __init__(self, target=None, priority=1, + n_parallel=1, + timeout=10, + number=3, + repeat=1, + min_repeat_ms=0, + cooldown_interval=0.0): + self.target = target + self.priority = priority + self.n_parallel = n_parallel + self.timeout = timeout + self.number = number + self.repeat = repeat + self.min_repeat_ms = min_repeat_ms + self.cooldown_interval = cooldown_interval + + self.tracker = None + self.server = None + self.runner = None + + def __enter__(self): + if self.target == "cuda": + ctx = tvm.context("cuda", 0) + cuda_arch = "sm_" + "".join(ctx.compute_version.split('.')) + tvm.autotvm.measure.measure_methods.set_cuda_target_arch(cuda_arch) + host = '0.0.0.0' + self.tracker = Tracker(host, port=9000, port_end=10000, silent=True) + device_key = '$local$device$%d' % self.tracker.port + self.server = Server(host, port=self.tracker.port, port_end=10000, + key=device_key, + use_popen=True, silent=True, + tracker_addr=(self.tracker.host, self.tracker.port)) + self.runner = RPCRunner(device_key, host, self.tracker.port, self.priority, + self.n_parallel, self.timeout, self.number, self.repeat, + self.min_repeat_ms, self.cooldown_interval) + + return self + + def __exit__(self, type, value, trace): + if value: + raise value + + self.tracker.terminate() + self.server.terminate() + MAX_ERROR_MSG_LEN = 512 diff --git a/src/ansor/cost_model/cost_model.cc b/src/ansor/cost_model/cost_model.cc index 8e0936071774..bbf15a241974 100644 --- a/src/ansor/cost_model/cost_model.cc +++ b/src/ansor/cost_model/cost_model.cc @@ -37,7 +37,7 @@ using ::tvm::runtime::NDArray; TVM_REGISTER_OBJECT_TYPE(CostModelNode); TVM_REGISTER_OBJECT_TYPE(RandomModelNode); TVM_REGISTER_OBJECT_TYPE(MeasureModelNode); -TVM_REGISTER_OBJECT_TYPE(PythonBasedCostModelNode); +TVM_REGISTER_OBJECT_TYPE(PythonBasedModelNode); void RandomNumber(TVMArgs args, TVMRetValue* rv) { int n = args[0]; @@ -101,30 +101,30 @@ void MeasureModelNode::Predict(const SearchTask& task, } } -CostModel PythonBasedCostModelNode::make(PackedFunc update_func, - PackedFunc predict_func, - PackedFunc predict_stage_func) { - auto node = make_object(); +CostModel PythonBasedModelNode::make(PackedFunc update_func, + PackedFunc predict_func, + PackedFunc predict_stage_func) { + auto node = make_object(); node->update_func = std::move(update_func); node->predict_func = std::move(predict_func); node->predict_stage_func = std::move(predict_stage_func); return CostModel(node); } -void PythonBasedCostModelNode::Update(const Array& inputs, - const Array& results) { +void PythonBasedModelNode::Update(const Array& inputs, + const Array& results) { update_func(inputs, results); } -void PythonBasedCostModelNode::Predict(const SearchTask& task, - const std::vector& states, - std::vector* scores) { +void PythonBasedModelNode::Predict(const SearchTask& task, + const std::vector& states, + std::vector* scores) { scores->resize(states.size()); predict_func(task, Array(states.begin(), states.end()), static_cast(scores->data())); } -void PythonBasedCostModelNode::PredictStages( +void PythonBasedModelNode::PredictStages( const SearchTask& task, const std::vector& states, std::vector* state_scores, std::vector>* stage_scores) { @@ -188,5 +188,12 @@ TVM_REGISTER_GLOBAL("ansor.RandomModel").set_body_typed([]() { return RandomModelNode::make(); }); +TVM_REGISTER_GLOBAL("ansor.PythonBasedModel") +.set_body_typed([](PackedFunc update_func, PackedFunc predict_func, + PackedFunc predict_stage_func) { + return PythonBasedModelNode::make(update_func, predict_func, + predict_stage_func); +}); + } // namespace ansor } // namespace tvm diff --git a/src/ansor/cost_model/cost_model.h b/src/ansor/cost_model/cost_model.h index 9daf01197bbf..472a3c201068 100644 --- a/src/ansor/cost_model/cost_model.h +++ b/src/ansor/cost_model/cost_model.h @@ -92,7 +92,7 @@ class MeasureModelNode : public CostModelNode { /*! \brief A wrapper for cost model defined by python code * This class will call python's function */ -class PythonBasedCostModelNode: public CostModelNode { +class PythonBasedModelNode: public CostModelNode { public: PackedFunc update_func; PackedFunc predict_func; @@ -108,8 +108,8 @@ class PythonBasedCostModelNode: public CostModelNode { std::vector* state_scores, std::vector>* stage_scores) final; - static constexpr const char *_type_key = "ansor.PythonBasedCostModel"; - TVM_DECLARE_FINAL_OBJECT_INFO(PythonBasedCostModelNode, CostModelNode); + static constexpr const char *_type_key = "ansor.PythonBasedModel"; + TVM_DECLARE_FINAL_OBJECT_INFO(PythonBasedModelNode, CostModelNode); }; } // namespace ansor diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.cc b/src/ansor/search_policy/meta_tile_rewrite_policy.cc index 86a7eba1da3a..f086a8879abb 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.cc +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.cc @@ -1397,24 +1397,27 @@ void MetaTileRewritePolicyNode::EvolutionarySearch( int id = RandomChoose(prefix_sum_probs, &rand_gen_); if (dis(rand_gen_) < mutation_prob) { - const std::vector rule_prefix_sum_probs{0.9, 0.95, 1.0}; + const std::vector rule_prefix_sum_probs{0.9, 1.0}; int rule_id = RandomChoose(rule_prefix_sum_probs, &rand_gen_); - State tmp_s; if (rule_id == 0) { - tmp_s = RandomMutateTileSize((*pnow)[id], &split_memo_, &rand_gen_, + // Mutate Tile Size + State tmp_s = RandomMutateTileSize((*pnow)[id], &split_memo_, &rand_gen_, cur_task_->hardware_params->max_innermost_split_factor); + if (tmp_s.defined()) { + pnext->push_back(std::move(tmp_s)); + } else { + mutation_fail_ct++; + } } else if (rule_id == 1) { - tmp_s = RandomMutateMaxUnrollStep((*pnow)[id], &rand_gen_, auto_unroll_configs); - } else if (rule_id == 2) { - tmp_s = MutataParallel((*pnow)[id], &split_memo_, &rand_gen_, cur_task_); - } - - if (tmp_s.defined()) { - pnext->push_back(std::move(tmp_s)); - } else { - mutation_fail_ct++; + // Mutate auto-unroll max step. + State tmp_s = RandomMutateMaxUnrollStep((*pnow)[id], &rand_gen_, auto_unroll_configs); + if (tmp_s.defined()) { + pnext->push_back(std::move(tmp_s)); + } else { + mutation_fail_ct++; + } } } else { pnext->push_back((*pnow)[id]); diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py index 9a57691aba22..6636787e807f 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -24,28 +24,25 @@ import tvm from tvm import ansor -from tvm.rpc.tracker import Tracker -from tvm.rpc.server import Server from test_ansor_common import matmul_nkkm -def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local'): +def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local', + cost_model=ansor.RandomModel(), n_trials=2): print("Test %s schedule search with the default search policy" % (target)) + random.seed(seed) N = 128 A, B, C = matmul_nkkm(N, N, N) dag = ansor.ComputeDAG([A, B, C]) tgt = tvm.target.create(target) task = ansor.SearchTask(dag, "test", tgt) - random.seed(seed) - with tempfile.NamedTemporaryFile() as fp: log_file = fp.name - cost_model = ansor.RandomModel() search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) - tune_option = ansor.TuneOption(n_trials=2, runner=runner, + tune_option = ansor.TuneOption(n_trials=n_trials, runner=runner, callbacks=[ansor.LogToFile(log_file)]) state = ansor.auto_schedule(task, search_policy, tune_option=tune_option) @@ -83,48 +80,30 @@ def test_search_basic(): search_common(seed=944563397) +def test_search_xgb_model_rpc_runner(): + with ansor.RPCRunnerWarpper() as rpc_runner: + search_common(seed=456787236, cost_model=ansor.XGBModel(), + runner=rpc_runner.runner) + + def test_search_opencl(): if tvm.context("opencl", 0).exist: - host = '0.0.0.0' - tracker = Tracker(host, port=9000, port_end=10000, silent=True) - device_key = '$local$device$%d' % tracker.port - server = Server(host, port=tracker.port, port_end=10000, - key=device_key, - use_popen=True, silent=True, - tracker_addr=(tracker.host, tracker.port)) - rpc_runner = ansor.RPCRunner(device_key, host, tracker.port) - - search_common("opencl", 380344973, rpc_runner) - - tracker.terminate() - server.terminate() + with ansor.RPCRunnerWarpper() as rpc_runner: + search_common("opencl", 380344973, rpc_runner.runner) else: print("OpenCL device not found, skip this test.") def test_search_cuda(): - ctx = tvm.context("cuda", 0) - if ctx.exist: - cuda_arch = "sm_" + "".join(ctx.compute_version.split('.')) - tvm.autotvm.measure.measure_methods.set_cuda_target_arch(cuda_arch) - host = '0.0.0.0' - tracker = Tracker(host, port=9000, port_end=10000, silent=True) - device_key = '$local$device$%d' % tracker.port - server = Server(host, port=tracker.port, port_end=10000, - key=device_key, - use_popen=True, silent=True, - tracker_addr=(tracker.host, tracker.port)) - rpc_runner = ansor.RPCRunner(device_key, host, tracker.port) - - search_common("cuda", 903667810, rpc_runner) - - tracker.terminate() - server.terminate() + if tvm.context("cuda", 0).exist: + with ansor.RPCRunnerWarpper("cuda") as rpc_runner: + search_common("cuda", 903667810, rpc_runner.runner) else: print("CUDA device not found, skip this test.") if __name__ == "__main__": test_search_basic() + test_search_xgb_model_rpc_runner() test_search_opencl() test_search_cuda() From cfe58d7829cd649f4b1a4af8f4af3200dbc5174f Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 9 Jun 2020 01:16:58 -0700 Subject: [PATCH 16/45] Migrate workload_registry.py (#16) * add workload registry * update * update --- python/tvm/ansor/__init__.py | 8 +- .../tvm/ansor/{task.py => auto_schedule.py} | 0 python/tvm/ansor/feature.py | 2 +- python/tvm/ansor/measure.py | 5 +- python/tvm/ansor/serialization.py | 5 + python/tvm/ansor/workload_registry.py | 190 ++++++++++++++++++ src/ansor/feature.cc | 2 + src/ansor/serialization.cc | 62 +++++- src/tir/analysis/verify_gpu_code.cc | 44 +++- tests/python/unittest/test_ansor_common.py | 11 +- tests/python/unittest/test_ansor_feature.py | 62 +++++- .../python/unittest/test_ansor_loop_state.py | 8 +- .../unittest/test_ansor_search_policy.py | 4 +- 13 files changed, 364 insertions(+), 39 deletions(-) rename python/tvm/ansor/{task.py => auto_schedule.py} (100%) create mode 100644 python/tvm/ansor/workload_registry.py diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 2d27995e328e..bb4822409757 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -21,15 +21,17 @@ from . import measure from . import serialization from . import loop_state -from . import task +from . import auto_schedule from . import utils from . import feature +from . import workload_registry # Shortcut from .compute_dag import ComputeDAG -from .task import SearchTask, MetaTileRewritePolicy, TuneOption -from .task import auto_schedule +from .auto_schedule import SearchTask, MetaTileRewritePolicy, TuneOption, HardwareParams +from .auto_schedule import auto_schedule from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, RPCRunnerWarpper from .cost_model import RandomModel from .cost_model.xgb_model import XGBModel from .serialization import LogToFile, LogReader, best_measure_pair_in_file +from .workload_registry import register_auto_scheduler_workload_func, workload_key_to_dag diff --git a/python/tvm/ansor/task.py b/python/tvm/ansor/auto_schedule.py similarity index 100% rename from python/tvm/ansor/task.py rename to python/tvm/ansor/auto_schedule.py diff --git a/python/tvm/ansor/feature.py b/python/tvm/ansor/feature.py index fb5fadf16296..a0885aabdc20 100644 --- a/python/tvm/ansor/feature.py +++ b/python/tvm/ansor/feature.py @@ -24,7 +24,7 @@ import numpy as np from .loop_state import StateObject -from .task import SearchTask +from .auto_schedule import SearchTask from .measure import MeasureInput, MeasureResult from . import _ffi_api diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index e35a73148f3a..0209a717cf0e 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -44,6 +44,8 @@ logger = logging.getLogger('ansor') +MAX_ERROR_MSG_LEN = 512 + @tvm._ffi.register_object("ansor.MeasureCallback") class MeasureCallback(Object): @@ -238,8 +240,6 @@ def __exit__(self, type, value, trace): self.tracker.terminate() self.server.terminate() -MAX_ERROR_MSG_LEN = 512 - class MeasureErrorNo(object): """Error type for MeasureResult""" @@ -505,3 +505,4 @@ def timed_func(inp, build_res): print("") return measure_results + diff --git a/python/tvm/ansor/serialization.py b/python/tvm/ansor/serialization.py index bd9a69944057..387825034a09 100644 --- a/python/tvm/ansor/serialization.py +++ b/python/tvm/ansor/serialization.py @@ -69,6 +69,11 @@ def write_measure_records_to_file(filename, inputs, results): _ffi_api.WriteMeasureRecordsToFile(filename, inputs, results) +def get_states_from_measure_inputs(inputs, task): + """Get states from measure inputs""" + return _ffi_api.GetStatesFromMeasureInputs(inputs, task) + + def best_measure_pair_in_file(filename, workload_key=None, target=None): """ Return best results form log file diff --git a/python/tvm/ansor/workload_registry.py b/python/tvm/ansor/workload_registry.py new file mode 100644 index 000000000000..c8b12f0244b2 --- /dev/null +++ b/python/tvm/ansor/workload_registry.py @@ -0,0 +1,190 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +""" +Workload registration and serialization. + +We use a json string to represent a workload (a compute dag). +The format of the string is `[func_name, [args...]]`. +The dag should be the return value of this `func_name(*args)`. + +Rationale: The workload is actually a compute dag defined by tvm dsl. But serializing compute dags +and matching them efficiently is not easy. Therefore, we use the above string to encode a compute dag. +These strings are efficient for serialization/matching and wont' be too long. +When we need the dag, we decode the string and call the function, which will return the dag. +""" + +from typing import List, Tuple, Callable, Union +from collections import Hashable +import pickle +import json +import hashlib + +import tvm._ffi +from ..te import Tensor, PlaceholderOp, ComputeOp, placeholder +from .utils import get_const_tuple +from .compute_dag import ComputeDAG + +WORKLOAD_FUNC_REGISTRY = {} + + +def register_auto_scheduler_workload_func(func: Callable): + """Register a workload generation function + The input function should take hashable and jsonable arguments + (int, float, tuple of int, tvm.tensor.Tensor, ...) and return a list of tvm.tensor.Tensor. + + Examples + -------- + @register_auto_scheduler_workload_func + def matmul(N, M, K): + A = tvm.placeholder((N, K), name='A') + B = tvm.placeholder((K, M), name='B') + k = tvm.reduce_axis((0, K), name='k') + C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i][k] * B[k][j], axis=[k]), name='C') + return [A, B, C] + """ + func_name = func.__name__ + if func_name in WORKLOAD_FUNC_REGISTRY: + raise RuntimeError('%s has been registered already' % func_name) + WORKLOAD_FUNC_REGISTRY[func_name] = func + return func + + +def compute_dag_hash(dag: ComputeDAG): + # todo: implement this more carefully and move this to c++ as a member function of ComputeDAG + str_key = '' + for op in dag.ops: + t = op.output(0) + if isinstance(op, PlaceholderOp): + str_key += 'placeholder,' + str_key += str(get_const_tuple(t.shape)) + ',' + str_key += t.dtype + ';' + elif isinstance(op, ComputeOp): + str_key += str(t.op.body) + ',' + str_key += str(get_const_tuple(t.shape)) + ',' + str_key += t.dtype + ';' + else: + raise ValueError("Invalid op: " + op) + + str_key = str_key.encode(encoding='utf-8') + return hashlib.md5(str_key).hexdigest() + + +def register_auto_scheduler_workload_bufs(bufs: List[Tensor]) -> str: + """Directly register buffers of a workload and return the workload_key + The buffers can be looked up with workload_key_to_tensors by the workload_key + """ + dag = ComputeDAG(bufs) + key = compute_dag_hash(dag) + WORKLOAD_FUNC_REGISTRY[key] = bufs + return json.dumps((key,)) + + +def list_to_tuple(x: List) -> Tuple: + """Convert a list to a tuple recursively""" + assert isinstance(x, list) + return tuple(list_to_tuple(y) if isinstance(y, list) else y for y in x) + + +def serialize_args(args: Tuple) -> Tuple: + """ + Serialize arguments of a function to a hashable and jsonable tuple. + Currently this is mainly used for tvm.tensor.Tensor + """ + ret = [] + for t in args: + if isinstance(t, Tensor): + t = ('TENSOR', get_const_tuple(t.shape), t.dtype) + elif isinstance(t, list): + t = list_to_tuple(t) + + assert isinstance(t, Hashable), str(t) + " is not hashable" + ret.append(t) + + return tuple(ret) + + +def deserialize_args(args: Tuple) -> List: + """The inverse function of :code:`serialize_args`""" + ret = [] + for t in args: + if isinstance(t, (tuple, list)) and t[0] == 'TENSOR': + ret.append(placeholder(shape=t[1], dtype=t[2])) + else: + ret.append(t) + return ret + + +@tvm._ffi.register_func("auto_scheduler.workload_key_to_tensors") +def workload_key_to_tensors(workload_key: str) -> List[Tensor]: + """Decode a workload key to the input/output tensors""" + workload = json.loads(workload_key) + name = workload[0] + lookup = WORKLOAD_FUNC_REGISTRY[name] + + if callable(lookup): + args = deserialize_args(workload[1:]) + return lookup(*args) + else: + return lookup + + +@ tvm._ffi.register_func("auto_scheduler.workload_key_to_dag") +def workload_key_to_dag(workload_key: str) -> ComputeDAG: + """Decode a workload key to a compute dag""" + tensors = workload_key_to_tensors(workload_key) + return ComputeDAG(tensors) + + +def make_workload_key_func(func: Union[str, Callable], args: Tuple) -> str: + """make a workload key from function and arguments""" + args = serialize_args(args) + + if callable(func): + func_name = func.__name__ + elif isinstance(func, str): + func_name = func + else: + raise ValueError("Invalid function: " + str(func)) + + assert func_name in WORKLOAD_FUNC_REGISTRY, \ + "%s is not registered. Please register it with register_auto_scheduler_workload_func" % func + + return json.dumps((func_name,) + args) + + +def make_workload_key_bufs(bufs: List[Tensor]) -> str: + """make a workload key from bufs""" + dag = ComputeDAG(bufs) + key = compute_dag_hash(dag) + return json.dumps((key,)) + + +def dump_workload_func_registry(filename: str): + """Dump workload function registry to a pickle binary file""" + global WORKLOAD_FUNC_REGISTRY + + pickle.dump(WORKLOAD_FUNC_REGISTRY, open(filename, 'wb')) + + +def load_workload_func_registry(filename: str): + """Load workload function registry from a pickle binary file""" + global WORKLOAD_FUNC_REGISTRY + + WORKLOAD_FUNC_REGISTRY = pickle.load(open(filename, 'rb')) + diff --git a/src/ansor/feature.cc b/src/ansor/feature.cc index 16ddb73ebf47..497a3ac4222b 100644 --- a/src/ansor/feature.cc +++ b/src/ansor/feature.cc @@ -1241,6 +1241,8 @@ void GetPerStmtFeaturesFromStates(const Array& states, for (size_t i = skip_first_n_feature_extraction; i < states.size(); ++i) { pool.Enqueue(GetPerStmtFeaturesWorkerFunc, task, states[i], max_n_bufs, &(*features)[i], &error_ct); + //GetPerStmtFeaturesWorkerFunc(task, states[i], + // max_n_bufs, &(*features)[i], &error_ct); } pool.WaitBatch(); diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index 53c75a13f197..76f5d4449001 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -507,13 +507,7 @@ bool LogReaderNode::ReadNext(MeasureInputNode* inp, MeasureResultNode* res) { // skip comment lines begin with '#' or ' ' continue; } - - try { - ReadMeasureRecord(cur_line, inp, res, &log_version); - } catch (...) { - return false; - } - + ReadMeasureRecord(cur_line, inp, res, &log_version); return true; } @@ -607,5 +601,59 @@ TVM_REGISTER_GLOBAL("ansor.LogReaderReadNext") } }); +TVM_REGISTER_GLOBAL("ansor.GetStatesFromMeasureInputs") +.set_body([](TVMArgs args, TVMRetValue *ret) { + Array inputs = args[0]; + SearchTask external_task; + + if (args.size() > 1) { + external_task = args[1]; + } + + Array states; + states.reserve(inputs.size()); + + // (workload_key, target) -> (search_task) + std::unordered_map, SearchTask> task_cache; + + for (const auto& inp : inputs) { + const std::string& workload_key = inp->task->workload_key; + std::pair key(workload_key, inp->task->target->str()); + + const SearchTaskNode* ptask; + if (external_task.defined()) { + ptask = external_task.operator->(); + } else { + auto find_res = task_cache.find(key); + if (find_res == task_cache.end()) { + if (inp->task->compute_dag.defined()) { // the measure input is complete + ptask = inp->task.operator->(); + } else { // the measure input is incomplete + // rebuild task for incomplete measure pairs read from file + SearchTask new_task = SearchTaskNode::make( + ComputeDAGNode::make_by_workload_key(workload_key), + workload_key, + inp->task->target, + inp->task->target_host, + inp->task->hardware_params); + task_cache.insert(std::make_pair(key, new_task)); + ptask = new_task.operator->(); + } + } else { + ptask = find_res->second.operator->(); + } + } + + State tmp_s = ptask->compute_dag.GetInitState(); + StateNode *ps = tmp_s.CopyOnWrite(); + ps->transform_steps = inp->state->transform_steps; + tmp_s.DoSteps(ps->transform_steps, ptask->compute_dag); + states.push_back(std::move(tmp_s)); + } + + *ret = states; +}); + + } // namespace ansor } // namespace tvm diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index 1fbae0fd2dcd..f6a8ad034aa5 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -33,20 +33,22 @@ namespace tvm { namespace tir { -class GPUCodeVerifier : public StmtVisitor { +class GPUCodeVerifier : public StmtExprVisitor { public: bool Verify(Stmt stmt, int64_t max_local_memory_per_block, int64_t max_shared_memory_per_block, int64_t max_threads_per_block, int64_t max_thread_x, int64_t max_thread_y, - int64_t max_thread_z) { + int64_t max_thread_z, int64_t max_vector_bytes) { max_local_memory_per_block_ = static_cast(max_local_memory_per_block); max_shared_memory_per_block_ = static_cast(max_shared_memory_per_block); max_threads_per_block_ = static_cast(max_threads_per_block); max_thread_x_ = static_cast(max_thread_x); max_thread_y_ = static_cast(max_thread_y); max_thread_z_ = static_cast(max_thread_z); + max_vector_bytes_ = static_cast(max_vector_bytes); Reset_(); + // TODO(jcf94): Add support of detecting CUDA Misaligned Address error this->VisitStmt(stmt); return valid_; @@ -62,6 +64,10 @@ class GPUCodeVerifier : public StmtVisitor { size_t size = static_cast(op->constant_allocation_size()); shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes(); } + + if (op->dtype.lanes() > 1) { + valid_ &= op->dtype.lanes() * op->dtype.bytes() <= static_cast(max_vector_bytes_); + } } void VisitStmt_(const AttrStmtNode* op) final { @@ -129,6 +135,18 @@ class GPUCodeVerifier : public StmtVisitor { } } + void VisitExpr_(const LoadNode* op) { + // Currently not able to check: + // if the index expression failed to be simplified to a Ramp + if (op->index->IsInstance()) { + if (op->dtype.lanes() > 1) { + valid_ &= op->dtype.lanes() * op->dtype.bytes() <= + static_cast(max_vector_bytes_); + } + } + ExprVisitor::VisitExpr_(op); + } + private: int nest_level_{0}; @@ -146,6 +164,7 @@ class GPUCodeVerifier : public StmtVisitor { size_t max_shared_memory_per_block_; size_t max_threads_per_block_; size_t max_thread_x_, max_thread_y_, max_thread_z_; + size_t max_vector_bytes_; bool valid_{true}; @@ -169,27 +188,32 @@ bool VerifyGPUCode(const PrimFunc& func, Map constraints) { int64_t max_thread_x = INT64_MAX; int64_t max_thread_y = INT64_MAX; int64_t max_thread_z = INT64_MAX; + int64_t max_vector_bytes = INT64_MAX; for (auto iter : constraints) { const IntImmNode* val = iter.second.as(); - if (iter.first == "max_local_memory_per_block") + if (iter.first == "max_local_memory_per_block") { max_local_memory_per_block = val->value; - else if (iter.first == "max_shared_memory_per_block") + } else if (iter.first == "max_shared_memory_per_block") { max_shared_memory_per_block = val->value; - else if (iter.first == "max_threads_per_block") + } else if (iter.first == "max_threads_per_block") { max_threads_per_block = val->value; - else if (iter.first == "max_thread_x") + } else if (iter.first == "max_thread_x") { max_thread_x = val->value; - else if (iter.first == "max_thread_y") + } else if (iter.first == "max_thread_y") { max_thread_y = val->value; - else if (iter.first == "max_thread_z") + } else if (iter.first == "max_thread_z") { max_thread_z = val->value; - else + } else if (iter.first == "max_vector_bytes") { + max_vector_bytes = val->value; + } else { LOG(FATAL) << "Invalid check item: " << iter.first; + } } return verifier.Verify(func->body, max_local_memory_per_block, max_shared_memory_per_block, - max_threads_per_block, max_thread_x, max_thread_y, max_thread_z); + max_threads_per_block, max_thread_x, max_thread_y, max_thread_z, + max_vector_bytes); } TVM_REGISTER_GLOBAL("tir.analysis.verify_gpu_code").set_body_typed(VerifyGPUCode); diff --git a/tests/python/unittest/test_ansor_common.py b/tests/python/unittest/test_ansor_common.py index cd8a1eedb162..1790b06bcb60 100644 --- a/tests/python/unittest/test_ansor_common.py +++ b/tests/python/unittest/test_ansor_common.py @@ -17,18 +17,16 @@ """Common functions for ansor test cases""" - from tvm import te, ansor import topi -def matmul_nkkm(N, M, K): +@ansor.register_auto_scheduler_workload_func +def matmul_ansor_test(N, M, K): A = te.placeholder((N, K), name='A') B = te.placeholder((K, M), name='B') k = te.reduce_axis((0, K), name='k') - C = te.compute((N, M), lambda i, j: te.sum( - A[i][k] * B[k][j], axis=[k]), name='C') - + C = te.compute((N, M), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C') return [A, B, C] @@ -58,7 +56,7 @@ def conv2d_nchw_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation def get_tiled_matmul(): - dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) + dag = ansor.ComputeDAG(matmul_ansor_test(512, 512, 512)) s0 = dag.get_init_state() A, B, C = 0, 1, 2 @@ -80,3 +78,4 @@ def get_tiled_matmul(): C_global += 1 s0.compute_at(A_global, C_global, s0.stages[C_global].iters[2]) return dag, s0.state_object + diff --git a/tests/python/unittest/test_ansor_feature.py b/tests/python/unittest/test_ansor_feature.py index abd304a9c2d7..3da1c7aa332e 100644 --- a/tests/python/unittest/test_ansor_feature.py +++ b/tests/python/unittest/test_ansor_feature.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + """Test feature extraction""" import math @@ -6,7 +23,7 @@ import tvm from tvm import te, ansor -from test_ansor_common import matmul_nkkm +from test_ansor_common import matmul_ansor_test def fequal(a, b): @@ -14,7 +31,7 @@ def fequal(a, b): def test_cpu_matmul(): - dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) + dag = ansor.ComputeDAG(matmul_ansor_test(512, 512, 512)) s = dag.get_init_state() C = 2 @@ -87,11 +104,48 @@ def fusion_test(N, M): def test_gpu_feature(): - # todo(lmzheng) - pass + ctx = tvm.context("cuda", 0) + if not ctx.exist: + return + + json_records = "\n".join(( + """{"i": [["[\\"matmul_ansor_test\\", 512, 512, 512]", "cuda"], [[], [["CHW", 2, "local"], ["SP", 2, 0, 512, [1, 16, 32, 1], 1], ["SP", 2, 5, 512, [4, 1, 1, 16], 1], ["SP", 2, 10, 512, [1, 2], 1], ["RE", 2, [0, 5, 1, 6, 2, 7, 10, 11, 3, 8, 12, 4, 9]], ["FSP", 3, 0, 1, 3], ["FSP", 3, 4, 2, 3], ["RE", 3, [0, 4, 1, 5, 2, 6, 3, 7]], ["FU", 2, [0, 1]], ["FU", 3, [0, 1]], ["FU", 2, [1, 2]], ["FU", 3, [1, 2]], ["FU", 2, [2, 3]], ["FU", 3, [2, 3]], ["CA", 2, 3, 2], ["CHR", 1, "shared", [2]], ["CA", 2, 3, 3], ["FU", 2, [0, 1]], ["FFSP", 2, 0, [1, 2], 1, 1], ["AN", 2, 1, 6], ["CHR", 0, "shared", [3]], ["CA", 1, 4, 3], ["FU", 1, [0, 1]], ["FFSP", 1, 0, [1, 2], 1, 1], ["AN", 1, 1, 6], ["AN", 5, 0, 5], ["AN", 5, 1, 4], ["AN", 5, 2, 6], ["PR", 4, 0, "auto_unroll_max_step$1024"]]]], "r": [[0.00536798], 0, 2.49277, 1585564852], "v": "v0.1"}""", + )) + + # load states + with tempfile.NamedTemporaryFile(mode='w') as f: + f.write(json_records) + f.flush() + inputs, results = ansor.LogReader(f.name).read_lines() + + inp = inputs[0] + dag = ansor.workload_key_to_dag(inp.task.workload_key) + task = ansor.SearchTask(dag, inp.task.workload_key, inp.task.target, None, ansor.HardwareParams(100000, 16, 64, 4, 64)) + + state = ansor.serialization.get_states_from_measure_inputs(inputs, task)[0] + state = dag.infer_bound_from_state(state) + fea = ansor.feature.get_per_stmt_features_from_states([state], task)[0] + names = ansor.feature.get_per_stmt_feature_names() + + # build feature dict + fea_dicts = [] + for i in range(len(fea)): + tmp_dict = {} + for j in range(len(names)): + tmp_dict[names[j]] = fea[i][j] + fea_dicts.append(tmp_dict) + + # check values + assert fequal(fea_dicts[0]['blockIdx_x_len'], math.log2(8 + 1)) + assert fequal(fea_dicts[0]['vthread_len'], math.log2(4 + 1)) + assert fequal(fea_dicts[1]['threadIdx_x_len'], math.log2(16 + 1)) + assert fequal(fea_dicts[0]['threadIdx_y_len'], math.log2(1 + 1)) + assert fequal(fea_dicts[2]['blockIdx_z_len'], math.log2(1 + 1)) + assert fequal(fea_dicts[0]['is_gpu'], 1.0) if __name__ == "__main__": test_cpu_matmul() test_cpu_fusion() test_gpu_feature() + diff --git a/tests/python/unittest/test_ansor_loop_state.py b/tests/python/unittest/test_ansor_loop_state.py index 287a1b773395..612d320036d8 100644 --- a/tests/python/unittest/test_ansor_loop_state.py +++ b/tests/python/unittest/test_ansor_loop_state.py @@ -20,11 +20,11 @@ from tvm import ansor, te import topi -from test_ansor_common import matmul_nkkm, conv2d_nchw_bn_relu +from test_ansor_common import matmul_ansor_test, conv2d_nchw_bn_relu def test_split_fuse_reorder_annotation(): - dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) + dag = ansor.ComputeDAG(matmul_ansor_test(512, 512, 512)) s0 = dag.get_init_state() C = 2 i, j, k = s0.stages[C].iters @@ -67,7 +67,7 @@ def test_split_fuse_reorder_annotation(): def test_follow_split_follow_fused_split(): - dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) + dag = ansor.ComputeDAG(matmul_ansor_test(512, 512, 512)) s0 = dag.get_init_state() C = 2 @@ -433,7 +433,7 @@ def test_cache_read_write(): def test_rfactor(): - dag = ansor.ComputeDAG(matmul_nkkm(8, 8, 512)) + dag = ansor.ComputeDAG(matmul_ansor_test(8, 8, 512)) s0 = dag.get_init_state() C = 2 diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py index 6636787e807f..a28456574abe 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -25,7 +25,7 @@ import tvm from tvm import ansor -from test_ansor_common import matmul_nkkm +from test_ansor_common import matmul_ansor_test def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local', cost_model=ansor.RandomModel(), n_trials=2): @@ -33,7 +33,7 @@ def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local' random.seed(seed) N = 128 - A, B, C = matmul_nkkm(N, N, N) + A, B, C = matmul_ansor_test(N, N, N) dag = ansor.ComputeDAG([A, B, C]) tgt = tvm.target.create(target) task = ansor.SearchTask(dag, "test", tgt) From 143ea451bfe848ed7ef5424ffaf468344c38ea4c Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 9 Jun 2020 02:10:48 -0700 Subject: [PATCH 17/45] add task scheduler (#17) --- python/tvm/ansor/__init__.py | 2 + python/tvm/ansor/auto_schedule.py | 7 +- python/tvm/ansor/cost_model/__init__.py | 1 + python/tvm/ansor/cost_model/xgb_model.py | 10 +- python/tvm/ansor/feature.py | 5 +- python/tvm/ansor/measure.py | 7 + python/tvm/ansor/task_scheduler.py | 274 ++++++++++++++++++ python/tvm/ansor/workload_registry.py | 1 - src/ansor/measure.cc | 39 +-- src/ansor/search_policy/search_policy.cc | 17 ++ .../unittest/test_ansor_task_scheduler.py | 43 +++ 11 files changed, 365 insertions(+), 41 deletions(-) create mode 100644 python/tvm/ansor/task_scheduler.py create mode 100644 tests/python/unittest/test_ansor_task_scheduler.py diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index bb4822409757..4e57c16d18a5 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -25,6 +25,7 @@ from . import utils from . import feature from . import workload_registry +from . import task_scheduler # Shortcut from .compute_dag import ComputeDAG @@ -35,3 +36,4 @@ from .cost_model.xgb_model import XGBModel from .serialization import LogToFile, LogReader, best_measure_pair_in_file from .workload_registry import register_auto_scheduler_workload_func, workload_key_to_dag +from .task_scheduler import TaskScheduler, SimpleTaskScheduler diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index affcf4a6e195..5f4b7946b087 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -22,7 +22,7 @@ import tvm._ffi from tvm.runtime import Object from .measure import LocalBuilder, LocalRunner -from .cost_model import RandomModel +from .cost_model import RandomModel, XGBModel from . import _ffi_api @@ -67,11 +67,12 @@ def __init__(self, dag, workload_key, target, target_host=None, @tvm._ffi.register_object("ansor.SearchPolicy") class SearchPolicy(Object): - pass + def continue_search(self, task, num_measure, verbose, measurer): + return _ffi_api.SearchPolicyContinueSearchOneRound(self, task, num_measure, verbose, measurer) @tvm._ffi.register_object("ansor.MetaTileRewritePolicy") -class MetaTileRewritePolicy(Object): +class MetaTileRewritePolicy(SearchPolicy): """ The search policy that searches with meta tiling and random rewrite Parameters diff --git a/python/tvm/ansor/cost_model/__init__.py b/python/tvm/ansor/cost_model/__init__.py index fc3821cf7998..56e4a5f9128b 100644 --- a/python/tvm/ansor/cost_model/__init__.py +++ b/python/tvm/ansor/cost_model/__init__.py @@ -18,3 +18,4 @@ """ Cost model that estimates the performance of programs """ from .cost_model import RandomModel +from .xgb_model import XGBModel diff --git a/python/tvm/ansor/cost_model/xgb_model.py b/python/tvm/ansor/cost_model/xgb_model.py index e61acfbd168f..fce3f16d18ba 100644 --- a/python/tvm/ansor/cost_model/xgb_model.py +++ b/python/tvm/ansor/cost_model/xgb_model.py @@ -92,14 +92,15 @@ def update(self, inputs, results): # extract feature n_cached = len(self.inputs_feature_cache) features, normalized_throughputs, task_ids = \ - get_per_stmt_features_from_measure_pairs(self.inputs, self.results, - skip_first_n_feature_extraction=n_cached) + get_per_stmt_features_from_measure_pairs(self.inputs, self.results, + skip_first_n_feature_extraction=n_cached) if n_cached > 0: features = list(features) features[:n_cached] = self.inputs_feature_cache features = np.array(features) self.inputs_feature_cache = features - dtrain = pack_sum_xgbmatrix(features, normalized_throughputs, task_ids, normalized_throughputs) + dtrain = pack_sum_xgbmatrix(features, normalized_throughputs, + task_ids, normalized_throughputs) # train xgb model self.bst = xgb.train(self.xgb_params, dtrain, @@ -133,7 +134,6 @@ def predict(self, task, states): def predict_stages(self, task, states): # Format: (s0 score, ..., sN score, s0 n_stage, s0 stage 0, ..., s1 n_stage, s1 stage 0,) - features = get_per_stmt_features_from_states(states, task) if self.bst is not None and len(self.inputs) > self.num_warmup_sample: dtest, pack_ids = pack_sum_xgbmatrix_for_prediction(features) @@ -339,7 +339,7 @@ def feval(preds, labels): return feval def pack_sum_average_recall_score(N): - """evaluate average recall score for xgb""" + """Evaluate average recall score for xgb""" def feval(preds, labels): group_sizes = dmatrix_context.get('group_sizes', labels, [len(preds)]) diff --git a/python/tvm/ansor/feature.py b/python/tvm/ansor/feature.py index a0885aabdc20..f91d7da169f5 100644 --- a/python/tvm/ansor/feature.py +++ b/python/tvm/ansor/feature.py @@ -24,7 +24,6 @@ import numpy as np from .loop_state import StateObject -from .auto_schedule import SearchTask from .measure import MeasureInput, MeasureResult from . import _ffi_api @@ -124,7 +123,7 @@ def get_per_stmt_features_from_file(filename: str, def get_per_stmt_features_from_measure_pairs(inputs: List[MeasureInput], results: List[MeasureResult], skip_first_n_feature_extraction: int = 0, - max_n_bufs: int = None,) \ + max_n_bufs: int = None) \ -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Get per_stmt features from measurement pairs""" byte_arr = _ffi_api.GetPerStmtFeaturesFromMeasurePairs( @@ -133,7 +132,7 @@ def get_per_stmt_features_from_measure_pairs(inputs: List[MeasureInput], def get_per_stmt_features_from_states(states: List[StateObject], - task: SearchTask, + task: "SearchTask", max_n_bufs: int = None) -> List[np.ndarray]: """Get per_stmt features from states""" byte_arr = _ffi_api.GetPerStmtFeaturesFromStates( diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index 0209a717cf0e..b062eb585d12 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -171,6 +171,13 @@ def __init__(self, self.__init_handle_by_constructor__( _ffi_api.LocalRunner, timeout, number, repeat, min_repeat_ms, cooldown_interval) +@tvm._ffi.register_object("ansor.ProgramMeasurer") +class ProgramMeasurer(Object): + def __init__(self, builder: Builder, runner: Runner, + callbacks: List[MeasureCallback], + verbose: int, max_continuous_error: int = -1): + self.__init_handle_by_constructor__( + _ffi_api.ProgramMeasurer, builder, runner, callbacks, verbose, max_continuous_error) @tvm._ffi.register_object("ansor.RPCRunner") class RPCRunner(Runner): diff --git a/python/tvm/ansor/task_scheduler.py b/python/tvm/ansor/task_scheduler.py new file mode 100644 index 000000000000..5144591d4f98 --- /dev/null +++ b/python/tvm/ansor/task_scheduler.py @@ -0,0 +1,274 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""TaskScheduler that allocates the time resources when tuning multiple tasks together""" +from typing import List, Union, Callable +import time + +import numpy as np + +from .auto_schedule import SearchTask, SearchPolicy, MetaTileRewritePolicy, TuneOption +from .cost_model import RandomModel, XGBModel +from .measure import ProgramMeasurer +from .utils import array_mean, to_str_round + + +class TaskScheduler: + """Allocate the time resources when tuning multiple tasks together""" + def __init__(self, + tasks: List[SearchTask], + objective_func: Callable = None): + self.tasks = tasks + self.objective_func = objective_func or sum + + def compute_score(self, costs: List[float]) -> float: + return self.objective_func(costs) + + +def get_search_policies(search_policy: Union[str, List[SearchPolicy]], tasks: List[SearchTask], + num_measure_per_iter, load_model_file=None, load_log_file=None): + if search_policy == 'default': + search_policy = 'meta-rewrite.xgb' + + if isinstance(search_policy, str): + policy_type, model_type = search_policy.split('.') + if model_type == 'xgb': + cost_model = XGBModel(num_warmup_sample=len(tasks) * num_measure_per_iter) + if load_model_file: + print("Load pretrained model...") + cost_model.load(load_model_file) + elif load_log_file: + cost_model.load_log_file(load_log_file) + elif model_type == 'random': + cost_model = RandomModel() + else: + raise ValueError("Invalid search policy: " + search_policy) + + if policy_type == 'meta-rewrite': + search_policies = [MetaTileRewritePolicy(cost_model) for _ in range(len(tasks))] + elif policy_type == 'limit-space': + search_policies = [MetaTileRewritePolicy(cost_model, + params={'cpu_multi_level_tiling_structure': 'SRS', + 'disable_change_compute_location': 1}) + for _ in range(len(tasks))] + elif policy_type == 'beam-search': + search_policies = [MetaTileRewritePolicy(cost_model, + params={'use_beam_search': 1}) + for _ in range(len(tasks))] + else: + raise ValueError("Invalid search policy: " + search_policy) + else: + # check type + assert isinstance(search_policy, (tuple, list)) + for item in search_policy: + assert isinstance(item, SearchPolicy) + search_policies = search_policy + + return search_policies + + +class SimpleTaskScheduler(TaskScheduler): + """The default task scheduler with several strategies + + Parameters + ---------- + tasks: List[SearchTask] + All workloads to tune + weights: List[float] + Weights of tasks (i.e. the number of occurrence of a task in the whole network) + strategy: str + The joint tuning strategy. + "sequential" : Tune tasks sequentially. Divide n_trials equally to every task. + "round-robin": Tune tasks in round robin order. + "gradient" : Tune tasks with gradient descent. + load_log_file: str + Load history log file to pre-train cost model + eps-random: float + Always allocate this percent of n_trials to select tasks randomly. This is for encouraging exploration. + verbose: int + The level of verbosity. 0 means silent. + alpha: float + The parameter used for 'gradient' strategy + beta: float + The parameter used for 'gradient' strategy + backward_window_size: int + The parameter used for 'gradient' strategy + """ + def __init__(self, + tasks: List[SearchTask], + objective_func: Callable = None, + strategy: str = 'gradient', + load_log_file: str = None, + load_model_file: str = None, + eps_random: float = 0.05, + verbose: int = 1, + alpha: float = 0.2, + beta: float = 2, + gamma: float = 0.5, + backward_window_size: int = 3, + use_debug_measurement_simulator=None): + super().__init__(tasks, objective_func) + self.strategy = strategy + self.eps_random = eps_random + self.verbose = verbose + self.load_log_file = load_log_file + self.load_model_file = load_model_file + self.alpha = alpha + self.beta = beta + self.gamma = gamma + self.backward_window_size = backward_window_size + self.use_debug_measurement_simulator = use_debug_measurement_simulator + + assert self.strategy in ['round-robin', 'gradient'] + + self.task_cts = [] + self.task_costs_history = [] + self.best_costs = self.cur_score = None + self.tune_option = self.measurer = self.search_policies = self.ct = self.tic = None + self.num_measure_per_iter = None + self.dead_tasks = set() + self.sequential_now_task_idx = 0 + self.sequential_now_task_begin_ct = 0 + + def tune(self, tune_option: TuneOption, search_policy: Union[str, List[SearchPolicy]] = 'default'): + # init members + self.task_cts = [0 for _ in range(len(self.tasks))] + self.task_costs_history = [[] for _ in range(len(self.tasks))] + self.best_costs = 1e10 * np.ones(len(self.tasks)) + self.cur_score = self.compute_score(self.best_costs) + self.tune_option = tune_option + if self.use_debug_measurement_simulator is None: + self.measurer = ProgramMeasurer(tune_option.builder, tune_option.runner, + tune_option.callbacks, tune_option.verbose) + self.ct = 0 + self.tic = time.time() + # reset num_measure_per_iter to make sure every task is tuned at least once + self.num_measure_per_iter = min(tune_option.num_measure_per_iter, + tune_option.n_trials // len(self.tasks)) + self.search_policies = get_search_policies(search_policy, self.tasks, + self.num_measure_per_iter, + self.load_model_file, + self.load_log_file) + self.dead_tasks = set() + self.sequential_now_task_idx = 0 + self.sequential_now_task_begin_ct = 0 + + # do a round robin first + if self.strategy != 'sequential': + for i in range(len(self.tasks)): + self.tune_task(i) + + # use the specific strategy to choose workload to tune + task_idx = -1 + while self.ct < tune_option.n_trials and len(self.dead_tasks) < len(self.tasks): + if self.strategy == 'sequential': + allocated_total_ct = ((tune_option.n_trials - self.sequential_now_task_begin_ct) + / (len(self.tasks) - self.sequential_now_task_idx)) + used_ct = self.ct - self.sequential_now_task_begin_ct + + if self.sequential_now_task_idx in self.dead_tasks or used_ct >= allocated_total_ct: + self.sequential_now_task_idx += 1 + self.sequential_now_task_begin_ct = self.ct + task_idx = self.sequential_now_task_idx + if task_idx >= len(self.tasks): + break + elif self.strategy == 'round-robin': + task_idx = (task_idx + 1) % len(self.tasks) + while task_idx in self.dead_tasks: + task_idx = (task_idx + 1) % len(self.tasks) + elif self.strategy == 'gradient': + gradients = [] + for i in range(len(self.tasks)): + if i in self.dead_tasks: + gradients.append(0) + continue + + # compute gradient from chain rule : (delta f / delta g_i) + delta = 1e-7 + new_costs = list(self.best_costs) + new_costs[i] -= delta + chain_grad = (self.compute_score(self.best_costs) - self.compute_score(new_costs)) / delta + + # compute (g_i(t_i) - g(t_i - \Delta t)) / (\Delta t) + if self.task_cts[i] - 1 - self.backward_window_size >= 0: + backward_grad = (self.task_costs_history[i][self.task_cts[i] - 1] + - self.task_costs_history[i][self.task_cts[i] - 1 - self.backward_window_size]) \ + / self.backward_window_size + else: + backward_grad = 0 + + # compute (g_i(t_i + \Delta t) - g(t_i)) / (\Delta t) + g_next_1 = self.best_costs[i] - (self.best_costs[i] / self.task_cts[i]) + # todo(lmzheng): this needs adding attribute to topi.compute for similarity check + g_next_2 = self.beta * 1e20 + g_next = min(g_next_1, g_next_2) + forward_grad = g_next - self.best_costs[i] + + # combine all grads + grad = chain_grad * (self.alpha * backward_grad + (1 - self.alpha) * forward_grad) + assert grad <= 0 + gradients.append(grad) + + if max(gradients) == min(gradients): + task_idx = np.random.choice(len(gradients)) + else: + task_idx = np.argmin(gradients) + else: + raise ValueError("Invalid strategy: " + self.strategy) + + self.tune_task(task_idx) + + def tune_task(self, task_idx): + if self.use_debug_measurement_simulator is not None: + measure_inputs, measure_results = \ + self.use_debug_measurement_simulator.get_next_batch( + self.tasks[task_idx], + self.num_measure_per_iter, + ) + else: + measure_inputs, measure_results = \ + self.search_policies[task_idx].continue_search( + self.tasks[task_idx], + self.num_measure_per_iter, + self.tune_option.verbose, + self.measurer) + + for inp, res in zip(measure_inputs, measure_results): + cost = array_mean(res.costs) + if cost < self.best_costs[task_idx]: + self.best_costs[task_idx] = cost + + if len(measure_inputs) == 0: + self.dead_tasks.add(task_idx) + + self.task_cts[task_idx] += 1 + self.task_costs_history[task_idx].append(self.best_costs[task_idx]) + + self.ct += len(measure_inputs) + self.cur_score = self.compute_score(self.best_costs) + + if self.verbose >= 1: + print(("TaskScheduler\tct: %d\testimated cost (ms): %.3f\ttime elapsed: %.2f\t" + + "best_costs (ms): %s\ttask_ct: %s") % + (self.ct, self.cur_score * 1e3, time.time() - self.tic, + to_str_round(self.best_costs * 1e3, decimal=3), + self.task_cts)) + + def remove_dead_task(self, prob): + for idx in self.dead_tasks: + prob[idx] = 0 + return prob / prob.sum() diff --git a/python/tvm/ansor/workload_registry.py b/python/tvm/ansor/workload_registry.py index c8b12f0244b2..381e6009eea8 100644 --- a/python/tvm/ansor/workload_registry.py +++ b/python/tvm/ansor/workload_registry.py @@ -187,4 +187,3 @@ def load_workload_func_registry(filename: str): global WORKLOAD_FUNC_REGISTRY WORKLOAD_FUNC_REGISTRY = pickle.load(open(filename, 'rb')) - diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc index e3593753d3ff..73bbade241c5 100644 --- a/src/ansor/measure.cc +++ b/src/ansor/measure.cc @@ -324,24 +324,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); TVM_REGISTER_GLOBAL("ansor.MeasureInput") -.set_body_typed([](SearchTask task, State state) { - return MeasureInputNode::make(task, state); -}); +.set_body_typed(MeasureInputNode::make); TVM_REGISTER_GLOBAL("ansor.BuildResult") -.set_body_typed([](std::string filename, Array args, - int error_no, std::string error_msg, double time_cost) { - return BuildResultNode::make(filename, args, error_no, error_msg, - time_cost); -}); +.set_body_typed(BuildResultNode::make); TVM_REGISTER_GLOBAL("ansor.MeasureResult") -.set_body_typed([](Array costs, int error_no, - std::string error_msg, double all_cost, - double timestamp) { - return MeasureResultNode::make(costs, error_no, error_msg, all_cost, - timestamp); -}); +.set_body_typed(MeasureResultNode::make); TVM_REGISTER_GLOBAL("ansor.BuilderBuild") .set_body_typed([](const Builder& builder, @@ -356,25 +345,17 @@ TVM_REGISTER_GLOBAL("ansor.RunnerRun") }); TVM_REGISTER_GLOBAL("ansor.LocalBuilder") -.set_body_typed([](int timeout, int n_parallel, - const std::string& build_func) { - return LocalBuilderNode::make(timeout, n_parallel, build_func); -}); +.set_body_typed(LocalBuilderNode::make); TVM_REGISTER_GLOBAL("ansor.LocalRunner") -.set_body_typed([](int timeout, int number, int repeat, int min_repeat_ms, - double cooldown_interval) { - return LocalRunnerNode::make(timeout, number, repeat, min_repeat_ms, - cooldown_interval); -}); +.set_body_typed(LocalRunnerNode::make); TVM_REGISTER_GLOBAL("ansor.RPCRunner") -.set_body_typed([](const std::string& key, const std::string& host, int port, - int priority, int timeout, int n_parallel, int number, - int repeat, int min_repeat_ms, double cooldown_interval) { - return RPCRunnerNode::make(key, host, port, priority, timeout, n_parallel, - number, repeat, min_repeat_ms, cooldown_interval); -}); +.set_body_typed(RPCRunnerNode::make); + +TVM_REGISTER_GLOBAL("ansor.ProgramMeasurer") +.set_body_typed(ProgramMeasurerNode::make); + } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_policy/search_policy.cc b/src/ansor/search_policy/search_policy.cc index 866922d0001e..f3072fda4956 100644 --- a/src/ansor/search_policy/search_policy.cc +++ b/src/ansor/search_policy/search_policy.cc @@ -23,11 +23,28 @@ */ #include "search_policy.h" +#include namespace tvm { namespace ansor { TVM_REGISTER_OBJECT_TYPE(SearchPolicyNode); +// Search Policy +TVM_REGISTER_GLOBAL("ansor.SearchPolicyContinueSearchOneRound") +.set_body([](TVMArgs args, TVMRetValue *ret) { + SearchPolicy policy = args[0]; + SearchTask task = args[1]; + int num_measure = args[2]; + int verbose = args[3]; + ProgramMeasurer measurer = args[4]; + + Array inputs; + Array results; + std::tie(inputs, results) = policy->ContinueSearchOneRound(task, num_measure, verbose, measurer); + + *ret = Array{inputs, results}; +}); + } // namespace ansor } // namespace tvm diff --git a/tests/python/unittest/test_ansor_task_scheduler.py b/tests/python/unittest/test_ansor_task_scheduler.py new file mode 100644 index 000000000000..e95d65d4b5ce --- /dev/null +++ b/tests/python/unittest/test_ansor_task_scheduler.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test the task scheduler """ + +import tvm +from tvm import ansor + +from test_ansor_common import matmul_ansor_test + +def test_task_scheduler_basic(): + N = 128 + A, B, C = matmul_ansor_test(N, N, N) + dag = ansor.ComputeDAG([A, B, C]) + tgt = tvm.target.create("llvm") + task1 = ansor.SearchTask(dag, "test", tgt) + task2 = ansor.SearchTask(dag, "test", tgt) + + def objective(costs): + return sum(costs) + + task_scheduler = ansor.SimpleTaskScheduler([task1, task2], objective) + tune_option = ansor.TuneOption(n_trials=3, runner='local') + + task_scheduler.tune(tune_option) + + +if __name__ == "__main__": + test_task_scheduler_basic() From ed075c276c3fecc3ed3ff16b87a707b5482ff6f9 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Tue, 9 Jun 2020 17:53:06 +0800 Subject: [PATCH 18/45] Add conv2d cuda tutorial with workload registry (#18) --- docs/conf.py | 2 +- python/tvm/ansor/__init__.py | 3 +- tutorials/ansor/tune_conv2d_cuda.py | 164 ++++++++++++++++++++++++ tutorials/ansor/tune_simple_subgraph.py | 2 + 4 files changed, 169 insertions(+), 2 deletions(-) create mode 100644 tutorials/ansor/tune_conv2d_cuda.py diff --git a/docs/conf.py b/docs/conf.py index 5cbaab7f7b6d..5826526d55b0 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -197,8 +197,8 @@ ['../tutorials/frontend', '../tutorials/language', '../tutorials/optimize', - '../tutorials/ansor', '../tutorials/autotvm', + '../tutorials/ansor', '../tutorials/dev', '../tutorials/topi', '../tutorials/deployment', diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 4e57c16d18a5..bfdbaf9c8c8c 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -35,5 +35,6 @@ from .cost_model import RandomModel from .cost_model.xgb_model import XGBModel from .serialization import LogToFile, LogReader, best_measure_pair_in_file -from .workload_registry import register_auto_scheduler_workload_func, workload_key_to_dag +from .workload_registry import register_auto_scheduler_workload_func, workload_key_to_dag, \ + make_workload_key_func from .task_scheduler import TaskScheduler, SimpleTaskScheduler diff --git a/tutorials/ansor/tune_conv2d_cuda.py b/tutorials/ansor/tune_conv2d_cuda.py new file mode 100644 index 000000000000..82a5e8572ba2 --- /dev/null +++ b/tutorials/ansor/tune_conv2d_cuda.py @@ -0,0 +1,164 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Auto-scheduling High Performance Convolution on NVIDIA GPUs +=========================================================== +**Author**: `Lianmin Zheng `_, \ + `Chengfan Jia `_, \ + `Minmin Sun `_, \ + `Zhao Wu `_ + +This is an tutorial for searching high performance schedule for NVIDIA GPU using +Ansor auto-scheduler. By running Ansor on this template, we can outperform the +vendor provided library CuDNN in many cases. +""" + +###################################################################### +# Install dependencies +# -------------------- +# To use autotvm package in tvm, we need to install some extra dependencies. +# (change "3" to "2" if you use python2): +# +# .. code-block:: bash +# +# pip3 install --user psutil xgboost tornado +# +# To make TVM run faster in tuning, it is recommended to use cython +# as FFI of tvm. In the root directory of tvm, execute +# +# .. code-block:: bash +# +# pip3 install --user cython +# sudo make cython3 +# +# Now return to python code. Import packages. + +import random +import sys + +import numpy as np +import tvm +import topi +from topi.testing import conv2d_nchw_python +from tvm import te + +# the module is called `ansor` +from tvm import ansor + +###################################################################### +# Step 1: Define the search task +# ------------------------------- +# There are plenty of useful schedule primitives in tvm. You can also find +# some tutorials that describe them in more details, such as +# (1). :ref:`opt-conv-gpu` +# (2). `Optimizing DepthwiseConv on NVIDIA GPU `_ +# +# It's usually a hard job if one wants to get a high performance schedule for a +# specific workload. Even writing an AutoTVM tunable template needs user to have +# expertises on how each schedule primitive works as well as how they finally +# reflect on the hardward architecture. +# +# However, with Ansor this will be quite simple. Firstly, define the target workload. +# Both :code:`tvm.te` API or topi op API are fine to be used. +# +# We can use the retuned :code:`Tensors` to create a ComputeDAG just like what we do +# in :ref:`ansor-simple-subgraph`, while the way to use workload registry is more +# recommended. + +# Use an extra function decorator to regist this workload +@ansor.register_auto_scheduler_workload_func +def conv2d_nchw(N, H, W, CO, CI, KH, KW, stride, padding): + data = te.placeholder((N, CI, H, W), name='data') + kernel = te.placeholder((CO, CI, KH, KW), name='kernel') + conv = topi.nn.conv2d_nchw(data, kernel, stride, padding, dilation=1, out_dtype='float32') + + return [data, kernel, conv] + +###################################################################### +# Step 2: Search through the schedule space +# ------------------------------------------ +# We pick the last layer on resnet as test case. +# Since our space is very large, :code:`XGBModel` is most suitable +# for our case. Here we only do 20 trials for demonstration. +# In practice, making 1000 trials usually can find some good kernels +# for this workload. + +tgt = tvm.target.cuda() + +# The last layer in resnet +N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1) +# Generate workload key with the ansor API +wkl_key = ansor.make_workload_key_func(conv2d_nchw, (N, H, W, CO, CI, KH, KW, strides, padding)) +# Generate ComputeDAG using the workload key +dag = ansor.workload_key_to_dag(wkl_key) +task = ansor.SearchTask(dag, wkl_key, target=tgt) + +log_file = "conv2d_nchw.json" +seed = 0 +random.seed(seed) +cost_model = ansor.XGBModel() +search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) + +######################################################################### +# The :code:`ansor.RPCRunnerWarpper` is used to create a RPC runner environment, +# +# Use local gpu, measure 10 times for every schedule to reduce variance. The timeout +# for each running is set to 4 seconds. +# +# During the searching process, we may generate several invalid schedules and they +# will be filtered out. It's fine to see "Encountered errors during feature extraction." +# in the tuning logs. + +with ansor.RPCRunnerWarpper("cuda", repeat=3, min_repeat_ms=100, timeout=4) as rpc_runner: + tune_option = ansor.TuneOption(n_trials=20, + runner=rpc_runner.runner, + callbacks=[ansor.LogToFile(log_file)]) + state = ansor.auto_schedule(task, search_policy, + tune_option=tune_option) + print(state) + +######################################################################### +# Finally we can directly use the returned result to get the generated schedule, +# while in the following tutorial we'll show how to inspect the best config from +# log file, check correctness, and measure running time. + +# Get history best from log file +inp, res = ansor.best_measure_pair_in_file(log_file) +# Get the task ComputeDAG from log result +dag = ansor.workload_key_to_dag(inp.task.workload_key) +# Apply log result to TVM schedule +s, arg_bufs = dag.apply_steps_from_state(inp.state) +func = tvm.build(s, arg_bufs, target=tgt) + +# check correctness +a_np = np.random.uniform(size=(N, CI, H, W)).astype(np.float32) +w_np = np.random.uniform(size=(CO, CI, KH, KW)).astype(np.float32) +c_np = conv2d_nchw_python(a_np, w_np, strides, padding) + +ctx = tvm.gpu() +a_tvm = tvm.nd.array(a_np, ctx=ctx) +w_tvm = tvm.nd.array(w_np, ctx=ctx) +c_tvm = tvm.nd.empty(c_np.shape, ctx=ctx) +func(a_tvm, w_tvm, c_tvm) + +tvm.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-2) + +# Evaluate running time. Here we choose a large repeat number (400) to reduce the noise +# and the overhead of kernel launch. You can also use nvprof to validate the result. +evaluator = func.time_evaluator(func.entry_name, ctx, number=400) +print('Time cost of this operator: %f' % evaluator(a_tvm, w_tvm, c_tvm).mean) + diff --git a/tutorials/ansor/tune_simple_subgraph.py b/tutorials/ansor/tune_simple_subgraph.py index 8555d6163c32..2af33c1e88ba 100644 --- a/tutorials/ansor/tune_simple_subgraph.py +++ b/tutorials/ansor/tune_simple_subgraph.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """ +.. _ansor-simple-subgraph: + Writing compute expression and Using Ansor auto-scheduler ========================================================= **Author**: `Lianmin Zheng `_, \ From 74ec7d0b792c31d04993d4b0e4ae1ea912e4e792 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 9 Jun 2020 04:40:01 -0700 Subject: [PATCH 19/45] add tune_test.py (the old tune_wkl.py) (#19) * add tune_test.py (the old tune_wkl.py) * update * fix measure * fix for gpu --- .gitignore | 3 + python/tvm/ansor/__init__.py | 8 +- python/tvm/ansor/auto_schedule.py | 21 +- python/tvm/ansor/measure.py | 52 +- python/tvm/ansor/workload_registry.py | 4 +- scripts/common.py | 1017 +++++++++++++++++ scripts/tune_test.py | 195 ++++ src/ansor/auto_schedule.cc | 16 +- src/ansor/auto_schedule.h | 4 +- tests/python/unittest/test_ansor_measure.py | 17 +- .../unittest/test_ansor_search_policy.py | 51 +- 11 files changed, 1285 insertions(+), 103 deletions(-) create mode 100644 scripts/common.py create mode 100644 scripts/tune_test.py diff --git a/.gitignore b/.gitignore index 506e54d93067..3c03e8ecda7a 100644 --- a/.gitignore +++ b/.gitignore @@ -234,3 +234,6 @@ conda/pkg # antlr files *.tokens *.interp + +# ansor tuning logs +scripts/*.json diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index bfdbaf9c8c8c..2e3553cf725c 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -31,10 +31,10 @@ from .compute_dag import ComputeDAG from .auto_schedule import SearchTask, MetaTileRewritePolicy, TuneOption, HardwareParams from .auto_schedule import auto_schedule -from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, RPCRunnerWarpper +from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, LocalRPCMeasureContext from .cost_model import RandomModel from .cost_model.xgb_model import XGBModel -from .serialization import LogToFile, LogReader, best_measure_pair_in_file -from .workload_registry import register_auto_scheduler_workload_func, workload_key_to_dag, \ - make_workload_key_func +from .serialization import LogToFile, LogReader, best_measure_pair_in_file, write_measure_records_to_file +from .workload_registry import register_auto_scheduler_workload_func, \ + workload_key_to_dag, make_workload_key_func from .task_scheduler import TaskScheduler, SimpleTaskScheduler diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index 5f4b7946b087..1192e6d551e5 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -160,12 +160,12 @@ def __init__(self, n_trials=0, early_stopping=-1, num_measure_per_iter=64, verbose, builder, runner, callbacks) -def auto_schedule(workload, search_policy='default', target=None, - target_host=None, hardware_params=None, - tune_option=None): +def auto_schedule(workload, target=None, + target_host=None, search_policy='default', + hardware_params=None, tune_option=None): """ Do auto schedule for a compute declaration. - The workload paramter can be a `string` as workload_key, or directly + The workload parameter can be a `string` as workload_key, or directly passing a `SearchTask` as input. Parameters @@ -174,8 +174,6 @@ def auto_schedule(workload, search_policy='default', target=None, target : Target - task : SearchTask - target_host : Target = None search_policy : Union[SearchPolicy, str] @@ -203,13 +201,12 @@ def auto_schedule(workload, search_policy='default', target=None, if isinstance(workload, str): sch, tensors = _ffi_api.AutoScheduleByWorkloadKey( - workload, target, target_host, search_policy, hardware_params, - tune_option) + workload, target, target_host, search_policy, hardware_params, tune_option) return sch, tensors elif isinstance(workload, SearchTask): - state = _ffi_api.AutoScheduleBySearchTask(workload, search_policy, - tune_option) - return state + sch, tensors = _ffi_api.AutoScheduleBySearchTask(workload, search_policy, tune_option) + return sch, tensors else: raise ValueError("Invalid workload: " + workload + - ", should be String or SearchTask") + ". Expect a string or SearchTask") + diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index b062eb585d12..299c004f756d 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -37,6 +37,7 @@ from tvm.ir import transform from tvm.rpc.tracker import Tracker from tvm.rpc.server import Server +from tvm.autotvm.measure.measure_methods import set_cuda_target_arch from ..contrib import tar, ndk from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout, request_remote, check_remote from .compute_dag import LayoutRewriteLevel @@ -78,7 +79,7 @@ class BuildResult(Object): def __init__(self, filename, args, error_no, error_msg, time_cost): self.__init_handle_by_constructor__( - _ffi_api.BuildResult, filename, args, error_no, + _ffi_api.BuildResult, filename if filename else "", args, error_no, error_msg if error_msg else "", time_cost) @@ -201,49 +202,32 @@ def __init__(self, key, host, port, priority=1, "and make sure you have free devices on the queue status.") -class RPCRunnerWarpper: - def __init__(self, target=None, priority=1, +class LocalRPCMeasureContext: + def __init__(self, + priority=1, n_parallel=1, timeout=10, - number=3, + number=10, repeat=1, min_repeat_ms=0, cooldown_interval=0.0): - self.target = target - self.priority = priority - self.n_parallel = n_parallel - self.timeout = timeout - self.number = number - self.repeat = repeat - self.min_repeat_ms = min_repeat_ms - self.cooldown_interval = cooldown_interval - - self.tracker = None - self.server = None - self.runner = None - - def __enter__(self): - if self.target == "cuda": - ctx = tvm.context("cuda", 0) + ctx = tvm.context("cuda", 0) + if ctx.exist: cuda_arch = "sm_" + "".join(ctx.compute_version.split('.')) - tvm.autotvm.measure.measure_methods.set_cuda_target_arch(cuda_arch) + set_cuda_target_arch(cuda_arch) host = '0.0.0.0' self.tracker = Tracker(host, port=9000, port_end=10000, silent=True) device_key = '$local$device$%d' % self.tracker.port self.server = Server(host, port=self.tracker.port, port_end=10000, - key=device_key, - use_popen=True, silent=True, - tracker_addr=(self.tracker.host, self.tracker.port)) - self.runner = RPCRunner(device_key, host, self.tracker.port, self.priority, - self.n_parallel, self.timeout, self.number, self.repeat, - self.min_repeat_ms, self.cooldown_interval) - - return self - - def __exit__(self, type, value, trace): - if value: - raise value - + key=device_key, use_popen=True, silent=True, + tracker_addr=(self.tracker.host, self.tracker.port)) + self.runner = RPCRunner(device_key, host, self.tracker.port, priority, + n_parallel, timeout, number, repeat, + min_repeat_ms, cooldown_interval) + # wait for the processes to start + time.sleep(0.5) + + def __del__(self): self.tracker.terminate() self.server.terminate() diff --git a/python/tvm/ansor/workload_registry.py b/python/tvm/ansor/workload_registry.py index 381e6009eea8..fccdcf8864be 100644 --- a/python/tvm/ansor/workload_registry.py +++ b/python/tvm/ansor/workload_registry.py @@ -130,7 +130,7 @@ def deserialize_args(args: Tuple) -> List: return ret -@tvm._ffi.register_func("auto_scheduler.workload_key_to_tensors") +@tvm._ffi.register_func("ansor.workload_key_to_tensors") def workload_key_to_tensors(workload_key: str) -> List[Tensor]: """Decode a workload key to the input/output tensors""" workload = json.loads(workload_key) @@ -144,7 +144,7 @@ def workload_key_to_tensors(workload_key: str) -> List[Tensor]: return lookup -@ tvm._ffi.register_func("auto_scheduler.workload_key_to_dag") +@ tvm._ffi.register_func("ansor.workload_key_to_dag") def workload_key_to_dag(workload_key: str) -> ComputeDAG: """Decode a workload key to a compute dag""" tensors = workload_key_to_tensors(workload_key) diff --git a/scripts/common.py b/scripts/common.py new file mode 100644 index 000000000000..4400104bdfe6 --- /dev/null +++ b/scripts/common.py @@ -0,0 +1,1017 @@ +"""Common utility for scripts""" +import argparse +import math +import os +import re +import time +from collections import defaultdict, namedtuple +from typing import Dict, List, Tuple + +import numpy as np +import matplotlib.pyplot as plt + +import topi +import tvm +from tvm import te +from tvm.ansor import (LogReader, make_workload_key_func, + register_auto_scheduler_workload_func, + write_measure_records_to_file) +from tvm.contrib import ndk, util + +############################################################ +###################### Test Workloads #################### +############################################################ + +@register_auto_scheduler_workload_func +def min_mn(M, N): + A = te.placeholder((M, N), name='A') + B = topi.min(A, axis=1) + + return [A, B] + +@register_auto_scheduler_workload_func +def argmin_mn(M, N): + A = te.placeholder((M, N), name='A') + B = topi.argmin(A, axis=1) + + return [A, B] + +@register_auto_scheduler_workload_func +def softmax_mn(M, N): + A = te.placeholder((M, N), name='A') + B = topi.nn.softmax(A, axis=1) + + return [A, B] + +@register_auto_scheduler_workload_func +def norm_bmn(B, M, N): + A = te.placeholder((B, M, N), name='A') + i = te.reduce_axis((0, M)) + j = te.reduce_axis((0, N)) + C = te.compute((B,), lambda b: te.sum(A[b][i][j] * A[b][i][j], axis=[i, j]), name='C') + D = te.compute((B,), lambda b: te.sqrt(C[b]), name='D') + + return [A, D] + +@register_auto_scheduler_workload_func +def add_mn(M, N): + A = te.placeholder((M, N), name='A') + B = te.placeholder((M, N), name='B') + C = te.compute((M, N), lambda i, j: A[i][j] + B[i][j], name='C') + + return [A, B, C] + +@register_auto_scheduler_workload_func +def matmul_nkkm(N, M, K, in_type='float32', out_type='float32', + tensor_core_support=False): + A = te.placeholder((N, K), name='A', dtype=in_type) + B = te.placeholder((K, M), name='B', dtype=in_type) + k = te.reduce_axis((0, K), name='k') + if in_type == out_type: + if not (in_type == 'float16' and out_type == 'float16'): + tensor_core_support = False + C = te.compute((N, M), + lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), + name='C', + attrs={"auto_scheduler_tensor_core_support": "True" if tensor_core_support else "False"}) + else: + if not ((in_type == 'float16' and out_type == 'float32') or \ + (in_type == 'int8' and out_type == 'int32')): + tensor_core_support = False + C = te.compute((N, M), + lambda i, j: te.sum(A[i][k].astype(out_type) * B[k][j].astype(out_type), + axis=[k]), + name='C', + attrs={"auto_scheduler_tensor_core_support": "True" if tensor_core_support else "False"}) + + return [A, B, C] + +@register_auto_scheduler_workload_func +def dense_layer(batch, in_dim, out_dim): + A = te.placeholder((batch, in_dim), name='A') + B = te.placeholder((out_dim, in_dim), name='B') + k = te.reduce_axis((0, in_dim), name='k') + C = te.compute((batch, out_dim), lambda i, j: te.sum(A[i][k] * B[j][k], axis=[k]), name='C') + + return [A, B, C] + +@register_auto_scheduler_workload_func +def max_pool_2d_nchw(N, C, H, W): + data = te.placeholder((N, C, H, W), name='data') + out = topi.nn.pool(data, (2, 2), (1, 1), (0, 0, 0, 0), pool_type='max', ceil_mode=True, + layout="NCHW", count_include_pad=True) + + return [data, out] + +@register_auto_scheduler_workload_func +def add_min_relu(M, N): + A = te.placeholder((M, N), name='A') + B = te.placeholder((M, N), name='B') + C = topi.add(A, B) + D = topi.min(C, axis=1) + out = topi.nn.relu(D) + return [A, B, out] + +@register_auto_scheduler_workload_func +def conv2d_relu_softmax_min(N, H, W, CI, CO, KH, KW, strides, padding, dilation): + data = te.placeholder((N, CI, H, W), name='data') + kernel = te.placeholder((CO, CI, KH, KW), name='kernel') + conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation) + relu = topi.nn.relu(conv) + softmax = topi.nn.softmax(relu, axis=1) + out = topi.min(softmax, axis=1) + + return [data, kernel, out] + +@register_auto_scheduler_workload_func +def conv2d_nchw_bias(N, H, W, CI, CO, KH, KW, strides, padding, dilation): + data = te.placeholder((N, CI, H, W), name='data') + kernel = te.placeholder((CO, CI, KH, KW), name='kernel') + bias = te.placeholder((CO, 1, 1), name='bias') + conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation) + #out = topi.nn.relu(conv) + out = topi.add(conv, bias) + return [data, kernel, bias, out] + +def conv2d_nhwc_without_layout_rewrite(Input, Filter, stride, padding, dilation, out_dtype='float32'): + """A copy of `topi.nn.conv2d_nhwc` but without the 'layout_free` attribute. + We use this in single op and subgraph evaluation because we don't want to introduce graph level optimization. + """ + assert isinstance(stride, int) or len(stride) == 2 + assert isinstance(dilation, int) or len(dilation) == 2 + + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + batch, in_height, in_width, in_channel = Input.shape + if len(Filter.shape) == 10: + kernel_h = Filter.shape[2] * Filter.shape[6] + kernel_w = Filter.shape[3] * Filter.shape[7] + channel = Filter.shape[4] * Filter.shape[8] + num_filter = Filter.shape[0] * Filter.shape[1] * Filter.shape[5] * Filter.shape[9] + #Filter = te.placeholder([kernel_h, kernel_w, channel, num_filter], Filter.dtype, Filter.name) + elif len(Filter.shape) == 11: + kernel_h = Filter.shape[3] * Filter.shape[7] + kernel_w = Filter.shape[4] * Filter.shape[8] + channel = Filter.shape[5] * Filter.shape[9] + num_filter = Filter.shape[0] * Filter.shape[1] * Filter.shape[2] * Filter.shape[6] * Filter.shape[10] + else: + kernel_h, kernel_w, channel, num_filter = Filter.shape + + # compute the output shape + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_top, pad_left, pad_down, pad_right = topi.nn.util.get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w)) + out_channel = num_filter + out_height = topi.util.simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) + out_width = topi.util.simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) + pad_before = [0, pad_top, pad_left, 0] + pad_after = [0, pad_down, pad_right, 0] + PaddedInput = topi.nn.pad(Input, pad_before, pad_after, name="PaddedInput") + rc = te.reduce_axis((0, in_channel), name='rc') + ry = te.reduce_axis((0, kernel_h), name='ry') + rx = te.reduce_axis((0, kernel_w), name='rx') + Output = te.compute( + (batch, out_height, out_width, out_channel), + lambda nn, yy, xx, ff: te.sum( + PaddedInput[nn, yy * stride_h + ry * dilation_h, + xx * stride_w + rx * dilation_w, rc].astype(out_dtype) * + Filter[ry, rx, rc, ff].astype(out_dtype) + , axis=[ry, rx, rc]), + name="Conv2dOutput", tag="conv2d_nhwc") + return Output + + +@register_auto_scheduler_workload_func +def conv2d_nhwc_bias_with_rewrite(N, H, W, CI, CO, KH, KW, strides, padding, dilation): + data = te.placeholder((N, H, W, CI), name='data') + kernel = te.placeholder((KH, KW, CI, CO), name='kernel') + bias = te.placeholder((CO, ), name='bias') + conv = topi.nn.conv2d_nhwc(data, kernel, strides, padding, dilation) + out = topi.add(conv, bias) + return [data, kernel, bias, out] + +@register_auto_scheduler_workload_func +def depthwise_conv2d_nhwc_bias_with_rewrite(N, H, W, CI, CO, KH, KW, strides, padding, dilation): + data = te.placeholder((N, H, W, CI), name='data') + kernel = te.placeholder((KH, KW, CI, 1), name='kernel') + bias = te.placeholder((CO, ), name='bias') + conv = topi.nn.depthwise_conv2d_nhwc(data, kernel, strides, padding, dilation) + out = topi.add(conv, bias) + return [data, kernel, bias, out] + +@register_auto_scheduler_workload_func +def conv2d_nhwc_bias(N, H, W, CI, CO, KH, KW, strides, padding, dilation): + data = te.placeholder((N, H, W, CI), name='data') + kernel = te.placeholder((KH, KW, CI, CO), name='kernel') + bias = te.placeholder((CO, ), name='bias') + conv = conv2d_nhwc_without_layout_rewrite(data, kernel, strides, padding, dilation) + out = topi.add(conv, bias) + return [data, kernel, bias, out] + + +@register_auto_scheduler_workload_func +def conv2d_nchw_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation=1): + data = te.placeholder((N, CI, H, W), name='data') + kernel = te.placeholder((CO, CI, kernel_size, kernel_size), name='kernel') + bias = te.placeholder((CO, 1, 1), name='bias') + bn_scale = te.placeholder((CO, 1, 1), name='bn_scale') + bn_offset = te.placeholder((CO, 1, 1), name='bn_offset') + + OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 + OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 + + conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation) + conv = te.compute((N, CO, OH, OW), + lambda i, j, k, l: conv[i, j, k, l] + bias[j, 0, 0], + name='bias_add') + conv = te.compute((N, CO, OH, OW), + lambda i, j, k, l: conv[i, j, k, l] * bn_scale[j, 0, 0], + name='bn_mul') + conv = te.compute((N, CO, OH, OW), + lambda i, j, k, l: conv[i, j, k, l] + bn_offset[j, 0, 0], + name='bn_add') + out = topi.nn.relu(conv) + + return [data, kernel, bias, bn_offset, bn_scale, out] + +@register_auto_scheduler_workload_func +def conv2d_nhwc_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation=1): + data = te.placeholder((N, H, W, CI), name='data') + kernel = te.placeholder((kernel_size, kernel_size, CI, CO), name='kernel') + bias = te.placeholder((CO,), name='bias') + bn_scale = te.placeholder((CO,), name='bn_scale') + bn_offset = te.placeholder((CO,), name='bn_offset') + + OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 + OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 + + conv = conv2d_nhwc_without_layout_rewrite(data, kernel, strides, padding, dilation) + conv = te.compute((N, OH, OW, CO), + lambda i, j, k, l: conv[i, j, k, l] + bias[l], + name='bias_add') + conv = te.compute((N, OH, OW, CO), + lambda i, j, k, l: conv[i, j, k, l] * bn_scale[l], + name='bn_mul') + conv = te.compute((N, OH, OW, CO), + lambda i, j, k, l: conv[i, j, k, l] + bn_offset[l], + name='bn_add') + out = topi.nn.relu(conv) + + return [data, kernel, bias, bn_offset, bn_scale, out] + +resnet_conv2d_configs = { + # format : N, H, W, CI, CO, KH, KW, strides, padding, dilation + '18': [ + (1, 224, 224, 3, 64, 7, 7, (2, 2), (3, 3), (1, 1)), + (1, 56, 56, 64, 128, 3, 3, (2, 2), (1, 1), (1, 1)), + (1, 56, 56, 64, 128, 1, 1, (2, 2), (0, 0), (1, 1)), + (1, 56, 56, 64, 64, 3, 3, (1, 1), (1, 1), (1, 1)), + (1, 56, 56, 64, 64, 1, 1, (1, 1), (0, 0), (1, 1)), + (1, 28, 28, 128, 256, 3, 3, (2, 2), (1, 1), (1, 1)), + (1, 28, 28, 128, 256, 1, 1, (2, 2), (0, 0), (1, 1)), + (1, 28, 28, 128, 128, 3, 3, (1, 1), (1, 1), (1, 1)), + (1, 14, 14, 256, 512, 3, 3, (2, 2), (1, 1), (1, 1)), + (1, 14, 14, 256, 512, 1, 1, (2, 2), (0, 0), (1, 1)), + (1, 14, 14, 256, 256, 3, 3, (1, 1), (1, 1), (1, 1)), + (1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1), (1, 1)), + ], + '50': [ + (1, 224, 224, 3, 64, 7, 7, (2, 2), (3, 3), (1, 1)), + (1, 56, 56, 256, 512, 1, 1, (2, 2), (0, 0), (1, 1)), + (1, 56, 56, 256, 128, 1, 1, (2, 2), (0, 0), (1, 1)), + (1, 56, 56, 256, 64, 1, 1, (1, 1), (0, 0), (1, 1)), + (1, 56, 56, 64, 256, 1, 1, (1, 1), (0, 0), (1, 1)), + (1, 56, 56, 64, 64, 3, 3, (1, 1), (1, 1), (1, 1)), + (1, 56, 56, 64, 64, 1, 1, (1, 1), (0, 0), (1, 1)), + (1, 28, 28, 512, 1024, 1, 1, (2, 2), (0, 0), (1, 1)), + (1, 28, 28, 512, 256, 1, 1, (2, 2), (0, 0), (1, 1)), + (1, 28, 28, 512, 128, 1, 1, (1, 1), (0, 0), (1, 1)), + (1, 28, 28, 128, 512, 1, 1, (1, 1), (0, 0), (1, 1)), + (1, 28, 28, 128, 128, 3, 3, (1, 1), (1, 1), (1, 1)), + (1, 14, 14, 1024, 2048, 1, 1, (2, 2), (0, 0), (1, 1)), + (1, 14, 14, 1024, 512, 1, 1, (2, 2), (0, 0), (1, 1)), + (1, 14, 14, 1024, 256, 1, 1, (1, 1), (0, 0), (1, 1)), + (1, 14, 14, 256, 1024, 1, 1, (1, 1), (0, 0), (1, 1)), + (1, 14, 14, 256, 256, 3, 3, (1, 1), (1, 1), (1, 1)), + (1, 7, 7, 2048, 512, 1, 1, (1, 1), (0, 0), (1, 1)), + (1, 7, 7, 512, 2048, 1, 1, (1, 1), (0, 0), (1, 1)), + (1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1), (1, 1)), + ], +} + +# number of appearance for all conv2ds in resnet +resnet_conv2d_weights = { + '18': [1, 1, 1, 4, 1, 1, 1, 3, 1, 1, 3, 3], + '50': [1, 1, 1, 2, 4, 3, 1, 1, 1, 3, 4, 4, 1, 1, 5, 6, 6, 2, 3, 3], +} + + +def parse_workload_name(name: str) -> List[str]: + """Parse workload name with wildcard character and abbreviation to standard names""" + if name.startswith('matmul-'): # e.g. matmul-512, matmul-1024, matmul-+ + N = name.split('-', maxsplit=1)[1] + if N == '+': + cfg_list = [256, 512, 1024] + else: + cfg_list = [N] + return ["matmul-%s" % x for x in cfg_list] + elif name.startswith('dense-'): # e.g. dense-1-512-1024, dense-16-512-512 + N = name.split('-', maxsplit=1)[1] + if N == '+': + cfg_list = ["1-512-512", "16-512-512"] + else: + cfg_list = [N] + return ["dense-%s" % x for x in cfg_list] + elif name.startswith('min-'): # e.g. min-4096 + N = name.split('-', maxsplit=1)[1] + if N == '+': + cfg_list = [4096, 8192, 16384] + else: + cfg_list = [N] + return ["min-%s" % x for x in cfg_list] + elif name.startswith('argmin-'): # e.g. argmin-4096 + N = name.split('-', maxsplit=1)[1] + if N == '+': + cfg_list = [4096, 8192, 16384] + else: + cfg_list = [N] + return ["argmin-%s" % x for x in cfg_list] + elif name.startswith('softmax-'): # e.g. softmax-4096 + N = name.split('-', maxsplit=1)[1] + if N == '+': + cfg_list = [4096, 8192, 16384] + else: + cfg_list = [N] + return ["softmax-%s" % x for x in cfg_list] + elif name.startswith('add-'): # e.g. add-4096 + N = name.split('-', maxsplit=1)[1] + if N == '+': + cfg_list = [4096, 8192, 16384] + else: + cfg_list = [N] + return ["add-%s" % x for x in cfg_list] + elif name.startswith('norm-'): # e.g. norm-1024 + N = name.split('-', maxsplit=1)[1] + if N == '+': + cfg_list = [4096, 8192, 16384] + else: + cfg_list = [N] + return ["norm-%s" % x for x in cfg_list] + elif name.startswith('add-min-relu'): # e.g. add-min-relu-4096 + N = name.split('-', maxsplit=3)[3] + if N == '+': + cfg_list = [4096, 8192, 16384] + else: + cfg_list = [N] + return ["add-min-relu-%s" % x for x in cfg_list] + elif name.startswith('nhwc-resnet-'): # e.g. nhwc-resnet-50.C1 + res = re.match(r'nhwc-resnet-(\d+).C([\d\+]+)(.B(\d+))?', name) + n_layers = res.group(1) + if res.group(2) == '+': + idx_list = range(len(resnet_conv2d_configs[n_layers])) + else: + idx_list = [int(res.group(2))] + + batch_size = 1 if res.group(4) is None else int(res.group(4)) + return ['nhwc-resnet-%s.C%d.B%d' % (n_layers, i, batch_size) for i in idx_list] + elif name.startswith('resnet-'): # e.g. resnet-50.C1, resnet-50.C1.B2, resnet-50.C+.B2 + res = re.match(r'resnet-(\d+).C([\d\+]+)(.B(\d+))?', name) + n_layers = res.group(1) + if res.group(2) == '+': + idx_list = range(len(resnet_conv2d_configs[n_layers])) + else: + idx_list = [int(res.group(2))] + + batch_size = 1 if res.group(4) is None else int(res.group(4)) + return ['resnet-%s.C%d.B%d' % (n_layers, i, batch_size) for i in idx_list] + elif name in ['conv2d-bn-relu', 'conv2d-relu-softmax-min', 'max-pool-2d', 'conv2d-rewrite', 'depthwise-conv2d-rewrite']: + return [name] + else: + raise ValueError("Invalid workload " + name) + + +def get_workload_keys(name: str) -> List[str]: + """Parse workload name and return the workload keys""" + normalized_names = parse_workload_name(name) + + ret = [] + for name in normalized_names: + if name.startswith('matmul-'): + name_split = name.split('-') + in_type = out_type = 'float32' + tensor_core_support = False + if len(name_split) == 2: # e.g. matmul-512 + N = K = M = int(name_split[1]) + elif len(name_split) == 4: # e.g. matmul-32-256-512 + N = int(name_split[1]) + K = int(name_split[2]) + M = int(name_split[3]) + elif len(name_split) == 6: # e.g. matmul-32-512-512-float16-float32 + N = int(name_split[1]) + K = int(name_split[2]) + M = int(name_split[3]) + in_type = name_split[4] + out_type = name_split[5] + elif len(name_split) == 7: # e.g. matmul-32-512-512-float16-float32-tc + N = int(name_split[1]) + K = int(name_split[2]) + M = int(name_split[3]) + in_type = name_split[4] + out_type = name_split[5] + tensor_core_support = name_split[6] == "tc" + else: + raise ValueError("Invalid matmul workload") + ret.append(make_workload_key_func(matmul_nkkm, + (N, M, K, in_type, out_type, tensor_core_support))) + elif name.startswith('dense-'): # e.g. dense-1-512-1024, dense-16-512-512 + name_split = name.split('-') + assert len(name_split) == 4 + batch = int(name_split[1]) + in_dim = int(name_split[2]) + out_dim = int(name_split[3]) + ret.append(make_workload_key_func(dense_layer, (batch, in_dim, out_dim))) + elif name.startswith('min-'): # e.g. min-4096 + name_split = name.split('-') + if len(name_split) == 2: + M = 64 + N = int(name_split[1]) + elif len(name_split) == 3: + M = int(name_split[1]) + N = int(name_split[2]) + else: + raise ValueError("Invalid min workload") + ret.append(make_workload_key_func(min_mn, (M, N))) + elif name.startswith('argmin-'): # e.g. argmin-4096 + name_split = name.split('-') + if len(name_split) == 2: + M = 64 + N = int(name_split[1]) + elif len(name_split) == 3: + M = int(name_split[1]) + N = int(name_split[2]) + else: + raise ValueError("Invalid argmin workload") + ret.append(make_workload_key_func(argmin_mn, (M, N))) + elif name.startswith('softmax-'): # e.g. softmax-4096 + name_split = name.split('-') + if len(name_split) == 2: + M = 64 + N = int(name_split[1]) + elif len(name_split) == 3: + M = int(name_split[1]) + N = int(name_split[2]) + else: + raise ValueError("Invalid softmax workload") + ret.append(make_workload_key_func(softmax_mn, (M, N))) + elif name.startswith('add-min-relu'): # e.g. add-min-relu-4096 + name_split = name.split('-') + if len(name_split) == 4: + M = 64 + N = int(name_split[3]) + elif len(name_split) == 5: + M = int(name_split[3]) + N = int(name_split[4]) + else: + raise ValueError("Invalid workload") + ret.append(make_workload_key_func(add_min_relu, (M, N))) + elif name.startswith('add-'): # e.g. add-4096 + name_split = name.split('-') + if len(name_split) == 2: + N = M = int(name_split[1]) + elif len(name_split) == 3: + M = int(name_split[1]) + N = int(name_split[2]) + else: + raise ValueError("Invalid add workload") + ret.append(make_workload_key_func(add_mn, (M, N))) + elif name.startswith('norm-'): # e.g. norm-4096 + name_split = name.split('-') + B = 2 + if len(name_split) == 2: + N = M = int(name_split[1]) + elif len(name_split) == 3: + M = int(name_split[1]) + N = int(name_split[2]) + else: + raise ValueError("Invalid norm workload") + ret.append(make_workload_key_func(norm_bmn, (B, M, N))) + elif name.startswith('nhwc-resnet-'): # e.g. nhwc-resnet-50.C1.B2 + res = re.match(r'nhwc-resnet-(\d+).C(\d+).B(\d+)', name) + n_layers = res.group(1) + idx = int(res.group(2)) + batch_size = 1 if res.group(3) is None else int(res.group(3)) + args = list(resnet_conv2d_configs[n_layers][idx]) + args[0] = batch_size + ret.append(make_workload_key_func(conv2d_nhwc_bias, args)) + elif name.startswith('resnet-'): # e.g. resnet-50.C1.B2 + res = re.match(r'resnet-(\d+).C(\d+).B(\d+)', name) + n_layers = res.group(1) + idx = int(res.group(2)) + batch_size = 1 if res.group(3) is None else int(res.group(3)) + args = list(resnet_conv2d_configs[n_layers][idx]) + args[0] = batch_size + ret.append(make_workload_key_func(conv2d_nchw_bias, args)) + elif name == 'max-pool-2d': + return [make_workload_key_func(max_pool_2d_nchw, (2, 512, 7, 7))] + elif name == 'conv2d-bn-relu': + return [make_workload_key_func(conv2d_nhwc_bn_relu, + (1, 7, 7, 512, 512, 3, 1, 1, 1)) ] + elif name == 'conv2d-rewrite': + return [ make_workload_key_func(conv2d_nhwc_bias_with_rewrite, + (1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1), (1, 1)))] + elif name == 'depthwise-conv2d-rewrite': + return [ make_workload_key_func(depthwise_conv2d_nhwc_bias_with_rewrite, + (1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1), (1, 1)))] + elif name == 'conv2d-relu-softmax-min': + return [make_workload_key_func(conv2d_relu_softmax_min, + (1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1), (1, 1)))] + else: + raise ValueError("Invalid workload " + name) + + return ret + + +def get_workload_weights(name: str) -> List[float]: + """Return weights for workload name""" + if name.startswith('resnet-'): + res = re.match(r'resnet-(\d+).C+', name) + n_layers = res.group(1) + return np.array(resnet_conv2d_weights[n_layers]) + else: + return np.ones(len(get_workload_keys(name))) + + +############################################################ +###################### Measure Tools #################### +############################################################ + + +def measure_schedule(s, + bufs, + target, + target_host=None, + remote=None, + ndk_cc=None, + number=10, + repeat=3, + min_repeat_ms=500): + """Measure the time cost of a schedule""" + func = tvm.build(s, bufs, target=target, target_host=target_host) + if remote: + ctx = remote.context(str(target), 0) + temp = util.tempdir() + remote_path = temp.relpath("tmp_deploy_lib.so") + os.environ['TVM_NDK_CC'] = ndk_cc + func.export_library(remote_path, ndk.create_shared) + remote.upload(remote_path) + func = remote.load_module("tmp_deploy_lib.so") + else: + ctx = tvm.context(str(target), 0) + + if os.environ.get('TVM_AUTO_CACHE_FLUSH', '0') == '1': + min_repeat_ms = 0 + number = 1 + + time_f = func.time_evaluator(func.entry_name, + ctx, + number=number, + repeat=repeat, + min_repeat_ms=min_repeat_ms) + + np_args = [np.ones(topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs] + args = [tvm.nd.array(x, ctx=ctx) for x in np_args] + ctx.sync() + + costs = time_f(*args).results + + return costs + +def check_correctness(s, bufs, s_ref, buf_ref, target, target_host=None, remote=None, ndk_cc=None): + """Check the correctness of a schedule against a reference schedule""" + func = tvm.build(s, bufs, target=target, target_host=target_host) + func_ref = tvm.build(s_ref, buf_ref, target='llvm') + + if remote: + raise NotImplemented + else: + ctx = tvm.context(str(target), 0) + ctx_ref = tvm.cpu() + + np_args = [np.ones(topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs] + args = [tvm.nd.array(x, ctx=ctx) for x in np_args] + args_ref = [tvm.nd.array(x, ctx=ctx_ref) for x in np_args] + ctx.sync() + + func(*args) + func_ref(*args_ref) + + for arr, arr_ref in zip(args, args_ref): + np.testing.assert_allclose(arr.asnumpy(), arr_ref.asnumpy()) + + +############################################################ +##################### Other Utilities #################### +############################################################ + + +def geomean(xs): + """Compute geometric mean""" + return math.exp(math.fsum(math.log(x) for x in xs) / len(xs)) + + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +global last_tic +last_tic = None + + +def PRINT_TIME(msg): + """Print time interval between differnt calls. This is for debug so we make the name letters capital""" + global last_tic + now = time.time() + + if last_tic is None: + last_tic = now + + print(msg, now - last_tic) + last_tic = now + + +############################################################ +###################### I/O Utilities ##################### +############################################################ + +# The format for a line in resulst file +BenchmarkRecord = namedtuple("BenchmarkRecord", [ + 'device', 'backend', 'workload_type', 'workload_name', 'library', 'algorithm', 'value', + 'time_stamp' +]) + + +class BaselineDatabase: + """A class for query records in baseline database""" + def __init__(self, filename): + self.filename = filename + + self.lines = [] + for line in open(filename): + if line.startswith('#') or line.isspace(): + continue + self.lines.append(line.split('\t')) + + def filter_records(self, devices=None, backends=None, wkl_names=None, libraries=None): + ret = [] + for line in self.lines: + line = BenchmarkRecord(*line) + + if devices is not None and line.device not in devices: + continue + if backends is not None and line.backend not in backends: + continue + if wkl_names is not None and line.workload_name not in wkl_names: + continue + if libraries is not None and line.library not in libraries: + continue + + ret.append(line) + return ret + + def get_data_dict(self, device, target, wkl_names) -> Tuple[Dict, List]: + """Return a data dict s.t. data[wkl][library] = cost""" + data = defaultdict(lambda: defaultdict(lambda: 1e10)) + + all_libraries = set() + + if "cpu" in target.keys: + backends = ['cpu'] + elif "gpu" in target.keys: + backends = ['gpu'] + else: + raise ValueError("Invalid target: " + target) + + # Read costs for baselines + records = self.filter_records(devices=[device], backends=backends, wkl_names=wkl_names) + for record in records: + # use min over (possible) multiple algorithms + all_libraries.add(record.library) + data[record.workload_name][record.library] = \ + min(data[record.workload_name][record.library], + np.mean(eval(record.value)['costs'])) + + return data, list(all_libraries) + + +class LogFileDatabase: + """A class for indexing best records in a log file""" + def __init__(self, filename: str, n_lines: int = -1): + inputs, results = LogReader(filename).read_lines(n_lines) + + # best records, search by (target_key, workload_key). e.g. ('gpu', 'conv2d...') + self.best_by_targetkey = {} + + # best according to (model, workload_key). e.g. ('1080ti', 'conv2d...')) + self.best_by_model = {} + + # find best records and build the index + for inp, res in zip(inputs, results): + if res.error_no != 0: + continue + + # use target keys in tvm target system as key to build best map + for target_key in inp.task.target.keys: + key = (target_key, inp.task.workload_key) + if key not in self.best_by_targetkey: + self.best_by_targetkey[key] = (inp, res) + else: + _, other_res = self.best_by_targetkey[key] + if np.mean([x.value for x in other_res.costs]) > \ + np.mean([x.value for x in res.costs]): + self.best_by_targetkey[key] = (inp, res) + + # use model as key to build best map + key = (inp.task.target.model, inp.task.workload_key) + if key not in self.best_by_model: + if inp.task.target.model != 'unknown': + self.best_by_model[key] = (inp, res) + else: + _, other_res = self.best_by_model[key] + if np.mean([x.value for x in other_res.costs]) > \ + np.mean([x.value for x in res.costs]): + self.best_by_model[key] = (inp, res) + + def write_best(self, filename: str): + best_records = list(self.best_by_targetkey.values()) + inputs = [x[0] for x in best_records] + results = [x[1] for x in best_records] + write_measure_records_to_file(filename, inputs, results) + + +############################################################ +###################### Plot Utilities #################### +############################################################ + +def max_curve(raw_curve): + """Return b[i] = max(a[:i]) """ + ret = [] + cur_max = -np.inf + for x in raw_curve: + cur_max = max(cur_max, x) + ret.append(cur_max) + return ret + +def min_curve(raw_curve): + """Return b[i] = min(a[:i]) """ + ret = [] + cur_min = np.inf + for x in raw_curve: + cur_min = min(cur_min, x) + ret.append(cur_min) + return ret + +def mean_curve(raw_curve, window_size=None): + """Return b[i] = mean(a[:i]) """ + ret = [] + mean = 0 + if window_size is None: + for i, x in enumerate(raw_curve): + mean = (mean * i + x) / (i + 1) + ret.append(mean) + else: + for i, x in enumerate(raw_curve): + if i >= window_size: + mean = (mean * window_size + x - raw_curve[i - window_size]) / window_size + else: + mean = (mean * i + x) / (i + 1) + ret.append(mean) + return ret + + +def enhance_color(color, h=1, l=1, s=1): + """Make color looks better for pyplot""" + import matplotlib.colors as mc + import colorsys + try: + c = mc.cnames[color] + except: + c = color + c = np.array(colorsys.rgb_to_hls(*mc.to_rgb(c))) + + h, l, s = h * c[0], l * c[1], s * c[2] + h, l, s = [max(min(x, 1), 0) for x in [h, l, s]] + + return colorsys.hls_to_rgb(h, l, s) + + +method_color_dict = { + 'ours': 'C0', + 'AutoTVM': 'C1', + + 'tensorflow': 'C2', + 'tensorflow-tensorrt': 'C9', + 'tflite': 'C2', + + 'pytorch': enhance_color('C3', l=1.1, s=0.9), + + 'FlexTensor': enhance_color('C5'), + 'halide': enhance_color('teal', l=1.25), + + 'Limit space': 'C7', + 'No fine-tuning': 'C8', + 'No task scheduler': 'C1', +} + +def method2color(method): + if '-batch-' in method: + method, batch_size = method.split('-batch-') + #return enhance_color(method_color_dict[method], s=1.1, l=1.5) + return method_color_dict[method] + else: + return method_color_dict[method] + +method_order_list = [ + 'pytorch', 'tensorflow', 'tensorflow-xla', 'tensorflow-tensorrt', + 'tflite', 'halide', 'FlexTensor', 'AutoTVM', + + 'Limit space', 'No fine-tuning', + 'ours', +] + +def method2order(method): + if '-batch-' in method: + method, batch_size = method.split('-batch-') + batch_size = int(batch_size) + return method_order_list.index(method) + batch_size / 100 + else: + return method_order_list.index(method) + +show_name_replace_dict = { + 'pytorch': "PyTorch", + 'tensorflow-tensorrt': 'TensorRT-TF', + 'tensorflow': 'TensorFlow', + 'tflite': 'TensorFlow Lite', + 'halide': 'Halide', + + 'ours': 'Ansor (ours)', + 'batch-16': 'batch', + + 'resnet_50': 'ResNet-50', + 'mobilenet_v2': 'Mobilenet V2', + 'resnet_18_3d': '3D-ResNet', + 'dcgan': 'DCGAN', + 'dqn': 'DQN', + 'bert': 'BERT', +} + +def show_name(name): + # if name.startswith('resnet-'): + # return name.split('.')[1] + for key, value in show_name_replace_dict.items(): + name = name.replace(key, value) + + return name + +def draw_grouped_bar_chart(data, baseline='pytorch', output='out.png', + yscale_log=False, yticks=None, y_max=None, + legend_bbox_to_anchor=None, legend_nrow=None, + figure_size=None, figax=None, draw_ylabel=True, draw_legend=True): + width = 1 + gap = 1.5 + fontsize = 19 + xticks_font_size = fontsize - 2 + + figure_size = figure_size or (11, 4) + legend_bbox_to_anchor = legend_bbox_to_anchor or (0.45, 1.35) + + all_methods = set() + legend_set = {} + + if figax is None: + fig, ax = plt.subplots() + axes = [] + axes.append(ax) + else: + ax = figax + + x0 = 0 + xticks = [] + xlabels = [] + + workloads = list(data.keys()) + for wkl in workloads: + ys = [] + colors = [] + + methods = list(data[wkl].keys()) + + if baseline in data[wkl]: + baseline_cost = data[wkl][baseline] + else: + # normalize to best library + baseline_cost = 1e10 + for method in methods: + if data[wkl][method] < baseline_cost: + baseline_cost = data[wkl][method] + + methods.sort(key=lambda x: method2order(x)) + for method in methods: + relative_speedup = baseline_cost / data[wkl][method] + if yticks is None: + ys.append(relative_speedup) + else: + ys.append(max(relative_speedup, yticks[0] * 1.1)) + colors.append(method2color(method)) + + # draw the bars + xs = np.arange(x0, x0 + len(ys)) + bars = ax.bar(xs, ys, width=width, color=colors) + + for method, bar_obj in zip(methods, bars): + all_methods.add(method) + if method not in legend_set: + legend_set[method] = bar_obj + + # tick and label + x0 += len(ys) + gap + + xticks.append(x0 - gap - len(ys)*width/2.0 - width/2.0) + xlabels.append(show_name(wkl)) + + ax.set_xticks(xticks) + ax.set_xticklabels(xlabels, fontsize=xticks_font_size) + plt.tick_params(axis='x', which='both', bottom='off', top='off') + + if draw_ylabel is True: + ax.set_ylabel('Relative Speedup', fontsize=fontsize) + elif isinstance(draw_ylabel, str): + ax.set_ylabel(draw_ylabel, fontsize=fontsize) + + if yscale_log: + ax.set_yscale('log', basey=2) + if yticks is not None: + ax.set_yticks(yticks) + if y_max: + ax.set_ylim(top=y_max) + + from matplotlib.ticker import FormatStrFormatter + ax.set_yticklabels(ax.get_yticks(), fontsize=fontsize) + ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f')) + ax.yaxis.grid(linewidth=0.4, linestyle='dotted') # draw grid line + ax.set_axisbelow(True) # grid lines are behind the rest + ax.tick_params(bottom=False, top=False, right=False) + + # put legend outside the plot + all_methods = list(all_methods) + all_methods.sort(key=lambda x : method2order(x)) + + if draw_legend: + legend_nrow = legend_nrow or 2 + ncol = (len(all_methods) + legend_nrow - 1)// legend_nrow + ax.legend([legend_set[x] for x in all_methods], + [show_name(x) for x in all_methods], + fontsize=fontsize-1, + loc='upper center', + bbox_to_anchor=legend_bbox_to_anchor, + ncol=ncol, + handlelength=1.0, + handletextpad=0.5, + columnspacing=1.1) + + if figax is None: + fig.set_size_inches(figure_size) + fig.savefig(output, bbox_inches='tight') + print("Output the plot to %s" % output) + + +def to_str_round(x, decimal=6): + if isinstance(x, str): + return x + if isinstance(x, (list, tuple)) or isinstance(x, np.ndarray): + return "[" + ", ".join([to_str_round(y, decimal=decimal) + for y in x]) + "]" + if isinstance(x, dict): + return str({k: eval(to_str_round(v)) for k, v in x.items()}) + if isinstance(x, int): + return str(x) + if isinstance(x, float): + format_str = "%%.%df" % decimal + return format_str % x + raise ValueError("Invalid value: " + str(x)) + diff --git a/scripts/tune_test.py b/scripts/tune_test.py new file mode 100644 index 000000000000..68f9dfadb8d4 --- /dev/null +++ b/scripts/tune_test.py @@ -0,0 +1,195 @@ +"""Use auto scheduler to tune workloads""" +import argparse +import logging +import os +import random + +import numpy as np + +import tvm +from tvm import ansor +from tvm.ansor.utils import request_remote + +from common import get_workload_keys, get_workload_weights, measure_schedule, str2bool + + +def make_cost_model(model_type, load_model_file, load_log_file): + if model_type == 'xgb': + model = ansor.XGBModel() + if load_model_file: + print("Load pretrained model...") + model.load(load_model_file) + elif load_log_file: + model.load_log_file(load_log_file) + elif model_type == "random": + model = ansor.RandomModel() + else: + raise ValueError("Invalid model: " + model_type) + return model + + +def tune_workload(wkl_key, target, target_host, n_trials, num_measure_per_iter, + policy, log_file, verbose, + model_type, load_model_file, load_log_file, + build_timeout, local_measure=True, device_key=None, host="0.0.0.0", + port=9190, n_parallel=1, ndk_cc=None, remeasure=True): + """Tune a workload""" + + if False: + # Debug info. Print static analysis results from the access analyzer + dag = auto_scheduler.workload_key_to_dag(wkl_key) + print(dag.access_analyzer) + exit() + + model = make_cost_model(model_type, load_model_file, load_log_file) + + if policy == 'meta-rewrite': + policy = ansor.MetaTileRewritePolicy(program_cost_model=model) + elif policy == 'beam-search': + policy = ansor.MetaTileRewritePolicy(program_cost_model=model, + params={'use_beam_search': 1}) + else: + raise ValueError("Invalid search policy: " + policy) + + if local_measure: + builder = ansor.LocalBuilder(build_timeout) + if target.target_name == "cuda": + measure_ctx = ansor.LocalRPCMeasureContext(repeat=1, min_repeat_ms=400) + runner = measure_ctx.runner + else: + runner = ansor.LocalRunner(repeat=1, min_repeat_ms=400) + else: + os.environ['TVM_NDK_CC'] = ndk_cc + builder = ansor.LocalBuilder(build_timeout, build_func='ndk') + runner = ansor.RPCRunner(device_key, host=host, port=port, + repeat=1, min_repeat_ms=400, + n_parallel=n_parallel) + + tune_option = ansor.TuneOption(n_trials=n_trials, + num_measure_per_iter=num_measure_per_iter, + verbose=verbose, + builder=builder, + runner=runner, + callbacks=[ansor.LogToFile(log_file)]) + s, bufs = ansor.auto_schedule(wkl_key, + target=target, target_host=target_host, + search_policy=policy, + tune_option=tune_option) + + if remeasure: + print("Found schedule:") + print(tvm.lower(s, bufs, simple_mode=True)) + print("Redo measurement for double check...") + if local_measure: + remote = None + else: + remote = request_remote(device_key, host, port, 1) + cost = np.mean((measure_schedule(s, bufs, target, remote=remote, ndk_cc=ndk_cc))) + print("Best schedule: %.2f GFLOPS\tcost: %.3f ms" % + (ansor.ComputeDAG(bufs).flop_ct / cost / 1e9, cost * 1e3)) + + +def tune_workloads_jointly(wkl_keys, weights, joint_tuner, target, target_host, + n_trials, num_measure_per_iter, + search_policy, log_file, verbose, + model_type, load_model_file, load_log_file, + build_timeout, local_measure=True, device_key=None, + host="0.0.0.0", port=9190, n_parallel=1, ndk_cc=None): + """Tune for multiple workloads jointly""" + if local_measure: + builder = ansor.LocalBuilder(timeout=build_timeout) + if target.target_name == "cuda": + measure_ctx = ansor.LocalRPCMeasureContext(repeat=1, min_repeat_ms=400) + runner = measure_ctx.runner + else: + runner = ansor.LocalRunner(repeat=1, min_repeat_ms=400) + else: + os.environ['TVM_NDK_CC'] = ndk_cc + builder = ansor.LocalBuilder(build_func='ndk', timeout=build_timeout) + runner = ansor.RPCRunner(device_key, host=host, port=port, + repeat=1, min_repeat_ms=400, + n_parallel=n_parallel) + + tasks = [] + for wkl_key in wkl_keys: + dag = ansor.workload_key_to_dag(wkl_key) + tasks.append(ansor.SearchTask(dag, wkl_key, target, target_host)) + + def objective_func(costs): + return sum(c * w for c, w in zip(costs, weights)) + + tuner = ansor.SimpleTaskScheduler(tasks, objective_func, strategy=joint_tuner, + load_log_file=load_log_file, load_model_file=load_model_file) + + search_policy = "%s.%s" % (search_policy, model_type) + tune_option = ansor.TuneOption(n_trials=n_trials, + num_measure_per_iter=num_measure_per_iter, + builder=builder, + verbose=verbose, + runner=runner, + callbacks=[ansor.LogToFile(log_file)]) + tuner.tune(tune_option, search_policy) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--wkl", type=str, required=True) + parser.add_argument("--n-trials", type=int, default=1000) + parser.add_argument("--target", type=str, default='llvm -mcpu=core-avx2') + parser.add_argument("--target-host", type=str, default=None) + parser.add_argument("--policy", type=str, choices=['meta-rewrite', 'beam-search'], default='meta-rewrite') + parser.add_argument("--log-file", type=str, help="Write log of measurement results to this file") + parser.add_argument("--build-timeout", type=int, default=10) + parser.add_argument("--run-timeout", type=int, default=60) + parser.add_argument("--model-type", type=str, choices=['xgb', 'random', 'no-share'], default='xgb') + parser.add_argument("--load-model", type=str) + parser.add_argument("--load-log", type=str, help="Load history log for pre-training the cost model") + parser.add_argument("--seed", type=int, default=0, help='random seed') + parser.add_argument("--verbose", type=int, default=1) + parser.add_argument("--task-scheduler", type=str, default='no', + choices=['no', 'gradient', 'round-robin'], + help='The strategy of task scheduler') + parser.add_argument("--local-measure", type=str2bool, nargs='?', const=True, default=True) + parser.add_argument("--device-key", type=str, default=None) + parser.add_argument("--host", type=str, default='0.0.0.0') + parser.add_argument("--port", type=int, default=9190) + parser.add_argument("--n-parallel", type=int, default=1) + parser.add_argument("--ndk-cc", type=str, default=None) + parser.add_argument("--num-measure-per-iter", type=int, default=48, + help="The number of programs to be measured at each iteration") + args = parser.parse_args() + + np.random.seed(args.seed) + random.seed(args.seed) + + logging.basicConfig() + logging.getLogger('auto_scheduler').setLevel(logging.DEBUG) + + log_file = args.log_file or args.wkl + ".json" + load_log_file = args.load_log or log_file + + target = tvm.target.create(args.target) + wkl_keys = get_workload_keys(args.wkl) + weights = get_workload_weights(args.wkl) + if args.task_scheduler == 'no': + # tune workloads one by one + for wkl_key in wkl_keys: + tune_workload(wkl_key, target, args.target_host, args.n_trials, + args.num_measure_per_iter, + args.policy, log_file, args.verbose, + args.model_type, args.load_model, load_log_file, + args.build_timeout, + args.local_measure, args.device_key, args.host, + args.port, args.n_parallel, args.ndk_cc, + remeasure=len(wkl_keys) == 1) + else: + # tune workloads jointly using JointTuner + tune_workloads_jointly(wkl_keys, weights, args.joint_tuner, + target, args.target_host, + args.n_trials, args.num_measure_per_iter, + args.policy, log_file, args.verbose, + args.model_type, args.load_model, args.load_log, + args.build_timeout, + args.local_measure, args.device_key, args.host, + args.port, args.n_parallel, args.ndk_cc) + diff --git a/src/ansor/auto_schedule.cc b/src/ansor/auto_schedule.cc index a0fa18874a69..3c793e5957f5 100644 --- a/src/ansor/auto_schedule.cc +++ b/src/ansor/auto_schedule.cc @@ -48,16 +48,18 @@ TuneOption TuneOptionNode::make(int n_trials, int early_stopping, return TuneOption(node); } -State AutoSchedule(SearchTask task, SearchPolicy search_policy, +std::pair > AutoSchedule(SearchTask task, SearchPolicy search_policy, TuneOption tune_option) { // Search for the best schedule ProgramMeasurer measurer = ProgramMeasurerNode::make(tune_option->builder, tune_option->runner, tune_option->callbacks, tune_option->verbose); - return search_policy->Search( + State state = search_policy->Search( task, tune_option->n_trials, tune_option->early_stopping, tune_option->num_measure_per_iter, tune_option->verbose, measurer); + + return task->compute_dag.ApplySteps(state->transform_steps); } std::pair > AutoSchedule( @@ -68,10 +70,8 @@ std::pair > AutoSchedule( SearchTask task = SearchTaskNode::make( std::move(dag), std::move(workload_key), std::move(target), std::move(target_host), std::move(hardware_params)); - State state = AutoSchedule(std::move(task), std::move(search_policy), + return AutoSchedule(std::move(task), std::move(search_policy), std::move(tune_option)); - - return task->compute_dag.ApplySteps(state->transform_steps); } TVM_REGISTER_GLOBAL("ansor.TuneOption") @@ -86,7 +86,11 @@ TVM_REGISTER_GLOBAL("ansor.TuneOption") TVM_REGISTER_GLOBAL("ansor.AutoScheduleBySearchTask") .set_body_typed([](SearchTask task, SearchPolicy search_policy, TuneOption tune_option) { - return AutoSchedule(task, search_policy, tune_option); + te::Schedule sch; + Array return_tensors; + std::tie(sch, return_tensors) = AutoSchedule(task, search_policy, tune_option); + + return Array{sch, return_tensors}; }); TVM_REGISTER_GLOBAL("ansor.AutoScheduleByWorkloadKey") diff --git a/src/ansor/auto_schedule.h b/src/ansor/auto_schedule.h index f68e844ba776..3737f8c5d096 100644 --- a/src/ansor/auto_schedule.h +++ b/src/ansor/auto_schedule.h @@ -67,8 +67,8 @@ class TuneOptionNode : public Object { TVM_DEFINE_COW_OBJECT_REF(TuneOption, ObjectRef, TuneOptionNode); /*! \brief Auto schedule for a compute declaration */ -State AutoSchedule(SearchTask task, SearchPolicy search_policy, - TuneOption tune_option); +std::pair > AutoSchedule( + SearchTask task, SearchPolicy search_policy, TuneOption tune_option); std::pair > AutoSchedule( std::string workload_key, Target target, Target target_host, diff --git a/tests/python/unittest/test_ansor_measure.py b/tests/python/unittest/test_ansor_measure.py index 0385568894fe..2ac54d3c765b 100644 --- a/tests/python/unittest/test_ansor_measure.py +++ b/tests/python/unittest/test_ansor_measure.py @@ -19,8 +19,6 @@ import tvm from tvm import ansor -from tvm.rpc.tracker import Tracker -from tvm.rpc.server import Server import tempfile from test_ansor_common import get_tiled_matmul @@ -69,26 +67,17 @@ def test_measure_local_builder_rpc_runner(): tgt = tvm.target.create("llvm") task = ansor.SearchTask(dag, "test", tgt) - minp = ansor.MeasureInput(task, s0) + local_builder = ansor.LocalBuilder() - host = '0.0.0.0' - tracker = Tracker(host, port=9000, port_end=10000, silent=True) - device_key = '$local$device$%d' % tracker.port - server = Server(host, port=tracker.port, port_end=10000, - key=device_key, - use_popen=True, silent=True, - tracker_addr=(tracker.host, tracker.port)) - rpc_runner = ansor.RPCRunner(device_key, host, tracker.port) + measure_ctx = ansor.LocalRPCMeasureContext() + rpc_runner = measure_ctx.runner bress = local_builder.build([minp]) assert bress[0].error_no == 0 mress = rpc_runner.run([minp], bress) assert mress[0].error_no == 0 - tracker.terminate() - server.terminate() - if __name__ == "__main__": test_serialization() diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py index a28456574abe..5cb67dba39fe 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -18,7 +18,6 @@ """Test search policy""" import random -import os import numpy as np import tempfile @@ -33,10 +32,10 @@ def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local' random.seed(seed) N = 128 - A, B, C = matmul_ansor_test(N, N, N) - dag = ansor.ComputeDAG([A, B, C]) - tgt = tvm.target.create(target) - task = ansor.SearchTask(dag, "test", tgt) + workload_key = ansor.make_workload_key_func(matmul_ansor_test, (N, N, N)) + dag = ansor.workload_key_to_dag(workload_key) + target = tvm.target.create(target) + task = ansor.SearchTask(dag, workload_key, target) with tempfile.NamedTemporaryFile() as fp: log_file = fp.name @@ -44,35 +43,29 @@ def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local' search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) tune_option = ansor.TuneOption(n_trials=n_trials, runner=runner, callbacks=[ansor.LogToFile(log_file)]) - state = ansor.auto_schedule(task, search_policy, + sch, args = ansor.auto_schedule(task, search_policy, tune_option=tune_option) - sch, args = dag.apply_steps_from_state(state) + inp, res = ansor.best_measure_pair_in_file(log_file, workload_key, target) - print("==== Get State ====") - print(state) - print("==== Get Python Code ====") - print(dag.print_python_code_from_state(state)) + print("==== Python Code ====") + print(dag.print_python_code_from_state(inp.state)) try: - print("==== Get Lowered Stmt ====") + print("==== Lowered Stmt ====") print(tvm.lower(sch, args, simple_mode=True)) - mod = tvm.build(sch, args, tgt) + mod = tvm.build(sch, args, target) - ctx = tvm.context(target, 0) - a = tvm.nd.array(np.random.uniform(size=(N, N)).astype(A.dtype), ctx) - b = tvm.nd.array(np.random.uniform(size=(N, N)).astype(B.dtype), ctx) - c = tvm.nd.array(np.zeros((N, N), dtype=C.dtype), ctx) + ctx = tvm.context(str(target), 0) + dtype = dag.tensors[0].dtype + a = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=(N, N)).astype(dtype), ctx) + c = tvm.nd.array(np.zeros((N, N), dtype=dtype), ctx) mod(a, b, c) tvm.testing.assert_allclose(c.asnumpy(), np.dot( a.asnumpy(), b.asnumpy()), rtol=1e-5) print("==== Verification passed ====") except Exception: raise Exception("Error encountered with seed: %d" % (seed)) - - inp, res = ansor.best_measure_pair_in_file(log_file) - s0 = dag.infer_bound_from_state(state) - s1 = dag.infer_bound_from_state(inp.state) - assert s0 == s1 print() @@ -81,23 +74,23 @@ def test_search_basic(): def test_search_xgb_model_rpc_runner(): - with ansor.RPCRunnerWarpper() as rpc_runner: - search_common(seed=456787236, cost_model=ansor.XGBModel(), - runner=rpc_runner.runner) + measure_ctx = ansor.LocalRPCMeasureContext() + search_common(seed=456787236, cost_model=ansor.XGBModel(), + runner=measure_ctx.runner) def test_search_opencl(): if tvm.context("opencl", 0).exist: - with ansor.RPCRunnerWarpper() as rpc_runner: - search_common("opencl", 380344973, rpc_runner.runner) + measure_ctx = ansor.LocalRPCMeasureContext() + search_common("opencl", 380344973, measure_ctx.runner) else: print("OpenCL device not found, skip this test.") def test_search_cuda(): if tvm.context("cuda", 0).exist: - with ansor.RPCRunnerWarpper("cuda") as rpc_runner: - search_common("cuda", 903667810, rpc_runner.runner) + measure_ctx = ansor.LocalRPCMeasureContext() + search_common("cuda", 903667810, measure_ctx.runner) else: print("CUDA device not found, skip this test.") From cd0a516271c2d7b5f239fa601247f969929a90d3 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Wed, 10 Jun 2020 21:01:11 +0800 Subject: [PATCH 20/45] Code refine for tune_test.py & Add a pre load callback (#20) * Bug fix for tutorials * Add PreLoadMeasuredStates * Add search_callback support for task tuner * Code refine for tune_test.py * Update * Update * Update * Update * Bug fix --- python/tvm/ansor/__init__.py | 2 +- python/tvm/ansor/auto_schedule.py | 38 +++- python/tvm/ansor/measure.py | 40 +++- python/tvm/ansor/task_scheduler.py | 9 +- scripts/tune_test.py | 212 ++++++++---------- src/ansor/auto_schedule.cc | 23 +- src/ansor/auto_schedule.h | 11 +- .../search_policy/meta_tile_rewrite_policy.cc | 5 +- .../search_policy/meta_tile_rewrite_policy.h | 9 +- src/ansor/search_policy/search_policy.cc | 82 ++++++- src/ansor/search_policy/search_policy.h | 36 ++- src/ansor/serialization.cc | 4 + src/ansor/serialization.h | 1 + .../unittest/test_ansor_search_policy.py | 6 +- tutorials/ansor/tune_conv2d_cuda.py | 29 ++- tutorials/ansor/tune_simple_subgraph.py | 41 +--- 16 files changed, 355 insertions(+), 193 deletions(-) diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 2e3553cf725c..1029875917aa 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -29,7 +29,7 @@ # Shortcut from .compute_dag import ComputeDAG -from .auto_schedule import SearchTask, MetaTileRewritePolicy, TuneOption, HardwareParams +from .auto_schedule import SearchTask, MetaTileRewritePolicy, TuneOption, HardwareParams, PreLoadMeasuredStatesCallback from .auto_schedule import auto_schedule from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, LocalRPCMeasureContext from .cost_model import RandomModel diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index 1192e6d551e5..5b5eb4fe183b 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -69,7 +69,15 @@ def __init__(self, dag, workload_key, target, target_host=None, class SearchPolicy(Object): def continue_search(self, task, num_measure, verbose, measurer): return _ffi_api.SearchPolicyContinueSearchOneRound(self, task, num_measure, verbose, measurer) + + def set_task(self, task): + _ffi_api.SearchPolicySetTask(self, task); + def set_verbose(self, verbose): + _ffi_api.SearchPolicySetVerbose(self, verbose); + + def run_callbacks(self, callbacks): + _ffi_api.SearchPolicyRunCallbacks(self, callbacks) @tvm._ffi.register_object("ansor.MetaTileRewritePolicy") class MetaTileRewritePolicy(SearchPolicy): @@ -117,6 +125,21 @@ def __init__(self, seed or random.randint(1, 1 << 30)) +@tvm._ffi.register_object("ansor.SearchCallback") +class SearchCallback(Object): + pass + + +@tvm._ffi.register_object("ansor.PreLoadMeasuredStatesCallback") +class PreLoadMeasuredStatesCallback(SearchCallback): + """ A SearchCallback that used for search policy to load measured hash + from the log file. + """ + def __init__(self, filename: str): + self.__init_handle_by_constructor__( + _ffi_api.PreLoadMeasuredStatesCallback, filename) + + @tvm._ffi.register_object("ansor.TuneOption") class TuneOption(Object): """ The options for tuning @@ -135,11 +158,13 @@ class TuneOption(Object): Builder which builds the program runner: Runner Runner which runs the program and measure time costs - callbacks: List[MeasureCallback] + measure_callbacks: List[MeasureCallback] Callback functions + pre_search_callbacks: List[SearchCallback] """ def __init__(self, n_trials=0, early_stopping=-1, num_measure_per_iter=64, - verbose=1, builder='local', runner='local', callbacks=None): + verbose=1, builder='local', runner='local', measure_callbacks=None, + pre_search_callbacks=None): if isinstance(builder, str): if builder == 'local': builder = LocalBuilder() @@ -152,12 +177,15 @@ def __init__(self, n_trials=0, early_stopping=-1, num_measure_per_iter=64, else: raise ValueError("Invalid builder: " + runner) - if callbacks is None: - callbacks = [] + if measure_callbacks is None: + measure_callbacks = [] + + if pre_search_callbacks is None: + pre_search_callbacks = [] self.__init_handle_by_constructor__( _ffi_api.TuneOption, n_trials, early_stopping, num_measure_per_iter, - verbose, builder, runner, callbacks) + verbose, builder, runner, measure_callbacks, pre_search_callbacks) def auto_schedule(workload, target=None, diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index 299c004f756d..610e9529090f 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -174,6 +174,16 @@ def __init__(self, @tvm._ffi.register_object("ansor.ProgramMeasurer") class ProgramMeasurer(Object): + """ + Parameters + ---------- + builder : Builder + runner : Runner + callbacks : List[MeasureCallback] + verbose : Int + max_continuous_error : Float + """ + def __init__(self, builder: Builder, runner: Runner, callbacks: List[MeasureCallback], verbose: int, max_continuous_error: int = -1): @@ -182,6 +192,21 @@ def __init__(self, builder: Builder, runner: Runner, @tvm._ffi.register_object("ansor.RPCRunner") class RPCRunner(Runner): + """ + Parameters + ---------- + key : Str + host : Str + port : Int + priority : Int + n_parallel : Int + timeout : Int + number : Int + repeat : Int + min_repeat_ms : Int + cooldown_interval : Float + """ + def __init__(self, key, host, port, priority=1, n_parallel=1, timeout=10, @@ -203,6 +228,19 @@ def __init__(self, key, host, port, priority=1, class LocalRPCMeasureContext: + """ A context wrapper for RPCRunner. + + Parameters + ---------- + priority : Int + n_parallel : Int + timeout : Int + number : Int + repeat : Int + min_repeat_ms : Int + cooldown_interval : Float + """ + def __init__(self, priority=1, n_parallel=1, @@ -228,8 +266,8 @@ def __init__(self, time.sleep(0.5) def __del__(self): - self.tracker.terminate() self.server.terminate() + self.tracker.terminate() class MeasureErrorNo(object): diff --git a/python/tvm/ansor/task_scheduler.py b/python/tvm/ansor/task_scheduler.py index 5144591d4f98..082b2d265140 100644 --- a/python/tvm/ansor/task_scheduler.py +++ b/python/tvm/ansor/task_scheduler.py @@ -153,7 +153,7 @@ def tune(self, tune_option: TuneOption, search_policy: Union[str, List[SearchPol self.tune_option = tune_option if self.use_debug_measurement_simulator is None: self.measurer = ProgramMeasurer(tune_option.builder, tune_option.runner, - tune_option.callbacks, tune_option.verbose) + tune_option.measure_callbacks, tune_option.verbose) self.ct = 0 self.tic = time.time() # reset num_measure_per_iter to make sure every task is tuned at least once @@ -167,6 +167,13 @@ def tune(self, tune_option: TuneOption, search_policy: Union[str, List[SearchPol self.sequential_now_task_idx = 0 self.sequential_now_task_begin_ct = 0 + for i in range(len(self.tasks)): + search_policy = self.search_policies[i] + task = self.tasks[i] + search_policy.set_task(task) + search_policy.set_verbose(tune_option.verbose) + search_policy.run_callbacks(tune_option.pre_search_callbacks) + # do a round robin first if self.strategy != 'sequential': for i in range(len(self.tasks)): diff --git a/scripts/tune_test.py b/scripts/tune_test.py index 68f9dfadb8d4..1f75f0dd583e 100644 --- a/scripts/tune_test.py +++ b/scripts/tune_test.py @@ -13,102 +13,67 @@ from common import get_workload_keys, get_workload_weights, measure_schedule, str2bool -def make_cost_model(model_type, load_model_file, load_log_file): - if model_type == 'xgb': - model = ansor.XGBModel() - if load_model_file: - print("Load pretrained model...") - model.load(load_model_file) - elif load_log_file: - model.load_log_file(load_log_file) - elif model_type == "random": - model = ansor.RandomModel() +def replay_workload(wkl_key, target, target_host, log_file, + local_measure=True, device_key=None, host="0.0.0.0", + port=9190, ndk_cc=None): + inp, res = ansor.best_measure_pair_in_file(log_file, wkl_key, target) + if inp is None: + print("Cannot find log for: %s" % (wkl_key)) else: - raise ValueError("Invalid model: " + model_type) - return model + dag = ansor.workload_key_to_dag(inp.task.workload_key) + s, bufs = dag.apply_steps_from_state(inp.state) + + print("Found schedule for: %s" % (wkl_key)) + print(tvm.lower(s, bufs, simple_mode=True)) + if local_measure: + remote = None + else: + remote = request_remote(device_key, host, port, 1) + cost = np.mean((measure_schedule(s, bufs, target, remote=remote, ndk_cc=ndk_cc))) + print("Best schedule: %.2f GFLOPS\tcost: %.3f ms" % + (ansor.ComputeDAG(bufs).flop_ct / cost / 1e9, cost * 1e3)) -def tune_workload(wkl_key, target, target_host, n_trials, num_measure_per_iter, - policy, log_file, verbose, - model_type, load_model_file, load_log_file, - build_timeout, local_measure=True, device_key=None, host="0.0.0.0", - port=9190, n_parallel=1, ndk_cc=None, remeasure=True): +def tune_workload(wkl_key, target, target_host, policy, model_type, load_model_file, + load_log_file, tune_option): """Tune a workload""" if False: # Debug info. Print static analysis results from the access analyzer - dag = auto_scheduler.workload_key_to_dag(wkl_key) + dag = ansor.workload_key_to_dag(wkl_key) print(dag.access_analyzer) exit() - model = make_cost_model(model_type, load_model_file, load_log_file) + if model_type == 'xgb': + model = ansor.XGBModel() + if load_model_file: + print("Load pretrained model...") + model.load(load_model_file) + elif load_log_file: + model.load_log_file(load_log_file) + elif model_type == "random": + model = ansor.RandomModel() + else: + raise ValueError("Invalid model: " + model_type) if policy == 'meta-rewrite': policy = ansor.MetaTileRewritePolicy(program_cost_model=model) elif policy == 'beam-search': policy = ansor.MetaTileRewritePolicy(program_cost_model=model, - params={'use_beam_search': 1}) + params={'use_beam_search': 1}) else: raise ValueError("Invalid search policy: " + policy) - if local_measure: - builder = ansor.LocalBuilder(build_timeout) - if target.target_name == "cuda": - measure_ctx = ansor.LocalRPCMeasureContext(repeat=1, min_repeat_ms=400) - runner = measure_ctx.runner - else: - runner = ansor.LocalRunner(repeat=1, min_repeat_ms=400) - else: - os.environ['TVM_NDK_CC'] = ndk_cc - builder = ansor.LocalBuilder(build_timeout, build_func='ndk') - runner = ansor.RPCRunner(device_key, host=host, port=port, - repeat=1, min_repeat_ms=400, - n_parallel=n_parallel) - - tune_option = ansor.TuneOption(n_trials=n_trials, - num_measure_per_iter=num_measure_per_iter, - verbose=verbose, - builder=builder, - runner=runner, - callbacks=[ansor.LogToFile(log_file)]) s, bufs = ansor.auto_schedule(wkl_key, target=target, target_host=target_host, search_policy=policy, tune_option=tune_option) - if remeasure: - print("Found schedule:") - print(tvm.lower(s, bufs, simple_mode=True)) - print("Redo measurement for double check...") - if local_measure: - remote = None - else: - remote = request_remote(device_key, host, port, 1) - cost = np.mean((measure_schedule(s, bufs, target, remote=remote, ndk_cc=ndk_cc))) - print("Best schedule: %.2f GFLOPS\tcost: %.3f ms" % - (ansor.ComputeDAG(bufs).flop_ct / cost / 1e9, cost * 1e3)) - -def tune_workloads_jointly(wkl_keys, weights, joint_tuner, target, target_host, - n_trials, num_measure_per_iter, - search_policy, log_file, verbose, - model_type, load_model_file, load_log_file, - build_timeout, local_measure=True, device_key=None, - host="0.0.0.0", port=9190, n_parallel=1, ndk_cc=None): +def tune_workloads_jointly(wkl_keys, weights, task_scheduler, target, target_host, + search_policy, model_type, load_model_file, load_log_file, + tune_option): """Tune for multiple workloads jointly""" - if local_measure: - builder = ansor.LocalBuilder(timeout=build_timeout) - if target.target_name == "cuda": - measure_ctx = ansor.LocalRPCMeasureContext(repeat=1, min_repeat_ms=400) - runner = measure_ctx.runner - else: - runner = ansor.LocalRunner(repeat=1, min_repeat_ms=400) - else: - os.environ['TVM_NDK_CC'] = ndk_cc - builder = ansor.LocalBuilder(build_func='ndk', timeout=build_timeout) - runner = ansor.RPCRunner(device_key, host=host, port=port, - repeat=1, min_repeat_ms=400, - n_parallel=n_parallel) tasks = [] for wkl_key in wkl_keys: @@ -118,78 +83,99 @@ def tune_workloads_jointly(wkl_keys, weights, joint_tuner, target, target_host, def objective_func(costs): return sum(c * w for c, w in zip(costs, weights)) - tuner = ansor.SimpleTaskScheduler(tasks, objective_func, strategy=joint_tuner, + tuner = ansor.SimpleTaskScheduler(tasks, objective_func, strategy=task_scheduler, load_log_file=load_log_file, load_model_file=load_model_file) - search_policy = "%s.%s" % (search_policy, model_type) - tune_option = ansor.TuneOption(n_trials=n_trials, - num_measure_per_iter=num_measure_per_iter, - builder=builder, - verbose=verbose, - runner=runner, - callbacks=[ansor.LogToFile(log_file)]) tuner.tune(tune_option, search_policy) if __name__ == "__main__": parser = argparse.ArgumentParser() + # Task related options parser.add_argument("--wkl", type=str, required=True) - parser.add_argument("--n-trials", type=int, default=1000) parser.add_argument("--target", type=str, default='llvm -mcpu=core-avx2') parser.add_argument("--target-host", type=str, default=None) + parser.add_argument("--n-trials", type=int, default=1000) + parser.add_argument("--num-measure-per-iter", type=int, default=48, + help="The number of programs to be measured at each iteration") + parser.add_argument("--tune", type=str2bool, nargs='?', const=True, default=True) + # Strategy related options + parser.add_argument("--seed", type=int, default=0, help='random seed') parser.add_argument("--policy", type=str, choices=['meta-rewrite', 'beam-search'], default='meta-rewrite') - parser.add_argument("--log-file", type=str, help="Write log of measurement results to this file") - parser.add_argument("--build-timeout", type=int, default=10) - parser.add_argument("--run-timeout", type=int, default=60) parser.add_argument("--model-type", type=str, choices=['xgb', 'random', 'no-share'], default='xgb') - parser.add_argument("--load-model", type=str) - parser.add_argument("--load-log", type=str, help="Load history log for pre-training the cost model") - parser.add_argument("--seed", type=int, default=0, help='random seed') - parser.add_argument("--verbose", type=int, default=1) parser.add_argument("--task-scheduler", type=str, default='no', choices=['no', 'gradient', 'round-robin'], help='The strategy of task scheduler') + # File related options + parser.add_argument("--log-file", type=str, help="Write log of measurement results to this file") + parser.add_argument("--load-model", type=str, help="Load pre trained cost model file") + parser.add_argument("--load-log", type=str, help="Load history log for pre-training the cost model") + # Detailed control options + parser.add_argument("--build-timeout", type=int, default=10) + parser.add_argument("--run-timeout", type=int, default=60) + parser.add_argument("--verbose", type=int, default=1) parser.add_argument("--local-measure", type=str2bool, nargs='?', const=True, default=True) parser.add_argument("--device-key", type=str, default=None) parser.add_argument("--host", type=str, default='0.0.0.0') parser.add_argument("--port", type=int, default=9190) parser.add_argument("--n-parallel", type=int, default=1) parser.add_argument("--ndk-cc", type=str, default=None) - parser.add_argument("--num-measure-per-iter", type=int, default=48, - help="The number of programs to be measured at each iteration") args = parser.parse_args() np.random.seed(args.seed) random.seed(args.seed) logging.basicConfig() - logging.getLogger('auto_scheduler').setLevel(logging.DEBUG) + logging.getLogger('ansor').setLevel(logging.DEBUG) log_file = args.log_file or args.wkl + ".json" - load_log_file = args.load_log or log_file target = tvm.target.create(args.target) wkl_keys = get_workload_keys(args.wkl) - weights = get_workload_weights(args.wkl) - if args.task_scheduler == 'no': - # tune workloads one by one - for wkl_key in wkl_keys: - tune_workload(wkl_key, target, args.target_host, args.n_trials, - args.num_measure_per_iter, - args.policy, log_file, args.verbose, - args.model_type, args.load_model, load_log_file, - args.build_timeout, - args.local_measure, args.device_key, args.host, - args.port, args.n_parallel, args.ndk_cc, - remeasure=len(wkl_keys) == 1) - else: - # tune workloads jointly using JointTuner - tune_workloads_jointly(wkl_keys, weights, args.joint_tuner, - target, args.target_host, - args.n_trials, args.num_measure_per_iter, - args.policy, log_file, args.verbose, - args.model_type, args.load_model, args.load_log, - args.build_timeout, - args.local_measure, args.device_key, args.host, - args.port, args.n_parallel, args.ndk_cc) + if args.tune: + load_log_file = args.load_log or log_file + weights = get_workload_weights(args.wkl) + + builder = runner = measure_ctx = None + if args.local_measure: + builder = ansor.LocalBuilder(timeout=args.build_timeout) + if target.target_name == "cuda": + measure_ctx = ansor.LocalRPCMeasureContext(repeat=1, min_repeat_ms=400) + runner = measure_ctx.runner + else: + runner = ansor.LocalRunner(repeat=1, min_repeat_ms=400) + else: + os.environ['TVM_NDK_CC'] = args.ndk_cc + builder = ansor.LocalBuilder(timeout=args.build_timeout, build_func='ndk') + runner = ansor.RPCRunner(args.device_key, host=args.host, port=args.port, + repeat=1, min_repeat_ms=400, n_parallel=args.n_parallel) + + tune_option = ansor.TuneOption(n_trials=args.n_trials, + num_measure_per_iter=args.num_measure_per_iter, + verbose=args.verbose, + builder=builder, + runner=runner, + measure_callbacks=[ansor.LogToFile(log_file)], + pre_search_callbacks=[ansor.PreLoadMeasuredStatesCallback(log_file)]) + + if args.task_scheduler == 'no': + # tune workloads one by one + for wkl_key in wkl_keys: + tune_workload(wkl_key, target, args.target_host, args.policy, + args.model_type, args.load_model, load_log_file, + tune_option) + else: + # tune workloads jointly using JointTuner + tune_workloads_jointly(wkl_keys, weights, args.task_scheduler, + target, args.target_host, args.policy, + args.model_type, args.load_model, load_log_file, + tune_option) + if measure_ctx: + del measure_ctx + + if not args.tune or len(wkl_keys) == 1: + for wkl_key in wkl_keys: + replay_workload(wkl_key, target, args.target_host, log_file, + args.local_measure, args.device_key, args.host, + args.port, args.ndk_cc) diff --git a/src/ansor/auto_schedule.cc b/src/ansor/auto_schedule.cc index 3c793e5957f5..200118cf708b 100644 --- a/src/ansor/auto_schedule.cc +++ b/src/ansor/auto_schedule.cc @@ -36,7 +36,8 @@ TVM_REGISTER_NODE_TYPE(TuneOptionNode); TuneOption TuneOptionNode::make(int n_trials, int early_stopping, int num_measure_per_iter, int verbose, Builder builder, Runner runner, - Array callbacks) { + Array measure_callbacks, + Array pre_search_callbacks) { auto node = make_object(); node->n_trials = n_trials; node->early_stopping = early_stopping; @@ -44,20 +45,23 @@ TuneOption TuneOptionNode::make(int n_trials, int early_stopping, node->verbose = verbose; node->builder = std::move(builder); node->runner = std::move(runner); - node->callbacks = std::move(callbacks); + node->measure_callbacks = std::move(measure_callbacks); + node->pre_search_callbacks = std::move(pre_search_callbacks); return TuneOption(node); } -std::pair > AutoSchedule(SearchTask task, SearchPolicy search_policy, - TuneOption tune_option) { +std::pair > AutoSchedule(SearchTask task, + SearchPolicy search_policy, TuneOption tune_option) { // Search for the best schedule ProgramMeasurer measurer = ProgramMeasurerNode::make(tune_option->builder, tune_option->runner, - tune_option->callbacks, tune_option->verbose); + tune_option->measure_callbacks, + tune_option->verbose); State state = search_policy->Search( task, tune_option->n_trials, tune_option->early_stopping, - tune_option->num_measure_per_iter, tune_option->verbose, measurer); + tune_option->num_measure_per_iter, tune_option->verbose, measurer, + tune_option->pre_search_callbacks); return task->compute_dag.ApplySteps(state->transform_steps); } @@ -71,16 +75,17 @@ std::pair > AutoSchedule( std::move(dag), std::move(workload_key), std::move(target), std::move(target_host), std::move(hardware_params)); return AutoSchedule(std::move(task), std::move(search_policy), - std::move(tune_option)); + std::move(tune_option)); } TVM_REGISTER_GLOBAL("ansor.TuneOption") .set_body_typed([](int n_trials, int early_stopping, int num_measure_per_iter, int verbose, Builder builder, - Runner runner, Array callbacks) { + Runner runner, Array measure_callbacks, + Array pre_search_callbacks) { return TuneOptionNode::make(n_trials, early_stopping, num_measure_per_iter, verbose, builder, - runner, callbacks); + runner, measure_callbacks, pre_search_callbacks); }); TVM_REGISTER_GLOBAL("ansor.AutoScheduleBySearchTask") diff --git a/src/ansor/auto_schedule.h b/src/ansor/auto_schedule.h index 3737f8c5d096..4e70ac0b577a 100644 --- a/src/ansor/auto_schedule.h +++ b/src/ansor/auto_schedule.h @@ -28,6 +28,7 @@ #include #include #include "measure.h" +#include "search_policy/search_policy.h" namespace tvm { namespace ansor { @@ -45,7 +46,9 @@ class TuneOptionNode : public Object { Builder builder; // Builder which builds the program Runner runner; // Runner which runs the program and measure time // costs - Array callbacks; // Callback functions + Array measure_callbacks; // MeasureCallback functions + Array pre_search_callbacks; // SearchCallback functions + // run before search void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("n_trials", &n_trials); @@ -54,12 +57,14 @@ class TuneOptionNode : public Object { v->Visit("verbose", &verbose); v->Visit("builder", &builder); v->Visit("runner", &runner); - v->Visit("callbacks", &callbacks); + v->Visit("measure_callbacks", &measure_callbacks); + v->Visit("pre_search_callbacks", &pre_search_callbacks); } static TuneOption make(int n_trials, int early_stopping, int num_measure_per_iter, int verbose, Builder builder, - Runner runner, Array callbacks); + Runner runner, Array measure_callbacks, + Array pre_search_callbacks); static constexpr const char* _type_key = "ansor.TuneOption"; TVM_DECLARE_FINAL_OBJECT_INFO(TuneOptionNode, Object); diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.cc b/src/ansor/search_policy/meta_tile_rewrite_policy.cc index f086a8879abb..0a9f97ab9170 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.cc +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.cc @@ -58,12 +58,15 @@ SearchPolicy MetaTileRewritePolicyNode::make(CostModel program_cost_model, State MetaTileRewritePolicyNode::Search(SearchTask task, int n_trials, int early_stopping, int num_measure_per_iter, - int verbose, ProgramMeasurer measurer) { + int verbose, ProgramMeasurer measurer, + Array pre_search_callbacks) { std::vector best_states, random_states; cur_task_ = task; verbose_ = verbose; num_measure_per_iter_ = num_measure_per_iter; + RunCallbacks(pre_search_callbacks); + if (n_trials <= 1) { // no measurement is allowed SearchOneRound(&best_states, 0, &random_states); CHECK_GT(best_states.size(), 0); diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.h b/src/ansor/search_policy/meta_tile_rewrite_policy.h index f92813b11273..8cf61b4d1e11 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.h +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.h @@ -63,7 +63,8 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode { // Return the best state State Search(SearchTask task, int n_trials, int early_stopping, int num_measure_per_iter, - int verbose, ProgramMeasurer measurer) final; + int verbose, ProgramMeasurer measurer, + Array pre_search_callbacks) final; // Continue search. This is used by JointTuner std::pair, Array > ContinueSearchOneRound( @@ -74,8 +75,6 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode { TVM_DECLARE_FINAL_OBJECT_INFO(MetaTileRewritePolicyNode, SearchPolicyNode); - SearchTask cur_task_; // The current task - protected: // Pick states from best states and random states with eps-greedy policy void PickStatesWithEpsGreedy(std::vector* inputs, @@ -100,12 +99,8 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode { SplitFactorizationMemo split_memo_; // Memorize split space for Split std::mt19937 rand_gen_; // Random generator - int verbose_; // Verbose level (0 means silent) int num_measure_per_iter_; // The number of states to measure per iteration - // The set of the already measured states. We store the string format for redundancy check - std::unordered_set measured_states_set_; - // The array of already measured states. std::vector measured_states_vector_; diff --git a/src/ansor/search_policy/search_policy.cc b/src/ansor/search_policy/search_policy.cc index f3072fda4956..b2ba27bfc6ba 100644 --- a/src/ansor/search_policy/search_policy.cc +++ b/src/ansor/search_policy/search_policy.cc @@ -23,27 +23,91 @@ */ #include "search_policy.h" + #include +#include "../serialization.h" + namespace tvm { namespace ansor { TVM_REGISTER_OBJECT_TYPE(SearchPolicyNode); +TVM_REGISTER_OBJECT_TYPE(PreLoadMeasuredStatesCallbackNode); + +void SearchPolicyNode::PreLoadMeasuredStates(const std::string& log_file) { + LogReader reader = LogReaderNode::make(log_file); + const auto& res = reader->ReadLines(-1); + if (res.first.size()) { + std::vector measured_states; + for (const auto& inp : res.first) { + if (inp->task->workload_key == cur_task_->workload_key && + inp->task->target->target_name.compare( + cur_task_->target->target_name) == 0) { + State state = cur_task_->compute_dag.GetInitState(); + state.CopyOnWrite()->transform_steps = inp->state->transform_steps; + state.DoSteps(inp->state->transform_steps, cur_task_->compute_dag); + measured_states.push_back(std::move(state)); + } + } + cur_task_->compute_dag.InferBound(&measured_states); + for (auto state : measured_states) { + measured_states_set_.insert(state.ToStr()); + } + + StdCout(verbose_) << "Measured States Set: " + << measured_states_set_.size() + << " state hashes loaded from " << log_file << std::endl; + } +} + +void SearchPolicyNode::RunCallbacks(const Array& callbacks) { + if (callbacks.defined() && callbacks.size()) { + PrintTitle("Process search callbacks", verbose_); + for (const auto& callback : callbacks) { + callback->callback(this); + } + } +} + +SearchCallback PreLoadMeasuredStatesCallbackNode::make(std::string filename) { + auto node = make_object(); + node->filename = std::move(filename); + return SearchCallback(node); +} + +void PreLoadMeasuredStatesCallbackNode::callback(SearchPolicyNode* policy) { + policy->PreLoadMeasuredStates(filename); +} // Search Policy TVM_REGISTER_GLOBAL("ansor.SearchPolicyContinueSearchOneRound") -.set_body([](TVMArgs args, TVMRetValue *ret) { - SearchPolicy policy = args[0]; - SearchTask task = args[1]; - int num_measure = args[2]; - int verbose = args[3]; - ProgramMeasurer measurer = args[4]; - +.set_body_typed([](SearchPolicy policy, SearchTask task, int num_measure, + int verbose, ProgramMeasurer measurer) { Array inputs; Array results; - std::tie(inputs, results) = policy->ContinueSearchOneRound(task, num_measure, verbose, measurer); + std::tie(inputs, results) = policy->ContinueSearchOneRound(task, num_measure, + verbose, measurer); + return Array{inputs, results}; +}); + +TVM_REGISTER_GLOBAL("ansor.SearchPolicyRunCallbacks") +.set_body_typed([](SearchPolicy policy, Array callbacks) { + policy->RunCallbacks(callbacks); +}); + +TVM_REGISTER_GLOBAL("ansor.SearchPolicySetTask") +.set_body_typed([](SearchPolicy policy, SearchTask task) { + policy->cur_task_ = task; +}); + +TVM_REGISTER_GLOBAL("ansor.SearchPolicySetVerbose") +.set_body_typed([](SearchPolicy policy, int verbose) { + policy->verbose_ = verbose; +}); - *ret = Array{inputs, results}; +TVM_REGISTER_GLOBAL("ansor.PreLoadMeasuredStatesCallback") +.set_body_typed([](std::string filename) { + return PreLoadMeasuredStatesCallbackNode::make(filename); }); } // namespace ansor diff --git a/src/ansor/search_policy/search_policy.h b/src/ansor/search_policy/search_policy.h index f2071deab447..0d7ebe94c14f 100644 --- a/src/ansor/search_policy/search_policy.h +++ b/src/ansor/search_policy/search_policy.h @@ -26,6 +26,7 @@ #define TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_ #include +#include #include #include #include @@ -36,17 +37,45 @@ namespace tvm { namespace ansor { class SearchPolicy; +class SearchPolicyNode; + +class SearchCallbackNode : public Object { + public: + virtual void callback(SearchPolicyNode* policy) = 0; + static constexpr const char *_type_key = "ansor.SearchCallback"; + TVM_DECLARE_BASE_OBJECT_INFO(SearchCallbackNode, Object); +}; +TVM_DEFINE_MUTABLE_OBJECT_REF(SearchCallback, SearchCallbackNode); + +class PreLoadMeasuredStatesCallbackNode : public SearchCallbackNode { + public: + std::string filename; + + static SearchCallback make(std::string filename); + + void callback(SearchPolicyNode* policy) final; + + static constexpr const char *_type_key = "ansor.PreLoadMeasuredStatesCallback"; + TVM_DECLARE_FINAL_OBJECT_INFO(PreLoadMeasuredStatesCallbackNode, SearchCallbackNode); +}; /*! \brief The base class for search policy */ class SearchPolicyNode : public Object { public: virtual State Search(SearchTask task, int n_trials, int early_stopping, int num_measure_per_iter, - int verbose, ProgramMeasurer measurer) = 0; + int verbose, ProgramMeasurer measurer, + Array pre_search_callbacks) = 0; virtual std::pair, Array > ContinueSearchOneRound( SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) = 0; + void PreLoadMeasuredStates(const std::string& log_file); + void RunCallbacks(const Array& callbacks); + + SearchTask cur_task_; // The current task + int verbose_; // Verbose level (0 means silent) + // Dict keys static constexpr const char* always_unroll_inner_key = "ansor_always_unroll_inner"; static constexpr const char* always_unroll_key = "ansor_always_unroll"; @@ -63,6 +92,11 @@ class SearchPolicyNode : public Object { static constexpr const char *_type_key = "ansor.SearchPolicy"; TVM_DECLARE_BASE_OBJECT_INFO(SearchPolicyNode, Object); + + protected: + // The set of the already measured states. + // We store the string format for redundancy check + std::unordered_set measured_states_set_; }; TVM_DEFINE_MUTABLE_OBJECT_REF(SearchPolicy, SearchPolicyNode); diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index 76f5d4449001..b03acb1edc3c 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -499,6 +499,10 @@ LogReader LogReaderNode::make(std::string filename) { return LogReader(node); } +LogReaderNode::~LogReaderNode() { + infile.close(); +} + bool LogReaderNode::ReadNext(MeasureInputNode* inp, MeasureResultNode* res) { std::string log_version; diff --git a/src/ansor/serialization.h b/src/ansor/serialization.h index a12760bb3acc..d877717db9cb 100644 --- a/src/ansor/serialization.h +++ b/src/ansor/serialization.h @@ -58,6 +58,7 @@ class LogReaderNode : public Object { std::ifstream infile; static LogReader make(std::string filename); + ~LogReaderNode(); /*! \brief Read next line in the log file * \return Whether the read is successful */ diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py index 5cb67dba39fe..6fe1012e6629 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -42,9 +42,9 @@ def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local' search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) tune_option = ansor.TuneOption(n_trials=n_trials, runner=runner, - callbacks=[ansor.LogToFile(log_file)]) - sch, args = ansor.auto_schedule(task, search_policy, - tune_option=tune_option) + measure_callbacks=[ansor.LogToFile(log_file)]) + sch, args = ansor.auto_schedule(task, search_policy=search_policy, + tune_option=tune_option) inp, res = ansor.best_measure_pair_in_file(log_file, workload_key, target) print("==== Python Code ====") diff --git a/tutorials/ansor/tune_conv2d_cuda.py b/tutorials/ansor/tune_conv2d_cuda.py index 82a5e8572ba2..caa040d1b3bc 100644 --- a/tutorials/ansor/tune_conv2d_cuda.py +++ b/tutorials/ansor/tune_conv2d_cuda.py @@ -110,11 +110,11 @@ def conv2d_nchw(N, H, W, CO, CI, KH, KW, stride, padding): log_file = "conv2d_nchw.json" seed = 0 random.seed(seed) -cost_model = ansor.XGBModel() +cost_model = ansor.XGBModel(seed=seed) search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) ######################################################################### -# The :code:`ansor.RPCRunnerWarpper` is used to create a RPC runner environment, +# The :code:`ansor.LocalRPCMeasureContext` is used to create a RPC runner environment. # # Use local gpu, measure 10 times for every schedule to reduce variance. The timeout # for each running is set to 4 seconds. @@ -123,15 +123,24 @@ def conv2d_nchw(N, H, W, CO, CI, KH, KW, stride, padding): # will be filtered out. It's fine to see "Encountered errors during feature extraction." # in the tuning logs. -with ansor.RPCRunnerWarpper("cuda", repeat=3, min_repeat_ms=100, timeout=4) as rpc_runner: - tune_option = ansor.TuneOption(n_trials=20, - runner=rpc_runner.runner, - callbacks=[ansor.LogToFile(log_file)]) - state = ansor.auto_schedule(task, search_policy, - tune_option=tune_option) - print(state) +measure_ctx = ansor.LocalRPCMeasureContext(repeat=3, min_repeat_ms=100, timeout=4) +tune_option = ansor.TuneOption(n_trials=20, + runner=measure_ctx.runner, + measure_callbacks=[ansor.LogToFile(log_file)]) +s, arg_bufs = ansor.auto_schedule(task, search_policy=search_policy, tune_option=tune_option) + +print("==== Get Lowered Stmt ====") +print(tvm.lower(s, arg_bufs, simple_mode=True)) + +# Release the RPC runner environment +del measure_ctx ######################################################################### +# From the example lower result showed above, we can see that Ansor has tried +# techniques such as `Shared Memory Cooperative Fetching`, `Kernel Fusion`, +# `Axis unroll`, `Axis Vectorize` and so on. There is no need for users to care +# about the details, and Ansor will catch them well. +# # Finally we can directly use the returned result to get the generated schedule, # while in the following tutorial we'll show how to inspect the best config from # log file, check correctness, and measure running time. @@ -160,5 +169,5 @@ def conv2d_nchw(N, H, W, CO, CI, KH, KW, stride, padding): # Evaluate running time. Here we choose a large repeat number (400) to reduce the noise # and the overhead of kernel launch. You can also use nvprof to validate the result. evaluator = func.time_evaluator(func.entry_name, ctx, number=400) -print('Time cost of this operator: %f' % evaluator(a_tvm, w_tvm, c_tvm).mean) +print('Time cost of this operator: %f s' % evaluator(a_tvm, w_tvm, c_tvm).mean) diff --git a/tutorials/ansor/tune_simple_subgraph.py b/tutorials/ansor/tune_simple_subgraph.py index 2af33c1e88ba..fedbb399d0cf 100644 --- a/tutorials/ansor/tune_simple_subgraph.py +++ b/tutorials/ansor/tune_simple_subgraph.py @@ -113,8 +113,8 @@ def matmul_add(N, L, M, dtype): # When proposing the next batch of schedules, Ansor can take different cost models to # guide the schedule generating process. # -# * :any:`RandomModel`: Generate and take new schedule randomly -# * :any:`XGBModel`: Use XGBoost model to estimate the performance of potential schedules, try to pick schedules with better performance in each step +# * :code:`RandomModel`: Generate and take new schedule randomly +# * :code:`XGBModel`: Use XGBoost model to estimate the performance of potential schedules, try to pick schedules with better performance in each step # # XGBModel can explore more efficiently and find better schedules. @@ -130,7 +130,7 @@ def matmul_add(N, L, M, dtype): # # Then we create the :code:`tvm.target` and a tuning task. -N, L, M = 64, 64, 64 +N, L, M = 128, 128, 128 A, B, C, D = matmul_add(N, L, M, 'float32') dag = ansor.ComputeDAG([A, B, C, D]) @@ -148,9 +148,6 @@ def matmul_add(N, L, M, dtype): # you can do more trials according to your time budget. # The :code:`ansor.LogToFile` callback will log the tuning results into a # log file, which can be used to get the best config later. -# -# Then just call :code:`ansor.auto_schedule` and Ansor will try to find a high -# performance schedule for the target subgraph automatically. log_file = "matmul_add.json" @@ -160,34 +157,20 @@ def matmul_add(N, L, M, dtype): search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) tune_option = ansor.TuneOption(n_trials=5, - callbacks=[ansor.LogToFile(log_file)]) + measure_callbacks=[ansor.LogToFile(log_file)]) -state = ansor.auto_schedule(task, search_policy, - tune_option=tune_option) -print(state) - -######################################################################### -# Finally we apply the history best to be a TVM schedule. -# -# We can call the function :code:`apply_steps_from_state` directly using the returned -# :code:`state` structure. -# :code:`state` can also be used to print out the user friendly Python code on demand. -# -# And since we've record the runing results to file, we can also use the following -# code to reply the best schedule from the log file: -# .. code-block:: c -# -# inp, res = ansor.best_measure_pair_in_file(log_file) -# state = inp.state -# s, arg_bufs = dag.apply_steps_from_state(state) +################################################################ +# Then just call :code:`ansor.auto_schedule` and Ansor will try to find a high +# performance schedule for the target subgraph automatically. # -# With the :code:`state` above, we have lowered result and its python code: +# The returned result will be a :code:`te.schedule` and a list of :code:`te.Tensor`, +# which can be used as the input of :code:`tvm.lower` or :code:`tvm.build`. + +s, arg_bufs = ansor.auto_schedule(task, search_policy=search_policy, + tune_option=tune_option) -s, arg_bufs = dag.apply_steps_from_state(state) print("==== Get Lowered Stmt ====") print(tvm.lower(s, arg_bufs, simple_mode=True)) -print("==== Get Python Code ====") -print(dag.print_python_code_from_state(state)) ######################################################################### # Check the correctness to make sure we generate a right schedule. From 3a24e49ee7b7e5d3b09e2fb6062c45923a95abd3 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Thu, 11 Jun 2020 19:09:24 +0800 Subject: [PATCH 21/45] Add python custom sketch rule (#21) * Add custom sketch rule * Bug fix --- python/tvm/ansor/__init__.py | 3 +- python/tvm/ansor/auto_schedule.py | 46 +++++-- scripts/tune_test.py | 2 +- .../search_policy/meta_tile_rewrite_policy.cc | 116 ++++++++++++++---- .../search_policy/meta_tile_rewrite_policy.h | 21 +++- src/ansor/search_policy/search_policy.cc | 12 +- src/ansor/search_policy/search_policy.h | 10 +- .../unittest/test_ansor_search_policy.py | 68 +++++++++- 8 files changed, 230 insertions(+), 48 deletions(-) diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 1029875917aa..845d1b5e477d 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -29,7 +29,8 @@ # Shortcut from .compute_dag import ComputeDAG -from .auto_schedule import SearchTask, MetaTileRewritePolicy, TuneOption, HardwareParams, PreLoadMeasuredStatesCallback +from .auto_schedule import SearchTask, MetaTileRewritePolicy, TuneOption, HardwareParams, \ + PreLoadMeasuredStates, PreAddCustomRule from .auto_schedule import auto_schedule from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, LocalRPCMeasureContext from .cost_model import RandomModel diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index 5b5eb4fe183b..e1a0711a80be 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -67,14 +67,16 @@ def __init__(self, dag, workload_key, target, target_host=None, @tvm._ffi.register_object("ansor.SearchPolicy") class SearchPolicy(Object): + """ The base search policy class + """ def continue_search(self, task, num_measure, verbose, measurer): return _ffi_api.SearchPolicyContinueSearchOneRound(self, task, num_measure, verbose, measurer) - + def set_task(self, task): - _ffi_api.SearchPolicySetTask(self, task); + _ffi_api.SearchPolicySetTask(self, task) def set_verbose(self, verbose): - _ffi_api.SearchPolicySetVerbose(self, verbose); + _ffi_api.SearchPolicySetVerbose(self, verbose) def run_callbacks(self, callbacks): _ffi_api.SearchPolicyRunCallbacks(self, callbacks) @@ -130,14 +132,39 @@ class SearchCallback(Object): pass -@tvm._ffi.register_object("ansor.PreLoadMeasuredStatesCallback") -class PreLoadMeasuredStatesCallback(SearchCallback): +@tvm._ffi.register_object("ansor.PreLoadMeasuredStates") +class PreLoadMeasuredStates(SearchCallback): """ A SearchCallback that used for search policy to load measured hash from the log file. + + Parameters + ---------- + filename: Str """ def __init__(self, filename: str): self.__init_handle_by_constructor__( - _ffi_api.PreLoadMeasuredStatesCallback, filename) + _ffi_api.PreLoadMeasuredStates, filename) + + +@tvm._ffi.register_object("ansor.PreAddCustomRule") +class PreAddCustomRule(SearchCallback): + """ + A SearchCallback for MetaTileRewritePolicy that allowing users to add + custom sketch rule. + + Notice: This is an advanced feature, make sure you're clear how it + works and this should only be used in MetaTileRewritePolicy. + + Parameters + ---------- + meet_condition_func: Function + A function with `(policy, state, stage_id) -> int` + apply_func: Function + A function with `(policy, state, stage_id) -> [[State, int], ...]` + """ + def __init__(self, meet_condition_func, apply_func): + self.__init_handle_by_constructor__( + _ffi_api.PreAddCustomRule, meet_condition_func, apply_func) @tvm._ffi.register_object("ansor.TuneOption") @@ -159,8 +186,13 @@ class TuneOption(Object): runner: Runner Runner which runs the program and measure time costs measure_callbacks: List[MeasureCallback] - Callback functions + Callback functions called after each measure + Candidates: + - ansor.LogToFile pre_search_callbacks: List[SearchCallback] + Callback functions called before the search process + Candidates: + - ansor.PreLoadMeasuredStates """ def __init__(self, n_trials=0, early_stopping=-1, num_measure_per_iter=64, verbose=1, builder='local', runner='local', measure_callbacks=None, diff --git a/scripts/tune_test.py b/scripts/tune_test.py index 1f75f0dd583e..08f0cc19ade2 100644 --- a/scripts/tune_test.py +++ b/scripts/tune_test.py @@ -157,7 +157,7 @@ def objective_func(costs): builder=builder, runner=runner, measure_callbacks=[ansor.LogToFile(log_file)], - pre_search_callbacks=[ansor.PreLoadMeasuredStatesCallback(log_file)]) + pre_search_callbacks=[ansor.PreLoadMeasuredStates(log_file)]) if args.task_scheduler == 'no': # tune workloads one by one diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.cc b/src/ansor/search_policy/meta_tile_rewrite_policy.cc index 0a9f97ab9170..5703e17ba29f 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.cc +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.cc @@ -41,7 +41,8 @@ namespace tvm { namespace ansor { -TVM_REGISTER_OBJECT_TYPE(MetaTileRewritePolicyNode); +TVM_REGISTER_NODE_TYPE(MetaTileRewritePolicyNode); +TVM_REGISTER_OBJECT_TYPE(PreAddCustomRuleNode); // All possible candidates for auto_unroll const std::vector MetaTileRewritePolicyNode::auto_unroll_configs{0, 16, 64, 512, 1024}; @@ -241,7 +242,7 @@ void MetaTileRewritePolicyNode::SearchOneRound(std::vector* best_states, // Synthesize meta structure std::vector meta_structures; - SynthesizeMetaStructure(&meta_structures); + GenerateMetaSketch(&meta_structures); // PrintAllStates(meta_structures); // exit(0); @@ -272,8 +273,8 @@ void MetaTileRewritePolicyNode::SearchOneRound(std::vector* best_states, RandomSampleStates(init_population, &rand_gen_, num_random_states * 10, random_states); } -// The baseclass of derivation rules used in meta structure synthesis -class StructureSynthesisRule { +// The baseclass of derivation rules used in meta sketch generation +class SketchGenerationRule { public: enum ConditionEnum { kPass, kApply, kApplyAndSkipRest @@ -345,7 +346,7 @@ static inline bool ShouldAlwaysBeInlined( } // The rule that inlines simple elementwise ops -class RuleAlwaysInline : public StructureSynthesisRule { +class RuleAlwaysInline : public SketchGenerationRule { public: ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { @@ -362,7 +363,7 @@ class RuleAlwaysInline : public StructureSynthesisRule { }; // The rule that simply skip the current stage -class RuleSkipStage : public StructureSynthesisRule { +class RuleSkipStage : public SketchGenerationRule { public: ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { @@ -387,7 +388,7 @@ class RuleSkipStage : public StructureSynthesisRule { }; // The rule that performs multi-level tiling -class RuleMultiLevelTiling : public StructureSynthesisRule { +class RuleMultiLevelTiling : public SketchGenerationRule { public: ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { @@ -413,7 +414,7 @@ class RuleMultiLevelTiling : public StructureSynthesisRule { }; // The rule that performs multi-level tiling and fuses later consumers -class RuleMultiLevelTilingWithFusion : public StructureSynthesisRule { +class RuleMultiLevelTilingWithFusion : public SketchGenerationRule { public: ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { @@ -482,7 +483,7 @@ class RuleMultiLevelTilingWithFusion : public StructureSynthesisRule { }; // The rule that adds a cache write stage -class RuleAddCacheWrite : public StructureSynthesisRule { +class RuleAddCacheWrite : public SketchGenerationRule { public: ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { @@ -515,7 +516,7 @@ class RuleAddCacheWrite : public StructureSynthesisRule { // The rule that adds a cache read stage // Mainly used for GPU cooperative fetching // Currently only support 1 to 1 match cache read -class RuleAddCacheRead : public StructureSynthesisRule { +class RuleAddCacheRead : public SketchGenerationRule { public: ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { @@ -546,7 +547,7 @@ class RuleAddCacheRead : public StructureSynthesisRule { }; // The rule that adds rfactor stage -class RuleAddRfactor : public StructureSynthesisRule { +class RuleAddRfactor : public SketchGenerationRule { public: ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { @@ -610,7 +611,7 @@ class RuleAddRfactor : public StructureSynthesisRule { } }; -void MetaTileRewritePolicyNode::SynthesizeMetaStructure( +void MetaTileRewritePolicyNode::GenerateMetaSketch( std::vector* out_states) { State init_state = cur_task_->compute_dag.GetInitState(); std::string cpu_multi_level_tiling_structure = @@ -634,18 +635,22 @@ void MetaTileRewritePolicyNode::SynthesizeMetaStructure( static RuleAddCacheWrite rule_add_cache_write_stage; static RuleAddCacheRead rule_add_cache_read_stage; static RuleAddRfactor rule_add_rfactor; - // We may apply and skip the rest when processing some rules, - // should take care of the rule vector order here - static std::vector all_rules { - &rule_always_inline, &rule_add_cache_write_stage, - &rule_multi_level_tiling_with_fusion, &rule_multi_level_tiling, - &rule_add_rfactor, &rule_skip_stage - }; - if (IS_GPU(cur_task_)) { - // Try cache read first before cache write - all_rules.insert(all_rules.begin() + 1, &rule_add_cache_read_stage); + if (sketch_rules.empty()) { + // We may apply and skip the rest when processing some rules, + // should take care of the rule vector order here + sketch_rules.push_back(&rule_always_inline); + sketch_rules.push_back(&rule_add_cache_write_stage); + sketch_rules.push_back(&rule_multi_level_tiling_with_fusion); + sketch_rules.push_back(&rule_multi_level_tiling); + sketch_rules.push_back(&rule_add_rfactor); + sketch_rules.push_back(&rule_skip_stage); + if (IS_GPU(cur_task_)) { + // Try cache read first before cache write + sketch_rules.insert(sketch_rules.begin() + 1, &rule_add_cache_read_stage); + } + // TODO(xian): Add a new rule to try combination of multi-level + // tiling + rfactor } - // TODO(xian): Add a new rule to try combination of multi-level tiling + rfactor // Derivation rule based synthesizer while (!pnow->empty()) { @@ -661,15 +666,15 @@ void MetaTileRewritePolicyNode::SynthesizeMetaStructure( } // Try all derivation rules - for (const auto& rule : all_rules) { + for (const auto& rule : sketch_rules) { auto rule_check = rule->MeetCondition(this, state, stage_id); - if (rule_check > StructureSynthesisRule::ConditionEnum::kPass) { + if (rule_check > SketchGenerationRule::ConditionEnum::kPass) { for (const auto& pair : rule->Apply(this, state, stage_id)) { cur_stage_id_map[pair.first] = pair.second; pnext->push_back(pair.first); } // Skip the reset rules - if (rule_check == StructureSynthesisRule::ConditionEnum::kApplyAndSkipRest) { + if (rule_check == SketchGenerationRule::ConditionEnum::kApplyAndSkipRest) { break; } } @@ -1444,6 +1449,60 @@ void MetaTileRewritePolicyNode::EvolutionarySearch( << std::fixed << std::setprecision(2) << duration << std::endl; } +class RuleCustomSketch : public SketchGenerationRule { + public: + RuleCustomSketch(PackedFunc meet_condition_func, PackedFunc apply_func) : + meet_condition_func_(meet_condition_func), apply_func_(apply_func) {} + + inline ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) final { + auto ret = meet_condition_func_( + tvm::runtime::GetRef(policy), state, stage_id); + if (ret.type_code() == 0) { + return ConditionEnum(static_cast(ret)); + } else { + return kApplyAndSkipRest; + } + } + + inline std::vector > Apply( + const MetaTileRewritePolicyNode* policy, + const State& state, int stage_id) final { + std::vector > ret; + + Array> apply_ret = apply_func_( + tvm::runtime::GetRef(policy), state, stage_id); + + for (const auto& item : apply_ret) { + CHECK_EQ(item.size(), 2); + State state = Downcast(item[0]); + auto next = item[1].as(); + ret.emplace_back(state, next->value); + } + return ret; + } + + private: + PackedFunc meet_condition_func_; + PackedFunc apply_func_; +}; + +SearchCallback PreAddCustomRuleNode::make(PackedFunc meet_condition_func, + PackedFunc apply_func) { + auto node = make_object(); + node->meet_condition_func = meet_condition_func; + node->apply_func = apply_func; + return SearchCallback(node); +} + +void PreAddCustomRuleNode::callback(SearchPolicyNode* policy) { + CHECK(policy->IsInstance()); + auto meta_policy = dynamic_cast(policy); + meta_policy->sketch_rules.emplace_back( + new RuleCustomSketch(meet_condition_func, apply_func)); + StdCout(policy->verbose_) << "Custom sketch rule added." << std::endl; +} + TVM_REGISTER_GLOBAL("ansor.MetaTileRewritePolicy") .set_body_typed([](CostModel program_cost_model, Map params, @@ -1451,5 +1510,10 @@ TVM_REGISTER_GLOBAL("ansor.MetaTileRewritePolicy") return MetaTileRewritePolicyNode::make(program_cost_model, params, seed); }); +TVM_REGISTER_GLOBAL("ansor.PreAddCustomRule") +.set_body_typed([](PackedFunc meet_condition_func, PackedFunc apply_func) { + return PreAddCustomRuleNode::make(meet_condition_func, apply_func); +}); + } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.h b/src/ansor/search_policy/meta_tile_rewrite_policy.h index 8cf61b4d1e11..befc002b6aa2 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.h +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.h @@ -38,6 +38,8 @@ namespace tvm { namespace ansor { +class SketchGenerationRule; + /*! Multi stage search policy */ class MetaTileRewritePolicyNode: public SearchPolicyNode { public: @@ -54,6 +56,7 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode { * str gpu_multi_level_tiling_structure // The structure of multi-level tiling for GPU */ Map params; + std::vector sketch_rules; static SearchPolicy make(CostModel program_cost_model, Map params, @@ -87,7 +90,7 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode { int num_random_states, std::vector* random_states); // Synthesize meta tiling structure without tile size - void SynthesizeMetaStructure(std::vector* out_states); + void GenerateMetaSketch(std::vector* out_states); // Sample init population void SampleInitPopulation(const std::vector& meta_structures, @@ -107,6 +110,22 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode { // The throughputs of already measured states std::vector measured_states_throughputs_; }; +TVM_DEFINE_MUTABLE_OBJECT_REF(MetaTileRewritePolicy, MetaTileRewritePolicyNode); + +class PreAddCustomRuleNode : public SearchCallbackNode { + public: + // TODO(jcf94): Use tvm::runtime::TypedPackedFunc? + PackedFunc meet_condition_func; + PackedFunc apply_func; + + static SearchCallback make(PackedFunc meet_condition_func, + PackedFunc apply_func); + + void callback(SearchPolicyNode* policy) final; + + static constexpr const char *_type_key = "ansor.PreAddCustomRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(PreAddCustomRuleNode, SearchCallbackNode); +}; } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_policy/search_policy.cc b/src/ansor/search_policy/search_policy.cc index b2ba27bfc6ba..d52b868e180d 100644 --- a/src/ansor/search_policy/search_policy.cc +++ b/src/ansor/search_policy/search_policy.cc @@ -32,7 +32,7 @@ namespace tvm { namespace ansor { TVM_REGISTER_OBJECT_TYPE(SearchPolicyNode); -TVM_REGISTER_OBJECT_TYPE(PreLoadMeasuredStatesCallbackNode); +TVM_REGISTER_OBJECT_TYPE(PreLoadMeasuredStatesNode); void SearchPolicyNode::PreLoadMeasuredStates(const std::string& log_file) { LogReader reader = LogReaderNode::make(log_file); @@ -69,13 +69,13 @@ void SearchPolicyNode::RunCallbacks(const Array& callbacks) { } } -SearchCallback PreLoadMeasuredStatesCallbackNode::make(std::string filename) { - auto node = make_object(); +SearchCallback PreLoadMeasuredStatesNode::make(std::string filename) { + auto node = make_object(); node->filename = std::move(filename); return SearchCallback(node); } -void PreLoadMeasuredStatesCallbackNode::callback(SearchPolicyNode* policy) { +void PreLoadMeasuredStatesNode::callback(SearchPolicyNode* policy) { policy->PreLoadMeasuredStates(filename); } @@ -105,9 +105,9 @@ TVM_REGISTER_GLOBAL("ansor.SearchPolicySetVerbose") policy->verbose_ = verbose; }); -TVM_REGISTER_GLOBAL("ansor.PreLoadMeasuredStatesCallback") +TVM_REGISTER_GLOBAL("ansor.PreLoadMeasuredStates") .set_body_typed([](std::string filename) { - return PreLoadMeasuredStatesCallbackNode::make(filename); + return PreLoadMeasuredStatesNode::make(filename); }); } // namespace ansor diff --git a/src/ansor/search_policy/search_policy.h b/src/ansor/search_policy/search_policy.h index 0d7ebe94c14f..2dfbd9429648 100644 --- a/src/ansor/search_policy/search_policy.h +++ b/src/ansor/search_policy/search_policy.h @@ -47,7 +47,7 @@ class SearchCallbackNode : public Object { }; TVM_DEFINE_MUTABLE_OBJECT_REF(SearchCallback, SearchCallbackNode); -class PreLoadMeasuredStatesCallbackNode : public SearchCallbackNode { +class PreLoadMeasuredStatesNode : public SearchCallbackNode { public: std::string filename; @@ -55,8 +55,8 @@ class PreLoadMeasuredStatesCallbackNode : public SearchCallbackNode { void callback(SearchPolicyNode* policy) final; - static constexpr const char *_type_key = "ansor.PreLoadMeasuredStatesCallback"; - TVM_DECLARE_FINAL_OBJECT_INFO(PreLoadMeasuredStatesCallbackNode, SearchCallbackNode); + static constexpr const char *_type_key = "ansor.PreLoadMeasuredStates"; + TVM_DECLARE_FINAL_OBJECT_INFO(PreLoadMeasuredStatesNode, SearchCallbackNode); }; /*! \brief The base class for search policy */ @@ -76,6 +76,10 @@ class SearchPolicyNode : public Object { SearchTask cur_task_; // The current task int verbose_; // Verbose level (0 means silent) + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("cur_task", &cur_task_); + } + // Dict keys static constexpr const char* always_unroll_inner_key = "ansor_always_unroll_inner"; static constexpr const char* always_unroll_key = "ansor_always_unroll"; diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py index 6fe1012e6629..b86dfa95f9bd 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -27,7 +27,8 @@ from test_ansor_common import matmul_ansor_test def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local', - cost_model=ansor.RandomModel(), n_trials=2): + cost_model=ansor.RandomModel(), n_trials=2, params=None, + pre_search_callbacks=None): print("Test %s schedule search with the default search policy" % (target)) random.seed(seed) @@ -40,9 +41,11 @@ def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local' with tempfile.NamedTemporaryFile() as fp: log_file = fp.name - search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) + search_policy = ansor.MetaTileRewritePolicy(cost_model, params=params, + seed=seed) tune_option = ansor.TuneOption(n_trials=n_trials, runner=runner, - measure_callbacks=[ansor.LogToFile(log_file)]) + measure_callbacks=[ansor.LogToFile(log_file)], + pre_search_callbacks=pre_search_callbacks) sch, args = ansor.auto_schedule(task, search_policy=search_policy, tune_option=tune_option) inp, res = ansor.best_measure_pair_in_file(log_file, workload_key, target) @@ -95,8 +98,67 @@ def test_search_cuda(): print("CUDA device not found, skip this test.") +def test_search_custom_sketch_rule(): + def meet_condition_func(meta_policy, state, stage_id): + # Apply and Skip the Rest if this function does not return + pass + + # Expecting: + # i.0 + # i.1 + # i.2 + # j.0 + # j.1 + # ax0 + # ax1 + # B.global + # j.2 + # k + # C + def apply_func1(meta_policy, state, stage_id): + # Stage by stage way + ret = [] + if stage_id == 2: + state = ansor.loop_state.State(state) + state.split(2, state.stages[2].iters[0], [4, 4]) + state.split(2, state.stages[2].iters[3], [4, 4]) + ret.append([state.state_object, stage_id - 1]) + elif stage_id == 1: + state = ansor.loop_state.State(state) + state.cache_read(1, "global", [2], meta_policy.cur_task.compute_dag) + state.compute_at(2, 3, state.stages[3].iters[4]) + ret.append([state.state_object, stage_id - 1]) + else: + ret.append([state, stage_id - 1]) + return ret + + def apply_func2(meta_policy, state, stage_id): + # More template like way + ret = [] + state = ansor.loop_state.State(state) + + state.split(2, state.stages[2].iters[0], [4, 4]) + state.split(2, state.stages[2].iters[3], [4, 4]) + state.cache_read(1, "global", [2], meta_policy.cur_task.compute_dag) + state.compute_at(2, 3, state.stages[3].iters[4]) + + ret.append([state.state_object, -1]) + return ret + + measure_ctx = ansor.LocalRPCMeasureContext() + search_common(seed=887823438, runner=measure_ctx.runner, + pre_search_callbacks=[ansor.PreAddCustomRule(meet_condition_func, + apply_func1)], + params={'disable_change_compute_location': 1}) + search_common(seed=887823438, runner=measure_ctx.runner, + pre_search_callbacks=[ansor.PreAddCustomRule(meet_condition_func, + apply_func2)], + params={'disable_change_compute_location': 1}) + + if __name__ == "__main__": test_search_basic() test_search_xgb_model_rpc_runner() test_search_opencl() test_search_cuda() + test_search_custom_sketch_rule() From a155c1f46fdbbe44d2189b790313ae16cc42ce52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Minmin=20Sun=20=28=E5=AD=99=E6=95=8F=E6=95=8F=29?= Date: Fri, 12 Jun 2020 16:25:26 +0800 Subject: [PATCH 22/45] Ansor Relay Integration (without layout rewrite) (#22) * relay integration --- python/tvm/ansor/__init__.py | 10 +- python/tvm/ansor/compute_dag.py | 11 + python/tvm/ansor/dispatcher.py | 518 ++++++++++++++++++++++++++ python/tvm/ansor/env.py | 8 + python/tvm/ansor/relay_integration.py | 209 +++++++++++ python/tvm/ansor/serialization.py | 3 + python/tvm/ansor/topi_integration.py | 215 +++++++++++ scripts/tune_network.py | 497 ++++++++++++++++++++++++ topi/python/topi/ansor.py | 95 +++++ topi/python/topi/arm_cpu/__init__.py | 5 + topi/python/topi/generic/__init__.py | 5 + topi/python/topi/x86/__init__.py | 5 + 12 files changed, 1579 insertions(+), 2 deletions(-) create mode 100644 python/tvm/ansor/dispatcher.py create mode 100644 python/tvm/ansor/env.py create mode 100644 python/tvm/ansor/relay_integration.py create mode 100644 python/tvm/ansor/topi_integration.py create mode 100644 scripts/tune_network.py create mode 100644 topi/python/topi/ansor.py diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 845d1b5e477d..6ea8a0ce904f 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -28,14 +28,20 @@ from . import task_scheduler # Shortcut -from .compute_dag import ComputeDAG +from .compute_dag import ComputeDAG, LayoutRewriteLevel, gen_schedule from .auto_schedule import SearchTask, MetaTileRewritePolicy, TuneOption, HardwareParams, \ PreLoadMeasuredStates, PreAddCustomRule from .auto_schedule import auto_schedule from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, LocalRPCMeasureContext from .cost_model import RandomModel from .cost_model.xgb_model import XGBModel -from .serialization import LogToFile, LogReader, best_measure_pair_in_file, write_measure_records_to_file +from .serialization import LogToFile, LogReader, best_measure_pair_in_file, \ + load_from_file, write_measure_records_to_file from .workload_registry import register_auto_scheduler_workload_func, \ workload_key_to_dag, make_workload_key_func from .task_scheduler import TaskScheduler, SimpleTaskScheduler +from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest as apply_history_best, \ + FallbackContext, clear_fallback_cache, ApplyGraphBest, BlockingEmptyContext +from .topi_integration import register_topi_schedule, TaskExtractEnv +from .relay_integration import extract_from_program, extract_from_multiple_program, \ + finish_layout_rewrite diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index 0b51ebb402cc..0c8aa2055482 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -19,6 +19,7 @@ import tvm._ffi from tvm.runtime import Object +from tvm import te from .loop_state import State from . import _ffi_api @@ -88,3 +89,13 @@ def infer_bound_from_state(self, state): state : StateObject """ return _ffi_api.ComputeDAGInferBoundFromState(self, state) + +def gen_schedule(state, bufs): + if not state or not state.complete: + return te.create_schedule([x.op for x in bufs]) + else: + dag = ComputeDAG(bufs) + # only update compute body, layout_rewrite_level = LayoutRewriteLevel.COMPUTE_REWRITE, + # since kernel layout has already been rewritten in relay pass + schedule, _ = dag.apply_steps_from_state(state, layout_rewrite_level=LayoutRewriteLevel.COMPUTE_REWRITE) + return schedule diff --git a/python/tvm/ansor/dispatcher.py b/python/tvm/ansor/dispatcher.py new file mode 100644 index 000000000000..2f00c355d285 --- /dev/null +++ b/python/tvm/ansor/dispatcher.py @@ -0,0 +1,518 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Template dispatcher module. + +A dispatcher is a function that can contains multiple behaviors. +Its specific behavior is can be controlled by DispatchContext. + +DispatchContext is used in two ways, usually via different implementation +of the DispatchContext base class. + +- During search, we can use it to pass the current proposal from tuner. +- During evaluation, we can use it to set pick the best policy. +""" +# pylint: disable=invalid-name + +from __future__ import absolute_import as _abs + +import logging + +import numpy as np +from decorator import decorate + +from tvm import target as _target +from tvm.tir.expr import StringImm, FloatImm + +from .loop_state import State, StateObject + +logger = logging.getLogger('auto_scheduler') + + +class DispatchContext(object): + """ + Base class of dispatch context. + + DispatchContext enables the target and workload + specific dispatch mechanism for templates. + """ + current = None + + def __init__(self): + self._old_ctx = DispatchContext.current + + def query(self, target, workload): + """ + Query the context to get the specific config for a template. + If cannot find the result inside this context, this function will query it + from the upper contexts. + + Parameters + ---------- + target: Target + The current target + workload : Workload + The current workload. + + Returns + ------- + cfg : State or str + The specific state for auto scheduler. + """ + ret = self._query_inside(target, workload) + #if ret is None: + # ret = self._old_ctx.query(target, workload) + return ret + + def update(self, target, workload, cfg): + """ + Update context with a specific config. + + Parameters + ---------- + target: Target + The current target + workload : Workload + The current workload. + cfg : State or str + The specific state for auto scheduler. + + Note + ---- + This interface is for cases when TVM decides to replace an operator in the graph. + For example, `AlterOpLayout` pass (enables when `opt_level = 3`) replaces `NCHW` + convolution with `NCHW[x]c` implementation on x86 CPUs. + Thus in TOPI, we first query schedule using original `NCHW` workload, + then update the dispatcher with the new `NCHW[x]c` workload. + So that later on, `NCHW[x]c` convolution can get schedule from the dispatcher using + its own workload directly. + + .. code-block:: python + + @conv2d_alter_layout.register("cpu") + def _alter_conv2d_layout(attrs, inputs, tinfo): + workload = get_conv2d_workload(...) + dispatch_ctx = auto_scheduler.DispatchContext.current + target = tvm.target.current_target() + config = dispatch_ctx.query(target, workload) + + # Get conv2d_NCHWc workload from config + # new_workload = ... + # new_inputs = ... + # new_attrs = ... + + # Store altered operator's config + dispatch_ctx.update(target, new_workload, config) + return sym.contrib.conv2d_NCHWc(*new_inputs, **new_attrs) + + We directly store `config` back because `conv2d_NCHW` and `conv2d_NCHWc` + share the same schedule parameters. + One can construct a new `State` if this is not the case. + """ + raise NotImplementedError() + + def _query_inside(self, target, workload): + """ + Query the context to get the specific config for a template. + This function only query config inside this context. + + Parameters + ---------- + target: Target + The current target + workload : Workload + The current workload. + + Returns + ------- + cfg : State or str + The specific state for auto scheduler. + """ + raise NotImplementedError() + + def __enter__(self): + self._old_ctx = DispatchContext.current + DispatchContext.current = self + return self + + def __exit__(self, ptype, value, trace): + DispatchContext.current = self._old_ctx + + +def dispatcher(fworkload): + """Wrap a workload dispatcher function. + + Parameters + ---------- + fworkload : function + The workload extraction function from arguments. + + Returns + ------- + fdispatcher : function + A wrapped dispatcher function, which will + dispatch based on DispatchContext and + the current workload. + """ + dispatch_dict = {} + func_name = fworkload.__name__ + + def register(key, func=None, override=False): + """Register template function. + + Parameters + ---------- + key : str or List of str + The template key to identify the template + under this dispatcher. + func : function + The function to be registered. + The first argument of the function is always + cfg returned by DispatchContext, + the rest arguments are the same as the fworkload. + override : bool + Whether override existing registration. + + Returns + ------- + The register function if necessary. + """ + if isinstance(key, str): + key = [key] + + def _do_reg(myf): + for x in key: + if x in dispatch_dict and not override: + raise ValueError( + "Key %s is already registered for %s" % (x, func_name)) + dispatch_dict[x] = myf + return myf + + if func: + return _do_reg(func) + return _do_reg + + def dispatch_func(func, *args, **kwargs): + """The wrapped dispatch function""" + tgt = _target.current_target() + workload = func(*args, **kwargs) + cfg = DispatchContext.current.query(tgt, workload) + return dispatch_dict['direct'](cfg, *args, **kwargs) + + fdecorate = decorate(fworkload, dispatch_func) + fdecorate.register = register + return fdecorate + + +class ApplyConfig(DispatchContext): + """Apply a deterministic config entity for all queries. + + Parameters + ---------- + config : State + The specific state for auto scheduler. + """ + def __init__(self, config): + super(ApplyConfig, self).__init__() + self._config = config + self.workload = None + + def _query_inside(self, target, workload): + """Override query""" + self.workload = workload + return self._config + + def update(self, target, workload, cfg): + """Override update""" + self.workload = workload + self._config = cfg + + +class ApplyHistoryBest(DispatchContext): + """ + Apply the history best config + + Parameters + ---------- + records : str or iterator of (MeasureInput, MeasureResult) + Collection of tuning records. + If is str, then it should be the filename of a records log file. + Each row of this file is an encoded record pair. + Otherwise, it is an iterator. + n_lines: int (optional) + if it is not None, only load the first `n_lines` lines of log + """ + def __init__(self, records, n_lines=None): + super(ApplyHistoryBest, self).__init__() + + self.best_by_targetkey = {} + self.best_by_model = {} + self._best_user_defined = {} + + if records: + self.load(records, n_lines) + + def load(self, records, n_lines=None): + """Load records to this dispatch context + + Parameters + ---------- + records : str or iterator of (MeasureInput, MeasureResult) + Collection of tuning records. + If is str, then it should be the filename of a records log file. + Each row of this file is an encoded record pair. + Otherwise, it is an iterator. + n_lines: int (optional) + if it is not None, only load the first `n_lines` lines of log + """ + from pathlib import Path + from . import load_from_file + + if isinstance(records, Path): + records = str(records) + + if isinstance(records, str): + records = load_from_file(records) + if not records: + return + + best_by_targetkey = self.best_by_targetkey + best_by_model = self.best_by_model + + counter = 0 + for inp, res in records: + if n_lines is not None and counter >= n_lines: + break + counter += 1 + if res.error_no != 0: + continue + + # use target keys in tvm target system as key to build best map + for k in inp.task.target.keys: + key = (k, inp.task.workload_key) + if key not in best_by_targetkey: + best_by_targetkey[key] = (inp, res) + else: + _, other_res = best_by_targetkey[key] + other_costs = [x.value for x in other_res.costs if isinstance(x, FloatImm)] + costs = [x.value for x in res.costs if isinstance(x, FloatImm)] + if np.mean(other_costs) > np.mean(costs): + best_by_targetkey[key] = (inp, res) + + # use model as key to build best map + key = (inp.task.target.model, inp.task.workload_key) + if key not in best_by_model: + if inp.task.target.model != 'unknown': + best_by_model[key] = (inp, res) + else: + _, other_res = best_by_model[key] + other_costs = [x.value for x in other_res.costs if isinstance(x, FloatImm)] + costs = [x.value for x in res.costs if isinstance(x, FloatImm)] + if np.mean(other_costs) > np.mean(costs): + best_by_model[key] = (inp, res) + + logger.debug("Finish loading %d records", counter) + + def _query_inside(self, target, workload): + if target is None: + raise RuntimeError("Need a target context to find the history best. " + "Hint: If your target is llvm, use `with tvm.target.create('llvm'):`" + " above the dispatcher call. So does other target. ") + + # first try matching by model + key = (target.model, workload) + if key in self._best_user_defined: + return self._best_user_defined[key] + if key in self.best_by_model: + return self.best_by_model[key][0].state + + # then try matching by target key + for k in target.keys: + key = (k, workload) + if key in self._best_user_defined: + return self._best_user_defined[key] + if key in self.best_by_targetkey: + return self.best_by_targetkey[key][0].state + + return None + + def update(self, target, workload, state): + model = target.model + key = (model, workload) + self._best_user_defined[key] = state + + for k in target.keys: + key = (k, workload) + self._best_user_defined[key] = state + + +class BlockingEmptyContext(DispatchContext): + """ + An empty context which returns emtpy State() for all queries. + This also blocks the queries, so the queries won't affect the global FallbackContext. + """ + def __init__(self): + super(BlockingEmptyContext, self).__init__() + + def query(self, target, workload): + #return StateObject() + return None + + +class FallbackContext(DispatchContext): + """ + A fallback dispatch context. + + Any tunable template can be called under this context. + This is the root context. + """ + + def __init__(self): + super(FallbackContext, self).__init__() + self.memory = {} + self.silent = False + + # a set to prevent print duplicated message + self.messages = set() + + def _query_inside(self, target, workload): + key = (str(target), workload) + if key in self.memory: + return self.memory[key] + + if not self.silent: + msg = "Cannot find config for target=%s, workload=%s. A fallback configuration "\ + "is used, which may bring great performance regression." % (target, workload) + if msg not in self.messages: + self.messages.add(msg) + logger.warning(msg) + #cfg = StateObject() + cfg = None + + # cache this config + self.memory[key] = cfg + return cfg + + def clear_cache(self, target, workload): + """Clear fallback cache. Pass the same argument as _query_inside to this function + to clean the cache. + + Parameters + ---------- + target: Target + The current target + workload : Workload + The current workload. + """ + key = (str(target), workload) + if key in self.memory: + del self.memory[key] + + def update(self, target, workload, cfg): + key = (str(target), workload) + self.memory[key] = cfg + + +DispatchContext.current = FallbackContext() + + +def clear_fallback_cache(target, workload): + """Clear fallback cache. Pass the same argument as _query_inside to this function + to clean the cache. + + Parameters + ---------- + target: Target + The current target + workload : Workload + The current workload. + + Note + ---- + This is used in alter_op_layout to clear the bad cache created before call topi compute function + """ + context = DispatchContext.current + while not isinstance(context, FallbackContext): + context = context._old_ctx + context.clear_cache(target, workload) + + +class ApplyGraphBest(DispatchContext): + """Load the graph level tuning optimal schedules. + + The input records should be in the ascending order of + node index for target operator. Usually this can be obtained + with graph tuner. + + This context maintains an internal counter to indicate the current + node index. + """ + def __init__(self, records): + """ + Parameters + ---------- + records : str or iterator of (MeasureInput, MeasureResult) + Collection of tuning records. + If is str, then it should be the filename of a records log file. + Each row of this file is an encoded record pair. + Otherwise, it is an iterator. + """ + from . import load_from_file + + super(ApplyGraphBest, self).__init__() + if isinstance(records, str): + records = load_from_file(records) + self._records = list(records) + self._counter = 0 + self._global_cfg_dict = {} + + def _query_inside(self, target, workload): + """ + Query the context to get config from records. + + Parameters + ---------- + target : Target + The current target + workload : Workload + The current workload. + + Returns + ------- + cfg : State or str + The specific state for auto scheduler. + """ + if self._counter < len(self._records): + cfg = self._records[self._counter][0].config + self._counter += 1 + self.update(target, workload, cfg) + return cfg + key = (str(target), workload) + if key not in self._global_cfg_dict: + msg = "Config for target=%s, workload=%s is missing in ApplyGraphBest context. " \ + "A fallback configuration is used, which may bring great performance " \ + "regression." % (target, workload) + logger.warning(msg) + cfg = None + self._global_cfg_dict[key] = cfg + else: + cfg = self._global_cfg_dict[key] + return cfg + + def update(self, target, workload, cfg): + key = (str(target), workload) + self._global_cfg_dict[key] = cfg diff --git a/python/tvm/ansor/env.py b/python/tvm/ansor/env.py new file mode 100644 index 000000000000..6d2bbd2c92af --- /dev/null +++ b/python/tvm/ansor/env.py @@ -0,0 +1,8 @@ +""" The scope to store global variables in auto_scheduelr """ + +class AutoschedulerGlobalScope(object): + def __init__(self): + self.topi_in_compute_rewrite_mode = False + +GLOBAL_SCOPE = AutoschedulerGlobalScope() + diff --git a/python/tvm/ansor/relay_integration.py b/python/tvm/ansor/relay_integration.py new file mode 100644 index 000000000000..7d7e18a94ddf --- /dev/null +++ b/python/tvm/ansor/relay_integration.py @@ -0,0 +1,209 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-variable,invalid-name +""" +Decorator and utilities for the integration with TOPI and Relay +99.9% copy-paste of implementation by @MerryMercy + +""" +import threading +import warnings +import tvm + + +from .topi_integration import TaskExtractEnv +from .dispatcher import BlockingEmptyContext +from .env import GLOBAL_SCOPE + +def _lower(mod, + target, + params): + """ Helper to lower VTA properly. + """ + # pylint: disable=import-outside-toplevel + from tvm import relay + from tvm.relay.backend import graph_runtime_codegen + + if hasattr(target, 'device_name') and target.device_name == "vta": + import vta + with vta.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): + mod, _ = relay.optimize(mod, target, params) + grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target) + grc.codegen(mod["main"]) + return + + # default case + # Try graph codegen first to extract autotvm tasks. + # If failed to compile, then fallback to use VM compiler. + # TODO: Currently VM compiler is likely to stack overflow for large models. + try: + with relay.build_config(opt_level=3): + opt_mod, _ = relay.optimize(mod, target, params) + grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target) + grc.codegen(opt_mod["main"]) + except tvm.TVMError: + compiler = relay.vm.VMCompiler() + if params: + compiler.set_params(params) + compiler.lower(mod, target=target) + +OP_TO_SCHEDULE = {} + +def init_op_to_schedule_map(): + # init the global map OP_TO_SCHEDULE inside a function, this is used to resolve import issues + global OP_TO_SCHEDULE + from tvm import relay + import topi + + if OP_TO_SCHEDULE: + return + + OP_TO_SCHEDULE = { + relay.op.nn.conv2d: [topi.generic.schedule_conv2d_nchw, + topi.generic.schedule_conv2d_nhwc, + topi.generic.schedule_depthwise_conv2d_nchw, + topi.generic.schedule_depthwise_conv2d_nhwc, + topi.generic.schedule_group_conv2d_nchw, + topi.generic.schedule_conv2d_winograd_without_weight_transform], + relay.op.nn.conv2d_transpose: [topi.generic.schedule_conv2d_transpose_nchw], + relay.op.nn.dense: [topi.generic.schedule_dense], + relay.op.nn.softmax: [topi.generic.schedule_softmax], + relay.op.nn.max_pool2d: [topi.generic.schedule_pool], + relay.op.nn.avg_pool2d: [topi.generic.schedule_pool], + relay.op.nn.global_avg_pool2d: [topi.generic.schedule_adaptive_pool], + relay.op.nn.global_max_pool2d: [topi.generic.schedule_adaptive_pool], + relay.op.nn.deformable_conv2d: [topi.generic.schedule_deformable_conv2d_nchw], + relay.op.mean: [topi.generic.schedule_reduce], + relay.op.prod: [topi.generic.schedule_reduce], + relay.op.nn.conv3d: [topi.generic.schedule_conv3d_ncdhw, + topi.generic.schedule_conv3d_ndhwc], + relay.op.nn.adaptive_avg_pool3d: [topi.generic.schedule_adaptive_pool], + relay.op.nn.batch_matmul: [topi.generic.schedule_batch_matmul], + } + +def extract_from_program(mod, params, ops, target, target_host=None): + """ Extract tuning tasks from a relay program. + + This function is the single program version of extract_from_multiple_program. + + Parameters + ---------- + mod : relay.Module + The module to extract. + params: dict of str to numpy array + The associated parameters of the program + ops: List of relay op + List of relay ops to be tuned + target: tvm.target.Target + The compilation target + target_host: tvm.target.Target + The host compilation target + + Returns + ------- + workloads: Array of Tuple(wkl_key, target) + """ + return extract_from_multiple_program([mod], [params], ops, target, target_host) + +def extract_from_multiple_program(mods, params, ops, target, target_host=None): + """ Extract tuning tasks from multiple relay programs. + + This function collects tuning tasks by building a list of programs + with a "tracing" target and tracing all the calls to topi. + + Parameters + ---------- + mods : List of relay.Module + The modules to extract. + params: List of dict of str to numpy array + The associated parameters of the programs + ops: List of relay op + List of relay ops to be tuned + target: tvm.target.Target + The compilation target + target_host: tvm.target.Target + The host compilation target + + Returns + ------- + workloads: Array of Tuple(wkl_key, target) + """ + from tvm import relay + + env = TaskExtractEnv.get() + + init_op_to_schedule_map() + topi_scheds = [] + for op_name in ops: + if op_name in OP_TO_SCHEDULE: + topi_scheds.extend(OP_TO_SCHEDULE[op_name]) + else: + warnings.warn("Op %s is not tunable, ignored." % op_name) + + # run compiler to collect all TOPI calls during compilation + env.reset(topi_scheds) + with env: + for mod, param in zip(mods, params): + # wrap build call in thread to avoid multiprocessing problems + with BlockingEmptyContext(): + build_thread = threading.Thread(target=_lower, + args=(mod, target, param)) + build_thread.start() + build_thread.join() + relay.backend.compile_engine.get().clear() + + # create tasks for target + wkl_keys = [] + wkl_weights = [] + for wkl_key, wkl_weight in env.get_wkl_keys().items(): + wkl_keys.append(wkl_key) + wkl_weights.append(wkl_weight) + + return wkl_keys, wkl_weights + +def prepare_layout_rewrite(mod, params, ops, target): + """Prepare for kernel layout rewrite. This function will write layout infos to a global static variable, + then these layout info will be used by a relay pass `kernel_layout_transform`. + """ + from .. import relay + + env = TaskExtractEnv.get(do_layout_rewrite=True) + + init_op_to_schedule_map() + topi_scheds = [] + for op_name in ops: + if op_name in OP_TO_SCHEDULE: + topi_scheds.extend(OP_TO_SCHEDULE[op_name]) + else: + warnings.warn("Op %s is not tunable, ignored." % op_name) + + with env: + env.reset(topi_scheds) + + # wrap build call in thread to avoid multiprocessing problems + build_thread = threading.Thread(target=_lower, + args=(mod, target, param)) + build_thread.start() + build_thread.join() + relay.backend.compile_engine.get().clear() + + if env.layout_rewrite_success_ct > 0: + GLOBAL_SCOPE.topi_in_compute_rewrite_mode = True + +def finish_layout_rewrite(): + """Clear the global flag for layout rewrite""" + GLOBAL_SCOPE.topi_in_compute_rewrite_mode = False diff --git a/python/tvm/ansor/serialization.py b/python/tvm/ansor/serialization.py index 387825034a09..3d7ed7733a78 100644 --- a/python/tvm/ansor/serialization.py +++ b/python/tvm/ansor/serialization.py @@ -63,6 +63,9 @@ def __iter__(self): break yield ret[0], ret[1] # (input, result) +def load_from_file(filename: str): + return zip(*LogReader(filename).read_lines()) + def write_measure_records_to_file(filename, inputs, results): """Write(append) measure records to file""" diff --git a/python/tvm/ansor/topi_integration.py b/python/tvm/ansor/topi_integration.py new file mode 100644 index 000000000000..b4c15f74ea44 --- /dev/null +++ b/python/tvm/ansor/topi_integration.py @@ -0,0 +1,215 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-variable,invalid-name,unused-argument +""" +Decorators for registering tunable templates to TOPI. + +These decorators can make your simple implementation be able to use different configurations +for different workloads. +Here we directly use all arguments to the TOPI call as "workload", so make sure all the arguments +(except tvm.te.Tensor) in you calls are hashable. For tvm.te.Tensor, +we will serialize it to a hashable tuple. + +See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage. +""" +import tvm.te._ffi_api +from tvm import target as _target +from tvm.te import tensor +from tvm.te.tensor import PlaceholderOp, ComputeOp + +from .dispatcher import DispatchContext +from .workload_registry import register_auto_scheduler_workload_bufs, \ + make_workload_key_bufs, compute_dag_hash + +def traverse_to_get_io_tensors(outs): + layout_free_ops = [] + inputs = [] + + visited = set() + + def traverse(t): + if t in visited: + return + if isinstance(t.op, PlaceholderOp): + inputs.append(t) + elif isinstance(t.op, ComputeOp): + if "layout_free_placeholders" in t.op.attrs: + layout_free_ops.append(t.op) + for x in t.op.input_tensors: + traverse(x) + visited.add(t) + + for t in outs: + traverse(t) + + has_layout_free = (len(layout_free_ops) > 0) + return inputs + [t for t in outs], has_layout_free + +# Task extractor for relay program +class TaskExtractEnv: + """Global environment for extracting tuning tasks from graph""" + current = None + registered = None + + def __init__(self, do_layout_rewrite=False): + self.do_layout_rewrite = do_layout_rewrite + self.wanted_relay_ops = None + self.modified_funcs = [] + self.tracing = False + self.relay_disable_build_cache_ = "false" + self.layout_rewrite_success_ct = 0 + self.wkl_key_collection = {} + + def __enter__(self): + self.tracing = True + self.wkl_key_collection = {} + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.tracing = False + + def reset(self, wanted_relay_ops=None): + """Reset task collections + + Parameters + ---------- + wanted_relay_ops: List of tvm.ir.Op + The relay ops to be extracted + """ + self.wanted_relay_ops = wanted_relay_ops + self.relay_disable_build_cache_ = "false" + self.layout_rewrite_success_ct = 0 + self.wkl_key_collection = {} + + def add_task(self, key): + """Add AutoTVM task + + Parameters + ---------- + task_name: str + AutoTVM task name. + + args: tuple + Arguments to the TOPI function. + """ + if key in self.wkl_key_collection: + self.wkl_key_collection[key] += 1 + else: + self.wkl_key_collection[key] = 1 + + def get_tasks(self): + """Get collected tasks + + Returns + ------- + tasks: List of tuple(name, args) + A list of tasks extracted from the graph + """ + return self.wkl_key_collection + + def get_wkl_keys(self): + """Get collected tasks + + Returns + ------- + wkl_keys: List of autoschedule workload_key + """ + return self.wkl_key_collection + + @staticmethod + def get(do_layout_rewrite=False): + """Get the single instance of TaskExtractEnv + + Parameters + ---------- + + Returns + ------- + env: TaskExtractEnv + The single instance of TaskExtractEnv + """ + if not TaskExtractEnv.current: + TaskExtractEnv.current = TaskExtractEnv() + else: + TaskExtractEnv.current.do_layout_rewrite = do_layout_rewrite + return TaskExtractEnv.current + +def register_topi_schedule(func=None): + """Register a tunable template for a topi schedule function. + + The registration will wrap this topi schedule to take `cfg` as the first argument, + followed by the original argument list. + + Note that this function will try to find "workload" from all the ComputeOp in the input. + You can attach "workload" to your compute op by using :any:`register_topi_compute`. + + The task name has to be the same as that of the corresponding topi compute function. + + Parameters + ---------- + task_name: str + The AutoTVM task name + + func: None or callable + If it is None, return a decorator. + If is callable, decorate this function. + + Returns + ------- + decorator: callable + A decorator + + Examples + -------- + See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage. + """ + def _decorate(topi_schedule): + def wrapper(outs, *args, **kwargs): + io_tensors, has_layout_free = traverse_to_get_io_tensors(outs) + key = register_auto_scheduler_workload_bufs(io_tensors) + task_env = TaskExtractEnv.current + if task_env is not None and task_env.tracing: + if task_env.do_layout_rewrite and has_layout_free: + # Rewrite the dag and update the transform history for + # the new dag in DispatchContext + dispatch_ctx = DispatchContext.current + tgt = _target.current_target() + state = dispatch_ctx.query(tgt, key) + dag = ComputeDAG(outs) + new_dag = dag.rewrite_layout_from_state(state) + new_key = json.dumps((compute_dag_hash(new_dag),)) + dispatch_ctx.update(tgt, new_key, state) + + if new_key != key: + task_env.layout_rewrite_success_ct += 1 + + # Call schedule_func under FallbackContext() to avoid layout rewrite + tgt = _target.Target.current() + cfg = BlockingEmptyContext().query(tgt, key) + return topi_schedule(cfg, outs) + + task_env.add_task(key) + + """wrapper function for topi schedule""" + tgt = _target.Target.current() + cfg = DispatchContext.current.query(tgt, key) + return topi_schedule(cfg, outs) + return wrapper + if func: + return _decorate(func) + return _decorate diff --git a/scripts/tune_network.py b/scripts/tune_network.py new file mode 100644 index 000000000000..3d858ce60ab0 --- /dev/null +++ b/scripts/tune_network.py @@ -0,0 +1,497 @@ +"""Tune all workloads in a network""" +import argparse +import logging +import random +import os +import time +import numpy as np + +import tvm +from tvm.rpc.tracker import Tracker +from tvm.rpc.server import Server +from tvm import ansor as auto_scheduler +from tvm import relay +from tvm.rpc.tracker import Tracker +from tvm.rpc.server import Server +from tvm.relay import testing +#from tvm._ffi.function import get_global_func +import tvm.contrib.graph_runtime as runtime +from tvm.contrib.debugger import debug_runtime +from tvm.contrib import util, ndk +from common import str2bool +from tvm.ansor import LocalRunner, LogToFile, TuneOption, SimpleTaskScheduler, \ + RPCRunner, LocalBuilder +from tvm.ansor.utils import request_remote +#from baseline.utils import log_line, BenchmarkRecord + +dtype = "float32" + +def get_network(name, model_path, batch_size, layout): + """Get the symbol definition and random weight of a network""" + input_shape = (batch_size, 3, 224, 224) + output_shape = (batch_size, 1000) + input_name = 'data' + + if name.startswith("resnet3d"): + n_layer = int(name.split('-')[1]) + layout = "NDHWC" + image_shape = (16, 112, 112, 3) + input_shape = (batch_size, *image_shape) + mod, params = relay.testing.resnet3d.get_workload(num_layers=n_layer, batch_size=batch_size, image_shape=image_shape, dtype=dtype, layout=layout) + elif name.startswith("resnet"): + n_layer = int(name.split('-')[1]) + image_shape = (224, 224, 3) if layout == 'NHWC' else (3, 224, 224) + input_shape = (batch_size, *image_shape) + mod, params = relay.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, layout=layout, image_shape=image_shape, dtype=dtype) + print(mod) + elif "lstm" in name: + mod, params = relay.testing.lstm.get_workload(iterations=10, num_hidden=512, batch_size=batch_size, dtype=dtype) + elif "mlp" in name: + input_shape = (batch_size, 1, 28, 28) + mod, params = relay.testing.mlp.get_workload(batch_size=batch_size, dtype=dtype) + elif "vgg" in name: + n_layer = int(name.split('-')[1]) + mod, params = relay.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype) + elif name == 'dcgan': + input_shape = (batch_size, 100) + mod, params = relay.testing.dcgan.get_workload(batch_size=batch_size, layout=layout) + elif name == 'dqn': + image_shape = (84, 84, 4) if layout == 'NHWC' else (4, 84, 84) + input_shape = (batch_size, *image_shape) + mod, params = relay.testing.dqn.get_workload(batch_size=batch_size, image_shape=image_shape, dtype=dtype) + elif name == 'mobilenet': + image_shape = (224, 224, 3) if layout == 'NHWC' else (3, 224, 224) + input_shape = (batch_size, *image_shape) + mod, params = relay.testing.mobilenet.get_workload(batch_size=batch_size, layout=layout, image_shape=image_shape, dtype=dtype) + elif name == 'r3d_18': + import torch + import torchvision + + model = getattr(torchvision.models.video, name)(pretrained=False) + model = model.eval() + + # We grab the TorchScripted model via tracing + input_shape = [batch_size, 3, 16, 112, 112] + input_data = torch.randn(input_shape) + scripted_model = torch.jit.trace(model, input_data).eval() + + input_name = 'input0' # only one input, set it to this name + shape_list = {input_name: input_shape} + mod, params = relay.frontend.from_pytorch(scripted_model, + shape_list) + elif name == 'squeezenet_v1.1': + mod, params = relay.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1', dtype=dtype) + elif name == 'inception_v3': + input_shape = (batch_size, 3, 299, 299) + mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype) + elif name == 'mxnet': + # an example for mxnet model + from mxnet.gluon.model_zoo.vision import get_model + block = get_model('resnet18_v1', pretrained=True) + mod, params = relay.frontend.from_mxnet(block, shape={"input_name": input_shape}, dtype=dtype) + net = mod["main"] + net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs) + mod = relay.Module.from_expr(net) + elif name == 'tflite-mobilenet-v2' or name == 'tflite-resnet-v2-50': + try: + import tflite.Model + except ImportError: + raise ImportError("The tflite package must be installed") + input_name = "input" + input_shape = (1, 224, 224, 3) + output_shape = (1, 1001) + input_dtype = "float32" + tflite_model_buf = open(model_path, "rb").read() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) + mod, params = relay.frontend.from_tflite(tflite_model, + shape_dict={input_name: input_shape}, + dtype_dict={input_name: input_dtype}) + elif name == 'pytorch-mobilenet-v2': + import torch + + model = torch.hub.load('pytorch/vision:v0.5.0', 'mobilenet_v2', pretrained=False) + model.eval() + + input_shape = [batch_size, 3, 224, 224] + input_data = torch.randn(input_shape) + scripted_model = torch.jit.trace(model, input_data).eval() + + input_name = 'input0' + shape_list = {input_name: input_shape} + mod, params = relay.frontend.from_pytorch(scripted_model, + shape_list) + elif name == 'bert': + import tensorflow as tf + + bert_pb = './baseline/tensorflow/tf_models/bert/bert-B%d.pb' % batch_size + try: + with tf.compat.v1.gfile.GFile(bert_pb, 'rb') as f: + graph_def = tf.compat.v1.GraphDef() + graph_def.ParseFromString(f.read()) + except: + raise ValueError("Need to run ./baseline/tensorflow/bert/generate_bert_pb.py to get model first") + + input_shape = (batch_size, 128) + input_name = ['input'] + shape_dict = { + 'input': input_shape + } + out_names = [ + 'bert/pooler/dense/Tanh' + ] + + mod, params = relay.frontend.from_tensorflow(graph_def, + shape=shape_dict, + outputs=out_names) + elif name == 'tflite-textcnn': + try: + import tflite.Model + except ImportError: + raise ImportError("The tflite package must be installed") + model_path = './baseline/tensorflow/fake_textcnn.tflite' + input_name = "Placeholder" + input_shape = (batch_size, 200, 128, 1) + output_shape = (1, 1001) + input_dtype = "float32" + tflite_model_buf = open(model_path, "rb").read() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) + mod, params = relay.frontend.from_tflite(tflite_model, + shape_dict={input_name: input_shape}, + dtype_dict={input_name: input_dtype}) + print(mod['main']) + elif name == 'textcnn': + import tensorflow as tf + + bert_pb = './baseline/tensorflow/fake_textcnn.pb' + try: + with tf.compat.v1.gfile.GFile(bert_pb, 'rb') as f: + graph_def = tf.compat.v1.GraphDef() + graph_def.ParseFromString(f.read()) + except: + raise ValueError("Need to run ./baseline/tensorflow/bert/generate_bert_pb.py to get model first") + + input_shape = (batch_size, 200, 128, 1) + input_name = ['Placeholder'] + shape_dict = { + 'Placeholder': input_shape + } + out_names = [ + 'concat/concat_dim' + ] + + mod, params = relay.frontend.from_tensorflow(graph_def, + shape=shape_dict, + outputs=out_names) + print(mod['main']) + elif name == 'tdnn': + import tensorflow as tf + + pb = './baseline/tensorflow/pruned_model_0407.pb' + #pb = './baseline/tensorflow/tdnn_4001.pb' + try: + with tf.compat.v1.gfile.GFile(pb, 'rb') as f: + graph_def = tf.compat.v1.GraphDef() + graph_def.ParseFromString(f.read()) + except: + raise ValueError("Need to run ./baseline/tensorflow/bert_convert.py to get model first") + + input_shape = (batch_size, 600, 64) + input_name = ['tf_loss_fn/Placeholder'] + + shape_dict = { + 'tf_loss_fn/Placeholder': input_shape, + } + out_names = [ + #"tf_loss_fn/ForwardPass/w2l_encoder/conv91/Conv2D" + "tf_loss_fn/ForwardPass/Softmax" + ] + mod, params = relay.frontend.from_tensorflow(graph_def, + shape=shape_dict, + outputs=out_names) + else: + raise ValueError("Unsupported network: " + name) + + return mod, params, input_name, input_shape, output_shape + + +def create_module(data_shape, graph, lib, target, input_name, params, debug_profile, + local_measure, ndk_cc, device_key, host, port, run_timeout, num_threads, seed=43): + # Upload parameters to device + if local_measure: + if target.target_name == "cuda": + ctx = tvm.gpu() + else: + ctx = tvm.cpu() + if num_threads: + config_threadpool = get_global_func('runtime.config_threadpool') + config_threadpool(0, num_threads) + else: + print("=============== Request Remote ===============") + if 'TVM_NDK_CC' not in os.environ: + os.environ['TVM_NDK_CC'] = ndk_cc + remote = request_remote(device_key, host, port, timeout=run_timeout) + + print("=============== Export ===============") + ctx = remote.cpu() + temp = util.tempdir() + path_lib = temp.relpath("deploy_lib.so") + lib.export_library(path_lib, ndk.create_shared) + + print("=============== Upload ===============") + remote.upload(path_lib) + + print("=============== Load ===============") + lib = remote.load_module("deploy_lib.so") + if num_threads: + config_threadpool = remote.get_function('runtime.config_threadpool') + config_threadpool(0, num_threads) + + np.random.seed(seed) + data_tvm = tvm.nd.array(100 * (np.random.uniform(size=data_shape)).astype(dtype), ctx=ctx) + if debug_profile: + module = debug_runtime.create(graph, lib, ctx) + else: + module = runtime.create(graph, lib, ctx) + if type(input_name) == list: + for name in input_name: + module.set_input(name, data_tvm) + else: + module.set_input(input_name, data_tvm) + for k, v in params.items(): + module.set_input(k, v) + + return module, ctx + + +def tune_and_evaluate(network_name, model_path, batch_size, target, target_host, + local_measure, device_key, host, port, n_parallel, ndk_cc, + build_timeout, run_timeout, num_threads, tune, check_correctness, + debug_profile, tuning_parameters, record_file, layout_set): + joint_tuner, model_type, policy, log_file, load_log_file = (tuning_parameters['joint_tuner'], + tuning_parameters['model_type'], tuning_parameters['policy'], + tuning_parameters['log_file'], tuning_parameters['load_log_file']) + + if layout_set: + layout = layout_set + elif target.target_name == 'cuda': + layout = 'NCHW' + else: + layout = "NHWC" + + # Extract workloads from relay program + print("=============== Extract workloads ===============") + mod, params, input_name, data_shape, out_shape = get_network(network_name, model_path, batch_size, layout) + + if tune: + workloads, wkl_weights = auto_scheduler.extract_from_program(mod, target=target, + params=params, ops=(relay.op.nn.dense, relay.op.nn.softmax, + relay.op.nn.conv2d, relay.op.nn.conv2d_transpose, + relay.op.nn.max_pool2d, relay.op.nn.avg_pool2d, + relay.op.nn.global_max_pool2d, relay.op.nn.global_avg_pool2d, + relay.op.nn.conv3d, relay.op.nn.adaptive_avg_pool3d, + relay.op.nn.batch_matmul, relay.op.mean, + )) + print("Total workload number: %d" % (len(workloads))) + #workloads = workloads[1:2] + #wkl_weights = wkl_weights[1:2] + #workloads = ['["2543426b0070d4a379a1f75a362a5f1b"]'] + + + # Tune workloads with auto scheduler + print("=============== Tuning ===============") + tasks = [] + for i, wkl_key in enumerate(workloads): + dag = auto_scheduler.workload_key_to_dag(wkl_key) + print("[========= Task %d =========]\n" % i, dag) + tasks.append(auto_scheduler.SearchTask(dag, wkl_key, target, target_host)) + + if joint_tuner != 'rl': + tuner = SimpleTaskScheduler(tasks, load_log_file=load_log_file) + elif joint_tuner == 'rl': + # put import here to remove pytorch dependency + from tvm.auto_scheduler.joint_tuner.rl_joint_tuner import RLJointTuner + tuner = RLJointTuner(tasks, weights=wkl_weights, load_log_file=load_log_file) + else: + raise ValueError("Invalid joint tuner: " + joint_tuner) + + if local_measure: + builder = LocalBuilder(timeout=build_timeout) + if target.target_name == "cuda": + ctx = tvm.context("cuda", 0) + cuda_arch = "sm_" + "".join(ctx.compute_version.split('.')) + tvm.autotvm.measure.measure_methods.set_cuda_target_arch(cuda_arch) + + tracker = Tracker('0.0.0.0', port=port, port_end=10000, silent=True) + if device_key is None: + device_key = '$local$device$%d' % tracker.port + server = Server('0.0.0.0', port=tracker.port, port_end=10000, + key=device_key, use_popen=True, silent=True, + tracker_addr=(tracker.host, tracker.port)) + runner = RPCRunner(device_key, host=host, port=tracker.port, + repeat=1, min_repeat_ms=400, + n_parallel=n_parallel) + else: + os.environ['TVM_AUTO_CACHE_FLUSH'] = "1" + runner = LocalRunner(repeat=10, number=1, min_repeat_ms=0, timeout=run_timeout) + else: + os.environ['TVM_NDK_CC'] = ndk_cc + builder = LocalBuilder(build_func='ndk', timeout=build_timeout) + runner = RPCRunner(device_key, host=host, port=port, + repeat=1, min_repeat_ms=400, + n_parallel=n_parallel, timeout=run_timeout) + + search_policy = "%s.%s" % (policy, model_type) + tune_option = TuneOption(n_trials=tuning_parameters['n_trials'], + early_stopping=tuning_parameters['early_stopping'], + num_measure_per_iter=tuning_parameters['num_measure_per_iter'], + builder=builder, + verbose=tuning_parameters['verbose'], + runner=runner, + measure_callbacks=[LogToFile(log_file)]) + if local_measure and target.target_name != 'cuda': + os.environ['TVM_BIND_MASTER_CORE_0'] = "1" + tuner.tune(tune_option, search_policy) + else: + tuner.tune(tune_option, search_policy) + + kernel_layout_rewrite = False + + # Compile graph with best states found by auto-scheduler + print("=============== Compile ===============") + with auto_scheduler.apply_history_best(log_file, args.log_n_lines): + #if True: + #with auto_scheduler.BlockingEmptyContext(): + os.environ['TVM_AUTO_CACHE_FLUSH'] = "0" + os.environ['TVM_BIND_MASTER_CORE_0'] = "1" + if kernel_layout_rewrite: + auto_scheduler.prepare_layout_rewrite(mod, target=target, + params=params, + ops=(relay.op.nn.dense, relay.op.nn.conv2d, relay.op.nn.conv3d)) + else: + # disable layout rewrite + auto_scheduler.LayoutRewriteLevel.BOTH_REWRITE = auto_scheduler.LayoutRewriteLevel.NO_REWRITE + auto_scheduler.LayoutRewriteLevel.COMPUTE_REWRITE = auto_scheduler.LayoutRewriteLevel.NO_REWRITE + + with relay.build_config(opt_level=3): + graph, lib, opt_params = relay.build_module.build( + mod, target=target, params=params) + ''' + from tvm.relay.backend import graph_runtime_codegen + with relay.build_config(opt_level=3): + opt_mod, _ = relay.optimize(mod, target, params) + grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target) + grc.codegen(opt_mod["main"]) + with tvm.transform.PassContext(opt_level=3): + graph, lib, opt_params = relay.build_module.build( + mod, target=target, params=params) + ''' + auto_scheduler.finish_layout_rewrite() + print("=============== Compile Finish ===============") + + module, ctx = create_module(data_shape, graph, lib, target, input_name, opt_params, + debug_profile, local_measure, ndk_cc, + device_key, host, port, run_timeout, num_threads) + + # Evaluate + print("========== Evaluate ==========") + ftimer = module.module.time_evaluator("run", ctx, number=10, repeat=3) + prof_res = np.array(ftimer().results) + # display profile information + if debug_profile or check_correctness: + module.run() + if check_correctness: + actual_output = module.get_output(0).asnumpy() + print(actual_output) + + print("Mean inference time (std dev): %.2f ms (%.2f ms)" % + (np.mean(prof_res) * 1000, np.std(prof_res) * 1000)) + #log_line(BenchmarkRecord(target.target_name, 'gpu' if target.target_name == 'cuda' else 'cpu', 'network', + # "%s.B%d" % (network_name, batch_size), 'AutoSchedule', layout, + # {"costs": prof_res}, time.time()), record_file) + + if check_correctness: + print("========== Check Correctness ==========") + # clean relay cache + relay.backend.compile_engine.get().clear() + + # disable layout rewrite + auto_scheduler.LayoutRewriteLevel.BOTH_REWRITE = auto_scheduler.LayoutRewriteLevel.NO_REWRITE + auto_scheduler.LayoutRewriteLevel.COMPUTE_REWRITE = auto_scheduler.LayoutRewriteLevel.NO_REWRITE + target = tvm.target.create('llvm') + with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): + graph, lib, opt_params = relay.build_module.build( + mod, target=target, params=params) + + module, _ = create_module(data_shape, graph, lib, target, input_name, opt_params, + debug_profile, local_measure, ndk_cc, + device_key, host, port, run_timeout, num_threads) + module.run() + + expected_output = module.get_output(0).asnumpy() + np.testing.assert_allclose(actual_output, expected_output, rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--network", type=str, required=True) + parser.add_argument("--model-path", type=str, default=None, help="The path of tflite model") + parser.add_argument("--n-trials", type=int, default=1000) + parser.add_argument("--target", type=str, default='llvm -mcpu=core-avx2') + parser.add_argument("--target-host", type=str, default=None) + parser.add_argument("--policy", type=str, choices=['multi-stage', 'meta-rewrite'], + default='meta-rewrite') + parser.add_argument("--tune", type=str2bool, nargs='?', const=True, default=True) + parser.add_argument("--check-correctness", type=str2bool, nargs='?', const=True, default=False) + parser.add_argument("--debug-profile", type=str2bool, nargs='?', const=True, default=False) + parser.add_argument("--build-timeout", type=int, default=10) + parser.add_argument("--run-timeout", type=int, default=10) + parser.add_argument("--log-file", type=str, help="Write log of measurement results to this file") + parser.add_argument("--model-type", type=str, choices=['xgb', 'random', 'no-share'], default='xgb') + parser.add_argument("--load-model", action='store_true') + parser.add_argument("--model-file", type=str, default='saved_model.xgb') + parser.add_argument("--load-log", type=str, help="Load history log for pre-training the cost model") + parser.add_argument("--out-file", type=str, default='results.tsv') + parser.add_argument("--seed", type=int, default=0, help='random seed') + parser.add_argument("--verbose", type=int, default=1) + parser.add_argument("--joint-tuner", type=str, default='bottleneck-decay', help='The type of joint tuner', + choices=['no', 'uniform', 'weighted', 'bottleneck', 'bottleneck-decay', 'sequential', 'round-robin', 'rl']) + parser.add_argument("--local-measure", type=str2bool, nargs='?', const=True, default=True) + parser.add_argument("--device-key", type=str, default=None) + parser.add_argument("--host", type=str, default='0.0.0.0') + parser.add_argument("--port", type=int, default=9190) + parser.add_argument("--n-parallel", type=int, default=1) + parser.add_argument("--ndk-cc", type=str, default=None) + parser.add_argument("--num-threads", type=int, default=None) + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--num-measure-per-iter", type=int, default=48, + help="The number of programs to be measured at each iteration") + parser.add_argument("--layout", type=str, default=None) + parser.add_argument("--log-n-lines", type=int) + args = parser.parse_args() + + np.random.seed(args.seed) + random.seed(args.seed) + + logging.basicConfig() + logging.getLogger('auto_scheduler').setLevel(logging.DEBUG) + + target = tvm.target.create(args.target) + + tuning_parameters = { + 'n_trials': args.n_trials, + 'num_measure_per_iter': args.num_measure_per_iter, + 'log_file': args.log_file if args.log_file else "%s-B%d.json" % (args.network, args.batch_size), + 'model_type': args.model_type, + 'joint_tuner': args.joint_tuner, + 'policy': args.policy, + 'early_stopping': -1, + 'verbose': 1, + } + tuning_parameters['load_log_file'] = args.load_log or tuning_parameters['log_file'] + + os.environ["TOPHUB_LOCATION"] = "NONE" + tune_and_evaluate(args.network, args.model_path, args.batch_size, target, args.target_host, + args.local_measure, args.device_key, args.host, + args.port, args.n_parallel, args.ndk_cc, args.build_timeout, + args.run_timeout, args.num_threads, args.tune, args.check_correctness, + args.debug_profile, tuning_parameters, args.out_file, args.layout) diff --git a/topi/python/topi/ansor.py b/topi/python/topi/ansor.py new file mode 100644 index 000000000000..e821fd5bd42f --- /dev/null +++ b/topi/python/topi/ansor.py @@ -0,0 +1,95 @@ +"""All AutoSchedule Supported Operators""" +from __future__ import absolute_import as _abs +from tvm import ansor + +@ansor.register_topi_schedule() +def schedule_dense_nopack(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_conv2d_nhwc(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_conv2d_NCHWc(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_reduce(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_pool(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_adaptive_pool(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_softmax(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_conv2d_nchw_int8(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_conv2d_nchw(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_depthwise_conv2d_nchw(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_depthwise_conv2d_nhwc(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_conv2d_NCHWc_int8(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_depthwise_conv2d_NCHWc(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_conv2d_transpose_nchw(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_conv3d_ncdhw(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_conv3d_ndhwc(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_conv1d_ncw(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_conv1d_nwc(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_dense_pack(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_batch_matmul(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_bitserial_conv2d_nchw(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_bitserial_conv2d_nhwc(cfg, outs): + return ansor.gen_schedule(cfg, outs) + +@ansor.register_topi_schedule() +def schedule_bitserial_dense(cfg, outs): + return ansor.gen_schedule(cfg, outs) diff --git a/topi/python/topi/arm_cpu/__init__.py b/topi/python/topi/arm_cpu/__init__.py index e121fbc7ec6d..e6ccadd4755f 100644 --- a/topi/python/topi/arm_cpu/__init__.py +++ b/topi/python/topi/arm_cpu/__init__.py @@ -26,3 +26,8 @@ from .bitserial_dense import * from .injective import * from . import cortex_m7 + +import os +use_auto_scheduler = os.environ.get("TVM_USE_AUTO_SCHEDULER", "true") +if use_auto_scheduler.lower() == "true": + from ..ansor import * diff --git a/topi/python/topi/generic/__init__.py b/topi/python/topi/generic/__init__.py index 6171317cd80f..7f37ba78a06c 100644 --- a/topi/python/topi/generic/__init__.py +++ b/topi/python/topi/generic/__init__.py @@ -39,3 +39,8 @@ from .sort import * from .search import * from .image import * + +import os +use_auto_scheduler = os.environ.get("TVM_USE_AUTO_SCHEDULER", "true") +if use_auto_scheduler.lower() == "true": + from ..ansor import * diff --git a/topi/python/topi/x86/__init__.py b/topi/python/topi/x86/__init__.py index 659668cbbe4c..a334397249e3 100644 --- a/topi/python/topi/x86/__init__.py +++ b/topi/python/topi/x86/__init__.py @@ -39,3 +39,8 @@ from .conv3d_transpose import * from .sparse import * from .conv2d_alter_op import * + +import os +use_auto_scheduler = os.environ.get("TVM_USE_AUTO_SCHEDULER", "true") +if use_auto_scheduler.lower() == "true": + from ..ansor import * From 674027f8d6b9943508ee9eaf0fba703189a1c781 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Fri, 12 Jun 2020 18:16:28 +0800 Subject: [PATCH 23/45] Add tune_op_subgraph.py & Some code clean for tune_network.py (#23) * Add single op tune scripts * Add tune subgraph support * Merge all op & all subgraph to one file * Rename file --- python/tvm/ansor/auto_schedule.py | 1 + python/tvm/ansor/relay_integration.py | 2 +- scripts/common.py | 2 +- scripts/shape_configs.py | 248 ++++++++++ scripts/tune_network.py | 230 +++------ scripts/tune_op_subgraph.py | 599 +++++++++++++++++++++++ scripts/tune_test.py | 79 +-- src/ansor/search_policy/search_policy.cc | 6 +- 8 files changed, 968 insertions(+), 199 deletions(-) create mode 100644 scripts/shape_configs.py create mode 100644 scripts/tune_op_subgraph.py diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index e1a0711a80be..09895302d25a 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -193,6 +193,7 @@ class TuneOption(Object): Callback functions called before the search process Candidates: - ansor.PreLoadMeasuredStates + - ansor.PreAddCustomRule """ def __init__(self, n_trials=0, early_stopping=-1, num_measure_per_iter=64, verbose=1, builder='local', runner='local', measure_callbacks=None, diff --git a/python/tvm/ansor/relay_integration.py b/python/tvm/ansor/relay_integration.py index 7d7e18a94ddf..de2e12e389e7 100644 --- a/python/tvm/ansor/relay_integration.py +++ b/python/tvm/ansor/relay_integration.py @@ -196,7 +196,7 @@ def prepare_layout_rewrite(mod, params, ops, target): # wrap build call in thread to avoid multiprocessing problems build_thread = threading.Thread(target=_lower, - args=(mod, target, param)) + args=(mod, target, params)) build_thread.start() build_thread.join() relay.backend.compile_engine.get().clear() diff --git a/scripts/common.py b/scripts/common.py index 4400104bdfe6..84fbf8d6c731 100644 --- a/scripts/common.py +++ b/scripts/common.py @@ -168,7 +168,7 @@ def conv2d_nhwc_without_layout_rewrite(Input, Filter, stride, padding, dilation, # compute the output shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 - pad_top, pad_left, pad_down, pad_right = topi.nn.util.get_pad_tuple( + pad_top, pad_left, pad_down, pad_right = topi.nn.get_pad_tuple( padding, (dilated_kernel_h, dilated_kernel_w)) out_channel = num_filter out_height = topi.util.simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) diff --git a/scripts/shape_configs.py b/scripts/shape_configs.py new file mode 100644 index 000000000000..95a1ba69634d --- /dev/null +++ b/scripts/shape_configs.py @@ -0,0 +1,248 @@ +""" Shape configurations for single operator evaluation +This file is shared by tune_all_single_op.py and scripts in baseline/ +""" + +matmul_shapes = [ + (1, 128, 128, 128), + (1, 512, 32, 512), + (1, 512, 512, 512), + (1, 1024, 1024, 1024), +] + +conv1d_shapes = [ + # derived from conv2d_shapes + (1, 256, 64, 128, 3, 2, 1), +# (1, 256, 64, 128, 1, 2, 0), +# (1, 256, 64, 64, 1, 1, 0), +# (1, 128, 128, 256, 3, 2, 1), + (1, 128, 128, 256, 1, 2, 0), +# (1, 128, 128, 128, 3, 1, 1), +# (1, 64, 256, 512, 3, 2, 1), +# (1, 64, 256, 512, 1, 2, 0), + (1, 64, 256, 256, 5, 1, 2), + (1, 32, 512, 512, 3, 1, 1), +] + +conv2d_shapes = [ + # all conv2d layers in resnet-18 + (1, 224, 224, 3, 64, 7, 2, 3), +# (1, 56, 56, 64, 128, 3, 2, 1), +# (1, 56, 56, 64, 128, 1, 2, 0), +# (1, 56, 56, 64, 64, 3, 1, 1), + (1, 56, 56, 64, 64, 1, 1, 0), +# (1, 28, 28, 128, 256, 3, 2, 1), +# (1, 28, 28, 128, 256, 1, 2, 0), +# (1, 28, 28, 128, 128, 3, 1, 1), +# (1, 14, 14, 256, 512, 3, 2, 1), +# (1, 14, 14, 256, 512, 1, 2, 0), + (1, 14, 14, 256, 256, 3, 1, 1), + (1, 7, 7, 512, 512, 3, 1, 1), +] + +conv3d_shapes = [ + # Derived from cnov2d_shapes. Use depth=16 for all configurations + (1, 16, 224, 224, 3, 64, 7, 2, 3), +# (1, 16, 56, 56, 64, 128, 3, 2, 1), +# (1, 16, 56, 56, 64, 128, 1, 2, 0), +# (1, 16, 56, 56, 64, 64, 3, 1, 1), + (1, 16, 56, 56, 64, 64, 1, 1, 0), +# (1, 16, 28, 28, 128, 256, 3, 2, 1), +# (1, 16, 28, 28, 128, 256, 1, 2, 0), +# (1, 16, 28, 28, 128, 128, 3, 1, 1), +# (1, 16, 14, 14, 256, 512, 3, 2, 1), +# (1, 16, 14, 14, 256, 512, 1, 2, 0), + (1, 16, 14, 14, 256, 256, 3, 1, 1), + (1, 16, 7, 7, 512, 512, 3, 1, 1), +] + +group_conv2d_shapes = [ + # Derived from cnov2d_shapes. Use group=4 for all configurations + (1, 56, 56, 64, 128, 3, 2, 1 , 1, 4), +# (1, 56, 56, 64, 128, 1, 2, 0 , 1, 4), +# (1, 56, 56, 64, 64, 3, 1, 1 , 1, 4), + (1, 56, 56, 64, 64, 1, 1, 0 , 1, 4), +# (1, 28, 28, 128, 256, 3, 2, 1, 1, 4), +# (1, 28, 28, 128, 256, 1, 2, 0, 1, 4), +# (1, 28, 28, 128, 128, 3, 1, 1, 1, 4), +# (1, 14, 14, 256, 512, 3, 2, 1, 1, 4), +# (1, 14, 14, 256, 512, 1, 2, 0, 1, 4), + (1, 14, 14, 256, 256, 3, 1, 1, 1, 4), + (1, 7, 7, 512, 512, 3, 1, 1 , 1, 4), +] + +dilation_conv2d_shapes = [ + # Derived from cnov2d_shapes. Use dilation=2 for all configurations + (1, 224, 224, 3, 64, 7, 2, 3 , 2), +# (1, 56, 56, 64, 128, 3, 2, 1 , 2), +# (1, 56, 56, 64, 128, 1, 2, 0 , 2), +# (1, 56, 56, 64, 64, 3, 1, 1 , 2), + (1, 56, 56, 64, 64, 1, 1, 0 , 2), +# (1, 28, 28, 128, 256, 3, 2, 1, 2), +# (1, 28, 28, 128, 256, 1, 2, 0, 2), +# (1, 28, 28, 128, 128, 3, 1, 1, 2), +# (1, 14, 14, 256, 512, 3, 2, 1, 2), +# (1, 14, 14, 256, 512, 1, 2, 0, 2), + (1, 14, 14, 256, 256, 3, 1, 1, 2), + (1, 7, 7, 512, 512, 3, 1, 1 , 2), +] + +depthwise_conv2d_shapes = [ + # all depthwise conv2d layers in mobilenet + (1, 112, 112, 32, 3, 1, 1), + (1, 112, 112, 64, 3, 2, 1), +# (1, 56, 56, 128, 3, 1, 1), +# (1, 56, 56, 128, 3, 2, 1), +# (1, 28, 28, 256, 3, 1, 1), +# (1, 28, 28, 256, 3, 2, 1), +# (1, 14, 14, 512, 3, 1, 1), + (1, 14, 14, 512, 3, 2, 1), + (1, 7, 7, 1024, 3, 1, 1), +] + +conv2d_transpose_shapes = [ + # all conv2d tranpose layers in DCGAN + (1, 4, 4, 512, 256, 4, 2, 1), + (1, 8, 8, 256, 128, 4, 2, 1), + (1, 16, 16, 128, 64, 4, 2, 1), + (1, 32, 32, 64, 3, 4, 2, 1), +] + +conv2d_capsule_shapes = [ + # all conv2d capsule layers in matrix capsules withemrouting (ICLR 2018) + (1, 16, 16, 32, 32, 3, 2, 1), + (1, 8, 8, 32, 32, 3, 1, 1), + (1, 16, 16, 8, 16, 3, 2, 1), + (1, 8, 8, 16, 16, 3, 1, 1), +] + +conv2d_winograd_nhwc_shapes = [ + (1, 56, 56, 64, 64, 3, 1, 1), + (1, 28, 28, 128, 128, 3, 1, 1), + (1, 14, 14, 256, 256, 3, 1, 1), + (1, 7, 7, 512, 512, 3, 1, 1), +] + +conv2d_winograd_nchw_shapes = [ + (1, 64, 56, 56, 64, 3, 1, 1), + (1, 128, 28, 28, 128, 3, 1, 1), + (1, 256, 14, 14, 256, 3, 1, 1), + (1, 512, 7, 7, 512, 3, 1, 1), +] + +matmul_tensor_core_shapes = [ + (16, 512, 512, 'float16', 'float32', True), + (32, 512, 512, 'float16', 'float32', True), + (512, 512, 512, 'float16', 'float32', True), +] + +norm_shapes = [ + (1, 256, 256), + (1, 512, 512), + (1, 1024, 1024), + (1, 4096, 1024), +] + +softmax_shapes = [ + (1, 1024), + (1, 4096), + (1, 16384), + (1, 65536), +] + +single_op_shape_dict = { + 'C1D': conv1d_shapes, + 'C2D': conv2d_shapes, + 'C3D': conv3d_shapes, + 'GMM': matmul_shapes, + 'GRP': group_conv2d_shapes, + 'DIL': dilation_conv2d_shapes, + 'DEP': depthwise_conv2d_shapes, + 'T2D': conv2d_transpose_shapes, + 'CAP': conv2d_capsule_shapes, + 'NRM': norm_shapes, + #'SMX': softmax_shapes, + +# The following workloads are not in our sinle op evaluation plan. +# They should be moved to `common.py` and be used by `tune_wkl.py`. +# 'C2D_NCHW': conv2d_nchw_shapes, + 'C2DWG_NHWC': conv2d_winograd_nhwc_shapes, +# 'C2DWG_NCHW': conv2d_winograd_nchw_shapes, +# 'GMM_TC': matmul_tensor_core_shapes, +} + +conv2d_bn_relu_shapes = [ + (1, 224, 224, 3, 64, 7, 2, 3), + (1, 56, 56, 64, 128, 3, 2, 1), + (1, 28, 28, 128, 256, 1, 2, 0), + (1, 7, 7, 512, 512, 3, 1, 1, 1), + (16, 224, 224, 3, 64, 7, 2, 3), + (16, 56, 56, 64, 128, 3, 2, 1), + (16, 28, 28, 128, 256, 1, 2, 0), + (16, 7, 7, 512, 512, 3, 1, 1, 1), +] + +transpose_batch_matmul_shapes = [ + (1, 128, 12, 64), + (1, 128, 16, 64), + (1, 64, 12, 128), + (1, 128, 12, 128), + (16, 128, 12, 64), + (16, 128, 16, 64), + (16, 64, 12, 128), + (16, 128, 12, 128), +] + + +batch_norm_shapes = [ + (16, 256), + (16, 1024), + (16, 4096), + (16, 16384), + (16, 65536), +] + +subgraph_shape_dict = { + "conv2d_bn_relu": conv2d_bn_relu_shapes, + "transpose_batch_matmul": transpose_batch_matmul_shapes, + #"batch_norm": batch_norm_shapes, +} + +resnet_shapes = [ + (1, ), + (16, ), +] + +mobilenet_v2_shapes = [ + (1, ), + (16, ), +] + +dcgan_shapes = [ + (1, ), + (16, ), +] + +dqn_shapes = [ + (1, ), + (16, ), +] + +bert_shapes = [ + (1, ), + (16, ), +] + +resnet18_3d_shapes = [ + (1, ), + (16, ), +] + +network_shape_dict = { + 'resnet_50': resnet_shapes, + 'mobilenet_v2': mobilenet_v2_shapes, + 'dcgan': dcgan_shapes, + 'dqn': dqn_shapes, + 'bert': bert_shapes, + 'resnet_18_3d': resnet18_3d_shapes, +} + diff --git a/scripts/tune_network.py b/scripts/tune_network.py index 3d858ce60ab0..f1f7cd54f8c6 100644 --- a/scripts/tune_network.py +++ b/scripts/tune_network.py @@ -7,23 +7,17 @@ import numpy as np import tvm -from tvm.rpc.tracker import Tracker -from tvm.rpc.server import Server -from tvm import ansor as auto_scheduler -from tvm import relay -from tvm.rpc.tracker import Tracker -from tvm.rpc.server import Server -from tvm.relay import testing -#from tvm._ffi.function import get_global_func +from tvm import _ffi, relay, ansor import tvm.contrib.graph_runtime as runtime from tvm.contrib.debugger import debug_runtime from tvm.contrib import util, ndk -from common import str2bool -from tvm.ansor import LocalRunner, LogToFile, TuneOption, SimpleTaskScheduler, \ - RPCRunner, LocalBuilder +from tvm.relay import testing from tvm.ansor.utils import request_remote #from baseline.utils import log_line, BenchmarkRecord +from common import str2bool +from tune_test import create_tune_option + dtype = "float32" def get_network(name, model_path, batch_size, layout): @@ -43,7 +37,6 @@ def get_network(name, model_path, batch_size, layout): image_shape = (224, 224, 3) if layout == 'NHWC' else (3, 224, 224) input_shape = (batch_size, *image_shape) mod, params = relay.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, layout=layout, image_shape=image_shape, dtype=dtype) - print(mod) elif "lstm" in name: mod, params = relay.testing.lstm.get_workload(iterations=10, num_hidden=512, batch_size=batch_size, dtype=dtype) elif "mlp" in name: @@ -54,9 +47,9 @@ def get_network(name, model_path, batch_size, layout): mod, params = relay.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype) elif name == 'dcgan': input_shape = (batch_size, 100) - mod, params = relay.testing.dcgan.get_workload(batch_size=batch_size, layout=layout) + mod, params = relay.testing.dcgan.get_workload(batch_size=batch_size) elif name == 'dqn': - image_shape = (84, 84, 4) if layout == 'NHWC' else (4, 84, 84) + image_shape = (4, 84, 84) input_shape = (batch_size, *image_shape) mod, params = relay.testing.dqn.get_workload(batch_size=batch_size, image_shape=image_shape, dtype=dtype) elif name == 'mobilenet': @@ -143,71 +136,6 @@ def get_network(name, model_path, batch_size, layout): mod, params = relay.frontend.from_tensorflow(graph_def, shape=shape_dict, outputs=out_names) - elif name == 'tflite-textcnn': - try: - import tflite.Model - except ImportError: - raise ImportError("The tflite package must be installed") - model_path = './baseline/tensorflow/fake_textcnn.tflite' - input_name = "Placeholder" - input_shape = (batch_size, 200, 128, 1) - output_shape = (1, 1001) - input_dtype = "float32" - tflite_model_buf = open(model_path, "rb").read() - tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) - mod, params = relay.frontend.from_tflite(tflite_model, - shape_dict={input_name: input_shape}, - dtype_dict={input_name: input_dtype}) - print(mod['main']) - elif name == 'textcnn': - import tensorflow as tf - - bert_pb = './baseline/tensorflow/fake_textcnn.pb' - try: - with tf.compat.v1.gfile.GFile(bert_pb, 'rb') as f: - graph_def = tf.compat.v1.GraphDef() - graph_def.ParseFromString(f.read()) - except: - raise ValueError("Need to run ./baseline/tensorflow/bert/generate_bert_pb.py to get model first") - - input_shape = (batch_size, 200, 128, 1) - input_name = ['Placeholder'] - shape_dict = { - 'Placeholder': input_shape - } - out_names = [ - 'concat/concat_dim' - ] - - mod, params = relay.frontend.from_tensorflow(graph_def, - shape=shape_dict, - outputs=out_names) - print(mod['main']) - elif name == 'tdnn': - import tensorflow as tf - - pb = './baseline/tensorflow/pruned_model_0407.pb' - #pb = './baseline/tensorflow/tdnn_4001.pb' - try: - with tf.compat.v1.gfile.GFile(pb, 'rb') as f: - graph_def = tf.compat.v1.GraphDef() - graph_def.ParseFromString(f.read()) - except: - raise ValueError("Need to run ./baseline/tensorflow/bert_convert.py to get model first") - - input_shape = (batch_size, 600, 64) - input_name = ['tf_loss_fn/Placeholder'] - - shape_dict = { - 'tf_loss_fn/Placeholder': input_shape, - } - out_names = [ - #"tf_loss_fn/ForwardPass/w2l_encoder/conv91/Conv2D" - "tf_loss_fn/ForwardPass/Softmax" - ] - mod, params = relay.frontend.from_tensorflow(graph_def, - shape=shape_dict, - outputs=out_names) else: raise ValueError("Unsupported network: " + name) @@ -223,7 +151,7 @@ def create_module(data_shape, graph, lib, target, input_name, params, debug_prof else: ctx = tvm.cpu() if num_threads: - config_threadpool = get_global_func('runtime.config_threadpool') + config_threadpool = _ffi.get_global_func('runtime.config_threadpool') config_threadpool(0, num_threads) else: print("=============== Request Remote ===============") @@ -267,23 +195,19 @@ def tune_and_evaluate(network_name, model_path, batch_size, target, target_host, local_measure, device_key, host, port, n_parallel, ndk_cc, build_timeout, run_timeout, num_threads, tune, check_correctness, debug_profile, tuning_parameters, record_file, layout_set): - joint_tuner, model_type, policy, log_file, load_log_file = (tuning_parameters['joint_tuner'], + task_scheduler, model_type, policy, log_file, load_log_file = (tuning_parameters['task_scheduler'], tuning_parameters['model_type'], tuning_parameters['policy'], tuning_parameters['log_file'], tuning_parameters['load_log_file']) if layout_set: layout = layout_set - elif target.target_name == 'cuda': - layout = 'NCHW' - else: - layout = "NHWC" # Extract workloads from relay program print("=============== Extract workloads ===============") mod, params, input_name, data_shape, out_shape = get_network(network_name, model_path, batch_size, layout) if tune: - workloads, wkl_weights = auto_scheduler.extract_from_program(mod, target=target, + workloads, wkl_weights = ansor.extract_from_program(mod, target=target, params=params, ops=(relay.op.nn.dense, relay.op.nn.softmax, relay.op.nn.conv2d, relay.op.nn.conv2d_transpose, relay.op.nn.max_pool2d, relay.op.nn.avg_pool2d, @@ -292,85 +216,54 @@ def tune_and_evaluate(network_name, model_path, batch_size, target, target_host, relay.op.nn.batch_matmul, relay.op.mean, )) print("Total workload number: %d" % (len(workloads))) - #workloads = workloads[1:2] - #wkl_weights = wkl_weights[1:2] - #workloads = ['["2543426b0070d4a379a1f75a362a5f1b"]'] - # Tune workloads with auto scheduler print("=============== Tuning ===============") tasks = [] for i, wkl_key in enumerate(workloads): - dag = auto_scheduler.workload_key_to_dag(wkl_key) + dag = ansor.workload_key_to_dag(wkl_key) print("[========= Task %d =========]\n" % i, dag) - tasks.append(auto_scheduler.SearchTask(dag, wkl_key, target, target_host)) - - if joint_tuner != 'rl': - tuner = SimpleTaskScheduler(tasks, load_log_file=load_log_file) - elif joint_tuner == 'rl': - # put import here to remove pytorch dependency - from tvm.auto_scheduler.joint_tuner.rl_joint_tuner import RLJointTuner - tuner = RLJointTuner(tasks, weights=wkl_weights, load_log_file=load_log_file) - else: - raise ValueError("Invalid joint tuner: " + joint_tuner) - - if local_measure: - builder = LocalBuilder(timeout=build_timeout) - if target.target_name == "cuda": - ctx = tvm.context("cuda", 0) - cuda_arch = "sm_" + "".join(ctx.compute_version.split('.')) - tvm.autotvm.measure.measure_methods.set_cuda_target_arch(cuda_arch) - - tracker = Tracker('0.0.0.0', port=port, port_end=10000, silent=True) - if device_key is None: - device_key = '$local$device$%d' % tracker.port - server = Server('0.0.0.0', port=tracker.port, port_end=10000, - key=device_key, use_popen=True, silent=True, - tracker_addr=(tracker.host, tracker.port)) - runner = RPCRunner(device_key, host=host, port=tracker.port, - repeat=1, min_repeat_ms=400, - n_parallel=n_parallel) - else: - os.environ['TVM_AUTO_CACHE_FLUSH'] = "1" - runner = LocalRunner(repeat=10, number=1, min_repeat_ms=0, timeout=run_timeout) - else: - os.environ['TVM_NDK_CC'] = ndk_cc - builder = LocalBuilder(build_func='ndk', timeout=build_timeout) - runner = RPCRunner(device_key, host=host, port=port, - repeat=1, min_repeat_ms=400, - n_parallel=n_parallel, timeout=run_timeout) + tasks.append(ansor.SearchTask(dag, wkl_key, target, target_host)) + + def objective_func(costs): + return sum(c * w for c, w in zip(costs, wkl_weights)) + + tuner = ansor.SimpleTaskScheduler(tasks, objective_func, strategy=task_scheduler, + load_log_file=load_log_file, + load_model_file=tuning_parameters['load_model']) + tune_option, measure_ctx = create_tune_option(target, log_file, + tuning_parameters['n_trials'], tuning_parameters['num_measure_per_iter'], + tuning_parameters['verbose'], n_parallel, build_timeout, + local_measure, device_key, host, port, ndk_cc, + tuning_parameters['early_stopping']) search_policy = "%s.%s" % (policy, model_type) - tune_option = TuneOption(n_trials=tuning_parameters['n_trials'], - early_stopping=tuning_parameters['early_stopping'], - num_measure_per_iter=tuning_parameters['num_measure_per_iter'], - builder=builder, - verbose=tuning_parameters['verbose'], - runner=runner, - measure_callbacks=[LogToFile(log_file)]) + if local_measure and target.target_name != 'cuda': os.environ['TVM_BIND_MASTER_CORE_0'] = "1" - tuner.tune(tune_option, search_policy) - else: - tuner.tune(tune_option, search_policy) + + tuner.tune(tune_option, search_policy) + + if measure_ctx: + del measure_ctx kernel_layout_rewrite = False # Compile graph with best states found by auto-scheduler print("=============== Compile ===============") - with auto_scheduler.apply_history_best(log_file, args.log_n_lines): + with ansor.apply_history_best(log_file, args.log_n_lines): #if True: - #with auto_scheduler.BlockingEmptyContext(): + #with ansor.BlockingEmptyContext(): os.environ['TVM_AUTO_CACHE_FLUSH'] = "0" os.environ['TVM_BIND_MASTER_CORE_0'] = "1" if kernel_layout_rewrite: - auto_scheduler.prepare_layout_rewrite(mod, target=target, + ansor.prepare_layout_rewrite(mod, target=target, params=params, ops=(relay.op.nn.dense, relay.op.nn.conv2d, relay.op.nn.conv3d)) else: # disable layout rewrite - auto_scheduler.LayoutRewriteLevel.BOTH_REWRITE = auto_scheduler.LayoutRewriteLevel.NO_REWRITE - auto_scheduler.LayoutRewriteLevel.COMPUTE_REWRITE = auto_scheduler.LayoutRewriteLevel.NO_REWRITE + ansor.LayoutRewriteLevel.BOTH_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE + ansor.LayoutRewriteLevel.COMPUTE_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE with relay.build_config(opt_level=3): graph, lib, opt_params = relay.build_module.build( @@ -385,7 +278,7 @@ def tune_and_evaluate(network_name, model_path, batch_size, target, target_host, graph, lib, opt_params = relay.build_module.build( mod, target=target, params=params) ''' - auto_scheduler.finish_layout_rewrite() + ansor.finish_layout_rewrite() print("=============== Compile Finish ===============") module, ctx = create_module(data_shape, graph, lib, target, input_name, opt_params, @@ -415,8 +308,8 @@ def tune_and_evaluate(network_name, model_path, batch_size, target, target_host, relay.backend.compile_engine.get().clear() # disable layout rewrite - auto_scheduler.LayoutRewriteLevel.BOTH_REWRITE = auto_scheduler.LayoutRewriteLevel.NO_REWRITE - auto_scheduler.LayoutRewriteLevel.COMPUTE_REWRITE = auto_scheduler.LayoutRewriteLevel.NO_REWRITE + ansor.LayoutRewriteLevel.BOTH_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE + ansor.LayoutRewriteLevel.COMPUTE_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE target = tvm.target.create('llvm') with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): graph, lib, opt_params = relay.build_module.build( @@ -433,28 +326,40 @@ def tune_and_evaluate(network_name, model_path, batch_size, target, target_host, if __name__ == "__main__": parser = argparse.ArgumentParser() + # Task related options parser.add_argument("--network", type=str, required=True) parser.add_argument("--model-path", type=str, default=None, help="The path of tflite model") - parser.add_argument("--n-trials", type=int, default=1000) + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--layout", type=str, default='NHWC') parser.add_argument("--target", type=str, default='llvm -mcpu=core-avx2') parser.add_argument("--target-host", type=str, default=None) - parser.add_argument("--policy", type=str, choices=['multi-stage', 'meta-rewrite'], - default='meta-rewrite') + parser.add_argument("--n-trials", type=int, default=1000) + parser.add_argument("--num-measure-per-iter", type=int, default=48, + help="The number of programs to be measured at each iteration") parser.add_argument("--tune", type=str2bool, nargs='?', const=True, default=True) parser.add_argument("--check-correctness", type=str2bool, nargs='?', const=True, default=False) parser.add_argument("--debug-profile", type=str2bool, nargs='?', const=True, default=False) - parser.add_argument("--build-timeout", type=int, default=10) - parser.add_argument("--run-timeout", type=int, default=10) - parser.add_argument("--log-file", type=str, help="Write log of measurement results to this file") + + # Strategy related options + parser.add_argument("--seed", type=int, default=0, help='random seed') + parser.add_argument("--policy", type=str, choices=['multi-stage', 'meta-rewrite'], + default='meta-rewrite') parser.add_argument("--model-type", type=str, choices=['xgb', 'random', 'no-share'], default='xgb') - parser.add_argument("--load-model", action='store_true') - parser.add_argument("--model-file", type=str, default='saved_model.xgb') + parser.add_argument("--task-scheduler", type=str, default='gradient', + choices=['no', 'gradient', 'round-robin'], + help='The strategy of task scheduler') + + # File related options + parser.add_argument("--log-file", type=str, help="Write log of measurement results to this file") + parser.add_argument("--load-model", type=str, help="Load pre trained cost model file") parser.add_argument("--load-log", type=str, help="Load history log for pre-training the cost model") parser.add_argument("--out-file", type=str, default='results.tsv') - parser.add_argument("--seed", type=int, default=0, help='random seed') + parser.add_argument("--log-n-lines", type=int) + + # Detailed control options + parser.add_argument("--build-timeout", type=int, default=10) + parser.add_argument("--run-timeout", type=int, default=10) parser.add_argument("--verbose", type=int, default=1) - parser.add_argument("--joint-tuner", type=str, default='bottleneck-decay', help='The type of joint tuner', - choices=['no', 'uniform', 'weighted', 'bottleneck', 'bottleneck-decay', 'sequential', 'round-robin', 'rl']) parser.add_argument("--local-measure", type=str2bool, nargs='?', const=True, default=True) parser.add_argument("--device-key", type=str, default=None) parser.add_argument("--host", type=str, default='0.0.0.0') @@ -462,27 +367,22 @@ def tune_and_evaluate(network_name, model_path, batch_size, target, target_host, parser.add_argument("--n-parallel", type=int, default=1) parser.add_argument("--ndk-cc", type=str, default=None) parser.add_argument("--num-threads", type=int, default=None) - parser.add_argument("--batch-size", type=int, default=1) - parser.add_argument("--num-measure-per-iter", type=int, default=48, - help="The number of programs to be measured at each iteration") - parser.add_argument("--layout", type=str, default=None) - parser.add_argument("--log-n-lines", type=int) args = parser.parse_args() np.random.seed(args.seed) random.seed(args.seed) - logging.basicConfig() - logging.getLogger('auto_scheduler').setLevel(logging.DEBUG) + logging.getLogger('ansor').setLevel(logging.DEBUG) target = tvm.target.create(args.target) tuning_parameters = { 'n_trials': args.n_trials, 'num_measure_per_iter': args.num_measure_per_iter, - 'log_file': args.log_file if args.log_file else "%s-B%d.json" % (args.network, args.batch_size), + 'log_file': args.log_file or "%s-B%d.json" % (args.network, args.batch_size), + 'load_model': args.load_model, 'model_type': args.model_type, - 'joint_tuner': args.joint_tuner, + 'task_scheduler': args.task_scheduler, 'policy': args.policy, 'early_stopping': -1, 'verbose': 1, diff --git a/scripts/tune_op_subgraph.py b/scripts/tune_op_subgraph.py new file mode 100644 index 000000000000..bf5cbe83c952 --- /dev/null +++ b/scripts/tune_op_subgraph.py @@ -0,0 +1,599 @@ +"""Tune all operators for single op & subgraph evaluation""" +import argparse +import logging +import os +import random + +import numpy as np + +import tvm +from tvm import te, ansor +import topi +from topi.nn.winograd_util import winograd_transform_matrices +from topi.util import get_const_tuple + +from common import measure_schedule, str2bool, \ + norm_bmn, softmax_mn, conv2d_nhwc_bn_relu, conv2d_nchw_bn_relu +from shape_configs import single_op_shape_dict, subgraph_shape_dict +from tune_test import tune_workloads_jointly, replay_workload, create_tune_option + +# ========================== Single Ops ========================== + +@ansor.register_auto_scheduler_workload_func +def batch_matmul_nkkm(B, N, M, K): + X = te.placeholder((B, N, K), name='A') + Y = te.placeholder((B, K, M), name='B') + k = te.reduce_axis((0, K), name='k') + Z = te.compute((B, N, M), lambda b, i, j: te.sum(X[b][i][k] * Y[b][k][j], axis=[k]), name='C') + return [X, Y, Z] + +@ansor.register_auto_scheduler_workload_func +def conv1d_nlc(N, L, CI, CO, kernel_size, stride=1, padding=0, dilation=1, groups=1): + inputs = te.placeholder((N, L, CI), name='inputs') + weight = te.placeholder((kernel_size, CI//groups, CO), name='weight') + + batch_size, in_len, in_channel = inputs.shape + k_len, channel_per_group, out_channel = weight.shape + out_channel_per_group = out_channel // groups + out_len = (in_len + 2 * padding - dilation * (k_len - 1) - 1) // stride + 1 + rc = te.reduce_axis((0, channel_per_group), name='rc') + rl = te.reduce_axis((0, k_len), name='rl') + + padded = topi.nn.pad(inputs, [0, padding, 0]) + output = te.compute( + (batch_size, out_len, out_channel), + lambda n, l, co: te.sum( + (padded[n, l * stride + rl * dilation, co // out_channel_per_group * channel_per_group + rc] * + weight[rl, rc, co]), axis=[rl, rc]), + name='conv1d_nlc' + ) + return [inputs, weight, output] + +@ansor.register_auto_scheduler_workload_func +def conv2d_nhwc(N, H, W, CI, CO, kernel_size, stride=1, padding=0, dilation=1, groups=1): + inputs = te.placeholder((N, H, W, CI), name='inputs') + weight = te.placeholder((kernel_size, kernel_size, CI//groups, CO), name='weight') + batch_size, in_h, in_w, in_channel = inputs.shape + k_h, k_w, channel_per_group, out_channel = weight.shape + out_channel_per_group = out_channel // groups + + out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1 + out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1 + rh = te.reduce_axis((0, k_h), name="rh") + rw = te.reduce_axis((0, k_w), name="rw") + rc = te.reduce_axis((0, channel_per_group), name="rc") + + padded = topi.nn.pad(inputs, [0, padding, padding, 0]) + output = te.compute( + (batch_size, out_h, out_w, out_channel), + lambda n, h, w, co: te.sum( + (padded[n, h * stride + rh * dilation, w * stride + rw * dilation, + co // out_channel_per_group * channel_per_group + rc] + * weight[rh, rw, rc, co]), axis=[rh, rw, rc] + ), + name='conv2d_nhwc' + ) + return [inputs, weight, output] + +@ansor.register_auto_scheduler_workload_func +def conv2d_nchw(N, CI, H, W, CO, kernel_size, stride=1, padding=0, dilation=1, groups=1): + inputs = te.placeholder((N, CI, H, W), name='inputs') + weight = te.placeholder((CO, CI//groups, kernel_size, kernel_size), name='weight') + batch_size, in_channel, in_h, in_w = inputs.shape + out_channel, channel_per_group, k_h, k_w, = weight.shape + out_channel_per_group = out_channel // groups + + out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1 + out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1 + rc = te.reduce_axis((0, channel_per_group), name="rc") + rh = te.reduce_axis((0, k_h), name="rh") + rw = te.reduce_axis((0, k_w), name="rw") + + padded = topi.nn.pad(inputs, [0, 0, padding, padding]) + output = te.compute( + (batch_size, out_channel, out_h, out_w), + lambda n, co, h, w: te.sum( + (padded[n, co // out_channel_per_group * channel_per_group + rc, + h * stride + rh * dilation, w * stride + rw * dilation] + * weight[co, rc, rh, rw]), axis=[rc, rh, rw] + ), + name='conv2d_nchw' + ) + return [inputs, weight, output] + +@ansor.register_auto_scheduler_workload_func +def conv3d_ndhwc(N, D, H, W, CI, CO, kernel_size, stride=1, padding=0, dilation=1, groups=1): + inputs = te.placeholder((N, D, H, W, CI)) + weight = te.placeholder((kernel_size, kernel_size, kernel_size, CI//groups, CO)) + batch_size, in_d, in_h, in_w, in_channel = inputs.shape + k_d, k_h, k_w, channel_per_group, out_channel = weight.shape + out_channel_per_group = out_channel // groups + + out_d = (in_d + 2 * padding - dilation * (k_d - 1) - 1) // stride + 1 + out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1 + out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1 + rd = te.reduce_axis((0, k_d), name='rd') + rh = te.reduce_axis((0, k_h), name='rh') + rw = te.reduce_axis((0, k_w), name='rw') + rc = te.reduce_axis((0, channel_per_group), name='rc') + + padded = topi.nn.pad(inputs, [0, padding, padding, padding, 0]) + output = te.compute( + (batch_size, out_d, out_h, out_w, out_channel), + lambda n, d, h, w, co: te.sum( + (padded[n, d * stride + rd * dilation, + h * stride + rh * dilation, w * stride + rw * dilation, + co // out_channel_per_group * channel_per_group + rc] + * weight[rd, rh, rw, rc, co]), + axis=[rd, rh, rw, rc] + ), + name='conv3d_ndhwc' + ) + return [inputs, weight, output] + +@ansor.register_auto_scheduler_workload_func +def depthwise_conv2d_nhwc(N, H, W, C, kernel_size, stride=1, padding=0, dilation=1, factor=1): + inputs = te.placeholder((N, H, W, C)) + weight = te.placeholder((factor, kernel_size, kernel_size, C)) + + batch_size, in_h, in_w, in_channel = inputs.shape + factor, k_h, k_w, in_channel = weight.shape + out_channel = in_channel * factor + + assert factor.value == 1, "Not optimized for factor != 1" + + out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1 + out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1 + rh = te.reduce_axis((0, k_h), name='rh') + rw = te.reduce_axis((0, k_w), name='rw') + + padded = topi.nn.pad(inputs, [0, padding, padding, 0]) + output = te.compute( + (batch_size, out_h, out_w, out_channel), + lambda n, h, w, c: te.sum( + (padded[n, h * stride + rh * dilation, w * stride + rw * dilation, c // factor] + * weight[c % factor, rh, rw, c // factor]), + axis=[rh, rw] + ), + name="depth_conv2d_nhwc" + ) + return [inputs, weight, output] + +@ansor.register_auto_scheduler_workload_func +def conv2d_transpose_nhwc(N, H, W, CI, CO, kernel_size, stride=1, padding=0): + inputs = te.placeholder((N, H, W, CI), name='inputs') + weight = te.placeholder((kernel_size, kernel_size, CI, CO), name='weight') + + batch, in_h, in_w, in_c = inputs.shape + filter_h, filter_w, in_c, out_c = weight.shape + stride_h, stride_w = (stride, stride) + + # compute padding + fpad_top, fpad_left, fpad_bottom, fpad_right = topi.nn.get_pad_tuple(padding, (filter_h, filter_w)) + bpad_top = filter_h - 1 - fpad_top + bpad_bottom = filter_h - 1 - fpad_bottom + bpad_left = filter_w - 1 - fpad_left + bpad_right = filter_w - 1 - fpad_right + + # padding stage + padded = topi.nn.pad(inputs, + [0, (bpad_top + stride_h - 1) // stride_h, + (bpad_left + stride_w - 1) // stride_w, 0], + [0, (bpad_bottom + stride_h - 1) // stride_h, + (bpad_right + stride_w - 1) // stride_w, 0]) + + # remove extra padding introduced by dilatation + idxdiv = te.indexdiv + idxmod = te.indexmod + border_h = idxmod(stride_h - idxmod(bpad_top, stride_h), stride_h) + border_w = idxmod(stride_w - idxmod(bpad_left, stride_w), stride_w) + + # dilation stage + strides = [1, stride_h, stride_w, 1] + n = len(padded.shape) + + # We should embed this dilation directly into te.compute rather than creating a new te.compute. + # Only in this way can we use unroll to eliminate the multiplication of zeros. + def _dilate(*indices): + not_zero = [] + index_tuple = [] + for i in range(n): + if not strides[i] == 1: + index_tuple.append(idxdiv(indices[i], strides[i])) + not_zero.append(idxmod(indices[i], strides[i]).equal(0)) + else: + index_tuple.append(indices[i]) + if not_zero: + not_zero = te.all(*not_zero) + return te.if_then_else(not_zero, padded(*index_tuple), tvm.tir.const(0.0, padded.dtype)) + return padded(*index_tuple) + + # convolution stage + out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h + out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w + rc = te.reduce_axis((0, in_c), name='rc') + rh = te.reduce_axis((0, filter_h), name='rh') + rw = te.reduce_axis((0, filter_w), name='rw') + + output = te.compute( + (batch, out_h, out_w, out_c), + lambda n, h, w, co: te.sum( + _dilate(n, h + rh + border_h, w + rw + border_w, rc) * + weight[filter_h - 1 - rh, filter_w - 1 - rw, rc, co], + axis=[rh, rw, rc]), + name="conv2d_transpose_nhwc", + attrs={"auto_scheduler_always_unroll_inner": ["h", "w", "rh", "rw", "h_c", "w_c"]}) + # todo(lmzheng): add constraints on the tile size of h and w + + return [inputs, weight, output] + +@ansor.register_auto_scheduler_workload_func +def conv2d_capsule_nhwijc(N, H, W, CI, CO, kernel_size, stride=1, padding=0, capsule_size=4): + inputs = te.placeholder((N, H, W, capsule_size, capsule_size, CI), name='inputs') + weight = te.placeholder((kernel_size, kernel_size, capsule_size, capsule_size, CI, CO), name='weight') + batch_size, in_h, in_w, _, _, in_channel = inputs.shape + k_h, k_w, _, _, _, out_channel = weight.shape + + out_h = (in_h + 2 * padding - kernel_size) // stride + 1 + out_w = (in_w + 2 * padding - kernel_size) // stride + 1 + + rh = te.reduce_axis((0, k_h), name="rh") + rw = te.reduce_axis((0, k_w), name="rw") + cap_k = te.reduce_axis((0, capsule_size), name='cap_k') + rc = te.reduce_axis((0, in_channel), name="rc") + + padded = topi.nn.pad(inputs, [0, padding, padding, 0, 0, 0]) + output = te.compute( + (batch_size, out_h, out_w, capsule_size, capsule_size, out_channel), + lambda n, h, w, cap_i, cap_j, co: te.sum( + (padded[n, h * stride + rh, w * stride + rw, cap_i, cap_k, rc] + * weight[rh, rw, cap_k, cap_j, rc, co]), axis=[rh, rw, cap_k, rc] + ), + name='conv2d_capsule_nhwijc' + ) + return [inputs, weight, output] + + +@ansor.register_auto_scheduler_workload_func +def conv2d_winograd_nhwc(N, H, W, CI, CO, kernel_size=3, stride=1, padding=0, dilation=1): + # TODO: implement tile_size + tile_size = 4 #_infer_tile_size(data, kernel) + inputs = te.placeholder((N, H, W, CI), name='inputs') + #weight = te.placeholder((kernel_size, kernel_size, CI, CO), name='weight') + N, H, W, CI = get_const_tuple(inputs.shape) + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + # if dilation_h != 1 or dilation_w != 1: + # weight = topi.nn.dilate(weight, (1, 1, dilation_h, dilation_w)) + KH = KW = kernel_size + HPAD, WPAD, _, _ = topi.nn.get_pad_tuple(padding, (KH, KW)) + HSTR, WSTR = (stride, stride) if isinstance(stride, int) else stride + assert HSTR == 1 and WSTR == 1 and KH == KW + + data_pad = topi.nn.pad(inputs, (0, HPAD, WPAD, 0), (0, HPAD, WPAD, 0), name="data_pad") + + r = KW + m = tile_size + alpha = m + r - 1 + A, B, G = winograd_transform_matrices(m, r, 'float32') + + H = (H + 2 * HPAD - KH) // HSTR + 1 + W = (W + 2 * WPAD - KW) // WSTR + 1 + nH, nW = (H + m - 1) // m, (W + m - 1) // m + P = N * nH * nW + r_kh = te.reduce_axis((0, KH), name='r_kh') + r_kw = te.reduce_axis((0, KW), name='r_kw') + # kernel_pack = te.compute((alpha, alpha, CO, CI), lambda eps, nu, co, ci: + # weight[0][0][0][0], + # name='kernel_pack') + kshape = (alpha, alpha, CO, CI) + kernel_pack = te.placeholder(kshape, inputs.dtype, name="weight") + + idxdiv = te.indexdiv + idxmod = te.indexmod + # pack input tile + input_tile = te.compute((alpha, alpha, P, CI), lambda eps, nu, p, ci: + data_pad[idxdiv(p, (nH * nW))][idxmod(idxdiv(p, nW), nH) * m + eps] + [idxmod(p, nW) * m + nu][ci], name='input_tile',) + + # transform data + r_a = te.reduce_axis((0, alpha), 'r_a') + r_b = te.reduce_axis((0, alpha), 'r_b') + data_pack = te.compute((alpha, alpha, P, CI), lambda eps, nu, p, ci: + te.sum(input_tile[r_a][r_b][p][ci] * B[r_a][eps] * B[r_b][nu], + axis=[r_a, r_b]), name='data_pack', + attrs={"auto_scheduler_no_split_at_inner": ["eps", "nu", "r_a", "r_b"], + "auto_scheduler_last_split_is_one": ["ci", "p"], + "auto_scheduler_always_unroll": ["eps", "nu", "r_a", "r_b"], + "auto_scheduler_no_cache_write": "True", + }) + + # do batch gemm + ci = te.reduce_axis((0, CI), name='ci') + bgemm = te.compute((alpha, alpha, P, CO), lambda eps, nu, p, co: + te.sum(data_pack[eps][nu][p][ci] * + kernel_pack[eps][nu][co][ci], + axis=[ci]), name='bgemm') + + # inverse transform + r_a = te.reduce_axis((0, alpha), 'r_a') + r_b = te.reduce_axis((0, alpha), 'r_b') + inverse = te.compute((m, m, P, CO), lambda vh, vw, p, co: + te.sum(bgemm[r_a][r_b][p][co] * A[r_a][vh] * A[r_b][vw], + axis=[r_a, r_b]), name='inverse', + attrs={"auto_scheduler_no_split_at_inner": ["vh", "vw", "r_a", "r_b"], + "auto_scheduler_always_unroll": ["vh", "vw", "r_a", "r_b"], + "auto_scheduler_last_split_is_one": ["co", "p"], + "auto_scheduler_no_cache_write": "True", + }) + + # output + output = te.compute((N, H, W, CO), lambda n, h, w, co: + inverse[idxmod(h, m), + idxmod(w, m), + n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m), + co], + name='conv2d_winograd', + tag='conv2d_winograd_nhwc', + attrs={"auto_scheduler_no_split_at_outer": ["n", "h", "w", "co"],}) + return [inputs, kernel_pack, output] + +@ansor.register_auto_scheduler_workload_func +def conv2d_winograd_nchw(N, CI, H, W, CO, kernel_size=3, stride=1, padding=0, dilation=1, precompute=False): + # TODO: implement tile_size + tile_size = 4 #_infer_tile_size(data, kernel) + inputs = te.placeholder((N, CI, H, W), name='inputs') + #weight = te.placeholder((CO, CI, kernel_size, kernel_size), name='weight') + N, CI, H, W = get_const_tuple(inputs.shape) + # if isinstance(dilation, int): + # dilation_h = dilation_w = dilation + # else: + # dilation_h, dilation_w = dilation + # if dilation_h != 1 or dilation_w != 1: + # weight = topi.nn.dilate(weight, (1, 1, dilation_h, dilation_w)) + KH = KW = kernel_size + HPAD, WPAD, _, _ = topi.nn.get_pad_tuple(padding, (KH, KW)) + HSTR, WSTR = (stride, stride) if isinstance(stride, int) else stride + assert HSTR == 1 and WSTR == 1 and KH == KW + + data_pad = topi.nn.pad(inputs, (0, 0, HPAD, WPAD), (0, 0, HPAD, WPAD), name="data_pad") + + r = KW + m = tile_size + alpha = m + r - 1 + A, B, G = winograd_transform_matrices(m, r, 'float32') + + H = (H + 2 * HPAD - KH) // HSTR + 1 + W = (W + 2 * WPAD - KW) // WSTR + 1 + nH, nW = (H + m - 1) // m, (W + m - 1) // m + P = N * nH * nW + r_kh = te.reduce_axis((0, KH), name='r_kh') + r_kw = te.reduce_axis((0, KW), name='r_kw') + # kernel_pack = te.compute((alpha, alpha, CI, CO), lambda eps, nu, ci, co: + # weight[0][0][0][0], + # name='kernel_pack') + kshape = (alpha, alpha, CI, CO) + kernel_pack = te.placeholder(kshape, inputs.dtype, name="weight") + + idxdiv = te.indexdiv + idxmod = te.indexmod + # pack input tile + input_tile = te.compute((CI, P, alpha, alpha), lambda ci, p, eps, nu: + data_pad[idxdiv(p, (nH * nW))][ci][idxmod(idxdiv(p, nW), nH) * m + eps] + [idxmod(p, nW) * m + nu], name='input_tile') + + # transform data + r_a = te.reduce_axis((0, alpha), 'r_a') + r_b = te.reduce_axis((0, alpha), 'r_b') + data_pack = te.compute((alpha, alpha, CI, P), lambda eps, nu, ci, p: + te.sum(input_tile[ci][p][r_a][r_b] * B[r_a][eps] * B[r_b][nu], + axis=[r_a, r_b]), name='data_pack', + attrs={"auto_scheduler_no_split_at_inner": ["eps", "nu", "r_a", "r_b"], + "auto_scheduler_no_split_at_outer": ["ci", "p"], + "auto_scheduler_always_unroll": ["eps", "nu", "r_a", "r_b"], + "auto_scheduler_no_cache_write": "True", + }) + + # do batch gemm + ci = te.reduce_axis((0, CI), name='ci') + bgemm = te.compute((alpha, alpha, CO, P), lambda eps, nu, co, p: + te.sum(data_pack[eps][nu][ci][p] * + kernel_pack[eps][nu][ci][co], + axis=[ci]), name='bgemm') + + # inverse transform + r_a = te.reduce_axis((0, alpha), 'r_a') + r_b = te.reduce_axis((0, alpha), 'r_b') + inverse = te.compute((CO, P, m, m), lambda co, p, vh, vw: + te.sum(bgemm[r_a][r_b][co][p] * A[r_a][vh] * A[r_b][vw], + axis=[r_a, r_b]), name='inverse', + attrs={"auto_scheduler_no_split_at_outer": ["co", "p", "vh", "vw", "r_a", "r_b"], + "auto_scheduler_always_unroll": ["vh", "vw", "r_a", "r_b"], + "auto_scheduler_no_cache_write": "True"}) + + # output + output = te.compute((N, CO, H, W), lambda n, co, h, w: + inverse[co, n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m), + idxmod(h, m), + idxmod(w, m)], + name='conv2d_winograd', + attrs={"auto_scheduler_no_split_at_outer": ["n", "co", "h", "w"],}) + return [inputs, kernel_pack, output] + +# ========================== Subgraphs ========================== + +@ansor.register_auto_scheduler_workload_func +def transpose_batch_matmul(batch, seq_len, n_head, n_dim): + query = te.placeholder((batch, seq_len, n_head, n_dim), name='query') + value = te.placeholder((batch, seq_len, n_head, n_dim), name='value') + query_T = te.compute((batch, n_head, seq_len, n_dim), + lambda b, h, l, d: query[b, l, h, d], name="query_T") + value_T = te.compute((batch, n_head, n_dim, seq_len), + lambda b, h, d, l: value[b, l, h, d], name="value_T") + k = te.reduce_axis((0, n_dim), name='k') + out = te.compute((batch, n_head, seq_len, seq_len), lambda b, h, i, j: te.sum(query_T[b][h][i][k] * value_T[b][h][k][j], axis=[k]), name='C') + return [query, value, out] + +@ansor.register_auto_scheduler_workload_func +def batch_norm(M, N, eps=1e-5): + A = te.placeholder((M, N), name='A') + k1 = te.reduce_axis((0, M), name='k1') + k2 = te.reduce_axis((0, M), name='k2') + mean = te.compute((N,), lambda j: te.sum(A[k1][j] / M, axis=k1), name="mean") + var = te.compute((N,), + lambda j: te.sum((A[k2][j] - mean[j]) * (A[k2][j] - mean[j]) / (M - 1), k2), + name="var") + B = te.compute((M, N), lambda i, j: (A[i][j] - mean[j]) / te.sqrt(var[j] + eps), name='B') + + return [A, B] + +# ========================== Tune func & Dicts ========================== + +def tune_wkl(task_func_dict, shape_dict, wkl_type, args): + target = tvm.target.create(args.target) + + for wkl_meta_name, func in task_func_dict.items(): + if not args.wkl in ["all", wkl_type, wkl_meta_name]: + continue + + log_file = args.log_file or wkl_meta_name + ".json" + wkl_keys = [] + for shape in shape_dict[wkl_meta_name]: + if shape[0] == 1: + shape = list(shape) + shape[0] = args.batch_size + wkl_key = ansor.make_workload_key_func(func, shape) + + wkl_keys.append(wkl_key) + if args.fast_check: + break + + if not args.tune: + cost, gflops = replay_workload( + wkl_key, target, args.target_host, log_file, + args.local_measure, args.device_key, args.host, + args.port, args.ndk_cc, False) + # TODO(): Add log record + # log_line(BenchmarkRecord(target.name, 'gpu' if target.name == 'cuda' else 'cpu', 'subgraph', + # workload_name, "AutoSchedule", "default", + # {"costs": [cost]}, time.time()), args.out_file) + + if args.tune: + print("========== Tune for %s (%d shapes) ========== " % (wkl_meta_name, len(wkl_keys))) + + load_log_file = args.load_log or log_file + n_trials = args.n_trials_per_shape * len(wkl_keys) + + tune_option, measure_ctx = create_tune_option(target, log_file, + n_trials, args.num_measure_per_iter, args.verbose, + args.n_parallel, args.build_timeout, args.local_measure, + args.device_key, args.host, args.port, args.ndk_cc) + + # tune workloads jointly using JointTuner + tune_workloads_jointly(wkl_keys, np.ones(len(wkl_keys)), args.task_scheduler, + target, args.target_host, args.policy, args.model_type, + args.load_model, load_log_file, tune_option) + + if measure_ctx: + del measure_ctx + + +single_op_task_func_dict = { + 'GMM': batch_matmul_nkkm, + 'C1D': conv1d_nlc, + 'C2D': conv2d_nhwc, + 'C3D': conv3d_ndhwc, + 'GRP': conv2d_nhwc, + 'DIL': conv2d_nhwc, + 'DEP': depthwise_conv2d_nhwc, + 'T2D': conv2d_transpose_nhwc, + 'CAP': conv2d_capsule_nhwijc, + 'NRM': norm_bmn, + #'SMX': softmax_mn, + +# The following workloads are not in our sinle op evaluation plan. +# They should be moved to `common.py` and be used by `tune_wkl.py`. +# 'C2D_NCHW': conv2d_nchw, + 'C2DWG_NHWC': conv2d_winograd_nhwc, +# 'C2DWG_NCHW': conv2d_winograd_nchw, +# 'GMM_TC': matmul_nkkm, +} + +subgraph_task_func_dict = { + 'conv2d_bn_relu': conv2d_nhwc_bn_relu, + #'conv2d_bn_relu': conv2d_nchw_bn_relu, # some old log uses conv2d_nchw_bn_relu + 'transpose_batch_matmul': transpose_batch_matmul, +} + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Task related options + parser.add_argument("--wkl", type=str, required=True, + help="all - For all workloads; \ + op - For all single ops; \ + subgraph - For all subgraphs; \ + Or specific wkl name") + parser.add_argument("--target", type=str, default='llvm -mcpu=core-avx2') + parser.add_argument("--target-host", type=str, default=None) + parser.add_argument("--n-trials-per-shape", type=int, default=1000) + parser.add_argument("--num-measure-per-iter", type=int, default=48, + help="The number of programs to be measured at each iteration") + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--tune", type=str2bool, nargs='?', const=True, default=True) + parser.add_argument("--fast-check", action='store_true', + help='Only run one shape for each workload. This is used for fast checking') + + # Strategy related options + parser.add_argument("--seed", type=int, default=0, help='random seed') + parser.add_argument("--policy", type=str, choices=['meta-rewrite', 'beam-search'], default='meta-rewrite') + parser.add_argument("--model-type", type=str, choices=['xgb', 'random', 'no-share'], default='xgb') + parser.add_argument("--task-scheduler", type=str, default='gradient', + choices=['no', 'gradient', 'round-robin'], + help='The strategy of task scheduler') + + # File related options + parser.add_argument("--log-file", type=str, help="Write log of measurement results to this file") + parser.add_argument("--load-model", type=str, help="Load pre trained cost model file") + parser.add_argument("--load-log", type=str, help="Load history log for pre-training the cost model") + parser.add_argument("--out-file", type=str, default='results.tsv') + + # Detailed control options + parser.add_argument("--build-timeout", type=int, default=10) + parser.add_argument("--run-timeout", type=int, default=60) + parser.add_argument("--verbose", type=int, default=1) + parser.add_argument("--local-measure", type=str2bool, nargs='?', const=True, default=True) + parser.add_argument("--device-key", type=str, default=None) + parser.add_argument("--host", type=str, default='0.0.0.0') + parser.add_argument("--port", type=int, default=9190) + parser.add_argument("--n-parallel", type=int, default=1) + parser.add_argument("--ndk-cc", type=str, default=None) + args = parser.parse_args() + + np.random.seed(args.seed) + random.seed(args.seed) + logging.basicConfig() + logging.getLogger('ansor').setLevel(logging.DEBUG) + + # compute the number of tasks + num_tasks = 0 + for wkl_meta_name in single_op_task_func_dict: + if not args.wkl in ["all", "op", wkl_meta_name]: + continue + if args.fast_check: + num_tasks += 1 + else: + num_tasks += len(single_op_shape_dict[wkl_meta_name]) + for wkl_meta_name in subgraph_task_func_dict: + if not args.wkl in ["all", "subgraph", wkl_meta_name]: + continue + if args.fast_check: + num_tasks += 1 + else: + num_tasks += len(subgraph_shape_dict[wkl_meta_name]) + print("Number of tasks: %d\tTotal trials: %d" % (num_tasks, num_tasks * args.n_trials_per_shape)) + + # tune for tasks + tune_wkl(single_op_task_func_dict, single_op_shape_dict, "op", args) + tune_wkl(subgraph_task_func_dict, subgraph_shape_dict, "subgraph", args) diff --git a/scripts/tune_test.py b/scripts/tune_test.py index 08f0cc19ade2..d6f552affbb1 100644 --- a/scripts/tune_test.py +++ b/scripts/tune_test.py @@ -12,26 +12,61 @@ from common import get_workload_keys, get_workload_weights, measure_schedule, str2bool +def create_tune_option(target, log_file, n_trials, num_measure_per_iter, verbose, + n_parallel, build_timeout, local_measure, device_key, host, + port, ndk_cc, early_stopping=-1): + builder = runner = measure_ctx = None + if local_measure: + builder = ansor.LocalBuilder(timeout=build_timeout) + if target.target_name == "cuda": + measure_ctx = ansor.LocalRPCMeasureContext(repeat=1, min_repeat_ms=400) + runner = measure_ctx.runner + else: + runner = ansor.LocalRunner(repeat=1, min_repeat_ms=400) + else: + os.environ['TVM_NDK_CC'] = ndk_cc + builder = ansor.LocalBuilder(timeout=build_timeout, build_func='ndk') + runner = ansor.RPCRunner(key=device_key, host=host, port=port, + n_parallel=n_parallel, repeat=1, min_repeat_ms=400) + + tune_option = ansor.TuneOption(n_trials=n_trials, early_stopping=early_stopping, + num_measure_per_iter=num_measure_per_iter, + verbose=verbose, + builder=builder, + runner=runner, + measure_callbacks=[ansor.LogToFile(log_file)], + pre_search_callbacks=[ansor.PreLoadMeasuredStates(log_file)]) + + return tune_option, measure_ctx + def replay_workload(wkl_key, target, target_host, log_file, local_measure=True, device_key=None, host="0.0.0.0", - port=9190, ndk_cc=None): + port=9190, ndk_cc=None, show_lower_result=True): + cost = gflops = None + inp, res = ansor.best_measure_pair_in_file(log_file, wkl_key, target) if inp is None: print("Cannot find log for: %s" % (wkl_key)) else: dag = ansor.workload_key_to_dag(inp.task.workload_key) + print("Found schedule for: %s" % (wkl_key)) + s, bufs = dag.apply_steps_from_state(inp.state) + if show_lower_result: + print(tvm.lower(s, bufs, simple_mode=True)) - print("Found schedule for: %s" % (wkl_key)) - print(tvm.lower(s, bufs, simple_mode=True)) if local_measure: remote = None else: remote = request_remote(device_key, host, port, 1) + cost = np.mean((measure_schedule(s, bufs, target, remote=remote, ndk_cc=ndk_cc))) + gflops = ansor.ComputeDAG(bufs).flop_ct / cost / 1e9 print("Best schedule: %.2f GFLOPS\tcost: %.3f ms" % - (ansor.ComputeDAG(bufs).flop_ct / cost / 1e9, cost * 1e3)) + (gflops, cost * 1e3)) + + return cost, gflops def tune_workload(wkl_key, target, target_host, policy, model_type, load_model_file, @@ -99,6 +134,7 @@ def objective_func(costs): parser.add_argument("--num-measure-per-iter", type=int, default=48, help="The number of programs to be measured at each iteration") parser.add_argument("--tune", type=str2bool, nargs='?', const=True, default=True) + # Strategy related options parser.add_argument("--seed", type=int, default=0, help='random seed') parser.add_argument("--policy", type=str, choices=['meta-rewrite', 'beam-search'], default='meta-rewrite') @@ -106,13 +142,15 @@ def objective_func(costs): parser.add_argument("--task-scheduler", type=str, default='no', choices=['no', 'gradient', 'round-robin'], help='The strategy of task scheduler') + # File related options parser.add_argument("--log-file", type=str, help="Write log of measurement results to this file") parser.add_argument("--load-model", type=str, help="Load pre trained cost model file") parser.add_argument("--load-log", type=str, help="Load history log for pre-training the cost model") + # Detailed control options parser.add_argument("--build-timeout", type=int, default=10) - parser.add_argument("--run-timeout", type=int, default=60) + parser.add_argument("--run-timeout", type=int, default=60) parser.add_argument("--verbose", type=int, default=1) parser.add_argument("--local-measure", type=str2bool, nargs='?', const=True, default=True) parser.add_argument("--device-key", type=str, default=None) @@ -124,40 +162,21 @@ def objective_func(costs): np.random.seed(args.seed) random.seed(args.seed) - logging.basicConfig() logging.getLogger('ansor').setLevel(logging.DEBUG) - log_file = args.log_file or args.wkl + ".json" - - target = tvm.target.create(args.target) wkl_keys = get_workload_keys(args.wkl) + target = tvm.target.create(args.target) + log_file = args.log_file or args.wkl + ".json" if args.tune: load_log_file = args.load_log or log_file weights = get_workload_weights(args.wkl) - builder = runner = measure_ctx = None - if args.local_measure: - builder = ansor.LocalBuilder(timeout=args.build_timeout) - if target.target_name == "cuda": - measure_ctx = ansor.LocalRPCMeasureContext(repeat=1, min_repeat_ms=400) - runner = measure_ctx.runner - else: - runner = ansor.LocalRunner(repeat=1, min_repeat_ms=400) - else: - os.environ['TVM_NDK_CC'] = args.ndk_cc - builder = ansor.LocalBuilder(timeout=args.build_timeout, build_func='ndk') - runner = ansor.RPCRunner(args.device_key, host=args.host, port=args.port, - repeat=1, min_repeat_ms=400, n_parallel=args.n_parallel) - - tune_option = ansor.TuneOption(n_trials=args.n_trials, - num_measure_per_iter=args.num_measure_per_iter, - verbose=args.verbose, - builder=builder, - runner=runner, - measure_callbacks=[ansor.LogToFile(log_file)], - pre_search_callbacks=[ansor.PreLoadMeasuredStates(log_file)]) + tune_option, measure_ctx = create_tune_option(target, log_file, + args.n_trials, args.num_measure_per_iter, args.verbose, + args.n_parallel, args.build_timeout, args.local_measure, + args.device_key, args.host, args.port, args.ndk_cc) if args.task_scheduler == 'no': # tune workloads one by one diff --git a/src/ansor/search_policy/search_policy.cc b/src/ansor/search_policy/search_policy.cc index d52b868e180d..c07a3af7473c 100644 --- a/src/ansor/search_policy/search_policy.cc +++ b/src/ansor/search_policy/search_policy.cc @@ -54,9 +54,11 @@ void SearchPolicyNode::PreLoadMeasuredStates(const std::string& log_file) { measured_states_set_.insert(state.ToStr()); } - StdCout(verbose_) << "Measured States Set: " - << measured_states_set_.size() + StdCout(verbose_) << "Measured States Set: " << measured_states_set_.size() << " state hashes loaded from " << log_file << std::endl; + } else { + StdCout(verbose_) << "Measured States Set: no states found from " + << log_file << std::endl; } } From 2f241ed4f83763979e827da2d6cca55c7f28cb77 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 12 Jun 2020 15:42:38 -0700 Subject: [PATCH 24/45] add explicit_unroll_max_extent (#25) --- src/tir/transforms/unroll_loop.cc | 19 ++++++++++++--- .../test_tir_transform_unroll_loop.py | 24 +++++++++++++++++++ 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index a15190665949..1c84304fb0e7 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -43,6 +43,7 @@ struct UnrollLoopConfigNode : public tvm::AttrsNode { int auto_max_depth; int auto_max_extent; int explicit_unroll; + int explicit_unroll_max_extent; TVM_DECLARE_ATTRS(UnrollLoopConfigNode, "tir.transform.UnrollLoopConfig") { TVM_ATTR_FIELD(auto_max_step) @@ -57,6 +58,9 @@ struct UnrollLoopConfigNode : public tvm::AttrsNode { TVM_ATTR_FIELD(explicit_unroll) .describe("Whether to explicitly unroll the loop instead of setting a pragma") .set_default(true); + TVM_ATTR_FIELD(explicit_unroll_max_extent) + .describe("The maximum extent of a loop that can be unrolled explicitly (-1 means infinite)") + .set_default(32); } }; @@ -71,11 +75,12 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig); class LoopUnroller : public StmtExprMutator { public: explicit LoopUnroller(int auto_max_step, int auto_max_depth, int auto_max_extent, - bool explicit_unroll) + bool explicit_unroll, int explicit_unroll_max_extent) : auto_max_step_(auto_max_step), auto_max_depth_(auto_max_depth), auto_max_extent_(auto_max_extent), - explicit_unroll_(explicit_unroll) {} + explicit_unroll_(explicit_unroll), + explicit_unroll_max_extent_(explicit_unroll_max_extent) {} Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == "pragma_auto_unroll_max_step") { @@ -165,6 +170,11 @@ class LoopUnroller : public StmtExprMutator { // For loop must have a constant integer extent CHECK_NE(value, -1) << "loop doesn't have a constant integer extent"; if (value == 0) return Evaluate(0); + if (explicit_unroll_max_extent_ > 0 && value > explicit_unroll_max_extent_ && explicit_unroll_) { + // Do not unroll too long loops + ForType for_type = op->for_type == ForType::Unrolled ? ForType::Serial : op->for_type; + return ForNode::make(op->loop_var, op->min, op->extent, for_type, op->device_api, op->body); + } Stmt body = op->body; Map vmap; Array unrolled; @@ -197,7 +207,10 @@ class LoopUnroller : public StmtExprMutator { // max extent of loop to auto unroll // this not not count the total steps, only count the number of loops int auto_max_extent_; + // Whether to explicitly unroll the loop instead of setting a pragma bool explicit_unroll_; + // The maximum extent of a loop that can be unrolled explicitly (-1 means infinite) + int explicit_unroll_max_extent_; // Number of normal loops in scope int normal_loop_depth_{0}; // number of unrolled cases in current scope. @@ -210,7 +223,7 @@ class LoopUnroller : public StmtExprMutator { Stmt UnrollLoop(Stmt stmt, UnrollLoopConfig cfg) { Stmt ret = LoopUnroller(cfg->auto_max_step, cfg->auto_max_depth, cfg->auto_max_extent, - cfg->explicit_unroll)(stmt); + cfg->explicit_unroll, cfg->explicit_unroll_max_extent)(stmt); if (!ret.same_as(stmt)) { return ConvertSSA(ret); } else { diff --git a/tests/python/unittest/test_tir_transform_unroll_loop.py b/tests/python/unittest/test_tir_transform_unroll_loop.py index 68639940bb05..12c686634548 100644 --- a/tests/python/unittest/test_tir_transform_unroll_loop.py +++ b/tests/python/unittest/test_tir_transform_unroll_loop.py @@ -110,7 +110,31 @@ def test_unroll_single_count_loops(): ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body assert ret == stmt +def test_unroll_explicitly_max_extent(): + n = 64 + A = te.placeholder((n,), name='A') + B = te.compute((n,), lambda *i: A(*i), name='B') + s = te.create_schedule(B.op) + s = s.normalize() + dom_map = tvm.te.schedule.InferBound(s) + stmt = tvm.te.schedule.ScheduleOps(s, dom_map) + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + + with tvm.transform.PassContext(config={ + "tir.UnrollLoop": {"explicit_unroll_max_extent": n-1} + }): + ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body + assert tvm.ir.structural_equal(ret, stmt) + + with tvm.transform.PassContext(config={ + "tir.UnrollLoop": {"explicit_unroll_max_extent": n} + }): + ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body + assert not tvm.ir.structural_equal(ret, stmt) + + if __name__ == "__main__": test_unroll_loop() test_unroll_fake_loop() test_unroll_single_count_loops() + test_unroll_explicitly_max_extent() From 18d44b8cff0a7048e79394d9ef16da986ebc3ca5 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Mon, 15 Jun 2020 18:17:56 +0800 Subject: [PATCH 25/45] Add Index simplification & API update (#26) * Add vectorized cooperative_fetching test * Update math simplify for vectorized CF * File rename * Update tune_network * API update --- python/tvm/ansor/auto_schedule.py | 3 - python/tvm/ansor/compute_dag.py | 24 ++- python/tvm/ansor/feature.py | 10 +- python/tvm/ansor/loop_state.py | 14 +- python/tvm/ansor/measure.py | 2 +- python/tvm/ansor/serialization.py | 4 +- scripts/tune_network.py | 136 ++++++++-------- scripts/tune_test.py | 4 +- src/ansor/loop_state.cc | 5 - .../search_policy/meta_tile_rewrite_policy.cc | 3 + src/arith/rewrite_simplify.cc | 71 +++++++- tests/python/unittest/test_ansor_common.py | 2 +- .../python/unittest/test_ansor_compute_dag.py | 3 +- tests/python/unittest/test_ansor_feature.py | 4 +- tests/python/unittest/test_ansor_measure.py | 2 +- ...t_ansor_vectorized_cooperative_fetching.py | 152 ++++++++++++++++++ tutorials/ansor/tune_conv2d_cuda.py | 8 +- tutorials/ansor/tune_simple_subgraph.py | 8 +- 18 files changed, 344 insertions(+), 111 deletions(-) create mode 100644 tests/python/unittest/test_ansor_vectorized_cooperative_fetching.py diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index 09895302d25a..127be4c7ad22 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -245,8 +245,6 @@ def auto_schedule(workload, target=None, Returns ------- - state : State - sch : tvm.Schedule tensors : List[Tensor] @@ -270,4 +268,3 @@ def auto_schedule(workload, target=None, else: raise ValueError("Invalid workload: " + workload + ". Expect a string or SearchTask") - diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index 0c8aa2055482..23ba1b32f5c4 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -20,7 +20,7 @@ import tvm._ffi from tvm.runtime import Object from tvm import te -from .loop_state import State +from .loop_state import State, StateObject from . import _ffi_api @@ -63,8 +63,12 @@ def apply_steps_from_state(self, state, layout_rewrite_level=None): sch : Schedule args : List[Tensor] """ - sch, args = _ffi_api.ComputeDAGApplyStepsFromState(self, state) - return sch, args + if isinstance(state, State): + return _ffi_api.ComputeDAGApplyStepsFromState(self, state.state_object) + elif isinstance(state, StateObject): + return _ffi_api.ComputeDAGApplyStepsFromState(self, state) + else: + raise ValueError("The input must be a State or StateObject") def print_python_code_from_state(self, state): """ @@ -76,7 +80,12 @@ def print_python_code_from_state(self, state): ------- str : Str """ - return _ffi_api.ComputeDAGPrintPythonCodeFromState(self, state) + if isinstance(state, State): + return _ffi_api.ComputeDAGPrintPythonCodeFromState(self, state.state_object) + elif isinstance(state, StateObject): + return _ffi_api.ComputeDAGPrintPythonCodeFromState(self, state) + else: + raise ValueError("The input must be a State or StateObject") def infer_bound_from_state(self, state): """ @@ -88,7 +97,12 @@ def infer_bound_from_state(self, state): ------- state : StateObject """ - return _ffi_api.ComputeDAGInferBoundFromState(self, state) + if isinstance(state, State): + return State(_ffi_api.ComputeDAGInferBoundFromState(self, state.state_object)) + elif isinstance(state, StateObject): + return State(_ffi_api.ComputeDAGInferBoundFromState(self, state)) + else: + raise ValueError("The input must be a State or StateObject") def gen_schedule(state, bufs): if not state or not state.complete: diff --git a/python/tvm/ansor/feature.py b/python/tvm/ansor/feature.py index f91d7da169f5..4f9fdeb9e6cd 100644 --- a/python/tvm/ansor/feature.py +++ b/python/tvm/ansor/feature.py @@ -23,7 +23,7 @@ import struct import numpy as np -from .loop_state import StateObject +from .loop_state import State, StateObject from .measure import MeasureInput, MeasureResult from . import _ffi_api @@ -131,12 +131,16 @@ def get_per_stmt_features_from_measure_pairs(inputs: List[MeasureInput], return unpack_feature(byte_arr) -def get_per_stmt_features_from_states(states: List[StateObject], +def get_per_stmt_features_from_states(states, task: "SearchTask", max_n_bufs: int = None) -> List[np.ndarray]: """Get per_stmt features from states""" + if isinstance(states[0], State): + state_objects = [s.state_object for s in states] + elif isinstance(states[0], StateObject): + state_objects = states byte_arr = _ffi_api.GetPerStmtFeaturesFromStates( - states, task, max_n_bufs or DEFAULT_MAX_N_BUFS) + state_objects, task, max_n_bufs or DEFAULT_MAX_N_BUFS) return unpack_feature(byte_arr)[0] diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/ansor/loop_state.py index 557bb9d3102b..0cf157147423 100644 --- a/python/tvm/ansor/loop_state.py +++ b/python/tvm/ansor/loop_state.py @@ -60,18 +60,6 @@ def iters(self): setattr(self, "iterators_cache", _ffi_api.StageGetIterators(self)) return getattr(self, "iterators_cache") - def iter(self, index): - """ - Parameters - ---------- - index : Int - - Returns - ------- - iter : Iterator - """ - return _ffi_api.StageGetIterator(self, index) - @tvm._ffi.register_object("ansor.State") class StateObject(Object): @@ -302,7 +290,7 @@ def bind_thread(self, stage_id, it, thread_name): } thread_id = trans_table[thread_name] - self.state_object, res = _ffi_api.StateUnroll(self.state_object, stage_id, it, thread_id) + self.state_object, res = _ffi_api.StateBindThread(self.state_object, stage_id, it, thread_id) self.clear_cache() return res diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index 610e9529090f..b82327ec67c4 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -62,7 +62,7 @@ class MeasureInput(Object): """ def __init__(self, task, state): - self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state) + self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state.state_object) @tvm._ffi.register_object("ansor.BuildResult") diff --git a/python/tvm/ansor/serialization.py b/python/tvm/ansor/serialization.py index 3d7ed7733a78..e11a589a7522 100644 --- a/python/tvm/ansor/serialization.py +++ b/python/tvm/ansor/serialization.py @@ -22,6 +22,7 @@ import tvm._ffi from tvm.runtime import Object from .measure import MeasureCallback, MeasureErrorNo +from .loop_state import State from . import _ffi_api @@ -74,7 +75,8 @@ def write_measure_records_to_file(filename, inputs, results): def get_states_from_measure_inputs(inputs, task): """Get states from measure inputs""" - return _ffi_api.GetStatesFromMeasureInputs(inputs, task) + state_objects = _ffi_api.GetStatesFromMeasureInputs(inputs, task) + return [State(s) for s in state_objects] def best_measure_pair_in_file(filename, workload_key=None, target=None): diff --git a/scripts/tune_network.py b/scripts/tune_network.py index f1f7cd54f8c6..5f22e31d50f7 100644 --- a/scripts/tune_network.py +++ b/scripts/tune_network.py @@ -191,22 +191,14 @@ def create_module(data_shape, graph, lib, target, input_name, params, debug_prof return module, ctx -def tune_and_evaluate(network_name, model_path, batch_size, target, target_host, - local_measure, device_key, host, port, n_parallel, ndk_cc, - build_timeout, run_timeout, num_threads, tune, check_correctness, - debug_profile, tuning_parameters, record_file, layout_set): - task_scheduler, model_type, policy, log_file, load_log_file = (tuning_parameters['task_scheduler'], - tuning_parameters['model_type'], tuning_parameters['policy'], - tuning_parameters['log_file'], tuning_parameters['load_log_file']) - - if layout_set: - layout = layout_set - +def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, + debug_profile, check_correctness, network_parameters, + task_scheduler_parameters, tune_parameters, module_parameters): # Extract workloads from relay program - print("=============== Extract workloads ===============") - mod, params, input_name, data_shape, out_shape = get_network(network_name, model_path, batch_size, layout) + mod, params, input_name, data_shape, out_shape = get_network(**network_parameters) if tune: + print("=============== Extracting workloads ===============") workloads, wkl_weights = ansor.extract_from_program(mod, target=target, params=params, ops=(relay.op.nn.dense, relay.op.nn.softmax, relay.op.nn.conv2d, relay.op.nn.conv2d_transpose, @@ -215,7 +207,7 @@ def tune_and_evaluate(network_name, model_path, batch_size, target, target_host, relay.op.nn.conv3d, relay.op.nn.adaptive_avg_pool3d, relay.op.nn.batch_matmul, relay.op.mean, )) - print("Total workload number: %d" % (len(workloads))) + print("Totally %d workload extracted." % (len(workloads))) # Tune workloads with auto scheduler print("=============== Tuning ===============") @@ -225,23 +217,13 @@ def tune_and_evaluate(network_name, model_path, batch_size, target, target_host, print("[========= Task %d =========]\n" % i, dag) tasks.append(ansor.SearchTask(dag, wkl_key, target, target_host)) - def objective_func(costs): - return sum(c * w for c, w in zip(costs, wkl_weights)) - - tuner = ansor.SimpleTaskScheduler(tasks, objective_func, strategy=task_scheduler, - load_log_file=load_log_file, - load_model_file=tuning_parameters['load_model']) + tuner = ansor.SimpleTaskScheduler(tasks, + lambda costs: sum(c * w for c, w in zip(costs, wkl_weights)), + **task_scheduler_parameters) + tune_option, measure_ctx = create_tune_option(target, **tune_parameters) - tune_option, measure_ctx = create_tune_option(target, log_file, - tuning_parameters['n_trials'], tuning_parameters['num_measure_per_iter'], - tuning_parameters['verbose'], n_parallel, build_timeout, - local_measure, device_key, host, port, ndk_cc, - tuning_parameters['early_stopping']) - search_policy = "%s.%s" % (policy, model_type) - - if local_measure and target.target_name != 'cuda': + if tune_parameters['local_measure'] and target.target_name != 'cuda': os.environ['TVM_BIND_MASTER_CORE_0'] = "1" - tuner.tune(tune_option, search_policy) if measure_ctx: @@ -251,15 +233,13 @@ def objective_func(costs): # Compile graph with best states found by auto-scheduler print("=============== Compile ===============") - with ansor.apply_history_best(log_file, args.log_n_lines): - #if True: - #with ansor.BlockingEmptyContext(): + with ansor.apply_history_best(tune_parameters['log_file'], log_n_lines): os.environ['TVM_AUTO_CACHE_FLUSH'] = "0" os.environ['TVM_BIND_MASTER_CORE_0'] = "1" if kernel_layout_rewrite: ansor.prepare_layout_rewrite(mod, target=target, - params=params, - ops=(relay.op.nn.dense, relay.op.nn.conv2d, relay.op.nn.conv3d)) + params=params, + ops=(relay.op.nn.dense, relay.op.nn.conv2d, relay.op.nn.conv3d)) else: # disable layout rewrite ansor.LayoutRewriteLevel.BOTH_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE @@ -268,22 +248,12 @@ def objective_func(costs): with relay.build_config(opt_level=3): graph, lib, opt_params = relay.build_module.build( mod, target=target, params=params) - ''' - from tvm.relay.backend import graph_runtime_codegen - with relay.build_config(opt_level=3): - opt_mod, _ = relay.optimize(mod, target, params) - grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target) - grc.codegen(opt_mod["main"]) - with tvm.transform.PassContext(opt_level=3): - graph, lib, opt_params = relay.build_module.build( - mod, target=target, params=params) - ''' + ansor.finish_layout_rewrite() print("=============== Compile Finish ===============") - module, ctx = create_module(data_shape, graph, lib, target, input_name, opt_params, - debug_profile, local_measure, ndk_cc, - device_key, host, port, run_timeout, num_threads) + module, ctx = create_module(data_shape, graph, lib, target, input_name, + opt_params, debug_profile, **module_parameters) # Evaluate print("========== Evaluate ==========") @@ -315,9 +285,8 @@ def objective_func(costs): graph, lib, opt_params = relay.build_module.build( mod, target=target, params=params) - module, _ = create_module(data_shape, graph, lib, target, input_name, opt_params, - debug_profile, local_measure, ndk_cc, - device_key, host, port, run_timeout, num_threads) + module, _ = create_module(data_shape, graph, lib, target, input_name, + opt_params, debug_profile, **module_parameters) module.run() expected_output = module.get_output(0).asnumpy() @@ -343,7 +312,7 @@ def objective_func(costs): # Strategy related options parser.add_argument("--seed", type=int, default=0, help='random seed') parser.add_argument("--policy", type=str, choices=['multi-stage', 'meta-rewrite'], - default='meta-rewrite') + default='meta-rewrite') parser.add_argument("--model-type", type=str, choices=['xgb', 'random', 'no-share'], default='xgb') parser.add_argument("--task-scheduler", type=str, default='gradient', choices=['no', 'gradient', 'round-robin'], @@ -359,6 +328,7 @@ def objective_func(costs): # Detailed control options parser.add_argument("--build-timeout", type=int, default=10) parser.add_argument("--run-timeout", type=int, default=10) + parser.add_argument("--early-stopping", type=int, default=-1) parser.add_argument("--verbose", type=int, default=1) parser.add_argument("--local-measure", type=str2bool, nargs='?', const=True, default=True) parser.add_argument("--device-key", type=str, default=None) @@ -375,23 +345,59 @@ def objective_func(costs): logging.getLogger('ansor').setLevel(logging.DEBUG) target = tvm.target.create(args.target) + log_file = args.log_file or "%s-B%d-%s.json" % (args.network, args.batch_size, + target.target_name) + load_log_file = args.load_log or log_file + search_policy = "%s.%s" % (args.policy, args.model_type) + if args.layout: + layout = args.layout + elif target.target_name == "cuda": + layout = "NCHW" + else: + layout = "NHWC" + + network_parameters = { + 'name': args.network, + 'model_path': args.model_path, + 'batch_size': args.batch_size, + 'layout': layout + } + + task_scheduler_parameters = { + 'strategy': args.task_scheduler, + 'load_log_file': load_log_file, + 'load_model_file': args.load_model, + 'verbose': args.verbose, + } - tuning_parameters = { + control_parameters = { + 'local_measure': args.local_measure, + 'device_key': args.device_key, + 'host': args.host, + 'port': args.port, + 'ndk_cc': args.ndk_cc, + } + + tune_parameters = { + 'log_file': log_file, 'n_trials': args.n_trials, 'num_measure_per_iter': args.num_measure_per_iter, - 'log_file': args.log_file or "%s-B%d.json" % (args.network, args.batch_size), - 'load_model': args.load_model, - 'model_type': args.model_type, - 'task_scheduler': args.task_scheduler, - 'policy': args.policy, - 'early_stopping': -1, - 'verbose': 1, + 'verbose': args.verbose, + 'n_parallel': args.n_parallel, + 'build_timeout': args.build_timeout, + 'run_timeout': args.run_timeout, + 'early_stopping': args.early_stopping, + **control_parameters + } + + module_parameters = { + 'run_timeout': args.run_timeout, + 'num_threads': args.num_threads, + **control_parameters } - tuning_parameters['load_log_file'] = args.load_log or tuning_parameters['log_file'] os.environ["TOPHUB_LOCATION"] = "NONE" - tune_and_evaluate(args.network, args.model_path, args.batch_size, target, args.target_host, - args.local_measure, args.device_key, args.host, - args.port, args.n_parallel, args.ndk_cc, args.build_timeout, - args.run_timeout, args.num_threads, args.tune, args.check_correctness, - args.debug_profile, tuning_parameters, args.out_file, args.layout) + tune_and_evaluate(target, args.target_host, args.log_n_lines, search_policy, + args.tune, args.debug_profile, args.check_correctness, + network_parameters, task_scheduler_parameters, tune_parameters, + module_parameters) diff --git a/scripts/tune_test.py b/scripts/tune_test.py index d6f552affbb1..a49ecd088afc 100644 --- a/scripts/tune_test.py +++ b/scripts/tune_test.py @@ -14,7 +14,7 @@ def create_tune_option(target, log_file, n_trials, num_measure_per_iter, verbose, n_parallel, build_timeout, local_measure, device_key, host, - port, ndk_cc, early_stopping=-1): + port, ndk_cc, early_stopping=-1, run_timeout=10): builder = runner = measure_ctx = None if local_measure: builder = ansor.LocalBuilder(timeout=build_timeout) @@ -26,7 +26,7 @@ def create_tune_option(target, log_file, n_trials, num_measure_per_iter, verbose else: os.environ['TVM_NDK_CC'] = ndk_cc builder = ansor.LocalBuilder(timeout=build_timeout, build_func='ndk') - runner = ansor.RPCRunner(key=device_key, host=host, port=port, + runner = ansor.RPCRunner(key=device_key, host=host, port=port, timeout=run_timeout, n_parallel=n_parallel, repeat=1, min_repeat_ms=400) tune_option = ansor.TuneOption(n_trials=n_trials, early_stopping=early_stopping, diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index faaac94f3323..77361dbf837c 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -1019,11 +1019,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) PrintState(&p->stream, node, true); }); - -TVM_REGISTER_GLOBAL("ansor.StageGetIterator").set_body_typed([](const Stage& stage, int index) { - return stage->iters[index]; -}); - TVM_REGISTER_GLOBAL("ansor.StageGetIterators").set_body_typed([](const Stage& stage) { return Array(stage->iters); }); diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.cc b/src/ansor/search_policy/meta_tile_rewrite_policy.cc index 5703e17ba29f..4a045d31a487 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.cc +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.cc @@ -1267,6 +1267,9 @@ void MetaTileRewritePolicyNode::SampleInitPopulation(const std::vector& m if (InitPopulationThreadBind(this, &tmp_s)) { continue_count++; + if (continue_count == out_size) { + StdCout(verbose_) << "Initial Population Sampling..." << std::endl; + } continue; } diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 4887ef0ee47d..d3af64a4f576 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -132,6 +132,13 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { TVM_TRY_REWRITE(ramp(b1, s1, lanes) + broadcast(x, lanes), ramp(b1 + x, s1, lanes)); TVM_TRY_REWRITE(broadcast(x, lanes) + ramp(b1, s1, lanes), ramp(x + b1, s1, lanes)); TVM_TRY_REWRITE(broadcast(x, lanes) + broadcast(y, lanes), broadcast(x + y, lanes)); + if ((x + broadcast(y, lanes)).Match(ret)) { + if (auto ps = y.Eval().as()) { + if (ps->value == 0.0) { + return x.Eval(); + } + } + } } if (IsIndexType(op->dtype)) { @@ -422,6 +429,13 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) { TVM_TRY_REWRITE(broadcast(x, lanes) * broadcast(y, lanes), broadcast(x * y, lanes)); TVM_TRY_REWRITE(ramp(b1, s1, lanes) * broadcast(x, lanes), ramp(b1 * x, s1 * x, lanes)); TVM_TRY_REWRITE(broadcast(x, lanes) * ramp(b1, s1, lanes), ramp(b1 * x, s1 * x, lanes)); + if ((broadcast(x, lanes) * y).Match(ret)) { + if (auto ps = x.Eval().as()) { + if (ps->value == 0.0) { + return make_const(op->dtype, 0.0); + } + } + } } if (IsIndexType(op->dtype)) { @@ -700,9 +714,9 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression - PVar x, y, z, b1; + PVar w, x, y, z, b1; // Pattern var match IntImm - PVar c1, c2, c3; + PVar c1, c2, c3, c4; // Pattern var for lanes in broadcast and ramp PVar lanes; @@ -767,6 +781,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE_IF(floordiv(max(y, x * c1), c2), max(floordiv(y, c2), x * floordiv(c1, c2)), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2), floordiv(x * c1, c2), + c1.Eval()->value > 0 && c2.Eval()->value > 0 && + c2.Eval()->value % c1.Eval()->value == 0 && + CanProveGreaterEqual(-y.Eval(), -c1.Eval()->value + 1)); + // Rules involving 3-operands. TVM_TRY_REWRITE_IF(floordiv(x * c1 + y + z, c2), x * floordiv(c1, c2) + floordiv(y + z, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); @@ -783,6 +802,13 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE_IF(floordiv(x + c1, c2), floordiv(x, c2) + floordiv(c1, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF(floordiv(x * c1 + y * c2 + z, c3), floordiv(x * c1 + y * c2, c3), + c1.Eval()->value > 0 && c2.Eval()->value > 0 && c3.Eval()->value > 0 && + c3.Eval()->value % c1.Eval()->value == 0 && + c3.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(-z.Eval(), + std::max(-c1.Eval()->value, -c2.Eval()->value) + 1)); + TVM_TRY_REWRITE_IF(floordiv(x + y, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0)); TVM_TRY_REWRITE_IF(floordiv(y + x, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0)); @@ -807,6 +833,18 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { CanProveGreaterEqual(z.Eval(), 0)); TVM_TRY_REWRITE_IF(floordiv(y + z * x, z), floordiv(y, z) + x, CanProveGreaterEqual(z.Eval(), 0)); + + // Rules involving 4-operands + TVM_TRY_REWRITE_IF(floordiv(w * c1 + x * c2 + y * c3 + z, c4), + floordiv(w * c1 + x * c2 + y * c3, c4), + c1.Eval()->value > 0 && c2.Eval()->value > 0 && + c3.Eval()->value > 0 && c4.Eval()->value > 0 && + c4.Eval()->value % c1.Eval()->value == 0 && + c4.Eval()->value % c2.Eval()->value == 0 && + c4.Eval()->value % c3.Eval()->value == 0 && + CanProveGreaterEqual(-z.Eval(), + std::max(-c1.Eval()->value, + std::max(-c2.Eval()->value, -c3.Eval()->value)) + 1)); } return ret; } @@ -818,9 +856,9 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { if (const_res.defined()) return const_res; // Pattern var to match any expression - PVar x, y, z, b1; + PVar w, x, y, z, b1; // Pattern var match IntImm - PVar c1, c2; + PVar c1, c2, c3, c4; // Pattern var for lanes in broadcast and ramp PVar lanes; @@ -864,6 +902,31 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x, floordiv(c2, c1)) * c1 + y, + c1.Eval()->value > 0 && c2.Eval()->value > 0 && + c2.Eval()->value % c1.Eval()->value == 0 && + CanProveGreaterEqual(-y.Eval(), -c1.Eval()->value + 1)); + + // TODO(jcf94): For the next three rules, better use the max common factor + // of c1, c2, c3 to do the simplify + TVM_TRY_REWRITE_IF(floormod(x * c1 + y * c2 + z, c3), + floormod(x * floordiv(c1, c2) + y, floordiv(c3, c2)) * c2 + z, + c1.Eval()->value > 0 && c2.Eval()->value > 0 && + c3.Eval()->value > 0 && + c3.Eval()->value % c2.Eval()->value == 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(-z.Eval(), -c2.Eval()->value + 1)); + + TVM_TRY_REWRITE_IF(floormod(w * c1 + x * c2 + y * c3 + z, c4), + floormod(w * floordiv(c1, c3) + x * floordiv(c2, c3) + y, + floordiv(c4, c3)) * c3 + z, + c1.Eval()->value > 0 && c2.Eval()->value > 0 && + c3.Eval()->value > 0 && c4.Eval()->value > 0 && + c4.Eval()->value % c3.Eval()->value == 0 && + c1.Eval()->value % c3.Eval()->value == 0 && + c2.Eval()->value % c3.Eval()->value == 0 && + CanProveGreaterEqual(-z.Eval(), -c3.Eval()->value + 1)); + // try modular analysis if (floormod(x, c1).Match(ret)) { ModularSet mod = analyzer_->modular_set(x.Eval()); diff --git a/tests/python/unittest/test_ansor_common.py b/tests/python/unittest/test_ansor_common.py index 1790b06bcb60..e23dba2aa4e3 100644 --- a/tests/python/unittest/test_ansor_common.py +++ b/tests/python/unittest/test_ansor_common.py @@ -77,5 +77,5 @@ def get_tiled_matmul(): C += 1 C_global += 1 s0.compute_at(A_global, C_global, s0.stages[C_global].iters[2]) - return dag, s0.state_object + return dag, s0 diff --git a/tests/python/unittest/test_ansor_compute_dag.py b/tests/python/unittest/test_ansor_compute_dag.py index b60136d4265f..b8afcc6a5b23 100644 --- a/tests/python/unittest/test_ansor_compute_dag.py +++ b/tests/python/unittest/test_ansor_compute_dag.py @@ -33,7 +33,6 @@ def test_apply_steps(): def test_infer_bound(): dag, s = get_tiled_matmul() s = dag.infer_bound_from_state(s) - s = ansor.loop_state.State(s) A_global, B_global, C_global = 1, 3, 4 assert s.stages[B_global].iters[0].range.extent == 512 @@ -62,7 +61,7 @@ def test_lower_legalize_invalid_attach(): s.compute_at(A, B, s.stages[B].iters[1]) s.split(B, s.stages[B].iters[1], [2]) - sch, tensors = dag.apply_steps_from_state(s.state_object) + sch, tensors = dag.apply_steps_from_state(s) stmt = tvm.lower(sch, tensors, simple_mode=True) diff --git a/tests/python/unittest/test_ansor_feature.py b/tests/python/unittest/test_ansor_feature.py index 3da1c7aa332e..567fc080c6f8 100644 --- a/tests/python/unittest/test_ansor_feature.py +++ b/tests/python/unittest/test_ansor_feature.py @@ -47,7 +47,7 @@ def test_cpu_matmul(): target = tvm.target.create('llvm') task = ansor.SearchTask(dag, "test", target) names = ansor.feature.get_per_stmt_feature_names() - fea = ansor.feature.get_per_stmt_features_from_states([s.state_object], task)[0] + fea = ansor.feature.get_per_stmt_features_from_states([s], task)[0] stage_0 = fea[0] assert len(stage_0) == len(names), "%d vs %d" % (len(stage_0), len(names)) @@ -91,7 +91,7 @@ def fusion_test(N, M): target = tvm.target.create('llvm') task = ansor.SearchTask(dag, "test", target) names = ansor.feature.get_per_stmt_feature_names() - fea = ansor.feature.get_per_stmt_features_from_states([s.state_object], task)[0] + fea = ansor.feature.get_per_stmt_features_from_states([s], task)[0] found = False for stage_fea in fea: diff --git a/tests/python/unittest/test_ansor_measure.py b/tests/python/unittest/test_ansor_measure.py index 2ac54d3c765b..d457dd2c55cc 100644 --- a/tests/python/unittest/test_ansor_measure.py +++ b/tests/python/unittest/test_ansor_measure.py @@ -43,7 +43,7 @@ def test_serialization(): s2 = dag.infer_bound_from_state(inputs[0].state) assert s1 == s2 - assert not (s1 == dag.get_init_state().state_object) + assert not (s1 == dag.get_init_state()) def test_measure_local_builder_runner(): diff --git a/tests/python/unittest/test_ansor_vectorized_cooperative_fetching.py b/tests/python/unittest/test_ansor_vectorized_cooperative_fetching.py new file mode 100644 index 000000000000..c41abc7bcb3d --- /dev/null +++ b/tests/python/unittest/test_ansor_vectorized_cooperative_fetching.py @@ -0,0 +1,152 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" Test for vectorized cooperative fetching """ + +import numpy as np +import tvm +from tvm import ansor, te +import topi + +from test_ansor_common import matmul_ansor_test, conv2d_nchw_bn_relu + + +def init_common(): + dag = ansor.ComputeDAG(matmul_ansor_test(512, 512, 512)) + s0 = dag.get_init_state() + A, B, C = 0, 1, 2 + B_shared = s0.cache_read(B, "shared", [C], dag) + C += 1 + B_local = s0.cache_read(B_shared, "local", [C], dag) + C += 1 + A_shared = s0.cache_read(A, "shared", [C], dag) + B += 1 + B_shared += 1 + B_local += 1 + C += 1 + A_local = s0.cache_read(A_shared, "local", [C], dag) + B += 1 + B_shared += 1 + B_local += 1 + C += 1 + + return A_shared, A_local, B_shared, B_local, C, dag, s0 + +def check_common(dag, state): + s, args = dag.apply_steps_from_state(state) + # To check if every vectorize loop transforms to ramp expr successfully + # TODO(jcf94): Find a better way to process the check in AST + print(tvm.lower(s, args)) + + if tvm.context("cuda", 0).exist: + tgt = tvm.target.cuda() + mod = tvm.build(s, args, tgt) + # To check if every vectorize loop transforms to correct instruction + print(mod.imported_modules[0].get_source()) + + ctx = tvm.context("cuda", 0) + dtype = dag.tensors[0].dtype + a = tvm.nd.array(np.random.uniform(size=(512, 512)).astype(dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=(512, 512)).astype(dtype), ctx) + c = tvm.nd.array(np.zeros((512, 512), dtype=dtype), ctx) + mod(a, b, c) + tvm.testing.assert_allclose(c.asnumpy(), np.dot( + a.asnumpy(), b.asnumpy()), rtol=1e-5) + else: + print("CUDA device not found, skip this test.") + +def test_vectorized_cooperative_fetching_x(): + A_shared, A_local, B_shared, B_local, C, dag, s0 = init_common() + + its0 = s0.split(C, s0.stages[C].iters[0], [1, 8, 2, 4]) + its1 = s0.split(C, s0.stages[C].iters[5], [2, 8, 2, 4]) + its2 = s0.split(C, s0.stages[C].iters[10], [8, 8]) + s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its2[0], + its2[1], its0[3], its1[3], its2[2], its0[4], its1[4]]) + s0.fuse(C, [s0.stages[C].iters[0], s0.stages[C].iters[1]]) + s0.bind_thread(C, s0.stages[C].iters[0], "blockIdx.x") + s0.fuse(C, [s0.stages[C].iters[1], s0.stages[C].iters[2]]) + s0.bind_thread(C, s0.stages[C].iters[1], "vthread") + s0.fuse(C, [s0.stages[C].iters[2], s0.stages[C].iters[3]]) + s0.bind_thread(C, s0.stages[C].iters[2], "threadIdx.x") + s0.vectorize(C, its1[4]) + + s0.compute_at(B_shared, C, s0.stages[C].iters[3]) + fused_it = s0.fuse(B_shared, s0.stages[B_shared].iters[:]) + its = s0.split(B_shared, fused_it, [64, 4]) + s0.bind_thread(B_shared, its[1], "threadIdx.x") + s0.vectorize(B_shared, its[2]) + s0.compute_at(B_local, C, s0.stages[C].iters[4]) + fused_it = s0.fuse(B_local, s0.stages[B_local].iters[:]) + its = s0.split(B_local, fused_it, [4]) + s0.vectorize(B_local, its[1]) + + s0.compute_at(A_shared, C, s0.stages[C].iters[3]) + fused_it = s0.fuse(A_shared, s0.stages[A_shared].iters[:]) + its = s0.split(A_shared, fused_it, [64, 4]) + s0.bind_thread(A_shared, its[1], "threadIdx.x") + s0.vectorize(A_shared, its[2]) + s0.compute_at(A_local, C, s0.stages[C].iters[4]) + fused_it = s0.fuse(A_local, s0.stages[A_local].iters[:]) + its = s0.split(A_local, fused_it, [4]) + s0.vectorize(A_local, its[1]) + + check_common(dag, s0) + +def test_vectorized_cooperative_fetching_xy(): + A_shared, A_local, B_shared, B_local, C, dag, s0 = init_common() + + its0 = s0.split(C, s0.stages[C].iters[0], [1, 8, 2, 4]) + its1 = s0.split(C, s0.stages[C].iters[5], [2, 8, 2, 4]) + its2 = s0.split(C, s0.stages[C].iters[10], [8, 8]) + s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its2[0], + its2[1], its0[3], its1[3], its2[2], its0[4], its1[4]]) + s0.fuse(C, [s0.stages[C].iters[0], s0.stages[C].iters[1]]) + s0.bind_thread(C, s0.stages[C].iters[0], "blockIdx.x") + s0.fuse(C, [s0.stages[C].iters[1], s0.stages[C].iters[2]]) + s0.bind_thread(C, s0.stages[C].iters[1], "vthread") + s0.bind_thread(C, s0.stages[C].iters[2], "threadIdx.x") + s0.bind_thread(C, s0.stages[C].iters[3], "threadIdx.y") + s0.vectorize(C, its1[4]) + + s0.compute_at(B_shared, C, s0.stages[C].iters[4]) + fused_it = s0.fuse(B_shared, s0.stages[B_shared].iters[:]) + its = s0.split(B_shared, fused_it, [8, 8, 4]) + s0.bind_thread(B_shared, its[1], "threadIdx.x") + s0.bind_thread(B_shared, its[2], "threadIdx.y") + s0.vectorize(B_shared, its[3]) + s0.compute_at(B_local, C, s0.stages[C].iters[5]) + fused_it = s0.fuse(B_local, s0.stages[B_local].iters[:]) + its = s0.split(B_local, fused_it, [4]) + s0.vectorize(B_local, its[1]) + + s0.compute_at(A_shared, C, s0.stages[C].iters[4]) + fused_it = s0.fuse(A_shared, s0.stages[A_shared].iters[:]) + its = s0.split(A_shared, fused_it, [8, 8, 4]) + s0.bind_thread(A_shared, its[1], "threadIdx.x") + s0.bind_thread(A_shared, its[2], "threadIdx.y") + s0.vectorize(A_shared, its[3]) + s0.compute_at(A_local, C, s0.stages[C].iters[5]) + fused_it = s0.fuse(A_local, s0.stages[A_local].iters[:]) + its = s0.split(A_local, fused_it, [4]) + s0.vectorize(A_local, its[1]) + + check_common(dag, s0) + +if __name__ == "__main__": + test_vectorized_cooperative_fetching_x() + test_vectorized_cooperative_fetching_xy() diff --git a/tutorials/ansor/tune_conv2d_cuda.py b/tutorials/ansor/tune_conv2d_cuda.py index caa040d1b3bc..14a6ee797276 100644 --- a/tutorials/ansor/tune_conv2d_cuda.py +++ b/tutorials/ansor/tune_conv2d_cuda.py @@ -122,11 +122,17 @@ def conv2d_nchw(N, H, W, CO, CI, KH, KW, stride, padding): # During the searching process, we may generate several invalid schedules and they # will be filtered out. It's fine to see "Encountered errors during feature extraction." # in the tuning logs. +# :code:`ansor.LogToFile` callback will log the tuning results into a +# log file, which can be used to get the best config later. +# :code:`ansor.PreLoadMeasuredStates` callback will load measured states +# from history log before schedule search, we can add this callback to make +# sure a same schedule will never be measured for multiple times. measure_ctx = ansor.LocalRPCMeasureContext(repeat=3, min_repeat_ms=100, timeout=4) tune_option = ansor.TuneOption(n_trials=20, runner=measure_ctx.runner, - measure_callbacks=[ansor.LogToFile(log_file)]) + measure_callbacks=[ansor.LogToFile(log_file)], + pre_search_callbacks=[ansor.PreLoadMeasuredStates(log_file)]) s, arg_bufs = ansor.auto_schedule(task, search_policy=search_policy, tune_option=tune_option) print("==== Get Lowered Stmt ====") diff --git a/tutorials/ansor/tune_simple_subgraph.py b/tutorials/ansor/tune_simple_subgraph.py index fedbb399d0cf..dfd36e89fd4c 100644 --- a/tutorials/ansor/tune_simple_subgraph.py +++ b/tutorials/ansor/tune_simple_subgraph.py @@ -146,8 +146,11 @@ def matmul_add(N, L, M, dtype): # # We only make 5 trials in this tutorial for demonstration. In practice, # you can do more trials according to your time budget. -# The :code:`ansor.LogToFile` callback will log the tuning results into a +# :code:`ansor.LogToFile` callback will log the tuning results into a # log file, which can be used to get the best config later. +# :code:`ansor.PreLoadMeasuredStates` callback will load measured states +# from history log before schedule search, we can add this callback to make +# sure a same schedule will never be measured for multiple times. log_file = "matmul_add.json" @@ -157,7 +160,8 @@ def matmul_add(N, L, M, dtype): search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) tune_option = ansor.TuneOption(n_trials=5, - measure_callbacks=[ansor.LogToFile(log_file)]) + measure_callbacks=[ansor.LogToFile(log_file)], + pre_search_callbacks=[ansor.PreLoadMeasuredStates(log_file)]) ################################################################ # Then just call :code:`ansor.auto_schedule` and Ansor will try to find a high From 4ea67122b82a9ce50f5299018a99eeeed1d37ee5 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Tue, 16 Jun 2020 14:00:52 +0800 Subject: [PATCH 26/45] Update PreLoadMeasuredStates & Some bug fix (#27) * Add a threading wrapper to fix the test bug * Set default TVM_USE_AUTO_SCHEDULER to false * Update PreLoadMeasuredStates callback --- python/tvm/ansor/auto_schedule.py | 3 +- python/tvm/ansor/relay_integration.py | 18 +++- python/tvm/ansor/task_scheduler.py | 11 +++ scripts/tune_network.py | 2 +- .../search_policy/meta_tile_rewrite_policy.h | 6 -- src/ansor/search_policy/search_policy.cc | 30 ++++-- src/ansor/search_policy/search_policy.h | 4 + src/ansor/search_policy/utils.cc | 5 +- .../unittest/test_ansor_relay_Integration.py | 96 +++++++++++++++++++ .../unittest/test_ansor_search_policy.py | 8 +- .../unittest/test_ansor_task_scheduler.py | 19 +++- topi/python/topi/arm_cpu/__init__.py | 2 +- topi/python/topi/generic/__init__.py | 2 +- topi/python/topi/x86/__init__.py | 2 +- 14 files changed, 178 insertions(+), 30 deletions(-) create mode 100644 tests/python/unittest/test_ansor_relay_Integration.py diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index 127be4c7ad22..232c24ee89ea 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -81,6 +81,7 @@ def set_verbose(self, verbose): def run_callbacks(self, callbacks): _ffi_api.SearchPolicyRunCallbacks(self, callbacks) + @tvm._ffi.register_object("ansor.MetaTileRewritePolicy") class MetaTileRewritePolicy(SearchPolicy): """ The search policy that searches with meta tiling and random rewrite @@ -231,7 +232,7 @@ def auto_schedule(workload, target=None, Parameters ---------- - workload : Str or SearchTask + workload : Union[SearchTask, str] target : Target diff --git a/python/tvm/ansor/relay_integration.py b/python/tvm/ansor/relay_integration.py index de2e12e389e7..348828eec4b4 100644 --- a/python/tvm/ansor/relay_integration.py +++ b/python/tvm/ansor/relay_integration.py @@ -20,6 +20,9 @@ 99.9% copy-paste of implementation by @MerryMercy """ +import os +os.environ['TVM_USE_AUTO_SCHEDULER'] = 'true' + import threading import warnings import tvm @@ -95,7 +98,7 @@ def init_op_to_schedule_map(): relay.op.nn.batch_matmul: [topi.generic.schedule_batch_matmul], } -def extract_from_program(mod, params, ops, target, target_host=None): +def extract_from_program(mod, params, target, target_host=None, ops=None): """ Extract tuning tasks from a relay program. This function is the single program version of extract_from_multiple_program. @@ -117,9 +120,9 @@ def extract_from_program(mod, params, ops, target, target_host=None): ------- workloads: Array of Tuple(wkl_key, target) """ - return extract_from_multiple_program([mod], [params], ops, target, target_host) + return extract_from_multiple_program([mod], [params], target, target_host, ops) -def extract_from_multiple_program(mods, params, ops, target, target_host=None): +def extract_from_multiple_program(mods, params, target, target_host=None, ops=None): """ Extract tuning tasks from multiple relay programs. This function collects tuning tasks by building a list of programs @@ -148,6 +151,15 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None): init_op_to_schedule_map() topi_scheds = [] + + if not ops: + ops = [relay.op.nn.dense, relay.op.nn.softmax, relay.op.nn.conv2d, + relay.op.nn.conv2d_transpose, relay.op.nn.max_pool2d, + relay.op.nn.avg_pool2d, relay.op.nn.global_max_pool2d, + relay.op.nn.global_avg_pool2d, relay.op.nn.conv3d, + relay.op.nn.adaptive_avg_pool3d, relay.op.nn.batch_matmul, + relay.op.mean] + for op_name in ops: if op_name in OP_TO_SCHEDULE: topi_scheds.extend(OP_TO_SCHEDULE[op_name]) diff --git a/python/tvm/ansor/task_scheduler.py b/python/tvm/ansor/task_scheduler.py index 082b2d265140..f8d3f419dcb4 100644 --- a/python/tvm/ansor/task_scheduler.py +++ b/python/tvm/ansor/task_scheduler.py @@ -145,6 +145,17 @@ def __init__(self, self.sequential_now_task_begin_ct = 0 def tune(self, tune_option: TuneOption, search_policy: Union[str, List[SearchPolicy]] = 'default'): + """ Tune tasks. + + Notice: This method does not have return value, make sure to set `LogToFile` + measure callback in `tune_option`. + + Parameters + ---------- + tune_option: TuneOption + + search_policy: Str or List[SearchPolicy] + """ # init members self.task_cts = [0 for _ in range(len(self.tasks))] self.task_costs_history = [[] for _ in range(len(self.tasks))] diff --git a/scripts/tune_network.py b/scripts/tune_network.py index 5f22e31d50f7..5e5a337c7bce 100644 --- a/scripts/tune_network.py +++ b/scripts/tune_network.py @@ -7,7 +7,7 @@ import numpy as np import tvm -from tvm import _ffi, relay, ansor +from tvm import _ffi, ansor, relay import tvm.contrib.graph_runtime as runtime from tvm.contrib.debugger import debug_runtime from tvm.contrib import util, ndk diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.h b/src/ansor/search_policy/meta_tile_rewrite_policy.h index befc002b6aa2..6930a71038a3 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.h +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.h @@ -103,12 +103,6 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode { SplitFactorizationMemo split_memo_; // Memorize split space for Split std::mt19937 rand_gen_; // Random generator int num_measure_per_iter_; // The number of states to measure per iteration - - // The array of already measured states. - std::vector measured_states_vector_; - - // The throughputs of already measured states - std::vector measured_states_throughputs_; }; TVM_DEFINE_MUTABLE_OBJECT_REF(MetaTileRewritePolicy, MetaTileRewritePolicyNode); diff --git a/src/ansor/search_policy/search_policy.cc b/src/ansor/search_policy/search_policy.cc index c07a3af7473c..685052f3f71f 100644 --- a/src/ansor/search_policy/search_policy.cc +++ b/src/ansor/search_policy/search_policy.cc @@ -37,28 +37,44 @@ TVM_REGISTER_OBJECT_TYPE(PreLoadMeasuredStatesNode); void SearchPolicyNode::PreLoadMeasuredStates(const std::string& log_file) { LogReader reader = LogReaderNode::make(log_file); const auto& res = reader->ReadLines(-1); - if (res.first.size()) { + size_t log_size = res.first.size(); + CHECK_EQ(log_size, res.second.size()); + if (log_size) { std::vector measured_states; - for (const auto& inp : res.first) { + std::vector measured_throughputs; + for (size_t i = 0; i < log_size; i++) { + const auto& inp = res.first[i]; if (inp->task->workload_key == cur_task_->workload_key && inp->task->target->target_name.compare( cur_task_->target->target_name) == 0) { State state = cur_task_->compute_dag.GetInitState(); state.CopyOnWrite()->transform_steps = inp->state->transform_steps; state.DoSteps(inp->state->transform_steps, cur_task_->compute_dag); - measured_states.push_back(std::move(state)); + measured_states.emplace_back(std::move(state)); + measured_throughputs.push_back(res.second[i]->error_no == 0 ? + (1.0 / FloatArrayMean(res.second[i]->costs)) : 0.0); } } cur_task_->compute_dag.InferBound(&measured_states); - for (auto state : measured_states) { - measured_states_set_.insert(state.ToStr()); + for (size_t i = 0; i < measured_states.size(); i ++) { + auto& state = measured_states[i]; + const auto& state_str = state.ToStr(); + if (!measured_states_set_.count(state_str)) { + measured_states_set_.insert(state_str); + if (measured_throughputs[i] != 0.0) { + measured_states_vector_.emplace_back(std::move(state)); + measured_states_throughputs_.emplace_back(measured_throughputs[i]); + } + } } StdCout(verbose_) << "Measured States Set: " << measured_states_set_.size() - << " state hashes loaded from " << log_file << std::endl; + << " state hashes loaded from " << log_file + << " for " << cur_task_->workload_key << std::endl; } else { StdCout(verbose_) << "Measured States Set: no states found from " - << log_file << std::endl; + << log_file << " for " << cur_task_->workload_key + << std::endl; } } diff --git a/src/ansor/search_policy/search_policy.h b/src/ansor/search_policy/search_policy.h index 2dfbd9429648..6085fd1816e8 100644 --- a/src/ansor/search_policy/search_policy.h +++ b/src/ansor/search_policy/search_policy.h @@ -101,6 +101,10 @@ class SearchPolicyNode : public Object { // The set of the already measured states. // We store the string format for redundancy check std::unordered_set measured_states_set_; + // The array of already measured states. + std::vector measured_states_vector_; + // The throughputs of already measured states + std::vector measured_states_throughputs_; }; TVM_DEFINE_MUTABLE_OBJECT_REF(SearchPolicy, SearchPolicyNode); diff --git a/src/ansor/search_policy/utils.cc b/src/ansor/search_policy/utils.cc index 608b89da118c..e0fd00b23e7b 100644 --- a/src/ansor/search_policy/utils.cc +++ b/src/ansor/search_policy/utils.cc @@ -311,9 +311,10 @@ State RandomMutateTileSize(const State& old_state, SplitFactorizationMemo* split CHECK(ps != nullptr); extent = GetIntImm(ps->extent); retry_ct += 1; - } while (retry_ct < static_cast(split_step_ids.size()) << 2 && extent == 1); + } while (retry_ct < static_cast(split_step_ids.size()) << 2 && + (extent == 1 || extent == 0)); - if (extent == 1) { + if (extent == 0 || extent == 1) { return State(); } diff --git a/tests/python/unittest/test_ansor_relay_Integration.py b/tests/python/unittest/test_ansor_relay_Integration.py new file mode 100644 index 000000000000..9c423220844c --- /dev/null +++ b/tests/python/unittest/test_ansor_relay_Integration.py @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Test Relay Integration """ + +import tempfile +import numpy as np + +import tvm +from tvm import ansor, relay +import tvm.contrib.graph_runtime as runtime + +from test_ansor_common import get_tiled_matmul + +def dense_graph(N, dtype="float32"): + ori_data = relay.var("data", shape=(N, N), dtype=dtype) + weight = relay.var("weight", shape=(N, N), dtype=dtype) + data = relay.multiply(ori_data, relay.const(2, dtype=dtype)) + dense = relay.nn.dense(data, weight, out_dtype=dtype) + dense = relay.add(dense, weight) + dense = relay.nn.dense(dense, weight, out_dtype=dtype) + return ori_data, weight, dense + +def test_dense_integration(): + N = 128 + data, weight, dense = dense_graph(N) + mod = relay.Function([data, weight], dense) + mod = tvm.IRModule.from_expr(mod) + + ctx = tvm.context("llvm") + target = tvm.target.create("llvm") + d = tvm.nd.array(np.random.uniform(size=(N, N)).astype(data.type_annotation.dtype), ctx) + w = tvm.nd.array(np.random.uniform(size=(N, N)).astype(weight.type_annotation.dtype), ctx) + workloads, wkl_weights = ansor.extract_from_program(mod, {}, target=target) + + assert len(workloads) == 2 + assert len(wkl_weights) == 2 + + tasks = [] + for wkl_key in workloads: + dag = ansor.workload_key_to_dag(wkl_key) + tasks.append(ansor.SearchTask(dag, wkl_key, target)) + + assert str(tasks[0].compute_dag) == "placeholder = PLACEHOLDER [128, 128]\n" + \ + "placeholder = PLACEHOLDER [128, 128]\n" + \ + "compute(z, y, x) += (placeholder[z, ((k*16) + x)]*placeholder[y, ((k*16) + x)])\n" + \ + "compute(y, x) += compute[y, x, kk]\n" + + assert str(tasks[1].compute_dag) == "placeholder = PLACEHOLDER [128, 128]\n" + \ + "placeholder = PLACEHOLDER [128, 128]\n" + \ + "compute(z, y, x) += (placeholder[z, ((k*16) + x)]*placeholder[y, ((k*16) + x)])\n" + \ + "compute(y, x) += compute[y, x, kk]\n" + \ + "T_add(ax0, ax1) = (compute[ax0, ax1] + placeholder[ax0, ax1])\n" + + tuner = ansor.SimpleTaskScheduler(tasks) + measure_ctx = ansor.LocalRPCMeasureContext() + with tempfile.NamedTemporaryFile() as fp: + tuner.tune(ansor.TuneOption(n_trials=4, runner=measure_ctx.runner, + measure_callbacks=[ansor.LogToFile(fp.name)])) + with ansor.apply_history_best(fp.name): + with relay.build_config(opt_level=3): + graph, lib, opt_params = relay.build_module.build( + mod, target=target) + + m = runtime.create(graph, lib, ctx) + m.set_input('data', d) + m.set_input('weight', w) + m.run() + res = m.get_output(0) + if measure_ctx: + del measure_ctx + + d = d.asnumpy() + d = d * 2 + w = w.asnumpy() + d = np.dot(d, np.transpose(w)) + d = d + w + d = np.dot(d, np.transpose(w)) + + tvm.testing.assert_allclose(res.asnumpy(), d, rtol=1e-5) + +if __name__ == "__main__": + test_dense_integration() diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py index b86dfa95f9bd..839992c67e0f 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -20,6 +20,7 @@ import random import numpy as np import tempfile +import threading import tvm from tvm import ansor @@ -73,8 +74,11 @@ def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local' def test_search_basic(): - search_common(seed=944563397) - + # Ansor search process with local runner has some modification on thread + # binding, wrap this to a subprocess to eliminate the impacts to other tests + t = threading.Thread(target=search_common, kwargs={'seed': 944563397}) + t.start() + t.join() def test_search_xgb_model_rpc_runner(): measure_ctx = ansor.LocalRPCMeasureContext() diff --git a/tests/python/unittest/test_ansor_task_scheduler.py b/tests/python/unittest/test_ansor_task_scheduler.py index e95d65d4b5ce..53cf2059c1f3 100644 --- a/tests/python/unittest/test_ansor_task_scheduler.py +++ b/tests/python/unittest/test_ansor_task_scheduler.py @@ -17,6 +17,8 @@ """Test the task scheduler """ +import threading + import tvm from tvm import ansor @@ -30,13 +32,20 @@ def test_task_scheduler_basic(): task1 = ansor.SearchTask(dag, "test", tgt) task2 = ansor.SearchTask(dag, "test", tgt) - def objective(costs): - return sum(costs) + def basic_test_func(task1, task2): + def objective(costs): + return sum(costs) - task_scheduler = ansor.SimpleTaskScheduler([task1, task2], objective) - tune_option = ansor.TuneOption(n_trials=3, runner='local') + task_scheduler = ansor.SimpleTaskScheduler([task1, task2], objective) + tune_option = ansor.TuneOption(n_trials=3, runner='local') + task_scheduler.tune(tune_option) - task_scheduler.tune(tune_option) + # Ansor search process with local runner has some modification on thread + # binding, wrap this to a subprocess to eliminate the impacts to other tests + t = threading.Thread(target=basic_test_func, + kwargs={'task1': task1, 'task2': task2}) + t.start() + t.join() if __name__ == "__main__": diff --git a/topi/python/topi/arm_cpu/__init__.py b/topi/python/topi/arm_cpu/__init__.py index e6ccadd4755f..0c0979763dba 100644 --- a/topi/python/topi/arm_cpu/__init__.py +++ b/topi/python/topi/arm_cpu/__init__.py @@ -28,6 +28,6 @@ from . import cortex_m7 import os -use_auto_scheduler = os.environ.get("TVM_USE_AUTO_SCHEDULER", "true") +use_auto_scheduler = os.environ.get("TVM_USE_AUTO_SCHEDULER", "false") if use_auto_scheduler.lower() == "true": from ..ansor import * diff --git a/topi/python/topi/generic/__init__.py b/topi/python/topi/generic/__init__.py index 7f37ba78a06c..d44fca8548d2 100644 --- a/topi/python/topi/generic/__init__.py +++ b/topi/python/topi/generic/__init__.py @@ -41,6 +41,6 @@ from .image import * import os -use_auto_scheduler = os.environ.get("TVM_USE_AUTO_SCHEDULER", "true") +use_auto_scheduler = os.environ.get("TVM_USE_AUTO_SCHEDULER", "false") if use_auto_scheduler.lower() == "true": from ..ansor import * diff --git a/topi/python/topi/x86/__init__.py b/topi/python/topi/x86/__init__.py index a334397249e3..28e9e862f4d8 100644 --- a/topi/python/topi/x86/__init__.py +++ b/topi/python/topi/x86/__init__.py @@ -41,6 +41,6 @@ from .conv2d_alter_op import * import os -use_auto_scheduler = os.environ.get("TVM_USE_AUTO_SCHEDULER", "true") +use_auto_scheduler = os.environ.get("TVM_USE_AUTO_SCHEDULER", "false") if use_auto_scheduler.lower() == "true": from ..ansor import * From 6126cdbefe7c30be19bc88c59af73e396161e81b Mon Sep 17 00:00:00 2001 From: Chenfan Date: Fri, 19 Jun 2020 14:47:19 +0800 Subject: [PATCH 27/45] Add tensorize step for loop_state (#31) * Add tensorize step --- python/tvm/ansor/loop_state.py | 25 +++++++- python/tvm/ansor/task_scheduler.py | 2 + src/ansor/compute_dag.cc | 5 +- src/ansor/loop_state.cc | 59 ++++++++++++++++--- src/ansor/loop_state.h | 20 +++++-- .../search_policy/meta_tile_rewrite_policy.cc | 20 ++++++- src/ansor/search_policy/utils.h | 10 ++++ src/ansor/serialization.cc | 16 ++++- src/ansor/transform_step.cc | 36 +++++++++++ src/ansor/transform_step.h | 43 +++++++++++--- .../python/unittest/test_ansor_loop_state.py | 38 ++++++++++++ 11 files changed, 246 insertions(+), 28 deletions(-) diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/ansor/loop_state.py index 0cf157147423..67ec3ed12b05 100644 --- a/python/tvm/ansor/loop_state.py +++ b/python/tvm/ansor/loop_state.py @@ -411,14 +411,33 @@ def storage_align(self, stage_id, it, factor, offset): it : Iterator factor : Int offset : Int + """ + self.state_object = _ffi_api.StateStorageAlign(self.state_object, stage_id, it, factor, offset) + self.clear_cache() + + def tensorize(self, stage_id, it, ti_func_name): + """ The `ti_func_name` corresponds to a global registered funcion + that returns a TensorIntrin + + Parameters + ---------- + stage_id : Int + The index of the stage to do storage align + it : Iterator + The target iterator + ti_func_name : Str + Tensorize intrinsic function name Returns ------- - state : State - The updated state + res_it : Iterator + The tensorized Iterator """ - self.state_object = _ffi_api.StateStorageAlign(self.state_object, stage_id, it, factor, offset) + self.state_object, res = _ffi_api.StateTensorize(self.state_object, + stage_id, it, + ti_func_name) self.clear_cache() + return res def __str__(self): return str(self.state_object) diff --git a/python/tvm/ansor/task_scheduler.py b/python/tvm/ansor/task_scheduler.py index f8d3f419dcb4..89b4afd84e86 100644 --- a/python/tvm/ansor/task_scheduler.py +++ b/python/tvm/ansor/task_scheduler.py @@ -248,6 +248,8 @@ def tune(self, tune_option: TuneOption, search_policy: Union[str, List[SearchPol else: raise ValueError("Invalid strategy: " + self.strategy) + if self.verbose >= 1: + print("Next tuning task: %d" % task_idx) self.tune_task(task_idx) def tune_task(self, task_idx): diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index de3b98a5106b..5ca0c8503662 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -1086,7 +1086,8 @@ void ComputeDAG::InferBoundCommon(StateNode* pstate) const { new_iters.push_back(IteratorNode::make(iter->name, (*find_res).second, iter->iter_type, iter->annotation, - &iter->ori_iters)); + &iter->ori_iters, + iter->attr)); } else { LOG(FATAL) << "Infer bound fails"; } @@ -1161,6 +1162,8 @@ std::pair > ComputeDAG::ReplaySteps( ps->ApplyToSchedule(stages, stage_to_axes, &schedule); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); + } else if (auto ps = step.as()) { + ps->ApplyToSchedule(stages, stage_to_axes); } else { LOG(FATAL) << "Invalid Step"; } diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index 77361dbf837c..b6e6d854e3e5 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -39,7 +39,8 @@ TVM_REGISTER_NODE_TYPE(IteratorNode); // Maker for other classes Iterator IteratorNode::make(std::string name, Range range, IteratorType iter_type, IteratorAnnotation annotation, - const std::vector* ori_iters) { + const std::vector* ori_iters, + std::string attr) { auto node = make_object(); node->name = std::move(name); node->range = std::move(range); @@ -48,6 +49,7 @@ Iterator IteratorNode::make(std::string name, Range range, if (ori_iters != nullptr) { node->ori_iters = *ori_iters; } + node->attr = std::move(attr); return Iterator(node); } @@ -310,6 +312,15 @@ void State::storage_align(int stage_id, const Iterator& it, int factor, return DoStorageAlignStep(step); } +Iterator State::tensorize(int stage_id, const Iterator& it, + std::string ti_func_name) { + const Stage& stage = operator->()->stages[stage_id]; + TensorizeStep step = TensorizeStepNode::make( + stage_id, GetIndex(stage->iters, it), ti_func_name); + CopyOnWrite()->transform_steps.push_back(step); + return DoTensorizeStep(step); +} + // Steps' implementations void State::DoReorderStep(const ReorderStep& step) { const Stage& stage = operator->()->stages[step->stage_id]; @@ -509,8 +520,10 @@ Iterator State::DoAnnotationStep(const AnnotationStep& step) { const Stage& stage = operator->()->stages[step->stage_id]; Iterator it = stage->iters[step->iter_id]; + CHECK_EQ(it->annotation, IteratorAnnotation::kNone); Iterator new_it = IteratorNode::make(it->name, it->range, it->iter_type, - step->annotation, &it->ori_iters); + step->annotation, &it->ori_iters, + it->attr); Stage new_stage = stage; new_stage.CopyOnWrite()->iters[step->iter_id] = new_it; StateNode* pstate = CopyOnWrite(); @@ -538,7 +551,8 @@ void State::DoComputeAtStep(const ComputeAtStep& step) { new_iters.push_back(it); } else { new_iters.push_back(IteratorNode::make(it->name, Range(), it->iter_type, - it->annotation, &it->ori_iters)); + it->annotation, &it->ori_iters, + it->attr)); } } @@ -559,7 +573,8 @@ void State::DoComputeRootStep(const ComputeRootStep& step) { std::vector new_iters; for (const Iterator& it : stage->iters) { new_iters.push_back(IteratorNode::make(it->name, Range(), it->iter_type, - it->annotation, &it->ori_iters)); + it->annotation, &it->ori_iters, + it->attr)); } // update attach map @@ -747,6 +762,18 @@ void State::DoStorageAlignStep(const StorageAlignStep& step) { stage->storage_offset = step->offset; } +Iterator State::DoTensorizeStep(const TensorizeStep& step) { + const Stage& stage = operator->()->stages[step->stage_id]; + Iterator it = stage->iters[step->iter_id]; + Iterator new_it = IteratorNode::make(it->name, it->range, it->iter_type, + IteratorAnnotation::kTensorized, &it->ori_iters, step->ti_func_name); + Stage new_stage = stage; + new_stage.CopyOnWrite()->iters[step->iter_id] = new_it; + StateNode* pstate = CopyOnWrite(); + pstate->stages[step->stage_id] = std::move(new_stage); + return new_it; +} + void State::DoStep(const Step& step, const ComputeDAG& dag) { if (auto ps = step.as()) { DoReorderStep(GetRef(ps)); @@ -776,6 +803,8 @@ void State::DoStep(const Step& step, const ComputeDAG& dag) { DoRfactorStep(GetRef(ps), dag); } else if (auto ps = step.as()) { DoStorageAlignStep(GetRef(ps)); + } else if (auto ps = step.as()) { + DoTensorizeStep(GetRef(ps)); } else { LOG(FATAL) << "Invalid step: " << step; } @@ -854,15 +883,22 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state, case kThreadY: *os << "gpu.threadIdx.y "; break; + case kTensorized: + *os << "tensorize "; + break; + default: + LOG(FATAL) << "Invalid Annotation " << iter->annotation; break; } if (iter->range.defined()) { *os << iter->name << " (" << iter->range->min << "," - << iter->range->extent << ")" - << "\n"; + << iter->range->extent << ")"; } else { - *os << iter->name << " (None)" - << "\n"; + *os << iter->name << " (None)"; } + if (!iter->attr.empty()) { + *os << " " << iter->attr; + } + *os << "\n"; indent += 2; } @@ -1174,6 +1210,13 @@ TVM_REGISTER_GLOBAL("ansor.StateStorageAlign") return state; }); +TVM_REGISTER_GLOBAL("ansor.StateTensorize") +.set_body_typed([](State state, int stage_id, const Iterator& it, + std::string ti_func) { + const auto& res = state.tensorize(stage_id, it, ti_func); + return Array{state, res}; +}); + TVM_REGISTER_GLOBAL("ansor.StateEqual") .set_body_typed([](State state1, State state2) { return std::equal_to()(state1, state2); diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h index 90ba48cd92ac..6eef404ae272 100644 --- a/src/ansor/loop_state.h +++ b/src/ansor/loop_state.h @@ -74,7 +74,8 @@ enum IteratorType { /*! \brief The type of an iterator's annotation */ enum IteratorAnnotation { kNone, kUnroll, kVectorize, kParallel, - kVThread, kBlockX, kThreadX, kBlockY, kThreadY + kVThread, kBlockX, kThreadX, kBlockY, kThreadY, + kTensorized }; class Iterator; @@ -90,14 +91,17 @@ class IteratorNode : public Object { IteratorType iter_type; IteratorAnnotation annotation; std::vector ori_iters; // The original iterators before fusion + std::string attr; static Iterator make(std::string name, Range range, IteratorType iter_type, IteratorAnnotation annotation, - const std::vector* ori_iters = nullptr); + const std::vector* ori_iters = nullptr, + std::string attr = ""); void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name", &name); v->Visit("range", &range); + v->Visit("attr", &attr); } static constexpr const char *_type_key = "ansor.Iterator"; @@ -115,6 +119,7 @@ class FuseStep; class AnnotationStep; class ComputeAtStep; class ComputeRootStep; class ComputeInlineStep; class CacheReadStep; class CacheWriteStep; class PragmaStep; class RfactorStep; class StorageAlignStep; +class TensorizeStep; /*! * \brief A stage in the compute declaration @@ -254,19 +259,21 @@ class State : public ObjectRef { Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1); Iterator bind_thread(int stage_id, const Iterator& it, IteratorAnnotation thread_type); + Iterator tensorize(int stage_id, const Iterator& it, + std::string ti_func_name); void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter); void compute_root(int stage_id); void compute_inline(int stage_id); + void pragma(int stage_id, const Iterator& it, const std::string& pragma_type); + void storage_align(int stage_id, const Iterator& it, int factor, int offset); int cache_read(int stage_id, const std::string& scope_name, const std::vector& reader_stage_ids, const ComputeDAG& task_dag); int cache_write(int stage_id, const std::string& scope_name, const ComputeDAG& task_dag); - void pragma(int stage_id, const Iterator& it, const std::string& pragma_type); int rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& task_dag); - void storage_align(int stage_id, const Iterator& it, int factor, int offset); /* Do transform steps * Note: The following functions only change loop state but do not change transform_history. @@ -278,14 +285,15 @@ class State : public ObjectRef { std::vector DoFollowFusedSplitStep(const FollowFusedSplitStep& step); Iterator DoFuseStep(const FuseStep& step); Iterator DoAnnotationStep(const AnnotationStep& step); + Iterator DoTensorizeStep(const TensorizeStep& step); void DoComputeAtStep(const ComputeAtStep& step); void DoComputeRootStep(const ComputeRootStep& step); void DoComputeInlineStep(const ComputeInlineStep& step); + void DoPragmaStep(const PragmaStep& step); + void DoStorageAlignStep(const StorageAlignStep& step); int DoCacheReadStep(const CacheReadStep& step, const ComputeDAG& dag); int DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag); - void DoPragmaStep(const PragmaStep& step); int DoRfactorStep(const RfactorStep& step, const ComputeDAG& dag); - void DoStorageAlignStep(const StorageAlignStep& step); // General do step functions with a runtime dynamic dispatcher void DoStep(const Step& step, const ComputeDAG& dag); diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.cc b/src/ansor/search_policy/meta_tile_rewrite_policy.cc index 4a045d31a487..7e022e3be3c3 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.cc +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.cc @@ -751,6 +751,11 @@ int InitPopulationThreadBind(const MetaTileRewritePolicyNode* policy, continue; } + if (HasAnnotationIter(stage, IteratorAnnotation::kThreadX)) { + // Skip if this stage has already done thread bind + continue; + } + std::vector to_fuse; // This stage has not been tiled, but in GPU schedule, we must tile it @@ -861,10 +866,16 @@ int InitPopulationCooperativeFetching(const MetaTileRewritePolicyNode* policy, !HasCacheWriteStage((*state), stage_id - 1)) || (stage_id > 1 && HasCacheReadStage((*state), stage_id - 2) && HasCacheWriteStage((*state), stage_id - 2))) { + const Stage& target_stage = (*state)->stages[stage_id]; + if (HasAnnotationIter(target_stage, IteratorAnnotation::kThreadX) || + HasAnnotationIter(target_stage, IteratorAnnotation::kTensorized)) { + // Skip if this stage has already done thread bind or has been + // tensorized + continue; + } // Get spatial_split_step_ids from the root stage std::unordered_set consumers; std::vector spatial_split_step_ids; - const Stage& target_stage = (*state)->stages[stage_id]; GetConsumers(policy->cur_task_, (*state), target_stage->op, &consumers); CHECK_EQ(consumers.size(), 1); int target_stage_id = OperationToStage(*consumers.begin(), (*state)); @@ -1129,6 +1140,11 @@ int InitPopulationVectorization(const MetaTileRewritePolicyNode* policy, continue; } + if (HasAnnotationIter(stage, IteratorAnnotation::kTensorized)) { + // Skip if this stage has been tensorized + continue; + } + // try to fuse and vectorize the space iterators in the inner most tile int cum_length_prod = 1; @@ -1224,7 +1240,7 @@ int InitPopulationUnroll(const MetaTileRewritePolicyNode* policy, n--; } - } else if (stage->op->attrs.count(policy->always_unroll_key)) { + } else if (stage->op->attrs.count(policy->always_unroll_key)) { // Special unroll policy auto to_unroll_name_set = GetIterNameSetParam(stage->op->attrs, policy->always_unroll_key); diff --git a/src/ansor/search_policy/utils.h b/src/ansor/search_policy/utils.h index 3d0611173c94..472e90771879 100644 --- a/src/ansor/search_policy/utils.h +++ b/src/ansor/search_policy/utils.h @@ -143,6 +143,16 @@ inline bool HasReduceIter(const Stage& stage) { return false; } +// Return whether the stage has specific annotated iterators +inline bool HasAnnotationIter(const Stage& stage, IteratorAnnotation type) { + for (const auto& iter : stage->iters) { + if (iter->annotation == type) { + return true; + } + } + return false; +} + // Return whether an op needs multi level tiling inline bool NeedsMultilevelTiling(const SearchTask& task, const State& state, const te::Operation& op) { diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index b03acb1edc3c..ed5d4b868c27 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -167,6 +167,11 @@ struct Handler > { writer->WriteArrayItem(ps->iter_id); writer->WriteArrayItem(ps->factor); writer->WriteArrayItem(ps->offset); + } else if (auto ps = data[i].as<::tvm::ansor::TensorizeStepNode>()) { + writer->WriteArrayItem(std::string("TS")); + writer->WriteArrayItem(ps->stage_id); + writer->WriteArrayItem(ps->iter_id); + writer->WriteArrayItem(ps->ti_func_name); } else { LOG(FATAL) << "Invalid step: " << data[i]; } @@ -179,7 +184,7 @@ struct Handler > { std::vector<::tvm::ansor::Step> * data) { std::vector int_list; bool s, inner_to_outer, factor_or_nparts; - std::string name, scope_name, pragma_type; + std::string name, scope_name, pragma_type, ti_func_name; int stage_id, target_stage_id, iter_id, src_step_id, n_split, ann, extent; int level, factor_iter_id, factor, offset; @@ -311,6 +316,15 @@ struct Handler > { reader->Read(&offset); data->push_back(::tvm::ansor::StorageAlignStepNode::make( stage_id, iter_id, factor, offset)); + } else if (name == "TS") { + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&stage_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&iter_id); + s = reader->NextArrayItem(); CHECK(s); + reader->Read(&ti_func_name); + data->push_back(::tvm::ansor::TensorizeStepNode::make( + stage_id, iter_id, ti_func_name)); } else { LOG(FATAL) << "Invalid step format"; } diff --git a/src/ansor/transform_step.cc b/src/ansor/transform_step.cc index 3f59ff736e9d..b0e67a481ae3 100644 --- a/src/ansor/transform_step.cc +++ b/src/ansor/transform_step.cc @@ -26,6 +26,7 @@ #include "transform_step.h" #include +#include #include #include "utils.h" @@ -801,5 +802,40 @@ std::string StorageAlignStepNode::PrintAsPythonAPI( return ss.str(); } +/********** Tensorize **********/ +TensorizeStep TensorizeStepNode::make(int stage_id, int iter_id, + std::string ti_func_name) { + auto node = make_object(); + node->stage_id = stage_id; + node->iter_id = iter_id; + node->ti_func_name = ti_func_name; + return TensorizeStep(node); +} + +void TensorizeStepNode::ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const { + te::Stage& stage = (*stages)[stage_id]; + const std::vector& axes = (*stage_to_axes)[stage]; + auto func = tvm::runtime::Registry::Get(ti_func_name); + CHECK(func != nullptr) << "Cannot find the tensorize intrinsic func"; + tvm::te::TensorIntrin res = (*func)(); + CHECK(res.defined()) << "Tensorize intrinsic func must return a " + << "tvm::te::TensorIntrin object"; + stage.tensorize(axes[iter_id], res); +} + +std::string TensorizeStepNode::PrintAsPythonAPI( + std::vector *stages, StageToAxesMap *stage_to_axes, + te::Schedule *schedule, const std::vector& transform_steps) const { + std::stringstream ss; + const auto& stage = (*stages)[stage_id]; + ss << "s[" << CleanName(stage->op->func_name()) << "].tensorize(" + << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", " + << ti_func_name << "())\n"; + + ApplyToSchedule(stages, stage_to_axes); + return ss.str(); +} + } // namespace ansor } // namespace tvm diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h index 8240623ae3b1..9af14429bf61 100644 --- a/src/ansor/transform_step.h +++ b/src/ansor/transform_step.h @@ -23,17 +23,18 @@ * * \Note How to add a new transform step. * Take fuse for example: - * 1. Define class FuseStepNode, FuseStep in transform_steps.h, and implement its make function - * in FuseStepNode::make(...) transform_steps.cc - * 2. Implement FuseStepNode::ApplyToSchedule and FuseStepNode::PrintAsPythonAPI. - * - In these two functions you need to lower this step with tvm's schedule API - * 3. Implement State::fuse and State::DoFuseStep. + * 1. Define class `FuseStepNode`, `FuseStep` in `transform_steps.h`, and implement its make function + * `FuseStepNode::make(...)` in `transform_steps.cc` + * 2. Implement `FuseStepNode::ApplyToSchedule` and `FuseStepNode::PrintAsPythonAPI`. + * - In these two functions you need to lower this step with tvm's te schedule API + * 3. Implement `State::fuse` and `State::DoFuseStep`. * - In these two functions you need to incrementally update all data structures in State with * CopyOnWrite style - * 4. Add you step to ComputeDAG::ReplaySteps and make sure it works. + * 4. Add you step to `ComputeDAG::ReplaySteps` and make sure it works. * 5. Add serialization support in `struct Handler >` - * (in serialization.cc) + * in `serialization.cc` * 6. Add hash support in `struct hash<::tvm::ansor::Step>` (search for this function in this file) + * 7. Add its corresponding Python API to `loop_state.py` and necessary unit test */ #ifndef TVM_ANSOR_TRANSFORM_STEP_H_ @@ -365,6 +366,29 @@ class StorageAlignStepNode: public StepNode { }; TVM_DEFINE_COW_OBJECT_REF(StorageAlignStep, Step, StorageAlignStepNode); +/*! \brief Tensorize step that corresponds to te::Schedule::tensorize + * \Note This step takes a global registered function name as input. */ +class TensorizeStepNode: public StepNode { + public: + int iter_id; + std::string ti_func_name; + + static TensorizeStep make(int stage_id, int iter_id, + std::string ti_func_name); + + void ApplyToSchedule(std::vector *stages, + StageToAxesMap *stage_to_axes) const; + + std::string PrintAsPythonAPI(std::vector *stages, + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const final; + + static constexpr const char* _type_key = "ansor.TensorizeStep"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorizeStepNode, Object); +}; +TVM_DEFINE_COW_OBJECT_REF(TensorizeStep, Step, TensorizeStepNode); + } // namespace ansor } // namespace tvm @@ -451,6 +475,11 @@ struct hash<::tvm::ansor::Step> { ::dmlc::HashCombine(std::hash()(ps->iter_id), ::dmlc::HashCombine(std::hash()(ps->factor), ps->offset)))); + } else if (auto ps = step.as<::tvm::ansor::TensorizeStepNode>()) { + return ::dmlc::HashCombine(15, + ::dmlc::HashCombine(std::hash()(ps->stage_id), + ::dmlc::HashCombine(std::hash()(ps->iter_id), + ps->ti_func_name))); } else { LOG(FATAL) << "Invalid step"; } diff --git a/tests/python/unittest/test_ansor_loop_state.py b/tests/python/unittest/test_ansor_loop_state.py index 612d320036d8..a2c09aafc07b 100644 --- a/tests/python/unittest/test_ansor_loop_state.py +++ b/tests/python/unittest/test_ansor_loop_state.py @@ -17,6 +17,7 @@ """Test loop state and schedule primitives""" +import tvm from tvm import ansor, te import topi @@ -468,9 +469,46 @@ def test_rfactor(): " C.repl = ...\n" +@tvm._ffi.register_func +def test_intrin_gemv(): + m = 16 + l = 64 + a = te.placeholder((l,), name='a') + b = te.placeholder((l, m), name='b') + k = te.reduce_axis((0, l), name='k') + c = te.compute((m,), lambda i: te.sum(a[k] * b[k, i], axis=k), name='c') + Ab = tvm.tir.decl_buffer(a.shape, a.dtype, name="A", + offset_factor=1, strides=[1]) + Bb = tvm.tir.decl_buffer(b.shape, b.dtype, name="B", + offset_factor=1, strides=[te.var("s0"), 1]) + Cb = tvm.tir.decl_buffer(c.shape, c.dtype, name="C", + offset_factor=1, strides=[1]) + def intrin_func(ins, outs): + ib = tvm.tir.ir_builder.create() + aa, bb = ins + cc = outs[0] + ib.emit(tvm.tir.call_extern("float32", "gemv_update", + cc.access_ptr("w"), + aa.access_ptr("r"), + bb.access_ptr("r"))) + return ib.get() + return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb}) + +def test_tensorize(): + dag = ansor.ComputeDAG(matmul_ansor_test(1024, 512, 64)) + s0 = dag.get_init_state() + C = 2 + + its = s0.split(C, s0.stages[C].iters[1], [16]) + s0.tensorize(C, its[1], "test_intrin_gemv") + + sch, tensors = dag.apply_steps_from_state(s0) + tvm.lower(sch, tensors, simple_mode=True) + if __name__ == "__main__": test_split_fuse_reorder_annotation() test_follow_split_follow_fused_split() test_compute_at_root_inline() test_cache_read_write() test_rfactor() + test_tensorize() From c7364df568922d1643d50b85f5e0c3fa3acb64d2 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Fri, 19 Jun 2020 18:19:39 +0800 Subject: [PATCH 28/45] State python api update (#33) * Start to update api * Add compute_dag to state * API update --- python/tvm/ansor/compute_dag.py | 6 +- python/tvm/ansor/loop_state.py | 177 +++++++++-- python/tvm/ansor/serialization.py | 2 +- src/ansor/loop_state.cc | 4 - tests/python/unittest/test_ansor_common.py | 29 +- .../python/unittest/test_ansor_compute_dag.py | 19 +- tests/python/unittest/test_ansor_feature.py | 4 +- .../python/unittest/test_ansor_loop_state.py | 275 ++++++++++++------ .../unittest/test_ansor_search_policy.py | 11 +- ...t_ansor_vectorized_cooperative_fetching.py | 152 ---------- 10 files changed, 374 insertions(+), 305 deletions(-) delete mode 100644 tests/python/unittest/test_ansor_vectorized_cooperative_fetching.py diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index 23ba1b32f5c4..6d82942aa744 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -49,7 +49,7 @@ def get_init_state(self): ------- state : State """ - return State(_ffi_api.ComputeDAGGetInitState(self)) + return State(_ffi_api.ComputeDAGGetInitState(self), self) def apply_steps_from_state(self, state, layout_rewrite_level=None): """ @@ -98,9 +98,9 @@ def infer_bound_from_state(self, state): state : StateObject """ if isinstance(state, State): - return State(_ffi_api.ComputeDAGInferBoundFromState(self, state.state_object)) + return State(_ffi_api.ComputeDAGInferBoundFromState(self, state.state_object), self) elif isinstance(state, StateObject): - return State(_ffi_api.ComputeDAGInferBoundFromState(self, state)) + return State(_ffi_api.ComputeDAGInferBoundFromState(self, state), self) else: raise ValueError("The input must be a State or StateObject") diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/ansor/loop_state.py index 67ec3ed12b05..23289c027293 100644 --- a/python/tvm/ansor/loop_state.py +++ b/python/tvm/ansor/loop_state.py @@ -77,16 +77,48 @@ class State: ----- This is a wrapper class of StateObject to deal with copy-on-write property """ - def __init__(self, state_object): + def __init__(self, state_object, dag): self.state_object = state_object + self.compute_dag = dag self.stages_cache = None + self.stage_id_map = {} + self.__update_tensor_stage_map() + + def __getitem__(self, k): + if not self.stages_cache: + self.stages_cache = _ffi_api.StateGetStages(self.state_object) + if isinstance(k, tvm.te.Tensor): + return self.stages_cache[self.stage_id_map[k.op]] + else: + raise ValueError("Item must be Tensor") + + def __update_tensor_stage_map(self): + if not self.stages_cache: + self.stages_cache = _ffi_api.StateGetStages(self.state_object) + for index, stage in enumerate(self.stages_cache): + self.stage_id_map[stage.op] = index + + def __insert_new_stage(self, new_stage_id): + new_stage_id = int(new_stage_id) + self.stages_cache = _ffi_api.StateGetStages(self.state_object) + added_stage_tensor = self.stages_cache[new_stage_id].op.output(0) + + for key, value in self.stage_id_map.items(): + if value >= new_stage_id: + self.stage_id_map[key] = value + 1 + self.stage_id_map[added_stage_tensor.op] = new_stage_id + self.__update_tensor_stage_map() + + return added_stage_tensor def clear_cache(self): self.stages_cache = None def copy(self): - return State(self.state_object) + state = State(self.state_object, self.compute_dag) + state.stage_id_map = self.stage_id_map.copy() + return state @property def stages(self): @@ -99,6 +131,17 @@ def stages(self): self.stages_cache = _ffi_api.StateGetStages(self.state_object) return self.stages_cache + @property + def stage_tensors(self): + """ + Returns + ------- + Tensor + """ + if not self.stages_cache: + self.stages_cache = _ffi_api.StateGetStages(self.state_object) + return [stage.op.output(0) for stage in self.stages_cache] + def transform_steps_size(self): """ Return the size of transform_steps """ @@ -113,6 +156,11 @@ def reorder(self, stage_id, order): order : List[Iterator] Iterators in the expected order """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + self.state_object = _ffi_api.StateReorder(self.state_object, stage_id, order) self.clear_cache() @@ -135,6 +183,11 @@ def split(self, stage_id, it, lengths, inner_to_outer=True): res_its : List[Iterator] The splitted new Iterators """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + self.state_object, res = _ffi_api.StateSplit(self.state_object, stage_id, it, lengths, inner_to_outer) self.clear_cache() @@ -158,6 +211,11 @@ def follow_split(self, stage_id, it, src_step_id, n_split): res_its : List[Iterator] The splitted new Iterators """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + self.state_object, res = _ffi_api.StateFollowSplit(self.state_object, stage_id, it, src_step_id, n_split) self.clear_cache() @@ -185,6 +243,11 @@ def follow_fused_split(self, stage_id, it, src_step_ids, level, res_its : List[Iterator] The splitted new Iterators """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + self.state_object, res = _ffi_api.StateFollowFusedSplit(self.state_object, stage_id, it, src_step_ids, level, factor_or_nparts) @@ -205,6 +268,11 @@ def fuse(self, stage_id, iters): res_it : Iterator The fused Iterator """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + self.state_object, res = _ffi_api.StateFuse(self.state_object, stage_id, iters) self.clear_cache() return res @@ -223,6 +291,11 @@ def vectorize(self, stage_id, it): res_it : Iterator The vectorized Iterator """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + self.state_object, res = _ffi_api.StateVectorize(self.state_object, stage_id, it) self.clear_cache() return res @@ -241,6 +314,11 @@ def parallel(self, stage_id, it): res_it : Iterator The parallelized Iterator """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + self.state_object, res = _ffi_api.StateParallel(self.state_object, stage_id, it) self.clear_cache() return res @@ -261,6 +339,11 @@ def unroll(self, stage_id, it, max_unroll=-1): res_it : Iterator The unrolled Iterator """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + self.state_object, res = _ffi_api.StateUnroll(self.state_object, stage_id, it, max_unroll) self.clear_cache() return res @@ -290,6 +373,11 @@ def bind_thread(self, stage_id, it, thread_name): } thread_id = trans_table[thread_name] + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + self.state_object, res = _ffi_api.StateBindThread(self.state_object, stage_id, it, thread_id) self.clear_cache() return res @@ -305,6 +393,15 @@ def compute_at(self, stage_id, target_stage_id, target_iter): target_iter : Iterator The target Iterator of compute_at """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + if isinstance(target_stage_id, tvm.te.Tensor): + target_stage_id = self.stage_id_map[target_stage_id.op] + elif not isinstance(target_stage_id, int): + raise ValueError("target_stage_id must be Tensor or Int") + self.state_object = _ffi_api.StateComputeAt(self.state_object, stage_id, target_stage_id, target_iter) self.clear_cache() @@ -316,6 +413,11 @@ def compute_root(self, stage_id): stage_id : Int The index of the stage to compute root """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + self.state_object = _ffi_api.StateComputeRoot(self.state_object, stage_id) self.clear_cache() @@ -326,10 +428,15 @@ def compute_inline(self, stage_id): stage_id : Int The index of the stage to compute inline """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + self.state_object = _ffi_api.StateComputeInline(self.state_object, stage_id) self.clear_cache() - def cache_read(self, stage_id, scope_name, reader_stage_ids, task_dag): + def cache_read(self, stage_id, scope_name, reader_stage_ids): """ Parameters ---------- @@ -337,37 +444,55 @@ def cache_read(self, stage_id, scope_name, reader_stage_ids, task_dag): The index of the stage to do cache_read scope_name : Str reader_stage_ids : List[Int] - task_dag : ComputeDAG Returns ------- new_stage_id : Int The added staged id """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + if isinstance(reader_stage_ids, list): + tmp_list = [] + for reader_stage_id in reader_stage_ids: + if isinstance(reader_stage_id, tvm.te.Tensor): + tmp_list.append(self.stage_id_map[reader_stage_id.op]) + elif isinstance(reader_stage_id, int): + tmp_list.append(reader_stage_id) + else: + raise ValueError("reader_stage_id must be Tensor or Int") + reader_stage_ids = tmp_list + else: + raise ValueError("reader_stage_ids must be list of Tensor or Int") + self.state_object, new_stage_id = _ffi_api.StateCacheRead(self.state_object, stage_id, scope_name, reader_stage_ids, - task_dag) - self.clear_cache() - return int(new_stage_id) + self.compute_dag) + return self.__insert_new_stage(new_stage_id) - def cache_write(self, stage_id, scope_name, task_dag): + def cache_write(self, stage_id, scope_name): """ Parameters ---------- stage_id : Int The index of the stage to do cache read scope_name : Str - task_dag : ComputeDAG Returns ------- new_stage_id : Int The added staged id """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + self.state_object, new_stage_id = _ffi_api.StateCacheWrite(self.state_object, stage_id, - scope_name, task_dag) - self.clear_cache() - return int(new_stage_id) + scope_name, self.compute_dag) + return self.__insert_new_stage(new_stage_id) def pragma(self, stage_id, it, pragma_type): """ @@ -379,10 +504,15 @@ def pragma(self, stage_id, it, pragma_type): The iterator to add pragma pragma_type : Str """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + self.state_object = _ffi_api.StatePragma(self.state_object, stage_id, it, pragma_type) self.clear_cache() - def rfactor(self, stage_id, it, factor_iter_id, task_dag): + def rfactor(self, stage_id, it, factor_iter_id): """ Parameters ---------- @@ -390,17 +520,20 @@ def rfactor(self, stage_id, it, factor_iter_id, task_dag): The index of the stage to do reduction factor it : Iterator factor_iter_id : Int - task_dag : ComputeDAG Returns ------- new_stage_id : Int The added staged id """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + self.state_object, new_stage_id = _ffi_api.StateRfactor(self.state_object, stage_id, it, - factor_iter_id, task_dag) - self.clear_cache() - return int(new_stage_id) + factor_iter_id, self.compute_dag) + return self.__insert_new_stage(new_stage_id) def storage_align(self, stage_id, it, factor, offset): """ @@ -412,6 +545,11 @@ def storage_align(self, stage_id, it, factor, offset): factor : Int offset : Int """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + self.state_object = _ffi_api.StateStorageAlign(self.state_object, stage_id, it, factor, offset) self.clear_cache() @@ -433,6 +571,11 @@ def tensorize(self, stage_id, it, ti_func_name): res_it : Iterator The tensorized Iterator """ + if isinstance(stage_id, tvm.te.Tensor): + stage_id = self.stage_id_map[stage_id.op] + elif not isinstance(stage_id, int): + raise ValueError("stage_id must be Tensor or Int") + self.state_object, res = _ffi_api.StateTensorize(self.state_object, stage_id, it, ti_func_name) diff --git a/python/tvm/ansor/serialization.py b/python/tvm/ansor/serialization.py index e11a589a7522..d9b8a2f5c075 100644 --- a/python/tvm/ansor/serialization.py +++ b/python/tvm/ansor/serialization.py @@ -76,7 +76,7 @@ def write_measure_records_to_file(filename, inputs, results): def get_states_from_measure_inputs(inputs, task): """Get states from measure inputs""" state_objects = _ffi_api.GetStatesFromMeasureInputs(inputs, task) - return [State(s) for s in state_objects] + return [State(s, task.compute_dag) for s in state_objects] def best_measure_pair_in_file(filename, workload_key=None, target=None): diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index b6e6d854e3e5..7569c91e3368 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -1063,10 +1063,6 @@ TVM_REGISTER_GLOBAL("ansor.StateGetStages").set_body_typed([](const State& state return Array(state->stages); }); -TVM_REGISTER_GLOBAL("ansor.StateGetStage").set_body_typed([](const State& state, int index) { - return state->stages[index]; -}); - TVM_REGISTER_GLOBAL("ansor.StateGetTransformStepsSize").set_body_typed([](const State& state) { return static_cast(state->transform_steps.size()); }); diff --git a/tests/python/unittest/test_ansor_common.py b/tests/python/unittest/test_ansor_common.py index e23dba2aa4e3..083bd2721cb6 100644 --- a/tests/python/unittest/test_ansor_common.py +++ b/tests/python/unittest/test_ansor_common.py @@ -56,26 +56,19 @@ def conv2d_nchw_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation def get_tiled_matmul(): - dag = ansor.ComputeDAG(matmul_ansor_test(512, 512, 512)) + A, B, C = matmul_ansor_test(512, 512, 512) + dag = ansor.ComputeDAG([A, B, C]) s0 = dag.get_init_state() - A, B, C = 0, 1, 2 - C_global = s0.cache_write(C, "global", dag) - C += 1 - its0 = s0.split(C, s0.stages[C].iters[0], [4, 8, 8]) - its1 = s0.split(C, s0.stages[C].iters[4], [8, 4, 4]) + C_global = s0.cache_write(C, "global") + its0 = s0.split(C, s0[C].iters[0], [4, 8, 8]) + its1 = s0.split(C, s0[C].iters[4], [8, 4, 4]) s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], its1[3]]) - s0.compute_at(C_global, C, s0.stages[C].iters[3]) - s0.split(C_global, s0.stages[C_global].iters[2], [16]) - B_global = s0.cache_read(B, "global", [C_global], dag) - C += 1 - C_global += 1 - s0.compute_at(B_global, C_global, s0.stages[C_global].iters[0]) - A_global = s0.cache_read(A, "global", [C_global], dag) - B += 1 - B_global += 1 - C += 1 - C_global += 1 - s0.compute_at(A_global, C_global, s0.stages[C_global].iters[2]) + s0.compute_at(C_global, C, s0[C].iters[3]) + s0.split(C_global, s0[C_global].iters[2], [16]) + B_global = s0.cache_read(B, "global", [C_global]) + s0.compute_at(B_global, C_global, s0[C_global].iters[0]) + A_global = s0.cache_read(A, "global", [C_global]) + s0.compute_at(A_global, C_global, s0[C_global].iters[2]) return dag, s0 diff --git a/tests/python/unittest/test_ansor_compute_dag.py b/tests/python/unittest/test_ansor_compute_dag.py index b8afcc6a5b23..313dc1f89902 100644 --- a/tests/python/unittest/test_ansor_compute_dag.py +++ b/tests/python/unittest/test_ansor_compute_dag.py @@ -34,12 +34,14 @@ def test_infer_bound(): dag, s = get_tiled_matmul() s = dag.infer_bound_from_state(s) - A_global, B_global, C_global = 1, 3, 4 - assert s.stages[B_global].iters[0].range.extent == 512 - assert s.stages[B_global].iters[1].range.extent == 16 - assert s.stages[A_global].iters[0].range.extent == 1 - assert s.stages[A_global].iters[1].range.extent == 16 - assert s.stages[C_global].iters[0].range.extent == 64 + A_global = s.stage_tensors[1] + B_global = s.stage_tensors[3] + C_global = s.stage_tensors[4] + assert s[B_global].iters[0].range.extent == 512 + assert s[B_global].iters[1].range.extent == 16 + assert s[A_global].iters[0].range.extent == 1 + assert s[A_global].iters[1].range.extent == 16 + assert s[C_global].iters[0].range.extent == 64 def test_estimate_flop(): @@ -57,9 +59,8 @@ def test_lower_legalize_invalid_attach(): dag = ansor.ComputeDAG([A, B]) s = dag.get_init_state() - A, B = 0, 1 - s.compute_at(A, B, s.stages[B].iters[1]) - s.split(B, s.stages[B].iters[1], [2]) + s.compute_at(A, B, s[B].iters[1]) + s.split(B, s[B].iters[1], [2]) sch, tensors = dag.apply_steps_from_state(s) stmt = tvm.lower(sch, tensors, simple_mode=True) diff --git a/tests/python/unittest/test_ansor_feature.py b/tests/python/unittest/test_ansor_feature.py index 567fc080c6f8..bb19b84a970d 100644 --- a/tests/python/unittest/test_ansor_feature.py +++ b/tests/python/unittest/test_ansor_feature.py @@ -33,9 +33,9 @@ def fequal(a, b): def test_cpu_matmul(): dag = ansor.ComputeDAG(matmul_ansor_test(512, 512, 512)) s = dag.get_init_state() - C = 2 + C = s.stage_tensors[2] - i, j, k = s.stages[C].iters + i, j, k = s[C].iters io, ii = s.split(C, i, [16]) jo, ji = s.split(C, j, [8]) s.reorder(C, [io, jo, k, ji, ii]) diff --git a/tests/python/unittest/test_ansor_loop_state.py b/tests/python/unittest/test_ansor_loop_state.py index a2c09aafc07b..87688e276469 100644 --- a/tests/python/unittest/test_ansor_loop_state.py +++ b/tests/python/unittest/test_ansor_loop_state.py @@ -17,6 +17,8 @@ """Test loop state and schedule primitives""" +import numpy as np + import tvm from tvm import ansor, te import topi @@ -25,16 +27,16 @@ def test_split_fuse_reorder_annotation(): - dag = ansor.ComputeDAG(matmul_ansor_test(512, 512, 512)) + A, B, C = matmul_ansor_test(512, 512, 512) + dag = ansor.ComputeDAG([A, B, C]) s0 = dag.get_init_state() - C = 2 - i, j, k = s0.stages[C].iters + i, j, k = s0[C].iters assert i.range.extent == 512 io, ii = s0.split(C, i, [16]) - assert s0.stages[C].iters[0] == io - assert s0.stages[C].iters[1] == ii + assert s0[C].iters[0] == io + assert s0[C].iters[1] == ii assert io.range.extent == 32 assert ii.range.extent == 16 @@ -43,21 +45,21 @@ def test_split_fuse_reorder_annotation(): assert ji.range.extent == 8 s0.reorder(C, [io, jo, k, ji, ii]) - assert s0.stages[C].iters[2].range.extent == 512 + assert s0[C].iters[2].range.extent == 512 fused_it = s0.fuse(C, [io, jo]) assert fused_it.range.extent == 2048 s1 = dag.get_init_state() - i, j, _ = s1.stages[C].iters + i, j, _ = s1[C].iters i1, i2, i3 = s1.split(C, i, [8, 2]) j1, j2, j3 = s1.split(C, j, [32, 8], False) - assert s1.stages[C].iters[0].range.extent == 32 - assert s1.stages[C].iters[1].range.extent == 8 - assert s1.stages[C].iters[2].range.extent == 2 - assert s1.stages[C].iters[3].range.extent == 32 - assert s1.stages[C].iters[4].range.extent == 8 - assert s1.stages[C].iters[5].range.extent == 2 + assert s1[C].iters[0].range.extent == 32 + assert s1[C].iters[1].range.extent == 8 + assert s1[C].iters[2].range.extent == 2 + assert s1[C].iters[3].range.extent == 32 + assert s1[C].iters[4].range.extent == 8 + assert s1[C].iters[5].range.extent == 2 s1.parallel(C, j1) s1.unroll(C, j2) @@ -68,23 +70,22 @@ def test_split_fuse_reorder_annotation(): def test_follow_split_follow_fused_split(): - dag = ansor.ComputeDAG(matmul_ansor_test(512, 512, 512)) + A, B, C = matmul_ansor_test(512, 512, 512) + dag = ansor.ComputeDAG([A, B, C]) s0 = dag.get_init_state() - C = 2 - C_global = s0.cache_write(C, "global", dag) - C += 1 + C_global = s0.cache_write(C, "global") - its0 = s0.split(C, s0.stages[C].iters[0], [4, 2, 8, 4], True) + its0 = s0.split(C, s0[C].iters[0], [4, 2, 8, 4], True) split_step0 = s0.transform_steps_size() - 1 for level in range(1, 6): tmp = s0.copy() - tmp.follow_split(C_global, tmp.stages[C_global].iters[0], split_step0, level) + tmp.follow_split(C_global, tmp[C_global].iters[0], split_step0, level) for i in range(0, level): - assert tmp.stages[C].iters[i].range.extent == \ - tmp.stages[C_global].iters[i].range.extent + assert tmp[C].iters[i].range.extent == \ + tmp[C_global].iters[i].range.extent - its1 = s0.split(C, s0.stages[C].iters[5], [2, 2, 4, 8]) + its1 = s0.split(C, s0[C].iters[5], [2, 2, 4, 8]) split_step1 = s0.transform_steps_size() - 1 its = [] for i0, i1 in zip(its0, its1): @@ -92,40 +93,41 @@ def test_follow_split_follow_fused_split(): its.append(i1) s0.reorder(C, its) for i in range(0, 5): - s0.fuse(C, [s0.stages[C].iters[i], s0.stages[C].iters[i + 1]]) + s0.fuse(C, [s0[C].iters[i], s0[C].iters[i + 1]]) for level in range(0, 4): tmp = s0.copy() - tmp.follow_fused_split(C_global, tmp.stages[C_global].iters[0], + tmp.follow_fused_split(C_global, tmp[C_global].iters[0], [split_step0, split_step1], level, False) - assert tmp.stages[C].iters[level + 1].range.extent == \ - tmp.stages[C_global].iters[0].range.extent + assert tmp[C].iters[level + 1].range.extent == \ + tmp[C_global].iters[0].range.extent for level in range(0, 4): tmp = s0.copy() - tmp.follow_fused_split(C_global, tmp.stages[C_global].iters[0], + tmp.follow_fused_split(C_global, tmp[C_global].iters[0], [split_step0, split_step1], level, True) - assert tmp.stages[C].iters[level + 1].range.extent == \ - tmp.stages[C_global].iters[1].range.extent + assert tmp[C].iters[level + 1].range.extent == \ + tmp[C_global].iters[1].range.extent def test_compute_at_root_inline(): dag = ansor.ComputeDAG(conv2d_nchw_bn_relu(1, 224, 224, 3, 64, 7, 2, 3)) + s0 = dag.get_init_state() # data, padding, kernel = 0, 1, 2 - conv = 3 + conv = s0.stage_tensors[3] # bias = 4 - bias_add = 5 + bias_add = s0.stage_tensors[5] # bn_scale = 6 - bn_mul = 7 + bn_mul = s0.stage_tensors[7] # bn_offset = 8 - bn_add, relu = 9, 10 + bn_add = s0.stage_tensors[9] + relu = s0.stage_tensors[10] - s0 = dag.get_init_state() s0.compute_inline(bn_add) s0.compute_inline(bn_mul) s0.compute_inline(bias_add) - s0.compute_at(conv, relu, s0.stages[relu].iters[2]) + s0.compute_at(conv, relu, s0[relu].iters[2]) assert str(s0) == \ "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + \ "for i1 (0,3)\n" + \ @@ -186,33 +188,27 @@ def test_cache_read_write(): name='Kernel') conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation=1) relu = topi.nn.relu(conv) - out = topi.add(data, relu) + add = topi.add(data, relu) + + dag = ansor.ComputeDAG([data, kernel_data, add]) + s0 = dag.get_init_state() - dag = ansor.ComputeDAG([data, kernel_data, out]) - data, pad_temp, kernel_data, kernel_split, kernel, conv, relu, add = 0, 1, 2, 3, 4, 5, 6, 7 + pad_temp = s0.stage_tensors[1] + kernel_split = s0.stage_tensors[3] # 0: init state - s0 = dag.get_init_state() - ori_its = s0.stages[add].iters - its = s0.split(add, s0.stages[add].iters[0], [2]) + ori_its = s0[add].iters + its = s0.split(add, s0[add].iters[0], [2]) s0.reorder(add, [its[0], ori_its[1], its[1], ori_its[2], ori_its[3]]) s0.compute_inline(relu) # 1: simple cache_write with compute_at - conv_global = s0.cache_write(conv, "global", dag) - conv += 1 - relu += 1 - add += 1 - s0.compute_at(conv_global, conv, s0.stages[conv].iters[3]) + conv_global = s0.cache_write(conv, "global") + s0.compute_at(conv_global, conv, s0[conv].iters[3]) # 2: simple cache_read with compute_at - kernel_global = s0.cache_read(kernel, "global", [conv_global], dag) - conv_global += 1 - conv += 1 - relu += 1 - add += 1 - s0.compute_at(kernel_global, conv_global, - s0.stages[conv_global].iters[4]) + kernel_global = s0.cache_read(kernel, "global", [conv_global]) + s0.compute_at(kernel_global, conv_global, s0[conv_global].iters[4]) assert str(s0) == \ "Placeholder: Data, Kernel_data\n" + \ "for i0 (0,4)\n" + \ @@ -257,41 +253,14 @@ def test_cache_read_write(): # 3: two level cache_read with compute_at # preparing for GPU's shared memory & local memory - pad_temp_global = s0.cache_read(pad_temp, "global", [conv_global], dag) - kernel_data += 1 - kernel_split += 1 - kernel += 1 - kernel_global += 1 - conv_global += 1 - conv += 1 - relu += 1 - add += 1 - pad_temp_shared = s0.cache_read(pad_temp_global, "shared", [conv_global], dag) - kernel_data += 1 - kernel_split += 1 - kernel += 1 - kernel_global += 1 - conv_global += 1 - conv += 1 - relu += 1 - add += 1 - s0.compute_at(pad_temp_global, conv_global, s0.stages[conv_global].iters[2]) - s0.compute_at(pad_temp_shared, conv_global, s0.stages[conv_global].iters[4]) + pad_temp_global = s0.cache_read(pad_temp, "global", [conv_global]) + pad_temp_shared = s0.cache_read(pad_temp_global, "shared", [conv_global]) + s0.compute_at(pad_temp_global, conv_global, s0[conv_global].iters[2]) + s0.compute_at(pad_temp_shared, conv_global, s0[conv_global].iters[4]) # 4: cache_read with multi readers # This stage cannot be compute at to its consumer - data_global = s0.cache_read(data, "global", [pad_temp, add], dag) - pad_temp += 1 - pad_temp_global += 1 - pad_temp_shared += 1 - kernel_data += 1 - kernel_split += 1 - kernel += 1 - kernel_global += 1 - conv_global += 1 - conv += 1 - relu += 1 - add += 1 + s0.cache_read(data, "global", [pad_temp, add]) assert str(s0) == \ "Placeholder: Data, Kernel_data\n" + \ "for ax0 (0,4)\n" + \ @@ -364,7 +333,7 @@ def test_cache_read_write(): # Seems there's bug with the input/output tensor. Such multi outputs case # should be unusual, so we make some hack on DoCacheWrite # To be fixed in the future - s0.cache_write(kernel_split, "global", dag) + s0.cache_write(kernel_split, "global") assert str(s0) == \ "Placeholder: Data, Kernel_data\n" + \ "for ax0 (0,4)\n" + \ @@ -434,14 +403,14 @@ def test_cache_read_write(): def test_rfactor(): - dag = ansor.ComputeDAG(matmul_ansor_test(8, 8, 512)) + A, B, C = matmul_ansor_test(8, 8, 512) + dag = ansor.ComputeDAG([A, B, C]) s0 = dag.get_init_state() - C = 2 - ko, ki = s0.split(C, s0.stages[C].iters[2], [16]) + ko, ki = s0.split(C, s0[C].iters[2], [16]) s1 = s0.copy() - s1.rfactor(C, ko, 2, dag) + s1.rfactor(C, ko, 2) assert str(s1) == \ "Placeholder: A, B\n" + \ "for i (0,8)\n" + \ @@ -455,7 +424,7 @@ def test_rfactor(): " C.repl = ...\n" s2 = s0.copy() - s2.rfactor(C, ki, 2, dag) + s2.rfactor(C, ki, 2) assert str(s2) == \ "Placeholder: A, B\n" + \ "for i (0,8)\n" + \ @@ -469,6 +438,122 @@ def test_rfactor(): " C.repl = ...\n" +def vcf_init_common(): + A, B, C = matmul_ansor_test(512, 512, 512) + dag = ansor.ComputeDAG([A, B, C]) + s0 = dag.get_init_state() + B_shared = s0.cache_read(B, "shared", [C]) + B_local = s0.cache_read(B_shared, "local", [C]) + A_shared = s0.cache_read(A, "shared", [C]) + A_local = s0.cache_read(A_shared, "local", [C]) + + return A_shared, A_local, B_shared, B_local, C, dag, s0 + + +def vcf_check_common(dag, state): + s, args = dag.apply_steps_from_state(state) + # To check if every vectorize loop transforms to ramp expr successfully + # TODO(jcf94): Find a better way to process the check in AST + print(tvm.lower(s, args)) + + if tvm.context("cuda", 0).exist: + tgt = tvm.target.cuda() + mod = tvm.build(s, args, tgt) + # To check if every vectorize loop transforms to correct instruction + print(mod.imported_modules[0].get_source()) + + ctx = tvm.context("cuda", 0) + dtype = dag.tensors[0].dtype + a = tvm.nd.array(np.random.uniform(size=(512, 512)).astype(dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=(512, 512)).astype(dtype), ctx) + c = tvm.nd.array(np.zeros((512, 512), dtype=dtype), ctx) + mod(a, b, c) + tvm.testing.assert_allclose(c.asnumpy(), np.dot( + a.asnumpy(), b.asnumpy()), rtol=1e-5) + else: + print("CUDA device not found, skip this test.") + + +def test_vectorized_cooperative_fetching_x(): + A_shared, A_local, B_shared, B_local, C, dag, s0 = vcf_init_common() + + its0 = s0.split(C, s0[C].iters[0], [1, 8, 2, 4]) + its1 = s0.split(C, s0[C].iters[5], [2, 8, 2, 4]) + its2 = s0.split(C, s0[C].iters[10], [8, 8]) + s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its2[0], + its2[1], its0[3], its1[3], its2[2], its0[4], its1[4]]) + s0.fuse(C, [s0[C].iters[0], s0[C].iters[1]]) + s0.bind_thread(C, s0[C].iters[0], "blockIdx.x") + s0.fuse(C, [s0[C].iters[1], s0[C].iters[2]]) + s0.bind_thread(C, s0[C].iters[1], "vthread") + s0.fuse(C, [s0[C].iters[2], s0[C].iters[3]]) + s0.bind_thread(C, s0[C].iters[2], "threadIdx.x") + s0.vectorize(C, its1[4]) + + s0.compute_at(B_shared, C, s0[C].iters[3]) + fused_it = s0.fuse(B_shared, s0[B_shared].iters[:]) + its = s0.split(B_shared, fused_it, [64, 4]) + s0.bind_thread(B_shared, its[1], "threadIdx.x") + s0.vectorize(B_shared, its[2]) + s0.compute_at(B_local, C, s0[C].iters[4]) + fused_it = s0.fuse(B_local, s0[B_local].iters[:]) + its = s0.split(B_local, fused_it, [4]) + s0.vectorize(B_local, its[1]) + + s0.compute_at(A_shared, C, s0[C].iters[3]) + fused_it = s0.fuse(A_shared, s0[A_shared].iters[:]) + its = s0.split(A_shared, fused_it, [64, 4]) + s0.bind_thread(A_shared, its[1], "threadIdx.x") + s0.vectorize(A_shared, its[2]) + s0.compute_at(A_local, C, s0[C].iters[4]) + fused_it = s0.fuse(A_local, s0[A_local].iters[:]) + its = s0.split(A_local, fused_it, [4]) + s0.vectorize(A_local, its[1]) + + vcf_check_common(dag, s0) + + +def test_vectorized_cooperative_fetching_xy(): + A_shared, A_local, B_shared, B_local, C, dag, s0 = vcf_init_common() + + its0 = s0.split(C, s0[C].iters[0], [1, 8, 2, 4]) + its1 = s0.split(C, s0[C].iters[5], [2, 8, 2, 4]) + its2 = s0.split(C, s0[C].iters[10], [8, 8]) + s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its2[0], + its2[1], its0[3], its1[3], its2[2], its0[4], its1[4]]) + s0.fuse(C, [s0[C].iters[0], s0[C].iters[1]]) + s0.bind_thread(C, s0[C].iters[0], "blockIdx.x") + s0.fuse(C, [s0[C].iters[1], s0[C].iters[2]]) + s0.bind_thread(C, s0[C].iters[1], "vthread") + s0.bind_thread(C, s0[C].iters[2], "threadIdx.x") + s0.bind_thread(C, s0[C].iters[3], "threadIdx.y") + s0.vectorize(C, its1[4]) + + s0.compute_at(B_shared, C, s0[C].iters[4]) + fused_it = s0.fuse(B_shared, s0[B_shared].iters[:]) + its = s0.split(B_shared, fused_it, [8, 8, 4]) + s0.bind_thread(B_shared, its[1], "threadIdx.x") + s0.bind_thread(B_shared, its[2], "threadIdx.y") + s0.vectorize(B_shared, its[3]) + s0.compute_at(B_local, C, s0[C].iters[5]) + fused_it = s0.fuse(B_local, s0[B_local].iters[:]) + its = s0.split(B_local, fused_it, [4]) + s0.vectorize(B_local, its[1]) + + s0.compute_at(A_shared, C, s0[C].iters[4]) + fused_it = s0.fuse(A_shared, s0[A_shared].iters[:]) + its = s0.split(A_shared, fused_it, [8, 8, 4]) + s0.bind_thread(A_shared, its[1], "threadIdx.x") + s0.bind_thread(A_shared, its[2], "threadIdx.y") + s0.vectorize(A_shared, its[3]) + s0.compute_at(A_local, C, s0[C].iters[5]) + fused_it = s0.fuse(A_local, s0[A_local].iters[:]) + its = s0.split(A_local, fused_it, [4]) + s0.vectorize(A_local, its[1]) + + vcf_check_common(dag, s0) + + @tvm._ffi.register_func def test_intrin_gemv(): m = 16 @@ -495,11 +580,11 @@ def intrin_func(ins, outs): return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb}) def test_tensorize(): - dag = ansor.ComputeDAG(matmul_ansor_test(1024, 512, 64)) + A, B, C = matmul_ansor_test(1024, 512, 64) + dag = ansor.ComputeDAG([A, B, C]) s0 = dag.get_init_state() - C = 2 - its = s0.split(C, s0.stages[C].iters[1], [16]) + its = s0.split(C, s0[C].iters[1], [16]) s0.tensorize(C, its[1], "test_intrin_gemv") sch, tensors = dag.apply_steps_from_state(s0) @@ -511,4 +596,6 @@ def test_tensorize(): test_compute_at_root_inline() test_cache_read_write() test_rfactor() + test_vectorized_cooperative_fetching_x() + test_vectorized_cooperative_fetching_xy() test_tensorize() diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py index 839992c67e0f..9b1716175b5a 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -80,6 +80,7 @@ def test_search_basic(): t.start() t.join() + def test_search_xgb_model_rpc_runner(): measure_ctx = ansor.LocalRPCMeasureContext() search_common(seed=456787236, cost_model=ansor.XGBModel(), @@ -123,13 +124,13 @@ def apply_func1(meta_policy, state, stage_id): # Stage by stage way ret = [] if stage_id == 2: - state = ansor.loop_state.State(state) + state = ansor.loop_state.State(state, meta_policy.cur_task.compute_dag) state.split(2, state.stages[2].iters[0], [4, 4]) state.split(2, state.stages[2].iters[3], [4, 4]) ret.append([state.state_object, stage_id - 1]) elif stage_id == 1: - state = ansor.loop_state.State(state) - state.cache_read(1, "global", [2], meta_policy.cur_task.compute_dag) + state = ansor.loop_state.State(state, meta_policy.cur_task.compute_dag) + state.cache_read(1, "global", [2]) state.compute_at(2, 3, state.stages[3].iters[4]) ret.append([state.state_object, stage_id - 1]) else: @@ -139,11 +140,11 @@ def apply_func1(meta_policy, state, stage_id): def apply_func2(meta_policy, state, stage_id): # More template like way ret = [] - state = ansor.loop_state.State(state) + state = ansor.loop_state.State(state, meta_policy.cur_task.compute_dag) state.split(2, state.stages[2].iters[0], [4, 4]) state.split(2, state.stages[2].iters[3], [4, 4]) - state.cache_read(1, "global", [2], meta_policy.cur_task.compute_dag) + state.cache_read(1, "global", [2]) state.compute_at(2, 3, state.stages[3].iters[4]) ret.append([state.state_object, -1]) diff --git a/tests/python/unittest/test_ansor_vectorized_cooperative_fetching.py b/tests/python/unittest/test_ansor_vectorized_cooperative_fetching.py deleted file mode 100644 index c41abc7bcb3d..000000000000 --- a/tests/python/unittest/test_ansor_vectorized_cooperative_fetching.py +++ /dev/null @@ -1,152 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" Test for vectorized cooperative fetching """ - -import numpy as np -import tvm -from tvm import ansor, te -import topi - -from test_ansor_common import matmul_ansor_test, conv2d_nchw_bn_relu - - -def init_common(): - dag = ansor.ComputeDAG(matmul_ansor_test(512, 512, 512)) - s0 = dag.get_init_state() - A, B, C = 0, 1, 2 - B_shared = s0.cache_read(B, "shared", [C], dag) - C += 1 - B_local = s0.cache_read(B_shared, "local", [C], dag) - C += 1 - A_shared = s0.cache_read(A, "shared", [C], dag) - B += 1 - B_shared += 1 - B_local += 1 - C += 1 - A_local = s0.cache_read(A_shared, "local", [C], dag) - B += 1 - B_shared += 1 - B_local += 1 - C += 1 - - return A_shared, A_local, B_shared, B_local, C, dag, s0 - -def check_common(dag, state): - s, args = dag.apply_steps_from_state(state) - # To check if every vectorize loop transforms to ramp expr successfully - # TODO(jcf94): Find a better way to process the check in AST - print(tvm.lower(s, args)) - - if tvm.context("cuda", 0).exist: - tgt = tvm.target.cuda() - mod = tvm.build(s, args, tgt) - # To check if every vectorize loop transforms to correct instruction - print(mod.imported_modules[0].get_source()) - - ctx = tvm.context("cuda", 0) - dtype = dag.tensors[0].dtype - a = tvm.nd.array(np.random.uniform(size=(512, 512)).astype(dtype), ctx) - b = tvm.nd.array(np.random.uniform(size=(512, 512)).astype(dtype), ctx) - c = tvm.nd.array(np.zeros((512, 512), dtype=dtype), ctx) - mod(a, b, c) - tvm.testing.assert_allclose(c.asnumpy(), np.dot( - a.asnumpy(), b.asnumpy()), rtol=1e-5) - else: - print("CUDA device not found, skip this test.") - -def test_vectorized_cooperative_fetching_x(): - A_shared, A_local, B_shared, B_local, C, dag, s0 = init_common() - - its0 = s0.split(C, s0.stages[C].iters[0], [1, 8, 2, 4]) - its1 = s0.split(C, s0.stages[C].iters[5], [2, 8, 2, 4]) - its2 = s0.split(C, s0.stages[C].iters[10], [8, 8]) - s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its2[0], - its2[1], its0[3], its1[3], its2[2], its0[4], its1[4]]) - s0.fuse(C, [s0.stages[C].iters[0], s0.stages[C].iters[1]]) - s0.bind_thread(C, s0.stages[C].iters[0], "blockIdx.x") - s0.fuse(C, [s0.stages[C].iters[1], s0.stages[C].iters[2]]) - s0.bind_thread(C, s0.stages[C].iters[1], "vthread") - s0.fuse(C, [s0.stages[C].iters[2], s0.stages[C].iters[3]]) - s0.bind_thread(C, s0.stages[C].iters[2], "threadIdx.x") - s0.vectorize(C, its1[4]) - - s0.compute_at(B_shared, C, s0.stages[C].iters[3]) - fused_it = s0.fuse(B_shared, s0.stages[B_shared].iters[:]) - its = s0.split(B_shared, fused_it, [64, 4]) - s0.bind_thread(B_shared, its[1], "threadIdx.x") - s0.vectorize(B_shared, its[2]) - s0.compute_at(B_local, C, s0.stages[C].iters[4]) - fused_it = s0.fuse(B_local, s0.stages[B_local].iters[:]) - its = s0.split(B_local, fused_it, [4]) - s0.vectorize(B_local, its[1]) - - s0.compute_at(A_shared, C, s0.stages[C].iters[3]) - fused_it = s0.fuse(A_shared, s0.stages[A_shared].iters[:]) - its = s0.split(A_shared, fused_it, [64, 4]) - s0.bind_thread(A_shared, its[1], "threadIdx.x") - s0.vectorize(A_shared, its[2]) - s0.compute_at(A_local, C, s0.stages[C].iters[4]) - fused_it = s0.fuse(A_local, s0.stages[A_local].iters[:]) - its = s0.split(A_local, fused_it, [4]) - s0.vectorize(A_local, its[1]) - - check_common(dag, s0) - -def test_vectorized_cooperative_fetching_xy(): - A_shared, A_local, B_shared, B_local, C, dag, s0 = init_common() - - its0 = s0.split(C, s0.stages[C].iters[0], [1, 8, 2, 4]) - its1 = s0.split(C, s0.stages[C].iters[5], [2, 8, 2, 4]) - its2 = s0.split(C, s0.stages[C].iters[10], [8, 8]) - s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its2[0], - its2[1], its0[3], its1[3], its2[2], its0[4], its1[4]]) - s0.fuse(C, [s0.stages[C].iters[0], s0.stages[C].iters[1]]) - s0.bind_thread(C, s0.stages[C].iters[0], "blockIdx.x") - s0.fuse(C, [s0.stages[C].iters[1], s0.stages[C].iters[2]]) - s0.bind_thread(C, s0.stages[C].iters[1], "vthread") - s0.bind_thread(C, s0.stages[C].iters[2], "threadIdx.x") - s0.bind_thread(C, s0.stages[C].iters[3], "threadIdx.y") - s0.vectorize(C, its1[4]) - - s0.compute_at(B_shared, C, s0.stages[C].iters[4]) - fused_it = s0.fuse(B_shared, s0.stages[B_shared].iters[:]) - its = s0.split(B_shared, fused_it, [8, 8, 4]) - s0.bind_thread(B_shared, its[1], "threadIdx.x") - s0.bind_thread(B_shared, its[2], "threadIdx.y") - s0.vectorize(B_shared, its[3]) - s0.compute_at(B_local, C, s0.stages[C].iters[5]) - fused_it = s0.fuse(B_local, s0.stages[B_local].iters[:]) - its = s0.split(B_local, fused_it, [4]) - s0.vectorize(B_local, its[1]) - - s0.compute_at(A_shared, C, s0.stages[C].iters[4]) - fused_it = s0.fuse(A_shared, s0.stages[A_shared].iters[:]) - its = s0.split(A_shared, fused_it, [8, 8, 4]) - s0.bind_thread(A_shared, its[1], "threadIdx.x") - s0.bind_thread(A_shared, its[2], "threadIdx.y") - s0.vectorize(A_shared, its[3]) - s0.compute_at(A_local, C, s0.stages[C].iters[5]) - fused_it = s0.fuse(A_local, s0.stages[A_local].iters[:]) - its = s0.split(A_local, fused_it, [4]) - s0.vectorize(A_local, its[1]) - - check_common(dag, s0) - -if __name__ == "__main__": - test_vectorized_cooperative_fetching_x() - test_vectorized_cooperative_fetching_xy() From 36cd9ef474664490c9736c43282912df4c48c257 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Minmin=20Sun=20=28=E5=AD=99=E6=95=8F=E6=95=8F=29?= Date: Fri, 19 Jun 2020 18:24:30 +0800 Subject: [PATCH 29/45] kernel layout rewrite (#28) * kernel layout rewrite * remove some hacks * add defuse_ops pass and move kernel_layout_rewrite pass after fuse_ops pass * set TVM_RELAY_DISABLE_BUILD_CACHE for task extraction and prepare_layout_rewrite --- include/tvm/relay/attrs/transform.h | 13 + include/tvm/relay/transform.h | 14 + python/tvm/ansor/__init__.py | 2 +- python/tvm/ansor/compute_dag.py | 9 +- python/tvm/ansor/measure.py | 1 - python/tvm/ansor/relay_integration.py | 7 +- python/tvm/ansor/topi_integration.py | 13 +- python/tvm/relay/op/_transform.py | 2 + python/tvm/relay/op/op_attrs.py | 3 + python/tvm/relay/op/transform.py | 21 + python/tvm/relay/testing/dqn.py | 25 +- python/tvm/relay/testing/resnet.py | 4 + python/tvm/te/tensor.py | 6 +- scripts/tune_network.py | 9 +- src/ansor/compute_dag.cc | 725 ++++++++++-------- src/ansor/compute_dag.h | 2 +- src/relay/analysis/type_solver.cc | 1 + src/relay/backend/build_module.cc | 13 + src/relay/backend/compile_engine.cc | 5 + src/relay/backend/compile_engine.h | 3 + src/relay/transforms/defuse_ops.cc | 98 +++ .../transforms/kernel_layout_transform.cc | 63 ++ .../transforms/kernel_layout_transform.h | 75 ++ src/relay/transforms/pattern_util.h | 2 + topi/python/topi/nn/conv2d.py | 24 +- 25 files changed, 787 insertions(+), 353 deletions(-) create mode 100644 src/relay/transforms/defuse_ops.cc create mode 100644 src/relay/transforms/kernel_layout_transform.cc create mode 100644 src/relay/transforms/kernel_layout_transform.h diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 750a8a43163c..95476ed61bdd 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -296,6 +296,19 @@ struct LayoutTransformAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes for KernelLayoutTransform operator */ +struct KernelLayoutTransformAttrs : public tvm::AttrsNode { + std::string src_layout; + std::string dst_layout; + + TVM_DECLARE_ATTRS(KernelLayoutTransformAttrs, "relay.attrs.KernelLayoutTransformAttrs") { + TVM_ATTR_FIELD(src_layout) + .describe("The source layout of the tensor. (e.g. 1N32C112H112W)"); + TVM_ATTR_FIELD(dst_layout) + .describe("The destination layout of the tensor. (e.g. 1N2C112H112W16c)"); + } +}; + /*! \brief Attributes for ShapeOf operator */ struct ShapeOfAttrs : public tvm::AttrsNode { DataType dtype; diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 1b8b31aee5d1..5f5d9b643633 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -277,6 +277,20 @@ TVM_DLL Pass CanonicalizeOps(); */ TVM_DLL Pass AlterOpLayout(); +/*! + * \brief Alternate the layouts of kernels. + * + * \return The pass. + */ +TVM_DLL Pass KernelLayoutTransform(); + +/*! + * \brief The reverse of FuseOps. + * + * \return The pass. + */ +TVM_DLL Pass DeFuseOps(); + /*! * \brief Given a dest layout, this pass transforms the expr such that most of the ops input data * layout is changed to the dest layout. In ideal situation, there are only 2 layout transforms, one diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 6ea8a0ce904f..b43b21a60144 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -44,4 +44,4 @@ FallbackContext, clear_fallback_cache, ApplyGraphBest, BlockingEmptyContext from .topi_integration import register_topi_schedule, TaskExtractEnv from .relay_integration import extract_from_program, extract_from_multiple_program, \ - finish_layout_rewrite + finish_layout_rewrite, prepare_layout_rewrite diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index 6d82942aa744..c54c14ec123a 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -64,12 +64,17 @@ def apply_steps_from_state(self, state, layout_rewrite_level=None): args : List[Tensor] """ if isinstance(state, State): - return _ffi_api.ComputeDAGApplyStepsFromState(self, state.state_object) + return _ffi_api.ComputeDAGApplyStepsFromState(self, state.state_object, + layout_rewrite_level) elif isinstance(state, StateObject): - return _ffi_api.ComputeDAGApplyStepsFromState(self, state) + return _ffi_api.ComputeDAGApplyStepsFromState(self, state, + layout_rewrite_level) else: raise ValueError("The input must be a State or StateObject") + def rewrite_layout_from_state(self, state: State): + return _ffi_api.ComputeDAGRewriteLayoutFromState(self, state) + def print_python_code_from_state(self, state): """ Parameters diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index b82327ec67c4..8b38f91647b2 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -534,4 +534,3 @@ def timed_func(inp, build_res): print("") return measure_results - diff --git a/python/tvm/ansor/relay_integration.py b/python/tvm/ansor/relay_integration.py index 348828eec4b4..383471ee060d 100644 --- a/python/tvm/ansor/relay_integration.py +++ b/python/tvm/ansor/relay_integration.py @@ -54,7 +54,7 @@ def _lower(mod, # If failed to compile, then fallback to use VM compiler. # TODO: Currently VM compiler is likely to stack overflow for large models. try: - with relay.build_config(opt_level=3): + with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): opt_mod, _ = relay.optimize(mod, target, params) grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target) grc.codegen(opt_mod["main"]) @@ -191,7 +191,7 @@ def prepare_layout_rewrite(mod, params, ops, target): """Prepare for kernel layout rewrite. This function will write layout infos to a global static variable, then these layout info will be used by a relay pass `kernel_layout_transform`. """ - from .. import relay + from tvm import relay env = TaskExtractEnv.get(do_layout_rewrite=True) @@ -203,9 +203,8 @@ def prepare_layout_rewrite(mod, params, ops, target): else: warnings.warn("Op %s is not tunable, ignored." % op_name) + env.reset(topi_scheds) with env: - env.reset(topi_scheds) - # wrap build call in thread to avoid multiprocessing problems build_thread = threading.Thread(target=_lower, args=(mod, target, params)) diff --git a/python/tvm/ansor/topi_integration.py b/python/tvm/ansor/topi_integration.py index b4c15f74ea44..77def00cf9ec 100644 --- a/python/tvm/ansor/topi_integration.py +++ b/python/tvm/ansor/topi_integration.py @@ -26,14 +26,17 @@ See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage. """ +import os +import json import tvm.te._ffi_api from tvm import target as _target from tvm.te import tensor from tvm.te.tensor import PlaceholderOp, ComputeOp -from .dispatcher import DispatchContext +from .dispatcher import DispatchContext, BlockingEmptyContext from .workload_registry import register_auto_scheduler_workload_bufs, \ make_workload_key_bufs, compute_dag_hash +from .compute_dag import ComputeDAG def traverse_to_get_io_tensors(outs): layout_free_ops = [] @@ -77,11 +80,14 @@ def __init__(self, do_layout_rewrite=False): def __enter__(self): self.tracing = True self.wkl_key_collection = {} + self.relay_disable_build_cache_ = os.environ.get("TVM_RELAY_DISABLE_BUILD_CACHE", "false") + os.environ["TVM_RELAY_DISABLE_BUILD_CACHE"] = "true" return self def __exit__(self, exc_type, exc_val, exc_tb): self.tracing = False + os.environ["TVM_RELAY_DISABLE_BUILD_CACHE"] = self.relay_disable_build_cache_ def reset(self, wanted_relay_ops=None): """Reset task collections @@ -144,7 +150,7 @@ def get(do_layout_rewrite=False): The single instance of TaskExtractEnv """ if not TaskExtractEnv.current: - TaskExtractEnv.current = TaskExtractEnv() + TaskExtractEnv.current = TaskExtractEnv(do_layout_rewrite) else: TaskExtractEnv.current.do_layout_rewrite = do_layout_rewrite return TaskExtractEnv.current @@ -188,7 +194,7 @@ def wrapper(outs, *args, **kwargs): # Rewrite the dag and update the transform history for # the new dag in DispatchContext dispatch_ctx = DispatchContext.current - tgt = _target.current_target() + tgt = _target.Target.current() state = dispatch_ctx.query(tgt, key) dag = ComputeDAG(outs) new_dag = dag.rewrite_layout_from_state(state) @@ -199,7 +205,6 @@ def wrapper(outs, *args, **kwargs): task_env.layout_rewrite_success_ct += 1 # Call schedule_func under FallbackContext() to avoid layout rewrite - tgt = _target.Target.current() cfg = BlockingEmptyContext().query(tgt, key) return topi_schedule(cfg, outs) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index d104c1b1c2f8..41bd10cabe3e 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -74,6 +74,8 @@ def compute_strided_set(attrs, inputs, output_type): # layout_transform _reg.register_injective_schedule("layout_transform") _reg.register_pattern("layout_transform", OpPattern.INJECTIVE) +_reg.register_injective_schedule("kernel_layout_transform") +_reg.register_pattern("kernel_layout_transform", OpPattern.INJECTIVE) # argwhere @_reg.register_compute("argwhere") diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 486d63c36ff0..58b9269a4c48 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -261,6 +261,9 @@ class ClipAttrs(Attrs): class LayoutTransformAttrs(Attrs): """Attributes for transform.layout_transform""" +@tvm._ffi.register_object("relay.attrs.KernelLayoutTransformAttrs") +class KernelLayoutTransformAttrs(Attrs): + """Attributes for transform.kernel_layout_transform""" @tvm._ffi.register_object("relay.attrs.ShapeOfAttrs") class ShapeOfAttrs(Attrs): diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index a37226ea4f58..f2fa2b5f5b90 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -815,6 +815,27 @@ def layout_transform(data, src_layout, dst_layout): """ return _make.layout_transform(data, src_layout, dst_layout) +def kernel_layout_transform(data, src_layout, dst_layout): + """Transform the layout of a kernel + + Parameters + ---------- + data : relay.Expr + The source tensor to be transformed + + src_layout: str + The source layout. (e.g 1N32C112H112W) + + dst_layout: str + The destination layout. (e.g. 1N2C112H112W16c) + + Returns + ------- + ret : relay.Expr + The transformed tensor. + """ + return _make.kernel_layout_transform(data, src_layout, dst_layout) + def reverse_reshape(data, newshape): """Reshapes the input array where the special values are inferred from diff --git a/python/tvm/relay/testing/dqn.py b/python/tvm/relay/testing/dqn.py index 10da37001f12..b65e0ad5cae9 100644 --- a/python/tvm/relay/testing/dqn.py +++ b/python/tvm/relay/testing/dqn.py @@ -26,27 +26,32 @@ from . import layers from .init import create_workload -def get_net(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32"): +def get_net(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32", layout="NCHW"): """get symbol of nature dqn""" data_shape = (batch_size,) + image_shape data = relay.var("data", shape=data_shape, dtype=dtype) + bias_axis = layout.index('C') + conv1_bias = relay.var("conv1_bias") conv1 = layers.conv2d(data, kernel_size=(8, 8), strides=(4, 4), padding=(0, 0), - channels=32, name="conv1") - conv1 = relay.nn.bias_add(conv1, conv1_bias) + channels=32, name="conv1", data_layout=layout, + kernel_layout=layers.conv_kernel_layout(layout)) + conv1 = relay.nn.bias_add(conv1, conv1_bias, bias_axis) relu1 = relay.nn.relu(conv1) conv2_bias = relay.var("conv2_bias") conv2 = layers.conv2d(relu1, kernel_size=(4, 4), strides=(2, 2), padding=(0, 0), - channels=64, name="conv2") - conv2 = relay.nn.bias_add(conv2, conv2_bias) + channels=64, name="conv2", data_layout=layout, + kernel_layout=layers.conv_kernel_layout(layout)) + conv2 = relay.nn.bias_add(conv2, conv2_bias, bias_axis) relu2 = relay.nn.relu(conv2) conv3_bias = relay.var("conv3_bias") conv3 = layers.conv2d(relu2, kernel_size=(3, 3), strides=(1, 1), padding=(0, 0), - channels=64, name="conv3") - conv3 = relay.nn.bias_add(conv3, conv3_bias) + channels=64, name="conv3", data_layout=layout, + kernel_layout=layers.conv_kernel_layout(layout)) + conv3 = relay.nn.bias_add(conv3, conv3_bias, bias_axis) relu3 = relay.nn.relu(conv3) bf1 = relay.nn.batch_flatten(relu3) @@ -58,7 +63,7 @@ def get_net(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32" return relay.Function(args, dense2) -def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32"): +def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32", layout="NCHW"): """Get benchmark workload for a Deep Q Network Parameters ---------- @@ -72,10 +77,10 @@ def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="flo The data type Returns ------- - mod : tvm.IRModule + mod : tvm.relay.Module The relay module that contains a DQN network. params : dict of str to NDArray The parameters. """ - net = get_net(batch_size, num_actions=num_actions, image_shape=image_shape, dtype=dtype) + net = get_net(batch_size, num_actions=num_actions, image_shape=image_shape, dtype=dtype, layout=layout) return create_workload(net) diff --git a/python/tvm/relay/testing/resnet.py b/python/tvm/relay/testing/resnet.py index b431dd096f9d..8633879465bd 100644 --- a/python/tvm/relay/testing/resnet.py +++ b/python/tvm/relay/testing/resnet.py @@ -162,6 +162,8 @@ def resnet(units, data = relay.var("data", shape=data_shape, dtype=dtype) data = layers.batch_norm_infer(data=data, epsilon=2e-5, scale=False, name='bn_data') (_, _, height, _) = data_shape + if layout == "NHWC": + (_, height, _, _) = data_shape if height <= 32: # such as cifar10 body = layers.conv2d( data=data, channels=filter_list[0], kernel_size=(3, 3), @@ -209,6 +211,8 @@ def get_net(batch_size, Original author Wei Wu """ (_, height, _) = image_shape + if layout == "NHWC": + (height, _, _) = image_shape data_shape = (batch_size,) + image_shape if height <= 28: num_stages = 3 diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py index 7d73bf42ab7d..6539aabaa48f 100644 --- a/python/tvm/te/tensor.py +++ b/python/tvm/te/tensor.py @@ -57,8 +57,10 @@ class Tensor(DataProducer, _expr.ExprOp): def __call__(self, *indices): ndim = self.ndim - if len(indices) != ndim: - raise ValueError("Need to provide %d index in tensor slice" % ndim) + # After ansor kernel layout rewrite, len(indices) <= ndim, + # and the indices will get modified by Ansor during schedule generation. + # if len(indices) != ndim: + # raise ValueError("Need to provide %d index in tensor slice" % ndim) indices = convert_to_object(indices) args = [] for x in indices: diff --git a/scripts/tune_network.py b/scripts/tune_network.py index 5e5a337c7bce..dc17f407d003 100644 --- a/scripts/tune_network.py +++ b/scripts/tune_network.py @@ -49,9 +49,10 @@ def get_network(name, model_path, batch_size, layout): input_shape = (batch_size, 100) mod, params = relay.testing.dcgan.get_workload(batch_size=batch_size) elif name == 'dqn': - image_shape = (4, 84, 84) + layout = "NHWC" + image_shape = (84, 84, 4) input_shape = (batch_size, *image_shape) - mod, params = relay.testing.dqn.get_workload(batch_size=batch_size, image_shape=image_shape, dtype=dtype) + mod, params = relay.testing.dqn.get_workload(batch_size=batch_size, image_shape=image_shape, dtype=dtype, layout=layout) elif name == 'mobilenet': image_shape = (224, 224, 3) if layout == 'NHWC' else (3, 224, 224) input_shape = (batch_size, *image_shape) @@ -229,7 +230,7 @@ def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, if measure_ctx: del measure_ctx - kernel_layout_rewrite = False + kernel_layout_rewrite = False # Compile graph with best states found by auto-scheduler print("=============== Compile ===============") @@ -245,7 +246,7 @@ def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, ansor.LayoutRewriteLevel.BOTH_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE ansor.LayoutRewriteLevel.COMPUTE_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE - with relay.build_config(opt_level=3): + with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): graph, lib, opt_params = relay.build_module.build( mod, target=target, params=params) diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index 5ca0c8503662..fec301dc54bc 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -37,8 +37,8 @@ #include #include #include "transform_step.h" -#include "utils.h" -// #include "../relay/pass/kernel_layout_transform.h" +#include "search_policy/utils.h" +#include "../relay/transforms/kernel_layout_transform.h" namespace tvm { namespace ansor { @@ -595,325 +595,383 @@ std::string BaseName(const std::string& str) { return str.substr(0, str.rfind("_")); } -// class IndexRewriter : public ExprMutator { -// public: -// IndexRewriter(const OperationMap >& placeholder_new_names, -// const OperationMap >& placeholder_new_shapes): -// placeholder_new_names_(placeholder_new_names), -// placeholder_new_shapes_(placeholder_new_shapes) {} - -// Expr Mutate_(const Call* op, const Expr& e) { -// Expr op_ = IRMutator::Mutate_(op, e); - -// const Call* call = op_.as(); - -// if (call->call_type == Call::CallType::Halide) { -// Tensor t = Downcast(call->func).output(call->value_index); -// auto it = placeholder_new_names_.find(t->op); -// if (it != placeholder_new_names_.end()) { -// const std::vector& new_names = it->second; -// const Array& new_shape = placeholder_new_shapes_.at(t->op); -// std::unordered_map name_to_arg; -// for (const auto& arg : call->args) { -// std::string axis_name; -// if (const auto* pimm = arg.as()) { -// CHECK_EQ(pimm->value, 0); -// axis_name = "IntImm"; -// } else { -// axis_name = BaseName(CleanName(Downcast(arg)->name_hint)); -// CHECK_EQ(name_to_arg.count(axis_name), 0); -// name_to_arg[axis_name] = arg; -// } -// } - -// std::unordered_map div_factors; -// std::vector r_new_args; -// for (int i = new_names.size() - 1; i >= 0; --i) { -// auto ori_iter_name = new_names[i]; -// auto name_it = name_to_arg.find(ori_iter_name); -// CHECK(name_it != name_to_arg.end()); -// Expr ori_arg = name_it->second; - -// Expr mod_factor = new_shape[i]; - -// Expr div_factor = 1; -// if (div_factors.count(ori_iter_name)) { -// div_factor = div_factors[ori_iter_name]; -// } -// div_factors[ori_iter_name] = div_factor * new_shape[i]; - -// Expr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor); - -// r_new_args.push_back(new_arg); -// } - -// Array new_args(std::make_move_iterator(r_new_args.rbegin()), -// std::make_move_iterator(r_new_args.rend())); - -// return Call::make(call->type, call->name, new_args, call->call_type, -// call->func, call->value_index); -// } -// } -// return op_; -// } - -// private: -// const OperationMap >& placeholder_new_names_; -// const OperationMap >& placeholder_new_shapes_; -// }; - -// // TODO(minminsun): spill out new functions -// void ComputeDAG::RewriteLayout( -// const std::vector &transform_steps, LayoutRewriteLevel layout_rewrite_level) const { -// ComputeDAGNode* pdag = const_cast(this)->CopyOnWrite(); -// const State& state = ReplayAndInferBound(transform_steps); - -// OperationMap > placeholder_new_names; -// OperationMap > placeholder_new_shapes; -// int stage_id = -1; -// for (const auto& stage : state->stages) { -// stage_id += 1; -// const Operation& op = stage->op; -// if (op->IsInstance()) { -// const Map& attrs = op->attrs; -// if (attrs.count(layout_free_placeholders_key)) { -// const ObjectRef& attr_value = attrs[layout_free_placeholders_key]; -// Array placeholders = Downcast>(attr_value); -// for (auto& placeholder : placeholders) { -// const auto placeholder_op = placeholder->op; - -// // Check whether this placeholder has already been handled -// if (placeholder_new_names.count(placeholder_op)) { -// continue; -// } - -// // skip the op that is not direct consumer of this placeholder, -// // mostly due to cache read/write. -// bool direct_consumer = false; -// for (auto& t : op->InputTensors()) { -// if (t->op == placeholder_op) { -// direct_consumer = true; -// break; -// } -// } -// if (!direct_consumer) { -// continue; -// } - -// std::set placeholder_axis_names; -// TensorAccessExtractor extractor; -// for (const auto& exp : op.as()->body) { -// extractor.Extract(exp); -// } -// bool rewrite_placeholder = (layout_rewrite_level == kPlaceholderRewrite || -// layout_rewrite_level == kBothRewrite); -// bool rewrite_body = (layout_rewrite_level == kComputeRewrite || -// layout_rewrite_level == kBothRewrite); -// std::ostringstream os; - -// uint i = 0; -// if (extractor.buf_accesses.count(placeholder_op)) { -// for (const auto& ev : extractor.buf_accesses[placeholder_op]) { -// for (const auto& e : ev) { -// // TODO(minminsun): check whether the extents match the shape of placeholder -// std::string axis_name; -// if (const auto* pimm = e.as()) { -// CHECK_EQ(pimm->value, 0); -// // CHECK_EQ(placeholder->shape[i].as()->value, 1); -// axis_name = "IntImm"; -// } else { -// axis_name = BaseName(CleanName(Downcast(e)->name_hint)); -// } - -// placeholder_axis_names.insert(axis_name); -// if (rewrite_placeholder) { -// os << placeholder->shape[i++] << axis_name; -// } -// } -// } - -// if (rewrite_placeholder) { -// CHECK_EQ(placeholder_axis_names.size(), placeholder->shape.size()); -// std::string ori_layout = os.str(); -// os.str(""); -// ::tvm::relay::KernelLayoutVisitor::global_ori_layouts_queue.push_back(ori_layout); -// } -// } - -// std::vector stage_iters; - -// auto attach_it = state->attach_map->stage_to_attach_iter.find(stage_id); -// int attach_pos = -1; -// size_t iters_before_attach = 0; -// if (attach_it != state->attach_map->stage_to_attach_iter.end()) { -// auto attach = attach_it->second; -// const auto& attach_stage = state->stages[attach.first]; -// attach_pos = attach.second; -// stage_iters.insert(stage_iters.end(), -// attach_stage->iters.begin(), -// attach_stage->iters.begin() + attach_pos + 1); -// } - -// stage_iters.insert(stage_iters.end(), stage->iters.begin(), stage->iters.end()); - -// std::vector iters; -// for (size_t i = 0; i < stage_iters.size(); ++i) { -// const auto& iter = stage_iters[i]; -// if (iter->ori_iters.empty()) { -// iters.push_back(iter); -// } else { -// for (const Iterator& ori_iter : iter->ori_iters) { -// iters.push_back(ori_iter); -// } -// } -// if (static_cast(i) == attach_pos) { -// iters_before_attach = iters.size(); -// } -// } - -// std::vector new_names; -// Array new_shape; -// std::vector new_axis_names; -// for (const Iterator& iter : iters) { -// std::set ori_iter_names; -// ExtractOriginalIterators(iter->name, &ori_iter_names); -// // fused iters have been replaced with iter->ori_iters. -// // So there should be only one ori iter name extracted from iter->name. -// CHECK_EQ(ori_iter_names.size(), 1); -// auto ori_iter_name = BaseName(*ori_iter_names.begin()); -// new_axis_names.push_back(ori_iter_name); -// } -// for (size_t i = 0; i < new_axis_names.size(); ++i) { -// auto iter = iters[i]; -// std::string ori_iter_name; -// if (i < iters_before_attach) { -// ori_iter_name = new_axis_names[i + iters_before_attach]; -// } else { -// ori_iter_name = new_axis_names[i]; -// } -// if (placeholder_axis_names.count(ori_iter_name)) { -// os << iter->range->extent << ori_iter_name; -// new_names.push_back(ori_iter_name); -// new_shape.push_back(iter->range->extent); -// } -// } -// std::string new_layout = os.str(); -// os.str(""); -// ::tvm::relay::KernelLayoutVisitor::global_new_layouts_queue.push_back(new_layout); -// placeholder_new_names[placeholder_op] = new_names; -// placeholder_new_shapes[placeholder_op] = new_shape; - -// Array old_ops = pdag->ops; -// ArrayNode* pops = pdag->ops.CopyOnWrite(); - -// // Create new placeholder -// Operation new_placeholder_op; -// if (rewrite_placeholder) { -// new_placeholder_op = -// te::PlaceholderOpNode::make(placeholder_op->name, -// new_shape, -// placeholder_op.as()->dtype); -// } else { -// new_placeholder_op = placeholder_op; -// } - -// Operation new_compute_op, old_compute_op; -// if (rewrite_body) { -// Array new_body; -// IndexRewriter index_rewriter(placeholder_new_names, -// placeholder_new_shapes); -// for (auto& op : old_ops) { -// if (auto* pop = op.as()) { -// bool need_update = false; -// for (auto& t : op->InputTensors()) { -// if (t->op == placeholder_op) { -// need_update = true; -// break; -// } -// } -// if (need_update) { -// for (auto& body : pop->body) { -// new_body.push_back(index_rewriter.Mutate(body)); -// } -// old_compute_op = op; -// CHECK(!new_compute_op.defined()); -// new_compute_op = ComputeOpNode::make( -// pop->name, pop->tag, pop->attrs, pop->axis, new_body); -// } -// } -// } -// } - -// // construct the map from old_op to new_op -// std::unordered_map updated_ops; -// for (size_t i = 0; i < old_ops.size(); ++i) { -// auto old_op = old_ops[i]; -// if (rewrite_placeholder && old_op == placeholder_op) { -// pops->data[i] = new_placeholder_op; -// updated_ops[placeholder_op] = new_placeholder_op; -// } else if (rewrite_body && old_op == old_compute_op) { -// pops->data[i] = new_compute_op; -// updated_ops[old_compute_op] = new_compute_op; -// } else { -// pops->data[i] = old_op; -// } -// } - -// // Because ops is sorted in topo-order, only do one pass linear scan here. -// for (size_t i = 0; i < pops->data.size(); ++i) { -// auto old_op = Downcast(pops->data[i]); -// if (auto* pop = old_op.as()) { -// auto inputs = pop->InputTensors(); -// std::unordered_map rmap; -// for (auto input : inputs) { -// auto it = updated_ops.find(input->op); -// Operation new_op; -// while (it != updated_ops.end()) { -// new_op = it->second; -// it = updated_ops.find(new_op); -// } -// if (new_op.defined()) { -// int index = input->value_index; -// rmap[input] = new_op.output(index); -// } -// } -// if (!rmap.empty()) { -// Operation new_op = pop->ReplaceInputs(old_op, rmap); -// updated_ops[old_op] = new_op; -// pops->data[i] = new_op; -// } -// } -// } - -// pdag->init_state = StateNode::make(pdag->ops); - -// Array old_tensors = pdag->tensors; -// ArrayNode* ptensors = pdag->tensors.CopyOnWrite(); - -// for (size_t i = 0; i < old_tensors.size(); ++i) { -// const auto& old_tensor = old_tensors[i]; -// auto it = updated_ops.find(old_tensor->op); -// Operation new_op; -// while (it != updated_ops.end()) { -// new_op = it->second; -// it = updated_ops.find(new_op); -// } -// if (new_op.defined()) { -// if (layout_rewrite_level == kBothRewrite) { -// auto index = old_tensor->value_index; -// ptensors->data[i] = new_op.output(index); -// } else if (layout_rewrite_level == kComputeRewrite) { -// TensorNode* old_tensor_node = -// const_cast(old_tensor.as()); -// old_tensor_node->op = new_op; -// } -// } -// } -// } // end for placeholder -// } -// } -// } // end for stage -// } +class IndexRewriter : public StmtExprMutator { + public: + IndexRewriter(const OperationMap >& placeholder_new_names, + const OperationMap >& placeholder_new_shapes): + placeholder_new_names_(placeholder_new_names), + placeholder_new_shapes_(placeholder_new_shapes) {} + + PrimExpr Rewrite(PrimExpr expr) { + return this->VisitExpr(expr); + } + + PrimExpr VisitExpr_(const ProducerLoadNode* op) final { + te::Tensor t = Downcast(op->producer); + auto it = placeholder_new_names_.find(t->op); + if (it != placeholder_new_names_.end()) { + const std::vector& new_names = it->second; + const Array& new_shape = placeholder_new_shapes_.at(t->op); + std::unordered_map name_to_arg; + for (const auto& arg : op->indices) { + std::string axis_name; + if (const auto* pimm = arg.as()) { + CHECK_EQ(pimm->value, 0); + axis_name = "IntImm"; + } else { + axis_name = BaseName(CleanName(Downcast(arg)->name_hint)); + CHECK_EQ(name_to_arg.count(axis_name), 0); + name_to_arg[axis_name] = arg; + } + } + + std::unordered_map div_factors; + std::vector r_new_args; + for (int i = new_names.size() - 1; i >= 0; --i) { + auto ori_iter_name = new_names[i]; + auto name_it = name_to_arg.find(ori_iter_name); + CHECK(name_it != name_to_arg.end()); + PrimExpr ori_arg = name_it->second; + + PrimExpr mod_factor = new_shape[i]; + + PrimExpr div_factor = 1; + if (div_factors.count(ori_iter_name)) { + div_factor = div_factors[ori_iter_name]; + } + div_factors[ori_iter_name] = div_factor * new_shape[i]; + + PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor); + + r_new_args.push_back(new_arg); + } + + Array new_args(std::make_move_iterator(r_new_args.rbegin()), + std::make_move_iterator(r_new_args.rend())); + + return ProducerLoad(op->producer, new_args); + } + return GetRef(op); + } + + /* + PrimExpr Mutate_(const Call* op, const PrimExpr& e) { + PrimExpr op_ = IRMutator::Mutate_(op, e); + + const Call* call = op_.as(); + + if (call->call_type == Call::CallType::Halide) { + te::Tensor t = Downcast(call->func).output(call->value_index); + auto it = placeholder_new_names_.find(t->op); + if (it != placeholder_new_names_.end()) { + const std::vector& new_names = it->second; + const Array& new_shape = placeholder_new_shapes_.at(t->op); + std::unordered_map name_to_arg; + for (const auto& arg : call->args) { + std::string axis_name; + if (const auto* pimm = arg.as()) { + CHECK_EQ(pimm->value, 0); + axis_name = "IntImm"; + } else { + axis_name = BaseName(CleanName(Downcast(arg)->name_hint)); + CHECK_EQ(name_to_arg.count(axis_name), 0); + name_to_arg[axis_name] = arg; + } + } + + std::unordered_map div_factors; + std::vector r_new_args; + for (int i = new_names.size() - 1; i >= 0; --i) { + auto ori_iter_name = new_names[i]; + auto name_it = name_to_arg.find(ori_iter_name); + CHECK(name_it != name_to_arg.end()); + PrimExpr ori_arg = name_it->second; + + PrimExpr mod_factor = new_shape[i]; + + PrimExpr div_factor = 1; + if (div_factors.count(ori_iter_name)) { + div_factor = div_factors[ori_iter_name]; + } + div_factors[ori_iter_name] = div_factor * new_shape[i]; + + PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor); + + r_new_args.push_back(new_arg); + } + + Array new_args(std::make_move_iterator(r_new_args.rbegin()), + std::make_move_iterator(r_new_args.rend())); + + return Call::make(call->type, call->name, new_args, call->call_type, + call->func, call->value_index); + } + } + return op_; + } + */ + + private: + const OperationMap >& placeholder_new_names_; + const OperationMap >& placeholder_new_shapes_; +}; + +void ComputeDAG::RewriteLayout( + const std::vector &transform_steps, LayoutRewriteLevel layout_rewrite_level) const { + ComputeDAGNode* pdag = const_cast(this)->CopyOnWrite(); + const State& state = ReplayAndInferBound(transform_steps); + + OperationMap > placeholder_new_names; + OperationMap > placeholder_new_shapes; + int stage_id = -1; + for (const auto& stage : state->stages) { + stage_id += 1; + const te::Operation& op = stage->op; + if (op->IsInstance()) { + const Map& attrs = op->attrs; + if (attrs.count(layout_free_placeholders_key)) { + const ObjectRef& attr_value = attrs[layout_free_placeholders_key]; + Array placeholders = Downcast>(attr_value); + for (auto& placeholder : placeholders) { + const auto placeholder_op = placeholder->op; + + // Check whether this placeholder has already been handled + if (placeholder_new_names.count(placeholder_op)) { + continue; + } + + // skip the op that is not direct consumer of this placeholder, + // mostly due to cache read/write. + bool direct_consumer = false; + for (auto& t : op->InputTensors()) { + if (t->op == placeholder_op) { + direct_consumer = true; + break; + } + } + if (!direct_consumer) { + continue; + } + + std::set placeholder_axis_names; + TensorAccessExtractor extractor; + for (const auto& exp : op.as()->body) { + extractor.Extract(exp); + } + bool rewrite_placeholder = (layout_rewrite_level == kPlaceholderRewrite || + layout_rewrite_level == kBothRewrite); + bool rewrite_body = (layout_rewrite_level == kComputeRewrite || + layout_rewrite_level == kBothRewrite); + std::ostringstream os; + + uint i = 0; + if (extractor.buf_accesses.count(placeholder_op)) { + for (const auto& ev : extractor.buf_accesses[placeholder_op]) { + for (const auto& e : ev) { + // TODO(minminsun): check whether the extents match the shape of placeholder + std::string axis_name; + if (const auto* pimm = e.as()) { + CHECK_EQ(pimm->value, 0); + // CHECK_EQ(placeholder->shape[i].as()->value, 1); + axis_name = "IntImm"; + } else { + axis_name = BaseName(CleanName(Downcast(e)->name_hint)); + } + + placeholder_axis_names.insert(axis_name); + if (rewrite_placeholder) { + os << placeholder->shape[i++] << axis_name; + } + } + } + + if (rewrite_placeholder) { + CHECK_EQ(placeholder_axis_names.size(), placeholder->shape.size()); + std::string ori_layout = os.str(); + os.str(""); + ::tvm::relay::KernelLayoutVisitor::global_ori_layouts_queue.push_back(ori_layout); + } + } + + std::vector stage_iters; + + auto attach_it = state->attach_map->stage_to_attach_iter.find(stage_id); + int attach_pos = -1; + size_t iters_before_attach = 0; + if (attach_it != state->attach_map->stage_to_attach_iter.end()) { + auto attach = attach_it->second; + const auto& attach_stage = state->stages[attach.first]; + attach_pos = attach.second; + stage_iters.insert(stage_iters.end(), + attach_stage->iters.begin(), + attach_stage->iters.begin() + attach_pos + 1); + } + + stage_iters.insert(stage_iters.end(), stage->iters.begin(), stage->iters.end()); + + std::vector iters; + for (size_t i = 0; i < stage_iters.size(); ++i) { + const auto& iter = stage_iters[i]; + if (iter->ori_iters.empty()) { + iters.push_back(iter); + } else { + for (const Iterator& ori_iter : iter->ori_iters) { + iters.push_back(ori_iter); + } + } + if (static_cast(i) == attach_pos) { + iters_before_attach = iters.size(); + } + } + + std::vector new_names; + Array new_shape; + std::vector new_axis_names; + for (const Iterator& iter : iters) { + std::set ori_iter_names; + ExtractOriginalIterators(iter->name, &ori_iter_names); + // fused iters have been replaced with iter->ori_iters. + // So there should be only one ori iter name extracted from iter->name. + CHECK_EQ(ori_iter_names.size(), 1); + auto ori_iter_name = BaseName(*ori_iter_names.begin()); + new_axis_names.push_back(ori_iter_name); + } + for (size_t i = 0; i < new_axis_names.size(); ++i) { + auto iter = iters[i]; + std::string ori_iter_name; + if (i < iters_before_attach) { + ori_iter_name = new_axis_names[i + iters_before_attach]; + } else { + ori_iter_name = new_axis_names[i]; + } + if (placeholder_axis_names.count(ori_iter_name)) { + os << iter->range->extent << ori_iter_name; + new_names.push_back(ori_iter_name); + new_shape.push_back(iter->range->extent); + } + } + std::string new_layout = os.str(); + os.str(""); + ::tvm::relay::KernelLayoutVisitor::global_new_layouts_queue.push_back(new_layout); + placeholder_new_names[placeholder_op] = new_names; + placeholder_new_shapes[placeholder_op] = new_shape; + + Array old_ops = pdag->ops; + ArrayNode* pops = pdag->ops.CopyOnWrite(); + + // Create new placeholder + te::Operation new_placeholder_op; + if (rewrite_placeholder) { + new_placeholder_op = + te::PlaceholderOpNode::make(placeholder_op->name, + new_shape, + placeholder_op.as()->dtype); + } else { + new_placeholder_op = placeholder_op; + } + + te::Operation new_compute_op, old_compute_op; + if (rewrite_body) { + Array new_body; + IndexRewriter index_rewriter(placeholder_new_names, + placeholder_new_shapes); + for (auto& op : old_ops) { + if (auto* pop = op.as()) { + bool need_update = false; + for (auto& t : op->InputTensors()) { + if (t->op == placeholder_op) { + need_update = true; + break; + } + } + if (need_update) { + for (auto& body : pop->body) { + new_body.push_back(index_rewriter.Rewrite(body)); + } + old_compute_op = op; + CHECK(!new_compute_op.defined()); + new_compute_op = te::ComputeOpNode::make( + pop->name, pop->tag, pop->attrs, pop->axis, new_body); + } + } + } + } + + // construct the map from old_op to new_op + std::unordered_map updated_ops; + for (size_t i = 0; i < old_ops.size(); ++i) { + auto old_op = old_ops[i]; + if (rewrite_placeholder && old_op == placeholder_op) { + //pops->data[i] = new_placeholder_op; + pops->SetItem(i, new_placeholder_op); + updated_ops[placeholder_op] = new_placeholder_op; + } else if (rewrite_body && old_op == old_compute_op) { + //pops->data[i] = new_compute_op; + pops->SetItem(i, new_compute_op); + updated_ops[old_compute_op] = new_compute_op; + } else { + //pops->data[i] = old_op; + pops->SetItem(i, old_op); + } + } + + // Because ops is sorted in topo-order, only do one pass linear scan here. + for (size_t i = 0; i < pops->size(); ++i) { + auto old_op = Downcast(pops->at(i)); + if (auto* pop = old_op.as()) { + auto inputs = pop->InputTensors(); + std::unordered_map rmap; + for (auto input : inputs) { + auto it = updated_ops.find(input->op); + te::Operation new_op; + while (it != updated_ops.end()) { + new_op = it->second; + it = updated_ops.find(new_op); + } + if (new_op.defined()) { + int index = input->value_index; + rmap[input] = new_op.output(index); + } + } + if (!rmap.empty()) { + te::Operation new_op = pop->ReplaceInputs(old_op, rmap); + updated_ops[old_op] = new_op; + //pops->data[i] = new_op; + pops->SetItem(i, new_op); + } + } + } + + pdag->init_state = StateNode::make(pdag->ops); + + Array old_tensors = pdag->tensors; + ArrayNode* ptensors = pdag->tensors.CopyOnWrite(); + + for (size_t i = 0; i < old_tensors.size(); ++i) { + const auto& old_tensor = old_tensors[i]; + auto it = updated_ops.find(old_tensor->op); + te::Operation new_op; + while (it != updated_ops.end()) { + new_op = it->second; + it = updated_ops.find(new_op); + } + if (new_op.defined()) { + if (layout_rewrite_level == kBothRewrite) { + auto index = old_tensor->value_index; + //ptensors->data[i] = new_op.output(index); + ptensors->SetItem(i, new_op.output(index)); + } else if (layout_rewrite_level == kComputeRewrite) { + te::TensorNode* old_tensor_node = + const_cast(old_tensor.as()); + old_tensor_node->op = new_op; + } + } + } + } // end for placeholder + } + } + } // end for stage +} void UpdateStageAxis(const te::Stage& stage, StageToAxesMap *stage_to_axes) { @@ -1273,6 +1331,30 @@ TVM_REGISTER_GLOBAL("ansor.ComputeDAG") TVM_REGISTER_GLOBAL("ansor.ComputeDAGGetInitState") .set_body_method(&ComputeDAG::GetInitState); +TVM_REGISTER_GLOBAL("ansor.ComputeDAGRewriteLayoutFromState") +.set_body([](TVMArgs args, TVMRetValue *ret) { + ComputeDAG dag = args[0]; + State state = args[1]; + + dag.RewriteLayout(state->transform_steps, kPlaceholderRewrite); + *ret = dag; +}); + +TVM_REGISTER_GLOBAL("ansor.ComputeDAGApplyStepsFromState") +.set_body([](TVMArgs args, TVMRetValue *ret) { + ComputeDAG dag = args[0]; + State state = args[1]; + LayoutRewriteLevel layout_rewrite_level = kNoRewrite; + if (args.size() >= 3) { + layout_rewrite_level = LayoutRewriteLevel(static_cast((args[2]))); + } + + te::Schedule sch; + Array return_tensors; + std::tie(sch, return_tensors) = dag.ApplySteps(state->transform_steps, layout_rewrite_level); + *ret = Array{sch, return_tensors}; +}); +/* TVM_REGISTER_GLOBAL("ansor.ComputeDAGApplyStepsFromState") .set_body_typed([](const ComputeDAG& dag, const State& state) { te::Schedule sch; @@ -1280,6 +1362,7 @@ TVM_REGISTER_GLOBAL("ansor.ComputeDAGApplyStepsFromState") std::tie(sch, return_tensors) = dag.ApplySteps(state->transform_steps); return Array{sch, return_tensors}; }); +*/ TVM_REGISTER_GLOBAL("ansor.ComputeDAGPrintPythonCodeFromState") .set_body_typed([](const ComputeDAG& dag, const State& state) { diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h index 60c1790a0cfb..c71c4f1b6586 100644 --- a/src/ansor/compute_dag.h +++ b/src/ansor/compute_dag.h @@ -146,7 +146,7 @@ class ComputeDAG: public ObjectRef { // Rewrite the the layout of "layout free" placeholders according to transform steps void RewriteLayout(const std::vector& transform_steps, - LayoutRewriteLevel layout_rewrite_level = kNoRewrite) const {} + LayoutRewriteLevel layout_rewrite_level = kNoRewrite) const; // Print transform steps as equivalent python schedule API std::string PrintStepsAsPython(const std::vector& steps) const; diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index a192002825e6..5b063eca4337 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -219,6 +219,7 @@ class TypeSolver::Unifier : public TypeFunctor { return Type(nullptr); } + tt1 = tt2; tvm::Array shape; if (tt1->shape.size() != tt2->shape.size()) { this->solver_->ReportError(ErrorBuilder() << "tensor type `" << PrettyPrint(tt1) << "` has " diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 34c3487e3ef2..8bd5eca7c93d 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -287,6 +287,7 @@ class RelayBuildModule : public runtime::ModuleNode { // Alter layout transformation is only applied to homogeneous execution yet. if (targets.size() == 1) { pass_seqs.push_back(transform::AlterOpLayout()); + //pass_seqs.push_back(transform::KernelLayoutTransform()); } // Fast math optimizations. @@ -315,6 +316,18 @@ class RelayBuildModule : public runtime::ModuleNode { // Fuse the operations if it is needed. relay_module = transform::FuseOps()(relay_module); + + if (targets.size() == 1) { + pass_seqs.push_back(transform::KernelLayoutTransform()); + pass_seqs.push_back(transform::DeFuseOps()); + pass_seqs.push_back(transform::FoldConstant()); + transform::Pass seq = transform::Sequential(pass_seqs); + const auto& it = targets.begin(); + With tctx((*it).second); + relay_module = seq(relay_module); + relay_module = transform::FuseOps()(relay_module); + } + relay_module = transform::InferType()(relay_module); // Inline the functions that have been lifted by the module scope. // diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 2aae8546248f..fde880b10f1d 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -68,6 +68,11 @@ CCacheKey::CCacheKey(Function source_func, Target target) { auto n = make_object(); n->source_func = std::move(source_func); n->target = std::move(target); + n->disabled = false; + char* envar = getenv("TVM_RELAY_DISABLE_BUILD_CACHE"); + if (envar != nullptr && strcmp(envar, "true") == 0) { + n->disabled = true; + } data_ = std::move(n); } diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index a5f3f6359f89..b290462a4b22 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -115,6 +115,8 @@ class CCacheKeyNode : public Object { /*! \brief The hardware target.*/ Target target; + bool disabled; + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("source_func", &source_func); v->Visit("target", &target); @@ -259,6 +261,7 @@ inline size_t CCacheKeyNode::Hash() const { } inline bool CCacheKeyNode::Equal(const CCacheKeyNode* other) const { + if (disabled) return false; if (Hash() != other->Hash()) return false; return this->target->str() == other->target->str() && tvm::StructuralEqual()(this->source_func, other->source_func); diff --git a/src/relay/transforms/defuse_ops.cc b/src/relay/transforms/defuse_ops.cc new file mode 100644 index 000000000000..f7c9037df687 --- /dev/null +++ b/src/relay/transforms/defuse_ops.cc @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "pattern_util.h" + +namespace tvm { +namespace relay { + +class DefuseOpsMutator : public ExprMutator { + public: + + class FuncBodyMutator : public ExprMutator { + public: + Array args_; + + FuncBodyMutator(const Array& args) : ExprMutator() { + args_ = args; + } + + Expr VisitExpr_(const VarNode* n) { + const std::string& name = n->name_hint(); + CHECK_EQ(name[0], 'p'); + std::string id_str = name.substr(1); + int id = atoi(id_str.c_str()); + CHECK(id >= 0 && size_t(id) < args_.size()); + return args_[id]; + } + }; + + Expr VisitExpr_(const CallNode* n) { + auto new_n = ExprMutator::VisitExpr_(n); + + const auto* call = new_n.as(); + if (call) { + const auto* func = call->op.as(); + if (func) { + const auto& func_call = func->body.as(); + if (func_call) { + return FuncBodyMutator(call->args).Mutate(func->body); + } + } + } + return new_n; + } +}; + +Expr DeFuseOps(const Expr& expr) { + return DefuseOpsMutator().Mutate(expr); +} + +namespace transform { + +Pass DeFuseOps() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(relay::DeFuseOps(f)); + }; + return CreateFunctionPass(pass_func, 3, "DeFuseOps", + {"InferType"}); +} + +TVM_REGISTER_GLOBAL("relay._transform.DeFuseOps") +.set_body_typed(DeFuseOps); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/kernel_layout_transform.cc b/src/relay/transforms/kernel_layout_transform.cc new file mode 100644 index 000000000000..681785c8123c --- /dev/null +++ b/src/relay/transforms/kernel_layout_transform.cc @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include "kernel_layout_transform.h" + +namespace tvm { +namespace relay { + +// Todo: do not use global variables +std::deque KernelLayoutVisitor::global_ori_layouts_queue; +std::deque KernelLayoutVisitor::global_new_layouts_queue; + +Expr KernelLayoutTransform(const Expr& expr) { + KernelLayoutVisitor visitor; + + // Do a pre-order DFS to gather the optimal kernel layouts for all conv2d nodes. + // These layouts were written to global static variables in python function `prepare_layout_rewrite` + visitor.VisitExpr(expr); + + // Do a post-order DSF to mutate layout for all conv2d nodes + return KernelLayoutTransformer(&visitor).Mutate(expr); +} + +namespace transform { + +Pass KernelLayoutTransform() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(relay::KernelLayoutTransform(f)); + }; + return CreateFunctionPass(pass_func, 3, "KernelLayoutTransform", + {"InferType"}); +} + +TVM_REGISTER_GLOBAL("relay._transform.KernelLayoutTransform") +.set_body_typed(KernelLayoutTransform); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/kernel_layout_transform.h b/src/relay/transforms/kernel_layout_transform.h new file mode 100644 index 000000000000..b4b806c20e28 --- /dev/null +++ b/src/relay/transforms/kernel_layout_transform.h @@ -0,0 +1,75 @@ +#include +#include +#include +#include + +#include "pattern_util.h" + +#include "../../ansor/compute_dag.h" + +namespace tvm { +namespace relay { + +/*! \brief A visitor to gather the optimal kernel layout for all conv2d nodes. */ +class KernelLayoutVisitor : public ExprVisitor { + public: + void VisitExpr_(const CallNode *n) { + if (n && n->op.as() && + (std::find(op_white_lists.begin(), op_white_lists.end(), n->op.as()->name) != + op_white_lists.end()) && n->args[1]->type_as()->shape[3].as()->value > 1 && + !global_ori_layouts_queue.empty() && !global_new_layouts_queue.empty()) { + ori_layouts_map[n] = global_ori_layouts_queue.front(); + new_layouts_map[n] = global_new_layouts_queue.front(); + std::cout << "ori_layout " << global_ori_layouts_queue.front() << " Filter_shape " << n->args[1]->type_as()->shape << std::endl; + global_ori_layouts_queue.pop_front(); + global_new_layouts_queue.pop_front(); + } + ExprVisitor::VisitExpr_(n); + } + + std::unordered_map ori_layouts_map; + std::unordered_map new_layouts_map; + std::vector op_white_lists {"nn.contrib_conv2d_winograd_without_weight_transform", + "nn.conv2d", "nn.conv3d"}; + + static std::deque global_ori_layouts_queue; + static std::deque global_new_layouts_queue; +}; + + +/*! \brief A mutator to rewrite kernel layout for all conv2d nodes */ +class KernelLayoutTransformer : public ExprMutator { + public: + KernelLayoutTransformer(KernelLayoutVisitor* visitor): ExprMutator(), visitor_(visitor) {} + + Expr VisitExpr_(const CallNode* n) { + auto new_n = ExprMutator::VisitExpr_(n); + + const auto* call = new_n.as(); + std::vector op_white_lists {"nn.contrib_conv2d_winograd_without_weight_transform", + "nn.conv2d", "nn.conv3d"}; + if (call && call->op.as() && + (std::find(op_white_lists.begin(), op_white_lists.end(), n->op.as()->name) != + op_white_lists.end() && n->args[1]->type_as()->shape[3].as()->value > 1)) { + auto ori_layout_iter = visitor_->ori_layouts_map.find(n); + auto new_layout_iter = visitor_->new_layouts_map.find(n); + if (ori_layout_iter != visitor_->ori_layouts_map.end() && + new_layout_iter != visitor_->new_layouts_map.end()) { + const std::string& ori_layout = ori_layout_iter->second; + const std::string& new_layout = new_layout_iter->second; + Expr updated_kernel = MakeKernelLayoutTransform(call->args[1], ori_layout, new_layout); + Array updated_args = {call->args[0], updated_kernel}; + new_n = Call(call->op, updated_args, + call->attrs); + } + } + return new_n; + } + + private: + KernelLayoutVisitor* visitor_; +}; + + +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/pattern_util.h b/src/relay/transforms/pattern_util.h index 7518eb9ac81a..a9d3b5168e47 100644 --- a/src/relay/transforms/pattern_util.h +++ b/src/relay/transforms/pattern_util.h @@ -685,6 +685,8 @@ Expr MakeExpandDims(Expr data, int axis, int num_newaxis); Expr MakeLayoutTransform(Expr data, String src_layout, String dst_layout); +Expr MakeKernelLayoutTransform(Expr data, String src_layout, String dst_layout); + Expr StopFusion(Expr data); Expr CastHint(Expr data, DataType dtype); diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index 4c7941b49692..de02367a4dff 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -342,7 +342,24 @@ def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'): dilation_h, dilation_w = dilation batch, in_height, in_width, in_channel = Input.shape - kernel_h, kernel_w, channel, num_filter = Filter.shape + if len(Filter.shape) == 10: + kernel_h = Filter.shape[2] * Filter.shape[6] + kernel_w = Filter.shape[3] * Filter.shape[7] + channel = Filter.shape[4] * Filter.shape[8] + num_filter = Filter.shape[0] * Filter.shape[1] * Filter.shape[5] * Filter.shape[9] + elif len(Filter.shape) == 11: + kernel_h = Filter.shape[3] * Filter.shape[7] + kernel_w = Filter.shape[4] * Filter.shape[8] + channel = Filter.shape[5] * Filter.shape[9] + num_filter = Filter.shape[0] * Filter.shape[1] * Filter.shape[2] * Filter.shape[6] * Filter.shape[10] + elif len(Filter.shape) == 12: + kernel_h = Filter.shape[4] * Filter.shape[8] + kernel_w = Filter.shape[5] * Filter.shape[9] + channel = Filter.shape[6] * Filter.shape[10] + num_filter = Filter.shape[0] * Filter.shape[1] * Filter.shape[2] * Filter.shape[3] * Filter.shape[7] * Filter.shape[11] + else: + kernel_h, kernel_w, channel, num_filter = Filter.shape + # compute the output shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 @@ -362,8 +379,9 @@ def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'): lambda nn, yy, xx, ff: te.sum( PaddedInput[nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rc].astype(out_dtype) * - Filter[ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc]), - name="Conv2dOutput", tag="conv2d_nhwc") + Filter[ry, rx, rc, ff].astype(out_dtype) + , axis=[ry, rx, rc]), + name="Conv2dOutput", tag="conv2d_nhwc", attrs={"layout_free_placeholders": [Filter]}) return Output From 145e61cf072b5e976eac07484beb25711b222c25 Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Sat, 20 Jun 2020 00:29:19 +0800 Subject: [PATCH 30/45] [cache flush] port cache flush to ansor (#32) --- scripts/tune_test.py | 3 ++- src/runtime/rpc/rpc_module.cc | 31 +++++++++++++++++++++++++++++++ src/runtime/threading_backend.cc | 9 +++++++-- 3 files changed, 40 insertions(+), 3 deletions(-) diff --git a/scripts/tune_test.py b/scripts/tune_test.py index a49ecd088afc..7831aea9dd4a 100644 --- a/scripts/tune_test.py +++ b/scripts/tune_test.py @@ -22,7 +22,8 @@ def create_tune_option(target, log_file, n_trials, num_measure_per_iter, verbose measure_ctx = ansor.LocalRPCMeasureContext(repeat=1, min_repeat_ms=400) runner = measure_ctx.runner else: - runner = ansor.LocalRunner(repeat=1, min_repeat_ms=400) + os.environ['TVM_AUTO_CACHE_FLUSH'] = "1" + runner = ansor.LocalRunner(repeat=10, number=1, min_repeat_ms=0, timeout=run_timeout) else: os.environ['TVM_NDK_CC'] = ndk_cc builder = ansor.LocalBuilder(timeout=build_timeout, build_func='ndk') diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index 89f3e7c6c7f8..b95d5ba25926 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -24,9 +24,14 @@ #include #include +#include #include #include +#if defined(_M_X64) || defined(__x86_64__) +#include +#endif + #include "rpc_endpoint.h" #include "rpc_session.h" @@ -300,6 +305,23 @@ std::shared_ptr RPCModuleGetSession(Module mod) { return rmod->sess(); } +inline void CacheFlush(const char* p, unsigned int allocation_size) { +// TODO: (FrozenGene) +// Support ARM. +#if (defined(_M_X64) || defined(__x86_64__)) + size_t cache_line = 64; + + if (p == nullptr || allocation_size <= 0) { + return; + } + + for (size_t i = 0; i < allocation_size; i += cache_line) { + _mm_clflush(static_cast(&p[i])); + } + +#endif +} + PackedFunc WrapTimeEvaluator(PackedFunc pf, TVMContext ctx, int number, int repeat, int min_repeat_ms) { CHECK(pf != nullptr); @@ -313,12 +335,21 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf, TVMContext ctx, int number, int repe auto ftimer = [pf, ctx, number, repeat, min_repeat_ms](TVMArgs args, TVMRetValue* rv) mutable { TVMRetValue temp; std::ostringstream os; + const char* cache_flush = std::getenv("TVM_AUTO_CACHE_FLUSH"); // skip first time call, to activate lazy compilation components. pf.CallPacked(args, &temp); DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); for (int i = 0; i < repeat; ++i) { + if (cache_flush && std::atoi(cache_flush) != 0) { + CHECK_EQ(number, 1); + // we want to keep input data + for (int j = 1; j < args.size(); j++) { + CacheFlush((char*)(args[j].operator DLTensor*()->data), + GetDataSize(*(args[j].operator DLTensor*()))); + } + } std::chrono::time_point tbegin, tend; double duration_ms = 0.0; diff --git a/src/runtime/threading_backend.cc b/src/runtime/threading_backend.cc index e5520efe30a6..3b1889aed8ef 100644 --- a/src/runtime/threading_backend.cc +++ b/src/runtime/threading_backend.cc @@ -166,8 +166,13 @@ class ThreadGroup::Impl { #if defined(_M_X64) || defined(__x86_64__) big_count /= 2; // ignore hyper-threading #endif - for (int i = 0; i < big_count; ++i) { - CPU_SET(sorted_order_[i], &cpuset); + const char* bind_master_core_0 = getenv("TVM_BIND_MASTER_CORE_0"); + if (bind_master_core_0 && atoi(bind_master_core_0) != 0) { + CPU_SET(sorted_order_[0], &cpuset); + } else { + for (int i = 0; i < big_count; ++i) { + CPU_SET(sorted_order_[i], &cpuset); + } } } #if defined(__ANDROID__) From 2c2781690313894dd35f578a8be48f940e8d7125 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 19 Jun 2020 17:42:00 -0700 Subject: [PATCH 31/45] Improve relay integration (#34) * tmp checkpoint * Improve relay integration * Improve relay integration --- python/tvm/ansor/__init__.py | 11 +- python/tvm/ansor/auto_schedule.py | 67 +++-- python/tvm/ansor/compute_dag.py | 50 +--- python/tvm/ansor/cost_model/cost_model.py | 2 + python/tvm/ansor/dispatcher.py | 18 +- python/tvm/ansor/env.py | 2 +- python/tvm/ansor/feature.py | 7 +- python/tvm/ansor/loop_state.py | 4 +- python/tvm/ansor/measure.py | 10 +- python/tvm/ansor/relay_integration.py | 259 ++++++++++-------- python/tvm/ansor/serialization.py | 7 +- python/tvm/ansor/topi_integration.py | 220 --------------- python/tvm/ansor/utils.py | 6 +- python/tvm/relay/backend/compile_engine.py | 5 +- python/tvm/relay/build_module.py | 7 + python/tvm/relay/op/strategy/x86.py | 63 +++-- python/tvm/relay/testing/resnet.py | 17 +- scripts/tune_network.py | 15 +- scripts/tune_test.py | 2 +- src/ansor/compute_dag.cc | 5 - src/ansor/compute_dag.h | 2 +- src/ansor/feature.cc | 21 +- src/ansor/measure.cc | 24 +- src/ansor/measure.h | 21 +- .../search_policy/meta_tile_rewrite_policy.cc | 136 ++++----- src/ansor/search_policy/search_policy.cc | 48 ++-- src/ansor/search_policy/search_policy.h | 41 +-- src/ansor/search_policy/utils.cc | 169 ------------ src/ansor/search_policy/utils.h | 8 - src/ansor/serialization.cc | 1 + src/relay/backend/build_module.cc | 21 +- .../transforms/kernel_layout_transform.h | 3 +- ...ion.py => test_ansor_relay_integration.py} | 79 ++++-- topi/python/topi/ansor.py | 95 ------- topi/python/topi/arm_cpu/__init__.py | 5 - topi/python/topi/generic/__init__.py | 5 - topi/python/topi/nn/conv2d.py | 47 ++-- topi/python/topi/x86/__init__.py | 5 - tutorials/ansor/tune_conv2d_cuda.py | 4 +- tutorials/ansor/tune_simple_subgraph.py | 4 +- 40 files changed, 551 insertions(+), 965 deletions(-) delete mode 100644 python/tvm/ansor/topi_integration.py rename tests/python/unittest/{test_ansor_relay_Integration.py => test_ansor_relay_integration.py} (53%) delete mode 100644 topi/python/topi/ansor.py diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index b43b21a60144..977e100e63c6 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -28,10 +28,9 @@ from . import task_scheduler # Shortcut -from .compute_dag import ComputeDAG, LayoutRewriteLevel, gen_schedule +from .compute_dag import ComputeDAG, LayoutRewriteLevel from .auto_schedule import SearchTask, MetaTileRewritePolicy, TuneOption, HardwareParams, \ - PreLoadMeasuredStates, PreAddCustomRule -from .auto_schedule import auto_schedule + PreloadMeasuredStates, PreAddCustomRule, auto_schedule from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, LocalRPCMeasureContext from .cost_model import RandomModel from .cost_model.xgb_model import XGBModel @@ -41,7 +40,7 @@ workload_key_to_dag, make_workload_key_func from .task_scheduler import TaskScheduler, SimpleTaskScheduler from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest as apply_history_best, \ - FallbackContext, clear_fallback_cache, ApplyGraphBest, BlockingEmptyContext -from .topi_integration import register_topi_schedule, TaskExtractEnv + FallbackContext, clear_fallback_cache, ApplyGraphBest from .relay_integration import extract_from_program, extract_from_multiple_program, \ - finish_layout_rewrite, prepare_layout_rewrite + finish_layout_rewrite, prepare_layout_rewrite, auto_schedule_topi +from .env import GLOBAL_SCOPE diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index 232c24ee89ea..acf8982d6e89 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -"""Meta information for a search task""" +"""User interface for auto-scheduler""" import random @@ -29,35 +29,36 @@ @tvm._ffi.register_object("ansor.HardwareParams") class HardwareParams(Object): """ + The parameters of target hardware + Parameters ---------- - num_cores : Int - vector_unit_bytes : Int - cache_line_bytes : Int - max_unroll_vec : Int - max_innermost_split_factor : Int + num_cores : int + vector_unit_bytes : int + cache_line_bytes : int + max_unroll_vec : int + max_innermost_split_factor : int """ - def __init__(self, num_cores, vector_unit_bytes, cache_line_bytes, max_unroll_vec, max_innermost_split_factor): self.__init_handle_by_constructor__(_ffi_api.HardwareParams, num_cores, vector_unit_bytes, cache_line_bytes, - max_unroll_vec, - max_innermost_split_factor) + max_unroll_vec, max_innermost_split_factor) @tvm._ffi.register_object("ansor.SearchTask") class SearchTask(Object): """ + The meta-information of a search task + Parameters ---------- dag : ComputeDAG - workload_key : Str - target : tvm.target - target_host : tvm.target + workload_key : str + target : tvm.target.Target + target_host : tvm.target.Target hardware_params : HardwareParams """ - def __init__(self, dag, workload_key, target, target_host=None, hardware_params=None): self.__init_handle_by_constructor__(_ffi_api.SearchTask, dag, @@ -67,10 +68,10 @@ def __init__(self, dag, workload_key, target, target_host=None, @tvm._ffi.register_object("ansor.SearchPolicy") class SearchPolicy(Object): - """ The base search policy class - """ + """ The base class for search policy """ def continue_search(self, task, num_measure, verbose, measurer): - return _ffi_api.SearchPolicyContinueSearchOneRound(self, task, num_measure, verbose, measurer) + return _ffi_api.SearchPolicyContinueSearchOneRound(self, task, + num_measure, verbose, measurer) def set_task(self, task): _ffi_api.SearchPolicySetTask(self, task) @@ -89,7 +90,7 @@ class MetaTileRewritePolicy(SearchPolicy): Parameters ---------- program_cost_model: CostModel - Cost model for complete programs + Cost model for programs params: int Parameters of the search policy, go meta_tile_rewrite_policy.h to find the definitions. See code below to find the default values @@ -130,21 +131,22 @@ def __init__(self, @tvm._ffi.register_object("ansor.SearchCallback") class SearchCallback(Object): + """Callback function before or after search process""" pass -@tvm._ffi.register_object("ansor.PreLoadMeasuredStates") -class PreLoadMeasuredStates(SearchCallback): - """ A SearchCallback that used for search policy to load measured hash - from the log file. +@tvm._ffi.register_object("ansor.PreloadMeasuredStates") +class PreloadMeasuredStates(SearchCallback): + """ A SearchCallback to load measured states from the log file for a search policy. + This can resume the state of the search policy. Parameters ---------- - filename: Str + filename: str """ def __init__(self, filename: str): self.__init_handle_by_constructor__( - _ffi_api.PreLoadMeasuredStates, filename) + _ffi_api.PreloadMeasuredStates, filename) @tvm._ffi.register_object("ansor.PreAddCustomRule") @@ -153,8 +155,10 @@ class PreAddCustomRule(SearchCallback): A SearchCallback for MetaTileRewritePolicy that allowing users to add custom sketch rule. - Notice: This is an advanced feature, make sure you're clear how it - works and this should only be used in MetaTileRewritePolicy. + Notes + ----- + This is an advanced feature. Make sure you're clear how it + works and this should only be used in MetaTileRewritePolicy. Parameters ---------- @@ -193,7 +197,7 @@ class TuneOption(Object): pre_search_callbacks: List[SearchCallback] Callback functions called before the search process Candidates: - - ansor.PreLoadMeasuredStates + - ansor.PreloadMeasuredStates - ansor.PreAddCustomRule """ def __init__(self, n_trials=0, early_stopping=-1, num_measure_per_iter=64, @@ -225,7 +229,7 @@ def __init__(self, n_trials=0, early_stopping=-1, num_measure_per_iter=64, def auto_schedule(workload, target=None, target_host=None, search_policy='default', hardware_params=None, tune_option=None): - """ Do auto schedule for a compute declaration. + """ Do auto scheduling for a computation declaration. The workload parameter can be a `string` as workload_key, or directly passing a `SearchTask` as input. @@ -233,21 +237,15 @@ def auto_schedule(workload, target=None, Parameters ---------- workload : Union[SearchTask, str] - target : Target - target_host : Target = None - search_policy : Union[SearchPolicy, str] - hardware_params : HardwareParams - tune_option : TuneOption Returns ------- sch : tvm.Schedule - tensors : List[Tensor] """ if isinstance(search_policy, str): @@ -267,5 +265,4 @@ def auto_schedule(workload, target=None, sch, tensors = _ffi_api.AutoScheduleBySearchTask(workload, search_policy, tune_option) return sch, tensors else: - raise ValueError("Invalid workload: " + workload + - ". Expect a string or SearchTask") + raise ValueError("Invalid workload: " + workload + ". Expect a string or SearchTask") diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index c54c14ec123a..f35c9d8221f3 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -19,7 +19,6 @@ import tvm._ffi from tvm.runtime import Object -from tvm import te from .loop_state import State, StateObject from . import _ffi_api @@ -34,11 +33,12 @@ class LayoutRewriteLevel(object): @tvm._ffi.register_object("ansor.ComputeDAG") class ComputeDAG(Object): """ + Computation declaration graph + Parameters ---------- tensors : List[Tensor] """ - def __init__(self, tensors): self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, tensors) @@ -51,29 +51,20 @@ def get_init_state(self): """ return State(_ffi_api.ComputeDAGGetInitState(self), self) - def apply_steps_from_state(self, state, layout_rewrite_level=None): + def apply_steps_from_state(self, state, layout_rewrite_level=LayoutRewriteLevel.NO_REWRITE): """ Parameters ---------- state : StateObject - layout_rewrite_level : LayoutRewriteLevel(***) + layout_rewrite_level : LayoutRewriteLevel Returns ------- sch : Schedule args : List[Tensor] """ - if isinstance(state, State): - return _ffi_api.ComputeDAGApplyStepsFromState(self, state.state_object, - layout_rewrite_level) - elif isinstance(state, StateObject): - return _ffi_api.ComputeDAGApplyStepsFromState(self, state, - layout_rewrite_level) - else: - raise ValueError("The input must be a State or StateObject") - - def rewrite_layout_from_state(self, state: State): - return _ffi_api.ComputeDAGRewriteLayoutFromState(self, state) + state_obj = state if isinstance(state, StateObject) else state.state_object + return _ffi_api.ComputeDAGApplyStepsFromState(self, state_obj, layout_rewrite_level) def print_python_code_from_state(self, state): """ @@ -85,12 +76,8 @@ def print_python_code_from_state(self, state): ------- str : Str """ - if isinstance(state, State): - return _ffi_api.ComputeDAGPrintPythonCodeFromState(self, state.state_object) - elif isinstance(state, StateObject): - return _ffi_api.ComputeDAGPrintPythonCodeFromState(self, state) - else: - raise ValueError("The input must be a State or StateObject") + state_obj = state if isinstance(state, StateObject) else state.state_object + return _ffi_api.ComputeDAGPrintPythonCodeFromState(self, state_obj) def infer_bound_from_state(self, state): """ @@ -102,19 +89,8 @@ def infer_bound_from_state(self, state): ------- state : StateObject """ - if isinstance(state, State): - return State(_ffi_api.ComputeDAGInferBoundFromState(self, state.state_object), self) - elif isinstance(state, StateObject): - return State(_ffi_api.ComputeDAGInferBoundFromState(self, state), self) - else: - raise ValueError("The input must be a State or StateObject") - -def gen_schedule(state, bufs): - if not state or not state.complete: - return te.create_schedule([x.op for x in bufs]) - else: - dag = ComputeDAG(bufs) - # only update compute body, layout_rewrite_level = LayoutRewriteLevel.COMPUTE_REWRITE, - # since kernel layout has already been rewritten in relay pass - schedule, _ = dag.apply_steps_from_state(state, layout_rewrite_level=LayoutRewriteLevel.COMPUTE_REWRITE) - return schedule + state_obj = state if isinstance(state, StateObject) else state.state_object + return State(_ffi_api.ComputeDAGInferBoundFromState(self, state_obj), self) + + def rewrite_layout_from_state(self, state: State): + return _ffi_api.ComputeDAGRewriteLayoutFromState(self, state) diff --git a/python/tvm/ansor/cost_model/cost_model.py b/python/tvm/ansor/cost_model/cost_model.py index fd9b67927185..47ea5092b302 100644 --- a/python/tvm/ansor/cost_model/cost_model.py +++ b/python/tvm/ansor/cost_model/cost_model.py @@ -34,6 +34,7 @@ class RandomModel(Object): def __init__(self): self.__init_handle_by_constructor__(_ffi_api.RandomModel) + # A random number generator func for c++'s RandomModel @tvm._ffi.register_func("ansor.cost_model.random_number") def random_number(n, return_ptr): @@ -43,6 +44,7 @@ def random_number(n, return_ptr): array_wrapper = np.ctypeslib.as_array(return_ptr, shape=(n,)) array_wrapper[:] = np.random.uniform(0, 1, (n,)) + @tvm._ffi.register_object("ansor.PythonBasedModel") class PythonBasedModel(CostModel): def __init__(self): diff --git a/python/tvm/ansor/dispatcher.py b/python/tvm/ansor/dispatcher.py index 2f00c355d285..0ef07197ea92 100644 --- a/python/tvm/ansor/dispatcher.py +++ b/python/tvm/ansor/dispatcher.py @@ -36,9 +36,7 @@ from decorator import decorate from tvm import target as _target -from tvm.tir.expr import StringImm, FloatImm - -from .loop_state import State, StateObject +from tvm.tir.expr import FloatImm logger = logging.getLogger('auto_scheduler') @@ -360,19 +358,6 @@ def update(self, target, workload, state): self._best_user_defined[key] = state -class BlockingEmptyContext(DispatchContext): - """ - An empty context which returns emtpy State() for all queries. - This also blocks the queries, so the queries won't affect the global FallbackContext. - """ - def __init__(self): - super(BlockingEmptyContext, self).__init__() - - def query(self, target, workload): - #return StateObject() - return None - - class FallbackContext(DispatchContext): """ A fallback dispatch context. @@ -400,7 +385,6 @@ def _query_inside(self, target, workload): if msg not in self.messages: self.messages.add(msg) logger.warning(msg) - #cfg = StateObject() cfg = None # cache this config diff --git a/python/tvm/ansor/env.py b/python/tvm/ansor/env.py index 6d2bbd2c92af..9e44ad66048b 100644 --- a/python/tvm/ansor/env.py +++ b/python/tvm/ansor/env.py @@ -1,4 +1,4 @@ -""" The scope to store global variables in auto_scheduelr """ +""" The scope to store global variables in ansor """ class AutoschedulerGlobalScope(object): def __init__(self): diff --git a/python/tvm/ansor/feature.py b/python/tvm/ansor/feature.py index 4f9fdeb9e6cd..9496533da6cc 100644 --- a/python/tvm/ansor/feature.py +++ b/python/tvm/ansor/feature.py @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. -""""Python API for Feature extraction. +"""" +Python API for Feature extraction. The specification of features can be found in `autoscheduler_doc/per_stage_feature.md` """ @@ -28,8 +29,10 @@ from . import _ffi_api +# Maximum number of buffers for one statement to extract feature for DEFAULT_MAX_N_BUFS = 5 +# The length of the feature vector DEFAULT_FEATURE_VEC_LEN = 164 @@ -145,6 +148,6 @@ def get_per_stmt_features_from_states(states, def get_per_stmt_feature_names(max_n_bufs: int = None) -> List[str]: - """Get names of the elements in the flatten feature vector""" + """Get names for the elements in the flatten feature vector""" return [x for x in _ffi_api.GetPerStmtFeatureNames(max_n_bufs or DEFAULT_MAX_N_BUFS)] diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/ansor/loop_state.py index 23289c027293..3c60c3f09a8d 100644 --- a/python/tvm/ansor/loop_state.py +++ b/python/tvm/ansor/loop_state.py @@ -26,9 +26,9 @@ We don't use the existing TVM IR because 1. We want fast incremental change to the loop structures 2. We want serializable history for replay and backtracking -3. We may create some Macro schedule primitives +3. We may create some new macro schedule primitives -After search is done, we will lower this IR to TVM IR with TVM schedule primitives. +After search is done, we will lower this IR to TVM IR with TVM's schedule primitives. Because we share a lot common objects during search, the transformation is implemented in copy on write style. All objects are immutable, which is similar to TVM IR. diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index 8b38f91647b2..3d9c33860cae 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -38,18 +38,20 @@ from tvm.rpc.tracker import Tracker from tvm.rpc.server import Server from tvm.autotvm.measure.measure_methods import set_cuda_target_arch -from ..contrib import tar, ndk +from tvm.contrib import tar, ndk +from . import _ffi_api from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout, request_remote, check_remote from .compute_dag import LayoutRewriteLevel -from . import _ffi_api logger = logging.getLogger('ansor') +# The maximum length of error message MAX_ERROR_MSG_LEN = 512 @tvm._ffi.register_object("ansor.MeasureCallback") class MeasureCallback(Object): + """Base class for measurement callback function""" pass @tvm._ffi.register_object("ansor.MeasureInput") @@ -103,7 +105,7 @@ def __init__(self, costs, error_no, error_msg, all_cost, timestamp): @tvm._ffi.register_object("ansor.Builder") class Builder(Object): - def build(self, measure_inputs, verbose=0): + def build(self, measure_inputs, verbose=1): """ Parameters ---------- @@ -119,7 +121,7 @@ def build(self, measure_inputs, verbose=0): @tvm._ffi.register_object("ansor.Runner") class Runner(Object): - def run(self, measure_inputs, build_results, verbose=0): + def run(self, measure_inputs, build_results, verbose=1): """ Parameters ---------- diff --git a/python/tvm/ansor/relay_integration.py b/python/tvm/ansor/relay_integration.py index 383471ee060d..85c4d8813f69 100644 --- a/python/tvm/ansor/relay_integration.py +++ b/python/tvm/ansor/relay_integration.py @@ -15,90 +15,33 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=unused-variable,invalid-name -""" -Decorator and utilities for the integration with TOPI and Relay -99.9% copy-paste of implementation by @MerryMercy +""" +Integrate ansor into relay. It implements the following items: +1. Extract search tasks from a relay program +2. Provide auto-scheduling for all TOPI compute functions """ import os -os.environ['TVM_USE_AUTO_SCHEDULER'] = 'true' - +import json import threading -import warnings -import tvm - -from .topi_integration import TaskExtractEnv -from .dispatcher import BlockingEmptyContext +from tvm import target, te, transform +from tvm.te.tensor import PlaceholderOp, ComputeOp +from .dispatcher import DispatchContext +from .workload_registry import register_auto_scheduler_workload_bufs, compute_dag_hash +from .compute_dag import ComputeDAG, LayoutRewriteLevel from .env import GLOBAL_SCOPE -def _lower(mod, - target, - params): - """ Helper to lower VTA properly. - """ +def call_all_topi_funcs(mod, target, params): + """Call all TOPI compute + schedule to extract tasks in a relay program""" # pylint: disable=import-outside-toplevel from tvm import relay - from tvm.relay.backend import graph_runtime_codegen - - if hasattr(target, 'device_name') and target.device_name == "vta": - import vta - with vta.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): - mod, _ = relay.optimize(mod, target, params) - grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target) - grc.codegen(mod["main"]) - return - # default case - # Try graph codegen first to extract autotvm tasks. - # If failed to compile, then fallback to use VM compiler. - # TODO: Currently VM compiler is likely to stack overflow for large models. - try: - with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): - opt_mod, _ = relay.optimize(mod, target, params) - grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target) - grc.codegen(opt_mod["main"]) - except tvm.TVMError: - compiler = relay.vm.VMCompiler() - if params: - compiler.set_params(params) - compiler.lower(mod, target=target) - -OP_TO_SCHEDULE = {} - -def init_op_to_schedule_map(): - # init the global map OP_TO_SCHEDULE inside a function, this is used to resolve import issues - global OP_TO_SCHEDULE - from tvm import relay - import topi - - if OP_TO_SCHEDULE: - return - - OP_TO_SCHEDULE = { - relay.op.nn.conv2d: [topi.generic.schedule_conv2d_nchw, - topi.generic.schedule_conv2d_nhwc, - topi.generic.schedule_depthwise_conv2d_nchw, - topi.generic.schedule_depthwise_conv2d_nhwc, - topi.generic.schedule_group_conv2d_nchw, - topi.generic.schedule_conv2d_winograd_without_weight_transform], - relay.op.nn.conv2d_transpose: [topi.generic.schedule_conv2d_transpose_nchw], - relay.op.nn.dense: [topi.generic.schedule_dense], - relay.op.nn.softmax: [topi.generic.schedule_softmax], - relay.op.nn.max_pool2d: [topi.generic.schedule_pool], - relay.op.nn.avg_pool2d: [topi.generic.schedule_pool], - relay.op.nn.global_avg_pool2d: [topi.generic.schedule_adaptive_pool], - relay.op.nn.global_max_pool2d: [topi.generic.schedule_adaptive_pool], - relay.op.nn.deformable_conv2d: [topi.generic.schedule_deformable_conv2d_nchw], - relay.op.mean: [topi.generic.schedule_reduce], - relay.op.prod: [topi.generic.schedule_reduce], - relay.op.nn.conv3d: [topi.generic.schedule_conv3d_ncdhw, - topi.generic.schedule_conv3d_ndhwc], - relay.op.nn.adaptive_avg_pool3d: [topi.generic.schedule_adaptive_pool], - relay.op.nn.batch_matmul: [topi.generic.schedule_batch_matmul], - } - -def extract_from_program(mod, params, target, target_host=None, ops=None): + with transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}): + bld_mod = relay.build_module.BuildModule() + bld_mod.call_all_topi_funcs(mod, target=target, params=params) + +def extract_from_program(mod, params, target, target_host=None): """ Extract tuning tasks from a relay program. This function is the single program version of extract_from_multiple_program. @@ -120,14 +63,11 @@ def extract_from_program(mod, params, target, target_host=None, ops=None): ------- workloads: Array of Tuple(wkl_key, target) """ - return extract_from_multiple_program([mod], [params], target, target_host, ops) + return extract_from_multiple_program([mod], [params], target, target_host) -def extract_from_multiple_program(mods, params, target, target_host=None, ops=None): +def extract_from_multiple_program(mods, params, target, target_host=None): """ Extract tuning tasks from multiple relay programs. - This function collects tuning tasks by building a list of programs - with a "tracing" target and tracing all the calls to topi. - Parameters ---------- mods : List of relay.Module @@ -145,35 +85,17 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No ------- workloads: Array of Tuple(wkl_key, target) """ + # pylint: disable=import-outside-toplevel from tvm import relay - env = TaskExtractEnv.get() - - init_op_to_schedule_map() - topi_scheds = [] - - if not ops: - ops = [relay.op.nn.dense, relay.op.nn.softmax, relay.op.nn.conv2d, - relay.op.nn.conv2d_transpose, relay.op.nn.max_pool2d, - relay.op.nn.avg_pool2d, relay.op.nn.global_max_pool2d, - relay.op.nn.global_avg_pool2d, relay.op.nn.conv3d, - relay.op.nn.adaptive_avg_pool3d, relay.op.nn.batch_matmul, - relay.op.mean] - - for op_name in ops: - if op_name in OP_TO_SCHEDULE: - topi_scheds.extend(OP_TO_SCHEDULE[op_name]) - else: - warnings.warn("Op %s is not tunable, ignored." % op_name) - - # run compiler to collect all TOPI calls during compilation - env.reset(topi_scheds) + env = TracingEnvironment(TracingMode.EXTRACT_TASK) with env: + # run compiler to collect all TOPI calls during compilation for mod, param in zip(mods, params): - # wrap build call in thread to avoid multiprocessing problems - with BlockingEmptyContext(): - build_thread = threading.Thread(target=_lower, - args=(mod, target, param)) + # wrap build call in a new thread to avoid the conflict + # between python's multiprocessing and tvm's thread pool + build_thread = threading.Thread(target=call_all_topi_funcs, + args=(mod, target, param)) build_thread.start() build_thread.join() relay.backend.compile_engine.get().clear() @@ -181,32 +103,26 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No # create tasks for target wkl_keys = [] wkl_weights = [] - for wkl_key, wkl_weight in env.get_wkl_keys().items(): + for wkl_key, wkl_weight in env.wkl_key_collection.items(): wkl_keys.append(wkl_key) wkl_weights.append(wkl_weight) return wkl_keys, wkl_weights -def prepare_layout_rewrite(mod, params, ops, target): - """Prepare for kernel layout rewrite. This function will write layout infos to a global static variable, - then these layout info will be used by a relay pass `kernel_layout_transform`. + +def prepare_layout_rewrite(mod, params, target): + """ + Prepare for kernel layout rewrite. This function will write layout infos to a global static variable. + Then these layout info will be used by a relay pass `kernel_layout_transform`. """ + # pylint: disable=import-outside-toplevel from tvm import relay - env = TaskExtractEnv.get(do_layout_rewrite=True) - - init_op_to_schedule_map() - topi_scheds = [] - for op_name in ops: - if op_name in OP_TO_SCHEDULE: - topi_scheds.extend(OP_TO_SCHEDULE[op_name]) - else: - warnings.warn("Op %s is not tunable, ignored." % op_name) - - env.reset(topi_scheds) + env = TracingEnvironment(TracingMode.PREPARE_LAYOUT_REWRITE) with env: - # wrap build call in thread to avoid multiprocessing problems - build_thread = threading.Thread(target=_lower, + # wrap build call in a new thread to avoid the conflict + # between python's multiprocessing and tvm's thread pool + build_thread = threading.Thread(target=call_all_topi_funcs, args=(mod, target, params)) build_thread.start() build_thread.join() @@ -218,3 +134,104 @@ def prepare_layout_rewrite(mod, params, ops, target): def finish_layout_rewrite(): """Clear the global flag for layout rewrite""" GLOBAL_SCOPE.topi_in_compute_rewrite_mode = False + + +class TracingMode: + """Two modes for tracing""" + EXTRACT_TASK = 0 # trace all topi calls to extract tasks + PREPARE_LAYOUT_REWRITE = 1 # trace all topi calls to prepare layout rewrite + +class TracingEnvironment: + """Global environment for tracing all topi function calls""" + current = None + + def __init__(self, tracing_mode): + self.tracing_mode = tracing_mode + self.relay_disable_build_cache = "false" + self.layout_rewrite_success_ct = 0 + self.wkl_key_collection = {} + + def __enter__(self): + self.relay_disable_build_cache = os.environ.get("TVM_RELAY_DISABLE_BUILD_CACHE", "false") + os.environ["TVM_RELAY_DISABLE_BUILD_CACHE"] = "true" + TracingEnvironment.current = self + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + os.environ["TVM_RELAY_DISABLE_BUILD_CACHE"] = self.relay_disable_build_cache + TracingEnvironment.current = None + + def add_workload_key(self, key): + """Add the workload key of an Ansor search task + + Parameters + ---------- + key: str + """ + if key in self.wkl_key_collection: + self.wkl_key_collection[key] += 1 + else: + self.wkl_key_collection[key] = 1 + + +def traverse_to_get_io_tensors(outs): + """Traverse from a list of output tensors to get a whole computational DAG""" + layout_free_ops = [] + inputs = [] + + visited = set() + + def traverse(t): + if t in visited: + return + if isinstance(t.op, PlaceholderOp): + inputs.append(t) + elif isinstance(t.op, ComputeOp): + if "layout_free_placeholders" in t.op.attrs: + layout_free_ops.append(t.op) + for x in t.op.input_tensors: + traverse(x) + visited.add(t) + + for t in outs: + traverse(t) + + has_layout_free = (len(layout_free_ops) > 0) + return inputs + [t for t in outs], has_layout_free + + +def auto_schedule_topi(outs): + """ Use ansor to auto-schedule a topi compute declaration """ + io_tensors, has_layout_free = traverse_to_get_io_tensors(outs) + key = register_auto_scheduler_workload_bufs(io_tensors) + + env = TracingEnvironment.current + if env is None: # in the final build mode + state = DispatchContext.current.query(target.Target.current(), key) + dag = ComputeDAG(io_tensors) + # Only update compute body, layout_rewrite_level = LayoutRewriteLevel.COMPUTE_REWRITE, + # Since kernel layout has already been rewritten in relay pass + schedule, _ = dag.apply_steps_from_state(state, + layout_rewrite_level=LayoutRewriteLevel.COMPUTE_REWRITE) + return schedule + elif env.tracing_mode == TracingMode.EXTRACT_TASK: # in the task extraction mode + env.add_workload_key(key) + return te.create_schedule([x.op for x in outs]) + elif env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE: + # in prepare_layout_rewrite mode + if has_layout_free: + # Rewrite the DAG and update the transform history for + # the new dag in DispatchContext + dispatch_ctx = DispatchContext.current + tgt = target.Target.current() + state = dispatch_ctx.query(tgt, key) + assert state is not None + dag = ComputeDAG(outs) + new_dag = dag.rewrite_layout_from_state(state) + new_key = json.dumps((compute_dag_hash(new_dag),)) + dispatch_ctx.update(tgt, new_key, state) + if new_key != key: + env.layout_rewrite_success_ct += 1 + return te.create_schedule([x.op for x in outs]) + else: + raise ValueError("Invalid tracing mode: " + env.tracing_mode) diff --git a/python/tvm/ansor/serialization.py b/python/tvm/ansor/serialization.py index d9b8a2f5c075..97903b38bb0b 100644 --- a/python/tvm/ansor/serialization.py +++ b/python/tvm/ansor/serialization.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -"""Tuning log I/O Utilities""" +"""Serialization and other I/O support for tuning logs (measurement records)""" import numpy as np @@ -29,7 +29,7 @@ @tvm._ffi.register_object("ansor.LogToFile") class LogToFile(MeasureCallback): """ - A measurement callback that writes tuning logs into a file + A measurement callback that writes measurement records into a file Parameters ---------- @@ -65,6 +65,7 @@ def __iter__(self): yield ret[0], ret[1] # (input, result) def load_from_file(filename: str): + """Load measurement records from a file""" return zip(*LogReader(filename).read_lines()) @@ -80,7 +81,7 @@ def get_states_from_measure_inputs(inputs, task): def best_measure_pair_in_file(filename, workload_key=None, target=None): - """ Return best results form log file + """ Return the best measurement pair form a log file Parameters ---------- diff --git a/python/tvm/ansor/topi_integration.py b/python/tvm/ansor/topi_integration.py deleted file mode 100644 index 77def00cf9ec..000000000000 --- a/python/tvm/ansor/topi_integration.py +++ /dev/null @@ -1,220 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-variable,invalid-name,unused-argument -""" -Decorators for registering tunable templates to TOPI. - -These decorators can make your simple implementation be able to use different configurations -for different workloads. -Here we directly use all arguments to the TOPI call as "workload", so make sure all the arguments -(except tvm.te.Tensor) in you calls are hashable. For tvm.te.Tensor, -we will serialize it to a hashable tuple. - -See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage. -""" -import os -import json -import tvm.te._ffi_api -from tvm import target as _target -from tvm.te import tensor -from tvm.te.tensor import PlaceholderOp, ComputeOp - -from .dispatcher import DispatchContext, BlockingEmptyContext -from .workload_registry import register_auto_scheduler_workload_bufs, \ - make_workload_key_bufs, compute_dag_hash -from .compute_dag import ComputeDAG - -def traverse_to_get_io_tensors(outs): - layout_free_ops = [] - inputs = [] - - visited = set() - - def traverse(t): - if t in visited: - return - if isinstance(t.op, PlaceholderOp): - inputs.append(t) - elif isinstance(t.op, ComputeOp): - if "layout_free_placeholders" in t.op.attrs: - layout_free_ops.append(t.op) - for x in t.op.input_tensors: - traverse(x) - visited.add(t) - - for t in outs: - traverse(t) - - has_layout_free = (len(layout_free_ops) > 0) - return inputs + [t for t in outs], has_layout_free - -# Task extractor for relay program -class TaskExtractEnv: - """Global environment for extracting tuning tasks from graph""" - current = None - registered = None - - def __init__(self, do_layout_rewrite=False): - self.do_layout_rewrite = do_layout_rewrite - self.wanted_relay_ops = None - self.modified_funcs = [] - self.tracing = False - self.relay_disable_build_cache_ = "false" - self.layout_rewrite_success_ct = 0 - self.wkl_key_collection = {} - - def __enter__(self): - self.tracing = True - self.wkl_key_collection = {} - self.relay_disable_build_cache_ = os.environ.get("TVM_RELAY_DISABLE_BUILD_CACHE", "false") - os.environ["TVM_RELAY_DISABLE_BUILD_CACHE"] = "true" - - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.tracing = False - os.environ["TVM_RELAY_DISABLE_BUILD_CACHE"] = self.relay_disable_build_cache_ - - def reset(self, wanted_relay_ops=None): - """Reset task collections - - Parameters - ---------- - wanted_relay_ops: List of tvm.ir.Op - The relay ops to be extracted - """ - self.wanted_relay_ops = wanted_relay_ops - self.relay_disable_build_cache_ = "false" - self.layout_rewrite_success_ct = 0 - self.wkl_key_collection = {} - - def add_task(self, key): - """Add AutoTVM task - - Parameters - ---------- - task_name: str - AutoTVM task name. - - args: tuple - Arguments to the TOPI function. - """ - if key in self.wkl_key_collection: - self.wkl_key_collection[key] += 1 - else: - self.wkl_key_collection[key] = 1 - - def get_tasks(self): - """Get collected tasks - - Returns - ------- - tasks: List of tuple(name, args) - A list of tasks extracted from the graph - """ - return self.wkl_key_collection - - def get_wkl_keys(self): - """Get collected tasks - - Returns - ------- - wkl_keys: List of autoschedule workload_key - """ - return self.wkl_key_collection - - @staticmethod - def get(do_layout_rewrite=False): - """Get the single instance of TaskExtractEnv - - Parameters - ---------- - - Returns - ------- - env: TaskExtractEnv - The single instance of TaskExtractEnv - """ - if not TaskExtractEnv.current: - TaskExtractEnv.current = TaskExtractEnv(do_layout_rewrite) - else: - TaskExtractEnv.current.do_layout_rewrite = do_layout_rewrite - return TaskExtractEnv.current - -def register_topi_schedule(func=None): - """Register a tunable template for a topi schedule function. - - The registration will wrap this topi schedule to take `cfg` as the first argument, - followed by the original argument list. - - Note that this function will try to find "workload" from all the ComputeOp in the input. - You can attach "workload" to your compute op by using :any:`register_topi_compute`. - - The task name has to be the same as that of the corresponding topi compute function. - - Parameters - ---------- - task_name: str - The AutoTVM task name - - func: None or callable - If it is None, return a decorator. - If is callable, decorate this function. - - Returns - ------- - decorator: callable - A decorator - - Examples - -------- - See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage. - """ - def _decorate(topi_schedule): - def wrapper(outs, *args, **kwargs): - io_tensors, has_layout_free = traverse_to_get_io_tensors(outs) - key = register_auto_scheduler_workload_bufs(io_tensors) - task_env = TaskExtractEnv.current - if task_env is not None and task_env.tracing: - if task_env.do_layout_rewrite and has_layout_free: - # Rewrite the dag and update the transform history for - # the new dag in DispatchContext - dispatch_ctx = DispatchContext.current - tgt = _target.Target.current() - state = dispatch_ctx.query(tgt, key) - dag = ComputeDAG(outs) - new_dag = dag.rewrite_layout_from_state(state) - new_key = json.dumps((compute_dag_hash(new_dag),)) - dispatch_ctx.update(tgt, new_key, state) - - if new_key != key: - task_env.layout_rewrite_success_ct += 1 - - # Call schedule_func under FallbackContext() to avoid layout rewrite - cfg = BlockingEmptyContext().query(tgt, key) - return topi_schedule(cfg, outs) - - task_env.add_task(key) - - """wrapper function for topi schedule""" - tgt = _target.Target.current() - cfg = DispatchContext.current.query(tgt, key) - return topi_schedule(cfg, outs) - return wrapper - if func: - return _decorate(func) - return _decorate diff --git a/python/tvm/ansor/utils.py b/python/tvm/ansor/utils.py index 5ed9bd46d355..9e3c857aba36 100644 --- a/python/tvm/ansor/utils.py +++ b/python/tvm/ansor/utils.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -"""Common utilities""" +"""Common utilities for ansor""" import multiprocessing import multiprocessing.pool @@ -30,7 +30,7 @@ except ImportError: psutil = None -from .. import rpc as _rpc +from tvm import rpc from tvm.tir import expr from tvm.tir.transform import Simplify from tvm.ir.transform import Sequential @@ -205,7 +205,7 @@ def request_remote(device_key, host=None, port=None, priority=1, timeout=60): host = host or os.environ['TVM_TRACKER_HOST'] port = port or int(os.environ['TVM_TRACKER_PORT']) - tracker = _rpc.connect_tracker(host, port) + tracker = rpc.connect_tracker(host, port) remote = tracker.request(device_key, priority=priority, session_timeout=timeout) return remote diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index 8e6698e4a164..66ef5cd4c852 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -21,6 +21,7 @@ import logging import numpy as np import tvm +import os from tvm import te from tvm.runtime import Object from ... import target as _target @@ -141,7 +142,6 @@ def get_valid_implementations(op, attrs, inputs, out_type, target): ret.append(impl) return ret - def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True): """Select the best implementation from the op strategy. @@ -179,6 +179,9 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True) ret : tuple(relay.op.OpImplementation, List[tvm.te.Tensor]) The best op implementation and the corresponding output tensors. """ + if os.environ.get('TVM_USE_AUTOTVM', 'false') == 'false': + use_autotvm = False + all_impls = get_valid_implementations(op, attrs, inputs, out_type, target) best_plevel_impl = None diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 30c5971e32b9..d1a39ceb630e 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -72,6 +72,7 @@ def __init__(self): self._get_module = self.mod["get_module"] self._build = self.mod["build"] self._optimize = self.mod["optimize"] + self._call_all_topi_funcs = self.mod["call_all_topi_funcs"] self._set_params_func = self.mod["set_params"] self._get_params_func = self.mod["get_params"] @@ -160,6 +161,12 @@ def optimize(self, mod, target=None, params=None): return mod, params + def call_all_topi_funcs(self, mod, target=None, target_host=None, params=None): + """Call all topi compute and schedule used in a relay function""" + target = _update_target(target) + if params: + self._set_params(params) + self._call_all_topi_funcs(mod, target, target_host) def _set_params(self, params): self._set_params_func(_convert_param_map(params)) diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index b02db416bdc8..2a0ddd1329b5 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -16,14 +16,16 @@ # under the License. """Definition of x86 operator strategy.""" # pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import -import logging -import re -import topi +import os from tvm.te import SpecializedCondition +from tvm import ansor from .generic import * from .. import op as _op +# Set the priority level to use the Ansor auto-scheduler +ansor_plevel = 11 + logger = logging.getLogger('strategy') _NCHWc_matcher = re.compile("^NCHW[0-9]+c$") @@ -39,7 +41,7 @@ def schedule_injective_cpu(attrs, outs, target): def schedule_reduce_cpu(attrs, outs, target): """schedule reduction ops for x86""" with target: - return topi.x86.schedule_reduce(outs) + return ansor.auto_schedule_topi(outs) @schedule_concatenate.register("cpu") def schedule_concatenate_cpu(attrs, outs, target): @@ -51,13 +53,13 @@ def schedule_concatenate_cpu(attrs, outs, target): def schedule_pool_cpu(attrs, outs, target): """schedule pooling ops for x86""" with target: - return topi.x86.schedule_pool(outs, attrs.layout) + return ansor.auto_schedule_topi(outs) @schedule_adaptive_pool.register("cpu") def schedule_adaptive_pool_cpu(attrs, outs, target): """schedule adaptive pooling ops for x86""" with target: - return topi.x86.schedule_adaptive_pool(outs) + return ansor.auto_schedule_topi(outs) @softmax_strategy.register("cpu") def softmax_strategy_cpu(attrs, inputs, out_type, target): @@ -65,15 +67,15 @@ def softmax_strategy_cpu(attrs, inputs, out_type, target): strategy = _op.OpStrategy() strategy.add_implementation( wrap_compute_softmax(topi.nn.softmax), - wrap_topi_schedule(topi.x86.schedule_softmax), - name="softmax.x86") + wrap_topi_schedule(ansor.auto_schedule_topi), + name="ansor") return strategy @schedule_log_softmax.register("cpu") def schedule_log_softmax_cpu(attrs, outs, target): """schedule log_softmax op for x86""" with target: - return topi.x86.schedule_softmax(outs) + return ansor.auto_schedule_topi(outs) @conv2d_strategy.register("cpu") def conv2d_strategy_cpu(attrs, inputs, out_type, target): @@ -105,18 +107,18 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): return conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target) elif layout == "NHWC": assert kernel_layout == "HWIO" - logger.warning("For x86 target, NCHW layout is recommended for conv2d.") + #logger.warning("For x86 target, NCHW layout is recommended for conv2d.") strategy.add_implementation( wrap_compute_conv2d(topi.nn.conv2d_nhwc), - wrap_topi_schedule(topi.x86.schedule_conv2d_nhwc), - name="conv2d_nhwc.x86") + wrap_topi_schedule(ansor.auto_schedule_topi), + name="ansor") elif layout == "HWCN": assert kernel_layout == "HWIO" - logger.warning("conv2d HWCN layout is not optimized for x86.") + #logger.warning("conv2d HWCN layout is not optimized for x86.") strategy.add_implementation( wrap_compute_conv2d(topi.nn.conv2d_hwcn), - wrap_topi_schedule(topi.generic.schedule_conv2d_hwcn), - name="conv2d_hwcn.generic") + wrap_topi_schedule(ansor.auto_schedule_topi), + name="ansor") else: raise RuntimeError("Unsupported conv2d layout {} for x86".format(layout)) elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups): @@ -143,8 +145,8 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): logger.warning("depthwise_conv2d NHWC layout is not optimized for x86.") strategy.add_implementation( wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc), - wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nhwc), - name="depthwise_conv2d_nhwc.generic") + wrap_topi_schedule(ansor.auto_schedule_topi), + name="ansor") else: raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout)) else: # group_conv2d @@ -153,8 +155,8 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): logger.warning("group_conv2d is not optimized for x86.") strategy.add_implementation( wrap_compute_conv2d(topi.nn.group_conv2d_nchw, has_groups=True), - wrap_topi_schedule(topi.generic.schedule_group_conv2d_nchw), - name="group_conv2d_nchw.generic") + wrap_topi_schedule(ansor.auto_schedule_topi), + name="ansor") else: raise RuntimeError("Unsupported group_conv2d layout {}".format(layout)) return strategy @@ -231,8 +233,8 @@ def conv3d_strategy_cpu(attrs, inputs, out_type, target): name="conv3d_ncdhw.x86") elif layout == "NDHWC": strategy.add_implementation(wrap_compute_conv3d(topi.x86.conv3d_ndhwc), - wrap_topi_schedule(topi.x86.schedule_conv3d_ndhwc), - name="conv3d_ndhwc.x86") + wrap_topi_schedule(ansor.auto_schedule_topi), + name="ansor") else: raise ValueError("Not support this layout {} yet".format(layout)) return strategy @@ -251,8 +253,8 @@ def conv1d_strategy_cpu(attrs, inputs, out_type, target): name="conv1d_ncw.x86") elif layout == "NWC": strategy.add_implementation(wrap_compute_conv1d(topi.nn.conv1d_nwc), - wrap_topi_schedule(topi.x86.schedule_conv1d_nwc), - name="conv1d_nwc.x86") + wrap_topi_schedule(ansor.auto_schedule_topi), + name="ansor") else: raise ValueError("Unsupported conv1d layout {}".format(layout)) return strategy @@ -261,16 +263,23 @@ def conv1d_strategy_cpu(attrs, inputs, out_type, target): def dense_strategy_cpu(attrs, inputs, out_type, target): """dense x86 strategy""" strategy = _op.OpStrategy() - m, _ = inputs[0].shape + + strategy.add_implementation(wrap_compute_dense(topi.nn.dense), + wrap_topi_schedule(ansor.auto_schedule_topi), + name='ansor', + plevel=ansor_plevel) + strategy.add_implementation(wrap_compute_dense(topi.x86.dense_nopack), wrap_topi_schedule(topi.x86.schedule_dense_nopack), name="dense_nopack.x86", plevel=10) + if "cblas" in target.libs: strategy.add_implementation(wrap_compute_dense(topi.x86.dense_cblas), wrap_topi_schedule(topi.x86.schedule_dense_cblas), name="dense_cblas.x86", plevel=15) + m, _ = inputs[0].shape with SpecializedCondition(m >= 16): # this implementation may not be well-optimized, so use plevel=8 for now. strategy.add_implementation(wrap_compute_dense(topi.x86.dense_pack), @@ -283,6 +292,12 @@ def dense_strategy_cpu(attrs, inputs, out_type, target): def batch_matmul_strategy_cpu(attrs, inputs, out_type, target): """batch_matmul x86 strategy""" strategy = _op.OpStrategy() + + strategy.add_implementation(wrap_compute_dense(topi.nn.batch_matmul), + wrap_topi_schedule(ansor.auto_schedule_topi), + name='ansor', + plevel=ansor_plevel) + strategy.add_implementation(wrap_compute_batch_matmul(topi.x86.batch_matmul), wrap_topi_schedule(topi.x86.schedule_batch_matmul), name="batch_matmul.x86", diff --git a/python/tvm/relay/testing/resnet.py b/python/tvm/relay/testing/resnet.py index 8633879465bd..4383157d9f06 100644 --- a/python/tvm/relay/testing/resnet.py +++ b/python/tvm/relay/testing/resnet.py @@ -59,9 +59,11 @@ def residual_unit(data, name : str Base name of the operators """ + bn_axis = data_layout.index('C') if bottle_neck: bn1 = layers.batch_norm_infer(data=data, epsilon=2e-5, + axis=bn_axis, name=name + '_bn1') act1 = relay.nn.relu(data=bn1) conv1 = layers.conv2d( @@ -73,13 +75,13 @@ def residual_unit(data, name=name + '_conv1', data_layout=data_layout, kernel_layout=kernel_layout) - bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, name=name + '_bn2') + bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, axis=bn_axis, name=name + '_bn2') act2 = relay.nn.relu(data=bn2) conv2 = layers.conv2d( data=act2, channels=int(num_filter*0.25), kernel_size=(3, 3), strides=(1, 1), padding=(1, 1), name=name + '_conv2', data_layout=data_layout, kernel_layout=kernel_layout) - bn3 = layers.batch_norm_infer(data=conv2, epsilon=2e-5, name=name + '_bn3') + bn3 = layers.batch_norm_infer(data=conv2, epsilon=2e-5, axis=bn_axis, name=name + '_bn3') act3 = relay.nn.relu(data=bn3) conv3 = layers.conv2d( data=act3, channels=num_filter, kernel_size=(1, 1), @@ -94,13 +96,13 @@ def residual_unit(data, data_layout=data_layout, kernel_layout=kernel_layout) return relay.add(conv3, shortcut) - bn1 = layers.batch_norm_infer(data=data, epsilon=2e-5, name=name + '_bn1') + bn1 = layers.batch_norm_infer(data=data, epsilon=2e-5, axis=bn_axis, name=name + '_bn1') act1 = relay.nn.relu(data=bn1) conv1 = layers.conv2d( data=act1, channels=num_filter, kernel_size=(3, 3), strides=stride, padding=(1, 1), name=name + '_conv1', data_layout=data_layout, kernel_layout=kernel_layout) - bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, name=name + '_bn2') + bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, axis=bn_axis, name=name + '_bn2') act2 = relay.nn.relu(data=bn2) conv2 = layers.conv2d( data=act2, channels=num_filter, kernel_size=(3, 3), @@ -156,11 +158,12 @@ def resnet(units, data_layout = layout kernel_layout = "OIHW" if layout == "NCHW" else "HWIO" + bn_axis = data_layout.index('C') num_unit = len(units) assert num_unit == num_stages data = relay.var("data", shape=data_shape, dtype=dtype) - data = layers.batch_norm_infer(data=data, epsilon=2e-5, scale=False, name='bn_data') + data = layers.batch_norm_infer(data=data, epsilon=2e-5, axis=bn_axis, scale=False, name='bn_data') (_, _, height, _) = data_shape if layout == "NHWC": (_, height, _, _) = data_shape @@ -174,7 +177,7 @@ def resnet(units, data=data, channels=filter_list[0], kernel_size=(7, 7), strides=(2, 2), padding=(3, 3), name="conv0", data_layout=data_layout, kernel_layout=kernel_layout) - body = layers.batch_norm_infer(data=body, epsilon=2e-5, name='bn0') + body = layers.batch_norm_infer(data=body, epsilon=2e-5, axis=bn_axis, name='bn0') body = relay.nn.relu(data=body) body = relay.nn.max_pool2d(data=body, pool_size=(3, 3), strides=(2, 2), padding=(1, 1), layout=data_layout) @@ -189,7 +192,7 @@ def resnet(units, body, filter_list[i+1], (1, 1), True, name='stage%d_unit%d' % (i + 1, j + 2), bottle_neck=bottle_neck, data_layout=data_layout, kernel_layout=kernel_layout) - bn1 = layers.batch_norm_infer(data=body, epsilon=2e-5, name='bn1') + bn1 = layers.batch_norm_infer(data=body, epsilon=2e-5, axis=bn_axis, name='bn1') relu1 = relay.nn.relu(data=bn1) # Although kernel is not used here when global_pool=True, we should put one pool1 = relay.nn.global_avg_pool2d(data=relu1, layout=data_layout) diff --git a/scripts/tune_network.py b/scripts/tune_network.py index dc17f407d003..d4f1afd95572 100644 --- a/scripts/tune_network.py +++ b/scripts/tune_network.py @@ -200,14 +200,7 @@ def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, if tune: print("=============== Extracting workloads ===============") - workloads, wkl_weights = ansor.extract_from_program(mod, target=target, - params=params, ops=(relay.op.nn.dense, relay.op.nn.softmax, - relay.op.nn.conv2d, relay.op.nn.conv2d_transpose, - relay.op.nn.max_pool2d, relay.op.nn.avg_pool2d, - relay.op.nn.global_max_pool2d, relay.op.nn.global_avg_pool2d, - relay.op.nn.conv3d, relay.op.nn.adaptive_avg_pool3d, - relay.op.nn.batch_matmul, relay.op.mean, - )) + workloads, wkl_weights = ansor.extract_from_program(mod, target=target, params=params) print("Totally %d workload extracted." % (len(workloads))) # Tune workloads with auto scheduler @@ -238,15 +231,13 @@ def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, os.environ['TVM_AUTO_CACHE_FLUSH'] = "0" os.environ['TVM_BIND_MASTER_CORE_0'] = "1" if kernel_layout_rewrite: - ansor.prepare_layout_rewrite(mod, target=target, - params=params, - ops=(relay.op.nn.dense, relay.op.nn.conv2d, relay.op.nn.conv3d)) + ansor.prepare_layout_rewrite(mod, target=target, params=params) else: # disable layout rewrite ansor.LayoutRewriteLevel.BOTH_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE ansor.LayoutRewriteLevel.COMPUTE_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE - with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): + with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}): graph, lib, opt_params = relay.build_module.build( mod, target=target, params=params) diff --git a/scripts/tune_test.py b/scripts/tune_test.py index 7831aea9dd4a..86f055caf889 100644 --- a/scripts/tune_test.py +++ b/scripts/tune_test.py @@ -36,7 +36,7 @@ def create_tune_option(target, log_file, n_trials, num_measure_per_iter, verbose builder=builder, runner=runner, measure_callbacks=[ansor.LogToFile(log_file)], - pre_search_callbacks=[ansor.PreLoadMeasuredStates(log_file)]) + pre_search_callbacks=[ansor.PreloadMeasuredStates(log_file)]) return tune_option, measure_ctx diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index fec301dc54bc..6269b9f16f71 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -902,15 +902,12 @@ void ComputeDAG::RewriteLayout( for (size_t i = 0; i < old_ops.size(); ++i) { auto old_op = old_ops[i]; if (rewrite_placeholder && old_op == placeholder_op) { - //pops->data[i] = new_placeholder_op; pops->SetItem(i, new_placeholder_op); updated_ops[placeholder_op] = new_placeholder_op; } else if (rewrite_body && old_op == old_compute_op) { - //pops->data[i] = new_compute_op; pops->SetItem(i, new_compute_op); updated_ops[old_compute_op] = new_compute_op; } else { - //pops->data[i] = old_op; pops->SetItem(i, old_op); } } @@ -936,7 +933,6 @@ void ComputeDAG::RewriteLayout( if (!rmap.empty()) { te::Operation new_op = pop->ReplaceInputs(old_op, rmap); updated_ops[old_op] = new_op; - //pops->data[i] = new_op; pops->SetItem(i, new_op); } } @@ -958,7 +954,6 @@ void ComputeDAG::RewriteLayout( if (new_op.defined()) { if (layout_rewrite_level == kBothRewrite) { auto index = old_tensor->value_index; - //ptensors->data[i] = new_op.output(index); ptensors->SetItem(i, new_op.output(index)); } else if (layout_rewrite_level == kComputeRewrite) { te::TensorNode* old_tensor_node = diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h index c71c4f1b6586..8da71f005f19 100644 --- a/src/ansor/compute_dag.h +++ b/src/ansor/compute_dag.h @@ -105,7 +105,7 @@ typedef std::unordered_map, ObjectHash void UpdateStageAxis(const tvm::te::Stage& stage, StageToAxesMap *stage_to_axes); -/*! \brief Compute declaration graph */ +/*! \brief Computation declaration graph */ class ComputeDAGNode : public Object { public: Array tensors; // Input and output tensors diff --git a/src/ansor/feature.cc b/src/ansor/feature.cc index 497a3ac4222b..3c6976a0e25a 100644 --- a/src/ansor/feature.cc +++ b/src/ansor/feature.cc @@ -653,9 +653,9 @@ class PerStmtFeatureExtractor : public StmtExprVisitor { fea.vec_prod *= GetIntImm(pfor->extent); } fea.vec_type = kPosMixed; - // todo(lmzheng): this feature requires operation (tvm.compute) information - //GetAnnotationPosEncoding(vec_for_stack.back()->loop_var, - //node->args, pcompute->axis, pcompute->reduce_axis); + // todo(lmzheng): this feature requires operation (tvm.compute) information + // GetAnnotationPosEncoding(vec_for_stack.back()->loop_var, + // node->args, pcompute->axis, pcompute->reduce_axis); } fea.unroll_num = unroll_for_stack.size(); @@ -666,8 +666,8 @@ class PerStmtFeatureExtractor : public StmtExprVisitor { fea.unroll_prod *= GetIntImm(pfor->extent); } fea.unroll_type = kPosMixed; - //GetAnnotationPosEncoding(unroll_for_stack.back()->loop_var, - //node->args, pcompute->axis, pcompute->reduce_axis); + // GetAnnotationPosEncoding(unroll_for_stack.back()->loop_var, + // node->args, pcompute->axis, pcompute->reduce_axis); } fea.parallel_num = parallel_for_stack.size(); @@ -678,8 +678,8 @@ class PerStmtFeatureExtractor : public StmtExprVisitor { fea.parallel_prod *= GetIntImm(pfor->extent); } fea.parallel_type = kPosMixed; - //GetAnnotationPosEncoding(parallel_for_stack.back()->loop_var, - //node->args, pcompute->axis, pcompute->reduce_axis); + // GetAnnotationPosEncoding(parallel_for_stack.back()->loop_var, + // node->args, pcompute->axis, pcompute->reduce_axis); } // GPU threads @@ -1213,7 +1213,8 @@ void GetPerStmtFeaturesWorkerFunc(const SearchTask& task, const State& state, const auto& optimize = tir::transform::Sequential(pass_list); optimize(mod); } - const auto& optimize = tir::transform::Sequential(Array{tir::transform::Simplify()}); + const auto& optimize = tir::transform::Sequential( + Array{tir::transform::Simplify()}); mod = optimize(std::move(mod)); const auto& it = mod->functions.find(global_var); CHECK(it != mod->functions.end()); @@ -1241,8 +1242,8 @@ void GetPerStmtFeaturesFromStates(const Array& states, for (size_t i = skip_first_n_feature_extraction; i < states.size(); ++i) { pool.Enqueue(GetPerStmtFeaturesWorkerFunc, task, states[i], max_n_bufs, &(*features)[i], &error_ct); - //GetPerStmtFeaturesWorkerFunc(task, states[i], - // max_n_bufs, &(*features)[i], &error_ct); + // GetPerStmtFeaturesWorkerFunc(task, states[i], + // max_n_bufs, &(*features)[i], &error_ct); } pool.WaitBatch(); diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc index 73bbade241c5..474ea048ebad 100644 --- a/src/ansor/measure.cc +++ b/src/ansor/measure.cc @@ -1,6 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! - * Copyright (c) 2020 by Contributors + * Copyright (c) 2020 by Contributors + * \file ansor/measure.cc + * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs */ + #include "measure.h" #include #include diff --git a/src/ansor/measure.h b/src/ansor/measure.h index 780a30514d46..6e432ba9c88b 100644 --- a/src/ansor/measure.h +++ b/src/ansor/measure.h @@ -1,6 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + /*! * Copyright (c) 2020 by Contributors - * \file ansor/search_task.h + * \file ansor/measure.h * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs */ diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.cc b/src/ansor/search_policy/meta_tile_rewrite_policy.cc index 7e022e3be3c3..8b5b97224c08 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.cc +++ b/src/ansor/search_policy/meta_tile_rewrite_policy.cc @@ -62,8 +62,8 @@ State MetaTileRewritePolicyNode::Search(SearchTask task, int n_trials, int verbose, ProgramMeasurer measurer, Array pre_search_callbacks) { std::vector best_states, random_states; - cur_task_ = task; - verbose_ = verbose; + this->cur_task = task; + this->verbose = verbose; num_measure_per_iter_ = num_measure_per_iter; RunCallbacks(pre_search_callbacks); @@ -85,17 +85,17 @@ State MetaTileRewritePolicyNode::Search(SearchTask task, int n_trials, while (ct < n_trials) { if (!inputs.empty()) { // retrain cost models - PrintTitle("Train cost model", verbose_); + PrintTitle("Train cost model", verbose); program_cost_model->Update(inputs, results); } // Search one round to get promising states - PrintTitle("Search", verbose_); + PrintTitle("Search", verbose); SearchOneRound(&best_states, num_random, &random_states); // Fill correct bound.This is necessary for computing the correct ToStr() for reduncency check - cur_task_->compute_dag.InferBound(&best_states); - cur_task_->compute_dag.InferBound(&random_states); + cur_task->compute_dag.InferBound(&best_states); + cur_task->compute_dag.InferBound(&random_states); // Pick `num_measure_per_iter` states to measure, check hash to remove already measured state // Also pick some random states to do eps-greedy @@ -108,11 +108,11 @@ State MetaTileRewritePolicyNode::Search(SearchTask task, int n_trials, } // Measure candidate states - PrintTitle("Measure", verbose_); - measurer->Measure(cur_task_, GetRef(this), inputs, &results); + PrintTitle("Measure", verbose); + measurer->Measure(cur_task, GetRef(this), inputs, &results); ct += inputs.size(); - if (ct - measurer->best_ct[cur_task_->workload_key] > early_stopping) { + if (ct - measurer->best_ct[cur_task->workload_key] > early_stopping) { StdCout(verbose) << "Meet the early stopping condition." << std::endl; break; } @@ -122,21 +122,21 @@ State MetaTileRewritePolicyNode::Search(SearchTask task, int n_trials, measured_states_throughputs_.push_back(1.0 / FloatArrayMean(res->costs)); } } - PrintTitle("Done", verbose_); + PrintTitle("Done", verbose); - return measurer->best_state[cur_task_->workload_key]; + return measurer->best_state[cur_task->workload_key]; } } std::pair, Array > MetaTileRewritePolicyNode::ContinueSearchOneRound( SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) { - if (cur_task_.defined()) { - CHECK_EQ(cur_task_, task); + if (cur_task.defined()) { + CHECK_EQ(cur_task, task); } else { - cur_task_ = task; + cur_task = task; } - verbose_ = verbose; + this->verbose = verbose; num_measure_per_iter_ = num_measure; std::vector best_states, random_states; @@ -149,8 +149,8 @@ std::pair, Array > SearchOneRound(&best_states, num_random * 2, &random_states); // Fill correct bound. This is necessary for computing the correct ToStr() for reduncency check - cur_task_->compute_dag.InferBound(&best_states); - cur_task_->compute_dag.InferBound(&random_states); + cur_task->compute_dag.InferBound(&best_states); + cur_task->compute_dag.InferBound(&random_states); // Pick `num_measure` states to measure, check hash to remove already measured state // Also pick some random states to do eps-greedy @@ -158,7 +158,7 @@ std::pair, Array > // Measure candidate states PrintTitle("Measure", verbose); - measurer->Measure(cur_task_, GetRef(this), inputs, &results); + measurer->Measure(cur_task, GetRef(this), inputs, &results); // Update throughputs of measured states. These states will join the LocalMutation in later rounds for (const auto& res : results) { @@ -219,7 +219,7 @@ void MetaTileRewritePolicyNode::PickStatesWithEpsGreedy( if (measured_states_set_.count(state_str)) { continue; } measured_states_set_.insert(state_str); - inputs->push_back(MeasureInputNode::make(cur_task_, *pstate)); + inputs->push_back(MeasureInputNode::make(cur_task, *pstate)); measured_states_vector_.push_back(std::move(*pstate)); } } @@ -288,7 +288,7 @@ class SketchGenerationRule { static inline bool ShouldBeCacheRead( const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) { - const SearchTask& task = policy->cur_task_; + const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; if (HasAttrsFlag(state, stage_id, @@ -320,7 +320,7 @@ static inline bool ShouldBeCacheRead( static inline bool ShouldAlwaysBeInlined( const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) { - const SearchTask& task = policy->cur_task_; + const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; if (stage->op->IsInstance()) { @@ -336,7 +336,7 @@ static inline bool ShouldAlwaysBeInlined( if (HasAttrsFlag(state, stage_id, SearchPolicyNode::always_compute_inline_key) || IsStrictInlineable(task, state, stage->op) || - (IS_GPU(policy->cur_task_) && + (IS_GPU(policy->cur_task) && !ShouldBeCacheRead(policy, state, stage_id))) { return true; } @@ -367,7 +367,7 @@ class RuleSkipStage : public SketchGenerationRule { public: ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task_; + const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; const auto& attrs = stage->op->attrs; @@ -392,16 +392,16 @@ class RuleMultiLevelTiling : public SketchGenerationRule { public: ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task_; + const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; return NeedsMultilevelTiling(task, state, stage->op) ? - (IS_GPU(policy->cur_task_) ? kApplyAndSkipRest : kApply) : kPass; + (IS_GPU(policy->cur_task) ? kApplyAndSkipRest : kApply) : kPass; } std::vector > Apply(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { - std::string multi_level_tiling_structure = IS_GPU(policy->cur_task_) ? + std::string multi_level_tiling_structure = IS_GPU(policy->cur_task) ? GetStringParam(policy->params, "gpu_multi_level_tiling_structure") : GetStringParam(policy->params, "cpu_multi_level_tiling_structure"); @@ -418,12 +418,12 @@ class RuleMultiLevelTilingWithFusion : public SketchGenerationRule { public: ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task_; + const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; int target_stage_id; - if (IS_GPU(policy->cur_task_)) { + if (IS_GPU(policy->cur_task)) { return NeedsMultilevelTiling(task, state, stage->op) && HasSingleElementwiseMatchedConsumer(task, state, stage, &target_stage_id) && @@ -440,9 +440,9 @@ class RuleMultiLevelTilingWithFusion : public SketchGenerationRule { std::vector > Apply(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task_; + const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; - std::string multi_level_tiling_structure = IS_GPU(policy->cur_task_) ? + std::string multi_level_tiling_structure = IS_GPU(policy->cur_task) ? GetStringParam(policy->params, "gpu_multi_level_tiling_structure") : GetStringParam(policy->params, "cpu_multi_level_tiling_structure"); @@ -457,7 +457,7 @@ class RuleMultiLevelTilingWithFusion : public SketchGenerationRule { base_state = DoMultiLevelTiling(base_state, stage_id, multi_level_tiling_structure, &spatial_split_step_ids); std::vector follow_tiling_levels; - if (IS_GPU(policy->cur_task_)) { + if (IS_GPU(policy->cur_task)) { follow_tiling_levels.push_back(3); } else { follow_tiling_levels.push_back(1); @@ -487,7 +487,7 @@ class RuleAddCacheWrite : public SketchGenerationRule { public: ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task_; + const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; int target_stage_id; @@ -505,7 +505,7 @@ class RuleAddCacheWrite : public SketchGenerationRule { std::vector > Apply(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task_; + const SearchTask& task = policy->cur_task; State tmp_s = state; tmp_s.cache_write(stage_id, "local", task->compute_dag); @@ -526,7 +526,7 @@ class RuleAddCacheRead : public SketchGenerationRule { std::vector > Apply(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task_; + const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; std::unordered_set consumers; @@ -551,7 +551,7 @@ class RuleAddRfactor : public SketchGenerationRule { public: ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task_; + const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; return NeedsRfactor(task, state, stage->op) && @@ -561,7 +561,7 @@ class RuleAddRfactor : public SketchGenerationRule { std::vector > Apply(const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task_; + const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; std::vector > ret; @@ -613,7 +613,7 @@ class RuleAddRfactor : public SketchGenerationRule { void MetaTileRewritePolicyNode::GenerateMetaSketch( std::vector* out_states) { - State init_state = cur_task_->compute_dag.GetInitState(); + State init_state = cur_task->compute_dag.GetInitState(); std::string cpu_multi_level_tiling_structure = GetStringParam(params, "cpu_multi_level_tiling_structure"); @@ -644,7 +644,7 @@ void MetaTileRewritePolicyNode::GenerateMetaSketch( sketch_rules.push_back(&rule_multi_level_tiling); sketch_rules.push_back(&rule_add_rfactor); sketch_rules.push_back(&rule_skip_stage); - if (IS_GPU(cur_task_)) { + if (IS_GPU(cur_task)) { // Try cache read first before cache write sketch_rules.insert(sketch_rules.begin() + 1, &rule_add_cache_read_stage); } @@ -705,7 +705,7 @@ void MetaTileRewritePolicyNode::GenerateMetaSketch( } } - StdCout(verbose_) << "Synthesize Meta Structure\t\t#s: " << out_states->size() << std::endl; + StdCout(verbose) << "Synthesize Meta Structure\t\t#s: " << out_states->size() << std::endl; } int InitPopulationFillTileSize(const MetaTileRewritePolicyNode* policy, @@ -728,7 +728,7 @@ int InitPopulationFillTileSize(const MetaTileRewritePolicyNode* policy, const std::vector >& candidate_lens = split_memo->GetFactorizationSchemes( extent, ps->lengths.size(), - policy->cur_task_->hardware_params->max_innermost_split_factor); + policy->cur_task->hardware_params->max_innermost_split_factor); StateNode* pstate = state->CopyOnWrite(); pstate->transform_steps[step_id] = SplitStepNode::make( @@ -771,11 +771,11 @@ int InitPopulationThreadBind(const MetaTileRewritePolicyNode* policy, // Set default vthread=1 & threadIdx.x=default_warp_size // EvolutionarySearch will try more possiblity if (GetExtent(fused_it) <= - policy->cur_task_->hardware_params->warp_size) { + policy->cur_task->hardware_params->warp_size) { state->bind_thread(stage_id, fused_it, kThreadX); } else { const auto& split_its = state->split(stage_id, fused_it, - {1, policy->cur_task_->hardware_params->warp_size}); + {1, policy->cur_task->hardware_params->warp_size}); state->bind_thread(stage_id, split_its[0], kBlockX); state->bind_thread(stage_id, split_its[1], kVThread); state->bind_thread(stage_id, split_its[2], kThreadX); @@ -793,7 +793,7 @@ int InitPopulationThreadBind(const MetaTileRewritePolicyNode* policy, } // TODO(..): Add ThreadBind support for rfactor - if (total_space_extent <= policy->cur_task_->hardware_params->warp_size) { + if (total_space_extent <= policy->cur_task->hardware_params->warp_size) { for (const auto& it : (*state)->stages[stage_id]->iters) { if (it->iter_type == kReduce) { break; @@ -828,7 +828,7 @@ int InitPopulationThreadBind(const MetaTileRewritePolicyNode* policy, } const auto& vthread_it = state->fuse(stage_id, to_fuse); if (GetExtent(vthread_it) > - policy->cur_task_->hardware_params->max_vthread_extent) { + policy->cur_task->hardware_params->max_vthread_extent) { return -1; } state->bind_thread(stage_id, vthread_it, kVThread); @@ -844,7 +844,7 @@ int InitPopulationThreadBind(const MetaTileRewritePolicyNode* policy, } const auto& threadidx_it = state->fuse(stage_id, to_fuse); if (GetExtent(threadidx_it) < - policy->cur_task_->hardware_params->warp_size) { + policy->cur_task->hardware_params->warp_size) { return -1; } state->bind_thread(stage_id, threadidx_it, kThreadX); @@ -876,7 +876,7 @@ int InitPopulationCooperativeFetching(const MetaTileRewritePolicyNode* policy, // Get spatial_split_step_ids from the root stage std::unordered_set consumers; std::vector spatial_split_step_ids; - GetConsumers(policy->cur_task_, (*state), target_stage->op, &consumers); + GetConsumers(policy->cur_task, (*state), target_stage->op, &consumers); CHECK_EQ(consumers.size(), 1); int target_stage_id = OperationToStage(*consumers.begin(), (*state)); GetSpaceSplitStepIds((*state), target_stage_id, &spatial_split_step_ids); @@ -915,13 +915,13 @@ int InitPopulationChangeComputeLocation(const MetaTileRewritePolicyNode* policy, continue; } - if (NeedsMultilevelTiling(policy->cur_task_, (*state), stage->op)) { + if (NeedsMultilevelTiling(policy->cur_task, (*state), stage->op)) { continue; } std::unordered_set consumers; - GetConsumers(policy->cur_task_, (*state), stage->op, &consumers); + GetConsumers(policy->cur_task, (*state), stage->op, &consumers); if (consumers.empty()) { continue; } @@ -1083,7 +1083,7 @@ int InitPopulationParallel(const MetaTileRewritePolicyNode* policy, to_fuse.push_back(it); parallel_degree *= GetExtent(it); - if (parallel_degree > policy->cur_task_->hardware_params->num_cores * 16) { + if (parallel_degree > policy->cur_task->hardware_params->num_cores * 16) { break; } @@ -1135,7 +1135,7 @@ int InitPopulationVectorization(const MetaTileRewritePolicyNode* policy, } // Skip cooperative fetching stage - if (IS_GPU(policy->cur_task_) && + if (IS_GPU(policy->cur_task) && HasCacheReadStage((*state), stage_id - 1)) { continue; } @@ -1179,7 +1179,7 @@ int InitPopulationVectorization(const MetaTileRewritePolicyNode* policy, } cum_length_prod *= GetExtent(it); - if (cum_length_prod > policy->cur_task_->hardware_params->max_unroll_vec) { + if (cum_length_prod > policy->cur_task->hardware_params->max_unroll_vec) { break; } @@ -1278,13 +1278,13 @@ void MetaTileRewritePolicyNode::SampleInitPopulation(const std::vector& m InitPopulationFillTileSize(this, &tmp_s, &rand_gen_, &split_memo_); - if (IS_GPU(cur_task_)) { - tmp_s = cur_task_->compute_dag.InferBound(tmp_s); + if (IS_GPU(cur_task)) { + tmp_s = cur_task->compute_dag.InferBound(tmp_s); if (InitPopulationThreadBind(this, &tmp_s)) { continue_count++; if (continue_count == out_size) { - StdCout(verbose_) << "Initial Population Sampling..." << std::endl; + StdCout(verbose) << "Initial Population Sampling..." << std::endl; } continue; } @@ -1293,7 +1293,7 @@ void MetaTileRewritePolicyNode::SampleInitPopulation(const std::vector& m } else { InitPopulationChangeComputeLocation(this, &tmp_s, &rand_gen_); - tmp_s = cur_task_->compute_dag.InferBound(tmp_s); + tmp_s = cur_task->compute_dag.InferBound(tmp_s); InitPopulationParallel(this, &tmp_s); } @@ -1305,8 +1305,8 @@ void MetaTileRewritePolicyNode::SampleInitPopulation(const std::vector& m out_states->push_back(std::move(tmp_s)); } - StdCout(verbose_) << "Sample Initial Population\t\t#s: " - << out_states->size() << std::endl; + StdCout(verbose) << "Sample Initial Population\t\t#s: " + << out_states->size() << std::endl; } void MetaTileRewritePolicyNode::EvolutionarySearch( @@ -1350,9 +1350,9 @@ void MetaTileRewritePolicyNode::EvolutionarySearch( // Genetic Algorithm for (int k = 0; k < num_iters + 1; ++k) { // Maintain the heap - cur_task_->compute_dag.InferBound(pnow); + cur_task->compute_dag.InferBound(pnow); PruneUndefined(pnow); - cost_model->Predict(cur_task_, *pnow, &scores); + cost_model->Predict(cur_task, *pnow, &scores); for (size_t i = 0; i < pnow->size(); ++i) { const State& state = (*pnow)[i]; @@ -1379,10 +1379,10 @@ void MetaTileRewritePolicyNode::EvolutionarySearch( } if (k % 5 == 0 || k == num_iters) { - StdCout(verbose_) << "GA Iter: " << k << std::fixed << std::setprecision(4) - << "\tMax score: " << max_score - << "\tMin score: " << heap.front().second - << "\tPop size: " << pnow->size() << std::endl; + StdCout(verbose) << "GA Iter: " << k << std::fixed << std::setprecision(4) + << "\tMax score: " << max_score + << "\tMin score: " << heap.front().second + << "\tPop size: " << pnow->size() << std::endl; } if (k == num_iters) { @@ -1431,7 +1431,7 @@ void MetaTileRewritePolicyNode::EvolutionarySearch( if (rule_id == 0) { // Mutate Tile Size State tmp_s = RandomMutateTileSize((*pnow)[id], &split_memo_, &rand_gen_, - cur_task_->hardware_params->max_innermost_split_factor); + cur_task->hardware_params->max_innermost_split_factor); if (tmp_s.defined()) { pnext->push_back(std::move(tmp_s)); } else { @@ -1463,9 +1463,9 @@ void MetaTileRewritePolicyNode::EvolutionarySearch( double duration = std::chrono::duration_cast >( std::chrono::high_resolution_clock::now()- tic_begin).count(); - StdCout(verbose_) << "EvolutionarySearch\t\t#s: " << best_states->size() - << "\tTime elapsed: " - << std::fixed << std::setprecision(2) << duration << std::endl; + StdCout(verbose) << "EvolutionarySearch\t\t#s: " << best_states->size() + << "\tTime elapsed: " + << std::fixed << std::setprecision(2) << duration << std::endl; } class RuleCustomSketch : public SketchGenerationRule { @@ -1519,7 +1519,7 @@ void PreAddCustomRuleNode::callback(SearchPolicyNode* policy) { auto meta_policy = dynamic_cast(policy); meta_policy->sketch_rules.emplace_back( new RuleCustomSketch(meet_condition_func, apply_func)); - StdCout(policy->verbose_) << "Custom sketch rule added." << std::endl; + StdCout(policy->verbose) << "Custom sketch rule added." << std::endl; } TVM_REGISTER_GLOBAL("ansor.MetaTileRewritePolicy") diff --git a/src/ansor/search_policy/search_policy.cc b/src/ansor/search_policy/search_policy.cc index 685052f3f71f..c9bccfdce806 100644 --- a/src/ansor/search_policy/search_policy.cc +++ b/src/ansor/search_policy/search_policy.cc @@ -23,18 +23,16 @@ */ #include "search_policy.h" - #include - #include "../serialization.h" namespace tvm { namespace ansor { TVM_REGISTER_OBJECT_TYPE(SearchPolicyNode); -TVM_REGISTER_OBJECT_TYPE(PreLoadMeasuredStatesNode); +TVM_REGISTER_OBJECT_TYPE(PreloadMeasuredStatesNode); -void SearchPolicyNode::PreLoadMeasuredStates(const std::string& log_file) { +void SearchPolicyNode::PreloadMeasuredStates(const std::string& log_file) { LogReader reader = LogReaderNode::make(log_file); const auto& res = reader->ReadLines(-1); size_t log_size = res.first.size(); @@ -44,18 +42,18 @@ void SearchPolicyNode::PreLoadMeasuredStates(const std::string& log_file) { std::vector measured_throughputs; for (size_t i = 0; i < log_size; i++) { const auto& inp = res.first[i]; - if (inp->task->workload_key == cur_task_->workload_key && + if (inp->task->workload_key == cur_task->workload_key && inp->task->target->target_name.compare( - cur_task_->target->target_name) == 0) { - State state = cur_task_->compute_dag.GetInitState(); + cur_task->target->target_name) == 0) { + State state = cur_task->compute_dag.GetInitState(); state.CopyOnWrite()->transform_steps = inp->state->transform_steps; - state.DoSteps(inp->state->transform_steps, cur_task_->compute_dag); + state.DoSteps(inp->state->transform_steps, cur_task->compute_dag); measured_states.emplace_back(std::move(state)); measured_throughputs.push_back(res.second[i]->error_no == 0 ? (1.0 / FloatArrayMean(res.second[i]->costs)) : 0.0); } } - cur_task_->compute_dag.InferBound(&measured_states); + cur_task->compute_dag.InferBound(&measured_states); for (size_t i = 0; i < measured_states.size(); i ++) { auto& state = measured_states[i]; const auto& state_str = state.ToStr(); @@ -68,33 +66,32 @@ void SearchPolicyNode::PreLoadMeasuredStates(const std::string& log_file) { } } - StdCout(verbose_) << "Measured States Set: " << measured_states_set_.size() - << " state hashes loaded from " << log_file - << " for " << cur_task_->workload_key << std::endl; + StdCout(verbose) << "Successfully load " << measured_states_set_.size() + << " measurement records from " << log_file + << " for " << cur_task->workload_key << std::endl; } else { - StdCout(verbose_) << "Measured States Set: no states found from " - << log_file << " for " << cur_task_->workload_key - << std::endl; + StdCout(verbose) << "No measurement records found in " + << log_file << " for " << cur_task->workload_key << std::endl; } } void SearchPolicyNode::RunCallbacks(const Array& callbacks) { if (callbacks.defined() && callbacks.size()) { - PrintTitle("Process search callbacks", verbose_); + PrintTitle("Call search callbacks", verbose); for (const auto& callback : callbacks) { callback->callback(this); } } } -SearchCallback PreLoadMeasuredStatesNode::make(std::string filename) { - auto node = make_object(); +SearchCallback PreloadMeasuredStatesNode::make(std::string filename) { + auto node = make_object(); node->filename = std::move(filename); return SearchCallback(node); } -void PreLoadMeasuredStatesNode::callback(SearchPolicyNode* policy) { - policy->PreLoadMeasuredStates(filename); +void PreloadMeasuredStatesNode::callback(SearchPolicyNode* policy) { + policy->PreloadMeasuredStates(filename); } // Search Policy @@ -103,8 +100,7 @@ TVM_REGISTER_GLOBAL("ansor.SearchPolicyContinueSearchOneRound") int verbose, ProgramMeasurer measurer) { Array inputs; Array results; - std::tie(inputs, results) = policy->ContinueSearchOneRound(task, num_measure, - verbose, measurer); + std::tie(inputs, results) = policy->ContinueSearchOneRound(task, num_measure, verbose, measurer); return Array{inputs, results}; }); @@ -115,17 +111,17 @@ TVM_REGISTER_GLOBAL("ansor.SearchPolicyRunCallbacks") TVM_REGISTER_GLOBAL("ansor.SearchPolicySetTask") .set_body_typed([](SearchPolicy policy, SearchTask task) { - policy->cur_task_ = task; + policy->cur_task = task; }); TVM_REGISTER_GLOBAL("ansor.SearchPolicySetVerbose") .set_body_typed([](SearchPolicy policy, int verbose) { - policy->verbose_ = verbose; + policy->verbose = verbose; }); -TVM_REGISTER_GLOBAL("ansor.PreLoadMeasuredStates") +TVM_REGISTER_GLOBAL("ansor.PreloadMeasuredStates") .set_body_typed([](std::string filename) { - return PreLoadMeasuredStatesNode::make(filename); + return PreloadMeasuredStatesNode::make(filename); }); } // namespace ansor diff --git a/src/ansor/search_policy/search_policy.h b/src/ansor/search_policy/search_policy.h index 6085fd1816e8..f1f6f45fce9a 100644 --- a/src/ansor/search_policy/search_policy.h +++ b/src/ansor/search_policy/search_policy.h @@ -25,12 +25,12 @@ #ifndef TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_ #define TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_ +#include "../search_task.h" #include #include #include #include #include -#include "../search_task.h" #include "../measure.h" namespace tvm { @@ -39,6 +39,7 @@ namespace ansor { class SearchPolicy; class SearchPolicyNode; +/*! Callback function to be called before or after the search process */ class SearchCallbackNode : public Object { public: virtual void callback(SearchPolicyNode* policy) = 0; @@ -47,7 +48,9 @@ class SearchCallbackNode : public Object { }; TVM_DEFINE_MUTABLE_OBJECT_REF(SearchCallback, SearchCallbackNode); -class PreLoadMeasuredStatesNode : public SearchCallbackNode { +/*! \brief Preload measured states from a log file. + * This can resume the state of the search policy */ +class PreloadMeasuredStatesNode : public SearchCallbackNode { public: std::string filename; @@ -55,44 +58,48 @@ class PreLoadMeasuredStatesNode : public SearchCallbackNode { void callback(SearchPolicyNode* policy) final; - static constexpr const char *_type_key = "ansor.PreLoadMeasuredStates"; - TVM_DECLARE_FINAL_OBJECT_INFO(PreLoadMeasuredStatesNode, SearchCallbackNode); + static constexpr const char *_type_key = "ansor.PreloadMeasuredStates"; + TVM_DECLARE_FINAL_OBJECT_INFO(PreloadMeasuredStatesNode, SearchCallbackNode); }; /*! \brief The base class for search policy */ class SearchPolicyNode : public Object { public: + SearchTask cur_task; // The current task + int verbose; // Verbose level (0 means silent) + + void VisitAttrs(AttrVisitor* v) { + v->Visit("cur_task", &cur_task); + v->Visit("verbose", &verbose); + } + + // Search for a task virtual State Search(SearchTask task, int n_trials, int early_stopping, int num_measure_per_iter, int verbose, ProgramMeasurer measurer, Array pre_search_callbacks) = 0; + // Continue search one round for a task. + // This is used in the task scheduler for searching for multiple tasks together. virtual std::pair, Array > ContinueSearchOneRound( SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) = 0; - void PreLoadMeasuredStates(const std::string& log_file); - void RunCallbacks(const Array& callbacks); - - SearchTask cur_task_; // The current task - int verbose_; // Verbose level (0 means silent) + // Preload measured states from a log file to resume the state of the search policy + void PreloadMeasuredStates(const std::string& log_file); - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("cur_task", &cur_task_); - } + // Run a list of callback functions + void RunCallbacks(const Array& callbacks); - // Dict keys + // Dict keys to give hints to the policy static constexpr const char* always_unroll_inner_key = "ansor_always_unroll_inner"; static constexpr const char* always_unroll_key = "ansor_always_unroll"; static constexpr const char* no_split_at_inner_key = "ansor_no_split_at_inner"; static constexpr const char* no_split_at_outer_key = "ansor_no_split_at_outer"; - static constexpr const char* debug_skip_region_key = "ansor_debug_skip_region"; static constexpr const char* last_split_is_one_key = "ansor_last_split_is_one"; - - // Flag keys + // Flag keys to give hints to the policy static constexpr const char* always_compute_inline_key = "ansor_always_compute_inline"; static constexpr const char* no_cache_write_key = "ansor_no_cache_write"; static constexpr const char* no_cache_read_key = "ansor_no_cache_read"; - static constexpr const char* tensor_core_support_key = "ansor_tensor_core_support"; static constexpr const char *_type_key = "ansor.SearchPolicy"; TVM_DECLARE_BASE_OBJECT_INFO(SearchPolicyNode, Object); diff --git a/src/ansor/search_policy/utils.cc b/src/ansor/search_policy/utils.cc index e0fd00b23e7b..ba42ca55611c 100644 --- a/src/ansor/search_policy/utils.cc +++ b/src/ansor/search_policy/utils.cc @@ -62,7 +62,6 @@ void GetSpaceSplitStepIds(const State& s, int stage_id, std::vector* spatia } } -// Apply multi-tiling structure according to a string format State DoMultiLevelTiling(const State& state, int stage_id, const std::string& format, std::vector* spatial_split_step_ids) { std::vector > space_levels; @@ -187,8 +186,6 @@ State DoMultiLevelTiling(const State& state, int stage_id, const std::string& fo return tmp_s; } -// Apply tiling structure: space, space -// But use tile sizes from other SplitStep State FollowTiling(const State& state, int stage_id, const std::vector& split_step_ids, int n_split) { if (n_split < 1 || n_split > 3) { @@ -280,7 +277,6 @@ State FollowTiling(const State& state, int stage_id, return tmp_s; } -// Randomly mutate the tile size of one SplitStep State RandomMutateTileSize(const State& old_state, SplitFactorizationMemo* split_memo, std::mt19937* random_gen, int max_innermost_split_factor) { State tmp_s = old_state; @@ -382,7 +378,6 @@ State RandomMutateTileSize(const State& old_state, SplitFactorizationMemo* split return State(); } -// Randomly mutate the value of one auto_unroll_max_step PragmaStep State RandomMutateMaxUnrollStep(const State& old_state, std::mt19937* random_gen, const std::vector& auto_unroll_configs) { State tmp_s = old_state; @@ -411,170 +406,6 @@ State RandomMutateMaxUnrollStep(const State& old_state, std::mt19937* random_gen return tmp_s; } -// Mutate a parallel loop. -State MutataParallel(const State& state, SplitFactorizationMemo* split_memo, - std::mt19937* random_gen, const SearchTask& task, int verbose) { - // To make this mutation simple but promising, we only focus on a specific case that - // parallel was added to the outermost loop and the loop is generated by fusing other loops. - // In short, we mutate the step pattern of (fuse -> parallel). - - // Extract all parallel steps. - std::vector parallel_steps; - for (size_t s = 0; s < state->transform_steps.size(); ++s) { - auto ps = state->transform_steps[s].as(); - if (!ps || ps->annotation != kParallel) { - continue; - } - parallel_steps.push_back(s); - } - if (parallel_steps.size() == 0) { - StdCout(verbose) << "Parallel mutation failed: No parallel annotations" << std::endl; - return State(); - } - - // Randomly pick one step. - int retry_ct = 0; - size_t step_id = 0; - size_t stage_id = 0; - do { - step_id = parallel_steps[(*random_gen)() % parallel_steps.size()]; - auto step = state->transform_steps[step_id].as(); - stage_id = step->stage_id; - - // Check assumptions. - auto iter_id = step->iter_id; - if (iter_id == 0 && step_id > 0 && state->transform_steps[step_id - 1].as()) { - break; - } - retry_ct++; - } while (retry_ct <= 3); - - if (retry_ct > 3) { - StdCout(verbose) << "Parallel mutation failed: No valid parallel annotations" << std::endl; - return State(); - } - - // 0: fuse less; 1: fuse more. - std::vector fuse_dir = {0.5, 1.0}; - - // The iter is an attached target so we can only fuse less. - if (state->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, 0)) > 0) { - fuse_dir[0] = 1.0; - } - - // Determine the fuse direction. - auto fuse_step = state->transform_steps[step_id - 1].as(); - std::vector fused_ids = fuse_step->fused_ids; - int iter_offset = 0; - if (RandomChoose(fuse_dir, random_gen) == 0) { - StdCout(verbose) << "Parallel mutation: release iter " << fused_ids.back() << std::endl; - fused_ids.pop_back(); - iter_offset = 1; - } else { - StdCout(verbose) << "Parallel mutation: include iter " << fused_ids.back() + 1 << std::endl; - fused_ids.push_back(fused_ids.back() + 1); - iter_offset = -1; - } - - // Replay a new state. - State tmp_s = task->compute_dag.GetInitState(); - for (size_t s = 0; s < state->transform_steps.size(); ++s) { - auto step = state->transform_steps[s]; - if (s == step_id - 1) { - step = FuseStepNode::make(step->stage_id, fused_ids); - } else if (s > step_id && step->stage_id == static_cast(stage_id)) { - // Since we change the loop structure, iter ID in later steps to the same stage - // has to be adjusted. - auto ps = step.as(); - if (ps) { - CHECK_LE(ps->iter_id + iter_offset, tmp_s->stages[stage_id]->iters.size()); - step = AnnotationStepNode::make(ps->stage_id, ps->iter_id + iter_offset, ps->annotation); - } else { - StdCout(verbose) << "Parallel mutation: Cannot apply " << step << " after fuse" - << std::endl; - return State(); - } - } - tmp_s.CopyOnWrite()->transform_steps.push_back(step); - tmp_s.DoStep(step, task->compute_dag); - } - return state; -} - -// Create all possible tile size states for all SplitStep -void GridMutateTileSize(const State& old_state, std::vector* cands, - SplitFactorizationMemo* split_memo, int max_innermost_split_factor) { - // Extract all SplitStep. - std::vector split_step_ids; - for (size_t i = 0; i < old_state->transform_steps.size(); ++i) { - if (old_state->transform_steps[i]->IsInstance()) { - split_step_ids.push_back(i); - } - } - if (split_step_ids.empty()) { - return; - } - - // Move tile sizes and generate candidates. - for (size_t step_id : split_step_ids) { - const SplitStepNode* ps = old_state->transform_steps[step_id].as(); - CHECK(ps != nullptr); - - int extent = GetIntImm(ps->extent); - if (extent == 1) { - continue; - } - - // Get the current tile sizes. - std::vector lengths(ps->lengths.size(), 1); - for (int i = 0; i < static_cast(ps->lengths.size()); ++i) { - lengths[i] = GetIntImm(ps->lengths[i]); - } - - const std::vector& const_factors = split_memo->GetFactors(extent); - CHECK_GE(const_factors.size(), 1); - - // Move tile size. - for (size_t i = 0; i < ps->lengths.size(); ++i) { - int old_length = lengths[i]; - - for (int factor : const_factors) { - if (i == ps->lengths.size() - 1 && factor > max_innermost_split_factor) { - // Limit the innermost factor. - break; - } - - // Make new length experssions and a new state. - std::vector length_exprs; - lengths[i] = factor; - int outermost = extent / ElementProduct(lengths); - if (outermost == 0) { - break; - } - - // std::cout << "Mutated extent " << extent << ": " << outermost; - for (size_t j = 0; j < lengths.size(); ++j) { - // std::cout << ", " << lengths[j]; - length_exprs.emplace_back(lengths[j]); - } - // std::cout << std::endl; - - State tmp_s = old_state; - const SplitStepNode* new_ps = tmp_s->transform_steps[step_id].as(); - auto pstate = tmp_s.CopyOnWrite(); - pstate->transform_steps[step_id] = - SplitStepNode::make(new_ps->stage_id, new_ps->iter_id, new_ps->extent, length_exprs, - new_ps->inner_to_outer); - if (tmp_s.defined()) { - cands->push_back(std::move(tmp_s)); - } - } - lengths[i] = old_length; - } - } -} - -// Prune undefined states. void PruneUndefined(std::vector* states) { size_t pt = 0; for (size_t i = 0; i < states->size(); ++i) { diff --git a/src/ansor/search_policy/utils.h b/src/ansor/search_policy/utils.h index 472e90771879..5f15397e7e90 100644 --- a/src/ansor/search_policy/utils.h +++ b/src/ansor/search_policy/utils.h @@ -464,14 +464,6 @@ State RandomMutateTileSize(const State& old_state, SplitFactorizationMemo* split State RandomMutateMaxUnrollStep(const State& old_state, std::mt19937* random_gen, const std::vector& auto_unroll_configs); -// Mutate a parallel loop. -State MutataParallel(const State& old_state, SplitFactorizationMemo* split_memo, - std::mt19937* random_gen, const SearchTask& task, int verbose = 0); - -// Create all possible tile size states for all SplitStep -void GridMutateTileSize(const State& old_state, std::vector* cands, - SplitFactorizationMemo* split_memo, int max_innermost_split_factor); - // GA: Crossover two states State CrossOverState(const State& p1, const State& p2); diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index ed5d4b868c27..454305c04ef5 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -28,6 +28,7 @@ #include #include #include +#include #include #include "serialization.h" #include "loop_state.h" diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 8bd5eca7c93d..a8cd1d3c2462 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -153,6 +153,11 @@ class RelayBuildModule : public runtime::ModuleNode { CHECK_EQ(args.num_args, 2); *rv = this->Optimize(args[0], args[1], this->params_); }); + } else if (name == "call_all_topi_funcs") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue *rv) { + CHECK_EQ(args.num_args, 3); + this->CallAllTopiFuncs(args[0], args[1], args[2]); + }); } else { LOG(FATAL) << "Unknown packed function: " << name; return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); @@ -227,6 +232,21 @@ class RelayBuildModule : public runtime::ModuleNode { BuildRelay(mod, params_); } + /*! \brief Call all used TOPI compute and schedule in a relay function */ + void CallAllTopiFuncs(IRModule mod, + const TargetsMap& targets, + const tvm::Target& target_host) { + targets_ = targets; + target_host_ = target_host; + + IRModule relay_module = Optimize(mod, targets_, params_); + auto func = Downcast(relay_module->Lookup("main")); + + graph_codegen_ = std::unique_ptr(new GraphCodegen()); + graph_codegen_->Init(nullptr, targets_); + graph_codegen_->Codegen(func); + } + protected: /*! * \brief Optimize a Relay IRModule. @@ -287,7 +307,6 @@ class RelayBuildModule : public runtime::ModuleNode { // Alter layout transformation is only applied to homogeneous execution yet. if (targets.size() == 1) { pass_seqs.push_back(transform::AlterOpLayout()); - //pass_seqs.push_back(transform::KernelLayoutTransform()); } // Fast math optimizations. diff --git a/src/relay/transforms/kernel_layout_transform.h b/src/relay/transforms/kernel_layout_transform.h index b4b806c20e28..c82a96b30612 100644 --- a/src/relay/transforms/kernel_layout_transform.h +++ b/src/relay/transforms/kernel_layout_transform.h @@ -20,7 +20,8 @@ class KernelLayoutVisitor : public ExprVisitor { !global_ori_layouts_queue.empty() && !global_new_layouts_queue.empty()) { ori_layouts_map[n] = global_ori_layouts_queue.front(); new_layouts_map[n] = global_new_layouts_queue.front(); - std::cout << "ori_layout " << global_ori_layouts_queue.front() << " Filter_shape " << n->args[1]->type_as()->shape << std::endl; + // std::cout << "ori_layout " << global_ori_layouts_queue.front() + // << " Filter_shape " << n->args[1]->type_as()->shape << std::endl; global_ori_layouts_queue.pop_front(); global_new_layouts_queue.pop_front(); } diff --git a/tests/python/unittest/test_ansor_relay_Integration.py b/tests/python/unittest/test_ansor_relay_integration.py similarity index 53% rename from tests/python/unittest/test_ansor_relay_Integration.py rename to tests/python/unittest/test_ansor_relay_integration.py index 9c423220844c..f3f424ab321b 100644 --- a/tests/python/unittest/test_ansor_relay_Integration.py +++ b/tests/python/unittest/test_ansor_relay_integration.py @@ -22,19 +22,18 @@ import tvm from tvm import ansor, relay import tvm.contrib.graph_runtime as runtime +from tvm.relay.testing import dqn -from test_ansor_common import get_tiled_matmul +def test_tune_dense_graph(): + def dense_graph(N, dtype="float32"): + ori_data = relay.var("data", shape=(N, N), dtype=dtype) + weight = relay.var("weight", shape=(N, N), dtype=dtype) + data = relay.multiply(ori_data, relay.const(2, dtype=dtype)) + dense = relay.nn.dense(data, weight, out_dtype=dtype) + dense = relay.add(dense, weight) + dense = relay.nn.dense(dense, weight, out_dtype=dtype) + return ori_data, weight, dense -def dense_graph(N, dtype="float32"): - ori_data = relay.var("data", shape=(N, N), dtype=dtype) - weight = relay.var("weight", shape=(N, N), dtype=dtype) - data = relay.multiply(ori_data, relay.const(2, dtype=dtype)) - dense = relay.nn.dense(data, weight, out_dtype=dtype) - dense = relay.add(dense, weight) - dense = relay.nn.dense(dense, weight, out_dtype=dtype) - return ori_data, weight, dense - -def test_dense_integration(): N = 128 data, weight, dense = dense_graph(N) mod = relay.Function([data, weight], dense) @@ -44,34 +43,23 @@ def test_dense_integration(): target = tvm.target.create("llvm") d = tvm.nd.array(np.random.uniform(size=(N, N)).astype(data.type_annotation.dtype), ctx) w = tvm.nd.array(np.random.uniform(size=(N, N)).astype(weight.type_annotation.dtype), ctx) - workloads, wkl_weights = ansor.extract_from_program(mod, {}, target=target) + wkl_keys, wkl_weights = ansor.extract_from_program(mod, {}, target=target) - assert len(workloads) == 2 + assert len(wkl_keys) == 2 assert len(wkl_weights) == 2 tasks = [] - for wkl_key in workloads: + for wkl_key in wkl_keys: dag = ansor.workload_key_to_dag(wkl_key) tasks.append(ansor.SearchTask(dag, wkl_key, target)) - assert str(tasks[0].compute_dag) == "placeholder = PLACEHOLDER [128, 128]\n" + \ - "placeholder = PLACEHOLDER [128, 128]\n" + \ - "compute(z, y, x) += (placeholder[z, ((k*16) + x)]*placeholder[y, ((k*16) + x)])\n" + \ - "compute(y, x) += compute[y, x, kk]\n" - - assert str(tasks[1].compute_dag) == "placeholder = PLACEHOLDER [128, 128]\n" + \ - "placeholder = PLACEHOLDER [128, 128]\n" + \ - "compute(z, y, x) += (placeholder[z, ((k*16) + x)]*placeholder[y, ((k*16) + x)])\n" + \ - "compute(y, x) += compute[y, x, kk]\n" + \ - "T_add(ax0, ax1) = (compute[ax0, ax1] + placeholder[ax0, ax1])\n" - tuner = ansor.SimpleTaskScheduler(tasks) measure_ctx = ansor.LocalRPCMeasureContext() with tempfile.NamedTemporaryFile() as fp: - tuner.tune(ansor.TuneOption(n_trials=4, runner=measure_ctx.runner, + tuner.tune(ansor.TuneOption(n_trials=2, runner=measure_ctx.runner, measure_callbacks=[ansor.LogToFile(fp.name)])) with ansor.apply_history_best(fp.name): - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}): graph, lib, opt_params = relay.build_module.build( mod, target=target) @@ -80,8 +68,8 @@ def test_dense_integration(): m.set_input('weight', w) m.run() res = m.get_output(0) - if measure_ctx: - del measure_ctx + + del measure_ctx d = d.asnumpy() d = d * 2 @@ -92,5 +80,36 @@ def test_dense_integration(): tvm.testing.assert_allclose(res.asnumpy(), d, rtol=1e-5) + +def test_tune_dqn(): + mod, params = dqn.get_workload(1, image_shape=(84, 84, 4), layout='NHWC') + target = tvm.target.create('llvm') + ctx = tvm.context("llvm") + + wkl_keys, wkl_weights = ansor.extract_from_program(mod, params, target) + + tasks = [] + for wkl_key in wkl_keys: + dag = ansor.workload_key_to_dag(wkl_key) + tasks.append(ansor.SearchTask(dag, wkl_key, target)) + + assert len(tasks) == 5 + + tuner = ansor.SimpleTaskScheduler(tasks) + measure_ctx = ansor.LocalRPCMeasureContext() + with tempfile.NamedTemporaryFile() as fp: + tuner.tune(ansor.TuneOption(n_trials=len(tasks), runner=measure_ctx.runner, + measure_callbacks=[ansor.LogToFile('tmp.json')]), + search_policy='meta-rewrite.random') + with ansor.apply_history_best('tmp.json'): + ansor.prepare_layout_rewrite(mod, params, target) + with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}): + graph, lib, opt_params = relay.build_module.build(mod, target=target) + ansor.finish_layout_rewrite() + + del measure_ctx + if __name__ == "__main__": - test_dense_integration() + test_tune_dense_graph() + test_tune_dqn() + diff --git a/topi/python/topi/ansor.py b/topi/python/topi/ansor.py deleted file mode 100644 index e821fd5bd42f..000000000000 --- a/topi/python/topi/ansor.py +++ /dev/null @@ -1,95 +0,0 @@ -"""All AutoSchedule Supported Operators""" -from __future__ import absolute_import as _abs -from tvm import ansor - -@ansor.register_topi_schedule() -def schedule_dense_nopack(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_conv2d_nhwc(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_conv2d_NCHWc(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_reduce(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_pool(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_adaptive_pool(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_softmax(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_conv2d_nchw_int8(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_conv2d_nchw(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_depthwise_conv2d_nchw(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_depthwise_conv2d_nhwc(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_conv2d_NCHWc_int8(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_depthwise_conv2d_NCHWc(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_conv2d_transpose_nchw(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_conv3d_ncdhw(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_conv3d_ndhwc(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_conv1d_ncw(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_conv1d_nwc(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_dense_pack(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_batch_matmul(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_bitserial_conv2d_nchw(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_bitserial_conv2d_nhwc(cfg, outs): - return ansor.gen_schedule(cfg, outs) - -@ansor.register_topi_schedule() -def schedule_bitserial_dense(cfg, outs): - return ansor.gen_schedule(cfg, outs) diff --git a/topi/python/topi/arm_cpu/__init__.py b/topi/python/topi/arm_cpu/__init__.py index 0c0979763dba..e121fbc7ec6d 100644 --- a/topi/python/topi/arm_cpu/__init__.py +++ b/topi/python/topi/arm_cpu/__init__.py @@ -26,8 +26,3 @@ from .bitserial_dense import * from .injective import * from . import cortex_m7 - -import os -use_auto_scheduler = os.environ.get("TVM_USE_AUTO_SCHEDULER", "false") -if use_auto_scheduler.lower() == "true": - from ..ansor import * diff --git a/topi/python/topi/generic/__init__.py b/topi/python/topi/generic/__init__.py index d44fca8548d2..6171317cd80f 100644 --- a/topi/python/topi/generic/__init__.py +++ b/topi/python/topi/generic/__init__.py @@ -39,8 +39,3 @@ from .sort import * from .search import * from .image import * - -import os -use_auto_scheduler = os.environ.get("TVM_USE_AUTO_SCHEDULER", "false") -if use_auto_scheduler.lower() == "true": - from ..ansor import * diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index de02367a4dff..6800129c12aa 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -20,7 +20,7 @@ from __future__ import absolute_import as _abs from collections import namedtuple import tvm -from tvm import te +from tvm import te, ansor from .pad import pad from .util import get_pad_tuple @@ -342,23 +342,36 @@ def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'): dilation_h, dilation_w = dilation batch, in_height, in_width, in_channel = Input.shape - if len(Filter.shape) == 10: - kernel_h = Filter.shape[2] * Filter.shape[6] - kernel_w = Filter.shape[3] * Filter.shape[7] - channel = Filter.shape[4] * Filter.shape[8] - num_filter = Filter.shape[0] * Filter.shape[1] * Filter.shape[5] * Filter.shape[9] - elif len(Filter.shape) == 11: - kernel_h = Filter.shape[3] * Filter.shape[7] - kernel_w = Filter.shape[4] * Filter.shape[8] - channel = Filter.shape[5] * Filter.shape[9] - num_filter = Filter.shape[0] * Filter.shape[1] * Filter.shape[2] * Filter.shape[6] * Filter.shape[10] - elif len(Filter.shape) == 12: - kernel_h = Filter.shape[4] * Filter.shape[8] - kernel_w = Filter.shape[5] * Filter.shape[9] - channel = Filter.shape[6] * Filter.shape[10] - num_filter = Filter.shape[0] * Filter.shape[1] * Filter.shape[2] * Filter.shape[3] * Filter.shape[7] * Filter.shape[11] + if ansor.GLOBAL_SCOPE.topi_in_compute_rewrite_mode: + # infer shape for the rewritten layout + if len(Filter.shape) >= 10: + # For cpu tile structure SSRSRS + base = len(Filter.shape) - 10 + kernel_h = Filter.shape[2 + base] * Filter.shape[6 + base] + kernel_w = Filter.shape[3 + base] * Filter.shape[7 + base] + channel = Filter.shape[4 + base] * Filter.shape[8 + base] + num_filter = Filter.shape[5 + base] * Filter.shape[9 + base] + for i in range(base + 2): + num_filter *= Filter.shape[i] + elif len(Filter.shape) == 6: + # For cpu tile structure SRS + num_filter = Filter.shape[0] * Filter.shape[1] * Filter.shape[5] + kernel_h = Filter.shape[2] + kernel_w = Filter.shape[3] + channel = Filter.shape[4] + elif len(Filter.shape) == 5: + # For cpu tile structure SRS + num_filter = Filter.shape[0] * Filter.shape[4] + kernel_h = Filter.shape[1] + kernel_w = Filter.shape[2] + channel = Filter.shape[3] + elif len(Filter.shape) == 4: + num_filter, kernel_h, kernel_w, channel = Filter.shape + else: + raise ValueError("Don't know how to infer layout for filter shape: %s. " \ + "You can add a new branch for it to fix this." % str(Filter)) else: - kernel_h, kernel_w, channel, num_filter = Filter.shape + kernel_h, kernel_w, channel, num_filter = Filter.shape # compute the output shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 diff --git a/topi/python/topi/x86/__init__.py b/topi/python/topi/x86/__init__.py index 28e9e862f4d8..659668cbbe4c 100644 --- a/topi/python/topi/x86/__init__.py +++ b/topi/python/topi/x86/__init__.py @@ -39,8 +39,3 @@ from .conv3d_transpose import * from .sparse import * from .conv2d_alter_op import * - -import os -use_auto_scheduler = os.environ.get("TVM_USE_AUTO_SCHEDULER", "false") -if use_auto_scheduler.lower() == "true": - from ..ansor import * diff --git a/tutorials/ansor/tune_conv2d_cuda.py b/tutorials/ansor/tune_conv2d_cuda.py index 14a6ee797276..437323d79791 100644 --- a/tutorials/ansor/tune_conv2d_cuda.py +++ b/tutorials/ansor/tune_conv2d_cuda.py @@ -124,7 +124,7 @@ def conv2d_nchw(N, H, W, CO, CI, KH, KW, stride, padding): # in the tuning logs. # :code:`ansor.LogToFile` callback will log the tuning results into a # log file, which can be used to get the best config later. -# :code:`ansor.PreLoadMeasuredStates` callback will load measured states +# :code:`ansor.PreloadMeasuredStates` callback will load measured states # from history log before schedule search, we can add this callback to make # sure a same schedule will never be measured for multiple times. @@ -132,7 +132,7 @@ def conv2d_nchw(N, H, W, CO, CI, KH, KW, stride, padding): tune_option = ansor.TuneOption(n_trials=20, runner=measure_ctx.runner, measure_callbacks=[ansor.LogToFile(log_file)], - pre_search_callbacks=[ansor.PreLoadMeasuredStates(log_file)]) + pre_search_callbacks=[ansor.PreloadMeasuredStates(log_file)]) s, arg_bufs = ansor.auto_schedule(task, search_policy=search_policy, tune_option=tune_option) print("==== Get Lowered Stmt ====") diff --git a/tutorials/ansor/tune_simple_subgraph.py b/tutorials/ansor/tune_simple_subgraph.py index dfd36e89fd4c..08d5628ad8a2 100644 --- a/tutorials/ansor/tune_simple_subgraph.py +++ b/tutorials/ansor/tune_simple_subgraph.py @@ -148,7 +148,7 @@ def matmul_add(N, L, M, dtype): # you can do more trials according to your time budget. # :code:`ansor.LogToFile` callback will log the tuning results into a # log file, which can be used to get the best config later. -# :code:`ansor.PreLoadMeasuredStates` callback will load measured states +# :code:`ansor.PreloadMeasuredStates` callback will load measured states # from history log before schedule search, we can add this callback to make # sure a same schedule will never be measured for multiple times. @@ -161,7 +161,7 @@ def matmul_add(N, L, M, dtype): tune_option = ansor.TuneOption(n_trials=5, measure_callbacks=[ansor.LogToFile(log_file)], - pre_search_callbacks=[ansor.PreLoadMeasuredStates(log_file)]) + pre_search_callbacks=[ansor.PreloadMeasuredStates(log_file)]) ################################################################ # Then just call :code:`ansor.auto_schedule` and Ansor will try to find a high From 0794875b61cea652fede1599b49dd64c81807ce5 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 20 Jun 2020 06:34:49 -0700 Subject: [PATCH 32/45] Fix xgb error & Simplify dispatcher (#35) --- python/tvm/ansor/__init__.py | 2 +- python/tvm/ansor/auto_schedule.py | 1 - python/tvm/ansor/compute_dag.py | 19 +- python/tvm/ansor/cost_model/cost_model.py | 5 +- python/tvm/ansor/cost_model/xgb_model.py | 12 +- python/tvm/ansor/dispatcher.py | 233 ++------------------ python/tvm/ansor/env.py | 18 ++ python/tvm/ansor/feature.py | 1 - python/tvm/ansor/measure.py | 8 +- python/tvm/ansor/serialization.py | 1 + python/tvm/ansor/task_scheduler.py | 5 +- python/tvm/ansor/workload_registry.py | 1 - src/ansor/serialization.cc | 3 +- tests/python/unittest/test_ansor_feature.py | 1 - 14 files changed, 70 insertions(+), 240 deletions(-) diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 977e100e63c6..90a11820d159 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -40,7 +40,7 @@ workload_key_to_dag, make_workload_key_func from .task_scheduler import TaskScheduler, SimpleTaskScheduler from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest as apply_history_best, \ - FallbackContext, clear_fallback_cache, ApplyGraphBest + FallbackContext from .relay_integration import extract_from_program, extract_from_multiple_program, \ finish_layout_rewrite, prepare_layout_rewrite, auto_schedule_topi from .env import GLOBAL_SCOPE diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index acf8982d6e89..e8108a067b2e 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -97,7 +97,6 @@ class MetaTileRewritePolicy(SearchPolicy): seed: int Random seed """ - def __init__(self, program_cost_model, params=None, diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index f35c9d8221f3..6304c7bb0e0a 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -53,6 +53,8 @@ def get_init_state(self): def apply_steps_from_state(self, state, layout_rewrite_level=LayoutRewriteLevel.NO_REWRITE): """ + Apply transform steps according to the history of a state + Parameters ---------- state : StateObject @@ -68,6 +70,8 @@ def apply_steps_from_state(self, state, layout_rewrite_level=LayoutRewriteLevel. def print_python_code_from_state(self, state): """ + Print transform steps in the history of a state as TVM's python schedule primitive + Parameters ---------- state : StateObject @@ -81,16 +85,29 @@ def print_python_code_from_state(self, state): def infer_bound_from_state(self, state): """ + Infer bound for a state + Parameters ---------- state : StateObject Returns ------- - state : StateObject + state : State """ state_obj = state if isinstance(state, StateObject) else state.state_object return State(_ffi_api.ComputeDAGInferBoundFromState(self, state_obj), self) def rewrite_layout_from_state(self, state: State): + """ + Rewrite the layout according to the transform steps in the history of a state + + Parameters + ---------- + state : StateObject + + Returns + ------- + state : StateObject + """ return _ffi_api.ComputeDAGRewriteLayoutFromState(self, state) diff --git a/python/tvm/ansor/cost_model/cost_model.py b/python/tvm/ansor/cost_model/cost_model.py index 47ea5092b302..57cc53853b2e 100644 --- a/python/tvm/ansor/cost_model/cost_model.py +++ b/python/tvm/ansor/cost_model/cost_model.py @@ -26,18 +26,20 @@ @tvm._ffi.register_object("ansor.CostModel") class CostModel(Object): + """The base class for cost model""" pass @tvm._ffi.register_object("ansor.RandomModel") class RandomModel(Object): + """A model returns random estimation for all inputs""" def __init__(self): self.__init_handle_by_constructor__(_ffi_api.RandomModel) -# A random number generator func for c++'s RandomModel @tvm._ffi.register_func("ansor.cost_model.random_number") def random_number(n, return_ptr): + """ A random number generator func for c++'s RandomModel """ if n == 0: return return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_float)) @@ -47,6 +49,7 @@ def random_number(n, return_ptr): @tvm._ffi.register_object("ansor.PythonBasedModel") class PythonBasedModel(CostModel): + """Base class for cost models implemented in python""" def __init__(self): def update_func(inputs, results): self.update(inputs, results) diff --git a/python/tvm/ansor/cost_model/xgb_model.py b/python/tvm/ansor/cost_model/xgb_model.py index fce3f16d18ba..42af17daae2c 100644 --- a/python/tvm/ansor/cost_model/xgb_model.py +++ b/python/tvm/ansor/cost_model/xgb_model.py @@ -16,16 +16,14 @@ # under the License. """Cost model based on xgboost""" -from typing import List import multiprocessing import logging -import time from collections import defaultdict import numpy as np import xgboost as xgb -from ...autotvm.tuner.xgboost_cost_model import get_rank, recall_curve, max_curve +from tvm.autotvm.tuner.xgboost_cost_model import get_rank, recall_curve, max_curve from .cost_model import PythonBasedModel from ..feature import get_per_stmt_features_from_measure_pairs, get_per_stmt_features_from_states from ..serialization import LogReader @@ -65,8 +63,8 @@ def __init__(self, verbose_eval=25, num_warmup_sample=100, seed=None): # todo(lmzheng): automatically decrease learning rate when the loss is too large 'n_gpus': 0, - 'n_threads': multiprocessing.cpu_count() / 2, - 'silent': 0, + 'nthread': multiprocessing.cpu_count() // 2, + 'verbosity': 0, 'seed': seed or 43, 'disable_default_eval_metric': 1 } @@ -180,7 +178,7 @@ def pack_sum_xgbmatrix_for_prediction(xs): x_flatten.append(row) pack_ids.append(ct) - return xgb.DMatrix(x_flatten), pack_ids + return xgb.DMatrix(np.array(x_flatten)), pack_ids def pack_sum_xgbmatrix(xs, ys, gids=None, weights=None): @@ -214,7 +212,7 @@ def pack_sum_xgbmatrix(xs, ys, gids=None, weights=None): y_flatten.append(y) pack_ids.append(ct) - ret = xgb.DMatrix(x_flatten, y_flatten) + ret = xgb.DMatrix(np.array(x_flatten), y_flatten) if weights is not None: ret.set_weight(weights_flatten) dmatrix_context.put('pack_ids', ret, np.array(pack_ids)) diff --git a/python/tvm/ansor/dispatcher.py b/python/tvm/ansor/dispatcher.py index 0ef07197ea92..0c07fd141bd2 100644 --- a/python/tvm/ansor/dispatcher.py +++ b/python/tvm/ansor/dispatcher.py @@ -15,16 +15,7 @@ # specific language governing permissions and limitations # under the License. """ -Template dispatcher module. - -A dispatcher is a function that can contains multiple behaviors. -Its specific behavior is can be controlled by DispatchContext. - -DispatchContext is used in two ways, usually via different implementation -of the DispatchContext base class. - -- During search, we can use it to pass the current proposal from tuner. -- During evaluation, we can use it to set pick the best policy. +The global context that dispatches best configurations to workloads """ # pylint: disable=invalid-name @@ -33,9 +24,7 @@ import logging import numpy as np -from decorator import decorate -from tvm import target as _target from tvm.tir.expr import FloatImm logger = logging.getLogger('auto_scheduler') @@ -44,9 +33,6 @@ class DispatchContext(object): """ Base class of dispatch context. - - DispatchContext enables the target and workload - specific dispatch mechanism for templates. """ current = None @@ -55,7 +41,7 @@ def __init__(self): def query(self, target, workload): """ - Query the context to get the specific config for a template. + Query the context to get the specific config for a workload. If cannot find the result inside this context, this function will query it from the upper contexts. @@ -63,22 +49,20 @@ def query(self, target, workload): ---------- target: Target The current target - workload : Workload - The current workload. + workload : str + The current workload Returns ------- - cfg : State or str - The specific state for auto scheduler. + cfg : State + The schedule configuration for the workload """ ret = self._query_inside(target, workload) - #if ret is None: - # ret = self._old_ctx.query(target, workload) return ret def update(self, target, workload, cfg): """ - Update context with a specific config. + Update the config for a workload Parameters ---------- @@ -86,46 +70,14 @@ def update(self, target, workload, cfg): The current target workload : Workload The current workload. - cfg : State or str - The specific state for auto scheduler. - - Note - ---- - This interface is for cases when TVM decides to replace an operator in the graph. - For example, `AlterOpLayout` pass (enables when `opt_level = 3`) replaces `NCHW` - convolution with `NCHW[x]c` implementation on x86 CPUs. - Thus in TOPI, we first query schedule using original `NCHW` workload, - then update the dispatcher with the new `NCHW[x]c` workload. - So that later on, `NCHW[x]c` convolution can get schedule from the dispatcher using - its own workload directly. - - .. code-block:: python - - @conv2d_alter_layout.register("cpu") - def _alter_conv2d_layout(attrs, inputs, tinfo): - workload = get_conv2d_workload(...) - dispatch_ctx = auto_scheduler.DispatchContext.current - target = tvm.target.current_target() - config = dispatch_ctx.query(target, workload) - - # Get conv2d_NCHWc workload from config - # new_workload = ... - # new_inputs = ... - # new_attrs = ... - - # Store altered operator's config - dispatch_ctx.update(target, new_workload, config) - return sym.contrib.conv2d_NCHWc(*new_inputs, **new_attrs) - - We directly store `config` back because `conv2d_NCHW` and `conv2d_NCHWc` - share the same schedule parameters. - One can construct a new `State` if this is not the case. + cfg : State + The schedule configuration for the workload """ raise NotImplementedError() def _query_inside(self, target, workload): """ - Query the context to get the specific config for a template. + Query the context to get the specific config for a workload. This function only query config inside this context. Parameters @@ -138,7 +90,7 @@ def _query_inside(self, target, workload): Returns ------- cfg : State or str - The specific state for auto scheduler. + The schedule configuration for the workload """ raise NotImplementedError() @@ -151,78 +103,13 @@ def __exit__(self, ptype, value, trace): DispatchContext.current = self._old_ctx -def dispatcher(fworkload): - """Wrap a workload dispatcher function. - - Parameters - ---------- - fworkload : function - The workload extraction function from arguments. - - Returns - ------- - fdispatcher : function - A wrapped dispatcher function, which will - dispatch based on DispatchContext and - the current workload. - """ - dispatch_dict = {} - func_name = fworkload.__name__ - - def register(key, func=None, override=False): - """Register template function. - - Parameters - ---------- - key : str or List of str - The template key to identify the template - under this dispatcher. - func : function - The function to be registered. - The first argument of the function is always - cfg returned by DispatchContext, - the rest arguments are the same as the fworkload. - override : bool - Whether override existing registration. - - Returns - ------- - The register function if necessary. - """ - if isinstance(key, str): - key = [key] - - def _do_reg(myf): - for x in key: - if x in dispatch_dict and not override: - raise ValueError( - "Key %s is already registered for %s" % (x, func_name)) - dispatch_dict[x] = myf - return myf - - if func: - return _do_reg(func) - return _do_reg - - def dispatch_func(func, *args, **kwargs): - """The wrapped dispatch function""" - tgt = _target.current_target() - workload = func(*args, **kwargs) - cfg = DispatchContext.current.query(tgt, workload) - return dispatch_dict['direct'](cfg, *args, **kwargs) - - fdecorate = decorate(fworkload, dispatch_func) - fdecorate.register = register - return fdecorate - - class ApplyConfig(DispatchContext): - """Apply a deterministic config entity for all queries. + """Apply a deterministic config for all queries. Parameters ---------- config : State - The specific state for auto scheduler. + The schedule configuration """ def __init__(self, config): super(ApplyConfig, self).__init__() @@ -361,9 +248,7 @@ def update(self, target, workload, state): class FallbackContext(DispatchContext): """ A fallback dispatch context. - - Any tunable template can be called under this context. - This is the root context. + This is used as the root context. """ def __init__(self): @@ -387,7 +272,7 @@ def _query_inside(self, target, workload): logger.warning(msg) cfg = None - # cache this config + # cache this config to avoid duplicated warning message self.memory[key] = cfg return cfg @@ -412,91 +297,3 @@ def update(self, target, workload, cfg): DispatchContext.current = FallbackContext() - - -def clear_fallback_cache(target, workload): - """Clear fallback cache. Pass the same argument as _query_inside to this function - to clean the cache. - - Parameters - ---------- - target: Target - The current target - workload : Workload - The current workload. - - Note - ---- - This is used in alter_op_layout to clear the bad cache created before call topi compute function - """ - context = DispatchContext.current - while not isinstance(context, FallbackContext): - context = context._old_ctx - context.clear_cache(target, workload) - - -class ApplyGraphBest(DispatchContext): - """Load the graph level tuning optimal schedules. - - The input records should be in the ascending order of - node index for target operator. Usually this can be obtained - with graph tuner. - - This context maintains an internal counter to indicate the current - node index. - """ - def __init__(self, records): - """ - Parameters - ---------- - records : str or iterator of (MeasureInput, MeasureResult) - Collection of tuning records. - If is str, then it should be the filename of a records log file. - Each row of this file is an encoded record pair. - Otherwise, it is an iterator. - """ - from . import load_from_file - - super(ApplyGraphBest, self).__init__() - if isinstance(records, str): - records = load_from_file(records) - self._records = list(records) - self._counter = 0 - self._global_cfg_dict = {} - - def _query_inside(self, target, workload): - """ - Query the context to get config from records. - - Parameters - ---------- - target : Target - The current target - workload : Workload - The current workload. - - Returns - ------- - cfg : State or str - The specific state for auto scheduler. - """ - if self._counter < len(self._records): - cfg = self._records[self._counter][0].config - self._counter += 1 - self.update(target, workload, cfg) - return cfg - key = (str(target), workload) - if key not in self._global_cfg_dict: - msg = "Config for target=%s, workload=%s is missing in ApplyGraphBest context. " \ - "A fallback configuration is used, which may bring great performance " \ - "regression." % (target, workload) - logger.warning(msg) - cfg = None - self._global_cfg_dict[key] = cfg - else: - cfg = self._global_cfg_dict[key] - return cfg - - def update(self, target, workload, cfg): - key = (str(target), workload) - self._global_cfg_dict[key] = cfg diff --git a/python/tvm/ansor/env.py b/python/tvm/ansor/env.py index 9e44ad66048b..0f35f92acbbc 100644 --- a/python/tvm/ansor/env.py +++ b/python/tvm/ansor/env.py @@ -1,5 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + """ The scope to store global variables in ansor """ + class AutoschedulerGlobalScope(object): def __init__(self): self.topi_in_compute_rewrite_mode = False diff --git a/python/tvm/ansor/feature.py b/python/tvm/ansor/feature.py index 9496533da6cc..d9f6d297f1af 100644 --- a/python/tvm/ansor/feature.py +++ b/python/tvm/ansor/feature.py @@ -17,7 +17,6 @@ """" Python API for Feature extraction. -The specification of features can be found in `autoscheduler_doc/per_stage_feature.md` """ from typing import List, Tuple diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index 3d9c33860cae..f00fe672505d 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -230,7 +230,8 @@ def __init__(self, key, host, port, priority=1, class LocalRPCMeasureContext: - """ A context wrapper for RPCRunner. + """ A context wrapper for running RPCRunner locally. + This will launch a local RPC Tracker and local RPC Server. Parameters ---------- @@ -276,10 +277,10 @@ class MeasureErrorNo(object): """Error type for MeasureResult""" NO_ERROR = 0 # No error INSTANTIATION_ERROR = 1 # Errors happen when apply transform steps from init state - # Errors happen when compiling code on host (e.g. tvm.build) + # Errors happen when compiling code on host (e.g. tvm.build) COMPILE_HOST = 2 COMPILE_DEVICE = 3 # Errors happen when compiling code on device - # (e.g. OpenCL JIT on the device) + # (e.g. OpenCL JIT on the device) RUNTIME_DEVICE = 4 # Errors happen when run program on device WRONG_ANSWER = 5 # Answer is wrong when compared to a reference output BUILD_TIMEOUT = 6 # Timeout during compilation @@ -288,6 +289,7 @@ class MeasureErrorNo(object): def make_error_msg(): + """Get the error message from traceback""" error_msg = str(traceback.format_exc()) if len(error_msg) > MAX_ERROR_MSG_LEN: error_msg = error_msg[:MAX_ERROR_MSG_LEN//2] + \ diff --git a/python/tvm/ansor/serialization.py b/python/tvm/ansor/serialization.py index 97903b38bb0b..1bd9d8cf64e6 100644 --- a/python/tvm/ansor/serialization.py +++ b/python/tvm/ansor/serialization.py @@ -64,6 +64,7 @@ def __iter__(self): break yield ret[0], ret[1] # (input, result) + def load_from_file(filename: str): """Load measurement records from a file""" return zip(*LogReader(filename).read_lines()) diff --git a/python/tvm/ansor/task_scheduler.py b/python/tvm/ansor/task_scheduler.py index 89b4afd84e86..3d4d9624d7c2 100644 --- a/python/tvm/ansor/task_scheduler.py +++ b/python/tvm/ansor/task_scheduler.py @@ -147,13 +147,12 @@ def __init__(self, def tune(self, tune_option: TuneOption, search_policy: Union[str, List[SearchPolicy]] = 'default'): """ Tune tasks. - Notice: This method does not have return value, make sure to set `LogToFile` - measure callback in `tune_option`. + Notice: This method does not have return value, make sure to set `LogToFile` + measure callback in `tune_option`. Parameters ---------- tune_option: TuneOption - search_policy: Str or List[SearchPolicy] """ # init members diff --git a/python/tvm/ansor/workload_registry.py b/python/tvm/ansor/workload_registry.py index fccdcf8864be..bcf8269b9490 100644 --- a/python/tvm/ansor/workload_registry.py +++ b/python/tvm/ansor/workload_registry.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. - """ Workload registration and serialization. diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index 454305c04ef5..2d8379f56a5f 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -55,7 +55,6 @@ template <> struct Handler > { inline static void Write(dmlc::JSONWriter* writer, const std::vector<::tvm::ansor::Stage> & data) { - // todo(lmzheng): support serialization of Stage writer->BeginArray(false); writer->EndArray(); } @@ -456,7 +455,7 @@ namespace ansor { TVM_REGISTER_OBJECT_TYPE(LogToFileNode); TVM_REGISTER_OBJECT_TYPE(LogReaderNode); -const std::string ANSOR_LOG_VERSION = "v0.1"; // NOLINT(*) +const std::string ANSOR_LOG_VERSION = "v0.2"; // NOLINT(*) MeasureCallback LogToFileNode::make(std::string filename) { auto node = make_object(); diff --git a/tests/python/unittest/test_ansor_feature.py b/tests/python/unittest/test_ansor_feature.py index bb19b84a970d..bcc7683b3f4a 100644 --- a/tests/python/unittest/test_ansor_feature.py +++ b/tests/python/unittest/test_ansor_feature.py @@ -148,4 +148,3 @@ def test_gpu_feature(): test_cpu_matmul() test_cpu_fusion() test_gpu_feature() - From a4c4548f2d1da651c8f13f8552e9cc9df2f167eb Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 20 Jun 2020 08:58:41 -0700 Subject: [PATCH 33/45] Rename "MetaTileRewritePolicy" to "SketchPolicy". (#36) * Rename "MetaTileRewritePolicy" to "SketchPolicy". * Add a new class for auto_unroll_max_step, storage_offset in StageNode * fix tune_op_subgraph.py --- python/tvm/ansor/__init__.py | 6 +- python/tvm/ansor/auto_schedule.py | 28 ++-- python/tvm/ansor/relay_integration.py | 7 +- python/tvm/ansor/task_scheduler.py | 18 +-- python/tvm/ansor/workload_registry.py | 14 +- scripts/common.py | 38 ++--- scripts/shape_configs.py | 24 +-- scripts/tune_network.py | 137 ++++++++--------- scripts/tune_op_subgraph.py | 144 ++++++++---------- scripts/tune_test.py | 97 ++++++------ src/ansor/auto_schedule.cc | 2 +- src/ansor/compute_dag.cc | 3 +- src/ansor/loop_state.cc | 37 +++-- src/ansor/loop_state.h | 15 +- src/ansor/search_policy/search_policy.h | 1 + ...rite_policy.cc => sketch_search_policy.cc} | 132 ++++++++-------- ...ewrite_policy.h => sketch_search_policy.h} | 53 ++++--- tests/python/unittest/test_ansor_common.py | 2 +- .../unittest/test_ansor_relay_integration.py | 3 +- .../unittest/test_ansor_search_policy.py | 15 +- tutorials/ansor/tune_conv2d_cuda.py | 4 +- tutorials/ansor/tune_simple_subgraph.py | 4 +- 22 files changed, 386 insertions(+), 398 deletions(-) rename src/ansor/search_policy/{meta_tile_rewrite_policy.cc => sketch_search_policy.cc} (91%) rename src/ansor/search_policy/{meta_tile_rewrite_policy.h => sketch_search_policy.h} (66%) diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 90a11820d159..c629c1049a87 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -29,14 +29,14 @@ # Shortcut from .compute_dag import ComputeDAG, LayoutRewriteLevel -from .auto_schedule import SearchTask, MetaTileRewritePolicy, TuneOption, HardwareParams, \ - PreloadMeasuredStates, PreAddCustomRule, auto_schedule +from .auto_schedule import SearchTask, SketchSearchPolicy, TuneOption, HardwareParams, \ + PreloadMeasuredStates, PreloadCustomSketchRule, auto_schedule from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, LocalRPCMeasureContext from .cost_model import RandomModel from .cost_model.xgb_model import XGBModel from .serialization import LogToFile, LogReader, best_measure_pair_in_file, \ load_from_file, write_measure_records_to_file -from .workload_registry import register_auto_scheduler_workload_func, \ +from .workload_registry import register_workload_func, \ workload_key_to_dag, make_workload_key_func from .task_scheduler import TaskScheduler, SimpleTaskScheduler from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest as apply_history_best, \ diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index e8108a067b2e..a03d9fdacbc2 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -83,17 +83,19 @@ def run_callbacks(self, callbacks): _ffi_api.SearchPolicyRunCallbacks(self, callbacks) -@tvm._ffi.register_object("ansor.MetaTileRewritePolicy") -class MetaTileRewritePolicy(SearchPolicy): - """ The search policy that searches with meta tiling and random rewrite +@tvm._ffi.register_object("ansor.SketchSearchPolicy") +class SketchSearchPolicy(SearchPolicy): + """ The search policy that searches in a hierarchical search space defined by sketches. + The policy randomly samples programs from the space defined by sketches + and use evolutionary search to fine-tune them. Parameters ---------- program_cost_model: CostModel Cost model for programs params: int - Parameters of the search policy, go meta_tile_rewrite_policy.h to find the - definitions. See code below to find the default values + Parameters of the search policy. See `src/ansor/search_policy/sketch_search_policy.h` + to find the definitions. See code below to find the default values seed: int Random seed """ @@ -124,7 +126,7 @@ def __init__(self, params[key] = value self.__init_handle_by_constructor__( - _ffi_api.MetaTileRewritePolicy, program_cost_model, params, + _ffi_api.SketchSearchPolicy, program_cost_model, params, seed or random.randint(1, 1 << 30)) @@ -148,16 +150,16 @@ def __init__(self, filename: str): _ffi_api.PreloadMeasuredStates, filename) -@tvm._ffi.register_object("ansor.PreAddCustomRule") -class PreAddCustomRule(SearchCallback): +@tvm._ffi.register_object("ansor.PreloadCustomSketchRule") +class PreloadCustomSketchRule(SearchCallback): """ - A SearchCallback for MetaTileRewritePolicy that allowing users to add + A SearchCallback for SketchSearchPolicy that allowing users to add custom sketch rule. Notes ----- This is an advanced feature. Make sure you're clear how it - works and this should only be used in MetaTileRewritePolicy. + works and this should only be used in SketchSearchPolicy. Parameters ---------- @@ -168,7 +170,7 @@ class PreAddCustomRule(SearchCallback): """ def __init__(self, meet_condition_func, apply_func): self.__init_handle_by_constructor__( - _ffi_api.PreAddCustomRule, meet_condition_func, apply_func) + _ffi_api.PreloadCustomSketchRule, meet_condition_func, apply_func) @tvm._ffi.register_object("ansor.TuneOption") @@ -197,7 +199,7 @@ class TuneOption(Object): Callback functions called before the search process Candidates: - ansor.PreloadMeasuredStates - - ansor.PreAddCustomRule + - ansor.PreloadCustomSketchRule """ def __init__(self, n_trials=0, early_stopping=-1, num_measure_per_iter=64, verbose=1, builder='local', runner='local', measure_callbacks=None, @@ -249,7 +251,7 @@ def auto_schedule(workload, target=None, """ if isinstance(search_policy, str): if search_policy == 'default': - search_policy = MetaTileRewritePolicy(RandomModel()) + search_policy = SketchSearchPolicy(RandomModel()) else: raise ValueError("Invalid search policy: " + search_policy) diff --git a/python/tvm/ansor/relay_integration.py b/python/tvm/ansor/relay_integration.py index 85c4d8813f69..3c2eabd3dfac 100644 --- a/python/tvm/ansor/relay_integration.py +++ b/python/tvm/ansor/relay_integration.py @@ -28,7 +28,7 @@ from tvm import target, te, transform from tvm.te.tensor import PlaceholderOp, ComputeOp from .dispatcher import DispatchContext -from .workload_registry import register_auto_scheduler_workload_bufs, compute_dag_hash +from .workload_registry import register_workload_bufs, compute_dag_hash from .compute_dag import ComputeDAG, LayoutRewriteLevel from .env import GLOBAL_SCOPE @@ -203,11 +203,14 @@ def traverse(t): def auto_schedule_topi(outs): """ Use ansor to auto-schedule a topi compute declaration """ io_tensors, has_layout_free = traverse_to_get_io_tensors(outs) - key = register_auto_scheduler_workload_bufs(io_tensors) + key = register_workload_bufs(io_tensors) env = TracingEnvironment.current if env is None: # in the final build mode state = DispatchContext.current.query(target.Target.current(), key) + if state is None: + return te.create_schedule([x.op for x in outs]) + dag = ComputeDAG(io_tensors) # Only update compute body, layout_rewrite_level = LayoutRewriteLevel.COMPUTE_REWRITE, # Since kernel layout has already been rewritten in relay pass diff --git a/python/tvm/ansor/task_scheduler.py b/python/tvm/ansor/task_scheduler.py index 3d4d9624d7c2..587fe3121e88 100644 --- a/python/tvm/ansor/task_scheduler.py +++ b/python/tvm/ansor/task_scheduler.py @@ -21,7 +21,7 @@ import numpy as np -from .auto_schedule import SearchTask, SearchPolicy, MetaTileRewritePolicy, TuneOption +from .auto_schedule import SearchTask, SearchPolicy, SketchSearchPolicy, TuneOption from .cost_model import RandomModel, XGBModel from .measure import ProgramMeasurer from .utils import array_mean, to_str_round @@ -42,7 +42,7 @@ def compute_score(self, costs: List[float]) -> float: def get_search_policies(search_policy: Union[str, List[SearchPolicy]], tasks: List[SearchTask], num_measure_per_iter, load_model_file=None, load_log_file=None): if search_policy == 'default': - search_policy = 'meta-rewrite.xgb' + search_policy = 'sketch.xgb' if isinstance(search_policy, str): policy_type, model_type = search_policy.split('.') @@ -58,16 +58,16 @@ def get_search_policies(search_policy: Union[str, List[SearchPolicy]], tasks: Li else: raise ValueError("Invalid search policy: " + search_policy) - if policy_type == 'meta-rewrite': - search_policies = [MetaTileRewritePolicy(cost_model) for _ in range(len(tasks))] + if policy_type == 'sketch': + search_policies = [SketchSearchPolicy(cost_model) for _ in range(len(tasks))] elif policy_type == 'limit-space': - search_policies = [MetaTileRewritePolicy(cost_model, - params={'cpu_multi_level_tiling_structure': 'SRS', - 'disable_change_compute_location': 1}) + search_policies = [SketchSearchPolicy(cost_model, + params={'cpu_multi_level_tiling_structure': 'SRS', + 'disable_change_compute_location': 1}) for _ in range(len(tasks))] elif policy_type == 'beam-search': - search_policies = [MetaTileRewritePolicy(cost_model, - params={'use_beam_search': 1}) + search_policies = [SketchSearchPolicy(cost_model, + params={'use_beam_search': 1}) for _ in range(len(tasks))] else: raise ValueError("Invalid search policy: " + search_policy) diff --git a/python/tvm/ansor/workload_registry.py b/python/tvm/ansor/workload_registry.py index bcf8269b9490..e706c0ec4cf9 100644 --- a/python/tvm/ansor/workload_registry.py +++ b/python/tvm/ansor/workload_registry.py @@ -42,19 +42,19 @@ WORKLOAD_FUNC_REGISTRY = {} -def register_auto_scheduler_workload_func(func: Callable): +def register_workload_func(func: Callable): """Register a workload generation function The input function should take hashable and jsonable arguments (int, float, tuple of int, tvm.tensor.Tensor, ...) and return a list of tvm.tensor.Tensor. Examples -------- - @register_auto_scheduler_workload_func + @register_workload_func def matmul(N, M, K): - A = tvm.placeholder((N, K), name='A') - B = tvm.placeholder((K, M), name='B') - k = tvm.reduce_axis((0, K), name='k') - C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i][k] * B[k][j], axis=[k]), name='C') + A = te.placeholder((N, K), name='A') + B = te.placeholder((K, M), name='B') + k = te.reduce_axis((0, K), name='k') + C = te.compute((N, M), lambda i, j: tvm.sum(A[i][k] * B[k][j], axis=[k]), name='C') return [A, B, C] """ func_name = func.__name__ @@ -84,7 +84,7 @@ def compute_dag_hash(dag: ComputeDAG): return hashlib.md5(str_key).hexdigest() -def register_auto_scheduler_workload_bufs(bufs: List[Tensor]) -> str: +def register_workload_bufs(bufs: List[Tensor]) -> str: """Directly register buffers of a workload and return the workload_key The buffers can be looked up with workload_key_to_tensors by the workload_key """ diff --git a/scripts/common.py b/scripts/common.py index 84fbf8d6c731..8f4fbec09dd0 100644 --- a/scripts/common.py +++ b/scripts/common.py @@ -14,7 +14,7 @@ import tvm from tvm import te from tvm.ansor import (LogReader, make_workload_key_func, - register_auto_scheduler_workload_func, + register_workload_func, write_measure_records_to_file) from tvm.contrib import ndk, util @@ -22,28 +22,28 @@ ###################### Test Workloads #################### ############################################################ -@register_auto_scheduler_workload_func +@register_workload_func def min_mn(M, N): A = te.placeholder((M, N), name='A') B = topi.min(A, axis=1) return [A, B] -@register_auto_scheduler_workload_func +@register_workload_func def argmin_mn(M, N): A = te.placeholder((M, N), name='A') B = topi.argmin(A, axis=1) return [A, B] -@register_auto_scheduler_workload_func +@register_workload_func def softmax_mn(M, N): A = te.placeholder((M, N), name='A') B = topi.nn.softmax(A, axis=1) return [A, B] -@register_auto_scheduler_workload_func +@register_workload_func def norm_bmn(B, M, N): A = te.placeholder((B, M, N), name='A') i = te.reduce_axis((0, M)) @@ -53,7 +53,7 @@ def norm_bmn(B, M, N): return [A, D] -@register_auto_scheduler_workload_func +@register_workload_func def add_mn(M, N): A = te.placeholder((M, N), name='A') B = te.placeholder((M, N), name='B') @@ -61,7 +61,7 @@ def add_mn(M, N): return [A, B, C] -@register_auto_scheduler_workload_func +@register_workload_func def matmul_nkkm(N, M, K, in_type='float32', out_type='float32', tensor_core_support=False): A = te.placeholder((N, K), name='A', dtype=in_type) @@ -73,7 +73,7 @@ def matmul_nkkm(N, M, K, in_type='float32', out_type='float32', C = te.compute((N, M), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C', - attrs={"auto_scheduler_tensor_core_support": "True" if tensor_core_support else "False"}) + attrs={"ansor_tensor_core_support": "True" if tensor_core_support else "False"}) else: if not ((in_type == 'float16' and out_type == 'float32') or \ (in_type == 'int8' and out_type == 'int32')): @@ -82,11 +82,11 @@ def matmul_nkkm(N, M, K, in_type='float32', out_type='float32', lambda i, j: te.sum(A[i][k].astype(out_type) * B[k][j].astype(out_type), axis=[k]), name='C', - attrs={"auto_scheduler_tensor_core_support": "True" if tensor_core_support else "False"}) + attrs={"ansor_tensor_core_support": "True" if tensor_core_support else "False"}) return [A, B, C] -@register_auto_scheduler_workload_func +@register_workload_func def dense_layer(batch, in_dim, out_dim): A = te.placeholder((batch, in_dim), name='A') B = te.placeholder((out_dim, in_dim), name='B') @@ -95,7 +95,7 @@ def dense_layer(batch, in_dim, out_dim): return [A, B, C] -@register_auto_scheduler_workload_func +@register_workload_func def max_pool_2d_nchw(N, C, H, W): data = te.placeholder((N, C, H, W), name='data') out = topi.nn.pool(data, (2, 2), (1, 1), (0, 0, 0, 0), pool_type='max', ceil_mode=True, @@ -103,7 +103,7 @@ def max_pool_2d_nchw(N, C, H, W): return [data, out] -@register_auto_scheduler_workload_func +@register_workload_func def add_min_relu(M, N): A = te.placeholder((M, N), name='A') B = te.placeholder((M, N), name='B') @@ -112,7 +112,7 @@ def add_min_relu(M, N): out = topi.nn.relu(D) return [A, B, out] -@register_auto_scheduler_workload_func +@register_workload_func def conv2d_relu_softmax_min(N, H, W, CI, CO, KH, KW, strides, padding, dilation): data = te.placeholder((N, CI, H, W), name='data') kernel = te.placeholder((CO, CI, KH, KW), name='kernel') @@ -123,7 +123,7 @@ def conv2d_relu_softmax_min(N, H, W, CI, CO, KH, KW, strides, padding, dilation) return [data, kernel, out] -@register_auto_scheduler_workload_func +@register_workload_func def conv2d_nchw_bias(N, H, W, CI, CO, KH, KW, strides, padding, dilation): data = te.placeholder((N, CI, H, W), name='data') kernel = te.placeholder((CO, CI, KH, KW), name='kernel') @@ -190,7 +190,7 @@ def conv2d_nhwc_without_layout_rewrite(Input, Filter, stride, padding, dilation, return Output -@register_auto_scheduler_workload_func +@register_workload_func def conv2d_nhwc_bias_with_rewrite(N, H, W, CI, CO, KH, KW, strides, padding, dilation): data = te.placeholder((N, H, W, CI), name='data') kernel = te.placeholder((KH, KW, CI, CO), name='kernel') @@ -199,7 +199,7 @@ def conv2d_nhwc_bias_with_rewrite(N, H, W, CI, CO, KH, KW, strides, padding, dil out = topi.add(conv, bias) return [data, kernel, bias, out] -@register_auto_scheduler_workload_func +@register_workload_func def depthwise_conv2d_nhwc_bias_with_rewrite(N, H, W, CI, CO, KH, KW, strides, padding, dilation): data = te.placeholder((N, H, W, CI), name='data') kernel = te.placeholder((KH, KW, CI, 1), name='kernel') @@ -208,7 +208,7 @@ def depthwise_conv2d_nhwc_bias_with_rewrite(N, H, W, CI, CO, KH, KW, strides, pa out = topi.add(conv, bias) return [data, kernel, bias, out] -@register_auto_scheduler_workload_func +@register_workload_func def conv2d_nhwc_bias(N, H, W, CI, CO, KH, KW, strides, padding, dilation): data = te.placeholder((N, H, W, CI), name='data') kernel = te.placeholder((KH, KW, CI, CO), name='kernel') @@ -218,7 +218,7 @@ def conv2d_nhwc_bias(N, H, W, CI, CO, KH, KW, strides, padding, dilation): return [data, kernel, bias, out] -@register_auto_scheduler_workload_func +@register_workload_func def conv2d_nchw_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation=1): data = te.placeholder((N, CI, H, W), name='data') kernel = te.placeholder((CO, CI, kernel_size, kernel_size), name='kernel') @@ -243,7 +243,7 @@ def conv2d_nchw_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation return [data, kernel, bias, bn_offset, bn_scale, out] -@register_auto_scheduler_workload_func +@register_workload_func def conv2d_nhwc_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation=1): data = te.placeholder((N, H, W, CI), name='data') kernel = te.placeholder((kernel_size, kernel_size, CI, CO), name='kernel') diff --git a/scripts/shape_configs.py b/scripts/shape_configs.py index 95a1ba69634d..244638f5b29c 100644 --- a/scripts/shape_configs.py +++ b/scripts/shape_configs.py @@ -1,5 +1,5 @@ -""" Shape configurations for single operator evaluation -This file is shared by tune_all_single_op.py and scripts in baseline/ +""" Shape configurations for single operator / subgraph evaluation +This file is shared by tune_op_subgraph.py and scripts in scripts/baseline/ """ matmul_shapes = [ @@ -142,13 +142,6 @@ (1, 4096, 1024), ] -softmax_shapes = [ - (1, 1024), - (1, 4096), - (1, 16384), - (1, 65536), -] - single_op_shape_dict = { 'C1D': conv1d_shapes, 'C2D': conv2d_shapes, @@ -160,12 +153,11 @@ 'T2D': conv2d_transpose_shapes, 'CAP': conv2d_capsule_shapes, 'NRM': norm_shapes, - #'SMX': softmax_shapes, # The following workloads are not in our sinle op evaluation plan. # They should be moved to `common.py` and be used by `tune_wkl.py`. # 'C2D_NCHW': conv2d_nchw_shapes, - 'C2DWG_NHWC': conv2d_winograd_nhwc_shapes, +# 'C2DWG_NHWC': conv2d_winograd_nhwc_shapes, # 'C2DWG_NCHW': conv2d_winograd_nchw_shapes, # 'GMM_TC': matmul_tensor_core_shapes, } @@ -192,19 +184,9 @@ (16, 128, 12, 128), ] - -batch_norm_shapes = [ - (16, 256), - (16, 1024), - (16, 4096), - (16, 16384), - (16, 65536), -] - subgraph_shape_dict = { "conv2d_bn_relu": conv2d_bn_relu_shapes, "transpose_batch_matmul": transpose_batch_matmul_shapes, - #"batch_norm": batch_norm_shapes, } resnet_shapes = [ diff --git a/scripts/tune_network.py b/scripts/tune_network.py index d4f1afd95572..1905d8132003 100644 --- a/scripts/tune_network.py +++ b/scripts/tune_network.py @@ -1,13 +1,12 @@ -"""Tune all workloads in a network""" +"""Tune a whole neural network""" import argparse import logging import random import os -import time import numpy as np import tvm -from tvm import _ffi, ansor, relay +from tvm import ansor, relay import tvm.contrib.graph_runtime as runtime from tvm.contrib.debugger import debug_runtime from tvm.contrib import util, ndk @@ -20,8 +19,8 @@ dtype = "float32" -def get_network(name, model_path, batch_size, layout): - """Get the symbol definition and random weight of a network""" +def get_network(name, network_path, batch_size, layout): + """Get the relay module and random weights for a network""" input_shape = (batch_size, 3, 224, 224) output_shape = (batch_size, 1000) input_name = 'data' @@ -95,7 +94,7 @@ def get_network(name, model_path, batch_size, layout): input_shape = (1, 224, 224, 3) output_shape = (1, 1001) input_dtype = "float32" - tflite_model_buf = open(model_path, "rb").read() + tflite_model_buf = open(network_path, "rb").read() tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) mod, params = relay.frontend.from_tflite(tflite_model, shape_dict={input_name: input_shape}, @@ -144,21 +143,17 @@ def get_network(name, model_path, batch_size, layout): def create_module(data_shape, graph, lib, target, input_name, params, debug_profile, - local_measure, ndk_cc, device_key, host, port, run_timeout, num_threads, seed=43): - # Upload parameters to device + local_measure, ndk_cc, rpc_device_key, rpc_host, rpc_port, rpc_num_threads, seed=43): if local_measure: if target.target_name == "cuda": ctx = tvm.gpu() else: ctx = tvm.cpu() - if num_threads: - config_threadpool = _ffi.get_global_func('runtime.config_threadpool') - config_threadpool(0, num_threads) else: print("=============== Request Remote ===============") if 'TVM_NDK_CC' not in os.environ: os.environ['TVM_NDK_CC'] = ndk_cc - remote = request_remote(device_key, host, port, timeout=run_timeout) + remote = request_remote(rpc_device_key, rpc_host, rpc_port) print("=============== Export ===============") ctx = remote.cpu() @@ -171,9 +166,10 @@ def create_module(data_shape, graph, lib, target, input_name, params, debug_prof print("=============== Load ===============") lib = remote.load_module("deploy_lib.so") - if num_threads: + + if rpc_num_threads: config_threadpool = remote.get_function('runtime.config_threadpool') - config_threadpool(0, num_threads) + config_threadpool(0, rpc_num_threads) np.random.seed(seed) data_tvm = tvm.nd.array(100 * (np.random.uniform(size=data_shape)).astype(dtype), ctx=ctx) @@ -181,6 +177,7 @@ def create_module(data_shape, graph, lib, target, input_name, params, debug_prof module = debug_runtime.create(graph, lib, ctx) else: module = runtime.create(graph, lib, ctx) + if type(input_name) == list: for name in input_name: module.set_input(name, data_tvm) @@ -192,19 +189,20 @@ def create_module(data_shape, graph, lib, target, input_name, params, debug_prof return module, ctx -def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, - debug_profile, check_correctness, network_parameters, - task_scheduler_parameters, tune_parameters, module_parameters): - # Extract workloads from relay program - mod, params, input_name, data_shape, out_shape = get_network(**network_parameters) +def tune_and_evaluate(network_arguments, target, target_host, + search_policy, task_scheduler_arguments, tune_option_arguments, + tune, debug_profile, check_correctness, log_n_lines): + # Extract tasks from relay program + mod, params, input_name, data_shape, out_shape = get_network(**network_arguments) + # Tune all if tune: - print("=============== Extracting workloads ===============") + print("=============== Extract Workloads ===============") workloads, wkl_weights = ansor.extract_from_program(mod, target=target, params=params) - print("Totally %d workload extracted." % (len(workloads))) + print("Extract %d workloads in total" % (len(workloads))) # Tune workloads with auto scheduler - print("=============== Tuning ===============") + print("=============== Tune ===============") tasks = [] for i, wkl_key in enumerate(workloads): dag = ansor.workload_key_to_dag(wkl_key) @@ -212,24 +210,24 @@ def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, tasks.append(ansor.SearchTask(dag, wkl_key, target, target_host)) tuner = ansor.SimpleTaskScheduler(tasks, - lambda costs: sum(c * w for c, w in zip(costs, wkl_weights)), - **task_scheduler_parameters) - tune_option, measure_ctx = create_tune_option(target, **tune_parameters) + lambda costs: sum(c * w for c, w in zip(costs, wkl_weights)), + **task_scheduler_arguments) + tune_option, measure_ctx = create_tune_option(target, **tune_option_arguments) - if tune_parameters['local_measure'] and target.target_name != 'cuda': + if tune_option_arguments['local_measure'] and target.target_name != 'cuda': os.environ['TVM_BIND_MASTER_CORE_0'] = "1" tuner.tune(tune_option, search_policy) if measure_ctx: del measure_ctx - kernel_layout_rewrite = False + kernel_layout_rewrite = True # Compile graph with best states found by auto-scheduler print("=============== Compile ===============") - with ansor.apply_history_best(tune_parameters['log_file'], log_n_lines): + with ansor.apply_history_best(tune_option_arguments['log_file'], log_n_lines): os.environ['TVM_AUTO_CACHE_FLUSH'] = "0" - os.environ['TVM_BIND_MASTER_CORE_0'] = "1" + if kernel_layout_rewrite: ansor.prepare_layout_rewrite(mod, target=target, params=params) else: @@ -245,12 +243,13 @@ def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, print("=============== Compile Finish ===============") module, ctx = create_module(data_shape, graph, lib, target, input_name, - opt_params, debug_profile, **module_parameters) + opt_params, debug_profile, **common_measure_parameters) # Evaluate print("========== Evaluate ==========") ftimer = module.module.time_evaluator("run", ctx, number=10, repeat=3) prof_res = np.array(ftimer().results) + # display profile information if debug_profile or check_correctness: module.run() @@ -273,12 +272,12 @@ def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, ansor.LayoutRewriteLevel.BOTH_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE ansor.LayoutRewriteLevel.COMPUTE_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE target = tvm.target.create('llvm') - with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): + with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}): graph, lib, opt_params = relay.build_module.build( mod, target=target, params=params) module, _ = create_module(data_shape, graph, lib, target, input_name, - opt_params, debug_profile, **module_parameters) + opt_params, debug_profile, **common_measure_parameters) module.run() expected_output = module.get_output(0).asnumpy() @@ -287,58 +286,58 @@ def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, if __name__ == "__main__": parser = argparse.ArgumentParser() - # Task related options + + # Search task related arguments parser.add_argument("--network", type=str, required=True) - parser.add_argument("--model-path", type=str, default=None, help="The path of tflite model") + parser.add_argument("--network-path", type=str, default=None, help="The path of tflite model") parser.add_argument("--batch-size", type=int, default=1) parser.add_argument("--layout", type=str, default='NHWC') parser.add_argument("--target", type=str, default='llvm -mcpu=core-avx2') parser.add_argument("--target-host", type=str, default=None) - parser.add_argument("--n-trials", type=int, default=1000) - parser.add_argument("--num-measure-per-iter", type=int, default=48, - help="The number of programs to be measured at each iteration") - parser.add_argument("--tune", type=str2bool, nargs='?', const=True, default=True) parser.add_argument("--check-correctness", type=str2bool, nargs='?', const=True, default=False) parser.add_argument("--debug-profile", type=str2bool, nargs='?', const=True, default=False) + parser.add_argument("--tune", type=str2bool, nargs='?', const=True, default=True) - # Strategy related options - parser.add_argument("--seed", type=int, default=0, help='random seed') - parser.add_argument("--policy", type=str, choices=['multi-stage', 'meta-rewrite'], - default='meta-rewrite') + # Search strategy related arguments + parser.add_argument("--n-trials", type=int, default=1000) + parser.add_argument("--policy", type=str, choices=['sketch'], default='sketch') parser.add_argument("--model-type", type=str, choices=['xgb', 'random', 'no-share'], default='xgb') parser.add_argument("--task-scheduler", type=str, default='gradient', choices=['no', 'gradient', 'round-robin'], help='The strategy of task scheduler') + parser.add_argument("--seed", type=int, default=0, help='random seed') - # File related options - parser.add_argument("--log-file", type=str, help="Write log of measurement results to this file") + # Log file related arguments + parser.add_argument("--log-file", type=str, help="Write measurement records to this log file") + parser.add_argument("--load-log", type=str, help="Load history log to resume the status of search") + parser.add_argument("--log-n-lines", type=int, help="Only load the first n lines for history log") parser.add_argument("--load-model", type=str, help="Load pre trained cost model file") - parser.add_argument("--load-log", type=str, help="Load history log for pre-training the cost model") - parser.add_argument("--out-file", type=str, default='results.tsv') - parser.add_argument("--log-n-lines", type=int) - # Detailed control options + # Measurement related and other arguments + parser.add_argument("--num-measure-per-iter", type=int, default=48, + help="The number of programs to be measured at each iteration") parser.add_argument("--build-timeout", type=int, default=10) parser.add_argument("--run-timeout", type=int, default=10) parser.add_argument("--early-stopping", type=int, default=-1) parser.add_argument("--verbose", type=int, default=1) parser.add_argument("--local-measure", type=str2bool, nargs='?', const=True, default=True) - parser.add_argument("--device-key", type=str, default=None) - parser.add_argument("--host", type=str, default='0.0.0.0') - parser.add_argument("--port", type=int, default=9190) + parser.add_argument("--rpc-device-key", type=str, default=None) + parser.add_argument("--rpc-host", type=str, default='0.0.0.0') + parser.add_argument("--rpc-port", type=int, default=9190) + parser.add_argument("--rpc-num-threads", type=int, default=None) parser.add_argument("--n-parallel", type=int, default=1) parser.add_argument("--ndk-cc", type=str, default=None) - parser.add_argument("--num-threads", type=int, default=None) args = parser.parse_args() np.random.seed(args.seed) random.seed(args.seed) logging.basicConfig() logging.getLogger('ansor').setLevel(logging.DEBUG) + os.environ["TOPHUB_LOCATION"] = "NONE" # disable autotvm target = tvm.target.create(args.target) - log_file = args.log_file or "%s-B%d-%s.json" % (args.network, args.batch_size, - target.target_name) + log_file = args.log_file or "%s-B%d-%s.json" % (args.network, args.batch_size, + target.target_name) load_log_file = args.load_log or log_file search_policy = "%s.%s" % (args.policy, args.model_type) if args.layout: @@ -348,9 +347,9 @@ def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, else: layout = "NHWC" - network_parameters = { + network_arguments = { 'name': args.network, - 'model_path': args.model_path, + 'network_path': args.network_path, 'batch_size': args.batch_size, 'layout': layout } @@ -362,15 +361,16 @@ def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, 'verbose': args.verbose, } - control_parameters = { + common_measure_parameters = { 'local_measure': args.local_measure, - 'device_key': args.device_key, - 'host': args.host, - 'port': args.port, + 'rpc_device_key': args.rpc_device_key, + 'rpc_host': args.rpc_host, + 'rpc_port': args.rpc_port, + 'rpc_num_threads': args.rpc_num_threads, 'ndk_cc': args.ndk_cc, } - tune_parameters = { + tune_option_arguments = { 'log_file': log_file, 'n_trials': args.n_trials, 'num_measure_per_iter': args.num_measure_per_iter, @@ -379,17 +379,10 @@ def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, 'build_timeout': args.build_timeout, 'run_timeout': args.run_timeout, 'early_stopping': args.early_stopping, - **control_parameters - } - - module_parameters = { - 'run_timeout': args.run_timeout, - 'num_threads': args.num_threads, - **control_parameters + **common_measure_parameters } - os.environ["TOPHUB_LOCATION"] = "NONE" - tune_and_evaluate(target, args.target_host, args.log_n_lines, search_policy, + tune_and_evaluate(network_arguments, target, args.target_host, + search_policy, task_scheduler_parameters, tune_option_arguments, args.tune, args.debug_profile, args.check_correctness, - network_parameters, task_scheduler_parameters, tune_parameters, - module_parameters) + args.log_n_lines) diff --git a/scripts/tune_op_subgraph.py b/scripts/tune_op_subgraph.py index bf5cbe83c952..6574bb77e510 100644 --- a/scripts/tune_op_subgraph.py +++ b/scripts/tune_op_subgraph.py @@ -1,7 +1,6 @@ -"""Tune all operators for single op & subgraph evaluation""" +"""Tune all workloads for single op & subgraph evaluation""" import argparse import logging -import os import random import numpy as np @@ -12,14 +11,13 @@ from topi.nn.winograd_util import winograd_transform_matrices from topi.util import get_const_tuple -from common import measure_schedule, str2bool, \ - norm_bmn, softmax_mn, conv2d_nhwc_bn_relu, conv2d_nchw_bn_relu +from common import measure_schedule, str2bool, norm_bmn, conv2d_nhwc_bn_relu, conv2d_nchw_bn_relu from shape_configs import single_op_shape_dict, subgraph_shape_dict from tune_test import tune_workloads_jointly, replay_workload, create_tune_option # ========================== Single Ops ========================== -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def batch_matmul_nkkm(B, N, M, K): X = te.placeholder((B, N, K), name='A') Y = te.placeholder((B, K, M), name='B') @@ -27,7 +25,7 @@ def batch_matmul_nkkm(B, N, M, K): Z = te.compute((B, N, M), lambda b, i, j: te.sum(X[b][i][k] * Y[b][k][j], axis=[k]), name='C') return [X, Y, Z] -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def conv1d_nlc(N, L, CI, CO, kernel_size, stride=1, padding=0, dilation=1, groups=1): inputs = te.placeholder((N, L, CI), name='inputs') weight = te.placeholder((kernel_size, CI//groups, CO), name='weight') @@ -49,7 +47,7 @@ def conv1d_nlc(N, L, CI, CO, kernel_size, stride=1, padding=0, dilation=1, group ) return [inputs, weight, output] -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def conv2d_nhwc(N, H, W, CI, CO, kernel_size, stride=1, padding=0, dilation=1, groups=1): inputs = te.placeholder((N, H, W, CI), name='inputs') weight = te.placeholder((kernel_size, kernel_size, CI//groups, CO), name='weight') @@ -75,7 +73,7 @@ def conv2d_nhwc(N, H, W, CI, CO, kernel_size, stride=1, padding=0, dilation=1, g ) return [inputs, weight, output] -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def conv2d_nchw(N, CI, H, W, CO, kernel_size, stride=1, padding=0, dilation=1, groups=1): inputs = te.placeholder((N, CI, H, W), name='inputs') weight = te.placeholder((CO, CI//groups, kernel_size, kernel_size), name='weight') @@ -101,7 +99,7 @@ def conv2d_nchw(N, CI, H, W, CO, kernel_size, stride=1, padding=0, dilation=1, g ) return [inputs, weight, output] -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def conv3d_ndhwc(N, D, H, W, CI, CO, kernel_size, stride=1, padding=0, dilation=1, groups=1): inputs = te.placeholder((N, D, H, W, CI)) weight = te.placeholder((kernel_size, kernel_size, kernel_size, CI//groups, CO)) @@ -131,7 +129,7 @@ def conv3d_ndhwc(N, D, H, W, CI, CO, kernel_size, stride=1, padding=0, dilation= ) return [inputs, weight, output] -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def depthwise_conv2d_nhwc(N, H, W, C, kernel_size, stride=1, padding=0, dilation=1, factor=1): inputs = te.placeholder((N, H, W, C)) weight = te.placeholder((factor, kernel_size, kernel_size, C)) @@ -159,7 +157,7 @@ def depthwise_conv2d_nhwc(N, H, W, C, kernel_size, stride=1, padding=0, dilation ) return [inputs, weight, output] -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def conv2d_transpose_nhwc(N, H, W, CI, CO, kernel_size, stride=1, padding=0): inputs = te.placeholder((N, H, W, CI), name='inputs') weight = te.placeholder((kernel_size, kernel_size, CI, CO), name='weight') @@ -222,12 +220,12 @@ def _dilate(*indices): weight[filter_h - 1 - rh, filter_w - 1 - rw, rc, co], axis=[rh, rw, rc]), name="conv2d_transpose_nhwc", - attrs={"auto_scheduler_always_unroll_inner": ["h", "w", "rh", "rw", "h_c", "w_c"]}) + attrs={"ansor_always_unroll_inner": ["h", "w", "rh", "rw", "h_c", "w_c"]}) # todo(lmzheng): add constraints on the tile size of h and w return [inputs, weight, output] -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def conv2d_capsule_nhwijc(N, H, W, CI, CO, kernel_size, stride=1, padding=0, capsule_size=4): inputs = te.placeholder((N, H, W, capsule_size, capsule_size, CI), name='inputs') weight = te.placeholder((kernel_size, kernel_size, capsule_size, capsule_size, CI, CO), name='weight') @@ -254,7 +252,7 @@ def conv2d_capsule_nhwijc(N, H, W, CI, CO, kernel_size, stride=1, padding=0, cap return [inputs, weight, output] -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def conv2d_winograd_nhwc(N, H, W, CI, CO, kernel_size=3, stride=1, padding=0, dilation=1): # TODO: implement tile_size tile_size = 4 #_infer_tile_size(data, kernel) @@ -304,10 +302,10 @@ def conv2d_winograd_nhwc(N, H, W, CI, CO, kernel_size=3, stride=1, padding=0, di data_pack = te.compute((alpha, alpha, P, CI), lambda eps, nu, p, ci: te.sum(input_tile[r_a][r_b][p][ci] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b]), name='data_pack', - attrs={"auto_scheduler_no_split_at_inner": ["eps", "nu", "r_a", "r_b"], - "auto_scheduler_last_split_is_one": ["ci", "p"], - "auto_scheduler_always_unroll": ["eps", "nu", "r_a", "r_b"], - "auto_scheduler_no_cache_write": "True", + attrs={"ansor_no_split_at_inner": ["eps", "nu", "r_a", "r_b"], + "ansor_last_split_is_one": ["ci", "p"], + "ansor_always_unroll": ["eps", "nu", "r_a", "r_b"], + "ansor_no_cache_write": "True", }) # do batch gemm @@ -323,10 +321,10 @@ def conv2d_winograd_nhwc(N, H, W, CI, CO, kernel_size=3, stride=1, padding=0, di inverse = te.compute((m, m, P, CO), lambda vh, vw, p, co: te.sum(bgemm[r_a][r_b][p][co] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b]), name='inverse', - attrs={"auto_scheduler_no_split_at_inner": ["vh", "vw", "r_a", "r_b"], - "auto_scheduler_always_unroll": ["vh", "vw", "r_a", "r_b"], - "auto_scheduler_last_split_is_one": ["co", "p"], - "auto_scheduler_no_cache_write": "True", + attrs={"ansor_no_split_at_inner": ["vh", "vw", "r_a", "r_b"], + "ansor_always_unroll": ["vh", "vw", "r_a", "r_b"], + "ansor_last_split_is_one": ["co", "p"], + "ansor_no_cache_write": "True", }) # output @@ -337,10 +335,10 @@ def conv2d_winograd_nhwc(N, H, W, CI, CO, kernel_size=3, stride=1, padding=0, di co], name='conv2d_winograd', tag='conv2d_winograd_nhwc', - attrs={"auto_scheduler_no_split_at_outer": ["n", "h", "w", "co"],}) + attrs={"ansor_no_split_at_outer": ["n", "h", "w", "co"],}) return [inputs, kernel_pack, output] -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def conv2d_winograd_nchw(N, CI, H, W, CO, kernel_size=3, stride=1, padding=0, dilation=1, precompute=False): # TODO: implement tile_size tile_size = 4 #_infer_tile_size(data, kernel) @@ -390,10 +388,10 @@ def conv2d_winograd_nchw(N, CI, H, W, CO, kernel_size=3, stride=1, padding=0, di data_pack = te.compute((alpha, alpha, CI, P), lambda eps, nu, ci, p: te.sum(input_tile[ci][p][r_a][r_b] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b]), name='data_pack', - attrs={"auto_scheduler_no_split_at_inner": ["eps", "nu", "r_a", "r_b"], - "auto_scheduler_no_split_at_outer": ["ci", "p"], - "auto_scheduler_always_unroll": ["eps", "nu", "r_a", "r_b"], - "auto_scheduler_no_cache_write": "True", + attrs={"ansor_no_split_at_inner": ["eps", "nu", "r_a", "r_b"], + "ansor_no_split_at_outer": ["ci", "p"], + "ansor_always_unroll": ["eps", "nu", "r_a", "r_b"], + "ansor_no_cache_write": "True", }) # do batch gemm @@ -409,9 +407,9 @@ def conv2d_winograd_nchw(N, CI, H, W, CO, kernel_size=3, stride=1, padding=0, di inverse = te.compute((CO, P, m, m), lambda co, p, vh, vw: te.sum(bgemm[r_a][r_b][co][p] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b]), name='inverse', - attrs={"auto_scheduler_no_split_at_outer": ["co", "p", "vh", "vw", "r_a", "r_b"], - "auto_scheduler_always_unroll": ["vh", "vw", "r_a", "r_b"], - "auto_scheduler_no_cache_write": "True"}) + attrs={"ansor_no_split_at_outer": ["co", "p", "vh", "vw", "r_a", "r_b"], + "ansor_always_unroll": ["vh", "vw", "r_a", "r_b"], + "ansor_no_cache_write": "True"}) # output output = te.compute((N, CO, H, W), lambda n, co, h, w: @@ -419,12 +417,12 @@ def conv2d_winograd_nchw(N, CI, H, W, CO, kernel_size=3, stride=1, padding=0, di idxmod(h, m), idxmod(w, m)], name='conv2d_winograd', - attrs={"auto_scheduler_no_split_at_outer": ["n", "co", "h", "w"],}) + attrs={"ansor_no_split_at_outer": ["n", "co", "h", "w"],}) return [inputs, kernel_pack, output] # ========================== Subgraphs ========================== -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def transpose_batch_matmul(batch, seq_len, n_head, n_dim): query = te.placeholder((batch, seq_len, n_head, n_dim), name='query') value = te.placeholder((batch, seq_len, n_head, n_dim), name='value') @@ -433,23 +431,12 @@ def transpose_batch_matmul(batch, seq_len, n_head, n_dim): value_T = te.compute((batch, n_head, n_dim, seq_len), lambda b, h, d, l: value[b, l, h, d], name="value_T") k = te.reduce_axis((0, n_dim), name='k') - out = te.compute((batch, n_head, seq_len, seq_len), lambda b, h, i, j: te.sum(query_T[b][h][i][k] * value_T[b][h][k][j], axis=[k]), name='C') + out = te.compute((batch, n_head, seq_len, seq_len), + lambda b, h, i, j: te.sum(query_T[b][h][i][k] * value_T[b][h][k][j], axis=[k]), + name='C') return [query, value, out] -@ansor.register_auto_scheduler_workload_func -def batch_norm(M, N, eps=1e-5): - A = te.placeholder((M, N), name='A') - k1 = te.reduce_axis((0, M), name='k1') - k2 = te.reduce_axis((0, M), name='k2') - mean = te.compute((N,), lambda j: te.sum(A[k1][j] / M, axis=k1), name="mean") - var = te.compute((N,), - lambda j: te.sum((A[k2][j] - mean[j]) * (A[k2][j] - mean[j]) / (M - 1), k2), - name="var") - B = te.compute((M, N), lambda i, j: (A[i][j] - mean[j]) / te.sqrt(var[j] + eps), name='B') - - return [A, B] - -# ========================== Tune func & Dicts ========================== +# ========================== Tune function & Task dicts ========================== def tune_wkl(task_func_dict, shape_dict, wkl_type, args): target = tvm.target.create(args.target) @@ -464,8 +451,8 @@ def tune_wkl(task_func_dict, shape_dict, wkl_type, args): if shape[0] == 1: shape = list(shape) shape[0] = args.batch_size - wkl_key = ansor.make_workload_key_func(func, shape) + wkl_key = ansor.make_workload_key_func(func, shape) wkl_keys.append(wkl_key) if args.fast_check: break @@ -473,9 +460,8 @@ def tune_wkl(task_func_dict, shape_dict, wkl_type, args): if not args.tune: cost, gflops = replay_workload( wkl_key, target, args.target_host, log_file, - args.local_measure, args.device_key, args.host, - args.port, args.ndk_cc, False) - # TODO(): Add log record + args.local_measure, args.rpc_device_key, args.rpc_host, + args.rpc_port, args.rpc_num_threads, args.ndk_cc, False) # log_line(BenchmarkRecord(target.name, 'gpu' if target.name == 'cuda' else 'cpu', 'subgraph', # workload_name, "AutoSchedule", "default", # {"costs": [cost]}, time.time()), args.out_file) @@ -489,7 +475,8 @@ def tune_wkl(task_func_dict, shape_dict, wkl_type, args): tune_option, measure_ctx = create_tune_option(target, log_file, n_trials, args.num_measure_per_iter, args.verbose, args.n_parallel, args.build_timeout, args.local_measure, - args.device_key, args.host, args.port, args.ndk_cc) + args.rpc_device_key, args.rpc_host, args.rpc_port, + args.rpc_num_threads, args.ndk_cc) # tune workloads jointly using JointTuner tune_workloads_jointly(wkl_keys, np.ones(len(wkl_keys)), args.task_scheduler, @@ -516,7 +503,7 @@ def tune_wkl(task_func_dict, shape_dict, wkl_type, args): # The following workloads are not in our sinle op evaluation plan. # They should be moved to `common.py` and be used by `tune_wkl.py`. # 'C2D_NCHW': conv2d_nchw, - 'C2DWG_NHWC': conv2d_winograd_nhwc, +# 'C2DWG_NHWC': conv2d_winograd_nhwc, # 'C2DWG_NCHW': conv2d_winograd_nchw, # 'GMM_TC': matmul_nkkm, } @@ -529,44 +516,43 @@ def tune_wkl(task_func_dict, shape_dict, wkl_type, args): if __name__ == "__main__": parser = argparse.ArgumentParser() - # Task related options - parser.add_argument("--wkl", type=str, required=True, - help="all - For all workloads; \ - op - For all single ops; \ - subgraph - For all subgraphs; \ - Or specific wkl name") + # Search task related arguments + parser.add_argument("--wkl", type=str, required=True, + help="all - Tune all workloads; \ + op - Tune all single ops; \ + subgraph - Tune all subgraphs; \ + specific wkl name - Tune a specific workload") + parser.add_argument("--batch-size", type=int, default=1) parser.add_argument("--target", type=str, default='llvm -mcpu=core-avx2') parser.add_argument("--target-host", type=str, default=None) - parser.add_argument("--n-trials-per-shape", type=int, default=1000) - parser.add_argument("--num-measure-per-iter", type=int, default=48, - help="The number of programs to be measured at each iteration") - parser.add_argument("--batch-size", type=int, default=1) parser.add_argument("--tune", type=str2bool, nargs='?', const=True, default=True) parser.add_argument("--fast-check", action='store_true', help='Only run one shape for each workload. This is used for fast checking') - # Strategy related options - parser.add_argument("--seed", type=int, default=0, help='random seed') - parser.add_argument("--policy", type=str, choices=['meta-rewrite', 'beam-search'], default='meta-rewrite') + # Search strategy related arguments + parser.add_argument("--n-trials-per-shape", type=int, default=1000) + parser.add_argument("--policy", type=str, choices=['sketch', 'beam-search'], default='sketch') parser.add_argument("--model-type", type=str, choices=['xgb', 'random', 'no-share'], default='xgb') - parser.add_argument("--task-scheduler", type=str, default='gradient', - choices=['no', 'gradient', 'round-robin'], - help='The strategy of task scheduler') + parser.add_argument("--task-scheduler", type=str, default='round-robin', + choices=['no', 'gradient', 'round-robin'], help='The strategy of task scheduler') + parser.add_argument("--seed", type=int, default=0, help='random seed') - # File related options - parser.add_argument("--log-file", type=str, help="Write log of measurement results to this file") - parser.add_argument("--load-model", type=str, help="Load pre trained cost model file") - parser.add_argument("--load-log", type=str, help="Load history log for pre-training the cost model") - parser.add_argument("--out-file", type=str, default='results.tsv') + # Log file related arguments + parser.add_argument("--log-file", type=str, help="Write measurement records to this log file") + parser.add_argument("--load-log", type=str, help="Load history log to resume the status of search") + parser.add_argument("--load-model", type=str, help="Load pre-trained cost model from this file") - # Detailed control options + # Measurement related and other arguments + parser.add_argument("--num-measure-per-iter", type=int, default=48, + help="The number of programs to be measured at each iteration") parser.add_argument("--build-timeout", type=int, default=10) parser.add_argument("--run-timeout", type=int, default=60) parser.add_argument("--verbose", type=int, default=1) parser.add_argument("--local-measure", type=str2bool, nargs='?', const=True, default=True) - parser.add_argument("--device-key", type=str, default=None) - parser.add_argument("--host", type=str, default='0.0.0.0') - parser.add_argument("--port", type=int, default=9190) + parser.add_argument("--rpc-device-key", type=str, default=None) + parser.add_argument("--rpc-host", type=str, default='0.0.0.0') + parser.add_argument("--rpc-port", type=int, default=9190) + parser.add_argument("--rpc-num-threads", type=int, default=None) parser.add_argument("--n-parallel", type=int, default=1) parser.add_argument("--ndk-cc", type=str, default=None) args = parser.parse_args() diff --git a/scripts/tune_test.py b/scripts/tune_test.py index 86f055caf889..67c0526dd624 100644 --- a/scripts/tune_test.py +++ b/scripts/tune_test.py @@ -13,8 +13,8 @@ from common import get_workload_keys, get_workload_weights, measure_schedule, str2bool def create_tune_option(target, log_file, n_trials, num_measure_per_iter, verbose, - n_parallel, build_timeout, local_measure, device_key, host, - port, ndk_cc, early_stopping=-1, run_timeout=10): + n_parallel, build_timeout, local_measure, rpc_device_key, rpc_host, + rpc_port, rpc_num_threads, ndk_cc, early_stopping=-1, run_timeout=10): builder = runner = measure_ctx = None if local_measure: builder = ansor.LocalBuilder(timeout=build_timeout) @@ -27,8 +27,13 @@ def create_tune_option(target, log_file, n_trials, num_measure_per_iter, verbose else: os.environ['TVM_NDK_CC'] = ndk_cc builder = ansor.LocalBuilder(timeout=build_timeout, build_func='ndk') - runner = ansor.RPCRunner(key=device_key, host=host, port=port, timeout=run_timeout, - n_parallel=n_parallel, repeat=1, min_repeat_ms=400) + runner = ansor.RPCRunner(key=rpc_device_key, host=rpc_host, port=rpc_port, + timeout=run_timeout, n_parallel=n_parallel, + repeat=1, min_repeat_ms=200) + remote = request_remote(rpc_device_key, rpc_host, rpc_port) + if rpc_num_threads: + config_threadpool = remote.get_function('runtime.config_threadpool') + config_threadpool(0, rpc_num_threads) tune_option = ansor.TuneOption(n_trials=n_trials, early_stopping=early_stopping, num_measure_per_iter=num_measure_per_iter, @@ -42,16 +47,17 @@ def create_tune_option(target, log_file, n_trials, num_measure_per_iter, verbose def replay_workload(wkl_key, target, target_host, log_file, - local_measure=True, device_key=None, host="0.0.0.0", - port=9190, ndk_cc=None, show_lower_result=True): + local_measure=True, rpc_device_key=None, rpc_host="0.0.0.0", + rpc_port=9190, rpc_num_threads=None, ndk_cc=None, + show_lower_result=True): cost = gflops = None inp, res = ansor.best_measure_pair_in_file(log_file, wkl_key, target) if inp is None: - print("Cannot find log for: %s" % (wkl_key)) + print("Cannot find log for: %s" % wkl_key) else: dag = ansor.workload_key_to_dag(inp.task.workload_key) - print("Found schedule for: %s" % (wkl_key)) + print("Found schedule for: %s" % wkl_key) s, bufs = dag.apply_steps_from_state(inp.state) if show_lower_result: @@ -60,18 +66,21 @@ def replay_workload(wkl_key, target, target_host, log_file, if local_measure: remote = None else: - remote = request_remote(device_key, host, port, 1) + remote = request_remote(rpc_device_key, rpc_host, rpc_port) + if rpc_num_threads: + config_threadpool = remote.get_function('runtime.config_threadpool') + config_threadpool(0, rpc_num_threads) - cost = np.mean((measure_schedule(s, bufs, target, remote=remote, ndk_cc=ndk_cc))) + cost = np.mean((measure_schedule(s, bufs, target, target_host, + remote=remote, ndk_cc=ndk_cc))) gflops = ansor.ComputeDAG(bufs).flop_ct / cost / 1e9 - print("Best schedule: %.2f GFLOPS\tcost: %.3f ms" % - (gflops, cost * 1e3)) + print("Best schedule: %.2f GFLOPS\tcost: %.3f ms" % (gflops, cost * 1e3)) return cost, gflops -def tune_workload(wkl_key, target, target_host, policy, model_type, load_model_file, - load_log_file, tune_option): +def tune_workload(wkl_key, target, target_host, policy, model_type, + load_model_file, load_log_file, tune_option): """Tune a workload""" if False: @@ -92,11 +101,11 @@ def tune_workload(wkl_key, target, target_host, policy, model_type, load_model_f else: raise ValueError("Invalid model: " + model_type) - if policy == 'meta-rewrite': - policy = ansor.MetaTileRewritePolicy(program_cost_model=model) + if policy == 'sketch': + policy = ansor.SketchSearchPolicy(program_cost_model=model) elif policy == 'beam-search': - policy = ansor.MetaTileRewritePolicy(program_cost_model=model, - params={'use_beam_search': 1}) + policy = ansor.SketchSearchPolicy(program_cost_model=model, + params={'use_beam_search': 1}) else: raise ValueError("Invalid search policy: " + policy) @@ -105,12 +114,10 @@ def tune_workload(wkl_key, target, target_host, policy, model_type, load_model_f search_policy=policy, tune_option=tune_option) - def tune_workloads_jointly(wkl_keys, weights, task_scheduler, target, target_host, search_policy, model_type, load_model_file, load_log_file, tune_option): - """Tune for multiple workloads jointly""" - + """Tune for multiple workloads together with TaksScheduler""" tasks = [] for wkl_key in wkl_keys: dag = ansor.workload_key_to_dag(wkl_key) @@ -127,36 +134,37 @@ def objective_func(costs): if __name__ == "__main__": parser = argparse.ArgumentParser() - # Task related options + # Search task related arguments parser.add_argument("--wkl", type=str, required=True) parser.add_argument("--target", type=str, default='llvm -mcpu=core-avx2') parser.add_argument("--target-host", type=str, default=None) - parser.add_argument("--n-trials", type=int, default=1000) - parser.add_argument("--num-measure-per-iter", type=int, default=48, - help="The number of programs to be measured at each iteration") parser.add_argument("--tune", type=str2bool, nargs='?', const=True, default=True) - # Strategy related options - parser.add_argument("--seed", type=int, default=0, help='random seed') - parser.add_argument("--policy", type=str, choices=['meta-rewrite', 'beam-search'], default='meta-rewrite') + # Search strategy related arguments + parser.add_argument("--n-trials", type=int, default=1000) + parser.add_argument("--policy", type=str, choices=['sketch', 'beam-search'], default='sketch') parser.add_argument("--model-type", type=str, choices=['xgb', 'random', 'no-share'], default='xgb') parser.add_argument("--task-scheduler", type=str, default='no', choices=['no', 'gradient', 'round-robin'], help='The strategy of task scheduler') + parser.add_argument("--seed", type=int, default=0, help='random seed') - # File related options - parser.add_argument("--log-file", type=str, help="Write log of measurement results to this file") - parser.add_argument("--load-model", type=str, help="Load pre trained cost model file") - parser.add_argument("--load-log", type=str, help="Load history log for pre-training the cost model") + # Log file related arguments + parser.add_argument("--log-file", type=str, help="Write measurement records to this log file") + parser.add_argument("--load-log", type=str, help="Load history log to resume the status of search") + parser.add_argument("--load-model", type=str, help="Load pre-trained cost model from this file") - # Detailed control options + # Measurement related and other arguments + parser.add_argument("--num-measure-per-iter", type=int, default=48, + help="The number of programs to be measured at each iteration") parser.add_argument("--build-timeout", type=int, default=10) parser.add_argument("--run-timeout", type=int, default=60) parser.add_argument("--verbose", type=int, default=1) parser.add_argument("--local-measure", type=str2bool, nargs='?', const=True, default=True) - parser.add_argument("--device-key", type=str, default=None) - parser.add_argument("--host", type=str, default='0.0.0.0') - parser.add_argument("--port", type=int, default=9190) + parser.add_argument("--rpc-device-key", type=str, default=None) + parser.add_argument("--rpc-host", type=str, default='0.0.0.0') + parser.add_argument("--rpc-port", type=int, default=9190) + parser.add_argument("--rpc-num-threads", type=int, default=None) parser.add_argument("--n-parallel", type=int, default=1) parser.add_argument("--ndk-cc", type=str, default=None) args = parser.parse_args() @@ -170,14 +178,16 @@ def objective_func(costs): target = tvm.target.create(args.target) log_file = args.log_file or args.wkl + ".json" + # Tune workloads if args.tune: load_log_file = args.load_log or log_file weights = get_workload_weights(args.wkl) tune_option, measure_ctx = create_tune_option(target, log_file, - args.n_trials, args.num_measure_per_iter, args.verbose, - args.n_parallel, args.build_timeout, args.local_measure, - args.device_key, args.host, args.port, args.ndk_cc) + args.n_trials, args.num_measure_per_iter, args.verbose, + args.n_parallel, args.build_timeout, args.local_measure, + args.rpc_device_key, args.rpc_host, args.rpc_port, args.rpc_num_threads, + args.ndk_cc) if args.task_scheduler == 'no': # tune workloads one by one @@ -186,7 +196,7 @@ def objective_func(costs): args.model_type, args.load_model, load_log_file, tune_option) else: - # tune workloads jointly using JointTuner + # tune workloads jointly with TaskScheduler tune_workloads_jointly(wkl_keys, weights, args.task_scheduler, target, args.target_host, args.policy, args.model_type, args.load_model, load_log_file, @@ -194,8 +204,9 @@ def objective_func(costs): if measure_ctx: del measure_ctx - if not args.tune or len(wkl_keys) == 1: + # Replay the best found schedule + if len(wkl_keys) == 1 or not args.tune: for wkl_key in wkl_keys: replay_workload(wkl_key, target, args.target_host, log_file, - args.local_measure, args.device_key, args.host, - args.port, args.ndk_cc) + args.local_measure, args.rpc_device_key, args.rpc_host, + args.rpc_port, args.rpc_num_threads, args.ndk_cc) diff --git a/src/ansor/auto_schedule.cc b/src/ansor/auto_schedule.cc index 200118cf708b..7ffc63a03917 100644 --- a/src/ansor/auto_schedule.cc +++ b/src/ansor/auto_schedule.cc @@ -26,7 +26,7 @@ #include #include #include -#include "search_policy/meta_tile_rewrite_policy.h" +#include "search_policy/sketch_search_policy.h" namespace tvm { namespace ansor { diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index 6269b9f16f71..95e744a0e777 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -1147,8 +1147,7 @@ void ComputeDAG::InferBoundCommon(StateNode* pstate) const { } pstate->stages[i] = StageNode::make(stage->op, stage->op_type, - std::move(new_iters), stage->compute_at, - stage->auto_unroll_max_step, stage->storage_offset); + std::move(new_iters), stage->compute_at, stage->attrs); } } diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index 7569c91e3368..239f4e6988ac 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -76,35 +76,32 @@ Stage StageNode::make(te::Operation op) { node->compute_at = kRoot; node->op = std::move(op); - node->auto_unroll_max_step = 0; - node->storage_offset = 0; + node->attrs.auto_unroll_max_step = 0; + node->attrs.storage_offset = 0; return Stage(node); } Stage StageNode::make(te::Operation op, StageType op_type, const std::vector& iters, - ComputeAtType compute_at, int auto_unroll_max_step, - int storage_offset) { + ComputeAtType compute_at, StageAttributes attrs) { auto node = make_object(); node->op = std::move(op); node->op_type = op_type; node->iters = iters; node->compute_at = compute_at; - node->auto_unroll_max_step = auto_unroll_max_step; - node->storage_offset = storage_offset; + node->attrs = attrs; return Stage(node); } Stage StageNode::make(te::Operation op, StageType op_type, std::vector&& iters, ComputeAtType compute_at, - int auto_unroll_max_step, int storage_offset) { + StageAttributes attrs) { auto node = make_object(); node->op = std::move(op); node->op_type = op_type; node->iters = std::move(iters); node->compute_at = compute_at; - node->auto_unroll_max_step = auto_unroll_max_step; - node->storage_offset = storage_offset; + node->attrs = attrs; return Stage(node); } @@ -333,7 +330,7 @@ void State::DoReorderStep(const ReorderStep& step) { StateNode* pstate = CopyOnWrite(); pstate->stages[step->stage_id] = StageNode::make( stage->op, stage->op_type, std::move(iters), stage->compute_at, - stage->auto_unroll_max_step, stage->storage_offset); + stage->attrs); } // common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep @@ -400,7 +397,7 @@ std::vector State::DoSplitStepCommon( StateNode* pstate = CopyOnWrite(); pstate->stages[stage_id] = StageNode::make( stage->op, stage->op_type, std::move(new_iters), stage->compute_at, - stage->auto_unroll_max_step, stage->storage_offset); + stage->attrs); // we have to replace the iterators in attach map, // these two vectors keep the replacement mapping @@ -494,7 +491,7 @@ Iterator State::DoFuseStep(const FuseStep& step) { StateNode* pstate = CopyOnWrite(); pstate->stages[stage_id] = StageNode::make( stage->op, stage->op_type, std::move(new_iters), stage->compute_at, - stage->auto_unroll_max_step, stage->storage_offset); + stage->attrs); // we have to replace the iterators in attach map, // these two vectors keep the replacement mapping @@ -559,7 +556,7 @@ void State::DoComputeAtStep(const ComputeAtStep& step) { StateNode* pstate = CopyOnWrite(); pstate->stages[step->stage_id] = StageNode::make(stage->op, stage->op_type, std::move(new_iters), kIter, - stage->auto_unroll_max_step, stage->storage_offset); + stage->attrs); pstate->attach_map.SetComputeAtIter(step->stage_id, step->target_stage_id, step->target_iter_id); } @@ -581,7 +578,7 @@ void State::DoComputeRootStep(const ComputeRootStep& step) { StateNode* pstate = CopyOnWrite(); pstate->stages[step->stage_id] = StageNode::make(stage->op, stage->op_type, std::move(new_iters), kRoot, - stage->auto_unroll_max_step, stage->storage_offset); + stage->attrs); pstate->attach_map.DeleteStage(step->stage_id); } @@ -716,7 +713,7 @@ void State::DoPragmaStep(const PragmaStep& step) { StateNode* pstate = CopyOnWrite(); StageNode* stage = pstate->stages[step->stage_id].CopyOnWrite(); size_t pos = step->pragma_type.find('$'); - stage->auto_unroll_max_step = atoi(step->pragma_type.c_str() + pos + 1); + stage->attrs.auto_unroll_max_step = atoi(step->pragma_type.c_str() + pos + 1); } else if (step->pragma_type == "tensor_core") { // Nothing needs to be done here } else { @@ -759,7 +756,7 @@ int State::DoRfactorStep(const RfactorStep& step, const ComputeDAG& dag) { void State::DoStorageAlignStep(const StorageAlignStep& step) { StateNode* pstate = CopyOnWrite(); StageNode* stage = pstate->stages[step->stage_id].CopyOnWrite(); - stage->storage_offset = step->offset; + stage->attrs.storage_offset = step->offset; } Iterator State::DoTensorizeStep(const TensorizeStep& step) { @@ -831,19 +828,19 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state, size_t base_indent, bool delete_trivial_loop) { const Stage& stage = state->stages[stage_id]; - if (stage->auto_unroll_max_step != 0) { + if (stage->attrs.auto_unroll_max_step != 0) { for (size_t j = 0; j < base_indent; ++j) { *os << " "; } *os << stage->op->func_name() - << " auto_unroll: " << stage->auto_unroll_max_step << "\n"; + << " auto_unroll: " << stage->attrs.auto_unroll_max_step << "\n"; } - if (stage->storage_offset != 0) { + if (stage->attrs.storage_offset != 0) { for (size_t j = 0; j < base_indent; ++j) { *os << " "; } *os << stage->op->func_name() - << " storage_offset: " << stage->storage_offset << "\n"; + << " storage_offset: " << stage->attrs.storage_offset << "\n"; } size_t indent = 0; diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h index 6eef404ae272..31ed5274184d 100644 --- a/src/ansor/loop_state.h +++ b/src/ansor/loop_state.h @@ -121,6 +121,12 @@ class CacheReadStep; class CacheWriteStep; class PragmaStep; class RfactorStep; class StorageAlignStep; class TensorizeStep; +/*! \brief Stage-level attributes */ +struct StageAttributes { + int auto_unroll_max_step; + int storage_offset; +}; + /*! * \brief A stage in the compute declaration * Similar to te::Stage in `include/schedule.h` @@ -131,8 +137,7 @@ class StageNode : public Object { StageType op_type; std::vector iters; ComputeAtType compute_at; - int auto_unroll_max_step; - int storage_offset; + StageAttributes attrs; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("op", &op); @@ -141,12 +146,10 @@ class StageNode : public Object { static Stage make(te::Operation op); static Stage make(te::Operation op, StageType op_type, const std::vector& iters, - ComputeAtType compute_at, int auto_unroll_max_step, - int storage_offset); + ComputeAtType compute_at, StageAttributes attrs); static Stage make(te::Operation op, StageType op_type, std::vector&& iters, - ComputeAtType compute_at, int auto_unroll_max_step, - int storage_offset); + ComputeAtType compute_at, StageAttributes attrs); static constexpr const char *_type_key = "ansor.Stage"; TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object); diff --git a/src/ansor/search_policy/search_policy.h b/src/ansor/search_policy/search_policy.h index f1f6f45fce9a..4710cc05ae7f 100644 --- a/src/ansor/search_policy/search_policy.h +++ b/src/ansor/search_policy/search_policy.h @@ -43,6 +43,7 @@ class SearchPolicyNode; class SearchCallbackNode : public Object { public: virtual void callback(SearchPolicyNode* policy) = 0; + static constexpr const char *_type_key = "ansor.SearchCallback"; TVM_DECLARE_BASE_OBJECT_INFO(SearchCallbackNode, Object); }; diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.cc b/src/ansor/search_policy/sketch_search_policy.cc similarity index 91% rename from src/ansor/search_policy/meta_tile_rewrite_policy.cc rename to src/ansor/search_policy/sketch_search_policy.cc index 8b5b97224c08..7e4c3999dce3 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.cc +++ b/src/ansor/search_policy/sketch_search_policy.cc @@ -18,11 +18,13 @@ */ /*! - * \file ansor/search_policy/meta_tile_rewrite_policy.h - * \brief The search policy that searches by program sampling and evolutionary search + * \file ansor/search_policy/sketch_search_policy.h + * \brief The search policy that searches in a hierarchical search space defined by sketches. + * The policy randomly samples programs from the space defined by sketches + * and use evolutionary search to fine-tune them. */ -#include "meta_tile_rewrite_policy.h" +#include "sketch_search_policy.h" #include #include #include @@ -41,23 +43,23 @@ namespace tvm { namespace ansor { -TVM_REGISTER_NODE_TYPE(MetaTileRewritePolicyNode); -TVM_REGISTER_OBJECT_TYPE(PreAddCustomRuleNode); +TVM_REGISTER_NODE_TYPE(SketchSearchPolicyNode); +TVM_REGISTER_OBJECT_TYPE(PreloadCustomSketchRuleNode); // All possible candidates for auto_unroll -const std::vector MetaTileRewritePolicyNode::auto_unroll_configs{0, 16, 64, 512, 1024}; +const std::vector SketchSearchPolicyNode::auto_unroll_configs{0, 16, 64, 512, 1024}; -SearchPolicy MetaTileRewritePolicyNode::make(CostModel program_cost_model, +SearchPolicy SketchSearchPolicyNode::make(CostModel program_cost_model, Map params, int seed) { - auto node = make_object(); + auto node = make_object(); node->program_cost_model = std::move(program_cost_model); node->rand_gen_ = std::mt19937(seed); node->params = std::move(params); return SearchPolicy(node); } -State MetaTileRewritePolicyNode::Search(SearchTask task, int n_trials, +State SketchSearchPolicyNode::Search(SearchTask task, int n_trials, int early_stopping, int num_measure_per_iter, int verbose, ProgramMeasurer measurer, Array pre_search_callbacks) { @@ -129,7 +131,7 @@ State MetaTileRewritePolicyNode::Search(SearchTask task, int n_trials, } std::pair, Array > - MetaTileRewritePolicyNode::ContinueSearchOneRound( + SketchSearchPolicyNode::ContinueSearchOneRound( SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) { if (cur_task.defined()) { CHECK_EQ(cur_task, task); @@ -176,7 +178,7 @@ std::pair, Array > return std::make_pair(std::move(inputs_arr), std::move(results_arr)); } -void MetaTileRewritePolicyNode::PickStatesWithEpsGreedy( +void SketchSearchPolicyNode::PickStatesWithEpsGreedy( std::vector* inputs, const std::vector& best_states, const std::vector& random_states, @@ -224,7 +226,7 @@ void MetaTileRewritePolicyNode::PickStatesWithEpsGreedy( } } -void MetaTileRewritePolicyNode::SearchOneRound(std::vector* best_states, +void SketchSearchPolicyNode::SearchOneRound(std::vector* best_states, int num_random_states, std::vector* random_states) { best_states->clear(); random_states->clear(); @@ -240,16 +242,16 @@ void MetaTileRewritePolicyNode::SearchOneRound(std::vector* best_states, num_use_measured = 0; } - // Synthesize meta structure - std::vector meta_structures; - GenerateMetaSketch(&meta_structures); + // Generate sketches + std::vector sketches; + GenerateSketch(&sketches); - // PrintAllStates(meta_structures); + // PrintAllStates(sketches); // exit(0); // Sample the init population std::vector init_population; - SampleInitPopulation(meta_structures, population - num_use_measured, &init_population); + SampleInitPopulation(sketches, population - num_use_measured, &init_population); // PrintAllStates(init_population); // exit(0); @@ -273,21 +275,21 @@ void MetaTileRewritePolicyNode::SearchOneRound(std::vector* best_states, RandomSampleStates(init_population, &rand_gen_, num_random_states * 10, random_states); } -// The baseclass of derivation rules used in meta sketch generation +// The baseclass of derivation rules used in sketch generation class SketchGenerationRule { public: enum ConditionEnum { kPass, kApply, kApplyAndSkipRest }; - virtual ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + virtual ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, const State& state, int stage_id) = 0; - virtual std::vector > Apply(const MetaTileRewritePolicyNode* policy, + virtual std::vector > Apply(const SketchSearchPolicyNode* policy, const State& state, int stage_id) = 0; }; static inline bool ShouldBeCacheRead( - const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) { + const SketchSearchPolicyNode* policy, const State& state, int stage_id) { const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; @@ -319,7 +321,7 @@ static inline bool ShouldBeCacheRead( } static inline bool ShouldAlwaysBeInlined( - const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) { + const SketchSearchPolicyNode* policy, const State& state, int stage_id) { const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; @@ -348,13 +350,13 @@ static inline bool ShouldAlwaysBeInlined( // The rule that inlines simple elementwise ops class RuleAlwaysInline : public SketchGenerationRule { public: - ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { return ShouldAlwaysBeInlined(policy, state, stage_id) ? kApplyAndSkipRest : kPass; } - std::vector > Apply(const MetaTileRewritePolicyNode* policy, + std::vector > Apply(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { State tmp_s = state; tmp_s.compute_inline(stage_id); @@ -365,7 +367,7 @@ class RuleAlwaysInline : public SketchGenerationRule { // The rule that simply skip the current stage class RuleSkipStage : public SketchGenerationRule { public: - ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; @@ -381,7 +383,7 @@ class RuleSkipStage : public SketchGenerationRule { return kApply; } - std::vector > Apply(const MetaTileRewritePolicyNode* policy, + std::vector > Apply(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { return {std::make_pair(state, stage_id - 1)}; } @@ -390,7 +392,7 @@ class RuleSkipStage : public SketchGenerationRule { // The rule that performs multi-level tiling class RuleMultiLevelTiling : public SketchGenerationRule { public: - ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; @@ -399,7 +401,7 @@ class RuleMultiLevelTiling : public SketchGenerationRule { (IS_GPU(policy->cur_task) ? kApplyAndSkipRest : kApply) : kPass; } - std::vector > Apply(const MetaTileRewritePolicyNode* policy, + std::vector > Apply(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { std::string multi_level_tiling_structure = IS_GPU(policy->cur_task) ? GetStringParam(policy->params, "gpu_multi_level_tiling_structure") : @@ -416,7 +418,7 @@ class RuleMultiLevelTiling : public SketchGenerationRule { // The rule that performs multi-level tiling and fuses later consumers class RuleMultiLevelTilingWithFusion : public SketchGenerationRule { public: - ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; @@ -438,7 +440,7 @@ class RuleMultiLevelTilingWithFusion : public SketchGenerationRule { kApply : kPass; } - std::vector > Apply(const MetaTileRewritePolicyNode* policy, + std::vector > Apply(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; @@ -485,7 +487,7 @@ class RuleMultiLevelTilingWithFusion : public SketchGenerationRule { // The rule that adds a cache write stage class RuleAddCacheWrite : public SketchGenerationRule { public: - ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; @@ -503,7 +505,7 @@ class RuleAddCacheWrite : public SketchGenerationRule { kApply : kPass; } - std::vector > Apply(const MetaTileRewritePolicyNode* policy, + std::vector > Apply(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { const SearchTask& task = policy->cur_task; @@ -518,13 +520,13 @@ class RuleAddCacheWrite : public SketchGenerationRule { // Currently only support 1 to 1 match cache read class RuleAddCacheRead : public SketchGenerationRule { public: - ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { return ShouldBeCacheRead(policy, state, stage_id) ? kApplyAndSkipRest : kPass; } - std::vector > Apply(const MetaTileRewritePolicyNode* policy, + std::vector > Apply(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; @@ -549,7 +551,7 @@ class RuleAddCacheRead : public SketchGenerationRule { // The rule that adds rfactor stage class RuleAddRfactor : public SketchGenerationRule { public: - ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; @@ -559,7 +561,7 @@ class RuleAddRfactor : public SketchGenerationRule { kApply : kPass; } - std::vector > Apply(const MetaTileRewritePolicyNode* policy, + std::vector > Apply(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; @@ -611,7 +613,7 @@ class RuleAddRfactor : public SketchGenerationRule { } }; -void MetaTileRewritePolicyNode::GenerateMetaSketch( +void SketchSearchPolicyNode::GenerateSketch( std::vector* out_states) { State init_state = cur_task->compute_dag.GetInitState(); std::string cpu_multi_level_tiling_structure = @@ -705,10 +707,10 @@ void MetaTileRewritePolicyNode::GenerateMetaSketch( } } - StdCout(verbose) << "Synthesize Meta Structure\t\t#s: " << out_states->size() << std::endl; + StdCout(verbose) << "Generate Sketches\t\t#s: " << out_states->size() << std::endl; } -int InitPopulationFillTileSize(const MetaTileRewritePolicyNode* policy, +int InitPopulationFillTileSize(const SketchSearchPolicyNode* policy, State* state, std::mt19937* rand_gen, SplitFactorizationMemo* split_memo) { for (size_t step_id = 0; step_id < (*state)->transform_steps.size(); ++step_id) { @@ -741,7 +743,7 @@ int InitPopulationFillTileSize(const MetaTileRewritePolicyNode* policy, return 0; } -int InitPopulationThreadBind(const MetaTileRewritePolicyNode* policy, +int InitPopulationThreadBind(const SketchSearchPolicyNode* policy, State* state) { for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { const Stage& stage = (*state)->stages[stage_id]; @@ -853,7 +855,7 @@ int InitPopulationThreadBind(const MetaTileRewritePolicyNode* policy, return 0; } -int InitPopulationCooperativeFetching(const MetaTileRewritePolicyNode* policy, +int InitPopulationCooperativeFetching(const SketchSearchPolicyNode* policy, State* state) { for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { // Do cooperative fetching with cache read stage @@ -898,7 +900,7 @@ int InitPopulationCooperativeFetching(const MetaTileRewritePolicyNode* policy, return 0; } -int InitPopulationChangeComputeLocation(const MetaTileRewritePolicyNode* policy, +int InitPopulationChangeComputeLocation(const SketchSearchPolicyNode* policy, State* state, std::mt19937* rand_gen) { if(GetIntParam(policy->params, "disable_change_compute_location")) { return 0; @@ -1060,12 +1062,12 @@ int InitPopulationChangeComputeLocation(const MetaTileRewritePolicyNode* policy, return 0; } -int InitPopulationParallel(const MetaTileRewritePolicyNode* policy, +int InitPopulationParallel(const SketchSearchPolicyNode* policy, State* state) { - std::function annotate_parallel; + std::function annotate_parallel; annotate_parallel = [&annotate_parallel]( - const MetaTileRewritePolicyNode* policy, State* state, int stage_id, int iter_offset) { + const SketchSearchPolicyNode* policy, State* state, int stage_id, int iter_offset) { const Stage& stage = (*state)->stages[stage_id]; std::vector to_fuse; @@ -1125,7 +1127,7 @@ int InitPopulationParallel(const MetaTileRewritePolicyNode* policy, return 0; } -int InitPopulationVectorization(const MetaTileRewritePolicyNode* policy, +int InitPopulationVectorization(const SketchSearchPolicyNode* policy, State* state, std::mt19937* rand_gen) { for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { const Stage& stage = (*state)->stages[stage_id]; @@ -1202,7 +1204,7 @@ int InitPopulationVectorization(const MetaTileRewritePolicyNode* policy, return 0; } -int InitPopulationUnroll(const MetaTileRewritePolicyNode* policy, +int InitPopulationUnroll(const SketchSearchPolicyNode* policy, State* state, std::mt19937* rand_gen) { for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { const Stage& stage = (*state)->stages[stage_id]; @@ -1266,7 +1268,7 @@ int InitPopulationUnroll(const MetaTileRewritePolicyNode* policy, return 0; } -void MetaTileRewritePolicyNode::SampleInitPopulation(const std::vector& meta_structures, +void SketchSearchPolicyNode::SampleInitPopulation(const std::vector& sketches, int out_size, std::vector* out_states) { std::uniform_real_distribution<> dis(0.0, 1.0); int continue_count = 0; @@ -1274,7 +1276,7 @@ void MetaTileRewritePolicyNode::SampleInitPopulation(const std::vector& m // TODO(...): Maybe try muti thread here while (static_cast(out_states->size()) < out_size && continue_count < out_size * 10) { - State tmp_s = meta_structures[rand_gen_() % meta_structures.size()]; + State tmp_s = sketches[rand_gen_() % sketches.size()]; InitPopulationFillTileSize(this, &tmp_s, &rand_gen_, &split_memo_); @@ -1305,11 +1307,11 @@ void MetaTileRewritePolicyNode::SampleInitPopulation(const std::vector& m out_states->push_back(std::move(tmp_s)); } - StdCout(verbose) << "Sample Initial Population\t\t#s: " + StdCout(verbose) << "Sample Initial Population\t#s: " << out_states->size() << std::endl; } -void MetaTileRewritePolicyNode::EvolutionarySearch( +void SketchSearchPolicyNode::EvolutionarySearch( const std::vector& init_population, int num_best_states, std::vector* best_states) { auto tic_begin = std::chrono::high_resolution_clock::now(); @@ -1473,10 +1475,10 @@ class RuleCustomSketch : public SketchGenerationRule { RuleCustomSketch(PackedFunc meet_condition_func, PackedFunc apply_func) : meet_condition_func_(meet_condition_func), apply_func_(apply_func) {} - inline ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + inline ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { auto ret = meet_condition_func_( - tvm::runtime::GetRef(policy), state, stage_id); + tvm::runtime::GetRef(policy), state, stage_id); if (ret.type_code() == 0) { return ConditionEnum(static_cast(ret)); } else { @@ -1485,12 +1487,12 @@ class RuleCustomSketch : public SketchGenerationRule { } inline std::vector > Apply( - const MetaTileRewritePolicyNode* policy, + const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { std::vector > ret; Array> apply_ret = apply_func_( - tvm::runtime::GetRef(policy), state, stage_id); + tvm::runtime::GetRef(policy), state, stage_id); for (const auto& item : apply_ret) { CHECK_EQ(item.size(), 2); @@ -1506,32 +1508,32 @@ class RuleCustomSketch : public SketchGenerationRule { PackedFunc apply_func_; }; -SearchCallback PreAddCustomRuleNode::make(PackedFunc meet_condition_func, +SearchCallback PreloadCustomSketchRuleNode::make(PackedFunc meet_condition_func, PackedFunc apply_func) { - auto node = make_object(); + auto node = make_object(); node->meet_condition_func = meet_condition_func; node->apply_func = apply_func; return SearchCallback(node); } -void PreAddCustomRuleNode::callback(SearchPolicyNode* policy) { - CHECK(policy->IsInstance()); - auto meta_policy = dynamic_cast(policy); - meta_policy->sketch_rules.emplace_back( +void PreloadCustomSketchRuleNode::callback(SearchPolicyNode* policy) { + CHECK(policy->IsInstance()); + auto sketch_policy = dynamic_cast(policy); + sketch_policy->sketch_rules.emplace_back( new RuleCustomSketch(meet_condition_func, apply_func)); StdCout(policy->verbose) << "Custom sketch rule added." << std::endl; } -TVM_REGISTER_GLOBAL("ansor.MetaTileRewritePolicy") +TVM_REGISTER_GLOBAL("ansor.SketchSearchPolicy") .set_body_typed([](CostModel program_cost_model, Map params, int seed){ - return MetaTileRewritePolicyNode::make(program_cost_model, params, seed); + return SketchSearchPolicyNode::make(program_cost_model, params, seed); }); -TVM_REGISTER_GLOBAL("ansor.PreAddCustomRule") +TVM_REGISTER_GLOBAL("ansor.PreloadCustomSketchRule") .set_body_typed([](PackedFunc meet_condition_func, PackedFunc apply_func) { - return PreAddCustomRuleNode::make(meet_condition_func, apply_func); + return PreloadCustomSketchRuleNode::make(meet_condition_func, apply_func); }); } // namespace ansor diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.h b/src/ansor/search_policy/sketch_search_policy.h similarity index 66% rename from src/ansor/search_policy/meta_tile_rewrite_policy.h rename to src/ansor/search_policy/sketch_search_policy.h index 6930a71038a3..60920c5c1fdd 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.h +++ b/src/ansor/search_policy/sketch_search_policy.h @@ -18,12 +18,14 @@ */ /*! - * \file ansor/search_policy/meta_tile_rewrite_policy.h - * \brief The search policy that searches by program sampling and evolutionary search + * \file ansor/search_policy/sketch_search_policy.h + * \brief The search policy that searches in a hierarchical search space defined by sketches. + * The policy randomly samples programs from the space defined by sketches + * and use evolutionary search to fine-tune them. */ -#ifndef TVM_ANSOR_SEARCH_POLICY_META_TILE_REWRITE_POLICY_H_ -#define TVM_ANSOR_SEARCH_POLICY_META_TILE_REWRITE_POLICY_H_ +#ifndef TVM_ANSOR_SEARCH_POLICY_SKETCH_SEARCH_POLICY_H_ +#define TVM_ANSOR_SEARCH_POLICY_SKETCH_SEARCH_POLICY_H_ #include #include @@ -40,12 +42,17 @@ namespace ansor { class SketchGenerationRule; -/*! Multi stage search policy */ -class MetaTileRewritePolicyNode: public SearchPolicyNode { +/*! + * \brief The search policy that searches in a hierarchical search space defined by sketches. + * The policy randomly samples programs from the space defined by sketches + * and use evolutionary search to fine-tune them. + */ +class SketchSearchPolicyNode: public SearchPolicyNode { public: + /*! \brief The cost model for complete programs */ CostModel program_cost_model; - /* this->params is used to store the following arguments + /*! \brief The parameters for search. It stores the following parameters: * int evolutionary_search_population // The population size for evolutionary search * int evolutionary_search_mutation_prob // The probability of mutation for evolutionary search * int evolutionary_search_num_iters; // The number of iterations for evolutionary search @@ -56,30 +63,33 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode { * str gpu_multi_level_tiling_structure // The structure of multi-level tiling for GPU */ Map params; + + /*! \brief The rules to generate sketches */ std::vector sketch_rules; static SearchPolicy make(CostModel program_cost_model, Map params, int seed); - // Search and make n_trails measurements - // Return the best state + /*! \brief Search and make n_trails measurements. + * \returns the best state */ State Search(SearchTask task, int n_trials, int early_stopping, int num_measure_per_iter, int verbose, ProgramMeasurer measurer, Array pre_search_callbacks) final; - // Continue search. This is used by JointTuner + /*! \brief Continue search for one round. This is used by JointTuner + * \returns the measurement pairs */ std::pair, Array > ContinueSearchOneRound( SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) final; - static constexpr const char *_type_key = "ansor.MetaTileRewritePolicy"; + static constexpr const char *_type_key = "ansor.SketchSearchPolicy"; static const std::vector auto_unroll_configs; - TVM_DECLARE_FINAL_OBJECT_INFO(MetaTileRewritePolicyNode, SearchPolicyNode); + TVM_DECLARE_FINAL_OBJECT_INFO(SketchSearchPolicyNode, SearchPolicyNode); protected: - // Pick states from best states and random states with eps-greedy policy + /*! \brief Pick states from best states and random states with eps-greedy policy */ void PickStatesWithEpsGreedy(std::vector* inputs, const std::vector& best_states, const std::vector& random_states, int remaining_n_trials); @@ -89,11 +99,11 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode { void SearchOneRound(std::vector* best_states, int num_random_states, std::vector* random_states); - // Synthesize meta tiling structure without tile size - void GenerateMetaSketch(std::vector* out_states); + // Generate sketches without tile size + void GenerateSketch(std::vector* out_states); // Sample init population - void SampleInitPopulation(const std::vector& meta_structures, + void SampleInitPopulation(const std::vector& sketches, int out_size, std::vector* out_states); // Perform evolutionary search @@ -104,9 +114,10 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode { std::mt19937 rand_gen_; // Random generator int num_measure_per_iter_; // The number of states to measure per iteration }; -TVM_DEFINE_MUTABLE_OBJECT_REF(MetaTileRewritePolicy, MetaTileRewritePolicyNode); +TVM_DEFINE_MUTABLE_OBJECT_REF(SketchSearchPolicy, SketchSearchPolicyNode); -class PreAddCustomRuleNode : public SearchCallbackNode { +/*! \brief Pre-search callback function to load custom rules for sketch generation */ +class PreloadCustomSketchRuleNode : public SearchCallbackNode { public: // TODO(jcf94): Use tvm::runtime::TypedPackedFunc? PackedFunc meet_condition_func; @@ -117,11 +128,11 @@ class PreAddCustomRuleNode : public SearchCallbackNode { void callback(SearchPolicyNode* policy) final; - static constexpr const char *_type_key = "ansor.PreAddCustomRule"; - TVM_DECLARE_FINAL_OBJECT_INFO(PreAddCustomRuleNode, SearchCallbackNode); + static constexpr const char *_type_key = "ansor.PreloadCustomSketchRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(PreloadCustomSketchRuleNode, SearchCallbackNode); }; } // namespace ansor } // namespace tvm -#endif // TVM_ANSOR_SEARCH_POLICY_META_TILE_REWRITE_POLICY_H_ +#endif // TVM_ANSOR_SEARCH_POLICY_SKETCH_SEARCH_POLICY_H_ diff --git a/tests/python/unittest/test_ansor_common.py b/tests/python/unittest/test_ansor_common.py index 083bd2721cb6..485679d6aa4e 100644 --- a/tests/python/unittest/test_ansor_common.py +++ b/tests/python/unittest/test_ansor_common.py @@ -21,7 +21,7 @@ import topi -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def matmul_ansor_test(N, M, K): A = te.placeholder((N, K), name='A') B = te.placeholder((K, M), name='B') diff --git a/tests/python/unittest/test_ansor_relay_integration.py b/tests/python/unittest/test_ansor_relay_integration.py index f3f424ab321b..1ad507e2f371 100644 --- a/tests/python/unittest/test_ansor_relay_integration.py +++ b/tests/python/unittest/test_ansor_relay_integration.py @@ -84,7 +84,6 @@ def dense_graph(N, dtype="float32"): def test_tune_dqn(): mod, params = dqn.get_workload(1, image_shape=(84, 84, 4), layout='NHWC') target = tvm.target.create('llvm') - ctx = tvm.context("llvm") wkl_keys, wkl_weights = ansor.extract_from_program(mod, params, target) @@ -100,7 +99,7 @@ def test_tune_dqn(): with tempfile.NamedTemporaryFile() as fp: tuner.tune(ansor.TuneOption(n_trials=len(tasks), runner=measure_ctx.runner, measure_callbacks=[ansor.LogToFile('tmp.json')]), - search_policy='meta-rewrite.random') + search_policy='sketch.random') with ansor.apply_history_best('tmp.json'): ansor.prepare_layout_rewrite(mod, params, target) with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}): diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py index 9b1716175b5a..deff561a4547 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -42,8 +42,7 @@ def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local' with tempfile.NamedTemporaryFile() as fp: log_file = fp.name - search_policy = ansor.MetaTileRewritePolicy(cost_model, params=params, - seed=seed) + search_policy = ansor.SketchSearchPolicy(cost_model, params=params, seed=seed) tune_option = ansor.TuneOption(n_trials=n_trials, runner=runner, measure_callbacks=[ansor.LogToFile(log_file)], pre_search_callbacks=pre_search_callbacks) @@ -74,8 +73,8 @@ def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local' def test_search_basic(): - # Ansor search process with local runner has some modification on thread - # binding, wrap this to a subprocess to eliminate the impacts to other tests + # wrap the search in a new thread to avoid the conflict + # between python's multiprocessing and tvm's thread pool t = threading.Thread(target=search_common, kwargs={'seed': 944563397}) t.start() t.join() @@ -152,12 +151,12 @@ def apply_func2(meta_policy, state, stage_id): measure_ctx = ansor.LocalRPCMeasureContext() search_common(seed=887823438, runner=measure_ctx.runner, - pre_search_callbacks=[ansor.PreAddCustomRule(meet_condition_func, - apply_func1)], + pre_search_callbacks=[ansor.PreloadCustomSketchRule( + meet_condition_func, apply_func1)], params={'disable_change_compute_location': 1}) search_common(seed=887823438, runner=measure_ctx.runner, - pre_search_callbacks=[ansor.PreAddCustomRule(meet_condition_func, - apply_func2)], + pre_search_callbacks=[ansor.PreloadCustomSketchRule( + meet_condition_func, apply_func2)], params={'disable_change_compute_location': 1}) diff --git a/tutorials/ansor/tune_conv2d_cuda.py b/tutorials/ansor/tune_conv2d_cuda.py index 437323d79791..03f1b24a768e 100644 --- a/tutorials/ansor/tune_conv2d_cuda.py +++ b/tutorials/ansor/tune_conv2d_cuda.py @@ -80,7 +80,7 @@ # recommended. # Use an extra function decorator to regist this workload -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def conv2d_nchw(N, H, W, CO, CI, KH, KW, stride, padding): data = te.placeholder((N, CI, H, W), name='data') kernel = te.placeholder((CO, CI, KH, KW), name='kernel') @@ -111,7 +111,7 @@ def conv2d_nchw(N, H, W, CO, CI, KH, KW, stride, padding): seed = 0 random.seed(seed) cost_model = ansor.XGBModel(seed=seed) -search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) +search_policy = ansor.SketchSearchPolicy(cost_model, seed=seed) ######################################################################### # The :code:`ansor.LocalRPCMeasureContext` is used to create a RPC runner environment. diff --git a/tutorials/ansor/tune_simple_subgraph.py b/tutorials/ansor/tune_simple_subgraph.py index 08d5628ad8a2..00bef82cf855 100644 --- a/tutorials/ansor/tune_simple_subgraph.py +++ b/tutorials/ansor/tune_simple_subgraph.py @@ -142,7 +142,7 @@ def matmul_add(N, L, M, dtype): ################################################################ # Next, we choose random model and create a default search policy: -# :code:`ansor.MetaTileRewritePolicy`. +# :code:`ansor.SketchSearchPolicy`. # # We only make 5 trials in this tutorial for demonstration. In practice, # you can do more trials according to your time budget. @@ -157,7 +157,7 @@ def matmul_add(N, L, M, dtype): seed = 0 random.seed(seed) cost_model = ansor.RandomModel() -search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) +search_policy = ansor.SketchSearchPolicy(cost_model, seed=seed) tune_option = ansor.TuneOption(n_trials=5, measure_callbacks=[ansor.LogToFile(log_file)], From 593a2c7f43ed157ca1c1d0be04955e7b3ad9efcd Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 20 Jun 2020 09:15:44 -0700 Subject: [PATCH 34/45] rebase --- src/ansor/compute_dag.cc | 24 +++++------ src/ansor/feature.cc | 2 +- src/ansor/loop_state.cc | 6 +-- src/ansor/transform_step.cc | 54 ++++++++++++------------ src/relay/op/tensor/transform.cc | 56 +++++++++++++++++++++++++ src/tir/transforms/unroll_loop.cc | 2 +- topi/include/topi/transform.h | 69 +++++++++++++++++++++++++++++++ 7 files changed, 169 insertions(+), 44 deletions(-) diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index 95e744a0e777..7b4857b34d76 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -863,9 +863,9 @@ void ComputeDAG::RewriteLayout( te::Operation new_placeholder_op; if (rewrite_placeholder) { new_placeholder_op = - te::PlaceholderOpNode::make(placeholder_op->name, - new_shape, - placeholder_op.as()->dtype); + te::PlaceholderOp(placeholder_op->name, + new_shape, + placeholder_op.as()->dtype); } else { new_placeholder_op = placeholder_op; } @@ -890,7 +890,7 @@ void ComputeDAG::RewriteLayout( } old_compute_op = op; CHECK(!new_compute_op.defined()); - new_compute_op = te::ComputeOpNode::make( + new_compute_op = te::ComputeOp( pop->name, pop->tag, pop->attrs, pop->axis, new_body); } } @@ -1028,8 +1028,8 @@ std::string ComputeDAG::PrintStepsAsPython(const std::vector& transform_st ss << ", "; } } - ss << " = " << "tuple(" << stage->op->func_name() << ".op.axis)" - << " + " << "tuple(" << stage->op->func_name() << ".op.reduce_axis)\n"; + ss << " = " << "tuple(" << stage->op->name << ".op.axis)" + << " + " << "tuple(" << stage->op->name << ".op.reduce_axis)\n"; } } @@ -1231,10 +1231,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) for (const auto& op : node->ops) { if (op->IsInstance()) { - ss << op->func_name() << " = PLACEHOLDER " << op.output(0)->shape << "\n"; + ss << op->name << " = PLACEHOLDER " << op.output(0)->shape << "\n"; } else if (auto pop = op.as()) { for (size_t k = 0; k < pop->body.size(); ++k) { - ss << op->func_name() << "("; + ss << op->name << "("; for (size_t i = 0; i < pop->axis.size(); i++) { ss << pop->axis[i]->var->name_hint; if (i != pop->axis.size() - 1) { @@ -1288,14 +1288,14 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "Read from:\t"; for (const auto& pair : node->read_from.at(op)) { for (const auto& index : pair.second) { - p->stream << pair.first->func_name() << Array(index) << ", "; + p->stream << pair.first->name << Array(index) << ", "; } } p->stream << "\n"; p->stream << "Read by:\t"; for (const auto& pair : node->read_by.at(op)) { for (const auto& index : pair.second) { - p->stream << pair.first->func_name() << Array(index) << ", "; + p->stream << pair.first->name << Array(index) << ", "; } } p->stream << "\n"; @@ -1310,8 +1310,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) if (i == j) { continue; } if (ana.ElementWiseMatch(node->ops_topo_order[i], node->ops_topo_order[j])) { - p->stream << node->ops_topo_order[i]->func_name() << " -> " - << node->ops_topo_order[j]->func_name() << "\n"; + p->stream << node->ops_topo_order[i]->name << " -> " + << node->ops_topo_order[j]->name << "\n"; } } } diff --git a/src/ansor/feature.cc b/src/ansor/feature.cc index 3c6976a0e25a..3b5849e22262 100644 --- a/src/ansor/feature.cc +++ b/src/ansor/feature.cc @@ -568,7 +568,7 @@ class PerStmtFeatureExtractor : public StmtExprVisitor { is_gpu = true; // make a fake for node for blockIdx.x or threadIdx.x - Stmt fake_for_node = ForNode::make(var, 0, extent, ForType::Parallel, + Stmt fake_for_node = For(var, 0, extent, ForType::Parallel, DeviceAPI::None, node->body); outer_loop_prod *= extent; diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index 239f4e6988ac..23e005503873 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -832,14 +832,14 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state, for (size_t j = 0; j < base_indent; ++j) { *os << " "; } - *os << stage->op->func_name() + *os << stage->op->name << " auto_unroll: " << stage->attrs.auto_unroll_max_step << "\n"; } if (stage->attrs.storage_offset != 0) { for (size_t j = 0; j < base_indent; ++j) { *os << " "; } - *os << stage->op->func_name() + *os << stage->op->name << " storage_offset: " << stage->attrs.storage_offset << "\n"; } @@ -915,7 +915,7 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state, for (size_t j = 0; j < base_indent + indent; ++j) { *os << " "; } - *os << stage->op->func_name() << " = ...\n"; + *os << stage->op->name << " = ...\n"; } void PrintState(std::ostream* os, const StateNode* node, diff --git a/src/ansor/transform_step.cc b/src/ansor/transform_step.cc index b0e67a481ae3..857f3e570de0 100644 --- a/src/ansor/transform_step.cc +++ b/src/ansor/transform_step.cc @@ -63,7 +63,7 @@ std::string ReorderStepNode::PrintAsPythonAPI(std::vector *stages, const te::Stage& stage = (*stages)[stage_id]; std::stringstream ss; - ss << "s[" << CleanName(stage->op->func_name()) << "].reorder("; + ss << "s[" << CleanName(stage->op->name) << "].reorder("; for (size_t i = 0; i < after_ids.size(); ++i) { ss << CleanName((*stage_to_axes)[stage][after_ids[i]]->var->name_hint); if (i != after_ids.size() - 1) { @@ -126,7 +126,7 @@ std::string PrintSplitAsPythonAPI(std::vector *stages, bool inner_to_outer) { te::Stage& stage = (*stages)[stage_id]; auto to_split = (*stage_to_axes)[stage][iter_id]; - const auto& func_name = CleanName(stage->op->func_name()); + const auto& func_name = CleanName(stage->op->name); const auto& outs = ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); @@ -330,7 +330,7 @@ std::string FuseStepNode::PrintAsPythonAPI(std::vector *stages, const auto& fused = ApplyToSchedule(stages, stage_to_axes); ss << CleanName(fused->var->name_hint) << " = s[" - << CleanName(stage->op->func_name()) << "].fuse(" + << CleanName(stage->op->name) << "].fuse(" << to_fuse.str() << ")\n"; return ss.str(); @@ -385,7 +385,7 @@ std::string AnnotationStepNode::PrintAsPythonAPI(std::vector *stages, ss << "thread_x = tvm.thread_axis(\"threadIdx.x\")\n"; } - ss << "s[" << CleanName(stage->op->func_name()) << "]."; + ss << "s[" << CleanName(stage->op->name) << "]."; switch (annotation) { case kUnroll: ss << "unroll("; break; case kVectorize: ss << "vectorize("; break; @@ -417,7 +417,7 @@ std::string AnnotationStepNode::PrintAsPythonAPI(std::vector *stages, ss << ")\n"; if (bind_reduce_iter) { - ss << "s[" << CleanName(stage->op->func_name()) << "]" + ss << "s[" << CleanName(stage->op->name) << "]" << ".set_store_predicate(thread_x.var.equal(0))\n"; } @@ -450,8 +450,8 @@ std::string ComputeAtStepNode::PrintAsPythonAPI(std::vector *stages, const auto& stage = (*stages)[stage_id]; const auto& target_stage = (*stages)[target_stage_id]; - ss << "s[" << CleanName(stage->op->func_name()) << "].compute_at(s[" - << CleanName(target_stage->op->func_name()) << "], " + ss << "s[" << CleanName(stage->op->name) << "].compute_at(s[" + << CleanName(target_stage->op->name) << "], " << CleanName((*stage_to_axes)[target_stage][target_iter_id]->var->name_hint); ss << ")\n"; @@ -478,7 +478,7 @@ std::string ComputeRootStepNode::PrintAsPythonAPI(std::vector *stages std::stringstream ss; const auto& stage = (*stages)[stage_id]; - ss << "s[" << CleanName(stage->op->func_name()) << "].compute_root()\n"; + ss << "s[" << CleanName(stage->op->name) << "].compute_root()\n"; ApplyToSchedule(stages, stage_to_axes); return ss.str(); @@ -504,7 +504,7 @@ std::string ComputeInlineStepNode::PrintAsPythonAPI( std::stringstream ss; const auto& stage = (*stages)[stage_id]; - ss << "s[" << CleanName(stage->op->func_name()) << "].compute_inline()\n"; + ss << "s[" << CleanName(stage->op->name) << "].compute_inline()\n"; ApplyToSchedule(stages, stage_to_axes); return ss.str(); @@ -551,12 +551,12 @@ std::string CacheReadStepNode::PrintAsPythonAPI(std::vector *stages, auto out = ApplyToSchedule(stages, stage_to_axes, schedule); - ss << CleanName(out->op->func_name()) << " = " - << "s.cache_read(" << CleanName(stage->op->func_name()) << ", \"" + ss << CleanName(out->op->name) << " = " + << "s.cache_read(" << CleanName(stage->op->name) << ", \"" << scope_name << "\", [" - << CleanName(reader_stages[0]->op->func_name()); + << CleanName(reader_stages[0]->op->name); for (size_t i = 1; i < reader_stage_ids.size(); ++i) { - ss << ", " << CleanName(reader_stages[i]->op->func_name()); + ss << ", " << CleanName(reader_stages[i]->op->name); } ss << "])\n"; @@ -567,7 +567,7 @@ std::string CacheReadStepNode::PrintAsPythonAPI(std::vector *stages, ss << ", "; } } - ss << " = " << "tuple(" << CleanName(out->op->func_name()) + ss << " = " << "tuple(" << CleanName(out->op->name) << ".op.axis)\n"; return ss.str(); @@ -615,7 +615,7 @@ std::string CacheWriteStepNode::PrintAsPythonAPI(std::vector *stages, auto outs = ApplyToSchedule(stages, stage_to_axes, schedule); for (size_t i = 0; i < outs.size(); ++i) { - ss << CleanName(outs[i]->op->func_name()) << ", "; + ss << CleanName(outs[i]->op->name) << ", "; } ss << "= " << "s.cache_write([" << CleanName(stage->op.output(0)->op->name); @@ -632,9 +632,9 @@ std::string CacheWriteStepNode::PrintAsPythonAPI(std::vector *stages, ss << ", "; } } - ss << " = " << "tuple(" << CleanName(out->op->func_name()) + ss << " = " << "tuple(" << CleanName(out->op->name) << ".op.axis)" - << " + " << "tuple(" << CleanName(out->op->func_name()) + << " + " << "tuple(" << CleanName(out->op->name) << ".op.reduce_axis)\n"; } @@ -675,14 +675,14 @@ std::string PragmaStepNode::PrintAsPythonAPI(std::vector *stages, if (StrStartsWith(pragma_type, "auto_unroll_max_step")) { size_t pos = pragma_type.find('$'); int value = atoi(pragma_type.c_str() + pos + 1); - ss << "s[" << CleanName(stage->op->func_name()) << "].pragma(" + ss << "s[" << CleanName(stage->op->name) << "].pragma(" << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", \"auto_unroll_max_step\", " << value << ")\n"; - ss << "s[" << CleanName(stage->op->func_name()) << "].pragma(" + ss << "s[" << CleanName(stage->op->name) << "].pragma(" << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", \"unroll_explicit\", True)\n"; } else { - ss << "s[" << CleanName(stage->op->func_name()) << "].pragma(" + ss << "s[" << CleanName(stage->op->name) << "].pragma(" << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", \"" << pragma_type << "\")\n"; } @@ -731,7 +731,7 @@ std::string RfactorStepNode::PrintAsPythonAPI(std::vector *stages, const auto& outs = ApplyToSchedule(stages, stage_to_axes, schedule); for (size_t i = 0; i < outs.size(); ++i) { - ss << CleanName(outs[i]->op->func_name()); + ss << CleanName(outs[i]->op->name); if (i != outs.size() - 1) { ss << ", "; } @@ -749,9 +749,9 @@ std::string RfactorStepNode::PrintAsPythonAPI(std::vector *stages, ss << ", "; } } - ss << " = " << "tuple(" << CleanName(out->op->func_name()) + ss << " = " << "tuple(" << CleanName(out->op->name) << ".op.axis)" - << " + " << "tuple(" << CleanName(out->op->func_name()) + << " + " << "tuple(" << CleanName(out->op->name) << ".op.reduce_axis)\n"; } @@ -763,9 +763,9 @@ std::string RfactorStepNode::PrintAsPythonAPI(std::vector *stages, ss << ", "; } } - ss << " = " << "tuple(s[" << CleanName(output->op->func_name()) + ss << " = " << "tuple(s[" << CleanName(output->op->name) << "].op.axis)" - << " + " << "tuple(s[" << CleanName(output->op->func_name()) + << " + " << "tuple(s[" << CleanName(output->op->name) << "].op.reduce_axis)\n"; return ss.str(); @@ -794,7 +794,7 @@ std::string StorageAlignStepNode::PrintAsPythonAPI( te::Schedule *schedule, const std::vector& transform_steps) const { std::stringstream ss; const auto& stage = (*stages)[stage_id]; - ss << "s[" << CleanName(stage->op->func_name()) << "].storage_align(" + ss << "s[" << CleanName(stage->op->name) << "].storage_align(" << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", " << factor << ", " << offset << ")\n"; @@ -829,7 +829,7 @@ std::string TensorizeStepNode::PrintAsPythonAPI( te::Schedule *schedule, const std::vector& transform_steps) const { std::stringstream ss; const auto& stage = (*stages)[stage_id]; - ss << "s[" << CleanName(stage->op->func_name()) << "].tensorize(" + ss << "s[" << CleanName(stage->op->name) << "].tensorize(" << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", " << ti_func_name << "())\n"; diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index ee5e291e3d53..18ace14a0b75 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2455,6 +2455,62 @@ the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w] .set_support_level(5) .set_attr("FTVMCompute", LayoutTransformCompute); +// relay.kernel_layout_transform +TVM_REGISTER_NODE_TYPE(KernelLayoutTransformAttrs); + +Array KernelLayoutTransformCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type) { + //const Target& target) { + const auto* param = attrs.as(); + CHECK(param != nullptr); + return Array{ + topi::kernel_layout_transform(inputs[0], param->src_layout, param->dst_layout) + }; +} + +bool KernelLayoutTransformRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + + const auto* data = types[0].as(); + CHECK(data != nullptr); + const KernelLayoutTransformAttrs* params = attrs.as(); + + Array dst_shape; + std::vector dst_axes; + + topi::parse_kernel_layout(params->dst_layout, &dst_shape, &dst_axes); + + reporter->Assign(types[1], TensorType(dst_shape, data->dtype)); + return true; +} + +Expr MakeKernelLayoutTransform(Expr data, + String src_layout, + String dst_layout) { + auto attrs = make_object(); + attrs->src_layout = std::move(src_layout); + attrs->dst_layout = std::move(dst_layout); + static const Op& op = Op::Get("kernel_layout_transform"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.kernel_layout_transform") +.set_body_typed(MakeKernelLayoutTransform); + +RELAY_REGISTER_OP("kernel_layout_transform") + .describe(R"code(Transform the input kernel layout. +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .add_type_rel("kernel_layout_transform", KernelLayoutTransformRel) + .set_support_level(5) + .set_attr("FTVMCompute", KernelLayoutTransformCompute); + + /* relay._contrib_reverse_reshape */ Expr MakeReverseReshape(Expr data, Array newshape) { auto attrs = make_object(); diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index 1c84304fb0e7..3876d67b7b11 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -173,7 +173,7 @@ class LoopUnroller : public StmtExprMutator { if (explicit_unroll_max_extent_ > 0 && value > explicit_unroll_max_extent_ && explicit_unroll_) { // Do not unroll too long loops ForType for_type = op->for_type == ForType::Unrolled ? ForType::Serial : op->for_type; - return ForNode::make(op->loop_var, op->min, op->extent, for_type, op->device_api, op->body); + return For(op->loop_var, op->min, op->extent, for_type, op->device_api, op->body); } Stmt body = op->body; Map vmap; diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index e0e455667889..7dd782f5b622 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -1295,6 +1295,75 @@ inline Tensor layout_transform(const Tensor& src, const std::string& src_layout, name, tag); } +/*! + * \brief utility function for kernel_layout_transform + */ +inline void parse_kernel_layout(const String& layout, + Array* shape, + std::vector* axes) { + int32_t factor = 0; + std::string axis = ""; + for (char c : std::string(layout)) { + if (c >= 'A' && c <= 'z') { + axis += c; + if (factor != 0) { + shape->push_back(factor); + factor = 0; + } + } else if (c >= '0' && c <= '9') { + factor = factor * 10 + c - '0'; + if (!axis.empty()) { + axes->push_back(axis); + axis = ""; + } + } else { + LOG(FATAL) << "Invalid layout " << layout; + } + } + if (!axis.empty()) { + axes->push_back(axis); + } +} + +/*! + * \brief Transform the kernel layout according to \p src_layout and \p dst_layout + * \param src the source input. + * \param src_layout the source layout. + * \param dst_layout the destination layout. + * \param name output tensor name. + * \param tag output tensor tag. + * \return A tensor with shape in \p dst_layout + */ +inline Tensor kernel_layout_transform(const Tensor& src, + const String& src_layout, + const String& dst_layout, + const String name = "T_kernel_layout_trans", + const String tag = kInjective) { + Array src_shape; + std::vector src_axes; + Array dst_shape; + std::vector dst_axes; + + parse_kernel_layout(src_layout, &src_shape, &src_axes); + parse_kernel_layout(dst_layout, &dst_shape, &dst_axes); + return compute( + dst_shape, [&](const Array& dst_indices) { + Array dst_indices_expr(dst_indices.begin(), dst_indices.end()); + Array src_indices; + for (const std::string& src_axis : src_axes) { + PrimExpr src_index = 0; + CHECK_EQ(dst_indices_expr.size(), dst_axes.size()); + for (size_t i = 0; i < dst_axes.size(); ++i) { + if (dst_axes[i] == src_axis) { + src_index = src_index * dst_shape[i] + dst_indices_expr[i]; + } + } + src_indices.push_back(src_index); + } + return src(src_indices); + }, name, tag); +} + /*! * \brief Get the shape of input tensor. * \param src the input tensor. From 53bd591167959a5ae0d85ca27988a826e73c8dcc Mon Sep 17 00:00:00 2001 From: Chenfan Date: Mon, 22 Jun 2020 15:22:23 +0800 Subject: [PATCH 35/45] Migrate all node::make to noderef's construct function (#37) * Start to move xxxnode::make to noderef() * Update * Update * Finish transform_step * Finish comute dag & auto schedule * Update * Update * Update * Update * Update * Code refine * Code refine * Code refine * Update * Update --- src/ansor/auto_schedule.cc | 26 +- src/ansor/auto_schedule.h | 22 +- src/ansor/compute_dag.cc | 39 ++- src/ansor/compute_dag.h | 20 +- src/ansor/cost_model/cost_model.cc | 33 ++- src/ansor/cost_model/cost_model.h | 67 ++++- src/ansor/feature.cc | 18 +- src/ansor/loop_state.cc | 182 ++++++------- src/ansor/loop_state.h | 125 +++++---- src/ansor/measure.cc | 92 ++++--- src/ansor/measure.h | 133 ++++++--- src/ansor/search_policy/search_policy.cc | 8 +- src/ansor/search_policy/search_policy.h | 17 +- .../search_policy/sketch_search_policy.cc | 34 ++- .../search_policy/sketch_search_policy.h | 43 ++- src/ansor/search_policy/utils.cc | 4 +- src/ansor/search_task.cc | 35 ++- src/ansor/search_task.h | 39 ++- src/ansor/serialization.cc | 54 ++-- src/ansor/serialization.h | 28 +- src/ansor/transform_step.cc | 92 ++++--- src/ansor/transform_step.h | 257 +++++++++++++----- tests/cpp/ansor_test.cc | 2 +- 23 files changed, 839 insertions(+), 531 deletions(-) diff --git a/src/ansor/auto_schedule.cc b/src/ansor/auto_schedule.cc index 7ffc63a03917..05cb95c2c451 100644 --- a/src/ansor/auto_schedule.cc +++ b/src/ansor/auto_schedule.cc @@ -33,11 +33,10 @@ namespace ansor { TVM_REGISTER_NODE_TYPE(TuneOptionNode); -TuneOption TuneOptionNode::make(int n_trials, int early_stopping, - int num_measure_per_iter, int verbose, - Builder builder, Runner runner, - Array measure_callbacks, - Array pre_search_callbacks) { +TuneOption::TuneOption(int n_trials, int early_stopping, + int num_measure_per_iter, int verbose, Builder builder, + Runner runner, Array measure_callbacks, + Array pre_search_callbacks) { auto node = make_object(); node->n_trials = n_trials; node->early_stopping = early_stopping; @@ -47,16 +46,16 @@ TuneOption TuneOptionNode::make(int n_trials, int early_stopping, node->runner = std::move(runner); node->measure_callbacks = std::move(measure_callbacks); node->pre_search_callbacks = std::move(pre_search_callbacks); - return TuneOption(node); + data_ = std::move(node); } std::pair > AutoSchedule(SearchTask task, SearchPolicy search_policy, TuneOption tune_option) { // Search for the best schedule ProgramMeasurer measurer = - ProgramMeasurerNode::make(tune_option->builder, tune_option->runner, - tune_option->measure_callbacks, - tune_option->verbose); + ProgramMeasurer(tune_option->builder, tune_option->runner, + tune_option->measure_callbacks, + tune_option->verbose); State state = search_policy->Search( task, tune_option->n_trials, tune_option->early_stopping, @@ -70,8 +69,8 @@ std::pair > AutoSchedule( std::string workload_key, Target target, Target target_host, SearchPolicy search_policy, HardwareParams hardware_params, TuneOption tune_option) { - ComputeDAG dag = ComputeDAGNode::make_by_workload_key(workload_key); - SearchTask task = SearchTaskNode::make( + ComputeDAG dag = ComputeDAG(workload_key); + SearchTask task = SearchTask( std::move(dag), std::move(workload_key), std::move(target), std::move(target_host), std::move(hardware_params)); return AutoSchedule(std::move(task), std::move(search_policy), @@ -83,9 +82,8 @@ TVM_REGISTER_GLOBAL("ansor.TuneOption") int num_measure_per_iter, int verbose, Builder builder, Runner runner, Array measure_callbacks, Array pre_search_callbacks) { - return TuneOptionNode::make(n_trials, early_stopping, - num_measure_per_iter, verbose, builder, - runner, measure_callbacks, pre_search_callbacks); + return TuneOption(n_trials, early_stopping, num_measure_per_iter, verbose, + builder, runner, measure_callbacks, pre_search_callbacks); }); TVM_REGISTER_GLOBAL("ansor.AutoScheduleBySearchTask") diff --git a/src/ansor/auto_schedule.h b/src/ansor/auto_schedule.h index 4e70ac0b577a..f17c043cfadd 100644 --- a/src/ansor/auto_schedule.h +++ b/src/ansor/auto_schedule.h @@ -34,7 +34,6 @@ namespace tvm { namespace ansor { /*! \brief Tuning and measurement options */ -class TuneOption; class TuneOptionNode : public Object { public: int n_trials; // Number of total measurement trials @@ -61,15 +60,24 @@ class TuneOptionNode : public Object { v->Visit("pre_search_callbacks", &pre_search_callbacks); } - static TuneOption make(int n_trials, int early_stopping, - int num_measure_per_iter, int verbose, Builder builder, - Runner runner, Array measure_callbacks, - Array pre_search_callbacks); - static constexpr const char* _type_key = "ansor.TuneOption"; TVM_DECLARE_FINAL_OBJECT_INFO(TuneOptionNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(TuneOption, ObjectRef, TuneOptionNode); + +/*! + * \brief Managed reference to TuneOptionNode. + * \sa TuneOptionNode + */ +class TuneOption : public ObjectRef { + public: + TuneOption(int n_trials, int early_stopping, int num_measure_per_iter, + int verbose, Builder builder, Runner runner, + Array measure_callbacks, + Array pre_search_callbacks); + + TVM_DEFINE_OBJECT_REF_METHODS(TuneOption, ObjectRef, TuneOptionNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(TuneOptionNode); +}; /*! \brief Auto schedule for a compute declaration */ std::pair > AutoSchedule( diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index 7b4857b34d76..13f64b2bdc89 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -241,7 +241,7 @@ static bool HasExpensiveOp(const PrimExpr& expr) { return found; } -AccessAnalyzer AccessAnalyzerNode::make(const Array& tensors) { +AccessAnalyzer::AccessAnalyzer(const Array& tensors) { auto node = make_object(); OperationMap has_branch; @@ -290,8 +290,8 @@ AccessAnalyzer AccessAnalyzerNode::make(const Array& tensors) { for (const auto& pair : node->read_from[op]) { const std::vector >& access = pair.second; for (const auto& index : access) { - if (!IsInjective(op, index, &axis_missing, &axis_duplicated, - &same_order)) { + if (!ansor::IsInjective(op, index, &axis_missing, &axis_duplicated, + &same_order)) { is_injective = false; is_strict_inlineable = false; break; @@ -356,7 +356,7 @@ AccessAnalyzer AccessAnalyzerNode::make(const Array& tensors) { } } - return AccessAnalyzer(node); + data_ = std::move(node); } bool AccessAnalyzer::NeedsMultiLevelTiling(const te::Operation &op) const { @@ -554,7 +554,6 @@ class FlopEstimator: public ExprFunctor { return ret; } - double VisitExprDefault_(const Object* op) final { fail = true; return -1.0; @@ -567,20 +566,20 @@ State ComputeDAG::GetInitState() const { return Downcast(operator->()->init_state); } -ComputeDAG ComputeDAGNode::make(Array tensors) { +ComputeDAG::ComputeDAG(Array tensors) { auto node = make_object(); FlopEstimator estimator; node->tensors = std::move(tensors); - node->access_analyzer = AccessAnalyzerNode::make(node->tensors); + node->access_analyzer = AccessAnalyzer(node->tensors); node->ops = Array(node->access_analyzer->ops_topo_order); node->flop_ct = estimator.EstimateFlop(node->ops); - node->init_state = StateNode::make(node->ops); + node->init_state = State(node->ops); - return ComputeDAG(node); + data_ = std::move(node); } -ComputeDAG ComputeDAGNode::make_by_workload_key(const std::string& workload_key) { +ComputeDAG::ComputeDAG(const std::string& workload_key) { Array tens; // Call python function to decode the workload_key and get the I/O tensors if (const auto* f = runtime::Registry::Get("ansor.workload_key_to_tensors")) { @@ -588,7 +587,7 @@ ComputeDAG ComputeDAGNode::make_by_workload_key(const std::string& workload_key) } else { LOG(FATAL) << "ansor.workload_key_to_tensors is not registered"; } - return ComputeDAGNode::make(std::move(tens)); + ComputeDAG(std::move(tens)); } std::string BaseName(const std::string& str) { @@ -938,7 +937,7 @@ void ComputeDAG::RewriteLayout( } } - pdag->init_state = StateNode::make(pdag->ops); + pdag->init_state = State(pdag->ops); Array old_tensors = pdag->tensors; ArrayNode* ptensors = pdag->tensors.CopyOnWrite(); @@ -1105,7 +1104,7 @@ void ComputeDAG::ReplayAndGetDAG(const std::vector &transform_steps, } } - *task_dag = ComputeDAGNode::make(new_tensors); + *task_dag = ComputeDAG(new_tensors); } @@ -1136,18 +1135,16 @@ void ComputeDAG::InferBoundCommon(StateNode* pstate) const { auto find_res = bounds.find(axis); if (find_res != bounds.end()) { - new_iters.push_back(IteratorNode::make(iter->name, (*find_res).second, - iter->iter_type, - iter->annotation, - &iter->ori_iters, - iter->attr)); + new_iters.push_back(Iterator(iter->name, (*find_res).second, + iter->iter_type, iter->annotation, + &iter->ori_iters, iter->attr)); } else { LOG(FATAL) << "Infer bound fails"; } } - pstate->stages[i] = StageNode::make(stage->op, stage->op_type, - std::move(new_iters), stage->compute_at, stage->attrs); + pstate->stages[i] = Stage(stage->op, stage->op_type, std::move(new_iters), + stage->compute_at, stage->attrs); } } @@ -1319,7 +1316,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_GLOBAL("ansor.ComputeDAG") .set_body_typed([](Array tensors) { - return ComputeDAGNode::make(tensors); + return ComputeDAG(tensors); }); TVM_REGISTER_GLOBAL("ansor.ComputeDAGGetInitState") diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h index 8da71f005f19..b1b60e678904 100644 --- a/src/ansor/compute_dag.h +++ b/src/ansor/compute_dag.h @@ -37,7 +37,6 @@ namespace tvm { namespace ansor { -class ComputeDAG; class AccessAnalyzer; class StateNode; class State; class Step; /*! \brief Read/Write access static analysis result */ @@ -54,15 +53,17 @@ class AccessAnalyzerNode : public Object { OperationMap is_output; std::vector ops_topo_order; - static AccessAnalyzer make(const Array& tensors); - static constexpr const char* _type_key = "ansor.AccessAnalyzer"; TVM_DECLARE_FINAL_OBJECT_INFO(AccessAnalyzerNode, Object); }; -/*! \brief Read/Write access static analysis result */ +/*! + * \brief Managed reference to AccessAnalyzerNode. + * \sa AccessAnalyzerNode + */ class AccessAnalyzer : public ObjectRef { public: + explicit AccessAnalyzer(const Array& tensors); // read/write access analysis bool NeedsMultiLevelTiling(const te::Operation& op) const; bool IsInjective(const te::Operation& op) const; @@ -121,9 +122,6 @@ class ComputeDAGNode : public Object { v->Visit("access_analyzer", &access_analyzer); } - static ComputeDAG make(Array tensors); - static ComputeDAG make_by_workload_key(const std::string& workload_key); - static constexpr const char* _type_key = "ansor.ComputeDAG"; TVM_DECLARE_FINAL_OBJECT_INFO(ComputeDAGNode, Object); }; @@ -135,9 +133,15 @@ enum LayoutRewriteLevel { kBothRewrite = 3, // Rewrite both placeholder and compute body in the compute dag }; -/*! \brief Compute declaration graph */ +/*! + * \brief Managed reference to ComputeDAGNode. + * \sa ComputeDAGNode + */ class ComputeDAG: public ObjectRef { public: + explicit ComputeDAG(Array tensors); + explicit ComputeDAG(const std::string& workload_key); + // Apply transform steps to the init state of this DAG, and get the equivalent tvm::schedule. // The return values can be used as arguments to tvm.build or tvm.lower std::pair > ApplySteps( diff --git a/src/ansor/cost_model/cost_model.cc b/src/ansor/cost_model/cost_model.cc index bbf15a241974..ee7bf8b26053 100644 --- a/src/ansor/cost_model/cost_model.cc +++ b/src/ansor/cost_model/cost_model.cc @@ -48,7 +48,7 @@ void RandomNumber(TVMArgs args, TVMRetValue* rv) { } } -CostModel RandomModelNode::make() { +RandomModel::RandomModel() { ObjectPtr node = make_object(); node->random_number_func = runtime::Registry::Get("ansor.cost_model.random_number"); @@ -58,7 +58,7 @@ CostModel RandomModelNode::make() { static PackedFunc cost_model_random_number(RandomNumber); node->random_number_func = &cost_model_random_number; } - return CostModel(node); + data_ = std::move(node); } void RandomModelNode::Update(const Array& inputs, @@ -71,11 +71,11 @@ void RandomModelNode::Predict(const SearchTask& task, (*random_number_func)(states.size(), static_cast(scores->data())); } -CostModel MeasureModelNode::make(Builder builder, Runner runner) { +MeasureModel::MeasureModel(Builder builder, Runner runner) { ObjectPtr node = make_object(); - node->measurer = ProgramMeasurerNode::make( - std::move(builder), std::move(runner), Array(), 0); - return CostModel(node); + node->measurer = ProgramMeasurer(std::move(builder), std::move(runner), + Array(), 0); + data_ = std::move(node); } void MeasureModelNode::Update(const Array& inputs, @@ -90,7 +90,7 @@ void MeasureModelNode::Predict(const SearchTask& task, inputs.clear(); inputs.reserve(states.size()); for (const auto& state : states) { - inputs.push_back(MeasureInputNode::make(task, state)); + inputs.push_back(MeasureInput(task, state)); } measurer->SilentMeasure(task, inputs, &results); @@ -101,14 +101,14 @@ void MeasureModelNode::Predict(const SearchTask& task, } } -CostModel PythonBasedModelNode::make(PackedFunc update_func, - PackedFunc predict_func, - PackedFunc predict_stage_func) { +PythonBasedModel::PythonBasedModel(PackedFunc update_func, + PackedFunc predict_func, + PackedFunc predict_stage_func) { auto node = make_object(); node->update_func = std::move(update_func); node->predict_func = std::move(predict_func); node->predict_stage_func = std::move(predict_stage_func); - return CostModel(node); + data_ = std::move(node); } void PythonBasedModelNode::Update(const Array& inputs, @@ -124,9 +124,8 @@ void PythonBasedModelNode::Predict(const SearchTask& task, static_cast(scores->data())); } -void PythonBasedModelNode::PredictStages( - const SearchTask& task, const std::vector& states, - std::vector* state_scores, +void PythonBasedModelNode::PredictStages(const SearchTask& task, + const std::vector& states, std::vector* state_scores, std::vector>* stage_scores) { int n_states = states.size(); int n_stages = task->compute_dag.GetInitState()->stages.size(); @@ -185,14 +184,14 @@ void PythonBasedModelNode::PredictStages( } TVM_REGISTER_GLOBAL("ansor.RandomModel").set_body_typed([]() { - return RandomModelNode::make(); + return RandomModel(); }); TVM_REGISTER_GLOBAL("ansor.PythonBasedModel") .set_body_typed([](PackedFunc update_func, PackedFunc predict_func, PackedFunc predict_stage_func) { - return PythonBasedModelNode::make(update_func, predict_func, - predict_stage_func); + return PythonBasedModel(update_func, predict_func, + predict_stage_func); }); } // namespace ansor diff --git a/src/ansor/cost_model/cost_model.h b/src/ansor/cost_model/cost_model.h index 472a3c201068..f38624a3572c 100644 --- a/src/ansor/cost_model/cost_model.h +++ b/src/ansor/cost_model/cost_model.h @@ -36,20 +36,20 @@ namespace ansor { using runtime::PackedFunc; -class CostModel; - /*! \brief The base class for cost model */ class CostModelNode: public Object { public: // Update the cost model according to new measurement pairs - virtual void Update(const Array& inputs, const Array& results) = 0; + virtual void Update(const Array& inputs, + const Array& results) = 0; // Predict the scores of states virtual void Predict(const SearchTask& task, const std::vector& states, std::vector* scores) = 0; // Predict the scores of all stages in states - virtual void PredictStages(const SearchTask& task, const std::vector& states, + virtual void PredictStages(const SearchTask& task, + const std::vector& states, std::vector* state_scores, std::vector>* stage_scores) { LOG(FATAL) << "Not Implemented"; @@ -65,9 +65,8 @@ class RandomModelNode: public CostModelNode { public: const PackedFunc* random_number_func; - static CostModel make(); - - void Update(const Array& inputs, const Array& results) final; + void Update(const Array& inputs, + const Array& results) final; void Predict(const SearchTask& task, const std::vector& states, std::vector* scores) final; @@ -75,14 +74,31 @@ class RandomModelNode: public CostModelNode { TVM_DECLARE_FINAL_OBJECT_INFO(RandomModelNode, CostModelNode); }; +/*! + * \brief Managed reference to RandomModelNode. + * \sa RandomModelNode + */ +class RandomModel : public CostModel { + public: + RandomModel(); + explicit RandomModel(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) + : CostModel(n) {} + + RandomModelNode* operator->() const { + return static_cast(data_.get()); + } + + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(RandomModel); + using ContainerType = RandomModelNode; +}; + /*! \brief The cost model returns actual cost by measurement */ class MeasureModelNode : public CostModelNode { public: ProgramMeasurer measurer; - static CostModel make(Builder builder, Runner runner); - - void Update(const Array& inputs, const Array& results) final; + void Update(const Array& inputs, + const Array& results) final; void Predict(const SearchTask& task, const std::vector& states, std::vector* scores) final; @@ -90,6 +106,18 @@ class MeasureModelNode : public CostModelNode { TVM_DECLARE_FINAL_OBJECT_INFO(MeasureModelNode, CostModelNode); }; +/*! + * \brief Managed reference to MeasureModelNode. + * \sa MeasureModelNode + */ +class MeasureModel : public CostModel { + public: + MeasureModel(Builder builder, Runner runner); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureModel, CostModel, + MeasureModelNode); +}; + /*! \brief A wrapper for cost model defined by python code * This class will call python's function */ class PythonBasedModelNode: public CostModelNode { @@ -98,10 +126,8 @@ class PythonBasedModelNode: public CostModelNode { PackedFunc predict_func; PackedFunc predict_stage_func; - static CostModel make(PackedFunc update_func, PackedFunc predict_func, - PackedFunc predict_stage_func); - - void Update(const Array& inputs, const Array& results) final; + void Update(const Array& inputs, + const Array& results) final; void Predict(const SearchTask& task, const std::vector& states, std::vector* scores) final; void PredictStages(const SearchTask& task, const std::vector& states, @@ -112,6 +138,19 @@ class PythonBasedModelNode: public CostModelNode { TVM_DECLARE_FINAL_OBJECT_INFO(PythonBasedModelNode, CostModelNode); }; +/*! + * \brief Managed reference to PythonBasedModelNode. + * \sa PythonBasedModelNode + */ +class PythonBasedModel : public CostModel { + public: + PythonBasedModel(PackedFunc update_func, PackedFunc predict_func, + PackedFunc predict_stage_func); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PythonBasedModel, CostModel, + PythonBasedModelNode); +}; + } // namespace ansor } // namespace tvm diff --git a/src/ansor/feature.cc b/src/ansor/feature.cc index 3b5849e22262..73f6bad0d432 100644 --- a/src/ansor/feature.cc +++ b/src/ansor/feature.cc @@ -1297,7 +1297,7 @@ void GetPerStmtFeaturesFromFile(const std::string& filename, std::vector min_costs; // read from file - LogReader reader = LogReaderNode::make(filename); + LogReader reader = LogReader(filename); auto cur_inp = make_object(); auto cur_res = make_object(); while (reader->ReadNext(cur_inp.get(), cur_res.get())) { @@ -1310,11 +1310,9 @@ void GetPerStmtFeaturesFromFile(const std::string& filename, auto find_res = task_cache.find(key); if (find_res == task_cache.end()) { // rebuild task - task = SearchTaskNode::make(ComputeDAGNode::make_by_workload_key(workload_key), - workload_key, - cur_inp->task->target, - cur_inp->task->target_host, - cur_inp->task->hardware_params); + task = SearchTask(ComputeDAG(workload_key), workload_key, + cur_inp->task->target, cur_inp->task->target_host, + cur_inp->task->hardware_params); task_id = task_cache.size(); // compute min cost for each task @@ -1378,11 +1376,9 @@ void GetPerStmtFeaturesFromMeasurePairs(const Array& inputs, task = inputs[i]->task; } else { // the measure input is incomplete // rebuild task for incomplete measure pairs read from file - task = SearchTaskNode::make(ComputeDAGNode::make_by_workload_key(workload_key), - workload_key, - inputs[i]->task->target, - inputs[i]->task->target_host, - inputs[i]->task->hardware_params); + task = SearchTask(ComputeDAG(workload_key), workload_key, + inputs[i]->task->target, inputs[i]->task->target_host, + inputs[i]->task->hardware_params); } task_id = task_cache.size(); diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index 23e005503873..ef4c4632e9bf 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -37,10 +37,10 @@ TVM_REGISTER_NODE_TYPE(StateNode); TVM_REGISTER_NODE_TYPE(IteratorNode); // Maker for other classes -Iterator IteratorNode::make(std::string name, Range range, - IteratorType iter_type, IteratorAnnotation annotation, - const std::vector* ori_iters, - std::string attr) { +Iterator::Iterator(std::string name, Range range, IteratorType iter_type, + IteratorAnnotation annotation, + const std::vector* ori_iters, + std::string attr) { auto node = make_object(); node->name = std::move(name); node->range = std::move(range); @@ -50,23 +50,22 @@ Iterator IteratorNode::make(std::string name, Range range, node->ori_iters = *ori_iters; } node->attr = std::move(attr); - return Iterator(node); + data_ = std::move(node); } - -Stage StageNode::make(te::Operation op) { +Stage::Stage(te::Operation op) { auto node = make_object(); if (op->IsInstance()) { node->op_type = kCompute; auto* pop = op.as(); for (const auto& axis : pop->axis) { - node->iters.push_back(IteratorNode::make(CleanName(axis->var->name_hint), - axis->dom, kSpace, kNone)); + node->iters.push_back(Iterator(CleanName(axis->var->name_hint), + axis->dom, kSpace, kNone)); } for (const auto& axis : pop->reduce_axis) { - node->iters.push_back(IteratorNode::make(CleanName(axis->var->name_hint), - axis->dom, kReduce, kNone)); + node->iters.push_back(Iterator(CleanName(axis->var->name_hint), + axis->dom, kReduce, kNone)); } } else if (op->IsInstance()) { node->op_type = kPlaceholder; @@ -78,67 +77,53 @@ Stage StageNode::make(te::Operation op) { node->op = std::move(op); node->attrs.auto_unroll_max_step = 0; node->attrs.storage_offset = 0; - return Stage(node); + data_ = std::move(node); } -Stage StageNode::make(te::Operation op, StageType op_type, - const std::vector& iters, - ComputeAtType compute_at, StageAttributes attrs) { +Stage::Stage(te::Operation op, StageType op_type, + const std::vector& iters, ComputeAtType compute_at, + StageAttributes attrs) { auto node = make_object(); node->op = std::move(op); node->op_type = op_type; node->iters = iters; node->compute_at = compute_at; node->attrs = attrs; - return Stage(node); + data_ = std::move(node); } -Stage StageNode::make(te::Operation op, StageType op_type, - std::vector&& iters, ComputeAtType compute_at, - StageAttributes attrs) { +Stage::Stage(te::Operation op, StageType op_type, std::vector&& iters, + ComputeAtType compute_at, StageAttributes attrs) { auto node = make_object(); node->op = std::move(op); node->op_type = op_type; node->iters = std::move(iters); node->compute_at = compute_at; node->attrs = attrs; - return Stage(node); -} - -State StateNode::make_empty_state() { - auto node = make_object(); - node->attach_map = AttachMapNode::make(); - node->complete = false; - node->aux_info = ObjectRef(); - return State(node); + data_ = std::move(node); } -State StateNode::make(const Array& ops) { +State::State(const Array& ops) { auto node = make_object(); for (const auto& op : ops) { - node->stages.push_back(StageNode::make(op)); + node->stages.push_back(Stage(op)); } - node->attach_map = AttachMapNode::make(); + node->attach_map = AttachMap(make_object()); node->complete = true; node->aux_info = ObjectRef(); - return State(node); + data_ = std::move(node); } -State StateNode::make(const std::vector& stages, - const std::vector& transform_steps, bool complete, - ObjectRef aux_info) { +State::State(const std::vector& stages, + const std::vector& transform_steps, bool complete, + ObjectRef aux_info) { auto node = make_object(); node->stages = stages; node->transform_steps = transform_steps; - node->attach_map = AttachMapNode::make(); + node->attach_map = AttachMap(make_object()); node->complete = complete; node->aux_info = std::move(aux_info); - return State(node); -} - -AttachMap AttachMapNode::make() { - auto node = make_object(); - return AttachMap(node); + data_ = std::move(node); } // Schedule primitives api @@ -149,7 +134,7 @@ void State::reorder(int stage_id, const std::vector& order) { "should be specified"; std::vector after_ids; GetIndices(stage->iters, order, &after_ids); - ReorderStep step = ReorderStepNode::make(stage_id, after_ids); + ReorderStep step = ReorderStep(stage_id, after_ids); CopyOnWrite()->transform_steps.push_back(step); DoReorderStep(step); } @@ -160,9 +145,9 @@ std::vector State::split(int stage_id, const Iterator& it, const Stage& stage = operator->()->stages[stage_id]; SplitStep step = - SplitStepNode::make(stage_id, GetIndex(stage->iters, it), - it->range.defined() ? it->range->extent : PrimExpr(), - lengths, inner_to_outer); + SplitStep(stage_id, GetIndex(stage->iters, it), + it->range.defined() ? it->range->extent : PrimExpr(), + lengths, inner_to_outer); CopyOnWrite()->transform_steps.push_back(step); return DoSplitStep(step); } @@ -171,7 +156,7 @@ std::vector State::follow_split(int stage_id, const Iterator& it, int src_step_id, int n_split) { const Stage& stage = operator->()->stages[stage_id]; - FollowSplitStep step = FollowSplitStepNode::make( + FollowSplitStep step = FollowSplitStep( stage_id, GetIndex(stage->iters, it), src_step_id, n_split); CopyOnWrite()->transform_steps.push_back(step); return DoFollowSplitStep(step); @@ -183,8 +168,8 @@ std::vector State::follow_fused_split( const Stage& stage = operator->()->stages[stage_id]; FollowFusedSplitStep step = - FollowFusedSplitStepNode::make(stage_id, GetIndex(stage->iters, it), - src_step_ids, level, factor_or_nparts); + FollowFusedSplitStep(stage_id, GetIndex(stage->iters, it), + src_step_ids, level, factor_or_nparts); CopyOnWrite()->transform_steps.push_back(step); return DoFollowFusedSplitStep(step); } @@ -193,14 +178,14 @@ Iterator State::fuse(int stage_id, const std::vector& iters) { const Stage& stage = operator->()->stages[stage_id]; std::vector indices; GetIndices(stage->iters, iters, &indices); - FuseStep step = FuseStepNode::make(stage_id, indices); + FuseStep step = FuseStep(stage_id, indices); CopyOnWrite()->transform_steps.push_back(step); return DoFuseStep(step); } Iterator State::vectorize(int stage_id, const Iterator& it) { const Stage& stage = operator->()->stages[stage_id]; - AnnotationStep step = AnnotationStepNode::make( + AnnotationStep step = AnnotationStep( stage_id, GetIndex(stage->iters, it), kVectorize); CopyOnWrite()->transform_steps.push_back(step); return DoAnnotationStep(step); @@ -209,7 +194,7 @@ Iterator State::vectorize(int stage_id, const Iterator& it) { Iterator State::parallel(int stage_id, const Iterator& it) { const Stage& stage = operator->()->stages[stage_id]; AnnotationStep step = - AnnotationStepNode::make(stage_id, GetIndex(stage->iters, it), kParallel); + AnnotationStep(stage_id, GetIndex(stage->iters, it), kParallel); CopyOnWrite()->transform_steps.push_back(step); return DoAnnotationStep(step); } @@ -217,7 +202,7 @@ Iterator State::parallel(int stage_id, const Iterator& it) { Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) { const Stage& stage = operator->()->stages[stage_id]; AnnotationStep step = - AnnotationStepNode::make(stage_id, GetIndex(stage->iters, it), kUnroll); + AnnotationStep(stage_id, GetIndex(stage->iters, it), kUnroll); // don't unroll if the extent is larger than max_unroll if (max_unroll != -1 && it->range.defined()) { @@ -235,20 +220,20 @@ Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) { void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) { const Stage& target_stage = operator->()->stages[target_stage_id]; - ComputeAtStep step = ComputeAtStepNode::make( + ComputeAtStep step = ComputeAtStep( stage_id, target_stage_id, GetIndex(target_stage->iters, target_iter)); CopyOnWrite()->transform_steps.push_back(step); return DoComputeAtStep(step); } void State::compute_root(int stage_id) { - ComputeRootStep step = ComputeRootStepNode::make(stage_id); + ComputeRootStep step = ComputeRootStep(stage_id); CopyOnWrite()->transform_steps.push_back(step); return DoComputeRootStep(step); } void State::compute_inline(int stage_id) { - ComputeInlineStep step = ComputeInlineStepNode::make(stage_id); + ComputeInlineStep step = ComputeInlineStep(stage_id); CopyOnWrite()->transform_steps.push_back(step); return DoComputeInlineStep(step); } @@ -257,10 +242,10 @@ Iterator State::bind_thread(int stage_id, const Iterator& it, IteratorAnnotation thread_type) { const Stage& stage = operator->()->stages[stage_id]; if (thread_type < kVThread || thread_type > kThreadY) { - LOG(FATAL) << "thread_type error, valide: kVThread, kBlockX, kThreadX, " - << "kThreadY"; + LOG(FATAL) << "thread_type error, valide: kVThread, kBlockX, kBlockY, " + << "kThreadX, kThreadY"; } - AnnotationStep step = AnnotationStepNode::make( + AnnotationStep step = AnnotationStep( stage_id, GetIndex(stage->iters, it), thread_type); CopyOnWrite()->transform_steps.push_back(step); return DoAnnotationStep(step); @@ -270,14 +255,14 @@ int State::cache_read(int stage_id, const std::string& scope_name, const std::vector& reader_stage_ids, const ComputeDAG& task_dag) { CacheReadStep step = - CacheReadStepNode::make(stage_id, scope_name, reader_stage_ids); + CacheReadStep(stage_id, scope_name, reader_stage_ids); CopyOnWrite()->transform_steps.push_back(step); return DoCacheReadStep(step, task_dag); } int State::cache_write(int stage_id, const std::string& scope_name, const ComputeDAG& task_dag) { - CacheWriteStep step = CacheWriteStepNode::make(stage_id, scope_name); + CacheWriteStep step = CacheWriteStep(stage_id, scope_name); CopyOnWrite()->transform_steps.push_back(step); return DoCacheWriteStep(step, task_dag); } @@ -286,7 +271,7 @@ void State::pragma(int stage_id, const Iterator& it, const std::string& pragma_type) { const Stage& stage = operator->()->stages[stage_id]; PragmaStep step = - PragmaStepNode::make(stage_id, GetIndex(stage->iters, it), pragma_type); + PragmaStep(stage_id, GetIndex(stage->iters, it), pragma_type); CopyOnWrite()->transform_steps.push_back(step); return DoPragmaStep(step); } @@ -294,8 +279,8 @@ void State::pragma(int stage_id, const Iterator& it, int State::rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& task_dag) { const Stage& stage = operator->()->stages[stage_id]; - RfactorStep step = RfactorStepNode::make(stage_id, GetIndex(stage->iters, it), - factor_iter_id); + RfactorStep step = RfactorStep(stage_id, GetIndex(stage->iters, it), + factor_iter_id); CopyOnWrite()->transform_steps.push_back(step); return DoRfactorStep(step, task_dag); } @@ -303,7 +288,7 @@ int State::rfactor(int stage_id, const Iterator& it, int factor_iter_id, void State::storage_align(int stage_id, const Iterator& it, int factor, int offset) { const Stage& stage = operator->()->stages[stage_id]; - StorageAlignStep step = StorageAlignStepNode::make( + StorageAlignStep step = StorageAlignStep( stage_id, GetIndex(stage->iters, it), factor, offset); CopyOnWrite()->transform_steps.push_back(step); return DoStorageAlignStep(step); @@ -312,7 +297,7 @@ void State::storage_align(int stage_id, const Iterator& it, int factor, Iterator State::tensorize(int stage_id, const Iterator& it, std::string ti_func_name) { const Stage& stage = operator->()->stages[stage_id]; - TensorizeStep step = TensorizeStepNode::make( + TensorizeStep step = TensorizeStep( stage_id, GetIndex(stage->iters, it), ti_func_name); CopyOnWrite()->transform_steps.push_back(step); return DoTensorizeStep(step); @@ -328,7 +313,7 @@ void State::DoReorderStep(const ReorderStep& step) { } StateNode* pstate = CopyOnWrite(); - pstate->stages[step->stage_id] = StageNode::make( + pstate->stages[step->stage_id] = Stage( stage->op, stage->op_type, std::move(iters), stage->compute_at, stage->attrs); } @@ -362,12 +347,12 @@ std::vector State::DoSplitStepCommon( } Iterator res; if (l.defined() && tosplit_min.defined() && tosplit_extent.defined()) { - res = IteratorNode::make(name, Range::make_by_min_extent(tosplit_min, l), - it->iter_type, kNone); + res = Iterator(name, Range::make_by_min_extent(tosplit_min, l), + it->iter_type, kNone); tosplit_min = 0; tosplit_extent = indexdiv(tosplit_extent + l - 1, l); } else { - res = IteratorNode::make(name, Range(), it->iter_type, kNone); + res = Iterator(name, Range(), it->iter_type, kNone); tosplit_min = tosplit_extent = PrimExpr(); } outs.push_back(std::move(res)); @@ -379,12 +364,12 @@ std::vector State::DoSplitStepCommon( } if (inner_to_outer) { outs.push_back( - IteratorNode::make(it->name + ".0", range, it->iter_type, kNone)); + Iterator(it->name + ".0", range, it->iter_type, kNone)); std::reverse(outs.begin(), outs.end()); } else { outs.push_back( - IteratorNode::make(it->name + "." + std::to_string(lengths.size()), - range, it->iter_type, kNone)); + Iterator(it->name + "." + std::to_string(lengths.size()), + range, it->iter_type, kNone)); } std::vector new_iters; @@ -395,7 +380,7 @@ std::vector State::DoSplitStepCommon( stage->iters.end()); StateNode* pstate = CopyOnWrite(); - pstate->stages[stage_id] = StageNode::make( + pstate->stages[stage_id] = Stage( stage->op, stage->op_type, std::move(new_iters), stage->compute_at, stage->attrs); @@ -479,7 +464,7 @@ Iterator State::DoFuseStep(const FuseStep& step) { range = Range::make_by_min_extent(0, new_extent); } Iterator new_it = - IteratorNode::make(new_name, range, new_iter_type, kNone, &ori_iters); + Iterator(new_name, range, new_iter_type, kNone, &ori_iters); std::vector new_iters; new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + step->fused_ids.front()); @@ -489,7 +474,7 @@ Iterator State::DoFuseStep(const FuseStep& step) { stage->iters.end()); StateNode* pstate = CopyOnWrite(); - pstate->stages[stage_id] = StageNode::make( + pstate->stages[stage_id] = Stage( stage->op, stage->op_type, std::move(new_iters), stage->compute_at, stage->attrs); @@ -518,9 +503,9 @@ Iterator State::DoAnnotationStep(const AnnotationStep& step) { Iterator it = stage->iters[step->iter_id]; CHECK_EQ(it->annotation, IteratorAnnotation::kNone); - Iterator new_it = IteratorNode::make(it->name, it->range, it->iter_type, - step->annotation, &it->ori_iters, - it->attr); + Iterator new_it = Iterator(it->name, it->range, it->iter_type, + step->annotation, &it->ori_iters, + it->attr); Stage new_stage = stage; new_stage.CopyOnWrite()->iters[step->iter_id] = new_it; StateNode* pstate = CopyOnWrite(); @@ -547,15 +532,14 @@ void State::DoComputeAtStep(const ComputeAtStep& step) { // We do this to keep the AnnotateCPU pass to annotate more efficiently. new_iters.push_back(it); } else { - new_iters.push_back(IteratorNode::make(it->name, Range(), it->iter_type, - it->annotation, &it->ori_iters, - it->attr)); + new_iters.push_back(Iterator(it->name, Range(), it->iter_type, + it->annotation, &it->ori_iters, it->attr)); } } StateNode* pstate = CopyOnWrite(); pstate->stages[step->stage_id] = - StageNode::make(stage->op, stage->op_type, std::move(new_iters), kIter, + Stage(stage->op, stage->op_type, std::move(new_iters), kIter, stage->attrs); pstate->attach_map.SetComputeAtIter(step->stage_id, step->target_stage_id, step->target_iter_id); @@ -569,16 +553,15 @@ void State::DoComputeRootStep(const ComputeRootStep& step) { // ComputeDAG::ReplayAndInferBound std::vector new_iters; for (const Iterator& it : stage->iters) { - new_iters.push_back(IteratorNode::make(it->name, Range(), it->iter_type, - it->annotation, &it->ori_iters, - it->attr)); + new_iters.push_back(Iterator(it->name, Range(), it->iter_type, + it->annotation, &it->ori_iters, it->attr)); } // update attach map StateNode* pstate = CopyOnWrite(); - pstate->stages[step->stage_id] = - StageNode::make(stage->op, stage->op_type, std::move(new_iters), kRoot, - stage->attrs); + pstate->stages[step->stage_id] = Stage(stage->op, stage->op_type, + std::move(new_iters), kRoot, + stage->attrs); pstate->attach_map.DeleteStage(step->stage_id); } @@ -647,7 +630,7 @@ int State::DoCacheReadStep(const CacheReadStep& step, const ComputeDAG& dag) { operator->()->task_dag->ops[step->stage_id]; pstate->stages.insert( pstate->stages.begin() + step->stage_id + 1, - StageNode::make(operator->()->task_dag->ops[step->stage_id + 1])); + Stage(operator->()->task_dag->ops[step->stage_id + 1])); for (size_t i = step->stage_id + 2; i < operator->()->stages.size(); ++i) { pstate->stages[i].CopyOnWrite()->op = operator->()->task_dag->ops[i]; } @@ -667,9 +650,8 @@ int State::DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag) { } } - int last_dag_op_size = pstate->task_dag.defined() - ? pstate->task_dag->ops.size() - : dag->ops.size(); + int last_dag_op_size = pstate->task_dag.defined() ? + pstate->task_dag->ops.size() : dag->ops.size(); dag.ReplayAndGetDAG(replay_steps, &(pstate->task_dag)); int added_ops = pstate->task_dag->ops.size() - last_dag_op_size; CHECK_GE(added_ops, 1); @@ -679,9 +661,9 @@ int State::DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag) { // Should insert new stage, update target stage, update the later stage's op pstate->stages.insert( pstate->stages.begin() + step->stage_id, - StageNode::make(operator->()->task_dag->ops[step->stage_id])); + Stage(operator->()->task_dag->ops[step->stage_id])); pstate->stages[step->stage_id + 1] = - StageNode::make(operator->()->task_dag->ops[step->stage_id + 1]); + Stage(operator->()->task_dag->ops[step->stage_id + 1]); int next_stage_id = step->stage_id + 2; // Notice: added_ops should actually assert to be 1 // branch of 2 here is somehow a hack to TVM's cache_write bug with @@ -691,7 +673,7 @@ int State::DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag) { if (added_ops == 2) { pstate->stages.insert( pstate->stages.begin() + next_stage_id, - StageNode::make(operator->()->task_dag->ops[next_stage_id])); + Stage(operator->()->task_dag->ops[next_stage_id])); next_stage_id++; } else if (added_ops > 2) { LOG(ERROR) << "Unexpected behavior of CacheWrite."; @@ -737,10 +719,10 @@ int State::DoRfactorStep(const RfactorStep& step, const ComputeDAG& dag) { // Should insert new stage, update target stage, update the later stage's op pstate->stages.insert( pstate->stages.begin() + step->stage_id, - StageNode::make(operator->()->task_dag->ops[step->stage_id])); + Stage(operator->()->task_dag->ops[step->stage_id])); // maintain the compute_at type of target stage Stage target_stage = - StageNode::make(operator->()->task_dag->ops[step->stage_id + 1]); + Stage(operator->()->task_dag->ops[step->stage_id + 1]); target_stage.CopyOnWrite()->compute_at = compute_at_type; pstate->stages[step->stage_id + 1] = target_stage; @@ -762,7 +744,7 @@ void State::DoStorageAlignStep(const StorageAlignStep& step) { Iterator State::DoTensorizeStep(const TensorizeStep& step) { const Stage& stage = operator->()->stages[step->stage_id]; Iterator it = stage->iters[step->iter_id]; - Iterator new_it = IteratorNode::make(it->name, it->range, it->iter_type, + Iterator new_it = Iterator(it->name, it->range, it->iter_type, IteratorAnnotation::kTensorized, &it->ori_iters, step->ti_func_name); Stage new_stage = stage; new_stage.CopyOnWrite()->iters[step->iter_id] = new_it; @@ -1017,7 +999,7 @@ void AttachMap::DeleteStageEntry(AttachMapNode* pnode, int stage_id) { } AttachMap AttachMap::ApplyStageIdOfffset(int start_id, int offset) const { - AttachMap map = AttachMapNode::make(); + AttachMap map = AttachMap(make_object()); auto pmap = map.CopyOnWrite(); for (const auto& x : operator->()->stage_to_attach_iter) { auto key = x.first; diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h index 31ed5274184d..2d64db11fc18 100644 --- a/src/ansor/loop_state.h +++ b/src/ansor/loop_state.h @@ -93,11 +93,6 @@ class IteratorNode : public Object { std::vector ori_iters; // The original iterators before fusion std::string attr; - static Iterator make(std::string name, Range range, - IteratorType iter_type, IteratorAnnotation annotation, - const std::vector* ori_iters = nullptr, - std::string attr = ""); - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name", &name); v->Visit("range", &range); @@ -107,19 +102,21 @@ class IteratorNode : public Object { static constexpr const char *_type_key = "ansor.Iterator"; TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(Iterator, ObjectRef, IteratorNode); -// Forward decelerations -class Stage; class State; -class AttachMap; +/*! + * \brief Managed reference to IteratorNode. + * \sa IteratorNode + */ +class Iterator : public ObjectRef { + public: + Iterator(std::string name, Range range, IteratorType iter_type, + IteratorAnnotation annotation, + const std::vector* ori_iters = nullptr, + std::string attr = ""); -class ReorderStep; class SplitStep; class FollowSplitStep; -class FollowFusedSplitStep; -class FuseStep; class AnnotationStep; -class ComputeAtStep; class ComputeRootStep; class ComputeInlineStep; -class CacheReadStep; class CacheWriteStep; -class PragmaStep; class RfactorStep; class StorageAlignStep; -class TensorizeStep; + TVM_DEFINE_OBJECT_REF_METHODS(Iterator, ObjectRef, IteratorNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(IteratorNode); +}; /*! \brief Stage-level attributes */ struct StageAttributes { @@ -143,23 +140,34 @@ class StageNode : public Object { v->Visit("op", &op); } - static Stage make(te::Operation op); - static Stage make(te::Operation op, StageType op_type, - const std::vector& iters, - ComputeAtType compute_at, StageAttributes attrs); - static Stage make(te::Operation op, StageType op_type, - std::vector&& iters, - ComputeAtType compute_at, StageAttributes attrs); - static constexpr const char *_type_key = "ansor.Stage"; TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(Stage, ObjectRef, StageNode); + +/*! + * \brief Managed reference to StageNode. + * \sa StageNode + */ +class Stage : public ObjectRef { + public: + explicit Stage(te::Operation op); + Stage(te::Operation op, StageType op_type, + const std::vector& iters, + ComputeAtType compute_at, StageAttributes attrs); + Stage(te::Operation op, StageType op_type, + std::vector&& iters, + ComputeAtType compute_at, StageAttributes attrs); + + TVM_DEFINE_OBJECT_REF_METHODS(Stage, ObjectRef, StageNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(StageNode); +}; /*! \brief stores the compute_at relation between stages * This stores a bi-directional mapping from stages and iter: - * 1. Stage to its attached iterator 2. Iterator to the stage attached to it - */ + * 1. Stage to its attached iterator 2. Iterator to the stage attached to it + * + * You can use AttachMapNode::stage_to_attach_iter and AttachMapNode::iter_to_attached_stages + * to query the relations */ class AttachMapNode: public Object { public: using StageKey = int; @@ -168,18 +176,14 @@ class AttachMapNode: public Object { std::unordered_map stage_to_attach_iter; std::unordered_map> iter_to_attached_stages; - static AttachMap make(); - static constexpr const char* _type_key = "ansor.AttachMap"; TVM_DECLARE_FINAL_OBJECT_INFO(AttachMapNode, Object); }; -/*! \brief stores the compute_at relation between stages - * This stores a bi-directional mapping from stages and iter: - * 1. Stage to its attached iterator 2. Iterator to the stage attached to it - * - * You can use AttachMapNode::stage_to_attach_iter and AttachMapNode::iter_to_attached_stages - * to query the relations */ +/*! + * \brief Managed reference to AttachMapNode. + * \sa AttachMapNode + */ class AttachMap : public ObjectRef { public: using StageKey = int; @@ -214,7 +218,17 @@ class StepNode: public Object { }; TVM_DEFINE_MUTABLE_OBJECT_REF(Step, StepNode); -/*! \brief The loop state and corresponding history steps to reach this state */ +// Step forward decelerations +class ReorderStep; class SplitStep; class FollowSplitStep; +class FollowFusedSplitStep; +class FuseStep; class AnnotationStep; +class ComputeAtStep; class ComputeRootStep; class ComputeInlineStep; +class CacheReadStep; class CacheWriteStep; +class PragmaStep; class RfactorStep; class StorageAlignStep; +class TensorizeStep; + +/*! \brief A state in the search process. + * It consists of the current loop structure and the history steps to reach this state. */ class StateNode: public Object { public: std::vector stages; // Current stages and loop structures @@ -232,22 +246,29 @@ class StateNode: public Object { v->Visit("task_dag", &task_dag); } - static State make_empty_state(); - static State make(const Array& ops); - static State make(const std::vector& stages, - const std::vector& transform_steps, bool complete, - ObjectRef aux_info); - static constexpr const char* _type_key = "ansor.State"; TVM_DECLARE_FINAL_OBJECT_INFO(StateNode, Object); }; -/*! \brief A state in the search process. - * It consists of the current loop structure and the history steps to reach this state. */ +/*! + * \brief Managed reference to StateNode. + * \sa StateNode + */ class State : public ObjectRef { public: + explicit State(const Array& ops); + State(const std::vector& stages, + const std::vector& transform_steps, bool complete, + ObjectRef aux_info); + // Schedule primitives void reorder(int stage_id, const std::vector& order); + void compute_at(int stage_id, int target_stage_id, + const Iterator& target_iter); + void compute_root(int stage_id); + void compute_inline(int stage_id); + void pragma(int stage_id, const Iterator& it, const std::string& pragma_type); + void storage_align(int stage_id, const Iterator& it, int factor, int offset); std::vector split(int stage_id, const Iterator& it, const std::vector& lengths, bool inner_to_outer = true); @@ -264,12 +285,6 @@ class State : public ObjectRef { IteratorAnnotation thread_type); Iterator tensorize(int stage_id, const Iterator& it, std::string ti_func_name); - void compute_at(int stage_id, int target_stage_id, - const Iterator& target_iter); - void compute_root(int stage_id); - void compute_inline(int stage_id); - void pragma(int stage_id, const Iterator& it, const std::string& pragma_type); - void storage_align(int stage_id, const Iterator& it, int factor, int offset); int cache_read(int stage_id, const std::string& scope_name, const std::vector& reader_stage_ids, const ComputeDAG& task_dag); @@ -283,17 +298,17 @@ class State : public ObjectRef { * We separate these functions out, * so you can call them for replay easily given history steps */ void DoReorderStep(const ReorderStep& step); + void DoComputeAtStep(const ComputeAtStep& step); + void DoComputeRootStep(const ComputeRootStep& step); + void DoComputeInlineStep(const ComputeInlineStep& step); + void DoPragmaStep(const PragmaStep& step); + void DoStorageAlignStep(const StorageAlignStep& step); std::vector DoSplitStep(const SplitStep& step); std::vector DoFollowSplitStep(const FollowSplitStep& step); std::vector DoFollowFusedSplitStep(const FollowFusedSplitStep& step); Iterator DoFuseStep(const FuseStep& step); Iterator DoAnnotationStep(const AnnotationStep& step); Iterator DoTensorizeStep(const TensorizeStep& step); - void DoComputeAtStep(const ComputeAtStep& step); - void DoComputeRootStep(const ComputeRootStep& step); - void DoComputeInlineStep(const ComputeInlineStep& step); - void DoPragmaStep(const PragmaStep& step); - void DoStorageAlignStep(const StorageAlignStep& step); int DoCacheReadStep(const CacheReadStep& step, const ComputeDAG& dag); int DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag); int DoRfactorStep(const RfactorStep& step, const ComputeDAG& dag); diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc index 474ea048ebad..4ae35fb410a9 100644 --- a/src/ansor/measure.cc +++ b/src/ansor/measure.cc @@ -58,11 +58,11 @@ const char* ErrorNoToStr[] = { }; // Measure input and result -MeasureInput MeasureInputNode::make(SearchTask task, State state) { +MeasureInput::MeasureInput(SearchTask task, State state) { auto node = make_object(); node->task = std::move(task); node->state = std::move(state); - return MeasureInput(node); + data_ = std::move(node); } MeasureInput MeasureInputNode::copy() const { @@ -72,28 +72,28 @@ MeasureInput MeasureInputNode::copy() const { return MeasureInput(node); } -BuildResult BuildResultNode::make(std::string filename, Array args, - int error_no, std::string error_msg, - double time_cost) { +BuildResult::BuildResult(std::string filename, Array args, + int error_no, std::string error_msg, + double time_cost) { auto node = make_object(); node->filename = std::move(filename); node->args = std::move(args); node->error_no = error_no; node->error_msg = std::move(error_msg); node->time_cost = time_cost; - return BuildResult(node); + data_ = std::move(node); } -MeasureResult MeasureResultNode::make(Array costs, int error_no, - std::string error_msg, double all_cost, - double timestamp) { +MeasureResult::MeasureResult(Array costs, int error_no, + std::string error_msg, double all_cost, + double timestamp) { auto node = make_object(); node->costs = std::move(costs); node->error_no = error_no; node->error_msg = std::move(error_msg); node->all_cost = all_cost; node->timestamp = timestamp; - return MeasureResult(node); + data_ = std::move(node); } MeasureResult MeasureResultNode::copy() const { @@ -107,13 +107,13 @@ MeasureResult MeasureResultNode::copy() const { } // LocalBuilder -Builder LocalBuilderNode::make(int timeout, int n_parallel, - const std::string& build_func) { +LocalBuilder::LocalBuilder(int timeout, int n_parallel, + const std::string& build_func) { auto node = make_object(); node->timeout = timeout; node->n_parallel = n_parallel; node->build_func = build_func; - return Builder(node); + data_ = std::move(node); } Array LocalBuilderNode::Build(const Array& inputs, @@ -129,10 +129,9 @@ Array LocalBuilderNode::Build(const Array& inputs, } // RPC Runner -Runner RPCRunnerNode::make(const std::string& key, const std::string& host, - int port, int priority, int timeout, int n_parallel, - int number, int repeat, int min_repeat_ms, - double cooldown_interval) { +RPCRunner::RPCRunner(const std::string& key, const std::string& host, int port, + int priority, int timeout, int n_parallel, int number, + int repeat, int min_repeat_ms, double cooldown_interval) { auto node = make_object(); node->key = key; node->host = host; @@ -144,7 +143,7 @@ Runner RPCRunnerNode::make(const std::string& key, const std::string& host, node->repeat = repeat; node->min_repeat_ms = min_repeat_ms; node->cooldown_interval = cooldown_interval; - return Runner(node); + data_ = std::move(node); } Array RPCRunnerNode::Run(const Array& inputs, @@ -162,15 +161,15 @@ Array RPCRunnerNode::Run(const Array& inputs, } // Local Runner -Runner LocalRunnerNode::make(int timeout, int number, int repeat, - int min_repeat_ms, double cooldown_interval) { +LocalRunner::LocalRunner(int timeout, int number, int repeat, + int min_repeat_ms, double cooldown_interval) { ObjectPtr node = make_object(); node->timeout = timeout; node->number = number; node->repeat = repeat; node->min_repeat_ms = min_repeat_ms; node->cooldown_interval = cooldown_interval; - return Runner(node); + data_ = std::move(node); } Array LocalRunnerNode::Run( @@ -188,19 +187,17 @@ Array LocalRunnerNode::Run( } // Program Measurer -ProgramMeasurer ProgramMeasurerNode::make(Builder builder, Runner runner, - Array callbacks, - int verbose, - int max_continous_error) { +ProgramMeasurer::ProgramMeasurer(Builder builder, Runner runner, + Array callbacks, int verbose, + int max_continous_error) { auto node = make_object(); node->builder = std::move(builder); node->runner = std::move(runner); node->callbacks = std::move(callbacks); node->verbose = verbose; - node->max_continous_error = max_continous_error < 0 - ? DEFAULT_MAX_CONTINOUS_ERROR - : max_continous_error; - return ProgramMeasurer(node); + node->max_continous_error = max_continous_error < 0 ? + ProgramMeasurerNode::DEFAULT_MAX_CONTINOUS_ERROR : max_continous_error; + data_ = std::move(node); } void ProgramMeasurerNode::Reset() { @@ -346,13 +343,21 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); TVM_REGISTER_GLOBAL("ansor.MeasureInput") -.set_body_typed(MeasureInputNode::make); +.set_body_typed([](SearchTask task, State state) { + return MeasureInput(task, state); +}); TVM_REGISTER_GLOBAL("ansor.BuildResult") -.set_body_typed(BuildResultNode::make); +.set_body_typed([](std::string filename, Array args, + int error_no, std::string error_msg, double time_cost) { + return BuildResult(filename, args, error_no, error_msg, time_cost); +}); TVM_REGISTER_GLOBAL("ansor.MeasureResult") -.set_body_typed(MeasureResultNode::make); +.set_body_typed([](Array costs, int error_no, std::string error_msg, + double all_cost, double timestamp) { + return MeasureResult(costs, error_no, error_msg, all_cost, timestamp); +}); TVM_REGISTER_GLOBAL("ansor.BuilderBuild") .set_body_typed([](const Builder& builder, @@ -367,16 +372,31 @@ TVM_REGISTER_GLOBAL("ansor.RunnerRun") }); TVM_REGISTER_GLOBAL("ansor.LocalBuilder") -.set_body_typed(LocalBuilderNode::make); +.set_body_typed([](int timeout, int n_parallel, const std::string& build_func) { + return LocalBuilder(timeout, n_parallel, build_func); +}); TVM_REGISTER_GLOBAL("ansor.LocalRunner") -.set_body_typed(LocalRunnerNode::make); +.set_body_typed([](int timeout, int number, int repeat, + int min_repeat_ms, double cooldown_interval) { + return LocalRunner(timeout, number, repeat, min_repeat_ms, cooldown_interval); +}); TVM_REGISTER_GLOBAL("ansor.RPCRunner") -.set_body_typed(RPCRunnerNode::make); +.set_body_typed([](const std::string& key, const std::string& host, int port, + int priority, int timeout, int n_parallel, int number, + int repeat, int min_repeat_ms, double cooldown_interval){ + return RPCRunner(key, host, port, priority, timeout, n_parallel, number, + repeat, min_repeat_ms, cooldown_interval); +}); TVM_REGISTER_GLOBAL("ansor.ProgramMeasurer") -.set_body_typed(ProgramMeasurerNode::make); +.set_body_typed([](Builder builder, Runner runner, + Array callbacks, int verbose, + int max_continous_error = -1) { + return ProgramMeasurer(builder, runner, callbacks, verbose, + max_continous_error); +}); } // namespace ansor diff --git a/src/ansor/measure.h b/src/ansor/measure.h index 6e432ba9c88b..a6db55f6181e 100644 --- a/src/ansor/measure.h +++ b/src/ansor/measure.h @@ -56,7 +56,7 @@ extern const char *ErrorNoToStr[]; // Inputs and results of one measurement -/* \brief Store the input of a measurement */ +/*! \brief Store the input of a measurement */ class MeasureInputNode: public Object { public: SearchTask task; // The search task @@ -67,20 +67,30 @@ class MeasureInputNode: public Object { v->Visit("state", &state); } - static MeasureInput make(SearchTask task, State state); MeasureInput copy() const; // Do deep copy static constexpr const char* _type_key = "ansor.MeasureInput"; TVM_DECLARE_FINAL_OBJECT_INFO(MeasureInputNode, Object); }; -TVM_DEFINE_OBJECT_REF(MeasureInput, MeasureInputNode); -/* \brief Store the input of a build */ +/*! + * \brief Managed reference to MeasureInputNode. + * \sa MeasureInputNode + */ +class MeasureInput : public ObjectRef { + public: + MeasureInput(SearchTask task, State state); + + TVM_DEFINE_OBJECT_REF_METHODS(MeasureInput, ObjectRef, MeasureInputNode); +}; + +/*! \brief Store the input of a build */ class BuildResultNode: public Object { public: std::string filename; // The filename of built binary file Array args; // The arguments - int error_no; // The error code (see MeasureErrorNO). 0 means no error. + int error_no; // The error code (see MeasureErrorNO). + // 0 means no error. std::string error_msg; // The error message if there is any error double time_cost; // The time cost of build @@ -92,19 +102,27 @@ class BuildResultNode: public Object { v->Visit("time_cost", &time_cost); } - static BuildResult make(std::string filename, Array args, - int error_no, std::string error_msg, double time_cost); - static constexpr const char* _type_key = "ansor.BuildResult"; TVM_DECLARE_FINAL_OBJECT_INFO(BuildResultNode, Object); }; -TVM_DEFINE_OBJECT_REF(BuildResult, BuildResultNode); -/* \brief Store the results of a measurement */ +/*! + * \brief Managed reference to BuildResultNode. + * \sa BuildResultNode + */ +class BuildResult : public ObjectRef { + public: + BuildResult(std::string filename, Array args, + int error_no, std::string error_msg, double time_cost); + TVM_DEFINE_OBJECT_REF_METHODS(BuildResult, ObjectRef, BuildResultNode); +}; + +/*! \brief Store the results of a measurement */ class MeasureResultNode: public Object { public: Array costs; // The time costs of execution - int error_no; // The error code (see MeasureErrorNO). 0 means no error. + int error_no; // The error code (see MeasureErrorNO). + // 0 means no error. std::string error_msg; // The error message if there is any error double all_cost; // The time cost of build and run double timestamp; // The time stamps of this measurement @@ -119,16 +137,23 @@ class MeasureResultNode: public Object { MeasureResult copy() const; // Do deep copy - static MeasureResult make(Array costs, int error_no, std::string error_msg, - double all_cost, double timestamp); - static constexpr const char* _type_key = "ansor.MeasureResult"; TVM_DECLARE_FINAL_OBJECT_INFO(MeasureResultNode, Object); }; -TVM_DEFINE_OBJECT_REF(MeasureResult, MeasureResultNode); +/*! + * \brief Managed reference to MeasureResultNode. + * \sa MeasureResultNode + */ +class MeasureResult : public ObjectRef { + public: + MeasureResult(Array costs, int error_no, std::string error_msg, + double all_cost, double timestamp); + + TVM_DEFINE_OBJECT_REF_METHODS(MeasureResult, ObjectRef, MeasureResultNode); +}; -/* \brief Bass class of measurement callbacks */ +/*! \brief Bass class of measurement callbacks */ class MeasureCallbackNode: public Object { public: /*! \biref Callback function that will be called on measurement input/result pairs @@ -141,10 +166,8 @@ class MeasureCallbackNode: public Object { }; TVM_DEFINE_MUTABLE_OBJECT_REF(MeasureCallback, MeasureCallbackNode); - // Base class for builder and runner - -/* \brief Builder that builds the programs */ +/*! \brief Builder that builds the programs */ class BuilderNode: public Object { public: int n_parallel; // The number of tasks to run in parallel @@ -158,7 +181,7 @@ class BuilderNode: public Object { }; TVM_DEFINE_MUTABLE_OBJECT_REF(Builder, BuilderNode); -/* \brief Runner that runs the built programs and measure the time cost */ +/*! \brief Runner that runs the built programs and measure the time cost */ class RunnerNode: public Object { public: int timeout; // Timeout of a run @@ -175,20 +198,30 @@ TVM_DEFINE_MUTABLE_OBJECT_REF(Runner, RunnerNode); // Implementation of various builders and runners -/* \brief LocalBuilder use local CPU cores to build programs in parallel */ +/*! \brief LocalBuilder use local CPU cores to build programs in parallel */ class LocalBuilderNode: public BuilderNode { public: std::string build_func; // Build function - static Builder make(int timeout, int n_parallel, const std::string& build_func); - Array Build(const Array& inputs, int verbose) final; static constexpr const char* _type_key = "ansor.LocalBuilder"; TVM_DECLARE_FINAL_OBJECT_INFO(LocalBuilderNode, BuilderNode); }; -/* \brief RPCRunner that uses RPC call to measures the time cost of programs on remote devices */ +/*! + * \brief Managed reference to LocalBuilderNode. + * \sa LocalBuilderNode + */ +class LocalBuilder: public Builder { + public: + LocalBuilder(int timeout, int n_parallel, const std::string& build_func); + + TVM_DEFINE_OBJECT_REF_METHODS(LocalBuilder, Builder, LocalBuilderNode); +}; + +/*! \brief RPCRunner that uses RPC call to measures the time cost of programs + * on remote devices */ class RPCRunnerNode : public RunnerNode { public: std::string key; @@ -201,10 +234,6 @@ class RPCRunnerNode : public RunnerNode { int min_repeat_ms; double cooldown_interval; - static Runner make(const std::string& key, const std::string& host, int port, - int priority, int timeout, int n_parallel, int number, - int repeat, int min_repeat_ms, double cooldown_interval); - /*! \biref Run measurement and return results */ Array Run(const Array& inputs, const Array& build_results, @@ -214,7 +243,20 @@ class RPCRunnerNode : public RunnerNode { TVM_DECLARE_FINAL_OBJECT_INFO(RPCRunnerNode, RunnerNode); }; -/* \brief LocalRunner that uses local CPU/GPU to measures the time cost of programs */ +/*! + * \brief Managed reference to RPCRunnerNode. + * \sa RPCRunnerNode + */ +class RPCRunner : public Runner { + public: + RPCRunner(const std::string& key, const std::string& host, int port, + int priority, int timeout, int n_parallel, int number, + int repeat, int min_repeat_ms, double cooldown_interval); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RPCRunner, Runner, RPCRunnerNode); +}; + +/*! \brief LocalRunner that uses local CPU/GPU to measures the time cost of programs */ class LocalRunnerNode: public RunnerNode { public: int number; @@ -222,9 +264,6 @@ class LocalRunnerNode: public RunnerNode { int min_repeat_ms; double cooldown_interval; - static Runner make(int timeout, int number, int repeat, - int min_repeat_ms, double cooldown_interval); - /*! \biref Run measurement and return results */ Array Run(const Array& inputs, const Array& build_results, @@ -234,6 +273,18 @@ class LocalRunnerNode: public RunnerNode { TVM_DECLARE_FINAL_OBJECT_INFO(LocalRunnerNode, RunnerNode); }; +/*! + * \brief Managed reference to LocalRunnerNode. + * \sa LocalRunnerNode + */ +class LocalRunner: public Runner { + public: + LocalRunner(int timeout, int number, int repeat, + int min_repeat_ms, double cooldown_interval); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(LocalRunner, Runner, + LocalRunnerNode); +}; /*! * \brief Measurer that measures the time costs of tvm programs @@ -254,11 +305,6 @@ class ProgramMeasurerNode: public Object { int verbose; int max_continous_error; - static ProgramMeasurer make(Builder builder, Runner runner, - Array callbacks, - int verbose, - int max_continous_error = -1); - /*! \brief Reset book keeping variables */ void Reset(); @@ -277,8 +323,19 @@ class ProgramMeasurerNode: public Object { static constexpr const char* _type_key = "ansor.ProgramMeasurer"; TVM_DECLARE_FINAL_OBJECT_INFO(ProgramMeasurerNode, Object); }; -TVM_DEFINE_MUTABLE_OBJECT_REF(ProgramMeasurer, ProgramMeasurerNode); +/*! + * \brief Managed reference to ProgramMeasurerNode. + * \sa ProgramMeasurerNode + */ +class ProgramMeasurer : public ObjectRef { + public: + ProgramMeasurer(Builder builder, Runner runner, + Array callbacks, + int verbose, int max_continous_error = -1); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ProgramMeasurer, ObjectRef, ProgramMeasurerNode); +}; } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_policy/search_policy.cc b/src/ansor/search_policy/search_policy.cc index c9bccfdce806..51a48780813a 100644 --- a/src/ansor/search_policy/search_policy.cc +++ b/src/ansor/search_policy/search_policy.cc @@ -33,7 +33,7 @@ TVM_REGISTER_OBJECT_TYPE(SearchPolicyNode); TVM_REGISTER_OBJECT_TYPE(PreloadMeasuredStatesNode); void SearchPolicyNode::PreloadMeasuredStates(const std::string& log_file) { - LogReader reader = LogReaderNode::make(log_file); + LogReader reader = LogReader(log_file); const auto& res = reader->ReadLines(-1); size_t log_size = res.first.size(); CHECK_EQ(log_size, res.second.size()); @@ -84,10 +84,10 @@ void SearchPolicyNode::RunCallbacks(const Array& callbacks) { } } -SearchCallback PreloadMeasuredStatesNode::make(std::string filename) { +PreloadMeasuredStates::PreloadMeasuredStates(std::string filename) { auto node = make_object(); node->filename = std::move(filename); - return SearchCallback(node); + data_ = std::move(node); } void PreloadMeasuredStatesNode::callback(SearchPolicyNode* policy) { @@ -121,7 +121,7 @@ TVM_REGISTER_GLOBAL("ansor.SearchPolicySetVerbose") TVM_REGISTER_GLOBAL("ansor.PreloadMeasuredStates") .set_body_typed([](std::string filename) { - return PreloadMeasuredStatesNode::make(filename); + return PreloadMeasuredStates(filename); }); } // namespace ansor diff --git a/src/ansor/search_policy/search_policy.h b/src/ansor/search_policy/search_policy.h index 4710cc05ae7f..03e7c3f025df 100644 --- a/src/ansor/search_policy/search_policy.h +++ b/src/ansor/search_policy/search_policy.h @@ -36,10 +36,9 @@ namespace tvm { namespace ansor { -class SearchPolicy; class SearchPolicyNode; -/*! Callback function to be called before or after the search process */ +/*! \brief Callback function to be called before or after the search process */ class SearchCallbackNode : public Object { public: virtual void callback(SearchPolicyNode* policy) = 0; @@ -55,14 +54,24 @@ class PreloadMeasuredStatesNode : public SearchCallbackNode { public: std::string filename; - static SearchCallback make(std::string filename); - void callback(SearchPolicyNode* policy) final; static constexpr const char *_type_key = "ansor.PreloadMeasuredStates"; TVM_DECLARE_FINAL_OBJECT_INFO(PreloadMeasuredStatesNode, SearchCallbackNode); }; +/*! + * \brief Managed reference to PreloadMeasuredStatesNode. + * \sa PreloadMeasuredStatesNode + */ +class PreloadMeasuredStates : public SearchCallback { + public: + explicit PreloadMeasuredStates(std::string filename); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PreloadMeasuredStates, SearchCallback, + PreloadMeasuredStatesNode); +}; + /*! \brief The base class for search policy */ class SearchPolicyNode : public Object { public: diff --git a/src/ansor/search_policy/sketch_search_policy.cc b/src/ansor/search_policy/sketch_search_policy.cc index 7e4c3999dce3..5b2c10c08c81 100644 --- a/src/ansor/search_policy/sketch_search_policy.cc +++ b/src/ansor/search_policy/sketch_search_policy.cc @@ -49,20 +49,19 @@ TVM_REGISTER_OBJECT_TYPE(PreloadCustomSketchRuleNode); // All possible candidates for auto_unroll const std::vector SketchSearchPolicyNode::auto_unroll_configs{0, 16, 64, 512, 1024}; -SearchPolicy SketchSearchPolicyNode::make(CostModel program_cost_model, - Map params, - int seed) { +SketchSearchPolicy::SketchSearchPolicy(CostModel program_cost_model, + Map params, + int seed) { auto node = make_object(); node->program_cost_model = std::move(program_cost_model); node->rand_gen_ = std::mt19937(seed); node->params = std::move(params); - return SearchPolicy(node); + data_ = std::move(node); } State SketchSearchPolicyNode::Search(SearchTask task, int n_trials, - int early_stopping, int num_measure_per_iter, - int verbose, ProgramMeasurer measurer, - Array pre_search_callbacks) { + int early_stopping, int num_measure_per_iter, int verbose, + ProgramMeasurer measurer, Array pre_search_callbacks) { std::vector best_states, random_states; this->cur_task = task; this->verbose = verbose; @@ -221,7 +220,7 @@ void SketchSearchPolicyNode::PickStatesWithEpsGreedy( if (measured_states_set_.count(state_str)) { continue; } measured_states_set_.insert(state_str); - inputs->push_back(MeasureInputNode::make(cur_task, *pstate)); + inputs->push_back(MeasureInput(cur_task, *pstate)); measured_states_vector_.push_back(std::move(*pstate)); } } @@ -701,8 +700,8 @@ void SketchSearchPolicyNode::GenerateSketch( auto step = pstate->transform_steps[split_step_id].as(); CHECK(step != nullptr); pstate->transform_steps[split_step_id] - = SplitStepNode::make(step->stage_id, step->iter_id, step->extent, {PrimExpr()}, - step->inner_to_outer); + = SplitStep(step->stage_id, step->iter_id, step->extent, {PrimExpr()}, + step->inner_to_outer); } } } @@ -733,7 +732,7 @@ int InitPopulationFillTileSize(const SketchSearchPolicyNode* policy, policy->cur_task->hardware_params->max_innermost_split_factor); StateNode* pstate = state->CopyOnWrite(); - pstate->transform_steps[step_id] = SplitStepNode::make( + pstate->transform_steps[step_id] = SplitStep( ps->stage_id, ps->iter_id, ps->extent, candidate_lens[(*rand_gen)() % candidate_lens.size()], ps->inner_to_outer); @@ -1508,12 +1507,12 @@ class RuleCustomSketch : public SketchGenerationRule { PackedFunc apply_func_; }; -SearchCallback PreloadCustomSketchRuleNode::make(PackedFunc meet_condition_func, - PackedFunc apply_func) { +PreloadCustomSketchRule::PreloadCustomSketchRule(PackedFunc meet_condition_func, + PackedFunc apply_func) { auto node = make_object(); node->meet_condition_func = meet_condition_func; node->apply_func = apply_func; - return SearchCallback(node); + data_ = std::move(node); } void PreloadCustomSketchRuleNode::callback(SearchPolicyNode* policy) { @@ -1525,15 +1524,14 @@ void PreloadCustomSketchRuleNode::callback(SearchPolicyNode* policy) { } TVM_REGISTER_GLOBAL("ansor.SketchSearchPolicy") -.set_body_typed([](CostModel program_cost_model, - Map params, +.set_body_typed([](CostModel program_cost_model, Map params, int seed){ - return SketchSearchPolicyNode::make(program_cost_model, params, seed); + return SketchSearchPolicy(program_cost_model, params, seed); }); TVM_REGISTER_GLOBAL("ansor.PreloadCustomSketchRule") .set_body_typed([](PackedFunc meet_condition_func, PackedFunc apply_func) { - return PreloadCustomSketchRuleNode::make(meet_condition_func, apply_func); + return PreloadCustomSketchRule(meet_condition_func, apply_func); }); } // namespace ansor diff --git a/src/ansor/search_policy/sketch_search_policy.h b/src/ansor/search_policy/sketch_search_policy.h index 60920c5c1fdd..54a5cdd1fa4e 100644 --- a/src/ansor/search_policy/sketch_search_policy.h +++ b/src/ansor/search_policy/sketch_search_policy.h @@ -51,7 +51,8 @@ class SketchSearchPolicyNode: public SearchPolicyNode { public: /*! \brief The cost model for complete programs */ CostModel program_cost_model; - + /*! \brief Random generator */ + std::mt19937 rand_gen_; /*! \brief The parameters for search. It stores the following parameters: * int evolutionary_search_population // The population size for evolutionary search * int evolutionary_search_mutation_prob // The probability of mutation for evolutionary search @@ -63,14 +64,9 @@ class SketchSearchPolicyNode: public SearchPolicyNode { * str gpu_multi_level_tiling_structure // The structure of multi-level tiling for GPU */ Map params; - /*! \brief The rules to generate sketches */ std::vector sketch_rules; - static SearchPolicy make(CostModel program_cost_model, - Map params, - int seed); - /*! \brief Search and make n_trails measurements. * \returns the best state */ State Search(SearchTask task, int n_trials, @@ -92,7 +88,8 @@ class SketchSearchPolicyNode: public SearchPolicyNode { /*! \brief Pick states from best states and random states with eps-greedy policy */ void PickStatesWithEpsGreedy(std::vector* inputs, const std::vector& best_states, - const std::vector& random_states, int remaining_n_trials); + const std::vector& random_states, + int remaining_n_trials); private: // Run one round of the search pipeline @@ -111,10 +108,22 @@ class SketchSearchPolicyNode: public SearchPolicyNode { int num_best_states, std::vector* best_states); SplitFactorizationMemo split_memo_; // Memorize split space for Split - std::mt19937 rand_gen_; // Random generator int num_measure_per_iter_; // The number of states to measure per iteration }; -TVM_DEFINE_MUTABLE_OBJECT_REF(SketchSearchPolicy, SketchSearchPolicyNode); + +/*! + * \brief Managed reference to SketchSearchPolicyNode. + * \sa SketchSearchPolicyNode + */ +class SketchSearchPolicy : public SearchPolicy { + public: + SketchSearchPolicy(CostModel program_cost_model, + Map params, + int seed); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SketchSearchPolicy, SearchPolicy, + SketchSearchPolicyNode); +}; /*! \brief Pre-search callback function to load custom rules for sketch generation */ class PreloadCustomSketchRuleNode : public SearchCallbackNode { @@ -123,15 +132,25 @@ class PreloadCustomSketchRuleNode : public SearchCallbackNode { PackedFunc meet_condition_func; PackedFunc apply_func; - static SearchCallback make(PackedFunc meet_condition_func, - PackedFunc apply_func); - void callback(SearchPolicyNode* policy) final; static constexpr const char *_type_key = "ansor.PreloadCustomSketchRule"; TVM_DECLARE_FINAL_OBJECT_INFO(PreloadCustomSketchRuleNode, SearchCallbackNode); }; +/*! + * \brief Managed reference to PreloadCustomSketchRuleNode. + * \sa PreloadCustomSketchRuleNode + */ +class PreloadCustomSketchRule : public SearchCallback { + public: + PreloadCustomSketchRule(PackedFunc meet_condition_func, + PackedFunc apply_func); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PreloadCustomSketchRule, SearchCallback, + PreloadCustomSketchRuleNode); +}; + } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_policy/utils.cc b/src/ansor/search_policy/utils.cc index ba42ca55611c..412d0afcca98 100644 --- a/src/ansor/search_policy/utils.cc +++ b/src/ansor/search_policy/utils.cc @@ -371,7 +371,7 @@ State RandomMutateTileSize(const State& old_state, SplitFactorizationMemo* split auto pstate = tmp_s.CopyOnWrite(); pstate->transform_steps[step_id] = - SplitStepNode::make(ps->stage_id, ps->iter_id, ps->extent, new_lengths, ps->inner_to_outer); + SplitStep(ps->stage_id, ps->iter_id, ps->extent, new_lengths, ps->inner_to_outer); return tmp_s; } @@ -401,7 +401,7 @@ State RandomMutateMaxUnrollStep(const State& old_state, std::mt19937* random_gen auto val = std::to_string(auto_unroll_configs[(*random_gen)() % auto_unroll_configs.size()]); auto pstate = tmp_s.CopyOnWrite(); - pstate->transform_steps[step_id] = PragmaStepNode::make( + pstate->transform_steps[step_id] = PragmaStep( ps->stage_id, ps->iter_id, std::string("auto_unroll_max_step") + "$" + val); return tmp_s; } diff --git a/src/ansor/search_task.cc b/src/ansor/search_task.cc index c65516150f30..17ab73efb6aa 100644 --- a/src/ansor/search_task.cc +++ b/src/ansor/search_task.cc @@ -35,28 +35,27 @@ namespace ansor { TVM_REGISTER_NODE_TYPE(HardwareParamsNode); TVM_REGISTER_NODE_TYPE(SearchTaskNode); -HardwareParams HardwareParamsNode::make(int num_cores, int vector_unit_bytes, - int cache_line_bytes, - int max_unroll_vec, - int max_innermost_split_factor) { +HardwareParams::HardwareParams(int num_cores, int vector_unit_bytes, + int cache_line_bytes, int max_unroll_vec, + int max_innermost_split_factor) { auto node = make_object(); node->num_cores = num_cores; node->vector_unit_bytes = vector_unit_bytes; node->cache_line_bytes = cache_line_bytes; node->max_unroll_vec = max_unroll_vec; node->max_innermost_split_factor = max_innermost_split_factor; - return HardwareParams(node); + data_ = std::move(node); } HardwareParams HardwareParamsNode::GetDefaultHardwareParams( const Target& target, const Target& target_host) { if (target->target_name == "llvm") { - return HardwareParamsNode::make(tvm::runtime::threading::MaxConcurrency(), - 32, 64, 16, 64); + return HardwareParams(tvm::runtime::threading::MaxConcurrency(), + 32, 64, 16, 64); } else if (target->device_type == kDLGPU) { // TODO(jcf94): temp implementation, max vectorize size in GPU is related // to the data type - auto hardware_params = HardwareParamsNode::make(100000, 16, 64, 4, 64); + auto hardware_params = HardwareParams(100000, 16, 64, 4, 64); auto* p_hardware_params = hardware_params.CopyOnWrite(); auto ctx = TVMContext{kDLGPU, 0}; @@ -87,7 +86,7 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams( return hardware_params; } else if (target->device_type == kDLOpenCL) { // TODO(jcf94): temp implementation - auto hardware_params = HardwareParamsNode::make(100000, 16, 64, 4, 64); + auto hardware_params = HardwareParams(100000, 16, 64, 4, 64); auto p_hardware_params = hardware_params.CopyOnWrite(); auto ctx = TVMContext{kDLOpenCL, 0}; @@ -118,10 +117,9 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams( return HardwareParams(); } -SearchTask SearchTaskNode::make(ComputeDAG compute_dag, - std::string workload_key, Target target, - Target target_host, - HardwareParams hardware_params) { +SearchTask::SearchTask(ComputeDAG compute_dag, std::string workload_key, + Target target, Target target_host, + HardwareParams hardware_params) { auto node = make_object(); node->compute_dag = std::move(compute_dag); node->workload_key = std::move(workload_key); @@ -133,24 +131,23 @@ SearchTask SearchTaskNode::make(ComputeDAG compute_dag, node->hardware_params = HardwareParamsNode::GetDefaultHardwareParams( node->target, node->target_host); } - return SearchTask(node); + data_ = std::move(node); } TVM_REGISTER_GLOBAL("ansor.HardwareParams") .set_body_typed([](int num_cores, int vector_unit_bytes, int cache_line_bytes, int max_unroll_vec, int max_innermost_split_factor) { - return HardwareParamsNode::make(num_cores, vector_unit_bytes, - cache_line_bytes, max_unroll_vec, - max_innermost_split_factor); + return HardwareParams(num_cores, vector_unit_bytes, cache_line_bytes, + max_unroll_vec, max_innermost_split_factor); }); TVM_REGISTER_GLOBAL("ansor.SearchTask") .set_body_typed([](ComputeDAG compute_dag, std::string workload_key, Target target, Target target_host, HardwareParams hardware_params) { - return SearchTaskNode::make(compute_dag, workload_key, target, - target_host, hardware_params); + return SearchTask(compute_dag, workload_key, target, target_host, + hardware_params); }); } // namespace ansor diff --git a/src/ansor/search_task.h b/src/ansor/search_task.h index cfa5500c39f4..c53fdcd0f792 100644 --- a/src/ansor/search_task.h +++ b/src/ansor/search_task.h @@ -32,7 +32,7 @@ namespace tvm { namespace ansor { -class HardwareParams; class SearchTask; +class HardwareParams; /*! \brief Hardware related parameters */ class HardwareParamsNode : public Object { @@ -69,17 +69,25 @@ class HardwareParamsNode : public Object { v->Visit("warp_size", &warp_size); } - static HardwareParams make(int num_cores, int vector_unit_bytes, - int cache_line_bytes, int max_unroll_vec, - int max_innermost_split_factor); - static HardwareParams GetDefaultHardwareParams(const Target& target, const Target& target_host); static constexpr const char* _type_key = "ansor.HardwareParams"; TVM_DECLARE_FINAL_OBJECT_INFO(HardwareParamsNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(HardwareParams, ObjectRef, HardwareParamsNode); + +/*! + * \brief Managed reference to HardwareParamsNode. + * \sa HardwareParamsNode + */ +class HardwareParams : public ObjectRef { + public: + HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes, + int max_unroll_vec, int max_innermost_split_factor); + + TVM_DEFINE_OBJECT_REF_METHODS(HardwareParams, ObjectRef, HardwareParamsNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(HardwareParamsNode); +}; /*! \brief Meta-info for a search task */ class SearchTaskNode : public Object { @@ -98,14 +106,23 @@ class SearchTaskNode : public Object { v->Visit("hardware_params", &hardware_params); } - static SearchTask make(ComputeDAG compute_dag, std::string workload_key, - Target target, Target target_host, - HardwareParams hardware_params); - static constexpr const char* _type_key = "ansor.SearchTask"; TVM_DECLARE_FINAL_OBJECT_INFO(SearchTaskNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(SearchTask, ObjectRef, SearchTaskNode); + +/*! + * \brief Managed reference to SearchTaskNode. + * \sa SearchTaskNode + */ +class SearchTask : public ObjectRef { + public: + SearchTask(ComputeDAG compute_dag, std::string workload_key, + Target target, Target target_host, + HardwareParams hardware_params); + + TVM_DEFINE_OBJECT_REF_METHODS(SearchTask, ObjectRef, SearchTaskNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(SearchTaskNode); +}; } // namespace ansor } // namespace tvm diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index 2d8379f56a5f..71fba764506f 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -199,7 +199,7 @@ struct Handler > { reader->Read(&stage_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&int_list); - data->push_back(::tvm::ansor::ReorderStepNode::make(stage_id, int_list)); + data->push_back(::tvm::ansor::ReorderStep(stage_id, int_list)); } else if (name == "SP") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); @@ -211,7 +211,7 @@ struct Handler > { reader->Read(&int_list); s = reader->NextArrayItem(); CHECK(s); reader->Read(&inner_to_outer); - data->push_back(::tvm::ansor::SplitStepNode::make( + data->push_back(::tvm::ansor::SplitStep( stage_id, iter_id, extent, std::vector<::tvm::PrimExpr>(int_list.begin(), int_list.end()), inner_to_outer)); @@ -224,7 +224,7 @@ struct Handler > { reader->Read(&src_step_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&n_split); - data->push_back(::tvm::ansor::FollowSplitStepNode::make( + data->push_back(::tvm::ansor::FollowSplitStep( stage_id, iter_id, src_step_id, n_split)); } else if (name == "FFSP") { s = reader->NextArrayItem(); CHECK(s); @@ -237,14 +237,14 @@ struct Handler > { reader->Read(&level); s = reader->NextArrayItem(); CHECK(s); reader->Read(&factor_or_nparts); - data->push_back(::tvm::ansor::FollowFusedSplitStepNode::make( + data->push_back(::tvm::ansor::FollowFusedSplitStep( stage_id, iter_id, int_list, level, factor_or_nparts)); } else if (name == "FU") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&int_list); - data->push_back(::tvm::ansor::FuseStepNode::make(stage_id, int_list)); + data->push_back(::tvm::ansor::FuseStep(stage_id, int_list)); } else if (name == "AN") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); @@ -252,7 +252,7 @@ struct Handler > { reader->Read(&iter_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&ann); - data->push_back(::tvm::ansor::AnnotationStepNode::make(stage_id, + data->push_back(::tvm::ansor::AnnotationStep(stage_id, iter_id, ::tvm::ansor::IteratorAnnotation(ann))); } else if (name == "CA") { s = reader->NextArrayItem(); CHECK(s); @@ -261,16 +261,16 @@ struct Handler > { reader->Read(&target_stage_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&iter_id); - data->push_back(::tvm::ansor::ComputeAtStepNode::make( + data->push_back(::tvm::ansor::ComputeAtStep( stage_id, target_stage_id, iter_id)); } else if (name == "CR") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); - data->push_back(::tvm::ansor::ComputeRootStepNode::make(stage_id)); + data->push_back(::tvm::ansor::ComputeRootStep(stage_id)); } else if (name == "CI") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); - data->push_back(::tvm::ansor::ComputeInlineStepNode::make(stage_id)); + data->push_back(::tvm::ansor::ComputeInlineStep(stage_id)); } else if (name == "CHR") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); @@ -278,14 +278,14 @@ struct Handler > { reader->Read(&scope_name); s = reader->NextArrayItem(); CHECK(s); reader->Read(&int_list); - data->push_back(::tvm::ansor::CacheReadStepNode::make( + data->push_back(::tvm::ansor::CacheReadStep( stage_id, scope_name, int_list)); } else if (name == "CHW") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&scope_name); - data->push_back(::tvm::ansor::CacheWriteStepNode::make( + data->push_back(::tvm::ansor::CacheWriteStep( stage_id, scope_name)); } else if (name == "PR") { s = reader->NextArrayItem(); CHECK(s); @@ -294,7 +294,7 @@ struct Handler > { reader->Read(&iter_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&pragma_type); - data->push_back(::tvm::ansor::PragmaStepNode::make( + data->push_back(::tvm::ansor::PragmaStep( stage_id, iter_id, pragma_type)); } else if (name == "RF") { s = reader->NextArrayItem(); CHECK(s); @@ -303,7 +303,7 @@ struct Handler > { reader->Read(&iter_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&factor_iter_id); - data->push_back(::tvm::ansor::RfactorStepNode::make( + data->push_back(::tvm::ansor::RfactorStep( stage_id, iter_id, factor_iter_id)); } else if (name == "SA") { s = reader->NextArrayItem(); CHECK(s); @@ -314,7 +314,7 @@ struct Handler > { reader->Read(&factor); s = reader->NextArrayItem(); CHECK(s); reader->Read(&offset); - data->push_back(::tvm::ansor::StorageAlignStepNode::make( + data->push_back(::tvm::ansor::StorageAlignStep( stage_id, iter_id, factor, offset)); } else if (name == "TS") { s = reader->NextArrayItem(); CHECK(s); @@ -323,7 +323,7 @@ struct Handler > { reader->Read(&iter_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&ti_func_name); - data->push_back(::tvm::ansor::TensorizeStepNode::make( + data->push_back(::tvm::ansor::TensorizeStep( stage_id, iter_id, ti_func_name)); } else { LOG(FATAL) << "Invalid step format"; @@ -457,10 +457,10 @@ TVM_REGISTER_OBJECT_TYPE(LogReaderNode); const std::string ANSOR_LOG_VERSION = "v0.2"; // NOLINT(*) -MeasureCallback LogToFileNode::make(std::string filename) { +LogToFile::LogToFile(std::string filename) { auto node = make_object(); node->filename = std::move(filename); - return MeasureCallback(node); + data_ = std::move(node); } void WriteMeasureRecords(std::ostream* os, @@ -506,11 +506,11 @@ void LogToFileNode::callback(const SearchPolicy& policy, WriteMeasureRecords(&ofs, inputs, results); } -LogReader LogReaderNode::make(std::string filename) { +LogReader::LogReader(std::string filename) { auto node = make_object(); node->filename = filename; node->infile.open(filename, std::ifstream::in); - return LogReader(node); + data_ = std::move(node); } LogReaderNode::~LogReaderNode() { @@ -556,15 +556,15 @@ std::pair, Array > LogReaderNode::ReadLines( return std::make_pair(inputs, results); } -std::pair BestMeasurePairInFile(const std::string& filename, - const std::string& workload_key, - const Target& target) { +std::pair BestMeasurePairInFile( + const std::string& filename, const std::string& workload_key, + const Target& target) { std::pair best_pair; double best_cost = 1e30; auto inp = make_object(); auto res = make_object(); - LogReader reader = LogReaderNode::make(filename); + LogReader reader = LogReader(filename); while (reader->ReadNext(inp.get(), res.get())) { if (res->error_no != kNoError || inp->task->workload_key != workload_key @@ -594,12 +594,12 @@ TVM_REGISTER_GLOBAL("ansor.WriteMeasureRecordsToFile") TVM_REGISTER_GLOBAL("ansor.LogToFile") .set_body_typed([](const std::string& filename) { - return LogToFileNode::make(filename); + return LogToFile(filename); }); TVM_REGISTER_GLOBAL("ansor.LogReader") .set_body_typed([](const std::string& filename) { - return LogReaderNode::make(filename); + return LogReader(filename); }); TVM_REGISTER_GLOBAL("ansor.LogReaderReadLines") @@ -648,8 +648,8 @@ TVM_REGISTER_GLOBAL("ansor.GetStatesFromMeasureInputs") ptask = inp->task.operator->(); } else { // the measure input is incomplete // rebuild task for incomplete measure pairs read from file - SearchTask new_task = SearchTaskNode::make( - ComputeDAGNode::make_by_workload_key(workload_key), + SearchTask new_task = SearchTask( + ComputeDAG(workload_key), workload_key, inp->task->target, inp->task->target_host, diff --git a/src/ansor/serialization.h b/src/ansor/serialization.h index d877717db9cb..82dd036991e6 100644 --- a/src/ansor/serialization.h +++ b/src/ansor/serialization.h @@ -38,8 +38,6 @@ class LogToFileNode : public MeasureCallbackNode { public: std::string filename; - static MeasureCallback make(std::string filename); - /*! \brief Log measure pairs to file. This is called by the search policy */ void callback(const SearchPolicy& policy, const Array& inputs, @@ -49,15 +47,23 @@ class LogToFileNode : public MeasureCallbackNode { TVM_DECLARE_FINAL_OBJECT_INFO(LogToFileNode, MeasureCallbackNode); }; -class LogReader; +/*! + * \brief Managed reference to LogToFileNode. + * \sa LogToFileNode + */ +class LogToFile : public MeasureCallback { + public: + explicit LogToFile(std::string filename); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(LogToFile, MeasureCallback, LogToFileNode); +}; -/*! \brief Log reader */ +/*! \brief Log reader to load step logs from a target file.*/ class LogReaderNode : public Object { public: std::string filename; std::ifstream infile; - static LogReader make(std::string filename); ~LogReaderNode(); /*! \brief Read next line in the log file @@ -76,7 +82,17 @@ class LogReaderNode : public Object { private: std::string cur_line; }; -TVM_DEFINE_MUTABLE_OBJECT_REF(LogReader, LogReaderNode); + +/*! + * \brief Managed reference to LogReaderNode. + * \sa LogReaderNode + */ +class LogReader : public ObjectRef { + public: + explicit LogReader(std::string filename); + + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(LogReader, ObjectRef, LogReaderNode); +}; /*! \brief Write measure records to an output stream */ void WriteMeasureRecords(std::ostream* os, diff --git a/src/ansor/transform_step.cc b/src/ansor/transform_step.cc index 857f3e570de0..bd0a7f7165f6 100644 --- a/src/ansor/transform_step.cc +++ b/src/ansor/transform_step.cc @@ -34,11 +34,11 @@ namespace tvm { namespace ansor { /********** Reorder **********/ -ReorderStep ReorderStepNode::make(int stage_id, const std::vector& after_ids) { +ReorderStep::ReorderStep(int stage_id, const std::vector& after_ids) { auto node = make_object(); node->stage_id = stage_id; node->after_ids = after_ids; - return ReorderStep(node); + data_ = std::move(node); } void ReorderStepNode::ApplyToSchedule(std::vector *stages, @@ -155,9 +155,9 @@ std::string PrintSplitAsPythonAPI(std::vector *stages, return ss.str(); } -SplitStep SplitStepNode::make(int stage_id, int iter_id, - PrimExpr extent, const std::vector& lengths, - bool inner_to_outer) { +SplitStep::SplitStep(int stage_id, int iter_id, PrimExpr extent, + const std::vector& lengths, + bool inner_to_outer) { auto node = make_object(); node->stage_id = stage_id; // Extent can be a unreducible expression in some special cases @@ -167,7 +167,7 @@ SplitStep SplitStepNode::make(int stage_id, int iter_id, node->iter_id = iter_id; node->lengths = lengths; node->inner_to_outer = inner_to_outer; - return SplitStep(node); + data_ = std::move(node); } std::vector SplitStepNode::ApplyToSchedule( @@ -184,18 +184,19 @@ std::string SplitStepNode::PrintAsPythonAPI( } /********** Follow Split **********/ -FollowSplitStep FollowSplitStepNode::make(int stage_id, int iter_id, - int src_step_id, int n_split) { +FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, + int src_step_id, int n_split) { auto node = make_object(); node->stage_id = stage_id; node->iter_id = iter_id; node->src_step_id = src_step_id; node->n_split = n_split; - return FollowSplitStep(node); + data_ = std::move(node); } -void FollowSplitStepNode::ExtractSplitLengths(const std::vector& transform_steps, - std::vector* lengths) const { +void FollowSplitStepNode::ExtractSplitLengths( + const std::vector& transform_steps, + std::vector* lengths) const { CHECK_LT(src_step_id, transform_steps.size()); auto ps = transform_steps[src_step_id].as(); CHECK(ps != nullptr); @@ -237,15 +238,15 @@ std::string FollowSplitStepNode::PrintAsPythonAPI( } /********** Follow Fused Split **********/ -FollowFusedSplitStep FollowFusedSplitStepNode::make(int stage_id, int iter_id, - const std::vector& src_step_ids, int level, bool factor_or_nparts) { +FollowFusedSplitStep::FollowFusedSplitStep(int stage_id, int iter_id, + const std::vector& src_step_ids, int level, bool factor_or_nparts) { auto node = make_object(); node->stage_id = stage_id; node->iter_id = iter_id; node->src_step_ids = src_step_ids;; node->level = level; node->factor_or_nparts = factor_or_nparts; - return FollowFusedSplitStep(node); + data_ = std::move(node); } PrimExpr FollowFusedSplitStepNode::ExtractSplitLength( @@ -279,16 +280,16 @@ std::string FollowFusedSplitStepNode::PrintAsPythonAPI( te::Schedule *schedule, const std::vector& transform_steps) const { const PrimExpr& length = ExtractSplitLength(transform_steps); return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, - {length}, factor_or_nparts); + {length}, factor_or_nparts); } /********** Fuse **********/ -FuseStep FuseStepNode::make(int stage_id, const std::vector& fused_ids) { +FuseStep::FuseStep(int stage_id, const std::vector& fused_ids) { auto node = make_object(); node->stage_id = stage_id; node->fused_ids = fused_ids; - return FuseStep(node); + data_ = std::move(node); } IterVar FuseStepNode::ApplyToSchedule(std::vector *stages, @@ -306,7 +307,7 @@ IterVar FuseStepNode::ApplyToSchedule(std::vector *stages, new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + fused_ids[0]); new_axes.push_back(fused_axis); new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back() + 1, - axes.end()); + axes.end()); (*stage_to_axes)[stage] = std::move(new_axes); return fused_axis; @@ -337,12 +338,13 @@ std::string FuseStepNode::PrintAsPythonAPI(std::vector *stages, } /********** Annotation **********/ -AnnotationStep AnnotationStepNode::make(int stage_id, int iter_id, IteratorAnnotation ann) { +AnnotationStep::AnnotationStep(int stage_id, int iter_id, + IteratorAnnotation ann) { auto node = make_object(); node->stage_id = stage_id; node->iter_id = iter_id; node->annotation = ann; - return AnnotationStep(node); + data_ = std::move(node); } void AnnotationStepNode::ApplyToSchedule(std::vector *stages, @@ -426,12 +428,13 @@ std::string AnnotationStepNode::PrintAsPythonAPI(std::vector *stages, } /********** Compute At **********/ -ComputeAtStep ComputeAtStepNode::make(int stage_id, int target_stage_id, int target_iter_id) { +ComputeAtStep::ComputeAtStep(int stage_id, int target_stage_id, + int target_iter_id) { auto node = make_object(); node->stage_id = stage_id; node->target_stage_id = target_stage_id; node->target_iter_id = target_iter_id; - return ComputeAtStep(node); + data_ = std::move(node); } void ComputeAtStepNode::ApplyToSchedule(std::vector *stages, @@ -460,10 +463,10 @@ std::string ComputeAtStepNode::PrintAsPythonAPI(std::vector *stages, } /********** Compute Root **********/ -ComputeRootStep ComputeRootStepNode::make(int stage_id) { +ComputeRootStep::ComputeRootStep(int stage_id) { auto node = make_object(); node->stage_id = stage_id; - return ComputeRootStep(node); + data_ = std::move(node); } void ComputeRootStepNode::ApplyToSchedule(std::vector *stages, @@ -485,10 +488,10 @@ std::string ComputeRootStepNode::PrintAsPythonAPI(std::vector *stages } /********** Compute Inline **********/ -ComputeInlineStep ComputeInlineStepNode::make(int stage_id) { +ComputeInlineStep::ComputeInlineStep(int stage_id) { auto node = make_object(); node->stage_id = stage_id; - return ComputeInlineStep(node); + data_ = std::move(node); } void ComputeInlineStepNode::ApplyToSchedule(std::vector *stages, @@ -511,13 +514,13 @@ std::string ComputeInlineStepNode::PrintAsPythonAPI( } /********** Cache Read **********/ -CacheReadStep CacheReadStepNode::make(int stage_id, std::string scope_name, - const std::vector& reader_stage_ids) { +CacheReadStep::CacheReadStep(int stage_id, std::string scope_name, + const std::vector& reader_stage_ids) { auto node = make_object(); node->stage_id = stage_id; node->scope_name = std::move(scope_name); node->reader_stage_ids = reader_stage_ids; - return CacheReadStep(node); + data_ = std::move(node); } te::Tensor CacheReadStepNode::ApplyToSchedule(std::vector* stages, @@ -574,11 +577,11 @@ std::string CacheReadStepNode::PrintAsPythonAPI(std::vector *stages, } /********** Cache Write **********/ -CacheWriteStep CacheWriteStepNode::make(int stage_id, std::string scope_name) { +CacheWriteStep::CacheWriteStep(int stage_id, std::string scope_name) { auto node = make_object(); node->stage_id = stage_id; node->scope_name = std::move(scope_name); - return CacheWriteStep(node); + data_ = std::move(node); } Array CacheWriteStepNode::ApplyToSchedule( @@ -642,13 +645,12 @@ std::string CacheWriteStepNode::PrintAsPythonAPI(std::vector *stages, } /********** Pragma **********/ -PragmaStep PragmaStepNode::make(int stage_id, int iter_id, - std::string pragma_type) { +PragmaStep::PragmaStep(int stage_id, int iter_id, std::string pragma_type) { auto node = make_object(); node->stage_id = stage_id; node->iter_id = iter_id; node->pragma_type = std::move(pragma_type); - return PragmaStep(node); + data_ = std::move(node); } void PragmaStepNode::ApplyToSchedule(std::vector *stages, @@ -692,12 +694,12 @@ std::string PragmaStepNode::PrintAsPythonAPI(std::vector *stages, } /********** Rfactor **********/ -RfactorStep RfactorStepNode::make(int stage_id, int iter_id, int factor_iter_id) { +RfactorStep::RfactorStep(int stage_id, int iter_id, int factor_iter_id) { auto node = make_object(); node->stage_id = stage_id; node->iter_id = iter_id; node->factor_iter_id = factor_iter_id; - return RfactorStep(node); + data_ = std::move(node); } Array RfactorStepNode::ApplyToSchedule(std::vector *stages, @@ -719,9 +721,9 @@ Array RfactorStepNode::ApplyToSchedule(std::vector *stage } std::string RfactorStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { + StageToAxesMap *stage_to_axes, + te::Schedule *schedule, + const std::vector& transform_steps) const { std::stringstream ss; const auto& stage = (*stages)[stage_id]; @@ -772,14 +774,14 @@ std::string RfactorStepNode::PrintAsPythonAPI(std::vector *stages, } /********** Storage Align **********/ -StorageAlignStep StorageAlignStepNode::make(int stage_id, int iter_id, - int factor, int offset) { +StorageAlignStep::StorageAlignStep(int stage_id, int iter_id, + int factor, int offset) { auto node = make_object(); node->stage_id = stage_id; node->iter_id = iter_id; node->factor = factor; node->offset = offset; - return StorageAlignStep(node); + data_ = std::move(node); } void StorageAlignStepNode::ApplyToSchedule(std::vector *stages, @@ -803,13 +805,13 @@ std::string StorageAlignStepNode::PrintAsPythonAPI( } /********** Tensorize **********/ -TensorizeStep TensorizeStepNode::make(int stage_id, int iter_id, - std::string ti_func_name) { +TensorizeStep::TensorizeStep(int stage_id, int iter_id, + std::string ti_func_name) { auto node = make_object(); node->stage_id = stage_id; node->iter_id = iter_id; node->ti_func_name = ti_func_name; - return TensorizeStep(node); + data_ = std::move(node); } void TensorizeStepNode::ApplyToSchedule(std::vector *stages, diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h index 9af14429bf61..3eb023eb75c8 100644 --- a/src/ansor/transform_step.h +++ b/src/ansor/transform_step.h @@ -21,15 +21,15 @@ * \file ansor/transform_step.h * \brief Transformation steps. For each schedule primitive, there is a corresponding transform step. * - * \Note How to add a new transform step. + * \note How to add a new transform step. * Take fuse for example: - * 1. Define class `FuseStepNode`, `FuseStep` in `transform_steps.h`, and implement its make function - * `FuseStepNode::make(...)` in `transform_steps.cc` + * 1. Define class `FuseStepNode`, `FuseStep` in `transform_steps.h`, and implement its construction + * function `FuseStep::FuseStep(...)` in `transform_steps.cc` * 2. Implement `FuseStepNode::ApplyToSchedule` and `FuseStepNode::PrintAsPythonAPI`. * - In these two functions you need to lower this step with tvm's te schedule API * 3. Implement `State::fuse` and `State::DoFuseStep`. * - In these two functions you need to incrementally update all data structures in State with - * CopyOnWrite style + * CopyOnWrite style * 4. Add you step to `ComputeDAG::ReplaySteps` and make sure it works. * 5. Add serialization support in `struct Handler >` * in `serialization.cc` @@ -56,8 +56,6 @@ class ReorderStepNode: public StepNode { std::vector after_ids; // The iterator ids after reorder. // This array should specify the order of all iterators. - static ReorderStep make(int stage_id, const std::vector& after_ids); - void ApplyToSchedule(std::vector *stages, StageToAxesMap *stage_to_axes) const; @@ -69,7 +67,18 @@ class ReorderStepNode: public StepNode { static constexpr const char* _type_key = "ansor.ReorderStep"; TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(ReorderStep, Step, ReorderStepNode); + +/*! + * \brief Managed reference to ReorderStepNode. + * \sa ReorderStepNode + */ +class ReorderStep : public Step { + public: + ReorderStep(int stage_id, const std::vector& after_ids); + + TVM_DEFINE_OBJECT_REF_METHODS(ReorderStep, Step, ReorderStepNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ReorderStepNode); +}; /*! \brief Split step that corresponds to te::Stage::split with additional * support of multiple-level of factors */ @@ -81,10 +90,6 @@ class SplitStepNode: public StepNode { bool inner_to_outer; // If true, the `lengths` denote the lengths of // iterators from inner level to outer level - static SplitStep make(int stage_id, int iter_id, PrimExpr extent, - const std::vector& lengths, - bool inner_to_outer); - std::vector ApplyToSchedule(std::vector *stages, StageToAxesMap *stage_to_axes) const; @@ -96,7 +101,20 @@ class SplitStepNode: public StepNode { static constexpr const char* _type_key = "ansor.SplitStep"; TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(SplitStep, Step, SplitStepNode); + +/*! + * \brief Managed reference to SplitStepNode. + * \sa SplitStepNode + */ +class SplitStep : public Step { + public: + SplitStep(int stage_id, int iter_id, PrimExpr extent, + const std::vector& lengths, + bool inner_to_outer); + + TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(SplitStepNode); +}; /*! \brief Similar to SplitStepNode, but use split factor from another step * (i.e. Follow another split step) */ @@ -106,9 +124,6 @@ class FollowSplitStepNode: public StepNode { int src_step_id; // The index of the split step to follow in the history int n_split; // The number of split level - static FollowSplitStep make(int stage_id, int iter_id, - int src_step_id, int n_split); - void ExtractSplitLengths(const std::vector& transform_steps, std::vector* lengths) const; @@ -124,7 +139,19 @@ class FollowSplitStepNode: public StepNode { static constexpr const char* _type_key = "ansor.FollowSplitStep"; TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(FollowSplitStep, Step, FollowSplitStepNode); + +/*! + * \brief Managed reference to FollowSplitStepNode. + * \sa FollowSplitStepNode + */ +class FollowSplitStep : public Step { + public: + FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split); + + TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(FollowSplitStepNode); +}; + /*! \brief Similar to FollowSplitStep, but use split factors from multiple steps. * \Note This can be used for the split in cooperative fetching @@ -136,10 +163,6 @@ class FollowFusedSplitStepNode: public StepNode { int level; // Use the length in this split level bool factor_or_nparts; // If this is true, use factor. Otherwise, use nparts - static FollowFusedSplitStep make(int stage_id, int iter_id, - const std::vector& src_step_ids, - int level, bool factor_or_nparts); - PrimExpr ExtractSplitLength(const std::vector& transform_steps) const; std::vector ApplyToSchedule(std::vector *stages, @@ -154,15 +177,26 @@ class FollowFusedSplitStepNode: public StepNode { static constexpr const char* _type_key = "ansor.FollowFusedSplitStep"; TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(FollowFusedSplitStep, Step, FollowFusedSplitStepNode); + +/*! + * \brief Managed reference to FollowFusedSplitStepNode. + * \sa FollowFusedSplitStepNode + */ +class FollowFusedSplitStep : public Step { + public: + FollowFusedSplitStep(int stage_id, int iter_id, + const std::vector& src_step_ids, + int level, bool factor_or_nparts); + + TVM_DEFINE_OBJECT_REF_METHODS(FollowFusedSplitStep, Step, FollowFusedSplitStepNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(FollowFusedSplitStepNode); +}; /*! \brief Fuse step that corresponds to te::Stage::fuse */ class FuseStepNode: public StepNode { public: std::vector fused_ids; // The ids of iterators to fuse - static FuseStep make(int stage_id, const std::vector& fused_ids); - IterVar ApplyToSchedule(std::vector *stages, StageToAxesMap *stage_to_axes) const; @@ -174,7 +208,18 @@ class FuseStepNode: public StepNode { static constexpr const char* _type_key = "ansor.FuseStep"; TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(FuseStep, Step, FuseStepNode); + +/*! + * \brief Managed reference to FuseStepNode. + * \sa FuseStepNode + */ +class FuseStep : public Step { + public: + FuseStep(int stage_id, const std::vector& fused_ids); + + TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(FuseStepNode); +}; /*! \brief Annotation step that corresponds to vectorize, parallel, unroll and thread binding. * (i.e. te::Stage::vectorize, te::Stage::parallel, te::Stage::vectorize, te::Stage::bind) @@ -184,8 +229,6 @@ class AnnotationStepNode: public StepNode { int iter_id; IteratorAnnotation annotation; - static AnnotationStep make(int stage_id, int iter_id, IteratorAnnotation ann); - void ApplyToSchedule(std::vector *stages, StageToAxesMap *stage_to_axes) const; @@ -197,7 +240,18 @@ class AnnotationStepNode: public StepNode { static constexpr const char* _type_key = "ansor.AnnotationStep"; TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(AnnotationStep, Step, AnnotationStepNode); + +/*! + * \brief Managed reference to AnnotationStepNode. + * \sa AnnotationStepNode + */ +class AnnotationStep : public Step { + public: + AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann); + + TVM_DEFINE_OBJECT_REF_METHODS(AnnotationStep, Step, AnnotationStepNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(AnnotationStepNode); +}; /*! \brief Fuse step that corresponds to te::Stage::compute_at */ class ComputeAtStepNode: public StepNode { @@ -205,9 +259,6 @@ class ComputeAtStepNode: public StepNode { int target_stage_id; int target_iter_id; - static ComputeAtStep make(int stage_id, int target_stage_id, - int target_iter_id); - void ApplyToSchedule(std::vector *stages, StageToAxesMap *stage_to_axes) const; @@ -219,12 +270,22 @@ class ComputeAtStepNode: public StepNode { static constexpr const char* _type_key = "ansor.ComputeAtStep"; TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(ComputeAtStep, Step, ComputeAtStepNode); + +/*! + * \brief Managed reference to ComputeAtStepNode. + * \sa ComputeAtStepNode + */ +class ComputeAtStep : public Step { + public: + ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id); + + TVM_DEFINE_OBJECT_REF_METHODS(ComputeAtStep, Step, ComputeAtStepNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeAtStepNode); +}; /*! \brief Fuse step that corresponds to te::Stage::compute_root */ class ComputeRootStepNode: public StepNode { public: - static ComputeRootStep make(int stage_id); void ApplyToSchedule(std::vector *stages, StageToAxesMap *stage_to_axes) const; @@ -237,13 +298,22 @@ class ComputeRootStepNode: public StepNode { static constexpr const char* _type_key = "ansor.ComputeRootStep"; TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(ComputeRootStep, Step, ComputeRootStepNode); + +/*! + * \brief Managed reference to ComputeRootStepNode. + * \sa ComputeRootStepNode + */ +class ComputeRootStep : public Step { + public: + explicit ComputeRootStep(int stage_id); + + TVM_DEFINE_OBJECT_REF_METHODS(ComputeRootStep, Step, ComputeRootStepNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeRootStepNode); +}; /*! \brief Fuse step that corresponds to te::Stage::compute_inline */ class ComputeInlineStepNode: public StepNode { public: - static ComputeInlineStep make(int stage_id); - void ApplyToSchedule(std::vector *stages, StageToAxesMap *stage_to_axes) const; @@ -255,7 +325,18 @@ class ComputeInlineStepNode: public StepNode { static constexpr const char* _type_key = "ansor.ComputeInlineStep"; TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(ComputeInlineStep, Step, ComputeInlineStepNode); + +/*! + * \brief Managed reference to ComputeInlineStepNode. + * \sa ComputeInlineStepNode + */ +class ComputeInlineStep : public Step { + public: + explicit ComputeInlineStep(int stage_id); + + TVM_DEFINE_OBJECT_REF_METHODS(ComputeInlineStep, Step, ComputeInlineStepNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeInlineStepNode); +}; /*! \brief Cache read step that corresponds to te::Schedule::cache_read */ class CacheReadStepNode: public StepNode { @@ -263,11 +344,9 @@ class CacheReadStepNode: public StepNode { std::string scope_name; std::vector reader_stage_ids; - static CacheReadStep make(int stage_id, std::string scope_name, - const std::vector& reader_stage_id); - te::Tensor ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, te::Schedule *schedule) const; + StageToAxesMap *stage_to_axes, + te::Schedule *schedule) const; std::string PrintAsPythonAPI(std::vector *stages, StageToAxesMap *stage_to_axes, @@ -277,7 +356,19 @@ class CacheReadStepNode: public StepNode { static constexpr const char* _type_key = "ansor.CacheReadStep"; TVM_DECLARE_FINAL_OBJECT_INFO(CacheReadStepNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(CacheReadStep, Step, CacheReadStepNode); + +/*! + * \brief Managed reference to CacheReadStepNode. + * \sa CacheReadStepNode + */ +class CacheReadStep : public Step { + public: + CacheReadStep(int stage_id, std::string scope_name, + const std::vector& reader_stage_id); + + TVM_DEFINE_OBJECT_REF_METHODS(CacheReadStep, Step, CacheReadStepNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(CacheReadStepNode); +}; /*! \brief Cache read step that corresponds to te::Schedule::cache_write * \Note This step will cache_write all output tensors of target stage */ @@ -285,10 +376,9 @@ class CacheWriteStepNode: public StepNode { public: std::string scope_name; - static CacheWriteStep make(int stage_id, std::string scope_name); - Array ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, te::Schedule *schedule) const; + StageToAxesMap *stage_to_axes, + te::Schedule *schedule) const; std::string PrintAsPythonAPI(std::vector *stages, StageToAxesMap *stage_to_axes, @@ -298,7 +388,18 @@ class CacheWriteStepNode: public StepNode { static constexpr const char* _type_key = "ansor.CacheWriteStep"; TVM_DECLARE_FINAL_OBJECT_INFO(CacheWriteStepNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(CacheWriteStep, Step, CacheWriteStepNode); + +/*! + * \brief Managed reference to CacheWriteStepNode. + * \sa CacheWriteStepNode + */ +class CacheWriteStep : public Step { + public: + CacheWriteStep(int stage_id, std::string scope_name); + + TVM_DEFINE_OBJECT_REF_METHODS(CacheWriteStep, Step, CacheWriteStepNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(CacheWriteStepNode); +}; /*! \brief Cache read step that corresponds to te::Schedule::pragma */ class PragmaStepNode: public StepNode { @@ -306,8 +407,6 @@ class PragmaStepNode: public StepNode { int iter_id; std::string pragma_type; - static PragmaStep make(int stage_id, int iter_id, std::string pragma_type); - void ApplyToSchedule(std::vector *stages, StageToAxesMap *stage_to_axes) const; @@ -319,7 +418,18 @@ class PragmaStepNode: public StepNode { static constexpr const char* _type_key = "ansor.PragmaStep"; TVM_DECLARE_FINAL_OBJECT_INFO(PragmaStepNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(PragmaStep, Step, PragmaStepNode); + +/*! + * \brief Managed reference to PragmaStepNode. + * \sa PragmaStepNode + */ +class PragmaStep : public Step { + public: + PragmaStep(int stage_id, int iter_id, std::string pragma_type); + + TVM_DEFINE_OBJECT_REF_METHODS(PragmaStep, Step, PragmaStepNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(PragmaStepNode); +}; /*! \brief Reduction factor step that corresponds to te::Schedule::rfactor */ class RfactorStepNode: public StepNode { @@ -327,11 +437,9 @@ class RfactorStepNode: public StepNode { int iter_id; int factor_iter_id; - static RfactorStep make(int stage_id, int iter_id, int factor_iter_id); - Array ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule) const; + StageToAxesMap *stage_to_axes, + te::Schedule *schedule) const; std::string PrintAsPythonAPI(std::vector *stages, StageToAxesMap *stage_to_axes, @@ -341,7 +449,18 @@ class RfactorStepNode: public StepNode { static constexpr const char* _type_key = "ansor.RfactorStep"; TVM_DECLARE_FINAL_OBJECT_INFO(RfactorStepNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(RfactorStep, Step, RfactorStepNode); + +/*! + * \brief Managed reference to RfactorStepNode. + * \sa RfactorStepNode + */ +class RfactorStep : public Step { + public: + RfactorStep(int stage_id, int iter_id, int factor_iter_id); + + TVM_DEFINE_OBJECT_REF_METHODS(RfactorStep, Step, RfactorStepNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(RfactorStepNode); +}; /*! \brief Storage align step that corresponds to te::Schedule::storage_align */ class StorageAlignStepNode: public StepNode { @@ -350,9 +469,6 @@ class StorageAlignStepNode: public StepNode { int factor; int offset; - static StorageAlignStep make(int stage_id, int iter_id, int factor, - int offset); - void ApplyToSchedule(std::vector *stages, StageToAxesMap *stage_to_axes) const; @@ -364,7 +480,18 @@ class StorageAlignStepNode: public StepNode { static constexpr const char* _type_key = "ansor.StorageAlignStep"; TVM_DECLARE_FINAL_OBJECT_INFO(StorageAlignStepNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(StorageAlignStep, Step, StorageAlignStepNode); + +/*! + * \brief Managed reference to StorageAlignStepNode. + * \sa StorageAlignStepNode + */ +class StorageAlignStep : public Step { + public: + StorageAlignStep(int stage_id, int iter_id, int factor, int offset); + + TVM_DEFINE_OBJECT_REF_METHODS(StorageAlignStep, Step, StorageAlignStepNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(StorageAlignStepNode); +}; /*! \brief Tensorize step that corresponds to te::Schedule::tensorize * \Note This step takes a global registered function name as input. */ @@ -373,9 +500,6 @@ class TensorizeStepNode: public StepNode { int iter_id; std::string ti_func_name; - static TensorizeStep make(int stage_id, int iter_id, - std::string ti_func_name); - void ApplyToSchedule(std::vector *stages, StageToAxesMap *stage_to_axes) const; @@ -387,7 +511,18 @@ class TensorizeStepNode: public StepNode { static constexpr const char* _type_key = "ansor.TensorizeStep"; TVM_DECLARE_FINAL_OBJECT_INFO(TensorizeStepNode, Object); }; -TVM_DEFINE_COW_OBJECT_REF(TensorizeStep, Step, TensorizeStepNode); + +/*! + * \brief Managed reference to TensorizeStepNode. + * \sa TensorizeStepNode + */ +class TensorizeStep : public Step { + public: + TensorizeStep(int stage_id, int iter_id, std::string ti_func_name); + + TVM_DEFINE_OBJECT_REF_METHODS(TensorizeStep, Step, TensorizeStepNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(TensorizeStepNode); +}; } // namespace ansor } // namespace tvm diff --git a/tests/cpp/ansor_test.cc b/tests/cpp/ansor_test.cc index 5f1dea0f1ea5..36ac46f49551 100644 --- a/tests/cpp/ansor_test.cc +++ b/tests/cpp/ansor_test.cc @@ -79,7 +79,7 @@ using namespace tvm::ansor; // Test Access Analyzer TEST(ComputeDAG, GetProducersConsumers) { const auto& tensors = conv2d_nchw_bn_relu_func(1, 224, 224, 3, 64, 7, 2, 3); - const auto& dag = tvm::ansor::ComputeDAGNode::make(tensors); + const auto& dag = tvm::ansor::ComputeDAG(tensors); int data = 0, padding = 1, kernel = 2, conv = 3, bias = 4, bias_add = 5; int bn_scale = 6, bn_mul = 7, bn_offset = 8, bn_add = 9, relu = 10; From 8e53d125d9fcdeed6ab5c422d151438a422a12a0 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Tue, 23 Jun 2020 16:28:58 +0800 Subject: [PATCH 36/45] Some lint fix & Recover the double constructor of tvm::PrimExpr (#39) * lint fix * clang-format-fix * pylint fix * Update * Recover the double constructor of tvm::PrimExpr * Fix pylint * pylint fix * pylint fix --- include/tvm/ir/expr.h | 5 -- python/tvm/ansor/__init__.py | 1 - python/tvm/ansor/auto_schedule.py | 8 +- python/tvm/ansor/cost_model/cost_model.py | 1 - python/tvm/ansor/dispatcher.py | 2 +- python/tvm/ansor/env.py | 1 - python/tvm/ansor/feature.py | 46 +++++------ python/tvm/ansor/loop_state.py | 81 ++++++++++--------- python/tvm/ansor/measure.py | 37 ++++++--- python/tvm/ansor/relay_integration.py | 27 ++++--- python/tvm/ansor/task_scheduler.py | 12 ++- python/tvm/ansor/workload_registry.py | 8 +- python/tvm/relay/backend/compile_engine.py | 2 +- python/tvm/relay/op/strategy/x86.py | 1 - python/tvm/relay/testing/dqn.py | 6 +- python/tvm/relay/testing/resnet.py | 3 +- python/tvm/te/tensor.py | 4 +- scripts/common.py | 17 ++++ scripts/shape_configs.py | 17 ++++ scripts/tune_network.py | 17 ++++ scripts/tune_op_subgraph.py | 17 ++++ scripts/tune_test.py | 17 ++++ src/ansor/measure.cc | 1 - src/ansor/measure.h | 1 - .../search_policy/sketch_search_policy.cc | 10 ++- src/ansor/serialization.cc | 2 +- src/ansor/transform_step.h | 5 +- src/ir/expr.cc | 2 - src/relay/op/tensor/transform.cc | 2 - src/relay/transforms/defuse_ops.cc | 33 +++----- .../transforms/kernel_layout_transform.cc | 25 +++--- .../transforms/kernel_layout_transform.h | 60 ++++++++++---- src/runtime/rpc/rpc_module.cc | 5 +- src/tir/transforms/unroll_loop.cc | 5 +- 34 files changed, 300 insertions(+), 181 deletions(-) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index b3e527ca6fd9..b2ce50d91f58 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -112,11 +112,6 @@ class PrimExpr : public BaseExpr { * \param value The value to be constructed. */ TVM_DLL PrimExpr(float value); // NOLINT(*) - /*! - * \brief construct from double. - * \param value The value to be constructed. - */ - TVM_DLL PrimExpr(double value); // NOLINT(*) /*! \return the data type of this expression. */ DataType dtype() const { return static_cast(get())->dtype; } diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index c629c1049a87..edade490018c 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -21,7 +21,6 @@ from . import measure from . import serialization from . import loop_state -from . import auto_schedule from . import utils from . import feature from . import workload_registry diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index a03d9fdacbc2..4497bb400703 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -22,7 +22,7 @@ import tvm._ffi from tvm.runtime import Object from .measure import LocalBuilder, LocalRunner -from .cost_model import RandomModel, XGBModel +from .cost_model import RandomModel from . import _ffi_api @@ -133,7 +133,6 @@ def __init__(self, @tvm._ffi.register_object("ansor.SearchCallback") class SearchCallback(Object): """Callback function before or after search process""" - pass @tvm._ffi.register_object("ansor.PreloadMeasuredStates") @@ -262,8 +261,7 @@ def auto_schedule(workload, target=None, sch, tensors = _ffi_api.AutoScheduleByWorkloadKey( workload, target, target_host, search_policy, hardware_params, tune_option) return sch, tensors - elif isinstance(workload, SearchTask): + if isinstance(workload, SearchTask): sch, tensors = _ffi_api.AutoScheduleBySearchTask(workload, search_policy, tune_option) return sch, tensors - else: - raise ValueError("Invalid workload: " + workload + ". Expect a string or SearchTask") + raise ValueError("Invalid workload: " + workload + ". Expect a string or SearchTask") diff --git a/python/tvm/ansor/cost_model/cost_model.py b/python/tvm/ansor/cost_model/cost_model.py index 57cc53853b2e..fbfc8242488b 100644 --- a/python/tvm/ansor/cost_model/cost_model.py +++ b/python/tvm/ansor/cost_model/cost_model.py @@ -27,7 +27,6 @@ @tvm._ffi.register_object("ansor.CostModel") class CostModel(Object): """The base class for cost model""" - pass @tvm._ffi.register_object("ansor.RandomModel") diff --git a/python/tvm/ansor/dispatcher.py b/python/tvm/ansor/dispatcher.py index 0c07fd141bd2..3a5dc4e9e206 100644 --- a/python/tvm/ansor/dispatcher.py +++ b/python/tvm/ansor/dispatcher.py @@ -34,7 +34,7 @@ class DispatchContext(object): """ Base class of dispatch context. """ - current = None + current = None def __init__(self): self._old_ctx = DispatchContext.current diff --git a/python/tvm/ansor/env.py b/python/tvm/ansor/env.py index 0f35f92acbbc..56e76e26ee4f 100644 --- a/python/tvm/ansor/env.py +++ b/python/tvm/ansor/env.py @@ -23,4 +23,3 @@ def __init__(self): self.topi_in_compute_rewrite_mode = False GLOBAL_SCOPE = AutoschedulerGlobalScope() - diff --git a/python/tvm/ansor/feature.py b/python/tvm/ansor/feature.py index d9f6d297f1af..fa1b2cb07dcc 100644 --- a/python/tvm/ansor/feature.py +++ b/python/tvm/ansor/feature.py @@ -40,21 +40,20 @@ def unpack_feature(byte_arr: bytearray) -> Tuple[np.ndarray, np.ndarray, np.ndar size_of_int = 4 size_of_float = 4 - """ - The format for n records is: - { - int n; - int[n+2] sizes - - float[sizes[0]] feature for record 1 - float[sizes[1]] feature for record 2 - ... feature for record i... - float[sizes[n-1]] feature for record n - - float[sizes[n]] normalized throughput for n records - int[sizes[n+1]] task id for n records - } - """ + # The format for n records is: + # { + # int n; + # int[n+2] sizes + + # float[sizes[0]] feature for record 1 + # float[sizes[1]] feature for record 2 + # ... feature for record i... + # float[sizes[n-1]] feature for record n + + # float[sizes[n]] normalized throughput for n records + # int[sizes[n+1]] task id for n records + # } + vec_len = DEFAULT_FEATURE_VEC_LEN # unpack sizes @@ -70,15 +69,14 @@ def unpack_feature(byte_arr: bytearray) -> Tuple[np.ndarray, np.ndarray, np.ndar for size in sizes[:-2]: row = [] - """ - Now we need to unpack the feature for multiple statements. - The format is: - { - int n_stmts - float[n_stmt][vec_len] feature_vecs - } - where vec_len can be calculated by `(size - 1) / n_stmts` - """ + # Now we need to unpack the feature for multiple statements. + # The format is: + # { + # int n_stmts + # float[n_stmt][vec_len] feature_vecs + # } + # where vec_len can be calculated by `(size - 1) / n_stmts` + if size == 0: # failed during lowering features.append(np.zeros((1, vec_len))) diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/ansor/loop_state.py index 3c60c3f09a8d..8560a57bc902 100644 --- a/python/tvm/ansor/loop_state.py +++ b/python/tvm/ansor/loop_state.py @@ -42,7 +42,6 @@ @tvm._ffi.register_object("ansor.Iterator") class Iterator(Object): """A for loop iterator""" - pass @tvm._ffi.register_object("ansor.Stage") @@ -90,8 +89,7 @@ def __getitem__(self, k): self.stages_cache = _ffi_api.StateGetStages(self.state_object) if isinstance(k, tvm.te.Tensor): return self.stages_cache[self.stage_id_map[k.op]] - else: - raise ValueError("Item must be Tensor") + raise ValueError("Item must be Tensor") def __update_tensor_stage_map(self): if not self.stages_cache: @@ -164,13 +162,13 @@ def reorder(self, stage_id, order): self.state_object = _ffi_api.StateReorder(self.state_object, stage_id, order) self.clear_cache() - def split(self, stage_id, it, lengths, inner_to_outer=True): + def split(self, stage_id, iterator, lengths, inner_to_outer=True): """ Parameters ---------- stage_id : Int The index of the stage to split - it : Iterator + iterator : Iterator The iterator to split lengths: List[Int] The split factors @@ -188,18 +186,18 @@ def split(self, stage_id, it, lengths, inner_to_outer=True): elif not isinstance(stage_id, int): raise ValueError("stage_id must be Tensor or Int") - self.state_object, res = _ffi_api.StateSplit(self.state_object, stage_id, it, lengths, + self.state_object, res = _ffi_api.StateSplit(self.state_object, stage_id, iterator, lengths, inner_to_outer) self.clear_cache() return res - def follow_split(self, stage_id, it, src_step_id, n_split): + def follow_split(self, stage_id, iterator, src_step_id, n_split): """ Parameters ---------- stage_id : Int The index of the stage to split - it : Iterator + iterator : Iterator The iterator to split src_step_id : Int The index of the split step to follow in the history @@ -216,19 +214,19 @@ def follow_split(self, stage_id, it, src_step_id, n_split): elif not isinstance(stage_id, int): raise ValueError("stage_id must be Tensor or Int") - self.state_object, res = _ffi_api.StateFollowSplit(self.state_object, stage_id, it, + self.state_object, res = _ffi_api.StateFollowSplit(self.state_object, stage_id, iterator, src_step_id, n_split) self.clear_cache() return res - def follow_fused_split(self, stage_id, it, src_step_ids, level, + def follow_fused_split(self, stage_id, iterator, src_step_ids, level, factor_or_nparts): """ Parameters ---------- stage_id : Int The index of the stage to split - it : Iterator + iterator : Iterator The iterator to split src_step_ids : List[Int] The indices of the split steps to follow in the history @@ -248,8 +246,8 @@ def follow_fused_split(self, stage_id, it, src_step_ids, level, elif not isinstance(stage_id, int): raise ValueError("stage_id must be Tensor or Int") - self.state_object, res = _ffi_api.StateFollowFusedSplit(self.state_object, stage_id, it, - src_step_ids, level, + self.state_object, res = _ffi_api.StateFollowFusedSplit(self.state_object, stage_id, + iterator, src_step_ids, level, factor_or_nparts) self.clear_cache() return res @@ -277,13 +275,13 @@ def fuse(self, stage_id, iters): self.clear_cache() return res - def vectorize(self, stage_id, it): + def vectorize(self, stage_id, iterator): """ Parameters ---------- stage_id : Int The index of the stage to vectorize - it : Iterator + iterator : Iterator The iterator to be vectorized Returns @@ -296,17 +294,17 @@ def vectorize(self, stage_id, it): elif not isinstance(stage_id, int): raise ValueError("stage_id must be Tensor or Int") - self.state_object, res = _ffi_api.StateVectorize(self.state_object, stage_id, it) + self.state_object, res = _ffi_api.StateVectorize(self.state_object, stage_id, iterator) self.clear_cache() return res - def parallel(self, stage_id, it): + def parallel(self, stage_id, iterator): """ Parameters ---------- stage_id : Int The index of the stage to parallel - it : Iterator + iterator : Iterator The iterator to be parallelized Returns @@ -319,17 +317,17 @@ def parallel(self, stage_id, it): elif not isinstance(stage_id, int): raise ValueError("stage_id must be Tensor or Int") - self.state_object, res = _ffi_api.StateParallel(self.state_object, stage_id, it) + self.state_object, res = _ffi_api.StateParallel(self.state_object, stage_id, iterator) self.clear_cache() return res - def unroll(self, stage_id, it, max_unroll=-1): + def unroll(self, stage_id, iterator, max_unroll=-1): """ Parameters ---------- stage_id : Int The index of the stage to unroll - it : Iterator + iterator : Iterator The iterator to be unrolled max_unroll: Int The maximum length of the iterator that can be unrolled @@ -344,17 +342,18 @@ def unroll(self, stage_id, it, max_unroll=-1): elif not isinstance(stage_id, int): raise ValueError("stage_id must be Tensor or Int") - self.state_object, res = _ffi_api.StateUnroll(self.state_object, stage_id, it, max_unroll) + self.state_object, res = _ffi_api.StateUnroll(self.state_object, stage_id, iterator, + max_unroll) self.clear_cache() return res - def bind_thread(self, stage_id, it, thread_name): + def bind_thread(self, stage_id, iterator, thread_name): """ Parameters ---------- stage_id : Int The index of the stage to bind - it : Iterator + iterator : Iterator The iterator to be bound thread_name : str The name of the thread (e.g. "blockIdx.x", "threadIdx.y", "vthread") @@ -378,7 +377,8 @@ def bind_thread(self, stage_id, it, thread_name): elif not isinstance(stage_id, int): raise ValueError("stage_id must be Tensor or Int") - self.state_object, res = _ffi_api.StateBindThread(self.state_object, stage_id, it, thread_id) + self.state_object, res = _ffi_api.StateBindThread(self.state_object, stage_id, iterator, + thread_id) self.clear_cache() return res @@ -403,7 +403,7 @@ def compute_at(self, stage_id, target_stage_id, target_iter): raise ValueError("target_stage_id must be Tensor or Int") self.state_object = _ffi_api.StateComputeAt(self.state_object, stage_id, - target_stage_id, target_iter) + target_stage_id, target_iter) self.clear_cache() def compute_root(self, stage_id): @@ -494,13 +494,13 @@ def cache_write(self, stage_id, scope_name): scope_name, self.compute_dag) return self.__insert_new_stage(new_stage_id) - def pragma(self, stage_id, it, pragma_type): + def pragma(self, stage_id, iterator, pragma_type): """ Parameters ---------- stage_id : Int The index of the stage to add pragma - it : Iterator + iterator : Iterator The iterator to add pragma pragma_type : Str """ @@ -509,16 +509,17 @@ def pragma(self, stage_id, it, pragma_type): elif not isinstance(stage_id, int): raise ValueError("stage_id must be Tensor or Int") - self.state_object = _ffi_api.StatePragma(self.state_object, stage_id, it, pragma_type) + self.state_object = _ffi_api.StatePragma(self.state_object, stage_id, iterator, + pragma_type) self.clear_cache() - def rfactor(self, stage_id, it, factor_iter_id): + def rfactor(self, stage_id, iterator, factor_iter_id): """ Parameters ---------- stage_id : Int The index of the stage to do reduction factor - it : Iterator + iterator : Iterator factor_iter_id : Int Returns @@ -531,17 +532,18 @@ def rfactor(self, stage_id, it, factor_iter_id): elif not isinstance(stage_id, int): raise ValueError("stage_id must be Tensor or Int") - self.state_object, new_stage_id = _ffi_api.StateRfactor(self.state_object, stage_id, it, - factor_iter_id, self.compute_dag) + self.state_object, new_stage_id = _ffi_api.StateRfactor(self.state_object, stage_id, + iterator, factor_iter_id, + self.compute_dag) return self.__insert_new_stage(new_stage_id) - def storage_align(self, stage_id, it, factor, offset): + def storage_align(self, stage_id, iterator, factor, offset): """ Parameters ---------- stage_id : Int The index of the stage to do storage align - it : Iterator + iterator : Iterator factor : Int offset : Int """ @@ -550,10 +552,11 @@ def storage_align(self, stage_id, it, factor, offset): elif not isinstance(stage_id, int): raise ValueError("stage_id must be Tensor or Int") - self.state_object = _ffi_api.StateStorageAlign(self.state_object, stage_id, it, factor, offset) + self.state_object = _ffi_api.StateStorageAlign(self.state_object, stage_id, iterator, + factor, offset) self.clear_cache() - def tensorize(self, stage_id, it, ti_func_name): + def tensorize(self, stage_id, iterator, ti_func_name): """ The `ti_func_name` corresponds to a global registered funcion that returns a TensorIntrin @@ -561,7 +564,7 @@ def tensorize(self, stage_id, it, ti_func_name): ---------- stage_id : Int The index of the stage to do storage align - it : Iterator + iterator : Iterator The target iterator ti_func_name : Str Tensorize intrinsic function name @@ -577,7 +580,7 @@ def tensorize(self, stage_id, it, ti_func_name): raise ValueError("stage_id must be Tensor or Int") self.state_object, res = _ffi_api.StateTensorize(self.state_object, - stage_id, it, + stage_id, iterator, ti_func_name) self.clear_cache() return res diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index f00fe672505d..be7d69e5ed3a 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -40,10 +40,11 @@ from tvm.autotvm.measure.measure_methods import set_cuda_target_arch from tvm.contrib import tar, ndk from . import _ffi_api -from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout, request_remote, check_remote +from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout, request_remote, \ + check_remote from .compute_dag import LayoutRewriteLevel -logger = logging.getLogger('ansor') +LOGGER = logging.getLogger('ansor') # The maximum length of error message MAX_ERROR_MSG_LEN = 512 @@ -52,7 +53,7 @@ @tvm._ffi.register_object("ansor.MeasureCallback") class MeasureCallback(Object): """Base class for measurement callback function""" - pass + @tvm._ffi.register_object("ansor.MeasureInput") class MeasureInput(Object): @@ -105,6 +106,8 @@ def __init__(self, costs, error_no, error_msg, all_cost, timestamp): @tvm._ffi.register_object("ansor.Builder") class Builder(Object): + """ Base class of Builder + """ def build(self, measure_inputs, verbose=1): """ Parameters @@ -121,6 +124,8 @@ def build(self, measure_inputs, verbose=1): @tvm._ffi.register_object("ansor.Runner") class Runner(Object): + """ Base class of Runner + """ def run(self, measure_inputs, build_results, verbose=1): """ Parameters @@ -221,7 +226,7 @@ def __init__(self, key, host, port, priority=1, number, repeat, min_repeat_ms, cooldown_interval) if check_remote(key, host, port, priority, timeout): - logger.info("Get devices for measurement successfully!") + LOGGER.info("Get devices for measurement successfully!") else: raise RuntimeError("Cannot get remote devices from the tracker. " "Please check the status of tracker by " @@ -260,7 +265,7 @@ def __init__(self, self.tracker = Tracker(host, port=9000, port_end=10000, silent=True) device_key = '$local$device$%d' % self.tracker.port self.server = Server(host, port=self.tracker.port, port_end=10000, - key=device_key, use_popen=True, silent=True, + key=device_key, use_popen=True, silent=True, tracker_addr=(self.tracker.host, self.tracker.port)) self.runner = RPCRunner(device_key, host, self.tracker.port, priority, n_parallel, timeout, number, repeat, @@ -302,6 +307,8 @@ def make_error_msg(): def local_build_worker(index): + """ Local builder function + """ # We use fork to copy arguments from a global variable. # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool measure_inputs, build_func, timeout, verbose = global_build_arguments @@ -362,7 +369,10 @@ def timed_func(): @tvm._ffi.register_func("ansor.local_builder.build") -def local_builder_build(inputs: List[MeasureInput], timeout: float, n_parallel: int, build_func: str, verbose: int): +def local_builder_build(inputs: List[MeasureInput], timeout: float, n_parallel: int, + build_func: str, verbose: int): + """ Local builder build function + """ # We use fork to copy arguments from a global variable. # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool global global_build_arguments @@ -409,6 +419,8 @@ def rpc_runner_run(inputs: List[MeasureInput], build_results: List[BuildResult], def rpc_run_worker(index): + """ ... + """ inputs, build_results, key, host, port, priority, timeout, number, \ repeat, min_repeat_ms, cooldown_interval, verbose = global_run_arguments @@ -417,7 +429,8 @@ def rpc_run_worker(index): build_res = build_results[index] if build_res.error_no != MeasureErrorNo.NO_ERROR: - return (MAX_FLOAT,), build_res.error_no, build_res.error_msg, build_res.time_cost, time.time() + return (MAX_FLOAT,), build_res.error_no, build_res.error_msg, build_res.time_cost, \ + time.time() def timed_func(): tic = time.time() @@ -478,6 +491,8 @@ def timed_func(): def local_run(inputs: List[MeasureInput], build_results: List[BuildResult], timeout: float, number: int, repeat: int, min_repeat_ms: int, cooldown_interval: float, verbose: int): + """ ... + """ MAX_FLOAT = 1e10 # We use 1e10 instead of sys.float_info.max for better readability in log def timed_func(inp, build_res): @@ -522,16 +537,16 @@ def timed_func(inp, build_res): "Measure input size should be equal to build results" for inp, build_res in zip(inputs, build_results): if build_res.error_no != 0: - res = ( - MAX_FLOAT,), build_res.error_no, build_res.error_msg, build_res.time_cost, time.time() + res = (MAX_FLOAT,), build_res.error_no, build_res.error_msg, build_res.time_cost, \ + time.time() else: res = call_func_with_timeout( timeout, timed_func, args=(inp, build_res)) if isinstance(res, TimeoutError): if verbose >= 1: print("*T", end="") # Run timeout - res = ( - MAX_FLOAT,), MeasureErrorNo.RUN_TIMEOUT, None, build_res.time_cost + timeout, time.time() + res = (MAX_FLOAT,), MeasureErrorNo.RUN_TIMEOUT, None, \ + build_res.time_cost + timeout, time.time() measure_results.append(MeasureResult(*res)) if verbose >= 1: diff --git a/python/tvm/ansor/relay_integration.py b/python/tvm/ansor/relay_integration.py index 3c2eabd3dfac..f2873f8c72fd 100644 --- a/python/tvm/ansor/relay_integration.py +++ b/python/tvm/ansor/relay_integration.py @@ -25,21 +25,22 @@ import json import threading -from tvm import target, te, transform +import tvm +from tvm import te, transform from tvm.te.tensor import PlaceholderOp, ComputeOp from .dispatcher import DispatchContext from .workload_registry import register_workload_bufs, compute_dag_hash from .compute_dag import ComputeDAG, LayoutRewriteLevel from .env import GLOBAL_SCOPE -def call_all_topi_funcs(mod, target, params): +def call_all_topi_funcs(mod, target, params, target_host=None): """Call all TOPI compute + schedule to extract tasks in a relay program""" # pylint: disable=import-outside-toplevel from tvm import relay with transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}): bld_mod = relay.build_module.BuildModule() - bld_mod.call_all_topi_funcs(mod, target=target, params=params) + bld_mod.call_all_topi_funcs(mod, target=target, params=params, target_host=target_host) def extract_from_program(mod, params, target, target_host=None): """ Extract tuning tasks from a relay program. @@ -95,7 +96,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None): # wrap build call in a new thread to avoid the conflict # between python's multiprocessing and tvm's thread pool build_thread = threading.Thread(target=call_all_topi_funcs, - args=(mod, target, param)) + args=(mod, target, param, target_host)) build_thread.start() build_thread.join() relay.backend.compile_engine.get().clear() @@ -112,7 +113,8 @@ def extract_from_multiple_program(mods, params, target, target_host=None): def prepare_layout_rewrite(mod, params, target): """ - Prepare for kernel layout rewrite. This function will write layout infos to a global static variable. + Prepare for kernel layout rewrite. This function will write layout infos to a global static + variable. Then these layout info will be used by a relay pass `kernel_layout_transform`. """ # pylint: disable=import-outside-toplevel @@ -207,26 +209,26 @@ def auto_schedule_topi(outs): env = TracingEnvironment.current if env is None: # in the final build mode - state = DispatchContext.current.query(target.Target.current(), key) + state = DispatchContext.current.query(tvm.target.Target.current(), key) if state is None: return te.create_schedule([x.op for x in outs]) dag = ComputeDAG(io_tensors) # Only update compute body, layout_rewrite_level = LayoutRewriteLevel.COMPUTE_REWRITE, # Since kernel layout has already been rewritten in relay pass - schedule, _ = dag.apply_steps_from_state(state, - layout_rewrite_level=LayoutRewriteLevel.COMPUTE_REWRITE) + schedule, _ = dag.apply_steps_from_state( + state, layout_rewrite_level=LayoutRewriteLevel.COMPUTE_REWRITE) return schedule - elif env.tracing_mode == TracingMode.EXTRACT_TASK: # in the task extraction mode + if env.tracing_mode == TracingMode.EXTRACT_TASK: # in the task extraction mode env.add_workload_key(key) return te.create_schedule([x.op for x in outs]) - elif env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE: + if env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE: # in prepare_layout_rewrite mode if has_layout_free: # Rewrite the DAG and update the transform history for # the new dag in DispatchContext dispatch_ctx = DispatchContext.current - tgt = target.Target.current() + tgt = tvm.target.Target.current() state = dispatch_ctx.query(tgt, key) assert state is not None dag = ComputeDAG(outs) @@ -236,5 +238,4 @@ def auto_schedule_topi(outs): if new_key != key: env.layout_rewrite_success_ct += 1 return te.create_schedule([x.op for x in outs]) - else: - raise ValueError("Invalid tracing mode: " + env.tracing_mode) + raise ValueError("Invalid tracing mode: " + env.tracing_mode) diff --git a/python/tvm/ansor/task_scheduler.py b/python/tvm/ansor/task_scheduler.py index 587fe3121e88..5b916ed39769 100644 --- a/python/tvm/ansor/task_scheduler.py +++ b/python/tvm/ansor/task_scheduler.py @@ -41,6 +41,8 @@ def compute_score(self, costs: List[float]) -> float: def get_search_policies(search_policy: Union[str, List[SearchPolicy]], tasks: List[SearchTask], num_measure_per_iter, load_model_file=None, load_log_file=None): + """ ... + """ if search_policy == 'default': search_policy = 'sketch.xgb' @@ -98,7 +100,8 @@ class SimpleTaskScheduler(TaskScheduler): load_log_file: str Load history log file to pre-train cost model eps-random: float - Always allocate this percent of n_trials to select tasks randomly. This is for encouraging exploration. + Always allocate this percent of n_trials to select tasks randomly. + This is for encouraging exploration. verbose: int The level of verbosity. 0 means silent. alpha: float @@ -144,7 +147,8 @@ def __init__(self, self.sequential_now_task_idx = 0 self.sequential_now_task_begin_ct = 0 - def tune(self, tune_option: TuneOption, search_policy: Union[str, List[SearchPolicy]] = 'default'): + def tune(self, tune_option: TuneOption, + search_policy: Union[str, List[SearchPolicy]] = 'default'): """ Tune tasks. Notice: This method does not have return value, make sure to set `LogToFile` @@ -252,6 +256,8 @@ def tune(self, tune_option: TuneOption, search_policy: Union[str, List[SearchPol self.tune_task(task_idx) def tune_task(self, task_idx): + """ ... + """ if self.use_debug_measurement_simulator is not None: measure_inputs, measure_results = \ self.use_debug_measurement_simulator.get_next_batch( @@ -282,7 +288,7 @@ def tune_task(self, task_idx): if self.verbose >= 1: print(("TaskScheduler\tct: %d\testimated cost (ms): %.3f\ttime elapsed: %.2f\t" + - "best_costs (ms): %s\ttask_ct: %s") % + "best_costs (ms): %s\ttask_ct: %s") % (self.ct, self.cur_score * 1e3, time.time() - self.tic, to_str_round(self.best_costs * 1e3, decimal=3), self.task_cts)) diff --git a/python/tvm/ansor/workload_registry.py b/python/tvm/ansor/workload_registry.py index e706c0ec4cf9..025b5f03c661 100644 --- a/python/tvm/ansor/workload_registry.py +++ b/python/tvm/ansor/workload_registry.py @@ -23,7 +23,8 @@ The dag should be the return value of this `func_name(*args)`. Rationale: The workload is actually a compute dag defined by tvm dsl. But serializing compute dags -and matching them efficiently is not easy. Therefore, we use the above string to encode a compute dag. +and matching them efficiently is not easy. Therefore, we use the above string to encode a compute +dag. These strings are efficient for serialization/matching and wont' be too long. When we need the dag, we decode the string and call the function, which will return the dag. """ @@ -65,6 +66,8 @@ def matmul(N, M, K): def compute_dag_hash(dag: ComputeDAG): + """ Get hash value for a ComputeDAG + """ # todo: implement this more carefully and move this to c++ as a member function of ComputeDAG str_key = '' for op in dag.ops: @@ -139,8 +142,7 @@ def workload_key_to_tensors(workload_key: str) -> List[Tensor]: if callable(lookup): args = deserialize_args(workload[1:]) return lookup(*args) - else: - return lookup + return lookup @ tvm._ffi.register_func("ansor.workload_key_to_dag") diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index 66ef5cd4c852..b6bedb411540 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -18,10 +18,10 @@ """Backend code generation engine.""" from __future__ import absolute_import +import os import logging import numpy as np import tvm -import os from tvm import te from tvm.runtime import Object from ... import target as _target diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 2a0ddd1329b5..3453b089f373 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -17,7 +17,6 @@ """Definition of x86 operator strategy.""" # pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import -import os from tvm.te import SpecializedCondition from tvm import ansor from .generic import * diff --git a/python/tvm/relay/testing/dqn.py b/python/tvm/relay/testing/dqn.py index b65e0ad5cae9..3d6883362c9b 100644 --- a/python/tvm/relay/testing/dqn.py +++ b/python/tvm/relay/testing/dqn.py @@ -63,7 +63,8 @@ def get_net(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32" return relay.Function(args, dense2) -def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32", layout="NCHW"): +def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32", + layout="NCHW"): """Get benchmark workload for a Deep Q Network Parameters ---------- @@ -82,5 +83,6 @@ def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="flo params : dict of str to NDArray The parameters. """ - net = get_net(batch_size, num_actions=num_actions, image_shape=image_shape, dtype=dtype, layout=layout) + net = get_net(batch_size, num_actions=num_actions, image_shape=image_shape, dtype=dtype, + layout=layout) return create_workload(net) diff --git a/python/tvm/relay/testing/resnet.py b/python/tvm/relay/testing/resnet.py index 4383157d9f06..ac63afde4cba 100644 --- a/python/tvm/relay/testing/resnet.py +++ b/python/tvm/relay/testing/resnet.py @@ -163,7 +163,8 @@ def resnet(units, num_unit = len(units) assert num_unit == num_stages data = relay.var("data", shape=data_shape, dtype=dtype) - data = layers.batch_norm_infer(data=data, epsilon=2e-5, axis=bn_axis, scale=False, name='bn_data') + data = layers.batch_norm_infer(data=data, epsilon=2e-5, axis=bn_axis, scale=False, + name='bn_data') (_, _, height, _) = data_shape if layout == "NHWC": (_, height, _, _) = data_shape diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py index 6539aabaa48f..6a2120817eb1 100644 --- a/python/tvm/te/tensor.py +++ b/python/tvm/te/tensor.py @@ -56,9 +56,9 @@ class Tensor(DataProducer, _expr.ExprOp): """Tensor object, to construct, see function.Tensor""" def __call__(self, *indices): - ndim = self.ndim + # ndim = self.ndim # After ansor kernel layout rewrite, len(indices) <= ndim, - # and the indices will get modified by Ansor during schedule generation. + # and the indices will get modified by Ansor during schedule generation. # if len(indices) != ndim: # raise ValueError("Need to provide %d index in tensor slice" % ndim) indices = convert_to_object(indices) diff --git a/scripts/common.py b/scripts/common.py index 8f4fbec09dd0..ac25b28e55b1 100644 --- a/scripts/common.py +++ b/scripts/common.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + """Common utility for scripts""" import argparse import math diff --git a/scripts/shape_configs.py b/scripts/shape_configs.py index 244638f5b29c..db6b3b9dc9aa 100644 --- a/scripts/shape_configs.py +++ b/scripts/shape_configs.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + """ Shape configurations for single operator / subgraph evaluation This file is shared by tune_op_subgraph.py and scripts in scripts/baseline/ """ diff --git a/scripts/tune_network.py b/scripts/tune_network.py index 1905d8132003..188da6cbe6e6 100644 --- a/scripts/tune_network.py +++ b/scripts/tune_network.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + """Tune a whole neural network""" import argparse import logging diff --git a/scripts/tune_op_subgraph.py b/scripts/tune_op_subgraph.py index 6574bb77e510..d3e70501873e 100644 --- a/scripts/tune_op_subgraph.py +++ b/scripts/tune_op_subgraph.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + """Tune all workloads for single op & subgraph evaluation""" import argparse import logging diff --git a/scripts/tune_test.py b/scripts/tune_test.py index 67c0526dd624..c98da3eca53b 100644 --- a/scripts/tune_test.py +++ b/scripts/tune_test.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + """Use auto scheduler to tune workloads""" import argparse import logging diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc index 4ae35fb410a9..a044acfe5395 100644 --- a/src/ansor/measure.cc +++ b/src/ansor/measure.cc @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2020 by Contributors * \file ansor/measure.cc * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs */ diff --git a/src/ansor/measure.h b/src/ansor/measure.h index a6db55f6181e..760a1542944f 100644 --- a/src/ansor/measure.h +++ b/src/ansor/measure.h @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2020 by Contributors * \file ansor/measure.h * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs */ diff --git a/src/ansor/search_policy/sketch_search_policy.cc b/src/ansor/search_policy/sketch_search_policy.cc index 5b2c10c08c81..63f75cad1c83 100644 --- a/src/ansor/search_policy/sketch_search_policy.cc +++ b/src/ansor/search_policy/sketch_search_policy.cc @@ -901,7 +901,7 @@ int InitPopulationCooperativeFetching(const SketchSearchPolicyNode* policy, int InitPopulationChangeComputeLocation(const SketchSearchPolicyNode* policy, State* state, std::mt19937* rand_gen) { - if(GetIntParam(policy->params, "disable_change_compute_location")) { + if (GetIntParam(policy->params, "disable_change_compute_location")) { return 0; } @@ -1063,7 +1063,8 @@ int InitPopulationChangeComputeLocation(const SketchSearchPolicyNode* policy, int InitPopulationParallel(const SketchSearchPolicyNode* policy, State* state) { - std::function annotate_parallel; + std::function + annotate_parallel; annotate_parallel = [&annotate_parallel]( const SketchSearchPolicyNode* policy, State* state, int stage_id, int iter_offset) { @@ -1095,7 +1096,8 @@ int InitPopulationParallel(const SketchSearchPolicyNode* policy, } if (parallel_degree == 1) { - auto res = (*state)->attach_map->iter_to_attached_stages.find(std::make_pair(stage_id, iter_id)); + auto res = + (*state)->attach_map->iter_to_attached_stages.find(std::make_pair(stage_id, iter_id)); if (res != (*state)->attach_map->iter_to_attached_stages.end()) { for (int attached_stage_id : res->second) { annotate_parallel(policy, state, attached_stage_id, 0); @@ -1188,7 +1190,7 @@ int InitPopulationVectorization(const SketchSearchPolicyNode* policy, } if (num_fusible > 1) { - num_fusible = 1 + (*rand_gen)() % (num_fusible - 1); // Select a random range to fuse + num_fusible = 1 + (*rand_gen)() % (num_fusible - 1); // Select a random range to fuse } if (num_fusible == 1) { diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index 71fba764506f..c026b9b6251a 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -434,7 +434,7 @@ struct Handler<::tvm::ansor::MeasureResultNode> { reader->Read(&tmp); data->costs.clear(); for (const auto& i : tmp) { - data->costs.push_back(i); + data->costs.push_back(::tvm::FloatImm(::tvm::DataType::Float(64), i)); } s = reader->NextArrayItem(); CHECK(s); reader->Read(&data->error_no); diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h index 3eb023eb75c8..edd71732b3e2 100644 --- a/src/ansor/transform_step.h +++ b/src/ansor/transform_step.h @@ -277,7 +277,7 @@ class ComputeAtStepNode: public StepNode { */ class ComputeAtStep : public Step { public: - ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id); + ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id); TVM_DEFINE_OBJECT_REF_METHODS(ComputeAtStep, Step, ComputeAtStepNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeAtStepNode); @@ -286,7 +286,6 @@ class ComputeAtStep : public Step { /*! \brief Fuse step that corresponds to te::Stage::compute_root */ class ComputeRootStepNode: public StepNode { public: - void ApplyToSchedule(std::vector *stages, StageToAxesMap *stage_to_axes) const; @@ -550,8 +549,8 @@ struct hash<::tvm::ansor::Step> { } else { ret = ::dmlc::HashCombine(ret, 0x5D); // a magic number } - return ret; } + return ret; } else if (auto ps = step.as<::tvm::ansor::FollowSplitStepNode>()) { return ::dmlc::HashCombine(3, ::dmlc::HashCombine(std::hash()(ps->stage_id), diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 6e898dd5ddb4..fd380aa33f86 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -38,8 +38,6 @@ PrimExpr::PrimExpr(int32_t value) : PrimExpr(IntImm(DataType::Int(32), value)) { PrimExpr::PrimExpr(float value) : PrimExpr(FloatImm(DataType::Float(32), value)) {} -PrimExpr::PrimExpr(double value) : PrimExpr(FloatImm(DataType::Float(64), value)) {} - PrimExpr PrimExpr::FromObject_(ObjectRef ref) { using runtime::ObjectTypeChecker; if (auto* ptr = ref.as()) { diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 18ace14a0b75..30269b85795f 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2461,7 +2461,6 @@ TVM_REGISTER_NODE_TYPE(KernelLayoutTransformAttrs); Array KernelLayoutTransformCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - //const Target& target) { const auto* param = attrs.as(); CHECK(param != nullptr); return Array{ @@ -2473,7 +2472,6 @@ bool KernelLayoutTransformRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - const auto* data = types[0].as(); CHECK(data != nullptr); const KernelLayoutTransformAttrs* params = attrs.as(); diff --git a/src/relay/transforms/defuse_ops.cc b/src/relay/transforms/defuse_ops.cc index f7c9037df687..1a108fb08888 100644 --- a/src/relay/transforms/defuse_ops.cc +++ b/src/relay/transforms/defuse_ops.cc @@ -17,19 +17,19 @@ * under the License. */ -#include #include -#include -#include #include +#include +#include #include #include -#include -#include + #include #include -#include +#include #include +#include +#include #include "pattern_util.h" @@ -38,14 +38,11 @@ namespace relay { class DefuseOpsMutator : public ExprMutator { public: - class FuncBodyMutator : public ExprMutator { public: Array args_; - FuncBodyMutator(const Array& args) : ExprMutator() { - args_ = args; - } + explicit FuncBodyMutator(const Array& args) : ExprMutator() { args_ = args; } Expr VisitExpr_(const VarNode* n) { const std::string& name = n->name_hint(); @@ -74,23 +71,19 @@ class DefuseOpsMutator : public ExprMutator { } }; -Expr DeFuseOps(const Expr& expr) { - return DefuseOpsMutator().Mutate(expr); -} +Expr DeFuseOps(const Expr& expr) { return DefuseOpsMutator().Mutate(expr); } namespace transform { Pass DeFuseOps() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(relay::DeFuseOps(f)); - }; - return CreateFunctionPass(pass_func, 3, "DeFuseOps", - {"InferType"}); + [=](Function f, IRModule m, PassContext pc) { + return Downcast(relay::DeFuseOps(f)); + }; + return CreateFunctionPass(pass_func, 3, "DeFuseOps", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.DeFuseOps") -.set_body_typed(DeFuseOps); +TVM_REGISTER_GLOBAL("relay._transform.DeFuseOps").set_body_typed(DeFuseOps); } // namespace transform diff --git a/src/relay/transforms/kernel_layout_transform.cc b/src/relay/transforms/kernel_layout_transform.cc index 681785c8123c..421968b8a6b9 100644 --- a/src/relay/transforms/kernel_layout_transform.cc +++ b/src/relay/transforms/kernel_layout_transform.cc @@ -17,13 +17,17 @@ * under the License. */ +#include "kernel_layout_transform.h" + +#include #include -#include #include -#include +#include #include + +#include #include -#include "kernel_layout_transform.h" +#include namespace tvm { namespace relay { @@ -36,7 +40,8 @@ Expr KernelLayoutTransform(const Expr& expr) { KernelLayoutVisitor visitor; // Do a pre-order DFS to gather the optimal kernel layouts for all conv2d nodes. - // These layouts were written to global static variables in python function `prepare_layout_rewrite` + // These layouts were written to global static variables in python function + // `prepare_layout_rewrite` visitor.VisitExpr(expr); // Do a post-order DSF to mutate layout for all conv2d nodes @@ -47,15 +52,13 @@ namespace transform { Pass KernelLayoutTransform() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(relay::KernelLayoutTransform(f)); - }; - return CreateFunctionPass(pass_func, 3, "KernelLayoutTransform", - {"InferType"}); + [=](Function f, IRModule m, PassContext pc) { + return Downcast(relay::KernelLayoutTransform(f)); + }; + return CreateFunctionPass(pass_func, 3, "KernelLayoutTransform", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.KernelLayoutTransform") -.set_body_typed(KernelLayoutTransform); +TVM_REGISTER_GLOBAL("relay._transform.KernelLayoutTransform").set_body_typed(KernelLayoutTransform); } // namespace transform diff --git a/src/relay/transforms/kernel_layout_transform.h b/src/relay/transforms/kernel_layout_transform.h index c82a96b30612..c6c38fb71cf4 100644 --- a/src/relay/transforms/kernel_layout_transform.h +++ b/src/relay/transforms/kernel_layout_transform.h @@ -1,11 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_RELAY_TRANSFORMS_KERNEL_LAYOUT_TRANSFORM_H_ +#define TVM_RELAY_TRANSFORMS_KERNEL_LAYOUT_TRANSFORM_H_ + #include #include + +#include +#include #include #include - -#include "pattern_util.h" +#include #include "../../ansor/compute_dag.h" +#include "pattern_util.h" namespace tvm { namespace relay { @@ -13,10 +37,11 @@ namespace relay { /*! \brief A visitor to gather the optimal kernel layout for all conv2d nodes. */ class KernelLayoutVisitor : public ExprVisitor { public: - void VisitExpr_(const CallNode *n) { + void VisitExpr_(const CallNode* n) { if (n && n->op.as() && (std::find(op_white_lists.begin(), op_white_lists.end(), n->op.as()->name) != - op_white_lists.end()) && n->args[1]->type_as()->shape[3].as()->value > 1 && + op_white_lists.end()) && + n->args[1]->type_as()->shape[3].as()->value > 1 && !global_ori_layouts_queue.empty() && !global_new_layouts_queue.empty()) { ori_layouts_map[n] = global_ori_layouts_queue.front(); new_layouts_map[n] = global_new_layouts_queue.front(); @@ -28,30 +53,31 @@ class KernelLayoutVisitor : public ExprVisitor { ExprVisitor::VisitExpr_(n); } - std::unordered_map ori_layouts_map; - std::unordered_map new_layouts_map; - std::vector op_white_lists {"nn.contrib_conv2d_winograd_without_weight_transform", - "nn.conv2d", "nn.conv3d"}; + std::unordered_map ori_layouts_map; + std::unordered_map new_layouts_map; + std::vector op_white_lists{"nn.contrib_conv2d_winograd_without_weight_transform", + "nn.conv2d", "nn.conv3d"}; static std::deque global_ori_layouts_queue; static std::deque global_new_layouts_queue; }; - /*! \brief A mutator to rewrite kernel layout for all conv2d nodes */ class KernelLayoutTransformer : public ExprMutator { public: - KernelLayoutTransformer(KernelLayoutVisitor* visitor): ExprMutator(), visitor_(visitor) {} + explicit KernelLayoutTransformer(KernelLayoutVisitor* visitor) + : ExprMutator(), visitor_(visitor) {} Expr VisitExpr_(const CallNode* n) { auto new_n = ExprMutator::VisitExpr_(n); const auto* call = new_n.as(); - std::vector op_white_lists {"nn.contrib_conv2d_winograd_without_weight_transform", - "nn.conv2d", "nn.conv3d"}; + std::vector op_white_lists{"nn.contrib_conv2d_winograd_without_weight_transform", + "nn.conv2d", "nn.conv3d"}; if (call && call->op.as() && (std::find(op_white_lists.begin(), op_white_lists.end(), n->op.as()->name) != - op_white_lists.end() && n->args[1]->type_as()->shape[3].as()->value > 1)) { + op_white_lists.end() && + n->args[1]->type_as()->shape[3].as()->value > 1)) { auto ori_layout_iter = visitor_->ori_layouts_map.find(n); auto new_layout_iter = visitor_->new_layouts_map.find(n); if (ori_layout_iter != visitor_->ori_layouts_map.end() && @@ -60,8 +86,7 @@ class KernelLayoutTransformer : public ExprMutator { const std::string& new_layout = new_layout_iter->second; Expr updated_kernel = MakeKernelLayoutTransform(call->args[1], ori_layout, new_layout); Array updated_args = {call->args[0], updated_kernel}; - new_n = Call(call->op, updated_args, - call->attrs); + new_n = Call(call->op, updated_args, call->attrs); } } return new_n; @@ -71,6 +96,7 @@ class KernelLayoutTransformer : public ExprMutator { KernelLayoutVisitor* visitor_; }; +} // namespace relay +} // namespace tvm -} // namespace relay -} // namespace tvm +#endif // TVM_RELAY_TRANSFORMS_KERNEL_LAYOUT_TRANSFORM_H_ diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index b95d5ba25926..d58130d700f4 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -306,8 +306,7 @@ std::shared_ptr RPCModuleGetSession(Module mod) { } inline void CacheFlush(const char* p, unsigned int allocation_size) { -// TODO: (FrozenGene) -// Support ARM. +// TODO(FrozenGene): Support ARM. #if (defined(_M_X64) || defined(__x86_64__)) size_t cache_line = 64; @@ -346,7 +345,7 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf, TVMContext ctx, int number, int repe CHECK_EQ(number, 1); // we want to keep input data for (int j = 1; j < args.size(); j++) { - CacheFlush((char*)(args[j].operator DLTensor*()->data), + CacheFlush(reinterpret_cast(args[j].operator DLTensor*()->data), GetDataSize(*(args[j].operator DLTensor*()))); } } diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index 3876d67b7b11..4f1078165f34 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -59,7 +59,7 @@ struct UnrollLoopConfigNode : public tvm::AttrsNode { .describe("Whether to explicitly unroll the loop instead of setting a pragma") .set_default(true); TVM_ATTR_FIELD(explicit_unroll_max_extent) - .describe("The maximum extent of a loop that can be unrolled explicitly (-1 means infinite)") + .describe("The maximum extent of a loop that can be unrolled explicitly (-1 for infinite)") .set_default(32); } }; @@ -170,7 +170,8 @@ class LoopUnroller : public StmtExprMutator { // For loop must have a constant integer extent CHECK_NE(value, -1) << "loop doesn't have a constant integer extent"; if (value == 0) return Evaluate(0); - if (explicit_unroll_max_extent_ > 0 && value > explicit_unroll_max_extent_ && explicit_unroll_) { + if (explicit_unroll_max_extent_ > 0 && value > explicit_unroll_max_extent_ && + explicit_unroll_) { // Do not unroll too long loops ForType for_type = op->for_type == ForType::Unrolled ? ForType::Serial : op->for_type; return For(op->loop_var, op->min, op->extent, for_type, op->device_api, op->body); From cd5c5ad71dea12d1dd51b9db913c525329949dcf Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 23 Jun 2020 12:09:05 -0700 Subject: [PATCH 37/45] Add MutateComputeLocation and MutateParallel in evolutionary search (#40) * Add MutateComputeLocation and MutateParallel in evolutionary search * fix lint --- src/ansor/auto_schedule.h | 11 +- src/ansor/compute_dag.cc | 66 ---- src/ansor/loop_state.cc | 5 +- src/ansor/loop_state.h | 27 +- src/ansor/measure.cc | 7 +- src/ansor/search_policy/search_policy.cc | 1 - .../search_policy/sketch_search_policy.cc | 9 +- src/ansor/search_policy/utils.cc | 345 +++++++++++++++++- src/ansor/search_policy/utils.h | 21 +- src/ansor/search_task.h | 1 - src/ansor/serialization.cc | 25 +- src/ansor/transform_step.cc | 3 +- src/ansor/transform_step.h | 25 +- src/ansor/utils.h | 18 - 14 files changed, 389 insertions(+), 175 deletions(-) diff --git a/src/ansor/auto_schedule.h b/src/ansor/auto_schedule.h index f17c043cfadd..7ffd2c4d3a70 100644 --- a/src/ansor/auto_schedule.h +++ b/src/ansor/auto_schedule.h @@ -37,14 +37,11 @@ namespace ansor { class TuneOptionNode : public Object { public: int n_trials; // Number of total measurement trials - int early_stopping; // Stops early the tuning if no improvement after n - // measurements - int num_measure_per_iter; // The number of programs to be measured at each - // iteration + int early_stopping; // Stops early the tuning if no improvement after n measurements + int num_measure_per_iter; // The number of programs to be measured at each iteration int verbose; // Verbosity level. 0 means silent. Builder builder; // Builder which builds the program - Runner runner; // Runner which runs the program and measure time - // costs + Runner runner; // Runner which runs the program and measure time costs Array measure_callbacks; // MeasureCallback functions Array pre_search_callbacks; // SearchCallback functions // run before search @@ -76,13 +73,13 @@ class TuneOption : public ObjectRef { Array pre_search_callbacks); TVM_DEFINE_OBJECT_REF_METHODS(TuneOption, ObjectRef, TuneOptionNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(TuneOptionNode); }; /*! \brief Auto schedule for a compute declaration */ std::pair > AutoSchedule( SearchTask task, SearchPolicy search_policy, TuneOption tune_option); +/*! \brief Auto schedule for a compute declaration */ std::pair > AutoSchedule( std::string workload_key, Target target, Target target_host, SearchPolicy search_policy, HardwareParams hardware_params, diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index 13f64b2bdc89..ee87318cdd84 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -653,63 +653,6 @@ class IndexRewriter : public StmtExprMutator { return GetRef(op); } - /* - PrimExpr Mutate_(const Call* op, const PrimExpr& e) { - PrimExpr op_ = IRMutator::Mutate_(op, e); - - const Call* call = op_.as(); - - if (call->call_type == Call::CallType::Halide) { - te::Tensor t = Downcast(call->func).output(call->value_index); - auto it = placeholder_new_names_.find(t->op); - if (it != placeholder_new_names_.end()) { - const std::vector& new_names = it->second; - const Array& new_shape = placeholder_new_shapes_.at(t->op); - std::unordered_map name_to_arg; - for (const auto& arg : call->args) { - std::string axis_name; - if (const auto* pimm = arg.as()) { - CHECK_EQ(pimm->value, 0); - axis_name = "IntImm"; - } else { - axis_name = BaseName(CleanName(Downcast(arg)->name_hint)); - CHECK_EQ(name_to_arg.count(axis_name), 0); - name_to_arg[axis_name] = arg; - } - } - - std::unordered_map div_factors; - std::vector r_new_args; - for (int i = new_names.size() - 1; i >= 0; --i) { - auto ori_iter_name = new_names[i]; - auto name_it = name_to_arg.find(ori_iter_name); - CHECK(name_it != name_to_arg.end()); - PrimExpr ori_arg = name_it->second; - - PrimExpr mod_factor = new_shape[i]; - - PrimExpr div_factor = 1; - if (div_factors.count(ori_iter_name)) { - div_factor = div_factors[ori_iter_name]; - } - div_factors[ori_iter_name] = div_factor * new_shape[i]; - - PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor); - - r_new_args.push_back(new_arg); - } - - Array new_args(std::make_move_iterator(r_new_args.rbegin()), - std::make_move_iterator(r_new_args.rend())); - - return Call::make(call->type, call->name, new_args, call->call_type, - call->func, call->value_index); - } - } - return op_; - } - */ - private: const OperationMap >& placeholder_new_names_; const OperationMap >& placeholder_new_shapes_; @@ -1345,15 +1288,6 @@ TVM_REGISTER_GLOBAL("ansor.ComputeDAGApplyStepsFromState") std::tie(sch, return_tensors) = dag.ApplySteps(state->transform_steps, layout_rewrite_level); *ret = Array{sch, return_tensors}; }); -/* -TVM_REGISTER_GLOBAL("ansor.ComputeDAGApplyStepsFromState") -.set_body_typed([](const ComputeDAG& dag, const State& state) { - te::Schedule sch; - Array return_tensors; - std::tie(sch, return_tensors) = dag.ApplySteps(state->transform_steps); - return Array{sch, return_tensors}; -}); -*/ TVM_REGISTER_GLOBAL("ansor.ComputeDAGPrintPythonCodeFromState") .set_body_typed([](const ComputeDAG& dag, const State& state) { diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index ef4c4632e9bf..010e5f3dc221 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -18,8 +18,9 @@ */ /*! - * \file ansor/loop_state.h - * \brief An IR (intermediate representation) for loop structures. + * \file ansor/loop_state.cc + * \brief An lightweight IR (intermediate representation) for loop structures. + * see ansor/loop_state.h for more explanation. */ #include "loop_state.h" diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h index 2d64db11fc18..1b7bbc40bb31 100644 --- a/src/ansor/loop_state.h +++ b/src/ansor/loop_state.h @@ -27,10 +27,10 @@ * Basically this is a simplified TVM IR with schedule primitives. * We don't use the existing TVM IR because * 1. We want fast incremental change to the loop structures - * 2. We want serializable history for replay and backtracking - * 3. We may create some Macro schedule primitives + * 2. We want serializable transformation history for replay, backtracking, and mutation. + * 3. We may create some macro schedule primitives * - * After search is done, we will lower this IR to TVM IR with TVM schedule primitives. + * After the search is done, we will lower this IR to TVM IR with TVM schedule primitives. * Because we share a lot common objects during search, the transformation is * implemented in copy on write style. All objects are immutable, which is * similar to TVM IR. @@ -53,7 +53,8 @@ using namespace tvm::tir; /*! \brief The type of a stage */ enum StageType { - kPlaceholder, kCompute + kPlaceholder, // A placeholder stage + kCompute // A compute stage }; /*! \brief The type of compute location */ @@ -78,6 +79,7 @@ enum IteratorAnnotation { kTensorized }; +// forward declaration class Iterator; /*! @@ -91,7 +93,7 @@ class IteratorNode : public Object { IteratorType iter_type; IteratorAnnotation annotation; std::vector ori_iters; // The original iterators before fusion - std::string attr; + std::string attr; // Todo(jcf94): Document this void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name", &name); @@ -115,13 +117,12 @@ class Iterator : public ObjectRef { std::string attr = ""); TVM_DEFINE_OBJECT_REF_METHODS(Iterator, ObjectRef, IteratorNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(IteratorNode); }; /*! \brief Stage-level attributes */ struct StageAttributes { - int auto_unroll_max_step; - int storage_offset; + int auto_unroll_max_step; // The maximum steps for the pragma `auto_unroll_max_step` + int storage_offset; // The storage offset for the schedule primitive `storage_align` }; /*! @@ -130,11 +131,11 @@ struct StageAttributes { */ class StageNode : public Object { public: - te::Operation op; - StageType op_type; - std::vector iters; - ComputeAtType compute_at; - StageAttributes attrs; + te::Operation op; // The operator of this stage + StageType op_type; // The type of this stage + std::vector iters; // The iterators in this stage + ComputeAtType compute_at; // The compute location of this stage + StageAttributes attrs; // Other stage-level attributes void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("op", &op); diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc index a044acfe5395..e99f41725077 100644 --- a/src/ansor/measure.cc +++ b/src/ansor/measure.cc @@ -341,8 +341,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ", " << node->time_cost << ")"; }); -TVM_REGISTER_GLOBAL("ansor.MeasureInput") -.set_body_typed([](SearchTask task, State state) { +TVM_REGISTER_GLOBAL("ansor.MeasureInput").set_body_typed([](SearchTask task, State state) { return MeasureInput(task, state); }); @@ -359,8 +358,7 @@ TVM_REGISTER_GLOBAL("ansor.MeasureResult") }); TVM_REGISTER_GLOBAL("ansor.BuilderBuild") -.set_body_typed([](const Builder& builder, - const Array& inputs, int verbose) { +.set_body_typed([](const Builder& builder, const Array& inputs, int verbose) { return builder->Build(inputs, verbose); }); @@ -397,6 +395,5 @@ TVM_REGISTER_GLOBAL("ansor.ProgramMeasurer") max_continous_error); }); - } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_policy/search_policy.cc b/src/ansor/search_policy/search_policy.cc index 51a48780813a..b86bf9490851 100644 --- a/src/ansor/search_policy/search_policy.cc +++ b/src/ansor/search_policy/search_policy.cc @@ -77,7 +77,6 @@ void SearchPolicyNode::PreloadMeasuredStates(const std::string& log_file) { void SearchPolicyNode::RunCallbacks(const Array& callbacks) { if (callbacks.defined() && callbacks.size()) { - PrintTitle("Call search callbacks", verbose); for (const auto& callback : callbacks) { callback->callback(this); } diff --git a/src/ansor/search_policy/sketch_search_policy.cc b/src/ansor/search_policy/sketch_search_policy.cc index 63f75cad1c83..c4365a391865 100644 --- a/src/ansor/search_policy/sketch_search_policy.cc +++ b/src/ansor/search_policy/sketch_search_policy.cc @@ -67,6 +67,7 @@ State SketchSearchPolicyNode::Search(SearchTask task, int n_trials, this->verbose = verbose; num_measure_per_iter_ = num_measure_per_iter; + PrintTitle("Call search callbacks", verbose); RunCallbacks(pre_search_callbacks); if (n_trials <= 1) { // no measurement is allowed @@ -94,7 +95,7 @@ State SketchSearchPolicyNode::Search(SearchTask task, int n_trials, PrintTitle("Search", verbose); SearchOneRound(&best_states, num_random, &random_states); - // Fill correct bound.This is necessary for computing the correct ToStr() for reduncency check + // Infer bound. This is necessary for computing the correct ToStr() for redundancy check cur_task->compute_dag.InferBound(&best_states); cur_task->compute_dag.InferBound(&random_states); @@ -218,10 +219,10 @@ void SketchSearchPolicyNode::PickStatesWithEpsGreedy( std::string state_str = pstate->ToStr(); if (measured_states_set_.count(state_str)) { continue; } - measured_states_set_.insert(state_str); + measured_states_set_.insert(std::move(state_str)); inputs->push_back(MeasureInput(cur_task, *pstate)); - measured_states_vector_.push_back(std::move(*pstate)); + measured_states_vector_.push_back(*pstate); } } @@ -274,7 +275,7 @@ void SketchSearchPolicyNode::SearchOneRound(std::vector* best_states, RandomSampleStates(init_population, &rand_gen_, num_random_states * 10, random_states); } -// The baseclass of derivation rules used in sketch generation +// The base class for derivation rules used in sketch generation class SketchGenerationRule { public: enum ConditionEnum { diff --git a/src/ansor/search_policy/utils.cc b/src/ansor/search_policy/utils.cc index 412d0afcca98..2d2f92ecbc20 100644 --- a/src/ansor/search_policy/utils.cc +++ b/src/ansor/search_policy/utils.cc @@ -32,9 +32,10 @@ void GetSpaceSplitStepIds(const State& s, int stage_id, std::vector* spatia auto pop = s->stages[stage_id]->op.as(); CHECK(pop != nullptr); - auto no_split_name_pair = QueryNoSplitAxis(s->stages[stage_id]); - std::set no_split_at_inner_name_set = no_split_name_pair.first; - std::set no_split_at_outer_name_set = no_split_name_pair.second; + const auto& no_split_name_pair = QueryNoSplitAxis(s->stages[stage_id]); + const std::set& no_split_at_inner_name_set = no_split_name_pair.first; + const std::set& no_split_at_outer_name_set = no_split_name_pair.second; + size_t reduce_count = 0; for (const auto axis : pop->reduce_axis) { if (!no_split_at_inner_name_set.count(axis->var->name_hint) && @@ -52,6 +53,8 @@ void GetSpaceSplitStepIds(const State& s, int stage_id, std::vector* spatia } } else if (auto ps = s->transform_steps[i].as()) { if (stage_id == ps->stage_id) { + // Assume SplitStep on reduction axes are always after SplitStep on spatial axes. + // TODO(jcf94): do not rely on this assumption if (reduce_count) { reduce_count--; } else { @@ -75,7 +78,7 @@ State DoMultiLevelTiling(const State& state, int stage_id, const std::string& fo } else if (tolower(c) == 'r') { reduce_levels.emplace_back(); } else { - LOG(FATAL) << "Invalid multi level tiling format: " << format; + LOG(FATAL) << "Invalid multi-level tiling format: " << format; } } size_t n_space = space_levels.size(); @@ -85,10 +88,10 @@ State DoMultiLevelTiling(const State& state, int stage_id, const std::string& fo State tmp_s = state; const Stage& stage = state->stages[stage_id]; - auto no_split_name_pair = QueryNoSplitAxis(stage); // handle special split strategy - auto last_split_is_one_name_set = QueryLastSplitIsOneAxis(stage); - std::set no_split_at_inner_name_set = no_split_name_pair.first; - std::set no_split_at_outer_name_set = no_split_name_pair.second; + const auto& no_split_name_pair = QueryNoSplitAxis(stage); // handle special split strategy + const auto& last_split_is_one_name_set = QueryLastSplitIsOneAxis(stage); + const std::set& no_split_at_inner_name_set = no_split_name_pair.first; + const std::set& no_split_at_outer_name_set = no_split_name_pair.second; for (const auto& iter : state->stages[stage_id]->iters) { if (iter->iter_type == kSpace) { @@ -119,10 +122,10 @@ State DoMultiLevelTiling(const State& state, int stage_id, const std::string& fo } } } else if (iter->iter_type == kReduce) { - // for reduce iterator, split it into two iterators if (!no_split_at_inner_name_set.count(iter->name) && !no_split_at_outer_name_set.count(iter->name)) { CHECK_GE(n_reduce, 1); + if (n_reduce == 1) { reduce_levels[0].push_back(iter); } else { @@ -147,23 +150,27 @@ State DoMultiLevelTiling(const State& state, int stage_id, const std::string& fo if (!space_outer.empty()) { CHECK(!space_levels.empty()); space_levels.front().insert(space_levels.front().begin(), - space_outer.begin(), space_outer.end()); + std::make_move_iterator(space_outer.begin()), + std::make_move_iterator(space_outer.end())); } if (!space_inner.empty()) { CHECK(!space_levels.empty()); space_levels.back().insert(space_levels.back().begin(), - space_inner.begin(), space_inner.end()); + std::make_move_iterator(space_inner.begin()), + std::make_move_iterator(space_inner.end())); } if (!reduce_outer.empty()) { CHECK(!reduce_levels.empty()); reduce_levels.front().insert(reduce_levels.front().begin(), - reduce_outer.begin(), reduce_outer.end()); + std::make_move_iterator(reduce_outer.begin()), + std::make_move_iterator(reduce_outer.end())); } if (!reduce_inner.empty()) { CHECK(!reduce_levels.empty()); reduce_levels.back().insert(reduce_levels.back().begin(), - reduce_inner.begin(), reduce_inner.end()); + std::make_move_iterator(reduce_inner.begin()), + std::make_move_iterator(reduce_inner.end())); } std::vector order; @@ -198,7 +205,7 @@ State FollowTiling(const State& state, int stage_id, auto pop = state->stages[stage_id]->op.as(); CHECK(pop != nullptr); const Stage& stage = state->stages[stage_id]; - auto no_split_name_pair = QueryNoSplitAxis(stage); // handle special split strategy + const auto& no_split_name_pair = QueryNoSplitAxis(stage); // handle special split strategy const std::set& no_split_at_inner_name_set = no_split_name_pair.first; const std::set& no_split_at_outer_name_set = no_split_name_pair.second; int no_split_at_inner_name_in_stage_cnt = 0; @@ -266,6 +273,7 @@ State FollowTiling(const State& state, int stage_id, LOG(FATAL) << "Invalid iter type: " << iter->iter_type; } } + if (n_split == 3) { ConcatenateMove(&tmp_order, &space_0, &space_1, &space_2, &space_3); } else if (n_split == 2) { @@ -406,13 +414,320 @@ State RandomMutateMaxUnrollStep(const State& old_state, std::mt19937* random_gen return tmp_s; } +State RandomMutateParallel(const State& old_state, std::mt19937* random_gen, + const SearchTask& task, int verbose) { + // To make this mutation simple but promising, we only focus on a specific case that + // parallel was added to the outermost loop and the loop is generated by fusing other loops. + // In short, we mutate the step pattern of (fuse -> parallel). + + // Extract all parallel steps. + std::vector parallel_steps; + for (size_t s = 0; s < old_state->transform_steps.size(); ++s) { + auto ps = old_state->transform_steps[s].as(); + if (!ps || ps->annotation != kParallel) { + continue; + } + parallel_steps.push_back(s); + } + if (parallel_steps.empty()) { + StdCout(verbose) << "Parallel mutation failed: No parallel annotations" << std::endl; + return State(); + } + + // Randomly pick one step. + int retry_ct = 0; + size_t step_id = 0; + size_t stage_id = 0; + do { + step_id = parallel_steps[(*random_gen)() % parallel_steps.size()]; + auto step = old_state->transform_steps[step_id].as(); + stage_id = step->stage_id; + + // Check assumptions. + auto iter_id = step->iter_id; + if (iter_id == 0 && step_id > 0 && old_state->transform_steps[step_id - 1].as()) { + break; + } + retry_ct++; + } while (retry_ct <= 3); + + if (retry_ct > 3) { + StdCout(verbose) << "Parallel mutation failed: No valid parallel annotations" << std::endl; + return State(); + } + + // Replay a new state until the picked fuse step. + State tmp_s = task->compute_dag.GetInitState(); + for (size_t s = 0; s < step_id - 1; ++s) { + auto step = old_state->transform_steps[s]; + tmp_s.CopyOnWrite()->transform_steps.push_back(step); + tmp_s.DoStep(step, task->compute_dag); + } + + // Determine the fuse direction. + // 0: fuse less; 1: fuse more. + auto fuse_step = old_state->transform_steps[step_id - 1].as(); + std::vector fused_ids = fuse_step->fused_ids; + std::vector fuse_dir = {0.5, 1.0}; + + // The case we can only fuse more. + if (fused_ids.size() == 1) { + fuse_dir[0] = 0.0; + } + + // The cases that we cannot fuse the next iters. + if (old_state->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, 0)) > 0 || + tmp_s->stages[stage_id]->iters.size() == fused_ids.size() || + tmp_s->stages[stage_id]->iters[1]->iter_type == kReduce) { + // In case we cannot fuse less neither, give up. + if (fuse_dir[0] == 0.0) { + StdCout(verbose) << "Parallel mutation failed: Cannot fuse more or less iters" << std::endl; + return State(); + } + fuse_dir[0] = 1.0; + } + + int iter_offset = 0; + if (RandomChoose(fuse_dir, random_gen) == 0) { + StdCout(verbose) << "Parallel mutation: release iter " << fused_ids.back() << std::endl; + fused_ids.pop_back(); + iter_offset = 1; + } else { + StdCout(verbose) << "Parallel mutation: include iter " << fused_ids.back() + 1 << std::endl; + fused_ids.push_back(fused_ids.back() + 1); + iter_offset = -1; + } + + // Replay the mutated fused and annotation step. + auto new_fuse_step = FuseStep(stage_id, fused_ids); + tmp_s.CopyOnWrite()->transform_steps.push_back(new_fuse_step); + tmp_s.DoStep(new_fuse_step, task->compute_dag); + tmp_s.CopyOnWrite()->transform_steps.push_back(old_state->transform_steps[step_id]); + tmp_s.DoStep(old_state->transform_steps[step_id], task->compute_dag); + + // Replay the rest steps. + for (size_t s = step_id + 1; s < old_state->transform_steps.size(); ++s) { + auto step = old_state->transform_steps[s]; + if (step->stage_id == static_cast(stage_id)) { + // Since we change the loop structure, iter ID in later steps to the same stage + // has to be adjusted. + auto ps = step.as(); + if (ps) { + if (ps->iter_id == 0) { + step = AnnotationStep(ps->stage_id, 0, ps->annotation); + } else { + CHECK_LE(ps->iter_id + iter_offset, tmp_s->stages[stage_id]->iters.size()); + step = AnnotationStep(ps->stage_id, ps->iter_id + iter_offset, ps->annotation); + } + } else { + StdCout(verbose) << "Parallel mutation: Cannot apply " << step << " after fuse" + << std::endl; + return State(); + } + } + tmp_s.CopyOnWrite()->transform_steps.push_back(step); + tmp_s.DoStep(step, task->compute_dag); + } + return tmp_s; +} + + +State RandomMutateComputeLocation(const State& old_state, std::mt19937* random_gen, + const SearchTask& task) { + // Extract all compute_at steps. + std::vector compute_at_steps; + for (size_t s = 0; s < old_state->transform_steps.size(); ++s) { + if (auto ps = old_state->transform_steps[s].as()) { + const Stage& stage = old_state->stages[ps->stage_id]; + if (IsTiled(stage)) { + continue; + } + + if (NeedsMultilevelTiling(task, old_state, stage->op)) { + continue; + } + compute_at_steps.push_back(s); + } + } + if (compute_at_steps.empty()) { + return State(); + } + + // Randomly pick one step + size_t step_id = compute_at_steps[(*random_gen)() % compute_at_steps.size()]; + auto ps = old_state->transform_steps[step_id].as(); + CHECK(ps != nullptr); + const Stage& stage = old_state->stages[ps->stage_id]; + + // Randomly pick one tile level + int new_compute_at_stage_id; + int new_compute_at_iter_id; + + // Copied from InitPopulationChangeComputeLocation + { + std::unordered_set consumers; + GetConsumers(task, old_state, stage->op, &consumers); + if (consumers.empty()) { + return State(); + } + + int target_stage_id; + if (consumers.size() == 1) { + target_stage_id = OperationToStage(*consumers.begin(), old_state); + } else { + // check all consumers share a common root + int common_root_id = -1; + bool mismatch = false; + for (const auto& consumer : consumers) { + int consumer_stage_id = OperationToStage(consumer, old_state); + int root_id = -1; + if ((old_state)->stages[consumer_stage_id]->compute_at == kRoot) { + root_id = consumer_stage_id; + } else if ((old_state)->stages[consumer_stage_id]->compute_at == kIter) { + root_id = (old_state)->attach_map->stage_to_attach_iter.at(consumer_stage_id).first; + } else { + LOG(FATAL) << "Invalid case"; + } + + if (common_root_id == -1) { + common_root_id = root_id; + } else { + if (common_root_id != root_id) { + mismatch = true; + break; + } + } + } + + if (mismatch) { + return State(); + } + target_stage_id = common_root_id; + } + + const Stage& target_stage = old_state->stages[target_stage_id]; + std::set to_unroll_name_set; + if (target_stage->op->attrs.count(SearchPolicyNode::always_unroll_key)) { + to_unroll_name_set = GetIterNameSetParam(target_stage->op->attrs, + SearchPolicyNode::always_unroll_key); + } + + std::vector > candidates; + bool target_compute_at_other = target_stage->compute_at == kIter; + bool target_is_tiled = IsTiled(target_stage); + + bool visited_reduce = false; + // enumerate compute_at location at target_stage + int ct = 0; + for (size_t iter_id = 0; iter_id < target_stage->iters.size(); ++iter_id) { + const auto& target_iter = target_stage->iters[iter_id]; + if (target_iter->iter_type == kReduce) { + visited_reduce = true; + if (!target_is_tiled) { // do not go into reduce iter + break; + } + } else if (target_iter->iter_type == kSpace) { + if (visited_reduce) { // do not go into inner tile + break; + } + } + + if (to_unroll_name_set.count(target_iter->name)) { + // Do not go into always unroll region + break; + } + + if (GetExtent(target_iter) == 1) { // skip iterators with length of 1 + continue; + } + if (target_compute_at_other && target_iter->iter_type == kSpace && + StrEndsWith(target_iter->name, ".0")) { + // skip the first level iterators if target stage compute_at another stage + // In this case, the lengths of first level iterators are always one + continue; + } + candidates.emplace_back(target_stage_id, iter_id); + + if ((old_state)->attach_map->iter_to_attached_stages.count( + std::make_pair(target_stage_id, ct++))) { + break; + } + } + + // if the target_stage is already compute_at another stage X, try also compute_at X + // We call stage X as `target_target_stage` + if (target_compute_at_other) { + int target_target_stage_id; + target_target_stage_id = (old_state)->attach_map->stage_to_attach_iter.at( + target_stage_id).first; + const Stage& target_target_stage = (old_state)->stages[target_target_stage_id]; + if (target_target_stage->op->attrs.count(SearchPolicyNode::always_unroll_key)) { + to_unroll_name_set = GetIterNameSetParam(target_target_stage->op->attrs, + SearchPolicyNode::always_unroll_key); + } else { + to_unroll_name_set.clear(); + } + + int ct = 0; + for (size_t iter_id = 0; iter_id < target_target_stage->iters.size(); ++iter_id) { + const auto& target_target_iter = target_target_stage->iters[iter_id]; + if (target_target_iter->iter_type == kReduce || + (old_state)->attach_map->iter_to_attached_stages.count( + std::make_pair(target_target_stage_id, ct++))) { + break; + } + + if (to_unroll_name_set.count(target_target_iter->name)) { + // Do not go into always unroll region + break; + } + + if (GetExtent(target_target_iter) == 1) { // skip iterators with length of 1 + continue; + } + + candidates.emplace_back(target_target_stage_id, iter_id); + } + } + + if (candidates.empty()) { + return State(); + } + + int choice = (*random_gen)() % (candidates.size()); + new_compute_at_stage_id = candidates[choice].first; + new_compute_at_iter_id = candidates[choice].second; + } + + // Replay a new state. + State tmp_s = task->compute_dag.GetInitState(); + for (size_t s = 0; s < old_state->transform_steps.size(); ++s) { + if (s == step_id) { + tmp_s.CopyOnWrite()->transform_steps.push_back( + ComputeAtStep(ps->stage_id, new_compute_at_stage_id, new_compute_at_iter_id)); + } else { + tmp_s.CopyOnWrite()->transform_steps.push_back(old_state->transform_steps[s]); + } + try { + tmp_s.DoStep(tmp_s->transform_steps.back(), task->compute_dag); + } catch (dmlc::Error &e) { + return State(); + } + } + + return tmp_s; +} + void PruneUndefined(std::vector* states) { size_t pt = 0; for (size_t i = 0; i < states->size(); ++i) { if (!(*states)[i].defined()) { continue; } - (*states)[pt++] = std::move((*states)[i]); + if (i != pt) { + (*states)[pt++] = std::move((*states)[i]); + } + pt++; } if (pt == 0) { diff --git a/src/ansor/search_policy/utils.h b/src/ansor/search_policy/utils.h index 5f15397e7e90..107e2ee72521 100644 --- a/src/ansor/search_policy/utils.h +++ b/src/ansor/search_policy/utils.h @@ -79,8 +79,8 @@ inline std::set GetIterNameSetParam(const Map& a CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; auto names = attr_dict[key].as(); CHECK(names != nullptr); - for (auto name = names->begin(); name != names->end(); name++) { - ret.insert(name->as()->value); + for (const auto & name : *names) { + ret.insert(name.as()->value); } return ret; } @@ -284,9 +284,6 @@ inline bool HasCacheReadStage(const State& s, int stage_id) { return false; } -// Get all split step on spatial iterators -void GetSpaceSplitStepIds(const State& s, int stage_id, std::vector* spatial_split_step_ids); - // Return whether the state did split/follow_split/follow_fused_split in stage_id inline bool HasSplitStep(const State& s, int stage_id) { for (int i = static_cast(s->transform_steps.size()) - 1; i >= 0; --i) { @@ -441,6 +438,9 @@ inline void PrintAllStates(const std::vector& states) { } } +// Get all split steps on spatial iterators for one stage +void GetSpaceSplitStepIds(const State& s, int stage_id, std::vector* spatial_split_step_ids); + // Apply multi-level tiling structure according to a string format, // where "S" stands a space level, "R" stands for a reudciton level. // For example, if the format is "SSRSRS", the we will @@ -451,8 +451,7 @@ inline void PrintAllStates(const std::vector& states) { State DoMultiLevelTiling(const State& state, int stage_id, const std::string& format, std::vector* spatial_split_step_ids); -// Apply tiling structure: space, space -// But use tile sizes from other SplitStep +// Apply tiling structure: space, space, space, ..., with tile sizes from other SplitStep State FollowTiling(const State& state, int stage_id, const std::vector& split_step_ids, int n_split); @@ -464,6 +463,14 @@ State RandomMutateTileSize(const State& old_state, SplitFactorizationMemo* split State RandomMutateMaxUnrollStep(const State& old_state, std::mt19937* random_gen, const std::vector& auto_unroll_configs); +// Randomly mutate the parallel degree of one stage. +State RandomMutateParallel(const State& old_state, std::mt19937* random_gen, + const SearchTask& task, int verbose = 0); + +// Randomly mutate the computation location of one stage. +State RandomMutateComputeLocation(const State& old_state, std::mt19937* random_gen, + const SearchTask& task); + // GA: Crossover two states State CrossOverState(const State& p1, const State& p2); diff --git a/src/ansor/search_task.h b/src/ansor/search_task.h index c53fdcd0f792..0f270d105d73 100644 --- a/src/ansor/search_task.h +++ b/src/ansor/search_task.h @@ -121,7 +121,6 @@ class SearchTask : public ObjectRef { HardwareParams hardware_params); TVM_DEFINE_OBJECT_REF_METHODS(SearchTask, ObjectRef, SearchTaskNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(SearchTaskNode); }; } // namespace ansor diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index c026b9b6251a..d84c3c57dc86 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -583,22 +583,11 @@ std::pair BestMeasurePairInFile( return best_pair; } -TVM_REGISTER_GLOBAL("ansor.WriteMeasureRecordsToFile") -.set_body([](TVMArgs args, TVMRetValue *ret) { - std::string filename = args[0]; - Array in = args[1]; - Array res = args[2]; - std::ofstream ofs(filename, std::ofstream::app); - WriteMeasureRecords(&ofs, in, res); -}); - -TVM_REGISTER_GLOBAL("ansor.LogToFile") -.set_body_typed([](const std::string& filename) { +TVM_REGISTER_GLOBAL("ansor.LogToFile").set_body_typed([](const std::string& filename) { return LogToFile(filename); }); -TVM_REGISTER_GLOBAL("ansor.LogReader") -.set_body_typed([](const std::string& filename) { +TVM_REGISTER_GLOBAL("ansor.LogReader").set_body_typed([](const std::string& filename) { return LogReader(filename); }); @@ -619,6 +608,15 @@ TVM_REGISTER_GLOBAL("ansor.LogReaderReadNext") } }); +TVM_REGISTER_GLOBAL("ansor.WriteMeasureRecordsToFile") +.set_body([](TVMArgs args, TVMRetValue *ret) { + std::string filename = args[0]; + Array in = args[1]; + Array res = args[2]; + std::ofstream ofs(filename, std::ofstream::app); + WriteMeasureRecords(&ofs, in, res); +}); + TVM_REGISTER_GLOBAL("ansor.GetStatesFromMeasureInputs") .set_body([](TVMArgs args, TVMRetValue *ret) { Array inputs = args[0]; @@ -672,6 +670,5 @@ TVM_REGISTER_GLOBAL("ansor.GetStatesFromMeasureInputs") *ret = states; }); - } // namespace ansor } // namespace tvm diff --git a/src/ansor/transform_step.cc b/src/ansor/transform_step.cc index bd0a7f7165f6..e882a0495263 100644 --- a/src/ansor/transform_step.cc +++ b/src/ansor/transform_step.cc @@ -428,8 +428,7 @@ std::string AnnotationStepNode::PrintAsPythonAPI(std::vector *stages, } /********** Compute At **********/ -ComputeAtStep::ComputeAtStep(int stage_id, int target_stage_id, - int target_iter_id) { +ComputeAtStep::ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id) { auto node = make_object(); node->stage_id = stage_id; node->target_stage_id = target_stage_id; diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h index edd71732b3e2..f8283b876f18 100644 --- a/src/ansor/transform_step.h +++ b/src/ansor/transform_step.h @@ -77,7 +77,6 @@ class ReorderStep : public Step { ReorderStep(int stage_id, const std::vector& after_ids); TVM_DEFINE_OBJECT_REF_METHODS(ReorderStep, Step, ReorderStepNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(ReorderStepNode); }; /*! \brief Split step that corresponds to te::Stage::split with additional @@ -113,7 +112,6 @@ class SplitStep : public Step { bool inner_to_outer); TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(SplitStepNode); }; /*! \brief Similar to SplitStepNode, but use split factor from another step @@ -149,7 +147,6 @@ class FollowSplitStep : public Step { FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split); TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(FollowSplitStepNode); }; @@ -189,7 +186,6 @@ class FollowFusedSplitStep : public Step { int level, bool factor_or_nparts); TVM_DEFINE_OBJECT_REF_METHODS(FollowFusedSplitStep, Step, FollowFusedSplitStepNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(FollowFusedSplitStepNode); }; /*! \brief Fuse step that corresponds to te::Stage::fuse */ @@ -218,7 +214,6 @@ class FuseStep : public Step { FuseStep(int stage_id, const std::vector& fused_ids); TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(FuseStepNode); }; /*! \brief Annotation step that corresponds to vectorize, parallel, unroll and thread binding. @@ -250,10 +245,9 @@ class AnnotationStep : public Step { AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann); TVM_DEFINE_OBJECT_REF_METHODS(AnnotationStep, Step, AnnotationStepNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(AnnotationStepNode); }; -/*! \brief Fuse step that corresponds to te::Stage::compute_at */ +/*! \brief Compute at step that corresponds to te::Stage::compute_at */ class ComputeAtStepNode: public StepNode { public: int target_stage_id; @@ -280,10 +274,9 @@ class ComputeAtStep : public Step { ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id); TVM_DEFINE_OBJECT_REF_METHODS(ComputeAtStep, Step, ComputeAtStepNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeAtStepNode); }; -/*! \brief Fuse step that corresponds to te::Stage::compute_root */ +/*! \brief Compute root step that corresponds to te::Stage::compute_root */ class ComputeRootStepNode: public StepNode { public: void ApplyToSchedule(std::vector *stages, @@ -307,10 +300,9 @@ class ComputeRootStep : public Step { explicit ComputeRootStep(int stage_id); TVM_DEFINE_OBJECT_REF_METHODS(ComputeRootStep, Step, ComputeRootStepNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeRootStepNode); }; -/*! \brief Fuse step that corresponds to te::Stage::compute_inline */ +/*! \brief Compute inline step that corresponds to te::Stage::compute_inline */ class ComputeInlineStepNode: public StepNode { public: void ApplyToSchedule(std::vector *stages, @@ -334,7 +326,6 @@ class ComputeInlineStep : public Step { explicit ComputeInlineStep(int stage_id); TVM_DEFINE_OBJECT_REF_METHODS(ComputeInlineStep, Step, ComputeInlineStepNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeInlineStepNode); }; /*! \brief Cache read step that corresponds to te::Schedule::cache_read */ @@ -366,10 +357,9 @@ class CacheReadStep : public Step { const std::vector& reader_stage_id); TVM_DEFINE_OBJECT_REF_METHODS(CacheReadStep, Step, CacheReadStepNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(CacheReadStepNode); }; -/*! \brief Cache read step that corresponds to te::Schedule::cache_write +/*! \brief Cache write step that corresponds to te::Schedule::cache_write * \Note This step will cache_write all output tensors of target stage */ class CacheWriteStepNode: public StepNode { public: @@ -397,10 +387,9 @@ class CacheWriteStep : public Step { CacheWriteStep(int stage_id, std::string scope_name); TVM_DEFINE_OBJECT_REF_METHODS(CacheWriteStep, Step, CacheWriteStepNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(CacheWriteStepNode); }; -/*! \brief Cache read step that corresponds to te::Schedule::pragma */ +/*! \brief Pragma step that corresponds to te::Schedule::pragma */ class PragmaStepNode: public StepNode { public: int iter_id; @@ -427,7 +416,6 @@ class PragmaStep : public Step { PragmaStep(int stage_id, int iter_id, std::string pragma_type); TVM_DEFINE_OBJECT_REF_METHODS(PragmaStep, Step, PragmaStepNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(PragmaStepNode); }; /*! \brief Reduction factor step that corresponds to te::Schedule::rfactor */ @@ -458,7 +446,6 @@ class RfactorStep : public Step { RfactorStep(int stage_id, int iter_id, int factor_iter_id); TVM_DEFINE_OBJECT_REF_METHODS(RfactorStep, Step, RfactorStepNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(RfactorStepNode); }; /*! \brief Storage align step that corresponds to te::Schedule::storage_align */ @@ -489,7 +476,6 @@ class StorageAlignStep : public Step { StorageAlignStep(int stage_id, int iter_id, int factor, int offset); TVM_DEFINE_OBJECT_REF_METHODS(StorageAlignStep, Step, StorageAlignStepNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(StorageAlignStepNode); }; /*! \brief Tensorize step that corresponds to te::Schedule::tensorize @@ -520,7 +506,6 @@ class TensorizeStep : public Step { TensorizeStep(int stage_id, int iter_id, std::string ti_func_name); TVM_DEFINE_OBJECT_REF_METHODS(TensorizeStep, Step, TensorizeStepNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(TensorizeStepNode); }; } // namespace ansor diff --git a/src/ansor/utils.h b/src/ansor/utils.h index cb90364b01b5..4e98bb907af9 100644 --- a/src/ansor/utils.h +++ b/src/ansor/utils.h @@ -81,13 +81,6 @@ struct hash > { namespace tvm { namespace ansor { -/*! \brief Macro to make it easy to define object ref type given node */ -#define TVM_DEFINE_OBJECT_REF(TypeName, ObjectName) \ - class TypeName : public ObjectRef { \ - public: \ - TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ObjectRef, ObjectName); \ - }; \ - /*! \brief Macro to make it easy to define mutable object ref type given node */ #define TVM_DEFINE_MUTABLE_OBJECT_REF(TypeName, ObjectName) \ class TypeName : public ObjectRef { \ @@ -95,17 +88,6 @@ namespace ansor { TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ObjectRef, ObjectName); \ }; \ -/*! - * \brief Macro to make it easy to define node ref type that - * has a CopyOnWrite member function. - */ -#define TVM_DEFINE_COW_OBJECT_REF(TypeName, BaseType, ObjectName) \ - class TypeName : public BaseType { \ - public: \ - TVM_DEFINE_OBJECT_REF_METHODS(TypeName, BaseType, ObjectName); \ - TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName); \ - }; - /********** Utilities for std::vector, std::set, std::string **********/ /*! \brief Get the first appearance index of elements in a vector */ template From 58601918b60ebf6bfcc57ec0a9c36c7da21c2de7 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 23 Jun 2020 13:57:32 -0700 Subject: [PATCH 38/45] Improve loop state python API (stage_tensors -> stage_ops) (#41) * improve loop state python API (stage_tensors -> stage_ops) * fix --- python/tvm/ansor/loop_state.py | 324 ++++++++---------- .../python/unittest/test_ansor_compute_dag.py | 6 +- tests/python/unittest/test_ansor_feature.py | 4 +- .../python/unittest/test_ansor_loop_state.py | 14 +- 4 files changed, 153 insertions(+), 195 deletions(-) diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/ansor/loop_state.py index 8560a57bc902..7aa5de0e9c1d 100644 --- a/python/tvm/ansor/loop_state.py +++ b/python/tvm/ansor/loop_state.py @@ -25,16 +25,17 @@ Basically this is a simplified TVM IR with schedule primitives. We don't use the existing TVM IR because 1. We want fast incremental change to the loop structures -2. We want serializable history for replay and backtracking +2. We want serializable transformation history for replay, backtracking, and mutation 3. We may create some new macro schedule primitives -After search is done, we will lower this IR to TVM IR with TVM's schedule primitives. +After the search is done, we will lower this IR to TVM IR with TVM's schedule primitives. Because we share a lot common objects during search, the transformation is implemented in copy on write style. All objects are immutable, which is similar to TVM IR. """ import tvm._ffi +from tvm.te.tensor import Operation, Tensor from tvm.runtime import Object from . import _ffi_api @@ -80,43 +81,9 @@ def __init__(self, state_object, dag): self.state_object = state_object self.compute_dag = dag - self.stages_cache = None - self.stage_id_map = {} - self.__update_tensor_stage_map() - - def __getitem__(self, k): - if not self.stages_cache: - self.stages_cache = _ffi_api.StateGetStages(self.state_object) - if isinstance(k, tvm.te.Tensor): - return self.stages_cache[self.stage_id_map[k.op]] - raise ValueError("Item must be Tensor") - - def __update_tensor_stage_map(self): - if not self.stages_cache: - self.stages_cache = _ffi_api.StateGetStages(self.state_object) - for index, stage in enumerate(self.stages_cache): - self.stage_id_map[stage.op] = index - - def __insert_new_stage(self, new_stage_id): - new_stage_id = int(new_stage_id) - self.stages_cache = _ffi_api.StateGetStages(self.state_object) - added_stage_tensor = self.stages_cache[new_stage_id].op.output(0) - - for key, value in self.stage_id_map.items(): - if value >= new_stage_id: - self.stage_id_map[key] = value + 1 - self.stage_id_map[added_stage_tensor.op] = new_stage_id - self.__update_tensor_stage_map() - - return added_stage_tensor - - def clear_cache(self): - self.stages_cache = None - - def copy(self): - state = State(self.state_object, self.compute_dag) - state.stage_id_map = self.stage_id_map.copy() - return state + self.stages_cache = None # A list to cache all stages + self.stage_id_map = {} # A dict maps operation to stage id + self._update_stage_id_map() @property def stages(self): @@ -130,15 +97,15 @@ def stages(self): return self.stages_cache @property - def stage_tensors(self): + def stage_ops(self): """ Returns ------- - Tensor + ops: List[Operation] """ if not self.stages_cache: self.stages_cache = _ffi_api.StateGetStages(self.state_object) - return [stage.op.output(0) for stage in self.stages_cache] + return [stage.op for stage in self.stages_cache] def transform_steps_size(self): """ Return the size of transform_steps @@ -149,30 +116,27 @@ def reorder(self, stage_id, order): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to reorder order : List[Iterator] Iterators in the expected order """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object = _ffi_api.StateReorder(self.state_object, stage_id, order) - self.clear_cache() + self._clear_cache() def split(self, stage_id, iterator, lengths, inner_to_outer=True): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to split iterator : Iterator The iterator to split - lengths: List[Int] + lengths: List[int] The split factors - inner_to_outer: Bool + inner_to_outer: bool True to use `factor` to split from inner to outer, False to use `nparts` to split from outer to inner @@ -181,27 +145,24 @@ def split(self, stage_id, iterator, lengths, inner_to_outer=True): res_its : List[Iterator] The splitted new Iterators """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object, res = _ffi_api.StateSplit(self.state_object, stage_id, iterator, lengths, inner_to_outer) - self.clear_cache() + self._clear_cache() return res def follow_split(self, stage_id, iterator, src_step_id, n_split): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to split iterator : Iterator The iterator to split - src_step_id : Int + src_step_id : int The index of the split step to follow in the history - n_split : Int + n_split : int The number of split level Returns @@ -209,14 +170,11 @@ def follow_split(self, stage_id, iterator, src_step_id, n_split): res_its : List[Iterator] The splitted new Iterators """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object, res = _ffi_api.StateFollowSplit(self.state_object, stage_id, iterator, src_step_id, n_split) - self.clear_cache() + self._clear_cache() return res def follow_fused_split(self, stage_id, iterator, src_step_ids, level, @@ -224,15 +182,15 @@ def follow_fused_split(self, stage_id, iterator, src_step_ids, level, """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to split iterator : Iterator The iterator to split - src_step_ids : List[Int] + src_step_ids : List[int] The indices of the split steps to follow in the history - level : Int + level : int Use the length in this split level - factor_or_nparts : Bool + factor_or_nparts : bool True to use `factor` for split from inner to outer, False to use `nparts` for split from outer to inner @@ -241,22 +199,19 @@ def follow_fused_split(self, stage_id, iterator, src_step_ids, level, res_its : List[Iterator] The splitted new Iterators """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object, res = _ffi_api.StateFollowFusedSplit(self.state_object, stage_id, iterator, src_step_ids, level, factor_or_nparts) - self.clear_cache() + self._clear_cache() return res def fuse(self, stage_id, iters): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to fuse iters : List[Iterator] The iterators to be fused @@ -266,20 +221,17 @@ def fuse(self, stage_id, iters): res_it : Iterator The fused Iterator """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object, res = _ffi_api.StateFuse(self.state_object, stage_id, iters) - self.clear_cache() + self._clear_cache() return res def vectorize(self, stage_id, iterator): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to vectorize iterator : Iterator The iterator to be vectorized @@ -289,20 +241,17 @@ def vectorize(self, stage_id, iterator): res_it : Iterator The vectorized Iterator """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object, res = _ffi_api.StateVectorize(self.state_object, stage_id, iterator) - self.clear_cache() + self._clear_cache() return res def parallel(self, stage_id, iterator): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to parallel iterator : Iterator The iterator to be parallelized @@ -312,24 +261,21 @@ def parallel(self, stage_id, iterator): res_it : Iterator The parallelized Iterator """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object, res = _ffi_api.StateParallel(self.state_object, stage_id, iterator) - self.clear_cache() + self._clear_cache() return res def unroll(self, stage_id, iterator, max_unroll=-1): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to unroll iterator : Iterator The iterator to be unrolled - max_unroll: Int + max_unroll: int The maximum length of the iterator that can be unrolled Returns @@ -337,21 +283,18 @@ def unroll(self, stage_id, iterator, max_unroll=-1): res_it : Iterator The unrolled Iterator """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object, res = _ffi_api.StateUnroll(self.state_object, stage_id, iterator, max_unroll) - self.clear_cache() + self._clear_cache() return res def bind_thread(self, stage_id, iterator, thread_name): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to bind iterator : Iterator The iterator to be bound @@ -372,201 +315,167 @@ def bind_thread(self, stage_id, iterator, thread_name): } thread_id = trans_table[thread_name] - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object, res = _ffi_api.StateBindThread(self.state_object, stage_id, iterator, thread_id) - self.clear_cache() + self._clear_cache() return res def compute_at(self, stage_id, target_stage_id, target_iter): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of source stage - target_stage_id : Int + target_stage_id : Union[int, Operation, Tensor] The index of the target stage of compute_at target_iter : Iterator The target Iterator of compute_at """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") - if isinstance(target_stage_id, tvm.te.Tensor): - target_stage_id = self.stage_id_map[target_stage_id.op] - elif not isinstance(target_stage_id, int): - raise ValueError("target_stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) + target_stage_id = self._resolve_stage_id(target_stage_id) self.state_object = _ffi_api.StateComputeAt(self.state_object, stage_id, target_stage_id, target_iter) - self.clear_cache() + self._clear_cache() def compute_root(self, stage_id): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to compute root """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object = _ffi_api.StateComputeRoot(self.state_object, stage_id) - self.clear_cache() + self._clear_cache() def compute_inline(self, stage_id): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to compute inline """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object = _ffi_api.StateComputeInline(self.state_object, stage_id) - self.clear_cache() + self._clear_cache() def cache_read(self, stage_id, scope_name, reader_stage_ids): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to do cache_read - scope_name : Str - reader_stage_ids : List[Int] + scope_name : str + reader_stage_ids : List[int] Returns ------- - new_stage_id : Int + new_stage_id : int The added staged id """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) + if isinstance(reader_stage_ids, list): tmp_list = [] for reader_stage_id in reader_stage_ids: - if isinstance(reader_stage_id, tvm.te.Tensor): - tmp_list.append(self.stage_id_map[reader_stage_id.op]) - elif isinstance(reader_stage_id, int): - tmp_list.append(reader_stage_id) - else: - raise ValueError("reader_stage_id must be Tensor or Int") + tmp_list.append(self._resolve_stage_id(reader_stage_id)) reader_stage_ids = tmp_list else: - raise ValueError("reader_stage_ids must be list of Tensor or Int") + raise ValueError("reader_stage_ids must be list of Tensor or int") self.state_object, new_stage_id = _ffi_api.StateCacheRead(self.state_object, stage_id, scope_name, reader_stage_ids, self.compute_dag) - return self.__insert_new_stage(new_stage_id) + return self._insert_new_stage(new_stage_id) def cache_write(self, stage_id, scope_name): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to do cache read - scope_name : Str + scope_name : str Returns ------- - new_stage_id : Int + new_stage_id : int The added staged id """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object, new_stage_id = _ffi_api.StateCacheWrite(self.state_object, stage_id, scope_name, self.compute_dag) - return self.__insert_new_stage(new_stage_id) + return self._insert_new_stage(new_stage_id) def pragma(self, stage_id, iterator, pragma_type): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to add pragma iterator : Iterator The iterator to add pragma - pragma_type : Str + pragma_type : str """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object = _ffi_api.StatePragma(self.state_object, stage_id, iterator, pragma_type) - self.clear_cache() + self._clear_cache() def rfactor(self, stage_id, iterator, factor_iter_id): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to do reduction factor iterator : Iterator - factor_iter_id : Int + factor_iter_id : int Returns ------- - new_stage_id : Int + new_stage_id : int The added staged id """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object, new_stage_id = _ffi_api.StateRfactor(self.state_object, stage_id, iterator, factor_iter_id, self.compute_dag) - return self.__insert_new_stage(new_stage_id) + return self._insert_new_stage(new_stage_id) def storage_align(self, stage_id, iterator, factor, offset): """ Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to do storage align iterator : Iterator - factor : Int - offset : Int + factor : int + offset : int """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object = _ffi_api.StateStorageAlign(self.state_object, stage_id, iterator, factor, offset) - self.clear_cache() + self._clear_cache() def tensorize(self, stage_id, iterator, ti_func_name): """ The `ti_func_name` corresponds to a global registered funcion - that returns a TensorIntrin + that returns a Tensorintrin Parameters ---------- - stage_id : Int + stage_id : Union[int, Operation, Tensor] The index of the stage to do storage align iterator : Iterator - The target iterator - ti_func_name : Str + The iterator to be tensorized + ti_func_name : str Tensorize intrinsic function name Returns @@ -574,17 +483,66 @@ def tensorize(self, stage_id, iterator, ti_func_name): res_it : Iterator The tensorized Iterator """ - if isinstance(stage_id, tvm.te.Tensor): - stage_id = self.stage_id_map[stage_id.op] - elif not isinstance(stage_id, int): - raise ValueError("stage_id must be Tensor or Int") + stage_id = self._resolve_stage_id(stage_id) self.state_object, res = _ffi_api.StateTensorize(self.state_object, stage_id, iterator, ti_func_name) - self.clear_cache() + self._clear_cache() return res + def _resolve_stage_id(self, stage_id): + if isinstance(stage_id, Operation): + return self.stage_id_map[stage_id] + elif isinstance(stage_id, tvm.te.Tensor): + return self.stage_id_map[stage_id.op] + elif isinstance(stage_id, int): + return stage_id + else: + raise ValueError("Invalid stage_id") + + def _update_stage_id_map(self): + if not self.stages_cache: + self.stages_cache = _ffi_api.StateGetStages(self.state_object) + for index, stage in enumerate(self.stages_cache): + self.stage_id_map[stage.op] = index + + def _insert_new_stage(self, new_stage_id): + new_stage_id = int(new_stage_id) + self.stages_cache = _ffi_api.StateGetStages(self.state_object) + added_op = self.stages_cache[new_stage_id].op + + # Add a new stage will change all ops. But we still want to use the old ops to index stages, + # So we keep updating them and do not remove the old ops. + + # Update stage_id_map for old ops, so we can still use the old ops to index stages. + for key, value in self.stage_id_map.items(): + if value >= new_stage_id: + self.stage_id_map[key] = value + 1 + self.stage_id_map[added_op] = new_stage_id + + # Update stage_id_map for new ops + self._update_stage_id_map() + + return added_op + + def _clear_cache(self): + self.stages_cache = None + + def copy(self): + state = State(self.state_object, self.compute_dag) + state.stage_id_map = self.stage_id_map.copy() + return state + + def __getitem__(self, key): + if not self.stages_cache: + self.stages_cache = _ffi_api.StateGetStages(self.state_object) + if isinstance(key, Tensor): + key = key.op + if isinstance(key, Operation): + return self.stages_cache[self.stage_id_map[key]] + raise ValueError("Item must be Tensor") + def __str__(self): return str(self.state_object) diff --git a/tests/python/unittest/test_ansor_compute_dag.py b/tests/python/unittest/test_ansor_compute_dag.py index 313dc1f89902..0768f82b805a 100644 --- a/tests/python/unittest/test_ansor_compute_dag.py +++ b/tests/python/unittest/test_ansor_compute_dag.py @@ -34,9 +34,9 @@ def test_infer_bound(): dag, s = get_tiled_matmul() s = dag.infer_bound_from_state(s) - A_global = s.stage_tensors[1] - B_global = s.stage_tensors[3] - C_global = s.stage_tensors[4] + A_global = s.stage_ops[1] + B_global = s.stage_ops[3] + C_global = s.stage_ops[4] assert s[B_global].iters[0].range.extent == 512 assert s[B_global].iters[1].range.extent == 16 assert s[A_global].iters[0].range.extent == 1 diff --git a/tests/python/unittest/test_ansor_feature.py b/tests/python/unittest/test_ansor_feature.py index bcc7683b3f4a..705556c65edf 100644 --- a/tests/python/unittest/test_ansor_feature.py +++ b/tests/python/unittest/test_ansor_feature.py @@ -33,7 +33,7 @@ def fequal(a, b): def test_cpu_matmul(): dag = ansor.ComputeDAG(matmul_ansor_test(512, 512, 512)) s = dag.get_init_state() - C = s.stage_tensors[2] + C = s.stage_ops[2] i, j, k = s[C].iters io, ii = s.split(C, i, [16]) @@ -42,7 +42,7 @@ def test_cpu_matmul(): s.vectorize(C, ji) s.parallel(C, io) s.parallel(C, jo) - s.unroll(2, k) + s.unroll(C, k) target = tvm.target.create('llvm') task = ansor.SearchTask(dag, "test", target) diff --git a/tests/python/unittest/test_ansor_loop_state.py b/tests/python/unittest/test_ansor_loop_state.py index 87688e276469..d90be1a78421 100644 --- a/tests/python/unittest/test_ansor_loop_state.py +++ b/tests/python/unittest/test_ansor_loop_state.py @@ -115,14 +115,14 @@ def test_compute_at_root_inline(): s0 = dag.get_init_state() # data, padding, kernel = 0, 1, 2 - conv = s0.stage_tensors[3] + conv = s0.stage_ops[3] # bias = 4 - bias_add = s0.stage_tensors[5] + bias_add = s0.stage_ops[5] # bn_scale = 6 - bn_mul = s0.stage_tensors[7] + bn_mul = s0.stage_ops[7] # bn_offset = 8 - bn_add = s0.stage_tensors[9] - relu = s0.stage_tensors[10] + bn_add = s0.stage_ops[9] + relu = s0.stage_ops[10] s0.compute_inline(bn_add) s0.compute_inline(bn_mul) @@ -193,8 +193,8 @@ def test_cache_read_write(): dag = ansor.ComputeDAG([data, kernel_data, add]) s0 = dag.get_init_state() - pad_temp = s0.stage_tensors[1] - kernel_split = s0.stage_tensors[3] + pad_temp = s0.stage_ops[1] + kernel_split = s0.stage_ops[3] # 0: init state ori_its = s0[add].iters From 14a19cd9597809801d570228818aea61b7082072 Mon Sep 17 00:00:00 2001 From: Chenfan Date: Wed, 24 Jun 2020 13:22:45 +0800 Subject: [PATCH 39/45] ComputeDAG bug fix & Add Custom TensorCore Matmul Example (#42) * Bug Fix * Sample example of Custom TensorCore Matmul --- scripts/common.py | 34 ++++---- scripts/tune_test.py | 181 +++++++++++++++++++++++++++++++++++++-- src/ansor/compute_dag.cc | 12 ++- 3 files changed, 199 insertions(+), 28 deletions(-) diff --git a/scripts/common.py b/scripts/common.py index ac25b28e55b1..e9cf58e128bb 100644 --- a/scripts/common.py +++ b/scripts/common.py @@ -81,25 +81,25 @@ def add_mn(M, N): @register_workload_func def matmul_nkkm(N, M, K, in_type='float32', out_type='float32', tensor_core_support=False): - A = te.placeholder((N, K), name='A', dtype=in_type) - B = te.placeholder((K, M), name='B', dtype=in_type) - k = te.reduce_axis((0, K), name='k') - if in_type == out_type: - if not (in_type == 'float16' and out_type == 'float16'): - tensor_core_support = False - C = te.compute((N, M), - lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), - name='C', - attrs={"ansor_tensor_core_support": "True" if tensor_core_support else "False"}) - else: + if tensor_core_support: + A = te.placeholder((N // 16, K // 16, 16, 16), name='A', dtype=in_type) + B = te.placeholder((K // 16, M // 16, 16, 16), name='B', dtype=in_type) + k = te.reduce_axis((0, K // 16), name='k') + kk = te.reduce_axis((0, 16), name='kk') if not ((in_type == 'float16' and out_type == 'float32') or \ - (in_type == 'int8' and out_type == 'int32')): - tensor_core_support = False + (in_type == 'int8' and out_type == 'int32')): + raise ValueError + C = te.compute((N // 16, M // 16, 16, 16), + lambda i, j, ii, jj: te.sum(A[i][k][ii][kk].astype(out_type) * B[k][j][kk][jj].astype(out_type), + axis=[k, kk]), + name='C') + else: + A = te.placeholder((N, K), name='A', dtype=in_type) + B = te.placeholder((K, M), name='B', dtype=in_type) + k = te.reduce_axis((0, K), name='k') C = te.compute((N, M), - lambda i, j: te.sum(A[i][k].astype(out_type) * B[k][j].astype(out_type), - axis=[k]), - name='C', - attrs={"ansor_tensor_core_support": "True" if tensor_core_support else "False"}) + lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), + name='C') return [A, B, C] diff --git a/scripts/tune_test.py b/scripts/tune_test.py index c98da3eca53b..6b39cf5e7865 100644 --- a/scripts/tune_test.py +++ b/scripts/tune_test.py @@ -24,14 +24,169 @@ import numpy as np import tvm -from tvm import ansor +from tvm import ansor, te from tvm.ansor.utils import request_remote from common import get_workload_keys, get_workload_weights, measure_schedule, str2bool +def tensor_core_meet_condition(meta_policy, state, stage_id): + pass + +def intrin_wmma_load_matrix(scope): + n = 16 + A = te.placeholder((n, n), name='A', dtype='float16') + BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope='shared', data_alignment=32, offset_factor=256) + C = te.compute((n, n), lambda i, j: A[i, j], name='C') + BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=256) + + def intrin_func(ins, outs): + ib = tvm.tir.ir_builder.create() + + BA = ins[0] + BC = outs[0] + ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync', + BC.data, n, n, n, BC.elem_offset // 256, + BA.access_ptr('r'), n, 'row_major')) + return ib.get() + + return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) + +@tvm._ffi.register_func +def intrin_wmma_load_matrix_a(): + return intrin_wmma_load_matrix("wmma.matrix_a") + +@tvm._ffi.register_func +def intrin_wmma_load_matrix_b(): + return intrin_wmma_load_matrix("wmma.matrix_b") + +@tvm._ffi.register_func +def intrin_wmma_gemm(): + n = 16 + A = te.placeholder((n, n), name='A', dtype='float16') + B = te.placeholder((n, n), name='B', dtype='float16') + k = te.reduce_axis((0, n), name="k") + C = te.compute((n, n), + lambda ii, jj: + te.sum(A[ii, k].astype('float') * B[k, jj].astype('float'), axis=k), + name='C') + BA = tvm.tir.decl_buffer(A.shape, A.dtype, name='BA', scope='wmma.matrix_a', data_alignment=32, offset_factor=256) + BB = tvm.tir.decl_buffer(B.shape, B.dtype, name='BB', scope='wmma.matrix_b', data_alignment=32, offset_factor=256) + BC = tvm.tir.decl_buffer(C.shape, C.dtype, name='BC', scope='wmma.accumulator', data_alignment=32, offset_factor=256) + + def intrin_func(ins, outs): + BA, BB = ins + BC, = outs + + def init(): + ib = tvm.tir.ir_builder.create() + ib.emit(tvm.tir.call_intrin('handle', 'tvm_fill_fragment', BC.data, n, n, n, BC.elem_offset // 256, 0.0)) + return ib.get() + + def update(): + ib = tvm.tir.ir_builder.create() + ib.emit(tvm.tir.call_intrin('handle', 'tvm_mma_sync', + BC.data, BC.elem_offset // 256, + BA.data, BA.elem_offset // 256, + BB.data, BB.elem_offset // 256, + BC.data, BC.elem_offset // 256)) + return ib.get() + + return update(), init(), update() + + return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC}) + +@tvm._ffi.register_func +def intrin_wmma_store_matrix(): + n = 16 + A = te.placeholder((n, n), name='A', dtype='float32') + BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope='wmma.accumulator', data_alignment=32, offset_factor=256) + C = te.compute((n, n), lambda i, j: A[i, j], name='C') + BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope='global', data_alignment=32, offset_factor=256) + + def intrin_func(ins, outs): + ib = tvm.tir.ir_builder.create() + BA = ins[0] + BC = outs[0] + ib.emit(tvm.tir.call_intrin('handle', 'tvm_store_matrix_sync', + BA.data, n, n, n, BA.elem_offset // 256, + BC.access_ptr('w'), n, 'row_major')) + return ib.get() + + return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) + +def tensor_core_apply(meta_policy, state, stage_id): + ret = [] + state = ansor.loop_state.State(state, meta_policy.cur_task.compute_dag) + + A, B, C = meta_policy.cur_task.compute_dag.ops + + C_local = state.cache_write(C, "wmma.accumulator") + + its0 = state.split(C_local, state[C_local].iters[0], [None, None]) + split_step0 = state.transform_steps_size() - 1 + its1 = state.split(C_local, state[C_local].iters[3], [None, None]) + split_step1 = state.transform_steps_size() - 1 + its2 = state.split(C_local, state[C_local].iters[8], [None]) + + state.reorder(C_local, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], + its2[0], its2[1], + state[C_local].iters[6], + state[C_local].iters[7], + state[C_local].iters[10]]) + state.fuse(C_local, [state[C_local].iters[0], state[C_local].iters[1]]) + state.fuse(C_local, [state[C_local].iters[1], state[C_local].iters[2]]) + state.fuse(C_local, [state[C_local].iters[2], state[C_local].iters[3]]) + + its0 = state.follow_split(C, state[C].iters[0], split_step0, 2) + its1 = state.follow_split(C, state[C].iters[3], split_step1, 2) + state.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], + state[C].iters[6], state[C].iters[7]]) + state.fuse(C, [state[C].iters[0], state[C].iters[1]]) + state.fuse(C, [state[C].iters[1], state[C].iters[2]]) + local_write_pos = state.fuse(C, [state[C].iters[2], state[C].iters[3]]) + state.compute_at(C_local, C, local_write_pos) + shared_read_pos = state[C_local].iters[3] + local_read_pos = state[C_local].iters[4] + state.bind_thread(C, state[C].iters[0], "blockIdx.x") + state.bind_thread(C, state[C].iters[1], "vthread") + state.bind_thread(C, state[C].iters[2], "threadIdx.x") + + B_shared = state.cache_read(B, "shared", [C_local]) + B_local = state.cache_read(B_shared, "wmma.matrix_b", [C_local]) + state.compute_at(B_shared, C_local, shared_read_pos) + state.compute_at(B_local, C_local, local_read_pos) + + it = state.fuse(B_shared, state[B_shared].iters[:]) + its = state.split(B_shared, it, [4]) # vectorize add a callback check function + state.vectorize(B_shared, its[1]) + its = state.follow_fused_split(B_shared, its[0], [split_step0, split_step1], 1, True) + state.bind_thread(B_shared, its[1], "threadIdx.x") + + A_shared = state.cache_read(A, "shared", [C_local]) + A_local = state.cache_read(A_shared, "wmma.matrix_a", [C_local]) + state.compute_at(A_shared, C_local, shared_read_pos) + state.compute_at(A_local, C_local, local_read_pos) + + it = state.fuse(A_shared, state[A_shared].iters[:]) + its = state.split(A_shared, it, [4]) # vectorize add a callback check function + state.vectorize(A_shared, its[1]) + its = state.follow_fused_split(A_shared, its[0], [split_step0, split_step1], 1, True) + state.bind_thread(A_shared, its[1], "threadIdx.x") + + state.tensorize(A_local, state[A_local].iters[-2], "intrin_wmma_load_matrix_a") + state.tensorize(B_local, state[B_local].iters[-2], "intrin_wmma_load_matrix_b") + state.tensorize(C_local, state[C_local].iters[-3], "intrin_wmma_gemm") + state.tensorize(C, state[C].iters[-2], "intrin_wmma_store_matrix") + + print(state) + + ret.append([state.state_object, -1]) + return ret + def create_tune_option(target, log_file, n_trials, num_measure_per_iter, verbose, n_parallel, build_timeout, local_measure, rpc_device_key, rpc_host, - rpc_port, rpc_num_threads, ndk_cc, early_stopping=-1, run_timeout=10): + rpc_port, rpc_num_threads, ndk_cc, early_stopping=-1, run_timeout=10, + tensor_core_matmul=False): builder = runner = measure_ctx = None if local_measure: builder = ansor.LocalBuilder(timeout=build_timeout) @@ -52,13 +207,16 @@ def create_tune_option(target, log_file, n_trials, num_measure_per_iter, verbose config_threadpool = remote.get_function('runtime.config_threadpool') config_threadpool(0, rpc_num_threads) + pre_search_callbacks = [ansor.PreloadMeasuredStates(log_file)] + if tensor_core_matmul: + pre_search_callbacks.append(ansor.PreloadCustomSketchRule(tensor_core_meet_condition, tensor_core_apply)) tune_option = ansor.TuneOption(n_trials=n_trials, early_stopping=early_stopping, num_measure_per_iter=num_measure_per_iter, verbose=verbose, builder=builder, runner=runner, measure_callbacks=[ansor.LogToFile(log_file)], - pre_search_callbacks=[ansor.PreloadMeasuredStates(log_file)]) + pre_search_callbacks=pre_search_callbacks) return tune_option, measure_ctx @@ -113,10 +271,10 @@ def tune_workload(wkl_key, target, target_host, policy, model_type, model.load(load_model_file) elif load_log_file: model.load_log_file(load_log_file) - elif model_type == "random": - model = ansor.RandomModel() - else: - raise ValueError("Invalid model: " + model_type) + elif model_type == "random": + model = ansor.RandomModel() + else: + raise ValueError("Invalid model: " + model_type) if policy == 'sketch': policy = ansor.SketchSearchPolicy(program_cost_model=model) @@ -200,11 +358,18 @@ def objective_func(costs): load_log_file = args.load_log or log_file weights = get_workload_weights(args.wkl) + # Special check for tensor core + wkl_key = args.wkl + wkl_key = wkl_key.split("-") + tensor_core_matmul = False + if wkl_key[0] == "matmul" and wkl_key[6] == "tc": + tensor_core_matmul = True + tune_option, measure_ctx = create_tune_option(target, log_file, args.n_trials, args.num_measure_per_iter, args.verbose, args.n_parallel, args.build_timeout, args.local_measure, args.rpc_device_key, args.rpc_host, args.rpc_port, args.rpc_num_threads, - args.ndk_cc) + args.ndk_cc, tensor_core_matmul=tensor_core_matmul) if args.task_scheduler == 'no': # tune workloads one by one diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index ee87318cdd84..9e6da6ff6f3b 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -569,13 +569,11 @@ State ComputeDAG::GetInitState() const { ComputeDAG::ComputeDAG(Array tensors) { auto node = make_object(); FlopEstimator estimator; - node->tensors = std::move(tensors); node->access_analyzer = AccessAnalyzer(node->tensors); node->ops = Array(node->access_analyzer->ops_topo_order); node->flop_ct = estimator.EstimateFlop(node->ops); node->init_state = State(node->ops); - data_ = std::move(node); } @@ -587,7 +585,15 @@ ComputeDAG::ComputeDAG(const std::string& workload_key) { } else { LOG(FATAL) << "ansor.workload_key_to_tensors is not registered"; } - ComputeDAG(std::move(tens)); + + auto node = make_object(); + FlopEstimator estimator; + node->tensors = std::move(tens); + node->access_analyzer = AccessAnalyzer(node->tensors); + node->ops = Array(node->access_analyzer->ops_topo_order); + node->flop_ct = estimator.EstimateFlop(node->ops); + node->init_state = State(node->ops); + data_ = std::move(node); } std::string BaseName(const std::string& str) { From 59c88d1ecd15c0651a5bd406e25f9e65c07acf46 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 24 Jun 2020 13:58:55 +0800 Subject: [PATCH 40/45] Revert commit --- docs/conf.py | 1 - include/tvm/relay/attrs/transform.h | 13 - include/tvm/relay/transform.h | 14 - include/tvm/runtime/c_runtime_api.h | 23 - include/tvm/runtime/device_api.h | 3 +- include/tvm/runtime/ndarray.h | 12 +- scripts/common.py | 1034 ----------------- scripts/shape_configs.py | 247 ---- scripts/tune_network.py | 405 ------- scripts/tune_op_subgraph.py | 602 ---------- scripts/tune_test.py | 394 ------- src/arith/rewrite_simplify.cc | 71 +- src/relay/analysis/type_solver.cc | 1 - src/relay/op/tensor/transform.cc | 54 - src/relay/transforms/defuse_ops.cc | 91 -- .../transforms/kernel_layout_transform.cc | 66 -- .../transforms/kernel_layout_transform.h | 102 -- src/relay/transforms/pattern_util.h | 2 - src/runtime/cuda/cuda_device_api.cc | 4 - src/runtime/ndarray.cc | 80 +- src/runtime/opencl/opencl_device_api.cc | 3 - src/runtime/rpc/rpc_module.cc | 30 - src/runtime/threading_backend.cc | 9 +- src/te/schedule/schedule_dataflow_rewrite.cc | 66 +- src/tir/analysis/verify_gpu_code.cc | 44 +- src/tir/transforms/unroll_loop.cc | 20 +- tests/python/unittest/test_ansor_feature.py | 150 --- .../unittest/test_ansor_relay_integration.py | 114 -- .../unittest/test_ansor_task_scheduler.py | 52 - .../test_tir_transform_unroll_loop.py | 24 - topi/include/topi/transform.h | 69 -- topi/python/topi/nn/conv2d.py | 39 +- tutorials/ansor/README.txt | 4 - tutorials/ansor/tune_conv2d_cuda.py | 179 --- tutorials/ansor/tune_simple_subgraph.py | 193 --- tutorials/autotvm/README.txt | 4 +- 36 files changed, 31 insertions(+), 4188 deletions(-) delete mode 100644 scripts/common.py delete mode 100644 scripts/shape_configs.py delete mode 100644 scripts/tune_network.py delete mode 100644 scripts/tune_op_subgraph.py delete mode 100644 scripts/tune_test.py delete mode 100644 src/relay/transforms/defuse_ops.cc delete mode 100644 src/relay/transforms/kernel_layout_transform.cc delete mode 100644 src/relay/transforms/kernel_layout_transform.h delete mode 100644 tests/python/unittest/test_ansor_feature.py delete mode 100644 tests/python/unittest/test_ansor_relay_integration.py delete mode 100644 tests/python/unittest/test_ansor_task_scheduler.py delete mode 100644 tutorials/ansor/README.txt delete mode 100644 tutorials/ansor/tune_conv2d_cuda.py delete mode 100644 tutorials/ansor/tune_simple_subgraph.py diff --git a/docs/conf.py b/docs/conf.py index 5826526d55b0..7ece63bd7aa8 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -198,7 +198,6 @@ '../tutorials/language', '../tutorials/optimize', '../tutorials/autotvm', - '../tutorials/ansor', '../tutorials/dev', '../tutorials/topi', '../tutorials/deployment', diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 95476ed61bdd..750a8a43163c 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -296,19 +296,6 @@ struct LayoutTransformAttrs : public tvm::AttrsNode { } }; -/*! \brief Attributes for KernelLayoutTransform operator */ -struct KernelLayoutTransformAttrs : public tvm::AttrsNode { - std::string src_layout; - std::string dst_layout; - - TVM_DECLARE_ATTRS(KernelLayoutTransformAttrs, "relay.attrs.KernelLayoutTransformAttrs") { - TVM_ATTR_FIELD(src_layout) - .describe("The source layout of the tensor. (e.g. 1N32C112H112W)"); - TVM_ATTR_FIELD(dst_layout) - .describe("The destination layout of the tensor. (e.g. 1N2C112H112W16c)"); - } -}; - /*! \brief Attributes for ShapeOf operator */ struct ShapeOfAttrs : public tvm::AttrsNode { DataType dtype; diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 5f5d9b643633..1b8b31aee5d1 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -277,20 +277,6 @@ TVM_DLL Pass CanonicalizeOps(); */ TVM_DLL Pass AlterOpLayout(); -/*! - * \brief Alternate the layouts of kernels. - * - * \return The pass. - */ -TVM_DLL Pass KernelLayoutTransform(); - -/*! - * \brief The reverse of FuseOps. - * - * \return The pass. - */ -TVM_DLL Pass DeFuseOps(); - /*! * \brief Given a dest layout, this pass transforms the expr such that most of the ops input data * layout is changed to the dest layout. In ideal situation, there are only 2 layout transforms, one diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 5a32ac7d3d9f..213c7059a5f9 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -384,29 +384,6 @@ TVM_DLL int TVMFuncListGlobalNames(int* out_size, const char*** out_array); TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_bits, int dtype_lanes, int device_type, int device_id, TVMArrayHandle* out); -/*! - * \brief Allocate a nd-array's memory of non-empty values, - * including space of shape, of given spec. - * - * \param shape The shape of the array, the data content will be copied to out - * \param ndim The number of dimension of the array. - * \param dtype_code The type code of the dtype - * \param dtype_bits The number of bits of dtype - * \param dtype_lanes The number of lanes in the dtype. - * \param device_type The device type of context - * \param device_id The device id of context. - * \param out The output handle. - * \return 0 when success, -1 when failure happens - */ -TVM_DLL int TVMArrayAllocNonEmpty(const tvm_index_t* shape, - int ndim, - int dtype_code, - int dtype_bits, - int dtype_lanes, - int device_type, - int device_id, - TVMArrayHandle* out); - /*! * \brief Free the TVM Array. * \param handle The array handle to be freed. diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index 9b2eb6be2160..421811a52c3b 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -44,8 +44,7 @@ enum DeviceAttrKind : int { kMaxClockRate = 6, kMultiProcessorCount = 7, kMaxThreadDimensions = 8, - kGcnArch = 9, - kMaxRegistersPerBlock = 10 + kGcnArch = 9 }; /*! \brief Number of bytes each allocation must align to */ diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index 9cc66a371974..e69d802652fd 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -138,17 +138,7 @@ class NDArray : public ObjectRef { * \param ctx The context of the Array. * \return The created Array */ - TVM_DLL static NDArray Empty(std::vector shape, - DLDataType dtype, DLContext ctx); - /*! - * \brief Create an NDArray with non-empty values. - * \param shape The shape of the new array. - * \param dtype The data type of the new array. - * \param ctx The context of the Array. - * \return The created Array - */ - TVM_DLL static NDArray NonEmpty(std::vector shape, - DLDataType dtype, DLContext ctx); + TVM_DLL static NDArray Empty(std::vector shape, DLDataType dtype, DLContext ctx); /*! * \brief Create a NDArray backed by a dlpack tensor. * diff --git a/scripts/common.py b/scripts/common.py deleted file mode 100644 index e9cf58e128bb..000000000000 --- a/scripts/common.py +++ /dev/null @@ -1,1034 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Common utility for scripts""" -import argparse -import math -import os -import re -import time -from collections import defaultdict, namedtuple -from typing import Dict, List, Tuple - -import numpy as np -import matplotlib.pyplot as plt - -import topi -import tvm -from tvm import te -from tvm.ansor import (LogReader, make_workload_key_func, - register_workload_func, - write_measure_records_to_file) -from tvm.contrib import ndk, util - -############################################################ -###################### Test Workloads #################### -############################################################ - -@register_workload_func -def min_mn(M, N): - A = te.placeholder((M, N), name='A') - B = topi.min(A, axis=1) - - return [A, B] - -@register_workload_func -def argmin_mn(M, N): - A = te.placeholder((M, N), name='A') - B = topi.argmin(A, axis=1) - - return [A, B] - -@register_workload_func -def softmax_mn(M, N): - A = te.placeholder((M, N), name='A') - B = topi.nn.softmax(A, axis=1) - - return [A, B] - -@register_workload_func -def norm_bmn(B, M, N): - A = te.placeholder((B, M, N), name='A') - i = te.reduce_axis((0, M)) - j = te.reduce_axis((0, N)) - C = te.compute((B,), lambda b: te.sum(A[b][i][j] * A[b][i][j], axis=[i, j]), name='C') - D = te.compute((B,), lambda b: te.sqrt(C[b]), name='D') - - return [A, D] - -@register_workload_func -def add_mn(M, N): - A = te.placeholder((M, N), name='A') - B = te.placeholder((M, N), name='B') - C = te.compute((M, N), lambda i, j: A[i][j] + B[i][j], name='C') - - return [A, B, C] - -@register_workload_func -def matmul_nkkm(N, M, K, in_type='float32', out_type='float32', - tensor_core_support=False): - if tensor_core_support: - A = te.placeholder((N // 16, K // 16, 16, 16), name='A', dtype=in_type) - B = te.placeholder((K // 16, M // 16, 16, 16), name='B', dtype=in_type) - k = te.reduce_axis((0, K // 16), name='k') - kk = te.reduce_axis((0, 16), name='kk') - if not ((in_type == 'float16' and out_type == 'float32') or \ - (in_type == 'int8' and out_type == 'int32')): - raise ValueError - C = te.compute((N // 16, M // 16, 16, 16), - lambda i, j, ii, jj: te.sum(A[i][k][ii][kk].astype(out_type) * B[k][j][kk][jj].astype(out_type), - axis=[k, kk]), - name='C') - else: - A = te.placeholder((N, K), name='A', dtype=in_type) - B = te.placeholder((K, M), name='B', dtype=in_type) - k = te.reduce_axis((0, K), name='k') - C = te.compute((N, M), - lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), - name='C') - - return [A, B, C] - -@register_workload_func -def dense_layer(batch, in_dim, out_dim): - A = te.placeholder((batch, in_dim), name='A') - B = te.placeholder((out_dim, in_dim), name='B') - k = te.reduce_axis((0, in_dim), name='k') - C = te.compute((batch, out_dim), lambda i, j: te.sum(A[i][k] * B[j][k], axis=[k]), name='C') - - return [A, B, C] - -@register_workload_func -def max_pool_2d_nchw(N, C, H, W): - data = te.placeholder((N, C, H, W), name='data') - out = topi.nn.pool(data, (2, 2), (1, 1), (0, 0, 0, 0), pool_type='max', ceil_mode=True, - layout="NCHW", count_include_pad=True) - - return [data, out] - -@register_workload_func -def add_min_relu(M, N): - A = te.placeholder((M, N), name='A') - B = te.placeholder((M, N), name='B') - C = topi.add(A, B) - D = topi.min(C, axis=1) - out = topi.nn.relu(D) - return [A, B, out] - -@register_workload_func -def conv2d_relu_softmax_min(N, H, W, CI, CO, KH, KW, strides, padding, dilation): - data = te.placeholder((N, CI, H, W), name='data') - kernel = te.placeholder((CO, CI, KH, KW), name='kernel') - conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation) - relu = topi.nn.relu(conv) - softmax = topi.nn.softmax(relu, axis=1) - out = topi.min(softmax, axis=1) - - return [data, kernel, out] - -@register_workload_func -def conv2d_nchw_bias(N, H, W, CI, CO, KH, KW, strides, padding, dilation): - data = te.placeholder((N, CI, H, W), name='data') - kernel = te.placeholder((CO, CI, KH, KW), name='kernel') - bias = te.placeholder((CO, 1, 1), name='bias') - conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation) - #out = topi.nn.relu(conv) - out = topi.add(conv, bias) - return [data, kernel, bias, out] - -def conv2d_nhwc_without_layout_rewrite(Input, Filter, stride, padding, dilation, out_dtype='float32'): - """A copy of `topi.nn.conv2d_nhwc` but without the 'layout_free` attribute. - We use this in single op and subgraph evaluation because we don't want to introduce graph level optimization. - """ - assert isinstance(stride, int) or len(stride) == 2 - assert isinstance(dilation, int) or len(dilation) == 2 - - if isinstance(stride, int): - stride_h = stride_w = stride - else: - stride_h, stride_w = stride - - if isinstance(dilation, int): - dilation_h = dilation_w = dilation - else: - dilation_h, dilation_w = dilation - - batch, in_height, in_width, in_channel = Input.shape - if len(Filter.shape) == 10: - kernel_h = Filter.shape[2] * Filter.shape[6] - kernel_w = Filter.shape[3] * Filter.shape[7] - channel = Filter.shape[4] * Filter.shape[8] - num_filter = Filter.shape[0] * Filter.shape[1] * Filter.shape[5] * Filter.shape[9] - #Filter = te.placeholder([kernel_h, kernel_w, channel, num_filter], Filter.dtype, Filter.name) - elif len(Filter.shape) == 11: - kernel_h = Filter.shape[3] * Filter.shape[7] - kernel_w = Filter.shape[4] * Filter.shape[8] - channel = Filter.shape[5] * Filter.shape[9] - num_filter = Filter.shape[0] * Filter.shape[1] * Filter.shape[2] * Filter.shape[6] * Filter.shape[10] - else: - kernel_h, kernel_w, channel, num_filter = Filter.shape - - # compute the output shape - dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 - dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 - pad_top, pad_left, pad_down, pad_right = topi.nn.get_pad_tuple( - padding, (dilated_kernel_h, dilated_kernel_w)) - out_channel = num_filter - out_height = topi.util.simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) - out_width = topi.util.simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) - pad_before = [0, pad_top, pad_left, 0] - pad_after = [0, pad_down, pad_right, 0] - PaddedInput = topi.nn.pad(Input, pad_before, pad_after, name="PaddedInput") - rc = te.reduce_axis((0, in_channel), name='rc') - ry = te.reduce_axis((0, kernel_h), name='ry') - rx = te.reduce_axis((0, kernel_w), name='rx') - Output = te.compute( - (batch, out_height, out_width, out_channel), - lambda nn, yy, xx, ff: te.sum( - PaddedInput[nn, yy * stride_h + ry * dilation_h, - xx * stride_w + rx * dilation_w, rc].astype(out_dtype) * - Filter[ry, rx, rc, ff].astype(out_dtype) - , axis=[ry, rx, rc]), - name="Conv2dOutput", tag="conv2d_nhwc") - return Output - - -@register_workload_func -def conv2d_nhwc_bias_with_rewrite(N, H, W, CI, CO, KH, KW, strides, padding, dilation): - data = te.placeholder((N, H, W, CI), name='data') - kernel = te.placeholder((KH, KW, CI, CO), name='kernel') - bias = te.placeholder((CO, ), name='bias') - conv = topi.nn.conv2d_nhwc(data, kernel, strides, padding, dilation) - out = topi.add(conv, bias) - return [data, kernel, bias, out] - -@register_workload_func -def depthwise_conv2d_nhwc_bias_with_rewrite(N, H, W, CI, CO, KH, KW, strides, padding, dilation): - data = te.placeholder((N, H, W, CI), name='data') - kernel = te.placeholder((KH, KW, CI, 1), name='kernel') - bias = te.placeholder((CO, ), name='bias') - conv = topi.nn.depthwise_conv2d_nhwc(data, kernel, strides, padding, dilation) - out = topi.add(conv, bias) - return [data, kernel, bias, out] - -@register_workload_func -def conv2d_nhwc_bias(N, H, W, CI, CO, KH, KW, strides, padding, dilation): - data = te.placeholder((N, H, W, CI), name='data') - kernel = te.placeholder((KH, KW, CI, CO), name='kernel') - bias = te.placeholder((CO, ), name='bias') - conv = conv2d_nhwc_without_layout_rewrite(data, kernel, strides, padding, dilation) - out = topi.add(conv, bias) - return [data, kernel, bias, out] - - -@register_workload_func -def conv2d_nchw_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation=1): - data = te.placeholder((N, CI, H, W), name='data') - kernel = te.placeholder((CO, CI, kernel_size, kernel_size), name='kernel') - bias = te.placeholder((CO, 1, 1), name='bias') - bn_scale = te.placeholder((CO, 1, 1), name='bn_scale') - bn_offset = te.placeholder((CO, 1, 1), name='bn_offset') - - OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 - OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 - - conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation) - conv = te.compute((N, CO, OH, OW), - lambda i, j, k, l: conv[i, j, k, l] + bias[j, 0, 0], - name='bias_add') - conv = te.compute((N, CO, OH, OW), - lambda i, j, k, l: conv[i, j, k, l] * bn_scale[j, 0, 0], - name='bn_mul') - conv = te.compute((N, CO, OH, OW), - lambda i, j, k, l: conv[i, j, k, l] + bn_offset[j, 0, 0], - name='bn_add') - out = topi.nn.relu(conv) - - return [data, kernel, bias, bn_offset, bn_scale, out] - -@register_workload_func -def conv2d_nhwc_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation=1): - data = te.placeholder((N, H, W, CI), name='data') - kernel = te.placeholder((kernel_size, kernel_size, CI, CO), name='kernel') - bias = te.placeholder((CO,), name='bias') - bn_scale = te.placeholder((CO,), name='bn_scale') - bn_offset = te.placeholder((CO,), name='bn_offset') - - OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 - OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 - - conv = conv2d_nhwc_without_layout_rewrite(data, kernel, strides, padding, dilation) - conv = te.compute((N, OH, OW, CO), - lambda i, j, k, l: conv[i, j, k, l] + bias[l], - name='bias_add') - conv = te.compute((N, OH, OW, CO), - lambda i, j, k, l: conv[i, j, k, l] * bn_scale[l], - name='bn_mul') - conv = te.compute((N, OH, OW, CO), - lambda i, j, k, l: conv[i, j, k, l] + bn_offset[l], - name='bn_add') - out = topi.nn.relu(conv) - - return [data, kernel, bias, bn_offset, bn_scale, out] - -resnet_conv2d_configs = { - # format : N, H, W, CI, CO, KH, KW, strides, padding, dilation - '18': [ - (1, 224, 224, 3, 64, 7, 7, (2, 2), (3, 3), (1, 1)), - (1, 56, 56, 64, 128, 3, 3, (2, 2), (1, 1), (1, 1)), - (1, 56, 56, 64, 128, 1, 1, (2, 2), (0, 0), (1, 1)), - (1, 56, 56, 64, 64, 3, 3, (1, 1), (1, 1), (1, 1)), - (1, 56, 56, 64, 64, 1, 1, (1, 1), (0, 0), (1, 1)), - (1, 28, 28, 128, 256, 3, 3, (2, 2), (1, 1), (1, 1)), - (1, 28, 28, 128, 256, 1, 1, (2, 2), (0, 0), (1, 1)), - (1, 28, 28, 128, 128, 3, 3, (1, 1), (1, 1), (1, 1)), - (1, 14, 14, 256, 512, 3, 3, (2, 2), (1, 1), (1, 1)), - (1, 14, 14, 256, 512, 1, 1, (2, 2), (0, 0), (1, 1)), - (1, 14, 14, 256, 256, 3, 3, (1, 1), (1, 1), (1, 1)), - (1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1), (1, 1)), - ], - '50': [ - (1, 224, 224, 3, 64, 7, 7, (2, 2), (3, 3), (1, 1)), - (1, 56, 56, 256, 512, 1, 1, (2, 2), (0, 0), (1, 1)), - (1, 56, 56, 256, 128, 1, 1, (2, 2), (0, 0), (1, 1)), - (1, 56, 56, 256, 64, 1, 1, (1, 1), (0, 0), (1, 1)), - (1, 56, 56, 64, 256, 1, 1, (1, 1), (0, 0), (1, 1)), - (1, 56, 56, 64, 64, 3, 3, (1, 1), (1, 1), (1, 1)), - (1, 56, 56, 64, 64, 1, 1, (1, 1), (0, 0), (1, 1)), - (1, 28, 28, 512, 1024, 1, 1, (2, 2), (0, 0), (1, 1)), - (1, 28, 28, 512, 256, 1, 1, (2, 2), (0, 0), (1, 1)), - (1, 28, 28, 512, 128, 1, 1, (1, 1), (0, 0), (1, 1)), - (1, 28, 28, 128, 512, 1, 1, (1, 1), (0, 0), (1, 1)), - (1, 28, 28, 128, 128, 3, 3, (1, 1), (1, 1), (1, 1)), - (1, 14, 14, 1024, 2048, 1, 1, (2, 2), (0, 0), (1, 1)), - (1, 14, 14, 1024, 512, 1, 1, (2, 2), (0, 0), (1, 1)), - (1, 14, 14, 1024, 256, 1, 1, (1, 1), (0, 0), (1, 1)), - (1, 14, 14, 256, 1024, 1, 1, (1, 1), (0, 0), (1, 1)), - (1, 14, 14, 256, 256, 3, 3, (1, 1), (1, 1), (1, 1)), - (1, 7, 7, 2048, 512, 1, 1, (1, 1), (0, 0), (1, 1)), - (1, 7, 7, 512, 2048, 1, 1, (1, 1), (0, 0), (1, 1)), - (1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1), (1, 1)), - ], -} - -# number of appearance for all conv2ds in resnet -resnet_conv2d_weights = { - '18': [1, 1, 1, 4, 1, 1, 1, 3, 1, 1, 3, 3], - '50': [1, 1, 1, 2, 4, 3, 1, 1, 1, 3, 4, 4, 1, 1, 5, 6, 6, 2, 3, 3], -} - - -def parse_workload_name(name: str) -> List[str]: - """Parse workload name with wildcard character and abbreviation to standard names""" - if name.startswith('matmul-'): # e.g. matmul-512, matmul-1024, matmul-+ - N = name.split('-', maxsplit=1)[1] - if N == '+': - cfg_list = [256, 512, 1024] - else: - cfg_list = [N] - return ["matmul-%s" % x for x in cfg_list] - elif name.startswith('dense-'): # e.g. dense-1-512-1024, dense-16-512-512 - N = name.split('-', maxsplit=1)[1] - if N == '+': - cfg_list = ["1-512-512", "16-512-512"] - else: - cfg_list = [N] - return ["dense-%s" % x for x in cfg_list] - elif name.startswith('min-'): # e.g. min-4096 - N = name.split('-', maxsplit=1)[1] - if N == '+': - cfg_list = [4096, 8192, 16384] - else: - cfg_list = [N] - return ["min-%s" % x for x in cfg_list] - elif name.startswith('argmin-'): # e.g. argmin-4096 - N = name.split('-', maxsplit=1)[1] - if N == '+': - cfg_list = [4096, 8192, 16384] - else: - cfg_list = [N] - return ["argmin-%s" % x for x in cfg_list] - elif name.startswith('softmax-'): # e.g. softmax-4096 - N = name.split('-', maxsplit=1)[1] - if N == '+': - cfg_list = [4096, 8192, 16384] - else: - cfg_list = [N] - return ["softmax-%s" % x for x in cfg_list] - elif name.startswith('add-'): # e.g. add-4096 - N = name.split('-', maxsplit=1)[1] - if N == '+': - cfg_list = [4096, 8192, 16384] - else: - cfg_list = [N] - return ["add-%s" % x for x in cfg_list] - elif name.startswith('norm-'): # e.g. norm-1024 - N = name.split('-', maxsplit=1)[1] - if N == '+': - cfg_list = [4096, 8192, 16384] - else: - cfg_list = [N] - return ["norm-%s" % x for x in cfg_list] - elif name.startswith('add-min-relu'): # e.g. add-min-relu-4096 - N = name.split('-', maxsplit=3)[3] - if N == '+': - cfg_list = [4096, 8192, 16384] - else: - cfg_list = [N] - return ["add-min-relu-%s" % x for x in cfg_list] - elif name.startswith('nhwc-resnet-'): # e.g. nhwc-resnet-50.C1 - res = re.match(r'nhwc-resnet-(\d+).C([\d\+]+)(.B(\d+))?', name) - n_layers = res.group(1) - if res.group(2) == '+': - idx_list = range(len(resnet_conv2d_configs[n_layers])) - else: - idx_list = [int(res.group(2))] - - batch_size = 1 if res.group(4) is None else int(res.group(4)) - return ['nhwc-resnet-%s.C%d.B%d' % (n_layers, i, batch_size) for i in idx_list] - elif name.startswith('resnet-'): # e.g. resnet-50.C1, resnet-50.C1.B2, resnet-50.C+.B2 - res = re.match(r'resnet-(\d+).C([\d\+]+)(.B(\d+))?', name) - n_layers = res.group(1) - if res.group(2) == '+': - idx_list = range(len(resnet_conv2d_configs[n_layers])) - else: - idx_list = [int(res.group(2))] - - batch_size = 1 if res.group(4) is None else int(res.group(4)) - return ['resnet-%s.C%d.B%d' % (n_layers, i, batch_size) for i in idx_list] - elif name in ['conv2d-bn-relu', 'conv2d-relu-softmax-min', 'max-pool-2d', 'conv2d-rewrite', 'depthwise-conv2d-rewrite']: - return [name] - else: - raise ValueError("Invalid workload " + name) - - -def get_workload_keys(name: str) -> List[str]: - """Parse workload name and return the workload keys""" - normalized_names = parse_workload_name(name) - - ret = [] - for name in normalized_names: - if name.startswith('matmul-'): - name_split = name.split('-') - in_type = out_type = 'float32' - tensor_core_support = False - if len(name_split) == 2: # e.g. matmul-512 - N = K = M = int(name_split[1]) - elif len(name_split) == 4: # e.g. matmul-32-256-512 - N = int(name_split[1]) - K = int(name_split[2]) - M = int(name_split[3]) - elif len(name_split) == 6: # e.g. matmul-32-512-512-float16-float32 - N = int(name_split[1]) - K = int(name_split[2]) - M = int(name_split[3]) - in_type = name_split[4] - out_type = name_split[5] - elif len(name_split) == 7: # e.g. matmul-32-512-512-float16-float32-tc - N = int(name_split[1]) - K = int(name_split[2]) - M = int(name_split[3]) - in_type = name_split[4] - out_type = name_split[5] - tensor_core_support = name_split[6] == "tc" - else: - raise ValueError("Invalid matmul workload") - ret.append(make_workload_key_func(matmul_nkkm, - (N, M, K, in_type, out_type, tensor_core_support))) - elif name.startswith('dense-'): # e.g. dense-1-512-1024, dense-16-512-512 - name_split = name.split('-') - assert len(name_split) == 4 - batch = int(name_split[1]) - in_dim = int(name_split[2]) - out_dim = int(name_split[3]) - ret.append(make_workload_key_func(dense_layer, (batch, in_dim, out_dim))) - elif name.startswith('min-'): # e.g. min-4096 - name_split = name.split('-') - if len(name_split) == 2: - M = 64 - N = int(name_split[1]) - elif len(name_split) == 3: - M = int(name_split[1]) - N = int(name_split[2]) - else: - raise ValueError("Invalid min workload") - ret.append(make_workload_key_func(min_mn, (M, N))) - elif name.startswith('argmin-'): # e.g. argmin-4096 - name_split = name.split('-') - if len(name_split) == 2: - M = 64 - N = int(name_split[1]) - elif len(name_split) == 3: - M = int(name_split[1]) - N = int(name_split[2]) - else: - raise ValueError("Invalid argmin workload") - ret.append(make_workload_key_func(argmin_mn, (M, N))) - elif name.startswith('softmax-'): # e.g. softmax-4096 - name_split = name.split('-') - if len(name_split) == 2: - M = 64 - N = int(name_split[1]) - elif len(name_split) == 3: - M = int(name_split[1]) - N = int(name_split[2]) - else: - raise ValueError("Invalid softmax workload") - ret.append(make_workload_key_func(softmax_mn, (M, N))) - elif name.startswith('add-min-relu'): # e.g. add-min-relu-4096 - name_split = name.split('-') - if len(name_split) == 4: - M = 64 - N = int(name_split[3]) - elif len(name_split) == 5: - M = int(name_split[3]) - N = int(name_split[4]) - else: - raise ValueError("Invalid workload") - ret.append(make_workload_key_func(add_min_relu, (M, N))) - elif name.startswith('add-'): # e.g. add-4096 - name_split = name.split('-') - if len(name_split) == 2: - N = M = int(name_split[1]) - elif len(name_split) == 3: - M = int(name_split[1]) - N = int(name_split[2]) - else: - raise ValueError("Invalid add workload") - ret.append(make_workload_key_func(add_mn, (M, N))) - elif name.startswith('norm-'): # e.g. norm-4096 - name_split = name.split('-') - B = 2 - if len(name_split) == 2: - N = M = int(name_split[1]) - elif len(name_split) == 3: - M = int(name_split[1]) - N = int(name_split[2]) - else: - raise ValueError("Invalid norm workload") - ret.append(make_workload_key_func(norm_bmn, (B, M, N))) - elif name.startswith('nhwc-resnet-'): # e.g. nhwc-resnet-50.C1.B2 - res = re.match(r'nhwc-resnet-(\d+).C(\d+).B(\d+)', name) - n_layers = res.group(1) - idx = int(res.group(2)) - batch_size = 1 if res.group(3) is None else int(res.group(3)) - args = list(resnet_conv2d_configs[n_layers][idx]) - args[0] = batch_size - ret.append(make_workload_key_func(conv2d_nhwc_bias, args)) - elif name.startswith('resnet-'): # e.g. resnet-50.C1.B2 - res = re.match(r'resnet-(\d+).C(\d+).B(\d+)', name) - n_layers = res.group(1) - idx = int(res.group(2)) - batch_size = 1 if res.group(3) is None else int(res.group(3)) - args = list(resnet_conv2d_configs[n_layers][idx]) - args[0] = batch_size - ret.append(make_workload_key_func(conv2d_nchw_bias, args)) - elif name == 'max-pool-2d': - return [make_workload_key_func(max_pool_2d_nchw, (2, 512, 7, 7))] - elif name == 'conv2d-bn-relu': - return [make_workload_key_func(conv2d_nhwc_bn_relu, - (1, 7, 7, 512, 512, 3, 1, 1, 1)) ] - elif name == 'conv2d-rewrite': - return [ make_workload_key_func(conv2d_nhwc_bias_with_rewrite, - (1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1), (1, 1)))] - elif name == 'depthwise-conv2d-rewrite': - return [ make_workload_key_func(depthwise_conv2d_nhwc_bias_with_rewrite, - (1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1), (1, 1)))] - elif name == 'conv2d-relu-softmax-min': - return [make_workload_key_func(conv2d_relu_softmax_min, - (1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1), (1, 1)))] - else: - raise ValueError("Invalid workload " + name) - - return ret - - -def get_workload_weights(name: str) -> List[float]: - """Return weights for workload name""" - if name.startswith('resnet-'): - res = re.match(r'resnet-(\d+).C+', name) - n_layers = res.group(1) - return np.array(resnet_conv2d_weights[n_layers]) - else: - return np.ones(len(get_workload_keys(name))) - - -############################################################ -###################### Measure Tools #################### -############################################################ - - -def measure_schedule(s, - bufs, - target, - target_host=None, - remote=None, - ndk_cc=None, - number=10, - repeat=3, - min_repeat_ms=500): - """Measure the time cost of a schedule""" - func = tvm.build(s, bufs, target=target, target_host=target_host) - if remote: - ctx = remote.context(str(target), 0) - temp = util.tempdir() - remote_path = temp.relpath("tmp_deploy_lib.so") - os.environ['TVM_NDK_CC'] = ndk_cc - func.export_library(remote_path, ndk.create_shared) - remote.upload(remote_path) - func = remote.load_module("tmp_deploy_lib.so") - else: - ctx = tvm.context(str(target), 0) - - if os.environ.get('TVM_AUTO_CACHE_FLUSH', '0') == '1': - min_repeat_ms = 0 - number = 1 - - time_f = func.time_evaluator(func.entry_name, - ctx, - number=number, - repeat=repeat, - min_repeat_ms=min_repeat_ms) - - np_args = [np.ones(topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs] - args = [tvm.nd.array(x, ctx=ctx) for x in np_args] - ctx.sync() - - costs = time_f(*args).results - - return costs - -def check_correctness(s, bufs, s_ref, buf_ref, target, target_host=None, remote=None, ndk_cc=None): - """Check the correctness of a schedule against a reference schedule""" - func = tvm.build(s, bufs, target=target, target_host=target_host) - func_ref = tvm.build(s_ref, buf_ref, target='llvm') - - if remote: - raise NotImplemented - else: - ctx = tvm.context(str(target), 0) - ctx_ref = tvm.cpu() - - np_args = [np.ones(topi.get_const_tuple(x.shape)).astype(x.dtype) for x in bufs] - args = [tvm.nd.array(x, ctx=ctx) for x in np_args] - args_ref = [tvm.nd.array(x, ctx=ctx_ref) for x in np_args] - ctx.sync() - - func(*args) - func_ref(*args_ref) - - for arr, arr_ref in zip(args, args_ref): - np.testing.assert_allclose(arr.asnumpy(), arr_ref.asnumpy()) - - -############################################################ -##################### Other Utilities #################### -############################################################ - - -def geomean(xs): - """Compute geometric mean""" - return math.exp(math.fsum(math.log(x) for x in xs) / len(xs)) - - -def str2bool(v): - if isinstance(v, bool): - return v - if v.lower() in ('yes', 'true', 't', 'y', '1'): - return True - elif v.lower() in ('no', 'false', 'f', 'n', '0'): - return False - else: - raise argparse.ArgumentTypeError('Boolean value expected.') - - -global last_tic -last_tic = None - - -def PRINT_TIME(msg): - """Print time interval between differnt calls. This is for debug so we make the name letters capital""" - global last_tic - now = time.time() - - if last_tic is None: - last_tic = now - - print(msg, now - last_tic) - last_tic = now - - -############################################################ -###################### I/O Utilities ##################### -############################################################ - -# The format for a line in resulst file -BenchmarkRecord = namedtuple("BenchmarkRecord", [ - 'device', 'backend', 'workload_type', 'workload_name', 'library', 'algorithm', 'value', - 'time_stamp' -]) - - -class BaselineDatabase: - """A class for query records in baseline database""" - def __init__(self, filename): - self.filename = filename - - self.lines = [] - for line in open(filename): - if line.startswith('#') or line.isspace(): - continue - self.lines.append(line.split('\t')) - - def filter_records(self, devices=None, backends=None, wkl_names=None, libraries=None): - ret = [] - for line in self.lines: - line = BenchmarkRecord(*line) - - if devices is not None and line.device not in devices: - continue - if backends is not None and line.backend not in backends: - continue - if wkl_names is not None and line.workload_name not in wkl_names: - continue - if libraries is not None and line.library not in libraries: - continue - - ret.append(line) - return ret - - def get_data_dict(self, device, target, wkl_names) -> Tuple[Dict, List]: - """Return a data dict s.t. data[wkl][library] = cost""" - data = defaultdict(lambda: defaultdict(lambda: 1e10)) - - all_libraries = set() - - if "cpu" in target.keys: - backends = ['cpu'] - elif "gpu" in target.keys: - backends = ['gpu'] - else: - raise ValueError("Invalid target: " + target) - - # Read costs for baselines - records = self.filter_records(devices=[device], backends=backends, wkl_names=wkl_names) - for record in records: - # use min over (possible) multiple algorithms - all_libraries.add(record.library) - data[record.workload_name][record.library] = \ - min(data[record.workload_name][record.library], - np.mean(eval(record.value)['costs'])) - - return data, list(all_libraries) - - -class LogFileDatabase: - """A class for indexing best records in a log file""" - def __init__(self, filename: str, n_lines: int = -1): - inputs, results = LogReader(filename).read_lines(n_lines) - - # best records, search by (target_key, workload_key). e.g. ('gpu', 'conv2d...') - self.best_by_targetkey = {} - - # best according to (model, workload_key). e.g. ('1080ti', 'conv2d...')) - self.best_by_model = {} - - # find best records and build the index - for inp, res in zip(inputs, results): - if res.error_no != 0: - continue - - # use target keys in tvm target system as key to build best map - for target_key in inp.task.target.keys: - key = (target_key, inp.task.workload_key) - if key not in self.best_by_targetkey: - self.best_by_targetkey[key] = (inp, res) - else: - _, other_res = self.best_by_targetkey[key] - if np.mean([x.value for x in other_res.costs]) > \ - np.mean([x.value for x in res.costs]): - self.best_by_targetkey[key] = (inp, res) - - # use model as key to build best map - key = (inp.task.target.model, inp.task.workload_key) - if key not in self.best_by_model: - if inp.task.target.model != 'unknown': - self.best_by_model[key] = (inp, res) - else: - _, other_res = self.best_by_model[key] - if np.mean([x.value for x in other_res.costs]) > \ - np.mean([x.value for x in res.costs]): - self.best_by_model[key] = (inp, res) - - def write_best(self, filename: str): - best_records = list(self.best_by_targetkey.values()) - inputs = [x[0] for x in best_records] - results = [x[1] for x in best_records] - write_measure_records_to_file(filename, inputs, results) - - -############################################################ -###################### Plot Utilities #################### -############################################################ - -def max_curve(raw_curve): - """Return b[i] = max(a[:i]) """ - ret = [] - cur_max = -np.inf - for x in raw_curve: - cur_max = max(cur_max, x) - ret.append(cur_max) - return ret - -def min_curve(raw_curve): - """Return b[i] = min(a[:i]) """ - ret = [] - cur_min = np.inf - for x in raw_curve: - cur_min = min(cur_min, x) - ret.append(cur_min) - return ret - -def mean_curve(raw_curve, window_size=None): - """Return b[i] = mean(a[:i]) """ - ret = [] - mean = 0 - if window_size is None: - for i, x in enumerate(raw_curve): - mean = (mean * i + x) / (i + 1) - ret.append(mean) - else: - for i, x in enumerate(raw_curve): - if i >= window_size: - mean = (mean * window_size + x - raw_curve[i - window_size]) / window_size - else: - mean = (mean * i + x) / (i + 1) - ret.append(mean) - return ret - - -def enhance_color(color, h=1, l=1, s=1): - """Make color looks better for pyplot""" - import matplotlib.colors as mc - import colorsys - try: - c = mc.cnames[color] - except: - c = color - c = np.array(colorsys.rgb_to_hls(*mc.to_rgb(c))) - - h, l, s = h * c[0], l * c[1], s * c[2] - h, l, s = [max(min(x, 1), 0) for x in [h, l, s]] - - return colorsys.hls_to_rgb(h, l, s) - - -method_color_dict = { - 'ours': 'C0', - 'AutoTVM': 'C1', - - 'tensorflow': 'C2', - 'tensorflow-tensorrt': 'C9', - 'tflite': 'C2', - - 'pytorch': enhance_color('C3', l=1.1, s=0.9), - - 'FlexTensor': enhance_color('C5'), - 'halide': enhance_color('teal', l=1.25), - - 'Limit space': 'C7', - 'No fine-tuning': 'C8', - 'No task scheduler': 'C1', -} - -def method2color(method): - if '-batch-' in method: - method, batch_size = method.split('-batch-') - #return enhance_color(method_color_dict[method], s=1.1, l=1.5) - return method_color_dict[method] - else: - return method_color_dict[method] - -method_order_list = [ - 'pytorch', 'tensorflow', 'tensorflow-xla', 'tensorflow-tensorrt', - 'tflite', 'halide', 'FlexTensor', 'AutoTVM', - - 'Limit space', 'No fine-tuning', - 'ours', -] - -def method2order(method): - if '-batch-' in method: - method, batch_size = method.split('-batch-') - batch_size = int(batch_size) - return method_order_list.index(method) + batch_size / 100 - else: - return method_order_list.index(method) - -show_name_replace_dict = { - 'pytorch': "PyTorch", - 'tensorflow-tensorrt': 'TensorRT-TF', - 'tensorflow': 'TensorFlow', - 'tflite': 'TensorFlow Lite', - 'halide': 'Halide', - - 'ours': 'Ansor (ours)', - 'batch-16': 'batch', - - 'resnet_50': 'ResNet-50', - 'mobilenet_v2': 'Mobilenet V2', - 'resnet_18_3d': '3D-ResNet', - 'dcgan': 'DCGAN', - 'dqn': 'DQN', - 'bert': 'BERT', -} - -def show_name(name): - # if name.startswith('resnet-'): - # return name.split('.')[1] - for key, value in show_name_replace_dict.items(): - name = name.replace(key, value) - - return name - -def draw_grouped_bar_chart(data, baseline='pytorch', output='out.png', - yscale_log=False, yticks=None, y_max=None, - legend_bbox_to_anchor=None, legend_nrow=None, - figure_size=None, figax=None, draw_ylabel=True, draw_legend=True): - width = 1 - gap = 1.5 - fontsize = 19 - xticks_font_size = fontsize - 2 - - figure_size = figure_size or (11, 4) - legend_bbox_to_anchor = legend_bbox_to_anchor or (0.45, 1.35) - - all_methods = set() - legend_set = {} - - if figax is None: - fig, ax = plt.subplots() - axes = [] - axes.append(ax) - else: - ax = figax - - x0 = 0 - xticks = [] - xlabels = [] - - workloads = list(data.keys()) - for wkl in workloads: - ys = [] - colors = [] - - methods = list(data[wkl].keys()) - - if baseline in data[wkl]: - baseline_cost = data[wkl][baseline] - else: - # normalize to best library - baseline_cost = 1e10 - for method in methods: - if data[wkl][method] < baseline_cost: - baseline_cost = data[wkl][method] - - methods.sort(key=lambda x: method2order(x)) - for method in methods: - relative_speedup = baseline_cost / data[wkl][method] - if yticks is None: - ys.append(relative_speedup) - else: - ys.append(max(relative_speedup, yticks[0] * 1.1)) - colors.append(method2color(method)) - - # draw the bars - xs = np.arange(x0, x0 + len(ys)) - bars = ax.bar(xs, ys, width=width, color=colors) - - for method, bar_obj in zip(methods, bars): - all_methods.add(method) - if method not in legend_set: - legend_set[method] = bar_obj - - # tick and label - x0 += len(ys) + gap - - xticks.append(x0 - gap - len(ys)*width/2.0 - width/2.0) - xlabels.append(show_name(wkl)) - - ax.set_xticks(xticks) - ax.set_xticklabels(xlabels, fontsize=xticks_font_size) - plt.tick_params(axis='x', which='both', bottom='off', top='off') - - if draw_ylabel is True: - ax.set_ylabel('Relative Speedup', fontsize=fontsize) - elif isinstance(draw_ylabel, str): - ax.set_ylabel(draw_ylabel, fontsize=fontsize) - - if yscale_log: - ax.set_yscale('log', basey=2) - if yticks is not None: - ax.set_yticks(yticks) - if y_max: - ax.set_ylim(top=y_max) - - from matplotlib.ticker import FormatStrFormatter - ax.set_yticklabels(ax.get_yticks(), fontsize=fontsize) - ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f')) - ax.yaxis.grid(linewidth=0.4, linestyle='dotted') # draw grid line - ax.set_axisbelow(True) # grid lines are behind the rest - ax.tick_params(bottom=False, top=False, right=False) - - # put legend outside the plot - all_methods = list(all_methods) - all_methods.sort(key=lambda x : method2order(x)) - - if draw_legend: - legend_nrow = legend_nrow or 2 - ncol = (len(all_methods) + legend_nrow - 1)// legend_nrow - ax.legend([legend_set[x] for x in all_methods], - [show_name(x) for x in all_methods], - fontsize=fontsize-1, - loc='upper center', - bbox_to_anchor=legend_bbox_to_anchor, - ncol=ncol, - handlelength=1.0, - handletextpad=0.5, - columnspacing=1.1) - - if figax is None: - fig.set_size_inches(figure_size) - fig.savefig(output, bbox_inches='tight') - print("Output the plot to %s" % output) - - -def to_str_round(x, decimal=6): - if isinstance(x, str): - return x - if isinstance(x, (list, tuple)) or isinstance(x, np.ndarray): - return "[" + ", ".join([to_str_round(y, decimal=decimal) - for y in x]) + "]" - if isinstance(x, dict): - return str({k: eval(to_str_round(v)) for k, v in x.items()}) - if isinstance(x, int): - return str(x) - if isinstance(x, float): - format_str = "%%.%df" % decimal - return format_str % x - raise ValueError("Invalid value: " + str(x)) - diff --git a/scripts/shape_configs.py b/scripts/shape_configs.py deleted file mode 100644 index db6b3b9dc9aa..000000000000 --- a/scripts/shape_configs.py +++ /dev/null @@ -1,247 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" Shape configurations for single operator / subgraph evaluation -This file is shared by tune_op_subgraph.py and scripts in scripts/baseline/ -""" - -matmul_shapes = [ - (1, 128, 128, 128), - (1, 512, 32, 512), - (1, 512, 512, 512), - (1, 1024, 1024, 1024), -] - -conv1d_shapes = [ - # derived from conv2d_shapes - (1, 256, 64, 128, 3, 2, 1), -# (1, 256, 64, 128, 1, 2, 0), -# (1, 256, 64, 64, 1, 1, 0), -# (1, 128, 128, 256, 3, 2, 1), - (1, 128, 128, 256, 1, 2, 0), -# (1, 128, 128, 128, 3, 1, 1), -# (1, 64, 256, 512, 3, 2, 1), -# (1, 64, 256, 512, 1, 2, 0), - (1, 64, 256, 256, 5, 1, 2), - (1, 32, 512, 512, 3, 1, 1), -] - -conv2d_shapes = [ - # all conv2d layers in resnet-18 - (1, 224, 224, 3, 64, 7, 2, 3), -# (1, 56, 56, 64, 128, 3, 2, 1), -# (1, 56, 56, 64, 128, 1, 2, 0), -# (1, 56, 56, 64, 64, 3, 1, 1), - (1, 56, 56, 64, 64, 1, 1, 0), -# (1, 28, 28, 128, 256, 3, 2, 1), -# (1, 28, 28, 128, 256, 1, 2, 0), -# (1, 28, 28, 128, 128, 3, 1, 1), -# (1, 14, 14, 256, 512, 3, 2, 1), -# (1, 14, 14, 256, 512, 1, 2, 0), - (1, 14, 14, 256, 256, 3, 1, 1), - (1, 7, 7, 512, 512, 3, 1, 1), -] - -conv3d_shapes = [ - # Derived from cnov2d_shapes. Use depth=16 for all configurations - (1, 16, 224, 224, 3, 64, 7, 2, 3), -# (1, 16, 56, 56, 64, 128, 3, 2, 1), -# (1, 16, 56, 56, 64, 128, 1, 2, 0), -# (1, 16, 56, 56, 64, 64, 3, 1, 1), - (1, 16, 56, 56, 64, 64, 1, 1, 0), -# (1, 16, 28, 28, 128, 256, 3, 2, 1), -# (1, 16, 28, 28, 128, 256, 1, 2, 0), -# (1, 16, 28, 28, 128, 128, 3, 1, 1), -# (1, 16, 14, 14, 256, 512, 3, 2, 1), -# (1, 16, 14, 14, 256, 512, 1, 2, 0), - (1, 16, 14, 14, 256, 256, 3, 1, 1), - (1, 16, 7, 7, 512, 512, 3, 1, 1), -] - -group_conv2d_shapes = [ - # Derived from cnov2d_shapes. Use group=4 for all configurations - (1, 56, 56, 64, 128, 3, 2, 1 , 1, 4), -# (1, 56, 56, 64, 128, 1, 2, 0 , 1, 4), -# (1, 56, 56, 64, 64, 3, 1, 1 , 1, 4), - (1, 56, 56, 64, 64, 1, 1, 0 , 1, 4), -# (1, 28, 28, 128, 256, 3, 2, 1, 1, 4), -# (1, 28, 28, 128, 256, 1, 2, 0, 1, 4), -# (1, 28, 28, 128, 128, 3, 1, 1, 1, 4), -# (1, 14, 14, 256, 512, 3, 2, 1, 1, 4), -# (1, 14, 14, 256, 512, 1, 2, 0, 1, 4), - (1, 14, 14, 256, 256, 3, 1, 1, 1, 4), - (1, 7, 7, 512, 512, 3, 1, 1 , 1, 4), -] - -dilation_conv2d_shapes = [ - # Derived from cnov2d_shapes. Use dilation=2 for all configurations - (1, 224, 224, 3, 64, 7, 2, 3 , 2), -# (1, 56, 56, 64, 128, 3, 2, 1 , 2), -# (1, 56, 56, 64, 128, 1, 2, 0 , 2), -# (1, 56, 56, 64, 64, 3, 1, 1 , 2), - (1, 56, 56, 64, 64, 1, 1, 0 , 2), -# (1, 28, 28, 128, 256, 3, 2, 1, 2), -# (1, 28, 28, 128, 256, 1, 2, 0, 2), -# (1, 28, 28, 128, 128, 3, 1, 1, 2), -# (1, 14, 14, 256, 512, 3, 2, 1, 2), -# (1, 14, 14, 256, 512, 1, 2, 0, 2), - (1, 14, 14, 256, 256, 3, 1, 1, 2), - (1, 7, 7, 512, 512, 3, 1, 1 , 2), -] - -depthwise_conv2d_shapes = [ - # all depthwise conv2d layers in mobilenet - (1, 112, 112, 32, 3, 1, 1), - (1, 112, 112, 64, 3, 2, 1), -# (1, 56, 56, 128, 3, 1, 1), -# (1, 56, 56, 128, 3, 2, 1), -# (1, 28, 28, 256, 3, 1, 1), -# (1, 28, 28, 256, 3, 2, 1), -# (1, 14, 14, 512, 3, 1, 1), - (1, 14, 14, 512, 3, 2, 1), - (1, 7, 7, 1024, 3, 1, 1), -] - -conv2d_transpose_shapes = [ - # all conv2d tranpose layers in DCGAN - (1, 4, 4, 512, 256, 4, 2, 1), - (1, 8, 8, 256, 128, 4, 2, 1), - (1, 16, 16, 128, 64, 4, 2, 1), - (1, 32, 32, 64, 3, 4, 2, 1), -] - -conv2d_capsule_shapes = [ - # all conv2d capsule layers in matrix capsules withemrouting (ICLR 2018) - (1, 16, 16, 32, 32, 3, 2, 1), - (1, 8, 8, 32, 32, 3, 1, 1), - (1, 16, 16, 8, 16, 3, 2, 1), - (1, 8, 8, 16, 16, 3, 1, 1), -] - -conv2d_winograd_nhwc_shapes = [ - (1, 56, 56, 64, 64, 3, 1, 1), - (1, 28, 28, 128, 128, 3, 1, 1), - (1, 14, 14, 256, 256, 3, 1, 1), - (1, 7, 7, 512, 512, 3, 1, 1), -] - -conv2d_winograd_nchw_shapes = [ - (1, 64, 56, 56, 64, 3, 1, 1), - (1, 128, 28, 28, 128, 3, 1, 1), - (1, 256, 14, 14, 256, 3, 1, 1), - (1, 512, 7, 7, 512, 3, 1, 1), -] - -matmul_tensor_core_shapes = [ - (16, 512, 512, 'float16', 'float32', True), - (32, 512, 512, 'float16', 'float32', True), - (512, 512, 512, 'float16', 'float32', True), -] - -norm_shapes = [ - (1, 256, 256), - (1, 512, 512), - (1, 1024, 1024), - (1, 4096, 1024), -] - -single_op_shape_dict = { - 'C1D': conv1d_shapes, - 'C2D': conv2d_shapes, - 'C3D': conv3d_shapes, - 'GMM': matmul_shapes, - 'GRP': group_conv2d_shapes, - 'DIL': dilation_conv2d_shapes, - 'DEP': depthwise_conv2d_shapes, - 'T2D': conv2d_transpose_shapes, - 'CAP': conv2d_capsule_shapes, - 'NRM': norm_shapes, - -# The following workloads are not in our sinle op evaluation plan. -# They should be moved to `common.py` and be used by `tune_wkl.py`. -# 'C2D_NCHW': conv2d_nchw_shapes, -# 'C2DWG_NHWC': conv2d_winograd_nhwc_shapes, -# 'C2DWG_NCHW': conv2d_winograd_nchw_shapes, -# 'GMM_TC': matmul_tensor_core_shapes, -} - -conv2d_bn_relu_shapes = [ - (1, 224, 224, 3, 64, 7, 2, 3), - (1, 56, 56, 64, 128, 3, 2, 1), - (1, 28, 28, 128, 256, 1, 2, 0), - (1, 7, 7, 512, 512, 3, 1, 1, 1), - (16, 224, 224, 3, 64, 7, 2, 3), - (16, 56, 56, 64, 128, 3, 2, 1), - (16, 28, 28, 128, 256, 1, 2, 0), - (16, 7, 7, 512, 512, 3, 1, 1, 1), -] - -transpose_batch_matmul_shapes = [ - (1, 128, 12, 64), - (1, 128, 16, 64), - (1, 64, 12, 128), - (1, 128, 12, 128), - (16, 128, 12, 64), - (16, 128, 16, 64), - (16, 64, 12, 128), - (16, 128, 12, 128), -] - -subgraph_shape_dict = { - "conv2d_bn_relu": conv2d_bn_relu_shapes, - "transpose_batch_matmul": transpose_batch_matmul_shapes, -} - -resnet_shapes = [ - (1, ), - (16, ), -] - -mobilenet_v2_shapes = [ - (1, ), - (16, ), -] - -dcgan_shapes = [ - (1, ), - (16, ), -] - -dqn_shapes = [ - (1, ), - (16, ), -] - -bert_shapes = [ - (1, ), - (16, ), -] - -resnet18_3d_shapes = [ - (1, ), - (16, ), -] - -network_shape_dict = { - 'resnet_50': resnet_shapes, - 'mobilenet_v2': mobilenet_v2_shapes, - 'dcgan': dcgan_shapes, - 'dqn': dqn_shapes, - 'bert': bert_shapes, - 'resnet_18_3d': resnet18_3d_shapes, -} - diff --git a/scripts/tune_network.py b/scripts/tune_network.py deleted file mode 100644 index 188da6cbe6e6..000000000000 --- a/scripts/tune_network.py +++ /dev/null @@ -1,405 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Tune a whole neural network""" -import argparse -import logging -import random -import os -import numpy as np - -import tvm -from tvm import ansor, relay -import tvm.contrib.graph_runtime as runtime -from tvm.contrib.debugger import debug_runtime -from tvm.contrib import util, ndk -from tvm.relay import testing -from tvm.ansor.utils import request_remote -#from baseline.utils import log_line, BenchmarkRecord - -from common import str2bool -from tune_test import create_tune_option - -dtype = "float32" - -def get_network(name, network_path, batch_size, layout): - """Get the relay module and random weights for a network""" - input_shape = (batch_size, 3, 224, 224) - output_shape = (batch_size, 1000) - input_name = 'data' - - if name.startswith("resnet3d"): - n_layer = int(name.split('-')[1]) - layout = "NDHWC" - image_shape = (16, 112, 112, 3) - input_shape = (batch_size, *image_shape) - mod, params = relay.testing.resnet3d.get_workload(num_layers=n_layer, batch_size=batch_size, image_shape=image_shape, dtype=dtype, layout=layout) - elif name.startswith("resnet"): - n_layer = int(name.split('-')[1]) - image_shape = (224, 224, 3) if layout == 'NHWC' else (3, 224, 224) - input_shape = (batch_size, *image_shape) - mod, params = relay.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size, layout=layout, image_shape=image_shape, dtype=dtype) - elif "lstm" in name: - mod, params = relay.testing.lstm.get_workload(iterations=10, num_hidden=512, batch_size=batch_size, dtype=dtype) - elif "mlp" in name: - input_shape = (batch_size, 1, 28, 28) - mod, params = relay.testing.mlp.get_workload(batch_size=batch_size, dtype=dtype) - elif "vgg" in name: - n_layer = int(name.split('-')[1]) - mod, params = relay.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size, dtype=dtype) - elif name == 'dcgan': - input_shape = (batch_size, 100) - mod, params = relay.testing.dcgan.get_workload(batch_size=batch_size) - elif name == 'dqn': - layout = "NHWC" - image_shape = (84, 84, 4) - input_shape = (batch_size, *image_shape) - mod, params = relay.testing.dqn.get_workload(batch_size=batch_size, image_shape=image_shape, dtype=dtype, layout=layout) - elif name == 'mobilenet': - image_shape = (224, 224, 3) if layout == 'NHWC' else (3, 224, 224) - input_shape = (batch_size, *image_shape) - mod, params = relay.testing.mobilenet.get_workload(batch_size=batch_size, layout=layout, image_shape=image_shape, dtype=dtype) - elif name == 'r3d_18': - import torch - import torchvision - - model = getattr(torchvision.models.video, name)(pretrained=False) - model = model.eval() - - # We grab the TorchScripted model via tracing - input_shape = [batch_size, 3, 16, 112, 112] - input_data = torch.randn(input_shape) - scripted_model = torch.jit.trace(model, input_data).eval() - - input_name = 'input0' # only one input, set it to this name - shape_list = {input_name: input_shape} - mod, params = relay.frontend.from_pytorch(scripted_model, - shape_list) - elif name == 'squeezenet_v1.1': - mod, params = relay.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1', dtype=dtype) - elif name == 'inception_v3': - input_shape = (batch_size, 3, 299, 299) - mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype) - elif name == 'mxnet': - # an example for mxnet model - from mxnet.gluon.model_zoo.vision import get_model - block = get_model('resnet18_v1', pretrained=True) - mod, params = relay.frontend.from_mxnet(block, shape={"input_name": input_shape}, dtype=dtype) - net = mod["main"] - net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs) - mod = relay.Module.from_expr(net) - elif name == 'tflite-mobilenet-v2' or name == 'tflite-resnet-v2-50': - try: - import tflite.Model - except ImportError: - raise ImportError("The tflite package must be installed") - input_name = "input" - input_shape = (1, 224, 224, 3) - output_shape = (1, 1001) - input_dtype = "float32" - tflite_model_buf = open(network_path, "rb").read() - tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) - mod, params = relay.frontend.from_tflite(tflite_model, - shape_dict={input_name: input_shape}, - dtype_dict={input_name: input_dtype}) - elif name == 'pytorch-mobilenet-v2': - import torch - - model = torch.hub.load('pytorch/vision:v0.5.0', 'mobilenet_v2', pretrained=False) - model.eval() - - input_shape = [batch_size, 3, 224, 224] - input_data = torch.randn(input_shape) - scripted_model = torch.jit.trace(model, input_data).eval() - - input_name = 'input0' - shape_list = {input_name: input_shape} - mod, params = relay.frontend.from_pytorch(scripted_model, - shape_list) - elif name == 'bert': - import tensorflow as tf - - bert_pb = './baseline/tensorflow/tf_models/bert/bert-B%d.pb' % batch_size - try: - with tf.compat.v1.gfile.GFile(bert_pb, 'rb') as f: - graph_def = tf.compat.v1.GraphDef() - graph_def.ParseFromString(f.read()) - except: - raise ValueError("Need to run ./baseline/tensorflow/bert/generate_bert_pb.py to get model first") - - input_shape = (batch_size, 128) - input_name = ['input'] - shape_dict = { - 'input': input_shape - } - out_names = [ - 'bert/pooler/dense/Tanh' - ] - - mod, params = relay.frontend.from_tensorflow(graph_def, - shape=shape_dict, - outputs=out_names) - else: - raise ValueError("Unsupported network: " + name) - - return mod, params, input_name, input_shape, output_shape - - -def create_module(data_shape, graph, lib, target, input_name, params, debug_profile, - local_measure, ndk_cc, rpc_device_key, rpc_host, rpc_port, rpc_num_threads, seed=43): - if local_measure: - if target.target_name == "cuda": - ctx = tvm.gpu() - else: - ctx = tvm.cpu() - else: - print("=============== Request Remote ===============") - if 'TVM_NDK_CC' not in os.environ: - os.environ['TVM_NDK_CC'] = ndk_cc - remote = request_remote(rpc_device_key, rpc_host, rpc_port) - - print("=============== Export ===============") - ctx = remote.cpu() - temp = util.tempdir() - path_lib = temp.relpath("deploy_lib.so") - lib.export_library(path_lib, ndk.create_shared) - - print("=============== Upload ===============") - remote.upload(path_lib) - - print("=============== Load ===============") - lib = remote.load_module("deploy_lib.so") - - if rpc_num_threads: - config_threadpool = remote.get_function('runtime.config_threadpool') - config_threadpool(0, rpc_num_threads) - - np.random.seed(seed) - data_tvm = tvm.nd.array(100 * (np.random.uniform(size=data_shape)).astype(dtype), ctx=ctx) - if debug_profile: - module = debug_runtime.create(graph, lib, ctx) - else: - module = runtime.create(graph, lib, ctx) - - if type(input_name) == list: - for name in input_name: - module.set_input(name, data_tvm) - else: - module.set_input(input_name, data_tvm) - for k, v in params.items(): - module.set_input(k, v) - - return module, ctx - - -def tune_and_evaluate(network_arguments, target, target_host, - search_policy, task_scheduler_arguments, tune_option_arguments, - tune, debug_profile, check_correctness, log_n_lines): - # Extract tasks from relay program - mod, params, input_name, data_shape, out_shape = get_network(**network_arguments) - - # Tune all - if tune: - print("=============== Extract Workloads ===============") - workloads, wkl_weights = ansor.extract_from_program(mod, target=target, params=params) - print("Extract %d workloads in total" % (len(workloads))) - - # Tune workloads with auto scheduler - print("=============== Tune ===============") - tasks = [] - for i, wkl_key in enumerate(workloads): - dag = ansor.workload_key_to_dag(wkl_key) - print("[========= Task %d =========]\n" % i, dag) - tasks.append(ansor.SearchTask(dag, wkl_key, target, target_host)) - - tuner = ansor.SimpleTaskScheduler(tasks, - lambda costs: sum(c * w for c, w in zip(costs, wkl_weights)), - **task_scheduler_arguments) - tune_option, measure_ctx = create_tune_option(target, **tune_option_arguments) - - if tune_option_arguments['local_measure'] and target.target_name != 'cuda': - os.environ['TVM_BIND_MASTER_CORE_0'] = "1" - tuner.tune(tune_option, search_policy) - - if measure_ctx: - del measure_ctx - - kernel_layout_rewrite = True - - # Compile graph with best states found by auto-scheduler - print("=============== Compile ===============") - with ansor.apply_history_best(tune_option_arguments['log_file'], log_n_lines): - os.environ['TVM_AUTO_CACHE_FLUSH'] = "0" - - if kernel_layout_rewrite: - ansor.prepare_layout_rewrite(mod, target=target, params=params) - else: - # disable layout rewrite - ansor.LayoutRewriteLevel.BOTH_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE - ansor.LayoutRewriteLevel.COMPUTE_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE - - with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}): - graph, lib, opt_params = relay.build_module.build( - mod, target=target, params=params) - - ansor.finish_layout_rewrite() - print("=============== Compile Finish ===============") - - module, ctx = create_module(data_shape, graph, lib, target, input_name, - opt_params, debug_profile, **common_measure_parameters) - - # Evaluate - print("========== Evaluate ==========") - ftimer = module.module.time_evaluator("run", ctx, number=10, repeat=3) - prof_res = np.array(ftimer().results) - - # display profile information - if debug_profile or check_correctness: - module.run() - if check_correctness: - actual_output = module.get_output(0).asnumpy() - print(actual_output) - - print("Mean inference time (std dev): %.2f ms (%.2f ms)" % - (np.mean(prof_res) * 1000, np.std(prof_res) * 1000)) - #log_line(BenchmarkRecord(target.target_name, 'gpu' if target.target_name == 'cuda' else 'cpu', 'network', - # "%s.B%d" % (network_name, batch_size), 'AutoSchedule', layout, - # {"costs": prof_res}, time.time()), record_file) - - if check_correctness: - print("========== Check Correctness ==========") - # clean relay cache - relay.backend.compile_engine.get().clear() - - # disable layout rewrite - ansor.LayoutRewriteLevel.BOTH_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE - ansor.LayoutRewriteLevel.COMPUTE_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE - target = tvm.target.create('llvm') - with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}): - graph, lib, opt_params = relay.build_module.build( - mod, target=target, params=params) - - module, _ = create_module(data_shape, graph, lib, target, input_name, - opt_params, debug_profile, **common_measure_parameters) - module.run() - - expected_output = module.get_output(0).asnumpy() - np.testing.assert_allclose(actual_output, expected_output, rtol=1e-3, atol=1e-3) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - - # Search task related arguments - parser.add_argument("--network", type=str, required=True) - parser.add_argument("--network-path", type=str, default=None, help="The path of tflite model") - parser.add_argument("--batch-size", type=int, default=1) - parser.add_argument("--layout", type=str, default='NHWC') - parser.add_argument("--target", type=str, default='llvm -mcpu=core-avx2') - parser.add_argument("--target-host", type=str, default=None) - parser.add_argument("--check-correctness", type=str2bool, nargs='?', const=True, default=False) - parser.add_argument("--debug-profile", type=str2bool, nargs='?', const=True, default=False) - parser.add_argument("--tune", type=str2bool, nargs='?', const=True, default=True) - - # Search strategy related arguments - parser.add_argument("--n-trials", type=int, default=1000) - parser.add_argument("--policy", type=str, choices=['sketch'], default='sketch') - parser.add_argument("--model-type", type=str, choices=['xgb', 'random', 'no-share'], default='xgb') - parser.add_argument("--task-scheduler", type=str, default='gradient', - choices=['no', 'gradient', 'round-robin'], - help='The strategy of task scheduler') - parser.add_argument("--seed", type=int, default=0, help='random seed') - - # Log file related arguments - parser.add_argument("--log-file", type=str, help="Write measurement records to this log file") - parser.add_argument("--load-log", type=str, help="Load history log to resume the status of search") - parser.add_argument("--log-n-lines", type=int, help="Only load the first n lines for history log") - parser.add_argument("--load-model", type=str, help="Load pre trained cost model file") - - # Measurement related and other arguments - parser.add_argument("--num-measure-per-iter", type=int, default=48, - help="The number of programs to be measured at each iteration") - parser.add_argument("--build-timeout", type=int, default=10) - parser.add_argument("--run-timeout", type=int, default=10) - parser.add_argument("--early-stopping", type=int, default=-1) - parser.add_argument("--verbose", type=int, default=1) - parser.add_argument("--local-measure", type=str2bool, nargs='?', const=True, default=True) - parser.add_argument("--rpc-device-key", type=str, default=None) - parser.add_argument("--rpc-host", type=str, default='0.0.0.0') - parser.add_argument("--rpc-port", type=int, default=9190) - parser.add_argument("--rpc-num-threads", type=int, default=None) - parser.add_argument("--n-parallel", type=int, default=1) - parser.add_argument("--ndk-cc", type=str, default=None) - args = parser.parse_args() - - np.random.seed(args.seed) - random.seed(args.seed) - logging.basicConfig() - logging.getLogger('ansor').setLevel(logging.DEBUG) - os.environ["TOPHUB_LOCATION"] = "NONE" # disable autotvm - - target = tvm.target.create(args.target) - log_file = args.log_file or "%s-B%d-%s.json" % (args.network, args.batch_size, - target.target_name) - load_log_file = args.load_log or log_file - search_policy = "%s.%s" % (args.policy, args.model_type) - if args.layout: - layout = args.layout - elif target.target_name == "cuda": - layout = "NCHW" - else: - layout = "NHWC" - - network_arguments = { - 'name': args.network, - 'network_path': args.network_path, - 'batch_size': args.batch_size, - 'layout': layout - } - - task_scheduler_parameters = { - 'strategy': args.task_scheduler, - 'load_log_file': load_log_file, - 'load_model_file': args.load_model, - 'verbose': args.verbose, - } - - common_measure_parameters = { - 'local_measure': args.local_measure, - 'rpc_device_key': args.rpc_device_key, - 'rpc_host': args.rpc_host, - 'rpc_port': args.rpc_port, - 'rpc_num_threads': args.rpc_num_threads, - 'ndk_cc': args.ndk_cc, - } - - tune_option_arguments = { - 'log_file': log_file, - 'n_trials': args.n_trials, - 'num_measure_per_iter': args.num_measure_per_iter, - 'verbose': args.verbose, - 'n_parallel': args.n_parallel, - 'build_timeout': args.build_timeout, - 'run_timeout': args.run_timeout, - 'early_stopping': args.early_stopping, - **common_measure_parameters - } - - tune_and_evaluate(network_arguments, target, args.target_host, - search_policy, task_scheduler_parameters, tune_option_arguments, - args.tune, args.debug_profile, args.check_correctness, - args.log_n_lines) diff --git a/scripts/tune_op_subgraph.py b/scripts/tune_op_subgraph.py deleted file mode 100644 index d3e70501873e..000000000000 --- a/scripts/tune_op_subgraph.py +++ /dev/null @@ -1,602 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Tune all workloads for single op & subgraph evaluation""" -import argparse -import logging -import random - -import numpy as np - -import tvm -from tvm import te, ansor -import topi -from topi.nn.winograd_util import winograd_transform_matrices -from topi.util import get_const_tuple - -from common import measure_schedule, str2bool, norm_bmn, conv2d_nhwc_bn_relu, conv2d_nchw_bn_relu -from shape_configs import single_op_shape_dict, subgraph_shape_dict -from tune_test import tune_workloads_jointly, replay_workload, create_tune_option - -# ========================== Single Ops ========================== - -@ansor.register_workload_func -def batch_matmul_nkkm(B, N, M, K): - X = te.placeholder((B, N, K), name='A') - Y = te.placeholder((B, K, M), name='B') - k = te.reduce_axis((0, K), name='k') - Z = te.compute((B, N, M), lambda b, i, j: te.sum(X[b][i][k] * Y[b][k][j], axis=[k]), name='C') - return [X, Y, Z] - -@ansor.register_workload_func -def conv1d_nlc(N, L, CI, CO, kernel_size, stride=1, padding=0, dilation=1, groups=1): - inputs = te.placeholder((N, L, CI), name='inputs') - weight = te.placeholder((kernel_size, CI//groups, CO), name='weight') - - batch_size, in_len, in_channel = inputs.shape - k_len, channel_per_group, out_channel = weight.shape - out_channel_per_group = out_channel // groups - out_len = (in_len + 2 * padding - dilation * (k_len - 1) - 1) // stride + 1 - rc = te.reduce_axis((0, channel_per_group), name='rc') - rl = te.reduce_axis((0, k_len), name='rl') - - padded = topi.nn.pad(inputs, [0, padding, 0]) - output = te.compute( - (batch_size, out_len, out_channel), - lambda n, l, co: te.sum( - (padded[n, l * stride + rl * dilation, co // out_channel_per_group * channel_per_group + rc] * - weight[rl, rc, co]), axis=[rl, rc]), - name='conv1d_nlc' - ) - return [inputs, weight, output] - -@ansor.register_workload_func -def conv2d_nhwc(N, H, W, CI, CO, kernel_size, stride=1, padding=0, dilation=1, groups=1): - inputs = te.placeholder((N, H, W, CI), name='inputs') - weight = te.placeholder((kernel_size, kernel_size, CI//groups, CO), name='weight') - batch_size, in_h, in_w, in_channel = inputs.shape - k_h, k_w, channel_per_group, out_channel = weight.shape - out_channel_per_group = out_channel // groups - - out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1 - out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1 - rh = te.reduce_axis((0, k_h), name="rh") - rw = te.reduce_axis((0, k_w), name="rw") - rc = te.reduce_axis((0, channel_per_group), name="rc") - - padded = topi.nn.pad(inputs, [0, padding, padding, 0]) - output = te.compute( - (batch_size, out_h, out_w, out_channel), - lambda n, h, w, co: te.sum( - (padded[n, h * stride + rh * dilation, w * stride + rw * dilation, - co // out_channel_per_group * channel_per_group + rc] - * weight[rh, rw, rc, co]), axis=[rh, rw, rc] - ), - name='conv2d_nhwc' - ) - return [inputs, weight, output] - -@ansor.register_workload_func -def conv2d_nchw(N, CI, H, W, CO, kernel_size, stride=1, padding=0, dilation=1, groups=1): - inputs = te.placeholder((N, CI, H, W), name='inputs') - weight = te.placeholder((CO, CI//groups, kernel_size, kernel_size), name='weight') - batch_size, in_channel, in_h, in_w = inputs.shape - out_channel, channel_per_group, k_h, k_w, = weight.shape - out_channel_per_group = out_channel // groups - - out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1 - out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1 - rc = te.reduce_axis((0, channel_per_group), name="rc") - rh = te.reduce_axis((0, k_h), name="rh") - rw = te.reduce_axis((0, k_w), name="rw") - - padded = topi.nn.pad(inputs, [0, 0, padding, padding]) - output = te.compute( - (batch_size, out_channel, out_h, out_w), - lambda n, co, h, w: te.sum( - (padded[n, co // out_channel_per_group * channel_per_group + rc, - h * stride + rh * dilation, w * stride + rw * dilation] - * weight[co, rc, rh, rw]), axis=[rc, rh, rw] - ), - name='conv2d_nchw' - ) - return [inputs, weight, output] - -@ansor.register_workload_func -def conv3d_ndhwc(N, D, H, W, CI, CO, kernel_size, stride=1, padding=0, dilation=1, groups=1): - inputs = te.placeholder((N, D, H, W, CI)) - weight = te.placeholder((kernel_size, kernel_size, kernel_size, CI//groups, CO)) - batch_size, in_d, in_h, in_w, in_channel = inputs.shape - k_d, k_h, k_w, channel_per_group, out_channel = weight.shape - out_channel_per_group = out_channel // groups - - out_d = (in_d + 2 * padding - dilation * (k_d - 1) - 1) // stride + 1 - out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1 - out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1 - rd = te.reduce_axis((0, k_d), name='rd') - rh = te.reduce_axis((0, k_h), name='rh') - rw = te.reduce_axis((0, k_w), name='rw') - rc = te.reduce_axis((0, channel_per_group), name='rc') - - padded = topi.nn.pad(inputs, [0, padding, padding, padding, 0]) - output = te.compute( - (batch_size, out_d, out_h, out_w, out_channel), - lambda n, d, h, w, co: te.sum( - (padded[n, d * stride + rd * dilation, - h * stride + rh * dilation, w * stride + rw * dilation, - co // out_channel_per_group * channel_per_group + rc] - * weight[rd, rh, rw, rc, co]), - axis=[rd, rh, rw, rc] - ), - name='conv3d_ndhwc' - ) - return [inputs, weight, output] - -@ansor.register_workload_func -def depthwise_conv2d_nhwc(N, H, W, C, kernel_size, stride=1, padding=0, dilation=1, factor=1): - inputs = te.placeholder((N, H, W, C)) - weight = te.placeholder((factor, kernel_size, kernel_size, C)) - - batch_size, in_h, in_w, in_channel = inputs.shape - factor, k_h, k_w, in_channel = weight.shape - out_channel = in_channel * factor - - assert factor.value == 1, "Not optimized for factor != 1" - - out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1 - out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1 - rh = te.reduce_axis((0, k_h), name='rh') - rw = te.reduce_axis((0, k_w), name='rw') - - padded = topi.nn.pad(inputs, [0, padding, padding, 0]) - output = te.compute( - (batch_size, out_h, out_w, out_channel), - lambda n, h, w, c: te.sum( - (padded[n, h * stride + rh * dilation, w * stride + rw * dilation, c // factor] - * weight[c % factor, rh, rw, c // factor]), - axis=[rh, rw] - ), - name="depth_conv2d_nhwc" - ) - return [inputs, weight, output] - -@ansor.register_workload_func -def conv2d_transpose_nhwc(N, H, W, CI, CO, kernel_size, stride=1, padding=0): - inputs = te.placeholder((N, H, W, CI), name='inputs') - weight = te.placeholder((kernel_size, kernel_size, CI, CO), name='weight') - - batch, in_h, in_w, in_c = inputs.shape - filter_h, filter_w, in_c, out_c = weight.shape - stride_h, stride_w = (stride, stride) - - # compute padding - fpad_top, fpad_left, fpad_bottom, fpad_right = topi.nn.get_pad_tuple(padding, (filter_h, filter_w)) - bpad_top = filter_h - 1 - fpad_top - bpad_bottom = filter_h - 1 - fpad_bottom - bpad_left = filter_w - 1 - fpad_left - bpad_right = filter_w - 1 - fpad_right - - # padding stage - padded = topi.nn.pad(inputs, - [0, (bpad_top + stride_h - 1) // stride_h, - (bpad_left + stride_w - 1) // stride_w, 0], - [0, (bpad_bottom + stride_h - 1) // stride_h, - (bpad_right + stride_w - 1) // stride_w, 0]) - - # remove extra padding introduced by dilatation - idxdiv = te.indexdiv - idxmod = te.indexmod - border_h = idxmod(stride_h - idxmod(bpad_top, stride_h), stride_h) - border_w = idxmod(stride_w - idxmod(bpad_left, stride_w), stride_w) - - # dilation stage - strides = [1, stride_h, stride_w, 1] - n = len(padded.shape) - - # We should embed this dilation directly into te.compute rather than creating a new te.compute. - # Only in this way can we use unroll to eliminate the multiplication of zeros. - def _dilate(*indices): - not_zero = [] - index_tuple = [] - for i in range(n): - if not strides[i] == 1: - index_tuple.append(idxdiv(indices[i], strides[i])) - not_zero.append(idxmod(indices[i], strides[i]).equal(0)) - else: - index_tuple.append(indices[i]) - if not_zero: - not_zero = te.all(*not_zero) - return te.if_then_else(not_zero, padded(*index_tuple), tvm.tir.const(0.0, padded.dtype)) - return padded(*index_tuple) - - # convolution stage - out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h - out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w - rc = te.reduce_axis((0, in_c), name='rc') - rh = te.reduce_axis((0, filter_h), name='rh') - rw = te.reduce_axis((0, filter_w), name='rw') - - output = te.compute( - (batch, out_h, out_w, out_c), - lambda n, h, w, co: te.sum( - _dilate(n, h + rh + border_h, w + rw + border_w, rc) * - weight[filter_h - 1 - rh, filter_w - 1 - rw, rc, co], - axis=[rh, rw, rc]), - name="conv2d_transpose_nhwc", - attrs={"ansor_always_unroll_inner": ["h", "w", "rh", "rw", "h_c", "w_c"]}) - # todo(lmzheng): add constraints on the tile size of h and w - - return [inputs, weight, output] - -@ansor.register_workload_func -def conv2d_capsule_nhwijc(N, H, W, CI, CO, kernel_size, stride=1, padding=0, capsule_size=4): - inputs = te.placeholder((N, H, W, capsule_size, capsule_size, CI), name='inputs') - weight = te.placeholder((kernel_size, kernel_size, capsule_size, capsule_size, CI, CO), name='weight') - batch_size, in_h, in_w, _, _, in_channel = inputs.shape - k_h, k_w, _, _, _, out_channel = weight.shape - - out_h = (in_h + 2 * padding - kernel_size) // stride + 1 - out_w = (in_w + 2 * padding - kernel_size) // stride + 1 - - rh = te.reduce_axis((0, k_h), name="rh") - rw = te.reduce_axis((0, k_w), name="rw") - cap_k = te.reduce_axis((0, capsule_size), name='cap_k') - rc = te.reduce_axis((0, in_channel), name="rc") - - padded = topi.nn.pad(inputs, [0, padding, padding, 0, 0, 0]) - output = te.compute( - (batch_size, out_h, out_w, capsule_size, capsule_size, out_channel), - lambda n, h, w, cap_i, cap_j, co: te.sum( - (padded[n, h * stride + rh, w * stride + rw, cap_i, cap_k, rc] - * weight[rh, rw, cap_k, cap_j, rc, co]), axis=[rh, rw, cap_k, rc] - ), - name='conv2d_capsule_nhwijc' - ) - return [inputs, weight, output] - - -@ansor.register_workload_func -def conv2d_winograd_nhwc(N, H, W, CI, CO, kernel_size=3, stride=1, padding=0, dilation=1): - # TODO: implement tile_size - tile_size = 4 #_infer_tile_size(data, kernel) - inputs = te.placeholder((N, H, W, CI), name='inputs') - #weight = te.placeholder((kernel_size, kernel_size, CI, CO), name='weight') - N, H, W, CI = get_const_tuple(inputs.shape) - if isinstance(dilation, int): - dilation_h = dilation_w = dilation - else: - dilation_h, dilation_w = dilation - # if dilation_h != 1 or dilation_w != 1: - # weight = topi.nn.dilate(weight, (1, 1, dilation_h, dilation_w)) - KH = KW = kernel_size - HPAD, WPAD, _, _ = topi.nn.get_pad_tuple(padding, (KH, KW)) - HSTR, WSTR = (stride, stride) if isinstance(stride, int) else stride - assert HSTR == 1 and WSTR == 1 and KH == KW - - data_pad = topi.nn.pad(inputs, (0, HPAD, WPAD, 0), (0, HPAD, WPAD, 0), name="data_pad") - - r = KW - m = tile_size - alpha = m + r - 1 - A, B, G = winograd_transform_matrices(m, r, 'float32') - - H = (H + 2 * HPAD - KH) // HSTR + 1 - W = (W + 2 * WPAD - KW) // WSTR + 1 - nH, nW = (H + m - 1) // m, (W + m - 1) // m - P = N * nH * nW - r_kh = te.reduce_axis((0, KH), name='r_kh') - r_kw = te.reduce_axis((0, KW), name='r_kw') - # kernel_pack = te.compute((alpha, alpha, CO, CI), lambda eps, nu, co, ci: - # weight[0][0][0][0], - # name='kernel_pack') - kshape = (alpha, alpha, CO, CI) - kernel_pack = te.placeholder(kshape, inputs.dtype, name="weight") - - idxdiv = te.indexdiv - idxmod = te.indexmod - # pack input tile - input_tile = te.compute((alpha, alpha, P, CI), lambda eps, nu, p, ci: - data_pad[idxdiv(p, (nH * nW))][idxmod(idxdiv(p, nW), nH) * m + eps] - [idxmod(p, nW) * m + nu][ci], name='input_tile',) - - # transform data - r_a = te.reduce_axis((0, alpha), 'r_a') - r_b = te.reduce_axis((0, alpha), 'r_b') - data_pack = te.compute((alpha, alpha, P, CI), lambda eps, nu, p, ci: - te.sum(input_tile[r_a][r_b][p][ci] * B[r_a][eps] * B[r_b][nu], - axis=[r_a, r_b]), name='data_pack', - attrs={"ansor_no_split_at_inner": ["eps", "nu", "r_a", "r_b"], - "ansor_last_split_is_one": ["ci", "p"], - "ansor_always_unroll": ["eps", "nu", "r_a", "r_b"], - "ansor_no_cache_write": "True", - }) - - # do batch gemm - ci = te.reduce_axis((0, CI), name='ci') - bgemm = te.compute((alpha, alpha, P, CO), lambda eps, nu, p, co: - te.sum(data_pack[eps][nu][p][ci] * - kernel_pack[eps][nu][co][ci], - axis=[ci]), name='bgemm') - - # inverse transform - r_a = te.reduce_axis((0, alpha), 'r_a') - r_b = te.reduce_axis((0, alpha), 'r_b') - inverse = te.compute((m, m, P, CO), lambda vh, vw, p, co: - te.sum(bgemm[r_a][r_b][p][co] * A[r_a][vh] * A[r_b][vw], - axis=[r_a, r_b]), name='inverse', - attrs={"ansor_no_split_at_inner": ["vh", "vw", "r_a", "r_b"], - "ansor_always_unroll": ["vh", "vw", "r_a", "r_b"], - "ansor_last_split_is_one": ["co", "p"], - "ansor_no_cache_write": "True", - }) - - # output - output = te.compute((N, H, W, CO), lambda n, h, w, co: - inverse[idxmod(h, m), - idxmod(w, m), - n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m), - co], - name='conv2d_winograd', - tag='conv2d_winograd_nhwc', - attrs={"ansor_no_split_at_outer": ["n", "h", "w", "co"],}) - return [inputs, kernel_pack, output] - -@ansor.register_workload_func -def conv2d_winograd_nchw(N, CI, H, W, CO, kernel_size=3, stride=1, padding=0, dilation=1, precompute=False): - # TODO: implement tile_size - tile_size = 4 #_infer_tile_size(data, kernel) - inputs = te.placeholder((N, CI, H, W), name='inputs') - #weight = te.placeholder((CO, CI, kernel_size, kernel_size), name='weight') - N, CI, H, W = get_const_tuple(inputs.shape) - # if isinstance(dilation, int): - # dilation_h = dilation_w = dilation - # else: - # dilation_h, dilation_w = dilation - # if dilation_h != 1 or dilation_w != 1: - # weight = topi.nn.dilate(weight, (1, 1, dilation_h, dilation_w)) - KH = KW = kernel_size - HPAD, WPAD, _, _ = topi.nn.get_pad_tuple(padding, (KH, KW)) - HSTR, WSTR = (stride, stride) if isinstance(stride, int) else stride - assert HSTR == 1 and WSTR == 1 and KH == KW - - data_pad = topi.nn.pad(inputs, (0, 0, HPAD, WPAD), (0, 0, HPAD, WPAD), name="data_pad") - - r = KW - m = tile_size - alpha = m + r - 1 - A, B, G = winograd_transform_matrices(m, r, 'float32') - - H = (H + 2 * HPAD - KH) // HSTR + 1 - W = (W + 2 * WPAD - KW) // WSTR + 1 - nH, nW = (H + m - 1) // m, (W + m - 1) // m - P = N * nH * nW - r_kh = te.reduce_axis((0, KH), name='r_kh') - r_kw = te.reduce_axis((0, KW), name='r_kw') - # kernel_pack = te.compute((alpha, alpha, CI, CO), lambda eps, nu, ci, co: - # weight[0][0][0][0], - # name='kernel_pack') - kshape = (alpha, alpha, CI, CO) - kernel_pack = te.placeholder(kshape, inputs.dtype, name="weight") - - idxdiv = te.indexdiv - idxmod = te.indexmod - # pack input tile - input_tile = te.compute((CI, P, alpha, alpha), lambda ci, p, eps, nu: - data_pad[idxdiv(p, (nH * nW))][ci][idxmod(idxdiv(p, nW), nH) * m + eps] - [idxmod(p, nW) * m + nu], name='input_tile') - - # transform data - r_a = te.reduce_axis((0, alpha), 'r_a') - r_b = te.reduce_axis((0, alpha), 'r_b') - data_pack = te.compute((alpha, alpha, CI, P), lambda eps, nu, ci, p: - te.sum(input_tile[ci][p][r_a][r_b] * B[r_a][eps] * B[r_b][nu], - axis=[r_a, r_b]), name='data_pack', - attrs={"ansor_no_split_at_inner": ["eps", "nu", "r_a", "r_b"], - "ansor_no_split_at_outer": ["ci", "p"], - "ansor_always_unroll": ["eps", "nu", "r_a", "r_b"], - "ansor_no_cache_write": "True", - }) - - # do batch gemm - ci = te.reduce_axis((0, CI), name='ci') - bgemm = te.compute((alpha, alpha, CO, P), lambda eps, nu, co, p: - te.sum(data_pack[eps][nu][ci][p] * - kernel_pack[eps][nu][ci][co], - axis=[ci]), name='bgemm') - - # inverse transform - r_a = te.reduce_axis((0, alpha), 'r_a') - r_b = te.reduce_axis((0, alpha), 'r_b') - inverse = te.compute((CO, P, m, m), lambda co, p, vh, vw: - te.sum(bgemm[r_a][r_b][co][p] * A[r_a][vh] * A[r_b][vw], - axis=[r_a, r_b]), name='inverse', - attrs={"ansor_no_split_at_outer": ["co", "p", "vh", "vw", "r_a", "r_b"], - "ansor_always_unroll": ["vh", "vw", "r_a", "r_b"], - "ansor_no_cache_write": "True"}) - - # output - output = te.compute((N, CO, H, W), lambda n, co, h, w: - inverse[co, n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m), - idxmod(h, m), - idxmod(w, m)], - name='conv2d_winograd', - attrs={"ansor_no_split_at_outer": ["n", "co", "h", "w"],}) - return [inputs, kernel_pack, output] - -# ========================== Subgraphs ========================== - -@ansor.register_workload_func -def transpose_batch_matmul(batch, seq_len, n_head, n_dim): - query = te.placeholder((batch, seq_len, n_head, n_dim), name='query') - value = te.placeholder((batch, seq_len, n_head, n_dim), name='value') - query_T = te.compute((batch, n_head, seq_len, n_dim), - lambda b, h, l, d: query[b, l, h, d], name="query_T") - value_T = te.compute((batch, n_head, n_dim, seq_len), - lambda b, h, d, l: value[b, l, h, d], name="value_T") - k = te.reduce_axis((0, n_dim), name='k') - out = te.compute((batch, n_head, seq_len, seq_len), - lambda b, h, i, j: te.sum(query_T[b][h][i][k] * value_T[b][h][k][j], axis=[k]), - name='C') - return [query, value, out] - -# ========================== Tune function & Task dicts ========================== - -def tune_wkl(task_func_dict, shape_dict, wkl_type, args): - target = tvm.target.create(args.target) - - for wkl_meta_name, func in task_func_dict.items(): - if not args.wkl in ["all", wkl_type, wkl_meta_name]: - continue - - log_file = args.log_file or wkl_meta_name + ".json" - wkl_keys = [] - for shape in shape_dict[wkl_meta_name]: - if shape[0] == 1: - shape = list(shape) - shape[0] = args.batch_size - - wkl_key = ansor.make_workload_key_func(func, shape) - wkl_keys.append(wkl_key) - if args.fast_check: - break - - if not args.tune: - cost, gflops = replay_workload( - wkl_key, target, args.target_host, log_file, - args.local_measure, args.rpc_device_key, args.rpc_host, - args.rpc_port, args.rpc_num_threads, args.ndk_cc, False) - # log_line(BenchmarkRecord(target.name, 'gpu' if target.name == 'cuda' else 'cpu', 'subgraph', - # workload_name, "AutoSchedule", "default", - # {"costs": [cost]}, time.time()), args.out_file) - - if args.tune: - print("========== Tune for %s (%d shapes) ========== " % (wkl_meta_name, len(wkl_keys))) - - load_log_file = args.load_log or log_file - n_trials = args.n_trials_per_shape * len(wkl_keys) - - tune_option, measure_ctx = create_tune_option(target, log_file, - n_trials, args.num_measure_per_iter, args.verbose, - args.n_parallel, args.build_timeout, args.local_measure, - args.rpc_device_key, args.rpc_host, args.rpc_port, - args.rpc_num_threads, args.ndk_cc) - - # tune workloads jointly using JointTuner - tune_workloads_jointly(wkl_keys, np.ones(len(wkl_keys)), args.task_scheduler, - target, args.target_host, args.policy, args.model_type, - args.load_model, load_log_file, tune_option) - - if measure_ctx: - del measure_ctx - - -single_op_task_func_dict = { - 'GMM': batch_matmul_nkkm, - 'C1D': conv1d_nlc, - 'C2D': conv2d_nhwc, - 'C3D': conv3d_ndhwc, - 'GRP': conv2d_nhwc, - 'DIL': conv2d_nhwc, - 'DEP': depthwise_conv2d_nhwc, - 'T2D': conv2d_transpose_nhwc, - 'CAP': conv2d_capsule_nhwijc, - 'NRM': norm_bmn, - #'SMX': softmax_mn, - -# The following workloads are not in our sinle op evaluation plan. -# They should be moved to `common.py` and be used by `tune_wkl.py`. -# 'C2D_NCHW': conv2d_nchw, -# 'C2DWG_NHWC': conv2d_winograd_nhwc, -# 'C2DWG_NCHW': conv2d_winograd_nchw, -# 'GMM_TC': matmul_nkkm, -} - -subgraph_task_func_dict = { - 'conv2d_bn_relu': conv2d_nhwc_bn_relu, - #'conv2d_bn_relu': conv2d_nchw_bn_relu, # some old log uses conv2d_nchw_bn_relu - 'transpose_batch_matmul': transpose_batch_matmul, -} - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Search task related arguments - parser.add_argument("--wkl", type=str, required=True, - help="all - Tune all workloads; \ - op - Tune all single ops; \ - subgraph - Tune all subgraphs; \ - specific wkl name - Tune a specific workload") - parser.add_argument("--batch-size", type=int, default=1) - parser.add_argument("--target", type=str, default='llvm -mcpu=core-avx2') - parser.add_argument("--target-host", type=str, default=None) - parser.add_argument("--tune", type=str2bool, nargs='?', const=True, default=True) - parser.add_argument("--fast-check", action='store_true', - help='Only run one shape for each workload. This is used for fast checking') - - # Search strategy related arguments - parser.add_argument("--n-trials-per-shape", type=int, default=1000) - parser.add_argument("--policy", type=str, choices=['sketch', 'beam-search'], default='sketch') - parser.add_argument("--model-type", type=str, choices=['xgb', 'random', 'no-share'], default='xgb') - parser.add_argument("--task-scheduler", type=str, default='round-robin', - choices=['no', 'gradient', 'round-robin'], help='The strategy of task scheduler') - parser.add_argument("--seed", type=int, default=0, help='random seed') - - # Log file related arguments - parser.add_argument("--log-file", type=str, help="Write measurement records to this log file") - parser.add_argument("--load-log", type=str, help="Load history log to resume the status of search") - parser.add_argument("--load-model", type=str, help="Load pre-trained cost model from this file") - - # Measurement related and other arguments - parser.add_argument("--num-measure-per-iter", type=int, default=48, - help="The number of programs to be measured at each iteration") - parser.add_argument("--build-timeout", type=int, default=10) - parser.add_argument("--run-timeout", type=int, default=60) - parser.add_argument("--verbose", type=int, default=1) - parser.add_argument("--local-measure", type=str2bool, nargs='?', const=True, default=True) - parser.add_argument("--rpc-device-key", type=str, default=None) - parser.add_argument("--rpc-host", type=str, default='0.0.0.0') - parser.add_argument("--rpc-port", type=int, default=9190) - parser.add_argument("--rpc-num-threads", type=int, default=None) - parser.add_argument("--n-parallel", type=int, default=1) - parser.add_argument("--ndk-cc", type=str, default=None) - args = parser.parse_args() - - np.random.seed(args.seed) - random.seed(args.seed) - logging.basicConfig() - logging.getLogger('ansor').setLevel(logging.DEBUG) - - # compute the number of tasks - num_tasks = 0 - for wkl_meta_name in single_op_task_func_dict: - if not args.wkl in ["all", "op", wkl_meta_name]: - continue - if args.fast_check: - num_tasks += 1 - else: - num_tasks += len(single_op_shape_dict[wkl_meta_name]) - for wkl_meta_name in subgraph_task_func_dict: - if not args.wkl in ["all", "subgraph", wkl_meta_name]: - continue - if args.fast_check: - num_tasks += 1 - else: - num_tasks += len(subgraph_shape_dict[wkl_meta_name]) - print("Number of tasks: %d\tTotal trials: %d" % (num_tasks, num_tasks * args.n_trials_per_shape)) - - # tune for tasks - tune_wkl(single_op_task_func_dict, single_op_shape_dict, "op", args) - tune_wkl(subgraph_task_func_dict, subgraph_shape_dict, "subgraph", args) diff --git a/scripts/tune_test.py b/scripts/tune_test.py deleted file mode 100644 index 6b39cf5e7865..000000000000 --- a/scripts/tune_test.py +++ /dev/null @@ -1,394 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Use auto scheduler to tune workloads""" -import argparse -import logging -import os -import random - -import numpy as np - -import tvm -from tvm import ansor, te -from tvm.ansor.utils import request_remote - -from common import get_workload_keys, get_workload_weights, measure_schedule, str2bool - -def tensor_core_meet_condition(meta_policy, state, stage_id): - pass - -def intrin_wmma_load_matrix(scope): - n = 16 - A = te.placeholder((n, n), name='A', dtype='float16') - BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope='shared', data_alignment=32, offset_factor=256) - C = te.compute((n, n), lambda i, j: A[i, j], name='C') - BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=256) - - def intrin_func(ins, outs): - ib = tvm.tir.ir_builder.create() - - BA = ins[0] - BC = outs[0] - ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync', - BC.data, n, n, n, BC.elem_offset // 256, - BA.access_ptr('r'), n, 'row_major')) - return ib.get() - - return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) - -@tvm._ffi.register_func -def intrin_wmma_load_matrix_a(): - return intrin_wmma_load_matrix("wmma.matrix_a") - -@tvm._ffi.register_func -def intrin_wmma_load_matrix_b(): - return intrin_wmma_load_matrix("wmma.matrix_b") - -@tvm._ffi.register_func -def intrin_wmma_gemm(): - n = 16 - A = te.placeholder((n, n), name='A', dtype='float16') - B = te.placeholder((n, n), name='B', dtype='float16') - k = te.reduce_axis((0, n), name="k") - C = te.compute((n, n), - lambda ii, jj: - te.sum(A[ii, k].astype('float') * B[k, jj].astype('float'), axis=k), - name='C') - BA = tvm.tir.decl_buffer(A.shape, A.dtype, name='BA', scope='wmma.matrix_a', data_alignment=32, offset_factor=256) - BB = tvm.tir.decl_buffer(B.shape, B.dtype, name='BB', scope='wmma.matrix_b', data_alignment=32, offset_factor=256) - BC = tvm.tir.decl_buffer(C.shape, C.dtype, name='BC', scope='wmma.accumulator', data_alignment=32, offset_factor=256) - - def intrin_func(ins, outs): - BA, BB = ins - BC, = outs - - def init(): - ib = tvm.tir.ir_builder.create() - ib.emit(tvm.tir.call_intrin('handle', 'tvm_fill_fragment', BC.data, n, n, n, BC.elem_offset // 256, 0.0)) - return ib.get() - - def update(): - ib = tvm.tir.ir_builder.create() - ib.emit(tvm.tir.call_intrin('handle', 'tvm_mma_sync', - BC.data, BC.elem_offset // 256, - BA.data, BA.elem_offset // 256, - BB.data, BB.elem_offset // 256, - BC.data, BC.elem_offset // 256)) - return ib.get() - - return update(), init(), update() - - return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC}) - -@tvm._ffi.register_func -def intrin_wmma_store_matrix(): - n = 16 - A = te.placeholder((n, n), name='A', dtype='float32') - BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope='wmma.accumulator', data_alignment=32, offset_factor=256) - C = te.compute((n, n), lambda i, j: A[i, j], name='C') - BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope='global', data_alignment=32, offset_factor=256) - - def intrin_func(ins, outs): - ib = tvm.tir.ir_builder.create() - BA = ins[0] - BC = outs[0] - ib.emit(tvm.tir.call_intrin('handle', 'tvm_store_matrix_sync', - BA.data, n, n, n, BA.elem_offset // 256, - BC.access_ptr('w'), n, 'row_major')) - return ib.get() - - return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) - -def tensor_core_apply(meta_policy, state, stage_id): - ret = [] - state = ansor.loop_state.State(state, meta_policy.cur_task.compute_dag) - - A, B, C = meta_policy.cur_task.compute_dag.ops - - C_local = state.cache_write(C, "wmma.accumulator") - - its0 = state.split(C_local, state[C_local].iters[0], [None, None]) - split_step0 = state.transform_steps_size() - 1 - its1 = state.split(C_local, state[C_local].iters[3], [None, None]) - split_step1 = state.transform_steps_size() - 1 - its2 = state.split(C_local, state[C_local].iters[8], [None]) - - state.reorder(C_local, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], - its2[0], its2[1], - state[C_local].iters[6], - state[C_local].iters[7], - state[C_local].iters[10]]) - state.fuse(C_local, [state[C_local].iters[0], state[C_local].iters[1]]) - state.fuse(C_local, [state[C_local].iters[1], state[C_local].iters[2]]) - state.fuse(C_local, [state[C_local].iters[2], state[C_local].iters[3]]) - - its0 = state.follow_split(C, state[C].iters[0], split_step0, 2) - its1 = state.follow_split(C, state[C].iters[3], split_step1, 2) - state.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], - state[C].iters[6], state[C].iters[7]]) - state.fuse(C, [state[C].iters[0], state[C].iters[1]]) - state.fuse(C, [state[C].iters[1], state[C].iters[2]]) - local_write_pos = state.fuse(C, [state[C].iters[2], state[C].iters[3]]) - state.compute_at(C_local, C, local_write_pos) - shared_read_pos = state[C_local].iters[3] - local_read_pos = state[C_local].iters[4] - state.bind_thread(C, state[C].iters[0], "blockIdx.x") - state.bind_thread(C, state[C].iters[1], "vthread") - state.bind_thread(C, state[C].iters[2], "threadIdx.x") - - B_shared = state.cache_read(B, "shared", [C_local]) - B_local = state.cache_read(B_shared, "wmma.matrix_b", [C_local]) - state.compute_at(B_shared, C_local, shared_read_pos) - state.compute_at(B_local, C_local, local_read_pos) - - it = state.fuse(B_shared, state[B_shared].iters[:]) - its = state.split(B_shared, it, [4]) # vectorize add a callback check function - state.vectorize(B_shared, its[1]) - its = state.follow_fused_split(B_shared, its[0], [split_step0, split_step1], 1, True) - state.bind_thread(B_shared, its[1], "threadIdx.x") - - A_shared = state.cache_read(A, "shared", [C_local]) - A_local = state.cache_read(A_shared, "wmma.matrix_a", [C_local]) - state.compute_at(A_shared, C_local, shared_read_pos) - state.compute_at(A_local, C_local, local_read_pos) - - it = state.fuse(A_shared, state[A_shared].iters[:]) - its = state.split(A_shared, it, [4]) # vectorize add a callback check function - state.vectorize(A_shared, its[1]) - its = state.follow_fused_split(A_shared, its[0], [split_step0, split_step1], 1, True) - state.bind_thread(A_shared, its[1], "threadIdx.x") - - state.tensorize(A_local, state[A_local].iters[-2], "intrin_wmma_load_matrix_a") - state.tensorize(B_local, state[B_local].iters[-2], "intrin_wmma_load_matrix_b") - state.tensorize(C_local, state[C_local].iters[-3], "intrin_wmma_gemm") - state.tensorize(C, state[C].iters[-2], "intrin_wmma_store_matrix") - - print(state) - - ret.append([state.state_object, -1]) - return ret - -def create_tune_option(target, log_file, n_trials, num_measure_per_iter, verbose, - n_parallel, build_timeout, local_measure, rpc_device_key, rpc_host, - rpc_port, rpc_num_threads, ndk_cc, early_stopping=-1, run_timeout=10, - tensor_core_matmul=False): - builder = runner = measure_ctx = None - if local_measure: - builder = ansor.LocalBuilder(timeout=build_timeout) - if target.target_name == "cuda": - measure_ctx = ansor.LocalRPCMeasureContext(repeat=1, min_repeat_ms=400) - runner = measure_ctx.runner - else: - os.environ['TVM_AUTO_CACHE_FLUSH'] = "1" - runner = ansor.LocalRunner(repeat=10, number=1, min_repeat_ms=0, timeout=run_timeout) - else: - os.environ['TVM_NDK_CC'] = ndk_cc - builder = ansor.LocalBuilder(timeout=build_timeout, build_func='ndk') - runner = ansor.RPCRunner(key=rpc_device_key, host=rpc_host, port=rpc_port, - timeout=run_timeout, n_parallel=n_parallel, - repeat=1, min_repeat_ms=200) - remote = request_remote(rpc_device_key, rpc_host, rpc_port) - if rpc_num_threads: - config_threadpool = remote.get_function('runtime.config_threadpool') - config_threadpool(0, rpc_num_threads) - - pre_search_callbacks = [ansor.PreloadMeasuredStates(log_file)] - if tensor_core_matmul: - pre_search_callbacks.append(ansor.PreloadCustomSketchRule(tensor_core_meet_condition, tensor_core_apply)) - tune_option = ansor.TuneOption(n_trials=n_trials, early_stopping=early_stopping, - num_measure_per_iter=num_measure_per_iter, - verbose=verbose, - builder=builder, - runner=runner, - measure_callbacks=[ansor.LogToFile(log_file)], - pre_search_callbacks=pre_search_callbacks) - - return tune_option, measure_ctx - - -def replay_workload(wkl_key, target, target_host, log_file, - local_measure=True, rpc_device_key=None, rpc_host="0.0.0.0", - rpc_port=9190, rpc_num_threads=None, ndk_cc=None, - show_lower_result=True): - cost = gflops = None - - inp, res = ansor.best_measure_pair_in_file(log_file, wkl_key, target) - if inp is None: - print("Cannot find log for: %s" % wkl_key) - else: - dag = ansor.workload_key_to_dag(inp.task.workload_key) - print("Found schedule for: %s" % wkl_key) - - s, bufs = dag.apply_steps_from_state(inp.state) - if show_lower_result: - print(tvm.lower(s, bufs, simple_mode=True)) - - if local_measure: - remote = None - else: - remote = request_remote(rpc_device_key, rpc_host, rpc_port) - if rpc_num_threads: - config_threadpool = remote.get_function('runtime.config_threadpool') - config_threadpool(0, rpc_num_threads) - - cost = np.mean((measure_schedule(s, bufs, target, target_host, - remote=remote, ndk_cc=ndk_cc))) - gflops = ansor.ComputeDAG(bufs).flop_ct / cost / 1e9 - print("Best schedule: %.2f GFLOPS\tcost: %.3f ms" % (gflops, cost * 1e3)) - - return cost, gflops - - -def tune_workload(wkl_key, target, target_host, policy, model_type, - load_model_file, load_log_file, tune_option): - """Tune a workload""" - - if False: - # Debug info. Print static analysis results from the access analyzer - dag = ansor.workload_key_to_dag(wkl_key) - print(dag.access_analyzer) - exit() - - if model_type == 'xgb': - model = ansor.XGBModel() - if load_model_file: - print("Load pretrained model...") - model.load(load_model_file) - elif load_log_file: - model.load_log_file(load_log_file) - elif model_type == "random": - model = ansor.RandomModel() - else: - raise ValueError("Invalid model: " + model_type) - - if policy == 'sketch': - policy = ansor.SketchSearchPolicy(program_cost_model=model) - elif policy == 'beam-search': - policy = ansor.SketchSearchPolicy(program_cost_model=model, - params={'use_beam_search': 1}) - else: - raise ValueError("Invalid search policy: " + policy) - - s, bufs = ansor.auto_schedule(wkl_key, - target=target, target_host=target_host, - search_policy=policy, - tune_option=tune_option) - -def tune_workloads_jointly(wkl_keys, weights, task_scheduler, target, target_host, - search_policy, model_type, load_model_file, load_log_file, - tune_option): - """Tune for multiple workloads together with TaksScheduler""" - tasks = [] - for wkl_key in wkl_keys: - dag = ansor.workload_key_to_dag(wkl_key) - tasks.append(ansor.SearchTask(dag, wkl_key, target, target_host)) - - def objective_func(costs): - return sum(c * w for c, w in zip(costs, weights)) - - tuner = ansor.SimpleTaskScheduler(tasks, objective_func, strategy=task_scheduler, - load_log_file=load_log_file, load_model_file=load_model_file) - search_policy = "%s.%s" % (search_policy, model_type) - tuner.tune(tune_option, search_policy) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Search task related arguments - parser.add_argument("--wkl", type=str, required=True) - parser.add_argument("--target", type=str, default='llvm -mcpu=core-avx2') - parser.add_argument("--target-host", type=str, default=None) - parser.add_argument("--tune", type=str2bool, nargs='?', const=True, default=True) - - # Search strategy related arguments - parser.add_argument("--n-trials", type=int, default=1000) - parser.add_argument("--policy", type=str, choices=['sketch', 'beam-search'], default='sketch') - parser.add_argument("--model-type", type=str, choices=['xgb', 'random', 'no-share'], default='xgb') - parser.add_argument("--task-scheduler", type=str, default='no', - choices=['no', 'gradient', 'round-robin'], - help='The strategy of task scheduler') - parser.add_argument("--seed", type=int, default=0, help='random seed') - - # Log file related arguments - parser.add_argument("--log-file", type=str, help="Write measurement records to this log file") - parser.add_argument("--load-log", type=str, help="Load history log to resume the status of search") - parser.add_argument("--load-model", type=str, help="Load pre-trained cost model from this file") - - # Measurement related and other arguments - parser.add_argument("--num-measure-per-iter", type=int, default=48, - help="The number of programs to be measured at each iteration") - parser.add_argument("--build-timeout", type=int, default=10) - parser.add_argument("--run-timeout", type=int, default=60) - parser.add_argument("--verbose", type=int, default=1) - parser.add_argument("--local-measure", type=str2bool, nargs='?', const=True, default=True) - parser.add_argument("--rpc-device-key", type=str, default=None) - parser.add_argument("--rpc-host", type=str, default='0.0.0.0') - parser.add_argument("--rpc-port", type=int, default=9190) - parser.add_argument("--rpc-num-threads", type=int, default=None) - parser.add_argument("--n-parallel", type=int, default=1) - parser.add_argument("--ndk-cc", type=str, default=None) - args = parser.parse_args() - - np.random.seed(args.seed) - random.seed(args.seed) - logging.basicConfig() - logging.getLogger('ansor').setLevel(logging.DEBUG) - - wkl_keys = get_workload_keys(args.wkl) - target = tvm.target.create(args.target) - log_file = args.log_file or args.wkl + ".json" - - # Tune workloads - if args.tune: - load_log_file = args.load_log or log_file - weights = get_workload_weights(args.wkl) - - # Special check for tensor core - wkl_key = args.wkl - wkl_key = wkl_key.split("-") - tensor_core_matmul = False - if wkl_key[0] == "matmul" and wkl_key[6] == "tc": - tensor_core_matmul = True - - tune_option, measure_ctx = create_tune_option(target, log_file, - args.n_trials, args.num_measure_per_iter, args.verbose, - args.n_parallel, args.build_timeout, args.local_measure, - args.rpc_device_key, args.rpc_host, args.rpc_port, args.rpc_num_threads, - args.ndk_cc, tensor_core_matmul=tensor_core_matmul) - - if args.task_scheduler == 'no': - # tune workloads one by one - for wkl_key in wkl_keys: - tune_workload(wkl_key, target, args.target_host, args.policy, - args.model_type, args.load_model, load_log_file, - tune_option) - else: - # tune workloads jointly with TaskScheduler - tune_workloads_jointly(wkl_keys, weights, args.task_scheduler, - target, args.target_host, args.policy, - args.model_type, args.load_model, load_log_file, - tune_option) - if measure_ctx: - del measure_ctx - - # Replay the best found schedule - if len(wkl_keys) == 1 or not args.tune: - for wkl_key in wkl_keys: - replay_workload(wkl_key, target, args.target_host, log_file, - args.local_measure, args.rpc_device_key, args.rpc_host, - args.rpc_port, args.rpc_num_threads, args.ndk_cc) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index d3af64a4f576..4887ef0ee47d 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -132,13 +132,6 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { TVM_TRY_REWRITE(ramp(b1, s1, lanes) + broadcast(x, lanes), ramp(b1 + x, s1, lanes)); TVM_TRY_REWRITE(broadcast(x, lanes) + ramp(b1, s1, lanes), ramp(x + b1, s1, lanes)); TVM_TRY_REWRITE(broadcast(x, lanes) + broadcast(y, lanes), broadcast(x + y, lanes)); - if ((x + broadcast(y, lanes)).Match(ret)) { - if (auto ps = y.Eval().as()) { - if (ps->value == 0.0) { - return x.Eval(); - } - } - } } if (IsIndexType(op->dtype)) { @@ -429,13 +422,6 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) { TVM_TRY_REWRITE(broadcast(x, lanes) * broadcast(y, lanes), broadcast(x * y, lanes)); TVM_TRY_REWRITE(ramp(b1, s1, lanes) * broadcast(x, lanes), ramp(b1 * x, s1 * x, lanes)); TVM_TRY_REWRITE(broadcast(x, lanes) * ramp(b1, s1, lanes), ramp(b1 * x, s1 * x, lanes)); - if ((broadcast(x, lanes) * y).Match(ret)) { - if (auto ps = x.Eval().as()) { - if (ps->value == 0.0) { - return make_const(op->dtype, 0.0); - } - } - } } if (IsIndexType(op->dtype)) { @@ -714,9 +700,9 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression - PVar w, x, y, z, b1; + PVar x, y, z, b1; // Pattern var match IntImm - PVar c1, c2, c3, c4; + PVar c1, c2, c3; // Pattern var for lanes in broadcast and ramp PVar lanes; @@ -781,11 +767,6 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE_IF(floordiv(max(y, x * c1), c2), max(floordiv(y, c2), x * floordiv(c1, c2)), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); - TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2), floordiv(x * c1, c2), - c1.Eval()->value > 0 && c2.Eval()->value > 0 && - c2.Eval()->value % c1.Eval()->value == 0 && - CanProveGreaterEqual(-y.Eval(), -c1.Eval()->value + 1)); - // Rules involving 3-operands. TVM_TRY_REWRITE_IF(floordiv(x * c1 + y + z, c2), x * floordiv(c1, c2) + floordiv(y + z, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); @@ -802,13 +783,6 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE_IF(floordiv(x + c1, c2), floordiv(x, c2) + floordiv(c1, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); - TVM_TRY_REWRITE_IF(floordiv(x * c1 + y * c2 + z, c3), floordiv(x * c1 + y * c2, c3), - c1.Eval()->value > 0 && c2.Eval()->value > 0 && c3.Eval()->value > 0 && - c3.Eval()->value % c1.Eval()->value == 0 && - c3.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(-z.Eval(), - std::max(-c1.Eval()->value, -c2.Eval()->value) + 1)); - TVM_TRY_REWRITE_IF(floordiv(x + y, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0)); TVM_TRY_REWRITE_IF(floordiv(y + x, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0)); @@ -833,18 +807,6 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { CanProveGreaterEqual(z.Eval(), 0)); TVM_TRY_REWRITE_IF(floordiv(y + z * x, z), floordiv(y, z) + x, CanProveGreaterEqual(z.Eval(), 0)); - - // Rules involving 4-operands - TVM_TRY_REWRITE_IF(floordiv(w * c1 + x * c2 + y * c3 + z, c4), - floordiv(w * c1 + x * c2 + y * c3, c4), - c1.Eval()->value > 0 && c2.Eval()->value > 0 && - c3.Eval()->value > 0 && c4.Eval()->value > 0 && - c4.Eval()->value % c1.Eval()->value == 0 && - c4.Eval()->value % c2.Eval()->value == 0 && - c4.Eval()->value % c3.Eval()->value == 0 && - CanProveGreaterEqual(-z.Eval(), - std::max(-c1.Eval()->value, - std::max(-c2.Eval()->value, -c3.Eval()->value)) + 1)); } return ret; } @@ -856,9 +818,9 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { if (const_res.defined()) return const_res; // Pattern var to match any expression - PVar w, x, y, z, b1; + PVar x, y, z, b1; // Pattern var match IntImm - PVar c1, c2, c3, c4; + PVar c1, c2; // Pattern var for lanes in broadcast and ramp PVar lanes; @@ -902,31 +864,6 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); - TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x, floordiv(c2, c1)) * c1 + y, - c1.Eval()->value > 0 && c2.Eval()->value > 0 && - c2.Eval()->value % c1.Eval()->value == 0 && - CanProveGreaterEqual(-y.Eval(), -c1.Eval()->value + 1)); - - // TODO(jcf94): For the next three rules, better use the max common factor - // of c1, c2, c3 to do the simplify - TVM_TRY_REWRITE_IF(floormod(x * c1 + y * c2 + z, c3), - floormod(x * floordiv(c1, c2) + y, floordiv(c3, c2)) * c2 + z, - c1.Eval()->value > 0 && c2.Eval()->value > 0 && - c3.Eval()->value > 0 && - c3.Eval()->value % c2.Eval()->value == 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(-z.Eval(), -c2.Eval()->value + 1)); - - TVM_TRY_REWRITE_IF(floormod(w * c1 + x * c2 + y * c3 + z, c4), - floormod(w * floordiv(c1, c3) + x * floordiv(c2, c3) + y, - floordiv(c4, c3)) * c3 + z, - c1.Eval()->value > 0 && c2.Eval()->value > 0 && - c3.Eval()->value > 0 && c4.Eval()->value > 0 && - c4.Eval()->value % c3.Eval()->value == 0 && - c1.Eval()->value % c3.Eval()->value == 0 && - c2.Eval()->value % c3.Eval()->value == 0 && - CanProveGreaterEqual(-z.Eval(), -c3.Eval()->value + 1)); - // try modular analysis if (floormod(x, c1).Match(ret)) { ModularSet mod = analyzer_->modular_set(x.Eval()); diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index 5b063eca4337..a192002825e6 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -219,7 +219,6 @@ class TypeSolver::Unifier : public TypeFunctor { return Type(nullptr); } - tt1 = tt2; tvm::Array shape; if (tt1->shape.size() != tt2->shape.size()) { this->solver_->ReportError(ErrorBuilder() << "tensor type `" << PrettyPrint(tt1) << "` has " diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 30269b85795f..ee5e291e3d53 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2455,60 +2455,6 @@ the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w] .set_support_level(5) .set_attr("FTVMCompute", LayoutTransformCompute); -// relay.kernel_layout_transform -TVM_REGISTER_NODE_TYPE(KernelLayoutTransformAttrs); - -Array KernelLayoutTransformCompute(const Attrs& attrs, - const Array& inputs, - const Type& out_type) { - const auto* param = attrs.as(); - CHECK(param != nullptr); - return Array{ - topi::kernel_layout_transform(inputs[0], param->src_layout, param->dst_layout) - }; -} - -bool KernelLayoutTransformRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - const auto* data = types[0].as(); - CHECK(data != nullptr); - const KernelLayoutTransformAttrs* params = attrs.as(); - - Array dst_shape; - std::vector dst_axes; - - topi::parse_kernel_layout(params->dst_layout, &dst_shape, &dst_axes); - - reporter->Assign(types[1], TensorType(dst_shape, data->dtype)); - return true; -} - -Expr MakeKernelLayoutTransform(Expr data, - String src_layout, - String dst_layout) { - auto attrs = make_object(); - attrs->src_layout = std::move(src_layout); - attrs->dst_layout = std::move(dst_layout); - static const Op& op = Op::Get("kernel_layout_transform"); - return Call(op, {data}, Attrs(attrs), {}); -} - -TVM_REGISTER_GLOBAL("relay.op._make.kernel_layout_transform") -.set_body_typed(MakeKernelLayoutTransform); - -RELAY_REGISTER_OP("kernel_layout_transform") - .describe(R"code(Transform the input kernel layout. -)code" TVM_ADD_FILELINE) - .set_attrs_type() - .set_num_inputs(1) - .add_argument("data", "Tensor", "The input tensor.") - .add_type_rel("kernel_layout_transform", KernelLayoutTransformRel) - .set_support_level(5) - .set_attr("FTVMCompute", KernelLayoutTransformCompute); - - /* relay._contrib_reverse_reshape */ Expr MakeReverseReshape(Expr data, Array newshape) { auto attrs = make_object(); diff --git a/src/relay/transforms/defuse_ops.cc b/src/relay/transforms/defuse_ops.cc deleted file mode 100644 index 1a108fb08888..000000000000 --- a/src/relay/transforms/defuse_ops.cc +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include "pattern_util.h" - -namespace tvm { -namespace relay { - -class DefuseOpsMutator : public ExprMutator { - public: - class FuncBodyMutator : public ExprMutator { - public: - Array args_; - - explicit FuncBodyMutator(const Array& args) : ExprMutator() { args_ = args; } - - Expr VisitExpr_(const VarNode* n) { - const std::string& name = n->name_hint(); - CHECK_EQ(name[0], 'p'); - std::string id_str = name.substr(1); - int id = atoi(id_str.c_str()); - CHECK(id >= 0 && size_t(id) < args_.size()); - return args_[id]; - } - }; - - Expr VisitExpr_(const CallNode* n) { - auto new_n = ExprMutator::VisitExpr_(n); - - const auto* call = new_n.as(); - if (call) { - const auto* func = call->op.as(); - if (func) { - const auto& func_call = func->body.as(); - if (func_call) { - return FuncBodyMutator(call->args).Mutate(func->body); - } - } - } - return new_n; - } -}; - -Expr DeFuseOps(const Expr& expr) { return DefuseOpsMutator().Mutate(expr); } - -namespace transform { - -Pass DeFuseOps() { - runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(relay::DeFuseOps(f)); - }; - return CreateFunctionPass(pass_func, 3, "DeFuseOps", {"InferType"}); -} - -TVM_REGISTER_GLOBAL("relay._transform.DeFuseOps").set_body_typed(DeFuseOps); - -} // namespace transform - -} // namespace relay -} // namespace tvm diff --git a/src/relay/transforms/kernel_layout_transform.cc b/src/relay/transforms/kernel_layout_transform.cc deleted file mode 100644 index 421968b8a6b9..000000000000 --- a/src/relay/transforms/kernel_layout_transform.cc +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include "kernel_layout_transform.h" - -#include -#include -#include -#include -#include - -#include -#include -#include - -namespace tvm { -namespace relay { - -// Todo: do not use global variables -std::deque KernelLayoutVisitor::global_ori_layouts_queue; -std::deque KernelLayoutVisitor::global_new_layouts_queue; - -Expr KernelLayoutTransform(const Expr& expr) { - KernelLayoutVisitor visitor; - - // Do a pre-order DFS to gather the optimal kernel layouts for all conv2d nodes. - // These layouts were written to global static variables in python function - // `prepare_layout_rewrite` - visitor.VisitExpr(expr); - - // Do a post-order DSF to mutate layout for all conv2d nodes - return KernelLayoutTransformer(&visitor).Mutate(expr); -} - -namespace transform { - -Pass KernelLayoutTransform() { - runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(relay::KernelLayoutTransform(f)); - }; - return CreateFunctionPass(pass_func, 3, "KernelLayoutTransform", {"InferType"}); -} - -TVM_REGISTER_GLOBAL("relay._transform.KernelLayoutTransform").set_body_typed(KernelLayoutTransform); - -} // namespace transform - -} // namespace relay -} // namespace tvm diff --git a/src/relay/transforms/kernel_layout_transform.h b/src/relay/transforms/kernel_layout_transform.h deleted file mode 100644 index c6c38fb71cf4..000000000000 --- a/src/relay/transforms/kernel_layout_transform.h +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#ifndef TVM_RELAY_TRANSFORMS_KERNEL_LAYOUT_TRANSFORM_H_ -#define TVM_RELAY_TRANSFORMS_KERNEL_LAYOUT_TRANSFORM_H_ - -#include -#include - -#include -#include -#include -#include -#include - -#include "../../ansor/compute_dag.h" -#include "pattern_util.h" - -namespace tvm { -namespace relay { - -/*! \brief A visitor to gather the optimal kernel layout for all conv2d nodes. */ -class KernelLayoutVisitor : public ExprVisitor { - public: - void VisitExpr_(const CallNode* n) { - if (n && n->op.as() && - (std::find(op_white_lists.begin(), op_white_lists.end(), n->op.as()->name) != - op_white_lists.end()) && - n->args[1]->type_as()->shape[3].as()->value > 1 && - !global_ori_layouts_queue.empty() && !global_new_layouts_queue.empty()) { - ori_layouts_map[n] = global_ori_layouts_queue.front(); - new_layouts_map[n] = global_new_layouts_queue.front(); - // std::cout << "ori_layout " << global_ori_layouts_queue.front() - // << " Filter_shape " << n->args[1]->type_as()->shape << std::endl; - global_ori_layouts_queue.pop_front(); - global_new_layouts_queue.pop_front(); - } - ExprVisitor::VisitExpr_(n); - } - - std::unordered_map ori_layouts_map; - std::unordered_map new_layouts_map; - std::vector op_white_lists{"nn.contrib_conv2d_winograd_without_weight_transform", - "nn.conv2d", "nn.conv3d"}; - - static std::deque global_ori_layouts_queue; - static std::deque global_new_layouts_queue; -}; - -/*! \brief A mutator to rewrite kernel layout for all conv2d nodes */ -class KernelLayoutTransformer : public ExprMutator { - public: - explicit KernelLayoutTransformer(KernelLayoutVisitor* visitor) - : ExprMutator(), visitor_(visitor) {} - - Expr VisitExpr_(const CallNode* n) { - auto new_n = ExprMutator::VisitExpr_(n); - - const auto* call = new_n.as(); - std::vector op_white_lists{"nn.contrib_conv2d_winograd_without_weight_transform", - "nn.conv2d", "nn.conv3d"}; - if (call && call->op.as() && - (std::find(op_white_lists.begin(), op_white_lists.end(), n->op.as()->name) != - op_white_lists.end() && - n->args[1]->type_as()->shape[3].as()->value > 1)) { - auto ori_layout_iter = visitor_->ori_layouts_map.find(n); - auto new_layout_iter = visitor_->new_layouts_map.find(n); - if (ori_layout_iter != visitor_->ori_layouts_map.end() && - new_layout_iter != visitor_->new_layouts_map.end()) { - const std::string& ori_layout = ori_layout_iter->second; - const std::string& new_layout = new_layout_iter->second; - Expr updated_kernel = MakeKernelLayoutTransform(call->args[1], ori_layout, new_layout); - Array updated_args = {call->args[0], updated_kernel}; - new_n = Call(call->op, updated_args, call->attrs); - } - } - return new_n; - } - - private: - KernelLayoutVisitor* visitor_; -}; - -} // namespace relay -} // namespace tvm - -#endif // TVM_RELAY_TRANSFORMS_KERNEL_LAYOUT_TRANSFORM_H_ diff --git a/src/relay/transforms/pattern_util.h b/src/relay/transforms/pattern_util.h index a9d3b5168e47..7518eb9ac81a 100644 --- a/src/relay/transforms/pattern_util.h +++ b/src/relay/transforms/pattern_util.h @@ -685,8 +685,6 @@ Expr MakeExpandDims(Expr data, int axis, int num_newaxis); Expr MakeLayoutTransform(Expr data, String src_layout, String dst_layout); -Expr MakeKernelLayoutTransform(Expr data, String src_layout, String dst_layout); - Expr StopFusion(Expr data); Expr CastHint(Expr data, DataType dtype); diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index 4e71383cc1bb..a6d4a5499469 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -94,10 +94,6 @@ class CUDADeviceAPI final : public DeviceAPI { } case kGcnArch: return; - case kMaxRegistersPerBlock: { - CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrMaxRegistersPerBlock, ctx.device_id)); - break; - } } *rv = value; } diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index 714535ecc8a6..800a9167dadc 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -26,9 +26,6 @@ #include #include -#include -#include - #include "runtime_base.h" extern "C" { @@ -183,8 +180,7 @@ NDArray NDArray::CreateView(std::vector shape, DLDataType dtype) { DLManagedTensor* NDArray::ToDLPack() const { return Internal::ToDLPack(get_mutable()); } -NDArray NDArray::Empty(std::vector shape, DLDataType dtype, - DLContext ctx) { +NDArray NDArray::Empty(std::vector shape, DLDataType dtype, DLContext ctx) { NDArray ret = Internal::Create(shape, dtype, ctx); // setup memory content size_t size = GetDataSize(ret.get_mutable()->dl_tensor); @@ -194,59 +190,6 @@ NDArray NDArray::Empty(std::vector shape, DLDataType dtype, return ret; } - -NDArray NDArray::NonEmpty(std::vector shape, DLDataType dtype, - DLContext ctx) { - NDArray ret = Internal::Create(shape, dtype, ctx); - NDArray dummy_cpu_arr = Internal::Create(shape, dtype, {kDLCPU, 0}); - - // setup memory content - size_t size = GetDataSize(ret.get_mutable()->dl_tensor); - size_t alignment = GetDataAlignment(ret.get_mutable()->dl_tensor); - dummy_cpu_arr.get_mutable()->dl_tensor.data = - DeviceAPI::Get(dummy_cpu_arr->ctx)->AllocDataSpace( - {kDLCPU, 0}, size, alignment, dummy_cpu_arr->dtype); - size_t elem_cnt = 1; - for (tvm_index_t i = 0; i < dummy_cpu_arr->ndim; ++i) { - elem_cnt *= static_cast(dummy_cpu_arr->shape[i]); - } - - // TODO(..): maybe we could have better solution for assigning values - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_real_distribution<> dis(1.0, 10.0); - // Use float representation could make us work well on float / int type too. - for (size_t i = 0; i < elem_cnt; ++i) { - if (dummy_cpu_arr->dtype.bits == 1) { - (reinterpret_cast( - dummy_cpu_arr.get_mutable()->dl_tensor.data))[i] = dis(gen); - } else if (dummy_cpu_arr->dtype.bits == 8) { - (reinterpret_cast( - dummy_cpu_arr.get_mutable()->dl_tensor.data))[i] = dis(gen); - } else if (dummy_cpu_arr->dtype.bits == 16) { - (reinterpret_cast( - dummy_cpu_arr.get_mutable()->dl_tensor.data))[i] = - __truncXfYf2__( - static_cast(dis(gen))); - } else if (dummy_cpu_arr->dtype.bits == 32) { - (reinterpret_cast( - dummy_cpu_arr.get_mutable()->dl_tensor.data))[i] = dis(gen); - } else if (dummy_cpu_arr->dtype.bits == 64) { - (reinterpret_cast( - dummy_cpu_arr.get_mutable()->dl_tensor.data))[i] = dis(gen); - } else { - LOG(FATAL) << "Doesn't support dtype code " << dtype.code - << " dtype bits " << dtype.bits; - } - } - ret.get_mutable()->dl_tensor.data = - DeviceAPI::Get(ret->ctx)->AllocDataSpace( - ret->ctx, size, alignment, ret->dtype); - CopyFromTo(&(dummy_cpu_arr.get_mutable()->dl_tensor), - &(ret.get_mutable()->dl_tensor)); - return ret; -} - NDArray NDArray::FromDLPack(DLManagedTensor* tensor) { NDArray::Container* data = new NDArray::Container(); // construct header @@ -314,9 +257,8 @@ int TVMArrayGetTypeIndex(TVMArrayHandle handle, unsigned* out_tindex) { API_END(); } -int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, - int dtype_bits, int dtype_lanes, int device_type, - int device_id, TVMArrayHandle* out) { +int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_bits, + int dtype_lanes, int device_type, int device_id, TVMArrayHandle* out) { API_BEGIN(); DLDataType dtype; dtype.code = static_cast(dtype_code); @@ -330,22 +272,6 @@ int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, API_END(); } -int TVMArrayAllocNonEmpty(const tvm_index_t* shape, int ndim, int dtype_code, - int dtype_bits, int dtype_lanes, int device_type, - int device_id, TVMArrayHandle* out) { - API_BEGIN(); - DLDataType dtype; - dtype.code = static_cast(dtype_code); - dtype.bits = static_cast(dtype_bits); - dtype.lanes = static_cast(dtype_lanes); - DLContext ctx; - ctx.device_type = static_cast(device_type); - ctx.device_id = device_id; - *out = NDArray::Internal::MoveToFFIHandle( - NDArray::NonEmpty(std::vector(shape, shape + ndim), dtype, ctx)); - API_END(); -} - int TVMArrayFree(TVMArrayHandle handle) { API_BEGIN(); NDArray::Internal::FFIDecRef(handle); diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index 71d3232ca4d5..6d9835e6231c 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -109,9 +109,6 @@ void OpenCLWorkspace::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* } case kGcnArch: return; - default: { - LOG(WARNING) << "Attr not implemented."; - } } } diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index d58130d700f4..89f3e7c6c7f8 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -24,14 +24,9 @@ #include #include -#include #include #include -#if defined(_M_X64) || defined(__x86_64__) -#include -#endif - #include "rpc_endpoint.h" #include "rpc_session.h" @@ -305,22 +300,6 @@ std::shared_ptr RPCModuleGetSession(Module mod) { return rmod->sess(); } -inline void CacheFlush(const char* p, unsigned int allocation_size) { -// TODO(FrozenGene): Support ARM. -#if (defined(_M_X64) || defined(__x86_64__)) - size_t cache_line = 64; - - if (p == nullptr || allocation_size <= 0) { - return; - } - - for (size_t i = 0; i < allocation_size; i += cache_line) { - _mm_clflush(static_cast(&p[i])); - } - -#endif -} - PackedFunc WrapTimeEvaluator(PackedFunc pf, TVMContext ctx, int number, int repeat, int min_repeat_ms) { CHECK(pf != nullptr); @@ -334,21 +313,12 @@ PackedFunc WrapTimeEvaluator(PackedFunc pf, TVMContext ctx, int number, int repe auto ftimer = [pf, ctx, number, repeat, min_repeat_ms](TVMArgs args, TVMRetValue* rv) mutable { TVMRetValue temp; std::ostringstream os; - const char* cache_flush = std::getenv("TVM_AUTO_CACHE_FLUSH"); // skip first time call, to activate lazy compilation components. pf.CallPacked(args, &temp); DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); for (int i = 0; i < repeat; ++i) { - if (cache_flush && std::atoi(cache_flush) != 0) { - CHECK_EQ(number, 1); - // we want to keep input data - for (int j = 1; j < args.size(); j++) { - CacheFlush(reinterpret_cast(args[j].operator DLTensor*()->data), - GetDataSize(*(args[j].operator DLTensor*()))); - } - } std::chrono::time_point tbegin, tend; double duration_ms = 0.0; diff --git a/src/runtime/threading_backend.cc b/src/runtime/threading_backend.cc index 3b1889aed8ef..e5520efe30a6 100644 --- a/src/runtime/threading_backend.cc +++ b/src/runtime/threading_backend.cc @@ -166,13 +166,8 @@ class ThreadGroup::Impl { #if defined(_M_X64) || defined(__x86_64__) big_count /= 2; // ignore hyper-threading #endif - const char* bind_master_core_0 = getenv("TVM_BIND_MASTER_CORE_0"); - if (bind_master_core_0 && atoi(bind_master_core_0) != 0) { - CPU_SET(sorted_order_[0], &cpuset); - } else { - for (int i = 0; i < big_count; ++i) { - CPU_SET(sorted_order_[i], &cpuset); - } + for (int i = 0; i < big_count; ++i) { + CPU_SET(sorted_order_[i], &cpuset); } } #if defined(__ANDROID__) diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index 04a3f0b25bee..af72d3b1a1df 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -461,7 +461,7 @@ void RebaseNonZeroMinLoop(const Schedule& sch) { for (IterVar iv : root_iter_vars) { size_t idx = FindNodeRef(leaf_vars, iv); auto it = s->iter_var_attrs.find(iv); - // don't need to rebase path that are binded. + // don;t need to rebase path that are binded. if (it != s->iter_var_attrs.end() && (*it).second->bind_thread.defined()) { continue; } @@ -614,74 +614,10 @@ void InjectInline(ScheduleNode* sch) { } } -void LegalizeInvalidAttach(ScheduleNode* sch) { - std::unordered_map replace_map; - - for (Stage stage : sch->stages) { - for (Stage s = stage; s.defined();) { - Stage spec = s.GetAttachSpec(); - if (spec->attach_type != kScope) { - break; - } - bool start_attach = false; - IterVar attach_ivar = spec->attach_ivar; - s = spec->attach_stage; - CHECK(attach_ivar.defined()); - CHECK(s.defined()); - - for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) { - IterVar iv = s->leaf_iter_vars[i - 1]; - if (!start_attach && iv.same_as(attach_ivar)) { - start_attach = true; - } - } - if (!start_attach) { - // If the attach_var is fused into another iter_var, update the - // attach_var to be the fused one - // Do this recursively. - IterVar new_attach_ivar = attach_ivar;; - bool updated = true; - while (updated) { - updated = false; - for (const auto& rel : s->relations) { - if (const FuseNode* r = rel.as()) { - if (new_attach_ivar.same_as(r->inner)) { - new_attach_ivar = r->fused; - updated = true; - } - } else if (const SplitNode* r = rel.as()) { - if (new_attach_ivar.same_as(r->parent)) { - new_attach_ivar = r->inner; - updated = true; - } - } - } - replace_map[attach_ivar] = new_attach_ivar; - } - } - } - } - - // remap the parent relation - for (Stage s : sch->stages) { - if (s->attach_type != kScope) continue; - if (replace_map.count(s->attach_ivar)) { - s->attach_ivar = replace_map.at(s->attach_ivar); - } - } - for (Stage s : sch->groups) { - if (s->attach_type != kScope) continue; - if (replace_map.count(s->attach_ivar)) { - s->attach_ivar = replace_map.at(s->attach_ivar); - } - } -} - Schedule Schedule::normalize() { Schedule sn = copy(); InjectInline(sn.operator->()); RebaseNonZeroMinLoop(sn); - LegalizeInvalidAttach(sn.operator->()); return sn; } diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index f6a8ad034aa5..1fbae0fd2dcd 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -33,22 +33,20 @@ namespace tvm { namespace tir { -class GPUCodeVerifier : public StmtExprVisitor { +class GPUCodeVerifier : public StmtVisitor { public: bool Verify(Stmt stmt, int64_t max_local_memory_per_block, int64_t max_shared_memory_per_block, int64_t max_threads_per_block, int64_t max_thread_x, int64_t max_thread_y, - int64_t max_thread_z, int64_t max_vector_bytes) { + int64_t max_thread_z) { max_local_memory_per_block_ = static_cast(max_local_memory_per_block); max_shared_memory_per_block_ = static_cast(max_shared_memory_per_block); max_threads_per_block_ = static_cast(max_threads_per_block); max_thread_x_ = static_cast(max_thread_x); max_thread_y_ = static_cast(max_thread_y); max_thread_z_ = static_cast(max_thread_z); - max_vector_bytes_ = static_cast(max_vector_bytes); Reset_(); - // TODO(jcf94): Add support of detecting CUDA Misaligned Address error this->VisitStmt(stmt); return valid_; @@ -64,10 +62,6 @@ class GPUCodeVerifier : public StmtExprVisitor { size_t size = static_cast(op->constant_allocation_size()); shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes(); } - - if (op->dtype.lanes() > 1) { - valid_ &= op->dtype.lanes() * op->dtype.bytes() <= static_cast(max_vector_bytes_); - } } void VisitStmt_(const AttrStmtNode* op) final { @@ -135,18 +129,6 @@ class GPUCodeVerifier : public StmtExprVisitor { } } - void VisitExpr_(const LoadNode* op) { - // Currently not able to check: - // if the index expression failed to be simplified to a Ramp - if (op->index->IsInstance()) { - if (op->dtype.lanes() > 1) { - valid_ &= op->dtype.lanes() * op->dtype.bytes() <= - static_cast(max_vector_bytes_); - } - } - ExprVisitor::VisitExpr_(op); - } - private: int nest_level_{0}; @@ -164,7 +146,6 @@ class GPUCodeVerifier : public StmtExprVisitor { size_t max_shared_memory_per_block_; size_t max_threads_per_block_; size_t max_thread_x_, max_thread_y_, max_thread_z_; - size_t max_vector_bytes_; bool valid_{true}; @@ -188,32 +169,27 @@ bool VerifyGPUCode(const PrimFunc& func, Map constraints) { int64_t max_thread_x = INT64_MAX; int64_t max_thread_y = INT64_MAX; int64_t max_thread_z = INT64_MAX; - int64_t max_vector_bytes = INT64_MAX; for (auto iter : constraints) { const IntImmNode* val = iter.second.as(); - if (iter.first == "max_local_memory_per_block") { + if (iter.first == "max_local_memory_per_block") max_local_memory_per_block = val->value; - } else if (iter.first == "max_shared_memory_per_block") { + else if (iter.first == "max_shared_memory_per_block") max_shared_memory_per_block = val->value; - } else if (iter.first == "max_threads_per_block") { + else if (iter.first == "max_threads_per_block") max_threads_per_block = val->value; - } else if (iter.first == "max_thread_x") { + else if (iter.first == "max_thread_x") max_thread_x = val->value; - } else if (iter.first == "max_thread_y") { + else if (iter.first == "max_thread_y") max_thread_y = val->value; - } else if (iter.first == "max_thread_z") { + else if (iter.first == "max_thread_z") max_thread_z = val->value; - } else if (iter.first == "max_vector_bytes") { - max_vector_bytes = val->value; - } else { + else LOG(FATAL) << "Invalid check item: " << iter.first; - } } return verifier.Verify(func->body, max_local_memory_per_block, max_shared_memory_per_block, - max_threads_per_block, max_thread_x, max_thread_y, max_thread_z, - max_vector_bytes); + max_threads_per_block, max_thread_x, max_thread_y, max_thread_z); } TVM_REGISTER_GLOBAL("tir.analysis.verify_gpu_code").set_body_typed(VerifyGPUCode); diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index 4f1078165f34..a15190665949 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -43,7 +43,6 @@ struct UnrollLoopConfigNode : public tvm::AttrsNode { int auto_max_depth; int auto_max_extent; int explicit_unroll; - int explicit_unroll_max_extent; TVM_DECLARE_ATTRS(UnrollLoopConfigNode, "tir.transform.UnrollLoopConfig") { TVM_ATTR_FIELD(auto_max_step) @@ -58,9 +57,6 @@ struct UnrollLoopConfigNode : public tvm::AttrsNode { TVM_ATTR_FIELD(explicit_unroll) .describe("Whether to explicitly unroll the loop instead of setting a pragma") .set_default(true); - TVM_ATTR_FIELD(explicit_unroll_max_extent) - .describe("The maximum extent of a loop that can be unrolled explicitly (-1 for infinite)") - .set_default(32); } }; @@ -75,12 +71,11 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig); class LoopUnroller : public StmtExprMutator { public: explicit LoopUnroller(int auto_max_step, int auto_max_depth, int auto_max_extent, - bool explicit_unroll, int explicit_unroll_max_extent) + bool explicit_unroll) : auto_max_step_(auto_max_step), auto_max_depth_(auto_max_depth), auto_max_extent_(auto_max_extent), - explicit_unroll_(explicit_unroll), - explicit_unroll_max_extent_(explicit_unroll_max_extent) {} + explicit_unroll_(explicit_unroll) {} Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == "pragma_auto_unroll_max_step") { @@ -170,12 +165,6 @@ class LoopUnroller : public StmtExprMutator { // For loop must have a constant integer extent CHECK_NE(value, -1) << "loop doesn't have a constant integer extent"; if (value == 0) return Evaluate(0); - if (explicit_unroll_max_extent_ > 0 && value > explicit_unroll_max_extent_ && - explicit_unroll_) { - // Do not unroll too long loops - ForType for_type = op->for_type == ForType::Unrolled ? ForType::Serial : op->for_type; - return For(op->loop_var, op->min, op->extent, for_type, op->device_api, op->body); - } Stmt body = op->body; Map vmap; Array unrolled; @@ -208,10 +197,7 @@ class LoopUnroller : public StmtExprMutator { // max extent of loop to auto unroll // this not not count the total steps, only count the number of loops int auto_max_extent_; - // Whether to explicitly unroll the loop instead of setting a pragma bool explicit_unroll_; - // The maximum extent of a loop that can be unrolled explicitly (-1 means infinite) - int explicit_unroll_max_extent_; // Number of normal loops in scope int normal_loop_depth_{0}; // number of unrolled cases in current scope. @@ -224,7 +210,7 @@ class LoopUnroller : public StmtExprMutator { Stmt UnrollLoop(Stmt stmt, UnrollLoopConfig cfg) { Stmt ret = LoopUnroller(cfg->auto_max_step, cfg->auto_max_depth, cfg->auto_max_extent, - cfg->explicit_unroll, cfg->explicit_unroll_max_extent)(stmt); + cfg->explicit_unroll)(stmt); if (!ret.same_as(stmt)) { return ConvertSSA(ret); } else { diff --git a/tests/python/unittest/test_ansor_feature.py b/tests/python/unittest/test_ansor_feature.py deleted file mode 100644 index 705556c65edf..000000000000 --- a/tests/python/unittest/test_ansor_feature.py +++ /dev/null @@ -1,150 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Test feature extraction""" - -import math -import tempfile - -import tvm -from tvm import te, ansor - -from test_ansor_common import matmul_ansor_test - - -def fequal(a, b): - return math.fabs(a - b) < 1e-6 - - -def test_cpu_matmul(): - dag = ansor.ComputeDAG(matmul_ansor_test(512, 512, 512)) - s = dag.get_init_state() - C = s.stage_ops[2] - - i, j, k = s[C].iters - io, ii = s.split(C, i, [16]) - jo, ji = s.split(C, j, [8]) - s.reorder(C, [io, jo, k, ji, ii]) - s.vectorize(C, ji) - s.parallel(C, io) - s.parallel(C, jo) - s.unroll(C, k) - - target = tvm.target.create('llvm') - task = ansor.SearchTask(dag, "test", target) - names = ansor.feature.get_per_stmt_feature_names() - fea = ansor.feature.get_per_stmt_features_from_states([s], task)[0] - - stage_0 = fea[0] - assert len(stage_0) == len(names), "%d vs %d" % (len(stage_0), len(names)) - fea_dict = {} - for name, value in zip(names, stage_0): - fea_dict[name] = value - - for name in ["B0", "B1", "B2"]: - if fequal(fea_dict[name + ".acc_type.kReadWrite"], 1.0): - c_name = name - if fequal(fea_dict[name + ".acc_type.kRead"], 1.0): - if fequal(fea_dict[name + ".stride"], 0.0): - b_name = name - else: - a_name = name - - assert fequal(fea_dict[c_name + ".bytes"], math.log2(512 ** 3 * 4 + 1)) - assert fequal(fea_dict[b_name + ".unique_bytes"], math.log2(512 ** 2 * 4 + 1)) - assert fequal(fea_dict[c_name + ".reuse_dis_iter"], math.log2(8 * 16 + 1)) - assert fequal(fea_dict[c_name + ".reuse_dis_bytes"], math.log2((8 * 16 + 8 + 16) * 4 + 1)) - assert fequal(fea_dict[c_name + ".reuse_ct"], math.log2(512 + 1)) - - assert fequal(fea_dict["unroll_num"], math.log2(1 + 1)) - # assert fequal(fea_dict["unroll_type.kPosInnerReduce"], 1.0) - assert fequal(fea_dict["vec_num"], math.log2(1 + 1)) - assert fequal(fea_dict["parallel_num"], math.log2(2 + 1)) - assert fequal(fea_dict["parallel_prod"], math.log2((512 * 512 / 16 / 8) + 1)) - - -def test_cpu_fusion(): - def fusion_test(N, M): - A = te.placeholder((N, M), name='A') - B = te.compute((N, M), lambda i, j: A[i][j], name='B') - C = te.compute((N, M), lambda i, j: B[i][j], name='C') - return [A, B, C] - - dag = ansor.ComputeDAG(fusion_test(64, 32)) - s = dag.get_init_state() - s.compute_at(1, 2, s.stages[2].iters[1]) - - target = tvm.target.create('llvm') - task = ansor.SearchTask(dag, "test", target) - names = ansor.feature.get_per_stmt_feature_names() - fea = ansor.feature.get_per_stmt_features_from_states([s], task)[0] - - found = False - for stage_fea in fea: - for i, (name, value) in enumerate(zip(names, stage_fea)): - if 'reuse_type.kSerialMultipleReadWrite' in name and value > 0.5: - assert fequal(stage_fea[i + 2], 1.0) - assert fequal(stage_fea[i + 3], math.log2(16 + 1)) - found = True - assert found - - -def test_gpu_feature(): - ctx = tvm.context("cuda", 0) - if not ctx.exist: - return - - json_records = "\n".join(( - """{"i": [["[\\"matmul_ansor_test\\", 512, 512, 512]", "cuda"], [[], [["CHW", 2, "local"], ["SP", 2, 0, 512, [1, 16, 32, 1], 1], ["SP", 2, 5, 512, [4, 1, 1, 16], 1], ["SP", 2, 10, 512, [1, 2], 1], ["RE", 2, [0, 5, 1, 6, 2, 7, 10, 11, 3, 8, 12, 4, 9]], ["FSP", 3, 0, 1, 3], ["FSP", 3, 4, 2, 3], ["RE", 3, [0, 4, 1, 5, 2, 6, 3, 7]], ["FU", 2, [0, 1]], ["FU", 3, [0, 1]], ["FU", 2, [1, 2]], ["FU", 3, [1, 2]], ["FU", 2, [2, 3]], ["FU", 3, [2, 3]], ["CA", 2, 3, 2], ["CHR", 1, "shared", [2]], ["CA", 2, 3, 3], ["FU", 2, [0, 1]], ["FFSP", 2, 0, [1, 2], 1, 1], ["AN", 2, 1, 6], ["CHR", 0, "shared", [3]], ["CA", 1, 4, 3], ["FU", 1, [0, 1]], ["FFSP", 1, 0, [1, 2], 1, 1], ["AN", 1, 1, 6], ["AN", 5, 0, 5], ["AN", 5, 1, 4], ["AN", 5, 2, 6], ["PR", 4, 0, "auto_unroll_max_step$1024"]]]], "r": [[0.00536798], 0, 2.49277, 1585564852], "v": "v0.1"}""", - )) - - # load states - with tempfile.NamedTemporaryFile(mode='w') as f: - f.write(json_records) - f.flush() - inputs, results = ansor.LogReader(f.name).read_lines() - - inp = inputs[0] - dag = ansor.workload_key_to_dag(inp.task.workload_key) - task = ansor.SearchTask(dag, inp.task.workload_key, inp.task.target, None, ansor.HardwareParams(100000, 16, 64, 4, 64)) - - state = ansor.serialization.get_states_from_measure_inputs(inputs, task)[0] - state = dag.infer_bound_from_state(state) - fea = ansor.feature.get_per_stmt_features_from_states([state], task)[0] - names = ansor.feature.get_per_stmt_feature_names() - - # build feature dict - fea_dicts = [] - for i in range(len(fea)): - tmp_dict = {} - for j in range(len(names)): - tmp_dict[names[j]] = fea[i][j] - fea_dicts.append(tmp_dict) - - # check values - assert fequal(fea_dicts[0]['blockIdx_x_len'], math.log2(8 + 1)) - assert fequal(fea_dicts[0]['vthread_len'], math.log2(4 + 1)) - assert fequal(fea_dicts[1]['threadIdx_x_len'], math.log2(16 + 1)) - assert fequal(fea_dicts[0]['threadIdx_y_len'], math.log2(1 + 1)) - assert fequal(fea_dicts[2]['blockIdx_z_len'], math.log2(1 + 1)) - assert fequal(fea_dicts[0]['is_gpu'], 1.0) - - -if __name__ == "__main__": - test_cpu_matmul() - test_cpu_fusion() - test_gpu_feature() diff --git a/tests/python/unittest/test_ansor_relay_integration.py b/tests/python/unittest/test_ansor_relay_integration.py deleted file mode 100644 index 1ad507e2f371..000000000000 --- a/tests/python/unittest/test_ansor_relay_integration.py +++ /dev/null @@ -1,114 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" Test Relay Integration """ - -import tempfile -import numpy as np - -import tvm -from tvm import ansor, relay -import tvm.contrib.graph_runtime as runtime -from tvm.relay.testing import dqn - -def test_tune_dense_graph(): - def dense_graph(N, dtype="float32"): - ori_data = relay.var("data", shape=(N, N), dtype=dtype) - weight = relay.var("weight", shape=(N, N), dtype=dtype) - data = relay.multiply(ori_data, relay.const(2, dtype=dtype)) - dense = relay.nn.dense(data, weight, out_dtype=dtype) - dense = relay.add(dense, weight) - dense = relay.nn.dense(dense, weight, out_dtype=dtype) - return ori_data, weight, dense - - N = 128 - data, weight, dense = dense_graph(N) - mod = relay.Function([data, weight], dense) - mod = tvm.IRModule.from_expr(mod) - - ctx = tvm.context("llvm") - target = tvm.target.create("llvm") - d = tvm.nd.array(np.random.uniform(size=(N, N)).astype(data.type_annotation.dtype), ctx) - w = tvm.nd.array(np.random.uniform(size=(N, N)).astype(weight.type_annotation.dtype), ctx) - wkl_keys, wkl_weights = ansor.extract_from_program(mod, {}, target=target) - - assert len(wkl_keys) == 2 - assert len(wkl_weights) == 2 - - tasks = [] - for wkl_key in wkl_keys: - dag = ansor.workload_key_to_dag(wkl_key) - tasks.append(ansor.SearchTask(dag, wkl_key, target)) - - tuner = ansor.SimpleTaskScheduler(tasks) - measure_ctx = ansor.LocalRPCMeasureContext() - with tempfile.NamedTemporaryFile() as fp: - tuner.tune(ansor.TuneOption(n_trials=2, runner=measure_ctx.runner, - measure_callbacks=[ansor.LogToFile(fp.name)])) - with ansor.apply_history_best(fp.name): - with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}): - graph, lib, opt_params = relay.build_module.build( - mod, target=target) - - m = runtime.create(graph, lib, ctx) - m.set_input('data', d) - m.set_input('weight', w) - m.run() - res = m.get_output(0) - - del measure_ctx - - d = d.asnumpy() - d = d * 2 - w = w.asnumpy() - d = np.dot(d, np.transpose(w)) - d = d + w - d = np.dot(d, np.transpose(w)) - - tvm.testing.assert_allclose(res.asnumpy(), d, rtol=1e-5) - - -def test_tune_dqn(): - mod, params = dqn.get_workload(1, image_shape=(84, 84, 4), layout='NHWC') - target = tvm.target.create('llvm') - - wkl_keys, wkl_weights = ansor.extract_from_program(mod, params, target) - - tasks = [] - for wkl_key in wkl_keys: - dag = ansor.workload_key_to_dag(wkl_key) - tasks.append(ansor.SearchTask(dag, wkl_key, target)) - - assert len(tasks) == 5 - - tuner = ansor.SimpleTaskScheduler(tasks) - measure_ctx = ansor.LocalRPCMeasureContext() - with tempfile.NamedTemporaryFile() as fp: - tuner.tune(ansor.TuneOption(n_trials=len(tasks), runner=measure_ctx.runner, - measure_callbacks=[ansor.LogToFile('tmp.json')]), - search_policy='sketch.random') - with ansor.apply_history_best('tmp.json'): - ansor.prepare_layout_rewrite(mod, params, target) - with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}): - graph, lib, opt_params = relay.build_module.build(mod, target=target) - ansor.finish_layout_rewrite() - - del measure_ctx - -if __name__ == "__main__": - test_tune_dense_graph() - test_tune_dqn() - diff --git a/tests/python/unittest/test_ansor_task_scheduler.py b/tests/python/unittest/test_ansor_task_scheduler.py deleted file mode 100644 index 53cf2059c1f3..000000000000 --- a/tests/python/unittest/test_ansor_task_scheduler.py +++ /dev/null @@ -1,52 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Test the task scheduler """ - -import threading - -import tvm -from tvm import ansor - -from test_ansor_common import matmul_ansor_test - -def test_task_scheduler_basic(): - N = 128 - A, B, C = matmul_ansor_test(N, N, N) - dag = ansor.ComputeDAG([A, B, C]) - tgt = tvm.target.create("llvm") - task1 = ansor.SearchTask(dag, "test", tgt) - task2 = ansor.SearchTask(dag, "test", tgt) - - def basic_test_func(task1, task2): - def objective(costs): - return sum(costs) - - task_scheduler = ansor.SimpleTaskScheduler([task1, task2], objective) - tune_option = ansor.TuneOption(n_trials=3, runner='local') - task_scheduler.tune(tune_option) - - # Ansor search process with local runner has some modification on thread - # binding, wrap this to a subprocess to eliminate the impacts to other tests - t = threading.Thread(target=basic_test_func, - kwargs={'task1': task1, 'task2': task2}) - t.start() - t.join() - - -if __name__ == "__main__": - test_task_scheduler_basic() diff --git a/tests/python/unittest/test_tir_transform_unroll_loop.py b/tests/python/unittest/test_tir_transform_unroll_loop.py index 12c686634548..68639940bb05 100644 --- a/tests/python/unittest/test_tir_transform_unroll_loop.py +++ b/tests/python/unittest/test_tir_transform_unroll_loop.py @@ -110,31 +110,7 @@ def test_unroll_single_count_loops(): ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body assert ret == stmt -def test_unroll_explicitly_max_extent(): - n = 64 - A = te.placeholder((n,), name='A') - B = te.compute((n,), lambda *i: A(*i), name='B') - s = te.create_schedule(B.op) - s = s.normalize() - dom_map = tvm.te.schedule.InferBound(s) - stmt = tvm.te.schedule.ScheduleOps(s, dom_map) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) - - with tvm.transform.PassContext(config={ - "tir.UnrollLoop": {"explicit_unroll_max_extent": n-1} - }): - ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body - assert tvm.ir.structural_equal(ret, stmt) - - with tvm.transform.PassContext(config={ - "tir.UnrollLoop": {"explicit_unroll_max_extent": n} - }): - ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body - assert not tvm.ir.structural_equal(ret, stmt) - - if __name__ == "__main__": test_unroll_loop() test_unroll_fake_loop() test_unroll_single_count_loops() - test_unroll_explicitly_max_extent() diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 7dd782f5b622..e0e455667889 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -1295,75 +1295,6 @@ inline Tensor layout_transform(const Tensor& src, const std::string& src_layout, name, tag); } -/*! - * \brief utility function for kernel_layout_transform - */ -inline void parse_kernel_layout(const String& layout, - Array* shape, - std::vector* axes) { - int32_t factor = 0; - std::string axis = ""; - for (char c : std::string(layout)) { - if (c >= 'A' && c <= 'z') { - axis += c; - if (factor != 0) { - shape->push_back(factor); - factor = 0; - } - } else if (c >= '0' && c <= '9') { - factor = factor * 10 + c - '0'; - if (!axis.empty()) { - axes->push_back(axis); - axis = ""; - } - } else { - LOG(FATAL) << "Invalid layout " << layout; - } - } - if (!axis.empty()) { - axes->push_back(axis); - } -} - -/*! - * \brief Transform the kernel layout according to \p src_layout and \p dst_layout - * \param src the source input. - * \param src_layout the source layout. - * \param dst_layout the destination layout. - * \param name output tensor name. - * \param tag output tensor tag. - * \return A tensor with shape in \p dst_layout - */ -inline Tensor kernel_layout_transform(const Tensor& src, - const String& src_layout, - const String& dst_layout, - const String name = "T_kernel_layout_trans", - const String tag = kInjective) { - Array src_shape; - std::vector src_axes; - Array dst_shape; - std::vector dst_axes; - - parse_kernel_layout(src_layout, &src_shape, &src_axes); - parse_kernel_layout(dst_layout, &dst_shape, &dst_axes); - return compute( - dst_shape, [&](const Array& dst_indices) { - Array dst_indices_expr(dst_indices.begin(), dst_indices.end()); - Array src_indices; - for (const std::string& src_axis : src_axes) { - PrimExpr src_index = 0; - CHECK_EQ(dst_indices_expr.size(), dst_axes.size()); - for (size_t i = 0; i < dst_axes.size(); ++i) { - if (dst_axes[i] == src_axis) { - src_index = src_index * dst_shape[i] + dst_indices_expr[i]; - } - } - src_indices.push_back(src_index); - } - return src(src_indices); - }, name, tag); -} - /*! * \brief Get the shape of input tensor. * \param src the input tensor. diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index 6800129c12aa..4c7941b49692 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -20,7 +20,7 @@ from __future__ import absolute_import as _abs from collections import namedtuple import tvm -from tvm import te, ansor +from tvm import te from .pad import pad from .util import get_pad_tuple @@ -342,37 +342,7 @@ def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'): dilation_h, dilation_w = dilation batch, in_height, in_width, in_channel = Input.shape - if ansor.GLOBAL_SCOPE.topi_in_compute_rewrite_mode: - # infer shape for the rewritten layout - if len(Filter.shape) >= 10: - # For cpu tile structure SSRSRS - base = len(Filter.shape) - 10 - kernel_h = Filter.shape[2 + base] * Filter.shape[6 + base] - kernel_w = Filter.shape[3 + base] * Filter.shape[7 + base] - channel = Filter.shape[4 + base] * Filter.shape[8 + base] - num_filter = Filter.shape[5 + base] * Filter.shape[9 + base] - for i in range(base + 2): - num_filter *= Filter.shape[i] - elif len(Filter.shape) == 6: - # For cpu tile structure SRS - num_filter = Filter.shape[0] * Filter.shape[1] * Filter.shape[5] - kernel_h = Filter.shape[2] - kernel_w = Filter.shape[3] - channel = Filter.shape[4] - elif len(Filter.shape) == 5: - # For cpu tile structure SRS - num_filter = Filter.shape[0] * Filter.shape[4] - kernel_h = Filter.shape[1] - kernel_w = Filter.shape[2] - channel = Filter.shape[3] - elif len(Filter.shape) == 4: - num_filter, kernel_h, kernel_w, channel = Filter.shape - else: - raise ValueError("Don't know how to infer layout for filter shape: %s. " \ - "You can add a new branch for it to fix this." % str(Filter)) - else: - kernel_h, kernel_w, channel, num_filter = Filter.shape - + kernel_h, kernel_w, channel, num_filter = Filter.shape # compute the output shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 @@ -392,9 +362,8 @@ def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'): lambda nn, yy, xx, ff: te.sum( PaddedInput[nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rc].astype(out_dtype) * - Filter[ry, rx, rc, ff].astype(out_dtype) - , axis=[ry, rx, rc]), - name="Conv2dOutput", tag="conv2d_nhwc", attrs={"layout_free_placeholders": [Filter]}) + Filter[ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc]), + name="Conv2dOutput", tag="conv2d_nhwc") return Output diff --git a/tutorials/ansor/README.txt b/tutorials/ansor/README.txt deleted file mode 100644 index 85b6ba401dae..000000000000 --- a/tutorials/ansor/README.txt +++ /dev/null @@ -1,4 +0,0 @@ -.. _tutorial-ansor-auto-schedule: - -Ansor: Template Free Auto Scheduling ------------------------------------- diff --git a/tutorials/ansor/tune_conv2d_cuda.py b/tutorials/ansor/tune_conv2d_cuda.py deleted file mode 100644 index 03f1b24a768e..000000000000 --- a/tutorials/ansor/tune_conv2d_cuda.py +++ /dev/null @@ -1,179 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Auto-scheduling High Performance Convolution on NVIDIA GPUs -=========================================================== -**Author**: `Lianmin Zheng `_, \ - `Chengfan Jia `_, \ - `Minmin Sun `_, \ - `Zhao Wu `_ - -This is an tutorial for searching high performance schedule for NVIDIA GPU using -Ansor auto-scheduler. By running Ansor on this template, we can outperform the -vendor provided library CuDNN in many cases. -""" - -###################################################################### -# Install dependencies -# -------------------- -# To use autotvm package in tvm, we need to install some extra dependencies. -# (change "3" to "2" if you use python2): -# -# .. code-block:: bash -# -# pip3 install --user psutil xgboost tornado -# -# To make TVM run faster in tuning, it is recommended to use cython -# as FFI of tvm. In the root directory of tvm, execute -# -# .. code-block:: bash -# -# pip3 install --user cython -# sudo make cython3 -# -# Now return to python code. Import packages. - -import random -import sys - -import numpy as np -import tvm -import topi -from topi.testing import conv2d_nchw_python -from tvm import te - -# the module is called `ansor` -from tvm import ansor - -###################################################################### -# Step 1: Define the search task -# ------------------------------- -# There are plenty of useful schedule primitives in tvm. You can also find -# some tutorials that describe them in more details, such as -# (1). :ref:`opt-conv-gpu` -# (2). `Optimizing DepthwiseConv on NVIDIA GPU `_ -# -# It's usually a hard job if one wants to get a high performance schedule for a -# specific workload. Even writing an AutoTVM tunable template needs user to have -# expertises on how each schedule primitive works as well as how they finally -# reflect on the hardward architecture. -# -# However, with Ansor this will be quite simple. Firstly, define the target workload. -# Both :code:`tvm.te` API or topi op API are fine to be used. -# -# We can use the retuned :code:`Tensors` to create a ComputeDAG just like what we do -# in :ref:`ansor-simple-subgraph`, while the way to use workload registry is more -# recommended. - -# Use an extra function decorator to regist this workload -@ansor.register_workload_func -def conv2d_nchw(N, H, W, CO, CI, KH, KW, stride, padding): - data = te.placeholder((N, CI, H, W), name='data') - kernel = te.placeholder((CO, CI, KH, KW), name='kernel') - conv = topi.nn.conv2d_nchw(data, kernel, stride, padding, dilation=1, out_dtype='float32') - - return [data, kernel, conv] - -###################################################################### -# Step 2: Search through the schedule space -# ------------------------------------------ -# We pick the last layer on resnet as test case. -# Since our space is very large, :code:`XGBModel` is most suitable -# for our case. Here we only do 20 trials for demonstration. -# In practice, making 1000 trials usually can find some good kernels -# for this workload. - -tgt = tvm.target.cuda() - -# The last layer in resnet -N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1) -# Generate workload key with the ansor API -wkl_key = ansor.make_workload_key_func(conv2d_nchw, (N, H, W, CO, CI, KH, KW, strides, padding)) -# Generate ComputeDAG using the workload key -dag = ansor.workload_key_to_dag(wkl_key) -task = ansor.SearchTask(dag, wkl_key, target=tgt) - -log_file = "conv2d_nchw.json" -seed = 0 -random.seed(seed) -cost_model = ansor.XGBModel(seed=seed) -search_policy = ansor.SketchSearchPolicy(cost_model, seed=seed) - -######################################################################### -# The :code:`ansor.LocalRPCMeasureContext` is used to create a RPC runner environment. -# -# Use local gpu, measure 10 times for every schedule to reduce variance. The timeout -# for each running is set to 4 seconds. -# -# During the searching process, we may generate several invalid schedules and they -# will be filtered out. It's fine to see "Encountered errors during feature extraction." -# in the tuning logs. -# :code:`ansor.LogToFile` callback will log the tuning results into a -# log file, which can be used to get the best config later. -# :code:`ansor.PreloadMeasuredStates` callback will load measured states -# from history log before schedule search, we can add this callback to make -# sure a same schedule will never be measured for multiple times. - -measure_ctx = ansor.LocalRPCMeasureContext(repeat=3, min_repeat_ms=100, timeout=4) -tune_option = ansor.TuneOption(n_trials=20, - runner=measure_ctx.runner, - measure_callbacks=[ansor.LogToFile(log_file)], - pre_search_callbacks=[ansor.PreloadMeasuredStates(log_file)]) -s, arg_bufs = ansor.auto_schedule(task, search_policy=search_policy, tune_option=tune_option) - -print("==== Get Lowered Stmt ====") -print(tvm.lower(s, arg_bufs, simple_mode=True)) - -# Release the RPC runner environment -del measure_ctx - -######################################################################### -# From the example lower result showed above, we can see that Ansor has tried -# techniques such as `Shared Memory Cooperative Fetching`, `Kernel Fusion`, -# `Axis unroll`, `Axis Vectorize` and so on. There is no need for users to care -# about the details, and Ansor will catch them well. -# -# Finally we can directly use the returned result to get the generated schedule, -# while in the following tutorial we'll show how to inspect the best config from -# log file, check correctness, and measure running time. - -# Get history best from log file -inp, res = ansor.best_measure_pair_in_file(log_file) -# Get the task ComputeDAG from log result -dag = ansor.workload_key_to_dag(inp.task.workload_key) -# Apply log result to TVM schedule -s, arg_bufs = dag.apply_steps_from_state(inp.state) -func = tvm.build(s, arg_bufs, target=tgt) - -# check correctness -a_np = np.random.uniform(size=(N, CI, H, W)).astype(np.float32) -w_np = np.random.uniform(size=(CO, CI, KH, KW)).astype(np.float32) -c_np = conv2d_nchw_python(a_np, w_np, strides, padding) - -ctx = tvm.gpu() -a_tvm = tvm.nd.array(a_np, ctx=ctx) -w_tvm = tvm.nd.array(w_np, ctx=ctx) -c_tvm = tvm.nd.empty(c_np.shape, ctx=ctx) -func(a_tvm, w_tvm, c_tvm) - -tvm.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-2) - -# Evaluate running time. Here we choose a large repeat number (400) to reduce the noise -# and the overhead of kernel launch. You can also use nvprof to validate the result. -evaluator = func.time_evaluator(func.entry_name, ctx, number=400) -print('Time cost of this operator: %f s' % evaluator(a_tvm, w_tvm, c_tvm).mean) - diff --git a/tutorials/ansor/tune_simple_subgraph.py b/tutorials/ansor/tune_simple_subgraph.py deleted file mode 100644 index 00bef82cf855..000000000000 --- a/tutorials/ansor/tune_simple_subgraph.py +++ /dev/null @@ -1,193 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -.. _ansor-simple-subgraph: - -Writing compute expression and Using Ansor auto-scheduler -========================================================= -**Author**: `Lianmin Zheng `_, \ - `Chengfan Jia `_, \ - `Minmin Sun `_, \ - `Zhao Wu `_ - -This is an introduction tutorial to the auto-scheduler module in TVM. - -There are two steps in auto-scheduling. -The first step is defining the target task. -The second step is running a search algorithm to auto explore the schedule. -In this tutorial, you can learn how to perform these two steps in TVM. -The whole workflow is illustrated by a matrix multiplication with bias add example. -""" - -###################################################################### -# Install dependencies -# -------------------- -# To use Ansor package in TVM, we need to install some extra dependencies. -# This step (installing xgboost) can be skipped as it doesn't need XGBoost -# (change "3" to "2" if you use python2): -# -# .. code-block:: bash -# -# pip3 install --user psutil xgboost -# -# To make TVM run faster in tuning, it is recommended to use cython -# as FFI of TVM. In the root directory of TVM, execute -# (change "3" to "2" if you use python2): -# -# .. code-block:: bash -# -# pip3 install --user cython -# sudo make cython3 -# -# Now return to python code. Import packages. - -import random -import sys - -import numpy as np -import tvm -from tvm import te - -# the module is called `ansor` -from tvm import ansor - -###################################################################### -# Step 1: Define the target compute subgraph -# ------------------------------------------- -# In this section, we will write a deterministic TVM compute expression code -# to a compute subgraph. -# -# .. note:: Comparing to :ref:`tutorials-autotvm-sec` -# -# In Ansor, we do not need users to provide a schedule template, the only input -# is the compute expression writing by :code:`tvm.te` API or topi op API. -# -# Here is how we implement a matrix multiplication subgraph in TVM. - -# Matmul with bias add -def matmul_add(N, L, M, dtype): - A = te.placeholder((N, L), name='A', dtype=dtype) - B = te.placeholder((L, M), name='B', dtype=dtype) - C = te.placeholder((N, M), name='C', dtype=dtype) - - k = te.reduce_axis((0, L), name='k') - mul = te.compute((N, M), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), - name='Mul') - D = te.compute((N, M), lambda i, j: C[i, j] + mul[i, j], name='D') - - return [A, B, C, D] - -###################################################################### -# Step 2: Search through the schedule space -# ------------------------------------------ -# In step 1, we build the compute subgraph. -# The next step is to pick a cost model as well as a search policy and explore the -# possible schedule. -# -# Auto-scheduler in TVM -# ^^^^^^^^^^^^^^^^^^^^^ -# The job for the Ansor auto-scheduler can be described by following pseudo code -# -# .. code-block:: c -# -# ct = 0 -# while ct < max_number_of_trials: -# auto generate a batch of schedules -# measure this batch of schedules on real hardware and get results -# ct += batch_size -# -# When proposing the next batch of schedules, Ansor can take different cost models to -# guide the schedule generating process. -# -# * :code:`RandomModel`: Generate and take new schedule randomly -# * :code:`XGBModel`: Use XGBoost model to estimate the performance of potential schedules, try to pick schedules with better performance in each step -# -# XGBModel can explore more efficiently and find better schedules. - -################################################################ -# Begin tuning -# ^^^^^^^^^^^^ -# Here we continue our matrix multiplication example. -# -# The :code:`ansor.ComputeDAG` takes the Tensor list as input, and generates -# a dag structure. During which process, :code:`ansor.ComputeDAG` will -# do some analyzes with the target subgraph and the results will be used in -# search policy later. -# -# Then we create the :code:`tvm.target` and a tuning task. - -N, L, M = 128, 128, 128 -A, B, C, D = matmul_add(N, L, M, 'float32') -dag = ansor.ComputeDAG([A, B, C, D]) - -print(dag) -print(dag.access_analyzer) - -tgt = tvm.target.create("llvm") -task = ansor.SearchTask(dag, "test", tgt) - -################################################################ -# Next, we choose random model and create a default search policy: -# :code:`ansor.SketchSearchPolicy`. -# -# We only make 5 trials in this tutorial for demonstration. In practice, -# you can do more trials according to your time budget. -# :code:`ansor.LogToFile` callback will log the tuning results into a -# log file, which can be used to get the best config later. -# :code:`ansor.PreloadMeasuredStates` callback will load measured states -# from history log before schedule search, we can add this callback to make -# sure a same schedule will never be measured for multiple times. - -log_file = "matmul_add.json" - -seed = 0 -random.seed(seed) -cost_model = ansor.RandomModel() -search_policy = ansor.SketchSearchPolicy(cost_model, seed=seed) - -tune_option = ansor.TuneOption(n_trials=5, - measure_callbacks=[ansor.LogToFile(log_file)], - pre_search_callbacks=[ansor.PreloadMeasuredStates(log_file)]) - -################################################################ -# Then just call :code:`ansor.auto_schedule` and Ansor will try to find a high -# performance schedule for the target subgraph automatically. -# -# The returned result will be a :code:`te.schedule` and a list of :code:`te.Tensor`, -# which can be used as the input of :code:`tvm.lower` or :code:`tvm.build`. - -s, arg_bufs = ansor.auto_schedule(task, search_policy=search_policy, - tune_option=tune_option) - -print("==== Get Lowered Stmt ====") -print(tvm.lower(s, arg_bufs, simple_mode=True)) - -######################################################################### -# Check the correctness to make sure we generate a right schedule. - -func = tvm.build(s, arg_bufs) - -# check correctness -a_np = np.random.uniform(size=(N, L)).astype(np.float32) -b_np = np.random.uniform(size=(L, M)).astype(np.float32) -c_np = np.random.uniform(size=(N, M)).astype(np.float32) -d_np = a_np.dot(b_np) + c_np - -d_tvm = tvm.nd.empty(d_np.shape) -func(tvm.nd.array(a_np), tvm.nd.array(b_np), tvm.nd.array(c_np), d_tvm) - -tvm.testing.assert_allclose(d_np, d_tvm.asnumpy(), rtol=1e-2) diff --git a/tutorials/autotvm/README.txt b/tutorials/autotvm/README.txt index 4ad36c000e3c..38e3b3343f4e 100644 --- a/tutorials/autotvm/README.txt +++ b/tutorials/autotvm/README.txt @@ -1,4 +1,4 @@ .. _tutorials-autotvm-sec: -AutoTVM: Template Based Auto Tuning ------------------------------------ +Auto tuning +----------- From 86bfd8fbe541766f669697134964e8325d8c535a Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 24 Jun 2020 14:12:33 +0800 Subject: [PATCH 41/45] Revert commits --- python/tvm/ansor/__init__.py | 11 +- python/tvm/ansor/auto_schedule.py | 86 ---- python/tvm/ansor/compute_dag.py | 21 - python/tvm/ansor/cost_model/__init__.py | 3 +- python/tvm/ansor/cost_model/cost_model.py | 31 -- python/tvm/ansor/cost_model/xgb_model.py | 474 --------------------- python/tvm/ansor/dispatcher.py | 299 ------------- python/tvm/ansor/env.py | 25 -- python/tvm/ansor/feature.py | 150 ------- python/tvm/ansor/loop_state.py | 339 --------------- python/tvm/ansor/measure.py | 3 +- python/tvm/ansor/relay_integration.py | 241 ----------- python/tvm/ansor/task_scheduler.py | 299 ------------- python/tvm/relay/backend/compile_engine.py | 5 +- python/tvm/relay/build_module.py | 7 - python/tvm/relay/op/_transform.py | 2 - python/tvm/relay/op/op_attrs.py | 3 - python/tvm/relay/op/strategy/x86.py | 62 ++- python/tvm/relay/op/transform.py | 21 - python/tvm/relay/testing/dqn.py | 27 +- python/tvm/relay/testing/resnet.py | 22 +- python/tvm/runtime/ndarray.py | 33 -- python/tvm/te/tensor.py | 8 +- src/relay/backend/build_module.cc | 32 -- src/relay/backend/compile_engine.cc | 5 - src/relay/backend/compile_engine.h | 3 - 26 files changed, 48 insertions(+), 2164 deletions(-) delete mode 100644 python/tvm/ansor/cost_model/xgb_model.py delete mode 100644 python/tvm/ansor/dispatcher.py delete mode 100644 python/tvm/ansor/env.py delete mode 100644 python/tvm/ansor/feature.py delete mode 100644 python/tvm/ansor/relay_integration.py delete mode 100644 python/tvm/ansor/task_scheduler.py diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index edade490018c..9cce63b2840d 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -22,24 +22,15 @@ from . import serialization from . import loop_state from . import utils -from . import feature from . import workload_registry -from . import task_scheduler # Shortcut -from .compute_dag import ComputeDAG, LayoutRewriteLevel +from .compute_dag import ComputeDAG from .auto_schedule import SearchTask, SketchSearchPolicy, TuneOption, HardwareParams, \ PreloadMeasuredStates, PreloadCustomSketchRule, auto_schedule from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, LocalRPCMeasureContext from .cost_model import RandomModel -from .cost_model.xgb_model import XGBModel from .serialization import LogToFile, LogReader, best_measure_pair_in_file, \ load_from_file, write_measure_records_to_file from .workload_registry import register_workload_func, \ workload_key_to_dag, make_workload_key_func -from .task_scheduler import TaskScheduler, SimpleTaskScheduler -from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest as apply_history_best, \ - FallbackContext -from .relay_integration import extract_from_program, extract_from_multiple_program, \ - finish_layout_rewrite, prepare_layout_rewrite, auto_schedule_topi -from .env import GLOBAL_SCOPE diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index 4497bb400703..37e622018658 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -82,96 +82,10 @@ def set_verbose(self, verbose): def run_callbacks(self, callbacks): _ffi_api.SearchPolicyRunCallbacks(self, callbacks) - -@tvm._ffi.register_object("ansor.SketchSearchPolicy") -class SketchSearchPolicy(SearchPolicy): - """ The search policy that searches in a hierarchical search space defined by sketches. - The policy randomly samples programs from the space defined by sketches - and use evolutionary search to fine-tune them. - - Parameters - ---------- - program_cost_model: CostModel - Cost model for programs - params: int - Parameters of the search policy. See `src/ansor/search_policy/sketch_search_policy.h` - to find the definitions. See code below to find the default values - seed: int - Random seed - """ - def __init__(self, - program_cost_model, - params=None, - seed=None): - # set default parameters - default_params = { - "eps_greedy": 0.05, - - 'evolutionary_search_population': 2048, - 'evolutionary_search_num_iters': 15, - "evolutionary_search_mutation_prob": 0.85, - "evolutionary_search_use_measured_ratio": 0.2, - - 'cpu_multi_level_tiling_structure': 'SSRSRS', - 'gpu_multi_level_tiling_structure': 'SSSRRSRS', - - 'disable_change_compute_location': 0, - } - - if params is None: - params = default_params - else: - for key, value in default_params.items(): - if key not in params: - params[key] = value - - self.__init_handle_by_constructor__( - _ffi_api.SketchSearchPolicy, program_cost_model, params, - seed or random.randint(1, 1 << 30)) - - @tvm._ffi.register_object("ansor.SearchCallback") class SearchCallback(Object): """Callback function before or after search process""" - -@tvm._ffi.register_object("ansor.PreloadMeasuredStates") -class PreloadMeasuredStates(SearchCallback): - """ A SearchCallback to load measured states from the log file for a search policy. - This can resume the state of the search policy. - - Parameters - ---------- - filename: str - """ - def __init__(self, filename: str): - self.__init_handle_by_constructor__( - _ffi_api.PreloadMeasuredStates, filename) - - -@tvm._ffi.register_object("ansor.PreloadCustomSketchRule") -class PreloadCustomSketchRule(SearchCallback): - """ - A SearchCallback for SketchSearchPolicy that allowing users to add - custom sketch rule. - - Notes - ----- - This is an advanced feature. Make sure you're clear how it - works and this should only be used in SketchSearchPolicy. - - Parameters - ---------- - meet_condition_func: Function - A function with `(policy, state, stage_id) -> int` - apply_func: Function - A function with `(policy, state, stage_id) -> [[State, int], ...]` - """ - def __init__(self, meet_condition_func, apply_func): - self.__init_handle_by_constructor__( - _ffi_api.PreloadCustomSketchRule, meet_condition_func, apply_func) - - @tvm._ffi.register_object("ansor.TuneOption") class TuneOption(Object): """ The options for tuning diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index 6304c7bb0e0a..994c3ae3ab97 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -23,13 +23,6 @@ from . import _ffi_api -class LayoutRewriteLevel(object): - NO_REWRITE = 0 # No layout rewrite - PLACEHOLDER_REWRITE = 1 # Only rewrite layout of placeholder in the compute dag - COMPUTE_REWRITE = 2 # Only rewrite compute body for new layout in the compute dag - BOTH_REWRITE = 3 # Rewrite both placeholder and compute body in the compute dag - - @tvm._ffi.register_object("ansor.ComputeDAG") class ComputeDAG(Object): """ @@ -97,17 +90,3 @@ def infer_bound_from_state(self, state): """ state_obj = state if isinstance(state, StateObject) else state.state_object return State(_ffi_api.ComputeDAGInferBoundFromState(self, state_obj), self) - - def rewrite_layout_from_state(self, state: State): - """ - Rewrite the layout according to the transform steps in the history of a state - - Parameters - ---------- - state : StateObject - - Returns - ------- - state : StateObject - """ - return _ffi_api.ComputeDAGRewriteLayoutFromState(self, state) diff --git a/python/tvm/ansor/cost_model/__init__.py b/python/tvm/ansor/cost_model/__init__.py index 56e4a5f9128b..1454da451b61 100644 --- a/python/tvm/ansor/cost_model/__init__.py +++ b/python/tvm/ansor/cost_model/__init__.py @@ -17,5 +17,4 @@ # pylint: disable=unused-import, redefined-builtin """ Cost model that estimates the performance of programs """ -from .cost_model import RandomModel -from .xgb_model import XGBModel +from .cost_model import RandomModel \ No newline at end of file diff --git a/python/tvm/ansor/cost_model/cost_model.py b/python/tvm/ansor/cost_model/cost_model.py index fbfc8242488b..605db14c19c3 100644 --- a/python/tvm/ansor/cost_model/cost_model.py +++ b/python/tvm/ansor/cost_model/cost_model.py @@ -44,34 +44,3 @@ def random_number(n, return_ptr): return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_float)) array_wrapper = np.ctypeslib.as_array(return_ptr, shape=(n,)) array_wrapper[:] = np.random.uniform(0, 1, (n,)) - - -@tvm._ffi.register_object("ansor.PythonBasedModel") -class PythonBasedModel(CostModel): - """Base class for cost models implemented in python""" - def __init__(self): - def update_func(inputs, results): - self.update(inputs, results) - - def predict_func(task, states, return_ptr): - return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_float)) - array_wrapper = np.ctypeslib.as_array(return_ptr, shape=(len(states),)) - array_wrapper[:] = self.predict(task, states) - - def predict_stage_func(task, states, return_ptr): - ret = self.predict_stages(task, states) - return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_float)) - array_wrapper = np.ctypeslib.as_array(return_ptr, shape=ret.shape) - array_wrapper[:] = ret - - self.__init_handle_by_constructor__(_ffi_api.PythonBasedModel, update_func, - predict_func, predict_stage_func) - - def update(self, inputs, results): - raise NotImplementedError - - def predict(self, task, states): - raise NotImplementedError - - def predict_stages(self, task, states): - raise NotImplementedError diff --git a/python/tvm/ansor/cost_model/xgb_model.py b/python/tvm/ansor/cost_model/xgb_model.py deleted file mode 100644 index 42af17daae2c..000000000000 --- a/python/tvm/ansor/cost_model/xgb_model.py +++ /dev/null @@ -1,474 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Cost model based on xgboost""" -import multiprocessing -import logging -from collections import defaultdict - -import numpy as np -import xgboost as xgb - -from tvm.autotvm.tuner.xgboost_cost_model import get_rank, recall_curve, max_curve -from .cost_model import PythonBasedModel -from ..feature import get_per_stmt_features_from_measure_pairs, get_per_stmt_features_from_states -from ..serialization import LogReader - -logger = logging.getLogger('ansor') - -class XGBDMatrixContext: - """Context to hold additional attributes of xgb.DMatrix""" - def __init__(self): - self.context_dict = defaultdict(dict) - - def get(self, key, matrix, default=None): - return self.context_dict[key].get(matrix.handle.value, default) - - def put(self, key, matrix, value): - self.context_dict[key][matrix.handle.value] = value - -dmatrix_context = XGBDMatrixContext() - -class XGBModel(PythonBasedModel): - """Train a XGBoost model to predict the runtime cost of a program. - The cost of a program = the sum of the costs of all stages in this program. - i.e. Cost(p) = cost_s0 + cost_s1 + ... + cost_sn, where cost_si is the cost of Stage i - - The xgboost model makes prediction per stage, then we sum them up. - The final predction made by this class is normalized throughtput (from 0 to 1, larger is better) - - To support this stage decomposition, we have to implement a custom loss function for - XGBoost, which is the `pack_sum` in the code below. - """ - def __init__(self, verbose_eval=25, num_warmup_sample=100, seed=None): - self.xgb_params = { - 'max_depth': 10, - 'gamma': 0.001, - 'min_child_weight': 0, - 'eta': 0.2, - # todo(lmzheng): automatically decrease learning rate when the loss is too large - - 'n_gpus': 0, - 'nthread': multiprocessing.cpu_count() // 2, - 'verbosity': 0, - 'seed': seed or 43, - 'disable_default_eval_metric': 1 - } - self.bst = None - self.plan_size = 32 - self.num_warmup_sample = num_warmup_sample - self.verbose_eval = verbose_eval - - super().__init__() - - # measurement input/result pairs - self.inputs = [] - self.results = [] - self.inputs_feature_cache = [] - - def update(self, inputs, results): - if len(inputs) <= 0: - return - - self.inputs.extend(inputs) - self.results.extend(results) - - # extract feature - n_cached = len(self.inputs_feature_cache) - features, normalized_throughputs, task_ids = \ - get_per_stmt_features_from_measure_pairs(self.inputs, self.results, - skip_first_n_feature_extraction=n_cached) - if n_cached > 0: - features = list(features) - features[:n_cached] = self.inputs_feature_cache - features = np.array(features) - self.inputs_feature_cache = features - dtrain = pack_sum_xgbmatrix(features, normalized_throughputs, - task_ids, normalized_throughputs) - - # train xgb model - self.bst = xgb.train(self.xgb_params, dtrain, - num_boost_round=10000, - obj=pack_sum_square_error, - callbacks=[custom_callback( - stopping_rounds=50, - metric='tr-p-rmse', - fevals=[ - pack_sum_rmse, pack_sum_average_peak_score(self.plan_size), - ], - evals=[(dtrain, 'tr')], - maximize=False, - verbose_eval=self.verbose_eval)]) - - def predict(self, task, states): - features = get_per_stmt_features_from_states(states, task) - if self.bst is not None and len(self.inputs) > self.num_warmup_sample: - dtest, pack_ids = pack_sum_xgbmatrix_for_prediction(features) - raw_preds = self.bst.predict(dtest) - ret = pack_sum_predict_throughput(raw_preds, pack_ids) - else: - ret = np.random.uniform(0, 1, (len(states),)) - - # Predict 0 for invalid states that failed to be lowered. - for idx, feature in enumerate(features): - if feature.min() == feature.max() == 0: - ret[idx] = float('-inf') - - return ret - - def predict_stages(self, task, states): - # Format: (s0 score, ..., sN score, s0 n_stage, s0 stage 0, ..., s1 n_stage, s1 stage 0,) - features = get_per_stmt_features_from_states(states, task) - if self.bst is not None and len(self.inputs) > self.num_warmup_sample: - dtest, pack_ids = pack_sum_xgbmatrix_for_prediction(features) - raw_preds = self.bst.predict(dtest) - breakdown = pack_sum_predict_throughput(raw_preds, pack_ids) - stage_scores = [[] for _ in range(len(states))] - for pred, pack_id in zip(raw_preds, pack_ids): - stage_scores[pack_id].append(pred) - for idx, stage_score in enumerate(stage_scores): - breakdown = np.append(breakdown, len(stage_score)) - breakdown = np.concatenate((breakdown, -np.array(stage_score))) - else: - breakdown = np.concatenate( - (np.random.uniform(0, 1, (len(states), )), np.zeros(len(states), ))) - - # Predict 0 for invalid states that failed to be lowered. - for idx, feature in enumerate(features): - if feature.min() == feature.max() == 0: - breakdown[idx] = float('-inf') - - return breakdown - - def load_log_file(self, file_name, n_lines=-1): - inputs, results = LogReader(file_name).read_lines(n_lines) - logger.info("XGBModel: Loaded %s lines of history log from %s", len(inputs), file_name) - self.update(inputs, results) - - def save(self, file_name: str): - self.bst.save_model(file_name) - - def load(self, file_name: str): - if self.bst is None: - self.bst = xgb.Booster(self.xgb_params) - self.bst.load_model(file_name) - self.num_warmup_sample = -1 - - -def pack_sum_xgbmatrix_for_prediction(xs): - x_flatten = [] - pack_ids = [] - - for ct, x in enumerate(xs): - for row in x: - x_flatten.append(row) - pack_ids.append(ct) - - return xgb.DMatrix(np.array(x_flatten)), pack_ids - - -def pack_sum_xgbmatrix(xs, ys, gids=None, weights=None): - if gids is not None: - # sort by group - indices = gids.argsort() - xs, ys = xs[indices], ys[indices] - group_sizes = np.bincount(gids) - if weights is not None: - weights = weights[indices] - else: - # assume it has only one group - group_sizes = [len(xs)] - - x_flatten = [] - y_flatten = [] - weights_flatten = [] - pack_ids = [] - - if weights is not None: - for ct, (x, y, w) in enumerate(zip(xs, ys, weights)): - for row in x: - x_flatten.append(row) - y_flatten.append(y) - weights_flatten.append(w) - pack_ids.append(ct) - else: - for ct, (x, y) in enumerate(zip(xs, ys)): - for row in x: - x_flatten.append(row) - y_flatten.append(y) - pack_ids.append(ct) - - ret = xgb.DMatrix(np.array(x_flatten), y_flatten) - if weights is not None: - ret.set_weight(weights_flatten) - dmatrix_context.put('pack_ids', ret, np.array(pack_ids)) - dmatrix_context.put('group_sizes', ret, group_sizes) - return ret - -LOSS_TYPE = 3 - -# Type 0 -# The model predicts cost. Use square error of throughput as loss -# loss = 1/2 * (1 / sum(x_i) - y) ^ 2 -# -# Type 1 -# The model predicts cost. Use square error of cost as loss -# loss = 1/2 * (sum(x_i) - 1 / y) ^ 2 -# -# Type 2 -# The model predicts throughput. Use square error of throughput as loss. -# loss = 1/2 * (1 / sum(1 / x_i) - y) ^ 2 -# -# Type 3 -# The model predicts throughput. Use square error of throughput as loss. -# But approximate 1 / (1 / a_1 + 1 / a_2 + ... + 1 / a_n) with -(b_1 + b_2 + b_3) -# loss = 1/2 * (-sum(x_i) - y) ^ 2 -# -# Type 4 -# The model predicts throughput. Use square error of throughput as loss. -# But approximate 1 / (1 / a_1 + 1 / a_2 + ... + 1 / a_n) with -(b_1 + b_2 + b_3) -# Also add a sigmoid to force the prediction to be within the range of (0, 1) -# loss = 1/2 * (sigmoid(-sum(x_i)) - y) ^ 2 -# - -def pack_sum_predict_throughput(raw_preds, pack_ids): - if LOSS_TYPE == 0: - sum_pred = np.bincount(pack_ids, weights=raw_preds) - return 1 / sum_pred - elif LOSS_TYPE == 1: - sum_pred = np.bincount(pack_ids, weights=raw_preds) - return 1 / sum_pred - elif LOSS_TYPE == 2: - sum_inverse_preds = np.bincount(pack_ids, weights=1 / raw_preds) - return 1 / sum_inverse_preds - elif LOSS_TYPE == 3: - sum_pred = np.bincount(pack_ids, weights=raw_preds) - return - sum_pred # pylint: disable=invalid-unary-operand-type - elif LOSS_TYPE == 4: - sum_pred = np.bincount(pack_ids, weights=raw_preds) - return 1 / (1 + np.exp(sum_pred)) - else: - raise ValueError("Invalid loss type: " + LOSS_TYPE) - -def pack_sum_square_error(preds, dtrain): - pack_ids = dmatrix_context.get("pack_ids", dtrain) - weight = dtrain.get_weight() - - if LOSS_TYPE == 0: - sum_pred = np.bincount(pack_ids, weights=preds) - x = sum_pred[pack_ids] - y = dtrain.get_label() - gradient = (x * y - 1) / np.power(x, 3) - hessian = (3 - 2 * x * y) / np.power(x, 4) - elif LOSS_TYPE == 1: - sum_pred = np.bincount(pack_ids, weights=preds) - x = sum_pred[pack_ids] - y = dtrain.get_label() - gradient = x - 1 / np.minimum(y, 1e6) - hessian = np.ones_like(gradient) - elif LOSS_TYPE == 2: - sum_inverse_preds = np.bincount(pack_ids, weights=1 / preds)[pack_ids] - y = dtrain.get_label() - gradient = (1 / sum_inverse_preds - y) / (np.power(preds * sum_inverse_preds, 2)) - hessian = (2 * preds * y * np.power(sum_inverse_preds, 2) - 2 * y * sum_inverse_preds - 2 * preds * sum_inverse_preds + 3) / (np.power(preds * sum_inverse_preds, 4)) - elif LOSS_TYPE == 3: - sum_pred = np.bincount(pack_ids, weights=preds) - x = sum_pred[pack_ids] - y = dtrain.get_label() - gradient = x + y - hessian = np.ones_like(gradient) - elif LOSS_TYPE == 4: - sum_pred = np.bincount(pack_ids, weights=preds) - exp_x = np.exp(sum_pred[pack_ids]) - exp_2x = np.power(exp_x, 2) - y = dtrain.get_label() - gradient = exp_x * (exp_x * y + y - 1) / np.power(exp_x + 1, 3) - hessian = exp_x * (-exp_2x * y + 2 * exp_x + y - 1) / np.power(exp_x + 1, 4) - else: - raise ValueError("Invalid loss type: " + LOSS_TYPE) - - if len(weight) == 0: - return gradient, hessian - else: - return gradient * weight, hessian * weight - -def pack_sum_rmse(raw_preds, dtrain): - pack_ids = dmatrix_context.get("pack_ids", dtrain) - preds = pack_sum_predict_throughput(raw_preds, pack_ids)[pack_ids] - return 'p-rmse', np.sqrt(np.mean(np.square((preds - dtrain.get_label())))) - -def pack_sum_average_peak_score(N): - """Evaluate pack sum average peak score for xgb""" - - def feval(preds, labels): - group_sizes = dmatrix_context.get('group_sizes', labels, [len(preds)]) - pack_ids = dmatrix_context.get("pack_ids", labels) - - preds = pack_sum_predict_throughput(preds, pack_ids) - labels = (np.bincount(pack_ids, weights=labels.get_label()) - / np.unique(pack_ids, return_counts=True)[1]) - - scores = [] - offset = 0 - for size in group_sizes: - preds_group = preds[offset:offset + size] - labels_group = labels[offset:offset + size] - offset += size - - trials = np.argsort(preds_group)[::-1][:N] - trial_scores = labels_group[trials] - curve = max_curve(trial_scores) / np.max(labels_group) - scores.append(np.mean(curve)) - return "a-peak@%d" % N, np.mean(scores) - return feval - -def pack_sum_average_recall_score(N): - """Evaluate average recall score for xgb""" - - def feval(preds, labels): - group_sizes = dmatrix_context.get('group_sizes', labels, [len(preds)]) - pack_ids = dmatrix_context.get("pack_ids", labels) - - preds = pack_sum_predict_throughput(preds, pack_ids) - labels = (np.bincount(pack_ids, weights=labels.get_label()) - / np.unique(pack_ids, return_counts=True)[1]) - - scores = [] - offset = 0 - for size in group_sizes: - preds_group = preds[offset:offset + size] - labels_group = labels[offset:offset + size] - offset += size - - trials = np.argsort(preds_group)[::-1] - ranks = get_rank(labels_group[trials])[:N] - curve = recall_curve(ranks) - scores.append(np.mean(curve)) - return "a-recall@%d" % N, np.mean(scores) - return feval - - -def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None, - maximize=False, verbose_eval=True, skip_every=2): - """Callback function for xgboost to support multiple custom evaluation functions""" - from xgboost.core import EarlyStopException - from xgboost.callback import _fmt_metric - from xgboost.training import aggcv - - state = {} - metric_shortname = metric.split("-")[1] - - def init(env): - """internal function""" - bst = env.model - - state['maximize_score'] = maximize - state['best_iteration'] = 0 - if maximize: - state['best_score'] = float('-inf') - else: - state['best_score'] = float('inf') - - if bst is not None: - if bst.attr('best_score') is not None: - state['best_score'] = float(bst.attr('best_score')) - state['best_iteration'] = int(bst.attr('best_iteration')) - state['best_msg'] = bst.attr('best_msg') - else: - bst.set_attr(best_iteration=str(state['best_iteration'])) - bst.set_attr(best_score=str(state['best_score'])) - else: - assert env.cvfolds is not None - - def callback(env): - """internal function""" - if not state: - init(env) - - bst = env.model - i = env.iteration - cvfolds = env.cvfolds - - res_dict = {} - - if i % skip_every == 1: - return - - ##### evaluation ##### - if cvfolds is not None: - for feval in fevals: - tmp = aggcv([f.eval(i, feval) for f in cvfolds]) - for k, mean, std in tmp: - res_dict[k] = [mean, std] - else: - for feval in fevals: - bst_eval = bst.eval_set(evals, i, feval) - res = [x.split(':') for x in bst_eval.split()] - for kv in res[1:]: - res_dict[kv[0]] = [float(kv[1])] - - eval_res = [] - keys = list(res_dict.keys()) - keys.sort(key=lambda x: x if metric_shortname not in x else "a" + x) - for key in keys: - v = res_dict[key] - eval_res.append([key] + v) - - ##### print eval result ##### - if not isinstance(verbose_eval, bool) and verbose_eval and i % verbose_eval == 0: - infos = ["XGB iter: %3d" % i] - for item in eval_res: - if 'null' in item[0]: - continue - infos.append("%s: %.6f" % (item[0], item[1])) - - logger.debug("\t".join(infos)) - if log_file: - with open(log_file, "a") as fout: - fout.write("\t".join(infos) + '\n') - - ##### choose score and do early stopping ##### - score = None - for item in eval_res: - if item[0] == metric: - score = item[1] - break - assert score is not None - - best_score = state['best_score'] - best_iteration = state['best_iteration'] - maximize_score = state['maximize_score'] - if (maximize_score and score > best_score) or \ - (not maximize_score and score < best_score): - msg = '[%d] %s' % ( - env.iteration, - '\t'.join([_fmt_metric(x) for x in eval_res])) - state['best_msg'] = msg - state['best_score'] = score - state['best_iteration'] = env.iteration - # save the property to attributes, so they will occur in checkpoint. - if env.model is not None: - env.model.set_attr(best_score=str(state['best_score']), - best_iteration=str(state['best_iteration']), - best_msg=state['best_msg']) - elif env.iteration - best_iteration >= stopping_rounds: - best_msg = state['best_msg'] - if verbose_eval and env.rank == 0: - logger.debug("XGB stopped. Best iteration: %s ", best_msg) - raise EarlyStopException(best_iteration) - - return callback diff --git a/python/tvm/ansor/dispatcher.py b/python/tvm/ansor/dispatcher.py deleted file mode 100644 index 3a5dc4e9e206..000000000000 --- a/python/tvm/ansor/dispatcher.py +++ /dev/null @@ -1,299 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -The global context that dispatches best configurations to workloads -""" -# pylint: disable=invalid-name - -from __future__ import absolute_import as _abs - -import logging - -import numpy as np - -from tvm.tir.expr import FloatImm - -logger = logging.getLogger('auto_scheduler') - - -class DispatchContext(object): - """ - Base class of dispatch context. - """ - current = None - - def __init__(self): - self._old_ctx = DispatchContext.current - - def query(self, target, workload): - """ - Query the context to get the specific config for a workload. - If cannot find the result inside this context, this function will query it - from the upper contexts. - - Parameters - ---------- - target: Target - The current target - workload : str - The current workload - - Returns - ------- - cfg : State - The schedule configuration for the workload - """ - ret = self._query_inside(target, workload) - return ret - - def update(self, target, workload, cfg): - """ - Update the config for a workload - - Parameters - ---------- - target: Target - The current target - workload : Workload - The current workload. - cfg : State - The schedule configuration for the workload - """ - raise NotImplementedError() - - def _query_inside(self, target, workload): - """ - Query the context to get the specific config for a workload. - This function only query config inside this context. - - Parameters - ---------- - target: Target - The current target - workload : Workload - The current workload. - - Returns - ------- - cfg : State or str - The schedule configuration for the workload - """ - raise NotImplementedError() - - def __enter__(self): - self._old_ctx = DispatchContext.current - DispatchContext.current = self - return self - - def __exit__(self, ptype, value, trace): - DispatchContext.current = self._old_ctx - - -class ApplyConfig(DispatchContext): - """Apply a deterministic config for all queries. - - Parameters - ---------- - config : State - The schedule configuration - """ - def __init__(self, config): - super(ApplyConfig, self).__init__() - self._config = config - self.workload = None - - def _query_inside(self, target, workload): - """Override query""" - self.workload = workload - return self._config - - def update(self, target, workload, cfg): - """Override update""" - self.workload = workload - self._config = cfg - - -class ApplyHistoryBest(DispatchContext): - """ - Apply the history best config - - Parameters - ---------- - records : str or iterator of (MeasureInput, MeasureResult) - Collection of tuning records. - If is str, then it should be the filename of a records log file. - Each row of this file is an encoded record pair. - Otherwise, it is an iterator. - n_lines: int (optional) - if it is not None, only load the first `n_lines` lines of log - """ - def __init__(self, records, n_lines=None): - super(ApplyHistoryBest, self).__init__() - - self.best_by_targetkey = {} - self.best_by_model = {} - self._best_user_defined = {} - - if records: - self.load(records, n_lines) - - def load(self, records, n_lines=None): - """Load records to this dispatch context - - Parameters - ---------- - records : str or iterator of (MeasureInput, MeasureResult) - Collection of tuning records. - If is str, then it should be the filename of a records log file. - Each row of this file is an encoded record pair. - Otherwise, it is an iterator. - n_lines: int (optional) - if it is not None, only load the first `n_lines` lines of log - """ - from pathlib import Path - from . import load_from_file - - if isinstance(records, Path): - records = str(records) - - if isinstance(records, str): - records = load_from_file(records) - if not records: - return - - best_by_targetkey = self.best_by_targetkey - best_by_model = self.best_by_model - - counter = 0 - for inp, res in records: - if n_lines is not None and counter >= n_lines: - break - counter += 1 - if res.error_no != 0: - continue - - # use target keys in tvm target system as key to build best map - for k in inp.task.target.keys: - key = (k, inp.task.workload_key) - if key not in best_by_targetkey: - best_by_targetkey[key] = (inp, res) - else: - _, other_res = best_by_targetkey[key] - other_costs = [x.value for x in other_res.costs if isinstance(x, FloatImm)] - costs = [x.value for x in res.costs if isinstance(x, FloatImm)] - if np.mean(other_costs) > np.mean(costs): - best_by_targetkey[key] = (inp, res) - - # use model as key to build best map - key = (inp.task.target.model, inp.task.workload_key) - if key not in best_by_model: - if inp.task.target.model != 'unknown': - best_by_model[key] = (inp, res) - else: - _, other_res = best_by_model[key] - other_costs = [x.value for x in other_res.costs if isinstance(x, FloatImm)] - costs = [x.value for x in res.costs if isinstance(x, FloatImm)] - if np.mean(other_costs) > np.mean(costs): - best_by_model[key] = (inp, res) - - logger.debug("Finish loading %d records", counter) - - def _query_inside(self, target, workload): - if target is None: - raise RuntimeError("Need a target context to find the history best. " - "Hint: If your target is llvm, use `with tvm.target.create('llvm'):`" - " above the dispatcher call. So does other target. ") - - # first try matching by model - key = (target.model, workload) - if key in self._best_user_defined: - return self._best_user_defined[key] - if key in self.best_by_model: - return self.best_by_model[key][0].state - - # then try matching by target key - for k in target.keys: - key = (k, workload) - if key in self._best_user_defined: - return self._best_user_defined[key] - if key in self.best_by_targetkey: - return self.best_by_targetkey[key][0].state - - return None - - def update(self, target, workload, state): - model = target.model - key = (model, workload) - self._best_user_defined[key] = state - - for k in target.keys: - key = (k, workload) - self._best_user_defined[key] = state - - -class FallbackContext(DispatchContext): - """ - A fallback dispatch context. - This is used as the root context. - """ - - def __init__(self): - super(FallbackContext, self).__init__() - self.memory = {} - self.silent = False - - # a set to prevent print duplicated message - self.messages = set() - - def _query_inside(self, target, workload): - key = (str(target), workload) - if key in self.memory: - return self.memory[key] - - if not self.silent: - msg = "Cannot find config for target=%s, workload=%s. A fallback configuration "\ - "is used, which may bring great performance regression." % (target, workload) - if msg not in self.messages: - self.messages.add(msg) - logger.warning(msg) - cfg = None - - # cache this config to avoid duplicated warning message - self.memory[key] = cfg - return cfg - - def clear_cache(self, target, workload): - """Clear fallback cache. Pass the same argument as _query_inside to this function - to clean the cache. - - Parameters - ---------- - target: Target - The current target - workload : Workload - The current workload. - """ - key = (str(target), workload) - if key in self.memory: - del self.memory[key] - - def update(self, target, workload, cfg): - key = (str(target), workload) - self.memory[key] = cfg - - -DispatchContext.current = FallbackContext() diff --git a/python/tvm/ansor/env.py b/python/tvm/ansor/env.py deleted file mode 100644 index 56e76e26ee4f..000000000000 --- a/python/tvm/ansor/env.py +++ /dev/null @@ -1,25 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" The scope to store global variables in ansor """ - - -class AutoschedulerGlobalScope(object): - def __init__(self): - self.topi_in_compute_rewrite_mode = False - -GLOBAL_SCOPE = AutoschedulerGlobalScope() diff --git a/python/tvm/ansor/feature.py b/python/tvm/ansor/feature.py deleted file mode 100644 index fa1b2cb07dcc..000000000000 --- a/python/tvm/ansor/feature.py +++ /dev/null @@ -1,150 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""" -Python API for Feature extraction. -""" - -from typing import List, Tuple -import struct -import numpy as np - -from .loop_state import State, StateObject -from .measure import MeasureInput, MeasureResult -from . import _ffi_api - - -# Maximum number of buffers for one statement to extract feature for -DEFAULT_MAX_N_BUFS = 5 - -# The length of the feature vector -DEFAULT_FEATURE_VEC_LEN = 164 - - -def unpack_feature(byte_arr: bytearray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """Unpack the encoded feature (in byte array format) of from c++""" - size_of_int = 4 - size_of_float = 4 - - # The format for n records is: - # { - # int n; - # int[n+2] sizes - - # float[sizes[0]] feature for record 1 - # float[sizes[1]] feature for record 2 - # ... feature for record i... - # float[sizes[n-1]] feature for record n - - # float[sizes[n]] normalized throughput for n records - # int[sizes[n+1]] task id for n records - # } - - vec_len = DEFAULT_FEATURE_VEC_LEN - - # unpack sizes - offset = 0 - n = struct.unpack_from("1i", byte_arr, offset=offset)[0] - offset += size_of_int - - sizes = struct.unpack_from("%di" % (n+2), byte_arr, offset=offset) - offset += size_of_int * (n+2) - - # unpack features - features = [] - for size in sizes[:-2]: - row = [] - - # Now we need to unpack the feature for multiple statements. - # The format is: - # { - # int n_stmts - # float[n_stmt][vec_len] feature_vecs - # } - # where vec_len can be calculated by `(size - 1) / n_stmts` - - if size == 0: - # failed during lowering - features.append(np.zeros((1, vec_len))) - else: - n_stmts = struct.unpack_from("f", byte_arr, offset=offset) - offset += size_of_float - - n_stmts = int(n_stmts[0] + 0.5) - tmp_vec_len = (size - 1) // n_stmts - assert tmp_vec_len == vec_len, "The lenght of feature vector is wrong. " \ - "Expected %d but got %d." % (vec_len, tmp_vec_len) - assert (size - 1) % n_stmts == 0 - for _ in range(n_stmts): - x = struct.unpack_from("%df" % vec_len, byte_arr, offset=offset) - offset += vec_len * size_of_float - row.append(x) - - features.append(np.array(row)) - - # unpack normalized_throughputs - m = sizes[-2] - normalized_throughputs = struct.unpack_from("%df" % m, byte_arr, offset=offset) - offset += m * size_of_int - - # unpack task_ids - m = sizes[-1] - task_ids = struct.unpack_from("%di" % m, byte_arr, offset=offset) - offset += m * size_of_int - - assert offset == len(byte_arr), "%d vs %d" % (offset, len(byte_arr)) - return np.array(features), np.array(normalized_throughputs), np.array(task_ids) - - -def get_per_stmt_features_from_file(filename: str, - n_lines: int, - max_n_bufs: int = None) \ - -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """Get per_stmt features from a log file""" - byte_arr = _ffi_api.GetPerStmtFeaturesFromFile( - filename, n_lines, max_n_bufs or DEFAULT_MAX_N_BUFS) - return unpack_feature(byte_arr) - - -def get_per_stmt_features_from_measure_pairs(inputs: List[MeasureInput], - results: List[MeasureResult], - skip_first_n_feature_extraction: int = 0, - max_n_bufs: int = None) \ - -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """Get per_stmt features from measurement pairs""" - byte_arr = _ffi_api.GetPerStmtFeaturesFromMeasurePairs( - inputs, results, skip_first_n_feature_extraction, max_n_bufs or DEFAULT_MAX_N_BUFS) - return unpack_feature(byte_arr) - - -def get_per_stmt_features_from_states(states, - task: "SearchTask", - max_n_bufs: int = None) -> List[np.ndarray]: - """Get per_stmt features from states""" - if isinstance(states[0], State): - state_objects = [s.state_object for s in states] - elif isinstance(states[0], StateObject): - state_objects = states - byte_arr = _ffi_api.GetPerStmtFeaturesFromStates( - state_objects, task, max_n_bufs or DEFAULT_MAX_N_BUFS) - return unpack_feature(byte_arr)[0] - - -def get_per_stmt_feature_names(max_n_bufs: int = None) -> List[str]: - """Get names for the elements in the flatten feature vector""" - return [x for x in - _ffi_api.GetPerStmtFeatureNames(max_n_bufs or DEFAULT_MAX_N_BUFS)] diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/ansor/loop_state.py index 7aa5de0e9c1d..470ae40f5278 100644 --- a/python/tvm/ansor/loop_state.py +++ b/python/tvm/ansor/loop_state.py @@ -152,345 +152,6 @@ def split(self, stage_id, iterator, lengths, inner_to_outer=True): self._clear_cache() return res - def follow_split(self, stage_id, iterator, src_step_id, n_split): - """ - Parameters - ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to split - iterator : Iterator - The iterator to split - src_step_id : int - The index of the split step to follow in the history - n_split : int - The number of split level - - Returns - ------- - res_its : List[Iterator] - The splitted new Iterators - """ - stage_id = self._resolve_stage_id(stage_id) - - self.state_object, res = _ffi_api.StateFollowSplit(self.state_object, stage_id, iterator, - src_step_id, n_split) - self._clear_cache() - return res - - def follow_fused_split(self, stage_id, iterator, src_step_ids, level, - factor_or_nparts): - """ - Parameters - ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to split - iterator : Iterator - The iterator to split - src_step_ids : List[int] - The indices of the split steps to follow in the history - level : int - Use the length in this split level - factor_or_nparts : bool - True to use `factor` for split from inner to outer, - False to use `nparts` for split from outer to inner - - Returns - ------- - res_its : List[Iterator] - The splitted new Iterators - """ - stage_id = self._resolve_stage_id(stage_id) - - self.state_object, res = _ffi_api.StateFollowFusedSplit(self.state_object, stage_id, - iterator, src_step_ids, level, - factor_or_nparts) - self._clear_cache() - return res - - def fuse(self, stage_id, iters): - """ - Parameters - ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to fuse - iters : List[Iterator] - The iterators to be fused - - Returns - ------- - res_it : Iterator - The fused Iterator - """ - stage_id = self._resolve_stage_id(stage_id) - - self.state_object, res = _ffi_api.StateFuse(self.state_object, stage_id, iters) - self._clear_cache() - return res - - def vectorize(self, stage_id, iterator): - """ - Parameters - ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to vectorize - iterator : Iterator - The iterator to be vectorized - - Returns - ------- - res_it : Iterator - The vectorized Iterator - """ - stage_id = self._resolve_stage_id(stage_id) - - self.state_object, res = _ffi_api.StateVectorize(self.state_object, stage_id, iterator) - self._clear_cache() - return res - - def parallel(self, stage_id, iterator): - """ - Parameters - ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to parallel - iterator : Iterator - The iterator to be parallelized - - Returns - ------- - res_it : Iterator - The parallelized Iterator - """ - stage_id = self._resolve_stage_id(stage_id) - - self.state_object, res = _ffi_api.StateParallel(self.state_object, stage_id, iterator) - self._clear_cache() - return res - - def unroll(self, stage_id, iterator, max_unroll=-1): - """ - Parameters - ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to unroll - iterator : Iterator - The iterator to be unrolled - max_unroll: int - The maximum length of the iterator that can be unrolled - - Returns - ------- - res_it : Iterator - The unrolled Iterator - """ - stage_id = self._resolve_stage_id(stage_id) - - self.state_object, res = _ffi_api.StateUnroll(self.state_object, stage_id, iterator, - max_unroll) - self._clear_cache() - return res - - def bind_thread(self, stage_id, iterator, thread_name): - """ - Parameters - ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to bind - iterator : Iterator - The iterator to be bound - thread_name : str - The name of the thread (e.g. "blockIdx.x", "threadIdx.y", "vthread") - - Returns - ------- - res_it : Iterator - The bound Iterator - """ - trans_table = { - "vthread": 4, - "blockIdx.x": 5, - "threadIdx.x": 6, - "blockIdx.y": 7, - "threadIdx.y": 8, - } - thread_id = trans_table[thread_name] - - stage_id = self._resolve_stage_id(stage_id) - - self.state_object, res = _ffi_api.StateBindThread(self.state_object, stage_id, iterator, - thread_id) - self._clear_cache() - return res - - def compute_at(self, stage_id, target_stage_id, target_iter): - """ - Parameters - ---------- - stage_id : Union[int, Operation, Tensor] - The index of source stage - target_stage_id : Union[int, Operation, Tensor] - The index of the target stage of compute_at - target_iter : Iterator - The target Iterator of compute_at - """ - stage_id = self._resolve_stage_id(stage_id) - target_stage_id = self._resolve_stage_id(target_stage_id) - - self.state_object = _ffi_api.StateComputeAt(self.state_object, stage_id, - target_stage_id, target_iter) - self._clear_cache() - - def compute_root(self, stage_id): - """ - Parameters - ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to compute root - """ - stage_id = self._resolve_stage_id(stage_id) - - self.state_object = _ffi_api.StateComputeRoot(self.state_object, stage_id) - self._clear_cache() - - def compute_inline(self, stage_id): - """ - Parameters - ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to compute inline - """ - stage_id = self._resolve_stage_id(stage_id) - - self.state_object = _ffi_api.StateComputeInline(self.state_object, stage_id) - self._clear_cache() - - def cache_read(self, stage_id, scope_name, reader_stage_ids): - """ - Parameters - ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to do cache_read - scope_name : str - reader_stage_ids : List[int] - - Returns - ------- - new_stage_id : int - The added staged id - """ - stage_id = self._resolve_stage_id(stage_id) - - if isinstance(reader_stage_ids, list): - tmp_list = [] - for reader_stage_id in reader_stage_ids: - tmp_list.append(self._resolve_stage_id(reader_stage_id)) - reader_stage_ids = tmp_list - else: - raise ValueError("reader_stage_ids must be list of Tensor or int") - - self.state_object, new_stage_id = _ffi_api.StateCacheRead(self.state_object, stage_id, - scope_name, reader_stage_ids, - self.compute_dag) - return self._insert_new_stage(new_stage_id) - - def cache_write(self, stage_id, scope_name): - """ - Parameters - ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to do cache read - scope_name : str - - Returns - ------- - new_stage_id : int - The added staged id - """ - stage_id = self._resolve_stage_id(stage_id) - - self.state_object, new_stage_id = _ffi_api.StateCacheWrite(self.state_object, stage_id, - scope_name, self.compute_dag) - return self._insert_new_stage(new_stage_id) - - def pragma(self, stage_id, iterator, pragma_type): - """ - Parameters - ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to add pragma - iterator : Iterator - The iterator to add pragma - pragma_type : str - """ - stage_id = self._resolve_stage_id(stage_id) - - self.state_object = _ffi_api.StatePragma(self.state_object, stage_id, iterator, - pragma_type) - self._clear_cache() - - def rfactor(self, stage_id, iterator, factor_iter_id): - """ - Parameters - ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to do reduction factor - iterator : Iterator - factor_iter_id : int - - Returns - ------- - new_stage_id : int - The added staged id - """ - stage_id = self._resolve_stage_id(stage_id) - - self.state_object, new_stage_id = _ffi_api.StateRfactor(self.state_object, stage_id, - iterator, factor_iter_id, - self.compute_dag) - return self._insert_new_stage(new_stage_id) - - def storage_align(self, stage_id, iterator, factor, offset): - """ - Parameters - ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to do storage align - iterator : Iterator - factor : int - offset : int - """ - stage_id = self._resolve_stage_id(stage_id) - - self.state_object = _ffi_api.StateStorageAlign(self.state_object, stage_id, iterator, - factor, offset) - self._clear_cache() - - def tensorize(self, stage_id, iterator, ti_func_name): - """ The `ti_func_name` corresponds to a global registered funcion - that returns a Tensorintrin - - Parameters - ---------- - stage_id : Union[int, Operation, Tensor] - The index of the stage to do storage align - iterator : Iterator - The iterator to be tensorized - ti_func_name : str - Tensorize intrinsic function name - - Returns - ------- - res_it : Iterator - The tensorized Iterator - """ - stage_id = self._resolve_stage_id(stage_id) - - self.state_object, res = _ffi_api.StateTensorize(self.state_object, - stage_id, iterator, - ti_func_name) - self._clear_cache() - return res - def _resolve_stage_id(self, stage_id): if isinstance(stage_id, Operation): return self.stage_id_map[stage_id] diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index be7d69e5ed3a..46c3e3aabd5d 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -42,7 +42,6 @@ from . import _ffi_api from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout, request_remote, \ check_remote -from .compute_dag import LayoutRewriteLevel LOGGER = logging.getLogger('ansor') @@ -331,7 +330,7 @@ def timed_func(): try: sch, args = task.compute_dag.apply_steps_from_state( - inp.state, LayoutRewriteLevel.BOTH_REWRITE) + inp.state) except Exception: error_no = MeasureErrorNo.INSTANTIATION_ERROR error_msg = make_error_msg() diff --git a/python/tvm/ansor/relay_integration.py b/python/tvm/ansor/relay_integration.py deleted file mode 100644 index f2873f8c72fd..000000000000 --- a/python/tvm/ansor/relay_integration.py +++ /dev/null @@ -1,241 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-variable,invalid-name - -""" -Integrate ansor into relay. It implements the following items: -1. Extract search tasks from a relay program -2. Provide auto-scheduling for all TOPI compute functions -""" -import os -import json -import threading - -import tvm -from tvm import te, transform -from tvm.te.tensor import PlaceholderOp, ComputeOp -from .dispatcher import DispatchContext -from .workload_registry import register_workload_bufs, compute_dag_hash -from .compute_dag import ComputeDAG, LayoutRewriteLevel -from .env import GLOBAL_SCOPE - -def call_all_topi_funcs(mod, target, params, target_host=None): - """Call all TOPI compute + schedule to extract tasks in a relay program""" - # pylint: disable=import-outside-toplevel - from tvm import relay - - with transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}): - bld_mod = relay.build_module.BuildModule() - bld_mod.call_all_topi_funcs(mod, target=target, params=params, target_host=target_host) - -def extract_from_program(mod, params, target, target_host=None): - """ Extract tuning tasks from a relay program. - - This function is the single program version of extract_from_multiple_program. - - Parameters - ---------- - mod : relay.Module - The module to extract. - params: dict of str to numpy array - The associated parameters of the program - ops: List of relay op - List of relay ops to be tuned - target: tvm.target.Target - The compilation target - target_host: tvm.target.Target - The host compilation target - - Returns - ------- - workloads: Array of Tuple(wkl_key, target) - """ - return extract_from_multiple_program([mod], [params], target, target_host) - -def extract_from_multiple_program(mods, params, target, target_host=None): - """ Extract tuning tasks from multiple relay programs. - - Parameters - ---------- - mods : List of relay.Module - The modules to extract. - params: List of dict of str to numpy array - The associated parameters of the programs - ops: List of relay op - List of relay ops to be tuned - target: tvm.target.Target - The compilation target - target_host: tvm.target.Target - The host compilation target - - Returns - ------- - workloads: Array of Tuple(wkl_key, target) - """ - # pylint: disable=import-outside-toplevel - from tvm import relay - - env = TracingEnvironment(TracingMode.EXTRACT_TASK) - with env: - # run compiler to collect all TOPI calls during compilation - for mod, param in zip(mods, params): - # wrap build call in a new thread to avoid the conflict - # between python's multiprocessing and tvm's thread pool - build_thread = threading.Thread(target=call_all_topi_funcs, - args=(mod, target, param, target_host)) - build_thread.start() - build_thread.join() - relay.backend.compile_engine.get().clear() - - # create tasks for target - wkl_keys = [] - wkl_weights = [] - for wkl_key, wkl_weight in env.wkl_key_collection.items(): - wkl_keys.append(wkl_key) - wkl_weights.append(wkl_weight) - - return wkl_keys, wkl_weights - - -def prepare_layout_rewrite(mod, params, target): - """ - Prepare for kernel layout rewrite. This function will write layout infos to a global static - variable. - Then these layout info will be used by a relay pass `kernel_layout_transform`. - """ - # pylint: disable=import-outside-toplevel - from tvm import relay - - env = TracingEnvironment(TracingMode.PREPARE_LAYOUT_REWRITE) - with env: - # wrap build call in a new thread to avoid the conflict - # between python's multiprocessing and tvm's thread pool - build_thread = threading.Thread(target=call_all_topi_funcs, - args=(mod, target, params)) - build_thread.start() - build_thread.join() - relay.backend.compile_engine.get().clear() - - if env.layout_rewrite_success_ct > 0: - GLOBAL_SCOPE.topi_in_compute_rewrite_mode = True - -def finish_layout_rewrite(): - """Clear the global flag for layout rewrite""" - GLOBAL_SCOPE.topi_in_compute_rewrite_mode = False - - -class TracingMode: - """Two modes for tracing""" - EXTRACT_TASK = 0 # trace all topi calls to extract tasks - PREPARE_LAYOUT_REWRITE = 1 # trace all topi calls to prepare layout rewrite - -class TracingEnvironment: - """Global environment for tracing all topi function calls""" - current = None - - def __init__(self, tracing_mode): - self.tracing_mode = tracing_mode - self.relay_disable_build_cache = "false" - self.layout_rewrite_success_ct = 0 - self.wkl_key_collection = {} - - def __enter__(self): - self.relay_disable_build_cache = os.environ.get("TVM_RELAY_DISABLE_BUILD_CACHE", "false") - os.environ["TVM_RELAY_DISABLE_BUILD_CACHE"] = "true" - TracingEnvironment.current = self - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - os.environ["TVM_RELAY_DISABLE_BUILD_CACHE"] = self.relay_disable_build_cache - TracingEnvironment.current = None - - def add_workload_key(self, key): - """Add the workload key of an Ansor search task - - Parameters - ---------- - key: str - """ - if key in self.wkl_key_collection: - self.wkl_key_collection[key] += 1 - else: - self.wkl_key_collection[key] = 1 - - -def traverse_to_get_io_tensors(outs): - """Traverse from a list of output tensors to get a whole computational DAG""" - layout_free_ops = [] - inputs = [] - - visited = set() - - def traverse(t): - if t in visited: - return - if isinstance(t.op, PlaceholderOp): - inputs.append(t) - elif isinstance(t.op, ComputeOp): - if "layout_free_placeholders" in t.op.attrs: - layout_free_ops.append(t.op) - for x in t.op.input_tensors: - traverse(x) - visited.add(t) - - for t in outs: - traverse(t) - - has_layout_free = (len(layout_free_ops) > 0) - return inputs + [t for t in outs], has_layout_free - - -def auto_schedule_topi(outs): - """ Use ansor to auto-schedule a topi compute declaration """ - io_tensors, has_layout_free = traverse_to_get_io_tensors(outs) - key = register_workload_bufs(io_tensors) - - env = TracingEnvironment.current - if env is None: # in the final build mode - state = DispatchContext.current.query(tvm.target.Target.current(), key) - if state is None: - return te.create_schedule([x.op for x in outs]) - - dag = ComputeDAG(io_tensors) - # Only update compute body, layout_rewrite_level = LayoutRewriteLevel.COMPUTE_REWRITE, - # Since kernel layout has already been rewritten in relay pass - schedule, _ = dag.apply_steps_from_state( - state, layout_rewrite_level=LayoutRewriteLevel.COMPUTE_REWRITE) - return schedule - if env.tracing_mode == TracingMode.EXTRACT_TASK: # in the task extraction mode - env.add_workload_key(key) - return te.create_schedule([x.op for x in outs]) - if env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE: - # in prepare_layout_rewrite mode - if has_layout_free: - # Rewrite the DAG and update the transform history for - # the new dag in DispatchContext - dispatch_ctx = DispatchContext.current - tgt = tvm.target.Target.current() - state = dispatch_ctx.query(tgt, key) - assert state is not None - dag = ComputeDAG(outs) - new_dag = dag.rewrite_layout_from_state(state) - new_key = json.dumps((compute_dag_hash(new_dag),)) - dispatch_ctx.update(tgt, new_key, state) - if new_key != key: - env.layout_rewrite_success_ct += 1 - return te.create_schedule([x.op for x in outs]) - raise ValueError("Invalid tracing mode: " + env.tracing_mode) diff --git a/python/tvm/ansor/task_scheduler.py b/python/tvm/ansor/task_scheduler.py deleted file mode 100644 index 5b916ed39769..000000000000 --- a/python/tvm/ansor/task_scheduler.py +++ /dev/null @@ -1,299 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""TaskScheduler that allocates the time resources when tuning multiple tasks together""" -from typing import List, Union, Callable -import time - -import numpy as np - -from .auto_schedule import SearchTask, SearchPolicy, SketchSearchPolicy, TuneOption -from .cost_model import RandomModel, XGBModel -from .measure import ProgramMeasurer -from .utils import array_mean, to_str_round - - -class TaskScheduler: - """Allocate the time resources when tuning multiple tasks together""" - def __init__(self, - tasks: List[SearchTask], - objective_func: Callable = None): - self.tasks = tasks - self.objective_func = objective_func or sum - - def compute_score(self, costs: List[float]) -> float: - return self.objective_func(costs) - - -def get_search_policies(search_policy: Union[str, List[SearchPolicy]], tasks: List[SearchTask], - num_measure_per_iter, load_model_file=None, load_log_file=None): - """ ... - """ - if search_policy == 'default': - search_policy = 'sketch.xgb' - - if isinstance(search_policy, str): - policy_type, model_type = search_policy.split('.') - if model_type == 'xgb': - cost_model = XGBModel(num_warmup_sample=len(tasks) * num_measure_per_iter) - if load_model_file: - print("Load pretrained model...") - cost_model.load(load_model_file) - elif load_log_file: - cost_model.load_log_file(load_log_file) - elif model_type == 'random': - cost_model = RandomModel() - else: - raise ValueError("Invalid search policy: " + search_policy) - - if policy_type == 'sketch': - search_policies = [SketchSearchPolicy(cost_model) for _ in range(len(tasks))] - elif policy_type == 'limit-space': - search_policies = [SketchSearchPolicy(cost_model, - params={'cpu_multi_level_tiling_structure': 'SRS', - 'disable_change_compute_location': 1}) - for _ in range(len(tasks))] - elif policy_type == 'beam-search': - search_policies = [SketchSearchPolicy(cost_model, - params={'use_beam_search': 1}) - for _ in range(len(tasks))] - else: - raise ValueError("Invalid search policy: " + search_policy) - else: - # check type - assert isinstance(search_policy, (tuple, list)) - for item in search_policy: - assert isinstance(item, SearchPolicy) - search_policies = search_policy - - return search_policies - - -class SimpleTaskScheduler(TaskScheduler): - """The default task scheduler with several strategies - - Parameters - ---------- - tasks: List[SearchTask] - All workloads to tune - weights: List[float] - Weights of tasks (i.e. the number of occurrence of a task in the whole network) - strategy: str - The joint tuning strategy. - "sequential" : Tune tasks sequentially. Divide n_trials equally to every task. - "round-robin": Tune tasks in round robin order. - "gradient" : Tune tasks with gradient descent. - load_log_file: str - Load history log file to pre-train cost model - eps-random: float - Always allocate this percent of n_trials to select tasks randomly. - This is for encouraging exploration. - verbose: int - The level of verbosity. 0 means silent. - alpha: float - The parameter used for 'gradient' strategy - beta: float - The parameter used for 'gradient' strategy - backward_window_size: int - The parameter used for 'gradient' strategy - """ - def __init__(self, - tasks: List[SearchTask], - objective_func: Callable = None, - strategy: str = 'gradient', - load_log_file: str = None, - load_model_file: str = None, - eps_random: float = 0.05, - verbose: int = 1, - alpha: float = 0.2, - beta: float = 2, - gamma: float = 0.5, - backward_window_size: int = 3, - use_debug_measurement_simulator=None): - super().__init__(tasks, objective_func) - self.strategy = strategy - self.eps_random = eps_random - self.verbose = verbose - self.load_log_file = load_log_file - self.load_model_file = load_model_file - self.alpha = alpha - self.beta = beta - self.gamma = gamma - self.backward_window_size = backward_window_size - self.use_debug_measurement_simulator = use_debug_measurement_simulator - - assert self.strategy in ['round-robin', 'gradient'] - - self.task_cts = [] - self.task_costs_history = [] - self.best_costs = self.cur_score = None - self.tune_option = self.measurer = self.search_policies = self.ct = self.tic = None - self.num_measure_per_iter = None - self.dead_tasks = set() - self.sequential_now_task_idx = 0 - self.sequential_now_task_begin_ct = 0 - - def tune(self, tune_option: TuneOption, - search_policy: Union[str, List[SearchPolicy]] = 'default'): - """ Tune tasks. - - Notice: This method does not have return value, make sure to set `LogToFile` - measure callback in `tune_option`. - - Parameters - ---------- - tune_option: TuneOption - search_policy: Str or List[SearchPolicy] - """ - # init members - self.task_cts = [0 for _ in range(len(self.tasks))] - self.task_costs_history = [[] for _ in range(len(self.tasks))] - self.best_costs = 1e10 * np.ones(len(self.tasks)) - self.cur_score = self.compute_score(self.best_costs) - self.tune_option = tune_option - if self.use_debug_measurement_simulator is None: - self.measurer = ProgramMeasurer(tune_option.builder, tune_option.runner, - tune_option.measure_callbacks, tune_option.verbose) - self.ct = 0 - self.tic = time.time() - # reset num_measure_per_iter to make sure every task is tuned at least once - self.num_measure_per_iter = min(tune_option.num_measure_per_iter, - tune_option.n_trials // len(self.tasks)) - self.search_policies = get_search_policies(search_policy, self.tasks, - self.num_measure_per_iter, - self.load_model_file, - self.load_log_file) - self.dead_tasks = set() - self.sequential_now_task_idx = 0 - self.sequential_now_task_begin_ct = 0 - - for i in range(len(self.tasks)): - search_policy = self.search_policies[i] - task = self.tasks[i] - search_policy.set_task(task) - search_policy.set_verbose(tune_option.verbose) - search_policy.run_callbacks(tune_option.pre_search_callbacks) - - # do a round robin first - if self.strategy != 'sequential': - for i in range(len(self.tasks)): - self.tune_task(i) - - # use the specific strategy to choose workload to tune - task_idx = -1 - while self.ct < tune_option.n_trials and len(self.dead_tasks) < len(self.tasks): - if self.strategy == 'sequential': - allocated_total_ct = ((tune_option.n_trials - self.sequential_now_task_begin_ct) - / (len(self.tasks) - self.sequential_now_task_idx)) - used_ct = self.ct - self.sequential_now_task_begin_ct - - if self.sequential_now_task_idx in self.dead_tasks or used_ct >= allocated_total_ct: - self.sequential_now_task_idx += 1 - self.sequential_now_task_begin_ct = self.ct - task_idx = self.sequential_now_task_idx - if task_idx >= len(self.tasks): - break - elif self.strategy == 'round-robin': - task_idx = (task_idx + 1) % len(self.tasks) - while task_idx in self.dead_tasks: - task_idx = (task_idx + 1) % len(self.tasks) - elif self.strategy == 'gradient': - gradients = [] - for i in range(len(self.tasks)): - if i in self.dead_tasks: - gradients.append(0) - continue - - # compute gradient from chain rule : (delta f / delta g_i) - delta = 1e-7 - new_costs = list(self.best_costs) - new_costs[i] -= delta - chain_grad = (self.compute_score(self.best_costs) - self.compute_score(new_costs)) / delta - - # compute (g_i(t_i) - g(t_i - \Delta t)) / (\Delta t) - if self.task_cts[i] - 1 - self.backward_window_size >= 0: - backward_grad = (self.task_costs_history[i][self.task_cts[i] - 1] - - self.task_costs_history[i][self.task_cts[i] - 1 - self.backward_window_size]) \ - / self.backward_window_size - else: - backward_grad = 0 - - # compute (g_i(t_i + \Delta t) - g(t_i)) / (\Delta t) - g_next_1 = self.best_costs[i] - (self.best_costs[i] / self.task_cts[i]) - # todo(lmzheng): this needs adding attribute to topi.compute for similarity check - g_next_2 = self.beta * 1e20 - g_next = min(g_next_1, g_next_2) - forward_grad = g_next - self.best_costs[i] - - # combine all grads - grad = chain_grad * (self.alpha * backward_grad + (1 - self.alpha) * forward_grad) - assert grad <= 0 - gradients.append(grad) - - if max(gradients) == min(gradients): - task_idx = np.random.choice(len(gradients)) - else: - task_idx = np.argmin(gradients) - else: - raise ValueError("Invalid strategy: " + self.strategy) - - if self.verbose >= 1: - print("Next tuning task: %d" % task_idx) - self.tune_task(task_idx) - - def tune_task(self, task_idx): - """ ... - """ - if self.use_debug_measurement_simulator is not None: - measure_inputs, measure_results = \ - self.use_debug_measurement_simulator.get_next_batch( - self.tasks[task_idx], - self.num_measure_per_iter, - ) - else: - measure_inputs, measure_results = \ - self.search_policies[task_idx].continue_search( - self.tasks[task_idx], - self.num_measure_per_iter, - self.tune_option.verbose, - self.measurer) - - for inp, res in zip(measure_inputs, measure_results): - cost = array_mean(res.costs) - if cost < self.best_costs[task_idx]: - self.best_costs[task_idx] = cost - - if len(measure_inputs) == 0: - self.dead_tasks.add(task_idx) - - self.task_cts[task_idx] += 1 - self.task_costs_history[task_idx].append(self.best_costs[task_idx]) - - self.ct += len(measure_inputs) - self.cur_score = self.compute_score(self.best_costs) - - if self.verbose >= 1: - print(("TaskScheduler\tct: %d\testimated cost (ms): %.3f\ttime elapsed: %.2f\t" + - "best_costs (ms): %s\ttask_ct: %s") % - (self.ct, self.cur_score * 1e3, time.time() - self.tic, - to_str_round(self.best_costs * 1e3, decimal=3), - self.task_cts)) - - def remove_dead_task(self, prob): - for idx in self.dead_tasks: - prob[idx] = 0 - return prob / prob.sum() diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index b6bedb411540..8e6698e4a164 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -18,7 +18,6 @@ """Backend code generation engine.""" from __future__ import absolute_import -import os import logging import numpy as np import tvm @@ -142,6 +141,7 @@ def get_valid_implementations(op, attrs, inputs, out_type, target): ret.append(impl) return ret + def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True): """Select the best implementation from the op strategy. @@ -179,9 +179,6 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True) ret : tuple(relay.op.OpImplementation, List[tvm.te.Tensor]) The best op implementation and the corresponding output tensors. """ - if os.environ.get('TVM_USE_AUTOTVM', 'false') == 'false': - use_autotvm = False - all_impls = get_valid_implementations(op, attrs, inputs, out_type, target) best_plevel_impl = None diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index d1a39ceb630e..30c5971e32b9 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -72,7 +72,6 @@ def __init__(self): self._get_module = self.mod["get_module"] self._build = self.mod["build"] self._optimize = self.mod["optimize"] - self._call_all_topi_funcs = self.mod["call_all_topi_funcs"] self._set_params_func = self.mod["set_params"] self._get_params_func = self.mod["get_params"] @@ -161,12 +160,6 @@ def optimize(self, mod, target=None, params=None): return mod, params - def call_all_topi_funcs(self, mod, target=None, target_host=None, params=None): - """Call all topi compute and schedule used in a relay function""" - target = _update_target(target) - if params: - self._set_params(params) - self._call_all_topi_funcs(mod, target, target_host) def _set_params(self, params): self._set_params_func(_convert_param_map(params)) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 41bd10cabe3e..d104c1b1c2f8 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -74,8 +74,6 @@ def compute_strided_set(attrs, inputs, output_type): # layout_transform _reg.register_injective_schedule("layout_transform") _reg.register_pattern("layout_transform", OpPattern.INJECTIVE) -_reg.register_injective_schedule("kernel_layout_transform") -_reg.register_pattern("kernel_layout_transform", OpPattern.INJECTIVE) # argwhere @_reg.register_compute("argwhere") diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 58b9269a4c48..486d63c36ff0 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -261,9 +261,6 @@ class ClipAttrs(Attrs): class LayoutTransformAttrs(Attrs): """Attributes for transform.layout_transform""" -@tvm._ffi.register_object("relay.attrs.KernelLayoutTransformAttrs") -class KernelLayoutTransformAttrs(Attrs): - """Attributes for transform.kernel_layout_transform""" @tvm._ffi.register_object("relay.attrs.ShapeOfAttrs") class ShapeOfAttrs(Attrs): diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 3453b089f373..b02db416bdc8 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -16,15 +16,14 @@ # under the License. """Definition of x86 operator strategy.""" # pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import +import logging +import re +import topi from tvm.te import SpecializedCondition -from tvm import ansor from .generic import * from .. import op as _op -# Set the priority level to use the Ansor auto-scheduler -ansor_plevel = 11 - logger = logging.getLogger('strategy') _NCHWc_matcher = re.compile("^NCHW[0-9]+c$") @@ -40,7 +39,7 @@ def schedule_injective_cpu(attrs, outs, target): def schedule_reduce_cpu(attrs, outs, target): """schedule reduction ops for x86""" with target: - return ansor.auto_schedule_topi(outs) + return topi.x86.schedule_reduce(outs) @schedule_concatenate.register("cpu") def schedule_concatenate_cpu(attrs, outs, target): @@ -52,13 +51,13 @@ def schedule_concatenate_cpu(attrs, outs, target): def schedule_pool_cpu(attrs, outs, target): """schedule pooling ops for x86""" with target: - return ansor.auto_schedule_topi(outs) + return topi.x86.schedule_pool(outs, attrs.layout) @schedule_adaptive_pool.register("cpu") def schedule_adaptive_pool_cpu(attrs, outs, target): """schedule adaptive pooling ops for x86""" with target: - return ansor.auto_schedule_topi(outs) + return topi.x86.schedule_adaptive_pool(outs) @softmax_strategy.register("cpu") def softmax_strategy_cpu(attrs, inputs, out_type, target): @@ -66,15 +65,15 @@ def softmax_strategy_cpu(attrs, inputs, out_type, target): strategy = _op.OpStrategy() strategy.add_implementation( wrap_compute_softmax(topi.nn.softmax), - wrap_topi_schedule(ansor.auto_schedule_topi), - name="ansor") + wrap_topi_schedule(topi.x86.schedule_softmax), + name="softmax.x86") return strategy @schedule_log_softmax.register("cpu") def schedule_log_softmax_cpu(attrs, outs, target): """schedule log_softmax op for x86""" with target: - return ansor.auto_schedule_topi(outs) + return topi.x86.schedule_softmax(outs) @conv2d_strategy.register("cpu") def conv2d_strategy_cpu(attrs, inputs, out_type, target): @@ -106,18 +105,18 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): return conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target) elif layout == "NHWC": assert kernel_layout == "HWIO" - #logger.warning("For x86 target, NCHW layout is recommended for conv2d.") + logger.warning("For x86 target, NCHW layout is recommended for conv2d.") strategy.add_implementation( wrap_compute_conv2d(topi.nn.conv2d_nhwc), - wrap_topi_schedule(ansor.auto_schedule_topi), - name="ansor") + wrap_topi_schedule(topi.x86.schedule_conv2d_nhwc), + name="conv2d_nhwc.x86") elif layout == "HWCN": assert kernel_layout == "HWIO" - #logger.warning("conv2d HWCN layout is not optimized for x86.") + logger.warning("conv2d HWCN layout is not optimized for x86.") strategy.add_implementation( wrap_compute_conv2d(topi.nn.conv2d_hwcn), - wrap_topi_schedule(ansor.auto_schedule_topi), - name="ansor") + wrap_topi_schedule(topi.generic.schedule_conv2d_hwcn), + name="conv2d_hwcn.generic") else: raise RuntimeError("Unsupported conv2d layout {} for x86".format(layout)) elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups): @@ -144,8 +143,8 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): logger.warning("depthwise_conv2d NHWC layout is not optimized for x86.") strategy.add_implementation( wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc), - wrap_topi_schedule(ansor.auto_schedule_topi), - name="ansor") + wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nhwc), + name="depthwise_conv2d_nhwc.generic") else: raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout)) else: # group_conv2d @@ -154,8 +153,8 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): logger.warning("group_conv2d is not optimized for x86.") strategy.add_implementation( wrap_compute_conv2d(topi.nn.group_conv2d_nchw, has_groups=True), - wrap_topi_schedule(ansor.auto_schedule_topi), - name="ansor") + wrap_topi_schedule(topi.generic.schedule_group_conv2d_nchw), + name="group_conv2d_nchw.generic") else: raise RuntimeError("Unsupported group_conv2d layout {}".format(layout)) return strategy @@ -232,8 +231,8 @@ def conv3d_strategy_cpu(attrs, inputs, out_type, target): name="conv3d_ncdhw.x86") elif layout == "NDHWC": strategy.add_implementation(wrap_compute_conv3d(topi.x86.conv3d_ndhwc), - wrap_topi_schedule(ansor.auto_schedule_topi), - name="ansor") + wrap_topi_schedule(topi.x86.schedule_conv3d_ndhwc), + name="conv3d_ndhwc.x86") else: raise ValueError("Not support this layout {} yet".format(layout)) return strategy @@ -252,8 +251,8 @@ def conv1d_strategy_cpu(attrs, inputs, out_type, target): name="conv1d_ncw.x86") elif layout == "NWC": strategy.add_implementation(wrap_compute_conv1d(topi.nn.conv1d_nwc), - wrap_topi_schedule(ansor.auto_schedule_topi), - name="ansor") + wrap_topi_schedule(topi.x86.schedule_conv1d_nwc), + name="conv1d_nwc.x86") else: raise ValueError("Unsupported conv1d layout {}".format(layout)) return strategy @@ -262,23 +261,16 @@ def conv1d_strategy_cpu(attrs, inputs, out_type, target): def dense_strategy_cpu(attrs, inputs, out_type, target): """dense x86 strategy""" strategy = _op.OpStrategy() - - strategy.add_implementation(wrap_compute_dense(topi.nn.dense), - wrap_topi_schedule(ansor.auto_schedule_topi), - name='ansor', - plevel=ansor_plevel) - + m, _ = inputs[0].shape strategy.add_implementation(wrap_compute_dense(topi.x86.dense_nopack), wrap_topi_schedule(topi.x86.schedule_dense_nopack), name="dense_nopack.x86", plevel=10) - if "cblas" in target.libs: strategy.add_implementation(wrap_compute_dense(topi.x86.dense_cblas), wrap_topi_schedule(topi.x86.schedule_dense_cblas), name="dense_cblas.x86", plevel=15) - m, _ = inputs[0].shape with SpecializedCondition(m >= 16): # this implementation may not be well-optimized, so use plevel=8 for now. strategy.add_implementation(wrap_compute_dense(topi.x86.dense_pack), @@ -291,12 +283,6 @@ def dense_strategy_cpu(attrs, inputs, out_type, target): def batch_matmul_strategy_cpu(attrs, inputs, out_type, target): """batch_matmul x86 strategy""" strategy = _op.OpStrategy() - - strategy.add_implementation(wrap_compute_dense(topi.nn.batch_matmul), - wrap_topi_schedule(ansor.auto_schedule_topi), - name='ansor', - plevel=ansor_plevel) - strategy.add_implementation(wrap_compute_batch_matmul(topi.x86.batch_matmul), wrap_topi_schedule(topi.x86.schedule_batch_matmul), name="batch_matmul.x86", diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index f2fa2b5f5b90..a37226ea4f58 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -815,27 +815,6 @@ def layout_transform(data, src_layout, dst_layout): """ return _make.layout_transform(data, src_layout, dst_layout) -def kernel_layout_transform(data, src_layout, dst_layout): - """Transform the layout of a kernel - - Parameters - ---------- - data : relay.Expr - The source tensor to be transformed - - src_layout: str - The source layout. (e.g 1N32C112H112W) - - dst_layout: str - The destination layout. (e.g. 1N2C112H112W16c) - - Returns - ------- - ret : relay.Expr - The transformed tensor. - """ - return _make.kernel_layout_transform(data, src_layout, dst_layout) - def reverse_reshape(data, newshape): """Reshapes the input array where the special values are inferred from diff --git a/python/tvm/relay/testing/dqn.py b/python/tvm/relay/testing/dqn.py index 3d6883362c9b..10da37001f12 100644 --- a/python/tvm/relay/testing/dqn.py +++ b/python/tvm/relay/testing/dqn.py @@ -26,32 +26,27 @@ from . import layers from .init import create_workload -def get_net(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32", layout="NCHW"): +def get_net(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32"): """get symbol of nature dqn""" data_shape = (batch_size,) + image_shape data = relay.var("data", shape=data_shape, dtype=dtype) - bias_axis = layout.index('C') - conv1_bias = relay.var("conv1_bias") conv1 = layers.conv2d(data, kernel_size=(8, 8), strides=(4, 4), padding=(0, 0), - channels=32, name="conv1", data_layout=layout, - kernel_layout=layers.conv_kernel_layout(layout)) - conv1 = relay.nn.bias_add(conv1, conv1_bias, bias_axis) + channels=32, name="conv1") + conv1 = relay.nn.bias_add(conv1, conv1_bias) relu1 = relay.nn.relu(conv1) conv2_bias = relay.var("conv2_bias") conv2 = layers.conv2d(relu1, kernel_size=(4, 4), strides=(2, 2), padding=(0, 0), - channels=64, name="conv2", data_layout=layout, - kernel_layout=layers.conv_kernel_layout(layout)) - conv2 = relay.nn.bias_add(conv2, conv2_bias, bias_axis) + channels=64, name="conv2") + conv2 = relay.nn.bias_add(conv2, conv2_bias) relu2 = relay.nn.relu(conv2) conv3_bias = relay.var("conv3_bias") conv3 = layers.conv2d(relu2, kernel_size=(3, 3), strides=(1, 1), padding=(0, 0), - channels=64, name="conv3", data_layout=layout, - kernel_layout=layers.conv_kernel_layout(layout)) - conv3 = relay.nn.bias_add(conv3, conv3_bias, bias_axis) + channels=64, name="conv3") + conv3 = relay.nn.bias_add(conv3, conv3_bias) relu3 = relay.nn.relu(conv3) bf1 = relay.nn.batch_flatten(relu3) @@ -63,8 +58,7 @@ def get_net(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32" return relay.Function(args, dense2) -def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32", - layout="NCHW"): +def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32"): """Get benchmark workload for a Deep Q Network Parameters ---------- @@ -78,11 +72,10 @@ def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="flo The data type Returns ------- - mod : tvm.relay.Module + mod : tvm.IRModule The relay module that contains a DQN network. params : dict of str to NDArray The parameters. """ - net = get_net(batch_size, num_actions=num_actions, image_shape=image_shape, dtype=dtype, - layout=layout) + net = get_net(batch_size, num_actions=num_actions, image_shape=image_shape, dtype=dtype) return create_workload(net) diff --git a/python/tvm/relay/testing/resnet.py b/python/tvm/relay/testing/resnet.py index ac63afde4cba..b431dd096f9d 100644 --- a/python/tvm/relay/testing/resnet.py +++ b/python/tvm/relay/testing/resnet.py @@ -59,11 +59,9 @@ def residual_unit(data, name : str Base name of the operators """ - bn_axis = data_layout.index('C') if bottle_neck: bn1 = layers.batch_norm_infer(data=data, epsilon=2e-5, - axis=bn_axis, name=name + '_bn1') act1 = relay.nn.relu(data=bn1) conv1 = layers.conv2d( @@ -75,13 +73,13 @@ def residual_unit(data, name=name + '_conv1', data_layout=data_layout, kernel_layout=kernel_layout) - bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, axis=bn_axis, name=name + '_bn2') + bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, name=name + '_bn2') act2 = relay.nn.relu(data=bn2) conv2 = layers.conv2d( data=act2, channels=int(num_filter*0.25), kernel_size=(3, 3), strides=(1, 1), padding=(1, 1), name=name + '_conv2', data_layout=data_layout, kernel_layout=kernel_layout) - bn3 = layers.batch_norm_infer(data=conv2, epsilon=2e-5, axis=bn_axis, name=name + '_bn3') + bn3 = layers.batch_norm_infer(data=conv2, epsilon=2e-5, name=name + '_bn3') act3 = relay.nn.relu(data=bn3) conv3 = layers.conv2d( data=act3, channels=num_filter, kernel_size=(1, 1), @@ -96,13 +94,13 @@ def residual_unit(data, data_layout=data_layout, kernel_layout=kernel_layout) return relay.add(conv3, shortcut) - bn1 = layers.batch_norm_infer(data=data, epsilon=2e-5, axis=bn_axis, name=name + '_bn1') + bn1 = layers.batch_norm_infer(data=data, epsilon=2e-5, name=name + '_bn1') act1 = relay.nn.relu(data=bn1) conv1 = layers.conv2d( data=act1, channels=num_filter, kernel_size=(3, 3), strides=stride, padding=(1, 1), name=name + '_conv1', data_layout=data_layout, kernel_layout=kernel_layout) - bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, axis=bn_axis, name=name + '_bn2') + bn2 = layers.batch_norm_infer(data=conv1, epsilon=2e-5, name=name + '_bn2') act2 = relay.nn.relu(data=bn2) conv2 = layers.conv2d( data=act2, channels=num_filter, kernel_size=(3, 3), @@ -158,16 +156,12 @@ def resnet(units, data_layout = layout kernel_layout = "OIHW" if layout == "NCHW" else "HWIO" - bn_axis = data_layout.index('C') num_unit = len(units) assert num_unit == num_stages data = relay.var("data", shape=data_shape, dtype=dtype) - data = layers.batch_norm_infer(data=data, epsilon=2e-5, axis=bn_axis, scale=False, - name='bn_data') + data = layers.batch_norm_infer(data=data, epsilon=2e-5, scale=False, name='bn_data') (_, _, height, _) = data_shape - if layout == "NHWC": - (_, height, _, _) = data_shape if height <= 32: # such as cifar10 body = layers.conv2d( data=data, channels=filter_list[0], kernel_size=(3, 3), @@ -178,7 +172,7 @@ def resnet(units, data=data, channels=filter_list[0], kernel_size=(7, 7), strides=(2, 2), padding=(3, 3), name="conv0", data_layout=data_layout, kernel_layout=kernel_layout) - body = layers.batch_norm_infer(data=body, epsilon=2e-5, axis=bn_axis, name='bn0') + body = layers.batch_norm_infer(data=body, epsilon=2e-5, name='bn0') body = relay.nn.relu(data=body) body = relay.nn.max_pool2d(data=body, pool_size=(3, 3), strides=(2, 2), padding=(1, 1), layout=data_layout) @@ -193,7 +187,7 @@ def resnet(units, body, filter_list[i+1], (1, 1), True, name='stage%d_unit%d' % (i + 1, j + 2), bottle_neck=bottle_neck, data_layout=data_layout, kernel_layout=kernel_layout) - bn1 = layers.batch_norm_infer(data=body, epsilon=2e-5, axis=bn_axis, name='bn1') + bn1 = layers.batch_norm_infer(data=body, epsilon=2e-5, name='bn1') relu1 = relay.nn.relu(data=bn1) # Although kernel is not used here when global_pool=True, we should put one pool1 = relay.nn.global_avg_pool2d(data=relu1, layout=data_layout) @@ -215,8 +209,6 @@ def get_net(batch_size, Original author Wei Wu """ (_, height, _) = image_shape - if layout == "NHWC": - (height, _, _) = image_shape data_shape = (batch_size,) + image_shape if height <= 28: num_stages = 3 diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index 967bfcdd3cde..060673dc19c6 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -279,39 +279,6 @@ def empty(shape, dtype="float32", ctx=context(1, 0)): return _make_array(handle, False, False) -def non_empty(shape, dtype="float32", ctx=context(1, 0)): - """Create an non-empty array given shape and device - - Parameters - ---------- - shape : tuple of int - The shape of the array - - dtype : type or str - The data type of the array. - - ctx : TVMContext - The context of the array - - Returns - ------- - arr : tvm.nd.NDArray - The array tvm supported. - """ - shape = c_array(tvm_shape_index_t, shape) - ndim = ctypes.c_int(len(shape)) - handle = TVMArrayHandle() - dtype = DataType(dtype) - check_call(_LIB.TVMArrayAllocNonEmpty( - shape, ndim, - ctypes.c_int(dtype.type_code), - ctypes.c_int(dtype.bits), - ctypes.c_int(dtype.lanes), - ctx.device_type, - ctx.device_id, - ctypes.byref(handle))) - return _make_array(handle, False, False) - def from_dlpack(dltensor): """Produce an array from a DLPack tensor without memory copy. Retreives the underlying DLPack tensor's pointer to create an array from the diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py index 6a2120817eb1..7d73bf42ab7d 100644 --- a/python/tvm/te/tensor.py +++ b/python/tvm/te/tensor.py @@ -56,11 +56,9 @@ class Tensor(DataProducer, _expr.ExprOp): """Tensor object, to construct, see function.Tensor""" def __call__(self, *indices): - # ndim = self.ndim - # After ansor kernel layout rewrite, len(indices) <= ndim, - # and the indices will get modified by Ansor during schedule generation. - # if len(indices) != ndim: - # raise ValueError("Need to provide %d index in tensor slice" % ndim) + ndim = self.ndim + if len(indices) != ndim: + raise ValueError("Need to provide %d index in tensor slice" % ndim) indices = convert_to_object(indices) args = [] for x in indices: diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index a8cd1d3c2462..34c3487e3ef2 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -153,11 +153,6 @@ class RelayBuildModule : public runtime::ModuleNode { CHECK_EQ(args.num_args, 2); *rv = this->Optimize(args[0], args[1], this->params_); }); - } else if (name == "call_all_topi_funcs") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue *rv) { - CHECK_EQ(args.num_args, 3); - this->CallAllTopiFuncs(args[0], args[1], args[2]); - }); } else { LOG(FATAL) << "Unknown packed function: " << name; return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); @@ -232,21 +227,6 @@ class RelayBuildModule : public runtime::ModuleNode { BuildRelay(mod, params_); } - /*! \brief Call all used TOPI compute and schedule in a relay function */ - void CallAllTopiFuncs(IRModule mod, - const TargetsMap& targets, - const tvm::Target& target_host) { - targets_ = targets; - target_host_ = target_host; - - IRModule relay_module = Optimize(mod, targets_, params_); - auto func = Downcast(relay_module->Lookup("main")); - - graph_codegen_ = std::unique_ptr(new GraphCodegen()); - graph_codegen_->Init(nullptr, targets_); - graph_codegen_->Codegen(func); - } - protected: /*! * \brief Optimize a Relay IRModule. @@ -335,18 +315,6 @@ class RelayBuildModule : public runtime::ModuleNode { // Fuse the operations if it is needed. relay_module = transform::FuseOps()(relay_module); - - if (targets.size() == 1) { - pass_seqs.push_back(transform::KernelLayoutTransform()); - pass_seqs.push_back(transform::DeFuseOps()); - pass_seqs.push_back(transform::FoldConstant()); - transform::Pass seq = transform::Sequential(pass_seqs); - const auto& it = targets.begin(); - With tctx((*it).second); - relay_module = seq(relay_module); - relay_module = transform::FuseOps()(relay_module); - } - relay_module = transform::InferType()(relay_module); // Inline the functions that have been lifted by the module scope. // diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index fde880b10f1d..2aae8546248f 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -68,11 +68,6 @@ CCacheKey::CCacheKey(Function source_func, Target target) { auto n = make_object(); n->source_func = std::move(source_func); n->target = std::move(target); - n->disabled = false; - char* envar = getenv("TVM_RELAY_DISABLE_BUILD_CACHE"); - if (envar != nullptr && strcmp(envar, "true") == 0) { - n->disabled = true; - } data_ = std::move(n); } diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index b290462a4b22..a5f3f6359f89 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -115,8 +115,6 @@ class CCacheKeyNode : public Object { /*! \brief The hardware target.*/ Target target; - bool disabled; - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("source_func", &source_func); v->Visit("target", &target); @@ -261,7 +259,6 @@ inline size_t CCacheKeyNode::Hash() const { } inline bool CCacheKeyNode::Equal(const CCacheKeyNode* other) const { - if (disabled) return false; if (Hash() != other->Hash()) return false; return this->target->str() == other->target->str() && tvm::StructuralEqual()(this->source_func, other->source_func); From 910964edcf977ebbab411f41f1f0da0959a7e187 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 24 Jun 2020 14:39:13 +0800 Subject: [PATCH 42/45] Rever Commits, Start to build minimum Ansor system --- python/tvm/ansor/__init__.py | 4 +- python/tvm/ansor/compute_dag.py | 2 +- python/tvm/ansor/loop_state.py | 20 + src/ansor/compute_dag.cc | 5 +- src/ansor/search_task.cc | 59 -- .../python/unittest/test_ansor_loop_state.py | 540 +----------------- tests/python/unittest/test_ansor_measure.py | 18 - .../unittest/test_ansor_search_policy.py | 85 --- 8 files changed, 27 insertions(+), 706 deletions(-) diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 9cce63b2840d..ccd8f27b71c1 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -26,8 +26,8 @@ # Shortcut from .compute_dag import ComputeDAG -from .auto_schedule import SearchTask, SketchSearchPolicy, TuneOption, HardwareParams, \ - PreloadMeasuredStates, PreloadCustomSketchRule, auto_schedule +from .auto_schedule import SearchTask, TuneOption, HardwareParams, \ + auto_schedule from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, LocalRPCMeasureContext from .cost_model import RandomModel from .serialization import LogToFile, LogReader, best_measure_pair_in_file, \ diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index 994c3ae3ab97..acfec66a166a 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -44,7 +44,7 @@ def get_init_state(self): """ return State(_ffi_api.ComputeDAGGetInitState(self), self) - def apply_steps_from_state(self, state, layout_rewrite_level=LayoutRewriteLevel.NO_REWRITE): + def apply_steps_from_state(self, state): """ Apply transform steps according to the history of a state diff --git a/python/tvm/ansor/loop_state.py b/python/tvm/ansor/loop_state.py index 470ae40f5278..bf81311ed664 100644 --- a/python/tvm/ansor/loop_state.py +++ b/python/tvm/ansor/loop_state.py @@ -152,6 +152,26 @@ def split(self, stage_id, iterator, lengths, inner_to_outer=True): self._clear_cache() return res + def fuse(self, stage_id, iters): + """ + Parameters + ---------- + stage_id : Union[int, Operation, Tensor] + The index of the stage to fuse + iters : List[Iterator] + The iterators to be fused + + Returns + ------- + res_it : Iterator + The fused Iterator + """ + stage_id = self._resolve_stage_id(stage_id) + + self.state_object, res = _ffi_api.StateFuse(self.state_object, stage_id, iters) + self._clear_cache() + return res + def _resolve_stage_id(self, stage_id): if isinstance(stage_id, Operation): return self.stage_id_map[stage_id] diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index 9e6da6ff6f3b..d7af8b94729a 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -38,7 +38,6 @@ #include #include "transform_step.h" #include "search_policy/utils.h" -#include "../relay/transforms/kernel_layout_transform.h" namespace tvm { namespace ansor { @@ -737,7 +736,7 @@ void ComputeDAG::RewriteLayout( CHECK_EQ(placeholder_axis_names.size(), placeholder->shape.size()); std::string ori_layout = os.str(); os.str(""); - ::tvm::relay::KernelLayoutVisitor::global_ori_layouts_queue.push_back(ori_layout); + // ::tvm::relay::KernelLayoutVisitor::global_ori_layouts_queue.push_back(ori_layout); } } @@ -800,7 +799,7 @@ void ComputeDAG::RewriteLayout( } std::string new_layout = os.str(); os.str(""); - ::tvm::relay::KernelLayoutVisitor::global_new_layouts_queue.push_back(new_layout); + // ::tvm::relay::KernelLayoutVisitor::global_new_layouts_queue.push_back(new_layout); placeholder_new_names[placeholder_op] = new_names; placeholder_new_shapes[placeholder_op] = new_shape; diff --git a/src/ansor/search_task.cc b/src/ansor/search_task.cc index 17ab73efb6aa..6be4773fe780 100644 --- a/src/ansor/search_task.cc +++ b/src/ansor/search_task.cc @@ -52,65 +52,6 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams( if (target->target_name == "llvm") { return HardwareParams(tvm::runtime::threading::MaxConcurrency(), 32, 64, 16, 64); - } else if (target->device_type == kDLGPU) { - // TODO(jcf94): temp implementation, max vectorize size in GPU is related - // to the data type - auto hardware_params = HardwareParams(100000, 16, 64, 4, 64); - auto* p_hardware_params = hardware_params.CopyOnWrite(); - - auto ctx = TVMContext{kDLGPU, 0}; - auto func = tvm::runtime::Registry::Get("device_api.gpu"); - CHECK(func != nullptr) << "Cannot find GPU device_api in registry"; - auto device_api = - static_cast(((*func)()).operator void*()); - - tvm::runtime::TVMRetValue ret; - device_api->GetAttr( - ctx, tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, &ret); - p_hardware_params->max_shared_memory_per_block = ret; - - device_api->GetAttr( - ctx, tvm::runtime::DeviceAttrKind::kMaxRegistersPerBlock, &ret); - p_hardware_params->max_registers_per_block = ret; - - device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, - &ret); - p_hardware_params->max_threads_per_block = ret; - - device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kWarpSize, &ret); - p_hardware_params->warp_size = ret; - - // Manually set now - p_hardware_params->max_vthread_extent = 4; - - return hardware_params; - } else if (target->device_type == kDLOpenCL) { - // TODO(jcf94): temp implementation - auto hardware_params = HardwareParams(100000, 16, 64, 4, 64); - auto p_hardware_params = hardware_params.CopyOnWrite(); - - auto ctx = TVMContext{kDLOpenCL, 0}; - auto func = tvm::runtime::Registry::Get("device_api.opencl"); - CHECK(func != nullptr) << "Cannot find GPU device_api in registry"; - auto device_api = - static_cast(((*func)()).operator void*()); - - tvm::runtime::TVMRetValue ret; - device_api->GetAttr( - ctx, tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, &ret); - p_hardware_params->max_shared_memory_per_block = ret; - - device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, - &ret); - p_hardware_params->max_threads_per_block = ret; - - device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kWarpSize, &ret); - p_hardware_params->warp_size = ret; - - // Manually set now - p_hardware_params->max_vthread_extent = 4; - - return hardware_params; } else { LOG(FATAL) << "No default hardware parameters for target: " << target; } diff --git a/tests/python/unittest/test_ansor_loop_state.py b/tests/python/unittest/test_ansor_loop_state.py index d90be1a78421..35894354349f 100644 --- a/tests/python/unittest/test_ansor_loop_state.py +++ b/tests/python/unittest/test_ansor_loop_state.py @@ -26,7 +26,7 @@ from test_ansor_common import matmul_ansor_test, conv2d_nchw_bn_relu -def test_split_fuse_reorder_annotation(): +def test_split_fuse_reorder(): A, B, C = matmul_ansor_test(512, 512, 512) dag = ansor.ComputeDAG([A, B, C]) s0 = dag.get_init_state() @@ -61,541 +61,5 @@ def test_split_fuse_reorder_annotation(): assert s1[C].iters[4].range.extent == 8 assert s1[C].iters[5].range.extent == 2 - s1.parallel(C, j1) - s1.unroll(C, j2) - s1.vectorize(C, j3) - s1.bind_thread(C, i1, "blockIdx.x") - s1.bind_thread(C, i2, "vthread") - s1.bind_thread(C, i3, "threadIdx.y") - - -def test_follow_split_follow_fused_split(): - A, B, C = matmul_ansor_test(512, 512, 512) - dag = ansor.ComputeDAG([A, B, C]) - s0 = dag.get_init_state() - - C_global = s0.cache_write(C, "global") - - its0 = s0.split(C, s0[C].iters[0], [4, 2, 8, 4], True) - split_step0 = s0.transform_steps_size() - 1 - for level in range(1, 6): - tmp = s0.copy() - tmp.follow_split(C_global, tmp[C_global].iters[0], split_step0, level) - for i in range(0, level): - assert tmp[C].iters[i].range.extent == \ - tmp[C_global].iters[i].range.extent - - its1 = s0.split(C, s0[C].iters[5], [2, 2, 4, 8]) - split_step1 = s0.transform_steps_size() - 1 - its = [] - for i0, i1 in zip(its0, its1): - its.append(i0) - its.append(i1) - s0.reorder(C, its) - for i in range(0, 5): - s0.fuse(C, [s0[C].iters[i], s0[C].iters[i + 1]]) - - for level in range(0, 4): - tmp = s0.copy() - tmp.follow_fused_split(C_global, tmp[C_global].iters[0], - [split_step0, split_step1], level, False) - assert tmp[C].iters[level + 1].range.extent == \ - tmp[C_global].iters[0].range.extent - - for level in range(0, 4): - tmp = s0.copy() - tmp.follow_fused_split(C_global, tmp[C_global].iters[0], - [split_step0, split_step1], level, True) - assert tmp[C].iters[level + 1].range.extent == \ - tmp[C_global].iters[1].range.extent - - -def test_compute_at_root_inline(): - dag = ansor.ComputeDAG(conv2d_nchw_bn_relu(1, 224, 224, 3, 64, 7, 2, 3)) - s0 = dag.get_init_state() - - # data, padding, kernel = 0, 1, 2 - conv = s0.stage_ops[3] - # bias = 4 - bias_add = s0.stage_ops[5] - # bn_scale = 6 - bn_mul = s0.stage_ops[7] - # bn_offset = 8 - bn_add = s0.stage_ops[9] - relu = s0.stage_ops[10] - - s0.compute_inline(bn_add) - s0.compute_inline(bn_mul) - s0.compute_inline(bias_add) - s0.compute_at(conv, relu, s0[relu].iters[2]) - assert str(s0) == \ - "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + \ - "for i1 (0,3)\n" + \ - " for i2 (0,230)\n" + \ - " for i3 (0,230)\n" + \ - " pad_temp = ...\n" + \ - "for i1 (0,64)\n" + \ - " for i2 (0,112)\n" + \ - " for nn (None)\n" + \ - " for ff (None)\n" + \ - " for yy (None)\n" + \ - " for xx (None)\n" + \ - " for rc (None)\n" + \ - " for ry (None)\n" + \ - " for rx (None)\n" + \ - " compute = ...\n" + \ - " for i3 (0,112)\n" + \ - " compute = ...\n" - - s0.compute_root(conv) - s0.compute_root(bn_mul) - assert str(s0) == \ - "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + \ - "for i1 (0,3)\n" + \ - " for i2 (0,230)\n" + \ - " for i3 (0,230)\n" + \ - " pad_temp = ...\n" + \ - "for nn (None)\n" + \ - " for ff (None)\n" + \ - " for yy (None)\n" + \ - " for xx (None)\n" + \ - " for rc (None)\n" + \ - " for ry (None)\n" + \ - " for rx (None)\n" + \ - " compute = ...\n" + \ - "for i (None)\n" + \ - " for j (None)\n" + \ - " for k (None)\n" + \ - " for l (None)\n" + \ - " Bn_mul = ...\n" + \ - "for i1 (0,64)\n" + \ - " for i2 (0,112)\n" + \ - " for i3 (0,112)\n" + \ - " compute = ...\n" - - -def test_cache_read_write(): - N, H, W, CO, CI, KH, KW, strides, padding = 4, 7, 7, 512, 512, 3, 3, ( - 1, 1), (1, 1) - - data = te.placeholder((N, CI, H, W), name='Data') - kernel_data = te.placeholder((CO, CI, KH, KW), name='Kernel_data') - k0, k1 = te.compute(kernel_data.shape, - lambda *i: (kernel_data(*i)+1, kernel_data(*i)/2), - name='Kernel_split') - kernel = te.compute(kernel_data.shape, - lambda *i: k0(*i) + k1(*i), - name='Kernel') - conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation=1) - relu = topi.nn.relu(conv) - add = topi.add(data, relu) - - dag = ansor.ComputeDAG([data, kernel_data, add]) - s0 = dag.get_init_state() - - pad_temp = s0.stage_ops[1] - kernel_split = s0.stage_ops[3] - - # 0: init state - ori_its = s0[add].iters - its = s0.split(add, s0[add].iters[0], [2]) - s0.reorder(add, [its[0], ori_its[1], its[1], ori_its[2], ori_its[3]]) - s0.compute_inline(relu) - - # 1: simple cache_write with compute_at - conv_global = s0.cache_write(conv, "global") - s0.compute_at(conv_global, conv, s0[conv].iters[3]) - - # 2: simple cache_read with compute_at - kernel_global = s0.cache_read(kernel, "global", [conv_global]) - s0.compute_at(kernel_global, conv_global, s0[conv_global].iters[4]) - assert str(s0) == \ - "Placeholder: Data, Kernel_data\n" + \ - "for i0 (0,4)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,9)\n" + \ - " for i3 (0,9)\n" + \ - " pad_temp = ...\n" + \ - "for i0 (0,512)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,3)\n" + \ - " for i3 (0,3)\n" + \ - " Kernel_split = ...\n" + \ - "for i0 (0,512)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,3)\n" + \ - " for i3 (0,3)\n" + \ - " Kernel = ...\n" + \ - "for nn (0,4)\n" + \ - " for ff (0,512)\n" + \ - " for yy (0,7)\n" + \ - " for xx (0,7)\n" + \ - " for nn_c (None)\n" + \ - " for ff_c (None)\n" + \ - " for yy_c (None)\n" + \ - " for xx_c (None)\n" + \ - " for rc (None)\n" + \ - " for ax0 (None)\n" + \ - " for ax1 (None)\n" + \ - " for ax2 (None)\n" + \ - " for ax3 (None)\n" + \ - " Kernel.global = ...\n" + \ - " for ry (None)\n" + \ - " for rx (None)\n" + \ - " compute.global = ...\n" + \ - " compute = ...\n" + \ - "for ax0.0 (0,2)\n" + \ - " for ax1 (0,512)\n" + \ - " for ax0.1 (0,2)\n" + \ - " for ax2 (0,7)\n" + \ - " for ax3 (0,7)\n" + \ - " T_add = ...\n" - - # 3: two level cache_read with compute_at - # preparing for GPU's shared memory & local memory - pad_temp_global = s0.cache_read(pad_temp, "global", [conv_global]) - pad_temp_shared = s0.cache_read(pad_temp_global, "shared", [conv_global]) - s0.compute_at(pad_temp_global, conv_global, s0[conv_global].iters[2]) - s0.compute_at(pad_temp_shared, conv_global, s0[conv_global].iters[4]) - - # 4: cache_read with multi readers - # This stage cannot be compute at to its consumer - s0.cache_read(data, "global", [pad_temp, add]) - assert str(s0) == \ - "Placeholder: Data, Kernel_data\n" + \ - "for ax0 (0,4)\n" + \ - " for ax1 (0,512)\n" + \ - " for ax2 (0,7)\n" + \ - " for ax3 (0,7)\n" + \ - " Data.global = ...\n" + \ - "for i0 (0,4)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,9)\n" + \ - " for i3 (0,9)\n" + \ - " pad_temp = ...\n" + \ - "for i0 (0,512)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,3)\n" + \ - " for i3 (0,3)\n" + \ - " Kernel_split = ...\n" + \ - "for i0 (0,512)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,3)\n" + \ - " for i3 (0,3)\n" + \ - " Kernel = ...\n" + \ - "for nn (0,4)\n" + \ - " for ff (0,512)\n" + \ - " for yy (0,7)\n" + \ - " for xx (0,7)\n" + \ - " for nn_c (None)\n" + \ - " for ff_c (None)\n" + \ - " for yy_c (None)\n" + \ - " for ax0 (None)\n" + \ - " for ax1 (None)\n" + \ - " for ax2 (None)\n" + \ - " for ax3 (None)\n" + \ - " pad_temp.global = ...\n" + \ - " for xx_c (None)\n" + \ - " for rc (None)\n" + \ - " for ax0 (None)\n" + \ - " for ax1 (None)\n" + \ - " for ax2 (None)\n" + \ - " for ax3 (None)\n" + \ - " Kernel.global = ...\n" + \ - " for ax0 (None)\n" + \ - " for ax1 (None)\n" + \ - " for ax2 (None)\n" + \ - " for ax3 (None)\n" + \ - " pad_temp.global.shared = ...\n" + \ - " for ry (None)\n" + \ - " for rx (None)\n" + \ - " compute.global = ...\n" + \ - " compute = ...\n" + \ - "for ax0.0 (0,2)\n" + \ - " for ax1 (0,512)\n" + \ - " for ax0.1 (0,2)\n" + \ - " for ax2 (0,7)\n" + \ - " for ax3 (0,7)\n" + \ - " T_add = ...\n" - - # 5: cache_write with multi outputs - # TVM's cache_write actually has a bug with this case: - # - # After schedule.cache_write, TVM generate one new stage: - # From: kernel_data -> kernel_split -> kernel - # To: kernel_data -> kernel_split_global -> kernel_split -> kernel - # - # But with topo sort analyse, we get: - # // kernel_data -> kernel_split_global -> kernel_split -> kernel - # \ / - # ----------------> kernel_split ----------------> - # - # Seems there's bug with the input/output tensor. Such multi outputs case - # should be unusual, so we make some hack on DoCacheWrite - # To be fixed in the future - s0.cache_write(kernel_split, "global") - assert str(s0) == \ - "Placeholder: Data, Kernel_data\n" + \ - "for ax0 (0,4)\n" + \ - " for ax1 (0,512)\n" + \ - " for ax2 (0,7)\n" + \ - " for ax3 (0,7)\n" + \ - " Data.global = ...\n" + \ - "for i0 (0,4)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,9)\n" + \ - " for i3 (0,9)\n" + \ - " pad_temp = ...\n" + \ - "for i0_c (0,512)\n" + \ - " for i1_c (0,512)\n" + \ - " for i2_c (0,3)\n" + \ - " for i3_c (0,3)\n" + \ - " Kernel_split.global = ...\n" + \ - "for i0 (0,512)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,3)\n" + \ - " for i3 (0,3)\n" + \ - " Kernel_split = ...\n" + \ - "for i0 (0,512)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,3)\n" + \ - " for i3 (0,3)\n" + \ - " Kernel_split = ...\n" + \ - "for i0 (0,512)\n" + \ - " for i1 (0,512)\n" + \ - " for i2 (0,3)\n" + \ - " for i3 (0,3)\n" + \ - " Kernel = ...\n" + \ - "for nn (0,4)\n" + \ - " for ff (0,512)\n" + \ - " for yy (0,7)\n" + \ - " for xx (0,7)\n" + \ - " for nn_c (None)\n" + \ - " for ff_c (None)\n" + \ - " for yy_c (None)\n" + \ - " for ax0 (None)\n" + \ - " for ax1 (None)\n" + \ - " for ax2 (None)\n" + \ - " for ax3 (None)\n" + \ - " pad_temp.global = ...\n" + \ - " for xx_c (None)\n" + \ - " for rc (None)\n" + \ - " for ax0 (None)\n" + \ - " for ax1 (None)\n" + \ - " for ax2 (None)\n" + \ - " for ax3 (None)\n" + \ - " Kernel.global = ...\n" + \ - " for ax0 (None)\n" + \ - " for ax1 (None)\n" + \ - " for ax2 (None)\n" + \ - " for ax3 (None)\n" + \ - " pad_temp.global.shared = ...\n" + \ - " for ry (None)\n" + \ - " for rx (None)\n" + \ - " compute.global = ...\n" + \ - " compute = ...\n" + \ - "for ax0.0 (0,2)\n" + \ - " for ax1 (0,512)\n" + \ - " for ax0.1 (0,2)\n" + \ - " for ax2 (0,7)\n" + \ - " for ax3 (0,7)\n" + \ - " T_add = ...\n" - - -def test_rfactor(): - A, B, C = matmul_ansor_test(8, 8, 512) - dag = ansor.ComputeDAG([A, B, C]) - s0 = dag.get_init_state() - - ko, ki = s0.split(C, s0[C].iters[2], [16]) - - s1 = s0.copy() - s1.rfactor(C, ko, 2) - assert str(s1) == \ - "Placeholder: A, B\n" + \ - "for i (0,8)\n" + \ - " for j (0,8)\n" + \ - " for k_o (0,32)\n" + \ - " for k_i (0,16)\n" + \ - " C.rf = ...\n" + \ - "for ax0 (0,8)\n" + \ - " for ax1 (0,8)\n" + \ - " for k_o_v (0,32)\n" + \ - " C.repl = ...\n" - - s2 = s0.copy() - s2.rfactor(C, ki, 2) - assert str(s2) == \ - "Placeholder: A, B\n" + \ - "for i (0,8)\n" + \ - " for j (0,8)\n" + \ - " for k_i (0,16)\n" + \ - " for k_o (0,32)\n" + \ - " C.rf = ...\n" + \ - "for ax0 (0,8)\n" + \ - " for ax1 (0,8)\n" + \ - " for k_i_v (0,16)\n" + \ - " C.repl = ...\n" - - -def vcf_init_common(): - A, B, C = matmul_ansor_test(512, 512, 512) - dag = ansor.ComputeDAG([A, B, C]) - s0 = dag.get_init_state() - B_shared = s0.cache_read(B, "shared", [C]) - B_local = s0.cache_read(B_shared, "local", [C]) - A_shared = s0.cache_read(A, "shared", [C]) - A_local = s0.cache_read(A_shared, "local", [C]) - - return A_shared, A_local, B_shared, B_local, C, dag, s0 - - -def vcf_check_common(dag, state): - s, args = dag.apply_steps_from_state(state) - # To check if every vectorize loop transforms to ramp expr successfully - # TODO(jcf94): Find a better way to process the check in AST - print(tvm.lower(s, args)) - - if tvm.context("cuda", 0).exist: - tgt = tvm.target.cuda() - mod = tvm.build(s, args, tgt) - # To check if every vectorize loop transforms to correct instruction - print(mod.imported_modules[0].get_source()) - - ctx = tvm.context("cuda", 0) - dtype = dag.tensors[0].dtype - a = tvm.nd.array(np.random.uniform(size=(512, 512)).astype(dtype), ctx) - b = tvm.nd.array(np.random.uniform(size=(512, 512)).astype(dtype), ctx) - c = tvm.nd.array(np.zeros((512, 512), dtype=dtype), ctx) - mod(a, b, c) - tvm.testing.assert_allclose(c.asnumpy(), np.dot( - a.asnumpy(), b.asnumpy()), rtol=1e-5) - else: - print("CUDA device not found, skip this test.") - - -def test_vectorized_cooperative_fetching_x(): - A_shared, A_local, B_shared, B_local, C, dag, s0 = vcf_init_common() - - its0 = s0.split(C, s0[C].iters[0], [1, 8, 2, 4]) - its1 = s0.split(C, s0[C].iters[5], [2, 8, 2, 4]) - its2 = s0.split(C, s0[C].iters[10], [8, 8]) - s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its2[0], - its2[1], its0[3], its1[3], its2[2], its0[4], its1[4]]) - s0.fuse(C, [s0[C].iters[0], s0[C].iters[1]]) - s0.bind_thread(C, s0[C].iters[0], "blockIdx.x") - s0.fuse(C, [s0[C].iters[1], s0[C].iters[2]]) - s0.bind_thread(C, s0[C].iters[1], "vthread") - s0.fuse(C, [s0[C].iters[2], s0[C].iters[3]]) - s0.bind_thread(C, s0[C].iters[2], "threadIdx.x") - s0.vectorize(C, its1[4]) - - s0.compute_at(B_shared, C, s0[C].iters[3]) - fused_it = s0.fuse(B_shared, s0[B_shared].iters[:]) - its = s0.split(B_shared, fused_it, [64, 4]) - s0.bind_thread(B_shared, its[1], "threadIdx.x") - s0.vectorize(B_shared, its[2]) - s0.compute_at(B_local, C, s0[C].iters[4]) - fused_it = s0.fuse(B_local, s0[B_local].iters[:]) - its = s0.split(B_local, fused_it, [4]) - s0.vectorize(B_local, its[1]) - - s0.compute_at(A_shared, C, s0[C].iters[3]) - fused_it = s0.fuse(A_shared, s0[A_shared].iters[:]) - its = s0.split(A_shared, fused_it, [64, 4]) - s0.bind_thread(A_shared, its[1], "threadIdx.x") - s0.vectorize(A_shared, its[2]) - s0.compute_at(A_local, C, s0[C].iters[4]) - fused_it = s0.fuse(A_local, s0[A_local].iters[:]) - its = s0.split(A_local, fused_it, [4]) - s0.vectorize(A_local, its[1]) - - vcf_check_common(dag, s0) - - -def test_vectorized_cooperative_fetching_xy(): - A_shared, A_local, B_shared, B_local, C, dag, s0 = vcf_init_common() - - its0 = s0.split(C, s0[C].iters[0], [1, 8, 2, 4]) - its1 = s0.split(C, s0[C].iters[5], [2, 8, 2, 4]) - its2 = s0.split(C, s0[C].iters[10], [8, 8]) - s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its2[0], - its2[1], its0[3], its1[3], its2[2], its0[4], its1[4]]) - s0.fuse(C, [s0[C].iters[0], s0[C].iters[1]]) - s0.bind_thread(C, s0[C].iters[0], "blockIdx.x") - s0.fuse(C, [s0[C].iters[1], s0[C].iters[2]]) - s0.bind_thread(C, s0[C].iters[1], "vthread") - s0.bind_thread(C, s0[C].iters[2], "threadIdx.x") - s0.bind_thread(C, s0[C].iters[3], "threadIdx.y") - s0.vectorize(C, its1[4]) - - s0.compute_at(B_shared, C, s0[C].iters[4]) - fused_it = s0.fuse(B_shared, s0[B_shared].iters[:]) - its = s0.split(B_shared, fused_it, [8, 8, 4]) - s0.bind_thread(B_shared, its[1], "threadIdx.x") - s0.bind_thread(B_shared, its[2], "threadIdx.y") - s0.vectorize(B_shared, its[3]) - s0.compute_at(B_local, C, s0[C].iters[5]) - fused_it = s0.fuse(B_local, s0[B_local].iters[:]) - its = s0.split(B_local, fused_it, [4]) - s0.vectorize(B_local, its[1]) - - s0.compute_at(A_shared, C, s0[C].iters[4]) - fused_it = s0.fuse(A_shared, s0[A_shared].iters[:]) - its = s0.split(A_shared, fused_it, [8, 8, 4]) - s0.bind_thread(A_shared, its[1], "threadIdx.x") - s0.bind_thread(A_shared, its[2], "threadIdx.y") - s0.vectorize(A_shared, its[3]) - s0.compute_at(A_local, C, s0[C].iters[5]) - fused_it = s0.fuse(A_local, s0[A_local].iters[:]) - its = s0.split(A_local, fused_it, [4]) - s0.vectorize(A_local, its[1]) - - vcf_check_common(dag, s0) - - -@tvm._ffi.register_func -def test_intrin_gemv(): - m = 16 - l = 64 - a = te.placeholder((l,), name='a') - b = te.placeholder((l, m), name='b') - k = te.reduce_axis((0, l), name='k') - c = te.compute((m,), lambda i: te.sum(a[k] * b[k, i], axis=k), name='c') - Ab = tvm.tir.decl_buffer(a.shape, a.dtype, name="A", - offset_factor=1, strides=[1]) - Bb = tvm.tir.decl_buffer(b.shape, b.dtype, name="B", - offset_factor=1, strides=[te.var("s0"), 1]) - Cb = tvm.tir.decl_buffer(c.shape, c.dtype, name="C", - offset_factor=1, strides=[1]) - def intrin_func(ins, outs): - ib = tvm.tir.ir_builder.create() - aa, bb = ins - cc = outs[0] - ib.emit(tvm.tir.call_extern("float32", "gemv_update", - cc.access_ptr("w"), - aa.access_ptr("r"), - bb.access_ptr("r"))) - return ib.get() - return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb}) - -def test_tensorize(): - A, B, C = matmul_ansor_test(1024, 512, 64) - dag = ansor.ComputeDAG([A, B, C]) - s0 = dag.get_init_state() - - its = s0.split(C, s0[C].iters[1], [16]) - s0.tensorize(C, its[1], "test_intrin_gemv") - - sch, tensors = dag.apply_steps_from_state(s0) - tvm.lower(sch, tensors, simple_mode=True) - if __name__ == "__main__": - test_split_fuse_reorder_annotation() - test_follow_split_follow_fused_split() - test_compute_at_root_inline() - test_cache_read_write() - test_rfactor() - test_vectorized_cooperative_fetching_x() - test_vectorized_cooperative_fetching_xy() - test_tensorize() + test_split_fuse_reorder() diff --git a/tests/python/unittest/test_ansor_measure.py b/tests/python/unittest/test_ansor_measure.py index d457dd2c55cc..f8d41edd27dd 100644 --- a/tests/python/unittest/test_ansor_measure.py +++ b/tests/python/unittest/test_ansor_measure.py @@ -62,24 +62,6 @@ def test_measure_local_builder_runner(): assert mress[0].error_no == 0 -def test_measure_local_builder_rpc_runner(): - dag, s0 = get_tiled_matmul() - - tgt = tvm.target.create("llvm") - task = ansor.SearchTask(dag, "test", tgt) - minp = ansor.MeasureInput(task, s0) - - local_builder = ansor.LocalBuilder() - measure_ctx = ansor.LocalRPCMeasureContext() - rpc_runner = measure_ctx.runner - - bress = local_builder.build([minp]) - assert bress[0].error_no == 0 - mress = rpc_runner.run([minp], bress) - assert mress[0].error_no == 0 - - if __name__ == "__main__": test_serialization() test_measure_local_builder_runner() - test_measure_local_builder_rpc_runner() diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py index deff561a4547..984434b9c58b 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -79,90 +79,5 @@ def test_search_basic(): t.start() t.join() - -def test_search_xgb_model_rpc_runner(): - measure_ctx = ansor.LocalRPCMeasureContext() - search_common(seed=456787236, cost_model=ansor.XGBModel(), - runner=measure_ctx.runner) - - -def test_search_opencl(): - if tvm.context("opencl", 0).exist: - measure_ctx = ansor.LocalRPCMeasureContext() - search_common("opencl", 380344973, measure_ctx.runner) - else: - print("OpenCL device not found, skip this test.") - - -def test_search_cuda(): - if tvm.context("cuda", 0).exist: - measure_ctx = ansor.LocalRPCMeasureContext() - search_common("cuda", 903667810, measure_ctx.runner) - else: - print("CUDA device not found, skip this test.") - - -def test_search_custom_sketch_rule(): - def meet_condition_func(meta_policy, state, stage_id): - # Apply and Skip the Rest if this function does not return - pass - - # Expecting: - # i.0 - # i.1 - # i.2 - # j.0 - # j.1 - # ax0 - # ax1 - # B.global - # j.2 - # k - # C - def apply_func1(meta_policy, state, stage_id): - # Stage by stage way - ret = [] - if stage_id == 2: - state = ansor.loop_state.State(state, meta_policy.cur_task.compute_dag) - state.split(2, state.stages[2].iters[0], [4, 4]) - state.split(2, state.stages[2].iters[3], [4, 4]) - ret.append([state.state_object, stage_id - 1]) - elif stage_id == 1: - state = ansor.loop_state.State(state, meta_policy.cur_task.compute_dag) - state.cache_read(1, "global", [2]) - state.compute_at(2, 3, state.stages[3].iters[4]) - ret.append([state.state_object, stage_id - 1]) - else: - ret.append([state, stage_id - 1]) - return ret - - def apply_func2(meta_policy, state, stage_id): - # More template like way - ret = [] - state = ansor.loop_state.State(state, meta_policy.cur_task.compute_dag) - - state.split(2, state.stages[2].iters[0], [4, 4]) - state.split(2, state.stages[2].iters[3], [4, 4]) - state.cache_read(1, "global", [2]) - state.compute_at(2, 3, state.stages[3].iters[4]) - - ret.append([state.state_object, -1]) - return ret - - measure_ctx = ansor.LocalRPCMeasureContext() - search_common(seed=887823438, runner=measure_ctx.runner, - pre_search_callbacks=[ansor.PreloadCustomSketchRule( - meet_condition_func, apply_func1)], - params={'disable_change_compute_location': 1}) - search_common(seed=887823438, runner=measure_ctx.runner, - pre_search_callbacks=[ansor.PreloadCustomSketchRule( - meet_condition_func, apply_func2)], - params={'disable_change_compute_location': 1}) - - if __name__ == "__main__": test_search_basic() - test_search_xgb_model_rpc_runner() - test_search_opencl() - test_search_cuda() - test_search_custom_sketch_rule() From d567617912f162e756b778442a7b9198088dd780 Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 24 Jun 2020 15:14:43 +0800 Subject: [PATCH 43/45] Code clean for minimum Ansor system --- python/tvm/ansor/compute_dag.py | 2 +- python/tvm/ansor/measure.py | 99 +- src/ansor/auto_schedule.cc | 1 - src/ansor/compute_dag.cc | 359 +--- src/ansor/compute_dag.h | 4 - src/ansor/cost_model/cost_model.cc | 121 -- src/ansor/cost_model/cost_model.h | 59 - src/ansor/feature.cc | 1573 ----------------- src/ansor/feature.h | 80 - src/ansor/loop_state.cc | 549 ------ src/ansor/loop_state.h | 45 +- .../search_policy/sketch_search_policy.cc | 1541 ---------------- .../search_policy/sketch_search_policy.h | 157 -- src/ansor/search_policy/utils.cc | 744 -------- src/ansor/search_policy/utils.h | 483 ----- src/ansor/serialization.cc | 175 +- src/ansor/transform_step.cc | 602 ------- src/ansor/transform_step.h | 427 +---- tests/python/unittest/test_ansor_common.py | 11 +- .../python/unittest/test_ansor_compute_dag.py | 27 - 20 files changed, 10 insertions(+), 7049 deletions(-) delete mode 100644 src/ansor/feature.cc delete mode 100644 src/ansor/feature.h delete mode 100644 src/ansor/search_policy/sketch_search_policy.cc delete mode 100644 src/ansor/search_policy/sketch_search_policy.h delete mode 100644 src/ansor/search_policy/utils.cc delete mode 100644 src/ansor/search_policy/utils.h diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index acfec66a166a..e57fbbc08843 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -59,7 +59,7 @@ def apply_steps_from_state(self, state): args : List[Tensor] """ state_obj = state if isinstance(state, StateObject) else state.state_object - return _ffi_api.ComputeDAGApplyStepsFromState(self, state_obj, layout_rewrite_level) + return _ffi_api.ComputeDAGApplyStepsFromState(self, state_obj) def print_python_code_from_state(self, state): """ diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index 46c3e3aabd5d..c9a5ef013cc7 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -389,103 +389,6 @@ def local_builder_build(inputs: List[MeasureInput], timeout: float, n_parallel: return results - -@tvm._ffi.register_func("ansor.rpc_runner.run") -def rpc_runner_run(inputs: List[MeasureInput], build_results: List[BuildResult], - key: str, host: str, port: int, priority: int, timeout: float, - n_parallel: int, number: int, repeat: int, min_repeat_ms: int, - cooldown_interval: float, verbose: int): - global global_run_arguments - global_run_arguments = (inputs, build_results, key, host, port, priority, timeout, number, - repeat, min_repeat_ms, cooldown_interval, verbose) - - assert len(inputs) == len(build_results), \ - "Measure input size should be equal to build results" - pool = NoDaemonPool(n_parallel) - tuple_res = pool.map(rpc_run_worker, range(len(build_results))) - pool.terminate() - pool.join() - del pool - - results = [] - for res in tuple_res: - results.append(MeasureResult(*res)) - - if verbose >= 1: - print("") - - return results - - -def rpc_run_worker(index): - """ ... - """ - inputs, build_results, key, host, port, priority, timeout, number, \ - repeat, min_repeat_ms, cooldown_interval, verbose = global_run_arguments - - MAX_FLOAT = 1e10 # We use 1e10 instead of sys.float_info.max for better readability in log - inp = inputs[index] - build_res = build_results[index] - - if build_res.error_no != MeasureErrorNo.NO_ERROR: - return (MAX_FLOAT,), build_res.error_no, build_res.error_msg, build_res.time_cost, \ - time.time() - - def timed_func(): - tic = time.time() - error_no = 0 - error_msg = None - try: - # upload built module - remote = request_remote(key, host, port, priority, timeout) - remote.upload(build_res.filename) - func = remote.load_module(os.path.split(build_res.filename)[1]) - ctx = remote.context(str(inp.task.target), 0) - time_f = func.time_evaluator( - func.entry_name, ctx, number=number, repeat=repeat, min_repeat_ms=min_repeat_ms) - except Exception: - costs = (MAX_FLOAT,) - error_no = MeasureErrorNo.COMPILE_DEVICE - error_msg = make_error_msg() - - if error_no == 0: - try: - args = [ndarray.non_empty(get_const_tuple(x.shape), x.dtype, ctx) for x in - build_res.args] - ctx.sync() - - costs = time_f(*args).results - # clean up remote files - remote.remove(build_res.filename) - remote.remove(os.path.splitext(build_res.filename)[0] + '.so') - remote.remove('') - except Exception: - costs = (MAX_FLOAT,) - error_no = MeasureErrorNo.RUNTIME_DEVICE - error_msg = make_error_msg() - - shutil.rmtree(os.path.dirname(build_res.filename)) - toc = time.time() - - time.sleep(cooldown_interval) - if verbose >= 1: - if error_no == MeasureErrorNo.NO_ERROR: - print("*", end="") - else: - print("*E", end="") # Run error - - return costs, error_no, error_msg, toc - tic + build_res.time_cost, toc - - res = call_func_with_timeout(timeout, timed_func) - - if isinstance(res, TimeoutError): - if verbose >= 1: - print("*T", end="") # Run timeout - res = (MAX_FLOAT,), MeasureErrorNo.RUN_TIMEOUT, None, build_res.time_cost + \ - timeout, time.time() - return res - - @tvm._ffi.register_func("ansor.local_runner.run") def local_run(inputs: List[MeasureInput], build_results: List[BuildResult], timeout: float, number: int, repeat: int, min_repeat_ms: int, @@ -510,7 +413,7 @@ def timed_func(inp, build_res): if error_no == 0: try: - args = [ndarray.non_empty(get_const_tuple(x.shape), x.dtype, ctx) for x in + args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args] ctx.sync() diff --git a/src/ansor/auto_schedule.cc b/src/ansor/auto_schedule.cc index 05cb95c2c451..82ec07930adc 100644 --- a/src/ansor/auto_schedule.cc +++ b/src/ansor/auto_schedule.cc @@ -26,7 +26,6 @@ #include #include #include -#include "search_policy/sketch_search_policy.h" namespace tvm { namespace ansor { diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index d7af8b94729a..7638f98e65ea 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -37,7 +37,6 @@ #include #include #include "transform_step.h" -#include "search_policy/utils.h" namespace tvm { namespace ansor { @@ -599,323 +598,6 @@ std::string BaseName(const std::string& str) { return str.substr(0, str.rfind("_")); } -class IndexRewriter : public StmtExprMutator { - public: - IndexRewriter(const OperationMap >& placeholder_new_names, - const OperationMap >& placeholder_new_shapes): - placeholder_new_names_(placeholder_new_names), - placeholder_new_shapes_(placeholder_new_shapes) {} - - PrimExpr Rewrite(PrimExpr expr) { - return this->VisitExpr(expr); - } - - PrimExpr VisitExpr_(const ProducerLoadNode* op) final { - te::Tensor t = Downcast(op->producer); - auto it = placeholder_new_names_.find(t->op); - if (it != placeholder_new_names_.end()) { - const std::vector& new_names = it->second; - const Array& new_shape = placeholder_new_shapes_.at(t->op); - std::unordered_map name_to_arg; - for (const auto& arg : op->indices) { - std::string axis_name; - if (const auto* pimm = arg.as()) { - CHECK_EQ(pimm->value, 0); - axis_name = "IntImm"; - } else { - axis_name = BaseName(CleanName(Downcast(arg)->name_hint)); - CHECK_EQ(name_to_arg.count(axis_name), 0); - name_to_arg[axis_name] = arg; - } - } - - std::unordered_map div_factors; - std::vector r_new_args; - for (int i = new_names.size() - 1; i >= 0; --i) { - auto ori_iter_name = new_names[i]; - auto name_it = name_to_arg.find(ori_iter_name); - CHECK(name_it != name_to_arg.end()); - PrimExpr ori_arg = name_it->second; - - PrimExpr mod_factor = new_shape[i]; - - PrimExpr div_factor = 1; - if (div_factors.count(ori_iter_name)) { - div_factor = div_factors[ori_iter_name]; - } - div_factors[ori_iter_name] = div_factor * new_shape[i]; - - PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor); - - r_new_args.push_back(new_arg); - } - - Array new_args(std::make_move_iterator(r_new_args.rbegin()), - std::make_move_iterator(r_new_args.rend())); - - return ProducerLoad(op->producer, new_args); - } - return GetRef(op); - } - - private: - const OperationMap >& placeholder_new_names_; - const OperationMap >& placeholder_new_shapes_; -}; - -void ComputeDAG::RewriteLayout( - const std::vector &transform_steps, LayoutRewriteLevel layout_rewrite_level) const { - ComputeDAGNode* pdag = const_cast(this)->CopyOnWrite(); - const State& state = ReplayAndInferBound(transform_steps); - - OperationMap > placeholder_new_names; - OperationMap > placeholder_new_shapes; - int stage_id = -1; - for (const auto& stage : state->stages) { - stage_id += 1; - const te::Operation& op = stage->op; - if (op->IsInstance()) { - const Map& attrs = op->attrs; - if (attrs.count(layout_free_placeholders_key)) { - const ObjectRef& attr_value = attrs[layout_free_placeholders_key]; - Array placeholders = Downcast>(attr_value); - for (auto& placeholder : placeholders) { - const auto placeholder_op = placeholder->op; - - // Check whether this placeholder has already been handled - if (placeholder_new_names.count(placeholder_op)) { - continue; - } - - // skip the op that is not direct consumer of this placeholder, - // mostly due to cache read/write. - bool direct_consumer = false; - for (auto& t : op->InputTensors()) { - if (t->op == placeholder_op) { - direct_consumer = true; - break; - } - } - if (!direct_consumer) { - continue; - } - - std::set placeholder_axis_names; - TensorAccessExtractor extractor; - for (const auto& exp : op.as()->body) { - extractor.Extract(exp); - } - bool rewrite_placeholder = (layout_rewrite_level == kPlaceholderRewrite || - layout_rewrite_level == kBothRewrite); - bool rewrite_body = (layout_rewrite_level == kComputeRewrite || - layout_rewrite_level == kBothRewrite); - std::ostringstream os; - - uint i = 0; - if (extractor.buf_accesses.count(placeholder_op)) { - for (const auto& ev : extractor.buf_accesses[placeholder_op]) { - for (const auto& e : ev) { - // TODO(minminsun): check whether the extents match the shape of placeholder - std::string axis_name; - if (const auto* pimm = e.as()) { - CHECK_EQ(pimm->value, 0); - // CHECK_EQ(placeholder->shape[i].as()->value, 1); - axis_name = "IntImm"; - } else { - axis_name = BaseName(CleanName(Downcast(e)->name_hint)); - } - - placeholder_axis_names.insert(axis_name); - if (rewrite_placeholder) { - os << placeholder->shape[i++] << axis_name; - } - } - } - - if (rewrite_placeholder) { - CHECK_EQ(placeholder_axis_names.size(), placeholder->shape.size()); - std::string ori_layout = os.str(); - os.str(""); - // ::tvm::relay::KernelLayoutVisitor::global_ori_layouts_queue.push_back(ori_layout); - } - } - - std::vector stage_iters; - - auto attach_it = state->attach_map->stage_to_attach_iter.find(stage_id); - int attach_pos = -1; - size_t iters_before_attach = 0; - if (attach_it != state->attach_map->stage_to_attach_iter.end()) { - auto attach = attach_it->second; - const auto& attach_stage = state->stages[attach.first]; - attach_pos = attach.second; - stage_iters.insert(stage_iters.end(), - attach_stage->iters.begin(), - attach_stage->iters.begin() + attach_pos + 1); - } - - stage_iters.insert(stage_iters.end(), stage->iters.begin(), stage->iters.end()); - - std::vector iters; - for (size_t i = 0; i < stage_iters.size(); ++i) { - const auto& iter = stage_iters[i]; - if (iter->ori_iters.empty()) { - iters.push_back(iter); - } else { - for (const Iterator& ori_iter : iter->ori_iters) { - iters.push_back(ori_iter); - } - } - if (static_cast(i) == attach_pos) { - iters_before_attach = iters.size(); - } - } - - std::vector new_names; - Array new_shape; - std::vector new_axis_names; - for (const Iterator& iter : iters) { - std::set ori_iter_names; - ExtractOriginalIterators(iter->name, &ori_iter_names); - // fused iters have been replaced with iter->ori_iters. - // So there should be only one ori iter name extracted from iter->name. - CHECK_EQ(ori_iter_names.size(), 1); - auto ori_iter_name = BaseName(*ori_iter_names.begin()); - new_axis_names.push_back(ori_iter_name); - } - for (size_t i = 0; i < new_axis_names.size(); ++i) { - auto iter = iters[i]; - std::string ori_iter_name; - if (i < iters_before_attach) { - ori_iter_name = new_axis_names[i + iters_before_attach]; - } else { - ori_iter_name = new_axis_names[i]; - } - if (placeholder_axis_names.count(ori_iter_name)) { - os << iter->range->extent << ori_iter_name; - new_names.push_back(ori_iter_name); - new_shape.push_back(iter->range->extent); - } - } - std::string new_layout = os.str(); - os.str(""); - // ::tvm::relay::KernelLayoutVisitor::global_new_layouts_queue.push_back(new_layout); - placeholder_new_names[placeholder_op] = new_names; - placeholder_new_shapes[placeholder_op] = new_shape; - - Array old_ops = pdag->ops; - ArrayNode* pops = pdag->ops.CopyOnWrite(); - - // Create new placeholder - te::Operation new_placeholder_op; - if (rewrite_placeholder) { - new_placeholder_op = - te::PlaceholderOp(placeholder_op->name, - new_shape, - placeholder_op.as()->dtype); - } else { - new_placeholder_op = placeholder_op; - } - - te::Operation new_compute_op, old_compute_op; - if (rewrite_body) { - Array new_body; - IndexRewriter index_rewriter(placeholder_new_names, - placeholder_new_shapes); - for (auto& op : old_ops) { - if (auto* pop = op.as()) { - bool need_update = false; - for (auto& t : op->InputTensors()) { - if (t->op == placeholder_op) { - need_update = true; - break; - } - } - if (need_update) { - for (auto& body : pop->body) { - new_body.push_back(index_rewriter.Rewrite(body)); - } - old_compute_op = op; - CHECK(!new_compute_op.defined()); - new_compute_op = te::ComputeOp( - pop->name, pop->tag, pop->attrs, pop->axis, new_body); - } - } - } - } - - // construct the map from old_op to new_op - std::unordered_map updated_ops; - for (size_t i = 0; i < old_ops.size(); ++i) { - auto old_op = old_ops[i]; - if (rewrite_placeholder && old_op == placeholder_op) { - pops->SetItem(i, new_placeholder_op); - updated_ops[placeholder_op] = new_placeholder_op; - } else if (rewrite_body && old_op == old_compute_op) { - pops->SetItem(i, new_compute_op); - updated_ops[old_compute_op] = new_compute_op; - } else { - pops->SetItem(i, old_op); - } - } - - // Because ops is sorted in topo-order, only do one pass linear scan here. - for (size_t i = 0; i < pops->size(); ++i) { - auto old_op = Downcast(pops->at(i)); - if (auto* pop = old_op.as()) { - auto inputs = pop->InputTensors(); - std::unordered_map rmap; - for (auto input : inputs) { - auto it = updated_ops.find(input->op); - te::Operation new_op; - while (it != updated_ops.end()) { - new_op = it->second; - it = updated_ops.find(new_op); - } - if (new_op.defined()) { - int index = input->value_index; - rmap[input] = new_op.output(index); - } - } - if (!rmap.empty()) { - te::Operation new_op = pop->ReplaceInputs(old_op, rmap); - updated_ops[old_op] = new_op; - pops->SetItem(i, new_op); - } - } - } - - pdag->init_state = State(pdag->ops); - - Array old_tensors = pdag->tensors; - ArrayNode* ptensors = pdag->tensors.CopyOnWrite(); - - for (size_t i = 0; i < old_tensors.size(); ++i) { - const auto& old_tensor = old_tensors[i]; - auto it = updated_ops.find(old_tensor->op); - te::Operation new_op; - while (it != updated_ops.end()) { - new_op = it->second; - it = updated_ops.find(new_op); - } - if (new_op.defined()) { - if (layout_rewrite_level == kBothRewrite) { - auto index = old_tensor->value_index; - ptensors->SetItem(i, new_op.output(index)); - } else if (layout_rewrite_level == kComputeRewrite) { - te::TensorNode* old_tensor_node = - const_cast(old_tensor.as()); - old_tensor_node->op = new_op; - } - } - } - } // end for placeholder - } - } - } // end for stage -} - - void UpdateStageAxis(const te::Stage& stage, StageToAxesMap *stage_to_axes) { if (auto pop = stage->op.as()) { std::vector& axes = (*stage_to_axes)[stage]; @@ -938,13 +620,7 @@ std::pair > ComputeDAG::ApplySteps( LayoutRewriteLevel layout_rewrite_level) const { std::vector stages; StageToAxesMap stage_to_axes; - if (layout_rewrite_level != kNoRewrite && !transform_steps.empty()) { - ComputeDAG new_dag = *this; - new_dag.RewriteLayout(transform_steps, layout_rewrite_level); - return new_dag.ReplaySteps(transform_steps, &stages, &stage_to_axes); - } else { - return ReplaySteps(transform_steps, &stages, &stage_to_axes); - } + return ReplaySteps(transform_steps, &stages, &stage_to_axes); } std::string ComputeDAG::PrintStepsAsPython(const std::vector& transform_steps) const { @@ -1135,32 +811,8 @@ std::pair > ComputeDAG::ReplaySteps( ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes, transform_steps); } else if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes, &schedule); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes, &schedule); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes, &schedule); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes); - } else if (auto ps = step.as()) { - ps->ApplyToSchedule(stages, stage_to_axes); } else { LOG(FATAL) << "Invalid Step"; } @@ -1270,15 +922,6 @@ TVM_REGISTER_GLOBAL("ansor.ComputeDAG") TVM_REGISTER_GLOBAL("ansor.ComputeDAGGetInitState") .set_body_method(&ComputeDAG::GetInitState); -TVM_REGISTER_GLOBAL("ansor.ComputeDAGRewriteLayoutFromState") -.set_body([](TVMArgs args, TVMRetValue *ret) { - ComputeDAG dag = args[0]; - State state = args[1]; - - dag.RewriteLayout(state->transform_steps, kPlaceholderRewrite); - *ret = dag; -}); - TVM_REGISTER_GLOBAL("ansor.ComputeDAGApplyStepsFromState") .set_body([](TVMArgs args, TVMRetValue *ret) { ComputeDAG dag = args[0]; diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h index b1b60e678904..2f1330d612dd 100644 --- a/src/ansor/compute_dag.h +++ b/src/ansor/compute_dag.h @@ -148,10 +148,6 @@ class ComputeDAG: public ObjectRef { const std::vector& transform_steps, LayoutRewriteLevel layout_rewrite_level = kNoRewrite) const; - // Rewrite the the layout of "layout free" placeholders according to transform steps - void RewriteLayout(const std::vector& transform_steps, - LayoutRewriteLevel layout_rewrite_level = kNoRewrite) const; - // Print transform steps as equivalent python schedule API std::string PrintStepsAsPython(const std::vector& steps) const; diff --git a/src/ansor/cost_model/cost_model.cc b/src/ansor/cost_model/cost_model.cc index ee7bf8b26053..d0ae30e20a9a 100644 --- a/src/ansor/cost_model/cost_model.cc +++ b/src/ansor/cost_model/cost_model.cc @@ -36,8 +36,6 @@ using ::tvm::runtime::NDArray; TVM_REGISTER_OBJECT_TYPE(CostModelNode); TVM_REGISTER_OBJECT_TYPE(RandomModelNode); -TVM_REGISTER_OBJECT_TYPE(MeasureModelNode); -TVM_REGISTER_OBJECT_TYPE(PythonBasedModelNode); void RandomNumber(TVMArgs args, TVMRetValue* rv) { int n = args[0]; @@ -71,128 +69,9 @@ void RandomModelNode::Predict(const SearchTask& task, (*random_number_func)(states.size(), static_cast(scores->data())); } -MeasureModel::MeasureModel(Builder builder, Runner runner) { - ObjectPtr node = make_object(); - node->measurer = ProgramMeasurer(std::move(builder), std::move(runner), - Array(), 0); - data_ = std::move(node); -} - -void MeasureModelNode::Update(const Array& inputs, - const Array& results) {} - -void MeasureModelNode::Predict(const SearchTask& task, - const std::vector& states, - std::vector* scores) { - std::vector inputs; - std::vector results; - - inputs.clear(); - inputs.reserve(states.size()); - for (const auto& state : states) { - inputs.push_back(MeasureInput(task, state)); - } - measurer->SilentMeasure(task, inputs, &results); - - scores->clear(); - scores->reserve(results.size()); - for (const auto& res : results) { - scores->push_back(1.0 / FloatArrayMean(res->costs)); - } -} - -PythonBasedModel::PythonBasedModel(PackedFunc update_func, - PackedFunc predict_func, - PackedFunc predict_stage_func) { - auto node = make_object(); - node->update_func = std::move(update_func); - node->predict_func = std::move(predict_func); - node->predict_stage_func = std::move(predict_stage_func); - data_ = std::move(node); -} - -void PythonBasedModelNode::Update(const Array& inputs, - const Array& results) { - update_func(inputs, results); -} - -void PythonBasedModelNode::Predict(const SearchTask& task, - const std::vector& states, - std::vector* scores) { - scores->resize(states.size()); - predict_func(task, Array(states.begin(), states.end()), - static_cast(scores->data())); -} - -void PythonBasedModelNode::PredictStages(const SearchTask& task, - const std::vector& states, std::vector* state_scores, - std::vector>* stage_scores) { - int n_states = states.size(); - int n_stages = task->compute_dag.GetInitState()->stages.size(); - std::vector flatten_scores; - // Allocate sufficient spaces. - flatten_scores.resize(n_states * n_stages * 2); - predict_stage_func(task, Array(states.begin(), states.end()), - static_cast(flatten_scores.data())); - - // Unpack flatten scores. - state_scores->clear(); - stage_scores->clear(); - - // Score of each states. - for (int i = 0; i < n_states; ++i) { - state_scores->push_back(flatten_scores[i]); - } - - // Score of each stage in each states. - size_t idx = n_states; - for (int i = 0; i < n_states; ++i) { - CHECK_LE(idx, flatten_scores.size()); - - // Number of scored stages of this state. - int s_length = static_cast(flatten_scores[idx++]); - - if (s_length > 0) { - std::vector scores; - int offset = 0; - - if ((*state_scores)[i] > -INFINITY) { - // If the score is valid. Copy scored stages and assign 0 to placeholder - // and inlined stages. If the score is 0, meaning this state failed to - // be lowered. Just bypass to update offset. - for (const Stage& stage : states[i]->stages) { - if (stage->op_type == kPlaceholder) { - scores.push_back(0); - continue; - } - if (stage->compute_at == kInlined) { - scores.push_back(0); - continue; - } - scores.push_back(flatten_scores[idx + offset]); - offset++; - } - CHECK_EQ(offset, s_length); - stage_scores->push_back(std::move(scores)); - } - idx += s_length; - } else { - // Cost model does not provide any stage score details. - stage_scores->push_back({}); - } - } -} - TVM_REGISTER_GLOBAL("ansor.RandomModel").set_body_typed([]() { return RandomModel(); }); -TVM_REGISTER_GLOBAL("ansor.PythonBasedModel") -.set_body_typed([](PackedFunc update_func, PackedFunc predict_func, - PackedFunc predict_stage_func) { - return PythonBasedModel(update_func, predict_func, - predict_stage_func); -}); - } // namespace ansor } // namespace tvm diff --git a/src/ansor/cost_model/cost_model.h b/src/ansor/cost_model/cost_model.h index f38624a3572c..03b7fb5f3399 100644 --- a/src/ansor/cost_model/cost_model.h +++ b/src/ansor/cost_model/cost_model.h @@ -92,65 +92,6 @@ class RandomModel : public CostModel { using ContainerType = RandomModelNode; }; -/*! \brief The cost model returns actual cost by measurement */ -class MeasureModelNode : public CostModelNode { - public: - ProgramMeasurer measurer; - - void Update(const Array& inputs, - const Array& results) final; - void Predict(const SearchTask& task, const std::vector& states, - std::vector* scores) final; - - static constexpr const char* _type_key = "ansor.MeasureModel"; - TVM_DECLARE_FINAL_OBJECT_INFO(MeasureModelNode, CostModelNode); -}; - -/*! - * \brief Managed reference to MeasureModelNode. - * \sa MeasureModelNode - */ -class MeasureModel : public CostModel { - public: - MeasureModel(Builder builder, Runner runner); - - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureModel, CostModel, - MeasureModelNode); -}; - -/*! \brief A wrapper for cost model defined by python code - * This class will call python's function */ -class PythonBasedModelNode: public CostModelNode { - public: - PackedFunc update_func; - PackedFunc predict_func; - PackedFunc predict_stage_func; - - void Update(const Array& inputs, - const Array& results) final; - void Predict(const SearchTask& task, const std::vector& states, - std::vector* scores) final; - void PredictStages(const SearchTask& task, const std::vector& states, - std::vector* state_scores, - std::vector>* stage_scores) final; - - static constexpr const char *_type_key = "ansor.PythonBasedModel"; - TVM_DECLARE_FINAL_OBJECT_INFO(PythonBasedModelNode, CostModelNode); -}; - -/*! - * \brief Managed reference to PythonBasedModelNode. - * \sa PythonBasedModelNode - */ -class PythonBasedModel : public CostModel { - public: - PythonBasedModel(PackedFunc update_func, PackedFunc predict_func, - PackedFunc predict_stage_func); - - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PythonBasedModel, CostModel, - PythonBasedModelNode); -}; - } // namespace ansor } // namespace tvm diff --git a/src/ansor/feature.cc b/src/ansor/feature.cc deleted file mode 100644 index 73f6bad0d432..000000000000 --- a/src/ansor/feature.cc +++ /dev/null @@ -1,1573 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file ansor/feature.cc - * \brief Feature extraction for the cost model - */ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "measure.h" -#include "serialization.h" -#include "utils.h" - -namespace tvm { -/* Import the function from driver_api.cc */ -extern void GetBinds(const Array& args, bool compact, - const std::unordered_map& binds, - Map* out_binds, Array* out_arg_list); -} // namespace tvm - - -namespace tvm { -namespace ansor { - -using namespace tvm::tir; -using arith::ConstIntBound; -using arith::Analyzer; - -template -using BufferMap = std::unordered_map; - -static const int ARITH_INTENSITY_CURVE_SAMPLE_N = 10; - -// Annotation position encoding -enum AnnotationPosType { - kPosNone, kPosInnerSpatial, kPosMiddleSpatial, kPosOuterSpatial, - kPosInnerReduce, kPosMiddleReduce, kPosOuterReduce, kPosMixed -}; - -// Buffer access type -enum BufferAccessType { - kRead, kWrite, kReadWrite, kUnknownRW -}; - -// Accesses to a buffer -struct BufferAccess { - BufferAccessType acc_type{kUnknownRW}; - std::vector > indices; -}; - -// Data reuse type -enum ReuseType { - kLoopMultipleRead, kSerialMultipleReadWrite, kNoReuse -}; - -// Feature for an access of a buffer -struct BufferAccessFeature { - std::string buffer_name; - BufferAccessType acc_type; - float bytes; - float unique_bytes; - float lines; - float unique_lines; - ReuseType reuse_type; - float reuse_dis_iter; // reuse distance in iterator number - float reuse_dis_bytes; // reuse distance in total touched bytes - float reuse_ct; // reuse times - float bytes_d_reuse_ct; - float unique_bytes_d_reuse_ct; - float lines_d_reuse_ct; - float unique_lines_d_reuse_ct; - float stride; -}; - -// Feature set of a statement -struct FeatureSet { - // compute feature - float float_mad; - float float_addsub; - float float_mul; - float float_divmod; - float float_cmp; - float float_math_func; - float float_other_func; - float int_mad; - float int_addsub; - float int_mul; - float int_divmod; - float int_cmp; - float int_math_func; - float int_other_func; - float bool_op; - float select_op; - float vec_num; // The number of vectorized iterators - float vec_prod; // The product of the lengths of vectorized iterators - float vec_len; // The length of the innermost vectorized iterator - AnnotationPosType vec_type; - float unroll_num; // The number of unrolled iterators - float unroll_prod; // The product of the lengths of vectorized iterators - float unroll_len; // The length of the innermost unrolled iterator - AnnotationPosType unroll_type; - float parallel_num; // The number of paralleled iterators - float parallel_prod; // The product of the lengths of paralleled iterators - float parallel_len; // The length of the innermost paralleled iterators - AnnotationPosType parallel_type; - float is_gpu; - float blockIdx_x_len; - float blockIdx_y_len; - float blockIdx_z_len; - float threadIdx_x_len; - float threadIdx_y_len; - float threadIdx_z_len; - float vthread_len; - - float arith_intensity_curve[ARITH_INTENSITY_CURVE_SAMPLE_N]; - - // buffer access feature (per buffer) - std::vector access_feas; - - // allocation feature - float alloc_size; - float alloc_prod; - float alloc_outer_prod; - float alloc_inner_prod; - - // overall feature - float outer_prod; - float num_loops; - float auto_unroll_max_step; -}; - -// Return whether a var is in an expr -bool VarInExpr(const Var& var, const PrimExpr& expr) { - bool find = false; - - PostOrderVisit(expr, [&find, &var](const ObjectRef &node) { - if (find) { - return; - } - - if (const VarNode* op = node.as()) { - if (op == var.get()) { - find = true; - } - } - }); - - return find; -} - -// Get position encoding for annotation -AnnotationPosType GetAnnotationPosEncoding( - const Var& var, const Array& spatial_args, - const Array& axis, const Array& reduce_axis) { - // Try to match spatial args first - size_t find_i = 0; - size_t find_ct = 0; - for (size_t i = 0; i < spatial_args.size(); ++i) { - if (VarInExpr(var, spatial_args[i])) { - find_i = i; - find_ct += 1; - } - } - - if (find_ct == 0) { - // If not find in spacial args, then it is a reduce iterator. - // Use name to match - const std::string& var_name = var->name_hint; - for (size_t i = 0; i < reduce_axis.size(); ++i) { - if (var_name.find(reduce_axis[i]->var->name_hint) != std::string::npos) { - find_i = i; - find_ct++; - } - } - if (find_ct >= 1) { - if (find_i == 0) { - return kPosInnerReduce; - } else if (find_i == reduce_axis.size() - 1) { - return kPosOuterReduce; - } else { - return kPosMiddleReduce; - } - } else { - // If the axis is not found in both spatial args and reduce axis, - // then this stage must compute_at somewhere under this aixs and this axis is simplified out - // We assume it is an outer spatial - return kPosOuterSpatial; - } - } else if (find_ct == 1) { - if (find_i == spatial_args.size() - 1) { - return kPosInnerSpatial; - } else if (find_i == 0) { - return kPosOuterSpatial; - } else { - return kPosMiddleSpatial; - } - } else { - return kPosMixed; - } -} - -// Count math ops in an expr -class MathOpCounter : public StmtExprVisitor { - public: -#define VisitBinary(Type, float_ct, int_ct) \ - void VisitExpr_(const Type* op) final { \ - if (op->a.dtype().is_float()) { \ - float_ct++; \ - } else { \ - int_ct++; \ - } \ - StmtExprVisitor::VisitExpr_(op); \ - } \ - - VisitBinary(AddNode, float_addsub, int_addsub); - VisitBinary(SubNode, float_addsub, int_addsub); - VisitBinary(MulNode, float_mul, int_mul); - VisitBinary(DivNode, float_divmod, int_divmod); - VisitBinary(ModNode, float_divmod, int_divmod); - VisitBinary(FloorDivNode, float_divmod, int_divmod); - VisitBinary(FloorModNode, float_divmod, int_divmod); - VisitBinary(MaxNode, float_cmp, int_cmp); - VisitBinary(MinNode, float_cmp, int_cmp); - VisitBinary(EQNode, float_cmp, int_cmp); - VisitBinary(NENode, float_cmp, int_cmp); - VisitBinary(LTNode, float_cmp, int_cmp); - VisitBinary(LENode, float_cmp, int_cmp); - VisitBinary(GTNode, float_cmp, int_cmp); - VisitBinary(GENode, float_cmp, int_cmp); - - void VisitExpr_(const AndNode* op) final { bool_op++; StmtExprVisitor::VisitExpr_(op); } - void VisitExpr_(const OrNode* op) final { bool_op++; StmtExprVisitor::VisitExpr_(op); } - void VisitExpr_(const NotNode* op) final { bool_op++; StmtExprVisitor::VisitExpr_(op); } - void VisitExpr_(const SelectNode* op) final { select_op++; StmtExprVisitor::VisitExpr_(op); } - - void VisitExpr_(const CallNode* op) final { - if (op->call_type == CallNode::CallType::PureIntrinsic) { - if (op->dtype.is_float()) { - float_math_func++; - } else { - int_math_func++; - } - } else { - if (op->dtype.is_float()) { - float_other_func++; - } else { - int_other_func++; - } - } - StmtExprVisitor::VisitExpr_(op); - } - - // todo(lmzheng): detect mad - size_t float_mad{0}, float_addsub{0}, float_mul{0}, float_divmod{0}, - float_cmp{0}, float_math_func{0}, float_other_func{0}; - size_t int_mad{0}, int_addsub{0}, int_mul{0}, int_divmod{0}, - int_cmp{0}, int_math_func{0}, int_other_func{0}; - size_t bool_op{0}, select_op{0}; -}; - - -// Extract all buffer accesses in an expr -class BufferAccessExtractor : public StmtExprVisitor { - public: - void ExtractReads(const PrimExpr& expr) { - this->VisitExpr(expr); - } - - void InsertAccess(const Buffer& buf, BufferAccessType acc_type, - const Array& indices) { - BufferAccess& acc = buf_accesses[buf]; - acc.acc_type = acc_type; - acc.indices.push_back(std::vector(indices.begin(), indices.end())); - } - - void VisitExpr_(const BufferLoadNode *op) final { - BufferAccess& acc = buf_accesses[op->buffer]; - switch (acc.acc_type) { - case kRead: - break; - case kWrite: - acc.acc_type = kReadWrite; break; - case kReadWrite: - break; - case kUnknownRW: - default: - acc.acc_type = kRead; break; - } - - if (acc.acc_type != kReadWrite) { - // If a buffer is both read and written, in the tvm DSL, it must be a update, - // so the indices should be the same. Then we can skip appending indices for it. - // Otherwise we do the following. - buf_accesses[op->buffer].indices.push_back( - std::vector(op->indices.begin(), op->indices.end())); - } - StmtExprVisitor::VisitExpr_(op); - } - - BufferMap buf_accesses; -}; - -// Compute coefficient for an loop iterator in an expression -// Note: we use a approximation strategy to find coefficient. -// Hopefully, it is faster than DetectLinearEquation and can handle more cases (non-linear) -class CoefficientExtractor : public StmtExprVisitor { - public: - void VisitExpr_(const MulNode *node) final { - StmtExprVisitor::VisitExpr_(node); - if (visited_var) { - if (!visited_add) { - if (auto a = node->a.as()) { - visited_mul = true; - stride = a->value; - } else if (auto b = node->b.as()) { - visited_mul = true; - stride = b->value; - } - } - } - } - - void VisitExpr_(const AddNode *node) final { - StmtExprVisitor::VisitExpr_(node); - if (visited_var) { - if (!visited_mul) { - visited_add = true; - stride = 1; - } - } - } - - void VisitExpr_(const VarNode *node) final { - if (node == var_) { - visited_var = true; - // This is a magic default stride in case our approximation strategy fails - stride = 2; - } - } - - int ExtractCoefficient(const PrimExpr& expr, const VarNode* var) { - visited_var = visited_mul = visited_add = false; - var_ = var; - - this->VisitExpr(expr); - - if (visited_var && !visited_mul && !visited_add) { - return 1; - } else { - return stride; - } - } - - bool visited_var{false}; - bool visited_mul{false}; - bool visited_add{false}; - int stride{0}; - - private: - const VarNode* var_{nullptr}; -}; - -// Compute stride for the accesses to a buffer -int64_t ComputeStride(const std::vector >& indices, - const std::vector& shape, - const VarNode* stride_var) { - int64_t min_stride = std::numeric_limits::max(); - bool find = false; - CoefficientExtractor extractor; - - for (const auto &index : indices) { - int64_t shape_stride = 1; - for (int i = static_cast(index.size()) - 1; i >= 0; i--) { - int coefficient = extractor.ExtractCoefficient(index[i], stride_var); - if (extractor.visited_var) { - find = true; - min_stride = std::min(min_stride, std::abs(coefficient) * shape_stride); - break; - } - shape_stride *= shape[i]; - } - } - - return find ? min_stride : 0; -} - -// Compute touched bytes and cache lines for accesses to a buffer -void ComputeRegion( - const std::vector > &indices, - arith::Analyzer* ana, - std::vector* region) { - region->clear(); - - if (indices.empty()) { - return; - } - - region->reserve(indices[0].size()); - - if (indices.size() == 1) { - for (const auto& index : indices[0]) { - ConstIntBound bound = ana->const_int_bound(index); - region->push_back(bound->max_value - bound->min_value + 1); - } - } else { - // future(lmzheng): implement a more accurate IntSet? - for (size_t i = 0; i < indices[0].size(); ++i) { - int64_t minimum = ConstIntBound::kPosInf, maximum = ConstIntBound::kNegInf; - for (size_t j = 0; j < indices.size(); ++j) { - ConstIntBound bound = ana->const_int_bound(indices[j][i]); - - minimum = std::min(minimum, bound->min_value); - maximum = std::max(maximum, bound->max_value); - } - region->push_back(maximum - minimum + 1); - } - } -} - -// Compute reuse distance and reuse ratio for accesses to a buffer -// return values: reuse_type, reuse_dis_iter, reuse_dis_bytes, reuse_ct -std::tuple ComputeReuse( - const Buffer& buf, - const std::vector >& indices, - const std::vector& for_loop_stack, - const std::unordered_map > > >& for_touch_regions) { - float reuse_dis_iter = 1.0f; - float reuse_dis_bytes = -1.0f; - - for (int i = static_cast(for_loop_stack.size()) - 1; i >= 0; --i) { - const ForNode* cur_for = for_loop_stack[i]; - bool find = false; - - for (size_t j = 0; j < indices.size(); j++) { - for (size_t k = 0; k < indices[j].size(); k++) { - if (VarInExpr(cur_for->loop_var, indices[j][k])) { - find = true; - break; - } - } - if (find) { - break; - } - } - - int64_t extent = GetIntImm(for_loop_stack[i]->extent); - if (find) { - // accumulate/update reuse distance - reuse_dis_iter *= extent; - reuse_dis_bytes = 0.0f; - for (const auto& iter : for_touch_regions.at(cur_for)) { - for (const auto& access : iter.second) { - reuse_dis_bytes += std::get<1>(access) * std::get<2>(access); - } - } - } else { - // Have LoopMultipleRead reuse - if (reuse_dis_bytes < 0) { - // For the reuse in the innermost axis, the above code won't be executed. - // So we compute bytes here - reuse_dis_bytes = 0.0f; - for (const auto& iter : for_touch_regions.at(cur_for)) { - for (const auto& access : iter.second) { - reuse_dis_bytes += 1 * std::get<2>(access); - } - } - } - return std::make_tuple(kLoopMultipleRead, reuse_dis_iter, reuse_dis_bytes, extent); - } - - const BufferMap > >& buffer_map - = for_touch_regions.at(cur_for); - - int serial_reuse = static_cast(buffer_map.at(buf).size()) - 1; - if (serial_reuse > 0) { - int64_t extent = GetIntImm(cur_for->extent); - - // Have SerialMultipleReadWrite reuse - reuse_dis_iter = std::numeric_limits::max(); - for (const auto& acc_info : buffer_map.at(buf)) { - reuse_dis_iter = std::min(reuse_dis_iter, static_cast(std::get<1>(acc_info))); - } - - reuse_dis_bytes = 0.0f; - for (const auto& iter : for_touch_regions.at(cur_for)) { - for (const auto& access : iter.second) { - reuse_dis_bytes += std::get<1>(access) * std::get<2>(access); - } - } - - return std::make_tuple(kSerialMultipleReadWrite, - reuse_dis_iter / extent, reuse_dis_bytes / extent, serial_reuse); - } - } - - return std::make_tuple(kNoReuse, 0, 0, 0); -} - -// Extract features for every Provide statement -class PerStmtFeatureExtractor : public StmtExprVisitor { - public: - explicit PerStmtFeatureExtractor(int cache_line_size) : - cache_line_size_(cache_line_size) {} - - void VisitStmt_(const AttrStmtNode* node) final { - if (node->attr_key == tir::attr::thread_extent || - node->attr_key == tir::attr::virtual_thread) { - const Var& var = node->node.as()->var; - int extent = GetIntImm(node->value); - - int* plen = nullptr; - - const std::string& name = var.get()->name_hint; - if (node->attr_key == tir::attr::thread_extent) { - if (name == "blockIdx.x") { - plen = &blockIdx_x_len; - } else if (name == "blockIdx.y") { - plen = &blockIdx_y_len; - } else if (name == "blockIdx.z") { - plen = &blockIdx_z_len; - } else if (name == "threadIdx.x") { - plen = &threadIdx_x_len; - } else if (name == "threadIdx.y") { - plen = &threadIdx_y_len; - } else if (name == "threadIdx.z") { - plen = &threadIdx_z_len; - } else { - LOG(FATAL) << "invalid thread itervar " + name; - } - } else { - plen = &vthread_len; - } - - int extent_before = *plen; - if (node->attr_key == tir::attr::thread_extent) { - *plen = extent; - } else { - *plen *= extent; - } - - is_gpu = true; - - // make a fake for node for blockIdx.x or threadIdx.x - Stmt fake_for_node = For(var, 0, extent, ForType::Parallel, - DeviceAPI::None, node->body); - - outer_loop_prod *= extent; - for_loop_stack.push_back(fake_for_node.as()); - StmtExprVisitor::VisitStmt_(node); - for_loop_stack.pop_back(); - outer_loop_prod /= extent; - - *plen = extent_before; - } else if (node->attr_key == "pragma_auto_unroll_max_step") { - int value = GetIntImm(node->value); - - int16_t old_value = cur_auto_unroll_max_step; - cur_auto_unroll_max_step = value; - StmtExprVisitor::VisitStmt_(node); - cur_auto_unroll_max_step = old_value; - } else { - StmtExprVisitor::VisitStmt_(node); - } - } - - void VisitStmt_(const ForNode* node) final { - int64_t loop_extent = GetIntImm(node->extent); - - if (node->for_type == ForType::Vectorized) { - vec_for_stack.push_back(node); - } else if (node->for_type == ForType::Unrolled) { - unroll_for_stack.push_back(node); - } else if (node->for_type == ForType::Parallel) { - parallel_for_stack.push_back(node); - } - - outer_loop_prod *= loop_extent; - for_loop_stack.push_back(node); - StmtExprVisitor::VisitStmt_(node); - for_loop_stack.pop_back(); - outer_loop_prod /= loop_extent; - - if (node->for_type == ForType::Vectorized) { - vec_for_stack.pop_back(); - } else if (node->for_type == ForType::Unrolled) { - unroll_for_stack.pop_back(); - } else if (node->for_type == ForType::Parallel) { - parallel_for_stack.pop_back(); - } - } - - void VisitStmt_(const BufferStoreNode* node) final { - FeatureSet &fea = buffer_features[node->buffer]; - - // compute feature - MathOpCounter mathops; - mathops(node->value); - fea.float_mad = outer_loop_prod * mathops.float_mad; - fea.float_addsub = outer_loop_prod * mathops.float_addsub; - fea.float_mul = outer_loop_prod * mathops.float_mul; - fea.float_divmod = outer_loop_prod * mathops.float_divmod; - fea.float_cmp = outer_loop_prod * mathops.float_cmp; - fea.float_math_func = outer_loop_prod * mathops.float_math_func; - fea.float_other_func = outer_loop_prod * mathops.float_other_func; - fea.int_mad = outer_loop_prod * mathops.int_mad; - fea.int_addsub = outer_loop_prod * mathops.int_addsub; - fea.int_mul = outer_loop_prod * mathops.int_mul; - fea.int_divmod = outer_loop_prod * mathops.int_divmod; - fea.int_math_func = outer_loop_prod * mathops.int_math_func; - fea.int_cmp = outer_loop_prod * mathops.int_cmp; - fea.int_other_func = outer_loop_prod * mathops.int_other_func; - fea.bool_op = outer_loop_prod * mathops.bool_op; - fea.select_op = outer_loop_prod * mathops.select_op; - - fea.outer_prod = outer_loop_prod; - fea.num_loops = for_loop_stack.size(); - fea.auto_unroll_max_step = cur_auto_unroll_max_step; - fea.vec_len = fea.unroll_len = fea.parallel_len = 0.0f; - fea.vec_type = fea.unroll_type = fea.parallel_type = kPosNone; - - fea.vec_num = vec_for_stack.size(); - if (!vec_for_stack.empty()) { - fea.vec_len = GetIntImm(vec_for_stack.back()->extent); - fea.vec_prod = 1.0; - for (const ForNode* pfor : vec_for_stack) { - fea.vec_prod *= GetIntImm(pfor->extent); - } - fea.vec_type = kPosMixed; - // todo(lmzheng): this feature requires operation (tvm.compute) information - // GetAnnotationPosEncoding(vec_for_stack.back()->loop_var, - // node->args, pcompute->axis, pcompute->reduce_axis); - } - - fea.unroll_num = unroll_for_stack.size(); - if (!unroll_for_stack.empty()) { - fea.unroll_len = GetIntImm(unroll_for_stack.back()->extent); - fea.unroll_prod = 1.0; - for (const ForNode* pfor : unroll_for_stack) { - fea.unroll_prod *= GetIntImm(pfor->extent); - } - fea.unroll_type = kPosMixed; - // GetAnnotationPosEncoding(unroll_for_stack.back()->loop_var, - // node->args, pcompute->axis, pcompute->reduce_axis); - } - - fea.parallel_num = parallel_for_stack.size(); - if (!parallel_for_stack.empty()) { - fea.parallel_len = GetIntImm(parallel_for_stack.back()->extent); - fea.parallel_prod = 1.0; - for (const ForNode* pfor : parallel_for_stack) { - fea.parallel_prod *= GetIntImm(pfor->extent); - } - fea.parallel_type = kPosMixed; - // GetAnnotationPosEncoding(parallel_for_stack.back()->loop_var, - // node->args, pcompute->axis, pcompute->reduce_axis); - } - - // GPU threads - fea.is_gpu = is_gpu; - fea.blockIdx_x_len = blockIdx_x_len; - fea.blockIdx_y_len = blockIdx_y_len; - fea.blockIdx_z_len = blockIdx_z_len; - fea.threadIdx_x_len = threadIdx_x_len; - fea.threadIdx_y_len = threadIdx_y_len; - fea.threadIdx_z_len = threadIdx_z_len; - fea.vthread_len = vthread_len; - - // Extract all buffer access - std::vector acc_feas; - BufferAccessExtractor buf_extractor; - buf_extractor.InsertAccess(node->buffer, kWrite, node->indices); - buf_extractor.ExtractReads(node->value); - - // Compute touched region for all outer loops - Analyzer ana; - for (auto x : for_loop_stack) { - ana.Bind(x->loop_var, Range::make_by_min_extent(x->min, 1), true); - } - - std::vector mem_bytes_list; - std::vector compute_ops_list; - - mem_bytes_list.reserve(for_loop_stack.size()); - compute_ops_list.reserve(for_loop_stack.size()); - - int cur_compute_ops = mathops.float_mad + mathops.float_addsub + mathops.float_mul + - mathops.float_divmod + mathops.float_cmp + - mathops.float_math_func + mathops.float_other_func; - - std::vector tmp_region; - for (int i = static_cast(for_loop_stack.size()) - 1; i >= 0; i--) { - const ForNode* p_for = for_loop_stack[i]; - - ana.Bind(p_for->loop_var, - Range::make_by_min_extent(for_loop_stack[i]->min, for_loop_stack[i]->extent), true); - - // Note, here we do overwrite. - // So if there are multiple Provides, the last one will overwrite the first few. - // e.g. The update part in gemm will overwrite the init part. - BufferMap > >& - buffer_regions_map = for_touch_regions[p_for]; - - int64_t mem_bytes = 0; - for (const auto &x : buf_extractor.buf_accesses) { - const Buffer& t = x.first; - const BufferAccess& acc = x.second; - - ComputeRegion(acc.indices, &ana, &tmp_region); - int64_t touched_size = ElementProduct(tmp_region); - buffer_regions_map[t].push_back(std::make_tuple(acc.acc_type, - touched_size, t->dtype.bytes())); - mem_bytes += touched_size * t->dtype.bytes(); - } - - mem_bytes_list.push_back(std::log2(mem_bytes)); - cur_compute_ops *= GetIntImm(for_loop_stack[i]->extent); - compute_ops_list.push_back(std::log2(cur_compute_ops)); - } - - // Compute arithmetic intensity curve (y axis : arithmetic intensity, x axis : flops). - // We use piecewise linear interpolation to fit this curve. - int pt = 0; - if (cur_compute_ops <= 0 || compute_ops_list.empty()) { - std::fill(fea.arith_intensity_curve, - fea.arith_intensity_curve + ARITH_INTENSITY_CURVE_SAMPLE_N, 0.0); - } else { - for (size_t i = 0; i < ARITH_INTENSITY_CURVE_SAMPLE_N; ++i) { - float cur_compute_ops = compute_ops_list.back() * (i+1) / ARITH_INTENSITY_CURVE_SAMPLE_N; - while (compute_ops_list[pt] < cur_compute_ops - 1e-4) { - pt++; - } - CHECK_LT(pt, compute_ops_list.size()); - - float value; - if (pt == 0) { - value = compute_ops_list[pt] / mem_bytes_list[pt]; - } else { - float base = compute_ops_list[pt-1] / mem_bytes_list[pt-1]; - float slope = (compute_ops_list[pt] / mem_bytes_list[pt] - - compute_ops_list[pt-1] / mem_bytes_list[pt-1]) / - (compute_ops_list[pt] - compute_ops_list[pt-1]); - value = base + slope * (cur_compute_ops - compute_ops_list[pt-1]); - } - fea.arith_intensity_curve[i] = value; - } - } - - // Compute buffer access feature - for (const auto &x : buf_extractor.buf_accesses) { - const Buffer& t = x.first; - const BufferAccess& acc = x.second; - - std::vector int_shape; - for (const auto& dim : t->shape) { - int_shape.push_back(GetIntImm(dim)); - } - - size_t ele_bytes = t->dtype.bytes(); - - // calculate bytes - float bytes = outer_loop_prod * ele_bytes; - float unique_bytes; - - // calculate cache lines - int64_t stride; - float lines; - float unique_lines; - - if (for_loop_stack.empty()) { - unique_bytes = ele_bytes; - stride = 0; - lines = 1.0f; - unique_lines = 1.0f; - } else { - unique_bytes = std::get<1>(for_touch_regions[for_loop_stack.front()][t].front()) - * ele_bytes; - - stride = 0; - int64_t reduce_ratio = 1; - - int i; - for (i = static_cast(for_loop_stack.size()) - 1; i >= 0; i--) { - stride = ComputeStride(acc.indices, int_shape, for_loop_stack[i]->loop_var.get()); - if (stride != 0) { - break; - } - reduce_ratio *= GetIntImm(for_loop_stack.back()->extent); - } - - lines = outer_loop_prod / reduce_ratio * - std::min(1.0f, 1.0f * stride * ele_bytes / cache_line_size_); - lines = std::max(lines, 1.0f); - - // convert `stride` back to the stride of the innermost iterator - stride = (i == static_cast(for_loop_stack.size()) - 1 ? stride : 0); - - float n_continuous = ele_bytes; - for (int i = static_cast(tmp_region.size()) - 1; i >= 0; i--) { - if (tmp_region[i] == int_shape[i]) { - n_continuous *= tmp_region[i]; - break; - } - } - unique_lines = unique_bytes / std::min(n_continuous, - static_cast(cache_line_size_)); - unique_lines = std::max(unique_lines, 1.0f); - } - - ReuseType reuse_type; - float reuse_dis_iter, reuse_dis_bytes, reuse_ct; - std::tie(reuse_type, reuse_dis_iter, reuse_dis_bytes, reuse_ct) = - ComputeReuse(t, acc.indices, for_loop_stack, for_touch_regions); - - acc_feas.emplace_back(); - BufferAccessFeature& acc_fea = acc_feas.back(); - - acc_fea.buffer_name = t->name; - acc_fea.acc_type = acc.acc_type; - acc_fea.stride = stride; - acc_fea.bytes = bytes; - acc_fea.unique_bytes = unique_bytes; - acc_fea.lines = lines; - acc_fea.unique_lines = unique_lines; - acc_fea.reuse_type = reuse_type; - acc_fea.reuse_dis_iter = reuse_dis_iter; - acc_fea.reuse_dis_bytes = reuse_dis_bytes; - acc_fea.reuse_ct = reuse_ct; - if (acc_fea.reuse_ct > 0.5) { - acc_fea.bytes_d_reuse_ct = bytes / reuse_ct; - acc_fea.unique_bytes_d_reuse_ct = unique_bytes / reuse_ct; - acc_fea.lines_d_reuse_ct = lines / reuse_ct; - acc_fea.unique_lines_d_reuse_ct = unique_lines / reuse_ct; - } else { - // no reuse, multiply by a magic number '2' - acc_fea.bytes_d_reuse_ct = bytes * 2; - acc_fea.unique_bytes_d_reuse_ct = unique_bytes * 2; - acc_fea.lines_d_reuse_ct = lines * 2; - acc_fea.unique_lines_d_reuse_ct = unique_lines* 2; - } - } - - fea.access_feas = acc_feas; - } - - void VisitStmt_(const BufferRealizeNode *node) final { - StmtExprVisitor::VisitStmt_(node); - - FeatureSet& fea = buffer_features[node->buffer]; - - float allocation_size = 1.0f; - for (const auto& x : node->bounds) { - allocation_size *= GetIntImm(x->extent); - } - // allocation feature - fea.alloc_size = allocation_size * node->buffer->dtype.bytes(); - fea.alloc_prod = allocation_size * outer_loop_prod; - fea.alloc_outer_prod = outer_loop_prod; - fea.alloc_inner_prod = fea.outer_prod / outer_loop_prod; - } - - float outer_loop_prod = 1.0f; - - std::vector for_loop_stack; - std::vector parallel_for_stack; - std::vector vec_for_stack; - std::vector unroll_for_stack; - - bool is_gpu; - int blockIdx_x_len{1}; - int blockIdx_y_len{1}; - int blockIdx_z_len{1}; - int threadIdx_x_len{1}; - int threadIdx_y_len{1}; - int threadIdx_z_len{1}; - int vthread_len{1}; - int16_t cur_auto_unroll_max_step{0}; - - BufferMap buffer_features; - - // for a loop, for all its touched buffers, for all different accesses to the buffers, - // its (access type, number of touched elements, number of bytes of single element) - std::unordered_map > > > for_touch_regions; - - private: - const int cache_line_size_ = 64; -}; - -// shifted log to incorporate the property that slog(0) = 0 -inline float slog(float x) { - return x < 0 ? -std::log2(-x+1) : std::log2(x+1); -} - -// Get features for all ir::Provide statements in a TVM program. -// So we call it `PerStmt` feature -void GetPerStmtFeature(const Stmt& stmt, - int cache_line_size, - int max_n_bufs, - std::vector* ret) { - PerStmtFeatureExtractor extractor(cache_line_size); - extractor(stmt); - - ret->push_back(extractor.buffer_features.size()); - - for (const auto& x : extractor.buffer_features) { - const FeatureSet& fea_set = x.second; - - /***** compute feature *****/ - ret->push_back(slog(fea_set.float_mad)); - ret->push_back(slog(fea_set.float_addsub)); - ret->push_back(slog(fea_set.float_mul)); - ret->push_back(slog(fea_set.float_divmod)); - ret->push_back(slog(fea_set.float_cmp)); - ret->push_back(slog(fea_set.float_math_func)); - ret->push_back(slog(fea_set.float_other_func)); - ret->push_back(slog(fea_set.int_mad)); - ret->push_back(slog(fea_set.int_addsub)); - ret->push_back(slog(fea_set.int_mul)); - ret->push_back(slog(fea_set.int_divmod)); - ret->push_back(slog(fea_set.int_cmp)); - ret->push_back(slog(fea_set.int_math_func)); - ret->push_back(slog(fea_set.int_other_func)); - ret->push_back(slog(fea_set.bool_op)); - ret->push_back(slog(fea_set.select_op)); - - ret->push_back(slog(fea_set.vec_num)); - ret->push_back(slog(fea_set.vec_prod)); - ret->push_back(slog(fea_set.vec_len)); - for (int i = 0; i <= kPosMixed; i++) { - ret->push_back(i == fea_set.vec_type); - } - - ret->push_back(slog(fea_set.unroll_num)); - ret->push_back(slog(fea_set.unroll_prod)); - ret->push_back(slog(fea_set.unroll_len)); - for (int i = 0; i <= kPosMixed; i++) { - ret->push_back(i == fea_set.unroll_type); - } - - ret->push_back(slog(fea_set.parallel_num)); - ret->push_back(slog(fea_set.parallel_prod)); - ret->push_back(slog(fea_set.parallel_len)); - for (int i = 0; i <= kPosMixed; i++) { - ret->push_back(i == fea_set.parallel_type); - } - - ret->push_back(fea_set.is_gpu); - ret->push_back(slog(fea_set.blockIdx_x_len)); - ret->push_back(slog(fea_set.blockIdx_y_len)); - ret->push_back(slog(fea_set.blockIdx_z_len)); - ret->push_back(slog(fea_set.threadIdx_x_len)); - ret->push_back(slog(fea_set.threadIdx_y_len)); - ret->push_back(slog(fea_set.threadIdx_z_len)); - ret->push_back(slog(fea_set.vthread_len)); - - for (size_t i = 0; i < ARITH_INTENSITY_CURVE_SAMPLE_N; ++i) { - ret->push_back(fea_set.arith_intensity_curve[i]); - } - - /***** access feature *****/ - // sort according to pair (lines, bytes) - std::vector > buf_order_key; - for (const auto& acc_fea : fea_set.access_feas) { - buf_order_key.emplace_back(acc_fea.lines, acc_fea.bytes); - } - std::vector buf_order(buf_order_key.size()); - std::iota(buf_order.begin(), buf_order.end(), 0); - - auto cmp = [&buf_order_key](int l, int r) { - return buf_order_key[l].first > buf_order_key[r].first - || (buf_order_key[l].first == buf_order_key[r].first - && buf_order_key[l].second > buf_order_key[r].second); - }; - std::sort(buf_order.begin(), buf_order.end(), cmp); - int n_bufs = std::min(max_n_bufs, static_cast(buf_order.size())); - buf_order.resize(n_bufs); - - for (int idx : buf_order) { - const auto& acc_fea = fea_set.access_feas[idx]; - for (int j = 0; j <= kReadWrite; ++j) { - ret->push_back(j == acc_fea.acc_type); - } - ret->push_back(slog(acc_fea.bytes)); - ret->push_back(slog(acc_fea.unique_bytes)); - ret->push_back(slog(acc_fea.lines)); - ret->push_back(slog(acc_fea.unique_lines)); - for (int j = 0; j <= kNoReuse; ++j) { - ret->push_back(acc_fea.reuse_type == j); - } - ret->push_back(slog(acc_fea.reuse_dis_iter)); - ret->push_back(slog(acc_fea.reuse_dis_bytes)); - ret->push_back(slog(acc_fea.reuse_ct)); - ret->push_back(slog(acc_fea.bytes_d_reuse_ct)); - ret->push_back(slog(acc_fea.unique_bytes_d_reuse_ct)); - ret->push_back(slog(acc_fea.lines_d_reuse_ct)); - ret->push_back(slog(acc_fea.unique_lines_d_reuse_ct)); - ret->push_back(slog(acc_fea.stride)); - } - // - fill padding - for (int i = 0; i < max_n_bufs - n_bufs; ++i) { - for (int j = 0; j <= kReadWrite; ++j) { // 3 - ret->push_back(0.0f); - } - ret->push_back(0.0f); ret->push_back(0.0f); ret->push_back(0.0f); ret->push_back(0.0f); - for (int j = 0; j <= kNoReuse; ++j) { // 3 - ret->push_back(0.0f); - } - ret->push_back(0.0f); ret->push_back(0.0f); ret->push_back(0.0f); ret->push_back(0.0f); - ret->push_back(0.0f); ret->push_back(0.0f); ret->push_back(0.0f); ret->push_back(0.0f); - } - - /***** allocation feature *****/ - ret->push_back(slog(fea_set.alloc_size)); - ret->push_back(slog(fea_set.alloc_prod)); - ret->push_back(slog(fea_set.alloc_outer_prod)); - ret->push_back(slog(fea_set.alloc_inner_prod)); - - /***** overall feature *****/ - ret->push_back(slog(fea_set.outer_prod)); - ret->push_back(slog(fea_set.num_loops)); - ret->push_back(slog(fea_set.auto_unroll_max_step)); - } -} - - -/* \brief Get the name of every element in the feature vector. Use this for debug and inspection */ -void GetPerStmtFeatureName(int max_n_bufs, std::vector *ret) { - /***** compute feature *****/ - ret->push_back(("float_mad")); - ret->push_back(("float_addsub")); - ret->push_back(("float_mul")); - ret->push_back(("float_divmod")); - ret->push_back(("float_cmp")); - ret->push_back(("float_mathfunc")); - ret->push_back(("float_otherfunc")); - ret->push_back(("int_mad")); - ret->push_back(("int_addsub")); - ret->push_back(("int_mul")); - ret->push_back(("int_divmod")); - ret->push_back(("int_cmp")); - ret->push_back(("int_mathfunc")); - ret->push_back(("int_otherfunc")); - ret->push_back(("bool_op")); - ret->push_back(("select_op")); - ret->push_back(("vec_num")); - ret->push_back(("vec_prod")); - ret->push_back(("vec_len")); - ret->push_back(("vec_type.kPosNone")); - ret->push_back(("vec_type.kPosInnerSpatial")); - ret->push_back(("vec_type.kPosMiddleSpatial")); - ret->push_back(("vec_type.kPosOuterSpatial")); - ret->push_back(("vec_type.kPosInnerReduce")); - ret->push_back(("vec_type.kPosMiddleReduce")); - ret->push_back(("vec_type.kPosOuterReduce")); - ret->push_back(("vec_type.kPosMixed")); - ret->push_back(("unroll_num")); - ret->push_back(("unroll_prod")); - ret->push_back(("unroll_len")); - ret->push_back(("unroll_type.kPosNone")); - ret->push_back(("unroll_type.kPosInnerSpatial")); - ret->push_back(("unroll_type.kPosMiddleSpatial")); - ret->push_back(("unroll_type.kPosOuterSpatial")); - ret->push_back(("unroll_type.kPosInnerReduce")); - ret->push_back(("unroll_type.kPosMiddleReduce")); - ret->push_back(("unroll_type.kPosOuterReduce")); - ret->push_back(("unroll_type.kPosMixed")); - ret->push_back(("parallel_num")); - ret->push_back(("parallel_prod")); - ret->push_back(("parallel_len")); - ret->push_back(("parallel_type.kPosNone")); - ret->push_back(("parallel_type.kPosInnerSpatial")); - ret->push_back(("parallel_type.kPosMiddleSpatial")); - ret->push_back(("parallel_type.kPosOuterSpatial")); - ret->push_back(("parallel_type.kPosInnerReduce")); - ret->push_back(("parallel_type.kPosMiddleReduce")); - ret->push_back(("parallel_type.kPosOuterReduce")); - ret->push_back(("parallel_type.kPosMixed")); - ret->push_back(("is_gpu")); - ret->push_back(("blockIdx_x_len")); - ret->push_back(("blockIdx_y_len")); - ret->push_back(("blockIdx_z_len")); - ret->push_back(("threadIdx_x_len")); - ret->push_back(("threadIdx_y_len")); - ret->push_back(("threadIdx_z_len")); - ret->push_back(("vthread_len")); - for (size_t i = 0; i < ARITH_INTENSITY_CURVE_SAMPLE_N; ++i) { - ret->push_back(("arith_intensity_curve_" + std::to_string(i))); - } - // section total: 55 + ARITH_INTENSITY_CURVE_SAMPLE_N = 65 - - /***** access feature *****/ - for (size_t i = 0; i < static_cast(max_n_bufs); ++i) { - std::string prefix = "B" + std::to_string(i) + "."; - ret->push_back((prefix + "acc_type.kRead")); - ret->push_back((prefix + "acc_type.kWrite")); - ret->push_back((prefix + "acc_type.kReadWrite")); - ret->push_back((prefix + "bytes")); - ret->push_back((prefix + "unique_bytes")); - ret->push_back((prefix + "lines")); - ret->push_back((prefix + "unique_lines")); - ret->push_back((prefix + "reuse_type.kLoopMultipleRead")); - ret->push_back((prefix + "reuse_type.kSerialMultipleReadWrite")); - ret->push_back((prefix + "reuse_type.kNoReuse")); - ret->push_back((prefix + "reuse_dis_iter")); - ret->push_back((prefix + "reuse_dis_bytes")); - ret->push_back((prefix + "reuse_ct")); - ret->push_back((prefix + "bytes_d_reuse_ct")); - ret->push_back((prefix + "unique_bytes_d_reuse_ct")); - ret->push_back((prefix + "lines_d_reuse_ct")); - ret->push_back((prefix + "unique_lines_d_reuse_ct")); - ret->push_back((prefix + "stride")); - } - // section total : max_n_bufs * 18 - - /***** allocation feature *****/ - ret->push_back(("alloc_size")); - ret->push_back(("alloc_prod")); - ret->push_back(("alloc_outer_prod")); - ret->push_back(("alloc_inner_prod")); - // section total : 4 - - /***** overall feature *****/ - ret->push_back(("outer_prod")); - ret->push_back(("num_loops")); - ret->push_back(("auto_unroll_max_step")); - // section total : 2 -} - -void GetPerStmtFeaturesWorkerFunc(const SearchTask& task, const State& state, - int max_n_bufs, std::vector* feature, std::atomic* error_ct) { - te::Schedule sch; - Array tensors; - - std::tie(sch, tensors) = task->compute_dag.ApplySteps(state->transform_steps); - sch = sch.normalize(); - auto bounds = te::InferBound(sch); - - try { - auto stmt = te::ScheduleOps(sch, bounds, false); - Map out_binds; Array out_arg_list; - bool compact = te::VerifyCompactBuffer(stmt); - const std::string& name = "main"; - GlobalVar global_var(name); - - // Copied from driver_api.cc::lower - auto pass_ctx = tvm::transform::PassContext::Current(); - GetBinds(tensors, compact, std::unordered_map(), - &out_binds, &out_arg_list); - tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); - f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); - - bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); - bool disable_vectorize = - pass_ctx->GetConfig("tir.disable_vectorize", Bool(false)).value(); - bool instrument_bound_checkers = - pass_ctx->GetConfig("tir.instrument_bound_checkers", Bool(false)).value(); - - if (noalias) { - f = WithAttr(std::move(f), "tir.noalias", Bool(true)); - } - auto mod = IRModule(Map({{global_var, f}})); - - if (task->target->device_type == kDLGPU) { - auto pass_list = Array(); - // Phase 0 - pass_list.push_back(tir::transform::InjectPrefetch()); - pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); - // Phase 1 - pass_list.push_back(tir::transform::NarrowDataType(32)); - pass_list.push_back(tir::transform::Simplify()); - pass_list.push_back(tir::transform::VectorizeLoop(disable_vectorize)); - pass_list.push_back(tir::transform::InjectVirtualThread()); - pass_list.push_back(tir::transform::StorageRewrite()); - pass_list.push_back(tir::transform::Simplify()); - tvm::Map gpu_params { - {"max_shared_memory_per_block", - task->hardware_params->max_shared_memory_per_block}, - {"max_local_memory_per_block", - task->hardware_params->max_registers_per_block}, - {"max_threads_per_block", - task->hardware_params->max_threads_per_block}, - {"max_vector_bytes", - task->hardware_params->vector_unit_bytes} - }; - pass_list.push_back(tir::transform::VerifyGPUCode(gpu_params)); - const auto& optimize = tir::transform::Sequential(pass_list); - optimize(mod); - } - const auto& optimize = tir::transform::Sequential( - Array{tir::transform::Simplify()}); - mod = optimize(std::move(mod)); - const auto& it = mod->functions.find(global_var); - CHECK(it != mod->functions.end()); - const auto& prim_func = (*it).second.as(); - GetPerStmtFeature(prim_func->body, - task->hardware_params->cache_line_bytes, - max_n_bufs, feature); - } catch (dmlc::Error &e) { - (*error_ct)++; - } -} - -void GetPerStmtFeaturesFromStates(const Array& states, - const SearchTask& task, - int skip_first_n_feature_extraction, - int max_n_bufs, - std::vector >* features) { - // extract features - features->assign(states.size(), std::vector()); - - std::atomic error_ct(0); - - ThreadPool& pool = ThreadPool::Global(); - pool.BeginBatch(static_cast(states.size()) - skip_first_n_feature_extraction); - for (size_t i = skip_first_n_feature_extraction; i < states.size(); ++i) { - pool.Enqueue(GetPerStmtFeaturesWorkerFunc, task, states[i], - max_n_bufs, &(*features)[i], &error_ct); - // GetPerStmtFeaturesWorkerFunc(task, states[i], - // max_n_bufs, &(*features)[i], &error_ct); - } - pool.WaitBatch(); - - if (error_ct > 0) { - std::cerr << "Encountered " << error_ct - << " errors during feature extraction. Ignored." << std::endl; - } -} - - -void GetPerStmtFeaturesFromStates(const Array& states, - const std::vector& tasks, - int skip_first_n_feature_extraction, - int max_n_bufs, - std::vector >* features) { - // extract features - features->assign(states.size(), std::vector()); - - std::atomic error_ct(0); - - ThreadPool& pool = ThreadPool::Global(); - pool.BeginBatch(static_cast(states.size()) - skip_first_n_feature_extraction); - for (size_t i = skip_first_n_feature_extraction; i < states.size(); ++i) { - pool.Enqueue(GetPerStmtFeaturesWorkerFunc, tasks[i], states[i], - max_n_bufs, &(*features)[i], &error_ct); - } - pool.WaitBatch(); - - if (error_ct > 0) { - std::cerr << "Encountered " << error_ct - << " errors during feature extraction. Ignored." << std::endl; - } -} - -void GetPerStmtFeaturesFromFile(const std::string& filename, - int n_lines, - int max_n_bufs, - std::vector >* features, - std::vector* normalized_throughputs, - std::vector* task_ids) { - Array states; - // ArrayNode* pstates = states.CopyOnWrite(); - std::vector tasks; - - normalized_throughputs->clear(); - task_ids->clear(); - - // (workload_key, target) -> (search_task, task_id) - std::unordered_map, std::pair> task_cache; - // task_id -> min_cost - std::vector min_costs; - - // read from file - LogReader reader = LogReader(filename); - auto cur_inp = make_object(); - auto cur_res = make_object(); - while (reader->ReadNext(cur_inp.get(), cur_res.get())) { - float cost = static_cast(FloatArrayMean(cur_res->costs)); - const std::string& workload_key = cur_inp->task->workload_key; - - SearchTask task; - size_t task_id; - std::pair key(workload_key, cur_inp->task->target->str()); - auto find_res = task_cache.find(key); - if (find_res == task_cache.end()) { - // rebuild task - task = SearchTask(ComputeDAG(workload_key), workload_key, - cur_inp->task->target, cur_inp->task->target_host, - cur_inp->task->hardware_params); - task_id = task_cache.size(); - - // compute min cost for each task - task_cache.insert(std::make_pair(key, std::make_pair(task, task_id))); - min_costs.push_back(cost); - } else { - std::tie(task, task_id) = find_res->second; - min_costs[task_id] = std::min(min_costs[task_id], cost); - } - - tasks.push_back(std::move(task)); - task_ids->push_back(task_id); - // pstates->data.push_back(cur_inp->state); - states.push_back(cur_inp->state); - normalized_throughputs->push_back(cost); - - if (n_lines > 0 && static_cast(states.size()) >= n_lines) { - break; - } - } - - for (size_t i = 0; i < normalized_throughputs->size(); ++i) { - (*normalized_throughputs)[i] = min_costs[(*task_ids)[i]] / (*normalized_throughputs)[i]; - } - - GetPerStmtFeaturesFromStates(states, tasks, 0, max_n_bufs, features); -} - -void GetPerStmtFeaturesFromMeasurePairs(const Array& inputs, - const Array& results, - int skip_first_n_feature_extraction, - int max_n_bufs, - std::vector >* features, - std::vector* normalized_throughputs, - std::vector* task_ids) { - Array states; - // ArrayNode* pstates = states.CopyOnWrite(); - std::vector tasks; - - normalized_throughputs->clear(); - task_ids->clear(); - - // (workload_key, target) -> (search_task, task_id) - std::unordered_map, std::pair> task_cache; - // task_id -> min_cost - std::vector min_costs; - - tasks.reserve(inputs.size()); - normalized_throughputs->reserve(inputs.size()); - task_ids->reserve(inputs.size()); - for (size_t i = 0; i < inputs.size(); ++i) { - float cost = static_cast(FloatArrayMean(results[i]->costs)); - const std::string& workload_key = inputs[i]->task->workload_key; - SearchTask task; - - size_t task_id; - std::pair key(workload_key, inputs[i]->task->target->str()); - auto find_res = task_cache.find(key); - if (find_res == task_cache.end()) { - if (inputs[i]->task->compute_dag.defined()) { // the measure input is complete - task = inputs[i]->task; - } else { // the measure input is incomplete - // rebuild task for incomplete measure pairs read from file - task = SearchTask(ComputeDAG(workload_key), workload_key, - inputs[i]->task->target, inputs[i]->task->target_host, - inputs[i]->task->hardware_params); - } - task_id = task_cache.size(); - - // compute min cost for each task - task_cache.insert(std::make_pair(key, std::make_pair(task, task_id))); - min_costs.push_back(cost); - } else { - std::tie(task, task_id) = find_res->second; - min_costs[task_id] = std::min(min_costs[task_id], cost); - } - - tasks.push_back(std::move(task)); - task_ids->push_back(task_id); - // pstates->data.push_back(inputs[i]->state); - states.push_back(inputs[i]->state); - normalized_throughputs->push_back(cost); - } - - for (size_t i = 0; i < normalized_throughputs->size(); ++i) { - (*normalized_throughputs)[i] = min_costs[(*task_ids)[i]] / (*normalized_throughputs)[i]; - } - - GetPerStmtFeaturesFromStates(states, tasks, skip_first_n_feature_extraction, - max_n_bufs, features); -} - -TVMByteArray SerializeFeatures(std::vector >&& features, - std::vector&& normalized_throughputs, - std::vector&& task_ids, - std::vector* out_data) { - size_t total_bytes = 0; - std::vector size_vector; - - int n = features.size(); - - // serialize sizes - size_t size_vector_size = 1 + n + 2; - total_bytes += size_vector_size * sizeof(int); - - size_vector.reserve(size_vector_size); - size_vector.push_back(features.size()); - for (const auto& x : features) { - size_vector.push_back(static_cast(x.size())); - total_bytes += sizeof(float) * x.size(); - } - size_vector.push_back(static_cast(normalized_throughputs.size())); - total_bytes += sizeof(float) * normalized_throughputs.size(); - size_vector.push_back(static_cast(task_ids.size())); - total_bytes += sizeof(int) * task_ids.size(); - - CHECK_EQ(size_vector.size(), size_vector_size); - - // allocate memory - out_data->reserve(total_bytes); - char* ptr = out_data->data(); - - // serialize size_vector - memmove(ptr, reinterpret_cast(size_vector.data()), size_vector.size() * sizeof(int)); - ptr += size_vector.size() * sizeof(int); - - // serialize features - for (auto& x : features) { - memmove(ptr, x.data(), sizeof(float) * x.size()); - ptr += sizeof(float) * x.size(); - x.clear(); - } - - // serialize normalized_throughputs - memmove(ptr, reinterpret_cast(normalized_throughputs.data()), - normalized_throughputs.size() * sizeof(int)); - ptr += normalized_throughputs.size() * sizeof(int); - - // serialize task_ids - memmove(ptr, reinterpret_cast(task_ids.data()), task_ids.size() * sizeof(int)); - ptr += task_ids.size() * sizeof(int); - - CHECK_EQ(ptr - out_data->data(), total_bytes); - - return TVMByteArray{out_data->data(), total_bytes}; -} - - -TVM_REGISTER_GLOBAL("ansor.GetPerStmtFeaturesFromFile") -.set_body([](TVMArgs args, TVMRetValue *ret) { - std::string filename = args[0]; - int n_lines = args[1]; - int max_n_bufs = args[2]; - - std::vector > features; - std::vector normalized_throughputs; - std::vector task_ids; - - GetPerStmtFeaturesFromFile(filename, n_lines, max_n_bufs, - &features, &normalized_throughputs, &task_ids); - - // serialization format for n records: - // - // int n; - // int[n+2] sizes - // - // float[sizes[0]] feature for record 1 - // float[sizes[1]] feature for record 2 - // ... feature for record i... - // float[sizes[n-1]] feature for record n - // - // float[sizes[n]] normalized throughput for n records - // int[sizes[n+1]] task id for n records - - std::vector byte_data; - *ret = SerializeFeatures(std::move(features), std::move(normalized_throughputs), - std::move(task_ids), &byte_data); -}); - -TVM_REGISTER_GLOBAL("ansor.GetPerStmtFeaturesFromMeasurePairs") -.set_body([](TVMArgs args, TVMRetValue *ret) { - Array inputs = args[0]; - Array results = args[1]; - int skip_first_n_feature_extraction = args[2]; - int max_n_bufs = args[3]; - - std::vector > features; - std::vector normalized_throughputs; - std::vector task_ids; - - GetPerStmtFeaturesFromMeasurePairs(inputs, results, skip_first_n_feature_extraction, max_n_bufs, - &features, &normalized_throughputs, &task_ids); - - // serialization format for n records: - // - // int n; - // int[n+2] sizes - // - // float[sizes[0]] feature for record 1 - // float[sizes[1]] feature for record 2 - // ... feature for record i... - // float[sizes[n-1]] feature for record n - // - // float[sizes[n]] normalized throughput for n records - // int[sizes[n+1]] task id for n records - - std::vector byte_data; - *ret = SerializeFeatures(std::move(features), std::move(normalized_throughputs), - std::move(task_ids), &byte_data); -}); - -TVM_REGISTER_GLOBAL("ansor.GetPerStmtFeaturesFromStates") -.set_body([](TVMArgs args, TVMRetValue *ret) { - Array states = args[0]; - SearchTask task = args[1]; - int max_n_bufs = args[2]; - - std::vector > features; - std::vector normalized_throughputs; - std::vector task_ids; - - GetPerStmtFeaturesFromStates(states, task, 0, max_n_bufs, &features); - - // serialization format for n records: - // - // int n; - // int[n+2] sizes - // - // float[sizes[0]] feature for record 1 - // float[sizes[1]] feature for record 2 - // ... feature for record i... - // float[sizes[n-1]] feature for record n - // - // float[sizes[n]] normalized throughput for n records - // int[sizes[n+1]] task id for n records - - std::vector byte_data; - *ret = SerializeFeatures(std::move(features), std::move(normalized_throughputs), - std::move(task_ids), &byte_data); -}); - -TVM_REGISTER_GLOBAL("ansor.GetPerStmtFeatureNames") - .set_body([](TVMArgs args, TVMRetValue *ret) { - int max_n_bufs = args[0]; - std::vector names; - - GetPerStmtFeatureName(max_n_bufs, &names); - - Array arr; - for (const auto& x : names) { - arr.push_back(x); - } - *ret = arr; -}); - - -} // namespace ansor -} // namespace tvm diff --git a/src/ansor/feature.h b/src/ansor/feature.h deleted file mode 100644 index e507149643e2..000000000000 --- a/src/ansor/feature.h +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file ansor/feature.h - * \brief Feature extraction for the cost model - */ - -#ifndef TVM_ANSOR_FEATURE_H_ -#define TVM_ANSOR_FEATURE_H_ - -#include -#include -#include "compute_dag.h" -#include "measure.h" - -namespace tvm { -namespace ansor { - -/*! \brief Get PerStmt feature from a tvm IR stmt */ -void GetPerStmtFeature(const Stmt& stmt, - int cache_line_size, - int max_n_bufs, - std::vector* ret); - -/* \brief Get the name of every element in the feature vector. Use this for debug and inspection */ -void GetPerStmtFeatureName(int max_n_bufs, std::vector *ret); - - -/*! \brief Get PerStmt feature from states and the same task */ -void GetPerStmtFeaturesFromStates(const Array& states, - const SearchTask& task, - int skip_first_n_feature_extraction, - int max_n_bufs, - std::vector >* features); - -/*! \brief Get PerStmt feature from states and different tasks */ -void GetPerStmtFeaturesFromStates(const Array& states, - const std::vector& tasks, - int skip_first_n_feature_extraction, - int max_n_bufs, - std::vector >* features); - -/*! \brief Get PerStmt feature from a log file */ -void GetPerStmtFeaturesFromFile(const std::string& filename, - int n_lines, - int max_n_bufs, - std::vector >* features, - std::vector* normalized_throughputs, - std::vector* task_ids); - -/*! \brief Get PerStmt feature from measure pairs */ -void GetPerStmtFeaturesFromMeasurePairs(const Array& inputs, - const Array& results, - int skip_first_n_feature_extraction, - int max_n_bufs, - std::vector >* features, - std::vector* normalized_throughputs, - std::vector* task_ids); - -} // namespace ansor -} // namespace tvm - -#endif // TVM_ANSOR_FEATURE_H_ diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index 010e5f3dc221..787e4256a181 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -153,28 +153,6 @@ std::vector State::split(int stage_id, const Iterator& it, return DoSplitStep(step); } -std::vector State::follow_split(int stage_id, const Iterator& it, - int src_step_id, int n_split) { - const Stage& stage = operator->()->stages[stage_id]; - - FollowSplitStep step = FollowSplitStep( - stage_id, GetIndex(stage->iters, it), src_step_id, n_split); - CopyOnWrite()->transform_steps.push_back(step); - return DoFollowSplitStep(step); -} - -std::vector State::follow_fused_split( - int stage_id, const Iterator& it, const std::vector& src_step_ids, - int level, bool factor_or_nparts) { - const Stage& stage = operator->()->stages[stage_id]; - - FollowFusedSplitStep step = - FollowFusedSplitStep(stage_id, GetIndex(stage->iters, it), - src_step_ids, level, factor_or_nparts); - CopyOnWrite()->transform_steps.push_back(step); - return DoFollowFusedSplitStep(step); -} - Iterator State::fuse(int stage_id, const std::vector& iters) { const Stage& stage = operator->()->stages[stage_id]; std::vector indices; @@ -184,126 +162,6 @@ Iterator State::fuse(int stage_id, const std::vector& iters) { return DoFuseStep(step); } -Iterator State::vectorize(int stage_id, const Iterator& it) { - const Stage& stage = operator->()->stages[stage_id]; - AnnotationStep step = AnnotationStep( - stage_id, GetIndex(stage->iters, it), kVectorize); - CopyOnWrite()->transform_steps.push_back(step); - return DoAnnotationStep(step); -} - -Iterator State::parallel(int stage_id, const Iterator& it) { - const Stage& stage = operator->()->stages[stage_id]; - AnnotationStep step = - AnnotationStep(stage_id, GetIndex(stage->iters, it), kParallel); - CopyOnWrite()->transform_steps.push_back(step); - return DoAnnotationStep(step); -} - -Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) { - const Stage& stage = operator->()->stages[stage_id]; - AnnotationStep step = - AnnotationStep(stage_id, GetIndex(stage->iters, it), kUnroll); - - // don't unroll if the extent is larger than max_unroll - if (max_unroll != -1 && it->range.defined()) { - if (auto imm = it->range->extent.as()) { - if (imm->value > max_unroll) { - return it; - } - } - } - - CopyOnWrite()->transform_steps.push_back(step); - return DoAnnotationStep(step); -} - -void State::compute_at(int stage_id, int target_stage_id, - const Iterator& target_iter) { - const Stage& target_stage = operator->()->stages[target_stage_id]; - ComputeAtStep step = ComputeAtStep( - stage_id, target_stage_id, GetIndex(target_stage->iters, target_iter)); - CopyOnWrite()->transform_steps.push_back(step); - return DoComputeAtStep(step); -} - -void State::compute_root(int stage_id) { - ComputeRootStep step = ComputeRootStep(stage_id); - CopyOnWrite()->transform_steps.push_back(step); - return DoComputeRootStep(step); -} - -void State::compute_inline(int stage_id) { - ComputeInlineStep step = ComputeInlineStep(stage_id); - CopyOnWrite()->transform_steps.push_back(step); - return DoComputeInlineStep(step); -} - -Iterator State::bind_thread(int stage_id, const Iterator& it, - IteratorAnnotation thread_type) { - const Stage& stage = operator->()->stages[stage_id]; - if (thread_type < kVThread || thread_type > kThreadY) { - LOG(FATAL) << "thread_type error, valide: kVThread, kBlockX, kBlockY, " - << "kThreadX, kThreadY"; - } - AnnotationStep step = AnnotationStep( - stage_id, GetIndex(stage->iters, it), thread_type); - CopyOnWrite()->transform_steps.push_back(step); - return DoAnnotationStep(step); -} - -int State::cache_read(int stage_id, const std::string& scope_name, - const std::vector& reader_stage_ids, - const ComputeDAG& task_dag) { - CacheReadStep step = - CacheReadStep(stage_id, scope_name, reader_stage_ids); - CopyOnWrite()->transform_steps.push_back(step); - return DoCacheReadStep(step, task_dag); -} - -int State::cache_write(int stage_id, const std::string& scope_name, - const ComputeDAG& task_dag) { - CacheWriteStep step = CacheWriteStep(stage_id, scope_name); - CopyOnWrite()->transform_steps.push_back(step); - return DoCacheWriteStep(step, task_dag); -} - -void State::pragma(int stage_id, const Iterator& it, - const std::string& pragma_type) { - const Stage& stage = operator->()->stages[stage_id]; - PragmaStep step = - PragmaStep(stage_id, GetIndex(stage->iters, it), pragma_type); - CopyOnWrite()->transform_steps.push_back(step); - return DoPragmaStep(step); -} - -int State::rfactor(int stage_id, const Iterator& it, int factor_iter_id, - const ComputeDAG& task_dag) { - const Stage& stage = operator->()->stages[stage_id]; - RfactorStep step = RfactorStep(stage_id, GetIndex(stage->iters, it), - factor_iter_id); - CopyOnWrite()->transform_steps.push_back(step); - return DoRfactorStep(step, task_dag); -} - -void State::storage_align(int stage_id, const Iterator& it, int factor, - int offset) { - const Stage& stage = operator->()->stages[stage_id]; - StorageAlignStep step = StorageAlignStep( - stage_id, GetIndex(stage->iters, it), factor, offset); - CopyOnWrite()->transform_steps.push_back(step); - return DoStorageAlignStep(step); -} - -Iterator State::tensorize(int stage_id, const Iterator& it, - std::string ti_func_name) { - const Stage& stage = operator->()->stages[stage_id]; - TensorizeStep step = TensorizeStep( - stage_id, GetIndex(stage->iters, it), ti_func_name); - CopyOnWrite()->transform_steps.push_back(step); - return DoTensorizeStep(step); -} - // Steps' implementations void State::DoReorderStep(const ReorderStep& step) { const Stage& stage = operator->()->stages[step->stage_id]; @@ -402,20 +260,6 @@ std::vector State::DoSplitStep(const SplitStep& step) { step->inner_to_outer); } -std::vector State::DoFollowSplitStep(const FollowSplitStep& step) { - std::vector lengths; - step->ExtractSplitLengths(operator->()->transform_steps, &lengths); - return DoSplitStepCommon(step->stage_id, step->iter_id, lengths, true); -} - -std::vector State::DoFollowFusedSplitStep( - const FollowFusedSplitStep& step) { - const PrimExpr& length = - step->ExtractSplitLength(operator->()->transform_steps); - return DoSplitStepCommon(step->stage_id, step->iter_id, {length}, - step->factor_or_nparts); -} - Iterator State::DoFuseStep(const FuseStep& step) { int stage_id = step->stage_id; const Stage& stage = operator->()->stages[stage_id]; @@ -499,292 +343,13 @@ Iterator State::DoFuseStep(const FuseStep& step) { return new_it; } -Iterator State::DoAnnotationStep(const AnnotationStep& step) { - const Stage& stage = operator->()->stages[step->stage_id]; - Iterator it = stage->iters[step->iter_id]; - - CHECK_EQ(it->annotation, IteratorAnnotation::kNone); - Iterator new_it = Iterator(it->name, it->range, it->iter_type, - step->annotation, &it->ori_iters, - it->attr); - Stage new_stage = stage; - new_stage.CopyOnWrite()->iters[step->iter_id] = new_it; - StateNode* pstate = CopyOnWrite(); - pstate->stages[step->stage_id] = std::move(new_stage); - return new_it; -} - -void State::DoComputeAtStep(const ComputeAtStep& step) { - const Stage& stage = operator->()->stages[step->stage_id]; - - // after compute_at, we don't know the accurate length information any more - // If we do want to know the accurate lengths, we can call - // ComputeDAG::ReplayAndInferBound - std::vector new_iters; - for (const Iterator& it : stage->iters) { - size_t s = it->name.size(); - if (s >= 2 && it->name[s - 2] == '.' && it->name[s - 1] >= '1' && - it->name[s - 1] <= '4') { - // We use a dangerous heuristic rule here : For multi level splitted - // iterators, we assume their length does not change after compute_at. - // Reason: These iterators are generated in MultiStagePolicy by multi - // level tiling, they will be carefully compute_at their consumers. - // In this case, their lengths do not change. - // We do this to keep the AnnotateCPU pass to annotate more efficiently. - new_iters.push_back(it); - } else { - new_iters.push_back(Iterator(it->name, Range(), it->iter_type, - it->annotation, &it->ori_iters, it->attr)); - } - } - - StateNode* pstate = CopyOnWrite(); - pstate->stages[step->stage_id] = - Stage(stage->op, stage->op_type, std::move(new_iters), kIter, - stage->attrs); - pstate->attach_map.SetComputeAtIter(step->stage_id, step->target_stage_id, - step->target_iter_id); -} - -void State::DoComputeRootStep(const ComputeRootStep& step) { - const Stage& stage = operator->()->stages[step->stage_id]; - - // after compute_root, we don't know the accurate length information any more - // If we do want to know the accurate lengths, we can call - // ComputeDAG::ReplayAndInferBound - std::vector new_iters; - for (const Iterator& it : stage->iters) { - new_iters.push_back(Iterator(it->name, Range(), it->iter_type, - it->annotation, &it->ori_iters, it->attr)); - } - - // update attach map - StateNode* pstate = CopyOnWrite(); - pstate->stages[step->stage_id] = Stage(stage->op, stage->op_type, - std::move(new_iters), kRoot, - stage->attrs); - pstate->attach_map.DeleteStage(step->stage_id); -} - -void State::DoComputeInlineStep(const ComputeInlineStep& step) { - const Stage& stage = operator->()->stages[step->stage_id]; - - StateNode* pstate = CopyOnWrite(); - - // CHECK the validity of compute_inline - const auto& iter_to_attached_stages = - pstate->attach_map->iter_to_attached_stages; - for (size_t i = 0; i < stage->iters.size(); ++i) { - CHECK_EQ(iter_to_attached_stages.count(std::make_pair(step->stage_id, i)), - 0) - << "Invalid compute_inline: Because there are some other stages " - "that are attached to the target stage"; - } - - pstate->stages[step->stage_id].CopyOnWrite()->compute_at = kInlined; - pstate->attach_map.DeleteStage(step->stage_id); -} - -// Common part for steps that add new stages -// (e.g. CacheReadStep, CacheWriteStep, RfactorStep) -void AddStageModificationSteps(size_t step_id, - const std::vector& transform_steps, - std::vector* replay_steps) { - const Step& step = transform_steps[step_id]; - if (step->IsInstance() || - step->IsInstance()) { - replay_steps->push_back(step); - } else if (step->IsInstance()) { - // add FuseStepNode required by rfactor - if (step_id >= 2 && - transform_steps[step_id - 2]->IsInstance()) { - const Step& fuse_step = transform_steps[step_id - 2]; - if (fuse_step->stage_id == step->stage_id) { - replay_steps->push_back(fuse_step); - } - } - // add SplitStepNode required by rfactor - CHECK_GE(step_id, 1); - CHECK(transform_steps[step_id - 1]->IsInstance()); - const Step& split_step = transform_steps[step_id - 1]; - CHECK_EQ(split_step->stage_id, step->stage_id); - replay_steps->push_back(split_step); - // add RfactorStepNode - replay_steps->push_back(step); - } -} - -int State::DoCacheReadStep(const CacheReadStep& step, const ComputeDAG& dag) { - StateNode* pstate = CopyOnWrite(); - std::vector replay_steps; - for (size_t i = 0; i < pstate->transform_steps.size(); ++i) { - AddStageModificationSteps(i, pstate->transform_steps, &replay_steps); - if (pstate->transform_steps[i].same_as(step)) { - break; - } - } - dag.ReplayAndGetDAG(replay_steps, &(pstate->task_dag)); - - // target -> target + target_store - // Should update target's op, insert new stage, update the later stage's op - pstate->stages[step->stage_id].CopyOnWrite()->op = - operator->()->task_dag->ops[step->stage_id]; - pstate->stages.insert( - pstate->stages.begin() + step->stage_id + 1, - Stage(operator->()->task_dag->ops[step->stage_id + 1])); - for (size_t i = step->stage_id + 2; i < operator->()->stages.size(); ++i) { - pstate->stages[i].CopyOnWrite()->op = operator->()->task_dag->ops[i]; - } - pstate->attach_map = operator->()->attach_map.ApplyStageIdOfffset( - step->stage_id + 1, 1); - - return step->stage_id + 1; -} - -int State::DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag) { - StateNode* pstate = CopyOnWrite(); - std::vector replay_steps; - for (size_t i = 0; i < pstate->transform_steps.size(); ++i) { - AddStageModificationSteps(i, pstate->transform_steps, &replay_steps); - if (pstate->transform_steps[i].same_as(step)) { - break; - } - } - - int last_dag_op_size = pstate->task_dag.defined() ? - pstate->task_dag->ops.size() : dag->ops.size(); - dag.ReplayAndGetDAG(replay_steps, &(pstate->task_dag)); - int added_ops = pstate->task_dag->ops.size() - last_dag_op_size; - CHECK_GE(added_ops, 1); - - // target -> target_compute + target - // Assume target stage has never been applied any steps before cache_write - // Should insert new stage, update target stage, update the later stage's op - pstate->stages.insert( - pstate->stages.begin() + step->stage_id, - Stage(operator->()->task_dag->ops[step->stage_id])); - pstate->stages[step->stage_id + 1] = - Stage(operator->()->task_dag->ops[step->stage_id + 1]); - int next_stage_id = step->stage_id + 2; - // Notice: added_ops should actually assert to be 1 - // branch of 2 here is somehow a hack to TVM's cache_write bug with - // multi outputs, see test/cpp/ansor_test.cc: CacheReadWrite test - // for more information - // TODO(jcf94): Fix this - if (added_ops == 2) { - pstate->stages.insert( - pstate->stages.begin() + next_stage_id, - Stage(operator->()->task_dag->ops[next_stage_id])); - next_stage_id++; - } else if (added_ops > 2) { - LOG(ERROR) << "Unexpected behavior of CacheWrite."; - } - for (size_t i = next_stage_id; i < operator->()->task_dag->ops.size(); ++i) { - pstate->stages[i].CopyOnWrite()->op = operator->()->task_dag->ops[i]; - } - pstate->attach_map = operator->()->attach_map.ApplyStageIdOfffset( - step->stage_id, added_ops); - - return step->stage_id; -} - -void State::DoPragmaStep(const PragmaStep& step) { - if (step->pragma_type == "debug_skip_region") { - StateNode* pstate = CopyOnWrite(); - pstate->attach_map.DeleteStage(step->stage_id); - } else if (StrStartsWith(step->pragma_type, "auto_unroll_max_step")) { - StateNode* pstate = CopyOnWrite(); - StageNode* stage = pstate->stages[step->stage_id].CopyOnWrite(); - size_t pos = step->pragma_type.find('$'); - stage->attrs.auto_unroll_max_step = atoi(step->pragma_type.c_str() + pos + 1); - } else if (step->pragma_type == "tensor_core") { - // Nothing needs to be done here - } else { - LOG(FATAL) << "Invalid pragma: " << step->pragma_type; - } -} - -int State::DoRfactorStep(const RfactorStep& step, const ComputeDAG& dag) { - StateNode* pstate = CopyOnWrite(); - const auto compute_at_type = pstate->stages[step->stage_id]->compute_at; - std::vector replay_steps; - for (size_t i = 0; i < pstate->transform_steps.size(); ++i) { - AddStageModificationSteps(i, pstate->transform_steps, &replay_steps); - if (pstate->transform_steps[i].same_as(step)) { - break; - } - } - dag.ReplayAndGetDAG(replay_steps, &(pstate->task_dag)); - - // target -> target_compute + target - // Should insert new stage, update target stage, update the later stage's op - pstate->stages.insert( - pstate->stages.begin() + step->stage_id, - Stage(operator->()->task_dag->ops[step->stage_id])); - // maintain the compute_at type of target stage - Stage target_stage = - Stage(operator->()->task_dag->ops[step->stage_id + 1]); - target_stage.CopyOnWrite()->compute_at = compute_at_type; - pstate->stages[step->stage_id + 1] = target_stage; - - for (size_t i = step->stage_id + 2; i < operator->()->stages.size(); ++i) { - pstate->stages[i].CopyOnWrite()->op = operator->()->task_dag->ops[i]; - } - pstate->attach_map = operator->()->attach_map.ApplyStageIdOfffset( - step->stage_id, 1); - - return step->stage_id; -} - -void State::DoStorageAlignStep(const StorageAlignStep& step) { - StateNode* pstate = CopyOnWrite(); - StageNode* stage = pstate->stages[step->stage_id].CopyOnWrite(); - stage->attrs.storage_offset = step->offset; -} - -Iterator State::DoTensorizeStep(const TensorizeStep& step) { - const Stage& stage = operator->()->stages[step->stage_id]; - Iterator it = stage->iters[step->iter_id]; - Iterator new_it = Iterator(it->name, it->range, it->iter_type, - IteratorAnnotation::kTensorized, &it->ori_iters, step->ti_func_name); - Stage new_stage = stage; - new_stage.CopyOnWrite()->iters[step->iter_id] = new_it; - StateNode* pstate = CopyOnWrite(); - pstate->stages[step->stage_id] = std::move(new_stage); - return new_it; -} - void State::DoStep(const Step& step, const ComputeDAG& dag) { if (auto ps = step.as()) { DoReorderStep(GetRef(ps)); } else if (auto ps = step.as()) { DoSplitStep(GetRef(ps)); - } else if (auto ps = step.as()) { - DoFollowSplitStep(GetRef(ps)); - } else if (auto ps = step.as()) { - DoFollowFusedSplitStep(GetRef(ps)); } else if (auto ps = step.as()) { DoFuseStep(GetRef(ps)); - } else if (auto ps = step.as()) { - DoAnnotationStep(GetRef(ps)); - } else if (auto ps = step.as()) { - DoComputeAtStep(GetRef(ps)); - } else if (auto ps = step.as()) { - DoComputeRootStep(GetRef(ps)); - } else if (auto ps = step.as()) { - DoComputeInlineStep(GetRef(ps)); - } else if (auto ps = step.as()) { - DoCacheReadStep(GetRef(ps), dag); - } else if (auto ps = step.as()) { - DoCacheWriteStep(GetRef(ps), dag); - } else if (auto ps = step.as()) { - DoPragmaStep(GetRef(ps)); - } else if (auto ps = step.as()) { - DoRfactorStep(GetRef(ps), dag); - } else if (auto ps = step.as()) { - DoStorageAlignStep(GetRef(ps)); - } else if (auto ps = step.as()) { - DoTensorizeStep(GetRef(ps)); } else { LOG(FATAL) << "Invalid step: " << step; } @@ -1068,26 +633,6 @@ TVM_REGISTER_GLOBAL("ansor.StateSplit") return Array{state, Array(res)}; }); -TVM_REGISTER_GLOBAL("ansor.StateFollowSplit") -.set_body_typed([](State state, int stage_id, const Iterator& it, - int src_step_id, int n_split) { - const auto& res = state.follow_split(stage_id, it, src_step_id, n_split); - return Array{state, Array(res)}; -}); - -TVM_REGISTER_GLOBAL("ansor.StateFollowFusedSplit") -.set_body_typed([](State state, int stage_id, const Iterator& it, - const Array& src_step_ids, int level, - bool factor_or_nparts) { - std::vector array_src_step_ids; - for (const auto& i : src_step_ids) { - array_src_step_ids.push_back(i->value); - } - const auto& res = state.follow_fused_split( - stage_id, it, array_src_step_ids, level, factor_or_nparts); - return Array{state, Array(res)}; -}); - TVM_REGISTER_GLOBAL("ansor.StateFuse") .set_body_typed([](State state, int stage_id, const Array& iters) { @@ -1099,100 +644,6 @@ TVM_REGISTER_GLOBAL("ansor.StateFuse") return Array{state, res}; }); -TVM_REGISTER_GLOBAL("ansor.StateVectorize") -.set_body_typed([](State state, int stage_id, const Iterator& it) { - const auto& res = state.vectorize(stage_id, it); - return Array{state, res}; -}); - -TVM_REGISTER_GLOBAL("ansor.StateParallel") -.set_body_typed([](State state, int stage_id, const Iterator& it) { - const auto& res = state.parallel(stage_id, it); - return Array{state, res}; -}); - -TVM_REGISTER_GLOBAL("ansor.StateUnroll") -.set_body_typed([](State state, int stage_id, const Iterator& it, - int max_unroll) { - const auto& res = state.unroll(stage_id, it, max_unroll); - return Array{state, res}; -}); - -TVM_REGISTER_GLOBAL("ansor.StateBindThread") -.set_body_typed([](State state, int stage_id, const Iterator& it, - int thread_type) { - const auto& res = - state.bind_thread(stage_id, it, IteratorAnnotation(thread_type)); - return Array{state, res}; -}); - -TVM_REGISTER_GLOBAL("ansor.StateComputeAt") -.set_body_typed([](State state, int stage_id, int target_stage_id, - const Iterator& target_iter) { - state.compute_at(stage_id, target_stage_id, target_iter); - return state; -}); - -TVM_REGISTER_GLOBAL("ansor.StateComputeRoot") -.set_body_typed([](State state, int stage_id) { - state.compute_root(stage_id); - return state; -}); - -TVM_REGISTER_GLOBAL("ansor.StateComputeInline") -.set_body_typed([](State state, int stage_id) { - state.compute_inline(stage_id); - return state; -}); - -TVM_REGISTER_GLOBAL("ansor.StateCacheRead") -.set_body_typed([](State state, int stage_id, const std::string& scope_name, - const Array& reader_stage_ids, - const ComputeDAG& task_dag) { - std::vector array_reader_stage_ids; - for (const auto& i : reader_stage_ids) { - array_reader_stage_ids.push_back(i->value); - } - int res = state.cache_read(stage_id, scope_name, array_reader_stage_ids, - task_dag); - return Array{state, IntImm(DataType::Int(32), res)}; -}); - -TVM_REGISTER_GLOBAL("ansor.StateCacheWrite") -.set_body_typed([](State state, int stage_id, const std::string& scope_name, - const ComputeDAG& task_dag) { - int res = state.cache_write(stage_id, scope_name, task_dag); - return Array{state, IntImm(DataType::Int(32), res)}; -}); - -TVM_REGISTER_GLOBAL("ansor.StatePragma") -.set_body_typed([](State state, int stage_id, const Iterator& it, - const std::string& pragma_type) { - state.pragma(stage_id, it, pragma_type); - return state; -}); - -TVM_REGISTER_GLOBAL("ansor.StateRfactor") -.set_body_typed([](State state, int stage_id, const Iterator& it, - int factor_iter_id, const ComputeDAG& task_dag) { - int res = state.rfactor(stage_id, it, factor_iter_id, task_dag); - return Array{state, IntImm(DataType::Int(32), res)}; -}); - -TVM_REGISTER_GLOBAL("ansor.StateStorageAlign") -.set_body_typed([](State state, int stage_id, const Iterator& it, - int factor, int offset) { - state.storage_align(stage_id, it, factor, offset); - return state; -}); - -TVM_REGISTER_GLOBAL("ansor.StateTensorize") -.set_body_typed([](State state, int stage_id, const Iterator& it, - std::string ti_func) { - const auto& res = state.tensorize(stage_id, it, ti_func); - return Array{state, res}; -}); - TVM_REGISTER_GLOBAL("ansor.StateEqual") .set_body_typed([](State state1, State state2) { return std::equal_to()(state1, state2); diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h index 1b7bbc40bb31..2d6c85db0247 100644 --- a/src/ansor/loop_state.h +++ b/src/ansor/loop_state.h @@ -220,13 +220,7 @@ class StepNode: public Object { TVM_DEFINE_MUTABLE_OBJECT_REF(Step, StepNode); // Step forward decelerations -class ReorderStep; class SplitStep; class FollowSplitStep; -class FollowFusedSplitStep; -class FuseStep; class AnnotationStep; -class ComputeAtStep; class ComputeRootStep; class ComputeInlineStep; -class CacheReadStep; class CacheWriteStep; -class PragmaStep; class RfactorStep; class StorageAlignStep; -class TensorizeStep; +class ReorderStep; class SplitStep; class FuseStep; /*! \brief A state in the search process. * It consists of the current loop structure and the history steps to reach this state. */ @@ -264,55 +258,18 @@ class State : public ObjectRef { // Schedule primitives void reorder(int stage_id, const std::vector& order); - void compute_at(int stage_id, int target_stage_id, - const Iterator& target_iter); - void compute_root(int stage_id); - void compute_inline(int stage_id); - void pragma(int stage_id, const Iterator& it, const std::string& pragma_type); - void storage_align(int stage_id, const Iterator& it, int factor, int offset); std::vector split(int stage_id, const Iterator& it, const std::vector& lengths, bool inner_to_outer = true); - std::vector follow_split(int stage_id, const Iterator& it, - int src_step_id, int n_split); - std::vector follow_fused_split(int stage_id, const Iterator& it, - const std::vector& src_step_ids, - int level, bool factor_or_nparts); Iterator fuse(int stage_id, const std::vector& iters); - Iterator vectorize(int stage_id, const Iterator& it); - Iterator parallel(int stage_id, const Iterator& it); - Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1); - Iterator bind_thread(int stage_id, const Iterator& it, - IteratorAnnotation thread_type); - Iterator tensorize(int stage_id, const Iterator& it, - std::string ti_func_name); - int cache_read(int stage_id, const std::string& scope_name, - const std::vector& reader_stage_ids, - const ComputeDAG& task_dag); - int cache_write(int stage_id, const std::string& scope_name, - const ComputeDAG& task_dag); - int rfactor(int stage_id, const Iterator& it, int factor_iter_id, - const ComputeDAG& task_dag); /* Do transform steps * Note: The following functions only change loop state but do not change transform_history. * We separate these functions out, * so you can call them for replay easily given history steps */ void DoReorderStep(const ReorderStep& step); - void DoComputeAtStep(const ComputeAtStep& step); - void DoComputeRootStep(const ComputeRootStep& step); - void DoComputeInlineStep(const ComputeInlineStep& step); - void DoPragmaStep(const PragmaStep& step); - void DoStorageAlignStep(const StorageAlignStep& step); std::vector DoSplitStep(const SplitStep& step); - std::vector DoFollowSplitStep(const FollowSplitStep& step); - std::vector DoFollowFusedSplitStep(const FollowFusedSplitStep& step); Iterator DoFuseStep(const FuseStep& step); - Iterator DoAnnotationStep(const AnnotationStep& step); - Iterator DoTensorizeStep(const TensorizeStep& step); - int DoCacheReadStep(const CacheReadStep& step, const ComputeDAG& dag); - int DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag); - int DoRfactorStep(const RfactorStep& step, const ComputeDAG& dag); // General do step functions with a runtime dynamic dispatcher void DoStep(const Step& step, const ComputeDAG& dag); diff --git a/src/ansor/search_policy/sketch_search_policy.cc b/src/ansor/search_policy/sketch_search_policy.cc deleted file mode 100644 index c4365a391865..000000000000 --- a/src/ansor/search_policy/sketch_search_policy.cc +++ /dev/null @@ -1,1541 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file ansor/search_policy/sketch_search_policy.h - * \brief The search policy that searches in a hierarchical search space defined by sketches. - * The policy randomly samples programs from the space defined by sketches - * and use evolutionary search to fine-tune them. - */ - -#include "sketch_search_policy.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "utils.h" - -#define IS_GPU(task) ((task)->target->device_type == kDLGPU || \ - (task)->target->device_type == kDLOpenCL) - -namespace tvm { -namespace ansor { - -TVM_REGISTER_NODE_TYPE(SketchSearchPolicyNode); -TVM_REGISTER_OBJECT_TYPE(PreloadCustomSketchRuleNode); - -// All possible candidates for auto_unroll -const std::vector SketchSearchPolicyNode::auto_unroll_configs{0, 16, 64, 512, 1024}; - -SketchSearchPolicy::SketchSearchPolicy(CostModel program_cost_model, - Map params, - int seed) { - auto node = make_object(); - node->program_cost_model = std::move(program_cost_model); - node->rand_gen_ = std::mt19937(seed); - node->params = std::move(params); - data_ = std::move(node); -} - -State SketchSearchPolicyNode::Search(SearchTask task, int n_trials, - int early_stopping, int num_measure_per_iter, int verbose, - ProgramMeasurer measurer, Array pre_search_callbacks) { - std::vector best_states, random_states; - this->cur_task = task; - this->verbose = verbose; - num_measure_per_iter_ = num_measure_per_iter; - - PrintTitle("Call search callbacks", verbose); - RunCallbacks(pre_search_callbacks); - - if (n_trials <= 1) { // no measurement is allowed - SearchOneRound(&best_states, 0, &random_states); - CHECK_GT(best_states.size(), 0); - return best_states[0]; - } else { - std::vector inputs; - std::vector results; - int num_random = static_cast(GetDoubleParam(params, "eps_greedy") * num_measure_per_iter); - - measurer->Reset(); - - early_stopping = early_stopping < 0 ? std::numeric_limits::max() >> 1 : early_stopping; - - int ct = 0; - while (ct < n_trials) { - if (!inputs.empty()) { - // retrain cost models - PrintTitle("Train cost model", verbose); - program_cost_model->Update(inputs, results); - } - - // Search one round to get promising states - PrintTitle("Search", verbose); - SearchOneRound(&best_states, num_random, &random_states); - - // Infer bound. This is necessary for computing the correct ToStr() for redundancy check - cur_task->compute_dag.InferBound(&best_states); - cur_task->compute_dag.InferBound(&random_states); - - // Pick `num_measure_per_iter` states to measure, check hash to remove already measured state - // Also pick some random states to do eps-greedy - PickStatesWithEpsGreedy(&inputs, best_states, random_states, n_trials - ct); - - // Have traversed all of search space - if (inputs.empty()) { - StdCout(verbose) << "All candidates in the search space have been measured." << std::endl; - break; - } - - // Measure candidate states - PrintTitle("Measure", verbose); - measurer->Measure(cur_task, GetRef(this), inputs, &results); - ct += inputs.size(); - - if (ct - measurer->best_ct[cur_task->workload_key] > early_stopping) { - StdCout(verbose) << "Meet the early stopping condition." << std::endl; - break; - } - - // Update measured states. These states will join the LocalMutation in later rounds - for (const auto& res : results) { - measured_states_throughputs_.push_back(1.0 / FloatArrayMean(res->costs)); - } - } - PrintTitle("Done", verbose); - - return measurer->best_state[cur_task->workload_key]; - } -} - -std::pair, Array > - SketchSearchPolicyNode::ContinueSearchOneRound( - SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) { - if (cur_task.defined()) { - CHECK_EQ(cur_task, task); - } else { - cur_task = task; - } - this->verbose = verbose; - num_measure_per_iter_ = num_measure; - - std::vector best_states, random_states; - std::vector inputs; - std::vector results; - int num_random = static_cast(GetDoubleParam(params, "eps_greedy") * num_measure); - - // Search one round to get promising states - PrintTitle("Search", verbose); - SearchOneRound(&best_states, num_random * 2, &random_states); - - // Fill correct bound. This is necessary for computing the correct ToStr() for reduncency check - cur_task->compute_dag.InferBound(&best_states); - cur_task->compute_dag.InferBound(&random_states); - - // Pick `num_measure` states to measure, check hash to remove already measured state - // Also pick some random states to do eps-greedy - PickStatesWithEpsGreedy(&inputs, best_states, random_states, num_measure); - - // Measure candidate states - PrintTitle("Measure", verbose); - measurer->Measure(cur_task, GetRef(this), inputs, &results); - - // Update throughputs of measured states. These states will join the LocalMutation in later rounds - for (const auto& res : results) { - measured_states_throughputs_.push_back(1.0 / FloatArrayMean(res->costs)); - } - - // Update the cost model - Array inputs_arr(std::make_move_iterator(inputs.begin()), - std::make_move_iterator(inputs.end())); - Array results_arr(std::make_move_iterator(results.begin()), - std::make_move_iterator(results.end())); - - PrintTitle("Train cost model", verbose); - program_cost_model->Update(inputs_arr, results_arr); - return std::make_pair(std::move(inputs_arr), std::move(results_arr)); -} - -void SketchSearchPolicyNode::PickStatesWithEpsGreedy( - std::vector* inputs, - const std::vector& best_states, - const std::vector& random_states, - int remaining_n_trials) { - int num_random = static_cast(GetDoubleParam(params, "eps_greedy") * num_measure_per_iter_); - int num_good = num_measure_per_iter_ - num_random; - - inputs->clear(); - size_t offset_best = 0, offset_random = 0; - - while (static_cast(inputs->size()) < std::min(num_measure_per_iter_, remaining_n_trials)) { - const State* pstate; - - bool has_best = offset_best < best_states.size(); - bool has_random = offset_random < random_states.size(); - - if (static_cast(inputs->size()) < num_good) { - // prefer best states - if (has_best) { - pstate = &best_states[offset_best++]; - } else if (has_random) { - pstate = &random_states[offset_random++]; - } else { - break; - } - } else { - // prefer random states - if (has_random) { - pstate = &random_states[offset_random++]; - } else if (has_best) { - pstate = &best_states[offset_best++]; - } else { - break; - } - } - - // Check if it has already been measured - std::string state_str = pstate->ToStr(); - - if (measured_states_set_.count(state_str)) { continue; } - measured_states_set_.insert(std::move(state_str)); - - inputs->push_back(MeasureInput(cur_task, *pstate)); - measured_states_vector_.push_back(*pstate); - } -} - -void SketchSearchPolicyNode::SearchOneRound(std::vector* best_states, - int num_random_states, std::vector* random_states) { - best_states->clear(); - random_states->clear(); - - // Get parameters - int population = GetIntParam(params, "evolutionary_search_population"); - int num_use_measured = std::min(static_cast(measured_states_vector_.size()), - static_cast( - GetDoubleParam(params, "evolutionary_search_use_measured_ratio") * population)); - bool have_cost_model = !program_cost_model->IsInstance(); - - if (!have_cost_model) { - num_use_measured = 0; - } - - // Generate sketches - std::vector sketches; - GenerateSketch(&sketches); - - // PrintAllStates(sketches); - // exit(0); - - // Sample the init population - std::vector init_population; - SampleInitPopulation(sketches, population - num_use_measured, &init_population); - - // PrintAllStates(init_population); - // exit(0); - - if (have_cost_model) { - // Also insert already measured good states to the initial population - std::vector indices; - Argsort(measured_states_throughputs_, &indices); - for (int i = 0; i < num_use_measured; i++) { - init_population.push_back(measured_states_vector_[indices[i]]); - } - - // Perform evolutionary search - EvolutionarySearch(init_population, num_measure_per_iter_ * 2, best_states); - } else { - // If the cost model is useless (i.e. RandomCostModel), skip evolutionary search - RandomSampleStates(init_population, &rand_gen_, num_measure_per_iter_ * 3, best_states); - } - - // Sample some random states for eps-greedy - RandomSampleStates(init_population, &rand_gen_, num_random_states * 10, random_states); -} - -// The base class for derivation rules used in sketch generation -class SketchGenerationRule { - public: - enum ConditionEnum { - kPass, kApply, kApplyAndSkipRest - }; - - virtual ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) = 0; - virtual std::vector > Apply(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) = 0; -}; - -static inline bool ShouldBeCacheRead( - const SketchSearchPolicyNode* policy, const State& state, int stage_id) { - const SearchTask& task = policy->cur_task; - const Stage& stage = state->stages[stage_id]; - - if (HasAttrsFlag(state, stage_id, - SearchPolicyNode::no_cache_read_key)) { - return false; - } - - std::unordered_set consumers; - GetConsumers(task, state, stage->op, &consumers); - if (consumers.size() != 1) { - return false; - } - - int target_stage_id = OperationToStage(*consumers.begin(), state); - if (!NeedsMultilevelTiling(task, state, - state->stages[target_stage_id]->op)) { - return false; - } - - std::unordered_set producers; - GetProducers(task, state, state->stages[target_stage_id]->op, &producers); - // Only those directly mapped stages can do CacheRead - if (producers.find(stage->op) == producers.end()) { - return false; - } - - return true; -} - -static inline bool ShouldAlwaysBeInlined( - const SketchSearchPolicyNode* policy, const State& state, int stage_id) { - const SearchTask& task = policy->cur_task; - const Stage& stage = state->stages[stage_id]; - - if (stage->op->IsInstance()) { - return false; - } - - // Inline limitation of TVM - if (!IsOutputOp(task, state, stage->op) && !HasReduceIter(stage)) { - // Always inline condition: - // 1. Has attrs that this must be inlined - // 2. Analyse shows this is strict inlineable - // 3. A GPU stage can be inlined(If it should be cache read, do it first) - if (HasAttrsFlag(state, stage_id, - SearchPolicyNode::always_compute_inline_key) || - IsStrictInlineable(task, state, stage->op) || - (IS_GPU(policy->cur_task) && - !ShouldBeCacheRead(policy, state, stage_id))) { - return true; - } - } - - return false; -} - -// The rule that inlines simple elementwise ops -class RuleAlwaysInline : public SketchGenerationRule { - public: - ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) final { - return ShouldAlwaysBeInlined(policy, state, stage_id) ? - kApplyAndSkipRest : kPass; - } - - std::vector > Apply(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) final { - State tmp_s = state; - tmp_s.compute_inline(stage_id); - return {std::make_pair(std::move(tmp_s), stage_id - 1)}; - } -}; - -// The rule that simply skip the current stage -class RuleSkipStage : public SketchGenerationRule { - public: - ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task; - const Stage& stage = state->stages[stage_id]; - - const auto& attrs = stage->op->attrs; - if ((attrs.count(SearchPolicyNode::no_split_at_inner_key) || - attrs.count(SearchPolicyNode::no_split_at_outer_key)) && - NeedsMultilevelTiling(task, state, stage->op)) { - // for the transform stages in Winograd - return kPass; - } - - return kApply; - } - - std::vector > Apply(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) final { - return {std::make_pair(state, stage_id - 1)}; - } -}; - -// The rule that performs multi-level tiling -class RuleMultiLevelTiling : public SketchGenerationRule { - public: - ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task; - const Stage& stage = state->stages[stage_id]; - - return NeedsMultilevelTiling(task, state, stage->op) ? - (IS_GPU(policy->cur_task) ? kApplyAndSkipRest : kApply) : kPass; - } - - std::vector > Apply(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) final { - std::string multi_level_tiling_structure = IS_GPU(policy->cur_task) ? - GetStringParam(policy->params, "gpu_multi_level_tiling_structure") : - GetStringParam(policy->params, "cpu_multi_level_tiling_structure"); - - std::vector spatial_split_step_ids; - State tmp_s = state; - tmp_s = DoMultiLevelTiling(tmp_s, stage_id, multi_level_tiling_structure, - &spatial_split_step_ids); - return {std::make_pair(std::move(tmp_s), stage_id-1)}; - } -}; - -// The rule that performs multi-level tiling and fuses later consumers -class RuleMultiLevelTilingWithFusion : public SketchGenerationRule { - public: - ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task; - const Stage& stage = state->stages[stage_id]; - - int target_stage_id; - - if (IS_GPU(policy->cur_task)) { - return NeedsMultilevelTiling(task, state, stage->op) && - HasSingleElementwiseMatchedConsumer(task, state, stage, - &target_stage_id) && - (!HasCacheReadStage(state, stage_id) || - HasCacheWriteStage(state, stage_id)) ? - kApplyAndSkipRest : kPass; - } - - return NeedsMultilevelTiling(task, state, stage->op) && - HasSingleElementwiseMatchedConsumer(task, state, stage, - &target_stage_id) ? - kApply : kPass; - } - - std::vector > Apply(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task; - const Stage& stage = state->stages[stage_id]; - std::string multi_level_tiling_structure = IS_GPU(policy->cur_task) ? - GetStringParam(policy->params, "gpu_multi_level_tiling_structure") : - GetStringParam(policy->params, "cpu_multi_level_tiling_structure"); - - std::vector spatial_split_step_ids; - int target_stage_id; - std::unordered_set consumers; - - GetConsumers(task, state, state->stages[stage_id]->op, &consumers); - CHECK(HasSingleElementwiseMatchedConsumer(task, state, stage, &target_stage_id)); - - State base_state = state; - base_state = DoMultiLevelTiling(base_state, stage_id, - multi_level_tiling_structure, &spatial_split_step_ids); - std::vector follow_tiling_levels; - if (IS_GPU(policy->cur_task)) { - follow_tiling_levels.push_back(3); - } else { - follow_tiling_levels.push_back(1); - follow_tiling_levels.push_back(2); - } - - std::vector > ret; - for (int level : follow_tiling_levels) { - if (tolower(multi_level_tiling_structure[level-1]) != 's') { - continue; - } - State tmp_s = base_state; - tmp_s = FollowTiling(tmp_s, target_stage_id, spatial_split_step_ids, level); - const Iterator &target_iter = tmp_s->stages[target_stage_id]->iters[ - level * spatial_split_step_ids.size() - 1]; - tmp_s.compute_at(stage_id, target_stage_id, target_iter); - - ret.emplace_back(std::move(tmp_s), stage_id - 1); - } - - return ret; - } -}; - -// The rule that adds a cache write stage -class RuleAddCacheWrite : public SketchGenerationRule { - public: - ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task; - const Stage& stage = state->stages[stage_id]; - - int target_stage_id; - - // Add cache write if a stage needs multi-level tiling, - // but does not have a element-wise matched consumer - return NeedsMultilevelTiling(task, state, stage->op) && - !HasAttrsFlag(state, stage_id, SearchPolicyNode::no_cache_write_key) && - (!HasSingleElementwiseMatchedConsumer(task, state, stage, - &target_stage_id) || - (HasCacheReadStage(state, stage_id) && - !HasCacheWriteStage(state, stage_id))) ? - kApply : kPass; - } - - std::vector > Apply(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task; - - State tmp_s = state; - tmp_s.cache_write(stage_id, "local", task->compute_dag); - return {std::make_pair(std::move(tmp_s), stage_id)}; - } -}; - -// The rule that adds a cache read stage -// Mainly used for GPU cooperative fetching -// Currently only support 1 to 1 match cache read -class RuleAddCacheRead : public SketchGenerationRule { - public: - ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) final { - return ShouldBeCacheRead(policy, state, stage_id) ? - kApplyAndSkipRest : kPass; - } - - std::vector > Apply(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task; - const Stage& stage = state->stages[stage_id]; - - std::unordered_set consumers; - GetConsumers(task, state, stage->op, &consumers); - CHECK_EQ(consumers.size(), 1); - int target_stage_id = OperationToStage(*consumers.begin(), state); - State tmp_s = state; - int added_stage_id = tmp_s.cache_read(stage_id, "shared", - {target_stage_id}, - task->compute_dag); - target_stage_id++; - const auto& share_read_pos = GetLastReduceIteratorInOutermostReduceTile( - tmp_s->stages[target_stage_id]); - tmp_s.compute_at(added_stage_id, target_stage_id, share_read_pos); - - return {std::make_pair(std::move(tmp_s), stage_id)}; - } -}; - -// The rule that adds rfactor stage -class RuleAddRfactor : public SketchGenerationRule { - public: - ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task; - const Stage& stage = state->stages[stage_id]; - - return NeedsRfactor(task, state, stage->op) && - !HasCacheWriteStage(state, stage_id) ? - kApply : kPass; - } - - std::vector > Apply(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) final { - const SearchTask& task = policy->cur_task; - const Stage& stage = state->stages[stage_id]; - - std::vector > ret; - - State tmp_s = state; - - // fuse reduce iters - std::vector space_iters, reduce_iters; - for (const auto &iter : stage->iters) { - if (iter->iter_type == kSpace) { - space_iters.push_back(iter); - } else if (iter->iter_type == kReduce) { - reduce_iters.push_back(iter); - } - } - CHECK(!reduce_iters.empty()); - Iterator fused_reduce_iter; - if (reduce_iters.size() > 1) { - fused_reduce_iter = tmp_s.fuse(stage_id, reduce_iters); - } else { - fused_reduce_iter = reduce_iters[0]; - } - - // split reduce iters - const auto &split_res = tmp_s.split(stage_id, fused_reduce_iter, {1}); - int factor_axis_id = static_cast(space_iters.size()); - State base_state = tmp_s; - for (const auto &split_iter : split_res) { - tmp_s = base_state; - tmp_s.rfactor(stage_id, split_iter, factor_axis_id, task->compute_dag); - - // reorder the space iterator to innermost for vectorization - if (split_iter == split_res[1]) { - std::vector new_order; - for (size_t i = 0; i < tmp_s->stages[stage_id]->iters.size(); ++i) { - if (i != space_iters.size()) { - new_order.push_back(tmp_s->stages[stage_id]->iters[i]); - } - } - new_order.push_back(tmp_s->stages[stage_id]->iters[space_iters.size()]); - tmp_s.reorder(stage_id, new_order); - } - ret.emplace_back(std::move(tmp_s), stage_id - 1); - } - - return ret; - } -}; - -void SketchSearchPolicyNode::GenerateSketch( - std::vector* out_states) { - State init_state = cur_task->compute_dag.GetInitState(); - std::string cpu_multi_level_tiling_structure = - GetStringParam(params, "cpu_multi_level_tiling_structure"); - - // two ping pong buffers to avoid copy - std::vector states_buf1, states_buf2; - std::vector *pnow, *pnext; - pnow = &states_buf1; - pnext = &states_buf2; - pnow->push_back(init_state); - - // A map that maps state to its current working position (stage_id) - std::unordered_map cur_stage_id_map; - cur_stage_id_map[init_state] = static_cast(init_state->stages.size() - 1); - - static RuleSkipStage rule_skip_stage; - static RuleAlwaysInline rule_always_inline; - static RuleMultiLevelTiling rule_multi_level_tiling; - static RuleMultiLevelTilingWithFusion rule_multi_level_tiling_with_fusion; - static RuleAddCacheWrite rule_add_cache_write_stage; - static RuleAddCacheRead rule_add_cache_read_stage; - static RuleAddRfactor rule_add_rfactor; - if (sketch_rules.empty()) { - // We may apply and skip the rest when processing some rules, - // should take care of the rule vector order here - sketch_rules.push_back(&rule_always_inline); - sketch_rules.push_back(&rule_add_cache_write_stage); - sketch_rules.push_back(&rule_multi_level_tiling_with_fusion); - sketch_rules.push_back(&rule_multi_level_tiling); - sketch_rules.push_back(&rule_add_rfactor); - sketch_rules.push_back(&rule_skip_stage); - if (IS_GPU(cur_task)) { - // Try cache read first before cache write - sketch_rules.insert(sketch_rules.begin() + 1, &rule_add_cache_read_stage); - } - // TODO(xian): Add a new rule to try combination of multi-level - // tiling + rfactor - } - - // Derivation rule based synthesizer - while (!pnow->empty()) { - pnext->clear(); - - for (const State& state : *pnow) { - int stage_id = cur_stage_id_map[state]; - - // Reaches to the terminal stage - if (stage_id < 0) { - out_states->push_back(state); - continue; - } - - // Try all derivation rules - for (const auto& rule : sketch_rules) { - auto rule_check = rule->MeetCondition(this, state, stage_id); - if (rule_check > SketchGenerationRule::ConditionEnum::kPass) { - for (const auto& pair : rule->Apply(this, state, stage_id)) { - cur_stage_id_map[pair.first] = pair.second; - pnext->push_back(pair.first); - } - // Skip the reset rules - if (rule_check == SketchGenerationRule::ConditionEnum::kApplyAndSkipRest) { - break; - } - } - } - } - - std::swap(pnow, pnext); - } - - // Hack for rfactor: Replace the split factor for rfactor to the undefined Expr(), - // so later we can sample random value for the split factor. - // Why don't we use Expr() when doing the split for rfactor at the first time? - // Because during ApplySteps, a rfactor with undefined Expr() will crash TVM. - // So rfactor with undefined Expr() will conflict with cache_write, cache_read, rfactor - // in other stages - for (size_t i = 0; i < out_states->size(); ++i) { - auto pstate = (*out_states)[i].CopyOnWrite(); - for (size_t step_id = 0; step_id < pstate->transform_steps.size(); ++step_id) { - if (pstate->transform_steps[step_id]->IsInstance()) { - CHECK_GE(step_id, 1); - int split_step_id = step_id - 1; - auto step = pstate->transform_steps[split_step_id].as(); - CHECK(step != nullptr); - pstate->transform_steps[split_step_id] - = SplitStep(step->stage_id, step->iter_id, step->extent, {PrimExpr()}, - step->inner_to_outer); - } - } - } - - StdCout(verbose) << "Generate Sketches\t\t#s: " << out_states->size() << std::endl; -} - -int InitPopulationFillTileSize(const SketchSearchPolicyNode* policy, - State* state, std::mt19937* rand_gen, - SplitFactorizationMemo* split_memo) { - for (size_t step_id = 0; step_id < (*state)->transform_steps.size(); ++step_id) { - if (auto ps = (*state)->transform_steps[step_id].as()) { - bool defined = true; - for (const PrimExpr& len : ps->lengths) { - if (!len.defined()) { - defined = false; - } - } - - if (defined) { - continue; - } - - int extent = GetIntImm(ps->extent); - const std::vector >& candidate_lens = - split_memo->GetFactorizationSchemes( - extent, ps->lengths.size(), - policy->cur_task->hardware_params->max_innermost_split_factor); - - StateNode* pstate = state->CopyOnWrite(); - pstate->transform_steps[step_id] = SplitStep( - ps->stage_id, ps->iter_id, ps->extent, - candidate_lens[(*rand_gen)() % candidate_lens.size()], - ps->inner_to_outer); - } - } - - return 0; -} - -int InitPopulationThreadBind(const SketchSearchPolicyNode* policy, - State* state) { - for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { - const Stage& stage = (*state)->stages[stage_id]; - auto pop = stage->op.as(); - - if (stage->compute_at != kRoot || stage->op_type == kPlaceholder) { - continue; - } - - if (HasAnnotationIter(stage, IteratorAnnotation::kThreadX)) { - // Skip if this stage has already done thread bind - continue; - } - - std::vector to_fuse; - - // This stage has not been tiled, but in GPU schedule, we must tile it - // to do thread binding - if (!HasSplitStep(*state, stage_id)) { - for (const auto& it : (*state)->stages[stage_id]->iters) { - if (it->iter_type == kReduce) { - break; - } - to_fuse.push_back(it); - } - const auto& fused_it = state->fuse(stage_id, to_fuse); - // Set default vthread=1 & threadIdx.x=default_warp_size - // EvolutionarySearch will try more possiblity - if (GetExtent(fused_it) <= - policy->cur_task->hardware_params->warp_size) { - state->bind_thread(stage_id, fused_it, kThreadX); - } else { - const auto& split_its = state->split(stage_id, fused_it, - {1, policy->cur_task->hardware_params->warp_size}); - state->bind_thread(stage_id, split_its[0], kBlockX); - state->bind_thread(stage_id, split_its[1], kVThread); - state->bind_thread(stage_id, split_its[2], kThreadX); - } - - continue; - } - - int total_space_extent = 1; - for (const auto& i : pop->root_iter_vars()) { - CHECK(i->dom.defined()); - const auto& pint = i->dom->extent.as(); - CHECK(pint); - total_space_extent *= pint->value; - } - - // TODO(..): Add ThreadBind support for rfactor - if (total_space_extent <= policy->cur_task->hardware_params->warp_size) { - for (const auto& it : (*state)->stages[stage_id]->iters) { - if (it->iter_type == kReduce) { - break; - } - to_fuse.push_back(it); - } - const auto& fused_it = state->fuse(stage_id, to_fuse); - state->bind_thread(stage_id, fused_it, kThreadX); - - continue; - } - - // Fuse the outermost space tile as blockIdx - for (size_t i = 0; i < pop->axis.size(); i++) { - const auto& it = (*state)->stages[stage_id]->iters[i]; - if (!StrEndsWith(it->name, ".0")) { - break; - } - to_fuse.push_back(it); - } - const auto& blockidx_it = state->fuse(stage_id, to_fuse); - state->bind_thread(stage_id, blockidx_it, kBlockX); - - // Fuse the second outermost space tile as vthread - to_fuse.clear(); - for (size_t i = 1; i < pop->axis.size() + 1; i++) { - const auto& it = (*state)->stages[stage_id]->iters[i]; - if (!StrEndsWith(it->name, ".1")) { - break; - } - to_fuse.push_back((*state)->stages[stage_id]->iters[i]); - } - const auto& vthread_it = state->fuse(stage_id, to_fuse); - if (GetExtent(vthread_it) > - policy->cur_task->hardware_params->max_vthread_extent) { - return -1; - } - state->bind_thread(stage_id, vthread_it, kVThread); - - // Fuse the third outermost space tile as threadIdx - to_fuse.clear(); - for (size_t i = 2; i < pop->axis.size() + 2; i++) { - const auto& it = (*state)->stages[stage_id]->iters[i]; - if (!StrEndsWith(it->name, ".2")) { - break; - } - to_fuse.push_back((*state)->stages[stage_id]->iters[i]); - } - const auto& threadidx_it = state->fuse(stage_id, to_fuse); - if (GetExtent(threadidx_it) < - policy->cur_task->hardware_params->warp_size) { - return -1; - } - state->bind_thread(stage_id, threadidx_it, kThreadX); - } - - return 0; -} - -int InitPopulationCooperativeFetching(const SketchSearchPolicyNode* policy, - State* state) { - for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { - // Do cooperative fetching with cache read stage - // For two stages: A -> B - // 1. A -> A_cache_read -> B - // * - // 2. A -> A_cache_write -> A_cache_read -> B - // * - if ((stage_id > 0 && HasCacheReadStage((*state), stage_id - 1) && - !HasCacheWriteStage((*state), stage_id - 1)) || - (stage_id > 1 && HasCacheReadStage((*state), stage_id - 2) && - HasCacheWriteStage((*state), stage_id - 2))) { - const Stage& target_stage = (*state)->stages[stage_id]; - if (HasAnnotationIter(target_stage, IteratorAnnotation::kThreadX) || - HasAnnotationIter(target_stage, IteratorAnnotation::kTensorized)) { - // Skip if this stage has already done thread bind or has been - // tensorized - continue; - } - // Get spatial_split_step_ids from the root stage - std::unordered_set consumers; - std::vector spatial_split_step_ids; - GetConsumers(policy->cur_task, (*state), target_stage->op, &consumers); - CHECK_EQ(consumers.size(), 1); - int target_stage_id = OperationToStage(*consumers.begin(), (*state)); - GetSpaceSplitStepIds((*state), target_stage_id, &spatial_split_step_ids); - - // Fuse all axis to to do cooperative fetching - Iterator fused = state->fuse(stage_id, - (*state)->stages[stage_id]->iters); - // Left a vectorized cooperative fetching split placeholder - const auto& iters0 = state->split(stage_id, fused, {1}); - state->vectorize(stage_id, iters0[1]); - // Follow split to keep a same thread extent with the root stage - const auto& iters1 = state->follow_fused_split(stage_id, iters0[0], - spatial_split_step_ids, - 1, true); - state->bind_thread(stage_id, iters1[1], kThreadX); - } - } - - return 0; -} - -int InitPopulationChangeComputeLocation(const SketchSearchPolicyNode* policy, - State* state, std::mt19937* rand_gen) { - if (GetIntParam(policy->params, "disable_change_compute_location")) { - return 0; - } - - for (int stage_id = static_cast((*state)->stages.size()) - 1; stage_id >= 0; stage_id--) { - const Stage& stage = (*state)->stages[stage_id]; - - if (stage->op_type == kPlaceholder) { - continue; - } - - if (IsTiled(stage) || stage->compute_at == kInlined) { - continue; - } - - if (NeedsMultilevelTiling(policy->cur_task, (*state), stage->op)) { - continue; - } - - std::unordered_set consumers; - - GetConsumers(policy->cur_task, (*state), stage->op, &consumers); - if (consumers.empty()) { - continue; - } - - int target_stage_id; - if (consumers.size() == 1) { - target_stage_id = OperationToStage(*consumers.begin(), *state); - } else { - // check all consumers share a common root - int common_root_id = -1; - bool mismatch = false; - for (const auto& consumer : consumers) { - int consumer_stage_id = OperationToStage(consumer, *state); - int root_id = -1; - if ((*state)->stages[consumer_stage_id]->compute_at == kRoot) { - root_id = consumer_stage_id; - } else if ((*state)->stages[consumer_stage_id]->compute_at == kIter) { - root_id = (*state)->attach_map->stage_to_attach_iter.at(consumer_stage_id).first; - } else { - LOG(FATAL) << "Invalid case"; - } - - if (common_root_id == -1) { - common_root_id = root_id; - } else { - if (common_root_id != root_id) { - mismatch = true; - break; - } - } - } - - if (mismatch) { - continue; - } - target_stage_id = common_root_id; - } - - const Stage& target_stage = (*state)->stages[target_stage_id]; - std::set to_unroll_name_set; - if (target_stage->op->attrs.count(policy->always_unroll_key)) { - to_unroll_name_set = GetIterNameSetParam(target_stage->op->attrs, - policy->always_unroll_key); - } - - std::vector > candidates; - bool target_compute_at_other = target_stage->compute_at == kIter; - bool target_is_tiled = IsTiled(target_stage); - - bool visited_reduce = false; - // enumerate compute_at location at target_stage - int ct = 0; - for (const auto& target_iter : target_stage->iters) { - if (target_iter->iter_type == kReduce) { - visited_reduce = true; - if (!target_is_tiled) { // do not go into reduce iter - break; - } - } else if (target_iter->iter_type == kSpace) { - if (visited_reduce) { // do not go into inner tile - break; - } - } - - if (to_unroll_name_set.count(target_iter->name)) { - // Do not go into always unroll region - break; - } - - if (GetExtent(target_iter) == 1) { // skip iterators with length of 1 - continue; - } - if (target_compute_at_other && target_iter->iter_type == kSpace && - StrEndsWith(target_iter->name, ".0")) { - // skip the first level iterators if target stage compute_at another stage - // In this case, the lengths of first level iterators are always one - continue; - } - candidates.emplace_back(target_stage_id, target_iter); - - if ((*state)->attach_map->iter_to_attached_stages.count( - std::make_pair(target_stage_id, ct++))) { - break; - } - } - - // if the target_stage is already compute_at another stage X, try also compute_at X - // We call stage X as `target_target_stage` - if (target_compute_at_other) { - int target_target_stage_id; - target_target_stage_id = (*state)->attach_map->stage_to_attach_iter.at( - target_stage_id).first; - const Stage& target_target_stage = (*state)->stages[target_target_stage_id]; - if (target_target_stage->op->attrs.count(policy->always_unroll_key)) { - to_unroll_name_set = GetIterNameSetParam(target_target_stage->op->attrs, - policy->always_unroll_key); - } else { - to_unroll_name_set.clear(); - } - - int ct = 0; - for (const auto& target_target_iter : target_target_stage->iters) { - if (target_target_iter->iter_type == kReduce || - (*state)->attach_map->iter_to_attached_stages.count( - std::make_pair(target_target_stage_id, ct++))) { - break; - } - - if (to_unroll_name_set.count(target_target_iter->name)) { - // Do not go into always unroll region - break; - } - - if (GetExtent(target_target_iter) == 1) { // skip iterators with length of 1 - continue; - } - - candidates.push_back(std::make_pair(target_target_stage_id, target_target_iter)); - } - } - - int choice = (*rand_gen)() % (candidates.size() + 2); - - if (choice == 0) { - if (!HasReduceIter(stage)) { - state->compute_inline(stage_id); - } - } else if (choice == 1) { - state->compute_root(stage_id); - } else { - choice = choice - 2; - state->compute_at(stage_id, candidates[choice].first, candidates[choice].second); - } - } - - return 0; -} - -int InitPopulationParallel(const SketchSearchPolicyNode* policy, - State* state) { - std::function - annotate_parallel; - - annotate_parallel = [&annotate_parallel]( - const SketchSearchPolicyNode* policy, State* state, int stage_id, int iter_offset) { - const Stage& stage = (*state)->stages[stage_id]; - - std::vector to_fuse; - int64_t parallel_degree = 1; - - // strategy: try to fuse and parallel the outermost n iterators - // Stop if we meet reduce iterator or we have enough parallel degree - size_t iter_id = iter_offset; - for (; iter_id < stage->iters.size(); ++iter_id) { - const Iterator& it = stage->iters[iter_id]; - if (it->iter_type == kReduce || it->annotation != kNone) { - break; - } - - to_fuse.push_back(it); - parallel_degree *= GetExtent(it); - - if (parallel_degree > policy->cur_task->hardware_params->num_cores * 16) { - break; - } - - if ((*state)->attach_map->iter_to_attached_stages.count( - std::make_pair(stage_id, iter_id))) { - break; - } - } - - if (parallel_degree == 1) { - auto res = - (*state)->attach_map->iter_to_attached_stages.find(std::make_pair(stage_id, iter_id)); - if (res != (*state)->attach_map->iter_to_attached_stages.end()) { - for (int attached_stage_id : res->second) { - annotate_parallel(policy, state, attached_stage_id, 0); - } - annotate_parallel(policy, state, stage_id, iter_id + 1); - } - } - - if (!to_fuse.empty()) { - if (to_fuse.size() == 1) { - state->parallel(stage_id, to_fuse[0]); - } else { - Iterator fused_iter = state->fuse(stage_id, to_fuse); - state->parallel(stage_id, fused_iter); - } - } - }; - - for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { - const Stage& stage = (*state)->stages[stage_id]; - if (stage->compute_at != kRoot || stage->op_type == kPlaceholder) { - continue; - } - - annotate_parallel(policy, state, stage_id, 0); - } - - return 0; -} - -int InitPopulationVectorization(const SketchSearchPolicyNode* policy, - State* state, std::mt19937* rand_gen) { - for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { - const Stage& stage = (*state)->stages[stage_id]; - - if (stage->op_type == kPlaceholder) { - continue; - } - - // Skip cooperative fetching stage - if (IS_GPU(policy->cur_task) && - HasCacheReadStage((*state), stage_id - 1)) { - continue; - } - - if (HasAnnotationIter(stage, IteratorAnnotation::kTensorized)) { - // Skip if this stage has been tensorized - continue; - } - - // try to fuse and vectorize the space iterators in the inner most tile - int cum_length_prod = 1; - - std::set to_unroll_name_set; - if (stage->op->attrs.count(policy->always_unroll_key)) { - to_unroll_name_set = GetIterNameSetParam(stage->op->attrs, - policy->always_unroll_key); - } - - int num_fusible = 0; - while (num_fusible < static_cast(stage->iters.size())) { - int iter_id = static_cast(stage->iters.size()) - 1 - num_fusible; - if ((*state)->attach_map->iter_to_attached_stages.count( - std::make_pair(stage_id, iter_id))) { - break; - } - - const Iterator& it = stage->iters[iter_id]; - - // Stop if we meet a reduce iterator - if (it->iter_type == kReduce || it->annotation != kNone || - to_unroll_name_set.count(it->name)) { - break; - } - - // Stop if the memory access is not continuous (vectorizable) - // Note: The check is too hard, so we use heuristic here - if (IsTiled(stage) && num_fusible != 0) { - // If the stage is tiled, then the memory access must not be continuous - // for the innermost two iterators - break; - } - - cum_length_prod *= GetExtent(it); - if (cum_length_prod > policy->cur_task->hardware_params->max_unroll_vec) { - break; - } - - num_fusible++; - } - - if (num_fusible > 1) { - num_fusible = 1 + (*rand_gen)() % (num_fusible - 1); // Select a random range to fuse - } - - if (num_fusible == 1) { - state->vectorize(stage_id, stage->iters.back()); - } else if (num_fusible > 1) { - std::vector to_fuse(stage->iters.end() - num_fusible, - stage->iters.end()); - state->vectorize(stage_id, state->fuse(stage_id, to_fuse)); - } - } - - return 0; -} - -int InitPopulationUnroll(const SketchSearchPolicyNode* policy, - State* state, std::mt19937* rand_gen) { - for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { - const Stage& stage = (*state)->stages[stage_id]; - - if (stage->op_type == kPlaceholder) { - continue; - } - - if (stage->op->attrs.count(policy->always_unroll_inner_key)) { - // Special unroll policy - auto to_unroll_name_set = GetIterNameSetParam(stage->op->attrs, - policy->always_unroll_inner_key); - std::set visited_names; - - // Unroll the space iterators and reduce iterators listed in the attrs - // in the innermost tile - int n = static_cast(stage->iters.size()) - 1; - visited_names.clear(); - while (n >= 0) { - const Iterator& it = stage->iters[n]; - - // If we meet two iterators that come from a same original iterator, - // then we are out of the innermost tile - size_t size_before = visited_names.size(); - ExtractOriginalIterators(it->name, &visited_names); - if (size_before == visited_names.size()) { - break; - } - - std::set name; - ExtractOriginalIterators(it->name, &name); - if (name.size() == 1 && to_unroll_name_set.count(*name.begin())) { - state->unroll(stage_id, it); - } - - n--; - } - } else if (stage->op->attrs.count(policy->always_unroll_key)) { - // Special unroll policy - auto to_unroll_name_set = GetIterNameSetParam(stage->op->attrs, - policy->always_unroll_key); - - // Unroll the space iterators and reduce iterators listed in the attrs - int n = static_cast(stage->iters.size()) - 1; - while (n >= 0) { - const Iterator& it = stage->iters[n]; - if (to_unroll_name_set.count(it->name)) { - state->unroll(stage_id, it); - } - n--; - } - } else if (HasReduceIter(stage)) { - // use auto unroll for multi level tiled stage - int value = policy->auto_unroll_configs[ - (*rand_gen)() % policy->auto_unroll_configs.size()]; - state->pragma(stage_id, (*state)->stages[stage_id]->iters[0], - std::string("auto_unroll_max_step") + "$" + std::to_string(value)); - } - } - - return 0; -} - -void SketchSearchPolicyNode::SampleInitPopulation(const std::vector& sketches, - int out_size, std::vector* out_states) { - std::uniform_real_distribution<> dis(0.0, 1.0); - int continue_count = 0; - - // TODO(...): Maybe try muti thread here - while (static_cast(out_states->size()) < out_size && - continue_count < out_size * 10) { - State tmp_s = sketches[rand_gen_() % sketches.size()]; - - InitPopulationFillTileSize(this, &tmp_s, &rand_gen_, &split_memo_); - - if (IS_GPU(cur_task)) { - tmp_s = cur_task->compute_dag.InferBound(tmp_s); - - if (InitPopulationThreadBind(this, &tmp_s)) { - continue_count++; - if (continue_count == out_size) { - StdCout(verbose) << "Initial Population Sampling..." << std::endl; - } - continue; - } - - InitPopulationCooperativeFetching(this, &tmp_s); - } else { - InitPopulationChangeComputeLocation(this, &tmp_s, &rand_gen_); - - tmp_s = cur_task->compute_dag.InferBound(tmp_s); - - InitPopulationParallel(this, &tmp_s); - } - - InitPopulationVectorization(this, &tmp_s, &rand_gen_); - - InitPopulationUnroll(this, &tmp_s, &rand_gen_); - - out_states->push_back(std::move(tmp_s)); - } - - StdCout(verbose) << "Sample Initial Population\t#s: " - << out_states->size() << std::endl; -} - -void SketchSearchPolicyNode::EvolutionarySearch( - const std::vector& init_population, - int num_best_states, std::vector* best_states) { - auto tic_begin = std::chrono::high_resolution_clock::now(); - - // Set parameters for genetic algorithm - int population = GetIntParam(params, "evolutionary_search_population"); - int num_iters = GetIntParam(params, "evolutionary_search_num_iters"); - double mutation_prob = GetDoubleParam(params, "evolutionary_search_mutation_prob"); - int num_cross_over = static_cast(population * 0.0); // NOT IMPLEMENTED currently - int num_cross_over_trial_upper_bound = num_cross_over * 3; - CostModel cost_model = program_cost_model; - - // Two ping pong buffers to avoid copy - std::vector states_buf1, states_buf2; - std::vector *pnow = &states_buf1, *pnext = &states_buf2; - states_buf1.reserve(population); - states_buf2.reserve(population); - states_buf1.insert(states_buf1.begin(), init_population.begin(), init_population.end()); - - // A heap to keep the best states during evolution - using StateItem = std::pair; - auto cmp = [](const StateItem& left, const StateItem& right) { - return left.second > right.second; - }; - std::vector heap; - std::unordered_set in_heap(measured_states_set_); - heap.reserve(num_best_states); - - // auxiliary global variables - std::vector scores; - std::vector prefix_sum_probs; - double max_score = 0.0; - scores.reserve(population); - prefix_sum_probs.reserve(population); - std::uniform_real_distribution<> dis(0.0, 1.0); - int mutation_fail_ct = 0; - - // Genetic Algorithm - for (int k = 0; k < num_iters + 1; ++k) { - // Maintain the heap - cur_task->compute_dag.InferBound(pnow); - PruneUndefined(pnow); - cost_model->Predict(cur_task, *pnow, &scores); - - for (size_t i = 0; i < pnow->size(); ++i) { - const State& state = (*pnow)[i]; - std::string state_str = state.ToStr(); - - if (in_heap.count(state_str) == 0) { - if (static_cast(heap.size()) < num_best_states) { - heap.emplace_back((*pnow)[i], scores[i]); - std::push_heap(heap.begin(), heap.end(), cmp); - in_heap.insert(state_str); - } else if (scores[i] > heap.front().second) { - std::string old_state_str = heap.front().first.ToStr(); - in_heap.erase(old_state_str); - in_heap.insert(state_str); - - std::pop_heap(heap.begin(), heap.end(), cmp); - heap.back() = StateItem(state, scores[i]); - std::push_heap(heap.begin(), heap.end(), cmp); - } - if (scores[i] > max_score) { - max_score = scores[i]; - } - } - } - - if (k % 5 == 0 || k == num_iters) { - StdCout(verbose) << "GA Iter: " << k << std::fixed << std::setprecision(4) - << "\tMax score: " << max_score - << "\tMin score: " << heap.front().second - << "\tPop size: " << pnow->size() << std::endl; - } - - if (k == num_iters) { - break; - } - - // Compute selection probability - double sum = 0.0; - prefix_sum_probs.resize(scores.size()); - for (size_t i = 0; i < scores.size(); ++i) { - sum += std::max(scores[i], 0.0f); - prefix_sum_probs[i] = sum; - } - for (size_t i = 0; i < scores.size(); ++i) { - prefix_sum_probs[i] = prefix_sum_probs[i] / sum; - } - - // Do cross over - int ct = 0; - while (static_cast(pnext->size()) < num_cross_over - && ct < num_cross_over_trial_upper_bound) { - int p1 = RandomChoose(prefix_sum_probs, &rand_gen_); - int p2 = RandomChoose(prefix_sum_probs, &rand_gen_); - - if (p1 == p2) { - pnext->push_back((*pnow)[p1]); - } else { - State tmp_s = CrossOverState((*pnow)[p1], (*pnow)[p2]); - if (tmp_s.defined()) { - pnext->push_back(std::move(tmp_s)); - } - } - ct++; - } - - // Do mutation - mutation_fail_ct = 0; - while (static_cast(pnext->size()) < population) { - int id = RandomChoose(prefix_sum_probs, &rand_gen_); - - if (dis(rand_gen_) < mutation_prob) { - const std::vector rule_prefix_sum_probs{0.9, 1.0}; - - int rule_id = RandomChoose(rule_prefix_sum_probs, &rand_gen_); - - if (rule_id == 0) { - // Mutate Tile Size - State tmp_s = RandomMutateTileSize((*pnow)[id], &split_memo_, &rand_gen_, - cur_task->hardware_params->max_innermost_split_factor); - if (tmp_s.defined()) { - pnext->push_back(std::move(tmp_s)); - } else { - mutation_fail_ct++; - } - } else if (rule_id == 1) { - // Mutate auto-unroll max step. - State tmp_s = RandomMutateMaxUnrollStep((*pnow)[id], &rand_gen_, auto_unroll_configs); - if (tmp_s.defined()) { - pnext->push_back(std::move(tmp_s)); - } else { - mutation_fail_ct++; - } - } - } else { - pnext->push_back((*pnow)[id]); - } - } - - std::swap(pnext, pnow); pnext->clear(); - } - - // Copy best states in the heap to out_states - std::sort(heap.begin(), heap.end(), cmp); - best_states->clear(); - for (auto& item : heap) { - best_states->push_back(std::move(item.first)); - } - - double duration = std::chrono::duration_cast >( - std::chrono::high_resolution_clock::now()- tic_begin).count(); - StdCout(verbose) << "EvolutionarySearch\t\t#s: " << best_states->size() - << "\tTime elapsed: " - << std::fixed << std::setprecision(2) << duration << std::endl; -} - -class RuleCustomSketch : public SketchGenerationRule { - public: - RuleCustomSketch(PackedFunc meet_condition_func, PackedFunc apply_func) : - meet_condition_func_(meet_condition_func), apply_func_(apply_func) {} - - inline ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, - const State& state, int stage_id) final { - auto ret = meet_condition_func_( - tvm::runtime::GetRef(policy), state, stage_id); - if (ret.type_code() == 0) { - return ConditionEnum(static_cast(ret)); - } else { - return kApplyAndSkipRest; - } - } - - inline std::vector > Apply( - const SketchSearchPolicyNode* policy, - const State& state, int stage_id) final { - std::vector > ret; - - Array> apply_ret = apply_func_( - tvm::runtime::GetRef(policy), state, stage_id); - - for (const auto& item : apply_ret) { - CHECK_EQ(item.size(), 2); - State state = Downcast(item[0]); - auto next = item[1].as(); - ret.emplace_back(state, next->value); - } - return ret; - } - - private: - PackedFunc meet_condition_func_; - PackedFunc apply_func_; -}; - -PreloadCustomSketchRule::PreloadCustomSketchRule(PackedFunc meet_condition_func, - PackedFunc apply_func) { - auto node = make_object(); - node->meet_condition_func = meet_condition_func; - node->apply_func = apply_func; - data_ = std::move(node); -} - -void PreloadCustomSketchRuleNode::callback(SearchPolicyNode* policy) { - CHECK(policy->IsInstance()); - auto sketch_policy = dynamic_cast(policy); - sketch_policy->sketch_rules.emplace_back( - new RuleCustomSketch(meet_condition_func, apply_func)); - StdCout(policy->verbose) << "Custom sketch rule added." << std::endl; -} - -TVM_REGISTER_GLOBAL("ansor.SketchSearchPolicy") -.set_body_typed([](CostModel program_cost_model, Map params, - int seed){ - return SketchSearchPolicy(program_cost_model, params, seed); -}); - -TVM_REGISTER_GLOBAL("ansor.PreloadCustomSketchRule") -.set_body_typed([](PackedFunc meet_condition_func, PackedFunc apply_func) { - return PreloadCustomSketchRule(meet_condition_func, apply_func); -}); - -} // namespace ansor -} // namespace tvm diff --git a/src/ansor/search_policy/sketch_search_policy.h b/src/ansor/search_policy/sketch_search_policy.h deleted file mode 100644 index 54a5cdd1fa4e..000000000000 --- a/src/ansor/search_policy/sketch_search_policy.h +++ /dev/null @@ -1,157 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file ansor/search_policy/sketch_search_policy.h - * \brief The search policy that searches in a hierarchical search space defined by sketches. - * The policy randomly samples programs from the space defined by sketches - * and use evolutionary search to fine-tune them. - */ - -#ifndef TVM_ANSOR_SEARCH_POLICY_SKETCH_SEARCH_POLICY_H_ -#define TVM_ANSOR_SEARCH_POLICY_SKETCH_SEARCH_POLICY_H_ - -#include -#include -#include -#include -#include -#include "search_policy.h" -#include "../cost_model/cost_model.h" -#include "../utils.h" - - -namespace tvm { -namespace ansor { - -class SketchGenerationRule; - -/*! - * \brief The search policy that searches in a hierarchical search space defined by sketches. - * The policy randomly samples programs from the space defined by sketches - * and use evolutionary search to fine-tune them. - */ -class SketchSearchPolicyNode: public SearchPolicyNode { - public: - /*! \brief The cost model for complete programs */ - CostModel program_cost_model; - /*! \brief Random generator */ - std::mt19937 rand_gen_; - /*! \brief The parameters for search. It stores the following parameters: - * int evolutionary_search_population // The population size for evolutionary search - * int evolutionary_search_mutation_prob // The probability of mutation for evolutionary search - * int evolutionary_search_num_iters; // The number of iterations for evolutionary search - * double local_mutation_use_measured_ratio; // The maximum percentage of measured states in the initial - * // population for evolutionary search - * double eps_greedy; // Always allocate this percentage of measurements to random sampled states - * str cpu_multi_level_tiling_structure // The structure of multi-level tiling for CPU - * str gpu_multi_level_tiling_structure // The structure of multi-level tiling for GPU - */ - Map params; - /*! \brief The rules to generate sketches */ - std::vector sketch_rules; - - /*! \brief Search and make n_trails measurements. - * \returns the best state */ - State Search(SearchTask task, int n_trials, - int early_stopping, int num_measure_per_iter, - int verbose, ProgramMeasurer measurer, - Array pre_search_callbacks) final; - - /*! \brief Continue search for one round. This is used by JointTuner - * \returns the measurement pairs */ - std::pair, Array > ContinueSearchOneRound( - SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) final; - - static constexpr const char *_type_key = "ansor.SketchSearchPolicy"; - static const std::vector auto_unroll_configs; - - TVM_DECLARE_FINAL_OBJECT_INFO(SketchSearchPolicyNode, SearchPolicyNode); - - protected: - /*! \brief Pick states from best states and random states with eps-greedy policy */ - void PickStatesWithEpsGreedy(std::vector* inputs, - const std::vector& best_states, - const std::vector& random_states, - int remaining_n_trials); - - private: - // Run one round of the search pipeline - void SearchOneRound(std::vector* best_states, - int num_random_states, std::vector* random_states); - - // Generate sketches without tile size - void GenerateSketch(std::vector* out_states); - - // Sample init population - void SampleInitPopulation(const std::vector& sketches, - int out_size, std::vector* out_states); - - // Perform evolutionary search - void EvolutionarySearch(const std::vector& init_population, - int num_best_states, std::vector* best_states); - - SplitFactorizationMemo split_memo_; // Memorize split space for Split - int num_measure_per_iter_; // The number of states to measure per iteration -}; - -/*! - * \brief Managed reference to SketchSearchPolicyNode. - * \sa SketchSearchPolicyNode - */ -class SketchSearchPolicy : public SearchPolicy { - public: - SketchSearchPolicy(CostModel program_cost_model, - Map params, - int seed); - - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SketchSearchPolicy, SearchPolicy, - SketchSearchPolicyNode); -}; - -/*! \brief Pre-search callback function to load custom rules for sketch generation */ -class PreloadCustomSketchRuleNode : public SearchCallbackNode { - public: - // TODO(jcf94): Use tvm::runtime::TypedPackedFunc? - PackedFunc meet_condition_func; - PackedFunc apply_func; - - void callback(SearchPolicyNode* policy) final; - - static constexpr const char *_type_key = "ansor.PreloadCustomSketchRule"; - TVM_DECLARE_FINAL_OBJECT_INFO(PreloadCustomSketchRuleNode, SearchCallbackNode); -}; - -/*! - * \brief Managed reference to PreloadCustomSketchRuleNode. - * \sa PreloadCustomSketchRuleNode - */ -class PreloadCustomSketchRule : public SearchCallback { - public: - PreloadCustomSketchRule(PackedFunc meet_condition_func, - PackedFunc apply_func); - - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PreloadCustomSketchRule, SearchCallback, - PreloadCustomSketchRuleNode); -}; - -} // namespace ansor -} // namespace tvm - -#endif // TVM_ANSOR_SEARCH_POLICY_SKETCH_SEARCH_POLICY_H_ diff --git a/src/ansor/search_policy/utils.cc b/src/ansor/search_policy/utils.cc deleted file mode 100644 index 2d2f92ecbc20..000000000000 --- a/src/ansor/search_policy/utils.cc +++ /dev/null @@ -1,744 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file ansor/search_policy/utils.cc - * \brief Common utilities for search policies - */ - -#include "utils.h" -#include "search_policy.h" - -namespace tvm { -namespace ansor { - -void GetSpaceSplitStepIds(const State& s, int stage_id, std::vector* spatial_split_step_ids) { - auto pop = s->stages[stage_id]->op.as(); - CHECK(pop != nullptr); - - const auto& no_split_name_pair = QueryNoSplitAxis(s->stages[stage_id]); - const std::set& no_split_at_inner_name_set = no_split_name_pair.first; - const std::set& no_split_at_outer_name_set = no_split_name_pair.second; - - size_t reduce_count = 0; - for (const auto axis : pop->reduce_axis) { - if (!no_split_at_inner_name_set.count(axis->var->name_hint) && - !no_split_at_outer_name_set.count(axis->var->name_hint)) { - reduce_count++; - } - } - - for (int i = static_cast(s->transform_steps.size()) - 1; i >= 0; --i) { - if (s->transform_steps[i]->IsInstance() || - s->transform_steps[i]->IsInstance() || - s->transform_steps[i]->IsInstance()) { - if (stage_id > s->transform_steps[i]->stage_id) { - stage_id--; - } - } else if (auto ps = s->transform_steps[i].as()) { - if (stage_id == ps->stage_id) { - // Assume SplitStep on reduction axes are always after SplitStep on spatial axes. - // TODO(jcf94): do not rely on this assumption - if (reduce_count) { - reduce_count--; - } else { - spatial_split_step_ids->push_back(i); - } - } - } - } -} - -State DoMultiLevelTiling(const State& state, int stage_id, const std::string& format, - std::vector* spatial_split_step_ids) { - std::vector > space_levels; - std::vector > reduce_levels; - std::vector space_outer, space_inner, reduce_outer, reduce_inner; - std::vector split_res; - - for (const auto c : format) { - if (tolower(c) == 's') { - space_levels.emplace_back(); - } else if (tolower(c) == 'r') { - reduce_levels.emplace_back(); - } else { - LOG(FATAL) << "Invalid multi-level tiling format: " << format; - } - } - size_t n_space = space_levels.size(); - size_t n_reduce = reduce_levels.size(); - - spatial_split_step_ids->clear(); - - State tmp_s = state; - const Stage& stage = state->stages[stage_id]; - const auto& no_split_name_pair = QueryNoSplitAxis(stage); // handle special split strategy - const auto& last_split_is_one_name_set = QueryLastSplitIsOneAxis(stage); - const std::set& no_split_at_inner_name_set = no_split_name_pair.first; - const std::set& no_split_at_outer_name_set = no_split_name_pair.second; - - for (const auto& iter : state->stages[stage_id]->iters) { - if (iter->iter_type == kSpace) { - if (!no_split_at_inner_name_set.count(iter->name) && - !no_split_at_outer_name_set.count(iter->name)) { - CHECK_GE(n_space, 1); - int tmp_n_space = n_space; - - if (last_split_is_one_name_set.count(iter->name)) { - tmp_n_space--; - } - - if (tmp_n_space == 1) { - space_levels[0].push_back(iter); - } else { - split_res = tmp_s.split(stage_id, iter, std::vector(tmp_n_space - 1)); - for (int i = 0; i < tmp_n_space; i++) { - space_levels[i].push_back(std::move(split_res[i])); - } - spatial_split_step_ids->push_back(tmp_s->transform_steps.size() - 1); - } - } else { - if (no_split_at_inner_name_set.count(iter->name)) { - space_inner.push_back(iter); - } - if (no_split_at_outer_name_set.count(iter->name)) { - space_outer.push_back(iter); - } - } - } else if (iter->iter_type == kReduce) { - if (!no_split_at_inner_name_set.count(iter->name) && - !no_split_at_outer_name_set.count(iter->name)) { - CHECK_GE(n_reduce, 1); - - if (n_reduce == 1) { - reduce_levels[0].push_back(iter); - } else { - split_res = tmp_s.split(stage_id, iter, std::vector(n_reduce - 1)); - for (size_t i = 0; i < n_reduce; i++) { - reduce_levels[i].push_back(std::move(split_res[i])); - } - } - } else { - if (no_split_at_inner_name_set.count(iter->name)) { - reduce_inner.push_back(iter); - } - if (no_split_at_outer_name_set.count(iter->name)) { - reduce_outer.push_back(iter); - } - } - } else { - LOG(FATAL) << "Invalid iter type: " << iter->iter_type; - } - } - - if (!space_outer.empty()) { - CHECK(!space_levels.empty()); - space_levels.front().insert(space_levels.front().begin(), - std::make_move_iterator(space_outer.begin()), - std::make_move_iterator(space_outer.end())); - } - if (!space_inner.empty()) { - CHECK(!space_levels.empty()); - space_levels.back().insert(space_levels.back().begin(), - std::make_move_iterator(space_inner.begin()), - std::make_move_iterator(space_inner.end())); - } - - if (!reduce_outer.empty()) { - CHECK(!reduce_levels.empty()); - reduce_levels.front().insert(reduce_levels.front().begin(), - std::make_move_iterator(reduce_outer.begin()), - std::make_move_iterator(reduce_outer.end())); - } - if (!reduce_inner.empty()) { - CHECK(!reduce_levels.empty()); - reduce_levels.back().insert(reduce_levels.back().begin(), - std::make_move_iterator(reduce_inner.begin()), - std::make_move_iterator(reduce_inner.end())); - } - - std::vector order; - int space_ct = 0, reduce_ct = 0; - for (const auto c : format) { - if (tolower(c) == 's') { - order.insert(order.end(), std::make_move_iterator(space_levels[space_ct].begin()), - std::make_move_iterator(space_levels[space_ct].end())); - space_ct++; - } else if (tolower(c) == 'r') { - order.insert(order.end(), std::make_move_iterator(reduce_levels[reduce_ct].begin()), - std::make_move_iterator(reduce_levels[reduce_ct].end())); - reduce_ct++; - } else { - LOG(FATAL) << "Invalid multi level tiling format: " << format; - } - } - - tmp_s.reorder(stage_id, order); - return tmp_s; -} - -State FollowTiling(const State& state, int stage_id, - const std::vector& split_step_ids, int n_split) { - if (n_split < 1 || n_split > 3) { - LOG(FATAL) << "Invalid split parts, currently only support 1, 2 and 3"; - } - // Apply up to three-level tiling structure: space_L0, space_L1, space_L2 - std::vector space_0, space_1, space_2, space_3; - std::vector split_res, tmp_order; - - auto pop = state->stages[stage_id]->op.as(); - CHECK(pop != nullptr); - const Stage& stage = state->stages[stage_id]; - const auto& no_split_name_pair = QueryNoSplitAxis(stage); // handle special split strategy - const std::set& no_split_at_inner_name_set = no_split_name_pair.first; - const std::set& no_split_at_outer_name_set = no_split_name_pair.second; - int no_split_at_inner_name_in_stage_cnt = 0; - int no_split_at_outer_name_in_stage_cnt = 0; - for (const auto& iter : state->stages[stage_id]->iters) { - no_split_at_inner_name_in_stage_cnt += no_split_at_inner_name_set.count(iter->name); - no_split_at_outer_name_in_stage_cnt += no_split_at_outer_name_set.count(iter->name); - } - - CHECK_EQ(state->stages[stage_id]->iters.size() - - no_split_at_inner_name_in_stage_cnt - - no_split_at_outer_name_in_stage_cnt, - split_step_ids.size()); - - State tmp_s = state; - int ct = 0; - for (const auto& iter : state->stages[stage_id]->iters) { - if (iter->iter_type == kSpace) { - // For spatial iterator, split it into multi iterators - if (!no_split_at_inner_name_set.count(iter->name) && - !no_split_at_outer_name_set.count(iter->name)) { - IteratorAnnotation ann_type = iter->annotation; - split_res = tmp_s.follow_split(stage_id, iter, split_step_ids[ct], - n_split); - // Restore annotation. Move unroll and vectorize to inner, move parallel - // to outer - switch (ann_type) { - case kUnroll: - split_res[n_split] = tmp_s.unroll(stage_id, split_res[n_split]); - break; - case kVectorize: - split_res[n_split] = tmp_s.vectorize(stage_id, split_res[n_split]); - break; - case kParallel: - split_res[0] = tmp_s.parallel(stage_id, split_res[0]); break; - default: - break; - } - - space_0.push_back(std::move(split_res[0])); - space_1.push_back(std::move(split_res[1])); - if (n_split >= 2) { - space_2.push_back(std::move(split_res[2])); - if (n_split == 3) { - space_3.push_back(std::move(split_res[3])); - } - } - ct++; - } else { - if (no_split_at_outer_name_set.count(iter->name)) { - space_0.push_back(iter); - } - if (no_split_at_inner_name_set.count(iter->name)) { - if (n_split == 1) { - space_1.push_back(iter); - } else if (n_split == 2) { - space_2.push_back(iter); - } else { - CHECK_EQ(n_split, 3); - space_3.push_back(iter); - } - } - } - } else { - LOG(FATAL) << "Invalid iter type: " << iter->iter_type; - } - } - - if (n_split == 3) { - ConcatenateMove(&tmp_order, &space_0, &space_1, &space_2, &space_3); - } else if (n_split == 2) { - ConcatenateMove(&tmp_order, &space_0, &space_1, &space_2); - } else { - ConcatenateMove(&tmp_order, &space_0, &space_1); - } - tmp_s.reorder(stage_id, tmp_order); - return tmp_s; -} - -State RandomMutateTileSize(const State& old_state, SplitFactorizationMemo* split_memo, - std::mt19937* random_gen, int max_innermost_split_factor) { - State tmp_s = old_state; - - // Extract all SplitStep - std::vector split_step_ids; - for (size_t i = 0; i < tmp_s->transform_steps.size(); ++i) { - if (auto ps = tmp_s->transform_steps[i].as()) { - if (ps->extent.defined() && ps->extent->IsInstance() && - GetIntImm(ps->lengths.back()) <= max_innermost_split_factor) { - split_step_ids.push_back(i); - } - } - } - if (split_step_ids.empty()) { - return State(); - } - - // Find a SplitStep with extent != 1 - int retry_ct = 0; - int64_t extent = 1; - int step_id; - const SplitStepNode* ps; - - do { - step_id = split_step_ids[(*random_gen)() % split_step_ids.size()]; - ps = tmp_s->transform_steps[step_id].as(); - CHECK(ps != nullptr); - extent = GetIntImm(ps->extent); - retry_ct += 1; - } while (retry_ct < static_cast(split_step_ids.size()) << 2 && - (extent == 1 || extent == 0)); - - if (extent == 0 || extent == 1) { - return State(); - } - - // Mutate tile size - std::vector lengths(ps->lengths.size() + 1, 1); - for (int i = 0; i < static_cast(ps->lengths.size()); ++i) { - lengths[i + 1] = GetIntImm(ps->lengths[i]); - } - lengths[0] = extent / ElementProduct(lengths); - - std::vector random_perm; - RandomPermutation(lengths.size(), &random_perm, random_gen); - - for (size_t i = 0; i < random_perm.size(); ++i) { - size_t src_idx = random_perm[i]; - int length = lengths[src_idx]; - - if (length == 1) { - continue; - } - - // Divide one factor from lengths[src_idx] and multiply it to lengths[dst_idx] - size_t dst_idx = random_perm[(i + 1) % random_perm.size()]; - - const std::vector& factors = split_memo->GetFactors(length); - CHECK_GE(factors.size(), 1); - - int divide_factor; - if (dst_idx == lengths.size() - 1) { - // Maintain the restriction of hardware_params.max_innermost_split_factor - int max_factor_index = static_cast(factors.size()) - 1; - for (; max_factor_index >= 1; max_factor_index--) { - if (factors[max_factor_index] * lengths[dst_idx] <= max_innermost_split_factor) { - break; - } - } - if (max_factor_index == 0) { - // failed on this dst_idx, try next one - continue; - } - divide_factor = factors[1 + (*random_gen)() % (max_factor_index)]; - } else { - divide_factor = factors[1 + (*random_gen)() % (factors.size() - 1)]; - } - - std::vector new_lengths; - for (size_t j = 1; j < lengths.size(); ++j) { - if (j == src_idx) { - new_lengths.emplace_back(lengths[j] / divide_factor); - } else if (j == dst_idx) { - new_lengths.emplace_back(lengths[j] * divide_factor); - } else { - new_lengths.emplace_back(lengths[j]); - } - } - - CHECK_LE(GetIntImm(new_lengths.back()), max_innermost_split_factor); - - auto pstate = tmp_s.CopyOnWrite(); - pstate->transform_steps[step_id] = - SplitStep(ps->stage_id, ps->iter_id, ps->extent, new_lengths, ps->inner_to_outer); - return tmp_s; - } - - return State(); -} - -State RandomMutateMaxUnrollStep(const State& old_state, std::mt19937* random_gen, - const std::vector& auto_unroll_configs) { - State tmp_s = old_state; - - // Extract all auto_unroll_max_step pragma steps. - std::vector annotate_steps; - for (size_t i = 0; i < old_state->transform_steps.size(); ++i) { - if (auto ps = tmp_s->transform_steps[i].as()) { - if (ps->pragma_type.find("auto_unroll_max_step") != std::string::npos) { - annotate_steps.push_back(i); - } - } - } - if (annotate_steps.empty()) { - return State(); - } - - // Randomly pick one step. - auto step_id = annotate_steps[(*random_gen)() % annotate_steps.size()]; - auto ps = tmp_s->transform_steps[step_id].as(); - auto val = std::to_string(auto_unroll_configs[(*random_gen)() % auto_unroll_configs.size()]); - - auto pstate = tmp_s.CopyOnWrite(); - pstate->transform_steps[step_id] = PragmaStep( - ps->stage_id, ps->iter_id, std::string("auto_unroll_max_step") + "$" + val); - return tmp_s; -} - -State RandomMutateParallel(const State& old_state, std::mt19937* random_gen, - const SearchTask& task, int verbose) { - // To make this mutation simple but promising, we only focus on a specific case that - // parallel was added to the outermost loop and the loop is generated by fusing other loops. - // In short, we mutate the step pattern of (fuse -> parallel). - - // Extract all parallel steps. - std::vector parallel_steps; - for (size_t s = 0; s < old_state->transform_steps.size(); ++s) { - auto ps = old_state->transform_steps[s].as(); - if (!ps || ps->annotation != kParallel) { - continue; - } - parallel_steps.push_back(s); - } - if (parallel_steps.empty()) { - StdCout(verbose) << "Parallel mutation failed: No parallel annotations" << std::endl; - return State(); - } - - // Randomly pick one step. - int retry_ct = 0; - size_t step_id = 0; - size_t stage_id = 0; - do { - step_id = parallel_steps[(*random_gen)() % parallel_steps.size()]; - auto step = old_state->transform_steps[step_id].as(); - stage_id = step->stage_id; - - // Check assumptions. - auto iter_id = step->iter_id; - if (iter_id == 0 && step_id > 0 && old_state->transform_steps[step_id - 1].as()) { - break; - } - retry_ct++; - } while (retry_ct <= 3); - - if (retry_ct > 3) { - StdCout(verbose) << "Parallel mutation failed: No valid parallel annotations" << std::endl; - return State(); - } - - // Replay a new state until the picked fuse step. - State tmp_s = task->compute_dag.GetInitState(); - for (size_t s = 0; s < step_id - 1; ++s) { - auto step = old_state->transform_steps[s]; - tmp_s.CopyOnWrite()->transform_steps.push_back(step); - tmp_s.DoStep(step, task->compute_dag); - } - - // Determine the fuse direction. - // 0: fuse less; 1: fuse more. - auto fuse_step = old_state->transform_steps[step_id - 1].as(); - std::vector fused_ids = fuse_step->fused_ids; - std::vector fuse_dir = {0.5, 1.0}; - - // The case we can only fuse more. - if (fused_ids.size() == 1) { - fuse_dir[0] = 0.0; - } - - // The cases that we cannot fuse the next iters. - if (old_state->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, 0)) > 0 || - tmp_s->stages[stage_id]->iters.size() == fused_ids.size() || - tmp_s->stages[stage_id]->iters[1]->iter_type == kReduce) { - // In case we cannot fuse less neither, give up. - if (fuse_dir[0] == 0.0) { - StdCout(verbose) << "Parallel mutation failed: Cannot fuse more or less iters" << std::endl; - return State(); - } - fuse_dir[0] = 1.0; - } - - int iter_offset = 0; - if (RandomChoose(fuse_dir, random_gen) == 0) { - StdCout(verbose) << "Parallel mutation: release iter " << fused_ids.back() << std::endl; - fused_ids.pop_back(); - iter_offset = 1; - } else { - StdCout(verbose) << "Parallel mutation: include iter " << fused_ids.back() + 1 << std::endl; - fused_ids.push_back(fused_ids.back() + 1); - iter_offset = -1; - } - - // Replay the mutated fused and annotation step. - auto new_fuse_step = FuseStep(stage_id, fused_ids); - tmp_s.CopyOnWrite()->transform_steps.push_back(new_fuse_step); - tmp_s.DoStep(new_fuse_step, task->compute_dag); - tmp_s.CopyOnWrite()->transform_steps.push_back(old_state->transform_steps[step_id]); - tmp_s.DoStep(old_state->transform_steps[step_id], task->compute_dag); - - // Replay the rest steps. - for (size_t s = step_id + 1; s < old_state->transform_steps.size(); ++s) { - auto step = old_state->transform_steps[s]; - if (step->stage_id == static_cast(stage_id)) { - // Since we change the loop structure, iter ID in later steps to the same stage - // has to be adjusted. - auto ps = step.as(); - if (ps) { - if (ps->iter_id == 0) { - step = AnnotationStep(ps->stage_id, 0, ps->annotation); - } else { - CHECK_LE(ps->iter_id + iter_offset, tmp_s->stages[stage_id]->iters.size()); - step = AnnotationStep(ps->stage_id, ps->iter_id + iter_offset, ps->annotation); - } - } else { - StdCout(verbose) << "Parallel mutation: Cannot apply " << step << " after fuse" - << std::endl; - return State(); - } - } - tmp_s.CopyOnWrite()->transform_steps.push_back(step); - tmp_s.DoStep(step, task->compute_dag); - } - return tmp_s; -} - - -State RandomMutateComputeLocation(const State& old_state, std::mt19937* random_gen, - const SearchTask& task) { - // Extract all compute_at steps. - std::vector compute_at_steps; - for (size_t s = 0; s < old_state->transform_steps.size(); ++s) { - if (auto ps = old_state->transform_steps[s].as()) { - const Stage& stage = old_state->stages[ps->stage_id]; - if (IsTiled(stage)) { - continue; - } - - if (NeedsMultilevelTiling(task, old_state, stage->op)) { - continue; - } - compute_at_steps.push_back(s); - } - } - if (compute_at_steps.empty()) { - return State(); - } - - // Randomly pick one step - size_t step_id = compute_at_steps[(*random_gen)() % compute_at_steps.size()]; - auto ps = old_state->transform_steps[step_id].as(); - CHECK(ps != nullptr); - const Stage& stage = old_state->stages[ps->stage_id]; - - // Randomly pick one tile level - int new_compute_at_stage_id; - int new_compute_at_iter_id; - - // Copied from InitPopulationChangeComputeLocation - { - std::unordered_set consumers; - GetConsumers(task, old_state, stage->op, &consumers); - if (consumers.empty()) { - return State(); - } - - int target_stage_id; - if (consumers.size() == 1) { - target_stage_id = OperationToStage(*consumers.begin(), old_state); - } else { - // check all consumers share a common root - int common_root_id = -1; - bool mismatch = false; - for (const auto& consumer : consumers) { - int consumer_stage_id = OperationToStage(consumer, old_state); - int root_id = -1; - if ((old_state)->stages[consumer_stage_id]->compute_at == kRoot) { - root_id = consumer_stage_id; - } else if ((old_state)->stages[consumer_stage_id]->compute_at == kIter) { - root_id = (old_state)->attach_map->stage_to_attach_iter.at(consumer_stage_id).first; - } else { - LOG(FATAL) << "Invalid case"; - } - - if (common_root_id == -1) { - common_root_id = root_id; - } else { - if (common_root_id != root_id) { - mismatch = true; - break; - } - } - } - - if (mismatch) { - return State(); - } - target_stage_id = common_root_id; - } - - const Stage& target_stage = old_state->stages[target_stage_id]; - std::set to_unroll_name_set; - if (target_stage->op->attrs.count(SearchPolicyNode::always_unroll_key)) { - to_unroll_name_set = GetIterNameSetParam(target_stage->op->attrs, - SearchPolicyNode::always_unroll_key); - } - - std::vector > candidates; - bool target_compute_at_other = target_stage->compute_at == kIter; - bool target_is_tiled = IsTiled(target_stage); - - bool visited_reduce = false; - // enumerate compute_at location at target_stage - int ct = 0; - for (size_t iter_id = 0; iter_id < target_stage->iters.size(); ++iter_id) { - const auto& target_iter = target_stage->iters[iter_id]; - if (target_iter->iter_type == kReduce) { - visited_reduce = true; - if (!target_is_tiled) { // do not go into reduce iter - break; - } - } else if (target_iter->iter_type == kSpace) { - if (visited_reduce) { // do not go into inner tile - break; - } - } - - if (to_unroll_name_set.count(target_iter->name)) { - // Do not go into always unroll region - break; - } - - if (GetExtent(target_iter) == 1) { // skip iterators with length of 1 - continue; - } - if (target_compute_at_other && target_iter->iter_type == kSpace && - StrEndsWith(target_iter->name, ".0")) { - // skip the first level iterators if target stage compute_at another stage - // In this case, the lengths of first level iterators are always one - continue; - } - candidates.emplace_back(target_stage_id, iter_id); - - if ((old_state)->attach_map->iter_to_attached_stages.count( - std::make_pair(target_stage_id, ct++))) { - break; - } - } - - // if the target_stage is already compute_at another stage X, try also compute_at X - // We call stage X as `target_target_stage` - if (target_compute_at_other) { - int target_target_stage_id; - target_target_stage_id = (old_state)->attach_map->stage_to_attach_iter.at( - target_stage_id).first; - const Stage& target_target_stage = (old_state)->stages[target_target_stage_id]; - if (target_target_stage->op->attrs.count(SearchPolicyNode::always_unroll_key)) { - to_unroll_name_set = GetIterNameSetParam(target_target_stage->op->attrs, - SearchPolicyNode::always_unroll_key); - } else { - to_unroll_name_set.clear(); - } - - int ct = 0; - for (size_t iter_id = 0; iter_id < target_target_stage->iters.size(); ++iter_id) { - const auto& target_target_iter = target_target_stage->iters[iter_id]; - if (target_target_iter->iter_type == kReduce || - (old_state)->attach_map->iter_to_attached_stages.count( - std::make_pair(target_target_stage_id, ct++))) { - break; - } - - if (to_unroll_name_set.count(target_target_iter->name)) { - // Do not go into always unroll region - break; - } - - if (GetExtent(target_target_iter) == 1) { // skip iterators with length of 1 - continue; - } - - candidates.emplace_back(target_target_stage_id, iter_id); - } - } - - if (candidates.empty()) { - return State(); - } - - int choice = (*random_gen)() % (candidates.size()); - new_compute_at_stage_id = candidates[choice].first; - new_compute_at_iter_id = candidates[choice].second; - } - - // Replay a new state. - State tmp_s = task->compute_dag.GetInitState(); - for (size_t s = 0; s < old_state->transform_steps.size(); ++s) { - if (s == step_id) { - tmp_s.CopyOnWrite()->transform_steps.push_back( - ComputeAtStep(ps->stage_id, new_compute_at_stage_id, new_compute_at_iter_id)); - } else { - tmp_s.CopyOnWrite()->transform_steps.push_back(old_state->transform_steps[s]); - } - try { - tmp_s.DoStep(tmp_s->transform_steps.back(), task->compute_dag); - } catch (dmlc::Error &e) { - return State(); - } - } - - return tmp_s; -} - -void PruneUndefined(std::vector* states) { - size_t pt = 0; - for (size_t i = 0; i < states->size(); ++i) { - if (!(*states)[i].defined()) { - continue; - } - if (i != pt) { - (*states)[pt++] = std::move((*states)[i]); - } - pt++; - } - - if (pt == 0) { - LOG(FATAL) << "All states are undefined."; - } else { - states->resize(pt); - } -} - -State CrossOverState(const State& p1, const State& p2) { return State(); } - -} // namespace ansor -} // namespace tvm - diff --git a/src/ansor/search_policy/utils.h b/src/ansor/search_policy/utils.h deleted file mode 100644 index 107e2ee72521..000000000000 --- a/src/ansor/search_policy/utils.h +++ /dev/null @@ -1,483 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file ansor/search_policy/utils.cc - * \brief Common utilities for search policies - */ - -#ifndef TVM_ANSOR_SEARCH_POLICY_UTILS_H_ -#define TVM_ANSOR_SEARCH_POLICY_UTILS_H_ - -#include -#include -#include -#include -#include -#include -#include "../cost_model/cost_model.h" -#include "../utils.h" -#include "../loop_state.h" -#include "../transform_step.h" -#include "search_policy.h" - -namespace tvm { -namespace ansor { - -// Get an integer from a tvm str Map -inline int GetIntParam(const Map& attr_dict, - const std::string& key) { - CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; - auto pint = attr_dict[key].as(); - CHECK(pint != nullptr); - return pint->value; -} - -// Get a double from a tvm str Map -inline double GetDoubleParam(const Map& attr_dict, - const std::string& key) { - CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; - auto pdouble = attr_dict[key].as(); - CHECK(pdouble != nullptr); - return pdouble->value; -} - -// Get a string from a tvm str Map -inline std::string GetStringParam(const Map& attr_dict, - const std::string& key) { - CHECK_GT(attr_dict.count(key), 0) - << "Cannot find key: \"" << key << "\" in " << attr_dict; - const auto& target = attr_dict[key]; - if (auto pstr = target.as()) { - return pstr->value; - } - auto pstr = target.as(); - CHECK(pstr != nullptr); - return pstr->data; -} - -// Get a iterator name set from a tvm str Map -inline std::set GetIterNameSetParam(const Map& attr_dict, - const std::string& key) { - std::set ret; - CHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; - auto names = attr_dict[key].as(); - CHECK(names != nullptr); - for (const auto & name : *names) { - ret.insert(name.as()->value); - } - return ret; -} - -// Convert operation to stage id -inline int OperationToStage(const te::Operation& op, const State& state) { - for (size_t i = 0; i < state->stages.size(); ++i) { - if (op == state->stages[i]->op) { - return i; - } - } - LOG(FATAL) << "Cannot find op: " << op; - return -1; -} - -// Return the extent of an iterator -inline int64_t GetExtent(const Iterator& it) { - if (it->range.defined()) { - if (auto pint = it->range->extent.as()) { - return pint->value; - } - } - return -1; -} - -// Return whether an op is strict inlineable -inline bool IsStrictInlineable(const SearchTask& task, - const State& state, const te::Operation& op) { - if (state->task_dag.defined()) { - return state->task_dag->access_analyzer.IsStrictInlineable(op); - } else { - return task->compute_dag->access_analyzer.IsStrictInlineable(op); - } -} - -// Return whether an op is an output op -inline bool IsOutputOp(const SearchTask& task, const State& state, const te::Operation& op) { - if (state->task_dag.defined()) { - return state->task_dag->access_analyzer.IsOutput(op); - } else { - return task->compute_dag->access_analyzer.IsOutput(op); - } -} - -// Return whether the stage has an attribute flag -inline bool HasAttrsFlag(const State& state, int stage_id, const char* target) { - if (state->stages[stage_id]->op->attrs.count(target)) { - return GetStringParam(state->stages[stage_id]->op->attrs, target) == "True"; - } - return false; -} - -// Return whether the stage has reduce iterators -inline bool HasReduceIter(const Stage& stage) { - for (const auto& iter : stage->iters) { - if (iter->iter_type != kSpace) { - return true; - } - } - return false; -} - -// Return whether the stage has specific annotated iterators -inline bool HasAnnotationIter(const Stage& stage, IteratorAnnotation type) { - for (const auto& iter : stage->iters) { - if (iter->annotation == type) { - return true; - } - } - return false; -} - -// Return whether an op needs multi level tiling -inline bool NeedsMultilevelTiling(const SearchTask& task, - const State& state, const te::Operation& op) { - if (state->task_dag.defined()) { - return state->task_dag->access_analyzer.NeedsMultiLevelTiling(op); - } else { - return task->compute_dag->access_analyzer.NeedsMultiLevelTiling(op); - } -} - -// Get all consumers for an op. This will take inline into consideration -inline void GetConsumers(const SearchTask& task, const State& state, const te::Operation& op, - std::unordered_set* consumers) { - if (state->task_dag.defined()) { - state->task_dag->access_analyzer.GetConsumers(state, op, consumers); - } else { - task->compute_dag->access_analyzer.GetConsumers(state, op, consumers); - } -} - -inline void GetProducers(const SearchTask& task, const State& state, const te::Operation& op, - std::unordered_set* producers) { - if (state->task_dag.defined()) { - state->task_dag->access_analyzer.GetProducers(state, op, producers); - } else { - task->compute_dag->access_analyzer.GetProducers(state, op, producers); - } -} - -// Return whether two ops are elementwise-matched -inline bool ElementwiseMatch(const SearchTask& task, const State& state, const te::Operation& op, - const te::Operation& target_op) { - if (state->task_dag.defined()) { - return state->task_dag->access_analyzer.ElementWiseMatch(op, target_op); - } else { - return task->compute_dag->access_analyzer.ElementWiseMatch(op, target_op); - } -} - -// Return whether the stage has only one consumer and they are elementwise-matched -inline bool HasSingleElementwiseMatchedConsumer(const SearchTask& task, - const State& state, const Stage& stage, int* target_stage_id) { - std::unordered_set consumers; - - GetConsumers(task, state, stage->op, &consumers); - if (consumers.size() == 1) { - *target_stage_id = OperationToStage(*consumers.begin(), state); - const Stage& target_stage = state->stages[*target_stage_id]; - if (ElementwiseMatch(task, state, stage->op, target_stage->op) && - (!(HasReduceIter(stage) && HasReduceIter(target_stage)))) { - return true; - } - } - return false; -} - -// Return whether this stage needs rfactor -inline bool NeedsRfactor(const SearchTask& task, const State& state, const te::Operation& op) { - if (op->IsInstance()) { - // Compute the product of lengths of all space iters and all reduce iters - int64_t cum_space_len = 1, cum_reduce_len = 1; - int stage_id = OperationToStage(op, state); - for (const auto& iter : state->stages[stage_id]->iters) { - if (iter->iter_type == kSpace) { - cum_space_len *= GetExtent(iter); - } else if (iter->iter_type == kReduce) { - cum_reduce_len *= GetExtent(iter); - } - } - - if (NeedsMultilevelTiling(task, state, op)) { - // Do not use rfactor if we have enough parallelism on space iters - if (cum_space_len > cum_reduce_len || - cum_space_len > task->hardware_params->num_cores * 16) { - return false; - } else { - return true; - } - } else if (cum_reduce_len > 1) { - // Always try rfactor for reduction ops - return true; - } - } - - return false; -} - -// Return whether the state did cache_write for stage_id -inline bool HasCacheWriteStage(const State& s, int stage_id) { - for (int i = static_cast(s->transform_steps.size()) - 1; i >= 0; --i) { - if (auto ps = s->transform_steps[i].as()) { - if (stage_id > ps->stage_id) { - stage_id--; - } else if (stage_id == ps->stage_id) { - return true; - } - } else if (auto ps = s->transform_steps[i].as()) { - if (stage_id > ps->stage_id) { - stage_id--; - } - } else if (auto ps = s->transform_steps[i].as()) { - if (stage_id > ps->stage_id) { - stage_id--; - } - } - } - return false; -} - -// Return whether the state did cache_read for stage_id -inline bool HasCacheReadStage(const State& s, int stage_id) { - for (int i = static_cast(s->transform_steps.size()) - 1; i >= 0; --i) { - if (auto ps = s->transform_steps[i].as()) { - if (stage_id > ps->stage_id) { - stage_id--; - } - } else if (auto ps = s->transform_steps[i].as()) { - if (stage_id > ps->stage_id) { - stage_id--; - } else if (stage_id == ps->stage_id) { - return true; - } - } else if (auto ps = s->transform_steps[i].as()) { - if (stage_id > ps->stage_id) { - stage_id--; - } - } - } - return false; -} - -// Return whether the state did split/follow_split/follow_fused_split in stage_id -inline bool HasSplitStep(const State& s, int stage_id) { - for (int i = static_cast(s->transform_steps.size()) - 1; i >= 0; --i) { - if (s->transform_steps[i]->IsInstance() || - s->transform_steps[i]->IsInstance() || - s->transform_steps[i]->IsInstance()) { - if (stage_id > s->transform_steps[i]->stage_id) { - stage_id--; - } - } else if (s->transform_steps[i]->IsInstance() || - s->transform_steps[i]->IsInstance() || - s->transform_steps[i]->IsInstance()) { - if (stage_id == s->transform_steps[i]->stage_id) { - return true; - } - } - } - return false; -} - -// Return whether the stage has been tiled already -inline bool IsTiled(const Stage& stage) { - auto op = stage->op.as(); - CHECK(op != nullptr); - return stage->iters.size() != op->axis.size() + op->reduce_axis.size(); -} - -// Query axes that should not be splitted according to the attribute from tvm.compute -inline std::pair, std::set > QueryNoSplitAxis( - const Stage& stage) { - std::pair, std::set > ret; - if (stage->op->attrs.count(SearchPolicyNode::no_split_at_inner_key)) { - ret.first = GetIterNameSetParam(stage->op->attrs, SearchPolicyNode::no_split_at_inner_key); - } - if (stage->op->attrs.count(SearchPolicyNode::no_split_at_outer_key)) { - ret.second = GetIterNameSetParam(stage->op->attrs, SearchPolicyNode::no_split_at_outer_key); - } - return ret; -} - -// Query axes that last split is one -inline std::set QueryLastSplitIsOneAxis(const Stage& stage) { - std::set ret; - if (stage->op->attrs.count(SearchPolicyNode::last_split_is_one_key)) { - ret = GetIterNameSetParam(stage->op->attrs, SearchPolicyNode::last_split_is_one_key); - } - return ret; -} - -// Extract primitive iterators from a nested fused or splitted iterator's name -inline void ExtractOriginalIterators(const std::string& name, std::set* rets) { - size_t last_pos = 0; - for (size_t i = 0; i < name.size(); ++i) { - if (name[i] == '@' || name[i] == '.') { // '@' for fuse and '.' for split - if (!isdigit(name[last_pos]) && name[last_pos] != '@' && name[last_pos] != '.') { - rets->insert(name.substr(last_pos, i - last_pos)); - } - last_pos = i + 1; - } - } - - if (last_pos < name.size() && !isdigit(name[last_pos]) && - name[last_pos] != '@' && name[last_pos] != '.') { - rets->insert(name.substr(last_pos, name.size() - last_pos)); - } -} - -// Get the last space iterator in the outer most tile -inline const Iterator& GetLastSpaceIteratorInOutermostTile(const Stage& stage) { - auto pop = stage->op.as(); - CHECK(pop != nullptr); - std::set original_names; - - for (const auto& iter : stage->iters) { - ExtractOriginalIterators(iter->name, &original_names); - if (original_names.size() == pop->axis.size()) { - return iter; - } - } - - LOG(FATAL) << "Cannot find the iterator."; - return stage->iters[0]; -} - -// Get the last reduce iterator in the outermost reduce tile -inline const Iterator& GetLastReduceIteratorInOutermostReduceTile(const Stage& stage) { - auto pop = stage->op.as(); - CHECK(pop != nullptr); - std::set original_names; - - auto no_split_name_pair = QueryNoSplitAxis(stage); - std::set no_split_at_inner_name_set = no_split_name_pair.first; - size_t axis_size = 0; - for (const auto axis : pop->axis) { - if (!no_split_at_inner_name_set.count(axis->var->name_hint)) { - axis_size++; - } - } - size_t reduce_axis_size = 0; - for (const auto axis : pop->reduce_axis) { - if (!no_split_at_inner_name_set.count(axis->var->name_hint)) { - reduce_axis_size++; - } - } - - if (reduce_axis_size) { - for (const auto& iter : stage->iters) { - ExtractOriginalIterators(iter->name, &original_names); - if (original_names.size() == axis_size + reduce_axis_size) { - return iter; - } - } - } else { - for (size_t i = 0; i < stage->iters.size(); i++) { - ExtractOriginalIterators(stage->iters[i]->name, &original_names); - if (original_names.size() == axis_size + 1) { - return stage->iters[i-1]; - } - } - } - - LOG(FATAL) << "Cannot find the iterator."; - return stage->iters[0]; -} - -// Random sample states -inline void RandomSampleStates(const std::vector& in_states, std::mt19937* random_gen, - size_t out_size, std::vector* out_states) { - out_states->clear(); - for (size_t i = 0; i < out_size; i++) { - out_states->push_back(in_states[(*random_gen)() % in_states.size()]); - } -} - -// Random choose an index according to a prefix sum probability -inline int RandomChoose(const std::vector& prefix_sum_probs, std::mt19937* random_gen) { - std::uniform_real_distribution<> dis(0.0, 1.0); - double x = dis(*random_gen); - - CHECK(!prefix_sum_probs.empty()); - - return std::lower_bound(prefix_sum_probs.begin(), prefix_sum_probs.end(), x) - - prefix_sum_probs.begin(); -} - -// Print all states -inline void PrintAllStates(const std::vector& states) { - for (size_t i = 0; i < states.size(); ++i) { - std::cerr << i << std::endl; - std::cerr << states[i]; - std::cerr << "==============================================" << std::endl; - } -} - -// Get all split steps on spatial iterators for one stage -void GetSpaceSplitStepIds(const State& s, int stage_id, std::vector* spatial_split_step_ids); - -// Apply multi-level tiling structure according to a string format, -// where "S" stands a space level, "R" stands for a reudciton level. -// For example, if the format is "SSRSRS", the we will -// use tiling structure: space_L0, space_L1, reduce_L0, space_L2, reduce_L1, space_L3 -// For example, if apply "SSRSRS" to matrix multiplication, -// we have space iterators i and j, reduce iterator k. -// Then the tiling structure is : i0, j0, i1, j1, k0, i2, j2, k1, i3, j3 -State DoMultiLevelTiling(const State& state, int stage_id, const std::string& format, - std::vector* spatial_split_step_ids); - -// Apply tiling structure: space, space, space, ..., with tile sizes from other SplitStep -State FollowTiling(const State& state, int stage_id, - const std::vector& split_step_ids, int n_split); - -// Randomly mutate the tile size of one SplitStep -State RandomMutateTileSize(const State& old_state, SplitFactorizationMemo* split_memo, - std::mt19937* random_gen, int max_innermost_split_factor); - -// Randomly mutate the value of one auto_unroll_max_step PragmaStep -State RandomMutateMaxUnrollStep(const State& old_state, std::mt19937* random_gen, - const std::vector& auto_unroll_configs); - -// Randomly mutate the parallel degree of one stage. -State RandomMutateParallel(const State& old_state, std::mt19937* random_gen, - const SearchTask& task, int verbose = 0); - -// Randomly mutate the computation location of one stage. -State RandomMutateComputeLocation(const State& old_state, std::mt19937* random_gen, - const SearchTask& task); - -// GA: Crossover two states -State CrossOverState(const State& p1, const State& p2); - -// Prune undefined states. -void PruneUndefined(std::vector* states); - -} // namespace ansor -} // namespace tvm - -#endif // TVM_ANSOR_SEARCH_POLICY_UTILS_H_ diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index d84c3c57dc86..939fca83f1fb 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -96,26 +96,6 @@ struct Handler > { } writer->WriteArrayItem(IntArrayToVector(&tmp, ps->lengths)); writer->WriteArrayItem(static_cast(ps->inner_to_outer)); - } else if (auto ps = data[i].as<::tvm::ansor::FollowSplitStepNode>()) { - writer->WriteArrayItem(std::string("FSP")); - writer->WriteArrayItem(ps->stage_id); - writer->WriteArrayItem(ps->iter_id); - writer->WriteArrayItem(ps->src_step_id); - writer->WriteArrayItem(ps->n_split); - } else if (auto ps = data[i].as<::tvm::ansor::FollowFusedSplitStepNode>()) { - writer->WriteArrayItem(std::string("FFSP")); - writer->WriteArrayItem(ps->stage_id); - writer->WriteArrayItem(ps->iter_id); - - writer->WriteArraySeperator(); - writer->BeginArray(false); - for (int x : ps->src_step_ids) { - writer->WriteArrayItem(x); - } - writer->EndArray(); - - writer->WriteArrayItem(ps->level); - writer->WriteArrayItem(static_cast(ps->factor_or_nparts)); } else if (auto ps = data[i].as<::tvm::ansor::FuseStepNode>()) { writer->WriteArrayItem(std::string("FU")); writer->WriteArrayItem(ps->stage_id); @@ -126,52 +106,6 @@ struct Handler > { writer->WriteArrayItem(x); } writer->EndArray(); - } else if (auto ps = data[i].as<::tvm::ansor::AnnotationStepNode>()) { - writer->WriteArrayItem(std::string("AN")); - writer->WriteArrayItem(ps->stage_id); - writer->WriteArrayItem(ps->iter_id); - writer->WriteArrayItem(static_cast(ps->annotation)); - } else if (auto ps = data[i].as<::tvm::ansor::ComputeAtStepNode>()) { - writer->WriteArrayItem(std::string("CA")); - writer->WriteArrayItem(ps->stage_id); - writer->WriteArrayItem(ps->target_stage_id); - writer->WriteArrayItem(ps->target_iter_id); - } else if (auto ps = data[i].as<::tvm::ansor::ComputeRootStepNode>()) { - writer->WriteArrayItem(std::string("CR")); - writer->WriteArrayItem(ps->stage_id); - } else if (auto ps = data[i].as<::tvm::ansor::ComputeInlineStepNode>()) { - writer->WriteArrayItem(std::string("CI")); - writer->WriteArrayItem(ps->stage_id); - } else if (auto ps = data[i].as<::tvm::ansor::CacheReadStepNode>()) { - writer->WriteArrayItem(std::string("CHR")); - writer->WriteArrayItem(ps->stage_id); - writer->WriteArrayItem(ps->scope_name); - writer->WriteArrayItem(ps->reader_stage_ids); - } else if (auto ps = data[i].as<::tvm::ansor::CacheWriteStepNode>()) { - writer->WriteArrayItem(std::string("CHW")); - writer->WriteArrayItem(ps->stage_id); - writer->WriteArrayItem(ps->scope_name); - } else if (auto ps = data[i].as<::tvm::ansor::PragmaStepNode>()) { - writer->WriteArrayItem(std::string("PR")); - writer->WriteArrayItem(ps->stage_id); - writer->WriteArrayItem(ps->iter_id); - writer->WriteArrayItem(ps->pragma_type); - } else if (auto ps = data[i].as<::tvm::ansor::RfactorStepNode>()) { - writer->WriteArrayItem(std::string("RF")); - writer->WriteArrayItem(ps->stage_id); - writer->WriteArrayItem(ps->iter_id); - writer->WriteArrayItem(ps->factor_iter_id); - } else if (auto ps = data[i].as<::tvm::ansor::StorageAlignStepNode>()) { - writer->WriteArrayItem(std::string("SA")); - writer->WriteArrayItem(ps->stage_id); - writer->WriteArrayItem(ps->iter_id); - writer->WriteArrayItem(ps->factor); - writer->WriteArrayItem(ps->offset); - } else if (auto ps = data[i].as<::tvm::ansor::TensorizeStepNode>()) { - writer->WriteArrayItem(std::string("TS")); - writer->WriteArrayItem(ps->stage_id); - writer->WriteArrayItem(ps->iter_id); - writer->WriteArrayItem(ps->ti_func_name); } else { LOG(FATAL) << "Invalid step: " << data[i]; } @@ -183,10 +117,9 @@ struct Handler > { inline static void Read(dmlc::JSONReader* reader, std::vector<::tvm::ansor::Step> * data) { std::vector int_list; - bool s, inner_to_outer, factor_or_nparts; + bool s, inner_to_outer; std::string name, scope_name, pragma_type, ti_func_name; - int stage_id, target_stage_id, iter_id, src_step_id, n_split, ann, extent; - int level, factor_iter_id, factor, offset; + int stage_id, iter_id, extent; reader->BeginArray(); data->clear(); @@ -215,116 +148,12 @@ struct Handler > { stage_id, iter_id, extent, std::vector<::tvm::PrimExpr>(int_list.begin(), int_list.end()), inner_to_outer)); - } else if (name == "FSP") { - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&stage_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&iter_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&src_step_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&n_split); - data->push_back(::tvm::ansor::FollowSplitStep( - stage_id, iter_id, src_step_id, n_split)); - } else if (name == "FFSP") { - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&stage_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&iter_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&int_list); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&level); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&factor_or_nparts); - data->push_back(::tvm::ansor::FollowFusedSplitStep( - stage_id, iter_id, int_list, level, factor_or_nparts)); } else if (name == "FU") { s = reader->NextArrayItem(); CHECK(s); reader->Read(&stage_id); s = reader->NextArrayItem(); CHECK(s); reader->Read(&int_list); data->push_back(::tvm::ansor::FuseStep(stage_id, int_list)); - } else if (name == "AN") { - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&stage_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&iter_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&ann); - data->push_back(::tvm::ansor::AnnotationStep(stage_id, - iter_id, ::tvm::ansor::IteratorAnnotation(ann))); - } else if (name == "CA") { - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&stage_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&target_stage_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&iter_id); - data->push_back(::tvm::ansor::ComputeAtStep( - stage_id, target_stage_id, iter_id)); - } else if (name == "CR") { - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&stage_id); - data->push_back(::tvm::ansor::ComputeRootStep(stage_id)); - } else if (name == "CI") { - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&stage_id); - data->push_back(::tvm::ansor::ComputeInlineStep(stage_id)); - } else if (name == "CHR") { - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&stage_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&scope_name); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&int_list); - data->push_back(::tvm::ansor::CacheReadStep( - stage_id, scope_name, int_list)); - } else if (name == "CHW") { - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&stage_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&scope_name); - data->push_back(::tvm::ansor::CacheWriteStep( - stage_id, scope_name)); - } else if (name == "PR") { - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&stage_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&iter_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&pragma_type); - data->push_back(::tvm::ansor::PragmaStep( - stage_id, iter_id, pragma_type)); - } else if (name == "RF") { - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&stage_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&iter_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&factor_iter_id); - data->push_back(::tvm::ansor::RfactorStep( - stage_id, iter_id, factor_iter_id)); - } else if (name == "SA") { - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&stage_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&iter_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&factor); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&offset); - data->push_back(::tvm::ansor::StorageAlignStep( - stage_id, iter_id, factor, offset)); - } else if (name == "TS") { - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&stage_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&iter_id); - s = reader->NextArrayItem(); CHECK(s); - reader->Read(&ti_func_name); - data->push_back(::tvm::ansor::TensorizeStep( - stage_id, iter_id, ti_func_name)); } else { LOG(FATAL) << "Invalid step format"; } diff --git a/src/ansor/transform_step.cc b/src/ansor/transform_step.cc index e882a0495263..1bcea3f690c9 100644 --- a/src/ansor/transform_step.cc +++ b/src/ansor/transform_step.cc @@ -183,107 +183,6 @@ std::string SplitStepNode::PrintAsPythonAPI( lengths, inner_to_outer); } -/********** Follow Split **********/ -FollowSplitStep::FollowSplitStep(int stage_id, int iter_id, - int src_step_id, int n_split) { - auto node = make_object(); - node->stage_id = stage_id; - node->iter_id = iter_id; - node->src_step_id = src_step_id; - node->n_split = n_split; - data_ = std::move(node); -} - -void FollowSplitStepNode::ExtractSplitLengths( - const std::vector& transform_steps, - std::vector* lengths) const { - CHECK_LT(src_step_id, transform_steps.size()); - auto ps = transform_steps[src_step_id].as(); - CHECK(ps != nullptr); - - // get lengths from src step - lengths->reserve(n_split); - int j = 0; - for (; j < n_split - 1; ++j) { - lengths->push_back(ps->lengths[j]); - } - PrimExpr last_factor = 1; - for (; j < static_cast(ps->lengths.size()); ++j) { - if (ps->lengths[j].defined()) { - last_factor *= ps->lengths[j]; - } else { - last_factor = PrimExpr(); - break; - } - } - lengths->push_back(std::move(last_factor)); -} - -std::vector FollowSplitStepNode::ApplyToSchedule( - std::vector *stages, StageToAxesMap *stage_to_axes, - const std::vector& transform_steps) const { - std::vector lengths; - ExtractSplitLengths(transform_steps, &lengths); - return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, - lengths, true); -} - -std::string FollowSplitStepNode::PrintAsPythonAPI( - std::vector *stages, StageToAxesMap *stage_to_axes, - te::Schedule *schedule, const std::vector& transform_steps) const { - std::vector lengths; - ExtractSplitLengths(transform_steps, &lengths); - return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, - lengths, true); -} - -/********** Follow Fused Split **********/ -FollowFusedSplitStep::FollowFusedSplitStep(int stage_id, int iter_id, - const std::vector& src_step_ids, int level, bool factor_or_nparts) { - auto node = make_object(); - node->stage_id = stage_id; - node->iter_id = iter_id; - node->src_step_ids = src_step_ids;; - node->level = level; - node->factor_or_nparts = factor_or_nparts; - data_ = std::move(node); -} - -PrimExpr FollowFusedSplitStepNode::ExtractSplitLength( - const std::vector& transform_steps) const { - PrimExpr ret(1); - - for (int src_step_id : src_step_ids) { - CHECK_LT(src_step_id, transform_steps.size()); - auto ps = transform_steps[src_step_id].as(); - CHECK(ps != nullptr); - if (ps->lengths[level].defined() && ret.defined()) { - ret *= ps->lengths[level]; - } else { - return PrimExpr(); - } - } - - return ret; -} - -std::vector FollowFusedSplitStepNode::ApplyToSchedule( - std::vector *stages, StageToAxesMap *stage_to_axes, - const std::vector& transform_steps) const { - const PrimExpr& length = ExtractSplitLength(transform_steps); - return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, - {length}, factor_or_nparts); -} - -std::string FollowFusedSplitStepNode::PrintAsPythonAPI( - std::vector *stages, StageToAxesMap *stage_to_axes, - te::Schedule *schedule, const std::vector& transform_steps) const { - const PrimExpr& length = ExtractSplitLength(transform_steps); - return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, - {length}, factor_or_nparts); -} - - /********** Fuse **********/ FuseStep::FuseStep(int stage_id, const std::vector& fused_ids) { auto node = make_object(); @@ -337,506 +236,5 @@ std::string FuseStepNode::PrintAsPythonAPI(std::vector *stages, return ss.str(); } -/********** Annotation **********/ -AnnotationStep::AnnotationStep(int stage_id, int iter_id, - IteratorAnnotation ann) { - auto node = make_object(); - node->stage_id = stage_id; - node->iter_id = iter_id; - node->annotation = ann; - data_ = std::move(node); -} - -void AnnotationStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const { - te::Stage& stage = (*stages)[stage_id]; - const std::vector& axes = (*stage_to_axes)[stage]; - - switch (annotation) { - case kUnroll: stage.unroll(axes[iter_id]); break; - case kVectorize: stage.vectorize(axes[iter_id]); break; - case kParallel: stage.parallel(axes[iter_id]); break; - case kVThread: stage.bind(axes[iter_id], te::thread_axis(Range(), "vthread")); break; - case kBlockX: stage.bind(axes[iter_id], te::thread_axis(Range(), "blockIdx.x")); break; - case kBlockY: stage.bind(axes[iter_id], te::thread_axis(Range(), "blockIdx.y")); break; - case kThreadX: - if (axes[iter_id]->iter_type == kCommReduce) { - const auto &thread_x = te::thread_axis(Range(), "threadIdx.x"); - stage.bind(axes[iter_id], thread_x); - stage.set_store_predicate(thread_x->var == 0); - } else { - stage.bind(axes[iter_id], te::thread_axis(Range(), "threadIdx.x")); - } - break; - case kThreadY: stage.bind(axes[iter_id], te::thread_axis(Range(), "threadIdx.y")); break; - case kNone: break; - default: LOG(FATAL) << "Invalid Annotation " << annotation; break; - } -} - -std::string AnnotationStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - std::stringstream ss; - const auto& stage = (*stages)[stage_id]; - const auto& iter = (*stage_to_axes)[stage][iter_id]; - - bool bind_reduce_iter = iter->iter_type == kCommReduce && annotation == kThreadX; - if (bind_reduce_iter) { - ss << "thread_x = tvm.thread_axis(\"threadIdx.x\")\n"; - } - - ss << "s[" << CleanName(stage->op->name) << "]."; - switch (annotation) { - case kUnroll: ss << "unroll("; break; - case kVectorize: ss << "vectorize("; break; - case kParallel: ss << "parallel("; break; - case kVThread: - case kBlockX: - case kBlockY: - case kThreadX: - case kThreadY: ss << "bind("; break; - case kNone: break; - default: - LOG(FATAL) << "Invalid annotation " << annotation; break; - } - ss << CleanName(iter->var->name_hint); - switch (annotation) { - case kVThread: ss << ", tvm.thread_axis(\"vthread\")"; break; - case kBlockX: ss << ", tvm.thread_axis(\"blockIdx.x\")"; break; - case kBlockY: ss << ", tvm.thread_axis(\"blockIdy.y\")"; break; - case kThreadX: - if (bind_reduce_iter) { - ss << ", thread_x"; - } else { - ss << ", tvm.thread_axis(\"threadIdx.x\")"; - } - break; - case kThreadY: ss << ", tvm.thread_axis(\"threadIdx.y\")"; break; - default: break; - } - ss << ")\n"; - - if (bind_reduce_iter) { - ss << "s[" << CleanName(stage->op->name) << "]" - << ".set_store_predicate(thread_x.var.equal(0))\n"; - } - - ApplyToSchedule(stages, stage_to_axes); - return ss.str(); -} - -/********** Compute At **********/ -ComputeAtStep::ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id) { - auto node = make_object(); - node->stage_id = stage_id; - node->target_stage_id = target_stage_id; - node->target_iter_id = target_iter_id; - data_ = std::move(node); -} - -void ComputeAtStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const { - te::Stage& stage = (*stages)[stage_id]; - const IterVar& target_axis = - (*stage_to_axes)[(*stages)[target_stage_id]][target_iter_id]; - stage.compute_at((*stages)[target_stage_id], target_axis); -} - -std::string ComputeAtStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - std::stringstream ss; - const auto& stage = (*stages)[stage_id]; - const auto& target_stage = (*stages)[target_stage_id]; - - ss << "s[" << CleanName(stage->op->name) << "].compute_at(s[" - << CleanName(target_stage->op->name) << "], " - << CleanName((*stage_to_axes)[target_stage][target_iter_id]->var->name_hint); - - ss << ")\n"; - ApplyToSchedule(stages, stage_to_axes); - return ss.str(); -} - -/********** Compute Root **********/ -ComputeRootStep::ComputeRootStep(int stage_id) { - auto node = make_object(); - node->stage_id = stage_id; - data_ = std::move(node); -} - -void ComputeRootStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const { - (*stages)[stage_id].compute_root(); -} - -std::string ComputeRootStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - std::stringstream ss; - const auto& stage = (*stages)[stage_id]; - - ss << "s[" << CleanName(stage->op->name) << "].compute_root()\n"; - ApplyToSchedule(stages, stage_to_axes); - - return ss.str(); -} - -/********** Compute Inline **********/ -ComputeInlineStep::ComputeInlineStep(int stage_id) { - auto node = make_object(); - node->stage_id = stage_id; - data_ = std::move(node); -} - -void ComputeInlineStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const { - (*stages)[stage_id].compute_inline(); -} - -std::string ComputeInlineStepNode::PrintAsPythonAPI( - std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - std::stringstream ss; - const auto& stage = (*stages)[stage_id]; - - ss << "s[" << CleanName(stage->op->name) << "].compute_inline()\n"; - ApplyToSchedule(stages, stage_to_axes); - - return ss.str(); -} - -/********** Cache Read **********/ -CacheReadStep::CacheReadStep(int stage_id, std::string scope_name, - const std::vector& reader_stage_ids) { - auto node = make_object(); - node->stage_id = stage_id; - node->scope_name = std::move(scope_name); - node->reader_stage_ids = reader_stage_ids; - data_ = std::move(node); -} - -te::Tensor CacheReadStepNode::ApplyToSchedule(std::vector* stages, - StageToAxesMap *stage_to_axes, te::Schedule *schedule) const { - te::Stage& stage = (*stages)[stage_id]; - - Array readers; - for (const auto& i : reader_stage_ids) { - readers.push_back((*stages)[i]->origin_op); - } - auto out = schedule->cache_read(stage->origin_op.output(0), scope_name, readers); - - const auto& new_stage = (*schedule)[out->op]; - UpdateStageAxis(new_stage, stage_to_axes); - stages->insert(stages->begin() + stage_id + 1, new_stage); - - return out; -} - -std::string CacheReadStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - std::stringstream ss; - // copy stage here, for the original stage will change after apply - auto stage = (*stages)[stage_id]; - std::vector reader_stages; - for (size_t i = 0; i < reader_stage_ids.size(); ++i) { - reader_stages.push_back((*stages)[reader_stage_ids[i]]); - } - - auto out = ApplyToSchedule(stages, stage_to_axes, schedule); - - ss << CleanName(out->op->name) << " = " - << "s.cache_read(" << CleanName(stage->op->name) << ", \"" - << scope_name << "\", [" - << CleanName(reader_stages[0]->op->name); - for (size_t i = 1; i < reader_stage_ids.size(); ++i) { - ss << ", " << CleanName(reader_stages[i]->op->name); - } - ss << "])\n"; - - const auto& iters = out->op->root_iter_vars(); - for (size_t i = 0; i < iters.size(); ++i) { - ss << CleanName(iters[i]->var->name_hint); - if (i != iters.size() - 1) { - ss << ", "; - } - } - ss << " = " << "tuple(" << CleanName(out->op->name) - << ".op.axis)\n"; - - return ss.str(); -} - -/********** Cache Write **********/ -CacheWriteStep::CacheWriteStep(int stage_id, std::string scope_name) { - auto node = make_object(); - node->stage_id = stage_id; - node->scope_name = std::move(scope_name); - data_ = std::move(node); -} - -Array CacheWriteStepNode::ApplyToSchedule( - std::vector *stages, StageToAxesMap *stage_to_axes, - te::Schedule *schedule) const { - te::Stage& stage = (*stages)[stage_id]; - - Array tensor_array; - // If the target stage has multi outputs, TVM requires to cache_write - // all of them or schedule.cache_write will raise an error - for (auto i = 0; i < stage->op->num_outputs(); ++i) { - tensor_array.push_back(stage->origin_op.output(i)); - } - auto outs = schedule->cache_write(tensor_array, scope_name); - - UpdateStageAxis(stage, stage_to_axes); - // Even if there is multi outputs, TVM schedule only generate one - // new stage - const auto& new_stage = (*schedule)[outs[0]->op]; - UpdateStageAxis(new_stage, stage_to_axes); - stages->insert(stages->begin() + stage_id, new_stage); - - return outs; -} - -std::string CacheWriteStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - std::stringstream ss; - // copy stage here, for the original stage will change after apply - te::Stage stage = (*stages)[stage_id]; - - auto outs = ApplyToSchedule(stages, stage_to_axes, schedule); - - for (size_t i = 0; i < outs.size(); ++i) { - ss << CleanName(outs[i]->op->name) << ", "; - } - ss << "= " << "s.cache_write([" - << CleanName(stage->op.output(0)->op->name); - for (auto i = 1; i < stage->op->num_outputs(); ++i) { - ss << ", " << CleanName(stage->op.output(i)->op->name); - } - ss << "], \"" << scope_name << "\")\n"; - - for (const auto& out : outs) { - const auto& iters = out->op->root_iter_vars(); - for (size_t i = 0; i < iters.size(); ++i) { - ss << CleanName(iters[i]->var->name_hint); - if (i != iters.size() - 1) { - ss << ", "; - } - } - ss << " = " << "tuple(" << CleanName(out->op->name) - << ".op.axis)" - << " + " << "tuple(" << CleanName(out->op->name) - << ".op.reduce_axis)\n"; - } - - return ss.str(); -} - -/********** Pragma **********/ -PragmaStep::PragmaStep(int stage_id, int iter_id, std::string pragma_type) { - auto node = make_object(); - node->stage_id = stage_id; - node->iter_id = iter_id; - node->pragma_type = std::move(pragma_type); - data_ = std::move(node); -} - -void PragmaStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const { - te::Stage& stage = (*stages)[stage_id]; - const std::vector& axes = (*stage_to_axes)[stage]; - if (StrStartsWith(pragma_type, "auto_unroll_max_step")) { - size_t pos = pragma_type.find('$'); - int value = atoi(pragma_type.c_str() + pos + 1); - stage.pragma(axes[iter_id], "auto_unroll_max_step", value); - stage.pragma(axes[iter_id], "unroll_explicit", true); - } else { - stage.pragma(axes[iter_id], pragma_type); - } -} - -std::string PragmaStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - std::stringstream ss; - const auto& stage = (*stages)[stage_id]; - - if (StrStartsWith(pragma_type, "auto_unroll_max_step")) { - size_t pos = pragma_type.find('$'); - int value = atoi(pragma_type.c_str() + pos + 1); - ss << "s[" << CleanName(stage->op->name) << "].pragma(" - << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) - << ", \"auto_unroll_max_step\", " << value << ")\n"; - ss << "s[" << CleanName(stage->op->name) << "].pragma(" - << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) - << ", \"unroll_explicit\", True)\n"; - } else { - ss << "s[" << CleanName(stage->op->name) << "].pragma(" - << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", \"" - << pragma_type << "\")\n"; - } - - ApplyToSchedule(stages, stage_to_axes); - return ss.str(); -} - -/********** Rfactor **********/ -RfactorStep::RfactorStep(int stage_id, int iter_id, int factor_iter_id) { - auto node = make_object(); - node->stage_id = stage_id; - node->iter_id = iter_id; - node->factor_iter_id = factor_iter_id; - data_ = std::move(node); -} - -Array RfactorStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, te::Schedule *schedule) const { - const auto& stage = (*stages)[stage_id]; - const std::vector& axes = (*stage_to_axes)[stage]; - - const te::Tensor& tensor = stage->origin_op.output(0); - const IterVar& axis = axes[iter_id]; - auto outs = schedule->rfactor(tensor, axis, factor_iter_id); - - UpdateStageAxis(stage, stage_to_axes); - - const auto& new_stage = (*schedule)[outs[0]->op]; - UpdateStageAxis(new_stage, stage_to_axes); - stages->insert(stages->begin() + stage_id, new_stage); - - return outs; -} - -std::string RfactorStepNode::PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const { - std::stringstream ss; - const auto& stage = (*stages)[stage_id]; - - const auto& tensor_name = CleanName(stage->origin_op.output(0)->op->name); - const auto& axis_name = CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint); - - const auto& outs = ApplyToSchedule(stages, stage_to_axes, schedule); - - for (size_t i = 0; i < outs.size(); ++i) { - ss << CleanName(outs[i]->op->name); - if (i != outs.size() - 1) { - ss << ", "; - } - } - ss << " = " << "s.rfactor(" - << tensor_name << ", " - << axis_name << ", " - << factor_iter_id << ")\n"; - - for (const auto& out : outs) { - const auto& iters = out->op->root_iter_vars(); - for (size_t i = 0; i < iters.size(); ++i) { - ss << CleanName(iters[i]->var->name_hint); - if (i != iters.size() - 1) { - ss << ", "; - } - } - ss << " = " << "tuple(" << CleanName(out->op->name) - << ".op.axis)" - << " + " << "tuple(" << CleanName(out->op->name) - << ".op.reduce_axis)\n"; - } - - const auto& output = (*stages)[stage_id + 1]->op.output(0); - const auto& iters = output->op->root_iter_vars(); - for (size_t i = 0; i < iters.size(); ++i) { - ss << CleanName(iters[i]->var->name_hint); - if (i != iters.size() - 1) { - ss << ", "; - } - } - ss << " = " << "tuple(s[" << CleanName(output->op->name) - << "].op.axis)" - << " + " << "tuple(s[" << CleanName(output->op->name) - << "].op.reduce_axis)\n"; - - return ss.str(); -} - -/********** Storage Align **********/ -StorageAlignStep::StorageAlignStep(int stage_id, int iter_id, - int factor, int offset) { - auto node = make_object(); - node->stage_id = stage_id; - node->iter_id = iter_id; - node->factor = factor; - node->offset = offset; - data_ = std::move(node); -} - -void StorageAlignStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const { - te::Stage& stage = (*stages)[stage_id]; - const std::vector& axes = (*stage_to_axes)[stage]; - stage.storage_align(axes[iter_id], factor, offset); -} - -std::string StorageAlignStepNode::PrintAsPythonAPI( - std::vector *stages, StageToAxesMap *stage_to_axes, - te::Schedule *schedule, const std::vector& transform_steps) const { - std::stringstream ss; - const auto& stage = (*stages)[stage_id]; - ss << "s[" << CleanName(stage->op->name) << "].storage_align(" - << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", " - << factor << ", " << offset << ")\n"; - - ApplyToSchedule(stages, stage_to_axes); - return ss.str(); -} - -/********** Tensorize **********/ -TensorizeStep::TensorizeStep(int stage_id, int iter_id, - std::string ti_func_name) { - auto node = make_object(); - node->stage_id = stage_id; - node->iter_id = iter_id; - node->ti_func_name = ti_func_name; - data_ = std::move(node); -} - -void TensorizeStepNode::ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const { - te::Stage& stage = (*stages)[stage_id]; - const std::vector& axes = (*stage_to_axes)[stage]; - auto func = tvm::runtime::Registry::Get(ti_func_name); - CHECK(func != nullptr) << "Cannot find the tensorize intrinsic func"; - tvm::te::TensorIntrin res = (*func)(); - CHECK(res.defined()) << "Tensorize intrinsic func must return a " - << "tvm::te::TensorIntrin object"; - stage.tensorize(axes[iter_id], res); -} - -std::string TensorizeStepNode::PrintAsPythonAPI( - std::vector *stages, StageToAxesMap *stage_to_axes, - te::Schedule *schedule, const std::vector& transform_steps) const { - std::stringstream ss; - const auto& stage = (*stages)[stage_id]; - ss << "s[" << CleanName(stage->op->name) << "].tensorize(" - << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", " - << ti_func_name << "())\n"; - - ApplyToSchedule(stages, stage_to_axes); - return ss.str(); -} - } // namespace ansor } // namespace tvm diff --git a/src/ansor/transform_step.h b/src/ansor/transform_step.h index f8283b876f18..8eff6a4e7536 100644 --- a/src/ansor/transform_step.h +++ b/src/ansor/transform_step.h @@ -114,80 +114,6 @@ class SplitStep : public Step { TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode); }; -/*! \brief Similar to SplitStepNode, but use split factor from another step - * (i.e. Follow another split step) */ -class FollowSplitStepNode: public StepNode { - public: - int iter_id; // The id of the iter to split - int src_step_id; // The index of the split step to follow in the history - int n_split; // The number of split level - - void ExtractSplitLengths(const std::vector& transform_steps, - std::vector* lengths) const; - - std::vector ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, - const std::vector& transform_steps) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.FollowSplitStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object); -}; - -/*! - * \brief Managed reference to FollowSplitStepNode. - * \sa FollowSplitStepNode - */ -class FollowSplitStep : public Step { - public: - FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split); - - TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode); -}; - - -/*! \brief Similar to FollowSplitStep, but use split factors from multiple steps. - * \Note This can be used for the split in cooperative fetching - */ -class FollowFusedSplitStepNode: public StepNode { - public: - int iter_id; // The id of the iter to split - std::vector src_step_ids; // The indices of the split steps to follow in the history - int level; // Use the length in this split level - bool factor_or_nparts; // If this is true, use factor. Otherwise, use nparts - - PrimExpr ExtractSplitLength(const std::vector& transform_steps) const; - - std::vector ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, - const std::vector& transform_steps) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.FollowFusedSplitStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, Object); -}; - -/*! - * \brief Managed reference to FollowFusedSplitStepNode. - * \sa FollowFusedSplitStepNode - */ -class FollowFusedSplitStep : public Step { - public: - FollowFusedSplitStep(int stage_id, int iter_id, - const std::vector& src_step_ids, - int level, bool factor_or_nparts); - - TVM_DEFINE_OBJECT_REF_METHODS(FollowFusedSplitStep, Step, FollowFusedSplitStepNode); -}; - /*! \brief Fuse step that corresponds to te::Stage::fuse */ class FuseStepNode: public StepNode { public: @@ -216,298 +142,6 @@ class FuseStep : public Step { TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode); }; -/*! \brief Annotation step that corresponds to vectorize, parallel, unroll and thread binding. - * (i.e. te::Stage::vectorize, te::Stage::parallel, te::Stage::vectorize, te::Stage::bind) - */ -class AnnotationStepNode: public StepNode { - public: - int iter_id; - IteratorAnnotation annotation; - - void ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.AnnotationStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, Object); -}; - -/*! - * \brief Managed reference to AnnotationStepNode. - * \sa AnnotationStepNode - */ -class AnnotationStep : public Step { - public: - AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann); - - TVM_DEFINE_OBJECT_REF_METHODS(AnnotationStep, Step, AnnotationStepNode); -}; - -/*! \brief Compute at step that corresponds to te::Stage::compute_at */ -class ComputeAtStepNode: public StepNode { - public: - int target_stage_id; - int target_iter_id; - - void ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.ComputeAtStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, Object); -}; - -/*! - * \brief Managed reference to ComputeAtStepNode. - * \sa ComputeAtStepNode - */ -class ComputeAtStep : public Step { - public: - ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id); - - TVM_DEFINE_OBJECT_REF_METHODS(ComputeAtStep, Step, ComputeAtStepNode); -}; - -/*! \brief Compute root step that corresponds to te::Stage::compute_root */ -class ComputeRootStepNode: public StepNode { - public: - void ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.ComputeRootStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, Object); -}; - -/*! - * \brief Managed reference to ComputeRootStepNode. - * \sa ComputeRootStepNode - */ -class ComputeRootStep : public Step { - public: - explicit ComputeRootStep(int stage_id); - - TVM_DEFINE_OBJECT_REF_METHODS(ComputeRootStep, Step, ComputeRootStepNode); -}; - -/*! \brief Compute inline step that corresponds to te::Stage::compute_inline */ -class ComputeInlineStepNode: public StepNode { - public: - void ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.ComputeInlineStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, Object); -}; - -/*! - * \brief Managed reference to ComputeInlineStepNode. - * \sa ComputeInlineStepNode - */ -class ComputeInlineStep : public Step { - public: - explicit ComputeInlineStep(int stage_id); - - TVM_DEFINE_OBJECT_REF_METHODS(ComputeInlineStep, Step, ComputeInlineStepNode); -}; - -/*! \brief Cache read step that corresponds to te::Schedule::cache_read */ -class CacheReadStepNode: public StepNode { - public: - std::string scope_name; - std::vector reader_stage_ids; - - te::Tensor ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.CacheReadStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(CacheReadStepNode, Object); -}; - -/*! - * \brief Managed reference to CacheReadStepNode. - * \sa CacheReadStepNode - */ -class CacheReadStep : public Step { - public: - CacheReadStep(int stage_id, std::string scope_name, - const std::vector& reader_stage_id); - - TVM_DEFINE_OBJECT_REF_METHODS(CacheReadStep, Step, CacheReadStepNode); -}; - -/*! \brief Cache write step that corresponds to te::Schedule::cache_write - * \Note This step will cache_write all output tensors of target stage */ -class CacheWriteStepNode: public StepNode { - public: - std::string scope_name; - - Array ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.CacheWriteStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(CacheWriteStepNode, Object); -}; - -/*! - * \brief Managed reference to CacheWriteStepNode. - * \sa CacheWriteStepNode - */ -class CacheWriteStep : public Step { - public: - CacheWriteStep(int stage_id, std::string scope_name); - - TVM_DEFINE_OBJECT_REF_METHODS(CacheWriteStep, Step, CacheWriteStepNode); -}; - -/*! \brief Pragma step that corresponds to te::Schedule::pragma */ -class PragmaStepNode: public StepNode { - public: - int iter_id; - std::string pragma_type; - - void ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.PragmaStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(PragmaStepNode, Object); -}; - -/*! - * \brief Managed reference to PragmaStepNode. - * \sa PragmaStepNode - */ -class PragmaStep : public Step { - public: - PragmaStep(int stage_id, int iter_id, std::string pragma_type); - - TVM_DEFINE_OBJECT_REF_METHODS(PragmaStep, Step, PragmaStepNode); -}; - -/*! \brief Reduction factor step that corresponds to te::Schedule::rfactor */ -class RfactorStepNode: public StepNode { - public: - int iter_id; - int factor_iter_id; - - Array ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.RfactorStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(RfactorStepNode, Object); -}; - -/*! - * \brief Managed reference to RfactorStepNode. - * \sa RfactorStepNode - */ -class RfactorStep : public Step { - public: - RfactorStep(int stage_id, int iter_id, int factor_iter_id); - - TVM_DEFINE_OBJECT_REF_METHODS(RfactorStep, Step, RfactorStepNode); -}; - -/*! \brief Storage align step that corresponds to te::Schedule::storage_align */ -class StorageAlignStepNode: public StepNode { - public: - int iter_id; - int factor; - int offset; - - void ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.StorageAlignStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(StorageAlignStepNode, Object); -}; - -/*! - * \brief Managed reference to StorageAlignStepNode. - * \sa StorageAlignStepNode - */ -class StorageAlignStep : public Step { - public: - StorageAlignStep(int stage_id, int iter_id, int factor, int offset); - - TVM_DEFINE_OBJECT_REF_METHODS(StorageAlignStep, Step, StorageAlignStepNode); -}; - -/*! \brief Tensorize step that corresponds to te::Schedule::tensorize - * \Note This step takes a global registered function name as input. */ -class TensorizeStepNode: public StepNode { - public: - int iter_id; - std::string ti_func_name; - - void ApplyToSchedule(std::vector *stages, - StageToAxesMap *stage_to_axes) const; - - std::string PrintAsPythonAPI(std::vector *stages, - StageToAxesMap *stage_to_axes, - te::Schedule *schedule, - const std::vector& transform_steps) const final; - - static constexpr const char* _type_key = "ansor.TensorizeStep"; - TVM_DECLARE_FINAL_OBJECT_INFO(TensorizeStepNode, Object); -}; - -/*! - * \brief Managed reference to TensorizeStepNode. - * \sa TensorizeStepNode - */ -class TensorizeStep : public Step { - public: - TensorizeStep(int stage_id, int iter_id, std::string ti_func_name); - - TVM_DEFINE_OBJECT_REF_METHODS(TensorizeStep, Step, TensorizeStepNode); -}; - } // namespace ansor } // namespace tvm @@ -536,69 +170,10 @@ struct hash<::tvm::ansor::Step> { } } return ret; - } else if (auto ps = step.as<::tvm::ansor::FollowSplitStepNode>()) { - return ::dmlc::HashCombine(3, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - ::dmlc::HashCombine(std::hash()(ps->src_step_id), - ps->n_split)))); - } else if (auto ps = step.as<::tvm::ansor::FollowFusedSplitStepNode>()) { - return ::dmlc::HashCombine(4, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - ::dmlc::HashCombine(std::hash>()(ps->src_step_ids), - ::dmlc::HashCombine(std::hash()(ps->level), - ps->factor_or_nparts))))); } else if (auto ps = step.as<::tvm::ansor::FuseStepNode>()) { - return ::dmlc::HashCombine(5, + return ::dmlc::HashCombine(3, ::dmlc::HashCombine(std::hash()(ps->stage_id), ps->fused_ids)); - } else if (auto ps = step.as<::tvm::ansor::AnnotationStepNode>()) { - return ::dmlc::HashCombine(6, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - static_cast(ps->annotation)))); - } else if (auto ps = step.as<::tvm::ansor::ComputeAtStepNode>()) { - return ::dmlc::HashCombine(7, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->target_stage_id), - ps->target_iter_id))); - } else if (auto ps = step.as<::tvm::ansor::ComputeRootStepNode>()) { - return ::dmlc::HashCombine(8, - ps->stage_id); - } else if (auto ps = step.as<::tvm::ansor::ComputeInlineStepNode>()) { - return ::dmlc::HashCombine(9, - ps->stage_id); - } else if (auto ps = step.as<::tvm::ansor::CacheReadStepNode>()) { - return ::dmlc::HashCombine(10, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->scope_name), - ps->reader_stage_ids))); - } else if (auto ps = step.as<::tvm::ansor::CacheWriteStepNode>()) { - return ::dmlc::HashCombine(11, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ps->scope_name)); - } else if (auto ps = step.as<::tvm::ansor::PragmaStepNode>()) { - return ::dmlc::HashCombine(12, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - ps->pragma_type))); - } else if (auto ps = step.as<::tvm::ansor::RfactorStepNode>()) { - return ::dmlc::HashCombine(13, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - ps->factor_iter_id))); - } else if (auto ps = step.as<::tvm::ansor::StorageAlignStepNode>()) { - return ::dmlc::HashCombine(14, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - ::dmlc::HashCombine(std::hash()(ps->factor), - ps->offset)))); - } else if (auto ps = step.as<::tvm::ansor::TensorizeStepNode>()) { - return ::dmlc::HashCombine(15, - ::dmlc::HashCombine(std::hash()(ps->stage_id), - ::dmlc::HashCombine(std::hash()(ps->iter_id), - ps->ti_func_name))); } else { LOG(FATAL) << "Invalid step"; } diff --git a/tests/python/unittest/test_ansor_common.py b/tests/python/unittest/test_ansor_common.py index 485679d6aa4e..62ebeb99a6c8 100644 --- a/tests/python/unittest/test_ansor_common.py +++ b/tests/python/unittest/test_ansor_common.py @@ -60,15 +60,10 @@ def get_tiled_matmul(): dag = ansor.ComputeDAG([A, B, C]) s0 = dag.get_init_state() - C_global = s0.cache_write(C, "global") its0 = s0.split(C, s0[C].iters[0], [4, 8, 8]) its1 = s0.split(C, s0[C].iters[4], [8, 4, 4]) - s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], its1[3]]) - s0.compute_at(C_global, C, s0[C].iters[3]) - s0.split(C_global, s0[C_global].iters[2], [16]) - B_global = s0.cache_read(B, "global", [C_global]) - s0.compute_at(B_global, C_global, s0[C_global].iters[0]) - A_global = s0.cache_read(A, "global", [C_global]) - s0.compute_at(A_global, C_global, s0[C_global].iters[2]) + s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], its1[3], + s0[C].iters[8]]) + return dag, s0 diff --git a/tests/python/unittest/test_ansor_compute_dag.py b/tests/python/unittest/test_ansor_compute_dag.py index 0768f82b805a..e5af07b31e0d 100644 --- a/tests/python/unittest/test_ansor_compute_dag.py +++ b/tests/python/unittest/test_ansor_compute_dag.py @@ -34,15 +34,6 @@ def test_infer_bound(): dag, s = get_tiled_matmul() s = dag.infer_bound_from_state(s) - A_global = s.stage_ops[1] - B_global = s.stage_ops[3] - C_global = s.stage_ops[4] - assert s[B_global].iters[0].range.extent == 512 - assert s[B_global].iters[1].range.extent == 16 - assert s[A_global].iters[0].range.extent == 1 - assert s[A_global].iters[1].range.extent == 16 - assert s[C_global].iters[0].range.extent == 64 - def test_estimate_flop(): dag, s = get_tiled_matmul() @@ -50,25 +41,7 @@ def test_estimate_flop(): assert abs(dag.flop_ct - 2 * 512 ** 3) < 0.5 -def test_lower_legalize_invalid_attach(): - N, M = 10, 10 - - A = te.compute((N, M), lambda i, j: 1.0, name='A') - B = te.compute((N, M), lambda i, j: A[i][j], name='B') - - dag = ansor.ComputeDAG([A, B]) - s = dag.get_init_state() - - s.compute_at(A, B, s[B].iters[1]) - s.split(B, s[B].iters[1], [2]) - - sch, tensors = dag.apply_steps_from_state(s) - stmt = tvm.lower(sch, tensors, simple_mode=True) - - if __name__ == "__main__": test_apply_steps() test_infer_bound() test_estimate_flop() - test_lower_legalize_invalid_attach() - From a8e589e88cbc0c78832ecc44efe8514bb1b65d5f Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 24 Jun 2020 16:34:24 +0800 Subject: [PATCH 44/45] UT ready --- python/tvm/ansor/__init__.py | 5 +- python/tvm/ansor/auto_schedule.py | 13 ++- python/tvm/ansor/cost_model/__init__.py | 20 ---- python/tvm/ansor/cost_model/cost_model.py | 46 --------- python/tvm/ansor/measure.py | 98 ------------------- src/ansor/measure.cc | 42 -------- src/ansor/measure.h | 36 ------- src/ansor/search_policy/empty_policy.cc | 90 +++++++++++++++++ src/ansor/search_policy/empty_policy.h | 79 +++++++++++++++ src/ansor/search_policy/search_policy.cc | 60 +----------- src/ansor/search_policy/search_policy.h | 38 ------- .../unittest/test_ansor_search_policy.py | 5 +- 12 files changed, 186 insertions(+), 346 deletions(-) delete mode 100644 python/tvm/ansor/cost_model/__init__.py delete mode 100644 python/tvm/ansor/cost_model/cost_model.py create mode 100644 src/ansor/search_policy/empty_policy.cc create mode 100644 src/ansor/search_policy/empty_policy.h diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index ccd8f27b71c1..93a82f073ac3 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -27,9 +27,8 @@ # Shortcut from .compute_dag import ComputeDAG from .auto_schedule import SearchTask, TuneOption, HardwareParams, \ - auto_schedule -from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, LocalRPCMeasureContext -from .cost_model import RandomModel + auto_schedule, EmptyPolicy +from .measure import MeasureInput, LocalBuilder, LocalRunner from .serialization import LogToFile, LogReader, best_measure_pair_in_file, \ load_from_file, write_measure_records_to_file from .workload_registry import register_workload_func, \ diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index 37e622018658..33f6dbcba7ff 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -22,7 +22,6 @@ import tvm._ffi from tvm.runtime import Object from .measure import LocalBuilder, LocalRunner -from .cost_model import RandomModel from . import _ffi_api @@ -82,10 +81,20 @@ def set_verbose(self, verbose): def run_callbacks(self, callbacks): _ffi_api.SearchPolicyRunCallbacks(self, callbacks) + +@tvm._ffi.register_object("ansor.EmptyPolicy") +class EmptyPolicy(SearchPolicy): + """ The example search policy + """ + def __init__(self): + self.__init_handle_by_constructor__(_ffi_api.EmptyPolicy) + + @tvm._ffi.register_object("ansor.SearchCallback") class SearchCallback(Object): """Callback function before or after search process""" + @tvm._ffi.register_object("ansor.TuneOption") class TuneOption(Object): """ The options for tuning @@ -164,7 +173,7 @@ def auto_schedule(workload, target=None, """ if isinstance(search_policy, str): if search_policy == 'default': - search_policy = SketchSearchPolicy(RandomModel()) + search_policy = EmptyPolicy() else: raise ValueError("Invalid search policy: " + search_policy) diff --git a/python/tvm/ansor/cost_model/__init__.py b/python/tvm/ansor/cost_model/__init__.py deleted file mode 100644 index 1454da451b61..000000000000 --- a/python/tvm/ansor/cost_model/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-import, redefined-builtin -""" Cost model that estimates the performance of programs """ - -from .cost_model import RandomModel \ No newline at end of file diff --git a/python/tvm/ansor/cost_model/cost_model.py b/python/tvm/ansor/cost_model/cost_model.py deleted file mode 100644 index 605db14c19c3..000000000000 --- a/python/tvm/ansor/cost_model/cost_model.py +++ /dev/null @@ -1,46 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" Cost model that estimates the performance of programs """ -import ctypes -import numpy as np - -import tvm._ffi -from tvm.runtime import Object -from .. import _ffi_api - - -@tvm._ffi.register_object("ansor.CostModel") -class CostModel(Object): - """The base class for cost model""" - - -@tvm._ffi.register_object("ansor.RandomModel") -class RandomModel(Object): - """A model returns random estimation for all inputs""" - def __init__(self): - self.__init_handle_by_constructor__(_ffi_api.RandomModel) - - -@tvm._ffi.register_func("ansor.cost_model.random_number") -def random_number(n, return_ptr): - """ A random number generator func for c++'s RandomModel """ - if n == 0: - return - return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_float)) - array_wrapper = np.ctypeslib.as_array(return_ptr, shape=(n,)) - array_wrapper[:] = np.random.uniform(0, 1, (n,)) diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index c9a5ef013cc7..b8b02e2b85df 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -178,104 +178,6 @@ def __init__(self, self.__init_handle_by_constructor__( _ffi_api.LocalRunner, timeout, number, repeat, min_repeat_ms, cooldown_interval) -@tvm._ffi.register_object("ansor.ProgramMeasurer") -class ProgramMeasurer(Object): - """ - Parameters - ---------- - builder : Builder - runner : Runner - callbacks : List[MeasureCallback] - verbose : Int - max_continuous_error : Float - """ - - def __init__(self, builder: Builder, runner: Runner, - callbacks: List[MeasureCallback], - verbose: int, max_continuous_error: int = -1): - self.__init_handle_by_constructor__( - _ffi_api.ProgramMeasurer, builder, runner, callbacks, verbose, max_continuous_error) - -@tvm._ffi.register_object("ansor.RPCRunner") -class RPCRunner(Runner): - """ - Parameters - ---------- - key : Str - host : Str - port : Int - priority : Int - n_parallel : Int - timeout : Int - number : Int - repeat : Int - min_repeat_ms : Int - cooldown_interval : Float - """ - - def __init__(self, key, host, port, priority=1, - n_parallel=1, - timeout=10, - number=3, - repeat=1, - min_repeat_ms=0, - cooldown_interval=0.0): - self.__init_handle_by_constructor__( - _ffi_api.RPCRunner, key, host, port, priority, timeout, n_parallel, - number, repeat, min_repeat_ms, cooldown_interval) - - if check_remote(key, host, port, priority, timeout): - LOGGER.info("Get devices for measurement successfully!") - else: - raise RuntimeError("Cannot get remote devices from the tracker. " - "Please check the status of tracker by " - "'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' " - "and make sure you have free devices on the queue status.") - - -class LocalRPCMeasureContext: - """ A context wrapper for running RPCRunner locally. - This will launch a local RPC Tracker and local RPC Server. - - Parameters - ---------- - priority : Int - n_parallel : Int - timeout : Int - number : Int - repeat : Int - min_repeat_ms : Int - cooldown_interval : Float - """ - - def __init__(self, - priority=1, - n_parallel=1, - timeout=10, - number=10, - repeat=1, - min_repeat_ms=0, - cooldown_interval=0.0): - ctx = tvm.context("cuda", 0) - if ctx.exist: - cuda_arch = "sm_" + "".join(ctx.compute_version.split('.')) - set_cuda_target_arch(cuda_arch) - host = '0.0.0.0' - self.tracker = Tracker(host, port=9000, port_end=10000, silent=True) - device_key = '$local$device$%d' % self.tracker.port - self.server = Server(host, port=self.tracker.port, port_end=10000, - key=device_key, use_popen=True, silent=True, - tracker_addr=(self.tracker.host, self.tracker.port)) - self.runner = RPCRunner(device_key, host, self.tracker.port, priority, - n_parallel, timeout, number, repeat, - min_repeat_ms, cooldown_interval) - # wait for the processes to start - time.sleep(0.5) - - def __del__(self): - self.server.terminate() - self.tracker.terminate() - class MeasureErrorNo(object): """Error type for MeasureResult""" diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc index e99f41725077..c50191813b2e 100644 --- a/src/ansor/measure.cc +++ b/src/ansor/measure.cc @@ -40,9 +40,7 @@ TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode); TVM_REGISTER_OBJECT_TYPE(RunnerNode); TVM_REGISTER_OBJECT_TYPE(BuilderNode); TVM_REGISTER_OBJECT_TYPE(LocalBuilderNode); -TVM_REGISTER_OBJECT_TYPE(RPCRunnerNode); TVM_REGISTER_OBJECT_TYPE(LocalRunnerNode); -TVM_REGISTER_OBJECT_TYPE(ProgramMeasurerNode); const char* ErrorNoToStr[] = { "NoError", @@ -127,38 +125,6 @@ Array LocalBuilderNode::Build(const Array& inputs, return Array(); } -// RPC Runner -RPCRunner::RPCRunner(const std::string& key, const std::string& host, int port, - int priority, int timeout, int n_parallel, int number, - int repeat, int min_repeat_ms, double cooldown_interval) { - auto node = make_object(); - node->key = key; - node->host = host; - node->port = port; - node->priority = priority; - node->timeout = timeout; - node->n_parallel = n_parallel; - node->number = number; - node->repeat = repeat; - node->min_repeat_ms = min_repeat_ms; - node->cooldown_interval = cooldown_interval; - data_ = std::move(node); -} - -Array RPCRunnerNode::Run(const Array& inputs, - const Array& build_results, - int verbose) { - if (const auto* f = runtime::Registry::Get("ansor.rpc_runner.run")) { - Array results = (*f)( - inputs, build_results, key, host, port, priority, timeout, n_parallel, - number, repeat, min_repeat_ms, cooldown_interval, verbose); - return results; - } else { - LOG(FATAL) << "ansor.rpc_runner.run is not registered"; - } - return Array(); -} - // Local Runner LocalRunner::LocalRunner(int timeout, int number, int repeat, int min_repeat_ms, double cooldown_interval) { @@ -379,14 +345,6 @@ TVM_REGISTER_GLOBAL("ansor.LocalRunner") return LocalRunner(timeout, number, repeat, min_repeat_ms, cooldown_interval); }); -TVM_REGISTER_GLOBAL("ansor.RPCRunner") -.set_body_typed([](const std::string& key, const std::string& host, int port, - int priority, int timeout, int n_parallel, int number, - int repeat, int min_repeat_ms, double cooldown_interval){ - return RPCRunner(key, host, port, priority, timeout, n_parallel, number, - repeat, min_repeat_ms, cooldown_interval); -}); - TVM_REGISTER_GLOBAL("ansor.ProgramMeasurer") .set_body_typed([](Builder builder, Runner runner, Array callbacks, int verbose, diff --git a/src/ansor/measure.h b/src/ansor/measure.h index 760a1542944f..630365512eb6 100644 --- a/src/ansor/measure.h +++ b/src/ansor/measure.h @@ -219,42 +219,6 @@ class LocalBuilder: public Builder { TVM_DEFINE_OBJECT_REF_METHODS(LocalBuilder, Builder, LocalBuilderNode); }; -/*! \brief RPCRunner that uses RPC call to measures the time cost of programs - * on remote devices */ -class RPCRunnerNode : public RunnerNode { - public: - std::string key; - std::string host; - int port; - int priority; - int n_parallel; - int number; - int repeat; - int min_repeat_ms; - double cooldown_interval; - - /*! \biref Run measurement and return results */ - Array Run(const Array& inputs, - const Array& build_results, - int verbose) final; - - static constexpr const char* _type_key = "ansor.RPCRunner"; - TVM_DECLARE_FINAL_OBJECT_INFO(RPCRunnerNode, RunnerNode); -}; - -/*! - * \brief Managed reference to RPCRunnerNode. - * \sa RPCRunnerNode - */ -class RPCRunner : public Runner { - public: - RPCRunner(const std::string& key, const std::string& host, int port, - int priority, int timeout, int n_parallel, int number, - int repeat, int min_repeat_ms, double cooldown_interval); - - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RPCRunner, Runner, RPCRunnerNode); -}; - /*! \brief LocalRunner that uses local CPU/GPU to measures the time cost of programs */ class LocalRunnerNode: public RunnerNode { public: diff --git a/src/ansor/search_policy/empty_policy.cc b/src/ansor/search_policy/empty_policy.cc new file mode 100644 index 000000000000..16b22c27ef92 --- /dev/null +++ b/src/ansor/search_policy/empty_policy.cc @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "empty_policy.h" + +#include + +namespace tvm { +namespace ansor { + +TVM_REGISTER_NODE_TYPE(EmptyPolicyNode); + +State EmptyPolicyNode::Search(SearchTask task, int n_trials, int early_stopping, + int num_measure_per_iter, int verbose, ProgramMeasurer measurer, + Array pre_search_callbacks) { + cur_task = task; + + // Run pre_search_callbacks before the search process + // This Interface is used to set some init status + RunCallbacks(pre_search_callbacks); + + if (n_trials <= 1) { + const auto& res = SearchOneRound(); + CHECK_GT(res.size(), 0); + return res[0]; + } else { + std::vector inputs; + std::vector results; + + measurer->Reset(); + int ct = 0; + while (ct < n_trials) { + const auto& res = SearchOneRound(); + ct += res.size(); + inputs.clear(); + for (const auto& state : res) { + inputs.emplace_back(cur_task, state); + } + measurer->Measure(cur_task, GetRef(this), inputs, &results); + } + + return measurer->best_state[cur_task->workload_key]; + } +} + +std::pair, Array > EmptyPolicyNode::ContinueSearchOneRound( + SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) { + std::vector inputs; + std::vector results; + + const auto& res = SearchOneRound(); + for (const auto& state : res) { + inputs.emplace_back(cur_task, state); + } + measurer->Measure(cur_task, GetRef(this), inputs, &results); + + Array inputs_arr(std::make_move_iterator(inputs.begin()), + std::make_move_iterator(inputs.end())); + Array results_arr(std::make_move_iterator(results.begin()), + std::make_move_iterator(results.end())); + return std::make_pair(std::move(inputs_arr), std::move(results_arr)); +} + +std::vector EmptyPolicyNode::SearchOneRound() { + std::vector res; + res.push_back(cur_task->compute_dag.GetInitState()); + return res; +} + +TVM_REGISTER_GLOBAL("ansor.EmptyPolicy") +.set_body_typed([]() { return EmptyPolicy(make_object()); }); + +} // namespace ansor +} // namespace tvm diff --git a/src/ansor/search_policy/empty_policy.h b/src/ansor/search_policy/empty_policy.h new file mode 100644 index 000000000000..78a3838f862a --- /dev/null +++ b/src/ansor/search_policy/empty_policy.h @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ansor/search_policy/empty_policy.h + * \brief This is an basic example of search policy + */ + +#ifndef TVM_ANSOR_SEARCH_POLICY_EMPTY_POLICY_H_ +#define TVM_ANSOR_SEARCH_POLICY_EMPTY_POLICY_H_ + +#include +#include + +#include "search_policy.h" + +namespace tvm { +namespace ansor { + +/*! + * \brief This is an basic example of search policy + */ +class EmptyPolicyNode : public SearchPolicyNode { + public: + /*! \brief Search and make n_trails measurements. + * \returns the best state + */ + State Search(SearchTask task, int n_trials, + int early_stopping, int num_measure_per_iter, + int verbose, ProgramMeasurer measurer, + Array pre_search_callbacks) final; + + /*! \brief Continue search for one round. This is used by JointTuner + * \returns the measurement pairs + */ + std::pair, Array > ContinueSearchOneRound( + SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) final; + + static constexpr const char *_type_key = "ansor.EmptyPolicy"; + TVM_DECLARE_FINAL_OBJECT_INFO(EmptyPolicyNode, SearchPolicyNode); + + private: + /*! + * \brief Usually we need a sub function to generate several candidate states in each + * search round. + * \returns Several generated states + */ + std::vector SearchOneRound(); +}; + +/*! + * \brief Managed reference to EmptyPolicyNode. + * \sa EmptyPolicyNode + */ +class EmptyPolicy : public SearchPolicy { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EmptyPolicy, SearchPolicy, EmptyPolicyNode); +}; + +} // namespace ansor +} // namespace tvm + +#endif // TVM_ANSOR_SEARCH_POLICY_EMPTY_POLICY_H_ \ No newline at end of file diff --git a/src/ansor/search_policy/search_policy.cc b/src/ansor/search_policy/search_policy.cc index b86bf9490851..e7a12702ba70 100644 --- a/src/ansor/search_policy/search_policy.cc +++ b/src/ansor/search_policy/search_policy.cc @@ -29,51 +29,8 @@ namespace tvm { namespace ansor { +TVM_REGISTER_OBJECT_TYPE(SearchCallbackNode); TVM_REGISTER_OBJECT_TYPE(SearchPolicyNode); -TVM_REGISTER_OBJECT_TYPE(PreloadMeasuredStatesNode); - -void SearchPolicyNode::PreloadMeasuredStates(const std::string& log_file) { - LogReader reader = LogReader(log_file); - const auto& res = reader->ReadLines(-1); - size_t log_size = res.first.size(); - CHECK_EQ(log_size, res.second.size()); - if (log_size) { - std::vector measured_states; - std::vector measured_throughputs; - for (size_t i = 0; i < log_size; i++) { - const auto& inp = res.first[i]; - if (inp->task->workload_key == cur_task->workload_key && - inp->task->target->target_name.compare( - cur_task->target->target_name) == 0) { - State state = cur_task->compute_dag.GetInitState(); - state.CopyOnWrite()->transform_steps = inp->state->transform_steps; - state.DoSteps(inp->state->transform_steps, cur_task->compute_dag); - measured_states.emplace_back(std::move(state)); - measured_throughputs.push_back(res.second[i]->error_no == 0 ? - (1.0 / FloatArrayMean(res.second[i]->costs)) : 0.0); - } - } - cur_task->compute_dag.InferBound(&measured_states); - for (size_t i = 0; i < measured_states.size(); i ++) { - auto& state = measured_states[i]; - const auto& state_str = state.ToStr(); - if (!measured_states_set_.count(state_str)) { - measured_states_set_.insert(state_str); - if (measured_throughputs[i] != 0.0) { - measured_states_vector_.emplace_back(std::move(state)); - measured_states_throughputs_.emplace_back(measured_throughputs[i]); - } - } - } - - StdCout(verbose) << "Successfully load " << measured_states_set_.size() - << " measurement records from " << log_file - << " for " << cur_task->workload_key << std::endl; - } else { - StdCout(verbose) << "No measurement records found in " - << log_file << " for " << cur_task->workload_key << std::endl; - } -} void SearchPolicyNode::RunCallbacks(const Array& callbacks) { if (callbacks.defined() && callbacks.size()) { @@ -83,16 +40,6 @@ void SearchPolicyNode::RunCallbacks(const Array& callbacks) { } } -PreloadMeasuredStates::PreloadMeasuredStates(std::string filename) { - auto node = make_object(); - node->filename = std::move(filename); - data_ = std::move(node); -} - -void PreloadMeasuredStatesNode::callback(SearchPolicyNode* policy) { - policy->PreloadMeasuredStates(filename); -} - // Search Policy TVM_REGISTER_GLOBAL("ansor.SearchPolicyContinueSearchOneRound") .set_body_typed([](SearchPolicy policy, SearchTask task, int num_measure, @@ -118,10 +65,5 @@ TVM_REGISTER_GLOBAL("ansor.SearchPolicySetVerbose") policy->verbose = verbose; }); -TVM_REGISTER_GLOBAL("ansor.PreloadMeasuredStates") -.set_body_typed([](std::string filename) { - return PreloadMeasuredStates(filename); -}); - } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_policy/search_policy.h b/src/ansor/search_policy/search_policy.h index 03e7c3f025df..eb4703be1914 100644 --- a/src/ansor/search_policy/search_policy.h +++ b/src/ansor/search_policy/search_policy.h @@ -48,30 +48,6 @@ class SearchCallbackNode : public Object { }; TVM_DEFINE_MUTABLE_OBJECT_REF(SearchCallback, SearchCallbackNode); -/*! \brief Preload measured states from a log file. - * This can resume the state of the search policy */ -class PreloadMeasuredStatesNode : public SearchCallbackNode { - public: - std::string filename; - - void callback(SearchPolicyNode* policy) final; - - static constexpr const char *_type_key = "ansor.PreloadMeasuredStates"; - TVM_DECLARE_FINAL_OBJECT_INFO(PreloadMeasuredStatesNode, SearchCallbackNode); -}; - -/*! - * \brief Managed reference to PreloadMeasuredStatesNode. - * \sa PreloadMeasuredStatesNode - */ -class PreloadMeasuredStates : public SearchCallback { - public: - explicit PreloadMeasuredStates(std::string filename); - - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PreloadMeasuredStates, SearchCallback, - PreloadMeasuredStatesNode); -}; - /*! \brief The base class for search policy */ class SearchPolicyNode : public Object { public: @@ -94,23 +70,9 @@ class SearchPolicyNode : public Object { virtual std::pair, Array > ContinueSearchOneRound( SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) = 0; - // Preload measured states from a log file to resume the state of the search policy - void PreloadMeasuredStates(const std::string& log_file); - // Run a list of callback functions void RunCallbacks(const Array& callbacks); - // Dict keys to give hints to the policy - static constexpr const char* always_unroll_inner_key = "ansor_always_unroll_inner"; - static constexpr const char* always_unroll_key = "ansor_always_unroll"; - static constexpr const char* no_split_at_inner_key = "ansor_no_split_at_inner"; - static constexpr const char* no_split_at_outer_key = "ansor_no_split_at_outer"; - static constexpr const char* last_split_is_one_key = "ansor_last_split_is_one"; - // Flag keys to give hints to the policy - static constexpr const char* always_compute_inline_key = "ansor_always_compute_inline"; - static constexpr const char* no_cache_write_key = "ansor_no_cache_write"; - static constexpr const char* no_cache_read_key = "ansor_no_cache_read"; - static constexpr const char *_type_key = "ansor.SearchPolicy"; TVM_DECLARE_BASE_OBJECT_INFO(SearchPolicyNode, Object); diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py index 984434b9c58b..b701dad6d8c0 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -28,7 +28,7 @@ from test_ansor_common import matmul_ansor_test def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local', - cost_model=ansor.RandomModel(), n_trials=2, params=None, + cost_model=None, n_trials=2, params=None, pre_search_callbacks=None): print("Test %s schedule search with the default search policy" % (target)) @@ -42,7 +42,8 @@ def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local' with tempfile.NamedTemporaryFile() as fp: log_file = fp.name - search_policy = ansor.SketchSearchPolicy(cost_model, params=params, seed=seed) + search_policy = ansor.EmptyPolicy() + # search_policy = ansor.SketchSearchPolicy(cost_model, params=params, seed=seed) tune_option = ansor.TuneOption(n_trials=n_trials, runner=runner, measure_callbacks=[ansor.LogToFile(log_file)], pre_search_callbacks=pre_search_callbacks) From 2456c3e0ab49c0a4c75bd98c39ffb6dd0ca6a3fc Mon Sep 17 00:00:00 2001 From: "chengfan.jcf" Date: Wed, 24 Jun 2020 17:01:48 +0800 Subject: [PATCH 45/45] Update --- python/tvm/ansor/auto_schedule.py | 24 ++---- python/tvm/ansor/compute_dag.py | 15 ---- python/tvm/ansor/measure.py | 8 +- python/tvm/ansor/utils.py | 66 ----------------- src/ansor/cost_model/cost_model.cc | 77 ------------------- src/ansor/cost_model/cost_model.h | 98 ------------------------- src/ansor/search_policy/empty_policy.cc | 10 ++- src/ansor/search_policy/empty_policy.h | 4 +- 8 files changed, 20 insertions(+), 282 deletions(-) delete mode 100644 src/ansor/cost_model/cost_model.cc delete mode 100644 src/ansor/cost_model/cost_model.h diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index 33f6dbcba7ff..8fddac567529 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -27,8 +27,8 @@ @tvm._ffi.register_object("ansor.HardwareParams") class HardwareParams(Object): - """ - The parameters of target hardware + """ The parameters of target hardware, this is used to guide the search process of + SearchPolicy. Parameters ---------- @@ -47,8 +47,7 @@ def __init__(self, num_cores, vector_unit_bytes, cache_line_bytes, @tvm._ffi.register_object("ansor.SearchTask") class SearchTask(Object): - """ - The meta-information of a search task + """ The meta-information of a search task Parameters ---------- @@ -68,23 +67,12 @@ def __init__(self, dag, workload_key, target, target_host=None, @tvm._ffi.register_object("ansor.SearchPolicy") class SearchPolicy(Object): """ The base class for search policy """ - def continue_search(self, task, num_measure, verbose, measurer): - return _ffi_api.SearchPolicyContinueSearchOneRound(self, task, - num_measure, verbose, measurer) - - def set_task(self, task): - _ffi_api.SearchPolicySetTask(self, task) - - def set_verbose(self, verbose): - _ffi_api.SearchPolicySetVerbose(self, verbose) - - def run_callbacks(self, callbacks): - _ffi_api.SearchPolicyRunCallbacks(self, callbacks) @tvm._ffi.register_object("ansor.EmptyPolicy") class EmptyPolicy(SearchPolicy): - """ The example search policy + """ This is an example empty search policy which will always generate + the init state of target ComputeDAG. """ def __init__(self): self.__init_handle_by_constructor__(_ffi_api.EmptyPolicy) @@ -92,7 +80,7 @@ def __init__(self): @tvm._ffi.register_object("ansor.SearchCallback") class SearchCallback(Object): - """Callback function before or after search process""" + """ Callback function before or after search process """ @tvm._ffi.register_object("ansor.TuneOption") diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index e57fbbc08843..d591d615d1c5 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -75,18 +75,3 @@ def print_python_code_from_state(self, state): """ state_obj = state if isinstance(state, StateObject) else state.state_object return _ffi_api.ComputeDAGPrintPythonCodeFromState(self, state_obj) - - def infer_bound_from_state(self, state): - """ - Infer bound for a state - - Parameters - ---------- - state : StateObject - - Returns - ------- - state : State - """ - state_obj = state if isinstance(state, StateObject) else state.state_object - return State(_ffi_api.ComputeDAGInferBoundFromState(self, state_obj), self) diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index b8b02e2b85df..af0eddc59653 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -35,13 +35,9 @@ from tvm.runtime import Object, module, ndarray from tvm.driver import build_module from tvm.ir import transform -from tvm.rpc.tracker import Tracker -from tvm.rpc.server import Server -from tvm.autotvm.measure.measure_methods import set_cuda_target_arch -from tvm.contrib import tar, ndk + from . import _ffi_api -from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout, request_remote, \ - check_remote +from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout LOGGER = logging.getLogger('ansor') diff --git a/python/tvm/ansor/utils.py b/python/tvm/ansor/utils.py index 9e3c857aba36..b406824ba811 100644 --- a/python/tvm/ansor/utils.py +++ b/python/tvm/ansor/utils.py @@ -177,69 +177,3 @@ def func_wrapper(que): del que return res - - -def request_remote(device_key, host=None, port=None, priority=1, timeout=60): - """Request a remote session - - Parameters - ---------- - device_key: string - The device key of registered device in tracker - host: host, optional - The host address of rpc tracker. - If is none, will use environment variable "TVM_TRACKER_HOST" - port: int, optional - The port of rpc tracker. - If is none, will use environment variable "TVM_TRACKER_PORT" - priority: int, optional - The priority of this request, larger is more prior - timeout: float, optional - The timeout of this session (units: second) - - Returns - ------ - session: RPCSession - """ - # connect to the tracker - host = host or os.environ['TVM_TRACKER_HOST'] - port = port or int(os.environ['TVM_TRACKER_PORT']) - - tracker = rpc.connect_tracker(host, port) - remote = tracker.request(device_key, priority=priority, - session_timeout=timeout) - return remote - - -def check_remote(device_key, host=None, port=None, priority=100, timeout=10): - """ - Check the availability of a remote device - - Parameters - ---------- - device_key: string - device key of registered device in tracker - host: host, optional - The host address of rpc tracker. - If is none, will use environment variable "TVM_TRACKER_HOST" - port: int, optional - The port address of rpc tracker. - If is none, will use environment variable "TVM_TRACKER_PORT" - priority: int, optional - The priority of this request, larger is more prior - timeout: float, optional - The timeout of this check (units: seconds). - - Returns - ------- - available: bool - True if can find available device - """ - - def _check(): - remote = request_remote(device_key, host, port, priority) - - t = threading.Thread(target=_check, ) - t.start() - t.join(timeout) - return not t.is_alive() diff --git a/src/ansor/cost_model/cost_model.cc b/src/ansor/cost_model/cost_model.cc deleted file mode 100644 index d0ae30e20a9a..000000000000 --- a/src/ansor/cost_model/cost_model.cc +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file ansor/cost_model.h - * \brief Cost model that estimates the performance of programs - */ - -#include "cost_model.h" - -#include -#include - -#include - -namespace tvm { -namespace ansor { - -using ::tvm::runtime::NDArray; - -TVM_REGISTER_OBJECT_TYPE(CostModelNode); -TVM_REGISTER_OBJECT_TYPE(RandomModelNode); - -void RandomNumber(TVMArgs args, TVMRetValue* rv) { - int n = args[0]; - void* data = args[1]; - float* fdata = reinterpret_cast(data); - for (int i = 0; i < n; i++) { - fdata[i] = static_cast(rand_r(nullptr)) / (static_cast(RAND_MAX)); - } -} - -RandomModel::RandomModel() { - ObjectPtr node = make_object(); - node->random_number_func = - runtime::Registry::Get("ansor.cost_model.random_number"); - if (node->random_number_func == nullptr) { - LOG(WARNING) << "ansor.cost_model.random_number is not registered, " - << "use C++ default random_number func instead."; - static PackedFunc cost_model_random_number(RandomNumber); - node->random_number_func = &cost_model_random_number; - } - data_ = std::move(node); -} - -void RandomModelNode::Update(const Array& inputs, - const Array& results) {} - -void RandomModelNode::Predict(const SearchTask& task, - const std::vector& states, - std::vector* scores) { - scores->resize(states.size()); - (*random_number_func)(states.size(), static_cast(scores->data())); -} - -TVM_REGISTER_GLOBAL("ansor.RandomModel").set_body_typed([]() { - return RandomModel(); -}); - -} // namespace ansor -} // namespace tvm diff --git a/src/ansor/cost_model/cost_model.h b/src/ansor/cost_model/cost_model.h deleted file mode 100644 index 03b7fb5f3399..000000000000 --- a/src/ansor/cost_model/cost_model.h +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file ansor/cost_model.h - * \brief Cost model that estimates the performance of programs -*/ - -#ifndef TVM_ANSOR_COST_MODEL_COST_MODEL_H_ -#define TVM_ANSOR_COST_MODEL_COST_MODEL_H_ - -#include -#include -#include -#include -#include "../measure.h" - -namespace tvm { -namespace ansor { - -using runtime::PackedFunc; - -/*! \brief The base class for cost model */ -class CostModelNode: public Object { - public: - // Update the cost model according to new measurement pairs - virtual void Update(const Array& inputs, - const Array& results) = 0; - - // Predict the scores of states - virtual void Predict(const SearchTask& task, const std::vector& states, - std::vector* scores) = 0; - - // Predict the scores of all stages in states - virtual void PredictStages(const SearchTask& task, - const std::vector& states, - std::vector* state_scores, - std::vector>* stage_scores) { - LOG(FATAL) << "Not Implemented"; - } - - static constexpr const char *_type_key = "ansor.CostModel"; - TVM_DECLARE_BASE_OBJECT_INFO(CostModelNode, Object); -}; -TVM_DEFINE_MUTABLE_OBJECT_REF(CostModel, CostModelNode); - -/*! \brief The cost model returns random value for all predictions */ -class RandomModelNode: public CostModelNode { - public: - const PackedFunc* random_number_func; - - void Update(const Array& inputs, - const Array& results) final; - void Predict(const SearchTask& task, const std::vector& states, - std::vector* scores) final; - - static constexpr const char *_type_key = "ansor.RandomModel"; - TVM_DECLARE_FINAL_OBJECT_INFO(RandomModelNode, CostModelNode); -}; - -/*! - * \brief Managed reference to RandomModelNode. - * \sa RandomModelNode - */ -class RandomModel : public CostModel { - public: - RandomModel(); - explicit RandomModel(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) - : CostModel(n) {} - - RandomModelNode* operator->() const { - return static_cast(data_.get()); - } - - TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(RandomModel); - using ContainerType = RandomModelNode; -}; - -} // namespace ansor -} // namespace tvm - -#endif // TVM_ANSOR_COST_MODEL_COST_MODEL_H_ diff --git a/src/ansor/search_policy/empty_policy.cc b/src/ansor/search_policy/empty_policy.cc index 16b22c27ef92..ba861f333c78 100644 --- a/src/ansor/search_policy/empty_policy.cc +++ b/src/ansor/search_policy/empty_policy.cc @@ -32,7 +32,7 @@ State EmptyPolicyNode::Search(SearchTask task, int n_trials, int early_stopping, cur_task = task; // Run pre_search_callbacks before the search process - // This Interface is used to set some init status + // This Interface is usually used to set some init status RunCallbacks(pre_search_callbacks); if (n_trials <= 1) { @@ -45,6 +45,8 @@ State EmptyPolicyNode::Search(SearchTask task, int n_trials, int early_stopping, measurer->Reset(); int ct = 0; + // In each round, we call SearchOneRound to get several candidate states, + // then use ProgramMeasurer to test their performance while (ct < n_trials) { const auto& res = SearchOneRound(); ct += res.size(); @@ -55,12 +57,16 @@ State EmptyPolicyNode::Search(SearchTask task, int n_trials, int early_stopping, measurer->Measure(cur_task, GetRef(this), inputs, &results); } + // Return a state with best measured performance return measurer->best_state[cur_task->workload_key]; } } std::pair, Array > EmptyPolicyNode::ContinueSearchOneRound( SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) { + // The whole process is almost the same as Search, while this function is designed to be + // called and managed by another global task scheduler + std::vector inputs; std::vector results; @@ -70,6 +76,7 @@ std::pair, Array > EmptyPolicyNode::ContinueS } measurer->Measure(cur_task, GetRef(this), inputs, &results); + // Return a pair of MeasureInput Array and MeasureResult Array Array inputs_arr(std::make_move_iterator(inputs.begin()), std::make_move_iterator(inputs.end())); Array results_arr(std::make_move_iterator(results.begin()), @@ -80,6 +87,7 @@ std::pair, Array > EmptyPolicyNode::ContinueS std::vector EmptyPolicyNode::SearchOneRound() { std::vector res; res.push_back(cur_task->compute_dag.GetInitState()); + // As an example policy, EmptyPolicy always return a init state return res; } diff --git a/src/ansor/search_policy/empty_policy.h b/src/ansor/search_policy/empty_policy.h index 78a3838f862a..5c2f52608fe0 100644 --- a/src/ansor/search_policy/empty_policy.h +++ b/src/ansor/search_policy/empty_policy.h @@ -34,7 +34,9 @@ namespace tvm { namespace ansor { /*! - * \brief This is an basic example of search policy + * \file ansor/search_policy/empty_policy.h + * \brief This is an basic example for search policy. The EmptyPolicy will + * always generates the init state of a ComputeDAG. */ class EmptyPolicyNode : public SearchPolicyNode { public: