Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
1a62c23
CombineContextCall
tqchen Jan 1, 2020
a33ee1c
Migrate BoundChecker
tqchen Jan 1, 2020
3bbc29b
Migrate CoprocSync
tqchen Jan 1, 2020
9dcaaa9
Migrate detect_device
tqchen Jan 1, 2020
335c2aa
Migrate loop_partition
tqchen Jan 1, 2020
5b58263
Migrate infer_fragement
tqchen Jan 1, 2020
6e853f3
Migrate inject_copy_intrin
tqchen Jan 1, 2020
5407c8d
Migrate inject double buffer
tqchen Jan 1, 2020
eefd692
Migrate lower_intrin and simplify
tqchen Jan 1, 2020
f3618af
Migrate storage flatten
tqchen Jan 1, 2020
a2be373
Migrate inject prefetch
tqchen Jan 1, 2020
5b8f962
Migrate inject_virtual_thread
tqchen Jan 2, 2020
bb68e4e
migrate inline
tqchen Jan 2, 2020
9dd6283
Migrate lift attr scope
tqchen Jan 2, 2020
b22c095
Migrate custom datatypes
tqchen Jan 2, 2020
2c0c7e8
migrate lower_thread_all_reduce
tqchen Jan 2, 2020
2776c44
Migrate lower_tvm_builtin
tqchen Jan 2, 2020
15be2bf
migrate lower_warp memory
tqchen Jan 2, 2020
47cccf6
Migrate make_api.cc
tqchen Jan 2, 2020
9400333
Migrate remap_thread_axis
tqchen Jan 2, 2020
5b870ba
Migrate remove_no_op
tqchen Jan 2, 2020
6fff863
migrate rewrite_unsafe_select
tqchen Jan 2, 2020
bf21e7e
Migrate skip_assert simple_passes
tqchen Jan 2, 2020
42bf9e9
Migrate split_host_device
tqchen Jan 2, 2020
fdd5a74
Migrate ssa
tqchen Jan 2, 2020
2bae517
Migrate storage_access
tqchen Jan 2, 2020
0397c76
Migrate storage_rewrite
tqchen Jan 2, 2020
d6b476b
Migrate tensor_core
tqchen Jan 2, 2020
6518850
Migrate unroll_loop
tqchen Jan 2, 2020
ef1d1e8
Migrate vectorize
tqchen Jan 2, 2020
4c2c348
Migrate verify compact_buffer gpu_code
tqchen Jan 2, 2020
ff0e27a
Migrate verify_memory
tqchen Jan 2, 2020
c38e155
Migrate storage_sync
tqchen Jan 2, 2020
9a1b932
Remove unused refs to mutator
tqchen Jan 2, 2020
0dbabeb
Migrate hybrid_op
tqchen Jan 2, 2020
d3b980a
Migrate tensorize
tqchen Jan 2, 2020
6145036
Migrate schedule ops
tqchen Jan 2, 2020
eb39e63
Migrate schedule_dataflow_rewrite
tqchen Jan 2, 2020
3b5b63b
Migrate auto_inline_elemwise
tqchen Jan 2, 2020
34888bb
Remove unecessary ref to visitor
tqchen Jan 2, 2020
0ab8bdc
remove unecessary ref
tqchen Jan 2, 2020
bb2256b
Migrate bound_deducer
tqchen Jan 2, 2020
9209cca
Migrate domain_touched
tqchen Jan 2, 2020
1645d55
Migrate autotvm feature touch extractor
tqchen Jan 2, 2020
b48e228
Add annotations
tqchen Jan 2, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions include/tvm/ir_functor_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,35 @@ class StmtExprMutator :
}
};

/*!
* \brief recursively visit the ir in post DFS order node, and transform it
*
* \param node The ir to be transformed.
* \param preorder The function called in before recursive mutation
* If preorder returns None, then the transform will proceed to recursive call.
* If preorder returns a not None Stmt/Expr, the transformer will simply return it and
* won't do further recursion.
* \param postorder The function called after recursive mutation.
* The recursive mutation result is passed to postorder for further mutation.
* \param only_enable List of StringImm.
* If it is empty, all IRNode will call preorder/postorder
* If it is not empty, preorder/postorder will only be called
* when the IRNode's type key is in the list.
*/
TVM_DLL Stmt IRTransform(Stmt node,
const runtime::PackedFunc& preorder,
const runtime::PackedFunc& postorder,
const Array<Expr>& only_enable = {});

/*!
* \brief recursively visit the ir in post DFS order node, apply fvisit
* Each node is guaranteed to be visited only once.
* \param node The ir to be visited.
* \param fvisit The visitor function to be applied.
*/
TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function<void(const ObjectRef&)> fvisit);


} // namespace ir
} // namespace tvm
#endif // TVM_IR_FUNCTOR_EXT_H_
21 changes: 0 additions & 21 deletions include/tvm/ir_mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,27 +122,6 @@ class TVM_DLL IRMutator {
virtual Expr Mutate_(const StringImm* op, const Expr& e);
virtual Expr Mutate_(const Shuffle* op, const Expr& e);
};


/*!
* \brief recursively visit the ir in post DFS order node, and transform it
*
* \param node The ir to be transformed.
* \param preorder The function called in before recursive mutation
* If preorder returns None, then the transform will proceed to recursive call.
* If preorder returns a not None Stmt/Expr, the transformer will simply return it and
* won't do further recursion.
* \param postorder The function called after recursive mutation.
* The recursive mutation result is passed to postorder for further mutation.
* \param only_enable List of StringImm.
* If it is empty, all IRNode will call preorder/postorder
* If it is not empty, preorder/postorder will only be called
* when the IRNode's type key is in the list.
*/
Stmt IRTransform(Stmt node,
const runtime::PackedFunc& preorder,
const runtime::PackedFunc& postorder,
const Array<Expr>& only_enable = {});
} // namespace ir
} // namespace tvm
#endif // TVM_IR_MUTATOR_H_
9 changes: 0 additions & 9 deletions include/tvm/ir_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,6 @@ class TVM_DLL IRVisitor {
virtual void Visit_(const FloatImm* op);
virtual void Visit_(const StringImm* op);
};

/*!
* \brief recursively visit the ir in post DFS order node, apply fvisit
* Each node is guaranteed to be visited only once.
* \param node The ir to be visited.
* \param fvisit The visitor function to be applied.
*/
TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function<void(const ObjectRef&)> fvisit);

} // namespace ir
} // namespace tvm

Expand Down
3 changes: 1 addition & 2 deletions src/api/api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
#include <tvm/ir.h>
#include <tvm/attrs.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/api_registry.h>

namespace tvm {
Expand Down
46 changes: 23 additions & 23 deletions src/arithmetic/bound_deducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
*/
#include <tvm/expr.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/arithmetic.h>
#include <tvm/api_registry.h>

Expand All @@ -38,17 +38,17 @@ using namespace ir;

// a visitor to find the path to the target variable
// from a expression.
class VariablePathFinder: public IRVisitor {
class VariablePathFinder: public ExprVisitor {
public:
explicit VariablePathFinder(Expr target) : target_(target) {}

void Visit(const ObjectRef& node) final {
void VisitExpr(const Expr& node) final {
if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get());

if (!found_) path_.push_back(node.get());
if (node.same_as(target_)) found_ = true;
IRVisitor::Visit(node);
ExprVisitor::VisitExpr(node);
if (!found_) path_.pop_back();
}

Expand All @@ -64,14 +64,14 @@ class VariablePathFinder: public IRVisitor {
// return empty vector to represent failure
std::vector<const Object*> GetPath(Expr target, Expr expr) {
VariablePathFinder v(target);
v.Visit(expr);
v(expr);
return v.path_;
}

enum CompareOp {kGreater, kLess, kEqual};

// a visitor to deduce the bound of a variable from a expression
class BoundDeducer: public IRVisitor {
class BoundDeducer: public ExprVisitor {
public:
friend class BoundDeduceInputChecker;
friend class Converter;
Expand All @@ -82,39 +82,39 @@ class BoundDeducer: public IRVisitor {

void Deduce();

void Visit(const ObjectRef& e) final {
void VisitExpr(const Expr& e) final {
if (!success_) return;
if (e.get() == path_[iter_++]) {
IRVisitor::Visit(e);
ExprVisitor::VisitExpr(e);
} else {
success_ = false;
return;
}
}

void Visit_(const LT* op) final {
void VisitExpr_(const LT* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}

void Visit_(const LE* op) final {
void VisitExpr_(const LE* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}

void Visit_(const GT* op) final {
void VisitExpr_(const GT* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}

void Visit_(const GE* op) final {
void VisitExpr_(const GE* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}

void Visit_(const Add* op) final {
void VisitExpr_(const Add* op) final {
bool left = op->a.get() == path_[iter_];
result_ -= left ? op->b : op->a;
Visit(left ? op->a : op->b);
this->VisitExpr(left ? op->a : op->b);
}

void Visit_(const Sub* op) final {
void VisitExpr_(const Sub* op) final {
bool left = op->a.get() == path_[iter_];
if (left) {
result_ += op->b;
Expand All @@ -123,10 +123,10 @@ class BoundDeducer: public IRVisitor {
result_ = - result_;
comp_op = ReverseOp(comp_op);
}
Visit(left ? op->a : op->b);
this->VisitExpr(left ? op->a : op->b);
}

void Visit_(const Mul* op) final {
void VisitExpr_(const Mul* op) final {
bool left = op->a.get() == path_[iter_];
Expr operand = left ? op->b : op->a;
Expr target_var = left ? op->a : op->b;
Expand Down Expand Up @@ -171,7 +171,7 @@ class BoundDeducer: public IRVisitor {
// ( x <= -3/-2 --> x <= 1)
}
}
Visit(left ? op->a : op->b);
this->VisitExpr(left ? op->a : op->b);
}

Expr result_;
Expand All @@ -194,17 +194,17 @@ class BoundDeducer: public IRVisitor {
Analyzer analyzer_;
};

class BoundDeduceInputChecker: public IRVisitor {
class BoundDeduceInputChecker: public ExprVisitor {
public:
bool Check(BoundDeducer* deducer) {
deducer_ = deducer;
Visit(deducer_->expr_);
this->VisitExpr(deducer_->expr_);
return target_count == 1;
}

void Visit(const ObjectRef& e) final {
void VisitExpr(const Expr& e) final {
if (e.same_as(deducer_->target_)) ++target_count;
IRVisitor::Visit(e);
ExprVisitor::VisitExpr(e);
}

private:
Expand Down Expand Up @@ -305,7 +305,7 @@ void BoundDeducer::Deduce() {
}
expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_);

Visit(expr_);
this->VisitExpr(expr_);
}

void BoundDeducer::Relax() {
Expand Down
Loading