From cf84177c1bfb379d60dd710836c2a343b77f9ac0 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 20 Jul 2020 19:41:20 -0700 Subject: [PATCH 1/8] add access analyzer --- src/auto_scheduler/compute_dag.cc | 451 +++++++++++++++++++++++++++++- src/auto_scheduler/compute_dag.h | 113 ++++++++ 2 files changed, 559 insertions(+), 5 deletions(-) diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index d81dff66d402..ea271af6db13 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -28,6 +28,7 @@ #include #include #include +#include #include #include @@ -44,6 +45,10 @@ namespace auto_scheduler { using namespace tvm::tir; +template +using OperationMap = AccessAnalyzerNode::OperationMap; +using OperationSet = std::unordered_set; + TVM_REGISTER_NODE_TYPE(ComputeDAGNode); // Topo-sort ops from tensors according to their read-write relations. @@ -114,7 +119,434 @@ Array TopoSortOps(const Array& tensors) { return ops; } -// Estimate number of float operations in an expression +// Extract all tensor accesses in an expr +class TensorAccessExtractor : public StmtExprVisitor { + public: + void Extract(PrimExpr expr) { this->VisitExpr(expr); } + + void VisitExpr_(const CallNode* op) final { + if (op->op.same_as(builtin::if_then_else())) { + has_branch = true; + } + StmtExprVisitor::VisitExpr_(op); + } + + void VisitExpr_(const ProducerLoadNode* op) final { + buf_accesses[Downcast(op->producer)->op].emplace_back(op->indices.begin(), + op->indices.end()); + StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const IfThenElseNode* op) final { + has_branch = true; + StmtExprVisitor::VisitStmt_(op); + } + + void VisitExpr_(const SelectNode* op) final { + has_branch = true; + StmtExprVisitor::VisitExpr_(op); + } + + OperationMap>> buf_accesses; + bool has_branch{false}; +}; + +// Returns whether the expr equals to the var with a const shift +bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) { + if (auto pv = expr.as()) { + return pv == var.get(); + } else if (auto padd = expr.as()) { + return ((padd->a.get() == var.get() && padd->b->IsInstance()) || + (padd->b.get() == var.get() && padd->a->IsInstance())); + } else if (auto psub = expr.as()) { + return ((psub->a.get() == var.get() && psub->b->IsInstance()) || + (psub->b.get() == var.get() && psub->a->IsInstance())); + } else { + return false; + } +} + +// Return whether the access is injective +bool IsInjective(const te::Operation& op, const std::vector& index, bool* axis_missing, + bool* axis_duplicated, bool* same_order) { + auto cop = op.as(); + if (cop == nullptr) { + return false; + } + + std::vector index_to_var_idx; + std::vector var_idx_ct(cop->axis.size(), 0); + + for (const auto& expr : index) { + if (!is_const_int(expr)) { + bool found = false; + for (size_t i = 0; i < cop->axis.size(); ++i) { + if (IsConstShiftEqual(cop->axis[i]->var, expr)) { + index_to_var_idx.push_back(i); + var_idx_ct[i]++; + found = true; + break; + } + } + if (!found) { + return false; + } + } + } + + *axis_missing = false; // Some axes are missing + *axis_duplicated = false; // Some axes appear more than once + *same_order = true; // The axis order is the same as op->axis + for (int ct : var_idx_ct) { + if (ct == 0) { + *axis_missing = true; + } else if (ct > 1) { + *axis_duplicated = true; + } + } + for (size_t i = 1; i < index_to_var_idx.size(); ++i) { + if (index_to_var_idx[i] < index_to_var_idx[i - 1]) { + *same_order = false; + break; + } + } + + return true; +} + +// Gather all VarNodes in an expr +static void GatherVars(const PrimExpr& expr, std::unordered_set* vars) { + PostOrderVisit(expr, [&vars](const ObjectRef& node) { + if (const VarNode* op = node.as()) { + vars->insert(op); + } + }); +} + +// Check whether an expr has expensive operations (e.g. exp) +static bool HasExpensiveOp(const PrimExpr& expr) { + bool found = false; + PostOrderVisit(expr, [&found](const ObjectRef& node) { + if (const CallNode* op = node.as()) { + if (op->op.as()->name == "tir.exp") { + found = true; + } + } + }); + return found; +} + +AccessAnalyzer::AccessAnalyzer(const Array& tensors) { + auto node = make_object(); + OperationMap has_branch; + + // get all ops + node->ops_topo_order = TopoSortOps(tensors); + + arith::Analyzer analyzer; + + // build read & write access map + for (const auto& op : node->ops_topo_order) { + if (op->IsInstance()) { + node->read_from[op] = OperationMap>>(); + } else if (auto cop = op.as()) { + TensorAccessExtractor extractor; + for (const auto& exp : cop->body) { + extractor.Extract(exp); + } + + // read_by and read_from map + for (const auto& iter : extractor.buf_accesses) { + std::vector>& accesses = node->read_by[iter.first][op]; + accesses.insert(accesses.begin(), iter.second.begin(), iter.second.end()); + } + + node->read_from[op] = std::move(extractor.buf_accesses); + has_branch[op] = extractor.has_branch; + + // compute number of common outer iterators + for (const auto& pair : node->read_from[op]) { + const te::Operation& producer = pair.first; + const std::vector>& access_list = pair.second; + const Array& output_shape = op->output_shape(0); + const Array& producer_shape = producer->output_shape(0); + + int n_common; + for (n_common = 0; + n_common < static_cast(std::min(output_shape.size(), producer_shape.size())); + n_common++) { + if (!is_zero(analyzer.Simplify(output_shape[n_common] - producer_shape[n_common]))) { + break; + } + + bool direct_access = true; + for (const auto& access : access_list) { + if (!IsConstShiftEqual(cop->axis[n_common]->var, access[n_common])) { + direct_access = false; + break; + } + } + + if (!direct_access) { + break; + } + } + + node->num_common_outer_iterators[op][producer] = n_common; + node->num_common_outer_iterators[producer][op] = n_common; + } + } else { + LOG(FATAL) << "Invalid op: " << op; + } + } + + // do some static analysis + for (const auto& op : node->ops_topo_order) { + if (op->IsInstance()) { + node->is_injective[op] = true; + node->needs_multi_level_tiling[op] = false; + node->is_strict_inlineable[op] = false; + node->is_output[op] = false; + } else if (auto pop = op.as()) { + // check whether is element-wise and strict-inlineable + // (see definition in compute_dag.h) + bool is_injective = true; + bool is_strict_inlineable = true; + + bool axis_missing, axis_duplicated, same_order; + for (const auto& pair : node->read_from[op]) { + const std::vector>& access = pair.second; + for (const auto& index : access) { + if (!auto_scheduler::IsInjective(op, index, &axis_missing, &axis_duplicated, + &same_order)) { + is_injective = false; + is_strict_inlineable = false; + break; + } + if (!same_order || axis_duplicated) { + // do not strictly inline transpose + is_strict_inlineable = false; + } + } + if (!is_injective) { + break; + } + } + if (has_branch[op]) { + is_strict_inlineable = false; + } + + // don't strictly inline expensive op (e.g. exp) + bool has_expensive_op = false; + for (const auto& expr : pop->body) { + has_expensive_op |= HasExpensiveOp(expr); + } + + node->is_injective[op] = is_injective; + node->is_strict_inlineable[op] = is_strict_inlineable && !has_expensive_op; + + // check whether the op needs multi-level tiling + // (see definition in compute_dag.h) + bool needs_multi_level_tiling = false; + int n_missing = 0; + + for (const auto& pair : node->read_from[op]) { + const std::vector>& access = pair.second; + std::unordered_set vars; + for (const std::vector& indices : access) { + for (const PrimExpr& expr : indices) { + GatherVars(expr, &vars); + } + } + bool missing = false; + for (const auto& axis : pop->axis) { + if (GetIntImm(axis->dom->extent) > 1 && vars.count(axis->var.get()) == 0) { + missing = true; + } + } + if (missing) { + n_missing++; + } + + if (n_missing >= 2 || (n_missing >= 1 && !pop->reduce_axis.empty())) { + needs_multi_level_tiling = true; + break; + } + } + + node->needs_multi_level_tiling[op] = needs_multi_level_tiling; + + // check whether is output + node->is_output[op] = node->read_by[op].empty(); + } else { + LOG(FATAL) << "Invalid op" << op; + } + } + + data_ = std::move(node); +} + +bool AccessAnalyzer::NeedsMultiLevelTiling(const te::Operation& op) const { + return operator->()->needs_multi_level_tiling.at(op); +} + +bool AccessAnalyzer::IsOutput(const te::Operation& op) const { + return operator->()->is_output.at(op); +} + +bool AccessAnalyzer::IsInjective(const te::Operation& op) const { + return operator->()->is_injective.at(op); +} + +bool AccessAnalyzer::IsStrictInlineable(const te::Operation& op) const { + return operator->()->is_strict_inlineable.at(op); +} + +void AccessAnalyzer::GetConsumers(const State& state, const te::Operation& op, + OperationSet* consumers) const { + OperationSet inlined_ops; + for (const auto& stage : state->stages) { + if (stage->compute_at == ComputeAtKind::kInlined) { + inlined_ops.insert(stage->op); + } + } + + std::function collect; + collect = [this, &collect, &inlined_ops, &consumers](const te::Operation& op) { + for (const auto& iter : operator->()->read_by.at(op)) { + if (inlined_ops.count(iter.first)) { + collect(iter.first); + } else { + consumers->insert(iter.first); + } + } + }; + + consumers->clear(); + collect(op); +} + +void AccessAnalyzer::GetDirectProducers(const te::Operation& op, OperationSet* producers) const { + producers->clear(); + for (const auto& iter : operator->()->read_from.at(op)) { + producers->insert(iter.first); + } +} + +void AccessAnalyzer::GetProducers(const State& state, const te::Operation& op, + OperationSet* producers) const { + OperationSet inlined_ops; + for (const auto& stage : state->stages) { + if (stage->compute_at == ComputeAtKind::kInlined) { + inlined_ops.insert(stage->op); + } + } + + std::function collect; + collect = [this, &collect, &inlined_ops, &producers](const te::Operation& op) { + for (const auto& iter : operator->()->read_from.at(op)) { + if (inlined_ops.count(iter.first)) { + collect(iter.first); + } else { + producers->insert(iter.first); + } + } + }; + + producers->clear(); + collect(op); +} + +int AccessAnalyzer::GetNumCommonOuterIterator(const State& state, const te::Operation& op, + const te::Operation& target_op) const { + int ret = INT32_MAX; + bool meet = false; + + std::function traverse; + traverse = [this, &traverse, &target_op, &ret, &meet](const te::Operation& cur_op, int cur_num) { + if (cur_op == target_op) { + ret = std::min(ret, cur_num); + meet = true; + return; + } + + for (const auto& iter : operator->()->read_by.at(cur_op)) { + traverse( + iter.first, + std::min(cur_num, operator->()->num_common_outer_iterators.at(cur_op).at(iter.first))); + } + }; + + traverse(op, op->output_shape(0).size()); + return meet ? ret : 0; +} + +// Return whether two int arrays are elementwise-equal +bool IntArrayEqual(const Array& arr1, const Array& arr2) { + if (arr1.size() != arr2.size()) { + return false; + } + + for (size_t i = 0; i < arr1.size(); ++i) { + auto int1 = arr1[i].as(); + auto int2 = arr2[i].as(); + CHECK(int1 != nullptr); + CHECK(int2 != nullptr); + if (int1->value != int2->value) { + return false; + } + } + return true; +} + +bool AccessAnalyzer::ElementWiseMatch(const te::Operation& op, + const te::Operation& target_op) const { + te::Operation cur_op = op; + while (cur_op != target_op) { + const AccessAnalyzerNode::OperationMap>>& map = + operator->()->read_by.at(cur_op); + + if (map.size() != 1) { + return false; + } + te::Operation next_op = map.begin()->first; + + // Check condition 1: has the same output size + auto p_cur = cur_op.as(); + auto p_next = next_op.as(); + if (p_cur == nullptr || p_next == nullptr) { + return false; + } + + Array output_shape = p_cur->output_shape(0); + for (int i = 1; i < p_cur->num_outputs(); ++i) { + if (!IntArrayEqual(p_cur->output_shape(i), output_shape)) { + return false; + } + } + for (int i = 0; i < p_next->num_outputs(); ++i) { + if (!IntArrayEqual(p_next->output_shape(i), output_shape)) { + return false; + } + } + + // Check condition 2: read is elementwise + const std::vector> reads = map.begin()->second; + bool is_injective, axis_missing, axis_duplicated, same_order; + for (const auto& read : reads) { + is_injective = + auto_scheduler::IsInjective(next_op, read, &axis_missing, &axis_duplicated, &same_order); + if (!is_injective || axis_missing || axis_duplicated || !same_order) { + return false; + } + } + + cur_op = std::move(next_op); + } + return true; +} + +// Estimate the number of float operations in an expression class FlopEstimator : public ExprFunctor { public: double EstimateFlop(const Array& ops) { @@ -126,6 +558,7 @@ class FlopEstimator : public ExprFunctor { fail_ = true; break; } + cur_type_code_ = pop->output_dtype(0).code(); double op_per_element = 0; for (const auto& x : pop->body) { op_per_element += VisitExpr(x); @@ -171,10 +604,17 @@ class FlopEstimator : public ExprFunctor { std::max(VisitExpr(op->true_value), VisitExpr(op->false_value)); } -#define VisitBinary(Node) \ - double VisitExpr_(const Node* op) final { return 1.0 + VisitExpr(op->a) + VisitExpr(op->b); } -#define VisitUnary(Node) \ - double VisitExpr_(const Node* op) final { return 1.0 + VisitExpr(op->a); } +#define VisitBinary(Node) \ + double VisitExpr_(const Node* op) final { \ + double base = op->dtype.code() == cur_type_code_ ? 1.0 : 0.0; \ + return base + VisitExpr(op->a) + VisitExpr(op->b); \ + } + +#define VisitUnary(Node) \ + double VisitExpr_(const Node* op) final { \ + double base = op->dtype.code() == cur_type_code_ ? 1.0 : 0.0; \ + return base + VisitExpr(op->a); \ + } VisitBinary(AddNode); VisitBinary(SubNode); @@ -210,6 +650,7 @@ class FlopEstimator : public ExprFunctor { private: bool fail_{false}; + int cur_type_code_; }; ComputeDAG::ComputeDAG(Array tensors) { diff --git a/src/auto_scheduler/compute_dag.h b/src/auto_scheduler/compute_dag.h index 2417d72983b0..c7399884cbc8 100644 --- a/src/auto_scheduler/compute_dag.h +++ b/src/auto_scheduler/compute_dag.h @@ -37,13 +37,126 @@ #include +#include +#include #include +#include #include "loop_state.h" namespace tvm { namespace auto_scheduler { +/*! \brief Static analysis result for a ComputeDAG */ +class AccessAnalyzerNode : public Object { + public: + template + using OperationMap = std::unordered_map; + + /*! \brief Map an operation to all operations it reads from. + * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses*/ + OperationMap>>> read_from; + /*! \brief Map an operation to all operations it is read by. + * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses*/ + OperationMap>>> read_by; + /*! \brief Store the number of common outer iterators for operation pairs that have read-write + * relations. */ + OperationMap> num_common_outer_iterators; + /*! \brief Store whether the operation is injective */ + OperationMap is_injective; + /*! \brief Store whether the operation is strictly-inlineable */ + OperationMap is_strict_inlineable; + /*! \brief Store whether the operation needs multi-level tiling */ + OperationMap needs_multi_level_tiling; + /*! \brief Store whether the operation is an output operation */ + OperationMap is_output; + /*! \brief Store the topological order of operations */ + Array ops_topo_order; + + static constexpr const char* _type_key = "auto_scheduler.AccessAnalyzer"; + TVM_DECLARE_FINAL_OBJECT_INFO(AccessAnalyzerNode, Object); +}; + +/*! + * \brief Managed reference to AccessAnalyzerNode. + * \sa AccessAnalyzerNode + */ +class AccessAnalyzer : public ObjectRef { + public: + explicit AccessAnalyzer(const Array& tensors); + + /*! + * \brief Return whether this operation needs multi-level tiling + * \param op The operation + */ + bool NeedsMultiLevelTiling(const te::Operation& op) const; + + /*! + * \brief Return whether this operation is an injective operation + * \param op The operation + */ + bool IsInjective(const te::Operation& op) const; + + /*! + * \brief Return whether this operation is strictly inlinable + * \param op The operation + */ + bool IsStrictInlineable(const te::Operation& op) const; + + /*! + * \brief Return whether this operation is an output op + * \param op The operation + */ + bool IsOutput(const te::Operation& op) const; + + /*! + * \brief Get all consumers of on operation + * \param state The current loop state + * \param op The operation + * \param consumers The return consumer set + * \note This function propagates the relation for inlined ops + */ + void GetConsumers(const State& state, const te::Operation& op, + std::unordered_set* consumers) const; + + /*! + * \brief Get all producers of on operation + * \param state The current loop state + * \param op The operation + * \param producers The return producer set + * \note This function propagates the relation for inlined ops + */ + void GetProducers(const State& state, const te::Operation& op, + std::unordered_set* producers) const; + + /*! + * \brief Get all direct producers of on operation + * \param op The operation + * \param producers The return producer set + * \note This function DOES NOT propagate the relation for inlined ops + */ + void GetDirectProducers( + const te::Operation& op, + std::unordered_set* producers) const; + + /*! + * \brief Get the number of common outer iterators. + * \param op The operation + * \param target_op The target operation + * \note This function propagates the relation for chains with multiple ops. + */ + int GetNumCommonOuterIterator(const State& state, const te::Operation& op, + const te::Operation& target_op) const; + + /*! + * \brief Return whether two operations are elementwise-matched + * (e.g. conv2d and relu are elementwise matched) + */ + bool ElementWiseMatch(const te::Operation& op, const te::Operation& target_op) const; + + TVM_DEFINE_OBJECT_REF_METHODS(AccessAnalyzer, ObjectRef, AccessAnalyzerNode); +}; + /*! \brief The TVM Auto-scheduler computational graph and related program analyses. */ class ComputeDAGNode : public Object { public: From 3a8b4b4e4313b399306110741ac23924d0b232d4 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 20 Jul 2020 20:47:07 -0700 Subject: [PATCH 2/8] add test cases --- src/auto_scheduler/compute_dag.cc | 9 +- src/auto_scheduler/compute_dag.h | 11 +- tests/cpp/auto_scheduler_test.cc | 159 ++++++++++++++++++ .../unittest/test_auto_scheduler_common.py | 2 +- .../test_auto_scheduler_compute_dag.py | 19 ++- 5 files changed, 185 insertions(+), 15 deletions(-) create mode 100644 tests/cpp/auto_scheduler_test.cc diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index ea271af6db13..ccea18c80c0f 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -308,8 +308,7 @@ AccessAnalyzer::AccessAnalyzer(const Array& tensors) { node->is_strict_inlineable[op] = false; node->is_output[op] = false; } else if (auto pop = op.as()) { - // check whether is element-wise and strict-inlineable - // (see definition in compute_dag.h) + // check whether this op is element-wise and strict-inlineable bool is_injective = true; bool is_strict_inlineable = true; @@ -346,7 +345,6 @@ AccessAnalyzer::AccessAnalyzer(const Array& tensors) { node->is_strict_inlineable[op] = is_strict_inlineable && !has_expensive_op; // check whether the op needs multi-level tiling - // (see definition in compute_dag.h) bool needs_multi_level_tiling = false; int n_missing = 0; @@ -457,7 +455,7 @@ void AccessAnalyzer::GetProducers(const State& state, const te::Operation& op, collect(op); } -int AccessAnalyzer::GetNumCommonOuterIterator(const State& state, const te::Operation& op, +int AccessAnalyzer::GetNumCommonOuterIterator(const te::Operation& op, const te::Operation& target_op) const { int ret = INT32_MAX; bool meet = false; @@ -656,7 +654,8 @@ class FlopEstimator : public ExprFunctor { ComputeDAG::ComputeDAG(Array tensors) { auto node = make_object(); node->tensors = std::move(tensors); - node->ops = TopoSortOps(node->tensors); + node->access_analyzer = AccessAnalyzer(node->tensors); + node->ops = node->access_analyzer->ops_topo_order; node->flop_ct = FlopEstimator().EstimateFlop(node->ops); node->init_state = State(node->ops); data_ = std::move(node); diff --git a/src/auto_scheduler/compute_dag.h b/src/auto_scheduler/compute_dag.h index c7399884cbc8..6e272d40e930 100644 --- a/src/auto_scheduler/compute_dag.h +++ b/src/auto_scheduler/compute_dag.h @@ -59,8 +59,8 @@ class AccessAnalyzerNode : public Object { /*! \brief Map an operation to all operations it is read by. * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses*/ OperationMap>>> read_by; - /*! \brief Store the number of common outer iterators for operation pairs that have read-write - * relations. */ + /*! \brief Store the number of common outer iterators for operation pairs that have + * read-write relations. */ OperationMap> num_common_outer_iterators; /*! \brief Store whether the operation is injective */ OperationMap is_injective; @@ -145,12 +145,12 @@ class AccessAnalyzer : public ObjectRef { * \param target_op The target operation * \note This function propagates the relation for chains with multiple ops. */ - int GetNumCommonOuterIterator(const State& state, const te::Operation& op, - const te::Operation& target_op) const; + int GetNumCommonOuterIterator(const te::Operation& op, const te::Operation& target_op) const; /*! * \brief Return whether two operations are elementwise-matched * (e.g. conv2d and relu are elementwise matched) + * \note This function propagates the relation for chains with multiple ops. */ bool ElementWiseMatch(const te::Operation& op, const te::Operation& target_op) const; @@ -171,7 +171,8 @@ class ComputeDAGNode : public Object { double flop_ct; /*! \brief The initial state without any transform steps. */ State init_state; - // TODO(merrymercy): Add more analyses later. + /*! \brief Static read-write access analyzer */ + AccessAnalyzer access_analyzer; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("tensors", &tensors); diff --git a/tests/cpp/auto_scheduler_test.cc b/tests/cpp/auto_scheduler_test.cc new file mode 100644 index 000000000000..67e54da43f8a --- /dev/null +++ b/tests/cpp/auto_scheduler_test.cc @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include + +// todo(merrymercy): expose auto_scheduler header files to `include/tvm` +// and do not use relative path here +#include "../../src/auto_scheduler/compute_dag.h" +#include "../../src/auto_scheduler/loop_state.h" + +// Compute declaration for test +tvm::Array conv2d_nchw_bn_relu_func(int N, int H, int W, int CI, int CO, + int kernel_size, int strides, int padding, + int dilation = 1) { + using namespace tvm; + using namespace tvm::te; + + Tensor data = placeholder({N, CI, H, W}, DataType::Float(32), "Data"); + Tensor kernel = placeholder({CO, CI, kernel_size, kernel_size}, DataType::Float(32), "Kernel"); + Tensor bias = placeholder({CO, 1, 1}, DataType::Float(32), "Bias"); + Tensor bn_scale = placeholder({CO, 1, 1}, DataType::Float(32), "Bn_scale"); + Tensor bn_offset = placeholder({CO, 1, 1}, DataType::Float(32), "Bn_offset"); + + int OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1) / strides + 1; + int OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1) / strides + 1; + + const auto& conv = topi::conv2d_nchw(data, kernel, padding, padding, strides, strides); + CHECK(conv->shape[2].as()->value == OH); + CHECK(conv->shape[3].as()->value == OW); + + const auto& bias_add = compute( + {N, CO, OH, OW}, [&](Var i, Var j, Var k, Var l) { return conv[i][j][k][l] + bias[j][0][0]; }, + "Bias_add"); + const auto& bn_mul = compute( + {N, CO, OH, OW}, + [&](Var i, Var j, Var k, Var l) { return bias_add[i][j][k][l] * bn_scale[j][0][0]; }, + "Bn_mul"); + const auto& bn_add = compute( + {N, CO, OH, OW}, + [&](Var i, Var j, Var k, Var l) { return bn_mul[i][j][k][l] + bn_offset[j][0][0]; }, + "Bn_add"); + const auto& out = topi::relu(bn_add); + + return {data, kernel, bias, bn_scale, bn_offset, out}; +} + +using namespace tvm::auto_scheduler; + +// Test Access Analyzer +TEST(ComputeDAG, AccessAnalyzer) { + const auto& tensors = conv2d_nchw_bn_relu_func(1, 224, 224, 3, 64, 7, 2, 3); + const auto& dag = tvm::auto_scheduler::ComputeDAG(tensors); + const auto& s0 = dag->init_state; + + int data = 0, padding = 1, kernel = 2, conv = 3, bias = 4, bias_add = 5; + int bn_scale = 6, bn_mul = 7, bn_offset = 8, bn_add = 9, relu = 10; + + std::set needs_multi_level_tiling = {conv}; + for (size_t stage_id = 0; stage_id < dag->ops.size(); stage_id++) { + if (needs_multi_level_tiling.count(stage_id)) { + CHECK(dag->access_analyzer.NeedsMultiLevelTiling(dag->ops[stage_id])); + } else { + CHECK(!dag->access_analyzer.NeedsMultiLevelTiling(dag->ops[stage_id])); + } + } + + std::set is_injective = {data, padding, kernel, bias, bias_add, + bn_scale, bn_mul, bn_offset, bn_add, relu}; + for (size_t stage_id = 0; stage_id < dag->ops.size(); stage_id++) { + if (is_injective.count(stage_id)) { + CHECK(dag->access_analyzer.IsInjective(dag->ops[stage_id])); + } else { + CHECK(!dag->access_analyzer.IsInjective(dag->ops[stage_id])); + } + } + + std::set is_strictly_inlinable = {bias_add, bn_mul, bn_add, relu}; + for (size_t stage_id = 0; stage_id < dag->ops.size(); stage_id++) { + if (is_strictly_inlinable.count(stage_id)) { + CHECK(dag->access_analyzer.IsStrictInlineable(dag->ops[stage_id])); + } else { + CHECK(!dag->access_analyzer.IsStrictInlineable(dag->ops[stage_id])); + } + } + + std::set is_output = {relu}; + for (size_t stage_id = 0; stage_id < dag->ops.size(); stage_id++) { + if (is_output.count(stage_id)) { + CHECK(dag->access_analyzer.IsOutput(dag->ops[stage_id])); + } else { + CHECK(!dag->access_analyzer.IsOutput(dag->ops[stage_id])); + } + } + + CHECK_EQ(dag->access_analyzer.GetNumCommonOuterIterator(dag->ops[conv], dag->ops[bias_add]), 4); + CHECK_EQ(dag->access_analyzer.GetNumCommonOuterIterator(dag->ops[conv], dag->ops[relu]), 4); + CHECK_EQ(dag->access_analyzer.GetNumCommonOuterIterator(dag->ops[data], dag->ops[relu]), 1); + + CHECK(dag->access_analyzer.ElementWiseMatch(dag->ops[conv], dag->ops[bias_add])); + CHECK(dag->access_analyzer.ElementWiseMatch(dag->ops[conv], dag->ops[relu])); + CHECK(!dag->access_analyzer.ElementWiseMatch(dag->ops[data], dag->ops[padding])); + + std::unordered_set op_set; + { + std::vector> consumer_list = { + {data, padding}, {padding, conv}, {kernel, conv}, {conv, bias_add}, + {bias, bias_add}, {bias_add, bn_mul}, {bn_scale, bn_mul}, {bn_mul, bn_add}, + {bn_offset, bn_add}, {bn_add, relu}}; + for (const auto& pair : consumer_list) { + dag->access_analyzer.GetConsumers(s0, s0->stages[pair.first]->op, &op_set); + CHECK_EQ(op_set.size(), 1); + CHECK_EQ((*op_set.begin()), s0->stages[pair.second]->op); + } + std::vector>> producer_list = {{padding, {data}}, + {conv, {padding, kernel}}, + {bias_add, {conv, bias}}, + {bn_mul, {bias_add, bn_scale}}, + {bn_add, {bn_mul, bn_offset}}, + {relu, {bn_add}}}; + for (const auto& pair : producer_list) { + dag->access_analyzer.GetProducers(s0, s0->stages[pair.first]->op, &op_set); + CHECK_EQ(op_set.size(), pair.second.size()); + for (const auto& target : pair.second) { + CHECK(op_set.count(s0->stages[target]->op)); + } + } + } + + // todo(lmzheng): Add more test cases for GetConsumer and GetProducesr after we have + // compute_inline +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +} diff --git a/tests/python/unittest/test_auto_scheduler_common.py b/tests/python/unittest/test_auto_scheduler_common.py index fa22fdc5597c..1114fb4e75b8 100644 --- a/tests/python/unittest/test_auto_scheduler_common.py +++ b/tests/python/unittest/test_auto_scheduler_common.py @@ -33,7 +33,7 @@ def matmul_auto_scheduler_test(N, M, K): @auto_scheduler.register_workload("matmul_auto_scheduler_test_rename_1") -def matmul_auto_scheduler_test_rename_0(N, M, K): +def matmul_auto_scheduler_test_rename_1(N, M, K): A = te.placeholder((N, K), name='A') B = te.placeholder((K, M), name='B') k = te.reduce_axis((0, K), name='k') diff --git a/tests/python/unittest/test_auto_scheduler_compute_dag.py b/tests/python/unittest/test_auto_scheduler_compute_dag.py index 49344634a8f0..d9c24b97171e 100644 --- a/tests/python/unittest/test_auto_scheduler_compute_dag.py +++ b/tests/python/unittest/test_auto_scheduler_compute_dag.py @@ -17,10 +17,10 @@ """Test ComputeDAG (replay, infer bound)""" -import tvm +import tvm, topi from tvm import auto_scheduler, te -from test_auto_scheduler_common import get_tiled_matmul +from test_auto_scheduler_common import get_tiled_matmul, matmul_auto_scheduler_test def test_apply_steps(): @@ -36,8 +36,19 @@ def test_infer_bound(): def test_estimate_flop(): - dag, s = get_tiled_matmul() - assert abs(dag.flop_ct - 2 * 512 ** 3) < 0.5 + N = 512 + A, B, C = matmul_auto_scheduler_test(N, N, N) + dag = auto_scheduler.ComputeDAG([A, B, C]) + assert abs(dag.flop_ct - 2 * N ** 3) < 0.5 + + D = topi.nn.relu(C) + dag = auto_scheduler.ComputeDAG([A, B, D]) + assert abs(dag.flop_ct - 2 * N ** 3 - N * N) < 0.5 + + # should not count the comparison operations in padding + D = topi.nn.pad(C, [1, 1]) + dag = auto_scheduler.ComputeDAG([A, B, D]) + assert abs(dag.flop_ct - 2 * N ** 3) < 0.5 if __name__ == "__main__": From 966c3cc94449f74cc9a194e9c7df0302c0f0ef98 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 21 Jul 2020 22:54:10 -0700 Subject: [PATCH 3/8] move header files and polish comments --- .../tvm}/auto_scheduler/auto_schedule.h | 28 ++++----- .../tvm}/auto_scheduler/compute_dag.h | 60 ++++++++++--------- .../tvm}/auto_scheduler/loop_state.h | 36 ++++++----- {src => include/tvm}/auto_scheduler/measure.h | 34 ++++++----- .../tvm}/auto_scheduler/measure_record.h | 28 ++++----- .../tvm/auto_scheduler}/search_policy.h | 37 +++++------- .../tvm}/auto_scheduler/search_task.h | 5 +- .../tvm}/auto_scheduler/transform_step.h | 13 ++-- python/tvm/auto_scheduler/auto_schedule.py | 5 +- .../tvm/auto_scheduler/workload_registry.py | 2 +- src/auto_scheduler/auto_schedule.cc | 3 +- src/auto_scheduler/compute_dag.cc | 5 +- src/auto_scheduler/loop_state.cc | 5 +- src/auto_scheduler/measure.cc | 3 +- src/auto_scheduler/measure_record.cc | 7 +-- .../search_policy/empty_policy.cc | 3 +- .../search_policy/empty_policy.h | 4 +- .../search_policy/search_policy.cc | 3 +- src/auto_scheduler/search_task.cc | 3 +- src/auto_scheduler/transform_step.cc | 16 ++--- tests/cpp/auto_scheduler_test.cc | 35 ++++++++--- .../unittest/test_auto_scheduler_common.py | 2 +- 22 files changed, 173 insertions(+), 164 deletions(-) rename {src => include/tvm}/auto_scheduler/auto_schedule.h (81%) rename {src => include/tvm}/auto_scheduler/compute_dag.h (81%) rename {src => include/tvm}/auto_scheduler/loop_state.h (96%) rename {src => include/tvm}/auto_scheduler/measure.h (93%) rename {src => include/tvm}/auto_scheduler/measure_record.h (83%) rename {src/auto_scheduler/search_policy => include/tvm/auto_scheduler}/search_policy.h (79%) rename {src => include/tvm}/auto_scheduler/search_task.h (97%) rename {src => include/tvm}/auto_scheduler/transform_step.h (98%) diff --git a/src/auto_scheduler/auto_schedule.h b/include/tvm/auto_scheduler/auto_schedule.h similarity index 81% rename from src/auto_scheduler/auto_schedule.h rename to include/tvm/auto_scheduler/auto_schedule.h index 55c6992dfd4e..8477966c0247 100644 --- a/src/auto_scheduler/auto_schedule.h +++ b/include/tvm/auto_scheduler/auto_schedule.h @@ -18,19 +18,17 @@ */ /*! - * \file auto_scheduler/auto_schedule.h - * \brief The user interface of the TVM Auto-scheduler. This is the entry structure to get - * schedule search requirements from upper level (Python API), and returns a high performance - * schedule after search process. + * \file tvm/auto_scheduler/auto_schedule.h + * \brief The user interface of the auto scheduler. */ #ifndef TVM_AUTO_SCHEDULER_AUTO_SCHEDULE_H_ #define TVM_AUTO_SCHEDULER_AUTO_SCHEDULE_H_ -#include +#include +#include -#include "measure.h" -#include "search_policy/search_policy.h" +#include namespace tvm { namespace auto_scheduler { @@ -38,9 +36,9 @@ namespace auto_scheduler { /*! \brief Tuning and measurement options. */ class TuningOptionsNode : public Object { public: - /*! \brief Number of total measurement trials. */ + /*! \brief The number of total measurement trials. */ int num_measure_trials; - /*! \brief Stops early the tuning if no improvement after n measurements. */ + /*! \brief Stops the tuning early if no improvement after n measurements. */ int early_stopping; /*! \brief The number of programs to be measured at each search round. */ int num_measures_per_round; @@ -51,7 +49,7 @@ class TuningOptionsNode : public Object { int verbose; /*! \brief ProgramBuilder which builds the program */ ProgramBuilder builder; - /*! \brief ProgramRunner which runs the program and measure time costs */ + /*! \brief ProgramRunner which runs the program and measures time costs */ ProgramRunner runner; /*! \brief MeasureCallback functions to be called after each measure batch */ Optional> measure_callbacks; @@ -81,8 +79,8 @@ class TuningOptions : public ObjectRef { public: /*! * \brief The constructor - * \param num_measure_trials Number of total measurement trials. - * \param early_stopping Stops early the tuning if no improvement after n measurements. + * \param num_measure_trials The number of total measurement trials. + * \param early_stopping Stops the tuning early if no improvement after n measurements. * \param num_measures_per_round The number of programs to be measured at each search round. * \param verbose Verbosity level. 0 for silent, 1 to output information during schedule * search. @@ -100,11 +98,11 @@ class TuningOptions : public ObjectRef { }; /*! - * \brief Auto schedule search for a given compute declaration. + * \brief Run schedule search for a given compute declaration. * \param task The search task of the compute declaration. - * \param search_policy The search policy to be used for schedule search. + * \param search_policy The search policy to be used. * \param tuning_options Tuning and measurement options. - * \return A `te::schedule` and the a Array of `te::Tensor` to be used in `tvm.lower` or + * \return A `te::schedule` and the an Array of `te::Tensor` to be used in `tvm.lower` or * `tvm.build`. */ TVM_DLL std::pair> AutoSchedule(SearchTask task, diff --git a/src/auto_scheduler/compute_dag.h b/include/tvm/auto_scheduler/compute_dag.h similarity index 81% rename from src/auto_scheduler/compute_dag.h rename to include/tvm/auto_scheduler/compute_dag.h index 6e272d40e930..3cfccf4cd352 100644 --- a/src/auto_scheduler/compute_dag.h +++ b/include/tvm/auto_scheduler/compute_dag.h @@ -1,4 +1,4 @@ -/* +/*r * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information @@ -18,8 +18,8 @@ */ /*! - * \file auto_scheduler/compute_dag.h - * \brief The TVM Auto-scheduler computational graph and related program analyses. + * \file tvm/auto_scheduler/compute_dag.h + * \brief The auto-scheduler's computational graph and related program analyses. * * We convert a compute declaration described by `tvm.compute` (could be a single operator or a * subgraph) to a ComputeDAG. It keeps the input/output tensors of the compute declaration, @@ -35,6 +35,8 @@ #ifndef TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_ #define TVM_AUTO_SCHEDULER_COMPUTE_DAG_H_ +#include +#include #include #include @@ -42,8 +44,6 @@ #include #include -#include "loop_state.h" - namespace tvm { namespace auto_scheduler { @@ -89,25 +89,25 @@ class AccessAnalyzer : public ObjectRef { * \brief Return whether this operation needs multi-level tiling * \param op The operation */ - bool NeedsMultiLevelTiling(const te::Operation& op) const; + TVM_DLL bool NeedsMultiLevelTiling(const te::Operation& op) const; /*! * \brief Return whether this operation is an injective operation * \param op The operation */ - bool IsInjective(const te::Operation& op) const; + TVM_DLL bool IsInjective(const te::Operation& op) const; /*! * \brief Return whether this operation is strictly inlinable * \param op The operation */ - bool IsStrictInlineable(const te::Operation& op) const; + TVM_DLL bool IsStrictInlineable(const te::Operation& op) const; /*! * \brief Return whether this operation is an output op * \param op The operation */ - bool IsOutput(const te::Operation& op) const; + TVM_DLL bool IsOutput(const te::Operation& op) const; /*! * \brief Get all consumers of on operation @@ -116,8 +116,9 @@ class AccessAnalyzer : public ObjectRef { * \param consumers The return consumer set * \note This function propagates the relation for inlined ops */ - void GetConsumers(const State& state, const te::Operation& op, - std::unordered_set* consumers) const; + TVM_DLL void GetConsumers( + const State& state, const te::Operation& op, + std::unordered_set* consumers) const; /*! * \brief Get all producers of on operation @@ -126,8 +127,9 @@ class AccessAnalyzer : public ObjectRef { * \param producers The return producer set * \note This function propagates the relation for inlined ops */ - void GetProducers(const State& state, const te::Operation& op, - std::unordered_set* producers) const; + TVM_DLL void GetProducers( + const State& state, const te::Operation& op, + std::unordered_set* producers) const; /*! * \brief Get all direct producers of on operation @@ -135,7 +137,7 @@ class AccessAnalyzer : public ObjectRef { * \param producers The return producer set * \note This function DOES NOT propagate the relation for inlined ops */ - void GetDirectProducers( + TVM_DLL void GetDirectProducers( const te::Operation& op, std::unordered_set* producers) const; @@ -145,14 +147,15 @@ class AccessAnalyzer : public ObjectRef { * \param target_op The target operation * \note This function propagates the relation for chains with multiple ops. */ - int GetNumCommonOuterIterator(const te::Operation& op, const te::Operation& target_op) const; + TVM_DLL int GetNumCommonOuterIterator( + const te::Operation& op, const te::Operation& target_op) const; /*! * \brief Return whether two operations are elementwise-matched * (e.g. conv2d and relu are elementwise matched) * \note This function propagates the relation for chains with multiple ops. */ - bool ElementWiseMatch(const te::Operation& op, const te::Operation& target_op) const; + TVM_DLL bool ElementWiseMatch(const te::Operation& op, const te::Operation& target_op) const; TVM_DEFINE_OBJECT_REF_METHODS(AccessAnalyzer, ObjectRef, AccessAnalyzerNode); }; @@ -167,11 +170,11 @@ class ComputeDAGNode : public Object { Array tensors; /*! \brief All related operations in topo order. */ Array ops; - /*! \brief Number of total float operations for this ComputeDAG. */ + /*! \brief The number of total float operations for this ComputeDAG. */ double flop_ct; /*! \brief The initial state without any transform steps. */ State init_state; - /*! \brief Static read-write access analyzer */ + /*! \brief The static read-write access analyzer */ AccessAnalyzer access_analyzer; void VisitAttrs(tvm::AttrVisitor* v) { @@ -194,16 +197,17 @@ class ComputeDAG : public ObjectRef { /*! \brief The constructor. * \param tensors `te::Tensor`s for a compute declaration. */ - explicit ComputeDAG(Array tensors); + TVM_DLL explicit ComputeDAG(Array tensors); /*! - * \brief Apply the history transform steps from a State to get a TVM schedule. + * \brief Apply the history transform steps to get a TVM schedule. * \param transform_steps Transform steps of a state. - * \param stages A pointer to a `te::Stage` Array, default to be nullptr. - * Pass a valid pointer if these information needs to be used outside this function. - * \param stage_to_axes A pointer to a StageToAxesMap, default to be nullptr. - * Pass a valid pointer if these information needs to be used outside this function. - * \return A `te.schedule` and the a list of `te.Tensor` to be used in `tvm.lower` or `tvm.build`. + * \param stages The list of stages after applying the steps. + * Pass a valid pointer if this information needs to be used outside this function. + * \param stage_to_axes The map that stores all axes for one stage. + * Pass a valid pointer if this information needs to be used outside this function. + * \return A `te.schedule` and the an Array of `te.Tensor` to be used in `tvm.lower` + * or `tvm.build`. */ std::pair> ApplySteps( const Array& transform_steps, Array* stages = nullptr, @@ -222,9 +226,9 @@ class ComputeDAG : public ObjectRef { * The states can lose complete bound information after some transform steps (e.g., compute_at). * We can call this function to infer and fill all the bound information. * This function calls TVM InferBound pass internally to get the bound. - * The returned state of this function is guaranteed to have complete iterator extent information. - * \param state The state to. - * \return The State after inferbound. + * The returned state of this function is guaranteed to have complete bound information. + * \param state The input state. + * \return The State with complete bound information */ State InferBound(const State& state) const; diff --git a/src/auto_scheduler/loop_state.h b/include/tvm/auto_scheduler/loop_state.h similarity index 96% rename from src/auto_scheduler/loop_state.h rename to include/tvm/auto_scheduler/loop_state.h index 4d6477b92b0f..ab7a52081b93 100644 --- a/src/auto_scheduler/loop_state.h +++ b/include/tvm/auto_scheduler/loop_state.h @@ -48,6 +48,8 @@ #ifndef TVM_AUTO_SCHEDULER_LOOP_STATE_H_ #define TVM_AUTO_SCHEDULER_LOOP_STATE_H_ +#include +#include #include #include @@ -55,8 +57,6 @@ #include #include -#include "transform_step.h" - namespace tvm { namespace auto_scheduler { @@ -159,10 +159,16 @@ using IterKey = std::pair; */ class AttachMapNode : public Object { public: + struct key_hash : public std::function { + std::size_t operator()(const IterKey& k) const { + return ::dmlc::HashCombine(std::hash()(k.first), std::hash()(k.second)); + } + }; + /*! \brief A Map to store the mapping of stage to its attached iterator. */ std::unordered_map stage_to_attach_iter; /*! \brief A Map to store the mapping of iterator to the stage attached to it. */ - std::unordered_map> iter_to_attached_stages; + std::unordered_map, key_hash> iter_to_attached_stages; static constexpr const char* _type_key = "auto_scheduler.AttachMap"; TVM_DECLARE_FINAL_OBJECT_INFO(AttachMapNode, Object); @@ -381,21 +387,11 @@ class State : public ObjectRef { // Hash and equal function for State namespace std { -/*! \brief The hash function for auto_scheduler::State. */ -template <> -struct hash<::tvm::auto_scheduler::State> { - std::size_t operator()(const ::tvm::auto_scheduler::State& state) const { - return tvm::runtime::ObjectHash()(state.ToStr()); - } -}; - /*! * \brief The equal_to function for auto_scheduler::State. - * We use the schedule result(its string format) of a state to check if two states are `euqal`. - * Equal States: 1. the transform steps are totally the same; 2. even with different steps, two - * states may still result in a same schedule. e.g. To split a axis with extent 512 to 3 parts - * [8, 16, 4]. We can split from inner to outter by factors [16, 4], while we can get a same result - * to split from outter to inner by factors [8, 16]) + * This function checkes the equality by looking at the lowered string format of states. + * If two states with different transform history have the same lowered string format, + * they will be considered being equal. */ template <> struct equal_to<::tvm::auto_scheduler::State> { @@ -405,6 +401,14 @@ struct equal_to<::tvm::auto_scheduler::State> { } }; +/*! \brief The hash function for auto_scheduler::State. */ +template <> +struct hash<::tvm::auto_scheduler::State> { + std::size_t operator()(const ::tvm::auto_scheduler::State& state) const { + return tvm::runtime::ObjectHash()(state.ToStr()); + } +}; + } // namespace std #endif // TVM_AUTO_SCHEDULER_LOOP_STATE_H_ diff --git a/src/auto_scheduler/measure.h b/include/tvm/auto_scheduler/measure.h similarity index 93% rename from src/auto_scheduler/measure.h rename to include/tvm/auto_scheduler/measure.h index 02d6e879a1cd..83d7c8d0d3e9 100644 --- a/src/auto_scheduler/measure.h +++ b/include/tvm/auto_scheduler/measure.h @@ -23,26 +23,28 @@ * These functions are responsible for building the tvm module, uploading it to remote devices, * recording the running time costs, and checking the correctness of the output. * - * We separate the measurement into two steps: build and run. + * The measurement is separated into two steps: build and run. * A builder builds the executable binary files and a runner runs the binary files to get the * measurement results. The flow of data structures is * * `ProgramBuilder` `ProgramRunner` * `MeasureInput` -----------------> `BuildResult` ----------------> `MeasureResult` * - * We implement these in python to utilize python's multiprocessing and error handling. + * The core functions is implemented in python to utilize python's multiprocessing + * and error handling (see also `python/tvm/auto_scheduler/measure.py`). + * This c++ file is just a wrapper for the python functions. */ #ifndef TVM_AUTO_SCHEDULER_MEASURE_H_ #define TVM_AUTO_SCHEDULER_MEASURE_H_ +#include +#include + #include #include #include -#include "loop_state.h" -#include "search_task.h" - namespace tvm { namespace auto_scheduler { @@ -209,7 +211,7 @@ class MeasureCallbackNode : public Object { public: /*! * \brief Callback function that will be called on measurement input/result pairs - * after measurement. + * after each measurement batch. * \param policy The current search policy. * \param inputs An Array of MeasureInput. * \param results An Array of MeasureResult. @@ -234,7 +236,7 @@ class MeasureCallback : public ObjectRef { /*! \brief ProgramBuilder that builds the programs */ class ProgramBuilderNode : public Object { public: - /*! \brief The number of tasks to run in parallel */ + /*! \brief The number of build processes to run in parallel */ int n_parallel; /*! \brief Timeout of a build */ int timeout; @@ -323,15 +325,15 @@ class LocalBuilder : public ProgramBuilder { * \brief The constructor. * \param timeout The timeout limit (in second) for each build thread. * This will be used in a wrapper of the multiprocessing.Process.join(). - * \param n_parallel Number of threads used to build in parallel. - * \param build_func The name of registered build function. + * \param n_parallel The number of threads used to build in parallel. + * \param build_func The name of the registered build function. */ LocalBuilder(int timeout, int n_parallel, const String& build_func); TVM_DEFINE_OBJECT_REF_METHODS(LocalBuilder, ProgramBuilder, LocalBuilderNode); }; -/*! \brief LocalRunner that uses local CPU/GPU to measures the time cost of programs */ +/*! \brief LocalRunner that uses local CPU/GPU to measure the time cost of programs */ class LocalRunnerNode : public ProgramRunnerNode { public: Array Run(const Array& inputs, @@ -373,13 +375,12 @@ class RPCRunnerNode : public ProgramRunnerNode { String key; /*! \brief The host address of the RPC Tracker. */ String host; - /*! \brief The port of RPC Tracker. */ + /*! \brief The port of the RPC Tracker. */ int port; /*! \brief The priority of this run request, larger is more prior. */ int priority; /*! \brief The number of tasks run in parallel. */ int n_parallel; - /*! \brief The number of times to run the generated code for taking average. */ Array Run(const Array& inputs, const Array& build_results, int verbose) final; @@ -395,10 +396,11 @@ class RPCRunnerNode : public ProgramRunnerNode { class RPCRunner : public ProgramRunner { public: /*! - * \brief The constructor. + * \brief The constructor. See the corresponding class in python/tvm/auto_scheduler/measure.py + * for more detailed parameter explaination. * \param key The key of the device registered in the RPC tracker. * \param host The host address of the RPC Tracker. - * \param prot The port of RPC Tracker. + * \param port The port of RPC Tracker. * \param priority The priority of this run request, larger is more prior. * \param n_parallel The number of tasks run in parallel. * \param timeout Timeout of a run. @@ -415,7 +417,7 @@ class RPCRunner : public ProgramRunner { /*! * \brief Measurer that measures the time costs of tvm programs - * This class combines ProgramBuilder and ProgramRunner, and provides a simpler API */ + * This class combines ProgramBuilder and ProgramRunner and provides a simpler API */ class ProgramMeasurerNode : public Object { public: /*! \brief Measured programs counter. */ @@ -483,7 +485,7 @@ class ProgramMeasurer : public ObjectRef { * \param callbacks MeasureCallback to be called after each measure batch. * \param verbose Verbosity level. 0 for silent, 1 to output information during program * measuring. - * \param max_continous_error The number of max continuous error. + * \param max_continous_error The number of allowed maximum continuous error. */ ProgramMeasurer(ProgramBuilder builder, ProgramRunner runner, Optional> callbacks, int verbose, diff --git a/src/auto_scheduler/measure_record.h b/include/tvm/auto_scheduler/measure_record.h similarity index 83% rename from src/auto_scheduler/measure_record.h rename to include/tvm/auto_scheduler/measure_record.h index 1cfeab07a400..fa8fe2b1b455 100644 --- a/src/auto_scheduler/measure_record.h +++ b/include/tvm/auto_scheduler/measure_record.h @@ -18,26 +18,26 @@ */ /*! - * \file auto_scheduler/measure_record.h - * \brief Json serialization format for dumping and loading tuning records. + * \file tvm/auto_scheduler/measure_record.h + * \brief Json serialization format for dumping and loading measurement records. */ #ifndef TVM_AUTO_SCHEDULER_MEASURE_RECORD_H_ #define TVM_AUTO_SCHEDULER_MEASURE_RECORD_H_ +#include + #include #include #include -#include "measure.h" - namespace tvm { namespace auto_scheduler { /*! \brief Callback for logging the input and results of measurements to file */ class RecordToFileNode : public MeasureCallbackNode { public: - /*! \brief File name for this callback to write log to. */ + /*! \brief The name of output file. */ String filename; void Callback(const SearchPolicy& policy, const Array& inputs, @@ -55,7 +55,7 @@ class RecordToFile : public MeasureCallback { public: /*! * \brief The constructor. - * \param filename File name for this callback to write log. + * \param filename The name of output file */ explicit RecordToFile(String filename); @@ -65,7 +65,7 @@ class RecordToFile : public MeasureCallback { /*! \brief Log reader to load step logs from a file.*/ class RecordReaderNode : public Object { public: - /*! \brief File name for this reader to load log from. */ + /*! \brief The name of input file. */ String filename; /*! \brief The reading file stream. */ std::ifstream infile; @@ -92,7 +92,7 @@ class RecordReaderNode : public Object { TVM_DECLARE_FINAL_OBJECT_INFO(RecordReaderNode, Object); private: - /*! \brief A string object to store the next line. */ + /*! \brief A string storing the current line. */ std::string cur_line_; }; @@ -104,7 +104,7 @@ class RecordReader : public ObjectRef { public: /*! * \brief The constructor. - * \param filename File name for this callback to write log. + * \param filename The name of input file */ explicit RecordReader(String filename); @@ -112,7 +112,7 @@ class RecordReader : public ObjectRef { }; /*! - * \brief Write measure records to an output stream. + * \brief Append measure records to an output stream. * \param os A pointer to a output stream. * \param inputs The MeasureInputs to be written. * \param results The MeasureResults to be written. @@ -122,10 +122,10 @@ void WriteMeasureRecords(std::ostream* os, const Array& inputs, /*! * \brief Read one measure record from a string. - * \param str The record string to be extract. - * \param inp A pointer to a MeasureInputNode, this is used as output. - * \param res A pointer to a MeasureResultNode, this is used as output. - * \param log_version A pointer to a log version string. + * \param str The record string to be parsed. + * \param inp A pointer to a MeasureInputNode used to store the return value. + * \param res A pointer to a MeasureResultNode used to store the return value. + * \param log_version A pointer to a string used to store the log version. */ void ReadMeasureRecord(const std::string& str, MeasureInputNode* inp, MeasureResultNode* res, std::string* log_version); diff --git a/src/auto_scheduler/search_policy/search_policy.h b/include/tvm/auto_scheduler/search_policy.h similarity index 79% rename from src/auto_scheduler/search_policy/search_policy.h rename to include/tvm/auto_scheduler/search_policy.h index 70f94ad65b94..457aca1e8f2e 100644 --- a/src/auto_scheduler/search_policy/search_policy.h +++ b/include/tvm/auto_scheduler/search_policy.h @@ -18,11 +18,11 @@ */ /*! - * \file auto_scheduler/search_policy/search_policy.h + * \file tvm/auto_scheduler/search_policy.h * \brief The base class of search policies, including the abstract definition of search policy and * other supporting data structures. * - * The basic schedule search process for TVM Auto-scheduler is design to be: + * The basic schedule search process for the auto-scheduler is design to be: * `Program sampling` -> `Performance Tuning`. * * In `Program sampling`, we use some predefined precise or heuristic rules to generate several @@ -31,7 +31,7 @@ * * Candidate schedules are measured against the specific hardware target. * - * \note Adding a new search policy. + * \note How to add a new search policy. * In design, there's no need for users to implement their own search policy, our formal search * policy(will be brought later) should be enough to cover most use cases. Meanwhile, a custom rule * mechanism will be provided to enable user-defined template search to serve the same functionality @@ -48,16 +48,15 @@ * during the search process. */ -#ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_SEARCH_POLICY_H_ -#define TVM_AUTO_SCHEDULER_SEARCH_POLICY_SEARCH_POLICY_H_ +#ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_H_ +#define TVM_AUTO_SCHEDULER_SEARCH_POLICY_H_ +#include #include #include #include -#include "../search_task.h" - namespace tvm { namespace auto_scheduler { @@ -110,16 +109,16 @@ class SearchPolicyNode : public Object { /*! * \brief Do schedule search for a task. Takes the SearchTask as input and returns the best state - * get during the search process. - * \param task The SearchTask or workload key for the computation declaration - * \param num_measure_trials Total schedules to be tried during this search. - * \param early_stopping Early stop if no better schedule is found. - * \param num_measures_per_round Max measure batch in one search round. + * found during the search. + * \param task The SearchTask for the computation declaration + * \param num_measure_trials The number of total measurement trials. + * \param early_stopping Stops the tuning early if no improvement after n measurements. + * \param num_measures_per_round The number of programs to be measured at each search round. * \param verbose Verbose level. 0 for silent, 1 to output information during schedule * search. - * \param measurer A ProgramMeasurer which packs ProgramBuilder & ProgramRunner inside. + * \param measurer A ProgramMeasurer to build and measure programs * \param pre_search_callbacks SearchCallback to be called before schedule search. - * \return The best state get. + * \return The best state found. */ virtual State Search(SearchTask task, int num_measure_trials, int early_stopping, int num_measures_per_round, int verbose, ProgramMeasurer measurer, @@ -137,16 +136,12 @@ class SearchPolicyNode : public Object { protected: /*! * \brief The set of already measured states. - * During the schedule search process, we may generate `equal states` through different search - * branches. (Equal States: 1. the transform steps are totally the same; 2. even with different - * steps, two states may still result in a same schedule. e.g. To split a axis with extent 512 - * to 3 parts [8, 16, 4]. We can split from inner to outter by factors [16, 4], while we can - * get a same result to split from outter to inner by factors [8, 16]) * We store the string format of a state for redundancy check. This is used to make sure a * measured state will never be measured again. */ std::unordered_set measured_states_set_; - /*! \brief The array of already measured states. This can be used in evolutionary search. */ + /*! \brief The array of already measured states. + * The good states can be used as the initial population in evolutionary search. */ std::vector measured_states_vector_; /*! \brief The throughputs of already measured states */ std::vector measured_states_throughputs_; @@ -164,4 +159,4 @@ class SearchPolicy : public ObjectRef { } // namespace auto_scheduler } // namespace tvm -#endif // TVM_AUTO_SCHEDULER_SEARCH_POLICY_SEARCH_POLICY_H_ +#endif // TVM_AUTO_SCHEDULER_SEARCH_POLICY_H_ diff --git a/src/auto_scheduler/search_task.h b/include/tvm/auto_scheduler/search_task.h similarity index 97% rename from src/auto_scheduler/search_task.h rename to include/tvm/auto_scheduler/search_task.h index ca313500cc8f..85154b5e406b 100644 --- a/src/auto_scheduler/search_task.h +++ b/include/tvm/auto_scheduler/search_task.h @@ -25,16 +25,15 @@ #ifndef TVM_AUTO_SCHEDULER_SEARCH_TASK_H_ #define TVM_AUTO_SCHEDULER_SEARCH_TASK_H_ +#include #include -#include "compute_dag.h" - namespace tvm { namespace auto_scheduler { class HardwareParams; -/*! \brief The parameters of target hardware used to guide the search process of SearchPolicy. */ +/*! \brief The parameters of target hardware used to guide the SearchPolicy. */ class HardwareParamsNode : public Object { public: /*! \brief The number of cores. */ diff --git a/src/auto_scheduler/transform_step.h b/include/tvm/auto_scheduler/transform_step.h similarity index 98% rename from src/auto_scheduler/transform_step.h rename to include/tvm/auto_scheduler/transform_step.h index ce3ca50ffae6..b23137a9ba5d 100644 --- a/src/auto_scheduler/transform_step.h +++ b/include/tvm/auto_scheduler/transform_step.h @@ -19,10 +19,10 @@ /*! * \file auto_scheduler/transform_step.h - * \brief Transformation steps. For each schedule primitive, there is a corresponding transform - * step. + * \brief Transformation steps. These steps are used to manipulate the LoopState. + * They are similar to the schedule primitives in te::Stage. * - * \note To add a new transform step: + * \note How to add a new transform step: * Take fuse step for example: * 1. Define class `FuseStepNode`, `FuseStep` in `transform_steps.h`, and implement its first * construction function `FuseStep::FuseStep()` in `transform_steps.cc`. @@ -51,8 +51,6 @@ #include #include -#include "utils.h" - namespace tvm { namespace auto_scheduler { @@ -187,7 +185,6 @@ Step StepReadFromRecord(dmlc::JSONReader* reader); * \param step The step to be applied to State. * \param state A mutable pointer to State. * \param dag The original ComputeDAG of this state. - * \return The iterator result after annotate. */ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag); @@ -209,7 +206,7 @@ void StepApplyToSchedule(const Step& step, Array* stages, StageToAxes String StepPrintAsPythonAPI(const Step& step, Array* stages, StageToAxesMap* stage_to_axes); -/********** Primitives working on single stage **********/ +/********** Steps working on single stage **********/ /*! * \brief Annotation step that corresponds to vectorize, parallel, unroll and thread binding. @@ -478,7 +475,7 @@ class SplitStep : public Step { TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode); }; -/********** Primitives working on multiple stages **********/ +/********** Steps working on multiple stages **********/ /*! \brief Compute at step that corresponds to te::Stage::compute_at */ class ComputeAtStepNode : public StepNode { diff --git a/python/tvm/auto_scheduler/auto_schedule.py b/python/tvm/auto_scheduler/auto_schedule.py index d45dbf8d0aaa..52aa62baf56f 100644 --- a/python/tvm/auto_scheduler/auto_schedule.py +++ b/python/tvm/auto_scheduler/auto_schedule.py @@ -57,7 +57,7 @@ def __init__(self, num_cores, vector_unit_bytes, cache_line_bytes): @tvm._ffi.register_object("auto_scheduler.SearchTask") class SearchTask(Object): - """ The computation information and hardware parameters for a specific schedule search task. + """ The computation information and hardware parameters for a schedule search task. Parameters ---------- @@ -158,9 +158,6 @@ def __init__(self, num_measure_trials=0, early_stopping=None, num_measures_per_r def auto_schedule(task, search_policy='default', tuning_options=None): """ Do auto scheduling for a computation declaration. - The task parameter can be a `string` as workload_key, or directly - passing a `SearchTask` as input. - Parameters ---------- task : SearchTask diff --git a/python/tvm/auto_scheduler/workload_registry.py b/python/tvm/auto_scheduler/workload_registry.py index 36c203781073..045720a037ea 100644 --- a/python/tvm/auto_scheduler/workload_registry.py +++ b/python/tvm/auto_scheduler/workload_registry.py @@ -95,7 +95,7 @@ def make_workload_key(func, args): Returns ------- - workload_key : Str + workload_key : str The workload key of the function. """ global WORKLOAD_FUNC_REGISTRY diff --git a/src/auto_scheduler/auto_schedule.cc b/src/auto_scheduler/auto_schedule.cc index b515b3accf7a..c537ca702b9d 100644 --- a/src/auto_scheduler/auto_schedule.cc +++ b/src/auto_scheduler/auto_schedule.cc @@ -24,8 +24,7 @@ * schedule after search process. */ -#include "auto_schedule.h" - +#include #include namespace tvm { diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index ccea18c80c0f..92239e51101a 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -22,8 +22,8 @@ * \brief Compute declaration graph and its related analysis tools. */ -#include "compute_dag.h" - +#include +#include #include #include #include @@ -37,7 +37,6 @@ #include #include -#include "loop_state.h" #include "utils.h" namespace tvm { diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc index bfe547864ed1..35d899ad561f 100644 --- a/src/auto_scheduler/loop_state.cc +++ b/src/auto_scheduler/loop_state.cc @@ -23,14 +23,13 @@ * see auto_scheduler/loop_state.h for more explanation. */ -#include "loop_state.h" - +#include +#include #include #include #include -#include "transform_step.h" #include "utils.h" namespace tvm { diff --git a/src/auto_scheduler/measure.cc b/src/auto_scheduler/measure.cc index 6198f60da5a6..e249f7bd7d28 100644 --- a/src/auto_scheduler/measure.cc +++ b/src/auto_scheduler/measure.cc @@ -22,8 +22,7 @@ * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs. */ -#include "measure.h" - +#include #include #include diff --git a/src/auto_scheduler/measure_record.cc b/src/auto_scheduler/measure_record.cc index 39f9ad86c958..02f244f93de5 100644 --- a/src/auto_scheduler/measure_record.cc +++ b/src/auto_scheduler/measure_record.cc @@ -22,9 +22,10 @@ * \brief Json serialization format for dumping and loading tuning records. */ -#include "measure_record.h" - #include +#include +#include +#include #include #include @@ -33,8 +34,6 @@ #include #include -#include "loop_state.h" -#include "transform_step.h" #include "utils.h" // Json serialization handler for MeasureInput, MeasureResult diff --git a/src/auto_scheduler/search_policy/empty_policy.cc b/src/auto_scheduler/search_policy/empty_policy.cc index 1886203593a9..4c85af486a61 100644 --- a/src/auto_scheduler/search_policy/empty_policy.cc +++ b/src/auto_scheduler/search_policy/empty_policy.cc @@ -24,10 +24,9 @@ #include "empty_policy.h" +#include #include -#include "../measure.h" - namespace tvm { namespace auto_scheduler { diff --git a/src/auto_scheduler/search_policy/empty_policy.h b/src/auto_scheduler/search_policy/empty_policy.h index 4ccc9c1042ea..ef7d38ddf116 100644 --- a/src/auto_scheduler/search_policy/empty_policy.h +++ b/src/auto_scheduler/search_policy/empty_policy.h @@ -26,8 +26,8 @@ #ifndef TVM_AUTO_SCHEDULER_SEARCH_POLICY_EMPTY_POLICY_H_ #define TVM_AUTO_SCHEDULER_SEARCH_POLICY_EMPTY_POLICY_H_ -#include "../loop_state.h" -#include "search_policy.h" +#include +#include namespace tvm { namespace auto_scheduler { diff --git a/src/auto_scheduler/search_policy/search_policy.cc b/src/auto_scheduler/search_policy/search_policy.cc index fba5155edaea..764b0a7fb97a 100644 --- a/src/auto_scheduler/search_policy/search_policy.cc +++ b/src/auto_scheduler/search_policy/search_policy.cc @@ -22,8 +22,7 @@ * \brief The base class of search policies. */ -#include "search_policy.h" - +#include #include namespace tvm { diff --git a/src/auto_scheduler/search_task.cc b/src/auto_scheduler/search_task.cc index 912a31046540..9cc21f2dfedc 100644 --- a/src/auto_scheduler/search_task.cc +++ b/src/auto_scheduler/search_task.cc @@ -22,8 +22,7 @@ * \brief Meta information and hardware parameters for a search task. */ -#include "search_task.h" - +#include #include #include diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index 6c672a5215f2..b1b3b9437006 100644 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -19,12 +19,12 @@ /*! * \file auto_scheduler/transform_step.cc - * \brief Transformation steps. For each schedule primitive, there is a corresponding transform - * step. + * \brief Transformation steps. These steps are used to manipulate the LoopState. + * They are similar to the schedule primitives in te::Stage. */ -#include "transform_step.h" - +#include +#include #include #include @@ -32,7 +32,6 @@ #include #include -#include "loop_state.h" #include "utils.h" namespace tvm { @@ -80,6 +79,7 @@ Step StepReadFromRecord(dmlc::JSONReader* reader) { } void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) { + // We need this runtime dispatcher because different steps have different function signatures if (auto ps = step.as()) { ps->ApplyToState(state); } else if (auto ps = step.as()) { @@ -101,6 +101,7 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) { void StepApplyToSchedule(const Step& step, Array* stages, StageToAxesMap* stage_to_axes) { + // We need this runtime dispatcher because different steps have different function signatures if (auto ps = step.as()) { ps->ApplyToSchedule(stages, stage_to_axes); } else if (auto ps = step.as()) { @@ -122,6 +123,7 @@ void StepApplyToSchedule(const Step& step, Array* stages, String StepPrintAsPythonAPI(const Step& step, Array* stages, StageToAxesMap* stage_to_axes) { + // We need this runtime dispatcher because different steps have different function signatures if (auto ps = step.as()) { return ps->PrintAsPythonAPI(stages, stage_to_axes); } else if (auto ps = step.as()) { @@ -142,7 +144,7 @@ String StepPrintAsPythonAPI(const Step& step, Array* stages, return ""; } -/********** Primitives working on single stage **********/ +/********** Steps working on single stage **********/ /********** Annotation **********/ AnnotationStep::AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann) { @@ -741,7 +743,7 @@ String SplitStepNode::PrintAsPythonAPI(Array* stages, return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer); } -/********** Primitives working on multiple stages **********/ +/********** Steps working on multiple stages **********/ /********** Compute At **********/ ComputeAtStep::ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id) { diff --git a/tests/cpp/auto_scheduler_test.cc b/tests/cpp/auto_scheduler_test.cc index 67e54da43f8a..f21fe1f5c57b 100644 --- a/tests/cpp/auto_scheduler_test.cc +++ b/tests/cpp/auto_scheduler_test.cc @@ -20,16 +20,12 @@ #include #include #include +#include #include #include #include -// todo(merrymercy): expose auto_scheduler header files to `include/tvm` -// and do not use relative path here -#include "../../src/auto_scheduler/compute_dag.h" -#include "../../src/auto_scheduler/loop_state.h" - // Compute declaration for test tvm::Array conv2d_nchw_bn_relu_func(int N, int H, int W, int CI, int CO, int kernel_size, int strides, int padding, @@ -72,7 +68,7 @@ using namespace tvm::auto_scheduler; TEST(ComputeDAG, AccessAnalyzer) { const auto& tensors = conv2d_nchw_bn_relu_func(1, 224, 224, 3, 64, 7, 2, 3); const auto& dag = tvm::auto_scheduler::ComputeDAG(tensors); - const auto& s0 = dag->init_state; + State s0 = dag->init_state; int data = 0, padding = 1, kernel = 2, conv = 3, bias = 4, bias_add = 5; int bn_scale = 6, bn_mul = 7, bn_offset = 8, bn_add = 9, relu = 10; @@ -148,8 +144,31 @@ TEST(ComputeDAG, AccessAnalyzer) { } } - // todo(lmzheng): Add more test cases for GetConsumer and GetProducesr after we have - // compute_inline + s0.compute_inline(bn_add); + s0.compute_inline(bn_mul); + s0.compute_inline(bias_add); + s0.compute_inline(padding); + { + std::vector> consumer_list = {{data, conv}, {kernel, conv}, {conv, relu}}; + for (const auto& pair : consumer_list) { + dag->access_analyzer.GetConsumers(s0, s0->stages[pair.first]->op, &op_set); + CHECK_EQ(op_set.size(), 1); + CHECK_EQ((*op_set.begin()), s0->stages[pair.second]->op); + } + std::vector>> producer_list = {{padding, {data}}, + {conv, {padding, kernel}}, + {bias_add, {conv, bias}}, + {bn_mul, {bias_add, bn_scale}}, + {bn_add, {bn_mul, bn_offset}}, + {relu, {bn_add}}}; + for (const auto& pair : producer_list) { + dag->access_analyzer.GetDirectProducers(s0->stages[pair.first]->op, &op_set); + CHECK_EQ(op_set.size(), pair.second.size()); + for (const auto& target : pair.second) { + CHECK(op_set.count(s0->stages[target]->op)); + } + } + } } int main(int argc, char** argv) { diff --git a/tests/python/unittest/test_auto_scheduler_common.py b/tests/python/unittest/test_auto_scheduler_common.py index 1114fb4e75b8..fa22fdc5597c 100644 --- a/tests/python/unittest/test_auto_scheduler_common.py +++ b/tests/python/unittest/test_auto_scheduler_common.py @@ -33,7 +33,7 @@ def matmul_auto_scheduler_test(N, M, K): @auto_scheduler.register_workload("matmul_auto_scheduler_test_rename_1") -def matmul_auto_scheduler_test_rename_1(N, M, K): +def matmul_auto_scheduler_test_rename_0(N, M, K): A = te.placeholder((N, K), name='A') B = te.placeholder((K, M), name='B') k = te.reduce_axis((0, K), name='k') From 1b0f69307ab516f6e9f74d84caee035f50ee4280 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 21 Jul 2020 22:57:35 -0700 Subject: [PATCH 4/8] fix lint --- include/tvm/auto_scheduler/compute_dag.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/tvm/auto_scheduler/compute_dag.h b/include/tvm/auto_scheduler/compute_dag.h index 3cfccf4cd352..de158915d266 100644 --- a/include/tvm/auto_scheduler/compute_dag.h +++ b/include/tvm/auto_scheduler/compute_dag.h @@ -147,8 +147,8 @@ class AccessAnalyzer : public ObjectRef { * \param target_op The target operation * \note This function propagates the relation for chains with multiple ops. */ - TVM_DLL int GetNumCommonOuterIterator( - const te::Operation& op, const te::Operation& target_op) const; + TVM_DLL int GetNumCommonOuterIterator(const te::Operation& op, + const te::Operation& target_op) const; /*! * \brief Return whether two operations are elementwise-matched From 64244ab172a300057fe709acfc33a7c0c51b654d Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 21 Jul 2020 23:16:18 -0700 Subject: [PATCH 5/8] update --- include/tvm/auto_scheduler/loop_state.h | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/include/tvm/auto_scheduler/loop_state.h b/include/tvm/auto_scheduler/loop_state.h index ab7a52081b93..8bbeaecbdf89 100644 --- a/include/tvm/auto_scheduler/loop_state.h +++ b/include/tvm/auto_scheduler/loop_state.h @@ -297,14 +297,14 @@ class State : public ObjectRef { * this input. * \return The iterator result after binded. */ - Iterator bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type); + TVM_DLL Iterator bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type); /*! * \brief Schedule primitive corresponds to te.parallel. * \param stage_id The index of the stage to be paralleled. * \param it The iterator to be paralleled. * \return The iterator result after parallel. */ - Iterator parallel(int stage_id, const Iterator& it); + TVM_DLL Iterator parallel(int stage_id, const Iterator& it); /*! * \brief Schedule primitive corresponds to te.unroll. * \param stage_id The index of the stage to be unrolled. @@ -313,14 +313,14 @@ class State : public ObjectRef { * skipped. * \return The iterator result after unrolled. */ - Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1); + TVM_DLL Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1); /*! * \brief Schedule primitive corresponds to te.vectorize. * \param stage_id The index of the stage to be vectorized. * \param it The iterator to be vectorized. * \return The iterator result after vectorize. */ - Iterator vectorize(int stage_id, const Iterator& it); + TVM_DLL Iterator vectorize(int stage_id, const Iterator& it); /*! * \brief Schedule primitive corresponds to te.fuse. * \param stage_id The index of the stage to be fused. @@ -329,13 +329,13 @@ class State : public ObjectRef { * \note If the iterators to be fused have stages attached at them(by compute_at), the fused * result will become the new attach point. */ - Iterator fuse(int stage_id, const Array& iters); + TVM_DLL Iterator fuse(int stage_id, const Array& iters); /*! * \brief Schedule primitive corresponds to te.reorder. * \param stage_id The index of the stage to be reordered. * \param order The expected iterator order. */ - void reorder(int stage_id, const Array& order); + TVM_DLL void reorder(int stage_id, const Array& order); /*! * \brief Schedule primitive corresponds to te.split. * \param stage_id The index of the stage to be split. @@ -346,8 +346,9 @@ class State : public ObjectRef { * \note If we do split on an iterator which has stages attached at it(by compute_at), the inner * most iterator of split results will become the new attach point. */ - Array split(int stage_id, const Iterator& it, const Array>& lengths, - bool inner_to_outer = true); + TVM_DLL Array split( + int stage_id, const Iterator& it, const Array>& lengths, + bool inner_to_outer = true); /********** Step APIs working on multiple stages **********/ @@ -361,12 +362,12 @@ class State : public ObjectRef { * bound for the newly created iterators. * Call ComputeDAG::InferBound on the updated state to get the complete bound information. */ - void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter); + TVM_DLL void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter); /*! * \brief Schedule primitive corresponds to te.compute_inline. * \param stage_id The index of the stage to be reordered. */ - void compute_inline(int stage_id); + TVM_DLL void compute_inline(int stage_id); /*! * \brief Schedule primitive corresponds to te.compute_root. * \param stage_id The index of the stage to be reordered. @@ -375,7 +376,7 @@ class State : public ObjectRef { * bound for the newly created iterators. * Call ComputeDAG::InferBound on the updated state to get the complete bound information. */ - void compute_root(int stage_id); + TVM_DLL void compute_root(int stage_id); TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode); From bea7dcfa6c29d50725383d03acd9e42f9e629569 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 21 Jul 2020 23:18:42 -0700 Subject: [PATCH 6/8] fix lint --- include/tvm/auto_scheduler/loop_state.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/include/tvm/auto_scheduler/loop_state.h b/include/tvm/auto_scheduler/loop_state.h index 8bbeaecbdf89..00de0c801568 100644 --- a/include/tvm/auto_scheduler/loop_state.h +++ b/include/tvm/auto_scheduler/loop_state.h @@ -346,9 +346,9 @@ class State : public ObjectRef { * \note If we do split on an iterator which has stages attached at it(by compute_at), the inner * most iterator of split results will become the new attach point. */ - TVM_DLL Array split( - int stage_id, const Iterator& it, const Array>& lengths, - bool inner_to_outer = true); + TVM_DLL Array split(int stage_id, const Iterator& it, + const Array>& lengths, + bool inner_to_outer = true); /********** Step APIs working on multiple stages **********/ From c690c32252249504e3ae2be25b80170476861fa6 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 24 Jul 2020 17:03:33 -0700 Subject: [PATCH 7/8] address comments --- include/tvm/auto_scheduler/compute_dag.h | 55 +++--- include/tvm/auto_scheduler/loop_state.h | 4 +- python/tvm/autotvm/task/relay_integration.py | 1 + src/auto_scheduler/compute_dag.cc | 171 +++++++++---------- src/auto_scheduler/utils.h | 18 ++ tests/cpp/auto_scheduler_test.cc | 18 +- 6 files changed, 138 insertions(+), 129 deletions(-) diff --git a/include/tvm/auto_scheduler/compute_dag.h b/include/tvm/auto_scheduler/compute_dag.h index de158915d266..b9c1f9e8c45c 100644 --- a/include/tvm/auto_scheduler/compute_dag.h +++ b/include/tvm/auto_scheduler/compute_dag.h @@ -51,22 +51,27 @@ namespace auto_scheduler { class AccessAnalyzerNode : public Object { public: template - using OperationMap = std::unordered_map; + using OperationMap = std::unordered_map; /*! \brief Map an operation to all operations it reads from. - * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses*/ + * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses + * The inner vector represents the indices of multi-dimensional access.*/ OperationMap>>> read_from; /*! \brief Map an operation to all operations it is read by. - * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses*/ + * For each operation pair, use a two-dimentional array to multiple multi-dimentional accesses + * The inner vector represents the indices of multi-dimensional access.*/ OperationMap>>> read_by; /*! \brief Store the number of common outer iterators for operation pairs that have * read-write relations. */ OperationMap> num_common_outer_iterators; - /*! \brief Store whether the operation is injective */ - OperationMap is_injective; - /*! \brief Store whether the operation is strictly-inlineable */ + /*! \brief Store whether the operation is an op with only simple access. + * (e.g., injective, broadcast and elementwise ops without reduction) */ + OperationMap is_simple_access; + /*! \brief Store whether the operation is strictly-inlineable + * (e.g., injective, broadcast and elementwise without reduction, branch or expenive operations) */ OperationMap is_strict_inlineable; - /*! \brief Store whether the operation needs multi-level tiling */ + /*! \brief Store whether the operation needs multi-level tiling + * (e.g., computation-intensive ops with data reuse opportunity like matmul, conv2d) */ OperationMap needs_multi_level_tiling; /*! \brief Store whether the operation is an output operation */ OperationMap is_output; @@ -86,22 +91,25 @@ class AccessAnalyzer : public ObjectRef { explicit AccessAnalyzer(const Array& tensors); /*! - * \brief Return whether this operation needs multi-level tiling + * \brief Return whether this operation is an injective operation + * (e.g., injective, broadcast and elementwise ops without reduction) * \param op The operation */ - TVM_DLL bool NeedsMultiLevelTiling(const te::Operation& op) const; + TVM_DLL bool IsSimpleAccess(const te::Operation& op) const; /*! - * \brief Return whether this operation is an injective operation + * \brief Return whether this operation is strictly inlinable + * (e.g., injective, broadcast and elementwise without reduction, branch or expenive operations) * \param op The operation */ - TVM_DLL bool IsInjective(const te::Operation& op) const; + TVM_DLL bool IsStrictInlineable(const te::Operation& op) const; /*! - * \brief Return whether this operation is strictly inlinable + * \brief Return whether this operation needs multi-level tiling + * (e.g., computation-intensive ops with data reuse opportunity like matmul, conv2d) * \param op The operation */ - TVM_DLL bool IsStrictInlineable(const te::Operation& op) const; + TVM_DLL bool NeedsMultiLevelTiling(const te::Operation& op) const; /*! * \brief Return whether this operation is an output op @@ -113,33 +121,30 @@ class AccessAnalyzer : public ObjectRef { * \brief Get all consumers of on operation * \param state The current loop state * \param op The operation - * \param consumers The return consumer set + * \return The set of consumers * \note This function propagates the relation for inlined ops */ - TVM_DLL void GetConsumers( - const State& state, const te::Operation& op, - std::unordered_set* consumers) const; + TVM_DLL std::unordered_set GetConsumers( + const State& state, const te::Operation& op) const; /*! * \brief Get all producers of on operation * \param state The current loop state * \param op The operation - * \param producers The return producer set + * \return The set of producers * \note This function propagates the relation for inlined ops */ - TVM_DLL void GetProducers( - const State& state, const te::Operation& op, - std::unordered_set* producers) const; + TVM_DLL std::unordered_set GetProducers( + const State& state, const te::Operation& op) const; /*! * \brief Get all direct producers of on operation * \param op The operation - * \param producers The return producer set + * \return The set of direct producers * \note This function DOES NOT propagate the relation for inlined ops */ - TVM_DLL void GetDirectProducers( - const te::Operation& op, - std::unordered_set* producers) const; + TVM_DLL std::unordered_set GetDirectProducers( + const te::Operation& op) const; /*! * \brief Get the number of common outer iterators. diff --git a/include/tvm/auto_scheduler/loop_state.h b/include/tvm/auto_scheduler/loop_state.h index 00de0c801568..4e9cb9bd7d20 100644 --- a/include/tvm/auto_scheduler/loop_state.h +++ b/include/tvm/auto_scheduler/loop_state.h @@ -159,7 +159,7 @@ using IterKey = std::pair; */ class AttachMapNode : public Object { public: - struct key_hash : public std::function { + struct IterKeyHash { std::size_t operator()(const IterKey& k) const { return ::dmlc::HashCombine(std::hash()(k.first), std::hash()(k.second)); } @@ -168,7 +168,7 @@ class AttachMapNode : public Object { /*! \brief A Map to store the mapping of stage to its attached iterator. */ std::unordered_map stage_to_attach_iter; /*! \brief A Map to store the mapping of iterator to the stage attached to it. */ - std::unordered_map, key_hash> iter_to_attached_stages; + std::unordered_map, IterKeyHash> iter_to_attached_stages; static constexpr const char* _type_key = "auto_scheduler.AttachMap"; TVM_DECLARE_FINAL_OBJECT_INFO(AttachMapNode, Object); diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py index 9a43f2f1ad95..70f32eb81a75 100644 --- a/python/tvm/autotvm/task/relay_integration.py +++ b/python/tvm/autotvm/task/relay_integration.py @@ -26,6 +26,7 @@ import tvm from .task import create from .topi_integration import TaskExtractEnv +from .dispatcher import FallbackContext logger = logging.getLogger('autotvm') diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index 92239e51101a..a7e0923285c8 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -38,6 +38,7 @@ #include #include "utils.h" +#include "../arith/pattern_match.h" namespace tvm { namespace auto_scheduler { @@ -119,7 +120,7 @@ Array TopoSortOps(const Array& tensors) { } // Extract all tensor accesses in an expr -class TensorAccessExtractor : public StmtExprVisitor { +class ReadAccessExtractor : public StmtExprVisitor { public: void Extract(PrimExpr expr) { this->VisitExpr(expr); } @@ -131,8 +132,8 @@ class TensorAccessExtractor : public StmtExprVisitor { } void VisitExpr_(const ProducerLoadNode* op) final { - buf_accesses[Downcast(op->producer)->op].emplace_back(op->indices.begin(), - op->indices.end()); + read_access[Downcast(op->producer)->op].emplace_back(op->indices.begin(), + op->indices.end()); StmtExprVisitor::VisitExpr_(op); } @@ -146,28 +147,31 @@ class TensorAccessExtractor : public StmtExprVisitor { StmtExprVisitor::VisitExpr_(op); } - OperationMap>> buf_accesses; + // All read accesses to all operations + // The innermost vector stores mulit-dimentional indices. + // The middle vector stores possible multiple accesses + OperationMap>> read_access; + // Whether this expression has branch bool has_branch{false}; }; -// Returns whether the expr equals to the var with a const shift +// Returns whether the expr equals to the var with an optional const shift bool IsConstShiftEqual(const Var& var, const PrimExpr& expr) { - if (auto pv = expr.as()) { - return pv == var.get(); - } else if (auto padd = expr.as()) { - return ((padd->a.get() == var.get() && padd->b->IsInstance()) || - (padd->b.get() == var.get() && padd->a->IsInstance())); - } else if (auto psub = expr.as()) { - return ((psub->a.get() == var.get() && psub->b->IsInstance()) || - (psub->b.get() == var.get() && psub->a->IsInstance())); - } else { - return false; + arith::PVar x; + arith::PVar c; + + if (((x + c).Match(expr) || (x - c).Match(expr) || (c + x).Match(expr) || x.Match(expr)) && + x.Eval().same_as(var)) { + return true; } + return false; } -// Return whether the access is injective -bool IsInjective(const te::Operation& op, const std::vector& index, bool* axis_missing, - bool* axis_duplicated, bool* same_order) { +// Return whether the access to an operation is a simple access +// (i.e. all index is just a variable with an optional constant shift) +// For example, A[i][j], A[i+1][j] are simple accesses but A[i][j+i] is not. +bool IsSimpleAccess(const te::Operation& op, const std::vector& indices, + bool* axis_missing, bool* axis_duplicated, bool* same_order) { auto cop = op.as(); if (cop == nullptr) { return false; @@ -176,7 +180,7 @@ bool IsInjective(const te::Operation& op, const std::vector& index, bo std::vector index_to_var_idx; std::vector var_idx_ct(cop->axis.size(), 0); - for (const auto& expr : index) { + for (const auto& expr : indices) { if (!is_const_int(expr)) { bool found = false; for (size_t i = 0; i < cop->axis.size(); ++i) { @@ -214,7 +218,7 @@ bool IsInjective(const te::Operation& op, const std::vector& index, bo } // Gather all VarNodes in an expr -static void GatherVars(const PrimExpr& expr, std::unordered_set* vars) { +void GatherVars(const PrimExpr& expr, std::unordered_set* vars) { PostOrderVisit(expr, [&vars](const ObjectRef& node) { if (const VarNode* op = node.as()) { vars->insert(op); @@ -223,7 +227,7 @@ static void GatherVars(const PrimExpr& expr, std::unordered_set* } // Check whether an expr has expensive operations (e.g. exp) -static bool HasExpensiveOp(const PrimExpr& expr) { +bool HasExpensiveOp(const PrimExpr& expr) { bool found = false; PostOrderVisit(expr, [&found](const ObjectRef& node) { if (const CallNode* op = node.as()) { @@ -239,28 +243,28 @@ AccessAnalyzer::AccessAnalyzer(const Array& tensors) { auto node = make_object(); OperationMap has_branch; - // get all ops + // Get all ops in topological order node->ops_topo_order = TopoSortOps(tensors); arith::Analyzer analyzer; - // build read & write access map + // Build read & write access map for (const auto& op : node->ops_topo_order) { if (op->IsInstance()) { node->read_from[op] = OperationMap>>(); } else if (auto cop = op.as()) { - TensorAccessExtractor extractor; + ReadAccessExtractor extractor; for (const auto& exp : cop->body) { extractor.Extract(exp); } // read_by and read_from map - for (const auto& iter : extractor.buf_accesses) { + for (const auto& iter : extractor.read_access) { std::vector>& accesses = node->read_by[iter.first][op]; accesses.insert(accesses.begin(), iter.second.begin(), iter.second.end()); } - node->read_from[op] = std::move(extractor.buf_accesses); + node->read_from[op] = std::move(extractor.read_access); has_branch[op] = extractor.has_branch; // compute number of common outer iterators @@ -278,15 +282,15 @@ AccessAnalyzer::AccessAnalyzer(const Array& tensors) { break; } - bool direct_access = true; + bool injective = true; for (const auto& access : access_list) { if (!IsConstShiftEqual(cop->axis[n_common]->var, access[n_common])) { - direct_access = false; + injective = false; break; } } - if (!direct_access) { + if (!injective) { break; } } @@ -299,25 +303,25 @@ AccessAnalyzer::AccessAnalyzer(const Array& tensors) { } } - // do some static analysis + // Do some static analysis on ComputeOps for (const auto& op : node->ops_topo_order) { if (op->IsInstance()) { - node->is_injective[op] = true; + node->is_simple_access[op] = true; node->needs_multi_level_tiling[op] = false; node->is_strict_inlineable[op] = false; node->is_output[op] = false; - } else if (auto pop = op.as()) { + } else if (auto cop = op.as()) { // check whether this op is element-wise and strict-inlineable - bool is_injective = true; + bool is_simple_access = true; bool is_strict_inlineable = true; bool axis_missing, axis_duplicated, same_order; for (const auto& pair : node->read_from[op]) { - const std::vector>& access = pair.second; - for (const auto& index : access) { - if (!auto_scheduler::IsInjective(op, index, &axis_missing, &axis_duplicated, - &same_order)) { - is_injective = false; + const std::vector>& access_list = pair.second; + for (const auto& access : access_list) { + if (!auto_scheduler::IsSimpleAccess(op, access, &axis_missing, &axis_duplicated, + &same_order)) { + is_simple_access = false; is_strict_inlineable = false; break; } @@ -326,46 +330,44 @@ AccessAnalyzer::AccessAnalyzer(const Array& tensors) { is_strict_inlineable = false; } } - if (!is_injective) { + if (!is_simple_access) { break; } } - if (has_branch[op]) { - is_strict_inlineable = false; - } // don't strictly inline expensive op (e.g. exp) bool has_expensive_op = false; - for (const auto& expr : pop->body) { + for (const auto& expr : cop->body) { has_expensive_op |= HasExpensiveOp(expr); } + if (has_expensive_op || has_branch[op]) { + is_strict_inlineable = false; + } - node->is_injective[op] = is_injective; - node->is_strict_inlineable[op] = is_strict_inlineable && !has_expensive_op; + node->is_simple_access[op] = is_simple_access; + node->is_strict_inlineable[op] = is_strict_inlineable; // check whether the op needs multi-level tiling bool needs_multi_level_tiling = false; int n_missing = 0; for (const auto& pair : node->read_from[op]) { - const std::vector>& access = pair.second; + const std::vector>& access_list = pair.second; std::unordered_set vars; - for (const std::vector& indices : access) { - for (const PrimExpr& expr : indices) { + for (const std::vector& access : access_list) { + for (const PrimExpr& expr : access) { GatherVars(expr, &vars); } } - bool missing = false; - for (const auto& axis : pop->axis) { + + for (const auto& axis : cop->axis) { if (GetIntImm(axis->dom->extent) > 1 && vars.count(axis->var.get()) == 0) { - missing = true; + n_missing++; + break; } } - if (missing) { - n_missing++; - } - if (n_missing >= 2 || (n_missing >= 1 && !pop->reduce_axis.empty())) { + if (n_missing >= 2 || (n_missing >= 1 && !cop->reduce_axis.empty())) { needs_multi_level_tiling = true; break; } @@ -373,7 +375,7 @@ AccessAnalyzer::AccessAnalyzer(const Array& tensors) { node->needs_multi_level_tiling[op] = needs_multi_level_tiling; - // check whether is output + // check whether the op is output node->is_output[op] = node->read_by[op].empty(); } else { LOG(FATAL) << "Invalid op" << op; @@ -391,16 +393,15 @@ bool AccessAnalyzer::IsOutput(const te::Operation& op) const { return operator->()->is_output.at(op); } -bool AccessAnalyzer::IsInjective(const te::Operation& op) const { - return operator->()->is_injective.at(op); +bool AccessAnalyzer::IsSimpleAccess(const te::Operation& op) const { + return operator->()->is_simple_access.at(op); } bool AccessAnalyzer::IsStrictInlineable(const te::Operation& op) const { return operator->()->is_strict_inlineable.at(op); } -void AccessAnalyzer::GetConsumers(const State& state, const te::Operation& op, - OperationSet* consumers) const { +OperationSet AccessAnalyzer::GetConsumers(const State& state, const te::Operation& op) const { OperationSet inlined_ops; for (const auto& stage : state->stages) { if (stage->compute_at == ComputeAtKind::kInlined) { @@ -408,30 +409,31 @@ void AccessAnalyzer::GetConsumers(const State& state, const te::Operation& op, } } + OperationSet consumers; std::function collect; collect = [this, &collect, &inlined_ops, &consumers](const te::Operation& op) { for (const auto& iter : operator->()->read_by.at(op)) { if (inlined_ops.count(iter.first)) { collect(iter.first); } else { - consumers->insert(iter.first); + consumers.insert(iter.first); } } }; - consumers->clear(); collect(op); + return consumers; } -void AccessAnalyzer::GetDirectProducers(const te::Operation& op, OperationSet* producers) const { - producers->clear(); +OperationSet AccessAnalyzer::GetDirectProducers(const te::Operation& op) const { + OperationSet producers; for (const auto& iter : operator->()->read_from.at(op)) { - producers->insert(iter.first); + producers.insert(iter.first); } + return producers; } -void AccessAnalyzer::GetProducers(const State& state, const te::Operation& op, - OperationSet* producers) const { +OperationSet AccessAnalyzer::GetProducers(const State& state, const te::Operation& op) const { OperationSet inlined_ops; for (const auto& stage : state->stages) { if (stage->compute_at == ComputeAtKind::kInlined) { @@ -439,19 +441,20 @@ void AccessAnalyzer::GetProducers(const State& state, const te::Operation& op, } } + OperationSet producers; std::function collect; collect = [this, &collect, &inlined_ops, &producers](const te::Operation& op) { for (const auto& iter : operator->()->read_from.at(op)) { if (inlined_ops.count(iter.first)) { collect(iter.first); } else { - producers->insert(iter.first); + producers.insert(iter.first); } } }; - producers->clear(); collect(op); + return producers; } int AccessAnalyzer::GetNumCommonOuterIterator(const te::Operation& op, @@ -478,24 +481,6 @@ int AccessAnalyzer::GetNumCommonOuterIterator(const te::Operation& op, return meet ? ret : 0; } -// Return whether two int arrays are elementwise-equal -bool IntArrayEqual(const Array& arr1, const Array& arr2) { - if (arr1.size() != arr2.size()) { - return false; - } - - for (size_t i = 0; i < arr1.size(); ++i) { - auto int1 = arr1[i].as(); - auto int2 = arr2[i].as(); - CHECK(int1 != nullptr); - CHECK(int2 != nullptr); - if (int1->value != int2->value) { - return false; - } - } - return true; -} - bool AccessAnalyzer::ElementWiseMatch(const te::Operation& op, const te::Operation& target_op) const { te::Operation cur_op = op; @@ -508,7 +493,7 @@ bool AccessAnalyzer::ElementWiseMatch(const te::Operation& op, } te::Operation next_op = map.begin()->first; - // Check condition 1: has the same output size + // Check condition 1: They have the same output size auto p_cur = cur_op.as(); auto p_next = next_op.as(); if (p_cur == nullptr || p_next == nullptr) { @@ -527,13 +512,13 @@ bool AccessAnalyzer::ElementWiseMatch(const te::Operation& op, } } - // Check condition 2: read is elementwise + // Check condition 2: The read is elementwise const std::vector> reads = map.begin()->second; - bool is_injective, axis_missing, axis_duplicated, same_order; + bool is_simple_access, axis_missing, axis_duplicated, same_order; for (const auto& read : reads) { - is_injective = - auto_scheduler::IsInjective(next_op, read, &axis_missing, &axis_duplicated, &same_order); - if (!is_injective || axis_missing || axis_duplicated || !same_order) { + is_simple_access = auto_scheduler::IsSimpleAccess(next_op, read, &axis_missing, + &axis_duplicated, &same_order); + if (!is_simple_access || axis_missing || axis_duplicated || !same_order) { return false; } } diff --git a/src/auto_scheduler/utils.h b/src/auto_scheduler/utils.h index de800da13b64..da5032e11c97 100644 --- a/src/auto_scheduler/utils.h +++ b/src/auto_scheduler/utils.h @@ -128,6 +128,24 @@ inline std::vector IntArrayToVector( return out; } +/*! \brief Return whether two int arrays are elementwise-equal */ +inline bool IntArrayEqual(const Array& arr1, const Array& arr2) { + if (arr1.size() != arr2.size()) { + return false; + } + + for (size_t i = 0; i < arr1.size(); ++i) { + auto int1 = arr1[i].as(); + auto int2 = arr2[i].as(); + CHECK(int1 != nullptr); + CHECK(int2 != nullptr); + if (int1->value != int2->value) { + return false; + } + } + return true; +} + /********** Utilities for TVM Containers / ByteArray **********/ /*! \brief Compute mean of a FloatImm array */ inline double FloatArrayMean(const Array& float_array) { diff --git a/tests/cpp/auto_scheduler_test.cc b/tests/cpp/auto_scheduler_test.cc index f21fe1f5c57b..85266057548c 100644 --- a/tests/cpp/auto_scheduler_test.cc +++ b/tests/cpp/auto_scheduler_test.cc @@ -82,13 +82,13 @@ TEST(ComputeDAG, AccessAnalyzer) { } } - std::set is_injective = {data, padding, kernel, bias, bias_add, - bn_scale, bn_mul, bn_offset, bn_add, relu}; + std::set is_simple_access = {data, padding, kernel, bias, bias_add, + bn_scale, bn_mul, bn_offset, bn_add, relu}; for (size_t stage_id = 0; stage_id < dag->ops.size(); stage_id++) { - if (is_injective.count(stage_id)) { - CHECK(dag->access_analyzer.IsInjective(dag->ops[stage_id])); + if (is_simple_access.count(stage_id)) { + CHECK(dag->access_analyzer.IsSimpleAccess(dag->ops[stage_id])); } else { - CHECK(!dag->access_analyzer.IsInjective(dag->ops[stage_id])); + CHECK(!dag->access_analyzer.IsSimpleAccess(dag->ops[stage_id])); } } @@ -125,7 +125,7 @@ TEST(ComputeDAG, AccessAnalyzer) { {bias, bias_add}, {bias_add, bn_mul}, {bn_scale, bn_mul}, {bn_mul, bn_add}, {bn_offset, bn_add}, {bn_add, relu}}; for (const auto& pair : consumer_list) { - dag->access_analyzer.GetConsumers(s0, s0->stages[pair.first]->op, &op_set); + op_set = dag->access_analyzer.GetConsumers(s0, s0->stages[pair.first]->op); CHECK_EQ(op_set.size(), 1); CHECK_EQ((*op_set.begin()), s0->stages[pair.second]->op); } @@ -136,7 +136,7 @@ TEST(ComputeDAG, AccessAnalyzer) { {bn_add, {bn_mul, bn_offset}}, {relu, {bn_add}}}; for (const auto& pair : producer_list) { - dag->access_analyzer.GetProducers(s0, s0->stages[pair.first]->op, &op_set); + op_set = dag->access_analyzer.GetProducers(s0, s0->stages[pair.first]->op); CHECK_EQ(op_set.size(), pair.second.size()); for (const auto& target : pair.second) { CHECK(op_set.count(s0->stages[target]->op)); @@ -151,7 +151,7 @@ TEST(ComputeDAG, AccessAnalyzer) { { std::vector> consumer_list = {{data, conv}, {kernel, conv}, {conv, relu}}; for (const auto& pair : consumer_list) { - dag->access_analyzer.GetConsumers(s0, s0->stages[pair.first]->op, &op_set); + op_set = dag->access_analyzer.GetConsumers(s0, s0->stages[pair.first]->op); CHECK_EQ(op_set.size(), 1); CHECK_EQ((*op_set.begin()), s0->stages[pair.second]->op); } @@ -162,7 +162,7 @@ TEST(ComputeDAG, AccessAnalyzer) { {bn_add, {bn_mul, bn_offset}}, {relu, {bn_add}}}; for (const auto& pair : producer_list) { - dag->access_analyzer.GetDirectProducers(s0->stages[pair.first]->op, &op_set); + op_set = dag->access_analyzer.GetDirectProducers(s0->stages[pair.first]->op); CHECK_EQ(op_set.size(), pair.second.size()); for (const auto& target : pair.second) { CHECK(op_set.count(s0->stages[target]->op)); From e03268681838a8857457f1cced201be658f5fa3e Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 24 Jul 2020 17:07:46 -0700 Subject: [PATCH 8/8] fix lint --- include/tvm/auto_scheduler/compute_dag.h | 3 ++- python/tvm/autotvm/task/relay_integration.py | 1 - src/auto_scheduler/compute_dag.cc | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/include/tvm/auto_scheduler/compute_dag.h b/include/tvm/auto_scheduler/compute_dag.h index b9c1f9e8c45c..71652fd692fa 100644 --- a/include/tvm/auto_scheduler/compute_dag.h +++ b/include/tvm/auto_scheduler/compute_dag.h @@ -68,7 +68,8 @@ class AccessAnalyzerNode : public Object { * (e.g., injective, broadcast and elementwise ops without reduction) */ OperationMap is_simple_access; /*! \brief Store whether the operation is strictly-inlineable - * (e.g., injective, broadcast and elementwise without reduction, branch or expenive operations) */ + * (e.g., injective, broadcast and elementwise without reduction, branch or expenive operations) + */ OperationMap is_strict_inlineable; /*! \brief Store whether the operation needs multi-level tiling * (e.g., computation-intensive ops with data reuse opportunity like matmul, conv2d) */ diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py index 70f32eb81a75..9a43f2f1ad95 100644 --- a/python/tvm/autotvm/task/relay_integration.py +++ b/python/tvm/autotvm/task/relay_integration.py @@ -26,7 +26,6 @@ import tvm from .task import create from .topi_integration import TaskExtractEnv -from .dispatcher import FallbackContext logger = logging.getLogger('autotvm') diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index a7e0923285c8..68d1bb42c493 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -37,8 +37,8 @@ #include #include -#include "utils.h" #include "../arith/pattern_match.h" +#include "utils.h" namespace tvm { namespace auto_scheduler { @@ -517,7 +517,7 @@ bool AccessAnalyzer::ElementWiseMatch(const te::Operation& op, bool is_simple_access, axis_missing, axis_duplicated, same_order; for (const auto& read : reads) { is_simple_access = auto_scheduler::IsSimpleAccess(next_op, read, &axis_missing, - &axis_duplicated, &same_order); + &axis_duplicated, &same_order); if (!is_simple_access || axis_missing || axis_duplicated || !same_order) { return false; }